Add support for bool in elementwise ops #321
Conversation
Signed-off-by: inocsin <[email protected]>
| auto self = args[0].ITensorOrFreeze(ctx); | ||
| auto other = args[1].ITensorOrFreeze(ctx); | ||
| auto mul = | ||
| nvinfer1::ILayer* mul = nullptr; |
There was a problem hiding this comment.
@inocsin Are these changes that would only be relevant to mul or other ops in general? Also did you write tests for this? I didnt see them in the commits
There was a problem hiding this comment.
Currently, I only find the Bool * Int operation in my models. I think other operation like Int / Bool or Int +/- Bool doesn't make any sense. I will add test case later.
There was a problem hiding this comment.
The input type of trtorch is set to float in default, I have to modify the TRTorch/core/conversion/conversion.cpp:150
auto trt_in = ctx->net->addInput(name.c_str(), ctx->input_type, dims.input_shape);
to support input type of Bool
There was a problem hiding this comment.
What does it mean to do bool * int? should that give you a bool out? like what does False * 8 mean?
There was a problem hiding this comment.
[True, False] * [8, 8] = [8, 0], this can be a mask operation, check demo graph here #327
There was a problem hiding this comment.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/core/conversion/converters/impl/element_wise.cpp b/tmp/changes.txt
index da582c9..f35c52e 100644
--- a/workspace/core/conversion/converters/impl/element_wise.cpp
+++ b/tmp/changes.txt
@@ -360,7 +360,7 @@ auto element_wise_registrations TRTORCH_UNUSED =
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
nvinfer1::ILayer* mul = nullptr;
- if (self->getType() ==nvinfer1::DataType::kBOOL || other->getType() == nvinfer1::DataType::kBOOL) {
+ if (self->getType() == nvinfer1::DataType::kBOOL || other->getType() == nvinfer1::DataType::kBOOL) {
auto self_id = ctx->net->addIdentity(*self);
auto other_id = ctx->net->addIdentity(*other);
if (self->getType() == nvinfer1::DataType::kBOOL) {
@@ -369,11 +369,15 @@ auto element_wise_registrations TRTORCH_UNUSED =
if (other->getType() == nvinfer1::DataType::kBOOL) {
other_id->getOutput(0)->setType(nvinfer1::DataType::kINT32);
}
- mul =
- add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self_id->getOutput(0), other_id->getOutput(0), util::node_info(n));
+ mul = add_elementwise(
+ ctx,
+ nvinfer1::ElementWiseOperation::kPROD,
+ self_id->getOutput(0),
+ other_id->getOutput(0),
+ util::node_info(n));
} else {
mul =
- add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
+ add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
}
TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);
ERROR: Some files do not conform to style guidelines|
We support aten::to operator now which should handle this. Check out https://github.com/NVIDIA/TRTorch/blob/master/tests/core/conversion/converters/test_cast.cpp for usage. |
Description
Adds support for bool tensors in element wise ops
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: