diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 91bf584..fd2b500 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 diff --git a/README.rst b/README.rst index c9c09f8..2d03543 100644 --- a/README.rst +++ b/README.rst @@ -61,7 +61,7 @@ To use the project: # as well as neuron activations and input saliency values. # To view the input saliency - output.saliency() + output.primary_attributions() This does the following: @@ -70,17 +70,17 @@ This does the following: 3. The model returns an ecco ``OutputSeq`` object. This object holds the output sequence, but also a lot of data generated by the generation run, including the input sequence and input saliency values. If we set ``activations=True`` in ``from_pretrained()``, then this would also contain neuron activation values. 4. ``output`` can now produce various interactive explorables. Examples include: -- ``output.saliency()`` to generate input saliency explorable [`Input Saliency Colab Notebook `_] +- ``output.primary_attributions()`` to generate input saliency explorable [`Input Saliency Colab Notebook `_] - ``output.run_nmf()`` to to explore non-negative matrix factorization of neuron activations [`Neuron Activation Colab Notebook `_] .. code-block:: python # To view the input saliency explorable - output.saliency() + output.primary_attributions() # to view input saliency with more details (a bar and % value for each token) - output.saliency(style="detailed") + output.primary_attributions(style="detailed") # output.activations contains the neuron activation values. it has the shape: (layer, neuron, token position) diff --git a/notebooks/Language_Models_and_Ecco_PyData_Khobar.ipynb b/notebooks/Language_Models_and_Ecco_PyData_Khobar.ipynb index c761441..5a402c9 100644 --- a/notebooks/Language_Models_and_Ecco_PyData_Khobar.ipynb +++ b/notebooks/Language_Models_and_Ecco_PyData_Khobar.ipynb @@ -87,13 +87,6 @@ "id": "Ixn-FHcdYL_M", "outputId": "21f706c7-2229-4c4d-c616-5c75a3e269c9" }, - "source": [ - "text = \" it was a matter of\"\r\n", - "\r\n", - "# Generate one token\r\n", - "output_1 = lm.generate(text, generate=1, do_sample=False)" - ], - "execution_count": 6, "outputs": [ { "output_type": "display_data", @@ -225,6 +218,12 @@ "tags": [] } } + ], + "source": [ + "text = \" it was a matter of\"\n", + "\n", + "# Generate one token\n", + "output_1 = lm.generate(text, generate=1, do_sample=False, output_hidden_states=True)" ] }, { @@ -237,12 +236,6 @@ "id": "JzA265wyYL4K", "outputId": "54f3826b-3e9c-438e-db7c-9a635c1cc4f6" }, - "source": [ - "# Show the top 10 candidate output tokens for position #5. \r\n", - "# Layer 5 is the last layer in the model.\r\n", - "output_1.layer_predictions(position=5, layer=5, topk=10)" - ], - "execution_count": 7, "outputs": [ { "output_type": "stream", @@ -363,6 +356,11 @@ "tags": [] } } + ], + "source": [ + "# Show the top 10 candidate output tokens for position #5. \n", + "# Layer 5 is the last layer in the model.\n", + "output_1.layer_predictions(position=5, layer=5, topk=10)" ] }, { @@ -415,11 +413,6 @@ "id": "RdC5MO1yvTvC", "outputId": "86193189-42d6-4d77-a28e-e8f665c2c83e" }, - "source": [ - "# Compare the rankings of \"Principle\" and \"Principal\" across layers\n", - "output_1.rankings_watch(watch=[7989, 10033], position=5)" - ], - "execution_count": null, "outputs": [ { "output_type": "display_data", @@ -433,6 +426,10 @@ "tags": [] } } + ], + "source": [ + "# Compare the rankings of \"Principle\" and \"Principal\" across layers\n", + "output_1.rankings_watch(watch=[7989, 10033], position=5)" ] }, { @@ -457,11 +454,6 @@ }, "outputId": "d5ac06ef-8157-4b89-debe-6d3400b5f848" }, - "source": [ - "text = \" Heathrow airport is located in\"\n", - "output_2 = lm.generate(text, generate=5, do_sample=False)" - ], - "execution_count": 6, "outputs": [ { "output_type": "display_data", @@ -681,6 +673,10 @@ "tags": [] } } + ], + "source": [ + "text = \" Heathrow airport is located in\"\n", + "output_2 = lm.generate(text, generate=5, do_sample=False)" ] }, { @@ -702,11 +698,6 @@ "id": "2E-8avmf8OWc", "outputId": "df611a7f-5ad4-4f8e-8bcf-8df553245d5a" }, - "source": [ - "text = \" Heathrow airport is located in the city of\"\n", - "output_2 = lm.generate(text, generate=1, do_sample=False)" - ], - "execution_count": 7, "outputs": [ { "output_type": "display_data", @@ -838,6 +829,10 @@ "tags": [] } } + ], + "source": [ + "text = \" Heathrow airport is located in the city of\"\n", + "output_2 = lm.generate(text, generate=1, do_sample=False, output_hidden_states=True)" ] }, { @@ -850,11 +845,6 @@ "id": "Ws7tKJ_K8Wcg", "outputId": "f045a616-d379-4eff-ab35-818b5bc7fd57" }, - "source": [ - "# What other tokens were possible to output in place of \"London\"?\n", - "output_2.layer_predictions(position=9, layer=5, topk=30)" - ], - "execution_count": null, "outputs": [ { "output_type": "stream", @@ -975,6 +965,10 @@ "tags": [] } } + ], + "source": [ + "# What other tokens were possible to output in place of \"London\"?\n", + "output_2.layer_predictions(position=9, layer=5, topk=30)" ] }, { @@ -987,12 +981,6 @@ "id": "HrkMe1rI8kCF", "outputId": "60d9d4f8-7aef-4fa7-edee-d42aa5994db9" }, - "source": [ - "# Now that the model has selcted the tokens \"London . \\n\"\n", - "# How did each layer rank these tokens during processing?\n", - "output_2.rankings()" - ], - "execution_count": null, "outputs": [ { "output_type": "display_data", @@ -1006,6 +994,11 @@ "tags": [] } } + ], + "source": [ + "# Now that the model has selcted the tokens \"London . \\n\"\n", + "# How did each layer rank these tokens during processing?\n", + "output_2.rankings()" ] }, { @@ -1046,12 +1039,6 @@ "id": "X72IPVlJ8_8r", "outputId": "dba57cae-c14a-4270-bcbe-1fa06bda06a9" }, - "source": [ - "text= \"The countries of the European Union are:\\n1. Austria\\n2. Belgium\\n3. Bulgaria\\n4.\"\n", - "\n", - "output_3 = lm.generate(text, generate=20, do_sample=True)" - ], - "execution_count": 8, "outputs": [ { "output_type": "display_data", @@ -1601,6 +1588,11 @@ "tags": [] } } + ], + "source": [ + "text= \"The countries of the European Union are:\\n1. Austria\\n2. Belgium\\n3. Bulgaria\\n4.\"\n", + "\n", + "output_3 = lm.generate(text, generate=20, do_sample=True, output_hidden_states=True, attribution=[\"grad_x_input\"])" ] }, { @@ -1639,6 +1631,9 @@ "tags": [] } } + ], + "source": [ + "output_3.rankings()" ] }, { @@ -1660,10 +1655,6 @@ "id": "D-nKpXW5FPlY", "outputId": "a476e098-5e50-4f51-9362-990ebb52ade9" }, - "source": [ - "output_3.saliency()" - ], - "execution_count": null, "outputs": [ { "output_type": "display_data", @@ -1780,6 +1771,9 @@ "tags": [] } } + ], + "source": [ + "output_3.primary_attributions()" ] }, { @@ -1802,10 +1796,6 @@ }, "outputId": "e7208416-9b90-4573-fa4f-185fec23e896" }, - "source": [ - "output_3.saliency(style=\"detailed\")" - ], - "execution_count": null, "outputs": [ { "output_type": "display_data", @@ -1914,6 +1904,9 @@ "tags": [] } } + ], + "source": [ + "output_3.primary_attributions(style=\"detailed\")" ] } ] diff --git a/requirements.txt b/requirements.txt index 3517f24..91e7907 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ matplotlib>=3.3 numpy>=1.19 ipython>=7.16 -scikit-learn>=0.24.2,<2 -seaborn>=0.11 -transformers~=4.6 -pytest>=6.1.2 -setuptools>=49.6.0 -torch>=1.9.0,<3 -PyYAML>=6.0 -captum~=0.4.1 +scikit-learn>=0.25 +seaborn>=0.13 +transformers<4.47 +pytest>=7 +setuptools>=50 +torch>=2 +PyYAML>=6 +captum>=0.4 diff --git a/setup.py b/setup.py index d332e51..f67ddf7 100644 --- a/setup.py +++ b/setup.py @@ -1,26 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from __future__ import absolute_import -from __future__ import print_function - -import io import re -from glob import glob -from os.path import basename -from os.path import dirname -from os.path import join -from os.path import splitext +from pathlib import Path from setuptools import find_packages from setuptools import setup -def read(*names, **kwargs): - with io.open( - join(dirname(__file__), *names), - encoding=kwargs.get('encoding', 'utf8') - ) as fh: - return fh.read() +def read(file): + return Path(file).read_text(encoding='utf-8') setup( @@ -37,23 +25,20 @@ def read(*names, **kwargs): url='https://github.com/jalammar/ecco', packages=find_packages('src'), package_dir={'': 'src'}, - py_modules=[splitext(basename(path))[0] for path in glob('src/*.py')], + py_modules=[p.stem for p in Path('src').glob('*.py')], include_package_data=True, zip_safe=False, classifiers=[ - # complete classifier list: http://pypi.python.org/pypi?%3Aaction=list_classifiers 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'License :: OSI Approved :: BSD License', - 'Operating System :: Unix', - 'Operating System :: POSIX', - 'Operating System :: Microsoft :: Windows', - 'Programming Language :: Python', + 'Operating System :: OS Independent', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy', 'Topic :: Utilities', @@ -65,21 +50,18 @@ def read(*names, **kwargs): keywords=[ 'Natural Language Processing', 'Explainable AI', 'keyword3', ], - python_requires='!=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*', + python_requires='>=3.7', install_requires=[ - "transformers ~= 4.2", - "seaborn >= 0.11", - "scikit-learn>=0.23,<2", - "PyYAML>=6.0", - "captum ~= 0.4" + "transformers<4.47", + "seaborn>=0.13", + "scikit-learn>=0.25", + "PyYAML>=6", + "captum>=0.4" ], extras_require={ "dev": [ - "pytest>=6.1", + "pytest>=7", ], - # eg: - # 'rst': ['docutils>=0.11'], - # ':python_version=="2.6"': ['argparse'], }, entry_points={ 'console_scripts': [