Support N-D tensors in Bilinear#5764
Conversation
|
@pytorchbot retest this please |
|
@pytorchbot retest this please |
| namespace at { namespace native { | ||
|
|
||
| Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight, const Tensor& bias) { | ||
| auto b_input1 = input1.unsqueeze(-2).unsqueeze(-2); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| Shape: | ||
| - Input: :math:`(N, \text{in1_features})`, :math:`(N, \text{in2_features})` | ||
| - Output: :math:`(N, \text{out_features})` | ||
| where :math:`*` means any number of additional dimensions. All but the last |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| auto b_input2 = input2.unsqueeze(-2).unsqueeze(-1); | ||
|
|
||
| auto output = at::matmul(at::matmul(b_input1, weight), b_input2); | ||
| output = output.squeeze(-1).squeeze(-2).sum(-1); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight, const Tensor& bias) { | ||
| if (input1.dim() != input2.dim()) { | ||
| throw std::runtime_error("Inputs should have the same number of dimensions"); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| } | ||
| for (int64_t i = 0; i < input1.dim() - 1; i++) { | ||
| if (input1.size(i) != input2.size(i)) { | ||
| throw std::runtime_error("Batch dimensions of inputs do not match"); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| if (input1.dim() != input2.dim()) { | ||
| throw std::runtime_error("Inputs should have the same number of dimensions"); | ||
| } | ||
| AT_ASSERT(input1.dim() == input2.dim(), "bilinear(): input dimensions do not match: got %d and %d", |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| throw std::runtime_error("Bias sizes does not match weight size"); | ||
| AT_ASSERT(input1.size(i) == input2.size(i), | ||
| "bilinear(): input batch dimensions do not match at dim %d: got %d and %d", | ||
| i, input1.size(i), input2.size(i)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@pytorchbot retest this please |
|
I like richard's enthusiasm :D |
* support n-d inputs in bilinear and move to aten * support n-d inputs in bilinear and move to aten * add asserts to bilinear inputs * address comments * cast int64_t in asserts
Closes #5601.