@@ -946,7 +946,8 @@ without having to convert to a single pointer:
946946 Accessor objects have a relatively high level interface, with ``.size() `` and
947947``.stride() `` methods and multi-dimensional indexing. The ``.accessor<> ``
948948interface is designed to access data efficiently on cpu tensor. The equivalent
949- for cuda tensors is the ``packed_accessor<> ``, which produces a Packed Accessor.
949+ for cuda tensors are ``packed_accessor64<> `` and ``packed_accessor32<> ``, which
950+ produce Packed Accessors with either 64-bit or 32-bit integer indexing.
950951
951952The fundamental difference with Accessor is that a Packed Accessor copies size
952953and stride data inside of its structure instead of pointing to it. It allows us
@@ -957,34 +958,34 @@ We can design a function that takes Packed Accessors instead of pointers.
957958.. code-block :: cpp
958959
959960 __global__ void lltm_cuda_forward_kernel(
960- const torch::PackedTensorAccessor <scalar_t,3,torch::RestrictPtrTraits,size_t > gates,
961- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > old_cell,
962- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > new_h,
963- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > new_cell,
964- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > input_gate,
965- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > output_gate,
966- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > candidate_cell)
961+ const torch::PackedTensorAccessor32 <scalar_t,3,torch::RestrictPtrTraits> gates,
962+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> old_cell,
963+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> new_h,
964+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> new_cell,
965+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> input_gate,
966+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> output_gate,
967+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> candidate_cell)
967968
968969 Let's decompose the template used here. the first two arguments ``scalar_t `` and
969970``2 `` are the same as regular Accessor. The argument
970971``torch::RestrictPtrTraits `` indicates that the ``__restrict__ `` keyword must be
971- used. Finally, the argument `` size_t `` indicates that sizes and strides must be
972- stored in a ``size_t `` integer . This is important as by default `` int64_t `` is
973- used and can make the kernel slower.
972+ used. Note also that we've used the `` PackedAccessor32 `` variant which store the
973+ sizes and strides in an ``int32_t `` . This is important as using the 64-bit
974+ variant (`` PackedAccessor64 ``) can make the kernel slower.
974975
975976The function declaration becomes
976977
977978.. code-block :: cpp
978979
979980 template <typename scalar_t>
980981 __global__ void lltm_cuda_forward_kernel(
981- const torch::PackedTensorAccessor <scalar_t,3,torch::RestrictPtrTraits,size_t > gates,
982- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > old_cell,
983- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > new_h,
984- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > new_cell,
985- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > input_gate,
986- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > output_gate,
987- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > candidate_cell) {
982+ const torch::PackedTensorAccessor32 <scalar_t,3,torch::RestrictPtrTraits> gates,
983+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> old_cell,
984+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> new_h,
985+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> new_cell,
986+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> input_gate,
987+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> output_gate,
988+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> candidate_cell) {
988989 //batch index
989990 const int n = blockIdx.y;
990991 // column index
@@ -1000,7 +1001,7 @@ The function declaration becomes
10001001 }
10011002
10021003 The implementation is much more readable! This function is then called by
1003- creating Packed Accessors with the ``.packed_accessor <> `` method within the
1004+ creating Packed Accessors with the ``.packed_accessor32 <> `` method within the
10041005host function.
10051006
10061007.. code-block :: cpp
@@ -1029,13 +1030,13 @@ host function.
10291030
10301031 AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
10311032 lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
1032- gates.packed_accessor <scalar_t,3,torch::RestrictPtrTraits,size_t >(),
1033- old_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1034- new_h.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1035- new_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1036- input_gate.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1037- output_gate.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1038- candidate_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >());
1033+ gates.packed_accessor32 <scalar_t,3,torch::RestrictPtrTraits>(),
1034+ old_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1035+ new_h.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1036+ new_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1037+ input_gate.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1038+ output_gate.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1039+ candidate_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>());
10391040 }));
10401041
10411042 return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
@@ -1048,15 +1049,15 @@ on it:
10481049
10491050 template <typename scalar_t>
10501051 __global__ void lltm_cuda_backward_kernel(
1051- torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > d_old_cell,
1052- torch::PackedTensorAccessor <scalar_t,3,torch::RestrictPtrTraits,size_t > d_gates,
1053- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > grad_h,
1054- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > grad_cell,
1055- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > new_cell,
1056- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > input_gate,
1057- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > output_gate,
1058- const torch::PackedTensorAccessor <scalar_t,2,torch::RestrictPtrTraits,size_t > candidate_cell,
1059- const torch::PackedTensorAccessor <scalar_t,3,torch::RestrictPtrTraits,size_t > gate_weights) {
1052+ torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> d_old_cell,
1053+ torch::PackedTensorAccessor32 <scalar_t,3,torch::RestrictPtrTraits> d_gates,
1054+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> grad_h,
1055+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> grad_cell,
1056+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> new_cell,
1057+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> input_gate,
1058+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> output_gate,
1059+ const torch::PackedTensorAccessor32 <scalar_t,2,torch::RestrictPtrTraits> candidate_cell,
1060+ const torch::PackedTensorAccessor32 <scalar_t,3,torch::RestrictPtrTraits> gate_weights) {
10601061 //batch index
10611062 const int n = blockIdx.y;
10621063 // column index
@@ -1102,15 +1103,15 @@ on it:
11021103
11031104 AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] {
11041105 lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
1105- d_old_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1106- d_gates.packed_accessor <scalar_t,3,torch::RestrictPtrTraits,size_t >(),
1107- grad_h.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1108- grad_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1109- new_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1110- input_gate.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1111- output_gate.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1112- candidate_cell.packed_accessor <scalar_t,2,torch::RestrictPtrTraits,size_t >(),
1113- gates.packed_accessor <scalar_t,3,torch::RestrictPtrTraits,size_t >());
1106+ d_old_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1107+ d_gates.packed_accessor32 <scalar_t,3,torch::RestrictPtrTraits>(),
1108+ grad_h.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1109+ grad_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1110+ new_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1111+ input_gate.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1112+ output_gate.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1113+ candidate_cell.packed_accessor32 <scalar_t,2,torch::RestrictPtrTraits>(),
1114+ gates.packed_accessor32 <scalar_t,3,torch::RestrictPtrTraits>());
11141115 }));
11151116
11161117 auto d_gate_weights = d_gates.reshape({batch_size, 3*state_size});
0 commit comments