Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e6be8a4

Browse files
Suharsh Sivakumartensorflower-gardener
authored andcommitted
Add StatsPublisherFactory and NoOpStatsPublisher to distributed session.
Change: 134707078
1 parent 8f4087c commit e6be8a4

6 files changed

Lines changed: 221 additions & 82 deletions

File tree

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
17+
18+
namespace tensorflow {
19+
20+
namespace {
21+
// NoOpStatsPublisher provides an dummy/no-op implementation of
22+
// StatsPublisherInterface.
23+
class NoOpStatsPublisher : public StatsPublisherInterface {
24+
public:
25+
NoOpStatsPublisher(){};
26+
27+
void PublishStatsProto(const StepStats& step_stats) override { return; }
28+
29+
void PublishGraphProto(
30+
const std::vector<const GraphDef*>& graph_defs) override {
31+
return;
32+
}
33+
34+
std::unique_ptr<ProfileHandler> GetProfileHandler(
35+
uint64 step, int64 execution_count, const RunOptions& ropts) override {
36+
return nullptr;
37+
}
38+
39+
~NoOpStatsPublisher() override {}
40+
};
41+
} // namespace
42+
43+
std::unique_ptr<StatsPublisherInterface> CreateNoOpStatsPublisher(
44+
const string& session, const BuildGraphOptions& bopts,
45+
const SessionOptions& sopts) {
46+
return std::unique_ptr<StatsPublisherInterface>(new NoOpStatsPublisher);
47+
}
48+
49+
} // namespace tensorflow

tensorflow/core/common_runtime/stats_publisher_interface.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ class StatsPublisherInterface {
5151
virtual ~StatsPublisherInterface() {}
5252
};
5353

54+
typedef std::function<std::unique_ptr<StatsPublisherInterface>(
55+
const string&, const BuildGraphOptions&, const SessionOptions&)>
56+
StatsPublisherFactory;
57+
58+
std::unique_ptr<StatsPublisherInterface> CreateNoOpStatsPublisher(
59+
const string& session, const BuildGraphOptions& bopts,
60+
const SessionOptions& sopts);
61+
5462
} // namespace tensorflow
5563

5664
#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_

tensorflow/core/distributed_runtime/master_session.cc

Lines changed: 124 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ limitations under the License.
2121

2222
#include "tensorflow/core/common_runtime/device_set.h"
2323
#include "tensorflow/core/common_runtime/process_util.h"
24+
#include "tensorflow/core/common_runtime/profile_handler.h"
2425
#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
26+
#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
2527
#include "tensorflow/core/distributed_runtime/master_env.h"
2628
#include "tensorflow/core/distributed_runtime/master_session_interface.h"
2729
#include "tensorflow/core/distributed_runtime/scheduler.h"
@@ -75,9 +77,10 @@ class MasterSession : public MasterSessionInterface {
7577
//
7678
// The caller takes ownership of all remote devices.
7779
MasterSession(const SessionOptions& options, const MasterEnv* env,
78-
std::vector<Device*>* remote_devs);
80+
std::vector<Device*>* remote_devs,
81+
StatsPublisherFactory stats_publisher_factory);
7982

80-
// Initialize the Session for "def". Must be called before Extend(),
83+
// Initialize the MasterSession for "def". Must be called before Extend(),
8184
// Run(), or Close().
8285
//
8386
// The callee may clear "def".
@@ -130,6 +133,8 @@ class MasterSession : public MasterSessionInterface {
130133
// The device set used by this session.
131134
DeviceSet devices_;
132135

136+
StatsPublisherFactory stats_publisher_factory_;
137+
133138
std::atomic_ulong last_access_time_usec_;
134139

135140
mutex mu_;
@@ -173,35 +178,37 @@ class MasterSession : public MasterSessionInterface {
173178
TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
174179
};
175180

176-
// Session wraps SimpleClientGraph in a reference counted object. This way,
177-
// Session can clear up the cache mapping Run requests to compiled
178-
// graphs while the compiled graph is still being used.
181+
// MasterSession wraps SimpleClientGraph in a reference counted object.
182+
// This way, MasterSession can clear up the cache mapping Run requests to
183+
// compiled graphs while the compiled graph is still being used.
179184
//
180185
// TODO(zhifengc): Cleanup this class. It's becoming messy.
181186
class MasterSession::ReffedClientGraph : public core::RefCounted {
182187
public:
183188
ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
184189
std::unique_ptr<SimpleClientGraph> cg,
185-
const GraphOptions& graph_opts)
190+
const SessionOptions& session_opts,
191+
StatsPublisherFactory stats_publisher_factory)
186192
: session_handle_(handle),
187193
client_graph_(std::move(cg)),
188194
bopts_(bopts),
189-
graph_opts_(graph_opts) {
195+
session_opts_(session_opts) {
190196
VLOG(1) << "Created ReffedClientGraph for node with "
191197
<< client_graph_->graph.num_node_ids();
192198

193-
const string key =
194-
strings::StrCat("{", str_util::Join(bopts.feed_endpoints, ","), "},{",
195-
str_util::Join(bopts.target_nodes, ","), "},{",
196-
str_util::Join(bopts.fetch_endpoints, ","), "}");
197-
// TODO(mrry): Publish information about the graph (such as
198-
// timelines, the pruned graph, statistics, etc.).
199+
stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
199200
}
200201

201202
~ReffedClientGraph() override { DeregisterPartitions(); }
202203

203204
const SimpleClientGraph* client_graph() { return client_graph_.get(); }
204205

206+
std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
207+
int64 execution_count,
208+
const RunOptions& ropts) {
209+
return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
210+
}
211+
205212
// Turn RPC logging on or off, both at the WorkerCache used by this
206213
// master process, and at each remote worker in use for the current
207214
// partitions.
@@ -297,8 +304,9 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
297304
// Post-processing of any runtime statistics gathered during execution.
298305
void ProcessStats(const MasterEnv* env, int64 step_id, PerStepState* pss,
299306
SimpleGraphExecutionState* execution_state,
300-
RunStepResponse* resp);
301-
void ProcessDeviceStats(SimpleGraphExecutionState* execution_state,
307+
ProfileHandler* ph, RunStepResponse* resp);
308+
void ProcessDeviceStats(ProfileHandler* ph,
309+
SimpleGraphExecutionState* execution_state,
302310
const DeviceStepStats& ds, bool is_rpc);
303311

304312
string DetailText(const NodeDef& def, const NodeExecStats& ns) {
@@ -323,7 +331,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
323331
const std::unique_ptr<SimpleClientGraph> client_graph_;
324332
std::unordered_set<const Node*> nodes_needing_input_mapping_;
325333
BuildGraphOptions bopts_;
326-
const GraphOptions graph_opts_;
334+
const SessionOptions session_opts_;
327335

328336
// Graph partitioned into per-location subgraphs.
329337
struct Part {
@@ -365,6 +373,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
365373
// init_result_ remembers the initialization error if any.
366374
Status init_result_ GUARDED_BY(mu_);
367375

376+
std::unique_ptr<StatsPublisherInterface> stats_publisher_;
377+
368378
// Send/Recv nodes that are the result of client-added
369379
// feeds and fetches must be tracked so that the tensors
370380
// can be added to the local rendezvous.
@@ -391,6 +401,12 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions(
391401
init_started_ = true;
392402
mu_.unlock();
393403
Status s = DoRegisterPartitions(env, popts, func_def_lib);
404+
std::vector<const GraphDef*> graph_defs;
405+
graph_defs.reserve(partitions_.size());
406+
for (const Part& part : partitions_) {
407+
graph_defs.push_back(&part.gdef);
408+
}
409+
stats_publisher_->PublishGraphProto(graph_defs);
394410
mu_.lock();
395411
init_result_ = s;
396412
init_done_.Notify();
@@ -504,7 +520,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
504520
Call* c = &calls[i];
505521
c->req.set_session_handle(session_handle_);
506522
*c->req.mutable_graph_def() = part.gdef;
507-
*c->req.mutable_graph_options() = graph_opts_;
523+
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
508524
VLOG(2) << "Register " << part.gdef.DebugString();
509525
auto cb = [c](const Status& s) {
510526
c->status = s;
@@ -800,35 +816,97 @@ void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
800816

801817
void MasterSession::ReffedClientGraph::ProcessStats(
802818
const MasterEnv* env, int64 step_id, PerStepState* pss,
803-
SimpleGraphExecutionState* execution_state, RunStepResponse* resp) {
819+
SimpleGraphExecutionState* execution_state, ProfileHandler* ph,
820+
RunStepResponse* resp) {
804821
if (!pss->collect_costs && !pss->collect_timeline) return;
805822

806823
// Out-of-band logging data is collected now, during post-processing.
807824
if (pss->collect_timeline) {
808-
// TODO(suharshs): Can these two RPCs be combined?
809825
SetRPCLogging(env, false);
810826
RetrieveLogs(env, step_id, &pss->rpc_stats);
811827
}
812-
for (int i = 0; i < partitions_.size(); ++i) {
828+
for (size_t i = 0; i < partitions_.size(); ++i) {
813829
const StepStats& ss = pss->step_stats[i];
814830
if (pss->collect_costs) {
815831
execution_state->UpdateCostsFromStats(ss);
816832
}
833+
if (ph) {
834+
for (const auto& ds : ss.dev_stats()) {
835+
ProcessDeviceStats(ph, execution_state, ds, false /*is_rpc*/);
836+
}
837+
}
838+
}
839+
if (ph) {
840+
for (const auto& ds : pss->rpc_stats.dev_stats()) {
841+
ProcessDeviceStats(ph, execution_state, ds, true /*is_rpc*/);
842+
}
843+
ph->StepDone(pss->start_micros, pss->end_micros,
844+
Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/,
845+
Status::OK());
817846
}
818847
// Assemble all stats for this timeline into a merged StepStats.
819848
StepStats step_stats_proto;
820849
if (pss->collect_timeline) {
821850
step_stats_proto = pss->rpc_stats;
822-
for (int i = 0; i < partitions_.size(); ++i) {
851+
for (size_t i = 0; i < partitions_.size(); ++i) {
823852
const StepStats& ss = pss->step_stats[i];
824853
step_stats_proto.MergeFrom(ss);
825854
}
826-
// TODO(suharshs): handle timeline_step when adding timeline support.
827-
resp->mutable_metadata()->mutable_step_stats()->Swap(&step_stats_proto);
855+
stats_publisher_->PublishStatsProto(step_stats_proto);
856+
// Copy the stats back, but only for on-demand profiling to avoid slowing
857+
// down calls that trigger the automatic profiling.
858+
if (session_opts_.config.graph_options().timeline_step() <= 0) {
859+
resp->mutable_metadata()->mutable_step_stats()->Swap(&step_stats_proto);
860+
}
828861
}
829862
}
830863

831-
// Makes async calls to workers without waiting deregistering subgraphs.
864+
void MasterSession::ReffedClientGraph::ProcessDeviceStats(
865+
ProfileHandler* ph, SimpleGraphExecutionState* execution_state,
866+
const DeviceStepStats& ds, bool is_rpc) {
867+
const string& dev_name = ds.device();
868+
VLOG(1) << "Device " << dev_name << " reports stats for "
869+
<< ds.node_stats_size() << " nodes";
870+
for (const auto& ns : ds.node_stats()) {
871+
if (is_rpc) {
872+
// We don't have access to a good Node pointer, so we rely on
873+
// sufficient data being present in the NodeExecStats.
874+
ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(),
875+
ns.timeline_label());
876+
} else {
877+
NodeDef ndef;
878+
Status s = execution_state->GlobalNodeDefByName(ns.node_name(), &ndef);
879+
const bool found_node_in_graph = s.ok();
880+
if (!found_node_in_graph && ns.timeline_label().empty()) {
881+
// The counter incrementing is not thread-safe. But we don't really
882+
// care.
883+
// TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
884+
// more general usage.
885+
static int log_counter = 0;
886+
if (log_counter < 10) {
887+
log_counter++;
888+
LOG(WARNING) << "Failed to find node " << ns.node_name()
889+
<< " for dev " << dev_name;
890+
}
891+
continue;
892+
}
893+
string optype = found_node_in_graph ? ndef.op() : ns.node_name();
894+
string details;
895+
if (!ns.timeline_label().empty()) {
896+
details = ns.timeline_label();
897+
} else if (found_node_in_graph) {
898+
details = DetailText(ndef, ns);
899+
} else {
900+
// Leave details string empty
901+
}
902+
ph->RecordOneOp(dev_name, ns, false /*is_copy*/, ns.node_name(), optype,
903+
details);
904+
}
905+
}
906+
}
907+
908+
// Asynchronously deregisters subgraphs on the workers, without waiting for the
909+
// result.
832910
void MasterSession::ReffedClientGraph::DeregisterPartitions() {
833911
struct Call {
834912
DeregisterGraphRequest req;
@@ -857,7 +935,6 @@ void BuildBuildGraphOptions(const RunStepRequest& req,
857935
opts->feed_endpoints.push_back(feed.name());
858936
}
859937
for (const auto& fetch : req.fetch()) {
860-
// TODO(touts): handle ref:
861938
opts->fetch_endpoints.push_back(fetch);
862939
}
863940
for (const auto& target : req.target()) {
@@ -901,10 +978,12 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) {
901978
}
902979

903980
MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env,
904-
std::vector<Device*>* remote_devs)
981+
std::vector<Device*>* remote_devs,
982+
StatsPublisherFactory stats_publisher_factory)
905983
: session_opts_(opt),
906984
env_(env),
907985
handle_(strings::FpToString(random::New64())),
986+
stats_publisher_factory_(std::move(stats_publisher_factory)),
908987
graph_version_(0),
909988
runs_(5),
910989
cancellation_manager_(new CancellationManager) {
@@ -944,8 +1023,8 @@ void MasterSession::UpdateLastAccessTime() {
9441023
Status MasterSession::Create(GraphDef* graph_def) {
9451024
if (session_opts_.config.graph_options().place_pruned_graph()) {
9461025
// TODO(b/29900832): Fix this or remove the option.
947-
return errors::Unimplemented(
948-
"MasterSession does not support the place_pruned_graph option.");
1026+
LOG(WARNING) << "Distributed session does not support the "
1027+
"place_pruned_graph option.";
9491028
}
9501029

9511030
SimpleGraphExecutionStateOptions options;
@@ -1014,7 +1093,7 @@ Status MasterSession::StartStep(const RunStepRequest& req,
10141093
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(*opts, &client_graph));
10151094
auto entry =
10161095
new ReffedClientGraph(handle_, *opts, std::move(client_graph),
1017-
session_opts_.config.graph_options());
1096+
session_opts_, stats_publisher_factory_);
10181097
iter = runs_.insert({hash, entry}).first;
10191098
auto obs_iter = obsolete_.find(hash);
10201099
if (obs_iter != obsolete_.end()) {
@@ -1121,9 +1200,14 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
11211200
const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
11221201
TRACEPRINTF("stepid %llu", step_id);
11231202

1203+
std::unique_ptr<ProfileHandler> ph;
11241204
pss.collect_timeline = req->options().trace_level() == RunOptions::FULL_TRACE;
11251205
pss.collect_costs = (0 == (count % CostFrequency(count)));
1126-
pss.collect_rpcs = false;
1206+
ph = rcg->GetProfileHandler(step_id, count, req->options());
1207+
if (ph) {
1208+
pss.collect_timeline = true;
1209+
pss.collect_rpcs = true;
1210+
}
11271211

11281212
TF_RETURN_IF_ERROR(rcg->RunPartitions(env_, step_id, count,
11291213
execution_state_.get(), &pss, opts,
@@ -1132,11 +1216,14 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
11321216
pss.end_micros = Env::Default()->NowMicros();
11331217

11341218
// Schedule post-processing and cleanup to be done asynchronously.
1135-
rcg->ProcessStats(env_, step_id, &pss, execution_state_.get(), resp);
1136-
rcg->CleanupPartitionsAsync(step_id, [](const Status& s) {
1219+
rcg->Ref();
1220+
rcg->ProcessStats(env_, step_id, &pss, execution_state_.get(), ph.get(),
1221+
resp);
1222+
rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) {
11371223
if (!s.ok()) {
11381224
LOG(ERROR) << "Cleanup partition error: " << s;
11391225
}
1226+
rcg->Unref();
11401227
});
11411228
return Status::OK();
11421229
}
@@ -1161,10 +1248,11 @@ Status MasterSession::Close() {
11611248

11621249
namespace internal {
11631250

1164-
MasterSessionInterface* NewMasterSession(const SessionOptions& options,
1165-
const MasterEnv* env,
1166-
std::vector<Device*>* remote_devs) {
1167-
return new MasterSession(options, env, remote_devs);
1251+
MasterSessionInterface* NewMasterSession(
1252+
const SessionOptions& options, const MasterEnv* env,
1253+
std::vector<Device*>* remote_devs,
1254+
StatsPublisherFactory stats_publisher_factory) {
1255+
return new MasterSession(options, env, remote_devs, stats_publisher_factory);
11681256
}
11691257

11701258
} // end namespace internal

0 commit comments

Comments
 (0)