diff --git a/images/vscode_attention.gif b/images/vscode_attention.gif new file mode 100644 index 0000000..1bb6f20 Binary files /dev/null and b/images/vscode_attention.gif differ diff --git a/notebooks/analyze_attention.ipynb b/notebooks/analyze_attention.ipynb new file mode 100644 index 0000000..ee66d2c --- /dev/null +++ b/notebooks/analyze_attention.ipynb @@ -0,0 +1,746 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[](https://github.com/lab-ml/python_autocomplete)\n", + "[](https://colab.research.google.com/github/lab-ml/python_autocomplete/blob/master/notebooks/analyze_attention.ipynb)\n", + "\n", + "# Analyze attention in Python Autocomplete model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install labml labml_python_autocomplete" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import string\n", + "import json\n", + "\n", + "import torch\n", + "from torch import nn\n", + "\n", + "import numpy as np\n", + "\n", + "from labml import experiment, logger, lab, analytics\n", + "from labml_helpers.module import Module\n", + "from labml.analytics import ModelProbe\n", + "from labml.logger import Text, Style, inspect\n", + "from labml.utils.pytorch import get_modules\n", + "from labml.utils.cache import cache\n", + "from labml_helpers.datasets.text import TextDataset\n", + "\n", + "from python_autocomplete.train import Configs\n", + "from python_autocomplete.evaluate import Predictor\n", + "from python_autocomplete.evaluate.beam_search import NextWordPredictionComplete" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We load the model from a training run. For this demo I'm loading from a run I trained at home.\n", + "\n", + "[](https://web.lab-ml.com/run?uuid=39b03a1e454011ebbaff2b26e3148b3d)\n", + "\n", + "If you have a locally trained model load it directly with:\n", + "\n", + "```python\n", + "run_uuid = 'RUN_UUID'\n", + "checkpoint = None # Get latest checkpoint\n", + "```\n", + "\n", + "`load_bundle` will download an archive with a saved checkpoint (pretrained model)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
Run a6cff3706ec411ebadd9bf753b33bae6 exists\n", + "Checkpoint 702925824 exists\n", + "Extract bundle...[DONE]\t1,613.73ms\n", + "" + ], + "text/plain": [ + "
Prepare model...\n", + " Prepare n_tokens...\n", + " Prepare tokenizer...[DONE]\t7.86ms\n", + " Prepare n_tokens...[DONE]\t13.67ms\n", + " Prepare transformer...[DONE]\t2.67ms\n", + " Prepare ffn...[DONE]\t1.20ms\n", + " Prepare device...\n", + " Prepare device_info...[DONE]\t1.80ms\n", + " Prepare device...[DONE]\t4.01ms\n", + "Prepare model...[DONE]\t219.50ms\n", + "" + ], + "text/plain": [ + "
Selected experiment = source_code run = a6cff3706ec411ebadd9bf753b33bae6 checkpoint = 702,925,824\n", + "Loading checkpoint...[DONE]\t404.82ms\n", + "\n", + "Notebook Experiment: c674962e784e11ec8390acde48001122\n", + "\t[dirty]: \"analyze attentions\"\n", + "\tloaded from: a6cff3706ec411ebadd9bf753b33bae6" + ], + "text/plain": [ + "
Prepare state_updater...[DONE]\t3.44ms\n", + "" + ], + "text/plain": [ + "
0: \"h\"\n", + "1: \"n\"\n", + "2: \", \"\n", + "3: \"c\"\n", + "4: \"n\"\n", + "Total 5 item(s)" + ], + "text/plain": [ + "
dtype: torch.float32\n", + "shape: [1, 1097]\n", + "min: 0.000000 max: 0.611870 mean: 0.000912 std: 0.020015\n", + "[\n", + " [0.000000, 0.000000, 0.000010, ..., 0.000000 ]\n", + "]" + ], + "text/plain": [ + "
0: \"transformer.layers.0.self_attn.softmax\"\n", + "1: \"transformer.layers.1.self_attn.softmax\"\n", + "2: \"transformer.layers.2.self_attn.softmax\"\n", + "3: \"transformer.layers.3.self_attn.softmax\"\n", + "4: \"transformer.layers.4.self_attn.softmax\"\n", + "5: \"transformer.layers.5.self_attn.softmax\"\n", + "Total 6 item(s)" + ], + "text/plain": [ + "
0: 279\n", + "1: 279\n", + "2: 1\n", + "3: 8\n", + "Total 4 item(s)" + ], + "text/plain": [ + "
dtype: torch.float32\n", + "shape: [6, 8, 279, 279]\n", + "min: 0.000000 max: 1.000000 mean: 0.003584 std: 0.026890\n", + "[\n", + " [\n", + " [[1.000000, 0.000000, 0.000000, ..., 0.000000], [0.993492, 0.006508, 0.000000, ..., 0.000000], [0.991854, 0.008065, 0.000081, ..., 0.000000], ..., [0.003070, 0.000385, 0.003941, ..., 0.000002]], \n", + " [[1.000000, 0.000000, 0.000000, ..., 0.000000], [0.996190, 0.003810, 0.000000, ..., 0.000000], [0.761333, 0.231494, 0.007174, ..., 0.000000], ..., [0.007664, 0.000290, 0.000628, ..., 0.000009]], \n", + " [[1.000000, 0.000000, 0.000000, ..., 0.000000], [0.999664, 0.000336, 0.000000, ..., 0.000000], [0.780615, 0.219290, 0.000096, ..., 0.000000], ..., [0.001629, 0.000160, 0.000875, ..., 0.000001]], \n", + " ..., \n", + " [[1.000000, 0.000000, 0.000000, ..., 0.000000], [0.999970, 0.000030, 0.000000, ..., 0.000000], [0.349430, 0.616987, 0.033583, ..., 0.000000], ..., [0.000142, 0.000123, 0.000777, ..., 0.000005]]\n", + " ], \n", + " [\n", + " [[1.000000, 0.000000, 0.000000, ..., 0.000000], [0.999963, 0.000037, 0.000000, ..., 0.000000], [0.893664, 0.099487, 0.006848, ..., 0.000000], ..., [0.000250, 0.000073, 0.000148, ..., 0.000001]], \n", + " [[1.000000, 0.000000, 0.000000, ..., 0.000000], [0.999436, 0.000564, 0.000000, ..., 0.000000], [0.733821, 0.216545, 0.049634, ..., 0.000000 ..." + ], + "text/plain": [ + "
Run a6cff3706ec411ebadd9bf753b33bae6 exists\n", + "Checkpoint 702925824 exists\n", + "Extract bundle...[DONE]\t1,534.94ms\n", + "" + ], + "text/plain": [ + "
Prepare model...\n", " Prepare n_tokens...\n", - " Prepare tokenizer...[DONE]\t3.41ms\n", - " Prepare n_tokens...[DONE]\t9.35ms\n", - " Prepare transformer...[DONE]\t1.55ms\n", - " Prepare ffn...[DONE]\t1.28ms\n", + " Prepare tokenizer...[DONE]\t4.84ms\n", + " Prepare n_tokens...[DONE]\t10.12ms\n", + " Prepare transformer...[DONE]\t2.95ms\n", + " Prepare ffn...[DONE]\t1.67ms\n", " Prepare device...\n", - " Prepare device_info...[DONE]\t70.64ms\n", - " Prepare device...[DONE]\t72.66ms\n", - "Prepare model...[DONE]\t1,830.13ms\n", + " Prepare device_info...[DONE]\t1.84ms\n", + " Prepare device...[DONE]\t4.15ms\n", + "Prepare model...[DONE]\t147.69ms\n", "" ], "text/plain": [ @@ -306,11 +326,11 @@ { "data": { "text/html": [ - "
Selected experiment = source_code run = a6cff3706ec411ebadd9bf753b33bae6 checkpoint = 702925824\n", - "Loading checkpoint...[DONE]\t107.99ms\n", + "Selected experiment = source_code run = a6cff3706ec411ebadd9bf753b33bae6 checkpoint = 702,925,824\n", + "Loading checkpoint...[DONE]\t389.90ms\n", "\n", - "Notebook Experiment: 294627686f4f11eb9fbb718949888c4e\n", - "[clean]: \"cleanup\"\n", + "Notebook Experiment: 07e6549a74ed11ec97e2acde48001122\n", + "\t[dirty]: \"serve without debug\"\n", "\tloaded from: a6cff3706ec411ebadd9bf753b33bae6" ], "text/plain": [ @@ -323,7 +343,7 @@ { "data": { "text/plain": [ - "" + " " ] }, "execution_count": 10, @@ -352,7 +372,7 @@ { "data": { "text/html": [ - " Prepare state_updater...[DONE]\t4.58ms\n", + "Prepare state_updater...[DONE]\t4.30ms\n", "" ], "text/plain": [ @@ -385,6 +405,22 @@ "_ = conf.model.eval()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Setup probing to extract attentions" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "probe = ModelProbe(conf.model)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -394,7 +430,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -422,7 +458,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -434,26 +470,26 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 53.5 ms, sys: 193 µs, total: 53.6 ms\n", - "Wall time: 52.2 ms\n" + "CPU times: user 193 ms, sys: 47.4 ms, total: 241 ms\n", + "Wall time: 218 ms\n" ] }, { "data": { "text/plain": [ - "[(0.07585373724739952, 'super'),\n", - " (0.0010850475927322429, '\"\"\"\\n '),\n", - " (0.0007989557343535125, ' ')]" + "[(0.07585359086023313, 'super'),\n", + " (0.0010850478390376612, '\"\"\"\\n '),\n", + " (0.0007989550358615816, ' ')]" ] }, - "execution_count": 33, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -482,7 +518,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.7.5" } }, "nbformat": 4, diff --git a/python_autocomplete/evaluate/__init__.py b/python_autocomplete/evaluate/__init__.py index 294bef7..98d0158 100644 --- a/python_autocomplete/evaluate/__init__.py +++ b/python_autocomplete/evaluate/__init__.py @@ -1,164 +1,16 @@ -from heapq import heappush, heappop -from typing import Any, Tuple, List, Optional, NamedTuple +from typing import Any, Tuple, List, NamedTuple import torch import torch.nn from torch import nn +from labml import monit from labml_helpers.module import Module -from python_autocomplete.dataset import Tokenizer, ID_CHARS +from python_autocomplete.dataset import Tokenizer +from python_autocomplete.evaluate.beam_search import PredictionComplete, BeamSearch, BeamSearchSimple +from python_autocomplete.evaluate.beam_search_lengthy import BeamSearchLengthy from python_autocomplete.train import StateUpdater -EPS_PROB = 1e-6 -MIN_BEAM_PROB = 1e-4 - - -class PredictionComplete: - def __call__(self, text, token_str: str): - raise NotImplementedError - - -class NextWordPredictionComplete(PredictionComplete): - def __init__(self, rest: str, min_length: int): - self.min_length = min_length - self.rest = rest - - def __call__(self, text, token_str: str): - if len(text) - len(self.rest) < self.min_length: - return False - - prev_is_id = text[-1] in ID_CHARS - last_is_id = token_str[-1] in ID_CHARS - - return prev_is_id != last_is_id - - -class BeamSearch: - def __init__(self, beam_size: int, prediction_complete: PredictionComplete, - max_beam_size: int, rest: str, - state_updater: 'StateUpdater', - probs: Optional[List[float]], - is_token_by_token: bool): - self.is_token_by_token = is_token_by_token - self.state_updater = state_updater - self.prediction_complete = prediction_complete - self.max_beam_size = max_beam_size - self.rest = rest - - if probs is None: - probs = [1 / beam_size] * beam_size - assert len(probs) == beam_size - self.probs = probs - - self.result_heap = [] - self.text = [''] * beam_size - self.beam_heap = [] - - @staticmethod - def is_substr(original, token_str): - if not original: - return True - - n = min(len(original), len(token_str)) - return original[:n] == token_str[:n] - - def add_prediction(self, prob: float, beam_idx: int, token_str: str, state): - if len(self.result_heap) == self.max_beam_size: - if self.result_heap[0][0] > prob - EPS_PROB: - return False - heappop(self.result_heap) - - state = self.state_updater.get_from_batch(state, beam_idx) - text = self.text[beam_idx] + token_str - heappush(self.result_heap, (prob, (text, state))) - - return True - - def add_prediction_before_token(self, prob: float, beam_idx: int, state): - if len(self.result_heap) == self.max_beam_size: - if self.result_heap[0][0] > prob - EPS_PROB: - return False - heappop(self.result_heap) - - state = self.state_updater.get_from_batch(state, beam_idx) - text = self.text[beam_idx] - heappush(self.result_heap, (prob, (text, state))) - - return True - - def add_beam(self, prob: float, beam_idx: int, token: int): - if self.result_heap and self.result_heap[0][0] > prob - EPS_PROB: - return False - - if prob < MIN_BEAM_PROB: - return False - - if len(self.beam_heap) == self.max_beam_size: - if self.beam_heap[0][0] > prob - EPS_PROB: - return False - heappop(self.beam_heap) - - heappush(self.beam_heap, (prob, (beam_idx, token))) - - return True - - def next_batch(self, prompt: torch.Tensor, state: Any, itos: List[str]): - if not self.beam_heap: - return None, None - - new_prompt = [] - new_state = [] - - texts = self.text - self.text = [] - self.probs = [] - - for prob, (b, token) in self.beam_heap: - token = prompt.new_tensor([token]) - if self.is_token_by_token: - new_prompt.append(token) - else: - new_prompt.append(torch.cat((prompt[1:, b], token))) - new_state.append(self.state_updater.get_from_batch(state, b)) - self.probs.append(prob) - self.text.append(texts[b] + itos[token]) - - new_prompt = torch.stack(new_prompt, dim=1) - new_state = self.state_updater.make_batch(new_state) - - self.beam_heap = [] - - return new_prompt, new_state - - def update(self, next_token, itos: List[str], state, old_state): - self.beam_heap = [] - - for b, text in enumerate(self.text): - text = self.text[b] - if len(text) >= len(self.rest): - check_rest = None - else: - check_rest = self.rest[len(text):] - - tokens = next_token[b] - sort_idx = torch.argsort(tokens) - - for i in reversed(range(len(tokens))): - token = sort_idx[i] - token_str = itos[token] - if not self.is_substr(check_rest, token_str): - continue - - if self.prediction_complete(text, token_str): - if not self.add_prediction_before_token(self.probs[b], b, old_state): - break - else: - break - # if not self.add_prediction(self.probs[b] * tokens[token].item(), b, token_str, state): - # break - if not self.add_beam(self.probs[b] * tokens[token].item(), b, token): - break - class Prediction(NamedTuple): prob: float @@ -199,13 +51,21 @@ def get_next_word(self, prompt: torch.Tensor, state: Any, rest: str, probs: List prediction_complete: PredictionComplete, max_beam_size: int) -> \ List[Prediction]: - beam = BeamSearch(prompt.shape[1], prediction_complete, max_beam_size, rest, self.state_updater, - probs, self.is_token_by_token) + beam = BeamSearchSimple(beam_size=prompt.shape[1], + prediction_complete=prediction_complete, + max_beam_size=max_beam_size, + rest=rest, + state_updater=self.state_updater, + probs=probs, + is_token_by_token=self.is_token_by_token, + itos=self.tokenizer.itos) for _ in range(10): - next_token, new_state = self._get_predictions(prompt, state) - beam.update(next_token, self.tokenizer.itos, new_state, state) - prompt, state = beam.next_batch(prompt, new_state, self.tokenizer.itos) + with monit.section('Predict', is_silent=True): + next_token, new_state = self._get_predictions(prompt, state) + with monit.section('Beam', is_silent=True): + beam.update(next_token, new_state, state) + prompt, state = beam.next_batch(prompt, new_state) if prompt is None: break diff --git a/python_autocomplete/evaluate/beam_search.py b/python_autocomplete/evaluate/beam_search.py new file mode 100644 index 0000000..0bc98fd --- /dev/null +++ b/python_autocomplete/evaluate/beam_search.py @@ -0,0 +1,190 @@ +from heapq import heappush, heappop +from typing import Any, List, Optional + +import torch +import torch.nn + +from python_autocomplete.dataset import ID_CHARS +from python_autocomplete.train import StateUpdater + +EPS_PROB = 1e-6 +MIN_BEAM_PROB = 1e-4 + + +class PredictionComplete: + def __call__(self, text, token_str: str): + raise NotImplementedError + + +class NextWordPredictionComplete(PredictionComplete): + def __init__(self, rest: str, min_length: int): + self.min_length = min_length + self.rest = rest + + def __call__(self, text, token_str: str): + if len(text) - len(self.rest) < self.min_length: + return False + + prev_is_id = text[-1] in ID_CHARS + last_is_id = token_str[-1] in ID_CHARS + + return prev_is_id != last_is_id + + +class NextWordNewLinePredictionComplete(PredictionComplete): + def __init__(self, rest: str, min_length: int): + self.min_length = min_length + self.rest = rest + + def __call__(self, text, token_str: str): + if len(text) - len(self.rest) < self.min_length: + return False + + if '\n' in token_str: + return True + + prev_is_id = text[-1] in ID_CHARS + last_is_id = token_str[-1] in ID_CHARS + + return prev_is_id != last_is_id + + +class BeamSearch: + def __init__(self): + pass + + def next_batch(self, prompt: torch.Tensor, state: Any): + raise NotImplementedError + + def update(self, next_token, state, old_state): + raise NotImplementedError + + +class BeamSearchSimple(BeamSearch): + def __init__(self, *, beam_size: int, prediction_complete: PredictionComplete, + max_beam_size: int, rest: str, + state_updater: 'StateUpdater', + probs: Optional[List[float]], + is_token_by_token: bool, + itos: List[str]): + super().__init__() + self.itos = itos + self.is_token_by_token = is_token_by_token + self.state_updater = state_updater + self.prediction_complete = prediction_complete + self.max_beam_size = max_beam_size + self.rest = rest + + if probs is None: + probs = [1 / beam_size] * beam_size + assert len(probs) == beam_size + self.probs = probs + + self.result_heap = [] + self.text = [''] * beam_size + self.beam_heap = [] + + @staticmethod + def is_substr(original, token_str): + if not original: + return True + + n = min(len(original), len(token_str)) + return original[:n] == token_str[:n] + + def add_prediction(self, prob: float, beam_idx: int, token_str: str, state): + if len(self.result_heap) == self.max_beam_size: + if self.result_heap[0][0] > prob - EPS_PROB: + return False + heappop(self.result_heap) + + state = self.state_updater.get_from_batch(state, beam_idx) + text = self.text[beam_idx] + token_str + heappush(self.result_heap, (prob, (text, state))) + + return True + + def add_prediction_before_token(self, prob: float, beam_idx: int, state): + if len(self.result_heap) == self.max_beam_size: + if self.result_heap[0][0] > prob - EPS_PROB: + return False + heappop(self.result_heap) + + state = self.state_updater.get_from_batch(state, beam_idx) + text = self.text[beam_idx] + heappush(self.result_heap, (prob, (text, state))) + + return True + + def add_beam(self, prob: float, beam_idx: int, token: int): + if self.result_heap and self.result_heap[0][0] > prob - EPS_PROB: + return False + + if prob < MIN_BEAM_PROB: + return False + + if len(self.beam_heap) == self.max_beam_size: + if self.beam_heap[0][0] > prob - EPS_PROB: + return False + heappop(self.beam_heap) + + heappush(self.beam_heap, (prob, (beam_idx, token))) + + return True + + def next_batch(self, prompt: torch.Tensor, state: Any): + if not self.beam_heap: + return None, None + + new_prompt = [] + new_state = [] + + texts = self.text + self.text = [] + self.probs = [] + + for prob, (b, token) in self.beam_heap: + token = prompt.new_tensor([token]) + if self.is_token_by_token: + new_prompt.append(token) + else: + new_prompt.append(torch.cat((prompt[1:, b], token))) + new_state.append(self.state_updater.get_from_batch(state, b)) + self.probs.append(prob) + self.text.append(texts[b] + self.itos[token]) + + new_prompt = torch.stack(new_prompt, dim=1) + new_state = self.state_updater.make_batch(new_state) + + self.beam_heap = [] + + return new_prompt, new_state + + def update(self, next_token, state, old_state): + self.beam_heap = [] + + for b, text in enumerate(self.text): + text = self.text[b] + if len(text) >= len(self.rest): + check_rest = None + else: + check_rest = self.rest[len(text):] + + tokens = next_token[b] + sort_idx = torch.argsort(tokens) + + for i in reversed(range(len(tokens))): + token = sort_idx[i] + token_str = self.itos[token] + if not self.is_substr(check_rest, token_str): + continue + + if self.prediction_complete(text, token_str): + if not self.add_prediction_before_token(self.probs[b], b, old_state): + break + else: + break + # if not self.add_prediction(self.probs[b] * tokens[token].item(), b, token_str, state): + # break + if not self.add_beam(self.probs[b] * tokens[token].item(), b, token): + break diff --git a/python_autocomplete/evaluate/beam_search_lengthy.py b/python_autocomplete/evaluate/beam_search_lengthy.py new file mode 100644 index 0000000..c589316 --- /dev/null +++ b/python_autocomplete/evaluate/beam_search_lengthy.py @@ -0,0 +1,151 @@ +from heapq import heappush, heappop +from typing import Any, List, Optional + +import torch +import torch.nn + +from python_autocomplete.evaluate.beam_search import PredictionComplete, BeamSearch +from python_autocomplete.train import StateUpdater + +EPS_SCORE = 1e-6 +MIN_BEAM_PROB = 1e-4 + + +class BeamSearchLengthy(BeamSearch): + """ + Use a score instead of probability to make longer predictions + """ + def __init__(self, *, beam_size: int, prediction_complete: PredictionComplete, + max_beam_size: int, rest: str, + state_updater: 'StateUpdater', + probs: Optional[List[float]], + is_token_by_token: bool, + itos: List[str]): + super().__init__() + self.itos = itos + self.is_token_by_token = is_token_by_token + self.state_updater = state_updater + self.prediction_complete = prediction_complete + self.max_beam_size = max_beam_size + self.rest = rest + + if probs is None: + probs = [1 / beam_size] * beam_size + assert len(probs) == beam_size + self.probs = probs + + self.result_heap = [] + self.text = [''] * beam_size + self.beam_heap = [] + + @staticmethod + def is_prefix(original, token_str): + """ + Whether `token_str` is a prefix of `original` + """ + if not original: + return True + + n = min(len(original), len(token_str)) + return original[:n] == token_str[:n] + + @staticmethod + def _get_score(text: str, prob: float): + return len(text) ** 0.2 * prob + + def _add_prediction_before_token(self, prob: float, beam_idx: int, state): + """ + Add a prediction before the last token to the results + """ + text = self.text[beam_idx] + score = self._get_score(text, prob) + + if len(self.result_heap) == self.max_beam_size: + if self.result_heap[0][0] > score - EPS_SCORE: + return + heappop(self.result_heap) + + state = self.state_updater.get_from_batch(state, beam_idx) + heappush(self.result_heap, (score, (text, state))) + + def _add_beam(self, prob: float, beam_idx: int, token: int): + """Add to the beam""" + text = self.text[beam_idx] + self.itos[token] + score = self._get_score(text, prob) + + # if self.result_heap and self.result_heap[0][0] > score - EPS_SCORE: + # return False + + if prob < MIN_BEAM_PROB: + return False + + if len(self.beam_heap) == self.max_beam_size: + if self.beam_heap[0][0] > score - EPS_SCORE: + return False + heappop(self.beam_heap) + + heappush(self.beam_heap, (score, (beam_idx, prob, token))) + + return True + + def next_batch(self, prompt: torch.Tensor, state: Any): + """Get the next batch (beam)""" + if not self.beam_heap: + return None, None + + new_prompt = [] + new_state = [] + + texts = self.text + self.text = [] + self.probs = [] + + for _, (b, prob, token) in self.beam_heap: + token = prompt.new_tensor([token]) + if self.is_token_by_token: + new_prompt.append(token) + else: + new_prompt.append(torch.cat((prompt[1:, b], token))) + new_state.append(self.state_updater.get_from_batch(state, b)) + self.probs.append(prob) + self.text.append(texts[b] + self.itos[token]) + + new_prompt = torch.stack(new_prompt, dim=1) + new_state = self.state_updater.make_batch(new_state) + + self.beam_heap = [] + + return new_prompt, new_state + + def update(self, next_token, state, old_state): + """Update beam search with the sampled data""" + self.beam_heap = [] + + for b, text in enumerate(self.text): + text = self.text[b] + if len(text) >= len(self.rest): + check_rest = None + else: + check_rest = self.rest[len(text):] + + tokens = next_token[b] + sort_idx = torch.argsort(tokens) + added_to_results = False + + for i in reversed(range(len(tokens))): + token = sort_idx[i] + token_str = self.itos[token] + if not self.is_prefix(check_rest, token_str): + continue + + if not added_to_results and self.prediction_complete(text, token_str): + added_to_results = True + self._add_prediction_before_token(self.probs[b], b, old_state) + + if check_rest and len(token_str) <= len(check_rest): + p = 1.0 + else: + p = tokens[token].item() + + if not self._add_beam(self.probs[b] * p, b, token): + break diff --git a/python_autocomplete/evaluate/eval_sample.py b/python_autocomplete/evaluate/eval_sample.py index e98da7e..d7d25d6 100644 --- a/python_autocomplete/evaluate/eval_sample.py +++ b/python_autocomplete/evaluate/eval_sample.py @@ -2,7 +2,8 @@ from labml import logger, lab, monit from labml.logger import Text, Style -from python_autocomplete.evaluate import NextWordPredictionComplete, Predictor +from python_autocomplete.evaluate import Predictor +from python_autocomplete.evaluate.beam_search import NextWordPredictionComplete from python_autocomplete.evaluate.factory import get_predictor diff --git a/python_autocomplete/evaluate/factory.py b/python_autocomplete/evaluate/factory.py index 259325b..eb7a921 100644 --- a/python_autocomplete/evaluate/factory.py +++ b/python_autocomplete/evaluate/factory.py @@ -1,4 +1,4 @@ -from labml import experiment +from labml import experiment, lab from labml.utils.pytorch import get_modules from python_autocomplete.evaluate import Predictor from python_autocomplete.train import Configs @@ -16,14 +16,15 @@ def load_experiment() -> Configs: # And for latest checkpoint # checkpoint = None - run_uuid = 'a6cff3706ec411ebadd9bf753b33bae6' # bpe - checkpoint = None - # run_uuid, checkpoint = experiment.load_bundle( - # lab.get_path() / 'saved_checkpoint.tar.gz', - # url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz') + # run_uuid = 'a6cff3706ec411ebadd9bf753b33bae6' # bpe + # checkpoint = None + run_uuid, checkpoint = experiment.load_bundle( + lab.get_path() / 'saved_checkpoint.tar.gz', + url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.5/bundle.tar.gz') conf_dict = experiment.load_configs(run_uuid) conf_dict['text.is_load_data'] = False + conf_dict['device.cuda_device'] = 1 experiment.configs(conf, conf_dict) experiment.add_pytorch_models(get_modules(conf)) experiment.load(run_uuid, checkpoint) diff --git a/python_autocomplete/serve.py b/python_autocomplete/serve.py index 5889ff5..7084b85 100644 --- a/python_autocomplete/serve.py +++ b/python_autocomplete/serve.py @@ -5,7 +5,7 @@ from flask import Flask, request, jsonify from labml import monit -from python_autocomplete.evaluate import NextWordPredictionComplete +from python_autocomplete.evaluate.beam_search import NextWordPredictionComplete from python_autocomplete.evaluate.factory import get_predictor app = Flask('python_autocomplete') @@ -29,20 +29,21 @@ def autocomplete(): if acquired: stripped, prompt = predictor.rstrip(prefix) rest = prefix[len(stripped):] - prediction_complete = NextWordPredictionComplete(rest, 5) + prediction_complete = NextWordPredictionComplete(rest, 15) prompt = torch.tensor(prompt, dtype=torch.long).unsqueeze(-1) predictions = predictor.get_next_word(prompt, None, rest, [1.], prediction_complete, 5) - predictions.sort(key=lambda x: -x[0]) + predictions.sort(key=lambda x: -x.prob) results = [pred.text[len(rest):] for pred in predictions] + probs = [pred.prob for pred in predictions] lock.release() s.message = f'{json.dumps(prefix[-5:])} -> {json.dumps(results)}' - return jsonify({'success': True, 'prediction': results}) + return jsonify({'success': True, 'prediction': results, 'probs': probs}) else: monit.fail() return jsonify({'success': False}) if __name__ == '__main__': - app.run(host='0.0.0.0', port=5000, debug=True) + app.run(host='0.0.0.0', port=5000, debug=False) diff --git a/readme.md b/readme.md index 5bbd05f..fd21aaa 100644 --- a/readme.md +++ b/readme.md @@ -5,6 +5,12 @@ # Python Autocomplete ++
+ +[The full length Python autocompletion Video](https://www.youtube.com/watch?v=ZFzxBPBUh0M) and a [Twitter thread describing how it works](https://twitter.com/labmlai/status/1367444214963838978) + This is a learning/demo project to show how deep learning can be used to auto complete Python code. You can experiment with LSTM and Transformer models. We also have built a simple VSCode extension to try out the trained models. diff --git a/setup.py b/setup.py index bc2ca82..be415ff 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='labml_python_autocomplete', - version='0.0.5', + version='0.0.7', author="Varuna Jayasiri", author_email="vpjayasiri@gmail.com", description="A simple model that learns to predict Python source code", @@ -21,7 +21,7 @@ 'test', 'test.*')), install_requires=['labml>=0.4.103', 'labml_helpers>=0.4.75', - 'labml_nn>=0.4.88' + 'labml_nn>=0.4.88', 'torch', 'einops', 'numpy'], diff --git a/vscode_extension/editor-settings.json b/vscode_extension/editor-settings.json new file mode 100644 index 0000000..78d80a8 --- /dev/null +++ b/vscode_extension/editor-settings.json @@ -0,0 +1,11 @@ +{ + "editor.quickSuggestions": { + "other": true, + "comments": true, + "strings": true + }, + "editor.quickSuggestionsDelay": 1, + "editor.wordBasedSuggestions": false, + "editor.acceptSuggestionOnCommitCharacter": false, + "editor.autoClosingBrackets": "never" +} diff --git a/vscode_extension/package-lock.json b/vscode_extension/package-lock.json index 045a368..8449668 100644 --- a/vscode_extension/package-lock.json +++ b/vscode_extension/package-lock.json @@ -1,47 +1,8 @@ { "name": "python-autocomplete", "version": "0.0.1", - "lockfileVersion": 2, + "lockfileVersion": 1, "requires": true, - "packages": { - "": { - "name": "python-autocomplete", - "version": "0.0.1", - "devDependencies": { - "@types/node": "*", - "@types/vscode": "*", - "typescript": "*" - }, - "engines": { - "vscode": "^1.32.0" - } - }, - "node_modules/@types/node": { - "version": "12.12.35", - "resolved": "https://registry.npmjs.org/@types/node/-/node-12.12.35.tgz", - "integrity": "sha512-ASYsaKecA7TUsDrqIGPNk3JeEox0z/0XR/WsJJ8BIX/9+SkMSImQXKWfU/yBrSyc7ZSE/NPqLu36Nur0miCFfQ==", - "dev": true - }, - "node_modules/@types/vscode": { - "version": "1.33.0", - "resolved": "https://registry.npmjs.org/@types/vscode/-/vscode-1.33.0.tgz", - "integrity": "sha512-JSmGiValbrcG5g20jjCfKakLiuWyrcjVezj+SEAEZ4klXQktE5EtowuGlkLVqbkiBK4iY5wy/4yW8OjecuHnjQ==", - "dev": true - }, - "node_modules/typescript": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.0.2.tgz", - "integrity": "sha512-e4ERvRV2wb+rRZ/IQeb3jm2VxBsirQLpQhdxplZ2MEzGvDkkMmPglecnNDfSUBivMjP93vRbngYYDQqQ/78bcQ==", - "dev": true, - "bin": { - "tsc": "bin/tsc", - "tsserver": "bin/tsserver" - }, - "engines": { - "node": ">=4.2.0" - } - } - }, "dependencies": { "@types/node": { "version": "12.12.35", diff --git a/vscode_extension/src/extension.ts b/vscode_extension/src/extension.ts index 4c0240f..e1f04bf 100644 --- a/vscode_extension/src/extension.ts +++ b/vscode_extension/src/extension.ts @@ -42,13 +42,57 @@ function getPrompt(document: vscode.TextDocument, position: vscode.Position) { return text } -function cleanupPredictions(predictions: string[]): string[] { +function removeNewLine(predictions: string[]): string[] { let res = [] for(let p of predictions) { let nl = p.indexOf('\n') if (nl !== -1) { p = p.substr(0, nl) } + res.push(p) + } + + return res +} + +function trimRight(predictions: string[]): string[] { + let res = [] + + for(let p of predictions) { + p = p.trimRight() + res.push(p) + } + + return res +} + +function removeDuplicates(predictions: string[]): string[] { + let set = new Set+
() + + for(let p of predictions) { + if(p !== '') { + set.add(p) + } + } + + let res = [] + for(let p of set) { + res.push(p) + } + + return res +} + +function removeSuffix(predictions: string[], document: vscode.TextDocument, position: vscode.Position): string[] { + const line = document.lineAt(position).text + const text = line.substr(position.character) + let res = [] + + for(let p of predictions) { + let suffix = p.indexOf(text[0]) + if (suffix !== -1) { + p = p.substr(0, suffix) + } if(p !== '') { res.push(p) } @@ -97,6 +141,7 @@ export function activate(context: vscode.ExtensionContext) { async provideCompletionItems(document: vscode.TextDocument, position: vscode.Position, token: vscode.CancellationToken, context: vscode.CompletionContext) { const prompt = getPrompt(document, position) let response + const fetchTime = new Date().getTime() try { response = await fetch(prompt) @@ -110,8 +155,19 @@ export function activate(context: vscode.ExtensionContext) { } let predictions: string[] = response.prediction + let probs: number[] = response.probs + for(let i = 0; i < probs.length - 1; ++i) { + if(probs[i] > probs[i + 1] * 4) { + predictions = predictions.slice(0, i + 1) + break + } + } + const nl = hasNewLine(predictions) - predictions = cleanupPredictions(predictions) + predictions = removeNewLine(predictions) + predictions = removeSuffix(predictions, document, position) + predictions = trimRight(predictions) + predictions = removeDuplicates(predictions) if (predictions.length === 0) { // If at end of a line just predict new line, to avoid annoying default vscode predictions @@ -120,8 +176,9 @@ export function activate(context: vscode.ExtensionContext) { simpleCompletion.kind = vscode.CompletionItemKind.Text simpleCompletion.command = { command: 'editor.action.triggerSuggest', title: 'Re-trigger completions...' } return [simpleCompletion] + } else { + return [] } - return [] } // Add any word prefix from text (because thats how vscode works) @@ -132,6 +189,7 @@ export function activate(context: vscode.ExtensionContext) { predictions = addPrefix(prefix, predictions) } + console.log(`Featching ${new Date().getTime() - fetchTime}ms`) return getCompletions(predictions, nl) } })