Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 60e776a

Browse files
ZhaoqiongZsvekarsAlannaBurke
authored
add windows support tutorial for sycl extension (#3699)
update the tutorial with windows support cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @ZailiWang @leslie-fang-intel @Xia-Weiwen @sekahler2 @CaoE @zhuhaozhe @Valentine233 --------- Co-authored-by: Svetlana Karslioglu <[email protected]> Co-authored-by: Alanna Burke <[email protected]>
1 parent f165429 commit 60e776a

1 file changed

Lines changed: 191 additions & 131 deletions

File tree

advanced_source/cpp_custom_ops_sycl.rst

Lines changed: 191 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@ Custom SYCL Operators
1313
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
1414
:class-card: card-prerequisites
1515

16-
* PyTorch 2.8 or later
16+
* PyTorch 2.8 or later for Linux
17+
* PyTorch 2.10 or later for Windows
1718
* Basic understanding of SYCL programming
1819

1920
.. note::
2021

2122
``SYCL`` serves as the backend programming language for Intel GPUs (device label ``xpu``). For configuration details, see:
22-
`Getting Started on Intel GPUs <https://docs.pytorch.org/docs/main/notes/get_start_xpu.html>`_. The Intel Compiler, which comes bundled with Intel Deep Learning Essentials, handles ``SYCL`` compilation. Ensure you install and activate the compiler environment prior to executing the code examples in this tutorial.
23+
`Getting Started on Intel GPUs <https://docs.pytorch.org/docs/main/notes/get_start_xpu.html>`_. The Intel Compiler, which comes bundled with `Intel Deep Learning Essentials <https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpus.html>`_, handles ``SYCL`` compilation. Ensure you install and activate the compiler environment prior to executing the code examples in this tutorial.
2324

2425
PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc).
2526
However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the
@@ -40,52 +41,71 @@ Follow the structure to create a custom SYCL operator:
4041
Setting up the Build System
4142
---------------------------
4243

43-
If you need to compile **SYCL** code (for example, ``.sycl`` files), use `torch.utils.cpp_extension.SyclExtension <https://docs.pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.SyclExtension>`_.
44+
If you need to compile **SYCL** code (noting that the extension should be ``.sycl``), use `torch.utils.cpp_extension.SyclExtension <https://docs.pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.SyclExtension>`_.
4445
The setup process is very similar to C++/CUDA, except the compilation arguments need to be adjusted for SYCL.
4546

4647
Using ``sycl_extension`` is as straightforward as writing the following ``setup.py``:
4748

4849
.. code-block:: python
4950
50-
import os
51-
import torch
52-
import glob
53-
from setuptools import find_packages, setup
54-
from torch.utils.cpp_extension import SyclExtension, BuildExtension
55-
56-
library_name = "sycl_extension"
57-
py_limited_api = True
58-
extra_compile_args = {
59-
"cxx": ["-O3",
60-
"-fdiagnostics-color=always",
61-
"-DPy_LIMITED_API=0x03090000"],
62-
"sycl": ["-O3" ]
63-
}
64-
65-
assert(torch.xpu.is_available()), "XPU is not available, please check your environment"
66-
# Source files collection
67-
this_dir = os.path.dirname(os.path.curdir)
68-
extensions_dir = os.path.join(this_dir, library_name)
69-
sources = list(glob.glob(os.path.join(extensions_dir, "*.sycl")))
70-
# Construct extension
71-
ext_modules = [
72-
SyclExtension(
73-
f"{library_name}._C",
74-
sources,
75-
extra_compile_args=extra_compile_args,
76-
py_limited_api=py_limited_api,
77-
)
78-
]
79-
setup(
80-
name=library_name,
81-
packages=find_packages(),
82-
ext_modules=ext_modules,
83-
install_requires=["torch"],
84-
description="Simple Example of PyTorch Sycl extensions",
85-
cmdclass={"build_ext": BuildExtension},
86-
options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {},
87-
)
88-
51+
import os
52+
import torch
53+
import glob
54+
import platform
55+
from setuptools import find_packages, setup
56+
from torch.utils.cpp_extension import SyclExtension, BuildExtension
57+
58+
library_name = "sycl_extension"
59+
py_limited_api = True
60+
61+
IS_WINDOWS = (platform.system() == 'Windows')
62+
63+
if IS_WINDOWS:
64+
cxx_args = [
65+
"/O2",
66+
"/std:c++17",
67+
"/DPy_LIMITED_API=0x03090000",
68+
]
69+
sycl_args = ["/O2", "/std:c++17"]
70+
else:
71+
cxx_args = [
72+
"-O3",
73+
"-fdiagnostics-color=always",
74+
"-DPy_LIMITED_API=0x03090000"
75+
]
76+
sycl_args = ["-O3"]
77+
78+
extra_compile_args = {
79+
"cxx": cxx_args,
80+
"sycl": sycl_args
81+
}
82+
83+
assert(torch.xpu.is_available()), "XPU is not available, please check your environment"
84+
85+
# Source files collection
86+
this_dir = os.path.dirname(os.path.curdir)
87+
extensions_dir = os.path.join(this_dir, library_name)
88+
sources = list(glob.glob(os.path.join(extensions_dir, "*.sycl")))
89+
90+
# Construct extension
91+
ext_modules = [
92+
SyclExtension(
93+
f"{library_name}._C",
94+
sources,
95+
extra_compile_args=extra_compile_args,
96+
py_limited_api=py_limited_api,
97+
)
98+
]
99+
100+
setup(
101+
name=library_name,
102+
packages=find_packages(),
103+
ext_modules=ext_modules,
104+
install_requires=["torch"],
105+
description="Simple Example of PyTorch Sycl extensions",
106+
cmdclass={"build_ext": BuildExtension},
107+
options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {},
108+
)
89109
90110
Defining the custom op and adding backend implementations
91111
---------------------------------------------------------
@@ -101,82 +121,109 @@ in a separate ``TORCH_LIBRARY_IMPL`` block:
101121

