From cd69aff8710d3f6b6f1f6492723b85919a23b253 Mon Sep 17 00:00:00 2001 From: Zoey Li Date: Wed, 8 Jun 2022 02:43:37 -0500 Subject: [PATCH 1/4] tgr tagging model --- train_tagger.py | 238 ++++++++++++++ zstagger/KAIROS_data_module.py | 337 ++++++++++++++++++++ zstagger/__init__.py | 0 zstagger/convert_output_for_arg_ext.py | 80 +++++ zstagger/data.py | 105 +++++++ zstagger/data_module.py | 164 ++++++++++ zstagger/layers.py | 410 +++++++++++++++++++++++++ zstagger/model.py | 321 +++++++++++++++++++ zstagger/utils.py | 350 +++++++++++++++++++++ zstagger/zs_model.py | 305 ++++++++++++++++++ 10 files changed, 2310 insertions(+) create mode 100644 train_tagger.py create mode 100644 zstagger/KAIROS_data_module.py create mode 100644 zstagger/__init__.py create mode 100644 zstagger/convert_output_for_arg_ext.py create mode 100644 zstagger/data.py create mode 100644 zstagger/data_module.py create mode 100644 zstagger/layers.py create mode 100644 zstagger/model.py create mode 100644 zstagger/utils.py create mode 100644 zstagger/zs_model.py diff --git a/train_tagger.py b/train_tagger.py new file mode 100644 index 0000000..4e9ae0d --- /dev/null +++ b/train_tagger.py @@ -0,0 +1,238 @@ +import argparse +import logging +import os +import random +import timeit +from datetime import datetime + +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping, ModelCheckpoint +from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger +from pytorch_lightning.utilities.seed import seed_everything + +from zstagger.data_module import TaggerDataModule +from zstagger.KAIROS_data_module import DocTaggerDataModule +from zstagger.model import TaggerModel +from zstagger.zs_model import ZSTaggerModel + + +logger = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model", + type=str, + required=True, + choices=['crf','zs-crf','proto'] + ) + parser.add_argument( + "--dataset", + type=str, + required=True, + choices=['ACE', 'ERE','KAIROS'] + ), + parser.add_argument( + '--task', + type=str, + default='trigger', + choices=['trigger', 'arg'], + help='whether to perform trigger extraction or argument extraction.' + ) + parser.add_argument( + "--pretrained_model", + type=str, + default='roberta-base', + ) + parser.add_argument( + '--tmp_dir', + type=str, + default='tag_data', + help='temporary directory for saving the preprocessed data. If this exists, will directly read from dir.') + + parser.add_argument( + '--event_n', + type=int, + default=33, + ) + parser.add_argument('--no_projection', action='store_true') + parser.add_argument('--token_classification', action='store_true') + parser.add_argument('--use_pl', action='store_true') + parser.add_argument('--proj_dim', type=int, default=500) + parser.add_argument('--use_transition', action='store_true') + parser.add_argument('--reg_weight',type=float, default=0.5) + parser.add_argument('--use_bilstm', action='store_true') + parser.add_argument('--bilstm_hidden_size', type=int, default=100) + parser.add_argument('--bilstm_dropout', type=float, default=0.5) + parser.add_argument( + "--ckpt_name", + default=None, + type=str, + help="The output directory where the model checkpoints and predictions will be written.", + ) + parser.add_argument( + "--load_ckpt", + default=None, + type=str, + ) + parser.add_argument( + "--train_file", + default=None, + type=str, + help="The input training file. If a data dir is specified, will look for the file there" + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", + ) + parser.add_argument( + "--val_file", + default=None, + type=str, + help="The input evaluation file. If a data dir is specified, will look for the file there" + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", + ) + parser.add_argument( + '--test_file', + type=str, + default=None, + ) + parser.add_argument("--train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") + parser.add_argument( + "--eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation." + ) + parser.add_argument( + "--eval_only", action="store_true", + ) + parser.add_argument("--bert_learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") + parser.add_argument("--bert_weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") + parser.add_argument("--bert_dropout", default=0.5, type=float, help="Dropout after BERT encoding.") + parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") + parser.add_argument("--gradient_clip_val", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform." + ) + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") + + parser.add_argument("--gpus", default=-1, help='-1 means train on all gpus') + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument( + "--fp16", + action="store_true", + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", + ) + parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") + args = parser.parse_args() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Set seed + seed_everything(args.seed) + + logger.info("Training/evaluation parameters %s", args) + + + if not args.ckpt_name: + d = datetime.now() + time_str = d.strftime('%m-%dT%H%M') + args.ckpt_name = '{}_{}lr{}_{}'.format(args.model, args.train_batch_size * args.accumulate_grad_batches, + args.learning_rate, time_str) + + + args.ckpt_dir = os.path.join(f'./checkpoints/{args.ckpt_name}') + + os.makedirs(args.ckpt_dir) + + checkpoint_callback = ModelCheckpoint( + dirpath=args.ckpt_dir, + save_top_k=1, + monitor='val/f1', + mode='max', + save_weights_only=True, + filename='best', + + ) + + early_stop_callback = EarlyStopping( + monitor='val/f1', + min_delta=0.0, + patience=5, + verbose=False, + mode='max' + ) + + + + lr_logger = LearningRateMonitor() + tb_logger = TensorBoardLogger('logs/') + # wb_logger = WandbLogger(project='genie', name=args.ckpt_name) + + # model = TaggerModel(args) + if args.model in ['zs-crf', 'proto']: + model = ZSTaggerModel(args) + else: + model = TaggerModel(args) + if args.dataset == 'ACE': + dm = TaggerDataModule(args) + elif args.dataset=='KAIROS': + dm = DocTaggerDataModule(args) + + + + + if args.max_steps < 0 : + args.max_epochs = args.min_epochs = args.num_train_epochs + + + + trainer = Trainer( + logger=tb_logger, + min_epochs=10, + max_epochs=args.num_train_epochs, + gpus=args.gpus, + checkpoint_callback=checkpoint_callback, + accumulate_grad_batches=args.accumulate_grad_batches, + gradient_clip_val=args.gradient_clip_val, + num_sanity_val_steps=0, + val_check_interval=1.0, # use float to check every n epochs + precision=16 if args.fp16 else 32, + callbacks = [lr_logger, early_stop_callback], + reload_dataloaders_every_epoch=True + + ) + + if args.load_ckpt: + model.load_state_dict(torch.load(args.load_ckpt,map_location=model.device)['state_dict']) + + + if args.eval_only: + dm.setup('test') + trainer.test(model, datamodule=dm) #also loads training dataloader + else: + dm.setup('fit') + trainer.fit(model, datamodule=dm) + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/zstagger/KAIROS_data_module.py b/zstagger/KAIROS_data_module.py new file mode 100644 index 0000000..4c105b9 --- /dev/null +++ b/zstagger/KAIROS_data_module.py @@ -0,0 +1,337 @@ +import os +import json +import argparse + +from transformers import AutoTokenizer +from torch.utils.data import DataLoader +import pytorch_lightning as pl + +from .data import IEDataset, MAX_LENGTH, adaptive_length_collate +from .utils import load_ontology, load_role_mapping, find_ent_span + + +# For KAIROS trigger extraction +MAX_CONTEXT_LENGTH=300 +MAX_LENGTH=400 +# ensure that documents with MAX_CONTENT_LENGTH are not truncated + +WORD_START_CHAR='\u0120' +def get_chunk(ex, window): + start =0 + for i in range(window[0]): + start += len(ex['sentences'][i][0]) + end = start + for i in range(window[0], window[1]+1): + end += len(ex['sentences'][i][0]) + + return (start, end) + + +class DocTaggerDataModule(pl.LightningDataModule): + def __init__(self, args): + super().__init__() + self.hparams = args + self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model) + + + def get_trigger_word_tags(self, ex, ontology_dict): + tokens = ex['tokens'] + tags = [0,] * len(tokens) + + for event in ex['event_mentions']: + event_type = event['event_type'] + if event_type not in ontology_dict: + continue + for idx in range(event['trigger']['start'], event['trigger']['end']):# idx not inclusive + tags[idx] = ontology_dict[event_type]['i-label'] + + return tags + + def get_arg_word_tags(self, ex, role_mapping): + ''' + Returns a list of word tags, one for each trigger. + ''' + tokens = ex['tokens'] + tag_list = [] + + + for event in ex['event_mentions']: + tags = [0,] * len(tokens) + center_sent = event['trigger']['sent_idx'] + trigger = event['trigger']['text'] + for arg in event['arguments']: + ent_id = arg['entity_id'] + role = arg['role'] + if role not in role_mapping: + continue + span = find_ent_span(ex, ent_id) # not inclusive of span end + + # tags[span[0]] = role_mapping[role]['b-label'] + + # for idx in range(span[0]+1, span[1]): + # tags[idx] = role_mapping[role]['i-label'] + + for idx in range(span[0], span[1]): + tags[idx] = role_mapping[role]['i-label'] + + tag_list.append({ + 'word_tags': tags, + 'trigger': trigger, + 'center_sent': center_sent + }) + + return tag_list + + + def get_labels(self, ex, word_tags, pretrained_model, start=0, end=-1): + if pretrained_model.startswith('roberta'): + if start!=0: + raise NotImplementedError + bpe_tokens = self.tokenizer.tokenize(ex['sentence'], add_prefix_space=True) + + widx =-1 + bpe_tags = [] + token_lens = [1, ] * len(word_tags) + bpe2word = [] + for b in bpe_tokens: + if b[0] == WORD_START_CHAR: + widx +=1 + else: + token_lens[widx]+=1 + bpe_tags.append(word_tags[widx]) + bpe2word.append(widx) + + assert(len(bpe_tags) == len(bpe_tokens)) + + labels = [0, ] + bpe_tags[:MAX_LENGTH-2] + [0,] # 0 for token and 0 for token + bpe2word = [-1, ] + bpe2word + [-1, ] + return labels, bpe2word, token_lens + + elif pretrained_model.startswith('bert'): + words = ex['tokens'][start:end] + bpe_tags = [] + token_lens = [1, ] * len(word_tags) + bpe2word = [] + for widx, w in enumerate(words): + bpe_tokens = self.tokenizer.tokenize(w) + token_lens[widx] = len(bpe_tokens) + bpe_tags.extend([word_tags[widx],] * len(bpe_tokens)) + bpe2word.extend([widx,] * len(bpe_tokens)) + + + labels = [0, ] + bpe_tags[:MAX_LENGTH-2] + [0,] # 0 for [CLS] token and 0 for [SEP] token + bpe2word = [-1, ] + bpe2word + [-1, ] + return labels, bpe2word, token_lens + else: + raise NotImplementedError + + + def prepare_data(self): + if not os.path.exists(self.hparams.tmp_dir): + os.makedirs(self.hparams.tmp_dir) + + ontology_dict = load_ontology(self.hparams.dataset) + role_mapping = load_role_mapping(self.hparams.dataset) + + for split,f in [('train',self.hparams.train_file), ('val',self.hparams.val_file), ('test',self.hparams.test_file)]: + with open(f,'r') as reader, open(os.path.join(self.hparams.tmp_dir, '{}.jsonl'.format(split)), 'w') as writer: + for line in reader: + ex = json.loads(line.strip()) + if split =='train' and len(ex['tokens']) < 4: + # removing headers + continue + if self.hparams.task == 'trigger': + word_tags = self.get_trigger_word_tags(ex, ontology_dict) + # token tags from word tags + # chunking + start = 0 + # while (start < len(ex['tokens'])): + for chunk_idx in range( len(ex['sentences'])): + # chunk by sentence + use_ex = True + sentence_tokens, sentence_text = ex['sentences'][chunk_idx] + sent_len = len(sentence_tokens) + chunk = (start, start+sent_len) + word_tags_chunk = word_tags[chunk[0]: chunk[1]] + tokens_chunk = ex['tokens'][chunk[0]: chunk[1]] + + start += sent_len + if sent_len < 4 and split=='train': + use_ex = False + try: + assert(len(tokens_chunk) <= MAX_CONTEXT_LENGTH) + except AssertionError: + print(len(tokens_chunk), split) + # discard this super long sentence + use_ex= False + + if use_ex: + tokenized = self.tokenizer.tokenize(' '.join(tokens_chunk)) + try: + assert(len(tokenized) <= MAX_LENGTH -2) + except AssertionError: + print('Original {}, tokenized {}'.format(len(tokens_chunk), len(tokenized))) + continue + #add_prefix_space=True if self.hparams.pretrained_model.startswith('roberta') else False, + input_tokens = self.tokenizer.encode_plus(' '.join(tokens_chunk), + add_special_tokens=True, + max_length=MAX_LENGTH, + truncation=True, padding=False) + + # is_pretokenized does not work with bpe at this version of transformers + labels, bpe2word, token_lens = self.get_labels(ex, word_tags_chunk, + self.hparams.pretrained_model, + start=chunk[0], + end=chunk[1]) + assert(len(labels) == len(input_tokens['input_ids'])) + assert(len(token_lens) == len(word_tags_chunk)) + processed_ex = { + 'doc_key': ex['doc_id'], + 'chunk_idx' : chunk_idx, + 'input_token_ids':input_tokens['input_ids'], + 'input_attn_mask': input_tokens['attention_mask'], + 'labels': labels, # bpe level labels, not used in current version + 'bpe_mapping': bpe2word, + 'token_lens': token_lens, + 'word_tags': word_tags_chunk, + } + writer.write(json.dumps(processed_ex) + '\n') + else: + word_tag_list = self.get_arg_word_tags(ex, role_mapping) + sentences_n = len(ex['sentences']) + for evt_idx, word_tag_dict in enumerate(word_tag_list): + use_ex = True + # find a three sentence window + + center_idx = word_tag_dict['center_sent'] + word_tags = word_tag_dict['word_tags'] + trigger = word_tag_dict['trigger'] + + # window = (max(0, center_idx-1), min(sentences_n, center_idx+1)) + # # get the start and end idx of this window + # chunk = get_chunk(ex, window) + # if (chunk[1]-chunk[0]) < 4: + # use_ex = False + # if (chunk[1]-chunk[0])> MAX_CONTEXT_LENGTH: + # # use only one sentence + # window = (center_idx, center_idx) + # chunk = get_chunk(ex, window) + + window = (center_idx, center_idx) + chunk = get_chunk(ex, window) + word_tags_chunk = word_tags[chunk[0]: chunk[1]] + tokens_chunk = ex['tokens'][chunk[0]: chunk[1]] + try: + assert(len(tokens_chunk) <= MAX_CONTEXT_LENGTH) + except AssertionError: + print(len(tokens_chunk), split) + # discard this super long sentence + use_ex= False + + if use_ex: + tokenized = self.tokenizer.tokenize(' '.join(tokens_chunk)) + try: + assert(len(tokenized) <= MAX_LENGTH -2) + except AssertionError: + print('Original {}, tokenized {}'.format(len(tokens_chunk), len(tokenized))) + continue + #add_prefix_space=True if self.hparams.pretrained_model.startswith('roberta') else False, + + # [CLS] tokens [SEP] trigger [SEP] + input_tokens = self.tokenizer.encode_plus( ' '.join(tokens_chunk), trigger, + add_special_tokens=True, + max_length=MAX_LENGTH, + truncation=True, padding=False) + + trigger_token_len = len(self.tokenizer.tokenize(trigger)) + # is_pretokenized does not work with bpe at this version of transformers + labels, bpe2word, token_lens = self.get_labels(ex, word_tags_chunk, + self.hparams.pretrained_model, + start=chunk[0], + end=chunk[1]) + labels.extend([0,] * (trigger_token_len +1) ) # bpe level labels for "trigger [SEP]" + bpe2word.extend([-1,]* (trigger_token_len +1) ) + + assert(len(labels) == len(input_tokens['input_ids'])) + assert(len(token_lens) == len(word_tags_chunk)) + + + if len(token_lens) > 0: + processed_ex = { + 'doc_key': '{}:{}'.format(ex['doc_id'], evt_idx), + 'chunk_idx' : 0, + 'input_token_ids':input_tokens['input_ids'], + 'input_attn_mask': input_tokens['attention_mask'], + 'labels': labels, # bpe level labels, not used in current version + 'bpe_mapping': bpe2word, + 'token_lens': token_lens, + 'word_tags': word_tags_chunk, + } + writer.write(json.dumps(processed_ex) + '\n') + + + + + + + + + def train_dataloader(self): + print('reading from {}'.format(self.hparams.tmp_dir)) + if self.hparams.use_pl: + dataset = IEDataset(os.path.join(self.hparams.tmp_dir, 'pl_train.jsonl'), split='pl') + else: + dataset = IEDataset(os.path.join(self.hparams.tmp_dir,'train.jsonl'),split='train') + + dataloader = DataLoader(dataset, + pin_memory=True, num_workers=4, + collate_fn=adaptive_length_collate, + batch_size=self.hparams.train_batch_size, + shuffle=True) + return dataloader + + + def val_dataloader(self): + dataset = IEDataset(os.path.join(self.hparams.tmp_dir, 'val.jsonl'),split='val') + + dataloader = DataLoader(dataset, pin_memory=True, num_workers=4, + collate_fn=adaptive_length_collate, + batch_size=self.hparams.eval_batch_size, shuffle=False) + return dataloader + + def test_dataloader(self): + dataset = IEDataset(os.path.join(self.hparams.tmp_dir,'test.jsonl'),split='test') + + dataloader = DataLoader(dataset, pin_memory=True, num_workers=4, + collate_fn=adaptive_length_collate, + batch_size=self.hparams.eval_batch_size, shuffle=False) + + return dataloader + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--train-file',type=str,default='data/kairos/train.jsonl') + parser.add_argument('--val-file', type=str, default='data/kairos/dev.jsonl') + parser.add_argument('--test-file', type=str, default='data/kairos/test.jsonl') + parser.add_argument('--pretrained-model', type=str, default='bert-large-cased') + parser.add_argument('--task', default='arg') + parser.add_argument('--dataset', default='KAIROS') + parser.add_argument('--train_batch_size', type=int, default=2) + parser.add_argument('--eval_batch_size', type=int, default=4) + parser.add_argument('--tmp-dir', type=str, default='tag_kairos_arg') + parser.add_argument('--use-pl', action='store_true', default=False) + args = parser.parse_args() + + dm = DocTaggerDataModule(args=args) + dm.prepare_data() + + # training dataloader + dataloader = dm.train_dataloader() + + for idx, batch in enumerate(dataloader): + print(batch) + if idx==5: + break + + # val dataloader \ No newline at end of file diff --git a/zstagger/__init__.py b/zstagger/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/zstagger/convert_output_for_arg_ext.py b/zstagger/convert_output_for_arg_ext.py new file mode 100644 index 0000000..bebe0fc --- /dev/null +++ b/zstagger/convert_output_for_arg_ext.py @@ -0,0 +1,80 @@ +''' +This file takes a prediction from the tagger model and converts it into the standard format for argument extraction. +''' +from json import load +import os +import json +import argparse +from utils import load_ontology, get_pred_tgr_mentions + +def get_tag_mapping(ontology_dict): + tag2event_type = {} + for et in ontology_dict: + tag = ontology_dict[et]['i-label'] + tag2event_type[tag] = et + + return tag2event_type + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--pred-file',type=str) + parser.add_argument('--ref-file', type=str) + parser.add_argument('--output-file',type=str) + parser.add_argument('--dataset', type=str, default='ACE') + args = parser.parse_args() + + + + ontology_dict = load_ontology(args.dataset) + tag2event_type = get_tag_mapping(ontology_dict) + + + + with open(args.pred_file) as f: + predictions = json.load(f) + + gold_exs = [] + with open(args.ref_file) as f: + for line in f: + ex = json.loads(line) + gold_exs.append(ex) + + + writer = open(args.output_file,'w') + total_pred = 0 + total_gold = 0 + for ex in gold_exs: + doc_key = ex['sent_id'] + pred_tags = predictions[doc_key]['pred_tags'] + ex['pred_tags'] = pred_tags + pred_exs = [] + if sum(pred_tags) > 0: + pred_triggers = get_pred_tgr_mentions(ex, tag2event_type) + # List of (cur_start, cur_end, tag2event_type[prev_tag]) + for tgr_tup in pred_triggers: + start, end, evt_type = tgr_tup + trigger_text = ' '.join(ex['tokens'][start:end]) + pred_exs.append({ + 'event_type': evt_type, + 'arguments':[], + 'trigger': { + 'start': start, + 'end': end, + 'text': trigger_text, + + } + }) + total_pred +=1 + total_gold += len(ex['event_mentions']) + ex['event_mentions'] = pred_exs + writer.write(json.dumps(ex) + '\n') + + print(total_pred) + print(total_gold) + writer.close() + + + + + diff --git a/zstagger/data.py b/zstagger/data.py new file mode 100644 index 0000000..6ef1a9b --- /dev/null +++ b/zstagger/data.py @@ -0,0 +1,105 @@ +import os +import json +import random +import torch +import time +from torch.utils.data import Dataset + +# from ACE_data_module import MAX_LENGTH +# For ACE triggers +# MAX_LENGTH=200 +# For KAIROS +MAX_LENGTH=400 #this needs to be a hyperparameter + + +def adaptive_length_collate(batch): + ''' + 'doc_key': ex['sent_id'], + 'input_token_ids':input_tokens['input_ids'], + 'input_attn_mask': input_tokens['attention_mask'], + 'labels': labels, + 'bpe_mapping': bpe2word, + 'token_lens': token_lens, + 'word_tags': word_tags, + ''' + doc_keys = [ex['doc_key'] for ex in batch] + token_lens = [ex['token_lens'] for ex in batch] + max_len = min(max([sum(ex['input_attn_mask']) for ex in batch]), MAX_LENGTH) + batch_size = len(batch) + + input_token_ids = torch.ones((batch_size, max_len), dtype=torch.long) + input_attn_mask = torch.zeros((batch_size, max_len), dtype=torch.bool) + bpe_mapping = torch.ones((batch_size, max_len), dtype=torch.long) * (-1) + labels = torch.zeros((batch_size, max_len), dtype=torch.long) + word_lengths = [] + + for i in range(batch_size): + ex = batch[i] + word_lengths.append(len(ex['word_tags'])) + l = min(sum(ex['input_attn_mask']), MAX_LENGTH) + input_token_ids[i, :l] = torch.LongTensor(ex['input_token_ids'][:MAX_LENGTH]) + input_attn_mask[i, :l] = torch.BoolTensor(ex['input_attn_mask'][:MAX_LENGTH]) + labels[i, :l] = torch.LongTensor(ex['labels'][:MAX_LENGTH]) + bpe_mapping[i, :l] = torch.LongTensor(ex['bpe_mapping'][:MAX_LENGTH]) + + + max_word_len = min(max(word_lengths), MAX_LENGTH) + word_tags = torch.ones((batch_size, max_word_len), dtype=torch.long) *(-1) + for i in range(batch_size): + ex = batch[i] + l = min(len(ex['word_tags']), MAX_LENGTH) + word_tags[i, :l] = torch.LongTensor(ex['word_tags'][:MAX_LENGTH]) + + chunk_idx = [ex['chunk_idx'] if 'chunk_idx' in ex else 0 for ex in batch ] + return { + 'input_token_ids': input_token_ids, + 'input_attn_mask': input_attn_mask, + 'labels': labels, + 'doc_key': doc_keys, + 'bpe_mapping': bpe_mapping, + 'token_lens': token_lens, + 'word_lengths': torch.LongTensor(word_lengths), + 'word_tags': word_tags, + 'chunk_idx': chunk_idx + } + + +class IEDataset(Dataset): + def __init__(self, input_file, split): + super().__init__() + self.examples = [] + self.pos_examples = [] + self.neg_examples = [] + with open(input_file, 'r') as f: + for line in f: + ex = json.loads(line.strip()) + if sum(ex['word_tags']) == 0: + # no event mention + self.neg_examples.append(ex) + else: + self.pos_examples.append(ex) + + self.examples.append(ex) + + if split == 'train' and len(self.neg_examples) > 2*len(self.pos_examples): + # downsample negatives + + def seed_random(): + t = int( time.time() * 1000.0 ) + random.seed( ((t & 0xff000000) >> 24) + + ((t & 0x00ff0000) >> 8) + + ((t & 0x0000ff00) << 8) + + ((t & 0x000000ff) << 24) ) + seed_random() + K = len(self.pos_examples) + selected_negs = random.sample(self.neg_examples, k=K*2) + self.examples = self.pos_examples + selected_negs + + + def __len__(self): + return len(self.examples) + + def __getitem__(self, idx): + return self.examples[idx] + + diff --git a/zstagger/data_module.py b/zstagger/data_module.py new file mode 100644 index 0000000..5140003 --- /dev/null +++ b/zstagger/data_module.py @@ -0,0 +1,164 @@ +import os +import json +import argparse + +from transformers import AutoTokenizer +from torch.utils.data import DataLoader +import pytorch_lightning as pl + +from .data import IEDataset, adaptive_length_collate +from .utils import load_ontology +MAX_LENGTH=200 +WORD_START_CHAR='\u0120' + +class TaggerDataModule(pl.LightningDataModule): + def __init__(self, args): + super().__init__() + self.hparams = args + self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model) + + + def get_word_tags(self, ex, ontology_dict): + tokens = ex['tokens'] + tags = [0,] * len(tokens) + + for event in ex['event_mentions']: + event_type = event['event_type'] + for idx in range(event['trigger']['start'], event['trigger']['end']): + tags[idx] = ontology_dict[event_type]['i-label'] + + return tags + + def get_labels(self, ex, word_tags, pretrained_model): + if pretrained_model.startswith('roberta'): + bpe_tokens = self.tokenizer.tokenize(ex['sentence'], add_prefix_space=True) + + widx =-1 + bpe_tags = [] + token_lens = [1, ] * len(word_tags) + bpe2word = [] + for b in bpe_tokens: + if b[0] == WORD_START_CHAR: + widx +=1 + else: + token_lens[widx]+=1 + bpe_tags.append(word_tags[widx]) + bpe2word.append(widx) + + assert(len(bpe_tags) == len(bpe_tokens)) + + labels = [0, ] + bpe_tags[:MAX_LENGTH-2] + [0,] # 0 for token and 0 for token + bpe2word = [-1, ] + bpe2word + [-1, ] + return labels, bpe2word, token_lens + + elif pretrained_model.startswith('bert'): + words = ex['tokens'] + bpe_tags = [] + token_lens = [1, ] * len(word_tags) + bpe2word = [] + for widx, w in enumerate(words): + bpe_tokens = self.tokenizer.tokenize(w) + token_lens[widx] = len(bpe_tokens) + bpe_tags.extend([word_tags[widx],] * len(bpe_tokens)) + bpe2word.extend([widx,] * len(bpe_tokens)) + + + labels = [0, ] + bpe_tags[:MAX_LENGTH-2] + [0,] # 0 for [CLS] token and 0 for [SEP] token + bpe2word = [-1, ] + bpe2word + [-1, ] + return labels, bpe2word, token_lens + + + def prepare_data(self): + if not os.path.exists(self.hparams.tmp_dir): + print('preprocessing data to {}'.format(self.hparams.tmp_dir)) + + os.makedirs(self.hparams.tmp_dir) + + ontology_dict = load_ontology(self.hparams.dataset) + + for split,f in [('train',self.hparams.train_file), ('val',self.hparams.val_file), ('test',self.hparams.test_file)]: + with open(f,'r') as reader, open(os.path.join(self.hparams.tmp_dir, '{}.jsonl'.format(split)), 'w') as writer: + for line in reader: + ex = json.loads(line.strip()) + if split =='train' and len(ex['tokens']) < 4: + # removing headers + continue + word_tags = self.get_word_tags(ex, ontology_dict) + # token tags from word tags + + input_tokens = self.tokenizer.encode_plus(ex['sentence'], + add_prefix_space=True if self.hparams.pretrained_model.startswith('roberta') else False, + add_special_tokens=True, + max_length=MAX_LENGTH, + truncation=True, padding=False) + # is_pretokenized does not work with bpe at this version of transformers + labels, bpe2word, token_lens = self.get_labels(ex, word_tags, self.hparams.pretrained_model) + assert(len(labels) == len(input_tokens['input_ids'])) + + processed_ex = { + 'doc_key': ex['sent_id'], + 'input_token_ids':input_tokens['input_ids'], + 'input_attn_mask': input_tokens['attention_mask'], + 'labels': labels, + 'bpe_mapping': bpe2word, + 'token_lens': token_lens, + 'word_tags': word_tags, + } + writer.write(json.dumps(processed_ex) + '\n') + + + + + def train_dataloader(self): + if self.hparams.use_pl: + dataset = IEDataset(os.path.join(self.hparams.tmp_dir, 'pl_train.jsonl'), split='pl') + else: + dataset = IEDataset('tag_data/train.jsonl',split='train') + + dataloader = DataLoader(dataset, + pin_memory=True, num_workers=4, + collate_fn=adaptive_length_collate, + batch_size=self.hparams.train_batch_size, + shuffle=True) + return dataloader + + + def val_dataloader(self): + dataset = IEDataset(os.path.join(self.hparams.tmp_dir, 'val.jsonl'),split='val') + + dataloader = DataLoader(dataset, pin_memory=True, num_workers=4, + collate_fn=adaptive_length_collate, + batch_size=self.hparams.eval_batch_size, shuffle=False) + return dataloader + + def test_dataloader(self): + dataset = IEDataset(os.path.join(self.hparams.tmp_dir,'test.jsonl'),split='test') + + dataloader = DataLoader(dataset, pin_memory=True, num_workers=4, + collate_fn=adaptive_length_collate, + batch_size=self.hparams.eval_batch_size, shuffle=False) + + return dataloader + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--train-file',type=str) + parser.add_argument('--val-file', type=str) + parser.add_argument('--test-file', type=str) + parser.add_argument('--pretrained-model', type=str, default='bert-large-cased') + parser.add_argument('--dataset', default='ACE') + parser.add_argument('--train_batch_size', type=int, default=2) + parser.add_argument('--eval_batch_size', type=int, default=4) + args = parser.parse_args() + + dm = TaggerDataModule(args=args) + dm.prepare_data() + + # training dataloader + dataloader = dm.train_dataloader() + + for idx, batch in enumerate(dataloader): + print(batch) + if idx==5: + break diff --git a/zstagger/layers.py b/zstagger/layers.py new file mode 100644 index 0000000..d684f38 --- /dev/null +++ b/zstagger/layers.py @@ -0,0 +1,410 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.dropout import Dropout +from torch_struct import LinearChainCRF +from transformers import AutoModel, AutoTokenizer, AutoConfig + + +class BiLSTM(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(BiLSTM, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.lstm = nn.LSTM(input_dim, hidden_dim, bidirectional=True, batch_first=True) + + + def forward(self, embeds): + lstm_out, _ = self.lstm(embeds) + return lstm_out + +# From OneIE +def token_lens_to_idxs(token_lens): + """Map token lengths to a word piece index matrix (for torch.gather) and a + mask tensor. + For example (only show a sequence instead of a batch): + + token lengths: [1,1,1,3,1] + => + indices: [[0,0,0], [1,0,0], [2,0,0], [3,4,5], [6,0,0]] + masks: [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], + [0.33, 0.33, 0.33], [1.0, 0.0, 0.0]] + + Next, we use torch.gather() to select vectors of word pieces for each token, + and average them as follows (incomplete code): + + outputs = torch.gather(bert_outputs, 1, indices) * masks + outputs = bert_outputs.view(batch_size, seq_len, -1, self.bert_dim) + outputs = bert_outputs.sum(2) + + :param token_lens (list): token lengths. + :return: a index matrix and a mask tensor. + """ + max_token_num = max([len(x) for x in token_lens]) + max_token_len = max([max(x) for x in token_lens]) + idxs, masks = [], [] + for seq_token_lens in token_lens: + seq_idxs, seq_masks = [], [] + offset = 0 + for token_len in seq_token_lens: + seq_idxs.extend([i + offset for i in range(token_len)] + + [-1] * (max_token_len - token_len)) + seq_masks.extend([1.0 / token_len] * token_len + + [0.0] * (max_token_len - token_len)) + offset += token_len + seq_idxs.extend([-1] * max_token_len * (max_token_num - len(seq_token_lens))) + seq_masks.extend([0.0] * max_token_len * (max_token_num - len(seq_token_lens))) + idxs.append(seq_idxs) + masks.append(seq_masks) + return idxs, masks, max_token_num, max_token_len + +class MLP(nn.Module): + """Multiple linear layers with Dropout.""" + def __init__(self, dimensions, activation='relu', dropout_prob=0.0, bias=True): + super().__init__() + assert len(dimensions) > 1 + self.layers = nn.ModuleList([nn.Linear(dimensions[i], dimensions[i + 1], bias=bias) + for i in range(len(dimensions) - 1)]) + self.activation = getattr(torch, activation) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, inputs): + for i, layer in enumerate(self.layers): + if i > 0: + inputs = self.activation(inputs) + inputs = self.dropout(inputs) + inputs = layer(inputs) + return inputs + + +class BERTEncoder(nn.Module): + def __init__(self, args, bert_dim): + super(BERTEncoder, self).__init__() + self.bert = AutoModel.from_pretrained(args.pretrained_model) + self.bert_dropout = Dropout(p=args.bert_dropout) + self.bert_dim = bert_dim + self.bert_pooler = nn.Linear(bert_dim*2, bert_dim) + + def forward(self, input_ids, attention_mask, token_lens): + ''' + token_lens: list + ''' + batch_size = input_ids.size(0) + all_bert_outputs = self.bert(input_ids=input_ids,attention_mask=attention_mask, output_hidden_states=True) + bert_outputs = all_bert_outputs[0] + extra_bert_outputs = all_bert_outputs[2][-3] # from OneIE + bert_outputs = torch.cat([bert_outputs, extra_bert_outputs], dim=2) # (batch, bpe_max_len, 2*hidden_dim) + bert_outputs = self.bert_pooler(bert_outputs) # (batch, bpe_max_len, hidden_dim) + + # average all pieces for multi-piece words + assert(len(token_lens) > 0) + idxs, masks, token_num, token_len = token_lens_to_idxs(token_lens) + idxs = input_ids.new(idxs).unsqueeze(-1).expand(batch_size, -1, self.bert_dim) + 1 # shift 1 for [CLS] + masks = bert_outputs.new(masks).unsqueeze(-1) + bert_outputs = torch.gather(bert_outputs, 1, idxs) * masks + bert_outputs = bert_outputs.view(batch_size, token_num, token_len, self.bert_dim) + bert_outputs = bert_outputs.sum(2) + + bert_outputs = self.bert_dropout(bert_outputs) + + return bert_outputs + + +class PrototypeNetworkHead(nn.Module): + def __init__(self,configs, feature_size, n_classes, class_vectors): + super(PrototypeNetworkHead, self).__init__() + self.configs = configs + C = n_classes + self.projection = MLP(dimensions=[feature_size, n_classes], activation='relu', dropout_prob=0.2) + self.transition = nn.Linear(C, C) # For CRF + # normalize class vectors + class_vectors = F.normalize(class_vectors, dim=1) + self.class_vectors = nn.Parameter(class_vectors, requires_grad=False) + null_vec = F.normalize(torch.rand(1,feature_size), dim=1) + self.null_vec = nn.Parameter(null_vec, requires_grad=False) # for class null + + def forward(self, features, lengths): + class_vectors = torch.cat([self.null_vec, self.class_vectors], dim=0) # train the class vectors as well + if self.configs.no_projection: + # for the complete zero-shot setting + final = torch.einsum('ijk,lk->ijl', features, class_vectors) + else: + projected_x = self.projection(features) + projected_class = self.projection(class_vectors) + final = torch.einsum('ijk,lk->ijl', projected_x, projected_class) + # CRF + batch, N, C = final.shape + + if self.configs.token_classification: + return final + + if self.configs.no_projection: + vals = final.view(batch, N, C, 1)[:, 1:N] + vals = vals.expand(batch, N-1, C, C).clone() # without transition + else: + vals = final.view(batch, N, C, 1)[:, 1:N] + self.transition.weight.view(1, 1, C, C) + + + vals[:, 0, :, :] += final.view(batch, N, 1, C)[:, 0] + dist = LinearChainCRF(vals, lengths=lengths) + return dist + + +class CRFFeatureHead(nn.Module): + def __init__(self, configs, feature_size, n_classes): + super(CRFFeatureHead, self).__init__() + self.configs = configs + C = n_classes + + # Prepare LSTM + if configs.use_bilstm: + hidden_size = configs.bilstm_hidden_size + self.bilstm = BiLSTM(feature_size, hidden_size) + self.dropout = nn.Dropout(p=configs.bilstm_dropout) + encoder_size = 2 * hidden_size + print('Prepare BiLSTM') + else: + encoder_size = feature_size + + # Maps into output space. + self.hidden2output = nn.Linear(encoder_size, C) + + # Parameters for CRF layer + self.transition = nn.Linear(C, C) # For CRF + + def forward(self, features, lengths): + # Bidirectional LSTM + Projection Layer + if self.configs.use_bilstm: + x = self.dropout(self.bilstm(features)) + else: + x = features + final = self.hidden2output(x) + # CRF + batch, N, C = final.shape + + vals = final.view(batch, N, C, 1)[:, 1:N] + self.transition.weight.view(1, 1, C, C) + # vals = final.view(batch, N, C, 1)[:, 1:N] + # vals = vals.expand(batch, N-1, C, C).clone() # without transition + vals[:, 0, :, :] += final.view(batch, N, 1, C)[:, 0] + dist = LinearChainCRF(vals, lengths=lengths) + return dist + + +class ZeroShotCollapsedTransitionCRFHead(nn.Module): + def __init__(self, configs, feature_size, n_classes, proj_dim, class_vectors, max_classes=None): + super(ZeroShotCollapsedTransitionCRFHead, self).__init__() + self.configs = configs + C = n_classes + if not max_classes: + m = n_classes + else: + m = max_classes + self.C = C + self.m = m + self.proj_dim = proj_dim + encoder_size = feature_size + self.feature_size = feature_size + class_vectors = F.normalize(class_vectors, dim=1) + self.class_vectors = nn.Parameter(class_vectors, requires_grad=False) + null_vec = F.normalize(torch.rand(1,feature_size), dim=1) + self.null_vec = nn.Parameter(null_vec, requires_grad=False) # for class null, + # Reference vectors + noise_scale = 0.01 + self.reference_vecs = nn.Parameter(torch.eye(m, encoder_size, dtype=torch.float) + torch.rand((m, encoder_size))*noise_scale) + # Transition matrix + self.self_transition_diag = nn.Parameter(torch.rand(proj_dim)) # actually diagonal + self.null_transition_diag = nn.Parameter(torch.rand(proj_dim)) + # projection matrix + self.register_buffer('M', torch.zeros((feature_size, proj_dim))) + # load classes vectors + print('initializing projection M...') + self.update_projection() + + + + + def update_projection(self): + ''' + class_vectors: (C, encoder_size) + m: number of reference vectors + d: dimension of projection + + Output: + M: projection matrix (encoder_size, d) + ''' + class_vectors = torch.cat([self.null_vec, self.class_vectors], dim=0).detach() # (C, feature_size) + ref_vec = self.reference_vecs.detach() + C, encoder_size = class_vectors.shape + # TODO: compute modified reference vector + mod_ref_vec = (ref_vec - torch.mean(ref_vec, dim=0)) * C/(C-1) + + D = F.normalize(class_vectors, dim=1) - F.normalize(mod_ref_vec[:C, :], dim=1) + # QR decomposition of D + Q,R = torch.qr(D.transpose(0,1), some=False) # complete QR + M = Q[:, C: (C+self.proj_dim)] # take proj_dim columns from Q + assert(M.size(0) == self.feature_size) + assert(M.size(1) == self.proj_dim) + return M + + + + def compute_transition(self, projected_ref, projected_x): + ''' + projected_x: (batch, seq, proj_dim) + projected_ref: (m, proj_dim) + + self.transition_mat: (proj_dim, proj_dim) + + + Output: + transition matrix (batch, seq, m, m ) + ''' + + batch, seq, proj_dim = projected_x.shape + m = projected_ref.size(0) + full_transition = torch.zeros((batch, seq, m,m)).to(projected_ref.device) + self_transition_mat = torch.diag(self.self_transition_diag) + null_transition_mat = torch.diag(self.null_transition_diag) + trans_ii = projected_x.reshape(-1, self.proj_dim).matmul(self_transition_mat).matmul(projected_ref.transpose(0, 1)) #(batch*seq, m) + trans_ii = trans_ii.reshape(batch, seq, m ) + + projected_null = projected_ref[0, :] # (proj_dim) + trans_0 = projected_x.reshape(-1, self.proj_dim).matmul(null_transition_mat).matmul(projected_null.unsqueeze(1)) #(batch*seq, 1) + # compute self-attention + # cross_attn = torch.einsum("ij,ij->ii", projected_ref, projected_ref) # (m, m) + # attn_scores= F.softmax(cross_attn, dim=1) #(m,m) + # attn_vecs = torch.einsum("ik,ij->ikj", attn_scores, projected_ref).sum(dim=1) # (m, proj_dim) + for i in range(m): + full_transition[:, :, i,i] = trans_ii[:, :, i] + full_transition[:, :, 0, i] = trans_0.reshape(batch, seq) + full_transition[:, :, i, 0] = trans_0.reshape(batch, seq) + + return full_transition + + + def update_params(self): + self.M = self.update_projection() + return + + def normalize_params(self): + self.reference_vecs.data = F.normalize(self.reference_vecs, dim=1) + return + + def regularize_params(self): + ''' + Orthonormal regularization. + ''' + return torch.norm(self.reference_vecs.matmul(self.reference_vecs.t()) - torch.eye(self.m).to(self.reference_vecs.device), p=2) + + def forward(self, features, lengths): + + x = features # (batch, seq, encoder_size) + projected_x = torch.matmul(x, self.M) # (batch, seq, proj_dim) + projected_ref = torch.matmul(self.reference_vecs[:self.C, :], self.M) # (m, proj_dim) + + final = torch.einsum('ijk,lk->ijl', projected_x, projected_ref) # (batch, seq, m) + + if self.configs.token_classification: + return final + + transition = self.compute_transition(projected_ref, projected_x) # (batch, seq, m, m ) + # CRF + batch, N, C = final.shape + + if self.configs.use_transition: + vals = final.view(batch, N, C, 1)[:, 1:N] + transition[:, 1:N, :, :] + else: + vals = final.view(batch, N, C, 1)[:, 1:N] + vals = vals.expand(batch, N-1, C, C).clone() + vals[:, 0, :, :] += final.view(batch, N, 1, C)[:, 0] + dist = LinearChainCRF(vals, lengths=lengths) + return dist + + +class ZeroShotCRFFeatureHead(nn.Module): + def __init__(self, configs, feature_size, n_classes, proj_dim, class_vectors, max_classes=None): + super(ZeroShotCRFFeatureHead, self).__init__() + self.configs = configs + C = n_classes + if not max_classes: + m = n_classes + else: + m = max_classes + self.C = C + self.m = m + self.proj_dim = proj_dim + encoder_size = feature_size + self.class_vectors = class_vectors + self.null_vec = nn.Parameter((torch.rand(encoder_size))) # for class null + # Reference vectors + noise_scale = 0.01 + self.reference_vecs = nn.Parameter(torch.eye((m, encoder_size), dtype=torch.float) + torch.rand((m, encoder_size))*noise_scale) + # Transition matrix + self.transition_mat = nn.Parameter(torch.rand(2*proj_dim, C*C)) + self.transition = nn.Parameter(torch.zeros((C*C)), requires_grad=False) + # projection matrix + self.M = nn.Parameter(torch.zeros((feature_size, proj_dim)), requires_grad=False) + # load classes vectors + + + + def update_projection(self): + ''' + class_vectors: (C, encoder_size) + m: number of reference vectors + d: dimension of projection + + Output: + M: projection matrix (encoder_size, d) + ''' + class_vectors = torch.cat([self.class_vectors, self.null_vec], dim=0) + C, encoder_size = class_vectors.shape + D = F.normalize(class_vectors, dim=1) - F.normalize(self.reference_vecs.weight[:C], dim=1) + # QR decomposition of D + Q,R = torch.qr(D, some=False) # complete QR + M = Q[:, C: (C+self.proj_dim+1)] # take proj_dim columns from Q + + return M + + + + def compute_transition(self, C): + ''' + M: projection matrix (encoder_size, d) + transition_mat: 2d*c *c + + + Output: + transition matrix + ''' + projected_ref = torch.matmul(self.reference_vecs, self.M) # (m, proj_dim) + # compute self-attention + cross_attn = torch.einsum("ij,ij->ii", projected_ref, projected_ref) # (m, m) + attn_scores= F.softmax(cross_attn, dim=1) #(m,m) + attn_vecs = torch.einsum("ik,ij->ikj", attn_scores, projected_ref).sum(dim=1) # (m, proj_dim) + transition = torch.zeros((C, C)).to(self.reference_vecs.device) + for i in range(C): + for j in range(C): + score = torch.dot(self.transition_mat, torch.cat([attn_vecs[i,:], attn_vecs[j,:]], dim=0)) + transition[i,j] = score + + return transition + + + def update_params(self): + self.M = self.update_projection() + self.transition = self.compute_transition(self.C) + return + + def forward(self, features, lengths): + + x = features # (batch, seq, encoder_size) + projected_x = torch.matmul(x, self.M) # (batch, seq, proj_dim) + projected_ref = torch.matmul(self.reference_vecs[:self.C, :], self.M) # (m, proj_dim) + final = torch.einsum('ijk,lk->ijl', projected_x, projected_ref) # (batch, seq, m) + # CRF + batch, N, C = final.shape + vals = final.view(batch, N, C, 1)[:, 1:N] + self.transition.view(1, 1, C, C) + vals[:, 0, :, :] += final.view(batch, N, 1, C)[:, 0] + dist = LinearChainCRF(vals, lengths=lengths) + return dist diff --git a/zstagger/model.py b/zstagger/model.py new file mode 100644 index 0000000..4f513c1 --- /dev/null +++ b/zstagger/model.py @@ -0,0 +1,321 @@ +import os +import torch +import logging +import json +from collections import defaultdict + +import pytorch_lightning as pl +from transformers import AutoModel, AutoTokenizer, AutoConfig +from transformers import AdamW, get_linear_schedule_with_warmup +from torch_struct import LinearChainCRF + +from .layers import CRFFeatureHead, BERTEncoder +from .utils import load_ontology, evaluate_trigger_f1, load_role_mapping, evaluate_arg_f1 + + +logger = logging.getLogger(__name__) + +class TaggerModel(pl.LightningModule): + def __init__(self, args): + super().__init__() + self.hparams = args + + + self.config=AutoConfig.from_pretrained(args.pretrained_model) + self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model) + self.n_classes = self.hparams.event_n +1 # IO tagging + # network internals + self.encoder = BERTEncoder(args, self.config.hidden_size) + self.feature_size = self.config.hidden_size # extra bert layer + + self.head = CRFFeatureHead(self.hparams, self.feature_size, self.n_classes) + + + def forward(self, inputs, stage='training'): + + lengths = inputs['word_lengths'] + features = self.encoder(input_ids=inputs['input_token_ids'], + attention_mask=inputs['input_attn_mask'], + token_lens=inputs['token_lens']) # (batch,word_seq_len, hidden_dim) + + dist = self.head(features, lengths) + + + if stage=='training': + assert(torch.max(lengths) == inputs['word_tags'].size(1)) + assert(torch.max(inputs['word_tags']) < self.n_classes) # should be [0, C-1] + label_ = LinearChainCRF.struct.to_parts(inputs['word_tags'], self.n_classes, + lengths=lengths).type_as(dist.log_potentials) + + + loss = -dist.log_prob(label_).mean() + return loss + + else: + # Compute predictions + argmax = dist.argmax + preds = dist.from_event(argmax)[0] # (batch, seq) + return preds + + + def training_step(self, batch, batch_idx): + ''' + 'doc_key': ex['sent_id'], + 'input_token_ids':input_tokens['input_ids'], + 'input_attn_mask': input_tokens['attention_mask'], + 'labels': labels, + 'bpe_mapping': bpe2word, + 'token_lens': token_lens, + 'word_tags': word_tags, + + ''' + loss = self.forward(batch, stage='training') + + log = { + 'train/loss': loss, + } + return { + 'loss': loss, + 'log': log + } + + + def validation_step(self,batch, batch_idx): + preds = self.forward(batch, stage='testing') + + return { + "doc_key": batch['doc_key'], + "chunk_idx" : batch['chunk_idx'], + # "input_ids": batch["input_token_ids"].cpu().tolist(), + # "attention_mask": batch["input_attn_mask"].cpu().tolist(), + "preds": preds.cpu().tolist(), + "word_tags": batch['word_tags'].cpu().tolist(), + "word_lengths": batch["word_lengths"].cpu().tolist(), + # "labels": batch["labels"].cpu().tolist(), + # "bpe_mapping": batch['bpe_mapping'].cpu().tolist() + } + + + def validation_epoch_end(self, outputs): + predictions = defaultdict(list) + + # aggregate predictions with the same doc_key + for batch_dict in outputs: + for i in range(len(batch_dict['doc_key'])): + # length = sum(batch_dict['attention_mask'][i]) + length=batch_dict['word_lengths'][i] + doc_key = batch_dict['doc_key'][i] + chunk_idx = batch_dict['chunk_idx'][i] + # input_ids = batch_dict['input_ids'][i][:length] + pred = batch_dict['preds'][i][:length] + labels = batch_dict['word_tags'][i][:length] + # bpe_mapping = batch_dict['bpe_mapping'][i][:length] + predictions[doc_key].append({ + 'pred_tags': pred, + 'labels': labels, + 'chunk_idx': chunk_idx + }) + + combined_predictions = {} # doc_key -> predictions + for doc_key, pred_list in predictions.items(): + # sort by chunk_idx and merge + sorted_pred_list = sorted(pred_list, key=lambda x:x['chunk_idx']) + combined_predictions[doc_key] = { + 'pred_tags': [tag for pred in sorted_pred_list for tag in pred['pred_tags']], + 'labels': [label for pred in sorted_pred_list for label in pred['labels']] + } + if self.hparams.task == 'trigger': + with open(self.hparams.val_file,'r') as f: + for line in f: + ex = json.loads(line) + if self.hparams.dataset == 'ACE': + combined_predictions[ex['sent_id']]['event_mentions'] = ex['event_mentions'] + else: + combined_predictions[ex['doc_id']]['event_mentions'] = ex['event_mentions'] + ontology_dict = load_ontology(self.hparams.dataset) + tgr_id_f1, tgr_f1 = evaluate_trigger_f1(combined_predictions, ontology_dict) + log = { + 'val/id_f1': torch.Tensor([tgr_id_f1,]), + 'val/f1': torch. Tensor([tgr_f1,]) + } + return { + 'f1': tgr_f1, + 'log': log + } + else: + with open(self.hparams.val_file, 'r') as f: + for line in f: + ex = json.loads(line) + if self.hparams.dataset == 'ACE': + for doc_key in combined_predictions: + sent_id, evt_idx = doc_key.split(':') + if sent_id == ex['sent_id']: + combined_predictions[doc_key]['event_mentions'] = [ex['event_mentions'][int(evt_idx)], ] + combined_predictions[doc_key]['entity_mentions'] = ex['entity_mentions'] + else: + for doc_key in combined_predictions: + doc_id, evt_idx = doc_key.split(':') + if doc_id == ex['doc_id']: + combined_predictions[doc_key]['event_mentions'] = [ex['event_mentions'][int(evt_idx)], ] + combined_predictions[doc_key]['entity_mentions'] = ex['entity_mentions'] + combined_predictions[doc_key]['sent_lens'] = [len(sent[0]) for sent in ex['sentences']] + + with open('checkpoints/{}/predictions.json'.format(self.hparams.ckpt_name),'w') as f: + json.dump(combined_predictions, f) + + role_mapping = load_role_mapping(self.hparams.dataset) + arg_id_f1 , arg_f1 = evaluate_arg_f1(combined_predictions, role_mapping) + log = { + 'val/id_f1':arg_id_f1, + 'val/f1': arg_f1 + } + return { + 'f1': arg_f1, + 'log': log + } + + + + + def test_step(self, batch, batch_idx): + + preds = self.forward(batch, stage='testing') + + return { + "doc_key": batch['doc_key'], + "chunk_idx" : batch['chunk_idx'], + # "input_ids": batch["input_token_ids"].cpu().tolist(), + # "attention_mask": batch["input_attn_mask"].cpu().tolist(), + "preds": preds.cpu().tolist(), + "word_tags": batch['word_tags'].cpu().tolist(), + "word_lengths": batch["word_lengths"].cpu().tolist(), + # "labels": batch["labels"].cpu().tolist(), + # "bpe_mapping": batch['bpe_mapping'].cpu().tolist() + } + + def test_epoch_end(self, outputs): + ontology_dict = load_ontology(self.hparams.dataset) + predictions = defaultdict(list) + + # aggregate predictions with the same doc_key + for batch_dict in outputs: + for i in range(len(batch_dict['doc_key'])): + # length = sum(batch_dict['attention_mask'][i]) + length=batch_dict['word_lengths'][i] + doc_key = batch_dict['doc_key'][i] + chunk_idx = batch_dict['chunk_idx'][i] + # input_ids = batch_dict['input_ids'][i][:length] + pred = batch_dict['preds'][i][:length] + labels = batch_dict['word_tags'][i][:length] + # bpe_mapping = batch_dict['bpe_mapping'][i][:length] + predictions[doc_key].append({ + 'pred_tags': pred, + 'labels': labels, + 'chunk_idx': chunk_idx + }) + + combined_predictions = {} # doc_key -> predictions + for doc_key, pred_list in predictions.items(): + # sort by chunk_idx and merge + sorted_pred_list = sorted(pred_list, key=lambda x:x['chunk_idx']) + combined_predictions[doc_key] = { + 'pred_tags': [tag for pred in sorted_pred_list for tag in pred['pred_tags']], + 'labels': [label for pred in sorted_pred_list for label in pred['labels']] + } + + + if self.hparams.task == 'trigger': + with open(self.hparams.test_file,'r') as f: + for line in f: + ex = json.loads(line) + if self.hparams.dataset == 'ACE': + combined_predictions[ex['sent_id']]['event_mentions'] = ex['event_mentions'] + else: + combined_predictions[ex['doc_id']]['event_mentions'] = ex['event_mentions'] + with open('checkpoints/{}/predictions.json'.format(self.hparams.ckpt_name),'w') as f: + json.dump(combined_predictions, f) + + ontology_dict = load_ontology(self.hparams.dataset) + tgr_id_f1, tgr_f1 = evaluate_trigger_f1(combined_predictions, ontology_dict) + log = { + 'val/id_f1': torch.Tensor([tgr_id_f1,]), + 'val/f1': torch. Tensor([tgr_f1,]) + } + return { + 'f1': tgr_f1, + 'log': log + } + else: + with open(self.hparams.test_file, 'r') as f: + for line in f: + ex = json.loads(line) + if self.hparams.dataset == 'ACE': + for doc_key in combined_predictions: + sent_id, evt_idx = doc_key.split(':') + if sent_id == ex['sent_id']: + combined_predictions[doc_key]['event_mentions'] = [ex['event_mentions'][int(evt_idx)], ] + combined_predictions[doc_key]['entity_mentions'] = ex['entity_mentions'] + else: + for doc_key in combined_predictions: + doc_id, evt_idx = doc_key.split(':') + if doc_id == ex['doc_id']: + combined_predictions[doc_key]['event_mentions'] = [ex['event_mentions'][int(evt_idx)], ] + combined_predictions[doc_key]['entity_mentions'] = ex['entity_mentions'] + combined_predictions[doc_key]['sent_lens'] = [len(sent[0]) for sent in ex['sentences']] + + with open('checkpoints/{}/predictions.json'.format(self.hparams.ckpt_name),'w') as f: + json.dump(combined_predictions, f) + + role_mapping = load_role_mapping(self.hparams.dataset) + arg_id_f1 , arg_f1 = evaluate_arg_f1(combined_predictions, role_mapping) + log = { + 'val/id_f1':arg_id_f1, + 'val/f1': arg_f1 + } + return { + 'f1': arg_f1, + 'log': log + } + + + def configure_optimizers(self): + self.train_len = len(self.train_dataloader()) + if self.hparams.max_steps > 0: + t_total = self.hparams.max_steps + self.hparams.num_train_epochs = self.hparams.max_steps // self.train_len // self.hparams.accumulate_grad_batches + 1 + else: + t_total = self.train_len // self.hparams.accumulate_grad_batches * self.hparams.num_train_epochs + + logger.info('{} training steps in total.. '.format(t_total)) + + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in self.encoder.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": self.hparams.bert_weight_decay, + "lr":self.hparams.bert_learning_rate, + }, + { + "params": [p for n, p in self.encoder.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + "lr":self.hparams.bert_learning_rate, + }, + {"params": [p for n, p in self.head.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": self.hparams.weight_decay + }, + {"params": [p for n, p in self.head.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0 + } + + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) + # scheduler is called only once per epoch by default + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total) + scheduler_dict = { + 'scheduler': scheduler, + 'interval': 'step', + 'name': 'linear-schedule', + } + + return [optimizer, ], [scheduler_dict,] \ No newline at end of file diff --git a/zstagger/utils.py b/zstagger/utils.py new file mode 100644 index 0000000..f7be14d --- /dev/null +++ b/zstagger/utils.py @@ -0,0 +1,350 @@ +import json +from lemminflect import getInflection + +### Constants +MAX_LENGTH=200 +WORD_START_CHAR='\u0120' +from spacy.tokens import Doc + +PRONOUN_FILE='pronoun_list.txt' +pronoun_set = set() +with open(PRONOUN_FILE, 'r') as f: + for line in f: + pronoun_set.add(line.strip()) + + +def check_pronoun(text): + if text.lower() in pronoun_set: + return True + else: + return False + +def expand_keywords_inflection(keywords): + ''' + Takes a list of keywords and return the expanded list. + ''' + VERB_TAGS = ['VB', 'VBD', 'VBG', 'VBN','VBP', 'VBZ'] + results = set() + for keyword in keywords: + results.add(keyword) + for tag in VERB_TAGS: + inflected = getInflection(keyword, tag=tag) + results.update(set(inflected)) + return list(results) + +class WhitespaceTokenizer: + def __init__(self, vocab): + self.vocab = vocab + + def __call__(self, text): + words = text.split(" ") + return Doc(self.vocab, words=words) + +def find_head(arg_start, arg_end, doc): + cur_i = arg_start + while doc[cur_i].head.i >= arg_start and doc[cur_i].head.i <=arg_end: + if doc[cur_i].head.i == cur_i: + # self is the head + break + else: + cur_i = doc[cur_i].head.i + + arg_head = cur_i + + return (arg_head, arg_head) + +### Utilities +def safe_div(num, denom): + if denom > 0: + return num / denom + else: + return 0 + +def compute_f1(predicted, gold, matched): + precision = safe_div(matched, predicted) + recall = safe_div(matched, gold) + f1 = safe_div(2 * precision * recall, precision + recall) + return precision, recall, f1 + + +def load_ontology(dataset): + ''' + Read ontology file for event to argument mapping. + ''' + ontology_dict ={} + with open('event_role_{}.json'.format(dataset),'r') as f: + ontology_dict = json.load(f) + + for evt_name, evt_dict in ontology_dict.items(): + for i, argname in enumerate(evt_dict['roles']): + evt_dict['arg{}'.format(i+1)] = argname + # argname -> role is not a one-to-one mapping + if argname in evt_dict: + evt_dict[argname].append('arg{}'.format(i+1)) + else: + evt_dict[argname] = ['arg{}'.format(i+1)] + + return ontology_dict + +def load_role_mapping(dataset): + ''' + Get label mapping for arg extraction. + ''' + with open('role_label_{}.json'.format(dataset), 'r') as f: + role_mapping = json.load(f) + + return role_mapping + + + +# (start, end, type) +def get_pred_tgr_mentions(ex, tag2event_type): + pred_mentions = set() + prev_tag = 0 + cur_start = 0 + cur_end = 0 + if 'bpe_mapping' in ex: # need to map bpe tokens to words + for i in range(len(ex['input_ids'])): + pred_tag = ex['pred_tags'][i] + word_idx = ex['bpe_mapping'][i] + if (word_idx != -1) and (word_idx != ex['bpe_mapping'][i-1]): + # predicting the beginning of a word + if (pred_tag > 0): + if (prev_tag !=pred_tag): + if prev_tag > 0: + # end the previous span + pred_mentions.add((cur_start, cur_end, tag2event_type[prev_tag])) + # the beginning of a new span + cur_start = word_idx + cur_end = word_idx +1 + + else: cur_end += 1# continue the current span + else: + if (prev_tag > 0 ): + # end of a span + pred_mentions.add((cur_start, cur_end, tag2event_type[prev_tag])) + + prev_tag = pred_tag + else: + for i in range(len(ex['pred_tags'])): + pred_tag = ex['pred_tags'][i] + if pred_tag > 0: + if (prev_tag!= pred_tag): + if (prev_tag>0): + # close the prev tag + pred_mentions.add((cur_start, cur_end, tag2event_type[prev_tag])) + cur_start = i + cur_end = i+1 + else: + cur_end = i+1 + else: + if (prev_tag >0): + pred_mentions.add((cur_start, cur_end, tag2event_type[prev_tag])) + + prev_tag = pred_tag + + + return pred_mentions + +def get_pred_arg_mentions_io(ex, tag2role): + mentions = set() + prev_tag = 0 + cur_start = 0 + cur_end = 0 + for i in range(len(ex['pred_tags'])): + tag = ex['pred_tags'][i] + if tag > 0: # not a O-tag + if tag != prev_tag: # begin a new span + if prev_tag > 0: + # close the prev tag + mentions.add((cur_start, cur_end, tag2role[prev_tag])) + cur_start = i + cur_end = i+1 + else: # should be a continuation + cur_end = i+1 + else: # tag is O + if prev_tag >0: + mentions.add((cur_start, cur_end, tag2role[prev_tag])) + + prev_tag = tag + # last tag + if prev_tag > 0: + mentions.add((cur_start, cur_end, tag2role[prev_tag])) + + return mentions + + +def get_pred_arg_mentions_bio(ex, tag2role, b_tags, i_tags): + ''' + Return a set of predicted args. + Predicted tags are on the word level. + ''' + mentions = set() + prev_tag = 0 + cur_start = 0 + cur_end = 0 + cur_role = None + for i in range(len(ex['pred_tags'])): + tag = ex['pred_tags'][i] + if tag > 0: # not a O-tag + if tag in b_tags: # begin a new span + if prev_tag > 0: + # close the prev tag + mentions.add((cur_start, cur_end, tag2role[prev_tag])) + cur_start = i + cur_end = i+1 + cur_role = tag2role[tag] + elif tag in i_tags: # should be a continuation + if cur_role == tag2role[tag]: + # labeling is correct + cur_end = i+1 + else: + # labeling is wrong, ignore this + tag = 0 + else: # tag is O + if prev_tag >0: + mentions.add((cur_start, cur_end, tag2role[prev_tag])) + + prev_tag = tag + # last tag + if prev_tag > 0: + mentions.add((cur_start, cur_end, tag2role[prev_tag])) + + return mentions + + + + + +def find_ent_span(ex, ent_id, offset=0): + ''' + The ent_span from entity_mentions is from the document. + If the document has been chunked, an offset is needed to align with the predictions. + ''' + matched_ent = [ent for ent in ex['entity_mentions'] if ent['id'] == ent_id][0] + return (matched_ent['start']-offset, matched_ent['end']-offset) + + +def get_tag_mapping(role_mapping): + tag2role = {} + b_tags = set() + i_tags = set() + for role in role_mapping: + if 'b-label' in role_mapping[role]: + tag = role_mapping[role]['b-label'] + b_tags.add(tag) + tag2role[tag] = role + tag = role_mapping[role]['i-label'] + i_tags.add(tag) + tag2role[tag] = role + return tag2role, b_tags, i_tags + +def evaluate_arg_f1(predictions, role_mapping): + + + tag2role, b_tags, i_tags = get_tag_mapping(role_mapping) + gold_cnt = 0 + arg_idn_cnt = 0 + arg_cls_cnt = 0 + pred_cnt = 0 + + for ex in predictions.values(): + gold_mentions = set() + for event_dict in ex['event_mentions']: + offset = 0 + if 'sent_lens' in ex: # need to consider offset + sent_idx = event_dict['trigger']['sent_idx'] + offset = sum([ex['sent_lens'][i] for i in range(sent_idx)]) + # print(offset) + + for arg in event_dict['arguments']: + role = arg['role'] + ent_id = arg['entity_id'] + span = find_ent_span(ex, ent_id, offset) + gold_mentions.add((span[0], span[1], role)) + gold_cnt += len(gold_mentions) + + # pred_mentions = get_pred_arg_mentions(ex, tag2role, b_tags, i_tags) + pred_mentions = get_pred_arg_mentions_io(ex, tag2role) + pred_cnt += len(pred_mentions) + + for tup in pred_mentions: + start, end, role = tup + gold_idn = {item for item in gold_mentions if item[0]==start and item[1]==end} + + if gold_idn: + arg_idn_cnt +=1 + gold_cls = {item for item in gold_idn if item[2] == role} + if gold_cls: + arg_cls_cnt+=1 + + + arg_id_prec, arg_id_rec, arg_id_f = compute_f1( + pred_cnt, gold_cnt, arg_idn_cnt) + arg_prec, arg_rec, arg_f = compute_f1( + pred_cnt, gold_cnt, arg_cls_cnt) + + + print('Argument identification: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format( + arg_id_prec * 100.0, arg_id_rec * 100.0, arg_id_f * 100.0)) + print('Argument: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format( + arg_prec * 100.0, arg_rec * 100.0, arg_f * 100.0)) + + return arg_id_f, arg_f + + + + + + +def evaluate_trigger_f1(predictions, ontology_dict): + + def get_tag_mapping(ontology_dict): + tag2event_type = {} + for et in ontology_dict: + tag = ontology_dict[et]['i-label'] + tag2event_type[tag] = et + + return tag2event_type + + tag2event_type = get_tag_mapping(ontology_dict) + gold_cnt = 0 + trigger_idn_cnt = 0 + trigger_cls_cnt = 0 + pred_cnt = 0 + + for ex in predictions.values(): + gold_mentions = set() + for event_dict in ex['event_mentions']: + gold_mentions.add((event_dict['trigger']['start'], event_dict['trigger']['end'], event_dict['event_type'])) + gold_cnt += len(gold_mentions) + + pred_mentions = get_pred_tgr_mentions(ex,tag2event_type) + + pred_cnt += len(pred_mentions) + + for tup in pred_mentions: + start, end, evt_type = tup + gold_idn = {item for item in gold_mentions if item[0]==start and item[1]==end} + + if gold_idn: + trigger_idn_cnt +=1 + gold_cls = {item for item in gold_idn if item[2] == evt_type} + if gold_cls: + trigger_cls_cnt+=1 + + + tgr_id_prec, tgr_id_rec, tgr_id_f = compute_f1( + pred_cnt, gold_cnt, trigger_idn_cnt) + tgr_prec, tgr_rec, tgr_f = compute_f1( + pred_cnt, gold_cnt, trigger_cls_cnt) + + + print('Trigger identification: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format( + tgr_id_prec * 100.0, tgr_id_rec * 100.0, tgr_id_f * 100.0)) + print('Trigger: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format( + tgr_prec * 100.0, tgr_rec * 100.0, tgr_f * 100.0)) + + return tgr_id_f, tgr_f + + diff --git a/zstagger/zs_model.py b/zstagger/zs_model.py new file mode 100644 index 0000000..08b1a74 --- /dev/null +++ b/zstagger/zs_model.py @@ -0,0 +1,305 @@ +import os +import torch +import logging +import json +from collections import defaultdict + +import pytorch_lightning as pl +from transformers import AutoModel, AutoTokenizer, AutoConfig +from transformers import AdamW, get_linear_schedule_with_warmup +from torch_struct import LinearChainCRF +import torch.nn.functional as F + +from .layers import CRFFeatureHead, BERTEncoder, ZeroShotCRFFeatureHead, ZeroShotCollapsedTransitionCRFHead, PrototypeNetworkHead +from .utils import load_ontology, evaluate_trigger_f1 + + +logger = logging.getLogger(__name__) + +class ZSTaggerModel(pl.LightningModule): + def __init__(self, args): + super().__init__() + self.hparams = args + + + self.config=AutoConfig.from_pretrained(args.pretrained_model) + self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model) + self.n_classes = self.hparams.event_n +1 # IO tagging + # network internals + self.encoder = BERTEncoder(args, self.config.hidden_size) + self.feature_size = self.config.hidden_size + + if self.hparams.dataset == 'ACE': + class_vectors = torch.load('all_class_vec.pt') + elif self.hparams.dataset == 'KAIROS': + class_vectors = torch.load('all_class_vec_KAIROS.pt') + print('loading KAIROS vectors...') + + assert(class_vectors.shape[0] == self.hparams.event_n) + + if self.hparams.model == 'zs-crf': + self.head = ZeroShotCollapsedTransitionCRFHead(self.hparams, self.feature_size, + self.n_classes, + self.hparams.proj_dim, + class_vectors,) + elif self.hparams.model == 'proto': + self.head = PrototypeNetworkHead(self.hparams, self.feature_size, self.n_classes,class_vectors) + + + def forward(self, inputs, stage='training'): + + lengths = inputs['word_lengths'] + features = self.encoder(input_ids=inputs['input_token_ids'], + attention_mask=inputs['input_attn_mask'], + token_lens=inputs['token_lens']) # (batch,word_seq_len, hidden_dim) + if self.hparams.token_classification: + scores = self.head(features, lengths) # (batch, seq, C) + batch_size, seq, C = scores.shape + if stage == 'training': + labels = inputs['word_tags'] # batch, seq + loss = F.cross_entropy(scores.reshape(-1, C), labels.reshape(-1), ignore_index=-1, reduction='sum') + loss = loss/ batch_size # per sequence loss + return loss + else: + preds = torch.argmax(scores, dim=2) + return preds + + + else: + dist = self.head(features, lengths) + + + if stage=='training': + assert(torch.max(lengths) == inputs['word_tags'].size(1)) + assert(torch.max(inputs['word_tags']) < self.n_classes) # should be [0, C-1] + label_ = LinearChainCRF.struct.to_parts(inputs['word_tags'], self.n_classes, + lengths=lengths).type_as(dist.log_potentials) + + + loss = -dist.log_prob(label_).mean() + if hasattr(self.head, 'regularize_params'): + reg = self.head.regularize_params() + loss += self.hparams.reg_weight * reg + return loss + + else: + # Compute predictions + argmax = dist.argmax + preds = dist.from_event(argmax)[0] # (batch, seq) + return preds + + def training_step(self, batch, batch_idx): + ''' + 'doc_key': ex['sent_id'], + 'input_token_ids':input_tokens['input_ids'], + 'input_attn_mask': input_tokens['attention_mask'], + 'labels': labels, + 'bpe_mapping': bpe2word, + 'token_lens': token_lens, + 'word_tags': word_tags, + + ''' + loss = self.forward(batch, stage='training') + + log = { + 'train/loss': loss, + } + return { + 'loss': loss, + 'log': log + } + + + def validation_step(self,batch, batch_idx): + preds = self.forward(batch, stage='testing') + + return { + "doc_key": batch['doc_key'], + "chunk_idx" : batch['chunk_idx'], + # "input_ids": batch["input_token_ids"].cpu().tolist(), + # "attention_mask": batch["input_attn_mask"].cpu().tolist(), + "preds": preds.cpu().tolist(), + "word_tags": batch['word_tags'].cpu().tolist(), + "word_lengths": batch["word_lengths"].cpu().tolist(), + # "labels": batch["labels"].cpu().tolist(), + # "bpe_mapping": batch['bpe_mapping'].cpu().tolist() + } + + + def validation_epoch_end(self, outputs): + ontology_dict = load_ontology(self.hparams.dataset) + + + predictions = defaultdict(list) + + # aggregate predictions with the same doc_key + for batch_dict in outputs: + for i in range(len(batch_dict['doc_key'])): + # length = sum(batch_dict['attention_mask'][i]) + length=batch_dict['word_lengths'][i] + doc_key = batch_dict['doc_key'][i] + chunk_idx = batch_dict['chunk_idx'][i] + # input_ids = batch_dict['input_ids'][i][:length] + pred = batch_dict['preds'][i][:length] + labels = batch_dict['word_tags'][i][:length] + # bpe_mapping = batch_dict['bpe_mapping'][i][:length] + predictions[doc_key].append({ + 'pred_tags': pred, + 'labels': labels, + 'chunk_idx': chunk_idx + }) + + combined_predictions = {} # doc_key -> predictions + for doc_key, pred_list in predictions.items(): + # sort by chunk_idx and merge + sorted_pred_list = sorted(pred_list, key=lambda x:x['chunk_idx']) + combined_predictions[doc_key] = { + 'pred_tags': [tag for pred in sorted_pred_list for tag in pred['pred_tags']], + 'labels': [label for pred in sorted_pred_list for label in pred['labels']] + } + with open(self.hparams.val_file,'r') as f: + for line in f: + ex = json.loads(line) + if self.hparams.dataset == 'ACE': + combined_predictions[ex['sent_id']]['event_mentions'] = ex['event_mentions'] + else: + combined_predictions[ex['doc_id']]['event_mentions'] = ex['event_mentions'] + + + tgr_id_f1, tgr_f1 = evaluate_trigger_f1(combined_predictions, ontology_dict) + + log = { + 'val/tgr_id_f1': torch.Tensor([tgr_id_f1,]), + 'val/tgr_f1': torch. Tensor([tgr_f1,]) + } + return { + 'f1': tgr_f1, + 'log': log + } + + + + + def test_step(self, batch, batch_idx): + + preds = self.forward(batch, stage='testing') + + return { + "doc_key": batch['doc_key'], + "chunk_idx" : batch['chunk_idx'], + # "input_ids": batch["input_token_ids"].cpu().tolist(), + # "attention_mask": batch["input_attn_mask"].cpu().tolist(), + "preds": preds.cpu().tolist(), + "word_tags": batch['word_tags'].cpu().tolist(), + "word_lengths": batch["word_lengths"].cpu().tolist(), + # "labels": batch["labels"].cpu().tolist(), + # "bpe_mapping": batch['bpe_mapping'].cpu().tolist() + } + + def test_epoch_end(self, outputs): + ontology_dict = load_ontology(self.hparams.dataset) + predictions = defaultdict(list) + + # aggregate predictions with the same doc_key + for batch_dict in outputs: + for i in range(len(batch_dict['doc_key'])): + # length = sum(batch_dict['attention_mask'][i]) + length=batch_dict['word_lengths'][i] + doc_key = batch_dict['doc_key'][i] + chunk_idx = batch_dict['chunk_idx'][i] + # input_ids = batch_dict['input_ids'][i][:length] + pred = batch_dict['preds'][i][:length] + labels = batch_dict['word_tags'][i][:length] + # bpe_mapping = batch_dict['bpe_mapping'][i][:length] + predictions[doc_key].append({ + 'pred_tags': pred, + 'labels': labels, + 'chunk_idx': chunk_idx + }) + + combined_predictions = {} # doc_key -> predictions + for doc_key, pred_list in predictions.items(): + # sort by chunk_idx and merge + sorted_pred_list = sorted(pred_list, key=lambda x:x['chunk_idx']) + combined_predictions[doc_key] = { + 'pred_tags': [tag for pred in sorted_pred_list for tag in pred['pred_tags']], + 'labels': [label for pred in sorted_pred_list for label in pred['labels']] + } + + with open(self.hparams.test_file,'r') as f: + for line in f: + ex = json.loads(line) + if self.hparams.dataset == 'ACE': + combined_predictions[ex['sent_id']]['event_mentions'] = ex['event_mentions'] + else: + combined_predictions[ex['doc_id']]['event_mentions'] = ex['event_mentions'] + + tgr_id_f1, tgr_f1 = evaluate_trigger_f1(combined_predictions, ontology_dict) + + with open('checkpoints/{}/predictions.json'.format(self.hparams.ckpt_name),'w') as f: + json.dump(combined_predictions, f) + + log = { + 'test/tgr_id_f1': torch.Tensor([tgr_id_f1,]), + 'test/tgr_f1': torch. Tensor([tgr_f1,]) + } + return { + 'f1': tgr_f1, + 'log': log + } + + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure=None, on_tpu=False, + using_native_amp=False, using_lbfgs=False): + optimizer.step(closure=optimizer_closure) + if hasattr(self.head, 'normalize_params'): + self.head.normalize_params() + + if batch_idx % 20== 0: + if hasattr(self.head, 'update_params'): + self.head.update_params() + return + + + + def configure_optimizers(self): + self.train_len = len(self.train_dataloader()) + if self.hparams.max_steps > 0: + t_total = self.hparams.max_steps + self.hparams.num_train_epochs = self.hparams.max_steps // self.train_len // self.hparams.accumulate_grad_batches + 1 + else: + t_total = self.train_len // self.hparams.accumulate_grad_batches * self.hparams.num_train_epochs + + logger.info('{} training steps in total.. '.format(t_total)) + + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in self.encoder.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": self.hparams.bert_weight_decay, + "lr":self.hparams.bert_learning_rate, + }, + { + "params": [p for n, p in self.encoder.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + "lr":self.hparams.bert_learning_rate, + }, + {"params": [p for n, p in self.head.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": self.hparams.weight_decay + }, + {"params": [p for n, p in self.head.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0 + } + + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) + # scheduler is called only once per epoch by default + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total) + scheduler_dict = { + 'scheduler': scheduler, + 'interval': 'step', + 'name': 'linear-schedule', + } + + return [optimizer, ], [scheduler_dict,] \ No newline at end of file From 3e3a3084fa3b077882c6ff914965768f044ea3d2 Mon Sep 17 00:00:00 2001 From: Zoey Li Date: Thu, 16 Jun 2022 19:17:04 -0500 Subject: [PATCH 2/4] file for creating pl_train --- zstagger/pseudo_label.py | 304 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 304 insertions(+) create mode 100644 zstagger/pseudo_label.py diff --git a/zstagger/pseudo_label.py b/zstagger/pseudo_label.py new file mode 100644 index 0000000..4cfe653 --- /dev/null +++ b/zstagger/pseudo_label.py @@ -0,0 +1,304 @@ +''' +Based on the keywords from the ontology, create pseudo labels for the training set. +''' + +import os +import json +import argparse +from typing import List, Dict, Tuple, Optional + +import torch +from transformers import BertTokenizer +from torch.utils.data import DataLoader +import torch.nn.functional as F +from nltk.corpus import stopwords + + +from data import IEDataset, adaptive_length_collate +from utils import load_ontology, expand_keywords_inflection +from layers import BERTEncoder +BATCH_SIZE=16 + +def get_vector_labels(ex:Dict, class_vectors: torch.FloatTensor, idx2event:Dict[int, str], + pos_thres: float=0.65, unk_thres: float=0.4, keyword_bonus: float=0.1): + ''' + Assign pseudo label to single instance. + ''' + vec = ex['vec'] # type: torch.FloatTensor + sim = F.normalize(vec, dim=1) @ F.normalize(class_vectors,dim=1).t() + max_sim_score, max_idx = torch.max(sim, dim=1) + max_class = max_idx +1 + stop_words = set(stopwords.words('english')) + predicted_mentions = [] + total_pred =0 + for i, score in enumerate(max_sim_score.tolist()): + evt_name = idx2event[max_class[i].item()] + keywords = ontology_dict[evt_name]['keywords'] + keywords = expand_keywords_inflection(set(keywords)) + text = ex['tokens'][i] + if text in stop_words: # stopwords are always O-label + continue + if ex['tokens'][i] in keywords: + score += keyword_bonus + if score >= pos_thres: + + predicted_mentions.append({ + 'evt_type': evt_name, + 'start': i, + 'score': score, + 'text': ex['tokens'][i] + } + ) + total_pred +=1 + elif score >=unk_thres: # uncertain + predicted_mentions.append({ + 'evt_type': 'unknown', + 'start': i, + 'score': score, + 'text': ex['tokens'][i] + }) + + if len(predicted_mentions) == 0: + return None, (0, 0, 0) + + res = { + 'doc_key': ex['sent_id'], + 'sentence': ex['sentence'], + 'gold_mentions':[], + 'predicted_mentions':predicted_mentions, + } + total_matched =0 + total_gold =0 + for e in ex['event_mentions']: + gold = { + 'evt_type': e['event_type'], + 'start': e['trigger']['start'], + 'score': max_sim_score[e['trigger']['start']].item(), + 'text': e['trigger']['text'] + } + total_gold +=1 + for m in predicted_mentions: + if m['start'] == gold['start'] and m['evt_type'] == gold['evt_type']: + total_matched += 1 + if m['start'] == gold['start'] and m['evt_type'] == 'unknown': + total_gold -=1 # ignore this gold mention + res['gold_mentions'].append(gold) + + return res , (total_gold, total_pred, total_matched) + + + +def move_batch(batch:Dict)-> Dict: + moved_batch = {} + for k, v in batch.items(): + if hasattr(v, 'dtype'): + # move to cuda + moved_batch[k] = v.to('cuda:0') + else: + moved_batch[k] = v + return moved_batch + + + +def embed_instances(args, encoder: BERTEncoder)-> None: + if not os.path.exists(args.tmp_dir): + raise FileNotFoundError + + dataset = IEDataset(os.path.join(args.tmp_dir, 'train.jsonl'),split='pl') + dataloader = DataLoader(dataset, + pin_memory=True, num_workers=4, + collate_fn=adaptive_length_collate, + batch_size=BATCH_SIZE, + shuffle=True) + + encoder = encoder.to('cuda:0') + encoder.eval() + + results = {} # doc_key -> tensor of (word_seq_len, hidden_dim) + with torch.no_grad(): + for batch in dataloader: + batch = move_batch(batch) + features = encoder(input_ids=batch['input_token_ids'], + attention_mask=batch['input_attn_mask'], + token_lens=batch['token_lens']) # (batch,word_seq_len, hidden_dim) + doc_key = batch['doc_key'] + lengths = batch['word_lengths'] + # removing padding + batch_size = features.size(0) + for i in range(batch_size): + vec = features[i, :lengths[i], :].cpu() + results[doc_key[i]] = vec + + + torch.save(results, f'training_embedded_{args.ontology}.pt') + print('training embedding saved ....') + return + + +def assign_pl(pos_exs: List, neg_exs:List, class_vectors: torch.FloatTensor, idx2event: Dict[int, str], + output_file: str='pl_training.jsonl', log_file: str='pl_label_errors.jsonl', + pos_thres: float=0.65, unk_thres:float=0.4, keyword_bonus: float=0.1): + ''' + :pos_thres: the instance has to score higher than this to be considered as positive + :unk_thres: if the instance is higher than this and lower than pos_thres, then considered as unknown + :keyword_bonus: if the keyword matches, then assign this bonus score + ''' + writer = open(output_file,'w') + error_logs = open(log_file,'w') + + total_gold = 0 + total_pred = 0 + total_matched = 0 + for ex in pos_exs: + res, stats = get_vector_labels(ex, class_vectors, idx2event, pos_thres, unk_thres, keyword_bonus) + if res == None or stats[1]==0: + continue + writer.write(json.dumps(res) + '\n') + total_gold += stats[0] + total_pred += stats[1] + total_matched += stats[2] + + if stats[0] != stats[2]: + error_logs.write(json.dumps(res) + '\n') + + + for ex in neg_exs:# nothing is being predicted + vec = ex['vec'] + sim = F.normalize(vec, dim=1) @ F.normalize(class_vectors,dim=1).t() + max_sim_score, max_idx = torch.max(sim, dim=1) + if torch.max(max_sim_score).item() > unk_thres: # grey area + continue + else: # use as negative + res = { + 'doc_key': ex['sent_id'], + 'sentence': ex['sentence'], + 'gold_mentions':[], + 'predicted_mentions':[], + } + for e in ex['event_mentions']: + gold = { + 'evt_type': e['event_type'], + 'start': e['trigger']['start'], + 'score': max_sim_score[e['trigger']['start']].item(), + 'text': e['trigger']['text'] + } + res['gold_mentions'].append(gold) + total_gold += len(res['gold_mentions']) + + if len(res['gold_mentions']) != 0: + error_logs.write(json.dumps(res) + '\n') + writer.write(json.dumps(res) + '\n') + writer.close() + # check label quality + recall = total_matched *1.0/ total_gold + precision = total_matched * 1.0/total_pred + f1 = 2* recall * precision/ (recall+precision) + + print('R:{}, P:{}, F1:{}'.format(recall, precision, f1)) + + return + + + + +def convert_pl_file(args, pl_file_name='pl_training.jsonl')-> None: + pl_exs = {} + with open(pl_file_name,'r') as f: + for line in f: + ex = json.loads(line) + pl_exs[ex['doc_key']] = ex['predicted_mentions'] + + writer = open(os.path.join(args.tmp_dir, 'pl_train.jsonl'),'w') + with open(os.path.join(args.tmp_dir, 'train.jsonl'),'r') as f: + for line in f: + ex = json.loads(line) + if ex['doc_key'] not in pl_exs: + continue + pred = pl_exs[ex['doc_key']] + pl_word_tags = [0,] * len(ex['token_lens']) + for e in pred: + evt_type = e['evt_type'] + i = e['start'] + if evt_type == 'unknown': + pl_word_tags[i] = -1 + else: + label_idx = ontology_dict[evt_type]['i-label'] + pl_word_tags[i] = label_idx + processed_ex = { + 'doc_key': ex['doc_key'], + 'input_token_ids':ex['input_token_ids'], + 'input_attn_mask': ex['input_attn_mask'], + 'labels': ex['labels'], + 'bpe_mapping': ex['bpe_mapping'], + 'token_lens': ex['token_lens'], + 'word_tags': pl_word_tags, + } + writer.write(json.dumps(processed_ex) +'\n') + writer.close() + return + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--ontology', type=str, default='KAIROS', choices=['ACE','KAIROS']) + parser.add_argument('--keyword_n',type=int, default=3, help='number of keywords to use for each event type.') + parser.add_argument('--tmp-dir',type=str, help='directory for preprocessed data. Will raise error if this directory does not exist.') + parser.add_argument('--data_file', type=str, default='data/ace/pro_mttrig_id/json/train.oneie.json') + parser.add_argument('--class-vectors-file',type=str) + args = parser.parse_args() + + + ontology_dict = load_ontology(args.ontology) + + idx2event = {} + for e in ontology_dict: + eidx = ontology_dict[e]['i-label'] + idx2event[eidx] = e + + all_keywords = set() + for e in ontology_dict: + keywords = set(ontology_dict[e]['keywords']) + keywords = expand_keywords_inflection(keywords) + all_keywords.update(keywords) + + tokenizer = BertTokenizer.from_pretrained('bert-large-cased') + encoder = BERTEncoder(args, bert_dim=1024) + + + if not os.path.exists(f'training_embedded_{args.ontology}.pt'): + embed_instances(args, encoder) + training_embedded = torch.load(f'training_embedded_{args.ontology}.pt') + + # collect positive and negative instances from the training data + training_exs = [] + pos_exs = [] + neg_exs = [] + with open(args.data_file) as f: + for line in f: + ex = json.loads(line) + doc_key = ex['sent_id'] + if doc_key not in training_embedded: + continue + vec = training_embedded[doc_key] + ex['vec'] = vec + has_keyword=False + for token in ex['tokens']: + if token in all_keywords: + pos_exs.append(ex) + has_keyword = True + break + if not has_keyword: + neg_exs.append(ex) + + training_exs.append(ex) + + + # assign pseudo labels + class_vectors = torch.load('all_class_vec_{}.pt'.format(args.ontology)) # type: torch.FloatTensor + assign_pl(pos_exs, neg_exs, class_vectors, idx2event, output_file='pl_training.jsonl', log_file='pl_label_errors.jsonl') + # convert pl file to preprocessed file format + convert_pl_file(args, pl_file_name='pl_training.jsonl') + + + + From 5c470efe7cbc8b89f5d54a18140a7e10a49670a3 Mon Sep 17 00:00:00 2001 From: Zoey Li Date: Thu, 16 Jun 2022 22:28:15 -0500 Subject: [PATCH 3/4] compute the class vector for keyword based tagging --- zstagger/class_vector.py | 236 +++++++++++++++++++++++++++++++++++++++ zstagger/pseudo_label.py | 2 +- 2 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 zstagger/class_vector.py diff --git a/zstagger/class_vector.py b/zstagger/class_vector.py new file mode 100644 index 0000000..3244205 --- /dev/null +++ b/zstagger/class_vector.py @@ -0,0 +1,236 @@ + +import torch +from datasets import load_dataset +from transformers import BertTokenizer, BertForMaskedLM +from tqdm import tqdm +import argparse +import json + + +from utils import load_ontology, expand_keywords_inflection, MAX_LENGTH + +RARE_TYPES = {'Justice:Appeal','Justice:Extradite','Justice:Acquit', 'Justice:Convict', "Justice:Sentence", + "Justice.Acquit.Unspecified", + "Medical.Vaccinate.Unspecified", + "Movement.Transportation.IllegalTransportation", + "ArtifactExistence.DamageDestroyDisableDismantle.DisableDefuse", + "ArtifactExistence.DamageDestroyDisableDismantle.Dismantle" + } + + + +def get_label_occurrence(ex, evt_type): + ''' + ex: a dict object with 'document' and 'event_mentions' key. + evt_type: event type name from ontology. + + Assumes that only one instance of the evt_type appears in the sentence, which could be unsuitable. + + return: + matches: tuples of (trigger, word index) + ''' + match = None + for e in ex['event_mentions']: + if e['event_type'] == evt_type: + trigger = e['trigger']['text'] + widx = e['trigger']['start'] + match = (trigger, widx) + break + if match: + return { + 'document': ex['document'], + 'match': match, + } + else: + return None + +def get_keyword_occurrence(ex, keywords): + ''' + ex: a Gigaword document with 'document' key. + keywords: set of keywords + + return: + matches: tuples of (keyword, word index) + ''' + match = None + words = ex['document'].split() + for widx, w in enumerate(words): + if w in keywords: + match=(w, widx) + # only match one instance of one event type per sentence + break + if match: + return { + 'document': ex['document'], + 'match': match, + } + else: + return None + + +def get_substitute(m_dict_list, model, tokenizer, top_k=50, use_mask=False, strategy='first', bert_dim=1024): + ''' + m_dict: List of match dictionary with 'document' and 'match' keys. + model: a pretrained bert model + strategy: 'first' use first subword token, 'mean' average tokens. only useful when use_mask=False. Not implemented for batch. + + ''' + def tokenize_instance(m_dict): + words = m_dict['document'].split() + match_idx = m_dict['match'][1] + prefix = tokenizer.tokenize(' '.join(words[:match_idx]) ) + mask_idx = len(prefix) +1 # add [CLS] + token_len =1 + if use_mask: + suffix = tokenizer.tokenize(' '.join(words[match_idx+1:])) + encoded = tokenizer.encode(prefix + [tokenizer.mask_token, ] + suffix, + add_special_tokens=True, + padding='max_length', + truncation=True, + max_length=MAX_LENGTH) + else: + token_len = len(tokenizer.tokenize(words[match_idx])) + suffix = tokenizer.tokenize(' '.join(words[match_idx:])) + encoded = tokenizer.encode(prefix + suffix, + add_special_tokens=True, + padding='max_length', + truncation=True, + max_length=MAX_LENGTH) + return { + 'encoded': torch.LongTensor(encoded), + 'mask_idx': mask_idx, + 'token_len': token_len + } + tokenized_batch = [tokenize_instance(ex) for ex in m_dict_list] + batch_size = len(tokenized_batch) + mask_idx_batch = torch.LongTensor([ex['mask_idx'] for ex in tokenized_batch]).to(model.device) # (batch) + token_len_batch = torch.LongTensor([ex['token_len'] for ex in tokenized_batch]).to(model.device) + token_ids = torch.stack([ex['encoded'] for ex in tokenized_batch]).to(model.device) #(batch, max_len) + + outputs = model(token_ids, output_hidden_states=True)#(batch, max_len, hidden_dim) + vocab_size = outputs[0].size(2) + prediction_scores = torch.gather(outputs[0], 1, + mask_idx_batch.unsqueeze(-1).unsqueeze(-1).expand(batch_size, 1, vocab_size)).squeeze(1) #(batch, vocab) + vec = torch.gather(outputs[1][-1], 1, + mask_idx_batch.unsqueeze(-1).unsqueeze(-1).expand(batch_size, 1, bert_dim)).squeeze(1) #(batch, hidden_dim) + vec = vec.cpu() + # prediction_scores = outputs[0].squeeze(0)[mask_idx] + # if use_mask or strategy == 'first' : + # vec = outputs[1][-1].squeeze(0)[mask_idx] + # else: + # vec = torch.mean(outputs[1][-1].squeeze(0)[mask_idx: mask_idx+ token_len]) + + + # decode one by one + top_scores = torch.argsort(prediction_scores, dim=1, descending=True)[:, :top_k] + top_scores = top_scores.cpu() + top_sub_list = [] + for i in range(batch_size): + top_subs = tokenizer.decode(top_scores[i, :]).split() + top_sub_list.append(top_subs) + return top_sub_list , vec + +def compute_sub_score(top_subs, match, keywords): + score = 0 + for idx, sub in enumerate(top_subs): + rank = idx +1 + if sub == match['match'][0]: + score += 1 + elif sub in keywords: + score += 1/ rank + return score + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='gigaword') + parser.add_argument('--data_file', type=str, default='data/ace/pro_mttrig_id/json/train.oneie.json') + parser.add_argument('--ontology', type=str, default='KAIROS', choices=['ACE', 'KAIROS']) + parser.add_argument('--keyword_n', type=int, default=3) + parser.add_argument('--min_score', type=float, default=1.0) + parser.add_argument('--keep_all', action='store_true') + parser.add_argument('--from_labels', action='store_true') + args = parser.parse_args() + + + ontology_dict = load_ontology(args.ontology) + docs = [] + if args.dataset == 'gigaword': + gigaword = load_dataset('gigaword') + docs = gigaword['train'] + elif args.dataset == 'training': + docs = [] + with open(args.data_file,'r') as f: + for line in f: + ex = json.loads(line) + docs.append({ + 'document': ex['sentence'], + 'event_mentions': ex['event_mentions'] + }) + + + tokenizer = BertTokenizer.from_pretrained('bert-large-cased') + model = BertForMaskedLM.from_pretrained('bert-large-cased') + + model = model.to('cuda:0') + BATCH_SIZE=64 + + event_vectors = {} + for eidx, event in enumerate(ontology_dict): + keywords = set(ontology_dict[event]['keywords'][:args.keyword_n]) + keywords = expand_keywords_inflection(keywords) + matches = [] + for i in range(len(docs)): + ex = docs[i] + if args.from_labels: + m_dict = get_label_occurrence(ex, event) + else: + m_dict = get_keyword_occurrence(ex, keywords) + + if m_dict: + matches.append(m_dict) + if len(matches) == 100: + break + if len(matches) == 0: + print('{} has no occurrences'.format(event)) + continue + + vec_list = [] + accepted_m = [] + start_idx = 0 + with torch.no_grad(): + with tqdm(total=len(matches)) as pbar: + while start_idx < len(matches): + match_list = matches[start_idx:start_idx + BATCH_SIZE] + top_sub_list, vec = get_substitute(match_list, model, tokenizer) + for i in range(len(match_list)): + score = compute_sub_score(top_sub_list[i], match_list[i], keywords) + if event in RARE_TYPES or len(keywords) == 1 or score > args.min_score: + vec_list.append(vec[i, :]) + accepted_m.append(match_list[i]) + start_idx += BATCH_SIZE + pbar.update(len(match_list)) + + + + + if len(vec_list) == 0: + print('{} has no accepted occurrences'.format(event)) + continue + if args.keep_all: + # don't do average + class_vector = torch.stack(vec_list, dim=0) + else: + class_vector = torch.stack(vec_list, dim=0).mean(dim=0) + event_vectors[event] = class_vector + + with open('class_vectors_{}_{}.pkl'.format(args.ontology,args.dataset),'wb') as f: + torch.save(event_vectors, f) + + C = len(ontology_dict) + # convert dictionary into single matrix for model input + all_cv = torch.zeros((C+1, 1024)) + for e in ontology_dict: + vector = event_vectors[e] + idx = ontology_dict[e]['i-label'] + all_cv[idx, :] = vector + torch.save(all_cv[1:, :], 'all_class_vec_{}.pt'.format(args.ontology)) \ No newline at end of file diff --git a/zstagger/pseudo_label.py b/zstagger/pseudo_label.py index 4cfe653..f04a509 100644 --- a/zstagger/pseudo_label.py +++ b/zstagger/pseudo_label.py @@ -257,7 +257,7 @@ def convert_pl_file(args, pl_file_name='pl_training.jsonl')-> None: all_keywords = set() for e in ontology_dict: - keywords = set(ontology_dict[e]['keywords']) + keywords = set(ontology_dict[e]['keywords'][:args.keyword_n]) keywords = expand_keywords_inflection(keywords) all_keywords.update(keywords) From b9d95942b98accb7cd6fe3feb9015fee7710b3c3 Mon Sep 17 00:00:00 2001 From: Zoey Li Date: Wed, 8 Feb 2023 10:48:18 -0600 Subject: [PATCH 4/4] moved ontology related files to new dir' --- pronoun_list.txt => constant/pronoun_list.txt | 0 docs/index.md | 5 +- .../aida_ontology_cleaned.csv | 0 ontology/entity_types.json | 122 ++++++++++++++++++ .../event_role_ACE.json | 0 .../event_role_KAIROS.json | 0 6 files changed, 124 insertions(+), 3 deletions(-) rename pronoun_list.txt => constant/pronoun_list.txt (100%) rename aida_ontology_cleaned.csv => ontology/aida_ontology_cleaned.csv (100%) create mode 100644 ontology/entity_types.json rename event_role_ACE.json => ontology/event_role_ACE.json (100%) rename event_role_KAIROS.json => ontology/event_role_KAIROS.json (100%) diff --git a/pronoun_list.txt b/constant/pronoun_list.txt similarity index 100% rename from pronoun_list.txt rename to constant/pronoun_list.txt diff --git a/docs/index.md b/docs/index.md index b0fb92d..202ebfe 100644 --- a/docs/index.md +++ b/docs/index.md @@ -16,8 +16,7 @@ The ACE dataset is provided by LDC and unfortunately we cannot release it direct ## Code -- v0.1 (April 12, 2021): Basic generation model for argument extraction. (This does not include the post-processing script.) +- v1 (April 12, 2021): Basic generation model for argument extraction. (This does not include the post-processing script.) +- v2: The event trigger detection model is under the branch `tapkey`. - -This page is currently under construction. diff --git a/aida_ontology_cleaned.csv b/ontology/aida_ontology_cleaned.csv similarity index 100% rename from aida_ontology_cleaned.csv rename to ontology/aida_ontology_cleaned.csv diff --git a/ontology/entity_types.json b/ontology/entity_types.json new file mode 100644 index 0000000..bf4943d --- /dev/null +++ b/ontology/entity_types.json @@ -0,0 +1,122 @@ +{ + "0":{ + "Type":"ABS", + "Output Value for Type":"abs", + "Definition":"Abstract, non-tangible artifacts such as software (e.g., programs, tool kits, apps, e-mail), measureable intellectual property, contracts, etc. (nb: does not include laws, which are LAW type)" + }, + "1":{ + "Type":"AML", + "Output Value for Type":"aml", + "Definition":"Animal, a non-human living organism which feeds on organic matter, typically having specialized sense organs and a nervous system and able to respond rapidly to stimuli" + }, + "2":{ + "Type":"BAL", + "Output Value for Type":"bal", + "Definition":"A ballot for an election, either the paper ballot used for voting, including both physical paper ballots and also the slate of candidates and ballot questions" + }, + "3":{ + "Type":"BOD", + "Output Value for Type":"bod", + "Definition":"An identifiable, living part of a human's or animal's body, such as a eye, ear, neck, leg, etc." + }, + "4":{ + "Type":"COM", + "Output Value for Type":"com", + "Definition":"A tangible product or article of trade for which someone pays or barters, or more generally, an artifact or a thing" + }, + "5":{ + "Type":"FAC", + "Output Value for Type":"fac", + "Definition":"A functional, primarily man-made structure. Facilities are artifacts falling under the domains of architecture and civil engineering, including more temporary human constructs, such as police lines and checkpoints." + }, + "6":{ + "Type":"GPE", + "Output Value for Type":"gpe", + "Definition":"Geopolitical entities such as countries, provinces, states, cities, towns, etc. GPEs are composite entities, consisting of a physical location, a government, and a population. All three of these elements must be present for an entity to be tagged as a GPE. A GPE entity may be a single geopolitical entity or a group." + }, + "7":{ + "Type":"INF", + "Output Value for Type":"inf", + "Definition":"An information object such as a field of study or topic of communication, including thoughts, opinions, etc." + }, + "8":{ + "Type":"LAW", + "Output Value for Type":"law", + "Definition":"A law, an act that is voted on by either a legislative body or an electorate, such as a law, referendum, act, regulation, statute, ordinance, etc." + }, + "9":{ + "Type":"LOC", + "Output Value for Type":"loc", + "Definition":"Geographical entities such as geographical areas and landmasses, bodies of water" + }, + "10":{ + "Type":"MHI", + "Output Value for Type":"mhi", + "Definition":"Any medical condition or health issue, to include everything from disease to broken bones to fever to general ill health, medical errors, even natural causes" + }, + "11":{ + "Type":"MON", + "Output Value for Type":"mon", + "Definition":"A monetary payment. The extent of a Money mention includes modifying quantifiers, the amount, and the currency unit, all of which can be optional." + }, + "12":{ + "Type":"NAT", + "Output Value for Type":"nat", + "Definition":"Valuable materials or substances, such as minerals, forests, water, and fertile land, that are not man-made, occur naturally within the environment and can be used for economic gain" + }, + "13":{ + "Type":"ORG", + "Output Value for Type":"org", + "Definition":"Corporations, agencies, and other groups of people defined by an established organizational structure. An ORG entity may be a single organization or a group. A key feature of an ORG is that it can change members without changing identity." + }, + "14":{ + "Type":"PER", + "Output Value for Type":"per", + "Definition":"Person entities are limited to humans. A PER entity may be a single person or a group." + }, + "15":{ + "Type":"PLA", + "Output Value for Type":"pla", + "Definition":"Plants\/flora as well as edible fungi such as mushrooms; multicellular living organisms, typically growing in the earth and lacking the power of locomotion, ex. grass and crops such as wheat, beans, fruit, etc." + }, + "16":{ + "Type":"PTH", + "Output Value for Type":"pth", + "Definition":"An infectious microorganism or agent, such as a virus, bacterium, protozoan, prion, viroid, or fungus" + }, + "17":{ + "Type":"RES", + "Output Value for Type":"res", + "Definition":"The results of a voting event. This will cover general results as well as counted results." + }, + "18":{ + "Type":"SEN", + "Output Value for Type":"sen", + "Definition":"The judicial or court sentence in a Justice event, the punishment a judge gives to a defendant found guilty of a crime" + }, + "19":{ + "Type":"SID", + "Output Value for Type":"sid", + "Definition":"The different sides of a conflict, such as philosophical, cultural, ideological, religious, political, guiding philosophical movement or group orientation. This will encompass sides of the battle\/conflict, sports fans when salient, and other such affiliations, in addition to religions, political parties, and other philosophies." + }, + "20":{ + "Type":"TTL", + "Output Value for Type":"ttl", + "Definition":"A person\u2019s title or job role" + }, + "21":{ + "Type":"VAL", + "Output Value for Type":"val", + "Definition":"A numerical value or non-numerical value such as an informational property such as color or make or URL" + }, + "22":{ + "Type":"VEH", + "Output Value for Type":"veh", + "Definition":"A physical device primarily designed to move an object from one location to another, by (for example) carrying, flying, pulling, or pushing the transported object. Vehicle entities may or may not have their own power source." + }, + "23":{ + "Type":"WEA", + "Output Value for Type":"wea", + "Definition":"A physical device that is primarily used as an instrument for physically harming or destroying entities" + } + } \ No newline at end of file diff --git a/event_role_ACE.json b/ontology/event_role_ACE.json similarity index 100% rename from event_role_ACE.json rename to ontology/event_role_ACE.json diff --git a/event_role_KAIROS.json b/ontology/event_role_KAIROS.json similarity index 100% rename from event_role_KAIROS.json rename to ontology/event_role_KAIROS.json