Reducing the SentencePiece Vocabulary Size of Pretrained NLP Models

Useful for fine-tuning on a subset of available languages

Jan 18, 2021 · 766 words · 4 minute read nlp

Photo Credit

Photo Credit

Motivation

Q: Why and when would we want to trim down the vocabulary size of a pretrained model?

A: When a large portion of the vocabulary isn’t used in your downstream task, it will make sense to get rid of the redundant part of the vocabulary to increase the model speed.

For example, Google’s multilingual version of T5mT5 — was pretrained on 101 languages. Imagine if we only use English, Japanese, and Chinese in our downstream text generation task. We would waste a lot of time and space to process the rows in the embedding matrix and the LM head that corresponds to tokens that never appear in the dataset.

In this post, I’ll demonstrate how to reduce the vocabulary size of a trained SentencePiece model. SentencePiece is used in XLNet, ALBERT, Marian, and T5. Other types of tokenizers are not covered.

Specifically, we’ll shrink the vocabulary of the mt5-small pretrained model. All tokens that are not used in the Chinese part of the XNLI dataset will be removed. As a result, the vocabulary size will go down from 250K to below 31K — an 87.6% reduction.

References

The solution presented in this post comes from these two notebooks:

I create an example showcasing the mechanism behind the code and verify the result. The complete notebook can be accessed here on Github.

Code Walkthrough

Download the Pretrained Model

We use huggingface/transformers to download the pretrained tokenizer(SentencePiece model):

from pathlib import Path
import shutil

from transformers import MT5Tokenizer

Path("cache/").mkdir(exist_ok=True)
if Path("cache/mt5-small").exists():
    shutil.rmtree("cache/mt5-small")

tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small")
tokenizer.save_pretrained("cache/mt5-small")

Download the XNLI dataset

Again, we use the datasets library from huggingface to download the Chinese part of the XNLI dataset:

from datasets import load_dataset

dataset = load_dataset("xnli", "zh")

Collect the Tokens

Since we want to keep all the tokens that appear in this downstream dataset, we need to get a list of them:

from itertools import chain
from tqdm import tqdm

def tokenize_data(data, batch_size=1024):
    global seen
    for i in tqdm(range(0, len(data), batch_size)):
        seen = seen.union(
            set(
              chain.from_iterable(
                tokenizer.batch_encode_plus(data[i:(i+batch_size)],
                return_attention_mask=False)["input_ids"]
              )
            )
        )

seen = set()
for subset in ("train", "test", "validation"):
    print(subset)
    tokenize_data(dataset[subset]["hypothesis"])
    tokenize_data(dataset[subset]["premise"])

# You can also add some additional (meta) tokens:
seen = seen.union(set(tokenizer.encode("mnli premise: hypothesis: <unk>")))

Load the SentencePiece Model

Here we load the pretrained SentencePiece model into memory (as a Protocol Buffers object):

from sentencepiece import sentencepiece_model_pb2 as model

m = model.ModelProto()
m.ParseFromString(open("cache/mt5-small/spiece.model", 'rb').read())

# There are some reserved places for speical tokens
for i, piece in enumerate(m.pieces[:320]):
    if i % 20 == 0:
        print(i, piece.piece)

We can see that the first 259 tokens are reserved for functional tokens. It might be a good idea to keep them.

Shrink the SentencePiece Model

Because m.pieces is a Protocol Buffers field, we can not merely point it to a new list. Instead, we need to use the field’s methods to manipulate its content:

kept_pieces, i = [], len(m.pieces) - 1
while len(m.pieces):
    piece = m.pieces.pop()
    if i < 259 or i in seen:
        kept_pieces.append(piece)
    i -= 1
kept_pieces = list(reversed(kept_pieces))

# Backup the old model
Path("cache/mt5-small/spiece.model").rename("cache/mt5-small/spiece.model.old")
# Write the new model to disk
with open("cache/mt5-small/spiece.model", 'wb') as f:
    f.write(m.SerializeToString())

We’ll also need to keep track of the tokens that are retained, so we can know which rows to keep in the embedding matrix and the LM head.

import json

kept_ids = sorted(list(seen.union(set(range(259)))))
print(len(kept_ids))
with open("cache/mt5-small/kept_ids.json", 'w') as f:
    json.dump(kept_ids, f)

Verification

We can verify that our new SentencePiece model can correctly tokenize our dataset (by encode and then decode sentences):

import random
tokenizer = MT5Tokenizer.from_pretrained("cache/mt5-small")

for i in random.sample(range(100), k=10):
    # the space placements are slightly different from the original
    converted = tokenizer.decode(
        tokenizer.encode(dataset["train"]["hypothesis"][i]), skip_special_tokens=True
    ).replace(" ", "")
    assert converted == dataset["train"]["hypothesis"][i].replace(" ", "")

Moving Further

Now we know how to create a new SentencePiece model with a smaller vocabulary size. The next step would be to use it to fine-tune an NLP model. We’ll need to modify the embedding matrix and the LM head (if the LM head’s weight and the embedding matrix are not tied together). It is beyond the scope of this post since the NLP model can be implemented in any of the deep learning frameworks (PyTorch, Tensorflow, MXNet, etc.).

Multilingual pretrained models are not the only use cases of this technique. It can also be useful when the upstream model is trained on a corpus covering several domains, while you are only interested in one of those domains in the downstream task. Just make sure you retain all the necessary tokens in the downstream task; otherwise, the performance might suffer.

tweet Share