本文共 2222 字,大约阅读时间需要 7 分钟。
学习莫烦pytorch视频,部分代码进行注释
#classification.pyimport numpy as npimport torchimport torch.nn.functional as Ffrom torch.autograd import Variableimport matplotlib.pyplot as pltimport mathimport pdbn_data = torch.ones(100,2)#pdb.set_trace()x0 = torch.normal(2*n_data, 1)#生成均值为2,方差为1的tensor#print(x0)y0 = torch.zeros(100)x1 = torch.normal(-2*n_data, 1)y1 = torch.ones(100)x=torch.cat((x0,x1),0).type(torch.FloatTensor)#32bit floating#0:竖着拼 1:横着拼y=torch.cat((y0,y1),0).type(torch.LongTensor)#64bit integer#只有一维#print(y0)x, y =Variable(x), Variable(y)#print(y)#plt.scatter(x.data.numpy()[:,0], x.data.numpy()[:,1], c=y.data.numpy(), s=100, lw=0,cmap='RdYlGn')#c:颜色,s:控制点大小,lw:加不加一样,,cmap控制颜色的数组#plt.show()class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden=torch.nn.Linear(n_feature,n_hidden) self.output=torch.nn.Linear(n_hidden,n_output) def forward(self, x): x=F.relu(self.hidden(x)) x=self.output(x) return xdef myLoss(pre, y): s=0 #pdb.set_trace() for i in range(200): p1=pre[i][0] p2=pre[i][1] #p11=math.exp(p1)/(math.exp(p1)+math.exp(p2)) #p22=1-p11 tmp = -pre[i][y[i]] +math.log(math.exp(p1)+math.exp(p2)) s=s+tmp return s/200net = Net(2,10,2)'''net = torch.nn.Sequential( torch.nn.Linear(2,10), torch.nn.ReLU(), torch.nn.Linear(10,2), ) 不用class,直接定义,用法一样'''optimizer = torch.optim.SGD(net.parameters(), lr=0.1)loss_func=torch.nn.CrossEntropyLoss()plt.ion()for i in range(20): pre = net(x) loss = loss_func(pre, y)#这里的loss不是直接用的,具体实现参见myLoss #myloss = myLoss(pre.data.numpy(),y.data.numpy()) #print(pre) #print(loss) #print(myloss) #break optimizer.zero_grad() loss.backward() optimizer.step() if i%2==0: plt.cla() _,prediction=torch.max(F.softmax(pre),1) pred_y=prediction.data.numpy().squeeze() target_y=y.data.numpy() plt.scatter(x.data.numpy()[:,0], x.data.numpy()[:,1], c=pred_y,s=100,lw=0,cmap='RdYlGn') ac=sum(pred_y==target_y)/200 plt.text(1.5, -4, 'Accuracy=%.2f'%ac, fontdict={ 'size':20, 'color':'red'}) plt.pause(0.5)plt.ioff()plt.show()
转载地址:http://xwksi.baihongyu.com/