gusucode.com > vision工具箱matlab源码程序 > vision/trainRCNNObjectDetector.m

    function detector = trainRCNNObjectDetector(groundTruth, network, options, varargin)
%trainRCNNObjectDetector Train an R-CNN deep learning object detector
% Use of this function requires that you have the Neural Network Toolbox,
% the Parallel Computing Toolbox, the Statistics and Machine Learning
% Toolbox, and a CUDA-capable NVIDIA(TM) GPU with compute capability 3.0 or
% higher.
%
% detector = trainRCNNObjectDetector(groundTruth, network, options) trains
% an R-CNN (Regions with CNN features) based object detector using deep
% learning. An R-CNN detector can be trained to detect multiple object
% classes.
%
% Inputs:
% -------
% groundTruth - a table with 2 or more columns. The first column must
%               contain image file names. The images can be grayscale or
%               true color, and can be in any format supported by IMREAD.
%               The remaining columns must contain M-by-4 matrices of [x,
%               y, width, height] bounding boxes specifying object
%               locations within each image. Each column represents a
%               single object class, e.g. person, car, dog. The table
%               variable names define the object class names. You can use
%               the trainingImageLabeler app to create this table.
%
% network     - a SeriesNetwork or an array of Layer objects defining the
%               pre-trained network. The network is trained to classify
%               object classes defined in the input groundTruth. The
%               SeriesNetwork and Layer object are available in the Neural
%               Network Toolbox. See help for SeriesNetwork and nnet.cnn.layer
%               for more details.
%
% options     - training options returned by the trainingOptions
%               function from Neural Network Toolbox. The training options
%               define the training parameters of the neural network. See
%               help for trainingOptions for more details. For fine-tuning
%               a pre-trained network for detection, it is recommended to
%               lower the initial learning rate to avoid changing the model
%               parameters too rapidly. For example, use the following
%               syntax to adjust the learning rate:
%
%               options = trainingOptions('sgdm', ...
%                          'InitialLearningRate', 1e-6);
%
%               rcnn = trainRCNNObjectDetector(groundTruth, network, options);
%
%               Setting a 'CheckpointPath' using the trainingOptions is
%               also recommended because network training may take a few
%               hours.
%
% [...] = trainRCNNObjectDetector(..., Name, Value) specifies additional
% name-value pair arguments described below:
%
% 'PositiveOverlapRange' A two-element vector that specifies a range of
%                        bounding box overlap ratios between 0 and 1.
%                        Region proposals that overlap with ground truth
%                        bounding boxes within the specified range are used
%                        as positive training samples.
%
%                        Default: [0.5 1]
%
% 'NegativeOverlapRange' A two-element vector that specifies a range of
%                        bounding box overlap ratios between 0 and 1.
%                        Region proposals that overlap with ground truth
%                        bounding boxes within the specified range are used
%                        as negative training samples.
%
%                        Default: [0.1 0.5]
%
% 'NumStrongestRegions'  The maximum number of strongest region proposals
%                        to use for generating training samples. Reduce
%                        this value to speed-up processing time at the cost
%                        of training accuracy. Set this to inf to use all
%                        region proposals.
%
%                        Default: 2000
%
% [...] = trainRCNNObjectDetector(..., 'RegionProposalFcn', proposalFcn)
% optionally train an R-CNN detector using a custom region proposal
% function, proposalFcn.  If a custom region proposal function is not
% specified, a variant of the EdgeBoxes algorithm is automatically used. A
% custom proposalFcn must have the following functional form:
%
%    [bboxes, scores] = proposalFcn(I)
%
% where the input I is an image defined in the groundTruth table. The
% function must return rectangular bounding boxes in an M-by-4 array. Each
% row of bboxes contains a four-element vector, [x, y, width, height]. This
% vector specifies the upper-left corner and size of a bounding box in
% pixels. The function must also return a score for each bounding box in an
% M-by-1 vector. Higher score values indicate that the bounding box is more
% likely to contain an object. The scores are used to select the strongest
% N regions, where N is defined by the value of 'NumStrongestRegions'.
%
% Notes:
% ------
% - trainRCNNObjectDetector supports parallel computing using
%   multiple MATLAB workers. Enable parallel computing using the 
%   <a href="matlab:preferences('Computer Vision System Toolbox')">preferences dialog</a>. 
%
% - This implementation of R-CNN does not train an SVM classifier for each
%   object class. 
%
% - The overlap ratio used in 'PositiveOverlapRange' and
%  'NegativeOverlapRange' is defined as area(A intersect B) / area(A union B),
%   where A and B are bounding boxes.
% 
% - Use the trainingOptions function to enable or disable verbose printing.
%
% - When the network is a SeriesNetwork, the network layers are
%   automatically adjusted to support the number of object classes defined
%   within the groundTruth training data plus an extra "Background" class.
% 
% - When the network is an array of Layer objects, the network must have a
%   classification layer that supports the number of object classes plus a
%   background class. Use this input type when you want to customize the
%   learning rates of each layer. You may also use this type of input to
%   resume training from a previous training session. This can be useful if
%   the network requires additional rounds of fine-tuning or if you wish to
%   train with additional training data.
% 
% Example - Train a stop sign detector
% ------------------------------------
% load('rcnnStopSigns.mat', 'stopSigns', 'layers')
%
% % Add fullpath to image files
% stopSigns.imageFilename = fullfile(toolboxdir('vision'),'visiondata', ...
%     stopSigns.imageFilename);
%
% % Set network training options to use mini-batch size of 32 to reduce GPU
% % memory usage. Lower the InitialLearningRate to reduce the rate at which
% % network parameters are changed. This is beneficial when fine-tuning a
% % pre-trained network and prevents the network from changing too rapidly.
% options = trainingOptions('sgdm', ...
%     'MiniBatchSize', 32, ...
%     'InitialLearnRate', 1e-6, ...
%     'MaxEpochs', 10);
%
% % Train the R-CNN detector. Training can take a few minutes to complete.
% rcnn = trainRCNNObjectDetector(stopSigns, layers, options, 'NegativeOverlapRange', [0 0.3]);
%
% % Test the R-CNN detector on a test image.
% img = imread('stopSignTest.jpg');
%
% [bbox, score, label] = detect(rcnn, img, 'MiniBatchSize', 32);
%
% % Display strongest detection result
% [score, idx] = max(score);
%
% bbox = bbox(idx, :);
% annotation = sprintf('%s: (Confidence = %f)', label(idx), score);
%
% detectedImg = insertObjectAnnotation(img, 'rectangle', bbox, annotation);
%
% figure
% imshow(detectedImg)
%
% % <a href="matlab:showdemo('DeepLearningRCNNObjectDetectionExample')">Learn more about training an R-CNN Object Detector.</a> 
%
% See also rcnnObjectDetector, SeriesNetwork, trainingOptions, trainNetwork,
%          nnet.cnn.layer, trainingImageLabeler, trainCascadeObjectDetector.

