Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch.utils._triton import has_triton

from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
F4_E2M1_EXP_BIAS,
Expand Down Expand Up @@ -335,11 +334,13 @@ def test_fp4_triton_unscaled_cast():
def test_fp4_triton_scaled_cast():
size = (256,)
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
mxtensor_ref = MXTensor.to_mx(orig_vals, block_size=32, elem_dtype=DTYPE_FP4)
mxtensor_ref = MXTensor.to_mx(
orig_vals, block_size=32, elem_dtype=torch.float4_e2m1fn_x2
)
mxtensor_triton = MXTensor.to_mx(
orig_vals,
block_size=32,
elem_dtype=DTYPE_FP4,
elem_dtype=torch.float4_e2m1fn_x2,
use_fp4_custom_triton_dequant_kernel=True,
)

Expand Down
57 changes: 43 additions & 14 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import torch.nn as nn

from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
MXInferenceLinearConfig,
MXLinearConfig,
MXLinearRecipeName,
)
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
SUPPORTED_ELEM_DTYPES,
Expand All @@ -29,15 +29,14 @@
from torchao.quantization import quantize_
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_7,
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_89,
is_sm_at_least_100,
)

torch.manual_seed(2)

if not TORCH_VERSION_AT_LEAST_2_7:
if not TORCH_VERSION_AT_LEAST_2_8:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand All @@ -51,19 +50,28 @@ def run_around_tests():
torch._dynamo.reset()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"elem_dtype",
(
elem_dtypes = (
[
# test each dtype
(torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn),
(DTYPE_FP6_E3M2, DTYPE_FP6_E3M2, DTYPE_FP6_E3M2),
(DTYPE_FP6_E2M3, DTYPE_FP6_E2M3, DTYPE_FP6_E2M3),
(DTYPE_FP4, DTYPE_FP4, DTYPE_FP4),
(torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2),
# only test one type of mixed-dtype overrides, to save testing time
(torch.float8_e4m3fn, DTYPE_FP4, DTYPE_FP4),
),
(torch.float8_e4m3fn, torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2),
]
if TORCH_VERSION_AT_LEAST_2_8
else [
# test each dtype
(torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn),
(DTYPE_FP6_E3M2, DTYPE_FP6_E3M2, DTYPE_FP6_E3M2),
(DTYPE_FP6_E2M3, DTYPE_FP6_E2M3, DTYPE_FP6_E2M3),
]
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", elem_dtypes)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)])
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
Expand Down Expand Up @@ -155,7 +163,7 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):

elem_dtype = torch.float8_e4m3fn
if recipe_name == MXLinearRecipeName.MXFP4_CUTLASS:
elem_dtype = DTYPE_FP4
elem_dtype = torch.float4_e2m1fn_x2

config_emulated = MXLinearConfig(block_size=32, elem_dtype=elem_dtype)
config_real = MXLinearConfig.from_recipe_name(recipe_name)
Expand Down Expand Up @@ -375,12 +383,21 @@ def test_inference_print_str():
assert "kernel=emulated" in s


test_dtypes = (
[torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
if TORCH_VERSION_AT_LEAST_2_8
else [
torch.float8_e4m3fn,
]
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
)
@pytest.mark.skipif(not is_sm_at_least_100, reason="Reqs sm100")
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn])
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("compile", [True, False])
@torch.no_grad()
Expand All @@ -394,7 +411,16 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):

m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
m_mx = copy.deepcopy(m)
config = MXFPInferenceConfig()
kernel_choice = (
MXGemmKernelChoice.CUTLASS
if elem_dtype == torch.float4_e2m1fn_x2
else MXGemmKernelChoice.CUBLAS
)
config = MXFPInferenceConfig(
activation_dtype=elem_dtype,
weight_dtype=elem_dtype,
gemm_kernel_choice=kernel_choice,
)
quantize_(m_mx, config=config)
if compile:
m_mx = torch.compile(m_mx, fullgraph=True)
Expand All @@ -403,4 +429,7 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
y_ref = m(x)
y_mx = m_mx(x)
sqnr = compute_error(y_ref, y_mx)
assert sqnr >= 25.0, f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
SQNR_THRESHOLD = 25.0 if elem_dtype == torch.float8_e4m3fn else 15.0
assert sqnr >= SQNR_THRESHOLD, (
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
)
15 changes: 10 additions & 5 deletions test/prototype/mx_formats/test_mx_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@

