gusucode.com > 用mushrooms数据对模式识别课程讲述的各种模式分类方法matlab源码程序 > pattern-recognition-simulation/KNN_function.m
function result_class=KNN_function(traing_example_original,test_example,K) %该函数用来进行KNN分类; %输入参数: traing_example_original:训练样本,带标签 % test_example:测试样本,带标签 % K:表示K近邻 %输出参数: result_class是一个m*1列向量,包括类别信息 m=size(test_example,1); %获得测试样本的个数 n=size(traing_example_original,1);%获得训练样本的个数 %初始化距离矩阵,oushi_distance行表示测试样本,列表示训练样本,行列交点表示两者的距离 oushi_distance=zeros(m,n); %计算一个测试样本与每个训练样本的欧氏距离 for i=1:m for j=1:n oushi_distance(i,j)=sum(((test_example(i,2:end)-traing_example_original(j,2:end)).^2)'); end end for i=1:m temp=oushi_distance(i,:); temp=sort(temp); Kmin_oushi_distance(1,1:K)=temp(1,1:K);%找到最小的K个近邻 for p = 1:K index_column = find(oushi_distance(i,:)==Kmin_oushi_distance(1,p));%在oushi_distance(i,:)中找出最小的距离对应的列号 if p == 1 k_traing_example_original = traing_example_original(index_column(1,1:end),:); else k_traing_example_original = [k_traing_example_original;traing_example_original(index_column(1,1:end),:)]; end end k_traing_example_original = k_traing_example_original(1:K,:); index_column_class1 = find(k_traing_example_original(:,1)==1); k1 = size(index_column_class1,1); if k1>=K-k1 result_class(i,1)=1; else result_class(i,1)=2; end end