# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# mypy: ignore-errors
# This test takes a long time to run
import copy
import gc
import tempfile
import unittest
import warnings
from pathlib import Path

import torch
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
from torch.testing._internal import common_utils
from torch.testing._internal.common_quantization import TestHelperModules
from torch.testing._internal.common_utils import TestCase

from torchao import quantize_
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao._models.llama.tokenizer import get_tokenizer
from torchao.dtypes import (
    AffineQuantizedTensor,
    Int4CPULayout,
    Int4XPULayout,
    PlainLayout,
    TensorCoreTiledLayout,
)
from torchao.quantization import (
    Int4TilePackedTo4dTensor,
    IntxUnpackedToInt8Tensor,
    LinearActivationQuantizedTensor,
    PerGroup,
)
from torchao.quantization.quant_api import (
    Float8DynamicActivationFloat8WeightConfig,
    Float8StaticActivationFloat8WeightConfig,
    Float8WeightOnlyConfig,
    FPXWeightOnlyConfig,
    GemliteUIntXWeightOnlyConfig,
    Int4DynamicActivationInt4WeightConfig,
    Int4WeightOnlyConfig,
    Int8DynamicActivationInt4WeightConfig,
    Int8DynamicActivationInt8WeightConfig,
    Int8DynamicActivationIntxWeightConfig,
    Int8WeightOnlyConfig,
    IntxWeightOnlyConfig,
    ModuleFqnToConfig,
    Quantizer,
    TwoStepQuantizer,
    UIntXWeightOnlyConfig,
    _replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.utils import compute_error
from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
    is_sm_at_least_89,
    is_sm_at_least_90,
    torch_version_at_least,
    unwrap_tensor_subclass,
)

try:
    import gemlite  # noqa: F401

    has_gemlite = True
except ModuleNotFoundError:
    has_gemlite = False


def dynamic_quant(model, example_inputs):
    m = torch.export.export(model, example_inputs, strict=True).module()
    quantizer = XNNPACKQuantizer().set_global(
        get_symmetric_quantization_config(is_dynamic=True)
    )
    m = prepare_pt2e(m, quantizer)
    m = convert_pt2e(m)
    return m


def capture_and_prepare(model, example_inputs):
    m = torch.export.export(model, example_inputs, strict=True)
    quantizer = XNNPACKQuantizer().set_global(
        get_symmetric_quantization_config(is_dynamic=True)
    )
    m = prepare_pt2e(m, quantizer)
    # TODO: we can run the weight observer in convert_pt2e so that user don't need to run this
    m(*example_inputs)
    return m


class XNNPackDynamicQuantizer(TwoStepQuantizer):
    def prepare(self, model: torch.nn.Module) -> torch.nn.Module:
        _replace_with_custom_fn_if_matches_filter(
            model,
            lambda linear_mod: capture_and_prepare(
                linear_mod, (torch.randn(1, linear_mod.in_features))
            ),
            lambda mod, fqn: isinstance(mod, torch.nn.Linear),
        )
        return model

    def convert(self, model: torch.nn.Module) -> torch.nn.Module:
        _replace_with_custom_fn_if_matches_filter(
            model,
            lambda linear_mod: convert_pt2e(linear_mod),
            lambda mod, fqn: isinstance(mod, torch.fx.GraphModule),
        )
        return model


class TorchCompileDynamicQuantizer(Quantizer):
    def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
        quantize_(model, Int8DynamicActivationInt8WeightConfig())
        return model


class ToyLinearModel(torch.nn.Module):
    def __init__(self, m=64, n=32, k=64, bias=False):
        super().__init__()
        self.linear1 = torch.nn.Linear(m, n, bias=bias).to(torch.float)
        self.linear2 = torch.nn.Linear(n, k, bias=bias).to(torch.float)

    def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
        return (
            torch.randn(
                batch_size, self.linear1.in_features, dtype=dtype, device=device
            ),
        )

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
    def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
        """
        The deprecated implementation for weight only quant API, used as a reference for
        numerics and performance
        """
        from torchao.quantization.quant_api import _get_subclass_inserter, _is_linear

        filter_fn = kwargs.pop("filter_fn", _is_linear)

        _replace_with_custom_fn_if_matches_filter(
            model,
            _get_subclass_inserter(
                deprecated_tenosr_subclass, enable_parametrization=True, **kwargs
            ),
            filter_fn,
        )

    return _ref_change_linear_weights_to_woqtensors


