得益于反向传播算法,神经网络计算导数时非常方便,下面代码中演示如何使用LibTorch进行自动微分求导。

进行自动微分运算需要调用函数

torch::autograd::grad(
outputs, // 为某个可微函数的输出 y=f(x) 中的 y
inputs, // 为某个可微函数的输入 y=f(x) 中的 x
grad_outputs,// 雅克比矩阵(此处计算 f'(x),故设置为1,且与x形状相同 )
retain_graph,// 默认值与 create_graph 相同,这里设置为 true即可
create_graph,// 需要设置为 true 以计算高阶导数
allow_unused // 设置为 false 即可
)

在本文示例中,我们计算 \(y=x^2+x\) 在 \(x = 0.1, 0.3, 0.5\) 处的函数值、一阶导数和二阶导数值,根据我们学到的数学知识,很容易计算出下列数据

\(x\) 0.1 0.3 0.5
\(y\) 0.11 0.39 0.75
\(y'\) 1.20 1.60 2.00
\(y''\) 2.00 2.00 2.00

而在LibTorch中调用自动微分计算导数的代码如下所示

#include <iostream>
#include <torch/torch.h> int main(int argc, char* atgv[])
{
std::cout.setf(std::ios::scientific);
std::cout.precision(7); std::vector<float> vec{0.1, 0.3, 0.5};
torch::Tensor x = torch::from_blob(vec.data(), {3}, torch::kFloat).requires_grad_(true);
torch::Tensor y = x * x + x; // y= x^2 + x
auto weight = torch::ones_like(x); std::cout << "x = ";
for (int i = 0; i < 3; ++i)
std::cout << x[i].item<float>() << " ";
std::cout << std::endl; std::cout << "y = "; // 0.11 0.39 0.75
for (int i = 0; i < 3; ++i)
std::cout << y[i].item<float>() << " ";
std::cout << std::endl; // 计算输出一阶导数(y' = 2x + 1)
auto dydx = torch::autograd::grad({y}, {x}, {weight}, true, true, false);
std::cout << "dydx = "; // 1.2 1.6 2.0
for (int i = 0; i < 3; ++i)
std::cout << dydx[0][i].item<float>() << " ";
std::cout << std::endl; // 计算输出二阶导数(y''= 2)
auto d2ydx2 = torch::autograd::grad({dydx[0]}, {x}, {weight});
std::cout << "d2ydx2 = "; // 2.0 2.0 2.0
for (int i = 0; i < 3; ++i)
std::cout << d2ydx2[0][i].item<float>() << " ";
std::cout << std::endl; return 0;
}

计算结果如下图所示,与我们手动计算的结果一致。

最新文章

  1. PHP之session与cookie
  2. 图文相关性 flickr数据实验结论_1
  3. 查看Linux内核版本命令
  4. mybatis中的mapxml的语法
  5. 自定义CSS博客(转)
  6. 6.Inout双向端口信号处理方法
  7. c# 计算1-100之间的所有质数(素数)的和
  8. C++多态性与C#的比较
  9. AutoCompleteTextView 与sqlite绑定实现记住用户输入的内容并自动提示
  10. 提示框插件SweetAlert
  11. Socket编程之聊天程序 - 模拟Fins/ModBus协议通信过程
  12. ASP.Net MVC C#画图 页面调用
  13. JFreeChart绘制折线图实例
  14. c/c++ linux 进程间通信系列5,使用信号量
  15. JavaScript 函数调用和this指针
  16. js中的“==”和“===”的区别
  17. 【C++】C++中const与constexpr的比较
  18. python找递归目录中文件,并移动到一个单独文件夹中,同时记录原始文件路径信息
  19. pytorch学习资料链接
  20. [C#]使用Windows Form开发的百度网盘搜索工具

热门文章

  1. Linux shell脚本算术运算和逻辑运算
  2. word processing in nlp with tensorflow
  3. 一文读懂数仓中的pg_stat
  4. Overfitting &amp; Train Set &amp; Test Set
  5. Java 中的对象池实现
  6. ApiDay001 __02 Java_StringBuilder
  7. Solution -「Hdu3037」Saving Beans
  8. Java的学习日常
  9. springmvc源码笔记-HandlerMethodReturnValueHandler
  10. 枚举子集为什么是 O(3^n) 的