gusucode.com > vision工具箱matlab源码程序 > vision/+vision/+internal/+rcnn/BoundingBoxRegressionModel.m
% This class defines the bounding box regression model defined here: % % Girshick, Ross, et al. "Rich feature hierarchies for accurate object % detection and semantic segmentation." Proceedings of the IEEE conference % on computer vision and pattern recognition. 2014. classdef BoundingBoxRegressionModel < handle properties ModelX ModelY ModelW ModelH Lambda Params IsTrained end %====================================================================== methods function this = BoundingBoxRegressionModel(params) this.Lambda = params.Lambda; this.Params = {'Solver', 'dual', 'Regularization','ridge', 'ObservationsIn','columns'}; this.IsTrained = false; end %------------------------------------------------------------------ function [beta, bias] = getInitialWeights(~, featureLength) initValue = zeros(featureLength, 1); beta.X = initValue; beta.Y = initValue; beta.W = initValue; beta.H = initValue; bias.X = 0; bias.Y = 0; bias.W = 0; bias.H = 0; end %------------------------------------------------------------------ function update(this, X, Y) if this.IsTrained [beta, bias] = getWeights(this); else [beta, bias] = getInitialWeights(this, size(X,1)); end this.ModelX = fitrlinear(X, Y(:,1), 'Beta', beta.X, 'Bias', bias.X, 'Lambda', this.Lambda, this.Params{:}); this.ModelY = fitrlinear(X, Y(:,2), 'Beta', beta.Y, 'Bias', bias.Y, 'Lambda', this.Lambda, this.Params{:}); this.ModelW = fitrlinear(X, Y(:,3), 'Beta', beta.W, 'Bias', bias.W, 'Lambda', this.Lambda, this.Params{:}); this.ModelH = fitrlinear(X, Y(:,4), 'Beta', beta.H, 'Bias', bias.H, 'Lambda', this.Lambda, this.Params{:}); % After update is called once, the model is considered trained. this.IsTrained = true; end %------------------------------------------------------------------ function [beta, bias] = getWeights(this) beta.X = this.ModelX.Beta; beta.Y = this.ModelY.Beta; beta.W = this.ModelW.Beta; beta.H = this.ModelH.Beta; bias.X = this.ModelX.Bias; bias.Y = this.ModelY.Bias; bias.W = this.ModelW.Bias; bias.H = this.ModelH.Bias; end %------------------------------------------------------------------ function fit(this, X, Y) this.ModelX = fitrlinear(X, Y(:,1), 'Lambda', this.Lambda, this.Params{:}); this.ModelY = fitrlinear(X, Y(:,2), 'Lambda', this.Lambda, this.Params{:}); this.ModelW = fitrlinear(X, Y(:,3), 'Lambda', this.Lambda, this.Params{:}); this.ModelH = fitrlinear(X, Y(:,4), 'Lambda', this.Lambda, this.Params{:}); this.IsTrained = true; end function g = apply(this, features, P) if ~isempty(features) [x, y, w, h] = predict(this, features); % center of proposals px = P(:,1) + floor(P(:,3)/2); py = P(:,2) + floor(P(:,4)/2); % compute regression value of ground truth box gx = P(:,3).*x + px; % center position gy = P(:,4).*y + py; gw = P(:,3) .* exp(w); gh = P(:,4) .* exp(h); % convert to [x y w h] format g = [ gx - floor(gw/2) gy - floor(gh/2) gw gh]; else g = zeros(0,4); end g = round(g); end %------------------------------------------------------------------ function [tx, ty, tw, th] = predict(this, features) tx = predict(this.ModelX, features, 'ObservationsIn', 'columns'); ty = predict(this.ModelY, features, 'ObservationsIn', 'columns'); tw = predict(this.ModelW, features, 'ObservationsIn', 'columns'); th = predict(this.ModelH, features, 'ObservationsIn', 'columns'); end %------------------------------------------------------------------ function [x, y] = getTrainingSamples(~, data, th) P = getProposals(data); G = getGroundTruth(data); F = getRegionFeatures(data); [G, P, selected] = vision.internal.rcnn.BoundingBoxRegressionModel.selectBBoxesForTraining(G, P, th); x = F(:, selected); y = vision.internal.rcnn.BoundingBoxRegressionModel.generateRegressionTargets(G, P); end end %====================================================================== methods(Static) function y = generateRegressionTargets(G, P) % Create regression targets. % center of proposal px = P(:,1) + floor(P(:,3)/2); py = P(:,2) + floor(P(:,4)/2); % center of gt gx = G(:,1) + floor(G(:,3)/2); gy = G(:,2) + floor(G(:,4)/2); tx = (gx - px)./P(:,3); ty = (gy - py)./P(:,4); tw = log(G(:,3)./P(:,3)); th = log(G(:,4)./P(:,4)); y = [tx ty tw th]; % observations in columns end %------------------------------------------------------------------ function [G, P, L, selected] = selectBBoxesForTraining(G, P, L, th) % find proposals that overlap with gt by some threshold. these % are use for training. % Input L contains label of each ground truth. if isempty(G) iou = zeros(0,size(P,1)); elseif isempty(P) iou = zeros(size(G,1),0); else iou = bboxOverlapRatio(G, P, 'union'); end lower = th(1); upper = th(2); % find ground truth that overlaps the most with each % proposal. index of ground truth [v,i] = max(iou,[],1); selected = v >= lower & v <= upper; L = L(i,:); % assign ground truth label to each proposal L = L(selected,:); % select proposals within overlap range P = P(selected,:); G = G(i(selected), :); % create an array that maps groundTruth to proposal, i.e. P(i,:) is assigned to G(i,:). end end end