Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Async for comprehension #5276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions Lib/test/test_asyncgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,16 +513,15 @@ def __anext__(self):
return self.yielded
self.check_async_iterator_anext(MyAsyncIterWithTypesCoro)

# TODO: RUSTPYTHON: async for gen expression compilation
# def test_async_gen_aiter(self):
# async def gen():
# yield 1
# yield 2
# g = gen()
# async def consume():
# return [i async for i in aiter(g)]
# res = self.loop.run_until_complete(consume())
# self.assertEqual(res, [1, 2])
def test_async_gen_aiter(self):
async def gen():
yield 1
yield 2
g = gen()
async def consume():
return [i async for i in aiter(g)]
res = self.loop.run_until_complete(consume())
self.assertEqual(res, [1, 2])

# TODO: RUSTPYTHON, NameError: name 'aiter' is not defined
@unittest.expectedFailure
Expand Down Expand Up @@ -1569,22 +1568,23 @@ async def main():
self.assertIn('unhandled exception during asyncio.run() shutdown',
message['message'])

# TODO: RUSTPYTHON: async for gen expression compilation
# def test_async_gen_expression_01(self):
# async def arange(n):
# for i in range(n):
# await asyncio.sleep(0.01)
# yield i
# TODO: RUSTPYTHON; TypeError: object async_generator can't be used in 'await' expression
@unittest.expectedFailure
def test_async_gen_expression_01(self):
async def arange(n):
for i in range(n):
await asyncio.sleep(0.01)
yield i

# def make_arange(n):
# # This syntax is legal starting with Python 3.7
# return (i * 2 async for i in arange(n))
def make_arange(n):
# This syntax is legal starting with Python 3.7
return (i * 2 async for i in arange(n))

# async def run():
# return [i async for i in make_arange(10)]
async def run():
return [i async for i in make_arange(10)]

# res = self.loop.run_until_complete(run())
# self.assertEqual(res, [i * 2 for i in range(10)])
res = self.loop.run_until_complete(run())
self.assertEqual(res, [i * 2 for i in range(10)])

# TODO: RUSTPYTHON: async for gen expression compilation
# def test_async_gen_expression_02(self):
Expand Down
74 changes: 38 additions & 36 deletions Lib/test/test_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,44 +418,46 @@ def test_var_annot_simple_exec(self):
gns['__annotations__']

# TODO: RUSTPYTHON
# def test_var_annot_custom_maps(self):
# # tests with custom locals() and __annotations__
# ns = {'__annotations__': CNS()}
# exec('X: int; Z: str = "Z"; (w): complex = 1j', ns)
# self.assertEqual(ns['__annotations__']['x'], int)
# self.assertEqual(ns['__annotations__']['z'], str)
# with self.assertRaises(KeyError):
# ns['__annotations__']['w']
# nonloc_ns = {}
# class CNS2:
# def __init__(self):
# self._dct = {}
# def __setitem__(self, item, value):
# nonlocal nonloc_ns
# self._dct[item] = value
# nonloc_ns[item] = value
# def __getitem__(self, item):
# return self._dct[item]
# exec('x: int = 1', {}, CNS2())
# self.assertEqual(nonloc_ns['__annotations__']['x'], int)
@unittest.expectedFailure
def test_var_annot_custom_maps(self):
# tests with custom locals() and __annotations__
ns = {'__annotations__': CNS()}
exec('X: int; Z: str = "Z"; (w): complex = 1j', ns)
self.assertEqual(ns['__annotations__']['x'], int)
self.assertEqual(ns['__annotations__']['z'], str)
with self.assertRaises(KeyError):
ns['__annotations__']['w']
nonloc_ns = {}
class CNS2:
def __init__(self):
self._dct = {}
def __setitem__(self, item, value):
nonlocal nonloc_ns
self._dct[item] = value
nonloc_ns[item] = value
def __getitem__(self, item):
return self._dct[item]
exec('x: int = 1', {}, CNS2())
self.assertEqual(nonloc_ns['__annotations__']['x'], int)

