Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
/xla
/jax

# temporary working files
/_my
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what generates this _my folder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing generates it, I just use it for work scripts and tired of adding it to a local .gitignore on each new checkout.
It's good to have a dedicated dir for such stuff, dumps and whatnot...

# bazel external dependencies directory mapping for individual bazel workspaces that might be used
/jax_rocm_plugin/external

# vim droppings
*.swp

Expand Down
6 changes: 6 additions & 0 deletions DEVSETUP.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ targets,
- `install` will install the wheels in `dist/` with `pip`
- `test` runs the basic plugin unit tests
- `clean` will delete all wheels in `dist/`
- `refresh` is a shortcut for `clean dist install`
- and a set of dedicated targets to locally build `jaxlib` when necessary:
- `jaxlib_clean` to remove old wheels in `$jax_dir/dist/`,
- `jaxlib` to build and
- `jaxlib_install` to install the wheel
- `refresh_jaxlib` is a shortcut to `jaxlib_clean jaxlib jaxlib_install`

To build and install the plugins in your virtual environment, run
```shell
Expand Down
184 changes: 168 additions & 16 deletions stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,33 @@
# customize to a single arch for local dev builds to reduce compile time
AMDGPU_TARGETS ?= "$(shell rocminfo | grep -o -m 1 'gfx.*')"

# Use your local XLA for building jax_rocm_pjrt
XLA_OVERRIDE_OPTION="--override_repository=xla=%(xla_dir)s"
###### auxiliary vars. Note the absence of quotes around variable values, - these vars are expected to be put into other quoted vars
# Bazel options to build repos in a certain mode.
CFG_DEBUG=--config=debug --compilation_mode=dbg --strip=never --copt=-g3 --copt=-O0 --cxxopt=-g3 --cxxopt=-O0
CFG_RELEASE_WITH_SYM=--strip=never --copt=-g3 --cxxopt=-g3

# Sets '-fdebug-prefix-map=' compiler parameter to remap source file locations from bazel's reproducible builds
# sandbox /proc/self/cwd to correct local paths. Note, external dependencies support require 'external' symlink
# in a corresponding bazel workspace root
PLUGIN_SYMBOLS=--copt=-fdebug-prefix-map=/proc/self/cwd=%(this_repo_root)s/jax_rocm_plugin --cxxopt=-fdebug-prefix-map=/proc/self/cwd=%(this_repo_root)s/jax_rocm_plugin
JAXLIB_SYMBOLS=--copt=-fdebug-prefix-map=/proc/self/cwd=%(kernels_jax_path)s --cxxopt=-fdebug-prefix-map=/proc/self/cwd=%(kernels_jax_path)s

###### --bazel_options values, must be enquoted
# Defines a value for '--bazel_options' for each of 3 build types (pjrt, plugin + jaxlib).
# By default, uses local XLA for each wheel. Redefine to whatever option is needed for your case
ALL_BAZEL_OPTIONS="--override_repository=xla=%(xla_path)s%(custom_options)s"

# PLUGIN_BAZEL_OPTIONS and JAXLIB_BAZEL_OPTIONS define pjrt&plugin specific bazel options and jaxlib specific build options.
PLUGIN_BAZEL_OPTIONS="%(plugin_bazel_options)s"
JAXLIB_BAZEL_OPTIONS="%(jaxlib_bazel_options)s"

# Use your local JAX for building the kernels in jax_rocm_plugin
# KERNELS_JAX_OVERRIDE_OPTION="--override_repository=jax=../jax"
KERNELS_JAX_OVERRIDE_OPTION="%(kernels_jax_override)s"

###


.PHONY: test clean install dist

.default: dist
Expand All @@ -45,10 +65,12 @@
python3 ./build/build.py build \
--use_clang=true \
--wheels=jax-rocm-plugin \
--target_cpu_features=native \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

--rocm_path=/opt/rocm/ \
--rocm_version=%(plugin_version)s \
--rocm_amdgpu_targets=${AMDGPU_TARGETS} \
--bazel_options=${XLA_OVERRIDE_OPTION} \
--bazel_options=${ALL_BAZEL_OPTIONS} \
--bazel_options=${PLUGIN_BAZEL_OPTIONS} \
--bazel_options=${KERNELS_JAX_OVERRIDE_OPTION} \
--verbose \
--clang_path=%(clang_path)s
Expand All @@ -58,10 +80,12 @@
python3 ./build/build.py build \
--use_clang=true \
--wheels=jax-rocm-pjrt \
--target_cpu_features=native \
--rocm_path=/opt/rocm/ \
--rocm_version=%(plugin_version)s \
--rocm_amdgpu_targets=${AMDGPU_TARGETS} \
--bazel_options=${XLA_OVERRIDE_OPTION} \
--bazel_options=${ALL_BAZEL_OPTIONS} \
--bazel_options=${PLUGIN_BAZEL_OPTIONS} \
--bazel_options=${KERNELS_JAX_OVERRIDE_OPTION} \
--verbose \
--clang_path=%(clang_path)s
Expand All @@ -75,6 +99,9 @@
pip install --force-reinstall dist/*


refresh: clean dist install


test:
python3 tests/test_plugin.py

Expand All @@ -85,21 +112,26 @@
# code is somehow making its way into jaxlib.

jaxlib:
(cd %(kernels_jax_dir)s && python3 ./build/build.py build \
(cd %(kernels_jax_path)s && python3 ./build/build.py build \
--target_cpu_features=native \
--use_clang=true \
--clang_path=%(clang_path)s \
--wheels=jaxlib \
--bazel_options=${XLA_OVERRIDE_OPTION} \
--wheels=jaxlib \
--bazel_options=${ALL_BAZEL_OPTIONS} \
--bazel_options=${JAXLIB_BAZEL_OPTIONS} \
--verbose \
)


jaxlib_clean:
rm -f %(kernels_jax_dir)s/dist/*
rm -f %(kernels_jax_path)s/dist/*


jaxlib_install:
pip install --force-reinstall %(kernels_jax_dir)s/dist/*
pip install --force-reinstall %(kernels_jax_path)s/dist/*


refresh_jaxlib: jaxlib_clean jaxlib jaxlib_install
"""


Expand All @@ -119,7 +151,6 @@ def find_clang():
# search /usr/lib/
top = "/usr/lib"
for root, dirs, files in os.walk(top):

# only walk llvm dirs
if root == top:
for d in dirs:
Expand All @@ -135,12 +166,108 @@ def find_clang():
return None


def _resolve_relative_paths(xla_dir: str, kernels_jax_dir: str) -> tuple[str, str, str]:
"""Transforms relative to absolute paths. This is needed to properly support
symbolic information remapping"""
this_repo_root = os.path.dirname(os.path.realpath(__file__))

xla_path = (
xla_dir
if os.path.isabs(xla_dir)
else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{xla_dir}")
)
assert os.path.isdir(
xla_path
), f"XLA path (specified as '{xla_dir}') doesn't resolve to existing directory at '{xla_path}'"

if kernels_jax_dir:
kernels_jax_path = (
kernels_jax_dir
if os.path.isabs(kernels_jax_dir)
else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{kernels_jax_dir}")
)
# pylint: disable=line-too-long
assert os.path.isdir(
kernels_jax_path
), f"XLA path (specified as '{kernels_jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused about this "kernels_jax_dir" variable. Could you please put detailed explanation for this one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernels_jax_dir is the var created by Charles in his previous PR. Essentially it's a directory for the JAX parts, such as kernels, that are needed for pjrt & plugin compilation. It might and might not coincide with ./jax tests directory. Before Charles's PR ./jax was just an independent test storage, and kernels were always taken from the upsteam jax-ml/jax. In his PR, he added an override to always use ./jax (more properly the value of --kernel-jax-dir argument if is set) instead of the upstream, so it's finally used to build pjrt&plugin where it's needed. Now, the upstream jax-ml/jax is used only if --kernel-jax-dir argument is explicitly set to an empty string.

kernels_jax_path is just an abspath(kernels_jax_dir). Absolute paths are mandatory to make debug info remapping work properly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand, thanks for the detailed explanation!

else:
kernels_jax_path = None
return this_repo_root, xla_path, kernels_jax_path


def _add_externals_symlink(this_repo_root: str, xla_path: str, kernels_jax_path: str):
"""Adds ./external symlink to $(bazel info output_base)/external into each path"""
assert os.path.isabs(this_repo_root) and os.path.isabs(xla_path)
assert not kernels_jax_path or os.path.isabs(kernels_jax_path)

# checking 'bazel' is executable. We only support essentially bazelisk here.
# Supporting individual bazel binaries installed by the upstream build system
# when it can't find bazel is a TODO for the future.
# Broad exceptions aren't a problem here
# pylint: disable=broad-exception-caught
try:
v = (
subprocess.run(
["bazel", "--version"],
cwd=f"{this_repo_root}/jax_rocm_plugin",
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
)
.stdout.decode("utf-8")
.rstrip()
)
print(
f"Bazelisk is detected (bazel=={v}), proceeding with creation of symlinks"
)
except Exception as e:
print(
"WARNING: Bazelisk is NOT detected and a wrapper for specific bazel "
"versions isn't implemented. Symlinks to '$(bazel info output_base)/external' "
"will not be created in each bazel workspace root, you'll have to make them manually.\n"
f"The error was: {e}"
)
return

def _link(target: str, name: str):
if os.path.exists(name):
print(f"Filesystem object {name} exists, skipping symlink creation.")
else:
os.symlink(target, name, target_is_directory=True)
print(f"Created symlink '{name}'-->'{target}'")

def _make_external(wrkspace: str):
try:
output_base = (
subprocess.run(
["bazel", "info", "output_base"],
cwd=wrkspace,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
)
.stdout.decode("utf-8")
.rstrip()
)
except Exception as e:
print(f"Failed to query 'bazel info output_base' for '{wrkspace}':{e}")
return
_link(f"{output_base}/external", f"{wrkspace}/external")

_make_external(f"{this_repo_root}/jax_rocm_plugin")
_make_external(xla_path) # not necessary, but useful for work on XLA only
if kernels_jax_path:
_make_external(kernels_jax_path)


# pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals
def setup_development(
xla_ref: str,
xla_dir: str,
test_jax_ref: str,
kernels_jax_dir: str,
rebuild_makefile: bool = False,
fix_bazel_symbols: bool = False,
):
"""Clone jax and xla repos, and set up Makefile for developers"""

Expand All @@ -153,28 +280,43 @@ def setup_development(

# clone xla from source for building jax_rocm_plugin if the user didn't
# specify an existing XLA directory
if not os.path.exists("./xla") and xla_dir != DEFAULT_XLA_DIR:
if not os.path.exists("./xla") and xla_dir == DEFAULT_XLA_DIR:
cmd = ["git", "clone"]
cmd.extend(["--branch", xla_ref])
cmd.append(XLA_REPL_URL)
subprocess.check_call(cmd)

# create build/install/test script
makefile_path = "./jax_rocm_plugin/Makefile"
if rebuild_makefile or not os.path.exists(makefile_path):
if rebuild_makefile or not os.path.exists(makefile_path) or fix_bazel_symbols:
this_repo_root, xla_path, kernels_jax_path = _resolve_relative_paths(
xla_dir, kernels_jax_dir
)
if fix_bazel_symbols:
plugin_bazel_options = "${PLUGIN_SYMBOLS}"
jaxlib_bazel_options = "${JAXLIB_SYMBOLS}"
custom_options = " ${CFG_RELEASE_WITH_SYM}"
_add_externals_symlink(this_repo_root, xla_path, kernels_jax_path)
else: # not modifying the build unless asked
plugin_bazel_options, jaxlib_bazel_options, custom_options = "", "", ""

kvs = {
"clang_path": "/usr/lib/llvm-18/bin/clang",
"plugin_version": PLUGIN_NAMESPACE_VERSION,
"xla_dir": xla_dir,
"this_repo_root": this_repo_root,
"xla_path": xla_path,
"kernels_jax_path": kernels_jax_path,
"plugin_bazel_options": plugin_bazel_options,
"jaxlib_bazel_options": jaxlib_bazel_options,
"custom_options": custom_options,
# If the user wants to use their own JAX for building the plugin wheel
# that contains all the jaxlib kernel code (jax_rocm7_plugin), add that
# to the Makefile.
"kernels_jax_override": (
(" --override_repository=jax=%s" % kernels_jax_dir)
if kernels_jax_dir
("--override_repository=jax=%s" % kernels_jax_path)
if kernels_jax_path
else ""
),
"kernels_jax_dir": kernels_jax_dir if kernels_jax_dir else "",
}

clang_path = find_clang()
Expand Down Expand Up @@ -279,6 +421,15 @@ def parse_args():
default=DEFAULT_KERNELS_JAX_DIR,
)

dev.add_argument(
"--fix-bazel-symbols",
help="When this option is enabled, the script assumes you need to build "
"code in a release with symbolic info configuration to alleviate debugging. "
"The script enables respective bazel options and adds 'external' symbolic "
"links to corresponding workspaces pointing to bazel's dependencies storage.",
action="store_true",
)

doc_parser = subp.add_parser("docker")
doc_parser.add_argument(
"--rm",
Expand All @@ -300,6 +451,7 @@ def main():
test_jax_ref=args.jax_ref,
kernels_jax_dir=args.kernel_jax_dir,
rebuild_makefile=args.rebuild_makefile,
fix_bazel_symbols=args.fix_bazel_symbols,
)


Expand Down