gusucode.com > vision工具箱matlab源码程序 > vision/trainRCNNObjectDetector.m
function detector = trainRCNNObjectDetector(groundTruth, network, options, varargin) %trainRCNNObjectDetector Train an R-CNN deep learning object detector % Use of this function requires that you have the Neural Network Toolbox, % the Parallel Computing Toolbox, the Statistics and Machine Learning % Toolbox, and a CUDA-capable NVIDIA(TM) GPU with compute capability 3.0 or % higher. % % detector = trainRCNNObjectDetector(groundTruth, network, options) trains % an R-CNN (Regions with CNN features) based object detector using deep % learning. An R-CNN detector can be trained to detect multiple object % classes. % % Inputs: % ------- % groundTruth - a table with 2 or more columns. The first column must % contain image file names. The images can be grayscale or % true color, and can be in any format supported by IMREAD. % The remaining columns must contain M-by-4 matrices of [x, % y, width, height] bounding boxes specifying object % locations within each image. Each column represents a % single object class, e.g. person, car, dog. The table % variable names define the object class names. You can use % the trainingImageLabeler app to create this table. % % network - a SeriesNetwork or an array of Layer objects defining the % pre-trained network. The network is trained to classify % object classes defined in the input groundTruth. The % SeriesNetwork and Layer object are available in the Neural % Network Toolbox. See help for SeriesNetwork and nnet.cnn.layer % for more details. % % options - training options returned by the trainingOptions % function from Neural Network Toolbox. The training options % define the training parameters of the neural network. See % help for trainingOptions for more details. For fine-tuning % a pre-trained network for detection, it is recommended to % lower the initial learning rate to avoid changing the model % parameters too rapidly. For example, use the following % syntax to adjust the learning rate: % % options = trainingOptions('sgdm', ... % 'InitialLearningRate', 1e-6); % % rcnn = trainRCNNObjectDetector(groundTruth, network, options); % % Setting a 'CheckpointPath' using the trainingOptions is % also recommended because network training may take a few % hours. % % [...] = trainRCNNObjectDetector(..., Name, Value) specifies additional % name-value pair arguments described below: % % 'PositiveOverlapRange' A two-element vector that specifies a range of % bounding box overlap ratios between 0 and 1. % Region proposals that overlap with ground truth % bounding boxes within the specified range are used % as positive training samples. % % Default: [0.5 1] % % 'NegativeOverlapRange' A two-element vector that specifies a range of % bounding box overlap ratios between 0 and 1. % Region proposals that overlap with ground truth % bounding boxes within the specified range are used % as negative training samples. % % Default: [0.1 0.5] % % 'NumStrongestRegions' The maximum number of strongest region proposals % to use for generating training samples. Reduce % this value to speed-up processing time at the cost % of training accuracy. Set this to inf to use all % region proposals. % % Default: 2000 % % [...] = trainRCNNObjectDetector(..., 'RegionProposalFcn', proposalFcn) % optionally train an R-CNN detector using a custom region proposal % function, proposalFcn. If a custom region proposal function is not % specified, a variant of the EdgeBoxes algorithm is automatically used. A % custom proposalFcn must have the following functional form: % % [bboxes, scores] = proposalFcn(I) % % where the input I is an image defined in the groundTruth table. The % function must return rectangular bounding boxes in an M-by-4 array. Each % row of bboxes contains a four-element vector, [x, y, width, height]. This % vector specifies the upper-left corner and size of a bounding box in % pixels. The function must also return a score for each bounding box in an % M-by-1 vector. Higher score values indicate that the bounding box is more % likely to contain an object. The scores are used to select the strongest % N regions, where N is defined by the value of 'NumStrongestRegions'. % % Notes: % ------ % - trainRCNNObjectDetector supports parallel computing using % multiple MATLAB workers. Enable parallel computing using the % <a href="matlab:preferences('Computer Vision System Toolbox')">preferences dialog</a>. % % - This implementation of R-CNN does not train an SVM classifier for each % object class. % % - The overlap ratio used in 'PositiveOverlapRange' and % 'NegativeOverlapRange' is defined as area(A intersect B) / area(A union B), % where A and B are bounding boxes. % % - Use the trainingOptions function to enable or disable verbose printing. % % - When the network is a SeriesNetwork, the network layers are % automatically adjusted to support the number of object classes defined % within the groundTruth training data plus an extra "Background" class. % % - When the network is an array of Layer objects, the network must have a % classification layer that supports the number of object classes plus a % background class. Use this input type when you want to customize the % learning rates of each layer. You may also use this type of input to % resume training from a previous training session. This can be useful if % the network requires additional rounds of fine-tuning or if you wish to % train with additional training data. % % Example - Train a stop sign detector % ------------------------------------ % load('rcnnStopSigns.mat', 'stopSigns', 'layers') % % % Add fullpath to image files % stopSigns.imageFilename = fullfile(toolboxdir('vision'),'visiondata', ... % stopSigns.imageFilename); % % % Set network training options to use mini-batch size of 32 to reduce GPU % % memory usage. Lower the InitialLearningRate to reduce the rate at which % % network parameters are changed. This is beneficial when fine-tuning a % % pre-trained network and prevents the network from changing too rapidly. % options = trainingOptions('sgdm', ... % 'MiniBatchSize', 32, ... % 'InitialLearnRate', 1e-6, ... % 'MaxEpochs', 10); % % % Train the R-CNN detector. Training can take a few minutes to complete. % rcnn = trainRCNNObjectDetector(stopSigns, layers, options, 'NegativeOverlapRange', [0 0.3]); % % % Test the R-CNN detector on a test image. % img = imread('stopSignTest.jpg'); % % [bbox, score, label] = detect(rcnn, img, 'MiniBatchSize', 32); % % % Display strongest detection result % [score, idx] = max(score); % % bbox = bbox(idx, :); % annotation = sprintf('%s: (Confidence = %f)', label(idx), score); % % detectedImg = insertObjectAnnotation(img, 'rectangle', bbox, annotation); % % figure % imshow(detectedImg) % % % <a href="matlab:showdemo('DeepLearningRCNNObjectDetectionExample')">Learn more about training an R-CNN Object Detector.</a> % % See also rcnnObjectDetector, SeriesNetwork, trainingOptions, trainNetwork, % nnet.cnn.layer, trainingImageLabeler, trainCascadeObjectDetector. % References: % ----------- % Girshick, Ross, et al. "Rich feature hierarchies for accurate object % detection and semantic segmentation." Proceedings of the IEEE conference % on computer vision and pattern recognition. 2014. % % Girshick, Ross. "Fast r-cnn." Proceedings of the IEEE International % Conference on Computer Vision. 2015. % % Zitnick, C. Lawrence, and Piotr Dollar. "Edge boxes: Locating object % proposals from edges." Computer Vision-ECCV 2014. Springer International % Publishing, 2014. 391-405. vision.internal.requiresStatisticsToolbox(mfilename); vision.internal.requiresNeuralToolbox(mfilename); vision.internal.requiresCUDAComputeCapability30(mfilename); params = parseInputs(groundTruth, network, options, varargin{:}); if params.IsNetwork % auto trim network for detection task layers = rcnnObjectDetector.initializeRCNNLayers(... network, params.NumClasses); else layers = network; end layers = removeRandCropAugmentationIfNeeded(layers); detector = rcnnObjectDetector.train(groundTruth, layers, options, params); %-------------------------------------------------------------------------- function params = parseInputs(groundTruth, network, options, varargin) checkGroundTruth(groundTruth); checkNetwork(network); checkTrainingOptions(options); p = inputParser; p.addParameter('RegionProposalFcn', @rcnnObjectDetector.proposeRegions); p.addParameter('UseParallel', vision.internal.useParallelPreference()); p.addParameter('PositiveOverlapRange', [0.5 1]); p.addParameter('NegativeOverlapRange', [0.1 0.5]); p.addParameter('NumStrongestRegions', 2000); p.parse(varargin{:}); userInput = p.Results; rcnnObjectDetector.checkRegionProposalFcn(userInput.RegionProposalFcn); useParallel = vision.internal.inputValidation.validateUseParallel(userInput.UseParallel); checkOverlapRatio(userInput.PositiveOverlapRange, 'PositiveOverlapRange'); checkOverlapRatio(userInput.NegativeOverlapRange, 'NegativeOverlapRange'); checkStrongestRegions(p.Results.NumStrongestRegions); if isa(network, 'SeriesNetwork') params.IsNetwork = true; checkNetworkLayers(network.Layers); elseif isa(network, 'nnet.cnn.layer.Layer') params.IsNetwork = false; checkNetworkLayers(network); checkNetworkClassificationLayer(network, groundTruth); end params.NumClasses = width(groundTruth) - 1; params.PositiveOverlapRange = double(userInput.PositiveOverlapRange); params.NegativeOverlapRange = double(userInput.NegativeOverlapRange); params.RegionProposalFcn = userInput.RegionProposalFcn; params.UsingDefaultRegionProposalFcn = ismember('RegionProposalFcn', p.UsingDefaults); params.NumStrongestRegions = double(userInput.NumStrongestRegions); params.UseParallel = useParallel; params.BackgroundLabel = getBackgroundLabel(groundTruth); assertPositiveAndNegativeOverlapRatioDoNotOverlap(params); %-------------------------------------------------------------------------- function label = getBackgroundLabel(groundTruth) % Use a background label that is not included in the ground truth. label = 'Background'; while ismember(label, groundTruth.Properties.VariableNames) label = sprintf('%s_%d', label, randi(9)); end %-------------------------------------------------------------------------- function checkOverlapRatio(range, name) validateattributes(range, {'numeric'}, ... {'vector', 'numel', 2, 'real', 'finite', 'nonsparse', 'nonnegative', 'increasing', '<=', 1},... mfilename, name); %-------------------------------------------------------------------------- function assertPositiveAndNegativeOverlapRatioDoNotOverlap(params) % positive and negative ranges should not overlap if params.PositiveOverlapRange(1) < params.NegativeOverlapRange(2) error(message('vision:rcnn:rangesOverlap')); end %-------------------------------------------------------------------------- function checkTrainingOptions(options) validateattributes(options, {'nnet.cnn.TrainingOptionsSGDM'}, {}, mfilename); if options.MiniBatchSize < 4 error(message('vision:rcnn:miniBatchSizeTooSmall')); end %-------------------------------------------------------------------------- function checkGroundTruth(gt) validateattributes(gt, {'table'},{'nonempty'}, mfilename, 'groundTruth',1); if width(gt) < 2 error(message('vision:rcnn:groundTruthTableWidthLessThanTwo')); end %-------------------------------------------------------------------------- function checkNetwork(network) validateattributes(network,{'SeriesNetwork', 'nnet.cnn.layer.Layer'},{}); %-------------------------------------------------------------------------- % Remove the randcrop augmentation, if present. This augmentation is not % useful when training an R-CNN network. %-------------------------------------------------------------------------- function layers = removeRandCropAugmentationIfNeeded(layers) if ismember('randcrop', layers(1).DataAugmentation) warning(message('vision:rcnn:removingRandCrop')); augmentations = cellstr(layers(1).DataAugmentation); idx = strcmpi('randcrop', augmentations); % remove randcrop data augmentation augmentations(idx) = []; if isempty(augmentations) augmentations = 'none'; end layers(1) = imageInputLayer(layers(1).InputSize, ... 'Name', layers(1).Name,... 'DataAugmentation', augmentations, ... 'Normalization', layers(1).Normalization); end %-------------------------------------------------------------------------- function checkNetworkLayers(layers) % First layer must have an image input layer if ~(numel(layers) >= 1 && isa(layers(1), 'nnet.cnn.layer.ImageInputLayer')) error(message('vision:rcnn:firstLayerNotImageInputLayer')); end % Last two layers must be softmax followed by a classification layer if ~(numel(layers) >= 3 && ... isa(layers(end), 'nnet.cnn.layer.ClassificationOutputLayer') && ... isa(layers(end-1), 'nnet.cnn.layer.SoftmaxLayer')) error(message('vision:rcnn:lastLayerNotClassificationLayer')); end %-------------------------------------------------------------------------- function checkNetworkClassificationLayer(layers, groundTruth) % The classification layer must support N classes plus a background class processedLayers = nnet.cnn.layer.Layer.inferParameters(layers); if processedLayers(end).OutputSize ~= width(groundTruth) error(message('vision:rcnn:notEnoughObjectClasses')); end %-------------------------------------------------------------------------- function checkStrongestRegions(N) if isinf(N) % OK, use all regions. else validateattributes(N, ... {'numeric'},... {'scalar', 'real', 'positive', 'integer', 'nonempty', 'finite', 'nonsparse'}, ... mfilename, 'NumStrongestRegions'); end