Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@danbev
Copy link
Member

@danbev danbev commented Nov 4, 2025

This is a work in progress to add support for GPU sampling.

The motivation for this feature is to enable sampling to be performed directly on the GPU as part of the computation graph being executed, allowing for some or all of the sampling to be done on the GPU.

For example, the GPU sampler chain might select/sample a token directly in which case only the sampled token needs to be transferred from device memory to host memory.

It is also possible for the GPU samplers to perform filtering of the logits, or compute and filter the probability distribution, in which case only the filtered logits or probabilites need to be transferred back to system memory for further processing by CPU samplers.

Currently the GPU sampling works in a similar manner to how pooling works, it is a function that is called by build_graph and the sampler operations become part of the models computation graph.

GPU samplers can be configured by creating sampler chains, where each sampler chain is associated with a specific sequence id:

    struct llama_sampler_chain_params params = llama_sampler_chain_default_params();
    struct llama_sampler * chain = llama_sampler_chain_init(params);
    llama_sampler_chain_add(chain, llama_sampler_gpu_init_greedy());
    std::vector<llama_sampler_seq_config> sampler_configs = {
        { 0, gpu_sampler_chain }
    };

The struct is defined as:

    struct llama_sampler_seq_config {
        llama_seq_id           seq_id;
        struct llama_sampler * sampler;
    };

These sampler configs are then passed as context params:

    llama_context_params cparams = llama_context_default_params();
    cparams.samplers = sampler_configs.data();
    cparams.n_samplers = sampler_configs.size();

When the model graph is built the GPU samplers will be called to enable them to add their operations to the graph:

ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
    std::unique_ptr<llm_graph_context> llm;
    ...

    // add GPU sampling layers (if any)
    llm->build_sampling(*this, params);

The llama_sampler_i interface as been extended with 4 new methods in the API, and they are currently all named with a _ggml suffix to indicate that they are for GPU sampling (and possibly other devices like NPUs in the future):

        void                   (*init_ggml)(struct llama_sampler      * smpl,
                                            ggml_backend_buffer_type_t  buft);

        void                   (*set_input_ggml)( struct llama_sampler * smpl,
                                                       ggml_context * ctx,
                                                        ggml_cgraph * gf);

        void                   (*apply_ggml)(  struct llama_sampler * smpl,
                                                       ggml_context * ctx,
                                                        ggml_cgraph * gf,
                                            llama_sampler_ggml_data * ggml_data);

        void                   (*accept_ggml)( struct llama_sampler * smpl,
                                                       ggml_context * ctx,
                                                        ggml_cgraph * gf,
                                               struct ggml_tensor * selected_token);

The init_ggml function allows GPU samplers to create input tensors that they might need. The ggml_backend_buffer_type should be used so that the tensors are created using this backend buffer type, which is the same as the ouput logits backend. This avoids splits in the computation graph that would require data transfer between different backends.

The set_input_ggml function is called after the computation graph has been scheduled but before it is computed. This allows the GPU sampler to set any input for the tensors it created in init_ggml.

The apply_ggml function is where the GPU sampler adds its operations to the graphs. When the graph is built, the configured sampler's _apply function is called which allows them to add operations/nodes to the computation graph.

The accept_ggml functions allows GPU samplers to update their tensor states if needed.

This enables the sampling to happen fully, or partially on the GPU. The samplers could sample a single token in which case that is what will be transferred from the device memory to host memory after llama_decode has been called. The sampled token can then be retrieved using:

    llama_token id = llama_get_sampled_token_ith(test_ctx.ctx, index);

Is it also possible to run a GPU sampler that only filters the logits and then only the filtered logits are transferred back to the host and the sampling can proceed on the CPU with the normal (CPU) sampler chain. In this case the CPU samplers are configured as usual but they will now operate on already filtered logits.

Similar to the above handling of logits, it is possible for a GPU samplers to compute the full probability distribution and transfer that to the host. And the CPU samplers can then operate on the those probabilities.

Building and running the tests

Download a model for testing:

$ cd models && wget https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf

Building the test:

$ cmake --build build --target test-gpu-sampling -j8

Runing all tests:

$ env LLAMACPP_TEST_MODELFILE=../models/stories15M-q4_0.gguf \
    ctest --test-dir build -R '^test-gpu-sampling$' -V

The following individual tests are available:

$ ctest --test-dir build -N -R test-gpu-sampling-
  Test 35: test-gpu-sampling-greedy
  Test 36: test-gpu-sampling-temp
  Test 38: test-gpu-sampling-top_k
  Test 40: test-gpu-sampling-mul_seq

Total Tests: 6

