gusucode.com > Matlab脸部识别程序源码 > code/code/learn_perceptron.m

    function [weights,mse,acc] = learn_perceptron(data,targets,rate,threshold,init_method,random_seed,plotflag,k)
% function [weights,mse,acc] = learn_perceptron(data,targets,rate,threshold,init_method,random_seed,plotflag,k)
%
%  Learn the weights for a perceptron (linear) classifier to minimize its
%    mean squared error.
%                            A. Student, ICS 175A
%
%  Inputs
%     data: N x (d+1) matrix of training data    
%     targets: N x 1 vector of target values (+1 or -1)
%     rate: learning rate for the perceptron algorithm (e.g., rate = 0.001)   
%     threshold: if the reduction in MSE from one iteration to the next is *less*
%                than threshold, then halt learning (e.g., threshold = 0.000001)
%     init_method: method used to initialize the weights (1 = random, 2 = half
%                way between 2 random points in each group, 3 = half way between
%                the centroids in each group)
%     random_seed:  this is an integer used to "seed" the random number generator
%                   for either methods 1 or 2 for initialization (this is useful
%                   to be able to recreate a particular run exactly)  
%     plotflag: 1 means plotting is turned on, default value is 0
%     k: how many iterations between plotting (e.g., k = 100)
%
%  Outputs
%     weights: 1 x (d+1) row vector of learned weights
%     mse: mean squared error for learned weights
%     acc: classification accuracy for learned weights (percentage, between 0 and 100)

% Initialize the upper bound on the mse, allows to track algorithm divergence
MSEThresh = 1e+30;

[N, d] = size(data);

% error checking
if nargin < 4
    error('The function takes at least 4 arguments (data, targets, rate, threshold)');
end
if size(data,1) ~= size(targets,1)
    error('The number of rows in the first two arguments (data, targets) does not match!');
end


% initialize the input arguments
if ~exist('k')
    k = 100;
end
if ~exist('plotflag')
    plotflag = 0;
end
if ~exist('random_seed') 
    random_seed = 1234;
end
if ~exist('init_method')
    init_method = 1;
end

% the data is supposed to be padded with 1's
d = d-1;  % d = "real diimensionality" of data

% init the weights 
w = initialize_weights175(data,targets,init_method,random_seed);

iteration=0;
while iteration < 2 | ...
        (...
        (abs(mse(iteration) - mse(iteration-1)) > threshold) & ...
        (mse(iteration) < MSEThresh) ...
        )
    iteration = iteration + 1;
    
    % cycle through all of the examples
    for i=1:N
        % calculate the unthesholded output
        o = sig( w * data(i,:)' );
        % update the wieghts
        w = w + rate * (targets(i) - o) * dsig(o) * data(i, :);
    end
    
    % calculate the errors using current parameter values
    [cerror(iteration), mse(iteration)] = perceptron_error(w, data, targets);
    
    % visualize the decision boundary if needed
    if plotflag == 1 & mod(iteration - 1, k) == 0
        t = strcat ('Decision boundary ar iteration # ', num2str(iteration));
        weightplot175(data, targets, w, t);
        
        pause(0.0001);
    end
   % mse(iteration)
end

if mse(iteration) >= MSEThresh
    error('The perceptron algorithm has diverged');
end

weights = w;
acc = 100 - cerror;
mse = mse(iteration);
acc = acc(iteration);
% create the plots of the MSE and Accuracy Vs. iteration number
if (plotflag == 1)
    figure(2);
    subplot(2, 1, 1);            
    plot(mse,'b-');
    xlabel('iteration');
    ylabel('MSE');
    
    subplot(2, 1, 2);
    plot(acc,'b-');
    xlabel('iteration');
    ylabel('Accuracy');
end        

function s = sig(x)
% Compute the (rescaled) sigmoid function
  s = 2./(1+exp(-x)) - 1;

function ds = dsig(x)
% Compute the derivative of the (rescaled) sigmoid
  ds = .5 .* (sig(x)+1) .* (1-sig(x));