Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@charleshofer
Copy link
Collaborator

@charleshofer charleshofer commented Aug 20, 2025

It's helpful for XLA developers to be able to keep their local XLA in a directory other than this repository. To get the JAX development build to work with a different XLA path, this would normally require editing the Makefile. This process is a little cumbersome, and it's easier to set this through an option on the stack.py file. This change adds that option. This also adds an argument for building the wheels with a local copy of JAX, similar to the new XLA option, and some make targets for building jaxlib.

Copy link

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks in ton! one more description in README.md will be pefect!

Copy link
Contributor

@Ruturaj4 Ruturaj4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah perhaps just a line in readme.md depicting the usage of --xla-dir

@charleshofer charleshofer force-pushed the add-xla-path-option-stack branch 3 times, most recently from 1ced563 to 0970e6b Compare August 20, 2025 21:27
@charleshofer
Copy link
Collaborator Author

yeah perhaps just a line in readme.md depicting the usage of --xla-dir

Added some stuff to the README to cover this

@charleshofer charleshofer force-pushed the add-xla-path-option-stack branch from 0970e6b to cdaac33 Compare August 20, 2025 21:30
README.md Outdated
This sets up a development environment that allows you to build the ROCm
plugins via `make`. You can tune the environment and build process to
fit your specific build needs by editing the `jax_rocm_plugin/Makefile`
that `stack.py` creates for you.
Copy link

@i-chaochen i-chaochen Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so IIUC the development procedure will be:

  1. python stack.py develop // this is to git clone jax within rocm-jax, seems I still cannot use my local jax? what if my changes in jax ? or my local jax needs to under rocm-jax ?
  2. then I go to rocm-jax/jax_rocm_plugin and edit Makefile generated from step-1
  3. replace --bazel_options="--override_repository=xla=/my/local/xla/path" for jax_rocm_plugin and jax_rocm_pjrt
  4. make clean dist in rocm-jax/jax_rocm_plugin // this will build my local xla with jax.

And if I made any changes, for example did some crazy modifications on jax.lax.abs() on my local jax, which is in rocm-jax/jax with my local xla, I need to re-do the above 4 steps again? but is the 1-step is to git clone rocm/jax-0.6 all the time?

If so, could you depict like this step by step, please? or wondering can we just run a magic rocm_install.sh at once to build jax/xla w/o Makefile....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1... seems I still cannot use my local jax?

JAX checkout is only used for running JAX tests, or if you'd need to also build jaxlib. There might be a slight problem with overriding jax checkout directory, since JAX tests runners are in rocm/rocm-jax and they assume jax is in ../jax. One way to solve this would be creating some path_mapping file in the root with JAX checkout path spec, and then reading this file for a non-default location in the runners. Not really fun to implement given that the use-case has unclear (at least to me) utility for the team (never need that), but doable, of course.

On the other hand, you can transfer your work to the default ../jax checkout via working branch with (cd ./jax && git checkout my_working_branch). Would that work for you? This will spare effort on doing & testing the code for a non-default jax...

2 .... 4

Now you just do python3 stack.py develop --xla-dir=/your/path and then directly step 4.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I might be wrong regarding that jax is only used for tests. Plugin build system references some things in @jax//jaxlib for example, so it's mandatory to build plugin.

Another issue with overriding default repo checkout is a set of patches that we have to apply to the checkout. These patches aren't present in the vanilla rocm/jax which is basically a fork of relevant upstream with changes in tests only.

Based on that, to save effort, I would prefer not to override jax checkout location at all. @i-chaochen , can you transfer you changes to the default checkout using working branch when you need it?

