@@ -21,7 +21,9 @@ limitations under the License.
2121
2222#include " tensorflow/core/common_runtime/device_set.h"
2323#include " tensorflow/core/common_runtime/process_util.h"
24+ #include " tensorflow/core/common_runtime/profile_handler.h"
2425#include " tensorflow/core/common_runtime/simple_graph_execution_state.h"
26+ #include " tensorflow/core/common_runtime/stats_publisher_interface.h"
2527#include " tensorflow/core/distributed_runtime/master_env.h"
2628#include " tensorflow/core/distributed_runtime/master_session_interface.h"
2729#include " tensorflow/core/distributed_runtime/scheduler.h"
@@ -75,9 +77,10 @@ class MasterSession : public MasterSessionInterface {
7577 //
7678 // The caller takes ownership of all remote devices.
7779 MasterSession (const SessionOptions& options, const MasterEnv* env,
78- std::vector<Device*>* remote_devs);
80+ std::vector<Device*>* remote_devs,
81+ StatsPublisherFactory stats_publisher_factory);
7982
80- // Initialize the Session for "def". Must be called before Extend(),
83+ // Initialize the MasterSession for "def". Must be called before Extend(),
8184 // Run(), or Close().
8285 //
8386 // The callee may clear "def".
@@ -130,6 +133,8 @@ class MasterSession : public MasterSessionInterface {
130133 // The device set used by this session.
131134 DeviceSet devices_;
132135
136+ StatsPublisherFactory stats_publisher_factory_;
137+
133138 std::atomic_ulong last_access_time_usec_;
134139
135140 mutex mu_;
@@ -173,35 +178,37 @@ class MasterSession : public MasterSessionInterface {
173178 TF_DISALLOW_COPY_AND_ASSIGN (MasterSession);
174179};
175180
176- // Session wraps SimpleClientGraph in a reference counted object. This way,
177- // Session can clear up the cache mapping Run requests to compiled
178- // graphs while the compiled graph is still being used.
181+ // MasterSession wraps SimpleClientGraph in a reference counted object.
182+ // This way, MasterSession can clear up the cache mapping Run requests to
183+ // compiled graphs while the compiled graph is still being used.
179184//
180185// TODO(zhifengc): Cleanup this class. It's becoming messy.
181186class MasterSession ::ReffedClientGraph : public core::RefCounted {
182187 public:
183188 ReffedClientGraph (const string& handle, const BuildGraphOptions& bopts,
184189 std::unique_ptr<SimpleClientGraph> cg,
185- const GraphOptions& graph_opts)
190+ const SessionOptions& session_opts,
191+ StatsPublisherFactory stats_publisher_factory)
186192 : session_handle_(handle),
187193 client_graph_ (std::move(cg)),
188194 bopts_(bopts),
189- graph_opts_(graph_opts ) {
195+ session_opts_(session_opts ) {
190196 VLOG (1 ) << " Created ReffedClientGraph for node with "
191197 << client_graph_->graph .num_node_ids ();
192198
193- const string key =
194- strings::StrCat (" {" , str_util::Join (bopts.feed_endpoints , " ," ), " },{" ,
195- str_util::Join (bopts.target_nodes , " ," ), " },{" ,
196- str_util::Join (bopts.fetch_endpoints , " ," ), " }" );
197- // TODO(mrry): Publish information about the graph (such as
198- // timelines, the pruned graph, statistics, etc.).
199+ stats_publisher_ = stats_publisher_factory (handle, bopts, session_opts);
199200 }
200201
201202 ~ReffedClientGraph () override { DeregisterPartitions (); }
202203
203204 const SimpleClientGraph* client_graph () { return client_graph_.get (); }
204205
206+ std::unique_ptr<ProfileHandler> GetProfileHandler (uint64 step,
207+ int64 execution_count,
208+ const RunOptions& ropts) {
209+ return stats_publisher_->GetProfileHandler (step, execution_count, ropts);
210+ }
211+
205212 // Turn RPC logging on or off, both at the WorkerCache used by this
206213 // master process, and at each remote worker in use for the current
207214 // partitions.
@@ -297,8 +304,9 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
297304 // Post-processing of any runtime statistics gathered during execution.
298305 void ProcessStats (const MasterEnv* env, int64 step_id, PerStepState* pss,
299306 SimpleGraphExecutionState* execution_state,
300- RunStepResponse* resp);
301- void ProcessDeviceStats (SimpleGraphExecutionState* execution_state,
307+ ProfileHandler* ph, RunStepResponse* resp);
308+ void ProcessDeviceStats (ProfileHandler* ph,
309+ SimpleGraphExecutionState* execution_state,
302310 const DeviceStepStats& ds, bool is_rpc);
303311
304312 string DetailText (const NodeDef& def, const NodeExecStats& ns) {
@@ -323,7 +331,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
323331 const std::unique_ptr<SimpleClientGraph> client_graph_;
324332 std::unordered_set<const Node*> nodes_needing_input_mapping_;
325333 BuildGraphOptions bopts_;
326- const GraphOptions graph_opts_ ;
334+ const SessionOptions session_opts_ ;
327335
328336 // Graph partitioned into per-location subgraphs.
329337 struct Part {
@@ -365,6 +373,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
365373 // init_result_ remembers the initialization error if any.
366374 Status init_result_ GUARDED_BY (mu_);
367375
376+ std::unique_ptr<StatsPublisherInterface> stats_publisher_;
377+
368378 // Send/Recv nodes that are the result of client-added
369379 // feeds and fetches must be tracked so that the tensors
370380 // can be added to the local rendezvous.
@@ -391,6 +401,12 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions(
391401 init_started_ = true ;
392402 mu_.unlock ();
393403 Status s = DoRegisterPartitions (env, popts, func_def_lib);
404+ std::vector<const GraphDef*> graph_defs;
405+ graph_defs.reserve (partitions_.size ());
406+ for (const Part& part : partitions_) {
407+ graph_defs.push_back (&part.gdef );
408+ }
409+ stats_publisher_->PublishGraphProto (graph_defs);
394410 mu_.lock ();
395411 init_result_ = s;
396412 init_done_.Notify ();
@@ -504,7 +520,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
504520 Call* c = &calls[i];
505521 c->req .set_session_handle (session_handle_);
506522 *c->req .mutable_graph_def () = part.gdef ;
507- *c->req .mutable_graph_options () = graph_opts_ ;
523+ *c->req .mutable_graph_options () = session_opts_. config . graph_options () ;
508524 VLOG (2 ) << " Register " << part.gdef .DebugString ();
509525 auto cb = [c](const Status& s) {
510526 c->status = s;
@@ -800,35 +816,97 @@ void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
800816
801817void MasterSession::ReffedClientGraph::ProcessStats (
802818 const MasterEnv* env, int64 step_id, PerStepState* pss,
803- SimpleGraphExecutionState* execution_state, RunStepResponse* resp) {
819+ SimpleGraphExecutionState* execution_state, ProfileHandler* ph,
820+ RunStepResponse* resp) {
804821 if (!pss->collect_costs && !pss->collect_timeline ) return ;
805822
806823 // Out-of-band logging data is collected now, during post-processing.
807824 if (pss->collect_timeline ) {
808- // TODO(suharshs): Can these two RPCs be combined?
809825 SetRPCLogging (env, false );
810826 RetrieveLogs (env, step_id, &pss->rpc_stats );
811827 }
812- for (int i = 0 ; i < partitions_.size (); ++i) {
828+ for (size_t i = 0 ; i < partitions_.size (); ++i) {
813829 const StepStats& ss = pss->step_stats [i];
814830 if (pss->collect_costs ) {
815831 execution_state->UpdateCostsFromStats (ss);
816832 }
833+ if (ph) {
834+ for (const auto & ds : ss.dev_stats ()) {
835+ ProcessDeviceStats (ph, execution_state, ds, false /* is_rpc*/ );
836+ }
837+ }
838+ }
839+ if (ph) {
840+ for (const auto & ds : pss->rpc_stats .dev_stats ()) {
841+ ProcessDeviceStats (ph, execution_state, ds, true /* is_rpc*/ );
842+ }
843+ ph->StepDone (pss->start_micros , pss->end_micros ,
844+ Microseconds (0 ) /* cleanup_time*/ , 0 /* total_runops*/ ,
845+ Status::OK ());
817846 }
818847 // Assemble all stats for this timeline into a merged StepStats.
819848 StepStats step_stats_proto;
820849 if (pss->collect_timeline ) {
821850 step_stats_proto = pss->rpc_stats ;
822- for (int i = 0 ; i < partitions_.size (); ++i) {
851+ for (size_t i = 0 ; i < partitions_.size (); ++i) {
823852 const StepStats& ss = pss->step_stats [i];
824853 step_stats_proto.MergeFrom (ss);
825854 }
826- // TODO(suharshs): handle timeline_step when adding timeline support.
827- resp->mutable_metadata ()->mutable_step_stats ()->Swap (&step_stats_proto);
855+ stats_publisher_->PublishStatsProto (step_stats_proto);
856+ // Copy the stats back, but only for on-demand profiling to avoid slowing
857+ // down calls that trigger the automatic profiling.
858+ if (session_opts_.config .graph_options ().timeline_step () <= 0 ) {
859+ resp->mutable_metadata ()->mutable_step_stats ()->Swap (&step_stats_proto);
860+ }
828861 }
829862}
830863
831- // Makes async calls to workers without waiting deregistering subgraphs.
864+ void MasterSession::ReffedClientGraph::ProcessDeviceStats (
865+ ProfileHandler* ph, SimpleGraphExecutionState* execution_state,
866+ const DeviceStepStats& ds, bool is_rpc) {
867+ const string& dev_name = ds.device ();
868+ VLOG (1 ) << " Device " << dev_name << " reports stats for "
869+ << ds.node_stats_size () << " nodes" ;
870+ for (const auto & ns : ds.node_stats ()) {
871+ if (is_rpc) {
872+ // We don't have access to a good Node pointer, so we rely on
873+ // sufficient data being present in the NodeExecStats.
874+ ph->RecordOneOp (dev_name, ns, true /* is_copy*/ , " " , ns.node_name (),
875+ ns.timeline_label ());
876+ } else {
877+ NodeDef ndef;
878+ Status s = execution_state->GlobalNodeDefByName (ns.node_name (), &ndef);
879+ const bool found_node_in_graph = s.ok ();
880+ if (!found_node_in_graph && ns.timeline_label ().empty ()) {
881+ // The counter incrementing is not thread-safe. But we don't really
882+ // care.
883+ // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
884+ // more general usage.
885+ static int log_counter = 0 ;
886+ if (log_counter < 10 ) {
887+ log_counter++;
888+ LOG (WARNING) << " Failed to find node " << ns.node_name ()
889+ << " for dev " << dev_name;
890+ }
891+ continue ;
892+ }
893+ string optype = found_node_in_graph ? ndef.op () : ns.node_name ();
894+ string details;
895+ if (!ns.timeline_label ().empty ()) {
896+ details = ns.timeline_label ();
897+ } else if (found_node_in_graph) {
898+ details = DetailText (ndef, ns);
899+ } else {
900+ // Leave details string empty
901+ }
902+ ph->RecordOneOp (dev_name, ns, false /* is_copy*/ , ns.node_name (), optype,
903+ details);
904+ }
905+ }
906+ }
907+
908+ // Asynchronously deregisters subgraphs on the workers, without waiting for the
909+ // result.
832910void MasterSession::ReffedClientGraph::DeregisterPartitions () {
833911 struct Call {
834912 DeregisterGraphRequest req;
@@ -857,7 +935,6 @@ void BuildBuildGraphOptions(const RunStepRequest& req,
857935 opts->feed_endpoints .push_back (feed.name ());
858936 }
859937 for (const auto & fetch : req.fetch ()) {
860- // TODO(touts): handle ref:
861938 opts->fetch_endpoints .push_back (fetch);
862939 }
863940 for (const auto & target : req.target ()) {
@@ -901,10 +978,12 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) {
901978}
902979
903980MasterSession::MasterSession (const SessionOptions& opt, const MasterEnv* env,
904- std::vector<Device*>* remote_devs)
981+ std::vector<Device*>* remote_devs,
982+ StatsPublisherFactory stats_publisher_factory)
905983 : session_opts_(opt),
906984 env_ (env),
907985 handle_(strings::FpToString(random::New64())),
986+ stats_publisher_factory_(std::move(stats_publisher_factory)),
908987 graph_version_(0 ),
909988 runs_(5 ),
910989 cancellation_manager_(new CancellationManager) {
@@ -944,8 +1023,8 @@ void MasterSession::UpdateLastAccessTime() {
9441023Status MasterSession::Create (GraphDef* graph_def) {
9451024 if (session_opts_.config .graph_options ().place_pruned_graph ()) {
9461025 // TODO(b/29900832): Fix this or remove the option.
947- return errors::Unimplemented (
948- " MasterSession does not support the place_pruned_graph option." ) ;
1026+ LOG (WARNING) << " Distributed session does not support the "
1027+ " place_pruned_graph option." ;
9491028 }
9501029
9511030 SimpleGraphExecutionStateOptions options;
@@ -1014,7 +1093,7 @@ Status MasterSession::StartStep(const RunStepRequest& req,
10141093 TF_RETURN_IF_ERROR (execution_state_->BuildGraph (*opts, &client_graph));
10151094 auto entry =
10161095 new ReffedClientGraph (handle_, *opts, std::move (client_graph),
1017- session_opts_. config . graph_options () );
1096+ session_opts_, stats_publisher_factory_ );
10181097 iter = runs_.insert ({hash, entry}).first ;
10191098 auto obs_iter = obsolete_.find (hash);
10201099 if (obs_iter != obsolete_.end ()) {
@@ -1121,9 +1200,14 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
11211200 const uint64 step_id = (random::New64 () & ((1uLL << 56 ) - 1 )) | (1uLL << 56 );
11221201 TRACEPRINTF (" stepid %llu" , step_id);
11231202
1203+ std::unique_ptr<ProfileHandler> ph;
11241204 pss.collect_timeline = req->options ().trace_level () == RunOptions::FULL_TRACE;
11251205 pss.collect_costs = (0 == (count % CostFrequency (count)));
1126- pss.collect_rpcs = false ;
1206+ ph = rcg->GetProfileHandler (step_id, count, req->options ());
1207+ if (ph) {
1208+ pss.collect_timeline = true ;
1209+ pss.collect_rpcs = true ;
1210+ }
11271211
11281212 TF_RETURN_IF_ERROR (rcg->RunPartitions (env_, step_id, count,
11291213 execution_state_.get (), &pss, opts,
@@ -1132,11 +1216,14 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
11321216 pss.end_micros = Env::Default ()->NowMicros ();
11331217
11341218 // Schedule post-processing and cleanup to be done asynchronously.
1135- rcg->ProcessStats (env_, step_id, &pss, execution_state_.get (), resp);
1136- rcg->CleanupPartitionsAsync (step_id, [](const Status& s) {
1219+ rcg->Ref ();
1220+ rcg->ProcessStats (env_, step_id, &pss, execution_state_.get (), ph.get (),
1221+ resp);
1222+ rcg->CleanupPartitionsAsync (step_id, [rcg](const Status& s) {
11371223 if (!s.ok ()) {
11381224 LOG (ERROR) << " Cleanup partition error: " << s;
11391225 }
1226+ rcg->Unref ();
11401227 });
11411228 return Status::OK ();
11421229}
@@ -1161,10 +1248,11 @@ Status MasterSession::Close() {
11611248
11621249namespace internal {
11631250
1164- MasterSessionInterface* NewMasterSession (const SessionOptions& options,
1165- const MasterEnv* env,
1166- std::vector<Device*>* remote_devs) {
1167- return new MasterSession (options, env, remote_devs);
1251+ MasterSessionInterface* NewMasterSession (
1252+ const SessionOptions& options, const MasterEnv* env,
1253+ std::vector<Device*>* remote_devs,
1254+ StatsPublisherFactory stats_publisher_factory) {
1255+ return new MasterSession (options, env, remote_devs, stats_publisher_factory);
11681256}
11691257
11701258} // end namespace internal
0 commit comments