Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more generic checks for hardware support
  • Loading branch information
JohannesGaessler committed Oct 30, 2025
commit 7efb6acfe6cb552af17cb9555f90f378b1587c97
81 changes: 66 additions & 15 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ namespace ggml_cuda_mma {
static constexpr int ne = I * J / 64;
T x[ne] = {0};

static constexpr __device__ bool supported() {
if (I == 64 && J == 2) return true;
if (I == 16 && J == 8) return true;
if (I == 32 && J == 4) return true;
if (I == 16 && J == 16) return true;
if (I == 32 && J == 32) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
return threadIdx.x % 16;
Expand All @@ -89,7 +98,7 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 32 && J == 32) {
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
Expand All @@ -106,14 +115,19 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 32 && J == 32) {
return threadIdx.x % 32;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
static constexpr int ne = I * J / 32;
T x[ne] = {0};

static constexpr __device__ bool supported() {
if (I == 32 && J == 8) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 32 && J == 8) {
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
Expand All @@ -122,7 +136,7 @@ namespace ggml_cuda_mma {
return (l & 2) | (threadIdx.x & ~2);
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
Expand All @@ -131,23 +145,36 @@ namespace ggml_cuda_mma {
if constexpr (I == 32 && J == 8) {
return (threadIdx.x & 2) | (l & (4 + 1));
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
#else
static constexpr int ne = I * J / 32;
T x[ne] = {0};

static constexpr __device__ bool supported() {
if (I == 8 && J == 4) return true;
if (I == 8 && J == 8) return true;
if (I == 16 && J == 8) return true;
if (I == 16 && J == 16) return true;
if (I == 32 && J == 8) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && (J == 4 || J == 8)) {
if constexpr (I == 8 && J == 4) {
return threadIdx.x / 4;
} else if constexpr ((I == 16 || I == 32) && J == 8) {
} else if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return ((l / 2) * 8) | (threadIdx.x / 4);
} else if constexpr (I == 16 && J == 16) {
return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
} else if constexpr (I == 32 && J == 8) {
return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
Expand All @@ -157,12 +184,14 @@ namespace ggml_cuda_mma {
return threadIdx.x % 4;
} else if constexpr (I == 8 && J == 8) {
return (l * 4) | (threadIdx.x % 4);
} else if constexpr ((I == 16 || I == 32) && J == 8) {
} else if constexpr (I == 16 && J == 8) {
return ((threadIdx.x % 4) * 2) | (l % 2);
} else if constexpr (I == 16 && J == 16) {
return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
} else if constexpr (I == 32 && J == 8) {
return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
Expand All @@ -178,6 +207,12 @@ namespace ggml_cuda_mma {
static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
if (I == 8 && J == 8) return true;
if (I == 32 && J == 8) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && J == 8) {
return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
Expand All @@ -188,7 +223,7 @@ namespace ggml_cuda_mma {
return threadIdx.x;
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
Expand All @@ -197,14 +232,23 @@ namespace ggml_cuda_mma {
if constexpr ((I == 8 || I == 32) && J == 8) {
return l;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
#else
static constexpr int ne = I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
if (I == 8 && J == 4) return true;
if (I == 8 && J == 8) return true;
if (I == 16 && J == 8) return true;
if (I == 16 && J == 16) return true;
if (I == 32 && J == 8) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
Expand All @@ -215,7 +259,7 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 32 && J == 8) {
return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
Expand All @@ -230,7 +274,7 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 32 && J == 8) {
return ((l & 2) * 2) | (threadIdx.x % 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
Expand All @@ -244,6 +288,13 @@ namespace ggml_cuda_mma {
static constexpr int ne = I * J / WARP_SIZE;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
if (I == 8 && J == 8) return true;
if (I == 16 && J == 4) return true;
if (I == 16 && J == 8) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
Expand All @@ -252,7 +303,7 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 16 && J == 8) {
return ((l % 2) * 8) | (threadIdx.x / 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
Expand All @@ -265,7 +316,7 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 16 && J == 8) {
return ((l / 2) * 4) | (threadIdx.x % 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
NO_DEVICE_CODE;
return -1;
}
}
Expand Down
133 changes: 34 additions & 99 deletions ggml/src/ggml-cuda/mmf.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,27 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);

template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
static __device__ void mul_mat_f_impl(
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int stride_col_id, const int stride_row_id,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
typedef tile<32, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<32, 8, float> tile_C;
#else
// In principle also possible to use tiles with I == 32, the performance difference is ~1%.
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();

if (!I_16_supported && !I_32_supported) {
NO_DEVICE_CODE;
return;
}

constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.

typedef tile<I_preferred, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<I_preferred, 8, float> tile_C;

constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
Expand Down Expand Up @@ -238,43 +242,10 @@ static __device__ void mul_mat_f_impl(
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}

template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int stride_col_id, const int stride_row_id,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
if constexpr (std::is_same_v<T, half2>) {
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
mul_mat_f_impl<T, rows_per_block, cols_per_block, nwarps, has_ids>(
x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y,
stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
#else
NO_DEVICE_CODE;
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
} else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, nv_bfloat162>) {
#ifdef AMPERE_MMA_AVAILABLE
mul_mat_f_impl<T, rows_per_block, cols_per_block, nwarps, has_ids>(
x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y,
stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
#else
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
} else {
static_assert(std::is_same_v<T, void>, "bad type");
}
GGML_UNUSED_VARS(x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y,
stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}

//This kernel is for larger batch sizes of mul_mat_id
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
static __device__ void mul_mat_f_ids_impl(
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f_ids(
const T * __restrict__ x, const float * __restrict__ y,
const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
Expand All @@ -283,16 +254,19 @@ static __device__ void mul_mat_f_ids_impl(
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
typedef tile<32, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<32, 8, float> tile_C;
#else
// In principle also possible to use tiles with I == 32, the performance difference is ~1%.
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();

if (!I_16_supported && !I_32_supported) {
NO_DEVICE_CODE;
return;
}

constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.

typedef tile<I_preferred, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<I_preferred, 8, float> tile_C;

constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
Expand Down Expand Up @@ -521,46 +495,6 @@ static __device__ void mul_mat_f_ids_impl(
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}

template <typename T, int rows_per_block, int cols_per_block, int nwarps>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f_ids(
const T * __restrict__ x, const float * __restrict__ y,
const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) {
if constexpr (std::is_same_v<T, half2>) {
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
mul_mat_f_ids_impl<T, rows_per_block, cols_per_block, nwarps>(
x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
#else
NO_DEVICE_CODE;
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
} else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, nv_bfloat162>) {
#ifdef AMPERE_MMA_AVAILABLE
mul_mat_f_ids_impl<T, rows_per_block, cols_per_block, nwarps>(
x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
#else
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
} else {
static_assert(std::is_same_v<T, void>, "bad type");
}
GGML_UNUSED_VARS(
x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
}

template<typename T, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst,
Expand Down Expand Up @@ -618,7 +552,7 @@ void mul_mat_f_cuda(
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream, const mmf_ids_data * ids_data) {
typedef tile<32, 8, T> tile_A_16;
typedef tile<16, 8, T> tile_A_16;
typedef tile<32, 8, T> tile_A_32;
typedef tile< 8, 8, T> tile_B;

Expand All @@ -630,7 +564,8 @@ void mul_mat_f_cuda(
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;

const int device = ggml_cuda_get_device();
const int device = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[device].cc;
const int warp_size = ggml_cuda_info().devices[device].warp_size;

int64_t nwarps_best = 1;
Expand All @@ -645,7 +580,7 @@ void mul_mat_f_cuda(
}

constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * (volta_mma_available ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
Expand Down
Loading