β

Some tips about using google’s TPU

斯巴达第二季 109 阅读

About one month ago, I submit a request to Google Research Cloud for using TPU for free. Fortunately, I received the approvement yesterday. The approvement let me use 5 regular Cloud TPUs and 100 preemptible Cloud TPUs for free for 30 days with only submitting my GCP project name to it.
Then I have to change my previous Tensorflow program to let it run on TPUs. I can’t just change tf.device(‘/gpu:0’) to ‘tf.device(‘/tpu:0’) in code to run training on Google TPU. Actually, there are many documents about how to modify the code for this, such as TPUEstimator , Using TPUs etc.

Here are some tips about porting code for TPUs:

1. We can only use TPUEstimator for training

classifier = tf.contrib.tpu.TPUEstimator(
                model_fn = model_wrapper,
                config = run_config,
                use_tpu = FLAGS.use_tpu,
                train_batch_size = 64,
                batch_axis = [0, 0],
                params = {'optimizer': opt})

Pay attention to the ‘batch_axis’. It tells TPU pod to split data by ‘0’ dimension for data and labels, for I use ‘NHWC’ data format.

2. model_fn and data_input_fn in TPUEstimator has arguments more than regular tf.estimator.Estimator. We need to fetch some arguments (‘batch_size’) from params.

def data_input_fn(params):
    batch = params['batch_size']
...
def model_fn(features, labels, mode, config, params):
...

3. TPU doesn’t support the operation like

images = tf.contrib.image.rotate(images, tf.random_uniform([1], minval = -math.pi / 4.0, maxval = math.pi / 4.0))

So try to avoid using them

4. Carefully use tf.dataset or else it will report data shape error. The code below could run correctly so far

dataset = files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, sloppy = True, cycle_length = buff_size))
  dataset = dataset.map(_parse_function)
  dataset = dataset.repeat()
  dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
  dataset = dataset.shuffle(batch_size * buff_size)
    
  iterator = dataset.make_initializable_iterator()

5. Because using TPUEstimator, we can’t init iterator of tf.dataset in ‘session.run()’, so a little trick should be used:

def data_input_fn():
    ...
    tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, it.initializer)
    ...

6. The Tensorflow in GCP VM instance only supports loading datasets from and storing model into GCP storage.

run_config = tf.contrib.tpu.RunConfig(
            master = master,
            evaluation_master = master,
            model_dir = 'gs://my-project/models/',
            session_config = tf.ConfigProto(
                allow_soft_placement = True, log_device_placement = True),
            tpu_config = tf.contrib.tpu.TPUConfig(
                FLAGS.iterations, FLAGS.num_shards)
        )

7. There aren’t any hooks for TPUEstimator currently in Tensorflow-1.9. So I can’t see any report from console after launching a TPU program. Hope Google could improve it as soon as possible.

作者:斯巴达第二季
董昊 (Robin Dong)
原文地址:Some tips about using google’s TPU, 感谢原作者分享。