From a2ed39395950cc15b9682a2618f873df8451f1df Mon Sep 17 00:00:00 2001 From: dshani Date: Sat, 8 May 2021 15:18:26 +0100 Subject: [PATCH] added visualise to transformer and test --- models/base_transformer.py | 19 ++++--- test.py | 112 +++++++++++++++++++++++++++---------- 2 files changed, 93 insertions(+), 38 deletions(-) diff --git a/models/base_transformer.py b/models/base_transformer.py index 4f740c3..9e9b096 100644 --- a/models/base_transformer.py +++ b/models/base_transformer.py @@ -268,9 +268,9 @@ def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, self.dropout = nn.Dropout(rate) - def forward(self, x, mask): + def forward(self, x, mask, visualise=False): # x shape is batch_size x input_seq_len - + latents = [] seq_len = x.size()[1] # adding embedding and position encoding @@ -283,9 +283,14 @@ def forward(self, x, mask): x = self.dropout(x) for i in range(self.num_layers): + latents.append(x) x = self.enc_layers[i](x, mask) - return x # (batch_size, input_seq_len, d_model) + if visualise: + return x, latents + + else: + return x # (batch_size, input_seq_len, d_model) class Decoder(nn.Module): @@ -394,9 +399,9 @@ def forward(self, x, enc_output, look_ahead_mask, padding_mask): class Transformer(nn.Module): - def __init__(self, num_layers = 4, num_heads = 4, dff = 256, - d_model = 64, input_vocab_size = 1500, target_vocab_size = 1500, - pe_input = 1500, pe_target = 1500, rate=0.1): + def __init__(self, num_layers=4, num_heads=4, dff=256, + d_model=64, input_vocab_size=1500, target_vocab_size=1500, + pe_input=1500, pe_target=1500, rate=0.1): super(Transformer, self).__init__() self.encoder = Encoder(num_layers, d_model, num_heads, dff, @@ -428,4 +433,4 @@ def forward(self, inp, tar, enc_mask, look_ahead_mask, dec_mask): out, attn = model(x, y, None, None, None) - print(out) \ No newline at end of file + print(out) diff --git a/test.py b/test.py index 7b11a1e..c3a5af3 100644 --- a/test.py +++ b/test.py @@ -4,7 +4,7 @@ import torch import torch.nn.functional as F -import numpy as np +import numpy as np import time from tokenizers import Tokenizer from common.preprocess import detokenize, tokenize @@ -16,7 +16,8 @@ from common.test_arguments import test_parser from hyperparams.loader import Loader from common.utils import to_devices, accuracy_fn, mask_after_stop, get_all_directions, get_pairs, get_directions - +from sklearn.manifold import TSNE +import matplotlib.pyplot as plt def greedy_search(x, y, y_tar, model, enc_mask=None): """Inference loop taking the most probable token at each step.""" @@ -38,29 +39,34 @@ def greedy_search(x, y, y_tar, model, enc_mask=None): return torch.cat(y_pred, dim=1) -def single_beam_search(x, y, y_tar, model, enc_mask=None, beam_length=2): +def single_beam_search(x, y, y_tar, model, enc_mask=None, beam_length=2, visualise=False): """ x : (seq_len) y : ([]) tensor of start token max_len maximum length greater than the length of x to consider """ - x_enc = model.encoder(x.unsqueeze(0), enc_mask.unsqueeze(0)) # (1, seq_len, d_model) - x_enc = x_enc.repeat(beam_length, 1, 1) # (beam, seq_len, d_model) + x_enc, x_latents = model.encoder(x.unsqueeze(0), enc_mask.unsqueeze(0), visualise) # (1, seq_len, d_model) + x_enc = x_enc.repeat(beam_length, 1, 1) # (beam, seq_len, d_model) + + if visualise: + y_inp, y_tar = y[:, :-1], y[:, 1:] + enc_mask_aux = base_transformer.create_mask(y_inp) + y_enc, y_latents = model.encoder(y_inp.unsqueeze(0), enc_mask_aux.unsqueeze(0), visualise) + decode = lambda y: F.log_softmax(model.final_layer(model.decoder(y, x_enc, None, None)[0]), dim=-1) - y = y.reshape(1, 1).repeat(beam_length, 1) # (beam, 1) + y = y.reshape(1, 1).repeat(beam_length, 1) # (beam, 1) log_p = torch.zeros(beam_length).to(y.device) for t in range(y_tar.size(0)): with torch.no_grad(): - # expand beams - y_pred = decode(y)[:, -1, :] # (beam, vocab) - new_log_p = (log_p.unsqueeze(-1) + y_pred).reshape(-1) # (beam * vocab) + y_pred = decode(y)[:, -1, :] # (beam, vocab) + new_log_p = (log_p.unsqueeze(-1) + y_pred).reshape(-1) # (beam * vocab) # trim beams - log_p, beam_idxs = torch.topk(new_log_p, beam_length) # (beam,) + log_p, beam_idxs = torch.topk(new_log_p, beam_length) # (beam,) beam_id, new_token = beam_idxs // y_pred.size(-1), beam_idxs % y_pred.size(-1) # update input @@ -68,23 +74,32 @@ def single_beam_search(x, y, y_tar, model, enc_mask=None, beam_length=2): best_beam = log_p.argmax() y = y[best_beam] + if visualise: + return y[1, :], x_latents, y_latents + else: + return y[1:] - return y[1:] - -def beam_search(x, y, y_tar, model, enc_mask=None, beam_length=2): +def beam_search(x, y, y_tar, model, enc_mask=None, beam_length=2, visualise=False): + latents_x = [] + latents_y = [] preds = [] for i in range(x.size(0)): enc_mask_i = enc_mask[i] if enc_mask is not None else None - preds.append(single_beam_search( + pred_i, x_latent, y_latent = single_beam_search( x[i], y[i], y_tar[i], model, - enc_mask=enc_mask_i, beam_length=beam_length) - ) - return torch.stack(preds, dim=0) + enc_mask=enc_mask_i, beam_length=beam_length, visualise=visualise) + preds.append(pred_i) + if visualise: + latents_x.append(x_latent) + latents_y.append(y_latent) + return torch.stack(preds, dim=0), latents_x, latents_y + else: + return torch.stack(preds, dim=0) def inference_step(x, y, model, logger, tokenizer, device, bleu=None, - teacher_forcing=False, pivot_mode=False, beam_length=1): + teacher_forcing=False, pivot_mode=False, beam_length=1, visualise=False): """ inference step. x: source language @@ -125,19 +140,27 @@ def inference_step(x, y, model, logger, tokenizer, device, bleu=None, if beam_length == 1: y_pred = greedy_search(x, y, y_tar, model, enc_mask=enc_mask) else: - y_pred = beam_search(x, y, y_tar, model, enc_mask=enc_mask, beam_length=beam_length) - - if not pivot_mode: - batch_acc = 0 - if bleu is not None: - bleu(y_pred, y_tar) - logger.log_examples(x, y_tar, y_pred, tokenizer) - return batch_acc + if visualise: + y_pred, latents_x, latents_y = beam_search(x, y, y_tar, model, enc_mask=enc_mask, + beam_length=beam_length, visualise=visualise) + else: + y_pred = beam_search(x, y, y_tar, model, enc_mask=enc_mask, beam_length=beam_length, + visualise=visualise) + + if visualise: + return latents_x, latents_y else: - return y_pred + if not pivot_mode: + batch_acc = 0 + if bleu is not None: + bleu(y_pred, y_tar) + logger.log_examples(x, y_tar, y_pred, tokenizer) + return batch_acc + else: + return y_pred -def test(device, params, test_dataloader, tokenizer, verbose=50): +def test(device, params, test_dataloader, tokenizer, verbose=50, visualise=False): """Test loop""" logger = logging.TestLogger(params) @@ -176,6 +199,33 @@ def test(device, params, test_dataloader, tokenizer, verbose=50): direction = params.langs[0] + '-' + params.langs[1] logger.log_results([direction, test_acc, test_bleu]) logger.dump_examples() + if visualise: + x_lat, y_lat = inference_step(x, y, model, logger, tokenizer, device, bleu=bleu, + teacher_forcing=params.teacher_forcing, + beam_length=params.beam_length, visualise=visualise) + x_latents = [] + y_latents = [] + layer_length = len(x_lat[0]) + for i in range(layer_length): + latents_i =[] + for j in range(len(x_lat)): + latents_i.append(x_lat[j][i]) + x_latents.append(latents_i) + + layer_length = len(y_lat[0]) + for i in range(layer_length): + latents_i = [] + for j in range(len(y_lat)): + latents_i.append(y_lat[j][i]) + y_latents.append(latents_i) + color = [1]*len(x_latents[0]) + [0]*len(y_latents[0]) + for i in range(len(x_latents)): + X = x_latents[i] + y_latents[i] + + tsne = TSNE() + transformed = tsne.fit_transform(X) + plt.scatter(transformed[:, 0], transformed[:, 1], color=color) + def multi_test(device, params, test_dataloader, tokenizer, verbose=50): @@ -190,12 +240,12 @@ def multi_test(device, params, test_dataloader, tokenizer, verbose=50): assert tokenizer is not None add_targets = preprocess.AddTargetTokens(params.langs, tokenizer) - pair_accs = {s+'-'+t : 0.0 for s, t in get_pairs(params.langs)} + pair_accs = {s + '-' + t: 0.0 for s, t in get_pairs(params.langs)} pair_bleus = {} for s, t in get_pairs(params.langs, excluded=params.excluded): _bleu = BLEU() _bleu.set_excluded_indices([0, 2]) - pair_bleus[s+'-'+t] = _bleu + pair_bleus[s + '-' + t] = _bleu test_acc = 0.0 start_ = time.time() @@ -287,7 +337,7 @@ def main(params): else: try: tokenizers = [Tokenizer.from_file(params.location + '/' + lang + '_tokenizer.json') for lang in - params.langs] + params.langs] except: tokenizers = None