diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index d65867642887..4dd5ab91477a 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -57,7 +57,7 @@ jobs: run: | pip install -r requirements.txt --progress-bar off --upgrade if [ "${{ matrix.nnx_enabled }}" == "true" ]; then - pip install --upgrade git+https://github.com/google/flax.git + pip install --upgrade flax>=0.11.1 fi pip uninstall -y keras keras-nightly pip install -e "." --progress-bar off --upgrade @@ -147,4 +147,4 @@ jobs: pip uninstall -y keras keras-nightly pip install -e "." --progress-bar off --upgrade - name: Run pre-commit - run: pre-commit run --all-files --hook-stage manual + run: pre-commit run --all-files --hook-stage manual \ No newline at end of file diff --git a/.github/workflows/tpu-tests-jax.yml b/.github/workflows/tpu-tests-jax.yml new file mode 100644 index 000000000000..543379f9030c --- /dev/null +++ b/.github/workflows/tpu-tests-jax.yml @@ -0,0 +1,104 @@ +# name: TPU Tests + +# on: +# push: +# branches: [ master ] +# pull_request: +# release: +# types: [created] + +# # Only basic permissions are needed now. +# permissions: +# contents: read + +# jobs: +# test-in-container: +# name: Test in Custom Container +# runs-on: linux-x86-ct6e-44-1tpu + +# # With the correct IAM policies applied to the runner's underlying service accounts, +# # the runner can now pull this private image directly without any in-workflow auth. +# container: +# image: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest +# # Options are still needed for the container to access the host's TPU hardware. +# options: --privileged --network host + +# steps: +# - name: Checkout Repository +# uses: actions/checkout@v4 +# # This makes your code available inside the container's workspace. + +# - name: Run Verification and Tests +# run: | +# echo "Successfully running inside the private container from GAR!" +# echo "Verifying JAX installation..." +# python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices: {jax.devices()}')" +# pip install grain +# pytest keras --ignore keras/src/applications \ +# --ignore keras/src/layers/merging/merging_test.py \ +# --ignore keras/src/trainers/data_adapters/py_dataset_adapter_test.py \ +# --ignore keras/src/backend/jax/distribution_lib_test.py \ +# --ignore keras/src/distribution/distribution_lib_test.py \ +# --cov=keras \ +# --cov-config=pyproject.toml + + + +name: Keras Tests on TPU Runner using JAX Backend + +on: + push: + branches: [ master ] + pull_request: + release: + types: [created] + +# Only basic permissions are needed now. +permissions: + contents: read + +jobs: + test-in-container: + name: Run Keras tests on TPU runner using JAX Backend + runs-on: linux-x86-ct6e-44-1tpu + + container: + # The container image is now set to python:3.10-slim + image: python:3.10-slim + # Options are still needed for the container to access the host's TPU hardware. + options: --privileged --network host + + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Install System Dependencies + run: | + apt-get update && apt-get install -y --no-install-recommends \ + git \ + sudo \ + && rm -rf /var/lib/apt/lists/* + + - name: Install Dependencies + run: | + pip install --no-cache-dir -U pip setuptools && \ + pip install --no-cache-dir -U psutil && \ + pip install --no-cache-dir -r requirements-jax-tpu.txt && \ + pip uninstall -y keras keras-nightly + + - name: Set Keras Backend + run: echo "KERAS_BACKEND=jax" >> $GITHUB_ENV + + - name: Run Verification and Tests + run: | + echo "Successfully running inside the public python container!" + echo "Verifying JAX installation..." + python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices: {jax.devices()[0].device_kind}')" + + pytest keras --ignore keras/src/applications \ + --ignore keras/src/layers/merging/merging_test.py \ + --ignore keras/src/trainers/data_adapters/py_dataset_adapter_test.py \ + --ignore keras/src/backend/jax/distribution_lib_test.py \ + --ignore keras/src/distribution/distribution_lib_test.py \ + --cov=keras \ + --cov-config=pyproject.toml diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000000..324f3798742f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,28 @@ +FROM --platform=linux/amd64 python:3.10-slim + +ENV KERAS_HOME=/github/workspace/.github/workflows/config/jax \ + KERAS_BACKEND=jax + +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + sudo \ + && rm -rf /var/lib/apt/lists/* + +# Copy the entire codebase into the container +COPY . /github/workspace +WORKDIR /github/workspace + +# Create and activate venv, install pip/setuptools/psutil, then run tests +# RUN cd ./keras/src/github/keras && \ +RUN pip install --no-cache-dir -U pip setuptools && \ + pip install --no-cache-dir -U psutil && \ + pip install --no-cache-dir -r requirements-jax-tpu.txt && \ + pip uninstall -y keras keras-nightly + # python3 -c 'import jax;print(jax.__version__);print(jax.default_backend())' && \ + # python3 -c 'import jax;assert jax.default_backend().lower() == "tpu"' && \ + # pytest keras --ignore keras/src/applications \ + # --ignore keras/src/layers/merging/merging_test.py \ + # --cov=keras \ + # --cov-config=pyproject.toml + +CMD ["/bin/bash"] diff --git a/conftest.py b/conftest.py index 0ade560a1bdf..cdf5af99569b 100644 --- a/conftest.py +++ b/conftest.py @@ -58,4 +58,4 @@ def pytest_collection_modifyitems(config, items): def skip_if_backend(given_backend, reason): - return pytest.mark.skipif(backend() == given_backend, reason=reason) + return pytest.mark.skipif(backend() == given_backend, reason=reason) \ No newline at end of file diff --git a/keras/src/backend/common/dtypes_test.py b/keras/src/backend/common/dtypes_test.py index 7750dcecdd11..97cedf361900 100644 --- a/keras/src/backend/common/dtypes_test.py +++ b/keras/src/backend/common/dtypes_test.py @@ -245,4 +245,4 @@ def test_invalid_float8_dtype(self): with self.assertRaisesRegex( ValueError, "There is no implicit conversions from float8 dtypes" ): - dtypes.result_type("float8_e5m2", "bfloat16") + dtypes.result_type("float8_e5m2", "bfloat16") \ No newline at end of file diff --git a/requirements-common-jax-tpu.txt b/requirements-common-jax-tpu.txt new file mode 100644 index 000000000000..cd5b52def930 --- /dev/null +++ b/requirements-common-jax-tpu.txt @@ -0,0 +1,31 @@ +pre-commit +#namex>=0.0.8 +ruff +pytest +#numpy +scipy +scikit-learn +pillow +pandas +#absl-py +#requests +#h5py +#ml-dtypes +#protobuf +tensorboard-plugin-profile +#rich +build +#optree +pytest-cov +#packaging +# for tree_test.py +dm_tree +coverage +# for onnx_test.py +onnxruntime +# TODO(https://github.com/keras-team/keras/issues/21390) +# > 0.3.1 breaks LSTM model export in torch backend. +onnxscript<=0.3.1 +openvino +# for grain_dataset_adapter_test.py +grain diff --git a/requirements-common.txt b/requirements-common.txt index d1f3616ab07a..4e27a80d69ea 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -20,12 +20,8 @@ pytest-cov packaging # for tree_test.py dm_tree -coverage +coverage!=7.6.5 # 7.6.5 breaks CI # for onnx_test.py onnxruntime -# TODO(https://github.com/keras-team/keras/issues/21390) -# > 0.3.1 breaks LSTM model export in torch backend. -onnxscript<=0.3.1 openvino -# for grain_dataset_adapter_test.py grain diff --git a/requirements-jax-tpu.txt b/requirements-jax-tpu.txt new file mode 100644 index 000000000000..c572fe82fe83 --- /dev/null +++ b/requirements-jax-tpu.txt @@ -0,0 +1,14 @@ +# Tensorflow cpu-only version (needed for testing). +tensorflow-cpu~=2.18.1 +tf2onnx + +# Torch cpu-only version (needed for testing). +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.6.0 + +# Jax with cuda support. +--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html +jax[tpu] +flax + +-r requirements-common-jax-tpu.txt diff --git a/requirements-tensorflow-tpu.txt b/requirements-tensorflow-tpu.txt new file mode 100644 index 000000000000..d2bec8f5c2e4 --- /dev/null +++ b/requirements-tensorflow-tpu.txt @@ -0,0 +1,14 @@ +#tensorflow==2.18.0 +#--find-links https://storage.googleapis.com/libtpu-tf-releases/index.html +#tensorflow-tpu==2.18.0 + +#tf2onnx + +# Torch cpu-only version (needed for testing). +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.6.0 + +# Jax cpu-only version (needed for testing). +jax + +-r requirements-common.txt