-
Notifications
You must be signed in to change notification settings - Fork 2
Improve stack.py develop options #98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
239b5bf to
f1d7731
Compare
i-chaochen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks in ton! one more description in README.md will be pefect!
Ruturaj4
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah perhaps just a line in readme.md depicting the usage of --xla-dir
1ced563 to
0970e6b
Compare
Added some stuff to the README to cover this |
0970e6b to
cdaac33
Compare
README.md
Outdated
| This sets up a development environment that allows you to build the ROCm | ||
| plugins via `make`. You can tune the environment and build process to | ||
| fit your specific build needs by editing the `jax_rocm_plugin/Makefile` | ||
| that `stack.py` creates for you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so IIUC the development procedure will be:
python stack.py develop// this is togit clone jaxwithinrocm-jax, seems I still cannot use my local jax? what if my changes injax? or my localjaxneeds to underrocm-jax?- then I go to
rocm-jax/jax_rocm_pluginand editMakefilegenerated from step-1 - replace
--bazel_options="--override_repository=xla=/my/local/xla/path"forjax_rocm_pluginandjax_rocm_pjrt make clean distinrocm-jax/jax_rocm_plugin// this will build my local xla with jax.
And if I made any changes, for example did some crazy modifications on jax.lax.abs() on my local jax, which is in rocm-jax/jax with my local xla, I need to re-do the above 4 steps again? but is the 1-step is to git clone rocm/jax-0.6 all the time?
If so, could you depict like this step by step, please? or wondering can we just run a magic rocm_install.sh at once to build jax/xla w/o Makefile....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1... seems I still cannot use my local jax?
JAX checkout is only used for running JAX tests, or if you'd need to also build jaxlib. There might be a slight problem with overriding jax checkout directory, since JAX tests runners are in rocm/rocm-jax and they assume jax is in ../jax. One way to solve this would be creating some path_mapping file in the root with JAX checkout path spec, and then reading this file for a non-default location in the runners. Not really fun to implement given that the use-case has unclear (at least to me) utility for the team (never need that), but doable, of course.
On the other hand, you can transfer your work to the default ../jax checkout via working branch with (cd ./jax && git checkout my_working_branch). Would that work for you? This will spare effort on doing & testing the code for a non-default jax...
2 .... 4
Now you just do python3 stack.py develop --xla-dir=/your/path and then directly step 4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I might be wrong regarding that jax is only used for tests. Plugin build system references some things in @jax//jaxlib for example, so it's mandatory to build plugin.
Another issue with overriding default repo checkout is a set of patches that we have to apply to the checkout. These patches aren't present in the vanilla rocm/jax which is basically a fork of relevant upstream with changes in tests only.
Based on that, to save effort, I would prefer not to override jax checkout location at all. @i-chaochen , can you transfer you changes to the default checkout using working branch when you need it?
cc @charleshofer (I see that PR is already approved without any discussion, so to prevent premature merge, I'll request changes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on that, to save effort, I would prefer not to override jax checkout location at all. @i-chaochen , can you transfer you changes to the default checkout using working branch when you need it?
I'm not sure I fully get it, just to be clear what's the asking: for example someone wants to do some changes both in jax and xla for ROCm, jax.lax.abs() to always return 0 and pass something info to XLA HLO, so needs to build jax/xla together, and it's based on this pining commit on jax. ?
And the work is based on the upstream (google) jax and rocm/xla-0.6?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for example I want to do some changes both in jax and xla code base, jax.lax.abs() to always return 0 and pass something info to XLA HLO, so needs to build jax/xla together
The best way to do this is to:
- checkout this
rocm/rocm-jaxmeta-repo - setup the current container with
./tools/docker_dev_setup.sh(options apply), then - run
python3 stack.py develop(possibly with--jaxlibargument that I proposed in a different comment). Then - switch
./jax(checkout ofrocm/jax) and./xla(checkout ofrocm/xla) to branches you need. pip install ./jax -eto installJAXin editable mode.- modify whatever you want in
./jaxand./xla - build all JAX from scratch with
(cd ./jax_rocm_plugin && make clean dist install) - test your changes, rinse and repeat from step 6.
I don't exactly know how & when upstream jax repo is used, but for most if not all dev use-cases the above process will work.
The question to you was, instead of overriding the default checkout path ./jax, can you transfer you changes from your other copy of jax repo to the default ./jax checkout (made at step 3) using git ? Is it clear now?
It's easy to override XLA checkout path, but it might be not easy to override jax checkout path due to patches that must be applied and test runners.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@i-chaochen right! And also doing so would tie a compiler version to ROCm version used, which might cause issues in compiling for older ROCms. It's much safer to always know upfront that the same compiler is used in all cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rocm/jax is only used for test cases (I'm hoping we can phase this out and just pull from jax-ml/jax) and staging PRs from AMD developers into jax-ml/jax. Mostly, this is storing test cases that we skip because we have yet to fix the underlying bug, and we'll eventually unskip that test case when it gets fixed.
jax-ml/jax is pulled in via Bazel (https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/third_party/jax/workspace.bzl#L12) and is only used to build the kernels (AKA jax_rocm7_plugin) wheel (https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel#L26). Bazel applies a handful of patches to the kernel code when it pulls jax-ml/jax (https://github.com/ROCm/rocm-jax/blob/master/jax_rocm_plugin/third_party/jax/workspace.bzl#L14). That kernel code is mostly stuff that we share with Nvidia, changes to it are few and far in-between, and changes almost always make their way into jax-ml/jax at some point, at which we can remove the patch file. So, using patch files here work just fine.
That said, there is a valid use-case for wanting to modify the kernel code in jax-ml/jax and test it out. That's just not something that we've done a whole lot, so we haven't got any docs for it yet. Here's how you would do it:
- Clone rocm-jax and run stack.py develop to generate your Makefile and get
./jax - in the
jax_rocm_pluginrule in your Makefile, add--bazel_options="--override_repository=jax=<path to my JAX repo from step two>" \ - cd ./jax
- Make your changes to the wheel kernel code (this code is almost all in jaxlib/gpu and jaxlib/rocm)
- cd ../jax_rocm_plugin
- make clean dist
- pip install wheehlhouse/jax_rocm7_plugin_
- Test out your changes in the python interpreter or with unit tests (cd ../ && pytest jax/tests/)
- Repeat steps 3-8 as much as you need. You do not need to run
python stack.py developagain unless you want to reset your Makefile back to the default.
Another process that doesn't have documentation on (also because we just haven't done it much), is how to install locally-built jax and jaxlib wheels alongside the plugin wheels. This is where you'd modify jax.lax.abs to do weird stuff with the HLO. The steps are basically what Aleksei already outlined. The only addition I'd make is that to install jaxlib, you'd have do something like python build/build.py --wheels jaxib from inside of the ./jax directory. Realistically, we only do this for debug purposes though, because problems with the jax wheel will usually impact Nvidia and TPU the same as us, and jaxlib (with the exception of loading up the plugins) is used for running JAX operators on CPU instead of GPUs/TPUs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to try and respond to all the questions.
@i-chaochen Thanks! Yes, this is clear to me! I hope this instructions could be into README.md This will super helpful for all!
Yes. It's pretty clear to me that we need better documentation on the build and dev process for JAX. I think having instructions like "How do I do X?" for modifying the jax and jaxlib wheels, building pjrt and plugin wheels with local XLA and jax repos would be helpful.
@i-chaochen so just to make sure I got the full picture, if I want to make changes on upstream jax/xla for ROCm, we have no way to do it? because rocm-jax is pinging a jax 0.6?
No. You'd get a local clone of the jax repo (feel free to use the one in rocm-jax by adding a new remote and pulling the branch you need), make your changes, and then override where Bazel pulls from. Setting --bazel_options="--override_repository=... in the Makefile will always override the JAX and XLA pins in jax_rocm_plugin/third_party.
@Arech8 Yes, you can use whatever XLA (including upstream) you want to build JAX based on it. Note that in many cases this would require rebuilding jaxlib too, that's why I asked @charleshofer to add that option.
@Arech8 there was some narrow case when upstream JAX is always used instead of local checkout for some part of the build, but I don't understand exactly when&how this happens, - @charleshofer can you please clarify if you know that?)
You will get your local JAX for building the plugin wheel (jax_rocm7_plugin) if you're override the --bazel-options="--override_repository=...". You can set this to whatever path you want. You always get your local jax from ./jax when you run tests with the test scripts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@charleshofer thanks a ton for great explanation, I really wish this was saved somewhere!
Regarding jaxlib - can you please look at my other comment where I proposed to generate an updated makefile to build jaxlib too, should a user want it. There's some additional rationale for that and it's very simple to implement. Perhaps, you can consider it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@i-chaochen and @Arech8 Let me know if I missed any questions or if the build process still isn't clear. It's a long thread and I probably missed some.
| ): | ||
| """Clone jax repo for jax test case source code""" | ||
|
|
||
| if not os.path.exists("./jax"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we need to have option to specifc jax directory as well?
| JAX_REPL_URL = "https://github.com/rocm/jax" | ||
| XLA_REPL_URL = "https://github.com/rocm/xla" | ||
|
|
||
| DEFAULT_XLA_DIR = "../xla" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please implement these two important things: refactor --bazel_options value to a single variable, and add an option to build jaxlib?
--bazel_options value problem
There's a very frequent need to modify the options (for example, to build everything in debug mode), and doing the same changes to 2 places (potentially - 3, see below) is silly. So instead of:
...
jax_rocm_plugin:
python3 ./build/build.py build \
...
--bazel_options="--override_repository=xla=%(xla_dir)s" \
...
jax_rocm_pjrt:
python3 ./build/build.py build \
...
--bazel_options="--override_repository=xla=%(xla_dir)s" \
...
...
there should be
...
BAZEL_OPTIONS="--override_repository=xla=%(xla_dir)s"
...
jax_rocm_plugin:
python3 ./build/build.py build \
...
--bazel_options=${BAZEL_OPTIONS} \
...
jax_rocm_pjrt:
python3 ./build/build.py build \
...
--bazel_options=${BAZEL_OPTIONS} \
...
...
Sometimes there's a need to build jaxlib
There are several use-cases to build jaxlib too:
- "namespace patch" - do I get it right that we still can not use the upstream
jaxlibas it is? The patch must always be applied to it? If the answer is yes, then the problem is that during development, packages can be installed and removed easily and sometimes unintentionally (by installing a package that depends on an incompatible version). If the patchedjaxlibis removed, but then installed back - the reinstalled version will always be the original non-patched version causing weird failures. Hence, it's just way simpler and more robust to be able to easily install a properly builtjaxlibback again, with the samemake install, as the installation of the pjrt and plugin requires. - Sometimes one need to work with a different JAX/jaxlib version than the repo originally intends. For example, to bump JAX version one might want to start with the current one. Or to check some patch or for some other reason. So, building
jaxlibfrom sources might be a better and simpler way to ensure you have a coherent setup, than messing with the upstream. - one might need to debug a weird issue where a potentially incompatible
jaxlibis a suspect.
So, even if (1) isn't a concern (i.e. if we could always use the upstream jaxlib without any modifications), (2) and (3) still stays.
It would be nice to add building jaxlib too if --jaxlib binary flag is passed:
- a binary
--jaxlibCLI flag triggers building jaxlib. - add
--jax-dirCLI argument to override the default../jaxpath, so @i-chaochen could use his own checkout of JAX. If I'm not mistaken, we don't needrocm/jaxfor anything else, except for running tests, so likely there's no reason to use this flag outside of--jaxlibflag context. However, overridingjaxcheckout location will also require modification of tests runners which are actually inrocm/rocm-jaxand have../jaxpath hardcoded... (anyway, I'm personally fine withjaxbeing in its default location, but I really needjaxlibcompilation feature) - if
args.jaxlibis true, add and modify the following inMakefile:
dist: jax_rocm_plugin jax_rocm_pjrt jaxlib
jaxlib:
(cd %(jax_dir)s && python3 ./build/build.py build \
...
--wheels=jaxlib \
--bazel_options=${BAZEL_OPTIONS} \
...)
clean:
rm -rf dist ; rm -rf %(jax_dir)s/dist
install: dist
pip install --force-reinstall dist/* && (cd %(jax_dir)s && pip install --force-reinstall dist/*)
(regarding install rule - for some reason pip didn't work correctly for me, when I invoked it with a ../jax/dist/* argument, but the workaround with cd worked)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bazel options thing makes sense. I will add that.
Regarding (1), yes. This will be the case until the jax/jaxlib 0.7.1 release.
(2) and (3) makes sense to me, too. I can modify the Makefile to include a jaxlib target. I'd prefer it not get included in the dist target though. There's use-cases for building our own jaxlib for debug purposes and if jaxlib is suspect, but I don't want us to be in the habit of depending on it.
I'd also prefer if we just always include it in the Makefile rather than making it optional with a --jaxlib flag. You can build it by running the make jaxlib target. Maybe we'd give it its own targets for clean and install, too. Bloating stack.py with options is piling build config on top of config. The plugin build scripts are already like 5 layers deep and in need of a refactor. Not being able to say "no" to every individual's request for just one more config option is part of why we're here.
Is that agreeable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, ok, if it's always in the makefile, just not a part of the dist rule or any other rule used by default in a "normal" dev process, this is fine too and makes sense to me, totally agree on that.
A note regarding on "5 layers deep" - while this is technically true, for me, as a build system user (note the BDD wording 😁 ), there are no 5 layers. I'm only interfacing with stack.py and don't want to look any deeper (but I had to due to necessity to compile jaxlib sometimes). But this was just a context, I'm totally fine with the solution you proposed, thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ugh, something I haven't thought about until just now. If you're building inside a container (and you 100% should be), using your own JAX and your own XLA inside the container is going to get hairy if you ran stack.py develop outside of the container and the script directories on your host. Maybe we could mount those? But then we've got to remap those paths to the mounted paths. The container mounts everything in rocm-jax and it works no sweat.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The wrinkle with keeping it all on rocm-jax/jax is that you might want different JAX repos for building jax/jaxlib, building the plugin kernels wheel, and running the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alekstheod how many times per day do you usually have to build jax when triaging bugs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depends. If I have some bug which is jax related then often. But any wrapper is just still calling bazel with a different flags and thats it. It is easier just to add these flags into a config in bazelrc to simlify the command. I still don't see a reason for a wrapper script.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's not easier, by far. That already assumes a fixed working machine which isn't a correct assumption. Moving even to a different checkout would require changing this over and over again, making iteration times unnecessary long. And then there's remembering and entering a full bazel command line arguments. What for, when you can just make clean dist install and have JAX up and running after that?
If you don't have a need for something, it doesn't mean the same for others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to enter the full bazel command if you put the args in your config you just need to add this config to your build command. I just say that whatever we do with external script we could do in the bazel itself.
For install you can do:
bazel run --config=rocm_clang_official --config=devenv //jax:jaxlib_install if you have your flags inside the devenv config and jaxlib_install target. This will build and install the jaxlib pkg into your machine.
You could also do
bazel run --config=rocm_clang_official --config=devenv //jax:setup_devenv so to install whatever packages are required.
This will guarantee that you installed what you did build. If you have that external script it will not guarantee that. So you might install things built with the previous checkout. Basically it is orthogonal with the bazel filosophy ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excuse me, are you trying to make a point, that doing all of this including modification of .bazelrc, is simpler than just calling make ?
README.md
Outdated
| This sets up a development environment that allows you to build the ROCm | ||
| plugins via `make`. You can tune the environment and build process to | ||
| fit your specific build needs by editing the `jax_rocm_plugin/Makefile` | ||
| that `stack.py` creates for you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I might be wrong regarding that jax is only used for tests. Plugin build system references some things in @jax//jaxlib for example, so it's mandatory to build plugin.
Another issue with overriding default repo checkout is a set of patches that we have to apply to the checkout. These patches aren't present in the vanilla rocm/jax which is basically a fork of relevant upstream with changes in tests only.
Based on that, to save effort, I would prefer not to override jax checkout location at all. @i-chaochen , can you transfer you changes to the default checkout using working branch when you need it?
cc @charleshofer (I see that PR is already approved without any discussion, so to prevent premature merge, I'll request changes)
|
Updated the PR. There's options for using your own JAX and XLA. I also added make targets to make it easier to build the jaxlib wheel. |
|
@charleshofer did you push changes? |
cdaac33 to
74746d9
Compare
74746d9 to
ea872ed
Compare
|
Pushed now |
Adds a few options to the stack.py script that allow developers to use local copies of XLA and JAX for building wheels. This also adds targets in the Makefile for building jaxlib.
ea872ed to
991dfc9
Compare
|
since #101 is merged, there is one more question I want to clarify, because I'm aware of you guys pinned a XLA commit for the jax/xla build, if I use my local XLA repo, I guess I don't need to change this commit anymore? |
|
@i-chaochen |
It's helpful for XLA developers to be able to keep their local XLA in a directory other than this repository. To get the JAX development build to work with a different XLA path, this would normally require editing the Makefile. This process is a little cumbersome, and it's easier to set this through an option on the
stack.pyfile. This change adds that option. This also adds an argument for building the wheels with a local copy of JAX, similar to the new XLA option, and some make targets for building jaxlib.