# TODO: RUSTPYTHON
# def test_var_annot_refleak(self):
# # complex case: custom locals plus custom __annotations__
# # this was causing refleak
# cns = CNS()
# nonloc_ns = {'__annotations__': cns}
# class CNS2:
# def __init__(self):
# self._dct = {'__annotations__': cns}
# def __setitem__(self, item, value):
# nonlocal nonloc_ns
# self._dct[item] = value
# nonloc_ns[item] = value
# def __getitem__(self, item):
# return self._dct[item]
# exec('X: str', {}, CNS2())
# self.assertEqual(nonloc_ns['__annotations__']['x'], str)
@unittest.expectedFailure
def test_var_annot_refleak(self):
# complex case: custom locals plus custom __annotations__
# this was causing refleak
cns = CNS()
nonloc_ns = {'__annotations__': cns}
class CNS2:
def __init__(self):
self._dct = {'__annotations__': cns}
def __setitem__(self, item, value):
nonlocal nonloc_ns
self._dct[item] = value
nonloc_ns[item] = value
def __getitem__(self, item):
return self._dct[item]
exec('X: str', {}, CNS2())
self.assertEqual(nonloc_ns['__annotations__']['x'], str)


def test_var_annot_rhs(self):
Expand Down
78 changes: 55 additions & 23 deletions compiler/codegen/src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2629,24 +2629,30 @@ impl Compiler {
compile_element: &dyn Fn(&mut Self) -> CompileResult<()>,
) -> CompileResult<()> {
let prev_ctx = self.ctx;
let is_async = generators.iter().any(|g| g.is_async);

self.ctx = CompileContext {
loop_data: None,
in_class: prev_ctx.in_class,
func: FunctionContext::Function,
func: if is_async {
FunctionContext::AsyncFunction
} else {
FunctionContext::Function
},
};

// We must have at least one generator:
assert!(!generators.is_empty());

let flags = bytecode::CodeFlags::NEW_LOCALS | bytecode::CodeFlags::IS_OPTIMIZED;
let flags = if is_async {
flags | bytecode::CodeFlags::IS_COROUTINE
} else {
flags
};

// Create magnificent function <listcomp>:
self.push_output(
bytecode::CodeFlags::NEW_LOCALS | bytecode::CodeFlags::IS_OPTIMIZED,
1,
1,
0,
name.to_owned(),
);
self.push_output(flags, 1, 1, 0, name.to_owned());
let arg0 = self.varname(".0")?;

let return_none = init_collection.is_none();
Expand All @@ -2657,13 +2663,11 @@ impl Compiler {

let mut loop_labels = vec![];
for generator in generators {
if generator.is_async {
unimplemented!("async for comprehensions");
}

let loop_block = self.new_block();
let after_block = self.new_block();

// emit!(self, Instruction::SetupLoop);

if loop_labels.is_empty() {
// Load iterator onto stack (passed as first argument):
emit!(self, Instruction::LoadFast(arg0));
Expand All @@ -2672,20 +2676,36 @@ impl Compiler {
self.compile_expression(&generator.iter)?;

// Get iterator / turn item into an iterator
emit!(self, Instruction::GetIter);
if generator.is_async {
emit!(self, Instruction::GetAIter);
} else {
emit!(self, Instruction::GetIter);
}
}

loop_labels.push((loop_block, after_block));

self.switch_to_block(loop_block);
emit!(
self,
Instruction::ForIter {
target: after_block,
}
);

self.compile_store(&generator.target)?;
if generator.is_async {
emit!(
self,
Instruction::SetupExcept {
handler: after_block,
}
);
emit!(self, Instruction::GetANext);
self.emit_constant(ConstantData::None);
emit!(self, Instruction::YieldFrom);
self.compile_store(&generator.target)?;
emit!(self, Instruction::PopBlock);
} else {
emit!(
self,
Instruction::ForIter {
target: after_block,
}
);
self.compile_store(&generator.target)?;
}

// Now evaluate the ifs:
for if_condition in &generator.ifs {
Expand All @@ -2701,6 +2721,9 @@ impl Compiler {

// End of for loop:
self.switch_to_block(after_block);
if is_async {
emit!(self, Instruction::EndAsyncFor);
}
}

if return_none {
Expand Down Expand Up @@ -2737,10 +2760,19 @@ impl Compiler {
self.compile_expression(&generators[0].iter)?;

// Get iterator / turn item into an iterator
emit!(self, Instruction::GetIter);
if is_async {
emit!(self, Instruction::GetAIter);
} else {
emit!(self, Instruction::GetIter);
};

// Call just created <listcomp> function:
emit!(self, Instruction::CallFunctionPositional { nargs: 1 });
if is_async {
emit!(self, Instruction::GetAwaitable);
self.emit_constant(ConstantData::None);
emit!(self, Instruction::YieldFrom);
}
Ok(())
}

Expand Down
32 changes: 31 additions & 1 deletion vm/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ impl ExecutingFrame<'_> {
let mut arg_state = bytecode::OpArgState::default();
loop {
let idx = self.lasti() as usize;
// eprintln!(
// "location: {:?} {}",
// self.code.locations[idx], self.code.source_path
// );
self.update_lasti(|i| *i += 1);
let bytecode::CodeUnit { op, arg } = instrs[idx];
let arg = arg_state.extend(arg);
Expand Down Expand Up @@ -993,6 +997,9 @@ impl ExecutingFrame<'_> {
Ok(None)
}
bytecode::Instruction::GetANext => {
#[cfg(debug_assertions)] // remove when GetANext is fully implemented
let orig_stack_len = self.state.stack.len();

let aiter = self.top_value();
let awaitable = if aiter.class().is(vm.ctx.types.async_generator) {
vm.call_special_method(aiter, identifier!(vm, __anext__), ())?
Expand Down Expand Up @@ -1030,6 +1037,8 @@ impl ExecutingFrame<'_> {
})?
};
self.push_value(awaitable);
#[cfg(debug_assertions)]
debug_assert_eq!(orig_stack_len + 1, self.state.stack.len());
Ok(None)
}
bytecode::Instruction::EndAsyncFor => {
Expand Down Expand Up @@ -1238,6 +1247,7 @@ impl ExecutingFrame<'_> {
fn unwind_blocks(&mut self, vm: &VirtualMachine, reason: UnwindReason) -> FrameResult {
// First unwind all existing blocks on the block stack:
while let Some(block) = self.current_block() {
// eprintln!("unwinding block: {:.60?} {:.60?}", block.typ, reason);
match block.typ {
BlockType::Loop => match reason {
UnwindReason::Break { target } => {
Expand Down Expand Up @@ -1935,6 +1945,7 @@ impl ExecutingFrame<'_> {
}

fn push_block(&mut self, typ: BlockType) {
// eprintln!("block pushed: {:.60?} {}", typ, self.state.stack.len());
self.state.blocks.push(Block {
typ,
level: self.state.stack.len(),
Expand All @@ -1944,6 +1955,12 @@ impl ExecutingFrame<'_> {
#[track_caller]
fn pop_block(&mut self) -> Block {
let block = self.state.blocks.pop().expect("No more blocks to pop!");
// eprintln!(
// "block popped: {:.60?} {} -> {} ",
// block.typ,
// self.state.stack.len(),
// block.level
// );
#[cfg(debug_assertions)]
if self.state.stack.len() < block.level {
dbg!(&self);
Expand All @@ -1965,6 +1982,11 @@ impl ExecutingFrame<'_> {
#[inline]
#[track_caller] // not a real track_caller but push_value is not very useful
fn push_value(&mut self, obj: PyObjectRef) {
// eprintln!(
// "push_value {} / len: {} +1",
// obj.class().name(),
// self.state.stack.len()
// );
match self.state.stack.try_push(obj) {
Ok(()) => {}
Err(_e) => self.fatal("tried to push value onto stack but overflowed max_stackdepth"),
Expand All @@ -1975,7 +1997,14 @@ impl ExecutingFrame<'_> {
#[track_caller] // not a real track_caller but pop_value is not very useful
fn pop_value(&mut self) -> PyObjectRef {
match self.state.stack.pop() {
Some(x) => x,
Some(x) => {
// eprintln!(
// "pop_value {} / len: {}",
// x.class().name(),
// self.state.stack.len()
// );
x
}
None => self.fatal("tried to pop value but there was nothing on the stack"),
}
}
Expand All @@ -2002,6 +2031,7 @@ impl ExecutingFrame<'_> {
}

#[inline]
#[track_caller]
fn nth_value(&self, depth: u32) -> &PyObject {
let stack = &self.state.stack;
&stack[stack.len() - depth as usize - 1]
Expand Down