tensorflow中共享变量 tf.get_variable 和命名空间 tf.variable_scope
2024-08-22 22:50:22
tensorflow中有很多需要变量共享的场合,比如在多个GPU上训练网络时网络参数和训练数据就需要共享。
tf通过 tf.get_variable() 可以建立或者获取一个共享的变量。 tf.get_variable函数的作用从tf的注释里就可以看出来-- ‘Gets an existing variable with this name or create a new one’。
与 tf.get_variable 函数相对的还有一个 tf.Variable 函数,两者的区别是:
- tf.Variable定义变量的时候会自动检测命名冲突并自行处理,例如已经定义了一个名称是 ‘wg_1’的变量,再使用tf.Variable定义名称是‘wg_1’的变量,会自动把后一个变量的名称更改为‘wg_1_0’,实际相当于创建了两个变量,tf.Variable不可以创建共享变量。
- tf.get_variable定义变量的时候不会自动处理命名冲突,如果遇到重名的变量并且创建该变量时没有设置为共享变量,tf会直接报错。
变量可以共享之后还有一个问题就是当模型很大很复杂的时候,变量和操作的数量也比较庞大,为了方便对这些变量进行管理,维护条理清晰的graph结构,tf建立了一套共享机制,通过 变量作用域(命名空间,variable_scope)实现对变量的共享和管理。例如,cnn的每一层中,均有weights和biases这两个变量,通过tf.variable_scope()为每一卷积层命名,就可以防止变量命名重复。
与 tf.variable_scope相对的还有一个 tf.name_scope 函数,两者的区别是:
- tf.name_scope 主要用于管理一个图(graph)里面的各种操作,返回的是一个以scope_name命名的context manager。一个graph会维护一个name_space的堆,每一个namespace下面可以定义各种op或者子namespace,实现一种层次化有条理的管理,避免各个op之间命名冲突。
- tf.variable_scope 一般与tf.name_scope()配合使用,用于管理一个图(graph)中变量的名字,避免变量之间的命名冲突,tf.variable_scope允许在一个variable_scope下面共享变量。
# coding: utf-8
import tensorflow as tf
# 定义的基本等价
v1 = tf.get_variable("v", shape=[1], initializer= tf.constant_initializer(1.0))
v2 = tf.Variable(tf.constant(1.0, shape=[1]), name="v")
with tf.variable_scope("abc"):
v3=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
# 在变量作用域内定义变量,不同变量作用域内的变量命名可以相同
with tf.variable_scope("xyz"):
v4=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
with tf.variable_scope("xyz", reuse=True):
v5 = tf.get_variable("v")
v6 = tf.get_variable("v",[1])
with tf.variable_scope("foo"):
v7 = tf.get_variable("v", [1])
# 通过 tf.get_variable_scope().reuse_variables() 设置以下的变量是共享变量;
# 如果不加,v8的定义会由于重名而报错
tf.get_variable_scope().reuse_variables()
v8 = tf.get_variable("v", [1])
assert v7 is v8
with tf.variable_scope("foo_1") as foo_scope:
v = tf.get_variable("v", [1])
with tf.variable_scope(foo_scope):
w = tf.get_variable("w", [1])
with tf.variable_scope(foo_scope, reuse=True):
v1 = tf.get_variable("v", [1])
w1 = tf.get_variable("w", [1])
assert v1 is v
assert w1 is w
with tf.variable_scope("foo1"):
with tf.name_scope("bar1"):
v_1 = tf.get_variable("v", [1])
x_1 = 1.0 + v_1
assert v_1.name == "foo1/v:0"
assert x_1.op.name == "foo1/bar1/add"
print v1==v2 # False
print v3==v4 # False 不同变量作用域中
print v3.name # abc/v:0
print v4==v5 # 输出为True
print v5==v6 # True
最新文章
- android 获取网络类型名称2G 3G 4G wifi
- git学习心得总结
- Java关键字用法及区别
- go语言选择语句 switch case
- 如何用iframe标签以及Javascript制作时钟?
- [转]Java中继承、多态、重载和重写介绍
- ural 1283. Dwarf
- PosPal银豹收银系统
- webservice soapclient报错Error fetching http headers
- USACO 2014 Open Silver Fairphoto
- 深入浅出 JavaScript 数组 v0.5
- 将Magento后台汉化的方法
- window 安装 Protobuf
- Windows 下如何安装配置Snort视频教程
- [Node.js]expressjs简单测试连接mysql
- Android动画之二:View Animation
- css实现的交互运动
- python3操作redis
- 关于leal和mov
- 如何通过Chrome远程调试android设备上的Web网站
热门文章
- (转)SpringBoot非官方教程 | 第三篇:SpringBoot用JdbcTemplates访问Mysql
- Selenium IDE脚本录制步骤简介
- 【转】Deep Learning(深度学习)学习笔记整理系列之(四)
- JavaScript Ajax上传文件miniupload.js
- Codeforces Round #524 (Div. 2) Solution
- P1879 [USACO06NOV]玉米田Corn Fields(状压dp)
- 20145106 java实验一
- 20145204 《Java程序设计》第9周学习总结
- HDU 2896 病毒侵袭(AC自动机)题解
- 【eclipse】运行maven项目clean tomcat7:run报错