閱讀856 返回首頁    go 阿裏雲 go 技術社區[雲棲]


TensorFlow中的那些高級API

1.png

TensorFlow擁有很多庫,比如KerasTFLearnSonnet,對於模型訓練來說,使用這些庫比使用低級功能更簡單。盡管Keras的API目前正在添加到TensorFlow中去,但TensorFlow本身就提供了一些高級構件,而且最新的1.3版本中也引入了一些新的構件。

在這篇文章中,我們將看到一個使用了這些最新的高級構件的例子,包括Estimator(估算器)、Experiment(實驗)和Dataset(數據集)。值得注意的是,你可以獨立地使用Experiment和Dataset。我在這裏假設你已經了解TensorFlow的基礎知識;如果沒有的話,那麼TensorFlow官網上提供的教程值得學習。

2.png
Experiment、Estimator和DataSet框架以及它們之間的交互。

我們在本文中將使用MNIST作為數據集。這是一個使用起來很簡單的數據集,可以從TensorFlow官網獲取到。你可以在這個gist中找到完整的代碼示例。使用這些框架的其中一個好處是,我們不需要直接處理會話

Estimator(估算器)類

Estimator類代表了一個模型,以及如何對這個模型進行訓練和評估。我們可以像下麵這段代碼創建一個Estimator:

return tf.estimator.Estimator(
    model_fn=model_fn,  # First-class function
    params=params,  # HParams
    config=run_config  # RunConfig
)

要創建Estimator,需要傳入一個模型函數、一組參數和一些配置。

  • 傳入的**參數**應該是模型超參數的一個集合。這可以是一個dictionary,但是我們將在這個例子中把它表示成一個HParams對象,就像namedtuple一樣。

  • 傳入的**配置**用於指定如何運行訓練和評估,以及在哪裏存儲結果。這個配置是一個RunConfig對象,該對象會把模型運行環境相關的信息告訴Estimator。

  • 模型函數是一個Python函數,它根據給定的輸入構建模型。

模型函數

模型函數是一個Python函數,並作為一級函數傳遞給Estimator。稍後我們會看到,TensorFlow在其他地方也使用了一級函數。將模型表示為一個函數的好處是可以通過實例化函數來多次創建模型。模型可以在訓練過程中用不同的輸入重新創建,例如,在訓練過程中運行驗證測試。

模型函數把**輸入特征**作為參數,將相應的**標簽**作為張量。它也能以某種方式來告知用戶模型是在訓練、評估或是在執行推理。模型函數的最後一個參數是**超參數**集合,它們與傳遞給Estimator的超參數集合相同。模型函數返回一個**EstimatorSpec**對象,該對象定義了一個完整的模型。

EstimatorSpec對象用於對操作進行預測、損失、訓練和評估,因此,它定義了一個用於訓練、評估和推理的完整的模型圖。由於EstimatorSpec隻可用於常規的TensorFlow操作,因此,我們可以使用像TF-Slim這樣的框架來定義模型。

Experiment(實驗)類

Experiment類定義了如何訓練模型,它與Estimator完美地集成在一起。我們可以像如下代碼創建一個Experiment對象:

experiment = tf.contrib.learn.Experiment(
    estimator=estimator,  # Estimator
    train_input_fn=train_input_fn,  # First-class function
    eval_input_fn=eval_input_fn,  # First-class function
    train_steps=params.train_steps,  # Minibatch steps
    min_eval_frequency=params.min_eval_frequency,  # Eval frequency
    train_monitors=[train_input_hook],  # Hooks for training
    eval_hooks=[eval_input_hook],  # Hooks for evaluation
    eval_steps=None  # Use evaluation feeder until its empty
)

以下幾種情況會把Experiment對象作為輸入:

  • 一個**estimator**(例如我們上麵定義的)。

  • 作為一級函數**訓練和評估數據**。這裏使用了與前麵提到的模型函數相同的概念。如果需要的話,通過傳入函數而不是操作,可以重新創建輸入圖。稍後我們還會談到這個。

  • 訓練和評估hook(鉤子)。鉤子可用於保存或監視特定的內容,或者在圖或會話中設置某些操作。例如,我們將其傳入到操作中,幫助初始化數據加載器。

  • 描述需要訓練多久以及何時評估的各種參數。

一旦定義了experiment,我們就可以像下麵這段代碼那樣使用learn_runner.run來運行它訓練和評估模型:

learn_runner.run(
    experiment_fn=experiment_fn,  # First-class function
    run_config=run_config,  # RunConfig
    schedule="train_and_evaluate",  # What to run
    hparams=params  # HParams
)

與模型函數和數據函數一樣,learn_runner將一個創建experiment的函數作為參數傳入。

Dataset(數據集)類

我們將使用Dataset類和相應的Iterator來表示數據的訓練和評估,以及創建在訓練過程中迭代數據的數據饋送器。 在本示例中,我們將使用在Tensorflow中可用的MNIST數據,並為其構建一個Dataset包裝。例如,我們將把訓練輸入數據表示為:

# Define the training inputs
def get_train_inputs(batch_size, mnist_data):
    """Return the input function to get the training data.
    Args:
        batch_size (int): Batch size of training iterator that is returned
                          by the input function.
        mnist_data (Object): Object holding the loaded mnist data.
    Returns:
        (Input function, IteratorInitializerHook):
            - Function that returns (features, labels) when called.
            - Hook to initialise input iterator.
    """
    iterator_initializer_hook = IteratorInitializerHook()

    def train_inputs():
        """Returns training set as Operations.
        Returns:
            (features, labels) Operations that iterate over the dataset
            on every evaluation
        """
        with tf.name_scope('Training_data'):
            # Get Mnist data
            images = mnist_data.train.images.reshape([-1, 28, 28, 1])
            labels = mnist_data.train.labels
            # Define placeholders
            images_placeholder = tf.placeholder(
                images.dtype, images.shape)
            labels_placeholder = tf.placeholder(
                labels.dtype, labels.shape)
            # Build dataset iterator
            dataset = tf.contrib.data.Dataset.from_tensor_slices(
                (images_placeholder, labels_placeholder))
            dataset = dataset.repeat(None)  # Infinite iterations
            dataset = dataset.shuffle(buffer_size=10000)
            dataset = dataset.batch(batch_size)
            iterator = dataset.make_initializable_iterator()
            next_example, next_label = iterator.get_next()
            # Set runhook to initialize iterator
            iterator_initializer_hook.iterator_initializer_func = \
                lambda sess: sess.run(
                    iterator.initializer,
                    feed_dict={images_placeholder: images,
                               labels_placeholder: labels})
            # Return batched (features, labels)
            return next_example, next_label

    # Return function and hook
    return train_inputs, iterator_initializer_hook

調用這個get_train_inputs將返回一個一級函數,用於在TensorFlow圖中創建數據加載操作,以及返回一個用於初始化迭代器的Hook

本示例中使用的MNIST數據最初是一個Numpy數組。我們創建了一個占位符張量來獲取數據;使用占位符的目的是為了避免數據的複製。接下來,我們在from_tensor_slices的幫助下創建一個切片數據集。我們要確保該數據集可以運行無限次數,並且數據被重新洗牌並放入指定大小的批次中。

要迭代數據,就需要從數據集中創建一個迭代器。由於我們正在使用占位符,因此需要使用NumPy數據在相關會話中對占位符進行初始化。可以通過創建一個可初始化的迭代器來實現這個。在創建圖的時候,將創建一個自定義的IteratorInitializerHook對象來初始化迭代器:

class IteratorInitializerHook(tf.train.SessionRunHook):
    """Hook to initialise data iterator after Session is created."""

    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None

    def after_create_session(self, session, coord):
        """Initialise the iterator after the session has been created."""
        self.iterator_initializer_func(session)

IteratorInitializerHook繼承自SessionRunHook。這個鉤子將在相關會話創建後立即調用after_create_session,並使用正確的數據初始化占位符。這個鉤子由我們的get_train_inputs函數返回,並在創建時傳遞給Experiment對象。

train_inputs函數返回的數據加載操作是TensorFlow的操作,該操作每次評估時都會返回一個新的批處理。

運行代碼

現在,我們已經定義了所有內容,可以使用下麵這個命令運行代碼了:

python mnist_estimator.py --model_dir ./mnist_training --data_dir ./mnist_data

如果不傳入參數,它將使用文件開頭的默認標誌來確定數據和模型保存的位置。

在訓練過程中,在終端上會輸出這段時間內的全​​局步驟、損失和準確性等信息。除此之外,Experiment和Estimator框架將記錄TensorBoard可視化的某些統計信息。如果我們運行這個命令:

tensorboard --logdir='./mnist_training'

那麼我們可以看到所有的訓練統計數據,如訓練損失、評估準確性、每個步驟的時間,以及模型圖。

3.png
TensorBoard可視化中的評估準確度

我寫這篇文章,是因為我在編寫代碼示例時,無法找到有關Tensorflow Estimator 、Experiment和Dataset框架太多的信息和示例。我希望這篇文章能向你簡要介紹一下這些框架是如何工作的,它們采用了什麼樣的抽象方法以及如何使用它們。如果你對使用這些框架感興趣,下麵我將介紹一些注意點和其他的文檔。

有關Estimator、Experiment和Dataset框架的注意點

文章原標題《Higher-Level APIs in TensorFlow》,作者:Peter Roelants,譯者:夏天,審校:主題曲。

文章為簡譯,更為詳細的內容,請查看原文需要爬梯,不方便的同學也可以下載下方的PDF附件,閱讀原文內容。

最後更新:2017-09-04 17:32:28

  上一篇:go  1
  下一篇:go  怎麼樣閱讀tomcat源碼