Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[WIP]: Add polar method for complex tensors. #35563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed

[WIP]: Add polar method for complex tensors. #35563

wants to merge 6 commits into from

Conversation

kostekIV
Copy link
Contributor

Hi @dylanbespalko its still work in progress but I would like to ask for your opinion. I was trying to follow your instruction and started with method polar. So from input it is producing its polar form. Currently magnitude is stored as real value and angle as imaginary. First I would like to ask you if this is like you imagined it should be? Some problems that I run into and not sure how to tackle them. Polar form makes sense not only for complex type tensor but also for other like float or double but it is hard implement it because I understand it like that if tensor is computed by kernel function its dtype will match dtype of tensor that was used as input to kernel. I was thinking of casting such tensors to complex in file aten/src/ATen/native/UnaryOps.cpp before passing them to polar implementations. Also I had to add function polar in vec256_base.h which is just a placeholder to make it compile, not sure how to do it better here.

Related to #35312

@dr-ci
Copy link

dr-ci bot commented Mar 27, 2020

💊 Build failures summary and remediations

As of commit 5adb0b5 (more details on the Dr. CI page):


  • 2/2 failures introduced in this PR

🕵️ 2 new failures recognized by patterns

The following build failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_windows_vs2019_py36_cpu_test2 (1/2)

Step: "Test" (full log | pattern match details | 🔁 rerun) <confirmed not flaky by 2 failures>

AssertionError: Not within tolerance rtol=1e-07 atol=1e-07 at input[3] (0.0 vs. 3.141592653589793) and 1 other locations (40.00%)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 930, in assertEqual 
    assertTensorsEqual(x, y) 
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 890, in assertTensorsEqual 
    atol=atol, rtol=rtol, message=message) 
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 930, in assertEqual 
    assertTensorsEqual(x, y) 
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 892, in assertTensorsEqual 
    torch.testing.assert_allclose(a, b, atol=atol, rtol=rtol, equal_nan=True, msg=message) 
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\__init__.py", line 60, in assert_allclose 
    raise AssertionError(msg) 
AssertionError: Not within tolerance rtol=1e-07 atol=1e-07 at input[3] (0.0 vs. 3.141592653589793) and 1 other locations (40.00%) 
 
---------------------------------------------------------------------- 
Ran 5 tests in 1.159s 
 
FAILED (failures=2) 
 
Generating XML reports... 
Generated XML report: test-reports\python-unittest\TEST-TestComplexTensor-20200504182619.xml 
Traceback (most recent call last): 
  File "run_test.py", line 673, in <module> 

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test2 (2/2)

Step: "Test" (full log | pattern match details | 🔁 rerun) <confirmed not flaky by 2 failures>

AssertionError: Not within tolerance rtol=1e-07 atol=1e-07 at input[3] (0.0 vs. 3.141592653589793) and 1 other locations (40.00%)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 930, in assertEqual 
    assertTensorsEqual(x, y) 
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 890, in assertTensorsEqual 
    atol=atol, rtol=rtol, message=message) 
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 930, in assertEqual 
    assertTensorsEqual(x, y) 
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 892, in assertTensorsEqual 
    torch.testing.assert_allclose(a, b, atol=atol, rtol=rtol, equal_nan=True, msg=message) 
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\__init__.py", line 60, in assert_allclose 
    raise AssertionError(msg) 
AssertionError: Not within tolerance rtol=1e-07 atol=1e-07 at input[3] (0.0 vs. 3.141592653589793) and 1 other locations (40.00%) 
 
---------------------------------------------------------------------- 
Ran 5 tests in 1.158s 
 
FAILED (failures=2) 
 
Generating XML reports... 
Generated XML report: test-reports\python-unittest\TEST-TestComplexTensor-20200504191906.xml 
Traceback (most recent call last): 
  File "run_test.py", line 673, in <module> 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 33 times.

@@ -129,6 +129,18 @@ template <> class Vec256<std::complex<double>> {
auto angle = _mm256_permute_pd(angle_(), 0x05); // angle 90-angle
return _mm256_and_pd(angle, real_mask); // angle 0
}
__m256d polar_() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The polar_() method is the key take-away here. It is essentially a private method for the Vec256 class.

Now we have a decision to make:

  1. We could modify the existing operator* and operator/ to convert to polar coordinates, perform the op, and then convert to back to cartesian coordinates. I think you will find that we waste too much time converting between polar and cart for a single math kernel. Note: You lose considerable performance by calling multiple AVX functions.
  2. Therefore, we might need to create a new dtype ComplexPolarFloat dtype so that we can keep the data in polar coordinates for multiple consecutive math ops. This will require significantly more work, but you are getting there.

Method 2 implies that we can have a different input dtype than the output dtype. PyTorch traditionally isn't setup to do this, but it has been attempted:
- People have tried to cast complex to float for the real(), imag(), abs(), angle() kernels. This is hard because its also changes the memory size.
- le kernel successfully casts the output byte data type to boolean (both are 1 byte in memory).
- #35524 is implementing the std::complex<T> data type internally in PyTorch. You could copy this code and create a std::complex_polar<T> C++ data type. Then you could do the same thing as the le_kernel and convert from std::complex<T> to std::complex_polar<T> when calling polar()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In method 2 would we also need to add Vec256<std::complex_polar<T>>? I will look into le kernel, thanks for link.

Comment on lines +241 to +250
template <typename TYPE, std::enable_if_t<!c10::is_complex_t<TYPE>::value, int> = 0>
inline std::complex<TYPE> polar_impl (TYPE a) {
return std::complex<TYPE> (std::abs(a), std::arg(a));
}

template <typename TYPE, std::enable_if_t<c10::is_complex_t<TYPE>::value, int> = 0>
inline TYPE polar_impl (TYPE a) {
return TYPE(std::abs(a), std::arg(a));
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might be able to use std::polar here, however there is no std::cart so you will need to do something like this anyways.

@dylanbespalko
Copy link
Contributor

dylanbespalko commented Mar 27, 2020

@kostekIV, @anjali411, @ezyang

LGTM. We should discuss the possibility of a std::complex_polar<T> and ScalarType::ComplexPolarFloat after #35524 is merged. I think the Vec256::polar() method will be a faster way to convert from ComplexFloat to ComplexPolarFloat than the built in to() function. Do you agree?

Comment on lines 319 to 326
Vec256<T> polar() const {
return *this;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in a template class you need to re-implement all methods during specialization`.

I think the default case should return abs(x), not real(x).

If you are publicly adding a new kernel, you also need to specify the derivate so that Autograd works in derivatives.yaml. You would have to combine the derivative of abs() and angle() in some way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yes you're right. I didn't know about derivates.yaml, will look into it.

(0.0899 + 0.3232j),
(0.9718 + 0.1947j),
(0.2349 + 0.1463j)], dtype=torch.complex64)
>>> a.polar()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to represent polar representation of complex as a + bi?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not, wasn't sure how to store it differently in complex tensor, I guess it may be resolved by ComplexPolarFloat dtype.

@kostekIV
Copy link
Contributor Author

@dylanbespalko I understand next steps to improve this solution are first create std::complex_polar<T> based on #35524 then implement Vec256 for it, add kernel polar which will return ComplexPolarFloat so under the hood it is std::complex_polar<T>, similarly for cart kernel. And wait with that until #35524 is merged and ComplexPolarFloat is agreed upon?

@dylanbespalko
Copy link
Contributor

@kostekIV,

Correct, you will likely need to submit a proposal to create:

  1. A ScalarType::ComplexPolarFloat PyTorch dtype
  2. A c10::complex_polar C++ type.

The use cases would be:

  1. Faster tensor multiplication (mul, div, ...)
  2. Faster matrix multiplication (mv, mm, ...)
  3. Can you think of more?

There will need to be a discussion because this could lead to Cylindrical coords, Spherical Coords, etc..

FYI, #11641 contains partial information on how the ComplexFloat dtype was added.

@kostekIV
Copy link
Contributor Author

@dylanbespalko

By proposal do you mean pull request or issue?
Similar to what you have already mentioned but also faster division of complex numbers, cant think of more uses straight from my head, I will try to dig into it a little more.

@dylanbespalko
Copy link
Contributor

@kostekIV

Sorry, I meant you should create a feature request. There are people at Fb that need to approve something with that large of a scope.

https://pytorch.org/docs/stable/community/contribution_guide.html

If I remember correctly, you create a new issue on GitHub and label as a Feature Request. Github should automatically apply this template

The people that made PyTorch will discuss it with you. They are usually the people that make big changes. I would also learn how to benchmark kernels

https://github.com/pytorch/pytorch/tree/master/benchmarks/operator_benchmark

Perhaps you could modify the mul kernel on your local machine to use polar coordinates:

  • How fast is mul in cartesian coordinates?
  • How fast is mul in polar coordinates?
  • What is the computational overhead of calling polar before mul?
  • Try a few tensor sizes.

@ezyang
Copy link
Contributor

ezyang commented Mar 30, 2020

@dylanbespalko lmk if you need help from me reviewing this

@ezyang
Copy link
Contributor

ezyang commented Mar 30, 2020

We should discuss the possibility of a std::complex_polar and ScalarType::ComplexPolarFloat after #35524 is merged

This seems reasonable, though whether or not we should invest in writing it (writing one of these classes is no small effort) depends on how serious we're going to go about polar.

@dylanbespalko
Copy link
Contributor

@kostekIV,

Looks like you have some CI failures specifically in the doc tests. Try running test/test_doc_coverage.py on your local machine. Some of the CI machines also failed to connect, so add another commit to restart the CI.

@dylanbespalko
Copy link
Contributor

@kostekIV

It looks like CI is failing on windows_vs2019

FAIL [0.004s]: test_polar_cpu_complex128 (__main__.TestTorchDeviceTypeCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_device_type.py", line 224, in instantiated_test
    result = test(self, device_arg, dtype)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_device_type.py", line 431, in only_fn
    return fn(slf, device, *args, **kwargs)
  File "test_torch.py", line 15175, in test_polar
    self.assertEqual(polar_tensor, a.abs() + 1j * a.angle())
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 894, in assertEqual
    assertTensorsEqual(x, y)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 860, in assertTensorsEqual
    self.assertLessEqual(max_err, prec, message)
AssertionError: tensor(1.5707, dtype=torch.float64) not less than or equal to 1e-05 : 

======================================================================
FAIL [0.004s]: test_polar_cpu_complex64 (__main__.TestTorchDeviceTypeCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_device_type.py", line 224, in instantiated_test
    result = test(self, device_arg, dtype)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_device_type.py", line 431, in only_fn
    return fn(slf, device, *args, **kwargs)
  File "test_torch.py", line 15175, in test_polar
    self.assertEqual(polar_tensor, a.abs() + 1j * a.angle())
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 894, in assertEqual
    assertTensorsEqual(x, y)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 860, in assertTensorsEqual
    self.assertLessEqual(max_err, prec, message)
AssertionError: tensor(1.5708) not less than or equal to 1e-05 : 

----------------------------------------------------------------------
Ran 4781 tests in 358.397s

FAILED (failures=2, skipped=201)

@kostekIV
Copy link
Contributor Author

kostekIV commented Apr 4, 2020

@kostekIV

It looks like CI is failing on windows_vs2019

FAIL [0.004s]: test_polar_cpu_complex128 (__main__.TestTorchDeviceTypeCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_device_type.py", line 224, in instantiated_test
    result = test(self, device_arg, dtype)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_device_type.py", line 431, in only_fn
    return fn(slf, device, *args, **kwargs)
  File "test_torch.py", line 15175, in test_polar
    self.assertEqual(polar_tensor, a.abs() + 1j * a.angle())
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 894, in assertEqual
    assertTensorsEqual(x, y)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 860, in assertTensorsEqual
    self.assertLessEqual(max_err, prec, message)
AssertionError: tensor(1.5707, dtype=torch.float64) not less than or equal to 1e-05 : 

======================================================================
FAIL [0.004s]: test_polar_cpu_complex64 (__main__.TestTorchDeviceTypeCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_device_type.py", line 224, in instantiated_test
    result = test(self, device_arg, dtype)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_device_type.py", line 431, in only_fn
    return fn(slf, device, *args, **kwargs)
  File "test_torch.py", line 15175, in test_polar
    self.assertEqual(polar_tensor, a.abs() + 1j * a.angle())
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 894, in assertEqual
    assertTensorsEqual(x, y)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 860, in assertTensorsEqual
    self.assertLessEqual(max_err, prec, message)
AssertionError: tensor(1.5708) not less than or equal to 1e-05 : 

----------------------------------------------------------------------
Ran 4781 tests in 358.397s

FAILED (failures=2, skipped=201)

I am looking into it, it might take some time for me as I dont have machine with windows, and on mac or linux I didn't succeed in reproducing it yet.

@dylanbespalko
Copy link
Contributor

@kostekIV,

The test case looks really good. I'm a little nervous about 1j * a.angle(). Obviously that should work, but you may me triggering an unrelated bug. You could try and create a tensor complex_j = torch.tensor([0 + 1j], dtype).

I also found out that they want to isolate the complex number test cases in tests/test_complex.py

I have had Windows failures int the past. I fix them by making changes and re-running CI. Not elegant, but it happens a lot.

@dylanbespalko
Copy link
Contributor

I've added #36029 to track any issues with complex unit testing.

@anjali411 anjali411 added the module: complex Related to complex number support in PyTorch label Apr 7, 2020
@anjali411 anjali411 self-requested a review April 7, 2020 16:51
@kostekIV
Copy link
Contributor Author

kostekIV commented May 2, 2020

@dylanbespalko @anjali411 Sorry for lack of word/work past few weeks. I am wondering what I should do with this pr now. Since polar complex most likely won't be a thing in near future should I proceed with cart function in similar format as is here with polar - so comlex input to complex output? Or maybe some other format, like input 2 float tensors output one complex for cartand complex input to 2 float tensors forpolar` function? I've tried to do something like it but with no luck so far.

@dylanbespalko
Copy link
Contributor

@kostekIV

Test Failures

In you test case, you have a line self.assertEqual(polar_tensor, a.abs() + 1j * a.angle()). Based on the error messages, a.abs() + 1j * a.angle() is triggering a bug in PyTorch that is unrelated to your polar() function. Maybe try hardcoding the expected answer in to the second argument of assertEqual. Also, I think assertEqual is for comparing integer numbers, not floating point.

Other things to try:

  • test_torch.py uses torch.isclose(x, y, abs_tol, rel_tol). I think this function is undocumented and has a history of problems with complex numbers. Update to the latest code to use this.
  • You can also switch to numpy (tensor.detach().numpy()) and use assert_allclose(x, y, abs_tol, rel_tol) from the numpy.test package.
  • I experienced a lot of problems on windows with std::complex because they use I in the preprocessor to define std::complex<T>(0, 1). That can produce naming conflicts with variables called I. Don't try to solve these bugs in this PR, they are painful.

Is polar() useful?

Well, the problem with creating a complex_polar type was that the hw in the CPU and GPU are optimized for cartesian coordinates and you can't overcome that. This polar() function allows you solve a math problem that is better defined in polar coordinates, while re-using the cartesian math functions. I know there are Calculus problems that would benefit from this. Don't worry if the print function doesn't look right, you can always hack that in Python if you need to.

Some future motivation

What about cylindrical coordinates, (p, phi, z)?
What about spherical coordinates (r, phi, theta)?

I have done some recent work on the FPGA where std::complex<T> doesn't work. This allowed me to write a template class Vec where I can use float[2], float[4], float[8], to define functions in R^2, R^4, R^8 spaces. In the long-run, we need to replace std::complex with float[2] to make things easier.

@dylanbespalko
Copy link
Contributor

@kostekIV,

My apologies, I don't think you can simply re-use the cartesian functions. They need to be redefined for polar coordinates. It's really hard to get these things going.

@facebook-github-bot
Copy link
Contributor

Hi @kostekIV!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label May 28, 2022
@github-actions github-actions bot closed this Jun 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: complex Related to complex number support in PyTorch open source Stale
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants