Fix torch BiLSTM dispatch and engage cuDNN#22874
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a fused bidirectional cuDNN LSTM for the Torch backend and updates the unidirectional LSTM to correctly handle the flat tuple returned by torch._VF.lstm. Feedback indicates that the bidirectional_lstm docstring is missing required Arguments: and code example sections as specified in the Keras API design guidelines. Additionally, the reviewer pointed out several redundant tensor operations and device movements that can be simplified by using the dtype parameter within convert_to_tensor.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #22874 +/- ##
==========================================
- Coverage 84.54% 75.40% -9.15%
==========================================
Files 463 463
Lines 68004 68036 +32
Branches 11142 11143 +1
==========================================
- Hits 57493 51300 -6193
- Misses 7586 14134 +6548
+ Partials 2925 2602 -323
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Please look at the error on torch. |
|
Hi, thanks for the review! I fixed the error with the torch test. |
|
So one test is failing on Torch: This test is using BidiLSTM. But feel free to adjust the atol and rtol on that assert. Ignore the argpartition failure, that's fixed in master, so you can rebase to get the fix. |
f95d466 to
c9ad2dc
Compare
|
HI thanks for the help! I just updated the tolerances and rebased. |
hertschuh
left a comment
There was a problem hiding this comment.
Thanks for adding this!
Description
Fixes a silent cuDNN dispatch failure in the torch LSTM backend (which I introduced) and adds a fused bidirectional cuDNN path mirroring #22791 (JAX equivalent).
The bug
torch._VF.lstmreturns a flat(output, h_n, c_n)tuple. In #22470 (March 2026, "Rewrite torch LSTM to use functional API with CPU fallback") I unpacked it as the nestedoutputs, (h_n, c_n)form, which raisesValueError: too many values to unpackat runtime. The exception is caught by thetry: ... except Exception: passin the surroundinglstm(), silently falling back to_fallback_lstm(the pure-Python per-timestep loop). CPU tests pass because_fallback_lstmis functionally correct; CPU never exercises the cuDNN path; no GPU CI checks LSTM perf, so the regression went unnoticed.Net effect: every keras-on-torch user with
LSTM(use_cudnn='auto')on GPU has been running through the slow fallback since #22470 merged. This PR restores the cuDNN path and also adds the bidirectional fused path that was previously aNotImplementedErrorstub.Changes
_cudnn_lstm(unidirectional,rnn.py:651): fix unpacking tooutputs, h_n, c_n = torch._VF.lstm(...).bidirectional_lstm: implement the fused path usingtorch._VF.lstm(..., bidirectional=True), mirroring Fuse JAX Bidirectional LSTM into a single cuDNN call #22791's JAX design.Benchmark (NVIDIA A100-SXM4-40GB, torch 2.10.0+cu128, B=32, L=128, H=256, 2 BiLSTM layers,
train_on_batch)The kernel-launch drop is mechanical evidence that cuDNN engages: master was unrolling the per-timestep Python loop in
_fallback_lstm; this PR routes through cuDNN's fused kernel. Confirmed cuDNN engagement on GPU via direct invocation ofBidirectional._call_fused_lstmand inspection of the installed source.Contributor Agreement