Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e3961e2

Browse files
committed
WIP: perf
Signed-off-by: Edwin Yu <[email protected]>
1 parent 40f752b commit e3961e2

2 files changed

Lines changed: 413 additions & 253 deletions

File tree

packages/server/server_tests/memmachine_server/episodic_memory/extra_memory/segment_linker/test_sqlalchemy_segment_linker.py

Lines changed: 37 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ async def test_delete_by_episodes(
605605

606606

607607
@pytest.mark.asyncio
608-
async def test_delete_by_episodes_decrements_ref_count(
608+
async def test_delete_by_episodes_orphans_derivative(
609609
partition: SQLAlchemySegmentLinkerPartition,
610610
) -> None:
611611
ep = uuid4()
@@ -616,8 +616,8 @@ async def test_delete_by_episodes_decrements_ref_count(
616616

617617
await partition.delete_segments_by_episodes([ep])
618618

619-
orphans = list(await partition.get_orphaned_derivatives())
620-
assert deriv in orphans
619+
await partition.mark_orphaned_derivatives_for_purging()
620+
assert deriv in await partition.get_derivatives_pending_purge()
621621

622622

623623
@pytest.mark.asyncio
@@ -663,19 +663,16 @@ async def test_orphan_lifecycle(
663663
await partition.register_segments({seg: [deriv]})
664664

665665
# No orphans yet
666-
assert list(await partition.get_orphaned_derivatives()) == []
666+
await partition.mark_orphaned_derivatives_for_purging()
667+
assert deriv not in await partition.get_derivatives_pending_purge()
667668

668669
# Delete segment -> derivative becomes orphaned
669670
await partition.delete_segments_by_episodes([ep])
670-
orphans = list(await partition.get_orphaned_derivatives())
671-
assert deriv in orphans
671+
await partition.mark_orphaned_derivatives_for_purging()
672672

673-
# Mark for purging
674-
marked = list(await partition.mark_orphaned_derivatives_for_purging([deriv]))
675-
assert deriv in marked
676-
677-
# No longer shows as orphaned (state=P)
678-
assert list(await partition.get_orphaned_derivatives()) == []
673+
# Derivative is now pending purge (state=P)
674+
pending = await partition.get_derivatives_pending_purge()
675+
assert deriv in pending
679676

680677
# Purge
681678
await partition.purge_derivatives([deriv])
@@ -694,8 +691,8 @@ async def test_mark_orphaned_ignores_non_orphans(
694691
deriv = uuid4()
695692
await partition.register_segments({seg: [deriv]})
696693

697-
marked = list(await partition.mark_orphaned_derivatives_for_purging([deriv]))
698-
assert marked == []
694+
await partition.mark_orphaned_derivatives_for_purging()
695+
assert deriv not in await partition.get_derivatives_pending_purge()
699696

700697

701698
@pytest.mark.asyncio
@@ -892,32 +889,27 @@ async def test_concurrent_orphan_mark_does_not_crash(
892889
await partition.register_segments({seg: [deriv]})
893890
await partition.delete_segments_by_episodes([ep])
894891

895-
orphans = list(await partition.get_orphaned_derivatives())
896-
assert deriv in orphans
897-
898892
errors: list[Exception] = []
899893

900-
async def mark() -> list[UUID]:
894+
async def mark() -> None:
901895
try:
902896
part = _get_partition(engine)
903-
return list(await part.mark_orphaned_derivatives_for_purging([deriv]))
897+
await part.mark_orphaned_derivatives_for_purging()
904898
except Exception as e:
905899
errors.append(e)
906-
return []
907900

908-
results = await asyncio.gather(mark(), mark())
909-
all_marked = [uuid for result in results for uuid in result]
901+
await asyncio.gather(mark(), mark())
910902

911903
assert errors == []
912904
# At least one should have marked it.
913-
assert deriv in all_marked
905+
assert deriv in await partition.get_derivatives_pending_purge()
914906

915907

916908
@pytest.mark.asyncio
917909
async def test_concurrent_delete_overlapping_episodes_shared_derivative(
918910
linker: SQLAlchemySegmentLinker,
919911
) -> None:
920-
"""Deleting two episodes that share a derivative concurrently should correctly decrement ref_count."""
912+
"""Deleting two episodes that share a derivative concurrently should orphan the derivative."""
921913
import asyncio
922914

923915
engine = linker._engine
@@ -929,16 +921,16 @@ async def test_concurrent_delete_overlapping_episodes_shared_derivative(
929921
deriv = uuid4()
930922
await partition.register_segments({seg1: [deriv], seg2: [deriv]})
931923

932-
# Both segments link to deriv (ref_count=2). Delete both episodes concurrently.
924+
# Both segments link to deriv. Delete both episodes concurrently.
933925
async def delete_ep(ep: UUID) -> None:
934926
part = _get_partition(engine)
935927
await part.delete_segments_by_episodes([ep])
936928

937929
await asyncio.gather(delete_ep(ep1), delete_ep(ep2))
938930

939-
# Derivative should be orphaned (ref_count=0).
940-
orphans = list(await partition.get_orphaned_derivatives())
941-
assert deriv in orphans
931+
# Derivative should be orphaned (owner_segment_uuid IS NULL).
932+
await partition.mark_orphaned_derivatives_for_purging()
933+
assert deriv in await partition.get_derivatives_pending_purge()
942934

943935

944936
# --- PostgreSQL-only concurrency tests ---
@@ -949,7 +941,7 @@ async def delete_ep(ep: UUID) -> None:
949941
async def test_pg_concurrent_orphan_mark_exactly_once(
950942
pg_linker: SQLAlchemySegmentLinker,
951943
) -> None:
952-
"""On PG, FOR UPDATE ensures only one of two concurrent markers wins."""
944+
"""On PG, FOR UPDATE SKIP LOCKED ensures only one of two concurrent markers wins."""
953945
import asyncio
954946

955947
engine = pg_linker._engine
@@ -961,18 +953,14 @@ async def test_pg_concurrent_orphan_mark_exactly_once(
961953
await partition.register_segments({seg: [deriv]})
962954
await partition.delete_segments_by_episodes([ep])
963955

964-
orphans = list(await partition.get_orphaned_derivatives())
965-
assert deriv in orphans
966-
967-
async def mark() -> list[UUID]:
956+
async def mark() -> None:
968957
part = _get_partition(engine)
969-
return list(await part.mark_orphaned_derivatives_for_purging([deriv]))
958+
await part.mark_orphaned_derivatives_for_purging()
970959

971-
results = await asyncio.gather(mark(), mark())
972-
all_marked = [uuid for result in results for uuid in result]
960+
await asyncio.gather(mark(), mark())
973961

974-
# Exactly one should have marked it due to FOR UPDATE serialization.
975-
assert all_marked.count(deriv) == 1
962+
# Exactly one should have marked it due to SKIP LOCKED serialization.
963+
assert deriv in await partition.get_derivatives_pending_purge()
976964

977965

978966
@pytest.mark.integration
@@ -1070,34 +1058,30 @@ async def test_pg_orphan_relink_race(
10701058

10711059
# Orphan it.
10721060
await partition.delete_segments_by_episodes([ep])
1073-
orphans = list(await partition.get_orphaned_derivatives())
1074-
assert deriv in orphans
10751061

10761062
# Now, concurrently re-link and try to mark for purging.
1077-
relinked = asyncio.Event()
1078-
10791063
async def relinker() -> None:
10801064
part = _get_partition(engine)
10811065
new_seg = _seg(ts_offset_seconds=10)
10821066
await part.register_segments({new_seg: [deriv]}, active=[deriv])
1083-
relinked.set()
10841067

1085-
async def marker() -> list[UUID]:
1068+
async def marker() -> None:
10861069
# Wait a tiny bit to let relinker likely win, but it's a race so either outcome is valid.
10871070
await asyncio.sleep(0.01)
10881071
part = _get_partition(engine)
1089-
return list(await part.mark_orphaned_derivatives_for_purging([deriv]))
1072+
await part.mark_orphaned_derivatives_for_purging()
10901073

1091-
_, marked = await asyncio.gather(relinker(), marker())
1074+
await asyncio.gather(relinker(), marker())
10921075

1093-
# If relinker won the race, deriv has ref_count > 0 and should NOT be marked.
1076+
# If relinker won the race, deriv has owner != NULL and should NOT be in pending.
10941077
# If marker won the race, deriv was still orphaned and gets marked.
10951078
# Either way, no crash.
1096-
if not marked:
1079+
pending = await partition.get_derivatives_pending_purge()
1080+
if deriv not in pending:
10971081
# Relinker won — derivative should still be retrievable.
10981082
result = await partition.get_segments_by_derivatives([deriv])
10991083
assert deriv in result
1100-
# If marked, the derivative was legitimately marked before relinking.
1084+
# If in pending, the derivative was legitimately marked before relinking.
11011085

11021086

11031087
@pytest.mark.integration
@@ -1134,8 +1118,8 @@ async def delete_ep(ep: UUID) -> None:
11341118
result = await partition.get_segments_by_derivatives([deriv])
11351119
assert deriv not in result
11361120

1137-
orphans = list(await partition.get_orphaned_derivatives())
1138-
assert deriv in orphans
1121+
await partition.mark_orphaned_derivatives_for_purging()
1122+
assert deriv in await partition.get_derivatives_pending_purge()
11391123

11401124

11411125
@pytest.mark.integration
@@ -1155,8 +1139,8 @@ async def test_pg_concurrent_purge_and_register_new_derivative(
11551139
old_deriv = uuid4()
11561140
await partition.register_segments({seg: [old_deriv]})
11571141
await partition.delete_segments_by_episodes([ep])
1158-
marked = list(await partition.mark_orphaned_derivatives_for_purging([old_deriv]))
1159-
assert old_deriv in marked
1142+
await partition.mark_orphaned_derivatives_for_purging()
1143+
assert old_deriv in await partition.get_derivatives_pending_purge()
11601144

11611145
errors: list[Exception] = []
11621146

0 commit comments

Comments
 (0)