% References:
% -----------
% Girshick, Ross, et al. "Rich feature hierarchies for accurate object
% detection and semantic segmentation." Proceedings of the IEEE conference
% on computer vision and pattern recognition. 2014.
%
% Girshick, Ross. "Fast r-cnn." Proceedings of the IEEE International
% Conference on Computer Vision. 2015.
%
% Zitnick, C. Lawrence, and Piotr Dollar. "Edge boxes: Locating object
% proposals from edges." Computer Vision-ECCV 2014. Springer International
% Publishing, 2014. 391-405.

vision.internal.requiresStatisticsToolbox(mfilename);
vision.internal.requiresNeuralToolbox(mfilename);
vision.internal.requiresCUDAComputeCapability30(mfilename);

params = parseInputs(groundTruth, network, options, varargin{:});

if params.IsNetwork
    % auto trim network for detection task
    layers = rcnnObjectDetector.initializeRCNNLayers(...
        network, params.NumClasses);
else
    layers = network;
end

layers = removeRandCropAugmentationIfNeeded(layers);

detector = rcnnObjectDetector.train(groundTruth, layers, options, params);

%--------------------------------------------------------------------------
function params = parseInputs(groundTruth, network, options, varargin)
checkGroundTruth(groundTruth);
checkNetwork(network);
checkTrainingOptions(options);

p = inputParser;
p.addParameter('RegionProposalFcn', @rcnnObjectDetector.proposeRegions);
p.addParameter('UseParallel', vision.internal.useParallelPreference());
p.addParameter('PositiveOverlapRange', [0.5 1]);
p.addParameter('NegativeOverlapRange', [0.1 0.5]); 
p.addParameter('NumStrongestRegions', 2000);
p.parse(varargin{:});

userInput = p.Results;

rcnnObjectDetector.checkRegionProposalFcn(userInput.RegionProposalFcn);

useParallel = vision.internal.inputValidation.validateUseParallel(userInput.UseParallel);

checkOverlapRatio(userInput.PositiveOverlapRange, 'PositiveOverlapRange');
checkOverlapRatio(userInput.NegativeOverlapRange, 'NegativeOverlapRange');

checkStrongestRegions(p.Results.NumStrongestRegions);

if isa(network, 'SeriesNetwork')
    params.IsNetwork = true;       
    checkNetworkLayers(network.Layers);
