Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add LibTorch Stable ABI infrastructure (#1946)
This commit implements the complete infrastructure for migrating apex to LibTorch
Stable ABI, enabling extensions to work across PyTorch versions without
recompilation.

**csrc/stable_abi_utils.h** (NEW)
- Custom MemoryFormat contiguity checking workaround
  - Implements is_contiguous() for ChannelsLast/ChannelsLast3d layouts
  - Addresses stable ABI limitation: Tensor::is_contiguous(MemoryFormat) not supported
- Error checking macros (STD_TORCH_CHECK, etc.)
- Boxed calling convention helpers for IValue stack manipulation
- Type conversion utilities (scalar_type_name, etc.)
- Device and CUDA stream management utilities
- Common tensor validation functions

**csrc/type_shim.h** (MODIFIED)
- Added dual-build support via TORCH_STABLE_ONLY conditional compilation
- Created apex_internal namespace for cross-compatible types
- Updated all type dispatch macros (DISPATCH_FLOAT_AND_HALF, etc.)
- Replaced AT_ERROR with APEX_ERROR macro supporting both modes

**csrc/multi_tensor_apply.cuh** (MODIFIED)
- Updated to support both stable and traditional Tensor types
- Created apex_tensor namespace with type aliases
- Added is_contiguous_any_format() using custom MemoryFormat workaround
- Conditional CUDA stream/device guard management
- Updated function signatures to use apex_tensor::Tensor

**setup.py** (MODIFIED)
- Added USE_STABLE_ABI flag detection from TORCH_STABLE_ONLY environment variable
- Created prepare_stable_abi_sources() to substitute .cpp → _stable.cpp
- Created add_stable_abi_compile_args() to inject -DTORCH_STABLE_ONLY flag
- Added StableCUDAExtension() and StableCppExtension() wrapper functions
- Updated ALL 35+ extension definitions to use stable wrappers

Traditional build (default):
```bash
python setup.py install
```

Stable ABI build:
```bash
TORCH_STABLE_ONLY=1 python setup.py install
```

- Stable ABI's Tensor::is_contiguous() doesn't support MemoryFormat parameter
- Solution: Custom implementation in stable_abi_utils.h checks ChannelsLast/ChannelsLast3d
- Used in multi_tensor_apply.cuh via is_contiguous_any_format() helper

- 35+ extension .cpp files need conversion to _stable.cpp versions
- Each requires manual PYBIND11 → boxed calling convention conversion
- Conversion pattern documented in issue #1946

- Issue: #1946
- Stable ABI docs: https://docs.pytorch.org/docs/stable/notes/libtorch_stable_abi.html
- Flash-attention example: Dao-AILab/flash-attention@b3846b0
  • Loading branch information
jackulau committed Dec 7, 2025
commit 4222f6d13d2f6c66c610c62c481471ab0293b361
60 changes: 53 additions & 7 deletions csrc/multi_tensor_apply.cuh
Original file line number Diff line number Diff line change
@@ -1,14 +1,49 @@
#ifdef TORCH_STABLE_ONLY
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/headeronly/types.h>
#include "stable_abi_utils.h"
#else
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <c10/cuda/CUDAGuard.h>
#endif

#include <assert.h>

// #include <iostream>

// This header is the one-stop shop for all your multi-tensor apply needs.

// Namespace aliases for dual-build support
#ifdef TORCH_STABLE_ONLY
namespace apex_tensor {
using Tensor = torch::stable::Tensor;
using MemoryFormat = apex::stable::MemoryFormat;
namespace device = torch::headeronly;

inline bool is_contiguous_any_format(const Tensor& t) {
return apex::stable::is_contiguous(t, MemoryFormat::Contiguous) ||
apex::stable::is_contiguous(t, MemoryFormat::ChannelsLast) ||
apex::stable::is_contiguous(t, MemoryFormat::ChannelsLast3d);
}
}
#else
namespace apex_tensor {
using Tensor = at::Tensor;
using MemoryFormat = at::MemoryFormat;
namespace device = at;

inline bool is_contiguous_any_format(const Tensor& t) {
return t.is_contiguous() ||
t.is_contiguous(at::MemoryFormat::ChannelsLast) ||
t.is_contiguous(at::MemoryFormat::ChannelsLast3d);
}
}
#endif

// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
Expand All @@ -30,21 +65,19 @@ __global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int* noop
}