cc @charleshofer (I see that PR is already approved without any discussion, so to prevent premature merge, I'll request changes)

Copy link

@i-chaochen i-chaochen Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on that, to save effort, I would prefer not to override jax checkout location at all. @i-chaochen , can you transfer you changes to the default checkout using working branch when you need it?

I'm not sure I fully get it, just to be clear what's the asking: for example someone wants to do some changes both in jax and xla for ROCm, jax.lax.abs() to always return 0 and pass something info to XLA HLO, so needs to build jax/xla together, and it's based on this pining commit on jax. ?

And the work is based on the upstream (google) jax and rocm/xla-0.6?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for example I want to do some changes both in jax and xla code base, jax.lax.abs() to always return 0 and pass something info to XLA HLO, so needs to build jax/xla together

The best way to do this is to:

  1. checkout this rocm/rocm-jax meta-repo
  2. setup the current container with ./tools/docker_dev_setup.sh (options apply), then
  3. run python3 stack.py develop (possibly with --jaxlib argument that I proposed in a different comment). Then
  4. switch ./jax (checkout of rocm/jax) and ./xla (checkout of rocm/xla) to branches you need.
  5. pip install ./jax -e to install JAX in editable mode.
  6. modify whatever you want in ./jax and ./xla
  7. build all JAX from scratch with (cd ./jax_rocm_plugin && make clean dist install)
  8. test your changes, rinse and repeat from step 6.

I don't exactly know how & when upstream jax repo is used, but for most if not all dev use-cases the above process will work.

The question to you was, instead of overriding the default checkout path ./jax, can you transfer you changes from your other copy of jax repo to the default ./jax checkout (made at step 3) using git ? Is it clear now?

It's easy to override XLA checkout path, but it might be not easy to override jax checkout path due to patches that must be applied and test runners.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@i-chaochen right! And also doing so would tie a compiler version to ROCm version used, which might cause issues in compiling for older ROCms. It's much safer to always know upfront that the same compiler is used in all cases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rocm/jax is only used for test cases (I'm hoping we can phase this out and just pull from jax-ml/jax) and staging PRs from AMD developers into jax-ml/jax. Mostly, this is storing test cases that we skip because we have yet to fix the underlying bug, and we'll eventually unskip that test case when it gets fixed.

jax-ml/jax is pulled in via Bazel (https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/third_party/jax/workspace.bzl#L12) and is only used to build the kernels (AKA jax_rocm7_plugin) wheel (https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel#L26). Bazel applies a handful of patches to the kernel code when it pulls jax-ml/jax (https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/third_party/jax/workspace.bzl#L14). That kernel code is mostly stuff that we share with Nvidia, changes to it are few and far in-between, and changes almost always make their way into jax-ml/jax at some point, at which we can remove the patch file. So, using patch files here work just fine.

That said, there is a valid use-case for wanting to modify the kernel code in jax-ml/jax and test it out. That's just not something that we've done a whole lot, so we haven't got any docs for it yet. Here's how you would do it:

  1. Clone rocm-jax and run stack.py develop to generate your Makefile and get ./jax
  2. in the jax_rocm_plugin rule in your Makefile, add --bazel_options="--override_repository=jax=<path to my JAX repo from step two>" \
  3. cd ./jax
  4. Make your changes to the wheel kernel code (this code is almost all in jaxlib/gpu and jaxlib/rocm)
  5. cd ../jax_rocm_plugin
  6. make clean dist
  7. pip install wheehlhouse/jax_rocm7_plugin_
  8. Test out your changes in the python interpreter or with unit tests (cd ../ && pytest jax/tests/)
  9. Repeat steps 3-8 as much as you need. You do not need to run python stack.py develop again unless you want to reset your Makefile back to the default.

Another process that doesn't have documentation on (also because we just haven't done it much), is how to install locally-built jax and jaxlib wheels alongside the plugin wheels. This is where you'd modify jax.lax.abs to do weird stuff with the HLO. The steps are basically what Aleksei already outlined. The only addition I'd make is that to install jaxlib, you'd have do something like python build/build.py --wheels jaxib from inside of the ./jax directory. Realistically, we only do this for debug purposes though, because problems with the jax wheel will usually impact Nvidia and TPU the same as us, and jaxlib (with the exception of loading up the plugins) is used for running JAX operators on CPU instead of GPUs/TPUs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to try and respond to all the questions.

@i-chaochen Thanks! Yes, this is clear to me! I hope this instructions could be into README.md This will super helpful for all!

Yes. It's pretty clear to me that we need better documentation on the build and dev process for JAX. I think having instructions like "How do I do X?" for modifying the jax and jaxlib wheels, building pjrt and plugin wheels with local XLA and jax repos would be helpful.

@i-chaochen so just to make sure I got the full picture, if I want to make changes on upstream jax/xla for ROCm, we have no way to do it? because rocm-jax is pinging a jax 0.6?

No. You'd get a local clone of the jax repo (feel free to use the one in rocm-jax by adding a new remote and pulling the branch you need), make your changes, and then override where Bazel pulls from. Setting --bazel_options="--override_repository=... in the Makefile will always override the JAX and XLA pins in jax_rocm_plugin/third_party.

@Arech8 Yes, you can use whatever XLA (including upstream) you want to build JAX based on it. Note that in many cases this would require rebuilding jaxlib too, that's why I asked @charleshofer to add that option.

@Arech8 there was some narrow case when upstream JAX is always used instead of local checkout for some part of the build, but I don't understand exactly when&how this happens, - @charleshofer can you please clarify if you know that?)

You will get your local JAX for building the plugin wheel (jax_rocm7_plugin) if you're override the --bazel-options="--override_repository=...". You can set this to whatever path you want. You always get your local jax from ./jax when you run tests with the test scripts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@charleshofer thanks a ton for great explanation, I really wish this was saved somewhere!
Regarding jaxlib - can you please look at my other comment where I proposed to generate an updated makefile to build jaxlib too, should a user want it. There's some additional rationale for that and it's very simple to implement. Perhaps, you can consider it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@i-chaochen and @Arech8 Let me know if I missed any questions or if the build process still isn't clear. It's a long thread and I probably missed some.

):
"""Clone jax repo for jax test case source code"""

if not os.path.exists("./jax"):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we need to have option to specifc jax directory as well?

JAX_REPL_URL = "https://github.com/rocm/jax"
XLA_REPL_URL = "https://github.com/rocm/xla"

DEFAULT_XLA_DIR = "../xla"
Copy link
Contributor

@Arech8 Arech8 Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please implement these two important things: refactor --bazel_options value to a single variable, and add an option to build jaxlib?

--bazel_options value problem

There's a very frequent need to modify the options (for example, to build everything in debug mode), and doing the same changes to 2 places (potentially - 3, see below) is silly. So instead of:

...
jax_rocm_plugin:
	python3 ./build/build.py build \
            ...
            --bazel_options="--override_repository=xla=%(xla_dir)s" \
            ...

jax_rocm_pjrt:
	python3 ./build/build.py build \
            ...
            --bazel_options="--override_repository=xla=%(xla_dir)s" \
            ...
...

there should be

...
BAZEL_OPTIONS="--override_repository=xla=%(xla_dir)s"
...
jax_rocm_plugin:
	python3 ./build/build.py build \
            ...
            --bazel_options=${BAZEL_OPTIONS} \
            ...

jax_rocm_pjrt:
	python3 ./build/build.py build \
            ...
            --bazel_options=${BAZEL_OPTIONS} \
            ...
...

Sometimes there's a need to build jaxlib

There are several use-cases to build jaxlib too:

  1. "namespace patch" - do I get it right that we still can not use the upstream jaxlib as it is? The patch must always be applied to it? If the answer is yes, then the problem is that during development, packages can be installed and removed easily and sometimes unintentionally (by installing a package that depends on an incompatible version). If the patched jaxlib is removed, but then installed back - the reinstalled version will always be the original non-patched version causing weird failures. Hence, it's just way simpler and more robust to be able to easily install a properly built jaxlib back again, with the same make install, as the installation of the pjrt and plugin requires.
  2. Sometimes one need to work with a different JAX/jaxlib version than the repo originally intends. For example, to bump JAX version one might want to start with the current one. Or to check some patch or for some other reason. So, building jaxlib from sources might be a better and simpler way to ensure you have a coherent setup, than messing with the upstream.
  3. one might need to debug a weird issue where a potentially incompatible jaxlib is a suspect.

So, even if (1) isn't a concern (i.e. if we could always use the upstream jaxlib without any modifications), (2) and (3) still stays.

It would be nice to add building jaxlib too if --jaxlib binary flag is passed:

  • a binary --jaxlib CLI flag triggers building jaxlib.
  • add --jax-dir CLI argument to override the default ../jax path, so @i-chaochen could use his own checkout of JAX. If I'm not mistaken, we don't need rocm/jax for anything else, except for running tests, so likely there's no reason to use this flag outside of --jaxlib flag context. However, overriding jax checkout location will also require modification of tests runners which are actually in rocm/rocm-jax and have ../jax path hardcoded... (anyway, I'm personally fine with jax being in its default location, but I really need jaxlib compilation feature)
  • if args.jaxlib is true, add and modify the following in Makefile:
dist: jax_rocm_plugin jax_rocm_pjrt jaxlib

jaxlib:
    (cd %(jax_dir)s && python3 ./build/build.py build \
            ...
            --wheels=jaxlib \
            --bazel_options=${BAZEL_OPTIONS} \
            ...)

clean:
    rm -rf dist ; rm -rf %(jax_dir)s/dist

install: dist
    pip install --force-reinstall dist/* && (cd %(jax_dir)s && pip install --force-reinstall dist/*)

(regarding install rule - for some reason pip didn't work correctly for me, when I invoked it with a ../jax/dist/* argument, but the workaround with cd worked)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bazel options thing makes sense. I will add that.

Regarding (1), yes. This will be the case until the jax/jaxlib 0.7.1 release.

(2) and (3) makes sense to me, too. I can modify the Makefile to include a jaxlib target. I'd prefer it not get included in the dist target though. There's use-cases for building our own jaxlib for debug purposes and if jaxlib is suspect, but I don't want us to be in the habit of depending on it.

I'd also prefer if we just always include it in the Makefile rather than making it optional with a --jaxlib flag. You can build it by running the make jaxlib target. Maybe we'd give it its own targets for clean and install, too. Bloating stack.py with options is piling build config on top of config. The plugin build scripts are already like 5 layers deep and in need of a refactor. Not being able to say "no" to every individual's request for just one more config option is part of why we're here.

Is that agreeable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, ok, if it's always in the makefile, just not a part of the dist rule or any other rule used by default in a "normal" dev process, this is fine too and makes sense to me, totally agree on that.

A note regarding on "5 layers deep" - while this is technically true, for me, as a build system user (note the BDD wording 😁 ), there are no 5 layers. I'm only interfacing with stack.py and don't want to look any deeper (but I had to due to necessity to compile jaxlib sometimes). But this was just a context, I'm totally fine with the solution you proposed, thank you.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, something I haven't thought about until just now. If you're building inside a container (and you 100% should be), using your own JAX and your own XLA inside the container is going to get hairy if you ran stack.py develop outside of the container and the script directories on your host. Maybe we could mount those? But then we've got to remap those paths to the mounted paths. The container mounts everything in rocm-jax and it works no sweat.

Copy link
Collaborator Author

@charleshofer charleshofer Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wrinkle with keeping it all on rocm-jax/jax is that you might want different JAX repos for building jax/jaxlib, building the plugin kernels wheel, and running the tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alekstheod how many times per day do you usually have to build jax when triaging bugs?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends. If I have some bug which is jax related then often. But any wrapper is just still calling bazel with a different flags and thats it. It is easier just to add these flags into a config in bazelrc to simlify the command. I still don't see a reason for a wrapper script.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not easier, by far. That already assumes a fixed working machine which isn't a correct assumption. Moving even to a different checkout would require changing this over and over again, making iteration times unnecessary long. And then there's remembering and entering a full bazel command line arguments. What for, when you can just make clean dist install and have JAX up and running after that?
If you don't have a need for something, it doesn't mean the same for others.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to enter the full bazel command if you put the args in your config you just need to add this config to your build command. I just say that whatever we do with external script we could do in the bazel itself.
For install you can do:
bazel run --config=rocm_clang_official --config=devenv //jax:jaxlib_install if you have your flags inside the devenv config and jaxlib_install target. This will build and install the jaxlib pkg into your machine.
You could also do
bazel run --config=rocm_clang_official --config=devenv //jax:setup_devenv so to install whatever packages are required.

This will guarantee that you installed what you did build. If you have that external script it will not guarantee that. So you might install things built with the previous checkout. Basically it is orthogonal with the bazel filosophy ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excuse me, are you trying to make a point, that doing all of this including modification of .bazelrc, is simpler than just calling make ?

README.md Outdated
This sets up a development environment that allows you to build the ROCm
plugins via `make`. You can tune the environment and build process to
fit your specific build needs by editing the `jax_rocm_plugin/Makefile`
that `stack.py` creates for you.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I might be wrong regarding that jax is only used for tests. Plugin build system references some things in @jax//jaxlib for example, so it's mandatory to build plugin.

Another issue with overriding default repo checkout is a set of patches that we have to apply to the checkout. These patches aren't present in the vanilla rocm/jax which is basically a fork of relevant upstream with changes in tests only.

Based on that, to save effort, I would prefer not to override jax checkout location at all. @i-chaochen , can you transfer you changes to the default checkout using working branch when you need it?

cc @charleshofer (I see that PR is already approved without any discussion, so to prevent premature merge, I'll request changes)

@charleshofer
Copy link
Collaborator Author

Updated the PR. There's options for using your own JAX and XLA. I also added make targets to make it easier to build the jaxlib wheel.

@Arech8
Copy link
Contributor

Arech8 commented Aug 27, 2025

@charleshofer did you push changes?

@charleshofer charleshofer force-pushed the add-xla-path-option-stack branch from cdaac33 to 74746d9 Compare August 27, 2025 16:12
@charleshofer charleshofer changed the title Add option to override XLA path in Makefile Improve stack.py develop options Aug 27, 2025
@charleshofer charleshofer force-pushed the add-xla-path-option-stack branch from 74746d9 to ea872ed Compare August 27, 2025 16:14
@charleshofer
Copy link
Collaborator Author

Pushed now

Adds a few options to the stack.py script that allow developers to use
local copies of XLA and JAX for building wheels. This also adds targets
in the Makefile for building jaxlib.
@charleshofer charleshofer force-pushed the add-xla-path-option-stack branch from ea872ed to 991dfc9 Compare August 28, 2025 22:16
@i-chaochen
Copy link

i-chaochen commented Sep 2, 2025

since #101 is merged, there is one more question I want to clarify, because I'm aware of you guys pinned a XLA commit for the jax/xla build, if I use my local XLA repo, I guess I don't need to change this commit anymore?

@Arech8
Copy link
Contributor

Arech8 commented Sep 2, 2025

@i-chaochen --override_repository flag you should be using to set your XLA checkout location is a bazel native feature, so there's no point to do anything related to a commit pinning. It completely overrides repo checkout location for bazel. Essentially, if this flag is not specified, bazel fetches the repo sources into $(bazel info output_base)/external/xla, but if you specify the flag, then $(bazel info output_base)/external/xla becomes a symlink to the dir you set. Hence the dependency use is completely overridden for the whole build system.

@charleshofer charleshofer merged commit bd2fab3 into master Sep 4, 2025
7 checks passed
@charleshofer charleshofer deleted the add-xla-path-option-stack branch September 4, 2025 22:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants