Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cmd/gopjrt_installer/cuda.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,20 @@ func CudaInstallNvidiaLibraries(plugin, version, installPath string) error {
if err := os.Symlink(ptxasPath, ptxasLinkPath); err != nil {
return errors.Wrapf(err, "failed to create symbolic link to ptxas in %s", ptxasLinkPath)
}

// Link libraries that Nvidia is not able to find from the SDK path set.
switch plugin {
case "cuda13":
libsPath := filepath.Join(installPath, "lib")
libCublasPath := "./gomlx/nvidia/cu13/lib"
for _, srcName := range []string{"libcublasLt.so.13", "libcublas.so.13"} {
dstPath := filepath.Join(libsPath, filepath.Base(srcName))
srcPath := filepath.Join(libCublasPath, srcName)
if err := os.Symlink(srcPath, dstPath); err != nil {
return errors.Wrapf(err, "failed to create symbolic link to %s in %s", srcPath, dstPath)
}
}
}
return nil
}

Expand Down
6 changes: 6 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Gopjrt Changelog

# Next

- Package `cmd/gopjrt_installer`:
- Link `libcublasLt.so.13` and `libcublas.so.13` to the `lib` subdirectory of the install directory given.
Nvidia needs it for some models, but doesn't know how to find it within the provided SDK path.

# v0.9.1 2025/11/07: More multi-device support; updated CPU PJRT; dropped static CPU PJRT linking.

- Package `pjrt`:
Expand Down
13 changes: 7 additions & 6 deletions pjrt/cuda_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,13 @@ func cudaNVidiaPath(plugin *Plugin) (nvidiaExpectedPath string, found bool) {
return nvidiaExpectedPath, err == nil && fi.IsDir()
}

// XlaFlagsEnv is the name of the environment variable to set XLA_FLAGS.
const XlaFlagsEnv = "XLA_FLAGS"

// cudaSetCUDADir as a flag set into the environment variable XLA_FLAGS.
func cudaSetCUDADir(nvidiaPath string) {
existingXLAFlags := os.Getenv(XlaFlagsEnv)
const (
XLAFlagsEnv = "XLA_FLAGS"
)

existingXLAFlags := os.Getenv(XLAFlagsEnv)
var newValue string
if existingXLAFlags != "" && !strings.Contains(existingXLAFlags, "--xla_gpu_cuda_data_dir") {
newValue = fmt.Sprintf("%s --xla_gpu_cuda_data_dir=%s", existingXLAFlags, nvidiaPath)
Expand All @@ -122,9 +123,9 @@ func cudaSetCUDADir(nvidiaPath string) {
newValue = fmt.Sprintf("--xla_gpu_cuda_data_dir=%s", nvidiaPath)
}
if newValue != "" {
err := os.Setenv(XlaFlagsEnv, newValue)
err := os.Setenv(XLAFlagsEnv, newValue)
if err != nil {
klog.Warningf("Failed to set %q environment variable to %q: %v", XlaFlagsEnv, newValue, err)
klog.Warningf("Failed to set %q environment variable to %q: %v", XLAFlagsEnv, newValue, err)
}
}
}
Loading