diff --git a/docker/Dockerfile_aarch64_ascend b/docker/Dockerfile_aarch64_ascend index 5ed842061c..c502826d1f 100644 --- a/docker/Dockerfile_aarch64_ascend +++ b/docker/Dockerfile_aarch64_ascend @@ -110,7 +110,8 @@ RUN echo "source /usr/local/Ascend/ascend-toolkit/set_env.sh" >> ~/.bashrc && \ RUN --mount=type=cache,target=/root/.cache/pip \ pip3 install torch==2.3.1 torchvision==0.18.1 torch-npu==2.3.1 && \ pip3 install transformers timm && \ - pip3 install dlinfer-ascend + pip3 install dlinfer-ascend && \ + pip3 install partial_json_parser shortuuid # lmdeploy FROM build_temp as copy_temp @@ -122,3 +123,47 @@ WORKDIR /opt/lmdeploy RUN --mount=type=cache,target=/root/.cache/pip \ LMDEPLOY_TARGET_DEVICE=ascend pip3 install -v --no-build-isolation -e . + +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling/lib/linux/$(arch):$LD_LIBRARY_PATH +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/tools/aml/lib64:${ASCEND_TOOLKIT_HOME}/tools/aml/lib64/plugin:$LD_LIBRARY_PATH +ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:$PYTHONPATH +ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${ASCEND_TOOLKIT_HOME}/tools/ccec_compiler/bin:$PATH +ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME} +ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit +ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME} + +ENV ATB_HOME_PATH=/usr/local/Ascend/nnal/atb/latest/atb/cxx_abi_0 +ENV LD_LIBRARY_PATH=${ATB_HOME_PATH}/lib:${ATB_HOME_PATH}/examples:${ATB_HOME_PATH}/tests/atbopstest:$LD_LIBRARY_PATH +ENV PATH=${ATB_HOME_PATH}/bin:$PATH + +ENV ATB_STREAM_SYNC_EVERY_KERNEL_ENABLE=0 +ENV ATB_STREAM_SYNC_EVERY_RUNNER_ENABLE=0 +ENV ATB_STREAM_SYNC_EVERY_OPERATION_ENABLE=0 +ENV ATB_OPSRUNNER_SETUP_CACHE_ENABLE=1 +ENV ATB_OPSRUNNER_KERNEL_CACHE_TYPE=3 +ENV ATB_OPSRUNNER_KERNEL_CACHE_LOCAL_COUNT=1 +ENV ATB_OPSRUNNER_KERNEL_CACHE_GLOABL_COUNT=5 +ENV ATB_OPSRUNNER_KERNEL_CACHE_TILING_SIZE=10240 +ENV ATB_WORKSPACE_MEM_ALLOC_ALG_TYPE=1 +ENV ATB_WORKSPACE_MEM_ALLOC_GLOBAL=0 +ENV ATB_COMPARE_TILING_EVERY_KERNEL=0 +ENV ATB_HOST_TILING_BUFFER_BLOCK_NUM=128 +ENV ATB_DEVICE_TILING_BUFFER_BLOCK_NUM=32 +ENV ATB_SHARE_MEMORY_NAME_SUFFIX="" +ENV ATB_LAUNCH_KERNEL_WITH_TILING=1 +ENV ATB_MATMUL_SHUFFLE_K_ENABLE=1 +ENV ATB_RUNNER_POOL_SIZE=64 + +ENV ASDOPS_HOME_PATH=${ATB_HOME_PATH} +ENV ASDOPS_MATMUL_PP_FLAG=1 +ENV ASDOPS_LOG_LEVEL=ERROR +ENV ASDOPS_LOG_TO_STDOUT=0 +ENV ASDOPS_LOG_TO_FILE=1 +ENV ASDOPS_LOG_TO_FILE_FLUSH=0 +ENV ASDOPS_LOG_TO_BOOST_TYPE=atb +ENV ASDOPS_LOG_PATH=~ +ENV ASDOPS_TILING_PARSE_CACHE_DISABLE=0 + +ENV LCCL_DETERMINISTIC=0 diff --git a/docs/en/get_started/ascend/get_started.md b/docs/en/get_started/ascend/get_started.md index be39f179b3..004d0adb20 100644 --- a/docs/en/get_started/ascend/get_started.md +++ b/docs/en/get_started/ascend/get_started.md @@ -158,6 +158,16 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu Please check [supported_models](../../supported_models/supported_models.md) before use this feature. +### w8a8 SMOOTH_QUANT + +Run the following commands to quantize weights on Atlas 800T A2. + +```bash +lmdeploy lite smooth_quant $HF_MODEL --work-dir $WORK_DIR --device npu +``` + +Please check [supported_models](../../supported_models/supported_models.md) before use this feature. + ### int8 KV-cache Quantization Ascend backend has supported offline int8 KV-cache Quantization on eager mode. diff --git a/docs/en/get_started/installation.md b/docs/en/get_started/installation.md index 2c5a9b007a..05a624428e 100644 --- a/docs/en/get_started/installation.md +++ b/docs/en/get_started/installation.md @@ -23,7 +23,7 @@ pip install lmdeploy The default prebuilt package is compiled on **CUDA 12**. If CUDA 11+ (>=11.3) is required, you can install lmdeploy by: ```shell -export LMDEPLOY_VERSION=0.7.2.post1 +export LMDEPLOY_VERSION=0.7.3 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index 176c10df46..2ac7a3ee6e 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -115,20 +115,23 @@ The following tables detail the models supported by LMDeploy's TurboMind engine ## PyTorchEngine on Huawei Ascend Platform -| Model | Size | Type | FP16/BF16(eager) | FP16/BF16(graph) | W4A16(eager) | -| :------------: | :------: | :--: | :--------------: | :--------------: | :----------: | -| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | -| Llama3 | 8B | LLM | Yes | Yes | Yes | -| Llama3.1 | 8B | LLM | Yes | Yes | Yes | -| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | -| InternLM2.5 | 7B - 20B | LLM | Yes | Yes | Yes | -| InternLM3 | 8B | LLM | Yes | Yes | Yes | -| Mixtral | 8x7B | LLM | Yes | Yes | No | -| QWen1.5-MoE | A2.7B | LLM | Yes | - | No | -| QWen2(.5) | 7B | LLM | Yes | Yes | No | -| QWen2-MoE | A14.57B | LLM | Yes | - | No | -| DeepSeek-V2 | 16B | LLM | No | Yes | No | -| InternVL(v1.5) | 2B-26B | MLLM | Yes | - | Yes | -| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | -| CogVLM2-chat | 19B | MLLM | Yes | No | - | -| GLM4V | 9B | MLLM | Yes | No | - | +| Model | Size | Type | FP16/BF16(eager) | FP16/BF16(graph) | W8A8(graph) | W4A16(eager) | +| :------------: | :------: | :--: | :--------------: | :--------------: | :---------: | :----------: | +| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | +| Llama3 | 8B | LLM | Yes | Yes | Yes | Yes | +| Llama3.1 | 8B | LLM | Yes | Yes | Yes | Yes | +| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | +| InternLM2.5 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | +| InternLM3 | 8B | LLM | Yes | Yes | Yes | Yes | +| Mixtral | 8x7B | LLM | Yes | Yes | No | No | +| QWen1.5-MoE | A2.7B | LLM | Yes | - | No | No | +| QWen2(.5) | 7B | LLM | Yes | Yes | Yes | Yes | +| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | - | - | +| QWen2.5-VL | 3B - 72B | MLLM | Yes | Yes | - | - | +| QWen2-MoE | A14.57B | LLM | Yes | - | No | No | +| DeepSeek-V2 | 16B | LLM | No | Yes | No | No | +| InternVL(v1.5) | 2B-26B | MLLM | Yes | - | Yes | Yes | +| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | Yes | +| InternVL2.5 | 1B-78B | MLLM | Yes | Yes | Yes | Yes | +| CogVLM2-chat | 19B | MLLM | Yes | No | - | - | +| GLM4V | 9B | MLLM | Yes | No | - | - | diff --git a/docs/zh_cn/get_started/ascend/get_started.md b/docs/zh_cn/get_started/ascend/get_started.md index 9161d6646d..8dad380a52 100644 --- a/docs/zh_cn/get_started/ascend/get_started.md +++ b/docs/zh_cn/get_started/ascend/get_started.md @@ -154,6 +154,16 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu 支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。 +### w8a8 SMOOTH_QUANT + +运行下面的代码可以在Atlas 800T A2上对权重进行W8A8量化。 + +```bash +lmdeploy lite smooth_quant $HF_MODEL --work-dir $WORK_DIR --device npu +``` + +支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。 + ### int8 KV-cache 量化 昇腾后端现在支持了在eager模式下的离线int8 KV-cache量化。 diff --git a/docs/zh_cn/get_started/installation.md b/docs/zh_cn/get_started/installation.md index c60297a3c2..1bfc924912 100644 --- a/docs/zh_cn/get_started/installation.md +++ b/docs/zh_cn/get_started/installation.md @@ -23,7 +23,7 @@ pip install lmdeploy 默认的预构建包是在 **CUDA 12** 上编译的。如果需要 CUDA 11+ (>=11.3),你可以使用以下命令安装 lmdeploy: ```shell -export LMDEPLOY_VERSION=0.7.2.post1 +export LMDEPLOY_VERSION=0.7.3 export PYTHON_VERSION=38 pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 ``` diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index c5cae84e47..07ca1c93bd 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -115,19 +115,23 @@ ## PyTorchEngine 华为昇腾平台 -| Model | Size | Type | FP16/BF16(eager) | FP16/BF16(graph) | W4A16(eager) | -| :------------: | :------: | :--: | :--------------: | :--------------: | :----------: | -| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | -| Llama3 | 8B | LLM | Yes | Yes | Yes | -| Llama3.1 | 8B | LLM | Yes | Yes | Yes | -| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | -| InternLM2.5 | 7B - 20B | LLM | Yes | Yes | Yes | -| Mixtral | 8x7B | LLM | Yes | Yes | No | -| QWen1.5-MoE | A2.7B | LLM | Yes | - | No | -| QWen2(.5) | 7B | LLM | Yes | Yes | No | -| QWen2-MoE | A14.57B | LLM | Yes | - | No | -| DeepSeek-V2 | 16B | LLM | No | Yes | No | -| InternVL(v1.5) | 2B-26B | MLLM | Yes | - | Yes | -| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | -| CogVLM2-chat | 19B | MLLM | Yes | No | - | -| GLM4V | 9B | MLLM | Yes | No | - | +| Model | Size | Type | FP16/BF16(eager) | FP16/BF16(graph) | W8A8(graph) | W4A16(eager) | +| :------------: | :------: | :--: | :--------------: | :--------------: | :---------: | :----------: | +| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | +| Llama3 | 8B | LLM | Yes | Yes | Yes | Yes | +| Llama3.1 | 8B | LLM | Yes | Yes | Yes | Yes | +| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | +| InternLM2.5 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | +| InternLM3 | 8B | LLM | Yes | Yes | Yes | Yes | +| Mixtral | 8x7B | LLM | Yes | Yes | No | No | +| QWen1.5-MoE | A2.7B | LLM | Yes | - | No | No | +| QWen2(.5) | 7B | LLM | Yes | Yes | Yes | Yes | +| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | - | - | +| QWen2.5-VL | 3B - 72B | MLLM | Yes | Yes | - | - | +| QWen2-MoE | A14.57B | LLM | Yes | - | No | No | +| DeepSeek-V2 | 16B | LLM | No | Yes | No | No | +| InternVL(v1.5) | 2B-26B | MLLM | Yes | - | Yes | Yes | +| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | Yes | +| InternVL2.5 | 1B-78B | MLLM | Yes | Yes | Yes | Yes | +| CogVLM2-chat | 19B | MLLM | Yes | No | - | - | +| GLM4V | 9B | MLLM | Yes | No | - | - | diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 7f26f59980..bc5b582d05 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -3,6 +3,7 @@ import torch +from lmdeploy.pytorch.backends.selector import get_backend from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig from lmdeploy.pytorch.model_inputs import StepContext from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta @@ -116,6 +117,7 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf self.graph_pool_handle = torch.cuda.graph_pool_handle() self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict() + self.has_try_compile_model: bool = False def check_enable_graph(self): """check enable graph.""" @@ -124,6 +126,16 @@ def check_enable_graph(self): return getattr(self.model, 'support_cuda_graph', _false) + def _try_compile_model_once(self): + if self.has_try_compile_model: + return + + if hasattr(self.model, 'compile_model'): + method = getattr(self.model, 'compile_model') + method() + + self.has_try_compile_model = True + def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List, attn_metadata: Any, inputs_embeds: torch.Tensor, **kwargs): """get graph key.""" @@ -135,6 +147,9 @@ def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, pas def __call__(self, **kwargs): """call.""" + if not self.backend_config.eager_mode and get_backend().get_name() == 'cuda': + self._try_compile_model_once() + enable_graph = self.enable_graph(**kwargs) if not enable_graph: diff --git a/lmdeploy/pytorch/engine/executor/dist_utils.py b/lmdeploy/pytorch/engine/executor/dist_utils.py index bf57b86893..91f437c87d 100644 --- a/lmdeploy/pytorch/engine/executor/dist_utils.py +++ b/lmdeploy/pytorch/engine/executor/dist_utils.py @@ -6,14 +6,17 @@ import torch.distributed as dist -def find_available_port() -> bool: +def find_available_port(start_port: int = 0) -> int: """find available port.""" - port = 29500 - while True: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - if s.connect_ex(('localhost', port)) != 0: - return port - port += 1 + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(('127.0.0.1', start_port)) + port = s.getsockname()[1] + return port + except socket.error as e: + if start_port == 0: + raise RuntimeError('Failed to find available port.') from e + return find_available_port(0) def setup_master_addr(addr: str, port: str): diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index 782911f767..bb62544691 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -230,13 +230,17 @@ def __init__(self, logger.info('Warming up distribute environment, this might take long time, please waiting...') ray.get([worker.warmup_dist.remote() for worker in self.workers]) - def collective_rpc(self, method: str, args: Tuple[Any] = None, kwargs: Dict[str, Any] = None): + def collective_rpc(self, + method: str, + args: Tuple[Any] = None, + kwargs: Dict[str, Any] = None, + timeout: float = None): """collective rpc.""" if args is None: args = list() if kwargs is None: kwargs = dict() - return ray.get([getattr(worker, method).remote(*args, **kwargs) for worker in self.workers]) + return ray.get([getattr(worker, method).remote(*args, **kwargs) for worker in self.workers], timeout=timeout) def build_model(self): """build model.""" @@ -293,9 +297,21 @@ def stop(self): def release(self): """release.""" - self.collective_rpc('release') - for worker in self.workers: - ray.kill(worker) + if self.dp == 1: + try: + self.collective_rpc('release', timeout=5.0) + logger.debug('RayExecutor workers released.') + except ray.exceptions.GetTimeoutError: + logger.info('Ray release timeout.') + + try: + self.collective_rpc('exit') + logger.debug('RayExecutor workers exited.') + except ray.exceptions.RayActorError as e: + logger.debug(f'ray actor exit: {e}') + else: + [ray.kill(worker) for worker in self.workers] + ray.util.remove_placement_group(self.placement_group) def _compile_dag(self): diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 27bf3d5320..c6940a69d2 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -4,11 +4,14 @@ import torch import torch.nn.functional as F +from packaging import version from torch import nn from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.distributed import get_world_rank from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch, split_batch from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor from lmdeploy.pytorch.nn import LayerNorm, RMSNorm from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear @@ -205,15 +208,22 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.ls1 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device)) self.ls2 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device)) - def forward( - self, - hidden_states: torch.Tensor, - ): - """forward.""" - hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1 + @enable_micro_batch(param_name='hidden_states', index=0) + def _attn(self, hidden_states): + hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states[0].dtype)) * self.ls1 + return hidden_states + @enable_micro_batch(param_name='hidden_states', index=0) + def _mlp(self, hidden_states): hidden_states = hidden_states + self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2 + return hidden_states + def forward( + self, + hidden_states, + ): + hidden_states = self._attn(hidden_states) + hidden_states = self._mlp(hidden_states) return hidden_states @@ -306,6 +316,33 @@ def __init__(self, self.input_processor = InternVLInputProcessor(self.config, dtype) + self.compile_vit = False + + def compile_model(self): + torch_version = version.parse(torch.__version__) + if torch_version < version.parse('2.5.0'): + return + + world_size, _ = get_world_rank() + if torch_version >= version.parse('2.6.0') and world_size > 1: + torch._inductor.config.reorder_for_compute_comm_overlap = True + if isinstance(self.vision_model, InternVisionModel): + self.vision_model.encoder.forward = split_batch(self.vision_model.encoder.forward, + 'inputs_embeds', + index=0) + + self.extract_feature = torch.compile(self.extract_feature, mode='max-autotune') + self.compile_vit = True + self.has_compiled_vit = False + + def _mark_dynamic_once(self, pixel_values, dims): + """call torch._dynamo.mark_dynamic to avoid recompile.""" + if not self.compile_vit or self.has_compiled_vit or pixel_values is None: + return + + torch._dynamo.mark_dynamic(pixel_values, dims) + self.has_compiled_vit = True + def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale @@ -350,6 +387,7 @@ def forward( ): if inputs_embeds is None and pixel_values is not None: # extract feature + self._mark_dynamic_once(pixel_values, [0]) vit_embeds = self.extract_feature(pixel_values) lang_embeds = self.language_model.get_input_embeddings()(input_ids) lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) diff --git a/lmdeploy/pytorch/models/utils/micro_batch.py b/lmdeploy/pytorch/models/utils/micro_batch.py new file mode 100644 index 0000000000..aba340b7c8 --- /dev/null +++ b/lmdeploy/pytorch/models/utils/micro_batch.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import torch + + +def enable_micro_batch(param_name, index=-1): + """Decorator factory to enable micro-batch computation.""" + + def decorator(func): + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if index != -1 and len(args) > index: + inputs = args[index] + else: + inputs = kwargs.get(param_name, None) + + if isinstance(inputs, list): + # Apply forward computation to each micro-batch + results = [] + for input in inputs: + if index != -1 and len(args) > index: + args = args[0:index] + (input, ) + args[index + 1:] + else: + kwargs[param_name] = input + result = func(self, *args, **kwargs) + results.append(result) + return results + else: + # If not a list, directly apply the forward computation + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +def split_batch(func, param_name, index=-1, num_splits=2): + """Decorator to split along the 0th dimension into a specified number of + chunks.""" + + def wrapper(*args, **kwargs): + if index != -1 and len(args) > index: + inputs = args[index] + else: + inputs = kwargs.get(param_name, None) + + if inputs is not None: + split_inputs = list(torch.chunk(inputs, num_splits, dim=0)) + if index != -1 and len(args) > index: + args = args[0:index] + (split_inputs, ) + args[index + 1:] + else: + kwargs[param_name] = split_inputs + + results = func(*args, **kwargs) + return torch.cat(results, dim=0) + else: + return func(*args, **kwargs) + + return wrapper diff --git a/lmdeploy/version.py b/lmdeploy/version.py index 106176c7a7..b001595fdd 100644 --- a/lmdeploy/version.py +++ b/lmdeploy/version.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple -__version__ = '0.7.2.post1' +__version__ = '0.7.3' short_version = __version__ diff --git a/requirements/runtime_cuda.txt b/requirements/runtime_cuda.txt index 88c6270877..9169c6ff30 100644 --- a/requirements/runtime_cuda.txt +++ b/requirements/runtime_cuda.txt @@ -16,8 +16,8 @@ safetensors sentencepiece shortuuid tiktoken -torch<=2.5.1,>=2.0.0 -torchvision<=0.20.1,>=0.15.0 +torch<=2.6.0,>=2.0.0 +torchvision<=0.21.0,>=0.15.0 transformers -triton<=3.1.0,>=3.0.0; sys_platform == "linux" +triton<=3.2.0,>=3.0.0; sys_platform == "linux" uvicorn