Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5094aac

Browse files
authored
Merge branch 'master' into patch-2
2 parents 500de6b + 6f404be commit 5094aac

2 files changed

Lines changed: 6 additions & 3 deletions

File tree

advanced_source/torch_script_custom_ops.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ Registration is very simple. For our case, we need to write:
170170
.. code-block:: cpp
171171
172172
static auto registry =
173-
torch::jit::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
173+
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
174174
175175
somewhere in the global scope of our ``op.cpp`` file. This creates a global
176176
variable ``registry``, which will register our operator with TorchScript in its
@@ -188,7 +188,7 @@ operator name are separated by two colons (``::``).
188188
.. code-block:: cpp
189189
190190
static auto registry =
191-
torch::jit::RegisterOperators("my_ops::warp_perspective", &warp_perspective)
191+
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective)
192192
.op("my_ops::another_op", &another_op)
193193
.op("my_ops::and_another_op", &and_another_op);
194194
@@ -982,7 +982,7 @@ custom TorchScript operator as a string. For this, use
982982
}
983983
984984
static auto registry =
985-
torch::jit::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
985+
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
986986
"""
987987
988988
torch.utils.cpp_extension.load_inline(

beginner_source/data_loading_tutorial.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ def __len__(self):
136136
return len(self.landmarks_frame)
137137

138138
def __getitem__(self, idx):
139+
if torch.is_tensor(idx):
140+
idx = idx.tolist()
141+
139142
img_name = os.path.join(self.root_dir,
140143
self.landmarks_frame.iloc[idx, 0])
141144
image = io.imread(img_name)

0 commit comments

Comments
 (0)