|
| 1 | +#include "presets.hpp" |
| 2 | +#include "common.hpp" |
| 3 | +#include "ggml.h" |
| 4 | +#include "set.hpp" |
| 5 | +#include <cstdint> |
| 6 | +#include <sycl/sycl.hpp> |
| 7 | +using namespace sycl; |
| 8 | + |
| 9 | +// Internal function: perform element-wise set operation for each thread |
| 10 | +inline void set_f32(const float* src, float* dst, |
| 11 | + const int64_t ne0, const int64_t ne1, |
| 12 | + const int64_t ne2, const int64_t ne3, |
| 13 | + const int64_t nb[3], const int64_t src_nb[3], |
| 14 | + const int64_t offset_elem, |
| 15 | + const nd_item<1>& item) |
| 16 | +{ |
| 17 | + const size_t idx = item.get_global_id(0); |
| 18 | + const size_t total = ne0 * ne1 * ne2 * ne3; |
| 19 | + if (idx >= total) return; |
| 20 | + |
| 21 | + // Convert linear index to 4D indices |
| 22 | + const size_t i3 = idx / (ne2 * ne1 * ne0); |
| 23 | + const size_t rem = idx % (ne2 * ne1 * ne0); |
| 24 | + const size_t i2 = rem / (ne1 * ne0); |
| 25 | + const size_t rem2 = rem % (ne1 * ne0); |
| 26 | + const size_t i1 = rem2 / ne0; |
| 27 | + const size_t i0 = rem2 % ne0; |
| 28 | + |
| 29 | + // Compute source and destination indices and copy |
| 30 | + dst[i0 + i1*nb[0] + i2*nb[1] + i3*nb[2] + offset_elem] = |
| 31 | + src[i0 + i1*src_nb[0] + i2*src_nb[1] + i3*src_nb[2]]; |
| 32 | +} |
| 33 | + |
| 34 | +// Main function: prepare GPU queue and launch parallel_for |
| 35 | +void ggml_sycl_op_set(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { |
| 36 | + const ggml_tensor* src0 = dst->src[0]; |
| 37 | + const ggml_tensor* src1 = dst->src[1]; |
| 38 | + |
| 39 | + // Ensure shapes and types are compatible |
| 40 | + GGML_ASSERT(ggml_are_same_shape(src0, dst)); |
| 41 | + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); |
| 42 | + GGML_ASSERT(dst->type == src0->type && src0->type == src1->type && dst->type == GGML_TYPE_F32); |
| 43 | + |
| 44 | + const int32_t* opts = (const int32_t*) dst->op_params; |
| 45 | + const int64_t nb[3] = {opts[0]/sizeof(float), opts[1]/sizeof(float), opts[2]/sizeof(float)}; |
| 46 | + const int64_t offset_elem = opts[3] / sizeof(float); |
| 47 | + const bool inplace = opts[4]; |
| 48 | + |
| 49 | + float* dst_ptr = (float*) dst->data; |
| 50 | + const float* src0_ptr = (const float*) src0->data; |
| 51 | + const float* src1_ptr = (const float*) src1->data; |
| 52 | + |
| 53 | + queue_ptr stream = ctx.stream(); |
| 54 | + |
| 55 | + // Copy src0 to dst if not inplace |
| 56 | + if (!inplace) |
| 57 | + stream->memcpy(dst_ptr, src0_ptr, ggml_nbytes(dst)); |
| 58 | + |
| 59 | + const int64_t ne[4] = {src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]}; |
| 60 | + const int64_t src_nb[3] = {src1->nb[1]/sizeof(float), src1->nb[2]/sizeof(float), src1->nb[3]/sizeof(float)}; |
| 61 | + |
| 62 | + const size_t total_threads = ne[0]*ne[1]*ne[2]*ne[3]; |
| 63 | + const size_t grid_size = ((total_threads + SYCL_SET_BLOCK_SIZE - 1) / SYCL_SET_BLOCK_SIZE) * SYCL_SET_BLOCK_SIZE; |
| 64 | + |
| 65 | + // Copy src0 to dst if not inplace |
| 66 | + stream->parallel_for( |
| 67 | + nd_range<1>(range<1>(grid_size), range<1>(SYCL_SET_BLOCK_SIZE)), |
| 68 | + [=](nd_item<1> item) { |
| 69 | + set_f32(src1_ptr, dst_ptr, |
| 70 | + ne[0], ne[1], ne[2], ne[3], |
| 71 | + nb, src_nb, offset_elem, item); } |
| 72 | + ); |
| 73 | +} |
0 commit comments