diff --git a/images/vscode_attention.gif b/images/vscode_attention.gif new file mode 100644 index 0000000..1bb6f20 Binary files /dev/null and b/images/vscode_attention.gif differ diff --git a/notebooks/analyze_attention.ipynb b/notebooks/analyze_attention.ipynb new file mode 100644 index 0000000..ee66d2c --- /dev/null +++ b/notebooks/analyze_attention.ipynb @@ -0,0 +1,746 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Github](https://img.shields.io/github/stars/lab-ml/python_autocomplete?style=social)](https://github.com/lab-ml/python_autocomplete)\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/python_autocomplete/blob/master/notebooks/analyze_attention.ipynb)\n", + "\n", + "# Analyze attention in Python Autocomplete model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install labml labml_python_autocomplete" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import string\n", + "import json\n", + "\n", + "import torch\n", + "from torch import nn\n", + "\n", + "import numpy as np\n", + "\n", + "from labml import experiment, logger, lab, analytics\n", + "from labml_helpers.module import Module\n", + "from labml.analytics import ModelProbe\n", + "from labml.logger import Text, Style, inspect\n", + "from labml.utils.pytorch import get_modules\n", + "from labml.utils.cache import cache\n", + "from labml_helpers.datasets.text import TextDataset\n", + "\n", + "from python_autocomplete.train import Configs\n", + "from python_autocomplete.evaluate import Predictor\n", + "from python_autocomplete.evaluate.beam_search import NextWordPredictionComplete" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We load the model from a training run. For this demo I'm loading from a run I trained at home.\n", + "\n", + "[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=39b03a1e454011ebbaff2b26e3148b3d)\n", + "\n", + "If you have a locally trained model load it directly with:\n", + "\n", + "```python\n", + "run_uuid = 'RUN_UUID'\n", + "checkpoint = None # Get latest checkpoint\n", + "```\n", + "\n", + "`load_bundle` will download an archive with a saved checkpoint (pretrained model)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
Run a6cff3706ec411ebadd9bf753b33bae6 exists\n",
+       "Checkpoint 702925824 exists\n",
+       "Extract bundle...[DONE]\t1,613.73ms\n",
+       "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# run_uuid = 'a6cff3706ec411ebadd9bf753b33bae6'\n", + "# checkpoint = None\n", + "\n", + "run_uuid, checkpoint = experiment.load_bundle(\n", + " lab.get_path() / 'saved_checkpoint.tar.gz',\n", + " url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.5/bundle.tar.gz')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We initialize `Configs` object defined in [`train.py`](https://github.com/lab-ml/python_autocomplete/blob/master/python_autocomplete/train.py)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "conf = Configs()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create a new experiment in evaluation mode. In evaluation mode a new training run is not created. " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "experiment.evaluate()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load custom configurations/hyper-parameters used in the training run." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'epochs': 32,\n", + " 'is_token_by_token': True,\n", + " 'mem_len': 256,\n", + " 'model': 'transformer_xl_model',\n", + " 'n_layers': 6,\n", + " 'optimizer.learning_rate': 0.000125,\n", + " 'optimizer.optimizer': 'AdamW',\n", + " 'state_updater': 'transformer_memory',\n", + " 'text.batch_size': 12,\n", + " 'text.is_shuffle': False,\n", + " 'text.seq_len': 256,\n", + " 'text.tokenizer': 'bpe'}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "custom_conf = experiment.load_configs(run_uuid)\n", + "custom_conf" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set the custom configurations" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# custom_conf['device.use_cuda'] = False" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
"
+      ],
+      "text/plain": [
+       ""
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "experiment.configs(conf, custom_conf)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Set models for saving and loading. This will load `conf.model` from the specified run."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
Prepare model...\n",
+       "  Prepare n_tokens...\n",
+       "    Prepare tokenizer...[DONE]\t7.86ms\n",
+       "  Prepare n_tokens...[DONE]\t13.67ms\n",
+       "  Prepare transformer...[DONE]\t2.67ms\n",
+       "  Prepare ffn...[DONE]\t1.20ms\n",
+       "  Prepare device...\n",
+       "    Prepare device_info...[DONE]\t1.80ms\n",
+       "  Prepare device...[DONE]\t4.01ms\n",
+       "Prepare model...[DONE]\t219.50ms\n",
+       "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "experiment.add_pytorch_models({'model': conf.model})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify which run to load from" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "experiment.load(run_uuid, checkpoint)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Start the experiment" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Selected experiment = source_code run = a6cff3706ec411ebadd9bf753b33bae6 checkpoint = 702,925,824\n",
+       "Loading checkpoint...[DONE]\t404.82ms\n",
+       "\n",
+       "Notebook Experiment: c674962e784e11ec8390acde48001122\n",
+       "\t[dirty]: \"analyze attentions\"\n",
+       "\tloaded from: a6cff3706ec411ebadd9bf753b33bae6
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment.start()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Initialize the `Predictor` defined in [`evaluate.py`](https://github.com/lab-ml/python_autocomplete/blob/master/python_autocomplete/evaluate.py).\n", + "\n", + "We load `stoi` and `itos` from cache, so that we don't have to read the dataset to generate them. `stoi` is the map for character to an integer index and `itos` is the map of integer to character map. These indexes are used in the model embeddings for each character." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
Prepare state_updater...[DONE]\t3.44ms\n",
+       "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "p = Predictor(conf.model, conf.text.tokenizer,\n", + " state_updater=conf.state_updater,\n", + " is_token_by_token=conf.is_token_by_token)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set model to evaluation mode" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "_ = conf.model.eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Setup probing to extract attentions" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "probe = ModelProbe(conf.model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A python prompt to test completion." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "PROMPT = \"\"\"from typing import Optional, Tuple\n", + "\n", + "import torch\n", + "from torch import nn\n", + "\n", + "from labml_nn.lstm import LSTM\n", + "from python_autocomplete.models import AutoregressiveModel\n", + "\n", + "\n", + "class LstmModel(AutoregressiveModel):\n", + " def __init__(self, *,\n", + " n_tokens: int,\n", + " embedding_size: int,\n", + " hidden_size: int,\n", + " n_layers: int):\n", + " super().__init__()\n", + "\n", + " self.embedding = nn.Embedding(n_tokens, embedding_size)\n", + " self.lstm = LSTM(input_size=embedding_size,\n", + " hidden_size=hidden_size,\n", + " n_layers=n_layers)\n", + " self.fc = nn.Linear(hidden_size, n_tokens)\n", + "\n", + " def __call__(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):\n", + " # shape of x is [seq, batch, feat]\n", + " x = self.embedding(x)\n", + " out, (hn, cn) = self.lstm(x, state)\n", + " logits = self.fc(out)\n", + "\n", + " return logits, (hn, cn)\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get a token. `get_token` predicts character by character greedily (no beam search) until it find and end of token character (non alpha-numeric character)." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "stripped, prompt = p.rstrip(PROMPT)\n", + "rest = PROMPT[len(stripped):]\n", + "prediction_complete = NextWordPredictionComplete(rest, 5)\n", + "prompt = torch.tensor(prompt, dtype=torch.long).unsqueeze(-1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Lets analyze attentions" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
0: \"h\"\n",
+       "1: \"n\"\n",
+       "2: \", \"\n",
+       "3: \"c\"\n",
+       "4: \"n\"\n",
+       "Total 5 item(s)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "tokens = [p.tokenizer.itos[i[0]] for i in prompt]\n", + "inspect(tokens[-5:])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets run the transformer XL model without cached memory to get the full attention matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
dtype: torch.float32\n",
+       "shape: [1, 1097]\n",
+       "min: 0.000000 max: 0.611870 mean: 0.000912 std: 0.020015\n",
+       "[\n",
+       " [0.000000, 0.000000, 0.000010, ..., 0.000000 ]\n",
+       "]
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "inspect(p._get_predictions(prompt, None)[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We capture the outputs after the [attention softmax](https://nn.labml.ai/transformers/mha.html#section-34)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
0: \"transformer.layers.0.self_attn.softmax\"\n",
+       "1: \"transformer.layers.1.self_attn.softmax\"\n",
+       "2: \"transformer.layers.2.self_attn.softmax\"\n",
+       "3: \"transformer.layers.3.self_attn.softmax\"\n",
+       "4: \"transformer.layers.4.self_attn.softmax\"\n",
+       "5: \"transformer.layers.5.self_attn.softmax\"\n",
+       "Total 6 item(s)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "inspect(probe.forward_output['*softmax*'])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "attn = probe.forward_output['*softmax*'].get_list()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Attentions have shape `[source, destination, batch, heads]`" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
0: 279\n",
+       "1: 279\n",
+       "2: 1\n",
+       "3: 8\n",
+       "Total 4 item(s)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "inspect(attn[0].shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "attn_maps = torch.stack([a.permute(2, 3, 0, 1)[0] for a in attn])" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
dtype: torch.float32\n",
+       "shape: [6, 8, 279, 279]\n",
+       "min: 0.000000 max: 1.000000 mean: 0.003584 std: 0.026890\n",
+       "[\n",
+       " [\n",
+       "  [[1.000000, 0.000000, 0.000000, ..., 0.000000], [0.993492, 0.006508, 0.000000, ..., 0.000000], [0.991854, 0.008065, 0.000081, ..., 0.000000], ..., [0.003070, 0.000385, 0.003941, ..., 0.000002]], \n",
+       "  [[1.000000, 0.000000, 0.000000, ..., 0.000000], [0.996190, 0.003810, 0.000000, ..., 0.000000], [0.761333, 0.231494, 0.007174, ..., 0.000000], ..., [0.007664, 0.000290, 0.000628, ..., 0.000009]], \n",
+       "  [[1.000000, 0.000000, 0.000000, ..., 0.000000], [0.999664, 0.000336, 0.000000, ..., 0.000000], [0.780615, 0.219290, 0.000096, ..., 0.000000], ..., [0.001629, 0.000160, 0.000875, ..., 0.000001]], \n",
+       "  ..., \n",
+       "  [[1.000000, 0.000000, 0.000000, ..., 0.000000], [0.999970, 0.000030, 0.000000, ..., 0.000000], [0.349430, 0.616987, 0.033583, ..., 0.000000], ..., [0.000142, 0.000123, 0.000777, ..., 0.000005]]\n",
+       " ], \n",
+       " [\n",
+       "  [[1.000000, 0.000000, 0.000000, ..., 0.000000], [0.999963, 0.000037, 0.000000, ..., 0.000000], [0.893664, 0.099487, 0.006848, ..., 0.000000], ..., [0.000250, 0.000073, 0.000148, ..., 0.000001]], \n",
+       "  [[1.000000, 0.000000, 0.000000, ..., 0.000000], [0.999436, 0.000564, 0.000000, ..., 0.000000], [0.733821, 0.216545, 0.049634, ..., 0.000000 ... 
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "inspect(attn_maps)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "torch.save( attn_maps, 'attentions.pt')" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "with open('tokens.json', 'w') as f:\n", + " f.write(json.dumps({'src': tokens, 'dst': tokens}))" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "analytics.init_inline_viz()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "analytics.text_attention(torch.sum(attn_maps, dim=[0,1]), tokens, tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/evaluate.ipynb b/notebooks/evaluate.ipynb index fac7250..b0bf355 100644 --- a/notebooks/evaluate.ipynb +++ b/notebooks/evaluate.ipynb @@ -73,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -82,15 +82,19 @@ "import torch\n", "from torch import nn\n", "\n", + "import numpy as np\n", + "\n", "from labml import experiment, logger, lab\n", "from labml_helpers.module import Module\n", - "from labml.logger import Text, Style\n", + "from labml.analytics import ModelProbe\n", + "from labml.logger import Text, Style, inspect\n", "from labml.utils.pytorch import get_modules\n", "from labml.utils.cache import cache\n", "from labml_helpers.datasets.text import TextDataset\n", "\n", "from python_autocomplete.train import Configs\n", - "from python_autocomplete.evaluate import Predictor, NextWordPredictionComplete" + "from python_autocomplete.evaluate import Predictor\n", + "from python_autocomplete.evaluate.beam_search import NextWordPredictionComplete" ] }, { @@ -115,14 +119,30 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
Run a6cff3706ec411ebadd9bf753b33bae6 exists\n",
+       "Checkpoint 702925824 exists\n",
+       "Extract bundle...[DONE]\t1,534.94ms\n",
+       "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "run_uuid = 'a6cff3706ec411ebadd9bf753b33bae6'\n", - "checkpoint = None\n", + "# run_uuid = 'a6cff3706ec411ebadd9bf753b33bae6'\n", + "# checkpoint = None\n", "\n", - "# run_uuid, checkpoint = experiment.load_bundle(\n", - "# lab.get_path() / 'saved_checkpoint.tar.gz',\n", - "# url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz')" + "run_uuid, checkpoint = experiment.load_bundle(\n", + " lab.get_path() / 'saved_checkpoint.tar.gz',\n", + " url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.5/bundle.tar.gz')" ] }, { @@ -251,14 +271,14 @@ "text/html": [ "
Prepare model...\n",
        "  Prepare n_tokens...\n",
-       "    Prepare tokenizer...[DONE]\t3.41ms\n",
-       "  Prepare n_tokens...[DONE]\t9.35ms\n",
-       "  Prepare transformer...[DONE]\t1.55ms\n",
-       "  Prepare ffn...[DONE]\t1.28ms\n",
+       "    Prepare tokenizer...[DONE]\t4.84ms\n",
+       "  Prepare n_tokens...[DONE]\t10.12ms\n",
+       "  Prepare transformer...[DONE]\t2.95ms\n",
+       "  Prepare ffn...[DONE]\t1.67ms\n",
        "  Prepare device...\n",
-       "    Prepare device_info...[DONE]\t70.64ms\n",
-       "  Prepare device...[DONE]\t72.66ms\n",
-       "Prepare model...[DONE]\t1,830.13ms\n",
+       "    Prepare device_info...[DONE]\t1.84ms\n",
+       "  Prepare device...[DONE]\t4.15ms\n",
+       "Prepare model...[DONE]\t147.69ms\n",
        "
" ], "text/plain": [ @@ -306,11 +326,11 @@ { "data": { "text/html": [ - "
Selected experiment = source_code run = a6cff3706ec411ebadd9bf753b33bae6 checkpoint = 702925824\n",
-       "Loading checkpoint...[DONE]\t107.99ms\n",
+       "
Selected experiment = source_code run = a6cff3706ec411ebadd9bf753b33bae6 checkpoint = 702,925,824\n",
+       "Loading checkpoint...[DONE]\t389.90ms\n",
        "\n",
-       "Notebook Experiment: 294627686f4f11eb9fbb718949888c4e\n",
-       "[clean]: \"cleanup\"\n",
+       "Notebook Experiment: 07e6549a74ed11ec97e2acde48001122\n",
+       "\t[dirty]: \"serve without debug\"\n",
        "\tloaded from: a6cff3706ec411ebadd9bf753b33bae6
" ], "text/plain": [ @@ -323,7 +343,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 10, @@ -352,7 +372,7 @@ { "data": { "text/html": [ - "
Prepare state_updater...[DONE]\t4.58ms\n",
+       "
Prepare state_updater...[DONE]\t4.30ms\n",
        "
" ], "text/plain": [ @@ -385,6 +405,22 @@ "_ = conf.model.eval()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Setup probing to extract attentions" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "probe = ModelProbe(conf.model)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -394,7 +430,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -422,7 +458,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -434,26 +470,26 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 53.5 ms, sys: 193 µs, total: 53.6 ms\n", - "Wall time: 52.2 ms\n" + "CPU times: user 193 ms, sys: 47.4 ms, total: 241 ms\n", + "Wall time: 218 ms\n" ] }, { "data": { "text/plain": [ - "[(0.07585373724739952, 'super'),\n", - " (0.0010850475927322429, '\"\"\"\\n '),\n", - " (0.0007989557343535125, ' ')]" + "[(0.07585359086023313, 'super'),\n", + " (0.0010850478390376612, '\"\"\"\\n '),\n", + " (0.0007989550358615816, ' ')]" ] }, - "execution_count": 33, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -482,7 +518,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.7.5" } }, "nbformat": 4, diff --git a/python_autocomplete/evaluate/__init__.py b/python_autocomplete/evaluate/__init__.py index 294bef7..98d0158 100644 --- a/python_autocomplete/evaluate/__init__.py +++ b/python_autocomplete/evaluate/__init__.py @@ -1,164 +1,16 @@ -from heapq import heappush, heappop -from typing import Any, Tuple, List, Optional, NamedTuple +from typing import Any, Tuple, List, NamedTuple import torch import torch.nn from torch import nn +from labml import monit from labml_helpers.module import Module -from python_autocomplete.dataset import Tokenizer, ID_CHARS +from python_autocomplete.dataset import Tokenizer +from python_autocomplete.evaluate.beam_search import PredictionComplete, BeamSearch, BeamSearchSimple +from python_autocomplete.evaluate.beam_search_lengthy import BeamSearchLengthy from python_autocomplete.train import StateUpdater -EPS_PROB = 1e-6 -MIN_BEAM_PROB = 1e-4 - - -class PredictionComplete: - def __call__(self, text, token_str: str): - raise NotImplementedError - - -class NextWordPredictionComplete(PredictionComplete): - def __init__(self, rest: str, min_length: int): - self.min_length = min_length - self.rest = rest - - def __call__(self, text, token_str: str): - if len(text) - len(self.rest) < self.min_length: - return False - - prev_is_id = text[-1] in ID_CHARS - last_is_id = token_str[-1] in ID_CHARS - - return prev_is_id != last_is_id - - -class BeamSearch: - def __init__(self, beam_size: int, prediction_complete: PredictionComplete, - max_beam_size: int, rest: str, - state_updater: 'StateUpdater', - probs: Optional[List[float]], - is_token_by_token: bool): - self.is_token_by_token = is_token_by_token - self.state_updater = state_updater - self.prediction_complete = prediction_complete - self.max_beam_size = max_beam_size - self.rest = rest - - if probs is None: - probs = [1 / beam_size] * beam_size - assert len(probs) == beam_size - self.probs = probs - - self.result_heap = [] - self.text = [''] * beam_size - self.beam_heap = [] - - @staticmethod - def is_substr(original, token_str): - if not original: - return True - - n = min(len(original), len(token_str)) - return original[:n] == token_str[:n] - - def add_prediction(self, prob: float, beam_idx: int, token_str: str, state): - if len(self.result_heap) == self.max_beam_size: - if self.result_heap[0][0] > prob - EPS_PROB: - return False - heappop(self.result_heap) - - state = self.state_updater.get_from_batch(state, beam_idx) - text = self.text[beam_idx] + token_str - heappush(self.result_heap, (prob, (text, state))) - - return True - - def add_prediction_before_token(self, prob: float, beam_idx: int, state): - if len(self.result_heap) == self.max_beam_size: - if self.result_heap[0][0] > prob - EPS_PROB: - return False - heappop(self.result_heap) - - state = self.state_updater.get_from_batch(state, beam_idx) - text = self.text[beam_idx] - heappush(self.result_heap, (prob, (text, state))) - - return True - - def add_beam(self, prob: float, beam_idx: int, token: int): - if self.result_heap and self.result_heap[0][0] > prob - EPS_PROB: - return False - - if prob < MIN_BEAM_PROB: - return False - - if len(self.beam_heap) == self.max_beam_size: - if self.beam_heap[0][0] > prob - EPS_PROB: - return False - heappop(self.beam_heap) - - heappush(self.beam_heap, (prob, (beam_idx, token))) - - return True - - def next_batch(self, prompt: torch.Tensor, state: Any, itos: List[str]): - if not self.beam_heap: - return None, None - - new_prompt = [] - new_state = [] - - texts = self.text - self.text = [] - self.probs = [] - - for prob, (b, token) in self.beam_heap: - token = prompt.new_tensor([token]) - if self.is_token_by_token: - new_prompt.append(token) - else: - new_prompt.append(torch.cat((prompt[1:, b], token))) - new_state.append(self.state_updater.get_from_batch(state, b)) - self.probs.append(prob) - self.text.append(texts[b] + itos[token]) - - new_prompt = torch.stack(new_prompt, dim=1) - new_state = self.state_updater.make_batch(new_state) - - self.beam_heap = [] - - return new_prompt, new_state - - def update(self, next_token, itos: List[str], state, old_state): - self.beam_heap = [] - - for b, text in enumerate(self.text): - text = self.text[b] - if len(text) >= len(self.rest): - check_rest = None - else: - check_rest = self.rest[len(text):] - - tokens = next_token[b] - sort_idx = torch.argsort(tokens) - - for i in reversed(range(len(tokens))): - token = sort_idx[i] - token_str = itos[token] - if not self.is_substr(check_rest, token_str): - continue - - if self.prediction_complete(text, token_str): - if not self.add_prediction_before_token(self.probs[b], b, old_state): - break - else: - break - # if not self.add_prediction(self.probs[b] * tokens[token].item(), b, token_str, state): - # break - if not self.add_beam(self.probs[b] * tokens[token].item(), b, token): - break - class Prediction(NamedTuple): prob: float @@ -199,13 +51,21 @@ def get_next_word(self, prompt: torch.Tensor, state: Any, rest: str, probs: List prediction_complete: PredictionComplete, max_beam_size: int) -> \ List[Prediction]: - beam = BeamSearch(prompt.shape[1], prediction_complete, max_beam_size, rest, self.state_updater, - probs, self.is_token_by_token) + beam = BeamSearchSimple(beam_size=prompt.shape[1], + prediction_complete=prediction_complete, + max_beam_size=max_beam_size, + rest=rest, + state_updater=self.state_updater, + probs=probs, + is_token_by_token=self.is_token_by_token, + itos=self.tokenizer.itos) for _ in range(10): - next_token, new_state = self._get_predictions(prompt, state) - beam.update(next_token, self.tokenizer.itos, new_state, state) - prompt, state = beam.next_batch(prompt, new_state, self.tokenizer.itos) + with monit.section('Predict', is_silent=True): + next_token, new_state = self._get_predictions(prompt, state) + with monit.section('Beam', is_silent=True): + beam.update(next_token, new_state, state) + prompt, state = beam.next_batch(prompt, new_state) if prompt is None: break diff --git a/python_autocomplete/evaluate/beam_search.py b/python_autocomplete/evaluate/beam_search.py new file mode 100644 index 0000000..0bc98fd --- /dev/null +++ b/python_autocomplete/evaluate/beam_search.py @@ -0,0 +1,190 @@ +from heapq import heappush, heappop +from typing import Any, List, Optional + +import torch +import torch.nn + +from python_autocomplete.dataset import ID_CHARS +from python_autocomplete.train import StateUpdater + +EPS_PROB = 1e-6 +MIN_BEAM_PROB = 1e-4 + + +class PredictionComplete: + def __call__(self, text, token_str: str): + raise NotImplementedError + + +class NextWordPredictionComplete(PredictionComplete): + def __init__(self, rest: str, min_length: int): + self.min_length = min_length + self.rest = rest + + def __call__(self, text, token_str: str): + if len(text) - len(self.rest) < self.min_length: + return False + + prev_is_id = text[-1] in ID_CHARS + last_is_id = token_str[-1] in ID_CHARS + + return prev_is_id != last_is_id + + +class NextWordNewLinePredictionComplete(PredictionComplete): + def __init__(self, rest: str, min_length: int): + self.min_length = min_length + self.rest = rest + + def __call__(self, text, token_str: str): + if len(text) - len(self.rest) < self.min_length: + return False + + if '\n' in token_str: + return True + + prev_is_id = text[-1] in ID_CHARS + last_is_id = token_str[-1] in ID_CHARS + + return prev_is_id != last_is_id + + +class BeamSearch: + def __init__(self): + pass + + def next_batch(self, prompt: torch.Tensor, state: Any): + raise NotImplementedError + + def update(self, next_token, state, old_state): + raise NotImplementedError + + +class BeamSearchSimple(BeamSearch): + def __init__(self, *, beam_size: int, prediction_complete: PredictionComplete, + max_beam_size: int, rest: str, + state_updater: 'StateUpdater', + probs: Optional[List[float]], + is_token_by_token: bool, + itos: List[str]): + super().__init__() + self.itos = itos + self.is_token_by_token = is_token_by_token + self.state_updater = state_updater + self.prediction_complete = prediction_complete + self.max_beam_size = max_beam_size + self.rest = rest + + if probs is None: + probs = [1 / beam_size] * beam_size + assert len(probs) == beam_size + self.probs = probs + + self.result_heap = [] + self.text = [''] * beam_size + self.beam_heap = [] + + @staticmethod + def is_substr(original, token_str): + if not original: + return True + + n = min(len(original), len(token_str)) + return original[:n] == token_str[:n] + + def add_prediction(self, prob: float, beam_idx: int, token_str: str, state): + if len(self.result_heap) == self.max_beam_size: + if self.result_heap[0][0] > prob - EPS_PROB: + return False + heappop(self.result_heap) + + state = self.state_updater.get_from_batch(state, beam_idx) + text = self.text[beam_idx] + token_str + heappush(self.result_heap, (prob, (text, state))) + + return True + + def add_prediction_before_token(self, prob: float, beam_idx: int, state): + if len(self.result_heap) == self.max_beam_size: + if self.result_heap[0][0] > prob - EPS_PROB: + return False + heappop(self.result_heap) + + state = self.state_updater.get_from_batch(state, beam_idx) + text = self.text[beam_idx] + heappush(self.result_heap, (prob, (text, state))) + + return True + + def add_beam(self, prob: float, beam_idx: int, token: int): + if self.result_heap and self.result_heap[0][0] > prob - EPS_PROB: + return False + + if prob < MIN_BEAM_PROB: + return False + + if len(self.beam_heap) == self.max_beam_size: + if self.beam_heap[0][0] > prob - EPS_PROB: + return False + heappop(self.beam_heap) + + heappush(self.beam_heap, (prob, (beam_idx, token))) + + return True + + def next_batch(self, prompt: torch.Tensor, state: Any): + if not self.beam_heap: + return None, None + + new_prompt = [] + new_state = [] + + texts = self.text + self.text = [] + self.probs = [] + + for prob, (b, token) in self.beam_heap: + token = prompt.new_tensor([token]) + if self.is_token_by_token: + new_prompt.append(token) + else: + new_prompt.append(torch.cat((prompt[1:, b], token))) + new_state.append(self.state_updater.get_from_batch(state, b)) + self.probs.append(prob) + self.text.append(texts[b] + self.itos[token]) + + new_prompt = torch.stack(new_prompt, dim=1) + new_state = self.state_updater.make_batch(new_state) + + self.beam_heap = [] + + return new_prompt, new_state + + def update(self, next_token, state, old_state): + self.beam_heap = [] + + for b, text in enumerate(self.text): + text = self.text[b] + if len(text) >= len(self.rest): + check_rest = None + else: + check_rest = self.rest[len(text):] + + tokens = next_token[b] + sort_idx = torch.argsort(tokens) + + for i in reversed(range(len(tokens))): + token = sort_idx[i] + token_str = self.itos[token] + if not self.is_substr(check_rest, token_str): + continue + + if self.prediction_complete(text, token_str): + if not self.add_prediction_before_token(self.probs[b], b, old_state): + break + else: + break + # if not self.add_prediction(self.probs[b] * tokens[token].item(), b, token_str, state): + # break + if not self.add_beam(self.probs[b] * tokens[token].item(), b, token): + break diff --git a/python_autocomplete/evaluate/beam_search_lengthy.py b/python_autocomplete/evaluate/beam_search_lengthy.py new file mode 100644 index 0000000..c589316 --- /dev/null +++ b/python_autocomplete/evaluate/beam_search_lengthy.py @@ -0,0 +1,151 @@ +from heapq import heappush, heappop +from typing import Any, List, Optional + +import torch +import torch.nn + +from python_autocomplete.evaluate.beam_search import PredictionComplete, BeamSearch +from python_autocomplete.train import StateUpdater + +EPS_SCORE = 1e-6 +MIN_BEAM_PROB = 1e-4 + + +class BeamSearchLengthy(BeamSearch): + """ + Use a score instead of probability to make longer predictions + """ + def __init__(self, *, beam_size: int, prediction_complete: PredictionComplete, + max_beam_size: int, rest: str, + state_updater: 'StateUpdater', + probs: Optional[List[float]], + is_token_by_token: bool, + itos: List[str]): + super().__init__() + self.itos = itos + self.is_token_by_token = is_token_by_token + self.state_updater = state_updater + self.prediction_complete = prediction_complete + self.max_beam_size = max_beam_size + self.rest = rest + + if probs is None: + probs = [1 / beam_size] * beam_size + assert len(probs) == beam_size + self.probs = probs + + self.result_heap = [] + self.text = [''] * beam_size + self.beam_heap = [] + + @staticmethod + def is_prefix(original, token_str): + """ + Whether `token_str` is a prefix of `original` + """ + if not original: + return True + + n = min(len(original), len(token_str)) + return original[:n] == token_str[:n] + + @staticmethod + def _get_score(text: str, prob: float): + return len(text) ** 0.2 * prob + + def _add_prediction_before_token(self, prob: float, beam_idx: int, state): + """ + Add a prediction before the last token to the results + """ + text = self.text[beam_idx] + score = self._get_score(text, prob) + + if len(self.result_heap) == self.max_beam_size: + if self.result_heap[0][0] > score - EPS_SCORE: + return + heappop(self.result_heap) + + state = self.state_updater.get_from_batch(state, beam_idx) + heappush(self.result_heap, (score, (text, state))) + + def _add_beam(self, prob: float, beam_idx: int, token: int): + """Add to the beam""" + text = self.text[beam_idx] + self.itos[token] + score = self._get_score(text, prob) + + # if self.result_heap and self.result_heap[0][0] > score - EPS_SCORE: + # return False + + if prob < MIN_BEAM_PROB: + return False + + if len(self.beam_heap) == self.max_beam_size: + if self.beam_heap[0][0] > score - EPS_SCORE: + return False + heappop(self.beam_heap) + + heappush(self.beam_heap, (score, (beam_idx, prob, token))) + + return True + + def next_batch(self, prompt: torch.Tensor, state: Any): + """Get the next batch (beam)""" + if not self.beam_heap: + return None, None + + new_prompt = [] + new_state = [] + + texts = self.text + self.text = [] + self.probs = [] + + for _, (b, prob, token) in self.beam_heap: + token = prompt.new_tensor([token]) + if self.is_token_by_token: + new_prompt.append(token) + else: + new_prompt.append(torch.cat((prompt[1:, b], token))) + new_state.append(self.state_updater.get_from_batch(state, b)) + self.probs.append(prob) + self.text.append(texts[b] + self.itos[token]) + + new_prompt = torch.stack(new_prompt, dim=1) + new_state = self.state_updater.make_batch(new_state) + + self.beam_heap = [] + + return new_prompt, new_state + + def update(self, next_token, state, old_state): + """Update beam search with the sampled data""" + self.beam_heap = [] + + for b, text in enumerate(self.text): + text = self.text[b] + if len(text) >= len(self.rest): + check_rest = None + else: + check_rest = self.rest[len(text):] + + tokens = next_token[b] + sort_idx = torch.argsort(tokens) + added_to_results = False + + for i in reversed(range(len(tokens))): + token = sort_idx[i] + token_str = self.itos[token] + if not self.is_prefix(check_rest, token_str): + continue + + if not added_to_results and self.prediction_complete(text, token_str): + added_to_results = True + self._add_prediction_before_token(self.probs[b], b, old_state) + + if check_rest and len(token_str) <= len(check_rest): + p = 1.0 + else: + p = tokens[token].item() + + if not self._add_beam(self.probs[b] * p, b, token): + break diff --git a/python_autocomplete/evaluate/eval_sample.py b/python_autocomplete/evaluate/eval_sample.py index e98da7e..d7d25d6 100644 --- a/python_autocomplete/evaluate/eval_sample.py +++ b/python_autocomplete/evaluate/eval_sample.py @@ -2,7 +2,8 @@ from labml import logger, lab, monit from labml.logger import Text, Style -from python_autocomplete.evaluate import NextWordPredictionComplete, Predictor +from python_autocomplete.evaluate import Predictor +from python_autocomplete.evaluate.beam_search import NextWordPredictionComplete from python_autocomplete.evaluate.factory import get_predictor diff --git a/python_autocomplete/evaluate/factory.py b/python_autocomplete/evaluate/factory.py index 259325b..eb7a921 100644 --- a/python_autocomplete/evaluate/factory.py +++ b/python_autocomplete/evaluate/factory.py @@ -1,4 +1,4 @@ -from labml import experiment +from labml import experiment, lab from labml.utils.pytorch import get_modules from python_autocomplete.evaluate import Predictor from python_autocomplete.train import Configs @@ -16,14 +16,15 @@ def load_experiment() -> Configs: # And for latest checkpoint # checkpoint = None - run_uuid = 'a6cff3706ec411ebadd9bf753b33bae6' # bpe - checkpoint = None - # run_uuid, checkpoint = experiment.load_bundle( - # lab.get_path() / 'saved_checkpoint.tar.gz', - # url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz') + # run_uuid = 'a6cff3706ec411ebadd9bf753b33bae6' # bpe + # checkpoint = None + run_uuid, checkpoint = experiment.load_bundle( + lab.get_path() / 'saved_checkpoint.tar.gz', + url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.5/bundle.tar.gz') conf_dict = experiment.load_configs(run_uuid) conf_dict['text.is_load_data'] = False + conf_dict['device.cuda_device'] = 1 experiment.configs(conf, conf_dict) experiment.add_pytorch_models(get_modules(conf)) experiment.load(run_uuid, checkpoint) diff --git a/python_autocomplete/serve.py b/python_autocomplete/serve.py index 5889ff5..7084b85 100644 --- a/python_autocomplete/serve.py +++ b/python_autocomplete/serve.py @@ -5,7 +5,7 @@ from flask import Flask, request, jsonify from labml import monit -from python_autocomplete.evaluate import NextWordPredictionComplete +from python_autocomplete.evaluate.beam_search import NextWordPredictionComplete from python_autocomplete.evaluate.factory import get_predictor app = Flask('python_autocomplete') @@ -29,20 +29,21 @@ def autocomplete(): if acquired: stripped, prompt = predictor.rstrip(prefix) rest = prefix[len(stripped):] - prediction_complete = NextWordPredictionComplete(rest, 5) + prediction_complete = NextWordPredictionComplete(rest, 15) prompt = torch.tensor(prompt, dtype=torch.long).unsqueeze(-1) predictions = predictor.get_next_word(prompt, None, rest, [1.], prediction_complete, 5) - predictions.sort(key=lambda x: -x[0]) + predictions.sort(key=lambda x: -x.prob) results = [pred.text[len(rest):] for pred in predictions] + probs = [pred.prob for pred in predictions] lock.release() s.message = f'{json.dumps(prefix[-5:])} -> {json.dumps(results)}' - return jsonify({'success': True, 'prediction': results}) + return jsonify({'success': True, 'prediction': results, 'probs': probs}) else: monit.fail() return jsonify({'success': False}) if __name__ == '__main__': - app.run(host='0.0.0.0', port=5000, debug=True) + app.run(host='0.0.0.0', port=5000, debug=False) diff --git a/readme.md b/readme.md index 5bbd05f..fd21aaa 100644 --- a/readme.md +++ b/readme.md @@ -5,6 +5,12 @@ # Python Autocomplete +

+ +

+ +[The full length Python autocompletion Video](https://www.youtube.com/watch?v=ZFzxBPBUh0M) and a [Twitter thread describing how it works](https://twitter.com/labmlai/status/1367444214963838978) + This is a learning/demo project to show how deep learning can be used to auto complete Python code. You can experiment with LSTM and Transformer models. We also have built a simple VSCode extension to try out the trained models. diff --git a/setup.py b/setup.py index bc2ca82..be415ff 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='labml_python_autocomplete', - version='0.0.5', + version='0.0.7', author="Varuna Jayasiri", author_email="vpjayasiri@gmail.com", description="A simple model that learns to predict Python source code", @@ -21,7 +21,7 @@ 'test', 'test.*')), install_requires=['labml>=0.4.103', 'labml_helpers>=0.4.75', - 'labml_nn>=0.4.88' + 'labml_nn>=0.4.88', 'torch', 'einops', 'numpy'], diff --git a/vscode_extension/editor-settings.json b/vscode_extension/editor-settings.json new file mode 100644 index 0000000..78d80a8 --- /dev/null +++ b/vscode_extension/editor-settings.json @@ -0,0 +1,11 @@ +{ + "editor.quickSuggestions": { + "other": true, + "comments": true, + "strings": true + }, + "editor.quickSuggestionsDelay": 1, + "editor.wordBasedSuggestions": false, + "editor.acceptSuggestionOnCommitCharacter": false, + "editor.autoClosingBrackets": "never" +} diff --git a/vscode_extension/package-lock.json b/vscode_extension/package-lock.json index 045a368..8449668 100644 --- a/vscode_extension/package-lock.json +++ b/vscode_extension/package-lock.json @@ -1,47 +1,8 @@ { "name": "python-autocomplete", "version": "0.0.1", - "lockfileVersion": 2, + "lockfileVersion": 1, "requires": true, - "packages": { - "": { - "name": "python-autocomplete", - "version": "0.0.1", - "devDependencies": { - "@types/node": "*", - "@types/vscode": "*", - "typescript": "*" - }, - "engines": { - "vscode": "^1.32.0" - } - }, - "node_modules/@types/node": { - "version": "12.12.35", - "resolved": "https://registry.npmjs.org/@types/node/-/node-12.12.35.tgz", - "integrity": "sha512-ASYsaKecA7TUsDrqIGPNk3JeEox0z/0XR/WsJJ8BIX/9+SkMSImQXKWfU/yBrSyc7ZSE/NPqLu36Nur0miCFfQ==", - "dev": true - }, - "node_modules/@types/vscode": { - "version": "1.33.0", - "resolved": "https://registry.npmjs.org/@types/vscode/-/vscode-1.33.0.tgz", - "integrity": "sha512-JSmGiValbrcG5g20jjCfKakLiuWyrcjVezj+SEAEZ4klXQktE5EtowuGlkLVqbkiBK4iY5wy/4yW8OjecuHnjQ==", - "dev": true - }, - "node_modules/typescript": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.0.2.tgz", - "integrity": "sha512-e4ERvRV2wb+rRZ/IQeb3jm2VxBsirQLpQhdxplZ2MEzGvDkkMmPglecnNDfSUBivMjP93vRbngYYDQqQ/78bcQ==", - "dev": true, - "bin": { - "tsc": "bin/tsc", - "tsserver": "bin/tsserver" - }, - "engines": { - "node": ">=4.2.0" - } - } - }, "dependencies": { "@types/node": { "version": "12.12.35", diff --git a/vscode_extension/src/extension.ts b/vscode_extension/src/extension.ts index 4c0240f..e1f04bf 100644 --- a/vscode_extension/src/extension.ts +++ b/vscode_extension/src/extension.ts @@ -42,13 +42,57 @@ function getPrompt(document: vscode.TextDocument, position: vscode.Position) { return text } -function cleanupPredictions(predictions: string[]): string[] { +function removeNewLine(predictions: string[]): string[] { let res = [] for(let p of predictions) { let nl = p.indexOf('\n') if (nl !== -1) { p = p.substr(0, nl) } + res.push(p) + } + + return res +} + +function trimRight(predictions: string[]): string[] { + let res = [] + + for(let p of predictions) { + p = p.trimRight() + res.push(p) + } + + return res +} + +function removeDuplicates(predictions: string[]): string[] { + let set = new Set() + + for(let p of predictions) { + if(p !== '') { + set.add(p) + } + } + + let res = [] + for(let p of set) { + res.push(p) + } + + return res +} + +function removeSuffix(predictions: string[], document: vscode.TextDocument, position: vscode.Position): string[] { + const line = document.lineAt(position).text + const text = line.substr(position.character) + let res = [] + + for(let p of predictions) { + let suffix = p.indexOf(text[0]) + if (suffix !== -1) { + p = p.substr(0, suffix) + } if(p !== '') { res.push(p) } @@ -97,6 +141,7 @@ export function activate(context: vscode.ExtensionContext) { async provideCompletionItems(document: vscode.TextDocument, position: vscode.Position, token: vscode.CancellationToken, context: vscode.CompletionContext) { const prompt = getPrompt(document, position) let response + const fetchTime = new Date().getTime() try { response = await fetch(prompt) @@ -110,8 +155,19 @@ export function activate(context: vscode.ExtensionContext) { } let predictions: string[] = response.prediction + let probs: number[] = response.probs + for(let i = 0; i < probs.length - 1; ++i) { + if(probs[i] > probs[i + 1] * 4) { + predictions = predictions.slice(0, i + 1) + break + } + } + const nl = hasNewLine(predictions) - predictions = cleanupPredictions(predictions) + predictions = removeNewLine(predictions) + predictions = removeSuffix(predictions, document, position) + predictions = trimRight(predictions) + predictions = removeDuplicates(predictions) if (predictions.length === 0) { // If at end of a line just predict new line, to avoid annoying default vscode predictions @@ -120,8 +176,9 @@ export function activate(context: vscode.ExtensionContext) { simpleCompletion.kind = vscode.CompletionItemKind.Text simpleCompletion.command = { command: 'editor.action.triggerSuggest', title: 'Re-trigger completions...' } return [simpleCompletion] + } else { + return [] } - return [] } // Add any word prefix from text (because thats how vscode works) @@ -132,6 +189,7 @@ export function activate(context: vscode.ExtensionContext) { predictions = addPrefix(prefix, predictions) } + console.log(`Featching ${new Date().getTime() - fetchTime}ms`) return getCompletions(predictions, nl) } })