1
- from typing import Optional
2
-
3
1
import os
4
2
from functools import wraps
5
3
from contextlib import nullcontext
18
16
19
17
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
20
18
21
- class DummyDataParallel (torch .nn .Module ): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
22
- def __new__ (cls , module , device_ids = None , output_device = None , dim = 0 ): # pylint: disable=unused-argument
23
- if isinstance (device_ids , list ) and len (device_ids ) > 1 :
24
- errors .log .error ("IPEX backend doesn't support DataParallel on multiple XPU devices" )
25
- return module .to (devices .device )
26
19
27
20
def return_null_context (* args , ** kwargs ): # pylint: disable=unused-argument
28
21
return nullcontext ()
29
22
23
+
30
24
@property
31
25
def is_cuda (self ):
32
26
return self .device .type == "xpu" or self .device .type == "cuda"
33
27
28
+
34
29
def check_device_type (device , device_type : str ) -> bool :
35
30
if device is None or type (device ) not in {str , int , torch .device }:
36
31
return False
37
32
else :
38
33
return bool (torch .device (device ).type == device_type )
39
34
35
+
40
36
def check_cuda (device ) -> bool :
41
37
return bool (isinstance (device , int ) or check_device_type (device , "cuda" ))
42
38
39
+
43
40
def return_xpu (device ): # keep the device instance type, aka return string if the input is string
44
41
return devices .device if device is None else f"xpu:{ device .split (':' )[- 1 ]} " if isinstance (device , str ) and ":" in device else f"xpu:{ device } " if isinstance (device , int ) else torch .device (f"xpu:{ device .index } " if device .index is not None else "xpu" ) if isinstance (device , torch .device ) else "xpu"
45
42
@@ -107,13 +104,15 @@ def functional_pad(input, pad, mode='constant', value=None):
107
104
else :
108
105
return original_functional_pad (input , pad , mode = mode , value = value )
109
106
107
+
110
108
# Diffusers FreeU
111
109
original_fft_fftn = torch .fft .fftn
112
110
@wraps (torch .fft .fftn )
113
111
def fft_fftn (input , s = None , dim = None , norm = None , * , out = None ):
114
112
return_dtype = input .dtype
115
113
return original_fft_fftn (input .to (dtype = torch .float32 ), s = s , dim = dim , norm = norm , out = out ).to (dtype = return_dtype )
116
114
115
+
117
116
# Diffusers FreeU
118
117
original_fft_ifftn = torch .fft .ifftn
119
118
@wraps (torch .fft .ifftn )
@@ -131,6 +130,7 @@ def from_numpy(ndarray):
131
130
else :
132
131
return original_from_numpy (ndarray )
133
132
133
+
134
134
original_as_tensor = torch .as_tensor
135
135
@wraps (torch .as_tensor )
136
136
def as_tensor (data , dtype = None , device = None ):
@@ -156,6 +156,7 @@ def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
156
156
dtype = torch .float32
157
157
return original_torch_tensor (data , * args , dtype = dtype , device = device , ** kwargs )
158
158
159
+
159
160
torch .Tensor .original_Tensor_to = torch .Tensor .to
160
161
@wraps (torch .Tensor .to )
161
162
def Tensor_to (self , device = None , * args , ** kwargs ):
@@ -169,6 +170,7 @@ def Tensor_to(self, device=None, *args, **kwargs):
169
170
device = torch .float32
170
171
return self .original_Tensor_to (device , * args , ** kwargs )
171
172
173
+
172
174
original_Tensor_cuda = torch .Tensor .cuda
173
175
@wraps (torch .Tensor .cuda )
174
176
def Tensor_cuda (self , device = None , * args , ** kwargs ):
@@ -177,6 +179,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
177
179
else :
178
180
return original_Tensor_cuda (self , device , * args , ** kwargs )
179
181
182
+
180
183
original_Tensor_pin_memory = torch .Tensor .pin_memory
181
184
@wraps (torch .Tensor .pin_memory )
182
185
def Tensor_pin_memory (self , device = None , * args , ** kwargs ):
@@ -185,6 +188,7 @@ def Tensor_pin_memory(self, device=None, *args, **kwargs):
185
188
else :
186
189
return original_Tensor_pin_memory (self , device , * args , ** kwargs )
187
190
191
+
188
192
original_UntypedStorage_init = torch .UntypedStorage .__init__
189
193
@wraps (torch .UntypedStorage .__init__ )
190
194
def UntypedStorage_init (* args , device = None , ** kwargs ):
@@ -193,6 +197,7 @@ def UntypedStorage_init(*args, device=None, **kwargs):
193
197
else :
194
198
return original_UntypedStorage_init (* args , device = device , ** kwargs )
195
199
200
+
196
201
if torch_version [0 ] > 2 or (torch_version [0 ] == 2 and torch_version [1 ] >= 4 ):
197
202
original_UntypedStorage_to = torch .UntypedStorage .to
198
203
@wraps (torch .UntypedStorage .to )
@@ -210,6 +215,7 @@ def UntypedStorage_cuda(self, device=None, non_blocking=False, **kwargs):
210
215
else :
211
216
return original_UntypedStorage_cuda (self , device = device , non_blocking = non_blocking , ** kwargs )
212
217
218
+
213
219
original_torch_empty = torch .empty
214
220
@wraps (torch .empty )
215
221
def torch_empty (* args , device = None , ** kwargs ):
@@ -218,6 +224,7 @@ def torch_empty(*args, device=None, **kwargs):
218
224
else :
219
225
return original_torch_empty (* args , device = device , ** kwargs )
220
226
227
+
221
228
original_torch_randn = torch .randn
222
229
@wraps (torch .randn )
223
230
def torch_randn (* args , device = None , dtype = None , ** kwargs ):
@@ -228,6 +235,7 @@ def torch_randn(*args, device=None, dtype=None, **kwargs):
228
235
else :
229
236
return original_torch_randn (* args , device = device , dtype = dtype , ** kwargs )
230
237
238
+
231
239
original_torch_ones = torch .ones
232
240
@wraps (torch .ones )
233
241
def torch_ones (* args , device = None , ** kwargs ):
@@ -236,6 +244,7 @@ def torch_ones(*args, device=None, **kwargs):
236
244
else :
237
245
return original_torch_ones (* args , device = device , ** kwargs )
238
246
247
+
239
248
original_torch_zeros = torch .zeros
240
249
@wraps (torch .zeros )
241
250
def torch_zeros (* args , device = None , ** kwargs ):
@@ -244,6 +253,7 @@ def torch_zeros(*args, device=None, **kwargs):
244
253
else :
245
254
return original_torch_zeros (* args , device = device , ** kwargs )
246
255
256
+
247
257
original_torch_full = torch .full
248
258
@wraps (torch .full )
249
259
def torch_full (* args , device = None , ** kwargs ):
@@ -252,6 +262,7 @@ def torch_full(*args, device=None, **kwargs):
252
262
else :
253
263
return original_torch_full (* args , device = device , ** kwargs )
254
264
265
+
255
266
original_torch_linspace = torch .linspace
256
267
@wraps (torch .linspace )
257
268
def torch_linspace (* args , device = None , ** kwargs ):
@@ -260,6 +271,7 @@ def torch_linspace(*args, device=None, **kwargs):
260
271
else :
261
272
return original_torch_linspace (* args , device = device , ** kwargs )
262
273
274
+
263
275
original_torch_eye = torch .eye
264
276
@wraps (torch .eye )
265
277
def torch_eye (* args , device = None , ** kwargs ):
@@ -268,6 +280,7 @@ def torch_eye(*args, device=None, **kwargs):
268
280
else :
269
281
return original_torch_eye (* args , device = device , ** kwargs )
270
282
283
+
271
284
original_torch_load = torch .load
272
285
@wraps (torch .load )
273
286
def torch_load (f , map_location = None , * args , ** kwargs ):
@@ -276,27 +289,31 @@ def torch_load(f, map_location=None, *args, **kwargs):
276
289
else :
277
290
return original_torch_load (f , * args , map_location = map_location , ** kwargs )
278
291
292
+
279
293
@wraps (torch .cuda .synchronize )
280
294
def torch_cuda_synchronize (device = None ):
281
295
if check_cuda (device ):
282
296
return torch .xpu .synchronize (return_xpu (device ))
283
297
else :
284
298
return torch .xpu .synchronize (device )
285
299
300
+
286
301
@wraps (torch .cuda .device )
287
302
def torch_cuda_device (device ):
288
303
if check_cuda (device ):
289
304
return torch .xpu .device (return_xpu (device ))
290
305
else :
291
306
return torch .xpu .device (device )
292
307
308
+
293
309
@wraps (torch .cuda .set_device )
294
310
def torch_cuda_set_device (device ):
295
311
if check_cuda (device ):
296
312
torch .xpu .set_device (return_xpu (device ))
297
313
else :
298
314
torch .xpu .set_device (device )
299
315
316
+
300
317
# torch.Generator has to be a class for isinstance checks
301
318
original_torch_Generator = torch .Generator
302
319
class torch_Generator (original_torch_Generator ):
@@ -335,7 +352,6 @@ def ipex_hijacks():
335
352
torch ._C .Generator = torch_Generator
336
353
337
354
torch .backends .cuda .sdp_kernel = return_null_context
338
- torch .nn .DataParallel = DummyDataParallel
339
355
torch .UntypedStorage .is_cuda = is_cuda
340
356
torch .amp .autocast_mode .autocast .__init__ = autocast_init
341
357
0 commit comments