gusucode.com > vision工具箱matlab源码程序 > vision/+vision/+internal/+rcnn/batchReadAndCrop.m

    function [miniBatchData, miniBatchResponse] = batchReadAndCrop(...
    imds, tbl, trainingSamples, regionResizer, randomSelector, ...
    start, stop, numPos, numPosSamples, currentMiniBatchSize)

networkImageSize = regionResizer.ImageSize;

% mini-batch size smaller than 4 will have no positive samples!
assert(currentMiniBatchSize >= 4);

numNeg = currentMiniBatchSize - numPos;

% gather samples
numImages = stop - start;
posSamples = cell(numImages,1);
negSamples = cell(numImages,1);
posResponse = cell(numImages,1);
negResponse = cell(numImages,1);
k = 1;

negSamplesPerImage = round(numNeg/numImages);

for i = start:stop-1
    samples = tbl.RegionProposalBoxes{i};
    
    I = readimage(imds, i);
    
    I = localConvertImageToMatchNumberOfNetworkImageChannels(I, networkImageSize);
    
    posSamples{k} = regionResizer.cropAndResize(I, ...
        samples(trainingSamples.Positive{i},:));
    
    % only crop out enough samples to fill current mini-batch.
    bb = samples(trainingSamples.Negative{i},:);
    labels = trainingSamples.Labels{i};
    N = size(bb,1);
    
    id = randomSelector.randperm(N, min(N, negSamplesPerImage));
    negSamples{k} = regionResizer.cropAndResize(I, bb(id,:));
    
    % response
    posResponse{k} = labels(trainingSamples.Positive{i});
    negResponse{k} = labels(trainingSamples.Negative{i});
    k = k + 1; 
end

posSamples = cat(4, posSamples{:});
negSamples = cat(4, negSamples{:});
numNegSamples = size(negSamples,4);

posResponse = vertcat(posResponse{:});
negResponse = vertcat(negResponse{:});

% There may be more positive samples than we need, randomly
% sample enough to fill mini-batch.
pidx = randomSelector.randperm(numPosSamples, min(numPos, numPosSamples));
nidx = randomSelector.randperm(numNegSamples, min(numNeg, numNegSamples));

miniBatchData = cat(4,posSamples(:,:,:,pidx), negSamples(:,:,:,nidx));

% data in mini-batch need not be shuffled. training responses
% are averaged over all mini-batch samples so order does not
% matter.
miniBatchResponse = nnet.internal.cnn.util.dummify(...
    [posResponse(pidx); negResponse(nidx)]);

%--------------------------------------------------------------------------
function I = localConvertImageToMatchNumberOfNetworkImageChannels(I, imageSize)

isNetImageRGB = numel(imageSize) == 3 && imageSize(end) == 3;
isImageRGB    = ~ismatrix(I);

if isImageRGB && ~isNetImageRGB
    I = rgb2gray(I);
    
elseif ~isImageRGB && isNetImageRGB
    I = repmat(I,1,1,3);
end