Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 28bad7a

Browse files
authored
Fix torch BiLSTM dispatch and engage cuDNN (#22874)
* Fix torch BiLSTM dispatch and engage cuDNN * Address review: docstring and remove redundant tensor conversions * Skip cuDNN LSTM during ONNX trace * Loosen ONNX LSTM export test tolerance to 1e-5
1 parent 0f3a15a commit 28bad7a

2 files changed

Lines changed: 181 additions & 17 deletions

File tree

keras/src/backend/torch/rnn.py

Lines changed: 176 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -559,9 +559,6 @@ def lstm(
559559
if mask is not None:
560560
raise NotImplementedError
561561

562-
# Get device from inputs
563-
device = get_device()
564-
565562
# Convert to torch tensors (convert_to_tensor unwraps Variables)
566563
kernel = convert_to_tensor(kernel)
567564
recurrent_kernel = convert_to_tensor(recurrent_kernel)
@@ -580,16 +577,24 @@ def lstm(
580577
seq_dim = 1 if batch_first else 0
581578
inputs = torch.flip(inputs, dims=[seq_dim])
582579

583-
# Move all tensors to the same device
584-
inputs = inputs.to(device)
585-
initial_state_h = initial_state_h.to(device)
586-
initial_state_c = initial_state_c.to(device)
587-
588-
cudnn_supported = cudnn_ok(
589-
activation,
590-
recurrent_activation,
591-
unroll,
592-
use_bias=bias is not None,
580+
# cuDNN only runs on CUDA. Skip it when inputs aren't on CUDA, or when
581+
# we're inside a TorchScript / Dynamo trace -- the trace records device
582+
# transfers that then fail device-consistency validation downstream
583+
# (e.g. `torch.onnx.export` failing in `wrapper_CUDA_cat`).
584+
device = inputs.device
585+
cudnn_supported = (
586+
device.type == "cuda"
587+
and not torch.jit.is_tracing()
588+
and not (
589+
hasattr(torch.compiler, "is_compiling")
590+
and torch.compiler.is_compiling()
591+
)
592+
and cudnn_ok(
593+
activation,
594+
recurrent_activation,
595+
unroll,
596+
use_bias=bias is not None,
597+
)
593598
)
594599

595600
if cudnn_supported:
@@ -647,8 +652,9 @@ def _cudnn_lstm(
647652

648653
params = prepare_lstm_params(kernel, recurrent_kernel, bias, device)
649654

650-
# Use functional LSTM to maintain gradient flow through weight tensors
651-
outputs, (h_n, c_n) = torch._VF.lstm(
655+
# Use functional LSTM to maintain gradient flow through weight tensors.
656+
# ``torch._VF.lstm`` returns a flat ``(output, h_n, c_n)`` tuple.
657+
outputs, h_n, c_n = torch._VF.lstm(
652658
inputs,
653659
(initial_state_h, initial_state_c),
654660
params,
@@ -872,5 +878,158 @@ def _cudnn_gru(
872878
return last_output, outputs, [h_n]
873879

874880

875-
def bidirectional_lstm(*args, **kwargs):
876-
raise NotImplementedError
881+
def bidirectional_lstm(
882+
inputs,
883+
fwd_initial_state_h,
884+
fwd_initial_state_c,
885+
bwd_initial_state_h,
886+
bwd_initial_state_c,
887+
mask,
888+
fwd_kernel,
889+
fwd_recurrent_kernel,
890+
fwd_bias,
891+
bwd_kernel,
892+
bwd_recurrent_kernel,
893+
bwd_bias,
894+
activation,
895+
recurrent_activation,
896+
return_sequences=False,
897+
unroll=False,
898+
):
899+
"""Fused bidirectional cuDNN LSTM for the torch backend.
900+
901+
Runs forward and backward passes in a single
902+
``torch._VF.lstm(..., bidirectional=True)`` call instead of dispatching
903+
two unidirectional LSTM calls. Backward outputs are returned in original
904+
time order, ready for the caller's ``merge_mode`` to consume directly.
905+
906+
Args:
907+
inputs: Input tensor of shape ``(batch, time, features)``.
908+
fwd_initial_state_h: Initial hidden state for the forward direction,
909+
shape ``(batch, hidden)``.
910+
fwd_initial_state_c: Initial cell state for the forward direction,
911+
shape ``(batch, hidden)``.
912+
bwd_initial_state_h: Initial hidden state for the backward direction,
913+
shape ``(batch, hidden)``.
914+
bwd_initial_state_c: Initial cell state for the backward direction,
915+
shape ``(batch, hidden)``.
916+
mask: Sequence mask. Only ``None`` is supported; otherwise
917+
``NotImplementedError`` is raised so the caller can fall back to
918+
the two-pass path.
919+
fwd_kernel: Forward input kernel, shape ``(features, 4 * hidden)``.
920+
fwd_recurrent_kernel: Forward recurrent kernel, shape
921+
``(hidden, 4 * hidden)``.
922+
fwd_bias: Forward bias, shape ``(4 * hidden,)`` or ``None``.
923+
bwd_kernel: Backward input kernel, shape ``(features, 4 * hidden)``.
924+
bwd_recurrent_kernel: Backward recurrent kernel, shape
925+
``(hidden, 4 * hidden)``.
926+
bwd_bias: Backward bias, shape ``(4 * hidden,)`` or ``None``.
927+
activation: Output activation. Only ``tanh`` engages cuDNN.
928+
recurrent_activation: Gate activation. Only ``sigmoid`` engages
929+
cuDNN.
930+
return_sequences: If ``True``, return outputs at every timestep;
931+
otherwise only the last timestep.
932+
unroll: Not supported; cuDNN requires the rolled path.
933+
934+
Returns:
935+
A pair ``((fwd_last, fwd_outputs, [fwd_h_n, fwd_c_n]),
936+
(bwd_last, bwd_outputs, [bwd_h_n, bwd_c_n]))`` matching the JAX
937+
equivalent's return shape.
938+
"""
939+
if mask is not None:
940+
raise NotImplementedError
941+
if not cudnn_ok(
942+
activation,
943+
recurrent_activation,
944+
unroll,
945+
use_bias=fwd_bias is not None and bwd_bias is not None,
946+
):
947+
raise NotImplementedError
948+
949+
fwd_kernel = convert_to_tensor(fwd_kernel)
950+
fwd_recurrent_kernel = convert_to_tensor(fwd_recurrent_kernel)
951+
bwd_kernel = convert_to_tensor(bwd_kernel)
952+
bwd_recurrent_kernel = convert_to_tensor(bwd_recurrent_kernel)
953+
954+
compute_dtype = fwd_kernel.dtype
955+
inputs = convert_to_tensor(inputs, dtype=compute_dtype)
956+
fwd_h0 = convert_to_tensor(fwd_initial_state_h, dtype=compute_dtype)
957+
fwd_c0 = convert_to_tensor(fwd_initial_state_c, dtype=compute_dtype)
958+
bwd_h0 = convert_to_tensor(bwd_initial_state_h, dtype=compute_dtype)
959+
bwd_c0 = convert_to_tensor(bwd_initial_state_c, dtype=compute_dtype)
960+
961+
# cuDNN only runs on CUDA. Fall back to the two-pass path when inputs
962+
# aren't on CUDA, or when we're inside a TorchScript / Dynamo trace --
963+
# the trace records device transfers that then fail device-consistency
964+
# validation downstream (e.g. `torch.onnx.export` in `wrapper_CUDA_cat`).
965+
device = inputs.device
966+
if (
967+
device.type != "cuda"
968+
or torch.jit.is_tracing()
969+
or (
970+
hasattr(torch.compiler, "is_compiling")
971+
and torch.compiler.is_compiling()
972+
)
973+
):
974+
raise NotImplementedError
975+
976+
fwd_params = prepare_lstm_params(
977+
fwd_kernel, fwd_recurrent_kernel, fwd_bias, device
978+
)
979+
bwd_params = prepare_lstm_params(
980+
bwd_kernel, bwd_recurrent_kernel, bwd_bias, device
981+
)
982+
983+
# torch._VF.lstm with bidirectional=True expects 4 params per direction,
984+
# forward direction first, then backward.
985+
params = fwd_params + bwd_params
986+
987+
# cuDNN expects (num_layers * num_directions, batch, hidden) for h0/c0.
988+
h_0 = torch.stack([fwd_h0, bwd_h0], dim=0)
989+
c_0 = torch.stack([fwd_c0, bwd_c0], dim=0)
990+
991+
try:
992+
# ``torch._VF.lstm`` returns a flat ``(output, h_n, c_n)`` tuple.
993+
outputs, h_n, c_n = torch._VF.lstm(
994+
inputs,
995+
(h_0, c_0),
996+
params,
997+
True, # has_biases
998+
1, # num_layers
999+
0.0, # dropout
1000+
torch.is_grad_enabled(), # training
1001+
True, # bidirectional
1002+
True, # batch_first
1003+
)
1004+
except (RuntimeError, TypeError, ValueError) as e:
1005+
raise NotImplementedError(
1006+
f"cuDNN bidirectional LSTM failed: {e}"
1007+
) from e
1008+
1009+
# outputs: (batch, seq_len, 2 * hidden_size). First half is the forward
1010+
# direction, second half is the backward direction (in original time
1011+
# order, courtesy of cuDNN).
1012+
hidden_size = fwd_recurrent_kernel.shape[0]
1013+
y_fwd = outputs[..., :hidden_size]
1014+
y_bwd = outputs[..., hidden_size:]
1015+
1016+
fwd_h_n, bwd_h_n = h_n[0], h_n[1]
1017+
fwd_c_n, bwd_c_n = c_n[0], c_n[1]
1018+
1019+
# Forward "last" is the last timestep of the forward sweep; backward
1020+
# "last" is the first timestep in original time order (i.e., the result
1021+
# after the full reverse sweep).
1022+
fwd_last = y_fwd[:, -1]
1023+
bwd_last = y_bwd[:, 0]
1024+
1025+
if return_sequences:
1026+
fwd_outputs = y_fwd
1027+
bwd_outputs = y_bwd
1028+
else:
1029+
fwd_outputs = fwd_last.unsqueeze(1)
1030+
bwd_outputs = bwd_last.unsqueeze(1)
1031+
1032+
return (
1033+
(fwd_last, fwd_outputs, [fwd_h_n, fwd_c_n]),
1034+
(bwd_last, bwd_outputs, [bwd_h_n, bwd_c_n]),
1035+
)

keras/src/export/onnx_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,14 @@ def test_standard_model_export(self, model_type):
101101
ort_inputs = {
102102
k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input])
103103
}
104+
# cuDNN-fused LSTM reference vs. unrolled ONNX graph differ by ~3e-6.
105+
atol = 1e-5 if model_type == "lstm" else 1e-6
106+
rtol = 1e-5 if model_type == "lstm" else 1e-6
104107
self.assertAllClose(
105108
ort_session.run(None, ort_inputs)[0],
106109
ref_output,
110+
atol=atol,
111+
rtol=rtol,
107112
tpu_atol=1e-3,
108113
tpu_rtol=1e-2,
109114
)

0 commit comments

Comments
 (0)