Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e83df50

Browse files
committed
IPEX fix missing dtype from randn
1 parent cc6101e commit e83df50

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

modules/intel/ipex/hijacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,9 @@ def torch_randn(*args, device=None, dtype=None, **kwargs):
322322
if dtype is bytes:
323323
dtype = None
324324
if check_cuda(device):
325-
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
325+
return original_torch_randn(*args, device=return_xpu(device), dtype=dtype, **kwargs)
326326
else:
327-
return original_torch_randn(*args, device=device, **kwargs)
327+
return original_torch_randn(*args, device=device, dtype=dtype, **kwargs)
328328

329329
original_torch_ones = torch.ones
330330
@wraps(torch.ones)

0 commit comments

Comments
 (0)