gopjrt (Installing)
gopjrt leverages OpenXLA to compile, optimize and accelerate numeric
computations (with large data) from Go using various backends supported by OpenXLA: CPU, GPUs (NVidia, Intel*, Apple Metal*) and TPU*.
It can be used to power Machine Learning frameworks (e.g. GoMLX), image processing, scientific
computation, game AIs, etc.
And because Jax, TensorFlow and optionally PyTorch run on XLA,
it is possible to run Jax functions in Go with gopjrt, and probably TensorFlow and PyTorch as well.
See example 2 below.
(*) Not tested yet, pls let me know if it works for you, or if you can lend access to these hardware (a virtual machine) so that I can use (a virtual machine) for a while, I would love to try to verify and make sure it works there.
gopjrt aims to be minimalist and robust: it provides well maintained, extensible Go wrappers for
OpenXLA PJRT and OpenXLA XlaBuilder libraries.
It is not very ergonomic (error handling everywhere), and the expectation is that others will create a
friendlier API on top of gopjrt -- the same way Jax is a friendlier API
on top of XLA/PJRT.
One such friendlier API is GoMLX, a Go machine learning framework, but gopjrt may be used as a standalone,
for lower level access to XLA and other accelerator use cases -- like running Jax functions in Go.
It provides 2 independent packages (often used together, but not necessarily):
This package loads PJRT plugins -- implementations of PJRT for specific hardware (CPU, GPUs, TPUs, etc.) in the form of a dynamic linked library -- and provides an API to compile and execute "programs".
"Programs" for PJRT are specified as "StableHLO serialized proto-buffers" (HloModuleProto more specifically).
This is an intermediary representation (IR) not usually written directly by humans that can be output by,
for instance, a Jax/PyTorch/Tensorflow program, or using the xlabuilder package described below.
It includes the following main concepts:
Client: first thing created after loading a plugin. It seems one can create a singletonClientper plugin, it's not very clear to me why one would create more than oneClient.LoadedExecutable: Created when one callsClient.Compilean HLO program. It's the compiled/optimized/accelerated code ready to run.Buffer: Represents a buffer with the input/output data for the computations in the accelerators. There are methods to transfer it to/from the host memory. They are the inputs and outputs ofLoadedExecutable.Execute.
While it uses CGO to dynamically load the plugin and call its C API, pjrt doesn't require anything other than the plugin
to be installed.
The project release includes 2 plugins, one for CPU (linux-x86) compiled from XLA source code, and one for GPUs provided in the Jax distributed binaries -- both for linux/x86-64 architecture (help with Mac wanted!). But there are instructions to build your own CPU plugin (e.g.: for a different architecture), or GPU (XLA seems to have code to support ROCm, but I'm not sure of the status). And it should work with binary plugins provided by others -- see plugins references in PJRT blog post.
This provides a Go API for build accelerated computation using the XLA Operations.
The output of building the computation using xlabuilder is an StableHLO(-ish)
program that can be directly used with PJRT (and the pjrt package above).
Again it aims to be minimalist, robust and well maintained, albeit not very ergonomic necessarily.
Main concepts:
XlaBuilder: builder object, used to keep track of the operations being added.XlaComputation: created withXlaBuilder.Build(...)and represents the finished program, ready to be used by PJRT (or saved to disk). It is also used to represent sub-routines/functions -- seeXlaBuilder.CreateSubBuilderandCallmethod.Literal: represents constants in the program. Some similarities with apjrt.Buffer, butLiteralis only used during the creation of the program. Usually, better to avoid large constants in a program, rather feed them aspjrt.Buffer, as inputs to the program during its execution.
See examples below.
The xlabuilder package includes a separate C project that generates a libgomlx_xlabuilder.so dynamic library
(~13Mb for linux/x86-64) and associated *.h files, that need to be installed. A tar.gz is included in the release
for linux/x86-64 architecture (help for Macs wanted!).
But one can also build it from scratch for different platforms -- it uses Bazel due to its dependencies to OpenXLA/XLA.
Notice that there are alternatives to using XlaBuilder:
- JAX/TensorFlow can output the HLO of JIT compiled functions, that can be fed directly to PJRT (see example 2).
- Use GoMLX.
- One can use
XlaBuilderduring development, and then save the output (seeXlaComputation.SerializedHLO). And then during production only use thepjrtpackage to execute it.
- This is a trivial example. XLA/PJRT really shines when doing large number crunching tasks.
- The package
github.com/janpfeifer/mustsimply converts errors to panics.
builder := xlabuilder.New("x*x+1")
x := must.M1(xlabuilder.Parameter(builder, "x", 0, xlabuilder.MakeShape(dtypes.F32))) // Scalar float32.
fX := must.M1(xlabuilder.Mul(x, x))
one := must.M1(xlabuilder.ScalarOne(builder, dtypes.Float32))
fX = must.M1(xlabuilder.Add(fX, one))
// Get computation created.
comp := must.M1(builder.Build(fX))
//fmt.Printf("HloModule proto:\n%s\n\n", comp.TextHLO())
// PJRT plugin and create a client.
plugin := must.M1(pjrt.GetPlugin(*flagPluginName))
fmt.Printf("Loaded %s\n", plugin)
client := must.M1(plugin.NewClient(nil))
// Compile program.
loadedExec := must.M1(client.Compile().WithComputation(comp).Done())
fmt.Printf("Compiled program: name=%s, #outputs=%d\n", loadedExec.Name, loadedExec.NumOutputs)
// Test values:
inputs := []float32{0.1, 1, 3, 4, 5}
fmt.Printf("f(x) = x^2 + 1:\n")
for _, input := range inputs {
inputBuffer := must.M1(pjrt.ScalarToBuffer(client, input))
outputBuffers := must.M1(loadedExec.Execute(inputBuffer).Done())
output := must.M1(pjrt.BufferToScalar[float32](outputBuffers[0]))
fmt.Printf("\tf(x=%g) = %g\n", input, output)
}
// Destroy the client and leave.
must.M1(client.Destroy())First we create the HLO program in Jax/Python (see Jax documentation)
(You can do this with Google's Colab without having to install anything)
import os
import jax
def f(x):
return x*x+1
comp = jax.xla_computation(f)(3.)
print(comp.as_hlo_text())
hlo_proto = comp.as_hlo_module()
with open('hlo.pb', 'wb') as file:
file.write(hlo_proto.as_serialized_hlo_module_proto())Then download the hlo.pb file and do:
- (The package
github.com/janpfeifer/mustsimply converts errors to panics)
hloBlob := must.M1(os.ReadFile("hlo.pb"))
// PJRT plugin and create a client.
plugin := must.M1(pjrt.GetPlugin(*flagPluginName))
fmt.Printf("Loaded %s\n", plugin)
client := must.M1(plugin.NewClient(nil))
loadedExec := must.M1(client.Compile().WithHLO(hloBlob).Done())
// Test values:
inputs := []float32{0.1, 1, 3, 4, 5}
fmt.Printf("f(x) = x^2 + 1:\n")
for _, input := range inputs {
inputBuffer := must.M1(pjrt.ScalarToBuffer(client, input))
outputBuffers := must.M1(loadedExec.Execute(inputBuffer).Done())
output := must.M1(pjrt.BufferToScalar[float32](outputBuffers[0]))
fmt.Printf("\tf(x=%g) = %g\n", input, output)
}Example 3: Mandelbrot Set Notebook
The notebook includes both the "regular" Go implementation and the corresponding implementation using XlaBuilder
and execution with PJRT for comparison, with some benchmarks.
gopjrt requires a C library installed and a plugin module. For Linux (*), run the following script to install under /usr/local/{lib,include}:
curl -sSf https://raw.githubusercontent.com/gomlx/gopjrt/main/cmd/install.sh | bashFor CUDA (NVidia GPU) support, in addition also run:
curl -sSf https://raw.githubusercontent.com/gomlx/gopjrt/main/cmd/install_cuda.sh | bashThat's it. The next sections explains in more details for those interested in special cases.
(*) It would be awesome if someone could build a mac/arm64 version.
The two scripts cmd/install.sh and cmd/install_cuda.sh can be controlled to install in any arbitrary
directory (by setting GOPJRT_INSTALL_DIR) and not to use sudo (by setting GOPJRT_NOSUDO). You many need
to fiddle with LD_LIBRARY_PATH if the installation directory is not standard, and the PJRT_PLUGIN_LIBRARY_PATH
to tell gopjrt where to find the plugins.
There are two parts that needs installing: (1) XLA Builder library (it's a C++ wrapper); (2) PJRT plugins for the accelerator devices you want to support.
The releases come with a prebuilt (1) XLA Builder library for linux/amd64 and (2) the PJRT for CPU,
again only for linux/amd64. One can download it from the latest release in GitHub, or use the
cmd/install.sh script, which does exactly that.
This is the C wrapper (for the C++ XlaBuilder library) needed by Go.
For linux/amd64 you can download the file gomlx_xlabuilder-linux-amd64.tar.gz
and "untar" it to /usr/local -- or some area of your system that is visible by the C-compiler (includes) and the loader (either in LD_LIBRARY_PATH
or /etc/ld.config)
You can use the tool under cmd/install.sh (it also installs the PJRT CPU plugin).
Or manually, with something like the following in Linux:
cd /usr/local
gopjrt_release_download_url="$(curl -s -L -I 'https://github.com/gomlx/gopjrt/releases/latest' | egrep -i '^location: ' | awk '{print $2}' | sed 's|/tag/|/download/| ; s/\r$//')"
echo "Downloading GOPJRT/XlaBuilder library from ${gopjrt_release_download_url}/gomlx_xlabuilder-linux-amd64.tar.gz"
sudo printf "\tsudo authorized\n"
curl -L "${gopjrt_release_download_url}/gomlx_xlabuilder-linux-amd64.tar.gz" | sudo tar xzvFor other base systems, you can build it (see the github.com/gomlx/gopjrt/c directory) from scratch:
if things work, generally it's straight forward, and in a modern computer it will take a few minutes only.
But for different platforms, XLA can be tricky to configure.
The recommended location for plugins is /usr/local/lib/gomlx/pjrt, but the pjrt package
will automatically search in all standard library locations (configured in /etc/ld.so.conf).
Alternatively, one can set the directory(ies) to search for plugins setting the environment variable
PJRT_PLUGIN_LIBRARY_PATH.
The release comes with a CPU plugin pre-compiled for the linux/x86-64 platform. The file is called
pjrt_c_api_cpu_plugin.so.gz. Please, uncompress the file and move it to your plugin directory -- e.g.:
/usr/local/lib/gomlx/pjrt.
You can use the tool under cmd/install.sh. Or manually, with something like the following in Linux:
gopjrt_release_download_url="$(curl -s -L -I 'https://github.com/gomlx/gopjrt/releases/latest' | egrep -i '^location: ' | awk '{print $2}' | sed 's|/tag/|/download/| ; s/\r$//')"
echo "Downloading PJRT CPU plugin from ${gopjrt_release_download_url}/pjrt_c_api_cpu_plugin.so.gz"
sudo printf "\tsudo authorized\n"
sudo mkdir -p /usr/local/lib/gomlx/pjrt
cd /usr/local/lib/gomlx/pjrt
curl -L "${gopjrt_release_download_url}/pjrt_c_api_cpu_plugin.so.gz" | gunzip | sudo bash -c 'cat > pjrt_c_api_cpu_plugin.so'NVidia licenses are complicated (I don't understand), so ... I hesitate to provide a prebuilt plugin and dependencies.
But there is a simple way to achieve it, by linking the files from a Jax installation.
And a script to facilitate it in cmd/install_cuda.sh.
Manually, you should do the following:
Create and activate a virtual environment (venv) for Python. Probably a Conda environment would also work.
Then install Jax for Cuda and its dependencies:
pip install -U "jax[cuda12]"Now we want to link 2 things: (1) the cuda PJRT plugin; (2) the various NVidia drivers.
Assuming the virtual environment (Python's venv) is set, the $VIRTUAL_ENV should be pointing
to its installation. Check that is the case with $VIRTUAL_ENV.
And then do (change the target directories to your preference):
sudo mkdir -p /usr/local/lib/gomlx/pjrt
sudo ln -sf ${VIRTUAL_ENV}/lib/python3.12/site-packages/jax_plugins/xla_cuda12/xla_cuda_plugin.so \
/usr/local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so
sudo ln -sf ${VIRTUAL_ENV}/lib/python3.12/site-packages/nvidia \
/usr/local/lib/gomlx/nvidia
Should be doable, in a similar way as but I don't own a Mac. Contributions would be most welcome.
See docs/devel.md on hints on how to compile a plugin from OpenXLA/XLA sources.
Also, see this blog post with the link and references to the Intel and Apple hardware plugins.
This is only required is the XlaBuilder library (xlabuilder package) is used.
The release comes with a CPU plugin pre-compiled for the linux/x86-64 platform. The file is called
gomlx_xlabuilder-linux-amd64.tar.gz and it includes two subdirectories lib/ and include/ with the files
required to compile Go's xlabuilder package.
The suggest location is to "untar" (decompress and unpackage) this file to /usr/local.
Change the path to the file on the command below:
cd /usr/local
sudo tar xzvf gomlx_xlabuilder-linux-amd64.tar.gzFinally, you want to make sure that your environment variable LD_LIBRARY_PATH includes /usr/local/lib.
Or that your system library paths in /etc/ld.so.conf include /usr/local/lib.
- When is feature X from PJRT or XlaBuilder going to be supported ?
Yes,
gopjrtdoesn't wrap everything -- although it does cover the most common operations. The simple ops and structs are auto-generated. But many require hand-writing. Please if it is useful to your project, create an issue, I'm happy to add it. I focused on the needs of GoMLX, but the idea is that it can serve other purposes, and I'm happy to support it. - Why not split in smaller packages ?
Because of golang/go#13467 : C API's cannot be exported across packages, even within the same repo.
Even a function as simple as
func Add(a, b C.int) C.intin one package cannot be called from another. So we need to wrap everything, and more than that, one cannot create separate sub-packages to handle separate concerns. THis is also the reason the librarychelper.gois copied in bothpjrtandxlabuilderpackages. - Why does PJRT spits out so much logging ? Can we disable it ?
This is a great question ... imagine if every library we use decided they also want to clutter our stderr?
I have an open question in Abseil about it.
It may be some issue with Abseil Logging which also has this other issue
of not allowing two different linked programs/libraries to call its initialization (see Issue #1656).
A hacky work around is duplicating fd 2 and assign to Go's
os.Stderr, and then close fd 2, so PJRT plugins won't have where to log. This hack is encoded in the functionpjrt.SuppressAbseilLoggingHack(): just call it before callingpjrt.GetPlugin. But it may have unintended consequences, if some other library is depending on the fd 2 to work, or if a real exceptional situation needs to be reported and is not.
- Google Drive Directory with Design Docs: Some links are outdated or redirected, but very valuable information.
- How to use the PJRT C API? #openxla/xla/issues/7038: discussion of folks trying to use PJRT in their projects. Some examples leveraging some of the XLA C++ library.
- How to use PJRT C API v.2 #openxla/xla/issues/7038.
- PJRT C API README.md: a collection of links to other documents.
- Public Design Document.
- Gemini helped quite a bit parsing/understanding things -- despite the hallucinations -- other AIs may help as well.
This project utilizes the following components from the OpenXLA project:
-
This project includes a (slightly modified) copy of the OpenXLA's
pjrt_c_api.hfile. -
OpenXLA PJRT CPU Plugin: This plugin enables execution of XLA computations on the CPU.
-
OpenXLA PJRT CUDA Plugin: This plugin enables execution of XLA computations on NVIDIA GPUs.
-
We gratefully acknowledge the OpenXLA team for their valuable work in developing and maintaining these plugins.
gopjrt is licensed under the Apache 2.0 license.
The OpenXLA project, including pjrt_c_api.h file, the CPU and CUDA plugins, is licensed under the Apache 2.0 license.
The CUDA plugin also utilizes the NVIDIA CUDA Toolkit, which is subject to NVIDIA's licensing terms and must be installed by the user.
For more information about OpenXLA, please visit their website at openxla.org, or the github page at github.com/openxla/xla