Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix torch BiLSTM dispatch and engage cuDNN#22874

Merged
hertschuh merged 4 commits into
keras-team:masterfrom
MarcosAsh:torch-bidirectional-lstm-cudnn
May 14, 2026
Merged

Fix torch BiLSTM dispatch and engage cuDNN#22874
hertschuh merged 4 commits into
keras-team:masterfrom
MarcosAsh:torch-bidirectional-lstm-cudnn

Conversation

@MarcosAsh
Copy link
Copy Markdown
Contributor

Description

Fixes a silent cuDNN dispatch failure in the torch LSTM backend (which I introduced) and adds a fused bidirectional cuDNN path mirroring #22791 (JAX equivalent).

The bug

torch._VF.lstm returns a flat (output, h_n, c_n) tuple. In #22470 (March 2026, "Rewrite torch LSTM to use functional API with CPU fallback") I unpacked it as the nested outputs, (h_n, c_n) form, which raises ValueError: too many values to unpack at runtime. The exception is caught by the try: ... except Exception: pass in the surrounding lstm(), silently falling back to _fallback_lstm (the pure-Python per-timestep loop). CPU tests pass because _fallback_lstm is functionally correct; CPU never exercises the cuDNN path; no GPU CI checks LSTM perf, so the regression went unnoticed.

Net effect: every keras-on-torch user with LSTM(use_cudnn='auto') on GPU has been running through the slow fallback since #22470 merged. This PR restores the cuDNN path and also adds the bidirectional fused path that was previously a NotImplementedError stub.

Changes

Benchmark (NVIDIA A100-SXM4-40GB, torch 2.10.0+cu128, B=32, L=128, H=256, 2 BiLSTM layers, train_on_batch)

master this PR reduction
BiLSTM warm step 391 ms 119.6 ms 3.27x faster
cudaLaunchKernel / step 16,459 2,078 8x fewer launches

The kernel-launch drop is mechanical evidence that cuDNN engages: master was unrolling the per-timestep Python loop in _fallback_lstm; this PR routes through cuDNN's fused kernel. Confirmed cuDNN engagement on GPU via direct invocation of Bidirectional._call_fused_lstm and inspection of the installed source.

Contributor Agreement

  • I am a human, and not a bot.
  • I will be responsible for responding to review comments in a timely manner.
  • I will work with the maintainers to push this PR forward until submission.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a fused bidirectional cuDNN LSTM for the Torch backend and updates the unidirectional LSTM to correctly handle the flat tuple returned by torch._VF.lstm. Feedback indicates that the bidirectional_lstm docstring is missing required Arguments: and code example sections as specified in the Keras API design guidelines. Additionally, the reviewer pointed out several redundant tensor operations and device movements that can be simplified by using the dtype parameter within convert_to_tensor.

Comment thread keras/src/backend/torch/rnn.py
Comment thread keras/src/backend/torch/rnn.py Outdated
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 11, 2026

Codecov Report

❌ Patch coverage is 97.43590% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 75.40%. Comparing base (0f3a15a) to head (c9ad2dc).

Files with missing lines Patch % Lines
keras/src/backend/torch/rnn.py 97.43% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (0f3a15a) and HEAD (c9ad2dc). Click for more details.

HEAD has 1 upload less than BASE
Flag BASE (0f3a15a) HEAD (c9ad2dc)
keras-openvino 1 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22874      +/-   ##
==========================================
- Coverage   84.54%   75.40%   -9.15%     
==========================================
  Files         463      463              
  Lines       68004    68036      +32     
  Branches    11142    11143       +1     
==========================================
- Hits        57493    51300    -6193     
- Misses       7586    14134    +6548     
+ Partials     2925     2602     -323     
Flag Coverage Δ
keras 75.20% <97.43%> (-9.14%) ⬇️
keras-cpu 74.51% <12.82%> (-3.18%) ⬇️
keras-gpu 69.30% <97.43%> (+0.03%) ⬆️
keras-jax 58.21% <2.56%> (-0.03%) ⬇️
keras-numpy 53.54% <2.56%> (-0.03%) ⬇️
keras-openvino ?
keras-tensorflow 59.50% <2.56%> (+1.17%) ⬆️
keras-torch 58.78% <97.43%> (+0.05%) ⬆️
keras-tpu 57.02% <2.56%> (-0.03%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@hertschuh
Copy link
Copy Markdown
Collaborator

Please look at the error on torch.

@MarcosAsh
Copy link
Copy Markdown
Contributor Author

Hi, thanks for the review! I fixed the error with the torch test.

@hertschuh
Copy link
Copy Markdown
Collaborator

So one test is failing on Torch:

FAILED keras/src/export/onnx_test.py::ExportONNXTest::test_standard_model_export_lstm - AssertionError: 
Not equal to tolerance rtol=1e-06, atol=1e-06
...
Max absolute difference among violations: 2.771616e-06
Max relative difference among violations: 0.00093703

This test is using BidiLSTM.

But feel free to adjust the atol and rtol on that assert.

Ignore the argpartition failure, that's fixed in master, so you can rebase to get the fix.

@MarcosAsh MarcosAsh force-pushed the torch-bidirectional-lstm-cudnn branch from f95d466 to c9ad2dc Compare May 14, 2026 12:42
@MarcosAsh
Copy link
Copy Markdown
Contributor Author

HI thanks for the help! I just updated the tolerances and rebased.

Copy link
Copy Markdown
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

@google-ml-butler google-ml-butler Bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels May 14, 2026
@hertschuh hertschuh merged commit 28bad7a into keras-team:master May 14, 2026
15 of 16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kokoro:force-run ready to pull Ready to be merged into the codebase size:M

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants