614
技術社區[雲棲]
TensorFlow教程之完整教程 2.5 TensorFlow運作方式入門
本文檔為TensorFlow參考文檔,本轉載已得到TensorFlow中文社區授權。
本篇教程的目的,是向大家展示如何利用TensorFlow使用(經典)MNIST數據集訓練並評估一個用於識別手寫數字的簡易前饋神經網絡(feed-forward neural network)。我們的目標讀者,是有興趣使用TensorFlow的資深機器學習人士。
因此,撰寫該係列教程並不是為了教大家機器學習領域的基礎知識。
在學習本教程之前,請確保您已按照安裝TensorFlow教程中的要求,完成了安裝。
教程使用的文件
本教程引用如下文件:
文件 | 目的 |
---|---|
mnist.py |
構建一個完全連接(fully connected)的MINST模型所需的代碼。 |
fully_connected_feed.py |
利用下載的數據集訓練構建好的MNIST模型的主要代碼,以數據反饋字典(feed dictionary)的形式作為輸入模型。 |
隻需要直接運行fully_connected_feed.py
文件,就可以開始訓練:
python fully_connected_feed.py
準備數據
MNIST是機器學習領域的一個經典問題,指的是讓機器查看一係列大小為28x28像素的手寫數字灰度圖像,並判斷這些圖像代表0-9中的哪一個數字。
下載
在run_training()
方法的一開始,input_data.read_data_sets()
函數會確保你的本地訓練文件夾中,已經下載了正確的數據,然後將這些數據解壓並返回一個含有DataSet
實例的字典。
data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
注意:fake_data
標記是用於單元測試的,讀者可以不必理會。
數據集 | 目的 |
---|---|
data_sets.train |
55000個圖像和標簽(labels),作為主要訓練集。 |
data_sets.validation |
5000個圖像和標簽,用於迭代驗證訓練準確度。 |
data_sets.test |
10000個圖像和標簽,用於最終測試訓練準確度(trained accuracy)。 |
輸入與占位符(Inputs and Placeholders)
placeholder_inputs()
函數將生成兩個tf.placeholder
操作,定義傳入圖表中的shape參數,shape參數中包括batch_size
值,後續還會將實際的訓練用例傳入圖表。
images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
IMAGE_PIXELS))
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
在訓練循環(training loop)的後續步驟中,傳入的整個圖像和標簽數據集會被切片,以符合每一個操作所設置的batch_size
值,占位符操作將會填補以符合這個batch_size
值。然後使用feed_dict
參數,將數據傳入sess.run()
函數。
構建圖表 (Build the Graph)
在為數據創建占位符之後,就可以運行mnist.py
文件,經過三階段的模式函數操作:inference()
, loss()
,和training()
。圖表就構建完成了。
1.inference()
—— 盡可能地構建好圖表,滿足促使神經網絡向前反饋並做出預測的要求。
2.loss()
—— 往inference圖表中添加生成損失(loss)所需要的操作(ops)。
3.training()
—— 往損失圖表中添加計算並應用梯度(gradients)所需的操作。

推理(Inference)
inference()
函數會盡可能地構建圖表,做到返回包含了預測結果(output prediction)的Tensor。
它接受圖像占位符為輸入,在此基礎上借助ReLu(Rectified Linear Units)激活函數,構建一對完全連接層(layers),以及一個有著十個節點(node)、指明了輸出logtis模型的線性層。
每一層都創建於一個唯一的tf.name_scope
之下,創建於該作用域之下的所有元素都將帶有其前綴。
with tf.name_scope('hidden1') as scope:
在定義的作用域中,每一層所使用的權重和偏差都在tf.Variable
實例中生成,並且包含了各自期望的shape。
weights = tf.Variable(
tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
name='weights')
biases = tf.Variable(tf.zeros([hidden1_units]),
name='biases')
例如,當這些層是在hidden1
作用域下生成時,賦予權重變量的獨特名稱將會是"hidden1/weights
"。
每個變量在構建時,都會獲得初始化操作(initializer ops)。
在這種最常見的情況下,通過tf.truncated_normal
函數初始化權重變量,給賦予的shape則是一個二維tensor,其中第一個維度代表該層中權重變量所連接(connect from)的單元數量,第二個維度代表該層中權重變量所連接到的(connect to)單元數量。對於名叫hidden1
的第一層,相應的維度則是[IMAGE_PIXELS, hidden1_units]
,因為權重變量將圖像輸入連接到了hidden1
層。tf.truncated_normal
初始函數將根據所得到的均值和標準差,生成一個隨機分布。
然後,通過tf.zeros
函數初始化偏差變量(biases),確保所有偏差的起始值都是0,而它們的shape則是其在該層中所接到的(connect to)單元數量。
圖表的三個主要操作,分別是兩個tf.nn.relu
操作,它們中嵌入了隱藏層所需的tf.matmul
;以及logits模型所需的另外一個tf.matmul
。三者依次生成,各自的tf.Variable
實例則與輸入占位符或下一層的輸出tensor所連接。
hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
logits = tf.matmul(hidden2, weights) + biases
最後,程序會返回包含了輸出結果的logits
Tensor。
損失(Loss)
loss()
函數通過添加所需的損失操作,進一步構建圖表。
首先,labels_placeholer
中的值,將被編碼為一個含有1-hot values的Tensor。例如,如果類標識符為“3”,那麼該值就會被轉換為: [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
batch_size = tf.size(labels)
labels = tf.expand_dims(labels, 1)
indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)
concated = tf.concat(1, [indices, labels])
onehot_labels = tf.sparse_to_dense(
concated, tf.pack([batch_size, NUM_CLASSES]), 1.0, 0.0)
之後,又添加一個tf.nn.softmax_cross_entropy_with_logits
操作,用來比較inference()
函數與1-hot標簽所輸出的logits Tensor。
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits,
onehot_labels,
name='xentropy')
然後,使用tf.reduce_mean
函數,計算batch維度(第一維度)下交叉熵(cross entropy)的平均值,將將該值作為總損失。
loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
最後,程序會返回包含了損失值的Tensor。
注意:交叉熵是信息理論中的概念,可以讓我們描述如果基於已有事實,相信神經網絡所做的推測最壞會導致什麼結果。更多詳情,請查閱博文《可視化信息理論
訓練
training()
函數添加了通過梯度下降(gradient descent)將損失最小化所需的操作。
首先,該函數從loss()
函數中獲取損失Tensor,將其交給tf.scalar_summary
,後者在與SummaryWriter
(見下文)配合使用時,可以向事件文件(events file)中生成匯總值(summary values)。在本篇教程中,每次寫入匯總值時,它都會釋放損失Tensor的當前值(snapshot value)。
tf.scalar_summary(loss.op.name, loss)
接下來,我們實例化一個tf.train.GradientDescentOptimizer
,負責按照所要求的學習效率(learning rate)應用梯度下降法(gradients)。
optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
之後,我們生成一個變量用於保存全局訓練步驟(global training step)的數值,並使用minimize()
函數更新係統中的三角權重(triangle weights)、增加全局步驟的操作。根據慣例,這個操作被稱為 train_op
,是TensorFlow會話(session)誘發一個完整訓練步驟所必須運行的操作(見下文)。
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
最後,程序返回包含了訓練操作(training op)輸出結果的Tensor。
訓練模型
一旦圖表構建完畢,就通過fully_connected_feed.py
文件中的用戶代碼進行循環地迭代式訓練和評估。
圖表
在run_training()
這個函數的一開始,是一個Python語言中的with
命令,這個命令表明所有已經構建的操作都要與默認的tf.Graph
全局實例關聯起來。
with tf.Graph().as_default():
tf.Graph
實例是一係列可以作為整體執行的操作。TensorFlow的大部分場景隻需要依賴默認圖表一個實例即可。
利用多個圖表的更加複雜的使用場景也是可能的,但是超出了本教程的範圍。
會話
完成全部的構建準備、生成全部所需的操作之後,我們就可以創建一個tf.Session
,用於運行圖表。
sess = tf.Session()
另外,也可以利用with
代碼塊生成Session
,限製作用域:
with tf.Session() as sess:
Session
函數中沒有傳入參數,表明該代碼將會依附於(如果還沒有創建會話,則會創建新的會話)默認的本地會話。
生成會話之後,所有tf.Variable
實例都會立即通過調用各自初始化操作中的sess.run()
函數進行初始化。
init = tf.initialize_all_variables()
sess.run(init)
sess.run()
方法將會運行圖表中與作為參數傳入的操作相對應的完整子集。在初次調用時,init
操作隻包含了變量初始化程序tf.group
。圖表的其他部分不會在這裏,而是在下麵的訓練循環運行。
訓練循環
完成會話中變量的初始化之後,就可以開始訓練了。
訓練的每一步都是通過用戶代碼控製,而能實現有效訓練的最簡單循環就是:
for step in xrange(max_steps):
sess.run(train_op)
但是,本教程中的例子要更為複雜一點,原因是我們必須把輸入的數據根據每一步的情況進行切分,以匹配之前生成的占位符。
向圖表提供反饋
執行每一步時,我們的代碼會生成一個反饋字典(feed dictionary),其中包含對應步驟中訓練所要使用的例子,這些例子的哈希鍵就是其所代表的占位符操作。
fill_feed_dict
函數會查詢給定的DataSet
,索要下一批次batch_size
的圖像和標簽,與占位符相匹配的Tensor則會包含下一批次的圖像和標簽。
images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size)
然後,以占位符為哈希鍵,創建一個Python字典對象,鍵值則是其代表的反饋Tensor。
feed_dict = {
images_placeholder: images_feed,
labels_placeholder: labels_feed,
}
這個字典隨後作為feed_dict
參數,傳入sess.run()
函數中,為這一步的訓練提供輸入樣例。
檢查狀態
在運行sess.run
函數時,要在代碼中明確其需要獲取的兩個值:[train_op, loss]
。
for step in xrange(FLAGS.max_steps):
feed_dict = fill_feed_dict(data_sets.train,
images_placeholder,
labels_placeholder)
_, loss_value = sess.run([train_op, loss],
feed_dict=feed_dict)
因為要獲取這兩個值,sess.run()
會返回一個有兩個元素的元組。其中每一個Tensor
對象,對應了返回的元組中的numpy數組,而這些數組中包含了當前這步訓練中對應Tensor的值。由於train_op
並不會產生輸出,其在返回的元祖中的對應元素就是None
,所以會被拋棄。但是,如果模型在訓練中出現偏差,loss
Tensor的值可能會變成NaN,所以我們要獲取它的值,並記錄下來。
假設訓練一切正常,沒有出現NaN,訓練循環會每隔100個訓練步驟,就打印一行簡單的狀態文本,告知用戶當前的訓練狀態。
if step % 100 == 0:
print 'Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)
狀態可視化
為了釋放TensorBoard所使用的事件文件(events file),所有的即時數據(在這裏隻有一個)都要在圖表構建階段合並至一個操作(op)中。
summary_op = tf.merge_all_summaries()
在創建好會話(session)之後,可以實例化一個tf.train.SummaryWriter
,用於寫入包含了圖表本身和即時數據具體值的事件文件。
summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
graph_def=sess.graph_def)
最後,每次運行summary_op
時,都會往事件文件中寫入最新的即時數據,函數的輸出會傳入事件文件讀寫器(writer)的add_summary()
函數。。
summary_str = sess.run(summary_op, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, step)
事件文件寫入完畢之後,可以就訓練文件夾打開一個TensorBoard,查看即時數據的情況。
保存檢查點(checkpoint)
為了得到可以用來後續恢複模型以進一步訓練或評估的檢查點文件(checkpoint file),我們實例化一個tf.train.Saver
。
saver = tf.train.Saver()
在訓練循環中,將定期調用saver.save()
方法,向訓練文件夾中寫入包含了當前所有可訓練變量值得檢查點文件。
saver.save(sess, FLAGS.train_dir, global_step=step)
這樣,我們以後就可以使用saver.restore()
方法,重載模型的參數,繼續訓練。
saver.restore(sess, FLAGS.train_dir)
評估模型
每隔一千個訓練步驟,我們的代碼會嚐試使用訓練數據集與測試數據集,對模型進行評估。do_eval
函數會被調用三次,分別使用訓練數據集、驗證數據集合測試數據集。
print 'Training Data Eval:'
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.train)
print 'Validation Data Eval:'
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.validation)
print 'Test Data Eval:'
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.test)
注意,更複雜的使用場景通常是,先隔絕
data_sets.test
測試數據集,隻有在大量的超參數優化調整(hyperparameter tuning)之後才進行檢查。但是,由於MNIST問題比較簡單,我們在這裏一次性評估所有的數據。
構建評估圖表(Eval Graph)
在打開默認圖表(Graph)之前,我們應該先調用get_data(train=False)
函數,抓取測試數據集。
test_all_images, test_all_labels = get_data(train=False)
在進入訓練循環之前,我們應該先調用mnist.py
文件中的evaluation
函數,傳入的logits和標簽參數要與loss
函數的一致。這樣做事為了先構建Eval操作。
eval_correct = mnist.evaluation(logits, labels_placeholder)
evaluation
函數會生成tf.nn.in_top_k
操作,如果在K個最有可能的預測中可以發現真的標簽,那麼這個操作就會將模型輸出標記為正確。在本文中,我們把K的值設置為1,也就是隻有在預測是真的標簽時,才判定它是正確的。
eval_correct = tf.nn.in_top_k(logits, labels, 1)
評估圖表的輸出(Eval Output)
之後,我們可以創建一個循環,往其中添加feed_dict
,並在調用sess.run()
函數時傳入eval_correct
操作,目的就是用給定的數據集評估模型。
for step in xrange(steps_per_epoch):
feed_dict = fill_feed_dict(data_set,
images_placeholder,
labels_placeholder)
true_count += sess.run(eval_correct, feed_dict=feed_dict)
true_count
變量會累加所有in_top_k
操作判定為正確的預測之和。接下來,隻需要將正確測試的總數,除以例子總數,就可以得出準確率了。
precision = float(true_count) / float(num_examples)
print ' Num examples: %d Num correct: %d Precision @ 1: %0.02f' % (
num_examples, true_count, precision)
最後更新:2017-08-22 15:36:00