102122
.. code-block:: cpp
103123
104-
#include <c10/xpu/XPUStream.h>
105-
#include <sycl/sycl.hpp>
106-
#include <ATen/Operators.h>
107-
#include <torch/all.h>
108-
#include <torch/library.h>
109-
110-
namespace sycl_extension {
111-
// MulAdd Kernel: result = a * b + c
112-
static void muladd_kernel(
113-
int numel, const float* a, const float* b, float c, float* result,
114-
const sycl::nd_item<1>& item) {
115-
int idx = item.get_global_id(0);
116-
if (idx < numel) {
117-
result[idx] = a[idx] * b[idx] + c;
118-
}
119-
}
120-
121-
class MulAddKernelFunctor {
122-
public:
123-
MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result)
124-
: numel(_numel), a(_a), b(_b), c(_c), result(_result) {}
125-
void operator()(const sycl::nd_item<1>& item) const {
126-
muladd_kernel(numel, a, b, c, result, item);
127-
}
128-
129-
private:
130-
int numel;
131-
const float* a;
132-
const float* b;
133-
float c;
134-
float* result;
135-
};
136-
137-
at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) {
138-
TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
139-
TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
140-
TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
141-
TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
142-
TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");
143-
144-
at::Tensor a_contig = a.contiguous();
145-
at::Tensor b_contig = b.contiguous();
146-
at::Tensor result = at::empty_like(a_contig);
147-
148-
const float* a_ptr = a_contig.data_ptr<float>();
149-
const float* b_ptr = b_contig.data_ptr<float>();
150-
float* res_ptr = result.data_ptr<float>();
151-
int numel = a_contig.numel();
152-
153-
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
154-
constexpr int threads = 256;
155-
int blocks = (numel + threads - 1) / threads;
156-
157-
queue.submit([&](sycl::handler& cgh) {
158-
cgh.parallel_for<MulAddKernelFunctor>(
159-
sycl::nd_range<1>(blocks * threads, threads),
160-
MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast<float>(c), res_ptr)
161-
);
162-
});
163-
164-
return result;
165-
}
166-
// Defines the operators
167-
TORCH_LIBRARY(sycl_extension, m) {
124+
#include <c10/xpu/XPUStream.h>
125+
#include <sycl/sycl.hpp>
126+
#include <ATen/Operators.h>
127+
#include <torch/all.h>
128+
#include <torch/library.h>
129+
130+
131+
#include <Python.h>
132+
133+
namespace sycl_extension {
134+
135+
// ==========================================================
136+
// 1. Kernel
137+
// ==========================================================
138+
static void muladd_kernel(
139+
int numel, const float* a, const float* b, float c, float* result,
140+
const sycl::nd_item<1>& item) {
141+
int idx = item.get_global_id(0);
142+
if (idx < numel) {
143+
result[idx] = a[idx] * b[idx] + c;
144+
}
145+
}
146+
147+
class MulAddKernelFunctor {
148+
public:
149+
MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result)
150+
: numel(_numel), a(_a), b(_b), c(_c), result(_result) {}
151+
void operator()(const sycl::nd_item<1>& item) const {
152+
muladd_kernel(numel, a, b, c, result, item);
153+
}
154+
155+
private:
156+
int numel;
157+
const float* a;
158+
const float* b;
159+
float c;
160+
float* result;
161+
};
162+
163+
// ==========================================================
164+
// 2. Wrapper
165+
// ==========================================================
166+
at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) {
167+
TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape");
168+
TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor");
169+
TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor");
170+
TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor");
171+
TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor");
172+
173+
at::Tensor a_contig = a.contiguous();
174+
at::Tensor b_contig = b.contiguous();
175+
at::Tensor result = at::empty_like(a_contig);
176+
177+
const float* a_ptr = a_contig.data_ptr<float>();
178+
const float* b_ptr = b_contig.data_ptr<float>();
179+
float* res_ptr = result.data_ptr<float>();
180+
int numel = a_contig.numel();
181+
182+
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
183+
constexpr int threads = 256;
184+
int blocks = (numel + threads - 1) / threads;
185+
186+
queue.submit([&](sycl::handler& cgh) {
187+
cgh.parallel_for<MulAddKernelFunctor>(
188+
sycl::nd_range<1>(blocks * threads, threads),
189+
MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast<float>(c), res_ptr)
190+
);
191+
});
192+
193+
return result;
194+
}
195+
196+
// ==========================================================
197+
// 3. Registration
198+
// ==========================================================
199+
TORCH_LIBRARY(sycl_extension, m) {
168200
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
169-
}
170-
171-
// ==================================================
172-
// Register SYCL Implementations to Torch Library
173-
// ==================================================
174-
TORCH_LIBRARY_IMPL(sycl_extension, XPU, m) {
175-
m.impl("mymuladd", &mymuladd_xpu);
176-
}
177-
178-
} // namespace sycl_extension
179-
201+
}
202+
203+
TORCH_LIBRARY_IMPL(sycl_extension, XPU, m) {
204+
m.impl("mymuladd", &mymuladd_xpu);
205+
}
206+
207+
} // namespace sycl_extension
208+
209+
// ==========================================================
210+
// 4. Windows Linker
211+
// ==========================================================
212+
extern "C" {
213+
#ifdef _WIN32
214+
__declspec(dllexport)
215+
#endif
216+
PyObject* PyInit__C(void) {
217+
static struct PyModuleDef moduledef = {
218+
PyModuleDef_HEAD_INIT,
219+
"_C",
220+
"XPU Extension Shim",
221+
-1,
222+
NULL
223+
};
224+
return PyModule_Create(&moduledef);
225+
}
226+
}
180227
181228
182229
Create a Python Interface
@@ -201,26 +248,39 @@ Create ``sycl_extension/__init__.py`` file to make the package importable:
201248