from torchao.float8.float8_utils import compute_error
from torchao.ops import mx_fp4_bf16
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_100
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_100,
)

if not TORCH_VERSION_AT_LEAST_2_7:
if not TORCH_VERSION_AT_LEAST_2_8:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand All @@ -25,7 +28,7 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
a = torch.rand((M, K), dtype=dtype, device=device)
b = torch.rand((N, K), dtype=dtype, device=device)

fmt = torch.float8_e4m3fn if format == "fp8" else DTYPE_FP4
fmt = torch.float8_e4m3fn if format == "fp8" else torch.float4_e2m1fn_x2
mx_func = (
partial(torch._scaled_mm, out_dtype=torch.bfloat16)
if format == "fp8"
Expand Down Expand Up @@ -75,7 +78,9 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
],
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
)
@pytest.mark.parametrize("format", ["fp8", "fp4"])
@pytest.mark.parametrize(
"format", ["fp8", "fp4"] if TORCH_VERSION_AT_LEAST_2_8 else ["fp8"]
)
def test_matrix_multiplication(size, format):
M, K, N = size
sqnr = run_matrix_test(M, K, N, format)
Expand Down
7 changes: 3 additions & 4 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
SUPPORTED_ELEM_DTYPES,
Expand Down Expand Up @@ -363,7 +362,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
if pack_fp6:
data_bits = data_bits.reshape(-1, block_size)
data_bits = pack_uint6(data_bits)
elif elem_dtype == DTYPE_FP4:
elif elem_dtype == torch.float4_e2m1fn_x2:
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
) # noqa: E501
Expand Down Expand Up @@ -407,7 +406,7 @@ def test_block_sizes(elem_dtype, B):
"""
Smoke test for various block sizes
"""
if B == 1 and elem_dtype == DTYPE_FP4:
if B == 1 and elem_dtype == torch.float4_e2m1fn_x2:
pytest.skip("unsupported configuration")
elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]:
pytest.skip("unsupported configuration")
Expand All @@ -422,7 +421,7 @@ def test_transpose(elem_dtype, fp4_triton):
"""
Verify that transposing an MX tensor works
"""
if elem_dtype != DTYPE_FP4 and fp4_triton:
if elem_dtype != torch.float4_e2m1fn_x2 and fp4_triton:
pytest.skip("unsupported configuration")

M, K = 128, 256
Expand Down
16 changes: 8 additions & 8 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MX training and inference with native PyTorch

This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware.

## Overall status
Expand Down Expand Up @@ -34,8 +34,8 @@ gemm_kernel_choice = MXGemmKernelChoice.CUBLAS

