@@ -13,13 +13,14 @@ Custom SYCL Operators
1313 .. grid-item-card :: :octicon:`list-unordered;1em;` Prerequisites
1414 :class-card: card-prerequisites
1515
16- * PyTorch 2.8 or later
16+ * PyTorch 2.8 or later for Linux
17+ * PyTorch 2.10 or later for Windows
1718 * Basic understanding of SYCL programming
1819
1920.. note ::
2021
2122 ``SYCL `` serves as the backend programming language for Intel GPUs (device label ``xpu ``). For configuration details, see:
22- `Getting Started on Intel GPUs <https://docs.pytorch.org/docs/main/notes/get_start_xpu.html >`_. The Intel Compiler, which comes bundled with Intel Deep Learning Essentials, handles ``SYCL `` compilation. Ensure you install and activate the compiler environment prior to executing the code examples in this tutorial.
23+ `Getting Started on Intel GPUs <https://docs.pytorch.org/docs/main/notes/get_start_xpu.html >`_. The Intel Compiler, which comes bundled with ` Intel Deep Learning Essentials < https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpus.html >`_ , handles ``SYCL `` compilation. Ensure you install and activate the compiler environment prior to executing the code examples in this tutorial.
2324
2425PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc).
2526However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the
@@ -40,52 +41,71 @@ Follow the structure to create a custom SYCL operator:
4041 Setting up the Build System
4142---------------------------
4243
43- If you need to compile **SYCL ** code (for example, ``.sycl `` files ), use `torch.utils.cpp_extension.SyclExtension <https://docs.pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.SyclExtension >`_.
44+ If you need to compile **SYCL ** code (noting that the extension should be ``.sycl ``), use `torch.utils.cpp_extension.SyclExtension <https://docs.pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.SyclExtension >`_.
4445The setup process is very similar to C++/CUDA, except the compilation arguments need to be adjusted for SYCL.
4546
4647Using ``sycl_extension `` is as straightforward as writing the following ``setup.py ``:
4748
4849.. code-block :: python
4950
50- import os
51- import torch
52- import glob
53- from setuptools import find_packages, setup
54- from torch.utils.cpp_extension import SyclExtension, BuildExtension
55-
56- library_name = " sycl_extension"
57- py_limited_api = True
58- extra_compile_args = {
59- " cxx" : [" -O3" ,
60- " -fdiagnostics-color=always" ,
61- " -DPy_LIMITED_API=0x03090000" ],
62- " sycl" : [" -O3" ]
63- }
64-
65- assert (torch.xpu.is_available()), " XPU is not available, please check your environment"
66- # Source files collection
67- this_dir = os.path.dirname(os.path.curdir)
68- extensions_dir = os.path.join(this_dir, library_name)
69- sources = list (glob.glob(os.path.join(extensions_dir, " *.sycl" )))
70- # Construct extension
71- ext_modules = [
72- SyclExtension(
73- f " { library_name} ._C " ,
74- sources,
75- extra_compile_args = extra_compile_args,
76- py_limited_api = py_limited_api,
77- )
78- ]
79- setup(
80- name = library_name,
81- packages = find_packages(),
82- ext_modules = ext_modules,
83- install_requires = [" torch" ],
84- description = " Simple Example of PyTorch Sycl extensions" ,
85- cmdclass = {" build_ext" : BuildExtension},
86- options = {" bdist_wheel" : {" py_limited_api" : " cp39" }} if py_limited_api else {},
87- )
88-
51+ import os
52+ import torch
53+ import glob
54+ import platform
55+ from setuptools import find_packages, setup
56+ from torch.utils.cpp_extension import SyclExtension, BuildExtension
57+
58+ library_name = " sycl_extension"
59+ py_limited_api = True
60+
61+ IS_WINDOWS = (platform.system() == ' Windows' )
62+
63+ if IS_WINDOWS :
64+ cxx_args = [
65+ " /O2" ,
66+ " /std:c++17" ,
67+ " /DPy_LIMITED_API=0x03090000" ,
68+ ]
69+ sycl_args = [" /O2" , " /std:c++17" ]
70+ else :
71+ cxx_args = [
72+ " -O3" ,
73+ " -fdiagnostics-color=always" ,
74+ " -DPy_LIMITED_API=0x03090000"
75+ ]
76+ sycl_args = [" -O3" ]
77+
78+ extra_compile_args = {
79+ " cxx" : cxx_args,
80+ " sycl" : sycl_args
81+ }
82+
83+ assert (torch.xpu.is_available()), " XPU is not available, please check your environment"
84+
85+ # Source files collection
86+ this_dir = os.path.dirname(os.path.curdir)
87+ extensions_dir = os.path.join(this_dir, library_name)
88+ sources = list (glob.glob(os.path.join(extensions_dir, " *.sycl" )))
89+
90+ # Construct extension
91+ ext_modules = [
92+ SyclExtension(
93+ f " { library_name} ._C " ,
94+ sources,
95+ extra_compile_args = extra_compile_args,
96+ py_limited_api = py_limited_api,
97+ )
98+ ]
99+
100+ setup(
101+ name = library_name,
102+ packages = find_packages(),
103+ ext_modules = ext_modules,
104+ install_requires = [" torch" ],
105+ description = " Simple Example of PyTorch Sycl extensions" ,
106+ cmdclass = {" build_ext" : BuildExtension},
107+ options = {" bdist_wheel" : {" py_limited_api" : " cp39" }} if py_limited_api else {},
108+ )
89109
90110 Defining the custom op and adding backend implementations
91111---------------------------------------------------------
@@ -101,82 +121,109 @@ in a separate ``TORCH_LIBRARY_IMPL`` block:
101121
102122.. code-block :: cpp
103123
104- #include <c10/xpu/XPUStream.h>
105- #include <sycl/sycl.hpp>
106- #include <ATen/Operators.h>
107- #include <torch/all.h>
108- #include <torch/library.h>
109-
110- namespace sycl_extension {
111- // MulAdd Kernel: result = a * b + c
112- static void muladd_kernel(
113- int numel, const float* a, const float* b, float c, float* result,
114- const sycl::nd_item<1>& item) {
115- int idx = item.get_global_id(0);
116- if (idx < numel) {
117- result[idx] = a[idx] * b[idx] + c;
118- }
119- }
120-
121- class MulAddKernelFunctor {
122- public:
123- MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result)
124- : numel(_numel), a(_a), b(_b), c(_c), result(_result) {}
125- void operator()(const sycl::nd_item<1>& item) const {
126- muladd_kernel(numel, a, b, c, result, item);
127- }
128-
129- private:
130- int numel;
131- const float* a;
132- const float* b;
133- float c;
134- float* result;
135- };
136-
137- at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) {
138- TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
139- TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
140- TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
141- TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
142- TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");
143-
144- at::Tensor a_contig = a.contiguous();
145- at::Tensor b_contig = b.contiguous();
146- at::Tensor result = at::empty_like(a_contig);
147-
148- const float* a_ptr = a_contig.data_ptr<float>();
149- const float* b_ptr = b_contig.data_ptr<float>();
150- float* res_ptr = result.data_ptr<float>();
151- int numel = a_contig.numel();
152-
153- sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
154- constexpr int threads = 256;
155- int blocks = (numel + threads - 1) / threads;
156-
157- queue.submit([&](sycl::handler& cgh) {
158- cgh.parallel_for<MulAddKernelFunctor>(
159- sycl::nd_range<1>(blocks * threads, threads),
160- MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast<float>(c), res_ptr)
161- );
162- });
163-
164- return result;
165- }
166- // Defines the operators
167- TORCH_LIBRARY(sycl_extension, m) {
124+ #include <c10/xpu/XPUStream.h>
125+ #include <sycl/sycl.hpp>
126+ #include <ATen/Operators.h>
127+ #include <torch/all.h>
128+ #include <torch/library.h>
129+
130+
131+ #include <Python.h>
132+
133+ namespace sycl_extension {
134+
135+ // ==========================================================
136+ // 1. Kernel
137+ // ==========================================================
138+ static void muladd_kernel(
139+ int numel, const float* a, const float* b, float c, float* result,
140+ const sycl::nd_item<1>& item) {
141+ int idx = item.get_global_id(0);
142+ if (idx < numel) {
143+ result[idx] = a[idx] * b[idx] + c;
144+ }
145+ }
146+
147+ class MulAddKernelFunctor {
148+ public:
149+ MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result)
150+ : numel(_numel), a(_a), b(_b), c(_c), result(_result) {}
151+ void operator()(const sycl::nd_item<1>& item) const {
152+ muladd_kernel(numel, a, b, c, result, item);
153+ }
154+
155+ private:
156+ int numel;
157+ const float* a;
158+ const float* b;
159+ float c;
160+ float* result;
161+ };
162+
163+ // ==========================================================
164+ // 2. Wrapper
165+ // ==========================================================
166+ at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) {
167+ TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
168+ TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
169+ TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
170+ TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
171+ TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");
172+
173+ at::Tensor a_contig = a.contiguous();
174+ at::Tensor b_contig = b.contiguous();
175+ at::Tensor result = at::empty_like(a_contig);
176+
177+ const float* a_ptr = a_contig.data_ptr<float>();
178+ const float* b_ptr = b_contig.data_ptr<float>();
179+ float* res_ptr = result.data_ptr<float>();
180+ int numel = a_contig.numel();
181+
182+ sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
183+ constexpr int threads = 256;
184+ int blocks = (numel + threads - 1) / threads;
185+
186+ queue.submit([&](sycl::handler& cgh) {
187+ cgh.parallel_for<MulAddKernelFunctor>(
188+ sycl::nd_range<1>(blocks * threads, threads),
189+ MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast<float>(c), res_ptr)
190+ );
191+ });
192+
193+ return result;
194+ }
195+
196+ // ==========================================================
197+ // 3. Registration
198+ // ==========================================================
199+ TORCH_LIBRARY(sycl_extension, m) {
168200 m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
169- }
170-
171- // ==================================================
172- // Register SYCL Implementations to Torch Library
173- // ==================================================
174- TORCH_LIBRARY_IMPL(sycl_extension, XPU, m) {
175- m.impl("mymuladd", &mymuladd_xpu);
176- }
177-
178- } // namespace sycl_extension
179-
201+ }
202+
203+ TORCH_LIBRARY_IMPL(sycl_extension, XPU, m) {
204+ m.impl("mymuladd", &mymuladd_xpu);
205+ }
206+
207+ } // namespace sycl_extension
208+
209+ // ==========================================================
210+ // 4. Windows Linker
211+ // ==========================================================
212+ extern "C" {
213+ #ifdef _WIN32
214+ __declspec(dllexport)
215+ #endif
216+ PyObject* PyInit__C(void) {
217+ static struct PyModuleDef moduledef = {
218+ PyModuleDef_HEAD_INIT,
219+ "_C",
220+ "XPU Extension Shim",
221+ -1,
222+ NULL
223+ };
224+ return PyModule_Create(&moduledef);
225+ }
226+ }
180227
181228
182229 Create a Python Interface
@@ -201,26 +248,39 @@ Create ``sycl_extension/__init__.py`` file to make the package importable:
201248
202249.. code-block :: python
203250
204- import ctypes
205- from pathlib import Path
251+ import ctypes
252+ import platform
253+ from pathlib import Path
206254
207- import torch
255+ import torch
256+
257+ current_dir = Path(__file__ ).parent.parent
258+ build_dir = current_dir / " build"
259+
260+ if platform.system() == ' Windows' :
261+ file_pattern = " **/*.pyd"
262+ else :
263+ file_pattern = " **/*.so"
264+
265+ lib_files = list (build_dir.glob(file_pattern))
266+
267+ if not lib_files:
268+ current_package_dir = Path(__file__ ).parent
269+ lib_files = list (current_package_dir.glob(file_pattern))
208270
209- current_dir = Path(__file__ ).parent.parent
210- build_dir = current_dir / " build"
211- so_files = list (build_dir.glob(" **/*.so" ))
271+ assert len (lib_files) > 0 , f " Could not find any { file_pattern} file in { build_dir} or { current_dir} "
272+ lib_file = lib_files[0 ]
212273
213- assert len (so_files) == 1 , f " Expected one _C*.so file, found { len (so_files)} "
214274
215- with torch._ops.dl_open_guard():
216- loaded_lib = ctypes.CDLL(so_files[ 0 ] )
275+ with torch._ops.dl_open_guard():
276+ loaded_lib = ctypes.CDLL(str (lib_file) )
217277
218- from . import ops
278+ from . import ops
219279
220- __all__ = [
221- " loaded_lib" ,
222- " ops" ,
223- ]
280+ __all__ = [
281+ " loaded_lib" ,
282+ " ops" ,
283+ ]
224284
225285 Testing SYCL extension operator
226286-------------------
0 commit comments