gusucode.com > vision工具箱matlab源码程序 > vision/imageCategoryClassifier.m
%imageCategoryClassifier Predict image category. % imageCategoryClassifier is returned by trainImageCategoryClassifier % function. It contains an SVM classifier trained to recognize an image % category. Use of the imageCategoryClassifier requires that you have % the Statistics and Machine Learning Toolbox. % % imageCategoryClassifier methods: % predict - Predict image category % evaluate - Returns prediction results and confusion matrix for input image sets % % imageCategoryClassifier properties: % Labels - A cell array of category labels % NumCategories - Number of trained categories % % Notes: % ------ % - imageCategoryClassifier supports parallel computing using % multiple MATLAB workers. Enable parallel computing using the % <a href="matlab:preferences('Computer Vision System Toolbox')">preferences dialog</a>. % % Example % ------- % % Load two image categories % setDir = fullfile(toolboxdir('vision'),'visiondata','imageSets'); % imds = imageDatastore(setDir, 'IncludeSubfolders', true, 'LabelSource', 'foldernames'); % % % Split data into a training and test set. Pick 30% of images from % % each label for training and the remainder (70%) for testing. % [trainingSet, testSet] = splitEachLabel(imds, 0.3, 'randomize'); % % % Create bag of visual words % bag = bagOfFeatures(trainingSet); % % % Train a classifier % categoryClassifier = trainImageCategoryClassifier(trainingSet, bag); % % % Evaluate the classifier on test images and display the confusion matrix % confMatrix = evaluate(categoryClassifier, testSet) % % % Average accuracy % mean(diag(confMatrix)) % % % You can apply the newly trained classifier to categorize new images % img = imread(fullfile(setDir, 'cups', 'bigMug.jpg')); % [labelIdx, scores] = predict(categoryClassifier, img); % % Display the string label % categoryClassifier.Labels(labelIdx) % % See also imageDatastore, bagOfFeatures, trainImageCategoryClassifier, % fitcecoc % Copyright 2014 MathWorks, Inc. % References: % Gabriella Csurka, Christopher R. Dance, Lixin Fan, Jutta Willamowski, % Cedric Bray "Visual Categorization with Bag of Keypoints", % Workshop on Statistical Learning in Computer Vision, ECCV, 2004. classdef imageCategoryClassifier < vision.internal.EnforceScalarHandle properties % A cell array of category labels Labels; end properties (GetAccess = public, SetAccess = private) % Number of trained categories NumCategories; end properties (Access = private) % Bag of features object used during the training Bag; % Multi-class classifier produced using fitcecoc function Classifier % Options passed into fitcecoc LearnerOptions end %----------------------------------------------------------------------- methods %------------------------------------------------------------------ function [label, score] = predict(this, img, varargin) %predict Predict image category % % [labelIdx, score] = predict(categoryClassifier, I) returns % the predicted label index and score. labelIdx corresponds to % the index of an image set used to train the bag of features. % The 1-by-N score vector provides negated average binary loss % per class of an SVM multi-class classifier that uses the % error correcting output codes (ECOC) approach. N is a number % of classes. labelIdx corresponds to the class with lowest % average binary loss. % % [labelIdx, score] = predict(categoryClassifier, imds) % returns M-by-1 predicted labelIdx indices and M-by-N scores % for M images in the ImageDatastore object, imds. N is % number of classes. % % [...] = predict(...) specifies additional name-value pairs % described below: % % 'Verbose' Set true to display progress information. % % Default: true % % Example % ------- % % Load two image categories % setDir = fullfile(toolboxdir('vision'),'visiondata','imageSets'); % imds = imageDatastore(setDir, 'IncludeSubfolders', true, 'LabelSource', 'foldernames'); % % % Split data into a training and test data. Pick 30% of images from % % each label for training data and the remainder (70%) for testing. % [trainingSet, testSet] = splitEachLabel(imds, 0.3, 'randomize'); % % % Create bag of visual words % bag = bagOfFeatures(trainingSet); % % % Train a classifier % categoryClassifier = trainImageCategoryClassifier(trainingSet, bag); % % % Predict category label for one of the images in testDs % img = readimage(testSet, 1); % [labelIdx, score] = predict(categoryClassifier, img); % categoryClassifier.Labels(labelIdx) vision.internal.requiresStatisticsToolbox(mfilename); isImageDs = isa(img, 'matlab.io.datastore.ImageDatastore'); isImageSet = isa(img, 'imageSet'); params = imageCategoryClassifier.parseCommonParameters(... isImageSet || isImageDs, varargin{:}); if isImageSet || isImageDs printer = vision.internal.MessagePrinter.configure(params.Verbose); this.printPredictHeader(printer); numImages = numel([img.Files]); if isImageSet numImgSets = numel(img); % Display image set information printer.printMessage('vision:imageCategoryClassifier:predictImageSets', numImgSets); this.printCategories(printer, numImgSets, ... 'vision:imageCategoryClassifier:imageSetDescription',... {img.Description}); [labelClass, scoreClass] = getPredictionOutputClasses(this, img(1)); label = zeros(numImages, 1, labelClass); score = zeros(numImages, this.NumCategories, scoreClass); outIdx = 1; for i=1:numImgSets imgSet = img(i); count = imgSet.Count; printer.printMessageNoReturn('vision:imageCategoryClassifier:predictStart',count,i); indices = outIdx:outIdx+count-1; [label(indices), score(indices, :)] = ... this.predictScalarImageSet(imgSet, params.UseParallel); outIdx = outIdx+count; printer.printMessage('vision:imageCategoryClassifier:predictDone'); end else % image datastore printer.printMessageNoReturn('vision:imageCategoryClassifier:predictStartDS',numImages); [label, score] = this.predictScalarImageSet(img, params.UseParallel); printer.printMessage('vision:imageCategoryClassifier:predictDone'); end printer.printMessage('vision:imageCategoryClassifier:predictFinished').linebreak; else vision.internal.inputValidation.validateImage(img,'I'); [label, score] = this.predictImage(img); end end % end of predict %------------------------------------------------------------------ function [confMat, knownLabel, predictedLabel, score] = evaluate(this, testSets, varargin) %evaluate Evaluate the classifier on a collection of images % confMat = evaluate(classifier, imds) returns a normalized % confusion matrix, confMat. Row indices of confMat correspond % to known labels, while columns correspond to the predicted % labels. classifier is an imageCategoryClassifier returned % by trainImageCategoryClassifier. imds is a ImageDatastore % object. % % [confMat, knownLabelIdx, predictedLabelIdx, score] = % evaluate(classifier, imds) additionally returns % knownLabelIdx, predictedLabelIdx, and M-by-N score. M is a % total number of images in imds and N is the number of image % categories. Each predictedLabelIdx index corresponds to the % class with largest value in score output. % % [...] = evaluate(..., Name, Value ) specifies additional % name-value pairs described below: % % 'Verbose' Set true to display progress information. % % Default: true % % Example % ------- % setDir = fullfile(toolboxdir('vision'),'visiondata','imageSets'); % imds = imageDatastore(setDir, 'IncludeSubfolders', true, 'LabelSource', 'foldernames'); % % % Split data into a training and test data. Pick 30% of images from % % each label for training and the remainder (70%) for testing. % [trainingSet, testSet] = splitEachLabel(imds, 0.3, 'randomize'); % % % Create bag of visual words % bag = bagOfFeatures(trainingSet); % % % Train a classifier % categoryClassifier = trainImageCategoryClassifier(trainingSet, bag); % % % Evaluate the classifier on test images and display the confusion matrix % confMatrix = evaluate(categoryClassifier, testSet) % % % Average accuracy % mean(diag(confMatrix)) vision.internal.requiresStatisticsToolbox(mfilename); varName = 'imds'; validateattributes(testSets,... {'imageSet', 'matlab.io.datastore.ImageDatastore'}, ... {'nonempty','vector'}, mfilename, varName); isImageSet = isa(testSets, 'imageSet'); params = imageCategoryClassifier.parseCommonParameters(true, varargin{:}); printer = vision.internal.MessagePrinter.configure(params.Verbose); this.printEvaluateHeader(printer); numImages = numel([testSets.Files]); if isImageSet % Check num categories against testSets if imageCategoryClassifier.getNumCategories(testSets) ~= this.NumCategories error(message('vision:imageCategoryClassifier:testSetsAndClassifierMustBeCompatible')); end if isanyempty(testSets) error(message('vision:dims:expectedNonemptyElements', varName)); end [labelClass, scoreClass] = getPredictionOutputClasses(this, testSets(1)); % preallocate knownLabel = zeros(numImages, 1, labelClass); predictedLabel = zeros(numImages, 1, labelClass); score = zeros(numImages, this.NumCategories, scoreClass); testIndex = 1; for categoryIndex = 1:this.NumCategories imageSet = testSets(categoryIndex); % process each image set printer.printMessageNoReturn('vision:imageCategoryClassifier:evalCategory',imageSet.Count,categoryIndex); [predicted, categoryScore] = this.predict(imageSet, ... 'UseParallel', params.UseParallel,'Verbose',false); actual = repmat(categoryIndex, imageSet.Count, 1); fillIdx = testIndex:testIndex+imageSet.Count-1; knownLabel(fillIdx, :) = actual; predictedLabel(fillIdx, :) = predicted; score(fillIdx, :) = categoryScore; testIndex = testIndex + imageSet.Count; printer.printMessage('vision:imageCategoryClassifier:evalCategoryDone'); end else % image datastore % Check num categories against testSets if imageCategoryClassifier.getNumCategories(testSets) ~= this.NumCategories error(message('vision:imageCategoryClassifier:testSetsAndClassifierMustBeCompatibleDS')); end printer.printMessageNoReturn('vision:imageCategoryClassifier:evalCategoryDS',numImages); [predictedLabel, score] = this.predict(testSets, ... 'UseParallel', params.UseParallel,'Verbose',false); knownLabel = cast(categorical(testSets.Labels), 'like', predictedLabel); printer.printMessage('vision:imageCategoryClassifier:evalCategoryDone'); end printer.linebreak; printer.printMessage('vision:imageCategoryClassifier:evalFinished').linebreak; printer.printMessage('vision:imageCategoryClassifier:evalDispConfMat').linebreak; % Display the results as a confusion matrix confMat = confusionmat(knownLabel, predictedLabel); confMat = bsxfun(@rdivide,confMat,sum(confMat,2)); % sum rows to get totals for actual labels this.printConfusionMatrix(printer, confMat); printer.printMessage('vision:imageCategoryClassifier:evalAvgAccuracy',... sprintf('%.2f',mean(diag(confMat)))).linebreak; end % ----------------------------------------------------------------- function set.Labels(this, labels) validateattributes(labels, {'cell'},{'vector'},... mfilename,'Labels'); if ~iscellstr(labels) error(message('vision:imageCategoryClassifier:descriptionMustBeAllStrings')); end if this.NumCategories ~= numel(labels) %#ok<MCSUP> error(message('vision:imageCategoryClassifier:numCategoriesMustMatch')); end this.Labels = labels; end %------------------------------------------------------------------ function s = saveobj(this) s.Labels = this.Labels; s.NumCategories = this.NumCategories; s.Classifier = this.Classifier; s.LearnerOptions = this.LearnerOptions; % Invoke customized saveobj for bagOfFeatures s.Bag = saveobj(this.Bag); end end % end public methods %====================================================================== methods (Hidden, Static) % ----------------------------------------------------------------- function this = create(imgSet, bag, varargin) this = imageCategoryClassifier(imgSet, bag, varargin{:}); end % ----------------------------------------------------------------- function params = parseCommonParameters(isImageSet, varargin) parser = inputParser(); parser.addParameter('Verbose', true); parser.addParameter('UseParallel', vision.internal.useParallelPreference()); parser.parse(varargin{:}); vision.internal.inputValidation.validateLogical(parser.Results.Verbose,'Verbose'); useParallel = vision.internal.inputValidation.validateUseParallel(parser.Results.UseParallel); params.Verbose = logical(parser.Results.Verbose); params.UseParallel = logical(useParallel); % warn about ignored options if ~isImageSet wasVerboseSpecified = ~any(strcmp(parser.UsingDefaults,'Verbose')); wasUseParallelSpecified = ~any(strcmp(parser.UsingDefaults,'UseParallel')); if wasVerboseSpecified || wasUseParallelSpecified warning(message('vision:imageCategoryClassifier:ignoreVerboseAndParallel')); end end end %------------------------------------------------------------------ function params = parseInputs(varargin) parser = inputParser(); parser.addParameter('Verbose', true); parser.addParameter('SVMOptions', []); parser.addParameter('LearnerOptions', templateSVM(), @imageCategoryClassifier.checkTemplateSVM); parser.addParameter('UseParallel', vision.internal.useParallelPreference()); parser.parse(varargin{:}); vision.internal.inputValidation.validateLogical(parser.Results.Verbose,'Verbose'); useParallel = vision.internal.inputValidation.validateUseParallel(parser.Results.UseParallel); params.Verbose = logical(parser.Results.Verbose); params.SVMOptions = parser.Results.SVMOptions; params.LearnerOptions = parser.Results.LearnerOptions; params.UseParallel = logical(useParallel); end %------------------------------------------------------------------ function tf = checkTemplateSVM(template) validateattributes(template, {'classreg.learning.FitTemplate'},... {'scalar'}, mfilename); if ~strcmp(template.Method, 'SVM') error(message('vision:imageCategoryClassifier:mustBeSVMTemplate')); end tf = true; end %------------------------------------------------------------------ function this = loadobj(s) this = imageCategoryClassifier(); % Invoke customized loadobj for bagOfFeatures this.Bag = bagOfFeatures.loadobj(s.Bag); % Set remaining properties this.Labels = s.Labels; this.NumCategories = s.NumCategories; this.LearnerOptions = s.LearnerOptions; if isa(s.Classifier, 'ClassificationECOC') % Before R2015a, the full model was saved. Use the compact % version now. this.Classifier = compact(s.Classifier); % removes training vectors. else this.Classifier = s.Classifier; end end %------------------------------------------------------------------ function tf = isImageSet(in) tf = isa(in, 'imageSet'); end %------------------------------------------------------------------ function n = getNumCategories(imgSetOrDs) n = max(1, height(countEachLabel(imgSetOrDs))); end end %====================================================================== methods (Access = protected) %------------------------------------------------------------------ % Constructor function this = imageCategoryClassifier(imgSet, bag, varargin) if nargin ~= 0 params = imageCategoryClassifier.parseInputs(varargin{:}); varName = 'imds'; validateattributes(bag, {'bagOfFeatures'}, {'nonempty'}, mfilename, 'bag'); validateattributes(imgSet, {'imageSet', 'matlab.io.datastore.ImageDatastore'}, ... {'nonempty','vector'}, mfilename, varName); if imageCategoryClassifier.isImageSet(imgSet) if isanyempty(imgSet) error(message('vision:dims:expectedNonemptyElements', varName)); end this.NumCategories = numel(imgSet); this.setLabelsFromImageSetDescriptions(imgSet); if this.NumCategories < 2 error(message('vision:imageCategoryClassifier:atLeastTwoElementSet')); end else % image datastore undefined = isundefined(categorical(imgSet.Labels)); if isempty(imgSet.Labels) || any(undefined) error(message('vision:imageCategoryClassifier:undefinedLabels')); end this.setNumCategoriesAndLabelsFromImageDs(imgSet); if this.NumCategories < 2 error(message('vision:imageCategoryClassifier:atLeastTwoUniqueLabels')); end end this.LearnerOptions = params.LearnerOptions; this.Bag = bag; printer = vision.internal.MessagePrinter.configure(params.Verbose); this.printHeader(printer); % Train an error correcting output mode (ECOC) SVM classifier [fvectors, labels] = this.createFeatureVectors(imgSet, printer, params); this.trainEcocClassifier(fvectors, labels, params.UseParallel); this.printFooter(printer); end end % end of Constructor %------------------------------------------------------------------ function [fvectors, labels] = createFeatureVectors(this, imgSet, ... printer, params) numImages = numel([imgSet.Files]); if imageCategoryClassifier.isImageSet(imgSet) % TODO refactor and move to internal % initialize outputs fvectors = zeros(numImages, this.Bag.VocabularySize); labels = zeros(numImages, 1); % Note that the loop could be avoided since encode() can handle % an array of image sets, but then we wouldn't be able to print % out progress until the entire encoding process was finished. outIdx = 1; for cIdx=1:this.NumCategories printer.printMessageNoReturn('vision:imageCategoryClassifier:encodingFeatures',cIdx); count = imgSet(cIdx).Count; % Encode each category fvectors(outIdx:outIdx+count-1,:) = this.Bag.encode(imgSet(cIdx), ... 'UseParallel', params.UseParallel,'Verbose',false); labels(outIdx:outIdx+count-1,:) = cIdx; outIdx = outIdx+count; printer.printMessage('vision:imageCategoryClassifier:encodingFeaturesDone'); end else % image data store printer.printMessageNoReturn('vision:imageCategoryClassifier:encodingFeaturesDS',numImages); fvectors = this.Bag.encode(imgSet, 'UseParallel', params.UseParallel, 'Verbose', false); printer.printMessage('vision:imageCategoryClassifier:encodingFeaturesDone'); % output labels as indices labels = double(categorical(imgSet.Labels)); end printer.linebreak; end % createSamples %------------------------------------------------------------------ function trainEcocClassifier(this, fvectors, labels, useParallel) opts = statset('UseParallel', useParallel); this.Classifier = fitcecoc(fvectors, labels, ... 'Learners', this.LearnerOptions, ... 'Coding', 'onevsall', ... 'Prior', 'uniform', ... 'Options', opts); % Remove training data from classifier this.Classifier = compact(this.Classifier); end % trainClassifier %------------------------------------------------------------------ function setNumCategoriesAndLabelsFromImageDs(this, imgSet) labels = categorical(imgSet.Labels); numCategories = numel(categories(labels)); % image datastore. Store labels as strings to preserve % behavior of string only labels. This may be enhanced in % the future to support other label types. labels = categories(labels); labels = reshape(labels,1,[]); % preserve row vector this.NumCategories = numCategories; this.Labels = labels; end %------------------------------------------------------------------ function setLabelsFromImageSetDescriptions(this, imgSet) % Set labels based on Description property of the image sets. % Use index value as label if an image set has an empty % Description. labels = {imgSet(:).Description}; empty = strcmpi(labels,''); indexAsLabel = arrayfun(@(x)sprintf('%d',x),... 1:this.NumCategories,'UniformOutput',false); labels(empty) = indexAsLabel(empty); this.Labels = labels; end %------------------------------------------------------------------ % predict for a single image set. function [label, score] = predictScalarImageSet(this, imgSet, useParallel) validateattributes(imgSet, ... {'imageSet', 'matlab.io.datastore.ImageDatastore'}, ... {'scalar'}, mfilename); featureVectors = this.Bag.encode(imgSet, 'Verbose', false, ... 'UseParallel', useParallel); % we are using negated average binary loss as a score opts = statset('UseParallel', useParallel); [label, score] = predict(this.Classifier, featureVectors,... 'Options', opts); end %------------------------------------------------------------------ % predict for a single image function [label, score] = predictImage(this, img) featureVector = this.Bag.encode(img); % we are using average binary loss as a score [label, score] = predict(this.Classifier, featureVector); end % end of predict %------------------------------------------------------------------ function [labelClass, scoreClass] = getPredictionOutputClasses(this, imgSet) numImages = numel([imgSet.Files]); if numImages > 1 % get the output class of labels and scores by calling % the classifier. [label, score] = predict(this, imgSet.readimage(1)); labelClass = class(label); scoreClass = class(score); else labelClass = 'double'; scoreClass = 'double'; end end end %====================================================================== % Verbose printing methods methods (Hidden, Access = protected) %------------------------------------------------------------------ function printHeader(this, printer) printer.linebreak; printer.printMessage('vision:imageCategoryClassifier:startTrainingTitle', this.NumCategories); printer.print('--------------------------------------------------------\n'); this.printCategories(printer, this.NumCategories, ... 'vision:imageCategoryClassifier:categoryDescription', ... this.Labels); end %------------------------------------------------------------------ function printEvaluateHeader(this, printer) printer.linebreak; printer.printMessage('vision:imageCategoryClassifier:startEval',this.NumCategories); printer.print('-------------------------------------------------------\n\n'); this.printCategories(printer,... this.NumCategories, ... 'vision:imageCategoryClassifier:categoryDescription',... this.Labels); end %------------------------------------------------------------------ function printPredictHeader(this, printer) printer.linebreak; printer.printMessage('vision:imageCategoryClassifier:predictTitle'); printer.print('---------------------------------------------------------------------\n\n'); % Display the categories for which this classifier is trained this.printCategories(printer, this.NumCategories,... 'vision:imageCategoryClassifier:categoryDescription', ... this.Labels); end %------------------------------------------------------------------ function printCategories(~,printer,numSets, msgID, labels) for i = 1:numSets printer.printMessage(msgID, i, labels{i}); end printer.linebreak; end %------------------------------------------------------------------ function printFooter(~, printer) cmd = printer.makeHyperlink('evaluate','help imageCategoryClassifier/evaluate'); printer.printMessage('vision:imageCategoryClassifier:finishedTraining',cmd); printer.linebreak; end %------------------------------------------------------------------ function printConfusionMatrix(this,printer,confMat) % exit early if not verbose if ~isa(printer,'vision.internal.VerbosePrinter') return end labels = this.Labels; % column widths for data elements minColWidth = 4; % accommodates "0.00" output (%.2f) maxLabelWidth = max(cellfun(@(x)numel(x),labels))+1; % define format for row and column headings fmt = sprintf('%%-%ds ', max(maxLabelWidth,numel('KNOWN'))); % print column heading colHeading = sprintf([fmt '%s '],'KNOWN','|'); sz = numel(colHeading); for j = 1:numel(labels) colWidth = max(numel(labels{j}), minColWidth); format = sprintf('%%-%ds ', colWidth); colHeading = sprintf(['%s' format], colHeading, labels{j}); end printer.linebreak; % center PREDICTED over row data centeringIdx = floor((numel(colHeading)-sz)/2 + sz) - 5; % -5 for numel('PREDICTED')/2 printer.print('%sPREDICTED\n',repmat(' ',1,centeringIdx)); printer.print('%s\n',colHeading); % add "----" between column headings and data printer.print('%s\n',repmat('-',1,numel(colHeading))); % print rows of the table for i = 1:numel(labels) % print row heading printer.print([fmt '%s '],labels{i},'|'); % print the data for j = 1:numel(labels) colWidth = max(numel(labels{j}),minColWidth); format = sprintf('%%-%d.2f ',colWidth); printer.print(format,confMat(i,j)); end printer.linebreak; end printer.linebreak; end end end