Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Tags: pytorch/pytorch

Tags

viable/strict/1780570258

Toggle viable/strict/1780570258's commit message
Fix reflection_pad3d out-of-bounds memory access (#185614)

## Fix out-of-bounds memory access in reflection_pad3d

Fixes #145258

### Problem

`reflection_pad3d` accessed `padding[0-5]` before validating the array had 6 elements, causing out-of-bounds reads and crashes with invalid input.

```python
# Before: crash or undefined behavior
torch.nn.functional.reflection_pad3d(tensor, (1,))

# After: clear error
RuntimeError: padding size is expected to be 6, but got: 1
```

### Root Cause

The validation check existed but ran after accessing array elements:

```cpp
int64_t pad_left = padding[0];  // Accesses padding[0-5] first
// ...
check_valid_input<3>(input, padding);  // Validates too late
```

### Solution

Moved check_valid_input<3> before array access. Also updated reflection_pad3d_backward error message to show actual padding size.

### Testing

Added tests for invalid padding sizes: empty (), length 1, 5, and 7. All existing tests pass.

### Security

This issue can be reproduced using AddressSanitizer. While this is not a critical security vulnerability, fixing this improves stability and provides clear error messages instead of mysterious crashes when users make mistakes.

### Previous Behavior

reflection_pad3d assumed that padding always contained 6 values and directly accessed padding[0] through padding[5]. When a shorter tuple was provided, this could result in out-of-bounds memory access, leading to undefined behavior such as segmentation faults, crashes, or other unpredictable results.

Examples:

```python
torch.nn.functional.reflection_pad3d(torch.randn(1, 1, 3, 3, 3), ())
torch.nn.functional.reflection_pad3d(torch.randn(1, 1, 3, 3, 3), (1,))
torch.nn.functional.reflection_pad3d(torch.randn(1, 1, 3, 3, 3), (1, 1, 1, 1, 1))
```

Possible outcomes included:

* Segmentation faults
* Out-of-bounds memory reads
* Undefined behavior
* Unpredictable crashes or incorrect results

### New Behavior

The implementation now validates the length of the padding tuple before accessing its elements. If the tuple does not contain exactly 6 values, a clear error is raised:

```text
RuntimeError: padding size is expected to be 6, but got: <actual_size>
```

Examples:

```text
RuntimeError: padding size is expected to be 6, but got: 0
RuntimeError: padding size is expected to be 6, but got: 1
RuntimeError: padding size is expected to be 6, but got: 5
```

---

@pytorchbot label "module: nn" "module: cpp" "module: error checking" "module: crash" "topic: fuzzer"

Pull Request resolved: #185614
Approved by: https://github.com/soulitzer

viable/strict/1780565739

Toggle viable/strict/1780565739's commit message
[MTIA] Short-circuit `signbit` for unsigned dtypes (#185985) (#185985)

Summary:

`signbit` on unsigned integer types (`bool`, `uint8`, `uint16`, `uint32`, `uint64`) should trivially returns all-`False`.

Replaced the generic `register_pointwise` for `aten.signbit` in `caffe2/torch/_inductor/lowering.py` with a custom `register_lowering` that short-circuits for unsigned dtypes using `full_like(x, False, dtype=torch.bool)`, following the same pattern as `isinf` and `isnan`. For signed types, the original pointwise codegen path is preserved.

Added `test_signbit_unsigned_dtypes` in `test_torchinductor.py`.

Pull Request resolved: #185985
Approved by: https://github.com/malfet, https://github.com/jansel

viable/strict/1780559766

Toggle viable/strict/1780559766's commit message
Revert "Fix index ops on expanded tensors in Inductor (#184488)"

This reverts commit 65fd92b.

Reverted #184488 on behalf of https://github.com/atalman due to reverted internally ([comment](#184488 (comment)))

viable/strict/1780552603

Toggle viable/strict/1780552603's commit message
[python] Expose at::from_blob as torch._from_blob (#185850)

Expose `at::from_blob` to Python as a private function for advanced use
cases that need to wrap externally managed memory (e.g. from C libraries,
custom allocators, or device pointers) where the Python buffer protocol
isn't available. Unlike `torch.frombuffer`, this takes a raw integer
address and does not hold a reference to any Python object -- the caller
is responsible for keeping the underlying memory alive.

The function is deliberately not registered in native_functions.yaml
(no autograd or JIT semantics) and is underscore-prefixed to signal no
BC guarantees. There are absolutely no stability guarantees on this
function -- the signature, behavior, and existence of `_from_blob` may
change or be removed without notice in any future release.

Test Plan:

```
python -m pytest test/test_tensor_creation_ops.py -k TestFromBlob -v
```

Tests cover basic creation across numeric dtypes, 2D reshape, custom
strides, bidirectional shared memory verification, explicit device/dtype,
default dtype behavior, and invalid device rejection.

Authored with Claude.
Pull Request resolved: #185850
Approved by: https://github.com/zou3519

trunk/404390f51707db63fc0842d1163f99cd969f2044

Toggle trunk/404390f51707db63fc0842d1163f99cd969f2044's commit message
Fix typos in comments and docstrings across torch (#185714)

Correct spelling and grammar errors in comments and docstrings across
64 files spanning dynamo, inductor, functorch, distributed, fx, nn,
autograd, JIT, quantization, and utility modules. These are
documentation-only changes with no functional impact.

Authored with Claude (typo_terminator2).

Pull Request resolved: #185714
Approved by: https://github.com/zou3519

trunk/136810cc9ce08999dc335f503bfe046d06100b74

Toggle trunk/136810cc9ce08999dc335f503bfe046d06100b74's commit message
Rename distributed collective ops to _single naming scheme (#186123)

Align the public torch.distributed collective APIs with the naming scheme
used by torchcomms' TorchCommBackend, where the single-tensor variants are
suffixed with `_single`. `all_gather_into_tensor` is renamed to
`all_gather_single` and `reduce_scatter_tensor` to `reduce_scatter_single`.

The previous names are kept as thin wrappers that delegate to the new
functions and are marked deprecated via FutureWarning, so existing code keeps
working. The new names are wired into the Dynamo / non-strict-export
collective remaps so they are traceable under torch.compile just like the old
names. Direct test usages are updated to the new names.

To review, start with distributed_c10d.py (the rename and the deprecated
aliases), then the remap plumbing in _functional_collectives.py,
_dynamo/variables/functions.py and _export/non_strict_utils.py, then the docs
and test updates.

Authored by Claude.
Pull Request resolved: #186123
Approved by: https://github.com/tushar00jain, https://github.com/kapilsh

trunk/66530c8f420465e0406fbfe8bcd37f5e5bfbf325

Toggle trunk/66530c8f420465e0406fbfe8bcd37f5e5bfbf325's commit message
Revert "Add _single ProcessGroup methods and deprecate the _base alia…

…ses (#186134)"

This reverts commit 97ee046.

Reverted #186134 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](#186125 (comment)))

trunk/9841bc06254fbebf8817514cc86a3637f6f5e013

Toggle trunk/9841bc06254fbebf8817514cc86a3637f6f5e013's commit message

Unverified

This commit is not signed, but one or more authors requires that any commit attributed to them is signed.
[dynamo/tvm] Handle ImportError for TVM backend (#185893)

The `tvm` backend did a bare import `tvm`, so without Apache TVM installed users hit a cryptic ModuleNotFoundError. This wraps the imports in try/except and re-raises an ImportError pointing to the install docs.

Why: match the OpenXLA backend, which already does this — see xla_backend_helper in  torchxla.py

https://github.com/pytorch/pytorch/blob/cd2d0c4c2291d60ed0f0285f9d26f9ab00aa90bb/torch/_dynamo/backends/torchxla.py#L32-L38

Pull Request resolved: #185893
Approved by: https://github.com/jansel

trunk/9680b3aaf628b22b9162507ec972cabd1c8725bf

Toggle trunk/9680b3aaf628b22b9162507ec972cabd1c8725bf's commit message
[AOTAutograd] Fix create_graph=True silently losing requires_grad on …

…gradient outputs (#181606)

When `torch.autograd.grad(create_graph=True)` or
`.backward(create_graph=True)` is used on the output of a
torch.compiled function, the gradient tensors silently have
`requires_grad=False` and `grad_fn=None`. This breaks MAML,
WGAN-GP, and physics-informed neural networks that need
higher-order gradients.

Root cause: `autograd.Function.forward` runs in a no-grad context,
so tensors saved for backward lack `requires_grad`. When the backward
runs with grad enabled (from `create_graph=True`), the `needs_grad`
check only looked at saved tensors and found none requiring grad,
skipping the `_double_backward` wrapper entirely.

Two fixes:
1. Expand `needs_grad` to also check `metadata.input_info` — if
   the original forward inputs required grad and grad is currently
   enabled, the backward should go through `_double_backward`.
2. In `_double_backward`, when no `all_args` tensor requires grad
   (because saved tensors were detached), inject a dummy
   `requires_grad=True` tensor so `CompiledFunctionBackward.apply`
   produces outputs with `requires_grad=True` and a proper `grad_fn`.

After the fix, gradient outputs from `create_graph=True` correctly
have `requires_grad=True`. Attempting actual double backward still
raises "does not currently support double backward" — but now this
happens at the right time instead of silently producing wrong metadata.

Authored with Claude.
Pull Request resolved: #181606
Approved by: https://github.com/aorenste

trunk/3873f013fc214e04c31e91ea9a4b6ddba95999ae

Toggle trunk/3873f013fc214e04c31e91ea9a4b6ddba95999ae's commit message
[TMA] Skips the types that are not in _TMA_SUPPORTED_DTYPES when enab…

…ling TMA (#185223)

Fixes #185222
Dtypes that lack a CUtensorMapDataType enum entry (e.g. torch.bool / tl.int1) cannot be used with TMA tensor descriptors. This PR skips TMA codegen for such dtypes, falling back to regular block pointers instead of crashing at runtime.

Pull Request resolved: #185223
Approved by: https://github.com/jansel, https://github.com/mlazos