template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor& noop_flag,
const std::vector<std::vector<at::Tensor>>& tensor_lists, T callable, ArgTypes... args) {
void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const apex_tensor::Tensor& noop_flag,
const std::vector<std::vector<apex_tensor::Tensor>>& tensor_lists, T callable, ArgTypes... args) {
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
TORCH_CHECK(ref_device.type() == apex_tensor::device::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
{
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
bool contiguous_memory = apex_tensor::is_contiguous_any_format(tensor_lists[l][t]);
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
Expand All @@ -55,8 +88,16 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor

TensorListMetadata<depth> tl;

#ifdef TORCH_STABLE_ONLY
// Stable ABI: device guard and stream management
auto device = tensor_lists[0][0].device();
// TODO: stable ABI device guard - for now assume correct device context
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much would this preserve the current semantics?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stable ABI path uses nullptr (the default stream) vs. the user's current stream

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would that mean the change potentially affects the behavior of multi_tensor_apply?

cudaStream_t stream = nullptr; // Use default stream for stable ABI
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential issue with CUDA stream handling in stable ABI path. Setting cudaStream_t stream = nullptr uses the default stream, but the comment mentions this is a TODO. The traditional path uses at::cuda::getCurrentCUDAStream() which gets the actual current stream. Using different streams between the two code paths could lead to incorrect synchronization behavior and subtle race conditions. Consider either: 1) implementing proper stream retrieval for stable ABI, or 2) documenting the limitation and its implications more clearly.

Suggested change
cudaStream_t stream = nullptr; // Use default stream for stable ABI
// TODO: stable ABI stream management - currently uses default stream.
// WARNING: This may cause incorrect synchronization if a non-default stream is active.
// If stable ABI provides a way to get the current stream, use it here.
cudaStream_t stream = nullptr; // Currently uses default stream for stable ABI

Copilot uses AI. Check for mistakes.
cudaGetLastError(); // Clear any prior errors
#else
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();
#endif

tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
Expand All @@ -82,7 +123,12 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(chunk_size, noop_flag.data_ptr<int>(), tl,
callable, args...);

#ifdef TORCH_STABLE_ONLY
cudaError_t err = cudaGetLastError();
apex::stable::STD_TORCH_CHECK(err == cudaSuccess, "CUDA kernel launch failed: %s", cudaGetErrorString(err));
#else
AT_CUDA_CHECK(cudaGetLastError());
#endif

// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
Expand Down
270 changes: 270 additions & 0 deletions csrc/stable_abi_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
#pragma once

#ifdef TORCH_STABLE_ONLY

// Stable ABI headers
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ivalue.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/dispatcher.h>
#include <torch/headeronly/types.h>

namespace apex {
namespace stable {

// ============================================================================
// MemoryFormat Contiguity Checking Workaround
// ============================================================================
// The stable ABI's Tensor::is_contiguous() doesn't support MemoryFormat
// parameter. This provides a workaround for checking different memory layouts.

enum class MemoryFormat {
Contiguous,
ChannelsLast,
ChannelsLast3d,
Preserve
};
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Incomplete implementation of MemoryFormat enum. The MemoryFormat::Preserve enum value is defined but not handled in the is_contiguous() function. If this value is passed to the function, it will return false by default. Consider either: 1) implementing the Preserve case (though its semantics are unclear for a contiguity check), 2) removing it if not needed, or 3) explicitly documenting that Preserve is not supported for contiguity checks.

Copilot uses AI. Check for mistakes.

// Check if a tensor is contiguous in a specific memory format
inline bool is_contiguous(const torch::stable::Tensor& tensor, MemoryFormat format) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given torch::stable would be under development, wouldn't it sound perhaps legit to wait this gets implemented in the upstream?

using namespace torch::stable;

// For standard contiguous check, use the stable ABI method
if (format == MemoryFormat::Contiguous) {
return tensor.is_contiguous();
}

// For ChannelsLast and ChannelsLast3d, we need custom logic
// Get tensor properties
auto sizes = tensor.sizes();
auto strides = tensor.strides();
int64_t ndim = tensor.dim();

