tensorflow中有很多需要变量共享的场合,比如在多个GPU上训练网络时网络参数和训练数据就需要共享。

tf通过 tf.get_variable() 可以建立或者获取一个共享的变量。 tf.get_variable函数的作用从tf的注释里就可以看出来-- ‘Gets an existing variable with this name or create a new one’。

与 tf.get_variable 函数相对的还有一个 tf.Variable 函数,两者的区别是:

  • tf.Variable定义变量的时候会自动检测命名冲突并自行处理,例如已经定义了一个名称是 ‘wg_1’的变量,再使用tf.Variable定义名称是‘wg_1’的变量,会自动把后一个变量的名称更改为‘wg_1_0’,实际相当于创建了两个变量,tf.Variable不可以创建共享变量。
  • tf.get_variable定义变量的时候不会自动处理命名冲突,如果遇到重名的变量并且创建该变量时没有设置为共享变量,tf会直接报错。

变量可以共享之后还有一个问题就是当模型很大很复杂的时候,变量和操作的数量也比较庞大,为了方便对这些变量进行管理,维护条理清晰的graph结构,tf建立了一套共享机制,通过 变量作用域(命名空间,variable_scope)实现对变量的共享和管理。例如,cnn的每一层中,均有weights和biases这两个变量,通过tf.variable_scope()为每一卷积层命名,就可以防止变量命名重复。

与 tf.variable_scope相对的还有一个 tf.name_scope 函数,两者的区别是:

  • tf.name_scope 主要用于管理一个图(graph)里面的各种操作,返回的是一个以scope_name命名的context manager。一个graph会维护一个name_space的堆,每一个namespace下面可以定义各种op或者子namespace,实现一种层次化有条理的管理,避免各个op之间命名冲突。
  • tf.variable_scope 一般与tf.name_scope()配合使用,用于管理一个图(graph)中变量的名字,避免变量之间的命名冲突,tf.variable_scope允许在一个variable_scope下面共享变量。
# coding: utf-8
import tensorflow as tf # 定义的基本等价
v1 = tf.get_variable("v", shape=[1], initializer= tf.constant_initializer(1.0))
v2 = tf.Variable(tf.constant(1.0, shape=[1]), name="v") with tf.variable_scope("abc"):
v3=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0)) # 在变量作用域内定义变量,不同变量作用域内的变量命名可以相同
with tf.variable_scope("xyz"):
v4=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0)) with tf.variable_scope("xyz", reuse=True):
v5 = tf.get_variable("v")
v6 = tf.get_variable("v",[1]) with tf.variable_scope("foo"):
v7 = tf.get_variable("v", [1]) # 通过 tf.get_variable_scope().reuse_variables() 设置以下的变量是共享变量;
# 如果不加,v8的定义会由于重名而报错
tf.get_variable_scope().reuse_variables()
v8 = tf.get_variable("v", [1])
assert v7 is v8 with tf.variable_scope("foo_1") as foo_scope:
v = tf.get_variable("v", [1])
with tf.variable_scope(foo_scope):
w = tf.get_variable("w", [1])
with tf.variable_scope(foo_scope, reuse=True):
v1 = tf.get_variable("v", [1])
w1 = tf.get_variable("w", [1])
assert v1 is v
assert w1 is w with tf.variable_scope("foo1"):
with tf.name_scope("bar1"):
v_1 = tf.get_variable("v", [1])
x_1 = 1.0 + v_1
assert v_1.name == "foo1/v:0"
assert x_1.op.name == "foo1/bar1/add" print v1==v2 # False
print v3==v4 # False 不同变量作用域中
print v3.name # abc/v:0
print v4==v5 # 输出为True
print v5==v6 # True

最新文章

  1. android 获取网络类型名称2G 3G 4G wifi
  2. git学习心得总结
  3. Java关键字用法及区别
  4. go语言选择语句 switch case
  5. 如何用iframe标签以及Javascript制作时钟?
  6. [转]Java中继承、多态、重载和重写介绍
  7. ural 1283. Dwarf
  8. PosPal银豹收银系统
  9. webservice soapclient报错Error fetching http headers
  10. USACO 2014 Open Silver Fairphoto
  11. 深入浅出 JavaScript 数组 v0.5
  12. 将Magento后台汉化的方法
  13. window 安装 Protobuf
  14. Windows 下如何安装配置Snort视频教程
  15. [Node.js]expressjs简单测试连接mysql
  16. Android动画之二:View Animation
  17. css实现的交互运动
  18. python3操作redis
  19. 关于leal和mov
  20. 如何通过Chrome远程调试android设备上的Web网站

热门文章

  1. (转)SpringBoot非官方教程 | 第三篇:SpringBoot用JdbcTemplates访问Mysql
  2. Selenium IDE脚本录制步骤简介
  3. 【转】Deep Learning(深度学习)学习笔记整理系列之(四)
  4. JavaScript Ajax上传文件miniupload.js
  5. Codeforces Round #524 (Div. 2) Solution
  6. P1879 [USACO06NOV]玉米田Corn Fields(状压dp)
  7. 20145106 java实验一
  8. 20145204 《Java程序设计》第9周学习总结
  9. HDU 2896 病毒侵袭(AC自动机)题解
  10. 【eclipse】运行maven项目clean tomcat7:run报错