Pytorch之数据处理
2024-10-21 07:27:26
使用TensorDataset和DataLoader来简化
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
def get_data(train_ds, valid_ds, bs):
return (
DataLoader(train_ds, batch_size=bs, shuffle=True),
DataLoader(valid_ds, batch_size=bs * 2),
)
- 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
- 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
for step in range(steps):
model.train()
for xb, yb in train_dl:
loss_batch(model, loss_func, xb, yb, opt)
model.eval()
with torch.no_grad():
losses, nums = zip(
*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
)
val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
print('当前step:'+str(step), '验证集损失:'+str(val_loss))
from torch import optim
def get_model():
model = Mnist_NN()
return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):
loss = loss_func(model(xb), yb)
if opt is not None:
loss.backward()
opt.step()
opt.zero_grad()
return loss.item(), len(xb)
三行搞定!
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)
最新文章
- 准备NOIP2017 编辑距离问题 模板
- 提高AdoQuery的速度
- 随机生成字符串-php-js
- Objective-C与C style语言的简单类比
- linux命令-shopt
- 向架构师进军---&;gt;怎样编写软件架构文档
- Bloom Filter概念和原理
- BZOJ3916: [Baltic2014]friends
- GParted: GNOME Partition Editor, sharp weapon to modify disk partitions.
- 通过css3实现的动画导航菜单代码
- 上下文管理器——with语句的实现
- PAT A1140 Look-and-say Sequence (20 分)——数学题
- ZOJ 3690 Choosing number(矩阵)
- Package CJK Error: Invalid character code. 问题解决方法--xelatex和pdflatex编译的转换
- Ubuntu 16.04服务器 配置
- Leetcode 784
- shell 脚本 实现随机数
- 利用flume+kafka+storm+mysql构建大数据实时系统
- [EMWIN]关于 GUI_GetPixelIndex 使用的问题
- 利用ffmpeg一步一步编程实现摄像头采集编码推流直播系统
热门文章
- 在GCP的Kubernetes上安装dapr
- 【学习笔记】XR872 Audio 驱动框架分析
- (原创)【B4A】一步一步入门02:可视化界面设计器、控件的使用
- k210 cpu、asm、rust、smpboot、ipi
- spring cloud alibaba sentinel 运行及简单使用
- [代码审计基础 02]-SQL注入和预编译和预编译绕过
- JZOJ 2933. 【NOIP2012模拟8.7】找位置
- 基于电商直播SDK快速实现一个淘宝直播APP【内附源码】
- 容忍和污点Taint和Toleration
- Word19 撰写企业质量管理论文office真题