-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add LibTorch Stable ABI infrastructure (#1946) #1956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
This commit implements the complete infrastructure for migrating apex to LibTorch Stable ABI, enabling extensions to work across PyTorch versions without recompilation. **csrc/stable_abi_utils.h** (NEW) - Custom MemoryFormat contiguity checking workaround - Implements is_contiguous() for ChannelsLast/ChannelsLast3d layouts - Addresses stable ABI limitation: Tensor::is_contiguous(MemoryFormat) not supported - Error checking macros (STD_TORCH_CHECK, etc.) - Boxed calling convention helpers for IValue stack manipulation - Type conversion utilities (scalar_type_name, etc.) - Device and CUDA stream management utilities - Common tensor validation functions **csrc/type_shim.h** (MODIFIED) - Added dual-build support via TORCH_STABLE_ONLY conditional compilation - Created apex_internal namespace for cross-compatible types - Updated all type dispatch macros (DISPATCH_FLOAT_AND_HALF, etc.) - Replaced AT_ERROR with APEX_ERROR macro supporting both modes **csrc/multi_tensor_apply.cuh** (MODIFIED) - Updated to support both stable and traditional Tensor types - Created apex_tensor namespace with type aliases - Added is_contiguous_any_format() using custom MemoryFormat workaround - Conditional CUDA stream/device guard management - Updated function signatures to use apex_tensor::Tensor **setup.py** (MODIFIED) - Added USE_STABLE_ABI flag detection from TORCH_STABLE_ONLY environment variable - Created prepare_stable_abi_sources() to substitute .cpp → _stable.cpp - Created add_stable_abi_compile_args() to inject -DTORCH_STABLE_ONLY flag - Added StableCUDAExtension() and StableCppExtension() wrapper functions - Updated ALL 35+ extension definitions to use stable wrappers Traditional build (default): ```bash python setup.py install ``` Stable ABI build: ```bash TORCH_STABLE_ONLY=1 python setup.py install ``` - Stable ABI's Tensor::is_contiguous() doesn't support MemoryFormat parameter - Solution: Custom implementation in stable_abi_utils.h checks ChannelsLast/ChannelsLast3d - Used in multi_tensor_apply.cuh via is_contiguous_any_format() helper - 35+ extension .cpp files need conversion to _stable.cpp versions - Each requires manual PYBIND11 → boxed calling convention conversion - Conversion pattern documented in issue #1946 - Issue: #1946 - Stable ABI docs: https://docs.pytorch.org/docs/stable/notes/libtorch_stable_abi.html - Flash-attention example: Dao-AILab/flash-attention@b3846b0
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,14 +1,49 @@ | ||||||||||||
| #ifdef TORCH_STABLE_ONLY | ||||||||||||
| #include <torch/csrc/stable/tensor.h> | ||||||||||||
| #include <torch/csrc/stable/accelerator.h> | ||||||||||||
| #include <torch/headeronly/types.h> | ||||||||||||
| #include "stable_abi_utils.h" | ||||||||||||
| #else | ||||||||||||
| #include <ATen/ATen.h> | ||||||||||||
| #include <ATen/AccumulateType.h> | ||||||||||||
| #include <ATen/cuda/CUDAContext.h> | ||||||||||||
| #include <ATen/cuda/Exceptions.h> | ||||||||||||
| #include <assert.h> | ||||||||||||
| #include <c10/cuda/CUDAGuard.h> | ||||||||||||
| #endif | ||||||||||||
|
|
||||||||||||
| #include <assert.h> | ||||||||||||
|
|
||||||||||||
| // #include <iostream> | ||||||||||||
|
|
||||||||||||
| // This header is the one-stop shop for all your multi-tensor apply needs. | ||||||||||||
|
|
||||||||||||
| // Namespace aliases for dual-build support | ||||||||||||
| #ifdef TORCH_STABLE_ONLY | ||||||||||||
| namespace apex_tensor { | ||||||||||||
| using Tensor = torch::stable::Tensor; | ||||||||||||
| using MemoryFormat = apex::stable::MemoryFormat; | ||||||||||||
| namespace device = torch::headeronly; | ||||||||||||
|
|
||||||||||||
| inline bool is_contiguous_any_format(const Tensor& t) { | ||||||||||||
| return apex::stable::is_contiguous(t, MemoryFormat::Contiguous) || | ||||||||||||
| apex::stable::is_contiguous(t, MemoryFormat::ChannelsLast) || | ||||||||||||
| apex::stable::is_contiguous(t, MemoryFormat::ChannelsLast3d); | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| #else | ||||||||||||
| namespace apex_tensor { | ||||||||||||
| using Tensor = at::Tensor; | ||||||||||||
| using MemoryFormat = at::MemoryFormat; | ||||||||||||
| namespace device = at; | ||||||||||||
|
|
||||||||||||
| inline bool is_contiguous_any_format(const Tensor& t) { | ||||||||||||
| return t.is_contiguous() || | ||||||||||||
| t.is_contiguous(at::MemoryFormat::ChannelsLast) || | ||||||||||||
| t.is_contiguous(at::MemoryFormat::ChannelsLast3d); | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| #endif | ||||||||||||
|
|
||||||||||||
| // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) | ||||||||||||
| constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; | ||||||||||||
| constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; | ||||||||||||
|
|
@@ -30,21 +65,19 @@ __global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int* noop | |||||||||||
| } | ||||||||||||
|
|
||||||||||||
| template <int depth, typename T, typename... ArgTypes> | ||||||||||||
| void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor& noop_flag, | ||||||||||||
| const std::vector<std::vector<at::Tensor>>& tensor_lists, T callable, ArgTypes... args) { | ||||||||||||
| void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const apex_tensor::Tensor& noop_flag, | ||||||||||||
| const std::vector<std::vector<apex_tensor::Tensor>>& tensor_lists, T callable, ArgTypes... args) { | ||||||||||||
| TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); | ||||||||||||
| int len0 = tensor_lists[0].size(); | ||||||||||||
| TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); | ||||||||||||
| auto ref_device = tensor_lists[0][0].device(); | ||||||||||||
| TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); | ||||||||||||
| TORCH_CHECK(ref_device.type() == apex_tensor::device::kCUDA, "expected input to be on cuda"); | ||||||||||||
| for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices | ||||||||||||
| { | ||||||||||||
| TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); | ||||||||||||
| for (int t = 0; t < tensor_lists[l].size(); t++) { | ||||||||||||
| // TODO: Print which tensor fails. | ||||||||||||
| bool contiguous_memory = tensor_lists[l][t].is_contiguous(); | ||||||||||||
| contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || | ||||||||||||
| tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d)); | ||||||||||||
| bool contiguous_memory = apex_tensor::is_contiguous_any_format(tensor_lists[l][t]); | ||||||||||||
| TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); | ||||||||||||
| TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor"); | ||||||||||||
| TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); | ||||||||||||
|
|
@@ -55,8 +88,16 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor | |||||||||||
|
|
||||||||||||
| TensorListMetadata<depth> tl; | ||||||||||||
|
|
||||||||||||
| #ifdef TORCH_STABLE_ONLY | ||||||||||||
| // Stable ABI: device guard and stream management | ||||||||||||
| auto device = tensor_lists[0][0].device(); | ||||||||||||
| // TODO: stable ABI device guard - for now assume correct device context | ||||||||||||
| cudaStream_t stream = nullptr; // Use default stream for stable ABI | ||||||||||||
|
||||||||||||
| cudaStream_t stream = nullptr; // Use default stream for stable ABI | |
| // TODO: stable ABI stream management - currently uses default stream. | |
| // WARNING: This may cause incorrect synchronization if a non-default stream is active. | |
| // If stable ABI provides a way to get the current stream, use it here. | |
| cudaStream_t stream = nullptr; // Currently uses default stream for stable ABI |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,270 @@ | ||||||
| #pragma once | ||||||
|
|
||||||
| #ifdef TORCH_STABLE_ONLY | ||||||
|
|
||||||
| // Stable ABI headers | ||||||
| #include <torch/csrc/stable/tensor.h> | ||||||
| #include <torch/csrc/stable/library.h> | ||||||
| #include <torch/csrc/stable/ivalue.h> | ||||||
| #include <torch/csrc/stable/accelerator.h> | ||||||
| #include <torch/csrc/stable/dispatcher.h> | ||||||
| #include <torch/headeronly/types.h> | ||||||
|
|
||||||
| namespace apex { | ||||||
| namespace stable { | ||||||
|
|
||||||
| // ============================================================================ | ||||||
| // MemoryFormat Contiguity Checking Workaround | ||||||
| // ============================================================================ | ||||||
| // The stable ABI's Tensor::is_contiguous() doesn't support MemoryFormat | ||||||
| // parameter. This provides a workaround for checking different memory layouts. | ||||||
|
|
||||||
| enum class MemoryFormat { | ||||||
| Contiguous, | ||||||
| ChannelsLast, | ||||||
| ChannelsLast3d, | ||||||
| Preserve | ||||||
| }; | ||||||
|
||||||
|
|
||||||
| // Check if a tensor is contiguous in a specific memory format | ||||||
| inline bool is_contiguous(const torch::stable::Tensor& tensor, MemoryFormat format) { | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given torch::stable would be under development, wouldn't it sound perhaps legit to wait this gets implemented in the upstream? |
||||||
| using namespace torch::stable; | ||||||
|
|
||||||
| // For standard contiguous check, use the stable ABI method | ||||||
| if (format == MemoryFormat::Contiguous) { | ||||||
| return tensor.is_contiguous(); | ||||||
| } | ||||||
|
|
||||||
| // For ChannelsLast and ChannelsLast3d, we need custom logic | ||||||
| // Get tensor properties | ||||||
| auto sizes = tensor.sizes(); | ||||||
| auto strides = tensor.strides(); | ||||||
| int64_t ndim = tensor.dim(); | ||||||
|
|
||||||
| if (format == MemoryFormat::ChannelsLast) { | ||||||
| // NCHW format requires ndim == 4 | ||||||
|
||||||
| // NCHW format requires ndim == 4 | |
| // ChannelsLast (NHWC) format requires ndim == 4 |
Copilot
AI
Nov 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Misleading comment. The comment says "NCDHW format requires ndim == 5" but this code is checking for ChannelsLast3d format, which is NDHWC, not NCDHW. NCDHW is the standard contiguous format for 5D tensors. The comment should say "ChannelsLast3d (NDHWC) format requires ndim == 5" for clarity.
| // NCDHW format requires ndim == 5 | |
| // ChannelsLast3d (NDHWC) format requires ndim == 5 |
Copilot
AI
Nov 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Potential buffer overflow in error message handling. The STD_TORCH_CHECK macro uses a fixed 1024-byte buffer with snprintf. If the formatted error message exceeds this size, it will be truncated silently. Consider either: 1) using a larger buffer size, 2) using dynamic allocation with std::string and stream formatting, or 3) documenting the message length limitation clearly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much would this preserve the current semantics?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The stable ABI path uses
nullptr(the default stream) vs. the user's current streamThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would that mean the change potentially affects the behavior of multi_tensor_apply?