Deep Learning 17: text classification with BERT using PyTorch


If you are a big fun of PyTorch and NLP, you must try to use the PyTorch based BERT implementation! If you have your own dataset and want to try the state-of-the-art model, BERT is a good choice.
Please check the code from to get a close look. However, in this post, I will help you to apply pre-trained BERT model on your own data to do classification.

Where to start?

They provided a nice example which you can find from here: What the code does is briefly, load the data, load the pre-trained model and fine-tune the network itself, then output the accuracy (or other metrics scores) on the develop dataset. You might want to read the example code line by line, and I believe it will be a very good practice.

Prepare Code

The only thing we need to do classification is only two python script files: and, where can be downloaded from the examples folder We use the first one to do prediction, and the second one is to provide util functions like data loading and processing.
As denoted in the README, the classification support various datasets:

We have different data loading functions for them because the format are various. We will focus on SST-2 as it is very easy to generate data in the similar format.

To run the code, simply run:

python \
--task_name SST-2 \
--do_train \
--do_eval \
--do_lower_case \
--data_dir YOUR_DATA_DIR \
--bert_model bert-base-uncased \
--max_seq_length 128 \
--train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--output_dir YOUR_OUTPUT_DIR

The --bert_model is the BERT model you want to restore, it can be a list of pre-defined model names (check the README file) or the path directory to your own fine-tuned BERT model!

Prepare data

Note that we will freeze the task name to be SST-2. And you should put all the data under YOUR_DATA_DIR including two files: train.tsv and dev.tsv. In the, it considers the dev file as the testing data. So please not that, train.tsv is the data you want to train on and dev.tsv is the data you want to evaluate on. After running the python script, it will output accuracy on dev.tsv data. Note that you can also change details in the

In train.tsv, first line(header) will be(separated by tab):
sentence \t label
The other lines will be actual sentences and then a tab, following by a label (starts from 0, then 1, 2..). Each line is a sample. We have the same format for dev.tsv file. For example, they should look like this:


How it performs

There will be a bar showing training progress:

Then it will evaluate after few epochs (you should give the number of epochs) and print out the accuracy.

When using your own dataset, it seems that the accuracy is very sensitive to the learning rate and number of epochs. If you have a small dataset, say only two thousands samples, I suggest that try smaller learning rates like 1e-5.

Have fun with BERT!

Published by Irene

Keep calm and update blog.

3 thoughts on “Deep Learning 17: text classification with BERT using PyTorch

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: