From b259c09fe255bd327977280d4f0c2e5eb74cc54b Mon Sep 17 00:00:00 2001 From: caikun-pjlab Date: Fri, 21 Mar 2025 09:54:09 +0800 Subject: [PATCH 01/16] use torch.compile for internvit --- lmdeploy/pytorch/models/internvl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 27bf3d5320..78d90c9625 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -318,6 +318,7 @@ def pixel_shuffle(self, x, scale_factor=0.5): x = x.permute(0, 2, 1, 3).contiguous() return x + @torch.compile(mode="max-autotune-no-cudagraphs") def extract_feature(self, pixel_values): """extract vision feature.""" assert self.select_layer == -1 @@ -350,6 +351,7 @@ def forward( ): if inputs_embeds is None and pixel_values is not None: # extract feature + torch._dynamo.mark_dynamic(pixel_values, 0) vit_embeds = self.extract_feature(pixel_values) lang_embeds = self.language_model.get_input_embeddings()(input_ids) lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) From 66aa3ba6251fe70b7cf5ececb5c710e36e22e298 Mon Sep 17 00:00:00 2001 From: jinminxi104 Date: Sat, 22 Mar 2025 15:34:32 +0000 Subject: [PATCH 02/16] enable micro-batch for internvl --- lmdeploy/pytorch/models/internvl.py | 44 +++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 78d90c9625..64b4fc5372 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -205,15 +205,33 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.ls1 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device)) self.ls2 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device)) - def forward( - self, - hidden_states: torch.Tensor, - ): - """forward.""" - hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1 + def enable_micro_batch(func): + """Decorator to enable micro-batch computation.""" + def wrapper(self, hidden_states, *args, **kwargs): + if isinstance(hidden_states, list): + # Apply forward computation to each micro-batch + return [func(self, hs, *args, **kwargs) for hs in hidden_states] + else: + # If not a list, directly apply the forward computation + return func(self, hidden_states, *args, **kwargs) + return wrapper + @enable_micro_batch + def _attn(self, hidden_states): + hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states[0].dtype)) * self.ls1 + return hidden_states + + @enable_micro_batch + def _mlp(self, hidden_states): hidden_states = hidden_states + self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2 + return hidden_states + def forward( + self, + hidden_states, + ): + hidden_states = self._attn(hidden_states) + hidden_states = self._mlp(hidden_states) return hidden_states @@ -226,6 +244,20 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.layers = nn.ModuleList( [InternVisionEncoderLayer(config, dtype=dtype, device=device) for idx in range(config.num_hidden_layers)]) + def split_inputs_embeds(num_splits=2): + """Decorator to split inputs_embeds along the 0th dimension into a specified number of chunks.""" + def decorator(func): + def wrapper(self, *args, **kwargs): + inputs_embeds = kwargs.get('inputs_embeds', None) + if inputs_embeds is not None: + split_embeds = list(torch.chunk(inputs_embeds, num_splits, dim=0)) + kwargs['inputs_embeds'] = split_embeds + results = func(self, *args, **kwargs) + return torch.cat(results, dim=0) + return wrapper + return decorator + + @split_inputs_embeds(num_splits=2) def forward( self, inputs_embeds, From 0462005de22e8f88a438f6289e1b0daf994220e0 Mon Sep 17 00:00:00 2001 From: caikun-pjlab Date: Mon, 24 Mar 2025 15:02:23 +0800 Subject: [PATCH 03/16] enable reorder_for_compute_comm_overlap --- lmdeploy/pytorch/models/internvl.py | 1 + requirements/runtime_cuda.txt | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 64b4fc5372..dbcb79ab7d 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -383,6 +383,7 @@ def forward( ): if inputs_embeds is None and pixel_values is not None: # extract feature + torch._inductor.config.reorder_for_compute_comm_overlap = True torch._dynamo.mark_dynamic(pixel_values, 0) vit_embeds = self.extract_feature(pixel_values) lang_embeds = self.language_model.get_input_embeddings()(input_ids) diff --git a/requirements/runtime_cuda.txt b/requirements/runtime_cuda.txt index 88c6270877..9169c6ff30 100644 --- a/requirements/runtime_cuda.txt +++ b/requirements/runtime_cuda.txt @@ -16,8 +16,8 @@ safetensors sentencepiece shortuuid tiktoken -torch<=2.5.1,>=2.0.0 -torchvision<=0.20.1,>=0.15.0 +torch<=2.6.0,>=2.0.0 +torchvision<=0.21.0,>=0.15.0 transformers -triton<=3.1.0,>=3.0.0; sys_platform == "linux" +triton<=3.2.0,>=3.0.0; sys_platform == "linux" uvicorn From 92dbc6c047d121415abdba32b4cebebe708bd4ad Mon Sep 17 00:00:00 2001 From: caikun-pjlab Date: Tue, 25 Mar 2025 17:20:31 +0800 Subject: [PATCH 04/16] refactor code --- .../pytorch/backends/cuda/graph_runner.py | 31 +++++++++++++++++++ lmdeploy/pytorch/models/internvl.py | 20 ++---------- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 7f26f59980..6872fa427c 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -4,7 +4,9 @@ import torch from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.pytorch.distributed import get_world_rank from lmdeploy.pytorch.model_inputs import StepContext +from lmdeploy.pytorch.models.internvl import InternVLChatModel, InternVisionModel from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta from lmdeploy.utils import get_logger @@ -102,6 +104,18 @@ def __del__(self): del self._graph +def split_batch(func, param_name, num_splits=2): + """Decorator to split along the 0th dimension into a specified number of chunks.""" + def decorator(*args, **kwargs): + inputs = kwargs.get(param_name, None) + if inputs is not None: + split_inputs = list(torch.chunk(inputs, num_splits, dim=0)) + kwargs[param_name] = split_inputs + results = func(*args, **kwargs) + return torch.cat(results, dim=0) + return decorator + + class CUDAGraphRunner(GraphRunner): """cuda graph runner.""" @@ -112,7 +126,17 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf self.max_tokens = cache_config.max_prefill_token_num self.num_blocks = cache_config.num_gpu_blocks + self.use_compile = False self.enable_graph = self.check_enable_graph() + if not self.backend_config.eager_mode and isinstance(self.model, InternVLChatModel): + world_size, _ = get_world_rank() + if world_size > 1: + torch._inductor.config.reorder_for_compute_comm_overlap = True + if isinstance(self.model.vision_model, InternVisionModel): + self.model.vision_model.encoder.forward = split_batch(self.model.vision_model.encoder.forward, "inputs_embeds") + self.model.extract_feature = torch.compile(self.model.extract_feature, mode="max-autotune") + self.use_compile = True + self.compiled = False self.graph_pool_handle = torch.cuda.graph_pool_handle() self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict() @@ -135,6 +159,13 @@ def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, pas def __call__(self, **kwargs): """call.""" + if self.use_compile and not self.compiled: + for tensor_name, dynamic_dims in self.model.compile_dynamic_args.items(): + tensor = kwargs.get(tensor_name, None) + if tensor is None: + continue + torch._dynamo.mark_dynamic(tensor, dynamic_dims) + self.compiled = True enable_graph = self.enable_graph(**kwargs) if not enable_graph: diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index dbcb79ab7d..e248a3a81b 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -244,20 +244,6 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.layers = nn.ModuleList( [InternVisionEncoderLayer(config, dtype=dtype, device=device) for idx in range(config.num_hidden_layers)]) - def split_inputs_embeds(num_splits=2): - """Decorator to split inputs_embeds along the 0th dimension into a specified number of chunks.""" - def decorator(func): - def wrapper(self, *args, **kwargs): - inputs_embeds = kwargs.get('inputs_embeds', None) - if inputs_embeds is not None: - split_embeds = list(torch.chunk(inputs_embeds, num_splits, dim=0)) - kwargs['inputs_embeds'] = split_embeds - results = func(self, *args, **kwargs) - return torch.cat(results, dim=0) - return wrapper - return decorator - - @split_inputs_embeds(num_splits=2) def forward( self, inputs_embeds, @@ -338,6 +324,9 @@ def __init__(self, self.input_processor = InternVLInputProcessor(self.config, dtype) + # for torch.compile, will call torch._dynamo.mark_dynamic to reduce recompile + self.compile_dynamic_args = {"pixel_values": [0]} + def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale @@ -350,7 +339,6 @@ def pixel_shuffle(self, x, scale_factor=0.5): x = x.permute(0, 2, 1, 3).contiguous() return x - @torch.compile(mode="max-autotune-no-cudagraphs") def extract_feature(self, pixel_values): """extract vision feature.""" assert self.select_layer == -1 @@ -383,8 +371,6 @@ def forward( ): if inputs_embeds is None and pixel_values is not None: # extract feature - torch._inductor.config.reorder_for_compute_comm_overlap = True - torch._dynamo.mark_dynamic(pixel_values, 0) vit_embeds = self.extract_feature(pixel_values) lang_embeds = self.language_model.get_input_embeddings()(input_ids) lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) From d995bdbf9bb38e8621c5004b54d9cb2eb688f004 Mon Sep 17 00:00:00 2001 From: caikun-pjlab Date: Wed, 26 Mar 2025 19:41:29 +0800 Subject: [PATCH 05/16] refactor code --- .../pytorch/backends/cuda/graph_runner.py | 53 ++++++++++--------- lmdeploy/pytorch/decorators.py | 26 +++++++++ lmdeploy/pytorch/models/internvl.py | 11 +--- 3 files changed, 56 insertions(+), 34 deletions(-) create mode 100644 lmdeploy/pytorch/decorators.py diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 6872fa427c..4c6150f531 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -4,6 +4,7 @@ import torch from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.pytorch.decorators import split_batch from lmdeploy.pytorch.distributed import get_world_rank from lmdeploy.pytorch.model_inputs import StepContext from lmdeploy.pytorch.models.internvl import InternVLChatModel, InternVisionModel @@ -104,18 +105,6 @@ def __del__(self): del self._graph -def split_batch(func, param_name, num_splits=2): - """Decorator to split along the 0th dimension into a specified number of chunks.""" - def decorator(*args, **kwargs): - inputs = kwargs.get(param_name, None) - if inputs is not None: - split_inputs = list(torch.chunk(inputs, num_splits, dim=0)) - kwargs[param_name] = split_inputs - results = func(*args, **kwargs) - return torch.cat(results, dim=0) - return decorator - - class CUDAGraphRunner(GraphRunner): """cuda graph runner.""" @@ -126,8 +115,12 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf self.max_tokens = cache_config.max_prefill_token_num self.num_blocks = cache_config.num_gpu_blocks - self.use_compile = False self.enable_graph = self.check_enable_graph() + + self.graph_pool_handle = torch.cuda.graph_pool_handle() + self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict() + + self.compile_vit = False if not self.backend_config.eager_mode and isinstance(self.model, InternVLChatModel): world_size, _ = get_world_rank() if world_size > 1: @@ -135,12 +128,9 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf if isinstance(self.model.vision_model, InternVisionModel): self.model.vision_model.encoder.forward = split_batch(self.model.vision_model.encoder.forward, "inputs_embeds") self.model.extract_feature = torch.compile(self.model.extract_feature, mode="max-autotune") - self.use_compile = True + self.compile_vit = True self.compiled = False - self.graph_pool_handle = torch.cuda.graph_pool_handle() - self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict() - def check_enable_graph(self): """check enable graph.""" if self.backend_config.eager_mode: @@ -157,15 +147,30 @@ def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, pas new_num_tokens = next_power_of_2(num_tokens) return (new_num_tokens, is_decoding) + def _is_decoding(self, attn_metadata: Any, **kwargs): + return attn_metadata.is_decoding + + def _mark_dynamic_once(self, **kwargs): + """call torch._dynamo.mark_dynamic to avoid recompile""" + if self.compiled: + return + + for tensor_name, dynamic_dims in self.model.compile_dynamic_args.items(): + tensor = kwargs.get(tensor_name, None) + if tensor is None: + continue + torch._dynamo.mark_dynamic(tensor, dynamic_dims) + self.compiled = True + def __call__(self, **kwargs): """call.""" - if self.use_compile and not self.compiled: - for tensor_name, dynamic_dims in self.model.compile_dynamic_args.items(): - tensor = kwargs.get(tensor_name, None) - if tensor is None: - continue - torch._dynamo.mark_dynamic(tensor, dynamic_dims) - self.compiled = True + if self.backend_config.eager_mode: + return self.model(**kwargs) + + if self.compile_vit and not self._is_decoding(**kwargs): + self._mark_dynamic_once(**kwargs) + return self.model(**kwargs) + enable_graph = self.enable_graph(**kwargs) if not enable_graph: diff --git a/lmdeploy/pytorch/decorators.py b/lmdeploy/pytorch/decorators.py new file mode 100644 index 0000000000..4b28604b1b --- /dev/null +++ b/lmdeploy/pytorch/decorators.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def enable_micro_batch(func): + """Decorator to enable micro-batch computation.""" + def wrapper(self, hidden_states, *args, **kwargs): + if isinstance(hidden_states, list): + # Apply forward computation to each micro-batch + return [func(self, hs, *args, **kwargs) for hs in hidden_states] + else: + # If not a list, directly apply the forward computation + return func(self, hidden_states, *args, **kwargs) + return wrapper + + +def split_batch(func, param_name, num_splits=2): + """Decorator to split along the 0th dimension into a specified number of chunks.""" + def wrapper(*args, **kwargs): + inputs = kwargs.get(param_name, None) + if inputs is not None: + split_inputs = list(torch.chunk(inputs, num_splits, dim=0)) + kwargs[param_name] = split_inputs + results = func(*args, **kwargs) + return torch.cat(results, dim=0) + return wrapper \ No newline at end of file diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index e248a3a81b..50df6e54d7 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -7,6 +7,7 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.decorators import enable_micro_batch from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor @@ -205,16 +206,6 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.ls1 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device)) self.ls2 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device)) - def enable_micro_batch(func): - """Decorator to enable micro-batch computation.""" - def wrapper(self, hidden_states, *args, **kwargs): - if isinstance(hidden_states, list): - # Apply forward computation to each micro-batch - return [func(self, hs, *args, **kwargs) for hs in hidden_states] - else: - # If not a list, directly apply the forward computation - return func(self, hidden_states, *args, **kwargs) - return wrapper @enable_micro_batch def _attn(self, hidden_states): From 76729bfc3a8d2e0d8ad735eefad85c7870d13f09 Mon Sep 17 00:00:00 2001 From: jinminxi104 Date: Tue, 1 Apr 2025 15:57:56 +0000 Subject: [PATCH 06/16] refine code --- lmdeploy/pytorch/backends/cuda/graph_runner.py | 11 ++++++----- lmdeploy/pytorch/models/internvl.py | 2 +- .../{decorators.py => models/utils/micro_batch.py} | 0 3 files changed, 7 insertions(+), 6 deletions(-) rename lmdeploy/pytorch/{decorators.py => models/utils/micro_batch.py} (100%) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 4c6150f531..df28e2ecd0 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -4,12 +4,13 @@ import torch from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig -from lmdeploy.pytorch.decorators import split_batch +from lmdeploy.pytorch.models.utils.micro_batch import split_batch from lmdeploy.pytorch.distributed import get_world_rank from lmdeploy.pytorch.model_inputs import StepContext from lmdeploy.pytorch.models.internvl import InternVLChatModel, InternVisionModel from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta from lmdeploy.utils import get_logger +from packaging import version from ..graph_runner import GraphRunner @@ -123,13 +124,13 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf self.compile_vit = False if not self.backend_config.eager_mode and isinstance(self.model, InternVLChatModel): world_size, _ = get_world_rank() - if world_size > 1: + if version.parse(torch.__version__) >= version.parse('2.6.0') and torch.__world_size > 1: torch._inductor.config.reorder_for_compute_comm_overlap = True if isinstance(self.model.vision_model, InternVisionModel): self.model.vision_model.encoder.forward = split_batch(self.model.vision_model.encoder.forward, "inputs_embeds") self.model.extract_feature = torch.compile(self.model.extract_feature, mode="max-autotune") self.compile_vit = True - self.compiled = False + self.has_compiled_vit = False def check_enable_graph(self): """check enable graph.""" @@ -152,7 +153,7 @@ def _is_decoding(self, attn_metadata: Any, **kwargs): def _mark_dynamic_once(self, **kwargs): """call torch._dynamo.mark_dynamic to avoid recompile""" - if self.compiled: + if self.has_compiled_vit: return for tensor_name, dynamic_dims in self.model.compile_dynamic_args.items(): @@ -160,7 +161,7 @@ def _mark_dynamic_once(self, **kwargs): if tensor is None: continue torch._dynamo.mark_dynamic(tensor, dynamic_dims) - self.compiled = True + self.has_compiled_vit = True def __call__(self, **kwargs): """call.""" diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 50df6e54d7..e75e565931 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -7,7 +7,7 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig -from lmdeploy.pytorch.decorators import enable_micro_batch +from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor diff --git a/lmdeploy/pytorch/decorators.py b/lmdeploy/pytorch/models/utils/micro_batch.py similarity index 100% rename from lmdeploy/pytorch/decorators.py rename to lmdeploy/pytorch/models/utils/micro_batch.py From 6c150064d77677923dde11f3633938e552f929df Mon Sep 17 00:00:00 2001 From: jinminxi104 Date: Tue, 1 Apr 2025 16:15:49 +0000 Subject: [PATCH 07/16] lint --- lmdeploy/pytorch/backends/cuda/graph_runner.py | 13 +++++++------ lmdeploy/pytorch/models/internvl.py | 5 ++--- lmdeploy/pytorch/models/utils/micro_batch.py | 9 +++++++-- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index df28e2ecd0..9665b2956b 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -2,15 +2,15 @@ from typing import Any, Dict, List, Tuple import torch +from packaging import version from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig -from lmdeploy.pytorch.models.utils.micro_batch import split_batch from lmdeploy.pytorch.distributed import get_world_rank from lmdeploy.pytorch.model_inputs import StepContext -from lmdeploy.pytorch.models.internvl import InternVLChatModel, InternVisionModel +from lmdeploy.pytorch.models.internvl import InternVisionModel, InternVLChatModel from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta +from lmdeploy.pytorch.models.utils.micro_batch import split_batch from lmdeploy.utils import get_logger -from packaging import version from ..graph_runner import GraphRunner @@ -127,8 +127,9 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf if version.parse(torch.__version__) >= version.parse('2.6.0') and torch.__world_size > 1: torch._inductor.config.reorder_for_compute_comm_overlap = True if isinstance(self.model.vision_model, InternVisionModel): - self.model.vision_model.encoder.forward = split_batch(self.model.vision_model.encoder.forward, "inputs_embeds") - self.model.extract_feature = torch.compile(self.model.extract_feature, mode="max-autotune") + self.model.vision_model.encoder.forward = split_batch(self.model.vision_model.encoder.forward, + 'inputs_embeds') + self.model.extract_feature = torch.compile(self.model.extract_feature, mode='max-autotune') self.compile_vit = True self.has_compiled_vit = False @@ -152,7 +153,7 @@ def _is_decoding(self, attn_metadata: Any, **kwargs): return attn_metadata.is_decoding def _mark_dynamic_once(self, **kwargs): - """call torch._dynamo.mark_dynamic to avoid recompile""" + """call torch._dynamo.mark_dynamic to avoid recompile.""" if self.has_compiled_vit: return diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index e75e565931..43df1dd54c 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -7,9 +7,9 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig -from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor from lmdeploy.pytorch.nn import LayerNorm, RMSNorm from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear @@ -206,7 +206,6 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.ls1 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device)) self.ls2 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device)) - @enable_micro_batch def _attn(self, hidden_states): hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states[0].dtype)) * self.ls1 @@ -316,7 +315,7 @@ def __init__(self, self.input_processor = InternVLInputProcessor(self.config, dtype) # for torch.compile, will call torch._dynamo.mark_dynamic to reduce recompile - self.compile_dynamic_args = {"pixel_values": [0]} + self.compile_dynamic_args = {'pixel_values': [0]} def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() diff --git a/lmdeploy/pytorch/models/utils/micro_batch.py b/lmdeploy/pytorch/models/utils/micro_batch.py index 4b28604b1b..322efb2e86 100644 --- a/lmdeploy/pytorch/models/utils/micro_batch.py +++ b/lmdeploy/pytorch/models/utils/micro_batch.py @@ -4,6 +4,7 @@ def enable_micro_batch(func): """Decorator to enable micro-batch computation.""" + def wrapper(self, hidden_states, *args, **kwargs): if isinstance(hidden_states, list): # Apply forward computation to each micro-batch @@ -11,11 +12,14 @@ def wrapper(self, hidden_states, *args, **kwargs): else: # If not a list, directly apply the forward computation return func(self, hidden_states, *args, **kwargs) + return wrapper def split_batch(func, param_name, num_splits=2): - """Decorator to split along the 0th dimension into a specified number of chunks.""" + """Decorator to split along the 0th dimension into a specified number of + chunks.""" + def wrapper(*args, **kwargs): inputs = kwargs.get(param_name, None) if inputs is not None: @@ -23,4 +27,5 @@ def wrapper(*args, **kwargs): kwargs[param_name] = split_inputs results = func(*args, **kwargs) return torch.cat(results, dim=0) - return wrapper \ No newline at end of file + + return wrapper From 7a6cc209ae43f5c27de91adcde6a70e82eaa6e2d Mon Sep 17 00:00:00 2001 From: caikun-pjlab Date: Wed, 2 Apr 2025 12:22:55 +0800 Subject: [PATCH 08/16] use torch.compile when torch.version ge 2.5.0 --- lmdeploy/pytorch/backends/cuda/graph_runner.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 9665b2956b..bd90f6412e 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -122,9 +122,11 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict() self.compile_vit = False - if not self.backend_config.eager_mode and isinstance(self.model, InternVLChatModel): + torch_version = version.parse(torch.__version__) + if not self.backend_config.eager_mode and isinstance( + self.model, InternVLChatModel) and torch_version >= version.parse('2.5.0'): world_size, _ = get_world_rank() - if version.parse(torch.__version__) >= version.parse('2.6.0') and torch.__world_size > 1: + if torch_version >= version.parse('2.6.0') and world_size > 1: torch._inductor.config.reorder_for_compute_comm_overlap = True if isinstance(self.model.vision_model, InternVisionModel): self.model.vision_model.encoder.forward = split_batch(self.model.vision_model.encoder.forward, From ea79d157082105a3279b8f2295b5f782d33355e7 Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Tue, 1 Apr 2025 00:39:44 +0800 Subject: [PATCH 09/16] set cmake policy minimum version as 3.5 (#3376) --- .github/workflows/unit-test.yml | 1 + generate.sh | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 5456a3d668..5a39caa859 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -72,6 +72,7 @@ jobs: -DUSE_NVTX=ON \ -DSM=80 \ -DCMAKE_CUDA_ARCHITECTURES=80 \ + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 \ -DBUILD_TEST=OFF make -j$(nproc) && make install - name: Install lmdeploy diff --git a/generate.sh b/generate.sh index 0c25b8cbf2..5e21d50885 100755 --- a/generate.sh +++ b/generate.sh @@ -14,4 +14,5 @@ cmake ${builder} .. \ -DBUILD_PY_FFI=ON \ -DBUILD_MULTI_GPU=ON \ -DCMAKE_CUDA_FLAGS="-lineinfo" \ + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 \ -DUSE_NVTX=ON From b111a9a72d43203e6581b47cbc746b8c31468d66 Mon Sep 17 00:00:00 2001 From: caikun-pjlab Date: Thu, 3 Apr 2025 15:50:05 +0800 Subject: [PATCH 10/16] refactor micro batch --- lmdeploy/pytorch/models/internvl.py | 4 +-- lmdeploy/pytorch/models/utils/micro_batch.py | 37 ++++++++++++++------ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 43df1dd54c..cf443cea8b 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -206,12 +206,12 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.ls1 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device)) self.ls2 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device)) - @enable_micro_batch + @enable_micro_batch(param_name='hidden_states') def _attn(self, hidden_states): hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states[0].dtype)) * self.ls1 return hidden_states - @enable_micro_batch + @enable_micro_batch(param_name='hidden_states') def _mlp(self, hidden_states): hidden_states = hidden_states + self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2 return hidden_states diff --git a/lmdeploy/pytorch/models/utils/micro_batch.py b/lmdeploy/pytorch/models/utils/micro_batch.py index 322efb2e86..1d951be7eb 100644 --- a/lmdeploy/pytorch/models/utils/micro_batch.py +++ b/lmdeploy/pytorch/models/utils/micro_batch.py @@ -1,19 +1,32 @@ # Copyright (c) OpenMMLab. All rights reserved. +import functools + import torch -def enable_micro_batch(func): - """Decorator to enable micro-batch computation.""" +def enable_micro_batch(param_name): + """Decorator factory to enable micro-batch computation.""" - def wrapper(self, hidden_states, *args, **kwargs): - if isinstance(hidden_states, list): - # Apply forward computation to each micro-batch - return [func(self, hs, *args, **kwargs) for hs in hidden_states] - else: - # If not a list, directly apply the forward computation - return func(self, hidden_states, *args, **kwargs) + def decorator(func): - return wrapper + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + inputs = kwargs.get(param_name, None) + if isinstance(inputs, list): + # Apply forward computation to each micro-batch + results = [] + for input in inputs: + kwargs[param_name] = input + result = func(self, *args, **kwargs) + results.append(result) + return results + else: + # If not a list, directly apply the forward computation + return func(self, *args, **kwargs) + + return wrapper + + return decorator def split_batch(func, param_name, num_splits=2): @@ -26,6 +39,8 @@ def wrapper(*args, **kwargs): split_inputs = list(torch.chunk(inputs, num_splits, dim=0)) kwargs[param_name] = split_inputs results = func(*args, **kwargs) - return torch.cat(results, dim=0) + return torch.cat(results, dim=0) + else: + return func(*args, **kwargs) return wrapper From 7e65ad32e66959bf23bbfc9a998bf68a7f5bff60 Mon Sep 17 00:00:00 2001 From: caikun-pjlab Date: Thu, 3 Apr 2025 17:28:54 +0800 Subject: [PATCH 11/16] refactor torch.compile --- .../pytorch/backends/cuda/graph_runner.py | 52 +++++-------------- lmdeploy/pytorch/models/internvl.py | 34 ++++++++++-- 2 files changed, 44 insertions(+), 42 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index bd90f6412e..c8e51b6e01 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -2,14 +2,10 @@ from typing import Any, Dict, List, Tuple import torch -from packaging import version from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig -from lmdeploy.pytorch.distributed import get_world_rank from lmdeploy.pytorch.model_inputs import StepContext -from lmdeploy.pytorch.models.internvl import InternVisionModel, InternVLChatModel from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta -from lmdeploy.pytorch.models.utils.micro_batch import split_batch from lmdeploy.utils import get_logger from ..graph_runner import GraphRunner @@ -120,20 +116,7 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf self.graph_pool_handle = torch.cuda.graph_pool_handle() self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict() - - self.compile_vit = False - torch_version = version.parse(torch.__version__) - if not self.backend_config.eager_mode and isinstance( - self.model, InternVLChatModel) and torch_version >= version.parse('2.5.0'): - world_size, _ = get_world_rank() - if torch_version >= version.parse('2.6.0') and world_size > 1: - torch._inductor.config.reorder_for_compute_comm_overlap = True - if isinstance(self.model.vision_model, InternVisionModel): - self.model.vision_model.encoder.forward = split_batch(self.model.vision_model.encoder.forward, - 'inputs_embeds') - self.model.extract_feature = torch.compile(self.model.extract_feature, mode='max-autotune') - self.compile_vit = True - self.has_compiled_vit = False + self.has_try_compile_model: bool = False def check_enable_graph(self): """check enable graph.""" @@ -142,6 +125,16 @@ def check_enable_graph(self): return getattr(self.model, 'support_cuda_graph', _false) + def _try_compile_model_once(self): + if self.has_try_compile_model: + return + + if hasattr(self.model, 'compile_model'): + method = getattr(self.model, 'compile_model') + method() + + self.has_try_compile_model = True + def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List, attn_metadata: Any, inputs_embeds: torch.Tensor, **kwargs): """get graph key.""" @@ -151,29 +144,10 @@ def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, pas new_num_tokens = next_power_of_2(num_tokens) return (new_num_tokens, is_decoding) - def _is_decoding(self, attn_metadata: Any, **kwargs): - return attn_metadata.is_decoding - - def _mark_dynamic_once(self, **kwargs): - """call torch._dynamo.mark_dynamic to avoid recompile.""" - if self.has_compiled_vit: - return - - for tensor_name, dynamic_dims in self.model.compile_dynamic_args.items(): - tensor = kwargs.get(tensor_name, None) - if tensor is None: - continue - torch._dynamo.mark_dynamic(tensor, dynamic_dims) - self.has_compiled_vit = True - def __call__(self, **kwargs): """call.""" - if self.backend_config.eager_mode: - return self.model(**kwargs) - - if self.compile_vit and not self._is_decoding(**kwargs): - self._mark_dynamic_once(**kwargs) - return self.model(**kwargs) + if not self.backend_config.eager_mode: + self._try_compile_model_once() enable_graph = self.enable_graph(**kwargs) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index cf443cea8b..1eaab747eb 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -4,12 +4,15 @@ import torch import torch.nn.functional as F +from packaging import version from torch import nn from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.backends.selector import get_backend +from lmdeploy.pytorch.distributed import get_world_rank from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch +from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch, split_batch from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor from lmdeploy.pytorch.nn import LayerNorm, RMSNorm from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear @@ -314,8 +317,32 @@ def __init__(self, self.input_processor = InternVLInputProcessor(self.config, dtype) - # for torch.compile, will call torch._dynamo.mark_dynamic to reduce recompile - self.compile_dynamic_args = {'pixel_values': [0]} + self.compile_vit = False + if get_backend().get_name() == 'cuda': + self.compile_model() + + def compile_model(self): + torch_version = version.parse(torch.__version__) + if torch_version < version.parse('2.5.0'): + return + + world_size, _ = get_world_rank() + if torch_version >= version.parse('2.6.0') and world_size > 1: + torch._inductor.config.reorder_for_compute_comm_overlap = True + if isinstance(self.vision_model, InternVisionModel): + self.vision_model.encoder.forward = split_batch(self.vision_model.encoder.forward, 'inputs_embeds') + + self.extract_feature = torch.compile(self.extract_feature, mode='max-autotune') + self.compile_vit = True + self.has_compiled_vit = False + + def _mark_dynamic_once(self, pixel_values, dims): + """call torch._dynamo.mark_dynamic to avoid recompile.""" + if not self.compile_vit or self.has_compiled_vit or pixel_values is None: + return + + torch._dynamo.mark_dynamic(pixel_values, dims) + self.has_compiled_vit = True def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() @@ -361,6 +388,7 @@ def forward( ): if inputs_embeds is None and pixel_values is not None: # extract feature + self._mark_dynamic_once(pixel_values, [0]) vit_embeds = self.extract_feature(pixel_values) lang_embeds = self.language_model.get_input_embeddings()(input_ids) lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) From 0222d47a32ebddd4800e958226780f871c590203 Mon Sep 17 00:00:00 2001 From: caikun-pjlab Date: Sat, 5 Apr 2025 16:08:12 +0000 Subject: [PATCH 12/16] make micro_batch more general --- .../pytorch/backends/cuda/graph_runner.py | 3 ++- lmdeploy/pytorch/models/internvl.py | 11 ++++---- lmdeploy/pytorch/models/utils/micro_batch.py | 27 ++++++++++++++----- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index c8e51b6e01..bc5b582d05 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -3,6 +3,7 @@ import torch +from lmdeploy.pytorch.backends.selector import get_backend from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig from lmdeploy.pytorch.model_inputs import StepContext from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta @@ -146,7 +147,7 @@ def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, pas def __call__(self, **kwargs): """call.""" - if not self.backend_config.eager_mode: + if not self.backend_config.eager_mode and get_backend().get_name() == 'cuda': self._try_compile_model_once() enable_graph = self.enable_graph(**kwargs) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 1eaab747eb..c6940a69d2 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -8,7 +8,6 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig -from lmdeploy.pytorch.backends.selector import get_backend from lmdeploy.pytorch.distributed import get_world_rank from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager @@ -209,12 +208,12 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.ls1 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device)) self.ls2 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device)) - @enable_micro_batch(param_name='hidden_states') + @enable_micro_batch(param_name='hidden_states', index=0) def _attn(self, hidden_states): hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states[0].dtype)) * self.ls1 return hidden_states - @enable_micro_batch(param_name='hidden_states') + @enable_micro_batch(param_name='hidden_states', index=0) def _mlp(self, hidden_states): hidden_states = hidden_states + self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2 return hidden_states @@ -318,8 +317,6 @@ def __init__(self, self.input_processor = InternVLInputProcessor(self.config, dtype) self.compile_vit = False - if get_backend().get_name() == 'cuda': - self.compile_model() def compile_model(self): torch_version = version.parse(torch.__version__) @@ -330,7 +327,9 @@ def compile_model(self): if torch_version >= version.parse('2.6.0') and world_size > 1: torch._inductor.config.reorder_for_compute_comm_overlap = True if isinstance(self.vision_model, InternVisionModel): - self.vision_model.encoder.forward = split_batch(self.vision_model.encoder.forward, 'inputs_embeds') + self.vision_model.encoder.forward = split_batch(self.vision_model.encoder.forward, + 'inputs_embeds', + index=0) self.extract_feature = torch.compile(self.extract_feature, mode='max-autotune') self.compile_vit = True diff --git a/lmdeploy/pytorch/models/utils/micro_batch.py b/lmdeploy/pytorch/models/utils/micro_batch.py index 1d951be7eb..aba340b7c8 100644 --- a/lmdeploy/pytorch/models/utils/micro_batch.py +++ b/lmdeploy/pytorch/models/utils/micro_batch.py @@ -4,19 +4,26 @@ import torch -def enable_micro_batch(param_name): +def enable_micro_batch(param_name, index=-1): """Decorator factory to enable micro-batch computation.""" def decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): - inputs = kwargs.get(param_name, None) + if index != -1 and len(args) > index: + inputs = args[index] + else: + inputs = kwargs.get(param_name, None) + if isinstance(inputs, list): # Apply forward computation to each micro-batch results = [] for input in inputs: - kwargs[param_name] = input + if index != -1 and len(args) > index: + args = args[0:index] + (input, ) + args[index + 1:] + else: + kwargs[param_name] = input result = func(self, *args, **kwargs) results.append(result) return results @@ -29,15 +36,23 @@ def wrapper(self, *args, **kwargs): return decorator -def split_batch(func, param_name, num_splits=2): +def split_batch(func, param_name, index=-1, num_splits=2): """Decorator to split along the 0th dimension into a specified number of chunks.""" def wrapper(*args, **kwargs): - inputs = kwargs.get(param_name, None) + if index != -1 and len(args) > index: + inputs = args[index] + else: + inputs = kwargs.get(param_name, None) + if inputs is not None: split_inputs = list(torch.chunk(inputs, num_splits, dim=0)) - kwargs[param_name] = split_inputs + if index != -1 and len(args) > index: + args = args[0:index] + (split_inputs, ) + args[index + 1:] + else: + kwargs[param_name] = split_inputs + results = func(*args, **kwargs) return torch.cat(results, dim=0) else: From a4c43b4f7f225558fa2b07d5c357a2ab898146e9 Mon Sep 17 00:00:00 2001 From: yaofengchen <67218893+yao-fengchen@users.noreply.github.com> Date: Fri, 11 Apr 2025 19:21:20 +0800 Subject: [PATCH 13/16] update ascend doc (#3420) --- docs/en/get_started/ascend/get_started.md | 10 +++++ docs/en/supported_models/supported_models.md | 37 ++++++++++--------- docs/zh_cn/get_started/ascend/get_started.md | 10 +++++ .../supported_models/supported_models.md | 36 ++++++++++-------- 4 files changed, 60 insertions(+), 33 deletions(-) diff --git a/docs/en/get_started/ascend/get_started.md b/docs/en/get_started/ascend/get_started.md index be39f179b3..004d0adb20 100644 --- a/docs/en/get_started/ascend/get_started.md +++ b/docs/en/get_started/ascend/get_started.md @@ -158,6 +158,16 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu Please check [supported_models](../../supported_models/supported_models.md) before use this feature. +### w8a8 SMOOTH_QUANT + +Run the following commands to quantize weights on Atlas 800T A2. + +```bash +lmdeploy lite smooth_quant $HF_MODEL --work-dir $WORK_DIR --device npu +``` + +Please check [supported_models](../../supported_models/supported_models.md) before use this feature. + ### int8 KV-cache Quantization Ascend backend has supported offline int8 KV-cache Quantization on eager mode. diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index 176c10df46..2ac7a3ee6e 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -115,20 +115,23 @@ The following tables detail the models supported by LMDeploy's TurboMind engine ## PyTorchEngine on Huawei Ascend Platform -| Model | Size | Type | FP16/BF16(eager) | FP16/BF16(graph) | W4A16(eager) | -| :------------: | :------: | :--: | :--------------: | :--------------: | :----------: | -| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | -| Llama3 | 8B | LLM | Yes | Yes | Yes | -| Llama3.1 | 8B | LLM | Yes | Yes | Yes | -| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | -| InternLM2.5 | 7B - 20B | LLM | Yes | Yes | Yes | -| InternLM3 | 8B | LLM | Yes | Yes | Yes | -| Mixtral | 8x7B | LLM | Yes | Yes | No | -| QWen1.5-MoE | A2.7B | LLM | Yes | - | No | -| QWen2(.5) | 7B | LLM | Yes | Yes | No | -| QWen2-MoE | A14.57B | LLM | Yes | - | No | -| DeepSeek-V2 | 16B | LLM | No | Yes | No | -| InternVL(v1.5) | 2B-26B | MLLM | Yes | - | Yes | -| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | -| CogVLM2-chat | 19B | MLLM | Yes | No | - | -| GLM4V | 9B | MLLM | Yes | No | - | +| Model | Size | Type | FP16/BF16(eager) | FP16/BF16(graph) | W8A8(graph) | W4A16(eager) | +| :------------: | :------: | :--: | :--------------: | :--------------: | :---------: | :----------: | +| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | +| Llama3 | 8B | LLM | Yes | Yes | Yes | Yes | +| Llama3.1 | 8B | LLM | Yes | Yes | Yes | Yes | +| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | +| InternLM2.5 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | +| InternLM3 | 8B | LLM | Yes | Yes | Yes | Yes | +| Mixtral | 8x7B | LLM | Yes | Yes | No | No | +| QWen1.5-MoE | A2.7B | LLM | Yes | - | No | No | +| QWen2(.5) | 7B | LLM | Yes | Yes | Yes | Yes | +| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | - | - | +| QWen2.5-VL | 3B - 72B | MLLM | Yes | Yes | - | - | +| QWen2-MoE | A14.57B | LLM | Yes | - | No | No | +| DeepSeek-V2 | 16B | LLM | No | Yes | No | No | +| InternVL(v1.5) | 2B-26B | MLLM | Yes | - | Yes | Yes | +| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | Yes | +| InternVL2.5 | 1B-78B | MLLM | Yes | Yes | Yes | Yes | +| CogVLM2-chat | 19B | MLLM | Yes | No | - | - | +| GLM4V | 9B | MLLM | Yes | No | - | - | diff --git a/docs/zh_cn/get_started/ascend/get_started.md b/docs/zh_cn/get_started/ascend/get_started.md index 9161d6646d..8dad380a52 100644 --- a/docs/zh_cn/get_started/ascend/get_started.md +++ b/docs/zh_cn/get_started/ascend/get_started.md @@ -154,6 +154,16 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu 支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。 +### w8a8 SMOOTH_QUANT + +运行下面的代码可以在Atlas 800T A2上对权重进行W8A8量化。 + +```bash +lmdeploy lite smooth_quant $HF_MODEL --work-dir $WORK_DIR --device npu +``` + +支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。 + ### int8 KV-cache 量化 昇腾后端现在支持了在eager模式下的离线int8 KV-cache量化。 diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index c5cae84e47..07ca1c93bd 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -115,19 +115,23 @@ ## PyTorchEngine 华为昇腾平台 -| Model | Size | Type | FP16/BF16(eager) | FP16/BF16(graph) | W4A16(eager) | -| :------------: | :------: | :--: | :--------------: | :--------------: | :----------: | -| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | -| Llama3 | 8B | LLM | Yes | Yes | Yes | -| Llama3.1 | 8B | LLM | Yes | Yes | Yes | -| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | -| InternLM2.5 | 7B - 20B | LLM | Yes | Yes | Yes | -| Mixtral | 8x7B | LLM | Yes | Yes | No | -| QWen1.5-MoE | A2.7B | LLM | Yes | - | No | -| QWen2(.5) | 7B | LLM | Yes | Yes | No | -| QWen2-MoE | A14.57B | LLM | Yes | - | No | -| DeepSeek-V2 | 16B | LLM | No | Yes | No | -| InternVL(v1.5) | 2B-26B | MLLM | Yes | - | Yes | -| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | -| CogVLM2-chat | 19B | MLLM | Yes | No | - | -| GLM4V | 9B | MLLM | Yes | No | - | +| Model | Size | Type | FP16/BF16(eager) | FP16/BF16(graph) | W8A8(graph) | W4A16(eager) | +| :------------: | :------: | :--: | :--------------: | :--------------: | :---------: | :----------: | +| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | +| Llama3 | 8B | LLM | Yes | Yes | Yes | Yes | +| Llama3.1 | 8B | LLM | Yes | Yes | Yes | Yes | +| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | +| InternLM2.5 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | +| InternLM3 | 8B | LLM | Yes | Yes | Yes | Yes | +| Mixtral | 8x7B | LLM | Yes | Yes | No | No | +| QWen1.5-MoE | A2.7B | LLM | Yes | - | No | No | +| QWen2(.5) | 7B | LLM | Yes | Yes | Yes | Yes | +| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | - | - | +| QWen2.5-VL | 3B - 72B | MLLM | Yes | Yes | - | - | +| QWen2-MoE | A14.57B | LLM | Yes | - | No | No | +| DeepSeek-V2 | 16B | LLM | No | Yes | No | No | +| InternVL(v1.5) | 2B-26B | MLLM | Yes | - | Yes | Yes | +| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | Yes | +| InternVL2.5 | 1B-78B | MLLM | Yes | Yes | Yes | Yes | +| CogVLM2-chat | 19B | MLLM | Yes | No | - | - | +| GLM4V | 9B | MLLM | Yes | No | - | - | From d131cc77f69c5649aa0ee3070cc31bfeb4852be9 Mon Sep 17 00:00:00 2001 From: q yao Date: Mon, 14 Apr 2025 17:38:01 +0800 Subject: [PATCH 14/16] find port (#3429) * find port * safe --- lmdeploy/pytorch/engine/executor/dist_utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/lmdeploy/pytorch/engine/executor/dist_utils.py b/lmdeploy/pytorch/engine/executor/dist_utils.py index bf57b86893..91f437c87d 100644 --- a/lmdeploy/pytorch/engine/executor/dist_utils.py +++ b/lmdeploy/pytorch/engine/executor/dist_utils.py @@ -6,14 +6,17 @@ import torch.distributed as dist -def find_available_port() -> bool: +def find_available_port(start_port: int = 0) -> int: """find available port.""" - port = 29500 - while True: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - if s.connect_ex(('localhost', port)) != 0: - return port - port += 1 + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(('127.0.0.1', start_port)) + port = s.getsockname()[1] + return port + except socket.error as e: + if start_port == 0: + raise RuntimeError('Failed to find available port.') from e + return find_available_port(0) def setup_master_addr(addr: str, port: str): From 231a3236751d879844a36dfcfc28fbf312fff2ee Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Mon, 14 Apr 2025 17:58:43 +0800 Subject: [PATCH 15/16] bump version to v0.7.3 (#3416) * bump version to v0.7.3 * add rayexecutor release timeout (#3403) * add release timeout * change log level --------- Co-authored-by: q yao --- docs/en/get_started/installation.md | 2 +- docs/zh_cn/get_started/installation.md | 2 +- .../pytorch/engine/executor/ray_executor.py | 26 +++++++++++++++---- lmdeploy/version.py | 2 +- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/docs/en/get_started/installation.md b/docs/en/get_started/installation.md index 2c5a9b007a..05a624428e 100644 --- a/docs/en/get_started/installation.md +++ b/docs/en/get_started/installation.md @@ -23,7 +23,7 @@ pip install lmdeploy The default prebuilt package is compiled on **CUDA 12**. If CUDA 11+ (>=11.3) is required, you can install lmdeploy by: ```shell -export LMDEPLOY_VERSION=0.7.2.post1 +export LMDEPLOY_VERSION=0.7.3 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` diff --git a/docs/zh_cn/get_started/installation.md b/docs/zh_cn/get_started/installation.md index c60297a3c2..1bfc924912 100644 --- a/docs/zh_cn/get_started/installation.md +++ b/docs/zh_cn/get_started/installation.md @@ -23,7 +23,7 @@ pip install lmdeploy 默认的预构建包是在 **CUDA 12** 上编译的。如果需要 CUDA 11+ (>=11.3),你可以使用以下命令安装 lmdeploy: ```shell -export LMDEPLOY_VERSION=0.7.2.post1 +export LMDEPLOY_VERSION=0.7.3 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index 782911f767..bb62544691 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -230,13 +230,17 @@ def __init__(self, logger.info('Warming up distribute environment, this might take long time, please waiting...') ray.get([worker.warmup_dist.remote() for worker in self.workers]) - def collective_rpc(self, method: str, args: Tuple[Any] = None, kwargs: Dict[str, Any] = None): + def collective_rpc(self, + method: str, + args: Tuple[Any] = None, + kwargs: Dict[str, Any] = None, + timeout: float = None): """collective rpc.""" if args is None: args = list() if kwargs is None: kwargs = dict() - return ray.get([getattr(worker, method).remote(*args, **kwargs) for worker in self.workers]) + return ray.get([getattr(worker, method).remote(*args, **kwargs) for worker in self.workers], timeout=timeout) def build_model(self): """build model.""" @@ -293,9 +297,21 @@ def stop(self): def release(self): """release.""" - self.collective_rpc('release') - for worker in self.workers: - ray.kill(worker) + if self.dp == 1: + try: + self.collective_rpc('release', timeout=5.0) + logger.debug('RayExecutor workers released.') + except ray.exceptions.GetTimeoutError: + logger.info('Ray release timeout.') + + try: + self.collective_rpc('exit') + logger.debug('RayExecutor workers exited.') + except ray.exceptions.RayActorError as e: + logger.debug(f'ray actor exit: {e}') + else: + [ray.kill(worker) for worker in self.workers] + ray.util.remove_placement_group(self.placement_group) def _compile_dag(self): diff --git a/lmdeploy/version.py b/lmdeploy/version.py index 106176c7a7..b001595fdd 100644 --- a/lmdeploy/version.py +++ b/lmdeploy/version.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple -__version__ = '0.7.2.post1' +__version__ = '0.7.3' short_version = __version__ From 8734bbf28cf187605f36f049557684947d504bdd Mon Sep 17 00:00:00 2001 From: Lantian Zhang <50076473+DoorKickers@users.noreply.github.com> Date: Wed, 16 Apr 2025 15:30:59 +0800 Subject: [PATCH 16/16] modify ascend dockerfile to support direct run lmdeploy serve (#3436) --- docker/Dockerfile_aarch64_ascend | 47 +++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/docker/Dockerfile_aarch64_ascend b/docker/Dockerfile_aarch64_ascend index 5ed842061c..c502826d1f 100644 --- a/docker/Dockerfile_aarch64_ascend +++ b/docker/Dockerfile_aarch64_ascend @@ -110,7 +110,8 @@ RUN echo "source /usr/local/Ascend/ascend-toolkit/set_env.sh" >> ~/.bashrc && \ RUN --mount=type=cache,target=/root/.cache/pip \ pip3 install torch==2.3.1 torchvision==0.18.1 torch-npu==2.3.1 && \ pip3 install transformers timm && \ - pip3 install dlinfer-ascend + pip3 install dlinfer-ascend && \ + pip3 install partial_json_parser shortuuid # lmdeploy FROM build_temp as copy_temp @@ -122,3 +123,47 @@ WORKDIR /opt/lmdeploy RUN --mount=type=cache,target=/root/.cache/pip \ LMDEPLOY_TARGET_DEVICE=ascend pip3 install -v --no-build-isolation -e . + +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling/lib/linux/$(arch):$LD_LIBRARY_PATH +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/tools/aml/lib64:${ASCEND_TOOLKIT_HOME}/tools/aml/lib64/plugin:$LD_LIBRARY_PATH +ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:$PYTHONPATH +ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${ASCEND_TOOLKIT_HOME}/tools/ccec_compiler/bin:$PATH +ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME} +ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit +ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME} + +ENV ATB_HOME_PATH=/usr/local/Ascend/nnal/atb/latest/atb/cxx_abi_0 +ENV LD_LIBRARY_PATH=${ATB_HOME_PATH}/lib:${ATB_HOME_PATH}/examples:${ATB_HOME_PATH}/tests/atbopstest:$LD_LIBRARY_PATH +ENV PATH=${ATB_HOME_PATH}/bin:$PATH + +ENV ATB_STREAM_SYNC_EVERY_KERNEL_ENABLE=0 +ENV ATB_STREAM_SYNC_EVERY_RUNNER_ENABLE=0 +ENV ATB_STREAM_SYNC_EVERY_OPERATION_ENABLE=0 +ENV ATB_OPSRUNNER_SETUP_CACHE_ENABLE=1 +ENV ATB_OPSRUNNER_KERNEL_CACHE_TYPE=3 +ENV ATB_OPSRUNNER_KERNEL_CACHE_LOCAL_COUNT=1 +ENV ATB_OPSRUNNER_KERNEL_CACHE_GLOABL_COUNT=5 +ENV ATB_OPSRUNNER_KERNEL_CACHE_TILING_SIZE=10240 +ENV ATB_WORKSPACE_MEM_ALLOC_ALG_TYPE=1 +ENV ATB_WORKSPACE_MEM_ALLOC_GLOBAL=0 +ENV ATB_COMPARE_TILING_EVERY_KERNEL=0 +ENV ATB_HOST_TILING_BUFFER_BLOCK_NUM=128 +ENV ATB_DEVICE_TILING_BUFFER_BLOCK_NUM=32 +ENV ATB_SHARE_MEMORY_NAME_SUFFIX="" +ENV ATB_LAUNCH_KERNEL_WITH_TILING=1 +ENV ATB_MATMUL_SHUFFLE_K_ENABLE=1 +ENV ATB_RUNNER_POOL_SIZE=64 + +ENV ASDOPS_HOME_PATH=${ATB_HOME_PATH} +ENV ASDOPS_MATMUL_PP_FLAG=1 +ENV ASDOPS_LOG_LEVEL=ERROR +ENV ASDOPS_LOG_TO_STDOUT=0 +ENV ASDOPS_LOG_TO_FILE=1 +ENV ASDOPS_LOG_TO_FILE_FLUSH=0 +ENV ASDOPS_LOG_TO_BOOST_TYPE=atb +ENV ASDOPS_LOG_PATH=~ +ENV ASDOPS_TILING_PARSE_CACHE_DISABLE=0 + +ENV LCCL_DETERMINISTIC=0