From 4ee4eb8dd6476a742854576df22140e11fe2a7b8 Mon Sep 17 00:00:00 2001 From: wdx727 Date: Thu, 18 Sep 2025 10:17:45 +0800 Subject: [PATCH] Fix auto deduce of class member --- codon/parser/ast/expr.h | 5 + codon/parser/ast/stmt.h | 1 + codon/parser/visitors/typecheck/class.cpp | 72 +++++++++- codon/parser/visitors/typecheck/typecheck.cpp | 136 ++++++++++++++++++ codon/parser/visitors/typecheck/typecheck.h | 35 +++++ 5 files changed, 244 insertions(+), 5 deletions(-) diff --git a/codon/parser/ast/expr.h b/codon/parser/ast/expr.h index b4ebd738..11c02998 100644 --- a/codon/parser/ast/expr.h +++ b/codon/parser/ast/expr.h @@ -23,6 +23,7 @@ namespace codon::ast { void accept(VISITOR &visitor) override; \ std::string toString(int) const override; \ friend class TypecheckVisitor; \ + friend class AutoDeduceMembersTypecheckVisitor; \ template friend struct CallbackASTVisitor; \ friend struct ReplacingCallbackASTVisitor; \ inline decltype(auto) match_members() const { return std::tie(__VA_ARGS__); } \ @@ -51,6 +52,8 @@ struct Expr : public AcceptorExtend { void setDone() { done = true; } Expr *getOrigExpr() const { return origExpr; } void setOrigExpr(Expr *orig) { origExpr = orig; } + Expr *getTypeExpr() const { return typeExpr; } + void setTypeExpr(Expr *type) { typeExpr = type; } static const char NodeId; SERIALIZE(Expr, BASE(ASTNode), /*type,*/ done, origExpr); @@ -69,6 +72,8 @@ struct Expr : public AcceptorExtend { bool done; /// Original (pre-transformation) expression Expr *origExpr; + /// the expression of type + Expr *typeExpr{nullptr}; }; /// Function signature parameter helper node (name: type = defaultValue). diff --git a/codon/parser/ast/stmt.h b/codon/parser/ast/stmt.h index b8a85de2..58f213f2 100644 --- a/codon/parser/ast/stmt.h +++ b/codon/parser/ast/stmt.h @@ -21,6 +21,7 @@ namespace codon::ast { void accept(VISITOR &visitor) override; \ std::string toString(int) const override; \ friend class TypecheckVisitor; \ + friend class AutoDeduceMembersTypecheckVisitor; \ template friend struct CallbackASTVisitor; \ friend struct ReplacingCallbackASTVisitor; \ inline decltype(auto) match_members() const { return std::tie(__VA_ARGS__); } \ diff --git a/codon/parser/visitors/typecheck/class.cpp b/codon/parser/visitors/typecheck/class.cpp index ab014e03..8282382b 100644 --- a/codon/parser/visitors/typecheck/class.cpp +++ b/codon/parser/visitors/typecheck/class.cpp @@ -479,6 +479,61 @@ std::vector TypecheckVisitor::parseBaseClasses( return asts; } +Expr *TypecheckVisitor::inferMemberType(std::string member, FunctionStmt *f) { + if (f->items.empty()) + return nullptr; + AutoDeduceMembersTypecheckVisitor v(ctx, f->items); + return inferMemberType(f->items[0].name, member, f->getSuite(), v); +} +Expr *TypecheckVisitor::inferMemberType(std::string self, std::string member, + Stmt *stmt, + AutoDeduceMembersTypecheckVisitor &v) { + Expr *rlt = nullptr; + if (auto suite = cast(stmt)) { + for (auto *s : *suite) + if (auto typ = inferMemberType(self, member, s, v)) + rlt = typ; + } else if (auto whileLoop = cast(stmt)) { + if (auto typ = inferMemberType(self, member, whileLoop->getSuite(), v)) + rlt = typ; + if (auto typ = inferMemberType(self, member, whileLoop->getElse(), v)) + rlt = typ; + } else if (auto forLoop = cast(stmt)) { + if (auto typ = inferMemberType(self, member, forLoop->getSuite(), v)) + rlt = typ; + if (auto typ = inferMemberType(self, member, forLoop->getElse(), v)) + rlt = typ; + } else if (auto ifStmt = cast(stmt)) { + if (auto typ = inferMemberType(self, member, ifStmt->getIf(), v)) + rlt = typ; + if (auto typ = inferMemberType(self, member, ifStmt->getElse(), v)) + rlt = typ; + } else if (auto matchStmt = cast(stmt)) { + for (auto &c : matchStmt->items) + if (auto typ = inferMemberType(self, member, c.getSuite(), v)) + rlt = typ; + } else if (auto tryStmt = cast(stmt)) { + if (auto typ = inferMemberType(self, member, tryStmt->getSuite(), v)) + rlt = typ; + if (auto typ = inferMemberType(self, member, tryStmt->getElse(), v)) + rlt = typ; + if (auto typ = inferMemberType(self, member, tryStmt->getFinally(), v)) + rlt = typ; + } else if (auto assignStmt = cast(stmt)) { + auto rhs = clone(assignStmt->getRhs()); + rhs->accept(v); + if (auto lhs = cast(assignStmt->getLhs())) { + if (auto idExpr = cast(lhs->getExpr())) { + if (idExpr->getValue() == self && lhs->getMember() == member) + rlt = rhs->getTypeExpr(); + } + } else if (auto lhs = cast(assignStmt->getLhs())) { + v.addVar(lhs->value, rhs->getTypeExpr()); + } + } + return rlt; +} + /// Find the first __init__ with self parameter and use it to deduce class members. /// Each deduced member will be treated as generic. /// @example @@ -493,13 +548,16 @@ std::vector TypecheckVisitor::parseBaseClasses( /// @return the transformed init and the pointer to the original function. bool TypecheckVisitor::autoDeduceMembers(ClassStmt *stmt, std::vector &args) { std::set members; + std::unordered_map member2type; for (const auto &sp : getClassMethods(stmt->suite)) if (auto f = cast(sp)) { if (f->name == "__init__") if (const auto b = f->getAttribute(Attr::ClassDeduce)) { - for (const auto &m : b->values) + for (const auto &m : b->values) { members.insert(m); + member2type[m] = inferMemberType(m, f); + } } } if (!members.empty()) { @@ -507,10 +565,14 @@ bool TypecheckVisitor::autoDeduceMembers(ClassStmt *stmt, std::vector &ar if (auto aa = stmt->getAttribute(Attr::ClassMagic)) std::erase(aa->values, "init"); for (auto m : members) { - auto genericName = fmt::format("T_{}", m); - args.emplace_back(genericName, N(TYPE_TYPE), N("NoneType"), - Param::Generic); - args.emplace_back(m, N(genericName)); + if (auto typ = member2type[m]) { + args.emplace_back(m, typ); + } else { + auto genericName = fmt::format("T_{}", m); + args.emplace_back(genericName, N(TYPE_TYPE), N("NoneType"), + Param::Generic); + args.emplace_back(m, N(genericName)); + } } return true; } diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index 4abe828a..52bc3a3f 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -1924,4 +1924,140 @@ ir::PyFunction TypecheckVisitor::cythonizeFunction(const std::string &name) { return {"", ""}; } +void AutoDeduceMembersTypecheckVisitor::visit(BoolExpr *exp) { + exp->setTypeExpr(N("bool")); +} + +void AutoDeduceMembersTypecheckVisitor::visit(IntExpr *exp) { + exp->setTypeExpr(N("int")); +} + +void AutoDeduceMembersTypecheckVisitor::visit(FloatExpr *exp) { + exp->setTypeExpr(N("float")); +} + +void AutoDeduceMembersTypecheckVisitor::visit(StringExpr *exp) { + exp->setTypeExpr(N("str")); +} + +void AutoDeduceMembersTypecheckVisitor::visit(IdExpr *exp) { + auto val = exp->getValue(); + auto it = std::find_if(args.begin(), args.end(), + [val](Param &arg) { return arg.name == val; }); + if (it != args.end()) { + exp->setTypeExpr(it->getType()); + } +} + +void AutoDeduceMembersTypecheckVisitor::visit(TupleExpr *exp) { + std::vector items; + for (auto e : exp->items) { + e->accept(*this); + items.push_back(e->getTypeExpr()); + } + auto tupleExpr = N(items); + auto idExpr = N("Tuple"); + exp->setTypeExpr(N(idExpr, tupleExpr)); +} + +void AutoDeduceMembersTypecheckVisitor::visit(ListExpr *exp) { + if (exp->items.empty()) + return; + exp->items[0]->accept(*this); + auto listExpr = N("List"); + exp->setTypeExpr(N(listExpr, exp->items[0]->getTypeExpr())); +} + +void AutoDeduceMembersTypecheckVisitor::visit(SetExpr *exp) { + if (exp->items.empty()) + return; + exp->items[0]->accept(*this); + auto setExpr = N("set"); + exp->setTypeExpr(N(setExpr, exp->items[0]->getTypeExpr())); +} + +void AutoDeduceMembersTypecheckVisitor::visit(DictExpr *exp) { + std::vector items; + for (auto e : exp->items) { + e->accept(*this); + items.push_back(e->getTypeExpr()); + } + auto tupleExpr = N(items); + auto dictExpr = N("Dict"); + exp->setTypeExpr(N(dictExpr, tupleExpr)); +} + +void AutoDeduceMembersTypecheckVisitor::visit(UnaryExpr *exp) { + if (exp->op == "not") { + exp->setTypeExpr(N("bool")); + } else if (exp->op == "+" || exp->op == "-" || exp->op == "~") { + exp->expr->accept(*this); + exp->setTypeExpr(exp->expr->getTypeExpr()); + } else { + exp->setTypeExpr(N(exp->op)); + } +} + +void AutoDeduceMembersTypecheckVisitor::visit(BinaryExpr *exp) { + exp->lexpr->accept(*this); + exp->rexpr->accept(*this); + exp->setTypeExpr( + mergeTypeExpr(exp->op, exp->lexpr->getTypeExpr(), exp->rexpr->getTypeExpr())); +} + +void AutoDeduceMembersTypecheckVisitor::visit(RangeExpr *exp) { + exp->setTypeExpr(N("range")); +} + +void AutoDeduceMembersTypecheckVisitor::visit(GeneratorExpr *exp) { + if (exp->kind == GeneratorExpr::ListGenerator || + exp->kind == GeneratorExpr::SetGenerator) { + if (auto forExpr = cast(exp->loops)) { + if (auto iter = cast(forExpr->getIter())) { + if (auto idExpr = cast(iter->expr)) { + if (idExpr->value == "range") { + if (forExpr->getSuite()->items.size() == 1) { + if (auto expStm = cast(forExpr->getSuite()->items[0])) { + if (auto varExpr = cast(forExpr->getVar())) { + addVar(varExpr->value, N("int")); + expStm->expr->accept(*this); + auto typ = exp->kind == GeneratorExpr::ListGenerator + ? N("List") + : N("set"); + exp->setTypeExpr(N(typ, expStm->expr->getTypeExpr())); + } + } + } + } + } + } + } + } +} + +Expr *AutoDeduceMembersTypecheckVisitor::mergeTypeExpr(std::string op, Expr *l, + Expr *r) { + if (l == nullptr || r == nullptr) { + return nullptr; + } else if (op == "==" || op == "!=" || op == ">" || op == ">=" || op == "<" || + op == "<=" || op == "is" || op == "is not" || op == "in" || + op == "not in" || op == "and" || op == "or") { + return N("bool"); + } else if (op == "&" || op == "|" || op == "^" || op == "<<" || op == ">>" || + op == "//") { + return N("int"); + } else if (op == "/") { + return N("float"); + } else if (op == "+" || op == "-" || op == "*" || op == "%" || op == "**") { + if (l->toString() == r->toString()) { + return l; + } else if (l->toString() == "'float" || r->toString() == "'float") { + return N("float"); + } else { + return N("int"); + } + } + return l; +} + } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/typecheck.h b/codon/parser/visitors/typecheck/typecheck.h index cdf6efb9..a9774eb6 100644 --- a/codon/parser/visitors/typecheck/typecheck.h +++ b/codon/parser/visitors/typecheck/typecheck.h @@ -16,6 +16,8 @@ namespace codon::ast { +class AutoDeduceMembersTypecheckVisitor; + /** * Visitor that infers expression types and performs type-guided transformations. * @@ -201,6 +203,8 @@ class TypecheckVisitor : public ReplacingCallbackASTVisitor { const std::string &, const Expr *, types::ClassType *); bool autoDeduceMembers(ClassStmt *, std::vector &); + Expr *inferMemberType(std::string, FunctionStmt *); + Expr *inferMemberType(std::string, std::string, Stmt *, AutoDeduceMembersTypecheckVisitor &); static std::vector getClassMethods(Stmt *s); void transformNestedClasses(const ClassStmt *, std::vector &, std::vector &, std::vector &); @@ -484,4 +488,35 @@ class TypecheckVisitor : public ReplacingCallbackASTVisitor { // types::Type *getType(const std::string &); }; +// A simpler typechecker to infer the member type in advance +// based on the initializing right-hand side values. +// TODO: support method calls. +class AutoDeduceMembersTypecheckVisitor : public ASTVisitor { +public: + AutoDeduceMembersTypecheckVisitor(std::shared_ptr ctx, std::vector &args) + : ctx(ctx), args(args) {} + void addVar(std::string name, Expr *typ) { args.emplace_back(name, typ); } +private: + template Tn *N(Ts &&...args) { + Tn *t = ctx->cache->N(std::forward(args)...); + return t; + } + std::shared_ptr ctx; + std::vector args; + void visit(BoolExpr *) override; + void visit(IntExpr *) override; + void visit(FloatExpr *) override; + void visit(StringExpr *) override; + void visit(IdExpr *) override; + void visit(TupleExpr *) override; + void visit(ListExpr *) override; + void visit(SetExpr *) override; + void visit(DictExpr *) override; + void visit(UnaryExpr *) override; + void visit(BinaryExpr *) override; + void visit(RangeExpr *) override; + void visit(GeneratorExpr *) override; + Expr *mergeTypeExpr(std::string, Expr *, Expr *); +}; + } // namespace codon::ast