@@ -21,6 +21,7 @@ class ForAllReplace;
2121class AddSuchThatPredicates ;
2222class Parallelize ;
2323class TopoReorder ;
24+ class SetAssembleStrategy ;
2425
2526// / A transformation is an optimization that transforms a statement in the
2627// / concrete index notation into a new statement that computes the same result
@@ -34,6 +35,7 @@ class Transformation {
3435 Transformation (Parallelize);
3536 Transformation (TopoReorder);
3637 Transformation (AddSuchThatPredicates);
38+ Transformation (SetAssembleStrategy);
3739
3840 IndexStmt apply (IndexStmt stmt, std::string *reason = nullptr ) const ;
3941
@@ -108,6 +110,7 @@ class Precompute : public TransformationInterface {
108110// / Print a precompute command.
109111std::ostream &operator <<(std::ostream &, const Precompute &);
110112
113+
111114// / Replaces all occurrences of directly nested forall nodes of pattern with
112115// / directly nested loops of replacement
113116class ForAllReplace : public TransformationInterface {
@@ -129,6 +132,10 @@ class ForAllReplace : public TransformationInterface {
129132 std::shared_ptr<Content> content;
130133};
131134
135+ // / Print a ForAllReplace command.
136+ std::ostream &operator <<(std::ostream &, const ForAllReplace &);
137+
138+
132139// / Adds a SuchThat node if it does not exist and adds the given IndexVarRels
133140class AddSuchThatPredicates : public TransformationInterface {
134141public:
@@ -147,6 +154,9 @@ class AddSuchThatPredicates : public TransformationInterface {
147154 std::shared_ptr<Content> content;
148155};
149156
157+ std::ostream& operator <<(std::ostream&, const AddSuchThatPredicates&);
158+
159+
150160// / The parallelize optimization tags a Forall as parallelized
151161// / after checking for preconditions
152162class Parallelize : public TransformationInterface {
@@ -169,13 +179,28 @@ class Parallelize : public TransformationInterface {
169179 std::shared_ptr<Content> content;
170180};
171181
172- // / Print a ForAllReplace command.
173- std::ostream &operator <<(std::ostream &, const ForAllReplace &);
174-
175182// / Print a parallelize command.
176183std::ostream& operator <<(std::ostream&, const Parallelize&);
177184
178- std::ostream& operator <<(std::ostream&, const AddSuchThatPredicates&);
185+
186+ class SetAssembleStrategy : public TransformationInterface {
187+ public:
188+ SetAssembleStrategy (TensorVar result, AssembleStrategy strategy);
189+
190+ TensorVar getResult () const ;
191+ AssembleStrategy getAssembleStrategy () const ;
192+
193+ IndexStmt apply (IndexStmt stmt, std::string *reason = nullptr ) const ;
194+
195+ void print (std::ostream &os) const ;
196+
197+ private:
198+ struct Content ;
199+ std::shared_ptr<Content> content;
200+ };
201+
202+ // / Print a SetAssembleStrategy command.
203+ std::ostream &operator <<(std::ostream &, const SetAssembleStrategy&);
179204
180205// Autoscheduling functions
181206
@@ -202,12 +227,11 @@ IndexStmt reorderLoopsTopologically(IndexStmt stmt);
202227 */
203228IndexStmt scalarPromote (IndexStmt stmt);
204229
205- IndexStmt insertAttributeQueries (IndexStmt stmt);
206-
207230/* *
208231 * Insert where statements with temporaries into the following statements kinds:
209232 * 1. The result is a is scattered into but does not support random insert.
210233 */
211234IndexStmt insertTemporaries (IndexStmt stmt);
235+
212236}
213237#endif
0 commit comments