These can be run individually, for example:

$ env LLAMACPP_TEST_MODELFILE=../models/stories15M-q4_0.gguf \
    ctest --test-dir build -R 'test-gpu-sampling-temp' -V

llama-cli

Initial support for llama-cli has been added and can be used as follows:

$ export GGML_SCHED_DEBUG=2
$ ./build/bin/llama-cli -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \
    -p "What is the Capital of Sweden?" \
    --gpu-sampling \
    --gpu-dist \
    -ngl 99 \
    -no-cnv \
    -n 20 \
    --no-warmup

(To print the backend schedulers assignments add -v/--verbose to the above command in combination with GGML_SCHED_DEBUG)

llama-server

GPU sampling can be enabled using the following global configuration command line options:

$ ./build-gpu-sampling/bin/llama-server --help
...
----- sampling params -----
...
--gpu-sampling                          enable GPU sampling (default: disabled)
--gpu-dist                              perform final sampling on GPU (default: disabled)

Usage:

$ export GGML_SCHED_DEBUG=2
$ ./build/bin/llama-server \
      -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \
      --gpu-sampling \
      --temp 0.8 \
      --top-k 40 \
      -ngl 50

(To print the backend schedulers assignments add -v/--verbose to the above command in combination with GGML_SCHED_DEBUG)

It is then possible to specify send GPU request parameters as follows:

curl --request POST \
    --url http://localhost:8080/completion \
    --header "Content-Type: application/json" \
    --data '{"prompt": "What is the capital of Sweden?","n_predict": 20, "top_k": 40, "gpu_dist": true}'

The gpu_dist option will cause the dist GPU sampler to sample a token. Without setting this the CPU samplers will be able to process the filtered tokens that GPU sampler produced. This currently needs more work on the CPU samplers side to work.

To enable testing with the webui, the following settings have been added:
gpu-settings

TODO

  • Allocate GPU sampler tensors on the same backend as the logits (dev_output.dev)
  • Allow GPU samplers to pre-allocate state tensors
  • Integrate GPU samplers with llama-cli
  • Set/unset GPU samplers
  • Integrate GPU samplers with llama-server
  • Add more tests/assertions for the gpu samplers to check more cases
  • penalties samplers
  • Add support for operations like ggml_top_k (support vocabulary size tensors) in all backends
  • Add ggml_cumsum operation to all backends There is being done as part of Add ops needed for new hybrid models: SOFTPLUS, EXPM1, TRI, SOLVE_TRI, CUMSUM #17063
  • Consistent and clearer naming of GPU (device sampling) functions and data types.

Implemented GPU samplers

  • temp
  • logit_bias
  • top_k (Not fully supported on all backends, see note below regarding argsort)
  • greedy
  • dist sampler

Remaining GPU samplers

The list below are the current CPU sampler that exist. All of these might not be appropriate as GPU samplers.

  • top_p
  • min_p
  • typical
  • temp_ext
  • xtc
  • top_n_sigma
  • mirostat/mirostat_v2
  • penalties
  • dry
  • infill

I think we should have support in all backends for the operations that the GPU samplers use. At the moment this is not the case and currently if the target backend device (the same device that holds the logits tensor) does not support the operation a warning is printed similar to this:

Warning: backend does not support argsort operation required for top-k sampling
CPU backend will be used instead which defeats the purpose of having GPU samplers

📝 Note:
ARGSORT is not supported for arbitrary column width on Metal at the moment

       case GGML_OP_ARGSORT:
           // TODO: Support arbitrary column width
           return op->src[0]->ne[0] <= 1024;

So on macos, samplers that use ARGSORT currenty don't work. And for GPU samplers the dimension can as large as the model vocab size, for example:

(lldb) p op->src[0]->ne[0]
(int64_t) 32000

@github-actions github-actions bot added the testing Everything test related label Nov 4, 2025
@am17an
Copy link
Collaborator

am17an commented Nov 5, 2025

One place this would be useful immediately is the diffusion-cli. I'm happy to test this when it's ready

@danbev danbev force-pushed the gpu-sampling branch 2 times, most recently from 71b0e3d to c82b67b Compare November 6, 2025 06:14
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Nov 6, 2025
@danbev danbev force-pushed the gpu-sampling branch 2 times, most recently from 56bca5e to 5d18032 Compare November 6, 2025 06:27
@danbev danbev force-pushed the gpu-sampling branch 7 times, most recently from f49a857 to 7c6dc02 Compare November 11, 2025 12:05
@danbev danbev force-pushed the gpu-sampling branch 2 times, most recently from 2828c48 to 346d8c9 Compare November 12, 2025 05:30
This is a work in progress to add support for GPU sampling.