if (format == MemoryFormat::ChannelsLast) {
// NCHW format requires ndim == 4
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misleading comment. The comment says "NCHW format requires ndim == 4" but this code is checking for ChannelsLast format, which is NHWC, not NCHW. NCHW is the standard contiguous format. The comment should say "ChannelsLast (NHWC) format requires ndim == 4" for clarity.

Suggested change
// NCHW format requires ndim == 4
// ChannelsLast (NHWC) format requires ndim == 4

Copilot uses AI. Check for mistakes.
if (ndim != 4) return false;

// For ChannelsLast (NHWC), strides should follow: C=1, W=C, H=W*W_size, N=H*H_size
// Expected stride order: strides[1] < strides[3] < strides[2] < strides[0]
int64_t N = sizes[0], C = sizes[1], H = sizes[2], W = sizes[3];
int64_t stride_c = strides[1];
int64_t stride_w = strides[3];
int64_t stride_h = strides[2];
int64_t stride_n = strides[0];

// Check if strides match NHWC layout
return (stride_c == 1) &&
(stride_w == C) &&
(stride_h == W * C) &&
(stride_n == H * W * C);
}

if (format == MemoryFormat::ChannelsLast3d) {
// NCDHW format requires ndim == 5
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misleading comment. The comment says "NCDHW format requires ndim == 5" but this code is checking for ChannelsLast3d format, which is NDHWC, not NCDHW. NCDHW is the standard contiguous format for 5D tensors. The comment should say "ChannelsLast3d (NDHWC) format requires ndim == 5" for clarity.

Suggested change
// NCDHW format requires ndim == 5
// ChannelsLast3d (NDHWC) format requires ndim == 5

Copilot uses AI. Check for mistakes.
if (ndim != 5) return false;

// For ChannelsLast3d (NDHWC), similar logic for 5D tensors
int64_t N = sizes[0], C = sizes[1], D = sizes[2], H = sizes[3], W = sizes[4];
int64_t stride_c = strides[1];
int64_t stride_w = strides[4];
int64_t stride_h = strides[3];
int64_t stride_d = strides[2];
int64_t stride_n = strides[0];

// Check if strides match NDHWC layout
return (stride_c == 1) &&
(stride_w == C) &&
(stride_h == W * C) &&
(stride_d == H * W * C) &&
(stride_n == D * H * W * C);
}

return false;
}

// ============================================================================
// Type Conversion Utilities
// ============================================================================

// Convert stable ScalarType to string for error messages
inline const char* scalar_type_name(torch::headeronly::ScalarType type) {
using namespace torch::headeronly;
switch (type) {
case kByte: return "Byte";
case kChar: return "Char";
case kShort: return "Short";
case kInt: return "Int";
case kLong: return "Long";
case kHalf: return "Half";
case kFloat: return "Float";
case kDouble: return "Double";
case kBool: return "Bool";
case kBFloat16: return "BFloat16";
case kFloat8_e5m2: return "Float8_e5m2";
case kFloat8_e4m3fn: return "Float8_e4m3fn";
default: return "Unknown";
}
}

// ============================================================================
// Error Checking Macros
// ============================================================================

#define STD_TORCH_CHECK(cond, ...) \
do { \
if (!(cond)) { \
char buffer[1024]; \
snprintf(buffer, sizeof(buffer), __VA_ARGS__); \
throw std::runtime_error(buffer); \
} \
} while (0)
Copy link

Copilot AI Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Potential buffer overflow in error message handling. The STD_TORCH_CHECK macro uses a fixed 1024-byte buffer with snprintf. If the formatted error message exceeds this size, it will be truncated silently. Consider either: 1) using a larger buffer size, 2) using dynamic allocation with std::string and stream formatting, or 3) documenting the message length limitation clearly.

Copilot uses AI. Check for mistakes.

#define STD_TORCH_CHECK_EQ(a, b, ...) STD_TORCH_CHECK((a) == (b), __VA_ARGS__)
#define STD_TORCH_CHECK_NE(a, b, ...) STD_TORCH_CHECK((a) != (b), __VA_ARGS__)
#define STD_TORCH_CHECK_GT(a, b, ...) STD_TORCH_CHECK((a) > (b), __VA_ARGS__)
#define STD_TORCH_CHECK_GE(a, b, ...) STD_TORCH_CHECK((a) >= (b), __VA_ARGS__)
#define STD_TORCH_CHECK_LT(a, b, ...) STD_TORCH_CHECK((a) < (b), __VA_ARGS__)
#define STD_TORCH_CHECK_LE(a, b, ...) STD_TORCH_CHECK((a) <= (b), __VA_ARGS__)

// ============================================================================
// Boxed Calling Convention Helpers
// ============================================================================

// Helper to extract tensor from IValue stack
inline torch::stable::Tensor tensor_from_stack(torch::stable::StableIValue* stack, int idx) {
return stack[idx].toTensor();
}

