博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch(二)
阅读量:4099 次
发布时间:2019-05-25

本文共 2222 字,大约阅读时间需要 7 分钟。

学习莫烦pytorch视频,部分代码进行注释

#classification.pyimport numpy as npimport torchimport torch.nn.functional as Ffrom torch.autograd import Variableimport matplotlib.pyplot as pltimport mathimport pdbn_data = torch.ones(100,2)#pdb.set_trace()x0 = torch.normal(2*n_data, 1)#生成均值为2,方差为1的tensor#print(x0)y0 = torch.zeros(100)x1 = torch.normal(-2*n_data, 1)y1 = torch.ones(100)x=torch.cat((x0,x1),0).type(torch.FloatTensor)#32bit floating#0:竖着拼 1:横着拼y=torch.cat((y0,y1),0).type(torch.LongTensor)#64bit integer#只有一维#print(y0)x, y =Variable(x), Variable(y)#print(y)#plt.scatter(x.data.numpy()[:,0], x.data.numpy()[:,1], c=y.data.numpy(), s=100, lw=0,cmap='RdYlGn')#c:颜色,s:控制点大小,lw:加不加一样,,cmap控制颜色的数组#plt.show()class Net(torch.nn.Module):    def __init__(self, n_feature, n_hidden, n_output):        super(Net, self).__init__()        self.hidden=torch.nn.Linear(n_feature,n_hidden)        self.output=torch.nn.Linear(n_hidden,n_output)    def forward(self, x):        x=F.relu(self.hidden(x))        x=self.output(x)        return xdef myLoss(pre, y):    s=0    #pdb.set_trace()    for i in range(200):        p1=pre[i][0]        p2=pre[i][1]        #p11=math.exp(p1)/(math.exp(p1)+math.exp(p2))        #p22=1-p11        tmp = -pre[i][y[i]] +math.log(math.exp(p1)+math.exp(p2))         s=s+tmp    return s/200net = Net(2,10,2)'''net = torch.nn.Sequential(    torch.nn.Linear(2,10),    torch.nn.ReLU(),    torch.nn.Linear(10,2),    )    不用class,直接定义,用法一样'''optimizer = torch.optim.SGD(net.parameters(), lr=0.1)loss_func=torch.nn.CrossEntropyLoss()plt.ion()for i in range(20):    pre = net(x)    loss = loss_func(pre, y)#这里的loss不是直接用的,具体实现参见myLoss    #myloss = myLoss(pre.data.numpy(),y.data.numpy())    #print(pre)    #print(loss)    #print(myloss)    #break    optimizer.zero_grad()    loss.backward()    optimizer.step()    if i%2==0:        plt.cla()        _,prediction=torch.max(F.softmax(pre),1)        pred_y=prediction.data.numpy().squeeze()        target_y=y.data.numpy()        plt.scatter(x.data.numpy()[:,0], x.data.numpy()[:,1], c=pred_y,s=100,lw=0,cmap='RdYlGn')        ac=sum(pred_y==target_y)/200        plt.text(1.5, -4, 'Accuracy=%.2f'%ac, fontdict={
'size':20, 'color':'red'}) plt.pause(0.5)plt.ioff()plt.show()

转载地址:http://xwksi.baihongyu.com/

你可能感兴趣的文章
听说玩这些游戏能提升编程能力?
查看>>
7 年工作经验,面试官竟然还让我写算法题???
查看>>
被 Zoom 逼疯的歪果仁,造出了视频会议机器人,同事已笑疯丨开源
查看>>
上古语言从入门到精通:COBOL 教程登上 GitHub 热榜
查看>>
再见,Eclipse...
查看>>
如果你还不了解 RTC,那我强烈建议你看看这个!
查看>>
沙雕程序员在无聊的时候,都搞出了哪些好玩的小玩意...
查看>>
程序员用 AI 修复百年前的老北京视频后,火了!
查看>>
漫话:为什么你下载小电影的时候进度总是卡在 99% 就不动了?
查看>>
我去!原来大神都是这样玩转「多线程与高并发」的...
查看>>
当你无聊时,可以玩玩 GitHub 上这个开源项目...
查看>>
B 站爆红的数学视频,竟是用这个 Python 开源项目做的!
查看>>
安利 10 个让你爽到爆的 IDEA 必备插件!
查看>>
自学编程的八大误区!克服它!
查看>>
GitHub 上的一个开源项目,可快速生成一款属于自己的手写字体!
查看>>
早知道这些免费 API,我就可以不用到处爬数据了!
查看>>
Java各种集合类的合并(数组、List、Set、Map)
查看>>
JS中各种数组遍历方式的性能对比
查看>>
Mysql复制表以及复制数据库
查看>>
进程管理(一)
查看>>