The motivation for this feature is to enable sampling to be performed
directly on the GPU as part of the computation graph being executed,
allowing for some or all of the sampling to be done on the GPU.

For example, the GPU sampler chain might select/sample a token directly
in which case only the sampled token needs to be transferred from
device memory to host memory.

It is also possible for the GPU samplers to perform filtering of the
logits, or compute and filter the probability distribution, in which
case only the filtered logits or probabilites need to be transferred
back to system memory for further processing by CPU samplers.

Currently the GPU sampling works in a similar manner to how pooling
works, it is a function that is called by build_graph:
```c++
    // add GPU sampling layers (if any)
    llm->build_sampling(*this, params);
```

GPU samplers can be configured by creating sampler chains, where each
sampler chain is associated with a specific sequence id:
```c++
    struct llama_sampler_chain_params params = llama_sampler_chain_default_params();
    struct llama_sampler * chain = llama_sampler_chain_init(params);
    llama_sampler_chain_add(chain, llama_sampler_gpu_init_greedy());
    std::vector<llama_sampler_seq_config> sampler_configs = {
        { 0, gpu_sampler_chain }
    };
```
The struct is defined as:
```c++
    struct llama_sampler_seq_config {
        llama_seq_id           seq_id;
        struct llama_sampler * sampler;
    };
```

These sampler configs are then passed as context params:
```c++
        llama_context_params cparams = llama_context_default_params();
        cparams.samplers = sampler_configs.data();
        cparams.n_samplers = sampler_configs.size();
```

When the graph is built, the configured sampler's _apply function is
called which allows them to add operations/nodes to the computation
graph.

This enables the sampling to happen fully, or partially on the GPU. The
samplers could sample a single token in which case that is what will be
transferred from the device memory to host memory after llama_decode has
been called. The sampled token can then be retrieved using:
```c++
    llama_token id = llama_get_sampled_token_ith(test_ctx.ctx, index);
```

Is it also possible to run a GPU sampler that only filters the logits
and then only the filtered logits are transferred back to the host and
the sampling can proceed on the CPU with the normal (CPU) sampler chain.
In this case the CPU samplers are configured as usual but they will now
operate on already filtered logits.

Similar to the above handling of logits, it is possible for a GPU
samplers to compute the full probability distribution and transfer that
to the host. And the CPU samplers can then operate on the those
probabilities.

Building and running the tests:

Download a model for testing:
```console
$ cd models && wget https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf
```
Building the test:
```console
$ cmake --build build --target test-gpu-sampling -j8
```
Runing all tests:
```console
$ env LLAMACPP_TEST_MODELFILE=../models/stories15M-q4_0.gguf \
    ctest --test-dir build -R '^test-gpu-sampling$' -V
```

The following individual tests are available:
```console
$ ctest --test-dir build -N -R test-gpu-sampling-
  Test 35: test-gpu-sampling-greedy
  Test 36: test-gpu-sampling-temp
  Test 37: test-gpu-sampling-softmax
  Test 38: test-gpu-sampling-top_k
  Test 39: test-gpu-sampling-top_p
  Test 40: test-gpu-sampling-mul_seq

Total Tests: 6
```
These can be run individually, for example:
```console
$ env LLAMACPP_TEST_MODELFILE=../models/stories15M-q4_0.gguf \
    ctest --test-dir build -R 'test-gpu-sampling-temp' -V
```

TODO:

- [ ] Allow GPU samplers to pre-allocate state tensors
- [ ] Integrate GPU samplers with llama-server
- [ ] Implement true top-p sampler on GPU
- [ ] Add missing GPU samplers (e.g. typical, mirostat, etc)
This commit adds a new cumulative sum (cumsum) operation to the ggml
library.

The motivation for this it to be able to implement GPU distribution
sampler. I notice that there is work underway to add cumsum in other PRs
so this commit can probably be removed once those are merged.
This commit adds initial support for GPU sampling in llama-cli.

Options:
```console
$ ./build/bin/llama-cli --help
----- sampling params -----
...
--gpu-sampling                          enable GPU sampling (default: disabled)
--gpu-top-k N                           GPU top-k sampling (default: 40, <= 0 = disabled)
--gpu-top-p-approx-k N                  GPU top-p approximation using top-k (default: 0, 0 = disabled)
--gpu-temp N                            GPU temperature (default: 0.80, 0.0 = disabled, greedy sampling)
--gpu-softmax                           add GPU softmax to sampling chain (default: disabled)
--gpu-dist                              add GPU dist (final sampling) to sampling chain (default: disabled)
```

