Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2842f5e

Browse files
committed
cleanup
1 parent cc51f79 commit 2842f5e

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

modules/intel/ipex/hijacks.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Optional
2-
31
import os
42
from functools import wraps
53
from contextlib import nullcontext
@@ -18,28 +16,27 @@
1816

1917
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
2018

21-
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
22-
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
23-
if isinstance(device_ids, list) and len(device_ids) > 1:
24-
errors.log.error("IPEX backend doesn't support DataParallel on multiple XPU devices")
25-
return module.to(devices.device)
2619

2720
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
2821
return nullcontext()
2922

23+
3024
@property
3125
def is_cuda(self):
3226
return self.device.type == "xpu" or self.device.type == "cuda"
3327

28+
3429
def check_device_type(device, device_type: str) -> bool:
3530
if device is None or type(device) not in {str, int, torch.device}:
3631
return False
3732
else:
3833
return bool(torch.device(device).type == device_type)
3934

35+
4036
def check_cuda(device) -> bool:
4137
return bool(isinstance(device, int) or check_device_type(device, "cuda"))
4238

39+
4340
def return_xpu(device): # keep the device instance type, aka return string if the input is string
4441
return devices.device if device is None else f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
4542

@@ -107,13 +104,15 @@ def functional_pad(input, pad, mode='constant', value=None):
107104
else:
108105
return original_functional_pad(input, pad, mode=mode, value=value)
109106

107+
110108
# Diffusers FreeU
111109
original_fft_fftn = torch.fft.fftn
112110
@wraps(torch.fft.fftn)
113111
def fft_fftn(input, s=None, dim=None, norm=None, *, out=None):
114112
return_dtype = input.dtype
115113
return original_fft_fftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
116114

115+
117116
# Diffusers FreeU
118117
original_fft_ifftn = torch.fft.ifftn
119118
@wraps(torch.fft.ifftn)
@@ -131,6 +130,7 @@ def from_numpy(ndarray):
131130
else:
132131
return original_from_numpy(ndarray)
133132

133+
134134
original_as_tensor = torch.as_tensor
135135
@wraps(torch.as_tensor)
136136
def as_tensor(data, dtype=None, device=None):
@@ -156,6 +156,7 @@ def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
156156
dtype = torch.float32
157157
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
158158

159+
159160
torch.Tensor.original_Tensor_to = torch.Tensor.to
160161
@wraps(torch.Tensor.to)
161162
def Tensor_to(self, device=None, *args, **kwargs):
@@ -169,6 +170,7 @@ def Tensor_to(self, device=None, *args, **kwargs):
169170
device = torch.float32
170171
return self.original_Tensor_to(device, *args, **kwargs)
171172

173+
172174
original_Tensor_cuda = torch.Tensor.cuda
173175
@wraps(torch.Tensor.cuda)
174176
def Tensor_cuda(self, device=None, *args, **kwargs):
@@ -177,6 +179,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
177179
else:
178180
return original_Tensor_cuda(self, device, *args, **kwargs)
179181

182+
180183
original_Tensor_pin_memory = torch.Tensor.pin_memory
181184
@wraps(torch.Tensor.pin_memory)
182185
def Tensor_pin_memory(self, device=None, *args, **kwargs):
@@ -185,6 +188,7 @@ def Tensor_pin_memory(self, device=None, *args, **kwargs):
185188
else:
186189
return original_Tensor_pin_memory(self, device, *args, **kwargs)
187190

191+
188192
original_UntypedStorage_init = torch.UntypedStorage.__init__
189193
@wraps(torch.UntypedStorage.__init__)
190194
def UntypedStorage_init(*args, device=None, **kwargs):
@@ -193,6 +197,7 @@ def UntypedStorage_init(*args, device=None, **kwargs):
193197
else:
194198
return original_UntypedStorage_init(*args, device=device, **kwargs)
195199

200+
196201
if torch_version[0] > 2 or (torch_version[0] == 2 and torch_version[1] >= 4):
197202
original_UntypedStorage_to = torch.UntypedStorage.to
198203
@wraps(torch.UntypedStorage.to)
@@ -210,6 +215,7 @@ def UntypedStorage_cuda(self, device=None, non_blocking=False, **kwargs):
210215
else:
211216
return original_UntypedStorage_cuda(self, device=device, non_blocking=non_blocking, **kwargs)
212217

218+
213219
original_torch_empty = torch.empty
214220
@wraps(torch.empty)
215221
def torch_empty(*args, device=None, **kwargs):
@@ -218,6 +224,7 @@ def torch_empty(*args, device=None, **kwargs):
218224
else:
219225
return original_torch_empty(*args, device=device, **kwargs)
220226

227+
221228
original_torch_randn = torch.randn
222229
@wraps(torch.randn)
223230
def torch_randn(*args, device=None, dtype=None, **kwargs):
@@ -228,6 +235,7 @@ def torch_randn(*args, device=None, dtype=None, **kwargs):
228235
else:
229236
return original_torch_randn(*args, device=device, dtype=dtype, **kwargs)
230237

238+
231239
original_torch_ones = torch.ones
232240
@wraps(torch.ones)
233241
def torch_ones(*args, device=None, **kwargs):
@@ -236,6 +244,7 @@ def torch_ones(*args, device=None, **kwargs):
236244
else:
237245
return original_torch_ones(*args, device=device, **kwargs)
238246

247+
239248
original_torch_zeros = torch.zeros
240249
@wraps(torch.zeros)
241250
def torch_zeros(*args, device=None, **kwargs):
@@ -244,6 +253,7 @@ def torch_zeros(*args, device=None, **kwargs):
244253
else:
245254
return original_torch_zeros(*args, device=device, **kwargs)
246255

256+
247257
original_torch_full = torch.full
248258
@wraps(torch.full)
249259
def torch_full(*args, device=None, **kwargs):
@@ -252,6 +262,7 @@ def torch_full(*args, device=None, **kwargs):
252262
else:
253263
return original_torch_full(*args, device=device, **kwargs)
254264

265+
255266
original_torch_linspace = torch.linspace
256267
@wraps(torch.linspace)
257268
def torch_linspace(*args, device=None, **kwargs):
@@ -260,6 +271,7 @@ def torch_linspace(*args, device=None, **kwargs):
260271
else:
261272
return original_torch_linspace(*args, device=device, **kwargs)
262273

274+
263275
original_torch_eye = torch.eye
264276
@wraps(torch.eye)
265277
def torch_eye(*args, device=None, **kwargs):
@@ -268,6 +280,7 @@ def torch_eye(*args, device=None, **kwargs):
268280
else:
269281
return original_torch_eye(*args, device=device, **kwargs)
270282

283+
271284
original_torch_load = torch.load
272285
@wraps(torch.load)
273286
def torch_load(f, map_location=None, *args, **kwargs):
@@ -276,27 +289,31 @@ def torch_load(f, map_location=None, *args, **kwargs):
276289
else:
277290
return original_torch_load(f, *args, map_location=map_location, **kwargs)
278291

292+
279293
@wraps(torch.cuda.synchronize)
280294
def torch_cuda_synchronize(device=None):
281295
if check_cuda(device):
282296
return torch.xpu.synchronize(return_xpu(device))
283297
else:
284298
return torch.xpu.synchronize(device)
285299

300+
286301
@wraps(torch.cuda.device)
287302
def torch_cuda_device(device):
288303
if check_cuda(device):
289304
return torch.xpu.device(return_xpu(device))
290305
else:
291306
return torch.xpu.device(device)
292307

308+
293309
@wraps(torch.cuda.set_device)
294310
def torch_cuda_set_device(device):
295311
if check_cuda(device):
296312
torch.xpu.set_device(return_xpu(device))
297313
else:
298314
torch.xpu.set_device(device)
299315

316+
300317
# torch.Generator has to be a class for isinstance checks
301318
original_torch_Generator = torch.Generator
302319
class torch_Generator(original_torch_Generator):
@@ -335,7 +352,6 @@ def ipex_hijacks():
335352
torch._C.Generator = torch_Generator
336353

337354
torch.backends.cuda.sdp_kernel = return_null_context
338-
torch.nn.DataParallel = DummyDataParallel
339355
torch.UntypedStorage.is_cuda = is_cuda
340356
torch.amp.autocast_mode.autocast.__init__ = autocast_init
341357

0 commit comments

Comments
 (0)