pytorch中的nn.CrossEntropyLoss()
2024-09-01 22:43:53
nn.CrossEntropyLoss()这个损失函数和我们普通说的交叉熵还是有些区别。
$x$是模型生成的结果,$class$是数据对应的label
$loss(x,class)=-log(\frac{exp(x[class])}{\sum_j exp(x[j])})=-x[class]+log(\sum_j exp(x[j]))$
nn.CrossEntropyLoss()的使用方式参见如下代码
import torch
import torch.nn as nn # 表示模型的输出output(B,C)格式,B是batch,C是类别
output = torch.randn(2, 3, requires_grad = True) #batch_size设置为2,3分类
# 表示数据的标签label(B)格式,B是batch,其中的数值是位于[0,C-1]
label = torch.empty(2, dtype=torch.long).random_(3) # 0 - 2, 任意选取一个分类
print(output)
'''
tensor([[-1.1313, 0.5944, -1.5735],
[ 1.2037, -1.0548, -0.9253]], requires_grad=True)
'''
print(label)#tensor([0, 2])
loss = nn.CrossEntropyLoss()
#先对每个训练样本求损失,而后再求平均损失
print ('loss :', loss(output, label))#loss : tensor(2.1565, grad_fn=<NllLossBackward>)
最新文章
- 使用 Git Hooks 实现自动项目部署
- Host基本概念
- 基于 REST 的 Web 服务:基础
- asp.net中使用ueditor
- Linux下MySQL的备份与还原
- 第二百七十四、五、六天 how can I 坚持
- iOS开发之自定义控制器切换
- Sae 上传文件到Storage
- C#生成高清缩略图
- Swift3GCD
- 用dedecms做网站时,空间服务器选择IIS还是apache???
- php 快排
- Android使用统计图AChartEngine 来展示数据
- C++ cout格式化输出
- BZOJ2738 矩阵乘法(整体二分+树状数组)
- Liunx cp
- IP地址和域
- 关闭ubuntu dash 方法
- 对cnblogs.com的用户体验
- prisma graphql 工具基本使用