Thanks to visit codestin.com
Credit goes to github.com

Skip to content

NotImplementedError: Could not run 'aten::index.Tensor' with arguments from the 'SparseCUDA' backend. #152226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ringohoffman opened this issue Apr 25, 2025 · 3 comments
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ringohoffman
Copy link
Contributor

ringohoffman commented Apr 25, 2025

πŸš€ The feature, motivation and pitch

I want to make vectorized selections on a sparse tensor, but it isn't implemented for the SparseCUDA backend.

import torch

device = torch.device("cuda:0")
indices = torch.tensor(
    [
        [0, 1, 2, 3],
        [1, 2, 3, 4],
        [2, 3, 4, 5]
    ],
    device=device,
)
values = torch.tensor(
    [10.0, 20.0, 30.0, 40.0],
    device=device,
)
size = (4, 5, 6)

sparse_tensor = torch.sparse_coo_tensor(indices, values, size)

indices = torch.tensor([
    [0, 1],
    [1, 2],
    [2, 3],
], device=device)

dense_tensor = sparse_tensor.to_dense()
result = dense_tensor[tuple(indices)]
# tensor([10., 20.], device='cuda:0')

sparse_result = sparse_tensor[indices]
# NotImplementedError: Could not run 'aten::index.Tensor' with arguments from the 'SparseCUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::index.Tensor' is only available for these backends: [CPU, CUDA, HIP, MPS, IPU, XPU, HPU, VE, MTIA, PrivateUse1, PrivateUse2, PrivateUse3, Meta, FPGA, MAIA, Vulkan, Metal, QuantizedCPU, QuantizedCUDA, QuantizedHIP, QuantizedMPS, QuantizedIPU, QuantizedXPU, QuantizedHPU, QuantizedVE, QuantizedMTIA, QuantizedPrivateUse1, QuantizedPrivateUse2, QuantizedPrivateUse3, QuantizedMeta, CustomRNGKeyId, MkldnnCPU, SparseCsrCPU, SparseCsrCUDA, SparseCsrHIP, SparseCsrMPS, SparseCsrIPU, SparseCsrXPU, SparseCsrHPU, SparseCsrVE, SparseCsrMTIA, SparseCsrPrivateUse1, SparseCsrPrivateUse2, SparseCsrPrivateUse3, SparseCsrMeta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].
$ python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 2.6.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-136-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to: LAZY

Alternatives

No response

Additional context

No response

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @jcaip

@ringohoffman
Copy link
Contributor Author

@pytorchbot label "module: sparse"

@pytorch-bot pytorch-bot bot added the module: sparse Related to torch.sparse label Apr 25, 2025
@ringohoffman
Copy link
Contributor Author

@malfet malfet added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 25, 2025
@ringohoffman
Copy link
Contributor Author

This is an implementation I came up with:

def sparse_index(sparse_tensor: torch.Tensor, query_indices: torch.Tensor) -> torch.Tensor:
    """Return the values of the sparse tensor at the given indices.

    Args:
        sparse_tensor: A sparse tensor with `ndim` dimensions.
        query_indices: The indices of the sparse tensor to query the values of of shape (`ndim`, `num_queries`).

    Returns:
        The values of the sparse tensor at the given indices of shape (`num_queries`,).
    """
    assert sparse_tensor.is_sparse

    sparse_tensor = sparse_tensor.coalesce()

    ndim, num_query_indices = query_indices.shape
    strides = torch.empty((ndim, 1), dtype=torch.long, device=query_indices.device)
    strides[-1] = 1
    for i in range(ndim - 2, -1, -1):
        strides[i] = strides[i + 1] * sparse_tensor.shape[i + 1]

    linear_nonzero_indices = (sparse_tensor.indices() * strides).sum(dim=0)
    linear_query_indices = (query_indices * strides).sum(dim=0)

    # the indices of coalesced sparse tensors are sorted in lexicographical order: https://pytorch.org/docs/stable/sparse.html#uncoalesced-sparse-coo-tensors
    # so we can perform binary search on it without needing to sort it
    insertion_indices = torch.searchsorted(linear_nonzero_indices, linear_query_indices).clamp_max(linear_nonzero_indices.numel() - 1)
    is_nonzero_indices = linear_nonzero_indices.index_select(dim=0, index=insertion_indices) == linear_query_indices

    query_values = torch.empty(num_query_indices, dtype=sparse_tensor.dtype, device=linear_query_indices.device)
    query_values[~is_nonzero_indices] = 0
    query_values[is_nonzero_indices] = sparse_tensor.values()[insertion_indices[is_nonzero_indices]]
    return query_values

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants