gusucode.com > Matlab脸部识别程序源码 > code/code/learn_perceptron.m
function [weights,mse,acc] = learn_perceptron(data,targets,rate,threshold,init_method,random_seed,plotflag,k) % function [weights,mse,acc] = learn_perceptron(data,targets,rate,threshold,init_method,random_seed,plotflag,k) % % Learn the weights for a perceptron (linear) classifier to minimize its % mean squared error. % A. Student, ICS 175A % % Inputs % data: N x (d+1) matrix of training data % targets: N x 1 vector of target values (+1 or -1) % rate: learning rate for the perceptron algorithm (e.g., rate = 0.001) % threshold: if the reduction in MSE from one iteration to the next is *less* % than threshold, then halt learning (e.g., threshold = 0.000001) % init_method: method used to initialize the weights (1 = random, 2 = half % way between 2 random points in each group, 3 = half way between % the centroids in each group) % random_seed: this is an integer used to "seed" the random number generator % for either methods 1 or 2 for initialization (this is useful % to be able to recreate a particular run exactly) % plotflag: 1 means plotting is turned on, default value is 0 % k: how many iterations between plotting (e.g., k = 100) % % Outputs % weights: 1 x (d+1) row vector of learned weights % mse: mean squared error for learned weights % acc: classification accuracy for learned weights (percentage, between 0 and 100) % Initialize the upper bound on the mse, allows to track algorithm divergence MSEThresh = 1e+30; [N, d] = size(data); % error checking if nargin < 4 error('The function takes at least 4 arguments (data, targets, rate, threshold)'); end if size(data,1) ~= size(targets,1) error('The number of rows in the first two arguments (data, targets) does not match!'); end % initialize the input arguments if ~exist('k') k = 100; end if ~exist('plotflag') plotflag = 0; end if ~exist('random_seed') random_seed = 1234; end if ~exist('init_method') init_method = 1; end % the data is supposed to be padded with 1's d = d-1; % d = "real diimensionality" of data % init the weights w = initialize_weights175(data,targets,init_method,random_seed); iteration=0; while iteration < 2 | ... (... (abs(mse(iteration) - mse(iteration-1)) > threshold) & ... (mse(iteration) < MSEThresh) ... ) iteration = iteration + 1; % cycle through all of the examples for i=1:N % calculate the unthesholded output o = sig( w * data(i,:)' ); % update the wieghts w = w + rate * (targets(i) - o) * dsig(o) * data(i, :); end % calculate the errors using current parameter values [cerror(iteration), mse(iteration)] = perceptron_error(w, data, targets); % visualize the decision boundary if needed if plotflag == 1 & mod(iteration - 1, k) == 0 t = strcat ('Decision boundary ar iteration # ', num2str(iteration)); weightplot175(data, targets, w, t); pause(0.0001); end % mse(iteration) end if mse(iteration) >= MSEThresh error('The perceptron algorithm has diverged'); end weights = w; acc = 100 - cerror; mse = mse(iteration); acc = acc(iteration); % create the plots of the MSE and Accuracy Vs. iteration number if (plotflag == 1) figure(2); subplot(2, 1, 1); plot(mse,'b-'); xlabel('iteration'); ylabel('MSE'); subplot(2, 1, 2); plot(acc,'b-'); xlabel('iteration'); ylabel('Accuracy'); end function s = sig(x) % Compute the (rescaled) sigmoid function s = 2./(1+exp(-x)) - 1; function ds = dsig(x) % Compute the derivative of the (rescaled) sigmoid ds = .5 .* (sig(x)+1) .* (1-sig(x));