transformer.
ipynb - Colab 07/10/24, 2:22 PM
Text Prediction with Pre-trained Transformer Models
import torch
import string
from transformers import BertTokenizer, BertForMaskedLM
from transformers import XLNetTokenizer, XLNetLMHeadModel
from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import ElectraTokenizer, ElectraForMaskedLM
from transformers import RobertaTokenizer, RobertaForMaskedLM
# Load tokenizers and models
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased').eval()
xlnet_tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
xlnet_model = XLNetLMHeadModel.from_pretrained('xlnet-base-cased').eval()
xlmroberta_tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')
xlmroberta_model = XLMRobertaForMaskedLM.from_pretrained('xlm-roberta-base').eval()
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
bart_model = BartForConditionalGeneration.from_pretrained('facebook/bart-large').ev
electra_tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-generato
electra_model = ElectraForMaskedLM.from_pretrained('google/electra-small-generator'
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
roberta_model = RobertaForMaskedLM.from_pretrained('roberta-base').eval()
top_k = 10
def decode(tokenizer, pred_idx, top_clean):
ignore_tokens = string.punctuation + '[PAD]'
tokens = []
for w in pred_idx:
token = tokenizer.decode([w], skip_special_tokens=True).strip()
if token and token not in ignore_tokens:
tokens.append(token.replace('##', ''))
return '\n'.join(tokens[:top_clean])
def encode(tokenizer, text_sentence, add_special_tokens=True):
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
if tokenizer.mask_token == text_sentence.split()[-1]:
https://colab.research.google.com/drive/1GHFch9YpGkKgVo5TMxtQP7pHIeLg2QHc#scrollTo=PG5DaRp5j88o Page 1 of 5
transformer.ipynb - Colab 07/10/24, 2:22 PM
text_sentence += ' .'
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=ad
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
return input_ids, mask_idx
def get_all_predictions(text_sentence, top_clean=5):
predictions = {}
# ========================= BERT =================================
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
with torch.no_grad():
predict = bert_model(input_ids)[0]
predictions['bert'] = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k
# ========================= XLNET =================================
input_ids, mask_idx = encode(xlnet_tokenizer, text_sentence, False)
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torc
perm_mask[:, :, mask_idx] = 1.0
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)
target_mapping[0, 0, mask_idx] = 1.0
with torch.no_grad():
predict = xlnet_model(input_ids, perm_mask=perm_mask, target_mapping=target
predictions['xlnet'] = decode(xlnet_tokenizer, predict[0, 0, :].topk(top_k).ind
# ========================= XLM ROBERTA =================================
input_ids, mask_idx = encode(xlmroberta_tokenizer, text_sentence)
with torch.no_grad():
predict = xlmroberta_model(input_ids)[0]
predictions['xlm'] = decode(xlmroberta_tokenizer, predict[0, mask_idx, :].topk(
# ========================= BART =================================
input_ids, mask_idx = encode(bart_tokenizer, text_sentence)
with torch.no_grad():
predict = bart_model(input_ids)[0]
predictions['bart'] = decode(bart_tokenizer, predict[0, mask_idx, :].topk(top_k
# ========================= ELECTRA =================================
input_ids, mask_idx = encode(electra_tokenizer, text_sentence)
with torch.no_grad():
predict = electra_model(input_ids)[0]
predictions['electra'] = decode(electra_tokenizer, predict[0, mask_idx, :].topk
# ========================= ROBERTA =================================
input_ids, mask_idx = encode(roberta_tokenizer, text_sentence)
with torch.no_grad():
predict = roberta_model(input_ids)[0]
https://colab.research.google.com/drive/1GHFch9YpGkKgVo5TMxtQP7pHIeLg2QHc#scrollTo=PG5DaRp5j88o Page 2 of 5
transformer.ipynb - Colab 07/10/24, 2:22 PM
predictions['roberta'] = decode(roberta_tokenizer, predict[0, mask_idx, :].topk
return predictions
Some weights of the model checkpoint at bert-base-uncased were not used when in
- This IS expected if you are initializing BertForMaskedLM from the checkpoint
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpo
Some weights of the model checkpoint at xlm-roberta-base were not used when ini
- This IS expected if you are initializing XLMRobertaForMaskedLM from the check
- This IS NOT expected if you are initializing XLMRobertaForMaskedLM from the c
pytorch_model.bin: 100% 1.02G/1.02G [00:00<00:00, 4.73MB/s]
tokenizer_config.json: 100% 48.0/48.0 [00:00<00:00, 901B/s]
vocab.txt: 100% 232k/232k [00:00<00:00, 3.72MB/s]
tokenizer.json: 100% 466k/466k [00:00<00:00, 7.95MB/s]
config.json: 100% 662/662 [00:00<00:00, 14.4kB/s]
pytorch_model.bin: 100% 54.2M/54.2M [00:01<00:00, 33.2MB/s]
tokenizer_config.json: 100% 25.0/25.0 [00:00<00:00, 392B/s]
vocab.json: 100% 899k/899k [00:00<00:00, 16.6MB/s]
merges.txt: 100% 456k/456k [00:00<00:00, 4.88MB/s]
tokenizer.json: 100% 1.36M/1.36M [00:00<00:00, 23.1MB/s]
config.json: 100% 481/481 [00:00<00:00, 5.57kB/s]
model.safetensors: 100% 499M/499M [00:08<00:00, 37.2MB/s]
https://colab.research.google.com/drive/1GHFch9YpGkKgVo5TMxtQP7pHIeLg2QHc#scrollTo=PG5DaRp5j88o Page 3 of 5
transformer.ipynb - Colab 07/10/24, 2:22 PM
text = "I am feeling great <mask>."
predictions = get_all_predictions(text)
for model, result in predictions.items():
print(f"{model.upper()} Predictions:\n{result}\n")
BERT Predictions:
for
about
with
now
and
XLNET Predictions:
awful
great
terrible
today
terrific
XLM Predictions:
in
i
at
for
and
BART Predictions:
and
today
…
physically
!!!
ELECTRA Predictions:
this
about
and
for
with
ROBERTA Predictions:
right
about
this
and
so
Start coding or generate with AI.
https://colab.research.google.com/drive/1GHFch9YpGkKgVo5TMxtQP7pHIeLg2QHc#scrollTo=PG5DaRp5j88o Page 4 of 5
transformer.ipynb - Colab 07/10/24, 2:22 PM
https://colab.research.google.com/drive/1GHFch9YpGkKgVo5TMxtQP7pHIeLg2QHc#scrollTo=PG5DaRp5j88o Page 5 of 5