(Unofficial) Pytorch implementation of R-BERT: Enriching Pre-trained Language Model with Entity Information for Relation Classification
- Get three vectors from BERT.
- [CLS] token vector
- averaged entity_1 vector
- averaged entity_2 vector
- Pass each vector to the fully-connected layers.
- dropout -> tanh -> fc-layer
- Concatenate three vectors.
- Pass the concatenated vector to fully-connect layer.
- dropout -> fc-layer
- Exactly the SAME conditions as written in paper.
- Averaging on
entity_1andentity_2hidden state vectors, respectively. (including $, # tokens) - Dropout and Tanh before Fully-connected layer.
- No [SEP] token at the end of sequence. (If you want add [SEP] token, give
--add_sep_tokenoption)
- Averaging on
- perl (For evaluating official f1 score)
- python>=3.6
- torch==1.6.0
- transformers==3.3.1
$ python3 main.py --do_train --do_eval- Prediction will be written on
proposed_answers.txtinevaldirectory.
$ python3 official_eval.py
# macro-averaged F1 = 88.29%- Evaluate based on the official evaluation perl script.
- MACRO-averaged f1 score (except
Otherrelation)
- MACRO-averaged f1 score (except
- You can see the detailed result on
result.txtinevaldirectory.
$ python3 predict.py --input_file {INPUT_FILE_PATH} --output_file {OUTPUT_FILE_PATH} --model_dir {SAVED_CKPT_PATH}