W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗(yàn)值獎勵
函數(shù):tf.get_collection
get_collection(
key,
scope=None
)
定義在:tensorflow/python/framework/ops.py.
參見指南:構(gòu)建圖>圖形集合
使用默認(rèn)圖形來包裝 Graph.get_collection().
參數(shù):
返回值:
集合中具有給定 name 的值的列表,或者如果沒有值已添加到該集合中,則為空列表.該列表包含按其收集順序排列的值.
函數(shù):tf.get_collection_ref
get_collection_ref(key)
定義在:tensorflow/python/framework/ops.py.
參見指南:構(gòu)建圖>圖形集合
使用默認(rèn)圖表來包裝 Graph.get_collection_ref().
參數(shù):
返回值:
集合中具有給定 name 的值的列表,或者如果沒有值已添加到該集合中,則為空列表.請注意,這將返回集合列表本身,可以修改該列表來更改集合.
# 在'My-TensorFlow-tutorials-master/02 CIFAR10/cifar10.py'代碼中
variables = tf.get_collection(tf.GraphKeys.VARIABLES)
for i in variables:
print(i)
>>> <tf.Variable 'conv1/weights:0' shape=(3, 3, 3, 96) dtype=float32_ref>
<tf.Variable 'conv1/biases:0' shape=(96,) dtype=float32_ref>
<tf.Variable 'conv2/weights:0' shape=(3, 3, 96, 64) dtype=float32_ref>
<tf.Variable 'conv2/biases:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'local3/weights:0' shape=(16384, 384) dtype=float32_ref>
<tf.Variable 'local3/biases:0' shape=(384,) dtype=float32_ref>
<tf.Variable 'local4/weights:0' shape=(384, 192) dtype=float32_ref>
<tf.Variable 'local4/biases:0' shape=(192,) dtype=float32_ref>
<tf.Variable 'softmax_linear/softmax_linear:0' shape=(192, 10) dtype=float32_ref>
<tf.Variable 'softmax_linear/biases:0' shape=(10,) dtype=float32_ref>
tf.get_collection會列出key里所有的值。
tf.GraphKeys 的點(diǎn)后可以跟很多類, 比如 VARIABLES 類(包含所有variables), 比如 REGULARIZATION_LOSSES。
具體 tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 的使用:
def easier_network(x, reg):
""" A network based on tf.contrib.learn, with input `x`. """
with tf.variable_scope('EasyNet'):
out = layers.flatten(x)
out = layers.fully_connected(out,
num_outputs=200,
weights_initializer = layers.xavier_initializer(uniform=True),
weights_regularizer = layers.l2_regularizer(scale=reg),
activation_fn = tf.nn.tanh)
out = layers.fully_connected(out,
num_outputs=200,
weights_initializer = layers.xavier_initializer(uniform=True),
weights_regularizer = layers.l2_regularizer(scale=reg),
activation_fn = tf.nn.tanh)
out = layers.fully_connected(out,
num_outputs=10, # Because there are ten digits!
weights_initializer = layers.xavier_initializer(uniform=True),
weights_regularizer = layers.l2_regularizer(scale=reg),
activation_fn = None)
return out
def main(_):
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
# Make a network with regularization
y_conv = easier_network(x, FLAGS.regu)
weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'EasyNet')
print("")
for w in weights:
shp = w.get_shape().as_list()
print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp)))
print("")
reg_ws = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, 'EasyNet')
for w in reg_ws:
shp = w.get_shape().as_list()
print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp)))
print("")
# Make the loss function `loss_fn` with regularization.
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
loss_fn = cross_entropy + tf.reduce_sum(reg_ws)
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss_fn)
main()
>>> - EasyNet/fully_connected/weights:0 shape:[784, 200] size:156800
- EasyNet/fully_connected/biases:0 shape:[200] size:200
- EasyNet/fully_connected_1/weights:0 shape:[200, 200] size:40000
- EasyNet/fully_connected_1/biases:0 shape:[200] size:200
- EasyNet/fully_connected_2/weights:0 shape:[200, 10] size:2000
- EasyNet/fully_connected_2/biases:0 shape:[10] size:10
- EasyNet/fully_connected/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0
- EasyNet/fully_connected_1/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0
- EasyNet/fully_connected_2/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0
根據(jù)下面的代碼的輸出可知, 在圖上的所有regularization都會集中保存到tf.GraphKeys.REGULARIZATION_LOSSES去。
for w in reg_ws:
shp = ....
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報電話:173-0602-2364|舉報郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: