A port of 3D Gaussian Splatting to JAX. Fully differentiable, CUDA accelerated.
Requires a working CUDA toolchain to install.
Simply pip installing directly from source should build and install jaxsplat:
$ python -m venv venv && . venv/bin/activate
$ pip install git+https://github.com/yklcs/jaxsplatThe primary function of this library is jaxsplat.render:
img = jaxsplat.render(
means3d, # jax.Array (N, 3)
scales, # jax.Array (N, 3)
quats, # jax.Array (N, 4) normalized
colors, # jax.Array (N, 3)
opacities, # jax.Array (N, 1)
viewmat=viewmat, # jax.Array (4, 4)
background=background, # jax.Array (3,)
img_shape=img_shape, # tuple[int, int] = (H, W)
f=f, # tuple[float, float] = (fx, fy)
c=c, # tuple[int, int] = (cx, cy)
glob_scale=glob_scale, # float
clip_thresh=clip_thresh, # float
block_size=block_size, # int <= 16
)The rendered output is differentiable w.r.t. means3d, scales, quats, colors, and opacities.
Alternatively, jaxsplat.project projects 3D Gaussians to 2D, and jaxsplat.rasterize sorts and rasterizes 2D Gaussians.
jaxsplat.render successively calls jaxsplat.project and jaxsplat.rasterize under the hood.
See /examples for examples. These can be ran like the following:
$ python -m venv venv && . venv/bin/activate
$ pip install -r examples/requirements.txt
# Train Gaussians on a single image
$ python -m examples.single_image input.pngWe use modified versions of gsplat's kernels. The original INRIA implementation uses a custom license and contains dynamically shaped tensors which are harder to port to JAX/XLA.