Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Backward function for torch.cdist#17173

Closed
ifedan wants to merge 6 commits into
pytorch:masterfrom
ifedan:cdist_backward_function
Closed

Backward function for torch.cdist#17173
ifedan wants to merge 6 commits into
pytorch:masterfrom
ifedan:cdist_backward_function

Conversation

@ifedan
Copy link
Copy Markdown
Contributor

@ifedan ifedan commented Feb 15, 2019

No description provided.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

_(aten, _pad_packed_sequence) \
_(aten, _pdist_backward) \
_(aten, _pdist_forward) \
_(aten, _cdist_backward) \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think you actually need this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Comment thread test/test_autograd.py Outdated
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
run_functional_checks(self, "test_cdist", "cdist", f,
True, f_args_variable, f_args_tensor)
self.assertTrue(gradcheck(f, f_args_variable, eps=1e-6, atol=PRECISION))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need this, it's already checked in run_functional_checks, it's just there in the above test as explained in the comments.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

REGISTER_DISPATCH(cdist_backward_stub, &cdist_backward_kernel_impl);

}} // at::native
}} // at::native No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think there should be a newline here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment thread test/test_autograd.py
False, f_args_variable, f_args_tensor)
self.assertTrue(gradcheck(lambda a, b: torch.cat((a, b)), f_args_variable, eps=1e-6, atol=PRECISION))

def test_cdist(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move this above or below all the cat tests?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment thread test/test_autograd.py
f = lambda a, b: torch.cdist(a, b, p)
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
run_functional_checks(self, "test_cdist", "cdist", f,
True, f_args_variable, f_args_tensor)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm...this exposed a pretty serious bug in our testing. This test is running double backwards tests and passing, but it shouldn't be because you didn't write a double backwards function.

I believe what is happening is this:

  1. because you don't have an entry in derivatives.yaml for your backwards function (i.e. a double backwards definition), it is assumed to be composed of differentiable functions
  2. it's not actually composed of differentiable functions
  3. the testing doesn't catch this because it only looks at differentiable outputs. But your function doesn't create any differentiable outputs.

This should be testable, though. Basically, this should only be valid if the numerical gradient of the function is 0. Can you add checks for that? (You should probably do it in a separate commit, there may be other functions that exhibit the same problem).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test to EXCLUDE_GRADGRADCHECK_BY_TEST_NAME list

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@gchanan gchanan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you fix the underlying issue with gradgradcheck? Which is basically that if none of the outputs require_grad, we don't actually check them. We should instead check that their numerical gradients are 0.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@gchanan gchanan closed this Mar 15, 2019
@gchanan gchanan reopened this Mar 15, 2019
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Mar 21, 2019
Summary: Pull Request resolved: pytorch/pytorch#17173

Differential Revision: D14111482

Pulled By: ifedan

fbshipit-source-id: d72cfd53c29d0f8cf5f8ad1148d14f3d5abd938e
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Mar 21, 2019

This diff fails lint

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary: Pull Request resolved: pytorch#17173

Differential Revision: D14111482

Pulled By: ifedan

fbshipit-source-id: d72cfd53c29d0f8cf5f8ad1148d14f3d5abd938e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants