原因是因为checkpoint设置好的确是保存了相关字段。但是其中设置的train_dataset却已经走过了epoch轮,当你再继续训练时候,train_dataset是从第一个load_data开始。

# -*- coding:utf-8 -*-
import os
import numpy as np
import torch
import cv2
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from matplotlib import pyplot as plt
import os
from PIL import Image
os.environ ['KMP_DUPLICATE_LIB_OK'] ='True'
import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)
fmap_block = list()
import torch.nn.functional as F
grad_block = list()
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) torch.manual_seed(1) # 设置随机种子
rmb_label = {"1": 0, "100": 1}
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 2) def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool1(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1
checkpoint_interval=5 # ============================ step 1/5 数据 ============================ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "rmb_split"))
if not os.path.exists(split_dir):
raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))
train_dir = os.path.join(split_dir, "train") train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
])
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
net = Net()
criterion = nn.CrossEntropyLoss() # 选择损失函数
# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器 checkpointdict = torch.load('./checkpoint4.pkl')
net.load_state_dict(checkpointdict["model_state_dict"])
optimizer.load_state_dict(checkpointdict["optimizer_state_dict"])
startepoch = checkpointdict["epoch"]
# ============================ step 5/5 训练 ============================
train_curve = list()
iter_count = 0 for epoch in range(startepoch+1,MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
for counti in range(6):
for i, data in enumerate(train_loader):
if counti <5:
continue
else:
iter_count += 1
# forward
inputs, labels = data
outputs = net(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()
# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.
# if ((epoch + 1) % checkpoint_interval == 0):
# checkpoint = {"model_state_dict": net.state_dict(),
# "optimizer_state_dict": optimizer.state_dict(),
# "epoch": epoch}
# path_checkpoint = './checkpoint{}.pkl'.format(epoch)
# torch.save(checkpoint, path_checkpoint)
# if ((epoch + 1) % 5 == 0):
# print("退出")
# break

最新文章

  1. 在线课程笔记&mdash;.NET基础
  2. springmvc的类型转换
  3. ThinkPHP跨控制器调用方法
  4. js中接口的声明与实现
  5. MySQL-使用tcpdump排查MySQLl数据库tps飙升的问题
  6. 《算法导论》习题解答 Chapter 22.1-7(关联矩阵的性质)
  7. 数据库对象(视图,序列,索引,同义词)【weber出品必属精品】
  8. DirectUI实现原理
  9. Boolean对象 识记
  10. OpenGL ES着色器语言之静态使用(static use)和预处理
  11. python学习日记(包——package)
  12. jQuery初识之选择器、样式操作和筛选器(模态框和菜单示例)
  13. 使用InternalsVisibleTo给assembly添加“友元assembly”
  14. 部署在sae上的servlet程序出现is not a javax.servlet.Servlet 错误
  15. eclipse如何将项目上传到码云
  16. Linux常用软件启动、停止、重启命令
  17. springcloud的turbine集成zookeeper
  18. ABP框架学习
  19. [luogu2114][起床困难综合症]
  20. Arrays常用方法

热门文章

  1. fzu2198 快来快来数一数
  2. 营业额统计 HYSBZ - 1588
  3. HDU 3416 Marriage Match IV (最短路径&amp;&amp;最大流)
  4. ElasticSearch入门到筋痛
  5. 牛客网多校第4场 J Hash Function 【思维+并查集建边】
  6. JVM升华篇
  7. 7816协议时序和采用UART模拟7816时序与智能卡APDU指令协议
  8. jira 优先级过滤
  9. 1GB === 1000MB &amp; 1GB === 1024MB
  10. Immutable.js 实现原理