@@ -24,6 +24,7 @@ an efficient work-stealing scheduling algorithm to run a taskflow.
24
24
class Executor {
25
25
26
26
friend class Subflow ;
27
+ friend class cudaFlow ;
27
28
28
29
struct Worker {
29
30
size_t id;
@@ -271,17 +272,21 @@ class Executor {
271
272
void _invoke_dynamic_work_external (Node*, Graph&, bool );
272
273
void _invoke_condition_work (Worker&, Node*);
273
274
void _invoke_module_work (Worker&, Node*);
274
-
275
- #ifdef TF_ENABLE_CUDA
276
- void _invoke_cudaflow_work (Worker&, Node*);
277
- void _invoke_cudaflow_work_internal (Worker&, Node*);
278
- #endif
279
-
280
275
void _set_up_topology (Topology*);
281
276
void _tear_down_topology (Topology*);
282
277
void _increment_topology ();
283
278
void _decrement_topology ();
284
279
void _decrement_topology_and_notify ();
280
+
281
+ #ifdef TF_ENABLE_CUDA
282
+ void _invoke_cudaflow_work (Worker&, Node*);
283
+
284
+ template <typename P>
285
+ void _invoke_cudaflow_work_internal (Worker&, cudaFlow&, P&&);
286
+
287
+ template <typename P>
288
+ void _invoke_cudaflow_work_external (cudaFlow&, P&&);
289
+ #endif
285
290
};
286
291
287
292
@@ -993,12 +998,80 @@ inline void Executor::_invoke_condition_work(Worker& worker, Node* node) {
993
998
#ifdef TF_ENABLE_CUDA
994
999
// Procedure: _invoke_cudaflow_work
995
1000
inline void Executor::_invoke_cudaflow_work (Worker& worker, Node* node) {
1001
+
996
1002
_observer_prologue (worker, node);
997
- _invoke_cudaflow_work_internal (worker, node);
1003
+
1004
+ assert (worker.domain == node->domain ());
1005
+
1006
+ // create a cudaflow
1007
+ auto & h = nstd::get<Node::cudaFlowWork>(node->_handle );
1008
+
1009
+ h.graph .clear ();
1010
+
1011
+ cudaFlow cf (*this , h.graph , 0 );
1012
+
1013
+ h.work (cf);
1014
+
1015
+ // join the cudaflow
1016
+ if (cf._joinable ) {
1017
+ _invoke_cudaflow_work_internal (
1018
+ worker, cf, [repeat=1 ] () mutable { return repeat-- == 0 ; }
1019
+ );
1020
+ cf._joinable = false ;
1021
+ }
1022
+
998
1023
_observer_epilogue (worker, node);
999
1024
}
1000
1025
1001
1026
// Procedure: _invoke_cudaflow_work_internal
1027
+ template <typename P>
1028
+ void Executor::_invoke_cudaflow_work_internal (
1029
+ Worker& w, cudaFlow& cf, P&& predicate
1030
+ ) {
1031
+
1032
+ if (cf.empty ()) {
1033
+ return ;
1034
+ }
1035
+
1036
+ cudaScopedDevice ctx (cf._device );
1037
+
1038
+ auto s = _cuda_devices[cf._device ].streams [w.id - _id_offset[w.domain ]];
1039
+
1040
+ // transforms cudaFlow to a native cudaGraph under the specified device
1041
+ // and launches the graph through a given or an internal device stream
1042
+ // TODO: need to leverage cudaGraphExecUpdate for changes between
1043
+ // successive offload calls; right now, we assume the graph
1044
+ // is not changed (only update parameter is allowed)
1045
+ cf._graph ._create_native_graph ();
1046
+
1047
+ while (!predicate ()) {
1048
+
1049
+ TF_CHECK_CUDA (
1050
+ cudaGraphLaunch (cf._graph ._native_exec_handle , s),
1051
+ " failed to launch cudaFlow on device " , cf._device
1052
+ );
1053
+
1054
+ TF_CHECK_CUDA (
1055
+ cudaStreamSynchronize (s),
1056
+ " failed to synchronize cudaFlow on device " , cf._device
1057
+ );
1058
+ }
1059
+
1060
+ cf._graph ._destroy_native_graph ();
1061
+ }
1062
+
1063
+ // Procedure: _invoke_cudaflow_work_external
1064
+ template <typename P>
1065
+ void Executor::_invoke_cudaflow_work_external (cudaFlow& cf, P&& predicate) {
1066
+
1067
+ auto w = _per_thread ().worker ;
1068
+
1069
+ assert (w && w->executor == this );
1070
+
1071
+ _invoke_cudaflow_work_internal (*w, cf, std::forward<P>(predicate));
1072
+ }
1073
+
1074
+ /* // Procedure: _invoke_cudaflow_work_internal
1002
1075
inline void Executor::_invoke_cudaflow_work_internal(Worker& w, Node* node) {
1003
1076
1004
1077
assert(w.domain == node->domain());
@@ -1007,7 +1080,7 @@ inline void Executor::_invoke_cudaflow_work_internal(Worker& w, Node* node) {
1007
1080
1008
1081
h.graph.clear();
1009
1082
1010
- cudaFlow cf (h.graph , [repeat=1 ] () mutable { return repeat-- == 0 ; });
1083
+ cudaFlow cf(*this, h.graph, 0 , [repeat=1] () mutable { return repeat-- == 0; });
1011
1084
1012
1085
h.work(cf);
1013
1086
@@ -1021,34 +1094,23 @@ inline void Executor::_invoke_cudaflow_work_internal(Worker& w, Node* node) {
1021
1094
1022
1095
cudaScopedDevice ctx(d);
1023
1096
1024
- auto s = cf._stream ? *(cf._stream ) :
1025
- _cuda_devices[d].streams [w.id - _id_offset[w.domain ]];
1026
-
1027
- h.graph ._make_native_graph ();
1097
+ auto s = _cuda_devices[d].streams[w.id - _id_offset[w.domain]];
1028
1098
1029
- cudaGraphExec_t exec ;
1099
+ h.graph._create_native_graph() ;
1030
1100
1031
- TF_CHECK_CUDA (
1032
- cudaGraphInstantiate (&exec, h.graph ._native_handle , nullptr , nullptr , 0 ),
1033
- " failed to create an executable cudaGraph"
1034
- );
1035
-
1036
1101
while(!cf._predicate()) {
1037
1102
TF_CHECK_CUDA(
1038
- cudaGraphLaunch (exec, s), " failed to launch cudaGraph on stream " , s
1103
+ cudaGraphLaunch(h.graph._native_exec_handle, s),
1104
+ "failed to launch cudaGraph on stream ", s
1039
1105
);
1040
1106
1041
1107
TF_CHECK_CUDA(
1042
1108
cudaStreamSynchronize(s), "failed to synchronize stream ", s
1043
1109
);
1044
1110
}
1045
1111
1046
- TF_CHECK_CUDA (
1047
- cudaGraphExecDestroy (exec), " failed to destroy an executable cudaGraph"
1048
- );
1049
-
1050
- h.graph .clear_native_graph ();
1051
- }
1112
+ h.graph._destroy_native_graph();
1113
+ }*/
1052
1114
#endif
1053
1115
1054
1116
// Procedure: _invoke_module_work
@@ -1283,6 +1345,34 @@ inline void Subflow::detach() {
1283
1345
_joinable = false ;
1284
1346
}
1285
1347
1348
+ // ----------------------------------------------------------------------------
1349
+ // cudaFlow
1350
+ // ----------------------------------------------------------------------------
1351
+
1352
+ #ifdef TF_ENABLE_CUDA
1353
+
1354
+ template <typename P>
1355
+ void cudaFlow::offload (P&& predicate) {
1356
+
1357
+ if (!_joinable) {
1358
+ TF_THROW (" cudaFlow already joined" );
1359
+ }
1360
+
1361
+ _executor._invoke_cudaflow_work_external (*this , std::forward<P>(predicate));
1362
+ }
1363
+
1364
+ template <typename P>
1365
+ void cudaFlow::join (P&& predicate) {
1366
+
1367
+ if (!_joinable) {
1368
+ TF_THROW (" cudaFlow already joined" );
1369
+ }
1370
+
1371
+ _executor._invoke_cudaflow_work_external (*this , std::forward<P>(predicate));
1372
+ _joinable = false ;
1373
+ }
1374
+
1375
+ #endif
1286
1376
1287
1377
} // end of namespace tf -----------------------------------------------------
1288
1378
0 commit comments