elseif isa(network, 'nnet.cnn.layer.Layer')    
    params.IsNetwork = false;    
    checkNetworkLayers(network);
    checkNetworkClassificationLayer(network, groundTruth);    
end

params.NumClasses = width(groundTruth) - 1;
params.PositiveOverlapRange = double(userInput.PositiveOverlapRange);
params.NegativeOverlapRange = double(userInput.NegativeOverlapRange);
params.RegionProposalFcn             = userInput.RegionProposalFcn;
params.UsingDefaultRegionProposalFcn = ismember('RegionProposalFcn', p.UsingDefaults);
params.NumStrongestRegions           = double(userInput.NumStrongestRegions);
params.UseParallel                   = useParallel;
params.BackgroundLabel               = getBackgroundLabel(groundTruth);

assertPositiveAndNegativeOverlapRatioDoNotOverlap(params);

%--------------------------------------------------------------------------
function label = getBackgroundLabel(groundTruth)
% Use a background label that is not included in the ground truth.
label = 'Background';
while ismember(label, groundTruth.Properties.VariableNames)
    label = sprintf('%s_%d', label, randi(9));
end

%--------------------------------------------------------------------------
function checkOverlapRatio(range, name)
validateattributes(range, {'numeric'}, ...
    {'vector', 'numel', 2, 'real', 'finite', 'nonsparse', 'nonnegative', 'increasing', '<=', 1},...
    mfilename, name);

%--------------------------------------------------------------------------
function assertPositiveAndNegativeOverlapRatioDoNotOverlap(params)
% positive and negative ranges should not overlap
if params.PositiveOverlapRange(1) < params.NegativeOverlapRange(2)
    error(message('vision:rcnn:rangesOverlap'));
end

%--------------------------------------------------------------------------
function checkTrainingOptions(options)
validateattributes(options, {'nnet.cnn.TrainingOptionsSGDM'}, {}, mfilename);

if options.MiniBatchSize < 4
    error(message('vision:rcnn:miniBatchSizeTooSmall'));
end

%--------------------------------------------------------------------------
function checkGroundTruth(gt)
validateattributes(gt, {'table'},{'nonempty'}, mfilename, 'groundTruth',1);

if width(gt) < 2 
    error(message('vision:rcnn:groundTruthTableWidthLessThanTwo'));
end

%--------------------------------------------------------------------------
function checkNetwork(network)
validateattributes(network,{'SeriesNetwork', 'nnet.cnn.layer.Layer'},{});

%--------------------------------------------------------------------------
% Remove the randcrop augmentation, if present. This augmentation is not
% useful when training an R-CNN network.
%--------------------------------------------------------------------------
function layers = removeRandCropAugmentationIfNeeded(layers)

if ismember('randcrop', layers(1).DataAugmentation)
    warning(message('vision:rcnn:removingRandCrop'));
    
    augmentations = cellstr(layers(1).DataAugmentation);
    idx = strcmpi('randcrop', augmentations);
    
    % remove randcrop data augmentation
    augmentations(idx) = [];
    
    if isempty(augmentations)
        augmentations = 'none';
    end        
    
    layers(1) = imageInputLayer(layers(1).InputSize, ...
        'Name', layers(1).Name,...
        'DataAugmentation', augmentations, ...
        'Normalization', layers(1).Normalization);
end

%--------------------------------------------------------------------------
function checkNetworkLayers(layers)

% First layer must have an image input layer
if ~(numel(layers) >= 1 && isa(layers(1), 'nnet.cnn.layer.ImageInputLayer'))
    error(message('vision:rcnn:firstLayerNotImageInputLayer'));
end

% Last two layers must be softmax followed by a classification layer
if ~(numel(layers) >= 3 && ...
        isa(layers(end), 'nnet.cnn.layer.ClassificationOutputLayer') && ...
        isa(layers(end-1), 'nnet.cnn.layer.SoftmaxLayer'))
    
    error(message('vision:rcnn:lastLayerNotClassificationLayer'));
end

%--------------------------------------------------------------------------
function checkNetworkClassificationLayer(layers, groundTruth)

% The classification layer must support N classes plus a background class
processedLayers = nnet.cnn.layer.Layer.inferParameters(layers);
if processedLayers(end).OutputSize ~= width(groundTruth)
    error(message('vision:rcnn:notEnoughObjectClasses'));
end

%--------------------------------------------------------------------------
function checkStrongestRegions(N)
if isinf(N)
    % OK, use all regions.
else
    validateattributes(N, ...
        {'numeric'},...
        {'scalar', 'real', 'positive', 'integer', 'nonempty', 'finite', 'nonsparse'}, ...
        mfilename, 'NumStrongestRegions');
end