gusucode.com > 小波神经网络的MATLAB原代码程序 > waveletneural/wavneural.m
clc clear %step 1========================= %定义输入样本; t=0:0.01:1.5; x=-sin(2*pi*t); targ=[0 0 1 1 0 0 ]; eta=0.02;aerfa=0.935; %初始化连接权wjh(输出层和隐层的连接权);whi(隐层和输出层的连接权); %假设小波函数节点数为:H个;样本数为P; %输出节点数为:J个;输入节点数为:I个; H=15;P=2; I=length(t); J=length(targ); %初始化小波参数 b=rand(H,1); a=rand(H,1); %初始化权系数; whi=rand(I,H); wjh=rand(H,J); %阈值初始化; b1=rand(H,1); b2=rand(J,1); p=0; %保存的误差; Err_NetOut=[]; flag=1;count=0; while flag>0 flag=0; count=count+1; %step 2================================= xhp1=0; for h=1:H for i=1:I xhp1=xhp1+whi(i,h)*x(i); end ixhp(h)=xhp1+b1(h); xhp1=0; end for h=1:H oxhp(h)=fai((ixhp(h)-b(h))/a(h)); end %step 3==================================== ixjp1=0; for j=1:J for h=1:H ixjp1=ixjp1+wjh(h,j)*oxhp(h); end ixjp(j)=ixjp1+b2(j); ixjp1=0; end for i=1:J oxjp(i)=fnn(ixjp(i)); end %step 6==保存每次误差===== wuchayy=1/2*sumsqr(oxjp-targ); %E_x=1/2*sumsqr(x); Err_NetOut=[Err_NetOut wuchayy];%保存每次的误差; %Err_rate=Err_NetOut/E_x; %Err_rate %oxjp %求detaj ,detab2================================== for j=1:J detaj(j)=-(oxjp(j)-targ(j))*oxjp(j)*(1-oxjp(j)); end for j=1:J for h=1:H detawjh(h,j)=eta*detaj(j)*oxhp(h); end end detab2=eta*detaj; %求detah, detawhi detab1 detab detaa;======================== sum=0; for h=1:H for j=1:J sum=detaj(j)*wjh(h,j)*diffai((ixhp(h)-b(h))/a(h))/a(h)+sum; end detah(h)=sum; sum=0; end for h=1:H for i=1:I detawhi(i,h)=eta*detah(h)*x(i); end end detab1=eta*detah; detab=-eta*detah; for h=1:H detaa(h)=-eta*detah(h)*((ixhp(h)-b(h))/a(h)); end %引入动量因子aerfa,修正各个系数========================================== wjh=wjh+(1+aerfa)*detawjh; whi=whi+(1+aerfa)*detawhi; a=a+(1+aerfa)*detaa'; b=b+(1+aerfa)*detab'; b1=b1+(1+aerfa)*detab1'; b2=b2+(1+aerfa)*detab2'; %====================================================== %引入修正算法!! %判断所有的样本是否计算完================================== p=p+1; if p~=P flag=flag+1; else if Err_NetOut(end)>0.05 flag=flag+1; else figure; plot(Err_NetOut); title('误差曲线'); disp('目标达到'); %disp(oxjp); end end if count>2000 figure; plot(Err_NetOut); title('误差曲线'); disp('目标未达到'); disp(oxjp); break; end end