Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/go.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.24.x"
go-version: "1.25.x"

- name: Install Gopjrt C library gomlx_xlabuilder and PJRT plugin
shell: bash
run: |
curl -sSf https://raw.githubusercontent.com/gomlx/gopjrt/main/cmd/install_linux_amd64.sh | bash
(cd ./cmd/gopjrt_installer && go build -o /tmp/gopjrt_installer . && sudo /tmp/gopjrt_installer -plugin=linux -version=latest -path=/usr/local)
sudo ln -sf /usr/local/lib/libpjrt* /usr/lib/x86_64-linux-gnu/
sudo ln -sf /usr/local/include/gomlx /usr/include/

- name: PreTest
run: |
go test . -test.v
Expand Down
33 changes: 10 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

## Why use GoPJRT ?

GoPJRT leverages [OpenXLA](https://openxla.org/) to compile, optimize and **accelerate numeric
GoPJRT leverages [OpenXLA](https://openxla.org/) to compile, optimize, and **accelerate numeric
computations** (with large data) from Go using various [backends supported by OpenXLA](https://opensource.googleblog.com/2024/03/pjrt-plugin-to-accelerate-machine-learning.html): CPU, GPUs (Nvidia, AMD ROCm*, Intel*, Apple Metal*) and TPU*.
It can be used to power Machine Learning frameworks (e.g. [GoMLX](https://github.com/gomlx/gomlx)), image processing, scientific
computation, game AIs, etc.
Expand All @@ -19,8 +19,6 @@ And because [Jax](https://docs.jax.dev/en/latest/), [TensorFlow](https://www.ten
and probably TensorFlow and PyTorch as well.
See [example 2 in xlabuilder/README.md](https://github.com/gomlx/gopjrt/blob/main/xlabuilder/README.md#example-2).

(*) Not tested or partially supported by the hardware vendor.

GoPJRT aims to be minimalist and robust: it provides well-maintained, extensible Go wrappers for
[OpenXLA PJRT](https://openxla.org/#pjrt).

Expand Down Expand Up @@ -61,7 +59,7 @@ development of **GoPJRT**, [github.com/gomlx/stablehlo](https://github.com/gomlx
> Small ones are debuggable, or can be used to probe which operations are being used behind the scenes,
> but definitely not friendly.

A "PJRT Plugin" is a dynamically linked library (`.so` file in Linux or `.dylib` in Darwin).
A "PJRT Plugin" is a dynamically linked library (`.so` file in Linux, or optionally `.dylib` in Darwin, or `.dll` in Windows).
Typically, there is one plugin per hardware you are supporting. E.g.: there are PJRT plugins
for CPU (Linux/amd64 for now, but likely it could be compiled for other CPUs -- SIMD/AVX are well-supported),
for TPUs (Google's accelerator),
Expand Down Expand Up @@ -111,7 +109,7 @@ The `pjrt` package includes the following main concepts:
methods to transfer it to/from the host memory. They are the inputs and outputs of `LoadedExecutable.Execute`.

PJRT plugins by default are loaded after the program is started (using `dlopen`).
But there is also the option to pre-link the CPU PJRT plugin in your program.
But there is also the option to pre-link the CPU PJRT plugin in your program -- option only works for Linux/amd64 for now.
For that, import (as `_`) one of the following packages:

- `github.com/gomlx/gopjrt/pjrt/cpu/static`: pre-link the CPU PJRT statically, so you don't need to distribute
Expand All @@ -131,31 +129,22 @@ It's been compiled for Macs before—I don't have easy access to an Apple Mac to
## Installing

GoPJRT requires a C library installed for XlaBuilder and one or more "PJRT plugin" modules (the thing that actually does the JIT compilation
of your computation graph). To facilitate, it provides an interactive and self-explanatory installer (it comes with lots of help messages):
of your computation graph). To facilitate, it provides an interactive and self-explanatory installer:

```bash
go run github.com/gomlx/gopjrt/cmd/gopjt_installer
go run github.com/gomlx/gopjrt/cmd/gopjrt_installer@latest
```

You can also directly provide the flags you want to avoid the interactive mode (so it can be used in scripts like Dockerfiles).

> [!NOTE]
> For now it only works for Linux/amd64 (or Windows+WSL) and Nvidia CUDA.
> I managed to write for Darwin (macOS) before, but not having easy access to a Mac to maintain it, eventually I dropped it.
> I would also love to support AMD ROCm, but again, I don't have easy access to hardwre to test/maintain it.
> For now it works for (1) CPU PJRT on linux/amd64 (or Windows+WSL); (2) Nvidia CUDA PJRT on Linux/amd64; (3) CPU PJRT on Darwin (macOS).
> I would love to support for AMD ROCm, Apple Metal (GPU), Intel, and others, but I don't have easy access to hardwre to test/maintain them.
> If you feel like contributing or donating hardware/cloud credits, please contact me.

There are also some older bash install scripts under [`github.com/gomlx/gopjrt/cmd`](https://github.com/gomlx/gopjrt/tree/main/cmd),
but they are deprecated and eventually will be removed in a few versions. Let me know if you need them.

## Building C/C++ dependencies

If you want to build from scratch (both `xlabuilder` and `pjrt` dependencies), go to the `c/` subdirectory
and run `basel.sh`.
It uses [Bazel](https://bazel.build/) due to its dependencies to OpenXLA/XLA.
If not in one of the supported platforms, you will need to create a `xla_configure.OS_ARCH.bazelrc`
file.

## PJRT Plugins for other devices or platforms.

See [docs/devel.md](https://github.com/gomlx/gopjrt/blob/main/docs/devel.md#pjrt-plugins) on hints on how to compile a plugin
Expand All @@ -165,8 +154,8 @@ Also, see [this blog post](https://opensource.googleblog.com/2024/03/pjrt-plugin

## FAQ

* **When is feature X from PJRT or XlaBuilder going to be supported ?**
Yes, GoPJRT doesn't wrap everything—although it does cover the most common operations.
* **When is feature X from PJRT going to be supported ?**
GoPJRT doesn't wrap everything—although it does cover the most common operations.
The simple ops and structs are auto-generated. But many require hand-writing.
Please, if it is useful to your project, create an issue; I'm happy to add it. I focused on the needs of GoMLX,
but the idea is that it can serve other purposes, and I'm happy to support it.
Expand Down Expand Up @@ -212,15 +201,13 @@ Environment variables that help control or debug how GoPJRT works:

## Running Tests

All tests support the following build tags to pre-link the CPU plugin (as opposed to `dlopen` the plugin) -- select at most one of them:
All tests support (in linux) the following build tags to pre-link the CPU plugin (as opposed to `dlopen` the plugin) -- select at most one of them:

* `--tags pjrt_cpu_static`: link (preload) the CPU PJRT plugin from the static library (`.a`) version.
Slowest to build (but executes the same speed).
* `--tags pjrt_cpu_dynamic`: link (preload) the CPU PJRT plugin from the dynamic library (`.so`) version.
Faster to build, but deployments require deploying the `libpjrt_c_api_cpu_dynamic.so` file along.

For Darwin (macOS), for the time being it is hardcoded with static linking, so avoid using these tags.

## Acknowledgements
This project uses the following components from the [OpenXLA project](https://openxla.org/):

Expand Down
62 changes: 42 additions & 20 deletions cmd/gopjrt_installer/cuda.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build (linux && amd64) || all

package main

import (
Expand All @@ -13,6 +15,18 @@ import (
"github.com/pkg/errors"
)

func init() {
for _, plugin := range []string{"cuda13", "cuda12"} {
pluginInstallers[plugin] = CudaInstall
pluginValidators[plugin] = CudaValidateVersion
}
pluginValues = append(pluginValues, "cuda13", "cuda12")
pluginDescriptions = append(pluginDescriptions,
"CUDA PJRT (for Linux/amd64, using CUDA 13)",
"CUDA PJRT (for Linux/amd64, using CUDA 12, deprecated)")
pluginPriorities = append(pluginPriorities, 10, 11)
}

var pipPackageLinuxAMD64 = regexp.MustCompile(`-manylinux.*x86_64`)

// CudaInstall installs the cuda PJRT from the Jax PIP packages, using pypi.org distributed files.
Expand All @@ -21,32 +35,41 @@ var pipPackageLinuxAMD64 = regexp.MustCompile(`-manylinux.*x86_64`)
// - Version exists
// - Author email is from the Jax team
// - Downloaded files sha256 match the ones on pypi.org
func CudaInstall() error {
func CudaInstall(plugin, version, installPath string) error {
// Create the target directory.
installPath := ReplaceTildeInDir(*flagPath)
installPath = ReplaceTildeInDir(installPath)
if err := os.MkdirAll(installPath, 0755); err != nil {
return errors.Wrapf(err, "failed to create install directory in %s", installPath)
}

version, err := CudaInstallPJRT(installPath)
var err error
version, err = CudaInstallPJRT(plugin, version, installPath)
if err != nil {
return err
}

err = CudaInstallNvidiaLibraries(*flagPlugin, version, installPath)
if err != nil {
if err := CudaInstallNvidiaLibraries(plugin, version, installPath); err != nil {
return err
}

cudaVersion := "13"
if *flagPlugin == "cuda12" {
if plugin == "cuda12" {
cudaVersion = "12"
}
fmt.Printf("\n✅ Installed \"cuda\" PJRT and Nvidia libraries based on Jax version %s and CUDA version %s\n\n", version, cudaVersion)
return nil
}

func CudaInstallPJRT(installPath string) (version string, err error) {
// CudaInstallPJRT installs the cuda PJRT from the Jax PIP packages, using pypi.org distributed files.
//
// Checks performed:
// - Version exists
// - Author email is from the Jax team
// - Downloaded files sha256 match the ones on pypi.org
//
// Returns the version that was installed -- it can be different if the requested version was "latest", in which case it
// is translated to the actual version.
func CudaInstallPJRT(plugin, version, installPath string) (string, error) {
// Make the directory that will hold the PJRT files.
pjrtDir := filepath.Join(installPath, "/lib/gomlx/pjrt")
pjrtOutputPath := path.Join(pjrtDir, "pjrt_c_api_cuda_plugin.so")
Expand All @@ -55,16 +78,15 @@ func CudaInstallPJRT(installPath string) (version string, err error) {
}

// Get CUDA PJRT wheel from pypi.org
info, packageName, err := CudaGetPJRTPipInfo(*flagPlugin)
info, packageName, err := CudaGetPJRTPipInfo(plugin)
if err != nil {
return "", errors.WithMessagef(err, "can't fetch pypi.org information for %s", *flagPlugin)
return "", errors.WithMessagef(err, "can't fetch pypi.org information for %s", plugin)
}
if info.Info.AuthorEmail != "[email protected]" {
return "", errors.Errorf("package %s is not from Jax team, something is very suspicious!?", packageName)
}

// Translate "latest" to the actual version if needed.
version = *flagVersion
if version == "latest" {
version = info.Info.Version
}
Expand All @@ -74,12 +96,12 @@ func CudaInstallPJRT(installPath string) (version string, err error) {
versions := slices.Collect(maps.Keys(info.Releases))
slices.Sort(versions)
return "", errors.Errorf("version %q not found for %q (from pip package %q) -- lastest is %q and existing versions are: %s",
*flagVersion, *flagPlugin, packageName, info.Info.Version, strings.Join(versions, ", "))
version, plugin, packageName, info.Info.Version, strings.Join(versions, ", "))
}

releaseInfo, err := PipSelectRelease(releaseInfos, pipPackageLinuxAMD64)
if err != nil {
return "", errors.Wrapf(err, "failed to find release for %s, version %s", *flagPlugin, *flagVersion)
return "", errors.Wrapf(err, "failed to find release for %s, version %s", plugin, version)
}
if releaseInfo.PackageType != "bdist_wheel" {
return "", errors.Errorf("release %s is not a \"binary wheel\" type", releaseInfo.Filename)
Expand All @@ -97,30 +119,30 @@ func CudaInstallPJRT(installPath string) (version string, err error) {
if err != nil {
return "", errors.Wrapf(err, "failed to extract CUDA PJRT file from %q wheel", packageName)
}
fmt.Printf("- Installed %s %s to %s\n", *flagPlugin, version, pjrtOutputPath)
fmt.Printf("- Installed %s %s to %s\n", plugin, version, pjrtOutputPath)
return version, nil
}

// CudaValidateVersion checks whether the cuda version selected by "-version" exists.
func CudaValidateVersion() error {
func CudaValidateVersion(plugin, version string) error {
// "latest" is always valid.
if *flagVersion == "latest" {
if version == "latest" {
return nil
}

info, packageName, err := CudaGetPJRTPipInfo(*flagPlugin)
info, packageName, err := CudaGetPJRTPipInfo(plugin)
if err != nil {
return errors.WithMessagef(err, "can't fetch pypi.org information for %s", *flagPlugin)
return errors.WithMessagef(err, "can't fetch pypi.org information for %s", plugin)
}
if info.Info.AuthorEmail != "[email protected]" {
return errors.Errorf("package %s is not from Jax team, something is very suspicious!?", packageName)
}

if _, ok := info.Releases[*flagVersion]; !ok {
if _, ok := info.Releases[version]; !ok {
versions := slices.Collect(maps.Keys(info.Releases))
slices.Sort(versions)
return errors.Errorf("version %s not found for %s (from pip package %q) -- existing versions: %s",
*flagVersion, *flagPlugin, packageName, strings.Join(versions, ", "))
version, plugin, packageName, strings.Join(versions, ", "))
}

// Version found.
Expand All @@ -130,7 +152,7 @@ func CudaValidateVersion() error {
// CudaGetPJRTPipInfo returns the JSON info for the PIP package that corresponds to the plugin.
func CudaGetPJRTPipInfo(plugin string) (*PipPackageInfo, string, error) {
var packageName string
switch *flagPlugin {
switch plugin {
case "cuda12":
packageName = "jax-cuda12-pjrt"
case "cuda13":
Expand Down
119 changes: 119 additions & 0 deletions cmd/gopjrt_installer/darwin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
//go:build (darwin && arm64) || all

package main

import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"

"github.com/pkg/errors"
)

func init() {
for _, plugin := range []string{"darwin"} {
pluginInstallers[plugin] = DarwinInstall
pluginValidators[plugin] = DarwinValidateVersion
}
pluginValues = append(pluginValues, "darwin")
pluginDescriptions = append(pluginDescriptions, "CPU PJRT (darwin/arm64)")
pluginPriorities = append(pluginPriorities, 3)
installPathSuggestions = append(installPathSuggestions, "/usr/local/", "~/Library/Application Support/GoMLX")
}

// LinuxValidateVersion checks whether the linux version selected by "-version" exists.
func DarwinValidateVersion(plugin, version string) error {
// "latest" is always valid.
if version == "latest" {
return nil
}

_, err := DarwinGetDownloadURL(plugin, version)
if err != nil {
return errors.WithMessagef(err, "can't fetch PJRT plugin from Gopjrt version %q, see "+
"https://github.com/gomlx/gopjrt/releases for a list of release versions to choose from", version)
}
return err
}

// DarwinGetDownloadURL returns the download URL for the given version and plugin.
func DarwinGetDownloadURL(plugin, version string) (url string, err error) {
var assets []string
assets, err = GitHubDownloadReleaseAssets(version)
if err != nil {
return
}

var wantAsset string
switch plugin {
case "darwin":
wantAsset = "gopjrt_cpu_darwin_arm64.tar.gz"
default:
err = errors.Errorf("version validation not implemented for plugin %q in version %s", plugin, version)
return
}
for _, assetURL := range assets {
if strings.HasSuffix(assetURL, wantAsset) {
return assetURL, nil
}
}
return "", errors.Errorf("Plugin %q version %q doesn't seem to have the required asset (%q) -- assets found: %v", plugin, version, wantAsset, assets)
}

// DarwinInstall the assets on the target directory.
func DarwinInstall(plugin, version, installPath string) error {
var err error
if version == "latest" || version == "" {
version, err = GitHubGetLatestVersion()
if err != nil {
return err
}
}
assetURL, err := DarwinGetDownloadURL(plugin, version)
if err != nil {
return err
}
assetName := filepath.Base(assetURL)

// Create the target directory.
installPath = ReplaceTildeInDir(installPath)
if strings.Contains(installPath, "Application Support/GoMLX") {
// Subdirectory in users Application Support directory is uppercased.
installPath = filepath.Join(installPath, "PJRT")
} else {
// E.g.: installPath = "/usr/local" -> installPath = "/usr/local/lib/gomlx/pjrt"
installPath = filepath.Join(installPath, "lib", "gomlx", "pjrt")
}
if err := os.MkdirAll(installPath, 0755); err != nil {
return errors.Wrap(err, "failed to create install directory")
}

// Download the asset to a temporary file.
sha256hash := "" // TODO: no hash for github releases. Is there a way to get them (or get a hardcoded table for all versions?)
downloadedFile, inCache, err := DownloadURLToTemp(assetURL, fmt.Sprintf("%s_%s", version, assetName), sha256hash)
if err != nil {
return err
}
if !inCache {
defer func() { ReportError(os.Remove(downloadedFile)) }()
}

// Extract files
fmt.Printf("- Extracting files in %s to %s\n", downloadedFile, installPath)
extractedFiles, err := Untar(downloadedFile, installPath)
if err != nil {
return err
}
if len(extractedFiles) == 0 {
return errors.Errorf("failed to extract files from %s", downloadedFile)
}
fmt.Printf("- Extracted %d file(s):\n", len(extractedFiles))
for _, file := range extractedFiles {
fmt.Printf(" - %s\n", file)
}

fmt.Printf("\n✅ Installed Gopjrt %s \"cpu\" PJRT to %s (%s/%s)\n\n", version, installPath, runtime.GOOS, runtime.GOARCH)
return nil
}
Loading