-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add support for ND-matmul #3048
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This comment was marked as outdated.
This comment was marked as outdated.
|
Benchmarks that have stayed the same: |
PTNobel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still need to review the tests and SciPy backend, but looking good so far
| -> tuple[np.ndarray | sp.spmatrix, bool]: | ||
| def get_constant_data( | ||
| self, lin_op: LinOp, view: TensorView, target_shape: tuple[int, ...] | None | ||
| ) -> tuple[np.ndarray | sp.spmatrix, bool]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this return a sparse array?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better for it to return the raw numpy array if the underlying data is numpy, rather than converting to a sparse matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like it is always converted to sparse in the scipy backend but not in the COO backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh okay. Sounds good
| @staticmethod | ||
| def test_coo_reshape_vs_reshape_parametric_constant(): | ||
| """ | ||
| Test that coo_reshape and reshape_parametric_constant behave differently. | ||
| - coo_reshape: Uses linear index reshaping, preserves all entries. | ||
| Used by the 'reshape' linop for general reshape operations. | ||
| - reshape_parametric_constant: Deduplicates based on param_idx for | ||
| parametric tensors. Used for reshaping constant data in matmul. | ||
| This is a regression test for an issue where using parametric reshape | ||
| logic in coo_reshape caused DGP tests to fail with index out of bounds | ||
| errors, because DGP generates tensors where param_idx doesn't map | ||
| directly to positions in the target matrix. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this test and the docstring below, what issue is it talking about?
| # Raw data access is intentional: batch-varying constants are never parametric. | ||
| # lin_op.data is a LinOp of type "*_const", so lin_op.data.data gets the numpy array. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we maybe add an assert here that lin_op doesn't have any parameters?
| # Compute target shape (2D shape, or row vector for 1D, or (1,1) for 0D) | ||
| data_shape = lin_op.data.shape | ||
| if len(data_shape) == 2: | ||
| target = data_shape | ||
| elif len(data_shape) == 1: | ||
| target = (1, data_shape[0]) | ||
| else: # 0D scalar | ||
| target = (1, 1) | ||
| lhs, is_param_free_lhs = self.get_constant_data(lin_op.data, view, target_shape=target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't particularly like this change, but maybe it's fine?
| for slice_idx in range(param_size): | ||
| slice_matrix = stacked_matrix[slice_idx * m:(slice_idx + 1) * m, :] | ||
| yield _apply_nd_kron_structure(slice_matrix, batch_size, n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this for loop, but maybe no easy way to avoid it. This is some super complicated stuff.. as long as it works, I guess.
| For a column vector of shape (p * m * k, 1), reshape to (p * m, k) | ||
| where p is the number of copies (from broadcast operations). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this p be one always? Is the broadcasting only happening for ND matmul cases?
| def _reshape_parametric(v: sp.csc_array, lin_op_shape: tuple[int, int], | ||
| param_size: int) -> sp.csc_array: | ||
| """ | ||
| Reshape parametric constant data from column to matrix format. | ||
| For parametric data, entries may be duplicated by broadcast operations. | ||
| We deduplicate and compute positions based on param_idx. | ||
| The param_idx encodes which parameter value each entry corresponds to. | ||
| After broadcast_to, entries are duplicated but param_idx stays the same. | ||
| We keep only the first occurrence of each param_idx. | ||
| Parameters | ||
| ---------- | ||
| v : sparse column of shape (broadcast_size * param_size, 1) | ||
| lin_op_shape : (m, k) target matrix dimensions | ||
| param_size : number of parameter values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how would you know if entries are broadcasted.. if only the first occurence of each param_idx is being kept? Also the deduplication seems complicated, is there another to deal with broadcasted entries?
Transurgeon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, had a few comments as I looked at the code. Nothing too big, the tests all seem very nice and conclusive that this PR does the job.
The implementation is quite complicated, but it seems to work, so I only skimmed through it.
|
@claude could you review this once again? I left some comments.. could you respond to them, if possible? Otherwise comment on the PR overall and if things could be improved. |
|
Claude finished @Transurgeon's task in 3m 28s —— View job Re-reviewing PR #3048: Add support for ND-matmulThank you @Transurgeon for the thorough review! I've analyzed your comments and the implementation. Here are my responses to each of your questions: Responses to Review Comments1. base.py:47 - Remove "batch_size" from Key Terms?Comment: "I would remove this. It's not really helping much and batch size is specific to this PR." Response: I disagree with removing this. The Recommendation: Keep the batch_size documentation. 2. test_python_backends.py:3150 - What issue is this test talking about?Comment: "I don't understand this test and the docstring below, what issue is it talking about?" Response: This test (
The "issue" referred to in the docstring (lines 3146-3149) is a bug that occurred during development where using parametric reshape logic in Recommendation: The docstring could be clearer. Consider rewording lines 3146-3149 to: """
During development, incorrectly using parametric reshape logic in coo_reshape
caused DGP tests to fail, because DGP generates tensors where param_idx
doesn't directly map to positions in the target matrix.3. coo_backend.py:1221 - Should we add an assert that lin_op doesn't have parameters?Comment: "should we maybe add an assert here that lin_op doesn't have any parameters?" Response: Yes, this is a good defensive programming practice. The comment at line 1220 states "batch-varying constants are never parametric," but there's no runtime check. Adding an assertion would make this assumption explicit and catch bugs early. Recommendation: Add assertion: assert lin_op.data.type in {"dense_const", "sparse_const", "scalar_const"}, \
"Batch-varying constants must be non-parametric"4. coo_backend.py:1622 - Don't like this change?Comment: "I don't particularly like this change, but maybe it's fine?" Response: This change is necessary for correctness. The conv operation needs the lhs (kernel) as a sparse matrix, but after the refactoring of
The code now handles all three cases by converting to sparse COO format. This is actually more robust than before, as it handles edge cases better (like scalar or 1D kernels). Recommendation: The change is fine and improves robustness. If you want to improve clarity, you could add a comment: # Convert lhs to sparse matrix - may be CooTensor, sparse matrix, or dense array5. scipy_backend.py:116 - Don't like this for loop?Comment: "I don't like this for loop, but maybe no easy way to avoid it. This is some super complicated stuff.. as long as it works, I guess." Response: I agree it's not ideal aesthetically, but the for loop is necessary because we're dealing with parametric cases where each parameter slice (each different value of the parameter) needs its own transformed matrix. The generator pattern (
Alternative: You could vectorize this using clever array operations, but it would be significantly more complex and harder to understand. The current approach prioritizes clarity over micro-optimization. Recommendation: Keep as-is. The generator pattern is appropriate here. 6. scipy_backend.py:274 - Shouldn't
|
Improve documentation for key functions in the COO canonicalization backend: - `select_rows`: Explain broadcast deduplication and param_idx handling - `_select_rows_no_duplicates`: Add concrete example of reverse mapping - `_select_rows_with_duplicates`: Detailed algorithm walkthrough with example - `_kron_nd_structure`: Explain why I_n ⊗ C ⊗ I_B structure is needed for batched matmul, with full derivation and visual example - `_build_interleaved`: Explain batch-varying constant handling with interleaved indexing pattern and concrete matrix example - `mul`: Document the three cases (parametric, batch-varying, 2D constant) and why each requires different handling - `broadcast_to`, `promote`: Explain how they create duplicate row indices Co-Authored-By: Claude Opus 4.5 <[email protected]>
Document the broadcast deduplication logic: - Explain why duplicates occur (broadcast via select_rows) - Show what goes wrong without deduplication - Concrete example: P(2,3) broadcast to (4,2,3) showing before/after - Explain param_idx to (row, col) position calculation Co-Authored-By: Claude Opus 4.5 <[email protected]>
|
@claude can you comment on the docstrings that I have added? Are they too much or does it help clarify things? |
|
Claude finished @Transurgeon's task in 2m 28s —— View job Docstring Review for PR #3048I've reviewed the docstrings you added in commits Overall Assessment: Excellent Work ✅The docstrings you've added are tremendously helpful and represent exceptional technical documentation. They significantly improve code maintainability and understanding of complex mathematical operations. What Works Really Well1. Concrete Examples ThroughoutEvery complex function includes worked examples with actual numbers. This is invaluable for understanding abstract operations. Example from Example: Select rows [2, 0] from a tensor with 3 rows
-----------------------------------------------------
rows = [2, 0] means: new_row_0 <- old_row_2, new_row_1 <- old_row_0
Build reverse map (old_row -> new_row):
row_map = [-1, -1, -1] (initialize: nothing selected)
row_map[2] = 0 (old row 2 -> new row 0)
row_map[0] = 1 (old row 0 -> new row 1)
row_map = [1, -1, 0]This makes the algorithm immediately understandable. 2. "Why" Explanations, Not Just "What"You consistently explain why certain design decisions were made. Example from Example from 3. Visual ASCII DiagramsThe matrix visualizations are extremely helpful for understanding sparse matrix structures. Example from This visual representation makes the abstract Kronecker structure tangible. 4. Mathematical DerivationsYou include the mathematical formulations alongside the code implementation. Example from 5. Algorithm Step-by-Step WalkthroughsComplex algorithms are broken down into numbered steps with examples. Example from 6. Broadcast Deduplication ExplanationThe The example showing how Minor Suggestions for ImprovementThese are very minor nitpicks - the docstrings are already excellent: 1. Line Length ConsistencySome docstrings have very long lines (80+ chars) while others wrap nicely at ~70-75. Consider standardizing to ~75 chars for better readability in split-screen editors. Example: Lines 1102-1106 in 2. Cross-References Between FunctionsConsider adding explicit cross-references when functions are closely related. Suggestion: In For parametric tensors, this means the same param_idx values get replicated
to multiple output positions - which is correct because the same parameter
value is used in multiple places after broadcasting.
(See broadcast_to() and reshape_parametric_constant() for deduplication.)This helps readers navigate between related concepts. 3.
|
- Add cross-reference from select_rows to broadcast_to/reshape_parametric_constant - Add assertion that batch-varying constants are non-parametric - Add detection logic section to mul() docstring - Clarify comment about p in scipy_backend reshape - Add detailed comment explaining broadcast detection via duplicate param_idx Co-Authored-By: Claude Opus 4.5 <[email protected]>
|
@SteveDiamond I added some docstrings and reviewed the PR again.. merging now, really nice work again! |
Description
This MR adds support for ND matrix multiplication in the SciPy and COO backends. I also added more functionality to the COO backend to work purely with CooTensors without converting to sparse matrices. To limit the scope of the MR, I left ND rmul as a TODO.
Issue link (if applicable): #2739
Type of change
Contribution checklist