TensorFlow 2.1 with TPU in Practice

Case Study: Google QUEST Q&A Labeling Competition

Feb 13, 2020 Β· 2962 words Β· 14 minute read nlp tensorflow kaggle

Executive Summary

  • TensorFlow has become much easier to use: As an experience PyTorch developer who only knows a bit of TensorFlow 1.x, I was able to pick up TensorFlow 2.x in my spare time in 60 days and do competitive machine learning.
  • TPU has never been more accessible: The new interface to TPU in TensorFlow 2.1 works right out of the box in most cases and greatly reduces the development time required to make a model TPU-compatible. Using TPU drastically increases the iteration speed of experiments.
  • We present a case study of solving a Q&A labeling problem by fine-tuning the RoBERTa-base model from huggingface/transformer library:
  • (TensorFlow 2.1 and TPU are also a very good fit for CV applications. A case study of solving an image classification problem will be published in about a month.)


I was granted free access to Cloud TPUs for 60 days via TensorFlow Research Cloud. It was for the TensorFlow 2.0 Question Answering competition. I chose to do this simpler Google QUEST Q&A Labeling competition first but unfortunately couldn’t find enough time to go back and do the original one (sorry!).

I was also granted $300 credits for the TensorFlow 2.0 Question Answering competition and had used those to develop a PyTorch baseline. They also covered the costs of Cloud Compute VM and Cloud Storage used to train models on TPU.


Google was handing out free TPU access to competitors in the TensorFlow 2.0 Question Answering competition, as an incentive for them to try out the newly added TPU support in TensorFlow 2.1 (then RC). Because the preemptible GPUs on GCP are barely usable at the time, I decided to give it a shot. It all began with this tweet:

Turns out that the TensorFlow model in huggingface/transformers library can work with TPU without modification! I then proceeded to develop models using TensorFlow(TF) 2.1 for a simpler competition Google QUEST Q&A Labeling.

I missed the post-processing trick in the QUEST competition because I spent most of my limited time wrestling with TF and TPU. After applying the post-processing trick, my final model would be somewhat competitive at around 65th place (silver medal) on the final leaderboard. The total training time of my 5-fold models using TPUv2 on Colab was about an hour. This is a satisfactory result in my opinion, given the time constraint.

TensorFlow 2.x

The TensorFlow 2.x has become much more approachable, and the customizable training loops provide a swath of opportunities to do creative things. I think I’ll be able to re-implement top solutions of the competition in TensorFlow without banging my head on the door (at least less frequently).

On the other hand, TF 2.x is still not as intuitive as PyTorch. Documentation and community support still have much to be desired. Many of the search results still point to TF 1.x solutions that do not apply to TF 2.x.

As an example, I ran into this problem in which the CuDNN failed to initialize:

One of the solutions is to limit the GPU memory usage, and here’s a confusingly long thread on how to do so:

TPU Support

Despite all the drawbacks, the TPU support in TF 2.1 is fantastic and has become my main reason for using TF 2.1+. It’s pretty impressive that the once extremely unstable TPU support in Keras has evolved into this good piece of engineering.

Although TensorFlow Research Cloud gave my access to multiple TPU units, I used only one of them as I didn’t see the need to do serious hyper-parameter optimizations yet. The competition data set is not ideal for TPU, as it is quite small (a few thousands of examples). I have to limit the batch size to achieve the best performance (in terms of the evaluation metric), but it is still a lot faster than training on my single local GTX 1070 GPU (4 ~ 8x speedup). TPUv2 is more than sufficient in this case (comparing to TPUv3).

One potentially interesting comparison would be using two V100 GPUs, which combined are a little more expensive than TPUv2, with a bigger batch size to train the same model.

The TPU on Google Colab also supports TF 2.1 now. You can train models much faster with it than any of the free GPU Colab provides (currently the best offer is a single Tesla P100). Check this notebook for a concrete example:

(I know that PyTorch has its own TPU support now, but it is still quite hard to use last time I checked, and it is not supported in Google Colab. Maybe I’ll take another look in the next few weeks.)

Case Study and Code Snippets

This section will briefly describe my solution to the QUEST Q&A Labeling competition, and discuss some parts of the code that I think are most helpful for those to come from PyTorch as I did. This section assumes that you already have a basic understanding of TensorFlow 2.x. If you’re not sure, please refer to the official tutorial Effective TensorFlow 2.

Source Code


  1. TF-Helper-Bot: this is a simple high-level wrapper of TensorFlow I wrote to improve code reusability.
  2. Input Formulation and TFRecords Preparation.
  3. TPU-compatible Data Loading.
  4. The Siamese Encoder Network.