// Helper to extract int64 from IValue stack
inline int64_t int64_from_stack(torch::stable::StableIValue* stack, int idx) {
return stack[idx].toInt();
}

// Helper to extract double from IValue stack
inline double double_from_stack(torch::stable::StableIValue* stack, int idx) {
return stack[idx].toDouble();
}

// Helper to extract bool from IValue stack
inline bool bool_from_stack(torch::stable::StableIValue* stack, int idx) {
return stack[idx].toBool();
}

// Helper to extract optional tensor from IValue stack
inline std::optional<torch::stable::Tensor> optional_tensor_from_stack(
torch::stable::StableIValue* stack, int idx) {
if (stack[idx].isNone()) {
return std::nullopt;
}
return stack[idx].toTensor();
}

// Helper to extract tensor list from IValue stack
inline std::vector<torch::stable::Tensor> tensor_list_from_stack(
torch::stable::StableIValue* stack, int idx) {
auto list = stack[idx].toList();
std::vector<torch::stable::Tensor> result;
result.reserve(list.size());
for (size_t i = 0; i < list.size(); ++i) {
result.push_back(list.get(i).toTensor());
}
return result;
}

// Helper to put tensor to IValue stack
inline void tensor_to_stack(torch::stable::StableIValue* stack, int idx,
const torch::stable::Tensor& tensor) {
stack[idx] = torch::stable::StableIValue::from(tensor);
}

// Helper to put tuple to IValue stack
inline void tuple_to_stack(torch::stable::StableIValue* stack, int idx,
const std::vector<torch::stable::Tensor>& tensors) {
std::vector<torch::stable::StableIValue> ivalues;
ivalues.reserve(tensors.size());
for (const auto& t : tensors) {
ivalues.push_back(torch::stable::StableIValue::from(t));
}
stack[idx] = torch::stable::StableIValue::fromTuple(ivalues);
}

// Helper to put list to IValue stack
inline void tensor_list_to_stack(torch::stable::StableIValue* stack, int idx,
const std::vector<torch::stable::Tensor>& tensors) {
std::vector<torch::stable::StableIValue> ivalues;
ivalues.reserve(tensors.size());
for (const auto& t : tensors) {
ivalues.push_back(torch::stable::StableIValue::from(t));
}
stack[idx] = torch::stable::StableIValue::fromList(ivalues);
}

// ============================================================================
// Device and Stream Utilities
// ============================================================================

// Check if tensor is on CUDA
inline bool is_cuda(const torch::stable::Tensor& tensor) {
return tensor.device().type() == torch::headeronly::kCUDA;
}

// Get CUDA device index
inline int64_t get_device_index(const torch::stable::Tensor& tensor) {
STD_TORCH_CHECK(is_cuda(tensor), "Tensor must be on CUDA device");
return tensor.device().index();
}

// ============================================================================
// Common Tensor Checks
// ============================================================================

inline void check_cuda(const torch::stable::Tensor& tensor, const char* name) {
STD_TORCH_CHECK(is_cuda(tensor), "%s must be a CUDA tensor", name);
}

inline void check_contiguous(const torch::stable::Tensor& tensor, const char* name) {
STD_TORCH_CHECK(tensor.is_contiguous(), "%s must be contiguous", name);
}

inline void check_same_device(const torch::stable::Tensor& t1,
const torch::stable::Tensor& t2,
const char* name1, const char* name2) {
STD_TORCH_CHECK(t1.device() == t2.device(),
"%s and %s must be on the same device", name1, name2);
}

inline void check_same_dtype(const torch::stable::Tensor& t1,
const torch::stable::Tensor& t2,
const char* name1, const char* name2) {
STD_TORCH_CHECK(t1.scalar_type() == t2.scalar_type(),
"%s and %s must have the same dtype, got %s and %s",
name1, name2,
scalar_type_name(t1.scalar_type()),
scalar_type_name(t2.scalar_type()));
}

} // namespace stable
} // namespace apex

#else // !TORCH_STABLE_ONLY

// When not using stable ABI, provide no-op definitions or traditional includes
#include <torch/extension.h>
#include <ATen/ATen.h>

namespace apex {
namespace stable {

// Map to traditional PyTorch MemoryFormat for non-stable builds
using MemoryFormat = at::MemoryFormat;

// Use traditional is_contiguous in non-stable builds
inline bool is_contiguous(const at::Tensor& tensor, at::MemoryFormat format) {
return tensor.is_contiguous(format);
}

} // namespace stable
} // namespace apex

#endif // TORCH_STABLE_ONLY
Loading