Thanks to visit codestin.com
Credit goes to github.com

Skip to content

"clamp_min_cpu" not implemented for 'ComplexDouble' #73915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Fun772283153 opened this issue Mar 8, 2022 · 14 comments
Closed

"clamp_min_cpu" not implemented for 'ComplexDouble' #73915

Fun772283153 opened this issue Mar 8, 2022 · 14 comments
Labels
module: complex Related to complex number support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Fun772283153
Copy link

Fun772283153 commented Mar 8, 2022

πŸ› Describe the bug

import torch

a = torch.zeros([1], dtype=torch.complex128)
a[0] = -3.2427e-04+5.8708e-03j
b = torch.zeros([0], dtype=torch.complex128)
print(a)
print(torch.clamp_min(a, b))

Versions

PyTorch version: 1.10.2
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.2.1 (x86_64)
GCC version: Could not collect
Clang version: 13.0.0 (clang-1300.0.29.30)
CMake version: Could not collect
Libc version: N/A

Python version: 3.9.7 (default, Sep 16 2021, 08:50:36) [Clang 10.0.0 ] (64-bit runtime)
Python platform: macOS-10.16-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.2
[pip3] numpydoc==1.1.0
[pip3] torch==1.10.2
[pip3] torchaudio==0.10.2
[pip3] torchvision==0.11.3
[conda] blas 1.0 mkl
[conda] ffmpeg 4.3 h0a44026_0 pytorch
[conda] mkl 2021.4.0 hecd8cb5_637
[conda] mkl-service 2.4.0 py39h9ed2024_0
[conda] mkl_fft 1.3.1 py39h4ab4a9b_0
[conda] mkl_random 1.2.2 py39hb2f4e1b_0
[conda] mypy_extensions 0.4.3 py39hecd8cb5_0
[conda] numpy 1.22.2 pypi_0 pypi
[conda] numpydoc 1.1.0 pyhd3eb1b0_1
[conda] pytorch 1.10.2 py3.9_0 pytorch
[conda] torchaudio 0.10.2 py39_cpu pytorch
[conda] torchvision 0.11.3 py39_cpu pytorch

cc @ezyang @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved

@bdhirsh bdhirsh added module: complex Related to complex number support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 8, 2022
@ngimel
Copy link
Collaborator

ngimel commented Mar 9, 2022

Comparison ops are not implemented for complex numbers, this is expected behavior #36444 (comment)

@ngimel ngimel closed this as completed Mar 9, 2022
@lezcano
Copy link
Collaborator

lezcano commented Mar 9, 2022

In my opinion, this operation makes sense, and it could potentially be performed independently on the real and imaginary part independently.

clamp is an operation that's used for stability to clamp values that are too small or too large. Too small or too large can be understood in a metric sense (complex numbers do have a metric) as "I don't wan't to have any value that has norm larger/smaller than X". The proposal above uses a variation of this idea via the norm 1.

I wanted to use this operation for the backward of the LU decomposition, and in the end we had to work around it by implementing this operation manually.

While the general rule of "we don't support comparison operations for complex numbers" is indeed the correct one, we can consider implementing some operations that are not so "mathematical" in appropriate ways, as long as we document this behaviour accordingly.

@anjali411 wdyt?

@ezyang
Copy link
Contributor

ezyang commented Mar 9, 2022

Does numpy implement clamp in this way? :)

@Fun772283153
Copy link
Author

Here is the problem, I got some data whose type is complex, and I want to train a network on these data, as I wrote

image

there is a runtime error when I run this code, I don't know why it will call clamp_min function in the forward process?

@Fun772283153
Copy link
Author

it seems that numpy doesn't implement clamp for complexdouble

@ngimel
Copy link
Collaborator

ngimel commented Mar 10, 2022

What semantics do you expect from relu operation on complex inputs?

@Fun772283153
Copy link
Author

What semantics do you expect from relu operation on complex inputs?

Oh, I know, but does PyTorch has any non-linear active function for complex?

@lezcano
Copy link
Collaborator

lezcano commented Mar 10, 2022

Any non-linear activation that does not depend on min/max certainly works for complex numbers (i.e. pretty much all of them but those with cut-offs and those that are variations of relu).

@SantaTitular
Copy link

How do you do a workaround of this issue? splitting into real and imaginary?

@mruberry
Copy link
Collaborator

Clamping the real and imaginary parts of a complex tensor separately is possible; one option for "clamping" complex numbers is to compute a complex number with the same angle/phase but scaled magnitude. In this case, however, a hypothetical clamp_magnitude would accept a minimum and maximum magnitude.

@anjali411
Copy link
Contributor

anjali411 commented Jul 12, 2022

@SantaTitular if you want to clamp the absolute value, then this is probably the recommended way.

a = torch.randn(4, dtype=torch.cfloat)
angle_ = a.angle()
clamped_abs = torch.clamp(a.abs(), min=...)
b = torch.polar(clamped_abs, angle_)

If you want to clamp the real or imag value for some reason, then that can easily done by accessing the real and imag attributes on the tensor (they are views!) and clamping them.

@SantaTitular
Copy link

SantaTitular commented Jul 12, 2022

Thank you for the fast response @mruberry @anjali411 .

Indeed both your ideas would work nicely (I'll use it as a reference for future applications). For my case, the application was using a complex ReLu DEEP COMPLEX NETWORKS
, similar to @Fun772283153 . I was testing a simple MLP (below) and I think it is a nice workaround, just have to get the Loss function (BCE, CrossEntropyLoss) to accept complex numbers to test it out (@mruberry not sure if you know something about it πŸ˜…, I already looked into #46642)

class Feedforward(torch.nn.Module):
        def __init__(self, input_size, hidden_size):
            super(Feedforward, self).__init__()
            self.input_size = input_size
            self.hidden_size  = hidden_size
            self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size).to(torch.complex128)
            self.relu = torch.nn.ReLU()
            self.fc2 = torch.nn.Linear(self.hidden_size, 1).to(torch.complex128)
            self.sigmoid = torch.nn.Sigmoid()
        def forward(self, x):
            hidden = self.fc1(x)
            relu = self.relu(hidden.real) + 1j * self.relu(hidden.imag)
            output = self.fc2(relu)
            output = self.sigmoid(output)
            return output

@mruberry
Copy link
Collaborator

@anjali411 is definitely our expert on complex losses -- in fact she's preparing a blog post that includes a complex neural network now!

@SantaTitular
Copy link

Damn, thats great news! I'm very interested in reading and commenting on it. @anjali411 please let me know when it is available (not sure if there is a way on github to track/read about new blog posts)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: complex Related to complex number support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

8 participants