閱讀513 返回首頁    go 阿裏雲 go 技術社區[雲棲]


TensorFlow教程之進階指南 3.2 變量:創建、初始化、保存和加載

本文檔為TensorFlow參考文檔,本轉載已得到TensorFlow中文社區授權。



當訓練模型時,用變量來存儲和更新參數。變量包含張量 (Tensor)存放於內存的緩存區。建模時它們需要被明確地初始化,模型訓練後它們必須被存儲到磁盤。這些變量的值可在之後模型訓練和分析是被加載。

本文檔描述以下兩個TensorFlow類。點擊以下鏈接可查看完整的API文檔:

創建

當創建一個變量時,你將一個張量作為初始值傳入構造函數Variable()。TensorFlow提供了一係列操作符來初始化張量,初始值是常量或是隨機值

注意,所有這些操作符都需要你指定張量的shape。那個形狀自動成為變量的shape。變量的shape通常是固定的,但TensorFlow提供了高級的機製來重新調整其行列數。

# Create two variables.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
biases = tf.Variable(tf.zeros([200]), name="biases")

調用tf.Variable()添加一些操作(Op, operation)到graph:

  • 一個Variable操作存放變量的值。
  • 一個初始化op將變量設置為初始值。這事實上是一個tf.assign操作.
  • 初始值的操作,例如示例中對biases變量的zeros操作也被加入了graph。

tf.Variable的返回值是Python的tf.Variable類的一個實例。

初始化

變量的初始化必須在模型的其它操作運行之前先明確地完成。最簡單的方法就是添加一個給所有變量初始化的操作,並在使用模型之前首先運行那個操作。

你或者可以從檢查點文件中重新獲取變量值,詳見下文。

使用tf.initialize_all_variables()添加一個操作對變量做初始化。記得在完全構建好模型並加載之後再運行那個操作。

# Create two variables.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
biases = tf.Variable(tf.zeros([200]), name="biases")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()

# Later, when launching the model
with tf.Session() as sess:
  # Run the init operation.
  sess.run(init_op)
  ...
  # Use the model
  ...

由另一個變量初始化

你有時候會需要用另一個變量的初始化值給當前變量初始化。由於tf.initialize_all_variables()是並行地初始化所有變量,所以在有這種需求的情況下需要小心。

用其它變量的值初始化一個新的變量時,使用其它變量的initialized_value()屬性。你可以直接把已初始化的值作為新變量的初始值,或者把它當做tensor計算得到一個值賦予新變量。

# Create a variable with a random value.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
# Create another variable with the same value as 'weights'.
w2 = tf.Variable(weights.initialized_value(), name="w2")
# Create another variable with twice the value of 'weights'
w_twice = tf.Variable(weights.initialized_value() * 0.2, name="w_twice")

自定義初始化

tf.initialize_all_variables()函數便捷地添加一個op來初始化模型的所有變量。你也可以給它傳入一組變量進行初始化。詳情請見Variables Documentation,包括檢查變量是否被初始化。

保存和加載

最簡單的保存和恢複模型的方法是使用tf.train.Saver對象。構造器給graph的所有變量,或是定義在列表裏的變量,添加saverestoreops。saver對象提供了方法來運行這些ops,定義檢查點文件的讀寫路徑。

檢查點文件

變量存儲在二進製文件裏,主要包含從變量名到tensor值的映射關係。

當你創建一個Saver對象時,你可以選擇性地為檢查點文件中的變量挑選變量名。默認情況下,將每個變量Variable.name屬性的值。

保存變量

tf.train.Saver()創建一個Saver來管理模型中的所有變量。

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print "Model saved in file: ", save_path

恢複變量

用同一個Saver對象來恢複變量。注意,當你從文件中恢複變量時,不需要事先對它們做初始化。

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print "Model restored."
  # Do some work with the model
  ...

選擇存儲和恢複哪些變量

如果你不給tf.train.Saver()傳入任何參數,那麼saver將處理graph中的所有變量。其中每一個變量都以變量創建時傳入的名稱被保存。

有時候在檢查點文件中明確定義變量的名稱很有用。舉個例子,你也許已經訓練得到了一個模型,其中有個變量命名為"weights",你想把它的值恢複到一個新的變量"params"中。

有時候僅保存和恢複模型的一部分變量很有用。再舉個例子,你也許訓練得到了一個5層神經網絡,現在想訓練一個6層的新模型,可以將之前5層模型的參數導入到新模型的前5層中。

你可以通過給tf.train.Saver()構造函數傳入Python字典,很容易地定義需要保持的變量及對應名稱:鍵對應使用的名稱,值對應被管理的變量。

注意:

  • 如果需要保存和恢複模型變量的不同子集,可以創建任意多個saver對象。同一個變量可被列入多個saver對象中,隻有當saver的restore()函數被運行時,它的值才會發生改變。
  • 如果你僅在session開始時恢複模型變量的一個子集,你需要對剩下的變量執行初始化op。詳情請見tf.initialize_variables()
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore only 'v2' using the name "my_v2"
saver = tf.train.Saver({"my_v2": v2})
# Use the saver object normally after that.
...

最後更新:2017-08-22 16:04:22

  上一篇:go  TensorFlow教程之進階指南 3.3 TensorBoard:可視化學習
  下一篇:go  TensorFlow教程之進階指南 3.1 總覽