@@ -16,6 +16,7 @@ limitations under the License.
1616#include " tensorflow/core/grappler/costs/graph_properties.h"
1717
1818#include " absl/types/optional.h"
19+ #include " tensorflow/core/common_runtime/function.h"
1920#include " tensorflow/core/framework/common_shape_fns.h"
2021#include " tensorflow/core/framework/function.pb.h"
2122#include " tensorflow/core/framework/node_def_util.h"
@@ -590,6 +591,20 @@ bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) {
590591 return kOpTpeWhitelist ->find (op_type) != kOpTpeWhitelist ->end ();
591592}
592593
594+ // Negative shape size of '-1' represents unknown, while negative shape sizes
595+ // less than -1 represent unknown symbolic shapes (e.g. the shape of [-5, 5, -1,
596+ // -5] really means [x, 5, ?, x]). Before we can output the tensors as shapes,
597+ // we need to normalize them: mark all values <-1 as "unknown" (-1).
598+ static void NormalizeShapeForOutput (TensorShapeProto* shape) {
599+ for (int i = 0 ; i < shape->dim_size (); i++) {
600+ if (shape->dim (i).size () < -1 ) {
601+ VLOG (2 ) << " Normalizing dimension: " << i << " from "
602+ << shape->dim (i).size () << " to -1" ;
603+ shape->mutable_dim (i)->set_size (-1 );
604+ }
605+ }
606+ }
607+
593608// Processes symbolic shapes.
594609// Each symbolic shape or dimension is represented by a handle. Unlike the TF
595610// shape refiner which creates new handles every time it processes an unknown
@@ -722,7 +737,8 @@ class SymbolicShapeRefiner {
722737 return it->second .inference_context .get ();
723738 }
724739
725- // Forward the shapes from the function input nodes to
740+ // Forward the shapes from the function input nodes, PartitionedCalls or
741+ // StatefulPartitionedCall to
726742 // the argument nodes (which are Placeholder nodes), then
727743 // perform shape inference on the function body.
728744 //
@@ -732,18 +748,20 @@ class SymbolicShapeRefiner {
732748 // In the event of an error, UpdateNode will simply set `function_node`'s
733749 // output shape to be Unknown.
734750 Status UpdateFunction (const NodeDef* function_node) {
735- auto it = fun_to_grappler_function_item_.find (function_node->op ());
751+ NameAttrList function;
752+ TF_RETURN_IF_ERROR (NameAndAttrsFromFunctionCall (*function_node, &function));
753+ auto it = fun_to_grappler_function_item_.find (function.name ());
736754 if (it == fun_to_grappler_function_item_.end ()) {
737755 return errors::InvalidArgument (
738- function_node-> op (),
756+ function. name (),
739757 " was not previously added to SymbolicShapeRefiner." );
740758 }
741759
742760 const absl::optional<GrapplerFunctionItem>& maybe_grappler_function_item =
743761 it->second ;
744762 if (!maybe_grappler_function_item.has_value ()) {
745763 VLOG (3 ) << " Skip failed to instantiate function call: function_name="
746- << function_node-> op ();
764+ << function. name ();
747765
748766 auto * ctx = GetNodeContext (function_node);
749767 auto * ic = ctx->inference_context .get ();
@@ -789,11 +807,7 @@ class SymbolicShapeRefiner {
789807 const auto & handle = input_ic->output (output_port_num);
790808 input_ic->ShapeHandleToProto (handle, &proto);
791809 // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
792- for (int i = 0 ; i < proto.dim_size (); i++) {
793- if (proto.dim (i).size () < -1 ) {
794- proto.mutable_dim (i)->set_size (-1 );
795- }
796- }
810+ NormalizeShapeForOutput (&proto);
797811
798812 AttrValue output_attr;
799813 output_attr.mutable_list ()->add_shape ()->Swap (&proto);
@@ -870,8 +884,9 @@ class SymbolicShapeRefiner {
870884 out_tensor.ToString (), " has invalid position " , out_tensor.index (),
871885 " (output_properties.size() = " , output_properties.size (), " )." );
872886 }
873- auto const & outprop = output_properties[out_tensor.index ()];
874- const TensorShapeProto& shape = outprop.shape ();
887+ auto & outprop = output_properties[out_tensor.index ()];
888+ TensorShapeProto shape = outprop.shape ();
889+ NormalizeShapeForOutput (&shape);
875890 ShapeHandle out;
876891 TF_RETURN_IF_ERROR (ic->MakeShapeFromShapeProto (shape, &out));
877892 ic->set_output (output, out);
@@ -1196,15 +1211,14 @@ class SymbolicShapeRefiner {
11961211 return true ;
11971212 }
11981213
1199- Status AddFunction (const NodeDef* function_node) {
1200- auto it = fun_to_grappler_function_item_.find (function_node-> op ());
1214+ Status AddFunction (const NodeDef* function_node, NameAttrList function ) {
1215+ auto it = fun_to_grappler_function_item_.find (function. name ());
12011216 if (it != fun_to_grappler_function_item_.end ()) {
12021217 return Status::OK ();
12031218 }
12041219
12051220 const FunctionDef* function_def =
1206- CHECK_NOTNULL (function_library_.Find (function_node->op ()));
1207-
1221+ CHECK_NOTNULL (function_library_.Find (function.name ()));
12081222 GrapplerFunctionItem grappler_function_item;
12091223 Status function_instantiated =
12101224 MakeGrapplerFunctionItem (*function_def, function_library_,
@@ -1242,10 +1256,15 @@ class SymbolicShapeRefiner {
12421256
12431257 Status AddNode (const NodeDef* node) {
12441258 NodeContext& node_ctx = node_to_context_[node];
1245- TF_RETURN_IF_ERROR (function_library_.LookUp (node->op (), &node_ctx.op_data ));
1259+ NameAttrList function;
1260+ TF_RETURN_IF_ERROR (NameAndAttrsFromFunctionCall (*node, &function));
1261+
1262+ // For PartitionedCall, op_data represents the function info.
1263+ TF_RETURN_IF_ERROR (
1264+ function_library_.LookUp (function.name (), &node_ctx.op_data ));
12461265
12471266 if (node_ctx.op_data ->is_function_op ) {
1248- TF_RETURN_IF_ERROR (AddFunction (node));
1267+ TF_RETURN_IF_ERROR (AddFunction (node, function ));
12491268 }
12501269
12511270 TF_RETURN_IF_ERROR (InOutTypesForNode (*node, node_ctx.op_data ->op_def ,
@@ -2525,13 +2544,7 @@ Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def,
25252544 TensorShapeProto* proto = attr_output_shape.mutable_list ()->add_shape ();
25262545 *proto = tensor_property.shape ();
25272546 if (!allow_symbolic_shapes) {
2528- // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to
2529- // -1.
2530- for (int i = 0 ; i < proto->dim_size (); i++) {
2531- if (proto->dim (i).size () < -1 ) {
2532- proto->mutable_dim (i)->set_size (-1 );
2533- }
2534- }
2547+ NormalizeShapeForOutput (proto);
25352548 }
25362549 }
25372550 (*node->mutable_attr ())[" _output_shapes" ] = attr_output_shape;
0 commit comments