m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
config = MXLinearConfig(
elem_dtype=torch.float8_e4m3fn,
block_size=32,
elem_dtype=torch.float8_e4m3fn,
block_size=32,
gemm_kernel_choice=gemm_kernel_choice,
)
quantize_(m, config)
Expand All @@ -55,8 +55,8 @@ from torchao.prototype.mx_formats import MXInferenceLinearConfig, MXGemmKernelCh
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
config = MXInferenceLinearConfig(
elem_dtype=torch.float8_e4m3fn,
block_size=32,
elem_dtype=torch.float8_e4m3fn,
block_size=32,
gemm_kernel_choice=gemm_kernel_choice,
)
quantize_(m, config=config)
Expand All @@ -71,10 +71,10 @@ only `torch.float32` and `torch.bfloat16` are supported as high precision format
```python
from torchao.prototype.mx_formats.mx_tensor import MXTensor
# Note: MX int8 is not implemented yet
from torchao.prototype.mx_formats.constants import DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4
from torchao.prototype.mx_formats.constants import DTYPE_FP6_E2M3, DTYPE_FP6_E3M2
x = torch.randn(32, 32, device='cuda')

# elem_dtype can be torch.float8_e4m3fn, torch.float8_e5m2, DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4
# elem_dtype can be torch.float8_e4m3fn, torch.float8_e5m2, DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, torch.float4_e2m1fn_x2
elem_dtype = torch.float8_e4m3fn

# high precision to MX, block size defaults to 32
Expand All @@ -88,7 +88,7 @@ x_hp = x_mx.to_dtype(torch.float)

## mxfp8 gemm

On NVIDIA B200 machines, we use the cuBLAS mxfp8 gemm exposed via the `torch._scaled_mm` op.
On NVIDIA B200 machines, we use the cuBLAS mxfp8 gemm exposed via the `torch._scaled_mm` op.
We observe a speedup of **2x to 3x** vs the bf16 baseline on common shapes. To reproduce this
on supported hardware, you can run the following command:

Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/mx_formats/benchmarks/bench_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from torchao.prototype.mx_formats import config
from torchao.prototype.mx_formats.constants import ( # noqa: E501
DTYPE_FP4,
SUPPORTED_ELEM_DTYPES,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor
Expand All @@ -44,7 +43,8 @@ def run(profile_folder: Optional[str] = None):
)

if (
elem_dtype != DTYPE_FP4 and use_fp4_custom_triton_dequant_kernel # noqa: E501
elem_dtype != torch.float4_e2m1fn_x2
and use_fp4_custom_triton_dequant_kernel # noqa: E501
):
# custom_triton_kernels only works for fp4
continue
Expand Down
8 changes: 4 additions & 4 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from torchao.core.config import AOBaseConfig
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
DTYPE_TO_SHORT_STR,
Expand Down Expand Up @@ -53,7 +52,7 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
assert block_size == 32, (
f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {block_size}"
)
valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4]
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
assert elem_dtype in valid_dtypes, (
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
)
Expand Down Expand Up @@ -126,10 +125,11 @@ def from_recipe_name(
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS)
elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED:
return MXLinearConfig(elem_dtype=DTYPE_FP4)
return MXLinearConfig(elem_dtype=torch.float4_e2m1fn_x2)
elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS:
return MXLinearConfig(
elem_dtype=DTYPE_FP4, gemm_kernel_choice=MXGemmKernelChoice.CUTLASS
elem_dtype=torch.float4_e2m1fn_x2,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
)
else:
raise AssertionError(f"unknown recipe_name {recipe_name}")
Expand Down
12 changes: 9 additions & 3 deletions torchao/prototype/mx_formats/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# LICENSE file in the root directory of this source tree.
import torch

from torchao.utils import TORCH_VERSION_AT_LEAST_2_8

# This is conceptually an enum of non-core dtypes
# TODO(future PR): change to a cleaner way to represent this without
# regressing torch.compile and while keeping things readable.
DTYPE_FP4 = "fp4_e2m1"
DTYPE_FP6_E3M2 = "fp6_e3m2"
DTYPE_FP6_E2M3 = "fp6_e2m3"

Expand All @@ -19,16 +20,21 @@
torch.float8_e5m2,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
DTYPE_FP4,
]
SUPPORTED_ELEM_DTYPES = (
SUPPORTED_ELEM_DTYPES + [torch.float4_e2m1fn_x2]
if TORCH_VERSION_AT_LEAST_2_8
else SUPPORTED_ELEM_DTYPES
)

DTYPE_TO_SHORT_STR = {
torch.float8_e4m3fn: "f8e4m3",
torch.float8_e5m2: "f8e5m2",
DTYPE_FP6_E2M3: "f6e2m3",
DTYPE_FP6_E3M2: "f6e3m2",
DTYPE_FP4: "f4e2m1",
}
if TORCH_VERSION_AT_LEAST_2_8:
DTYPE_TO_SHORT_STR[torch.float4_e2m1fn_x2] = "f4e2m1"

F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
F8E5M2_MAX = torch.finfo(torch.float8_e5m2).max # 57344.0
Expand Down
5 changes: 2 additions & 3 deletions torchao/prototype/mx_formats/fp_format_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import torch

from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
)
Expand Down Expand Up @@ -494,7 +493,7 @@ def run(dtype):
headers = ["orig_val", "formula", "s_enc", "e_enc", "m_enc", "note"]
results = []

if dtype == DTYPE_FP4:
if dtype == torch.float4_e2m1fn_x2:
results = float4_e2m1_interesting_values
elif dtype == DTYPE_FP6_E3M2:
results = float6_e3m2_interesting_values
Expand Down Expand Up @@ -539,6 +538,6 @@ def run(dtype):
torch.float8_e5m2,
DTYPE_FP6_E3M2,
DTYPE_FP6_E2M3,
DTYPE_FP4,
torch.float4_e2m1fn_x2,
):
run(dtype)
Loading