Deep Learning 19: Training MLM on any pre-trained BERT models

MLM, masked language modeling, is an important task for trianing a BERT model. In the orignal BERT paper: BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, it is one of the main tasks of how BERT was pre-trained. So if you have your own corpus, it is possible to train MLM on any pre-trained BERT models, i.e., RoBERTa, SciBert.

Huggingface Library and Input tsv

The Huggingface library supports a various pre-trained BERT models. Now let’s first prepare a tsv file as our courpus, and this would be the input file to train the MLM. Simply, put the free-text in lines, and say we name this file to be MyData.tsv.

MLM for regular BERT Models

For some regular BERT models, the huggingface library supports some well-defined classes. Now let’s take the RoBERTa model as an example. There are three classes we need to be familar with:

from transformers import RobertaTokenizerFast
from transformers import RobertaConfig
from transformers import RobertaForMaskedLM

tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

config = RobertaConfig(
    vocab_size=52_000,
    max_position_embeddings=514,
    num_attention_heads=12,
    num_hidden_layers=6,
    type_vocab_size=1,
)

model = RobertaForMaskedLM(config=config)

The tokenizer is associated with RoBERTa model. You will need to change to other classes if not using RoBERTa. Similarly, here we want to initialize a RoBERTa, so we pass the config to init the model. Now, to deal with the data loading:

from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path='MyData.tsv',
    block_size=128,
)

Simply, we provide the data path to LineByLineTextDataset, so as to load in the data in lines. You might have noticed that in the data_collator, we assign 0.15 as the probability of MLM, which is the same as the BERT paper. Next, start training and save trained model:

from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
    output_dir=trained_path,
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_gpu_train_batch_size=16,
    save_steps=10_000,
    save_total_limit=2,
    prediction_loss_only=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

print ('Start a trainer...')
# Start training
trainer.train()

# Save
trainer.save_model('trained_path/')
print ('Finished training all...',trained_path)

After training is finished, under trained_path, you will see the saved model. Next time, you can load in the model for your own downstream tasks.

MLM for special BERT Models

In the library, there are many other BERT models, i.e., SciBERT. Such models don’t have a special Tokenizer class or a Config class, but it is still possible to train MLM on top of those models. Next, let’s take the pre-trained SciBERT as an example, and the following is the way to init it:

from transformers import AutoConfig,AutoTokenizer,AutoModelForMaskedLM

config = AutoConfig.from_pretrained('scibert_scivocab_uncased')
tokenizer = AutoTokenizer.from_pretrained('scibert_scivocab_uncased')
model = AutoModelForMaskedLM.from_pretrained('scibert_scivocab_uncased')

The rest of the code is exacly the same as the previous section. We use Autoxxx class to let the library to auto-recognize the classes.

References

https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb

Published by Irene

Keep calm and update blog.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: