Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Eval bug: Program crashes during long input inference when batch size is set to 16384 #14325

Open
@zts9989

Description

@zts9989

Name and Version

build/bin/llama-cli --version
version: b5731
built with cc (Debian 12.2.0-14) 12.2.0 for x86_64-linux-gnu

Operating systems

Linux

GGML backends

CUDA

Hardware

5950X + 2xA6000

Models

DeepSeek-R1-0528-Q4_K_M

Problem description & steps to reproduce

It appears that under certain conditions, the input for CUDA copy operations needs to be expanded from int32 to int64.

I asked DeepSeek to help me translate a bug report, and it directly provided a solution: bypass the GGML_ASSERT and add checks for the element count.

C
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
    const int64_t ne = ggml_nelements(src0);
    GGML_ASSERT(ne == ggml_nelements(src1));

    //GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
    //GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
    if (ne > INT_MAX) {
	    GGML_ABORT("ggml_cuda_cpy: number of elements (%ld) exceeds maximum supported value (%d)\n", ne, INT_MAX);
    }

    const int64_t ne00 = src0->ne[0];
    const int64_t ne01 = src0->ne[1];
    const int64_t ne02 = src0->ne[2];

This modification can extend the output of some tokens. However, it failed to fundamentally resolve the issue. It stated that a fundamental solution would require both the element count and the tensor size to be represented as int64.

Reproduction Steps:

Configure context length: -c 163840
Process input sequence length: ~35500 tokens
Set batch size: -b 16384 -ub 16384

build/bin/llama-server -m /data/DeepSeek-R1-0528-Q4_K_M-00001-of-00009.gguf -fa --temp 0.6 --top-p 0.95 -s 3047 --no-warmup -ngl 160 -c 163840 --host 0.0.0.0 -ot exps=CPU -b 16384 -ub 16384

First Bad Commit

No response

Relevant log output

<|User|>Hello<|Assistant|>Hi there<|end▁of▁sentence|><|User|>How are you?<|Assistant|>'
main: server is listening on http://0.0.0.0:8080 - starting the main loop
srv  update_slots: all slots are idle
slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 163840, n_keep = 0, n_prompt_tokens = 35500
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 16384, n_tokens = 16384, progress = 0.461521
slot update_slots: id  0 | task 0 | kv cache rm [16384, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 32768, n_tokens = 16384, progress = 0.923042
/data/llama.cpp/llama.cpp-b5731/ggml/src/ggml-cuda/cpy.cu:561: GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX) failed
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-base.so(+0x13e18)[0x7fb54b4a4e18]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-base.so(ggml_print_backtrace+0x1e4)[0x7fb54b4a51e4]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-base.so(ggml_abort+0xd6)[0x7fb54b4a5316]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-cuda.so(_Z13ggml_cuda_cpyR25ggml_backend_cuda_contextPK11ggml_tensorPS1_b+0xb69)[0x7fb549c70d99]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-cuda.so(+0xad5d2)[0x7fb549cad5d2]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-base.so(ggml_backend_sched_graph_compute_async+0x453)[0x7fb54b4ba5d3]
/data/llama.cpp/llama.cpp-b5731/build/bin/libllama.so(_ZN13llama_context13graph_computeEP11ggml_cgraphb+0x99)[0x7fb54b5eb8f9]
/data/llama.cpp/llama.cpp-b5731/build/bin/libllama.so(_ZN13llama_context14process_ubatchERK12llama_ubatch14llm_graph_typeP22llama_memory_context_iR11ggml_status+0xf3)[0x7fb54b5ebb53]
/data/llama.cpp/llama.cpp-b5731/build/bin/libllama.so(_ZN13llama_context6decodeERK11llama_batch+0x276)[0x7fb54b5efda6]
/data/llama.cpp/llama.cpp-b5731/build/bin/libllama.so(llama_decode+0xb)[0x7fb54b5f0f5b]
build/bin/llama-server(+0xbf4ad)[0x55d40531e4ad]
build/bin/llama-server(+0x8730a)[0x55d4052e630a]
build/bin/llama-server(+0x50835)[0x55d4052af835]
/lib/x86_64-linux-gnu/libc.so.6(+0x2724a)[0x7fb54af6624a]
/lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0x85)[0x7fb54af66305]
build/bin/llama-server(+0x52551)[0x55d4052b1551]
Aborted


Set batch size: -b 8192 -ub 8192 
<|User|>Hello<|Assistant|>Hi there<|end▁of▁sentence|><|User|>How are you?<|Assistant|>'
main: server is listening on http://0.0.0.0:8080 - starting the main loop
srv  update_slots: all slots are idle
slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 163840, n_keep = 0, n_prompt_tokens = 157832
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 8192, n_tokens = 8192, progress = 0.051903
slot update_slots: id  0 | task 0 | kv cache rm [8192, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 16384, n_tokens = 8192, progress = 0.103807
slot update_slots: id  0 | task 0 | kv cache rm [16384, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 24576, n_tokens = 8192, progress = 0.155710
slot update_slots: id  0 | task 0 | kv cache rm [24576, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 32768, n_tokens = 8192, progress = 0.207613
slot update_slots: id  0 | task 0 | kv cache rm [32768, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 40960, n_tokens = 8192, progress = 0.259516
slot update_slots: id  0 | task 0 | kv cache rm [40960, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 49152, n_tokens = 8192, progress = 0.311420
slot update_slots: id  0 | task 0 | kv cache rm [49152, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 57344, n_tokens = 8192, progress = 0.363323
slot update_slots: id  0 | task 0 | kv cache rm [57344, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 65536, n_tokens = 8192, progress = 0.415226
/data/llama.cpp/llama.cpp-b5731/ggml/src/ggml-cuda/cpy.cu:561: GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX) failed
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-base.so(+0x13e18)[0x7f522f6bae18]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-base.so(ggml_print_backtrace+0x1e4)[0x7f522f6bb1e4]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-base.so(ggml_abort+0xd6)[0x7f522f6bb316]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-cuda.so(_Z13ggml_cuda_cpyR25ggml_backend_cuda_contextPK11ggml_tensorPS1_b+0xb69)[0x7f522de70d99]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-cuda.so(+0xad5d2)[0x7f522dead5d2]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-base.so(ggml_backend_sched_graph_compute_async+0x453)[0x7f522f6d05d3]
/data/llama.cpp/llama.cpp-b5731/build/bin/libllama.so(_ZN13llama_context13graph_computeEP11ggml_cgraphb+0x99)[0x7f522f8018f9]
/data/llama.cpp/llama.cpp-b5731/build/bin/libllama.so(_ZN13llama_context14process_ubatchERK12llama_ubatch14llm_graph_typeP22llama_memory_context_iR11ggml_status+0xf3)[0x7f522f801b53]
/data/llama.cpp/llama.cpp-b5731/build/bin/libllama.so(_ZN13llama_context6decodeERK11llama_batch+0x276)[0x7f522f805da6]

***DeepSeek-modified version***
<|User|>Hello<|Assistant|>Hi there<|end▁of▁sentence|><|User|>How are you?<|Assistant|>'
main: server is listening on http://0.0.0.0:8080 - starting the main loop
srv  update_slots: all slots are idle
slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 163840, n_keep = 0, n_prompt_tokens = 158350
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 16384, n_tokens = 16384, progress = 0.103467
slot update_slots: id  0 | task 0 | kv cache rm [16384, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 32768, n_tokens = 16384, progress = 0.206934
slot update_slots: id  0 | task 0 | kv cache rm [32768, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 49152, n_tokens = 16384, progress = 0.310401
slot update_slots: id  0 | task 0 | kv cache rm [49152, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 65536, n_tokens = 16384, progress = 0.413868
slot update_slots: id  0 | task 0 | kv cache rm [65536, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 81920, n_tokens = 16384, progress = 0.517335
slot update_slots: id  0 | task 0 | kv cache rm [81920, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 98304, n_tokens = 16384, progress = 0.620802
slot update_slots: id  0 | task 0 | kv cache rm [98304, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 114688, n_tokens = 16384, progress = 0.724269
slot update_slots: id  0 | task 0 | kv cache rm [114688, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 131072, n_tokens = 16384, progress = 0.827736
/data/llama.cpp/llama.cpp-b5731/ggml/src/ggml-cuda/cpy.cu:564: ggml_cuda_cpy: number of elements (2147483648) exceeds maximum supported value (2147483647)

/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-base.so(+0x13e18)[0x7f2ad87b7e18]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-base.so(ggml_print_backtrace+0x1e4)[0x7f2ad87b81e4]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-base.so(ggml_abort+0xd6)[0x7f2ad87b8316]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-cuda.so(_Z13ggml_cuda_cpyR25ggml_backend_cuda_contextPK11ggml_tensorPS1_b+0xb3a)[0x7f2ad6e70d6a]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-cuda.so(+0xad5a2)[0x7f2ad6ead5a2]
/data/llama.cpp/llama.cpp-b5731/build/bin/libggml-base.so(ggml_backend_sched_graph_compute_async+0x453)[0x7f2ad87cd5d3]
/data/llama.cpp/llama.cpp-b5731/build/bin/libllama.so(_ZN13llama_context13graph_computeEP11ggml_cgraphb+0x99)[0x7f2ad88fe8f9]
/data/llama.cpp/llama.cpp-b5731/build/bin/libllama.so(_ZN13llama_context14process_ubatchERK12llama_ubatch14llm_graph_typeP22llama_memory_context_iR11ggml_status+0xf3)[0x7f2ad88feb53]
/data/llama.cpp/llama.cpp-b5731/build/bin/libllama.so(_ZN13llama_context6decodeERK11llama_batch+0x276)[0x7f2ad8902da6]
/data/llama.cpp/llama.cpp-b5731/build/bin/libllama.so(llama_decode+0xb)[0x7f2ad8903f5b]
/data/llama.cpp/llama.cpp-b5731/build/bin/llama-server(+0xbf4ad)[0x564afd2c74ad]
/data/llama.cpp/llama.cpp-b5731/build/bin/llama-server(+0x8730a)[0x564afd28f30a]

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions