Unofficial implementation of XLNet. Embedding extraction and embedding extract with memory show how to get the outputs of the last transformer layer using pre-trained checkpoints.
pip install keras-xlnetClick the task name to see the demos with base model:
| Task Name | Metrics | Approximate Results on Dev Set |
|---|---|---|
| CoLA | Matthew Corr. | 52 |
| SST-2 | Accuracy | 93 |
| MRPC | Accuracy/F1 | 86/89 |
| STS-B | Pearson Corr. / Spearman Corr. | 86/87 |
| QQP | Accuracy/F1 | 90/86 |
| MNLI | Accuracy | 84/84 |
| QNLI | Accuracy | 86 |
| RTE | Accuracy | 64 |
| WNLI | Accuracy | 56 |
(Only 0s are predicted in WNLI dataset)
import os
from keras_xlnet import Tokenizer, load_trained_model_from_checkpoint, ATTENTION_TYPE_BI
checkpoint_path = '.../xlnet_cased_L-24_H-1024_A-16'
tokenizer = Tokenizer(os.path.join(checkpoint_path, 'spiece.model'))
model = load_trained_model_from_checkpoint(
config_path=os.path.join(checkpoint_path, 'xlnet_config.json'),
checkpoint_path=os.path.join(checkpoint_path, 'xlnet_model.ckpt'),
batch_size=16,
memory_len=512,
target_len=128,
in_train_phase=False,
attention_type=ATTENTION_TYPE_BI,
)
model.summary()Arguments batch_size, memory_len and target_len are maximum sizes used for initialization of memories. The model used for training a language model is returned if in_train_phase is True, otherwise a model used for fine-tuning will be returned.
Note that shuffle should be False in either fit or fit_generator if memories are used.
3 inputs:
- IDs of tokens, with shape
(batch_size, target_len). - IDs of segments, with shape
(batch_size, target_len). - Length of memories, with shape
(batch_size, 1).
1 output:
- The feature for each token, with shape
(batch_size, target_len, units).
4 inputs:
- IDs of tokens, with shape
(batch_size, target_len). - IDs of segments, with shape
(batch_size, target_len). - Length of memories, with shape
(batch_size, 1). - Masks of tokens, with shape
(batch_size, target_len).
1 output:
- The probability of each token in each position, with shape
(batch_size, target_len, num_token).