Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 101 additions & 101 deletions metrax_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
},
{
"cell_type": "code",
"source": [
"!pip install google-metrax clu"
],
"execution_count": null,
"metadata": {
"id": "t-eEZKa8qHy7"
},
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"!pip install google-metrax"
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -135,6 +135,9 @@
},
{
"cell_type": "markdown",
"metadata": {
"id": "aZiutTU3qguP"
},
"source": [
"## Lifecycle of a Metrax Metric\n",
"\n",
Expand All @@ -144,10 +147,7 @@
"2. **Iteration/Batch Processing:** As you process data in batches or on different devices, you create a new metric state for the current batch/device using `Metric.from_model_output()` (functional API) or update the existing metric object with `.update()` (object-oriented API), passing the predictions, labels, and any relevant weights for that specific data slice.\n",
"3. **Merging/Updating:** For the functional API, you merge the newly created metric state for the current batch/device with the accumulated state in your dictionary using the `.merge()` method. For the object-oriented API, the `.update()` method directly modifies the state within the metric object in your dictionary. This step accumulates the necessary statistics across all processed data.\n",
"4. **Final Computation:** After processing all data, you call the `.compute()` method on the final, merged (functional API) or updated (object-oriented API) metric state in your dictionary. This performs the final calculations and returns the metric's value (e.g., a single floating-point number like AUCPR)."
],
"metadata": {
"id": "aZiutTU3qguP"
}
]
},
{
"cell_type": "markdown",
Expand All @@ -171,16 +171,21 @@
},
{
"cell_type": "markdown",
"metadata": {
"id": "TOT_PXRDv80K"
},
"source": [
"###Basic Usage: Unweighted Metrics\n",
"This first example demonstrates the core functional workflow without sample weights."
],
"metadata": {
"id": "TOT_PXRDv80K"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HFHRQRmpxmys"
},
"outputs": [],
"source": [
"import metrax\n",
"\n",
Expand All @@ -191,12 +196,7 @@
" 'AUCPR': metrax.AUCPR,\n",
" 'AUCROC': metrax.AUCROC,\n",
"}"
],
"metadata": {
"id": "HFHRQRmpxmys"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
Expand All @@ -220,6 +220,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sYZe0JUQxiNt"
},
"outputs": [],
"source": [
"# --- Method 2: Iterative Merging by Batch (Unweighted) ---\n",
"print(\"\\n--- Method 2: Iterative Merging (Unweighted) ---\")\n",
Expand All @@ -239,25 +244,25 @@
"for name, metric_state in iterative_metrics.items():\n",
" iterative_results[name] = metric_state.compute()\n",
" print(f\"{name}: {iterative_results[name]}\")"
],
"metadata": {
"id": "sYZe0JUQxiNt"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B_HJqaJKv_73"
},
"source": [
"###Advanced Usage: Incorporating Sample Weights\n",
"Just like the NNX API, the functional API supports sample_weights in from_model_output for metrics where it is applicable. The following example calculates AUCPR and AUCROC using the weighted data we prepared earlier."
],
"metadata": {
"id": "B_HJqaJKv_73"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "czPiI6dhxvmL"
},
"outputs": [],
"source": [
"import metrax\n",
"\n",
Expand All @@ -266,15 +271,15 @@
" 'AUCPR': metrax.AUCPR,\n",
" 'AUCROC': metrax.AUCROC,\n",
"}"
],
"metadata": {
"id": "czPiI6dhxvmL"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MSnndNonwDKq"
},
"outputs": [],
"source": [
"# --- Method 1: Full-Batch Calculation (Weighted) ---\n",
"print(\"--- Method 1: Full-Batch Calculation (Weighted) ---\")\n",
Expand All @@ -287,15 +292,15 @@
" )\n",
" full_batch_results_weighted[name] = metric_state.compute()\n",
" print(f\"{name}: {full_batch_results_weighted[name]}\")"
],
"metadata": {
"id": "MSnndNonwDKq"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e8BPLW6XxxVr"
},
"outputs": [],
"source": [
"# --- Method 2: Iterative Merging by Batch (Weighted) ---\n",
"print(\"\\n--- Method 2: Iterative Merging (Weighted) ---\")\n",
Expand All @@ -315,12 +320,7 @@
"for name, metric_state in iterative_metrics_weighted.items():\n",
" iterative_results_weighted[name] = metric_state.compute()\n",
" print(f\"{name}: {iterative_results_weighted[name]}\")"
],
"metadata": {
"id": "e8BPLW6XxxVr"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
Expand All @@ -341,17 +341,22 @@
},
{
"cell_type": "markdown",
"metadata": {
"id": "TokxcQvTvWua"
},
"source": [
"###Basic Usage: Unweighted Metrics\n",
"Let's start with the simplest use case: calculating metrics without any sample weights. This example demonstrates the core object-oriented workflow. Note that for Precision and Recall, metrax uses a default classification threshold of 0.5.\n",
"\n"
],
"metadata": {
"id": "TokxcQvTvWua"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "g0ULds_Sx06D"
},
"outputs": [],
"source": [
"import metrax.nnx\n",
"\n",
Expand All @@ -362,12 +367,7 @@
" 'AUCPR': metrax.nnx.AUCPR,\n",
" 'AUCROC': metrax.nnx.AUCROC,\n",
"}"
],
"metadata": {
"id": "g0ULds_Sx06D"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -395,6 +395,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2Hz14HRfx5vM"
},
"outputs": [],
"source": [
"# --- Method 2: Iterative Updating by Batch (nnx) ---\n",
"print(\"\\n--- Method 2: Iterative Updating with nnx (Unweighted) ---\")\n",
Expand All @@ -410,27 +415,27 @@
"for name, metric_obj in iterative_metrics_nnx.items():\n",
" iterative_results_nnx[name] = metric_obj.compute()\n",
" print(f\"{name}: {iterative_results_nnx[name]}\")"
],
"metadata": {
"id": "2Hz14HRfx5vM"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mhq1dM5Rvq-Q"
},
"source": [
"###Advanced Usage: Incorporating Sample Weights\n",
"In many real-world scenarios, you'll want to assign different importance to different examples. This is often done to handle class imbalance, where you might give more weight to examples from a rare class. metrax.nnx supports this through the sample_weights argument in the .update() method.\n",
"\n",
"The following example calculates AUCPR and AUCROC, which are metrics that support sample weights, using the weighted data we prepared earlier"
],
"metadata": {
"id": "mhq1dM5Rvq-Q"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mB2z33Jax8md"
},
"outputs": [],
"source": [
"import metrax.nnx\n",
"\n",
Expand All @@ -439,15 +444,15 @@
" 'AUCPR': metrax.nnx.AUCPR,\n",
" 'AUCROC': metrax.nnx.AUCROC,\n",
"}"
],
"metadata": {
"id": "mB2z33Jax8md"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HERtwSZbvs6-"
},
"outputs": [],
"source": [
"# --- Method 1: Full-Batch Calculation with Sample Weights ---\n",
"print(\"--- Method 1: Full-Batch Calculation with nnx (Weighted) ---\")\n",
Expand All @@ -467,15 +472,15 @@
"for name, metric_obj in full_batch_metrics_weighted.items():\n",
" full_batch_results_weighted[name] = metric_obj.compute()\n",
" print(f\"{name}: {full_batch_results_weighted[name]}\")"
],
"metadata": {
"id": "HERtwSZbvs6-"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rTAmdmyHx-wi"
},
"outputs": [],
"source": [
"# --- Method 2: Iterative Updating with Sample Weights ---\n",
"print(\"\\n--- Method 2: Iterative Updating with nnx (Weighted) ---\")\n",
Expand All @@ -495,15 +500,13 @@
"for name, metric_obj in iterative_metrics_weighted.items():\n",
" iterative_results_weighted[name] = metric_obj.compute()\n",
" print(f\"{name}: {iterative_results_weighted[name]}\")"
],
"metadata": {
"id": "rTAmdmyHx-wi"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5_MHjDaRzwBf"
},
"source": [
"## Scaling to Multiple Devices\n",
"\n",
Expand All @@ -521,13 +524,15 @@
"4. **`jit`-Compile the Function**: You write a function that looks like a normal, single-device calculation and decorate it with `@jax.jit`. When JAX's compiler sees that the inputs to this function are sharded arrays, it automatically generates a distributed version of the code, implicitly handling all cross-device communication.\n",
"\n",
"This method provides the fine-grained control that is essential for all modern JAX parallelism patterns."
],
"metadata": {
"id": "5_MHjDaRzwBf"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QkVcKJRPAbRE"
},
"outputs": [],
"source": [
"import jax\n",
"import numpy as np\n",
Expand All @@ -547,15 +552,15 @@
" labels=labels,\n",
" sample_weights=sample_weights\n",
" )"
],
"metadata": {
"id": "QkVcKJRPAbRE"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VGsGxzR3Afsv"
},
"outputs": [],
"source": [
"# Advanced SPMD Parallelism: jit + Mesh\n",
"def calculate_aucpr_mesh(predictions, labels, sample_weights):\n",
Expand Down Expand Up @@ -588,12 +593,7 @@
"\n",
" # The result is already a globally correct metric state, replicated on all devices.\n",
" return jitted_calculate(sharded_predictions, sharded_labels, sharded_weights)"
],
"metadata": {
"id": "VGsGxzR3Afsv"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
Expand Down
Loading