Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1b57826

Browse files
Fix _Recv op caching for multi-output port ops in VirtualScheduler.
PiperOrigin-RevId: 162970793
1 parent 62bced8 commit 1b57826

3 files changed

Lines changed: 169 additions & 10 deletions

File tree

tensorflow/core/grappler/costs/virtual_scheduler.cc

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,32 @@ Costs CombineCosts(const Costs& left, const Costs& right) {
5656
<< " max_per_op_streaming=" << result.max_per_op_streaming;
5757
return result;
5858
}
59+
60+
// Key to the cached _Recv ops map, and its hash and predicate structures.
61+
struct RecvNodeDescriptor {
62+
const NodeDef* node;
63+
const int port_num;
64+
const string& device;
65+
66+
RecvNodeDescriptor(const NodeDef* node_, const int port_num_,
67+
const string& device_)
68+
: node(node_), port_num(port_num_), device(device_) {}
69+
};
70+
71+
struct RecvNodeDescritorHash {
72+
std::size_t operator()(const RecvNodeDescriptor& recv_node) const {
73+
return std::hash<const NodeDef*>()(recv_node.node) ^
74+
std::hash<int>()(recv_node.port_num) ^
75+
std::hash<string>()(recv_node.device);
76+
}
77+
};
78+
79+
struct RecvNodeDescriptorEqual {
80+
bool operator()(const RecvNodeDescriptor& a,
81+
const RecvNodeDescriptor& b) const {
82+
return a.node == b.node && a.port_num == b.port_num && a.device == b.device;
83+
}
84+
};
5985
} // namespace
6086

6187
VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
@@ -109,6 +135,11 @@ Status VirtualScheduler::Init() {
109135
name_to_node[node->name()] = node;
110136
}
111137

138+
// To reuse _Recv ops.
139+
std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescritorHash,
140+
RecvNodeDescriptorEqual>
141+
cached_recv_nodes;
142+
112143
// Build node_map; for each node, create its NodeState and connect its inputs
113144
// and outputs.
114145
for (const auto* curr_node : nodes) {
@@ -131,12 +162,13 @@ Status VirtualScheduler::Init() {
131162
auto& input_node_state = GetNodeStateOrCreateIt(input_node);
132163
input_node_state.outputs[input_node_port_num].push_back(curr_node);
133164
} else {
134-
if (cached_recv_nodes_.count(input_node) > 0 &&
135-
cached_recv_nodes_[input_node].count(curr_node_device) > 0) {
165+
RecvNodeDescriptor recv_node(input_node, input_node_port_num,
166+
curr_node_device);
167+
auto it = cached_recv_nodes.find(recv_node);
168+
if (it != cached_recv_nodes.end()) {
136169
// Different device, but found an already-cached copy (a _Recv op);
137170
// connect the _Recv to curr_node.
138-
const auto* recv_op =
139-
cached_recv_nodes_[input_node][curr_node_device];
171+
const NodeDef* recv_op = it->second;
140172
// recv_op's output port is hard-coded to zero.
141173
curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
142174
auto& input_node_state = node_map_.at(recv_op);
@@ -156,7 +188,7 @@ Status VirtualScheduler::Init() {
156188
input_node_state.outputs[input_node_port_num].push_back(send);
157189

158190
// Cache the _Recv op for future use.
159-
cached_recv_nodes_[input_node][curr_node_device] = recv;
191+
cached_recv_nodes[recv_node] = recv;
160192
}
161193
}
162194
}
@@ -269,10 +301,16 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
269301
// input names, attrs, etc.
270302

271303
auto input_node_port_num = NodePosition(input_name);
304+
string src_name;
305+
if (input_node_port_num >= 0) {
306+
src_name = strings::StrCat(from->name(), ":", input_node_port_num);
307+
} else {
308+
src_name = strings::StrCat(from->name(), ":minus1");
309+
}
272310

273311
// _Send op.
274312
auto* send = new NodeDef();
275-
send->set_name("Send " + from->name() + " from " + DeviceName(from) + " to " +
313+
send->set_name("Send " + src_name + " from " + DeviceName(from) + " to " +
276314
DeviceName(to));
277315
send->set_op("_Send");
278316
send->add_input(from->name());
@@ -284,7 +322,7 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
284322

285323
// _Recv op.
286324
auto* recv = new NodeDef();
287-
recv->set_name("Recv " + from->name() + " on " + DeviceName(to));
325+
recv->set_name("Recv " + src_name + " on " + DeviceName(to));
288326
recv->set_op("_Recv");
289327
recv->add_input(send->name());
290328
recv->set_device(DeviceName(to));

tensorflow/core/grappler/costs/virtual_scheduler.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,6 @@ class VirtualScheduler {
254254

255255
// Pool of NodeDefs for SendRecv and Identity ops created.
256256
std::vector<std::unique_ptr<NodeDef>> additional_nodes_;
257-
// Cache of nodes transferred to another device.
258-
std::unordered_map<const NodeDef*, std::unordered_map<string, const NodeDef*>>
259-
cached_recv_nodes_;
260257

261258
// Stats:
262259
std::map<string, int> op_counts_; // Op counts with key with input shape.

tensorflow/core/grappler/costs/virtual_scheduler_test.cc

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@ class TestVirtualScheduler : public VirtualScheduler {
3636
FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
3737
FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
3838
FRIEND_TEST(VirtualSchedulerTest, Variable);
39+
FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer);
3940
};
4041

4142
class VirtualSchedulerTest : public ::testing::Test {
4243
protected:
4344
NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
4445

4546
const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0";
47+
const string kCPU1 = "/job:localhost/replica:0/task:0/cpu:1";
4648

4749
DeviceProperties GetDummyCPUDevice() {
4850
// Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth.
@@ -74,6 +76,7 @@ class VirtualSchedulerTest : public ::testing::Test {
7476
// IMPORTANT: Device is not actually ever used in the test case since
7577
// force_cpu_type is defaulted to "Haswell"
7678
devices[kCPU0] = cpu_device;
79+
devices[kCPU1] = cpu_device;
7780
cluster_.reset(new VirtualCluster(devices));
7881
placer_.reset(new VirtualPlacer(cluster_.get()));
7982
}
@@ -642,6 +645,56 @@ versions {
642645
grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
643646
}
644647

648+
void CreateGrapplerItemWithInterDeviceTransfers() {
649+
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
650+
651+
// Create a FusedBatchNorm op that has multiple output ports.
652+
auto x = ops::RandomUniform(
653+
s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
654+
auto scale =
655+
ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
656+
auto offset =
657+
ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
658+
auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
659+
auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
660+
661+
auto batch_norm = ops::FusedBatchNorm(
662+
s.WithOpName("bn"), x, scale, offset, mean, var,
663+
ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
664+
auto y = batch_norm.y;
665+
auto batch_mean = batch_norm.batch_mean;
666+
auto batch_var = batch_norm.batch_variance;
667+
// y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
668+
auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
669+
auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
670+
// batch_mean1 and batch_var1 take different output ports, so each will
671+
// initiate Send/Recv.
672+
auto batch_mean1 = ops::Identity(
673+
s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean);
674+
auto batch_var1 =
675+
ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
676+
// This is control dependency.
677+
auto control_dep = ops::NoOp(s.WithOpName("control_dep")
678+
.WithControlDependencies(y)
679+
.WithDevice(kCPU1));
680+
681+
GraphDef def;
682+
TF_CHECK_OK(s.ToGraphDef(&def));
683+
684+
grappler_item_.reset(new GrapplerItem);
685+
grappler_item_->id = "test_conv2d_graph";
686+
grappler_item_->graph = def;
687+
grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1",
688+
"control_dep"};
689+
690+
dependency_["bn"] = {"x", "mean", "var"};
691+
dependency_["y1"] = {"bn"};
692+
dependency_["y2"] = {"bn"};
693+
dependency_["batch_mean1"] = {"bn"};
694+
dependency_["batch_var1"] = {"bn"};
695+
dependency_["control_dep"] = {"bn"};
696+
}
697+
645698
// Call this after creating grappler_item_ and setting up dependency_.
646699
void InitScheduler() {
647700
scheduler_.reset(new TestVirtualScheduler(
@@ -1236,5 +1289,76 @@ TEST_F(VirtualSchedulerTest, WhileLoop) {
12361289
EXPECT_EQ(1, num_exit_1);
12371290
}
12381291

1292+
TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
1293+
// Init.
1294+
CreateGrapplerItemWithInterDeviceTransfers();
1295+
InitScheduler();
1296+
1297+
// Run the scheduler.
1298+
auto ops_executed = RunScheduler("");
1299+
1300+
// Helper lambda to extract port num from _Send and _Recv op name.
1301+
auto get_port_num = [](const string& name) -> int {
1302+
if (name.find("bn:0") != std::string::npos) {
1303+
return 0;
1304+
} else if (name.find("bn:1") != std::string::npos) {
1305+
return 1;
1306+
} else if (name.find("bn:2") != std::string::npos) {
1307+
return 2;
1308+
} else if (name.find("bn:minus1") != std::string::npos) {
1309+
return -1;
1310+
}
1311+
return -999;
1312+
};
1313+
1314+
// Reorganize ops_executed for further testing.
1315+
std::unordered_map<string, int> op_count;
1316+
std::unordered_map<int, string> recv_op_names;
1317+
std::unordered_map<int, string> send_op_names;
1318+
for (const auto& x : ops_executed) {
1319+
const auto& name = x.first;
1320+
const auto& node_info = x.second;
1321+
const auto& op = node_info.op_info.op();
1322+
if (op == "_Recv") {
1323+
recv_op_names[get_port_num(name)] = name;
1324+
} else if (op == "_Send") {
1325+
send_op_names[get_port_num(name)] = name;
1326+
}
1327+
op_count[op]++;
1328+
}
1329+
1330+
// Same number of _Send and _Recv.
1331+
EXPECT_EQ(op_count.at("_Send"), op_count.at("_Recv"));
1332+
1333+
// Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency.
1334+
EXPECT_EQ(op_count.at("_Recv"), 4);
1335+
EXPECT_EQ(op_count.at("_Send"), 4);
1336+
1337+
// Helper lambda for extracting output Tensor size.
1338+
auto get_output_size = [this, ops_executed](const string& name) -> int64 {
1339+
const auto& output_properties_ = ops_executed.at(name).op_info.outputs();
1340+
std::vector<OpInfo::TensorProperties> output_properties;
1341+
for (const auto& output_property : output_properties_) {
1342+
output_properties.push_back(output_property);
1343+
}
1344+
return scheduler_->CalculateOutputSize(output_properties, 0);
1345+
1346+
};
1347+
1348+
// Validate transfer size.
1349+
// Batchnorm output y is 4D vector: batch x width x width x depth.
1350+
int input_size = 4 * batch_size_ * width_ * height_ * depth_in_;
1351+
EXPECT_EQ(get_output_size(recv_op_names[0]), input_size);
1352+
EXPECT_EQ(get_output_size(send_op_names[0]), input_size);
1353+
// Mean and vars are 1-D vector with size depth_in_.
1354+
EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_);
1355+
EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_);
1356+
EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_);
1357+
EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_);
1358+
// Control dependency size is 4B.
1359+
EXPECT_EQ(get_output_size(recv_op_names[-1]), 4);
1360+
EXPECT_EQ(get_output_size(send_op_names[-1]), 4);
1361+
}
1362+
12391363
} // end namespace grappler
12401364
} // end namespace tensorflow

0 commit comments

Comments
 (0)