-
Notifications
You must be signed in to change notification settings - Fork 2
Bugfix stack.py to re-enable ./xla checkout, +improvements #112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
659d7e9
e3ade39
5f074ab
bdc4d7e
8c6cdad
1db44fc
65866fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,13 +26,33 @@ | |
| # customize to a single arch for local dev builds to reduce compile time | ||
| AMDGPU_TARGETS ?= "$(shell rocminfo | grep -o -m 1 'gfx.*')" | ||
gulsumgudukbay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Use your local XLA for building jax_rocm_pjrt | ||
| XLA_OVERRIDE_OPTION="--override_repository=xla=%(xla_dir)s" | ||
| ###### auxiliary vars. Note the absence of quotes around variable values, - these vars are expected to be put into other quoted vars | ||
| # Bazel options to build repos in a certain mode. | ||
| CFG_DEBUG=--config=debug --compilation_mode=dbg --strip=never --copt=-g3 --copt=-O0 --cxxopt=-g3 --cxxopt=-O0 | ||
| CFG_RELEASE_WITH_SYM=--strip=never --copt=-g3 --cxxopt=-g3 | ||
|
|
||
| # Sets '-fdebug-prefix-map=' compiler parameter to remap source file locations from bazel's reproducible builds | ||
| # sandbox /proc/self/cwd to correct local paths. Note, external dependencies support require 'external' symlink | ||
| # in a corresponding bazel workspace root | ||
| PLUGIN_SYMBOLS=--copt=-fdebug-prefix-map=/proc/self/cwd=%(this_repo_root)s/jax_rocm_plugin --cxxopt=-fdebug-prefix-map=/proc/self/cwd=%(this_repo_root)s/jax_rocm_plugin | ||
| JAXLIB_SYMBOLS=--copt=-fdebug-prefix-map=/proc/self/cwd=%(kernels_jax_path)s --cxxopt=-fdebug-prefix-map=/proc/self/cwd=%(kernels_jax_path)s | ||
|
|
||
| ###### --bazel_options values, must be enquoted | ||
| # Defines a value for '--bazel_options' for each of 3 build types (pjrt, plugin + jaxlib). | ||
| # By default, uses local XLA for each wheel. Redefine to whatever option is needed for your case | ||
| ALL_BAZEL_OPTIONS="--override_repository=xla=%(xla_path)s%(custom_options)s" | ||
|
|
||
| # PLUGIN_BAZEL_OPTIONS and JAXLIB_BAZEL_OPTIONS define pjrt&plugin specific bazel options and jaxlib specific build options. | ||
| PLUGIN_BAZEL_OPTIONS="%(plugin_bazel_options)s" | ||
| JAXLIB_BAZEL_OPTIONS="%(jaxlib_bazel_options)s" | ||
|
|
||
| # Use your local JAX for building the kernels in jax_rocm_plugin | ||
| # KERNELS_JAX_OVERRIDE_OPTION="--override_repository=jax=../jax" | ||
| KERNELS_JAX_OVERRIDE_OPTION="%(kernels_jax_override)s" | ||
|
|
||
| ### | ||
|
|
||
|
|
||
| .PHONY: test clean install dist | ||
|
|
||
| .default: dist | ||
|
|
@@ -45,10 +65,12 @@ | |
| python3 ./build/build.py build \ | ||
| --use_clang=true \ | ||
| --wheels=jax-rocm-plugin \ | ||
| --target_cpu_features=native \ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea! |
||
| --rocm_path=/opt/rocm/ \ | ||
| --rocm_version=%(plugin_version)s \ | ||
| --rocm_amdgpu_targets=${AMDGPU_TARGETS} \ | ||
| --bazel_options=${XLA_OVERRIDE_OPTION} \ | ||
| --bazel_options=${ALL_BAZEL_OPTIONS} \ | ||
| --bazel_options=${PLUGIN_BAZEL_OPTIONS} \ | ||
| --bazel_options=${KERNELS_JAX_OVERRIDE_OPTION} \ | ||
| --verbose \ | ||
| --clang_path=%(clang_path)s | ||
|
|
@@ -58,10 +80,12 @@ | |
| python3 ./build/build.py build \ | ||
| --use_clang=true \ | ||
| --wheels=jax-rocm-pjrt \ | ||
| --target_cpu_features=native \ | ||
| --rocm_path=/opt/rocm/ \ | ||
| --rocm_version=%(plugin_version)s \ | ||
| --rocm_amdgpu_targets=${AMDGPU_TARGETS} \ | ||
| --bazel_options=${XLA_OVERRIDE_OPTION} \ | ||
| --bazel_options=${ALL_BAZEL_OPTIONS} \ | ||
| --bazel_options=${PLUGIN_BAZEL_OPTIONS} \ | ||
| --bazel_options=${KERNELS_JAX_OVERRIDE_OPTION} \ | ||
| --verbose \ | ||
| --clang_path=%(clang_path)s | ||
|
|
@@ -75,6 +99,9 @@ | |
| pip install --force-reinstall dist/* | ||
|
|
||
|
|
||
| refresh: clean dist install | ||
|
|
||
|
|
||
| test: | ||
| python3 tests/test_plugin.py | ||
|
|
||
|
|
@@ -85,21 +112,26 @@ | |
| # code is somehow making its way into jaxlib. | ||
|
|
||
| jaxlib: | ||
| (cd %(kernels_jax_dir)s && python3 ./build/build.py build \ | ||
| (cd %(kernels_jax_path)s && python3 ./build/build.py build \ | ||
| --target_cpu_features=native \ | ||
| --use_clang=true \ | ||
| --clang_path=%(clang_path)s \ | ||
| --wheels=jaxlib \ | ||
| --bazel_options=${XLA_OVERRIDE_OPTION} \ | ||
| --wheels=jaxlib \ | ||
| --bazel_options=${ALL_BAZEL_OPTIONS} \ | ||
| --bazel_options=${JAXLIB_BAZEL_OPTIONS} \ | ||
| --verbose \ | ||
| ) | ||
|
|
||
|
|
||
| jaxlib_clean: | ||
| rm -f %(kernels_jax_dir)s/dist/* | ||
| rm -f %(kernels_jax_path)s/dist/* | ||
|
|
||
|
|
||
| jaxlib_install: | ||
| pip install --force-reinstall %(kernels_jax_dir)s/dist/* | ||
| pip install --force-reinstall %(kernels_jax_path)s/dist/* | ||
|
|
||
|
|
||
| refresh_jaxlib: jaxlib_clean jaxlib jaxlib_install | ||
| """ | ||
|
|
||
|
|
||
|
|
@@ -119,7 +151,6 @@ def find_clang(): | |
| # search /usr/lib/ | ||
| top = "/usr/lib" | ||
| for root, dirs, files in os.walk(top): | ||
|
|
||
| # only walk llvm dirs | ||
| if root == top: | ||
| for d in dirs: | ||
|
|
@@ -135,12 +166,108 @@ def find_clang(): | |
| return None | ||
|
|
||
|
|
||
| def _resolve_relative_paths(xla_dir: str, kernels_jax_dir: str) -> tuple[str, str, str]: | ||
| """Transforms relative to absolute paths. This is needed to properly support | ||
| symbolic information remapping""" | ||
| this_repo_root = os.path.dirname(os.path.realpath(__file__)) | ||
|
|
||
| xla_path = ( | ||
| xla_dir | ||
| if os.path.isabs(xla_dir) | ||
| else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{xla_dir}") | ||
| ) | ||
| assert os.path.isdir( | ||
| xla_path | ||
| ), f"XLA path (specified as '{xla_dir}') doesn't resolve to existing directory at '{xla_path}'" | ||
|
|
||
| if kernels_jax_dir: | ||
| kernels_jax_path = ( | ||
| kernels_jax_dir | ||
| if os.path.isabs(kernels_jax_dir) | ||
| else os.path.abspath(f"{this_repo_root}/jax_rocm_plugin/{kernels_jax_dir}") | ||
| ) | ||
| # pylint: disable=line-too-long | ||
| assert os.path.isdir( | ||
| kernels_jax_path | ||
| ), f"XLA path (specified as '{kernels_jax_dir}') doesn't resolve to existing directory at '{kernels_jax_path}'" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am confused about this "kernels_jax_dir" variable. Could you please put detailed explanation for this one?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand, thanks for the detailed explanation! |
||
| else: | ||
| kernels_jax_path = None | ||
| return this_repo_root, xla_path, kernels_jax_path | ||
|
|
||
|
|
||
| def _add_externals_symlink(this_repo_root: str, xla_path: str, kernels_jax_path: str): | ||
| """Adds ./external symlink to $(bazel info output_base)/external into each path""" | ||
| assert os.path.isabs(this_repo_root) and os.path.isabs(xla_path) | ||
| assert not kernels_jax_path or os.path.isabs(kernels_jax_path) | ||
|
|
||
| # checking 'bazel' is executable. We only support essentially bazelisk here. | ||
| # Supporting individual bazel binaries installed by the upstream build system | ||
| # when it can't find bazel is a TODO for the future. | ||
| # Broad exceptions aren't a problem here | ||
| # pylint: disable=broad-exception-caught | ||
| try: | ||
| v = ( | ||
| subprocess.run( | ||
| ["bazel", "--version"], | ||
| cwd=f"{this_repo_root}/jax_rocm_plugin", | ||
| check=True, | ||
| stdout=subprocess.PIPE, | ||
| stderr=subprocess.DEVNULL, | ||
| ) | ||
| .stdout.decode("utf-8") | ||
| .rstrip() | ||
| ) | ||
| print( | ||
| f"Bazelisk is detected (bazel=={v}), proceeding with creation of symlinks" | ||
| ) | ||
| except Exception as e: | ||
| print( | ||
| "WARNING: Bazelisk is NOT detected and a wrapper for specific bazel " | ||
| "versions isn't implemented. Symlinks to '$(bazel info output_base)/external' " | ||
| "will not be created in each bazel workspace root, you'll have to make them manually.\n" | ||
| f"The error was: {e}" | ||
| ) | ||
| return | ||
|
|
||
| def _link(target: str, name: str): | ||
| if os.path.exists(name): | ||
| print(f"Filesystem object {name} exists, skipping symlink creation.") | ||
| else: | ||
| os.symlink(target, name, target_is_directory=True) | ||
| print(f"Created symlink '{name}'-->'{target}'") | ||
|
|
||
| def _make_external(wrkspace: str): | ||
| try: | ||
| output_base = ( | ||
| subprocess.run( | ||
| ["bazel", "info", "output_base"], | ||
| cwd=wrkspace, | ||
| check=True, | ||
| stdout=subprocess.PIPE, | ||
| stderr=subprocess.DEVNULL, | ||
| ) | ||
| .stdout.decode("utf-8") | ||
| .rstrip() | ||
| ) | ||
| except Exception as e: | ||
| print(f"Failed to query 'bazel info output_base' for '{wrkspace}':{e}") | ||
| return | ||
| _link(f"{output_base}/external", f"{wrkspace}/external") | ||
|
|
||
| _make_external(f"{this_repo_root}/jax_rocm_plugin") | ||
| _make_external(xla_path) # not necessary, but useful for work on XLA only | ||
| if kernels_jax_path: | ||
| _make_external(kernels_jax_path) | ||
|
|
||
|
|
||
| # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals | ||
| def setup_development( | ||
| xla_ref: str, | ||
| xla_dir: str, | ||
| test_jax_ref: str, | ||
| kernels_jax_dir: str, | ||
| rebuild_makefile: bool = False, | ||
| fix_bazel_symbols: bool = False, | ||
| ): | ||
| """Clone jax and xla repos, and set up Makefile for developers""" | ||
|
|
||
|
|
@@ -153,28 +280,43 @@ def setup_development( | |
|
|
||
| # clone xla from source for building jax_rocm_plugin if the user didn't | ||
| # specify an existing XLA directory | ||
| if not os.path.exists("./xla") and xla_dir != DEFAULT_XLA_DIR: | ||
| if not os.path.exists("./xla") and xla_dir == DEFAULT_XLA_DIR: | ||
| cmd = ["git", "clone"] | ||
| cmd.extend(["--branch", xla_ref]) | ||
| cmd.append(XLA_REPL_URL) | ||
| subprocess.check_call(cmd) | ||
|
|
||
| # create build/install/test script | ||
| makefile_path = "./jax_rocm_plugin/Makefile" | ||
| if rebuild_makefile or not os.path.exists(makefile_path): | ||
| if rebuild_makefile or not os.path.exists(makefile_path) or fix_bazel_symbols: | ||
| this_repo_root, xla_path, kernels_jax_path = _resolve_relative_paths( | ||
| xla_dir, kernels_jax_dir | ||
| ) | ||
| if fix_bazel_symbols: | ||
| plugin_bazel_options = "${PLUGIN_SYMBOLS}" | ||
| jaxlib_bazel_options = "${JAXLIB_SYMBOLS}" | ||
| custom_options = " ${CFG_RELEASE_WITH_SYM}" | ||
| _add_externals_symlink(this_repo_root, xla_path, kernels_jax_path) | ||
| else: # not modifying the build unless asked | ||
| plugin_bazel_options, jaxlib_bazel_options, custom_options = "", "", "" | ||
|
|
||
| kvs = { | ||
| "clang_path": "/usr/lib/llvm-18/bin/clang", | ||
| "plugin_version": PLUGIN_NAMESPACE_VERSION, | ||
| "xla_dir": xla_dir, | ||
| "this_repo_root": this_repo_root, | ||
| "xla_path": xla_path, | ||
| "kernels_jax_path": kernels_jax_path, | ||
| "plugin_bazel_options": plugin_bazel_options, | ||
| "jaxlib_bazel_options": jaxlib_bazel_options, | ||
| "custom_options": custom_options, | ||
| # If the user wants to use their own JAX for building the plugin wheel | ||
| # that contains all the jaxlib kernel code (jax_rocm7_plugin), add that | ||
| # to the Makefile. | ||
| "kernels_jax_override": ( | ||
| (" --override_repository=jax=%s" % kernels_jax_dir) | ||
| if kernels_jax_dir | ||
| ("--override_repository=jax=%s" % kernels_jax_path) | ||
| if kernels_jax_path | ||
| else "" | ||
| ), | ||
| "kernels_jax_dir": kernels_jax_dir if kernels_jax_dir else "", | ||
| } | ||
|
|
||
| clang_path = find_clang() | ||
|
|
@@ -279,6 +421,15 @@ def parse_args(): | |
| default=DEFAULT_KERNELS_JAX_DIR, | ||
| ) | ||
|
|
||
| dev.add_argument( | ||
| "--fix-bazel-symbols", | ||
| help="When this option is enabled, the script assumes you need to build " | ||
| "code in a release with symbolic info configuration to alleviate debugging. " | ||
| "The script enables respective bazel options and adds 'external' symbolic " | ||
| "links to corresponding workspaces pointing to bazel's dependencies storage.", | ||
| action="store_true", | ||
| ) | ||
|
|
||
| doc_parser = subp.add_parser("docker") | ||
| doc_parser.add_argument( | ||
| "--rm", | ||
|
|
@@ -300,6 +451,7 @@ def main(): | |
| test_jax_ref=args.jax_ref, | ||
| kernels_jax_dir=args.kernel_jax_dir, | ||
| rebuild_makefile=args.rebuild_makefile, | ||
| fix_bazel_symbols=args.fix_bazel_symbols, | ||
| ) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what generates this _my folder?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing generates it, I just use it for work scripts and tired of adding it to a local
.gitignoreon each new checkout.It's good to have a dedicated dir for such stuff, dumps and whatnot...