pytorch中的词向量的使用

在pytorch我们使用nn.embedding进行词嵌入的工作。

具体用法就是:

import torch
word_to_ix={'hello':0,'world':1}
embeds = torch.nn.Embedding(2,5)
hello_idx=torch.LongTensor([word_to_ix['hello']])
hello_embed = embeds(hello_idx)
print(hello_embed)
print(embeds.weight) tensor([[ 0.6584, 0.2991, -1.2654, 0.9369, 0.6088]], grad_fn=<EmbeddingBackward>) Parameter containing:
tensor([[ 0.6584, 0.2991, -1.2654, 0.9369, 0.6088],
[ 0.1922, 1.5374, 0.5737, -0.8007, -0.4896]], requires_grad=True)

在torch.nn.Embedding的源代码中,它是这么解释,

This module is often used to store word embeddings and retrieve them using indices.

The input to the module is a list of indices, and the output is the corresponding

word embeddings.

对于这个,我的理解是这样的torch.nn.Embedding 是一个矩阵类,当我传入参数之后,我可以得到一个矩阵对象,比如上面代码中的

embeds = torch.nn.Embedding(2,5) 通过这个代码,我就获得了一个两行三列的矩阵对象embeds。这个时候,矩阵对象embeds的输入就是一个索引列表(当然这个列表

应该是longtensor格式,得到的结果就是对应索引的词向量)

我们这里有一点需要格外注意,在上面的结果中,有个这个东西 requires_grad=True

我在开始接触pytorch的时候,对embedding的一个疑惑就是它是如何定义自动更新的。因为现在我们得到的这个词向量是随机初始化的结果,

在后续神经网络反向传递过程中,这个参数是需要更新的。

这里我想要点出一点来,就是词向量在这里是使用标准正态分布进行的初始化。我们可以通过查看源代码来进行验证。

在源代码中

if _weight is None:
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) ##定义一个Parameter对象
self.reset_parameters() #随后对这个对象进行初始化
...
... def reset_parameters(self): #标准正态进行初始化
init.normal_(self.weight)
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)

最新文章

  1. nginx 添加nginx-http-concat模块
  2. xcode 消除警告
  3. JS回到顶部代码小记
  4. GIT用法总结
  5. WPF:行列显示
  6. js ajax post提交 ie和火狐、谷歌提交的编码不一致,导致中文乱码
  7. iOS修改声明为readonly的属性值
  8. ecshop在nginx下实现负载均衡
  9. js 联系电话验证实现
  10. cocos2d-x 3.0来做一个简单的游戏教程 win32平台 vs2012 详解献给刚開始学习的人们!
  11. jsp web JavaBean MVC 架构 EL表达式 EL函数 JSTL
  12. 类似fabric主机管理demo
  13. 用C# (.NET Core) 实现迭代器设计模式
  14. Spring常用配置示例
  15. Vue 2.3、2.4 知识点小结
  16. arcgis 获得工具箱工具的个数
  17. MVC校验方式【六】
  18. C#语言————格式化数值结果表
  19. python笔记之强制函数以关键字参数传参
  20. JVM系列三:JVM参数设置

热门文章

  1. 树莓派直连线连接PC
  2. intellijidea课程 intellijidea神器使用技巧1-4 idea安装
  3. agc001E - BBQ Hard(dp 组合数)
  4. 由Asp.Net客户端控件生成的服务器端控件
  5. java compiler没有1.8怎么办
  6. aaS软件的必要特征分析,一定是多租户特性吗
  7. vim复制粘贴到系统剪贴板
  8. MYSQL:随机抽取一条数据库记录
  9. Socket的基本使用步骤
  10. sublime打开txt文件乱码的问题