Usage:
```console
$ ./build/bin/llama-cli -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \
    -p "What is the Capital of Sweden?" \
    --gpu-sampling \
    --gpu-temp 0 \
    --gpu-top-k 20 \
    --gpu-dist \
    -ngl 99 \
    -no-cnv \
    -n 20 \
    --no-warmup
```
This commit adds initial support for GPU sampling in llama-server.

GPU sampling for llama-cli was pretty straightforward as there is
bacially just one GPU sampler needed to be configured. Recall that
the GPU samplers need to be configured before the context is created as
they are added to the models computation graph, and are not something
that is processed after the computation has completed but are part of
it.

It is possible to have a GPU sampler per sequence, and llama-server
supports multiple slots. So currently it is possible to configure gpu
samplers per slot:
```console
$ ./build-gpu-sampling/bin/llama-server --help
...
----- sampling params -----
...
--gpu-sampling                          enable GPU sampling (default: disabled)
--gpu-top-k N                           GPU top-k sampling (default: 40, <= 0 = disabled)
--gpu-top-p-approx-k N                  GPU top-p approximation using top-k (default: 0, 0 = disabled)
--gpu-temp N                            GPU temperature (default: 0.80, 0.0 = disabled, greedy sampling)
--gpu-softmax                           add GPU softmax to sampling chain (default: disabled)
--gpu-dist                              add GPU dist (final sampling) to sampling chain (default: disabled)
--gpu-slot SLOT_ID:CONFIG               configure GPU sampling for a specific slot
                                        format: SLOT_ID:top-k=N,temp=F,dist=BOOL
                                        example: --gpu-slot 0:top-k=20,temp=0.8,dist=true --gpu-slot
                                        1:top-k=40,temp=0.5 --gpu-slot 2:none
```
The options with the --gpu- prefix configure the default GPU sampling
and used for all slots unless overridden by the --gpu-slot.

Usage:
```console
./build/bin/llama-server \
      -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \
      --gpu-sampling \
      --gpu-slot 0:top-k=20,temp=0.8,dist=true \
      --gpu-slot 1:top-k=40,temp=0.1,dist=false \
      --gpu-slot 2:none \
      --gpu-slot 3:none \
      -ngl 99 \
```
Optionally verbose logging can be enabled and the environment variable
GGML_SCHED_DEBUG=2 set to see the schedulers splits to verify that the
gpu is being used for sampling.
This commit introduces the function `llama_set_ggml_sampler` that allows
setting/unsetting, passing in null, the GPU sampler for a specific
sequence ID in the llama context.

The motivation for this is to enable llama-server to be able to manage
this dynamically (still not exactly sure how this will work but will
look into that next).
This commit removes the per slot configuration options for GPU sampling
and now only includes global flags to enable GPU sampling and GPU
final sampling (dist).

The motivation for this is that now that we can set the GPU samplers
directly we can reuse the CPU sampler options in the request. But we
still have the gpu_dist request flag so that it is possible specify that
the GPU should do the final sampling step. But there might be usecases
where this is not desired. For example, one might want to only do
filtering on the GPU and then do the final sampling CPU (though this is
not fully supported yet and needs more work).
This commit adds new settings to the chat configuration for enabling GPU
sampling and optionally GPU distribution sampling.

The motivation for this to enable testing of the GPU sampling features
directly from the web UI.
Copy link
Contributor

