Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Boris/embedding examples #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions examples/embeddings/Classification.ipynb

Large diffs are not rendered by default.

262 changes: 262 additions & 0 deletions examples/embeddings/Clustering.ipynb

Large diffs are not rendered by default.

396 changes: 396 additions & 0 deletions examples/embeddings/Code_search.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,396 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Code search\n",
"\n",
"We index our own openai-python code repository, and show how it can be searched. We implement a simple version of file parsing and extracting of functions from python files. The dataset is created in the [Obtain_dataset Notebook](Obtain_dataset.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of py files: 40\n",
"Total number of functions extracted: 64\n"
]
}
],
"source": [
"import os\n",
"from glob import glob\n",
"import pandas as pd\n",
"\n",
"def get_function_name(code):\n",
" \"\"\"\n",
" Extract function name from a line beginning with \"def \"\n",
" \"\"\"\n",
" assert code.startswith(\"def \")\n",
" return code[len(\"def \"): code.index(\"(\")]\n",
"\n",
"def get_until_no_space(all_lines, i) -> str:\n",
" \"\"\"\n",
" Get all lines until a line outside the function definition is found.\n",
" \"\"\"\n",
" ret = [all_lines[i]]\n",
" for j in range(i + 1, i + 10000):\n",
" if j < len(all_lines):\n",
" if len(all_lines[j]) == 0 or all_lines[j][0] in [\" \", \"\\t\", \")\"]:\n",
" ret.append(all_lines[j])\n",
" else:\n",
" break\n",
" return \"\\n\".join(ret)\n",
"\n",
"def get_functions(filepath):\n",
" \"\"\"\n",
" Get all functions in a Python file.\n",
" \"\"\"\n",
" whole_code = open(filepath).read().replace(\"\\r\", \"\\n\")\n",
" all_lines = whole_code.split(\"\\n\")\n",
" for i, l in enumerate(all_lines):\n",
" if l.startswith(\"def \"):\n",
" code = get_until_no_space(all_lines, i)\n",
" function_name = get_function_name(code)\n",
" yield {\"code\": code, \"function_name\": function_name, \"filepath\": filepath}\n",
"\n",
"\n",
"# get user root directory\n",
"root_dir = os.path.expanduser(\"~\")\n",
"\n",
"# path to code repository directory\n",
"code_root = root_dir + \"/openai-python\"\n",
"code_files = [y for x in os.walk(code_root) for y in glob(os.path.join(x[0], '*.py'))]\n",
"print(\"Total number of py files:\", len(code_files))\n",
"all_funcs = []\n",
"for code_file in code_files:\n",
" funcs = list(get_functions(code_file))\n",
" for func in funcs:\n",
" all_funcs.append(func)\n",
"\n",
"print(\"Total number of functions extracted:\", len(all_funcs))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For code search models we use babbage-code-search-code to obtain embeddings for code snippets, and code-search-text to embed natural language queries."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>code</th>\n",
" <th>function_name</th>\n",
" <th>filepath</th>\n",
" <th>code_embedding</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>def semantic_search(engine, query, documents):...</td>\n",
" <td>semantic_search</td>\n",
" <td>/examples/semanticsearch/semanticsearch.py</td>\n",
" <td>[-0.038976121693849564, -0.0031428150832653046...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>def main():\\n parser = argparse.ArgumentPar...</td>\n",
" <td>main</td>\n",
" <td>/examples/semanticsearch/semanticsearch.py</td>\n",
" <td>[-0.024289356544613838, -0.017748363316059113,...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>def get_candidates(\\n prompt: str,\\n sto...</td>\n",
" <td>get_candidates</td>\n",
" <td>/examples/codex/backtranslation.py</td>\n",
" <td>[-0.04161201789975166, -0.0169310811907053, 0....</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>def rindex(lst: List, value: str) -&gt; int:\\n ...</td>\n",
" <td>rindex</td>\n",
" <td>/examples/codex/backtranslation.py</td>\n",
" <td>[-0.027255680412054062, -0.007931121625006199,...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>def eval_candidate(\\n candidate_answer: str...</td>\n",
" <td>eval_candidate</td>\n",
" <td>/examples/codex/backtranslation.py</td>\n",
" <td>[-0.00999179296195507, -0.01640152558684349, 0...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" code function_name \\\n",
"0 def semantic_search(engine, query, documents):... semantic_search \n",
"1 def main():\\n parser = argparse.ArgumentPar... main \n",
"2 def get_candidates(\\n prompt: str,\\n sto... get_candidates \n",
"3 def rindex(lst: List, value: str) -> int:\\n ... rindex \n",
"4 def eval_candidate(\\n candidate_answer: str... eval_candidate \n",
"\n",
" filepath \\\n",
"0 /examples/semanticsearch/semanticsearch.py \n",
"1 /examples/semanticsearch/semanticsearch.py \n",
"2 /examples/codex/backtranslation.py \n",
"3 /examples/codex/backtranslation.py \n",
"4 /examples/codex/backtranslation.py \n",
"\n",
" code_embedding \n",
"0 [-0.038976121693849564, -0.0031428150832653046... \n",
"1 [-0.024289356544613838, -0.017748363316059113,... \n",
"2 [-0.04161201789975166, -0.0169310811907053, 0.... \n",
"3 [-0.027255680412054062, -0.007931121625006199,... \n",
"4 [-0.00999179296195507, -0.01640152558684349, 0... "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from utils import get_embedding\n",
"\n",
"df = pd.DataFrame(all_funcs)\n",
"df['code_embedding'] = df['code'].apply(lambda x: get_embedding(x, engine='babbage-code-search-code'))\n",
"df['filepath'] = df['filepath'].apply(lambda x: x.replace(code_root, \"\"))\n",
"df.to_csv(\"output/code_search_openai-python.csv\", index=False)\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/openai/tests/test_endpoints.py:test_completions_multiple_prompts score=0.681\n",
"def test_completions_multiple_prompts():\n",
" result = openai.Completion.create(\n",
" prompt=[\"This was a test\", \"This was another test\"], n=5, engine=\"ada\"\n",
" )\n",
" assert len(result.choices) == 10\n",
"\n",
"----------------------------------------------------------------------\n",
"/openai/tests/test_endpoints.py:test_completions score=0.675\n",
"def test_completions():\n",
" result = openai.Completion.create(prompt=\"This was a test\", n=5, engine=\"ada\")\n",
" assert len(result.choices) == 5\n",
"\n",
"\n",
"----------------------------------------------------------------------\n",
"/openai/tests/test_api_requestor.py:test_requestor_sets_request_id score=0.635\n",
"def test_requestor_sets_request_id(mocker: MockerFixture) -> None:\n",
" # Fake out 'requests' and confirm that the X-Request-Id header is set.\n",
"\n",
" got_headers = {}\n",
"\n",
" def fake_request(self, *args, **kwargs):\n",
" nonlocal got_headers\n",
"----------------------------------------------------------------------\n"
]
}
],
"source": [
"from utils import cosine_similarity\n",
"\n",
"def search_functions(df, code_query, n=3, pprint=True, n_lines=7):\n",
" embedding = get_embedding(code_query, engine='babbage-code-search-text')\n",
" df['similarities'] = df.code_embedding.apply(lambda x: cosine_similarity(x, embedding))\n",
"\n",
" res = df.sort_values('similarities', ascending=False).head(n)\n",
" if pprint:\n",
" for r in res.iterrows():\n",
" print(r[1].filepath+\":\"+r[1].function_name + \" score=\" + str(round(r[1].similarities, 3)))\n",
" print(\"\\n\".join(r[1].code.split(\"\\n\")[:n_lines]))\n",
" print('-'*70)\n",
" return res\n",
"res = search_functions(df, 'Completions API tests', n=3)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/openai/validators.py:format_inferrer_validator score=0.655\n",
"def format_inferrer_validator(df):\n",
" \"\"\"\n",
" This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification.\n",
" It will also suggest to use ada, --no_packing and explain train/validation split benefits.\n",
" \"\"\"\n",
" ft_type = infer_task_type(df)\n",
" immediate_msg = None\n",
"----------------------------------------------------------------------\n",
"/openai/validators.py:long_examples_validator score=0.649\n",
"def long_examples_validator(df):\n",
" \"\"\"\n",
" This validator will suggest to the user to remove examples that are too long.\n",
" \"\"\"\n",
" immediate_msg = None\n",
" optional_msg = None\n",
" optional_fn = None\n",
"----------------------------------------------------------------------\n",
"/openai/validators.py:non_empty_completion_validator score=0.646\n",
"def non_empty_completion_validator(df):\n",
" \"\"\"\n",
" This validator will ensure that no completion is empty.\n",
" \"\"\"\n",
" necessary_msg = None\n",
" necessary_fn = None\n",
" immediate_msg = None\n",
"----------------------------------------------------------------------\n"
]
}
],
"source": [
"res = search_functions(df, 'fine-tuning input data validation logic', n=3)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/openai/validators.py:common_completion_suffix_validator score=0.665\n",
"def common_completion_suffix_validator(df):\n",
" \"\"\"\n",
" This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation.\n",
" \"\"\"\n",
" error_msg = None\n",
" immediate_msg = None\n",
" optional_msg = None\n",
" optional_fn = None\n",
"\n",
" ft_type = infer_task_type(df)\n",
"----------------------------------------------------------------------\n",
"/openai/validators.py:get_outfnames score=0.66\n",
"def get_outfnames(fname, split):\n",
" suffixes = [\"_train\", \"_valid\"] if split else [\"\"]\n",
" i = 0\n",
" while True:\n",
" index_suffix = f\" ({i})\" if i > 0 else \"\"\n",
" candidate_fnames = [\n",
" fname.split(\".\")[0] + \"_prepared\" + suffix + index_suffix + \".jsonl\"\n",
" for suffix in suffixes\n",
" ]\n",
" if not any(os.path.isfile(f) for f in candidate_fnames):\n",
"----------------------------------------------------------------------\n"
]
}
],
"source": [
"res = search_functions(df, 'find common suffix', n=2, n_lines=10)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/openai/cli.py:tools_register score=0.651\n",
"def tools_register(parser):\n",
" subparsers = parser.add_subparsers(\n",
" title=\"Tools\", help=\"Convenience client side tools\"\n",
" )\n",
"\n",
" def help(args):\n",
" parser.print_help()\n",
"\n",
" parser.set_defaults(func=help)\n",
"\n",
" sub = subparsers.add_parser(\"fine_tunes.prepare_data\")\n",
" sub.add_argument(\n",
" \"-f\",\n",
" \"--file\",\n",
" required=True,\n",
" help=\"JSONL, JSON, CSV, TSV, TXT or XLSX file containing prompt-completion examples to be analyzed.\"\n",
" \"This should be the local file path.\",\n",
" )\n",
" sub.add_argument(\n",
" \"-q\",\n",
"----------------------------------------------------------------------\n"
]
}
],
"source": [
"res = search_functions(df, 'Command line interface for fine-tuning', n=1, n_lines=20)"
]
}
],
"metadata": {
"interpreter": {
"hash": "be4b5d5b73a21c599de40d6deb1129796d12dc1cc33a738f7bac13269cfcafe8"
},
"kernelspec": {
"display_name": "Python 3.7.3 64-bit ('base': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading