Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ec9e078

Browse files
cheshiretensorflower-gardener
authored andcommitted
Fix handling of PartitionedCall in ShapeInference pass. Previously, the pass
assumed that the function call can only happen from the function input node, and could not perform inter-procedural shape propagation through functions called using PartitionedCall. Note that in order to use this from a grappler pass P, P needs to override `UsesFunctionLibrary` and return `true` from it. PiperOrigin-RevId: 308886562 Change-Id: I1d64c07d9c5fbab2365725e2c213b95f2b21ae01
1 parent 6007324 commit ec9e078

4 files changed

Lines changed: 113 additions & 25 deletions

File tree

tensorflow/core/grappler/costs/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ tf_cc_test(
9191
deps = [
9292
":graph_properties",
9393
"//tensorflow/cc:cc_ops",
94+
"//tensorflow/cc:functional_ops",
9495
"//tensorflow/cc:scope",
9596
"//tensorflow/core:framework",
9697
"//tensorflow/core:lib",

tensorflow/core/grappler/costs/graph_properties.cc

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "tensorflow/core/grappler/costs/graph_properties.h"
1717

1818
#include "absl/types/optional.h"
19+
#include "tensorflow/core/common_runtime/function.h"
1920
#include "tensorflow/core/framework/common_shape_fns.h"
2021
#include "tensorflow/core/framework/function.pb.h"
2122
#include "tensorflow/core/framework/node_def_util.h"
@@ -590,6 +591,20 @@ bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) {
590591
return kOpTpeWhitelist->find(op_type) != kOpTpeWhitelist->end();
591592
}
592593

594+
// Negative shape size of '-1' represents unknown, while negative shape sizes
595+
// less than -1 represent unknown symbolic shapes (e.g. the shape of [-5, 5, -1,
596+
// -5] really means [x, 5, ?, x]). Before we can output the tensors as shapes,
597+
// we need to normalize them: mark all values <-1 as "unknown" (-1).
598+
static void NormalizeShapeForOutput(TensorShapeProto* shape) {
599+
for (int i = 0; i < shape->dim_size(); i++) {
600+
if (shape->dim(i).size() < -1) {
601+
VLOG(2) << "Normalizing dimension: " << i << " from "
602+
<< shape->dim(i).size() << " to -1";
603+
shape->mutable_dim(i)->set_size(-1);
604+
}
605+
}
606+
}
607+
593608
// Processes symbolic shapes.
594609
// Each symbolic shape or dimension is represented by a handle. Unlike the TF
595610
// shape refiner which creates new handles every time it processes an unknown
@@ -722,7 +737,8 @@ class SymbolicShapeRefiner {
722737
return it->second.inference_context.get();
723738
}
724739

725-
// Forward the shapes from the function input nodes to
740+
// Forward the shapes from the function input nodes, PartitionedCalls or
741+
// StatefulPartitionedCall to
726742
// the argument nodes (which are Placeholder nodes), then
727743
// perform shape inference on the function body.
728744
//
@@ -732,18 +748,20 @@ class SymbolicShapeRefiner {
732748
// In the event of an error, UpdateNode will simply set `function_node`'s
733749
// output shape to be Unknown.
734750
Status UpdateFunction(const NodeDef* function_node) {
735-
auto it = fun_to_grappler_function_item_.find(function_node->op());
751+
NameAttrList function;
752+
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*function_node, &function));
753+
auto it = fun_to_grappler_function_item_.find(function.name());
736754
if (it == fun_to_grappler_function_item_.end()) {
737755
return errors::InvalidArgument(
738-
function_node->op(),
756+
function.name(),
739757
" was not previously added to SymbolicShapeRefiner.");
740758
}
741759

742760
const absl::optional<GrapplerFunctionItem>& maybe_grappler_function_item =
743761
it->second;
744762
if (!maybe_grappler_function_item.has_value()) {
745763
VLOG(3) << "Skip failed to instantiate function call: function_name="
746-
<< function_node->op();
764+
<< function.name();
747765

748766
auto* ctx = GetNodeContext(function_node);
749767
auto* ic = ctx->inference_context.get();
@@ -789,11 +807,7 @@ class SymbolicShapeRefiner {
789807
const auto& handle = input_ic->output(output_port_num);
790808
input_ic->ShapeHandleToProto(handle, &proto);
791809
// There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
792-
for (int i = 0; i < proto.dim_size(); i++) {
793-
if (proto.dim(i).size() < -1) {
794-
proto.mutable_dim(i)->set_size(-1);
795-
}
796-
}
810+
NormalizeShapeForOutput(&proto);
797811

798812
AttrValue output_attr;
799813
output_attr.mutable_list()->add_shape()->Swap(&proto);
@@ -870,8 +884,9 @@ class SymbolicShapeRefiner {
870884
out_tensor.ToString(), " has invalid position ", out_tensor.index(),
871885
" (output_properties.size() = ", output_properties.size(), ").");
872886
}
873-
auto const& outprop = output_properties[out_tensor.index()];
874-
const TensorShapeProto& shape = outprop.shape();
887+
auto& outprop = output_properties[out_tensor.index()];
888+
TensorShapeProto shape = outprop.shape();
889+
NormalizeShapeForOutput(&shape);
875890
ShapeHandle out;
876891
TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
877892
ic->set_output(output, out);
@@ -1196,15 +1211,14 @@ class SymbolicShapeRefiner {
11961211
return true;
11971212
}
11981213

1199-
Status AddFunction(const NodeDef* function_node) {
1200-
auto it = fun_to_grappler_function_item_.find(function_node->op());
1214+
Status AddFunction(const NodeDef* function_node, NameAttrList function) {
1215+
auto it = fun_to_grappler_function_item_.find(function.name());
12011216
if (it != fun_to_grappler_function_item_.end()) {
12021217
return Status::OK();
12031218
}
12041219

12051220
const FunctionDef* function_def =
1206-
CHECK_NOTNULL(function_library_.Find(function_node->op()));
1207-
1221+
CHECK_NOTNULL(function_library_.Find(function.name()));
12081222
GrapplerFunctionItem grappler_function_item;
12091223
Status function_instantiated =
12101224
MakeGrapplerFunctionItem(*function_def, function_library_,
@@ -1242,10 +1256,15 @@ class SymbolicShapeRefiner {
12421256

12431257
Status AddNode(const NodeDef* node) {
12441258
NodeContext& node_ctx = node_to_context_[node];
1245-
TF_RETURN_IF_ERROR(function_library_.LookUp(node->op(), &node_ctx.op_data));
1259+
NameAttrList function;
1260+
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*node, &function));
1261+
1262+
// For PartitionedCall, op_data represents the function info.
1263+
TF_RETURN_IF_ERROR(
1264+
function_library_.LookUp(function.name(), &node_ctx.op_data));
12461265

12471266
if (node_ctx.op_data->is_function_op) {
1248-
TF_RETURN_IF_ERROR(AddFunction(node));
1267+
TF_RETURN_IF_ERROR(AddFunction(node, function));
12491268
}
12501269

12511270
TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
@@ -2525,13 +2544,7 @@ Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def,
25252544
TensorShapeProto* proto = attr_output_shape.mutable_list()->add_shape();
25262545
*proto = tensor_property.shape();
25272546
if (!allow_symbolic_shapes) {
2528-
// There may be dim.size < -1 in SymbolicShapeRefiner. Change those to
2529-
// -1.
2530-
for (int i = 0; i < proto->dim_size(); i++) {
2531-
if (proto->dim(i).size() < -1) {
2532-
proto->mutable_dim(i)->set_size(-1);
2533-
}
2534-
}
2547+
NormalizeShapeForOutput(proto);
25352548
}
25362549
}
25372550
(*node->mutable_attr())["_output_shapes"] = attr_output_shape;

tensorflow/core/grappler/costs/graph_properties_test.cc

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "tensorflow/core/grappler/costs/graph_properties.h"
1717

1818
#include "tensorflow/cc/framework/scope.h"
19+
#include "tensorflow/cc/ops/functional_ops.h"
1920
#include "tensorflow/cc/ops/standard_ops.h"
2021
#include "tensorflow/core/framework/graph_def_util.h"
2122
#include "tensorflow/core/framework/node_def_builder.h"
@@ -2002,6 +2003,79 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) {
20022003
EXPECT_EQ("float: [10,100]", PropToString(prop));
20032004
}
20042005

2006+
TEST_F(GraphPropertiesTest, PartitionedCallOp) {
2007+
Scope root = Scope::NewRootScope().ExitOnError();
2008+
FunctionDefLibrary library;
2009+
FunctionDef called_func = FunctionDefHelper::Create(
2010+
"identity_function",
2011+
/*in_def=*/{"arg0: int32"},
2012+
/*out_def=*/{"ret0: int32"},
2013+
/*attr_def=*/{},
2014+
{{{"Identity"}, "Identity", {"arg0"}, {{"T", DT_INT32}}}},
2015+
/*ret_def=*/{{"ret0", "Identity:output:0"}});
2016+
*library.add_function() = called_func;
2017+
TF_CHECK_OK(root.graph()->AddFunctionLibrary(library));
2018+
2019+
Output in = ops::Const(root, {3, 1, 2, 0});
2020+
NameAttrList b_name_attr;
2021+
b_name_attr.set_name("identity_function");
2022+
ops::PartitionedCall call(root.WithOpName("identity_call"), {in}, {DT_INT32},
2023+
b_name_attr);
2024+
2025+
GrapplerItem item;
2026+
TF_CHECK_OK(root.ToGraphDef(&item.graph));
2027+
2028+
GraphProperties properties(item);
2029+
TF_CHECK_OK(properties.InferStatically(
2030+
/*assume_valid_feeds=*/true,
2031+
/*aggressive_shape_inference=*/false,
2032+
/*include_tensor_values=*/true));
2033+
2034+
EXPECT_EQ("int32: [4]",
2035+
PropToString(properties.GetOutputProperties("identity_call")[0]));
2036+
}
2037+
2038+
TEST_F(GraphPropertiesTest, NonTrivialInputPartitionedCallOp) {
2039+
auto f = FunctionDefHelper::Create(
2040+
// Name
2041+
"FunctionWhichAdds",
2042+
// Inputs
2043+
{"arg0: int32", "arg1: int32"},
2044+
// Outputs
2045+
{"ret0: int32"},
2046+
/*attr_def=*/{},
2047+
// Nodes
2048+
{{{"a"}, "Add", {"arg0", "arg1"}, {{"T", DT_INT32}}}},
2049+
/*ret_def=*/{{"ret0", "a:z:0"}});
2050+
2051+
FunctionDefLibrary function_lib;
2052+
function_lib.add_function()->Swap(&f);
2053+
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
2054+
TF_CHECK_OK(root.graph()->AddFunctionLibrary(function_lib));
2055+
2056+
PartialTensorShape input_shape({2, 2, -1});
2057+
Output in1 =
2058+
ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2059+
Output in2 =
2060+
ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
2061+
NameAttrList b_name_attr;
2062+
b_name_attr.set_name("FunctionWhichAdds");
2063+
ops::PartitionedCall call(root.WithOpName("add_call"), {in1, in2}, {DT_INT32},
2064+
b_name_attr);
2065+
2066+
GrapplerItem item;
2067+
TF_CHECK_OK(root.ToGraphDef(&item.graph));
2068+
2069+
GraphProperties properties(item);
2070+
TF_CHECK_OK(properties.InferStatically(
2071+
/*assume_valid_feeds=*/true,
2072+
/*aggressive_shape_inference=*/false,
2073+
/*include_tensor_values=*/true));
2074+
2075+
EXPECT_EQ("int32: [2,2,-1]",
2076+
PropToString(properties.GetOutputProperties("add_call")[0]));
2077+
}
2078+
20052079
TEST_F(GraphPropertiesTest, ShapeAnnotatedFunctionOp) {
20062080
// A function, which we cannot infer output shape statically.
20072081
auto f = FunctionDefHelper::Create(

tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class ScopedAllocatorOptimizer : public GraphOptimizer {
4343

4444
string name() const override { return "scoped_allocator_optimizer"; }
4545

46-
bool UsesFunctionLibrary() const override { return false; }
46+
bool UsesFunctionLibrary() const override { return true; }
4747

4848
Status Optimize(Cluster* cluster, const GrapplerItem& item,
4949
GraphDef* optimized_graph) override;

0 commit comments

Comments
 (0)