@@ -559,9 +559,6 @@ def lstm(
559559 if mask is not None :
560560 raise NotImplementedError
561561
562- # Get device from inputs
563- device = get_device ()
564-
565562 # Convert to torch tensors (convert_to_tensor unwraps Variables)
566563 kernel = convert_to_tensor (kernel )
567564 recurrent_kernel = convert_to_tensor (recurrent_kernel )
@@ -580,16 +577,24 @@ def lstm(
580577 seq_dim = 1 if batch_first else 0
581578 inputs = torch .flip (inputs , dims = [seq_dim ])
582579
583- # Move all tensors to the same device
584- inputs = inputs .to (device )
585- initial_state_h = initial_state_h .to (device )
586- initial_state_c = initial_state_c .to (device )
587-
588- cudnn_supported = cudnn_ok (
589- activation ,
590- recurrent_activation ,
591- unroll ,
592- use_bias = bias is not None ,
580+ # cuDNN only runs on CUDA. Skip it when inputs aren't on CUDA, or when
581+ # we're inside a TorchScript / Dynamo trace -- the trace records device
582+ # transfers that then fail device-consistency validation downstream
583+ # (e.g. `torch.onnx.export` failing in `wrapper_CUDA_cat`).
584+ device = inputs .device
585+ cudnn_supported = (
586+ device .type == "cuda"
587+ and not torch .jit .is_tracing ()
588+ and not (
589+ hasattr (torch .compiler , "is_compiling" )
590+ and torch .compiler .is_compiling ()
591+ )
592+ and cudnn_ok (
593+ activation ,
594+ recurrent_activation ,
595+ unroll ,
596+ use_bias = bias is not None ,
597+ )
593598 )
594599
595600 if cudnn_supported :
@@ -647,8 +652,9 @@ def _cudnn_lstm(
647652
648653 params = prepare_lstm_params (kernel , recurrent_kernel , bias , device )
649654
650- # Use functional LSTM to maintain gradient flow through weight tensors
651- outputs , (h_n , c_n ) = torch ._VF .lstm (
655+ # Use functional LSTM to maintain gradient flow through weight tensors.
656+ # ``torch._VF.lstm`` returns a flat ``(output, h_n, c_n)`` tuple.
657+ outputs , h_n , c_n = torch ._VF .lstm (
652658 inputs ,
653659 (initial_state_h , initial_state_c ),
654660 params ,
@@ -872,5 +878,158 @@ def _cudnn_gru(
872878 return last_output , outputs , [h_n ]
873879
874880
875- def bidirectional_lstm (* args , ** kwargs ):
876- raise NotImplementedError
881+ def bidirectional_lstm (
882+ inputs ,
883+ fwd_initial_state_h ,
884+ fwd_initial_state_c ,
885+ bwd_initial_state_h ,
886+ bwd_initial_state_c ,
887+ mask ,
888+ fwd_kernel ,
889+ fwd_recurrent_kernel ,
890+ fwd_bias ,
891+ bwd_kernel ,
892+ bwd_recurrent_kernel ,
893+ bwd_bias ,
894+ activation ,
895+ recurrent_activation ,
896+ return_sequences = False ,
897+ unroll = False ,
898+ ):
899+ """Fused bidirectional cuDNN LSTM for the torch backend.
900+
901+ Runs forward and backward passes in a single
902+ ``torch._VF.lstm(..., bidirectional=True)`` call instead of dispatching
903+ two unidirectional LSTM calls. Backward outputs are returned in original
904+ time order, ready for the caller's ``merge_mode`` to consume directly.
905+
906+ Args:
907+ inputs: Input tensor of shape ``(batch, time, features)``.
908+ fwd_initial_state_h: Initial hidden state for the forward direction,
909+ shape ``(batch, hidden)``.
910+ fwd_initial_state_c: Initial cell state for the forward direction,
911+ shape ``(batch, hidden)``.
912+ bwd_initial_state_h: Initial hidden state for the backward direction,
913+ shape ``(batch, hidden)``.
914+ bwd_initial_state_c: Initial cell state for the backward direction,
915+ shape ``(batch, hidden)``.
916+ mask: Sequence mask. Only ``None`` is supported; otherwise
917+ ``NotImplementedError`` is raised so the caller can fall back to
918+ the two-pass path.
919+ fwd_kernel: Forward input kernel, shape ``(features, 4 * hidden)``.
920+ fwd_recurrent_kernel: Forward recurrent kernel, shape
921+ ``(hidden, 4 * hidden)``.
922+ fwd_bias: Forward bias, shape ``(4 * hidden,)`` or ``None``.
923+ bwd_kernel: Backward input kernel, shape ``(features, 4 * hidden)``.
924+ bwd_recurrent_kernel: Backward recurrent kernel, shape
925+ ``(hidden, 4 * hidden)``.
926+ bwd_bias: Backward bias, shape ``(4 * hidden,)`` or ``None``.
927+ activation: Output activation. Only ``tanh`` engages cuDNN.
928+ recurrent_activation: Gate activation. Only ``sigmoid`` engages
929+ cuDNN.
930+ return_sequences: If ``True``, return outputs at every timestep;
931+ otherwise only the last timestep.
932+ unroll: Not supported; cuDNN requires the rolled path.
933+
934+ Returns:
935+ A pair ``((fwd_last, fwd_outputs, [fwd_h_n, fwd_c_n]),
936+ (bwd_last, bwd_outputs, [bwd_h_n, bwd_c_n]))`` matching the JAX
937+ equivalent's return shape.
938+ """
939+ if mask is not None :
940+ raise NotImplementedError
941+ if not cudnn_ok (
942+ activation ,
943+ recurrent_activation ,
944+ unroll ,
945+ use_bias = fwd_bias is not None and bwd_bias is not None ,
946+ ):
947+ raise NotImplementedError
948+
949+ fwd_kernel = convert_to_tensor (fwd_kernel )
950+ fwd_recurrent_kernel = convert_to_tensor (fwd_recurrent_kernel )
951+ bwd_kernel = convert_to_tensor (bwd_kernel )
952+ bwd_recurrent_kernel = convert_to_tensor (bwd_recurrent_kernel )
953+
954+ compute_dtype = fwd_kernel .dtype
955+ inputs = convert_to_tensor (inputs , dtype = compute_dtype )
956+ fwd_h0 = convert_to_tensor (fwd_initial_state_h , dtype = compute_dtype )
957+ fwd_c0 = convert_to_tensor (fwd_initial_state_c , dtype = compute_dtype )
958+ bwd_h0 = convert_to_tensor (bwd_initial_state_h , dtype = compute_dtype )
959+ bwd_c0 = convert_to_tensor (bwd_initial_state_c , dtype = compute_dtype )
960+
961+ # cuDNN only runs on CUDA. Fall back to the two-pass path when inputs
962+ # aren't on CUDA, or when we're inside a TorchScript / Dynamo trace --
963+ # the trace records device transfers that then fail device-consistency
964+ # validation downstream (e.g. `torch.onnx.export` in `wrapper_CUDA_cat`).
965+ device = inputs .device
966+ if (
967+ device .type != "cuda"
968+ or torch .jit .is_tracing ()
969+ or (
970+ hasattr (torch .compiler , "is_compiling" )
971+ and torch .compiler .is_compiling ()
972+ )
973+ ):
974+ raise NotImplementedError
975+
976+ fwd_params = prepare_lstm_params (
977+ fwd_kernel , fwd_recurrent_kernel , fwd_bias , device
978+ )
979+ bwd_params = prepare_lstm_params (
980+ bwd_kernel , bwd_recurrent_kernel , bwd_bias , device
981+ )
982+
983+ # torch._VF.lstm with bidirectional=True expects 4 params per direction,
984+ # forward direction first, then backward.
985+ params = fwd_params + bwd_params
986+
987+ # cuDNN expects (num_layers * num_directions, batch, hidden) for h0/c0.
988+ h_0 = torch .stack ([fwd_h0 , bwd_h0 ], dim = 0 )
989+ c_0 = torch .stack ([fwd_c0 , bwd_c0 ], dim = 0 )
990+
991+ try :
992+ # ``torch._VF.lstm`` returns a flat ``(output, h_n, c_n)`` tuple.
993+ outputs , h_n , c_n = torch ._VF .lstm (
994+ inputs ,
995+ (h_0 , c_0 ),
996+ params ,
997+ True , # has_biases
998+ 1 , # num_layers
999+ 0.0 , # dropout
1000+ torch .is_grad_enabled (), # training
1001+ True , # bidirectional
1002+ True , # batch_first
1003+ )
1004+ except (RuntimeError , TypeError , ValueError ) as e :
1005+ raise NotImplementedError (
1006+ f"cuDNN bidirectional LSTM failed: { e } "
1007+ ) from e
1008+
1009+ # outputs: (batch, seq_len, 2 * hidden_size). First half is the forward
1010+ # direction, second half is the backward direction (in original time
1011+ # order, courtesy of cuDNN).
1012+ hidden_size = fwd_recurrent_kernel .shape [0 ]
1013+ y_fwd = outputs [..., :hidden_size ]
1014+ y_bwd = outputs [..., hidden_size :]
1015+
1016+ fwd_h_n , bwd_h_n = h_n [0 ], h_n [1 ]
1017+ fwd_c_n , bwd_c_n = c_n [0 ], c_n [1 ]
1018+
1019+ # Forward "last" is the last timestep of the forward sweep; backward
1020+ # "last" is the first timestep in original time order (i.e., the result
1021+ # after the full reverse sweep).
1022+ fwd_last = y_fwd [:, - 1 ]
1023+ bwd_last = y_bwd [:, 0 ]
1024+
1025+ if return_sequences :
1026+ fwd_outputs = y_fwd
1027+ bwd_outputs = y_bwd
1028+ else :
1029+ fwd_outputs = fwd_last .unsqueeze (1 )
1030+ bwd_outputs = bwd_last .unsqueeze (1 )
1031+
1032+ return (
1033+ (fwd_last , fwd_outputs , [fwd_h_n , fwd_c_n ]),
1034+ (bwd_last , bwd_outputs , [bwd_h_n , bwd_c_n ]),
1035+ )
0 commit comments