Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix dense Embedding to work with double backward#9078

Closed
kshitij12345 wants to merge 9 commits into
pytorch:masterfrom
kshitij12345:master
Closed

Fix dense Embedding to work with double backward#9078
kshitij12345 wants to merge 9 commits into
pytorch:masterfrom
kshitij12345:master

Conversation

@kshitij12345
Copy link
Copy Markdown
Collaborator

Fixes : #6469

  1. ATen/native/native_functions.yml had dispatch variants for for embedding_dense_backward , however embedding_backward explicitly made call to it, thus leading to error.

  2. In case of CUDA type tensor, the function crashed used to crash on dereferencing of indices's data pointer.

Both have been solved and checked against (on CUDA and CPU)

  1. As mentioned in the issue
import torch

class Test(torch.nn.Module):
    
    def __init__(self):
        super(Test,self).__init__()
        self.embd = torch.nn.Embedding(1000, 100)
        self.dense = torch.nn.Linear(100, 1)
    
    def forward(self, inp):
        inp = self.embd(inp)
        return self.dense(inp)

test = Test()
#test.cuda()
inp = torch.tensor([0,1,2,1,1])
out = test(inp)
raw_loss = out.mean(dim=0)

loss_grad = torch.autograd.grad(outputs=raw_loss,
                         inputs=list(test.parameters()),
                         retain_graph=True, create_graph=True, only_inputs=True)
norm = sum([param.norm()**2 for param in loss_grad])
loss = raw_loss + norm

loss.backward(retain_graph=True)

print(test.embd.weight.grad)

  1. Test Script
import torch
import time
start = time.time()
l = [1,1]*100 
input = torch.tensor([[1,0],[1,0]],device='cpu')
embedding_matrix = torch.tensor([[1.0,3.0],[2.0,4]],requires_grad=True,device='cpu')

sq = embedding_matrix * embedding_matrix
emb = torch.nn.functional.embedding(input, sq,scale_grad_by_freq=False)

print('Embedding Matrix')
print(embedding_matrix)
print('-----------------')

#prod = torch.cumprod(emb,1)
sum_ = emb.sum()#prod.sum()

loss_grad, = torch.autograd.grad(outputs=sum_,inputs=embedding_matrix,create_graph=True)

print('Gradient')
print(loss_grad)
print('-----------------')

sum2_ = sum_ + loss_grad.sum()
print(sum2_)
sum2_.backward()

print(embedding_matrix.grad)
print(time.time() - start)

@soumith soumith changed the title fix embedding_dense_backward (#6469) Fix dense Embedding to work with double backward Jul 1, 2018
@ssnl
Copy link
Copy Markdown
Collaborator

ssnl commented Jul 1, 2018

shouldn't the solution be adding a custom double backward, rather than slowing down backward with autograd ops?

Copy link
Copy Markdown
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What @ssnl said. This looks like it will slow down embedding significantly. We should just define a derivative for embedding_backward.

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

Sure , I ll try that and update once done.

@weiyangfb
Copy link
Copy Markdown
Contributor

weiyangfb commented Jul 10, 2018

@kshitij12345 this needs a rebase now

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@weiyangfb sure will do that.

@t-vi
Copy link
Copy Markdown
Collaborator

t-vi commented Jul 12, 2018

Wouldn't it still be better to have a fused index_add_ + mul for the backward than implementing a specific double backward? I'd think that it is probably a bit less code and more similar to what we do for other ops.

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@ssnl I have tried and here is my opinion from what I understand

In the derivatives.yaml,

- name: embedding(Tensor weight, Tensor indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse)
  weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse)

- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse)
  weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse)

We'll need to pass in weight as an argument to embedding_backward for the embedding_double_backward to update gradient on weight, but that will also require minor changes in embedding_bag_sparse_backward as it calls embedding_backward.

Also embedding_backward calls embedding_dense_backward and embedding_sparse_backward. Thus each will need its own version of double backward.

So I believe that as @t-vi suggested, we should opt for fused index_add_ + mul.
Please let me know what you think.
Thank You.

@ssnl
Copy link
Copy Markdown
Collaborator

ssnl commented Jul 15, 2018

@t-vi @kshitij12345 I'm fine if you want to implement fused index_add_ + mul and also write a backward for that. But that would be considerably more work than just writing a custom double backward for this.

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

