how to use multiple async sessions with gather (answer: use TaskGroup, not gather) #9312
-
Hi, after migrating to 2.0 i'm gaining such error:
What's interesting - database giving me correct responses and all doing fine, but i have this errors in logs and tests are failing engine = create_async_engine(settings.DB_URL, echo=settings.SQLA_ECHO)
async_session = async_sessionmaker(engine, expire_on_commit=False)
async def get_session() -> AsyncSession:
async with async_session() as session:
yield session The code it complains about |
Beta Was this translation helpful? Give feedback.
Replies: 8 comments 36 replies
-
Hi, It seems your session is being shared by more than one task. |
Beta Was this translation helpful? Give feedback.
-
the change where these errors are now indicated is described at https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#session-raises-proactively-when-illegal-concurrent-or-reentrant-access-is-detected |
Beta Was this translation helpful? Give feedback.
-
Is there a way to have the performances of gather ? I mean can I avoid going from : asyncio.gather(dbCallOne(), dbCallTwo()) to await dbCallOne()
await dbCallTwo() By having a way to call the two functions at the same time and stay in line with the "IllegalStateChangeError" ? |
Beta Was this translation helpful? Give feedback.
-
To add to the discussion, there's an example on using |
Beta Was this translation helpful? Give feedback.
-
Thank you @stevanmilic ! That's very interesting. I have one question related to the code, in the async with async_session() as session, session.begin():
session.add_all([A(data="a_%d" % i) for i in range(100)])
statements = [
select(A).where(A.data == "a_%d" % random.choice(range(100)))
for i in range(30)
]
results = await asyncio.gather(
*(
run_out_of_band(async_session, session, statement)
for statement in statements
)
) When calling |
Beta Was this translation helpful? Give feedback.
-
How can I use a different session for each task?
|
Beta Was this translation helpful? Give feedback.
-
OK first off, here is a complete demo of the pattern we have with the nested context managers. it illustrates the connect failing in one of them and reproduces the error: import asyncio
import random
from typing import cast
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
Base = declarative_base()
class Organization(Base):
__tablename__ = "organization"
org_id = Column(Integer, primary_key=True)
name = Column(String)
class Transaction(Base):
__tablename__ = "transaction"
transaction_id = Column(Integer, primary_key=True)
org_id = Column(ForeignKey("organization.org_id"))
total_cost = Column(Integer)
total_time = Column(Integer)
class User(Base):
__tablename__ = "user"
user_id = Column(Integer, primary_key=True)
org_id = Column(ForeignKey("organization.org_id"))
engine_async = create_async_engine(
"postgresql+asyncpg://scott:tiger@localhost/test"
)
creator = engine_async.pool._creator
def my_creator():
if random.randint(0, 1) == 1:
raise Exception("boom")
else:
return creator()
engine_async.pool._creator = my_creator
def get_data(event, context): # this is the lambda handler
return asyncio.get_event_loop().run_until_complete(
get_data_async(event, context)
)
async def get_data_async(event, context):
org_id = event["org_id"]
async_session = cast(
type[AsyncSession],
sessionmaker(bind=engine_async, future=True, class_=AsyncSession),
)
async with async_session() as session1, async_session() as session2, async_session() as session3:
users_count_query = select(func.count(User.user_id)).filter_by(
org_id=org_id
)
org_name_query = select(Organization.name).filter_by(org_id=org_id)
other_data_query = select(
func.count(Transaction.transaction_id).label("transactions_count"),
func.sum(Transaction.total_cost).label("total_cost"),
func.sum(Transaction.total_time).label("total_time"),
).where(Transaction.org_id == org_id)
tasks = [
session1.execute(users_count_query),
session2.execute(other_data_query),
session3.execute(org_name_query),
]
results = await asyncio.gather(*tasks)
users_count, data, org_name = results
data = data.fetchone()
users_count = users_count.scalar_one()
org_name = org_name.scalar_one()
if random.randint(0, 1) == 1:
raise Exception("request interrupted")
rec = {
"statusCode": 200,
"users_count": users_count,
"transactions_count": data.transactions_count,
"total_cost": 0 if not data.total_cost else data.total_cost,
"total_time": data.total_time,
"org_name": org_name,
}
print(f"Returning rec: {rec}")
return rec
def setup():
e = create_engine("postgresql://scott:tiger@localhost/test")
Base.metadata.drop_all(e)
Base.metadata.create_all(e)
with Session(e) as sess:
org = Organization(org_id=1, name="organization one")
sess.add(org)
sess.add_all(
[
Transaction(
org_id=1,
total_cost=random.randint(0, 1000),
total_time=random.randint(0, 1000),
)
for i in range(100)
]
)
sess.add_all([User(org_id=1) for i in range(100)])
sess.commit()
e.dispose()
setup()
for i in range(10):
get_data({"org_id": 1}, None) that's a full test. Now let's demonstrate a completely minimal form of what is happening: import asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
engine_async = create_async_engine(
"postgresql+asyncpg://scott:tiger@localhost/test"
)
async def blow_up():
raise Exception("boom")
async def get_data_async():
async_session = AsyncSession
session1 = async_session(engine_async)
try:
tasks = [
session1.connection(),
blow_up()
]
await asyncio.gather(*tasks)
except Exception as err:
print(f"We're in an exception, but the state of session1 is: {session1.sync_session._transaction._state}")
await session1.__aexit__(type(err), err, None)
asyncio.run(get_data_async()) then let's look at asyncio.gather():
We can see just what happens with the above program. before we get the error, we see the task for
We are essentially creating We can guard against this problem by letting the tasks complete: async def blow_up():
raise Exception("boom")
async def get_data_async():
async_session = AsyncSession
async with async_session(engine_async) as session1:
tasks = [
session1.connection(),
blow_up()
]
await asyncio.gather(*tasks, return_exceptions=True)
asyncio.run(get_data_async()) that way all the tasks are completed and not interrupted concurrently. now that's not very convenient and we note in Python's documentation: "A more modern way to create and run tasks concurrently and wait for their completion is asyncio.TaskGroup." This seems to be a reason they might have thought of this. let's try it out: import asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
engine_async = create_async_engine(
"postgresql+asyncpg://scott:tiger@localhost/test"
)
async def blow_up():
raise Exception("boom")
async def get_data_async():
async_session = AsyncSession
async with async_session(engine_async) as session1:
async with asyncio.TaskGroup() as tg:
task1 = tg.create_task(session1.connection())
task2 = tg.create_task(blow_up())
asyncio.run(get_data_async()) now we get a nifty error that is fairly over the top, but no issue in SQLAlchemy because all tasks were allowed to finish:
let's try with the original test case: import asyncio
import random
from typing import cast
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
Base = declarative_base()
class Organization(Base):
__tablename__ = "organization"
org_id = Column(Integer, primary_key=True)
name = Column(String)
class Transaction(Base):
__tablename__ = "transaction"
transaction_id = Column(Integer, primary_key=True)
org_id = Column(ForeignKey("organization.org_id"))
total_cost = Column(Integer)
total_time = Column(Integer)
class User(Base):
__tablename__ = "user"
user_id = Column(Integer, primary_key=True)
org_id = Column(ForeignKey("organization.org_id"))
engine_async = create_async_engine(
"postgresql+asyncpg://scott:tiger@localhost/test"
)
creator = engine_async.pool._creator
def my_creator():
if random.randint(0, 1) == 1:
raise Exception("boom")
else:
return creator()
engine_async.pool._creator = my_creator
def get_data(event, context): # this is the lambda handler
return asyncio.get_event_loop().run_until_complete(
get_data_async(event, context)
)
async def get_data_async(event, context):
org_id = event["org_id"]
async_session = cast(
type[AsyncSession],
sessionmaker(bind=engine_async, future=True, class_=AsyncSession),
)
async with async_session() as session1, async_session() as session2, async_session() as session3:
users_count_query = select(func.count(User.user_id)).filter_by(
org_id=org_id
)
org_name_query = select(Organization.name).filter_by(org_id=org_id)
other_data_query = select(
func.count(Transaction.transaction_id).label("transactions_count"),
func.sum(Transaction.total_cost).label("total_cost"),
func.sum(Transaction.total_time).label("total_time"),
).where(Transaction.org_id == org_id)
async with asyncio.TaskGroup() as tg:
task1 = tg.create_task(session1.execute(users_count_query))
task2 = tg.create_task(session2.execute(other_data_query))
task3 = tg.create_task(session3.execute(org_name_query))
users_count, data, org_name = task1.result(), task2.result(), task3.result()
data = data.fetchone()
users_count = users_count.scalar_one()
org_name = org_name.scalar_one()
if random.randint(0, 1) == 1:
raise Exception("request interrupted")
rec = {
"statusCode": 200,
"users_count": users_count,
"transactions_count": data.transactions_count,
"total_cost": 0 if not data.total_cost else data.total_cost,
"total_time": data.total_time,
"org_name": org_name,
}
print(f"Returning rec: {rec}")
return rec
def setup():
e = create_engine("postgresql://scott:tiger@localhost/test")
Base.metadata.drop_all(e)
Base.metadata.create_all(e)
with Session(e) as sess:
org = Organization(org_id=1, name="organization one")
sess.add(org)
sess.add_all(
[
Transaction(
org_id=1,
total_cost=random.randint(0, 1000),
total_time=random.randint(0, 1000),
)
for i in range(100)
]
)
sess.add_all([User(org_id=1) for i in range(100)])
sess.commit()
e.dispose()
setup()
for i in range(10):
get_data({"org_id": 1}, None) and now we get an organized result:
so the moral of the story is, use TaskGroup, not gather(). Make sure every task is allowed to finish uninterrupted. |
Beta Was this translation helpful? Give feedback.
-
Is this a bad solution to this? Finding out the hard way that Postgres is not kind to too many parallel updates even with a pgbouncer so looking to reduce the number of sessions wherever possible |
Beta Was this translation helpful? Give feedback.
Hi,
It seems your session is being shared by more than one task.
Please ensure that you are not ding things like
asyncio.gather(func(session), func2(session)
.