Pytorch中神经网络包中最核心的是autograd包,我们先来简单地学习它,然后训练我们第一个神经网络。

autograd包为所有在tensor上的运算提供了自动求导的支持,这是一个逐步运行的框架,也就意味着后向传播过程是按照你的代码定义的,并且单个循环可以不同

我们通过一些简单例子来了解

Tensor

torch.tensor是这个包的基础类,如果你设置.requires_grads为True,它就会开始跟踪上面的所有运算。如果你做完了运算使用.backward(),所有的梯度就会自动运算,tesor的梯度将会累加到.grad这个属性。

若要停止tensor的历史纪录,可以使用.detch()将它从历史计算中分离出来,防止未来的计算被跟踪。

为了防止追踪历史(并且使用内存),你也可以将代码块包含在with torch.no_grad():中。这对于评估模型时是很有用的,因为模型也许拥有可训练的参数使用了requires_grad=True,但是这种情况下我们不需要梯度。

还有一个类对autograd的实现非常重要,——Function

Tensor和Function是相互关联的并一起组成非循环图,它编码了所有计算的历史,每个tensor拥有一个属性.grad_fn,该属性引用已创建tensor的Function。(除了用户自己创建的tensor,它们的.grad_fn为None)。

如果你想计算导数,可以在一个Tensor上调用.backward()。如果Tensor是一个标量(也就是只包含一个元素数据),你不需要为backward指明任何参数,但是拥有多个元素的情况下,你需要指定一个匹配维度的gradient参数。

import torch

创建一个tensor并设置rquires_grad=True来追踪上面的计算

x=torch.ones(2,2,requires_grad=True)
print(x) out:
tensor([[ 1., 1.],
[ 1., 1.]])

执行一个tensor运算

y=x+2
print(y)

out:
tensor([[ 3., 3.],
[ 3., 3.]])

y是通过运算的结果建立的,所以它有grad_fn

print(y.grad_fn)

out:
<AddBackward0 object at 0x000001EDFE054D30>

在y上进行进一步的运算

z=y*y*3
out=z.mean()
print(z,out)\ out:
tensor([[ 27., 27.],
[ 27., 27.]]) tensor(27.)

.requires_grad_(...)可以用内建方式改变tensor的requires_grad标志位。如果没有给定,输入标志默认为False

a=torch.randn(2,2)
a=((a*3)/(a-1))
print(a.requires_grad)
a.requires_grad_(True)
print(a.requires_grad)
b=(a*a).sum()
print(b.grad_fn)

out:
False
True
<SumBackward0 object at 0x000001EDFE054940>

Gradients

我们开始反向传播,因为out包含单一标量,out.backward()相当于out.backward(torch.tensor(1)).

out.backward()

打印梯度d(out)/dx

print(x.grad)

out:
tensor([[ 4.5000, 4.5000],
[ 4.5000, 4.5000]])

你应该得到一个4.5的矩阵。可以简单手动计算一下这一结果。

你可以使用autograd做许多疯狂的事情

x=torch.randn(3,requires_grad=True)
y=x*2
while y.data.norm()<1000:
y=y*2
print(y)

out:
tensor([  980.8958,  1180.4403,   614.2102])
gradients=torch.tensor([0.1,1.0,0.0001],dtype=torch.float)
y.backward(gradients)
print(x.grad)

out:
tensor([  102.4000,  1024.0000,     0.1024])

你可以将语句包含在with torch.no_grad()从Tensor的历史停止自动求导

print(x.requires_grad)
print((x**2).requires_grad)
with torch.no_grad():
print((x**2).requires_grad) out:
True
True
False

   

最新文章

  1. 去掉IE下input的叉号
  2. System.Data.Entity 无法引用的问题
  3. telnet登录路由器启动服务的shell脚本
  4. Jquery 禁用 a 标签 onclick 事件30秒后可用
  5. Rediss_基本介绍
  6. sql存储过程的创建
  7. c++ stl algorithm: std::find, std::find_if
  8. Vivado完成综合_实现_生成比特流后发出提醒声音-原创☺
  9. jdbc连接数据库并打印的简单例子
  10. CIA402状态转换图
  11. sklearn交叉验证-【老鱼学sklearn】
  12. java SSM 解决跨域问题
  13. Fastreport.net 如何在开发MVC应用程序时使用报表
  14. Gitlab CR
  15. 【BZOJ3309】DZY Loves Math(莫比乌斯反演)
  16. js后退
  17. MQ的订阅模式
  18. linux禁止非法用户试探登录
  19. python+正态分布+蒙特卡洛预测男女身高概率!
  20. WCF - 服务实例管理模式

热门文章

  1. git详情、git工作流程、常用命令、忽略文件、分支操作、gitee远程仓库使用
  2. 数据库-mysql索引篇
  3. netty系列之:netty中的自动解码器ReplayingDecoder
  4. 聊聊 node 如何优雅地获取 mac 系统版本
  5. idea打开service窗口
  6. 浅尝Spring注解开发_Servlet3.0与SpringMVC
  7. 【深入理解计算机系统CSAPP】第六章 存储器层次结构
  8. 【多线程】线程优先级 Priority
  9. SpringBoot线程池
  10. 将MySQL查询结果导出到Excel