Oh , will take a look and update you on it.

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

I installed pytorch from two source codes:

Commit 6dcaa47 from your repo (kshitij12345/pytorch:master)
Commit ae1a972 from original pytorch:master (Which is the last commit your repo has from the pytorch:master)
So, there shouldn't be any problem with the single backward call for both. But I get the mentioned error from the former one while the latter one was fine.

It seems that your commit causes this error (btw, I'm not sure). I'll double check and will have a look and let you know if I found anything.

In that case can you please share a minimal code snippet that produces the given error, so even I can check.
Thank You.

@pooryapzm
Copy link
Copy Markdown

I installed pytorch from two source codes:

Commit 6dcaa47 from your repo (kshitij12345/pytorch:master)
Commit ae1a972 from original pytorch:master (Which is the last commit your repo has from the pytorch:master)
So, there shouldn't be any problem with the single backward call for both. But I get the mentioned error from the former one while the latter one was fine.

It seems that your commit causes this error (btw, I'm not sure). I'll double check and will have a look and let you know if I found anything.

In that case can you please share a minimal code snippet that produces the given error, so even I can check.
Thank You.

I am trying to write a minimal code to reproduce the given error.
Meanwhile, you can simply reproduce the error by training a simple model with OpenNMT. If you train the very example in their README, you will get the same error (there is no need for double gradients there).
As far as I understand, the issue is with the (single) backward of the embedding table in both encoder and decoder.
https://github.com/OpenNMT/OpenNMT-py

@pooryapzm
Copy link
Copy Markdown

pooryapzm commented Sep 26, 2018

I installed pytorch from two source codes:

Commit 6dcaa47 from your repo (kshitij12345/pytorch:master)
Commit ae1a972 from original pytorch:master (Which is the last commit your repo has from the pytorch:master)
So, there shouldn't be any problem with the single backward call for both. But I get the mentioned error from the former one while the latter one was fine.

It seems that your commit causes this error (btw, I'm not sure). I'll double check and will have a look and let you know if I found anything.

In that case can you please share a minimal code snippet that produces the given error, so even I can check.
Thank You.

I am trying to write a minimal code to reproduce the given error.
Meanwhile, you can simply reproduce the error by training a simple model with OpenNMT. If you train the very example in their README, you will get the same error (there is no need for double gradients there).
As far as I understand, the issue is with the (single) backward of the embedding table in both encoder and decoder.
https://github.com/OpenNMT/OpenNMT-py

Here is the minimal code to replicate the error. Sorry, but you need OpenNMT to import Embeddings.

import torch

from onmt.modules.embeddings import Embeddings

class Test(torch.nn.Module):

    def __init__(self):
        super(Test, self).__init__()
        self.dense = torch.nn.Linear(100, 1)
        self.oembd = Embeddings(word_vec_size=100,
                   position_encoding=False,
                   dropout=0.3,
                   word_padding_idx=1,
                   word_vocab_size=1000)

    def forward(self, inp):
        inp = self.oembd(inp)
        return self.dense(inp)

test = Test()
# test.cuda()
inp = torch.tensor([1, 1, 2, 1, 1])
inp=inp.unsqueeze(0).unsqueeze(-1)
out = test(inp)
raw_loss = out.mean(dim=1)

loss_grad = torch.autograd.grad(outputs=raw_loss,
                                inputs=list(test.parameters(recurse=False)),
                                retain_graph=True, create_graph=True, only_inputs=True)

If you run this code with your version of PyTorch (I mean commit 6dcaa47) you'll see the error while the original version (ae1a972) doesn't generate error.

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

kshitij12345 commented Sep 26, 2018

@pooryapzm Indeed problem was in my part of code. Sorry. Have fixed it. Please let me know if it works for you as well. Have checked the following code:

import torch
torch.manual_seed(6)

from onmt.modules.embeddings import Embeddings

class Test(torch.nn.Module):

    def __init__(self):
        super(Test, self).__init__()
        self.dense = torch.nn.Linear(100, 1)
        self.oembd = Embeddings(word_vec_size=100,
                   position_encoding=False,
                   dropout=0.3,feat_merge=None,
                   word_padding_idx=1,
                   word_vocab_size=1000)

    def forward(self, inp):
        inp = self.oembd(inp)
        
        return self.dense(inp[0])

test = Test()
test.cuda()
inp = torch.tensor([1, 1, 2, 1, 1, 2],device='cuda')
inp=inp.unsqueeze(0).unsqueeze(-1)
out = test(inp)
raw_loss = out.mean(dim=1)
loss_grad = torch.autograd.grad(outputs=raw_loss,
                                inputs=list(test.parameters()),
                                retain_graph=True, create_graph=True, only_inputs=True)

norm = sum([param.norm()**2 for param in loss_grad])

loss = raw_loss + norm

loss.backward()
print("Succesful")

@pooryapzm
Copy link
Copy Markdown

pooryapzm commented Oct 1, 2018

@kshitij12345 Awesome. It works with my code as well.
Thanks very much.

@t-vi
Copy link
Copy Markdown
Collaborator

t-vi commented Oct 14, 2018

I'm afraid I have to own up to the fact that there the indexAdd approach isn't good after all.
There is a pretty strict drawback: dense embedding's backward is currently deteministic (on in particular on cuda) and it would become non-deterministic when switched to indexAdd.
As such, I have come to the conclusion that the best way to fix this is to implement a double backward explicitly.
@kshitij12345 I'm sorry about putting you on a wrong track here.

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@t-vi Oh, no worries , I ll try the double backwards approach. If I have any doubts or get stuck will ask for your help. Thank You Again for guiding me.

However I am confused about what you meant in regards to non-deterministic. Cause if iyou mean the non-determinism in time, from the benchmark it is visible that the running time for embedding with index_add_ has less standard deviation, be it on CUDA or CPU. Just curious to understand and know more.

@t-vi
Copy link
Copy Markdown
Collaborator

t-vi commented Oct 16, 2018

index_add_ does not produce deterministic results, as the order of addition is unspecified (see #12217 for a tiny bit more context).

Comment thread aten/src/ATen/native/Embedding.cpp Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@colesbury , please have a look.

@t-vi
Copy link
Copy Markdown
Collaborator

t-vi commented Oct 25, 2018

That looks simple enough for me to be a bit embarrassed to have suggested the weighted index_add_ first.
You're the better expert for embeddings than me, but:

  • Would having a second derivative also make sense for sparse?
  • We are certain that embeddings are always vectors, right? If not it might make sense to pass the weight sizes to the double backward instead deriving the shape from the indices there. (But I don't know if that would be good or not.)

This comment was marked as off-topic.

This comment was marked as off-topic.

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

kshitij12345 commented Oct 27, 2018

That looks simple enough for me to be a bit embarrassed to have suggested the weighted index_add_ first.
You're the better expert for embeddings than me, but:

* Would having a second derivative also make sense for sparse?

* We are certain that embeddings are always vectors, right? If not it might make sense to pass the weight sizes to the double backward instead deriving the shape from the indices there. (But I don't know if that would be good or not.)

Even I am embarrased, for my first attempt of contributing here is taking so long with all these mistakes (also with the fact that in excitement I started work in master branch and sent PR through master )

As for second point, in the little experience I have, embeddings have always been vectors in all the courses that I have taken and literature that I have seen. (Would be interesting to see if that is not always the case).

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@colesbury @ssnl please review.

Comment thread test/test_nn.py Outdated
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a few more tests with other parameters being tested as well, like padding_idx, max_norm, etc?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There already are tests for those parameters. Will add a test for double_backward in padding_idx.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant a test here for double backwards

Copy link
Copy Markdown
Collaborator Author

@kshitij12345 kshitij12345 Dec 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I see here, the normalization is independently applied before embedding. So I believe the test for max_norm should be independent like it is. Do let me know if I am missing something.

As for the padding_idx, I have extended the already present test to check for double backwards as well. Please review.

Thank You.

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@colesbury , please review

@soumith
Copy link
Copy Markdown
Collaborator

soumith commented Mar 29, 2019

@pytorchbot rebase this please

@pytorchbot
Copy link
Copy Markdown
Collaborator

There's nothing to do! This branch is already up to date with master (1240327).

(To learn more about this bot, see Bot commands.)

Copy link
Copy Markdown
Collaborator

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for your contribution Kshitij, it looks like this is finally good to go :)

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soumith has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[PyTorch] Dense embedding doesn't work with double backward