Thanks to visit codestin.com
Credit goes to github.com

Skip to content

跑base_demo遇到的问题 #550

@NicholasEinstein

Description

@NicholasEinstein

System Info / 系統信息

Traceback (most recent call last):
File "/home/ubuntu/haize/CogVLM/basic_demo/cli_demo_hf.py", line 137, in
outputs = model.generate(**inputs, **gen_kwargs)
File "/home/ubuntu/anaconda3/envs/cogagent/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/cogagent/lib/python3.10/site-packages/transformers/generation/utils.py", line 1527, in generate
result = self._greedy_search(
File "/home/ubuntu/anaconda3/envs/cogagent/lib/python3.10/site-packages/transformers/generation/utils.py", line 2411, in _greedy_search
outputs = self(
File "/home/ubuntu/anaconda3/envs/cogagent/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/cogagent/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/cogvlm-chat-hf/modeling_cogvlm.py", line 660, in forward
outputs = self.model(
File "/home/ubuntu/anaconda3/envs/cogagent/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/cogvlm-chat-hf/modeling_cogvlm.py", line 428, in forward
images_features = self.encode_images(images)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/cogvlm-chat-hf/modeling_cogvlm.py", line 400, in encode_images
images_features = self.vision(images)
File "/home/ubuntu/anaconda3/envs/cogagent/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/cogvlm-chat-hf/visual.py", line 134, in forward
x = torch.cat((boi, x, eoi), dim=1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:5 and cuda:7! (when checking argument for argument tensors in method wrapper_CUDA_cat)

我发现在多卡3090跑17B的cogvlm的base_demo的时候会报出这个错,经过排查发现是visual.py里面的boi和eoi不会跟随x的流向。当x最终不在初始的gpu的时候,boi和eoi在初始的gpu就会cat的时候报这个错。
解决方法: 在visual.py里面找到这里
boi = self.boi.expand(x.shape[0], -1, -1)
eoi = self.eoi.expand(x.shape[0], -1, -1)
改为:
boi = self.boi.expand(x.shape[0], -1, -1).to(x.device)
eoi = self.eoi.expand(x.shape[0], -1, -1).to(x.device)

Who can help? / 谁可以帮助到您?

No response

Information / 问题信息

  • The official example scripts / 官方的示例脚本
  • My own modified scripts / 我自己修改的脚本和任务

Reproduction / 复现过程

多卡3090跑17B的cogvlm的base_demo的时候会报出这个错,单卡则不会

Expected behavior / 期待表现

或者作者有更好的解决方案吗

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions