import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim from torch.utils.tensorboard import SummaryWriter batch_size_list = [100, 1000, 10000]
lr_list = [.01, .001, .0001, .00001]
shuffle = [True,False]
def get_num_correct(preds, labels):
return preds.argmax(dim=1).eq(labels).sum().item() train_set = torchvision.datasets.FashionMNIST(
root='./data/FashionMNIST',
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor()])
) # data_loader = torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=True) # shuffle=True class Network(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5) self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=60)
self.out = nn.Linear(in_features=60, out_features=10) def forward(self, t):
# (1) input layer
t = t # (2) hidden conv layer
t = self.conv1(t)
t = F.relu(t)
t = F.max_pool2d(t, kernel_size=2, stride=2) # (3) hidden conv layer
t = self.conv2(t)
t = F.relu(t)
t = F.max_pool2d(t, kernel_size=2, stride=2) # (4) hidden Linear layer
t = t.reshape(-1, 12 * 4 * 4) # -1表示对行没约束,反正是12*4*4列
t = self.fc1(t)
t = F.relu(t)
# (5) hidden Linear layer
t = self.fc2(t)
t = F.relu(t)
# (6) output layer
t = self.out(t)
# t=F.softmax(t,dim=1) #此处不使用softmax函数,因为在训练中我们使用了交叉熵损失函数,而在torch.nn函数类中,已经在其输入中隐式的
# 执行了一个softmax操作,这里我们只返回最后一个线性变换的结果,也即是 return t,也即意味着我们的网络将使用softmax操作进行训练,但在
# 训练完成后,将不需要额外的计算操纵。 return t network = Network() for batch_size in batch_size_list:
for lr in lr_list:
network = Network() data_loader = torch.utils.data.DataLoader(
train_set, batch_size=batch_size
)
optimizer = optim.Adam(
network.parameters(), lr=lr
) images, labels = next(iter(data_loader))
grid = torchvision.utils.make_grid(images) comment=f' batch_size={batch_size} lr={lr}'
tb = SummaryWriter(comment=comment)
tb.add_image('images', grid)
tb.add_graph(network, images) for epoch in range(5):
total_loss = 0
total_correct = 0
for batch in data_loader:
images, labels = batch # Get Batch
preds = network(images) # Pass Batch
loss = F.cross_entropy(preds, labels) # Calculate Loss
optimizer.zero_grad() # Zero Gradients
loss.backward() # Calculate Gradients
optimizer.step() # Update Weights total_loss += loss.item() * batch_size#这里上述用的是mini-batch训练方法,一个batch得loss会被平均,所以乘以size得到总和
total_correct += get_num_correct(preds, labels) tb.add_scalar(
'Loss', total_loss, epoch
)
tb.add_scalar(
'Number Correct', total_correct, epoch
)
tb.add_scalar(
'Accuracy', total_correct / len(train_set), epoch
) for name, param in network.named_parameters():
tb.add_histogram(name, param, epoch)
tb.add_histogram(f'{name}.grad', param.grad, epoch) print(
"epoch", epoch
,"total_correct:", total_correct
,"loss:", total_loss
)
tb.close() f'''
from itertools import product parameters=dict(lr=[.01,.001],batch_size=[10,100,1000],shuffle=[True,False])
# for i,j in parameters.items():
# print(i,j,sep='\t')
para_values=[value for value in parameters.values()]
for lr,batch_size,shuffle in product(*para_values):#这里的星号告诉乘积函数把列表中的每个值作为参数,而不是把列表本身作为参数来对待
comment=f' batch_size={batch_size} lr={lr} shuffle={shuffle}'
print(lr,batch_size,shuffle)
'''

  

最新文章

  1. Quartz —— Spring 环境下的使用
  2. java_method_下拉框成json
  3. nfs基本配置
  4. perl备忘
  5. Android的消息处理机制,handler,message,looper(一)
  6. C# 构造函数的使用方法
  7. bzoj1084: [SCOI2005]最大子矩阵
  8. cf443A Anton and Letters
  9. dojo的TabContainer添加ContentPane假设closable,怎么不闭幕后予以销毁ContentPane
  10. 第八讲:I/O虚拟化
  11. winform 实现类似于TrackBar的自定义滑动条,功能更全
  12. python递归
  13. scrapy使用指南
  14. form提交xml文件
  15. XML一
  16. ODPS SQL <for 数据操作语言DML>
  17. Eclipse中一些真正常用的快捷键
  18. Shiro 基础教程
  19. linux bash的重定向
  20. Spring Boot开发之流水无情(二)

热门文章

  1. Mybatis 实现批量插入和批量删除源码实例
  2. px,rem,em 通过媒体查询统一的代码
  3. redis 指定db库导入导出数据
  4. 制作Unity中的单位血条
  5. 不仅仅是一把瑞士军刀 —— Apifox的野望和不足
  6. 1s 创建100G文件,最快的方法是?
  7. 我用 CSS3 实现了一个超炫的 3D 加载动画
  8. 在django中使用orm来操作MySQL数据库的建表,增删改
  9. Linux获取本机公网IP,调整双节点主从服务的RPC调用逻辑
  10. 【在下版本,有何贵干?】Dockerfile中 RUN yum -y install vim失败Cannot prepare internal mirrorlist: No URLs in mirrorlist