Run Keras and JAX workloads on cloud TPUs and GPUs with a simple decorator. No infrastructure management required.
import keras_remote
@keras_remote.run(accelerator="v3-8")
def train_model():
import keras
model = keras.Sequential([...])
model.fit(x_train, y_train)
return model.history.history["loss"][-1]
# Executes on TPU v3-8, returns the result
final_loss = train_model()- Features
- Installation
- Quick Start
- Usage Examples
- Configuration
- Supported Accelerators
- Monitoring
- Troubleshooting
- Contributing
- License
- Simple decorator API — Add
@keras_remote.run()to any function to execute it remotely - Automatic infrastructure — No manual VM provisioning or teardown required
- Result serialization — Functions return actual values, not just logs
- Container caching — Subsequent runs start in 2-4 minutes after initial build
- Built-in monitoring — View job status and logs in Google Cloud Console
- Automatic cleanup — Resources are released when jobs complete
Install the core package to use the @keras_remote.run() decorator in your code:
git clone https://github.com/keras-team/keras-remote.git
cd keras-remote
pip install -e .This is sufficient if your infrastructure (GKE cluster, Artifact Registry, etc.) is already provisioned.
Install with the cli extra to also get the keras-remote command for managing infrastructure:
git clone https://github.com/keras-team/keras-remote.git
cd keras-remote
pip install -e ".[cli]"This adds the keras-remote up, keras-remote down, keras-remote status, and keras-remote config commands for provisioning and tearing down cloud resources.
- Python 3.11+
- Google Cloud SDK (
gcloud)- Run
gcloud auth loginandgcloud auth application-default login
- Run
- Pulumi CLI (required for
[cli]install only) - A Google Cloud project with billing enabled
Run the CLI setup command:
keras-remote upThis will interactively:
- Prompt for your GCP project ID
- Let you choose an accelerator type (CPU, GPU, or TPU)
- Enable required APIs (Cloud Build, Artifact Registry, Cloud Storage, GKE)
- Create the Artifact Registry repository
- Provision a GKE cluster with optional accelerator node pools
- Configure Docker authentication and kubectl access
You can also run non-interactively:
keras-remote up --project=my-project --accelerator=t4 --yesTo view current infrastructure state:
keras-remote statusTo view configuration:
keras-remote configAdd to your shell profile (~/.bashrc, ~/.zshrc, etc.):
export KERAS_REMOTE_PROJECT="your-project-id"
export KERAS_REMOTE_ZONE="us-central1-a" # Optionalimport keras_remote
@keras_remote.run(accelerator="v3-8")
def hello_tpu():
import jax
return f"Running on {jax.devices()}"
result = hello_tpu()
print(result)import keras_remote
@keras_remote.run(accelerator="v3-8")
def compute(x, y):
return x + y
result = compute(5, 7)
print(f"Result: {result}") # Output: Result: 12import keras_remote
@keras_remote.run(accelerator="v3-8")
def train_model():
import keras
import numpy as np
model = keras.Sequential([
keras.layers.Dense(64, activation="relu", input_shape=(10,)),
keras.layers.Dense(1)
])
model.compile(optimizer="adam", loss="mse")
x_train = np.random.randn(1000, 10)
y_train = np.random.randn(1000, 1)
history = model.fit(x_train, y_train, epochs=5, verbose=0)
return history.history["loss"][-1]
final_loss = train_model()
print(f"Final loss: {final_loss}")Create a requirements.txt in your project directory:
tensorflow-datasets
pillow
scikit-learn
Keras Remote automatically detects and installs dependencies on the remote worker.
Skip container build time by using prebuilt images:
@keras_remote.run(
accelerator="v3-8",
container_image="us-docker.pkg.dev/my-project/keras-remote/prebuilt:v1.0"
)
def train():
...See examples/Dockerfile.prebuilt for a template.
| Variable | Required | Default | Description |
|---|---|---|---|
KERAS_REMOTE_PROJECT |
Yes | — | Google Cloud project ID |
KERAS_REMOTE_ZONE |
No | us-central1-a |
Default compute zone |
KERAS_REMOTE_CLUSTER |
No | — | GKE cluster name |
@keras_remote.run(
accelerator="v3-8", # Required: TPU/GPU type
container_image=None, # Custom container URI
zone=None, # Override default zone
project=None, # Override default project
cluster=None, # GKE cluster name
namespace="default" # Kubernetes namespace
)Note: each accelerator and topology requires setting up its own NodePool as a prerequisite.
| Type | Configurations |
|---|---|
| TPU v2 | v2-8, v2-32 |
| TPU v3 | v3-8, v3-32 |
| TPU v5 Litepod | v5litepod-1, v5litepod-4, v5litepod-8 |
| TPU v5p | v5p-8, v5p-16 |
| TPU v6e | v6e-8, v6e-16 |
| Type | Aliases |
|---|---|
| NVIDIA T4 | t4, nvidia-tesla-t4 |
| NVIDIA L4 | l4, nvidia-l4 |
| NVIDIA V100 | v100, nvidia-tesla-v100 |
| NVIDIA A100 | a100, nvidia-tesla-a100 |
| NVIDIA H100 | h100, nvidia-h100-80gb |
For multi-GPU configurations on GKE, append the count: a100x4, l4x2, etc.
- Cloud Build: console.cloud.google.com/cloud-build/builds
- GKE Workloads: console.cloud.google.com/kubernetes/workload
# List GKE jobs
kubectl get jobs -n defaultexport KERAS_REMOTE_PROJECT="your-project-id"Enable required APIs and create the Artifact Registry repository:
gcloud services enable cloudbuild.googleapis.com \
artifactregistry.googleapis.com storage.googleapis.com \
container.googleapis.com --project=$KERAS_REMOTE_PROJECT
gcloud artifacts repositories create keras-remote \
--repository-format=docker \
--location=us \
--project=$KERAS_REMOTE_PROJECTGrant required IAM roles:
gcloud projects add-iam-policy-binding $KERAS_REMOTE_PROJECT \
--member="user:[email protected]" \
--role="roles/storage.admin"Check Cloud Build logs:
gcloud builds list --project=$KERAS_REMOTE_PROJECT --limit=5import logging
logging.basicConfig(level=logging.INFO)# Check authentication
gcloud auth list
# Check project
echo $KERAS_REMOTE_PROJECT
# Check APIs
gcloud services list --enabled --project=$KERAS_REMOTE_PROJECT \
| grep -E "(cloudbuild|artifactregistry|storage|container)"
# Check Artifact Registry
gcloud artifacts repositories describe keras-remote \
--location=us --project=$KERAS_REMOTE_PROJECTRemove all Keras Remote resources to avoid charges:
keras-remote downThis removes:
- GKE cluster and accelerator node pools (via Pulumi)
- Artifact Registry repository and container images
- Cloud Storage buckets (jobs and builds)
Use
--yesto skip the confirmation prompt.
Contributions are welcome. Please read our contributing guidelines before submitting pull requests.
All contributions must follow our Code of Conduct.
This project is licensed under the Apache License 2.0. See LICENSE for details.
Maintained by the Keras team at Google.