在net.py里面构造网络,网络的结构为输入为28*28,第一层隐藏层的输出为300, 第二层输出的输出为100, 最后一层的输出层为10,

net.py

import torch
from torch import nn class Batch_Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Batch_Net, self).__init__()
self.layer_1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1), nn.ReLU(True))
self.layer_2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True))
self.output = nn.Sequential(nn.Linear(n_hidden_2, out_dim)) def forward(self, x):
x = self.layer_1(x)
x = self.layer_2(x)
x = self.output(x)
return x

main.py 进行网络的训练

import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms import net batch_size = 128 # 每一个batch_size的大小
learning_rate = 1e-2 # 学习率的大小
num_epoches = 20 # 迭代的epoch值
# 表示data将数据变成0, 1之间,0.5, 0.5表示减去均值处以标准差
data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) # 表示均值和标准差
# 获得训练集的数据
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
# 获得测试集的数据
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf, download=True)
# 获得训练集的可迭代队列
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 获得测试集的可迭代队列
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 构造模型的网络
model = net.Batch_Net(28*28, 300, 100, 10)
if torch.cuda.is_available(): # 如果有cuda就将模型放在GPU上
model.cuda() criterion = nn.CrossEntropyLoss() # 构造交叉损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate) # 构造模型的优化器 for epoch in range(num_epoches): # 迭代的epoch
train_loss = 0 # 训练的损失值
test_loss = 0 # 测试的损失值
eval_acc = 0 # 测试集的准确率
for data in train_loader: # 获得一个batch的样本
img, label = data # 获得图片和标签
img = img.view(img.size(0), -1) # 将图片进行img的转换
if torch.cuda.is_available(): # 如果存在torch
img = Variable(img).cuda() # 将图片放在torch上
label = Variable(label).cuda() # 将标签放在torch上
else:
img = Variable(img) # 构造img的变量
label = Variable(label)
optimizer.zero_grad() # 消除optimizer的梯度
out = model.forward(img) # 进行前向传播
loss = criterion(out, label) # 计算损失值
loss.backward() # 进行损失值的后向传播
optimizer.step() # 进行优化器的优化
train_loss += loss.data #
for data in test_loader:
img, label = data
img = img.view(img.size(0), -1)
if torch.cuda.is_available():
img = Variable(img, volatile=True).cuda()
label = Variable(label, volatile=True).cuda()
else:
img = Variable(img, volatile=True)
label = Variable(label, volatile=True)
out = model.forward(img)
loss = criterion(out, label)
test_loss += loss.data
top_p, top_class = out.topk(1, dim=1) # 获得输出的每一个样本的最大损失
equals = top_class == label.view(*top_class.shape) # 判断两组样本的标签是否相等
accuracy = torch.mean(equals.type(torch.FloatTensor)) # 计算准确率
eval_acc += accuracy
print('train_loss{:.6f}, test_loss{:.6f}, Acc:{:.6f}'.format(train_loss / len(train_loader), test_loss / len(test_loader), eval_acc / len(test_loader)))

最新文章

  1. php获取网卡MAC地址源码
  2. 设置 TabBarItem 选中时的图片及文字颜色
  3. JSON和JSONP的区别
  4. Java利用MessageDigest提供的MD5算法加密字符串或文件
  5. Android开发-API指南-应用程序开发基础
  6. Mybatis 中常用的java类型与jdbc类型
  7. 在DataTable中更新、删除数据
  8. Swagger 生成 ASP.NET Web API
  9. LabView 下载与安装
  10. 多平台Client TCP通讯组件
  11. storm.yaml 配置项
  12. ICE第四篇-----python版本
  13. 【推荐】地推统计结算工具SDK,手机开发首选
  14. Java数据库学习之模糊查询(like )
  15. elasticsearch(6) 映射和分析
  16. 20165213 Exp4 恶意代码分析
  17. SpringBoot 2.0 报错: Failed to configure a DataSource: 'url' attribute is not specified and no embe
  18. 如何用MarsEdit快速插入源代码
  19. ES标准
  20. Asp.Net MVC学习总结之过滤器详解(转载)

热门文章

  1. PHP 多维数组将下标从0开始
  2. oracle的listagg函数
  3. Hexo NexT主题内加入动态背景
  4. rabbimq 生产消费者
  5. Collection 和 Collections的区别
  6. 【Struts2】拦截器
  7. Mysql(七):视图、触发器、事务、存储过程、函数
  8. ffmpeg 命令行 杂记
  9. 团队第二次作业:需求分析&系统设计
  10. PAT乙级1038