diff --git a/CMakeLists.txt b/CMakeLists.txt index c736f177..759a6fc4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -241,6 +241,7 @@ set(CODON_HPPFILES codon/cir/transform/folding/const_prop.h codon/cir/transform/folding/folding.h codon/cir/transform/folding/rule.h + codon/cir/transform/lowering/await.h codon/cir/transform/lowering/imperative.h codon/cir/transform/lowering/pipeline.h codon/cir/transform/manager.h @@ -347,6 +348,7 @@ set(CODON_CPPFILES codon/cir/transform/folding/const_fold.cpp codon/cir/transform/folding/const_prop.cpp codon/cir/transform/folding/folding.cpp + codon/cir/transform/lowering/await.cpp codon/cir/transform/lowering/imperative.cpp codon/cir/transform/lowering/pipeline.cpp codon/cir/transform/manager.cpp diff --git a/codon/cir/analyze/dataflow/capture.cpp b/codon/cir/analyze/dataflow/capture.cpp index db63ec74..f862d1b9 100644 --- a/codon/cir/analyze/dataflow/capture.cpp +++ b/codon/cir/analyze/dataflow/capture.cpp @@ -363,6 +363,8 @@ struct ExtractVars : public util::ConstVisitor { void visit(const FlowInstr *v) override { process(v->getValue()); } + void visit(const CoroHandleInstr *v) override {} + void visit(const dsl::CustomInstr *v) override { // TODO } @@ -649,6 +651,11 @@ struct CaptureTracker : public util::Operator { [&](DerivedSet &dset) { dset.result.returnCaptures = true; }); } + void handle(AwaitInstr *v) override { + forEachDSetOf(v->getValue(), + [&](DerivedSet &dset) { dset.result.returnCaptures = true; }); + } + void handle(ThrowInstr *v) override { forEachDSetOf(v->getValue(), [&](DerivedSet &dset) { dset.setExternCaptured(); }); } diff --git a/codon/cir/analyze/dataflow/cfg.cpp b/codon/cir/analyze/dataflow/cfg.cpp index ac92fd01..56a96ef1 100644 --- a/codon/cir/analyze/dataflow/cfg.cpp +++ b/codon/cir/analyze/dataflow/cfg.cpp @@ -413,6 +413,11 @@ void CFVisitor::visit(const YieldInstr *v) { defaultInsert(v); } +void CFVisitor::visit(const AwaitInstr *v) { + process(v->getValue()); + defaultInsert(v); +} + void CFVisitor::visit(const ThrowInstr *v) { if (v->getValue()) process(v->getValue()); @@ -426,6 +431,8 @@ void CFVisitor::visit(const FlowInstr *v) { defaultInsert(v); } +void CFVisitor::visit(const CoroHandleInstr *v) { defaultInsert(v); } + void CFVisitor::visit(const dsl::CustomInstr *v) { v->getCFBuilder()->buildCFNodes(this); } diff --git a/codon/cir/analyze/dataflow/cfg.h b/codon/cir/analyze/dataflow/cfg.h index 84ed2096..4f6de052 100644 --- a/codon/cir/analyze/dataflow/cfg.h +++ b/codon/cir/analyze/dataflow/cfg.h @@ -520,8 +520,10 @@ class CFVisitor : public util::ConstVisitor { void visit(const ContinueInstr *v) override; void visit(const ReturnInstr *v) override; void visit(const YieldInstr *v) override; + void visit(const AwaitInstr *v) override; void visit(const ThrowInstr *v) override; void visit(const FlowInstr *v) override; + void visit(const CoroHandleInstr *v) override; void visit(const dsl::CustomInstr *v) override; template void process(const NodeType *v) { diff --git a/codon/cir/analyze/module/side_effect.cpp b/codon/cir/analyze/module/side_effect.cpp index bbaf71af..eded3cd2 100644 --- a/codon/cir/analyze/module/side_effect.cpp +++ b/codon/cir/analyze/module/side_effect.cpp @@ -332,6 +332,10 @@ struct SideEfectAnalyzer : public util::ConstVisitor { set(v, max(Status::NO_CAPTURE, process(v->getValue()))); } + void visit(const AwaitInstr *v) override { + set(v, max(Status::NO_CAPTURE, process(v->getValue()))); + } + void visit(const ThrowInstr *v) override { process(v->getValue()); set(v, Status::UNKNOWN, Status::NO_CAPTURE); @@ -341,6 +345,8 @@ struct SideEfectAnalyzer : public util::ConstVisitor { set(v, max(process(v->getFlow()), process(v->getValue()))); } + void visit(const CoroHandleInstr *v) override { set(v, Status::PURE); } + void visit(const dsl::CustomInstr *v) override { set(v, v->getSideEffectStatus(/*local=*/true), v->getSideEffectStatus(/*local=*/false)); diff --git a/codon/cir/func.h b/codon/cir/func.h index 6732ca04..a352a295 100644 --- a/codon/cir/func.h +++ b/codon/cir/func.h @@ -16,6 +16,8 @@ class Func : public AcceptorExtend { std::string unmangledName; /// whether the function is a generator bool generator; + /// whether the function is an async function + bool async; /// Parent type if func is a method, or null if not types::Type *parentType; @@ -36,7 +38,7 @@ class Func : public AcceptorExtend { /// @param name the function's name explicit Func(std::string name = "") : AcceptorExtend(nullptr, true, false, std::move(name)), generator(false), - parentType(nullptr) {} + async(false), parentType(nullptr) {} /// Re-initializes the function with a new type and names. /// @param newType the function's new type @@ -73,6 +75,12 @@ class Func : public AcceptorExtend { /// @param v the new value void setGenerator(bool v = true) { generator = v; } + /// @return true if the function is an async function + bool isAsync() const { return async; } + /// Sets the function's async flag. + /// @param v the new value + void setAsync(bool v = true) { async = v; } + /// @return the variable corresponding to the given argument name /// @param n the argument name Var *getArgVar(const std::string &n); diff --git a/codon/cir/instr.cpp b/codon/cir/instr.cpp index 6c97ecf6..052b2838 100644 --- a/codon/cir/instr.cpp +++ b/codon/cir/instr.cpp @@ -142,6 +142,24 @@ int YieldInInstr::doReplaceUsedType(const std::string &name, types::Type *newTyp return 0; } +const char AwaitInstr::NodeId = 0; + +int AwaitInstr::doReplaceUsedValue(id_t id, Value *newValue) { + if (value->getId() == id) { + value = newValue; + return 1; + } + return 0; +} + +int AwaitInstr::doReplaceUsedType(const std::string &name, types::Type *newType) { + if (type->getName() == name) { + type = newType; + return 1; + } + return 0; +} + const char TernaryInstr::NodeId = 0; int TernaryInstr::doReplaceUsedValue(id_t id, Value *newValue) { @@ -265,5 +283,12 @@ int FlowInstr::doReplaceUsedValue(id_t id, Value *newValue) { return replacements; } +const char CoroHandleInstr::NodeId = 0; + +types::Type *CoroHandleInstr::doGetType() const { + auto *M = getModule(); + return M->getPointerType(M->getByteType()); +} + } // namespace ir } // namespace codon diff --git a/codon/cir/instr.h b/codon/cir/instr.h index 2cb258ed..36124d17 100644 --- a/codon/cir/instr.h +++ b/codon/cir/instr.h @@ -331,7 +331,7 @@ class YieldInInstr : public AcceptorExtend { /// @param v the new value void setSuspending(bool v = true) { suspend = v; } - /// Sets the type being inspected + /// Sets the type. /// @param t the new type void setType(types::Type *t) { type = t; } @@ -477,6 +477,7 @@ class ReturnInstr : public AcceptorExtend { int doReplaceUsedValue(id_t id, Value *newValue) override; }; +/// Instr representing a yield statement. class YieldInstr : public AcceptorExtend { private: /// the value @@ -509,6 +510,40 @@ class YieldInstr : public AcceptorExtend { int doReplaceUsedValue(id_t id, Value *newValue) override; }; +/// Instr representing an await statement. +class AwaitInstr : public AcceptorExtend { +private: + /// the value + Value *value; + /// the type of the result + types::Type *type; + +public: + static const char NodeId; + + explicit AwaitInstr(Value *value, types::Type *type, std::string name = "") + : AcceptorExtend(std::move(name)), value(value), type(type) {} + + /// @return the value + Value *getValue() { return value; } + /// @return the value + const Value *getValue() const { return value; } + /// Sets the value. + /// @param v the new value + void setValue(Value *v) { value = v; } + + /// Sets the type. + /// @param t the new type + void setType(types::Type *t) { type = t; } + +protected: + types::Type *doGetType() const override { return type; } + std::vector doGetUsedValues() const override { return {value}; } + std::vector doGetUsedTypes() const override { return {type}; } + int doReplaceUsedValue(id_t id, Value *newValue) override; + int doReplaceUsedType(const std::string &name, types::Type *newType) override; +}; + class ThrowInstr : public AcceptorExtend { private: /// the value @@ -573,5 +608,19 @@ class FlowInstr : public AcceptorExtend { int doReplaceUsedValue(id_t id, Value *newValue) override; }; +/// Instr for obtaining the coroutine handle of the enclosing function. +/// Not emitted by any front-end; instead used internally in IR lowering. +class CoroHandleInstr : public AcceptorExtend { +public: + static const char NodeId; + + /// Constructs a coroutine handle instruction. + /// @param name the name + explicit CoroHandleInstr(std::string name = "") : AcceptorExtend(std::move(name)) {} + +protected: + types::Type *doGetType() const override; +}; + } // namespace ir } // namespace codon diff --git a/codon/cir/llvm/llvisitor.cpp b/codon/cir/llvm/llvisitor.cpp index 84fb6169..678ba64d 100644 --- a/codon/cir/llvm/llvisitor.cpp +++ b/codon/cir/llvm/llvisitor.cpp @@ -647,16 +647,11 @@ llvm::Function *LLVMVisitor::createPyTryCatchWrapper(llvm::Function *func) { auto *unwindType = llvm::StructType::get(B->getInt64Ty()); // header only auto *unwindException = B->CreateExtractValue(caughtResult, 0); auto *unwindExceptionClass = B->CreateLoad( - B->getInt64Ty(), - B->CreateStructGEP( - unwindType, B->CreatePointerCast(unwindException, unwindType->getPointerTo()), - 0)); + B->getInt64Ty(), B->CreateStructGEP(unwindType, unwindException, 0)); unwindException = B->CreateExtractValue(caughtResult, 0); - auto *excType = llvm::StructType::get(getTypeInfoType(), B->getPtrTy()); - auto *excVal = - B->CreatePointerCast(B->CreateConstGEP1_64(B->getInt8Ty(), unwindException, - (uint64_t)seq_exc_offset()), - excType->getPointerTo()); + auto *excType = B->getPtrTy(); + auto *excVal = B->CreateConstGEP1_64(B->getInt8Ty(), unwindException, + (uint64_t)seq_exc_offset()); auto *loadedExc = B->CreateLoad(excType, excVal); auto *objPtr = B->CreateExtractValue(loadedExc, 1); @@ -1362,8 +1357,7 @@ llvm::FunctionCallee LLVMVisitor::makePersonalityFunc() { } llvm::FunctionCallee LLVMVisitor::makeExcAllocFunc() { - auto f = M->getOrInsertFunction("seq_alloc_exc", B->getPtrTy(), B->getInt32Ty(), - B->getPtrTy()); + auto f = M->getOrInsertFunction("seq_alloc_exc", B->getPtrTy(), B->getPtrTy()); auto *g = cast(f.getCallee()); g->setDoesNotThrow(); return f; @@ -1391,32 +1385,21 @@ llvm::StructType *LLVMVisitor::getPadType() { return llvm::StructType::get(B->getPtrTy(), B->getInt32Ty()); } -llvm::StructType *LLVMVisitor::getExceptionType() { - return llvm::StructType::get(getTypeInfoType(), B->getPtrTy()); -} - namespace { -int typeIdxLookup(const std::string &name) { - static std::unordered_map cache; - static int next = 1000; - if (name.empty()) +int typeIdxLookup(types::Type *type) { + if (!type) return 0; - auto it = cache.find(name); - if (it != cache.end()) { - return it->second; - } else { - const int myID = next++; - cache[name] = myID; - return myID; - } + auto *M = type->getModule(); + return M->getCache()->getRealizationId(type->getAstType()->getClass()); } } // namespace -llvm::GlobalVariable *LLVMVisitor::getTypeIdxVar(const std::string &name) { +llvm::GlobalVariable *LLVMVisitor::getTypeIdxVar(types::Type *type) { auto *typeInfoType = getTypeInfoType(); - const std::string typeVarName = "codon.typeidx." + (name.empty() ? "" : name); + const std::string name = type ? type->getName() : ""; + const std::string typeVarName = "codon.typeidx." + (type ? name : ""); llvm::GlobalVariable *tidx = M->getGlobalVariable(typeVarName); - int idx = typeIdxLookup(name); + int idx = typeIdxLookup(type); if (!tidx) { tidx = new llvm::GlobalVariable( *M, typeInfoType, /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, @@ -1426,13 +1409,7 @@ llvm::GlobalVariable *LLVMVisitor::getTypeIdxVar(const std::string &name) { return tidx; } -llvm::GlobalVariable *LLVMVisitor::getTypeIdxVar(types::Type *catchType) { - return getTypeIdxVar(catchType ? catchType->getName() : ""); -} - -int LLVMVisitor::getTypeIdx(types::Type *catchType) { - return typeIdxLookup(catchType ? catchType->getName() : ""); -} +int LLVMVisitor::getTypeIdx(types::Type *catchType) { return typeIdxLookup(catchType); } llvm::Value *LLVMVisitor::call(llvm::FunctionCallee callee, llvm::ArrayRef args) { @@ -1578,7 +1555,6 @@ void LLVMVisitor::visit(const Module *x) { llvm::Value *elemSize = B->getInt64(M->getDataLayout().getTypeAllocSize(strType)); llvm::Value *allocSize = B->CreateMul(len, elemSize); llvm::Value *ptr = B->CreateCall(allocFunc, allocSize); - ptr = B->CreateBitCast(ptr, strType->getPointerTo()); llvm::Value *arr = llvm::UndefValue::get(arrType); arr = B->CreateInsertValue(arr, len, 0); arr = B->CreateInsertValue(arr, ptr, 1); @@ -1805,7 +1781,6 @@ void LLVMVisitor::visit(const InternalFunc *x) { B->getInt64(M->getDataLayout().getTypeAllocSize(llvmBaseType)); llvm::Value *allocSize = B->CreateMul(elemSize, args[0]); result = B->CreateCall(allocFunc, allocSize); - result = B->CreateBitCast(result, llvmBaseType->getPointerTo()); } else if (internalFuncMatches("__new__", x)) { @@ -1838,7 +1813,7 @@ void LLVMVisitor::visit(const InternalFunc *x) { B->getInt32(M->getDataLayout().getPrefTypeAlign(baseType).value()); llvm::Value *from = B->getFalse(); llvm::Value *ptr = B->CreateCall(coroPromise, {args[0], aln, from}); - result = B->CreateBitCast(ptr, baseType->getPointerTo()); + result = ptr; } } @@ -2051,8 +2026,9 @@ void LLVMVisitor::visit(const BodiedFunc *x) { } auto *startBlock = llvm::BasicBlock::Create(*context, "start", func); + const bool generator = x->isGenerator() || x->isAsync(); - if (x->isGenerator()) { + if (generator) { func->setPresplitCoroutine(); auto *generatorType = cast(returnType); seqassertn(generatorType, "{} is not a generator type", *returnType); @@ -2082,8 +2058,7 @@ void LLVMVisitor::visit(const BodiedFunc *x) { if (!cast(generatorType->getBase())) { coro.promise = B->CreateAlloca(getLLVMType(generatorType->getBase())); coro.promise->setName("coro.promise"); - llvm::Value *promiseRaw = B->CreateBitCast(coro.promise, B->getPtrTy()); - id = B->CreateCall(coroId, {B->getInt32(0), promiseRaw, nullPtr, nullPtr}); + id = B->CreateCall(coroId, {B->getInt32(0), coro.promise, nullPtr, nullPtr}); } else { id = B->CreateCall(coroId, {B->getInt32(0), nullPtr, nullPtr, nullPtr}); } @@ -2140,7 +2115,7 @@ void LLVMVisitor::visit(const BodiedFunc *x) { process(x->getBody()); B->SetInsertPoint(block); - if (x->isGenerator()) { + if (generator) { B->CreateBr(coro.exit); } else { if (cast(returnType)) { @@ -2251,12 +2226,7 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) { } if (auto *x = cast(t)) { - auto *p = B->getPtrTy(); - if (x->isPolymorphic()) { - return llvm::StructType::get(*context, {p, p}); - } else { - return p; - } + return B->getPtrTy(); } if (auto *x = cast(t)) { @@ -2425,37 +2395,37 @@ llvm::DIType *LLVMVisitor::getDITypeHelper( if (auto *x = cast(t)) { auto *ref = db.builder->createReferenceType( llvm::dwarf::DW_TAG_reference_type, getDITypeHelper(x->getContents(), cache)); - if (x->isPolymorphic()) { - auto *p = B->getPtrTy(); - auto pointerSizeInBits = layout.getTypeAllocSizeInBits(p); - auto *rtti = db.builder->createBasicType("rtti", pointerSizeInBits, - llvm::dwarf::DW_ATE_address); - auto *structType = llvm::StructType::get(p, p); - auto *structLayout = layout.getStructLayout(structType); - auto *srcInfo = getSrcInfo(x); - llvm::DIFile *file = db.getFile(srcInfo->file); - std::vector members; - - llvm::DICompositeType *diType = db.builder->createStructType( - file, x->getName(), file, srcInfo->line, structLayout->getSizeInBits(), - /*AlignInBits=*/0, llvm::DINode::FlagZero, /*DerivedFrom=*/nullptr, - db.builder->getOrCreateArray(members)); - - members.push_back(db.builder->createMemberType( - diType, "data", file, srcInfo->line, pointerSizeInBits, - /*AlignInBits=*/0, structLayout->getElementOffsetInBits(0), - llvm::DINode::FlagZero, ref)); - - members.push_back(db.builder->createMemberType( - diType, "rtti", file, srcInfo->line, pointerSizeInBits, - /*AlignInBits=*/0, structLayout->getElementOffsetInBits(1), - llvm::DINode::FlagZero, rtti)); - - db.builder->replaceArrays(diType, db.builder->getOrCreateArray(members)); - return diType; - } else { - return ref; - } + // if (x->isPolymorphic()) { + // auto *p = B->getPtrTy(); + // auto pointerSizeInBits = layout.getTypeAllocSizeInBits(p); + // auto *rtti = db.builder->createBasicType("rtti", pointerSizeInBits, + // llvm::dwarf::DW_ATE_address); + // auto *structType = llvm::StructType::get(p, p); + // auto *structLayout = layout.getStructLayout(structType); + // auto *srcInfo = getSrcInfo(x); + // llvm::DIFile *file = db.getFile(srcInfo->file); + // std::vector members; + + // llvm::DICompositeType *diType = db.builder->createStructType( + // file, x->getName(), file, srcInfo->line, structLayout->getSizeInBits(), + // /*AlignInBits=*/0, llvm::DINode::FlagZero, /*DerivedFrom=*/nullptr, + // db.builder->getOrCreateArray(members)); + + // members.push_back(db.builder->createMemberType( + // diType, "data", file, srcInfo->line, pointerSizeInBits, + // /*AlignInBits=*/0, structLayout->getElementOffsetInBits(0), + // llvm::DINode::FlagZero, ref)); + + // members.push_back(db.builder->createMemberType( + // diType, "rtti", file, srcInfo->line, pointerSizeInBits, + // /*AlignInBits=*/0, structLayout->getElementOffsetInBits(1), + // llvm::DINode::FlagZero, rtti)); + + // db.builder->replaceArrays(diType, db.builder->getOrCreateArray(members)); + // return diType; + // } else { + return ref; + // } } if (auto *x = cast(t)) { @@ -2810,9 +2780,9 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { auto *excStateContinue = B->getInt8(TryCatchData::State::CONTINUE); auto *excStateRethrow = B->getInt8(TryCatchData::State::RETHROW); - llvm::StructType *padType = getPadType(); - llvm::StructType *unwindType = llvm::StructType::get(B->getInt64Ty()); // header only - llvm::StructType *excType = llvm::StructType::get(getTypeInfoType(), B->getPtrTy()); + auto *padType = getPadType(); + auto *unwindType = llvm::StructType::get(B->getInt64Ty()); // header only + auto *excType = B->getPtrTy(); if (isRoot) { tc.excFlag = B->CreateAlloca(B->getInt8Ty()); @@ -3061,10 +3031,7 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { B->CreateStore(depthMax, tc.delegateDepth); llvm::Value *unwindExceptionClass = B->CreateLoad( - B->getInt64Ty(), - B->CreateStructGEP( - unwindType, B->CreatePointerCast(unwindException, unwindType->getPointerTo()), - 0)); + B->getInt64Ty(), B->CreateStructGEP(unwindType, unwindException, 0)); // check for foreign exceptions B->CreateCondBr( @@ -3078,15 +3045,9 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { // reroute Codon exceptions B->SetInsertPoint(tc.exceptionRouteBlock); unwindException = B->CreateExtractValue(B->CreateLoad(padType, tc.catchStore), 0); - llvm::Value *excVal = - B->CreatePointerCast(B->CreateConstGEP1_64(B->getInt8Ty(), unwindException, - (uint64_t)seq_exc_offset()), - excType->getPointerTo()); - + llvm::Value *excVal = B->CreateConstGEP1_64(B->getInt8Ty(), unwindException, + (uint64_t)seq_exc_offset()); llvm::Value *loadedExc = B->CreateLoad(excType, excVal); - llvm::Value *objType = B->CreateExtractValue(loadedExc, 0); - objType = B->CreateExtractValue(objType, 0); - llvm::Value *objPtr = B->CreateExtractValue(loadedExc, 1); // set depth when catch-all entered auto *defaultRouteBlock = llvm::BasicBlock::Create(*context, "trycatch.fdepth", func); @@ -3097,6 +3058,7 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { : tc.finallyBlock); B->SetInsertPoint(tc.exceptionRouteBlock); + auto *objType = B->CreateExtractValue(B->CreateLoad(padType, tc.catchStore), 1); llvm::SwitchInst *switchToCatchBlock = B->CreateSwitch(objType, defaultRouteBlock, (unsigned)handlersFull.size()); for (unsigned i = 0; i < handlersFull.size(); i++) { @@ -3118,15 +3080,14 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { const Var *var = catches[i]->getVar(); if (var) { - llvm::Value *obj = B->CreateBitCast(objPtr, getLLVMType(catches[i]->getType())); llvm::Value *varPtr = getVar(var); seqassertn(varPtr, "could not get catch var"); - B->CreateStore(obj, varPtr); + B->CreateStore(loadedExc, varPtr); } B->CreateStore(excStateCaught, tc.excFlag); CatchData cd; - cd.exception = objPtr; + cd.exception = loadedExc; cd.typeId = objType; enterCatch(cd); process(catches[i]->getHandler()); @@ -3227,7 +3188,6 @@ void LLVMVisitor::codegenPipeline( B->getInt32(M->getDataLayout().getPrefTypeAlign(baseType).value()); llvm::Value *from = B->getFalse(); llvm::Value *promise = B->CreateCall(coroPromise, {iter, alignment, from}); - promise = B->CreateBitCast(promise, baseType->getPointerTo()); value = B->CreateLoad(baseType, promise); block = bodyBlock; @@ -3284,9 +3244,10 @@ void LLVMVisitor::visit(const ExtractInstr *x) { process(x->getVal()); B->SetInsertPoint(block); if (auto *refType = cast(memberedType)) { - if (refType->isPolymorphic()) - value = - B->CreateExtractValue(value, 0); // polymorphic ref type is tuple (data, rtti) + if (refType->isPolymorphic()) { + // polymorphic ref type is ref to (data, rtti) + value = B->CreateLoad(B->getPtrTy(), value); + } value = B->CreateLoad(getLLVMType(refType->getContents()), value); } value = B->CreateExtractValue(value, index); @@ -3304,8 +3265,10 @@ void LLVMVisitor::visit(const InsertInstr *x) { llvm::Value *rhs = value; B->SetInsertPoint(block); - if (refType->isPolymorphic()) - lhs = B->CreateExtractValue(lhs, 0); // polymorphic ref type is tuple (data, rtti) + if (refType->isPolymorphic()) { + // polymorphic ref type is ref to (data, rtti) + lhs = B->CreateLoad(B->getPtrTy(), lhs); + } llvm::Value *load = B->CreateLoad(getLLVMType(refType->getContents()), lhs); load = B->CreateInsertValue(load, rhs, index); B->CreateStore(load, lhs); @@ -3511,6 +3474,10 @@ void LLVMVisitor::visit(const YieldInstr *x) { } } +void LLVMVisitor::visit(const AwaitInstr *x) { + seqassertn(false, "await instruction not lowered"); +} + void LLVMVisitor::visit(const ThrowInstr *x) { if (DisableExceptions) { B->SetInsertPoint(block); @@ -3536,7 +3503,7 @@ void LLVMVisitor::visit(const ThrowInstr *x) { } B->SetInsertPoint(block); - llvm::Value *exc = B->CreateCall(excAllocFunc, {typ, obj}); + llvm::Value *exc = B->CreateCall(excAllocFunc, {obj}); call(throwFunc, exc); } @@ -3545,6 +3512,11 @@ void LLVMVisitor::visit(const FlowInstr *x) { process(x->getValue()); } +void LLVMVisitor::visit(const CoroHandleInstr *x) { + seqassertn(coro.handle, "no coroutine handle"); + value = coro.handle; +} + void LLVMVisitor::visit(const dsl::CustomInstr *x) { B->SetInsertPoint(block); value = x->getBuilder()->buildValue(this); diff --git a/codon/cir/llvm/llvisitor.h b/codon/cir/llvm/llvisitor.h index 875c47c1..c8845534 100644 --- a/codon/cir/llvm/llvisitor.h +++ b/codon/cir/llvm/llvisitor.h @@ -184,7 +184,6 @@ class LLVMVisitor : public util::ConstVisitor { llvm::StructType *getTypeInfoType(); llvm::StructType *getPadType(); llvm::StructType *getExceptionType(); - llvm::GlobalVariable *getTypeIdxVar(const std::string &name); llvm::GlobalVariable *getTypeIdxVar(types::Type *catchType); int getTypeIdx(types::Type *catchType = nullptr); @@ -445,8 +444,10 @@ class LLVMVisitor : public util::ConstVisitor { void visit(const ContinueInstr *) override; void visit(const ReturnInstr *) override; void visit(const YieldInstr *) override; + void visit(const AwaitInstr *) override; void visit(const ThrowInstr *) override; void visit(const FlowInstr *) override; + void visit(const CoroHandleInstr *) override; void visit(const dsl::CustomInstr *) override; }; diff --git a/codon/cir/transform/lowering/await.cpp b/codon/cir/transform/lowering/await.cpp new file mode 100644 index 00000000..6d87c1bb --- /dev/null +++ b/codon/cir/transform/lowering/await.cpp @@ -0,0 +1,71 @@ +// Copyright (C) 2022-2025 Exaloop Inc. + +#include "await.h" + +#include + +#include "codon/cir/util/cloning.h" +#include "codon/cir/util/irtools.h" +#include "codon/parser/visitors/typecheck/typecheck.h" + +namespace codon { +namespace ir { +namespace transform { +namespace lowering { +namespace { + +bool isFuture(const types::Type *type) { + return type->getName().rfind("std.asyncio.Future.", 0) == 0; +} + +bool isTask(const types::Type *type) { + return type->getName().rfind("std.asyncio.Task.", 0) == 0; +} + +const types::GeneratorType *isCoroutine(const types::Type *type) { + return cast(type); +} + +} // namespace + +const std::string AwaitLowering::KEY = "core-await-lowering"; + +void AwaitLowering::handle(AwaitInstr *v) { + auto *M = v->getModule(); + auto *value = v->getValue(); + auto *resultType = v->getType(); + auto *valueType = value->getType(); + util::CloneVisitor cv(M); + + if (isFuture(valueType) || isTask(valueType)) { + auto *coro = M->Nr(); + auto *addCallback = M->getOrRealizeMethod(valueType, "_add_done_callback", + {valueType, coro->getType()}); + seqassertn(addCallback, "add-callback method not found"); + auto *getResult = M->getOrRealizeMethod(valueType, "result", {valueType}); + seqassertn(getResult, "get-result method not found"); + + auto *series = M->Nr(); + auto *futureVar = + util::makeVar(cv.clone(value), series, cast(getParentFunc())); + series->push_back(util::call(addCallback, {M->Nr(futureVar), coro})); + series->push_back(M->Nr()); + + auto *replacement = + M->Nr(series, util::call(getResult, {M->Nr(futureVar)})); + v->replaceAll(replacement); + } else if (auto *genType = isCoroutine(valueType)) { + auto *var = M->Nr(genType->getBase(), /*global=*/false); + cast(getParentFunc())->push_back(var); + auto *replacement = M->Nr(cv.clone(value), M->Nr(), var); + v->replaceAll(replacement); + } else { + seqassertn(false, "unexpected value type '{}' in await instruction", + valueType->getName()); + } +} + +} // namespace lowering +} // namespace transform +} // namespace ir +} // namespace codon diff --git a/codon/cir/transform/lowering/await.h b/codon/cir/transform/lowering/await.h new file mode 100644 index 00000000..9df98d1e --- /dev/null +++ b/codon/cir/transform/lowering/await.h @@ -0,0 +1,22 @@ +// Copyright (C) 2022-2025 Exaloop Inc. + +#pragma once + +#include "codon/cir/transform/pass.h" + +namespace codon { +namespace ir { +namespace transform { +namespace lowering { + +class AwaitLowering : public OperatorPass { +public: + static const std::string KEY; + std::string getKey() const override { return KEY; } + void handle(AwaitInstr *v) override; +}; + +} // namespace lowering +} // namespace transform +} // namespace ir +} // namespace codon diff --git a/codon/cir/transform/manager.cpp b/codon/cir/transform/manager.cpp index 0734082d..6b7c2649 100644 --- a/codon/cir/transform/manager.cpp +++ b/codon/cir/transform/manager.cpp @@ -12,6 +12,7 @@ #include "codon/cir/analyze/module/global_vars.h" #include "codon/cir/analyze/module/side_effect.h" #include "codon/cir/transform/folding/folding.h" +#include "codon/cir/transform/lowering/await.h" #include "codon/cir/transform/lowering/imperative.h" #include "codon/cir/transform/lowering/pipeline.h" #include "codon/cir/transform/manager.h" @@ -156,6 +157,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) { case Init::DEBUG: { registerPass(std::make_unique()); registerPass(std::make_unique()); + registerPass(std::make_unique()); registerPass(std::make_unique()); break; } @@ -207,6 +209,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) { {seKey1, rdKey, cfgKey, globalKey, capKey}); // parallel + registerPass(std::make_unique()); registerPass(std::make_unique(), /*insertBefore=*/"", {}, {cfgKey, globalKey}); diff --git a/codon/cir/util/cloning.cpp b/codon/cir/util/cloning.cpp index 3bcd4bfd..0ebcc96e 100644 --- a/codon/cir/util/cloning.cpp +++ b/codon/cir/util/cloning.cpp @@ -88,6 +88,7 @@ void CloneVisitor::visit(const BodiedFunc *v) { } res->setUnmangledName(v->getUnmangledName()); res->setGenerator(v->isGenerator()); + res->setAsync(v->isAsync()); res->realize(cast(v->getType()), argNames); auto argIt1 = v->arg_begin(); @@ -114,6 +115,7 @@ void CloneVisitor::visit(const ExternalFunc *v) { argNames.push_back((*it)->getName()); res->setUnmangledName(v->getUnmangledName()); res->setGenerator(v->isGenerator()); + res->setAsync(v->isAsync()); res->realize(cast(v->getType()), argNames); auto argIt1 = v->arg_begin(); @@ -134,6 +136,7 @@ void CloneVisitor::visit(const InternalFunc *v) { argNames.push_back((*it)->getName()); res->setUnmangledName(v->getUnmangledName()); res->setGenerator(v->isGenerator()); + res->setAsync(v->isAsync()); res->realize(cast(v->getType()), argNames); auto argIt1 = v->arg_begin(); @@ -155,6 +158,7 @@ void CloneVisitor::visit(const LLVMFunc *v) { argNames.push_back((*it)->getName()); res->setUnmangledName(v->getUnmangledName()); res->setGenerator(v->isGenerator()); + res->setAsync(v->isAsync()); res->realize(cast(v->getType()), argNames); auto argIt1 = v->arg_begin(); @@ -320,12 +324,18 @@ void CloneVisitor::visit(const YieldInstr *v) { result = Nt(v, clone(v->getValue()), v->isFinal()); } +void CloneVisitor::visit(const AwaitInstr *v) { + result = Nt(v, clone(v->getValue()), v->getType()); +} + void CloneVisitor::visit(const ThrowInstr *v) { result = Nt(v, clone(v->getValue())); } void CloneVisitor::visit(const FlowInstr *v) { result = Nt(v, clone(v->getFlow()), clone(v->getValue())); } +void CloneVisitor::visit(const CoroHandleInstr *v) { result = Nt(v); } + void CloneVisitor::visit(const dsl::CustomInstr *v) { result = v->doClone(*this); } } // namespace util diff --git a/codon/cir/util/cloning.h b/codon/cir/util/cloning.h index 2dd0e2c0..27110d6b 100644 --- a/codon/cir/util/cloning.h +++ b/codon/cir/util/cloning.h @@ -68,8 +68,10 @@ class CloneVisitor : public ConstVisitor { void visit(const ContinueInstr *v) override; void visit(const ReturnInstr *v) override; void visit(const YieldInstr *v) override; + void visit(const AwaitInstr *v) override; void visit(const ThrowInstr *v) override; void visit(const FlowInstr *v) override; + void visit(const CoroHandleInstr *v) override; void visit(const dsl::CustomInstr *v) override; /// Clones a value, returning the previous value if other has already been cloned. diff --git a/codon/cir/util/format.cpp b/codon/cir/util/format.cpp index 91ba5f36..1e856c6e 100644 --- a/codon/cir/util/format.cpp +++ b/codon/cir/util/format.cpp @@ -278,6 +278,10 @@ class FormatVisitor : util::ConstVisitor { void visit(const YieldInstr *v) override { fmt::print(os, FMT_STRING("(yield {})"), makeFormatter(v->getValue())); } + void visit(const AwaitInstr *v) override { + fmt::print(os, FMT_STRING("(await {} {})"), makeFormatter(v->getType()), + makeFormatter(v->getValue())); + } void visit(const ThrowInstr *v) override { fmt::print(os, FMT_STRING("(throw {})"), makeFormatter(v->getValue())); } @@ -285,6 +289,7 @@ class FormatVisitor : util::ConstVisitor { fmt::print(os, FMT_STRING("(flow {} {})"), makeFormatter(v->getFlow()), makeFormatter(v->getValue())); } + void visit(const CoroHandleInstr *v) override { os << "(coro_handle)"; } void visit(const dsl::CustomInstr *v) override { v->doFormat(os); } void visit(const types::IntType *v) override { diff --git a/codon/cir/util/matching.cpp b/codon/cir/util/matching.cpp index f8f150f0..dd62f415 100644 --- a/codon/cir/util/matching.cpp +++ b/codon/cir/util/matching.cpp @@ -218,6 +218,11 @@ class MatchVisitor : public util::ConstVisitor { void handle(const YieldInstr *x, const YieldInstr *y) { result = process(x->getValue(), y->getValue()); } + VISIT(AwaitInstr); + void handle(const AwaitInstr *x, const AwaitInstr *y) { + result = + process(x->getType(), y->getType()) && process(x->getValue(), y->getValue()); + } VISIT(ThrowInstr); void handle(const ThrowInstr *x, const ThrowInstr *y) { result = process(x->getValue(), y->getValue()); @@ -227,6 +232,8 @@ class MatchVisitor : public util::ConstVisitor { result = process(x->getFlow(), y->getFlow()) && process(x->getValue(), y->getValue()); } + VISIT(CoroHandleInstr); + void handle(const CoroHandleInstr *x, const CoroHandleInstr *y) { result = true; } VISIT(dsl::CustomInstr); void handle(const dsl::CustomInstr *x, const dsl::CustomInstr *y) { result = x->match(y); diff --git a/codon/cir/util/operator.h b/codon/cir/util/operator.h index 2d3231a3..f53d3d61 100644 --- a/codon/cir/util/operator.h +++ b/codon/cir/util/operator.h @@ -117,8 +117,10 @@ class Operator : public Visitor { LAMBDA_VISIT(ContinueInstr); LAMBDA_VISIT(ReturnInstr); LAMBDA_VISIT(YieldInstr); + LAMBDA_VISIT(AwaitInstr); LAMBDA_VISIT(ThrowInstr); LAMBDA_VISIT(FlowInstr); + LAMBDA_VISIT(CoroHandleInstr); LAMBDA_VISIT(dsl::CustomInstr); template void process(Node *v) { v->accept(*this); } diff --git a/codon/cir/util/outlining.cpp b/codon/cir/util/outlining.cpp index 5ee05db9..e6d40aa2 100644 --- a/codon/cir/util/outlining.cpp +++ b/codon/cir/util/outlining.cpp @@ -176,6 +176,11 @@ struct Outliner : public Operator { invalid = true; } + void handle(AwaitInstr *v) override { + if (inRegion) + invalid = true; + } + void handle(YieldInInstr *v) override { if (inRegion) invalid = true; @@ -186,6 +191,11 @@ struct Outliner : public Operator { invalid = true; } + void handle(CoroHandleInstr *v) override { + if (inRegion) + invalid = true; + } + void handle(AssignInstr *v) override { if (inRegion) { auto *var = v->getLhs(); diff --git a/codon/cir/util/visitor.cpp b/codon/cir/util/visitor.cpp index 7f80d246..9e6af11f 100644 --- a/codon/cir/util/visitor.cpp +++ b/codon/cir/util/visitor.cpp @@ -46,8 +46,10 @@ void Visitor::visit(ContinueInstr *x) { defaultVisit(x); } void Visitor::visit(ReturnInstr *x) { defaultVisit(x); } void Visitor::visit(TypePropertyInstr *x) { defaultVisit(x); } void Visitor::visit(YieldInstr *x) { defaultVisit(x); } +void Visitor::visit(AwaitInstr *x) { defaultVisit(x); } void Visitor::visit(ThrowInstr *x) { defaultVisit(x); } void Visitor::visit(FlowInstr *x) { defaultVisit(x); } +void Visitor::visit(CoroHandleInstr *x) { defaultVisit(x); } void Visitor::visit(dsl::CustomInstr *x) { defaultVisit(x); } void Visitor::visit(types::Type *x) { defaultVisit(x); } void Visitor::visit(types::PrimitiveType *x) { defaultVisit(x); } @@ -109,8 +111,10 @@ void ConstVisitor::visit(const ContinueInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const ReturnInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const TypePropertyInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const YieldInstr *x) { defaultVisit(x); } +void ConstVisitor::visit(const AwaitInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const ThrowInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const FlowInstr *x) { defaultVisit(x); } +void ConstVisitor::visit(const CoroHandleInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const dsl::CustomInstr *x) { defaultVisit(x); } void ConstVisitor::visit(const types::Type *x) { defaultVisit(x); } void ConstVisitor::visit(const types::PrimitiveType *x) { defaultVisit(x); } diff --git a/codon/cir/util/visitor.h b/codon/cir/util/visitor.h index cc24f450..f8866847 100644 --- a/codon/cir/util/visitor.h +++ b/codon/cir/util/visitor.h @@ -87,8 +87,10 @@ class BreakInstr; class ContinueInstr; class ReturnInstr; class YieldInstr; +class AwaitInstr; class ThrowInstr; class FlowInstr; +class CoroHandleInstr; namespace util { @@ -146,8 +148,10 @@ class Visitor { VISIT(ContinueInstr); VISIT(ReturnInstr); VISIT(YieldInstr); + VISIT(AwaitInstr); VISIT(ThrowInstr); VISIT(FlowInstr); + VISIT(CoroHandleInstr); VISIT(dsl::CustomInstr); VISIT(types::Type); @@ -226,8 +230,10 @@ class ConstVisitor { CONST_VISIT(ContinueInstr); CONST_VISIT(ReturnInstr); CONST_VISIT(YieldInstr); + CONST_VISIT(AwaitInstr); CONST_VISIT(ThrowInstr); CONST_VISIT(FlowInstr); + CONST_VISIT(CoroHandleInstr); CONST_VISIT(dsl::CustomInstr); CONST_VISIT(types::Type); diff --git a/codon/compiler/error.h b/codon/compiler/error.h index 318c6143..9269e265 100644 --- a/codon/compiler/error.h +++ b/codon/compiler/error.h @@ -72,7 +72,6 @@ enum Error { FN_OUTSIDE_ERROR, FN_GLOBAL_ASSIGNED, FN_GLOBAL_NOT_FOUND, - FN_NO_DECORATORS, FN_BAD_LLVM, FN_REALIZE_BUILTIN, EXPECTED_LOOP, @@ -366,8 +365,6 @@ template std::string Emsg(Error e, const TA &...args) { fmt::runtime("name '{}' is assigned to before global declaration"), args...); case Error::FN_GLOBAL_NOT_FOUND: return fmt::format(fmt::runtime("no binding for {} '{}' found"), args...); - case Error::FN_NO_DECORATORS: - return fmt::format(fmt::runtime("class methods cannot be decorated"), args...); case Error::FN_BAD_LLVM: return fmt::format("invalid LLVM code"); case Error::FN_REALIZE_BUILTIN: diff --git a/codon/parser/ast/expr.h b/codon/parser/ast/expr.h index b4ebd738..112ee000 100644 --- a/codon/parser/ast/expr.h +++ b/codon/parser/ast/expr.h @@ -375,6 +375,7 @@ struct BinaryExpr : public AcceptorExtend { BinaryExpr(const BinaryExpr &, bool); std::string getOp() const { return op; } + void setOp(const std::string &o) { op = o; } Expr *getLhs() const { return lexpr; } Expr *getRhs() const { return rexpr; } bool isInPlace() const { return inPlace; } diff --git a/codon/parser/cache.cpp b/codon/parser/cache.cpp index c7db48b3..f7d37aa3 100644 --- a/codon/parser/cache.cpp +++ b/codon/parser/cache.cpp @@ -177,6 +177,37 @@ ir::types::Type *Cache::makeUnion(const std::vector &types) { return realizeType(tv.getStdLibType("Union")->getClass(), {argType}); } +size_t Cache::getRealizationId(types::ClassType *type) { + auto cv = TypecheckVisitor(typeCtx).getClassRealization(type); + return cv->id; +} + +std::vector Cache::getBaseRealizationIds(types::ClassType *type) { + auto r = TypecheckVisitor(typeCtx).getClassRealization(type); + std::vector baseIds; + for (const auto &t : r->bases) { + baseIds.push_back(getRealizationId(t.get())); + } + return baseIds; +} + +std::vector Cache::getChildRealizationIds(types::ClassType *type) { + auto cv = TypecheckVisitor(typeCtx).getClassRealization(type); + auto parentId = cv->id; + std::vector childIds; + for (const auto &[_, c] : classes) { + for (const auto &[_, r] : c.realizations) { + for (const auto &t : r->bases) { + if (getRealizationId(t.get()) == parentId) { + childIds.push_back(r->id); + break; + } + } + } + } + return childIds; +} + void Cache::parseCode(const std::string &code) { auto nodeOrErr = ast::parseCode(this, "", code, /*startLine=*/0); if (nodeOrErr) diff --git a/codon/parser/cache.h b/codon/parser/cache.h index 8067d18f..c64d706c 100644 --- a/codon/parser/cache.h +++ b/codon/parser/cache.h @@ -164,6 +164,8 @@ struct Cache { std::vector> fields; /// IR type pointer. codon::ir::types::Type *ir = nullptr; + // Bases (in MRO order) + std::vector> bases; /// Realization vtable (for each base class). /// Maps {base, function signature} to {thunk realization, thunk ID}. @@ -323,6 +325,10 @@ struct Cache { ir::types::Type *makeFunction(const std::vector &types); ir::types::Type *makeUnion(const std::vector &types); + size_t getRealizationId(types::ClassType *type); + std::vector getBaseRealizationIds(types::ClassType *type); + std::vector getChildRealizationIds(types::ClassType *type); + void parseCode(const std::string &code); static std::vector> diff --git a/codon/parser/common.cpp b/codon/parser/common.cpp index 4980d9fd..c57b323e 100644 --- a/codon/parser/common.cpp +++ b/codon/parser/common.cpp @@ -69,6 +69,9 @@ Filesystem::Filesystem(const std::string &argv0, const std::string &module0) for (auto loci : {"../lib/codon/stdlib", "../stdlib", "stdlib"}) { add_search_path(root / loci); } + for (auto loci : {"../lib/codon/plugins", "../plugins"}) { + add_search_path(root / loci); + } } } @@ -369,15 +372,15 @@ std::shared_ptr getImportFile(Cache *cache, const std::string &what, paths.emplace_back(fs->canonical(path)); } } - if (paths.empty()) { - // Load a plugin maybe - auto path = parentRelativeTo / what; - if (fs->exists(path / "plugin.toml") && - fs->exists(path / "stdlib" / "__init__.codon")) { + + auto checkPlugin = [&paths, &fs, &cache](const std::filesystem::path &path, + const std::string &what) { + if (fs->exists(path / what / "plugin.toml") && + fs->exists(path / what / "stdlib" / what / "__init__.codon")) { bool failed = false; - if (cache->compiler && !cache->compiler->isPluginLoaded(path)) { - LOG_REALIZE("Loading plugin {}", path); - llvm::handleAllErrors(cache->compiler->load(path), + if (cache->compiler && !cache->compiler->isPluginLoaded(path / what)) { + LOG_REALIZE("Loading plugin {}", path / what); + llvm::handleAllErrors(cache->compiler->load(path / what), [&failed](const codon::error::PluginErrorInfo &e) { codon::compilationError(e.getMessage(), /*file=*/"", /*line=*/0, /*col=*/0, @@ -388,8 +391,14 @@ std::shared_ptr getImportFile(Cache *cache, const std::string &what, }); } if (!failed) - paths.emplace_back(fs->canonical(path / "stdlib" / "__init__.codon")); + paths.emplace_back( + fs->canonical(path / what / "stdlib" / what / "__init__.codon")); } + }; + + if (paths.empty()) { + // Load a plugin maybe + checkPlugin(parentRelativeTo, what); } for (auto &p : fs->get_stdlib_paths()) { auto path = p / what; @@ -400,6 +409,9 @@ std::shared_ptr getImportFile(Cache *cache, const std::string &what, path = p / what / "__init__.codon"; if (fs->exists(path)) paths.emplace_back(fs->canonical(path)); + + // Load a plugin maybe + checkPlugin(p, what); } if (paths.empty()) diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index e5359e05..9f1a96a9 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -317,7 +317,9 @@ void TranslateVisitor::visit(CallExpr *expr) { arrayType->setAstType(expr->getType()->shared_from_this()); result = make(expr, arrayType, sz); return; - } else if (ei && startswith(ei->getValue(), "__internal__.yield_in_no_suspend")) { + } else if (ei && startswith(ei->getValue(), + getMangledMethod("std.internal.core", "Generator", + "_yield_in_no_suspend"))) { result = make(expr, getType(expr->getType()), false); return; } @@ -460,7 +462,7 @@ void TranslateVisitor::visit(ExprStmt *stmt) { auto ce = cast(stmt->getExpr()); if (ce && ((ei = cast(ce->getExpr()))) && ei->getValue() == - getMangledMethod("std.internal.core", "__internal__", "yield_final")) { + getMangledMethod("std.internal.core", "Generator", "_yield_final")) { result = make(stmt, transform((*ce)[0].value), true); ctx->getBase()->setGenerator(); } else { @@ -661,6 +663,12 @@ void TranslateVisitor::visit(ThrowStmt *stmt) { stmt->getExpr() ? transform(stmt->getExpr()) : nullptr); } +void TranslateVisitor::visit(AwaitStmt *stmt) { + auto type = TypecheckVisitor(ctx->cache->typeCtx) + .extractClassGeneric(stmt->getExpr()->getType()); + result = make(stmt, transform(stmt->getExpr()), getType(type)); +} + void TranslateVisitor::visit(FunctionStmt *stmt) { // Process all realizations. transformFunctionRealizations(stmt->getName(), stmt->hasAttribute(Attr::LLVM)); @@ -739,6 +747,8 @@ void TranslateVisitor::transformFunction(const types::FuncType *type, FunctionSt cast(func)->setBody(body); ctx->popBlock(); } + if (ast->isAsync()) + func->setAsync(); } void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt *ast, diff --git a/codon/parser/visitors/translate/translate.h b/codon/parser/visitors/translate/translate.h index 8acc1b65..30a57907 100644 --- a/codon/parser/visitors/translate/translate.h +++ b/codon/parser/visitors/translate/translate.h @@ -58,6 +58,7 @@ class TranslateVisitor : public CallbackASTVisitor { void visit(IfStmt *) override; void visit(TryStmt *) override; void visit(ThrowStmt *) override; + void visit(AwaitStmt *) override; void visit(FunctionStmt *) override; void visit(ClassStmt *) override; void visit(CommentStmt *) override {} diff --git a/codon/parser/visitors/typecheck/access.cpp b/codon/parser/visitors/typecheck/access.cpp index 1df6dfc1..444c545f 100644 --- a/codon/parser/visitors/typecheck/access.cpp +++ b/codon/parser/visitors/typecheck/access.cpp @@ -176,10 +176,27 @@ void TypecheckVisitor::visit(DotExpr *expr) { wrapSide(N(extractType(expr->getExpr())->prettyString())); return; } + // Special case: cls.__mro__ + if (isTypeExpr(expr->getExpr()) && expr->getMember() == "__mro__") { + if (realize(expr->getExpr()->getType())) { + auto t = extractType(expr->getExpr())->getClass(); + std::vector items; + if (auto c = getClass(t)) { + const auto &mros = c->mro; + for (size_t i = 1; i < mros.size(); i++) { + auto mt = instantiateType(mros[i].get(), t); + seqassert(mt->canRealize(), "cannot realize {}", mt->debugString(2)); + items.push_back(N(mt->realizedName())); + } + } + resultExpr = wrapSide(N(items)); + } + return; + } if (isTypeExpr(expr->getExpr()) && expr->getMember() == "__repr__") { - resultExpr = transform( - N(N(getMangledFunc("std.internal.internal", "__type_repr__")), - expr->getExpr(), N(EllipsisExpr::PARTIAL))); + resultExpr = transform(N( + N(getMangledFunc("std.internal.types.type", "__type_repr__")), + expr->getExpr(), N(EllipsisExpr::PARTIAL))); return; } // Special case: expr.__is_static__ @@ -261,13 +278,13 @@ void TypecheckVisitor::visit(DotExpr *expr) { id = N(0); slf = expr->getExpr(); } - resultExpr = transform(N( - N(getMangledMethod("std.internal.core", "__internal__", - "class_thunk_dispatch")), - std::vector{ - CallArg{"slf", slf}, CallArg{"cls_id", id}, - CallArg{"F", N(bestMethod->ast->name)}, - CallArg{"", N(EllipsisExpr::PARTIAL)}})); + resultExpr = transform( + N(N(getMangledMethod("std.internal.core", "RTTIType", + "_thunk_dispatch")), + std::vector{ + CallArg{"slf", slf}, CallArg{"cls_id", id}, + CallArg{"F", N(bestMethod->ast->name)}, + CallArg{"", N(EllipsisExpr::PARTIAL)}})); return; } @@ -586,12 +603,12 @@ Expr *TypecheckVisitor::getClassMember(DotExpr *expr) { N(expr->getMember()))); } - // Case: transform `union.m` to `__internal__.get_union_method(union, "m", ...)` + // Case: transform `union.m` to `Union._member(union, "m", ...)` if (typ->getUnion()) { if (!typ->canRealize()) return nullptr; // delay! return transform(N( - N(N("__internal__"), "union_member"), + N(N("Union"), "_member"), std::vector{CallArg{"union", expr->getExpr()}, CallArg{"member", N(expr->getMember())}})); } diff --git a/codon/parser/visitors/typecheck/assign.cpp b/codon/parser/visitors/typecheck/assign.cpp index 99e0cb56..63ee3232 100644 --- a/codon/parser/visitors/typecheck/assign.cpp +++ b/codon/parser/visitors/typecheck/assign.cpp @@ -427,8 +427,7 @@ std::pair TypecheckVisitor::transformInplaceUpdate(AssignStmt *stm if (stmt->getLhs()->getType()->is("Capsule")) { return {true, transform(N( - N(N(N("__internal__.capsule_get_ptr"), - stmt->getLhs()), + N(N(N("Capsule._ptr"), stmt->getLhs()), N(0)), stmt->getRhs()))}; } diff --git a/codon/parser/visitors/typecheck/class.cpp b/codon/parser/visitors/typecheck/class.cpp index ab014e03..37e1849c 100644 --- a/codon/parser/visitors/typecheck/class.cpp +++ b/codon/parser/visitors/typecheck/class.cpp @@ -187,11 +187,12 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { val->scope = {0}; registerGlobal(val->canonicalName); auto assign = N( - N(varName), a.getDefault(), + N(varName), transform(a.getDefault()), a.getType() ? cast(a.getType())->getIndex() : nullptr); assign->setUpdate(); varStmts.push_back(assign); cls.classVars[a.getName()] = varName; + ctx->add(a.getName(), val); } else if (!stmt->hasAttribute(Attr::Extend)) { std::string varName = a.getName(); args.emplace_back(varName, transformType(clean_clone(a.getType()), true), @@ -293,7 +294,7 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { } // Add class methods - for (const auto &sp : getClassMethods(stmt->getSuite())) + for (const auto &sp : getClassMethods(stmt->getSuite())) { if (auto fp = cast(sp)) { for (auto *&dc : fp->decorators) { // Handle @setter setters @@ -310,6 +311,7 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { sp->setAttribute(Attr::AutoGenerated); fnStmts.emplace_back(transform(sp)); } + } // After popping context block, record types and nested classes will disappear. // Store their references and re-add them to the context after popping @@ -436,6 +438,16 @@ std::vector TypecheckVisitor::parseBaseClasses( typ->hiddenGenerics.push_back(g); for (auto &g : clsTyp->hiddenGenerics) typ->hiddenGenerics.push_back(g); + + // Add class variables + for (auto &[varName, varCanonicalName] : cachedCls->classVars) { + // Handle class variables. Transform them later to allow self-references + auto newName = fmt::format("{}.{}", canonicalName, varName); + auto newCanonicalName = ctx->generateCanonicalName(newName); + getClass(typ)->classVars[varName] = varCanonicalName; + ctx->add(newName, ctx->forceFind(varCanonicalName)); + ctx->add(newCanonicalName, ctx->forceFind(varCanonicalName)); + } } // Add normal fields auto cls = getClass(canonicalName); @@ -683,7 +695,7 @@ Stmt *TypecheckVisitor::codegenMagic(const std::string &op, Expr *typExpr, ret = I("str"); stmts.emplace_back(N(N(NS(op), I("self")))); } else if (op == "repr_default") { - // def __repr__(self: T) -> str + // def __repr_default__(self: T) -> str fargs.emplace_back("self", clone(typExpr)); ret = I("str"); stmts.emplace_back(N(N(NS(op), I("self")))); diff --git a/codon/parser/visitors/typecheck/collections.cpp b/codon/parser/visitors/typecheck/collections.cpp index 2f34d4a7..ad84d872 100644 --- a/codon/parser/visitors/typecheck/collections.cpp +++ b/codon/parser/visitors/typecheck/collections.cpp @@ -140,6 +140,13 @@ void TypecheckVisitor::visit(GeneratorExpr *expr) { } else { expr->loops = transform(expr->getFinalSuite()); // assume: internal data will be changed + if (!expr->getFinalExpr()) { + // Case such as (0 for _ in static.range(2)) + // TODO: make this better. + E(Error::CUSTOM, expr, + "generator cannot be compiled. If using static tuple generator, use tuple(...) " + "instead."); + } unify(expr->getType(), instantiateType(getStdLibType("Generator"), {expr->getFinalExpr()->getType()})); if (realize(expr->getType())) diff --git a/codon/parser/visitors/typecheck/error.cpp b/codon/parser/visitors/typecheck/error.cpp index 049ff642..e266ca3d 100644 --- a/codon/parser/visitors/typecheck/error.cpp +++ b/codon/parser/visitors/typecheck/error.cpp @@ -125,8 +125,8 @@ void TypecheckVisitor::visit(TryStmt *stmt) { exceptionOK = true; break; } - if (!exceptionOK) - E(Error::CATCH_EXCEPTION_TYPE, c->getException(), t->prettyString()); + // if (!exceptionOK) + // E(Error::CATCH_EXCEPTION_TYPE, c->getException(), t->prettyString()); if (val) unify(val->getType(), extractType(c->getException())); } @@ -175,7 +175,7 @@ void TypecheckVisitor::visit(TryStmt *stmt) { /// Transform `raise` statements. /// @example -/// `raise exc` -> ```raise __internal__.set_header(exc, "fn", "file", line, col)``` +/// `raise exc` -> ```raise BaseException.set_header(exc, "fn", "file", line, col)``` void TypecheckVisitor::visit(ThrowStmt *stmt) { if (!stmt->expr) { stmt->setDone(); @@ -184,17 +184,17 @@ void TypecheckVisitor::visit(ThrowStmt *stmt) { stmt->expr = transform(stmt->getExpr()); if (!match(stmt->getExpr(), - M(M(getMangledMethod("std.internal.core", "__internal__", - "set_header")), + M(M(getMangledMethod("std.internal.types.error", + "BaseException", "_set_header")), M_))) { stmt->expr = transform(N( - N(getMangledMethod("std.internal.core", "__internal__", "set_header")), + N(getMangledMethod("std.internal.types.error", "BaseException", + "_set_header")), stmt->getExpr(), N(ctx->getBase()->name), N(stmt->getSrcInfo().file), N(stmt->getSrcInfo().line), N(stmt->getSrcInfo().col), stmt->getFrom() - ? N(N(N("__internal__"), "class_super"), - stmt->getFrom(), + ? N(N(N("Super"), "_super"), stmt->getFrom(), N(getMangledClass("std.internal.types.error", "BaseException"))) : N(N("NoneType")))); diff --git a/codon/parser/visitors/typecheck/function.cpp b/codon/parser/visitors/typecheck/function.cpp index e84ed40c..c26877e0 100644 --- a/codon/parser/visitors/typecheck/function.cpp +++ b/codon/parser/visitors/typecheck/function.cpp @@ -66,6 +66,7 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) { if (!ctx->inFunction()) E(Error::FN_OUTSIDE_ERROR, stmt, "return"); + auto isAsync = ctx->getBase()->func->isAsync(); if (!stmt->expr && ctx->getBase()->func->hasAttribute(Attr::IsGenerator)) { stmt->setDone(); @@ -96,7 +97,13 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) { ->getStatic() ->getNonStaticType() ->shared_from_this()); - unify(ctx->getBase()->returnType.get(), stmt->getExpr()->getType()); + + if (isAsync) { + unify(ctx->getBase()->returnType.get(), + instantiateType(getStdLibType("Coroutine"), {stmt->getExpr()->getType()})); + } else { + unify(ctx->getBase()->returnType.get(), stmt->getExpr()->getType()); + } } // If we are not within conditional block, ignore later statements in this function. @@ -112,11 +119,35 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) { void TypecheckVisitor::visit(YieldStmt *stmt) { if (!ctx->inFunction()) E(Error::FN_OUTSIDE_ERROR, stmt, "yield"); + auto isAsync = ctx->getBase()->func->isAsync(); stmt->expr = transform(stmt->getExpr() ? stmt->getExpr() : N(N("NoneType"))); unify(ctx->getBase()->returnType.get(), - instantiateType(getStdLibType("Generator"), {stmt->getExpr()->getType()})); + instantiateType(getStdLibType(!isAsync ? "Generator" : "AsyncGenerator"), + {stmt->getExpr()->getType()})); + + if (stmt->getExpr()->isDone()) + stmt->setDone(); +} + +/// Typecheck await statements. +void TypecheckVisitor::visit(AwaitStmt *stmt) { + if (!ctx->inFunction()) + E(Error::FN_OUTSIDE_ERROR, stmt, "await"); + auto isAsync = ctx->getBase()->func->isAsync(); + if (!isAsync) + E(Error::FN_OUTSIDE_ERROR, stmt, "await"); + + stmt->expr = transform(stmt->getExpr()); + + if (auto c = stmt->getExpr()->getType()->getClass()) { + if (!c->is(getMangledClass("std.internal.core", "Coroutine")) && + !c->is(getMangledClass("std.asyncio", "Future")) && + !c->is(getMangledClass("std.asyncio", "Task"))) { + E(Error::EXPECTED_TYPE, stmt, "awaitable"); + } + } if (stmt->getExpr()->isDone()) stmt->setDone(); @@ -137,9 +168,6 @@ void TypecheckVisitor::visit(GlobalStmt *stmt) { resultStmt = N(); } /// Parse a function stub and create a corresponding generic function type. /// Also realize built-ins and extern C functions. void TypecheckVisitor::visit(FunctionStmt *stmt) { - if (stmt->isAsync()) - E(Error::CUSTOM, stmt, "async not yet supported"); - if (stmt->hasAttribute(Attr::Python)) { // Handle Python block resultStmt = @@ -504,7 +532,8 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { stmt->setAttribute(Attr::Module, ctx->moduleName.path); // Make function AST and cache it for later realization - auto f = N(canonicalName, ret, args, suite); + auto f = N(canonicalName, ret, args, suite, std::vector{}, + stmt->isAsync()); f->cloneAttributesFrom(stmt); auto &fn = ctx->cache->functions[canonicalName] = Cache::Function{ctx->getModulePath(), @@ -576,16 +605,45 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { // Parse remaining decorators for (auto i = stmt->decorators.size(); i-- > 0;) { if (stmt->decorators[i]) { - if (isClassMember) - E(Error::FN_NO_DECORATORS, stmt->decorators[i]); // Replace each decorator with `decorator(finalExpr)` in the reverse order finalExpr = N(stmt->decorators[i], finalExpr ? finalExpr : N(canonicalName)); } } if (finalExpr) { - resultStmt = N( - f, transform(N(N(stmt->getName()), finalExpr))); + auto a = N(N(stmt->getName()), finalExpr); + if (isClassMember) { // class method decorator + auto nctx = std::make_shared(ctx->cache); + *nctx = *ctx; + nctx->bases.pop_back(); + nctx->bases.erase(nctx->bases.begin() + 1, nctx->bases.end()); // global context + auto tv = TypecheckVisitor(nctx); + + auto defName = ctx->generateCanonicalName(stmt->getName()); + preamble->addStmt( + tv.transform(N(N(defName), nullptr, nullptr))); + registerGlobal(defName); + a->setUpdate(); + + cast(a->getLhs())->value = defName; + std::vector args; + for (auto arg : *stmt) { + if (startswith(arg.name, "**")) + args.push_back(N(N(arg.name))); + else if (startswith(arg.name, "*")) + args.push_back(N(N(arg.name))); + else + args.push_back(N(arg.name)); + } + Stmt *newFunc = N( + stmt->getName(), clone(stmt->getReturn()), clone(stmt->items), + N(N(N(N(defName), args))), + std::vector{}, stmt->isAsync()); + newFunc = transform(newFunc); + resultStmt = N(f, N(transform(a), newFunc)); + } else { + resultStmt = N(f, transform(a)); + } } else { resultStmt = f; } @@ -641,6 +699,25 @@ Stmt *TypecheckVisitor::transformLLVMDefinition(Stmt *codeStmt) { auto m = match(codeStmt, M(MVar(codeExpr))); seqassert(m, "invalid LLVM definition"); auto code = codeExpr->getValue(); + /// Remove docstring (if any) + size_t start = 0; + while (start < code.size() && std::isspace(code[start])) + start++; + if (startswith(code.substr(start), "\"\"\"")) { + start += 3; + bool found = false; + while (start < code.size() - 2) { + if (code[start] == '"' && code[start + 1] == '"' && code[start + 2] == '"') { + found = true; + start += 3; + break; + } + start++; + } + if (found) { + code = code.substr(start); + } + } std::vector items; std::string finalCode; diff --git a/codon/parser/visitors/typecheck/import.cpp b/codon/parser/visitors/typecheck/import.cpp index 591891d2..1da91d15 100644 --- a/codon/parser/visitors/typecheck/import.cpp +++ b/codon/parser/visitors/typecheck/import.cpp @@ -352,7 +352,8 @@ Stmt *TypecheckVisitor::transformNewImport(const ImportFile &file) { (ctx->isStdlibLoading || (ctx->isGlobal() && ctx->scope.size() == 1)); auto importVar = import.importVar = getTemporaryVar(fmt::format("import_{}", moduleID)); - LOG_TYPECHECK("[import] initializing {} ({})", importVar, import.loadedAtToplevel); + LOG_REALIZE("[import] initializing {} (location: {}, toplevel: {})", importVar, + file.path, import.loadedAtToplevel); // __name__ = [import name] Stmt *n = nullptr; diff --git a/codon/parser/visitors/typecheck/infer.cpp b/codon/parser/visitors/typecheck/infer.cpp index 7d95418a..0b8d6878 100644 --- a/codon/parser/visitors/typecheck/infer.cpp +++ b/codon/parser/visitors/typecheck/infer.cpp @@ -301,6 +301,13 @@ types::Type *TypecheckVisitor::realizeType(types::ClassType *type) { realization->type = rt; realization->id = ++ctx->cache->classRealizationCnt; + const auto &mros = getClass(realized)->mro; + for (size_t i = 1; i < mros.size(); i++) { + auto mt = instantiateType(mros[i].get(), realized); + seqassert(mt->canRealize(), "cannot realize {}", mt->debugString(2)); + realization->bases.push_back(mt); + } + // Create LLVM stub auto lt = makeIRType(realized); @@ -485,7 +492,10 @@ types::Type *TypecheckVisitor::realizeFunc(types::FuncType *type, bool force) { // Use NoneType as the return type when the return type is not specified and // function has no return statement if (!ast->getReturn() && isUnbound(type->getRetType())) { - unify(type->getRetType(), getStdLibType("NoneType")); + auto rt = getStdLibType("NoneType")->shared_from_this(); + if (ast->isAsync()) + rt = instantiateType(getStdLibType("Coroutine"), {rt.get()}); + unify(type->getRetType(), rt.get()); } } // Realize the return type @@ -507,7 +517,8 @@ types::Type *TypecheckVisitor::realizeFunc(types::FuncType *type, bool force) { args.emplace_back(varName, nullptr, nullptr, i.status); } r->ast = - N(r->type->realizedName(), nullptr, args, ctx->getBase()->suite); + N(r->type->realizedName(), nullptr, args, ctx->getBase()->suite, + std::vector{}, ast->isAsync()); r->ast->setSrcInfo(ast->getSrcInfo()); r->ast->cloneAttributesFrom(ast); @@ -609,6 +620,9 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) { } else if (t->name == "Generator") { seqassert(types.size() == 1, "bad generics/statics"); handle = module->unsafeGetGeneratorType(types[0]); + } else if (t->name == "Coroutine") { + seqassert(types.size() == 1, "bad generics/statics"); + handle = module->unsafeGetGeneratorType(types[0]); } else if (t->name == TYPE_OPTIONAL) { seqassert(types.size() == 1, "bad generics/statics"); handle = module->unsafeGetOptionalType(types[0]); diff --git a/codon/parser/visitors/typecheck/op.cpp b/codon/parser/visitors/typecheck/op.cpp index ad8555eb..1d3aaade 100644 --- a/codon/parser/visitors/typecheck/op.cpp +++ b/codon/parser/visitors/typecheck/op.cpp @@ -751,18 +751,18 @@ Expr *TypecheckVisitor::transformBinarySimple(const BinaryExpr *expr) { N(N(expr->getRhs(), "__bool__")), N(false))); } else { - return transform(N( - N(getMangledMethod("std.internal.core", "__internal__", "and_union")), - expr->getLhs(), expr->getRhs())); + return transform( + N(N(getMangledMethod("std.internal.core", "Union", "_and")), + expr->getLhs(), expr->getRhs())); } } else if (expr->getOp() == "||") { if (ctx->expectedType && ctx->expectedType->is("bool")) { return transform(N(expr->getLhs(), N(true), N(N(expr->getRhs(), "__bool__")))); } else { - return transform(N( - N(getMangledMethod("std.internal.core", "__internal__", "or_union")), - expr->getLhs(), expr->getRhs())); + return transform( + N(N(getMangledMethod("std.internal.core", "Union", "_or")), + expr->getLhs(), expr->getRhs())); } } else if (expr->getOp() == "not in") { return transform(N(N( diff --git a/codon/parser/visitors/typecheck/special.cpp b/codon/parser/visitors/typecheck/special.cpp index 3c5439db..48191e5e 100644 --- a/codon/parser/visitors/typecheck/special.cpp +++ b/codon/parser/visitors/typecheck/special.cpp @@ -21,15 +21,15 @@ namespace codon::ast { using namespace types; -/// Generate ASTs for all __internal__ functions that deal with vtable generation. +/// Generate ASTs for all internal functions that deal with vtable generation. /// Intended to be called once the typechecking is done. /// TODO: add JIT compatibility. void TypecheckVisitor::prepareVTables() { - // def class_get_thunk_id(F, T): + // def RTTIType._get_thunk_id(F, T): // return VID - auto fn = getFunction( - getMangledMethod("std.internal.core", "__internal__", "class_get_thunk_id")); + auto fn = + getFunction(getMangledMethod("std.internal.core", "RTTIType", "_get_thunk_id")); auto oldAst = fn->ast; // Keep iterating as thunks can generate more thunks. std::unordered_set cache; @@ -49,17 +49,16 @@ void TypecheckVisitor::prepareVTables() { } fn = getFunction( - getMangledMethod("std.internal.core", "__internal__", "class_populate_vtables")); + getMangledMethod("std.internal.core", "RTTIType", "_populate_vtables")); fn->ast->suite = generateClassPopulateVTablesAST(); auto typ = fn->realizations.begin()->second->getType(); typ->ast = fn->ast; LOG_REALIZE("[poly] {} : {}", typ->debugString(2), fn->ast->toString(2)); realizeFunc(typ, true); - // def class_base_derived_dist(B, D): + // def RTTIType._dist(B, D): // return Tuple[].__elemsize__ - fn = getFunction( - getMangledMethod("std.internal.core", "__internal__", "class_base_derived_dist")); + fn = getFunction(getMangledMethod("std.internal.core", "RTTIType", "_dist")); oldAst = fn->ast; for (const auto &real : fn->realizations | std::views::values) { fn->ast->suite = generateBaseDerivedDistAST(real->getType()); @@ -76,11 +75,10 @@ SuiteStmt *TypecheckVisitor::generateClassPopulateVTablesAST() { for (const auto &[r, real] : cls.realizations) { if (real->vtable.empty()) continue; - // __internal__.class_set_rtti_vtable(real.ID, size, real.type) + // RTTIType._init_vtable(size, real.type) suite->addStmt(N( - N(N(N("__internal__"), "class_set_rtti_vtable"), - N(real->id), N(ctx->cache->thunkIds.size() + 2), - N(r)))); + N(N(N("RTTIType"), "_init_vtable"), + N(ctx->cache->thunkIds.size() + 2), N(r)))); LOG_REALIZE("[poly] {} -> {}", r, real->id); for (const auto &[key, fn] : real->vtable) { auto id = in(ctx->cache->thunkIds, key); @@ -97,10 +95,10 @@ SuiteStmt *TypecheckVisitor::generateClassPopulateVTablesAST() { std::vector{N(N(TYPE_TUPLE), ids), N(fn->getRetType()->realizedName())}), N(fn->realizedName())); - suite->addStmt(N(N( - N(N("__internal__"), "class_set_rtti_vtable_fn"), - N(real->id), N(int64_t(*id)), - N(N(fnCall, "__raw__")), N(r)))); + suite->addStmt(N( + N(N(N("RTTIType"), "_set_vtable_fn"), + N(real->id), N(int64_t(*id)), + N(N(fnCall, "__raw__")), N(r)))); } } } @@ -116,12 +114,16 @@ SuiteStmt *TypecheckVisitor::generateBaseDerivedDistAST(FuncType *f) { } } + std::unordered_set alreadyDerived; + for (auto &m : getClass(baseTyp)->mro) + alreadyDerived.insert(m->name); + auto derivedTyp = extractFuncGeneric(f, 1)->getClass(); auto fields = getClassFields(derivedTyp); auto types = std::vector{}; auto found = false; for (auto &fld : fields) { - if (fld.baseClass == baseTyp->name) { + if (in(alreadyDerived, fld.baseClass)) { found = true; break; } else { @@ -165,7 +167,7 @@ FunctionStmt *TypecheckVisitor::generateThunkAST(const FuncType *fp, ClassType * // Thunk contents: // def _thunk...(self, ): // return ( - // __internal__.class_base_to_derived(self, , ), + // RTTIType._to_derived(self, , ), // ) std::vector fnArgs; fnArgs.emplace_back("self", N(base->realizedName()), nullptr); @@ -173,9 +175,9 @@ FunctionStmt *TypecheckVisitor::generateThunkAST(const FuncType *fp, ClassType * fnArgs.emplace_back(getUnmangledName((*fp->ast)[i].getName()), N(args[i]->realizedName()), nullptr); std::vector callArgs; - callArgs.emplace_back(N( - N(N("__internal__"), "class_base_to_derived"), N("self"), - N(base->realizedName()), N(derived->realizedName()))); + callArgs.emplace_back(N(N(N("RTTIType"), "_to_derived"), + N("self"), N(base->realizedName()), + N(derived->realizedName()))); for (size_t i = 1; i < args.size(); i++) callArgs.emplace_back(N(getUnmangledName((*fp->ast)[i].getName()))); @@ -187,10 +189,9 @@ FunctionStmt *TypecheckVisitor::generateThunkAST(const FuncType *fp, ClassType * thunkName, nullptr, fnArgs, N( // For debugging - N(N( - N(getMangledMethod("std.internal.core", "__internal__", - "class_thunk_debug")), - debugCallArgs)), + N(N(N(getMangledMethod( + "std.internal.core", "RTTIType", "_thunk_debug")), + debugCallArgs)), N(N(N(m->ast->getName()), callArgs)))); thunkAst->setAttribute(Attr::Inline); return cast(transform(thunkAst)); @@ -296,24 +297,23 @@ SuiteStmt *TypecheckVisitor::generateUnionNewAST(const FuncType *type) { auto unionType = type->funcParent->getUnion(); seqassert(unionType, "expected union, got {}", *(type->funcParent)); - Stmt *suite = N(N( - N(N("__internal__"), "new_union"), - N(type->ast->begin()->name), N(unionType->realizedName()))); + Stmt *suite = N(N(N(N("Union"), "_new"), + N(type->ast->begin()->name), + N(unionType->realizedName()))); return SuiteStmt::wrap(suite); } SuiteStmt *TypecheckVisitor::generateUnionTagAST(FuncType *type) { - // return __internal__.union_get_data(union, T0) + // return Union._get_data(union, T0) auto tag = getIntLiteral(extractFuncGeneric(type)); auto unionType = extractFuncArgType(type)->getUnion(); auto unionTypes = unionType->getRealizationTypes(); if (tag < 0 || tag >= unionTypes.size()) E(Error::CUSTOM, getSrcInfo(), "bad union tag"); auto selfVar = type->ast->begin()->name; - auto suite = N(N( - N(N(getMangledMethod("std.internal.core", "__internal__", - "union_get_data")), - N(selfVar), N(unionTypes[tag]->realizedName())))); + auto suite = N(N(N( + N(getMangledMethod("std.internal.core", "Union", "_get_data")), + N(selfVar), N(unionTypes[tag]->realizedName())))); return suite; } @@ -359,11 +359,11 @@ SuiteStmt *TypecheckVisitor::generateSpecialAst(types::FuncType *type) { return generateFunctionCallInternalAST(type); } else if (startswith(ast->name, "Union.__new__")) { return generateUnionNewAST(type); - } else if (startswith(ast->name, getMangledMethod("std.internal.core", "__internal__", - "get_union_tag"))) { + } else if (startswith(ast->name, + getMangledMethod("std.internal.core", "Union", "_tag"))) { return generateUnionTagAST(type); - } else if (startswith(ast->name, getMangledMethod("std.internal.core", "__internal__", - "namedkeys"))) { + } else if (startswith(ast->name, getMangledMethod("std.internal.core", "NamedTuple", + "_namedkeys"))) { return generateNamedKeysAST(type); } else if (startswith(ast->name, getMangledMethod("std.internal.core", "__magic__", "mul"))) { @@ -498,8 +498,8 @@ Expr *TypecheckVisitor::transformSuper() { auto typExpr = N(superTyp->getClass()->name); typExpr->setType(instantiateTypeVar(superTyp->getClass())); - return transform(N(N(N("__internal__"), "class_super"), - self, typExpr, N(1))); + return transform(N(N(N("Super"), "_super"), self, + typExpr, N(1))); } const auto &name = cands.front(); // the first inherited type @@ -517,7 +517,7 @@ Expr *TypecheckVisitor::transformSuper() { e->setType(superTyp->shared_from_this()); return e; } else { - // Case: reference types. Return `__internal__.class_super(self, T)` + // Case: reference types. Return `Super._super(self, T)` auto self = N(funcTyp->ast->begin()->name); self->setType(typ->shared_from_this()); return castToSuperClass(self, superTyp->getClass()); @@ -621,10 +621,10 @@ Expr *TypecheckVisitor::transformIsInstance(CallExpr *expr) { } if (tag == -1) return transform(N(false)); - return transform(N( - N(N(N("__internal__"), "union_get_tag"), - expr->begin()->getExpr()), - "==", N(tag))); + return transform( + N(N(N(N("Union"), "_get_tag"), + expr->begin()->getExpr()), + "==", N(tag))); } else if (typExpr->getType()->is("pyobj")) { if (typ->is("pyobj")) { return transform( @@ -1217,22 +1217,40 @@ TypecheckVisitor::populateStaticVarsLoop(Expr *iter, std::vector block; auto typ = extractFuncArgType(fn->getType())->getClass(); size_t idx = 0; - for (auto &f : getClassFields(typ)) { - std::vector stmts; - if (withIdx) { + if (typ->is("TypeWrap")) { // type passed! + for (auto &f : getClass(extractClassGeneric(typ))->classVars) { + std::vector stmts; + if (withIdx) { + stmts.push_back( + N(N(vars[0]), N(idx), + N(N("Literal"), N("int")))); + } + stmts.push_back( + N(N(vars[withIdx]), N(f.first), + N(N("Literal"), N("str")))); + stmts.push_back(N(N(vars[withIdx + 1]), N(f.second))); + auto b = N(stmts); + block.push_back(b); + idx++; + } + } else { + for (auto &f : getClassFields(typ)) { + std::vector stmts; + if (withIdx) { + stmts.push_back( + N(N(vars[0]), N(idx), + N(N("Literal"), N("int")))); + } stmts.push_back( - N(N(vars[0]), N(idx), - N(N("Literal"), N("int")))); + N(N(vars[withIdx]), N(f.name), + N(N("Literal"), N("str")))); + stmts.push_back( + N(N(vars[withIdx + 1]), + N(clone((*cast(iter))[0].value), f.name))); + auto b = N(stmts); + block.push_back(b); + idx++; } - stmts.push_back( - N(N(vars[withIdx]), N(f.name), - N(N("Literal"), N("str")))); - stmts.push_back( - N(N(vars[withIdx + 1]), - N(clone((*cast(iter))[0].value), f.name))); - auto b = N(stmts); - block.push_back(b); - idx++; } return block; } diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index 27acf5b3..df160ad5 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -386,10 +386,6 @@ void TypecheckVisitor::visit(ExprStmt *stmt) { stmt->setDone(); } -void TypecheckVisitor::visit(AwaitStmt *stmt) { - E(Error::CUSTOM, stmt, "await not yet supported"); -} - void TypecheckVisitor::visit(CustomStmt *stmt) { if (stmt->getSuite()) { auto fn = in(ctx->cache->customBlockStmts, stmt->getKeyword()); @@ -581,9 +577,9 @@ std::vector TypecheckVisitor::findMatchingMethods( /// expected `Optional[T]`, got `T` -> `Optional(expr)` /// expected `T`, got `Optional[T]` -> `unwrap(expr)` /// expected `Function`, got a function -> partialize function -/// expected `T`, got `Union[T...]` -> `__internal__.get_union(expr, T)` -/// expected `Union[T...]`, got `T` -> `__internal__.new_union(expr, -/// Union[T...])` expected base class, got derived -> downcast to base class +/// expected `T`, got `Union[T...]` -> `Union._get(expr, T)` +/// expected `Union[T...]`, got `T` -> `Union._new(expr, Union[T...])` +/// expected base class, got derived -> downcast to base class /// @param allowUnwrap allow optional unwrapping. bool TypecheckVisitor::wrapExpr(Expr **expr, Type *expectedType, FuncType *callee, bool allowUnwrap) { @@ -643,14 +639,28 @@ TypecheckVisitor::canWrapExpr(Type *exprType, Type *expectedType, FuncType *call exprClass->is("Capsule")) { type = extractClassGeneric(exprClass)->shared_from_this(); fn = [&](Expr *expr) -> Expr * { - return N(N("__internal__.capsule_get"), expr); + return N(N("Capsule._get"), expr); }; } else if (expectedClass && expectedClass->is("Capsule") && exprClass && !exprClass->is("Capsule")) { type = instantiateType(getStdLibType("Capsule"), std::vector{exprClass}); fn = [&](Expr *expr) -> Expr * { - return N(N("__internal__.capsule_make"), expr); + return N(N("Capsule._make"), expr); + }; + } + + else if (expectedClass && !expectedClass->is("Any") && exprClass && + exprClass->is("Any")) { + type = expectedClass->shared_from_this(); + fn = [this, type](Expr *expr) -> Expr * { + auto r = realize(type.get()); + seqassert(r, "not realizable"); + return N(N("Any.unwrap"), expr, N(r->realizedName())); }; + } else if (expectedClass && expectedClass->is("Any") && exprClass && + !exprClass->is("Any")) { + type = expectedClass->shared_from_this(); + fn = [&](Expr *expr) -> Expr * { return N(N("Any"), expr); }; } else if (expectedClass && expectedClass->is("Generator") && @@ -846,7 +856,7 @@ TypecheckVisitor::canWrapExpr(Type *exprType, Type *expectedType, FuncType *call else if (allowUnwrap && exprClass && exprType->getUnion() && expectedClass && !expectedClass->getUnion()) { - // Extract union types via __internal__.get_union + // Extract union types via Union._get if (auto t = realize(expectedClass)) { auto e = realize(exprType); if (!e) @@ -861,9 +871,9 @@ TypecheckVisitor::canWrapExpr(Type *exprType, Type *expectedType, FuncType *call if (ok) { type = t->shared_from_this(); fn = [this, type](Expr *expr) -> Expr * { - return N(N(getMangledMethod("std.internal.core", - "__internal__", "get_union")), - expr, N(type->realizedName())); + return N( + N(getMangledMethod("std.internal.core", "Union", "_get")), expr, + N(type->realizedName())); }; } } else { @@ -872,7 +882,7 @@ TypecheckVisitor::canWrapExpr(Type *exprType, Type *expectedType, FuncType *call } else if (exprClass && expectedClass && expectedClass->getUnion()) { - // Make union types via __internal__.new_union + // Make union types via Union._new if (!expectedClass->getUnion()->isSealed()) { if (!expectedClass->getUnion()->addType(exprClass)) E(error::Error::UNION_TOO_BIG, expectedClass->getSrcInfo(), @@ -882,7 +892,7 @@ TypecheckVisitor::canWrapExpr(Type *exprType, Type *expectedType, FuncType *call if (expectedClass->unify(exprClass, nullptr) == -1) { type = t->shared_from_this(); fn = [this, type](Expr *expr) -> Expr * { - return N(N(N("__internal__"), "new_union"), expr, + return N(N(N("Union"), "_new"), expr, N(type->realizedName())); }; } @@ -905,9 +915,8 @@ TypecheckVisitor::canWrapExpr(Type *exprType, Type *expectedType, FuncType *call // Super[T] to T type = extractClassGeneric(exprClass)->shared_from_this(); fn = [this](Expr *expr) -> Expr * { - return N(N(getMangledMethod("std.internal.core", "__internal__", - "class_super_change_rtti")), - expr); + return N( + N(getMangledMethod("std.internal.core", "Super", "_unwrap")), expr); }; } @@ -943,7 +952,7 @@ Expr *TypecheckVisitor::castToSuperClass(Expr *expr, ClassType *superTyp, realize(superTyp); auto typExpr = N(superTyp->realizedName()); return transform( - N(N(N("__internal__"), "class_super"), expr, typExpr)); + N(N(N("Super"), "_super"), expr, typExpr)); } /// Unpack a Tuple or KwTuple expression into (name, type) vector. diff --git a/codon/runtime/exc.cpp b/codon/runtime/exc.cpp index 34db8839..2bdab412 100644 --- a/codon/runtime/exc.cpp +++ b/codon/runtime/exc.cpp @@ -133,21 +133,30 @@ template static uintptr_t ReadType(const uint8_t *&p) { } } // namespace -struct OurExceptionType_t { +// Note: this should match Codon definition +struct TypeInfo { + seq_int_t id; + seq_int_t *parent_ids; + seq_str_t raw_name; + // other fields do not need to be included +}; + +struct RTTIObject { + void *data; + TypeInfo *type; +}; + +struct CodonBaseExceptionType { int type; }; -struct OurBaseException_t { - OurExceptionType_t type; // Seq exception type - void *obj; // Seq exception instance +struct CodonBaseException { + void *obj; Backtrace bt; _Unwind_Exception unwindException; }; -typedef struct OurBaseException_t OurException; - -struct SeqExcHeader_t { - seq_str_t type; +struct CodonExceptionHeader { seq_str_t msg; seq_str_t func; seq_str_t file; @@ -172,7 +181,7 @@ void seq_exc_init(int flags) { static void seq_delete_exc(_Unwind_Exception *expToDelete) { if (!expToDelete || expToDelete->exception_class != SEQ_EXCEPTION_CLASS) return; - auto *exc = (OurException *)((char *)expToDelete + seq_exc_offset()); + auto *exc = (CodonBaseException *)((char *)expToDelete + seq_exc_offset()); if (seq_flags & SEQ_FLAG_DEBUG) { exc->bt.free(); } @@ -187,11 +196,10 @@ static void seq_delete_unwind_exc(_Unwind_Reason_Code reason, static struct backtrace_state *state = nullptr; static std::mutex stateLock; -SEQ_FUNC void *seq_alloc_exc(int type, void *obj) { - const size_t size = sizeof(OurException); - auto *e = (OurException *)memset(seq_alloc(size), 0, size); +SEQ_FUNC void *seq_alloc_exc(void *obj) { + const size_t size = sizeof(CodonBaseException); + auto *e = (CodonBaseException *)memset(seq_alloc(size), 0, size); assert(e); - e->type.type = type; e->obj = obj; e->unwindException.exception_class = SEQ_EXCEPTION_CLASS; e->unwindException.exception_cleanup = seq_delete_unwind_exc; @@ -236,11 +244,12 @@ static void print_from_last_dot(seq_str_t s, std::ostringstream &buf) { static std::function jitErrorCallback; SEQ_FUNC void seq_terminate(void *exc) { - auto *base = (OurBaseException_t *)((char *)exc + seq_exc_offset()); + auto *base = (CodonBaseException *)((char *)exc + seq_exc_offset()); void *obj = base->obj; - auto *hdr = (SeqExcHeader_t *)obj; + auto *hdr = *(CodonExceptionHeader **)obj; + auto tname = ((RTTIObject *)obj)->type->raw_name; - if (std::string(hdr->type.str, hdr->type.len) == "SystemExit") { + if (std::string(tname.str, tname.len) == "SystemExit") { seq_int_t status = *(seq_int_t *)(hdr + 1); exit((int)status); } @@ -250,7 +259,7 @@ SEQ_FUNC void seq_terminate(void *exc) { buf << codon::runtime::getCapturedOutput(); buf << "\033[1m"; - print_from_last_dot(hdr->type, buf); + print_from_last_dot(tname, buf); if (hdr->msg.len > 0) { buf << ": \033[0m"; buf.write(hdr->msg.str, hdr->msg.len); @@ -292,7 +301,7 @@ SEQ_FUNC void seq_terminate(void *exc) { auto *bt = &base->bt; std::string msg(hdr->msg.str, hdr->msg.len); std::string file(hdr->file.str, hdr->file.len); - std::string type(hdr->type.str, hdr->type.len); + std::string type(tname.str, tname.len); std::vector backtrace; if (seq_flags & SEQ_FLAG_DEBUG) { @@ -447,6 +456,21 @@ static uintptr_t readEncodedPointer(const uint8_t **data, uint8_t encoding) { return result; } +static bool isinstance(void *obj, seq_int_t type) { + auto *info = ((RTTIObject *)obj)->type; + if (info->id == type) + return true; + if (info->parent_ids) { + auto *p = info->parent_ids; + while (*p) { + if (*p++ == type) { + return true; + } + } + } + return false; +} + static bool handleActionValue(int64_t *resultAction, uint8_t TTypeEncoding, const uint8_t *ClassInfo, uintptr_t actionEntry, uint64_t exceptionClass, @@ -457,10 +481,7 @@ static bool handleActionValue(int64_t *resultAction, uint8_t TTypeEncoding, return ret; auto *excp = - (struct OurBaseException_t *)(((char *)exceptionObject) + seq_exc_offset()); - OurExceptionType_t *excpType = &(excp->type); - seq_int_t type = excpType->type; - + (struct CodonBaseException *)(((char *)exceptionObject) + seq_exc_offset()); const uint8_t *actionPos = (uint8_t *)actionEntry, *tempActionPos; int64_t typeOffset = 0, actionOffset; @@ -480,10 +501,11 @@ static bool handleActionValue(int64_t *resultAction, uint8_t TTypeEncoding, unsigned EncSize = getEncodingSize(TTypeEncoding); const uint8_t *EntryP = ClassInfo - typeOffset * EncSize; uintptr_t P = readEncodedPointer(&EntryP, TTypeEncoding); - auto *ThisClassInfo = reinterpret_cast(P); + auto *ThisClassInfo = reinterpret_cast(P); + auto ThisClassType = ThisClassInfo->type; // type=0 means catch-all - if (ThisClassInfo->type == 0 || ThisClassInfo->type == type) { - *resultAction = i + 1; + if (ThisClassType == 0 || isinstance(excp->obj, ThisClassType)) { + *resultAction = ThisClassType; ret = true; break; } @@ -625,7 +647,7 @@ SEQ_FUNC _Unwind_Reason_Code seq_personality(int version, _Unwind_Action actions } SEQ_FUNC int64_t seq_exc_offset() { - static OurBaseException_t dummy = {}; + static CodonBaseException dummy = {}; return (int64_t)((uintptr_t)&dummy - (uintptr_t)&(dummy.unwindException)); } diff --git a/codon/runtime/lib.h b/codon/runtime/lib.h index 99361181..c7d28e37 100644 --- a/codon/runtime/lib.h +++ b/codon/runtime/lib.h @@ -68,7 +68,7 @@ SEQ_FUNC void seq_gc_remove_roots(void *start, void *end); SEQ_FUNC void seq_gc_clear_roots(); SEQ_FUNC void seq_gc_exclude_static_roots(void *start, void *end); -SEQ_FUNC void *seq_alloc_exc(int type, void *obj); +SEQ_FUNC void *seq_alloc_exc(void *obj); SEQ_FUNC void seq_throw(void *exc); SEQ_FUNC _Unwind_Reason_Code seq_personality(int version, _Unwind_Action actions, uint64_t exceptionClass, diff --git a/docs/language/classes.md b/docs/language/classes.md index 7fde31bd..6c3f4411 100644 --- a/docs/language/classes.md +++ b/docs/language/classes.md @@ -283,23 +283,3 @@ print(foo.x, bar.x) # 1 2 foo.hello() # Foo bar.hello() # Bar ``` - -### Exceptions - -Subclasses of exception classes like `Exception`, `ValueError`, etc. must use static inheritance -in order to be thrown and caught. Furthermore, when calling their parent class's constructor, exception -subclasses must supply their type name as the first argument. Here is an example: - -``` python -class MyException(Static[Exception]): - x: int - - def __init__(self, x: int): - super().__init__('MyException', 'my exception message') - self.x = x - -try: - raise MyException(42) -except MyException as e: - print('caught:', str(e), e.x) # caught: my exception message 42 -``` diff --git a/stdlib/asyncio.codon b/stdlib/asyncio.codon new file mode 100644 index 00000000..b4241559 --- /dev/null +++ b/stdlib/asyncio.codon @@ -0,0 +1,539 @@ +# Copyright (C) 2022-2025 Exaloop Inc. + +from threading import Lock +from time import time as _time, sleep as _sleep +import internal.gc as gc + +class _deque: + info: int + count: int + mask: int + a: Ptr[T] + T: type + + def __init__(self): + bits = 2 + self.info = bits + self.count = 0 + self.mask = (1 << bits) - 1 + self.a = Ptr[T](1 << bits) + + def __len__(self): + return self.count + + def __str__(self): + front = self.front + count = self.count + mask = self.mask + s = ', '.join([repr(self.a[(front + i) & mask]) for i in range(count)]) + return f'deque([{s}])' + + @property + def bits(self): + return self.info & 0x3F + + @bits.setter + def bits(self, value: int): + self.info &= ~0b111111 + self.info |= value + + @property + def front(self): + return self.info >> 6 + + @front.setter + def front(self, value: int): + self.info &= 0b111111 + self.info |= (value << 6) + + def resize(self, new_bits: int): + q = self + old_bits = q.bits + front = q.front + count = q.count + old_size = 1 << old_bits + new_size = 1 << new_bits + + if new_size < q.count: + i = 0 + while i < 64: + if (1 << i) > q.count: + break + i += 1 + new_bits = i + new_size = 1 << new_bits + + if new_bits == old_bits: + return old_bits + + a = q.a + if new_bits > old_bits: + a = Ptr[T](gc.realloc(a.as_byte(), + (1 << new_bits) * gc.sizeof(T), + (1 << old_bits) * gc.sizeof(T))) + + if front + count <= old_size: + if front + count > new_size: + str.memmove(a.as_byte(), + (a + new_size).as_byte(), + (front + count - new_size) * gc.sizeof(T)) + else: + str.memmove((a + (new_size - (old_size - front))).as_byte(), + (a + front).as_byte(), + (old_size - front) * gc.sizeof(T)) + front = new_size - (old_size - front) + + bits = new_bits + mask = (1 << bits) - 1 + if new_bits < old_bits: + a = Ptr[T](gc.realloc(a.as_byte(), + (1 << new_bits) * gc.sizeof(T), + (1 << old_bits) * gc.sizeof(T))) + + q.front = front + q.bits = new_bits + q.mask = mask + q.a = a + + def push(self, v: T): + bits = self.bits + if self.count == (1 << bits): + self.resize(bits + 1) + count = self.count + self.count += 1 + self.a[(count + self.front) & self.mask] = v + + def push_front(self, v: T): + bits = self.bits + if self.count == (1 << bits): + self.resize(bits + 1) + + count = self.count + self.count += 1 + front = self.front + new_front = front - 1 if front else (1 << self.bits) - 1 + self.front = new_front + self.a[new_front] = v + + def pop(self): + self.count -= 1 + return self.a[(self.count + self.front) & self.mask] + + def pop_front(self): + d = self.a[self.front] + self.front = (self.front + 1) & self.mask + self.count -= 1 + return d + + +_FUTURE_STATE_PENDING: Literal[int] = 0 +_FUTURE_STATE_FINISHED: Literal[int] = 1 +_FUTURE_STATE_EXCEPTION: Literal[int] = 2 +_FUTURE_STATE_CANCELLED: Literal[int] = 3 + +@tuple +class WorkItem: + data: cobj # Task or Future + coro: cobj # Coroutine handle + result_size: i32 # Size in bytes of enclosed Future/Task result + is_task: bool # True if `data` is a Task; False if Future + +@tuple +class Timer: + work: WorkItem + when: float + +class LoopCallback: + coro: cobj + next: Optional[LoopCallback] + +class EventLoop: + lock: Lock + worklist: _deque[WorkItem] + timers: Ptr[Timer] + timers_len: int + timers_cap: int + running: bool + stop_flag: bool + +class Future: + _result: R + _exception: Optional[BaseException] + _lock: Lock + _loop: EventLoop + _cancel_msg: str + _done_callbacks: Ptr[cobj] + _done_callbacks_len: int + _done_callbacks_cap: int + _state: int + R: type + +class Task(Static[Future[R]]): + _name: str + _coro: cobj + R: type + +@extend +class WorkItem: + def __new__(coro: Coroutine): + result_size = gc.sizeof(coro.T) + is_task = False + return WorkItem(cobj(), coro.__raw__(), i32(result_size), is_task) + + def __new__(coro: cobj, future: Future): + data = future.__raw__() + result_size = gc.sizeof(future.R) + is_task = False + return WorkItem(data, coro, i32(result_size), is_task) + + def __new__(coro: Coroutine, future: Future): + return WorkItem(coro.__raw__(), future) + + def __new__(task: Task): + data = task.__raw__() + coro = task._coro + result_size = gc.sizeof(task.R) + is_task = True + return WorkItem(data, coro, i32(result_size), is_task) + +class InvalidStateError(Exception): + def __init__(self, message: str = ''): + super().__init__(message) + +class CancelledError(Exception): + def __init__(self, message: str = ''): + super().__init__(message) + +async def _callback_wrapper(callback, *args): + callback(*args) + +@extend +class EventLoop: + def __init__(self): + TIMERS_CAP_INIT: Literal[int] = 8 # must be power of 2 + self.lock = Lock() + self.worklist = _deque[WorkItem]() + self.timers = Ptr[Timer](TIMERS_CAP_INIT) + self.timers_len = 0 + self.timers_cap = TIMERS_CAP_INIT + self.running = False + self.stop_flag = False + + def time(self): + return _time() + + def stop(self): + with self.lock: + self.stop_flag = True + + def _call_soon(self, work: WorkItem): + with self.lock: + self.worklist.push(work) + + def call_soon(self, callback, *args): + self._call_soon(WorkItem(_callback_wrapper(callback, *args))) + # TODO: return Handle + + def call_soon_threadsafe(self, callback, *args): + with self.lock: + return self.call_soon(callback, *args) + + def _call_later(self, work: WorkItem, delay: float): + t = Timer(work, self.time() + delay) + with self.lock: + self._timers_push(t) + + def call_later(self, delay: float, callback, *args): + self._call_later(WorkItem(_callback_wrapper(callback, *args)), delay) + # TODO: return TimerHandle + + def call_at(self, delay: float, callback, *args): + return self.call_later(delay, callback, *args) + + def run_forever(self): + self.running = True + + while True: + work: Optional[WorkItem] = None + now = self.time() + stop = False + + with self.lock: + while self.timers_len > 0 and self.timers[0].when <= now: + self.worklist.push(self._timers_pop().work) + + if len(self.worklist) > 0: + work = self.worklist.pop_front() + + stop = self.stop_flag + + if stop: + break + + if work is not None: + g = Generator[None](work.coro) + g.__resume__() + if work.data and g.__done__(): + str.memcpy(work.data, + g.__promise__().as_byte(), + int(work.result_size)) + else: + sleep_time = 0.01 # 10ms default + with self.lock: + if self.timers_len > 0: + dt = self.timers[0].when - now + if dt > 0 and dt < sleep_time: + sleep_time = dt + + if sleep_time > 0: + _sleep(sleep_time) + + self.running = False + + def _timers_reserve(self, new_cap: int): + old_cap = self.timers_cap + if new_cap <= old_cap: + return + + sz = gc.sizeof(Timer) + self.timers = gc.realloc( + self.timers.as_byte(), + new_cap * sz, old_cap * sz) + self.timers_cap = new_cap + + def _timers_swap(self, i: int, j: int): + timers = self.timers + tmp = timers[i] + timers[i] = timers[j] + timers[j] = tmp + + def _timers_push(self, t: Timer): + if self.timers_len == self.timers_cap: + self._timers_reserve(self.timers_cap * 2) + + i = self.timers_len + self.timers_len += 1 + timers = self.timers + + # Sift up + while i > 0: + parent = (i - 1) >> 1 + if timers[parent].when <= timers[i].when: + break + self._timers_swap(parent, i) + i = parent + + def _timers_pop(self): + if self.timers_len == 0: + return None + + timers = self.timers + out = timers[0] + self.timers_len -= 1 + timers_len = self.timers_len + + if timers_len > 0: + timers[0] = timers[timers_len] + i = 0 + while True: + left = 2*i + 1 + right = 2*i + 2 + smallest = i + + if (left < timers_len and timers[left].when < timers[smallest].when): + smallest = left + if (right < timers_len and timers[right].when < timers[smallest].when): + smallest = right + if smallest == i: + break + + self._timers_swap(i, smallest) + i = smallest + + return out + + def create_future(self, T: type = NoneType): + return Future[T](loop=self) + + def create_task(self, coro: Coroutine, name: Optional[str] = None): + task = Task(coro, loop=self, name=name) + work = WorkItem(task) + self._call_soon(work) + return task + + +_running_loop = EventLoop() + +def get_running_loop(): + if not _running_loop.running: + raise RuntimeError("no running event loop") + return _running_loop + +def create_task(coro, name: Optional[str] = None): + return get_running_loop().create_task(coro, name=name) + + +@extend +class Future: + def __init__(self, loop: Optional[EventLoop] = None): + self._exception = None + self._lock = Lock() + if loop is None: + self._loop = _running_loop + else: + self._loop = loop + self._cancel_msg = '' + self._done_callbacks = Ptr[cobj]() + self._done_callbacks_len = 0 + self._done_callbacks_cap = 0 + self._state = _FUTURE_STATE_PENDING + + def _result_size(self): + return gc.sizeof(R) + + def _reset_callbacks(self): + if self._done_callbacks_cap > 0: + gc.free(self._done_callbacks.as_byte()) + self._done_callbacks = Ptr[cobj]() + self._done_callbacks_len = 0 + self._done_callbacks_cap = 0 + + def _add_done_callback(self, coro: cobj): + lock = self._lock + lock.acquire() + + if self.done(): + lock.release() + self._loop._call_soon(WorkItem(coro, future=self)) + return + + n = self._done_callbacks_len + m = self._done_callbacks_cap + + if m == 0: + self._done_callbacks = Ptr[cobj](1) + self._done_callbacks_cap = 1 + elif n >= m: + new_m = m * 2 + sz = gc.sizeof(cobj) + self._done_callbacks = Ptr[cobj]( + gc.realloc( + self._done_callbacks.as_byte(), + new_m * sz, m * sz)) + self._done_callbacks_cap = new_m + + self._done_callbacks[n] = coro + self._done_callbacks_len += 1 + lock.release() + + def _schedule_callbacks(self, callbacks: Ptr[cobj], num_callbacks: int): + for i in range(num_callbacks): + self._loop._call_soon(callbacks[i]) + + def result(self): + with self._lock: + state = self._state + if state == _FUTURE_STATE_CANCELLED: + raise CancelledError(self._cancel_msg) + elif state == _FUTURE_STATE_EXCEPTION: + raise self._exception.__val__() + elif state == _FUTURE_STATE_PENDING: + raise InvalidStateError("Result is not set.") + else: + return self._result + + def set_result(self, result: R): + callbacks = Ptr[cobj]() + num_callbacks = 0 + + with self._lock: + if self.done(): + raise InvalidStateError("Invalid state") + + self._result = result + self._state = _FUTURE_STATE_FINISHED + + callbacks = self._done_callbacks + num_callbacks = self._done_callbacks_len + self._reset_callbacks() + + self._schedule_callbacks(callbacks, num_callbacks) + + def cancelled(self): + return self._state == _FUTURE_STATE_CANCELLED + + def done(self): + return self._state != _FUTURE_STATE_PENDING + + def cancel(self, msg: Optional[str] = None): + callbacks = Ptr[cobj]() + num_callbacks = 0 + + with self._lock: + if self.done(): + return False + + self._state = _FUTURE_STATE_CANCELLED + if msg is not None: + self._cancel_msg = msg + + callbacks = self._done_callbacks + num_callbacks = self._done_callbacks_len + self._reset_callbacks() + + self._schedule_callbacks(callbacks, num_callbacks) + return True + + def get_loop(self): + return self._loop + + def exception(self) -> Optional[BaseException]: + with self._lock: + state = self._state + if state == _FUTURE_STATE_CANCELLED: + raise CancelledError(self._cancel_msg) + elif state == _FUTURE_STATE_EXCEPTION: + return self._exception.__val__() + elif state == _FUTURE_STATE_PENDING: + raise InvalidStateError("Exception is not set.") + else: + return None + +_default_task_name_counter = 1 +def _default_task_name(): + global _default_task_name_counter + n = _default_task_name_counter + _default_task_name_counter += 1 + return f'Task-{n}' + +@extend +class Task: + def __init__(self, + coro: Coroutine[R], + loop: Optional[EventLoop] = None, + name: Optional[str] = None): + super().__init__(loop) + if name is None: + self._name = _default_task_name() + else: + self._name = name + self._coro = coro.__raw__() + + def get_name(self): + return self._name + + def set_name(self, value: str): + self._name = value + + def get_coro(self, T: type = NoneType) -> Coroutine[T]: + return Coroutine[T](self._coro) + + def _add_done_callback(self, coro: cobj): + super()._add_done_callback(coro) + + def result(self): + return super().result() + +def run(coro): + _running_loop.create_task(coro) + _running_loop.run_forever() diff --git a/stdlib/copy.codon b/stdlib/copy.codon index a9414ab0..b9bdb1c4 100644 --- a/stdlib/copy.codon +++ b/stdlib/copy.codon @@ -1,8 +1,8 @@ # Copyright (C) 2022-2025 Exaloop Inc. -class Error(Static[Exception]): +class Error(Exception): def __init__(self, message: str = ""): - super().__init__("copy.Error", message) + super().__init__(message) def copy(x): return x.__copy__() diff --git a/stdlib/datetime.codon b/stdlib/datetime.codon index edf76062..60b46700 100644 --- a/stdlib/datetime.codon +++ b/stdlib/datetime.codon @@ -418,7 +418,7 @@ class timedelta: _microseconds: int def _new(microseconds: int) -> timedelta: - return __internal__.tuple_cast_unsafe((microseconds,), timedelta) + return type._force_value_cast((microseconds,), timedelta) @inline def _accum(sofar: int, leftover: float, num: int, factor: int) -> Tuple[int, float]: diff --git a/stdlib/getopt.codon b/stdlib/getopt.codon index 2cacb155..e6e595ec 100644 --- a/stdlib/getopt.codon +++ b/stdlib/getopt.codon @@ -45,9 +45,9 @@ import os -class GetoptError(Static[Exception]): +class GetoptError(Exception): def __init__(self, message: str = ""): - super().__init__("GetoptError", message) + super().__init__(message) def long_has_args(opt: str, longopts: List[str]) -> Tuple[bool, str]: possibilities = [o for o in longopts if o.startswith(opt)] diff --git a/stdlib/internal/__init__.codon b/stdlib/internal/__init__.codon index cd3789a6..c85a3cfd 100644 --- a/stdlib/internal/__init__.codon +++ b/stdlib/internal/__init__.codon @@ -14,13 +14,20 @@ from internal.types.error import * from internal.types.intn import * from internal.types.float import * from internal.types.byte import * +from internal.types.type import * +from internal.types.tuple import * +from internal.types.rtti import * +from internal.types.function import * +from internal.types.union import * +from internal.types.any import * +from internal.types.ellipsis import * +from internal.types.import_ import * +from internal.types.ptr import * from internal.types.generator import * from internal.types.optional import * - import internal.c_stubs as _C from internal.format import * from internal.internal import * - from internal.types.slice import * from internal.types.range import * from internal.types.complex import * @@ -40,6 +47,7 @@ from internal.str import * from internal.sort import sorted from openmp import Ident as __OMPIdent, for_par, for_par as par +from asyncio import Future as _Future from internal.file import File, gzFile, open, gzopen from internal.gpu import _gpu_loop_outline_template from pickle import pickle, unpickle diff --git a/stdlib/internal/__init_test__.codon b/stdlib/internal/__init_test__.codon index 6312db14..c772919e 100644 --- a/stdlib/internal/__init_test__.codon +++ b/stdlib/internal/__init_test__.codon @@ -14,6 +14,15 @@ from internal.types.error import * from internal.types.intn import * from internal.types.float import * from internal.types.byte import * +from internal.types.type import * +from internal.types.tuple import * +from internal.types.rtti import * +from internal.types.function import * +from internal.types.union import * +from internal.types.any import * +from internal.types.ellipsis import * +from internal.types.import_ import * +from internal.types.ptr import * from internal.types.generator import * from internal.types.optional import * from internal.internal import * diff --git a/stdlib/internal/core.codon b/stdlib/internal/core.codon index 36195cf1..db318c59 100644 --- a/stdlib/internal/core.codon +++ b/stdlib/internal/core.codon @@ -5,11 +5,6 @@ @__noextend__ class type[T]: # __new__ / __init__: always done internally - # __repr__ - # __call__ - # __name__ - # __module__ - # __doc__ pass @tuple @@ -26,6 +21,7 @@ class TypeWrap[T]: @__internal__ class __internal__: pass + @__internal__ class __magic__: pass @@ -241,8 +237,6 @@ class Union[TU]: # compiler-generated def __new__(val): TU - def __call__(self, *args, **kwargs): - return __internal__.union_call(self, args, kwargs) @extend class Function: @@ -254,19 +248,26 @@ class Function: # dummy @__internal__ -class TypeTrait[T]: pass +class TypeTrait[T]: + pass + @__internal__ -class ByVal: pass +class ByVal: + pass + @__internal__ -class ByRef: pass +class ByRef: + pass @__internal__ class ClassVar[T]: pass @__internal__ -class RTTI: - id: int +class RTTIType: + data: Ptr[byte] # pointer to the data + typeinfo: Ptr[byte] # TypeInfo type information + @__internal__ @tuple @@ -299,11 +300,6 @@ class Import: %0 = insertvalue { {=bool} } undef, {=bool} %loaded, 0 ret { {=bool} } %0 - def _set_loaded(i: Ptr[Import]): - Ptr[bool](i.as_byte())[0] = True - - def __repr__(self) -> str: - return f"" def __ptr__(var): pass @@ -355,23 +351,6 @@ class NamedTuple: %0 = insertvalue { {=T} } undef, {=T} %args, 0 ret { {=T} } %0 - def __getitem__(self, key: Literal[str]): - return getattr(self, key) - - def __contains__(self, key: Literal[str]): - return hasattr(self, key) - - def get(self, key: Literal[str], default = NoneType()): - return __internal__.kwargs_get(self, key, default) - - def __keys__(self): - return __internal__.namedkeys(N) - - def __repr__(self): - keys = self.__keys__() - values = [v.__repr__() for v in self.args] - s = ', '.join(f"{keys[i]}: {values[i]}" for i in range(len(keys))) - return f"({s})" @__internal__ @tuple @@ -392,19 +371,6 @@ class Partial: %1 = insertvalue { {=T}, {=K} } %0, {=K} %kwargs, 1 ret { {=T}, {=K} } %1 - def __repr__(self): - return __magic__.repr_partial(self) - - def __call__(self, *args, **kwargs): - return self(*args, **kwargs) - - @property - def __fn_name__(self): - return F.__name__[16:-1] # chop off unrealized_type - - def __raw__(self): - # TODO: better error message - return F.T.__raw__() @__internal__ @tuple @@ -414,4 +380,23 @@ class Callable: T: type TR: type + +@tuple +@__internal__ +class Coroutine[T]: + @pure + @derives + @llvm + def __raw__(self) -> Ptr[byte]: + ret ptr %self + +@tuple +@__internal__ +class AsyncGenerator[T]: + @pure + @derives + @llvm + def __raw__(self) -> Ptr[byte]: + ret ptr %self + __codon__: Literal[bool] = True diff --git a/stdlib/internal/gpu.codon b/stdlib/internal/gpu.codon index 7e8c155c..c38d2ef7 100644 --- a/stdlib/internal/gpu.codon +++ b/stdlib/internal/gpu.codon @@ -36,13 +36,11 @@ cuModuleGetFunction = Function[[Ptr[CUfunction], CUmodule, cobj], CUresult](cobj cuModuleLoadData = Function[[Ptr[CUmodule], cobj], CUresult](cobj()) -class CUDAError(Static[Exception]): +class CUDAError(Exception): result: int - _pytype: ClassVar[cobj] = cobj() def __init__(self, result: int, message: str = ""): - super().__init__("CUDAError", message) + super().__init__(message) self.result = result - self.python_type = self.__class__._pytype @tuple @@ -402,75 +400,81 @@ def _tuple_from_gpu(args, gpu_args): a.__from_gpu__(g) _tuple_from_gpu(args[1:], gpu_args[1:]) -def kernel(fn): - def nvptx_function(name: str) -> CUfunction: - function = CUfunction() - result = CUresult() - - clean = ''.join(c if c.isalnum() else '_' for c in name) - clean_p = Ptr[byte](len(clean) + 1) - str.memcpy(clean_p, clean.ptr, len(clean)) - clean_p[len(clean)] = byte(0) - if not clean[0].isalpha(): clean_p[0] = byte('_') - - for m in modules[::-1]: - result = cuModuleGetFunction(__ptr__(function), m, clean_p) - if result == i32(CUDA_SUCCESS): - return function - elif result == i32(CUDA_ERROR_NOT_FOUND): - continue - else: - break - cuda_check(result) # this will raise an error - return CUfunction() - - def canonical_dim(dim): - if isinstance(dim, NoneType): - return (1, 1, 1) - elif isinstance(dim, int): - return (dim, 1, 1) - elif isinstance(dim, Tuple[int,int]): - return (dim[0], dim[1], 1) - elif isinstance(dim, Tuple[int,int,int]): - return dim - elif isinstance(dim, Dim3): - return (dim.x, dim.y, dim.z) - else: - compile_error("bad dimension argument") - def offsets(t): - @pure - @llvm - def offsetof(t: T, i: Literal[int], T: type, S: type) -> int: - %p = getelementptr {=T}, ptr null, i64 0, i32 {=i} - %s = ptrtoint ptr %p to i64 - ret i64 %s +def nvptx_function(name: str) -> CUfunction: + function = CUfunction() + result = CUresult() - if static.len(t) == 0: - return () + clean = ''.join(c if c.isalnum() else '_' for c in name) + clean_p = Ptr[byte](len(clean) + 1) + str.memcpy(clean_p, clean.ptr, len(clean)) + clean_p[len(clean)] = byte(0) + if not clean[0].isalpha(): clean_p[0] = byte('_') + + for m in modules[::-1]: + result = cuModuleGetFunction(__ptr__(function), m, clean_p) + if result == i32(CUDA_SUCCESS): + return function + elif result == i32(CUDA_ERROR_NOT_FOUND): + continue else: - T = type(t) - S = type(t[-1]) - return (*offsets(t[:-1]), offsetof(t, static.len(t) - 1, T, S)) - - def wrapper(*args, grid, block): - grid = canonical_dim(grid) - block = canonical_dim(block) - cache = AllocCache([]) - shared_mem = 0 - gpu_args = tuple(arg.__to_gpu__(cache) for arg in args) - kernel_ptr = nvptx_function(static.function.realized(fn, *gpu_args).__llvm_name__) - p = __ptr__(gpu_args).as_byte() - arg_ptrs = tuple((p + offset) for offset in offsets(gpu_args)) - cuda_check(cuLaunchKernel(kernel_ptr, - u32(grid[0]), u32(grid[1]), u32(grid[2]), - u32(block[0]), u32(block[1]), u32(block[2]), - u32(shared_mem), cobj(), - __ptr__(arg_ptrs).as_byte(), cobj())) - _tuple_from_gpu(args, gpu_args) - cache.free() - - return wrapper + break + cuda_check(result) # this will raise an error + return CUfunction() + + +def canonical_dim(dim): + if isinstance(dim, NoneType): + return (1, 1, 1) + elif isinstance(dim, int): + return (dim, 1, 1) + elif isinstance(dim, Tuple[int,int]): + return (dim[0], dim[1], 1) + elif isinstance(dim, Tuple[int,int,int]): + return dim + elif isinstance(dim, Dim3): + return (dim.x, dim.y, dim.z) + else: + compile_error("bad dimension argument") + + +def offsets(t): + @pure + @llvm + def offsetof(t: T, i: Literal[int], T: type, S: type) -> int: + %p = getelementptr {=T}, ptr null, i64 0, i32 {=i} + %s = ptrtoint ptr %p to i64 + ret i64 %s + + if static.len(t) == 0: + return () + else: + T = type(t) + S = type(t[-1]) + return (*offsets(t[:-1]), offsetof(t, static.len(t) - 1, T, S)) + + +def kernel_wrapper(*args, grid, block, fn): + grid = canonical_dim(grid) + block = canonical_dim(block) + cache = AllocCache([]) + shared_mem = 0 + gpu_args = tuple(arg.__to_gpu__(cache) for arg in args) + kernel_ptr = nvptx_function(static.function.realized(fn, *gpu_args).__llvm_name__) + p = __ptr__(gpu_args).as_byte() + arg_ptrs = tuple((p + offset) for offset in offsets(gpu_args)) + cuda_check(cuLaunchKernel(kernel_ptr, + u32(grid[0]), u32(grid[1]), u32(grid[2]), + u32(block[0]), u32(block[1]), u32(block[2]), + u32(shared_mem), cobj(), + __ptr__(arg_ptrs).as_byte(), cobj())) + _tuple_from_gpu(args, gpu_args) + cache.free() + + +def kernel(fn): + return kernel_wrapper(fn=fn, ...) + def _ptr_to_gpu(p: Ptr[T], n: int, cache: AllocCache, index_filter = lambda i: True, T: type): from internal.gc import atomic @@ -848,8 +852,8 @@ class Optional: return Optional[T](T.__from_gpu_new__(other.__val__())) @extend -class __internal__: - def class_to_gpu(obj, cache: AllocCache): +class type: + def _to_gpu(obj, cache: AllocCache): if isinstance(obj, Tuple): return tuple(a.__to_gpu__(cache) for a in obj) elif isinstance(obj, ByVal): @@ -862,7 +866,7 @@ class __internal__: Ptr[S](mem.__raw__())[0] = tuple(obj).__to_gpu__(cache) return _object_to_gpu(mem, cache) - def class_from_gpu(obj, other): + def _from_gpu(obj, other): if isinstance(obj, Tuple): _tuple_from_gpu(obj, other) elif isinstance(obj, ByVal): @@ -871,7 +875,7 @@ class __internal__: S = type(tuple(obj)) Ptr[S](obj.__raw__())[0] = S.__from_gpu_new__(tuple(_object_from_gpu(other))) - def class_from_gpu_new(other): + def _from_gpu_new(other): if isinstance(other, Tuple): return tuple(type(a).__from_gpu_new__(a) for a in other) elif isinstance(other, ByVal): diff --git a/stdlib/internal/internal.codon b/stdlib/internal/internal.codon index ad08ba02..7fe56111 100644 --- a/stdlib/internal/internal.codon +++ b/stdlib/internal/internal.codon @@ -1,70 +1,9 @@ # Copyright (C) 2022-2025 Exaloop Inc. -from internal.gc import ( - alloc, alloc_atomic, alloc_atomic_uncollectable, - free, sizeof, register_finalizer -) import internal.static as static - -__vtables__ = Ptr[Ptr[cobj]]() -__vtable_size__ = 0 - - @extend class __internal__: - def yield_final(val): - pass - - def yield_in_no_suspend(T: type) -> T: - pass - - @pure - @derives - @llvm - def class_raw_ptr(obj) -> Ptr[byte]: - ret ptr %obj - - @pure - @derives - @llvm - def class_raw_rtti_ptr(obj) -> Ptr[byte]: - %0 = extractvalue {ptr, ptr} %obj, 0 - ret ptr %0 - - @pure - @derives - @llvm - def class_raw_rtti_rtti(obj: T, T: type) -> Ptr[byte]: - %0 = extractvalue {ptr, ptr} %obj, 1 - ret ptr %0 - - def class_alloc(T: type) -> T: - """Allocates a new reference (class) type""" - sz = sizeof(tuple(T)) - obj = alloc_atomic(sz) if T.__contents_atomic__ else alloc(sz) - if static.has_rtti(T): - register_finalizer(obj) - rtti = RTTI(T.__id__).__raw__() - return __internal__.to_class_ptr_rtti((obj, rtti), T) - else: - register_finalizer(obj) - return __internal__.to_class_ptr(obj, T) - - def class_ctr(T: type, *args, **kwargs) -> T: - """Shorthand for `t = T.__new__(); t.__init__(*args, **kwargs); t`""" - return T(*args, **kwargs) - - def class_init_vtables(): - """ - Create a global vtable. - """ - global __vtables__ - sz = __vtable_size__ + 1 - p = alloc_atomic_uncollectable(sz * sizeof(Ptr[cobj])) - __vtables__ = Ptr[Ptr[cobj]](p) - __internal__.class_populate_vtables() - def _print(a): from C import seq_print(str) if hasattr(a, "__repr__"): @@ -72,304 +11,6 @@ class __internal__: else: seq_print(a.__str__()) - def class_populate_vtables() -> None: - """ - Populate content of vtables. Compiler generated. - Corresponds to: - for each realized class C: - __internal__.class_set_rtti_vtable(, + 1, T=C) - for each fn F in C's vtable: - __internal__.class_set_rtti_vtable_fn( - , , Function().__raw__(), T=C - ) - """ - pass - - def class_set_rtti_vtable(id: int, sz: int, T: type): - if not static.has_rtti(T): - compile_error("class is not polymorphic") - p = alloc_atomic_uncollectable((sz + 1) * sizeof(cobj)) - __vtables__[id] = Ptr[cobj](p) - __internal__.class_set_typeinfo(__vtables__[id], id) - - def class_set_rtti_vtable_fn(id: int, fid: int, f: cobj, T: type): - if not static.has_rtti(T): - compile_error("class is not polymorphic") - __vtables__[id][fid] = f - - def class_get_thunk_id(F: type, T: type) -> int: - """Compiler-generated""" - return 0 - - def class_thunk_debug(base, func, sig, *args): - # print("![thunk]!", base, func, sig, args[0].__raw__()) - pass - - @no_argument_wrap - def class_thunk_dispatch(slf, cls_id, *args, F: type): - if not static.has_rtti(type(slf)): - compile_error("class is not polymorphic") - - FR = type(static.function.realized(F, slf, *args)) - T = type(slf) - thunk_id = __internal__.class_get_thunk_id(FR, T) - - # Get RTTI table - if cls_id == 0: - rtti = __internal__.class_raw_rtti_rtti(slf) - cls_id = __internal__.to_class_ptr(rtti, RTTI).id - fptr = __vtables__[cls_id][thunk_id] - f = FR(fptr) - return f(slf, *args) - - def class_set_typeinfo(p: Ptr[cobj], typeinfo: T, T: type) -> None: - i = Ptr[T](1) - i[0] = typeinfo - p[0] = i.as_byte() - - def class_get_typeinfo(p) -> int: - c = Ptr[Ptr[cobj]](p.__raw__()) - vt = c[0] - return Ptr[int](vt[0])[0] - - @inline - def class_base_derived_dist(B: type, D: type) -> int: - """Calculates the byte distance of base class B and derived class D. Compiler generated.""" - return 0 - - @inline - def class_base_to_derived(b: B, B: type, D: type) -> D: - if not (static.has_rtti(D) and static.has_rtti(B)): - compile_error("classes are not polymorphic") - off = __internal__.class_base_derived_dist(B, D) - ptr = __internal__.class_raw_rtti_ptr(b) - off - pr = __internal__.class_raw_rtti_rtti(b) - return __internal__.to_class_ptr_rtti((ptr, pr), D) - - def class_copy(obj: T, T: type) -> T: - p = __internal__.class_alloc(T) - str.memcpy(p.__raw__(), obj.__raw__(), sizeof(tuple(T))) - return p - - def class_super(obj, B: type, use_super_type: Literal[int] = 0): - D = type(obj) - if not static.has_rtti(D): # static inheritance - return __internal__.to_class_ptr(obj.__raw__(), B) - else: - if not static.has_rtti(B): - compile_error("classes are not polymorphic") - off = __internal__.class_base_derived_dist(B, D) - ptr = __internal__.class_raw_rtti_ptr(obj) + off - pr = __internal__.class_raw_rtti_rtti(obj) - res = __internal__.to_class_ptr_rtti((ptr, pr), B) - if use_super_type: - # This is explicit super() - return __internal__.tuple_cast_unsafe((res, ), Super[B]) - else: - # Implicit super() just used for casting - return res - - def class_super_change_rtti(obj: Super[B], B: type): - ptr = __internal__.class_raw_rtti_ptr(obj.__obj__) - pr = RTTI(B.__id__).__raw__() - return __internal__.to_class_ptr_rtti((ptr, pr), B) - - @pure - @derives - @llvm - def _capsule_make_helper(val: Ptr[T], T: type) -> Capsule[T]: - %0 = insertvalue { ptr } undef, ptr %val, 0 - ret { ptr } %0 - - def capsule_make(val: T, T: type) -> Capsule[T]: - p = Ptr[T](1) - p[0] = val - return __internal__._capsule_make_helper(p) - - @pure - @derives - @llvm - def capsule_get_ptr(ref: Capsule[T], T: type) -> Ptr[T]: - %0 = extractvalue { ptr } %ref, 0 - %1 = getelementptr {=T}, ptr %0, i64 0 - ret ptr %1 - - @pure - @derives - @llvm - def capsule_get(ref: Capsule[T], T: type) -> T: - %0 = extractvalue { ptr } %ref, 0 - %1 = getelementptr {=T}, ptr %0, i64 0 - %2 = load {=T}, ptr %1 - ret {=T} %2 - - - # Unions - - def get_union_tag(u, tag: Literal[int]): # compiler-generated - pass - - @llvm - def union_set_tag(tag: byte, U: type) -> U: - %0 = insertvalue {=U} undef, i8 %tag, 0 - ret {=U} %0 - - @llvm - def union_get_data_ptr(ptr: Ptr[U], U: type, T: type) -> Ptr[T]: - %0 = getelementptr inbounds {=U}, ptr %ptr, i64 0, i32 1 - ret ptr %0 - - @llvm - def union_get_tag(u: U, U: type) -> byte: - %0 = extractvalue {=U} %u, 0 - ret i8 %0 - - def union_get_data(u, T: type) -> T: - return __internal__.union_get_data_ptr(__ptr__(u), T=T)[0] - - def union_make(tag: int, value, U: type) -> U: - u = __internal__.union_set_tag(byte(tag), U) - __internal__.union_get_data_ptr(__ptr__(u), T=type(value))[0] = value - return u - - def new_union(value, U: type) -> U: - for tag, T in static.vars_types(U, with_index=True): - if isinstance(value, T): - return __internal__.union_make(tag, value, U) - if isinstance(value, Union[T]): - return __internal__.union_make(tag, __internal__.get_union(value, T), U) - # TODO: make this static! - raise TypeError("invalid union constructor") - - def get_union(union, T: type) -> T: - for tag, TU in static.vars_types(union, with_index=True): - if isinstance(TU, T): - if __internal__.union_get_tag(union) == tag: - return __internal__.union_get_data(union, TU) - raise TypeError(f"invalid union getter for type '{T.__class__.__name__}'") - - def _union_member_helper(union, member: Literal[str]) -> Union: - for tag, T in static.vars_types(union, with_index=True): - if hasattr(T, member): - if __internal__.union_get_tag(union) == tag: - return getattr(__internal__.union_get_data(union, T), member) - raise TypeError(f"invalid union call '{member}'") - - def union_member(union, member: Literal[str]): - t = __internal__._union_member_helper(union, member) - if static.len(t) == 1: - return __internal__.get_union_tag(t, 0) - else: - return t - - def _union_call_helper(union, args, kwargs) -> Union: - for tag, T in static.vars_types(union, with_index=True): - if static.function.can_call(T, *args, **kwargs): - if __internal__.union_get_tag(union) == tag: - return __internal__.union_get_data(union, T)(*args, **kwargs) - elif hasattr(T, '__call__'): - if static.function.can_call(T.__call__, *args, **kwargs): - if __internal__.union_get_tag(union) == tag: - return __internal__.union_get_data(union, T).__call__(*args, **kwargs) - raise TypeError("cannot call union " + union.__class__.__name__) - - def union_call(union, args, kwargs): - t = __internal__._union_call_helper(union, args, kwargs) - if static.len(t) == 1: - return __internal__.get_union_tag(t, 0) - else: - return t - - def union_str(union): - for tag, T in static.vars_types(union, with_index=True): - if hasattr(T, '__str__'): - if __internal__.union_get_tag(union) == tag: - return __internal__.union_get_data(union, T).__str__() - elif hasattr(T, '__repr__'): - if __internal__.union_get_tag(union) == tag: - return __internal__.union_get_data(union, T).__repr__() - return '' - - # and/or - - def and_union(x, y): - if type(x) is type(y): - return y if x else x - else: - T = Union[type(x),type(y)] - return T(y) if x else T(x) - - def or_union(x, y): - if type(x) is type(y): - return x if x else y - else: - T = Union[type(x),type(y)] - return T(x) if x else T(y) - - - # Tuples - - def namedkeys(N: Literal[int]): - pass - - @pure - @derives - @llvm - def _tuple_getitem_llvm(t: T, idx: int, T: type, E: type) -> E: - %x = alloca {=T} - store {=T} %t, ptr %x - %p = getelementptr {=E}, ptr %x, i64 %idx - %v = load {=E}, ptr %p - ret {=E} %v - - def tuple_fix_index(idx: int, len: int) -> int: - if idx < 0: - idx += len - if idx < 0 or idx >= len: - raise IndexError("tuple index out of range") - return idx - - def tuple_getitem(t: T, idx: int, T: type, E: type) -> E: - return __internal__._tuple_getitem_llvm( - t, __internal__.tuple_fix_index(idx, static.len(t)), T, E - ) - - @pure - @derives - @llvm - def tuple_cast_unsafe(t, U: type) -> U: - ret {=U} %t - - @pure - @derives - @llvm - def fn_new(p: Ptr[byte], T: type) -> T: - ret ptr %p - - @pure - @derives - @llvm - def fn_raw(fn: T, T: type) -> Ptr[byte]: - ret ptr %fn - - @pure - @llvm - def int_sext(what, F: Literal[int], T: Literal[int]) -> Int[T]: - %0 = sext i{=F} %what to i{=T} - ret i{=T} %0 - - @pure - @llvm - def int_zext(what, F: Literal[int], T: Literal[int]) -> Int[T]: - %0 = zext i{=F} %what to i{=T} - ret i{=T} %0 - - @pure - @llvm - def int_trunc(what, F: Literal[int], T: Literal[int]) -> Int[T]: - %0 = trunc i{=F} %what to i{=T} - ret i{=T} %0 - def seq_assert(file: str, line: int, msg: str) -> AssertionError: s = f": {msg}" if msg else "" s = f"Assert failed{s} ({file}:{line.__repr__()})" @@ -381,201 +22,10 @@ class __internal__: s = f"\033[1;31mTEST FAILED:\033[0m {file} (line {line}){s}\n" seq_print(s) - def check_errno(prefix: str): - @pure - @C - def seq_check_errno() -> str: - pass - - msg = seq_check_errno() - if msg: - raise OSError(prefix + msg) - - @pure - @llvm - def opt_ref_new(T: type) -> Optional[T]: - ret ptr null - - @pure - @llvm - def opt_ref_new_rtti(T: type) -> Optional[T]: - ret { ptr, ptr } { ptr null, ptr null } - - @pure - @derives - @llvm - def opt_tuple_new_arg(what: T, T: type) -> Optional[T]: - %0 = insertvalue { i1, {=T} } { i1 true, {=T} undef }, {=T} %what, 1 - ret { i1, {=T} } %0 - - @pure - @derives - @llvm - def opt_ref_new_arg(what: T, T: type) -> Optional[T]: - ret ptr %what - - @pure - @derives - @llvm - def opt_ref_new_arg_rtti(what: T, T: type) -> Optional[T]: - ret { ptr, ptr } %what - - @pure - @llvm - def opt_tuple_bool(what: Optional[T], T: type) -> bool: - %0 = extractvalue { i1, {=T} } %what, 0 - %1 = zext i1 %0 to i8 - ret i8 %1 - - @pure - @llvm - def opt_ref_bool(what: Optional[T], T: type) -> bool: - %0 = icmp ne ptr %what, null - %1 = zext i1 %0 to i8 - ret i8 %1 - - @pure - def opt_ref_bool_rtti(what: Optional[T], T: type) -> bool: - return __internal__.class_raw_rtti_ptr(what) != cobj() - - @pure - @derives - @llvm - def opt_tuple_invert(what: Optional[T], T: type) -> T: - %0 = extractvalue { i1, {=T} } %what, 1 - ret {=T} %0 - - @pure - @derives - @llvm - def opt_ref_invert(what: Optional[T], T: type) -> T: - ret ptr %what - - @pure - @derives - @llvm - def opt_ref_invert_rtti(what: Optional[T], T: type) -> T: - ret { ptr, ptr } %what - - @pure - @derives - @llvm - def to_class_ptr(p: Ptr[byte], T: type) -> T: - ret ptr %p - - @pure - @derives - @llvm - def to_class_ptr_rtti(p: Tuple[Ptr[byte], Ptr[byte]], T: type) -> T: - ret { ptr, ptr } %p - - def _tuple_offsetof(x, field: Literal[int]) -> int: - @pure - @llvm - def _llvm_offsetof(T: type, idx: Literal[int], TE: type) -> int: - %a = alloca {=T} - %b = getelementptr inbounds {=T}, ptr %a, i64 0, i32 {=idx} - %base = ptrtoint ptr %a to i64 - %elem = ptrtoint ptr %b to i64 - %offset = sub i64 %elem, %base - ret i64 %offset - - return _llvm_offsetof(type(x), field, type(x[field])) - - def raw_type_str(p: Ptr[byte], name: str) -> str: - pstr = p.__repr__() - # '<[name] at [pstr]>' - total = 1 + name.len + 4 + pstr.len + 1 - buf = Ptr[byte](total) - where = 0 - buf[where] = byte(60) # '<' - where += 1 - str.memcpy(buf + where, name.ptr, name.len) - where += name.len - buf[where] = byte(32) # ' ' - where += 1 - buf[where] = byte(97) # 'a' - where += 1 - buf[where] = byte(116) # 't' - where += 1 - buf[where] = byte(32) # ' ' - where += 1 - str.memcpy(buf + where, pstr.ptr, pstr.len) - where += pstr.len - buf[where] = byte(62) # '>' - free(pstr.ptr) - return str(buf, total) - - def tuple_str(strs: Ptr[str], names: Ptr[str], n: int) -> str: - # special case of 1-element plain tuple: format as "(x,)" - if n == 1 and names[0].len == 0: - total = strs[0].len + 3 - buf = Ptr[byte](total) - buf[0] = byte(40) # '(' - str.memcpy(buf + 1, strs[0].ptr, strs[0].len) - buf[total - 2] = byte(44) # ',' - buf[total - 1] = byte(41) # ')' - return str(buf, total) - - total = 2 # one for each of '(' and ')' - i = 0 - while i < n: - total += strs[i].len - if names[i].len: - total += names[i].len + 2 # extra : and space - if i < n - 1: - total += 2 # ", " - i += 1 - buf = Ptr[byte](total) - where = 0 - buf[where] = byte(40) # '(' - where += 1 - i = 0 - while i < n: - s = names[i] - l = s.len - if l: - str.memcpy(buf + where, s.ptr, l) - where += l - buf[where] = byte(58) # ':' - where += 1 - buf[where] = byte(32) # ' ' - where += 1 - s = strs[i] - l = s.len - str.memcpy(buf + where, s.ptr, l) - where += l - if i < n - 1: - buf[where] = byte(44) # ',' - where += 1 - buf[where] = byte(32) # ' ' - where += 1 - i += 1 - buf[where] = byte(41) # ')' - return str(buf, total) - def undef(v, s): if not v: raise NameError(f"name '{s}' is not defined") - @__hidden__ - def set_header(e, func, file, line, col, cause): - if not isinstance(e, BaseException): - compile_error("exceptions must derive from BaseException") - - e.func = func - e.file = file - e.line = line - e.col = col - if cause is not None: - e.cause = cause - return e - - def kwargs_get(kw, key: Literal[str], default): - if hasattr(kw, key): - return getattr(kw, key) - else: - return default @extend class __magic__: @@ -587,17 +37,14 @@ class __magic__: # always present for reference types only def new(T: type) -> T: """Create a new reference (class) type""" - return __internal__.class_alloc(T) + return type._ref_new(T) # init is compiler-generated when init=True for reference types # def init(self, a1, ..., aN): ... # always present for reference types only def raw(obj) -> Ptr[byte]: - if static.has_rtti(type(obj)): - return __internal__.class_raw_rtti_ptr(obj) - else: - return __internal__.class_raw_ptr(obj) + return type._ref_raw(obj) # always present for reference types only def dict(slf) -> List[str]: @@ -636,9 +83,9 @@ class __magic__: # @dataclass parameter: container=True def getitem(slf, index: int): if static.len(slf) == 0: - __internal__.tuple_fix_index(index, 0) # raise exception + Tuple._fix_index(index, 0) # raise exception else: - return __internal__.tuple_getitem(slf, index, type(slf), static.tuple_type(slf, 0)) + return Tuple._getitem(slf, index, type(slf), static.tuple_type(slf, 0)) # @dataclass parameter: container=True def iter(slf): @@ -713,7 +160,7 @@ class __magic__: # @dataclass parameter: pickle=True def unpickle(src: Ptr[byte], T: type) -> T: if isinstance(T, ByVal): - return __internal__.tuple_cast_unsafe(tuple(type(t).__unpickle__(src) for t in static.vars_types(T)), T) + return type._force_value_cast(tuple(type(t).__unpickle__(src) for t in static.vars_types(T)), T) else: obj = T.__new__() for k, v in static.vars(obj): @@ -730,7 +177,7 @@ class __magic__: # @dataclass parameter: python=True def from_py(src: Ptr[byte], T: type) -> T: if isinstance(T, ByVal): - return __internal__.tuple_cast_unsafe(tuple( + return type._force_value_cast(tuple( type(t).__from_py__(pyobj._tuple_get(src, i)) for i, t in static.vars_types(T, with_index=True) ), T) @@ -742,15 +189,15 @@ class __magic__: # @dataclass parameter: gpu=True def to_gpu(slf, cache): - return __internal__.class_to_gpu(slf, cache) + return type._to_gpu(slf, cache) # @dataclass parameter: gpu=True def from_gpu(slf: T, other: T, T: type): - __internal__.class_from_gpu(slf, other) + type._from_gpu(slf, other) # @dataclass parameter: gpu=True def from_gpu_new(other: T, T: type) -> T: - return __internal__.class_from_gpu_new(other) + return type._from_gpu_new(other) # @dataclass parameter: repr=True def repr(slf) -> str: @@ -765,7 +212,7 @@ class __magic__: n[i] = "" else: n[i] = k - return __internal__.tuple_str(a.ptr, n.ptr, l) + return Tuple._str(a.ptr, n.ptr, l) # @dataclass parameter: repr=False def repr_default(slf) -> str: @@ -777,33 +224,6 @@ class __magic__: return slf.__repr_default__() return slf.__repr__() -@extend -class Function: - @pure - @overload - @llvm - def __new__(what: Ptr[byte]) -> Function[T, TR]: - ret ptr %what - - @overload - def __new__(what: Function[T, TR]) -> Function[T, TR]: - return what - - @pure - @llvm - def __raw__(self) -> Ptr[byte]: - ret ptr %self - - def __repr__(self) -> str: - return __internal__.raw_type_str(self.__raw__(), "function") - - @llvm - def __call_internal__(self: Function[T, TR], args: T) -> TR: - noop # compiler will populate this one - - def __call__(self, *args) -> TR: - return Function.__call_internal__(self, args) - @tuple class PyObject: refcnt: int @@ -814,122 +234,4 @@ class PyWrapper[T]: head: PyObject data: T -@extend -class RTTI: - def __new__() -> RTTI: - return __magic__.new(RTTI) - def __init__(self, i: int): - self.id = i - def __raw__(self): - return __internal__.class_raw_ptr(self) - -@extend -class ellipsis: - def __repr__(self): - return 'Ellipsis' - def __eq__(self, other: ellipsis): - return True - def __ne__(self, other: ellipsis): - return False - def __hash__(self): - return 269626442 # same as CPython - -__internal__.class_init_vtables() - -@extend -class Super: - def __repr__(self): - return f'' - -class __cast__: - def cast(obj: T, T: type) -> Generator[T]: - return obj.__iter__() - - @overload - def cast(obj: int) -> float: - return float(obj) - - @overload - def cast(obj: T, T: type) -> Optional[T]: - return Optional[T](obj) - - @overload - def cast(obj: Optional[T], T: type) -> T: - return obj.unwrap() - - @overload - def cast(obj: T, T: type) -> pyobj: - return obj.__to_py__() - - @overload - def cast(obj: pyobj, T: type) -> T: - return T.__from_py__(obj) - - # Function[[T...], R] - # ExternFunction[[T...], R] - # CodonFunction[[T...], R] - # Partial[foo, [T...], R] - - # function into partial (if not Function) / fn(foo) -> fn(foo(...)) - # empty partial (!!) into Function[] - # union extract - # any into Union[] - # derived to base - - def conv_float(obj: float) -> int: - return int(obj) - -def __type_repr__(T: type): - return f"" - -@extend -class TypeWrap: - def __new__(T: type) -> TypeWrap[T]: - return __internal__.tuple_cast_unsafe((), TypeWrap[T]) - - def __call_no_self__(*args, **kwargs) -> T: - return T(*args, **kwargs) - - def __call__(self, *args, **kwargs) -> T: - return T(*args, **kwargs) - - def __repr__(self): - return __type_repr__(T) - - @property - def __name__(self): - return T.__name__ - -@extend -class Capsule: - def __init__(self, val: T): - self.val[0] = val - -@extend -class Callable: - def __new__(fn: Function[[Ptr[byte], T], TR], data: Ptr[byte]) -> Callable[T, TR]: - return __internal__.tuple_cast_unsafe((fn, data), Callable[T, TR]) - - @overload - def __new__(fn: Function[[Ptr[byte], T], TR], data: Partial[M,PT,K,F], - T: type, TR: type, - M: Literal[str], PT: type, F: type, K: type) -> Callable[T, TR]: - p = Ptr[Partial[M,PT,K,F]](1) - p[0] = data - return Callable(fn, p.as_byte()) - - @overload - def __new__(fn: Function[[Ptr[byte], T], TR], data: Function[T, TR]) -> Callable[T, TR]: - return Callable(fn, data.__raw__()) - - @overload - def __new__(fn: Function[T, TR]) -> Callable[T, TR]: - def _wrap(data: Ptr[byte], args, f: type): - return f(data)(*args) - return Callable( - static.function.realized(_wrap(f=Function[T, TR], ...), Ptr[byte], T), - fn.__raw__() - ) - - def __call__(self, *args): - return self.fn.__call__(self.data, args) +RTTIType._init_vtables() diff --git a/stdlib/internal/python.codon b/stdlib/internal/python.codon index 79ebe0d1..b765b437 100644 --- a/stdlib/internal/python.codon +++ b/stdlib/internal/python.codon @@ -954,25 +954,25 @@ class _PyArg_Parser: return _PyArg_Parser(z, format, keywords, fname, o, z, z, z, o, o) @dataclass(init=False) -class PyError(Static[Exception]): +class PyError(Exception): pytype: pyobj def __init__(self, message: str): - super().__init__("PyError", message) + super().__init__(message) self.pytype = pyobj(cobj(), steal=True) @overload def __init__(self, message: str, pytype: pyobj): - super().__init__("PyError", message) + super().__init__(message) self.pytype = pytype @extend class pyobj: def __new__() -> pyobj: - return __internal__.class_alloc(pyobj) + return type._ref_new(pyobj) def __raw__(self) -> Ptr[byte]: - return __internal__.class_raw(self) + return type._ref_raw(self) def __init__(self, p: Ptr[byte], steal: bool = False): self.p = p @@ -1536,7 +1536,7 @@ def _____(): __pyenv__ # make it global! import internal.static as _S -class _PyWrapError(Static[PyError]): +class _PyWrapError(PyError): def __init__(self, message: str, pytype: pyobj = pyobj(cobj(), steal=True)): super().__init__("_PyWrapError", message) self.pytype = pytype diff --git a/stdlib/internal/static.codon b/stdlib/internal/static.codon index c0d99061..02744672 100644 --- a/stdlib/internal/static.codon +++ b/stdlib/internal/static.codon @@ -13,13 +13,17 @@ def print(*args): def range(start: Literal[int], stop: Literal[int], step: Literal[int] = 1): import internal.types.range - return internal.types.range.range(start, stop, step) + if step == 0: + compile_error("range() step argument must not be zero") + # Avoid exception raising method here as exception raising depends on this method + return type._force_value_cast((start, stop, step), internal.types.range.range) @overload def range(stop: Literal[int]): import internal.types.range - return internal.types.range.range(0, stop, 1) + # Avoid exception raising method here as exception raising depends on this method + return type._force_value_cast((0, stop, 1), internal.types.range.range) def enumerate(tup): i = -1 diff --git a/stdlib/internal/types/any.codon b/stdlib/internal/types/any.codon new file mode 100644 index 00000000..42334922 --- /dev/null +++ b/stdlib/internal/types/any.codon @@ -0,0 +1,71 @@ +# Copyright (C) 2022-2025 Exaloop Inc. + +import internal.static as static + + +@__internal__ +class Any: + _data: Ptr[byte] + _typeinfo: TypeInfo + + def __str__(self): + if self._typeinfo.repr.__raw__() != cobj(): + return f"Any({self._typeinfo.repr(self._data)}, {self._typeinfo})" + return f"Any({self._typeinfo})" + + def __new__() -> Any: + return __magic__.new(Any) + + def __init__(self, obj): + T = type(obj) + self._typeinfo = TypeInfo(T) + if isinstance(T, ByVal): + p = Ptr[T](1) + p[0] = obj + self._data = p.as_byte() + else: + self._data = obj.__raw__().as_byte() + + def unwrap(self, T: type) -> T: + if TypeInfo(T) != self._typeinfo: + raise TypeError(f"Any.unwrap failed: requested {T}, but got {self._typeinfo.nice_name} instead") + if isinstance(T, ByVal): + p = type._force_cast(self._data, Ptr[T]) + return p[0] + else: + return type._force_cast(self._data, T) + + +@extend +class Capsule: + @pure + @derives + @llvm + def _make(val: Ptr[T], T: type) -> Capsule[T]: + %0 = insertvalue { ptr } undef, ptr %val, 0 + ret { ptr } %0 + + def make(val: T, T: type) -> Capsule[T]: + p = Ptr[T](1) + p[0] = val + return Capsule._make(p) + + @pure + @derives + @llvm + def _ptr(ref: Capsule[T], T: type) -> Ptr[T]: + %0 = extractvalue { ptr } %ref, 0 + %1 = getelementptr {=T}, ptr %0, i64 0 + ret ptr %1 + + @pure + @derives + @llvm + def _get(ref: Capsule[T], T: type) -> T: + %0 = extractvalue { ptr } %ref, 0 + %1 = getelementptr {=T}, ptr %0, i64 0 + %2 = load {=T}, ptr %1 + ret {=T} %2 + + def __init__(self, val: T): + self.val[0] = val diff --git a/stdlib/internal/types/ellipsis.codon b/stdlib/internal/types/ellipsis.codon new file mode 100644 index 00000000..9e440a80 --- /dev/null +++ b/stdlib/internal/types/ellipsis.codon @@ -0,0 +1,12 @@ +# Copyright (C) 2022-2025 Exaloop Inc. + +@extend +class ellipsis: + def __repr__(self): + return 'Ellipsis' + def __eq__(self, other: ellipsis): + return True + def __ne__(self, other: ellipsis): + return False + def __hash__(self): + return 269626442 # same as CPython diff --git a/stdlib/internal/types/error.codon b/stdlib/internal/types/error.codon index c02e6149..e3474f75 100644 --- a/stdlib/internal/types/error.codon +++ b/stdlib/internal/types/error.codon @@ -4,7 +4,6 @@ # header type defined in runtime/exc.cpp. class BaseException: _pytype: ClassVar[cobj] = cobj() - typename: str message: str func: str file: str @@ -13,15 +12,14 @@ class BaseException: python_type: cobj cause: Optional[BaseException] - def __init__(self, typename: str, message: str = ""): - self.typename = typename + def __init__(self, message: str = ""): self.message = message self.func = "" self.file = "" self.line = 0 self.col = 0 self.python_type = BaseException._pytype - self.cause = __internal__.opt_ref_new(T=BaseException) + self.cause = Optional._ref_new(T=BaseException) def __str__(self): return self.message @@ -33,136 +31,141 @@ class BaseException: def __cause__(self): return self.cause -class Exception(Static[BaseException]): + @__hidden__ + def _set_header(e, func, file, line, col, cause): + # if not isinstance(e, BaseException): + # compile_error("exceptions must derive from BaseException") + + e.func = func + e.file = file + e.line = line + e.col = col + if cause is not None: + e.cause = cause + return e + + +class Exception(BaseException): _pytype: ClassVar[cobj] = cobj() - def __init__(self, typename: str, msg: str = ""): - super().__init__(typename, msg) + def __init__(self, message: str = ""): + super().__init__(message) if hasattr(self.__class__, "_pytype"): self.python_type = self.__class__._pytype -class NameError(Static[Exception]): +class NameError(Exception): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("NameError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class OSError(Static[Exception]): +class OSError(Exception): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("OSError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class IOError(Static[Exception]): +class IOError(Exception): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("IOError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class ValueError(Static[Exception]): +class ValueError(Exception): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("ValueError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class LookupError(Static[Exception]): +class LookupError(Exception): _pytype: ClassVar[cobj] = cobj() - def __init__(self, typename: str, message: str = ""): - super().__init__(typename, message) - self.python_type = self.__class__._pytype - def __init__(self, msg: str = ""): - super().__init__("LookupError", msg) + def __init__(self, message: str = ""): + super().__init__(message) self.python_type = self.__class__._pytype -class IndexError(Static[LookupError]): +class IndexError(LookupError): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("IndexError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class KeyError(Static[LookupError]): +class KeyError(LookupError): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("KeyError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class CError(Static[Exception]): +class CError(Exception): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("CError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class TypeError(Static[Exception]): +class TypeError(Exception): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("TypeError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class ArithmeticError(Static[Exception]): +class ArithmeticError(Exception): _pytype: ClassVar[cobj] = cobj() - def __init__(self, msg: str = ""): - super().__init__("ArithmeticError", msg) + def __init__(self, message: str = ""): + super().__init__(message) self.python_type = self.__class__._pytype -class ZeroDivisionError(Static[ArithmeticError]): +class ZeroDivisionError(ArithmeticError): _pytype: ClassVar[cobj] = cobj() - def __init__(self, typename: str, message: str = ""): - super().__init__(typename, message) - self.python_type = self.__class__._pytype def __init__(self, message: str = ""): - super().__init__("ZeroDivisionError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class OverflowError(Static[ArithmeticError]): +class OverflowError(ArithmeticError): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("OverflowError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class AttributeError(Static[Exception]): +class AttributeError(Exception): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("AttributeError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class RuntimeError(Static[Exception]): +class RuntimeError(Exception): _pytype: ClassVar[cobj] = cobj() - def __init__(self, typename: str, message: str = ""): - super().__init__(typename, message) - self.python_type = self.__class__._pytype def __init__(self, message: str = ""): - super().__init__("RuntimeError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class NotImplementedError(Static[RuntimeError]): +class NotImplementedError(RuntimeError): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("NotImplementedError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class StopIteration(Static[Exception]): +class StopIteration(Exception): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("StopIteration", message) + super().__init__(message) self.python_type = self.__class__._pytype -class AssertionError(Static[Exception]): +class AssertionError(Exception): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("AssertionError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class EOFError(Static[Exception]): +class EOFError(Exception): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): - super().__init__("EOFError", message) + super().__init__(message) self.python_type = self.__class__._pytype -class SystemExit(Static[BaseException]): +class SystemExit(BaseException): _pytype: ClassVar[cobj] = cobj() _status: int def __init__(self, message: str = "", status: int = 0): - super().__init__("SystemExit", message) + super().__init__(message) self._status = status self.python_type = self.__class__._pytype @@ -173,6 +176,6 @@ class SystemExit(Static[BaseException]): def status(self): return self._status -class StaticCompileError(Static[Exception]): +class StaticCompileError(Exception): def __init__(self, message: str = ""): - super().__init__("StaticCompileError", message) + super().__init__(message) diff --git a/stdlib/internal/types/function.codon b/stdlib/internal/types/function.codon new file mode 100644 index 00000000..d1c62b4c --- /dev/null +++ b/stdlib/internal/types/function.codon @@ -0,0 +1,77 @@ +# Copyright (C) 2022-2025 Exaloop Inc. + +import internal.static as static + +@extend +class Function: + @pure + @overload + @llvm + def __new__(what: Ptr[byte]) -> Function[T, TR]: + ret ptr %what + + @overload + def __new__(what: Function[T, TR]) -> Function[T, TR]: + return what + + @pure + @llvm + def __raw__(self) -> Ptr[byte]: + ret ptr %self + + def __repr__(self) -> str: + return Ptr._ptr_to_str(self.__raw__(), "function") + + @llvm + def __call_internal__(self: Function[T, TR], args: T) -> TR: + noop # compiler will populate this one + + def __call__(self, *args) -> TR: + return Function.__call_internal__(self, args) + + +@extend +class Callable: + def __new__(fn: Function[[Ptr[byte], T], TR], data: Ptr[byte]) -> Callable[T, TR]: + return type._force_value_cast((fn, data), Callable[T, TR]) + + @overload + def __new__(fn: Function[[Ptr[byte], T], TR], data: Partial[M,PT,K,F], + T: type, TR: type, + M: Literal[str], PT: type, F: type, K: type) -> Callable[T, TR]: + p = Ptr[Partial[M,PT,K,F]](1) + p[0] = data + return Callable(fn, p.as_byte()) + + @overload + def __new__(fn: Function[[Ptr[byte], T], TR], data: Function[T, TR]) -> Callable[T, TR]: + return Callable(fn, data.__raw__()) + + @overload + def __new__(fn: Function[T, TR]) -> Callable[T, TR]: + def _wrap(data: Ptr[byte], args, f: type): + return f(data)(*args) + return Callable( + static.function.realized(_wrap(f=Function[T, TR], ...), Ptr[byte], T), + fn.__raw__() + ) + + def __call__(self, *args): + return self.fn.__call__(self.data, args) + + +@extend +class Partial: + def __repr__(self): + return __magic__.repr_partial(self) + + def __call__(self, *args, **kwargs): + return self(*args, **kwargs) + + @property + def __fn_name__(self): + return F.__name__[16:-1] # chop off unrealized_type + + def __raw__(self): + # TODO: better error message + return F.T.__raw__() diff --git a/stdlib/internal/types/generator.codon b/stdlib/internal/types/generator.codon index 0bd049b3..39d98c9f 100644 --- a/stdlib/internal/types/generator.codon +++ b/stdlib/internal/types/generator.codon @@ -2,6 +2,14 @@ @extend class Generator: + def _yield_final(val): + """Compiler-generated.""" + pass + + def _yield_in_no_suspend(T: type) -> T: + """Compiler-generated.""" + pass + @__internal__ def __promise__(self) -> Ptr[T]: pass @@ -20,6 +28,7 @@ class Generator: return self @pure + @derives @llvm def __raw__(self) -> Ptr[byte]: ret ptr %self @@ -50,7 +59,7 @@ class Generator: ret {} {} def __repr__(self) -> str: - return __internal__.raw_type_str(self.__raw__(), "generator") + return Ptr._ptr_to_str(self.__raw__(), "generator") def send(self, what: T) -> T: p = self.__promise__() diff --git a/stdlib/internal/types/import_.codon b/stdlib/internal/types/import_.codon new file mode 100644 index 00000000..54805c26 --- /dev/null +++ b/stdlib/internal/types/import_.codon @@ -0,0 +1,10 @@ +# Copyright (C) 2022-2025 Exaloop Inc. + + +@extend +class Import: + def _set_loaded(i: Ptr[Import]): + Ptr[bool](i.as_byte())[0] = True + + def __repr__(self) -> str: + return f"" diff --git a/stdlib/internal/types/intn.codon b/stdlib/internal/types/intn.codon index 7a9599dc..7ad462fe 100644 --- a/stdlib/internal/types/intn.codon +++ b/stdlib/internal/types/intn.codon @@ -9,6 +9,24 @@ def _check_bitwidth(N: Literal[int]): @extend class Int: + @pure + @llvm + def _sext(what, F: Literal[int], T: Literal[int]) -> Int[T]: + %0 = sext i{=F} %what to i{=T} + ret i{=T} %0 + + @pure + @llvm + def _zext(what, F: Literal[int], T: Literal[int]) -> Int[T]: + %0 = zext i{=F} %what to i{=T} + ret i{=T} %0 + + @pure + @llvm + def _trunc(what, F: Literal[int], T: Literal[int]) -> Int[T]: + %0 = trunc i{=F} %what to i{=T} + ret i{=T} %0 + def __new__() -> Int[N]: _check_bitwidth(N) return Int[N](0) @@ -28,21 +46,21 @@ class Int: def __new__(what: Int[M], M: Literal[int]) -> Int[N]: _check_bitwidth(N) if N < M: - return __internal__.int_trunc(what, M, N) + return Int._trunc(what, M, N) elif N == M: return what else: - return __internal__.int_sext(what, M, N) + return Int._sext(what, M, N) @overload def __new__(what: UInt[M], M: Literal[int]) -> Int[N]: _check_bitwidth(N) if N < M: - return __internal__.int_trunc(what, M, N) + return Int._trunc(what, M, N) elif N == M: return Int[N](what) else: - return __internal__.int_sext(what, M, N) + return Int._sext(what, M, N) @overload def __new__(what: UInt[N]) -> Int[N]: @@ -58,11 +76,11 @@ class Int: def __new__(what: int) -> Int[N]: _check_bitwidth(N) if N < 64: - return __internal__.int_trunc(what, 64, N) + return Int._trunc(what, 64, N) elif N == 64: return what else: - return __internal__.int_sext(what, 64, N) + return Int._sext(what, 64, N) @overload def __new__(what: str) -> Int[N]: @@ -82,11 +100,11 @@ class Int: def __int__(self) -> int: if N > 64: - return __internal__.int_trunc(self, N, 64) + return Int._trunc(self, N, 64) elif N == 64: return self else: - return __internal__.int_sext(self, N, 64) + return Int._sext(self, N, 64) def __index__(self) -> int: return int(self) @@ -332,21 +350,21 @@ class UInt: def __new__(what: UInt[M], M: Literal[int]) -> UInt[N]: _check_bitwidth(N) if N < M: - return UInt[N](__internal__.int_trunc(what, M, N)) + return UInt[N](Int._trunc(what, M, N)) elif N == M: return what else: - return UInt[N](__internal__.int_zext(what, M, N)) + return UInt[N](Int._zext(what, M, N)) @overload def __new__(what: Int[M], M: Literal[int]) -> UInt[N]: _check_bitwidth(N) if N < M: - return UInt[N](__internal__.int_trunc(what, M, N)) + return UInt[N](Int._trunc(what, M, N)) elif N == M: return UInt[N](what) else: - return UInt[N](__internal__.int_sext(what, M, N)) + return UInt[N](Int._sext(what, M, N)) @overload def __new__(what: Int[N]) -> UInt[N]: @@ -371,11 +389,11 @@ class UInt: _check_bitwidth(N) if N < 64: - return UInt[N](__internal__.int_trunc(what, 64, N)) + return UInt[N](Int._trunc(what, 64, N)) elif N == 64: return convert(what) else: - return UInt[N](__internal__.int_sext(what, 64, N)) + return UInt[N](Int._sext(what, 64, N)) def __new__(what: str) -> UInt[N]: _check_bitwidth(N) @@ -383,11 +401,11 @@ class UInt: def __int__(self) -> int: if N > 64: - return __internal__.int_trunc(self, N, 64) + return Int._trunc(self, N, 64) elif N == 64: return Int[64](self) else: - return __internal__.int_zext(self, N, 64) + return Int._zext(self, N, 64) def __index__(self) -> int: return int(self) diff --git a/stdlib/internal/types/optional.codon b/stdlib/internal/types/optional.codon index 28d59470..8abdb6b7 100644 --- a/stdlib/internal/types/optional.codon +++ b/stdlib/internal/types/optional.codon @@ -1,47 +1,81 @@ # Copyright (C) 2022-2025 Exaloop Inc. @extend -class __internal__: +class Optional: @pure @llvm - def opt_tuple_new(T: type) -> Optional[T]: + def _tuple_new(T: type) -> Optional[T]: ret { i1, {=T} } { i1 false, {=T} undef } + @pure + @llvm + def _ref_new(T: type) -> Optional[T]: + ret ptr null -@extend -class Optional: def __new__() -> Optional[T]: if isinstance(T, ByVal): - return __internal__.opt_tuple_new(T) - elif static.has_rtti(T): - return __internal__.opt_ref_new_rtti(T) + return Optional._tuple_new(T) else: - return __internal__.opt_ref_new(T) + return Optional._ref_new(T) + + @pure + @derives + @llvm + def _tuple_new_arg(what: T, T: type) -> Optional[T]: + %0 = insertvalue { i1, {=T} } { i1 true, {=T} undef }, {=T} %what, 1 + ret { i1, {=T} } %0 + + @pure + @derives + @llvm + def _ref_new_arg(what: T, T: type) -> Optional[T]: + ret ptr %what @overload def __new__(what: T) -> Optional[T]: if isinstance(T, ByVal): - return __internal__.opt_tuple_new_arg(what, T) - elif static.has_rtti(T): - return __internal__.opt_ref_new_arg_rtti(what, T) + return Optional._tuple_new_arg(what, T) else: - return __internal__.opt_ref_new_arg(what, T) + return Optional._ref_new_arg(what, T) + + @pure + @llvm + def _tuple_bool(what: Optional[T], T: type) -> bool: + %0 = extractvalue { i1, {=T} } %what, 0 + %1 = zext i1 %0 to i8 + ret i8 %1 + + @pure + @llvm + def _ref_bool(what: Optional[T], T: type) -> bool: + %0 = icmp ne ptr %what, null + %1 = zext i1 %0 to i8 + ret i8 %1 def __has__(self) -> bool: if isinstance(T, ByVal): - return __internal__.opt_tuple_bool(self, T) - elif static.has_rtti(T): - return __internal__.opt_ref_bool_rtti(self, T) + return Optional._tuple_bool(self, T) else: - return __internal__.opt_ref_bool(self, T) + return Optional._ref_bool(self, T) + + @pure + @derives + @llvm + def _tuple_invert(what: Optional[T], T: type) -> T: + %0 = extractvalue { i1, {=T} } %what, 1 + ret {=T} %0 + + @pure + @derives + @llvm + def _ref_invert(what: Optional[T], T: type) -> T: + ret ptr %what def __val__(self) -> T: if isinstance(T, ByVal): - return __internal__.opt_tuple_invert(self, T) - elif static.has_rtti(T): - return __internal__.opt_ref_invert_rtti(self, T) + return Optional._tuple_invert(self, T) else: - return __internal__.opt_ref_invert(self, T) + return Optional._ref_invert(self, T) def __val_or__(self, default: T): if self.__has__(): diff --git a/stdlib/internal/types/ptr.codon b/stdlib/internal/types/ptr.codon index 280935f6..c6b5dd4f 100644 --- a/stdlib/internal/types/ptr.codon +++ b/stdlib/internal/types/ptr.codon @@ -190,6 +190,30 @@ class Ptr: def __repr__(self) -> str: return self.__format__("") + def _ptr_to_str(p: Ptr[byte], name: str) -> str: + pstr = p.__repr__() + # '<[name] at [pstr]>' + total = 1 + name.len + 4 + pstr.len + 1 + buf = Ptr[byte](total) + where = 0 + buf[where] = byte(60) # '<' + where += 1 + str.memcpy(buf + where, name.ptr, name.len) + where += name.len + buf[where] = byte(32) # ' ' + where += 1 + buf[where] = byte(97) # 'a' + where += 1 + buf[where] = byte(116) # 't' + where += 1 + buf[where] = byte(32) # ' ' + where += 1 + str.memcpy(buf + where, pstr.ptr, pstr.len) + where += pstr.len + buf[where] = byte(62) # '>' + free(pstr.ptr) + return str(buf, total) + ptr = Ptr Jar = Ptr[byte] diff --git a/stdlib/internal/types/rtti.codon b/stdlib/internal/types/rtti.codon new file mode 100644 index 00000000..acc279e0 --- /dev/null +++ b/stdlib/internal/types/rtti.codon @@ -0,0 +1,132 @@ +# Copyright (C) 2022-2025 Exaloop Inc. + +import internal.static as static + + +__vtables__ = Ptr[Ptr[cobj]]() +__vtable_size__ = 0 + + +@extend +class RTTIType: + def _new(data: Ptr[byte], typeinfo: Ptr[byte], T: type) -> T: + """ + Creates a new RTTIType wrapper for data and casts it to T. + + Internal use only. No type checks are performed. + """ + + p = type._ref_new(RTTIType) + p.data = data + p.typeinfo = typeinfo + return type._force_cast(p, T) + + @inline + def _dist(B: type, D: type) -> int: + """Calculates the byte distance of base class B and derived class D. Compiler generated.""" + return 0 + + @inline + def _to_derived(b: B, B: type, D: type) -> D: + if not (static.has_rtti(D) and static.has_rtti(B)): + compile_error("classes are not polymorphic") + off = RTTIType._dist(B, D) + rtti = type._force_cast(b, RTTIType) + return RTTIType._new(rtti.data - off, rtti.typeinfo, D) + + + ### vTable & thunk setup. Mostly compiler-generated. + + def _init_vtables(): + """ + Create a global vtable. + """ + + from internal.gc import alloc_atomic_uncollectable, sizeof + + global __vtables__ + sz = __vtable_size__ + 1 + p = alloc_atomic_uncollectable(sz * sizeof(Ptr[Ptr[byte]])) + __vtables__ = Ptr[Ptr[Ptr[byte]]](p) + RTTIType._populate_vtables() + + def _populate_vtables(): + """ + Populate content of vtables. Compiler generated. + Corresponds to: + for each realized class C: + _init_vtable(, + 1, T=C) + for each fn F in C's vtable: + _set_vtable_fn( + , , Function().__raw__(), T=C + ) + """ + pass + + def _init_vtable(sz: int, T: type): + from internal.gc import alloc_atomic_uncollectable, sizeof + + if not static.has_rtti(T): + compile_error("class is not polymorphic") + p = alloc_atomic_uncollectable((sz + 1) * sizeof(Ptr[byte])) + id = T.__id__ + __vtables__[id] = Ptr[Ptr[byte]](p) + # Set typeinfo + p = TypeInfo(T) + __vtables__[id][0] = p.__raw__().as_byte() + + def _set_vtable_fn(id: int, fid: int, f: Ptr[byte], T: type): + if not static.has_rtti(T): + compile_error("class is not polymorphic") + __vtables__[id][fid] = f + + def _get_thunk_id(F: type, T: type) -> int: + """Compiler-generated""" + return 0 + + def _thunk_debug(base, func, sig, *args): + # print("![thunk]!", base, func, sig, args[0].__raw__()) + pass + + @no_argument_wrap + def _thunk_dispatch(slf, cls_id, *args, F: type): + if not static.has_rtti(type(slf)): + compile_error("class is not polymorphic") + + FR = type(static.function.realized(F, slf, *args)) + T = type(slf) + thunk_id = RTTIType._get_thunk_id(FR, T) + + # Get RTTI table + if cls_id == 0: + cls_id = type._force_cast(type._force_cast(slf, RTTIType).typeinfo, TypeInfo).id + fptr = __vtables__[cls_id][thunk_id] + f = FR(fptr) + return f(slf, *args) + + +@extend +class Super: + def __repr__(self): + return f'' + + def _super(obj, B: type, use_super_type: Literal[int] = 0): + D = type(obj) + if not static.has_rtti(D): # static inheritance + return type._force_cast(obj, B) + else: + if not static.has_rtti(B): + compile_error("classes are not polymorphic") + off = RTTIType._dist(B, D) + rtti = type._force_cast(obj, RTTIType) + res = RTTIType._new(rtti.data + off, rtti.typeinfo, B) + if use_super_type: + # This is explicit super() + return type._force_value_cast((res, ), Super[B]) + else: + # Implicit super() just used for casting + return res + + def _unwrap(obj: Super[B], B: type) -> B: + rtti = type._force_cast(obj.__obj__, RTTIType) + return RTTIType._new(rtti.data, TypeInfo(B).__raw__(), B) diff --git a/stdlib/internal/types/slice.codon b/stdlib/internal/types/slice.codon index 284e6437..dd2cba05 100644 --- a/stdlib/internal/types/slice.codon +++ b/stdlib/internal/types/slice.codon @@ -28,7 +28,7 @@ class Slice: T: type = int, U: type = int, V: type = int) -> Slice[T, U, V]: - return __internal__.tuple_cast_unsafe((start, stop, step), Slice[T, U, V]) + return type._force_value_cast((start, stop, step), Slice[T, U, V]) def adjust_indices(self, length: int) -> Tuple[int, int, int, int]: if not (T is int or T is None) or not (U is int or U is None) or not (V is int or V is None): diff --git a/stdlib/internal/types/str.codon b/stdlib/internal/types/str.codon index 370c413b..338ef1e3 100644 --- a/stdlib/internal/types/str.codon +++ b/stdlib/internal/types/str.codon @@ -24,7 +24,7 @@ class str: @overload def __new__(what) -> str: if isinstance(what, Union): - return __internal__.union_str(what) + return Union._str(what) elif isinstance(what, type): return what.__repr__() elif hasattr(what, "__str__"): diff --git a/stdlib/internal/types/tuple.codon b/stdlib/internal/types/tuple.codon new file mode 100644 index 00000000..7d9b37ca --- /dev/null +++ b/stdlib/internal/types/tuple.codon @@ -0,0 +1,117 @@ +# Copyright (C) 2022-2025 Exaloop Inc. + +import internal.static as static + +@extend +class Tuple: + def _fix_index(idx: int, len: int) -> int: + if idx < 0: + idx += len + if idx < 0 or idx >= len: + raise IndexError("tuple index out of range") + return idx + + def _getitem(t: T, idx: int, T: type, E: type) -> E: + @pure + @derives + @llvm + def llvm_helper(t: T, idx: int, T: type, E: type) -> E: + %x = alloca {=T} + store {=T} %t, ptr %x + %p = getelementptr {=E}, ptr %x, i64 %idx + %v = load {=E}, ptr %p + ret {=E} %v + + return llvm_helper(t, Tuple._fix_index(idx, static.len(t)), T, E) + + def _offsetof(x, field: Literal[int]) -> int: + @pure + @llvm + def llvm_helper(T: type, idx: Literal[int], TE: type) -> int: + %a = alloca {=T} + %b = getelementptr inbounds {=T}, ptr %a, i64 0, i32 {=idx} + %base = ptrtoint ptr %a to i64 + %elem = ptrtoint ptr %b to i64 + %offset = sub i64 %elem, %base + ret i64 %offset + + return llvm_helper(type(x), field, type(x[field])) + + def _str(strs: Ptr[str], names: Ptr[str], n: int) -> str: + # special case of 1-element plain tuple: format as "(x,)" + if n == 1 and names[0].len == 0: + total = strs[0].len + 3 + buf = Ptr[byte](total) + buf[0] = byte(40) # '(' + str.memcpy(buf + 1, strs[0].ptr, strs[0].len) + buf[total - 2] = byte(44) # ',' + buf[total - 1] = byte(41) # ')' + return str(buf, total) + + total = 2 # one for each of '(' and ')' + i = 0 + while i < n: + total += strs[i].len + if names[i].len: + total += names[i].len + 2 # extra : and space + if i < n - 1: + total += 2 # ", " + i += 1 + buf = Ptr[byte](total) + where = 0 + buf[where] = byte(40) # '(' + where += 1 + i = 0 + while i < n: + s = names[i] + l = s.len + if l: + str.memcpy(buf + where, s.ptr, l) + where += l + buf[where] = byte(58) # ':' + where += 1 + buf[where] = byte(32) # ' ' + where += 1 + s = strs[i] + l = s.len + str.memcpy(buf + where, s.ptr, l) + where += l + if i < n - 1: + buf[where] = byte(44) # ',' + where += 1 + buf[where] = byte(32) # ' ' + where += 1 + i += 1 + buf[where] = byte(41) # ')' + return str(buf, total) + + +@extend +class NamedTuple: + def __getitem__(self, key: Literal[str]): + return getattr(self, key) + + def __contains__(self, key: Literal[str]): + return hasattr(self, key) + + def _get(kw, key: Literal[str], default): + if hasattr(kw, key): + return getattr(kw, key) + else: + return default + + def get(self, key: Literal[str], default = NoneType()): + return NamedTuple.kwargs_get(self, key, default) + + def _namedkeys(N: Literal[int]): + # Compiler generated + pass + + def __keys__(self): + return NamedTuple._namedkeys(N) + + def __repr__(self): + keys = self.__keys__() + values = [v.__repr__() for v in self.args] + s = ', '.join(f"{keys[i]}: {values[i]}" for i in range(len(keys))) + return f"({s})" diff --git a/stdlib/internal/types/type.codon b/stdlib/internal/types/type.codon new file mode 100644 index 00000000..4d93e87a --- /dev/null +++ b/stdlib/internal/types/type.codon @@ -0,0 +1,186 @@ +# Copyright (C) 2022-2025 Exaloop Inc. + +import internal.static as static + + +def __type_repr__(T: type): + return f"" + + +@extend +class type: + @pure + @derives + @llvm + def _force_cast(p, T: type) -> T: + """ + Casts p of any reference type to the reference type T. + + This method is intented for the internal typechecking usage and is completely unsafe. + No checks are performed; T is assumed to be a reference type. + Any violations will result in LLVM errors. + """ + ret ptr %p + + @pure + @derives + @llvm + def _force_value_cast(t, U: type) -> U: + """ + Casts t of any type to U. Used for casting tuples into named tuple types. + + This method is intented for the internal typechecking usage and is completely unsafe. + No checks are performed; t is assumed to be byte-compatible with U. + Any violations will result in LLVM errors. + """ + ret {=U} %t + + @pure + @derives + @llvm + def _ref_raw(obj) -> Ptr[byte]: + """ + Casts the reference type to a pointer. + """ + + ret ptr %obj + + def _ref_new(T: type) -> T: + """ + Allocates a new reference (class) object. + """ + + from internal.gc import alloc, alloc_atomic, sizeof, register_finalizer + + sz = sizeof(tuple(T)) + obj = alloc_atomic(sz) if T.__contents_atomic__ else alloc(sz) + register_finalizer(obj) + if static.has_rtti(T): + obj = RTTIType._new(obj, TypeInfo(T).__raw__(), Ptr[byte]) + return type._force_cast(obj, T) + + def _construct(T: type, *args, **kwargs) -> T: + """ + Shorthand for `t = T.__new__(); t.__init__(*args, **kwargs); t` + """ + + return T(*args, **kwargs) + + +@__internal__ +class TypeInfo: + id: int + _parent_ids: Ptr[int] + raw_name: str + nice_name: str + repr: Function[[Ptr[byte]], str] + + def __new__() -> TypeInfo: + return __magic__.new(TypeInfo) + + def __raw__(self) -> Ptr[byte]: + return __magic__.raw(self) + + def __init__(self, T: type): + if isinstance(T, TypeWrap): + self.__init__(T.T) + return + + self.id = T.__id__ + self.raw_name = T.__name__ + self.nice_name = f"{T}" + + self.repr = Function[[Ptr[byte]], str](cobj()) + # if hasattr(T, "__repr__") or hasattr(T, "__str__"): + # fn = static.function.realized(TypeInfo.wrap(T=T, ...), Ptr[byte]) + # if isinstance(fn.TR, str): + # self.repr = fn + + mro = T.__mro__ + num_mro: Literal[int] = static.len(mro) + + if num_mro > 0: + self._parent_ids = Ptr[int](num_mro + 1) + for i in static.range(num_mro): + self._parent_ids[i] = mro[i].T.__id__ + self._parent_ids[num_mro] = 0 + else: + self._parent_ids = Ptr[int]() + + def wrap(arg: Ptr[byte], T: type) -> str: + if isinstance(T, ByVal): + p = type._force_cast(arg, Ptr[T]) + return repr(p[0]) + else: + obj = type._force_cast(arg, T) + return repr(obj) + + def __str__(self): + return self.nice_name + + def __repr__(self): + return self.__str__() + + def __eq__(self, other: TypeInfo): + return self.id == other.id + + def __ne__(self, other: TypeInfo): + return self.id != other.id + + +@extend +class TypeWrap: + def __new__(T: type) -> TypeWrap[T]: + return type._force_value_cast((), TypeWrap[T]) + + def __call_no_self__(*args, **kwargs) -> T: + return T(*args, **kwargs) + + def __call__(self, *args, **kwargs) -> T: + return T(*args, **kwargs) + + def __repr__(self): + return __type_repr__(T) + + @property + def __name__(self): + return T.__name__ + + +class __cast__: + def cast(obj: T, T: type) -> Generator[T]: + return obj.__iter__() + + @overload + def cast(obj: int) -> float: + return float(obj) + + @overload + def cast(obj: T, T: type) -> Optional[T]: + return Optional[T](obj) + + @overload + def cast(obj: Optional[T], T: type) -> T: + return obj.unwrap() + + @overload + def cast(obj: T, T: type) -> pyobj: + return obj.__to_py__() + + @overload + def cast(obj: pyobj, T: type) -> T: + return T.__from_py__(obj) + + # Function[[T...], R] + # ExternFunction[[T...], R] + # CodonFunction[[T...], R] + # Partial[foo, [T...], R] + + # function into partial (if not Function) / fn(foo) -> fn(foo(...)) + # empty partial (!!) into Function[] + # union extract + # any into Union[] + # derived to base + + def conv_float(obj: float) -> int: + return int(obj) diff --git a/stdlib/internal/types/union.codon b/stdlib/internal/types/union.codon new file mode 100644 index 00000000..b341edd2 --- /dev/null +++ b/stdlib/internal/types/union.codon @@ -0,0 +1,108 @@ +# Copyright (C) 2022-2025 Exaloop Inc. + +import internal.static as static + + +@extend +class Union: + def _tag(u, tag: Literal[int]): # compiler-generated + pass + + @llvm + def _set_tag(tag: byte, U: type) -> U: + %0 = insertvalue {=U} undef, i8 %tag, 0 + ret {=U} %0 + + @llvm + def _get_data_ptr(ptr: Ptr[U], U: type, T: type) -> Ptr[T]: + %0 = getelementptr inbounds {=U}, ptr %ptr, i64 0, i32 1 + ret ptr %0 + + @llvm + def _get_tag(u: U, U: type) -> byte: + %0 = extractvalue {=U} %u, 0 + ret i8 %0 + + def _get_data(u, T: type) -> T: + return Union._get_data_ptr(__ptr__(u), T=T)[0] + + def _make(tag: int, value, U: type) -> U: + u = Union._set_tag(byte(tag), U) + Union._get_data_ptr(__ptr__(u), T=type(value))[0] = value + return u + + def _new(value, U: type) -> U: + for tag, T in static.vars_types(U, with_index=True): + if isinstance(value, T): + return Union._make(tag, value, U) + if isinstance(value, Union[T]): + return Union._make(tag, Union._get(value, T), U) + # TODO: make this static! + raise TypeError("invalid union constructor") + + def _get(union, T: type) -> T: + for tag, TU in static.vars_types(union, with_index=True): + if isinstance(TU, T): + if Union._get_tag(union) == tag: + return Union._get_data(union, TU) + raise TypeError(f"invalid union getter for type '{T.__class__.__name__}'") + + def _member_helper(union, member: Literal[str]) -> Union: + for tag, T in static.vars_types(union, with_index=True): + if hasattr(T, member): + if Union._get_tag(union) == tag: + return getattr(Union._get_data(union, T), member) + raise TypeError(f"invalid union call '{member}'") + + def _member(union, member: Literal[str]): + t = Union._member_helper(union, member) + if static.len(t) == 1: + return Union._tag(t, 0) + else: + return t + + def _call_helper(union, args, kwargs) -> Union: + for tag, T in static.vars_types(union, with_index=True): + if static.function.can_call(T, *args, **kwargs): + if Union._get_tag(union) == tag: + return Union._get_data(union, T)(*args, **kwargs) + elif hasattr(T, '__call__'): + if static.function.can_call(T.__call__, *args, **kwargs): + if Union._get_tag(union) == tag: + return Union._get_data(union, T).__call__(*args, **kwargs) + raise TypeError("cannot call union " + union.__class__.__name__) + + def _call(union, args, kwargs): + t = Union._call_helper(union, args, kwargs) + if static.len(t) == 1: + return Union._tag(t, 0) + else: + return t + + def __call__(self, *args, **kwargs): + return Union._call(self, args, kwargs) + + def _str(union): + for tag, T in static.vars_types(union, with_index=True): + if hasattr(T, '__str__'): + if Union._get_tag(union) == tag: + return Union._get_data(union, T).__str__() + elif hasattr(T, '__repr__'): + if Union._get_tag(union) == tag: + return Union._get_data(union, T).__repr__() + return '' + + def _and(x, y): + if type(x) is type(y): + return y if x else x + else: + T = Union[type(x),type(y)] + return T(y) if x else T(x) + + def _or(x, y): + if type(x) is type(y): + return x if x else y + else: + T = Union[type(x),type(y)] + return T(x) if x else T(y) + diff --git a/stdlib/numpy/linalg/linalg.codon b/stdlib/numpy/linalg/linalg.codon index c751e120..218b5165 100644 --- a/stdlib/numpy/linalg/linalg.codon +++ b/stdlib/numpy/linalg/linalg.codon @@ -17,7 +17,7 @@ from ..util import cast, cdiv_int, coerce, eps, exp, free, inf, log, multirange, # Utilities # ############# -class LinAlgError(Static[Exception]): +class LinAlgError(Exception): def __init__(self, message: str = ''): super().__init__("numpy.linalg.LinAlgError", message) diff --git a/stdlib/numpy/ndarray.codon b/stdlib/numpy/ndarray.codon index 2ad27838..3b52f695 100644 --- a/stdlib/numpy/ndarray.codon +++ b/stdlib/numpy/ndarray.codon @@ -316,7 +316,7 @@ class ndarray[dtype, ndim: Literal[int]]: def __new__(shape: Tuple[ndim, int], strides: Tuple[ndim, int], data: Ptr[dtype]) -> ndarray[dtype, ndim]: - return __internal__.tuple_cast_unsafe((shape, strides, data), ndarray[dtype, ndim]) + return type._force_value_cast((shape, strides, data), ndarray[dtype, ndim]) def __new__(shape: Tuple[ndim, int], data: Ptr[dtype], fcontig: bool = False): strides = util.strides(shape, fcontig, dtype) diff --git a/stdlib/numpy/npio.codon b/stdlib/numpy/npio.codon index b4d21386..f9ae924b 100644 --- a/stdlib/numpy/npio.codon +++ b/stdlib/numpy/npio.codon @@ -2180,8 +2180,7 @@ def genfromtxt(fname, upper = False lower = False - BAD_CASE_SENSITIVE: Static[ - str] = "'case_sensitive' must be True, False, 'upper' or 'lower'" + BAD_CASE_SENSITIVE: Literal[str] = "'case_sensitive' must be True, False, 'upper' or 'lower'" if isinstance(case_sensitive, bool): if not case_sensitive: diff --git a/stdlib/numpy/util.codon b/stdlib/numpy/util.codon index 8cf57b75..ac8c83b7 100644 --- a/stdlib/numpy/util.codon +++ b/stdlib/numpy/util.codon @@ -9,13 +9,13 @@ from .npdatetime import datetime64, timedelta64, _promote as dt_promote, \ # Exceptions # ############## -class AxisError(Static[Exception]): +class AxisError(Exception): def __init__(self, message: str = ''): - super().__init__("numpy.AxisError", message) + super().__init__(message) -class TooHardError(Static[Exception]): +class TooHardError(Exception): def __init__(self, message: str = ''): - super().__init__("numpy.TooHardError", message) + super().__init__(message) ############## diff --git a/stdlib/re.codon b/stdlib/re.codon index 32803aaa..b6119705 100644 --- a/stdlib/re.codon +++ b/stdlib/re.codon @@ -82,11 +82,11 @@ def seq_re_purge() -> None: def seq_re_compile(pattern: str, flags: int) -> cobj: pass -class error(Static[Exception]): +class error(Exception): pattern: str def __init__(self, message: str = "", pattern: str = ""): - super().__init__("re.error", message) + super().__init__(message) self.pattern = pattern @property diff --git a/stdlib/statistics.codon b/stdlib/statistics.codon index 23ea4c7b..8d8bf1bf 100644 --- a/stdlib/statistics.codon +++ b/stdlib/statistics.codon @@ -16,9 +16,9 @@ from math import ( hypot as _hypot, ) -class StatisticsError(Static[Exception]): +class StatisticsError(Exception): def __init__(self, message: str = ""): - super().__init__("StatisticsError", message) + super().__init__(message) def median(data: List[T], T: type) -> float: """ diff --git a/test/core/exceptions.codon b/test/core/exceptions.codon index bd9c6160..690151b0 100644 --- a/test/core/exceptions.codon +++ b/test/core/exceptions.codon @@ -1,36 +1,36 @@ -class Exc1(Static[Exception]): +class Exc1(Exception): def __init__(self, msg: str): - super().__init__('Exc1', msg) + super().__init__(msg) def show(self): print self.message -class Exc2(Static[Exception]): +class Exc2(Exception): def __init__(self, msg: str): - super().__init__('Exc2', msg) + super().__init__(msg) def show(self): print self.message -class A(Static[Exception]): +class A(Exception): def __init__(self, msg: str): - super().__init__('A', msg) + super().__init__(msg) -class B(Static[Exception]): +class B(Exception): def __init__(self, msg: str): - super().__init__('B', msg) + super().__init__(msg) -class C(Static[Exception]): +class C(Exception): def __init__(self, msg: str): - super().__init__('C', msg) + super().__init__(msg) -class D(Static[Exception]): +class D(Exception): def __init__(self, msg: str): - super().__init__('D', msg) + super().__init__(msg) -class E(Static[Exception]): +class E(Exception): def __init__(self, msg: str): - super().__init__('E', msg) + super().__init__(msg) def foo1(x): if x: diff --git a/test/parser/typecheck/test_call.codon b/test/parser/typecheck/test_call.codon index e03a42b1..953380bc 100644 --- a/test/parser/typecheck/test_call.codon +++ b/test/parser/typecheck/test_call.codon @@ -602,6 +602,32 @@ Base.test(a) #: a.test #: base.test +class A: + n: int + def __init__(self, n: int): + print(f"A init, set {n}") + self.n = n + +class B(A): + def __init__(self, n: int): + print(f"B init, set {n}") + super().__init__(n) + +class C(B): + def __init__(self, n: int): + print(f"C init, set {n}") + super().__init__(n) + +print(B(42).n) +#: B init, set 42 +#: A init, set 42 +#: 42 +print(C(42).n) +#: C init, set 42 +#: B init, set 42 +#: A init, set 42 +#: 42 + #%% super_tuple,barebones @tuple class A[T]: diff --git a/test/parser/typecheck/test_class.codon b/test/parser/typecheck/test_class.codon index 5da7a80a..894839b0 100644 --- a/test/parser/typecheck/test_class.codon +++ b/test/parser/typecheck/test_class.codon @@ -569,3 +569,37 @@ class Tuple: pass ("abc",).inc_ref() # #683 #! 'Tuple' cannot be extended + + +#%% class_var_poly,barebones +class A: + apple = {"a": 1, "b": 2, "c": 3} + pear = 3 +class B(A): + pass +class C(B): + pass + +print(A.apple) +#: {'a': 1, 'b': 2, 'c': 3} +print(A.pear) +#: 3 +print(B.apple) +#: {'a': 1, 'b': 2, 'c': 3} +print(B.pear) +#: 3 +print(C.apple) +#: {'a': 1, 'b': 2, 'c': 3} +print(C.pear) +#: 3 +A.pear += 1 +print(A.pear) +#: 4 +print(B.pear) +#: 4 +print(C.pear) +#: 4 +B.apple["d"] = 4 +print(C.apple) +#: {'a': 1, 'b': 2, 'c': 3, 'd': 4} + diff --git a/test/parser/typecheck/test_error.codon b/test/parser/typecheck/test_error.codon index 90fb18c3..9a6401af 100644 --- a/test/parser/typecheck/test_error.codon +++ b/test/parser/typecheck/test_error.codon @@ -13,9 +13,9 @@ except AssertionError as e: print e.message[:23], e.message[-20:] #: Assert failed: hehe 1 ( test_error.codon:11) #%% try_throw,barebones -class MyError(Static[Exception]): +class MyError(Exception): def __init__(self, message: str): - super().__init__('MyError', message) + super().__init__(message) try: raise MyError("hello!") except MyError as e: diff --git a/test/parser/typecheck/test_infer.codon b/test/parser/typecheck/test_infer.codon index 19f4780f..5eb14241 100644 --- a/test/parser/typecheck/test_infer.codon +++ b/test/parser/typecheck/test_infer.codon @@ -679,7 +679,7 @@ def foo_int(x: int): def foo_str(x: str): print(f'{x} {x.__class__.__name__}') def foo(x): - print(f'{x} {int(__internal__.union_get_tag(x))} {x.__class__.__name__}') + print(f'{x} {int(Union._get_tag(x))} {x.__class__.__name__}') a: Union[int, str] = 5 foo_int(a) #: 5 int