Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 98c2096

Browse files
jroeschrkimballjunrushaotkonoligecomaniac
authored
[Diagnostics][Relay][InferType] Refactor InferType to work on whole module, and use new diagnostics. (apache#6274)
* Refactor the type checker to use diagnostics Although this patch is very large and seemingly disjoint the fixes are required to get it working for the entire stack. I started with first changing InferType to use the diagnostics, these weren't yet in the pass manager so this required changes to module and module pass. InferType wasn't actually written correctly as a pass requring refactoring there, then in order to add spans to AST it required turning on AnnotateSpans which in term required changes to the parser, and module to make it possible to use the errors. These changes to parse and module required changes to diagnostics and InferType. Althought seemingly disconnected there are hidden cycles between the components which require simultaneous change in order to remove the old error reporting. A huge change due to this patch is that the module no longer implicitly type checks functions which are added. * Apply suggestions from code review Co-authored-by: Robert Kimball <[email protected]> Co-authored-by: Junru Shao <[email protected]> * Apply suggestions from code review Co-authored-by: Tristan Konolige <[email protected]> * Clean up parser * CR feedback * Apply Bobs suggestions * Fix up Python interface for diagnostics * Fix test_ir_parser and formatting * Fix cpplint * Fix lint * Fix format * More lint * Fix format * Kill dead doc comment * Fix documentation comment * Rebase fixups * Add docs for type.h * Fix parser.cc * Fix unittests * Fix black * Skip previously typechecked functions * fix ACL * Fix numerous issues * Add repr method * Fix issue with Pytest, I am ready to cry * Fix the rest of tests * Kill dead code * Fix dignostic tests * Fix more tests * fix more tests (VeriSilicon#11) * Fix diagnostic.py deinit bug * Fix deinit issue * Format * Tweak disabling of override * Format * Fix BYOC * Fix TensorArray stuff * Fix PyTorch * Format * Format Co-authored-by: Robert Kimball <[email protected]> Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Tristan Konolige <[email protected]> Co-authored-by: Cody Yu <[email protected]> Co-authored-by: Zhi <[email protected]>
1 parent f73a1f6 commit 98c2096

128 files changed

Lines changed: 4274 additions & 2594 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ if(MSVC)
9797
add_definitions(-D_CRT_SECURE_NO_WARNINGS)
9898
add_definitions(-D_SCL_SECURE_NO_WARNINGS)
9999
add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE)
100+
add_definitions(-DNOMINMAX)
100101
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
101102
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
102103
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj")

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ jvminstall:
135135
mvn install -P$(JVM_PKG_PROFILE) -Dcxx="$(CXX)" \
136136
-Dcflags="$(PKG_CFLAGS)" -Dldflags="$(PKG_LDFLAGS)" \
137137
-Dcurrent_libdir="$(ROOTDIR)/$(OUTPUTDIR)" $(JVM_TEST_ARGS))
138+
format:
139+
./tests/lint/git-clang-format.sh -i origin/master
140+
black .
141+
cd rust; which cargo && cargo fmt --all; cd ..
142+
138143

139144
# clean rule
140145
clean:

docker/install/ubuntu_install_arm_compute_lib.sh

100644100755
File mode changed.

docker/install/ubuntu_install_ethosn_driver_stack.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,3 @@ git checkout "$repo_revision"
5757

5858
cd "driver"
5959
scons install_prefix="$install_path" install
60-

include/tvm/ir/diagnostic.h

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file diagnostic.h
22+
* \brief A new diagnostic interface for TVM error reporting.
23+
*
24+
* A prototype of the new diagnostic reporting interface for TVM.
25+
*
26+
* Eventually we hope to promote this file to the top-level and
27+
* replace the existing errors.h.
28+
*/
29+
30+
#ifndef TVM_IR_DIAGNOSTIC_H_
31+
#define TVM_IR_DIAGNOSTIC_H_
32+
33+
#include <tvm/ir/module.h>
34+
#include <tvm/ir/span.h>
35+
#include <tvm/parser/source_map.h>
36+
#include <tvm/runtime/container.h>
37+
#include <tvm/runtime/object.h>
38+
#include <tvm/support/logging.h>
39+
40+
#include <fstream>
41+
#include <string>
42+
#include <utility>
43+
#include <vector>
44+
45+
namespace tvm {
46+
47+
using tvm::parser::SourceMap;
48+
using tvm::runtime::TypedPackedFunc;
49+
50+
extern const char* kTVM_INTERNAL_ERROR_MESSAGE;
51+
52+
#define ICHECK_INDENT " "
53+
54+
#define ICHECK_BINARY_OP(name, op, x, y) \
55+
if (dmlc::LogCheckError _check_err = dmlc::LogCheck##name(x, y)) \
56+
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \
57+
<< kTVM_INTERNAL_ERROR_MESSAGE << std::endl \
58+
<< ICHECK_INDENT << "Check failed: " << #x " " #op " " #y << *(_check_err.str) << ": "
59+
60+
#define ICHECK(x) \
61+
if (!(x)) \
62+
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \
63+
<< kTVM_INTERNAL_ERROR_MESSAGE << ICHECK_INDENT << "Check failed: " #x << " == false: "
64+
65+
#define ICHECK_LT(x, y) ICHECK_BINARY_OP(_LT, <, x, y)
66+
#define ICHECK_GT(x, y) ICHECK_BINARY_OP(_GT, >, x, y)
67+
#define ICHECK_LE(x, y) ICHECK_BINARY_OP(_LE, <=, x, y)
68+
#define ICHECK_GE(x, y) ICHECK_BINARY_OP(_GE, >=, x, y)
69+
#define ICHECK_EQ(x, y) ICHECK_BINARY_OP(_EQ, ==, x, y)
70+
#define ICHECK_NE(x, y) ICHECK_BINARY_OP(_NE, !=, x, y)
71+
#define ICHECK_NOTNULL(x) \
72+
((x) == nullptr ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \
73+
<< kTVM_INTERNAL_ERROR_MESSAGE << __INDENT << "Check not null: " #x \
74+
<< ' ', \
75+
(x) : (x)) // NOLINT(*)
76+
77+
/*! \brief The diagnostic level, controls the printing of the message. */
78+
enum class DiagnosticLevel : int {
79+
kBug = 10,
80+
kError = 20,
81+
kWarning = 30,
82+
kNote = 40,
83+
kHelp = 50,
84+
};
85+
86+
class DiagnosticBuilder;
87+
88+
/*! \brief A compiler diagnostic. */
89+
class Diagnostic;
90+
91+
/*! \brief A compiler diagnostic message. */
92+
class DiagnosticNode : public Object {
93+
public:
94+
/*! \brief The level. */
95+
DiagnosticLevel level;
96+
/*! \brief The span at which to report an error. */
97+
Span span;
98+
/*! \brief The diagnostic message. */
99+
String message;
100+
101+
// override attr visitor
102+
void VisitAttrs(AttrVisitor* v) {
103+
v->Visit("level", &level);
104+
v->Visit("span", &span);
105+
v->Visit("message", &message);
106+
}
107+
108+
bool SEqualReduce(const DiagnosticNode* other, SEqualReducer equal) const {
109+
return equal(this->level, other->level) && equal(this->span, other->span) &&
110+
equal(this->message, other->message);
111+
}
112+
113+
static constexpr const char* _type_key = "Diagnostic";
114+
TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticNode, Object);
115+
};
116+
117+
class Diagnostic : public ObjectRef {
118+
public:
119+
TVM_DLL Diagnostic(DiagnosticLevel level, Span span, const std::string& message);
120+
121+
static DiagnosticBuilder Bug(Span span);
122+
static DiagnosticBuilder Error(Span span);
123+
static DiagnosticBuilder Warning(Span span);
124+
static DiagnosticBuilder Note(Span span);
125+
static DiagnosticBuilder Help(Span span);
126+
127+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Diagnostic, ObjectRef, DiagnosticNode);
128+
};
129+
130+
/*!
131+
* \brief A wrapper around std::stringstream to build a diagnostic.
132+
*/
133+
class DiagnosticBuilder {
134+
public:
135+
/*! \brief The level. */
136+
DiagnosticLevel level;
137+
138+
/*! \brief The source name. */
139+
SourceName source_name;
140+
141+
/*! \brief The span of the diagnostic. */
142+
Span span;
143+
144+
template <typename T>
145+
DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*)
146+
stream_ << val;
147+
return *this;
148+
}
149+
150+
DiagnosticBuilder() : level(DiagnosticLevel::kError), source_name(), span(Span()) {}
151+
152+
DiagnosticBuilder(const DiagnosticBuilder& builder)
153+
: level(builder.level), source_name(builder.source_name), span(builder.span) {}
154+
155+
DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {}
156+
157+
operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); }
158+
159+
private:
160+
std::stringstream stream_;
161+
friend class Diagnostic;
162+
};
163+
164+
/*!
165+
* \brief A diagnostic context for recording errors against a source file.
166+
*/
167+
class DiagnosticContext;
168+
169+
/*! \brief Display diagnostics in a given display format.
170+
*
171+
* A diagnostic renderer is responsible for converting the
172+
* raw diagnostics into consumable output.
173+
*
174+
* For example the terminal renderer will render a sequence
175+
* of compiler diagnostics to std::out and std::err in
176+
* a human readable form.
177+
*/
178+
class DiagnosticRendererNode : public Object {
179+
public:
180+
TypedPackedFunc<void(DiagnosticContext ctx)> renderer;
181+
182+
// override attr visitor
183+
void VisitAttrs(AttrVisitor* v) {}
184+
185+
static constexpr const char* _type_key = "DiagnosticRenderer";
186+
TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticRendererNode, Object);
187+
};
188+
189+
class DiagnosticRenderer : public ObjectRef {
190+
public:
191+
TVM_DLL DiagnosticRenderer(TypedPackedFunc<void(DiagnosticContext ctx)> render);
192+
TVM_DLL DiagnosticRenderer()
193+
: DiagnosticRenderer(TypedPackedFunc<void(DiagnosticContext ctx)>()) {}
194+
195+
void Render(const DiagnosticContext& ctx);
196+
197+
DiagnosticRendererNode* operator->() {
198+
CHECK(get() != nullptr);
199+
return static_cast<DiagnosticRendererNode*>(get_mutable());
200+
}
201+
202+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticRenderer, ObjectRef, DiagnosticRendererNode);
203+
};
204+
205+
class DiagnosticContextNode : public Object {
206+
public:
207+
/*! \brief The Module to report against. */
208+
IRModule module;
209+
210+
/*! \brief The set of diagnostics to report. */
211+
Array<Diagnostic> diagnostics;
212+
213+
/*! \brief The renderer set for the context. */
214+
DiagnosticRenderer renderer;
215+
216+
void VisitAttrs(AttrVisitor* v) {
217+
v->Visit("module", &module);
218+
v->Visit("diagnostics", &diagnostics);
219+
}
220+
221+
bool SEqualReduce(const DiagnosticContextNode* other, SEqualReducer equal) const {
222+
return equal(module, other->module) && equal(diagnostics, other->diagnostics);
223+
}
224+
225+
static constexpr const char* _type_key = "DiagnosticContext";
226+
TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticContextNode, Object);
227+
};
228+
229+
class DiagnosticContext : public ObjectRef {
230+
public:
231+
TVM_DLL DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer);
232+
TVM_DLL static DiagnosticContext Default(const IRModule& source_map);
233+
234+
/*! \brief Emit a diagnostic.
235+
* \param diagnostic The diagnostic to emit.
236+
*/
237+
void Emit(const Diagnostic& diagnostic);
238+
239+
/*! \brief Emit a diagnostic and then immediately attempt to render all errors.
240+
*
241+
* \param diagnostic The diagnostic to emit.
242+
*
243+
* Note: this will raise an exception if you would like to instead continue execution
244+
* use the Emit method instead.
245+
*/
246+
void EmitFatal(const Diagnostic& diagnostic);
247+
248+
/*! \brief Render the errors and raise a DiagnosticError exception. */
249+
void Render();
250+
251+
DiagnosticContextNode* operator->() {
252+
CHECK(get() != nullptr);
253+
return static_cast<DiagnosticContextNode*>(get_mutable());
254+
}
255+
256+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticContext, ObjectRef, DiagnosticContextNode);
257+
};
258+
259+
DiagnosticRenderer TerminalRenderer(std::ostream& ostream);
260+
261+
} // namespace tvm
262+
#endif // TVM_IR_DIAGNOSTIC_H_