@ORippler ORippler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I have a strong opinion on this but removing hybrid sampling would reduce the complexity a bit I think (basically if we always set --gpu-dist we only have two states (either full gpu sampling or full cpu sampling, and no in-between).

Comment on lines +1504 to +1518
add_opt(common_arg(
{"--gpu-sampling"},
"enable GPU sampling (default: disabled)",
[](common_params & params) {
params.sampling.gpu_sampling = true;
}
).set_sparam());
add_opt(common_arg(
{"--gpu-dist"},
"perform final sampling on GPU (default: disabled)",
[](common_params & params) {
params.sampling.gpu_dist = true;
params.sampling.gpu_sampling = true;
}
).set_sparam());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have presumed --gpu-sampling to enable --gpu-dist as well; Maybe rename to something like "--gpu-sampling" and "--gpu-pre-sampling"? Or "--gpu-sampling" and "--hybrid-sampling"?.

Or put a comment somwhere that dist is short-hand for distribution, which would make it clearer we are sampling the final distribution of pseudo-probabilities on gpu

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is the naming that is unclear here and like you mentioned dist is intended to mean distribution sampler. I've added an item in the todo list regarding naming and I'll revisit these options in the next couple of days as I did not spend much time when adding them.

Comment on lines +116 to +118
const float * sampled_probs = llama_get_sampled_probs_ith(ctx, idx);
const float * sampled_logits = llama_get_sampled_logits_ith(ctx, idx);
const llama_token * sampled_ids = llama_get_sampled_token_ids_ith(ctx, idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not obvious to me that llama_get_sampled refers to GPU-offloaded sampling, and llama_get_logits refers to non-offloaded sampling. Maybe it makes sense to add a boolean flag to llama_context that can be queried?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I think we should try to name them specifically indicating that these are related to device sampling. Similar to my previous comment and the todo item about some consistent naming for them. Perhaps something like llama_get_device_sampled_probs_ith?

Comment on lines +969 to +991
// Get the GPU sampled token for the ith token.
// Returns LLAMA_TOKEN_NULL if no token was sampled.
LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i);

// Get the GPU sampled probabilites for the ith token
// The index matches llama_get_sampled_token_ith().
// Returns NULL if no probabilites were generated.
LLAMA_API float * llama_get_sampled_probs_ith(struct llama_context * ctx, int32_t i);

// Get the GPU sampled logits for the ith token
// Returns NULL if no logits were sampled.
LLAMA_API float * llama_get_sampled_logits_ith(struct llama_context * ctx, int32_t i);

// Get the GPU sampled token ids associated with the sampled logits for the ith token
// Returns NULL if no logits were sampled.
LLAMA_API llama_token * llama_get_sampled_token_ids_ith(struct llama_context * ctx, int32_t i);

// Get the number of GPU sampled logits for the ith token.
LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i);

// Get the number of GPU sampled probabilites for the ith token.
LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it's not obvious that llama_get_sampled should resolve to GPU sampling, and llama_get should resolve to CPU sampling

Comment on lines +251 to +255
static void llama_sampler_gpu_dist_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite genious how you avoid sorting by transforming to pseudo-probabilities + computing the CDF!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we adjust the can_batch_with function to allow batching only when sampling parameters are equal? Or am I misunderstanding how slot batching works in the server

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intend to take a closer look at llama-server in the next few days as the ability to set/reset the gpu/device samplers was only added yesterday, and I need to look into can_batch_with as I missed that completely.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing can_batch_with should be compatible with the current design, so I don't think it needs extra logic to guarantee that the samplers are equal. The idea is that when we construct the graph, in build_sampling we will add a sampler chain for every separate sequence that participates in the ubatch, so even if the samplers are different for the different slots, this should still work.

This commit adds an initial implementation of a GPU-based logit bias
sampler.
@danbev
Copy link
Member Author

danbev commented Nov 13, 2025

Not sure if I have a strong opinion on this but removing hybrid sampling would reduce the complexity a bit I think (basically if we always set --gpu-dist we only have two states (either full gpu sampling or full cpu sampling, and no in-between).

My thoughts are that I think we should keep the hybrid approach even though it does come with some additional complexity like you say. I think there could be use cases where one might want to perform some sampling like temp/logit_bias/top-k sampling on the device, and then only have a smaller set of logits copied to the host memory, and still enable other CPU samplers, including grammars, to be able to process the logits.

This might turn out to be an incorrect assumption and not something anyone wants to use, but it feels safer to have the ability do hybrid sampling to play it safe.

This commit changes the logging level of the message indicating that the
GPU sampler has selected a token from INFO to DEBUG.

The motivation for this is that when this is logged at INFO level, it
will interfere with the sampled token output. But this can be useful
when debugging to see if the GPU sampler is actually selecting tokens.
This commit adds support for logit bias in the GPU sampling
initialization. This also required that the gpu sampler initialization
need access to the model which it did not require before, and this meant
that the gpu sampler initialization in main.cpp needed to be done after
the model and context has been loaded/initialized. And for llama-server
the model needed to be passed into the gpu sampler initialization.

With these changes is is possible to run main using:
```console
./build/bin/llama-cli -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \
    --no-warmup --prompt '"What is the Capital of France?"' \
    --gpu-sampling \
    --logit-bias 3888+100 \
    --top-k 40 \
    --gpu-dist \
    --seed 88 \
    -n 20 \
    -no-cnv \
    --verbose-prompt
```

And the server can include logit bias in its requests as well:
```console
curl --request POST \
    --url http://localhost:8080/completion \
    --header "Content-Type: application/json" \
    --data '{"prompt": "What is the capital of Sweden?",
             "n_predict": 20,
             "top_k": 40,
             "gpu_dist": true,
             "logit_bias": {"3888": 100}}'
```
This commit disables test that require full backend operation support
for GPU sampling. The tests can stil be run manually.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs server testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants