如何使用 libtorch 实现 VGG16 网络?
2024-09-04 14:16:31
参考地址:https://ethereon.github.io/netscope/#/preset/vgg-16
按照上面的图来写即可。
论文地址:https://arxiv.org/pdf/1409.1556.pdf
// Define a new Module.
struct Net : torch::nn::Module {
Net() {
conv1_1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 64, { 3,3 }).padding(1));
conv1_2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 64, { 3,3 }).padding(1));
conv2_1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 128, { 3,3 }).padding(1));
conv2_2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(128, 128, { 3,3 }).padding(1));
conv3_1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(128, 256, { 3,3 }).padding(1));
conv3_2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, { 3,3 }).padding(1));
conv3_3 = torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, { 3,3 }).padding(1));
conv4_1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 512, { 3,3 }).padding(1));
conv4_2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, { 3,3 }).padding(1));
conv4_3 = torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, { 3,3 }).padding(1));
conv5_1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, { 3,3 }).padding(1));
conv5_2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, { 3,3 }).padding(1));
conv5_3 = torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, { 3,3 }).padding(1));
fc1 = torch::nn::Linear(512*7*7,4096);
fc2 = torch::nn::Linear(4096, 4096);
fc3 = torch::nn::Linear(4096, 1000);
}
// Implement the Net's algorithm.
torch::Tensor forward(torch::Tensor x) {
x = conv1_1->forward(x);
x = torch::relu(x);
x = conv1_2->forward(x);
x = torch::relu(x);
x = torch::max_pool2d(x, { 2,2 }, { 2,2 });
x = conv2_1->forward(x);
x = torch::relu(x);
x = conv2_2->forward(x);
x = torch::relu(x);
x = torch::max_pool2d(x, { 2,2 }, { 2,2 });
x = conv3_1->forward(x);
x = torch::relu(x);
x = conv3_2->forward(x);
x = torch::relu(x);
x = conv3_3->forward(x);
x = torch::relu(x);
x = torch::max_pool2d(x, { 2,2 }, { 2,2 });
x = conv4_1->forward(x);
x = torch::relu(x);
x = conv4_2->forward(x);
x = torch::relu(x);
x = conv4_3->forward(x);
x = torch::relu(x);
x = torch::max_pool2d(x, { 2,2 }, { 2,2 });
x = conv5_1->forward(x);
x = torch::relu(x);
x = conv5_2->forward(x);
x = torch::relu(x);
x = conv5_3->forward(x);
x = torch::relu(x);
x = torch::max_pool2d(x, { 2,2 }, { 2,2 });
x = x.view({ x.size(0), -1 });//512x7x7 = 25088
x = fc1->forward(x);
x = torch::relu(x);
x = torch::dropout(x, 0.5, is_training());
x = fc2->forward(x);
x = torch::relu(x);
x = torch::dropout(x, 0.5, is_training());
x = fc3->forward(x);
x = torch::log_softmax(x, 1);
return x;
}
// Use one of many "standard library" modules.
torch::nn::Conv2d conv1_1{ nullptr };
torch::nn::Conv2d conv1_2{ nullptr };
torch::nn::Conv2d conv2_1{ nullptr };
torch::nn::Conv2d conv2_2{ nullptr };
torch::nn::Conv2d conv3_1{ nullptr };
torch::nn::Conv2d conv3_2{ nullptr };
torch::nn::Conv2d conv3_3{ nullptr };
torch::nn::Conv2d conv4_1{ nullptr };
torch::nn::Conv2d conv4_2{ nullptr };
torch::nn::Conv2d conv4_3{ nullptr };
torch::nn::Conv2d conv5_1{ nullptr };
torch::nn::Conv2d conv5_2{ nullptr };
torch::nn::Conv2d conv5_3{ nullptr };
torch::nn::Linear fc1{ nullptr };
torch::nn::Linear fc2{ nullptr };
torch::nn::Linear fc3{ nullptr };
};
最新文章
- H5案例分享:使用JS判断客户端、浏览器、操作系统类型
- Eclipse中全局搜索和更替
- [codeforces 339]C. Xenia and Weights
- kali 密码攻击
- JVM-运行时数据区
- WPF-控件-将ListBox条目水平排列
- sublime text高亮less
- react环境搭建
- TiDB:支持 MySQL 协议的分布式数据库解决方案
- 利用WebApi获取手机号码归属地
- 第二章实例:动态生成View控件例子---小球跟随手指滑动
- linux下Ftp环境的搭建
- SQL生成一年每一天的时间列表的几种方法
- Oracle经常用到的一些函数
- 基于open62541的opc ua 服务器开发实现(1)
- php中使用sphinx搜索引擎
- Windows10系统网络连接问题
- Ubuntu16.04安装YouCompleteMe
- beego+vue.js分离开发,结合发布,简单部署
- Beta阶段团队项目开发篇章2
热门文章
- [转]JavaScript放在<;head>;和<;body>;的区别
- java读properties文件 乱码
- 基于jQuery仿Flash横向切换焦点图
- Linux samba 服务的配置
- JavaScript 框架 jQuery 的下载和安装
- sama5d3 环境检测 adc测试
- man page用法
- java定时调度器解决方案分类及特性介绍
- 安全 流程服务器开新机器 内外网 iptables 安全组 用户安全root用户的使用.
- 集中精力的重要性(The Importance of Focus)