class TestQuantFlow(TestCase):
    GPU_DEVICES = (["cuda"] if torch.cuda.is_available() else []) + (
        ["xpu"] if torch.xpu.is_available() else []
    )

    def test_dynamic_quant_gpu_singleline(self):
        m = ToyLinearModel().eval()
        example_inputs = m.example_inputs()
        quantize_(m, Int8DynamicActivationInt8WeightConfig())
        m(*example_inputs)
        # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
        # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
        # m = torch.compile(m, mode="max-autotune")
        # print(example_inputs[0].dtype)
        # compiled = m(*example_inputs)
        # torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)

    @unittest.skip("skipping for now due to torch.compile error")
    def test_dynamic_quant_gpu_unified_api_unified_impl(self):
        quantizer = XNNPackDynamicQuantizer()
        m = ToyLinearModel().eval()
        example_inputs = m.example_inputs()
        m = quantizer.prepare(m)
        m = quantizer.convert(m)
        quantized = m(*example_inputs)
        # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
        # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
        m = torch.compile(m, mode="max-autotune")
        # print(example_inputs[0].dtype)
        compiled = m(*example_inputs)
        torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)

    @unittest.skip(
        "FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!"
    )
    def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
        quantizer = TorchCompileDynamicQuantizer()
        m = ToyLinearModel().eval()
        example_inputs = m.example_inputs()
        m = quantizer.quantize(m)
        quantized = m(*example_inputs)
        m = torch.compile(m, mode="max-autotune")
        compiled = m(*example_inputs)
        torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)

    @unittest.skipIf(not torch.xpu.is_available(), "Need XPU available")
    @unittest.skipIf(not torch_version_at_least("2.8.0"), "only works for torch 2.8+")
    def test_int4_wo_quant_save_load(self):
        m = ToyLinearModel().eval().cpu()

        def api(model):
            quantize_(model, Int4WeightOnlyConfig(layout=Int4XPULayout(), version=1))
            unwrap_tensor_subclass(model)

        api(m)

        example_inputs = m.example_inputs()
        ref = m(*example_inputs)
        with tempfile.NamedTemporaryFile() as f:
            torch.save(m.state_dict(), f)
            f.seek(0)
            state_dict = torch.load(f)

        m2 = ToyLinearModel().eval().cpu()
        api(m2)

        m2.load_state_dict(state_dict)
        m2 = m2.to(device="xpu")
        example_inputs = map(lambda x: x.xpu(), example_inputs)
        res = m2(*example_inputs)

        torch.testing.assert_close(ref, res.cpu())

    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    def test_int8_wo_quant_save_load(self):
        m = ToyLinearModel().eval().cpu()

        def api(model):
            quantize_(model, Int8WeightOnlyConfig())
            unwrap_tensor_subclass(model)

        api(m)

        example_inputs = m.example_inputs()
        ref = m(*example_inputs)
        with tempfile.NamedTemporaryFile() as f:
            torch.save(m.state_dict(), f)
            f.seek(0)
            state_dict = torch.load(f)

        m2 = ToyLinearModel().eval().cpu()
        api(m2)

        m2.load_state_dict(state_dict)
        m2 = m2.to(device="cuda")
        example_inputs = map(lambda x: x.cuda(), example_inputs)
        res = m2(*example_inputs)

        # TODO: figure out why ROCm has a larger error
        atol, rtol = (1e-2, 1e-2) if torch.version.hip else (None, None)
        torch.testing.assert_close(ref, res.cpu(), atol=atol, rtol=rtol)

    def test_8da4w_quantizer(self):
        from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightLinear
        from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

        quantizer = Int8DynActInt4WeightQuantizer(groupsize=32)
        m = ToyLinearModel().eval()
        example_inputs = m.example_inputs()
        m = quantizer.quantize(m)
        assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
        assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
        m(*example_inputs)

    def test_8da4w_quantizer_linear_bias(self):
        from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightLinear
        from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

        quantizer = Int8DynActInt4WeightQuantizer(groupsize=32)
        m = ToyLinearModel(bias=True).eval()
        example_inputs = m.example_inputs()
        m = quantizer.quantize(m)
        assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
        assert isinstance(m.linear2, Int8DynActInt4WeightLinear)
        m(*example_inputs)

    @unittest.skip("skipping until we get checkpoints for gpt-fast")
    def test_quantizer_int4_weight_only(self):
        from torchao._models._eval import TransformerEvalWrapper
        from torchao.quantization.linear_quant_modules import Int4WeightOnlyQuantizer

        precision = torch.bfloat16
        device = "cuda"
        checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
        model = Transformer.from_name(checkpoint_path.parent.name)
        checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
        model.load_state_dict(checkpoint, assign=True)
        model = model.to(dtype=precision, device=device)
        model.eval()
        tokenizer_path = checkpoint_path.parent / "tokenizer.model"
        assert tokenizer_path.is_file(), tokenizer_path
        tokenizer = get_tokenizer(  # pyre-ignore[28]
            tokenizer_path,
            "Llama-2-7b-chat-hf",
        )
        groupsize = 64
        quantizer = Int4WeightOnlyQuantizer(
            groupsize,
        )
        model = quantizer.quantize(model).cuda()
        result = TransformerEvalWrapper(
            model,
            tokenizer,
            model.config.block_size,
            prepare_inputs_for_model,
            device,
        ).run_eval(
            ["wikitext"],
            1,
        )
        assert result["results"]["wikitext"]["word_perplexity,none"] < 8.24, (
            f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
        )

    @unittest.skip("skipping until we get checkpoints for gpt-fast")
    def test_eval_wrapper(self):
        from torchao._models._eval import TransformerEvalWrapper

        precision = torch.bfloat16
        device = "cuda"
        checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
        model = Transformer.from_name(checkpoint_path.parent.name)
        checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
        model.load_state_dict(checkpoint, assign=True)
        model = model.to(dtype=precision, device=device)
        model.eval()
        tokenizer_path = checkpoint_path.parent / "tokenizer.model"
        assert tokenizer_path.is_file(), tokenizer_path
        tokenizer = get_tokenizer(  # pyre-ignore[28]
            tokenizer_path,
            "Llama-2-7b-chat-hf",
        )
        result = TransformerEvalWrapper(
            model,
            tokenizer,
            model.config.block_size,
            prepare_inputs_for_model,
            device,
        ).run_eval(
            ["wikitext"],
            1,
        )
        assert result["results"]["wikitext"]["word_perplexity,none"] < 7.77, (
            f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
        )

    # EVAL IS CURRENTLY BROKEN FOR LLAMA 3, VERY LOW ACCURACY
    @unittest.skip("skipping until we get checkpoints for gpt-fast")
    def test_eval_wrapper_llama3(self):
        from torchao._models._eval import TransformerEvalWrapper

        precision = torch.bfloat16
        device = "cuda"
        checkpoint_path = Path(
            ".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth"
        )
        model = Transformer.from_name(checkpoint_path.parent.name)
        checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
        model.load_state_dict(checkpoint, assign=True)
        model = model.to(dtype=precision, device=device)
        model.eval()
        tokenizer_path = checkpoint_path.parent / "tokenizer.model"
        assert tokenizer_path.is_file(), tokenizer_path
        tokenizer = get_tokenizer(  # pyre-ignore[28]
            tokenizer_path,
            "Meta-Llama-3-8B",
        )
        result = TransformerEvalWrapper(
            model,
            tokenizer,
            model.config.block_size,
            prepare_inputs_for_model,
            device,
        ).run_eval(
            ["wikitext"],
            1,
        )
        assert result["results"]["wikitext"]["word_perplexity,none"] < 8.24, (
            f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
        )

    # TODO: move to a separate test file
    @common_utils.parametrize(
        "mapping_type", [MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR]
    )
    def test_quantized_tensor_subclass_8da4w(self, mapping_type):
        group_size = 32
        m = ToyLinearModel().eval()
        m_copy = copy.deepcopy(m)
        example_inputs = m.example_inputs()
        quantize_(
            m,
            Int8DynamicActivationInt4WeightConfig(
                group_size=group_size, mapping_type=mapping_type
            ),
        )

        assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
        assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
        assert isinstance(
            m.linear1.weight.original_weight_tensor, AffineQuantizedTensor
        )
        assert isinstance(
            m.linear2.weight.original_weight_tensor, AffineQuantizedTensor
        )

        # reference
        from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightLinear
        from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

        quantizer = Int8DynActInt4WeightQuantizer(
            groupsize=group_size, mapping_type=mapping_type
        )
        m_copy = quantizer.quantize(m_copy)
        assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear)
        assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear)

        res = m(*example_inputs)
        ref = m_copy(*example_inputs)
        self.assertTrue(torch.equal(res, ref))

    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    def test_quantized_tensor_subclass_save_load(self):
        m = ToyLinearModel().eval().to(torch.bfloat16)
        m_copy = copy.deepcopy(m)
        example_inputs = m.example_inputs(dtype=torch.bfloat16)

        quantize_(m, Int8WeightOnlyConfig())
        ref = m(*example_inputs)
        with tempfile.NamedTemporaryFile() as f:
            torch.save(m.state_dict(), f)
            f.seek(0)
            state_dict = torch.load(f)

        m_copy.load_state_dict(state_dict, assign=True)

        res = m_copy(*example_inputs)
        self.assertEqual(res, ref)

    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    def test_int8wo_quantized_model_to_device(self):
        m = ToyLinearModel().eval().to(torch.bfloat16)
        example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu")

        quantize_(m, Int8WeightOnlyConfig())
        ref = m(*example_inputs)

        example_inputs_cuda = (example_inputs[0].to("cuda"),)
        m.to(device="cuda")
        cuda_res = m(*example_inputs_cuda)
        self.assertEqual(cuda_res.cpu(), ref)

    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    def test_quantized_tensor_subclass_save_load_map_location(self):
        m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device="cuda")
        example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")

        quantize_(m, Int8WeightOnlyConfig())
        ref = m(*example_inputs)
        with tempfile.NamedTemporaryFile() as f:
            torch.save(m.state_dict(), f)
            f.seek(0)
            state_dict = torch.load(f.name, map_location="cpu", mmap=True)

        with torch.device("meta"):
            m_copy = ToyLinearModel().eval()

        m_copy.load_state_dict(state_dict, assign=True)
        m_copy.to(dtype=torch.bfloat16, device="cuda")

        res = m_copy(*example_inputs)
        self.assertEqual(res, ref)

    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    def test_quantized_model_streaming(self):
        def reset_memory():
            gc.collect()
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()

        reset_memory()
        m = ToyLinearModel()
        quantize_(m.to(device="cuda"), Int8WeightOnlyConfig())
        memory_baseline = torch.cuda.max_memory_allocated()

        del m
        reset_memory()
        m = ToyLinearModel()
        quantize_(m, Int8WeightOnlyConfig(), device="cuda")
        memory_streaming = torch.cuda.max_memory_allocated()

        for param in m.parameters():
            assert param.is_cuda
        self.assertLess(memory_streaming, memory_baseline)

    @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
    @common_utils.parametrize("x_dim", [2, 3])
    @common_utils.parametrize("use_hqq", [True, False])
    def test_int4wo_cpu(self, dtype, x_dim, use_hqq):
        device = "cpu"
        m = ToyLinearModel().eval().to(dtype).to(device)
        example_inputs = m.example_inputs(dtype=dtype, device=device)
        if x_dim == 3:
            example_inputs = (example_inputs[0].unsqueeze(0),)

        with torch.no_grad():
            quantize_(
                m,
                Int4WeightOnlyConfig(
                    group_size=32, layout=Int4CPULayout(), use_hqq=use_hqq, version=1
                ),
            )
            # ensure the expected op is in the code
            _, code = torch._inductor.utils.run_and_get_code(
                torch.compile(m, fullgraph=True, dynamic=True),
                *example_inputs,
            )
            assert "_weight_int4pack_mm_for_cpu" in code[0]
            assert "aten.mm.default" not in code[0]

    # TODO(#1690): move to new config names
    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    @common_utils.parametrize(
        "config",
        [
            Int4WeightOnlyConfig(version=1),
            Float8WeightOnlyConfig(),
            Float8DynamicActivationFloat8WeightConfig(),
            Float8StaticActivationFloat8WeightConfig(scale=torch.tensor([1.0])),
            Int4DynamicActivationInt4WeightConfig(),
            Int8DynamicActivationInt8WeightConfig(),
            Int8DynamicActivationInt4WeightConfig(),
            Int8WeightOnlyConfig(),
            FPXWeightOnlyConfig(ebits=4, mbits=3),
            GemliteUIntXWeightOnlyConfig(),
            UIntXWeightOnlyConfig(dtype=torch.uint4),
        ],
    )
    @skip_if_rocm("ROCm enablement in progress")
    def test_workflow_e2e_numerics(self, config):
        """
        Simple test of e2e Int4WeightOnlyConfig workflow, comparing numerics
        to a bfloat16 baseline.
        """
        if (
            isinstance(
                config,
                (
                    Float8DynamicActivationFloat8WeightConfig,
                    Float8StaticActivationFloat8WeightConfig,
                ),
            )
            and not is_sm_at_least_89()
        ):
            return unittest.skip("requires CUDA capability 8.9 or greater")
        elif (
            isinstance(config, Int4DynamicActivationInt4WeightConfig)
            and is_sm_at_least_90()
        ):
            return unittest.skip("only supported on CUDA capability 8.9, not greater")
        elif isinstance(config, GemliteUIntXWeightOnlyConfig) and not has_gemlite:
            return unittest.skip("gemlite not available")

        # scale has to be moved to cuda here because the parametrization init
        # code happens before gating for cuda availability
        if isinstance(config, Float8StaticActivationFloat8WeightConfig):
            config.scale = config.scale.to("cuda")

        dtype = torch.bfloat16
        if isinstance(config, GemliteUIntXWeightOnlyConfig):
            dtype = torch.float16

        # set up inputs
        x = torch.randn(128, 128, device="cuda", dtype=dtype)
        # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
        # is that expected?
        m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype)
        m_q = copy.deepcopy(m_ref)

        # quantize
        quantize_(m_q, config)

        with torch.no_grad():
            y_ref = m_ref(x)
            y_q = m_q(x)

        sqnr = compute_error(y_ref, y_q)
        assert sqnr >= 16.5, f"SQNR {sqnr} is too low"

    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    def test_module_fqn_to_config_default(self):
        config1 = Int4WeightOnlyConfig(group_size=32, version=1)
        config2 = Int8WeightOnlyConfig()
        config = ModuleFqnToConfig({"_default": config1, "linear2": config2})
        model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
        example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
        quantize_(model, config)
        model(*example_inputs)
        assert isinstance(model.linear1.weight, AffineQuantizedTensor)
        assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout)
        assert isinstance(model.linear2.weight, AffineQuantizedTensor)
        assert isinstance(model.linear2.weight._layout, PlainLayout)

    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    def test_module_fqn_to_config_module_name(self):
        config1 = Int4WeightOnlyConfig(group_size=32, version=1)
        config2 = Int8WeightOnlyConfig()
        config = ModuleFqnToConfig({"linear1": config1, "linear2": config2})
        model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
        example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
        quantize_(model, config)
        model(*example_inputs)
        assert isinstance(model.linear1.weight, AffineQuantizedTensor)
        assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout)
        assert isinstance(model.linear2.weight, AffineQuantizedTensor)
        assert isinstance(model.linear2.weight._layout, PlainLayout)

    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    def test_module_fqn_to_config_regex_basic(self):
        config1 = Int4WeightOnlyConfig(
            group_size=32, int4_packing_format="tile_packed_to_4d"
        )
        config = ModuleFqnToConfig({"re:linear.*": config1})
        model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
        example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
        quantize_(model, config)
        model(*example_inputs)
        assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
        assert isinstance(model.linear2.weight, Int4TilePackedTo4dTensor)

    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    def test_module_fqn_to_config_regex_precedence(self):
        """Testing that full path config takes precedence over
        regex config in ModuleFqnToConfig
        """
        config1 = Int4WeightOnlyConfig(
            group_size=32, int4_packing_format="tile_packed_to_4d"
        )
        config2 = IntxWeightOnlyConfig()
        config = ModuleFqnToConfig({"linear1": config1, "re:linear.*": config2})
        model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
        example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
        quantize_(model, config)
        model(*example_inputs)
        assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
        assert isinstance(model.linear2.weight, IntxUnpackedToInt8Tensor)

    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    def test_module_fqn_to_config_regex_precedence2(self):
        """Testing that full path config takes precedence over
        regex config in ModuleFqnToConfig, swapping
        the order of `re:linear.*` and `linear1` to make sure that
        `linear1` config has precedence even it comes after `linear*`
        """
        config1 = Int4WeightOnlyConfig(
            group_size=32, int4_packing_format="tile_packed_to_4d"
        )
        config2 = IntxWeightOnlyConfig()
        config = ModuleFqnToConfig({"re:linear.*": config2, "linear1": config1})
        model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
        example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
        quantize_(model, config)
        model(*example_inputs)
        assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
        assert isinstance(model.linear2.weight, IntxUnpackedToInt8Tensor)

    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    def test_module_fqn_to_config_regex_fullmatch(self):
        """Testing that we will only match the fqns that fully
        matches the regex
        """

        class M(torch.nn.Module):
            def __init__(self, dtype, device):
                super().__init__()
                self.dtype = dtype
                self.device = device
                self.linear1 = torch.nn.Linear(32, 64, dtype=dtype, device=device)
                self.not_full_match_linear2 = torch.nn.Linear(
                    64, 32, dtype=dtype, device=device
                )
                self.linear3_full_match = torch.nn.Linear(
                    32, 32, dtype=dtype, device=device
                )

            def forward(self, x):
                x = self.linear1(x)
                x = self.not_full_match_linear2(x)
                x = self.linear3_full_match(x)
                return

            def example_inputs(self):
                return (torch.randn(1, 32, dtype=self.dtype, device=self.device),)

        config1 = Int4WeightOnlyConfig(
            group_size=32, int4_packing_format="tile_packed_to_4d"
        )
        config2 = IntxWeightOnlyConfig()
        config = ModuleFqnToConfig({"re:linear.*": config2, "linear1": config1})
        model = M(dtype=torch.bfloat16, device="cuda")
        example_inputs = model.example_inputs()
        quantize_(model, config)
        model(*example_inputs)
        assert isinstance(model.linear1.weight, Int4TilePackedTo4dTensor)
        # since fqn does not fully match `linear*`, it should not be quantized
        assert not isinstance(
            model.not_full_match_linear2.weight, IntxUnpackedToInt8Tensor
        )
        # linear3_full_match matches `linear*`, so should be quantized
        assert isinstance(model.linear3_full_match.weight, IntxUnpackedToInt8Tensor)

    def test_module_fqn_to_config_embedding_linear(self):
        weight_dtype = torch.int8
        granularity = PerGroup(8)
        mapping_type = MappingType.SYMMETRIC
        embedding_config = IntxWeightOnlyConfig(
            weight_dtype=weight_dtype,
            granularity=granularity,
            mapping_type=mapping_type,
        )
        # example model linear is Linear(16, 8)
        linear_config = Int8DynamicActivationIntxWeightConfig(
            weight_dtype=torch.int4,
            weight_granularity=PerGroup(16),
        )

        config = ModuleFqnToConfig({"emb": embedding_config, "linear": linear_config})
        indices = torch.randint(0, 10, (32,))
        indices = indices.unsqueeze(0)
        example_inputs = (indices,)
        model = TestHelperModules.EmbeddingConvLinearModule().eval()
        model(*example_inputs)
        quantize_(
            model,
            config,
            filter_fn=lambda x, fqn: isinstance(x, torch.nn.Linear)
            or isinstance(x, torch.nn.Embedding),
        )
        model(*example_inputs)

        assert isinstance(model.emb.weight, IntxUnpackedToInt8Tensor)
        assert isinstance(model.linear.weight, IntxUnpackedToInt8Tensor)

    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    def test_module_fqn_to_config_skip(self):
        config1 = Int4WeightOnlyConfig(group_size=32, version=1)
        config = ModuleFqnToConfig({"_default": config1, "linear2": None})
        model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
        example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
        quantize_(model, config)
        model(*example_inputs)
        assert isinstance(model.linear1.weight, AffineQuantizedTensor)
        assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout)
        assert not isinstance(model.linear2.weight, AffineQuantizedTensor)

    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
    def test_int4wo_cuda_serialization(self):
        config = Int4WeightOnlyConfig(group_size=32, version=1)
        model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
        # quantize in cuda
        quantize_(model, config)
        example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
        model(*example_inputs)
        with tempfile.NamedTemporaryFile() as ckpt:
            # save checkpoint in cuda
            torch.save(model.state_dict(), ckpt)
            # load checkpoint on cpu then move checkpoint to cuda
            # This is what torchtune does: https://github.com/pytorch/torchtune/blob/v0.6.1/torchtune/training/checkpointing/_utils.py#L253
            sd = torch.load(ckpt.name, weights_only=False, map_location="cpu")
            for k, v in sd.items():
                sd[k] = v.to("cuda")
            # load state_dict in cuda
            model.load_state_dict(sd, assign=True)

    def test_config_deprecation(self):
        """
        Test that old config functions like `int4_weight_only` trigger deprecation warnings.
        """
        from torchao.quantization import (
            float8_dynamic_activation_float8_weight,
            float8_static_activation_float8_weight,
            float8_weight_only,
            fpx_weight_only,
            gemlite_uintx_weight_only,
            int4_dynamic_activation_int4_weight,
            int4_weight_only,
            int8_dynamic_activation_int4_weight,
            int8_dynamic_activation_int8_weight,
            int8_weight_only,
            uintx_weight_only,
        )

        # Reset deprecation warning state, otherwise we won't log warnings here
        warnings.resetwarnings()

        # Map from deprecated API to the args needed to instantiate it
        deprecated_apis_to_args = {
            float8_dynamic_activation_float8_weight: (),
            float8_static_activation_float8_weight: (torch.randn(3)),
            float8_weight_only: (),
            fpx_weight_only: (3, 2),
            gemlite_uintx_weight_only: (),
            int4_dynamic_activation_int4_weight: (),
            int4_weight_only: (),
            int8_dynamic_activation_int4_weight: (),
            int8_dynamic_activation_int8_weight: (),
            int8_weight_only: (),
            uintx_weight_only: (torch.uint4,),
        }

        with warnings.catch_warnings(record=True) as _warnings:
            # Call each deprecated API twice
            for cls, args in deprecated_apis_to_args.items():
                cls(*args)
                cls(*args)

            # Each call should trigger the warning only once
            self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
            for w in _warnings:
                self.assertIn(
                    "is deprecated and will be removed in a future release",
                    str(w.message),
                )


common_utils.instantiate_parametrized_tests(TestQuantFlow)


if __name__ == "__main__":
    unittest.main()
