diff --git a/generate.py b/generate.py index 6d359e0..fe6a5c3 100755 --- a/generate.py +++ b/generate.py @@ -1,10 +1,16 @@ +# 对原有的generate.py进行修改,添加量化支持 + import torch -from typing import Optional +from typing import Optional, List +from lite_llama.utils.prompt_templates import get_prompter, get_image_token +from lite_llama.generate_stream import GenerateStreamText +from lite_llama.utils.image_process import vis_images + +import warnings + warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils") from lite_llama.utils.common import get_gpu_memory, detect_device, count_tokens, get_model_type -from lite_llama.utils.prompt_templates import get_prompter -from lite_llama.generate_stream import GenerateStreamText # 导入 GenerateText 类 -import warnings +from lite_llama.llava_generate_stream import LlavaGeneratorStream import sys, os, time from pathlib import Path @@ -13,16 +19,20 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) import psutil +from lite_llama.utils.logger import log + +# 新增导入 +from lite_llama.quantization.quant_manager import quantization_manager, QuantizationType process = psutil.Process(os.getpid()) -def report_resource_usage(ram_before, vram_before, gpu_type) -> None: +def report_resource_usage(ram_before, vram_before) -> None: end_time = time.time() ram_after = process.memory_info().rss - vram_after = get_gpu_memory(gpu_type) + vram_after = get_gpu_memory(detect_device()) - ram_used = (ram_after - ram_before) / (1024**3) # Bytes to GB + ram_used = (ram_after - ram_before) / (1024 ** 3) # Bytes to GB if vram_before is not None and vram_after is not None: vram_used = vram_after - vram_before @@ -30,57 +40,57 @@ def report_resource_usage(ram_before, vram_before, gpu_type) -> None: else: vram_text = "Unavailable" - print(f"CPU RAM Used: {ram_used:.2f} GB") - print(f"GPU VRAM Used: {vram_text}") - - -def main( - prompt: str = "Hello, my name is", - *, - temperature: float = 0.6, - top_p: float = 0.9, - max_seq_len: int = 2048, - max_gpu_num_blocks=40960, - max_gen_len: Optional[int] = 1024, - load_model: bool = True, - compiled_model: bool = False, - triton_weight: bool = True, - gpu_type: str = "nvidia", - checkpoint_path: Path = Path("checkpoints/lit-llama/7B/"), - quantize: Optional[str] = None, + log.info(f"CPU RAM Used: {ram_used:.2f} GB") + log.info(f"GPU VRAM Used: {vram_text}") + + +def generate_llama( + prompt: str = "Hello, my name is", + quantization: Optional[str] = None, # 新增参数 + *, + temperature: float = 0.6, + top_p: float = 0.9, + max_seq_len: int = 2048, + max_gpu_num_blocks=40960, + max_gen_len: Optional[int] = 1024, + compiled_model: bool = False, + gpu_type: str = "nvidia", + checkpoint_path: Path = Path("checkpoints/lit-llama/7B/"), ): - device = "cuda" if torch.cuda.is_available() else "cpu" + device = 'cuda' if torch.cuda.is_available() else 'cpu' assert checkpoint_path.is_dir(), checkpoint_path checkpoint_path = str(checkpoint_path) + # 检测量化类型 + if quantization is None: + quantization = quantization_manager.detect_quantization_type(checkpoint_path) + if quantization != QuantizationType.NONE: + log.info(f"Automatically detect the quantization type: {quantization}") + if max_seq_len <= 1024: short_prompt = True else: short_prompt = False - model_prompter = get_prompter( - get_model_type(checkpoint_path), checkpoint_path, short_prompt - ) + model_prompter = get_prompter(get_model_type(checkpoint_path), checkpoint_path, short_prompt) + # Start resource tracking ram_before = process.memory_info().rss - - gpu_type = detect_device() vram_before = get_gpu_memory(gpu_type) - # Init LLM generator - start = time.perf_counter() + # 创建生成器,传入量化参数 generator = GenerateStreamText( checkpoints_dir=checkpoint_path, tokenizer_path=checkpoint_path, max_gpu_num_blocks=max_gpu_num_blocks, max_seq_len=max_seq_len, - load_model=load_model, compiled_model=compiled_model, - triton_weight=triton_weight, device=device, + quantization=quantization, # 新增参数 ) model_prompter.insert_prompt(prompt) prompts = [model_prompter.model_input] + # Call the generation function and start the stream generation stream = generator.text_completion_stream( prompts, @@ -88,28 +98,297 @@ def main( top_p=top_p, max_gen_len=max_gen_len, ) - end = time.perf_counter() - completion = "" # Initialize to generate the result - # NOTE: After creating a generator, it can be iterated through a for loop + completion = '' text_msg = "" + start = time.perf_counter() for batch_completions in stream: - new_text = batch_completions[0]["generation"][len(completion) :] - completion = batch_completions[0]["generation"] - print(new_text, end="", flush=True) + new_text = batch_completions[0]['generation'][len(completion):] + completion = batch_completions[0]['generation'] + print(new_text, end='', flush=True) text_msg += new_text + end = time.perf_counter() print("\n\n==================================\n") - print( - f"Time for inference: {(end - start):.2f} sec, {count_tokens(text_msg, generator.tokenizer) / (end - start):.2f} tokens/sec" - ) + log.info( + f"Time for inference: {(end - start):.2f} sec, {count_tokens(text_msg, generator.tokenizer) / (end - start):.2f} tokens/sec") # Report resource usage - report_resource_usage(ram_before, vram_before, gpu_type) + report_resource_usage(ram_before, vram_before) + + +def generate_llava( + prompt: str = "Hello, my name is", + checkpoint_path: Path = Path("checkpoints/lit-llama/7B/"), + figure_path: Path = Path("figures/lit-llama/"), + gpu_type: str = "nvidia", + quantization: Optional[str] = None, # 新增参数 + temperature: float = 0.6, + top_p: float = 0.9, + max_seq_len: int = 2048, + max_gpu_num_blocks=None, + max_gen_len: Optional[int] = 512, + compiled_model: bool = False, +): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # 检测量化类型 + if quantization is None: + quantization = quantization_manager.detect_quantization_type(str(checkpoint_path)) + if quantization != QuantizationType.NONE: + log.info(f"Automatically detect the quantization type: {quantization}") + + if max_seq_len <= 1024: + short_prompt = True + else: + short_prompt = False + + if not os.path.isfile(figure_path): + log.error(f"'{figure_path}' Not a valid file path!") + else: + image_input = str(figure_path).strip() + image_items = [image_input] + image_num = len(image_items) + vis_images(image_items) + assert checkpoint_path.is_dir(), checkpoint_path + checkpoint_path = str(checkpoint_path) + model_prompter = get_prompter("llama", checkpoint_path, short_prompt) + + # Start resource tracking + ram_before = process.memory_info().rss + vram_before = get_gpu_memory(gpu_type) + + try: + generator = LlavaGeneratorStream( + checkpoints_dir=checkpoint_path, + tokenizer_path=checkpoint_path, + max_gpu_num_blocks=max_gpu_num_blocks, + max_seq_len=max_seq_len, + compiled_model=compiled_model, + device=device, + quantization=quantization, # 新增参数 + ) + except Exception as e: + log.error(f"Model loading failure: {e}") + sys.exit(1) + + image_token = get_image_token() + model_prompter.insert_prompt(image_token * image_num + prompt) + prompts = [model_prompter.model_input] + + try: + stream = generator.text_completion_stream( + prompts, + image_items, + temperature=temperature, + top_p=top_p, + max_gen_len=max_gen_len, + ) + except Exception as e: + log.error(f"Text Generation Failure: {e}") + + completion = '' + text_msg = "" + start = time.perf_counter() + + for batch_completions in stream: + next_text = batch_completions[0]['generation'][len(completion):] + completion = batch_completions[0]['generation'] + print(f"\033[91m{next_text}\033[0m", end='', flush=True) + text_msg += next_text + end = time.perf_counter() + + print("\n\n==================================\n") + log.info( + f"Time for inference: {(end - start):.2f} sec, {count_tokens(text_msg, generator.tokenizer) / (end - start):.2f} tokens/sec") + + # Report resource usage + report_resource_usage(ram_before, vram_before) if __name__ == "__main__": from jsonargparse import CLI torch.set_float32_matmul_precision("high") + + + def main( + prompt: str = "Hello, my name is", + checkpoint_path: Path = Path("checkpoints/lite-llama/7B/"), + figure_path: Optional[Path] = None, + quantization: Optional[str] = None, # 新增参数 + ): + """ + Generate text using lite_llama with optional quantization support + + Args: + prompt: Input prompt text + checkpoint_path: Path to model checkpoint directory + figure_path: Path to Image file for LLaVA generation, optional + quantization: Quantization method ('gptq', 'awq', 'smoothquant', or None for auto-detection) + """ + gpu_type = detect_device() + model_path = os.path.abspath(checkpoint_path) + + # 验证量化参数 + if quantization and quantization not in ['gptq', 'awq', 'smoothquant']: + log.error(f"不支持的量化方法: {quantization}") + log.info("支持的量化方法: gptq, awq, smoothquant") + return + + if figure_path: + generate_llava( + prompt=prompt, + checkpoint_path=Path(model_path), + figure_path=Path(figure_path), + gpu_type=gpu_type, + quantization=quantization, + ) + else: + generate_llama( + prompt=prompt, + checkpoint_path=Path(model_path), + gpu_type=gpu_type, + quantization=quantization + ) + + CLI(main) + + +# 新增量化推理的便捷函数 +def run_quantized_inference( + model_path: str, + prompt: str, + quantization_method: Optional[str] = None, + **kwargs +): + """ + 运行量化推理的便捷函数 + + Args: + model_path: 模型路径 + prompt: 输入提示 + quantization_method: 量化方法,None为自动检测 + **kwargs: 其他推理参数 + """ + + # 检查模型是否存在 + if not os.path.exists(model_path): + raise FileNotFoundError(f"模型路径不存在: {model_path}") + + # 获取模型类型 + model_type = get_model_type(model_path) + + # 设置默认参数 + default_params = { + 'temperature': 0.6, + 'top_p': 0.9, + 'max_seq_len': 2048, + 'max_gen_len': 1024, + 'compiled_model': False, + } + default_params.update(kwargs) + + if model_type == 'llava': + # LLaVA模型需要图像输入 + figure_path = kwargs.get('figure_path') + if not figure_path: + log.warning("LLaVA模型需要图像输入,将使用默认图像") + # 这里可以设置一个默认图像路径 + + generate_llava( + prompt=prompt, + checkpoint_path=Path(model_path), + figure_path=Path(figure_path) if figure_path else None, + quantization=quantization_method, + **default_params + ) + else: + generate_llama( + prompt=prompt, + checkpoint_path=Path(model_path), + quantization=quantization_method, + **default_params + ) + + +# 量化性能测试函数 +def benchmark_quantized_model( + model_path: str, + quantization_methods: Optional[List[str]] = None, + test_prompts: Optional[List[str]] = None, + num_runs: int = 3 +): + """ + 对量化模型进行性能基准测试 + + Args: + model_path: 模型路径 + quantization_methods: 要测试的量化方法列表 + test_prompts: 测试提示列表 + num_runs: 每个配置的运行次数 + """ + + if quantization_methods is None: + quantization_methods = ['gptq', 'awq', 'smoothquant', None] # None代表无量化 + + if test_prompts is None: + test_prompts = [ + "What is artificial intelligence?", + "Explain quantum computing in simple terms.", + "Write a short story about a robot." + ] + + results = {} + + for method in quantization_methods: + method_name = method or "no_quantization" + log.info(f"测试量化方法: {method_name}") + + method_results = [] + + for prompt in test_prompts: + prompt_results = [] + + for run in range(num_runs): + log.info(f"运行 {run + 1}/{num_runs}: {prompt[:50]}...") + + start_time = time.time() + try: + run_quantized_inference( + model_path=model_path, + prompt=prompt, + quantization_method=method, + max_gen_len=256 # 限制生成长度以便快速测试 + ) + end_time = time.time() + prompt_results.append(end_time - start_time) + + except Exception as e: + log.error(f"测试失败 ({method_name}, run {run + 1}): {e}") + prompt_results.append(float('inf')) + + method_results.append(prompt_results) + + results[method_name] = method_results + + # 打印结果摘要 + log.info("=" * 60) + log.info("基准测试结果摘要") + log.info("=" * 60) + + for method_name, method_results in results.items(): + avg_times = [] + for prompt_results in method_results: + valid_times = [t for t in prompt_results if t != float('inf')] + if valid_times: + avg_times.append(sum(valid_times) / len(valid_times)) + + if avg_times: + overall_avg = sum(avg_times) / len(avg_times) + log.info(f"{method_name:15}: {overall_avg:.2f}s 平均响应时间") + else: + log.info(f"{method_name:15}: 测试失败") + + return results \ No newline at end of file diff --git a/lite_llama/models/quantized_models.py b/lite_llama/models/quantized_models.py new file mode 100644 index 0000000..8139409 --- /dev/null +++ b/lite_llama/models/quantized_models.py @@ -0,0 +1,353 @@ +""" +Quantized Model Builder for lite_llama +Creates quantized versions of supported models +""" +import torch +import torch.nn as nn +from typing import Dict, Any, Optional, Union +import copy + +from .llama import LlamaModel, FusedAttention as LlamaAttention, FusedMLP as LlamaMLP +from .qwen2 import Qwen2Model, Qwen2Attention, FusedMLP as Qwen2MLP +from .qwen3 import Qwen3Model, Qwen3Attention, FusedMLP as Qwen3MLP +from .llava import LlavaLlama +from .model_config import LlamaConfig, Qwen2Config, Qwen3Config + +# Import quantized layers +from lite_llama.kernels.awq_linear import AWQLinear +from lite_llama.kernels.gptq_linear import GPTQLinear +from lite_llama.kernels.sq_linear import SmoothQuantLinear + +from ..quantization.quant_manager import QuantizationType + + +class QuantizedAttentionMixin: + """量化Attention层的Mixin""" + + def replace_linear_with_quantized(self, quantization_method: str, config: Dict[str, Any]): + """替换线性层为量化层""" + + if quantization_method == QuantizationType.GPTQ: + # 替换投影层为GPTQ量化层 + if hasattr(self, 'q_proj'): + self.q_proj = self._create_gptq_linear(self.q_proj, config) + if hasattr(self, 'k_proj'): + self.k_proj = self._create_gptq_linear(self.k_proj, config) + if hasattr(self, 'v_proj'): + self.v_proj = self._create_gptq_linear(self.v_proj, config) + if hasattr(self, 'o_proj'): + self.o_proj = self._create_gptq_linear(self.o_proj, config) + # 处理融合的kv_proj权重 + if hasattr(self, 'kv_proj_weight'): + # 需要特殊处理融合权重 + pass + + elif quantization_method == QuantizationType.AWQ: + # 替换为AWQ量化层 + if hasattr(self, 'q_proj'): + self.q_proj = self._create_awq_linear(self.q_proj, config) + if hasattr(self, 'k_proj'): + self.k_proj = self._create_awq_linear(self.k_proj, config) + if hasattr(self, 'v_proj'): + self.v_proj = self._create_awq_linear(self.v_proj, config) + if hasattr(self, 'o_proj'): + self.o_proj = self._create_awq_linear(self.o_proj, config) + + elif quantization_method == QuantizationType.SMOOTHQUANT: + # 替换为SmoothQuant量化层 + if hasattr(self, 'q_proj'): + self.q_proj = self._create_sq_linear(self.q_proj, config) + if hasattr(self, 'k_proj'): + self.k_proj = self._create_sq_linear(self.k_proj, config) + if hasattr(self, 'v_proj'): + self.v_proj = self._create_sq_linear(self.v_proj, config) + if hasattr(self, 'o_proj'): + self.o_proj = self._create_sq_linear(self.o_proj, config) + + def _create_gptq_linear(self, original_layer: nn.Linear, config: Dict[str, Any]) -> GPTQLinear: + """创建GPTQ量化线性层""" + gptq_layer = GPTQLinear( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias is not None, + dtype=torch.float16, + bits=config.get('w_bit', 4), + groupsize=config.get('group_size', 128), + device=config.get('device', 'cuda') + ) + return gptq_layer + + def _create_awq_linear(self, original_layer: nn.Linear, config: Dict[str, Any]) -> AWQLinear: + """创建AWQ量化线性层""" + awq_layer = AWQLinear( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias is not None, + group_size=config.get('group_size', 128), + wbits=config.get('w_bit', 4) + ) + return awq_layer + + def _create_sq_linear(self, original_layer: nn.Linear, config: Dict[str, Any]) -> SmoothQuantLinear: + """创建SmoothQuant量化线性层""" + from ..quantization.quant_config import SmoothQuantConfig + sq_config = SmoothQuantConfig(**config) + + sq_layer = SmoothQuantLinear( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias is not None, + ) + return sq_layer + + +class QuantizedMLPMixin: + """量化MLP层的Mixin""" + + def replace_linear_with_quantized(self, quantization_method: str, config: Dict[str, Any]): + """替换线性层为量化层""" + + if quantization_method == QuantizationType.GPTQ: + self.gate_proj = self._create_gptq_linear(self.gate_proj, config) + self.up_proj = self._create_gptq_linear(self.up_proj, config) + self.down_proj = self._create_gptq_linear(self.down_proj, config) + + elif quantization_method == QuantizationType.AWQ: + self.gate_proj = self._create_awq_linear(self.gate_proj, config) + self.up_proj = self._create_awq_linear(self.up_proj, config) + self.down_proj = self._create_awq_linear(self.down_proj, config) + + elif quantization_method == QuantizationType.SMOOTHQUANT: + self.gate_proj = self._create_sq_linear(self.gate_proj, config) + self.up_proj = self._create_sq_linear(self.up_proj, config) + self.down_proj = self._create_sq_linear(self.down_proj, config) + + def _create_gptq_linear(self, original_layer: nn.Linear, config: Dict[str, Any]) -> GPTQLinear: + """创建GPTQ量化线性层""" + gptq_layer = GPTQLinear( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias is not None, + dtype=torch.float16, + bits=config.get('w_bit', 4), + groupsize=config.get('group_size', 128), + device=config.get('device', 'cuda') + ) + return gptq_layer + + def _create_awq_linear(self, original_layer: nn.Linear, config: Dict[str, Any]) -> AWQLinear: + """创建AWQ量化线性层""" + awq_layer = AWQLinear( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias is not None, + group_size=config.get('group_size', 128), + wbits=config.get('w_bit', 4) + ) + return awq_layer + + def _create_sq_linear(self, original_layer: nn.Linear, config: Dict[str, Any]) -> SmoothQuantLinear: + """创建SmoothQuant量化线性层""" + from ..quantization.quant_config import SmoothQuantConfig + + sq_layer = SmoothQuantLinear( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias is not None, + ) + return sq_layer + + +# 创建量化版本的Attention层 +class QuantizedLlamaAttention(LlamaAttention, QuantizedAttentionMixin): + def __init__(self, config: LlamaConfig, quantization_method: str, quantization_config: Dict[str, Any]): + super().__init__(config) + self.replace_linear_with_quantized(quantization_method, quantization_config) + + +class QuantizedQwen2Attention(Qwen2Attention, QuantizedAttentionMixin): + def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int, + quantization_method: str, quantization_config: Dict[str, Any], dtype=torch.float16): + super().__init__(hidden_size, num_heads, num_kv_heads, dtype) + self.replace_linear_with_quantized(quantization_method, quantization_config) + + +class QuantizedQwen3Attention(Qwen3Attention, QuantizedAttentionMixin): + def __init__(self, config: Qwen3Config, quantization_method: str, quantization_config: Dict[str, Any]): + super().__init__(config) + self.replace_linear_with_quantized(quantization_method, quantization_config) + + +# 创建量化版本的MLP层 +class QuantizedLlamaMLP(LlamaMLP, QuantizedMLPMixin): + def __init__(self, config: LlamaConfig, quantization_method: str, quantization_config: Dict[str, Any]): + super().__init__(config) + self.replace_linear_with_quantized(quantization_method, quantization_config) + + +class QuantizedQwen2MLP(Qwen2MLP, QuantizedMLPMixin): + def __init__(self, config: Qwen2Config, quantization_method: str, quantization_config: Dict[str, Any]): + super().__init__(config) + self.replace_linear_with_quantized(quantization_method, quantization_config) + + +class QuantizedQwen3MLP(Qwen3MLP, QuantizedMLPMixin): + def __init__(self, config: Qwen3Config, quantization_method: str, quantization_config: Dict[str, Any]): + super().__init__(config) + self.replace_linear_with_quantized(quantization_method, quantization_config) + + +def create_quantized_model( + model_config: Union[LlamaConfig, Qwen2Config, Qwen3Config], + quantization_method: str, + quantization_config: Dict[str, Any], + device: str = "cuda" +) -> torch.nn.Module: + """创建量化模型""" + + model_type = model_config.model_type.lower() + + if model_type == "llama": + model = create_quantized_llama(model_config, quantization_method, quantization_config, device) + elif model_type == "qwen2": + model = create_quantized_qwen2(model_config, quantization_method, quantization_config, device) + elif model_type == "qwen3": + model = create_quantized_qwen3(model_config, quantization_method, quantization_config, device) + elif model_type == "llava": + model = create_quantized_llava(model_config, quantization_method, quantization_config, device) + else: + raise ValueError(f"不支持的模型类型: {model_type}") + + return model.to(device) + + +def create_quantized_llama( + config: LlamaConfig, + quantization_method: str, + quantization_config: Dict[str, Any], + device: str +) -> LlamaModel: + """创建量化的Llama模型""" + + # 创建基础模型 + model = LlamaModel(config) + + # 替换层为量化版本 + for i, layer in enumerate(model.layers): + # 替换attention + quantized_attention = QuantizedLlamaAttention( + config, quantization_method, quantization_config + ) + + # 复制权重信息(在实际加载时会被覆盖) + layer.self_attn = quantized_attention + + # 替换MLP + quantized_mlp = QuantizedLlamaMLP( + config, quantization_method, quantization_config + ) + layer.mlp = quantized_mlp + + # 替换lm_head如果需要 + if quantization_method in [QuantizationType.GPTQ, QuantizationType.AWQ]: + if quantization_method == QuantizationType.GPTQ: + quantized_lm_head = GPTQLinear( + in_features=model.lm_head.in_features, + out_features=model.lm_head.out_features, + bias=model.lm_head.bias is not None, + dtype=torch.float16, + bits=quantization_config.get('w_bit', 4), + groupsize=quantization_config.get('group_size', 128), + device=device + ) + else: # AWQ + quantized_lm_head = AWQLinear( + in_features=model.lm_head.in_features, + out_features=model.lm_head.out_features, + bias=model.lm_head.bias is not None, + group_size=quantization_config.get('group_size', 128), + wbits=quantization_config.get('w_bit', 4) + ) + + model.lm_head = quantized_lm_head + + return model + + +def create_quantized_qwen2( + config: Qwen2Config, + quantization_method: str, + quantization_config: Dict[str, Any], + device: str +) -> Qwen2Model: + """创建量化的Qwen2模型""" + + # 创建基础模型 + model = Qwen2Model(config) + + # 替换层为量化版本 + for i, layer in enumerate(model.layers): + # 替换attention + quantized_attention = QuantizedQwen2Attention( + config.hidden_size, config.num_heads, config.num_kv_heads, + quantization_method, quantization_config + ) + layer.self_attn = quantized_attention + + # 替换MLP + quantized_mlp = QuantizedQwen2MLP( + config, quantization_method, quantization_config + ) + layer.mlp = quantized_mlp + + return model + + +def create_quantized_qwen3( + config: Qwen3Config, + quantization_method: str, + quantization_config: Dict[str, Any], + device: str +) -> Qwen3Model: + """创建量化的Qwen3模型""" + + # 创建基础模型 + model = Qwen3Model(config) + + # 替换层为量化版本 + for i, layer in enumerate(model.layers): + # 替换attention + quantized_attention = QuantizedQwen3Attention( + config, quantization_method, quantization_config + ) + layer.self_attn = quantized_attention + + # 替换MLP + quantized_mlp = QuantizedQwen3MLP( + config, quantization_method, quantization_config + ) + layer.mlp = quantized_mlp + + return model + + +def create_quantized_llava( + config: Any, # LlavaConfig + quantization_method: str, + quantization_config: Dict[str, Any], + device: str +) -> LlavaLlama: + """创建量化的LLaVA模型""" + + # 创建基础模型 + model = LlavaLlama(config) + + # 量化language_model部分 + llama_config = model.llama_config + quantized_language_model = create_quantized_llama( + llama_config, quantization_method, quantization_config, device + ) + + model.language_model = quantized_language_model + + return model \ No newline at end of file diff --git a/lite_llama/quantization/__init__.py b/lite_llama/quantization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lite_llama/quantization/quant_config.py b/lite_llama/quantization/quant_config.py new file mode 100644 index 0000000..bef81d3 --- /dev/null +++ b/lite_llama/quantization/quant_config.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass, field +from typing import List + + +@dataclass +class AWQConfig: + + """Configuration for AWQ quantization""" + w_bit: int = 4 # Weight quantization bits + group_size: int = 128 # Group size for quantization + zero_point: bool = True # Whether to use zero point + version: str = "GEMM" # GEMM or GEMV + calib_data_size: int = 128 # Calibration dataset size + search_scale: bool = False # Whether to search for optimal scales + auto_scale: bool = True # Automatic scaling + device: str = "cuda" + alpha: float = 0.5 + + +@dataclass +class GPTQConfig: + """GPTQ量化配置""" + w_bit: int = 4 + group_size: int = 64 # 减少组大小提高压缩率 + device: str = "cuda" + quantize_embedding: bool = True + quantize_lm_head: bool = True + adaptive_group_size: bool = True # 自适应组大小 + optimize_for_compression: bool = True # 优化压缩率 + + + +@dataclass +class SmoothQuantConfig: + """Configuration for SmoothQuant""" + alpha: float = 0.5 # Smoothing factor balance between act and weight + w_bit: int = 8 # Weight quantization bits + a_bit: int = 8 # Activation quantization bits + device: str = "cuda" + symmetric_weight: bool = True # Use symmetric quantization for weights + symmetric_activation: bool = False # Use asymmetric quantization for activations + per_channel_weight: bool = True # Per-channel quantization for weights + per_token_activation: bool = True # Per-token quantization for activations + calibration_samples: int = 128 # Number of calibration samples + smooth_layers: List[str] = field(default_factory=lambda: [ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj" + ]) + +@dataclass +class QuantLayerConfig: + """Configuration for QuantLayer""" + quant_layers: List[str] = field(default_factory=lambda: [ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + "kv_proj", "lm_head" + ]) \ No newline at end of file diff --git a/lite_llama/quantization/quant_manager.py b/lite_llama/quantization/quant_manager.py new file mode 100644 index 0000000..ab9ccdf --- /dev/null +++ b/lite_llama/quantization/quant_manager.py @@ -0,0 +1,268 @@ +""" +Quantization Manager for lite_llama +Provides unified interface for GPTQ, AWQ, and SmoothQuant +""" +import os +import json +import torch +import torch.nn as nn +from typing import Dict, Optional, Union, Any, List +from pathlib import Path +from tqdm import tqdm + +from .awq import AWQ, quantize_awq +from .gptq import GPTQ, quantize_gptq +from .sq import SmoothQuantizer, apply_smoothquant +from .quant_config import AWQConfig, GPTQConfig, SmoothQuantConfig, QuantLayerConfig + +# Import quantized linear layers +from ..kernels.awq_linear import AWQLinear +from ..kernels.gptq_linear import GPTQLinear +from ..kernels.sq_linear import SmoothQuantLinear + +from ..utils.logger import get_logger +logger = get_logger(__name__) + +class QuantizationType: + NONE = "none" + GPTQ = "gptq" + AWQ = "awq" + SMOOTHQUANT = "smoothquant" + INT4 = "int4" + INT8 = "int8" + + +class QuantizationManager: + """统一的量化管理器""" + + def __init__(self): + self.supported_methods = { + QuantizationType.GPTQ: self._load_gptq, + QuantizationType.AWQ: self._load_awq, + QuantizationType.SMOOTHQUANT: self._load_smoothquant, + } + + def detect_quantization_type(self, model_path: str) -> str: + """Automatically detect the quantization type of the model""" + model_path = Path(model_path) + + # 检查量化配置文件 + quant_config_path = model_path / "quantization_config.json" + if quant_config_path.exists(): + with open(quant_config_path, 'r') as f: + config = json.load(f) + return config.get("quantization_method", QuantizationType.NONE) + + # 通过权重文件名检测 + weight_files = list(model_path.glob("*.pth")) + if weight_files: + state_dict = torch.load(weight_files[0], map_location="cpu") + + # 检查是否有量化相关的键 + for key in state_dict.keys(): + if "qweight" in key and "qzeros" in key: + if "qscales" in key: + return QuantizationType.AWQ + elif "scales" in key: + return QuantizationType.GPTQ + elif "weight_scale" in key and "smoothing_factor" in key: + return QuantizationType.SMOOTHQUANT + + return QuantizationType.NONE + + def quantize_model( + self, + model_path: str, + output_path: str, + method: str, + config: Optional[Dict] = None, + calibration_data: Optional[Any] = None, + model: Optional[torch.nn.Module] = None + ) -> str: + """Quantitative model""" + log.info(f"Using the {method} method to quantify the model...") + + # 加载原始模型状态字典 + model_path = Path(model_path) + weight_files = list(model_path.glob("*.pth")) + if not weight_files: + raise ValueError(f"The weight file was not found in {model_path}") + + state_dict = torch.load(weight_files[0], map_location="cpu") + + # 根据方法进行量化 + if method == QuantizationType.GPTQ: + config = config or {} + gptq_config = GPTQConfig(**config) + quantized_state_dict = quantize_gptq( + model_state_dict=state_dict, + target_layers=self._get_target_layers(state_dict), + device=gptq_config.device + ) + + elif method == QuantizationType.AWQ: + config = config or {} + awq_config = AWQConfig(**config) + quantized_state_dict = quantize_awq( + model_state_dict=state_dict, + model=model, + config=awq_config, + target_layers=self._get_target_layers(state_dict), + device=awq_config.device + ) + + elif method == QuantizationType.SMOOTHQUANT: + config = config or {} + config.smooth_layers = self._get_target_layers(state_dict) + sq_config = SmoothQuantConfig(**config) + quantized_state_dict = apply_smoothquant( + model_state_dict=state_dict, + config=sq_config, + ) + + else: + raise ValueError(f"Unsupported quantitative methods: {method}") + + # 保存量化后的模型 + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + # 保存权重 + torch.save( + quantized_state_dict, + output_path / f"{model_path.name}.pth", + _use_new_zipfile_serialization=True + ) + + # 复制其他文件 + for file in model_path.glob("*.json"): + if file.name != "quantization_config.json": + import shutil + shutil.copy2(file, output_path) + + # 复制tokenizer文件 + for file in model_path.glob("tokenizer*"): + import shutil + shutil.copy2(file, output_path) + + # 保存量化配置 + quant_config = { + "quantization_method": method, + "config": config, + "quantized_at": torch.cuda.get_device_name() if torch.cuda.is_available() else "cpu" + } + + with open(output_path / "quantization_config.json", 'w') as f: + json.dump(quant_config, f, indent=2) + + log.info(f"Quantification completed! Saved to: {output_path}") + return str(output_path) + + def load_quantized_model( + self, + model_path: str, + model_config: Any, + device: str = "cuda" + ) -> torch.nn.Module: + """Load the quantized model""" + quant_type = self.detect_quantization_type(model_path) + + if quant_type == QuantizationType.NONE: + # 正常加载非量化模型 + return self._load_normal_model(model_path, model_config, device) + + if quant_type in self.supported_methods: + return self.supported_methods[quant_type](model_path, model_config, device) + else: + raise ValueError(f"Unsupported quantization types: {quant_type}") + + def _load_gptq(self, model_path: str, model_config: Any, device: str) -> torch.nn.Module: + """Load the GPTQ quantitative model""" + from ..models.quantized_models import create_quantized_model + + # 读取量化配置 + quant_config_path = Path(model_path) / "quantization_config.json" + with open(quant_config_path, 'r') as f: + quant_config = json.load(f) + + # 创建量化模型 + model = create_quantized_model( + model_config=model_config, + quantization_method=QuantizationType.GPTQ, + quantization_config=quant_config.get("config", {}), + device=device + ) + + # 加载量化权重 + weight_files = list(Path(model_path).glob("*.pth")) + state_dict = torch.load(weight_files[0], map_location=device) + model.load_state_dict(state_dict, strict=False) + + return model + + def _load_awq(self, model_path: str, model_config: Any, device: str) -> torch.nn.Module: + """Load the AWQ quantification model""" + from ..models.quantized_models import create_quantized_model + + # 读取量化配置 + quant_config_path = Path(model_path) / "quantization_config.json" + with open(quant_config_path, 'r') as f: + quant_config = json.load(f) + + # 创建量化模型 + model = create_quantized_model( + model_config=model_config, + quantization_method=QuantizationType.AWQ, + quantization_config=quant_config.get("config", {}), + device=device + ) + + # 加载量化权重 + weight_files = list(Path(model_path).glob("*.pth")) + state_dict = torch.load(weight_files[0], map_location=device) + model.load_state_dict(state_dict, strict=False) + + return model + + def _load_smoothquant(self, model_path: str, model_config: Any, device: str) -> torch.nn.Module: + """Load the SmoothQuant quantitative model""" + from ..models.quantized_models import create_quantized_model + + # 读取量化配置 + quant_config_path = Path(model_path) / "quantization_config.json" + with open(quant_config_path, 'r') as f: + quant_config = json.load(f) + + # 创建量化模型 + model = create_quantized_model( + model_config=model_config, + quantization_method=QuantizationType.SMOOTHQUANT, + quantization_config=quant_config.get("config", {}), + device=device + ) + + # 加载量化权重 + weight_files = list(Path(model_path).glob("*.pth")) + state_dict = torch.load(weight_files[0], map_location=device) + model.load_state_dict(state_dict, strict=False) + + return model + + def _load_normal_model(self, model_path: str, model_config: Any, device: str) -> torch.nn.Module: + """加载非量化模型 - 这里需要调用原有的模型加载逻辑""" + # 这里应该调用现有的模型加载逻辑 + # 需要根据具体的模型架构来实现 + pass + + def _get_target_layers(self, state_dict: Dict[str, torch.Tensor]) -> List[str]: + """Obtain the layers that need to be quantified""" + target_layers = [] + quant_layer = QuantLayerConfig() + for name in state_dict.keys(): + if any(pattern in name for pattern in quant_layer.quant_layers): + target_layers.append(name) + return target_layers + + +# 全局量化管理器实例 +quantization_manager = QuantizationManager() \ No newline at end of file diff --git a/lite_llama/quantization/utils.py b/lite_llama/quantization/utils.py new file mode 100644 index 0000000..f827814 --- /dev/null +++ b/lite_llama/quantization/utils.py @@ -0,0 +1,49 @@ +import torch + + +def pack_weight(weight): + """ + Pack two 4-bit values into one uint8 value consistently + + Args: + weight: Tensor of shape [out_features, in_features] with values in [0, 15] + + Returns: + packed: Tensor of shape [out_features, in_features//2] with packed values + """ + rows, cols = weight.shape + + # Ensure even number of columns for packing + if cols % 2 != 0: + weight = torch.cat([weight, torch.zeros(rows, 1, dtype=weight.dtype, device=weight.device)], dim=1) + cols += 1 + + # Pack: lower 4 bits from even indices, upper 4 bits from odd indices + # Format: [odd_value << 4] | even_value + packed = (weight[:, 0::2] & 0xF) | ((weight[:, 1::2] & 0xF) << 4) + + return packed.contiguous().to(torch.uint8) + + +def unpack_weight(packed_weight, original_cols): + """ + Unpack uint8 values back to two 4-bit values consistently + + Args: + packed_weight: Packed tensor of shape [out_features, packed_cols] + original_cols: Original number of columns before packing + + Returns: + unpacked: Tensor of shape [out_features, original_cols] with unpacked values + """ + rows, packed_cols = packed_weight.shape + + # Allocate unpacked tensor + unpacked = torch.zeros((rows, packed_cols * 2), dtype=torch.uint8, device=packed_weight.device) + + # Unpack: even positions get lower 4 bits, odd positions get upper 4 bits + unpacked[:, 0::2] = packed_weight & 0xF # Lower 4 bits + unpacked[:, 1::2] = (packed_weight >> 4) & 0xF # Upper 4 bits + + # Trim to original size + return unpacked[:, :original_cols].contiguous() \ No newline at end of file diff --git a/lite_llama/utils/common.py b/lite_llama/utils/common.py index 55dbbc9..ebb04b2 100644 --- a/lite_llama/utils/common.py +++ b/lite_llama/utils/common.py @@ -2,7 +2,9 @@ import time, os import subprocess from typing import List, Optional - +import torch +from contextlib import contextmanager +import functools def read_json(json_path): with open(json_path, "r") as json_file: @@ -37,7 +39,7 @@ def getProjectPath(): return os.path.abspath(os.path.join(script_path, "..")) -def get_gpu_memory(gpu_type="amd", device_id="0"): +def get_gpu_memory(gpu_type, device_id="0"): try: if gpu_type == "amd": result = subprocess.run( @@ -67,7 +69,7 @@ def get_gpu_memory(gpu_type="amd", device_id="0"): elif gpu_type == "cpu": return None except Exception as e: - from utils.logger import log + from lite_llama.utils.logger import log log.warning(f"Unable to fetch GPU memory: {e}") return None @@ -82,7 +84,7 @@ def count_tokens(texts: List[str], tokenizer) -> int: def get_model_type(checkpoint_path: str) -> str | None: - from utils.logger import log + from .logger import log model_type = ["llama", "falcon", "mpt", "qwen2", "llava"] @@ -94,3 +96,69 @@ def get_model_type(checkpoint_path: str) -> str | None: return m log.error(f"No model type found: {checkpoint_path}") return None + + +def check_model_compatibility(model_path): + """Check if the model is compatible for quantization""" + # Check if model path exists and contains .pth files + if not os.path.exists(model_path): + return False, f"Model path does not exist: {model_path}" + + pth_files = [f for f in os.listdir(model_path) if f.endswith('.pth')] + if not pth_files: + return False, f"No .pth files found in {model_path}" + + # Check if required config files exist + config_files = ["config.json", "tokenizer_config.json"] + missing_configs = [f for f in config_files if not os.path.exists(os.path.join(model_path, f))] + if missing_configs: + print(f"Warning: Missing config files: {missing_configs}") + + return True, "Model is compatible" + + +def get_model_info(model_path): + """Get basic information about the model""" + model_info = { + "model_name": os.path.basename(model_path), + "model_type": "unknown", + "size": 0.0 + } + + # Detect model type from path or config + model_name_lower = model_info["model_name"].lower() + if "llava" in model_name_lower: + model_info["model_type"] = "llava" + elif "qwen2" in model_name_lower: + model_info["model_type"] = "qwen2" + elif "llama" in model_name_lower: + model_info["model_type"] = "llama" + + # Try to read from config.json + config_path = os.path.join(model_path, "config.json") + if os.path.exists(config_path): + try: + with open(config_path, 'r') as f: + config = json.load(f) + if "architectures" in config: + arch = config["architectures"][0].lower() + if "llava" in arch: + model_info["model_type"] = "llava" + elif "qwen2" in arch: + model_info["model_type"] = "qwen2" + elif "llama" in arch: + model_info["model_type"] = "llama" + except: + pass + + # Calculate total size + total_size = 0 + for f in os.listdir(model_path): + if f.endswith('.pth'): + file_path = os.path.join(model_path, f) + total_size += os.path.getsize(file_path) + + model_info["size"] = total_size / (1024 ** 3) # Convert to GB + + return model_info + diff --git a/quantize_lite_llama.py b/quantize_lite_llama.py new file mode 100644 index 0000000..34a7661 --- /dev/null +++ b/quantize_lite_llama.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +quantize_lite_llama.py +~~~~~~~~~~~~~~~~~~~~ +用于量化lite_llama格式模型的工具脚本 + +支持GPTQ、AWQ、SmoothQuant三种量化方法 + +Usage +----- +# GPTQ量化 +python quantize_lite_llama.py --model-path /path/to/model --output-path /path/to/output --method gptq --bits 4 --group-size 128 + +# AWQ量化 +python quantize_lite_llama.py --model-path /path/to/model --output-path /path/to/output --method awq --bits 4 --group-size 128 --calib-data /path/to/calib.txt + +# SmoothQuant量化 +python quantize_lite_llama.py --model-path /path/to/model --output-path /path/to/output --method smoothquant --alpha 0.5 +""" + +import argparse +import json +import os +import sys +import time +from pathlib import Path +from typing import Optional, Dict, Any, List +import torch +from tqdm import tqdm + +# Add lite_llama to Python path +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from lite_llama.quantization.quant_manager import quantization_manager, QuantizationType +from lite_llama.quantization.quant_config import AWQConfig, GPTQConfig, SmoothQuantConfig +from lite_llama.utils.common import get_model_info, check_model_compatibility +from lite_llama.utils.logger import get_logger +from lite_llama.executor.model_executor import ModelExecutor +from transformers import AutoTokenizer + +logger = get_logger(__name__) + +class CalibrationDataLoader: + """校准数据加载器""" + + def __init__(self, data_path: str, tokenizer_path: str, max_samples: int = 128, max_length: int = 512): + self.data_path = data_path + self.max_samples = max_samples + self.max_length = max_length + + # 加载分词器 + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # 加载校准数据 + self.texts = self._load_calibration_data() + + def _load_calibration_data(self) -> List[str]: + """加载校准数据""" + texts = [] + + if self.data_path.endswith('.txt'): + # 纯文本文件,每行一个样本 + with open(self.data_path, 'r', encoding='utf-8') as f: + texts = [line.strip() for line in f if line.strip()] + + elif self.data_path.endswith('.json'): + # JSON文件 + with open(self.data_path, 'r', encoding='utf-8') as f: + data = json.load(f) + if isinstance(data, list): + # 假设是文本列表 + texts = [item if isinstance(item, str) else item.get('text', '') for item in data] + else: + # 假设是包含文本字段的对象 + texts = [data.get('text', '')] + + elif self.data_path.endswith('.jsonl'): + # JSONL文件 + with open(self.data_path, 'r', encoding='utf-8') as f: + for line in f: + item = json.loads(line.strip()) + texts.append(item.get('text', '')) + + else: + raise ValueError(f"Unsupported file formats: {self.data_path}") + + # 限制样本数量 + texts = texts[:self.max_samples] + log.info(f"{len(texts)} calibration samples were loaded") + + return texts + + def __len__(self): + return len(self.texts) + + def __iter__(self): + """返回批次数据的迭代器""" + for text in self.texts: + # 编码文本 + encoding = self.tokenizer( + text, + return_tensors='pt', + max_length=self.max_length, + truncation=True, + padding=True + ) + + yield encoding + + +def create_default_calibration_data(tokenizer_path: str, num_samples: int = 32) -> List[str]: + """创建默认的校准数据""" + default_texts = [ + "The quick brown fox jumps over the lazy dog.", + "Artificial intelligence is transforming the world.", + "Machine learning models require careful optimization.", + "Deep neural networks can learn complex patterns.", + "Natural language processing enables human-computer interaction.", + "Computer vision systems can understand visual content.", + "Quantization reduces model size while maintaining accuracy.", + "Large language models demonstrate emergent capabilities.", + "Transformer architectures have revolutionized AI.", + "Self-attention mechanisms capture long-range dependencies." + ] + + # 重复样本以达到所需数量 + texts = (default_texts * ((num_samples // len(default_texts)) + 1))[:num_samples] + log.info(f"Using the default calibration data, there are a total of {len(texts)} samples") + + return texts + + +def validate_quantization_config(method: str, config: Dict[str, Any]) -> Dict[str, Any]: + """验证和标准化量化配置""" + + if method == QuantizationType.GPTQ: + validated_config = { + 'w_bit': config.get('bits', 4), + 'group_size': config.get('group_size', 128), + 'device': config.get('device', 'cuda') + } + + # 验证参数范围 + if validated_config['w_bit'] not in [2, 3, 4, 8]: + raise ValueError(f"The number of bits not supported by GPTQ: {validated_config['w_bit']}") + + elif method == QuantizationType.AWQ: + validated_config = { + 'w_bit': config.get('bits', 4), + 'group_size': config.get('group_size', 128), + 'zero_point': config.get('zero_point', True), + 'search_scale': config.get('search_scale', False), + 'auto_scale': config.get('auto_scale', True), + 'alpha': config.get('alpha', 0.5), + 'device': config.get('device', 'cuda') + } + + if validated_config['w_bit'] not in [4, 8]: + raise ValueError(f"The number of bits not supported by AWQ: {validated_config['w_bit']}") + + elif method == QuantizationType.SMOOTHQUANT: + validated_config = { + 'alpha': config.get('alpha', 0.5), + 'w_bit': config.get('w_bits', 8), + 'a_bit': config.get('a_bits', 8), + 'symmetric_weight': config.get('symmetric_weight', True), + 'symmetric_activation': config.get('symmetric_activation', False), + 'per_channel_weight': config.get('per_channel_weight', True), + 'per_token_activation': config.get('per_token_activation', True), + 'calibration_samples': config.get('calibration_samples', 128), + 'device': config.get('device', 'cuda') + } + + if not (0.0 <= validated_config['alpha'] <= 1.0): + raise ValueError(f"The alpha parameter of SmoothQuant must be between 0 and 1: {validated_config['alpha']}") + + else: + raise ValueError(f"Unsupported quantitative methods: {method}") + + return validated_config + + +def main(): + parser = argparse.ArgumentParser(description="Quantify the model in lite_llama format") + + # 基本参数 + parser.add_argument("--model-path", type=str, required=True, + help="Input model path") + parser.add_argument("--output-path", type=str, required=True, + help="Output model path") + parser.add_argument("--method", type=str, required=True, + choices=['gptq', 'awq', 'smoothquant'], + help="Quantitative method") + + # 量化参数 + parser.add_argument("--bits", type=int, default=4, + help="Quantification bit number (default: 4)") + parser.add_argument("--group-size", type=int, default=128, + help="Group size (default: 128)") + + # AWQ特有参数 + parser.add_argument("--alpha", type=float, default=0.5, + help="The alpha parameter of AWQ/SmoothQuant (default: 0.5)") + parser.add_argument("--search-scale", action='store_true', + help="Does AWQ search for the optimal scaling factor") + parser.add_argument("--auto-scale", action='store_true', default=True, + help="Does AWQ scale automatically") + + # SmoothQuant特有参数 + parser.add_argument("--w-bits", type=int, default=8, + help="Weighted quantification number of bits (SmoothQuant, default: 8)") + parser.add_argument("--a-bits", type=int, default=8, + help="Activation quantization bit number (SmoothQuant, default: 8)") + + # 校准数据 + parser.add_argument("--calib-data", type=str, default=None, + help="Calibrate the data file path (.txt/.json/.jsonl)") + parser.add_argument("--calib-samples", type=int, default=128, + help="Calibration sample quantity (default: 128)") + parser.add_argument("--max-length", type=int, default=512, + help="The maximum length of the calibration data (default: 512)") + + # 其他参数 + parser.add_argument("--device", type=str, default="cuda", + choices=['cuda', 'cpu'], + help="device (default: cuda)") + parser.add_argument("--no-verify", action='store_true', + help="Skip quantitative validation") + + args = parser.parse_args() + + # 检查模型兼容性 + is_compatible, message = check_model_compatibility(args.model_path) + if not is_compatible: + log.error(f"The model compatibility check failed: {message}") + return 1 + + # 获取模型信息 + model_info = get_model_info(args.model_path) + log.info(f"Model information: {model_info}") + + # 准备量化配置 + config = { + 'bits': args.bits, + 'group_size': args.group_size, + 'alpha': args.alpha, + 'search_scale': args.search_scale, + 'auto_scale': args.auto_scale, + 'w_bits': args.w_bits, + 'a_bits': args.a_bits, + 'device': args.device, + 'calibration_samples': args.calib_samples + } + + # 验证配置 + try: + validated_config = validate_quantization_config(args.method, config) + log.info(f"Quantitative configuration: {validated_config}") + except ValueError as e: + log.error(f"Configuration verification failed: {e}") + return 1 + + # 准备校准数据 + calibration_data = None + model = None + + if args.method in ['awq', 'smoothquant']: + log.info("Prepare calibration data...") + + if args.calib_data: + # 使用用户提供的校准数据 + try: + calibration_data = CalibrationDataLoader( + args.calib_data, + args.model_path, + args.calib_samples, + args.max_length + ) + log.info(f"Load calibration data: {len(calibration_data)} samples") + except Exception as e: + log.error(f"Failed to load calibration data: {e}") + log.info("The default calibration data will be used") + calibration_data = create_default_calibration_data( + args.model_path, args.calib_samples + ) + else: + # 使用默认校准数据 + calibration_data = create_default_calibration_data( + args.model_path, args.calib_samples + ) + + # 如果需要,加载原始模型用于校准 + if args.method == 'awq': + log.info("Load the original model for AWQ calibration...") + try: + model_executor = ModelExecutor.build( + checkpoints_dir=args.model_path, + max_seq_len=2048, + max_gpu_num_blocks=None, + compiled_model=False, + device=args.device + ) + model = model_executor.model + log.info("The model has been loaded successfully.") + except Exception as e: + log.error(f"Model loading failed: {e}") + return 1 + + # 执行量化 + log.info(f"Quantifying the model using the {args.method.upper()} method...") + start_time = time.time() + + try: + output_path = quantization_manager.quantize_model( + model_path=args.model_path, + output_path=args.output_path, + method=args.method, + config=validated_config, + calibration_data=calibration_data, + model=model + ) + + quantization_time = time.time() - start_time + log.info(f"Quantification completed! Time consumption: {quantization_time:.2f}s") + log.info(f"The quantitative model saved to: {output_path}") + + except Exception as e: + log.error(f"Quantitative failure: {e}") + return 1 + + # 验证量化结果 + if not args.no_verify: + log.info("Verify the quantification results...") + try: + # 检测量化类型 + detected_type = quantization_manager.detect_quantization_type(output_path) + if detected_type == args.method: + log.info(f"The quantitative type verification has been passed: {detected_type}") + else: + log.warning(f"Quantization type mismatch: expected {args.method}, detected {detected_type}") + + # 检查文件大小 + original_size = sum(f.stat().st_size for f in Path(args.model_path).glob("*.pth")) + quantized_size = sum(f.stat().st_size for f in Path(output_path).glob("*.pth")) + compression_ratio = original_size / quantized_size if quantized_size > 0 else 1.0 + + log.info(f"Original model size: {original_size / (1024 ** 3):.2f} GB") + log.info(f"Quantitative model size: {quantized_size / (1024 ** 3):.2f} GB") + log.info(f"Compression ratio: {compression_ratio:.2f}x") + + except Exception as e: + log.warning(f"Quantitative verification failed: {e}") + + log.info("Quantitative task completion!") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file