include/tvm/ir/module.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <tvm/ir/function.h>
3030
#include <tvm/ir/type.h>
3131
#include <tvm/node/container.h>
32+
#include <tvm/parser/source_map.h>
3233

3334
#include <string>
3435
#include <unordered_map>
@@ -53,14 +54,17 @@ class IRModuleNode : public Object {
5354
Map<GlobalVar, BaseFunc> functions;
5455
/*! \brief A map from global type vars to ADT type data. */
5556
Map<GlobalTypeVar, TypeData> type_definitions;
57+
/*! \brief The source map for the module. */
58+
parser::SourceMap source_map;
5659

57-
IRModuleNode() {}
60+
IRModuleNode() : source_map() {}
5861

5962
void VisitAttrs(AttrVisitor* v) {
6063
v->Visit("functions", &functions);
6164
v->Visit("type_definitions", &type_definitions);
6265
v->Visit("global_var_map_", &global_var_map_);
6366
v->Visit("global_type_var_map_", &global_type_var_map_);
67+
v->Visit("source_map", &source_map);
6468
}
6569

6670
TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
@@ -280,12 +284,14 @@ class IRModule : public ObjectRef {
280284
* \param functions Functions in the module.
281285
* \param type_definitions Type definitions in the module.
282286
* \param import_set Set of imported files in the module
287+
* \param map The module source map.
283288
*/
284289
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
285290
Map<GlobalTypeVar, TypeData> type_definitions = {},
286-
std::unordered_set<String> import_set = {});
291+
std::unordered_set<String> import_set = {}, parser::SourceMap map = {});
292+
287293
/*! \brief default constructor */
288-
IRModule() : IRModule(Map<GlobalVar, BaseFunc>()) {}
294+
IRModule() : IRModule(Map<GlobalVar, BaseFunc>({})) {}
289295
/*!
290296
* \brief constructor
291297
* \param n The object pointer.

include/tvm/ir/span.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class Span : public ObjectRef {
114114
TVM_DLL Span(SourceName source_name, int line, int end_line, int column, int end_column);
115115

116116
/*! \brief Merge two spans into one which captures the combined regions. */
117-
TVM_DLL Span Merge(const Span& other);
117+
TVM_DLL Span Merge(const Span& other) const;
118118

