Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 39b75e6

Browse files
committed
Implemented assemble scheduling command + don't sort sparse accelerator if performing reduction
1 parent f0232d3 commit 39b75e6

10 files changed

Lines changed: 675 additions & 553 deletions

File tree

include/taco/index_notation/index_notation.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,8 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
640640
/// integer number of iterations
641641
/// Preconditions: unrollFactor is a positive nonzero integer
642642
IndexStmt unroll(IndexVar i, size_t unrollFactor) const;
643+
644+
IndexStmt assemble(TensorVar result, AssembleStrategy strategy) const;
643645
};
644646

645647
/// Check if two index statements are isomorphic.

include/taco/index_notation/transformations.h

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class ForAllReplace;
2121
class AddSuchThatPredicates;
2222
class Parallelize;
2323
class TopoReorder;
24+
class SetAssembleStrategy;
2425

2526
/// A transformation is an optimization that transforms a statement in the
2627
/// concrete index notation into a new statement that computes the same result
@@ -34,6 +35,7 @@ class Transformation {
3435
Transformation(Parallelize);
3536
Transformation(TopoReorder);
3637
Transformation(AddSuchThatPredicates);
38+
Transformation(SetAssembleStrategy);
3739

3840
IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const;
3941

@@ -108,6 +110,7 @@ class Precompute : public TransformationInterface {
108110
/// Print a precompute command.
109111
std::ostream &operator<<(std::ostream &, const Precompute &);
110112

113+
111114
/// Replaces all occurrences of directly nested forall nodes of pattern with
112115
/// directly nested loops of replacement
113116
class ForAllReplace : public TransformationInterface {
@@ -129,6 +132,10 @@ class ForAllReplace : public TransformationInterface {
129132
std::shared_ptr<Content> content;
130133
};
131134

135+
/// Print a ForAllReplace command.
136+
std::ostream &operator<<(std::ostream &, const ForAllReplace &);
137+
138+
132139
/// Adds a SuchThat node if it does not exist and adds the given IndexVarRels
133140
class AddSuchThatPredicates : public TransformationInterface {
134141
public:
@@ -147,6 +154,9 @@ class AddSuchThatPredicates : public TransformationInterface {
147154
std::shared_ptr<Content> content;
148155
};
149156

157+
std::ostream& operator<<(std::ostream&, const AddSuchThatPredicates&);
158+
159+
150160
/// The parallelize optimization tags a Forall as parallelized
151161
/// after checking for preconditions
152162
class Parallelize : public TransformationInterface {
@@ -169,13 +179,28 @@ class Parallelize : public TransformationInterface {
169179
std::shared_ptr<Content> content;
170180
};
171181

172-
/// Print a ForAllReplace command.
173-
std::ostream &operator<<(std::ostream &, const ForAllReplace &);
174-
175182
/// Print a parallelize command.
176183
std::ostream& operator<<(std::ostream&, const Parallelize&);
177184

178-
std::ostream& operator<<(std::ostream&, const AddSuchThatPredicates&);
185+
186+
class SetAssembleStrategy : public TransformationInterface {
187+
public:
188+
SetAssembleStrategy(TensorVar result, AssembleStrategy strategy);
189+
190+
TensorVar getResult() const;
191+
AssembleStrategy getAssembleStrategy() const;
192+
193+
IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const;
194+
195+
void print(std::ostream &os) const;
196+
197+
private:
198+
struct Content;
199+
std::shared_ptr<Content> content;
200+
};
201+
202+
/// Print a SetAssembleStrategy command.
203+
std::ostream &operator<<(std::ostream &, const SetAssembleStrategy&);
179204

180205
// Autoscheduling functions
181206

@@ -202,12 +227,11 @@ IndexStmt reorderLoopsTopologically(IndexStmt stmt);
202227
*/
203228
IndexStmt scalarPromote(IndexStmt stmt);
204229

205-
IndexStmt insertAttributeQueries(IndexStmt stmt);
206-
207230
/**
208231
* Insert where statements with temporaries into the following statements kinds:
209232
* 1. The result is a is scattered into but does not support random insert.
210233
*/
211234
IndexStmt insertTemporaries(IndexStmt stmt);
235+
212236
}
213237
#endif

include/taco/ir_tags.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ enum class BoundType {
2727
MinExact, MinConstraint, MaxExact, MaxConstraint
2828
};
2929
extern const char *BoundType_NAMES[];
30+
31+
enum class AssembleStrategy {
32+
Append, Insert
33+
};
34+
extern const char *AssembleStrategy_NAMES[];
35+
3036
}
3137

3238
#endif //TACO_IR_TAGS_H

include/taco/lower/lowerer_impl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef TACO_LOWERER_IMPL_H
22
#define TACO_LOWERER_IMPL_H
33

4+
#include <utility>
45
#include <vector>
56
#include <map>
67
#include <set>
@@ -348,7 +349,7 @@ class LowererImpl : public util::Uncopyable {
348349

349350
/// Returns true iff the temporary used in the where statement is dense and sparse iteration over that
350351
/// temporary can be automaticallty supported by the compiler.
351-
bool canAccelerateDenseTemp(Where where);
352+
std::pair<bool,bool> canAccelerateDenseTemp(Where where);
352353

353354
/// Initializes a temporary workspace
354355
std::vector<ir::Stmt> codeToInitializeTemporary(Where where);

src/index_notation/index_notation.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,6 +1560,17 @@ IndexStmt IndexStmt::unroll(IndexVar i, size_t unrollFactor) const {
15601560
return UnrollLoop(i, unrollFactor).rewrite(*this);
15611561
}
15621562

1563+
IndexStmt IndexStmt::assemble(TensorVar result,
1564+
AssembleStrategy strategy) const {
1565+
string reason;
1566+
IndexStmt transformed =
1567+
SetAssembleStrategy(result, strategy).apply(*this, &reason);
1568+
if (!transformed.defined()) {
1569+
taco_uerror << reason;
1570+
}
1571+
return transformed;
1572+
}
1573+
15631574
std::ostream& operator<<(std::ostream& os, const IndexStmt& expr) {
15641575
if (!expr.defined()) return os << "IndexStmt()";
15651576
IndexNotationPrinter printer(os);

0 commit comments

Comments
 (0)