@@ -36,13 +36,15 @@ class TestVirtualScheduler : public VirtualScheduler {
3636 FRIEND_TEST (VirtualSchedulerTest, ControlDependency);
3737 FRIEND_TEST (VirtualSchedulerTest, ComplexDependency);
3838 FRIEND_TEST (VirtualSchedulerTest, Variable);
39+ FRIEND_TEST (VirtualSchedulerTest, InterDeviceTransfer);
3940};
4041
4142class VirtualSchedulerTest : public ::testing::Test {
4243 protected:
4344 NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
4445
4546 const string kCPU0 = " /job:localhost/replica:0/task:0/cpu:0" ;
47+ const string kCPU1 = " /job:localhost/replica:0/task:0/cpu:1" ;
4648
4749 DeviceProperties GetDummyCPUDevice () {
4850 // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth.
@@ -74,6 +76,7 @@ class VirtualSchedulerTest : public ::testing::Test {
7476 // IMPORTANT: Device is not actually ever used in the test case since
7577 // force_cpu_type is defaulted to "Haswell"
7678 devices[kCPU0 ] = cpu_device;
79+ devices[kCPU1 ] = cpu_device;
7780 cluster_.reset (new VirtualCluster (devices));
7881 placer_.reset (new VirtualPlacer (cluster_.get ()));
7982 }
@@ -642,6 +645,56 @@ versions {
642645 grappler_item_->fetch = {" while/Exit" , " while/Exit_1" };
643646 }
644647
648+ void CreateGrapplerItemWithInterDeviceTransfers () {
649+ tensorflow::Scope s = tensorflow::Scope::NewRootScope ().WithDevice (kCPU0 );
650+
651+ // Create a FusedBatchNorm op that has multiple output ports.
652+ auto x = ops::RandomUniform (
653+ s.WithOpName (" x" ), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
654+ auto scale =
655+ ops::RandomUniform (s.WithOpName (" scale" ), {depth_in_}, DT_FLOAT);
656+ auto offset =
657+ ops::RandomUniform (s.WithOpName (" offset" ), {depth_in_}, DT_FLOAT);
658+ auto mean = ops::RandomUniform (s.WithOpName (" mean" ), {0 }, DT_FLOAT);
659+ auto var = ops::RandomUniform (s.WithOpName (" var" ), {0 }, DT_FLOAT);
660+
661+ auto batch_norm = ops::FusedBatchNorm (
662+ s.WithOpName (" bn" ), x, scale, offset, mean, var,
663+ ops::FusedBatchNorm::IsTraining (true ).Epsilon (0 .1f ));
664+ auto y = batch_norm.y ;
665+ auto batch_mean = batch_norm.batch_mean ;
666+ auto batch_var = batch_norm.batch_variance ;
667+ // y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
668+ auto y1 = ops::Identity (s.WithOpName (" y1" ).WithDevice (kCPU1 ), y);
669+ auto y2 = ops::Identity (s.WithOpName (" y2" ).WithDevice (kCPU1 ), y);
670+ // batch_mean1 and batch_var1 take different output ports, so each will
671+ // initiate Send/Recv.
672+ auto batch_mean1 = ops::Identity (
673+ s.WithOpName (" batch_mean1" ).WithDevice (kCPU1 ), batch_mean);
674+ auto batch_var1 =
675+ ops::Identity (s.WithOpName (" batch_var1" ).WithDevice (kCPU1 ), batch_var);
676+ // This is control dependency.
677+ auto control_dep = ops::NoOp (s.WithOpName (" control_dep" )
678+ .WithControlDependencies (y)
679+ .WithDevice (kCPU1 ));
680+
681+ GraphDef def;
682+ TF_CHECK_OK (s.ToGraphDef (&def));
683+
684+ grappler_item_.reset (new GrapplerItem);
685+ grappler_item_->id = " test_conv2d_graph" ;
686+ grappler_item_->graph = def;
687+ grappler_item_->fetch = {" y1" , " y2" , " batch_mean1" , " batch_var1" ,
688+ " control_dep" };
689+
690+ dependency_[" bn" ] = {" x" , " mean" , " var" };
691+ dependency_[" y1" ] = {" bn" };
692+ dependency_[" y2" ] = {" bn" };
693+ dependency_[" batch_mean1" ] = {" bn" };
694+ dependency_[" batch_var1" ] = {" bn" };
695+ dependency_[" control_dep" ] = {" bn" };
696+ }
697+
645698 // Call this after creating grappler_item_ and setting up dependency_.
646699 void InitScheduler () {
647700 scheduler_.reset (new TestVirtualScheduler (
@@ -1236,5 +1289,76 @@ TEST_F(VirtualSchedulerTest, WhileLoop) {
12361289 EXPECT_EQ (1 , num_exit_1);
12371290}
12381291
1292+ TEST_F (VirtualSchedulerTest, InterDeviceTransfer) {
1293+ // Init.
1294+ CreateGrapplerItemWithInterDeviceTransfers ();
1295+ InitScheduler ();
1296+
1297+ // Run the scheduler.
1298+ auto ops_executed = RunScheduler (" " );
1299+
1300+ // Helper lambda to extract port num from _Send and _Recv op name.
1301+ auto get_port_num = [](const string& name) -> int {
1302+ if (name.find (" bn:0" ) != std::string::npos) {
1303+ return 0 ;
1304+ } else if (name.find (" bn:1" ) != std::string::npos) {
1305+ return 1 ;
1306+ } else if (name.find (" bn:2" ) != std::string::npos) {
1307+ return 2 ;
1308+ } else if (name.find (" bn:minus1" ) != std::string::npos) {
1309+ return -1 ;
1310+ }
1311+ return -999 ;
1312+ };
1313+
1314+ // Reorganize ops_executed for further testing.
1315+ std::unordered_map<string, int > op_count;
1316+ std::unordered_map<int , string> recv_op_names;
1317+ std::unordered_map<int , string> send_op_names;
1318+ for (const auto & x : ops_executed) {
1319+ const auto & name = x.first ;
1320+ const auto & node_info = x.second ;
1321+ const auto & op = node_info.op_info .op ();
1322+ if (op == " _Recv" ) {
1323+ recv_op_names[get_port_num (name)] = name;
1324+ } else if (op == " _Send" ) {
1325+ send_op_names[get_port_num (name)] = name;
1326+ }
1327+ op_count[op]++;
1328+ }
1329+
1330+ // Same number of _Send and _Recv.
1331+ EXPECT_EQ (op_count.at (" _Send" ), op_count.at (" _Recv" ));
1332+
1333+ // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency.
1334+ EXPECT_EQ (op_count.at (" _Recv" ), 4 );
1335+ EXPECT_EQ (op_count.at (" _Send" ), 4 );
1336+
1337+ // Helper lambda for extracting output Tensor size.
1338+ auto get_output_size = [this , ops_executed](const string& name) -> int64 {
1339+ const auto & output_properties_ = ops_executed.at (name).op_info .outputs ();
1340+ std::vector<OpInfo::TensorProperties> output_properties;
1341+ for (const auto & output_property : output_properties_) {
1342+ output_properties.push_back (output_property);
1343+ }
1344+ return scheduler_->CalculateOutputSize (output_properties, 0 );
1345+
1346+ };
1347+
1348+ // Validate transfer size.
1349+ // Batchnorm output y is 4D vector: batch x width x width x depth.
1350+ int input_size = 4 * batch_size_ * width_ * height_ * depth_in_;
1351+ EXPECT_EQ (get_output_size (recv_op_names[0 ]), input_size);
1352+ EXPECT_EQ (get_output_size (send_op_names[0 ]), input_size);
1353+ // Mean and vars are 1-D vector with size depth_in_.
1354+ EXPECT_EQ (get_output_size (recv_op_names[1 ]), 4 * depth_in_);
1355+ EXPECT_EQ (get_output_size (send_op_names[1 ]), 4 * depth_in_);
1356+ EXPECT_EQ (get_output_size (recv_op_names[2 ]), 4 * depth_in_);
1357+ EXPECT_EQ (get_output_size (send_op_names[2 ]), 4 * depth_in_);
1358+ // Control dependency size is 4B.
1359+ EXPECT_EQ (get_output_size (recv_op_names[-1 ]), 4 );
1360+ EXPECT_EQ (get_output_size (send_op_names[-1 ]), 4 );
1361+ }
1362+
12391363} // end namespace grappler
12401364} // end namespace tensorflow
0 commit comments