202249
.. code-block:: python
203250
204-
import ctypes
205-
from pathlib import Path
251+
import ctypes
252+
import platform
253+
from pathlib import Path
206254
207-
import torch
255+
import torch
256+
257+
current_dir = Path(__file__).parent.parent
258+
build_dir = current_dir / "build"
259+
260+
if platform.system() == 'Windows':
261+
file_pattern = "**/*.pyd"
262+
else:
263+
file_pattern = "**/*.so"
264+
265+
lib_files = list(build_dir.glob(file_pattern))
266+
267+
if not lib_files:
268+
current_package_dir = Path(__file__).parent
269+
lib_files = list(current_package_dir.glob(file_pattern))
208270
209-
current_dir = Path(__file__).parent.parent
210-
build_dir = current_dir / "build"
211-
so_files = list(build_dir.glob("**/*.so"))
271+
assert len(lib_files) > 0, f"Could not find any {file_pattern} file in {build_dir} or {current_dir}"
272+
lib_file = lib_files[0]
212273
213-
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
214274
215-
with torch._ops.dl_open_guard():
216-
loaded_lib = ctypes.CDLL(so_files[0])
275+
with torch._ops.dl_open_guard():
276+
loaded_lib = ctypes.CDLL(str(lib_file))
217277
218-
from . import ops
278+
from . import ops
219279
220-
__all__ = [
221-
"loaded_lib",
222-
"ops",
223-
]
280+
__all__ = [
281+
"loaded_lib",
282+
"ops",
283+
]
224284
225285
Testing SYCL extension operator
226286
-------------------

0 commit comments

Comments
 (0)