119119
TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
120120
};

include/tvm/ir/transform.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
#ifndef TVM_IR_TRANSFORM_H_
5757
#define TVM_IR_TRANSFORM_H_
5858

59+
#include <tvm/ir/diagnostic.h>
5960
#include <tvm/ir/error.h>
6061
#include <tvm/ir/module.h>
6162
#include <tvm/node/container.h>
@@ -84,23 +85,19 @@ using TraceFunc =
8485
*/
8586
class PassContextNode : public Object {
8687
public:
87-
/*!
88-
* \brief The error reporter used to notify users why an optimization fails.
89-
*/
90-
ErrorReporter err_reporter;
91-
9288
/*! \brief The default optimization level. */
9389
int opt_level{2};
9490

9591
/*! \brief The list of required passes. */
9692
Array<String> required_pass;
9793
/*! \brief The list of disabled passes. */
9894
Array<String> disabled_pass;
99-
/*! \brief Trace function to be invoked before and after each pass. */
100-
TraceFunc trace_func;
101-
95+
/*! \brief The diagnostic context. */
96+
mutable Optional<DiagnosticContext> diag_ctx;
10297
/*! \brief Pass specific configurations. */
10398
Map<String, ObjectRef> config;
99+
/*! \brief Trace function to be invoked before and after each pass. */
100+
TraceFunc trace_func;
104101

105102
PassContextNode() = default;
106103

@@ -139,6 +136,7 @@ class PassContextNode : public Object {
139136
v->Visit("required_pass", &required_pass);
140137
v->Visit("disabled_pass", &disabled_pass);
141138
v->Visit("config", &config);
139+
v->Visit("diag_ctx", &diag_ctx);
142140
}
143141

144142
static constexpr const char* _type_key = "transform.PassContext";

0 commit comments

Comments
 (0)