Q: Why and when would we want to trim down the vocabulary size of a pretrained model?
A: When a large portion of the vocabulary isn’t used in your downstream task, it will make sense to get rid of the redundant part of the vocabulary to increase the model speed.
For example, Google’s multilingual version of T5 — mT5 — was pretrained on 101 languages. Imagine if we only use English, Japanese, and Chinese in our downstream text generation task. We would waste a lot of time and space to process the rows in the embedding matrix and the LM head that corresponds to tokens that never appear in the dataset.
In this post, I’ll demonstrate how to reduce the vocabulary size of a trained SentencePiece model. SentencePiece is used in XLNet, ALBERT, Marian, and T5. Other types of tokenizers are not covered.
Specifically, we’ll shrink the vocabulary of the
mt5-small pretrained model. All tokens that are not used in the Chinese part of the XNLI dataset will be removed. As a result, the vocabulary size will go down from 250K to below 31K — an 87.6% reduction.
The solution presented in this post comes from these two notebooks:
I create an example showcasing the mechanism behind the code and verify the result. The complete notebook can be accessed here on Github.
Download the Pretrained Model
We use huggingface/transformers to download the pretrained tokenizer(SentencePiece model):
from pathlib import Path import shutil from transformers import MT5Tokenizer Path("cache/").mkdir(exist_ok=True) if Path("cache/mt5-small").exists(): shutil.rmtree("cache/mt5-small") tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small") tokenizer.save_pretrained("cache/mt5-small")
Download the XNLI dataset
Again, we use the datasets library from huggingface to download the Chinese part of the XNLI dataset:
from datasets import load_dataset dataset = load_dataset("xnli", "zh")
Collect the Tokens
Since we want to keep all the tokens that appear in this downstream dataset, we need to get a list of them:
from itertools import chain from tqdm import tqdm def tokenize_data(data, batch_size=1024): global seen for i in tqdm(range(0, len(data), batch_size)): seen = seen.union( set( chain.from_iterable( tokenizer.batch_encode_plus(data[i:(i+batch_size)], return_attention_mask=False)["input_ids"] ) ) ) seen = set() for subset in ("train", "test", "validation"): print(subset) tokenize_data(dataset[subset]["hypothesis"]) tokenize_data(dataset[subset]["premise"]) # You can also add some additional (meta) tokens: seen = seen.union(set(tokenizer.encode("mnli premise: hypothesis: <unk>")))
Load the SentencePiece Model
Here we load the pretrained SentencePiece model into memory (as a Protocol Buffers object):
from sentencepiece import sentencepiece_model_pb2 as model m = model.ModelProto() m.ParseFromString(open("cache/mt5-small/spiece.model", 'rb').read()) # There are some reserved places for speical tokens for i, piece in enumerate(m.pieces[:320]): if i % 20 == 0: print(i, piece.piece)
We can see that the first 259 tokens are reserved for functional tokens. It might be a good idea to keep them.
Shrink the SentencePiece Model
m.pieces is a Protocol Buffers field, we can not merely point it to a new list. Instead, we need to use the field’s methods to manipulate its content:
kept_pieces, i = , len(m.pieces) - 1 while len(m.pieces): piece = m.pieces.pop() if i < 259 or i in seen: kept_pieces.append(piece) i -= 1 kept_pieces = list(reversed(kept_pieces)) # Backup the old model Path("cache/mt5-small/spiece.model").rename("cache/mt5-small/spiece.model.old") # Write the new model to disk with open("cache/mt5-small/spiece.model", 'wb') as f: f.write(m.SerializeToString())
We’ll also need to keep track of the tokens that are retained, so we can know which rows to keep in the embedding matrix and the LM head.
import json kept_ids = sorted(list(seen.union(set(range(259))))) print(len(kept_ids)) with open("cache/mt5-small/kept_ids.json", 'w') as f: json.dump(kept_ids, f)
We can verify that our new SentencePiece model can correctly tokenize our dataset (by encode and then decode sentences):
import random tokenizer = MT5Tokenizer.from_pretrained("cache/mt5-small") for i in random.sample(range(100), k=10): # the space placements are slightly different from the original converted = tokenizer.decode( tokenizer.encode(dataset["train"]["hypothesis"][i]), skip_special_tokens=True ).replace(" ", "") assert converted == dataset["train"]["hypothesis"][i].replace(" ", "")
Now we know how to create a new SentencePiece model with a smaller vocabulary size. The next step would be to use it to fine-tune an NLP model. We’ll need to modify the embedding matrix and the LM head (if the LM head’s weight and the embedding matrix are not tied together). It is beyond the scope of this post since the NLP model can be implemented in any of the deep learning frameworks (PyTorch, Tensorflow, MXNet, etc.).
Multilingual pretrained models are not the only use cases of this technique. It can also be useful when the upstream model is trained on a corpus covering several domains, while you are only interested in one of those domains in the downstream task. Just make sure you retain all the necessary tokens in the downstream task; otherwise, the performance might suffer.