单向LSTM

import torch.nn as nn
import torch seq_len = 20
batch_size = 64
embedding_dim = 100
num_embeddings = 300
hidden_size = 128
number_layer = 3 input = torch.randint(low=0,high=256,size=[batch_size,seq_len]) #[64,20] embedding = nn.Embedding(num_embeddings,embedding_dim) input_embeded = embedding(input) #[64,20,100] #转置,变换batch_size 和seq_len
# input_embeded = input_embeded.transpose(0,1)
# input_embeded = input_embeded.permute(1,0,2)
#实例化lstm lstm = nn.LSTM(input_size=embedding_dim,hidden_size=hidden_size,batch_first=True,num_layers=number_layer) output,(h_n,c_n) = lstm(input_embeded)
print(output.size()) #[64,20,128] [batch_size,seq_len,hidden_size]
print(h_n.size()) #[3,64,128] [number_layer,batch_size,hidden_size]
print(c_n.size()) #同上 #获取最后时间步的output
output_last = output[:,-1,:]
#获取最后一层的h_n
h_n_last = h_n[-1] print(output_last.size())
print(h_n_last.size())
#最后的output等于最后一层的h_n
print(output_last.eq(h_n_last))

D:\anaconda\python.exe C:/Users/liuxinyu/Desktop/pytorch_test/day4/LSTM练习.py
torch.Size([64, 20, 128])
torch.Size([3, 64, 128])
torch.Size([3, 64, 128])
torch.Size([64, 128])
torch.Size([64, 128])
tensor([[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
...,
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True]])

Process finished with exit code 0

  双向LSTM

import torch.nn as nn
import torch seq_len = 20
batch_size = 64
embedding_dim = 100
num_embeddings = 300
hidden_size = 128
number_layer = 3 input = torch.randint(low=0,high=256,size=[batch_size,seq_len]) #[64,20] embedding = nn.Embedding(num_embeddings,embedding_dim) input_embeded = embedding(input) #[64,20,100] #转置,变换batch_size 和seq_len
# input_embeded = input_embeded.transpose(0,1)
# input_embeded = input_embeded.permute(1,0,2)
#实例化lstm lstm = nn.LSTM(input_size=embedding_dim,hidden_size=hidden_size,batch_first=True,num_layers=number_layer,bidirectional=True) output,(h_n,c_n) = lstm(input_embeded)
print(output.size()) #[64,20,128*2] [batch_size,seq_len,hidden_size]
print(h_n.size()) #[3*2,64,128] [number_layer,batch_size,hidden_size]
print(c_n.size()) #同上 #获取反向的最后一个output
output_last = output[:,0,-128:]
#获反向最后一层的h_n
h_n_last = h_n[-1] print(output_last.size())
print(h_n_last.size())
# 反向最后的output等于最后一层的h_n
print(output_last.eq(h_n_last)) #获取正向的最后一个output
output_last = output[:,-1,:128]
#获取正向最后一层的h_n
h_n_last = h_n[-2]
# 反向最后的output等于最后一层的h_n
print(output_last.eq(h_n_last))

D:\anaconda\python.exe C:/Users/liuxinyu/Desktop/pytorch_test/day4/双向LSTM练习.py
torch.Size([64, 20, 256])
torch.Size([6, 64, 128])
torch.Size([6, 64, 128])
torch.Size([64, 128])
torch.Size([64, 128])
tensor([[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
...,
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True]])
tensor([[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
...,
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True]])

Process finished with exit code 0

  

最新文章

  1. e.target.files[0]层层剖析
  2. VMware下利用ubuntu13.04建立嵌入式开发环境之一
  3. 【Python自动化运维之路Day6】
  4. HTTP 错误 500.21 - Internal Server Error 处理程序“ExtensionlessUrlHandler-Integrated-4.0”在其模块列表中有一个错误模块“ManagedPipelineHandler”
  5. 使用Jvisualvm监控JVM的内存、CPU、线程
  6. 将Excel中数据导入数据库(一)
  7. 一起刷LeetCode5-Longest Palindromic Substring
  8. Careercup - Google面试题 - 5732809947742208
  9. 【HTML】Advanced3:Tables: Columns, Headers, and Footers
  10. java基础(二十一)IO流(四)
  11. java调用C++ DLL库方法
  12. 在CG/HLSL中访问着色器属性(Properties)
  13. ubuntu15.04安装hexo
  14. 找唯一不出现三次而出现1次的数子O(n)位运算算法
  15. 【JAVASCRIPT】React + Redux
  16. OAuth2.0介绍
  17. 项目配置linux上, 配置文件访问不到
  18. rabbitMQ rabbitmq-server -detached rabbitmq-server -detached rabbitmq-server -detached
  19. 使用 UICollectionView 实现网格化视图效果
  20. iOS文本文件的编码检测

热门文章

  1. 常见Web安全漏洞--------CSRF
  2. 非常诡异的IIS下由配置文件加上svg的mime头导致整个网站的静态文件访问报错误
  3. 理解MapReduce计算构架
  4. [POJ1835]宇航员<模拟>
  5. Effective Java要点笔记
  6. Java序列化机制剖析
  7. 无法像程序语言那样写SQL查询语句,提示“数据库中已存在名为 '#temp1' 的对象。”
  8. Keil5新建STM32工程(库函数版本)
  9. mpvue中使用flyjs全局拦截
  10. es elasticsearch 6/7 设置内存方法