-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathmulti_tensor_apply.py
More file actions
27 lines (22 loc) · 938 Bytes
/
multi_tensor_apply.py
File metadata and controls
27 lines (22 loc) · 938 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class MultiTensorApply(object):
available = False
warned = False
def __init__(self, chunk_size):
try:
import amp_C
MultiTensorApply.available = True
self.chunk_size = chunk_size
except ImportError as err:
MultiTensorApply.available = False
MultiTensorApply.import_err = err
def check_avail(self):
if MultiTensorApply.available == False:
raise RuntimeError(
"Attempted to call MultiTensorApply method, but MultiTensorApply "
"is not available, possibly because Apex was installed without "
"--cpp_ext --cuda_ext. Original import error message:",
MultiTensorApply.import_err,
)
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
self.check_avail()
return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)