Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions .github/workflows/rocm-perf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
name: ROCm DLM Performance Evaluations

on:
push:

jobs:
build-and-test-jax-perf:
runs-on: mi-250
strategy:
matrix:
python-version: ["3.10"]
rocm-version: ["6.4.1"]

env:
WORKSPACE_DIR: ${{ format(
'jax_rocm_perf_{0}_{1}_{2}',
github.run_id,
github.run_number,
github.run_attempt
) }}
PYTHON_VERSION: ${{ matrix.python-version }}
ROCM_VERSION: ${{ matrix.rocm-version }}

steps:
- name: Clean up old workdirs
run: |
ls -l
docker run -v "$(pwd):/rocm-jax" ubuntu bash -c "chown -R $UID /rocm-jax/* || true"
rm -rf * || true
ls -l

- name: Print system info
run: |
whoami
printenv
df -h
rocm-smi || true

- name: Checkout source
uses: actions/checkout@v4

- name: Build plugin wheels
run: |
python3 build/ci_build \
--compiler clang \
--python-versions $PYTHON_VERSION \
--rocm-version $ROCM_VERSION \
dist_wheels

- name: Copy wheels for Docker build context
run: |
mkdir -p wheelhouse
cp ./jax_rocm_plugin/wheelhouse/*.whl ./wheelhouse/

- name: Build JAX docker image
run: |
python3 build/ci_build \
--rocm-version $ROCM_VERSION \
build_dockers \
--filter ubu22

- name: Checkout MaxText source
uses: actions/checkout@v4
with:
repository: ROCm/maxtext
ref: rv_jax
path: ${{ env.WORKSPACE_DIR }}/maxtext

- name: Launch container
run: |
docker run -d --name maxtext_container \
--network=host \
--device=/dev/kfd \
--device=/dev/dri \
--ipc=host \
--shm-size=64G \
--group-add=video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
-v "$(pwd)/${{ env.WORKSPACE_DIR }}/maxtext:/maxtext" \
-w /maxtext \
"jax-ubu22.rocm${ROCM_VERSION//.}" \
tail -f /dev/null

- name: Install git inside the container
run: |
docker exec maxtext_container bash -c "apt-get update && apt-get install -y git"

- name: Install requirements and show pip list
run: |
docker exec maxtext_container bash -c "pip install -r requirements.txt && pip list"

- name: Patch jaxlib/plugin_support.py in container
run: |
docker exec maxtext_container bash -c '
JAXLIB_SITE=$(pip show jaxlib | grep Location | awk "{print \$2}")
echo "$JAXLIB_SITE"

sed -i \
"s|\"jax_rocm60_plugin\"|[\"jax_rocm60_plugin\", \"jax_rocm7_plugin\"]|g" \
"$JAXLIB_SITE/jaxlib/plugin_support.py"

sed -i \
"s|_PLUGIN_MODULE_NAME\\[|*_PLUGIN_MODULE_NAME\\[|g" \
"$JAXLIB_SITE/jaxlib/plugin_support.py"

grep -A 10 "_PLUGIN_MODULE_NAME" \
"$JAXLIB_SITE/jaxlib/plugin_support.py"
'

- name: Run MaxText training and save logs
run: |
for config in \
MaxText/configs/models/gpu/llama2_7b_rocm.yml \
MaxText/configs/models/gpu/gemma_2b_rocm.yml \
MaxText/configs/models/gpu/gpt3_6b_rocm.yml \
MaxText/configs/models/gpu/mixtral_8x1b_rocm.yml; do
model_name=$(basename "$config" _rocm.yml)
echo "Running $model_name"
if [[ "$model_name" == "mixtral_8x1b" ]]; then
docker exec maxtext_container bash -c \
"export XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 && \
python3 -m MaxText.train $config" | tee logs_${model_name}.log
else
docker exec maxtext_container bash -c "python3 -m MaxText.train $config" \
| tee logs_${model_name}.log
fi
done

- name: Analyze logs to compute median step time
run: |
pip install numpy
python3 build/analyze_maxtext_logs.py
cat summary.json

- name: Upload logs and summary
uses: actions/upload-artifact@v4
with:
name: training-results
path: |
logs_*.log
summary.json

- name: Cleanup container
if: always()
run: |
docker stop maxtext_container || true
docker rm maxtext_container || true
33 changes: 33 additions & 0 deletions build/analyze_maxtext_logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Script to analyze MaxText logs and compute median step times."""

# pylint: disable=import-error
import json
import re
import glob
import numpy as np

summary = {}
for log in glob.glob("logs_*.log"):
model = log.replace("logs_", "").replace(".log", "")
times = []
with open(log, encoding="utf-8") as f:
for line in f:
m = re.search(r"completed step: \d+, seconds: ([\d.]+)", line)
if m:
times.append(float(m.group(1)))
if times:
times_np = np.array(times)
step_info = [{"step": n, "time": t} for n, t in enumerate(times)]
summary[model] = {
"steps": step_info,
"min_step_time": round(float(np.min(times_np)), 3),
"q25_step_time": round(float(np.percentile(times_np, 25)), 3),
"median_step_time": round(float(np.median(times_np)), 3),
"mean_step_time": round(float(np.mean(times_np)), 3),
"q75_step_time": round(float(np.percentile(times_np, 75)), 3),
"max_step_time": round(float(np.max(times_np)), 3),
"steps_counted": len(times),
}

with open("summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, indent=2)