TF-Helper-Bot is a simple high-level wrapper of TensorFlow and is basically a port of my other project β€” PyTorch-Helper-Bot(which is heavily inspired by the fast.ai library). It handles custom training loops, distributed training (TPU), metric evaluation, checkpoints, and some other useful stuff for you.


The central component of TF-Helper-Bot is the BaseBot class. It would normally be inherited and adapted for each new project. Think of it as a robot butler. Give her/him a model, an optimizer, a loss function, and other optional goodies via __init__(), and call train(). The robot will have your model trained and ready. You can also call eval() to do validation or testing, and call predict() to make predictions.

One important way to improve TensorFlow 2.x code performance is to use tf.function to mark a function for JIT compilation. This is how TF-Helper-Bot does it (source code location):

(BaseBot uses dataclass to manage internal instance states. A additional initialization code should go into __post_init__() method)

class BaseBot:
    train_dataset: tf.data.Dataset
    # omitted...
    criterion: Callable
    model: tf.keras.Model
    optimizer: tf.keras.optimizers.Optimizer
    name: str = "basebot"
    # omitted...

    def __post_init__(self):
        # omitted...

        def get_gradient(input_tensors, target):
            with tf.GradientTape() as tape:
                output = self.model(
                    input_tensors, training=True)
                loss_raw = self.criterion(
                    target, self._extract_prediction(output)
                loss_ = (
                    if self.mixed_precision else loss_raw
            gradients_ = tape.gradient(
                loss_, self.model.trainable_variables)
            if self.mixed_precision:
                gradients_ = self.optimizer.get_unscaled_gradients(gradients_)
            return loss_raw, gradients_

        def step_optimizer(gradients):

        def predict_batch(input_tensors):
            return self.model(input_tensors, training=False)

        self._get_gradient = get_gradient
        self._step_optimizer = step_optimizer
        self._predict_batch = predict_batch
  • The reason why it uses seemingly contrived nested functions is that decorating the class methods with tf.function doesn’t work for me. Please let me know if you have any working examples of that.
  • Even decorating the bland predict_batch can improve the performance significantly. Without this, I couldn’t get the inference kernel to finish within the time limit in another CV competition.
  • The get_gradient method supports mixed-precision training, which isn’t covered in this post.


If we want to do distributed training (TPU has 8 cores), we’ll need some specialized interfaces to do that. We accommodate these by subclassing BaseBot(source code location):

class BaseDistributedBot(BaseBot):
    strategy: tf.distribute.Strategy = None

    def __post_init__(self):
        assert self.strategy is not None
        assert self.gradient_accumulation_steps == 1, (
            "Distribution mode doesn't suppoprt gradient accumulation"
        def train_one_step(input_tensor_list, target):
            loss, gradients = self._get_gradient(
                input_tensor_list[0], target)
            return loss

        self._train_one_step = train_one_step

    def train_one_step(self, input_tensors, target):
        loss = self.strategy.experimental_run_v2(
            args=(input_tensors, target)
        return self.strategy.reduce(
            tf.distribute.ReduceOp.MEAN, loss, axis=None

The TPU requires that the entire training loop to be compiled into graphs (i.e., the function you pass to experimental_run_v2 must be compiled), and I couldn’t find a way to do gradient accumulation in such situation. As a result, gradient accumulation was removed and the train_one_step() method has been simplified. One additional note: when I combined _get_gradient and _step_optimizer, they both would be automatically compiled, so they don’t really need their own tf.function decorator.

You’ll need to get a strategy object from tf_hepler_bot.utils.prepare_tpu(source code location):

(If you’re using Cloud TPU instead of TPU from Colab, you’ll also need to set the TPU_NAME environment variable. Check the documentation for more details. Also, remember to allow access to Cloud APIs on your VM.)

def prepare_tpu():
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
        print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
    except ValueError:
        tpu = None
    strategy = tf.distribute.get_strategy()
    if tpu:
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
    return strategy, tpu

The TPU will be initialized if it exists. It’ll return a dummy strategy when using a single GPU or CPU. (I haven’t tested it with multiple GPU yet.)

Preparing TFRecord Files

The QUEST training dataset contains just over 6,000 question-answer pairs. There are only around 3,500 unique questions. Each question has a title and a body field.

I split a pair into two sequences β€” question and answer. This is how I formulate the question sequence:

<s> question </s><s> this is the title of a question </s><s> this is the body of a question </s>

And the answer sequence:

<s> answer </s><s> this is the body of an answer </s>

(I padded the sequence with some spaces to make it more readable.)

The <s> and </s> is what RoBERTa uses to mark a sentence. The <s> question </s> and <s> answer </s> header is to help the encoder distinguish between the two types of sequences.

The input data pipeline used by TPU must not contain any python code, and the TPU does not support reading from your local filesystem (the TPU is connected to your VM via network). The simplest way to create an input pipeline for TPU is to save your data into TFRecord files and store it in a Cloud Storage bucket. TPU will read directly from your bucket and run the compiled pipeline.

(Theoretically, if your dataset fits into memory, you can create it locally in memory and send it over to TPU. But I haven’t found any good examples of this approach yet. Anyway, I think that always dumping your dataset into TFRecord files for more consistency is generally a good practice.)

This is how I convert the tokenized sequences and labels into TFRecord files(source code location):

def to_example(input_dict, labels):
    feature = {
        "input_ids_question": tf.train.Feature(
        "input_mask_question": tf.train.Feature(
        "input_ids_answer": tf.train.Feature(
        "input_mask_answer": tf.train.Feature(
        "labels": tf.train.Feature(
    return tf.train.Example(features=tf.train.Features(feature=feature))

def _write_tfrecords(inputs, labels, output_filepath):
    with tf.io.TFRecordWriter(str(output_filepath)) as writer:
        for input_dict, labels_single in zip(inputs, labels):
            example = to_example(input_dict, labels_single)
    print("Wrote file {} containing {} records".format(
        output_filepath, len(inputs)))

Data Loading

TensorFlow as a tf.data module specifically for creating input pipelines. It is especially useful when writing pipelines that will be compiled into graphs (which is required by TPU). The following is how I use it to load the data from TFRecord files (with some details omitted)(source code location):

AUTOTUNE = tf.data.experimental.AUTOTUNE

def tfrecord_dataset(filename, batch_size, strategy, is_train: bool = True):
    opt = tf.data.Options()
    opt.experimental_deterministic = False

    # omitted...

    features_description = {
        "input_ids_question": tf.io.FixedLenFeature([max_q_len], tf.int64),
        "input_mask_question": tf.io.FixedLenFeature([max_q_len], tf.int64),
        "input_ids_answer": tf.io.FixedLenFeature([max_a_len], tf.int64),
        "input_mask_answer": tf.io.FixedLenFeature([max_a_len], tf.int64),
        "labels": tf.io.FixedLenFeature([30], tf.float32),

    def _parse_function(example_proto):
        # Parse the input `tf.Example` proto using the dictionary above.
        example = tf.io.parse_single_example(
            example_proto, features_description)
        return (
                'input_ids_question': tf.cast(example['input_ids_question'], tf.int32),
                'attention_mask_question': tf.cast(example['input_mask_question'], tf.int32),
                'input_ids_answer': tf.cast(example['input_ids_answer'], tf.int32),
                'attention_mask_answer': tf.cast(example['input_mask_answer'], tf.int32),

    raw_dataset = tf.data.TFRecordDataset(
        filename, num_parallel_reads=4
    dataset = raw_dataset.map(
        _parse_function, num_parallel_calls=AUTOTUNE
    if is_train:
        dataset = dataset.shuffle(
            2048, reshuffle_each_iteration=True
        # omitted...
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTOTUNE)
  • We’ve padded the token sequences to a fixed length when creating TFRecords, so tf.io.FixedLenFeature is used to parse the feature. If we want to use sequences of variable lengths, we should use tf.io.FixedLenSequenceFeature. (More details later.)
  • TPU only supports 32-bit integers, and the TFRecord only supports 64-bit integers, so we need to do a conversion.
  • cache() method is called to save the parsed data into memory. This prevents the loader from reading and parsing the same file multiple times. (You don’t want to call cache() after the shuffle, as it can potentially make the dataset only get shuffled once.)
  • This is one of the less tuned parts of the codebase, I mostly just followed the documentation and its recommendations. You could probably get better throughput by tinkering with the pipeline.

Dynamic Batching?

One neat trick to increase the training speed in PyTorch is to group sequences with similar lengths, pick a batch of them from this group, and then pad to the maximum length of that batch. It is sometimes called a “sortish sampler”. This often greatly reduces the padding required and the average sequence length after padding.

However, this trick doesn’t seem to work well when training with TPU. Fixed-size input tensors run fastest on TPU in my experience. Nonetheless, I’ll describe how to do it below in case you are interested.

  1. Remove the padding in the TFRecord preparation script/function. (duh)
  2. Use tf.io.FixedLenSequenceFeature in features_description (code below).
  3. Replace batch() with padded_batch(batch_size, padded_shapes=None).
  4. Change the tf.function decorators to tf.function(experimental_relax_shapes) to allow input tensors of different shapes without retracing (i.e., recompiling).
features_description = {
    "input_ids_question": tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
    "input_mask_question": tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
    "input_ids_answer": tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
    "input_mask_answer": tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
    "labels": tf.io.FixedLenFeature([30], tf.float32),

Distributed Dataset

Distributed training with a custom training loop requires converting a tf.data.Dataset instance into a distributed dataset:

train_dist_ds = strategy.experimental_distribute_dataset(
valid_dist_ds = strategy.experimental_distribute_dataset(

If you’re using the Keras fit() API, you won’t need to do this conversion.

The Siamese Encoder Network

Finally, let’s take a look at the neural network model that put tags on the input pair of question and answer. The tokenized question and answer are fed to the same RoBERTa encoder (a.k.a. Siamese network). The hidden states of the last layer of the encoder are put through an average pooling layer. Here’s the code of the encoder(source code location):

(Reminder: we are using the RoBERTa model from huggingface/transformers here.)

class RobertaEncoder(TFRobertaPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels
        self.roberta = TFRobertaMainLayer(config, name="roberta")
        self.pooling = AveragePooling()

    def call(self, inputs, **kwargs):
        if "attention_mask" not in inputs:
            inputs["attention_mask"] = tf.ones(
                tf.shape(inputs["input_ids"])[:2], tf.int32
        outputs = self.roberta(inputs, **kwargs)[0]
        return self.pooling(outputs, inputs["attention_mask"])

There are 30 types of target labels, and I split them into three categories:

  1. Those that are only related to the question.
  2. Those that are only related to the answer.
  3. Those that require information from both the question and the answer.

We create three classification heads for each category. Each head only uses data from the relevant encoder (e.g., the question head will only use results from the question encoder) to reduce over-fitting. I also found that putting a context gating on the intermediate states slightly improves the accuracy.

The top-level model code with some details omitted(source code location) :

class DualRobertaModel(tf.keras.Model):
    def __init__(self, config, model_name, pretrained: bool = True):
        # omitted...
        if pretrained:
            self.roberta = RobertaEncoder.from_pretrained(
                model_name, config=config, name="roberta_question")
            self.roberta = RobertaEncoder(
                config=config, name="roberta_question")
        self.dropout = tf.keras.layers.Dropout(0.5)
        self.q_classifier = tf.keras.layers.Dense(
            # omitted...
        self.a_classifier = tf.keras.layers.Dense(
            # omitted...
        self.j_classifier = tf.keras.layers.Dense(
            # omitted...
        # omitted...

    # omitted...

    def call(self, inputs, **kwargs):
        pooled_output_question = self.roberta(
                "input_ids": inputs["input_ids_question"],
                "attention_mask": inputs["attention_mask_question"]
            }, **kwargs
        pooled_output_answer = self.roberta(
                "input_ids": inputs["input_ids_answer"],
                "attention_mask": inputs["attention_mask_answer"]
            }, **kwargs
        combined = tf.concat(
                pooled_output_question, pooled_output_answer,
                pooled_output_answer * pooled_output_question
        q_logit = self.q_classifier(self.dropout(
            ), training=kwargs.get("training", False)
        # omitted...
        logits = tf.concat(
            [q_logit, a_logit, j_logit],
        # omitted...

PyTorch developer should find the above quite similar to a PyTorch model.

(This is a form of multitask learning. I’ve seen some of the top solutions chose to train two or three separate models instead. The latter approach also helps to alleviate the problem of some questions appearing multiple times in the dataset.)

Wrapping Up

In this post, we briefly discussed how the learning curve of TensorFlow has been significantly improved in the 2.x release, and how the TPU has become more accessible than ever. We also present a case study of solving a Q&A labeling problem by fine-tuning RoBERTa-base model from huggingface/transformer library and with it some code snippets that could be useful to those who are more familiar with PyTorch.

TensorFlow 2.1 and TPU are also a very good fit for CV applications. I have another CV project in the pipeline and it has been a nice development experience so far. I’ll probably publish another case study of solving an image classification problem in about a month.

Thanks for reading! I’d love to hear from you. If you find any details in the presented codebase confusing, please let me know in the comment section. I’ll add a section to this post or write a bonus post.

