From 7ab42086d1e28dd902cfa1a109cdba84a0e28f82 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Sat, 28 Jun 2025 01:31:53 -0400 Subject: [PATCH 01/65] update papaya (#928) --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index e693b92f0..4a579ac71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ hashbrown = "0.15" hashlink = "0.10" indexmap = "2" intrusive-collections = "0.9.7" -papaya = "0.2.2" +papaya = "0.2.3" parking_lot = "0.12" portable-atomic = "1" rustc-hash = "2" From fc00eba89e5dcaa5edba51c41aa5f309b5cb126b Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Wed, 2 Jul 2025 14:34:02 +0200 Subject: [PATCH 02/65] Fix `heap_size` option not being preserved in tracked impls (#930) --- components/salsa-macros/src/options.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/components/salsa-macros/src/options.rs b/components/salsa-macros/src/options.rs index 00ac4331a..7664d6eae 100644 --- a/components/salsa-macros/src/options.rs +++ b/components/salsa-macros/src/options.rs @@ -527,7 +527,7 @@ impl quote::ToTokens for Options { tokens.extend(quote::quote! { revisions = #revisions, }); } if let Some(heap_size_fn) = heap_size_fn { - tokens.extend(quote::quote! { heap_size_fn = #heap_size_fn, }); + tokens.extend(quote::quote! { heap_size = #heap_size_fn, }); } if let Some(self_ty) = self_ty { tokens.extend(quote::quote! { self_ty = #self_ty, }); From d28d66bf1390037f38abfe9cfcf9cd1ee4eb6f74 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 4 Jul 2025 22:47:07 +0200 Subject: [PATCH 03/65] fix: Fix phantom data usage in salsa structs affecting auto traits (#932) --- components/salsa-macro-rules/src/setup_interned_struct.rs | 2 +- components/salsa-macro-rules/src/setup_tracked_fn.rs | 2 +- components/salsa-macro-rules/src/setup_tracked_struct.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index 193436224..103f4d210 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -81,7 +81,7 @@ macro_rules! setup_interned_struct { #[derive(Copy, Clone, PartialEq, Eq, Hash)] $vis struct $Struct< $($db_lt_arg)? >( $Id, - std::marker::PhantomData < & $interior_lt salsa::plumbing::interned::Value <$StructWithStatic> > + std::marker::PhantomData &$interior_lt ()> ); #[allow(clippy::all)] diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index c424c4660..850b3e58d 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -99,7 +99,7 @@ macro_rules! setup_tracked_fn { #[derive(Copy, Clone)] struct $InternedData<$db_lt>( salsa::Id, - std::marker::PhantomData<&$db_lt $zalsa::interned::Value<$Configuration>>, + std::marker::PhantomData &$db_lt ()>, ); static $INTERN_CACHE: $zalsa::IngredientCache<$zalsa::interned::IngredientImpl<$Configuration>> = diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index ab3512ad7..8b54fdd22 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -104,7 +104,7 @@ macro_rules! setup_tracked_struct { #[derive(Copy, Clone, PartialEq, Eq, Hash)] $vis struct $Struct<$db_lt>( salsa::Id, - std::marker::PhantomData < & $db_lt salsa::plumbing::tracked_struct::Value < $Struct<'static> > > + std::marker::PhantomData &$db_lt ()> ); #[allow(clippy::all)] From dba66f1a37acca014c2402f231ed5b361bd7d8fe Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Fri, 18 Jul 2025 00:55:50 -0400 Subject: [PATCH 04/65] Use `inventory` for static ingredient registration (#934) * use `inventory` for static ingredient registration * remove unnecessary synchronization from memo tables * use global ingredient caches for database-independent ingredients * add manual ingredient registration API * remove static ingredient index optimization when manual registration is in use * fix atomic imports * simplify ingredient caches --- .github/workflows/test.yml | 2 + Cargo.toml | 7 +- .../src/setup_accumulator_impl.rs | 13 +- .../src/setup_input_struct.rs | 17 +- .../src/setup_interned_struct.rs | 15 +- .../salsa-macro-rules/src/setup_tracked_fn.rs | 68 ++-- .../src/setup_tracked_struct.rs | 17 +- components/salsa-macros/src/db.rs | 9 +- components/salsa-macros/src/fn_util.rs | 2 +- components/salsa-macros/src/hygiene.rs | 12 +- components/salsa-macros/src/supertype.rs | 4 +- components/salsa-macros/src/tracked_fn.rs | 8 +- components/salsa-macros/src/tracked_impl.rs | 2 +- src/accumulator.rs | 13 +- src/database.rs | 7 +- src/function.rs | 35 +- src/function/memo.rs | 2 +- src/ingredient.rs | 30 +- src/ingredient_cache.rs | 202 ++++++++++++ src/input.rs | 13 +- src/input/input_field.rs | 6 +- src/interned.rs | 39 +-- src/lib.rs | 15 +- src/memo_ingredient_indices.rs | 43 ++- src/salsa_struct.rs | 2 +- src/storage.rs | 64 +++- src/sync.rs | 96 +----- src/table/memo.rs | 169 +++------- src/tracked_struct.rs | 29 +- src/tracked_struct/tracked_field.rs | 6 +- src/views.rs | 40 ++- src/zalsa.rs | 309 +++++++----------- src/zalsa_local.rs | 2 +- tests/accumulate-chain.rs | 2 + tests/accumulate-custom-debug.rs | 2 + tests/accumulate-dag.rs | 2 + tests/accumulate-execution-order.rs | 2 + tests/accumulate-from-tracked-fn.rs | 2 + tests/accumulate-no-duplicates.rs | 2 + tests/accumulate-reuse-workaround.rs | 2 + tests/accumulate-reuse.rs | 2 + tests/accumulate.rs | 2 + tests/accumulated_backdate.rs | 2 + tests/backtrace.rs | 34 +- tests/check_auto_traits.rs | 2 + tests/compile_fail.rs | 2 + tests/cycle.rs | 2 + tests/cycle_accumulate.rs | 2 + tests/cycle_fallback_immediate.rs | 2 + tests/cycle_initial_call_back_into_cycle.rs | 2 + tests/cycle_initial_call_query.rs | 2 + tests/cycle_maybe_changed_after.rs | 2 + tests/cycle_output.rs | 2 + tests/cycle_recovery_call_back_into_cycle.rs | 2 + tests/cycle_recovery_call_query.rs | 2 + tests/cycle_regression_455.rs | 2 + tests/cycle_result_dependencies.rs | 2 + tests/cycle_tracked.rs | 2 + tests/cycle_tracked_own_input.rs | 2 + tests/dataflow.rs | 2 + tests/debug.rs | 2 + tests/debug_db_contents.rs | 2 + tests/deletion-cascade.rs | 2 + tests/deletion-drops.rs | 2 + tests/deletion.rs | 2 + tests/derive_update.rs | 2 + tests/durability.rs | 2 + tests/elided-lifetime-in-tracked-fn.rs | 2 + ...truct_changes_but_fn_depends_on_field_y.rs | 2 + ...input_changes_but_fn_depends_on_field_y.rs | 2 + tests/hash_collision.rs | 2 + tests/hello_world.rs | 2 + tests/input_default.rs | 2 + tests/input_field_durability.rs | 2 + tests/input_setter_preserves_durability.rs | 2 + tests/intern_access_in_different_revision.rs | 2 + tests/interned-revisions.rs | 2 + tests/interned-structs.rs | 2 + tests/interned-structs_self_ref.rs | 20 +- tests/lru.rs | 2 + tests/manual_registration.rs | 92 ++++++ tests/memory-usage.rs | 2 + tests/mutate_in_place.rs | 2 + tests/override_new_get_set.rs | 2 + ...ng-tracked-struct-outside-of-tracked-fn.rs | 2 + tests/parallel/main.rs | 2 + tests/preverify-struct-with-leaked-data-2.rs | 2 + tests/preverify-struct-with-leaked-data.rs | 2 + tests/return_mode.rs | 2 + tests/singleton.rs | 2 + ...the-key-is-created-in-the-current-query.rs | 2 + tests/synthetic_write.rs | 2 + tests/tracked-struct-id-field-bad-eq.rs | 2 + tests/tracked-struct-id-field-bad-hash.rs | 2 + tests/tracked-struct-unchanged-in-new-rev.rs | 2 + tests/tracked-struct-value-field-bad-eq.rs | 2 + tests/tracked-struct-value-field-not-eq.rs | 2 + tests/tracked_assoc_fn.rs | 2 + tests/tracked_fn_constant.rs | 2 + .../tracked_fn_high_durability_dependency.rs | 1 + tests/tracked_fn_interned_lifetime.rs | 2 + tests/tracked_fn_multiple_args.rs | 2 + tests/tracked_fn_no_eq.rs | 2 + tests/tracked_fn_on_input.rs | 2 + ...racked_fn_on_input_with_high_durability.rs | 1 + tests/tracked_fn_on_interned.rs | 2 + tests/tracked_fn_on_interned_enum.rs | 2 + tests/tracked_fn_on_tracked.rs | 2 + tests/tracked_fn_on_tracked_specify.rs | 2 + tests/tracked_fn_orphan_escape_hatch.rs | 2 + tests/tracked_fn_read_own_entity.rs | 2 + tests/tracked_fn_read_own_specify.rs | 2 + tests/tracked_fn_return_ref.rs | 2 + tests/tracked_method.rs | 2 + tests/tracked_method_inherent_return_deref.rs | 2 + tests/tracked_method_inherent_return_ref.rs | 2 + tests/tracked_method_on_tracked_struct.rs | 2 + tests/tracked_method_trait_return_ref.rs | 2 + tests/tracked_method_with_self_ty.rs | 2 + tests/tracked_struct.rs | 2 + tests/tracked_struct_db1_lt.rs | 2 + tests/tracked_struct_disambiguates.rs | 2 + tests/tracked_struct_durability.rs | 2 + tests/tracked_struct_manual_update.rs | 2 + tests/tracked_struct_mixed_tracked_fields.rs | 2 + tests/tracked_struct_recreate_new_revision.rs | 2 + tests/tracked_struct_with_interned_query.rs | 2 + tests/tracked_with_intern.rs | 2 + tests/tracked_with_struct_db.rs | 2 + tests/tracked_with_struct_ord.rs | 2 + 130 files changed, 1020 insertions(+), 610 deletions(-) create mode 100644 src/ingredient_cache.rs create mode 100644 tests/manual_registration.rs diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 803d4105e..acb692572 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -55,6 +55,8 @@ jobs: run: cargo clippy --workspace --all-targets -- -D warnings - name: Test run: cargo nextest run --workspace --all-targets --no-fail-fast + - name: Test Manual Registration + run: cargo nextest run --workspace --tests --no-fail-fast --no-default-features --features macros - name: Test docs run: cargo test --workspace --doc - name: Check (without default features) diff --git a/Cargo.toml b/Cargo.toml index 4a579ac71..54eac8531 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,13 +19,15 @@ hashbrown = "0.15" hashlink = "0.10" indexmap = "2" intrusive-collections = "0.9.7" -papaya = "0.2.3" parking_lot = "0.12" portable-atomic = "1" rustc-hash = "2" smallvec = "1" tracing = { version = "0.1", default-features = false, features = ["std"] } +# Automatic ingredient registration. +inventory = { version = "0.3.20", optional = true } + # parallel map rayon = { version = "1.10.0", optional = true } @@ -36,7 +38,8 @@ thin-vec = "0.2.13" shuttle = { version = "0.8.0", optional = true } [features] -default = ["salsa_unstable", "rayon", "macros"] +default = ["salsa_unstable", "rayon", "macros", "inventory"] +inventory = ["dep:inventory"] shuttle = ["dep:shuttle"] # FIXME: remove `salsa_unstable` before 1.0. salsa_unstable = [] diff --git a/components/salsa-macro-rules/src/setup_accumulator_impl.rs b/components/salsa-macro-rules/src/setup_accumulator_impl.rs index 788d5cf76..7842067e7 100644 --- a/components/salsa-macro-rules/src/setup_accumulator_impl.rs +++ b/components/salsa-macro-rules/src/setup_accumulator_impl.rs @@ -21,14 +21,21 @@ macro_rules! setup_accumulator_impl { use salsa::plumbing as $zalsa; use salsa::plumbing::accumulator as $zalsa_struct; + impl $zalsa::HasJar for $Struct { + type Jar = $zalsa_struct::JarImpl<$Struct>; + const KIND: $zalsa::JarKind = $zalsa::JarKind::Struct; + } + + $zalsa::register_jar! { + $zalsa::ErasedJar::erase::<$Struct>() + } + fn $ingredient(zalsa: &$zalsa::Zalsa) -> &$zalsa_struct::IngredientImpl<$Struct> { static $CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Struct>> = $zalsa::IngredientCache::new(); $CACHE.get_or_create(zalsa, || { - zalsa - .lookup_jar_by_type::<$zalsa_struct::JarImpl<$Struct>>() - .get_or_create() + zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Struct>>() }) } diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index a2a402ba9..38988c546 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -74,6 +74,15 @@ macro_rules! setup_input_struct { type $Configuration = $Struct; + impl $zalsa::HasJar for $Struct { + type Jar = $zalsa_struct::JarImpl<$Configuration>; + const KIND: $zalsa::JarKind = $zalsa::JarKind::Struct; + } + + $zalsa::register_jar! { + $zalsa::ErasedJar::erase::<$Struct>() + } + impl $zalsa_struct::Configuration for $Configuration { const LOCATION: $zalsa::Location = $zalsa::Location { file: file!(), @@ -101,14 +110,14 @@ macro_rules! setup_input_struct { $zalsa::IngredientCache::new(); CACHE.get_or_create(zalsa, || { - zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().get_or_create() + zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() }) } pub fn ingredient_mut(db: &mut dyn $zalsa::Database) -> (&mut $zalsa_struct::IngredientImpl, &mut $zalsa::Runtime) { let zalsa_mut = db.zalsa_mut(); zalsa_mut.new_revision(); - let index = zalsa_mut.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().get_or_create(); + let index = zalsa_mut.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>(); let (ingredient, runtime) = zalsa_mut.lookup_ingredient_mut(index); let ingredient = ingredient.assert_type_mut::<$zalsa_struct::IngredientImpl>(); (ingredient, runtime) @@ -149,8 +158,8 @@ macro_rules! setup_input_struct { impl $zalsa::SalsaStructInDb for $Struct { type MemoIngredientMap = $zalsa::MemoIngredientSingletonIndex; - fn lookup_or_create_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { - aux.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().get_or_create().into() + fn lookup_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { + aux.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into() } #[inline] diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index 103f4d210..3a62355cd 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -92,6 +92,15 @@ macro_rules! setup_interned_struct { type $Configuration = $StructWithStatic; + impl<$($db_lt_arg)?> $zalsa::HasJar for $Struct<$($db_lt_arg)?> { + type Jar = $zalsa_struct::JarImpl<$Configuration>; + const KIND: $zalsa::JarKind = $zalsa::JarKind::Struct; + } + + $zalsa::register_jar! { + $zalsa::ErasedJar::erase::<$StructWithStatic>() + } + type $StructDataIdent<$db_lt> = ($($field_ty,)*); /// Key to use during hash lookups. Each field is some type that implements `Lookup` @@ -149,7 +158,7 @@ macro_rules! setup_interned_struct { let zalsa = db.zalsa(); CACHE.get_or_create(zalsa, || { - zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().get_or_create() + zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() }) } } @@ -181,8 +190,8 @@ macro_rules! setup_interned_struct { impl< $($db_lt_arg)? > $zalsa::SalsaStructInDb for $Struct< $($db_lt_arg)? > { type MemoIngredientMap = $zalsa::MemoIngredientSingletonIndex; - fn lookup_or_create_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { - aux.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().get_or_create().into() + fn lookup_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { + aux.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into() } #[inline] diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 850b3e58d..477070714 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -91,6 +91,16 @@ macro_rules! setup_tracked_fn { struct $Configuration; + $zalsa::register_jar! { + $zalsa::ErasedJar::erase::<$fn_name>() + } + + #[allow(non_local_definitions)] + impl $zalsa::HasJar for $fn_name { + type Jar = $fn_name; + const KIND: $zalsa::JarKind = $zalsa::JarKind::TrackedFn; + } + static $FN_CACHE: $zalsa::IngredientCache<$zalsa::function::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); @@ -108,7 +118,7 @@ macro_rules! setup_tracked_fn { impl $zalsa::SalsaStructInDb for $InternedData<'_> { type MemoIngredientMap = $zalsa::MemoIngredientSingletonIndex; - fn lookup_or_create_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { + fn lookup_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { $zalsa::IngredientIndices::empty() } @@ -155,27 +165,19 @@ macro_rules! setup_tracked_fn { impl $Configuration { fn fn_ingredient(db: &dyn $Db) -> &$zalsa::function::IngredientImpl<$Configuration> { let zalsa = db.zalsa(); - $FN_CACHE.get_or_create(zalsa, || { - let jar_entry = zalsa.lookup_jar_by_type::<$Configuration>(); - - // If the ingredient has already been inserted, we know that the downcaster - // has also been registered. This is a fast-path for multi-database use cases - // that bypass the ingredient cache and will always execute this closure. - if let Some(index) = jar_entry.get() { - return index; - } - - ::zalsa_register_downcaster(db); - jar_entry.get_or_create() - }) + $FN_CACHE + .get_or_create(zalsa, || zalsa.lookup_jar_by_type::<$fn_name>()) + .get_or_init(|| ::zalsa_register_downcaster(db)) } pub fn fn_ingredient_mut(db: &mut dyn $Db) -> &mut $zalsa::function::IngredientImpl { - ::zalsa_register_downcaster(db); + let view = ::zalsa_register_downcaster(db); let zalsa_mut = db.zalsa_mut(); - let index = zalsa_mut.lookup_jar_by_type::<$Configuration>().get_or_create(); + let index = zalsa_mut.lookup_jar_by_type::<$fn_name>(); let (ingredient, _) = zalsa_mut.lookup_ingredient_mut(index); - ingredient.assert_type_mut::<$zalsa::function::IngredientImpl>() + let ingredient = ingredient.assert_type_mut::<$zalsa::function::IngredientImpl>(); + ingredient.get_or_init(|| view); + ingredient } $zalsa::macro_if! { $needs_interner => @@ -184,8 +186,7 @@ macro_rules! setup_tracked_fn { ) -> &$zalsa::interned::IngredientImpl<$Configuration> { let zalsa = db.zalsa(); $INTERN_CACHE.get_or_create(zalsa, || { - ::zalsa_register_downcaster(db); - zalsa.lookup_jar_by_type::<$Configuration>().get_or_create().successor(0) + zalsa.lookup_jar_by_type::<$fn_name>().successor(0) }) } } @@ -248,42 +249,31 @@ macro_rules! setup_tracked_fn { } } - impl $zalsa::Jar for $Configuration { - fn create_dependencies(zalsa: &$zalsa::Zalsa) -> $zalsa::IngredientIndices - where - Self: Sized - { - $zalsa::macro_if! { - if $needs_interner { - $zalsa::IngredientIndices::empty() - } else { - <$InternedData as $zalsa::SalsaStructInDb>::lookup_or_create_ingredient_index(zalsa) - } - } - } - + #[allow(non_local_definitions)] + impl $zalsa::Jar for $fn_name { fn create_ingredients( - zalsa: &$zalsa::Zalsa, + zalsa: &mut $zalsa::Zalsa, first_index: $zalsa::IngredientIndex, - struct_index: $zalsa::IngredientIndices, ) -> Vec> { let struct_index: $zalsa::IngredientIndices = $zalsa::macro_if! { if $needs_interner { first_index.successor(0).into() } else { - struct_index + // Note that struct ingredients are created before tracked functions, + // so this cannot panic. + <$InternedData as $zalsa::SalsaStructInDb>::lookup_ingredient_index(zalsa) } }; $zalsa::macro_if! { $needs_interner => - let intern_ingredient = <$zalsa::interned::IngredientImpl<$Configuration>>::new( + let mut intern_ingredient = <$zalsa::interned::IngredientImpl<$Configuration>>::new( first_index.successor(0) ); } let intern_ingredient_memo_types = $zalsa::macro_if! { if $needs_interner { - Some($zalsa::Ingredient::memo_table_types(&intern_ingredient)) + Some($zalsa::Ingredient::memo_table_types_mut(&mut intern_ingredient)) } else { None } @@ -303,7 +293,6 @@ macro_rules! setup_tracked_fn { first_index, memo_ingredient_indices, $lru, - zalsa.views().downcaster_for::(), ); $zalsa::macro_if! { if $needs_interner { @@ -386,6 +375,7 @@ macro_rules! setup_tracked_fn { $zalsa::return_mode_expression!(($return_mode, __, __), $output_ty, result,) }) } + // The struct needs be last in the macro expansion in order to make the tracked // function's ident be identified as a function, not a struct, during semantic highlighting. // for more details, see https://github.com/salsa-rs/salsa/pull/612. diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index 8b54fdd22..07131735e 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -107,8 +107,8 @@ macro_rules! setup_tracked_struct { std::marker::PhantomData &$db_lt ()> ); - #[allow(clippy::all)] #[allow(dead_code)] + #[allow(clippy::all)] const _: () = { use salsa::plumbing as $zalsa; use $zalsa::tracked_struct as $zalsa_struct; @@ -116,6 +116,15 @@ macro_rules! setup_tracked_struct { type $Configuration = $Struct<'static>; + impl<$db_lt> $zalsa::HasJar for $Struct<$db_lt> { + type Jar = $zalsa_struct::JarImpl<$Configuration>; + const KIND: $zalsa::JarKind = $zalsa::JarKind::Struct; + } + + $zalsa::register_jar! { + $zalsa::ErasedJar::erase::<$Struct<'static>>() + } + impl $zalsa_struct::Configuration for $Configuration { const LOCATION: $zalsa::Location = $zalsa::Location { file: file!(), @@ -188,7 +197,7 @@ macro_rules! setup_tracked_struct { $zalsa::IngredientCache::new(); CACHE.get_or_create(zalsa, || { - zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().get_or_create() + zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() }) } } @@ -210,8 +219,8 @@ macro_rules! setup_tracked_struct { impl $zalsa::SalsaStructInDb for $Struct<'_> { type MemoIngredientMap = $zalsa::MemoIngredientSingletonIndex; - fn lookup_or_create_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { - aux.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().get_or_create().into() + fn lookup_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { + aux.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into() } #[inline] diff --git a/components/salsa-macros/src/db.rs b/components/salsa-macros/src/db.rs index 478ebea5d..12ee48917 100644 --- a/components/salsa-macros/src/db.rs +++ b/components/salsa-macros/src/db.rs @@ -110,7 +110,7 @@ impl DbMacro { let trait_name = &input.ident; input.items.push(parse_quote! { #[doc(hidden)] - fn zalsa_register_downcaster(&self); + fn zalsa_register_downcaster(&self) -> salsa::plumbing::DatabaseDownCaster; }); let comment = format!(" Downcast a [`dyn Database`] to a [`dyn {trait_name}`]"); @@ -135,10 +135,11 @@ impl DbMacro { }; input.items.push(parse_quote! { + #[cold] + #[inline(never)] #[doc(hidden)] - #[inline(always)] - fn zalsa_register_downcaster(&self) { - salsa::plumbing::views(self).add(::downcast); + fn zalsa_register_downcaster(&self) -> salsa::plumbing::DatabaseDownCaster { + salsa::plumbing::views(self).add(::downcast) } }); input.items.push(parse_quote! { diff --git a/components/salsa-macros/src/fn_util.rs b/components/salsa-macros/src/fn_util.rs index 619d0fd97..d06d0a7d1 100644 --- a/components/salsa-macros/src/fn_util.rs +++ b/components/salsa-macros/src/fn_util.rs @@ -15,7 +15,7 @@ pub fn input_ids(hygiene: &Hygiene, sig: &syn::Signature, skip: usize) -> Vec syn::Ident { + pub(crate) fn ident(&self, text: impl AsRef) -> syn::Ident { // Make the default be `foo_` rather than `foo` -- this helps detect // cases where people wrote `foo` instead of `#foo` or `$foo` in the generated code. - let mut buffer = format!("{text}_"); + let mut buffer = format!("{}_", text.as_ref()); while self.user_tokens.contains(&buffer) { buffer.push('_'); @@ -61,4 +61,12 @@ impl Hygiene { syn::Ident::new(&buffer, proc_macro2::Span::call_site()) } + + /// Generates an identifier similar to `text` but distinct from any identifiers + /// that appear in the user's code. + /// + /// The identifier must be unique relative to the `scope` identifier. + pub(crate) fn scoped_ident(&self, scope: &syn::Ident, text: &str) -> syn::Ident { + self.ident(format!("{scope}_{text}")) + } } diff --git a/components/salsa-macros/src/supertype.rs b/components/salsa-macros/src/supertype.rs index 5b433bd86..d1c6c70b8 100644 --- a/components/salsa-macros/src/supertype.rs +++ b/components/salsa-macros/src/supertype.rs @@ -72,8 +72,8 @@ fn enum_impl(enum_item: syn::ItemEnum) -> syn::Result { type MemoIngredientMap = zalsa::MemoIngredientIndices; #[inline] - fn lookup_or_create_ingredient_index(__zalsa: &zalsa::Zalsa) -> zalsa::IngredientIndices { - zalsa::IngredientIndices::merge([ #( <#variant_types as zalsa::SalsaStructInDb>::lookup_or_create_ingredient_index(__zalsa) ),* ]) + fn lookup_ingredient_index(__zalsa: &zalsa::Zalsa) -> zalsa::IngredientIndices { + zalsa::IngredientIndices::merge([ #( <#variant_types as zalsa::SalsaStructInDb>::lookup_ingredient_index(__zalsa) ),* ]) } #[inline] diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index f2f078934..9f7ec0879 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -132,10 +132,10 @@ impl Macro { inner_fn.sig.ident = self.hygiene.ident("inner"); let zalsa = self.hygiene.ident("zalsa"); - let Configuration = self.hygiene.ident("Configuration"); - let InternedData = self.hygiene.ident("InternedData"); - let FN_CACHE = self.hygiene.ident("FN_CACHE"); - let INTERN_CACHE = self.hygiene.ident("INTERN_CACHE"); + let Configuration = self.hygiene.scoped_ident(fn_name, "Configuration"); + let InternedData = self.hygiene.scoped_ident(fn_name, "InternedData"); + let FN_CACHE = self.hygiene.scoped_ident(fn_name, "FN_CACHE"); + let INTERN_CACHE = self.hygiene.scoped_ident(fn_name, "INTERN_CACHE"); let inner = &inner_fn.sig.ident; let function_type = function_type(&item); diff --git a/components/salsa-macros/src/tracked_impl.rs b/components/salsa-macros/src/tracked_impl.rs index 9ad9c2c43..2a07ae0c6 100644 --- a/components/salsa-macros/src/tracked_impl.rs +++ b/components/salsa-macros/src/tracked_impl.rs @@ -99,7 +99,7 @@ impl Macro { }); let InnerTrait = self.hygiene.ident("InnerTrait"); - let inner_fn_name = self.hygiene.ident(&fn_item.sig.ident.to_string()); + let inner_fn_name = self.hygiene.ident(fn_item.sig.ident.to_string()); let AssociatedFunctionArguments { self_token, diff --git a/src/accumulator.rs b/src/accumulator.rs index b3103c317..3b1358c60 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -10,7 +10,7 @@ use accumulated::{Accumulated, AnyAccumulated}; use crate::cycle::CycleHeads; use crate::function::VerifyResult; use crate::ingredient::{Ingredient, Jar}; -use crate::plumbing::{IngredientIndices, ZalsaLocal}; +use crate::plumbing::ZalsaLocal; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::zalsa::{IngredientIndex, Zalsa}; @@ -44,9 +44,8 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( - _zalsa: &Zalsa, + _zalsa: &mut Zalsa, first_index: IngredientIndex, - _dependencies: IngredientIndices, ) -> Vec> { vec![Box::new(>::new(first_index))] } @@ -64,7 +63,7 @@ pub struct IngredientImpl { impl IngredientImpl { /// Find the accumulator ingredient for `A` in the database, if any. pub fn from_zalsa(zalsa: &Zalsa) -> Option<&Self> { - let index = zalsa.lookup_jar_by_type::>().get_or_create(); + let index = zalsa.lookup_jar_by_type::>(); let ingredient = zalsa.lookup_ingredient(index).assert_type::(); Some(ingredient) } @@ -115,7 +114,11 @@ impl Ingredient for IngredientImpl { A::DEBUG_NAME } - fn memo_table_types(&self) -> Arc { + fn memo_table_types(&self) -> &Arc { + unreachable!("accumulator does not allocate pages") + } + + fn memo_table_types_mut(&mut self) -> &mut Arc { unreachable!("accumulator does not allocate pages") } } diff --git a/src/database.rs b/src/database.rs index 594deb0a1..b840398ff 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,6 +1,7 @@ use std::any::Any; use std::borrow::Cow; +use crate::views::DatabaseDownCaster; use crate::zalsa::{IngredientIndex, ZalsaDatabase}; use crate::{Durability, Revision}; @@ -80,9 +81,11 @@ pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { crate::attach::attach(self, || op(self)) } + #[cold] + #[inline(never)] #[doc(hidden)] - #[inline(always)] - fn zalsa_register_downcaster(&self) { + fn zalsa_register_downcaster(&self) -> DatabaseDownCaster { + self.zalsa().views().downcaster_for::() // The no-op downcaster is special cased in view caster construction. } diff --git a/src/function.rs b/src/function.rs index 76b2abf7d..ceb006feb 100644 --- a/src/function.rs +++ b/src/function.rs @@ -3,6 +3,7 @@ use std::any::Any; use std::fmt; use std::ptr::NonNull; use std::sync::atomic::Ordering; +use std::sync::OnceLock; pub(crate) use sync::SyncGuard; use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}; @@ -129,7 +130,7 @@ pub struct IngredientImpl { /// /// The supplied database must be be the same as the database used to construct the [`Views`] /// instances that this downcaster was derived from. - view_caster: DatabaseDownCaster, + view_caster: OnceLock>, sync_table: SyncTable, @@ -156,18 +157,30 @@ where index: IngredientIndex, memo_ingredient_indices: as SalsaStructInDb>::MemoIngredientMap, lru: usize, - view_caster: DatabaseDownCaster, ) -> Self { Self { index, memo_ingredient_indices, lru: lru::Lru::new(lru), deleted_entries: Default::default(), - view_caster, + view_caster: OnceLock::new(), sync_table: SyncTable::new(index), } } + /// Set the view-caster for this tracked function ingredient, if it has + /// not already been initialized. + #[inline] + pub fn get_or_init( + &self, + view_caster: impl FnOnce() -> DatabaseDownCaster, + ) -> &Self { + // Note that we must set this lazily as we don't have access to the database + // type when ingredients are registered into the `Zalsa`. + self.view_caster.get_or_init(view_caster); + self + } + #[inline] pub fn database_key_index(&self, key: Id) -> DatabaseKeyIndex { DatabaseKeyIndex::new(self.index, key) @@ -226,6 +239,12 @@ where fn memo_ingredient_index(&self, zalsa: &Zalsa, id: Id) -> MemoIngredientIndex { self.memo_ingredient_indices.get_zalsa_id(zalsa, id) } + + fn view_caster(&self) -> &DatabaseDownCaster { + self.view_caster + .get() + .expect("tracked function ingredients cannot be accessed before calling `init`") + } } impl Ingredient for IngredientImpl @@ -248,7 +267,7 @@ where cycle_heads: &mut CycleHeads, ) -> VerifyResult { // SAFETY: The `db` belongs to the ingredient as per caller invariant - let db = unsafe { self.view_caster.downcast_unchecked(db) }; + let db = unsafe { self.view_caster().downcast_unchecked(db) }; self.maybe_changed_after(db, input, revision, cycle_heads) } @@ -339,7 +358,11 @@ where C::DEBUG_NAME } - fn memo_table_types(&self) -> Arc { + fn memo_table_types(&self) -> &Arc { + unreachable!("function does not allocate pages") + } + + fn memo_table_types_mut(&mut self) -> &mut Arc { unreachable!("function does not allocate pages") } @@ -352,7 +375,7 @@ where db: &'db dyn Database, key_index: Id, ) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) { - let db = self.view_caster.downcast(db); + let db = self.view_caster().downcast(db); self.accumulated_map(db, key_index) } } diff --git a/src/function/memo.rs b/src/function/memo.rs index 8f8952e5b..8f8393fc6 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -466,7 +466,7 @@ mod _memory_usage { impl SalsaStructInDb for DummyStruct { type MemoIngredientMap = MemoIngredientSingletonIndex; - fn lookup_or_create_ingredient_index(_: &Zalsa) -> IngredientIndices { + fn lookup_ingredient_index(_: &Zalsa) -> IngredientIndices { unimplemented!() } diff --git a/src/ingredient.rs b/src/ingredient.rs index ff4837694..796a6e12f 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -6,7 +6,6 @@ use crate::cycle::{ empty_cycle_heads, CycleHeads, CycleRecoveryStrategy, IterationCount, ProvisionalStatus, }; use crate::function::VerifyResult; -use crate::plumbing::IngredientIndices; use crate::runtime::Running; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; @@ -16,35 +15,20 @@ use crate::zalsa_local::QueryOriginRef; use crate::{Database, DatabaseKeyIndex, Id, Revision}; /// A "jar" is a group of ingredients that are added atomically. +/// /// Each type implementing jar can be added to the database at most once. pub trait Jar: Any { - /// This creates the ingredient dependencies of this jar. We need to split this from `create_ingredients()` - /// because while `create_ingredients()` is called, a lock on the ingredient map is held (to guarantee - /// atomicity), so other ingredients could not be created. - /// - /// Only tracked fns use this. - fn create_dependencies(_zalsa: &Zalsa) -> IngredientIndices - where - Self: Sized, - { - IngredientIndices::empty() - } - /// Create the ingredients given the index of the first one. + /// /// All subsequent ingredients will be assigned contiguous indices. fn create_ingredients( - zalsa: &Zalsa, + zalsa: &mut Zalsa, first_index: IngredientIndex, - dependencies: IngredientIndices, - ) -> Vec> - where - Self: Sized; + ) -> Vec>; /// This returns the [`TypeId`] of the ID struct, that is, the struct that wraps `salsa::Id` /// and carry the name of the jar. - fn id_struct_type_id() -> TypeId - where - Self: Sized; + fn id_struct_type_id() -> TypeId; } pub struct Location { @@ -151,7 +135,9 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { ); } - fn memo_table_types(&self) -> Arc; + fn memo_table_types(&self) -> &Arc; + + fn memo_table_types_mut(&mut self) -> &mut Arc; fn fmt_index(&self, index: crate::Id, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt_index(self.debug_name(), index, fmt) diff --git a/src/ingredient_cache.rs b/src/ingredient_cache.rs new file mode 100644 index 000000000..8b9ebe76b --- /dev/null +++ b/src/ingredient_cache.rs @@ -0,0 +1,202 @@ +pub use imp::IngredientCache; + +#[cfg(feature = "inventory")] +mod imp { + use crate::plumbing::Ingredient; + use crate::sync::atomic::{self, AtomicU32, Ordering}; + use crate::zalsa::Zalsa; + use crate::IngredientIndex; + + use std::marker::PhantomData; + + /// Caches an ingredient index. + /// + /// Note that all ingredients are statically registered with `inventory`, so their + /// indices should be stable across any databases. + pub struct IngredientCache + where + I: Ingredient, + { + ingredient_index: AtomicU32, + phantom: PhantomData I>, + } + + impl Default for IngredientCache + where + I: Ingredient, + { + fn default() -> Self { + Self::new() + } + } + + impl IngredientCache + where + I: Ingredient, + { + const UNINITIALIZED: u32 = u32::MAX; + + /// Create a new cache + pub const fn new() -> Self { + Self { + ingredient_index: atomic::AtomicU32::new(Self::UNINITIALIZED), + phantom: PhantomData, + } + } + + /// Get a reference to the ingredient in the database. + /// + /// If the ingredient index is not already in the cache, it will be loaded and cached. + pub fn get_or_create<'db>( + &self, + zalsa: &'db Zalsa, + load_index: impl Fn() -> IngredientIndex, + ) -> &'db I { + let mut ingredient_index = self.ingredient_index.load(Ordering::Acquire); + if ingredient_index == Self::UNINITIALIZED { + ingredient_index = self.get_or_create_index_slow(load_index).as_u32(); + }; + + zalsa + .lookup_ingredient(IngredientIndex::from_unchecked(ingredient_index)) + .assert_type() + } + + #[cold] + #[inline(never)] + fn get_or_create_index_slow( + &self, + load_index: impl Fn() -> IngredientIndex, + ) -> IngredientIndex { + let ingredient_index = load_index(); + + // It doesn't matter if we overwrite any stores, as `create_index` should + // always return the same index when the `inventory` feature is enabled. + self.ingredient_index + .store(ingredient_index.as_u32(), Ordering::Release); + + ingredient_index + } + } +} + +#[cfg(not(feature = "inventory"))] +mod imp { + use crate::nonce::Nonce; + use crate::plumbing::Ingredient; + use crate::sync::atomic::{AtomicU64, Ordering}; + use crate::zalsa::{StorageNonce, Zalsa}; + use crate::IngredientIndex; + + use std::marker::PhantomData; + use std::mem; + + /// Caches an ingredient index. + /// + /// With manual registration, ingredient indices can vary across databases, + /// but we can retain most of the benefit by optimizing for the the case of + /// a single database. + pub struct IngredientCache + where + I: Ingredient, + { + // A packed representation of `Option<(Nonce, IngredientIndex)>`. + // + // This allows us to replace a lock in favor of an atomic load. This works thanks to `Nonce` + // having a niche, which means the entire type can fit into an `AtomicU64`. + cached_data: AtomicU64, + phantom: PhantomData I>, + } + + impl Default for IngredientCache + where + I: Ingredient, + { + fn default() -> Self { + Self::new() + } + } + + impl IngredientCache + where + I: Ingredient, + { + const UNINITIALIZED: u64 = 0; + + /// Create a new cache + pub const fn new() -> Self { + Self { + cached_data: AtomicU64::new(Self::UNINITIALIZED), + phantom: PhantomData, + } + } + + /// Get a reference to the ingredient in the database. + /// + /// If the ingredient is not already in the cache, it will be created. + #[inline(always)] + pub fn get_or_create<'db>( + &self, + zalsa: &'db Zalsa, + create_index: impl Fn() -> IngredientIndex, + ) -> &'db I { + let index = self.get_or_create_index(zalsa, create_index); + zalsa.lookup_ingredient(index).assert_type::() + } + + pub fn get_or_create_index( + &self, + zalsa: &Zalsa, + create_index: impl Fn() -> IngredientIndex, + ) -> IngredientIndex { + const _: () = assert!( + mem::size_of::<(Nonce, IngredientIndex)>() == mem::size_of::() + ); + + let cached_data = self.cached_data.load(Ordering::Acquire); + if cached_data == Self::UNINITIALIZED { + return self.get_or_create_index_slow(zalsa, create_index); + }; + + // Unpack our `u64` into the nonce and index. + let index = IngredientIndex::from_unchecked(cached_data as u32); + + // SAFETY: We've checked against `UNINITIALIZED` (0) above and so the upper bits must be non-zero. + let nonce = crate::nonce::Nonce::::from_u32(unsafe { + std::num::NonZeroU32::new_unchecked((cached_data >> u32::BITS) as u32) + }); + + // The data was cached for a different database, we have to ensure the ingredient was + // created in ours. + if zalsa.nonce() != nonce { + return create_index(); + } + + index + } + + #[cold] + #[inline(never)] + fn get_or_create_index_slow( + &self, + zalsa: &Zalsa, + create_index: impl Fn() -> IngredientIndex, + ) -> IngredientIndex { + let index = create_index(); + let nonce = zalsa.nonce().into_u32().get() as u64; + let packed = (nonce << u32::BITS) | (index.as_u32() as u64); + debug_assert_ne!(packed, IngredientCache::::UNINITIALIZED); + + // Discard the result, whether we won over the cache or not doesn't matter. + _ = self.cached_data.compare_exchange( + IngredientCache::::UNINITIALIZED, + packed, + Ordering::Release, + Ordering::Relaxed, + ); + + // Use our locally computed index regardless of which one was cached. + index + } + } +} diff --git a/src/input.rs b/src/input.rs index fe72c7e16..c13d23e21 100644 --- a/src/input.rs +++ b/src/input.rs @@ -56,9 +56,8 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( - _zalsa: &Zalsa, + _zalsa: &mut Zalsa, struct_index: crate::zalsa::IngredientIndex, - _dependencies: crate::memo_ingredient_indices::IngredientIndices, ) -> Vec> { let struct_ingredient: IngredientImpl = IngredientImpl::new(struct_index); @@ -117,7 +116,7 @@ impl IngredientImpl { fields, revisions, durabilities, - memos: Default::default(), + memos: MemoTable::new(self.memo_table_types()), }) }); @@ -238,8 +237,12 @@ impl Ingredient for IngredientImpl { C::DEBUG_NAME } - fn memo_table_types(&self) -> Arc { - self.memo_table_types.clone() + fn memo_table_types(&self) -> &Arc { + &self.memo_table_types + } + + fn memo_table_types_mut(&mut self) -> &mut Arc { + &mut self.memo_table_types } /// Returns memory usage information about any inputs. diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 5e1df4874..f0e4856c8 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -76,7 +76,11 @@ where C::FIELD_DEBUG_NAMES[self.field_index] } - fn memo_table_types(&self) -> Arc { + fn memo_table_types(&self) -> &Arc { + unreachable!("input fields do not allocate pages") + } + + fn memo_table_types_mut(&mut self) -> &mut Arc { unreachable!("input fields do not allocate pages") } } diff --git a/src/interned.rs b/src/interned.rs index 2138494c0..4350e77a6 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -15,7 +15,7 @@ use crate::durability::Durability; use crate::function::VerifyResult; use crate::id::{AsId, FromId}; use crate::ingredient::Ingredient; -use crate::plumbing::{IngredientIndices, Jar, ZalsaLocal}; +use crate::plumbing::{Jar, ZalsaLocal}; use crate::revision::AtomicRevision; use crate::sync::{Arc, Mutex, OnceLock}; use crate::table::memo::{MemoTable, MemoTableTypes, MemoTableWithTypesMut}; @@ -224,9 +224,8 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( - _zalsa: &Zalsa, + _zalsa: &mut Zalsa, first_index: IngredientIndex, - _dependencies: IngredientIndices, ) -> Vec> { vec![Box::new(IngredientImpl::::new(first_index)) as _] } @@ -416,7 +415,6 @@ where // Fill up the table for the first few revisions without attempting garbage collection. if !self.revision_queue.is_primed() { return self.intern_id_cold( - db, key, zalsa, zalsa_local, @@ -530,16 +528,16 @@ where // Insert the new value into the ID map. shard.key_map.insert_unique(hash, new_id, hasher); - // Free the memos associated with the previous interned value. - // // SAFETY: We hold the lock for the shard containing the value, and the // value has not been interned in the current revision, so no references to // it can exist. - let mut memo_table = unsafe { std::mem::take(&mut *value.memos.get()) }; + let memo_table = unsafe { &mut *value.memos.get() }; + // Free the memos associated with the previous interned value. + // // SAFETY: The memo table belongs to a value that we allocated, so it has the // correct type. - unsafe { self.clear_memos(zalsa, &mut memo_table, new_id) }; + unsafe { self.clear_memos(zalsa, memo_table, new_id) }; if value_shared.is_reusable::() { // Move the value to the front of the LRU list. @@ -553,16 +551,7 @@ where } // If we could not find any stale slots, we are forced to allocate a new one. - self.intern_id_cold( - db, - key, - zalsa, - zalsa_local, - assemble, - shard, - shard_index, - hash, - ) + self.intern_id_cold(key, zalsa, zalsa_local, assemble, shard, shard_index, hash) } /// The cold path for interning a value, allocating a new slot. @@ -571,7 +560,6 @@ where #[allow(clippy::too_many_arguments)] fn intern_id_cold<'db, Key>( &'db self, - _db: &'db dyn crate::Database, key: Key, zalsa: &Zalsa, zalsa_local: &ZalsaLocal, @@ -598,7 +586,7 @@ where let id = zalsa_local.allocate(zalsa, self.ingredient_index, |id| Value:: { shard: shard_index as u16, link: LinkedListLink::new(), - memos: UnsafeCell::new(MemoTable::default()), + memos: UnsafeCell::new(MemoTable::new(self.memo_table_types())), // SAFETY: We call `from_internal_data` to restore the correct lifetime before access. fields: UnsafeCell::new(unsafe { self.to_internal_data(assemble(id, key)) }), shared: UnsafeCell::new(ValueShared { @@ -696,6 +684,9 @@ where }; std::mem::forget(table_guard); + + // Reset the table after having dropped any memos. + memo_table.reset(); } // Hashes the value by its fields. @@ -849,8 +840,12 @@ where C::DEBUG_NAME } - fn memo_table_types(&self) -> Arc { - self.memo_table_types.clone() + fn memo_table_types(&self) -> &Arc { + &self.memo_table_types + } + + fn memo_table_types_mut(&mut self) -> &mut Arc { + &mut self.memo_table_types } /// Returns memory usage information about any interned values. diff --git a/src/lib.rs b/src/lib.rs index 83e600771..171452250 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,13 +14,11 @@ mod function; mod hash; mod id; mod ingredient; +mod ingredient_cache; mod input; mod interned; mod key; mod memo_ingredient_indices; -mod nonce; -#[cfg(feature = "rayon")] -mod parallel; mod return_mode; mod revision; mod runtime; @@ -34,6 +32,12 @@ mod views; mod zalsa; mod zalsa_local; +#[cfg(not(feature = "inventory"))] +mod nonce; + +#[cfg(feature = "rayon")] +mod parallel; + #[cfg(feature = "rayon")] pub use parallel::{join, par_map}; #[cfg(feature = "macros")] @@ -90,6 +94,7 @@ pub mod plumbing { pub use crate::durability::Durability; pub use crate::id::{AsId, FromId, FromIdWithDb, Id}; pub use crate::ingredient::{Ingredient, Jar, Location}; + pub use crate::ingredient_cache::IngredientCache; pub use crate::key::DatabaseKeyIndex; pub use crate::memo_ingredient_indices::{ IngredientIndices, MemoIngredientIndices, MemoIngredientMap, MemoIngredientSingletonIndex, @@ -102,8 +107,10 @@ pub mod plumbing { pub use crate::tracked_struct::TrackedStructInDb; pub use crate::update::helper::{Dispatch as UpdateDispatch, Fallback as UpdateFallback}; pub use crate::update::{always_update, Update}; + pub use crate::views::DatabaseDownCaster; pub use crate::zalsa::{ - transmute_data_ptr, views, IngredientCache, IngredientIndex, Zalsa, ZalsaDatabase, + register_jar, transmute_data_ptr, views, ErasedJar, HasJar, IngredientIndex, JarKind, + Zalsa, ZalsaDatabase, }; pub use crate::zalsa_local::ZalsaLocal; diff --git a/src/memo_ingredient_indices.rs b/src/memo_ingredient_indices.rs index 8d3e4322c..ba1dcf45d 100644 --- a/src/memo_ingredient_indices.rs +++ b/src/memo_ingredient_indices.rs @@ -49,11 +49,11 @@ pub trait NewMemoIngredientIndices { /// /// The memo types must be correct. unsafe fn create( - zalsa: &Zalsa, + zalsa: &mut Zalsa, struct_indices: IngredientIndices, ingredient: IngredientIndex, memo_type: MemoEntryType, - intern_ingredient_memo_types: Option>, + intern_ingredient_memo_types: Option<&mut Arc>, ) -> Self; } @@ -62,34 +62,39 @@ impl NewMemoIngredientIndices for MemoIngredientIndices { /// /// The memo types must be correct. unsafe fn create( - zalsa: &Zalsa, + zalsa: &mut Zalsa, struct_indices: IngredientIndices, ingredient: IngredientIndex, memo_type: MemoEntryType, - _intern_ingredient_memo_types: Option>, + _intern_ingredient_memo_types: Option<&mut Arc>, ) -> Self { debug_assert!( _intern_ingredient_memo_types.is_none(), "intern ingredient can only have a singleton memo ingredient" ); + let Some(&last) = struct_indices.indices.last() else { unreachable!("Attempting to construct struct memo mapping for non tracked function?") }; + let mut indices = Vec::new(); indices.resize( (last.as_u32() as usize) + 1, MemoIngredientIndex::from_usize((u32::MAX - 1) as usize), ); + for &struct_ingredient in &struct_indices.indices { - let memo_types = zalsa - .lookup_ingredient(struct_ingredient) - .memo_table_types(); + let memo_ingredient_index = + zalsa.next_memo_ingredient_index(struct_ingredient, ingredient); + indices[struct_ingredient.as_u32() as usize] = memo_ingredient_index; - let mi = zalsa.next_memo_ingredient_index(struct_ingredient, ingredient); - memo_types.set(mi, &memo_type); + let (struct_ingredient, _) = zalsa.lookup_ingredient_mut(struct_ingredient); + let memo_types = Arc::get_mut(struct_ingredient.memo_table_types_mut()) + .expect("memo tables are not shared until database initialization is complete"); - indices[struct_ingredient.as_u32() as usize] = mi; + memo_types.set(memo_ingredient_index, memo_type); } + MemoIngredientIndices { indices: indices.into_boxed_slice(), } @@ -146,25 +151,27 @@ impl MemoIngredientMap for MemoIngredientSingletonIndex { impl NewMemoIngredientIndices for MemoIngredientSingletonIndex { #[inline] unsafe fn create( - zalsa: &Zalsa, + zalsa: &mut Zalsa, indices: IngredientIndices, ingredient: IngredientIndex, memo_type: MemoEntryType, - intern_ingredient_memo_types: Option>, + intern_ingredient_memo_types: Option<&mut Arc>, ) -> Self { let &[struct_ingredient] = &*indices.indices else { unreachable!("Attempting to construct struct memo mapping from enum?") }; + let memo_ingredient_index = zalsa.next_memo_ingredient_index(struct_ingredient, ingredient); let memo_types = intern_ingredient_memo_types.unwrap_or_else(|| { - zalsa - .lookup_ingredient(struct_ingredient) - .memo_table_types() + let (struct_ingredient, _) = zalsa.lookup_ingredient_mut(struct_ingredient); + struct_ingredient.memo_table_types_mut() }); - let mi = zalsa.next_memo_ingredient_index(struct_ingredient, ingredient); - memo_types.set(mi, &memo_type); - Self(mi) + Arc::get_mut(memo_types) + .expect("memo tables are not shared until database initialization is complete") + .set(memo_ingredient_index, memo_type); + + Self(memo_ingredient_index) } } diff --git a/src/salsa_struct.rs b/src/salsa_struct.rs index 80ba48795..725b308a4 100644 --- a/src/salsa_struct.rs +++ b/src/salsa_struct.rs @@ -16,7 +16,7 @@ pub trait SalsaStructInDb: Sized { /// While implementors of this trait may call [`crate::zalsa::JarEntry::get_or_create`] /// to create the ingredient, they aren't required to. For example, supertypes recursively /// call [`crate::zalsa::JarEntry::get_or_create`] for their variants and combine them. - fn lookup_or_create_ingredient_index(zalsa: &Zalsa) -> IngredientIndices; + fn lookup_ingredient_index(zalsa: &Zalsa) -> IngredientIndices; /// Plumbing to support nested salsa supertypes. /// diff --git a/src/storage.rs b/src/storage.rs index 19dd55a40..a8c2abec0 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use std::panic::RefUnwindSafe; use crate::sync::{Arc, Condvar, Mutex}; -use crate::zalsa::{Zalsa, ZalsaDatabase}; +use crate::zalsa::{ErasedJar, HasJar, Zalsa, ZalsaDatabase}; use crate::zalsa_local::{self, ZalsaLocal}; use crate::{Database, Event, EventKind}; @@ -42,8 +42,15 @@ impl Default for StorageHandle { impl StorageHandle { pub fn new(event_callback: Option>) -> Self { + Self::with_jars(event_callback, Vec::new()) + } + + fn with_jars( + event_callback: Option>, + jars: Vec, + ) -> Self { Self { - zalsa_impl: Arc::new(Zalsa::new::(event_callback)), + zalsa_impl: Arc::new(Zalsa::new::(event_callback, jars)), coordinate: CoordinateDrop(Arc::new(Coordinate { clones: Mutex::new(1), cvar: Default::default(), @@ -115,6 +122,11 @@ impl Storage { } } + /// Returns a builder for database storage. + pub fn builder() -> StorageBuilder { + StorageBuilder::default() + } + /// Convert this instance of [`Storage`] into a [`StorageHandle`]. /// /// This will discard the local state of this [`Storage`], thereby returning a value that @@ -168,6 +180,54 @@ impl Storage { // ANCHOR_END: cancel_other_workers } +/// A builder for a [`Storage`] instance. +/// +/// This type can be created with the [`Storage::builder`] function. +pub struct StorageBuilder { + jars: Vec, + event_callback: Option>, + _db: PhantomData, +} + +impl Default for StorageBuilder { + fn default() -> Self { + Self { + jars: Vec::new(), + event_callback: None, + _db: PhantomData, + } + } +} + +impl StorageBuilder { + /// Set a callback for salsa events. + /// + /// The `event_callback` function will be invoked by the salsa runtime at various points during execution. + pub fn event_callback( + mut self, + callback: Box, + ) -> Self { + self.event_callback = Some(callback); + self + } + + /// Manually register an ingredient. + /// + /// Manual ingredient registration is necessary when the `inventory` feature is disabled. + pub fn ingredient(mut self) -> Self { + self.jars.push(ErasedJar::erase::()); + self + } + + /// Construct the [`Storage`] using the provided builder options. + pub fn build(self) -> Storage { + Storage { + handle: StorageHandle::with_jars(self.event_callback, self.jars), + zalsa_local: ZalsaLocal::new(), + } + } +} + #[allow(clippy::undocumented_unsafe_blocks)] // TODO(#697) document safety unsafe impl ZalsaDatabase for T { #[inline(always)] diff --git a/src/sync.rs b/src/sync.rs index b9e2d258d..e3472d2da 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -5,40 +5,6 @@ pub mod shim { pub use shuttle::sync::*; pub use shuttle::{thread, thread_local}; - pub mod papaya { - use std::hash::{BuildHasher, Hash}; - use std::marker::PhantomData; - - pub struct HashMap(super::Mutex>); - - impl Default for HashMap { - fn default() -> Self { - Self(super::Mutex::default()) - } - } - - pub struct LocalGuard<'a>(PhantomData<&'a ()>); - - impl HashMap - where - K: Eq + Hash, - V: Clone, - S: BuildHasher, - { - pub fn guard(&self) -> LocalGuard<'_> { - LocalGuard(PhantomData) - } - - pub fn get(&self, key: &K, _guard: &LocalGuard<'_>) -> Option { - self.0.lock().get(key).cloned() - } - - pub fn insert(&self, key: K, value: V, _guard: &LocalGuard<'_>) { - self.0.lock().insert(key, value); - } - } - } - /// A wrapper around shuttle's `Mutex` to mirror parking-lot's API. #[derive(Default, Debug)] pub struct Mutex(shuttle::sync::Mutex); @@ -57,24 +23,6 @@ pub mod shim { } } - /// A wrapper around shuttle's `RwLock` to mirror parking-lot's API. - #[derive(Default, Debug)] - pub struct RwLock(shuttle::sync::RwLock); - - impl RwLock { - pub fn read(&self) -> RwLockReadGuard<'_, T> { - self.0.read().unwrap() - } - - pub fn write(&self) -> RwLockWriteGuard<'_, T> { - self.0.write().unwrap() - } - - pub fn get_mut(&mut self) -> &mut T { - self.0.get_mut().unwrap() - } - } - /// A wrapper around shuttle's `Condvar` to mirror parking-lot's API. #[derive(Default, Debug)] pub struct Condvar(shuttle::sync::Condvar); @@ -164,7 +112,7 @@ pub mod shim { #[cfg(not(feature = "shuttle"))] pub mod shim { - pub use parking_lot::{Mutex, MutexGuard, RwLock}; + pub use parking_lot::{Mutex, MutexGuard}; pub use std::sync::*; pub use std::{thread, thread_local}; @@ -173,48 +121,6 @@ pub mod shim { pub use std::sync::atomic::*; } - pub mod papaya { - use std::hash::{BuildHasher, Hash}; - - pub use papaya::LocalGuard; - - pub struct HashMap(papaya::HashMap); - - impl Default for HashMap { - fn default() -> Self { - Self( - papaya::HashMap::builder() - .capacity(256) // A relatively large capacity to hopefully avoid resizing. - .resize_mode(papaya::ResizeMode::Blocking) - .hasher(S::default()) - .build(), - ) - } - } - - impl HashMap - where - K: Eq + Hash, - V: Clone, - S: BuildHasher, - { - #[inline] - pub fn guard(&self) -> LocalGuard<'_> { - self.0.guard() - } - - #[inline] - pub fn get(&self, key: &K, guard: &LocalGuard<'_>) -> Option { - self.0.get(key, guard).cloned() - } - - #[inline] - pub fn insert(&self, key: K, value: V, guard: &LocalGuard<'_>) { - self.0.insert(key, value, guard); - } - } - } - /// A wrapper around parking-lot's `Condvar` to mirror shuttle's API. pub struct Condvar(parking_lot::Condvar); diff --git a/src/table/memo.rs b/src/table/memo.rs index 2ee10134f..63d3671ad 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -3,19 +3,32 @@ use std::fmt::Debug; use std::mem; use std::ptr::{self, NonNull}; -use portable_atomic::hint::spin_loop; -use thin_vec::ThinVec; - use crate::sync::atomic::{AtomicPtr, Ordering}; -use crate::sync::{OnceLock, RwLock}; use crate::{zalsa::MemoIngredientIndex, zalsa_local::QueryOriginRef}; /// The "memo table" stores the memoized results of tracked function calls. /// Every tracked function must take a salsa struct as its first argument /// and memo tables are attached to those salsa structs as auxiliary data. -#[derive(Default)] pub(crate) struct MemoTable { - memos: RwLock>, + memos: Box<[MemoEntry]>, +} + +impl MemoTable { + /// Create a `MemoTable` with slots for memos from the provided `MemoTableTypes`. + pub fn new(types: &MemoTableTypes) -> Self { + Self { + memos: (0..types.len()).map(|_| MemoEntry::default()).collect(), + } + } + + /// Reset any memos in the table. + /// + /// Note that the memo entries should be freed manually before calling this function. + pub fn reset(&mut self) { + for memo in &mut self.memos { + *memo = MemoEntry::default(); + } + } } pub trait Memo: Any + Send + Sync { @@ -50,13 +63,8 @@ struct MemoEntry { atomic_memo: AtomicPtr, } -#[derive(Default)] -pub struct MemoEntryType { - data: OnceLock, -} - #[derive(Clone, Copy, Debug)] -struct MemoEntryTypeData { +pub struct MemoEntryType { /// The `type_id` of the erased memo type `M` type_id: TypeId, @@ -89,17 +97,10 @@ impl MemoEntryType { #[inline] pub fn of() -> Self { Self { - data: OnceLock::from(MemoEntryTypeData { - type_id: TypeId::of::(), - to_dyn_fn: Self::to_dyn_fn::(), - }), + type_id: TypeId::of::(), + to_dyn_fn: Self::to_dyn_fn::(), } } - - #[inline] - fn load(&self) -> Option<&MemoEntryTypeData> { - self.data.get() - } } /// Dummy placeholder type that we use when erasing the memo type `M` in [`MemoEntryData`][]. @@ -127,43 +128,21 @@ impl Memo for DummyMemo { #[derive(Default)] pub struct MemoTableTypes { - types: boxcar::Vec, + types: Vec, } impl MemoTableTypes { pub(crate) fn set( - &self, + &mut self, memo_ingredient_index: MemoIngredientIndex, - memo_type: &MemoEntryType, + memo_type: MemoEntryType, ) { - let memo_ingredient_index = memo_ingredient_index.as_usize(); - - // Try to create our entry if it has not already been created. - if memo_ingredient_index >= self.types.count() { - while self.types.push(MemoEntryType::default()) < memo_ingredient_index {} - } - - loop { - let Some(memo_entry_type) = self.types.get(memo_ingredient_index) else { - // It's possible that someone else began pushing to our index but has not - // completed the entry's initialization yet, as `boxcar` is lock-free. This - // is extremely unlikely given initialization is just a handful of instructions. - // Additionally, this function is generally only called on startup, so we can - // just spin here. - spin_loop(); - continue; - }; + self.types + .insert(memo_ingredient_index.as_usize(), memo_type); + } - memo_entry_type - .data - .set( - *memo_type.data.get().expect( - "cannot provide an empty `MemoEntryType` for `MemoEntryType::set()`", - ), - ) - .expect("memo type should only be set once"); - break; - } + pub fn len(&self) -> usize { + self.types.len() } /// # Safety @@ -204,59 +183,25 @@ impl MemoTableWithTypes<'_> { assert_eq!( self.types .types - .get(memo_ingredient_index.as_usize()) - .and_then(MemoEntryType::load)? + .get(memo_ingredient_index.as_usize())? .type_id, TypeId::of::(), "inconsistent type-id for `{memo_ingredient_index:?}`" ); - // If the memo slot is already occupied, it must already have the - // right type info etc, and we only need the read-lock. - if let Some(MemoEntry { atomic_memo }) = self + // The memo table is pre-sized on creation based on the corresponding `MemoTableTypes`. + let MemoEntry { atomic_memo } = self .memos .memos - .read() .get(memo_ingredient_index.as_usize()) - { - let old_memo = - atomic_memo.swap(MemoEntryType::to_dummy(memo).as_ptr(), Ordering::AcqRel); - - let old_memo = NonNull::new(old_memo); + .expect("accessed memo table with invalid index"); - // SAFETY: `type_id` check asserted above - return old_memo.map(|old_memo| unsafe { MemoEntryType::from_dummy(old_memo) }); - } - - // Otherwise we need the write lock. - self.insert_cold(memo_ingredient_index, memo) - } - - #[cold] - fn insert_cold( - self, - memo_ingredient_index: MemoIngredientIndex, - memo: NonNull, - ) -> Option> { - let memo_ingredient_index = memo_ingredient_index.as_usize(); - let mut memos = self.memos.memos.write(); - - // Grow the table if needed. - if memos.len() <= memo_ingredient_index { - let additional_len = memo_ingredient_index - memos.len() + 1; - memos.reserve(additional_len); - while memos.len() <= memo_ingredient_index { - memos.push(MemoEntry::default()); - } - } + let old_memo = atomic_memo.swap(MemoEntryType::to_dummy(memo).as_ptr(), Ordering::AcqRel); - let old_entry = mem::replace( - memos[memo_ingredient_index].atomic_memo.get_mut(), - MemoEntryType::to_dummy(memo).as_ptr(), - ); + let old_memo = NonNull::new(old_memo); - // SAFETY: The `TypeId` is asserted in `insert()`. - NonNull::new(old_entry).map(|memo| unsafe { MemoEntryType::from_dummy(memo) }) + // SAFETY: `type_id` check asserted above + old_memo.map(|old_memo| unsafe { MemoEntryType::from_dummy(old_memo) }) } #[inline] @@ -264,13 +209,8 @@ impl MemoTableWithTypes<'_> { self, memo_ingredient_index: MemoIngredientIndex, ) -> Option> { - let read = self.memos.memos.read(); - let memo = read.get(memo_ingredient_index.as_usize())?; - let type_ = self - .types - .types - .get(memo_ingredient_index.as_usize()) - .and_then(MemoEntryType::load)?; + let memo = self.memos.memos.get(memo_ingredient_index.as_usize())?; + let type_ = self.types.types.get(memo_ingredient_index.as_usize())?; assert_eq!( type_.type_id, TypeId::of::(), @@ -284,13 +224,12 @@ impl MemoTableWithTypes<'_> { #[cfg(feature = "salsa_unstable")] pub(crate) fn memory_usage(&self) -> Vec { let mut memory_usage = Vec::new(); - let memos = self.memos.memos.read(); - for (index, memo) in memos.iter().enumerate() { + for (index, memo) in self.memos.memos.iter().enumerate() { let Some(memo) = NonNull::new(memo.atomic_memo.load(Ordering::Acquire)) else { continue; }; - let Some(type_) = self.types.types.get(index).and_then(MemoEntryType::load) else { + let Some(type_) = self.types.types.get(index) else { continue; }; @@ -317,12 +256,7 @@ impl MemoTableWithTypesMut<'_> { memo_ingredient_index: MemoIngredientIndex, f: impl FnOnce(&mut M), ) { - let Some(type_) = self - .types - .types - .get(memo_ingredient_index.as_usize()) - .and_then(MemoEntryType::load) - else { + let Some(type_) = self.types.types.get(memo_ingredient_index.as_usize()) else { return; }; assert_eq!( @@ -331,13 +265,13 @@ impl MemoTableWithTypesMut<'_> { "inconsistent type-id for `{memo_ingredient_index:?}`" ); - // If the memo slot is already occupied, it must already have the - // right type info etc, and we only need the read-lock. - let memos = self.memos.memos.get_mut(); - let Some(MemoEntry { atomic_memo }) = memos.get_mut(memo_ingredient_index.as_usize()) + // The memo table is pre-sized on creation based on the corresponding `MemoTableTypes`. + let Some(MemoEntry { atomic_memo }) = + self.memos.memos.get_mut(memo_ingredient_index.as_usize()) else { return; }; + let Some(memo) = NonNull::new(*atomic_memo.get_mut()) else { return; }; @@ -357,7 +291,7 @@ impl MemoTableWithTypesMut<'_> { #[inline] pub unsafe fn drop(&mut self) { let types = self.types.types.iter(); - for ((_, type_), memo) in std::iter::zip(types, self.memos.memos.get_mut()) { + for (type_, memo) in std::iter::zip(types, &mut self.memos.memos) { // SAFETY: The types match as per our constructor invariant. unsafe { memo.take(type_) }; } @@ -371,12 +305,12 @@ impl MemoTableWithTypesMut<'_> { &mut self, mut f: impl FnMut(MemoIngredientIndex, Box), ) { - let memos = self.memos.memos.get_mut(); - memos + self.memos + .memos .iter_mut() .zip(self.types.types.iter()) .enumerate() - .filter_map(|(index, (memo, (_, type_)))| { + .filter_map(|(index, (memo, type_))| { // SAFETY: The types match as per our constructor invariant. let memo = unsafe { memo.take(type_)? }; Some((MemoIngredientIndex::from_usize(index), memo)) @@ -393,7 +327,6 @@ impl MemoEntry { unsafe fn take(&mut self, type_: &MemoEntryType) -> Option> { let memo = mem::replace(self.atomic_memo.get_mut(), ptr::null_mut()); let memo = NonNull::new(memo)?; - let type_ = type_.load()?; // SAFETY: Our preconditions. Some(unsafe { Box::from_raw((type_.to_dyn_fn)(memo).as_ptr()) }) } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index ef7f9926f..fd93250aa 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -110,9 +110,8 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( - _zalsa: &Zalsa, + _zalsa: &mut Zalsa, struct_index: crate::zalsa::IngredientIndex, - _dependencies: crate::memo_ingredient_indices::IngredientIndices, ) -> Vec> { let struct_ingredient = >::new(struct_index); @@ -444,7 +443,7 @@ where // lifetime erase for storage fields: unsafe { mem::transmute::, C::Fields<'static>>(fields) }, revisions: C::new_revisions(current_deps.changed_at), - memos: Default::default(), + memos: MemoTable::new(self.memo_table_types()), }; while let Some(id) = self.free_list.pop() { @@ -601,11 +600,11 @@ where // Note that we hold the lock and have exclusive access to the tracked struct data, // so there should be no live instances of IDs from the previous generation. We clear // the memos and return a new ID here as if we have allocated a new slot. - let mut table = data.take_memo_table(); + let memo_table = data.memo_table_mut(); // SAFETY: The memo table belongs to a value that we allocated, so it has the // correct type. - unsafe { self.clear_memos(zalsa, &mut table, id) }; + unsafe { self.clear_memos(zalsa, memo_table, id) }; id = id .next_generation() @@ -674,11 +673,11 @@ where // SAFETY: We have acquired the write lock let data = unsafe { &mut *data_raw }; - let mut memo_table = data.take_memo_table(); + let memo_table = data.memo_table_mut(); // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. - unsafe { self.clear_memos(zalsa, &mut memo_table, id) }; + unsafe { self.clear_memos(zalsa, memo_table, id) }; // now that all cleanup has occurred, make available for re-use self.free_list.push(id); @@ -724,6 +723,9 @@ where }; mem::forget(table_guard); + + // Reset the table after having dropped any memos. + memo_table.reset(); } /// Return reference to the field data ignoring dependency tracking. @@ -849,8 +851,12 @@ where C::DEBUG_NAME } - fn memo_table_types(&self) -> Arc { - self.memo_table_types.clone() + fn memo_table_types(&self) -> &Arc { + &self.memo_table_types + } + + fn memo_table_types_mut(&mut self) -> &mut Arc { + &mut self.memo_table_types } /// Returns memory usage information about any tracked structs. @@ -891,13 +897,12 @@ where unsafe { mem::transmute::<&C::Fields<'static>, &C::Fields<'_>>(&self.fields) } } - fn take_memo_table(&mut self) -> MemoTable { + fn memo_table_mut(&mut self) -> &mut MemoTable { // This fn is only called after `updated_at` has been set to `None`; // this ensures that there is no concurrent access // (and that the `&mut self` is accurate...). assert!(self.updated_at.load().is_none()); - - mem::take(&mut self.memos) + &mut self.memos } fn read_lock(&self, current_revision: Revision) { diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 5ec38c680..ad3e871e8 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -82,7 +82,11 @@ where C::TRACKED_FIELD_NAMES[self.field_index] } - fn memo_table_types(&self) -> Arc { + fn memo_table_types(&self) -> &Arc { + unreachable!("tracked field does not allocate pages") + } + + fn memo_table_types_mut(&mut self) -> &mut Arc { unreachable!("tracked field does not allocate pages") } } diff --git a/src/views.rs b/src/views.rs index a14852898..01a0a2de5 100644 --- a/src/views.rs +++ b/src/views.rs @@ -80,16 +80,16 @@ impl Views { } /// Add a new downcaster from `dyn Database` to `dyn DbView`. - pub fn add(&self, func: DatabaseDownCasterSig) { - let target_type_id = TypeId::of::(); - if self - .view_casters - .iter() - .any(|(_, u)| u.target_type_id == target_type_id) - { - return; + pub fn add( + &self, + func: DatabaseDownCasterSig, + ) -> DatabaseDownCaster { + if let Some(view) = self.try_downcaster_for() { + return view; } + self.view_casters.push(ViewCaster::new::(func)); + DatabaseDownCaster(self.source_type_id, func) } /// Retrieve an downcaster function from `dyn Database` to `dyn DbView`. @@ -98,23 +98,31 @@ impl Views { /// /// If the underlying type of `db` is not the same as the database type this upcasts was created for. pub fn downcaster_for(&self) -> DatabaseDownCaster { + self.try_downcaster_for().unwrap_or_else(|| { + panic!( + "No downcaster registered for type `{}` in `Views`", + std::any::type_name::(), + ) + }) + } + + /// Retrieve an downcaster function from `dyn Database` to `dyn DbView`, if it exists. + #[inline] + pub fn try_downcaster_for(&self) -> Option> { let view_type_id = TypeId::of::(); - for (_idx, view) in self.view_casters.iter() { + for (_, view) in self.view_casters.iter() { if view.target_type_id == view_type_id { // SAFETY: We are unerasing the type erased function pointer having made sure the - // TypeId matches. - return DatabaseDownCaster(self.source_type_id, unsafe { + // `TypeId` matches. + return Some(DatabaseDownCaster(self.source_type_id, unsafe { std::mem::transmute::>( view.cast, ) - }); + })); } } - panic!( - "No downcaster registered for type `{}` in `Views`", - std::any::type_name::(), - ); + None } } diff --git a/src/zalsa.rs b/src/zalsa.rs index cc23881cd..c1c46296d 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -1,18 +1,13 @@ use std::any::{Any, TypeId}; use std::hash::BuildHasherDefault; -use std::marker::PhantomData; -use std::mem; -use std::num::NonZeroU32; use std::panic::RefUnwindSafe; +use hashbrown::HashMap; use rustc_hash::FxHashMap; use crate::hash::TypeIdHasher; use crate::ingredient::{Ingredient, Jar}; -use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::Runtime; -use crate::sync::atomic::{AtomicU64, Ordering}; -use crate::sync::{papaya, Mutex, RwLock}; use crate::table::memo::MemoTableWithTypes; use crate::table::Table; use crate::views::Views; @@ -62,13 +57,14 @@ pub unsafe trait ZalsaDatabase: Any { pub fn views(db: &Db) -> &Views { db.zalsa().views() } - /// Nonce type representing the underlying database storage. #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[cfg(not(feature = "inventory"))] pub struct StorageNonce; // Generator for storage nonces. -static NONCE: NonceGenerator = NonceGenerator::new(); +#[cfg(not(feature = "inventory"))] +static NONCE: crate::nonce::NonceGenerator = crate::nonce::NonceGenerator::new(); /// An ingredient index identifies a particular [`Ingredient`] in the database. /// @@ -83,10 +79,16 @@ impl IngredientIndex { /// This reserves one bit for an optional tag. const MAX_INDEX: u32 = 0x7FFF_FFFF; - /// Create an ingredient index from a `usize`. - pub(crate) fn from(v: usize) -> Self { - assert!(v <= Self::MAX_INDEX as usize); - Self(v as u32) + /// Create an ingredient index from a `u32`. + pub(crate) fn from(v: u32) -> Self { + assert!(v <= Self::MAX_INDEX); + Self(v) + } + + /// Create an ingredient index from a `u32`, without performing validating + /// that the index is valid. + pub(crate) fn from_unchecked(v: u32) -> Self { + Self(v) } /// Convert the ingredient index back into a `u32`. @@ -134,28 +136,24 @@ impl MemoIngredientIndex { pub struct Zalsa { views_of: Views, - nonce: Nonce, + #[cfg(not(feature = "inventory"))] + nonce: crate::nonce::Nonce, /// Map from the [`IngredientIndex::as_usize`][] of a salsa struct to a list of /// [ingredient-indices](`IngredientIndex`) for tracked functions that have this salsa struct /// as input. - memo_ingredient_indices: RwLock>>, + memo_ingredient_indices: Vec>, /// Map from the type-id of an `impl Jar` to the index of its first ingredient. - jar_map: papaya::HashMap>, - - /// The write-lock for `jar_map`. - jar_map_lock: Mutex<()>, + jar_map: HashMap>, /// A map from the `IngredientIndex` to the `TypeId` of its ID struct. /// /// Notably this is not the reverse mapping of `jar_map`. - ingredient_to_id_struct_type_id_map: RwLock>, + ingredient_to_id_struct_type_id_map: FxHashMap, /// Vector of ingredients. - /// - /// Immutable unless the mutex on `ingredients_map` is held. - ingredients_vec: boxcar::Vec>, + ingredients_vec: Vec>, /// Indices of ingredients that require reset when a new revision starts. ingredients_requiring_reset: boxcar::Vec, @@ -177,22 +175,43 @@ impl RefUnwindSafe for Zalsa {} impl Zalsa { pub(crate) fn new( event_callback: Option>, + jars: Vec, ) -> Self { - Self { + let mut zalsa = Self { views_of: Views::new::(), - nonce: NONCE.nonce(), - jar_map: papaya::HashMap::default(), - jar_map_lock: Mutex::default(), + jar_map: HashMap::default(), ingredient_to_id_struct_type_id_map: Default::default(), - ingredients_vec: boxcar::Vec::new(), + ingredients_vec: Vec::new(), ingredients_requiring_reset: boxcar::Vec::new(), runtime: Runtime::default(), memo_ingredient_indices: Default::default(), event_callback, + #[cfg(not(feature = "inventory"))] + nonce: NONCE.nonce(), + }; + + // Collect and initialize all registered ingredients. + #[cfg(feature = "inventory")] + let mut jars = inventory::iter::() + .copied() + .chain(jars) + .collect::>(); + + #[cfg(not(feature = "inventory"))] + let mut jars = jars; + + // Ensure structs are initialized before tracked functions. + jars.sort_by_key(|jar| jar.kind); + + for jar in jars { + zalsa.insert_jar(jar); } + + zalsa } - pub(crate) fn nonce(&self) -> Nonce { + #[cfg(not(feature = "inventory"))] + pub(crate) fn nonce(&self) -> crate::nonce::Nonce { self.nonce } @@ -218,7 +237,7 @@ impl Zalsa { } #[inline] - pub(crate) fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient { + pub fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient { let index = index.as_u32() as usize; self.ingredients_vec .get(index) @@ -231,7 +250,7 @@ impl Zalsa { struct_ingredient_index: IngredientIndex, memo_ingredient_index: MemoIngredientIndex, ) -> IngredientIndex { - self.memo_ingredient_indices.read()[struct_ingredient_index.as_u32() as usize] + self.memo_ingredient_indices[struct_ingredient_index.as_u32() as usize] [memo_ingredient_index.as_usize()] } @@ -239,7 +258,7 @@ impl Zalsa { pub(crate) fn ingredients(&self) -> impl Iterator { self.ingredients_vec .iter() - .map(|(_, ingredient)| ingredient.as_ref()) + .map(|ingredient| ingredient.as_ref()) } /// Starts unwinding the stack if the current revision is cancelled. @@ -259,11 +278,11 @@ impl Zalsa { } pub(crate) fn next_memo_ingredient_index( - &self, + &mut self, struct_ingredient_index: IngredientIndex, ingredient_index: IngredientIndex, ) -> MemoIngredientIndex { - let mut memo_ingredients = self.memo_ingredient_indices.write(); + let memo_ingredients = &mut self.memo_ingredient_indices; let idx = struct_ingredient_index.as_u32() as usize; let memo_ingredients = if let Some(memo_ingredients) = memo_ingredients.get_mut(idx) { memo_ingredients @@ -291,7 +310,6 @@ impl Zalsa { let ingredient_index = self.ingredient_index(id); *self .ingredient_to_id_struct_type_id_map - .read() .get(&ingredient_index) .expect("should have the ingredient index available") } @@ -299,44 +317,36 @@ impl Zalsa { /// **NOT SEMVER STABLE** #[doc(hidden)] #[inline] - pub fn lookup_jar_by_type(&self) -> JarEntry<'_, J> { + pub fn lookup_jar_by_type(&self) -> IngredientIndex { let jar_type_id = TypeId::of::(); - let guard = self.jar_map.guard(); - - match self.jar_map.get(&jar_type_id, &guard) { - Some(index) => JarEntry::Occupied(index), - None => JarEntry::Vacant { - guard, - zalsa: self, - _jar: PhantomData, - }, - } - } - #[cold] - #[inline(never)] - fn add_or_lookup_jar_by_type(&self, guard: &papaya::LocalGuard<'_>) -> IngredientIndex { - let jar_type_id = TypeId::of::(); - let dependencies = J::create_dependencies(self); + *self.jar_map.get(&jar_type_id).unwrap_or_else(|| { + panic!( + "ingredient `{}` was not registered", + std::any::type_name::() + ) + }) + } - let jar_map_lock = self.jar_map_lock.lock(); + fn insert_jar(&mut self, jar: ErasedJar) { + let jar_type_id = (jar.type_id)(); - let index = IngredientIndex::from(self.ingredients_vec.count()); + let index = IngredientIndex::from(self.ingredients_vec.len() as u32); - // Someone made it earlier than us. - if let Some(index) = self.jar_map.get(&jar_type_id, guard) { - return index; - }; + if self.jar_map.contains_key(&jar_type_id) { + return; + } - let ingredients = J::create_ingredients(self, index, dependencies); + let ingredients = (jar.create_ingredients)(self, index); for ingredient in ingredients { let expected_index = ingredient.ingredient_index(); - if ingredient.requires_reset_for_new_revision() { self.ingredients_requiring_reset.push(expected_index); } - let actual_index = self.ingredients_vec.push(ingredient); + self.ingredients_vec.push(ingredient); + + let actual_index = self.ingredients_vec.len() - 1; assert_eq!( expected_index.as_u32() as usize, actual_index, @@ -347,17 +357,10 @@ impl Zalsa { ); } - // Insert the index after all ingredients are inserted to avoid exposing - // partially initialized jars to readers. - self.jar_map.insert(jar_type_id, index, guard); - - drop(jar_map_lock); + self.jar_map.insert(jar_type_id, index); self.ingredient_to_id_struct_type_id_map - .write() - .insert(index, J::id_struct_type_id()); - - index + .insert(index, (jar.id_struct_type_id)()); } /// **NOT SEMVER STABLE** @@ -434,139 +437,69 @@ impl Zalsa { } } -pub enum JarEntry<'a, J> { - Occupied(IngredientIndex), - Vacant { - zalsa: &'a Zalsa, - guard: papaya::LocalGuard<'a>, - _jar: PhantomData, - }, -} - -impl JarEntry<'_, J> -where - J: Jar, -{ - #[inline] - pub fn get(&self) -> Option { - match *self { - JarEntry::Occupied(index) => Some(index), - JarEntry::Vacant { .. } => None, - } - } - - #[inline] - pub fn get_or_create(&self) -> IngredientIndex { - match self { - JarEntry::Occupied(index) => *index, - JarEntry::Vacant { zalsa, guard, _jar } => zalsa.add_or_lookup_jar_by_type::(guard), - } - } -} - -/// Caches a pointer to an ingredient in a database. -/// Optimized for the case of a single database. -pub struct IngredientCache -where - I: Ingredient, -{ - // A packed representation of `Option<(Nonce, IngredientIndex)>`. - // - // This allows us to replace a lock in favor of an atomic load. This works thanks to `Nonce` - // having a niche, which means the entire type can fit into an `AtomicU64`. - cached_data: AtomicU64, - phantom: PhantomData I>, +/// A type-erased `Jar`, used for ingredient registration. +#[derive(Clone, Copy)] +pub struct ErasedJar { + kind: JarKind, + type_id: fn() -> TypeId, + id_struct_type_id: fn() -> TypeId, + create_ingredients: fn(&mut Zalsa, IngredientIndex) -> Vec>, } -impl Default for IngredientCache -where - I: Ingredient, -{ - fn default() -> Self { - Self::new() - } +/// The kind of an `Jar`. +/// +/// Note that the ordering of the variants is important. Struct ingredients must be +/// initialized before tracked functions, as tracked function ingredients depend on +/// their input struct. +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)] +pub enum JarKind { + /// An input/tracked/interned struct. + Struct, + + /// A tracked function. + TrackedFn, } -impl IngredientCache -where - I: Ingredient, -{ - const UNINITIALIZED: u64 = 0; - - /// Create a new cache - pub const fn new() -> Self { +impl ErasedJar { + /// Performs type-erasure of a given ingredient. + pub const fn erase() -> Self { Self { - cached_data: AtomicU64::new(Self::UNINITIALIZED), - phantom: PhantomData, + kind: I::KIND, + type_id: TypeId::of::, + create_ingredients: ::create_ingredients, + id_struct_type_id: ::id_struct_type_id, } } +} - /// Get a reference to the ingredient in the database. - /// If the ingredient is not already in the cache, it will be created. - #[inline(always)] - pub fn get_or_create<'db>( - &self, - zalsa: &'db Zalsa, - create_index: impl Fn() -> IngredientIndex, - ) -> &'db I { - let index = self.get_or_create_index(zalsa, create_index); - zalsa.lookup_ingredient(index).assert_type::() - } +/// A salsa ingredient that can be registered in the database. +/// +/// This trait is implemented for tracked functions and salsa structs. +pub trait HasJar { + /// The [`Jar`] associated with this ingredient. + type Jar: Jar; - /// Get a reference to the ingredient in the database. - /// If the ingredient is not already in the cache, it will be created. - #[inline(always)] - pub fn get_or_create_index( - &self, - zalsa: &Zalsa, - create_index: impl Fn() -> IngredientIndex, - ) -> IngredientIndex { - const _: () = assert!( - mem::size_of::<(Nonce, IngredientIndex)>() == mem::size_of::() - ); - let cached_data = self.cached_data.load(Ordering::Acquire); - if cached_data == Self::UNINITIALIZED { - #[cold] - #[inline(never)] - fn get_or_create_index_slow( - this: &IngredientCache, - zalsa: &Zalsa, - create_index: impl Fn() -> IngredientIndex, - ) -> IngredientIndex { - let index = create_index(); - let nonce = zalsa.nonce().into_u32().get() as u64; - let packed = (nonce << u32::BITS) | (index.as_u32() as u64); - debug_assert_ne!(packed, IngredientCache::::UNINITIALIZED); - - // Discard the result, whether we won over the cache or not does not matter - // we know that something has been cached now - _ = this.cached_data.compare_exchange( - IngredientCache::::UNINITIALIZED, - packed, - Ordering::Release, - Ordering::Acquire, - ); - // and we already have our index computed so we can just use that - index - } + /// The [`JarKind`] for `Self::Jar`. + const KIND: JarKind; +} - return get_or_create_index_slow(self, zalsa, create_index); - }; +// Collect jars statically at compile-time if supported. +#[cfg(feature = "inventory")] +inventory::collect!(ErasedJar); - // unpack our u64 - // SAFETY: We've checked against `UNINITIALIZED` (0) above and so the upper bits must be non-zero - let nonce = Nonce::::from_u32(unsafe { - NonZeroU32::new_unchecked((cached_data >> u32::BITS) as u32) - }); - let mut index = IngredientIndex(cached_data as u32); +#[cfg(feature = "inventory")] +pub use inventory::submit as register_jar; - if zalsa.nonce() != nonce { - index = create_index(); - } - index - } +#[cfg(not(feature = "inventory"))] +#[macro_export] +#[doc(hidden)] +macro_rules! register_jar { + ($($_:tt)*) => {}; } +#[cfg(not(feature = "inventory"))] +pub use crate::register_jar; + /// Given a wide pointer `T`, extracts the data pointer (typed as `U`). /// /// # Safety diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 80e24e7f9..51c28c9c5 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -754,7 +754,7 @@ impl QueryOrigin { QueryOriginKind::Assigned => { // SAFETY: `data.index` is initialized when the tag is `QueryOriginKind::Assigned`. let index = unsafe { self.data.index }; - let ingredient_index = IngredientIndex::from(self.metadata as usize); + let ingredient_index = IngredientIndex::from(self.metadata); QueryOriginRef::Assigned(DatabaseKeyIndex::new(ingredient_index, index)) } diff --git a/tests/accumulate-chain.rs b/tests/accumulate-chain.rs index d51e67c14..18d4bb56a 100644 --- a/tests/accumulate-chain.rs +++ b/tests/accumulate-chain.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that when having nested tracked functions //! we don't drop any values when accumulating. diff --git a/tests/accumulate-custom-debug.rs b/tests/accumulate-custom-debug.rs index a4c078ab7..180156042 100644 --- a/tests/accumulate-custom-debug.rs +++ b/tests/accumulate-custom-debug.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + mod common; use expect_test::expect; diff --git a/tests/accumulate-dag.rs b/tests/accumulate-dag.rs index 6786d0f8e..41d9b3908 100644 --- a/tests/accumulate-dag.rs +++ b/tests/accumulate-dag.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + mod common; use expect_test::expect; diff --git a/tests/accumulate-execution-order.rs b/tests/accumulate-execution-order.rs index 28aeb5d7b..1a0d3e233 100644 --- a/tests/accumulate-execution-order.rs +++ b/tests/accumulate-execution-order.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Demonstrates that accumulation is done in the order //! in which things were originally executed. diff --git a/tests/accumulate-from-tracked-fn.rs b/tests/accumulate-from-tracked-fn.rs index 5fba8a688..67e591688 100644 --- a/tests/accumulate-from-tracked-fn.rs +++ b/tests/accumulate-from-tracked-fn.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Accumulate values from within a tracked function. //! Then mutate the values so that the tracked function re-executes. //! Check that we accumulate the appropriate, new values. diff --git a/tests/accumulate-no-duplicates.rs b/tests/accumulate-no-duplicates.rs index 0907c469d..8d21281e5 100644 --- a/tests/accumulate-no-duplicates.rs +++ b/tests/accumulate-no-duplicates.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that we don't get duplicate accumulated values mod common; diff --git a/tests/accumulate-reuse-workaround.rs b/tests/accumulate-reuse-workaround.rs index 915a14c1c..43c5bb3ce 100644 --- a/tests/accumulate-reuse-workaround.rs +++ b/tests/accumulate-reuse-workaround.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Demonstrates the workaround of wrapping calls to //! `accumulated` in a tracked function to get better //! reuse. diff --git a/tests/accumulate-reuse.rs b/tests/accumulate-reuse.rs index b7d918870..1e6194de6 100644 --- a/tests/accumulate-reuse.rs +++ b/tests/accumulate-reuse.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Accumulator re-use test. //! //! Tests behavior when a query's only inputs diff --git a/tests/accumulate.rs b/tests/accumulate.rs index dcacfb7ad..54022a15e 100644 --- a/tests/accumulate.rs +++ b/tests/accumulate.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + mod common; use common::{LogDatabase, LoggerDatabase}; use expect_test::expect; diff --git a/tests/accumulated_backdate.rs b/tests/accumulated_backdate.rs index ce8f6580f..45759d1ba 100644 --- a/tests/accumulated_backdate.rs +++ b/tests/accumulated_backdate.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Tests that accumulated values are correctly accounted for //! when backdating a value. diff --git a/tests/backtrace.rs b/tests/backtrace.rs index c64fd8ae3..74124c1ab 100644 --- a/tests/backtrace.rs +++ b/tests/backtrace.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use expect_test::expect; use salsa::{Backtrace, Database, DatabaseImpl}; use test_log::test; @@ -71,15 +73,15 @@ fn backtrace_works() { expect![[r#" query stacktrace: 0: query_e(Id(0)) - at tests/backtrace.rs:30 + at tests/backtrace.rs:32 1: query_d(Id(0)) - at tests/backtrace.rs:25 + at tests/backtrace.rs:27 2: query_c(Id(0)) - at tests/backtrace.rs:20 + at tests/backtrace.rs:22 3: query_b(Id(0)) - at tests/backtrace.rs:15 + at tests/backtrace.rs:17 4: query_a(Id(0)) - at tests/backtrace.rs:10 + at tests/backtrace.rs:12 "#]] .assert_eq(&backtrace); @@ -87,15 +89,15 @@ fn backtrace_works() { expect![[r#" query stacktrace: 0: query_e(Id(1)) -> (R1, Durability::LOW) - at tests/backtrace.rs:30 + at tests/backtrace.rs:32 1: query_d(Id(1)) -> (R1, Durability::HIGH) - at tests/backtrace.rs:25 + at tests/backtrace.rs:27 2: query_c(Id(1)) -> (R1, Durability::HIGH) - at tests/backtrace.rs:20 + at tests/backtrace.rs:22 3: query_b(Id(1)) -> (R1, Durability::HIGH) - at tests/backtrace.rs:15 + at tests/backtrace.rs:17 4: query_a(Id(1)) -> (R1, Durability::HIGH) - at tests/backtrace.rs:10 + at tests/backtrace.rs:12 "#]] .assert_eq(&backtrace); @@ -103,12 +105,12 @@ fn backtrace_works() { expect![[r#" query stacktrace: 0: query_e(Id(2)) - at tests/backtrace.rs:30 + at tests/backtrace.rs:32 1: query_cycle(Id(2)) - at tests/backtrace.rs:43 + at tests/backtrace.rs:45 cycle heads: query_cycle(Id(2)) -> IterationCount(0) 2: query_f(Id(2)) - at tests/backtrace.rs:38 + at tests/backtrace.rs:40 "#]] .assert_eq(&backtrace); @@ -116,12 +118,12 @@ fn backtrace_works() { expect![[r#" query stacktrace: 0: query_e(Id(3)) -> (R1, Durability::LOW) - at tests/backtrace.rs:30 + at tests/backtrace.rs:32 1: query_cycle(Id(3)) -> (R1, Durability::HIGH, iteration = IterationCount(0)) - at tests/backtrace.rs:43 + at tests/backtrace.rs:45 cycle heads: query_cycle(Id(3)) -> IterationCount(0) 2: query_f(Id(3)) -> (R1, Durability::HIGH) - at tests/backtrace.rs:38 + at tests/backtrace.rs:40 "#]] .assert_eq(&backtrace); } diff --git a/tests/check_auto_traits.rs b/tests/check_auto_traits.rs index 6e9c62c62..0d314a83c 100644 --- a/tests/check_auto_traits.rs +++ b/tests/check_auto_traits.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that auto trait impls exist as expected. use std::panic::UnwindSafe; diff --git a/tests/compile_fail.rs b/tests/compile_fail.rs index 3b4a37fcf..c43c0b4db 100644 --- a/tests/compile_fail.rs +++ b/tests/compile_fail.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + #[rustversion::all(stable, since(1.84))] #[test] fn compile_fail() { diff --git a/tests/cycle.rs b/tests/cycle.rs index f18cf92af..28266f2c5 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test cases for fixpoint iteration cycle resolution. //! //! These test cases use a generic query setup that allows constructing arbitrary dependency diff --git a/tests/cycle_accumulate.rs b/tests/cycle_accumulate.rs index d547b5760..fa31845d9 100644 --- a/tests/cycle_accumulate.rs +++ b/tests/cycle_accumulate.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use std::collections::HashSet; mod common; diff --git a/tests/cycle_fallback_immediate.rs b/tests/cycle_fallback_immediate.rs index b22767202..374978d81 100644 --- a/tests/cycle_fallback_immediate.rs +++ b/tests/cycle_fallback_immediate.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! It is possible to omit the `cycle_fn`, only specifying `cycle_result` in which case //! an immediate fallback value is used as the cycle handling opposed to doing a fixpoint resolution. diff --git a/tests/cycle_initial_call_back_into_cycle.rs b/tests/cycle_initial_call_back_into_cycle.rs index 9dfe39a92..326fd46c7 100644 --- a/tests/cycle_initial_call_back_into_cycle.rs +++ b/tests/cycle_initial_call_back_into_cycle.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Calling back into the same cycle from your cycle initial function will trigger another cycle. #[salsa::tracked] diff --git a/tests/cycle_initial_call_query.rs b/tests/cycle_initial_call_query.rs index 4c52fff27..cb10e77e1 100644 --- a/tests/cycle_initial_call_query.rs +++ b/tests/cycle_initial_call_query.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! It's possible to call a Salsa query from within a cycle initial fn. #[salsa::tracked] diff --git a/tests/cycle_maybe_changed_after.rs b/tests/cycle_maybe_changed_after.rs index 2759c65ff..6ee42d3a5 100644 --- a/tests/cycle_maybe_changed_after.rs +++ b/tests/cycle_maybe_changed_after.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Tests for incremental validation for queries involved in a cycle. mod common; diff --git a/tests/cycle_output.rs b/tests/cycle_output.rs index 8a4d13e95..975c8a44d 100644 --- a/tests/cycle_output.rs +++ b/tests/cycle_output.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test tracked struct output from a query in a cycle. mod common; use common::{HasLogger, LogDatabase, Logger}; diff --git a/tests/cycle_recovery_call_back_into_cycle.rs b/tests/cycle_recovery_call_back_into_cycle.rs index a4dc5e250..af7c10219 100644 --- a/tests/cycle_recovery_call_back_into_cycle.rs +++ b/tests/cycle_recovery_call_back_into_cycle.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Calling back into the same cycle from your cycle recovery function _can_ work out, as long as //! the overall cycle still converges. diff --git a/tests/cycle_recovery_call_query.rs b/tests/cycle_recovery_call_query.rs index a768017c8..dcc31abeb 100644 --- a/tests/cycle_recovery_call_query.rs +++ b/tests/cycle_recovery_call_query.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! It's possible to call a Salsa query from within a cycle recovery fn. #[salsa::tracked] diff --git a/tests/cycle_regression_455.rs b/tests/cycle_regression_455.rs index 5beff8d3d..99c193ab9 100644 --- a/tests/cycle_regression_455.rs +++ b/tests/cycle_regression_455.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use salsa::{Database, Setter}; #[salsa::tracked] diff --git a/tests/cycle_result_dependencies.rs b/tests/cycle_result_dependencies.rs index e7071a029..8e025f998 100644 --- a/tests/cycle_result_dependencies.rs +++ b/tests/cycle_result_dependencies.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use salsa::{Database, Setter}; #[salsa::input] diff --git a/tests/cycle_tracked.rs b/tests/cycle_tracked.rs index 30bf513f1..b9ef6ed14 100644 --- a/tests/cycle_tracked.rs +++ b/tests/cycle_tracked.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Tests for cycles where the cycle head is stored on a tracked struct //! and that tracked struct is freed in a later revision. diff --git a/tests/cycle_tracked_own_input.rs b/tests/cycle_tracked_own_input.rs index e8e520f4c..17e8b815e 100644 --- a/tests/cycle_tracked_own_input.rs +++ b/tests/cycle_tracked_own_input.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test for cycle handling where a tracked struct created in the first revision //! is stored in the final value of the cycle but isn't recreated in the second //! iteration of the creating query. diff --git a/tests/dataflow.rs b/tests/dataflow.rs index f973e970e..960cc33f5 100644 --- a/tests/dataflow.rs +++ b/tests/dataflow.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test case for fixpoint iteration cycle resolution. //! //! This test case is intended to simulate a (very simplified) version of a real dataflow analysis diff --git a/tests/debug.rs b/tests/debug.rs index 03b59dcab..da184b40a 100644 --- a/tests/debug.rs +++ b/tests/debug.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that `DeriveWithDb` is correctly derived. use expect_test::expect; diff --git a/tests/debug_db_contents.rs b/tests/debug_db_contents.rs index 30efb1736..6ab8b212e 100644 --- a/tests/debug_db_contents.rs +++ b/tests/debug_db_contents.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + #[salsa::interned(debug)] struct InternedStruct<'db> { name: String, diff --git a/tests/deletion-cascade.rs b/tests/deletion-cascade.rs index 1e02c42f0..6a17fe93e 100644 --- a/tests/deletion-cascade.rs +++ b/tests/deletion-cascade.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Delete cascade: //! //! * when we delete memoized data, also delete outputs from that data diff --git a/tests/deletion-drops.rs b/tests/deletion-drops.rs index 52b0b5120..6ce9e6a2c 100644 --- a/tests/deletion-drops.rs +++ b/tests/deletion-drops.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Basic deletion test: //! //! * entities not created in a revision are deleted, as is any memoized data keyed on them. diff --git a/tests/deletion.rs b/tests/deletion.rs index c7c415e2d..6202ea03e 100644 --- a/tests/deletion.rs +++ b/tests/deletion.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Basic deletion test: //! //! * entities not created in a revision are deleted, as is any memoized data keyed on them. diff --git a/tests/derive_update.rs b/tests/derive_update.rs index e07d17838..4ce04cb76 100644 --- a/tests/derive_update.rs +++ b/tests/derive_update.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that the `Update` derive works as expected #[derive(salsa::Update)] diff --git a/tests/durability.rs b/tests/durability.rs index 3a8e244e0..a39e30655 100644 --- a/tests/durability.rs +++ b/tests/durability.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Tests that code using the builder's durability methods compiles. use salsa::{Database, Durability, Setter}; diff --git a/tests/elided-lifetime-in-tracked-fn.rs b/tests/elided-lifetime-in-tracked-fn.rs index 4979c2ee9..81aa3a5fc 100644 --- a/tests/elided-lifetime-in-tracked-fn.rs +++ b/tests/elided-lifetime-in-tracked-fn.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. diff --git a/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs b/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs index cc0f09808..7a29d804b 100644 --- a/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs +++ b/tests/expect_reuse_field_x_of_a_tracked_struct_changes_but_fn_depends_on_field_y.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that if field X of a tracked struct changes but not field Y, //! functions that depend on X re-execute, but those depending only on Y do not //! compiles and executes successfully. diff --git a/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs b/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs index c4a6ba56e..218d875b0 100644 --- a/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs +++ b/tests/expect_reuse_field_x_of_an_input_changes_but_fn_depends_on_field_y.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that if field X of an input changes but not field Y, //! functions that depend on X re-execute, but those depending only on Y do not //! compiles and executes successfully. diff --git a/tests/hash_collision.rs b/tests/hash_collision.rs index 4efadfa33..f37c2aed3 100644 --- a/tests/hash_collision.rs +++ b/tests/hash_collision.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use std::hash::Hash; #[test] diff --git a/tests/hello_world.rs b/tests/hello_world.rs index 4a083648b..561fce916 100644 --- a/tests/hello_world.rs +++ b/tests/hello_world.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. diff --git a/tests/input_default.rs b/tests/input_default.rs index 5a4d2bd54..1fef27039 100644 --- a/tests/input_default.rs +++ b/tests/input_default.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Tests that fields attributed with `#[default]` are initialized with `Default::default()`. use salsa::Durability; diff --git a/tests/input_field_durability.rs b/tests/input_field_durability.rs index b65a512e0..de15b666b 100644 --- a/tests/input_field_durability.rs +++ b/tests/input_field_durability.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Tests that code using the builder's durability methods compiles. use salsa::Durability; diff --git a/tests/input_setter_preserves_durability.rs b/tests/input_setter_preserves_durability.rs index 529418bf3..27edc06cf 100644 --- a/tests/input_setter_preserves_durability.rs +++ b/tests/input_setter_preserves_durability.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use salsa::plumbing::ZalsaDatabase; use salsa::{Durability, Setter}; use test_log::test; diff --git a/tests/intern_access_in_different_revision.rs b/tests/intern_access_in_different_revision.rs index ff50ae7a5..ab8957cae 100644 --- a/tests/intern_access_in_different_revision.rs +++ b/tests/intern_access_in_different_revision.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use salsa::{Durability, Setter}; #[salsa::interned(no_lifetime)] diff --git a/tests/interned-revisions.rs b/tests/interned-revisions.rs index f48393b8a..225f24d4f 100644 --- a/tests/interned-revisions.rs +++ b/tests/interned-revisions.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. diff --git a/tests/interned-structs.rs b/tests/interned-structs.rs index da9ec6ae5..931b1ab67 100644 --- a/tests/interned-structs.rs +++ b/tests/interned-structs.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. diff --git a/tests/interned-structs_self_ref.rs b/tests/interned-structs_self_ref.rs index 69a19fbf4..556c49607 100644 --- a/tests/interned-structs_self_ref.rs +++ b/tests/interned-structs_self_ref.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. @@ -35,7 +37,18 @@ struct InternedString<'db>( const _: () = { use salsa::plumbing as zalsa_; use zalsa_::interned as zalsa_struct_; + type Configuration_ = InternedString<'static>; + + impl<'db> zalsa_::HasJar for InternedString<'db> { + type Jar = zalsa_struct_::JarImpl; + const KIND: zalsa_::JarKind = zalsa_::JarKind::Struct; + } + + zalsa_::register_jar! { + zalsa_::ErasedJar::erase::>() + } + #[derive(Clone)] struct StructData<'db>(String, InternedString<'db>); @@ -87,9 +100,7 @@ const _: () = { let zalsa = db.zalsa(); CACHE.get_or_create(zalsa, || { - zalsa - .lookup_jar_by_type::>() - .get_or_create() + zalsa.lookup_jar_by_type::>() }) } } @@ -115,9 +126,8 @@ const _: () = { impl zalsa_::SalsaStructInDb for InternedString<'_> { type MemoIngredientMap = zalsa_::MemoIngredientSingletonIndex; - fn lookup_or_create_ingredient_index(aux: &Zalsa) -> salsa::plumbing::IngredientIndices { + fn lookup_ingredient_index(aux: &Zalsa) -> salsa::plumbing::IngredientIndices { aux.lookup_jar_by_type::>() - .get_or_create() .into() } diff --git a/tests/lru.rs b/tests/lru.rs index e1c11e504..1d417267a 100644 --- a/tests/lru.rs +++ b/tests/lru.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn with lru options //! compiles and executes successfully. diff --git a/tests/manual_registration.rs b/tests/manual_registration.rs new file mode 100644 index 000000000..16d87dc59 --- /dev/null +++ b/tests/manual_registration.rs @@ -0,0 +1,92 @@ +#![cfg(not(feature = "inventory"))] + +mod ingredients { + #[salsa::input] + pub(super) struct MyInput { + field: u32, + } + + #[salsa::tracked] + pub(super) struct MyTracked<'db> { + pub(super) field: u32, + } + + #[salsa::interned] + pub(super) struct MyInterned<'db> { + pub(super) field: u32, + } + + #[salsa::tracked] + pub(super) fn intern<'db>(db: &'db dyn salsa::Database, input: MyInput) -> MyInterned<'db> { + MyInterned::new(db, input.field(db)) + } + + #[salsa::tracked] + pub(super) fn track<'db>(db: &'db dyn salsa::Database, input: MyInput) -> MyTracked<'db> { + MyTracked::new(db, input.field(db)) + } +} + +#[salsa::db] +#[derive(Clone, Default)] +pub struct DatabaseImpl { + storage: salsa::Storage, +} + +#[salsa::db] +impl salsa::Database for DatabaseImpl {} + +#[test] +fn single_database() { + let db = DatabaseImpl { + storage: salsa::Storage::builder() + .ingredient::() + .ingredient::() + .ingredient::() + .ingredient::>() + .ingredient::>() + .build(), + }; + + let input = ingredients::MyInput::new(&db, 1); + + let tracked = ingredients::track(&db, input); + let interned = ingredients::intern(&db, input); + + assert_eq!(tracked.field(&db), 1); + assert_eq!(interned.field(&db), 1); +} + +#[test] +fn multiple_databases() { + let db1 = DatabaseImpl { + storage: salsa::Storage::builder() + .ingredient::() + .ingredient::() + .ingredient::>() + .build(), + }; + + let input = ingredients::MyInput::new(&db1, 1); + let interned = ingredients::intern(&db1, input); + assert_eq!(interned.field(&db1), 1); + + // Create a second database with different ingredient indices. + let db2 = DatabaseImpl { + storage: salsa::Storage::builder() + .ingredient::() + .ingredient::() + .ingredient::() + .ingredient::>() + .ingredient::>() + .build(), + }; + + let input = ingredients::MyInput::new(&db2, 2); + let interned = ingredients::intern(&db2, input); + assert_eq!(interned.field(&db2), 2); + + let input = ingredients::MyInput::new(&db2, 3); + let tracked = ingredients::track(&db2, input); + assert_eq!(tracked.field(&db2), 3); +} diff --git a/tests/memory-usage.rs b/tests/memory-usage.rs index a990ff6a3..f9fca29ab 100644 --- a/tests/memory-usage.rs +++ b/tests/memory-usage.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use expect_test::expect; #[salsa::input] diff --git a/tests/mutate_in_place.rs b/tests/mutate_in_place.rs index 047373ee5..5327df419 100644 --- a/tests/mutate_in_place.rs +++ b/tests/mutate_in_place.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a setting a field on a `#[salsa::input]` //! overwrites and returns the old value. diff --git a/tests/override_new_get_set.rs b/tests/override_new_get_set.rs index 367decf04..222ba7b46 100644 --- a/tests/override_new_get_set.rs +++ b/tests/override_new_get_set.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that the `constructor` macro overrides //! the `new` method's name and `get` and `set` //! change the name of the getter and setter of the fields. diff --git a/tests/panic-when-creating-tracked-struct-outside-of-tracked-fn.rs b/tests/panic-when-creating-tracked-struct-outside-of-tracked-fn.rs index 32b444c7f..dfc4f972b 100644 --- a/tests/panic-when-creating-tracked-struct-outside-of-tracked-fn.rs +++ b/tests/panic-when-creating-tracked-struct-outside-of-tracked-fn.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that creating a tracked struct outside of a //! tracked function panics with an assert message. diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index da5e5e4a1..e14780424 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + mod setup; mod signal; diff --git a/tests/preverify-struct-with-leaked-data-2.rs b/tests/preverify-struct-with-leaked-data-2.rs index d7e3f8f9f..f3d6c05d5 100644 --- a/tests/preverify-struct-with-leaked-data-2.rs +++ b/tests/preverify-struct-with-leaked-data-2.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. diff --git a/tests/preverify-struct-with-leaked-data.rs b/tests/preverify-struct-with-leaked-data.rs index 5c0f84954..6af7c6e80 100644 --- a/tests/preverify-struct-with-leaked-data.rs +++ b/tests/preverify-struct-with-leaked-data.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. diff --git a/tests/return_mode.rs b/tests/return_mode.rs index 34f676288..fdf116002 100644 --- a/tests/return_mode.rs +++ b/tests/return_mode.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use salsa::Database; #[salsa::input] diff --git a/tests/singleton.rs b/tests/singleton.rs index 381db9b74..5c7d2c71e 100644 --- a/tests/singleton.rs +++ b/tests/singleton.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Basic Singleton struct test: //! //! Singleton structs are created only once. Subsequent `get`s and `new`s after creation return the same `Id`. diff --git a/tests/specify-only-works-if-the-key-is-created-in-the-current-query.rs b/tests/specify-only-works-if-the-key-is-created-in-the-current-query.rs index a407aee62..44d1cf5e6 100644 --- a/tests/specify-only-works-if-the-key-is-created-in-the-current-query.rs +++ b/tests/specify-only-works-if-the-key-is-created-in-the-current-query.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that `specify` only works if the key is a tracked struct created in the current query. //! compilation succeeds but execution panics #![allow(warnings)] diff --git a/tests/synthetic_write.rs b/tests/synthetic_write.rs index 9e3c2f305..ccd6d0266 100644 --- a/tests/synthetic_write.rs +++ b/tests/synthetic_write.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a constant `tracked` fn (has no inputs) //! compiles and executes successfully. #![allow(warnings)] diff --git a/tests/tracked-struct-id-field-bad-eq.rs b/tests/tracked-struct-id-field-bad-eq.rs index 44deec6cb..d2d1a9c58 100644 --- a/tests/tracked-struct-id-field-bad-eq.rs +++ b/tests/tracked-struct-id-field-bad-eq.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test an id field whose `PartialEq` impl is always true. use salsa::{Database, Setter}; diff --git a/tests/tracked-struct-id-field-bad-hash.rs b/tests/tracked-struct-id-field-bad-hash.rs index f60aa4798..bcd765bf5 100644 --- a/tests/tracked-struct-id-field-bad-hash.rs +++ b/tests/tracked-struct-id-field-bad-hash.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test for a tracked struct where an untracked field has a //! very poorly chosen hash impl (always returns 0). //! diff --git a/tests/tracked-struct-unchanged-in-new-rev.rs b/tests/tracked-struct-unchanged-in-new-rev.rs index e4633740f..da782eb79 100644 --- a/tests/tracked-struct-unchanged-in-new-rev.rs +++ b/tests/tracked-struct-unchanged-in-new-rev.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use salsa::{Database as Db, Setter}; use test_log::test; diff --git a/tests/tracked-struct-value-field-bad-eq.rs b/tests/tracked-struct-value-field-bad-eq.rs index 3a02d63c5..f05cfb59e 100644 --- a/tests/tracked-struct-value-field-bad-eq.rs +++ b/tests/tracked-struct-value-field-bad-eq.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test a field whose `PartialEq` impl is always true. //! This can result in us getting different results than //! if we were to execute from scratch. diff --git a/tests/tracked-struct-value-field-not-eq.rs b/tests/tracked-struct-value-field-not-eq.rs index e37d4af9e..451099d1e 100644 --- a/tests/tracked-struct-value-field-not-eq.rs +++ b/tests/tracked-struct-value-field-not-eq.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test a field whose `PartialEq` impl is always true. //! This can our "last changed" data to be wrong //! but we *should* always reflect the final values. diff --git a/tests/tracked_assoc_fn.rs b/tests/tracked_assoc_fn.rs index 18e6b953d..c369740f9 100644 --- a/tests/tracked_assoc_fn.rs +++ b/tests/tracked_assoc_fn.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. #![allow(warnings)] diff --git a/tests/tracked_fn_constant.rs b/tests/tracked_fn_constant.rs index c6753ebf4..04681fcde 100644 --- a/tests/tracked_fn_constant.rs +++ b/tests/tracked_fn_constant.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a constant `tracked` fn (has no inputs) //! compiles and executes successfully. #![allow(warnings)] diff --git a/tests/tracked_fn_high_durability_dependency.rs b/tests/tracked_fn_high_durability_dependency.rs index a05be178f..17b7efe70 100644 --- a/tests/tracked_fn_high_durability_dependency.rs +++ b/tests/tracked_fn_high_durability_dependency.rs @@ -1,3 +1,4 @@ +#![cfg(feature = "inventory")] #![allow(warnings)] use salsa::plumbing::HasStorage; diff --git a/tests/tracked_fn_interned_lifetime.rs b/tests/tracked_fn_interned_lifetime.rs index 99b33af4b..0cce1c321 100644 --- a/tests/tracked_fn_interned_lifetime.rs +++ b/tests/tracked_fn_interned_lifetime.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + #[salsa::interned] struct Interned<'db> { field: i32, diff --git a/tests/tracked_fn_multiple_args.rs b/tests/tracked_fn_multiple_args.rs index 7c014356c..5bed18089 100644 --- a/tests/tracked_fn_multiple_args.rs +++ b/tests/tracked_fn_multiple_args.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on multiple salsa struct args //! compiles and executes successfully. diff --git a/tests/tracked_fn_no_eq.rs b/tests/tracked_fn_no_eq.rs index 6f223b791..5a1a4ed0d 100644 --- a/tests/tracked_fn_no_eq.rs +++ b/tests/tracked_fn_no_eq.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + mod common; use common::LogDatabase; diff --git a/tests/tracked_fn_on_input.rs b/tests/tracked_fn_on_input.rs index e588a40a9..a177f4d46 100644 --- a/tests/tracked_fn_on_input.rs +++ b/tests/tracked_fn_on_input.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. #![allow(warnings)] diff --git a/tests/tracked_fn_on_input_with_high_durability.rs b/tests/tracked_fn_on_input_with_high_durability.rs index fbf122dea..07c9aee9e 100644 --- a/tests/tracked_fn_on_input_with_high_durability.rs +++ b/tests/tracked_fn_on_input_with_high_durability.rs @@ -1,3 +1,4 @@ +#![cfg(feature = "inventory")] #![allow(warnings)] use common::{EventLoggerDatabase, HasLogger, LogDatabase, Logger}; diff --git a/tests/tracked_fn_on_interned.rs b/tests/tracked_fn_on_interned.rs index b551b880d..78a8a9fdf 100644 --- a/tests/tracked_fn_on_interned.rs +++ b/tests/tracked_fn_on_interned.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::interned` //! compiles and executes successfully. diff --git a/tests/tracked_fn_on_interned_enum.rs b/tests/tracked_fn_on_interned_enum.rs index 63fa03b67..df8109196 100644 --- a/tests/tracked_fn_on_interned_enum.rs +++ b/tests/tracked_fn_on_interned_enum.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::interned` //! compiles and executes successfully. diff --git a/tests/tracked_fn_on_tracked.rs b/tests/tracked_fn_on_tracked.rs index 967bbd558..802e76fa7 100644 --- a/tests/tracked_fn_on_tracked.rs +++ b/tests/tracked_fn_on_tracked.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. diff --git a/tests/tracked_fn_on_tracked_specify.rs b/tests/tracked_fn_on_tracked_specify.rs index 70e4997a2..42d7f58c8 100644 --- a/tests/tracked_fn_on_tracked_specify.rs +++ b/tests/tracked_fn_on_tracked_specify.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. #![allow(warnings)] diff --git a/tests/tracked_fn_orphan_escape_hatch.rs b/tests/tracked_fn_orphan_escape_hatch.rs index f94e93949..25d2d7f4c 100644 --- a/tests/tracked_fn_orphan_escape_hatch.rs +++ b/tests/tracked_fn_orphan_escape_hatch.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. #![allow(warnings)] diff --git a/tests/tracked_fn_read_own_entity.rs b/tests/tracked_fn_read_own_entity.rs index 11ae72c18..a62c82794 100644 --- a/tests/tracked_fn_read_own_entity.rs +++ b/tests/tracked_fn_read_own_entity.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. diff --git a/tests/tracked_fn_read_own_specify.rs b/tests/tracked_fn_read_own_specify.rs index a96cba356..26b9a2f56 100644 --- a/tests/tracked_fn_read_own_specify.rs +++ b/tests/tracked_fn_read_own_specify.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use expect_test::expect; mod common; use common::LogDatabase; diff --git a/tests/tracked_fn_return_ref.rs b/tests/tracked_fn_return_ref.rs index 918f33b37..c7691c1b9 100644 --- a/tests/tracked_fn_return_ref.rs +++ b/tests/tracked_fn_return_ref.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use salsa::Database; #[salsa::input] diff --git a/tests/tracked_method.rs b/tests/tracked_method.rs index d5953deaa..f3ee7d798 100644 --- a/tests/tracked_method.rs +++ b/tests/tracked_method.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. #![allow(warnings)] diff --git a/tests/tracked_method_inherent_return_deref.rs b/tests/tracked_method_inherent_return_deref.rs index 2477b5a1d..ac02bf92d 100644 --- a/tests/tracked_method_inherent_return_deref.rs +++ b/tests/tracked_method_inherent_return_deref.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use salsa::Database; #[salsa::input] diff --git a/tests/tracked_method_inherent_return_ref.rs b/tests/tracked_method_inherent_return_ref.rs index 564bb31ff..2d22336b3 100644 --- a/tests/tracked_method_inherent_return_ref.rs +++ b/tests/tracked_method_inherent_return_ref.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use salsa::Database; #[salsa::input] diff --git a/tests/tracked_method_on_tracked_struct.rs b/tests/tracked_method_on_tracked_struct.rs index f7cb9f6da..e99a649b5 100644 --- a/tests/tracked_method_on_tracked_struct.rs +++ b/tests/tracked_method_on_tracked_struct.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use salsa::Database; #[derive(Debug, PartialEq, Eq, Hash)] diff --git a/tests/tracked_method_trait_return_ref.rs b/tests/tracked_method_trait_return_ref.rs index ec74cf3ae..e632a7ce5 100644 --- a/tests/tracked_method_trait_return_ref.rs +++ b/tests/tracked_method_trait_return_ref.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + use salsa::Database; #[salsa::input] diff --git a/tests/tracked_method_with_self_ty.rs b/tests/tracked_method_with_self_ty.rs index 8f8b0678d..a541af989 100644 --- a/tests/tracked_method_with_self_ty.rs +++ b/tests/tracked_method_with_self_ty.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a `tracked` fn with `Self` in its signature or body on a `salsa::input` //! compiles and executes successfully. #![allow(warnings)] diff --git a/tests/tracked_struct.rs b/tests/tracked_struct.rs index a1bba36c6..e148bffd7 100644 --- a/tests/tracked_struct.rs +++ b/tests/tracked_struct.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + mod common; use salsa::{Database, Setter}; diff --git a/tests/tracked_struct_db1_lt.rs b/tests/tracked_struct_db1_lt.rs index e5de757ca..277666118 100644 --- a/tests/tracked_struct_db1_lt.rs +++ b/tests/tracked_struct_db1_lt.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that tracked structs with lifetimes not named `'db` //! compile successfully. diff --git a/tests/tracked_struct_disambiguates.rs b/tests/tracked_struct_disambiguates.rs index 663fedf42..db0435a20 100644 --- a/tests/tracked_struct_disambiguates.rs +++ b/tests/tracked_struct_disambiguates.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that disambiguation works, that is when we have a revision where we track multiple structs //! that have the same hash, we can still differentiate between them. #![allow(warnings)] diff --git a/tests/tracked_struct_durability.rs b/tests/tracked_struct_durability.rs index 7dfd87284..3608a5e78 100644 --- a/tests/tracked_struct_durability.rs +++ b/tests/tracked_struct_durability.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + /// Test that high durabilities can't cause "access tracked struct from previous revision" panic. /// /// The test models a situation where we have two File inputs (0, 1), where `File(0)` has LOW diff --git a/tests/tracked_struct_manual_update.rs b/tests/tracked_struct_manual_update.rs index 10c91117a..e9dca14a3 100644 --- a/tests/tracked_struct_manual_update.rs +++ b/tests/tracked_struct_manual_update.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + mod common; use std::sync::atomic::{AtomicBool, Ordering}; diff --git a/tests/tracked_struct_mixed_tracked_fields.rs b/tests/tracked_struct_mixed_tracked_fields.rs index a5630c631..50cd4cc72 100644 --- a/tests/tracked_struct_mixed_tracked_fields.rs +++ b/tests/tracked_struct_mixed_tracked_fields.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + mod common; use salsa::{Database, Setter}; diff --git a/tests/tracked_struct_recreate_new_revision.rs b/tests/tracked_struct_recreate_new_revision.rs index c1786a2dd..d6a0a6e8f 100644 --- a/tests/tracked_struct_recreate_new_revision.rs +++ b/tests/tracked_struct_recreate_new_revision.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that re-creating a `tracked` struct after it was deleted in a previous //! revision doesn't panic. #![allow(warnings)] diff --git a/tests/tracked_struct_with_interned_query.rs b/tests/tracked_struct_with_interned_query.rs index 4476cceaa..c41605310 100644 --- a/tests/tracked_struct_with_interned_query.rs +++ b/tests/tracked_struct_with_interned_query.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + mod common; use salsa::Setter; diff --git a/tests/tracked_with_intern.rs b/tests/tracked_with_intern.rs index 1b3a381d3..a2ab5d6aa 100644 --- a/tests/tracked_with_intern.rs +++ b/tests/tracked_with_intern.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a setting a field on a `#[salsa::input]` //! overwrites and returns the old value. diff --git a/tests/tracked_with_struct_db.rs b/tests/tracked_with_struct_db.rs index 9af7b5500..323f9facf 100644 --- a/tests/tracked_with_struct_db.rs +++ b/tests/tracked_with_struct_db.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that a setting a field on a `#[salsa::input]` //! overwrites and returns the old value. diff --git a/tests/tracked_with_struct_ord.rs b/tests/tracked_with_struct_ord.rs index 51fef9cc6..1b3c6a79e 100644 --- a/tests/tracked_with_struct_ord.rs +++ b/tests/tracked_with_struct_ord.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "inventory")] + //! Test that `PartialOrd` and `Ord` can be derived for tracked structs use salsa::{Database, DatabaseImpl}; From 962e0b924e3bae97a0fe84f328a541281ca5457e Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Mon, 21 Jul 2025 02:34:48 -0400 Subject: [PATCH 05/65] optimize page access (#940) --- src/table.rs | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/table.rs b/src/table.rs index 3f08841d7..7132de1ae 100644 --- a/src/table.rs +++ b/src/table.rs @@ -383,13 +383,10 @@ impl Page { #[inline] fn assert_type(&self) -> PageView<'_, T> { - assert_eq!( - self.slot_type_id, - TypeId::of::(), - "page has slot type `{:?}` but `{:?}` was expected", - self.slot_type_name, - std::any::type_name::(), - ); + if self.slot_type_id != TypeId::of::() { + type_assert_failed::(self); + } + PageView(self, PhantomData) } @@ -403,6 +400,17 @@ impl Page { } } +/// This function is explicitly outlined to avoid debug machinery in the hot-path. +#[cold] +#[inline(never)] +fn type_assert_failed(page: &Page) -> ! { + panic!( + "page has slot type `{:?}` but `{:?}` was expected", + page.slot_type_name, + std::any::type_name::(), + ) +} + impl Drop for Page { fn drop(&mut self) { let len = *self.allocated.get_mut(); From 0e1df67ec4f09e46a4f2bfbb6e6c8025f2686715 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Mon, 21 Jul 2025 12:09:11 -0400 Subject: [PATCH 06/65] Avoid dynamic dispatch to access memo tables (#941) * avoid dynamic dispatch to access memo tables * update comment --- .../src/setup_input_struct.rs | 10 +++++ .../src/setup_interned_struct.rs | 10 +++++ .../salsa-macro-rules/src/setup_tracked_fn.rs | 10 +++++ .../src/setup_tracked_struct.rs | 10 +++++ components/salsa-macros/src/supertype.rs | 13 +++++++ src/function/memo.rs | 13 +++++-- src/lib.rs | 1 + src/salsa_struct.rs | 15 +++++++- src/table.rs | 38 ++++++++++++++++--- src/table/memo.rs | 4 +- src/tracked_struct.rs | 5 +-- src/zalsa.rs | 10 ++--- tests/interned-structs_self_ref.rs | 14 +++++++ 13 files changed, 133 insertions(+), 20 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index 38988c546..06315e07a 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -170,6 +170,16 @@ macro_rules! setup_input_struct { $zalsa::None } } + + #[inline] + unsafe fn memo_table( + zalsa: &$zalsa::Zalsa, + id: $zalsa::Id, + current_revision: $zalsa::Revision, + ) -> $zalsa::MemoTableWithTypes<'_> { + // SAFETY: Guaranteed by caller. + unsafe { zalsa.table().memos::<$zalsa_struct::Value<$Configuration>>(id, current_revision) } + } } impl $Struct { diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index 3a62355cd..19aeaa53a 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -202,6 +202,16 @@ macro_rules! setup_interned_struct { $zalsa::None } } + + #[inline] + unsafe fn memo_table( + zalsa: &$zalsa::Zalsa, + id: $zalsa::Id, + current_revision: $zalsa::Revision, + ) -> $zalsa::MemoTableWithTypes<'_> { + // SAFETY: Guaranteed by caller. + unsafe { zalsa.table().memos::<$zalsa_struct::Value<$Configuration>>(id, current_revision) } + } } unsafe impl< $($db_lt_arg)? > $zalsa::Update for $Struct< $($db_lt_arg)? > { diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 477070714..b05030241 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -130,6 +130,16 @@ macro_rules! setup_tracked_fn { None } } + + #[inline] + unsafe fn memo_table( + zalsa: &$zalsa::Zalsa, + id: $zalsa::Id, + current_revision: $zalsa::Revision, + ) -> $zalsa::MemoTableWithTypes<'_> { + // SAFETY: Guaranteed by caller. + unsafe { zalsa.table().memos::<$zalsa::interned::Value<$Configuration>>(id, current_revision) } + } } impl $zalsa::AsId for $InternedData<'_> { diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index 07131735e..0b5c115bf 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -231,6 +231,16 @@ macro_rules! setup_tracked_struct { $zalsa::None } } + + #[inline] + unsafe fn memo_table( + zalsa: &$zalsa::Zalsa, + id: $zalsa::Id, + current_revision: $zalsa::Revision, + ) -> $zalsa::MemoTableWithTypes<'_> { + // SAFETY: Guaranteed by caller. + unsafe { zalsa.table().memos::<$zalsa_struct::Value<$Configuration>>(id, current_revision) } + } } impl $zalsa::TrackedStructInDb for $Struct<'_> { diff --git a/components/salsa-macros/src/supertype.rs b/components/salsa-macros/src/supertype.rs index d1c6c70b8..ebf7f4516 100644 --- a/components/salsa-macros/src/supertype.rs +++ b/components/salsa-macros/src/supertype.rs @@ -89,6 +89,19 @@ fn enum_impl(enum_item: syn::ItemEnum) -> syn::Result { None } } + + #[inline] + unsafe fn memo_table( + zalsa: &zalsa::Zalsa, + id: zalsa::Id, + current_revision: zalsa::Revision, + ) -> zalsa::MemoTableWithTypes<'_> { + // Note that we need to use `dyn_memos` here, as the `Id` could map to any variant + // of the supertype enum. + // + // SAFETY: Guaranteed by caller. + unsafe { zalsa.table().dyn_memos(id, current_revision) } + } } }; diff --git a/src/function/memo.rs b/src/function/memo.rs index 8f8393fc6..37e600070 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -30,7 +30,7 @@ impl IngredientImpl { let static_memo = unsafe { transmute::>, NonNull>>(memo) }; let old_static_memo = zalsa - .memo_table_for(id) + .memo_table_for::>(id) .insert(memo_ingredient_index, static_memo)?; // SAFETY: The table stores 'static memos (to support `Any`), the memos are in fact valid // for `'db` though as we delay their dropping to the end of a revision. @@ -48,7 +48,9 @@ impl IngredientImpl { id: Id, memo_ingredient_index: MemoIngredientIndex, ) -> Option<&'db Memo<'db, C>> { - let static_memo = zalsa.memo_table_for(id).get(memo_ingredient_index)?; + let static_memo = zalsa + .memo_table_for::>(id) + .get(memo_ingredient_index)?; // SAFETY: The table stores 'static memos (to support `Any`), the memos are in fact valid // for `'db` though as we delay their dropping to the end of a revision. Some(unsafe { transmute::<&Memo<'static, C>, &'db Memo<'db, C>>(static_memo.as_ref()) }) @@ -451,8 +453,9 @@ mod _memory_usage { use crate::cycle::CycleRecoveryStrategy; use crate::ingredient::Location; use crate::plumbing::{IngredientIndices, MemoIngredientSingletonIndex, SalsaStructInDb}; + use crate::table::memo::MemoTableWithTypes; use crate::zalsa::Zalsa; - use crate::{CycleRecoveryAction, Database, Id}; + use crate::{CycleRecoveryAction, Database, Id, Revision}; use std::any::TypeId; use std::num::NonZeroUsize; @@ -473,6 +476,10 @@ mod _memory_usage { fn cast(_: Id, _: TypeId) -> Option { unimplemented!() } + + unsafe fn memo_table(_: &Zalsa, _: Id, _: Revision) -> MemoTableWithTypes<'_> { + unimplemented!() + } } struct DummyConfiguration; diff --git a/src/lib.rs b/src/lib.rs index 171452250..cbc919377 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -104,6 +104,7 @@ pub mod plumbing { pub use crate::runtime::{stamp, Runtime, Stamp}; pub use crate::salsa_struct::SalsaStructInDb; pub use crate::storage::{HasStorage, Storage}; + pub use crate::table::memo::MemoTableWithTypes; pub use crate::tracked_struct::TrackedStructInDb; pub use crate::update::helper::{Dispatch as UpdateDispatch, Fallback as UpdateFallback}; pub use crate::update::{always_update, Update}; diff --git a/src/salsa_struct.rs b/src/salsa_struct.rs index 725b308a4..cb3307e65 100644 --- a/src/salsa_struct.rs +++ b/src/salsa_struct.rs @@ -1,8 +1,9 @@ use std::any::TypeId; use crate::memo_ingredient_indices::{IngredientIndices, MemoIngredientMap}; +use crate::table::memo::MemoTableWithTypes; use crate::zalsa::Zalsa; -use crate::Id; +use crate::{Id, Revision}; pub trait SalsaStructInDb: Sized { type MemoIngredientMap: MemoIngredientMap; @@ -62,4 +63,16 @@ pub trait SalsaStructInDb: Sized { /// Why `TypeId` and not `IngredientIndex`? Because it's cheaper and easier: the `TypeId` is readily /// available at compile time, while the `IngredientIndex` requires a runtime lookup. fn cast(id: Id, type_id: TypeId) -> Option; + + /// Return the memo table associated with `id`. + /// + /// # Safety + /// + /// The parameter `current_revision` must be the current revision of the owner of database + /// owning this table. + unsafe fn memo_table( + zalsa: &Zalsa, + id: Id, + current_revision: Revision, + ) -> MemoTableWithTypes<'_>; } diff --git a/src/table.rs b/src/table.rs index 7132de1ae..84415a4a0 100644 --- a/src/table.rs +++ b/src/table.rs @@ -34,7 +34,7 @@ pub struct Table { /// /// Implementors of this trait need to make sure that their type is unique with respect to /// their owning ingredient as the allocation strategy relies on this. -pub(crate) unsafe trait Slot: Any + Send + Sync { +pub unsafe trait Slot: Any + Send + Sync { /// Access the [`MemoTable`][] for this slot. /// /// # Safety condition @@ -220,17 +220,42 @@ impl Table { PageIndex::new(self.pages.push(Page::new::(ingredient, memo_types))) } - /// Get the memo table associated with `id` + /// Get the memo table associated with `id` for the concrete type `T`. /// - /// # Safety condition + /// # Safety /// - /// The parameter `current_revision` MUST be the current revision - /// of the owner of database owning this table. - pub(crate) unsafe fn memos( + /// The parameter `current_revision` must be the current revision of the database + /// owning this table. + /// + /// # Panics + /// + /// If `page` is out of bounds or the type `T` is incorrect. + pub unsafe fn memos( &self, id: Id, current_revision: Revision, ) -> MemoTableWithTypes<'_> { + let (page, slot) = split_id(id); + let page = self.pages[page.0].assert_type::(); + let slot = &page.data()[slot.0]; + + // SAFETY: The caller is required to pass the `current_revision`. + let memos = unsafe { slot.memos(current_revision) }; + + // SAFETY: The `Page` keeps the correct memo types. + unsafe { page.0.memo_types.attach_memos(memos) } + } + + /// Get the memo table associated with `id`. + /// + /// Unlike `Table::memos`, this does not require a concrete type, and instead uses dynamic + /// dispatch. + /// + /// # Safety + /// + /// The parameter `current_revision` must be the current revision of the owner of database + /// owning this table. + pub unsafe fn dyn_memos(&self, id: Id, current_revision: Revision) -> MemoTableWithTypes<'_> { let (page, slot) = split_id(id); let page = &self.pages[page.0]; // SAFETY: We supply a proper slot pointer and the caller is required to pass the `current_revision`. @@ -373,6 +398,7 @@ impl Page { slot.0 < len, "out of bounds access `{slot:?}` (maximum slot `{len}`)" ); + // SAFETY: We have checked that the resulting pointer will be within bounds. unsafe { self.data diff --git a/src/table/memo.rs b/src/table/memo.rs index 63d3671ad..3d5f9bc17 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -9,7 +9,7 @@ use crate::{zalsa::MemoIngredientIndex, zalsa_local::QueryOriginRef}; /// The "memo table" stores the memoized results of tracked function calls. /// Every tracked function must take a salsa struct as its first argument /// and memo tables are attached to those salsa structs as auxiliary data. -pub(crate) struct MemoTable { +pub struct MemoTable { memos: Box<[MemoEntry]>, } @@ -168,7 +168,7 @@ impl MemoTableTypes { } } -pub(crate) struct MemoTableWithTypes<'a> { +pub struct MemoTableWithTypes<'a> { types: &'a MemoTableTypes, memos: &'a MemoTable, } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index fd93250aa..e662fa902 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -954,9 +954,8 @@ where { #[inline(always)] unsafe fn memos(&self, current_revision: Revision) -> &crate::table::memo::MemoTable { - // Acquiring the read lock here with the current revision - // ensures that there is no danger of a race - // when deleting a tracked struct. + // Acquiring the read lock here with the current revision to ensure that there + // is no danger of a race when deleting a tracked struct. self.read_lock(current_revision); &self.memos } diff --git a/src/zalsa.rs b/src/zalsa.rs index c1c46296d..9f5a1fb3c 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -7,6 +7,7 @@ use rustc_hash::FxHashMap; use crate::hash::TypeIdHasher; use crate::ingredient::{Ingredient, Jar}; +use crate::plumbing::SalsaStructInDb; use crate::runtime::Runtime; use crate::table::memo::MemoTableWithTypes; use crate::table::Table; @@ -225,15 +226,14 @@ impl Zalsa { /// Returns the [`Table`] used to store the value of salsa structs #[inline] - pub(crate) fn table(&self) -> &Table { + pub fn table(&self) -> &Table { self.runtime.table() } /// Returns the [`MemoTable`][] for the salsa struct with the given id - pub(crate) fn memo_table_for(&self, id: Id) -> MemoTableWithTypes<'_> { - let table = self.table(); - // SAFETY: We are supplying the correct current revision - unsafe { table.memos(id, self.current_revision()) } + pub(crate) fn memo_table_for(&self, id: Id) -> MemoTableWithTypes<'_> { + // SAFETY: We are supplying the correct current revision. + unsafe { T::memo_table(self, id, self.current_revision()) } } #[inline] diff --git a/tests/interned-structs_self_ref.rs b/tests/interned-structs_self_ref.rs index 556c49607..3ff12a09c 100644 --- a/tests/interned-structs_self_ref.rs +++ b/tests/interned-structs_self_ref.rs @@ -139,6 +139,20 @@ const _: () = { None } } + + #[inline] + unsafe fn memo_table( + zalsa: &zalsa_::Zalsa, + id: zalsa_::Id, + current_revision: zalsa_::Revision, + ) -> zalsa_::MemoTableWithTypes<'_> { + // SAFETY: Guaranteed by caller. + unsafe { + zalsa + .table() + .memos::>(id, current_revision) + } + } } unsafe impl zalsa_::Update for InternedString<'_> { From 53cd6b15ba3b6d6e405c043fd4c2eb187358d812 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Mon, 21 Jul 2025 12:10:10 -0400 Subject: [PATCH 07/65] remove bounds and type checks from `IngredientCache` (#937) --- .../src/setup_accumulator_impl.rs | 10 ++-- .../src/setup_input_struct.rs | 10 ++-- .../src/setup_interned_struct.rs | 11 +++-- .../salsa-macro-rules/src/setup_tracked_fn.rs | 21 +++++--- .../src/setup_tracked_struct.rs | 10 ++-- src/ingredient.rs | 25 +++++++++- src/ingredient_cache.rs | 49 ++++++++++++++++--- src/tracked_struct.rs | 24 ++++----- src/zalsa.rs | 24 +++++++-- src/zalsa_local.rs | 6 ++- tests/interned-structs_self_ref.rs | 11 +++-- 11 files changed, 155 insertions(+), 46 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_accumulator_impl.rs b/components/salsa-macro-rules/src/setup_accumulator_impl.rs index 7842067e7..3edbe9c2f 100644 --- a/components/salsa-macro-rules/src/setup_accumulator_impl.rs +++ b/components/salsa-macro-rules/src/setup_accumulator_impl.rs @@ -34,9 +34,13 @@ macro_rules! setup_accumulator_impl { static $CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Struct>> = $zalsa::IngredientCache::new(); - $CACHE.get_or_create(zalsa, || { - zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Struct>>() - }) + // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only + // ingredient created by our jar is the struct ingredient. + unsafe { + $CACHE.get_or_create(zalsa, || { + zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Struct>>() + }) + } } impl $zalsa::Accumulator for $Struct { diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index 06315e07a..fed0594a6 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -109,9 +109,13 @@ macro_rules! setup_input_struct { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); - CACHE.get_or_create(zalsa, || { - zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() - }) + // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only + // ingredient created by our jar is the struct ingredient. + unsafe { + CACHE.get_or_create(zalsa, || { + zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() + }) + } } pub fn ingredient_mut(db: &mut dyn $zalsa::Database) -> (&mut $zalsa_struct::IngredientImpl, &mut $zalsa::Runtime) { diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index 19aeaa53a..9b121fdc1 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -157,9 +157,14 @@ macro_rules! setup_interned_struct { $zalsa::IngredientCache::new(); let zalsa = db.zalsa(); - CACHE.get_or_create(zalsa, || { - zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() - }) + + // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only + // ingredient created by our jar is the struct ingredient. + unsafe { + CACHE.get_or_create(zalsa, || { + zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() + }) + } } } diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index b05030241..a79007592 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -175,9 +175,13 @@ macro_rules! setup_tracked_fn { impl $Configuration { fn fn_ingredient(db: &dyn $Db) -> &$zalsa::function::IngredientImpl<$Configuration> { let zalsa = db.zalsa(); - $FN_CACHE - .get_or_create(zalsa, || zalsa.lookup_jar_by_type::<$fn_name>()) - .get_or_init(|| ::zalsa_register_downcaster(db)) + + // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the first + // ingredient created by our jar is the function ingredient. + unsafe { + $FN_CACHE.get_or_create(zalsa, || zalsa.lookup_jar_by_type::<$fn_name>()) + } + .get_or_init(|| ::zalsa_register_downcaster(db)) } pub fn fn_ingredient_mut(db: &mut dyn $Db) -> &mut $zalsa::function::IngredientImpl { @@ -195,9 +199,14 @@ macro_rules! setup_tracked_fn { db: &dyn $Db, ) -> &$zalsa::interned::IngredientImpl<$Configuration> { let zalsa = db.zalsa(); - $INTERN_CACHE.get_or_create(zalsa, || { - zalsa.lookup_jar_by_type::<$fn_name>().successor(0) - }) + + // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the second + // ingredient created by our jar is the interned ingredient (given `needs_interner`). + unsafe { + $INTERN_CACHE.get_or_create(zalsa, || { + zalsa.lookup_jar_by_type::<$fn_name>().successor(0) + }) + } } } } diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index 0b5c115bf..5545a44fd 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -196,9 +196,13 @@ macro_rules! setup_tracked_struct { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); - CACHE.get_or_create(zalsa, || { - zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() - }) + // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only + // ingredient created by our jar is the struct ingredient. + unsafe { + CACHE.get_or_create(zalsa, || { + zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() + }) + } } } diff --git a/src/ingredient.rs b/src/ingredient.rs index 796a6e12f..fb567948b 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -177,7 +177,8 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { } impl dyn Ingredient { - /// Equivalent to the `downcast` methods on `any`. + /// Equivalent to the `downcast` method on `Any`. + /// /// Because we do not have dyn-upcasting support, we need this workaround. pub fn assert_type(&self) -> &T { assert_eq!( @@ -192,7 +193,27 @@ impl dyn Ingredient { unsafe { transmute_data_ptr(self) } } - /// Equivalent to the `downcast` methods on `any`. + /// Equivalent to the `downcast` methods on `Any`. + /// + /// Because we do not have dyn-upcasting support, we need this workaround. + /// + /// # Safety + /// + /// The contained value must be of type `T`. + pub unsafe fn assert_type_unchecked(&self) -> &T { + debug_assert_eq!( + self.type_id(), + TypeId::of::(), + "ingredient `{self:?}` is not of type `{}`", + std::any::type_name::() + ); + + // SAFETY: Guaranteed by caller. + unsafe { transmute_data_ptr(self) } + } + + /// Equivalent to the `downcast` method on `Any`. + /// /// Because we do not have dyn-upcasting support, we need this workaround. pub fn assert_type_mut(&mut self) -> &mut T { assert_eq!( diff --git a/src/ingredient_cache.rs b/src/ingredient_cache.rs index 8b9ebe76b..6e36c2cff 100644 --- a/src/ingredient_cache.rs +++ b/src/ingredient_cache.rs @@ -47,7 +47,12 @@ mod imp { /// Get a reference to the ingredient in the database. /// /// If the ingredient index is not already in the cache, it will be loaded and cached. - pub fn get_or_create<'db>( + /// + /// # Safety + /// + /// The `IngredientIndex` returned by the closure must reference a valid ingredient of + /// type `I` in the provided zalsa database. + pub unsafe fn get_or_create<'db>( &self, zalsa: &'db Zalsa, load_index: impl Fn() -> IngredientIndex, @@ -57,9 +62,21 @@ mod imp { ingredient_index = self.get_or_create_index_slow(load_index).as_u32(); }; - zalsa - .lookup_ingredient(IngredientIndex::from_unchecked(ingredient_index)) - .assert_type() + // SAFETY: `ingredient_index` is initialized from a valid `IngredientIndex`. + let ingredient_index = unsafe { IngredientIndex::new_unchecked(ingredient_index) }; + + // SAFETY: There are a two cases here: + // - The `create_index` closure was called due to the data being uncached. In this + // case, the caller guarantees the index is in-bounds and has the correct type. + // - The index was cached. While the current database might not be the same database + // the ingredient was initially loaded from, the `inventory` feature is enabled, so + // ingredient indices are stable across databases. Thus the index is still in-bounds + // and has the correct type. + unsafe { + zalsa + .lookup_ingredient_unchecked(ingredient_index) + .assert_type_unchecked() + } } #[cold] @@ -134,14 +151,30 @@ mod imp { /// Get a reference to the ingredient in the database. /// /// If the ingredient is not already in the cache, it will be created. + /// + /// # Safety + /// + /// The `IngredientIndex` returned by the closure must reference a valid ingredient of + /// type `I` in the provided zalsa database. #[inline(always)] - pub fn get_or_create<'db>( + pub unsafe fn get_or_create<'db>( &self, zalsa: &'db Zalsa, create_index: impl Fn() -> IngredientIndex, ) -> &'db I { let index = self.get_or_create_index(zalsa, create_index); - zalsa.lookup_ingredient(index).assert_type::() + + // SAFETY: There are a two cases here: + // - The `create_index` closure was called due to the data being uncached for the + // provided database. In this case, the caller guarantees the index is in-bounds + // and has the correct type. + // - We verified the index was cached for the same database, by the nonce check. + // Thus the initial safety argument still applies. + unsafe { + zalsa + .lookup_ingredient_unchecked(index) + .assert_type_unchecked::() + } } pub fn get_or_create_index( @@ -159,7 +192,9 @@ mod imp { }; // Unpack our `u64` into the nonce and index. - let index = IngredientIndex::from_unchecked(cached_data as u32); + // + // SAFETY: The lower bits of `cached_data` are initialized from a valid `IngredientIndex`. + let index = unsafe { IngredientIndex::new_unchecked(cached_data as u32) }; // SAFETY: We've checked against `UNINITIALIZED` (0) above and so the upper bits must be non-zero. let nonce = crate::nonce::Nonce::::from_u32(unsafe { diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index e662fa902..ba8637558 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -975,19 +975,19 @@ mod tests { let mut d = DisambiguatorMap::default(); // set up all 4 permutations of differing field values let h1 = IdentityHash { - ingredient_index: IngredientIndex::from(0), + ingredient_index: IngredientIndex::new(0), hash: 0, }; let h2 = IdentityHash { - ingredient_index: IngredientIndex::from(1), + ingredient_index: IngredientIndex::new(1), hash: 0, }; let h3 = IdentityHash { - ingredient_index: IngredientIndex::from(0), + ingredient_index: IngredientIndex::new(0), hash: 1, }; let h4 = IdentityHash { - ingredient_index: IngredientIndex::from(1), + ingredient_index: IngredientIndex::new(1), hash: 1, }; assert_eq!(d.disambiguate(h1), Disambiguator(0)); @@ -1005,42 +1005,42 @@ mod tests { let mut d = IdentityMap::default(); // set up all 8 permutations of differing field values let i1 = Identity { - ingredient_index: IngredientIndex::from(0), + ingredient_index: IngredientIndex::new(0), hash: 0, disambiguator: Disambiguator(0), }; let i2 = Identity { - ingredient_index: IngredientIndex::from(1), + ingredient_index: IngredientIndex::new(1), hash: 0, disambiguator: Disambiguator(0), }; let i3 = Identity { - ingredient_index: IngredientIndex::from(0), + ingredient_index: IngredientIndex::new(0), hash: 1, disambiguator: Disambiguator(0), }; let i4 = Identity { - ingredient_index: IngredientIndex::from(1), + ingredient_index: IngredientIndex::new(1), hash: 1, disambiguator: Disambiguator(0), }; let i5 = Identity { - ingredient_index: IngredientIndex::from(0), + ingredient_index: IngredientIndex::new(0), hash: 0, disambiguator: Disambiguator(1), }; let i6 = Identity { - ingredient_index: IngredientIndex::from(1), + ingredient_index: IngredientIndex::new(1), hash: 0, disambiguator: Disambiguator(1), }; let i7 = Identity { - ingredient_index: IngredientIndex::from(0), + ingredient_index: IngredientIndex::new(0), hash: 1, disambiguator: Disambiguator(1), }; let i8 = Identity { - ingredient_index: IngredientIndex::from(1), + ingredient_index: IngredientIndex::new(1), hash: 1, disambiguator: Disambiguator(1), }; diff --git a/src/zalsa.rs b/src/zalsa.rs index 9f5a1fb3c..751548972 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -81,14 +81,18 @@ impl IngredientIndex { const MAX_INDEX: u32 = 0x7FFF_FFFF; /// Create an ingredient index from a `u32`. - pub(crate) fn from(v: u32) -> Self { + pub(crate) fn new(v: u32) -> Self { assert!(v <= Self::MAX_INDEX); Self(v) } /// Create an ingredient index from a `u32`, without performing validating /// that the index is valid. - pub(crate) fn from_unchecked(v: u32) -> Self { + /// + /// # Safety + /// + /// The index must be less than or equal to `IngredientIndex::MAX_INDEX`. + pub(crate) unsafe fn new_unchecked(v: u32) -> Self { Self(v) } @@ -236,6 +240,7 @@ impl Zalsa { unsafe { T::memo_table(self, id, self.current_revision()) } } + /// Returns the ingredient at the given index, or panics if it is out-of-bounds. #[inline] pub fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient { let index = index.as_u32() as usize; @@ -245,6 +250,19 @@ impl Zalsa { .as_ref() } + /// Returns the ingredient at the given index. + /// + /// # Safety + /// + /// The index must be in-bounds. + #[inline] + pub unsafe fn lookup_ingredient_unchecked(&self, index: IngredientIndex) -> &dyn Ingredient { + let index = index.as_u32() as usize; + + // SAFETY: Guaranteed by caller. + unsafe { self.ingredients_vec.get_unchecked(index).as_ref() } + } + pub(crate) fn ingredient_index_for_memo( &self, struct_ingredient_index: IngredientIndex, @@ -331,7 +349,7 @@ impl Zalsa { fn insert_jar(&mut self, jar: ErasedJar) { let jar_type_id = (jar.type_id)(); - let index = IngredientIndex::from(self.ingredients_vec.len() as u32); + let index = IngredientIndex::new(self.ingredients_vec.len() as u32); if self.jar_map.contains_key(&jar_type_id) { return; diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 51c28c9c5..9b89ab405 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -754,7 +754,11 @@ impl QueryOrigin { QueryOriginKind::Assigned => { // SAFETY: `data.index` is initialized when the tag is `QueryOriginKind::Assigned`. let index = unsafe { self.data.index }; - let ingredient_index = IngredientIndex::from(self.metadata); + + // SAFETY: `metadata` is initialized from a valid `IngredientIndex` when the tag + // is `QueryOriginKind::Assigned`. + let ingredient_index = unsafe { IngredientIndex::new_unchecked(self.metadata) }; + QueryOriginRef::Assigned(DatabaseKeyIndex::new(ingredient_index, index)) } diff --git a/tests/interned-structs_self_ref.rs b/tests/interned-structs_self_ref.rs index 3ff12a09c..55eb8c06f 100644 --- a/tests/interned-structs_self_ref.rs +++ b/tests/interned-structs_self_ref.rs @@ -99,9 +99,14 @@ const _: () = { zalsa_::IngredientCache::new(); let zalsa = db.zalsa(); - CACHE.get_or_create(zalsa, || { - zalsa.lookup_jar_by_type::>() - }) + + // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only + // ingredient created by our jar is the struct ingredient. + unsafe { + CACHE.get_or_create(zalsa, || { + zalsa.lookup_jar_by_type::>() + }) + } } } impl zalsa_::AsId for InternedString<'_> { From bb0831a64046d423f2517b2173d991239d4f5022 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Tue, 22 Jul 2025 14:45:02 -0400 Subject: [PATCH 08/65] Outline all tracing events (#942) * outline all tracing events * outline log events --- src/database_impl.rs | 2 +- src/function/backdate.rs | 2 +- src/function/execute.rs | 14 ++++---- src/function/fetch.rs | 7 ++-- src/function/maybe_changed_after.rs | 18 +++++----- src/function/memo.rs | 12 +++---- src/function/specify.rs | 2 +- src/lib.rs | 1 + src/runtime.rs | 8 ++--- src/tracing.rs | 54 +++++++++++++++++++++++++++++ src/tracked_struct.rs | 8 ++--- src/zalsa.rs | 16 ++++++--- src/zalsa_local.rs | 15 +++++--- 13 files changed, 116 insertions(+), 43 deletions(-) create mode 100644 src/tracing.rs diff --git a/src/database_impl.rs b/src/database_impl.rs index c1eda125a..8b8c9bd25 100644 --- a/src/database_impl.rs +++ b/src/database_impl.rs @@ -16,7 +16,7 @@ impl Default for DatabaseImpl { // Default behavior: tracing debug log the event. storage: Storage::new(if tracing::enabled!(Level::DEBUG) { Some(Box::new(|event| { - tracing::debug!("salsa_event({:?})", event) + crate::tracing::debug!("salsa_event({:?})", event) })) } else { None diff --git a/src/function/backdate.rs b/src/function/backdate.rs index 873041597..f735f577b 100644 --- a/src/function/backdate.rs +++ b/src/function/backdate.rs @@ -34,7 +34,7 @@ where if revisions.durability >= old_memo.revisions.durability && C::values_equal(old_value, value) { - tracing::debug!( + crate::tracing::debug!( "{index:?} value is equal, back-dating to {:?}", old_memo.revisions.changed_at, ); diff --git a/src/function/execute.rs b/src/function/execute.rs index 13ecca561..9013ee7fe 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -29,7 +29,7 @@ where let database_key_index = active_query.database_key_index; let id = database_key_index.key_index(); - tracing::info!("{:?}: executing query", database_key_index); + crate::tracing::info!("{:?}: executing query", database_key_index); let zalsa = db.zalsa(); zalsa.event(&|| { @@ -169,7 +169,7 @@ where }; // SAFETY: The `LRU` does not run mid-execution, so the value remains filled let last_provisional_value = unsafe { last_provisional_value.unwrap_unchecked() }; - tracing::debug!( + crate::tracing::debug!( "{database_key_index:?}: execute: \ I am a cycle head, comparing last provisional value with new value" ); @@ -196,7 +196,7 @@ where ) { crate::CycleRecoveryAction::Iterate => {} crate::CycleRecoveryAction::Fallback(fallback_value) => { - tracing::debug!( + crate::tracing::debug!( "{database_key_index:?}: execute: user cycle_fn says to fall back" ); new_value = fallback_value; @@ -220,7 +220,7 @@ where }); cycle_heads.update_iteration_count(database_key_index, iteration_count); revisions.update_iteration_count(iteration_count); - tracing::debug!( + crate::tracing::debug!( "{database_key_index:?}: execute: iterate again, revisions: {revisions:#?}" ); opt_last_provisional = Some(self.insert_memo( @@ -236,7 +236,7 @@ where continue; } - tracing::debug!( + crate::tracing::debug!( "{database_key_index:?}: execute: fixpoint iteration has a final value" ); cycle_heads.remove(&database_key_index); @@ -247,7 +247,9 @@ where } } - tracing::debug!("{database_key_index:?}: execute: result.revisions = {revisions:#?}"); + crate::tracing::debug!( + "{database_key_index:?}: execute: result.revisions = {revisions:#?}" + ); break (new_value, revisions); } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 6c7819f81..252bba124 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -15,8 +15,9 @@ where zalsa.unwind_if_revision_cancelled(zalsa_local); let database_key_index = self.database_key_index(id); + #[cfg(debug_assertions)] - let _span = tracing::debug_span!("fetch", query = ?database_key_index).entered(); + let _span = crate::tracing::debug_span!("fetch", query = ?database_key_index).entered(); let memo = self.refresh_memo(db, zalsa, zalsa_local, id); // SAFETY: We just refreshed the memo so it is guaranteed to contain a value now. @@ -169,7 +170,7 @@ where ); }), CycleRecoveryStrategy::Fixpoint => { - tracing::debug!( + crate::tracing::debug!( "hit cycle at {database_key_index:#?}, \ inserting and returning fixpoint initial value" ); @@ -183,7 +184,7 @@ where )) } CycleRecoveryStrategy::FallbackImmediate => { - tracing::debug!( + crate::tracing::debug!( "hit a `FallbackImmediate` cycle at {database_key_index:#?}" ); let active_query = diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 9d0ca4c44..7df0a41fe 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -53,7 +53,9 @@ where loop { let database_key_index = self.database_key_index(id); - tracing::debug!("{database_key_index:?}: maybe_changed_after(revision = {revision:?})"); + crate::tracing::debug!( + "{database_key_index:?}: maybe_changed_after(revision = {revision:?})" + ); // Check if we have a verified version: this is the hot path. let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); @@ -117,7 +119,7 @@ where return Some(VerifyResult::unchanged()); } CycleRecoveryStrategy::Fixpoint => { - tracing::debug!( + crate::tracing::debug!( "hit cycle at {database_key_index:?} in `maybe_changed_after`, returning fixpoint initial value", ); cycle_heads.push_initial(database_key_index); @@ -132,7 +134,7 @@ where return Some(VerifyResult::Changed); }; - tracing::debug!( + crate::tracing::debug!( "{database_key_index:?}: maybe_changed_after_cold, successful claim, \ revision = {revision:?}, old_memo = {old_memo:#?}", old_memo = old_memo.tracing_debug() @@ -194,7 +196,7 @@ where database_key_index: DatabaseKeyIndex, memo: &Memo<'_, C>, ) -> ShallowUpdate { - tracing::debug!( + crate::tracing::debug!( "{database_key_index:?}: shallow_verify_memo(memo = {memo:#?})", memo = memo.tracing_debug() ); @@ -207,7 +209,7 @@ where } let last_changed = zalsa.last_changed_revision(memo.revisions.durability); - tracing::debug!( + crate::tracing::debug!( "{database_key_index:?}: check_durability(memo = {memo:#?}, last_changed={:?} <= verified_at={:?}) = {:?}", last_changed, verified_at, @@ -263,7 +265,7 @@ where database_key_index: DatabaseKeyIndex, memo: &Memo<'_, C>, ) -> bool { - tracing::trace!( + crate::tracing::trace!( "{database_key_index:?}: validate_provisional(memo = {memo:#?})", memo = memo.tracing_debug() ); @@ -324,7 +326,7 @@ where database_key_index: DatabaseKeyIndex, memo: &Memo<'_, C>, ) -> bool { - tracing::trace!( + crate::tracing::trace!( "{database_key_index:?}: validate_same_iteration(memo = {memo:#?})", memo = memo.tracing_debug() ); @@ -377,7 +379,7 @@ where database_key_index: DatabaseKeyIndex, cycle_heads: &mut CycleHeads, ) -> VerifyResult { - tracing::debug!( + crate::tracing::debug!( "{database_key_index:?}: deep_verify_memo(old_memo = {old_memo:#?})", old_memo = old_memo.tracing_debug() ); diff --git a/src/function/memo.rs b/src/function/memo.rs index 37e600070..d6a872b69 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -153,7 +153,7 @@ impl<'db, C: Configuration> Memo<'db, C> { } else { // all our cycle heads are complete; re-fetch // and we should get a non-provisional memo. - tracing::debug!( + crate::tracing::debug!( "Retrying provisional memo {database_key_index:?} after awaiting cycle heads." ); true @@ -180,7 +180,7 @@ impl<'db, C: Configuration> Memo<'db, C> { #[inline(never)] fn block_on_heads_cold(zalsa: &Zalsa, heads: &CycleHeads) -> bool { - let _entered = tracing::debug_span!("block_on_heads").entered(); + let _entered = crate::tracing::debug_span!("block_on_heads").entered(); let mut cycle_heads = TryClaimCycleHeadsIter::new(zalsa, heads); let mut all_cycles = true; @@ -209,7 +209,7 @@ impl<'db, C: Configuration> Memo<'db, C> { /// Unlike `block_on_heads`, this code does not block on any cycle head. Instead it returns `false` if /// claiming all cycle heads failed because one of them is running on another thread. pub(super) fn try_claim_heads(&self, zalsa: &Zalsa, zalsa_local: &ZalsaLocal) -> bool { - let _entered = tracing::debug_span!("try_claim_heads").entered(); + let _entered = crate::tracing::debug_span!("try_claim_heads").entered(); if self.all_cycles_on_stack(zalsa_local) { return true; } @@ -419,7 +419,7 @@ impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { ProvisionalStatus::Final { .. } | ProvisionalStatus::FallbackImmediate => { // This cycle is already finalized, so we don't need to wait on it; // keep looping through cycle heads. - tracing::trace!("Dependent cycle head {head:?} has been finalized."); + crate::tracing::trace!("Dependent cycle head {head:?} has been finalized."); Some(TryClaimHeadsResult::Finalized) } ProvisionalStatus::Provisional { .. } => { @@ -427,11 +427,11 @@ impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { WaitForResult::Cycle { .. } => { // We hit a cycle blocking on the cycle head; this means this query actively // participates in the cycle and some other query is blocked on this thread. - tracing::debug!("Waiting for {head:?} results in a cycle"); + crate::tracing::debug!("Waiting for {head:?} results in a cycle"); Some(TryClaimHeadsResult::Cycle) } WaitForResult::Running(running) => { - tracing::debug!("Ingredient {head:?} is running: {running:?}"); + crate::tracing::debug!("Ingredient {head:?} is running: {running:?}"); Some(TryClaimHeadsResult::Running(RunningCycleHead { inner: running, diff --git a/src/function/specify.rs b/src/function/specify.rs index 10a85b513..3bc71c565 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -83,7 +83,7 @@ where revisions, }; - tracing::debug!( + crate::tracing::debug!( "specify: about to add memo {:#?} for key {:?}", memo.tracing_debug(), key diff --git a/src/lib.rs b/src/lib.rs index cbc919377..7bc94eec4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,6 +26,7 @@ mod salsa_struct; mod storage; mod sync; mod table; +mod tracing; mod tracked_struct; mod update; mod views; diff --git a/src/runtime.rs b/src/runtime.rs index 8366e16f1..bc2859a7e 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -86,7 +86,7 @@ impl Running<'_> { }) }); - tracing::debug!( + crate::tracing::debug!( "block_on: thread {thread_id:?} is blocking on {database_key:?} in thread {other_id:?}", ); @@ -180,7 +180,7 @@ impl Runtime { } pub(crate) fn set_cancellation_flag(&self) { - tracing::trace!("set_cancellation_flag"); + crate::tracing::trace!("set_cancellation_flag"); self.revision_canceled.store(true, Ordering::Release); } @@ -206,7 +206,7 @@ impl Runtime { let r_old = self.current_revision(); let r_new = r_old.next(); self.revisions[0] = r_new; - tracing::debug!("new_revision: {r_old:?} -> {r_new:?}"); + crate::tracing::debug!("new_revision: {r_old:?} -> {r_new:?}"); r_new } @@ -236,7 +236,7 @@ impl Runtime { let dg = self.dependency_graph.lock(); if dg.depends_on(other_id, thread_id) { - tracing::debug!("block_on: cycle detected for {database_key:?} in thread {thread_id:?} on {other_id:?}"); + crate::tracing::debug!("block_on: cycle detected for {database_key:?} in thread {thread_id:?} on {other_id:?}"); return BlockResult::Cycle { same_thread: false }; } diff --git a/src/tracing.rs b/src/tracing.rs new file mode 100644 index 000000000..47f95d00e --- /dev/null +++ b/src/tracing.rs @@ -0,0 +1,54 @@ +//! Wrappers around `tracing` macros that avoid inlining debug machinery into the hot path, +//! as tracing events are typically only enabled for debugging purposes. + +macro_rules! trace { + ($($x:tt)*) => { + crate::tracing::event!(TRACE, $($x)*) + }; +} + +macro_rules! info { + ($($x:tt)*) => { + crate::tracing::event!(INFO, $($x)*) + }; +} + +macro_rules! debug { + ($($x:tt)*) => { + crate::tracing::event!(DEBUG, $($x)*) + }; +} + +macro_rules! debug_span { + ($($x:tt)*) => { + crate::tracing::span!(DEBUG, $($x)*) + }; +} + +macro_rules! event { + ($level:ident, $($x:tt)*) => {{ + let event = { + #[cold] #[inline(never)] || { ::tracing::event!(::tracing::Level::$level, $($x)*) } + }; + + if ::tracing::enabled!(::tracing::Level::$level) { + event(); + } + }}; +} + +macro_rules! span { + ($level:ident, $($x:tt)*) => {{ + let span = { + #[cold] #[inline(never)] || { ::tracing::span!(::tracing::Level::$level, $($x)*) } + }; + + if ::tracing::enabled!(::tracing::Level::$level) { + span() + } else { + ::tracing::Span::none() + } + }}; +} + +pub(crate) use {debug, debug_span, event, info, span, trace}; diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index ba8637558..977b214d1 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -397,7 +397,7 @@ where if let Some(id) = zalsa_local.tracked_struct_id(&identity) { // The struct already exists in the intern map. let index = self.database_key_index(id); - tracing::trace!("Reuse tracked struct {id:?}", id = index); + crate::tracing::trace!("Reuse tracked struct {id:?}", id = index); zalsa_local.add_output(index); // SAFETY: The `id` was present in the interned map, so the value must be initialized. @@ -423,7 +423,7 @@ where // in the struct map. let id = self.allocate(zalsa, zalsa_local, current_revision, ¤t_deps, fields); let key = self.database_key_index(id); - tracing::trace!("Allocated new tracked struct {key:?}"); + crate::tracing::trace!("Allocated new tracked struct {key:?}"); zalsa_local.add_output(key); zalsa_local.store_tracked_struct_id(identity, id); FromId::from_id(id) @@ -453,7 +453,7 @@ where // If the generation would overflow, we are forced to leak the slot. Note that this // shouldn't be a problem in general as sufficient bits are reserved for the generation. let Some(id) = id.next_generation() else { - tracing::info!( + crate::tracing::info!( "leaking tracked struct {:?} due to generation overflow", self.database_key_index(id) ); @@ -553,7 +553,7 @@ where // the unlikely case that the ID is already at its maximum generation, we are forced to leak // the previous slot and allocate a new value. if id.generation() == u32::MAX { - tracing::info!( + crate::tracing::info!( "leaking tracked struct {:?} due to generation overflow", self.database_key_index(id) ); diff --git a/src/zalsa.rs b/src/zalsa.rs index 751548972..41ece3cae 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -414,7 +414,7 @@ impl Zalsa { #[doc(hidden)] pub fn new_revision(&mut self) -> Revision { let new_revision = self.runtime.new_revision(); - let _span = tracing::debug_span!("new_revision", ?new_revision).entered(); + let _span = crate::tracing::debug_span!("new_revision", ?new_revision).entered(); for (_, index) in self.ingredients_requiring_reset.iter() { let index = index.as_u32() as usize; @@ -432,7 +432,7 @@ impl Zalsa { /// **NOT SEMVER STABLE** #[doc(hidden)] pub fn evict_lru(&mut self) { - let _span = tracing::debug_span!("evict_lru").entered(); + let _span = crate::tracing::debug_span!("evict_lru").entered(); for (_, index) in self.ingredients_requiring_reset.iter() { let index = index.as_u32() as usize; self.ingredients_vec @@ -449,10 +449,18 @@ impl Zalsa { #[inline(always)] pub fn event(&self, event: &dyn Fn() -> crate::Event) { - if let Some(event_callback) = &self.event_callback { - event_callback(event()); + if self.event_callback.is_some() { + self.event_cold(event); } } + + // Avoid inlining, as events are typically only enabled for debugging purposes. + #[cold] + #[inline(never)] + pub fn event_cold(&self, event: &dyn Fn() -> crate::Event) { + let event_callback = self.event_callback.as_ref().unwrap(); + event_callback(event()); + } } /// A type-erased `Jar`, used for ingredient registration. diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 9b89ab405..294ab0843 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -4,7 +4,6 @@ use std::ptr::{self, NonNull}; use rustc_hash::FxHashMap; use thin_vec::ThinVec; -use tracing::debug; use crate::accumulator::accumulated_map::{AccumulatedMap, AtomicInputAccumulatedValues}; use crate::active_query::QueryStack; @@ -197,10 +196,13 @@ impl ZalsaLocal { accumulated_inputs: &AtomicInputAccumulatedValues, cycle_heads: &CycleHeads, ) { - debug!( + crate::tracing::debug!( "report_tracked_read(input={:?}, durability={:?}, changed_at={:?})", - input, durability, changed_at + input, + durability, + changed_at ); + self.with_query_stack_mut(|stack| { if let Some(top_query) = stack.last_mut() { top_query.add_read( @@ -223,10 +225,13 @@ impl ZalsaLocal { durability: Durability, changed_at: Revision, ) { - debug!( + crate::tracing::debug!( "report_tracked_read(input={:?}, durability={:?}, changed_at={:?})", - input, durability, changed_at + input, + durability, + changed_at ); + self.with_query_stack_mut(|stack| { if let Some(top_query) = stack.last_mut() { top_query.add_read_simple(input, durability, changed_at); From 8b6d12b596833f6200dd239145047c787633c39a Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Sun, 27 Jul 2025 15:14:48 -0400 Subject: [PATCH 09/65] remove extra bounds checks from memo table hot-paths (#938) --- src/input.rs | 4 +- src/interned.rs | 4 +- src/table/memo.rs | 102 +++++++++++++++++++++++++----------------- src/tracked_struct.rs | 4 +- 4 files changed, 71 insertions(+), 43 deletions(-) diff --git a/src/input.rs b/src/input.rs index c13d23e21..fe25d9b91 100644 --- a/src/input.rs +++ b/src/input.rs @@ -116,7 +116,9 @@ impl IngredientImpl { fields, revisions, durabilities, - memos: MemoTable::new(self.memo_table_types()), + // SAFETY: We only ever access the memos of a value that we allocated through + // our `MemoTableTypes`. + memos: unsafe { MemoTable::new(self.memo_table_types()) }, }) }); diff --git a/src/interned.rs b/src/interned.rs index 4350e77a6..812512ff6 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -586,7 +586,9 @@ where let id = zalsa_local.allocate(zalsa, self.ingredient_index, |id| Value:: { shard: shard_index as u16, link: LinkedListLink::new(), - memos: UnsafeCell::new(MemoTable::new(self.memo_table_types())), + // SAFETY: We only ever access the memos of a value that we allocated through + // our `MemoTableTypes`. + memos: UnsafeCell::new(unsafe { MemoTable::new(self.memo_table_types()) }), // SAFETY: We call `from_internal_data` to restore the correct lifetime before access. fields: UnsafeCell::new(unsafe { self.to_internal_data(assemble(id, key)) }), shared: UnsafeCell::new(ValueShared { diff --git a/src/table/memo.rs b/src/table/memo.rs index 3d5f9bc17..b7bc5fb7d 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -15,7 +15,14 @@ pub struct MemoTable { impl MemoTable { /// Create a `MemoTable` with slots for memos from the provided `MemoTableTypes`. - pub fn new(types: &MemoTableTypes) -> Self { + /// + /// # Safety + /// + /// The created memo table must only be accessed with the same `MemoTableTypes`. + pub unsafe fn new(types: &MemoTableTypes) -> Self { + // Note that the safety invariant guarantees that any indices in-bounds for + // this table are also in-bounds for its `MemoTableTypes`, as `MemoTableTypes` + // is append-only. Self { memos: (0..types.len()).map(|_| MemoEntry::default()).collect(), } @@ -179,46 +186,51 @@ impl MemoTableWithTypes<'_> { memo_ingredient_index: MemoIngredientIndex, memo: NonNull, ) -> Option> { - // The type must already exist, we insert it when creating the memo ingredient. - assert_eq!( + let MemoEntry { atomic_memo } = self.memos.memos.get(memo_ingredient_index.as_usize())?; + + // SAFETY: Any indices that are in-bounds for the `MemoTable` are also in-bounds for its + // corresponding `MemoTableTypes`, by construction. + let type_ = unsafe { self.types .types - .get(memo_ingredient_index.as_usize())? - .type_id, - TypeId::of::(), - "inconsistent type-id for `{memo_ingredient_index:?}`" - ); - - // The memo table is pre-sized on creation based on the corresponding `MemoTableTypes`. - let MemoEntry { atomic_memo } = self - .memos - .memos - .get(memo_ingredient_index.as_usize()) - .expect("accessed memo table with invalid index"); + .get_unchecked(memo_ingredient_index.as_usize()) + }; - let old_memo = atomic_memo.swap(MemoEntryType::to_dummy(memo).as_ptr(), Ordering::AcqRel); + // Verify that the we are casting to the correct type. + if type_.type_id != TypeId::of::() { + type_assert_failed(memo_ingredient_index); + } - let old_memo = NonNull::new(old_memo); + let old_memo = atomic_memo.swap(MemoEntryType::to_dummy(memo).as_ptr(), Ordering::AcqRel); - // SAFETY: `type_id` check asserted above - old_memo.map(|old_memo| unsafe { MemoEntryType::from_dummy(old_memo) }) + // SAFETY: We asserted that the type is correct above. + NonNull::new(old_memo).map(|old_memo| unsafe { MemoEntryType::from_dummy(old_memo) }) } + /// Returns a pointer to the memo at the given index, if one has been inserted. #[inline] pub(crate) fn get( self, memo_ingredient_index: MemoIngredientIndex, ) -> Option> { - let memo = self.memos.memos.get(memo_ingredient_index.as_usize())?; - let type_ = self.types.types.get(memo_ingredient_index.as_usize())?; - assert_eq!( - type_.type_id, - TypeId::of::(), - "inconsistent type-id for `{memo_ingredient_index:?}`" - ); - let memo = NonNull::new(memo.atomic_memo.load(Ordering::Acquire))?; - // SAFETY: `type_id` check asserted above - Some(unsafe { MemoEntryType::from_dummy(memo) }) + let MemoEntry { atomic_memo } = self.memos.memos.get(memo_ingredient_index.as_usize())?; + + // SAFETY: Any indices that are in-bounds for the `MemoTable` are also in-bounds for its + // corresponding `MemoTableTypes`, by construction. + let type_ = unsafe { + self.types + .types + .get_unchecked(memo_ingredient_index.as_usize()) + }; + + // Verify that the we are casting to the correct type. + if type_.type_id != TypeId::of::() { + type_assert_failed(memo_ingredient_index); + } + + NonNull::new(atomic_memo.load(Ordering::Acquire)) + // SAFETY: We asserted that the type is correct above. + .map(|memo| unsafe { MemoEntryType::from_dummy(memo) }) } #[cfg(feature = "salsa_unstable")] @@ -256,27 +268,30 @@ impl MemoTableWithTypesMut<'_> { memo_ingredient_index: MemoIngredientIndex, f: impl FnOnce(&mut M), ) { - let Some(type_) = self.types.types.get(memo_ingredient_index.as_usize()) else { - return; - }; - assert_eq!( - type_.type_id, - TypeId::of::(), - "inconsistent type-id for `{memo_ingredient_index:?}`" - ); - - // The memo table is pre-sized on creation based on the corresponding `MemoTableTypes`. let Some(MemoEntry { atomic_memo }) = self.memos.memos.get_mut(memo_ingredient_index.as_usize()) else { return; }; + // SAFETY: Any indices that are in-bounds for the `MemoTable` are also in-bounds for its + // corresponding `MemoTableTypes`, by construction. + let type_ = unsafe { + self.types + .types + .get_unchecked(memo_ingredient_index.as_usize()) + }; + + // Verify that the we are casting to the correct type. + if type_.type_id != TypeId::of::() { + type_assert_failed(memo_ingredient_index); + } + let Some(memo) = NonNull::new(*atomic_memo.get_mut()) else { return; }; - // SAFETY: `type_id` check asserted above + // SAFETY: We asserted that the type is correct above. f(unsafe { MemoEntryType::from_dummy(memo).as_mut() }); } @@ -319,6 +334,13 @@ impl MemoTableWithTypesMut<'_> { } } +/// This function is explicitly outlined to avoid debug machinery in the hot-path. +#[cold] +#[inline(never)] +fn type_assert_failed(memo_ingredient_index: MemoIngredientIndex) -> ! { + panic!("inconsistent type-id for `{memo_ingredient_index:?}`") +} + impl MemoEntry { /// # Safety /// diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 977b214d1..f6f4ea440 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -443,7 +443,9 @@ where // lifetime erase for storage fields: unsafe { mem::transmute::, C::Fields<'static>>(fields) }, revisions: C::new_revisions(current_deps.changed_at), - memos: MemoTable::new(self.memo_table_types()), + // SAFETY: We only ever access the memos of a value that we allocated through + // our `MemoTableTypes`. + memos: unsafe { MemoTable::new(self.memo_table_types()) }, }; while let Some(id) = self.free_list.pop() { From f3dc2f30f9a250618161e35600a00de7fe744953 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Wed, 30 Jul 2025 12:04:08 +0200 Subject: [PATCH 10/65] Retain backing allocation of `ActiveQuery::input_outputs` in `ActiveQuery::seed_iteration` (#948) --- src/active_query.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/active_query.rs b/src/active_query.rs index cb563132d..7789fe6de 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -73,7 +73,7 @@ impl ActiveQuery { untracked_read: bool, ) { assert!(self.input_outputs.is_empty()); - self.input_outputs = edges.iter().cloned().collect(); + self.input_outputs.extend(edges.iter().cloned()); self.durability = self.durability.min(durability); self.changed_at = self.changed_at.max(changed_at); self.untracked_read |= untracked_read; From 211bc158dfe432138cfa77c6e0a51b6a2bc82884 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Thu, 31 Jul 2025 17:56:13 +0200 Subject: [PATCH 11/65] Do manual trait casting (#922) * Do manual trait upcasting instead of downcasting * Remove another dynamic `zalsa` call * Rename UpCaster back to DownCaster * Address reviews --- .../src/setup_input_struct.rs | 19 +-- .../src/setup_interned_struct.rs | 15 +- .../salsa-macro-rules/src/setup_tracked_fn.rs | 37 +++-- .../src/setup_tracked_struct.rs | 16 ++- components/salsa-macros/src/db.rs | 22 ++- src/accumulator.rs | 3 +- src/database.rs | 57 ++++---- src/function.rs | 18 +-- src/function/accumulated.rs | 9 +- src/function/execute.rs | 34 ++--- src/function/fetch.rs | 15 +- src/function/maybe_changed_after.rs | 6 +- src/function/memo.rs | 2 +- src/ingredient.rs | 26 ++-- src/input.rs | 28 ++-- src/input/input_field.rs | 6 +- src/interned.rs | 39 +++-- src/key.rs | 8 +- src/lib.rs | 2 +- src/parallel.rs | 77 ++++++++-- src/storage.rs | 5 +- src/tracked_struct.rs | 32 ++--- src/tracked_struct/tracked_field.rs | 10 +- src/views.rs | 136 +++++++++++------- src/zalsa.rs | 3 +- tests/debug_db_contents.rs | 9 +- tests/interned-structs.rs | 4 +- tests/interned-structs_self_ref.rs | 9 +- 28 files changed, 371 insertions(+), 276 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index fed0594a6..d6d131abf 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -118,8 +118,7 @@ macro_rules! setup_input_struct { } } - pub fn ingredient_mut(db: &mut dyn $zalsa::Database) -> (&mut $zalsa_struct::IngredientImpl, &mut $zalsa::Runtime) { - let zalsa_mut = db.zalsa_mut(); + pub fn ingredient_mut(zalsa_mut: &mut $zalsa::Zalsa) -> (&mut $zalsa_struct::IngredientImpl, &mut $zalsa::Runtime) { zalsa_mut.new_revision(); let index = zalsa_mut.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>(); let (ingredient, runtime) = zalsa_mut.lookup_ingredient_mut(index); @@ -208,8 +207,10 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let fields = $Configuration::ingredient_(db.zalsa()).field( - db.as_dyn_database(), + let (zalsa, zalsa_local) = db.zalsas(); + let fields = $Configuration::ingredient_(zalsa).field( + zalsa, + zalsa_local, self, $field_index, ); @@ -228,7 +229,8 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let (ingredient, revision) = $Configuration::ingredient_mut(db.as_dyn_database_mut()); + let zalsa = db.zalsa_mut(); + let (ingredient, revision) = $Configuration::ingredient_mut(zalsa); $zalsa::input::SetterImpl::new( revision, self, @@ -267,7 +269,8 @@ macro_rules! setup_input_struct { $(for<'__trivial_bounds> $field_ty: std::fmt::Debug),* { $zalsa::with_attached_database(|db| { - let fields = $Configuration::ingredient(db).leak_fields(db, this); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient_(zalsa).leak_fields(zalsa, this); let mut f = f.debug_struct(stringify!($Struct)); let f = f.field("[salsa id]", &$zalsa::AsId::as_id(&this)); $( @@ -296,11 +299,11 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + salsa::Database { - let zalsa = db.zalsa(); + let (zalsa, zalsa_local) = db.zalsas(); let current_revision = zalsa.current_revision(); let ingredient = $Configuration::ingredient_(zalsa); let (fields, revision, durabilities) = builder::builder_into_inner(self, current_revision); - ingredient.new_input(db.as_dyn_database(), fields, revision, durabilities) + ingredient.new_input(zalsa, zalsa_local, fields, revision, durabilities) } } diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index 9b121fdc1..b637586e5 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -149,15 +149,11 @@ macro_rules! setup_interned_struct { } impl $Configuration { - pub fn ingredient(db: &Db) -> &$zalsa_struct::IngredientImpl - where - Db: ?Sized + $zalsa::Database, + pub fn ingredient(zalsa: &$zalsa::Zalsa) -> &$zalsa_struct::IngredientImpl { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); - let zalsa = db.zalsa(); - // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only // ingredient created by our jar is the struct ingredient. unsafe { @@ -239,7 +235,8 @@ macro_rules! setup_interned_struct { $field_ty: $zalsa::interned::HashEqLike<$indexed_ty>, )* { - $Configuration::ingredient(db).intern(db.as_dyn_database(), + let (zalsa, zalsa_local) = db.zalsas(); + $Configuration::ingredient(zalsa).intern(zalsa, zalsa_local, StructKey::<$db_lt>($($field_id,)* std::marker::PhantomData::default()), |_, data| ($($zalsa::interned::Lookup::into_owned(data.$field_index),)*)) } @@ -250,7 +247,8 @@ macro_rules! setup_interned_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let fields = $Configuration::ingredient(db).fields(db.as_dyn_database(), self); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient(zalsa).fields(zalsa, self); $zalsa::return_mode_expression!( $field_option, $field_ty, @@ -262,7 +260,8 @@ macro_rules! setup_interned_struct { /// Default debug formatting for this struct (may be useful if you define your own `Debug` impl) pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { $zalsa::with_attached_database(|db| { - let fields = $Configuration::ingredient(db).fields(db.as_dyn_database(), this); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient(zalsa).fields(zalsa, this); let mut f = f.debug_struct(stringify!($Struct)); $( let f = f.field(stringify!($field_id), &fields.$field_index); diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index a79007592..77325a484 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -175,17 +175,21 @@ macro_rules! setup_tracked_fn { impl $Configuration { fn fn_ingredient(db: &dyn $Db) -> &$zalsa::function::IngredientImpl<$Configuration> { let zalsa = db.zalsa(); + Self::fn_ingredient_(db, zalsa) + } + #[inline] + fn fn_ingredient_<'z>(db: &dyn $Db, zalsa: &'z $zalsa::Zalsa) -> &'z $zalsa::function::IngredientImpl<$Configuration> { // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the first // ingredient created by our jar is the function ingredient. unsafe { $FN_CACHE.get_or_create(zalsa, || zalsa.lookup_jar_by_type::<$fn_name>()) } - .get_or_init(|| ::zalsa_register_downcaster(db)) + .get_or_init(|| *::zalsa_register_downcaster(db)) } pub fn fn_ingredient_mut(db: &mut dyn $Db) -> &mut $zalsa::function::IngredientImpl { - let view = ::zalsa_register_downcaster(db); + let view = *::zalsa_register_downcaster(db); let zalsa_mut = db.zalsa_mut(); let index = zalsa_mut.lookup_jar_by_type::<$fn_name>(); let (ingredient, _) = zalsa_mut.lookup_ingredient_mut(index); @@ -199,7 +203,12 @@ macro_rules! setup_tracked_fn { db: &dyn $Db, ) -> &$zalsa::interned::IngredientImpl<$Configuration> { let zalsa = db.zalsa(); - + Self::intern_ingredient_(zalsa) + } + #[inline] + fn intern_ingredient_<'z>( + zalsa: &'z $zalsa::Zalsa + ) -> &'z $zalsa::interned::IngredientImpl<$Configuration> { // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the second // ingredient created by our jar is the interned ingredient (given `needs_interner`). unsafe { @@ -257,12 +266,12 @@ macro_rules! setup_tracked_fn { $($cycle_recovery_fn)*(db, value, count, $($input_id),*) } - fn id_to_input<$db_lt>(db: &$db_lt Self::DbView, key: salsa::Id) -> Self::Input<$db_lt> { + fn id_to_input<$db_lt>(zalsa: &$db_lt $zalsa::Zalsa, key: salsa::Id) -> Self::Input<$db_lt> { $zalsa::macro_if! { if $needs_interner { - $Configuration::intern_ingredient(db).data(db.as_dyn_database(), key).clone() + $Configuration::intern_ingredient_(zalsa).data(zalsa, key).clone() } else { - $zalsa::FromIdWithDb::from_id(key, db.zalsa()) + $zalsa::FromIdWithDb::from_id(key, zalsa) } } } @@ -340,9 +349,10 @@ macro_rules! setup_tracked_fn { ) -> Vec<&$db_lt A> { use salsa::plumbing as $zalsa; let key = $zalsa::macro_if! { - if $needs_interner { - $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*), |_, data| data) - } else { + if $needs_interner {{ + let (zalsa, zalsa_local) = $db.zalsas(); + $Configuration::intern_ingredient($db).intern_id(zalsa, zalsa_local, ($($input_id),*), |_, data| data) + }} else { $zalsa::AsId::as_id(&($($input_id),*)) } }; @@ -380,14 +390,17 @@ macro_rules! setup_tracked_fn { } $zalsa::attach($db, || { + let (zalsa, zalsa_local) = $db.zalsas(); let result = $zalsa::macro_if! { if $needs_interner { { - let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*), |_, data| data); - $Configuration::fn_ingredient($db).fetch($db, key) + let key = $Configuration::intern_ingredient_(zalsa).intern_id(zalsa, zalsa_local, ($($input_id),*), |_, data| data); + $Configuration::fn_ingredient_($db, zalsa).fetch($db, zalsa, zalsa_local, key) } } else { - $Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&($($input_id),*))) + { + $Configuration::fn_ingredient_($db, zalsa).fetch($db, zalsa, zalsa_local, $zalsa::AsId::as_id(&($($input_id),*))) + } } }; diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index 5545a44fd..f92b1ac5f 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -282,8 +282,9 @@ macro_rules! setup_tracked_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - $Configuration::ingredient(db.as_dyn_database()).new_struct( - db.as_dyn_database(), + let (zalsa, zalsa_local) = db.zalsas(); + $Configuration::ingredient_(zalsa).new_struct( + zalsa,zalsa_local, ($($field_id,)*) ) } @@ -295,8 +296,8 @@ macro_rules! setup_tracked_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let db = db.as_dyn_database(); - let fields = $Configuration::ingredient(db).tracked_field(db, self, $relative_tracked_index); + let (zalsa, zalsa_local) = db.zalsas(); + let fields = $Configuration::ingredient_(zalsa).tracked_field(zalsa, zalsa_local, self, $relative_tracked_index); $crate::return_mode_expression!( $tracked_option, $tracked_ty, @@ -312,8 +313,8 @@ macro_rules! setup_tracked_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let db = db.as_dyn_database(); - let fields = $Configuration::ingredient(db).untracked_field(db, self); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient_(zalsa).untracked_field(zalsa, self); $crate::return_mode_expression!( $untracked_option, $untracked_ty, @@ -335,7 +336,8 @@ macro_rules! setup_tracked_struct { $(for<$db_lt> $field_ty: std::fmt::Debug),* { $zalsa::with_attached_database(|db| { - let fields = $Configuration::ingredient(db).leak_fields(db, this); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient_(zalsa).leak_fields(zalsa, this); let mut f = f.debug_struct(stringify!($Struct)); let f = f.field("[salsa id]", &$zalsa::AsId::as_id(&this)); $( diff --git a/components/salsa-macros/src/db.rs b/components/salsa-macros/src/db.rs index 12ee48917..2c49604ad 100644 --- a/components/salsa-macros/src/db.rs +++ b/components/salsa-macros/src/db.rs @@ -110,18 +110,14 @@ impl DbMacro { let trait_name = &input.ident; input.items.push(parse_quote! { #[doc(hidden)] - fn zalsa_register_downcaster(&self) -> salsa::plumbing::DatabaseDownCaster; + fn zalsa_register_downcaster(&self) -> &salsa::plumbing::DatabaseDownCaster; }); - let comment = format!(" Downcast a [`dyn Database`] to a [`dyn {trait_name}`]"); + let comment = format!(" downcast `Self` to a [`dyn {trait_name}`]"); input.items.push(parse_quote! { #[doc = #comment] - /// - /// # Safety - /// - /// The input database must be of type `Self`. #[doc(hidden)] - unsafe fn downcast(db: &dyn salsa::plumbing::Database) -> &dyn #trait_name where Self: Sized; + fn downcast(&self) -> &dyn #trait_name where Self: Sized; }); Ok(()) } @@ -138,17 +134,17 @@ impl DbMacro { #[cold] #[inline(never)] #[doc(hidden)] - fn zalsa_register_downcaster(&self) -> salsa::plumbing::DatabaseDownCaster { - salsa::plumbing::views(self).add(::downcast) + fn zalsa_register_downcaster(&self) -> &salsa::plumbing::DatabaseDownCaster { + salsa::plumbing::views(self).add::(unsafe { + ::std::mem::transmute(::downcast as fn(_) -> _) + }) } }); input.items.push(parse_quote! { #[doc(hidden)] #[inline(always)] - unsafe fn downcast(db: &dyn salsa::plumbing::Database) -> &dyn #TraitPath where Self: Sized { - debug_assert_eq!(db.type_id(), ::core::any::TypeId::of::()); - // SAFETY: The input database must be of type `Self`. - unsafe { &*salsa::plumbing::transmute_data_ptr::(db) } + fn downcast(&self) -> &dyn #TraitPath where Self: Sized { + self } }); Ok(()) diff --git a/src/accumulator.rs b/src/accumulator.rs index 3b1358c60..4bd1280a7 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -102,7 +102,8 @@ impl Ingredient for IngredientImpl { unsafe fn maybe_changed_after( &self, - _db: &dyn Database, + _zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabase<'_>, _input: Id, _revision: Revision, _cycle_heads: &mut CycleHeads, diff --git a/src/database.rs b/src/database.rs index b840398ff..30178b2da 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,13 +1,39 @@ -use std::any::Any; use std::borrow::Cow; +use std::ptr::NonNull; use crate::views::DatabaseDownCaster; use crate::zalsa::{IngredientIndex, ZalsaDatabase}; use crate::{Durability, Revision}; +#[derive(Copy, Clone)] +pub struct RawDatabase<'db> { + pub(crate) ptr: NonNull<()>, + _marker: std::marker::PhantomData<&'db dyn Database>, +} + +impl<'db, Db: Database + ?Sized> From<&'db Db> for RawDatabase<'db> { + #[inline] + fn from(db: &'db Db) -> Self { + RawDatabase { + ptr: NonNull::from(db).cast(), + _marker: std::marker::PhantomData, + } + } +} + +impl<'db, Db: Database + ?Sized> From<&'db mut Db> for RawDatabase<'db> { + #[inline] + fn from(db: &'db mut Db) -> Self { + RawDatabase { + ptr: NonNull::from(db).cast(), + _marker: std::marker::PhantomData, + } + } +} + /// The trait implemented by all Salsa databases. /// You can create your own subtraits of this trait using the `#[salsa::db]`(`crate::db`) procedural macro. -pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { +pub trait Database: Send + ZalsaDatabase + AsDynDatabase { /// Enforces current LRU limits, evicting entries if necessary. /// /// **WARNING:** Just like an ordinary write, this method triggers @@ -84,28 +110,27 @@ pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { #[cold] #[inline(never)] #[doc(hidden)] - fn zalsa_register_downcaster(&self) -> DatabaseDownCaster { + fn zalsa_register_downcaster(&self) -> &DatabaseDownCaster { self.zalsa().views().downcaster_for::() // The no-op downcaster is special cased in view caster construction. } #[doc(hidden)] #[inline(always)] - unsafe fn downcast(db: &dyn Database) -> &dyn Database + fn downcast(&self) -> &dyn Database where Self: Sized, { // No-op - db + self } } /// Upcast to a `dyn Database`. /// -/// Only required because upcasts not yet stabilized (*grr*). +/// Only required because upcasting does not work for unsized generic parameters. pub trait AsDynDatabase { fn as_dyn_database(&self) -> &dyn Database; - fn as_dyn_database_mut(&mut self) -> &mut dyn Database; } impl AsDynDatabase for T { @@ -113,30 +138,12 @@ impl AsDynDatabase for T { fn as_dyn_database(&self) -> &dyn Database { self } - - #[inline(always)] - fn as_dyn_database_mut(&mut self) -> &mut dyn Database { - self - } } pub fn current_revision(db: &Db) -> Revision { db.zalsa().current_revision() } -impl dyn Database { - /// Upcasts `self` to the given view. - /// - /// # Panics - /// - /// If the view has not been added to the database (see [`crate::views::Views`]). - #[track_caller] - pub fn as_view(&self) -> &DbView { - let views = self.zalsa().views(); - views.downcaster_for().downcast(self) - } -} - #[cfg(feature = "salsa_unstable")] pub use memory_usage::IngredientInfo; diff --git a/src/function.rs b/src/function.rs index ceb006feb..7642d4bab 100644 --- a/src/function.rs +++ b/src/function.rs @@ -10,6 +10,7 @@ use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues use crate::cycle::{ empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, ProvisionalStatus, }; +use crate::database::RawDatabase; use crate::function::delete::DeletedEntries; use crate::function::sync::{ClaimResult, SyncTable}; use crate::ingredient::{Ingredient, WaitForResult}; @@ -22,7 +23,7 @@ use crate::table::Table; use crate::views::DatabaseDownCaster; use crate::zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}; use crate::zalsa_local::QueryOriginRef; -use crate::{Database, Id, Revision}; +use crate::{Id, Revision}; mod accumulated; mod backdate; @@ -68,10 +69,9 @@ pub trait Configuration: Any { /// This invokes user code in form of the `Eq` impl. fn values_equal<'db>(old_value: &Self::Output<'db>, new_value: &Self::Output<'db>) -> bool; - // FIXME: This should take a `&Zalsa` /// Convert from the id used internally to the value that execute is expecting. /// This is a no-op if the input to the function is a salsa struct. - fn id_to_input(db: &Self::DbView, key: Id) -> Self::Input<'_>; + fn id_to_input(zalsa: &Zalsa, key: Id) -> Self::Input<'_>; /// Returns the size of any heap allocations in the output value, in bytes. fn heap_size(_value: &Self::Output<'_>) -> usize { @@ -124,7 +124,7 @@ pub struct IngredientImpl { /// Used to find memos to throw out when we have too many memoized values. lru: lru::Lru, - /// A downcaster from `dyn Database` to `C::DbView`. + /// An downcaster to `C::DbView`. /// /// # Safety /// @@ -261,7 +261,8 @@ where unsafe fn maybe_changed_after( &self, - db: &dyn Database, + _zalsa: &crate::zalsa::Zalsa, + db: RawDatabase<'_>, input: Id, revision: Revision, cycle_heads: &mut CycleHeads, @@ -370,12 +371,13 @@ where C::CYCLE_STRATEGY } - fn accumulated<'db>( + unsafe fn accumulated<'db>( &'db self, - db: &'db dyn Database, + db: RawDatabase<'db>, key_index: Id, ) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) { - let db = self.view_caster().downcast(db); + // SAFETY: The `db` belongs to the ingredient as per caller invariant + let db = unsafe { self.view_caster().downcast_unchecked(db) }; self.accumulated_map(db, key_index) } } diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 47fe09a84..a65804e64 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -4,7 +4,7 @@ use crate::function::{Configuration, IngredientImpl}; use crate::hash::FxHashSet; use crate::zalsa::ZalsaDatabase; use crate::zalsa_local::QueryOriginRef; -use crate::{AsDynDatabase, DatabaseKeyIndex, Id}; +use crate::{DatabaseKeyIndex, Id}; impl IngredientImpl where @@ -37,9 +37,8 @@ where let mut output = vec![]; // First ensure the result is up to date - self.fetch(db, key); + self.fetch(db, zalsa, zalsa_local, key); - let db = db.as_dyn_database(); let db_key = self.database_key_index(key); let mut visited: FxHashSet = FxHashSet::default(); let mut stack: Vec = vec![db_key]; @@ -54,7 +53,9 @@ where let ingredient = zalsa.lookup_ingredient(k.ingredient_index()); // Extend `output` with any values accumulated by `k`. - let (accumulated_map, input) = ingredient.accumulated(db, k.key_index()); + // SAFETY: `db` owns the `ingredient` + let (accumulated_map, input) = + unsafe { ingredient.accumulated(db.into(), k.key_index()) }; if let Some(accumulated_map) = accumulated_map { accumulated_map.extend_with_accumulated(accumulator.index(), &mut output); } diff --git a/src/function/execute.rs b/src/function/execute.rs index 9013ee7fe..2690d1a5c 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -4,7 +4,7 @@ use crate::function::{Configuration, IngredientImpl}; use crate::sync::atomic::{AtomicBool, Ordering}; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; use crate::zalsa_local::{ActiveQueryGuard, QueryRevisions}; -use crate::{Event, EventKind, Id, Revision}; +use crate::{Event, EventKind, Id}; impl IngredientImpl where @@ -41,16 +41,11 @@ where let (new_value, mut revisions) = match C::CYCLE_STRATEGY { CycleRecoveryStrategy::Panic => { - Self::execute_query(db, active_query, opt_old_memo, zalsa.current_revision(), id) + Self::execute_query(db, zalsa, active_query, opt_old_memo, id) } CycleRecoveryStrategy::FallbackImmediate => { - let (mut new_value, mut revisions) = Self::execute_query( - db, - active_query, - opt_old_memo, - zalsa.current_revision(), - id, - ); + let (mut new_value, mut revisions) = + Self::execute_query(db, zalsa, active_query, opt_old_memo, id); if let Some(cycle_heads) = revisions.cycle_heads_mut() { // Did the new result we got depend on our own provisional value, in a cycle? @@ -77,7 +72,7 @@ where let active_query = db .zalsa_local() .push_query(database_key_index, IterationCount::initial()); - new_value = C::cycle_initial(db, C::id_to_input(db, id)); + new_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); revisions = active_query.pop(); // We need to set `cycle_heads` and `verified_final` because it needs to propagate to the callers. // When verifying this, we will see we have fallback and mark ourselves verified. @@ -136,13 +131,8 @@ where let mut opt_last_provisional: Option<&Memo<'db, C>> = None; loop { let previous_memo = opt_last_provisional.or(opt_old_memo); - let (mut new_value, mut revisions) = Self::execute_query( - db, - active_query, - previous_memo, - zalsa.current_revision(), - id, - ); + let (mut new_value, mut revisions) = + Self::execute_query(db, zalsa, active_query, previous_memo, id); // Did the new result we got depend on our own provisional value, in a cycle? if let Some(cycle_heads) = revisions @@ -192,7 +182,7 @@ where db, &new_value, iteration_count.as_u32(), - C::id_to_input(db, id), + C::id_to_input(zalsa, id), ) { crate::CycleRecoveryAction::Iterate => {} crate::CycleRecoveryAction::Fallback(fallback_value) => { @@ -258,9 +248,9 @@ where #[inline] fn execute_query<'db>( db: &'db C::DbView, + zalsa: &'db Zalsa, active_query: ActiveQueryGuard<'db>, opt_old_memo: Option<&Memo<'db, C>>, - current_revision: Revision, id: Id, ) -> (C::Output<'db>, QueryRevisions) { if let Some(old_memo) = opt_old_memo { @@ -275,14 +265,16 @@ where // * ensure that tracked struct created during the previous iteration // (and are owned by the query) are alive even if the query in this iteration no longer creates them. // * ensure the final returned memo depends on all inputs from all iterations. - if old_memo.may_be_provisional() && old_memo.verified_at.load() == current_revision { + if old_memo.may_be_provisional() + && old_memo.verified_at.load() == zalsa.current_revision() + { active_query.seed_iteration(&old_memo.revisions); } } // Query was not previously executed, or value is potentially // stale, or value is absent. Let's execute! - let new_value = C::execute(db, C::id_to_input(db, id)); + let new_value = C::execute(db, C::id_to_input(zalsa, id)); (new_value, active_query.pop()) } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 252bba124..d6de9d9cb 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -2,7 +2,7 @@ use crate::cycle::{CycleHeads, CycleRecoveryStrategy, IterationCount}; use crate::function::memo::Memo; use crate::function::sync::ClaimResult; use crate::function::{Configuration, IngredientImpl, VerifyResult}; -use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; +use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{QueryRevisions, ZalsaLocal}; use crate::Id; @@ -10,8 +10,13 @@ impl IngredientImpl where C: Configuration, { - pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> &'db C::Output<'db> { - let (zalsa, zalsa_local) = db.zalsas(); + pub fn fetch<'db>( + &'db self, + db: &'db C::DbView, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, + id: Id, + ) -> &'db C::Output<'db> { zalsa.unwind_if_revision_cancelled(zalsa_local); let database_key_index = self.database_key_index(id); @@ -175,7 +180,7 @@ where inserting and returning fixpoint initial value" ); let revisions = QueryRevisions::fixpoint_initial(database_key_index); - let initial_value = C::cycle_initial(db, C::id_to_input(db, id)); + let initial_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); Some(self.insert_memo( zalsa, id, @@ -189,7 +194,7 @@ where ); let active_query = zalsa_local.push_query(database_key_index, IterationCount::initial()); - let fallback_value = C::cycle_initial(db, C::id_to_input(db, id)); + let fallback_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); let mut revisions = active_query.pop(); revisions.set_cycle_heads(CycleHeads::initial(database_key_index)); // We need this for `cycle_heads()` to work. We will unset this in the outer `execute()`. diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 7df0a41fe..20e82d1fa 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -7,7 +7,7 @@ use crate::key::DatabaseKeyIndex; use crate::sync::atomic::Ordering; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; use crate::zalsa_local::{QueryEdgeKind, QueryOriginRef, ZalsaLocal}; -use crate::{AsDynDatabase as _, Id, Revision}; +use crate::{Id, Revision}; /// Result of memo validation. pub enum VerifyResult { @@ -434,8 +434,6 @@ where return VerifyResult::Changed; } - let dyn_db = db.as_dyn_database(); - let mut inputs = InputAccumulatedValues::Empty; // Fully tracked inputs? Iterate over the inputs and check them, one by one. // @@ -447,7 +445,7 @@ where match edge.kind() { QueryEdgeKind::Input(dependency_index) => { match dependency_index.maybe_changed_after( - dyn_db, + db.into(), zalsa, old_memo.verified_at.load(), cycle_heads, diff --git a/src/function/memo.rs b/src/function/memo.rs index d6a872b69..a478b1d46 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -497,7 +497,7 @@ mod _memory_usage { unimplemented!() } - fn id_to_input(_: &Self::DbView, _: Id) -> Self::Input<'_> { + fn id_to_input(_: &Zalsa, _: Id) -> Self::Input<'_> { unimplemented!() } diff --git a/src/ingredient.rs b/src/ingredient.rs index fb567948b..12b8ebcba 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -5,6 +5,7 @@ use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues use crate::cycle::{ empty_cycle_heads, CycleHeads, CycleRecoveryStrategy, IterationCount, ProvisionalStatus, }; +use crate::database::RawDatabase; use crate::function::VerifyResult; use crate::runtime::Running; use crate::sync::Arc; @@ -12,7 +13,7 @@ use crate::table::memo::MemoTableTypes; use crate::table::Table; use crate::zalsa::{transmute_data_mut_ptr, transmute_data_ptr, IngredientIndex, Zalsa}; use crate::zalsa_local::QueryOriginRef; -use crate::{Database, DatabaseKeyIndex, Id, Revision}; +use crate::{DatabaseKeyIndex, Id, Revision}; /// A "jar" is a group of ingredients that are added atomically. /// @@ -45,9 +46,10 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// # Safety /// /// The passed in database needs to be the same one that the ingredient was created with. - unsafe fn maybe_changed_after<'db>( - &'db self, - db: &'db dyn Database, + unsafe fn maybe_changed_after( + &self, + zalsa: &crate::zalsa::Zalsa, + db: crate::database::RawDatabase<'_>, input: Id, revision: Revision, cycle_heads: &mut CycleHeads, @@ -159,9 +161,13 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// What values were accumulated during the creation of the value at `key_index` /// (if any). - fn accumulated<'db>( + /// + /// # Safety + /// + /// The passed in database needs to be the same one that the ingredient was created with. + unsafe fn accumulated<'db>( &'db self, - db: &'db dyn Database, + db: RawDatabase<'db>, key_index: Id, ) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) { let _ = (db, key_index); @@ -171,7 +177,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// Returns memory usage information about any instances of the ingredient, /// if applicable. #[cfg(feature = "salsa_unstable")] - fn memory_usage(&self, _db: &dyn Database) -> Option> { + fn memory_usage(&self, _db: &dyn crate::Database) -> Option> { None } } @@ -179,7 +185,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { impl dyn Ingredient { /// Equivalent to the `downcast` method on `Any`. /// - /// Because we do not have dyn-upcasting support, we need this workaround. + /// Because we do not have dyn-downcasting support, we need this workaround. pub fn assert_type(&self) -> &T { assert_eq!( self.type_id(), @@ -195,7 +201,7 @@ impl dyn Ingredient { /// Equivalent to the `downcast` methods on `Any`. /// - /// Because we do not have dyn-upcasting support, we need this workaround. + /// Because we do not have dyn-downcasting support, we need this workaround. /// /// # Safety /// @@ -214,7 +220,7 @@ impl dyn Ingredient { /// Equivalent to the `downcast` method on `Any`. /// - /// Because we do not have dyn-upcasting support, we need this workaround. + /// Because we do not have dyn-downcasting support, we need this workaround. pub fn assert_type_mut(&mut self) -> &mut T { assert_eq!( Any::type_id(self), diff --git a/src/input.rs b/src/input.rs index fe25d9b91..af6648e73 100644 --- a/src/input.rs +++ b/src/input.rs @@ -19,7 +19,7 @@ use crate::sync::Arc; use crate::table::memo::{MemoTable, MemoTableTypes}; use crate::table::{Slot, Table}; use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{Database, Durability, Id, Revision, Runtime}; +use crate::{zalsa_local, Durability, Id, Revision, Runtime}; pub trait Configuration: Any { const DEBUG_NAME: &'static str; @@ -104,13 +104,12 @@ impl IngredientImpl { pub fn new_input( &self, - db: &dyn Database, + zalsa: &Zalsa, + zalsa_local: &zalsa_local::ZalsaLocal, fields: C::Fields, revisions: C::Revisions, durabilities: C::Durabilities, ) -> C::Struct { - let (zalsa, zalsa_local) = db.zalsas(); - let id = self.singleton.with_scope(|| { zalsa_local.allocate(zalsa, self.ingredient_index, |_| Value:: { fields, @@ -177,11 +176,11 @@ impl IngredientImpl { /// The caller is responsible for selecting the appropriate element. pub fn field<'db>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, + zalsa_local: &'db zalsa_local::ZalsaLocal, id: C::Struct, field_index: usize, ) -> &'db C::Fields { - let (zalsa, zalsa_local) = db.zalsas(); let field_ingredient_index = self.ingredient_index.successor(field_index); let id = id.as_id(); let value = Self::data(zalsa, id); @@ -197,17 +196,13 @@ impl IngredientImpl { #[cfg(feature = "salsa_unstable")] /// Returns all data corresponding to the input struct. - pub fn entries<'db>( - &'db self, - db: &'db dyn crate::Database, - ) -> impl Iterator> { - db.zalsa().table().slots_of::>() + pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { + zalsa.table().slots_of::>() } /// Peek at the field values without recording any read dependency. /// Used for debug printouts. - pub fn leak_fields<'db>(&'db self, db: &'db dyn Database, id: C::Struct) -> &'db C::Fields { - let zalsa = db.zalsa(); + pub fn leak_fields<'db>(&'db self, zalsa: &'db Zalsa, id: C::Struct) -> &'db C::Fields { let id = id.as_id(); let value = Self::data(zalsa, id); &value.fields @@ -225,7 +220,8 @@ impl Ingredient for IngredientImpl { unsafe fn maybe_changed_after( &self, - _db: &dyn Database, + _zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabase<'_>, _input: Id, _revision: Revision, _cycle_heads: &mut CycleHeads, @@ -249,9 +245,9 @@ impl Ingredient for IngredientImpl { /// Returns memory usage information about any inputs. #[cfg(feature = "salsa_unstable")] - fn memory_usage(&self, db: &dyn Database) -> Option> { + fn memory_usage(&self, db: &dyn crate::Database) -> Option> { let memory_usage = self - .entries(db) + .entries(db.zalsa()) // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. .map(|value| unsafe { value.memory_usage(&self.memo_table_types) }) diff --git a/src/input/input_field.rs b/src/input/input_field.rs index f0e4856c8..82ed9889d 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -8,7 +8,7 @@ use crate::input::{Configuration, IngredientImpl, Value}; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::zalsa::IngredientIndex; -use crate::{Database, Id, Revision}; +use crate::{Id, Revision}; /// Ingredient used to represent the fields of a `#[salsa::input]`. /// @@ -52,12 +52,12 @@ where unsafe fn maybe_changed_after( &self, - db: &dyn Database, + zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabase<'_>, input: Id, revision: Revision, _cycle_heads: &mut CycleHeads, ) -> VerifyResult { - let zalsa = db.zalsa(); let value = >::data(zalsa, input); VerifyResult::changed_if(value.revisions[self.field_index] > revision) } diff --git a/src/interned.rs b/src/interned.rs index 812512ff6..e3aecd309 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -21,7 +21,7 @@ use crate::sync::{Arc, Mutex, OnceLock}; use crate::table::memo::{MemoTable, MemoTableTypes, MemoTableWithTypesMut}; use crate::table::Slot; use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{Database, DatabaseKeyIndex, Event, EventKind, Id, Revision}; +use crate::{DatabaseKeyIndex, Event, EventKind, Id, Revision}; /// Trait that defines the key properties of an interned struct. /// @@ -296,7 +296,8 @@ where /// the database ends up trying to intern or allocate a new value. pub fn intern<'db, Key>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, key: Key, assemble: impl FnOnce(Id, Key) -> C::Fields<'db>, ) -> C::Struct<'db> @@ -304,7 +305,7 @@ where Key: Hash, C::Fields<'db>: HashEqLike, { - FromId::from_id(self.intern_id(db, key, assemble)) + FromId::from_id(self.intern_id(zalsa, zalsa_local, key, assemble)) } /// Intern data to a unique reference. @@ -319,7 +320,8 @@ where /// the database ends up trying to intern or allocate a new value. pub fn intern_id<'db, Key>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, key: Key, assemble: impl FnOnce(Id, Key) -> C::Fields<'db>, ) -> crate::Id @@ -331,8 +333,6 @@ where // so instead we go with this and transmute the lifetime in the `eq` closure C::Fields<'db>: HashEqLike, { - let (zalsa, zalsa_local) = db.zalsas(); - // Record the current revision as active. let current_revision = zalsa.current_revision(); self.revision_queue.record(current_revision); @@ -735,8 +735,7 @@ where } /// Lookup the data for an interned value based on its ID. - pub fn data<'db>(&'db self, db: &'db dyn Database, id: Id) -> &'db C::Fields<'db> { - let zalsa = db.zalsa(); + pub fn data<'db>(&'db self, zalsa: &'db Zalsa, id: Id) -> &'db C::Fields<'db> { let value = zalsa.table().get::>(id); debug_assert!( @@ -761,12 +760,12 @@ where /// Lookup the fields from an interned struct. /// /// Note that this is not "leaking" since no dependency edge is required. - pub fn fields<'db>(&'db self, db: &'db dyn Database, s: C::Struct<'db>) -> &'db C::Fields<'db> { - self.data(db, AsId::as_id(&s)) + pub fn fields<'db>(&'db self, zalsa: &'db Zalsa, s: C::Struct<'db>) -> &'db C::Fields<'db> { + self.data(zalsa, AsId::as_id(&s)) } - pub fn reset(&mut self, db: &mut dyn Database) { - _ = db.zalsa_mut(); + pub fn reset(&mut self, zalsa_mut: &mut Zalsa) { + _ = zalsa_mut; for shard in self.shards.iter() { // We can clear the key maps now that we have cancelled all other handles. @@ -776,11 +775,8 @@ where #[cfg(feature = "salsa_unstable")] /// Returns all data corresponding to the interned struct. - pub fn entries<'db>( - &'db self, - db: &'db dyn crate::Database, - ) -> impl Iterator> { - db.zalsa().table().slots_of::>() + pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { + zalsa.table().slots_of::>() } } @@ -798,13 +794,12 @@ where unsafe fn maybe_changed_after( &self, - db: &dyn Database, + zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabase<'_>, input: Id, _revision: Revision, _cycle_heads: &mut CycleHeads, ) -> VerifyResult { - let zalsa = db.zalsa(); - // Record the current revision as active. let current_revision = zalsa.current_revision(); self.revision_queue.record(current_revision); @@ -852,7 +847,7 @@ where /// Returns memory usage information about any interned values. #[cfg(all(not(feature = "shuttle"), feature = "salsa_unstable"))] - fn memory_usage(&self, db: &dyn Database) -> Option> { + fn memory_usage(&self, db: &dyn crate::Database) -> Option> { use parking_lot::lock_api::RawMutex; for shard in self.shards.iter() { @@ -861,7 +856,7 @@ where } let memory_usage = self - .entries(db) + .entries(db.zalsa()) // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. Additionally, we are holding the locks for all shards. .map(|value| unsafe { value.memory_usage(&self.memo_table_types) }) diff --git a/src/key.rs b/src/key.rs index 5883ef9cb..80904e978 100644 --- a/src/key.rs +++ b/src/key.rs @@ -3,7 +3,7 @@ use core::fmt; use crate::cycle::CycleHeads; use crate::function::VerifyResult; use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{Database, Id}; +use crate::Id; // ANCHOR: DatabaseKeyIndex /// An integer that uniquely identifies a particular query instance within the @@ -36,16 +36,18 @@ impl DatabaseKeyIndex { pub(crate) fn maybe_changed_after( &self, - db: &dyn Database, + db: crate::database::RawDatabase<'_>, zalsa: &Zalsa, last_verified_at: crate::Revision, cycle_heads: &mut CycleHeads, ) -> VerifyResult { // SAFETY: The `db` belongs to the ingredient unsafe { + // here, `db` has to be either the correct type already, or a subtype (as far as trait + // hierarchy is concerned) zalsa .lookup_ingredient(self.ingredient_index()) - .maybe_changed_after(db, self.key_index(), last_verified_at, cycle_heads) + .maybe_changed_after(zalsa, db, self.key_index(), last_verified_at, cycle_heads) } } diff --git a/src/lib.rs b/src/lib.rs index 7bc94eec4..2600d9a33 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,7 +51,7 @@ pub use self::accumulator::Accumulator; pub use self::active_query::Backtrace; pub use self::cancelled::Cancelled; pub use self::cycle::CycleRecoveryAction; -pub use self::database::{AsDynDatabase, Database}; +pub use self::database::Database; pub use self::database_impl::DatabaseImpl; pub use self::durability::Durability; pub use self::event::{Event, EventKind}; diff --git a/src/parallel.rs b/src/parallel.rs index 1d2504b77..8a0bde655 100644 --- a/src/parallel.rs +++ b/src/parallel.rs @@ -1,44 +1,91 @@ use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator}; -use crate::Database; +use crate::{database::RawDatabase, views::DatabaseDownCaster, Database}; pub fn par_map(db: &Db, inputs: impl IntoParallelIterator, op: F) -> C where - Db: Database + ?Sized, + Db: Database + ?Sized + Send, F: Fn(&Db, T) -> R + Sync + Send, T: Send, R: Send + Sync, C: FromParallelIterator, { + let views = db.zalsa().views(); + let caster = &views.downcaster_for::(); + let db_caster = &views.downcaster_for::(); inputs .into_par_iter() - .map_with(DbForkOnClone(db.fork_db()), |db, element| { - op(db.0.as_view(), element) - }) + .map_with( + DbForkOnClone(db.fork_db(), caster, db_caster), + |db, element| op(db.as_view(), element), + ) .collect() } -struct DbForkOnClone(Box); +struct DbForkOnClone<'views, Db: Database + ?Sized>( + RawDatabase<'static>, + &'views DatabaseDownCaster, + &'views DatabaseDownCaster, +); -impl Clone for DbForkOnClone { +// SAFETY: `T: Send` -> `&own T: Send`, `DbForkOnClone` is an owning pointer +unsafe impl Send for DbForkOnClone<'_, Db> {} + +impl DbForkOnClone<'_, Db> { + fn as_view(&self) -> &Db { + // SAFETY: The downcaster ensures that the pointer is valid for the lifetime of the view. + unsafe { self.1.downcast_unchecked(self.0) } + } +} + +impl Drop for DbForkOnClone<'_, Db> { + fn drop(&mut self) { + // SAFETY: `caster` is derived from a `db` fitting for our database clone + let db = unsafe { self.1.downcast_mut_unchecked(self.0) }; + // SAFETY: `db` has been box allocated and leaked by `fork_db` + _ = unsafe { Box::from_raw(db) }; + } +} + +impl Clone for DbForkOnClone<'_, Db> { fn clone(&self) -> Self { - DbForkOnClone(self.0.fork_db()) + DbForkOnClone( + // SAFETY: `caster` is derived from a `db` fitting for our database clone + unsafe { self.2.downcast_unchecked(self.0) }.fork_db(), + self.1, + self.2, + ) } } -pub fn join(db: &Db, a: A, b: B) -> (RA, RB) +pub fn join(db: &Db, a: A, b: B) -> (RA, RB) where A: FnOnce(&Db) -> RA + Send, B: FnOnce(&Db) -> RB + Send, RA: Send, RB: Send, { + #[derive(Copy, Clone)] + struct AssertSend(T); + // SAFETY: We send owning pointers over, which are Send, given the `Db` type parameter above is Send + unsafe impl Send for AssertSend {} + + let caster = &db.zalsa().views().downcaster_for::(); // we need to fork eagerly, as `rayon::join_context` gives us no option to tell whether we get // moved to another thread before the closure is executed - let db_a = db.fork_db(); - let db_b = db.fork_db(); - rayon::join( - move || a(db_a.as_view::()), - move || b(db_b.as_view::()), - ) + let db_a = AssertSend(db.fork_db()); + let db_b = AssertSend(db.fork_db()); + let res = rayon::join( + // SAFETY: `caster` is derived from a `db` fitting for our database clone + move || a(unsafe { caster.downcast_unchecked({ db_a }.0) }), + // SAFETY: `caster` is derived from a `db` fitting for our database clone + move || b(unsafe { caster.downcast_unchecked({ db_b }.0) }), + ); + + // SAFETY: `db` has been box allocated and leaked by `fork_db` + // FIXME: Clean this mess up, RAII + _ = unsafe { Box::from_raw(caster.downcast_mut_unchecked(db_a.0)) }; + // SAFETY: `db` has been box allocated and leaked by `fork_db` + _ = unsafe { Box::from_raw(caster.downcast_mut_unchecked(db_b.0)) }; + res } diff --git a/src/storage.rs b/src/storage.rs index a8c2abec0..f63981e4f 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -2,6 +2,7 @@ use std::marker::PhantomData; use std::panic::RefUnwindSafe; +use crate::database::RawDatabase; use crate::sync::{Arc, Condvar, Mutex}; use crate::zalsa::{ErasedJar, HasJar, Zalsa, ZalsaDatabase}; use crate::zalsa_local::{self, ZalsaLocal}; @@ -245,8 +246,8 @@ unsafe impl ZalsaDatabase for T { } #[inline(always)] - fn fork_db(&self) -> Box { - Box::new(self.clone()) + fn fork_db(&self) -> RawDatabase<'static> { + Box::leak(Box::new(self.clone())).into() } } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index f6f4ea440..ec240ebcb 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -23,7 +23,7 @@ use crate::sync::Arc; use crate::table::memo::{MemoTable, MemoTableTypes, MemoTableWithTypesMut}; use crate::table::{Slot, Table}; use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{Database, Durability, Event, EventKind, Id, Revision}; +use crate::{Durability, Event, EventKind, Id, Revision}; pub mod tracked_field; @@ -375,11 +375,10 @@ where pub fn new_struct<'db>( &'db self, - db: &'db dyn Database, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, mut fields: C::Fields<'db>, ) -> C::Struct<'db> { - let (zalsa, zalsa_local) = db.zalsas(); - let identity_hash = IdentityHash { ingredient_index: self.ingredient_index, hash: crate::hash::hash(&C::untracked_fields(&fields)), @@ -734,11 +733,11 @@ where /// Used for debugging. pub fn leak_fields<'db>( &'db self, - db: &'db dyn Database, + zalsa: &'db Zalsa, s: C::Struct<'db>, ) -> &'db C::Fields<'db> { let id = AsId::as_id(&s); - let data = Self::data(db.zalsa().table(), id); + let data = Self::data(zalsa.table(), id); data.fields() } @@ -748,11 +747,11 @@ where /// The caller is responsible for selecting the appropriate element. pub fn tracked_field<'db>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, s: C::Struct<'db>, relative_tracked_index: usize, ) -> &'db C::Fields<'db> { - let (zalsa, zalsa_local) = db.zalsas(); let id = AsId::as_id(&s); let field_ingredient_index = self.ingredient_index.successor(relative_tracked_index); let data = Self::data(zalsa.table(), id); @@ -776,10 +775,9 @@ where /// The caller is responsible for selecting the appropriate element. pub fn untracked_field<'db>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, s: C::Struct<'db>, ) -> &'db C::Fields<'db> { - let zalsa = db.zalsa(); let id = AsId::as_id(&s); let data = Self::data(zalsa.table(), id); @@ -794,11 +792,8 @@ where #[cfg(feature = "salsa_unstable")] /// Returns all data corresponding to the tracked struct. - pub fn entries<'db>( - &'db self, - db: &'db dyn crate::Database, - ) -> impl Iterator> { - db.zalsa().table().slots_of::>() + pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { + zalsa.table().slots_of::>() } } @@ -816,7 +811,8 @@ where unsafe fn maybe_changed_after( &self, - _db: &dyn Database, + _zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabase<'_>, _input: Id, _revision: Revision, _cycle_heads: &mut CycleHeads, @@ -863,9 +859,9 @@ where /// Returns memory usage information about any tracked structs. #[cfg(feature = "salsa_unstable")] - fn memory_usage(&self, db: &dyn Database) -> Option> { + fn memory_usage(&self, db: &dyn crate::Database) -> Option> { let memory_usage = self - .entries(db) + .entries(db.zalsa()) // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. .map(|value| unsafe { value.memory_usage(&self.memo_table_types) }) diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index ad3e871e8..587e473fa 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -7,7 +7,7 @@ use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::tracked_struct::{Configuration, Value}; use crate::zalsa::IngredientIndex; -use crate::{Database, Id}; +use crate::Id; /// Created for each tracked struct. /// @@ -55,14 +55,14 @@ where self.ingredient_index } - unsafe fn maybe_changed_after<'db>( - &'db self, - db: &'db dyn Database, + unsafe fn maybe_changed_after( + &self, + zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabase<'_>, input: Id, revision: crate::Revision, _cycle_heads: &mut CycleHeads, ) -> VerifyResult { - let zalsa = db.zalsa(); let data = >::data(zalsa.table(), input); let field_changed_at = data.revisions[self.field_index]; VerifyResult::changed_if(field_changed_at > revision) diff --git a/src/views.rs b/src/views.rs index 01a0a2de5..d449779c3 100644 --- a/src/views.rs +++ b/src/views.rs @@ -1,10 +1,15 @@ -use std::any::{Any, TypeId}; +use std::{ + any::{Any, TypeId}, + marker::PhantomData, + mem, + ptr::NonNull, +}; -use crate::Database; +use crate::{database::RawDatabase, Database}; /// A `Views` struct is associated with some specific database type /// (a `DatabaseImpl` for some existential `U`). It contains functions -/// to downcast from `dyn Database` to `dyn DbView` for various traits `DbView` via this specific +/// to downcast to `dyn DbView` for various traits `DbView` via this specific /// database type. /// None of these types are known at compilation time, they are all checked /// dynamically through `TypeId` magic. @@ -13,6 +18,7 @@ pub struct Views { view_casters: boxcar::Vec, } +#[derive(Copy, Clone)] struct ViewCaster { /// The id of the target type `dyn DbView` that we can cast to. target_type_id: TypeId, @@ -20,50 +26,69 @@ struct ViewCaster { /// The name of the target type `dyn DbView` that we can cast to. type_name: &'static str, - /// Type-erased function pointer that downcasts from `dyn Database` to `dyn DbView`. + /// Type-erased function pointer that downcasts to `dyn DbView`. cast: ErasedDatabaseDownCasterSig, } impl ViewCaster { - fn new(func: unsafe fn(&dyn Database) -> &DbView) -> ViewCaster { + fn new(func: DatabaseDownCasterSig) -> ViewCaster { ViewCaster { target_type_id: TypeId::of::(), type_name: std::any::type_name::(), // SAFETY: We are type erasing for storage, taking care of unerasing before we call // the function pointer. cast: unsafe { - std::mem::transmute::, ErasedDatabaseDownCasterSig>( - func, - ) + mem::transmute::, ErasedDatabaseDownCasterSig>(func) }, } } } -type ErasedDatabaseDownCasterSig = unsafe fn(&dyn Database) -> *const (); -type DatabaseDownCasterSig = unsafe fn(&dyn Database) -> &DbView; +type ErasedDatabaseDownCasterSig = unsafe fn(RawDatabase<'_>) -> NonNull<()>; +type DatabaseDownCasterSig = unsafe fn(RawDatabase<'_>) -> NonNull; -pub struct DatabaseDownCaster(TypeId, DatabaseDownCasterSig); +#[repr(transparent)] +pub struct DatabaseDownCaster(ViewCaster, PhantomData DbView>); -impl DatabaseDownCaster { - pub fn downcast<'db>(&self, db: &'db dyn Database) -> &'db DbView { - assert_eq!( - self.0, - db.type_id(), - "Database type does not match the expected type for this `Views` instance" - ); - // SAFETY: We've asserted that the database is correct. - unsafe { (self.1)(db) } +impl Copy for DatabaseDownCaster {} +impl Clone for DatabaseDownCaster { + fn clone(&self) -> Self { + *self } +} +impl DatabaseDownCaster { + /// Downcast `db` to `DbView`. + /// + /// # Safety + /// + /// The caller must ensure that `db` is of the correct type. + #[inline] + pub unsafe fn downcast_unchecked<'db>(&self, db: RawDatabase<'db>) -> &'db DbView { + // SAFETY: The caller must ensure that `db` is of the correct type. + // The returned pointer is live for `'db` due to construction of the downcaster functions. + unsafe { (self.unerased_downcaster())(db).as_ref() } + } /// Downcast `db` to `DbView`. /// /// # Safety /// /// The caller must ensure that `db` is of the correct type. - pub unsafe fn downcast_unchecked<'db>(&self, db: &'db dyn Database) -> &'db DbView { + #[inline] + pub unsafe fn downcast_mut_unchecked<'db>(&self, db: RawDatabase<'db>) -> &'db mut DbView { // SAFETY: The caller must ensure that `db` is of the correct type. - unsafe { (self.1)(db) } + // The returned pointer is live for `'db` due to construction of the downcaster functions. + unsafe { (self.unerased_downcaster())(db).as_mut() } + } + + #[inline] + fn unerased_downcaster(&self) -> DatabaseDownCasterSig { + // SAFETY: The type-erased function pointer is guaranteed to be ABI compatible for `DbView` + unsafe { + mem::transmute::>( + self.0.cast, + ) + } } } @@ -71,58 +96,63 @@ impl Views { pub(crate) fn new() -> Self { let source_type_id = TypeId::of::(); let view_casters = boxcar::Vec::new(); - // special case the no-op transformation, that way we skip out on reconstructing the wide pointer - view_casters.push(ViewCaster::new::(|db| db)); + view_casters.push(ViewCaster::new::(|db| db.ptr.cast::())); Self { source_type_id, view_casters, } } - /// Add a new downcaster from `dyn Database` to `dyn DbView`. - pub fn add( + /// Add a new downcaster to `dyn DbView`. + pub fn add( &self, - func: DatabaseDownCasterSig, - ) -> DatabaseDownCaster { - if let Some(view) = self.try_downcaster_for() { - return view; + func: fn(NonNull) -> NonNull, + ) -> &DatabaseDownCaster { + assert_eq!(self.source_type_id, TypeId::of::()); + let target_type_id = TypeId::of::(); + if let Some((_, caster)) = self + .view_casters + .iter() + .find(|(_, u)| u.target_type_id == target_type_id) + { + // SAFETY: The type-erased function pointer is guaranteed to be valid for `DbView` + return unsafe { &*(&raw const *caster).cast::>() }; } - self.view_casters.push(ViewCaster::new::(func)); - DatabaseDownCaster(self.source_type_id, func) + // SAFETY: We are type erasing the function pointer for storage, and we will unerase it + // before we call it. + let caster = unsafe { + mem::transmute::) -> NonNull, DatabaseDownCasterSig>( + func, + ) + }; + let caster = ViewCaster::new::(caster); + let idx = self.view_casters.push(caster); + // SAFETY: The type-erased function pointer is guaranteed to be valid for `DbView` + unsafe { &*(&raw const self.view_casters[idx]).cast::>() } } - /// Retrieve an downcaster function from `dyn Database` to `dyn DbView`. + /// Retrieve an downcaster function to `dyn DbView`. /// /// # Panics /// - /// If the underlying type of `db` is not the same as the database type this upcasts was created for. - pub fn downcaster_for(&self) -> DatabaseDownCaster { - self.try_downcaster_for().unwrap_or_else(|| { - panic!( - "No downcaster registered for type `{}` in `Views`", - std::any::type_name::(), - ) - }) - } - - /// Retrieve an downcaster function from `dyn Database` to `dyn DbView`, if it exists. - #[inline] - pub fn try_downcaster_for(&self) -> Option> { + /// If the underlying type of `db` is not the same as the database type this downcasts was created for. + pub fn downcaster_for(&self) -> &DatabaseDownCaster { let view_type_id = TypeId::of::(); for (_, view) in self.view_casters.iter() { if view.target_type_id == view_type_id { // SAFETY: We are unerasing the type erased function pointer having made sure the - // `TypeId` matches. - return Some(DatabaseDownCaster(self.source_type_id, unsafe { - std::mem::transmute::>( - view.cast, - ) - })); + // TypeId matches. + return unsafe { + &*((view as *const ViewCaster).cast::>()) + }; } } - None + panic!( + "No downcaster registered for type `{}` in `Views`", + std::any::type_name::(), + ); } } diff --git a/src/zalsa.rs b/src/zalsa.rs index 41ece3cae..1cc6ba5f5 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -5,6 +5,7 @@ use std::panic::RefUnwindSafe; use hashbrown::HashMap; use rustc_hash::FxHashMap; +use crate::database::RawDatabase; use crate::hash::TypeIdHasher; use crate::ingredient::{Ingredient, Jar}; use crate::plumbing::SalsaStructInDb; @@ -52,7 +53,7 @@ pub unsafe trait ZalsaDatabase: Any { /// Clone the database. #[doc(hidden)] - fn fork_db(&self) -> Box; + fn fork_db(&self) -> RawDatabase<'static>; } pub fn views(db: &Db) -> &Views { diff --git a/tests/debug_db_contents.rs b/tests/debug_db_contents.rs index 6ab8b212e..a253d8869 100644 --- a/tests/debug_db_contents.rs +++ b/tests/debug_db_contents.rs @@ -22,14 +22,15 @@ fn tracked_fn(db: &dyn salsa::Database, input: InputStruct) -> TrackedStruct<'_> #[test] fn execute() { + use salsa::plumbing::ZalsaDatabase; let db = salsa::DatabaseImpl::new(); let _ = InternedStruct::new(&db, "Salsa".to_string()); let _ = InternedStruct::new(&db, "Salsa2".to_string()); // test interned structs - let interned = InternedStruct::ingredient(&db) - .entries(&db) + let interned = InternedStruct::ingredient(db.zalsa()) + .entries(db.zalsa()) .collect::>(); assert_eq!(interned.len(), 2); @@ -40,7 +41,7 @@ fn execute() { let input = InputStruct::new(&db, 22); let inputs = InputStruct::ingredient(&db) - .entries(&db) + .entries(db.zalsa()) .collect::>(); assert_eq!(inputs.len(), 1); @@ -50,7 +51,7 @@ fn execute() { let computed = tracked_fn(&db, input).field(&db); assert_eq!(computed, 44); let tracked = TrackedStruct::ingredient(&db) - .entries(&db) + .entries(db.zalsa()) .collect::>(); assert_eq!(tracked.len(), 1); diff --git a/tests/interned-structs.rs b/tests/interned-structs.rs index 931b1ab67..a9db074c4 100644 --- a/tests/interned-structs.rs +++ b/tests/interned-structs.rs @@ -132,13 +132,13 @@ fn interning_boxed() { #[test] fn interned_structs_have_public_ingredients() { - use salsa::plumbing::AsId; + use salsa::plumbing::{AsId, ZalsaDatabase}; let db = salsa::DatabaseImpl::new(); let s = InternedString::new(&db, String::from("Hello, world!")); let underlying_id = s.0; - let data = InternedString::ingredient(&db).data(&db, underlying_id.as_id()); + let data = InternedString::ingredient(db.zalsa()).data(db.zalsa(), underlying_id.as_id()); assert_eq!(data, &(String::from("Hello, world!"),)); } diff --git a/tests/interned-structs_self_ref.rs b/tests/interned-structs_self_ref.rs index 55eb8c06f..3443f3ac2 100644 --- a/tests/interned-structs_self_ref.rs +++ b/tests/interned-structs_self_ref.rs @@ -181,7 +181,8 @@ const _: () = { String: zalsa_::interned::HashEqLike, { Configuration_::ingredient(db).intern( - db.as_dyn_database(), + db.zalsa(), + db.zalsa_local(), StructKey::<'db>(data, std::marker::PhantomData::default()), |id, data| { StructData( @@ -195,20 +196,20 @@ const _: () = { where Db_: ?Sized + zalsa_::Database, { - let fields = Configuration_::ingredient(db).fields(db.as_dyn_database(), self); + let fields = Configuration_::ingredient(db).fields(db.zalsa(), self); std::clone::Clone::clone((&fields.0)) } fn other(self, db: &'db Db_) -> InternedString<'db> where Db_: ?Sized + zalsa_::Database, { - let fields = Configuration_::ingredient(db).fields(db.as_dyn_database(), self); + let fields = Configuration_::ingredient(db).fields(db.zalsa(), self); std::clone::Clone::clone((&fields.1)) } #[doc = r" Default debug formatting for this struct (may be useful if you define your own `Debug` impl)"] pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { zalsa_::with_attached_database(|db| { - let fields = Configuration_::ingredient(db).fields(db.as_dyn_database(), this); + let fields = Configuration_::ingredient(db).fields(db.zalsa(), this); let mut f = f.debug_struct("InternedString"); let f = f.field("data", &fields.0); let f = f.field("other", &fields.1); From 0ca4f4ff9a780764b7da6d23734183ba2d46fced Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Thu, 31 Jul 2025 11:58:37 -0400 Subject: [PATCH 12/65] remove borrow checks from `ZalsaLocal` (#939) --- src/function/fetch.rs | 14 +- src/function/maybe_changed_after.rs | 67 +++--- src/function/memo.rs | 17 +- src/zalsa_local.rs | 323 +++++++++++++++++----------- 4 files changed, 253 insertions(+), 168 deletions(-) diff --git a/src/function/fetch.rs b/src/function/fetch.rs index d6de9d9cb..f3f79bfac 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -25,6 +25,7 @@ where let _span = crate::tracing::debug_span!("fetch", query = ?database_key_index).entered(); let memo = self.refresh_memo(db, zalsa, zalsa_local, id); + // SAFETY: We just refreshed the memo so it is guaranteed to contain a value now. let memo_value = unsafe { memo.value.as_ref().unwrap_unchecked() }; @@ -167,13 +168,16 @@ where } // no provisional value; create/insert/return initial provisional value return match C::CYCLE_STRATEGY { - CycleRecoveryStrategy::Panic => zalsa_local.with_query_stack(|stack| { - panic!( - "dependency graph cycle when querying {database_key_index:#?}, \ + // SAFETY: We do not access the query stack reentrantly. + CycleRecoveryStrategy::Panic => unsafe { + zalsa_local.with_query_stack_unchecked(|stack| { + panic!( + "dependency graph cycle when querying {database_key_index:#?}, \ set cycle_fn/cycle_initial to fixpoint iterate.\n\ Query stack:\n{stack:#?}", - ); - }), + ); + }) + }, CycleRecoveryStrategy::Fixpoint => { crate::tracing::debug!( "hit cycle at {database_key_index:#?}, \ diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 20e82d1fa..dcc17bb22 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -108,13 +108,16 @@ where return None; } ClaimResult::Cycle { .. } => match C::CYCLE_STRATEGY { - CycleRecoveryStrategy::Panic => db.zalsa_local().with_query_stack(|stack| { - panic!( - "dependency graph cycle when validating {database_key_index:#?}, \ + // SAFETY: We do not access the query stack reentrantly. + CycleRecoveryStrategy::Panic => unsafe { + db.zalsa_local().with_query_stack_unchecked(|stack| { + panic!( + "dependency graph cycle when validating {database_key_index:#?}, \ set cycle_fn/cycle_initial to fixpoint iterate.\n\ Query stack:\n{stack:#?}", - ); - }), + ); + }) + }, CycleRecoveryStrategy::FallbackImmediate => { return Some(VerifyResult::unchanged()); } @@ -336,32 +339,38 @@ where return true; } - zalsa_local.with_query_stack(|stack| { - cycle_heads.iter().all(|cycle_head| { - stack - .iter() - .rev() - .find(|query| query.database_key_index == cycle_head.database_key_index) - .map(|query| query.iteration_count()) - .or_else(|| { - // If this is a cycle head is owned by another thread that is blocked by this ingredient, - // check if it has the same iteration count. - let ingredient = zalsa - .lookup_ingredient(cycle_head.database_key_index.ingredient_index()); - let wait_result = - ingredient.wait_for(zalsa, cycle_head.database_key_index.key_index()); - - if !wait_result.is_cycle_with_other_thread() { - return None; - } + // SAFETY: We do not access the query stack reentrantly. + unsafe { + zalsa_local.with_query_stack_unchecked(|stack| { + cycle_heads.iter().all(|cycle_head| { + stack + .iter() + .rev() + .find(|query| query.database_key_index == cycle_head.database_key_index) + .map(|query| query.iteration_count()) + .or_else(|| { + // If this is a cycle head is owned by another thread that is blocked by this ingredient, + // check if it has the same iteration count. + let ingredient = zalsa.lookup_ingredient( + cycle_head.database_key_index.ingredient_index(), + ); + let wait_result = ingredient + .wait_for(zalsa, cycle_head.database_key_index.key_index()); + + if !wait_result.is_cycle_with_other_thread() { + return None; + } - let provisional_status = ingredient - .provisional_status(zalsa, cycle_head.database_key_index.key_index())?; - provisional_status.iteration() - }) - == Some(cycle_head.iteration_count) + let provisional_status = ingredient.provisional_status( + zalsa, + cycle_head.database_key_index.key_index(), + )?; + provisional_status.iteration() + }) + == Some(cycle_head.iteration_count) + }) }) - }) + } } /// VerifyResult::Unchanged if the memo's value and `changed_at` time is up-to-date in the diff --git a/src/function/memo.rs b/src/function/memo.rs index a478b1d46..810e5b268 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -236,14 +236,17 @@ impl<'db, C: Configuration> Memo<'db, C> { return true; } - zalsa_local.with_query_stack(|stack| { - cycle_heads.iter().all(|cycle_head| { - stack - .iter() - .rev() - .any(|query| query.database_key_index == cycle_head.database_key_index) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + zalsa_local.with_query_stack_unchecked(|stack| { + cycle_heads.iter().all(|cycle_head| { + stack + .iter() + .rev() + .any(|query| query.database_key_index == cycle_head.database_key_index) + }) }) - }) + } } /// Cycle heads that should be propagated to dependent queries. diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 294ab0843..8bc49f22b 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -1,4 +1,4 @@ -use std::cell::RefCell; +use std::cell::{RefCell, UnsafeCell}; use std::panic::UnwindSafe; use std::ptr::{self, NonNull}; @@ -32,14 +32,14 @@ pub struct ZalsaLocal { /// Stores the most recent page for a given ingredient. /// This is thread-local to avoid contention. - most_recent_pages: RefCell>, + most_recent_pages: UnsafeCell>, } impl ZalsaLocal { pub(crate) fn new() -> Self { ZalsaLocal { query_stack: RefCell::new(QueryStack::default()), - most_recent_pages: RefCell::new(FxHashMap::default()), + most_recent_pages: UnsafeCell::new(FxHashMap::default()), } } @@ -65,16 +65,17 @@ impl ZalsaLocal { .memo_table_types() .clone() }; + + // SAFETY: `ZalsaLocal` is `!Sync`, and we never expose a reference to this field, + // so we have exclusive access. + let most_recent_pages = unsafe { &mut *self.most_recent_pages.get() }; + // Find the most recent page, pushing a page if needed - let mut page = *self - .most_recent_pages - .borrow_mut() - .entry(ingredient) - .or_insert_with(|| { - zalsa - .table() - .fetch_or_push_page::(ingredient, memo_types) - }); + let mut page = *most_recent_pages.entry(ingredient).or_insert_with(|| { + zalsa + .table() + .fetch_or_push_page::(ingredient, memo_types) + }); loop { // Try to allocate an entry on that page @@ -89,7 +90,7 @@ impl ZalsaLocal { Err(v) => { value = v; page = zalsa.table().push_page::(ingredient, memo_types()); - self.most_recent_pages.borrow_mut().insert(ingredient, page); + most_recent_pages.insert(ingredient, page); } } } @@ -101,50 +102,76 @@ impl ZalsaLocal { database_key_index: DatabaseKeyIndex, iteration_count: IterationCount, ) -> ActiveQueryGuard<'_> { - let mut query_stack = self.query_stack.borrow_mut(); - query_stack.push_new_query(database_key_index, iteration_count); - ActiveQueryGuard { - local_state: self, - database_key_index, - #[cfg(debug_assertions)] - push_len: query_stack.len(), + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked_mut(|stack| { + stack.push_new_query(database_key_index, iteration_count); + + ActiveQueryGuard { + local_state: self, + database_key_index, + #[cfg(debug_assertions)] + push_len: stack.len(), + } + }) } } /// Executes a closure within the context of the current active query stacks (mutable). + /// + /// # Safety + /// + /// The closure cannot access the query stack reentrantly, whether mutable or immutable. #[inline(always)] - pub(crate) fn with_query_stack_mut( + pub(crate) unsafe fn with_query_stack_unchecked_mut( &self, - c: impl UnwindSafe + FnOnce(&mut QueryStack) -> R, + f: impl UnwindSafe + FnOnce(&mut QueryStack) -> R, ) -> R { - c(&mut self.query_stack.borrow_mut()) + // SAFETY: The caller guarantees that the query stack will not be accessed reentrantly. + // Additionally, `ZalsaLocal` is `!Sync`, and we never expose a reference to the query + // stack except through this method, so we have exclusive access. + unsafe { f(&mut self.query_stack.try_borrow_mut().unwrap_unchecked()) } } + /// Executes a closure within the context of the current active query stacks. + /// + /// # Safety + /// + /// No mutable references to the query stack can exist while the closure is executed. #[inline(always)] - pub(crate) fn with_query_stack(&self, c: impl UnwindSafe + FnOnce(&QueryStack) -> R) -> R { - c(&mut self.query_stack.borrow()) + pub(crate) unsafe fn with_query_stack_unchecked( + &self, + f: impl UnwindSafe + FnOnce(&QueryStack) -> R, + ) -> R { + // SAFETY: The caller guarantees that the query stack will not being accessed mutably. + // Additionally, `ZalsaLocal` is `!Sync`, and we never expose a reference to the query + // stack except through this method, so we have exclusive access. + unsafe { f(&self.query_stack.try_borrow().unwrap_unchecked()) } } #[inline(always)] pub(crate) fn try_with_query_stack( &self, - c: impl UnwindSafe + FnOnce(&QueryStack) -> R, + f: impl UnwindSafe + FnOnce(&QueryStack) -> R, ) -> Option { self.query_stack .try_borrow() .ok() .as_ref() - .map(|stack| c(stack)) + .map(|stack| f(stack)) } /// Returns the index of the active query along with its *current* durability/changed-at /// information. As the query continues to execute, naturally, that information may change. pub(crate) fn active_query(&self) -> Option<(DatabaseKeyIndex, Stamp)> { - self.with_query_stack(|stack| { - stack - .last() - .map(|active_query| (active_query.database_key_index, active_query.stamp())) - }) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked(|stack| { + stack + .last() + .map(|active_query| (active_query.database_key_index, active_query.stamp())) + }) + } } /// Add an output to the current query's list of dependencies @@ -155,34 +182,43 @@ impl ZalsaLocal { index: IngredientIndex, value: A, ) -> Result<(), ()> { - self.with_query_stack_mut(|stack| { - if let Some(top_query) = stack.last_mut() { - top_query.accumulate(index, value); - Ok(()) - } else { - Err(()) - } - }) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked_mut(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.accumulate(index, value); + Ok(()) + } else { + Err(()) + } + }) + } } /// Add an output to the current query's list of dependencies pub(crate) fn add_output(&self, entity: DatabaseKeyIndex) { - self.with_query_stack_mut(|stack| { - if let Some(top_query) = stack.last_mut() { - top_query.add_output(entity) - } - }) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked_mut(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_output(entity) + } + }) + } } /// Check whether `entity` is an output of the currently active query (if any) pub(crate) fn is_output_of_active_query(&self, entity: DatabaseKeyIndex) -> bool { - self.with_query_stack_mut(|stack| { - if let Some(top_query) = stack.last_mut() { - top_query.is_output(entity) - } else { - false - } - }) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked_mut(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.is_output(entity) + } else { + false + } + }) + } } /// Register that currently active query reads the given input @@ -203,18 +239,21 @@ impl ZalsaLocal { changed_at ); - self.with_query_stack_mut(|stack| { - if let Some(top_query) = stack.last_mut() { - top_query.add_read( - input, - durability, - changed_at, - has_accumulated, - accumulated_inputs, - cycle_heads, - ); - } - }) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked_mut(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_read( + input, + durability, + changed_at, + has_accumulated, + accumulated_inputs, + cycle_heads, + ); + } + }) + } } /// Register that currently active query reads the given input @@ -232,11 +271,14 @@ impl ZalsaLocal { changed_at ); - self.with_query_stack_mut(|stack| { - if let Some(top_query) = stack.last_mut() { - top_query.add_read_simple(input, durability, changed_at); - } - }) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked_mut(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_read_simple(input, durability, changed_at); + } + }) + } } /// Register that the current query read an untracked value @@ -246,11 +288,14 @@ impl ZalsaLocal { /// * `current_revision`, the current revision #[inline(always)] pub(crate) fn report_untracked_read(&self, current_revision: Revision) { - self.with_query_stack_mut(|stack| { - if let Some(top_query) = stack.last_mut() { - top_query.add_untracked_read(current_revision); - } - }) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked_mut(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_untracked_read(current_revision); + } + }) + } } /// Update the top query on the stack to act as though it read a value @@ -258,11 +303,14 @@ impl ZalsaLocal { // FIXME: Use or remove this. #[allow(dead_code)] pub(crate) fn report_synthetic_read(&self, durability: Durability, revision: Revision) { - self.with_query_stack_mut(|stack| { - if let Some(top_query) = stack.last_mut() { - top_query.add_synthetic_read(durability, revision); - } - }) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked_mut(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_synthetic_read(durability, revision); + } + }) + } } /// Called when the active queries creates an index from the @@ -277,33 +325,42 @@ impl ZalsaLocal { /// * the disambiguator index #[track_caller] pub(crate) fn disambiguate(&self, key: IdentityHash) -> (Stamp, Disambiguator) { - self.with_query_stack_mut(|stack| { - let top_query = stack.last_mut().expect( - "cannot create a tracked struct disambiguator outside of a tracked function", - ); - let disambiguator = top_query.disambiguate(key); - (top_query.stamp(), disambiguator) - }) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked_mut(|stack| { + let top_query = stack.last_mut().expect( + "cannot create a tracked struct disambiguator outside of a tracked function", + ); + let disambiguator = top_query.disambiguate(key); + (top_query.stamp(), disambiguator) + }) + } } #[track_caller] pub(crate) fn tracked_struct_id(&self, identity: &Identity) -> Option { - self.with_query_stack(|stack| { - let top_query = stack - .last() - .expect("cannot create a tracked struct ID outside of a tracked function"); - top_query.tracked_struct_ids().get(identity) - }) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked(|stack| { + let top_query = stack + .last() + .expect("cannot create a tracked struct ID outside of a tracked function"); + top_query.tracked_struct_ids().get(identity) + }) + } } #[track_caller] pub(crate) fn store_tracked_struct_id(&self, identity: Identity, id: Id) { - self.with_query_stack_mut(|stack| { - let top_query = stack - .last_mut() - .expect("cannot store a tracked struct ID outside of a tracked function"); - top_query.tracked_struct_ids_mut().insert(identity, id); - }) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.with_query_stack_unchecked_mut(|stack| { + let top_query = stack + .last_mut() + .expect("cannot store a tracked struct ID outside of a tracked function"); + top_query.tracked_struct_ids_mut().insert(identity, id); + }) + } } #[cold] @@ -922,15 +979,18 @@ pub(crate) struct ActiveQueryGuard<'me> { impl ActiveQueryGuard<'_> { /// Initialize the tracked struct ids with the values from the prior execution. pub(crate) fn seed_tracked_struct_ids(&self, tracked_struct_ids: &[(Identity, Id)]) { - self.local_state.with_query_stack_mut(|stack| { - #[cfg(debug_assertions)] - assert_eq!(stack.len(), self.push_len); - let frame = stack.last_mut().unwrap(); - assert!(frame.tracked_struct_ids().is_empty()); - frame - .tracked_struct_ids_mut() - .clone_from_slice(tracked_struct_ids); - }) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.local_state.with_query_stack_unchecked_mut(|stack| { + #[cfg(debug_assertions)] + assert_eq!(stack.len(), self.push_len); + let frame = stack.last_mut().unwrap(); + assert!(frame.tracked_struct_ids().is_empty()); + frame + .tracked_struct_ids_mut() + .clone_from_slice(tracked_struct_ids); + }) + } } /// Append the given `outputs` to the query's output list. @@ -943,23 +1003,29 @@ impl ActiveQueryGuard<'_> { QueryOriginRef::DerivedUntracked(_) ); - self.local_state.with_query_stack_mut(|stack| { - #[cfg(debug_assertions)] - assert_eq!(stack.len(), self.push_len); - let frame = stack.last_mut().unwrap(); - frame.seed_iteration(durability, changed_at, edges, untracked_read); - }) + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.local_state.with_query_stack_unchecked_mut(|stack| { + #[cfg(debug_assertions)] + assert_eq!(stack.len(), self.push_len); + let frame = stack.last_mut().unwrap(); + frame.seed_iteration(durability, changed_at, edges, untracked_read); + }) + } } /// Invoked when the query has successfully completed execution. fn complete(self) -> QueryRevisions { - let query = self.local_state.with_query_stack_mut(|stack| { - stack.pop_into_revisions( - self.database_key_index, - #[cfg(debug_assertions)] - self.push_len, - ) - }); + // SAFETY: We do not access the query stack reentrantly. + let query = unsafe { + self.local_state.with_query_stack_unchecked_mut(|stack| { + stack.pop_into_revisions( + self.database_key_index, + #[cfg(debug_assertions)] + self.push_len, + ) + }) + }; std::mem::forget(self); query } @@ -975,12 +1041,15 @@ impl ActiveQueryGuard<'_> { impl Drop for ActiveQueryGuard<'_> { fn drop(&mut self) { - self.local_state.with_query_stack_mut(|stack| { - stack.pop( - self.database_key_index, - #[cfg(debug_assertions)] - self.push_len, - ); - }); + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.local_state.with_query_stack_unchecked_mut(|stack| { + stack.pop( + self.database_key_index, + #[cfg(debug_assertions)] + self.push_len, + ); + }) + }; } } From f303b6db56f6c95501c2f27c8ee8823983094568 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Fri, 1 Aug 2025 12:50:02 -0400 Subject: [PATCH 13/65] optimize allocation fast-path (#949) --- src/input.rs | 6 ++++-- src/interned.rs | 3 +-- src/table.rs | 17 ++++++++++++----- src/tracked_struct.rs | 4 +++- src/zalsa_local.rs | 32 ++++++++++++++++++++++++++++---- 5 files changed, 48 insertions(+), 14 deletions(-) diff --git a/src/input.rs b/src/input.rs index af6648e73..81fa55b60 100644 --- a/src/input.rs +++ b/src/input.rs @@ -111,14 +111,16 @@ impl IngredientImpl { durabilities: C::Durabilities, ) -> C::Struct { let id = self.singleton.with_scope(|| { - zalsa_local.allocate(zalsa, self.ingredient_index, |_| Value:: { + let (id, _) = zalsa_local.allocate(zalsa, self.ingredient_index, |_| Value:: { fields, revisions, durabilities, // SAFETY: We only ever access the memos of a value that we allocated through // our `MemoTableTypes`. memos: unsafe { MemoTable::new(self.memo_table_types()) }, - }) + }); + + id }); FromIdWithDb::from_id(id, zalsa) diff --git a/src/interned.rs b/src/interned.rs index e3aecd309..cc3c3ac65 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -583,7 +583,7 @@ where .unwrap_or((Durability::MAX, Revision::max())); // Allocate the value slot. - let id = zalsa_local.allocate(zalsa, self.ingredient_index, |id| Value:: { + let (id, value) = zalsa_local.allocate(zalsa, self.ingredient_index, |id| Value:: { shard: shard_index as u16, link: LinkedListLink::new(), // SAFETY: We only ever access the memos of a value that we allocated through @@ -598,7 +598,6 @@ where }), }); - let value = zalsa.table().get::>(id); // SAFETY: We hold the lock for the shard containing the value. let value_shared = unsafe { &mut *value.shared.get() }; diff --git a/src/table.rs b/src/table.rs index 84415a4a0..62abf2dee 100644 --- a/src/table.rs +++ b/src/table.rs @@ -286,6 +286,8 @@ impl Table { .flat_map(|view| view.data()) } + #[cold] + #[inline(never)] pub(crate) fn fetch_or_push_page( &self, ingredient: IngredientIndex, @@ -299,6 +301,7 @@ impl Table { { return page; } + self.push_page::(ingredient, memo_types()) } @@ -311,22 +314,23 @@ impl Table { } } -impl<'p, T: Slot> PageView<'p, T> { +impl<'db, T: Slot> PageView<'db, T> { #[inline] - fn page_data(&self) -> &'p [PageDataEntry] { + fn page_data(&self) -> &'db [PageDataEntry] { let len = self.0.allocated.load(Ordering::Acquire); // SAFETY: `len` is the initialized length of the page unsafe { slice::from_raw_parts(self.0.data.cast::>().as_ptr(), len) } } #[inline] - fn data(&self) -> &'p [T] { + fn data(&self) -> &'db [T] { let len = self.0.allocated.load(Ordering::Acquire); // SAFETY: `len` is the initialized length of the page unsafe { slice::from_raw_parts(self.0.data.cast::().as_ptr(), len) } } - pub(crate) fn allocate(&self, page: PageIndex, value: V) -> Result + #[inline] + pub(crate) fn allocate(&self, page: PageIndex, value: V) -> Result<(Id, &'db T), V> where V: FnOnce(Id) -> T, { @@ -347,11 +351,14 @@ impl<'p, T: Slot> PageView<'p, T> { // interior unsafe { (*entry.get()).write(value(id)) }; + // SAFETY: We just initialized the value above. + let value = unsafe { (*entry.get()).assume_init_ref() }; + // Update the length (this must be done after initialization as otherwise an uninitialized // read could occur!) self.0.allocated.store(index + 1, Ordering::Release); - Ok(id) + Ok((id, value)) } } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index ec240ebcb..9724093cf 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -477,7 +477,9 @@ where return id; } - zalsa_local.allocate::>(zalsa, self.ingredient_index, value) + let (id, _) = zalsa_local.allocate::>(zalsa, self.ingredient_index, value); + + id } /// Get mutable access to the data for `id` -- this holds a write lock for the duration diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 8bc49f22b..21e03713a 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -53,12 +53,36 @@ impl ZalsaLocal { /// Allocate a new id in `table` for the given ingredient /// storing `value`. Remembers the most recent page from this /// thread and attempts to reuse it. - pub(crate) fn allocate( + pub(crate) fn allocate<'db, T: Slot>( &self, - zalsa: &Zalsa, + zalsa: &'db Zalsa, ingredient: IngredientIndex, mut value: impl FnOnce(Id) -> T, - ) -> Id { + ) -> (Id, &'db T) { + // SAFETY: `ZalsaLocal` is `!Sync`, and we never expose a reference to this field, + // so we have exclusive access. + let most_recent_pages = unsafe { &mut *self.most_recent_pages.get() }; + + // Fast-path, we already have an unfilled page available. + if let Some(&page) = most_recent_pages.get(&ingredient) { + let page_ref = zalsa.table().page::(page); + match page_ref.allocate(page, value) { + Ok((id, value)) => return (id, value), + Err(v) => value = v, + } + } + + self.allocate_cold(zalsa, ingredient, value) + } + + #[cold] + #[inline(never)] + pub(crate) fn allocate_cold<'db, T: Slot>( + &self, + zalsa: &'db Zalsa, + ingredient: IngredientIndex, + mut value: impl FnOnce(Id) -> T, + ) -> (Id, &'db T) { let memo_types = || { zalsa .lookup_ingredient(ingredient) @@ -82,7 +106,7 @@ impl ZalsaLocal { let page_ref = zalsa.table().page::(page); match page_ref.allocate(page, value) { // If successful, return - Ok(id) => return id, + Ok((id, value)) => return (id, value), // Otherwise, create a new page and try again // Note that we could try fetching a page again, but as we just filled one up From 679d82c4e739402bded893ae9fd968a63967b2b1 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sat, 2 Aug 2025 10:41:17 +0200 Subject: [PATCH 14/65] Gate accumulator feature behind a feature flag (#946) --- .github/workflows/test.yml | 4 +- Cargo.toml | 12 ++- components/salsa-macro-rules/Cargo.toml | 3 + .../salsa-macro-rules/src/gate_accumulated.rs | 13 ++++ components/salsa-macro-rules/src/lib.rs | 2 + .../salsa-macro-rules/src/setup_tracked_fn.rs | 30 ++++---- src/active_query.rs | 73 ++++++++++++------- src/function.rs | 8 +- src/function/fetch.rs | 6 +- src/function/maybe_changed_after.rs | 52 ++++++++++--- src/function/specify.rs | 3 + src/ingredient.rs | 12 ++- src/lib.rs | 14 +++- src/zalsa_local.rs | 50 +++++++++---- tests/accumulate-chain.rs | 2 +- tests/accumulate-custom-debug.rs | 2 +- tests/accumulate-dag.rs | 2 +- tests/accumulate-execution-order.rs | 2 +- tests/accumulate-from-tracked-fn.rs | 2 +- tests/accumulate-no-duplicates.rs | 2 +- tests/accumulate-reuse-workaround.rs | 2 +- tests/accumulate-reuse.rs | 2 +- tests/accumulate.rs | 2 +- tests/accumulated_backdate.rs | 2 +- tests/cycle_accumulate.rs | 2 +- 25 files changed, 215 insertions(+), 89 deletions(-) create mode 100644 components/salsa-macro-rules/src/gate_accumulated.rs diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index acb692572..57b0b1d91 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -55,12 +55,10 @@ jobs: run: cargo clippy --workspace --all-targets -- -D warnings - name: Test run: cargo nextest run --workspace --all-targets --no-fail-fast - - name: Test Manual Registration + - name: Test Manual Registration / no-default-features run: cargo nextest run --workspace --tests --no-fail-fast --no-default-features --features macros - name: Test docs run: cargo test --workspace --doc - - name: Check (without default features) - run: cargo check --workspace --no-default-features miri: name: Miri diff --git a/Cargo.toml b/Cargo.toml index 54eac8531..3f3470866 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,9 +38,10 @@ thin-vec = "0.2.13" shuttle = { version = "0.8.0", optional = true } [features] -default = ["salsa_unstable", "rayon", "macros", "inventory"] +default = ["salsa_unstable", "rayon", "macros", "inventory", "accumulator"] inventory = ["dep:inventory"] shuttle = ["dep:shuttle"] +accumulator = ["salsa-macro-rules/accumulator"] # FIXME: remove `salsa_unstable` before 1.0. salsa_unstable = [] macros = ["dep:salsa-macros"] @@ -82,11 +83,20 @@ harness = false [[bench]] name = "accumulator" harness = false +required-features = ["accumulator"] [[bench]] name = "dataflow" harness = false +[[example]] +name = "lazy-input" +required-features = ["accumulator"] + +[[example]] +name = "calc" +required-features = ["accumulator"] + [workspace] members = ["components/salsa-macro-rules", "components/salsa-macros"] diff --git a/components/salsa-macro-rules/Cargo.toml b/components/salsa-macro-rules/Cargo.toml index 5becae0f5..65770e10a 100644 --- a/components/salsa-macro-rules/Cargo.toml +++ b/components/salsa-macro-rules/Cargo.toml @@ -9,3 +9,6 @@ rust-version.workspace = true description = "Declarative macros for the salsa crate" [dependencies] + +[features] +accumulator = [] diff --git a/components/salsa-macro-rules/src/gate_accumulated.rs b/components/salsa-macro-rules/src/gate_accumulated.rs new file mode 100644 index 000000000..d2a86061a --- /dev/null +++ b/components/salsa-macro-rules/src/gate_accumulated.rs @@ -0,0 +1,13 @@ +#[cfg(feature = "accumulator")] +#[macro_export] +macro_rules! gate_accumulated { + ($($body:tt)*) => { + $($body)* + }; +} + +#[cfg(not(feature = "accumulator"))] +#[macro_export] +macro_rules! gate_accumulated { + ($($body:tt)*) => {}; +} diff --git a/components/salsa-macro-rules/src/lib.rs b/components/salsa-macro-rules/src/lib.rs index 897ff4cc5..8a53e6dd7 100644 --- a/components/salsa-macro-rules/src/lib.rs +++ b/components/salsa-macro-rules/src/lib.rs @@ -12,10 +12,12 @@ //! from a submodule is to use multiple crates, hence the existence //! of this crate. +mod gate_accumulated; mod macro_if; mod maybe_backdate; mod maybe_default; mod return_mode; +#[cfg(feature = "accumulator")] mod setup_accumulator_impl; mod setup_input_struct; mod setup_interned_struct; diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 77325a484..50ca6a034 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -343,21 +343,23 @@ macro_rules! setup_tracked_fn { #[allow(non_local_definitions)] impl $fn_name { - pub fn accumulated<$db_lt, A: salsa::Accumulator>( - $db: &$db_lt dyn $Db, - $($input_id: $interned_input_ty,)* - ) -> Vec<&$db_lt A> { - use salsa::plumbing as $zalsa; - let key = $zalsa::macro_if! { - if $needs_interner {{ - let (zalsa, zalsa_local) = $db.zalsas(); - $Configuration::intern_ingredient($db).intern_id(zalsa, zalsa_local, ($($input_id),*), |_, data| data) - }} else { - $zalsa::AsId::as_id(&($($input_id),*)) - } - }; + $zalsa::gate_accumulated! { + pub fn accumulated<$db_lt, A: salsa::Accumulator>( + $db: &$db_lt dyn $Db, + $($input_id: $interned_input_ty,)* + ) -> Vec<&$db_lt A> { + use salsa::plumbing as $zalsa; + let key = $zalsa::macro_if! { + if $needs_interner {{ + let (zalsa, zalsa_local) = $db.zalsas(); + $Configuration::intern_ingredient($db).intern_id(zalsa, zalsa_local, ($($input_id),*), |_, data| data) + }} else { + $zalsa::AsId::as_id(&($($input_id),*)) + } + }; - $Configuration::fn_ingredient($db).accumulated_by::($db, key) + $Configuration::fn_ingredient($db).accumulated_by::($db, key) + } } $zalsa::macro_if! { $is_specifiable => diff --git a/src/active_query.rs b/src/active_query.rs index 7789fe6de..dff64db3e 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -1,7 +1,9 @@ use std::{fmt, mem, ops}; -use crate::accumulator::accumulated_map::{ - AccumulatedMap, AtomicInputAccumulatedValues, InputAccumulatedValues, +#[cfg(feature = "accumulator")] +use crate::accumulator::{ + accumulated_map::{AccumulatedMap, AtomicInputAccumulatedValues, InputAccumulatedValues}, + Accumulator, }; use crate::cycle::{CycleHeads, IterationCount}; use crate::durability::Durability; @@ -11,7 +13,7 @@ use crate::runtime::Stamp; use crate::sync::atomic::AtomicBool; use crate::tracked_struct::{Disambiguator, DisambiguatorMap, IdentityHash, IdentityMap}; use crate::zalsa_local::{QueryEdge, QueryOrigin, QueryRevisions, QueryRevisionsExtra}; -use crate::{Accumulator, IngredientIndex, Revision}; +use crate::Revision; #[derive(Debug)] pub(crate) struct ActiveQuery { @@ -51,10 +53,12 @@ pub(crate) struct ActiveQuery { /// Stores the values accumulated to the given ingredient. /// The type of accumulated value is erased but known to the ingredient. + #[cfg(feature = "accumulator")] accumulated: AccumulatedMap, /// [`InputAccumulatedValues::Empty`] if any input read during the query's execution /// has any accumulated values. + #[cfg(feature = "accumulator")] accumulated_inputs: InputAccumulatedValues, /// Provisional cycle results that this query depends on. @@ -84,18 +88,21 @@ impl ActiveQuery { input: DatabaseKeyIndex, durability: Durability, changed_at: Revision, - has_accumulated: bool, - accumulated_inputs: &AtomicInputAccumulatedValues, cycle_heads: &CycleHeads, + #[cfg(feature = "accumulator")] has_accumulated: bool, + #[cfg(feature = "accumulator")] accumulated_inputs: &AtomicInputAccumulatedValues, ) { self.durability = self.durability.min(durability); self.changed_at = self.changed_at.max(changed_at); self.input_outputs.insert(QueryEdge::input(input)); - self.accumulated_inputs = self.accumulated_inputs.or_else(|| match has_accumulated { - true => InputAccumulatedValues::Any, - false => accumulated_inputs.load(), - }); self.cycle_heads.extend(cycle_heads); + #[cfg(feature = "accumulator")] + { + self.accumulated_inputs = self.accumulated_inputs.or_else(|| match has_accumulated { + true => InputAccumulatedValues::Any, + false => accumulated_inputs.load(), + }); + } } pub(super) fn add_read_simple( @@ -121,7 +128,8 @@ impl ActiveQuery { self.changed_at = self.changed_at.max(revision); } - pub(super) fn accumulate(&mut self, index: IngredientIndex, value: impl Accumulator) { + #[cfg(feature = "accumulator")] + pub(super) fn accumulate(&mut self, index: crate::IngredientIndex, value: impl Accumulator) { self.accumulated.accumulate(index, value); } @@ -169,10 +177,12 @@ impl ActiveQuery { untracked_read: false, disambiguator_map: Default::default(), tracked_struct_ids: Default::default(), - accumulated: Default::default(), - accumulated_inputs: Default::default(), cycle_heads: Default::default(), iteration_count, + #[cfg(feature = "accumulator")] + accumulated: Default::default(), + #[cfg(feature = "accumulator")] + accumulated_inputs: Default::default(), } } @@ -185,10 +195,12 @@ impl ActiveQuery { untracked_read, ref mut disambiguator_map, ref mut tracked_struct_ids, - ref mut accumulated, - accumulated_inputs, ref mut cycle_heads, iteration_count, + #[cfg(feature = "accumulator")] + ref mut accumulated, + #[cfg(feature = "accumulator")] + accumulated_inputs, } = self; let origin = if untracked_read { @@ -198,19 +210,22 @@ impl ActiveQuery { }; disambiguator_map.clear(); + #[cfg(feature = "accumulator")] + let accumulated_inputs = AtomicInputAccumulatedValues::new(accumulated_inputs); let verified_final = cycle_heads.is_empty(); let extra = QueryRevisionsExtra::new( + #[cfg(feature = "accumulator")] mem::take(accumulated), mem::take(tracked_struct_ids), mem::take(cycle_heads), iteration_count, ); - let accumulated_inputs = AtomicInputAccumulatedValues::new(accumulated_inputs); QueryRevisions { changed_at, durability, origin, + #[cfg(feature = "accumulator")] accumulated_inputs, verified_final: AtomicBool::new(verified_final), extra, @@ -226,17 +241,20 @@ impl ActiveQuery { untracked_read: _, disambiguator_map, tracked_struct_ids, - accumulated, - accumulated_inputs: _, cycle_heads, iteration_count, + #[cfg(feature = "accumulator")] + accumulated, + #[cfg(feature = "accumulator")] + accumulated_inputs: _, } = self; input_outputs.clear(); disambiguator_map.clear(); tracked_struct_ids.clear(); - accumulated.clear(); *cycle_heads = Default::default(); *iteration_count = IterationCount::initial(); + #[cfg(feature = "accumulator")] + accumulated.clear(); } fn reset_for( @@ -252,16 +270,17 @@ impl ActiveQuery { untracked_read, disambiguator_map, tracked_struct_ids, - accumulated, - accumulated_inputs, cycle_heads, iteration_count, + #[cfg(feature = "accumulator")] + accumulated, + #[cfg(feature = "accumulator")] + accumulated_inputs, } = self; *database_key_index = new_database_key_index; *durability = Durability::MAX; *changed_at = Revision::start(); *untracked_read = false; - *accumulated_inputs = Default::default(); *iteration_count = new_iteration_count; debug_assert!( input_outputs.is_empty(), @@ -279,10 +298,14 @@ impl ActiveQuery { cycle_heads.is_empty(), "`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called" ); - debug_assert!( - accumulated.is_empty(), - "`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called" - ); + #[cfg(feature = "accumulator")] + { + *accumulated_inputs = Default::default(); + debug_assert!( + accumulated.is_empty(), + "`ActiveQuery::clear` or `ActiveQuery::into_revisions` should've been called" + ); + } } } diff --git a/src/function.rs b/src/function.rs index 7642d4bab..75a046b8a 100644 --- a/src/function.rs +++ b/src/function.rs @@ -6,7 +6,6 @@ use std::sync::atomic::Ordering; use std::sync::OnceLock; pub(crate) use sync::SyncGuard; -use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}; use crate::cycle::{ empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, ProvisionalStatus, }; @@ -25,6 +24,7 @@ use crate::zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}; use crate::zalsa_local::QueryOriginRef; use crate::{Id, Revision}; +#[cfg(feature = "accumulator")] mod accumulated; mod backdate; mod delete; @@ -371,11 +371,15 @@ where C::CYCLE_STRATEGY } + #[cfg(feature = "accumulator")] unsafe fn accumulated<'db>( &'db self, db: RawDatabase<'db>, key_index: Id, - ) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) { + ) -> ( + Option<&'db crate::accumulator::accumulated_map::AccumulatedMap>, + crate::accumulator::accumulated_map::InputAccumulatedValues, + ) { // SAFETY: The `db` belongs to the ingredient as per caller invariant let db = unsafe { self.view_caster().downcast_unchecked(db) }; self.accumulated_map(db, key_index) diff --git a/src/function/fetch.rs b/src/function/fetch.rs index f3f79bfac..c6b1111bf 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -35,9 +35,11 @@ where database_key_index, memo.revisions.durability, memo.revisions.changed_at, + memo.cycle_heads(), + #[cfg(feature = "accumulator")] memo.revisions.accumulated().is_some(), + #[cfg(feature = "accumulator")] &memo.revisions.accumulated_inputs, - memo.cycle_heads(), ); memo_value @@ -221,7 +223,7 @@ where if let Some(old_memo) = opt_old_memo { if old_memo.value.is_some() { let mut cycle_heads = CycleHeads::default(); - if let VerifyResult::Unchanged(_) = + if let VerifyResult::Unchanged { .. } = self.deep_verify_memo(db, zalsa, old_memo, database_key_index, &mut cycle_heads) { if cycle_heads.is_empty() { diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index dcc17bb22..8f3a8762b 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "accumulator")] use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::cycle::{CycleHeads, CycleRecoveryStrategy, IterationCount, ProvisionalStatus}; use crate::function::memo::Memo; @@ -18,7 +19,10 @@ pub enum VerifyResult { /// /// The inner value tracks whether the memo or any of its dependencies have an /// accumulated value. - Unchanged(InputAccumulatedValues), + Unchanged { + #[cfg(feature = "accumulator")] + accumulated: InputAccumulatedValues, + }, } impl VerifyResult { @@ -31,7 +35,10 @@ impl VerifyResult { } pub(crate) fn unchanged() -> Self { - Self::Unchanged(InputAccumulatedValues::Empty) + Self::Unchanged { + #[cfg(feature = "accumulator")] + accumulated: InputAccumulatedValues::Empty, + } } } @@ -71,7 +78,10 @@ where return if memo.revisions.changed_at > revision { VerifyResult::Changed } else { - VerifyResult::Unchanged(memo.revisions.accumulated_inputs.load()) + VerifyResult::Unchanged { + #[cfg(feature = "accumulator")] + accumulated: memo.revisions.accumulated_inputs.load(), + } }; } @@ -146,11 +156,18 @@ where // Check if the inputs are still valid. We can just compare `changed_at`. let deep_verify = self.deep_verify_memo(db, zalsa, old_memo, database_key_index, cycle_heads); - if let VerifyResult::Unchanged(accumulated_inputs) = deep_verify { + if let VerifyResult::Unchanged { + #[cfg(feature = "accumulator")] + accumulated: accumulated_inputs, + } = deep_verify + { return Some(if old_memo.revisions.changed_at > revision { VerifyResult::Changed } else { - VerifyResult::Unchanged(accumulated_inputs) + VerifyResult::Unchanged { + #[cfg(feature = "accumulator")] + accumulated: accumulated_inputs, + } }); } @@ -174,10 +191,13 @@ where return Some(if changed_at > revision { VerifyResult::Changed } else { - VerifyResult::Unchanged(match memo.revisions.accumulated() { - Some(_) => InputAccumulatedValues::Any, - None => memo.revisions.accumulated_inputs.load(), - }) + VerifyResult::Unchanged { + #[cfg(feature = "accumulator")] + accumulated: match memo.revisions.accumulated() { + Some(_) => InputAccumulatedValues::Any, + None => memo.revisions.accumulated_inputs.load(), + }, + } }); } @@ -443,6 +463,7 @@ where return VerifyResult::Changed; } + #[cfg(feature = "accumulator")] let mut inputs = InputAccumulatedValues::Empty; // Fully tracked inputs? Iterate over the inputs and check them, one by one. // @@ -460,9 +481,12 @@ where cycle_heads, ) { VerifyResult::Changed => return VerifyResult::Changed, - VerifyResult::Unchanged(input_accumulated) => { - inputs |= input_accumulated; + #[cfg(feature = "accumulator")] + VerifyResult::Unchanged { accumulated } => { + inputs |= accumulated; } + #[cfg(not(feature = "accumulator"))] + VerifyResult::Unchanged { .. } => {} } } QueryEdgeKind::Output(dependency_index) => { @@ -517,6 +541,7 @@ where // 1 and 3 if cycle_heads.is_empty() { old_memo.mark_as_verified(zalsa, database_key_index); + #[cfg(feature = "accumulator")] old_memo.revisions.accumulated_inputs.store(inputs); if is_provisional { @@ -527,7 +552,10 @@ where } } - VerifyResult::Unchanged(inputs) + VerifyResult::Unchanged { + #[cfg(feature = "accumulator")] + accumulated: inputs, + } } } } diff --git a/src/function/specify.rs b/src/function/specify.rs index 3bc71c565..37e25209e 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "accumulator")] use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl}; @@ -66,6 +67,7 @@ where changed_at: current_deps.changed_at, durability: current_deps.durability, origin: QueryOrigin::assigned(active_query_key), + #[cfg(feature = "accumulator")] accumulated_inputs: Default::default(), verified_final: AtomicBool::new(true), extra: QueryRevisionsExtra::default(), @@ -124,6 +126,7 @@ where let database_key_index = self.database_key_index(key); memo.mark_as_verified(zalsa, database_key_index); + #[cfg(feature = "accumulator")] memo.revisions .accumulated_inputs .store(InputAccumulatedValues::Empty); diff --git a/src/ingredient.rs b/src/ingredient.rs index 12b8ebcba..e8766b5cf 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -1,7 +1,6 @@ use std::any::{Any, TypeId}; use std::fmt; -use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}; use crate::cycle::{ empty_cycle_heads, CycleHeads, CycleRecoveryStrategy, IterationCount, ProvisionalStatus, }; @@ -165,13 +164,20 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// # Safety /// /// The passed in database needs to be the same one that the ingredient was created with. + #[cfg(feature = "accumulator")] unsafe fn accumulated<'db>( &'db self, db: RawDatabase<'db>, key_index: Id, - ) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) { + ) -> ( + Option<&'db crate::accumulator::accumulated_map::AccumulatedMap>, + crate::accumulator::accumulated_map::InputAccumulatedValues, + ) { let _ = (db, key_index); - (None, InputAccumulatedValues::Empty) + ( + None, + crate::accumulator::accumulated_map::InputAccumulatedValues::Empty, + ) } /// Returns memory usage information about any instances of the ingredient, diff --git a/src/lib.rs b/src/lib.rs index 2600d9a33..66c346b20 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #![deny(clippy::undocumented_unsafe_blocks)] #![forbid(unsafe_op_in_unsafe_fn)] +#[cfg(feature = "accumulator")] mod accumulator; mod active_query; mod attach; @@ -47,6 +48,7 @@ pub use salsa_macros::{accumulator, db, input, interned, tracked, Supertype, Upd #[cfg(feature = "salsa_unstable")] pub use self::database::IngredientInfo; +#[cfg(feature = "accumulator")] pub use self::accumulator::Accumulator; pub use self::active_query::Backtrace; pub use self::cancelled::Cancelled; @@ -68,7 +70,9 @@ pub use self::zalsa::IngredientIndex; pub use crate::attach::{attach, with_attached_database}; pub mod prelude { - pub use crate::{Accumulator, Database, Setter}; + #[cfg(feature = "accumulator")] + pub use crate::accumulator::Accumulator; + pub use crate::{Database, Setter}; } /// Internal names used by salsa macros. @@ -81,13 +85,16 @@ pub mod plumbing { pub use std::any::TypeId; pub use std::option::Option::{self, None, Some}; + #[cfg(feature = "accumulator")] + pub use salsa_macro_rules::setup_accumulator_impl; pub use salsa_macro_rules::{ - macro_if, maybe_backdate, maybe_default, maybe_default_tt, return_mode_expression, - return_mode_ty, setup_accumulator_impl, setup_input_struct, setup_interned_struct, + gate_accumulated, macro_if, maybe_backdate, maybe_default, maybe_default_tt, + return_mode_expression, return_mode_ty, setup_input_struct, setup_interned_struct, setup_tracked_assoc_fn_body, setup_tracked_fn, setup_tracked_method_body, setup_tracked_struct, unexpected_cycle_initial, unexpected_cycle_recovery, }; + #[cfg(feature = "accumulator")] pub use crate::accumulator::Accumulator; pub use crate::attach::{attach, with_attached_database}; pub use crate::cycle::{CycleRecoveryAction, CycleRecoveryStrategy}; @@ -116,6 +123,7 @@ pub mod plumbing { }; pub use crate::zalsa_local::ZalsaLocal; + #[cfg(feature = "accumulator")] pub mod accumulator { pub use crate::accumulator::{IngredientImpl, JarImpl}; } diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 21e03713a..8d58d7171 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -5,7 +5,11 @@ use std::ptr::{self, NonNull}; use rustc_hash::FxHashMap; use thin_vec::ThinVec; -use crate::accumulator::accumulated_map::{AccumulatedMap, AtomicInputAccumulatedValues}; +#[cfg(feature = "accumulator")] +use crate::accumulator::{ + accumulated_map::{AccumulatedMap, AtomicInputAccumulatedValues}, + Accumulator, +}; use crate::active_query::QueryStack; use crate::cycle::{empty_cycle_heads, CycleHeads, IterationCount}; use crate::durability::Durability; @@ -15,7 +19,7 @@ use crate::sync::atomic::AtomicBool; use crate::table::{PageIndex, Slot, Table}; use crate::tracked_struct::{Disambiguator, Identity, IdentityHash, IdentityMap}; use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{Accumulator, Cancelled, Id, Revision}; +use crate::{Cancelled, Id, Revision}; /// State that is specific to a single execution thread. /// @@ -201,6 +205,7 @@ impl ZalsaLocal { /// Add an output to the current query's list of dependencies /// /// Returns `Err` if not in a query. + #[cfg(feature = "accumulator")] pub(crate) fn accumulate( &self, index: IngredientIndex, @@ -252,9 +257,9 @@ impl ZalsaLocal { input: DatabaseKeyIndex, durability: Durability, changed_at: Revision, - has_accumulated: bool, - accumulated_inputs: &AtomicInputAccumulatedValues, cycle_heads: &CycleHeads, + #[cfg(feature = "accumulator")] has_accumulated: bool, + #[cfg(feature = "accumulator")] accumulated_inputs: &AtomicInputAccumulatedValues, ) { crate::tracing::debug!( "report_tracked_read(input={:?}, durability={:?}, changed_at={:?})", @@ -271,9 +276,11 @@ impl ZalsaLocal { input, durability, changed_at, + cycle_heads, + #[cfg(feature = "accumulator")] has_accumulated, + #[cfg(feature = "accumulator")] accumulated_inputs, - cycle_heads, ); } }) @@ -420,6 +427,7 @@ pub(crate) struct QueryRevisions { /// /// Note that this field could be in `QueryRevisionsExtra` as it is only relevant /// for accumulators, but we get it for free anyways due to padding. + #[cfg(feature = "accumulator")] pub(super) accumulated_inputs: AtomicInputAccumulatedValues, /// Are the `cycle_heads` verified to not be provisional anymore? @@ -439,10 +447,11 @@ impl QueryRevisions { let QueryRevisions { changed_at: _, durability: _, - accumulated_inputs: _, verified_final: _, origin, extra, + #[cfg(feature = "accumulator")] + accumulated_inputs: _, } = self; let mut memory = 0; @@ -472,19 +481,24 @@ pub(crate) struct QueryRevisionsExtra(Option>); impl QueryRevisionsExtra { pub fn new( - accumulated: AccumulatedMap, + #[cfg(feature = "accumulator")] accumulated: AccumulatedMap, tracked_struct_ids: IdentityMap, cycle_heads: CycleHeads, iteration: IterationCount, ) -> Self { - let inner = if tracked_struct_ids.is_empty() + #[cfg(feature = "accumulator")] + let acc = accumulated.is_empty(); + #[cfg(not(feature = "accumulator"))] + let acc = true; + let inner = if acc + && tracked_struct_ids.is_empty() && cycle_heads.is_empty() - && accumulated.is_empty() && iteration.is_initial() { None } else { Some(Box::new(QueryRevisionsExtraInner { + #[cfg(feature = "accumulator")] accumulated, cycle_heads, tracked_struct_ids: tracked_struct_ids.into_thin_vec(), @@ -498,6 +512,7 @@ impl QueryRevisionsExtra { #[derive(Debug)] struct QueryRevisionsExtraInner { + #[cfg(feature = "accumulator")] accumulated: AccumulatedMap, /// The ids of tracked structs created by this query. @@ -536,15 +551,18 @@ impl QueryRevisionsExtraInner { #[cfg(feature = "salsa_unstable")] fn allocation_size(&self) -> usize { let QueryRevisionsExtraInner { + #[cfg(feature = "accumulator")] accumulated, tracked_struct_ids, cycle_heads, iteration: _, } = self; - accumulated.allocation_size() - + cycle_heads.allocation_size() - + std::mem::size_of_val(tracked_struct_ids.as_slice()) + #[cfg(feature = "accumulator")] + let b = accumulated.allocation_size(); + #[cfg(not(feature = "accumulator"))] + let b = 0; + b + cycle_heads.allocation_size() + std::mem::size_of_val(tracked_struct_ids.as_slice()) } } @@ -555,7 +573,7 @@ const _: [(); std::mem::size_of::()] = [(); std::mem::size_of::< #[cfg(not(feature = "shuttle"))] #[cfg(target_pointer_width = "64")] const _: [(); std::mem::size_of::()] = - [(); std::mem::size_of::<[usize; 7]>()]; + [(); std::mem::size_of::<[usize; if cfg!(feature = "accumulator") { 7 } else { 3 }]>()]; impl QueryRevisions { pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex) -> Self { @@ -563,9 +581,11 @@ impl QueryRevisions { changed_at: Revision::start(), durability: Durability::MAX, origin: QueryOrigin::fixpoint_initial(), + #[cfg(feature = "accumulator")] accumulated_inputs: Default::default(), verified_final: AtomicBool::new(false), extra: QueryRevisionsExtra::new( + #[cfg(feature = "accumulator")] AccumulatedMap::default(), IdentityMap::default(), CycleHeads::initial(query), @@ -575,6 +595,7 @@ impl QueryRevisions { } /// Returns a reference to the `AccumulatedMap` for this query, or `None` if the map is empty. + #[cfg(feature = "accumulator")] pub(crate) fn accumulated(&self) -> Option<&AccumulatedMap> { self.extra .0 @@ -606,6 +627,7 @@ impl QueryRevisions { Some(extra) => extra.cycle_heads = cycle_heads, None => { self.extra = QueryRevisionsExtra::new( + #[cfg(feature = "accumulator")] AccumulatedMap::default(), IdentityMap::default(), cycle_heads, @@ -676,6 +698,7 @@ pub enum QueryOriginRef<'a> { impl<'a> QueryOriginRef<'a> { /// Indices for queries *read* by this query #[inline] + #[cfg(feature = "accumulator")] pub(crate) fn inputs(self) -> impl DoubleEndedIterator + use<'a> { let opt_edges = match self { QueryOriginRef::Derived(edges) | QueryOriginRef::DerivedUntracked(edges) => Some(edges), @@ -968,6 +991,7 @@ pub enum QueryEdgeKind { /// Returns the (tracked) inputs that were executed in computing this memoized value. /// /// These will always be in execution order. +#[cfg(feature = "accumulator")] pub(crate) fn input_edges( input_outputs: &[QueryEdge], ) -> impl DoubleEndedIterator + use<'_> { diff --git a/tests/accumulate-chain.rs b/tests/accumulate-chain.rs index 18d4bb56a..4a26cec76 100644 --- a/tests/accumulate-chain.rs +++ b/tests/accumulate-chain.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "inventory")] +#![cfg(all(feature = "inventory", feature = "accumulator"))] //! Test that when having nested tracked functions //! we don't drop any values when accumulating. diff --git a/tests/accumulate-custom-debug.rs b/tests/accumulate-custom-debug.rs index 180156042..77b7d59f9 100644 --- a/tests/accumulate-custom-debug.rs +++ b/tests/accumulate-custom-debug.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "inventory")] +#![cfg(all(feature = "inventory", feature = "accumulator"))] mod common; diff --git a/tests/accumulate-dag.rs b/tests/accumulate-dag.rs index 41d9b3908..507ff6bfb 100644 --- a/tests/accumulate-dag.rs +++ b/tests/accumulate-dag.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "inventory")] +#![cfg(all(feature = "inventory", feature = "accumulator"))] mod common; diff --git a/tests/accumulate-execution-order.rs b/tests/accumulate-execution-order.rs index 1a0d3e233..c71732a55 100644 --- a/tests/accumulate-execution-order.rs +++ b/tests/accumulate-execution-order.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "inventory")] +#![cfg(all(feature = "inventory", feature = "accumulator"))] //! Demonstrates that accumulation is done in the order //! in which things were originally executed. diff --git a/tests/accumulate-from-tracked-fn.rs b/tests/accumulate-from-tracked-fn.rs index 67e591688..34c69a69a 100644 --- a/tests/accumulate-from-tracked-fn.rs +++ b/tests/accumulate-from-tracked-fn.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "inventory")] +#![cfg(all(feature = "inventory", feature = "accumulator"))] //! Accumulate values from within a tracked function. //! Then mutate the values so that the tracked function re-executes. diff --git a/tests/accumulate-no-duplicates.rs b/tests/accumulate-no-duplicates.rs index 8d21281e5..96bae0629 100644 --- a/tests/accumulate-no-duplicates.rs +++ b/tests/accumulate-no-duplicates.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "inventory")] +#![cfg(all(feature = "inventory", feature = "accumulator"))] //! Test that we don't get duplicate accumulated values diff --git a/tests/accumulate-reuse-workaround.rs b/tests/accumulate-reuse-workaround.rs index 43c5bb3ce..db26b7e3c 100644 --- a/tests/accumulate-reuse-workaround.rs +++ b/tests/accumulate-reuse-workaround.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "inventory")] +#![cfg(all(feature = "inventory", feature = "accumulator"))] //! Demonstrates the workaround of wrapping calls to //! `accumulated` in a tracked function to get better diff --git a/tests/accumulate-reuse.rs b/tests/accumulate-reuse.rs index 1e6194de6..6b63e2de9 100644 --- a/tests/accumulate-reuse.rs +++ b/tests/accumulate-reuse.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "inventory")] +#![cfg(all(feature = "inventory", feature = "accumulator"))] //! Accumulator re-use test. //! diff --git a/tests/accumulate.rs b/tests/accumulate.rs index 54022a15e..0ad88a31f 100644 --- a/tests/accumulate.rs +++ b/tests/accumulate.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "inventory")] +#![cfg(all(feature = "inventory", feature = "accumulator"))] mod common; use common::{LogDatabase, LoggerDatabase}; diff --git a/tests/accumulated_backdate.rs b/tests/accumulated_backdate.rs index 45759d1ba..efc8d2f85 100644 --- a/tests/accumulated_backdate.rs +++ b/tests/accumulated_backdate.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "inventory")] +#![cfg(all(feature = "inventory", feature = "accumulator"))] //! Tests that accumulated values are correctly accounted for //! when backdating a value. diff --git a/tests/cycle_accumulate.rs b/tests/cycle_accumulate.rs index fa31845d9..e06fe033b 100644 --- a/tests/cycle_accumulate.rs +++ b/tests/cycle_accumulate.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "inventory")] +#![cfg(all(feature = "inventory", feature = "accumulator"))] use std::collections::HashSet; From 5b411a290c719a977c09386f2f037feb2dededf9 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Sat, 2 Aug 2025 15:34:21 +0200 Subject: [PATCH 15/65] refactor: Use `CycleHeadSet` in `maybe_update_after` (#953) --- src/accumulator.rs | 4 +-- src/cycle.rs | 47 +++++++++++++++++++---------- src/function.rs | 5 +-- src/function/fetch.rs | 4 +-- src/function/maybe_changed_after.rs | 12 ++++---- src/ingredient.rs | 5 +-- src/input.rs | 4 +-- src/input/input_field.rs | 4 +-- src/interned.rs | 4 +-- src/key.rs | 4 +-- src/tracked_struct.rs | 4 +-- src/tracked_struct/tracked_field.rs | 4 +-- 12 files changed, 59 insertions(+), 42 deletions(-) diff --git a/src/accumulator.rs b/src/accumulator.rs index 4bd1280a7..0e9feab62 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -7,7 +7,7 @@ use std::panic::UnwindSafe; use accumulated::{Accumulated, AnyAccumulated}; -use crate::cycle::CycleHeads; +use crate::cycle::CycleHeadKeys; use crate::function::VerifyResult; use crate::ingredient::{Ingredient, Jar}; use crate::plumbing::ZalsaLocal; @@ -106,7 +106,7 @@ impl Ingredient for IngredientImpl { _db: crate::database::RawDatabase<'_>, _input: Id, _revision: Revision, - _cycle_heads: &mut CycleHeads, + _cycle_heads: &mut CycleHeadKeys, ) -> VerifyResult { panic!("nothing should ever depend on an accumulator directly") } diff --git a/src/cycle.rs b/src/cycle.rs index 66f205448..f05983196 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -178,22 +178,6 @@ impl CycleHeads { } } - #[inline] - pub(crate) fn push_initial(&mut self, database_key_index: DatabaseKeyIndex) { - if let Some(existing) = self - .0 - .iter() - .find(|candidate| candidate.database_key_index == database_key_index) - { - assert_eq!(existing.iteration_count, IterationCount::initial()); - } else { - self.0.push(CycleHead { - database_key_index, - iteration_count: IterationCount::initial(), - }); - } - } - #[inline] pub(crate) fn extend(&mut self, other: &Self) { self.0.reserve(other.0.len()); @@ -247,6 +231,37 @@ pub(crate) fn empty_cycle_heads() -> &'static CycleHeads { EMPTY_CYCLE_HEADS.get_or_init(|| CycleHeads(ThinVec::new())) } +/// Set of cycle head database keys. +/// +/// Unlike [`CycleHeads`], this type doesn't track the iteration count +/// of each cycle head. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct CycleHeadKeys(Vec); + +impl CycleHeadKeys { + pub(crate) fn new() -> Self { + Self(Vec::new()) + } + + pub(crate) fn insert(&mut self, database_key_index: DatabaseKeyIndex) { + if !self.0.contains(&database_key_index) { + self.0.push(database_key_index); + } + } + + pub(crate) fn remove(&mut self, value: &DatabaseKeyIndex) -> bool { + let found = self.0.iter().position(|&head| head == *value); + let Some(found) = found else { return false }; + + self.0.swap_remove(found); + true + } + + pub(crate) fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + #[derive(Debug, PartialEq, Eq)] pub enum ProvisionalStatus { Provisional { iteration: IterationCount }, diff --git a/src/function.rs b/src/function.rs index 75a046b8a..891e3dbad 100644 --- a/src/function.rs +++ b/src/function.rs @@ -7,7 +7,8 @@ use std::sync::OnceLock; pub(crate) use sync::SyncGuard; use crate::cycle::{ - empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, ProvisionalStatus, + empty_cycle_heads, CycleHeadKeys, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, + ProvisionalStatus, }; use crate::database::RawDatabase; use crate::function::delete::DeletedEntries; @@ -265,7 +266,7 @@ where db: RawDatabase<'_>, input: Id, revision: Revision, - cycle_heads: &mut CycleHeads, + cycle_heads: &mut CycleHeadKeys, ) -> VerifyResult { // SAFETY: The `db` belongs to the ingredient as per caller invariant let db = unsafe { self.view_caster().downcast_unchecked(db) }; diff --git a/src/function/fetch.rs b/src/function/fetch.rs index c6b1111bf..88e6f7e8d 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,4 +1,4 @@ -use crate::cycle::{CycleHeads, CycleRecoveryStrategy, IterationCount}; +use crate::cycle::{CycleHeadKeys, CycleHeads, CycleRecoveryStrategy, IterationCount}; use crate::function::memo::Memo; use crate::function::sync::ClaimResult; use crate::function::{Configuration, IngredientImpl, VerifyResult}; @@ -222,7 +222,7 @@ where if let Some(old_memo) = opt_old_memo { if old_memo.value.is_some() { - let mut cycle_heads = CycleHeads::default(); + let mut cycle_heads = CycleHeadKeys::new(); if let VerifyResult::Unchanged { .. } = self.deep_verify_memo(db, zalsa, old_memo, database_key_index, &mut cycle_heads) { diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 8f3a8762b..ac96edade 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -1,6 +1,6 @@ #[cfg(feature = "accumulator")] use crate::accumulator::accumulated_map::InputAccumulatedValues; -use crate::cycle::{CycleHeads, CycleRecoveryStrategy, IterationCount, ProvisionalStatus}; +use crate::cycle::{CycleHeadKeys, CycleRecoveryStrategy, IterationCount, ProvisionalStatus}; use crate::function::memo::Memo; use crate::function::sync::ClaimResult; use crate::function::{Configuration, IngredientImpl}; @@ -51,7 +51,7 @@ where db: &'db C::DbView, id: Id, revision: Revision, - cycle_heads: &mut CycleHeads, + cycle_heads: &mut CycleHeadKeys, ) -> VerifyResult { let (zalsa, zalsa_local) = db.zalsas(); let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); @@ -108,7 +108,7 @@ where key_index: Id, revision: Revision, memo_ingredient_index: MemoIngredientIndex, - cycle_heads: &mut CycleHeads, + cycle_heads: &mut CycleHeadKeys, ) -> Option { let database_key_index = self.database_key_index(key_index); @@ -135,7 +135,7 @@ where crate::tracing::debug!( "hit cycle at {database_key_index:?} in `maybe_changed_after`, returning fixpoint initial value", ); - cycle_heads.push_initial(database_key_index); + cycle_heads.insert(database_key_index); return Some(VerifyResult::unchanged()); } }, @@ -406,7 +406,7 @@ where zalsa: &Zalsa, old_memo: &Memo<'_, C>, database_key_index: DatabaseKeyIndex, - cycle_heads: &mut CycleHeads, + cycle_heads: &mut CycleHeadKeys, ) -> VerifyResult { crate::tracing::debug!( "{database_key_index:?}: deep_verify_memo(old_memo = {old_memo:#?})", @@ -447,7 +447,7 @@ where // are tracked by the outer query. Nothing should have changed assuming that the // fixpoint initial function is deterministic. QueryOriginRef::FixpointInitial => { - cycle_heads.push_initial(database_key_index); + cycle_heads.insert(database_key_index); VerifyResult::unchanged() } QueryOriginRef::DerivedUntracked(_) => { diff --git a/src/ingredient.rs b/src/ingredient.rs index e8766b5cf..2ad7bb8ee 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -2,7 +2,8 @@ use std::any::{Any, TypeId}; use std::fmt; use crate::cycle::{ - empty_cycle_heads, CycleHeads, CycleRecoveryStrategy, IterationCount, ProvisionalStatus, + empty_cycle_heads, CycleHeadKeys, CycleHeads, CycleRecoveryStrategy, IterationCount, + ProvisionalStatus, }; use crate::database::RawDatabase; use crate::function::VerifyResult; @@ -51,7 +52,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { db: crate::database::RawDatabase<'_>, input: Id, revision: Revision, - cycle_heads: &mut CycleHeads, + cycle_heads: &mut CycleHeadKeys, ) -> VerifyResult; /// Returns information about the current provisional status of `input`. diff --git a/src/input.rs b/src/input.rs index 81fa55b60..58d769a6b 100644 --- a/src/input.rs +++ b/src/input.rs @@ -8,7 +8,7 @@ pub mod singleton; use input_field::FieldIngredientImpl; -use crate::cycle::CycleHeads; +use crate::cycle::CycleHeadKeys; use crate::function::VerifyResult; use crate::id::{AsId, FromId, FromIdWithDb}; use crate::ingredient::Ingredient; @@ -226,7 +226,7 @@ impl Ingredient for IngredientImpl { _db: crate::database::RawDatabase<'_>, _input: Id, _revision: Revision, - _cycle_heads: &mut CycleHeads, + _cycle_heads: &mut CycleHeadKeys, ) -> VerifyResult { // Input ingredients are just a counter, they store no data, they are immortal. // Their *fields* are stored in function ingredients elsewhere. diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 82ed9889d..0d724b0ca 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -1,7 +1,7 @@ use std::fmt; use std::marker::PhantomData; -use crate::cycle::CycleHeads; +use crate::cycle::CycleHeadKeys; use crate::function::VerifyResult; use crate::ingredient::Ingredient; use crate::input::{Configuration, IngredientImpl, Value}; @@ -56,7 +56,7 @@ where _db: crate::database::RawDatabase<'_>, input: Id, revision: Revision, - _cycle_heads: &mut CycleHeads, + _cycle_heads: &mut CycleHeadKeys, ) -> VerifyResult { let value = >::data(zalsa, input); VerifyResult::changed_if(value.revisions[self.field_index] > revision) diff --git a/src/interned.rs b/src/interned.rs index cc3c3ac65..49663ee65 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -10,7 +10,7 @@ use crossbeam_utils::CachePadded; use intrusive_collections::{intrusive_adapter, LinkedList, LinkedListLink, UnsafeRef}; use rustc_hash::FxBuildHasher; -use crate::cycle::CycleHeads; +use crate::cycle::CycleHeadKeys; use crate::durability::Durability; use crate::function::VerifyResult; use crate::id::{AsId, FromId}; @@ -797,7 +797,7 @@ where _db: crate::database::RawDatabase<'_>, input: Id, _revision: Revision, - _cycle_heads: &mut CycleHeads, + _cycle_heads: &mut CycleHeadKeys, ) -> VerifyResult { // Record the current revision as active. let current_revision = zalsa.current_revision(); diff --git a/src/key.rs b/src/key.rs index 80904e978..9045e8337 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,6 +1,6 @@ use core::fmt; -use crate::cycle::CycleHeads; +use crate::cycle::CycleHeadKeys; use crate::function::VerifyResult; use crate::zalsa::{IngredientIndex, Zalsa}; use crate::Id; @@ -39,7 +39,7 @@ impl DatabaseKeyIndex { db: crate::database::RawDatabase<'_>, zalsa: &Zalsa, last_verified_at: crate::Revision, - cycle_heads: &mut CycleHeads, + cycle_heads: &mut CycleHeadKeys, ) -> VerifyResult { // SAFETY: The `db` belongs to the ingredient unsafe { diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 9724093cf..2ec3e24dc 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -10,7 +10,7 @@ use crossbeam_queue::SegQueue; use thin_vec::ThinVec; use tracked_field::FieldIngredientImpl; -use crate::cycle::CycleHeads; +use crate::cycle::CycleHeadKeys; use crate::function::VerifyResult; use crate::id::{AsId, FromId}; use crate::ingredient::{Ingredient, Jar}; @@ -817,7 +817,7 @@ where _db: crate::database::RawDatabase<'_>, _input: Id, _revision: Revision, - _cycle_heads: &mut CycleHeads, + _cycle_heads: &mut CycleHeadKeys, ) -> VerifyResult { // Any change to a tracked struct results in a new ID generation. VerifyResult::unchanged() diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 587e473fa..95ec32fa6 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use crate::cycle::CycleHeads; +use crate::cycle::CycleHeadKeys; use crate::function::VerifyResult; use crate::ingredient::Ingredient; use crate::sync::Arc; @@ -61,7 +61,7 @@ where _db: crate::database::RawDatabase<'_>, input: Id, revision: crate::Revision, - _cycle_heads: &mut CycleHeads, + _cycle_heads: &mut CycleHeadKeys, ) -> VerifyResult { let data = >::data(zalsa.table(), input); let field_changed_at = data.revisions[self.field_index]; From c3f86b8d02c32db4c36d8a4c40096daeaff64058 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Sun, 3 Aug 2025 06:40:21 +0200 Subject: [PATCH 16/65] Upgrade dependencies (#956) --- Cargo.toml | 18 +++++++++--------- components/salsa-macros/Cargo.toml | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3f3470866..3b4eb3455 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ salsa-macro-rules = { version = "0.23.0", path = "components/salsa-macro-rules" salsa-macros = { version = "0.23.0", path = "components/salsa-macros", optional = true } boxcar = "0.2.13" -crossbeam-queue = "0.3.11" +crossbeam-queue = "0.3.12" crossbeam-utils = "0.8.21" hashbrown = "0.15" hashlink = "0.10" @@ -33,9 +33,9 @@ rayon = { version = "1.10.0", optional = true } # Stuff we want Update impls for by default compact_str = { version = "0.9", optional = true } -thin-vec = "0.2.13" +thin-vec = "0.2.14" -shuttle = { version = "0.8.0", optional = true } +shuttle = { version = "0.8.1", optional = true } [features] default = ["salsa_unstable", "rayon", "macros", "inventory", "accumulator"] @@ -55,18 +55,18 @@ salsa-macros = { version = "=0.23.0", path = "components/salsa-macros" } [dev-dependencies] # examples -crossbeam-channel = "0.5.14" +crossbeam-channel = "0.5.15" dashmap = { version = "6", features = ["raw-api"] } -eyre = "0.6.8" +eyre = "0.6.12" notify-debouncer-mini = "0.4.1" -ordered-float = "4.2.1" +ordered-float = "5.0.0" # tests/benches annotate-snippets = "0.11.5" -codspeed-criterion-compat = { version = "2.6.0", default-features = false } -expect-test = "1.5.0" +codspeed-criterion-compat = { version = "3.0.5", default-features = false } +expect-test = "1.5.1" rustversion = "1.0" -test-log = { version = "0.2.11", features = ["trace"] } +test-log = { version = "0.2.18", features = ["trace"] } trybuild = "1.0" [target.'cfg(all(not(target_os = "windows"), not(target_os = "openbsd"), any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "powerpc64")))'.dev-dependencies] diff --git a/components/salsa-macros/Cargo.toml b/components/salsa-macros/Cargo.toml index f992d4368..ea4efe078 100644 --- a/components/salsa-macros/Cargo.toml +++ b/components/salsa-macros/Cargo.toml @@ -14,5 +14,5 @@ proc-macro = true [dependencies] proc-macro2 = "1.0" quote = "1.0" -syn = { version = "2.0.101", features = ["full", "visit-mut"] } +syn = { version = "2.0.104", features = ["full", "visit-mut"] } synstructure = "0.13.2" From 86ca4a9d70e97dd5107e6111a09647dd10ff7535 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Sun, 3 Aug 2025 10:49:53 +0200 Subject: [PATCH 17/65] Expose API to manually trigger cancellation (#959) --- src/database.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/database.rs b/src/database.rs index 30178b2da..a253ac5a0 100644 --- a/src/database.rs +++ b/src/database.rs @@ -59,6 +59,14 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase { zalsa_mut.runtime_mut().report_tracked_write(durability); } + /// This method triggers cancellation. + /// If you invoke it while a snapshot exists, it + /// will block until that snapshot is dropped -- if that snapshot + /// is owned by the current thread, this could trigger deadlock. + fn trigger_cancellation(&mut self) { + let _ = self.zalsa_mut(); + } + /// Reports that the query depends on some state unknown to salsa. /// /// Queries which report untracked reads will be re-executed in the next From d66fe331d546216132ace503512b94d5c68d2c50 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 5 Aug 2025 09:37:27 +0200 Subject: [PATCH 18/65] allow reuse of cached provisional memos within the same cycle iteration during `maybe_changed_after` (#954) * Allow `validate_same_iteration` slow path for `maybe_changed_after` * Add logging to cycle_panic test and enable output capture for debugging * Add RUST_LOG=trace to CI for detailed tracing output * Possible fix * Discard changes to .github/workflows/test.yml * Delete logs/.eb4cbc62113c2d27a93f1a2e33d842313ed0aa05-audit.json * Delete logs/mcp-puppeteer-2025-08-02.log * docs * Add regression test * Discard changes to tests/parallel/cycle_panic.rs * Discard changes to Cargo.toml * Fix comment * Discard changes to Cargo.toml * Update src/function/maybe_changed_after.rs Co-authored-by: Carl Meyer --------- Co-authored-by: Carl Meyer --- src/function.rs | 2 +- src/function/maybe_changed_after.rs | 37 +++++++++++++++++++++++---- src/function/sync.rs | 4 +-- src/ingredient.rs | 7 +++--- src/runtime.rs | 6 ++--- tests/common/mod.rs | 4 +++ tests/cycle.rs | 39 +++++++++++++++++++++++++++++ 7 files changed, 84 insertions(+), 15 deletions(-) diff --git a/src/function.rs b/src/function.rs index 891e3dbad..e9ab57939 100644 --- a/src/function.rs +++ b/src/function.rs @@ -311,7 +311,7 @@ where fn wait_for<'me>(&'me self, zalsa: &'me Zalsa, key_index: Id) -> WaitForResult<'me> { match self.sync_table.try_claim(zalsa, key_index) { ClaimResult::Running(blocked_on) => WaitForResult::Running(blocked_on), - ClaimResult::Cycle { same_thread } => WaitForResult::Cycle { same_thread }, + ClaimResult::Cycle => WaitForResult::Cycle, ClaimResult::Claimed(_) => WaitForResult::Available, } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index ac96edade..c5422ffeb 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -267,7 +267,7 @@ where /// cycle heads have all been finalized. /// * provisional memos that have been created in the same revision and iteration and are part of the same cycle. #[inline] - pub(super) fn validate_may_be_provisional( + fn validate_may_be_provisional( &self, zalsa: &Zalsa, zalsa_local: &ZalsaLocal, @@ -342,7 +342,7 @@ where /// If this is a provisional memo, validate that it was cached in the same iteration of the /// same cycle(s) that we are still executing. If so, it is valid for reuse. This avoids /// runaway re-execution of the same queries within a fixpoint iteration. - pub(super) fn validate_same_iteration( + fn validate_same_iteration( &self, zalsa: &Zalsa, zalsa_local: &ZalsaLocal, @@ -369,15 +369,42 @@ where .find(|query| query.database_key_index == cycle_head.database_key_index) .map(|query| query.iteration_count()) .or_else(|| { - // If this is a cycle head is owned by another thread that is blocked by this ingredient, - // check if it has the same iteration count. + // If the cycle head isn't on our stack because: + // + // * another thread holds the lock on the cycle head (but it waits for the current query to complete) + // * we're in `maybe_changed_after` because `maybe_changed_after` doesn't modify the cycle stack + // + // check if the latest memo has the same iteration count. + + // However, we've to be careful to skip over fixpoint initial values: + // If the head is the memo we're trying to validate, always return `None` + // to force a re-execution of the query. This is necessary because the query + // has obviously not completed its iteration yet. + // + // This should be rare but the `cycle_panic` test fails on some platforms (mainly GitHub actions) + // without this check. What happens there is that: + // + // * query a blocks on query b + // * query b tries to claim a, fails to do so and inserts the fixpoint initial value + // * query b completes and has `a` as head. It returns its query result Salsa blocks query b from + // exiting inside `block_on` (or the thread would complete before the cycle iteration is complete) + // * query a resumes but panics because of the fixpoint iteration function + // * query b resumes. It rexecutes its own query which then tries to fetch a (which depends on itself because it's a fixpoint initial value). + // Without this check, `validate_same_iteration` would return `true` because the latest memo for `a` is the fixpoint initial value. + // But it should return `false` so that query b's thread re-executes `a` (which then also causes the panic). + // + // That's why we always return `None` if the cycle head is the same as the current database key index. + if cycle_head.database_key_index == database_key_index { + return None; + } + let ingredient = zalsa.lookup_ingredient( cycle_head.database_key_index.ingredient_index(), ); let wait_result = ingredient .wait_for(zalsa, cycle_head.database_key_index.key_index()); - if !wait_result.is_cycle_with_other_thread() { + if !wait_result.is_cycle() { return None; } diff --git a/src/function/sync.rs b/src/function/sync.rs index 28f088af4..bb514e114 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -20,7 +20,7 @@ pub(crate) enum ClaimResult<'a> { /// Can't claim the query because it is running on an other thread. Running(Running<'a>), /// Claiming the query results in a cycle. - Cycle { same_thread: bool }, + Cycle, /// Successfully claimed the query. Claimed(ClaimGuard<'a>), } @@ -62,7 +62,7 @@ impl SyncTable { write, ) { BlockResult::Running(blocked_on) => ClaimResult::Running(blocked_on), - BlockResult::Cycle { same_thread } => ClaimResult::Cycle { same_thread }, + BlockResult::Cycle => ClaimResult::Cycle, } } std::collections::hash_map::Entry::Vacant(vacant_entry) => { diff --git a/src/ingredient.rs b/src/ingredient.rs index 2ad7bb8ee..f117cb696 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -250,12 +250,11 @@ pub(crate) fn fmt_index(debug_name: &str, id: Id, fmt: &mut fmt::Formatter<'_>) pub enum WaitForResult<'me> { Running(Running<'me>), Available, - Cycle { same_thread: bool }, + Cycle, } impl WaitForResult<'_> { - /// Returns `true` if waiting for this input results in a cycle with another thread. - pub const fn is_cycle_with_other_thread(&self) -> bool { - matches!(self, WaitForResult::Cycle { same_thread: false }) + pub const fn is_cycle(&self) -> bool { + matches!(self, WaitForResult::Cycle) } } diff --git a/src/runtime.rs b/src/runtime.rs index bc2859a7e..6a4d1e8b8 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -51,7 +51,7 @@ pub(crate) enum BlockResult<'me> { /// /// The lock is hold by the current thread or there's another thread that is waiting on the current thread, /// and blocking this thread on the other thread would result in a deadlock/cycle. - Cycle { same_thread: bool }, + Cycle, } pub struct Running<'me>(Box>); @@ -230,14 +230,14 @@ impl Runtime { let thread_id = thread::current().id(); // Cycle in the same thread. if thread_id == other_id { - return BlockResult::Cycle { same_thread: true }; + return BlockResult::Cycle; } let dg = self.dependency_graph.lock(); if dg.depends_on(other_id, thread_id) { crate::tracing::debug!("block_on: cycle detected for {database_key:?} in thread {thread_id:?} on {other_id:?}"); - return BlockResult::Cycle { same_thread: false }; + return BlockResult::Cycle; } BlockResult::Running(Running(Box::new(BlockedOnInner { diff --git a/tests/common/mod.rs b/tests/common/mod.rs index f7aa79b31..df3fba477 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -34,6 +34,10 @@ pub trait LogDatabase: HasLogger + Database { self.logger().logs.lock().unwrap().push(string); } + fn clear_logs(&self) { + std::mem::take(&mut *self.logger().logs.lock().unwrap()); + } + /// Asserts what the (formatted) logs should look like, /// clearing the logged events. This takes `&mut self` because /// it is meant to be run from outside any tracked functions. diff --git a/tests/cycle.rs b/tests/cycle.rs index 28266f2c5..3c3687f3d 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -1025,3 +1025,42 @@ fn repeat_provisional_query() { "salsa_event(WillExecute { database_key: min_panic(Id(2)) })", ]"#]]); } + +#[test] +fn repeat_provisional_query_incremental() { + let mut db = ExecuteValidateLoggerDatabase::default(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![value(59), b.clone()]); + b_in.set_inputs(&mut db) + .to(vec![value(60), c.clone(), c.clone(), c]); + c_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert_value(&db, 59); + + db.clear_logs(); + + c_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert_value(&db, 59); + + // `min_panic(Id(2)) should only twice: + // * Once before iterating + // * Once as part of iterating + // + // If it runs more than once before iterating, than this suggests that + // `validate_same_iteration` incorrectly returns `false`. + db.assert_logs(expect![[r#" + [ + "salsa_event(WillExecute { database_key: min_panic(Id(2)) })", + "salsa_event(WillExecute { database_key: min_panic(Id(1)) })", + "salsa_event(WillExecute { database_key: min_iterate(Id(0)) })", + "salsa_event(WillIterateCycle { database_key: min_iterate(Id(0)), iteration_count: IterationCount(1), fell_back: false })", + "salsa_event(WillExecute { database_key: min_panic(Id(1)) })", + "salsa_event(WillExecute { database_key: min_panic(Id(2)) })", + ]"#]]); +} From 5b2a97b56c0cf4834e9e18d7ed85b0c7bb707f2f Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Wed, 6 Aug 2025 08:50:13 +0200 Subject: [PATCH 19/65] refactor: Extract the cycle branches from `fetch` and `maybe_changed_after` (#955) * Extract the cycle branches from `fetch` and `maybe_changed_after` * Add `inline(never)` --- src/function/fetch.rs | 152 +++++++++++++++------------- src/function/maybe_changed_after.rs | 114 ++++++++++++--------- 2 files changed, 148 insertions(+), 118 deletions(-) diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 88e6f7e8d..b65089b43 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -4,7 +4,7 @@ use crate::function::sync::ClaimResult; use crate::function::{Configuration, IngredientImpl, VerifyResult}; use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{QueryRevisions, ZalsaLocal}; -use crate::Id; +use crate::{DatabaseKeyIndex, Id}; impl IngredientImpl where @@ -130,6 +130,7 @@ where let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. let claim_guard = match self.sync_table.try_claim(zalsa, id) { + ClaimResult::Claimed(guard) => guard, ClaimResult::Running(blocked_on) => { blocked_on.block_on(zalsa); @@ -146,75 +147,15 @@ where return None; } ClaimResult::Cycle { .. } => { - // check if there's a provisional value for this query - // Note we don't `validate_may_be_provisional` the memo here as we want to reuse an - // existing provisional memo if it exists - let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); - if let Some(memo) = memo_guard { - if memo.value.is_some() - && memo.revisions.cycle_heads().contains(&database_key_index) - { - let can_shallow_update = - self.shallow_verify_memo(zalsa, database_key_index, memo); - if can_shallow_update.yes() { - self.update_shallow( - zalsa, - database_key_index, - memo, - can_shallow_update, - ); - // SAFETY: memo is present in memo_map. - return unsafe { Some(self.extend_memo_lifetime(memo)) }; - } - } - } - // no provisional value; create/insert/return initial provisional value - return match C::CYCLE_STRATEGY { - // SAFETY: We do not access the query stack reentrantly. - CycleRecoveryStrategy::Panic => unsafe { - zalsa_local.with_query_stack_unchecked(|stack| { - panic!( - "dependency graph cycle when querying {database_key_index:#?}, \ - set cycle_fn/cycle_initial to fixpoint iterate.\n\ - Query stack:\n{stack:#?}", - ); - }) - }, - CycleRecoveryStrategy::Fixpoint => { - crate::tracing::debug!( - "hit cycle at {database_key_index:#?}, \ - inserting and returning fixpoint initial value" - ); - let revisions = QueryRevisions::fixpoint_initial(database_key_index); - let initial_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); - Some(self.insert_memo( - zalsa, - id, - Memo::new(Some(initial_value), zalsa.current_revision(), revisions), - memo_ingredient_index, - )) - } - CycleRecoveryStrategy::FallbackImmediate => { - crate::tracing::debug!( - "hit a `FallbackImmediate` cycle at {database_key_index:#?}" - ); - let active_query = - zalsa_local.push_query(database_key_index, IterationCount::initial()); - let fallback_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); - let mut revisions = active_query.pop(); - revisions.set_cycle_heads(CycleHeads::initial(database_key_index)); - // We need this for `cycle_heads()` to work. We will unset this in the outer `execute()`. - *revisions.verified_final.get_mut() = false; - Some(self.insert_memo( - zalsa, - id, - Memo::new(Some(fallback_value), zalsa.current_revision(), revisions), - memo_ingredient_index, - )) - } - }; + return Some(self.fetch_cold_cycle( + zalsa, + zalsa_local, + db, + id, + database_key_index, + memo_ingredient_index, + )); } - ClaimResult::Claimed(guard) => guard, }; // Now that we've claimed the item, check again to see if there's a "hot" value. @@ -272,4 +213,77 @@ where Some(memo) } + + #[cold] + #[inline(never)] + fn fetch_cold_cycle<'db>( + &'db self, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, + db: &'db C::DbView, + id: Id, + database_key_index: DatabaseKeyIndex, + memo_ingredient_index: MemoIngredientIndex, + ) -> &'db Memo<'db, C> { + // check if there's a provisional value for this query + // Note we don't `validate_may_be_provisional` the memo here as we want to reuse an + // existing provisional memo if it exists + let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); + if let Some(memo) = memo_guard { + if memo.value.is_some() && memo.revisions.cycle_heads().contains(&database_key_index) { + let can_shallow_update = self.shallow_verify_memo(zalsa, database_key_index, memo); + if can_shallow_update.yes() { + self.update_shallow(zalsa, database_key_index, memo, can_shallow_update); + // SAFETY: memo is present in memo_map. + return unsafe { self.extend_memo_lifetime(memo) }; + } + } + } + + // no provisional value; create/insert/return initial provisional value + match C::CYCLE_STRATEGY { + // SAFETY: We do not access the query stack reentrantly. + CycleRecoveryStrategy::Panic => unsafe { + zalsa_local.with_query_stack_unchecked(|stack| { + panic!( + "dependency graph cycle when querying {database_key_index:#?}, \ + set cycle_fn/cycle_initial to fixpoint iterate.\n\ + Query stack:\n{stack:#?}", + ); + }) + }, + CycleRecoveryStrategy::Fixpoint => { + crate::tracing::debug!( + "hit cycle at {database_key_index:#?}, \ + inserting and returning fixpoint initial value" + ); + let revisions = QueryRevisions::fixpoint_initial(database_key_index); + let initial_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); + self.insert_memo( + zalsa, + id, + Memo::new(Some(initial_value), zalsa.current_revision(), revisions), + memo_ingredient_index, + ) + } + CycleRecoveryStrategy::FallbackImmediate => { + crate::tracing::debug!( + "hit a `FallbackImmediate` cycle at {database_key_index:#?}" + ); + let active_query = + zalsa_local.push_query(database_key_index, IterationCount::initial()); + let fallback_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); + let mut revisions = active_query.pop(); + revisions.set_cycle_heads(CycleHeads::initial(database_key_index)); + // We need this for `cycle_heads()` to work. We will unset this in the outer `execute()`. + *revisions.verified_final.get_mut() = false; + self.insert_memo( + zalsa, + id, + Memo::new(Some(fallback_value), zalsa.current_revision(), revisions), + memo_ingredient_index, + ) + } + } + } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index c5422ffeb..93c971895 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -113,33 +113,18 @@ where let database_key_index = self.database_key_index(key_index); let _claim_guard = match self.sync_table.try_claim(zalsa, key_index) { + ClaimResult::Claimed(guard) => guard, ClaimResult::Running(blocked_on) => { blocked_on.block_on(zalsa); return None; } - ClaimResult::Cycle { .. } => match C::CYCLE_STRATEGY { - // SAFETY: We do not access the query stack reentrantly. - CycleRecoveryStrategy::Panic => unsafe { - db.zalsa_local().with_query_stack_unchecked(|stack| { - panic!( - "dependency graph cycle when validating {database_key_index:#?}, \ - set cycle_fn/cycle_initial to fixpoint iterate.\n\ - Query stack:\n{stack:#?}", - ); - }) - }, - CycleRecoveryStrategy::FallbackImmediate => { - return Some(VerifyResult::unchanged()); - } - CycleRecoveryStrategy::Fixpoint => { - crate::tracing::debug!( - "hit cycle at {database_key_index:?} in `maybe_changed_after`, returning fixpoint initial value", - ); - cycle_heads.insert(database_key_index); - return Some(VerifyResult::unchanged()); - } - }, - ClaimResult::Claimed(guard) => guard, + ClaimResult::Cycle { .. } => { + return Some(self.maybe_changed_after_cold_cycle( + db, + database_key_index, + cycle_heads, + )) + } }; // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index, memo_ingredient_index) @@ -205,6 +190,36 @@ where Some(VerifyResult::Changed) } + #[cold] + #[inline(never)] + fn maybe_changed_after_cold_cycle<'db>( + &'db self, + db: &'db C::DbView, + database_key_index: DatabaseKeyIndex, + cycle_heads: &mut CycleHeadKeys, + ) -> VerifyResult { + match C::CYCLE_STRATEGY { + // SAFETY: We do not access the query stack reentrantly. + CycleRecoveryStrategy::Panic => unsafe { + db.zalsa_local().with_query_stack_unchecked(|stack| { + panic!( + "dependency graph cycle when validating {database_key_index:#?}, \ + set cycle_fn/cycle_initial to fixpoint iterate.\n\ + Query stack:\n{stack:#?}", + ); + }) + }, + CycleRecoveryStrategy::FallbackImmediate => VerifyResult::unchanged(), + CycleRecoveryStrategy::Fixpoint => { + crate::tracing::debug!( + "hit cycle at {database_key_index:?} in `maybe_changed_after`, returning fixpoint initial value", + ); + cycle_heads.insert(database_key_index); + VerifyResult::unchanged() + } + } + } + /// `Some` if the memo's value and `changed_at` time is still valid in this revision. /// Does only a shallow O(1) check, doesn't walk the dependencies. /// @@ -455,32 +470,6 @@ where } match old_memo.revisions.origin.as_ref() { - QueryOriginRef::Assigned(_) => { - // If the value was assigned by another query, - // and that query were up-to-date, - // then we would have updated the `verified_at` field already. - // So the fact that we are here means that it was not specified - // during this revision or is otherwise stale. - // - // Example of how this can happen: - // - // Conditionally specified queries - // where the value is specified - // in rev 1 but not in rev 2. - VerifyResult::Changed - } - // Return `Unchanged` similar to the initial value that we insert - // when we hit the cycle. Any dependencies accessed when creating the fixpoint initial - // are tracked by the outer query. Nothing should have changed assuming that the - // fixpoint initial function is deterministic. - QueryOriginRef::FixpointInitial => { - cycle_heads.insert(database_key_index); - VerifyResult::unchanged() - } - QueryOriginRef::DerivedUntracked(_) => { - // Untracked inputs? Have to assume that it changed. - VerifyResult::Changed - } QueryOriginRef::Derived(edges) => { let is_provisional = old_memo.may_be_provisional(); @@ -584,6 +573,33 @@ where accumulated: inputs, } } + + QueryOriginRef::Assigned(_) => { + // If the value was assigned by another query, + // and that query were up-to-date, + // then we would have updated the `verified_at` field already. + // So the fact that we are here means that it was not specified + // during this revision or is otherwise stale. + // + // Example of how this can happen: + // + // Conditionally specified queries + // where the value is specified + // in rev 1 but not in rev 2. + VerifyResult::Changed + } + // Return `Unchanged` similar to the initial value that we insert + // when we hit the cycle. Any dependencies accessed when creating the fixpoint initial + // are tracked by the outer query. Nothing should have changed assuming that the + // fixpoint initial function is deterministic. + QueryOriginRef::FixpointInitial => { + cycle_heads.insert(database_key_index); + VerifyResult::unchanged() + } + QueryOriginRef::DerivedUntracked(_) => { + // Untracked inputs? Have to assume that it changed. + VerifyResult::Changed + } } } } From ea38537827cb9ffaaaf566af3dd2143c301cfdfc Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Wed, 6 Aug 2025 09:01:18 +0200 Subject: [PATCH 20/65] Add heap size support for salsa structs (#943) * Improve unstable size analysis support 1. Include an option `panic_if_missing` that will panic if there is an ingredient with no `heap_size()` defined, to ensure coverage. 2. Add `heap_size()` to tracked structs, interneds an inputs. * Make heap size a separate field, remove panic argument * Remove stale comment --------- Co-authored-by: Chayim Refael Friedman --- .../src/setup_input_struct.rs | 9 +++ .../src/setup_interned_struct.rs | 9 +++ .../salsa-macro-rules/src/setup_tracked_fn.rs | 4 +- .../src/setup_tracked_struct.rs | 9 +++ components/salsa-macros/src/input.rs | 4 +- components/salsa-macros/src/interned.rs | 5 +- components/salsa-macros/src/tracked_struct.rs | 7 ++- src/database.rs | 23 +++++++- src/function.rs | 4 +- src/function/memo.rs | 9 ++- src/input.rs | 7 +++ src/interned.rs | 7 +++ src/table/memo.rs | 3 +- src/tracked_struct.rs | 7 +++ .../input_struct_incompatibles.rs | 3 - .../input_struct_incompatibles.stderr | 6 -- .../interned_struct_incompatibles.rs | 5 -- .../interned_struct_incompatibles.stderr | 6 -- .../tracked_struct_incompatibles.rs | 5 -- .../tracked_struct_incompatibles.stderr | 6 -- tests/memory-usage.rs | 58 +++++++++++++------ 21 files changed, 136 insertions(+), 60 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index d6d131abf..cc3871361 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -50,6 +50,9 @@ macro_rules! setup_input_struct { // If true, generate a debug impl. generate_debug_impl: $generate_debug_impl:tt, + // The function used to implement `C::heap_size`. + heap_size_fn: $($heap_size_fn:path)?, + // Annoyingly macro-rules hygiene does not extend to items defined in the macro. // We have the procedural macro generate names for those items that are // not used elsewhere in the user's code. @@ -98,6 +101,12 @@ macro_rules! setup_input_struct { type Revisions = [$zalsa::Revision; $N]; type Durabilities = [$zalsa::Durability; $N]; + + $( + fn heap_size(value: &Self::Fields) -> Option { + Some($heap_size_fn(value)) + } + )? } impl $Configuration { diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index b637586e5..ebf7f23cd 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -66,6 +66,9 @@ macro_rules! setup_interned_struct { // If true, generate a debug impl. generate_debug_impl: $generate_debug_impl:tt, + // The function used to implement `C::heap_size`. + heap_size_fn: $($heap_size_fn:path)?, + // Annoyingly macro-rules hygiene does not extend to items defined in the macro. // We have the procedural macro generate names for those items that are // not used elsewhere in the user's code. @@ -146,6 +149,12 @@ macro_rules! setup_interned_struct { )? type Fields<'a> = $StructDataIdent<'a>; type Struct<'db> = $Struct< $($db_lt_arg)? >; + + $( + fn heap_size(value: &Self::Fields<'_>) -> Option { + Some($heap_size_fn(value)) + } + )? } impl $Configuration { diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 50ca6a034..de0c11a18 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -240,8 +240,8 @@ macro_rules! setup_tracked_fn { $($values_equal)+ $( - fn heap_size(value: &Self::Output<'_>) -> usize { - $heap_size_fn(value) + fn heap_size(value: &Self::Output<'_>) -> Option { + Some($heap_size_fn(value)) } )? diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index f92b1ac5f..cb69a08e1 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -88,6 +88,9 @@ macro_rules! setup_tracked_struct { // If true, generate a debug impl. generate_debug_impl: $generate_debug_impl:tt, + // The function used to implement `C::heap_size`. + heap_size_fn: $($heap_size_fn:path)?, + // Annoyingly macro-rules hygiene does not extend to items defined in the macro. // We have the procedural macro generate names for those items that are // not used elsewhere in the user's code. @@ -185,6 +188,12 @@ macro_rules! setup_tracked_struct { )* false } } + + $( + fn heap_size(value: &Self::Fields<'_>) -> Option { + Some($heap_size_fn(value)) + } + )? } impl $Configuration { diff --git a/components/salsa-macros/src/input.rs b/components/salsa-macros/src/input.rs index 71799f2b0..b04176d68 100644 --- a/components/salsa-macros/src/input.rs +++ b/components/salsa-macros/src/input.rs @@ -65,7 +65,7 @@ impl crate::options::AllowedOptions for InputStruct { const REVISIONS: bool = false; - const HEAP_SIZE: bool = false; + const HEAP_SIZE: bool = true; const SELF_TY: bool = false; } @@ -112,6 +112,7 @@ impl Macro { let field_attrs = salsa_struct.field_attrs(); let is_singleton = self.args.singleton.is_some(); let generate_debug_impl = salsa_struct.generate_debug_impl(); + let heap_size_fn = self.args.heap_size_fn.iter(); let zalsa = self.hygiene.ident("zalsa"); let zalsa_struct = self.hygiene.ident("zalsa_struct"); @@ -140,6 +141,7 @@ impl Macro { num_fields: #num_fields, is_singleton: #is_singleton, generate_debug_impl: #generate_debug_impl, + heap_size_fn: #(#heap_size_fn)*, unused_names: [ #zalsa, #zalsa_struct, diff --git a/components/salsa-macros/src/interned.rs b/components/salsa-macros/src/interned.rs index 606701c6f..dd064af12 100644 --- a/components/salsa-macros/src/interned.rs +++ b/components/salsa-macros/src/interned.rs @@ -65,7 +65,7 @@ impl crate::options::AllowedOptions for InternedStruct { const REVISIONS: bool = true; - const HEAP_SIZE: bool = false; + const HEAP_SIZE: bool = true; const SELF_TY: bool = false; } @@ -131,6 +131,8 @@ impl Macro { (None, quote!(#struct_ident), static_lifetime) }; + let heap_size_fn = self.args.heap_size_fn.iter(); + let zalsa = self.hygiene.ident("zalsa"); let zalsa_struct = self.hygiene.ident("zalsa_struct"); let Configuration = self.hygiene.ident("Configuration"); @@ -161,6 +163,7 @@ impl Macro { field_attrs: [#([#(#field_unused_attrs),*]),*], num_fields: #num_fields, generate_debug_impl: #generate_debug_impl, + heap_size_fn: #(#heap_size_fn)*, unused_names: [ #zalsa, #zalsa_struct, diff --git a/components/salsa-macros/src/tracked_struct.rs b/components/salsa-macros/src/tracked_struct.rs index fe077da53..5768eb9cd 100644 --- a/components/salsa-macros/src/tracked_struct.rs +++ b/components/salsa-macros/src/tracked_struct.rs @@ -61,7 +61,7 @@ impl crate::options::AllowedOptions for TrackedStruct { const REVISIONS: bool = false; - const HEAP_SIZE: bool = false; + const HEAP_SIZE: bool = true; const SELF_TY: bool = false; } @@ -141,6 +141,8 @@ impl Macro { } }); + let heap_size_fn = self.args.heap_size_fn.iter(); + let num_tracked_fields = salsa_struct.num_tracked_fields(); let generate_debug_impl = salsa_struct.generate_debug_impl(); @@ -188,6 +190,9 @@ impl Macro { num_tracked_fields: #num_tracked_fields, generate_debug_impl: #generate_debug_impl, + + heap_size_fn: #(#heap_size_fn)*, + unused_names: [ #zalsa, #zalsa_struct, diff --git a/src/database.rs b/src/database.rs index a253ac5a0..933e137ac 100644 --- a/src/database.rs +++ b/src/database.rs @@ -172,17 +172,24 @@ mod memory_usage { let mut size_of_fields = 0; let mut size_of_metadata = 0; let mut instances = 0; + let mut heap_size_of_fields = None; for slot in ingredient.memory_usage(self)? { instances += 1; size_of_fields += slot.size_of_fields; size_of_metadata += slot.size_of_metadata; + + if let Some(slot_heap_size) = slot.heap_size_of_fields { + heap_size_of_fields = + Some(heap_size_of_fields.unwrap_or_default() + slot_heap_size); + } } Some(IngredientInfo { count: instances, size_of_fields, size_of_metadata, + heap_size_of_fields, debug_name: ingredient.debug_name(), }) }) @@ -211,6 +218,11 @@ mod memory_usage { info.count += 1; info.size_of_fields += memo.output.size_of_fields; info.size_of_metadata += memo.output.size_of_metadata; + + if let Some(memo_heap_size) = memo.output.heap_size_of_fields { + info.heap_size_of_fields = + Some(info.heap_size_of_fields.unwrap_or_default() + memo_heap_size); + } } } } @@ -226,6 +238,7 @@ mod memory_usage { count: usize, size_of_metadata: usize, size_of_fields: usize, + heap_size_of_fields: Option, } impl IngredientInfo { @@ -234,11 +247,18 @@ mod memory_usage { self.debug_name } - /// Returns the total size of the fields of any instances of this ingredient, in bytes. + /// Returns the total stack size of the fields of any instances of this ingredient, in bytes. pub fn size_of_fields(&self) -> usize { self.size_of_fields } + /// Returns the total heap size of the fields of any instances of this ingredient, in bytes. + /// + /// Returns `None` if the ingredient doesn't specify a `heap_size` function. + pub fn heap_size_of_fields(&self) -> Option { + self.heap_size_of_fields + } + /// Returns the total size of Salsa metadata of any instances of this ingredient, in bytes. pub fn size_of_metadata(&self) -> usize { self.size_of_metadata @@ -255,6 +275,7 @@ mod memory_usage { pub(crate) debug_name: &'static str, pub(crate) size_of_metadata: usize, pub(crate) size_of_fields: usize, + pub(crate) heap_size_of_fields: Option, pub(crate) memos: Vec, } diff --git a/src/function.rs b/src/function.rs index e9ab57939..6d4ec5365 100644 --- a/src/function.rs +++ b/src/function.rs @@ -75,8 +75,8 @@ pub trait Configuration: Any { fn id_to_input(zalsa: &Zalsa, key: Id) -> Self::Input<'_>; /// Returns the size of any heap allocations in the output value, in bytes. - fn heap_size(_value: &Self::Output<'_>) -> usize { - 0 + fn heap_size(_value: &Self::Output<'_>) -> Option { + None } /// Invoked when we need to compute the value for the given key, either because we've never diff --git a/src/function/memo.rs b/src/function/memo.rs index 810e5b268..a6107060e 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -321,14 +321,19 @@ where #[cfg(feature = "salsa_unstable")] fn memory_usage(&self) -> crate::database::MemoInfo { let size_of = std::mem::size_of::>() + self.revisions.allocation_size(); - let heap_size = self.value.as_ref().map(C::heap_size).unwrap_or(0); + let heap_size = if let Some(value) = self.value.as_ref() { + C::heap_size(value) + } else { + Some(0) + }; crate::database::MemoInfo { debug_name: C::DEBUG_NAME, output: crate::database::SlotInfo { size_of_metadata: size_of - std::mem::size_of::>(), debug_name: std::any::type_name::>(), - size_of_fields: std::mem::size_of::>() + heap_size, + size_of_fields: std::mem::size_of::>(), + heap_size_of_fields: heap_size, memos: Vec::new(), }, } diff --git a/src/input.rs b/src/input.rs index 58d769a6b..464f79f10 100644 --- a/src/input.rs +++ b/src/input.rs @@ -40,6 +40,11 @@ pub trait Configuration: Any { /// A array of [`Durability`], one per each of the value fields. type Durabilities: Send + Sync + fmt::Debug + IndexMut; + + /// Returns the size of any heap allocations in the output value, in bytes. + fn heap_size(_value: &Self::Fields) -> Option { + None + } } pub struct JarImpl { @@ -307,6 +312,7 @@ where /// The `MemoTable` must belong to a `Value` of the correct type. #[cfg(feature = "salsa_unstable")] unsafe fn memory_usage(&self, memo_table_types: &MemoTableTypes) -> crate::database::SlotInfo { + let heap_size = C::heap_size(&self.fields); // SAFETY: The caller guarantees this is the correct types table. let memos = unsafe { memo_table_types.attach_memos(&self.memos) }; @@ -314,6 +320,7 @@ where debug_name: C::DEBUG_NAME, size_of_metadata: std::mem::size_of::() - std::mem::size_of::(), size_of_fields: std::mem::size_of::(), + heap_size_of_fields: heap_size, memos: memos.memory_usage(), } } diff --git a/src/interned.rs b/src/interned.rs index 49663ee65..cdf4d61b8 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -44,6 +44,11 @@ pub trait Configuration: Sized + 'static { /// The end user struct type Struct<'db>: Copy + FromId + AsId; + + /// Returns the size of any heap allocations in the output value, in bytes. + fn heap_size(_value: &Self::Fields<'_>) -> Option { + None + } } pub trait InternedData: Sized + Eq + Hash + Clone + Sync + Send {} @@ -199,6 +204,7 @@ where /// lock must be held for the shard containing the value. #[cfg(all(not(feature = "shuttle"), feature = "salsa_unstable"))] unsafe fn memory_usage(&self, memo_table_types: &MemoTableTypes) -> crate::database::SlotInfo { + let heap_size = C::heap_size(self.fields()); // SAFETY: The caller guarantees we hold the lock for the shard containing the value, so we // have at-least read-only access to the value's memos. let memos = unsafe { &*self.memos.get() }; @@ -209,6 +215,7 @@ where debug_name: C::DEBUG_NAME, size_of_metadata: std::mem::size_of::() - std::mem::size_of::>(), size_of_fields: std::mem::size_of::>(), + heap_size_of_fields: heap_size, memos: memos.memory_usage(), } } diff --git a/src/table/memo.rs b/src/table/memo.rs index b7bc5fb7d..7e4837aa1 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -14,7 +14,7 @@ pub struct MemoTable { } impl MemoTable { - /// Create a `MemoTable` with slots for memos from the provided `MemoTableTypes`. + /// Create a `MemoTable` with slots for memos from the provided `MemoTableTypes`. /// /// # Safety /// @@ -127,6 +127,7 @@ impl Memo for DummyMemo { debug_name: "dummy", size_of_metadata: 0, size_of_fields: 0, + heap_size_of_fields: None, memos: Vec::new(), }, } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 2ec3e24dc..00986107a 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -90,6 +90,11 @@ pub trait Configuration: Sized + 'static { old_fields: *mut Self::Fields<'db>, new_fields: Self::Fields<'db>, ) -> bool; + + /// Returns the size of any heap allocations in the output value, in bytes. + fn heap_size(_value: &Self::Fields<'_>) -> Option { + None + } } // ANCHOR_END: Configuration @@ -935,6 +940,7 @@ where /// The `MemoTable` must belong to a `Value` of the correct type. #[cfg(feature = "salsa_unstable")] unsafe fn memory_usage(&self, memo_table_types: &MemoTableTypes) -> crate::database::SlotInfo { + let heap_size = C::heap_size(self.fields()); // SAFETY: The caller guarantees this is the correct types table. let memos = unsafe { memo_table_types.attach_memos(&self.memos) }; @@ -942,6 +948,7 @@ where debug_name: C::DEBUG_NAME, size_of_metadata: mem::size_of::() - mem::size_of::>(), size_of_fields: mem::size_of::>(), + heap_size_of_fields: heap_size, memos: memos.memory_usage(), } } diff --git a/tests/compile-fail/input_struct_incompatibles.rs b/tests/compile-fail/input_struct_incompatibles.rs index 98cdb916d..31ca9abb8 100644 --- a/tests/compile-fail/input_struct_incompatibles.rs +++ b/tests/compile-fail/input_struct_incompatibles.rs @@ -25,7 +25,4 @@ struct InputWithTrackedField { field: u32, } -#[salsa::input(heap_size = size)] -struct InputWithHeapSize(u32); - fn main() {} diff --git a/tests/compile-fail/input_struct_incompatibles.stderr b/tests/compile-fail/input_struct_incompatibles.stderr index 9fe025275..a1b94e9aa 100644 --- a/tests/compile-fail/input_struct_incompatibles.stderr +++ b/tests/compile-fail/input_struct_incompatibles.stderr @@ -47,12 +47,6 @@ error: `#[tracked]` cannot be used with `#[salsa::input]` 25 | | field: u32, | |______________^ -error: `heap_size` option not allowed here - --> tests/compile-fail/input_struct_incompatibles.rs:28:16 - | -28 | #[salsa::input(heap_size = size)] - | ^^^^^^^^^ - error: cannot find attribute `tracked` in this scope --> tests/compile-fail/input_struct_incompatibles.rs:24:7 | diff --git a/tests/compile-fail/interned_struct_incompatibles.rs b/tests/compile-fail/interned_struct_incompatibles.rs index b8d504282..435335b18 100644 --- a/tests/compile-fail/interned_struct_incompatibles.rs +++ b/tests/compile-fail/interned_struct_incompatibles.rs @@ -39,9 +39,4 @@ struct InternedWithZeroRevisions { field: u32, } -#[salsa::interned(heap_size = size)] -struct AccWithHeapSize { - field: u32, -} - fn main() {} diff --git a/tests/compile-fail/interned_struct_incompatibles.stderr b/tests/compile-fail/interned_struct_incompatibles.stderr index 76ccc7f8b..482e38b46 100644 --- a/tests/compile-fail/interned_struct_incompatibles.stderr +++ b/tests/compile-fail/interned_struct_incompatibles.stderr @@ -41,12 +41,6 @@ error: `#[tracked]` cannot be used with `#[salsa::interned]` 34 | | field: u32, | |______________^ -error: `heap_size` option not allowed here - --> tests/compile-fail/interned_struct_incompatibles.rs:42:19 - | -42 | #[salsa::interned(heap_size = size)] - | ^^^^^^^^^ - error: cannot find attribute `tracked` in this scope --> tests/compile-fail/interned_struct_incompatibles.rs:33:7 | diff --git a/tests/compile-fail/tracked_struct_incompatibles.rs b/tests/compile-fail/tracked_struct_incompatibles.rs index eff1eebd1..5abd62dcc 100644 --- a/tests/compile-fail/tracked_struct_incompatibles.rs +++ b/tests/compile-fail/tracked_struct_incompatibles.rs @@ -33,9 +33,4 @@ struct TrackedStructWithRevisions { field: u32, } -#[salsa::tracked(heap_size = size)] -struct TrackedStructWithHeapSize { - field: u32, -} - fn main() {} diff --git a/tests/compile-fail/tracked_struct_incompatibles.stderr b/tests/compile-fail/tracked_struct_incompatibles.stderr index e27777ca0..928bbb126 100644 --- a/tests/compile-fail/tracked_struct_incompatibles.stderr +++ b/tests/compile-fail/tracked_struct_incompatibles.stderr @@ -39,9 +39,3 @@ error: `revisions` option not allowed here | 31 | #[salsa::tracked(revisions = 12)] | ^^^^^^^^^ - -error: `heap_size` option not allowed here - --> tests/compile-fail/tracked_struct_incompatibles.rs:36:18 - | -36 | #[salsa::tracked(heap_size = size)] - | ^^^^^^^^^ diff --git a/tests/memory-usage.rs b/tests/memory-usage.rs index f9fca29ab..30a669e7e 100644 --- a/tests/memory-usage.rs +++ b/tests/memory-usage.rs @@ -2,19 +2,19 @@ use expect_test::expect; -#[salsa::input] +#[salsa::input(heap_size = string_tuple_size_of)] struct MyInput { - field: u32, + field: String, } -#[salsa::tracked] +#[salsa::tracked(heap_size = string_tuple_size_of)] struct MyTracked<'db> { - field: u32, + field: String, } -#[salsa::interned] +#[salsa::interned(heap_size = string_tuple_size_of)] struct MyInterned<'db> { - field: u32, + field: String, } #[salsa::tracked] @@ -32,12 +32,16 @@ fn input_to_string<'db>(_db: &'db dyn salsa::Database) -> String { "a".repeat(1000) } -#[salsa::tracked(heap_size = string_heap_size)] +#[salsa::tracked(heap_size = string_size_of)] fn input_to_string_get_size<'db>(_db: &'db dyn salsa::Database) -> String { "a".repeat(1000) } -fn string_heap_size(x: &String) -> usize { +fn string_size_of(x: &String) -> usize { + x.capacity() +} + +fn string_tuple_size_of((x,): &(String,)) -> usize { x.capacity() } @@ -56,9 +60,9 @@ fn input_to_tracked_tuple<'db>( fn test() { let db = salsa::DatabaseImpl::new(); - let input1 = MyInput::new(&db, 1); - let input2 = MyInput::new(&db, 2); - let input3 = MyInput::new(&db, 3); + let input1 = MyInput::new(&db, "a".repeat(50)); + let input2 = MyInput::new(&db, "a".repeat(150)); + let input3 = MyInput::new(&db, "a".repeat(250)); let _tracked1 = input_to_tracked(&db, input1); let _tracked2 = input_to_tracked(&db, input2); @@ -79,32 +83,43 @@ fn test() { IngredientInfo { debug_name: "MyInput", count: 3, - size_of_metadata: 84, - size_of_fields: 12, + size_of_metadata: 96, + size_of_fields: 72, + heap_size_of_fields: Some( + 450, + ), }, IngredientInfo { debug_name: "MyTracked", count: 4, - size_of_metadata: 112, - size_of_fields: 16, + size_of_metadata: 128, + size_of_fields: 96, + heap_size_of_fields: Some( + 300, + ), }, IngredientInfo { debug_name: "MyInterned", count: 3, - size_of_metadata: 156, - size_of_fields: 12, + size_of_metadata: 168, + size_of_fields: 72, + heap_size_of_fields: Some( + 450, + ), }, IngredientInfo { debug_name: "input_to_string::interned_arguments", count: 1, size_of_metadata: 56, size_of_fields: 0, + heap_size_of_fields: None, }, IngredientInfo { debug_name: "input_to_string_get_size::interned_arguments", count: 1, size_of_metadata: 56, size_of_fields: 0, + heap_size_of_fields: None, }, ]"#]]; @@ -124,6 +139,7 @@ fn test() { count: 3, size_of_metadata: 192, size_of_fields: 24, + heap_size_of_fields: None, }, ), ( @@ -133,6 +149,7 @@ fn test() { count: 1, size_of_metadata: 40, size_of_fields: 24, + heap_size_of_fields: None, }, ), ( @@ -141,7 +158,10 @@ fn test() { debug_name: "alloc::string::String", count: 1, size_of_metadata: 40, - size_of_fields: 1024, + size_of_fields: 24, + heap_size_of_fields: Some( + 1000, + ), }, ), ( @@ -151,6 +171,7 @@ fn test() { count: 2, size_of_metadata: 192, size_of_fields: 16, + heap_size_of_fields: None, }, ), ( @@ -160,6 +181,7 @@ fn test() { count: 1, size_of_metadata: 132, size_of_fields: 16, + heap_size_of_fields: None, }, ), ]"#]]; From e999ae966357a25cfd50d047a0b96252b3a52826 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Wed, 6 Aug 2025 13:33:44 -0400 Subject: [PATCH 21/65] consolidate memory usage information API (#964) --- src/database.rs | 78 +++++++++++++++++++++---------------------- tests/memory-usage.rs | 8 ++--- 2 files changed, 41 insertions(+), 45 deletions(-) diff --git a/src/database.rs b/src/database.rs index 933e137ac..46120eae4 100644 --- a/src/database.rs +++ b/src/database.rs @@ -164,52 +164,32 @@ mod memory_usage { use hashbrown::HashMap; impl dyn Database { - /// Returns information about any Salsa structs. - pub fn structs_info(&self) -> Vec { - self.zalsa() - .ingredients() - .filter_map(|ingredient| { - let mut size_of_fields = 0; - let mut size_of_metadata = 0; - let mut instances = 0; - let mut heap_size_of_fields = None; - - for slot in ingredient.memory_usage(self)? { - instances += 1; - size_of_fields += slot.size_of_fields; - size_of_metadata += slot.size_of_metadata; - - if let Some(slot_heap_size) = slot.heap_size_of_fields { - heap_size_of_fields = - Some(heap_size_of_fields.unwrap_or_default() + slot_heap_size); - } - } - - Some(IngredientInfo { - count: instances, - size_of_fields, - size_of_metadata, - heap_size_of_fields, - debug_name: ingredient.debug_name(), - }) - }) - .collect() - } - - /// Returns information about any memoized Salsa queries. - /// - /// The returned map holds memory usage information for memoized values of a given query, keyed - /// by the query function name. - pub fn queries_info(&self) -> HashMap<&'static str, IngredientInfo> { + /// Returns memory usage information about ingredients in the database. + pub fn memory_usage(&self) -> DatabaseInfo { let mut queries = HashMap::new(); + let mut structs = Vec::new(); for input_ingredient in self.zalsa().ingredients() { let Some(input_info) = input_ingredient.memory_usage(self) else { continue; }; - for input in input_info { - for memo in input.memos { + let mut size_of_fields = 0; + let mut size_of_metadata = 0; + let mut count = 0; + let mut heap_size_of_fields = None; + + for input_slot in input_info { + count += 1; + size_of_fields += input_slot.size_of_fields; + size_of_metadata += input_slot.size_of_metadata; + + if let Some(slot_heap_size) = input_slot.heap_size_of_fields { + heap_size_of_fields = + Some(heap_size_of_fields.unwrap_or_default() + slot_heap_size); + } + + for memo in input_slot.memos { let info = queries.entry(memo.debug_name).or_insert(IngredientInfo { debug_name: memo.output.debug_name, ..Default::default() @@ -225,12 +205,30 @@ mod memory_usage { } } } + + structs.push(IngredientInfo { + count, + size_of_fields, + size_of_metadata, + heap_size_of_fields, + debug_name: input_ingredient.debug_name(), + }); } - queries + DatabaseInfo { structs, queries } } } + /// Memory usage information about ingredients in the Salsa database. + pub struct DatabaseInfo { + /// Information about any Salsa structs. + pub structs: Vec, + + /// Memory usage information for memoized values of a given query, keyed + /// by the query function name. + pub queries: HashMap<&'static str, IngredientInfo>, + } + /// Information about instances of a particular Salsa ingredient. #[derive(Default, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct IngredientInfo { diff --git a/tests/memory-usage.rs b/tests/memory-usage.rs index 30a669e7e..9a9433d3b 100644 --- a/tests/memory-usage.rs +++ b/tests/memory-usage.rs @@ -76,7 +76,7 @@ fn test() { let _string1 = input_to_string(&db); let _string2 = input_to_string_get_size(&db); - let structs_info = ::structs_info(&db); + let memory_usage = ::memory_usage(&db); let expected = expect![[r#" [ @@ -123,11 +123,9 @@ fn test() { }, ]"#]]; - expected.assert_eq(&format!("{structs_info:#?}")); + expected.assert_eq(&format!("{:#?}", memory_usage.structs)); - let mut queries_info = ::queries_info(&db) - .into_iter() - .collect::>(); + let mut queries_info = memory_usage.queries.into_iter().collect::>(); queries_info.sort(); let expected = expect![[r#" From b121ee46c4483ba74c19e933a3522bd548eb7343 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Wed, 6 Aug 2025 13:55:51 -0400 Subject: [PATCH 22/65] remove allocation lock (#962) --- src/table.rs | 64 +++++++++++++++++++++------------------------- src/zalsa_local.rs | 13 +++++++--- 2 files changed, 39 insertions(+), 38 deletions(-) diff --git a/src/table.rs b/src/table.rs index 62abf2dee..c6d22118b 100644 --- a/src/table.rs +++ b/src/table.rs @@ -60,9 +60,11 @@ struct SlotVTable { /// [`Slot`] methods memos: SlotMemosFnRaw, memos_mut: SlotMemosMutFnRaw, + /// The type name of what is stored as entries in data. + type_name: fn() -> &'static str, /// A drop impl to call when the own page drops - /// SAFETY: The caller is required to supply a correct data pointer to a `Box>` and initialized length, - /// and correct memo types. + /// SAFETY: The caller is required to supply a valid pointer to a `Box>`, and + /// the correct initialized length and memo types. drop_impl: unsafe fn(data: *mut (), initialized: usize, memo_types: &MemoTableTypes), } @@ -70,20 +72,23 @@ impl SlotVTable { const fn of() -> &'static Self { const { &Self { - drop_impl: |data, initialized, memo_types| - // SAFETY: The caller is required to supply a correct data pointer and initialized length - unsafe { - let data = Box::from_raw(data.cast::>()); + drop_impl: |data, initialized, memo_types| { + // SAFETY: The caller is required to provide a valid data pointer. + let data = unsafe { Box::from_raw(data.cast::>()) }; for i in 0..initialized { let item = data[i].get().cast::(); - memo_types.attach_memos_mut((*item).memos_mut()).drop(); - ptr::drop_in_place(item); + // SAFETY: The caller is required to provide a valid initialized length. + unsafe { + memo_types.attach_memos_mut((*item).memos_mut()).drop(); + ptr::drop_in_place(item); + } } }, layout: Layout::new::(), - // SAFETY: The signatures are compatible + type_name: std::any::type_name::, + // SAFETY: The signatures are ABI-compatible. memos: unsafe { mem::transmute::, SlotMemosFnRaw>(T::memos) }, - // SAFETY: The signatures are compatible + // SAFETY: The signatures are ABI-compatible. memos_mut: unsafe { mem::transmute::, SlotMemosMutFnRaw>(T::memos_mut) }, @@ -102,16 +107,6 @@ struct Page { /// Number of elements of `data` that are initialized. allocated: AtomicUsize, - /// The "allocation lock" is held when we allocate a new entry. - /// - /// It ensures that we can load the index, initialize it, and then update the length atomically - /// with respect to other allocations. - /// - /// We could avoid it if we wanted, we'd just have to be a bit fancier in our reasoning - /// (for example, the bounds check in `Page::get` no longer suffices to truly guarantee - /// that the data is initialized). - allocation_lock: Mutex<()>, - /// The potentially uninitialized data of this page. As we initialize new entries, we increment `allocated`. /// This is a box allocated `PageData` data: NonNull<()>, @@ -121,9 +116,6 @@ struct Page { /// The type id of what is stored as entries in data. // FIXME: Move this into SlotVTable once const stable slot_type_id: TypeId, - /// The type name of what is stored as entries in data. - // FIXME: Move this into SlotVTable once const stable - slot_type_name: &'static str, memo_types: Arc, } @@ -329,12 +321,17 @@ impl<'db, T: Slot> PageView<'db, T> { unsafe { slice::from_raw_parts(self.0.data.cast::().as_ptr(), len) } } + /// Allocate a value in this page. + /// + /// # Safety + /// + /// The caller must be the unique writer to this page, i.e. `allocate` cannot be called + /// concurrently by multiple threads. Concurrent readers however, are fine. #[inline] - pub(crate) fn allocate(&self, page: PageIndex, value: V) -> Result<(Id, &'db T), V> + pub(crate) unsafe fn allocate(&self, page: PageIndex, value: V) -> Result<(Id, &'db T), V> where V: FnOnce(Id) -> T, { - let _guard = self.0.allocation_lock.lock(); let index = self.0.allocated.load(Ordering::Acquire); if index >= PAGE_LEN { return Err(value); @@ -347,15 +344,14 @@ impl<'db, T: Slot> PageView<'db, T> { // SAFETY: `index` is also guaranteed to be in bounds as per the check above. let entry = unsafe { &*data.as_ptr().add(index) }; - // SAFETY: We acquired the allocation lock, so we have unique access to the UnsafeCell - // interior + // SAFETY: The caller guarantees we are the unique writer, and readers will not attempt to + // access this index until we have updated the length. unsafe { (*entry.get()).write(value(id)) }; // SAFETY: We just initialized the value above. let value = unsafe { (*entry.get()).assume_init_ref() }; - // Update the length (this must be done after initialization as otherwise an uninitialized - // read could occur!) + // Update the length now that we have initialized the value. self.0.allocated.store(index + 1, Ordering::Release); Ok((id, value)) @@ -383,14 +379,12 @@ impl Page { }; Self { + ingredient, + memo_types, slot_vtable: SlotVTable::of::(), slot_type_id: TypeId::of::(), - slot_type_name: std::any::type_name::(), - ingredient, - allocated: Default::default(), - allocation_lock: Default::default(), + allocated: AtomicUsize::new(0), data: NonNull::from(Box::leak(data)).cast::<()>(), - memo_types, } } @@ -439,7 +433,7 @@ impl Page { fn type_assert_failed(page: &Page) -> ! { panic!( "page has slot type `{:?}` but `{:?}` was expected", - page.slot_type_name, + (page.slot_vtable.type_name)(), std::any::type_name::(), ) } diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 8d58d7171..59b2165b3 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -70,7 +70,10 @@ impl ZalsaLocal { // Fast-path, we already have an unfilled page available. if let Some(&page) = most_recent_pages.get(&ingredient) { let page_ref = zalsa.table().page::(page); - match page_ref.allocate(page, value) { + + // SAFETY: `ZalsaLocal` is `!Sync`, and we only insert a page into `most_recent_pages` + // if it was allocated by our thread, so we are the unique writer. + match unsafe { page_ref.allocate(page, value) } { Ok((id, value)) => return (id, value), Err(v) => value = v, } @@ -108,11 +111,15 @@ impl ZalsaLocal { loop { // Try to allocate an entry on that page let page_ref = zalsa.table().page::(page); - match page_ref.allocate(page, value) { + + // SAFETY: `ZalsaLocal` is `!Sync`, and we only insert a page into `most_recent_pages` + // if it was allocated by our thread, so we are the unique writer. + match unsafe { page_ref.allocate(page, value) } { // If successful, return Ok((id, value)) => return (id, value), - // Otherwise, create a new page and try again + // Otherwise, create a new page and try again. + // // Note that we could try fetching a page again, but as we just filled one up // it is unlikely that there is a non-full one available. Err(v) => { From 1b8f7c079fb5cb2b8e1dd69a139bb677ff723651 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Fri, 8 Aug 2025 09:49:55 +0200 Subject: [PATCH 23/65] Update tests for Rust 1.89 (#966) --- .../compile-fail/tracked_fn_return_ref.stderr | 20 ------------------- tests/compile_fail.rs | 2 +- 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/tests/compile-fail/tracked_fn_return_ref.stderr b/tests/compile-fail/tracked_fn_return_ref.stderr index ee8595edc..7015adb7c 100644 --- a/tests/compile-fail/tracked_fn_return_ref.stderr +++ b/tests/compile-fail/tracked_fn_return_ref.stderr @@ -9,26 +9,6 @@ help: consider using the `'db` lifetime 33 | ) -> ContainsRef<'db> { | +++++ -warning: elided lifetime has a name - --> tests/compile-fail/tracked_fn_return_ref.rs:33:6 - | -30 | fn tracked_fn_return_struct_containing_ref_elided_implicit<'db>( - | --- lifetime `'db` declared here -... -33 | ) -> ContainsRef { - | ^^^^^^^^^^^ this elided lifetime gets resolved as `'db` - | - = note: `#[warn(elided_named_lifetimes)]` on by default - -warning: elided lifetime has a name - --> tests/compile-fail/tracked_fn_return_ref.rs:43:18 - | -40 | fn tracked_fn_return_struct_containing_ref_elided_explicit<'db>( - | --- lifetime `'db` declared here -... -43 | ) -> ContainsRef<'_> { - | ^^ this elided lifetime gets resolved as `'db` - error: lifetime may not live long enough --> tests/compile-fail/tracked_fn_return_ref.rs:15:67 | diff --git a/tests/compile_fail.rs b/tests/compile_fail.rs index c43c0b4db..73f87ee52 100644 --- a/tests/compile_fail.rs +++ b/tests/compile_fail.rs @@ -1,6 +1,6 @@ #![cfg(feature = "inventory")] -#[rustversion::all(stable, since(1.84))] +#[rustversion::all(stable, since(1.89))] #[test] fn compile_fail() { let t = trybuild::TestCases::new(); From 22a4d9932bf628aa7bb68fa7a94fc66796b05c47 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Fri, 8 Aug 2025 10:32:37 +0200 Subject: [PATCH 24/65] test: add parallel maybe changed after test (#963) * Add parallel `maybe_changed_after` shuttle test * Update tests/parallel/cycle_nested_deep_conditional_changed.rs Co-authored-by: Carl Meyer --------- Co-authored-by: Carl Meyer --- .../cycle_nested_deep_conditional_changed.rs | 151 ++++++++++++++++++ .../cycle_nested_three_threads_changed.rs | 123 ++++++++++++++ tests/parallel/main.rs | 2 + 3 files changed, 276 insertions(+) create mode 100644 tests/parallel/cycle_nested_deep_conditional_changed.rs create mode 100644 tests/parallel/cycle_nested_three_threads_changed.rs diff --git a/tests/parallel/cycle_nested_deep_conditional_changed.rs b/tests/parallel/cycle_nested_deep_conditional_changed.rs new file mode 100644 index 000000000..7c96d808d --- /dev/null +++ b/tests/parallel/cycle_nested_deep_conditional_changed.rs @@ -0,0 +1,151 @@ +//! Test a deeply nested-cycle scenario where cycles have changing query dependencies. +//! +//! The trick is that different threads call into the same cycle from different entry queries and +//! the cycle heads change over different iterations +//! +//! * Thread 1: `a` -> `b` -> `c` +//! * Thread 2: `b` +//! * Thread 3: `d` -> `c` +//! * Thread 4: `e` -> `c` +//! +//! `c` calls: +//! * `d` and `a` in the first few iterations +//! * `d`, `b` and `e` in the last iterations +//! +//! Specifically, the maybe_changed_after flow. +use crate::sync::thread; + +use salsa::CycleRecoveryAction; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(3); + +#[salsa::input] +struct Input { + value: u32, +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_a(db: &dyn salsa::Database, input: Input) -> CycleValue { + query_b(db, input) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_b(db: &dyn salsa::Database, input: Input) -> CycleValue { + let c_value = query_c(db, input); + CycleValue(c_value.0 + input.value(db).max(1)).min(MAX) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_c(db: &dyn salsa::Database, input: Input) -> CycleValue { + let d_value = query_d(db, input); + + if d_value > CycleValue(0) { + let e_value = query_e(db, input); + let b_value = query_b(db, input); + CycleValue(d_value.0.max(e_value.0).max(b_value.0)) + } else { + let a_value = query_a(db, input); + CycleValue(d_value.0.max(a_value.0)) + } +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_d(db: &dyn salsa::Database, input: Input) -> CycleValue { + query_c(db, input) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_e(db: &dyn salsa::Database, input: Input) -> CycleValue { + query_c(db, input) +} + +fn cycle_fn( + _db: &dyn salsa::Database, + _value: &CycleValue, + _count: u32, + _input: Input, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn initial(_db: &dyn salsa::Database, _input: Input) -> CycleValue { + MIN +} + +#[test_log::test] +fn the_test() { + use crate::sync; + use salsa::Setter as _; + sync::check(|| { + tracing::debug!("New run"); + + // This is a bit silly but it works around https://github.com/awslabs/shuttle/issues/192 + static INITIALIZE: sync::Mutex> = + sync::Mutex::new(None); + + fn get_db(f: impl FnOnce(&salsa::DatabaseImpl, Input)) -> (salsa::DatabaseImpl, Input) { + let mut shared = INITIALIZE.lock().unwrap(); + + if let Some((db, input)) = shared.as_ref() { + return (db.clone(), *input); + } + + let mut db = salsa::DatabaseImpl::default(); + + let input = Input::new(&db, 0); + + f(&db, input); + + input.set_value(&mut db).to(1); + + *shared = Some((db.clone(), input)); + + (db, input) + } + + let t1 = thread::spawn(move || { + let (db, input) = get_db(|db, input| { + query_a(db, input); + }); + + let _span = tracing::debug_span!("t1", thread_id = ?thread::current().id()).entered(); + + query_a(&db, input) + }); + let t2 = thread::spawn(move || { + let (db, input) = get_db(|db, input| { + query_b(db, input); + }); + + let _span = tracing::debug_span!("t4", thread_id = ?thread::current().id()).entered(); + query_b(&db, input) + }); + let t3 = thread::spawn(move || { + let (db, input) = get_db(|db, input| { + query_d(db, input); + }); + + let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered(); + query_d(&db, input) + }); + let t4 = thread::spawn(move || { + let (db, input) = get_db(|db, input| { + query_e(db, input); + }); + + let _span = tracing::debug_span!("t3", thread_id = ?thread::current().id()).entered(); + query_e(&db, input) + }); + + let r_t1 = t1.join().unwrap(); + let r_t2 = t2.join().unwrap(); + let r_t3 = t3.join().unwrap(); + let r_t4 = t4.join().unwrap(); + + assert_eq!((r_t1, r_t2, r_t3, r_t4), (MAX, MAX, MAX, MAX)); + }); +} diff --git a/tests/parallel/cycle_nested_three_threads_changed.rs b/tests/parallel/cycle_nested_three_threads_changed.rs new file mode 100644 index 000000000..ccd92a407 --- /dev/null +++ b/tests/parallel/cycle_nested_three_threads_changed.rs @@ -0,0 +1,123 @@ +//! Test a nested-cycle scenario across three threads: +//! +//! ```text +//! Thread T1 Thread T2 Thread T3 +//! --------- --------- --------- +//! | | | +//! v | | +//! query_a() | | +//! ^ | v | +//! | +------------> query_b() | +//! | ^ | v +//! | | +------------> query_c() +//! | | | +//! +------------------+--------------------+ +//! ``` +//! +//! Specifically, the maybe_changed_after flow. + +use crate::sync; +use crate::sync::thread; + +use salsa::{CycleRecoveryAction, DatabaseImpl, Setter as _}; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(3); + +#[salsa::input] +struct Input { + value: u32, +} + +// Signal 1: T1 has entered `query_a` +// Signal 2: T2 has entered `query_b` +// Signal 3: T3 has entered `query_c` + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_a(db: &dyn salsa::Database, input: Input) -> CycleValue { + query_b(db, input) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_b(db: &dyn salsa::Database, input: Input) -> CycleValue { + let c_value = query_c(db, input); + CycleValue(c_value.0 + input.value(db)).min(MAX) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_c(db: &dyn salsa::Database, input: Input) -> CycleValue { + let a_value = query_a(db, input); + let b_value = query_b(db, input); + CycleValue(a_value.0.max(b_value.0)) +} + +fn cycle_fn( + _db: &dyn salsa::Database, + _value: &CycleValue, + _count: u32, + _input: Input, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn initial(_db: &dyn salsa::Database, _input: Input) -> CycleValue { + MIN +} + +#[test_log::test] +fn the_test() { + crate::sync::check(move || { + // This is a bit silly but it works around https://github.com/awslabs/shuttle/issues/192 + static INITIALIZE: sync::Mutex> = + sync::Mutex::new(None); + + fn get_db(f: impl FnOnce(&salsa::DatabaseImpl, Input)) -> (salsa::DatabaseImpl, Input) { + let mut shared = INITIALIZE.lock().unwrap(); + + if let Some((db, input)) = shared.as_ref() { + return (db.clone(), *input); + } + + let mut db = DatabaseImpl::default(); + + let input = Input::new(&db, 1); + + f(&db, input); + + input.set_value(&mut db).to(2); + + *shared = Some((db.clone(), input)); + + (db, input) + } + + let t1 = thread::spawn(|| { + let (db, input) = get_db(|db, input| { + query_a(db, input); + }); + + query_a(&db, input) + }); + let t2 = thread::spawn(|| { + let (db, input) = get_db(|db, input| { + query_b(db, input); + }); + query_b(&db, input) + }); + let t3 = thread::spawn(|| { + let (db, input) = get_db(|db, input| { + query_c(db, input); + }); + query_c(&db, input) + }); + + let r_t1 = t1.join().unwrap(); + let r_t2 = t2.join().unwrap(); + let r_t3 = t3.join().unwrap(); + + assert_eq!((r_t1, r_t2, r_t3), (MAX, MAX, MAX)); + }); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index e14780424..a764a864c 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -8,7 +8,9 @@ mod cycle_a_t1_b_t2_fallback; mod cycle_ab_peeping_c; mod cycle_nested_deep; mod cycle_nested_deep_conditional; +mod cycle_nested_deep_conditional_changed; mod cycle_nested_three_threads; +mod cycle_nested_three_threads_changed; mod cycle_panic; mod cycle_provisional_depending_on_itself; mod parallel_cancellation; From 940f9c09a7d0a0903bc7edc4a3f1977382980f0b Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Fri, 8 Aug 2025 14:22:31 +0200 Subject: [PATCH 25/65] Fix `maybe_changed_after` runnaway for fixpoint queries (#961) * WIP: Fix maybe_changed_after with fixpoint * Remove `validate_same_iteration` from `deep_verify_memo` * Keep `VerifyResult` small * Simplify? * Add test for verifying dependencies if the outer query has cycles * Docs * Rename `MaybeChangeAfterCycleHeads` * Remove provisional fallback --- src/accumulator.rs | 5 +- src/cycle.rs | 31 ---- src/function.rs | 7 +- src/function/execute.rs | 2 + src/function/fetch.rs | 44 +++-- src/function/maybe_changed_after.rs | 261 ++++++++++++++++++++-------- src/ingredient.rs | 7 +- src/input.rs | 5 +- src/input/input_field.rs | 5 +- src/interned.rs | 7 +- src/key.rs | 5 +- src/tracked_struct.rs | 5 +- src/tracked_struct/tracked_field.rs | 5 +- tests/common/mod.rs | 2 + tests/cycle.rs | 73 ++++++++ 15 files changed, 323 insertions(+), 141 deletions(-) diff --git a/src/accumulator.rs b/src/accumulator.rs index 0e9feab62..62332b000 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -7,8 +7,7 @@ use std::panic::UnwindSafe; use accumulated::{Accumulated, AnyAccumulated}; -use crate::cycle::CycleHeadKeys; -use crate::function::VerifyResult; +use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::ingredient::{Ingredient, Jar}; use crate::plumbing::ZalsaLocal; use crate::sync::Arc; @@ -106,7 +105,7 @@ impl Ingredient for IngredientImpl { _db: crate::database::RawDatabase<'_>, _input: Id, _revision: Revision, - _cycle_heads: &mut CycleHeadKeys, + _cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult { panic!("nothing should ever depend on an accumulator directly") } diff --git a/src/cycle.rs b/src/cycle.rs index f05983196..e044572fb 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -231,37 +231,6 @@ pub(crate) fn empty_cycle_heads() -> &'static CycleHeads { EMPTY_CYCLE_HEADS.get_or_init(|| CycleHeads(ThinVec::new())) } -/// Set of cycle head database keys. -/// -/// Unlike [`CycleHeads`], this type doesn't track the iteration count -/// of each cycle head. -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct CycleHeadKeys(Vec); - -impl CycleHeadKeys { - pub(crate) fn new() -> Self { - Self(Vec::new()) - } - - pub(crate) fn insert(&mut self, database_key_index: DatabaseKeyIndex) { - if !self.0.contains(&database_key_index) { - self.0.push(database_key_index); - } - } - - pub(crate) fn remove(&mut self, value: &DatabaseKeyIndex) -> bool { - let found = self.0.iter().position(|&head| head == *value); - let Some(found) = found else { return false }; - - self.0.swap_remove(found); - true - } - - pub(crate) fn is_empty(&self) -> bool { - self.0.is_empty() - } -} - #[derive(Debug, PartialEq, Eq)] pub enum ProvisionalStatus { Provisional { iteration: IterationCount }, diff --git a/src/function.rs b/src/function.rs index 6d4ec5365..3e8674cf0 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,4 +1,4 @@ -pub(crate) use maybe_changed_after::VerifyResult; +pub(crate) use maybe_changed_after::{VerifyCycleHeads, VerifyResult}; use std::any::Any; use std::fmt; use std::ptr::NonNull; @@ -7,8 +7,7 @@ use std::sync::OnceLock; pub(crate) use sync::SyncGuard; use crate::cycle::{ - empty_cycle_heads, CycleHeadKeys, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, - ProvisionalStatus, + empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, ProvisionalStatus, }; use crate::database::RawDatabase; use crate::function::delete::DeletedEntries; @@ -266,7 +265,7 @@ where db: RawDatabase<'_>, input: Id, revision: Revision, - cycle_heads: &mut CycleHeadKeys, + cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult { // SAFETY: The `db` belongs to the ingredient as per caller invariant let db = unsafe { self.view_caster().downcast_unchecked(db) }; diff --git a/src/function/execute.rs b/src/function/execute.rs index 2690d1a5c..5587b1d94 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -148,12 +148,14 @@ where // initial provisional value from there. let memo = self .get_memo_from_table_for(zalsa, id, memo_ingredient_index) + .filter(|memo| memo.verified_at.load() == zalsa.current_revision()) .unwrap_or_else(|| { unreachable!( "{database_key_index:#?} is a cycle head, \ but no provisional memo found" ) }); + debug_assert!(memo.may_be_provisional()); memo.value.as_ref() }; diff --git a/src/function/fetch.rs b/src/function/fetch.rs index b65089b43..32e2eb44a 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,7 +1,8 @@ -use crate::cycle::{CycleHeadKeys, CycleHeads, CycleRecoveryStrategy, IterationCount}; +use crate::cycle::{CycleHeads, CycleRecoveryStrategy, IterationCount}; +use crate::function::maybe_changed_after::VerifyCycleHeads; use crate::function::memo::Memo; use crate::function::sync::ClaimResult; -use crate::function::{Configuration, IngredientImpl, VerifyResult}; +use crate::function::{Configuration, IngredientImpl}; use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{QueryRevisions, ZalsaLocal}; use crate::{DatabaseKeyIndex, Id}; @@ -163,15 +164,38 @@ where if let Some(old_memo) = opt_old_memo { if old_memo.value.is_some() { - let mut cycle_heads = CycleHeadKeys::new(); - if let VerifyResult::Unchanged { .. } = - self.deep_verify_memo(db, zalsa, old_memo, database_key_index, &mut cycle_heads) + let can_shallow_update = + self.shallow_verify_memo(zalsa, database_key_index, old_memo); + if can_shallow_update.yes() + && self.validate_may_be_provisional( + zalsa, + zalsa_local, + database_key_index, + old_memo, + true, + ) { - if cycle_heads.is_empty() { - // SAFETY: memo is present in memo_map and we have verified that it is - // still valid for the current revision. - return unsafe { Some(self.extend_memo_lifetime(old_memo)) }; - } + self.update_shallow(zalsa, database_key_index, old_memo, can_shallow_update); + + // SAFETY: memo is present in memo_map and we have verified that it is + // still valid for the current revision. + return unsafe { Some(self.extend_memo_lifetime(old_memo)) }; + } + + let mut cycle_heads = VerifyCycleHeads::default(); + let verify_result = self.deep_verify_memo( + db, + zalsa, + old_memo, + database_key_index, + &mut cycle_heads, + can_shallow_update, + ); + + if verify_result.is_unchanged() && !cycle_heads.has_any() { + // SAFETY: memo is present in memo_map and we have verified that it is + // still valid for the current revision. + return unsafe { Some(self.extend_memo_lifetime(old_memo)) }; } // If this is a provisional memo from the same revision, await all its cycle heads because diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 93c971895..54fce885d 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -1,6 +1,6 @@ #[cfg(feature = "accumulator")] use crate::accumulator::accumulated_map::InputAccumulatedValues; -use crate::cycle::{CycleHeadKeys, CycleRecoveryStrategy, IterationCount, ProvisionalStatus}; +use crate::cycle::{CycleRecoveryStrategy, IterationCount, ProvisionalStatus}; use crate::function::memo::Memo; use crate::function::sync::ClaimResult; use crate::function::{Configuration, IngredientImpl}; @@ -11,6 +11,7 @@ use crate::zalsa_local::{QueryEdgeKind, QueryOriginRef, ZalsaLocal}; use crate::{Id, Revision}; /// Result of memo validation. +#[derive(Debug)] pub enum VerifyResult { /// Memo has changed and needs to be recomputed. Changed, @@ -26,20 +27,40 @@ pub enum VerifyResult { } impl VerifyResult { - pub(crate) fn changed_if(changed: bool) -> Self { + pub(crate) const fn changed_if(changed: bool) -> Self { if changed { - Self::Changed + Self::changed() } else { Self::unchanged() } } - pub(crate) fn unchanged() -> Self { + pub(crate) const fn changed() -> Self { + Self::Changed + } + + pub(crate) const fn unchanged() -> Self { Self::Unchanged { #[cfg(feature = "accumulator")] accumulated: InputAccumulatedValues::Empty, } } + + #[inline] + #[cfg(feature = "accumulator")] + pub(crate) fn unchanged_with_accumulated(accumulated: InputAccumulatedValues) -> Self { + Self::Unchanged { accumulated } + } + + #[inline] + #[cfg(not(feature = "accumulator"))] + pub(crate) fn unchanged_with_accumulated() -> Self { + Self::unchanged() + } + + pub(crate) const fn is_unchanged(&self) -> bool { + matches!(self, Self::Unchanged { .. }) + } } impl IngredientImpl @@ -51,7 +72,7 @@ where db: &'db C::DbView, id: Id, revision: Revision, - cycle_heads: &mut CycleHeadKeys, + cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult { let (zalsa, zalsa_local) = db.zalsas(); let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); @@ -68,7 +89,7 @@ where let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); let Some(memo) = memo_guard else { // No memo? Assume has changed. - return VerifyResult::Changed; + return VerifyResult::changed(); }; let can_shallow_update = self.shallow_verify_memo(zalsa, database_key_index, memo); @@ -76,12 +97,14 @@ where self.update_shallow(zalsa, database_key_index, memo, can_shallow_update); return if memo.revisions.changed_at > revision { - VerifyResult::Changed + VerifyResult::changed() } else { - VerifyResult::Unchanged { + VerifyResult::unchanged_with_accumulated( #[cfg(feature = "accumulator")] - accumulated: memo.revisions.accumulated_inputs.load(), - } + { + memo.revisions.accumulated_inputs.load() + }, + ) }; } @@ -108,7 +131,7 @@ where key_index: Id, revision: Revision, memo_ingredient_index: MemoIngredientIndex, - cycle_heads: &mut CycleHeadKeys, + cycle_heads: &mut VerifyCycleHeads, ) -> Option { let database_key_index = self.database_key_index(key_index); @@ -129,7 +152,7 @@ where // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index, memo_ingredient_index) else { - return Some(VerifyResult::Changed); + return Some(VerifyResult::changed()); }; crate::tracing::debug!( @@ -138,21 +161,41 @@ where old_memo = old_memo.tracing_debug() ); - // Check if the inputs are still valid. We can just compare `changed_at`. - let deep_verify = - self.deep_verify_memo(db, zalsa, old_memo, database_key_index, cycle_heads); - if let VerifyResult::Unchanged { - #[cfg(feature = "accumulator")] - accumulated: accumulated_inputs, - } = deep_verify + let can_shallow_update = self.shallow_verify_memo(zalsa, database_key_index, old_memo); + if can_shallow_update.yes() + && self.validate_may_be_provisional( + zalsa, + db.zalsa_local(), + database_key_index, + old_memo, + // Don't conclude that the query is unchanged if the memo itself is still + // provisional (because all its cycle heads have the same iteration count + // as the cycle head memos in the database). + // See https://github.com/salsa-rs/salsa/pull/961 + false, + ) { + self.update_shallow(zalsa, database_key_index, old_memo, can_shallow_update); + + return Some(VerifyResult::unchanged()); + } + + let deep_verify = self.deep_verify_memo( + db, + zalsa, + old_memo, + database_key_index, + cycle_heads, + can_shallow_update, + ); + + if deep_verify.is_unchanged() { + // Check if the inputs are still valid. We can just compare `changed_at`. return Some(if old_memo.revisions.changed_at > revision { - VerifyResult::Changed + VerifyResult::changed() } else { - VerifyResult::Unchanged { - #[cfg(feature = "accumulator")] - accumulated: accumulated_inputs, - } + // Returns unchanged but propagates the accumulated values + deep_verify }); } @@ -166,28 +209,32 @@ where // the cycle head returned *fixpoint initial* without validating its dependencies. // `in_cycle` tracks if the enclosing query is in a cycle. `deep_verify.cycle_heads` tracks // if **this query** encountered a cycle (which means there's some provisional value somewhere floating around). - if old_memo.value.is_some() && cycle_heads.is_empty() { + if old_memo.value.is_some() && !cycle_heads.has_any() { let active_query = db .zalsa_local() .push_query(database_key_index, IterationCount::initial()); let memo = self.execute(db, active_query, Some(old_memo)); let changed_at = memo.revisions.changed_at; - return Some(if changed_at > revision { - VerifyResult::Changed + // Always assume that a provisional value has changed. + // + // We don't know if a provisional value has actually changed. To determine whether a provisional + // value has changed, we need to iterate the outer cycle, which cannot be done here. + return Some(if changed_at > revision || memo.may_be_provisional() { + VerifyResult::changed() } else { - VerifyResult::Unchanged { + VerifyResult::unchanged_with_accumulated( #[cfg(feature = "accumulator")] - accumulated: match memo.revisions.accumulated() { + match memo.revisions.accumulated() { Some(_) => InputAccumulatedValues::Any, None => memo.revisions.accumulated_inputs.load(), }, - } + ) }); } // Otherwise, nothing for it: have to consider the value to have changed. - Some(VerifyResult::Changed) + Some(VerifyResult::changed()) } #[cold] @@ -196,7 +243,7 @@ where &'db self, db: &'db C::DbView, database_key_index: DatabaseKeyIndex, - cycle_heads: &mut CycleHeadKeys, + cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult { match C::CYCLE_STRATEGY { // SAFETY: We do not access the query stack reentrantly. @@ -281,17 +328,20 @@ where /// * provisional memos that have been successfully marked as verified final, that is, its /// cycle heads have all been finalized. /// * provisional memos that have been created in the same revision and iteration and are part of the same cycle. + /// This check is skipped if `allow_non_finalized` is `false` as the memo itself is still not finalized. It's a provisional value. #[inline] - fn validate_may_be_provisional( + pub(super) fn validate_may_be_provisional( &self, zalsa: &Zalsa, zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, memo: &Memo<'_, C>, + allow_non_finalized: bool, ) -> bool { !memo.may_be_provisional() || self.validate_provisional(zalsa, database_key_index, memo) - || self.validate_same_iteration(zalsa, zalsa_local, database_key_index, memo) + || (allow_non_finalized + && self.validate_same_iteration(zalsa, zalsa_local, database_key_index, memo)) } /// Check if this memo's cycle heads have all been finalized. If so, mark it verified final and @@ -448,39 +498,28 @@ where zalsa: &Zalsa, old_memo: &Memo<'_, C>, database_key_index: DatabaseKeyIndex, - cycle_heads: &mut CycleHeadKeys, + cycle_heads: &mut VerifyCycleHeads, + can_shallow_update: ShallowUpdate, ) -> VerifyResult { crate::tracing::debug!( "{database_key_index:?}: deep_verify_memo(old_memo = {old_memo:#?})", old_memo = old_memo.tracing_debug() ); - let can_shallow_update = self.shallow_verify_memo(zalsa, database_key_index, old_memo); - if can_shallow_update.yes() - && self.validate_may_be_provisional( - zalsa, - db.zalsa_local(), - database_key_index, - old_memo, - ) - { - self.update_shallow(zalsa, database_key_index, old_memo, can_shallow_update); - - return VerifyResult::unchanged(); - } + debug_assert!(!cycle_heads.contains(database_key_index)); match old_memo.revisions.origin.as_ref() { QueryOriginRef::Derived(edges) => { - let is_provisional = old_memo.may_be_provisional(); - // If the value is from the same revision but is still provisional, consider it changed // because we're now in a new iteration. - if can_shallow_update == ShallowUpdate::Verified && is_provisional { - return VerifyResult::Changed; + if can_shallow_update == ShallowUpdate::Verified && old_memo.may_be_provisional() { + return VerifyResult::changed(); } #[cfg(feature = "accumulator")] let mut inputs = InputAccumulatedValues::Empty; + let mut child_cycle_heads = Vec::new(); + // Fully tracked inputs? Iterate over the inputs and check them, one by one. // // NB: It's important here that we are iterating the inputs in the order that @@ -490,13 +529,29 @@ where for &edge in edges { match edge.kind() { QueryEdgeKind::Input(dependency_index) => { - match dependency_index.maybe_changed_after( + debug_assert!(child_cycle_heads.is_empty()); + + // The `MaybeChangeAfterCycleHeads` is used as an out parameter and it's + // the caller's responsibility to pass an empty `heads`, which is what we do here. + let mut inner_cycle_heads = VerifyCycleHeads { + heads: std::mem::take(&mut child_cycle_heads), + has_outer_cycles: cycle_heads.has_any(), + }; + + let input_result = dependency_index.maybe_changed_after( db.into(), zalsa, old_memo.verified_at.load(), - cycle_heads, - ) { - VerifyResult::Changed => return VerifyResult::Changed, + &mut inner_cycle_heads, + ); + + // Reuse the cycle head allocation. + child_cycle_heads = inner_cycle_heads.heads; + // Aggregate the cycle heads into the parent cycle heads + cycle_heads.append(&mut child_cycle_heads); + + match input_result { + VerifyResult::Changed => return VerifyResult::changed(), #[cfg(feature = "accumulator")] VerifyResult::Unchanged { accumulated } => { inputs |= accumulated; @@ -552,26 +607,23 @@ where // from cycle heads. We will handle our own memo (and the rest of our cycle) on a // future iteration; first the outer cycle head needs to verify itself. - cycle_heads.remove(&database_key_index); + cycle_heads.remove(database_key_index); // 1 and 3 - if cycle_heads.is_empty() { + if !cycle_heads.has_own() { old_memo.mark_as_verified(zalsa, database_key_index); #[cfg(feature = "accumulator")] old_memo.revisions.accumulated_inputs.store(inputs); - - if is_provisional { - old_memo - .revisions - .verified_final - .store(true, Ordering::Relaxed); - } + old_memo + .revisions + .verified_final + .store(true, Ordering::Relaxed); } - VerifyResult::Unchanged { + VerifyResult::unchanged_with_accumulated( #[cfg(feature = "accumulator")] - accumulated: inputs, - } + inputs, + ) } QueryOriginRef::Assigned(_) => { @@ -586,7 +638,7 @@ where // Conditionally specified queries // where the value is specified // in rev 1 but not in rev 2. - VerifyResult::Changed + VerifyResult::changed() } // Return `Unchanged` similar to the initial value that we insert // when we hit the cycle. Any dependencies accessed when creating the fixpoint initial @@ -598,7 +650,7 @@ where } QueryOriginRef::DerivedUntracked(_) => { // Untracked inputs? Have to assume that it changed. - VerifyResult::Changed + VerifyResult::changed() } } } @@ -625,3 +677,72 @@ impl ShallowUpdate { ) } } + +/// The cycles encountered while verifying if an ingredient has changed after a given revision. +/// +/// We use this as an out parameter to avoid increasing the size of [`VerifyResult`]. +/// The `heads` of a `MaybeChangeAfterCycleHeads` must be empty when +/// calling [`maybe_changed_after`]. The [`maybe_changed_after`] then collects all cycle heads +/// encountered while verifying this ingredient and its subtree. +/// +/// Note that `heads` only contains the cycle heads up to the point where [`maybe_changed_after`] +/// returned [`VerifyResult::Changed`]. Cycles that only manifest when verifying later dependencies +/// aren't included. +/// +/// [`maybe_changed_after`]: crate::ingredient::Ingredient::maybe_changed_after +#[derive(Debug, Default)] +pub struct VerifyCycleHeads { + heads: Vec, + + /// Whether the outer query (e.g. the parent query running `maybe_changed_after`) has encountered + /// any cycles to this point. + has_outer_cycles: bool, +} + +impl VerifyCycleHeads { + #[inline] + fn contains(&self, key: DatabaseKeyIndex) -> bool { + self.heads.contains(&key) + } + + #[inline] + fn insert(&mut self, key: DatabaseKeyIndex) { + if !self.heads.contains(&key) { + self.heads.push(key); + } + } + + fn remove(&mut self, key: DatabaseKeyIndex) -> bool { + let found = self.heads.iter().position(|&head| head == key); + let Some(found) = found else { return false }; + + self.heads.swap_remove(found); + true + } + + #[inline] + fn append(&mut self, heads: &mut Vec) { + if heads.is_empty() { + return; + } + + self.append_slow(heads); + } + + fn append_slow(&mut self, heads: &mut Vec) { + for key in heads.drain(..) { + self.insert(key); + } + } + + /// Returns `true` if this query or any of its dependencies has encountered a cycle or + /// if the outer query has encountered a cycle. + pub fn has_any(&self) -> bool { + self.has_outer_cycles || !self.heads.is_empty() + } + + /// Returns `true` if this query has encountered a cycle. + fn has_own(&self) -> bool { + !self.heads.is_empty() + } +} diff --git a/src/ingredient.rs b/src/ingredient.rs index f117cb696..3e1e0f2f7 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -2,11 +2,10 @@ use std::any::{Any, TypeId}; use std::fmt; use crate::cycle::{ - empty_cycle_heads, CycleHeadKeys, CycleHeads, CycleRecoveryStrategy, IterationCount, - ProvisionalStatus, + empty_cycle_heads, CycleHeads, CycleRecoveryStrategy, IterationCount, ProvisionalStatus, }; use crate::database::RawDatabase; -use crate::function::VerifyResult; +use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::runtime::Running; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; @@ -52,7 +51,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { db: crate::database::RawDatabase<'_>, input: Id, revision: Revision, - cycle_heads: &mut CycleHeadKeys, + cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult; /// Returns information about the current provisional status of `input`. diff --git a/src/input.rs b/src/input.rs index 464f79f10..b48b369c8 100644 --- a/src/input.rs +++ b/src/input.rs @@ -8,8 +8,7 @@ pub mod singleton; use input_field::FieldIngredientImpl; -use crate::cycle::CycleHeadKeys; -use crate::function::VerifyResult; +use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::id::{AsId, FromId, FromIdWithDb}; use crate::ingredient::Ingredient; use crate::input::singleton::{Singleton, SingletonChoice}; @@ -231,7 +230,7 @@ impl Ingredient for IngredientImpl { _db: crate::database::RawDatabase<'_>, _input: Id, _revision: Revision, - _cycle_heads: &mut CycleHeadKeys, + _cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult { // Input ingredients are just a counter, they store no data, they are immortal. // Their *fields* are stored in function ingredients elsewhere. diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 0d724b0ca..9b352b561 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -1,8 +1,7 @@ use std::fmt; use std::marker::PhantomData; -use crate::cycle::CycleHeadKeys; -use crate::function::VerifyResult; +use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::ingredient::Ingredient; use crate::input::{Configuration, IngredientImpl, Value}; use crate::sync::Arc; @@ -56,7 +55,7 @@ where _db: crate::database::RawDatabase<'_>, input: Id, revision: Revision, - _cycle_heads: &mut CycleHeadKeys, + _cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult { let value = >::data(zalsa, input); VerifyResult::changed_if(value.revisions[self.field_index] > revision) diff --git a/src/interned.rs b/src/interned.rs index cdf4d61b8..33723ac2c 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -10,9 +10,8 @@ use crossbeam_utils::CachePadded; use intrusive_collections::{intrusive_adapter, LinkedList, LinkedListLink, UnsafeRef}; use rustc_hash::FxBuildHasher; -use crate::cycle::CycleHeadKeys; use crate::durability::Durability; -use crate::function::VerifyResult; +use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::id::{AsId, FromId}; use crate::ingredient::Ingredient; use crate::plumbing::{Jar, ZalsaLocal}; @@ -804,7 +803,7 @@ where _db: crate::database::RawDatabase<'_>, input: Id, _revision: Revision, - _cycle_heads: &mut CycleHeadKeys, + _cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult { // Record the current revision as active. let current_revision = zalsa.current_revision(); @@ -820,7 +819,7 @@ where // The slot was reused. if value_shared.id.generation() > input.generation() { - return VerifyResult::Changed; + return VerifyResult::changed(); } // Validate the value for the current revision to avoid reuse. diff --git a/src/key.rs b/src/key.rs index 9045e8337..47f750e7f 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,7 +1,6 @@ use core::fmt; -use crate::cycle::CycleHeadKeys; -use crate::function::VerifyResult; +use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::zalsa::{IngredientIndex, Zalsa}; use crate::Id; @@ -39,7 +38,7 @@ impl DatabaseKeyIndex { db: crate::database::RawDatabase<'_>, zalsa: &Zalsa, last_verified_at: crate::Revision, - cycle_heads: &mut CycleHeadKeys, + cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult { // SAFETY: The `db` belongs to the ingredient unsafe { diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 00986107a..116e96c00 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -10,8 +10,7 @@ use crossbeam_queue::SegQueue; use thin_vec::ThinVec; use tracked_field::FieldIngredientImpl; -use crate::cycle::CycleHeadKeys; -use crate::function::VerifyResult; +use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::id::{AsId, FromId}; use crate::ingredient::{Ingredient, Jar}; use crate::key::DatabaseKeyIndex; @@ -822,7 +821,7 @@ where _db: crate::database::RawDatabase<'_>, _input: Id, _revision: Revision, - _cycle_heads: &mut CycleHeadKeys, + _cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult { // Any change to a tracked struct results in a new ID generation. VerifyResult::unchanged() diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 95ec32fa6..0d565bcfd 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,7 +1,6 @@ use std::marker::PhantomData; -use crate::cycle::CycleHeadKeys; -use crate::function::VerifyResult; +use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::ingredient::Ingredient; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; @@ -61,7 +60,7 @@ where _db: crate::database::RawDatabase<'_>, input: Id, revision: crate::Revision, - _cycle_heads: &mut CycleHeadKeys, + _cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult { let data = >::data(zalsa.table(), input); let field_changed_at = data.revisions[self.field_index]; diff --git a/tests/common/mod.rs b/tests/common/mod.rs index df3fba477..3b9e7434f 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -41,6 +41,7 @@ pub trait LogDatabase: HasLogger + Database { /// Asserts what the (formatted) logs should look like, /// clearing the logged events. This takes `&mut self` because /// it is meant to be run from outside any tracked functions. + #[track_caller] fn assert_logs(&self, expected: expect_test::Expect) { let logs = std::mem::take(&mut *self.logger().logs.lock().unwrap()); expected.assert_eq(&format!("{logs:#?}")); @@ -49,6 +50,7 @@ pub trait LogDatabase: HasLogger + Database { /// Asserts the length of the logs, /// clearing the logged events. This takes `&mut self` because /// it is meant to be run from outside any tracked functions. + #[track_caller] fn assert_logs_len(&self, expected: usize) { let logs = std::mem::take(&mut *self.logger().logs.lock().unwrap()); assert_eq!(logs.len(), expected, "Actual logs: {logs:#?}"); diff --git a/tests/cycle.rs b/tests/cycle.rs index 3c3687f3d..2eb9bac23 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -994,6 +994,79 @@ fn cycle_unchanged_nested_intertwined() { } } +/// Test that cycle heads from one dependency don't interfere with sibling verification. +/// +/// a:Ni(b, c, d) -> b:Ni(a) [cycle with a, unchanged] +/// \-> c:Np(v100) [no cycle, unchanged] +/// \-> d:Np(v200->v201) [no cycle, changes] +/// +/// When verifying a in a new revision: +/// 1. b goes through deep verification (detects b->a cycle, adds cycle heads, returns unchanged) +/// 2. c gets verified (should not be affected by b's cycle heads with the fix) +/// 3. d returns changed, causing a to re-execute +/// +/// Without the fix: cycle heads from b's verification remain in shared context and interfere with c +/// With the fix: c gets fresh cycle head context and verifies cleanly +#[test] +fn cycle_sibling_interference() { + let mut db = ExecuteValidateLoggerDatabase::default(); + let a_in = Inputs::new(&db, vec![]); // a = min_iterate(Id(0)) + let b_in = Inputs::new(&db, vec![]); // b = min_iterate(Id(1)) + let c_in = Inputs::new(&db, vec![]); // c = min_panic(Id(2)) + let d_in = Inputs::new(&db, vec![]); // d = min_panic(Id(3)) + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + let d = Input::MinPanic(d_in); + + a_in.set_inputs(&mut db) + .to(vec![b.clone(), c.clone(), d.clone()]); // a depends on b, c, d (in that order) + b_in.set_inputs(&mut db).to(vec![a.clone()]); // b depends on a (forming a->b->a cycle) + c_in.set_inputs(&mut db).to(vec![value(100)]); // c is independent, no cycles + d_in.set_inputs(&mut db).to(vec![value(200)]); // d is independent, no cycles + + // First execution - this will establish the cycle and memos + // The cycle: a depends on b, b depends on a + // During fixpoint iteration, initial values are 255 + // a computes min(255, 100, 200) = 100 + // b computes min(100) = 100 + // Next iteration: a computes min(100, 100, 200) = 100 (converged) + a.assert_value(&db, 100); + b.assert_value(&db, 100); + c.assert_value(&db, 100); + d.assert_value(&db, 200); + + // Clear logs to prepare for the next revision + db.clear_logs(); + + // Change d's input to trigger a new revision + // This forces verification of all dependencies in the new revision + d_in.set_inputs(&mut db).to(vec![value(201)]); + + // Verify a - this should trigger: + // 1. b: deep verification (cycle detected, cycle heads added to context, but b unchanged) + // 2. c: verification (should be clean without cycle head interference) + // 3. d: changed, causing a to re-execute + a.assert_value(&db, 100); // min(255, 100, 201) = 100 + + // Query mapping: a=min_iterate(Id(0)), b=min_iterate(Id(1)), c=min_panic(Id(2)), d=min_panic(Id(3)) + // - c gets validated cleanly during verification of `a`. The fact that `a` and `b` form a cycle shouldn't prevent that + // - a re-executes (due to d changing) + // - b re-executes (as part of a-b cycle) + // - d re-executes (input changed) + // - cycle iteration continues + // - b re-executes again during cycle iteration + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })", + "salsa_event(WillExecute { database_key: min_iterate(Id(0)) })", + "salsa_event(WillExecute { database_key: min_iterate(Id(1)) })", + "salsa_event(WillExecute { database_key: min_panic(Id(3)) })", + "salsa_event(WillIterateCycle { database_key: min_iterate(Id(0)), iteration_count: IterationCount(1), fell_back: false })", + "salsa_event(WillExecute { database_key: min_iterate(Id(1)) })", + ]"#]]); +} + /// Provisional query results in a cycle should still be cached within a single iteration. /// /// a:Ni(v59, b) -> b:Np(v60, c, c, c) -> c:Np(a) From 1ffb32f54cba6f3f3c6d99a2ebef2a47fc6f1b14 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Mon, 11 Aug 2025 15:02:49 -0400 Subject: [PATCH 26/65] Initial persistent caching prototype (#967) * persistent caching prototype * move serialization arguments under `persist` attribute * move `serde` dependency under `persistence` feature * avoid serializing provisional memos * use exhaustive checking for manual `Serialize` implementations * update tests * remove distinction between ingredient `entries` and `instances` * avoid enabling `shuttle` feature in CI * serialize ingredients by index --- .github/workflows/test.yml | 5 +- Cargo.toml | 11 +- .../src/setup_input_struct.rs | 65 +++ .../src/setup_interned_struct.rs | 69 +++- .../salsa-macro-rules/src/setup_tracked_fn.rs | 80 +++- .../src/setup_tracked_struct.rs | 65 +++ components/salsa-macros/Cargo.toml | 4 + components/salsa-macros/src/accumulator.rs | 4 +- components/salsa-macros/src/input.rs | 13 +- components/salsa-macros/src/interned.rs | 13 +- components/salsa-macros/src/options.rs | 118 +++++- components/salsa-macros/src/salsa_struct.rs | 30 ++ components/salsa-macros/src/supertype.rs | 7 + components/salsa-macros/src/tracked_fn.rs | 9 +- components/salsa-macros/src/tracked_struct.rs | 14 +- src/accumulator.rs | 6 +- src/cycle.rs | 3 + src/database.rs | 222 +++++++++- src/durability.rs | 2 + src/function.rs | 225 ++++++++++- src/function/memo.rs | 122 +++++- src/id.rs | 16 +- src/ingredient.rs | 47 ++- src/input.rs | 285 ++++++++++++- src/input/input_field.rs | 16 +- src/input/singleton.rs | 2 +- src/interned.rs | 381 ++++++++++++++++-- src/key.rs | 1 + src/lib.rs | 19 + src/memo_ingredient_indices.rs | 4 + src/revision.rs | 63 +++ src/runtime.rs | 10 + src/salsa_struct.rs | 5 +- src/sync.rs | 4 - src/table.rs | 104 ++++- src/table/memo.rs | 2 +- src/tracked_struct.rs | 272 ++++++++++++- src/tracked_struct/tracked_field.rs | 6 +- src/zalsa.rs | 34 +- src/zalsa_local.rs | 86 +++- tests/compile-fail/incomplete_persistence.rs | 14 + .../incomplete_persistence.stderr | 100 +++++ tests/compile-fail/invalid_persist_options.rs | 56 +++ .../invalid_persist_options.stderr | 23 ++ tests/compile_fail.rs | 2 +- tests/debug_db_contents.rs | 3 + tests/interned-structs_self_ref.rs | 39 +- tests/memory-usage.rs | 18 +- tests/persistence.rs | 378 +++++++++++++++++ 49 files changed, 2955 insertions(+), 122 deletions(-) create mode 100644 tests/compile-fail/incomplete_persistence.rs create mode 100644 tests/compile-fail/incomplete_persistence.stderr create mode 100644 tests/compile-fail/invalid_persist_options.rs create mode 100644 tests/compile-fail/invalid_persist_options.stderr create mode 100644 tests/persistence.rs diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 57b0b1d91..6a43ba722 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,8 +53,11 @@ jobs: run: cargo fmt -- --check - name: Clippy run: cargo clippy --workspace --all-targets -- -D warnings + # TODO: Use something like cargo-hack for more robust feature configuration testing. + - name: Clippy / all-features + run: cargo clippy --workspace --all-targets --features persistence -- -D warnings - name: Test - run: cargo nextest run --workspace --all-targets --no-fail-fast + run: cargo nextest run --workspace --all-targets --features persistence --no-fail-fast - name: Test Manual Registration / no-default-features run: cargo nextest run --workspace --tests --no-fail-fast --no-default-features --features macros - name: Test docs diff --git a/Cargo.toml b/Cargo.toml index 3b4eb3455..847bca2b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ parking_lot = "0.12" portable-atomic = "1" rustc-hash = "2" smallvec = "1" +thin-vec = { version = "0.2.14", features = ["serde"] } tracing = { version = "0.1", default-features = false, features = ["std"] } # Automatic ingredient registration. @@ -33,18 +34,23 @@ rayon = { version = "1.10.0", optional = true } # Stuff we want Update impls for by default compact_str = { version = "0.9", optional = true } -thin-vec = "0.2.14" shuttle = { version = "0.8.1", optional = true } +# Persistent caching +erased-serde = { version = "0.4.6", optional = true } +serde = { version = "1.0.219", features = ["derive"], optional = true } + [features] default = ["salsa_unstable", "rayon", "macros", "inventory", "accumulator"] inventory = ["dep:inventory"] +persistence = ["dep:serde", "dep:erased-serde", "salsa-macros/persistence"] shuttle = ["dep:shuttle"] accumulator = ["salsa-macro-rules/accumulator"] +macros = ["dep:salsa-macros"] + # FIXME: remove `salsa_unstable` before 1.0. salsa_unstable = [] -macros = ["dep:salsa-macros"] # This interlocks the `salsa-macros` and `salsa` versions together # preventing scenarios where they could diverge in a given project @@ -68,6 +74,7 @@ expect-test = "1.5.1" rustversion = "1.0" test-log = { version = "0.2.18", features = ["trace"] } trybuild = "1.0" +serde_json = "1.0.140" [target.'cfg(all(not(target_os = "windows"), not(target_os = "openbsd"), any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "powerpc64")))'.dev-dependencies] tikv-jemallocator = "0.6.0" diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index cc3871361..ce5318208 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -53,6 +53,15 @@ macro_rules! setup_input_struct { // The function used to implement `C::heap_size`. heap_size_fn: $($heap_size_fn:path)?, + // If `true`, `serialize_fn` and `deserialize_fn` have been provided. + persist: $persist:tt, + + // The path to the `serialize` function for the value's fields. + serialize_fn: $($serialize_fn:path)?, + + // The path to the `serialize` function for the value's fields. + deserialize_fn: $($deserialize_fn:path)?, + // Annoyingly macro-rules hygiene does not extend to items defined in the macro. // We have the procedural macro generate names for those items that are // not used elsewhere in the user's code. @@ -93,6 +102,9 @@ macro_rules! setup_input_struct { }; const DEBUG_NAME: &'static str = stringify!($Struct); const FIELD_DEBUG_NAMES: &'static [&'static str] = &[$(stringify!($field_id)),*]; + + const PERSIST: bool = $persist; + type Singleton = $zalsa::macro_if! {if $is_singleton {$zalsa::input::Singleton} else {$zalsa::input::NotSingleton}}; type Struct = $Struct; @@ -107,6 +119,32 @@ macro_rules! setup_input_struct { Some($heap_size_fn(value)) } )? + + fn serialize( + fields: &Self::Fields, + serializer: S, + ) -> Result { + $zalsa::macro_if! { + if $persist { + $($serialize_fn(fields, serializer))? + } else { + panic!("attempted to serialize value not marked with `persist` attribute") + } + } + } + + fn deserialize<'de, D: $zalsa::serde::Deserializer<'de>>( + deserializer: D, + ) -> Result { + $zalsa::macro_if! { + if $persist { + $($deserialize_fn(deserializer))? + } else { + panic!("attempted to deserialize value not marked with `persist` attribute") + } + } + } + } impl $Configuration { @@ -174,6 +212,13 @@ macro_rules! setup_input_struct { aux.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into() } + fn entries( + zalsa: &$zalsa::Zalsa + ) -> impl Iterator + '_ { + let ingredient_index = zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>(); + <$Configuration>::ingredient_(zalsa).entries(zalsa).map(|(key, _)| key) + } + #[inline] fn cast(id: $zalsa::Id, type_id: $zalsa::TypeId) -> $zalsa::Option { if type_id == $zalsa::TypeId::of::<$Struct>() { @@ -194,6 +239,26 @@ macro_rules! setup_input_struct { } } + $zalsa::macro_if! { $persist => + impl $zalsa::serde::Serialize for $Struct { + fn serialize(&self, serializer: S) -> Result + where + S: $zalsa::serde::Serializer, + { + $zalsa::serde::Serialize::serialize(&$zalsa::AsId::as_id(self), serializer) + } + } + + impl<'de> $zalsa::serde::Deserialize<'de> for $Struct { + fn deserialize(deserializer: D) -> Result + where + D: $zalsa::serde::Deserializer<'de>, + { + let id = $zalsa::Id::deserialize(deserializer)?; + Ok($zalsa::FromId::from_id(id)) + } + } + } impl $Struct { #[inline] pub fn $new_fn<$Db>(db: &$Db, $($required_field_id: $required_field_ty),*) -> Self diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index ebf7f23cd..e473e58b9 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -69,6 +69,15 @@ macro_rules! setup_interned_struct { // The function used to implement `C::heap_size`. heap_size_fn: $($heap_size_fn:path)?, + // If `true`, `serialize_fn` and `deserialize_fn` have been provided. + persist: $persist:tt, + + // The path to the `serialize` function for the value's fields. + serialize_fn: $($serialize_fn:path)?, + + // The path to the `serialize` function for the value's fields. + deserialize_fn: $($deserialize_fn:path)?, + // Annoyingly macro-rules hygiene does not extend to items defined in the macro. // We have the procedural macro generate names for those items that are // not used elsewhere in the user's code. @@ -144,9 +153,12 @@ macro_rules! setup_interned_struct { line: line!(), }; const DEBUG_NAME: &'static str = stringify!($Struct); + const PERSIST: bool = $persist; + $( const REVISIONS: ::core::num::NonZeroUsize = ::core::num::NonZeroUsize::new($revisions).unwrap(); )? + type Fields<'a> = $StructDataIdent<'a>; type Struct<'db> = $Struct< $($db_lt_arg)? >; @@ -155,11 +167,35 @@ macro_rules! setup_interned_struct { Some($heap_size_fn(value)) } )? + + fn serialize( + fields: &Self::Fields<'_>, + serializer: S, + ) -> Result { + $zalsa::macro_if! { + if $persist { + $($serialize_fn(fields, serializer))? + } else { + panic!("attempted to serialize value not marked with `persist` attribute") + } + } + } + + fn deserialize<'de, D: $zalsa::serde::Deserializer<'de>>( + deserializer: D, + ) -> Result, D::Error> { + $zalsa::macro_if! { + if $persist { + $($deserialize_fn(deserializer))? + } else { + panic!("attempted to deserialize value not marked with `persist` attribute") + } + } + } } impl $Configuration { - pub fn ingredient(zalsa: &$zalsa::Zalsa) -> &$zalsa_struct::IngredientImpl - { + pub fn ingredient(zalsa: &$zalsa::Zalsa) -> &$zalsa_struct::IngredientImpl { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); @@ -204,6 +240,13 @@ macro_rules! setup_interned_struct { aux.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into() } + fn entries( + zalsa: &$zalsa::Zalsa + ) -> impl Iterator + '_ { + let ingredient_index = zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>(); + <$Configuration>::ingredient(zalsa).entries(zalsa).map(|(key, _)| key) + } + #[inline] fn cast(id: $zalsa::Id, type_id: $zalsa::TypeId) -> $zalsa::Option { if type_id == $zalsa::TypeId::of::<$Struct>() { @@ -224,6 +267,28 @@ macro_rules! setup_interned_struct { } } + $zalsa::macro_if! { $persist => + impl<$($db_lt_arg)?> $zalsa::serde::Serialize for $Struct<$($db_lt_arg)?> { + fn serialize(&self, serializer: S) -> Result + where + S: $zalsa::serde::Serializer, + { + $zalsa::serde::Serialize::serialize(&$zalsa::AsId::as_id(self), serializer) + } + } + + impl<'de, $($db_lt_arg)?> $zalsa::serde::Deserialize<'de> for $Struct<$($db_lt_arg)?> { + fn deserialize(deserializer: D) -> Result + where + D: $zalsa::serde::Deserializer<'de>, + { + let id = $zalsa::Id::deserialize(deserializer)?; + Ok($zalsa::FromId::from_id(id)) + } + } + } + + unsafe impl< $($db_lt_arg)? > $zalsa::Update for $Struct< $($db_lt_arg)? > { unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { if unsafe { *old_pointer } != new_value { diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index de0c11a18..51239802c 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -64,6 +64,9 @@ macro_rules! setup_tracked_fn { // The return mode for the function, see `salsa_macros::options::Option::returns` return_mode: $return_mode:tt, + // If true, the input and output values implement `serde::{Serialize, Deserialize}`. + persist: $persist:tt, + assert_return_type_is_update: {$($assert_return_type_is_update:tt)*}, $(self_ty: $self_ty:ty,)? @@ -122,6 +125,13 @@ macro_rules! setup_tracked_fn { $zalsa::IngredientIndices::empty() } + fn entries( + zalsa: &$zalsa::Zalsa + ) -> impl Iterator + '_ { + let ingredient_index = zalsa.lookup_jar_by_type::<$fn_name>().successor(0); + <$Configuration>::intern_ingredient(zalsa).entries(zalsa).map(|(key, _)| key) + } + #[inline] fn cast(id: $zalsa::Id, type_id: ::core::any::TypeId) -> Option { if type_id == ::core::any::TypeId::of::<$InternedData>() { @@ -162,13 +172,51 @@ macro_rules! setup_tracked_fn { line: line!(), }; const DEBUG_NAME: &'static str = concat!($(stringify!($self_ty), "::",)? stringify!($fn_name), "::interned_arguments"); + const PERSIST: bool = true; type Fields<$db_lt> = ($($interned_input_ty),*); type Struct<$db_lt> = $InternedData<$db_lt>; + + fn serialize( + fields: &Self::Fields<'_>, + serializer: S, + ) -> Result { + $zalsa::macro_if! { + if $persist { + $zalsa::serde::Serialize::serialize(fields, serializer) + } else { + panic!("attempted to serialize value not marked with `persist` attribute") + } + } + } + + fn deserialize<'de, D: $zalsa::serde::Deserializer<'de>>( + deserializer: D, + ) -> Result, D::Error> { + $zalsa::macro_if! { + if $persist { + $zalsa::serde::Deserialize::deserialize(deserializer) + } else { + panic!("attempted to deserialize value not marked with `persist` attribute") + } + } + } } } else { type $InternedData<$db_lt> = ($($interned_input_ty),*); + + $zalsa::macro_if! { $persist => + const fn query_input_is_persistable() + where + T: $zalsa::serde::Serialize + for<'de> $zalsa::serde::Deserialize<'de>, + { + } + + fn assert_query_input_is_persistable<$db_lt>() { + query_input_is_persistable::<$($interned_input_ty),*>(); + } + } } } @@ -200,11 +248,11 @@ macro_rules! setup_tracked_fn { $zalsa::macro_if! { $needs_interner => fn intern_ingredient( - db: &dyn $Db, + zalsa: &$zalsa::Zalsa, ) -> &$zalsa::interned::IngredientImpl<$Configuration> { - let zalsa = db.zalsa(); Self::intern_ingredient_(zalsa) } + #[inline] fn intern_ingredient_<'z>( zalsa: &'z $zalsa::Zalsa @@ -226,6 +274,7 @@ macro_rules! setup_tracked_fn { line: line!(), }; const DEBUG_NAME: &'static str = concat!($(stringify!($self_ty), "::", )? stringify!($fn_name)); + const PERSIST: bool = $persist; type DbView = dyn $Db; @@ -275,6 +324,31 @@ macro_rules! setup_tracked_fn { } } } + + fn serialize( + value: &Self::Output<'_>, + serializer: S, + ) -> Result { + $zalsa::macro_if! { + if $persist { + $zalsa::serde::Serialize::serialize(value, serializer) + } else { + panic!("attempted to serialize value not marked with `persist` attribute") + } + } + } + + fn deserialize<'de, D: $zalsa::serde::Deserializer<'de>>( + deserializer: D, + ) -> Result, D::Error> { + $zalsa::macro_if! { + if $persist { + $zalsa::serde::Deserialize::deserialize(deserializer) + } else { + panic!("attempted to deserialize value not marked with `persist` attribute") + } + } + } } #[allow(non_local_definitions)] @@ -352,7 +426,7 @@ macro_rules! setup_tracked_fn { let key = $zalsa::macro_if! { if $needs_interner {{ let (zalsa, zalsa_local) = $db.zalsas(); - $Configuration::intern_ingredient($db).intern_id(zalsa, zalsa_local, ($($input_id),*), |_, data| data) + $Configuration::intern_ingredient(zalsa).intern_id(zalsa, zalsa_local, ($($input_id),*), |_, data| data) }} else { $zalsa::AsId::as_id(&($($input_id),*)) } diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index cb69a08e1..191659a07 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -91,6 +91,15 @@ macro_rules! setup_tracked_struct { // The function used to implement `C::heap_size`. heap_size_fn: $($heap_size_fn:path)?, + // If `true`, `serialize_fn` and `deserialize_fn` have been provided. + persist: $persist:tt, + + // The path to the `serialize` function for the value's fields. + serialize_fn: $($serialize_fn:path)?, + + // The path to the `serialize` function for the value's fields. + deserialize_fn: $($deserialize_fn:path)?, + // Annoyingly macro-rules hygiene does not extend to items defined in the macro. // We have the procedural macro generate names for those items that are // not used elsewhere in the user's code. @@ -143,6 +152,8 @@ macro_rules! setup_tracked_struct { $($relative_tracked_index,)* ]; + const PERSIST: bool = $persist; + type Fields<$db_lt> = ($($field_ty,)*); type Revisions = [$Revision; $N]; @@ -194,6 +205,31 @@ macro_rules! setup_tracked_struct { Some($heap_size_fn(value)) } )? + + fn serialize( + fields: &Self::Fields<'_>, + serializer: S, + ) -> Result { + $zalsa::macro_if! { + if $persist { + $($serialize_fn(fields, serializer))? + } else { + panic!("attempted to serialize value not marked with `persist` attribute") + } + } + } + + fn deserialize<'de, D: $zalsa::serde::Deserializer<'de>>( + deserializer: D, + ) -> Result, D::Error> { + $zalsa::macro_if! { + if $persist { + $($deserialize_fn(deserializer))? + } else { + panic!("attempted to deserialize value not marked with `persist` attribute") + } + } + } } impl $Configuration { @@ -236,6 +272,13 @@ macro_rules! setup_tracked_struct { aux.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into() } + fn entries( + zalsa: &$zalsa::Zalsa + ) -> impl Iterator + '_ { + let ingredient_index = zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>(); + <$Configuration>::ingredient_(zalsa).entries(zalsa).map(|(key, _)| key) + } + #[inline] fn cast(id: $zalsa::Id, type_id: $zalsa::TypeId) -> $zalsa::Option { if type_id == $zalsa::TypeId::of::<$Struct<'static>>() { @@ -262,6 +305,28 @@ macro_rules! setup_tracked_struct { } } + $zalsa::macro_if! { $persist => + impl $zalsa::serde::Serialize for $Struct<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: $zalsa::serde::Serializer, + { + $zalsa::serde::Serialize::serialize(&$zalsa::AsId::as_id(self), serializer) + } + } + + impl<'de> $zalsa::serde::Deserialize<'de> for $Struct<'_> { + fn deserialize(deserializer: D) -> Result + where + D: $zalsa::serde::Deserializer<'de>, + { + let id = $zalsa::Id::deserialize(deserializer)?; + Ok($zalsa::FromId::from_id(id)) + } + } + } + + unsafe impl Send for $Struct<'_> {} unsafe impl Sync for $Struct<'_> {} diff --git a/components/salsa-macros/Cargo.toml b/components/salsa-macros/Cargo.toml index ea4efe078..9bf6992ae 100644 --- a/components/salsa-macros/Cargo.toml +++ b/components/salsa-macros/Cargo.toml @@ -16,3 +16,7 @@ proc-macro2 = "1.0" quote = "1.0" syn = { version = "2.0.104", features = ["full", "visit-mut"] } synstructure = "0.13.2" + +[features] +default = [] +persistence = [] diff --git a/components/salsa-macros/src/accumulator.rs b/components/salsa-macros/src/accumulator.rs index 866674bf6..1c443a688 100644 --- a/components/salsa-macros/src/accumulator.rs +++ b/components/salsa-macros/src/accumulator.rs @@ -1,7 +1,7 @@ use proc_macro2::TokenStream; use crate::hygiene::Hygiene; -use crate::options::{AllowedOptions, Options}; +use crate::options::{AllowedOptions, AllowedPersistOptions, Options}; use crate::token_stream_with_error; // #[salsa::accumulator(jar = Jar0)] @@ -47,6 +47,8 @@ impl AllowedOptions for Accumulator { const REVISIONS: bool = false; const HEAP_SIZE: bool = false; const SELF_TY: bool = false; + // TODO: Support serializing accumulators. + const PERSIST: AllowedPersistOptions = AllowedPersistOptions::Invalid; } struct StructMacro { diff --git a/components/salsa-macros/src/input.rs b/components/salsa-macros/src/input.rs index b04176d68..9b4901419 100644 --- a/components/salsa-macros/src/input.rs +++ b/components/salsa-macros/src/input.rs @@ -1,7 +1,7 @@ use proc_macro2::TokenStream; use crate::hygiene::Hygiene; -use crate::options::Options; +use crate::options::{AllowedOptions, AllowedPersistOptions, Options}; use crate::salsa_struct::{SalsaStruct, SalsaStructAllowedOptions}; use crate::token_stream_with_error; @@ -32,7 +32,7 @@ type InputArgs = Options; struct InputStruct; -impl crate::options::AllowedOptions for InputStruct { +impl AllowedOptions for InputStruct { const RETURNS: bool = false; const SPECIFY: bool = false; @@ -68,6 +68,8 @@ impl crate::options::AllowedOptions for InputStruct { const HEAP_SIZE: bool = true; const SELF_TY: bool = false; + + const PERSIST: AllowedPersistOptions = AllowedPersistOptions::AllowedValue; } impl SalsaStructAllowedOptions for InputStruct { @@ -114,6 +116,10 @@ impl Macro { let generate_debug_impl = salsa_struct.generate_debug_impl(); let heap_size_fn = self.args.heap_size_fn.iter(); + let persist = self.args.persist(); + let serialize_fn = salsa_struct.serialize_fn(); + let deserialize_fn = salsa_struct.deserialize_fn(); + let zalsa = self.hygiene.ident("zalsa"); let zalsa_struct = self.hygiene.ident("zalsa_struct"); let Configuration = self.hygiene.ident("Configuration"); @@ -142,6 +148,9 @@ impl Macro { is_singleton: #is_singleton, generate_debug_impl: #generate_debug_impl, heap_size_fn: #(#heap_size_fn)*, + persist: #persist, + serialize_fn: #(#serialize_fn)*, + deserialize_fn: #(#deserialize_fn)*, unused_names: [ #zalsa, #zalsa_struct, diff --git a/components/salsa-macros/src/interned.rs b/components/salsa-macros/src/interned.rs index dd064af12..8b72b3510 100644 --- a/components/salsa-macros/src/interned.rs +++ b/components/salsa-macros/src/interned.rs @@ -1,7 +1,7 @@ use proc_macro2::TokenStream; use crate::hygiene::Hygiene; -use crate::options::Options; +use crate::options::{AllowedOptions, AllowedPersistOptions, Options}; use crate::salsa_struct::{SalsaStruct, SalsaStructAllowedOptions}; use crate::{db_lifetime, token_stream_with_error}; @@ -32,7 +32,7 @@ type InternedArgs = Options; struct InternedStruct; -impl crate::options::AllowedOptions for InternedStruct { +impl AllowedOptions for InternedStruct { const RETURNS: bool = false; const SPECIFY: bool = false; @@ -68,6 +68,8 @@ impl crate::options::AllowedOptions for InternedStruct { const HEAP_SIZE: bool = true; const SELF_TY: bool = false; + + const PERSIST: AllowedPersistOptions = AllowedPersistOptions::AllowedValue; } impl SalsaStructAllowedOptions for InternedStruct { @@ -131,6 +133,10 @@ impl Macro { (None, quote!(#struct_ident), static_lifetime) }; + let persist = self.args.persist(); + let serialize_fn = salsa_struct.serialize_fn(); + let deserialize_fn = salsa_struct.deserialize_fn(); + let heap_size_fn = self.args.heap_size_fn.iter(); let zalsa = self.hygiene.ident("zalsa"); @@ -164,6 +170,9 @@ impl Macro { num_fields: #num_fields, generate_debug_impl: #generate_debug_impl, heap_size_fn: #(#heap_size_fn)*, + persist: #persist, + serialize_fn: #(#serialize_fn)*, + deserialize_fn: #(#deserialize_fn)*, unused_names: [ #zalsa, #zalsa_struct, diff --git a/components/salsa-macros/src/options.rs b/components/salsa-macros/src/options.rs index 7664d6eae..b7bdc807b 100644 --- a/components/salsa-macros/src/options.rs +++ b/components/salsa-macros/src/options.rs @@ -1,8 +1,8 @@ use std::marker::PhantomData; use syn::ext::IdentExt; -use syn::parenthesized; use syn::spanned::Spanned; +use syn::{parenthesized, token}; /// "Options" are flags that can be supplied to the various salsa related /// macros. They are listed like `(ref, no_eq, foo=bar)` etc. The commas @@ -50,6 +50,12 @@ pub(crate) struct Options { /// If this is `Some`, the value is the `non_update_return_type` identifier. pub non_update_return_type: Option, + /// The `persist` options indicates that the ingredient should be persisted with the database. + /// + /// If this is `Some`, the value is optional paths to custom serialization/deserialization + /// functions, based on `serde::{Serialize, Deserialize}`. + pub persist: Option, + /// The `db = ` option is used to indicate the db. /// /// If this is `Some`, the value is the ``. @@ -113,6 +119,21 @@ pub(crate) struct Options { phantom: PhantomData, } +impl Options { + pub fn persist(&self) -> bool { + cfg!(feature = "persistence") && self.persist.is_some() + } +} + +#[derive(Debug, Default, Clone)] +pub struct PersistOptions { + /// Path to a custom serialize function. + pub serialize_fn: Option, + + /// Path to a custom serialize function. + pub deserialize_fn: Option, +} + impl Default for Options { fn default() -> Self { Self { @@ -135,6 +156,7 @@ impl Default for Options { revisions: Default::default(), heap_size_fn: Default::default(), self_ty: Default::default(), + persist: Default::default(), } } } @@ -159,6 +181,23 @@ pub(crate) trait AllowedOptions { const REVISIONS: bool; const HEAP_SIZE: bool; const SELF_TY: bool; + const PERSIST: AllowedPersistOptions; +} + +pub(crate) enum AllowedPersistOptions { + AllowedIdent, + AllowedValue, + Invalid, +} + +impl AllowedPersistOptions { + fn allowed(&self) -> bool { + matches!(self, Self::AllowedIdent | Self::AllowedValue) + } + + fn allowed_value(&self) -> bool { + matches!(self, Self::AllowedValue) + } } type Equals = syn::Token![=]; @@ -247,6 +286,65 @@ impl syn::parse::Parse for Options { "`unsafe` options not allowed here", )); } + } else if ident == "persist" { + if !cfg!(feature = "persistence") { + return Err(syn::Error::new( + ident.span(), + "the `persist` option cannot be used when the `persistence` feature is disabled", + )); + } + + if !A::PERSIST.allowed() { + return Err(syn::Error::new( + ident.span(), + "`persist` option not allowed here", + )); + } + + if options.persist.is_some() { + return Err(syn::Error::new( + ident.span(), + "option `persist` provided twice", + )); + } + + let persist = options.persist.insert(PersistOptions::default()); + + if input.peek(token::Paren) { + let content; + parenthesized!(content in input); + + let parse_argument = |content| { + let ident = syn::Ident::parse(content)?; + let _ = Equals::parse(content)?; + let path = syn::Path::parse(content)?; + Ok((ident, path)) + }; + + for (ident, path) in content.parse_terminated(parse_argument, syn::Token![,])? { + if !A::PERSIST.allowed_value() { + return Err(syn::Error::new(ident.span(), "unexpected argument")); + } + + if ident == "serialize" { + if persist.serialize_fn.replace(path).is_some() { + return Err(syn::Error::new( + ident.span(), + "option `serialize` provided twice", + )); + } + } else if ident == "deserialize" { + if persist.deserialize_fn.replace(path).is_some() { + return Err(syn::Error::new( + ident.span(), + "option `deserialize` provided twice", + )); + } + } else { + return Err(syn::Error::new(ident.span(), "unexpected argument")); + } + } + } } else if ident == "singleton" { if A::SINGLETON { if let Some(old) = options.singleton.replace(ident) { @@ -476,6 +574,7 @@ impl quote::ToTokens for Options { revisions, heap_size_fn, self_ty, + persist, phantom: _, } = self; if let Some(returns) = returns { @@ -532,5 +631,22 @@ impl quote::ToTokens for Options { if let Some(self_ty) = self_ty { tokens.extend(quote::quote! { self_ty = #self_ty, }); } + if let Some(persist) = persist { + let mut args = proc_macro2::TokenStream::new(); + + if let Some(path) = &persist.serialize_fn { + args.extend(quote::quote! { serialize = #path, }); + } + + if let Some(path) = &persist.deserialize_fn { + args.extend(quote::quote! { deserialize = #path, }); + } + + if args.is_empty() { + tokens.extend(quote::quote! { persist, }); + } else { + tokens.extend(quote::quote! { persist(#args), }); + } + } } } diff --git a/components/salsa-macros/src/salsa_struct.rs b/components/salsa-macros/src/salsa_struct.rs index 9caa31147..463aca0bc 100644 --- a/components/salsa-macros/src/salsa_struct.rs +++ b/components/salsa-macros/src/salsa_struct.rs @@ -431,6 +431,36 @@ where .enumerate() .filter(|(_, f)| !f.has_tracked_attr) } + + /// Returns the path to the `serialize` function as an optional iterator. + /// + /// This will be `None` if `persistable` returns `false`. + pub(crate) fn serialize_fn(&self) -> impl Iterator + '_ { + self.args + .persist + .clone() + .map(|persist| { + persist + .serialize_fn + .unwrap_or(parse_quote! { serde::Serialize::serialize }) + }) + .into_iter() + } + + /// Returns the path to the `deserialize` function as an optional iterator. + /// + /// This will be `None` if `persistable` returns `false`. + pub(crate) fn deserialize_fn(&self) -> impl Iterator + '_ { + self.args + .persist + .clone() + .map(|persist| { + persist + .deserialize_fn + .unwrap_or(parse_quote! { serde::Deserialize::deserialize }) + }) + .into_iter() + } } impl<'s> SalsaField<'s> { diff --git a/components/salsa-macros/src/supertype.rs b/components/salsa-macros/src/supertype.rs index ebf7f4516..1c98deb23 100644 --- a/components/salsa-macros/src/supertype.rs +++ b/components/salsa-macros/src/supertype.rs @@ -76,6 +76,13 @@ fn enum_impl(enum_item: syn::ItemEnum) -> syn::Result { zalsa::IngredientIndices::merge([ #( <#variant_types as zalsa::SalsaStructInDb>::lookup_ingredient_index(__zalsa) ),* ]) } + fn entries( + zalsa: &zalsa::Zalsa + ) -> impl Iterator + '_ { + std::iter::empty() + #( .chain(<#variant_types as zalsa::SalsaStructInDb>::entries(zalsa)) )* + } + #[inline] fn cast(id: zalsa::Id, type_id: ::core::any::TypeId) -> Option { #( diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 9f7ec0879..5c6fab7d2 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -4,7 +4,7 @@ use syn::spanned::Spanned; use syn::{Ident, ItemFn}; use crate::hygiene::Hygiene; -use crate::options::Options; +use crate::options::{AllowedOptions, AllowedPersistOptions, Options}; use crate::{db_lifetime, fn_util}; // Source: @@ -25,7 +25,7 @@ pub type FnArgs = Options; pub struct TrackedFn; -impl crate::options::AllowedOptions for TrackedFn { +impl AllowedOptions for TrackedFn { const RETURNS: bool = true; const SPECIFY: bool = true; @@ -61,6 +61,8 @@ impl crate::options::AllowedOptions for TrackedFn { const HEAP_SIZE: bool = true; const SELF_TY: bool = true; + + const PERSIST: AllowedPersistOptions = AllowedPersistOptions::AllowedIdent; } struct Macro { @@ -183,6 +185,8 @@ impl Macro { )); } + let persist = self.args.persist(); + // The path expression is responsible for emitting the primary span in the diagnostic we // want, so by uniformly using `output_ty.span()` we ensure that the diagnostic is emitted // at the return type in the original input. @@ -229,6 +233,7 @@ impl Macro { heap_size_fn: #(#heap_size_fn)*, lru: #lru, return_mode: #return_mode, + persist: #persist, assert_return_type_is_update: { #assert_return_type_is_update }, #self_ty unused_names: [ diff --git a/components/salsa-macros/src/tracked_struct.rs b/components/salsa-macros/src/tracked_struct.rs index 5768eb9cd..d6ee32d47 100644 --- a/components/salsa-macros/src/tracked_struct.rs +++ b/components/salsa-macros/src/tracked_struct.rs @@ -3,7 +3,7 @@ use syn::spanned::Spanned; use crate::db_lifetime; use crate::hygiene::Hygiene; -use crate::options::Options; +use crate::options::{AllowedOptions, AllowedPersistOptions, Options}; use crate::salsa_struct::{SalsaStruct, SalsaStructAllowedOptions}; /// For an entity struct `Foo` with fields `f1: T1, ..., fN: TN`, we generate... @@ -28,7 +28,7 @@ type TrackedArgs = Options; struct TrackedStruct; -impl crate::options::AllowedOptions for TrackedStruct { +impl AllowedOptions for TrackedStruct { const RETURNS: bool = false; const SPECIFY: bool = false; @@ -64,6 +64,8 @@ impl crate::options::AllowedOptions for TrackedStruct { const HEAP_SIZE: bool = true; const SELF_TY: bool = false; + + const PERSIST: AllowedPersistOptions = AllowedPersistOptions::AllowedValue; } impl SalsaStructAllowedOptions for TrackedStruct { @@ -141,6 +143,10 @@ impl Macro { } }); + let persist = self.args.persist(); + let serialize_fn = salsa_struct.serialize_fn(); + let deserialize_fn = salsa_struct.deserialize_fn(); + let heap_size_fn = self.args.heap_size_fn.iter(); let num_tracked_fields = salsa_struct.num_tracked_fields(); @@ -193,6 +199,10 @@ impl Macro { heap_size_fn: #(#heap_size_fn)*, + persist: #persist, + serialize_fn: #(#serialize_fn)*, + deserialize_fn: #(#deserialize_fn)*, + unused_names: [ #zalsa, #zalsa_struct, diff --git a/src/accumulator.rs b/src/accumulator.rs index 62332b000..76ecc4678 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -12,7 +12,7 @@ use crate::ingredient::{Ingredient, Jar}; use crate::plumbing::ZalsaLocal; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; -use crate::zalsa::{IngredientIndex, Zalsa}; +use crate::zalsa::{IngredientIndex, JarKind, Zalsa}; use crate::{Database, Id, Revision}; mod accumulated; @@ -114,6 +114,10 @@ impl Ingredient for IngredientImpl { A::DEBUG_NAME } + fn jar_kind(&self) -> JarKind { + JarKind::Struct + } + fn memo_table_types(&self) -> &Arc { unreachable!("accumulator does not allocate pages") } diff --git a/src/cycle.rs b/src/cycle.rs index e044572fb..0558bda04 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -96,12 +96,14 @@ pub enum CycleRecoveryStrategy { /// fixpoint iteration is enabled for that query), and then is responsible for re-iterating the /// cycle until it converges. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct CycleHead { pub(crate) database_key_index: DatabaseKeyIndex, pub(crate) iteration_count: IterationCount, } #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Default)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct IterationCount(u8); impl IterationCount { @@ -131,6 +133,7 @@ impl IterationCount { /// plural in case of nested cycles) representing the cycles it is part of, and the current /// iteration count for each cycle head. This struct tracks these cycle heads. #[derive(Clone, Debug, Default)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct CycleHeads(ThinVec); impl CycleHeads { diff --git a/src/database.rs b/src/database.rs index 46120eae4..cddb14941 100644 --- a/src/database.rs +++ b/src/database.rs @@ -152,6 +152,225 @@ pub fn current_revision(db: &Db) -> Revision { db.zalsa().current_revision() } +#[cfg(feature = "persistence")] +mod persistence { + use crate::plumbing::Ingredient; + use crate::zalsa::Zalsa; + use crate::{Database, IngredientIndex, Runtime}; + + use std::fmt; + + use serde::de::{self, DeserializeSeed}; + use serde::ser::SerializeMap; + + impl dyn Database { + /// Returns a type implementing [`serde::Serialize`], that can be used to serialize the + /// current state of the database. + pub fn as_serialize(&mut self) -> impl serde::Serialize + '_ { + SerializeDatabase { + runtime: self.zalsa().runtime(), + ingredients: SerializeIngredients(self.zalsa()), + } + } + + /// Deserialize the database using a [`serde::Deserializer`]. + /// + /// This method will modify the database in-place based on the serialized data. + pub fn deserialize<'db, D>(&mut self, deserializer: D) -> Result<(), D::Error> + where + D: serde::Deserializer<'db>, + { + DeserializeDatabase(self.zalsa_mut()).deserialize(deserializer) + } + } + + #[derive(serde::Serialize)] + #[serde(rename = "Database")] + pub struct SerializeDatabase<'db> { + pub runtime: &'db Runtime, + pub ingredients: SerializeIngredients<'db>, + } + + pub struct SerializeIngredients<'db>(pub &'db Zalsa); + + impl serde::Serialize for SerializeIngredients<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let SerializeIngredients(zalsa) = self; + + let mut ingredients = zalsa + .ingredients() + .filter(|ingredient| ingredient.should_serialize(zalsa)) + .collect::>(); + + // Ensure structs are serialized before tracked functions, as deserializing a + // memo requires its input struct to have been deserialized. + ingredients.sort_by_key(|ingredient| ingredient.jar_kind()); + + let mut map = serializer.serialize_map(Some(ingredients.len()))?; + for ingredient in ingredients { + map.serialize_entry( + &ingredient.ingredient_index().as_u32(), + &SerializeIngredient(ingredient, zalsa), + )?; + } + + map.end() + } + } + + struct SerializeIngredient<'db>(&'db dyn Ingredient, &'db Zalsa); + + impl serde::Serialize for SerializeIngredient<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut result = None; + let mut serializer = Some(serializer); + + // SAFETY: `::as_serialize` take `&mut self`. + unsafe { + self.0.serialize(self.1, &mut |serialize| { + let serializer = serializer.take().expect( + "`Ingredient::serialize` must invoke the serialization callback only once", + ); + + result = Some(erased_serde::serialize(&serialize, serializer)) + }) + }; + + result.expect("`Ingredient::serialize` must invoke the serialization callback") + } + } + + #[derive(serde::Deserialize)] + #[serde(field_identifier, rename_all = "lowercase")] + enum DatabaseField { + Runtime, + Ingredients, + } + + pub struct DeserializeDatabase<'db>(pub &'db mut Zalsa); + + impl<'de> de::DeserializeSeed<'de> for DeserializeDatabase<'_> { + type Value = (); + + fn deserialize(self, deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + // Note that we have to deserialize using a manual visitor here because the + // `Deserialize` derive does not support fields that use `DeserializeSeed`. + deserializer.deserialize_struct("Database", &["runtime", "ingredients"], self) + } + } + + impl<'de> serde::de::Visitor<'de> for DeserializeDatabase<'_> { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct Database") + } + + fn visit_map(self, mut map: V) -> Result<(), V::Error> + where + V: serde::de::MapAccess<'de>, + { + let mut runtime = None; + let mut ingredients = None; + + while let Some(key) = map.next_key()? { + match key { + DatabaseField::Runtime => { + if runtime.is_some() { + return Err(serde::de::Error::duplicate_field("runtime")); + } + + runtime = Some(map.next_value()?); + } + DatabaseField::Ingredients => { + if ingredients.is_some() { + return Err(serde::de::Error::duplicate_field("ingredients")); + } + + ingredients = Some(map.next_value_seed(DeserializeIngredients(self.0))?); + } + } + } + + let mut runtime = runtime.ok_or_else(|| serde::de::Error::missing_field("runtime"))?; + let () = ingredients.ok_or_else(|| serde::de::Error::missing_field("ingredients"))?; + + self.0.runtime_mut().deserialize_from(&mut runtime); + + Ok(()) + } + } + + struct DeserializeIngredients<'db>(&'db mut Zalsa); + + impl<'de> serde::de::Visitor<'de> for DeserializeIngredients<'_> { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map") + } + + fn visit_map(self, mut access: M) -> Result + where + M: serde::de::MapAccess<'de>, + { + let DeserializeIngredients(zalsa) = self; + + while let Some(index) = access.next_key::()? { + let index = IngredientIndex::new(index); + + // Remove the ingredient temporarily, to avoid holding an overlapping mutable borrow + // to the ingredient as well as the database. + let mut ingredient = zalsa.take_ingredient(index); + + // Deserialize the ingredient. + access.next_value_seed(DeserializeIngredient(&mut *ingredient, zalsa))?; + + zalsa.replace_ingredient(index, ingredient); + } + + Ok(()) + } + } + + impl<'de> serde::de::DeserializeSeed<'de> for DeserializeIngredients<'_> { + type Value = (); + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_map(self) + } + } + + struct DeserializeIngredient<'db>(&'db mut dyn Ingredient, &'db mut Zalsa); + + impl<'de> serde::de::DeserializeSeed<'de> for DeserializeIngredient<'_> { + type Value = (); + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let deserializer = &mut ::erase(deserializer); + + self.0 + .deserialize(self.1, deserializer) + .map_err(serde::de::Error::custom) + } + } +} + #[cfg(feature = "salsa_unstable")] pub use memory_usage::IngredientInfo; @@ -160,9 +379,10 @@ pub(crate) use memory_usage::{MemoInfo, SlotInfo}; #[cfg(feature = "salsa_unstable")] mod memory_usage { - use crate::Database; use hashbrown::HashMap; + use crate::Database; + impl dyn Database { /// Returns memory usage information about ingredients in the database. pub fn memory_usage(&self) -> DatabaseInfo { diff --git a/src/durability.rs b/src/durability.rs index 2f2b384d8..a3e33b1bc 100644 --- a/src/durability.rs +++ b/src/durability.rs @@ -17,6 +17,7 @@ /// configuration, the source from library crates, or other things /// that are unlikely to be edited. #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct Durability(DurabilityVal); impl std::fmt::Debug for Durability { @@ -37,6 +38,7 @@ impl std::fmt::Debug for Durability { // We use an enum here instead of a u8 for niches. #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] enum DurabilityVal { Low = 0, Medium = 1, diff --git a/src/function.rs b/src/function.rs index 3e8674cf0..28b76a0cc 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,10 +1,11 @@ pub(crate) use maybe_changed_after::{VerifyCycleHeads, VerifyResult}; +pub(crate) use sync::SyncGuard; + use std::any::Any; use std::fmt; use std::ptr::NonNull; use std::sync::atomic::Ordering; use std::sync::OnceLock; -pub(crate) use sync::SyncGuard; use crate::cycle::{ empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, ProvisionalStatus, @@ -14,13 +15,13 @@ use crate::function::delete::DeletedEntries; use crate::function::sync::{ClaimResult, SyncTable}; use crate::ingredient::{Ingredient, WaitForResult}; use crate::key::DatabaseKeyIndex; -use crate::plumbing::MemoIngredientMap; +use crate::plumbing::{self, MemoIngredientMap}; use crate::salsa_struct::SalsaStructInDb; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::table::Table; use crate::views::DatabaseDownCaster; -use crate::zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}; +use crate::zalsa::{IngredientIndex, JarKind, MemoIngredientIndex, Zalsa}; use crate::zalsa_local::QueryOriginRef; use crate::{Id, Revision}; @@ -43,6 +44,7 @@ pub type Memo = memo::Memo<'static, C>; pub trait Configuration: Any { const DEBUG_NAME: &'static str; const LOCATION: crate::ingredient::Location; + const PERSIST: bool; /// The database that this function is associated with. type DbView: ?Sized + crate::Database; @@ -96,6 +98,20 @@ pub trait Configuration: Any { count: u32, input: Self::Input<'db>, ) -> CycleRecoveryAction>; + + /// Serialize the output type using `serde`. + /// + /// Panics if the value is not persistable, i.e. `Configuration::PERSIST` is `false`. + fn serialize(value: &Self::Output<'_>, serializer: S) -> Result + where + S: plumbing::serde::Serializer; + + /// Deserialize the output type using `serde`. + /// + /// Panics if the value is not persistable, i.e. `Configuration::PERSIST` is `false`. + fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: plumbing::serde::Deserializer<'de>; } /// Function ingredients are the "workhorse" of salsa. @@ -359,6 +375,10 @@ where C::DEBUG_NAME } + fn jar_kind(&self) -> JarKind { + JarKind::TrackedFn + } + fn memo_table_types(&self) -> &Arc { unreachable!("function does not allocate pages") } @@ -384,6 +404,56 @@ where let db = unsafe { self.view_caster().downcast_unchecked(db) }; self.accumulated_map(db, key_index) } + + fn is_persistable(&self) -> bool { + C::PERSIST + } + + fn should_serialize(&self, zalsa: &Zalsa) -> bool { + if !C::PERSIST { + return false; + } + + // We only serialize the query if there are any memos associated with it. + for entry in as SalsaStructInDb>::entries(zalsa) { + let memo_ingredient_index = self.memo_ingredient_indices.get(entry.ingredient_index()); + + let memo = + self.get_memo_from_table_for(zalsa, entry.key_index(), memo_ingredient_index); + + if memo.is_some_and(|memo| memo.should_serialize()) { + return true; + } + } + + false + } + + #[cfg(feature = "persistence")] + unsafe fn serialize<'db>( + &'db self, + zalsa: &'db Zalsa, + f: &mut dyn FnMut(&dyn erased_serde::Serialize), + ) { + f(&persistence::SerializeIngredient { + zalsa, + ingredient: self, + }) + } + + #[cfg(feature = "persistence")] + fn deserialize( + &mut self, + zalsa: &mut Zalsa, + deserializer: &mut dyn erased_serde::Deserializer, + ) -> Result<(), erased_serde::Error> { + let deserialize = persistence::DeserializeIngredient { + zalsa, + ingredient: self, + }; + + serde::de::DeserializeSeed::deserialize(deserialize, deserializer) + } } impl std::fmt::Debug for IngredientImpl @@ -396,3 +466,152 @@ where .finish() } } + +#[cfg(feature = "persistence")] +mod persistence { + use super::{Configuration, IngredientImpl, Memo}; + use crate::plumbing::{Ingredient, MemoIngredientMap, SalsaStructInDb}; + use crate::zalsa::Zalsa; + use crate::{Id, IngredientIndex}; + + use serde::de; + use serde::ser::SerializeMap; + + use std::ptr::NonNull; + + pub struct SerializeIngredient<'db, C> + where + C: Configuration, + { + pub zalsa: &'db Zalsa, + pub ingredient: &'db IngredientImpl, + } + + impl serde::Serialize for SerializeIngredient<'_, C> + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let Self { ingredient, zalsa } = self; + + let mut map = serializer.serialize_map(None)?; + + for struct_index in + as SalsaStructInDb>::lookup_ingredient_index(zalsa).iter() + { + let struct_ingredient = zalsa.lookup_ingredient(struct_index); + assert!( + struct_ingredient.is_persistable(), + "the input of a serialized tracked function must be serialized" + ); + } + + for entry in as SalsaStructInDb>::entries(zalsa) { + let memo_ingredient_index = ingredient + .memo_ingredient_indices + .get(entry.ingredient_index()); + + let memo = ingredient.get_memo_from_table_for( + zalsa, + entry.key_index(), + memo_ingredient_index, + ); + + if let Some(memo) = memo.filter(|memo| memo.should_serialize()) { + for edge in memo.revisions.origin.as_ref().edges() { + let dependency = zalsa.lookup_ingredient(edge.key().ingredient_index()); + + // TODO: This is not strictly necessary, we only need the transitive input + // dependencies of this query to serialize a valid memo. + assert!( + dependency.is_persistable(), + "attempted to serialize query `{}`, but dependency `{}` is not persistable", + ingredient.debug_name(), + dependency.debug_name() + ); + } + + // TODO: Group structs by ingredient index into a nested map. + let key = format!( + "{}:{}", + entry.ingredient_index().as_u32(), + entry.key_index().as_bits() + ); + + map.serialize_entry(&key, memo)?; + } + } + + map.end() + } + } + + pub struct DeserializeIngredient<'db, C> + where + C: Configuration, + { + pub zalsa: &'db Zalsa, + pub ingredient: &'db mut IngredientImpl, + } + + impl<'de, C> de::DeserializeSeed<'de> for DeserializeIngredient<'_, C> + where + C: Configuration, + { + type Value = (); + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_map(self) + } + } + + impl<'de, C> de::Visitor<'de> for DeserializeIngredient<'_, C> + where + C: Configuration, + { + type Value = (); + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a map") + } + + fn visit_map(self, mut access: M) -> Result + where + M: de::MapAccess<'de>, + { + let DeserializeIngredient { zalsa, ingredient } = self; + + while let Some((key, memo)) = access.next_entry::>()? { + let (ingredient_index, id) = key + .split_once(':') + .ok_or_else(|| de::Error::custom("invalid database key"))?; + + let ingredient_index = IngredientIndex::new( + ingredient_index.parse::().map_err(de::Error::custom)?, + ); + + let id = Id::from_bits(id.parse::().map_err(de::Error::custom)?); + + let memo_ingredient_index = + ingredient.memo_ingredient_indices.get(ingredient_index); + + // SAFETY: We provide the current revision. + let memo_table = unsafe { zalsa.table().dyn_memos(id, zalsa.current_revision()) }; + + memo_table.insert( + memo_ingredient_index, + // FIXME: Use `Box::into_non_null` once stable. + NonNull::from(Box::leak(Box::new(memo))), + ); + } + + Ok(()) + } + } +} diff --git a/src/function/memo.rs b/src/function/memo.rs index a6107060e..eb8fcec70 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -114,6 +114,13 @@ impl<'db, C: Configuration> Memo<'db, C> { } } + /// Returns `true` if this memo should be serialized. + pub(super) fn should_serialize(&self) -> bool { + // TODO: Serialization is a good opportunity to prune old query results based on + // the `verified_at` revision. + self.value.is_some() && !self.may_be_provisional() + } + /// True if this may be a provisional cycle-iteration result. #[inline] pub(super) fn may_be_provisional(&self) -> bool { @@ -340,6 +347,97 @@ where } } +#[cfg(feature = "persistence")] +mod persistence { + use crate::function::memo::Memo; + use crate::function::Configuration; + use crate::revision::AtomicRevision; + use crate::zalsa_local::QueryRevisions; + + use serde::ser::SerializeStruct; + use serde::Deserialize; + + impl serde::Serialize for Memo<'_, C> + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + struct SerializeValue<'me, 'db, C: Configuration>(&'me C::Output<'db>); + + impl serde::Serialize for SerializeValue<'_, '_, C> + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + C::serialize(self.0, serializer) + } + } + + let Memo { + value, + verified_at, + revisions, + } = self; + + let value = value.as_ref().expect("attempted to serialize empty memo"); + + let mut s = serializer.serialize_struct("Memo", 3)?; + s.serialize_field("value", &SerializeValue::(value))?; + s.serialize_field("verified_at", &verified_at)?; + s.serialize_field("revisions", &revisions)?; + s.end() + } + } + + impl<'de, C> serde::Deserialize<'de> for Memo<'static, C> + where + C: Configuration, + { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + pub struct DeserializeMemo { + #[serde(bound = "C: Configuration")] + value: DeserializeValue, + verified_at: AtomicRevision, + revisions: QueryRevisions, + } + + struct DeserializeValue(C::Output<'static>); + + impl<'de, C> serde::Deserialize<'de> for DeserializeValue + where + C: Configuration, + { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + C::deserialize(deserializer) + .map(DeserializeValue) + .map_err(serde::de::Error::custom) + } + } + + let memo = DeserializeMemo::::deserialize(deserializer)?; + + Ok(Memo { + value: Some(memo.value.0), + verified_at: memo.verified_at, + revisions: memo.revisions, + }) + } + } +} + pub(super) enum TryClaimHeadsResult<'me> { /// Claiming every cycle head results in a cycle head. Cycle, @@ -460,7 +558,7 @@ impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { mod _memory_usage { use crate::cycle::CycleRecoveryStrategy; use crate::ingredient::Location; - use crate::plumbing::{IngredientIndices, MemoIngredientSingletonIndex, SalsaStructInDb}; + use crate::plumbing::{self, IngredientIndices, MemoIngredientSingletonIndex, SalsaStructInDb}; use crate::table::memo::MemoTableWithTypes; use crate::zalsa::Zalsa; use crate::{CycleRecoveryAction, Database, Id, Revision}; @@ -488,6 +586,10 @@ mod _memory_usage { unsafe fn memo_table(_: &Zalsa, _: Id, _: Revision) -> MemoTableWithTypes<'_> { unimplemented!() } + + fn entries(_: &Zalsa) -> impl Iterator + '_ { + std::iter::empty() + } } struct DummyConfiguration; @@ -495,11 +597,13 @@ mod _memory_usage { impl super::Configuration for DummyConfiguration { const DEBUG_NAME: &'static str = ""; const LOCATION: Location = Location { file: "", line: 0 }; + const PERSIST: bool = false; + const CYCLE_STRATEGY: CycleRecoveryStrategy = CycleRecoveryStrategy::Panic; + type DbView = dyn Database; type SalsaStruct<'db> = DummyStruct; type Input<'db> = (); type Output<'db> = NonZeroUsize; - const CYCLE_STRATEGY: CycleRecoveryStrategy = CycleRecoveryStrategy::Panic; fn values_equal<'db>(_: &Self::Output<'db>, _: &Self::Output<'db>) -> bool { unimplemented!() @@ -525,5 +629,19 @@ mod _memory_usage { ) -> CycleRecoveryAction> { unimplemented!() } + + fn serialize(_: &Self::Output<'_>, _: S) -> Result + where + S: plumbing::serde::Serializer, + { + unimplemented!() + } + + fn deserialize<'de, D>(_: D) -> Result, D::Error> + where + D: plumbing::serde::Deserializer<'de>, + { + unimplemented!() + } } } diff --git a/src/id.rs b/src/id.rs index d8291f3c8..c8712c102 100644 --- a/src/id.rs +++ b/src/id.rs @@ -18,6 +18,7 @@ use crate::zalsa::Zalsa; /// As an end-user of `Salsa` you will generally not use `Id` directly, /// it is wrapped in new types. #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct Id { index: NonZeroU32, generation: u32, @@ -61,7 +62,7 @@ impl Id { #[doc(hidden)] #[track_caller] #[inline] - pub const unsafe fn from_bits(bits: u64) -> Self { + pub const unsafe fn from_bits_unchecked(bits: u64) -> Self { // SAFETY: Caller obligation. let index = unsafe { NonZeroU32::new_unchecked(bits as u32) }; let generation = (bits >> 32) as u32; @@ -69,6 +70,19 @@ impl Id { Id { index, generation } } + /// Create a `salsa::Id` from a `u64` value. + /// + /// This should only be used to recreate an `Id` together with `Id::as_u64`, + /// and may panic if the `Id` is invalid. + #[inline] + #[doc(hidden)] + pub const fn from_bits(bits: u64) -> Self { + let index = NonZeroU32::new(bits as u32).expect("attempted to create invalid `Id`"); + let generation = (bits >> 32) as u32; + + Id { index, generation } + } + /// Return a `u64` representation of this `Id`. #[inline] pub fn as_bits(self) -> u64 { diff --git a/src/ingredient.rs b/src/ingredient.rs index 3e1e0f2f7..4cd857962 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -10,7 +10,7 @@ use crate::runtime::Running; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::table::Table; -use crate::zalsa::{transmute_data_mut_ptr, transmute_data_ptr, IngredientIndex, Zalsa}; +use crate::zalsa::{transmute_data_mut_ptr, transmute_data_ptr, IngredientIndex, JarKind, Zalsa}; use crate::zalsa_local::QueryOriginRef; use crate::{DatabaseKeyIndex, Id, Revision}; @@ -39,6 +39,7 @@ pub struct Location { pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { fn debug_name(&self) -> &'static str; fn location(&self) -> &'static Location; + fn jar_kind(&self) -> JarKind; /// Has the value for `input` in this ingredient changed after `revision`? /// @@ -48,7 +49,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { unsafe fn maybe_changed_after( &self, zalsa: &crate::zalsa::Zalsa, - db: crate::database::RawDatabase<'_>, + db: RawDatabase<'_>, input: Id, revision: Revision, cycle_heads: &mut VerifyCycleHeads, @@ -186,6 +187,48 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { fn memory_usage(&self, _db: &dyn crate::Database) -> Option> { None } + + /// Whether this ingredient will be persisted with the database. + fn is_persistable(&self) -> bool { + false + } + + /// Whether there is data to serialize for this ingredient. + /// + /// If this returns `false`, the ingredient will not be serialized, even if `is_persistable` + /// returns `true`. + fn should_serialize(&self, _zalsa: &Zalsa) -> bool { + false + } + + /// Serialize the ingredient. + /// + /// This function should invoke the provided callback with a reference to an object implementing [`erased_serde::Serialize`]. + /// + /// # Safety + /// + /// While this method takes an immutable reference to the database, it can only be called when a + /// the serializer has exclusive access to the database. + // See for why this callback signature is necessary, instead + // of providing an `erased_serde::Serializer` directly. + #[cfg(feature = "persistence")] + unsafe fn serialize<'db>( + &'db self, + _zalsa: &'db Zalsa, + _f: &mut dyn FnMut(&dyn erased_serde::Serialize), + ) { + unimplemented!("called `serialize` on ingredient where `is_persistable` returns `false`") + } + + /// Deserialize the ingredient. + #[cfg(feature = "persistence")] + fn deserialize( + &mut self, + _zalsa: &mut Zalsa, + _deserializer: &mut dyn erased_serde::Deserializer, + ) -> Result<(), erased_serde::Error> { + unimplemented!("called `deserialize` on ingredient where `is_persistable` returns `false`") + } } impl dyn Ingredient { diff --git a/src/input.rs b/src/input.rs index b48b369c8..fd8fc018a 100644 --- a/src/input.rs +++ b/src/input.rs @@ -13,18 +13,21 @@ use crate::id::{AsId, FromId, FromIdWithDb}; use crate::ingredient::Ingredient; use crate::input::singleton::{Singleton, SingletonChoice}; use crate::key::DatabaseKeyIndex; -use crate::plumbing::Jar; +use crate::plumbing::{self, Jar, ZalsaLocal}; use crate::sync::Arc; use crate::table::memo::{MemoTable, MemoTableTypes}; use crate::table::{Slot, Table}; -use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{zalsa_local, Durability, Id, Revision, Runtime}; +use crate::zalsa::{IngredientIndex, JarKind, Zalsa}; +use crate::{Durability, Id, Revision, Runtime}; pub trait Configuration: Any { const DEBUG_NAME: &'static str; const FIELD_DEBUG_NAMES: &'static [&'static str]; const LOCATION: crate::ingredient::Location; + /// Whether this struct should be persisted with the database. + const PERSIST: bool; + /// The singleton state for this input if any. type Singleton: SingletonChoice + Send + Sync; @@ -35,15 +38,47 @@ pub trait Configuration: Any { type Fields: Send + Sync; /// A array of [`Revision`], one per each of the value fields. + #[cfg(feature = "persistence")] + type Revisions: Send + + Sync + + fmt::Debug + + IndexMut + + plumbing::serde::Serialize + + for<'de> plumbing::serde::Deserialize<'de>; + + #[cfg(not(feature = "persistence"))] type Revisions: Send + Sync + fmt::Debug + IndexMut; /// A array of [`Durability`], one per each of the value fields. + #[cfg(feature = "persistence")] + type Durabilities: Send + + Sync + + fmt::Debug + + IndexMut + + plumbing::serde::Serialize + + for<'de> plumbing::serde::Deserialize<'de>; + + #[cfg(not(feature = "persistence"))] type Durabilities: Send + Sync + fmt::Debug + IndexMut; /// Returns the size of any heap allocations in the output value, in bytes. fn heap_size(_value: &Self::Fields) -> Option { None } + + /// Serialize the fields using `serde`. + /// + /// Panics if the value is not persistable, i.e. `Configuration::PERSIST` is `false`. + fn serialize(value: &Self::Fields, serializer: S) -> Result + where + S: plumbing::serde::Serializer; + + /// Deserialize the fields using `serde`. + /// + /// Panics if the value is not persistable, i.e. `Configuration::PERSIST` is `false`. + fn deserialize<'de, D>(deserializer: D) -> Result + where + D: plumbing::serde::Deserializer<'de>; } pub struct JarImpl { @@ -102,14 +137,14 @@ impl IngredientImpl { table.get_raw(id) } - pub fn database_key_index(&self, id: C::Struct) -> DatabaseKeyIndex { - DatabaseKeyIndex::new(self.ingredient_index, id.as_id()) + pub fn database_key_index(&self, id: Id) -> DatabaseKeyIndex { + DatabaseKeyIndex::new(self.ingredient_index, id) } pub fn new_input( &self, zalsa: &Zalsa, - zalsa_local: &zalsa_local::ZalsaLocal, + zalsa_local: &ZalsaLocal, fields: C::Fields, revisions: C::Revisions, durabilities: C::Durabilities, @@ -183,7 +218,7 @@ impl IngredientImpl { pub fn field<'db>( &'db self, zalsa: &'db Zalsa, - zalsa_local: &'db zalsa_local::ZalsaLocal, + zalsa_local: &'db ZalsaLocal, id: C::Struct, field_index: usize, ) -> &'db C::Fields { @@ -200,10 +235,15 @@ impl IngredientImpl { &value.fields } - #[cfg(feature = "salsa_unstable")] /// Returns all data corresponding to the input struct. - pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { - zalsa.table().slots_of::>() + pub fn entries<'db>( + &'db self, + zalsa: &'db Zalsa, + ) -> impl Iterator)> + 'db { + zalsa + .table() + .slots_of::>() + .map(|(id, value)| (self.database_key_index(id), value)) } /// Peek at the field values without recording any read dependency. @@ -241,6 +281,10 @@ impl Ingredient for IngredientImpl { C::DEBUG_NAME } + fn jar_kind(&self) -> JarKind { + JarKind::Struct + } + fn memo_table_types(&self) -> &Arc { &self.memo_table_types } @@ -256,10 +300,45 @@ impl Ingredient for IngredientImpl { .entries(db.zalsa()) // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. - .map(|value| unsafe { value.memory_usage(&self.memo_table_types) }) + .map(|(_, value)| unsafe { value.memory_usage(&self.memo_table_types) }) .collect(); + Some(memory_usage) } + + fn is_persistable(&self) -> bool { + C::PERSIST + } + + fn should_serialize(&self, zalsa: &Zalsa) -> bool { + C::PERSIST && self.entries(zalsa).next().is_some() + } + + #[cfg(feature = "persistence")] + unsafe fn serialize<'db>( + &'db self, + zalsa: &'db Zalsa, + f: &mut dyn FnMut(&dyn erased_serde::Serialize), + ) { + f(&persistence::SerializeIngredient { + zalsa, + _ingredient: self, + }) + } + + #[cfg(feature = "persistence")] + fn deserialize( + &mut self, + zalsa: &mut Zalsa, + deserializer: &mut dyn erased_serde::Deserializer, + ) -> Result<(), erased_serde::Error> { + let deserialize = persistence::DeserializeIngredient { + zalsa, + ingredient: self, + }; + + serde::de::DeserializeSeed::deserialize(deserialize, deserializer) + } } impl std::fmt::Debug for IngredientImpl { @@ -344,3 +423,187 @@ where &mut self.memos } } + +#[cfg(feature = "persistence")] +mod persistence { + use std::fmt; + + use serde::ser::SerializeMap; + use serde::{de, Deserialize}; + + use super::{Configuration, IngredientImpl, Value}; + use crate::plumbing::Ingredient; + use crate::table::memo::MemoTable; + use crate::zalsa::Zalsa; + use crate::Id; + + pub struct SerializeIngredient<'db, C> + where + C: Configuration, + { + pub zalsa: &'db Zalsa, + pub _ingredient: &'db IngredientImpl, + } + + impl serde::Serialize for SerializeIngredient<'_, C> + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let Self { zalsa, .. } = self; + + let mut map = serializer.serialize_map(None)?; + + for (id, value) in zalsa.table().slots_of::>() { + map.serialize_entry(&id.as_bits(), value)?; + } + + map.end() + } + } + + impl serde::Serialize for Value + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut map = serializer.serialize_map(None)?; + + struct SerializeFields<'db, C: Configuration>(&'db C::Fields); + + impl serde::Serialize for SerializeFields<'_, C> + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + C::serialize(self.0, serializer) + } + } + + let Value { + fields, + revisions, + durabilities, + memos: _, + } = self; + + map.serialize_entry(&"durabilities", &durabilities)?; + map.serialize_entry(&"revisions", &revisions)?; + map.serialize_entry(&"fields", &SerializeFields::(fields))?; + + map.end() + } + } + + pub struct DeserializeIngredient<'db, C> + where + C: Configuration, + { + pub zalsa: &'db mut Zalsa, + pub ingredient: &'db mut IngredientImpl, + } + + impl<'de, C> de::DeserializeSeed<'de> for DeserializeIngredient<'_, C> + where + C: Configuration, + { + type Value = (); + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_map(self) + } + } + + impl<'de, C> de::Visitor<'de> for DeserializeIngredient<'_, C> + where + C: Configuration, + { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map") + } + + fn visit_map(self, mut access: M) -> Result + where + M: de::MapAccess<'de>, + { + let DeserializeIngredient { zalsa, ingredient } = self; + + while let Some((id, value)) = access.next_entry::>()? { + let id = Id::from_bits(id); + let (page_idx, _) = crate::table::split_id(id); + + let value = Value:: { + fields: value.fields.0, + revisions: value.revisions, + durabilities: value.durabilities, + // SAFETY: We only ever access the memos of a value that we allocated through + // our `MemoTableTypes`. + memos: unsafe { MemoTable::new(ingredient.memo_table_types()) }, + }; + + // Force initialize the relevant page. + zalsa.table_mut().force_page::>( + page_idx, + ingredient.ingredient_index(), + ingredient.memo_table_types(), + ); + + // Initialize the slot. + // + // SAFETY: We have a mutable reference to the database. + let (allocated_id, _) = unsafe { + zalsa + .table() + .page(page_idx) + .allocate(page_idx, |_| value) + .unwrap_or_else(|_| panic!("serialized an invalid `Id`: {id:?}")) + }; + + assert_eq!( + allocated_id, id, + "values are serialized in allocation order" + ); + } + + Ok(()) + } + } + + #[derive(Deserialize)] + pub struct DeserializeValue { + durabilities: C::Durabilities, + revisions: C::Revisions, + #[serde(bound = "C: Configuration")] + fields: DeserializeFields, + } + + struct DeserializeFields(C::Fields); + + impl<'de, C> serde::Deserialize<'de> for DeserializeFields + where + C: Configuration, + { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + C::deserialize(deserializer) + .map(DeserializeFields) + .map_err(de::Error::custom) + } + } +} diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 9b352b561..7bfeb507a 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -6,7 +6,7 @@ use crate::ingredient::Ingredient; use crate::input::{Configuration, IngredientImpl, Value}; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; -use crate::zalsa::IngredientIndex; +use crate::zalsa::{IngredientIndex, JarKind, Zalsa}; use crate::{Id, Revision}; /// Ingredient used to represent the fields of a `#[salsa::input]`. @@ -75,6 +75,10 @@ where C::FIELD_DEBUG_NAMES[self.field_index] } + fn jar_kind(&self) -> JarKind { + JarKind::Struct + } + fn memo_table_types(&self) -> &Arc { unreachable!("input fields do not allocate pages") } @@ -82,6 +86,16 @@ where fn memo_table_types_mut(&mut self) -> &mut Arc { unreachable!("input fields do not allocate pages") } + + fn is_persistable(&self) -> bool { + // Input field dependencies are valid as long as the input is persistable. + C::PERSIST + } + + fn should_serialize(&self, _zalsa: &Zalsa) -> bool { + // However, they are never serialized directly. + false + } } impl std::fmt::Debug for FieldIngredientImpl diff --git a/src/input/singleton.rs b/src/input/singleton.rs index 09107575f..c069a7fd1 100644 --- a/src/input/singleton.rs +++ b/src/input/singleton.rs @@ -35,7 +35,7 @@ impl SingletonChoice for Singleton { 0 => None, // SAFETY: Our u64 is derived from an ID and thus safe to convert back. - id => Some(unsafe { Id::from_bits(id) }), + id => Some(unsafe { Id::from_bits_unchecked(id) }), } } } diff --git a/src/interned.rs b/src/interned.rs index 33723ac2c..547a5a67e 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -14,12 +14,12 @@ use crate::durability::Durability; use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::id::{AsId, FromId}; use crate::ingredient::Ingredient; -use crate::plumbing::{Jar, ZalsaLocal}; +use crate::plumbing::{self, Jar, ZalsaLocal}; use crate::revision::AtomicRevision; use crate::sync::{Arc, Mutex, OnceLock}; use crate::table::memo::{MemoTable, MemoTableTypes, MemoTableWithTypesMut}; use crate::table::Slot; -use crate::zalsa::{IngredientIndex, Zalsa}; +use crate::zalsa::{IngredientIndex, JarKind, Zalsa}; use crate::{DatabaseKeyIndex, Event, EventKind, Id, Revision}; /// Trait that defines the key properties of an interned struct. @@ -28,9 +28,11 @@ use crate::{DatabaseKeyIndex, Event, EventKind, Id, Revision}; /// a struct. pub trait Configuration: Sized + 'static { const LOCATION: crate::ingredient::Location; - const DEBUG_NAME: &'static str; + /// Whether this struct should be persisted with the database. + const PERSIST: bool; + // The minimum number of revisions that must pass before a stale value is garbage collected. #[cfg(test)] const REVISIONS: NonZeroUsize = NonZeroUsize::new(3).unwrap(); @@ -48,6 +50,20 @@ pub trait Configuration: Sized + 'static { fn heap_size(_value: &Self::Fields<'_>) -> Option { None } + + /// Serialize the fields using `serde`. + /// + /// Panics if the value is not persistable, i.e. `Configuration::PERSIST` is `false`. + fn serialize(value: &Self::Fields<'_>, serializer: S) -> Result + where + S: plumbing::serde::Serializer; + + /// Deserialize the fields using `serde`. + /// + /// Panics if the value is not persistable, i.e. `Configuration::PERSIST` is `false`. + fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: plumbing::serde::Deserializer<'de>; } pub trait InternedData: Sized + Eq + Hash + Clone + Sync + Send {} @@ -138,6 +154,7 @@ where /// Shared value data can only be read through the lock. #[repr(Rust, packed)] // Allow `durability` to be stored in the padding of the outer `Value` struct. +#[derive(Clone, Copy)] struct ValueShared { /// The interned ID for this value. /// @@ -604,6 +621,40 @@ where }), }); + // Insert the newly allocated ID. + self.insert_id(id, zalsa, shard, hash, value); + + let index = self.database_key_index(id); + + // Record a dependency on the newly interned value. + // + // Note that the ID is unique to this use of the interned slot, so it seems logical to use + // `Revision::start()` here. However, it is possible that the ID we read is different from + // the previous execution of this query if the previous slot has been reused. In that case, + // the query has changed without a corresponding input changing. Using `current_revision` + // for dependencies on interned values encodes the fact that interned IDs are not stable + // across revisions. + zalsa_local.report_tracked_read_simple(index, durability, current_revision); + + zalsa.event(&|| { + Event::new(EventKind::DidInternValue { + key: index, + revision: current_revision, + }) + }); + + id + } + + /// Inserts a newly interned value ID into the LRU list and key map. + fn insert_id( + &self, + id: Id, + zalsa: &Zalsa, + shard: &mut IngredientShard, + hash: u64, + value: &Value, + ) { // SAFETY: We hold the lock for the shard containing the value. let value_shared = unsafe { &mut *value.shared.get() }; @@ -627,27 +678,6 @@ where // SAFETY: We hold the lock for the shard containing the value. unsafe { self.hasher.hash_one(&*value.fields.get()) } }); - - let index = self.database_key_index(id); - - // Record a dependency on the newly interned value. - // - // Note that the ID is unique to this use of the interned slot, so it seems logical to use - // `Revision::start()` here. However, it is possible that the ID we read is different from - // the previous execution of this query if the previous slot has been reused. In that case, - // the query has changed without a corresponding input changing. Using `current_revision` - // for dependencies on interned values encodes the fact that interned IDs are not stable - // across revisions. - zalsa_local.report_tracked_read_simple(index, durability, current_revision); - - zalsa.event(&|| { - Event::new(EventKind::DidInternValue { - key: index, - revision: current_revision, - }) - }); - - id } /// Clears the given memo table. @@ -778,10 +808,40 @@ where } } - #[cfg(feature = "salsa_unstable")] /// Returns all data corresponding to the interned struct. - pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { - zalsa.table().slots_of::>() + pub fn entries<'db>( + &'db self, + zalsa: &'db Zalsa, + ) -> impl Iterator)> + 'db { + // SAFETY: `should_lock` is `true` + unsafe { self.entries_inner(true, zalsa) } + } + + /// Returns all data corresponding to the interned struct. + /// + /// # Safety + /// + /// If `should_lock` is `false`, the caller *must* hold the locks for all shards + /// of the key map. + unsafe fn entries_inner<'db>( + &'db self, + should_lock: bool, + zalsa: &'db Zalsa, + ) -> impl Iterator)> + 'db { + // TODO: Grab all locks eagerly. + zalsa.table().slots_of::>().map(move |(_, value)| { + if should_lock { + // SAFETY: `value.shard` is guaranteed to be in-bounds for `self.shards`. + let _shard = unsafe { self.shards.get_unchecked(value.shard as usize) }.lock(); + } + + // SAFETY: The caller guarantees we hold the lock for the shard containing the value. + // + // Note that this ID includes the generation, unlike the ID provided by the table. + let id = unsafe { (*value.shared.get()).id }; + + (self.database_key_index(id), value) + }) } } @@ -842,6 +902,10 @@ where C::DEBUG_NAME } + fn jar_kind(&self) -> JarKind { + JarKind::Struct + } + fn memo_table_types(&self) -> &Arc { &self.memo_table_types } @@ -860,11 +924,13 @@ where unsafe { shard.raw().lock() }; } - let memory_usage = self - .entries(db.zalsa()) + // SAFETY: We hold the locks for all shards. + let entries = unsafe { self.entries_inner(false, db.zalsa()) }; + + let memory_usage = entries // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. Additionally, we are holding the locks for all shards. - .map(|value| unsafe { value.memory_usage(&self.memo_table_types) }) + .map(|(_, value)| unsafe { value.memory_usage(&self.memo_table_types) }) .collect(); for shard in self.shards.iter() { @@ -874,6 +940,40 @@ where Some(memory_usage) } + + fn is_persistable(&self) -> bool { + C::PERSIST + } + + fn should_serialize(&self, zalsa: &Zalsa) -> bool { + C::PERSIST && self.entries(zalsa).next().is_some() + } + + #[cfg(feature = "persistence")] + unsafe fn serialize<'db>( + &'db self, + zalsa: &'db Zalsa, + f: &mut dyn FnMut(&dyn erased_serde::Serialize), + ) { + f(&persistence::SerializeIngredient { + zalsa, + _ingredient: self, + }) + } + + #[cfg(feature = "persistence")] + fn deserialize( + &mut self, + zalsa: &mut Zalsa, + deserializer: &mut dyn erased_serde::Deserializer, + ) -> Result<(), erased_serde::Error> { + let deserialize = persistence::DeserializeIngredient { + zalsa, + ingredient: self, + }; + + serde::de::DeserializeSeed::deserialize(deserialize, deserializer) + } } impl std::fmt::Debug for IngredientImpl @@ -1189,3 +1289,224 @@ impl Lookup for &Path { self.to_owned() } } + +#[cfg(feature = "persistence")] +mod persistence { + use std::cell::UnsafeCell; + use std::fmt; + use std::hash::BuildHasher; + + use intrusive_collections::LinkedListLink; + use serde::ser::SerializeMap; + use serde::{de, Deserialize}; + + use super::{Configuration, IngredientImpl, Value, ValueShared}; + use crate::plumbing::Ingredient; + use crate::table::memo::MemoTable; + use crate::zalsa::Zalsa; + use crate::{Durability, Id, Revision}; + + pub struct SerializeIngredient<'db, C> + where + C: Configuration, + { + pub zalsa: &'db Zalsa, + pub _ingredient: &'db IngredientImpl, + } + + impl serde::Serialize for SerializeIngredient<'_, C> + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let Self { zalsa, .. } = self; + + let mut map = serializer.serialize_map(None)?; + + for (_, value) in zalsa.table().slots_of::>() { + // SAFETY: The safety invariant of `Ingredient::serialize` ensures we have exclusive access + // to the database. + let id = unsafe { (*value.shared.get()).id }; + + map.serialize_entry(&id.as_bits(), value)?; + } + + map.end() + } + } + + impl serde::Serialize for Value + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut map = serializer.serialize_map(None)?; + + struct SerializeFields<'db, C: Configuration>(&'db C::Fields<'static>); + + impl serde::Serialize for SerializeFields<'_, C> + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + C::serialize(self.0, serializer) + } + } + + let Value { + fields, + shared, + shard: _, + link: _, + memos: _, + } = self; + + // SAFETY: The safety invariant of `Ingredient::serialize` ensures we have exclusive access + // to the database. + let fields = unsafe { &*fields.get() }; + + // SAFETY: The safety invariant of `Ingredient::serialize` ensures we have exclusive access + // to the database. + let ValueShared { + durability, + last_interned_at, + id: _, + } = unsafe { *shared.get() }; + + map.serialize_entry(&"durability", &durability)?; + map.serialize_entry(&"last_interned_at", &last_interned_at)?; + map.serialize_entry(&"fields", &SerializeFields::(fields))?; + + map.end() + } + } + + pub struct DeserializeIngredient<'db, C> + where + C: Configuration, + { + pub zalsa: &'db mut Zalsa, + pub ingredient: &'db mut IngredientImpl, + } + + impl<'de, C> de::DeserializeSeed<'de> for DeserializeIngredient<'_, C> + where + C: Configuration, + { + type Value = (); + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_map(self) + } + } + + impl<'de, C> de::Visitor<'de> for DeserializeIngredient<'_, C> + where + C: Configuration, + { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map") + } + + fn visit_map(self, mut access: M) -> Result + where + M: de::MapAccess<'de>, + { + let DeserializeIngredient { zalsa, ingredient } = self; + + while let Some((id, value)) = access.next_entry::>()? { + let id = Id::from_bits(id); + let (page_idx, _) = crate::table::split_id(id); + + // Determine the value shard. + let hash = ingredient.hasher.hash_one(&value.fields.0); + let shard_index = ingredient.shard(hash); + + // SAFETY: `shard_index` is guaranteed to be in-bounds for `self.shards`. + let shard = unsafe { &mut *ingredient.shards.get_unchecked(shard_index).lock() }; + + let value = Value:: { + shard: shard_index as u16, + link: LinkedListLink::new(), + // SAFETY: We only ever access the memos of a value that we allocated through + // our `MemoTableTypes`. + memos: UnsafeCell::new(unsafe { + MemoTable::new(ingredient.memo_table_types()) + }), + fields: UnsafeCell::new(value.fields.0), + shared: UnsafeCell::new(ValueShared { + id, + durability: value.durability, + last_interned_at: value.last_interned_at, + }), + }; + + // Force initialize the relevant page. + zalsa.table_mut().force_page::>( + page_idx, + ingredient.ingredient_index(), + ingredient.memo_table_types(), + ); + + // Initialize the slot. + // + // SAFETY: We have a mutable reference to the database. + let (allocated_id, value) = unsafe { + zalsa + .table() + .page(page_idx) + .allocate(page_idx, |_| value) + .unwrap_or_else(|_| panic!("serialized an invalid `Id`: {id:?}")) + }; + + assert_eq!( + allocated_id, id, + "values are serialized in allocation order" + ); + + // Insert the newly allocated ID into our ingredient. + ingredient.insert_id(id, zalsa, shard, hash, value); + } + + Ok(()) + } + } + + #[derive(Deserialize)] + pub struct DeserializeValue { + durability: Durability, + last_interned_at: Revision, + #[serde(bound = "C: Configuration")] + fields: DeserializeFields, + } + + struct DeserializeFields(C::Fields<'static>); + + impl<'de, C> serde::Deserialize<'de> for DeserializeFields + where + C: Configuration, + { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + C::deserialize(deserializer) + .map(DeserializeFields) + .map_err(de::Error::custom) + } + } +} diff --git a/src/key.rs b/src/key.rs index 47f750e7f..bb5604c5f 100644 --- a/src/key.rs +++ b/src/key.rs @@ -10,6 +10,7 @@ use crate::Id; /// ordered and equatable but those orderings are arbitrary, and meant to be used /// only for inserting into maps and the like. #[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct DatabaseKeyIndex { key_index: Id, ingredient_index: IngredientIndex, diff --git a/src/lib.rs b/src/lib.rs index 66c346b20..d846cbd76 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -123,6 +123,25 @@ pub mod plumbing { }; pub use crate::zalsa_local::ZalsaLocal; + #[cfg(feature = "persistence")] + pub use serde; + + // A stub for `serde` used when persistence is disabled. + // + // We provide dummy types to avoid detecting features during macro expansion. + #[cfg(not(feature = "persistence"))] + pub mod serde { + pub trait Serializer { + type Ok; + type Error; + } + + pub trait Deserializer<'de> { + type Ok; + type Error; + } + } + #[cfg(feature = "accumulator")] pub mod accumulator { pub use crate::accumulator::{IngredientImpl, JarImpl}; diff --git a/src/memo_ingredient_indices.rs b/src/memo_ingredient_indices.rs index ba1dcf45d..a2df50dd1 100644 --- a/src/memo_ingredient_indices.rs +++ b/src/memo_ingredient_indices.rs @@ -42,6 +42,10 @@ impl IngredientIndices { indices: indices.into_boxed_slice(), } } + + pub fn iter(&self) -> impl Iterator + '_ { + self.indices.iter().copied() + } } pub trait NewMemoIngredientIndices { diff --git a/src/revision.rs b/src/revision.rs index 11cfd149c..852313e0d 100644 --- a/src/revision.rs +++ b/src/revision.rs @@ -54,6 +54,26 @@ impl Revision { } } +#[cfg(feature = "persistence")] +impl serde::Serialize for Revision { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serde::Serialize::serialize(&self.as_usize(), serializer) + } +} + +#[cfg(feature = "persistence")] +impl<'de> serde::Deserialize<'de> for Revision { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + serde::Deserialize::deserialize(deserializer).map(|generation| Self { generation }) + } +} + impl std::fmt::Debug for Revision { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(fmt, "R{}", self.generation) @@ -65,6 +85,28 @@ pub(crate) struct AtomicRevision { data: AtomicUsize, } +#[cfg(feature = "persistence")] +impl serde::Serialize for AtomicRevision { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serde::Serialize::serialize(&self.data.load(Ordering::Relaxed), serializer) + } +} + +#[cfg(feature = "persistence")] +impl<'de> serde::Deserialize<'de> for AtomicRevision { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + serde::Deserialize::deserialize(deserializer).map(|data| Self { + data: AtomicUsize::new(data), + }) + } +} + impl From for AtomicRevision { fn from(value: Revision) -> Self { Self { @@ -97,6 +139,27 @@ impl AtomicRevision { pub(crate) struct OptionalAtomicRevision { data: AtomicUsize, } +#[cfg(feature = "persistence")] +impl serde::Serialize for OptionalAtomicRevision { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serde::Serialize::serialize(&self.data.load(Ordering::Relaxed), serializer) + } +} + +#[cfg(feature = "persistence")] +impl<'de> serde::Deserialize<'de> for OptionalAtomicRevision { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + serde::Deserialize::deserialize(deserializer).map(|data| Self { + data: AtomicUsize::new(data), + }) + } +} impl From for OptionalAtomicRevision { fn from(value: Revision) -> Self { diff --git a/src/runtime.rs b/src/runtime.rs index 6a4d1e8b8..ec3b091d5 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -11,10 +11,12 @@ use crate::{Cancelled, Event, EventKind, Revision}; mod dependency_graph; +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct Runtime { /// Set to true when the current revision has been canceled. /// This is done when we an input is being changed. The flag /// is set back to false once the input has been changed. + #[cfg_attr(feature = "persistence", serde(skip))] revision_canceled: AtomicBool, /// Stores the "last change" revision for values of each duration. @@ -30,9 +32,11 @@ pub struct Runtime { /// The dependency graph tracks which runtimes are blocked on one /// another, waiting for queries to terminate. + #[cfg_attr(feature = "persistence", serde(skip))] dependency_graph: Mutex, /// Data for instances + #[cfg_attr(feature = "persistence", serde(skip))] table: Table, } @@ -263,4 +267,10 @@ impl Runtime { .lock() .unblock_runtimes_blocked_on(database_key, wait_result); } + + #[cfg(feature = "persistence")] + pub(crate) fn deserialize_from(&mut self, other: &mut Runtime) { + // The only field that is serialized is `revisions`. + self.revisions = other.revisions; + } } diff --git a/src/salsa_struct.rs b/src/salsa_struct.rs index cb3307e65..73010ef2b 100644 --- a/src/salsa_struct.rs +++ b/src/salsa_struct.rs @@ -3,7 +3,7 @@ use std::any::TypeId; use crate::memo_ingredient_indices::{IngredientIndices, MemoIngredientMap}; use crate::table::memo::MemoTableWithTypes; use crate::zalsa::Zalsa; -use crate::{Id, Revision}; +use crate::{DatabaseKeyIndex, Id, Revision}; pub trait SalsaStructInDb: Sized { type MemoIngredientMap: MemoIngredientMap; @@ -19,6 +19,9 @@ pub trait SalsaStructInDb: Sized { /// call [`crate::zalsa::JarEntry::get_or_create`] for their variants and combine them. fn lookup_ingredient_index(zalsa: &Zalsa) -> IngredientIndices; + /// Returns the IDs of any instances of this struct in the database. + fn entries(zalsa: &Zalsa) -> impl Iterator + '_; + /// Plumbing to support nested salsa supertypes. /// /// In the example below, there are two supertypes: `InnerEnum` and `OuterEnum`, diff --git a/src/sync.rs b/src/sync.rs index e3472d2da..2498cd1a7 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -77,10 +77,6 @@ pub mod shim { self.get().unwrap() } - pub fn set(&self, value: T) -> Result<(), T> { - self.set_with(|| value).map_err(|f| f()) - } - fn set_with(&self, f: F) -> Result<(), F> where F: FnOnce() -> T, diff --git a/src/table.rs b/src/table.rs index c6d22118b..53cf10cce 100644 --- a/src/table.rs +++ b/src/table.rs @@ -127,7 +127,7 @@ unsafe impl Send for Page /* where for M: Send */ {} // requires `Sync`.` unsafe impl Sync for Page /* where for M: Sync */ {} -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct PageIndex(usize); impl PageIndex { @@ -136,10 +136,15 @@ impl PageIndex { debug_assert!(idx < MAX_PAGES); Self(idx) } + + #[allow(dead_code)] + pub fn as_usize(&self) -> usize { + self.0 + } } #[derive(Copy, Clone, Debug)] -struct SlotIndex(usize); +pub struct SlotIndex(usize); impl SlotIndex { #[inline] @@ -192,6 +197,11 @@ impl Table { page_ref.page_data()[slot.0].get().cast::() } + /// Returns the number of pages that have been allocated. + pub fn page_count(&self) -> usize { + self.pages.count() + } + /// Gets a reference to the page which has slots of type `T` /// /// # Panics @@ -202,6 +212,51 @@ impl Table { self.pages[page.0].assert_type::() } + /// Force initialize the page at the given index. + /// + /// If the page at the provided index was created using `push_uninit_page`, it + /// will be initialized using the provided ingredient data. + /// + /// Otherwise, the page will be allocated. + /// + /// # Panics + /// + /// If `page` is out of bounds or the type `T` is incorrect. + #[inline] + #[allow(dead_code)] + pub(crate) fn force_page( + &mut self, + page_idx: PageIndex, + ingredient: IngredientIndex, + memo_types: &Arc, + ) { + let page = self.pages.get_mut(page_idx.0); + + match page { + Some(page) => { + // Initialize the page if was created using `push_uninit_page`. + if page.slot_type_id == TypeId::of::() { + *page = Page::new::(ingredient, memo_types.clone()); + } + + // Ensure the page has the correct type. + page.assert_type::(); + } + + None => { + // Create dummy pages until we reach the page we want. + while self.page_count() < page_idx.as_usize() { + // We make sure not to claim any intermediary pages for ourselves, as they may + // be required by a different ingredient when it is deserialized. + self.push_uninit_page(); + } + + let allocated_idx = self.push_page::(ingredient, memo_types.clone()); + assert_eq!(allocated_idx, page_idx); + } + }; + } + /// Allocate a new page for the given ingredient and with slots of type `T` #[inline] pub(crate) fn push_page( @@ -212,6 +267,18 @@ impl Table { PageIndex::new(self.pages.push(Page::new::(ingredient, memo_types))) } + /// Allocate an uninitialized page. + #[inline] + #[allow(dead_code)] + pub(crate) fn push_uninit_page(&self) -> PageIndex { + // Note that `DummySlot` is a ZST, so the memory wasted by any pages of ingredients + // that were not serialized should be negligible. + PageIndex::new(self.pages.push(Page::new::( + IngredientIndex::new(0), + Arc::new(MemoTableTypes::default()), + ))) + } + /// Get the memo table associated with `id` for the concrete type `T`. /// /// # Safety @@ -270,12 +337,19 @@ impl Table { unsafe { page.memo_types.attach_memos_mut(memos) } } - #[cfg(feature = "salsa_unstable")] - pub(crate) fn slots_of(&self) -> impl Iterator + '_ { + pub(crate) fn slots_of(&self) -> impl Iterator + '_ { self.pages .iter() - .filter_map(|(_, page)| page.cast_type::()) - .flat_map(|view| view.data()) + .filter_map(|(page_index, page)| Some((page_index, page.cast_type::()?))) + .flat_map(move |(page_index, view)| { + view.data() + .iter() + .enumerate() + .map(move |(slot_index, value)| { + let id = make_id(PageIndex::new(page_index), SlotIndex::new(slot_index)); + (id, value) + }) + }) } #[cold] @@ -296,7 +370,6 @@ impl Table { self.push_page::(ingredient, memo_types()) } - pub(crate) fn record_unfilled_page(&self, ingredient: IngredientIndex, page: PageIndex) { self.non_full_pages .lock() @@ -417,7 +490,6 @@ impl Page { PageView(self, PhantomData) } - #[cfg(feature = "salsa_unstable")] fn cast_type(&self) -> Option> { if self.slot_type_id == TypeId::of::() { Some(PageView(self, PhantomData)) @@ -446,6 +518,20 @@ impl Drop for Page { } } +/// A placeholder type representing the slots of an uninitialized `Page`. +struct DummySlot; + +// SAFETY: The `DummySlot type is private. +unsafe impl Slot for DummySlot { + unsafe fn memos(&self, _: Revision) -> &MemoTable { + unreachable!() + } + + fn memos_mut(&mut self) -> &mut MemoTable { + unreachable!() + } +} + fn make_id(page: PageIndex, slot: SlotIndex) -> Id { let page = page.0 as u32; let slot = slot.0 as u32; @@ -454,7 +540,7 @@ fn make_id(page: PageIndex, slot: SlotIndex) -> Id { } #[inline] -fn split_id(id: Id) -> (PageIndex, SlotIndex) { +pub fn split_id(id: Id) -> (PageIndex, SlotIndex) { let index = id.index() as usize; let slot = index & PAGE_LEN_MASK; let page = index >> PAGE_LEN_BITS; diff --git a/src/table/memo.rs b/src/table/memo.rs index 7e4837aa1..3b6366934 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -64,7 +64,7 @@ pub trait Memo: Any + Send + Sync { /// Therefore, we hide the type by transmuting to `DummyMemo`; but we must then be very careful /// when freeing `MemoEntryData` values to transmute things back. See the `Drop` impl for /// [`MemoEntry`][] for details. -#[derive(Default)] +#[derive(Default, Debug)] struct MemoEntry { /// An [`AtomicPtr`][] to a `Box` for the erased memo type `M` atomic_memo: AtomicPtr, diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 116e96c00..451533d07 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -14,14 +14,14 @@ use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::id::{AsId, FromId}; use crate::ingredient::{Ingredient, Jar}; use crate::key::DatabaseKeyIndex; -use crate::plumbing::ZalsaLocal; +use crate::plumbing::{self, ZalsaLocal}; use crate::revision::OptionalAtomicRevision; use crate::runtime::Stamp; use crate::salsa_struct::SalsaStructInDb; use crate::sync::Arc; use crate::table::memo::{MemoTable, MemoTableTypes, MemoTableWithTypesMut}; use crate::table::{Slot, Table}; -use crate::zalsa::{IngredientIndex, Zalsa}; +use crate::zalsa::{IngredientIndex, JarKind, Zalsa}; use crate::{Durability, Event, EventKind, Id, Revision}; pub mod tracked_field; @@ -43,6 +43,9 @@ pub trait Configuration: Sized + 'static { /// The relative indices of any tracked fields. const TRACKED_FIELD_INDICES: &'static [usize]; + /// Whether this struct should be persisted with the database. + const PERSIST: bool; + /// A (possibly empty) tuple of the fields for this struct. type Fields<'db>: Send + Sync; @@ -50,6 +53,14 @@ pub trait Configuration: Sized + 'static { /// When a struct is re-recreated in a new revision, the corresponding /// entries for each field are updated to the new revision if their /// values have changed (or if the field is marked as `#[no_eq]`). + #[cfg(feature = "persistence")] + type Revisions: Send + + Sync + + Index + + plumbing::serde::Serialize + + for<'de> plumbing::serde::Deserialize<'de>; + + #[cfg(not(feature = "persistence"))] type Revisions: Send + Sync + Index; type Struct<'db>: Copy + FromId + AsId; @@ -94,6 +105,20 @@ pub trait Configuration: Sized + 'static { fn heap_size(_value: &Self::Fields<'_>) -> Option { None } + + /// Serialize the fields using `serde`. + /// + /// Panics if the value is not persistable, i.e. `Configuration::PERSIST` is `false`. + fn serialize(value: &Self::Fields<'_>, serializer: S) -> Result + where + S: plumbing::serde::Serializer; + + /// Deserialize the fields using `serde`. + /// + /// Panics if the value is not persistable, i.e. `Configuration::PERSIST` is `false`. + fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: plumbing::serde::Deserializer<'de>; } // ANCHOR_END: Configuration @@ -179,6 +204,7 @@ where /// stored in the [`ActiveQuery`](`crate::active_query::ActiveQuery`) /// struct and later moved to the [`Memo`](`crate::function::memo::Memo`). #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub(crate) struct Identity { // Conceptually, this contains an `IdentityHash`, but using `IdentityHash` directly will grow the size // of this struct struct by a `std::mem::size_of::()` due to unusable padding. To avoid this increase @@ -321,6 +347,7 @@ where // ANCHOR_END: ValueStruct #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct Disambiguator(u32); #[derive(Default, Debug)] @@ -796,10 +823,15 @@ where data.fields() } - #[cfg(feature = "salsa_unstable")] /// Returns all data corresponding to the tracked struct. - pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { - zalsa.table().slots_of::>() + pub fn entries<'db>( + &'db self, + zalsa: &'db Zalsa, + ) -> impl Iterator)> + 'db { + zalsa + .table() + .slots_of::>() + .map(|(id, value)| (self.database_key_index(id), value)) } } @@ -855,6 +887,10 @@ where C::DEBUG_NAME } + fn jar_kind(&self) -> JarKind { + JarKind::Struct + } + fn memo_table_types(&self) -> &Arc { &self.memo_table_types } @@ -870,10 +906,45 @@ where .entries(db.zalsa()) // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. - .map(|value| unsafe { value.memory_usage(&self.memo_table_types) }) + .map(|(_, value)| unsafe { value.memory_usage(&self.memo_table_types) }) .collect(); + Some(memory_usage) } + + fn is_persistable(&self) -> bool { + C::PERSIST + } + + fn should_serialize(&self, zalsa: &Zalsa) -> bool { + C::PERSIST && self.entries(zalsa).next().is_some() + } + + #[cfg(feature = "persistence")] + unsafe fn serialize<'db>( + &'db self, + zalsa: &'db Zalsa, + f: &mut dyn FnMut(&dyn erased_serde::Serialize), + ) { + f(&persistence::SerializeIngredient { + zalsa, + _ingredient: self, + }) + } + + #[cfg(feature = "persistence")] + fn deserialize( + &mut self, + zalsa: &mut Zalsa, + deserializer: &mut dyn erased_serde::Deserializer, + ) -> Result<(), erased_serde::Error> { + let deserialize = persistence::DeserializeIngredient { + zalsa, + ingredient: self, + }; + + serde::de::DeserializeSeed::deserialize(deserialize, deserializer) + } } impl std::fmt::Debug for IngredientImpl @@ -1072,3 +1143,192 @@ mod tests { }; } } + +#[cfg(feature = "persistence")] +mod persistence { + use std::fmt; + + use serde::ser::SerializeMap; + use serde::{de, Deserialize}; + + use super::{Configuration, IngredientImpl, Value}; + use crate::plumbing::Ingredient; + use crate::revision::OptionalAtomicRevision; + use crate::table::memo::MemoTable; + use crate::zalsa::Zalsa; + use crate::{Durability, Id}; + + pub struct SerializeIngredient<'db, C> + where + C: Configuration, + { + pub zalsa: &'db Zalsa, + pub _ingredient: &'db IngredientImpl, + } + + impl serde::Serialize for SerializeIngredient<'_, C> + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let Self { zalsa, .. } = self; + + let mut map = serializer.serialize_map(None)?; + + for (id, value) in zalsa.table().slots_of::>() { + map.serialize_entry(&id.as_bits(), value)?; + } + + map.end() + } + } + + impl serde::Serialize for Value + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut map = serializer.serialize_map(None)?; + + struct SerializeFields<'db, C: Configuration>(&'db C::Fields<'static>); + + impl serde::Serialize for SerializeFields<'_, C> + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + C::serialize(self.0, serializer) + } + } + + let Value { + durability, + updated_at, + fields, + revisions, + memos: _, + } = self; + + map.serialize_entry(&"durability", &durability)?; + map.serialize_entry(&"updated_at", &updated_at)?; + map.serialize_entry(&"revisions", &revisions)?; + map.serialize_entry(&"fields", &SerializeFields::(fields))?; + + map.end() + } + } + + pub struct DeserializeIngredient<'db, C> + where + C: Configuration, + { + pub zalsa: &'db mut Zalsa, + pub ingredient: &'db mut IngredientImpl, + } + + impl<'de, C> de::DeserializeSeed<'de> for DeserializeIngredient<'_, C> + where + C: Configuration, + { + type Value = (); + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_map(self) + } + } + + impl<'de, C> de::Visitor<'de> for DeserializeIngredient<'_, C> + where + C: Configuration, + { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map") + } + + fn visit_map(self, mut access: M) -> Result + where + M: de::MapAccess<'de>, + { + let DeserializeIngredient { zalsa, ingredient } = self; + + while let Some((id, value)) = access.next_entry::>()? { + let id = Id::from_bits(id); + let (page_idx, _) = crate::table::split_id(id); + + let value = Value:: { + updated_at: value.updated_at, + durability: value.durability, + fields: value.fields.0, + revisions: value.revisions, + // SAFETY: We only ever access the memos of a value that we allocated through + // our `MemoTableTypes`. + memos: unsafe { MemoTable::new(ingredient.memo_table_types()) }, + }; + + // Force initialize the relevant page. + zalsa.table_mut().force_page::>( + page_idx, + ingredient.ingredient_index(), + ingredient.memo_table_types(), + ); + + // Initialize the slot. + // + // SAFETY: We have a mutable reference to the database. + let (allocated_id, _) = unsafe { + zalsa + .table() + .page(page_idx) + .allocate(page_idx, |_| value) + .unwrap_or_else(|_| panic!("serialized an invalid `Id`: {id:?}")) + }; + + assert_eq!( + allocated_id, id, + "values are serialized in allocation order" + ); + } + + Ok(()) + } + } + + #[derive(Deserialize)] + pub struct DeserializeValue { + durability: Durability, + updated_at: OptionalAtomicRevision, + revisions: C::Revisions, + #[serde(bound = "C: Configuration")] + fields: DeserializeFields, + } + + struct DeserializeFields(C::Fields<'static>); + + impl<'de, C> serde::Deserialize<'de> for DeserializeFields + where + C: Configuration, + { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + C::deserialize(deserializer) + .map(DeserializeFields) + .map_err(de::Error::custom) + } + } +} diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 0d565bcfd..20c59b81c 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -5,7 +5,7 @@ use crate::ingredient::Ingredient; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::tracked_struct::{Configuration, Value}; -use crate::zalsa::IngredientIndex; +use crate::zalsa::{IngredientIndex, JarKind}; use crate::Id; /// Created for each tracked struct. @@ -81,6 +81,10 @@ where C::TRACKED_FIELD_NAMES[self.field_index] } + fn jar_kind(&self) -> JarKind { + JarKind::Struct + } + fn memo_table_types(&self) -> &Arc { unreachable!("tracked field does not allocate pages") } diff --git a/src/zalsa.rs b/src/zalsa.rs index 1cc6ba5f5..118c890d8 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -73,6 +73,7 @@ static NONCE: crate::nonce::NonceGenerator = crate::nonce::NonceGe /// The database contains a number of jars, and each jar contains a number of ingredients. /// Each ingredient is given a unique index as the database is being created. #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct IngredientIndex(u32); impl IngredientIndex { @@ -207,7 +208,10 @@ impl Zalsa { let mut jars = jars; // Ensure structs are initialized before tracked functions. - jars.sort_by_key(|jar| jar.kind); + // + // We also further sort by debug name, to maintain a consistent ordering across + // builds. + jars.sort_by(|a, b| a.kind.cmp(&b.kind).then(a.type_name().cmp(b.type_name()))); for jar in jars { zalsa.insert_jar(jar); @@ -235,6 +239,13 @@ impl Zalsa { self.runtime.table() } + /// Returns a mutable reference to the [`Table`] used to store the value of salsa structs + #[inline] + #[allow(dead_code)] + pub(crate) fn table_mut(&mut self) -> &mut Table { + self.runtime.table_mut() + } + /// Returns the [`MemoTable`][] for the salsa struct with the given id pub(crate) fn memo_table_for(&self, id: Id) -> MemoTableWithTypes<'_> { // SAFETY: We are supplying the correct current revision. @@ -273,7 +284,7 @@ impl Zalsa { [memo_ingredient_index.as_usize()] } - #[cfg(feature = "salsa_unstable")] + #[allow(unused)] pub(crate) fn ingredients(&self) -> impl Iterator { self.ingredients_vec .iter() @@ -396,6 +407,19 @@ impl Zalsa { (ingredient.as_mut(), &mut self.runtime) } + /// **NOT SEMVER STABLE** + #[doc(hidden)] + pub fn take_ingredient(&mut self, index: IngredientIndex) -> Box { + self.ingredients_vec.remove(index.as_u32() as usize) + } + + /// **NOT SEMVER STABLE** + #[doc(hidden)] + pub fn replace_ingredient(&mut self, index: IngredientIndex, ingredient: Box) { + self.ingredients_vec + .insert(index.as_u32() as usize, ingredient); + } + /// **NOT SEMVER STABLE** #[doc(hidden)] #[inline] @@ -469,6 +493,7 @@ impl Zalsa { pub struct ErasedJar { kind: JarKind, type_id: fn() -> TypeId, + type_name: fn() -> &'static str, id_struct_type_id: fn() -> TypeId, create_ingredients: fn(&mut Zalsa, IngredientIndex) -> Vec>, } @@ -493,10 +518,15 @@ impl ErasedJar { Self { kind: I::KIND, type_id: TypeId::of::, + type_name: std::any::type_name::, create_ingredients: ::create_ingredients, id_struct_type_id: ::id_struct_type_id, } } + + pub fn type_name(&self) -> &'static str { + (self.type_name)() + } } /// A salsa ingredient that can be registered in the database. diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 59b2165b3..bc82a2057 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -418,6 +418,7 @@ impl std::panic::RefUnwindSafe for ZalsaLocal {} /// Summarizes "all the inputs that a query used" /// and "all the outputs it has written to" #[derive(Debug)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] // #[derive(Clone)] cloning this is expensive, so we don't derive pub(crate) struct QueryRevisions { /// The most revision in which some input changed. @@ -435,6 +436,8 @@ pub(crate) struct QueryRevisions { /// Note that this field could be in `QueryRevisionsExtra` as it is only relevant /// for accumulators, but we get it for free anyways due to padding. #[cfg(feature = "accumulator")] + // TODO: Support serializing accumulators. + #[cfg_attr(feature = "persistence", serde(skip))] pub(super) accumulated_inputs: AtomicInputAccumulatedValues, /// Are the `cycle_heads` verified to not be provisional anymore? @@ -442,12 +445,33 @@ pub(crate) struct QueryRevisions { /// Note that this field could be in `QueryRevisionsExtra` as it is only /// relevant for queries that participate in a cycle, but we get it for /// free anyways due to padding. + #[cfg_attr(feature = "persistence", serde(with = "verified_final"))] pub(super) verified_final: AtomicBool, /// Lazily allocated state. pub(super) extra: QueryRevisionsExtra, } +#[cfg(feature = "persistence")] +// A workaround the fact that `shuttle` atomic types do not implement `serde::{Serialize, Deserialize}`. +mod verified_final { + use crate::sync::atomic::{AtomicBool, Ordering}; + + pub fn serialize(value: &AtomicBool, serializer: S) -> Result + where + S: serde::Serializer, + { + serde::Serialize::serialize(&value.load(Ordering::Relaxed), serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + serde::Deserialize::deserialize(deserializer).map(AtomicBool::new) + } +} + impl QueryRevisions { #[cfg(feature = "salsa_unstable")] pub(crate) fn allocation_size(&self) -> usize { @@ -484,6 +508,7 @@ impl QueryRevisions { /// In particular, not all queries create tracked structs, participate /// in cycles, or create accumulators. #[derive(Debug, Default)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub(crate) struct QueryRevisionsExtra(Option>); impl QueryRevisionsExtra { @@ -518,8 +543,11 @@ impl QueryRevisionsExtra { } #[derive(Debug)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] struct QueryRevisionsExtraInner { #[cfg(feature = "accumulator")] + // TODO: Support serializing accumulators. + #[cfg_attr(feature = "persistence", serde(skip))] accumulated: AccumulatedMap, /// The ids of tracked structs created by this query. @@ -539,6 +567,10 @@ struct QueryRevisionsExtraInner { /// previous revision. To handle this, `diff_outputs` compares /// the structs from the old/new revision and retains /// only entries that appeared in the new revision. + // + // TODO: We only need to serialize the IDs of tracked structs that + // are actually going to be serialized. Those that are not will + // be created with new IDs anyways. tracked_struct_ids: ThinVec<(Identity, Id)>, /// This result was computed based on provisional values from @@ -680,8 +712,9 @@ impl QueryRevisions { /// Tracks the way that a memoized value for a query was created. /// /// This is a read-only reference to a `PackedQueryOrigin`. -#[derive(Debug, Clone, Copy)] #[repr(u8)] +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize))] pub enum QueryOriginRef<'a> { /// The value was assigned as the output of another query (e.g., using `specify`). /// The `DatabaseKeyIndex` is the identity of the assigning query. @@ -909,6 +942,41 @@ impl QueryOrigin { } } +#[cfg(feature = "persistence")] +impl serde::Serialize for QueryOrigin { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.as_ref().serialize(serializer) + } +} + +#[cfg(feature = "persistence")] +impl<'de> serde::Deserialize<'de> for QueryOrigin { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + // Matches the signature of `QueryOriginRef`. + #[derive(serde::Deserialize)] + #[repr(u8)] + pub enum QueryOriginOwned { + Assigned(DatabaseKeyIndex) = QueryOriginKind::Assigned as u8, + Derived(Box<[QueryEdge]>) = QueryOriginKind::Derived as u8, + DerivedUntracked(Box<[QueryEdge]>) = QueryOriginKind::DerivedUntracked as u8, + FixpointInitial = QueryOriginKind::FixpointInitial as u8, + } + + Ok(match QueryOriginOwned::deserialize(deserializer)? { + QueryOriginOwned::Assigned(key) => QueryOrigin::assigned(key), + QueryOriginOwned::Derived(edges) => QueryOrigin::derived(edges), + QueryOriginOwned::DerivedUntracked(edges) => QueryOrigin::derived_untracked(edges), + QueryOriginOwned::FixpointInitial => QueryOrigin::fixpoint_initial(), + }) + } +} + impl Drop for QueryOrigin { fn drop(&mut self) { match self.kind { @@ -948,6 +1016,7 @@ impl std::fmt::Debug for QueryOrigin { /// the size of the type. Notably, this type is 12 bytes as opposed to the 16 byte /// `QueryEdgeKind`, which is meaningful as inputs and outputs are stored contiguously. #[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct QueryEdge { key: DatabaseKeyIndex, } @@ -967,18 +1036,21 @@ impl QueryEdge { } } - /// Returns the kind of this query edge. - pub fn kind(self) -> QueryEdgeKind { + /// Return the key of this query edge. + pub fn key(self) -> DatabaseKeyIndex { // Clear the tag to restore the original index. - let untagged = DatabaseKeyIndex::new( + DatabaseKeyIndex::new( self.key.ingredient_index().with_tag(false), self.key.key_index(), - ); + ) + } + /// Returns the kind of this query edge. + pub fn kind(self) -> QueryEdgeKind { if self.key.ingredient_index().tag() { - QueryEdgeKind::Output(untagged) + QueryEdgeKind::Output(self.key()) } else { - QueryEdgeKind::Input(untagged) + QueryEdgeKind::Input(self.key()) } } } diff --git a/tests/compile-fail/incomplete_persistence.rs b/tests/compile-fail/incomplete_persistence.rs new file mode 100644 index 000000000..c2ab310ee --- /dev/null +++ b/tests/compile-fail/incomplete_persistence.rs @@ -0,0 +1,14 @@ +#[salsa::tracked(persist)] +struct Persistable<'db> { + field: NotPersistable<'db>, +} + +#[salsa::tracked] +struct NotPersistable<'db> { + field: usize, +} + +#[salsa::tracked(persist)] +fn query(_db: &dyn salsa::Database, _input: NotPersistable<'_>) {} + +fn main() {} diff --git a/tests/compile-fail/incomplete_persistence.stderr b/tests/compile-fail/incomplete_persistence.stderr new file mode 100644 index 000000000..ded998277 --- /dev/null +++ b/tests/compile-fail/incomplete_persistence.stderr @@ -0,0 +1,100 @@ +error[E0277]: the trait bound `NotPersistable<'_>: Serialize` is not satisfied + --> tests/compile-fail/incomplete_persistence.rs:1:1 + | +1 | #[salsa::tracked(persist)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ + | | + | the trait `Serialize` is not implemented for `NotPersistable<'_>` + | required by a bound introduced by this call + | + = note: for local types consider adding `#[derive(serde::Serialize)]` to your `NotPersistable<'_>` type + = note: for types from other crates check whether the crate offers a `serde` feature flag + = help: the following other types implement trait `Serialize`: + &'a T + &'a mut T + () + (T,) + (T0, T1) + (T0, T1, T2) + (T0, T1, T2, T3) + (T0, T1, T2, T3, T4) + and $N others + = note: required for `(NotPersistable<'_>,)` to implement `Serialize` + = note: this error originates in the macro `salsa::plumbing::setup_tracked_struct` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0277]: the trait bound `NotPersistable<'_>: Deserialize<'_>` is not satisfied + --> tests/compile-fail/incomplete_persistence.rs:1:1 + | +1 | #[salsa::tracked(persist)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Deserialize<'_>` is not implemented for `NotPersistable<'_>` + | + = note: for local types consider adding `#[derive(serde::Deserialize)]` to your `NotPersistable<'_>` type + = note: for types from other crates check whether the crate offers a `serde` feature flag + = help: the following other types implement trait `Deserialize<'de>`: + &'a Path + &'a [u8] + &'a str + () + (T,) + (T0, T1) + (T0, T1, T2) + (T0, T1, T2, T3) + and $N others + = note: required for `(NotPersistable<'_>,)` to implement `Deserialize<'_>` + = note: this error originates in the macro `salsa::plumbing::setup_tracked_struct` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0277]: the trait bound `NotPersistable<'db>: Serialize` is not satisfied + --> tests/compile-fail/incomplete_persistence.rs:12:45 + | +12 | fn query(_db: &dyn salsa::Database, _input: NotPersistable<'_>) {} + | ^^^^^^^^^^^^^^^^^^ the trait `Serialize` is not implemented for `NotPersistable<'db>` + | + = note: for local types consider adding `#[derive(serde::Serialize)]` to your `NotPersistable<'db>` type + = note: for types from other crates check whether the crate offers a `serde` feature flag + = help: the following other types implement trait `Serialize`: + &'a T + &'a mut T + () + (T,) + (T0, T1) + (T0, T1, T2) + (T0, T1, T2, T3) + (T0, T1, T2, T3, T4) + and $N others +note: required by a bound in `query_input_is_persistable` + --> tests/compile-fail/incomplete_persistence.rs:11:1 + | +11 | #[salsa::tracked(persist)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ + | | + | required by a bound in this function + | required by this bound in `query_input_is_persistable` + = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0277]: the trait bound `for<'de> NotPersistable<'db>: Deserialize<'de>` is not satisfied + --> tests/compile-fail/incomplete_persistence.rs:12:45 + | +12 | fn query(_db: &dyn salsa::Database, _input: NotPersistable<'_>) {} + | ^^^^^^^^^^^^^^^^^^ the trait `for<'de> Deserialize<'de>` is not implemented for `NotPersistable<'db>` + | + = note: for local types consider adding `#[derive(serde::Deserialize)]` to your `NotPersistable<'db>` type + = note: for types from other crates check whether the crate offers a `serde` feature flag + = help: the following other types implement trait `Deserialize<'de>`: + &'a Path + &'a [u8] + &'a str + () + (T,) + (T0, T1) + (T0, T1, T2) + (T0, T1, T2, T3) + and $N others +note: required by a bound in `query_input_is_persistable` + --> tests/compile-fail/incomplete_persistence.rs:11:1 + | +11 | #[salsa::tracked(persist)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ + | | + | required by a bound in this function + | required by this bound in `query_input_is_persistable` + = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/compile-fail/invalid_persist_options.rs b/tests/compile-fail/invalid_persist_options.rs new file mode 100644 index 000000000..4464532ff --- /dev/null +++ b/tests/compile-fail/invalid_persist_options.rs @@ -0,0 +1,56 @@ +#[salsa::input(persist)] +struct Input { + text: String, +} + +#[salsa::input(persist())] +struct Input2 { + text: String, +} + +#[salsa::input(persist(serialize = serde::Serialize::serialize))] +struct Input3 { + text: String, +} + +#[salsa::input(persist(deserialize = serde::Deserialize::deserialize))] +struct Input4 { + text: String, +} + +#[salsa::input(persist(serialize = serde::Serialize::serialize, deserialize = serde::Deserialize::deserialize))] +struct Input5 { + text: String, +} + +#[salsa::input(persist(serialize = serde::Serialize::serialize, serialize = serde::Serialize::serialize))] +struct InvalidInput { + text: String, +} + +#[salsa::input(persist(deserialize = serde::Deserialize::deserialize, deserialize = serde::Deserialize::deserialize))] +struct InvalidInput2 { + text: String, +} + +#[salsa::input(persist(not_an_option = std::convert::identity))] +struct InvalidInput3 { + text: String, +} + +#[salsa::tracked(persist)] +fn tracked_fn(db: &dyn salsa::Database, input: Input) -> String { + input.text(db) +} + +#[salsa::tracked(persist())] +fn tracked_fn2(db: &dyn salsa::Database, input: Input) -> String { + input.text(db) +} + +#[salsa::tracked(persist(serialize = serde::Serialize::serialize, deserialize = serde::Deserialize::deserialize))] +fn invalid_tracked_fn(db: &dyn salsa::Database, input: Input) -> String { + input.text(db) +} + +fn main() {} diff --git a/tests/compile-fail/invalid_persist_options.stderr b/tests/compile-fail/invalid_persist_options.stderr new file mode 100644 index 000000000..f616b6d59 --- /dev/null +++ b/tests/compile-fail/invalid_persist_options.stderr @@ -0,0 +1,23 @@ +error: option `serialize` provided twice + --> tests/compile-fail/invalid_persist_options.rs:26:65 + | +26 | #[salsa::input(persist(serialize = serde::Serialize::serialize, serialize = serde::Serialize::serialize))] + | ^^^^^^^^^ + +error: option `deserialize` provided twice + --> tests/compile-fail/invalid_persist_options.rs:31:71 + | +31 | #[salsa::input(persist(deserialize = serde::Deserialize::deserialize, deserialize = serde::Deserialize::deserialize))] + | ^^^^^^^^^^^ + +error: unexpected argument + --> tests/compile-fail/invalid_persist_options.rs:36:24 + | +36 | #[salsa::input(persist(not_an_option = std::convert::identity))] + | ^^^^^^^^^^^^^ + +error: unexpected argument + --> tests/compile-fail/invalid_persist_options.rs:51:26 + | +51 | #[salsa::tracked(persist(serialize = serde::Serialize::serialize, deserialize = serde::Deserialize::deserialize))] + | ^^^^^^^^^ diff --git a/tests/compile_fail.rs b/tests/compile_fail.rs index 73f87ee52..3648f756d 100644 --- a/tests/compile_fail.rs +++ b/tests/compile_fail.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "inventory")] +#![cfg(all(feature = "inventory", feature = "persistence"))] #[rustversion::all(stable, since(1.89))] #[test] diff --git a/tests/debug_db_contents.rs b/tests/debug_db_contents.rs index a253d8869..4160a3e32 100644 --- a/tests/debug_db_contents.rs +++ b/tests/debug_db_contents.rs @@ -31,6 +31,7 @@ fn execute() { // test interned structs let interned = InternedStruct::ingredient(db.zalsa()) .entries(db.zalsa()) + .map(|(_, value)| value) .collect::>(); assert_eq!(interned.len(), 2); @@ -42,6 +43,7 @@ fn execute() { let inputs = InputStruct::ingredient(&db) .entries(db.zalsa()) + .map(|(_, value)| value) .collect::>(); assert_eq!(inputs.len(), 1); @@ -52,6 +54,7 @@ fn execute() { assert_eq!(computed, 44); let tracked = TrackedStruct::ingredient(&db) .entries(db.zalsa()) + .map(|(_, value)| value) .collect::>(); assert_eq!(tracked.len(), 1); diff --git a/tests/interned-structs_self_ref.rs b/tests/interned-structs_self_ref.rs index 3443f3ac2..01ff914c0 100644 --- a/tests/interned-structs_self_ref.rs +++ b/tests/interned-structs_self_ref.rs @@ -89,17 +89,28 @@ const _: () = { const DEBUG_NAME: &'static str = "InternedString"; type Fields<'a> = StructData<'a>; type Struct<'a> = InternedString<'a>; - } - impl Configuration_ { - pub fn ingredient(db: &Db) -> &zalsa_struct_::IngredientImpl + + const PERSIST: bool = false; + + fn serialize(value: &Self::Fields<'_>, serializer: S) -> Result + where + S: zalsa_::serde::Serializer, + { + panic!("attempted to serialize value not marked with `persist` attribute") + } + + fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> where - Db: ?Sized + zalsa_::Database, + D: zalsa_::serde::Deserializer<'de>, { + panic!("attempted to deserialize value not marked with `persist` attribute") + } + } + impl Configuration_ { + pub fn ingredient(zalsa: &zalsa_::Zalsa) -> &zalsa_struct_::IngredientImpl { static CACHE: zalsa_::IngredientCache> = zalsa_::IngredientCache::new(); - let zalsa = db.zalsa(); - // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only // ingredient created by our jar is the struct ingredient. unsafe { @@ -136,6 +147,14 @@ const _: () = { .into() } + fn entries(zalsa: &zalsa_::Zalsa) -> impl Iterator + '_ { + let ingredient_index = + zalsa.lookup_jar_by_type::>(); + ::ingredient(zalsa) + .entries(zalsa) + .map(|(key, _)| key) + } + #[inline] fn cast(id: zalsa_::Id, type_id: TypeId) -> Option { if type_id == TypeId::of::() { @@ -180,7 +199,7 @@ const _: () = { Db_: ?Sized + salsa::Database, String: zalsa_::interned::HashEqLike, { - Configuration_::ingredient(db).intern( + Configuration_::ingredient(db.zalsa()).intern( db.zalsa(), db.zalsa_local(), StructKey::<'db>(data, std::marker::PhantomData::default()), @@ -196,20 +215,20 @@ const _: () = { where Db_: ?Sized + zalsa_::Database, { - let fields = Configuration_::ingredient(db).fields(db.zalsa(), self); + let fields = Configuration_::ingredient(db.zalsa()).fields(db.zalsa(), self); std::clone::Clone::clone((&fields.0)) } fn other(self, db: &'db Db_) -> InternedString<'db> where Db_: ?Sized + zalsa_::Database, { - let fields = Configuration_::ingredient(db).fields(db.zalsa(), self); + let fields = Configuration_::ingredient(db.zalsa()).fields(db.zalsa(), self); std::clone::Clone::clone((&fields.1)) } #[doc = r" Default debug formatting for this struct (may be useful if you define your own `Debug` impl)"] pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { zalsa_::with_attached_database(|db| { - let fields = Configuration_::ingredient(db).fields(db.zalsa(), this); + let fields = Configuration_::ingredient(db.zalsa()).fields(db.zalsa(), this); let mut f = f.debug_struct("InternedString"); let f = f.field("data", &fields.0); let f = f.field("other", &fields.1); diff --git a/tests/memory-usage.rs b/tests/memory-usage.rs index 9a9433d3b..f62a961a8 100644 --- a/tests/memory-usage.rs +++ b/tests/memory-usage.rs @@ -89,15 +89,6 @@ fn test() { 450, ), }, - IngredientInfo { - debug_name: "MyTracked", - count: 4, - size_of_metadata: 128, - size_of_fields: 96, - heap_size_of_fields: Some( - 300, - ), - }, IngredientInfo { debug_name: "MyInterned", count: 3, @@ -107,6 +98,15 @@ fn test() { 450, ), }, + IngredientInfo { + debug_name: "MyTracked", + count: 4, + size_of_metadata: 128, + size_of_fields: 96, + heap_size_of_fields: Some( + 300, + ), + }, IngredientInfo { debug_name: "input_to_string::interned_arguments", count: 1, diff --git a/tests/persistence.rs b/tests/persistence.rs new file mode 100644 index 000000000..8848d79d3 --- /dev/null +++ b/tests/persistence.rs @@ -0,0 +1,378 @@ +#![cfg(all(feature = "persistence", feature = "inventory"))] + +mod common; +use common::LogDatabase; + +use expect_test::expect; + +#[salsa::input(persist)] +struct MyInput { + field: usize, +} + +#[salsa::interned(persist)] +struct MyInterned<'db> { + field: String, +} + +#[salsa::tracked(persist)] +struct MyTracked<'db> { + field: String, +} + +#[salsa::tracked(persist)] +fn unit_to_interned(db: &dyn salsa::Database) -> MyInterned<'_> { + MyInterned::new(db, "a".repeat(50)) +} + +#[salsa::tracked(persist)] +fn input_to_tracked(db: &dyn salsa::Database, input: MyInput) -> MyTracked<'_> { + MyTracked::new(db, "a".repeat(input.field(db))) +} + +#[salsa::tracked(persist)] +fn input_pair_to_string(db: &dyn salsa::Database, input1: MyInput, input2: MyInput) -> String { + "a".repeat(input1.field(db) + input2.field(db)) +} + +#[test] +fn everything() { + let mut db = common::LoggerDatabase::default(); + + let _input1 = MyInput::new(&db, 1); + let _input2 = MyInput::new(&db, 2); + + let serialized = + serde_json::to_string_pretty(&::as_serialize(&mut db)).unwrap(); + + let expected = expect![[r#" + { + "runtime": { + "revisions": [ + 1, + 1, + 1 + ] + }, + "ingredients": { + "0": { + "1": { + "durabilities": [ + "Low" + ], + "revisions": [ + 1 + ], + "fields": [ + 1 + ] + }, + "2": { + "durabilities": [ + "Low" + ], + "revisions": [ + 1 + ], + "fields": [ + 2 + ] + } + } + } + }"#]]; + + expected.assert_eq(&serialized); + + let input1 = MyInput::new(&db, 1); + let input2 = MyInput::new(&db, 2); + + let _out = unit_to_interned(&db); + let _out = input_to_tracked(&db, input1); + let _out = input_pair_to_string(&db, input1, input2); + + let serialized = + serde_json::to_string_pretty(&::as_serialize(&mut db)).unwrap(); + + let expected = expect![[r#" + { + "runtime": { + "revisions": [ + 1, + 1, + 1 + ] + }, + "ingredients": { + "0": { + "1": { + "durabilities": [ + "Low" + ], + "revisions": [ + 1 + ], + "fields": [ + 1 + ] + }, + "2": { + "durabilities": [ + "Low" + ], + "revisions": [ + 1 + ], + "fields": [ + 2 + ] + }, + "3": { + "durabilities": [ + "Low" + ], + "revisions": [ + 1 + ], + "fields": [ + 1 + ] + }, + "4": { + "durabilities": [ + "Low" + ], + "revisions": [ + 1 + ], + "fields": [ + 2 + ] + } + }, + "2": { + "2049": { + "durability": "High", + "last_interned_at": 1, + "fields": [ + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + ] + } + }, + "4": { + "3073": { + "durability": "Low", + "updated_at": 1, + "revisions": [], + "fields": [ + "a" + ] + } + }, + "6": { + "4097": { + "durability": "High", + "last_interned_at": 18446744073709551615, + "fields": [ + { + "index": 3, + "generation": 0 + }, + { + "index": 4, + "generation": 0 + } + ] + } + }, + "11": { + "1025": { + "durability": "High", + "last_interned_at": 18446744073709551615, + "fields": null + } + }, + "5": { + "6:4097": { + "value": "aaa", + "verified_at": 1, + "revisions": { + "changed_at": 1, + "durability": "Low", + "origin": { + "Derived": [ + { + "key": { + "key_index": { + "index": 3, + "generation": 0 + }, + "ingredient_index": 1 + } + }, + { + "key": { + "key_index": { + "index": 4, + "generation": 0 + }, + "ingredient_index": 1 + } + } + ] + }, + "verified_final": true, + "extra": null + } + } + }, + "7": { + "0:3": { + "value": { + "index": 3073, + "generation": 0 + }, + "verified_at": 1, + "revisions": { + "changed_at": 1, + "durability": "Low", + "origin": { + "Derived": [ + { + "key": { + "key_index": { + "index": 3, + "generation": 0 + }, + "ingredient_index": 1 + } + }, + { + "key": { + "key_index": { + "index": 3073, + "generation": 0 + }, + "ingredient_index": 2147483652 + } + } + ] + }, + "verified_final": true, + "extra": { + "tracked_struct_ids": [ + [ + { + "ingredient_index": 4, + "hash": 6073466998405137972, + "disambiguator": 0 + }, + { + "index": 3073, + "generation": 0 + } + ] + ], + "cycle_heads": [], + "iteration": 0 + } + } + } + }, + "10": { + "11:1025": { + "value": { + "index": 2049, + "generation": 0 + }, + "verified_at": 1, + "revisions": { + "changed_at": 1, + "durability": "High", + "origin": { + "Derived": [ + { + "key": { + "key_index": { + "index": 2049, + "generation": 0 + }, + "ingredient_index": 2 + } + } + ] + }, + "verified_final": true, + "extra": null + } + } + } + } + }"#]]; + + expected.assert_eq(&serialized); + + let mut db = common::EventLoggerDatabase::default(); + ::deserialize( + &mut db, + &mut serde_json::Deserializer::from_str(&serialized), + ) + .unwrap(); + + let _out = unit_to_interned(&db); + let _out = input_to_tracked(&db, input1); + let _out = input_pair_to_string(&db, input1, input2); + + // The structs are not recreated, and the queries are not reexecuted. + db.assert_logs(expect![[r#" + [ + "DidSetCancellationFlag", + "WillCheckCancellation", + "WillCheckCancellation", + "WillCheckCancellation", + ]"#]]); +} + +#[test] +#[should_panic(expected = "is not persistable")] +fn invalid_dependency() { + #[salsa::interned] + struct MyInterned<'db> { + field: usize, + } + + #[salsa::tracked(persist)] + fn new_interned(db: &dyn salsa::Database) { + let _interned = MyInterned::new(db, 0); + } + + let mut db = common::LoggerDatabase::default(); + + new_interned(&db); + + let _serialized = + serde_json::to_string_pretty(&::as_serialize(&mut db)).unwrap(); +} + +#[test] +fn serialize_nothing() { + let mut db = common::LoggerDatabase::default(); + + let serialized = + serde_json::to_string_pretty(&::as_serialize(&mut db)).unwrap(); + + // Empty ingredients should not be serialized. + let expected = expect![[r#" + { + "runtime": { + "revisions": [ + 1, + 1, + 1 + ] + }, + "ingredients": {} + }"#]]; + + expected.assert_eq(&serialized); +} From be1b76b78b3f7530acbfcee6d15e49cdbd75bc7b Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Tue, 12 Aug 2025 09:12:48 +0200 Subject: [PATCH 27/65] fix: Do not unnecessarily require `Debug` on fields for interned structs (#951) --- components/salsa-macro-rules/src/macro_if.rs | 8 +++ .../src/setup_interned_struct.rs | 66 ++++++++++++++----- tests/debug_bounds.rs | 51 ++++++++++++++ 3 files changed, 110 insertions(+), 15 deletions(-) create mode 100644 tests/debug_bounds.rs diff --git a/components/salsa-macro-rules/src/macro_if.rs b/components/salsa-macro-rules/src/macro_if.rs index e7d05beff..a8dcc49c9 100644 --- a/components/salsa-macro-rules/src/macro_if.rs +++ b/components/salsa-macro-rules/src/macro_if.rs @@ -22,4 +22,12 @@ macro_rules! macro_if { (if0 $n:literal { $($t:tt)* } else { $($f:tt)*}) => { $($f)* }; + + (iftt () { $($t:tt)* } else { $($f:tt)*}) => { + $($f)* + }; + + (iftt ($($tt:tt)+) { $($t:tt)* } else { $($f:tt)*}) => { + $($t)* + }; } diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index e473e58b9..e73737ae8 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -330,22 +330,58 @@ macro_rules! setup_interned_struct { ) } )* + } - /// Default debug formatting for this struct (may be useful if you define your own `Debug` impl) - pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - $zalsa::with_attached_database(|db| { - let zalsa = db.zalsa(); - let fields = $Configuration::ingredient(zalsa).fields(zalsa, this); - let mut f = f.debug_struct(stringify!($Struct)); - $( - let f = f.field(stringify!($field_id), &fields.$field_index); - )* - f.finish() - }).unwrap_or_else(|| { - f.debug_tuple(stringify!($Struct)) - .field(&$zalsa::AsId::as_id(&this)) - .finish() - }) + // Duplication can be dropped here once we no longer allow the `no_lifetime` hack + $zalsa::macro_if! { + iftt ($($db_lt_arg)?) { + impl $Struct<'_> { + /// Default debug formatting for this struct (may be useful if you define your own `Debug` impl) + pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + where + // rustc rejects trivial bounds, but it cannot see through higher-ranked bounds + // with its check :^) + $(for<$db_lt> $field_ty: std::fmt::Debug),* + { + $zalsa::with_attached_database(|db| { + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient(zalsa).fields(zalsa, this); + let mut f = f.debug_struct(stringify!($Struct)); + $( + let f = f.field(stringify!($field_id), &fields.$field_index); + )* + f.finish() + }).unwrap_or_else(|| { + f.debug_tuple(stringify!($Struct)) + .field(&$zalsa::AsId::as_id(&this)) + .finish() + }) + } + } + } else { + impl $Struct { + /// Default debug formatting for this struct (may be useful if you define your own `Debug` impl) + pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + where + // rustc rejects trivial bounds, but it cannot see through higher-ranked bounds + // with its check :^) + $(for<$db_lt> $field_ty: std::fmt::Debug),* + { + $zalsa::with_attached_database(|db| { + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient(zalsa).fields(zalsa, this); + let mut f = f.debug_struct(stringify!($Struct)); + $( + let f = f.field(stringify!($field_id), &fields.$field_index); + )* + f.finish() + }).unwrap_or_else(|| { + f.debug_tuple(stringify!($Struct)) + .field(&$zalsa::AsId::as_id(&this)) + .finish() + }) + } + } } } }; diff --git a/tests/debug_bounds.rs b/tests/debug_bounds.rs new file mode 100644 index 000000000..36fcb2331 --- /dev/null +++ b/tests/debug_bounds.rs @@ -0,0 +1,51 @@ +#![cfg(feature = "inventory")] + +//! Test that debug and non-debug structs compile correctly + +#[derive(Ord, PartialOrd, Eq, PartialEq, Copy, Clone, Hash)] +struct NotDebug; +#[derive(Ord, PartialOrd, Eq, PartialEq, Copy, Clone, Hash, Debug)] +struct Debug; + +#[salsa::input(debug)] +struct DebugInput { + field: Debug, +} + +#[salsa::input] +struct NotDebugInput { + field: NotDebug, +} + +#[salsa::interned(debug)] +struct DebugInterned { + field: Debug, +} + +#[salsa::interned] +struct NotDebugInterned { + field: NotDebug, +} + +#[salsa::interned(no_lifetime, debug)] +struct DebugInternedNoLifetime { + field: Debug, +} + +#[salsa::interned(no_lifetime)] +struct NotDebugInternedNoLifetime { + field: NotDebug, +} + +#[salsa::tracked(debug)] +struct DebugTracked<'db> { + field: Debug, +} + +#[salsa::tracked] +struct NotDebugTracked<'db> { + field: NotDebug, +} + +#[test] +fn ok() {} From a2cd1b8a72d2958d9c96d21d06e4316d0357138f Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 12 Aug 2025 11:30:59 +0200 Subject: [PATCH 28/65] Remove jemalloc (#972) --- Cargo.toml | 3 --- benches/accumulator.rs | 2 -- benches/compare.rs | 2 -- benches/dataflow.rs | 2 -- benches/incremental.rs | 2 -- benches/shims/global_alloc_overwrite.rs | 29 ------------------------- 6 files changed, 40 deletions(-) delete mode 100644 benches/shims/global_alloc_overwrite.rs diff --git a/Cargo.toml b/Cargo.toml index 847bca2b4..bc28eab4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,9 +76,6 @@ test-log = { version = "0.2.18", features = ["trace"] } trybuild = "1.0" serde_json = "1.0.140" -[target.'cfg(all(not(target_os = "windows"), not(target_os = "openbsd"), any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "powerpc64")))'.dev-dependencies] -tikv-jemallocator = "0.6.0" - [[bench]] name = "compare" harness = false diff --git a/benches/accumulator.rs b/benches/accumulator.rs index 3a91ec19f..041c1f474 100644 --- a/benches/accumulator.rs +++ b/benches/accumulator.rs @@ -3,8 +3,6 @@ use std::hint::black_box; use codspeed_criterion_compat::{criterion_group, criterion_main, BatchSize, Criterion}; use salsa::Accumulator; -include!("shims/global_alloc_overwrite.rs"); - #[salsa::input] struct Input { expressions: usize, diff --git a/benches/compare.rs b/benches/compare.rs index c4e6b36f8..8d9dcdad6 100644 --- a/benches/compare.rs +++ b/benches/compare.rs @@ -6,8 +6,6 @@ use codspeed_criterion_compat::{ }; use salsa::Setter; -include!("shims/global_alloc_overwrite.rs"); - #[salsa::input] pub struct Input { #[returns(ref)] diff --git a/benches/dataflow.rs b/benches/dataflow.rs index f4f1aeaf1..24f5a16ee 100644 --- a/benches/dataflow.rs +++ b/benches/dataflow.rs @@ -8,8 +8,6 @@ use std::iter::IntoIterator; use codspeed_criterion_compat::{criterion_group, criterion_main, BatchSize, Criterion}; use salsa::{CycleRecoveryAction, Database as Db, Setter}; -include!("shims/global_alloc_overwrite.rs"); - /// A Use of a symbol. #[salsa::input] struct Use { diff --git a/benches/incremental.rs b/benches/incremental.rs index 872d9fa1a..7b2d8b559 100644 --- a/benches/incremental.rs +++ b/benches/incremental.rs @@ -3,8 +3,6 @@ use std::hint::black_box; use codspeed_criterion_compat::{criterion_group, criterion_main, BatchSize, Criterion}; use salsa::Setter; -include!("shims/global_alloc_overwrite.rs"); - #[salsa::input] struct Input { field: usize, diff --git a/benches/shims/global_alloc_overwrite.rs b/benches/shims/global_alloc_overwrite.rs deleted file mode 100644 index e3b5ea74f..000000000 --- a/benches/shims/global_alloc_overwrite.rs +++ /dev/null @@ -1,29 +0,0 @@ -#[cfg(all( - not(target_os = "windows"), - not(target_os = "openbsd"), - any( - target_arch = "x86_64", - target_arch = "aarch64", - target_arch = "powerpc64" - ) -))] -#[global_allocator] -static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; - -// Disable decay after 10s because it can show up as *random* slow allocations -// in benchmarks. We don't need purging in benchmarks because it isn't important -// to give unallocated pages back to the OS. -// https://jemalloc.net/jemalloc.3.html#opt.dirty_decay_ms -#[cfg(all( - not(target_os = "windows"), - not(target_os = "openbsd"), - any( - target_arch = "x86_64", - target_arch = "aarch64", - target_arch = "powerpc64" - ) -))] -#[allow(non_upper_case_globals)] -#[export_name = "_rjem_malloc_conf"] -#[allow(unsafe_code)] -pub static _rjem_malloc_conf: &[u8] = b"dirty_decay_ms:-1,muzzy_decay_ms:-1\0"; From e5bd9eb673696702cef43ce63b54a8cecf9a9dfe Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 12 Aug 2025 11:48:58 +0200 Subject: [PATCH 29/65] refactor: Remove tracked structs from query outputs (#969) * refactor: Remove tracked structs from outputs * clean up * fix `IdentityMap::is_active` * Update persistence snapshot * Perf? * Other nit * Remove Deref * Split `diff_stale_outputs` * More short circuits * Remove shrink to fit? * Undo shrink-to-fit removal * Pass CompletedQuery to diff_outputs --------- Co-authored-by: Ibraheem Ahmed --- src/active_query.rs | 34 +++++-- src/function/diff_outputs.rs | 56 +++++------ src/function/execute.rs | 64 ++++++++---- src/function/fetch.rs | 14 ++- src/function/memo.rs | 11 ++- src/function/specify.rs | 33 ++++--- src/interned.rs | 4 +- src/table/memo.rs | 13 +-- src/tracked_struct.rs | 183 ++++++++++++++++++++++++----------- src/zalsa_local.rs | 44 ++++----- tests/cycle_output.rs | 12 +-- tests/memory-usage.rs | 4 +- tests/persistence.rs | 9 -- 13 files changed, 294 insertions(+), 187 deletions(-) diff --git a/src/active_query.rs b/src/active_query.rs index dff64db3e..cc5e4fc58 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -30,7 +30,6 @@ pub(crate) struct ActiveQuery { /// Inputs: Set of subqueries that were accessed thus far. /// Outputs: Tracks values written by this query. Could be... /// - /// * tracked structs created /// * invocations of `specify` /// * accumulators pushed to input_outputs: FxIndexSet, @@ -77,10 +76,14 @@ impl ActiveQuery { untracked_read: bool, ) { assert!(self.input_outputs.is_empty()); + self.input_outputs.extend(edges.iter().cloned()); self.durability = self.durability.min(durability); self.changed_at = self.changed_at.max(changed_at); self.untracked_read |= untracked_read; + + // Mark all tracked structs from the previous iteration as active. + self.tracked_struct_ids.mark_all_active(); } pub(super) fn add_read( @@ -139,10 +142,6 @@ impl ActiveQuery { } /// True if the given key was output by this query. - pub(super) fn is_output(&self, key: DatabaseKeyIndex) -> bool { - self.input_outputs.contains(&QueryEdge::output(key)) - } - pub(super) fn disambiguate(&mut self, key: IdentityHash) -> Disambiguator { self.disambiguator_map.disambiguate(key) } @@ -186,7 +185,7 @@ impl ActiveQuery { } } - fn top_into_revisions(&mut self) -> QueryRevisions { + fn top_into_revisions(&mut self) -> CompletedQuery { let &mut Self { database_key_index: _, durability, @@ -213,15 +212,17 @@ impl ActiveQuery { #[cfg(feature = "accumulator")] let accumulated_inputs = AtomicInputAccumulatedValues::new(accumulated_inputs); let verified_final = cycle_heads.is_empty(); + let (active_tracked_structs, stale_tracked_structs) = tracked_struct_ids.drain(); + let extra = QueryRevisionsExtra::new( #[cfg(feature = "accumulator")] mem::take(accumulated), - mem::take(tracked_struct_ids), + active_tracked_structs, mem::take(cycle_heads), iteration_count, ); - QueryRevisions { + let revisions = QueryRevisions { changed_at, durability, origin, @@ -229,6 +230,11 @@ impl ActiveQuery { accumulated_inputs, verified_final: AtomicBool::new(verified_final), extra, + }; + + CompletedQuery { + revisions, + stale_tracked_structs, } } @@ -370,7 +376,7 @@ impl QueryStack { &mut self, key: DatabaseKeyIndex, #[cfg(debug_assertions)] push_len: usize, - ) -> QueryRevisions { + ) -> CompletedQuery { #[cfg(debug_assertions)] assert_eq!(push_len, self.len(), "unbalanced push/pop"); debug_assert_ne!(self.len, 0, "too many pops"); @@ -395,6 +401,16 @@ impl QueryStack { } } +/// The state of a completed query. +pub(crate) struct CompletedQuery { + /// Inputs and outputs accumulated during query execution. + pub(crate) revisions: QueryRevisions, + + /// The keys of any tracked structs that were created in a previous execution of the + /// query but not the current one, and should be marked as stale. + pub(crate) stale_tracked_structs: Vec, +} + struct CapturedQuery { database_key_index: DatabaseKeyIndex, durability: Durability, diff --git a/src/function/diff_outputs.rs b/src/function/diff_outputs.rs index b1d17b75a..923a0fc88 100644 --- a/src/function/diff_outputs.rs +++ b/src/function/diff_outputs.rs @@ -1,63 +1,50 @@ +use crate::active_query::CompletedQuery; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl}; use crate::hash::FxIndexSet; use crate::zalsa::Zalsa; -use crate::zalsa_local::{output_edges, QueryOriginRef, QueryRevisions}; -use crate::{DatabaseKeyIndex, Event, EventKind, Id}; +use crate::zalsa_local::{output_edges, QueryOriginRef}; +use crate::{DatabaseKeyIndex, Event, EventKind}; impl IngredientImpl where C: Configuration, { - /// Compute the old and new outputs and invoke `remove_stale_output` - /// for each output that was generated before but is not generated now. - /// - /// This function takes a `&mut` reference to `revisions` to remove outputs - /// that no longer exist in this revision from [`QueryRevisions::tracked_struct_ids`]. + /// Compute the old and new outputs and invoke `remove_stale_output` for each output that + /// was generated before but is not generated now. pub(super) fn diff_outputs( &self, zalsa: &Zalsa, key: DatabaseKeyIndex, old_memo: &Memo<'_, C>, - revisions: &mut QueryRevisions, + completed_query: &CompletedQuery, ) { let (QueryOriginRef::Derived(edges) | QueryOriginRef::DerivedUntracked(edges)) = old_memo.revisions.origin.as_ref() else { return; }; - // Iterate over the outputs of the `old_memo` and put them into a hashset - // - // Ignore key_generation here, because we use the same tracked struct allocation for - // all generations with the same key_index and can't report it as stale - let mut old_outputs: FxIndexSet<_> = output_edges(edges) - .map(|a| (a.ingredient_index(), a.key_index().index())) - .collect(); - if old_outputs.is_empty() { - return; + // Note that tracked structs are not stored as direct query outputs, but they are still outputs + // that need to be reported as stale. + for output in &completed_query.stale_tracked_structs { + Self::report_stale_output(zalsa, key, *output); } - // Iterate over the outputs of the current query - // and remove elements from `old_outputs` when we find them - for new_output in revisions.origin.as_ref().outputs() { - old_outputs.swap_remove(&( - new_output.ingredient_index(), - new_output.key_index().index(), - )); + let mut stale_outputs = output_edges(edges).collect::>(); + + if stale_outputs.is_empty() { + return; } - // Remove the outputs that are no longer present in the current revision - // to prevent that the next revision is seeded with an id mapping that no longer exists. - if let Some(tracked_struct_ids) = revisions.tracked_struct_ids_mut() { - tracked_struct_ids - .retain(|(k, value)| !old_outputs.contains(&(k.ingredient_index(), value.index()))); - }; + // Preserve any outputs that were recreated in the current revision. + for new_output in completed_query.revisions.origin.as_ref().outputs() { + stale_outputs.swap_remove(&new_output); + } - for (ingredient_index, key_index) in old_outputs { - // SAFETY: key_index acquired from valid output - let id = unsafe { Id::from_index(key_index) }; - Self::report_stale_output(zalsa, key, DatabaseKeyIndex::new(ingredient_index, id)); + // Any outputs that were created in a previous revision but not the current one are stale. + for output in stale_outputs { + Self::report_stale_output(zalsa, key, output); } } @@ -68,6 +55,7 @@ where output_key: output, }) }); + output.remove_stale_output(zalsa, key); } } diff --git a/src/function/execute.rs b/src/function/execute.rs index 5587b1d94..67cee969d 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,9 +1,10 @@ +use crate::active_query::CompletedQuery; use crate::cycle::{CycleRecoveryStrategy, IterationCount}; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl}; use crate::sync::atomic::{AtomicBool, Ordering}; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; -use crate::zalsa_local::{ActiveQueryGuard, QueryRevisions}; +use crate::zalsa_local::ActiveQueryGuard; use crate::{Event, EventKind, Id}; impl IngredientImpl @@ -39,15 +40,15 @@ where }); let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); - let (new_value, mut revisions) = match C::CYCLE_STRATEGY { + let (new_value, mut completed_query) = match C::CYCLE_STRATEGY { CycleRecoveryStrategy::Panic => { Self::execute_query(db, zalsa, active_query, opt_old_memo, id) } CycleRecoveryStrategy::FallbackImmediate => { - let (mut new_value, mut revisions) = + let (mut new_value, mut completed_query) = Self::execute_query(db, zalsa, active_query, opt_old_memo, id); - if let Some(cycle_heads) = revisions.cycle_heads_mut() { + if let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() { // Did the new result we got depend on our own provisional value, in a cycle? if cycle_heads.contains(&database_key_index) { // Ignore the computed value, leave the fallback value there. @@ -73,14 +74,14 @@ where .zalsa_local() .push_query(database_key_index, IterationCount::initial()); new_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); - revisions = active_query.pop(); + completed_query = active_query.pop(); // We need to set `cycle_heads` and `verified_final` because it needs to propagate to the callers. // When verifying this, we will see we have fallback and mark ourselves verified. - revisions.set_cycle_heads(cycle_heads); - revisions.verified_final = AtomicBool::new(false); + completed_query.revisions.set_cycle_heads(cycle_heads); + completed_query.revisions.verified_final = AtomicBool::new(false); } - (new_value, revisions) + (new_value, completed_query) } CycleRecoveryStrategy::Fixpoint => self.execute_maybe_iterate( db, @@ -97,16 +98,25 @@ where // really change, even if some of its inputs have. So we can // "backdate" its `changed_at` revision to be the same as the // old value. - self.backdate_if_appropriate(old_memo, database_key_index, &mut revisions, &new_value); + self.backdate_if_appropriate( + old_memo, + database_key_index, + &mut completed_query.revisions, + &new_value, + ); // Diff the new outputs with the old, to discard any no-longer-emitted // outputs and update the tracked struct IDs for seeding the next revision. - self.diff_outputs(zalsa, database_key_index, old_memo, &mut revisions); + self.diff_outputs(zalsa, database_key_index, old_memo, &completed_query); } self.insert_memo( zalsa, id, - Memo::new(Some(new_value), zalsa.current_revision(), revisions), + Memo::new( + Some(new_value), + zalsa.current_revision(), + completed_query.revisions, + ), memo_ingredient_index, ) } @@ -120,7 +130,7 @@ where zalsa: &'db Zalsa, id: Id, memo_ingredient_index: MemoIngredientIndex, - ) -> (C::Output<'db>, QueryRevisions) { + ) -> (C::Output<'db>, CompletedQuery) { let database_key_index = active_query.database_key_index; let mut iteration_count = IterationCount::initial(); let mut fell_back = false; @@ -131,11 +141,12 @@ where let mut opt_last_provisional: Option<&Memo<'db, C>> = None; loop { let previous_memo = opt_last_provisional.or(opt_old_memo); - let (mut new_value, mut revisions) = + let (mut new_value, mut completed_query) = Self::execute_query(db, zalsa, active_query, previous_memo, id); // Did the new result we got depend on our own provisional value, in a cycle? - if let Some(cycle_heads) = revisions + if let Some(cycle_heads) = completed_query + .revisions .cycle_heads_mut() .filter(|cycle_heads| cycle_heads.contains(&database_key_index)) { @@ -211,14 +222,21 @@ where }) }); cycle_heads.update_iteration_count(database_key_index, iteration_count); - revisions.update_iteration_count(iteration_count); + completed_query + .revisions + .update_iteration_count(iteration_count); crate::tracing::debug!( - "{database_key_index:?}: execute: iterate again, revisions: {revisions:#?}" + "{database_key_index:?}: execute: iterate again, revisions: {revisions:#?}", + revisions = &completed_query.revisions ); opt_last_provisional = Some(self.insert_memo( zalsa, id, - Memo::new(Some(new_value), zalsa.current_revision(), revisions), + Memo::new( + Some(new_value), + zalsa.current_revision(), + completed_query.revisions, + ), memo_ingredient_index, )); @@ -235,15 +253,19 @@ where if cycle_heads.is_empty() { // If there are no more cycle heads, we can mark this as verified. - revisions.verified_final.store(true, Ordering::Relaxed); + completed_query + .revisions + .verified_final + .store(true, Ordering::Relaxed); } } crate::tracing::debug!( - "{database_key_index:?}: execute: result.revisions = {revisions:#?}" + "{database_key_index:?}: execute: result.revisions = {revisions:#?}", + revisions = &completed_query.revisions ); - break (new_value, revisions); + break (new_value, completed_query); } } @@ -254,7 +276,7 @@ where active_query: ActiveQueryGuard<'db>, opt_old_memo: Option<&Memo<'db, C>>, id: Id, - ) -> (C::Output<'db>, QueryRevisions) { + ) -> (C::Output<'db>, CompletedQuery) { if let Some(old_memo) = opt_old_memo { // If we already executed this query once, then use the tracked-struct ids from the // previous execution as the starting point for the new one. diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 32e2eb44a..87ff22db3 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -297,14 +297,20 @@ where let active_query = zalsa_local.push_query(database_key_index, IterationCount::initial()); let fallback_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); - let mut revisions = active_query.pop(); - revisions.set_cycle_heads(CycleHeads::initial(database_key_index)); + let mut completed_query = active_query.pop(); + completed_query + .revisions + .set_cycle_heads(CycleHeads::initial(database_key_index)); // We need this for `cycle_heads()` to work. We will unset this in the outer `execute()`. - *revisions.verified_final.get_mut() = false; + *completed_query.revisions.verified_final.get_mut() = false; self.insert_memo( zalsa, id, - Memo::new(Some(fallback_value), zalsa.current_revision(), revisions), + Memo::new( + Some(fallback_value), + zalsa.current_revision(), + completed_query.revisions, + ), memo_ingredient_index, ) } diff --git a/src/function/memo.rs b/src/function/memo.rs index eb8fcec70..54df61420 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -321,8 +321,15 @@ impl crate::table::memo::Memo for Memo<'static, C> where C::Output<'static>: Send + Sync + Any, { - fn origin(&self) -> QueryOriginRef<'_> { - self.revisions.origin.as_ref() + fn remove_outputs(&self, zalsa: &Zalsa, executor: DatabaseKeyIndex) { + for stale_output in self.revisions.origin.as_ref().outputs() { + stale_output.remove_stale_output(zalsa, executor); + } + + for (identity, id) in self.revisions.tracked_struct_ids().into_iter().flatten() { + let key = DatabaseKeyIndex::new(identity.ingredient_index(), *id); + key.remove_stale_output(zalsa, executor); + } } #[cfg(feature = "salsa_unstable")] diff --git a/src/function/specify.rs b/src/function/specify.rs index 37e25209e..99539d375 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -1,5 +1,6 @@ #[cfg(feature = "accumulator")] use crate::accumulator::accumulated_map::InputAccumulatedValues; +use crate::active_query::CompletedQuery; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl}; use crate::revision::AtomicRevision; @@ -39,7 +40,7 @@ where // // Now, if We invoke Q3 first, We get one result for Q2, but if We invoke Q4 first, We get a different value. That's no good. let database_key_index = >::database_key_index(zalsa, key); - if !zalsa_local.is_output_of_active_query(database_key_index) { + if !zalsa_local.is_tracked_struct_of_active_query(database_key_index) { panic!("can only use `specify` on salsa structs created during the current tracked fn"); } @@ -63,26 +64,34 @@ where // - a result that is NOT verified and has untracked inputs, which will re-execute (and likely panic) let revision = zalsa.current_revision(); - let mut revisions = QueryRevisions { - changed_at: current_deps.changed_at, - durability: current_deps.durability, - origin: QueryOrigin::assigned(active_query_key), - #[cfg(feature = "accumulator")] - accumulated_inputs: Default::default(), - verified_final: AtomicBool::new(true), - extra: QueryRevisionsExtra::default(), + let mut completed_query = CompletedQuery { + revisions: QueryRevisions { + changed_at: current_deps.changed_at, + durability: current_deps.durability, + origin: QueryOrigin::assigned(active_query_key), + #[cfg(feature = "accumulator")] + accumulated_inputs: Default::default(), + verified_final: AtomicBool::new(true), + extra: QueryRevisionsExtra::default(), + }, + stale_tracked_structs: Vec::new(), }; let memo_ingredient_index = self.memo_ingredient_index(zalsa, key); if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) { - self.backdate_if_appropriate(old_memo, database_key_index, &mut revisions, &value); - self.diff_outputs(zalsa, database_key_index, old_memo, &mut revisions); + self.backdate_if_appropriate( + old_memo, + database_key_index, + &mut completed_query.revisions, + &value, + ); + self.diff_outputs(zalsa, database_key_index, old_memo, &completed_query); } let memo = Memo { value: Some(value), verified_at: AtomicRevision::from(revision), - revisions, + revisions: completed_query.revisions, }; crate::tracing::debug!( diff --git a/src/interned.rs b/src/interned.rs index 547a5a67e..ff25f26e9 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -714,9 +714,7 @@ where zalsa.event(&|| Event::new(EventKind::DidDiscard { key: executor })); - for stale_output in memo.origin().outputs() { - stale_output.remove_stale_output(zalsa, executor); - } + memo.remove_outputs(zalsa, executor); }) }; diff --git a/src/table/memo.rs b/src/table/memo.rs index 3b6366934..dd0d71e58 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -4,7 +4,9 @@ use std::mem; use std::ptr::{self, NonNull}; use crate::sync::atomic::{AtomicPtr, Ordering}; -use crate::{zalsa::MemoIngredientIndex, zalsa_local::QueryOriginRef}; +use crate::zalsa::MemoIngredientIndex; +use crate::zalsa::Zalsa; +use crate::DatabaseKeyIndex; /// The "memo table" stores the memoized results of tracked function calls. /// Every tracked function must take a salsa struct as its first argument @@ -39,8 +41,9 @@ impl MemoTable { } pub trait Memo: Any + Send + Sync { - /// Returns the `origin` of this memo - fn origin(&self) -> QueryOriginRef<'_>; + /// Removes the outputs that were created when this query ran. This includes + /// tracked structs and specified queries. + fn remove_outputs(&self, zalsa: &Zalsa, executor: DatabaseKeyIndex); /// Returns memory usage information about the memoized value. #[cfg(feature = "salsa_unstable")] @@ -115,9 +118,7 @@ impl MemoEntryType { struct DummyMemo; impl Memo for DummyMemo { - fn origin(&self) -> QueryOriginRef<'_> { - unreachable!("should not get here") - } + fn remove_outputs(&self, _zalsa: &Zalsa, _executor: DatabaseKeyIndex) {} #[cfg(feature = "salsa_unstable")] fn memory_usage(&self) -> crate::database::MemoInfo { diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 451533d07..51e3d286d 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -7,6 +7,7 @@ use std::ops::Index; use std::{fmt, mem}; use crossbeam_queue::SegQueue; +use hashbrown::hash_table::Entry; use thin_vec::ThinVec; use tracked_field::FieldIngredientImpl; @@ -227,10 +228,11 @@ impl Identity { } /// Stores the data that (almost) uniquely identifies a tracked struct. -/// This includes the ingredient index of that struct type plus the hash of its untracked fields. -/// This is mapped to a disambiguator -- a value that starts as 0 but increments each round, -/// allowing for multiple tracked structs with the same hash and ingredient_index -/// created within the query to each have a unique id. +/// +/// This includes the ingredient index of that struct type plus the hash of its untracked +/// fields. This is mapped to a disambiguator -- a value that starts as 0 but increments +/// each round, allowing for multiple tracked structs with the same hash and `IngredientIndex` +/// created within the query to each have a unique ID. #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone)] pub struct IdentityHash { /// Index of the tracked struct ingredient. @@ -240,49 +242,120 @@ pub struct IdentityHash { hash: u64, } -/// A map from tracked struct keys (which include the hash + [Disambiguator]) to their -/// final [Id]. +/// A map from tracked struct [`Identity`] to their final [`Id`]. #[derive(Default, Debug)] pub(crate) struct IdentityMap { - // we use a hashtable here as our key contains its own hash (`Identity::hash`) - // so we do the hash wrangling ourselves - table: hashbrown::HashTable<(Identity, Id)>, -} - -impl Clone for IdentityMap { - fn clone(&self) -> Self { - Self { - table: self.table.clone(), - } - } + // We use a `HashTable` here as our key contains its own hash (`Identity::hash`), + // so we do the hash wrangling ourselves. + table: hashbrown::HashTable, } impl IdentityMap { - pub(crate) fn clone_from_slice(&mut self, source: &[(Identity, Id)]) { + /// Seeds the identity map with the IDs from a previous revision. + pub(crate) fn seed(&mut self, source: &[(Identity, Id)]) { self.table.clear(); - self.table.reserve(source.len(), |(k, _)| k.hash); + self.table + .reserve(source.len(), |entry| entry.identity.hash); + + for &(key, id) in source { + self.insert_entry(key, id, false); + } + } - for (key, id) in source { - self.insert(*key, *id); + // Mark all tracked structs in the map as created by the current query. + pub(crate) fn mark_all_active(&mut self) { + for entry in self.table.iter_mut() { + entry.active = true; } } + /// Insert a tracked struct identity into the map with the given ID. pub(crate) fn insert(&mut self, key: Identity, id: Id) -> Option { - let entry = self.table.find_mut(key.hash, |&(k, _)| k == key); + self.insert_entry(key, id, true) + } + + fn insert_entry(&mut self, key: Identity, id: Id, active: bool) -> Option { + let entry = self.table.entry( + key.hash, + |entry| entry.identity == key, + |entry| entry.identity.hash, + ); + match entry { - Some(occupied) => Some(mem::replace(&mut occupied.1, id)), - None => { - self.table - .insert_unique(key.hash, (key, id), |(k, _)| k.hash); + Entry::Vacant(entry) => { + entry.insert(TrackedEntry { + identity: key, + id, + active, + }); + None } + Entry::Occupied(mut entry) => { + let tracked = entry.get_mut(); + tracked.active = active; + + Some(std::mem::replace(&mut tracked.id, id)) + } } } - pub(crate) fn get(&self, key: &Identity) -> Option { + /// Reuses an existing identity if it already exists in the map, marking it as active. + /// + /// Returns the existing ID, or `None` if no ID for the given identity exists. + pub(crate) fn reuse(&mut self, key: &Identity) -> Option { + self.table + .find_mut(key.hash, |entry| key == &entry.identity) + .map(|entry| { + entry.active = true; + entry.id + }) + } + + /// Returns `true` if the given tracked struct key was created in the current query execution. + pub(crate) fn is_active(&self, key: DatabaseKeyIndex) -> bool { self.table - .find(key.hash, |&(k, _)| k == *key) - .map(|&(_, v)| v) + .iter() + .find(|entry| { + entry.id == key.key_index() + && entry.identity.ingredient_index() == key.ingredient_index() + }) + .is_some_and(|entry| entry.active) + } + + /// Drains the [`IdentityMap`] into a tuple of active and stale tracked structs. + /// + /// The first entry contains the identity and IDs of any tracked structs that were + /// created by the current execution of the query, while the second entry contains any + /// tracked structs that were created in a previous execution but not the current one. + pub(crate) fn drain(&mut self) -> (ThinVec<(Identity, Id)>, Vec) { + if self.table.is_empty() { + return (ThinVec::new(), Vec::new()); + } + + let mut stale = Vec::new(); + let mut active = ThinVec::with_capacity(self.table.len()); + + for entry in self.table.drain() { + if entry.active { + active.push((entry.identity, entry.id)); + } else { + stale.push(DatabaseKeyIndex::new( + entry.identity.ingredient_index(), + entry.id, + )); + } + } + + // Removing a stale tracked struct ID shows up in the event logs, so make sure + // the order is stable here. + stale.sort_unstable_by(|a, b| { + a.ingredient_index() + .cmp(&b.ingredient_index()) + .then(a.key_index().cmp(&b.key_index())) + }); + + (active, stale) } pub(crate) fn is_empty(&self) -> bool { @@ -292,10 +365,23 @@ impl IdentityMap { pub(crate) fn clear(&mut self) { self.table.clear() } +} - pub(crate) fn into_thin_vec(self) -> ThinVec<(Identity, Id)> { - self.table.into_iter().collect() - } +/// A tracked struct entry stored in an [`IdentityMap`]. +#[derive(Debug)] +struct TrackedEntry { + /// The identity of the tracked struct. + identity: Identity, + + /// The current ID of the tracked struct. + id: Id, + + /// Whether or not this tracked struct was created by the current query. + /// + /// Entries where `active` is `false` represent tracked structs that were created + /// by a previous execution of the query, but not in the current one, and hence can + /// be collected. + active: bool, } // ANCHOR: ValueStruct @@ -428,7 +514,6 @@ where // The struct already exists in the intern map. let index = self.database_key_index(id); crate::tracing::trace!("Reuse tracked struct {id:?}", id = index); - zalsa_local.add_output(index); // SAFETY: The `id` was present in the interned map, so the value must be initialized. let update_result = @@ -454,7 +539,6 @@ where let id = self.allocate(zalsa, zalsa_local, current_revision, ¤t_deps, fields); let key = self.database_key_index(id); crate::tracing::trace!("Allocated new tracked struct {key:?}"); - zalsa_local.add_output(key); zalsa_local.store_tracked_struct_id(identity, id); FromId::from_id(id) } @@ -750,9 +834,7 @@ where zalsa.event(&|| Event::new(EventKind::DidDiscard { key: executor })); - for stale_output in memo.origin().outputs() { - stale_output.remove_stale_output(zalsa, executor); - } + memo.remove_outputs(zalsa, executor); }) }; @@ -859,17 +941,6 @@ where VerifyResult::unchanged() } - fn mark_validated_output( - &self, - _zalsa: &Zalsa, - _executor: DatabaseKeyIndex, - _output_key: crate::Id, - ) { - // we used to update `update_at` field but now we do it lazilly when data is accessed - // - // FIXME: delete this method - } - fn remove_stale_output( &self, zalsa: &Zalsa, @@ -1132,14 +1203,14 @@ mod tests { assert_eq!(d.insert(i7, Id::from_index(6)), None); assert_eq!(d.insert(i8, Id::from_index(7)), None); - assert_eq!(d.get(&i1), Some(Id::from_index(0))); - assert_eq!(d.get(&i2), Some(Id::from_index(1))); - assert_eq!(d.get(&i3), Some(Id::from_index(2))); - assert_eq!(d.get(&i4), Some(Id::from_index(3))); - assert_eq!(d.get(&i5), Some(Id::from_index(4))); - assert_eq!(d.get(&i6), Some(Id::from_index(5))); - assert_eq!(d.get(&i7), Some(Id::from_index(6))); - assert_eq!(d.get(&i8), Some(Id::from_index(7))); + assert_eq!(d.reuse(&i1), Some(Id::from_index(0))); + assert_eq!(d.reuse(&i2), Some(Id::from_index(1))); + assert_eq!(d.reuse(&i3), Some(Id::from_index(2))); + assert_eq!(d.reuse(&i4), Some(Id::from_index(3))); + assert_eq!(d.reuse(&i5), Some(Id::from_index(4))); + assert_eq!(d.reuse(&i6), Some(Id::from_index(5))); + assert_eq!(d.reuse(&i7), Some(Id::from_index(6))); + assert_eq!(d.reuse(&i8), Some(Id::from_index(7))); }; } } diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index bc82a2057..74cd102fd 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -10,14 +10,14 @@ use crate::accumulator::{ accumulated_map::{AccumulatedMap, AtomicInputAccumulatedValues}, Accumulator, }; -use crate::active_query::QueryStack; +use crate::active_query::{CompletedQuery, QueryStack}; use crate::cycle::{empty_cycle_heads, CycleHeads, IterationCount}; use crate::durability::Durability; use crate::key::DatabaseKeyIndex; use crate::runtime::Stamp; use crate::sync::atomic::AtomicBool; use crate::table::{PageIndex, Slot, Table}; -use crate::tracked_struct::{Disambiguator, Identity, IdentityHash, IdentityMap}; +use crate::tracked_struct::{Disambiguator, Identity, IdentityHash}; use crate::zalsa::{IngredientIndex, Zalsa}; use crate::{Cancelled, Id, Revision}; @@ -243,16 +243,14 @@ impl ZalsaLocal { } } - /// Check whether `entity` is an output of the currently active query (if any) - pub(crate) fn is_output_of_active_query(&self, entity: DatabaseKeyIndex) -> bool { + /// Check whether `entity` is a tracked struct that was created by the currently active query (if any) + pub(crate) fn is_tracked_struct_of_active_query(&self, entity: DatabaseKeyIndex) -> bool { // SAFETY: We do not access the query stack reentrantly. unsafe { self.with_query_stack_unchecked_mut(|stack| { - if let Some(top_query) = stack.last_mut() { - top_query.is_output(entity) - } else { - false - } + stack + .last_mut() + .is_some_and(|top_query| top_query.tracked_struct_ids().is_active(entity)) }) } } @@ -379,11 +377,11 @@ impl ZalsaLocal { pub(crate) fn tracked_struct_id(&self, identity: &Identity) -> Option { // SAFETY: We do not access the query stack reentrantly. unsafe { - self.with_query_stack_unchecked(|stack| { + self.with_query_stack_unchecked_mut(|stack| { let top_query = stack - .last() + .last_mut() .expect("cannot create a tracked struct ID outside of a tracked function"); - top_query.tracked_struct_ids().get(identity) + top_query.tracked_struct_ids_mut().reuse(identity) }) } } @@ -514,7 +512,7 @@ pub(crate) struct QueryRevisionsExtra(Option>); impl QueryRevisionsExtra { pub fn new( #[cfg(feature = "accumulator")] accumulated: AccumulatedMap, - tracked_struct_ids: IdentityMap, + mut tracked_struct_ids: ThinVec<(Identity, Id)>, cycle_heads: CycleHeads, iteration: IterationCount, ) -> Self { @@ -529,11 +527,13 @@ impl QueryRevisionsExtra { { None } else { + tracked_struct_ids.shrink_to_fit(); + Some(Box::new(QueryRevisionsExtraInner { #[cfg(feature = "accumulator")] accumulated, cycle_heads, - tracked_struct_ids: tracked_struct_ids.into_thin_vec(), + tracked_struct_ids, iteration, })) }; @@ -626,7 +626,7 @@ impl QueryRevisions { extra: QueryRevisionsExtra::new( #[cfg(feature = "accumulator")] AccumulatedMap::default(), - IdentityMap::default(), + ThinVec::default(), CycleHeads::initial(query), IterationCount::initial(), ), @@ -638,7 +638,7 @@ impl QueryRevisions { pub(crate) fn accumulated(&self) -> Option<&AccumulatedMap> { self.extra .0 - .as_ref() + .as_deref() .map(|extra| &extra.accumulated) .filter(|map| !map.is_empty()) } @@ -668,7 +668,7 @@ impl QueryRevisions { self.extra = QueryRevisionsExtra::new( #[cfg(feature = "accumulator")] AccumulatedMap::default(), - IdentityMap::default(), + ThinVec::default(), cycle_heads, IterationCount::default(), ); @@ -1113,9 +1113,7 @@ impl ActiveQueryGuard<'_> { assert_eq!(stack.len(), self.push_len); let frame = stack.last_mut().unwrap(); assert!(frame.tracked_struct_ids().is_empty()); - frame - .tracked_struct_ids_mut() - .clone_from_slice(tracked_struct_ids); + frame.tracked_struct_ids_mut().seed(tracked_struct_ids); }) } } @@ -1142,7 +1140,7 @@ impl ActiveQueryGuard<'_> { } /// Invoked when the query has successfully completed execution. - fn complete(self) -> QueryRevisions { + fn complete(self) -> CompletedQuery { // SAFETY: We do not access the query stack reentrantly. let query = unsafe { self.local_state.with_query_stack_unchecked_mut(|stack| { @@ -1157,11 +1155,11 @@ impl ActiveQueryGuard<'_> { query } - /// Pops an active query from the stack. Returns the [`QueryRevisions`] + /// Pops an active query from the stack. Returns the [`CompletedQuery`] /// which summarizes the other queries that were accessed during this /// query's execution. #[inline] - pub(crate) fn pop(self) -> QueryRevisions { + pub(crate) fn pop(self) -> CompletedQuery { self.complete() } } diff --git a/tests/cycle_output.rs b/tests/cycle_output.rs index 975c8a44d..c8e17ba6b 100644 --- a/tests/cycle_output.rs +++ b/tests/cycle_output.rs @@ -187,23 +187,23 @@ fn revalidate_with_change_after_output_read() { "salsa_event(DidValidateInternedValue { key: query_d::interned_arguments(Id(800)), revision: R2 })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", "salsa_event(WillExecute { database_key: query_d(Id(800)) })", - "salsa_event(WillDiscardStaleOutput { execute_key: query_a(Id(0)), output_key: Output(Id(403)) })", - "salsa_event(DidDiscard { key: Output(Id(403)) })", - "salsa_event(DidDiscard { key: read_value(Id(403)) })", "salsa_event(WillDiscardStaleOutput { execute_key: query_a(Id(0)), output_key: Output(Id(401)) })", "salsa_event(DidDiscard { key: Output(Id(401)) })", "salsa_event(DidDiscard { key: read_value(Id(401)) })", "salsa_event(WillDiscardStaleOutput { execute_key: query_a(Id(0)), output_key: Output(Id(402)) })", "salsa_event(DidDiscard { key: Output(Id(402)) })", "salsa_event(DidDiscard { key: read_value(Id(402)) })", + "salsa_event(WillDiscardStaleOutput { execute_key: query_a(Id(0)), output_key: Output(Id(403)) })", + "salsa_event(DidDiscard { key: Output(Id(403)) })", + "salsa_event(DidDiscard { key: read_value(Id(403)) })", "salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: IterationCount(1), fell_back: false })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", - "salsa_event(WillExecute { database_key: read_value(Id(403g1)) })", + "salsa_event(WillExecute { database_key: read_value(Id(401g1)) })", "salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: IterationCount(2), fell_back: false })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", - "salsa_event(WillExecute { database_key: read_value(Id(401g1)) })", + "salsa_event(WillExecute { database_key: read_value(Id(402g1)) })", "salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: IterationCount(3), fell_back: false })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", - "salsa_event(WillExecute { database_key: read_value(Id(402g1)) })", + "salsa_event(WillExecute { database_key: read_value(Id(403g1)) })", ]"#]]); } diff --git a/tests/memory-usage.rs b/tests/memory-usage.rs index f62a961a8..37d434292 100644 --- a/tests/memory-usage.rs +++ b/tests/memory-usage.rs @@ -167,7 +167,7 @@ fn test() { IngredientInfo { debug_name: "memory_usage::MyTracked", count: 2, - size_of_metadata: 192, + size_of_metadata: 168, size_of_fields: 16, heap_size_of_fields: None, }, @@ -177,7 +177,7 @@ fn test() { IngredientInfo { debug_name: "(memory_usage::MyTracked, memory_usage::MyTracked)", count: 1, - size_of_metadata: 132, + size_of_metadata: 108, size_of_fields: 16, heap_size_of_fields: None, }, diff --git a/tests/persistence.rs b/tests/persistence.rs index 8848d79d3..379f14523 100644 --- a/tests/persistence.rs +++ b/tests/persistence.rs @@ -246,15 +246,6 @@ fn everything() { }, "ingredient_index": 1 } - }, - { - "key": { - "key_index": { - "index": 3073, - "generation": 0 - }, - "ingredient_index": 2147483652 - } } ] }, From 918d35d873b2b73a0237536144ef4d22e8d57f27 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 12 Aug 2025 12:09:41 +0200 Subject: [PATCH 30/65] Make `thin-vec/serde` dependency dependent on `persistence` feature (#973) --- Cargo.toml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bc28eab4f..0998e2329 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ parking_lot = "0.12" portable-atomic = "1" rustc-hash = "2" smallvec = "1" -thin-vec = { version = "0.2.14", features = ["serde"] } +thin-vec = { version = "0.2.14" } tracing = { version = "0.1", default-features = false, features = ["std"] } # Automatic ingredient registration. @@ -44,7 +44,12 @@ serde = { version = "1.0.219", features = ["derive"], optional = true } [features] default = ["salsa_unstable", "rayon", "macros", "inventory", "accumulator"] inventory = ["dep:inventory"] -persistence = ["dep:serde", "dep:erased-serde", "salsa-macros/persistence"] +persistence = [ + "dep:serde", + "dep:erased-serde", + "salsa-macros/persistence", + "thin-vec/serde", +] shuttle = ["dep:shuttle"] accumulator = ["salsa-macro-rules/accumulator"] macros = ["dep:salsa-macros"] From 5aab823a76aef3805948ac6988bfbce27f1a3643 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Wed, 13 Aug 2025 13:04:05 -0400 Subject: [PATCH 31/65] optimize `Id::hash` (#974) --- src/id.rs | 11 +++++++++-- src/key.rs | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/id.rs b/src/id.rs index c8712c102..e50141dbb 100644 --- a/src/id.rs +++ b/src/id.rs @@ -1,5 +1,5 @@ use std::fmt::Debug; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::num::NonZeroU32; use crate::zalsa::Zalsa; @@ -17,7 +17,7 @@ use crate::zalsa::Zalsa; /// /// As an end-user of `Salsa` you will generally not use `Id` directly, /// it is wrapped in new types. -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct Id { index: NonZeroU32, @@ -126,6 +126,13 @@ impl Id { } } +impl Hash for Id { + fn hash(&self, state: &mut H) { + // Convert to a `u64` to avoid dispatching multiple calls to `H::write`. + state.write_u64(self.as_bits()); + } +} + impl Debug for Id { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if self.generation() == 0 { diff --git a/src/key.rs b/src/key.rs index bb5604c5f..fa947575f 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,4 +1,4 @@ -use core::fmt; +use std::fmt; use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::zalsa::{IngredientIndex, Zalsa}; From 34882a129b9687519e5206d478828628fa9fd6db Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Wed, 13 Aug 2025 16:49:01 -0400 Subject: [PATCH 32/65] Flatten unserializable query dependencies (#975) * flatten unserializable query dependencies * avoid allocating intermediary serialized dependency edges --- .../salsa-macro-rules/src/setup_tracked_fn.rs | 2 +- src/accumulator.rs | 11 + src/function.rs | 100 +++-- src/function/memo.rs | 34 +- src/id.rs | 21 +- src/ingredient.rs | 23 +- src/input.rs | 13 +- src/input/input_field.rs | 19 +- src/interned.rs | 20 + src/revision.rs | 1 + src/tracked_struct.rs | 23 +- src/tracked_struct/tracked_field.rs | 24 +- src/zalsa_local.rs | 87 +++-- tests/interned-revisions.rs | 2 +- tests/persistence.rs | 348 +++++++++++++++--- 15 files changed, 608 insertions(+), 120 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 51239802c..e252f068f 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -172,7 +172,7 @@ macro_rules! setup_tracked_fn { line: line!(), }; const DEBUG_NAME: &'static str = concat!($(stringify!($self_ty), "::",)? stringify!($fn_name), "::interned_arguments"); - const PERSIST: bool = true; + const PERSIST: bool = $persist; type Fields<$db_lt> = ($($interned_input_ty),*); diff --git a/src/accumulator.rs b/src/accumulator.rs index 76ecc4678..b05aa64f1 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -8,11 +8,13 @@ use std::panic::UnwindSafe; use accumulated::{Accumulated, AnyAccumulated}; use crate::function::{VerifyCycleHeads, VerifyResult}; +use crate::hash::FxIndexSet; use crate::ingredient::{Ingredient, Jar}; use crate::plumbing::ZalsaLocal; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::zalsa::{IngredientIndex, JarKind, Zalsa}; +use crate::zalsa_local::QueryEdge; use crate::{Database, Id, Revision}; mod accumulated; @@ -110,6 +112,15 @@ impl Ingredient for IngredientImpl { panic!("nothing should ever depend on an accumulator directly") } + fn collect_minimum_serialized_edges( + &self, + _zalsa: &Zalsa, + _edge: QueryEdge, + _serialized_edges: &mut FxIndexSet, + ) { + panic!("nothing should ever depend on an accumulator directly") + } + fn debug_name(&self) -> &'static str { A::DEBUG_NAME } diff --git a/src/function.rs b/src/function.rs index 28b76a0cc..711cdf723 100644 --- a/src/function.rs +++ b/src/function.rs @@ -13,6 +13,7 @@ use crate::cycle::{ use crate::database::RawDatabase; use crate::function::delete::DeletedEntries; use crate::function::sync::{ClaimResult, SyncTable}; +use crate::hash::FxIndexSet; use crate::ingredient::{Ingredient, WaitForResult}; use crate::key::DatabaseKeyIndex; use crate::plumbing::{self, MemoIngredientMap}; @@ -22,7 +23,7 @@ use crate::table::memo::MemoTableTypes; use crate::table::Table; use crate::views::DatabaseDownCaster; use crate::zalsa::{IngredientIndex, JarKind, MemoIngredientIndex, Zalsa}; -use crate::zalsa_local::QueryOriginRef; +use crate::zalsa_local::{QueryEdge, QueryOriginRef}; use crate::{Id, Revision}; #[cfg(feature = "accumulator")] @@ -277,7 +278,7 @@ where unsafe fn maybe_changed_after( &self, - _zalsa: &crate::zalsa::Zalsa, + _zalsa: &Zalsa, db: RawDatabase<'_>, input: Id, revision: Revision, @@ -288,6 +289,29 @@ where self.maybe_changed_after(db, input, revision, cycle_heads) } + fn collect_minimum_serialized_edges( + &self, + zalsa: &Zalsa, + edge: QueryEdge, + serialized_edges: &mut FxIndexSet, + ) { + let input = edge.key().key_index(); + + let Some(memo) = + self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) + else { + return; + }; + + let origin = memo.revisions.origin.as_ref(); + + // Collect the minimum dependency tree. + for edge in origin.edges() { + let dependency = zalsa.lookup_ingredient(edge.key().ingredient_index()); + dependency.collect_minimum_serialized_edges(zalsa, *edge, serialized_edges) + } + } + /// Returns `final` only if the memo has the `verified_final` flag set and the cycle recovery strategy is not `FallbackImmediate`. /// /// Otherwise, the value is still provisional. For both final and provisional, it also @@ -470,8 +494,10 @@ where #[cfg(feature = "persistence")] mod persistence { use super::{Configuration, IngredientImpl, Memo}; - use crate::plumbing::{Ingredient, MemoIngredientMap, SalsaStructInDb}; + use crate::hash::FxIndexSet; + use crate::plumbing::{MemoIngredientMap, SalsaStructInDb}; use crate::zalsa::Zalsa; + use crate::zalsa_local::{QueryEdge, QueryOrigin, QueryOriginRef}; use crate::{Id, IngredientIndex}; use serde::de; @@ -499,16 +525,6 @@ mod persistence { let mut map = serializer.serialize_map(None)?; - for struct_index in - as SalsaStructInDb>::lookup_ingredient_index(zalsa).iter() - { - let struct_ingredient = zalsa.lookup_ingredient(struct_index); - assert!( - struct_ingredient.is_persistable(), - "the input of a serialized tracked function must be serialized" - ); - } - for entry in as SalsaStructInDb>::entries(zalsa) { let memo_ingredient_index = ingredient .memo_ingredient_indices @@ -521,18 +537,30 @@ mod persistence { ); if let Some(memo) = memo.filter(|memo| memo.should_serialize()) { - for edge in memo.revisions.origin.as_ref().edges() { - let dependency = zalsa.lookup_ingredient(edge.key().ingredient_index()); - - // TODO: This is not strictly necessary, we only need the transitive input - // dependencies of this query to serialize a valid memo. - assert!( - dependency.is_persistable(), - "attempted to serialize query `{}`, but dependency `{}` is not persistable", - ingredient.debug_name(), - dependency.debug_name() - ); - } + // Flatten the dependencies of this query down to the base inputs. + let flattened_origin = match memo.revisions.origin.as_ref() { + QueryOriginRef::Derived(edges) => { + QueryOrigin::derived(flatten_edges(zalsa, edges)) + } + QueryOriginRef::DerivedUntracked(edges) => { + QueryOrigin::derived_untracked(flatten_edges(zalsa, edges)) + } + QueryOriginRef::Assigned(key) => { + let dependency = zalsa.lookup_ingredient(key.ingredient_index()); + assert!( + dependency.is_persistable(), + "specified query `{}` must be persistable", + dependency.debug_name() + ); + + QueryOrigin::assigned(key) + } + QueryOriginRef::FixpointInitial => unreachable!( + "`should_serialize` returns `false` for provisional queries" + ), + }; + + let memo = memo.with_origin(flattened_origin); // TODO: Group structs by ingredient index into a nested map. let key = format!( @@ -541,7 +569,7 @@ mod persistence { entry.key_index().as_bits() ); - map.serialize_entry(&key, memo)?; + map.serialize_entry(&key, &memo)?; } } @@ -549,6 +577,26 @@ mod persistence { } } + // Flatten the dependency edges before serialization. + fn flatten_edges(zalsa: &Zalsa, edges: &[QueryEdge]) -> FxIndexSet { + let mut flattened_edges = + FxIndexSet::with_capacity_and_hasher(edges.len(), Default::default()); + + for &edge in edges { + let dependency = zalsa.lookup_ingredient(edge.key().ingredient_index()); + + if dependency.is_persistable() { + // If the dependency will be serialized, we can serialize the edge directly. + flattened_edges.insert(edge); + } else { + // Otherwise, serialize the minimum edges necessary to cover the dependency. + dependency.collect_minimum_serialized_edges(zalsa, edge, &mut flattened_edges); + } + } + + flattened_edges + } + pub struct DeserializeIngredient<'db, C> where C: Configuration, diff --git a/src/function/memo.rs b/src/function/memo.rs index 54df61420..4894cc642 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -359,12 +359,36 @@ mod persistence { use crate::function::memo::Memo; use crate::function::Configuration; use crate::revision::AtomicRevision; - use crate::zalsa_local::QueryRevisions; + use crate::zalsa_local::persistence::MappedQueryRevisions; + use crate::zalsa_local::{QueryOrigin, QueryRevisions}; use serde::ser::SerializeStruct; use serde::Deserialize; - impl serde::Serialize for Memo<'_, C> + /// A reference to the fields of a [`Memo`], with its [`QueryRevisions`] transformed. + pub(crate) struct MappedMemo<'memo, 'db, C: Configuration> { + value: Option<&'memo C::Output<'db>>, + verified_at: AtomicRevision, + revisions: MappedQueryRevisions<'memo>, + } + + impl<'db, C: Configuration> Memo<'db, C> { + pub(crate) fn with_origin(&self, origin: QueryOrigin) -> MappedMemo<'_, 'db, C> { + let Memo { + ref verified_at, + ref value, + ref revisions, + } = *self; + + MappedMemo { + value: value.as_ref(), + verified_at: AtomicRevision::from(verified_at.load()), + revisions: revisions.with_origin(origin), + } + } + } + + impl serde::Serialize for MappedMemo<'_, '_, C> where C: Configuration, { @@ -386,13 +410,15 @@ mod persistence { } } - let Memo { + let MappedMemo { value, verified_at, revisions, } = self; - let value = value.as_ref().expect("attempted to serialize empty memo"); + let value = value.expect( + "attempted to serialize memo where `Memo::should_serialize` returned `false`", + ); let mut s = serializer.serialize_struct("Memo", 3)?; s.serialize_field("value", &SerializeValue::(value))?; diff --git a/src/id.rs b/src/id.rs index e50141dbb..bc2565410 100644 --- a/src/id.rs +++ b/src/id.rs @@ -18,7 +18,6 @@ use crate::zalsa::Zalsa; /// As an end-user of `Salsa` you will generally not use `Id` directly, /// it is wrapped in new types. #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct Id { index: NonZeroU32, generation: u32, @@ -133,6 +132,26 @@ impl Hash for Id { } } +#[cfg(feature = "persistence")] +impl serde::Serialize for Id { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serde::Serialize::serialize(&self.as_bits(), serializer) + } +} + +#[cfg(feature = "persistence")] +impl<'de> serde::Deserialize<'de> for Id { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + serde::Deserialize::deserialize(deserializer).map(Self::from_bits) + } +} + impl Debug for Id { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if self.generation() == 0 { diff --git a/src/ingredient.rs b/src/ingredient.rs index 4cd857962..73377a3b7 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -6,12 +6,13 @@ use crate::cycle::{ }; use crate::database::RawDatabase; use crate::function::{VerifyCycleHeads, VerifyResult}; +use crate::hash::FxIndexSet; use crate::runtime::Running; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::table::Table; use crate::zalsa::{transmute_data_mut_ptr, transmute_data_ptr, IngredientIndex, JarKind, Zalsa}; -use crate::zalsa_local::QueryOriginRef; +use crate::zalsa_local::{QueryEdge, QueryOriginRef}; use crate::{DatabaseKeyIndex, Id, Revision}; /// A "jar" is a group of ingredients that are added atomically. @@ -55,6 +56,20 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult; + /// Collects the minimum edges necessary to serialize a given dependency edge on this ingredient, + /// without necessarily serializing the dependency edge itself. + /// + /// This generally only returns any transitive input dependencies, i.e. the leaves of the dependency + /// tree, as most other fine-grained dependencies are covered by the inputs. + /// + /// Note that any ingredients returned by this function must be persistable. + fn collect_minimum_serialized_edges( + &self, + zalsa: &Zalsa, + edge: QueryEdge, + serialized_edges: &mut FxIndexSet, + ); + /// Returns information about the current provisional status of `input`. /// /// Is it a provisional value or has it been finalized and in which iteration. @@ -217,7 +232,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { _zalsa: &'db Zalsa, _f: &mut dyn FnMut(&dyn erased_serde::Serialize), ) { - unimplemented!("called `serialize` on ingredient where `is_persistable` returns `false`") + unimplemented!("called `serialize` on ingredient where `should_serialize` returns `false`") } /// Deserialize the ingredient. @@ -227,7 +242,9 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { _zalsa: &mut Zalsa, _deserializer: &mut dyn erased_serde::Deserializer, ) -> Result<(), erased_serde::Error> { - unimplemented!("called `deserialize` on ingredient where `is_persistable` returns `false`") + unimplemented!( + "called `deserialize` on ingredient where `should_serialize` returns `false`" + ) } } diff --git a/src/input.rs b/src/input.rs index fd8fc018a..31b4cc8db 100644 --- a/src/input.rs +++ b/src/input.rs @@ -9,6 +9,7 @@ pub mod singleton; use input_field::FieldIngredientImpl; use crate::function::{VerifyCycleHeads, VerifyResult}; +use crate::hash::FxIndexSet; use crate::id::{AsId, FromId, FromIdWithDb}; use crate::ingredient::Ingredient; use crate::input::singleton::{Singleton, SingletonChoice}; @@ -18,6 +19,7 @@ use crate::sync::Arc; use crate::table::memo::{MemoTable, MemoTableTypes}; use crate::table::{Slot, Table}; use crate::zalsa::{IngredientIndex, JarKind, Zalsa}; +use crate::zalsa_local::QueryEdge; use crate::{Durability, Id, Revision, Runtime}; pub trait Configuration: Any { @@ -274,7 +276,16 @@ impl Ingredient for IngredientImpl { ) -> VerifyResult { // Input ingredients are just a counter, they store no data, they are immortal. // Their *fields* are stored in function ingredients elsewhere. - VerifyResult::unchanged() + panic!("nothing should ever depend on an input struct directly") + } + + fn collect_minimum_serialized_edges( + &self, + _zalsa: &Zalsa, + _edge: QueryEdge, + _serialized_edges: &mut FxIndexSet, + ) { + panic!("nothing should ever depend on an input struct directly") } fn debug_name(&self) -> &'static str { diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 7bfeb507a..5b8f0706f 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -2,11 +2,13 @@ use std::fmt; use std::marker::PhantomData; use crate::function::{VerifyCycleHeads, VerifyResult}; +use crate::hash::FxIndexSet; use crate::ingredient::Ingredient; use crate::input::{Configuration, IngredientImpl, Value}; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::zalsa::{IngredientIndex, JarKind, Zalsa}; +use crate::zalsa_local::QueryEdge; use crate::{Id, Revision}; /// Ingredient used to represent the fields of a `#[salsa::input]`. @@ -51,7 +53,7 @@ where unsafe fn maybe_changed_after( &self, - zalsa: &crate::zalsa::Zalsa, + zalsa: &Zalsa, _db: crate::database::RawDatabase<'_>, input: Id, revision: Revision, @@ -61,6 +63,21 @@ where VerifyResult::changed_if(value.revisions[self.field_index] > revision) } + fn collect_minimum_serialized_edges( + &self, + _zalsa: &Zalsa, + edge: QueryEdge, + serialized_edges: &mut FxIndexSet, + ) { + assert!( + C::PERSIST, + "the inputs of a persistable tracked function must be persistable" + ); + + // Input dependencies are the leaves of the minimum dependency tree. + serialized_edges.insert(edge); + } + fn fmt_index(&self, index: crate::Id, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!( fmt, diff --git a/src/interned.rs b/src/interned.rs index ff25f26e9..c425b9e24 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -12,6 +12,7 @@ use rustc_hash::FxBuildHasher; use crate::durability::Durability; use crate::function::{VerifyCycleHeads, VerifyResult}; +use crate::hash::FxIndexSet; use crate::id::{AsId, FromId}; use crate::ingredient::Ingredient; use crate::plumbing::{self, Jar, ZalsaLocal}; @@ -20,6 +21,7 @@ use crate::sync::{Arc, Mutex, OnceLock}; use crate::table::memo::{MemoTable, MemoTableTypes, MemoTableWithTypesMut}; use crate::table::Slot; use crate::zalsa::{IngredientIndex, JarKind, Zalsa}; +use crate::zalsa_local::QueryEdge; use crate::{DatabaseKeyIndex, Event, EventKind, Id, Revision}; /// Trait that defines the key properties of an interned struct. @@ -896,6 +898,24 @@ where VerifyResult::unchanged() } + fn collect_minimum_serialized_edges( + &self, + _zalsa: &Zalsa, + edge: QueryEdge, + serialized_edges: &mut FxIndexSet, + ) { + if C::PERSIST { + // If the interned struct is being persisted, it may be reachable through transitive queries. + // Additionally, interned struct dependencies are impure in that garbage collection can + // invalidate a dependency without a base input necessarily being updated. Thus, we must + // preserve the transitive dependency on the interned struct. + serialized_edges.insert(edge); + } + + // Otherwise, the dependency is covered by the base inputs, as the interned struct itself is + // not being persisted. + } + fn debug_name(&self) -> &'static str { C::DEBUG_NAME } diff --git a/src/revision.rs b/src/revision.rs index 852313e0d..22f1463d7 100644 --- a/src/revision.rs +++ b/src/revision.rs @@ -139,6 +139,7 @@ impl AtomicRevision { pub(crate) struct OptionalAtomicRevision { data: AtomicUsize, } + #[cfg(feature = "persistence")] impl serde::Serialize for OptionalAtomicRevision { fn serialize(&self, serializer: S) -> Result diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 51e3d286d..6ad91321a 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -12,6 +12,7 @@ use thin_vec::ThinVec; use tracked_field::FieldIngredientImpl; use crate::function::{VerifyCycleHeads, VerifyResult}; +use crate::hash::FxIndexSet; use crate::id::{AsId, FromId}; use crate::ingredient::{Ingredient, Jar}; use crate::key::DatabaseKeyIndex; @@ -23,6 +24,7 @@ use crate::sync::Arc; use crate::table::memo::{MemoTable, MemoTableTypes, MemoTableWithTypesMut}; use crate::table::{Slot, Table}; use crate::zalsa::{IngredientIndex, JarKind, Zalsa}; +use crate::zalsa_local::QueryEdge; use crate::{Durability, Event, EventKind, Id, Revision}; pub mod tracked_field; @@ -937,8 +939,25 @@ where _revision: Revision, _cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult { - // Any change to a tracked struct results in a new ID generation. - VerifyResult::unchanged() + // Any change to a tracked struct results in a new ID generation, so there + // are no direct dependencies on the struct, only on its tracked fields. + panic!("nothing should ever depend on a tracked struct directly") + } + + fn collect_minimum_serialized_edges( + &self, + _zalsa: &Zalsa, + _edge: QueryEdge, + _serialized_edges: &mut FxIndexSet, + ) { + // Note that tracked structs are referenced by the identity map, but that + // only matters if we are serializing the creating query, in which case + // the dependency edge will be serialized directly. + // + // TODO: We could flatten the identity map here if the tracked struct is being + // persisted, in order to more aggressively preserve the tracked struct IDs if + // the transitive query is re-executed. + panic!("nothing should ever depend on a tracked struct directly") } fn remove_stale_output( diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 20c59b81c..d4e90e278 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,11 +1,13 @@ use std::marker::PhantomData; use crate::function::{VerifyCycleHeads, VerifyResult}; +use crate::hash::FxIndexSet; use crate::ingredient::Ingredient; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::tracked_struct::{Configuration, Value}; -use crate::zalsa::{IngredientIndex, JarKind}; +use crate::zalsa::{IngredientIndex, JarKind, Zalsa}; +use crate::zalsa_local::QueryEdge; use crate::Id; /// Created for each tracked struct. @@ -67,6 +69,16 @@ where VerifyResult::changed_if(field_changed_at > revision) } + fn collect_minimum_serialized_edges( + &self, + _zalsa: &Zalsa, + _edge: QueryEdge, + _serialized_edges: &mut FxIndexSet, + ) { + // Tracked fields do not have transitive dependencies, and their dependencies are covered by + // the base inputs. + } + fn fmt_index(&self, index: crate::Id, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( fmt, @@ -92,6 +104,16 @@ where fn memo_table_types_mut(&mut self) -> &mut Arc { unreachable!("tracked field does not allocate pages") } + + fn is_persistable(&self) -> bool { + // Tracked field dependencies are valid as long as the tracked struct is persistable. + C::PERSIST + } + + fn should_serialize(&self, _zalsa: &Zalsa) -> bool { + // However, they are never serialized directly. + false + } } impl std::fmt::Debug for FieldIngredientImpl diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 74cd102fd..ced5e9281 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -413,8 +413,7 @@ impl ZalsaLocal { // - neither can `query_stack` as we require the closures accessing it to be `UnwindSafe` impl std::panic::RefUnwindSafe for ZalsaLocal {} -/// Summarizes "all the inputs that a query used" -/// and "all the outputs it has written to" +/// Summarizes "all the inputs that a query used" and "all the outputs it has written to". #[derive(Debug)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] // #[derive(Clone)] cloning this is expensive, so we don't derive @@ -434,8 +433,7 @@ pub(crate) struct QueryRevisions { /// Note that this field could be in `QueryRevisionsExtra` as it is only relevant /// for accumulators, but we get it for free anyways due to padding. #[cfg(feature = "accumulator")] - // TODO: Support serializing accumulators. - #[cfg_attr(feature = "persistence", serde(skip))] + #[cfg_attr(feature = "persistence", serde(skip))] // TODO: Support serializing accumulators pub(super) accumulated_inputs: AtomicInputAccumulatedValues, /// Are the `cycle_heads` verified to not be provisional anymore? @@ -443,33 +441,13 @@ pub(crate) struct QueryRevisions { /// Note that this field could be in `QueryRevisionsExtra` as it is only /// relevant for queries that participate in a cycle, but we get it for /// free anyways due to padding. - #[cfg_attr(feature = "persistence", serde(with = "verified_final"))] + #[cfg_attr(feature = "persistence", serde(with = "persistence::verified_final"))] pub(super) verified_final: AtomicBool, /// Lazily allocated state. pub(super) extra: QueryRevisionsExtra, } -#[cfg(feature = "persistence")] -// A workaround the fact that `shuttle` atomic types do not implement `serde::{Serialize, Deserialize}`. -mod verified_final { - use crate::sync::atomic::{AtomicBool, Ordering}; - - pub fn serialize(value: &AtomicBool, serializer: S) -> Result - where - S: serde::Serializer, - { - serde::Serialize::serialize(&value.load(Ordering::Relaxed), serializer) - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - serde::Deserialize::deserialize(deserializer).map(AtomicBool::new) - } -} - impl QueryRevisions { #[cfg(feature = "salsa_unstable")] pub(crate) fn allocation_size(&self) -> usize { @@ -1178,3 +1156,62 @@ impl Drop for ActiveQueryGuard<'_> { }; } } + +#[cfg(feature = "persistence")] +pub(crate) mod persistence { + use super::{QueryOrigin, QueryRevisions, QueryRevisionsExtra}; + use crate::sync::atomic::{AtomicBool, Ordering}; + use crate::{Durability, Revision}; + + /// A reference to the fields of [`QueryRevisions`], with its [`QueryOrigin`] transformed. + #[derive(serde::Serialize)] + pub(crate) struct MappedQueryRevisions<'a> { + changed_at: Revision, + durability: Durability, + origin: QueryOrigin, + #[serde(with = "verified_final")] + verified_final: AtomicBool, + extra: &'a QueryRevisionsExtra, + } + + impl QueryRevisions { + pub(crate) fn with_origin(&self, origin: QueryOrigin) -> MappedQueryRevisions<'_> { + let QueryRevisions { + changed_at, + durability, + ref verified_final, + ref extra, + #[cfg(feature = "accumulator")] + accumulated_inputs: _, // TODO: Support serializing accumulators + origin: _, + } = *self; + + MappedQueryRevisions { + changed_at, + durability, + extra, + origin, + verified_final: AtomicBool::new(verified_final.load(Ordering::Relaxed)), + } + } + } + + // A workaround the fact that `shuttle` atomic types do not implement `serde::{Serialize, Deserialize}`. + pub(super) mod verified_final { + use crate::sync::atomic::{AtomicBool, Ordering}; + + pub fn serialize(value: &AtomicBool, serializer: S) -> Result + where + S: serde::Serializer, + { + serde::Serialize::serialize(&value.load(Ordering::Relaxed), serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + serde::Deserialize::deserialize(deserializer).map(AtomicBool::new) + } + } +} diff --git a/tests/interned-revisions.rs b/tests/interned-revisions.rs index 225f24d4f..bef1db61c 100644 --- a/tests/interned-revisions.rs +++ b/tests/interned-revisions.rs @@ -273,7 +273,7 @@ fn test_reuse_indirect() { #[salsa::tracked] fn intern_inner<'db>(db: &'db dyn Database, input: Input, value: usize) -> Interned<'db> { - let _i = input.field1(db); + let _i = input.field1(db); // Only low durability interned values are garbage collected. Interned::new(db, BadHash(value)) } diff --git a/tests/persistence.rs b/tests/persistence.rs index 379f14523..a1f424e3d 100644 --- a/tests/persistence.rs +++ b/tests/persistence.rs @@ -1,7 +1,9 @@ #![cfg(all(feature = "persistence", feature = "inventory"))] mod common; + use common::LogDatabase; +use salsa::{Database, Durability, Setter}; use expect_test::expect; @@ -159,7 +161,7 @@ fn everything() { ] } }, - "4": { + "3": { "3073": { "durability": "Low", "updated_at": 1, @@ -169,31 +171,25 @@ fn everything() { ] } }, - "6": { + "5": { "4097": { "durability": "High", "last_interned_at": 18446744073709551615, "fields": [ - { - "index": 3, - "generation": 0 - }, - { - "index": 4, - "generation": 0 - } + 3, + 4 ] } }, - "11": { + "17": { "1025": { "durability": "High", "last_interned_at": 18446744073709551615, "fields": null } }, - "5": { - "6:4097": { + "4": { + "5:4097": { "value": "aaa", "verified_at": 1, "revisions": { @@ -203,19 +199,13 @@ fn everything() { "Derived": [ { "key": { - "key_index": { - "index": 3, - "generation": 0 - }, + "key_index": 3, "ingredient_index": 1 } }, { "key": { - "key_index": { - "index": 4, - "generation": 0 - }, + "key_index": 4, "ingredient_index": 1 } } @@ -226,12 +216,9 @@ fn everything() { } } }, - "7": { + "6": { "0:3": { - "value": { - "index": 3073, - "generation": 0 - }, + "value": 3073, "verified_at": 1, "revisions": { "changed_at": 1, @@ -240,10 +227,7 @@ fn everything() { "Derived": [ { "key": { - "key_index": { - "index": 3, - "generation": 0 - }, + "key_index": 3, "ingredient_index": 1 } } @@ -254,14 +238,11 @@ fn everything() { "tracked_struct_ids": [ [ { - "ingredient_index": 4, + "ingredient_index": 3, "hash": 6073466998405137972, "disambiguator": 0 }, - { - "index": 3073, - "generation": 0 - } + 3073 ] ], "cycle_heads": [], @@ -270,12 +251,9 @@ fn everything() { } } }, - "10": { - "11:1025": { - "value": { - "index": 2049, - "generation": 0 - }, + "16": { + "17:1025": { + "value": 2049, "verified_at": 1, "revisions": { "changed_at": 1, @@ -284,10 +262,7 @@ fn everything() { "Derived": [ { "key": { - "key_index": { - "index": 2049, - "generation": 0 - }, + "key_index": 2049, "ingredient_index": 2 } } @@ -314,7 +289,7 @@ fn everything() { let _out = input_to_tracked(&db, input1); let _out = input_pair_to_string(&db, input1, input2); - // The structs are not recreated, and the queries are not reexecuted. + // The structs are not recreated, and the queries are not re-executed. db.assert_logs(expect![[r#" [ "DidSetCancellationFlag", @@ -325,21 +300,286 @@ fn everything() { } #[test] -#[should_panic(expected = "is not persistable")] -fn invalid_dependency() { - #[salsa::interned] - struct MyInterned<'db> { - field: usize, +fn partial_query() { + use salsa::plumbing::{FromId, ZalsaDatabase}; + + #[salsa::tracked(persist)] + fn query<'db>(db: &'db dyn salsa::Database, input: MyInput) -> usize { + inner_query(db, input) + 1 + } + + // Note that the inner query is not persisted, but we should still preserve the dependency on `input.field`. + #[salsa::tracked] + fn inner_query<'db>(db: &'db dyn salsa::Database, input: MyInput) -> usize { + input.field(db) } + let mut db = common::EventLoggerDatabase::default(); + + let input = MyInput::new(&db, 0); + + let result = query(&db, input); + assert_eq!(result, 1); + + let serialized = + serde_json::to_string_pretty(&::as_serialize(&mut db)).unwrap(); + let expected = expect![[r#" + { + "runtime": { + "revisions": [ + 1, + 1, + 1 + ] + }, + "ingredients": { + "0": { + "1": { + "durabilities": [ + "Low" + ], + "revisions": [ + 1 + ], + "fields": [ + 0 + ] + } + }, + "11": { + "0:1": { + "value": 1, + "verified_at": 1, + "revisions": { + "changed_at": 1, + "durability": "Low", + "origin": { + "Derived": [ + { + "key": { + "key_index": 1, + "ingredient_index": 1 + } + } + ] + }, + "verified_final": true, + "extra": null + } + } + } + } + }"#]]; + expected.assert_eq(&serialized); + + let mut db = common::EventLoggerDatabase::default(); + ::deserialize( + &mut db, + &mut serde_json::Deserializer::from_str(&serialized), + ) + .unwrap(); + + // TODO: Expose a better way of recreating inputs after deserialization. + let (id, _) = MyInput::ingredient(&db) + .entries(db.zalsa()) + .next() + .expect("`MyInput` was persisted"); + let input = MyInput::from_id(id.key_index()); + + let result = query(&db, input); + assert_eq!(result, 1); + + // The query was not re-executed. + db.assert_logs(expect![[r#" + [ + "DidSetCancellationFlag", + "WillCheckCancellation", + ]"#]]); + + input.set_field(&mut db).to(1); + + let result = query(&db, input); + assert_eq!(result, 2); + + // The query was re-executed afer the input was updated. + db.assert_logs(expect![[r#" + [ + "DidSetCancellationFlag", + "WillCheckCancellation", + "WillExecute { database_key: query(Id(0)) }", + "WillCheckCancellation", + "WillExecute { database_key: inner_query(Id(0)) }", + ]"#]]); +} + +#[test] +fn partial_query_interned() { + use salsa::plumbing::{AsId, FromId, ZalsaDatabase}; + #[salsa::tracked(persist)] - fn new_interned(db: &dyn salsa::Database) { - let _interned = MyInterned::new(db, 0); + fn intern<'db>(db: &'db dyn salsa::Database, input: MyInput, value: usize) -> MyInterned<'db> { + do_intern(db, input, value) + } + + // Note that the inner query is not persisted, but we should still preserve the dependency on `MyInterned`. + #[salsa::tracked] + fn do_intern<'db>( + db: &'db dyn salsa::Database, + input: MyInput, + value: usize, + ) -> MyInterned<'db> { + let _i = input.field(db); // Only low durability interned values are garbage collected. + MyInterned::new(db, value.to_string()) + } + + let mut db = common::EventLoggerDatabase::default(); + let input = MyInput::builder(0).durability(Durability::LOW).new(&db); + + // Intern `i0`. + let i0 = intern(&db, input, 0); + assert_eq!(i0.field(&db), "0"); + + let serialized = + serde_json::to_string_pretty(&::as_serialize(&mut db)).unwrap(); + let expected = expect![[r#" + { + "runtime": { + "revisions": [ + 1, + 1, + 1 + ] + }, + "ingredients": { + "0": { + "1": { + "durabilities": [ + "Low" + ], + "revisions": [ + 1 + ], + "fields": [ + 0 + ] + } + }, + "2": { + "3073": { + "durability": "Low", + "last_interned_at": 1, + "fields": [ + "0" + ] + } + }, + "15": { + "1025": { + "durability": "High", + "last_interned_at": 18446744073709551615, + "fields": [ + 1, + 0 + ] + } + }, + "14": { + "15:1025": { + "value": 3073, + "verified_at": 1, + "revisions": { + "changed_at": 1, + "durability": "Low", + "origin": { + "Derived": [ + { + "key": { + "key_index": 1, + "ingredient_index": 1 + } + }, + { + "key": { + "key_index": 3073, + "ingredient_index": 2 + } + } + ] + }, + "verified_final": true, + "extra": null + } + } + } + } + }"#]]; + expected.assert_eq(&serialized); + + let mut db = common::EventLoggerDatabase::default(); + ::deserialize( + &mut db, + &mut serde_json::Deserializer::from_str(&serialized), + ) + .unwrap(); + + // TODO: Expose a better way of recreating inputs after deserialization. + let (id, _) = MyInput::ingredient(&db) + .entries(db.zalsa()) + .next() + .expect("`MyInput` was persisted"); + let input = MyInput::from_id(id.key_index()); + + // Re-intern `i0`. + let i0 = intern(&db, input, 0); + let i0_id = i0.as_id(); + assert_eq!(i0.field(&db), "0"); + + // The query was not re-executed. + db.assert_logs(expect![[r#" + [ + "DidSetCancellationFlag", + "WillCheckCancellation", + ]"#]]); + + // Get the garbage collector to consider `i0` stale. + for x in 1.. { + db.synthetic_write(Durability::LOW); + + let ix = intern(&db, input, x); + let ix_id = ix.as_id(); + + // We reused the slot of `i0`. + if ix_id.index() == i0_id.index() { + break; + } + } + + // Re-intern `i0` after is has been garbage collected. + let i0 = intern(&db, input, 0); + + // The query was re-executed due to garbage collection, even though no inputs have changed + // and the inner query was not persisted. + assert_eq!(i0.field(&db), "0"); + assert_ne!(i0_id.index(), i0.as_id().index()); +} + +#[test] +#[should_panic(expected = "must be persistable")] +fn invalid_specified_dependency() { + #[salsa::tracked] + fn specify<'db>(db: &'db dyn salsa::Database) { + let tracked = MyTracked::new(db, "a".to_string()); + specified_query::specify(db, tracked, 2222); + } + + #[salsa::tracked(specify, persist)] + fn specified_query<'db>(_db: &'db dyn salsa::Database, _tracked: MyTracked<'db>) -> u32 { + 0 } let mut db = common::LoggerDatabase::default(); - new_interned(&db); + specify(&db); let _serialized = serde_json::to_string_pretty(&::as_serialize(&mut db)).unwrap(); From c380f1924b81d923a9bfb9b4a920874fd59e71cd Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Wed, 13 Aug 2025 21:47:27 -0400 Subject: [PATCH 33/65] fix assertion during interned deserialization (#978) --- src/interned.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/interned.rs b/src/interned.rs index c425b9e24..4ba3ea405 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -1492,7 +1492,8 @@ mod persistence { }; assert_eq!( - allocated_id, id, + allocated_id.index(), + id.index(), "values are serialized in allocation order" ); From 52a40577a72717c1b540881b805f8ec9831e1a72 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Thu, 14 Aug 2025 11:19:23 -0400 Subject: [PATCH 34/65] avoid cycles during serialization (#977) --- src/accumulator.rs | 3 ++- src/function.rs | 32 +++++++++++++++++++++++++---- src/ingredient.rs | 3 ++- src/input.rs | 3 ++- src/input/input_field.rs | 6 ++++-- src/interned.rs | 3 ++- src/tracked_struct.rs | 3 ++- src/tracked_struct/tracked_field.rs | 3 ++- 8 files changed, 44 insertions(+), 12 deletions(-) diff --git a/src/accumulator.rs b/src/accumulator.rs index b05aa64f1..be27d3d96 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -8,7 +8,7 @@ use std::panic::UnwindSafe; use accumulated::{Accumulated, AnyAccumulated}; use crate::function::{VerifyCycleHeads, VerifyResult}; -use crate::hash::FxIndexSet; +use crate::hash::{FxHashSet, FxIndexSet}; use crate::ingredient::{Ingredient, Jar}; use crate::plumbing::ZalsaLocal; use crate::sync::Arc; @@ -117,6 +117,7 @@ impl Ingredient for IngredientImpl { _zalsa: &Zalsa, _edge: QueryEdge, _serialized_edges: &mut FxIndexSet, + _visited_edges: &mut FxHashSet, ) { panic!("nothing should ever depend on an accumulator directly") } diff --git a/src/function.rs b/src/function.rs index 711cdf723..ae8564df5 100644 --- a/src/function.rs +++ b/src/function.rs @@ -13,7 +13,7 @@ use crate::cycle::{ use crate::database::RawDatabase; use crate::function::delete::DeletedEntries; use crate::function::sync::{ClaimResult, SyncTable}; -use crate::hash::FxIndexSet; +use crate::hash::{FxHashSet, FxIndexSet}; use crate::ingredient::{Ingredient, WaitForResult}; use crate::key::DatabaseKeyIndex; use crate::plumbing::{self, MemoIngredientMap}; @@ -294,6 +294,7 @@ where zalsa: &Zalsa, edge: QueryEdge, serialized_edges: &mut FxIndexSet, + visited_edges: &mut FxHashSet, ) { let input = edge.key().key_index(); @@ -305,10 +306,27 @@ where let origin = memo.revisions.origin.as_ref(); + visited_edges.insert(edge); + // Collect the minimum dependency tree. for edge in origin.edges() { + // Avoid forming cycles. + if visited_edges.contains(edge) { + continue; + } + + // Avoid flattening edges that we're going to serialize directly. + if serialized_edges.contains(edge) { + continue; + } + let dependency = zalsa.lookup_ingredient(edge.key().ingredient_index()); - dependency.collect_minimum_serialized_edges(zalsa, *edge, serialized_edges) + dependency.collect_minimum_serialized_edges( + zalsa, + *edge, + serialized_edges, + visited_edges, + ) } } @@ -494,7 +512,7 @@ where #[cfg(feature = "persistence")] mod persistence { use super::{Configuration, IngredientImpl, Memo}; - use crate::hash::FxIndexSet; + use crate::hash::{FxHashSet, FxIndexSet}; use crate::plumbing::{MemoIngredientMap, SalsaStructInDb}; use crate::zalsa::Zalsa; use crate::zalsa_local::{QueryEdge, QueryOrigin, QueryOriginRef}; @@ -579,6 +597,7 @@ mod persistence { // Flatten the dependency edges before serialization. fn flatten_edges(zalsa: &Zalsa, edges: &[QueryEdge]) -> FxIndexSet { + let mut visited_edges = FxHashSet::default(); let mut flattened_edges = FxIndexSet::with_capacity_and_hasher(edges.len(), Default::default()); @@ -590,7 +609,12 @@ mod persistence { flattened_edges.insert(edge); } else { // Otherwise, serialize the minimum edges necessary to cover the dependency. - dependency.collect_minimum_serialized_edges(zalsa, edge, &mut flattened_edges); + dependency.collect_minimum_serialized_edges( + zalsa, + edge, + &mut flattened_edges, + &mut visited_edges, + ); } } diff --git a/src/ingredient.rs b/src/ingredient.rs index 73377a3b7..6493061b6 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -6,7 +6,7 @@ use crate::cycle::{ }; use crate::database::RawDatabase; use crate::function::{VerifyCycleHeads, VerifyResult}; -use crate::hash::FxIndexSet; +use crate::hash::{FxHashSet, FxIndexSet}; use crate::runtime::Running; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; @@ -68,6 +68,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { zalsa: &Zalsa, edge: QueryEdge, serialized_edges: &mut FxIndexSet, + visited_edges: &mut FxHashSet, ); /// Returns information about the current provisional status of `input`. diff --git a/src/input.rs b/src/input.rs index 31b4cc8db..bf2d3f363 100644 --- a/src/input.rs +++ b/src/input.rs @@ -9,7 +9,7 @@ pub mod singleton; use input_field::FieldIngredientImpl; use crate::function::{VerifyCycleHeads, VerifyResult}; -use crate::hash::FxIndexSet; +use crate::hash::{FxHashSet, FxIndexSet}; use crate::id::{AsId, FromId, FromIdWithDb}; use crate::ingredient::Ingredient; use crate::input::singleton::{Singleton, SingletonChoice}; @@ -284,6 +284,7 @@ impl Ingredient for IngredientImpl { _zalsa: &Zalsa, _edge: QueryEdge, _serialized_edges: &mut FxIndexSet, + _visited_edges: &mut FxHashSet, ) { panic!("nothing should ever depend on an input struct directly") } diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 5b8f0706f..f874142e8 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -2,7 +2,7 @@ use std::fmt; use std::marker::PhantomData; use crate::function::{VerifyCycleHeads, VerifyResult}; -use crate::hash::FxIndexSet; +use crate::hash::{FxHashSet, FxIndexSet}; use crate::ingredient::Ingredient; use crate::input::{Configuration, IngredientImpl, Value}; use crate::sync::Arc; @@ -68,10 +68,12 @@ where _zalsa: &Zalsa, edge: QueryEdge, serialized_edges: &mut FxIndexSet, + _visited_edges: &mut FxHashSet, ) { assert!( C::PERSIST, - "the inputs of a persistable tracked function must be persistable" + "the inputs of a persistable tracked function must be persistable: `{}` is not persistable", + C::DEBUG_NAME ); // Input dependencies are the leaves of the minimum dependency tree. diff --git a/src/interned.rs b/src/interned.rs index 4ba3ea405..548f61c31 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -12,7 +12,7 @@ use rustc_hash::FxBuildHasher; use crate::durability::Durability; use crate::function::{VerifyCycleHeads, VerifyResult}; -use crate::hash::FxIndexSet; +use crate::hash::{FxHashSet, FxIndexSet}; use crate::id::{AsId, FromId}; use crate::ingredient::Ingredient; use crate::plumbing::{self, Jar, ZalsaLocal}; @@ -903,6 +903,7 @@ where _zalsa: &Zalsa, edge: QueryEdge, serialized_edges: &mut FxIndexSet, + _visited_edges: &mut FxHashSet, ) { if C::PERSIST { // If the interned struct is being persisted, it may be reachable through transitive queries. diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 6ad91321a..7ef998a4b 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -12,7 +12,7 @@ use thin_vec::ThinVec; use tracked_field::FieldIngredientImpl; use crate::function::{VerifyCycleHeads, VerifyResult}; -use crate::hash::FxIndexSet; +use crate::hash::{FxHashSet, FxIndexSet}; use crate::id::{AsId, FromId}; use crate::ingredient::{Ingredient, Jar}; use crate::key::DatabaseKeyIndex; @@ -949,6 +949,7 @@ where _zalsa: &Zalsa, _edge: QueryEdge, _serialized_edges: &mut FxIndexSet, + _visited_edges: &mut FxHashSet, ) { // Note that tracked structs are referenced by the identity map, but that // only matters if we are serializing the creating query, in which case diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index d4e90e278..abe435bea 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use crate::function::{VerifyCycleHeads, VerifyResult}; -use crate::hash::FxIndexSet; +use crate::hash::{FxHashSet, FxIndexSet}; use crate::ingredient::Ingredient; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; @@ -74,6 +74,7 @@ where _zalsa: &Zalsa, _edge: QueryEdge, _serialized_edges: &mut FxIndexSet, + _visited_edges: &mut FxHashSet, ) { // Tracked fields do not have transitive dependencies, and their dependencies are covered by // the base inputs. From 411f8448e8db6dab3246dfdab97bf3e08fb0d98d Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 19 Aug 2025 13:39:23 +0200 Subject: [PATCH 35/65] Update snapshot to fix nightly type rendering (#983) --- tests/memory-usage.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/memory-usage.rs b/tests/memory-usage.rs index 37d434292..a3dbac80f 100644 --- a/tests/memory-usage.rs +++ b/tests/memory-usage.rs @@ -1,7 +1,5 @@ #![cfg(feature = "inventory")] -use expect_test::expect; - #[salsa::input(heap_size = string_tuple_size_of)] struct MyInput { field: String, @@ -56,8 +54,11 @@ fn input_to_tracked_tuple<'db>( ) } +#[rustversion::all(stable, since(1.91))] #[test] fn test() { + use expect_test::expect; + let db = salsa::DatabaseImpl::new(); let input1 = MyInput::new(&db, "a".repeat(50)); @@ -133,7 +134,7 @@ fn test() { ( "input_to_interned", IngredientInfo { - debug_name: "memory_usage::MyInterned", + debug_name: "memory_usage::MyInterned<'_>", count: 3, size_of_metadata: 192, size_of_fields: 24, @@ -165,7 +166,7 @@ fn test() { ( "input_to_tracked", IngredientInfo { - debug_name: "memory_usage::MyTracked", + debug_name: "memory_usage::MyTracked<'_>", count: 2, size_of_metadata: 168, size_of_fields: 16, @@ -175,7 +176,7 @@ fn test() { ( "input_to_tracked_tuple", IngredientInfo { - debug_name: "(memory_usage::MyTracked, memory_usage::MyTracked)", + debug_name: "(memory_usage::MyTracked<'_>, memory_usage::MyTracked<'_>)", count: 1, size_of_metadata: 108, size_of_fields: 16, From 0656eca815f68656410f4cd631a14b1ec4134ed1 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 19 Aug 2025 13:59:11 +0200 Subject: [PATCH 36/65] fix: Delete not re-created tracked structs after fixpoint iteration (#979) * Fix tracked structs diffing in cycles * Proper fix * Clippy * Add regression test * Discard changes to src/function/maybe_changed_after.rs * Improve comemnt * Suppress clippy error in position where I don't control the types --- examples/calc/db.rs | 1 + src/active_query.rs | 13 ++-- src/function/diff_outputs.rs | 5 +- src/function/execute.rs | 22 ++++-- src/function/memo.rs | 2 +- src/tracked_struct.rs | 22 ++---- src/zalsa_local.rs | 10 +-- tests/cycle_tracked.rs | 131 +++++++++++++++++++++++++++++++++-- 8 files changed, 169 insertions(+), 37 deletions(-) diff --git a/examples/calc/db.rs b/examples/calc/db.rs index 63cc4fe12..05e06c0d0 100644 --- a/examples/calc/db.rs +++ b/examples/calc/db.rs @@ -48,6 +48,7 @@ impl CalcDatabaseImpl { } #[cfg(test)] + #[allow(unused)] pub fn take_logs(&self) -> Vec { let mut logs = self.logs.lock().unwrap(); if let Some(logs) = &mut *logs { diff --git a/src/active_query.rs b/src/active_query.rs index cc5e4fc58..00d0f5338 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -5,8 +5,6 @@ use crate::accumulator::{ accumulated_map::{AccumulatedMap, AtomicInputAccumulatedValues, InputAccumulatedValues}, Accumulator, }; -use crate::cycle::{CycleHeads, IterationCount}; -use crate::durability::Durability; use crate::hash::FxIndexSet; use crate::key::DatabaseKeyIndex; use crate::runtime::Stamp; @@ -14,6 +12,11 @@ use crate::sync::atomic::AtomicBool; use crate::tracked_struct::{Disambiguator, DisambiguatorMap, IdentityHash, IdentityMap}; use crate::zalsa_local::{QueryEdge, QueryOrigin, QueryRevisions, QueryRevisionsExtra}; use crate::Revision; +use crate::{ + cycle::{CycleHeads, IterationCount}, + Id, +}; +use crate::{durability::Durability, tracked_struct::Identity}; #[derive(Debug)] pub(crate) struct ActiveQuery { @@ -74,6 +77,7 @@ impl ActiveQuery { changed_at: Revision, edges: &[QueryEdge], untracked_read: bool, + active_tracked_ids: &[(Identity, Id)], ) { assert!(self.input_outputs.is_empty()); @@ -83,7 +87,8 @@ impl ActiveQuery { self.untracked_read |= untracked_read; // Mark all tracked structs from the previous iteration as active. - self.tracked_struct_ids.mark_all_active(); + self.tracked_struct_ids + .mark_all_active(active_tracked_ids.iter().copied()); } pub(super) fn add_read( @@ -408,7 +413,7 @@ pub(crate) struct CompletedQuery { /// The keys of any tracked structs that were created in a previous execution of the /// query but not the current one, and should be marked as stale. - pub(crate) stale_tracked_structs: Vec, + pub(crate) stale_tracked_structs: Vec<(Identity, Id)>, } struct CapturedQuery { diff --git a/src/function/diff_outputs.rs b/src/function/diff_outputs.rs index 923a0fc88..003310ae1 100644 --- a/src/function/diff_outputs.rs +++ b/src/function/diff_outputs.rs @@ -27,8 +27,9 @@ where // Note that tracked structs are not stored as direct query outputs, but they are still outputs // that need to be reported as stale. - for output in &completed_query.stale_tracked_structs { - Self::report_stale_output(zalsa, key, *output); + for (identity, id) in &completed_query.stale_tracked_structs { + let output = DatabaseKeyIndex::new(identity.ingredient_index(), *id); + Self::report_stale_output(zalsa, key, output); } let mut stale_outputs = output_edges(edges).collect::>(); diff --git a/src/function/execute.rs b/src/function/execute.rs index 67cee969d..d1651859e 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -3,6 +3,7 @@ use crate::cycle::{CycleRecoveryStrategy, IterationCount}; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl}; use crate::sync::atomic::{AtomicBool, Ordering}; +use crate::tracked_struct::Identity; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; use crate::zalsa_local::ActiveQueryGuard; use crate::{Event, EventKind, Id}; @@ -134,13 +135,25 @@ where let database_key_index = active_query.database_key_index; let mut iteration_count = IterationCount::initial(); let mut fell_back = false; + let zalsa_local = db.zalsa_local(); // Our provisional value from the previous iteration, when doing fixpoint iteration. // Initially it's set to None, because the initial provisional value is created lazily, // only when a cycle is actually encountered. let mut opt_last_provisional: Option<&Memo<'db, C>> = None; + let mut last_stale_tracked_ids: Vec<(Identity, Id)> = Vec::new(); + loop { let previous_memo = opt_last_provisional.or(opt_old_memo); + + // Tracked struct ids that existed in the previous revision + // but weren't recreated in the last iteration. It's important that we seed the next + // query with these ids because the query might re-create them as part of the next iteration. + // This is not only important to ensure that the re-created tracked structs have the same ids, + // it's also important to ensure that these tracked structs get removed + // if they aren't recreated when reaching the final iteration. + active_query.seed_tracked_struct_ids(&last_stale_tracked_ids); + let (mut new_value, mut completed_query) = Self::execute_query(db, zalsa, active_query, previous_memo, id); @@ -239,10 +252,9 @@ where ), memo_ingredient_index, )); + last_stale_tracked_ids = completed_query.stale_tracked_structs; - active_query = db - .zalsa_local() - .push_query(database_key_index, iteration_count); + active_query = zalsa_local.push_query(database_key_index, iteration_count); continue; } @@ -280,9 +292,7 @@ where if let Some(old_memo) = opt_old_memo { // If we already executed this query once, then use the tracked-struct ids from the // previous execution as the starting point for the new one. - if let Some(tracked_struct_ids) = old_memo.revisions.tracked_struct_ids() { - active_query.seed_tracked_struct_ids(tracked_struct_ids); - } + active_query.seed_tracked_struct_ids(old_memo.revisions.tracked_struct_ids()); // Copy over all inputs and outputs from a previous iteration. // This is necessary to: diff --git a/src/function/memo.rs b/src/function/memo.rs index 4894cc642..9671c83d1 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -326,7 +326,7 @@ where stale_output.remove_stale_output(zalsa, executor); } - for (identity, id) in self.revisions.tracked_struct_ids().into_iter().flatten() { + for (identity, id) in self.revisions.tracked_struct_ids() { let key = DatabaseKeyIndex::new(identity.ingredient_index(), *id); key.remove_stale_output(zalsa, executor); } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 7ef998a4b..cccb13fd1 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -255,19 +255,15 @@ pub(crate) struct IdentityMap { impl IdentityMap { /// Seeds the identity map with the IDs from a previous revision. pub(crate) fn seed(&mut self, source: &[(Identity, Id)]) { - self.table.clear(); - self.table - .reserve(source.len(), |entry| entry.identity.hash); - for &(key, id) in source { self.insert_entry(key, id, false); } } // Mark all tracked structs in the map as created by the current query. - pub(crate) fn mark_all_active(&mut self) { - for entry in self.table.iter_mut() { - entry.active = true; + pub(crate) fn mark_all_active(&mut self, items: impl IntoIterator) { + for (key, id) in items { + self.insert_entry(key, id, true); } } @@ -330,7 +326,8 @@ impl IdentityMap { /// The first entry contains the identity and IDs of any tracked structs that were /// created by the current execution of the query, while the second entry contains any /// tracked structs that were created in a previous execution but not the current one. - pub(crate) fn drain(&mut self) -> (ThinVec<(Identity, Id)>, Vec) { + #[expect(clippy::type_complexity)] + pub(crate) fn drain(&mut self) -> (ThinVec<(Identity, Id)>, Vec<(Identity, Id)>) { if self.table.is_empty() { return (ThinVec::new(), Vec::new()); } @@ -342,19 +339,14 @@ impl IdentityMap { if entry.active { active.push((entry.identity, entry.id)); } else { - stale.push(DatabaseKeyIndex::new( - entry.identity.ingredient_index(), - entry.id, - )); + stale.push((entry.identity, entry.id)); } } // Removing a stale tracked struct ID shows up in the event logs, so make sure // the order is stable here. stale.sort_unstable_by(|a, b| { - a.ingredient_index() - .cmp(&b.ingredient_index()) - .then(a.key_index().cmp(&b.key_index())) + (a.0.ingredient_index(), a.1).cmp(&(b.0.ingredient_index(), b.1)) }); (active, stale) diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index ced5e9281..77387f72f 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -668,13 +668,13 @@ impl QueryRevisions { } } - /// Returns a reference to the `IdentityMap` for this query, or `None` if the map is empty. - pub fn tracked_struct_ids(&self) -> Option<&[(Identity, Id)]> { + /// Returns the ids of the tracked structs created when running this query. + pub fn tracked_struct_ids(&self) -> &[(Identity, Id)] { self.extra .0 .as_ref() .map(|extra| &*extra.tracked_struct_ids) - .filter(|tracked_struct_ids| !tracked_struct_ids.is_empty()) + .unwrap_or_default() } /// Returns a mutable reference to the `IdentityMap` for this query, or `None` if the map is empty. @@ -1090,7 +1090,6 @@ impl ActiveQueryGuard<'_> { #[cfg(debug_assertions)] assert_eq!(stack.len(), self.push_len); let frame = stack.last_mut().unwrap(); - assert!(frame.tracked_struct_ids().is_empty()); frame.tracked_struct_ids_mut().seed(tracked_struct_ids); }) } @@ -1105,6 +1104,7 @@ impl ActiveQueryGuard<'_> { previous.origin.as_ref(), QueryOriginRef::DerivedUntracked(_) ); + let tracked_ids = previous.tracked_struct_ids(); // SAFETY: We do not access the query stack reentrantly. unsafe { @@ -1112,7 +1112,7 @@ impl ActiveQueryGuard<'_> { #[cfg(debug_assertions)] assert_eq!(stack.len(), self.push_len); let frame = stack.last_mut().unwrap(); - frame.seed_iteration(durability, changed_at, edges, untracked_read); + frame.seed_iteration(durability, changed_at, edges, untracked_read, tracked_ids); }) } } diff --git a/tests/cycle_tracked.rs b/tests/cycle_tracked.rs index b9ef6ed14..27e52934d 100644 --- a/tests/cycle_tracked.rs +++ b/tests/cycle_tracked.rs @@ -1,8 +1,5 @@ #![cfg(feature = "inventory")] -//! Tests for cycles where the cycle head is stored on a tracked struct -//! and that tracked struct is freed in a later revision. - mod common; use crate::common::{EventLoggerDatabase, LogDatabase}; @@ -45,6 +42,7 @@ struct Node<'db> { #[salsa::input(debug)] struct GraphInput { simple: bool, + fixpoint_variant: usize, } #[salsa::tracked(returns(ref))] @@ -125,11 +123,13 @@ fn cycle_recover( CycleRecoveryAction::Iterate } +/// Tests for cycles where the cycle head is stored on a tracked struct +/// and that tracked struct is freed in a later revision. #[test] fn main() { let mut db = EventLoggerDatabase::default(); - let input = GraphInput::new(&db, false); + let input = GraphInput::new(&db, false, 0); let graph = create_graph(&db, input); let c = graph.find_node(&db, "c").unwrap(); @@ -192,3 +192,126 @@ fn main() { "WillCheckCancellation", ]"#]]); } + +#[salsa::tracked] +struct IterationNode<'db> { + #[returns(ref)] + name: String, + iteration: usize, +} + +/// A cyclic query that creates more tracked structs in later fixpoint iterations. +/// +/// The output depends on the input's fixpoint_variant: +/// - variant=0: Returns `[base]` (1 struct, no cycle) +/// - variant=1: Through fixpoint iteration, returns `[iter_0, iter_1, iter_2]` (3 structs) +/// - variant=2: Through fixpoint iteration, returns `[iter_0, iter_1]` (2 structs) +/// - variant>2: Through fixpoint iteration, returns `[iter_0, iter_1]` (2 structs, same as variant=2) +/// +/// When variant > 0, the query creates a cycle by calling itself. The fixpoint iteration +/// proceeds as follows: +/// 1. Initial: returns empty vector +/// 2. First iteration: returns `[iter_0]` +/// 3. Second iteration: returns `[iter_0, iter_1]` +/// 4. Third iteration (only for variant=1): returns `[iter_0, iter_1, iter_2]` +/// 5. Further iterations: no change, fixpoint reached +#[salsa::tracked(cycle_fn=cycle_recover_with_structs, cycle_initial=initial_with_structs)] +fn create_tracked_in_cycle<'db>( + db: &'db dyn Database, + input: GraphInput, +) -> Vec> { + // Check if we should create more nodes based on the input. + let variant = input.fixpoint_variant(db); + + if variant == 0 { + // Base case - no cycle, just return a single node. + vec![IterationNode::new(db, "base".to_string(), 0)] + } else { + // Create a cycle by calling ourselves. + let previous = create_tracked_in_cycle(db, input); + + // In later iterations, create additional tracked structs. + if previous.is_empty() { + // First iteration - initial returns empty. + vec![IterationNode::new(db, "iter_0".to_string(), 0)] + } else { + // Limit based on variant: variant=1 allows 3 nodes, variant=2 allows 2 nodes. + let limit = if variant == 1 { 3 } else { 2 }; + + if previous.len() < limit { + // Subsequent iterations - add more nodes. + let mut nodes = previous; + nodes.push(IterationNode::new( + db, + format!("iter_{}", nodes.len()), + nodes.len(), + )); + nodes + } else { + // Reached the limit. + previous + } + } + } +} + +fn initial_with_structs(_db: &dyn Database, _input: GraphInput) -> Vec> { + vec![] +} + +#[allow(clippy::ptr_arg)] +fn cycle_recover_with_structs<'db>( + _db: &'db dyn Database, + _value: &Vec>, + _iteration: u32, + _input: GraphInput, +) -> CycleRecoveryAction>> { + CycleRecoveryAction::Iterate +} + +#[test] +fn test_cycle_with_fixpoint_structs() { + let mut db = EventLoggerDatabase::default(); + + // Create an input that will trigger the cyclic behavior. + let input = GraphInput::new(&db, false, 1); + + // Initial query - this will create structs across multiple iterations. + let nodes = create_tracked_in_cycle(&db, input); + assert_eq!(nodes.len(), 3); + // First iteration: previous is empty [], so we get [iter_0] + // Second iteration: previous is [iter_0], so we get [iter_0, iter_1] + // Third iteration: previous is [iter_0, iter_1], so we get [iter_0, iter_1, iter_2] + assert_eq!(nodes[0].name(&db), "iter_0"); + assert_eq!(nodes[1].name(&db), "iter_1"); + assert_eq!(nodes[2].name(&db), "iter_2"); + + // Clear logs to focus on the change. + db.clear_logs(); + + // Change the input to force re-execution with a different variant. + // This will create 2 tracked structs instead of 3 (one fewer than before). + input.set_fixpoint_variant(&mut db).to(2); + + // Re-query - this should handle the tracked struct changes properly. + let nodes = create_tracked_in_cycle(&db, input); + assert_eq!(nodes.len(), 2); + assert_eq!(nodes[0].name(&db), "iter_0"); + assert_eq!(nodes[1].name(&db), "iter_1"); + + // Check the logs to ensure proper execution and struct management. + // We should see the third struct (iter_2) being discarded. + db.assert_logs(expect![[r#" + [ + "DidSetCancellationFlag", + "WillCheckCancellation", + "WillExecute { database_key: create_tracked_in_cycle(Id(0)) }", + "WillCheckCancellation", + "WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(1), fell_back: false }", + "WillCheckCancellation", + "WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(2), fell_back: false }", + "WillCheckCancellation", + "WillDiscardStaleOutput { execute_key: create_tracked_in_cycle(Id(0)), output_key: IterationNode(Id(402)) }", + "DidDiscard { key: IterationNode(Id(402)) }", + ]"#]]); +} From b92180bdcf8731a350a20d6ef7464d8244c4c599 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Tue, 19 Aug 2025 16:40:38 -0400 Subject: [PATCH 37/65] outline cold path of `lookup_ingredient` (#984) --- src/zalsa.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/zalsa.rs b/src/zalsa.rs index 118c890d8..9158fda0a 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -255,11 +255,7 @@ impl Zalsa { /// Returns the ingredient at the given index, or panics if it is out-of-bounds. #[inline] pub fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient { - let index = index.as_u32() as usize; - self.ingredients_vec - .get(index) - .unwrap_or_else(|| panic!("index `{index}` is uninitialized")) - .as_ref() + self.ingredients_vec[index.as_u32() as usize].as_ref() } /// Returns the ingredient at the given index. @@ -269,10 +265,12 @@ impl Zalsa { /// The index must be in-bounds. #[inline] pub unsafe fn lookup_ingredient_unchecked(&self, index: IngredientIndex) -> &dyn Ingredient { - let index = index.as_u32() as usize; - // SAFETY: Guaranteed by caller. - unsafe { self.ingredients_vec.get_unchecked(index).as_ref() } + unsafe { + self.ingredients_vec + .get_unchecked(index.as_u32() as usize) + .as_ref() + } } pub(crate) fn ingredient_index_for_memo( From fb4cb24cea0612c7844d6ef68bd6864f7dab83ef Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Tue, 19 Aug 2025 16:48:30 -0400 Subject: [PATCH 38/65] Persistent caching fixes (#982) * reduce `serde` overhead * avoid adding dependencies on interned values where garbage collection is disabled * add compatibility with non-self-describing `serde` formats * reuse edge traversal allocation * correctly deserialize singleton inputs --- src/active_query.rs | 4 +- src/cycle.rs | 1 + src/database.rs | 17 ++++- src/durability.rs | 22 +++++- src/function.rs | 60 +++++++++++---- src/function/memo.rs | 7 +- src/input.rs | 70 +++++++++-------- src/interned.rs | 67 +++++++++-------- src/key.rs | 26 ++++++- src/tracked_struct.rs | 47 ++++++------ src/zalsa.rs | 1 + src/zalsa_local.rs | 21 +++--- tests/persistence.rs | 171 ++++++++++++++++++++++-------------------- 13 files changed, 315 insertions(+), 199 deletions(-) diff --git a/src/active_query.rs b/src/active_query.rs index 00d0f5338..0b2231052 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -208,9 +208,9 @@ impl ActiveQuery { } = self; let origin = if untracked_read { - QueryOrigin::derived_untracked(input_outputs.drain(..)) + QueryOrigin::derived_untracked(input_outputs.drain(..).collect()) } else { - QueryOrigin::derived(input_outputs.drain(..)) + QueryOrigin::derived(input_outputs.drain(..).collect()) }; disambiguator_map.clear(); diff --git a/src/cycle.rs b/src/cycle.rs index 0558bda04..2ba1dcc8f 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -104,6 +104,7 @@ pub struct CycleHead { #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Default)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "persistence", serde(transparent))] pub struct IterationCount(u8); impl IterationCount { diff --git a/src/database.rs b/src/database.rs index cddb14941..0df83b03b 100644 --- a/src/database.rs +++ b/src/database.rs @@ -160,7 +160,7 @@ mod persistence { use std::fmt; - use serde::de::{self, DeserializeSeed}; + use serde::de::{self, DeserializeSeed, SeqAccess}; use serde::ser::SerializeMap; impl dyn Database { @@ -275,6 +275,21 @@ mod persistence { formatter.write_str("struct Database") } + fn visit_seq(self, mut seq: V) -> Result<(), V::Error> + where + V: SeqAccess<'de>, + { + let mut runtime = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let () = seq + .next_element_seed(DeserializeIngredients(self.0))? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + + self.0.runtime_mut().deserialize_from(&mut runtime); + Ok(()) + } + fn visit_map(self, mut map: V) -> Result<(), V::Error> where V: serde::de::MapAccess<'de>, diff --git a/src/durability.rs b/src/durability.rs index a3e33b1bc..4691e9372 100644 --- a/src/durability.rs +++ b/src/durability.rs @@ -17,9 +17,28 @@ /// configuration, the source from library crates, or other things /// that are unlikely to be edited. #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct Durability(DurabilityVal); +#[cfg(feature = "persistence")] +impl serde::Serialize for Durability { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serde::Serialize::serialize(&(self.0 as u8), serializer) + } +} + +#[cfg(feature = "persistence")] +impl<'de> serde::Deserialize<'de> for Durability { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + u8::deserialize(deserializer).map(|value| Self(DurabilityVal::from(value))) + } +} + impl std::fmt::Debug for Durability { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if f.alternate() { @@ -38,7 +57,6 @@ impl std::fmt::Debug for Durability { // We use an enum here instead of a u8 for niches. #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] enum DurabilityVal { Low = 0, Medium = 1, diff --git a/src/function.rs b/src/function.rs index ae8564df5..47fdca40e 100644 --- a/src/function.rs +++ b/src/function.rs @@ -541,7 +541,26 @@ mod persistence { { let Self { ingredient, zalsa } = self; - let mut map = serializer.serialize_map(None)?; + let count = as SalsaStructInDb>::entries(zalsa) + .filter(|entry| { + let memo_ingredient_index = ingredient + .memo_ingredient_indices + .get(entry.ingredient_index()); + + let memo = ingredient.get_memo_from_table_for( + zalsa, + entry.key_index(), + memo_ingredient_index, + ); + + memo.is_some_and(|memo| memo.should_serialize()) + }) + .count(); + + let mut map = serializer.serialize_map(Some(count))?; + + let mut visited_edges = FxHashSet::default(); + let mut flattened_edges = FxIndexSet::default(); for entry in as SalsaStructInDb>::entries(zalsa) { let memo_ingredient_index = ingredient @@ -558,10 +577,24 @@ mod persistence { // Flatten the dependencies of this query down to the base inputs. let flattened_origin = match memo.revisions.origin.as_ref() { QueryOriginRef::Derived(edges) => { - QueryOrigin::derived(flatten_edges(zalsa, edges)) + collect_minimum_serialized_edges( + zalsa, + edges, + &mut visited_edges, + &mut flattened_edges, + ); + + QueryOrigin::derived(flattened_edges.drain(..).collect()) } QueryOriginRef::DerivedUntracked(edges) => { - QueryOrigin::derived_untracked(flatten_edges(zalsa, edges)) + collect_minimum_serialized_edges( + zalsa, + edges, + &mut visited_edges, + &mut flattened_edges, + ); + + QueryOrigin::derived_untracked(flattened_edges.drain(..).collect()) } QueryOriginRef::Assigned(key) => { let dependency = zalsa.lookup_ingredient(key.ingredient_index()); @@ -588,6 +621,8 @@ mod persistence { ); map.serialize_entry(&key, &memo)?; + + visited_edges.clear(); } } @@ -596,11 +631,12 @@ mod persistence { } // Flatten the dependency edges before serialization. - fn flatten_edges(zalsa: &Zalsa, edges: &[QueryEdge]) -> FxIndexSet { - let mut visited_edges = FxHashSet::default(); - let mut flattened_edges = - FxIndexSet::with_capacity_and_hasher(edges.len(), Default::default()); - + fn collect_minimum_serialized_edges( + zalsa: &Zalsa, + edges: &[QueryEdge], + visited_edges: &mut FxHashSet, + flattened_edges: &mut FxIndexSet, + ) { for &edge in edges { let dependency = zalsa.lookup_ingredient(edge.key().ingredient_index()); @@ -612,13 +648,11 @@ mod persistence { dependency.collect_minimum_serialized_edges( zalsa, edge, - &mut flattened_edges, - &mut visited_edges, + flattened_edges, + visited_edges, ); } } - - flattened_edges } pub struct DeserializeIngredient<'db, C> @@ -659,7 +693,7 @@ mod persistence { { let DeserializeIngredient { zalsa, ingredient } = self; - while let Some((key, memo)) = access.next_entry::>()? { + while let Some((key, memo)) = access.next_entry::<&str, Memo>()? { let (ingredient_index, id) = key .split_once(':') .ok_or_else(|| de::Error::custom("invalid database key"))?; diff --git a/src/function/memo.rs b/src/function/memo.rs index 9671c83d1..dc346adcf 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -367,9 +367,9 @@ mod persistence { /// A reference to the fields of a [`Memo`], with its [`QueryRevisions`] transformed. pub(crate) struct MappedMemo<'memo, 'db, C: Configuration> { - value: Option<&'memo C::Output<'db>>, - verified_at: AtomicRevision, - revisions: MappedQueryRevisions<'memo>, + pub(crate) value: Option<&'memo C::Output<'db>>, + pub(crate) verified_at: AtomicRevision, + pub(crate) revisions: MappedQueryRevisions<'memo>, } impl<'db, C: Configuration> Memo<'db, C> { @@ -437,6 +437,7 @@ mod persistence { D: serde::Deserializer<'de>, { #[derive(Deserialize)] + #[serde(rename = "Memo")] pub struct DeserializeMemo { #[serde(bound = "C: Configuration")] value: DeserializeValue, diff --git a/src/input.rs b/src/input.rs index bf2d3f363..4c79cd80f 100644 --- a/src/input.rs +++ b/src/input.rs @@ -152,16 +152,16 @@ impl IngredientImpl { durabilities: C::Durabilities, ) -> C::Struct { let id = self.singleton.with_scope(|| { - let (id, _) = zalsa_local.allocate(zalsa, self.ingredient_index, |_| Value:: { - fields, - revisions, - durabilities, - // SAFETY: We only ever access the memos of a value that we allocated through - // our `MemoTableTypes`. - memos: unsafe { MemoTable::new(self.memo_table_types()) }, - }); - - id + zalsa_local + .allocate(zalsa, self.ingredient_index, |_| Value:: { + fields, + revisions, + durabilities, + // SAFETY: We only ever access the memos of a value that we allocated through + // our `MemoTableTypes`. + memos: unsafe { MemoTable::new(self.memo_table_types()) }, + }) + .0 }); FromIdWithDb::from_id(id, zalsa) @@ -440,10 +440,11 @@ where mod persistence { use std::fmt; - use serde::ser::SerializeMap; + use serde::ser::{SerializeMap, SerializeStruct}; use serde::{de, Deserialize}; use super::{Configuration, IngredientImpl, Value}; + use crate::input::singleton::SingletonChoice; use crate::plumbing::Ingredient; use crate::table::memo::MemoTable; use crate::zalsa::Zalsa; @@ -467,7 +468,8 @@ mod persistence { { let Self { zalsa, .. } = self; - let mut map = serializer.serialize_map(None)?; + let count = zalsa.table().slots_of::>().count(); + let mut map = serializer.serialize_map(Some(count))?; for (id, value) in zalsa.table().slots_of::>() { map.serialize_entry(&id.as_bits(), value)?; @@ -485,21 +487,7 @@ mod persistence { where S: serde::Serializer, { - let mut map = serializer.serialize_map(None)?; - - struct SerializeFields<'db, C: Configuration>(&'db C::Fields); - - impl serde::Serialize for SerializeFields<'_, C> - where - C: Configuration, - { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - C::serialize(self.0, serializer) - } - } + let mut value = serializer.serialize_struct("Value", 3)?; let Value { fields, @@ -508,11 +496,25 @@ mod persistence { memos: _, } = self; - map.serialize_entry(&"durabilities", &durabilities)?; - map.serialize_entry(&"revisions", &revisions)?; - map.serialize_entry(&"fields", &SerializeFields::(fields))?; + value.serialize_field("durabilities", &durabilities)?; + value.serialize_field("revisions", &revisions)?; + value.serialize_field("fields", &SerializeFields::(fields))?; - map.end() + value.end() + } + } + + struct SerializeFields<'db, C: Configuration>(&'db C::Fields); + + impl serde::Serialize for SerializeFields<'_, C> + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + C::serialize(self.0, serializer) } } @@ -577,13 +579,14 @@ mod persistence { // Initialize the slot. // // SAFETY: We have a mutable reference to the database. - let (allocated_id, _) = unsafe { + let allocated_id = ingredient.singleton.with_scope(|| unsafe { zalsa .table() .page(page_idx) .allocate(page_idx, |_| value) .unwrap_or_else(|_| panic!("serialized an invalid `Id`: {id:?}")) - }; + .0 + }); assert_eq!( allocated_id, id, @@ -596,6 +599,7 @@ mod persistence { } #[derive(Deserialize)] + #[serde(rename = "Value")] pub struct DeserializeValue { durabilities: C::Durabilities, revisions: C::Revisions, diff --git a/src/interned.rs b/src/interned.rs index 548f61c31..c7e1dc2a0 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -802,9 +802,9 @@ where pub fn reset(&mut self, zalsa_mut: &mut Zalsa) { _ = zalsa_mut; - for shard in self.shards.iter() { + for shard in self.shards.iter_mut() { // We can clear the key maps now that we have cancelled all other handles. - shard.lock().key_map.clear(); + shard.get_mut().key_map.clear(); } } @@ -905,16 +905,16 @@ where serialized_edges: &mut FxIndexSet, _visited_edges: &mut FxHashSet, ) { - if C::PERSIST { + if C::PERSIST && C::REVISIONS != IMMORTAL { // If the interned struct is being persisted, it may be reachable through transitive queries. // Additionally, interned struct dependencies are impure in that garbage collection can // invalidate a dependency without a base input necessarily being updated. Thus, we must - // preserve the transitive dependency on the interned struct. + // preserve the transitive dependency on the interned struct, if garbage collection is + // enabled. serialized_edges.insert(edge); } - // Otherwise, the dependency is covered by the base inputs, as the interned struct itself is - // not being persisted. + // Otherwise, the dependency is covered by the base inputs. } fn debug_name(&self) -> &'static str { @@ -976,7 +976,7 @@ where ) { f(&persistence::SerializeIngredient { zalsa, - _ingredient: self, + ingredient: self, }) } @@ -1316,7 +1316,7 @@ mod persistence { use std::hash::BuildHasher; use intrusive_collections::LinkedListLink; - use serde::ser::SerializeMap; + use serde::ser::{SerializeMap, SerializeStruct}; use serde::{de, Deserialize}; use super::{Configuration, IngredientImpl, Value, ValueShared}; @@ -1330,7 +1330,7 @@ mod persistence { C: Configuration, { pub zalsa: &'db Zalsa, - pub _ingredient: &'db IngredientImpl, + pub ingredient: &'db IngredientImpl, } impl serde::Serialize for SerializeIngredient<'_, C> @@ -1341,9 +1341,15 @@ mod persistence { where S: serde::Serializer, { - let Self { zalsa, .. } = self; + let Self { zalsa, ingredient } = *self; + + let count = ingredient + .shards + .iter() + .map(|shard| shard.lock().key_map.len()) + .sum(); - let mut map = serializer.serialize_map(None)?; + let mut map = serializer.serialize_map(Some(count))?; for (_, value) in zalsa.table().slots_of::>() { // SAFETY: The safety invariant of `Ingredient::serialize` ensures we have exclusive access @@ -1365,21 +1371,7 @@ mod persistence { where S: serde::Serializer, { - let mut map = serializer.serialize_map(None)?; - - struct SerializeFields<'db, C: Configuration>(&'db C::Fields<'static>); - - impl serde::Serialize for SerializeFields<'_, C> - where - C: Configuration, - { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - C::serialize(self.0, serializer) - } - } + let mut value = serializer.serialize_struct("Value,", 3)?; let Value { fields, @@ -1401,11 +1393,25 @@ mod persistence { id: _, } = unsafe { *shared.get() }; - map.serialize_entry(&"durability", &durability)?; - map.serialize_entry(&"last_interned_at", &last_interned_at)?; - map.serialize_entry(&"fields", &SerializeFields::(fields))?; + value.serialize_field("durability", &durability)?; + value.serialize_field("last_interned_at", &last_interned_at)?; + value.serialize_field("fields", &SerializeFields::(fields))?; - map.end() + value.end() + } + } + + struct SerializeFields<'db, C: Configuration>(&'db C::Fields<'static>); + + impl serde::Serialize for SerializeFields<'_, C> + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + C::serialize(self.0, serializer) } } @@ -1507,6 +1513,7 @@ mod persistence { } #[derive(Deserialize)] + #[serde(rename = "Value")] pub struct DeserializeValue { durability: Durability, last_interned_at: Revision, diff --git a/src/key.rs b/src/key.rs index fa947575f..82d922565 100644 --- a/src/key.rs +++ b/src/key.rs @@ -10,7 +10,6 @@ use crate::Id; /// ordered and equatable but those orderings are arbitrary, and meant to be used /// only for inserting into maps and the like. #[derive(Copy, Clone, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct DatabaseKeyIndex { key_index: Id, ingredient_index: IngredientIndex, @@ -68,6 +67,31 @@ impl DatabaseKeyIndex { } } +#[cfg(feature = "persistence")] +impl serde::Serialize for DatabaseKeyIndex { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serde::Serialize::serialize(&(self.key_index, self.ingredient_index), serializer) + } +} + +#[cfg(feature = "persistence")] +impl<'de> serde::Deserialize<'de> for DatabaseKeyIndex { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let (key_index, ingredient_index) = serde::Deserialize::deserialize(deserializer)?; + + Ok(DatabaseKeyIndex { + key_index, + ingredient_index, + }) + } +} + impl fmt::Debug for DatabaseKeyIndex { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { crate::attach::with_attached_database(|db| { diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index cccb13fd1..66c0cade9 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -428,6 +428,7 @@ where #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "persistence", serde(transparent))] pub struct Disambiguator(u32); #[derive(Default, Debug)] @@ -1231,7 +1232,7 @@ mod tests { mod persistence { use std::fmt; - use serde::ser::SerializeMap; + use serde::ser::{SerializeMap, SerializeStruct}; use serde::{de, Deserialize}; use super::{Configuration, IngredientImpl, Value}; @@ -1259,7 +1260,8 @@ mod persistence { { let Self { zalsa, .. } = self; - let mut map = serializer.serialize_map(None)?; + let count = zalsa.table().slots_of::>().count(); + let mut map = serializer.serialize_map(Some(count))?; for (id, value) in zalsa.table().slots_of::>() { map.serialize_entry(&id.as_bits(), value)?; @@ -1277,21 +1279,7 @@ mod persistence { where S: serde::Serializer, { - let mut map = serializer.serialize_map(None)?; - - struct SerializeFields<'db, C: Configuration>(&'db C::Fields<'static>); - - impl serde::Serialize for SerializeFields<'_, C> - where - C: Configuration, - { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - C::serialize(self.0, serializer) - } - } + let mut value = serializer.serialize_struct("Value", 4)?; let Value { durability, @@ -1301,12 +1289,26 @@ mod persistence { memos: _, } = self; - map.serialize_entry(&"durability", &durability)?; - map.serialize_entry(&"updated_at", &updated_at)?; - map.serialize_entry(&"revisions", &revisions)?; - map.serialize_entry(&"fields", &SerializeFields::(fields))?; + value.serialize_field("durability", &durability)?; + value.serialize_field("updated_at", &updated_at)?; + value.serialize_field("revisions", &revisions)?; + value.serialize_field("fields", &SerializeFields::(fields))?; - map.end() + value.end() + } + } + + struct SerializeFields<'db, C: Configuration>(&'db C::Fields<'static>); + + impl serde::Serialize for SerializeFields<'_, C> + where + C: Configuration, + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + C::serialize(self.0, serializer) } } @@ -1391,6 +1393,7 @@ mod persistence { } #[derive(Deserialize)] + #[serde(rename = "Value")] pub struct DeserializeValue { durability: Durability, updated_at: OptionalAtomicRevision, diff --git a/src/zalsa.rs b/src/zalsa.rs index 9158fda0a..9fc139e64 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -74,6 +74,7 @@ static NONCE: crate::nonce::NonceGenerator = crate::nonce::NonceGe /// Each ingredient is given a unique index as the database is being created. #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "persistence", serde(transparent))] pub struct IngredientIndex(u32); impl IngredientIndex { diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 77387f72f..e332b516f 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -485,6 +485,7 @@ impl QueryRevisions { /// in cycles, or create accumulators. #[derive(Debug, Default)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "persistence", serde(transparent))] pub(crate) struct QueryRevisionsExtra(Option>); impl QueryRevisionsExtra { @@ -524,8 +525,7 @@ impl QueryRevisionsExtra { #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] struct QueryRevisionsExtraInner { #[cfg(feature = "accumulator")] - // TODO: Support serializing accumulators. - #[cfg_attr(feature = "persistence", serde(skip))] + #[cfg_attr(feature = "persistence", serde(skip))] // TODO: Support serializing accumulators accumulated: AccumulatedMap, /// The ids of tracked structs created by this query. @@ -693,6 +693,7 @@ impl QueryRevisions { #[repr(u8)] #[derive(Debug, Clone, Copy)] #[cfg_attr(feature = "persistence", derive(serde::Serialize))] +#[cfg_attr(feature = "persistence", serde(rename = "QueryOrigin"))] pub enum QueryOriginRef<'a> { /// The value was assigned as the output of another query (e.g., using `specify`). /// The `DatabaseKeyIndex` is the identity of the assigning query. @@ -839,9 +840,7 @@ impl QueryOrigin { } /// Create a query origin of type `QueryOriginKind::Derived`, with the given edges. - pub fn derived(input_outputs: impl IntoIterator) -> QueryOrigin { - let input_outputs = input_outputs.into_iter().collect::>(); - + pub fn derived(input_outputs: Box<[QueryEdge]>) -> QueryOrigin { // Exceeding `u32::MAX` query edges should never happen in real-world usage. let length = u32::try_from(input_outputs.len()) .expect("exceeded more than `u32::MAX` query edges; this should never happen."); @@ -858,7 +857,7 @@ impl QueryOrigin { } /// Create a query origin of type `QueryOriginKind::DerivedUntracked`, with the given edges. - pub fn derived_untracked(input_outputs: impl IntoIterator) -> QueryOrigin { + pub fn derived_untracked(input_outputs: Box<[QueryEdge]>) -> QueryOrigin { let mut origin = QueryOrigin::derived(input_outputs); origin.kind = QueryOriginKind::DerivedUntracked; origin @@ -937,8 +936,9 @@ impl<'de> serde::Deserialize<'de> for QueryOrigin { D: serde::Deserializer<'de>, { // Matches the signature of `QueryOriginRef`. - #[derive(serde::Deserialize)] #[repr(u8)] + #[derive(serde::Deserialize)] + #[serde(rename = "QueryOrigin")] pub enum QueryOriginOwned { Assigned(DatabaseKeyIndex) = QueryOriginKind::Assigned as u8, Derived(Box<[QueryEdge]>) = QueryOriginKind::Derived as u8, @@ -964,9 +964,9 @@ impl Drop for QueryOrigin { let input_outputs = unsafe { self.data.input_outputs }; let length = self.metadata as usize; - // SAFETY: `input_outputs` and `self.metadata` form a valid slice when the - // tag is `QueryOriginKind::DerivedUntracked` or `QueryOriginKind::DerivedUntracked`, - // and we have `&mut self`. + // SAFETY: `input_outputs` and `self.metadata` form a valid slice when the tag is + // `QueryOriginKind::DerivedUntracked` or `QueryOriginKind::DerivedUntracked`, and + // we have `&mut self`. let _input_outputs: Box<[QueryEdge]> = unsafe { Box::from_raw(ptr::slice_from_raw_parts_mut( input_outputs.as_ptr(), @@ -995,6 +995,7 @@ impl std::fmt::Debug for QueryOrigin { /// `QueryEdgeKind`, which is meaningful as inputs and outputs are stored contiguously. #[derive(Copy, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "persistence", serde(transparent))] pub struct QueryEdge { key: DatabaseKeyIndex, } diff --git a/tests/persistence.rs b/tests/persistence.rs index a1f424e3d..3f28dce2f 100644 --- a/tests/persistence.rs +++ b/tests/persistence.rs @@ -12,6 +12,11 @@ struct MyInput { field: usize, } +#[salsa::input(persist, singleton)] +struct MySingleton { + field: usize, +} + #[salsa::interned(persist)] struct MyInterned<'db> { field: String, @@ -60,7 +65,7 @@ fn everything() { "0": { "1": { "durabilities": [ - "Low" + 0 ], "revisions": [ 1 @@ -71,7 +76,7 @@ fn everything() { }, "2": { "durabilities": [ - "Low" + 0 ], "revisions": [ 1 @@ -88,6 +93,7 @@ fn everything() { let input1 = MyInput::new(&db, 1); let input2 = MyInput::new(&db, 2); + let _singleton = MySingleton::new(&db, 1); let _out = unit_to_interned(&db); let _out = input_to_tracked(&db, input1); @@ -109,7 +115,7 @@ fn everything() { "0": { "1": { "durabilities": [ - "Low" + 0 ], "revisions": [ 1 @@ -120,7 +126,7 @@ fn everything() { }, "2": { "durabilities": [ - "Low" + 0 ], "revisions": [ 1 @@ -131,7 +137,7 @@ fn everything() { }, "3": { "durabilities": [ - "Low" + 0 ], "revisions": [ 1 @@ -142,7 +148,7 @@ fn everything() { }, "4": { "durabilities": [ - "Low" + 0 ], "revisions": [ 1 @@ -153,17 +159,30 @@ fn everything() { } }, "2": { - "2049": { - "durability": "High", + "1025": { + "durabilities": [ + 0 + ], + "revisions": [ + 1 + ], + "fields": [ + 1 + ] + } + }, + "4": { + "3073": { + "durability": 2, "last_interned_at": 1, "fields": [ "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" ] } }, - "3": { - "3073": { - "durability": "Low", + "5": { + "4097": { + "durability": 0, "updated_at": 1, "revisions": [], "fields": [ @@ -171,9 +190,9 @@ fn everything() { ] } }, - "5": { - "4097": { - "durability": "High", + "7": { + "5121": { + "durability": 2, "last_interned_at": 18446744073709551615, "fields": [ 3, @@ -181,34 +200,30 @@ fn everything() { ] } }, - "17": { - "1025": { - "durability": "High", + "19": { + "2049": { + "durability": 2, "last_interned_at": 18446744073709551615, "fields": null } }, - "4": { - "5:4097": { + "6": { + "7:5121": { "value": "aaa", "verified_at": 1, "revisions": { "changed_at": 1, - "durability": "Low", + "durability": 0, "origin": { "Derived": [ - { - "key": { - "key_index": 3, - "ingredient_index": 1 - } - }, - { - "key": { - "key_index": 4, - "ingredient_index": 1 - } - } + [ + 3, + 1 + ], + [ + 4, + 1 + ] ] }, "verified_final": true, @@ -216,21 +231,19 @@ fn everything() { } } }, - "6": { + "8": { "0:3": { - "value": 3073, + "value": 4097, "verified_at": 1, "revisions": { "changed_at": 1, - "durability": "Low", + "durability": 0, "origin": { "Derived": [ - { - "key": { - "key_index": 3, - "ingredient_index": 1 - } - } + [ + 3, + 1 + ] ] }, "verified_final": true, @@ -238,11 +251,11 @@ fn everything() { "tracked_struct_ids": [ [ { - "ingredient_index": 3, + "ingredient_index": 5, "hash": 6073466998405137972, "disambiguator": 0 }, - 3073 + 4097 ] ], "cycle_heads": [], @@ -251,21 +264,19 @@ fn everything() { } } }, - "16": { - "17:1025": { - "value": 2049, + "18": { + "19:2049": { + "value": 3073, "verified_at": 1, "revisions": { "changed_at": 1, - "durability": "High", + "durability": 2, "origin": { "Derived": [ - { - "key": { - "key_index": 2049, - "ingredient_index": 2 - } - } + [ + 3073, + 4 + ] ] }, "verified_final": true, @@ -285,6 +296,8 @@ fn everything() { ) .unwrap(); + assert_eq!(MySingleton::get(&db).field(&db), 1); + let _out = unit_to_interned(&db); let _out = input_to_tracked(&db, input1); let _out = input_pair_to_string(&db, input1, input2); @@ -336,7 +349,7 @@ fn partial_query() { "0": { "1": { "durabilities": [ - "Low" + 0 ], "revisions": [ 1 @@ -346,21 +359,19 @@ fn partial_query() { ] } }, - "11": { + "13": { "0:1": { "value": 1, "verified_at": 1, "revisions": { "changed_at": 1, - "durability": "Low", + "durability": 0, "origin": { "Derived": [ - { - "key": { - "key_index": 1, - "ingredient_index": 1 - } - } + [ + 1, + 1 + ] ] }, "verified_final": true, @@ -454,7 +465,7 @@ fn partial_query_interned() { "0": { "1": { "durabilities": [ - "Low" + 0 ], "revisions": [ 1 @@ -464,18 +475,18 @@ fn partial_query_interned() { ] } }, - "2": { + "4": { "3073": { - "durability": "Low", + "durability": 0, "last_interned_at": 1, "fields": [ "0" ] } }, - "15": { + "17": { "1025": { - "durability": "High", + "durability": 2, "last_interned_at": 18446744073709551615, "fields": [ 1, @@ -483,27 +494,23 @@ fn partial_query_interned() { ] } }, - "14": { - "15:1025": { + "16": { + "17:1025": { "value": 3073, "verified_at": 1, "revisions": { "changed_at": 1, - "durability": "Low", + "durability": 0, "origin": { "Derived": [ - { - "key": { - "key_index": 1, - "ingredient_index": 1 - } - }, - { - "key": { - "key_index": 3073, - "ingredient_index": 2 - } - } + [ + 1, + 1 + ], + [ + 3073, + 4 + ] ] }, "verified_final": true, From a3ffa22cb26756473d56f867aedec3fd907c4dd9 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Wed, 20 Aug 2025 09:15:36 +0200 Subject: [PATCH 39/65] fix: Runaway for unchanged queries participating in cycle (#981) * fix: Runaway for unchanged queries participating in cycle * Another regression test * Fix runaway situation * Discard changes to src/function/fetch.rs * Undo tracing changes * Move accumulated write outside of non-cycle branch * Short circuit if cycle head is executing * Inline * Update expected test output * Fix double execution * Simplify check in `validate_same_iteration` * Some more inline * Pass references --- src/cycle.rs | 23 ++- src/function.rs | 10 +- src/function/fetch.rs | 11 +- src/function/maybe_changed_after.rs | 211 +++++++++++++++++------- src/function/memo.rs | 1 + src/ingredient.rs | 11 +- tests/common/mod.rs | 1 + tests/cycle.rs | 239 ++++++++++++++++++++++++++++ tests/cycle_output.rs | 1 - 9 files changed, 433 insertions(+), 75 deletions(-) diff --git a/src/cycle.rs b/src/cycle.rs index 2ba1dcc8f..12cb1cdc9 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -53,6 +53,7 @@ use thin_vec::{thin_vec, ThinVec}; use crate::key::DatabaseKeyIndex; use crate::sync::OnceLock; +use crate::Revision; /// The maximum number of times we'll fixpoint-iterate before panicking. /// @@ -237,16 +238,30 @@ pub(crate) fn empty_cycle_heads() -> &'static CycleHeads { #[derive(Debug, PartialEq, Eq)] pub enum ProvisionalStatus { - Provisional { iteration: IterationCount }, - Final { iteration: IterationCount }, + Provisional { + iteration: IterationCount, + verified_at: Revision, + }, + Final { + iteration: IterationCount, + verified_at: Revision, + }, FallbackImmediate, } impl ProvisionalStatus { pub(crate) const fn iteration(&self) -> Option { match self { - ProvisionalStatus::Provisional { iteration } => Some(*iteration), - ProvisionalStatus::Final { iteration } => Some(*iteration), + ProvisionalStatus::Provisional { iteration, .. } => Some(*iteration), + ProvisionalStatus::Final { iteration, .. } => Some(*iteration), + ProvisionalStatus::FallbackImmediate => None, + } + } + + pub(crate) const fn verified_at(&self) -> Option { + match self { + ProvisionalStatus::Provisional { verified_at, .. } => Some(*verified_at), + ProvisionalStatus::Final { verified_at, .. } => Some(*verified_at), ProvisionalStatus::FallbackImmediate => None, } } diff --git a/src/function.rs b/src/function.rs index 47fdca40e..58f773895 100644 --- a/src/function.rs +++ b/src/function.rs @@ -345,10 +345,16 @@ where if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { ProvisionalStatus::FallbackImmediate } else { - ProvisionalStatus::Final { iteration } + ProvisionalStatus::Final { + iteration, + verified_at: memo.verified_at.load(), + } } } else { - ProvisionalStatus::Provisional { iteration } + ProvisionalStatus::Provisional { + iteration, + verified_at: memo.verified_at.load(), + } }) } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 87ff22db3..57b79a52a 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,3 +1,5 @@ +use rustc_hash::FxHashMap; + use crate::cycle::{CycleHeads, CycleRecoveryStrategy, IterationCount}; use crate::function::maybe_changed_after::VerifyCycleHeads; use crate::function::memo::Memo; @@ -172,7 +174,6 @@ where zalsa_local, database_key_index, old_memo, - true, ) { self.update_shallow(zalsa, database_key_index, old_memo, can_shallow_update); @@ -182,17 +183,19 @@ where return unsafe { Some(self.extend_memo_lifetime(old_memo)) }; } - let mut cycle_heads = VerifyCycleHeads::default(); + let mut cycle_heads = Vec::new(); + let mut participating_queries = FxHashMap::default(); + let verify_result = self.deep_verify_memo( db, zalsa, old_memo, database_key_index, - &mut cycle_heads, + &mut VerifyCycleHeads::new(&mut cycle_heads, &mut participating_queries), can_shallow_update, ); - if verify_result.is_unchanged() && !cycle_heads.has_any() { + if verify_result.is_unchanged() && cycle_heads.is_empty() { // SAFETY: memo is present in memo_map and we have verified that it is // still valid for the current revision. return unsafe { Some(self.extend_memo_lifetime(old_memo)) }; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 54fce885d..2ab1d18ee 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -1,9 +1,12 @@ +use rustc_hash::FxHashMap; + #[cfg(feature = "accumulator")] use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::cycle::{CycleRecoveryStrategy, IterationCount, ProvisionalStatus}; use crate::function::memo::Memo; use crate::function::sync::ClaimResult; use crate::function::{Configuration, IngredientImpl}; + use crate::key::DatabaseKeyIndex; use crate::sync::atomic::Ordering; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; @@ -11,7 +14,7 @@ use crate::zalsa_local::{QueryEdgeKind, QueryOriginRef, ZalsaLocal}; use crate::{Id, Revision}; /// Result of memo validation. -#[derive(Debug)] +#[derive(Debug, Copy, Clone)] pub enum VerifyResult { /// Memo has changed and needs to be recomputed. Changed, @@ -110,6 +113,7 @@ where if let Some(mcs) = self.maybe_changed_after_cold( zalsa, + zalsa_local, db, id, revision, @@ -124,9 +128,11 @@ where } #[inline(never)] + #[expect(clippy::too_many_arguments)] fn maybe_changed_after_cold<'db>( &'db self, zalsa: &Zalsa, + zalsa_local: &ZalsaLocal, db: &'db C::DbView, key_index: Id, revision: Revision, @@ -143,7 +149,7 @@ where } ClaimResult::Cycle { .. } => { return Some(self.maybe_changed_after_cold_cycle( - db, + zalsa_local, database_key_index, cycle_heads, )) @@ -163,21 +169,32 @@ where let can_shallow_update = self.shallow_verify_memo(zalsa, database_key_index, old_memo); if can_shallow_update.yes() - && self.validate_may_be_provisional( - zalsa, - db.zalsa_local(), - database_key_index, - old_memo, - // Don't conclude that the query is unchanged if the memo itself is still - // provisional (because all its cycle heads have the same iteration count - // as the cycle head memos in the database). - // See https://github.com/salsa-rs/salsa/pull/961 - false, - ) + && self.validate_may_be_provisional(zalsa, zalsa_local, database_key_index, old_memo) { self.update_shallow(zalsa, database_key_index, old_memo, can_shallow_update); - return Some(VerifyResult::unchanged()); + // If `validate_maybe_provisional` returns `true`, but only because all cycle heads are from the same iteration, + // carry over the cycle heads so that the caller verifies them. + if old_memo.may_be_provisional() { + for head in old_memo.cycle_heads() { + cycle_heads.insert_head(head.database_key_index); + } + } + + return Some(if old_memo.revisions.changed_at > revision { + VerifyResult::changed() + } else { + VerifyResult::unchanged_with_accumulated( + #[cfg(feature = "accumulator")] + { + old_memo.revisions.accumulated_inputs.load() + }, + ) + }); + } + + if let Some(cached) = cycle_heads.get_result(database_key_index) { + return Some(*cached); } let deep_verify = self.deep_verify_memo( @@ -210,9 +227,8 @@ where // `in_cycle` tracks if the enclosing query is in a cycle. `deep_verify.cycle_heads` tracks // if **this query** encountered a cycle (which means there's some provisional value somewhere floating around). if old_memo.value.is_some() && !cycle_heads.has_any() { - let active_query = db - .zalsa_local() - .push_query(database_key_index, IterationCount::initial()); + let active_query = + zalsa_local.push_query(database_key_index, IterationCount::initial()); let memo = self.execute(db, active_query, Some(old_memo)); let changed_at = memo.revisions.changed_at; @@ -239,16 +255,16 @@ where #[cold] #[inline(never)] - fn maybe_changed_after_cold_cycle<'db>( - &'db self, - db: &'db C::DbView, + fn maybe_changed_after_cold_cycle( + &self, + zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, cycle_heads: &mut VerifyCycleHeads, ) -> VerifyResult { match C::CYCLE_STRATEGY { // SAFETY: We do not access the query stack reentrantly. CycleRecoveryStrategy::Panic => unsafe { - db.zalsa_local().with_query_stack_unchecked(|stack| { + zalsa_local.with_query_stack_unchecked(|stack| { panic!( "dependency graph cycle when validating {database_key_index:#?}, \ set cycle_fn/cycle_initial to fixpoint iterate.\n\ @@ -261,8 +277,23 @@ where crate::tracing::debug!( "hit cycle at {database_key_index:?} in `maybe_changed_after`, returning fixpoint initial value", ); - cycle_heads.insert(database_key_index); - VerifyResult::unchanged() + cycle_heads.insert_head(database_key_index); + + // SAFETY: We don't access the query stack reentrantly. + let running = unsafe { + zalsa_local.with_query_stack_unchecked(|stack| { + stack + .iter() + .any(|query| query.database_key_index == database_key_index) + }) + }; + + // If the cycle head is being executed, consider this query as changed. + if running { + VerifyResult::changed() + } else { + VerifyResult::unchanged() + } } } } @@ -328,7 +359,6 @@ where /// * provisional memos that have been successfully marked as verified final, that is, its /// cycle heads have all been finalized. /// * provisional memos that have been created in the same revision and iteration and are part of the same cycle. - /// This check is skipped if `allow_non_finalized` is `false` as the memo itself is still not finalized. It's a provisional value. #[inline] pub(super) fn validate_may_be_provisional( &self, @@ -336,12 +366,10 @@ where zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, memo: &Memo<'_, C>, - allow_non_finalized: bool, ) -> bool { !memo.may_be_provisional() || self.validate_provisional(zalsa, database_key_index, memo) - || (allow_non_finalized - && self.validate_same_iteration(zalsa, zalsa_local, database_key_index, memo)) + || self.validate_same_iteration(zalsa, zalsa_local, database_key_index, memo) } /// Check if this memo's cycle heads have all been finalized. If so, mark it verified final and @@ -357,6 +385,9 @@ where "{database_key_index:?}: validate_provisional(memo = {memo:#?})", memo = memo.tracing_debug() ); + + let memo_verified_at = memo.verified_at.load(); + for cycle_head in memo.revisions.cycle_heads() { // Test if our cycle heads (with the same revision) are now finalized. let Some(kind) = zalsa @@ -365,15 +396,24 @@ where else { return false; }; + match kind { ProvisionalStatus::Provisional { .. } => return false, - ProvisionalStatus::Final { iteration } => { - // It's important to also account for the revision for the case where: + ProvisionalStatus::Final { + iteration, + verified_at, + } => { + // Only consider the cycle head if it is from the same revision as the memo + if verified_at != memo_verified_at { + return false; + } + + // It's important to also account for the iteration for the case where: // thread 1: `b` -> `a` (but only in the first iteration) // -> `c` -> `b` // thread 2: `a` -> `b` // - // If we don't account for the revision, then `a` (from iteration 0) will be finalized + // If we don't account for the iteration, then `a` (from iteration 0) will be finalized // because its cycle head `b` is now finalized, but `b` never pulled `a` in the last iteration. if iteration != cycle_head.iteration_count { return false; @@ -424,6 +464,15 @@ where return true; } + let verified_at = memo.verified_at.load(); + + // This is an optimization to avoid unnecessary re-execution within the same revision. + // Don't apply it when verifying memos from past revisions. We want them to re-execute + // to verify their cycle heads and all participating queries. + if verified_at != zalsa.current_revision() { + return false; + } + // SAFETY: We do not access the query stack reentrantly. unsafe { zalsa_local.with_query_stack_unchecked(|stack| { @@ -477,7 +526,12 @@ where zalsa, cycle_head.database_key_index.key_index(), )?; - provisional_status.iteration() + + if provisional_status.verified_at() == Some(verified_at) { + provisional_status.iteration() + } else { + None + } }) == Some(cycle_head.iteration_count) }) @@ -506,7 +560,7 @@ where old_memo = old_memo.tracing_debug() ); - debug_assert!(!cycle_heads.contains(database_key_index)); + debug_assert!(!cycle_heads.contains_head(database_key_index)); match old_memo.revisions.origin.as_ref() { QueryOriginRef::Derived(edges) => { @@ -534,8 +588,9 @@ where // The `MaybeChangeAfterCycleHeads` is used as an out parameter and it's // the caller's responsibility to pass an empty `heads`, which is what we do here. let mut inner_cycle_heads = VerifyCycleHeads { - heads: std::mem::take(&mut child_cycle_heads), has_outer_cycles: cycle_heads.has_any(), + heads: &mut child_cycle_heads, + participating_queries: cycle_heads.participating_queries, }; let input_result = dependency_index.maybe_changed_after( @@ -545,10 +600,8 @@ where &mut inner_cycle_heads, ); - // Reuse the cycle head allocation. - child_cycle_heads = inner_cycle_heads.heads; // Aggregate the cycle heads into the parent cycle heads - cycle_heads.append(&mut child_cycle_heads); + cycle_heads.append_heads(&mut child_cycle_heads); match input_result { VerifyResult::Changed => return VerifyResult::changed(), @@ -607,23 +660,30 @@ where // from cycle heads. We will handle our own memo (and the rest of our cycle) on a // future iteration; first the outer cycle head needs to verify itself. - cycle_heads.remove(database_key_index); + cycle_heads.remove_head(database_key_index); + + let result = VerifyResult::unchanged_with_accumulated( + #[cfg(feature = "accumulator")] + inputs, + ); + + // This value is only read once the memo is verified. It's therefore safe + // to write a non-final value here. + #[cfg(feature = "accumulator")] + old_memo.revisions.accumulated_inputs.store(inputs); // 1 and 3 if !cycle_heads.has_own() { old_memo.mark_as_verified(zalsa, database_key_index); - #[cfg(feature = "accumulator")] - old_memo.revisions.accumulated_inputs.store(inputs); old_memo .revisions .verified_final .store(true, Ordering::Relaxed); + } else { + cycle_heads.insert_participating_query(database_key_index, result); } - VerifyResult::unchanged_with_accumulated( - #[cfg(feature = "accumulator")] - inputs, - ) + result } QueryOriginRef::Assigned(_) => { @@ -645,7 +705,7 @@ where // are tracked by the outer query. Nothing should have changed assuming that the // fixpoint initial function is deterministic. QueryOriginRef::FixpointInitial => { - cycle_heads.insert(database_key_index); + cycle_heads.insert_head(database_key_index); VerifyResult::unchanged() } QueryOriginRef::DerivedUntracked(_) => { @@ -690,29 +750,58 @@ impl ShallowUpdate { /// aren't included. /// /// [`maybe_changed_after`]: crate::ingredient::Ingredient::maybe_changed_after -#[derive(Debug, Default)] -pub struct VerifyCycleHeads { - heads: Vec, +#[derive(Debug)] +pub struct VerifyCycleHeads<'a> { + /// The cycle heads encountered while verifying this ingredient and its subtree. + heads: &'a mut Vec, + + /// The cached `maybe_changed_after` results for queries that participate in cycles but aren't a cycle head + /// themselves. We need to cache the results here to avoid calling `deep_verify_memo` repeatedly + /// for queries that have cyclic dependencies (b depends on a (iteration 0) and a depends on b(iteration 1)) + /// as well as to avoid a run-away situation if a query is dependet on a lot inside a single cycle. + participating_queries: &'a mut FxHashMap, /// Whether the outer query (e.g. the parent query running `maybe_changed_after`) has encountered /// any cycles to this point. has_outer_cycles: bool, } -impl VerifyCycleHeads { +impl<'a> VerifyCycleHeads<'a> { + pub(crate) fn new( + heads: &'a mut Vec, + participating_queries: &'a mut FxHashMap, + ) -> Self { + Self { + heads, + participating_queries, + has_outer_cycles: false, + } + } + + /// Returns `true` if this query or any of its dependencies depend on this cycle. #[inline] - fn contains(&self, key: DatabaseKeyIndex) -> bool { + fn contains_head(&self, key: DatabaseKeyIndex) -> bool { self.heads.contains(&key) } #[inline] - fn insert(&mut self, key: DatabaseKeyIndex) { + fn insert_head(&mut self, key: DatabaseKeyIndex) { if !self.heads.contains(&key) { self.heads.push(key); } } - fn remove(&mut self, key: DatabaseKeyIndex) -> bool { + #[inline] + fn remove_head(&mut self, key: DatabaseKeyIndex) -> bool { + if self.heads.is_empty() { + return false; + } + + self.remove_head_slow(key) + } + + #[cold] + fn remove_head_slow(&mut self, key: DatabaseKeyIndex) -> bool { let found = self.heads.iter().position(|&head| head == key); let Some(found) = found else { return false }; @@ -721,20 +810,30 @@ impl VerifyCycleHeads { } #[inline] - fn append(&mut self, heads: &mut Vec) { + fn append_heads(&mut self, heads: &mut Vec) { if heads.is_empty() { return; } - self.append_slow(heads); + self.append_heads_slow(heads); } - fn append_slow(&mut self, heads: &mut Vec) { - for key in heads.drain(..) { - self.insert(key); + #[cold] + fn append_heads_slow(&mut self, other: &mut Vec) { + for key in other.drain(..) { + self.insert_head(key); } } + fn insert_participating_query(&mut self, key: DatabaseKeyIndex, result: VerifyResult) { + self.participating_queries.insert(key, result); + } + + #[inline] + fn get_result(&self, key: DatabaseKeyIndex) -> Option<&VerifyResult> { + self.participating_queries.get(&key) + } + /// Returns `true` if this query or any of its dependencies has encountered a cycle or /// if the outer query has encountered a cycle. pub fn has_any(&self) -> bool { diff --git a/src/function/memo.rs b/src/function/memo.rs index dc346adcf..793f4832a 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -553,6 +553,7 @@ impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { .provisional_status(self.zalsa, head_key_index) .unwrap_or(ProvisionalStatus::Provisional { iteration: IterationCount::initial(), + verified_at: Revision::start(), }); match cycle_head_kind { diff --git a/src/ingredient.rs b/src/ingredient.rs index 6493061b6..3cf36ae61 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -1,9 +1,7 @@ use std::any::{Any, TypeId}; use std::fmt; -use crate::cycle::{ - empty_cycle_heads, CycleHeads, CycleRecoveryStrategy, IterationCount, ProvisionalStatus, -}; +use crate::cycle::{empty_cycle_heads, CycleHeads, CycleRecoveryStrategy, ProvisionalStatus}; use crate::database::RawDatabase; use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::hash::{FxHashSet, FxIndexSet}; @@ -76,11 +74,8 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// Is it a provisional value or has it been finalized and in which iteration. /// /// Returns `None` if `input` doesn't exist. - fn provisional_status(&self, zalsa: &Zalsa, input: Id) -> Option { - _ = (zalsa, input); - Some(ProvisionalStatus::Final { - iteration: IterationCount::initial(), - }) + fn provisional_status(&self, _zalsa: &Zalsa, _input: Id) -> Option { + unreachable!("provisional_status should only be called on cycle heads and only functions can be cycle heads"); } /// Returns the cycle heads for this ingredient. diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 3b9e7434f..46f1de86c 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -158,6 +158,7 @@ impl Default for ExecuteValidateLoggerDatabase { move |event| match event.kind { salsa::EventKind::WillExecute { .. } | salsa::EventKind::WillIterateCycle { .. } + | salsa::EventKind::DidValidateInternedValue { .. } | salsa::EventKind::DidValidateMemoizedValue { .. } => { logger.push_log(format!("salsa_event({:?})", event.kind)); } diff --git a/tests/cycle.rs b/tests/cycle.rs index 2eb9bac23..d226a9eb7 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -1137,3 +1137,242 @@ fn repeat_provisional_query_incremental() { "salsa_event(WillExecute { database_key: min_panic(Id(2)) })", ]"#]]); } + +/// Tests a situation where a query participating in a cycle gets called many times (think thousands of times). +/// +/// We want to avoid calling `deep_verify_memo` for that query over and over again. +/// This isn't an issue for regular queries because a non-cyclic query is guaranteed to be verified +/// after `maybe_changed_after` because: +/// * It can be shallow verified +/// * `deep_verify_memo` returns `unchanged` and it updates the `verified_at` revision. +/// * `deep_verify_memo` returns `changed` and Salsa re-executes the query. The query is verified once `execute` completes. +/// +/// The same guarantee doesn't exist for queries participating in cycles because: +/// +/// * Salsa update `verified_at` because it depends on the cycle head if the query didn't change. +/// * Salsa doesn't run `execute` because some inputs may not have been verified yet (which can lead to all sort of pancis). +#[test] +fn repeat_query_participating_in_cycle() { + #[salsa::input] + struct Input { + value: u32, + } + + #[salsa::interned] + struct Interned { + value: u32, + } + + #[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=initial)] + fn head(db: &dyn Db, input: Input) -> u32 { + let a = query_a(db, input); + + a.min(2) + } + + fn initial(_db: &dyn Db, _input: Input) -> u32 { + 0 + } + + fn cycle_recover( + _db: &dyn Db, + _value: &u32, + _count: u32, + _input: Input, + ) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate + } + + #[salsa::tracked] + fn query_a(db: &dyn Db, input: Input) -> u32 { + let _ = query_b(db, input); + + query_hot(db, input) + } + + #[salsa::tracked] + fn query_b(db: &dyn Db, input: Input) -> u32 { + let _ = query_c(db, input); + + query_hot(db, input) + } + + #[salsa::tracked] + fn query_c(db: &dyn Db, input: Input) -> u32 { + let _ = query_d(db, input); + + query_hot(db, input) + } + + #[salsa::tracked] + fn query_d(db: &dyn Db, input: Input) -> u32 { + query_hot(db, input) + } + + #[salsa::tracked] + fn query_hot(db: &dyn Db, input: Input) -> u32 { + let value = head(db, input); + + let _ = Interned::new(db, 2); + + let _ = input.value(db); + + value + 1 + } + + let mut db = ExecuteValidateLoggerDatabase::default(); + + let input = Input::new(&db, 1); + + assert_eq!(head(&db, input), 2); + + db.clear_logs(); + + input.set_value(&mut db).to(10); + + assert_eq!(head(&db, input), 2); + + // The interned value should only be validate once. We otherwise have a + // run-away situation where `deep_verify_memo` of `query_hot` is called over and over again. + // * First: when checking if `head` has changed + // * Second: when checking if `query_a` has changed + // * Third: when checking if `query_b` has changed + // * ... + // Ultimately, this can easily be more expensive than running the cycle head again. + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateInternedValue { key: Interned(Id(400)), revision: R2 })", + "salsa_event(WillExecute { database_key: head(Id(0)) })", + "salsa_event(WillExecute { database_key: query_a(Id(0)) })", + "salsa_event(WillExecute { database_key: query_b(Id(0)) })", + "salsa_event(WillExecute { database_key: query_c(Id(0)) })", + "salsa_event(WillExecute { database_key: query_d(Id(0)) })", + "salsa_event(WillExecute { database_key: query_hot(Id(0)) })", + "salsa_event(WillIterateCycle { database_key: head(Id(0)), iteration_count: IterationCount(1), fell_back: false })", + "salsa_event(WillExecute { database_key: query_a(Id(0)) })", + "salsa_event(WillExecute { database_key: query_b(Id(0)) })", + "salsa_event(WillExecute { database_key: query_c(Id(0)) })", + "salsa_event(WillExecute { database_key: query_d(Id(0)) })", + "salsa_event(WillExecute { database_key: query_hot(Id(0)) })", + "salsa_event(WillIterateCycle { database_key: head(Id(0)), iteration_count: IterationCount(2), fell_back: false })", + "salsa_event(WillExecute { database_key: query_a(Id(0)) })", + "salsa_event(WillExecute { database_key: query_b(Id(0)) })", + "salsa_event(WillExecute { database_key: query_c(Id(0)) })", + "salsa_event(WillExecute { database_key: query_d(Id(0)) })", + "salsa_event(WillExecute { database_key: query_hot(Id(0)) })", + ]"#]]); +} + +/// Tests a similar scenario as `repeat_query_participating_in_cycle` with the main difference +/// that `query_hot` is called before calling the next `query_xxx`. +#[test] +fn repeat_query_participating_in_cycle2() { + #[salsa::input] + struct Input { + value: u32, + } + + #[salsa::interned] + struct Interned { + value: u32, + } + + #[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=initial)] + fn head(db: &dyn Db, input: Input) -> u32 { + let a = query_a(db, input); + + a.min(2) + } + + fn initial(_db: &dyn Db, _input: Input) -> u32 { + 0 + } + + fn cycle_recover( + _db: &dyn Db, + _value: &u32, + _count: u32, + _input: Input, + ) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate + } + + #[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=initial)] + fn query_a(db: &dyn Db, input: Input) -> u32 { + let _ = query_hot(db, input); + query_b(db, input) + } + + #[salsa::tracked] + fn query_b(db: &dyn Db, input: Input) -> u32 { + let _ = query_hot(db, input); + query_c(db, input) + } + + #[salsa::tracked] + fn query_c(db: &dyn Db, input: Input) -> u32 { + let _ = query_hot(db, input); + query_d(db, input) + } + + #[salsa::tracked] + fn query_d(db: &dyn Db, input: Input) -> u32 { + let _ = query_hot(db, input); + + let value = head(db, input); + let _ = input.value(db); + + value + 1 + } + + #[salsa::tracked] + fn query_hot(db: &dyn Db, input: Input) -> u32 { + let _ = Interned::new(db, 2); + + let _ = head(db, input); + + 1 + } + + let mut db = ExecuteValidateLoggerDatabase::default(); + + let input = Input::new(&db, 1); + + assert_eq!(head(&db, input), 2); + + db.clear_logs(); + + input.set_value(&mut db).to(10); + + assert_eq!(head(&db, input), 2); + + // `DidValidateInternedValue { key: Interned(Id(400)), revision: R2 }` should only be logged + // once per `maybe_changed_after` root-call (e.g. validating `head` shouldn't validate `query_hot` multiple times). + // + // This is important to avoid a run-away situation where a query is called many times within a cycle and + // Salsa would end up recusively validating the hot query over and over again. + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateInternedValue { key: Interned(Id(400)), revision: R2 })", + "salsa_event(WillExecute { database_key: head(Id(0)) })", + "salsa_event(DidValidateInternedValue { key: Interned(Id(400)), revision: R2 })", + "salsa_event(WillExecute { database_key: query_a(Id(0)) })", + "salsa_event(DidValidateInternedValue { key: Interned(Id(400)), revision: R2 })", + "salsa_event(WillExecute { database_key: query_hot(Id(0)) })", + "salsa_event(WillExecute { database_key: query_b(Id(0)) })", + "salsa_event(WillExecute { database_key: query_c(Id(0)) })", + "salsa_event(WillExecute { database_key: query_d(Id(0)) })", + "salsa_event(WillIterateCycle { database_key: head(Id(0)), iteration_count: IterationCount(1), fell_back: false })", + "salsa_event(WillExecute { database_key: query_a(Id(0)) })", + "salsa_event(WillExecute { database_key: query_hot(Id(0)) })", + "salsa_event(WillExecute { database_key: query_b(Id(0)) })", + "salsa_event(WillExecute { database_key: query_c(Id(0)) })", + "salsa_event(WillExecute { database_key: query_d(Id(0)) })", + "salsa_event(WillIterateCycle { database_key: head(Id(0)), iteration_count: IterationCount(2), fell_back: false })", + "salsa_event(WillExecute { database_key: query_a(Id(0)) })", + "salsa_event(WillExecute { database_key: query_hot(Id(0)) })", + "salsa_event(WillExecute { database_key: query_b(Id(0)) })", + "salsa_event(WillExecute { database_key: query_c(Id(0)) })", + "salsa_event(WillExecute { database_key: query_d(Id(0)) })", + ]"#]]); +} diff --git a/tests/cycle_output.rs b/tests/cycle_output.rs index c8e17ba6b..27ba304e3 100644 --- a/tests/cycle_output.rs +++ b/tests/cycle_output.rs @@ -184,7 +184,6 @@ fn revalidate_with_change_after_output_read() { "salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(400)) })", "salsa_event(DidValidateInternedValue { key: query_d::interned_arguments(Id(800)), revision: R2 })", "salsa_event(WillExecute { database_key: query_b(Id(0)) })", - "salsa_event(DidValidateInternedValue { key: query_d::interned_arguments(Id(800)), revision: R2 })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", "salsa_event(WillExecute { database_key: query_d(Id(800)) })", "salsa_event(WillDiscardStaleOutput { execute_key: query_a(Id(0)), output_key: Output(Id(401)) })", From a0e7a0660c93136f23bf08b4f1604eee3d1f6b11 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Fri, 22 Aug 2025 17:54:46 -0400 Subject: [PATCH 40/65] refactor `entries` API (#987) --- .../src/setup_input_struct.rs | 2 +- .../src/setup_interned_struct.rs | 2 +- .../salsa-macro-rules/src/setup_tracked_fn.rs | 2 +- .../src/setup_tracked_struct.rs | 2 +- src/input.rs | 41 +++++++++++++++--- src/interned.rs | 43 ++++++++++++++++--- src/tracked_struct.rs | 41 +++++++++++++++--- tests/debug_db_contents.rs | 26 +++++------ tests/interned-structs_self_ref.rs | 2 +- tests/persistence.rs | 18 ++++---- 10 files changed, 133 insertions(+), 46 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index ce5318208..741f9393e 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -216,7 +216,7 @@ macro_rules! setup_input_struct { zalsa: &$zalsa::Zalsa ) -> impl Iterator + '_ { let ingredient_index = zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>(); - <$Configuration>::ingredient_(zalsa).entries(zalsa).map(|(key, _)| key) + <$Configuration>::ingredient_(zalsa).entries(zalsa).map(|entry| entry.key()) } #[inline] diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index e73737ae8..1d27a33a2 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -244,7 +244,7 @@ macro_rules! setup_interned_struct { zalsa: &$zalsa::Zalsa ) -> impl Iterator + '_ { let ingredient_index = zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>(); - <$Configuration>::ingredient(zalsa).entries(zalsa).map(|(key, _)| key) + <$Configuration>::ingredient(zalsa).entries(zalsa).map(|entry| entry.key()) } #[inline] diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index e252f068f..945021f3a 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -129,7 +129,7 @@ macro_rules! setup_tracked_fn { zalsa: &$zalsa::Zalsa ) -> impl Iterator + '_ { let ingredient_index = zalsa.lookup_jar_by_type::<$fn_name>().successor(0); - <$Configuration>::intern_ingredient(zalsa).entries(zalsa).map(|(key, _)| key) + <$Configuration>::intern_ingredient(zalsa).entries(zalsa).map(|entry| entry.key()) } #[inline] diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index 191659a07..92dc25974 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -276,7 +276,7 @@ macro_rules! setup_tracked_struct { zalsa: &$zalsa::Zalsa ) -> impl Iterator + '_ { let ingredient_index = zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>(); - <$Configuration>::ingredient_(zalsa).entries(zalsa).map(|(key, _)| key) + <$Configuration>::ingredient_(zalsa).entries(zalsa).map(|entry| entry.key()) } #[inline] diff --git a/src/input.rs b/src/input.rs index 4c79cd80f..0ce0f5c87 100644 --- a/src/input.rs +++ b/src/input.rs @@ -238,14 +238,14 @@ impl IngredientImpl { } /// Returns all data corresponding to the input struct. - pub fn entries<'db>( - &'db self, - zalsa: &'db Zalsa, - ) -> impl Iterator)> + 'db { + pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { zalsa .table() .slots_of::>() - .map(|(id, value)| (self.database_key_index(id), value)) + .map(|(id, value)| StructEntry { + value, + key: self.database_key_index(id), + }) } /// Peek at the field values without recording any read dependency. @@ -257,6 +257,35 @@ impl IngredientImpl { } } +/// An input struct entry. +pub struct StructEntry<'db, C> +where + C: Configuration, +{ + value: &'db Value, + key: DatabaseKeyIndex, +} + +impl<'db, C> StructEntry<'db, C> +where + C: Configuration, +{ + /// Returns the `DatabaseKeyIndex` for this entry. + pub fn key(&self) -> DatabaseKeyIndex { + self.key + } + + /// Returns the input struct. + pub fn as_struct(&self) -> C::Struct { + FromId::from_id(self.key.key_index()) + } + + #[cfg(feature = "salsa_unstable")] + pub fn value(&self) -> &'db Value { + self.value + } +} + impl Ingredient for IngredientImpl { fn location(&self) -> &'static crate::ingredient::Location { &C::LOCATION @@ -312,7 +341,7 @@ impl Ingredient for IngredientImpl { .entries(db.zalsa()) // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. - .map(|(_, value)| unsafe { value.memory_usage(&self.memo_table_types) }) + .map(|entry| unsafe { entry.value.memory_usage(&self.memo_table_types) }) .collect(); Some(memory_usage) diff --git a/src/interned.rs b/src/interned.rs index c7e1dc2a0..77dda42b7 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -809,10 +809,7 @@ where } /// Returns all data corresponding to the interned struct. - pub fn entries<'db>( - &'db self, - zalsa: &'db Zalsa, - ) -> impl Iterator)> + 'db { + pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { // SAFETY: `should_lock` is `true` unsafe { self.entries_inner(true, zalsa) } } @@ -827,7 +824,7 @@ where &'db self, should_lock: bool, zalsa: &'db Zalsa, - ) -> impl Iterator)> + 'db { + ) -> impl Iterator> { // TODO: Grab all locks eagerly. zalsa.table().slots_of::>().map(move |(_, value)| { if should_lock { @@ -840,11 +837,43 @@ where // Note that this ID includes the generation, unlike the ID provided by the table. let id = unsafe { (*value.shared.get()).id }; - (self.database_key_index(id), value) + StructEntry { + value, + key: self.database_key_index(id), + } }) } } +/// An interned struct entry. +pub struct StructEntry<'db, C> +where + C: Configuration, +{ + value: &'db Value, + key: DatabaseKeyIndex, +} + +impl<'db, C> StructEntry<'db, C> +where + C: Configuration, +{ + /// Returns the `DatabaseKeyIndex` for this entry. + pub fn key(&self) -> DatabaseKeyIndex { + self.key + } + + /// Returns the interned struct. + pub fn as_struct(&self) -> C::Struct<'_> { + FromId::from_id(self.key.key_index()) + } + + #[cfg(feature = "salsa_unstable")] + pub fn value(&self) -> &'db Value { + self.value + } +} + impl Ingredient for IngredientImpl where C: Configuration, @@ -949,7 +978,7 @@ where let memory_usage = entries // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. Additionally, we are holding the locks for all shards. - .map(|(_, value)| unsafe { value.memory_usage(&self.memo_table_types) }) + .map(|entry| unsafe { entry.value.memory_usage(&self.memo_table_types) }) .collect(); for shard in self.shards.iter() { diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 66c0cade9..78ebef5d9 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -901,14 +901,43 @@ where } /// Returns all data corresponding to the tracked struct. - pub fn entries<'db>( - &'db self, - zalsa: &'db Zalsa, - ) -> impl Iterator)> + 'db { + pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { zalsa .table() .slots_of::>() - .map(|(id, value)| (self.database_key_index(id), value)) + .map(|(id, value)| StructEntry { + value, + key: self.database_key_index(id), + }) + } +} + +/// A tracked struct entry. +pub struct StructEntry<'db, C> +where + C: Configuration, +{ + value: &'db Value, + key: DatabaseKeyIndex, +} + +impl<'db, C> StructEntry<'db, C> +where + C: Configuration, +{ + /// Returns the `DatabaseKeyIndex` for this entry. + pub fn key(&self) -> DatabaseKeyIndex { + self.key + } + + /// Returns the tracked struct. + pub fn as_struct(&self) -> C::Struct<'_> { + FromId::from_id(self.key.key_index()) + } + + #[cfg(feature = "salsa_unstable")] + pub fn value(&self) -> &'db Value { + self.value } } @@ -990,7 +1019,7 @@ where .entries(db.zalsa()) // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. - .map(|(_, value)| unsafe { value.memory_usage(&self.memo_table_types) }) + .map(|entry| unsafe { entry.value.memory_usage(&self.memo_table_types) }) .collect(); Some(memory_usage) diff --git a/tests/debug_db_contents.rs b/tests/debug_db_contents.rs index 4160a3e32..aaad544a6 100644 --- a/tests/debug_db_contents.rs +++ b/tests/debug_db_contents.rs @@ -25,38 +25,40 @@ fn execute() { use salsa::plumbing::ZalsaDatabase; let db = salsa::DatabaseImpl::new(); - let _ = InternedStruct::new(&db, "Salsa".to_string()); - let _ = InternedStruct::new(&db, "Salsa2".to_string()); + let interned1 = InternedStruct::new(&db, "Salsa".to_string()); + let interned2 = InternedStruct::new(&db, "Salsa2".to_string()); // test interned structs let interned = InternedStruct::ingredient(db.zalsa()) .entries(db.zalsa()) - .map(|(_, value)| value) .collect::>(); assert_eq!(interned.len(), 2); - assert_eq!(interned[0].fields().0, "Salsa"); - assert_eq!(interned[1].fields().0, "Salsa2"); + assert_eq!(interned[0].as_struct(), interned1); + assert_eq!(interned[1].as_struct(), interned2); + assert_eq!(interned[0].value().fields().0, "Salsa"); + assert_eq!(interned[1].value().fields().0, "Salsa2"); // test input structs - let input = InputStruct::new(&db, 22); + let input1 = InputStruct::new(&db, 22); let inputs = InputStruct::ingredient(&db) .entries(db.zalsa()) - .map(|(_, value)| value) .collect::>(); assert_eq!(inputs.len(), 1); - assert_eq!(inputs[0].fields().0, 22); + assert_eq!(inputs[0].as_struct(), input1); + assert_eq!(inputs[0].value().fields().0, 22); // test tracked structs - let computed = tracked_fn(&db, input).field(&db); - assert_eq!(computed, 44); + let tracked1 = tracked_fn(&db, input1); + assert_eq!(tracked1.field(&db), 44); + let tracked = TrackedStruct::ingredient(&db) .entries(db.zalsa()) - .map(|(_, value)| value) .collect::>(); assert_eq!(tracked.len(), 1); - assert_eq!(tracked[0].fields().0, computed); + assert_eq!(tracked[0].as_struct(), tracked1); + assert_eq!(tracked[0].value().fields().0, tracked1.field(&db)); } diff --git a/tests/interned-structs_self_ref.rs b/tests/interned-structs_self_ref.rs index 01ff914c0..4fa34b5c5 100644 --- a/tests/interned-structs_self_ref.rs +++ b/tests/interned-structs_self_ref.rs @@ -152,7 +152,7 @@ const _: () = { zalsa.lookup_jar_by_type::>(); ::ingredient(zalsa) .entries(zalsa) - .map(|(key, _)| key) + .map(|entry| entry.key()) } #[inline] diff --git a/tests/persistence.rs b/tests/persistence.rs index 3f28dce2f..b476bd192 100644 --- a/tests/persistence.rs +++ b/tests/persistence.rs @@ -314,7 +314,7 @@ fn everything() { #[test] fn partial_query() { - use salsa::plumbing::{FromId, ZalsaDatabase}; + use salsa::plumbing::ZalsaDatabase; #[salsa::tracked(persist)] fn query<'db>(db: &'db dyn salsa::Database, input: MyInput) -> usize { @@ -390,12 +390,11 @@ fn partial_query() { ) .unwrap(); - // TODO: Expose a better way of recreating inputs after deserialization. - let (id, _) = MyInput::ingredient(&db) + let input = MyInput::ingredient(&db) .entries(db.zalsa()) .next() - .expect("`MyInput` was persisted"); - let input = MyInput::from_id(id.key_index()); + .unwrap() + .as_struct(); let result = query(&db, input); assert_eq!(result, 1); @@ -425,7 +424,7 @@ fn partial_query() { #[test] fn partial_query_interned() { - use salsa::plumbing::{AsId, FromId, ZalsaDatabase}; + use salsa::plumbing::{AsId, ZalsaDatabase}; #[salsa::tracked(persist)] fn intern<'db>(db: &'db dyn salsa::Database, input: MyInput, value: usize) -> MyInterned<'db> { @@ -529,12 +528,11 @@ fn partial_query_interned() { ) .unwrap(); - // TODO: Expose a better way of recreating inputs after deserialization. - let (id, _) = MyInput::ingredient(&db) + let input = MyInput::ingredient(&db) .entries(db.zalsa()) .next() - .expect("`MyInput` was persisted"); - let input = MyInput::from_id(id.key_index()); + .unwrap() + .as_struct(); // Re-intern `i0`. let i0 = intern(&db, input, 0); From 3713cd7eb30821c0c086591832dd6f59f2af7fe7 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 10 Sep 2025 12:18:30 -0700 Subject: [PATCH 41/65] Allow fallback to take longer than one iteration to converge (#991) --- book/src/cycles.md | 8 +++++--- src/event.rs | 1 - src/function/execute.rs | 16 ---------------- tests/cycle.rs | 16 ++++++++-------- tests/cycle_output.rs | 6 +++--- tests/cycle_recovery_call_back_into_cycle.rs | 1 - tests/cycle_tracked.rs | 6 +++--- tests/cycle_tracked_own_input.rs | 2 +- 8 files changed, 20 insertions(+), 36 deletions(-) diff --git a/book/src/cycles.md b/book/src/cycles.md index 8222d9eaf..2215b8ff3 100644 --- a/book/src/cycles.md +++ b/book/src/cycles.md @@ -21,11 +21,13 @@ fn initial(_db: &dyn KnobsDatabase) -> u32 { } ``` -If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `initial_fn` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the current value and the number of iterations that have occurred so far). The `cycle_fn` can return `salsa::CycleRecoveryAction::Iterate` to indicate that the cycle should iterate again, or `salsa::CycleRecoveryAction::Fallback(value)` to indicate that the cycle should stop iterating and fall back to the value provided. +If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `initial_fn` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the current value and the number of iterations that have occurred so far). The `cycle_fn` can return `salsa::CycleRecoveryAction::Iterate` to indicate that the cycle should iterate again, or `salsa::CycleRecoveryAction::Fallback(value)` to indicate that fixpoint iteration should resume starting with the given value (which should be a value that will converge quickly). -If the `cycle_fn` continues to return `Iterate`, the cycle will iterate until it converges: that is, until two successive iterations produce the same result. +The cycle will iterate until it converges: that is, until two successive iterations produce the same result. -If the `cycle_fn` returns `Fallback`, the cycle will iterate one last time and verify that the returned value is the same as the fallback value; that is, the fallback value results in a stable converged cycle. If not, Salsa will panic. It is not permitted to use a fallback value that does not converge, because this would leave the cycle in an unpredictable state, depending on the order of query execution. +If the `cycle_fn` returns `Fallback`, the cycle will still continue to iterate (using the given value as a new starting point), in order to verify that the fallback value results in a stable converged cycle. It is not permitted to use a fallback value that does not converge, because this would leave the cycle in an unpredictable state, depending on the order of query execution. + +If a cycle iterates more than 200 times, Salsa will panic rather than iterate forever. ## All potential cycle heads must set `cycle_fn` and `cycle_initial` diff --git a/src/event.rs b/src/event.rs index fbe8a784e..310c565f0 100644 --- a/src/event.rs +++ b/src/event.rs @@ -63,7 +63,6 @@ pub enum EventKind { /// The database-key for the cycle head. Implements `Debug`. database_key: DatabaseKeyIndex, iteration_count: IterationCount, - fell_back: bool, }, /// Indicates that `unwind_if_cancelled` was called and salsa will check if diff --git a/src/function/execute.rs b/src/function/execute.rs index d1651859e..723964851 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -134,7 +134,6 @@ where ) -> (C::Output<'db>, CompletedQuery) { let database_key_index = active_query.database_key_index; let mut iteration_count = IterationCount::initial(); - let mut fell_back = false; let zalsa_local = db.zalsa_local(); // Our provisional value from the previous iteration, when doing fixpoint iteration. @@ -192,16 +191,6 @@ where // If the new result is equal to the last provisional result, the cycle has // converged and we are done. if !C::values_equal(&new_value, last_provisional_value) { - if fell_back { - // We fell back to a value last iteration, but the fallback didn't result - // in convergence. We only have bad options here: continue iterating - // (ignoring the request to fall back), or forcibly use the fallback and - // leave the cycle in an inconsistent state (we'll be using a value for - // this query that it doesn't evaluate to, given its inputs). Maybe we'll - // have to go with the latter, but for now let's panic and see if real use - // cases need non-converging fallbacks. - panic!("{database_key_index:?}: execute: fallback did not converge"); - } // We are in a cycle that hasn't converged; ask the user's // cycle-recovery function what to do: match C::recover_from_cycle( @@ -216,10 +205,6 @@ where "{database_key_index:?}: execute: user cycle_fn says to fall back" ); new_value = fallback_value; - // We have to insert the fallback value for this query and then iterate - // one more time to fill in correct values for everything else in the - // cycle based on it; then we'll re-insert it as final value. - fell_back = true; } } // `iteration_count` can't overflow as we check it against `MAX_ITERATIONS` @@ -231,7 +216,6 @@ where Event::new(EventKind::WillIterateCycle { database_key: database_key_index, iteration_count, - fell_back, }) }); cycle_heads.update_iteration_count(database_key_index, iteration_count); diff --git a/tests/cycle.rs b/tests/cycle.rs index d226a9eb7..7a7e26a07 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -436,7 +436,7 @@ fn two_fallback_count() { /// /// Two-query cycle, falls back but fallback does not converge. #[test] -#[should_panic(expected = "fallback did not converge")] +#[should_panic(expected = "too many cycle iterations")] fn two_fallback_diverge() { let mut db = DbImpl::new(); let a_in = Inputs::new(&db, vec![]); @@ -1062,7 +1062,7 @@ fn cycle_sibling_interference() { "salsa_event(WillExecute { database_key: min_iterate(Id(0)) })", "salsa_event(WillExecute { database_key: min_iterate(Id(1)) })", "salsa_event(WillExecute { database_key: min_panic(Id(3)) })", - "salsa_event(WillIterateCycle { database_key: min_iterate(Id(0)), iteration_count: IterationCount(1), fell_back: false })", + "salsa_event(WillIterateCycle { database_key: min_iterate(Id(0)), iteration_count: IterationCount(1) })", "salsa_event(WillExecute { database_key: min_iterate(Id(1)) })", ]"#]]); } @@ -1093,7 +1093,7 @@ fn repeat_provisional_query() { "salsa_event(WillExecute { database_key: min_iterate(Id(0)) })", "salsa_event(WillExecute { database_key: min_panic(Id(1)) })", "salsa_event(WillExecute { database_key: min_panic(Id(2)) })", - "salsa_event(WillIterateCycle { database_key: min_iterate(Id(0)), iteration_count: IterationCount(1), fell_back: false })", + "salsa_event(WillIterateCycle { database_key: min_iterate(Id(0)), iteration_count: IterationCount(1) })", "salsa_event(WillExecute { database_key: min_panic(Id(1)) })", "salsa_event(WillExecute { database_key: min_panic(Id(2)) })", ]"#]]); @@ -1132,7 +1132,7 @@ fn repeat_provisional_query_incremental() { "salsa_event(WillExecute { database_key: min_panic(Id(2)) })", "salsa_event(WillExecute { database_key: min_panic(Id(1)) })", "salsa_event(WillExecute { database_key: min_iterate(Id(0)) })", - "salsa_event(WillIterateCycle { database_key: min_iterate(Id(0)), iteration_count: IterationCount(1), fell_back: false })", + "salsa_event(WillIterateCycle { database_key: min_iterate(Id(0)), iteration_count: IterationCount(1) })", "salsa_event(WillExecute { database_key: min_panic(Id(1)) })", "salsa_event(WillExecute { database_key: min_panic(Id(2)) })", ]"#]]); @@ -1248,13 +1248,13 @@ fn repeat_query_participating_in_cycle() { "salsa_event(WillExecute { database_key: query_c(Id(0)) })", "salsa_event(WillExecute { database_key: query_d(Id(0)) })", "salsa_event(WillExecute { database_key: query_hot(Id(0)) })", - "salsa_event(WillIterateCycle { database_key: head(Id(0)), iteration_count: IterationCount(1), fell_back: false })", + "salsa_event(WillIterateCycle { database_key: head(Id(0)), iteration_count: IterationCount(1) })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", "salsa_event(WillExecute { database_key: query_b(Id(0)) })", "salsa_event(WillExecute { database_key: query_c(Id(0)) })", "salsa_event(WillExecute { database_key: query_d(Id(0)) })", "salsa_event(WillExecute { database_key: query_hot(Id(0)) })", - "salsa_event(WillIterateCycle { database_key: head(Id(0)), iteration_count: IterationCount(2), fell_back: false })", + "salsa_event(WillIterateCycle { database_key: head(Id(0)), iteration_count: IterationCount(2) })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", "salsa_event(WillExecute { database_key: query_b(Id(0)) })", "salsa_event(WillExecute { database_key: query_c(Id(0)) })", @@ -1362,13 +1362,13 @@ fn repeat_query_participating_in_cycle2() { "salsa_event(WillExecute { database_key: query_b(Id(0)) })", "salsa_event(WillExecute { database_key: query_c(Id(0)) })", "salsa_event(WillExecute { database_key: query_d(Id(0)) })", - "salsa_event(WillIterateCycle { database_key: head(Id(0)), iteration_count: IterationCount(1), fell_back: false })", + "salsa_event(WillIterateCycle { database_key: head(Id(0)), iteration_count: IterationCount(1) })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", "salsa_event(WillExecute { database_key: query_hot(Id(0)) })", "salsa_event(WillExecute { database_key: query_b(Id(0)) })", "salsa_event(WillExecute { database_key: query_c(Id(0)) })", "salsa_event(WillExecute { database_key: query_d(Id(0)) })", - "salsa_event(WillIterateCycle { database_key: head(Id(0)), iteration_count: IterationCount(2), fell_back: false })", + "salsa_event(WillIterateCycle { database_key: head(Id(0)), iteration_count: IterationCount(2) })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", "salsa_event(WillExecute { database_key: query_hot(Id(0)) })", "salsa_event(WillExecute { database_key: query_b(Id(0)) })", diff --git a/tests/cycle_output.rs b/tests/cycle_output.rs index 27ba304e3..59b789aa4 100644 --- a/tests/cycle_output.rs +++ b/tests/cycle_output.rs @@ -195,13 +195,13 @@ fn revalidate_with_change_after_output_read() { "salsa_event(WillDiscardStaleOutput { execute_key: query_a(Id(0)), output_key: Output(Id(403)) })", "salsa_event(DidDiscard { key: Output(Id(403)) })", "salsa_event(DidDiscard { key: read_value(Id(403)) })", - "salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: IterationCount(1), fell_back: false })", + "salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: IterationCount(1) })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", "salsa_event(WillExecute { database_key: read_value(Id(401g1)) })", - "salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: IterationCount(2), fell_back: false })", + "salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: IterationCount(2) })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", "salsa_event(WillExecute { database_key: read_value(Id(402g1)) })", - "salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: IterationCount(3), fell_back: false })", + "salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: IterationCount(3) })", "salsa_event(WillExecute { database_key: query_a(Id(0)) })", "salsa_event(WillExecute { database_key: read_value(Id(403g1)) })", ]"#]]); diff --git a/tests/cycle_recovery_call_back_into_cycle.rs b/tests/cycle_recovery_call_back_into_cycle.rs index af7c10219..805a2be7b 100644 --- a/tests/cycle_recovery_call_back_into_cycle.rs +++ b/tests/cycle_recovery_call_back_into_cycle.rs @@ -37,7 +37,6 @@ fn converges() { } #[test] -#[should_panic(expected = "fallback did not converge")] fn diverges() { let db = DatabaseWithValue::new(3); diff --git a/tests/cycle_tracked.rs b/tests/cycle_tracked.rs index 27e52934d..154ba3370 100644 --- a/tests/cycle_tracked.rs +++ b/tests/cycle_tracked.rs @@ -165,7 +165,7 @@ fn main() { "WillExecute { database_key: cost_to_start(Id(401)) }", "WillCheckCancellation", "WillCheckCancellation", - "WillIterateCycle { database_key: cost_to_start(Id(403)), iteration_count: IterationCount(1), fell_back: false }", + "WillIterateCycle { database_key: cost_to_start(Id(403)), iteration_count: IterationCount(1) }", "WillCheckCancellation", "WillCheckCancellation", "WillCheckCancellation", @@ -307,9 +307,9 @@ fn test_cycle_with_fixpoint_structs() { "WillCheckCancellation", "WillExecute { database_key: create_tracked_in_cycle(Id(0)) }", "WillCheckCancellation", - "WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(1), fell_back: false }", + "WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(1) }", "WillCheckCancellation", - "WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(2), fell_back: false }", + "WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(2) }", "WillCheckCancellation", "WillDiscardStaleOutput { execute_key: create_tracked_in_cycle(Id(0)), output_key: IterationNode(Id(402)) }", "DidDiscard { key: IterationNode(Id(402)) }", diff --git a/tests/cycle_tracked_own_input.rs b/tests/cycle_tracked_own_input.rs index 17e8b815e..38218f1a7 100644 --- a/tests/cycle_tracked_own_input.rs +++ b/tests/cycle_tracked_own_input.rs @@ -117,7 +117,7 @@ fn main() { "WillExecute { database_key: infer_type_param(Id(400)) }", "WillCheckCancellation", "DidInternValue { key: Class(Id(c00)), revision: R2 }", - "WillIterateCycle { database_key: infer_class(Id(0)), iteration_count: IterationCount(1), fell_back: false }", + "WillIterateCycle { database_key: infer_class(Id(0)), iteration_count: IterationCount(1) }", "WillCheckCancellation", "WillExecute { database_key: infer_type_param(Id(400)) }", "WillCheckCancellation", From e257df12eabd566825ba53bb12d782560b9a4dcd Mon Sep 17 00:00:00 2001 From: Chayim Refael Friedman Date: Thu, 18 Sep 2025 11:57:22 +0300 Subject: [PATCH 42/65] Provide a method to attach a database even if it's different from the current attached one (#992) * Provide a method to attach a database even if it's different from the current attached one rust-analyzer needs this. * Fix Clippy on beta * Update compile fail test outputs --------- Co-authored-by: Lukas Wirth --- src/attach.rs | 49 +++++++++++++++++++ src/lib.rs | 2 +- src/zalsa.rs | 3 ++ .../incomplete_persistence.stderr | 8 +-- 4 files changed, 57 insertions(+), 5 deletions(-) diff --git a/src/attach.rs b/src/attach.rs index 671933b50..973da8959 100644 --- a/src/attach.rs +++ b/src/attach.rs @@ -79,6 +79,38 @@ impl Attached { op() } + #[inline] + fn attach_allow_change(&self, db: &Db, op: impl FnOnce() -> R) -> R + where + Db: ?Sized + Database, + { + struct DbGuard<'s> { + state: &'s Attached, + prev: Option>, + } + + impl<'s> DbGuard<'s> { + #[inline] + fn new(attached: &'s Attached, db: &dyn Database) -> Self { + let prev = attached.database.replace(Some(NonNull::from(db))); + Self { + state: attached, + prev, + } + } + } + + impl Drop for DbGuard<'_> { + #[inline] + fn drop(&mut self) { + self.state.database.set(self.prev); + } + } + + let _guard = DbGuard::new(self, db.as_dyn_database()); + op() + } + /// Access the "attached" database. Returns `None` if no database is attached. /// Databases are attached with `attach_database`. #[inline] @@ -104,6 +136,23 @@ where ) } +/// Attach the database to the current thread and execute `op`. +/// Allows a different database than currently attached. The original database +/// will be restored on return. +/// +/// **Note:** Switching databases can cause bugs. If you do not intend to switch +/// databases, prefer [`attach`] which will panic if you accidentally do. +#[inline] +pub fn attach_allow_change(db: &Db, op: impl FnOnce() -> R) -> R +where + Db: ?Sized + Database, +{ + ATTACHED.with( + #[inline] + |a| a.attach_allow_change(db, op), + ) +} + /// Access the "attached" database. Returns `None` if no database is attached. /// Databases are attached with `attach_database`. #[inline] diff --git a/src/lib.rs b/src/lib.rs index d846cbd76..8ab47379d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,7 +67,7 @@ pub use self::runtime::Runtime; pub use self::storage::{Storage, StorageHandle}; pub use self::update::Update; pub use self::zalsa::IngredientIndex; -pub use crate::attach::{attach, with_attached_database}; +pub use crate::attach::{attach, attach_allow_change, with_attached_database}; pub mod prelude { #[cfg(feature = "accumulator")] diff --git a/src/zalsa.rs b/src/zalsa.rs index 9fc139e64..ee3c68ce0 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -516,6 +516,9 @@ impl ErasedJar { pub const fn erase() -> Self { Self { kind: I::KIND, + // This is a false positive of the lint on beta, fixed on nightly. + // FIXME: Remove this when nightly stabilizes. + #[allow(clippy::incompatible_msrv)] type_id: TypeId::of::, type_name: std::any::type_name::, create_ingredients: ::create_ingredients, diff --git a/tests/compile-fail/incomplete_persistence.stderr b/tests/compile-fail/incomplete_persistence.stderr index ded998277..f7082ecca 100644 --- a/tests/compile-fail/incomplete_persistence.stderr +++ b/tests/compile-fail/incomplete_persistence.stderr @@ -1,4 +1,4 @@ -error[E0277]: the trait bound `NotPersistable<'_>: Serialize` is not satisfied +error[E0277]: the trait bound `NotPersistable<'_>: serde::Serialize` is not satisfied --> tests/compile-fail/incomplete_persistence.rs:1:1 | 1 | #[salsa::tracked(persist)] @@ -22,7 +22,7 @@ error[E0277]: the trait bound `NotPersistable<'_>: Serialize` is not satisfied = note: required for `(NotPersistable<'_>,)` to implement `Serialize` = note: this error originates in the macro `salsa::plumbing::setup_tracked_struct` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) -error[E0277]: the trait bound `NotPersistable<'_>: Deserialize<'_>` is not satisfied +error[E0277]: the trait bound `NotPersistable<'_>: serde::Deserialize<'de>` is not satisfied --> tests/compile-fail/incomplete_persistence.rs:1:1 | 1 | #[salsa::tracked(persist)] @@ -43,7 +43,7 @@ error[E0277]: the trait bound `NotPersistable<'_>: Deserialize<'_>` is not satis = note: required for `(NotPersistable<'_>,)` to implement `Deserialize<'_>` = note: this error originates in the macro `salsa::plumbing::setup_tracked_struct` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) -error[E0277]: the trait bound `NotPersistable<'db>: Serialize` is not satisfied +error[E0277]: the trait bound `NotPersistable<'db>: serde::Serialize` is not satisfied --> tests/compile-fail/incomplete_persistence.rs:12:45 | 12 | fn query(_db: &dyn salsa::Database, _input: NotPersistable<'_>) {} @@ -71,7 +71,7 @@ note: required by a bound in `query_input_is_persistable` | required by this bound in `query_input_is_persistable` = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) -error[E0277]: the trait bound `for<'de> NotPersistable<'db>: Deserialize<'de>` is not satisfied +error[E0277]: the trait bound `NotPersistable<'db>: serde::Deserialize<'de>` is not satisfied --> tests/compile-fail/incomplete_persistence.rs:12:45 | 12 | fn query(_db: &dyn salsa::Database, _input: NotPersistable<'_>) {} From 29ab321b45d00daa4315fa2a06f7207759a8c87e Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Wed, 24 Sep 2025 19:49:24 +0200 Subject: [PATCH 43/65] fix: Cleanup provisional cycle head memos when query panics (#993) * fix: Cleanup provisional cycle head memos on panic * Update test outputs --- src/function/execute.rs | 63 +++++++++++++++++-- src/function/sync.rs | 4 +- src/runtime.rs | 2 +- .../get-set-on-private-input-field.stderr | 4 +- .../input_struct_incompatibles.stderr | 4 +- .../interned_struct_incompatibles.stderr | 4 +- tests/compile-fail/span-input-setter.stderr | 6 +- .../tracked_fn_return_not_update.stderr | 12 ++-- .../tracked_impl_incompatibles.stderr | 16 ++--- .../tracked_struct_not_update.stderr | 4 +- tests/compile_fail.rs | 2 +- 11 files changed, 88 insertions(+), 33 deletions(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index 723964851..7445d2f81 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -5,7 +5,7 @@ use crate::function::{Configuration, IngredientImpl}; use crate::sync::atomic::{AtomicBool, Ordering}; use crate::tracked_struct::Identity; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; -use crate::zalsa_local::ActiveQueryGuard; +use crate::zalsa_local::{ActiveQueryGuard, QueryRevisions}; use crate::{Event, EventKind, Id}; impl IngredientImpl @@ -141,6 +141,7 @@ where // only when a cycle is actually encountered. let mut opt_last_provisional: Option<&Memo<'db, C>> = None; let mut last_stale_tracked_ids: Vec<(Identity, Id)> = Vec::new(); + let _guard = ClearCycleHeadIfPanicking::new(self, zalsa, id, memo_ingredient_index); loop { let previous_memo = opt_last_provisional.or(opt_old_memo); @@ -210,6 +211,9 @@ where // `iteration_count` can't overflow as we check it against `MAX_ITERATIONS` // which is less than `u32::MAX`. iteration_count = iteration_count.increment().unwrap_or_else(|| { + tracing::warn!( + "{database_key_index:?}: execute: too many cycle iterations" + ); panic!("{database_key_index:?}: execute: too many cycle iterations") }); zalsa.event(&|| { @@ -222,10 +226,7 @@ where completed_query .revisions .update_iteration_count(iteration_count); - crate::tracing::debug!( - "{database_key_index:?}: execute: iterate again, revisions: {revisions:#?}", - revisions = &completed_query.revisions - ); + crate::tracing::info!("{database_key_index:?}: execute: iterate again...",); opt_last_provisional = Some(self.insert_memo( zalsa, id, @@ -297,3 +298,55 @@ where (new_value, active_query.pop()) } } + +/// Replaces any inserted memo with a fixpoint initial memo without a value if the current thread panics. +/// +/// A regular query doesn't insert any memo if it panics and the query +/// simply gets re-executed if any later called query depends on the panicked query (and will panic again unless the query isn't deterministic). +/// +/// Unfortunately, this isn't the case for cycle heads because Salsa first inserts the fixpoint initial memo and later inserts +/// provisional memos for every iteration. Detecting whether a query has previously panicked +/// in `fetch` (e.g., `validate_same_iteration`) and requires re-execution is probably possible but not very straightforward +/// and it's easy to get it wrong, which results in infinite loops where `Memo::provisional_retry` keeps retrying to get the latest `Memo` +/// but `fetch` doesn't re-execute the query for reasons. +/// +/// Specifically, a Memo can linger after a panic, which is then incorrectly returned +/// by `fetch_cold_cycle` because it passes the `shallow_verified_memo` check instead of inserting +/// a new fix point initial value if that happens. +/// +/// We could insert a fixpoint initial value here, but it seems unnecessary. +struct ClearCycleHeadIfPanicking<'a, C: Configuration> { + ingredient: &'a IngredientImpl, + zalsa: &'a Zalsa, + id: Id, + memo_ingredient_index: MemoIngredientIndex, +} + +impl<'a, C: Configuration> ClearCycleHeadIfPanicking<'a, C> { + fn new( + ingredient: &'a IngredientImpl, + zalsa: &'a Zalsa, + id: Id, + memo_ingredient_index: MemoIngredientIndex, + ) -> Self { + Self { + ingredient, + zalsa, + id, + memo_ingredient_index, + } + } +} + +impl Drop for ClearCycleHeadIfPanicking<'_, C> { + fn drop(&mut self) { + if std::thread::panicking() { + let revisions = + QueryRevisions::fixpoint_initial(self.ingredient.database_key_index(self.id)); + + let memo = Memo::new(None, self.zalsa.current_revision(), revisions); + self.ingredient + .insert_memo(self.zalsa, self.id, memo, self.memo_ingredient_index); + } + } +} diff --git a/src/function/sync.rs b/src/function/sync.rs index bb514e114..0a88844af 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -97,9 +97,11 @@ impl ClaimGuard<'_> { syncs.remove(&self.key_index).expect("key claimed twice?"); if anyone_waiting { + let database_key = DatabaseKeyIndex::new(self.sync_table.ingredient, self.key_index); self.zalsa.runtime().unblock_queries_blocked_on( - DatabaseKeyIndex::new(self.sync_table.ingredient, self.key_index), + database_key, if thread::panicking() { + tracing::info!("Unblocking queries blocked on {database_key:?} after a panick"); WaitResult::Panicked } else { WaitResult::Completed diff --git a/src/runtime.rs b/src/runtime.rs index ec3b091d5..8436c684d 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -90,7 +90,7 @@ impl Running<'_> { }) }); - crate::tracing::debug!( + crate::tracing::info!( "block_on: thread {thread_id:?} is blocking on {database_key:?} in thread {other_id:?}", ); diff --git a/tests/compile-fail/get-set-on-private-input-field.stderr b/tests/compile-fail/get-set-on-private-input-field.stderr index 40acd8c2d..0b6738bc7 100644 --- a/tests/compile-fail/get-set-on-private-input-field.stderr +++ b/tests/compile-fail/get-set-on-private-input-field.stderr @@ -1,7 +1,7 @@ error[E0624]: method `field` is private --> tests/compile-fail/get-set-on-private-input-field.rs:12:11 | -2 | #[salsa::input] + 2 | #[salsa::input] | --------------- private method defined here ... 12 | input.field(&db); @@ -10,7 +10,7 @@ error[E0624]: method `field` is private error[E0624]: method `set_field` is private --> tests/compile-fail/get-set-on-private-input-field.rs:13:11 | -2 | #[salsa::input] + 2 | #[salsa::input] | --------------- private method defined here ... 13 | input.set_field(&mut db).to(23); diff --git a/tests/compile-fail/input_struct_incompatibles.stderr b/tests/compile-fail/input_struct_incompatibles.stderr index a1b94e9aa..b1fb3c542 100644 --- a/tests/compile-fail/input_struct_incompatibles.stderr +++ b/tests/compile-fail/input_struct_incompatibles.stderr @@ -55,7 +55,7 @@ error: cannot find attribute `tracked` in this scope | help: consider importing one of these attribute macros | -1 + use salsa::tracked; + 1 + use salsa::tracked; | -1 + use salsa_macros::tracked; + 1 + use salsa_macros::tracked; | diff --git a/tests/compile-fail/interned_struct_incompatibles.stderr b/tests/compile-fail/interned_struct_incompatibles.stderr index 482e38b46..bc4a10fd6 100644 --- a/tests/compile-fail/interned_struct_incompatibles.stderr +++ b/tests/compile-fail/interned_struct_incompatibles.stderr @@ -49,7 +49,7 @@ error: cannot find attribute `tracked` in this scope | help: consider importing one of these attribute macros | -1 + use salsa::tracked; + 1 + use salsa::tracked; | -1 + use salsa_macros::tracked; + 1 + use salsa_macros::tracked; | diff --git a/tests/compile-fail/span-input-setter.stderr b/tests/compile-fail/span-input-setter.stderr index b43d07df3..afd2b726d 100644 --- a/tests/compile-fail/span-input-setter.stderr +++ b/tests/compile-fail/span-input-setter.stderr @@ -11,10 +11,10 @@ error[E0308]: mismatched types note: method defined here --> tests/compile-fail/span-input-setter.rs:3:5 | -1 | #[salsa::input] + 1 | #[salsa::input] | --------------- -2 | pub struct MyInput { -3 | field: u32, + 2 | pub struct MyInput { + 3 | field: u32, | ^^^^^ help: consider mutably borrowing here | diff --git a/tests/compile-fail/tracked_fn_return_not_update.stderr b/tests/compile-fail/tracked_fn_return_not_update.stderr index 805f05916..98f4ebe13 100644 --- a/tests/compile-fail/tracked_fn_return_not_update.stderr +++ b/tests/compile-fail/tracked_fn_return_not_update.stderr @@ -7,18 +7,18 @@ error[E0369]: binary operation `==` cannot be applied to type `&NotUpdate` note: an implementation of `PartialEq` might be missing for `NotUpdate` --> tests/compile-fail/tracked_fn_return_not_update.rs:7:1 | -7 | struct NotUpdate; + 7 | struct NotUpdate; | ^^^^^^^^^^^^^^^^ must implement `PartialEq` help: consider annotating `NotUpdate` with `#[derive(PartialEq)]` | -7 + #[derive(PartialEq)] -8 | struct NotUpdate; + 7 + #[derive(PartialEq)] + 8 | struct NotUpdate; | error[E0599]: the function or associated item `maybe_update` exists for struct `UpdateDispatch`, but its trait bounds were not satisfied --> tests/compile-fail/tracked_fn_return_not_update.rs:10:56 | -7 | struct NotUpdate; + 7 | struct NotUpdate; | ---------------- doesn't satisfy `NotUpdate: PartialEq` or `NotUpdate: Update` ... 10 | fn tracked_fn<'db>(db: &'db dyn Db, input: MyInput) -> NotUpdate { @@ -40,6 +40,6 @@ note: the trait `Update` must be implemented | ^^^^^^^^^^^^^^^^^^^^^^^ help: consider annotating `NotUpdate` with `#[derive(PartialEq)]` | -7 + #[derive(PartialEq)] -8 | struct NotUpdate; + 7 + #[derive(PartialEq)] + 8 | struct NotUpdate; | diff --git a/tests/compile-fail/tracked_impl_incompatibles.stderr b/tests/compile-fail/tracked_impl_incompatibles.stderr index 3dd3e5868..43a23ff19 100644 --- a/tests/compile-fail/tracked_impl_incompatibles.stderr +++ b/tests/compile-fail/tracked_impl_incompatibles.stderr @@ -55,7 +55,7 @@ error: unexpected token error[E0119]: conflicting implementations of trait `Default` for type `MyTracked<'_>` --> tests/compile-fail/tracked_impl_incompatibles.rs:12:1 | -7 | impl<'db> std::default::Default for MyTracked<'db> { + 7 | impl<'db> std::default::Default for MyTracked<'db> { | -------------------------------------------------- first implementation here ... 12 | impl<'db> std::default::Default for MyTracked<'db> { @@ -64,7 +64,7 @@ error[E0119]: conflicting implementations of trait `Default` for type `MyTracked error[E0119]: conflicting implementations of trait `Default` for type `MyTracked<'_>` --> tests/compile-fail/tracked_impl_incompatibles.rs:17:1 | -7 | impl<'db> std::default::Default for MyTracked<'db> { + 7 | impl<'db> std::default::Default for MyTracked<'db> { | -------------------------------------------------- first implementation here ... 17 | impl<'db> std::default::Default for MyTracked<'db> { @@ -73,7 +73,7 @@ error[E0119]: conflicting implementations of trait `Default` for type `MyTracked error[E0119]: conflicting implementations of trait `Default` for type `MyTracked<'_>` --> tests/compile-fail/tracked_impl_incompatibles.rs:22:1 | -7 | impl<'db> std::default::Default for MyTracked<'db> { + 7 | impl<'db> std::default::Default for MyTracked<'db> { | -------------------------------------------------- first implementation here ... 22 | impl<'db> std::default::Default for MyTracked<'db> { @@ -82,7 +82,7 @@ error[E0119]: conflicting implementations of trait `Default` for type `MyTracked error[E0119]: conflicting implementations of trait `Default` for type `MyTracked<'_>` --> tests/compile-fail/tracked_impl_incompatibles.rs:27:1 | -7 | impl<'db> std::default::Default for MyTracked<'db> { + 7 | impl<'db> std::default::Default for MyTracked<'db> { | -------------------------------------------------- first implementation here ... 27 | impl<'db> std::default::Default for MyTracked<'db> { @@ -91,7 +91,7 @@ error[E0119]: conflicting implementations of trait `Default` for type `MyTracked error[E0119]: conflicting implementations of trait `Default` for type `MyTracked<'_>` --> tests/compile-fail/tracked_impl_incompatibles.rs:32:1 | -7 | impl<'db> std::default::Default for MyTracked<'db> { + 7 | impl<'db> std::default::Default for MyTracked<'db> { | -------------------------------------------------- first implementation here ... 32 | impl<'db> std::default::Default for MyTracked<'db> { @@ -100,7 +100,7 @@ error[E0119]: conflicting implementations of trait `Default` for type `MyTracked error[E0119]: conflicting implementations of trait `Default` for type `MyTracked<'_>` --> tests/compile-fail/tracked_impl_incompatibles.rs:37:1 | -7 | impl<'db> std::default::Default for MyTracked<'db> { + 7 | impl<'db> std::default::Default for MyTracked<'db> { | -------------------------------------------------- first implementation here ... 37 | impl<'db> std::default::Default for MyTracked<'db> { @@ -109,7 +109,7 @@ error[E0119]: conflicting implementations of trait `Default` for type `MyTracked error[E0119]: conflicting implementations of trait `Default` for type `MyTracked<'_>` --> tests/compile-fail/tracked_impl_incompatibles.rs:42:1 | -7 | impl<'db> std::default::Default for MyTracked<'db> { + 7 | impl<'db> std::default::Default for MyTracked<'db> { | -------------------------------------------------- first implementation here ... 42 | impl<'db> std::default::Default for MyTracked<'db> { @@ -118,7 +118,7 @@ error[E0119]: conflicting implementations of trait `Default` for type `MyTracked error[E0119]: conflicting implementations of trait `Default` for type `MyTracked<'_>` --> tests/compile-fail/tracked_impl_incompatibles.rs:47:1 | -7 | impl<'db> std::default::Default for MyTracked<'db> { + 7 | impl<'db> std::default::Default for MyTracked<'db> { | -------------------------------------------------- first implementation here ... 47 | impl<'db> std::default::Default for MyTracked<'db> { diff --git a/tests/compile-fail/tracked_struct_not_update.stderr b/tests/compile-fail/tracked_struct_not_update.stderr index 293552db0..306029ee4 100644 --- a/tests/compile-fail/tracked_struct_not_update.stderr +++ b/tests/compile-fail/tracked_struct_not_update.stderr @@ -29,6 +29,6 @@ note: the trait `Update` must be implemented = note: this error originates in the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) help: consider annotating `NotUpdate` with `#[derive(PartialEq)]` | -7 + #[derive(PartialEq)] -8 | struct NotUpdate; + 7 + #[derive(PartialEq)] + 8 | struct NotUpdate; | diff --git a/tests/compile_fail.rs b/tests/compile_fail.rs index 3648f756d..081a7eaab 100644 --- a/tests/compile_fail.rs +++ b/tests/compile_fail.rs @@ -1,6 +1,6 @@ #![cfg(all(feature = "inventory", feature = "persistence"))] -#[rustversion::all(stable, since(1.89))] +#[rustversion::all(stable, since(1.90))] #[test] fn compile_fail() { let t = trybuild::TestCases::new(); From 5330dd99b94f7a69d474a2854f7b2a0df328a7c2 Mon Sep 17 00:00:00 2001 From: Astavie Date: Fri, 26 Sep 2025 11:29:39 +0200 Subject: [PATCH 44/65] Add implementations for Lookup and HashEqLike for CompactString (#988) --- src/interned.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/interned.rs b/src/interned.rs index 77dda42b7..544a8d0ee 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -1278,6 +1278,13 @@ impl Lookup for &str { } } +#[cfg(feature = "compact_str")] +impl Lookup for &str { + fn into_owned(self) -> compact_str::CompactString { + compact_str::CompactString::new(self) + } +} + impl HashEqLike<&str> for String { fn hash(&self, h: &mut H) { Hash::hash(self, &mut *h) @@ -1288,6 +1295,17 @@ impl HashEqLike<&str> for String { } } +#[cfg(feature = "compact_str")] +impl HashEqLike<&str> for compact_str::CompactString { + fn hash(&self, h: &mut H) { + Hash::hash(self, &mut *h) + } + + fn eq(&self, data: &&str) -> bool { + self == *data + } +} + impl> HashEqLike<&[A]> for Vec { fn hash(&self, h: &mut H) { Hash::hash(self, h); From 5c826b59da97e351cad3a9b4be715281ff93a703 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 30 Sep 2025 12:48:35 +0100 Subject: [PATCH 45/65] Update codspeed action (#997) --- .github/workflows/test.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6a43ba722..6325785eb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -129,7 +129,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Setup Rust toolchain uses: dtolnay/rust-toolchain@master @@ -159,7 +159,8 @@ jobs: run: cargo codspeed build - name: "Run benchmarks" - uses: CodSpeedHQ/action@v3 + uses: CodSpeedHQ/action@v4 with: + mode: instrumentation run: cargo codspeed run token: ${{ secrets.CODSPEED_TOKEN }} From 4a26bf9e49a2bd111425ac48180e95ee5e9ad36d Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 30 Sep 2025 13:22:03 +0100 Subject: [PATCH 46/65] refactor: Push active query in execute (#996) * refactor: Push active query in execute * Remove inline from `execute_maybe_iterate` * Pass Zalsa and ZalsaLocal * Remove inline from `push_query` * Remove `id` from `execute_query` * Discard changes to src/zalsa_local.rs --- .github/workflows/test.yml | 3 +- src/function/execute.rs | 54 +++++++++++++++++------------ src/function/fetch.rs | 6 +--- src/function/maybe_changed_after.rs | 6 ++-- 4 files changed, 36 insertions(+), 33 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6325785eb..ee9aa06b8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -152,7 +152,8 @@ jobs: target/ key: ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.cachekey }}-${{ hashFiles('**/Cargo.toml') }} restore-keys: | - ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.cachekey }}- + ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.cachekey }}-benchmark + ${{ runner.os }}-cargo-${{ steps.rust-toolchain.outputs.cachekey }} ${{ runner.os }}-cargo- - name: "Build benchmarks" diff --git a/src/function/execute.rs b/src/function/execute.rs index 7445d2f81..738df1247 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -2,11 +2,12 @@ use crate::active_query::CompletedQuery; use crate::cycle::{CycleRecoveryStrategy, IterationCount}; use crate::function::memo::Memo; use crate::function::{Configuration, IngredientImpl}; +use crate::plumbing::ZalsaLocal; use crate::sync::atomic::{AtomicBool, Ordering}; use crate::tracked_struct::Identity; -use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; +use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{ActiveQueryGuard, QueryRevisions}; -use crate::{Event, EventKind, Id}; +use crate::{DatabaseKeyIndex, Event, EventKind, Id}; impl IngredientImpl where @@ -25,14 +26,14 @@ where pub(super) fn execute<'db>( &'db self, db: &'db C::DbView, - active_query: ActiveQueryGuard<'db>, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, + database_key_index: DatabaseKeyIndex, opt_old_memo: Option<&Memo<'db, C>>, ) -> &'db Memo<'db, C> { - let database_key_index = active_query.database_key_index; let id = database_key_index.key_index(); crate::tracing::info!("{:?}: executing query", database_key_index); - let zalsa = db.zalsa(); zalsa.event(&|| { Event::new(EventKind::WillExecute { @@ -42,12 +43,19 @@ where let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); let (new_value, mut completed_query) = match C::CYCLE_STRATEGY { - CycleRecoveryStrategy::Panic => { - Self::execute_query(db, zalsa, active_query, opt_old_memo, id) - } + CycleRecoveryStrategy::Panic => Self::execute_query( + db, + zalsa, + zalsa_local.push_query(database_key_index, IterationCount::initial()), + opt_old_memo, + ), CycleRecoveryStrategy::FallbackImmediate => { - let (mut new_value, mut completed_query) = - Self::execute_query(db, zalsa, active_query, opt_old_memo, id); + let (mut new_value, mut completed_query) = Self::execute_query( + db, + zalsa, + zalsa_local.push_query(database_key_index, IterationCount::initial()), + opt_old_memo, + ); if let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() { // Did the new result we got depend on our own provisional value, in a cycle? @@ -71,9 +79,8 @@ where // Cycle participants that don't have a fallback will be discarded in // `validate_provisional()`. let cycle_heads = std::mem::take(cycle_heads); - let active_query = db - .zalsa_local() - .push_query(database_key_index, IterationCount::initial()); + let active_query = + zalsa_local.push_query(database_key_index, IterationCount::initial()); new_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); completed_query = active_query.pop(); // We need to set `cycle_heads` and `verified_final` because it needs to propagate to the callers. @@ -86,10 +93,10 @@ where } CycleRecoveryStrategy::Fixpoint => self.execute_maybe_iterate( db, - active_query, opt_old_memo, zalsa, - id, + zalsa_local, + database_key_index, memo_ingredient_index, ), }; @@ -122,19 +129,18 @@ where ) } - #[inline] fn execute_maybe_iterate<'db>( &'db self, db: &'db C::DbView, - mut active_query: ActiveQueryGuard<'db>, opt_old_memo: Option<&Memo<'db, C>>, zalsa: &'db Zalsa, - id: Id, + zalsa_local: &'db ZalsaLocal, + database_key_index: DatabaseKeyIndex, memo_ingredient_index: MemoIngredientIndex, ) -> (C::Output<'db>, CompletedQuery) { - let database_key_index = active_query.database_key_index; + let id = database_key_index.key_index(); let mut iteration_count = IterationCount::initial(); - let zalsa_local = db.zalsa_local(); + let mut active_query = zalsa_local.push_query(database_key_index, iteration_count); // Our provisional value from the previous iteration, when doing fixpoint iteration. // Initially it's set to None, because the initial provisional value is created lazily, @@ -155,7 +161,7 @@ where active_query.seed_tracked_struct_ids(&last_stale_tracked_ids); let (mut new_value, mut completed_query) = - Self::execute_query(db, zalsa, active_query, previous_memo, id); + Self::execute_query(db, zalsa, active_query, previous_memo); // Did the new result we got depend on our own provisional value, in a cycle? if let Some(cycle_heads) = completed_query @@ -272,7 +278,6 @@ where zalsa: &'db Zalsa, active_query: ActiveQueryGuard<'db>, opt_old_memo: Option<&Memo<'db, C>>, - id: Id, ) -> (C::Output<'db>, CompletedQuery) { if let Some(old_memo) = opt_old_memo { // If we already executed this query once, then use the tracked-struct ids from the @@ -293,7 +298,10 @@ where // Query was not previously executed, or value is potentially // stale, or value is absent. Let's execute! - let new_value = C::execute(db, C::id_to_input(zalsa, id)); + let new_value = C::execute( + db, + C::id_to_input(zalsa, active_query.database_key_index.key_index()), + ); (new_value, active_query.pop()) } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 57b79a52a..a1b6658f6 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -232,11 +232,7 @@ where } } - let memo = self.execute( - db, - zalsa_local.push_query(database_key_index, IterationCount::initial()), - opt_old_memo, - ); + let memo = self.execute(db, zalsa, zalsa_local, database_key_index, opt_old_memo); Some(memo) } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 2ab1d18ee..4f69655cd 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -2,7 +2,7 @@ use rustc_hash::FxHashMap; #[cfg(feature = "accumulator")] use crate::accumulator::accumulated_map::InputAccumulatedValues; -use crate::cycle::{CycleRecoveryStrategy, IterationCount, ProvisionalStatus}; +use crate::cycle::{CycleRecoveryStrategy, ProvisionalStatus}; use crate::function::memo::Memo; use crate::function::sync::ClaimResult; use crate::function::{Configuration, IngredientImpl}; @@ -227,9 +227,7 @@ where // `in_cycle` tracks if the enclosing query is in a cycle. `deep_verify.cycle_heads` tracks // if **this query** encountered a cycle (which means there's some provisional value somewhere floating around). if old_memo.value.is_some() && !cycle_heads.has_any() { - let active_query = - zalsa_local.push_query(database_key_index, IterationCount::initial()); - let memo = self.execute(db, active_query, Some(old_memo)); + let memo = self.execute(db, zalsa, zalsa_local, database_key_index, Some(old_memo)); let changed_at = memo.revisions.changed_at; // Always assume that a provisional value has changed. From 9c3278f8f05e21b81d23d042048ea0d8daed5ea9 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 30 Sep 2025 13:36:56 +0100 Subject: [PATCH 47/65] Replace unsafe unwrap with `expect` call (#998) --- src/function/execute.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index 738df1247..9521a9dce 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -189,8 +189,10 @@ where debug_assert!(memo.may_be_provisional()); memo.value.as_ref() }; - // SAFETY: The `LRU` does not run mid-execution, so the value remains filled - let last_provisional_value = unsafe { last_provisional_value.unwrap_unchecked() }; + + let last_provisional_value = last_provisional_value.expect( + "`fetch_cold_cycle` should have inserted a provisional memo with Cycle::initial", + ); crate::tracing::debug!( "{database_key_index:?}: execute: \ I am a cycle head, comparing last provisional value with new value" From adf0556ed7aa163f6076ca7dbaed14f3e433d8f7 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sun, 5 Oct 2025 08:11:53 +0200 Subject: [PATCH 48/65] chore: release v0.24.0 (#929) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- CHANGELOG.md | 55 +++++++++++++++++++++++ Cargo.toml | 8 ++-- components/salsa-macro-rules/CHANGELOG.md | 19 ++++++++ components/salsa-macro-rules/Cargo.toml | 2 +- components/salsa-macros/CHANGELOG.md | 12 +++++ components/salsa-macros/Cargo.toml | 2 +- 6 files changed, 92 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c21cd9811..08630a95a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,61 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.24.0](https://github.com/salsa-rs/salsa/compare/salsa-v0.23.0...salsa-v0.24.0) - 2025-09-30 + +### Fixed + +- Cleanup provisional cycle head memos when query panics ([#993](https://github.com/salsa-rs/salsa/pull/993)) +- Runaway for unchanged queries participating in cycle ([#981](https://github.com/salsa-rs/salsa/pull/981)) +- Delete not re-created tracked structs after fixpoint iteration ([#979](https://github.com/salsa-rs/salsa/pull/979)) +- fix assertion during interned deserialization ([#978](https://github.com/salsa-rs/salsa/pull/978)) +- Do not unnecessarily require `Debug` on fields for interned structs ([#951](https://github.com/salsa-rs/salsa/pull/951)) +- Fix phantom data usage in salsa structs affecting auto traits ([#932](https://github.com/salsa-rs/salsa/pull/932)) + +### Other + +- Replace unsafe unwrap with `expect` call ([#998](https://github.com/salsa-rs/salsa/pull/998)) +- Push active query in execute ([#996](https://github.com/salsa-rs/salsa/pull/996)) +- Update codspeed action ([#997](https://github.com/salsa-rs/salsa/pull/997)) +- Add implementations for Lookup and HashEqLike for CompactString ([#988](https://github.com/salsa-rs/salsa/pull/988)) +- Provide a method to attach a database even if it's different from the current attached one ([#992](https://github.com/salsa-rs/salsa/pull/992)) +- Allow fallback to take longer than one iteration to converge ([#991](https://github.com/salsa-rs/salsa/pull/991)) +- refactor `entries` API ([#987](https://github.com/salsa-rs/salsa/pull/987)) +- Persistent caching fixes ([#982](https://github.com/salsa-rs/salsa/pull/982)) +- outline cold path of `lookup_ingredient` ([#984](https://github.com/salsa-rs/salsa/pull/984)) +- Update snapshot to fix nightly type rendering ([#983](https://github.com/salsa-rs/salsa/pull/983)) +- avoid cycles during serialization ([#977](https://github.com/salsa-rs/salsa/pull/977)) +- Flatten unserializable query dependencies ([#975](https://github.com/salsa-rs/salsa/pull/975)) +- optimize `Id::hash` ([#974](https://github.com/salsa-rs/salsa/pull/974)) +- Make `thin-vec/serde` dependency dependent on `persistence` feature ([#973](https://github.com/salsa-rs/salsa/pull/973)) +- Remove tracked structs from query outputs ([#969](https://github.com/salsa-rs/salsa/pull/969)) +- Remove jemalloc ([#972](https://github.com/salsa-rs/salsa/pull/972)) +- Initial persistent caching prototype ([#967](https://github.com/salsa-rs/salsa/pull/967)) +- Fix `maybe_changed_after` runnaway for fixpoint queries ([#961](https://github.com/salsa-rs/salsa/pull/961)) +- add parallel maybe changed after test ([#963](https://github.com/salsa-rs/salsa/pull/963)) +- Update tests for Rust 1.89 ([#966](https://github.com/salsa-rs/salsa/pull/966)) +- remove allocation lock ([#962](https://github.com/salsa-rs/salsa/pull/962)) +- consolidate memory usage information API ([#964](https://github.com/salsa-rs/salsa/pull/964)) +- Add heap size support for salsa structs ([#943](https://github.com/salsa-rs/salsa/pull/943)) +- Extract the cycle branches from `fetch` and `maybe_changed_after` ([#955](https://github.com/salsa-rs/salsa/pull/955)) +- allow reuse of cached provisional memos within the same cycle iteration during `maybe_changed_after` ([#954](https://github.com/salsa-rs/salsa/pull/954)) +- Expose API to manually trigger cancellation ([#959](https://github.com/salsa-rs/salsa/pull/959)) +- Upgrade dependencies ([#956](https://github.com/salsa-rs/salsa/pull/956)) +- Use `CycleHeadSet` in `maybe_update_after` ([#953](https://github.com/salsa-rs/salsa/pull/953)) +- Gate accumulator feature behind a feature flag ([#946](https://github.com/salsa-rs/salsa/pull/946)) +- optimize allocation fast-path ([#949](https://github.com/salsa-rs/salsa/pull/949)) +- remove borrow checks from `ZalsaLocal` ([#939](https://github.com/salsa-rs/salsa/pull/939)) +- Do manual trait casting ([#922](https://github.com/salsa-rs/salsa/pull/922)) +- Retain backing allocation of `ActiveQuery::input_outputs` in `ActiveQuery::seed_iteration` ([#948](https://github.com/salsa-rs/salsa/pull/948)) +- remove extra bounds checks from memo table hot-paths ([#938](https://github.com/salsa-rs/salsa/pull/938)) +- Outline all tracing events ([#942](https://github.com/salsa-rs/salsa/pull/942)) +- remove bounds and type checks from `IngredientCache` ([#937](https://github.com/salsa-rs/salsa/pull/937)) +- Avoid dynamic dispatch to access memo tables ([#941](https://github.com/salsa-rs/salsa/pull/941)) +- optimize page access ([#940](https://github.com/salsa-rs/salsa/pull/940)) +- Use `inventory` for static ingredient registration ([#934](https://github.com/salsa-rs/salsa/pull/934)) +- Fix `heap_size` option not being preserved in tracked impls ([#930](https://github.com/salsa-rs/salsa/pull/930)) +- update papaya ([#928](https://github.com/salsa-rs/salsa/pull/928)) + ## [0.23.0](https://github.com/salsa-rs/salsa/compare/salsa-v0.22.0...salsa-v0.23.0) - 2025-06-27 ### Added diff --git a/Cargo.toml b/Cargo.toml index 0998e2329..cc1cd0347 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "salsa" -version = "0.23.0" +version = "0.24.0" authors.workspace = true edition.workspace = true license.workspace = true @@ -9,8 +9,8 @@ rust-version.workspace = true description = "A generic framework for on-demand, incrementalized computation (experimental)" [dependencies] -salsa-macro-rules = { version = "0.23.0", path = "components/salsa-macro-rules" } -salsa-macros = { version = "0.23.0", path = "components/salsa-macros", optional = true } +salsa-macro-rules = { version = "0.24.0", path = "components/salsa-macro-rules" } +salsa-macros = { version = "0.24.0", path = "components/salsa-macros", optional = true } boxcar = "0.2.13" crossbeam-queue = "0.3.12" @@ -62,7 +62,7 @@ salsa_unstable = [] # which may ultimately result in odd issues due to the proc-macro # output mismatching with the declarative macro inputs [target.'cfg(any())'.dependencies] -salsa-macros = { version = "=0.23.0", path = "components/salsa-macros" } +salsa-macros = { version = "=0.24.0", path = "components/salsa-macros" } [dev-dependencies] # examples diff --git a/components/salsa-macro-rules/CHANGELOG.md b/components/salsa-macro-rules/CHANGELOG.md index 14d542557..85b8986c2 100644 --- a/components/salsa-macro-rules/CHANGELOG.md +++ b/components/salsa-macro-rules/CHANGELOG.md @@ -7,6 +7,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.24.0](https://github.com/salsa-rs/salsa/compare/salsa-macro-rules-v0.23.0...salsa-macro-rules-v0.24.0) - 2025-09-30 + +### Fixed + +- Do not unnecessarily require `Debug` on fields for interned structs ([#951](https://github.com/salsa-rs/salsa/pull/951)) +- Fix phantom data usage in salsa structs affecting auto traits ([#932](https://github.com/salsa-rs/salsa/pull/932)) + +### Other + +- refactor `entries` API ([#987](https://github.com/salsa-rs/salsa/pull/987)) +- Flatten unserializable query dependencies ([#975](https://github.com/salsa-rs/salsa/pull/975)) +- Initial persistent caching prototype ([#967](https://github.com/salsa-rs/salsa/pull/967)) +- Add heap size support for salsa structs ([#943](https://github.com/salsa-rs/salsa/pull/943)) +- Gate accumulator feature behind a feature flag ([#946](https://github.com/salsa-rs/salsa/pull/946)) +- Do manual trait casting ([#922](https://github.com/salsa-rs/salsa/pull/922)) +- remove bounds and type checks from `IngredientCache` ([#937](https://github.com/salsa-rs/salsa/pull/937)) +- Avoid dynamic dispatch to access memo tables ([#941](https://github.com/salsa-rs/salsa/pull/941)) +- Use `inventory` for static ingredient registration ([#934](https://github.com/salsa-rs/salsa/pull/934)) + ## [0.23.0](https://github.com/salsa-rs/salsa/compare/salsa-macro-rules-v0.22.0...salsa-macro-rules-v0.23.0) - 2025-06-27 ### Added diff --git a/components/salsa-macro-rules/Cargo.toml b/components/salsa-macro-rules/Cargo.toml index 65770e10a..96b85de23 100644 --- a/components/salsa-macro-rules/Cargo.toml +++ b/components/salsa-macro-rules/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "salsa-macro-rules" -version = "0.23.0" +version = "0.24.0" authors.workspace = true edition.workspace = true license.workspace = true diff --git a/components/salsa-macros/CHANGELOG.md b/components/salsa-macros/CHANGELOG.md index 251842904..20ad5fd76 100644 --- a/components/salsa-macros/CHANGELOG.md +++ b/components/salsa-macros/CHANGELOG.md @@ -7,6 +7,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.24.0](https://github.com/salsa-rs/salsa/compare/salsa-macros-v0.23.0...salsa-macros-v0.24.0) - 2025-09-30 + +### Other + +- Initial persistent caching prototype ([#967](https://github.com/salsa-rs/salsa/pull/967)) +- Add heap size support for salsa structs ([#943](https://github.com/salsa-rs/salsa/pull/943)) +- Upgrade dependencies ([#956](https://github.com/salsa-rs/salsa/pull/956)) +- Do manual trait casting ([#922](https://github.com/salsa-rs/salsa/pull/922)) +- Avoid dynamic dispatch to access memo tables ([#941](https://github.com/salsa-rs/salsa/pull/941)) +- Use `inventory` for static ingredient registration ([#934](https://github.com/salsa-rs/salsa/pull/934)) +- Fix `heap_size` option not being preserved in tracked impls ([#930](https://github.com/salsa-rs/salsa/pull/930)) + ## [0.23.0](https://github.com/salsa-rs/salsa/compare/salsa-macros-v0.22.0...salsa-macros-v0.23.0) - 2025-06-27 ### Added diff --git a/components/salsa-macros/Cargo.toml b/components/salsa-macros/Cargo.toml index 9bf6992ae..a317bf498 100644 --- a/components/salsa-macros/Cargo.toml +++ b/components/salsa-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "salsa-macros" -version = "0.23.0" +version = "0.24.0" authors.workspace = true edition.workspace = true license.workspace = true From 8b0831f2a3544c03103c755f3a1c34fd8bab7c70 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Thu, 9 Oct 2025 17:43:52 +0200 Subject: [PATCH 49/65] Add benchmark for a fixpoint iteration with nested cycles (#1001) * Add benchmark for a fixpoint iteration with nested cycles * Fix clippy warning --- benches/dataflow.rs | 43 ++++++++++++++++++- .../src/unexpected_cycle_recovery.rs | 4 +- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/benches/dataflow.rs b/benches/dataflow.rs index 24f5a16ee..db099c6b2 100644 --- a/benches/dataflow.rs +++ b/benches/dataflow.rs @@ -167,5 +167,46 @@ fn dataflow(criterion: &mut Criterion) { }); } -criterion_group!(benches, dataflow); +/// Emulates a data flow problem of the form: +/// ```py +/// self.x0 = self.x1 + self.x2 + self.x3 + self.x4 +/// self.x1 = self.x0 + self.x2 + self.x3 + self.x4 +/// self.x2 = self.x0 + self.x1 + self.x3 + self.x4 +/// self.x3 = self.x0 + self.x1 + self.x2 + self.x4 +/// self.x4 = 0 +/// ``` +fn nested(criterion: &mut Criterion) { + criterion.bench_function("converge_diverge_nested", |b| { + b.iter_batched_ref( + || { + let mut db = salsa::DatabaseImpl::new(); + + let def_x0 = Definition::new(&db, None, 0); + let def_x1 = Definition::new(&db, None, 0); + let def_x2 = Definition::new(&db, None, 0); + let def_x3 = Definition::new(&db, None, 0); + let def_x4 = Definition::new(&db, None, 0); + + let use_x0 = Use::new(&db, vec![def_x1, def_x2, def_x3, def_x4]); + let use_x1 = Use::new(&db, vec![def_x0, def_x2, def_x3, def_x4]); + let use_x2 = Use::new(&db, vec![def_x0, def_x1, def_x3, def_x4]); + let use_x3 = Use::new(&db, vec![def_x0, def_x1, def_x3, def_x4]); + + def_x0.set_base(&mut db).to(Some(use_x0)); + def_x1.set_base(&mut db).to(Some(use_x1)); + def_x2.set_base(&mut db).to(Some(use_x2)); + def_x3.set_base(&mut db).to(Some(use_x3)); + + (db, def_x0) + }, + |(db, def_x0)| { + // All symbols converge on 0. + assert_eq!(infer_definition(db, *def_x0), Type::Values(Box::from([0]))); + }, + BatchSize::LargeInput, + ); + }); +} + +criterion_group!(benches, dataflow, nested); criterion_main!(benches); diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index a1cd1e73f..8d56d54f3 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -5,7 +5,7 @@ macro_rules! unexpected_cycle_recovery { ($db:ident, $value:ident, $count:ident, $($other_inputs:ident),*) => {{ std::mem::drop($db); - std::mem::drop(($($other_inputs),*)); + std::mem::drop(($($other_inputs,)*)); panic!("cannot recover from cycle") }}; } @@ -14,7 +14,7 @@ macro_rules! unexpected_cycle_recovery { macro_rules! unexpected_cycle_initial { ($db:ident, $($other_inputs:ident),*) => {{ std::mem::drop($db); - std::mem::drop(($($other_inputs),*)); + std::mem::drop(($($other_inputs,)*)); panic!("no cycle initial value") }}; } From ef9f9329be6923acd050c8dddd172e3bc93e8051 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Thu, 16 Oct 2025 11:23:03 +0200 Subject: [PATCH 50/65] Run fixpoint per strongly connected component (#999) * Run nested cycles in a single fixpoint iteration Fix serde attribute * Remove inline from `validate_same_iteration` * Nits * Move locking into sync table * More trying * More in progress work * More progress * Fix most parallel tests * More bugfixes * Short circuit in some cases * Short circuit in drop * Delete some unused code * A working solution * Simplify more * Avoid repeated query lookups in `transfer_lock` * Use recursion for unblocking * Fix hang in `maybe_changed_after` * Move claiming of transferred memos into a separate function * More aggressive use of attributes * Make re-entrant a const parameter * Smaller clean-ups * Only collect cycle heads one level deep * More cleanups * More docs * More comments * More documentation, cleanups * More documentation, cleanups * Remove inline attribute * Fix failing tracked structs test * Fix panic * Fix persistence test * Add test for panic in nested cycle * Allow cycle initial values same-stack * Try inlining fetch * Remove some inline attributes * Add safety comment * Clippy * Panic if `provisional_retry` runs too many times * Better handling of panics in cycles * Don't use const-generic for `REENTRANT` * More nit improvements * Remove `IterationCount::panicked` * Prefer outer most cycles in `outer_cycle` * Code review feedback * Iterate only once in panic test when running with miri --- Cargo.toml | 2 +- src/active_query.rs | 4 +- src/cancelled.rs | 1 + src/cycle.rs | 308 ++++++++++-- src/function.rs | 65 ++- src/function/execute.rs | 464 +++++++++++++----- src/function/fetch.rs | 82 ++-- src/function/maybe_changed_after.rs | 185 ++++--- src/function/memo.rs | 262 +++++----- src/function/sync.rs | 358 +++++++++++++- src/ingredient.rs | 53 +- src/key.rs | 2 +- src/runtime.rs | 140 +++++- src/runtime/dependency_graph.rs | 406 ++++++++++++++- src/tracing.rs | 16 +- src/zalsa_local.rs | 102 +++- tests/backtrace.rs | 6 +- tests/cycle.rs | 8 +- tests/cycle_tracked.rs | 2 +- tests/parallel/cycle_a_t1_b_t2.rs | 2 +- tests/parallel/cycle_a_t1_b_t2_fallback.rs | 11 +- tests/parallel/cycle_nested_deep.rs | 1 + .../parallel/cycle_nested_deep_conditional.rs | 2 +- .../cycle_nested_deep_conditional_changed.rs | 12 +- tests/parallel/cycle_nested_deep_panic.rs | 142 ++++++ tests/parallel/cycle_nested_three_threads.rs | 15 +- tests/parallel/main.rs | 3 +- 27 files changed, 2093 insertions(+), 561 deletions(-) create mode 100644 tests/parallel/cycle_nested_deep_panic.rs diff --git a/Cargo.toml b/Cargo.toml index cc1cd0347..9c419e339 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ intrusive-collections = "0.9.7" parking_lot = "0.12" portable-atomic = "1" rustc-hash = "2" -smallvec = "1" +smallvec = { version = "1", features = ["const_new"] } thin-vec = { version = "0.2.14" } tracing = { version = "0.1", default-features = false, features = ["std"] } diff --git a/src/active_query.rs b/src/active_query.rs index 0b2231052..d830fece1 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -498,7 +498,7 @@ impl fmt::Display for Backtrace { if full { write!(fmt, " -> ({changed_at:?}, {durability:#?}")?; if !cycle_heads.is_empty() || !iteration_count.is_initial() { - write!(fmt, ", iteration = {iteration_count:?}")?; + write!(fmt, ", iteration = {iteration_count}")?; } write!(fmt, ")")?; } @@ -517,7 +517,7 @@ impl fmt::Display for Backtrace { } write!( fmt, - "{:?} -> {:?}", + "{:?} -> iteration = {}", head.database_key_index, head.iteration_count )?; } diff --git a/src/cancelled.rs b/src/cancelled.rs index 2f2f315d9..3c31bae5a 100644 --- a/src/cancelled.rs +++ b/src/cancelled.rs @@ -20,6 +20,7 @@ pub enum Cancelled { } impl Cancelled { + #[cold] pub(crate) fn throw(self) -> ! { // We use resume and not panic here to avoid running the panic // hook (that is, to avoid collecting and printing backtrace). diff --git a/src/cycle.rs b/src/cycle.rs index 12cb1cdc9..c9a9b82c1 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -44,14 +44,18 @@ //! result in a stable, converged cycle. If it does not (that is, if the result of another //! iteration of the cycle is not the same as the fallback value), we'll panic. //! -//! In nested cycle cases, the inner cycle head will iterate until its own cycle is resolved, but -//! the "final" value it then returns will still be provisional on the outer cycle head. The outer -//! cycle head may then iterate, which may result in a new set of iterations on the inner cycle, -//! for each iteration of the outer cycle. - +//! In nested cycle cases, the inner cycles are iterated as part of the outer cycle iteration. This helps +//! to significantly reduce the number of iterations needed to reach a fixpoint. For nested cycles, +//! the inner cycles head will transfer their lock ownership to the outer cycle. This ensures +//! that, over time, the outer cycle will hold all necessary locks to complete the fixpoint iteration. +//! Without this, different threads would compete for the locks of inner cycle heads, leading to potential +//! hangs (but not deadlocks). + +use std::iter::FusedIterator; use thin_vec::{thin_vec, ThinVec}; use crate::key::DatabaseKeyIndex; +use crate::sync::atomic::{AtomicBool, AtomicU8, Ordering}; use crate::sync::OnceLock; use crate::Revision; @@ -96,14 +100,47 @@ pub enum CycleRecoveryStrategy { /// would be the cycle head. It returns an "initial value" when the cycle is encountered (if /// fixpoint iteration is enabled for that query), and then is responsible for re-iterating the /// cycle until it converges. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Debug)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct CycleHead { pub(crate) database_key_index: DatabaseKeyIndex, - pub(crate) iteration_count: IterationCount, + pub(crate) iteration_count: AtomicIterationCount, + + /// Marks a cycle head as removed within its `CycleHeads` container. + /// + /// Cycle heads are marked as removed when the memo from the last iteration (a provisional memo) + /// is used as the initial value for the next iteration. It's necessary to remove all but its own + /// head from the `CycleHeads` container, because the query might now depend on fewer cycles + /// (in case of conditional dependencies). However, we can't actually remove the cycle head + /// within `fetch_cold_cycle` because we only have a readonly memo. That's what `removed` is used for. + #[cfg_attr(feature = "persistence", serde(skip))] + removed: AtomicBool, +} + +impl CycleHead { + pub const fn new( + database_key_index: DatabaseKeyIndex, + iteration_count: IterationCount, + ) -> Self { + Self { + database_key_index, + iteration_count: AtomicIterationCount(AtomicU8::new(iteration_count.0)), + removed: AtomicBool::new(false), + } + } } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Default)] +impl Clone for CycleHead { + fn clone(&self) -> Self { + Self { + database_key_index: self.database_key_index, + iteration_count: self.iteration_count.load().into(), + removed: self.removed.load(Ordering::Relaxed).into(), + } + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Default, PartialOrd, Ord)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "persistence", serde(transparent))] pub struct IterationCount(u8); @@ -131,11 +168,69 @@ impl IterationCount { } } +impl std::fmt::Display for IterationCount { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[derive(Debug)] +pub(crate) struct AtomicIterationCount(AtomicU8); + +impl AtomicIterationCount { + pub(crate) fn load(&self) -> IterationCount { + IterationCount(self.0.load(Ordering::Relaxed)) + } + + pub(crate) fn load_mut(&mut self) -> IterationCount { + IterationCount(*self.0.get_mut()) + } + + pub(crate) fn store(&self, value: IterationCount) { + self.0.store(value.0, Ordering::Release); + } + + pub(crate) fn store_mut(&mut self, value: IterationCount) { + *self.0.get_mut() = value.0; + } +} + +impl std::fmt::Display for AtomicIterationCount { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.load().fmt(f) + } +} + +impl From for AtomicIterationCount { + fn from(iteration_count: IterationCount) -> Self { + AtomicIterationCount(iteration_count.0.into()) + } +} + +#[cfg(feature = "persistence")] +impl serde::Serialize for AtomicIterationCount { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.load().serialize(serializer) + } +} + +#[cfg(feature = "persistence")] +impl<'de> serde::Deserialize<'de> for AtomicIterationCount { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + IterationCount::deserialize(deserializer).map(Into::into) + } +} + /// Any provisional value generated by any query in a cycle will track the cycle head(s) (can be /// plural in case of nested cycles) representing the cycles it is part of, and the current /// iteration count for each cycle head. This struct tracks these cycle heads. #[derive(Clone, Debug, Default)] -#[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct CycleHeads(ThinVec); impl CycleHeads { @@ -143,15 +238,30 @@ impl CycleHeads { self.0.is_empty() } - pub(crate) fn initial(database_key_index: DatabaseKeyIndex) -> Self { + pub(crate) fn initial( + database_key_index: DatabaseKeyIndex, + iteration_count: IterationCount, + ) -> Self { Self(thin_vec![CycleHead { database_key_index, - iteration_count: IterationCount::initial(), + iteration_count: iteration_count.into(), + removed: false.into() }]) } - pub(crate) fn iter(&self) -> std::slice::Iter<'_, CycleHead> { - self.0.iter() + pub(crate) fn iter(&self) -> CycleHeadsIterator<'_> { + CycleHeadsIterator { + inner: self.0.iter(), + } + } + + /// Iterates over all cycle heads that aren't equal to `own`. + pub(crate) fn iter_not_eq( + &self, + own: DatabaseKeyIndex, + ) -> impl DoubleEndedIterator { + self.iter() + .filter(move |head| head.database_key_index != own) } pub(crate) fn contains(&self, value: &DatabaseKeyIndex) -> bool { @@ -159,17 +269,25 @@ impl CycleHeads { .any(|head| head.database_key_index == *value) } - pub(crate) fn remove(&mut self, value: &DatabaseKeyIndex) -> bool { - let found = self - .0 - .iter() - .position(|&head| head.database_key_index == *value); - let Some(found) = found else { return false }; - self.0.swap_remove(found); - true + /// Removes all cycle heads except `except` by marking them as removed. + /// + /// Note that the heads aren't actually removed. They're only marked as removed and will be + /// skipped when iterating. This is because we might not have a mutable reference. + pub(crate) fn remove_all_except(&self, except: DatabaseKeyIndex) { + for head in self.0.iter() { + if head.database_key_index == except { + continue; + } + + head.removed.store(true, Ordering::Release); + } } - pub(crate) fn update_iteration_count( + /// Updates the iteration count for the head `cycle_head_index` to `new_iteration_count`. + /// + /// Unlike [`update_iteration_count`], this method takes a `&mut self` reference. It should + /// be preferred if possible, as it avoids atomic operations. + pub(crate) fn update_iteration_count_mut( &mut self, cycle_head_index: DatabaseKeyIndex, new_iteration_count: IterationCount, @@ -179,7 +297,24 @@ impl CycleHeads { .iter_mut() .find(|cycle_head| cycle_head.database_key_index == cycle_head_index) { - cycle_head.iteration_count = new_iteration_count; + cycle_head.iteration_count.store_mut(new_iteration_count); + } + } + + /// Updates the iteration count for the head `cycle_head_index` to `new_iteration_count`. + /// + /// Unlike [`update_iteration_count_mut`], this method takes a `&self` reference. + pub(crate) fn update_iteration_count( + &self, + cycle_head_index: DatabaseKeyIndex, + new_iteration_count: IterationCount, + ) { + if let Some(cycle_head) = self + .0 + .iter() + .find(|cycle_head| cycle_head.database_key_index == cycle_head_index) + { + cycle_head.iteration_count.store(new_iteration_count); } } @@ -188,15 +323,42 @@ impl CycleHeads { self.0.reserve(other.0.len()); for head in other { - if let Some(existing) = self - .0 - .iter() - .find(|candidate| candidate.database_key_index == head.database_key_index) - { - assert_eq!(existing.iteration_count, head.iteration_count); + debug_assert!(!head.removed.load(Ordering::Relaxed)); + self.insert(head.database_key_index, head.iteration_count.load()); + } + } + + pub(crate) fn insert( + &mut self, + database_key_index: DatabaseKeyIndex, + iteration_count: IterationCount, + ) -> bool { + if let Some(existing) = self + .0 + .iter_mut() + .find(|candidate| candidate.database_key_index == database_key_index) + { + let removed = existing.removed.get_mut(); + + if *removed { + *removed = false; + + true } else { - self.0.push(*head); + let existing_count = existing.iteration_count.load_mut(); + + assert_eq!( + existing_count, iteration_count, + "Can't merge cycle heads {:?} with different iteration counts ({existing_count:?}, {iteration_count:?})", + existing.database_key_index + ); + + false } + } else { + self.0 + .push(CycleHead::new(database_key_index, iteration_count)); + true } } @@ -206,6 +368,37 @@ impl CycleHeads { } } +#[cfg(feature = "persistence")] +impl serde::Serialize for CycleHeads { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeSeq; + + let mut seq = serializer.serialize_seq(None)?; + for e in self { + if e.removed.load(Ordering::Relaxed) { + continue; + } + + seq.serialize_element(e)?; + } + seq.end() + } +} + +#[cfg(feature = "persistence")] +impl<'de> serde::Deserialize<'de> for CycleHeads { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let vec: ThinVec = serde::Deserialize::deserialize(deserializer)?; + Ok(CycleHeads(vec)) + } +} + impl IntoIterator for CycleHeads { type Item = CycleHead; type IntoIter = as IntoIterator>::IntoIter; @@ -215,9 +408,44 @@ impl IntoIterator for CycleHeads { } } +pub struct CycleHeadsIterator<'a> { + inner: std::slice::Iter<'a, CycleHead>, +} + +impl<'a> Iterator for CycleHeadsIterator<'a> { + type Item = &'a CycleHead; + + fn next(&mut self) -> Option { + loop { + let next = self.inner.next()?; + + if next.removed.load(Ordering::Relaxed) { + continue; + } + + return Some(next); + } + } +} + +impl FusedIterator for CycleHeadsIterator<'_> {} +impl DoubleEndedIterator for CycleHeadsIterator<'_> { + fn next_back(&mut self) -> Option { + loop { + let next = self.inner.next_back()?; + + if next.removed.load(Ordering::Relaxed) { + continue; + } + + return Some(next); + } + } +} + impl<'a> std::iter::IntoIterator for &'a CycleHeads { type Item = &'a CycleHead; - type IntoIter = std::slice::Iter<'a, CycleHead>; + type IntoIter = CycleHeadsIterator<'a>; fn into_iter(self) -> Self::IntoIter { self.iter() @@ -248,21 +476,3 @@ pub enum ProvisionalStatus { }, FallbackImmediate, } - -impl ProvisionalStatus { - pub(crate) const fn iteration(&self) -> Option { - match self { - ProvisionalStatus::Provisional { iteration, .. } => Some(*iteration), - ProvisionalStatus::Final { iteration, .. } => Some(*iteration), - ProvisionalStatus::FallbackImmediate => None, - } - } - - pub(crate) const fn verified_at(&self) -> Option { - match self { - ProvisionalStatus::Provisional { verified_at, .. } => Some(*verified_at), - ProvisionalStatus::Final { verified_at, .. } => Some(*verified_at), - ProvisionalStatus::FallbackImmediate => None, - } - } -} diff --git a/src/function.rs b/src/function.rs index 58f773895..259dff14b 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,5 +1,5 @@ pub(crate) use maybe_changed_after::{VerifyCycleHeads, VerifyResult}; -pub(crate) use sync::SyncGuard; +pub(crate) use sync::{ClaimGuard, ClaimResult, Reentrancy, SyncGuard, SyncOwner, SyncTable}; use std::any::Any; use std::fmt; @@ -8,11 +8,11 @@ use std::sync::atomic::Ordering; use std::sync::OnceLock; use crate::cycle::{ - empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, ProvisionalStatus, + empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, IterationCount, + ProvisionalStatus, }; use crate::database::RawDatabase; use crate::function::delete::DeletedEntries; -use crate::function::sync::{ClaimResult, SyncTable}; use crate::hash::{FxHashSet, FxIndexSet}; use crate::ingredient::{Ingredient, WaitForResult}; use crate::key::DatabaseKeyIndex; @@ -92,7 +92,18 @@ pub trait Configuration: Any { /// Decide whether to iterate a cycle again or fallback. `value` is the provisional return /// value from the latest iteration of this cycle. `count` is the number of cycle iterations - /// we've already completed. + /// completed so far. + /// + /// # Iteration count semantics + /// + /// The `count` parameter isn't guaranteed to start from zero or to be contiguous: + /// + /// * **Initial value**: `count` may be non-zero on the first call for a given query if that + /// query becomes the outermost cycle head after a nested cycle complete a few iterations. In this case, + /// `count` continues from the nested cycle's iteration count rather than resetting to zero. + /// * **Non-contiguous values**: This function isn't called if this cycle is part of an outer cycle + /// and the value for this query remains unchanged for one iteration. But the outer cycle might + /// keep iterating because other heads keep changing. fn recover_from_cycle<'db>( db: &'db Self::DbView, value: &Self::Output<'db>, @@ -358,6 +369,41 @@ where }) } + fn set_cycle_iteration_count(&self, zalsa: &Zalsa, input: Id, iteration_count: IterationCount) { + let Some(memo) = + self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) + else { + return; + }; + + memo.revisions + .set_iteration_count(Self::database_key_index(self, input), iteration_count); + } + + fn finalize_cycle_head(&self, zalsa: &Zalsa, input: Id) { + let Some(memo) = + self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) + else { + return; + }; + + memo.revisions.verified_final.store(true, Ordering::Release); + } + + fn cycle_converged(&self, zalsa: &Zalsa, input: Id) -> bool { + let Some(memo) = + self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) + else { + return true; + }; + + memo.revisions.cycle_converged() + } + + fn mark_as_transfer_target(&self, key_index: Id) -> Option { + self.sync_table.mark_as_transfer_target(key_index) + } + fn cycle_heads<'db>(&self, zalsa: &'db Zalsa, input: Id) -> &'db CycleHeads { self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) .map(|memo| memo.cycle_heads()) @@ -372,9 +418,12 @@ where /// * [`WaitResult::Cycle`] Claiming the `key_index` results in a cycle because it's on the current's thread query stack or /// running on another thread that is blocked on this thread. fn wait_for<'me>(&'me self, zalsa: &'me Zalsa, key_index: Id) -> WaitForResult<'me> { - match self.sync_table.try_claim(zalsa, key_index) { + match self + .sync_table + .try_claim(zalsa, key_index, Reentrancy::Deny) + { ClaimResult::Running(blocked_on) => WaitForResult::Running(blocked_on), - ClaimResult::Cycle => WaitForResult::Cycle, + ClaimResult::Cycle { inner } => WaitForResult::Cycle { inner }, ClaimResult::Claimed(_) => WaitForResult::Available, } } @@ -435,10 +484,6 @@ where unreachable!("function does not allocate pages") } - fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { - C::CYCLE_STRATEGY - } - #[cfg(feature = "accumulator")] unsafe fn accumulated<'db>( &'db self, diff --git a/src/function/execute.rs b/src/function/execute.rs index 9521a9dce..67f76e145 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,12 +1,18 @@ +use smallvec::SmallVec; + use crate::active_query::CompletedQuery; -use crate::cycle::{CycleRecoveryStrategy, IterationCount}; +use crate::cycle::{CycleHeads, CycleRecoveryStrategy, IterationCount}; use crate::function::memo::Memo; -use crate::function::{Configuration, IngredientImpl}; +use crate::function::sync::ReleaseMode; +use crate::function::{ClaimGuard, Configuration, IngredientImpl}; +use crate::ingredient::WaitForResult; use crate::plumbing::ZalsaLocal; use crate::sync::atomic::{AtomicBool, Ordering}; +use crate::sync::thread; use crate::tracked_struct::Identity; use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{ActiveQueryGuard, QueryRevisions}; +use crate::{tracing, Cancelled}; use crate::{DatabaseKeyIndex, Event, EventKind, Id}; impl IngredientImpl @@ -26,12 +32,15 @@ where pub(super) fn execute<'db>( &'db self, db: &'db C::DbView, - zalsa: &'db Zalsa, + mut claim_guard: ClaimGuard<'db>, zalsa_local: &'db ZalsaLocal, - database_key_index: DatabaseKeyIndex, opt_old_memo: Option<&Memo<'db, C>>, ) -> &'db Memo<'db, C> { + let database_key_index = claim_guard.database_key_index(); + let zalsa = claim_guard.zalsa(); + let id = database_key_index.key_index(); + let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); crate::tracing::info!("{:?}: executing query", database_key_index); @@ -40,7 +49,6 @@ where database_key: database_key_index, }) }); - let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); let (new_value, mut completed_query) = match C::CYCLE_STRATEGY { CycleRecoveryStrategy::Panic => Self::execute_query( @@ -94,9 +102,8 @@ where CycleRecoveryStrategy::Fixpoint => self.execute_maybe_iterate( db, opt_old_memo, - zalsa, + &mut claim_guard, zalsa_local, - database_key_index, memo_ingredient_index, ), }; @@ -117,6 +124,7 @@ where // outputs and update the tracked struct IDs for seeding the next revision. self.diff_outputs(zalsa, database_key_index, old_memo, &completed_query); } + self.insert_memo( zalsa, id, @@ -133,25 +141,53 @@ where &'db self, db: &'db C::DbView, opt_old_memo: Option<&Memo<'db, C>>, - zalsa: &'db Zalsa, + claim_guard: &mut ClaimGuard<'db>, zalsa_local: &'db ZalsaLocal, - database_key_index: DatabaseKeyIndex, memo_ingredient_index: MemoIngredientIndex, ) -> (C::Output<'db>, CompletedQuery) { + claim_guard.set_release_mode(ReleaseMode::Default); + + let database_key_index = claim_guard.database_key_index(); + let zalsa = claim_guard.zalsa(); + let id = database_key_index.key_index(); - let mut iteration_count = IterationCount::initial(); - let mut active_query = zalsa_local.push_query(database_key_index, iteration_count); // Our provisional value from the previous iteration, when doing fixpoint iteration. - // Initially it's set to None, because the initial provisional value is created lazily, - // only when a cycle is actually encountered. - let mut opt_last_provisional: Option<&Memo<'db, C>> = None; + // This is different from `opt_old_memo` which might be from a different revision. + let mut last_provisional_memo: Option<&Memo<'db, C>> = None; + + // TODO: Can we seed those somehow? let mut last_stale_tracked_ids: Vec<(Identity, Id)> = Vec::new(); - let _guard = ClearCycleHeadIfPanicking::new(self, zalsa, id, memo_ingredient_index); + let mut iteration_count = IterationCount::initial(); + + if let Some(old_memo) = opt_old_memo { + if old_memo.verified_at.load() == zalsa.current_revision() + && old_memo.cycle_heads().contains(&database_key_index) + { + let memo_iteration_count = old_memo.revisions.iteration(); + + // The `DependencyGraph` locking propagates panics when another thread is blocked on a panicking query. + // However, the locking doesn't handle the case where a thread fetches the result of a panicking + // cycle head query **after** all locks were released. That's what we do here. + // We could consider re-executing the entire cycle but: + // a) It's tricky to ensure that all queries participating in the cycle will re-execute + // (we can't rely on `iteration_count` being updated for nested cycles because the nested cycles may have completed successfully). + // b) It's guaranteed that this query will panic again anyway. + // That's why we simply propagate the panic here. It simplifies our lives and it also avoids duplicate panic messages. + if old_memo.value.is_none() { + tracing::warn!("Propagating panic for cycle head that panicked in an earlier execution in that revision"); + Cancelled::PropagatedPanic.throw(); + } + last_provisional_memo = Some(old_memo); + iteration_count = memo_iteration_count; + } + } - loop { - let previous_memo = opt_last_provisional.or(opt_old_memo); + let _poison_guard = + PoisonProvisionalIfPanicking::new(self, zalsa, id, memo_ingredient_index); + let mut active_query = zalsa_local.push_query(database_key_index, iteration_count); + let (new_value, completed_query) = loop { // Tracked struct ids that existed in the previous revision // but weren't recreated in the last iteration. It's important that we seed the next // query with these ids because the query might re-create them as part of the next iteration. @@ -160,118 +196,267 @@ where // if they aren't recreated when reaching the final iteration. active_query.seed_tracked_struct_ids(&last_stale_tracked_ids); - let (mut new_value, mut completed_query) = - Self::execute_query(db, zalsa, active_query, previous_memo); + let (mut new_value, mut completed_query) = Self::execute_query( + db, + zalsa, + active_query, + last_provisional_memo.or(opt_old_memo), + ); + + // If there are no cycle heads, break out of the loop (`cycle_heads_mut` returns `None` if the cycle head list is empty) + let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() else { + claim_guard.set_release_mode(ReleaseMode::SelfOnly); + break (new_value, completed_query); + }; + + // Take the cycle heads to not-fight-rust's-borrow-checker. + let mut cycle_heads = std::mem::take(cycle_heads); + let mut missing_heads: SmallVec<[(DatabaseKeyIndex, IterationCount); 1]> = + SmallVec::new_const(); + let mut max_iteration_count = iteration_count; + let mut depends_on_self = false; + + // Ensure that we resolve the latest cycle heads from any provisional value this query depended on during execution. + // This isn't required in a single-threaded execution, but it's not guaranteed that `cycle_heads` contains all cycles + // in a multi-threaded execution: + // + // t1: a -> b + // t2: c -> b (blocks on t1) + // t1: a -> b -> c (cycle, returns fixpoint initial with c(0) in heads) + // t1: a -> b (completes b, b has c(0) in its cycle heads, releases `b`, which resumes `t2`, and `retry_provisional` blocks on `c` (t2)) + // t2: c -> a (cycle, returns fixpoint initial for a with a(0) in heads) + // t2: completes c, `provisional_retry` blocks on `a` (t2) + // t1: a (completes `b` with `c` in heads) + // + // Note how `a` only depends on `c` but not `a`. This is because `a` only saw the initial value of `c` and wasn't updated when `c` completed. + // That's why we need to resolve the cycle heads recursively so `cycle_heads` contains all cycle heads at the moment this query completed. + for head in &cycle_heads { + max_iteration_count = max_iteration_count.max(head.iteration_count.load()); + depends_on_self |= head.database_key_index == database_key_index; + + let ingredient = + zalsa.lookup_ingredient(head.database_key_index.ingredient_index()); + + for nested_head in + ingredient.cycle_heads(zalsa, head.database_key_index.key_index()) + { + let nested_as_tuple = ( + nested_head.database_key_index, + nested_head.iteration_count.load(), + ); + + if !cycle_heads.contains(&nested_head.database_key_index) + && !missing_heads.contains(&nested_as_tuple) + { + missing_heads.push(nested_as_tuple); + } + } + } + + for (head_key, iteration_count) in missing_heads { + max_iteration_count = max_iteration_count.max(iteration_count); + depends_on_self |= head_key == database_key_index; + + cycle_heads.insert(head_key, iteration_count); + } + + let outer_cycle = outer_cycle(zalsa, zalsa_local, &cycle_heads, database_key_index); // Did the new result we got depend on our own provisional value, in a cycle? - if let Some(cycle_heads) = completed_query - .revisions - .cycle_heads_mut() - .filter(|cycle_heads| cycle_heads.contains(&database_key_index)) - { - let last_provisional_value = if let Some(last_provisional) = opt_last_provisional { - // We have a last provisional value from our previous time around the loop. - last_provisional.value.as_ref() + // If not, return because this query is not a cycle head. + if !depends_on_self { + // For as long as this query participates in any cycle, don't release its lock, instead + // transfer it to the outermost cycle head (if any). This prevents any other thread + // from claiming this query (all cycle heads are potential entry points to the same cycle), + // which would result in them competing for the same locks (we want the locks to converge to a single cycle head). + if let Some(outer_cycle) = outer_cycle { + claim_guard.set_release_mode(ReleaseMode::TransferTo(outer_cycle)); } else { - // This is our first time around the loop; a provisional value must have been - // inserted into the memo table when the cycle was hit, so let's pull our - // initial provisional value from there. - let memo = self - .get_memo_from_table_for(zalsa, id, memo_ingredient_index) - .filter(|memo| memo.verified_at.load() == zalsa.current_revision()) - .unwrap_or_else(|| { - unreachable!( - "{database_key_index:#?} is a cycle head, \ + claim_guard.set_release_mode(ReleaseMode::SelfOnly); + } + + completed_query.revisions.set_cycle_heads(cycle_heads); + break (new_value, completed_query); + } + + // Get the last provisional value for this query so that we can compare it with the new value + // to test if the cycle converged. + let last_provisional_value = if let Some(last_provisional) = last_provisional_memo { + // We have a last provisional value from our previous time around the loop. + last_provisional.value.as_ref() + } else { + // This is our first time around the loop; a provisional value must have been + // inserted into the memo table when the cycle was hit, so let's pull our + // initial provisional value from there. + let memo = self + .get_memo_from_table_for(zalsa, id, memo_ingredient_index) + .unwrap_or_else(|| { + unreachable!( + "{database_key_index:#?} is a cycle head, \ but no provisional memo found" - ) - }); + ) + }); - debug_assert!(memo.may_be_provisional()); - memo.value.as_ref() - }; + debug_assert!(memo.may_be_provisional()); + memo.value.as_ref() + }; - let last_provisional_value = last_provisional_value.expect( - "`fetch_cold_cycle` should have inserted a provisional memo with Cycle::initial", - ); - crate::tracing::debug!( - "{database_key_index:?}: execute: \ - I am a cycle head, comparing last provisional value with new value" - ); - // If the new result is equal to the last provisional result, the cycle has - // converged and we are done. - if !C::values_equal(&new_value, last_provisional_value) { - // We are in a cycle that hasn't converged; ask the user's - // cycle-recovery function what to do: - match C::recover_from_cycle( - db, - &new_value, - iteration_count.as_u32(), - C::id_to_input(zalsa, id), - ) { - crate::CycleRecoveryAction::Iterate => {} - crate::CycleRecoveryAction::Fallback(fallback_value) => { - crate::tracing::debug!( - "{database_key_index:?}: execute: user cycle_fn says to fall back" - ); - new_value = fallback_value; - } - } - // `iteration_count` can't overflow as we check it against `MAX_ITERATIONS` - // which is less than `u32::MAX`. - iteration_count = iteration_count.increment().unwrap_or_else(|| { - tracing::warn!( - "{database_key_index:?}: execute: too many cycle iterations" + let last_provisional_value = last_provisional_value.expect( + "`fetch_cold_cycle` should have inserted a provisional memo with Cycle::initial", + ); + tracing::debug!( + "{database_key_index:?}: execute: \ + I am a cycle head, comparing last provisional value with new value" + ); + + let this_converged = C::values_equal(&new_value, last_provisional_value); + + // If this is the outermost cycle, use the maximum iteration count of all cycles. + // This is important for when later iterations introduce new cycle heads (that then + // become the outermost cycle). We want to ensure that the iteration count keeps increasing + // for all queries or they won't be re-executed because `validate_same_iteration` would + // pass when we go from 1 -> 0 and then increment by 1 to 1). + iteration_count = if outer_cycle.is_none() { + max_iteration_count + } else { + // Otherwise keep the iteration count because outer cycles + // already have a cycle head with this exact iteration count (and we don't allow + // heads from different iterations). + iteration_count + }; + + if !this_converged { + // We are in a cycle that hasn't converged; ask the user's + // cycle-recovery function what to do: + match C::recover_from_cycle( + db, + &new_value, + iteration_count.as_u32(), + C::id_to_input(zalsa, id), + ) { + crate::CycleRecoveryAction::Iterate => {} + crate::CycleRecoveryAction::Fallback(fallback_value) => { + tracing::debug!( + "{database_key_index:?}: execute: user cycle_fn says to fall back" ); - panic!("{database_key_index:?}: execute: too many cycle iterations") - }); - zalsa.event(&|| { - Event::new(EventKind::WillIterateCycle { - database_key: database_key_index, - iteration_count, - }) - }); - cycle_heads.update_iteration_count(database_key_index, iteration_count); - completed_query - .revisions - .update_iteration_count(iteration_count); - crate::tracing::info!("{database_key_index:?}: execute: iterate again...",); - opt_last_provisional = Some(self.insert_memo( - zalsa, - id, - Memo::new( - Some(new_value), - zalsa.current_revision(), - completed_query.revisions, - ), - memo_ingredient_index, - )); - last_stale_tracked_ids = completed_query.stale_tracked_structs; - - active_query = zalsa_local.push_query(database_key_index, iteration_count); - - continue; + new_value = fallback_value; + } } - crate::tracing::debug!( - "{database_key_index:?}: execute: fixpoint iteration has a final value" + } + + if let Some(outer_cycle) = outer_cycle { + tracing::info!( + "Detected nested cycle {database_key_index:?}, iterate it as part of the outer cycle {outer_cycle:?}" ); - cycle_heads.remove(&database_key_index); - - if cycle_heads.is_empty() { - // If there are no more cycle heads, we can mark this as verified. - completed_query - .revisions - .verified_final - .store(true, Ordering::Relaxed); + + completed_query.revisions.set_cycle_heads(cycle_heads); + // Store whether this cycle has converged, so that the outer cycle can check it. + completed_query + .revisions + .set_cycle_converged(this_converged); + + // Transfer ownership of this query to the outer cycle, so that it can claim it + // and other threads don't compete for the same lock. + claim_guard.set_release_mode(ReleaseMode::TransferTo(outer_cycle)); + + break (new_value, completed_query); + } + + // If this is the outermost cycle, test if all inner cycles have converged as well. + let converged = this_converged + && cycle_heads.iter_not_eq(database_key_index).all(|head| { + let ingredient = + zalsa.lookup_ingredient(head.database_key_index.ingredient_index()); + + let converged = + ingredient.cycle_converged(zalsa, head.database_key_index.key_index()); + + if !converged { + tracing::debug!("inner cycle {database_key_index:?} has not converged"); + } + + converged + }); + + if converged { + tracing::debug!( + "{database_key_index:?}: execute: fixpoint iteration has a final value after {iteration_count:?} iterations" + ); + + // Set the nested cycles as verified. This is necessary because + // `validate_provisional` doesn't follow cycle heads recursively (and the memos now depend on all cycle heads). + for head in cycle_heads.iter_not_eq(database_key_index) { + let ingredient = + zalsa.lookup_ingredient(head.database_key_index.ingredient_index()); + ingredient.finalize_cycle_head(zalsa, head.database_key_index.key_index()); } + + *completed_query.revisions.verified_final.get_mut() = true; + + break (new_value, completed_query); + } + + // The fixpoint iteration hasn't converged. Iterate again... + iteration_count = iteration_count.increment().unwrap_or_else(|| { + tracing::warn!("{database_key_index:?}: execute: too many cycle iterations"); + panic!("{database_key_index:?}: execute: too many cycle iterations") + }); + + zalsa.event(&|| { + Event::new(EventKind::WillIterateCycle { + database_key: database_key_index, + iteration_count, + }) + }); + + tracing::info!( + "{database_key_index:?}: execute: iterate again ({iteration_count:?})...", + ); + + // Update the iteration count of nested cycles. + for head in cycle_heads.iter_not_eq(database_key_index) { + let ingredient = + zalsa.lookup_ingredient(head.database_key_index.ingredient_index()); + + ingredient.set_cycle_iteration_count( + zalsa, + head.database_key_index.key_index(), + iteration_count, + ); } - crate::tracing::debug!( - "{database_key_index:?}: execute: result.revisions = {revisions:#?}", - revisions = &completed_query.revisions + // Update the iteration count of this cycle head, but only after restoring + // the cycle heads array (or this becomes a no-op). + completed_query.revisions.set_cycle_heads(cycle_heads); + completed_query + .revisions + .update_iteration_count_mut(database_key_index, iteration_count); + + let new_memo = self.insert_memo( + zalsa, + id, + Memo::new( + Some(new_value), + zalsa.current_revision(), + completed_query.revisions, + ), + memo_ingredient_index, ); - break (new_value, completed_query); - } + last_provisional_memo = Some(new_memo); + + last_stale_tracked_ids = completed_query.stale_tracked_structs; + active_query = zalsa_local.push_query(database_key_index, iteration_count); + + continue; + }; + + tracing::debug!( + "{database_key_index:?}: execute_maybe_iterate: result.revisions = {revisions:#?}", + revisions = &completed_query.revisions + ); + + (new_value, completed_query) } #[inline] @@ -325,14 +510,14 @@ where /// a new fix point initial value if that happens. /// /// We could insert a fixpoint initial value here, but it seems unnecessary. -struct ClearCycleHeadIfPanicking<'a, C: Configuration> { +struct PoisonProvisionalIfPanicking<'a, C: Configuration> { ingredient: &'a IngredientImpl, zalsa: &'a Zalsa, id: Id, memo_ingredient_index: MemoIngredientIndex, } -impl<'a, C: Configuration> ClearCycleHeadIfPanicking<'a, C> { +impl<'a, C: Configuration> PoisonProvisionalIfPanicking<'a, C> { fn new( ingredient: &'a IngredientImpl, zalsa: &'a Zalsa, @@ -348,9 +533,9 @@ impl<'a, C: Configuration> ClearCycleHeadIfPanicking<'a, C> { } } -impl Drop for ClearCycleHeadIfPanicking<'_, C> { +impl Drop for PoisonProvisionalIfPanicking<'_, C> { fn drop(&mut self) { - if std::thread::panicking() { + if thread::panicking() { let revisions = QueryRevisions::fixpoint_initial(self.ingredient.database_key_index(self.id)); @@ -360,3 +545,44 @@ impl Drop for ClearCycleHeadIfPanicking<'_, C> { } } } + +/// Returns the key of any potential outer cycle head or `None` if there is no outer cycle. +/// +/// That is, any query that's currently blocked on the result computed by this query (claiming it results in a cycle). +fn outer_cycle( + zalsa: &Zalsa, + zalsa_local: &ZalsaLocal, + cycle_heads: &CycleHeads, + current_key: DatabaseKeyIndex, +) -> Option { + // First, look for the outer most cycle head on the same thread. + // Using the outer most over the inner most should reduce the need + // for transitive transfers. + // SAFETY: We don't call into with_query_stack recursively + if let Some(same_thread) = unsafe { + zalsa_local.with_query_stack_unchecked(|stack| { + stack + .iter() + .find(|active_query| { + cycle_heads.contains(&active_query.database_key_index) + && active_query.database_key_index != current_key + }) + .map(|active_query| active_query.database_key_index) + }) + } { + return Some(same_thread); + } + + // Check for any outer cycle head running on a different thread. + cycle_heads + .iter_not_eq(current_key) + .rfind(|head| { + let ingredient = zalsa.lookup_ingredient(head.database_key_index.ingredient_index()); + + matches!( + ingredient.wait_for(zalsa, head.database_key_index.key_index()), + WaitForResult::Cycle { inner: false } + ) + }) + .map(|head| head.database_key_index) +} diff --git a/src/function/fetch.rs b/src/function/fetch.rs index a1b6658f6..ef42708a7 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -4,7 +4,7 @@ use crate::cycle::{CycleHeads, CycleRecoveryStrategy, IterationCount}; use crate::function::maybe_changed_after::VerifyCycleHeads; use crate::function::memo::Memo; use crate::function::sync::ClaimResult; -use crate::function::{Configuration, IngredientImpl}; +use crate::function::{Configuration, IngredientImpl, Reentrancy}; use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{QueryRevisions, ZalsaLocal}; use crate::{DatabaseKeyIndex, Id}; @@ -13,6 +13,7 @@ impl IngredientImpl where C: Configuration, { + #[inline] pub fn fetch<'db>( &'db self, db: &'db C::DbView, @@ -57,11 +58,19 @@ where id: Id, ) -> &'db Memo<'db, C> { let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); + let mut retry_count = 0; loop { if let Some(memo) = self .fetch_hot(zalsa, id, memo_ingredient_index) .or_else(|| { - self.fetch_cold_with_retry(zalsa, zalsa_local, db, id, memo_ingredient_index) + self.fetch_cold_with_retry( + zalsa, + zalsa_local, + db, + id, + memo_ingredient_index, + &mut retry_count, + ) }) { return memo; @@ -95,7 +104,6 @@ where } } - #[inline(never)] fn fetch_cold_with_retry<'db>( &'db self, zalsa: &'db Zalsa, @@ -103,6 +111,7 @@ where db: &'db C::DbView, id: Id, memo_ingredient_index: MemoIngredientIndex, + retry_count: &mut u32, ) -> Option<&'db Memo<'db, C>> { let memo = self.fetch_cold(zalsa, zalsa_local, db, id, memo_ingredient_index)?; @@ -114,7 +123,7 @@ where // That is only correct for fixpoint cycles, though: `FallbackImmediate` cycles // never have provisional entries. if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate - || !memo.provisional_retry(zalsa, zalsa_local, self.database_key_index(id)) + || !memo.provisional_retry(zalsa, zalsa_local, self.database_key_index(id), retry_count) { Some(memo) } else { @@ -132,21 +141,21 @@ where ) -> Option<&'db Memo<'db, C>> { let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. - let claim_guard = match self.sync_table.try_claim(zalsa, id) { + let claim_guard = match self.sync_table.try_claim(zalsa, id, Reentrancy::Allow) { ClaimResult::Claimed(guard) => guard, ClaimResult::Running(blocked_on) => { blocked_on.block_on(zalsa); - let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { + let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); - if let Some(memo) = memo { - // This isn't strictly necessary, but if this is a provisional memo for an inner cycle, - // await all outer cycle heads to give the thread driving it a chance to complete - // (we don't want multiple threads competing for the queries participating in the same cycle). - if memo.value.is_some() && memo.may_be_provisional() { - memo.block_on_heads(zalsa, zalsa_local); + if let Some(memo) = memo { + if memo.value.is_some() { + memo.block_on_heads(zalsa, zalsa_local); + } } } + return None; } ClaimResult::Cycle { .. } => { @@ -200,39 +209,10 @@ where // still valid for the current revision. return unsafe { Some(self.extend_memo_lifetime(old_memo)) }; } - - // If this is a provisional memo from the same revision, await all its cycle heads because - // we need to ensure that only one thread is iterating on a cycle at a given time. - // For example, if we have a nested cycle like so: - // ``` - // a -> b -> c -> b - // -> a - // - // d -> b - // ``` - // thread 1 calls `a` and `a` completes the inner cycle `b -> c` but hasn't finished the outer cycle `a` yet. - // thread 2 now calls `b`. We don't want that thread 2 iterates `b` while thread 1 is iterating `a` at the same time - // because it can result in thread b overriding provisional memos that thread a has accessed already and still relies upon. - // - // By waiting, we ensure that thread 1 completes a (based on a provisional value for `b`) and `b` - // becomes the new outer cycle, which thread 2 drives to completion. - if old_memo.may_be_provisional() - && old_memo.verified_at.load() == zalsa.current_revision() - { - // Try to claim all cycle heads of the provisional memo. If we can't because - // some head is running on another thread, drop our claim guard to give that thread - // a chance to take ownership of this query and complete it as part of its fixpoint iteration. - // We will then block on the cycle head and retry once all cycle heads completed. - if !old_memo.try_claim_heads(zalsa, zalsa_local) { - drop(claim_guard); - old_memo.block_on_heads(zalsa, zalsa_local); - return None; - } - } } } - let memo = self.execute(db, zalsa, zalsa_local, database_key_index, opt_old_memo); + let memo = self.execute(db, claim_guard, zalsa_local, opt_old_memo); Some(memo) } @@ -257,6 +237,19 @@ where let can_shallow_update = self.shallow_verify_memo(zalsa, database_key_index, memo); if can_shallow_update.yes() { self.update_shallow(zalsa, database_key_index, memo, can_shallow_update); + + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::Fixpoint { + memo.revisions + .cycle_heads() + .remove_all_except(database_key_index); + } + + crate::tracing::debug!( + "hit cycle at {database_key_index:#?}, \ + returning last provisional value: {:#?}", + memo.revisions + ); + // SAFETY: memo is present in memo_map. return unsafe { self.extend_memo_lifetime(memo) }; } @@ -299,7 +292,10 @@ where let mut completed_query = active_query.pop(); completed_query .revisions - .set_cycle_heads(CycleHeads::initial(database_key_index)); + .set_cycle_heads(CycleHeads::initial( + database_key_index, + IterationCount::initial(), + )); // We need this for `cycle_heads()` to work. We will unset this in the outer `execute()`. *completed_query.revisions.verified_final.get_mut() = false; self.insert_memo( diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 4f69655cd..698285055 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -2,10 +2,10 @@ use rustc_hash::FxHashMap; #[cfg(feature = "accumulator")] use crate::accumulator::accumulated_map::InputAccumulatedValues; -use crate::cycle::{CycleRecoveryStrategy, ProvisionalStatus}; -use crate::function::memo::Memo; +use crate::cycle::{CycleHeads, CycleRecoveryStrategy, ProvisionalStatus}; +use crate::function::memo::{Memo, TryClaimCycleHeadsIter, TryClaimHeadsResult}; use crate::function::sync::ClaimResult; -use crate::function::{Configuration, IngredientImpl}; +use crate::function::{Configuration, IngredientImpl, Reentrancy}; use crate::key::DatabaseKeyIndex; use crate::sync::atomic::Ordering; @@ -141,7 +141,10 @@ where ) -> Option { let database_key_index = self.database_key_index(key_index); - let _claim_guard = match self.sync_table.try_claim(zalsa, key_index) { + let claim_guard = match self + .sync_table + .try_claim(zalsa, key_index, Reentrancy::Deny) + { ClaimResult::Claimed(guard) => guard, ClaimResult::Running(blocked_on) => { blocked_on.block_on(zalsa); @@ -175,10 +178,8 @@ where // If `validate_maybe_provisional` returns `true`, but only because all cycle heads are from the same iteration, // carry over the cycle heads so that the caller verifies them. - if old_memo.may_be_provisional() { - for head in old_memo.cycle_heads() { - cycle_heads.insert_head(head.database_key_index); - } + for head in old_memo.cycle_heads() { + cycle_heads.insert_head(head.database_key_index); } return Some(if old_memo.revisions.changed_at > revision { @@ -227,7 +228,7 @@ where // `in_cycle` tracks if the enclosing query is in a cycle. `deep_verify.cycle_heads` tracks // if **this query** encountered a cycle (which means there's some provisional value somewhere floating around). if old_memo.value.is_some() && !cycle_heads.has_any() { - let memo = self.execute(db, zalsa, zalsa_local, database_key_index, Some(old_memo)); + let memo = self.execute(db, claim_guard, zalsa_local, Some(old_memo)); let changed_at = memo.revisions.changed_at; // Always assume that a provisional value has changed. @@ -323,12 +324,11 @@ where } let last_changed = zalsa.last_changed_revision(memo.revisions.durability); - crate::tracing::debug!( - "{database_key_index:?}: check_durability(memo = {memo:#?}, last_changed={:?} <= verified_at={:?}) = {:?}", + crate::tracing::trace!( + "{database_key_index:?}: check_durability({database_key_index:#?}, last_changed={:?} <= verified_at={:?}) = {:?}", last_changed, verified_at, last_changed <= verified_at, - memo = memo.tracing_debug() ); if last_changed <= verified_at { // No input of the suitable durability has changed since last verified. @@ -365,28 +365,48 @@ where database_key_index: DatabaseKeyIndex, memo: &Memo<'_, C>, ) -> bool { - !memo.may_be_provisional() - || self.validate_provisional(zalsa, database_key_index, memo) - || self.validate_same_iteration(zalsa, zalsa_local, database_key_index, memo) + if !memo.may_be_provisional() { + return true; + } + + let cycle_heads = memo.cycle_heads(); + + if cycle_heads.is_empty() { + return true; + } + + crate::tracing::trace!( + "{database_key_index:?}: validate_may_be_provisional(memo = {memo:#?})", + memo = memo.tracing_debug() + ); + + let verified_at = memo.verified_at.load(); + + self.validate_provisional(zalsa, database_key_index, memo, verified_at, cycle_heads) + || self.validate_same_iteration( + zalsa, + zalsa_local, + database_key_index, + verified_at, + cycle_heads, + ) } /// Check if this memo's cycle heads have all been finalized. If so, mark it verified final and /// return true, if not return false. - #[inline] fn validate_provisional( &self, zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, memo: &Memo<'_, C>, + memo_verified_at: Revision, + cycle_heads: &CycleHeads, ) -> bool { crate::tracing::trace!( - "{database_key_index:?}: validate_provisional(memo = {memo:#?})", - memo = memo.tracing_debug() + "{database_key_index:?}: validate_provisional({database_key_index:?})", ); - let memo_verified_at = memo.verified_at.load(); - - for cycle_head in memo.revisions.cycle_heads() { + for cycle_head in cycle_heads { // Test if our cycle heads (with the same revision) are now finalized. let Some(kind) = zalsa .lookup_ingredient(cycle_head.database_key_index.ingredient_index()) @@ -413,7 +433,7 @@ where // // If we don't account for the iteration, then `a` (from iteration 0) will be finalized // because its cycle head `b` is now finalized, but `b` never pulled `a` in the last iteration. - if iteration != cycle_head.iteration_count { + if iteration != cycle_head.iteration_count.load() { return false; } @@ -449,92 +469,61 @@ where &self, zalsa: &Zalsa, zalsa_local: &ZalsaLocal, - database_key_index: DatabaseKeyIndex, - memo: &Memo<'_, C>, + memo_database_key_index: DatabaseKeyIndex, + memo_verified_at: Revision, + cycle_heads: &CycleHeads, ) -> bool { - crate::tracing::trace!( - "{database_key_index:?}: validate_same_iteration(memo = {memo:#?})", - memo = memo.tracing_debug() - ); - - let cycle_heads = memo.revisions.cycle_heads(); - if cycle_heads.is_empty() { - return true; - } - - let verified_at = memo.verified_at.load(); + crate::tracing::trace!("validate_same_iteration({memo_database_key_index:?})",); // This is an optimization to avoid unnecessary re-execution within the same revision. // Don't apply it when verifying memos from past revisions. We want them to re-execute // to verify their cycle heads and all participating queries. - if verified_at != zalsa.current_revision() { + if memo_verified_at != zalsa.current_revision() { return false; } - // SAFETY: We do not access the query stack reentrantly. - unsafe { - zalsa_local.with_query_stack_unchecked(|stack| { - cycle_heads.iter().all(|cycle_head| { + // Always return `false` for cycle initial values "unless" they are running in the same thread. + if cycle_heads + .iter() + .all(|head| head.database_key_index == memo_database_key_index) + { + // SAFETY: We do not access the query stack reentrantly. + let on_stack = unsafe { + zalsa_local.with_query_stack_unchecked(|stack| { stack .iter() .rev() - .find(|query| query.database_key_index == cycle_head.database_key_index) - .map(|query| query.iteration_count()) - .or_else(|| { - // If the cycle head isn't on our stack because: - // - // * another thread holds the lock on the cycle head (but it waits for the current query to complete) - // * we're in `maybe_changed_after` because `maybe_changed_after` doesn't modify the cycle stack - // - // check if the latest memo has the same iteration count. - - // However, we've to be careful to skip over fixpoint initial values: - // If the head is the memo we're trying to validate, always return `None` - // to force a re-execution of the query. This is necessary because the query - // has obviously not completed its iteration yet. - // - // This should be rare but the `cycle_panic` test fails on some platforms (mainly GitHub actions) - // without this check. What happens there is that: - // - // * query a blocks on query b - // * query b tries to claim a, fails to do so and inserts the fixpoint initial value - // * query b completes and has `a` as head. It returns its query result Salsa blocks query b from - // exiting inside `block_on` (or the thread would complete before the cycle iteration is complete) - // * query a resumes but panics because of the fixpoint iteration function - // * query b resumes. It rexecutes its own query which then tries to fetch a (which depends on itself because it's a fixpoint initial value). - // Without this check, `validate_same_iteration` would return `true` because the latest memo for `a` is the fixpoint initial value. - // But it should return `false` so that query b's thread re-executes `a` (which then also causes the panic). - // - // That's why we always return `None` if the cycle head is the same as the current database key index. - if cycle_head.database_key_index == database_key_index { - return None; - } + .any(|query| query.database_key_index == memo_database_key_index) + }) + }; - let ingredient = zalsa.lookup_ingredient( - cycle_head.database_key_index.ingredient_index(), - ); - let wait_result = ingredient - .wait_for(zalsa, cycle_head.database_key_index.key_index()); + return on_stack; + } - if !wait_result.is_cycle() { - return None; - } + let cycle_heads_iter = TryClaimCycleHeadsIter::new(zalsa, zalsa_local, cycle_heads); - let provisional_status = ingredient.provisional_status( - zalsa, - cycle_head.database_key_index.key_index(), - )?; + for cycle_head in cycle_heads_iter { + match cycle_head { + TryClaimHeadsResult::Cycle { + head_iteration_count, + memo_iteration_count: current_iteration_count, + verified_at: head_verified_at, + } => { + if head_verified_at != memo_verified_at { + return false; + } - if provisional_status.verified_at() == Some(verified_at) { - provisional_status.iteration() - } else { - None - } - }) - == Some(cycle_head.iteration_count) - }) - }) + if head_iteration_count != current_iteration_count { + return false; + } + } + _ => { + return false; + } + } } + + true } /// VerifyResult::Unchanged if the memo's value and `changed_at` time is up-to-date in the @@ -553,6 +542,12 @@ where cycle_heads: &mut VerifyCycleHeads, can_shallow_update: ShallowUpdate, ) -> VerifyResult { + // If the value is from the same revision but is still provisional, consider it changed + // because we're now in a new iteration. + if can_shallow_update == ShallowUpdate::Verified && old_memo.may_be_provisional() { + return VerifyResult::changed(); + } + crate::tracing::debug!( "{database_key_index:?}: deep_verify_memo(old_memo = {old_memo:#?})", old_memo = old_memo.tracing_debug() @@ -562,12 +557,6 @@ where match old_memo.revisions.origin.as_ref() { QueryOriginRef::Derived(edges) => { - // If the value is from the same revision but is still provisional, consider it changed - // because we're now in a new iteration. - if can_shallow_update == ShallowUpdate::Verified && old_memo.may_be_provisional() { - return VerifyResult::changed(); - } - #[cfg(feature = "accumulator")] let mut inputs = InputAccumulatedValues::Empty; let mut child_cycle_heads = Vec::new(); diff --git a/src/function/memo.rs b/src/function/memo.rs index 793f4832a..302ca73c3 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -3,10 +3,11 @@ use std::fmt::{Debug, Formatter}; use std::mem::transmute; use std::ptr::NonNull; -use crate::cycle::{empty_cycle_heads, CycleHead, CycleHeads, IterationCount, ProvisionalStatus}; +use crate::cycle::{ + empty_cycle_heads, CycleHeads, CycleHeadsIterator, IterationCount, ProvisionalStatus, +}; use crate::function::{Configuration, IngredientImpl}; -use crate::hash::FxHashSet; -use crate::ingredient::{Ingredient, WaitForResult}; +use crate::ingredient::WaitForResult; use crate::key::DatabaseKeyIndex; use crate::revision::AtomicRevision; use crate::runtime::Running; @@ -143,21 +144,23 @@ impl<'db, C: Configuration> Memo<'db, C> { zalsa: &Zalsa, zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, + retry_count: &mut u32, ) -> bool { - if self.revisions.cycle_heads().is_empty() { - return false; - } - - if !self.may_be_provisional() { - return false; - }; - if self.block_on_heads(zalsa, zalsa_local) { // If we get here, we are a provisional value of // the cycle head (either initial value, or from a later iteration) and should be // returned to caller to allow fixpoint iteration to proceed. false } else { + assert!( + *retry_count <= 20000, + "Provisional memo retry limit exceeded for {database_key_index:?}; \ + this usually indicates a bug in salsa's cycle caching/locking. \ + (retried {retry_count} times)", + ); + + *retry_count += 1; + // all our cycle heads are complete; re-fetch // and we should get a non-provisional memo. crate::tracing::debug!( @@ -176,33 +179,50 @@ impl<'db, C: Configuration> Memo<'db, C> { // IMPORTANT: If you make changes to this function, make sure to run `cycle_nested_deep` with // shuttle with at least 10k iterations. - // The most common case is that the entire cycle is running in the same thread. - // If that's the case, short circuit and return `true` immediately. - if self.all_cycles_on_stack(zalsa_local) { + let cycle_heads = self.cycle_heads(); + if cycle_heads.is_empty() { return true; } - // Otherwise, await all cycle heads, recursively. - return block_on_heads_cold(zalsa, self.cycle_heads()); + return block_on_heads_cold(zalsa, zalsa_local, cycle_heads); #[inline(never)] - fn block_on_heads_cold(zalsa: &Zalsa, heads: &CycleHeads) -> bool { + fn block_on_heads_cold( + zalsa: &Zalsa, + zalsa_local: &ZalsaLocal, + heads: &CycleHeads, + ) -> bool { let _entered = crate::tracing::debug_span!("block_on_heads").entered(); - let mut cycle_heads = TryClaimCycleHeadsIter::new(zalsa, heads); + let cycle_heads = TryClaimCycleHeadsIter::new(zalsa, zalsa_local, heads); let mut all_cycles = true; - while let Some(claim_result) = cycle_heads.next() { + for claim_result in cycle_heads { match claim_result { - TryClaimHeadsResult::Cycle => {} - TryClaimHeadsResult::Finalized => { - all_cycles = false; + TryClaimHeadsResult::Cycle { + memo_iteration_count: current_iteration_count, + head_iteration_count, + .. + } => { + // We need to refetch if the head now has a new iteration count. + // This is to avoid a race between thread A and B: + // * thread A is in `blocks_on` (`retry_provisional`) for the memo `c`. It owns the lock for `e` + // * thread B owns `d` and calls `c`. `c` didn't depend on `e` in the first iteration. + // Thread B completes the first iteration (which bumps the iteration count on `c`). + // `c` now depends on E in the second iteration, introducing a new cycle head. + // Thread B transfers ownership of `c` to thread A (which awakes A). + // * Thread A now continues, there are no other cycle heads, so all queries result in a cycle. + // However, `d` has now a new iteration count, so it's important that we refetch `c`. + + if current_iteration_count != head_iteration_count { + all_cycles = false; + } } TryClaimHeadsResult::Available => { all_cycles = false; } TryClaimHeadsResult::Running(running) => { all_cycles = false; - running.block_on(&mut cycle_heads); + running.block_on(zalsa); } } } @@ -211,51 +231,6 @@ impl<'db, C: Configuration> Memo<'db, C> { } } - /// Tries to claim all cycle heads to see if they're finalized or available. - /// - /// Unlike `block_on_heads`, this code does not block on any cycle head. Instead it returns `false` if - /// claiming all cycle heads failed because one of them is running on another thread. - pub(super) fn try_claim_heads(&self, zalsa: &Zalsa, zalsa_local: &ZalsaLocal) -> bool { - let _entered = crate::tracing::debug_span!("try_claim_heads").entered(); - if self.all_cycles_on_stack(zalsa_local) { - return true; - } - - let cycle_heads = TryClaimCycleHeadsIter::new(zalsa, self.revisions.cycle_heads()); - - for claim_result in cycle_heads { - match claim_result { - TryClaimHeadsResult::Cycle - | TryClaimHeadsResult::Finalized - | TryClaimHeadsResult::Available => {} - TryClaimHeadsResult::Running(_) => { - return false; - } - } - } - - true - } - - fn all_cycles_on_stack(&self, zalsa_local: &ZalsaLocal) -> bool { - let cycle_heads = self.revisions.cycle_heads(); - if cycle_heads.is_empty() { - return true; - } - - // SAFETY: We do not access the query stack reentrantly. - unsafe { - zalsa_local.with_query_stack_unchecked(|stack| { - cycle_heads.iter().all(|cycle_head| { - stack - .iter() - .rev() - .any(|query| query.database_key_index == cycle_head.database_key_index) - }) - }) - } - } - /// Cycle heads that should be propagated to dependent queries. #[inline(always)] pub(super) fn cycle_heads(&self) -> &CycleHeads { @@ -473,118 +448,111 @@ mod persistence { } pub(super) enum TryClaimHeadsResult<'me> { - /// Claiming every cycle head results in a cycle head. - Cycle, - - /// The cycle head has been finalized. - Finalized, + /// Claiming the cycle head results in a cycle. + Cycle { + head_iteration_count: IterationCount, + memo_iteration_count: IterationCount, + verified_at: Revision, + }, /// The cycle head is not finalized, but it can be claimed. Available, /// The cycle head is currently executed on another thread. - Running(RunningCycleHead<'me>), -} - -pub(super) struct RunningCycleHead<'me> { - inner: Running<'me>, - ingredient: &'me dyn Ingredient, -} - -impl<'a> RunningCycleHead<'a> { - fn block_on(self, cycle_heads: &mut TryClaimCycleHeadsIter<'a>) { - let key_index = self.inner.database_key().key_index(); - self.inner.block_on(cycle_heads.zalsa); - - cycle_heads.queue_ingredient_heads(self.ingredient, key_index); - } + Running(Running<'me>), } /// Iterator to try claiming the transitive cycle heads of a memo. -struct TryClaimCycleHeadsIter<'a> { +pub(super) struct TryClaimCycleHeadsIter<'a> { zalsa: &'a Zalsa, - queue: Vec, - queued: FxHashSet, + zalsa_local: &'a ZalsaLocal, + cycle_heads: CycleHeadsIterator<'a>, } impl<'a> TryClaimCycleHeadsIter<'a> { - fn new(zalsa: &'a Zalsa, heads: &CycleHeads) -> Self { - let queue: Vec<_> = heads.iter().copied().collect(); - let queued: FxHashSet<_> = queue.iter().copied().collect(); - + pub(super) fn new( + zalsa: &'a Zalsa, + zalsa_local: &'a ZalsaLocal, + cycle_heads: &'a CycleHeads, + ) -> Self { Self { zalsa, - queue, - queued, + zalsa_local, + cycle_heads: cycle_heads.iter(), } } - - fn queue_ingredient_heads(&mut self, ingredient: &dyn Ingredient, key: Id) { - // Recursively wait for all cycle heads that this head depends on. It's important - // that we fetch those from the updated memo because the cycle heads can change - // between iterations and new cycle heads can be added if a query depeonds on - // some cycle heads depending on a specific condition being met - // (`a` calls `b` and `c` in iteration 0 but `c` and `d` in iteration 1 or later). - // IMPORTANT: It's critical that we get the cycle head from the latest memo - // here, in case the memo has become part of another cycle (we need to block on that too!). - self.queue.extend( - ingredient - .cycle_heads(self.zalsa, key) - .iter() - .copied() - .filter(|head| self.queued.insert(*head)), - ) - } } impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { type Item = TryClaimHeadsResult<'me>; fn next(&mut self) -> Option { - let head = self.queue.pop()?; + let head = self.cycle_heads.next()?; let head_database_key = head.database_key_index; + let head_iteration_count = head.iteration_count.load(); + + // The most common case is that the head is already in the query stack. So let's check that first. + // SAFETY: We do not access the query stack reentrantly. + if let Some(current_iteration_count) = unsafe { + self.zalsa_local.with_query_stack_unchecked(|stack| { + stack + .iter() + .rev() + .find(|query| query.database_key_index == head_database_key) + .map(|query| query.iteration_count()) + }) + } { + crate::tracing::trace!( + "Waiting for {head_database_key:?} results in a cycle (because it is already in the query stack)" + ); + return Some(TryClaimHeadsResult::Cycle { + head_iteration_count, + memo_iteration_count: current_iteration_count, + verified_at: self.zalsa.current_revision(), + }); + } + let head_key_index = head_database_key.key_index(); let ingredient = self .zalsa .lookup_ingredient(head_database_key.ingredient_index()); - let cycle_head_kind = ingredient - .provisional_status(self.zalsa, head_key_index) - .unwrap_or(ProvisionalStatus::Provisional { - iteration: IterationCount::initial(), - verified_at: Revision::start(), - }); + match ingredient.wait_for(self.zalsa, head_key_index) { + WaitForResult::Cycle { .. } => { + // We hit a cycle blocking on the cycle head; this means this query actively + // participates in the cycle and some other query is blocked on this thread. + crate::tracing::trace!("Waiting for {head_database_key:?} results in a cycle"); + + let provisional_status = ingredient + .provisional_status(self.zalsa, head_key_index) + .expect("cycle head memo to exist"); + let (current_iteration_count, verified_at) = match provisional_status { + ProvisionalStatus::Provisional { + iteration, + verified_at, + } + | ProvisionalStatus::Final { + iteration, + verified_at, + } => (iteration, verified_at), + ProvisionalStatus::FallbackImmediate => { + (IterationCount::initial(), self.zalsa.current_revision()) + } + }; - match cycle_head_kind { - ProvisionalStatus::Final { .. } | ProvisionalStatus::FallbackImmediate => { - // This cycle is already finalized, so we don't need to wait on it; - // keep looping through cycle heads. - crate::tracing::trace!("Dependent cycle head {head:?} has been finalized."); - Some(TryClaimHeadsResult::Finalized) + Some(TryClaimHeadsResult::Cycle { + memo_iteration_count: current_iteration_count, + head_iteration_count, + verified_at, + }) } - ProvisionalStatus::Provisional { .. } => { - match ingredient.wait_for(self.zalsa, head_key_index) { - WaitForResult::Cycle { .. } => { - // We hit a cycle blocking on the cycle head; this means this query actively - // participates in the cycle and some other query is blocked on this thread. - crate::tracing::debug!("Waiting for {head:?} results in a cycle"); - Some(TryClaimHeadsResult::Cycle) - } - WaitForResult::Running(running) => { - crate::tracing::debug!("Ingredient {head:?} is running: {running:?}"); + WaitForResult::Running(running) => { + crate::tracing::trace!("Ingredient {head_database_key:?} is running: {running:?}"); - Some(TryClaimHeadsResult::Running(RunningCycleHead { - inner: running, - ingredient, - })) - } - WaitForResult::Available => { - self.queue_ingredient_heads(ingredient, head_key_index); - Some(TryClaimHeadsResult::Available) - } - } + Some(TryClaimHeadsResult::Running(running)) } + WaitForResult::Available => Some(TryClaimHeadsResult::Available), } } } diff --git a/src/function/sync.rs b/src/function/sync.rs index 0a88844af..97a36262c 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -1,9 +1,13 @@ use rustc_hash::FxHashMap; +use std::collections::hash_map::OccupiedEntry; use crate::key::DatabaseKeyIndex; -use crate::runtime::{BlockResult, Running, WaitResult}; -use crate::sync::thread::{self, ThreadId}; +use crate::runtime::{ + BlockOnTransferredOwner, BlockResult, BlockTransferredResult, Running, WaitResult, +}; +use crate::sync::thread::{self}; use crate::sync::Mutex; +use crate::tracing; use crate::zalsa::Zalsa; use crate::{Id, IngredientIndex}; @@ -20,17 +24,36 @@ pub(crate) enum ClaimResult<'a> { /// Can't claim the query because it is running on an other thread. Running(Running<'a>), /// Claiming the query results in a cycle. - Cycle, + Cycle { + /// `true` if this is a cycle with an inner query. For example, if `a` transferred its ownership to + /// `b`. If the thread claiming `b` tries to claim `a`, then this results in a cycle except when calling + /// [`SyncTable::try_claim`] with [`Reentrant::Allow`]. + inner: bool, + }, /// Successfully claimed the query. Claimed(ClaimGuard<'a>), } pub(crate) struct SyncState { - id: ThreadId, + /// The thread id that currently owns this query (actively executing it or iterating it as part of a larger cycle). + id: SyncOwner, /// Set to true if any other queries are blocked, /// waiting for this query to complete. anyone_waiting: bool, + + /// Whether any other query has transferred its lock ownership to this query. + /// This is only an optimization so that the expensive unblocking of transferred queries + /// can be skipped if `false`. This field might be `true` in cases where queries *were* transferred + /// to this query, but have since then been transferred to another query (in a later iteration). + is_transfer_target: bool, + + /// Whether this query has been claimed by the query that currently owns it. + /// + /// If `a` has been transferred to `b` and the stack for t1 is `b -> a`, then `a` can be claimed + /// and `claimed_twice` is set to `true`. However, t2 won't be able to claim `a` because + /// it doesn't own `b`. + claimed_twice: bool, } impl SyncTable { @@ -41,14 +64,34 @@ impl SyncTable { } } - pub(crate) fn try_claim<'me>(&'me self, zalsa: &'me Zalsa, key_index: Id) -> ClaimResult<'me> { + /// Claims the given key index, or blocks if it is running on another thread. + pub(crate) fn try_claim<'me>( + &'me self, + zalsa: &'me Zalsa, + key_index: Id, + reentrant: Reentrancy, + ) -> ClaimResult<'me> { let mut write = self.syncs.lock(); match write.entry(key_index) { std::collections::hash_map::Entry::Occupied(occupied_entry) => { + let id = match occupied_entry.get().id { + SyncOwner::Thread(id) => id, + SyncOwner::Transferred => { + return match self.try_claim_transferred(zalsa, occupied_entry, reentrant) { + Ok(claimed) => claimed, + Err(other_thread) => match other_thread.block(write) { + BlockResult::Cycle => ClaimResult::Cycle { inner: false }, + BlockResult::Running(running) => ClaimResult::Running(running), + }, + } + } + }; + let &mut SyncState { - id, ref mut anyone_waiting, + .. } = occupied_entry.into_mut(); + // NB: `Ordering::Relaxed` is sufficient here, // as there are no loads that are "gated" on this // value. Everything that is written is also protected @@ -62,22 +105,116 @@ impl SyncTable { write, ) { BlockResult::Running(blocked_on) => ClaimResult::Running(blocked_on), - BlockResult::Cycle => ClaimResult::Cycle, + BlockResult::Cycle => ClaimResult::Cycle { inner: false }, } } std::collections::hash_map::Entry::Vacant(vacant_entry) => { vacant_entry.insert(SyncState { - id: thread::current().id(), + id: SyncOwner::Thread(thread::current().id()), anyone_waiting: false, + is_transfer_target: false, + claimed_twice: false, }); ClaimResult::Claimed(ClaimGuard { key_index, zalsa, sync_table: self, + mode: ReleaseMode::Default, }) } } } + + #[cold] + #[inline(never)] + fn try_claim_transferred<'me>( + &'me self, + zalsa: &'me Zalsa, + mut entry: OccupiedEntry, + reentrant: Reentrancy, + ) -> Result, Box>> { + let key_index = *entry.key(); + let database_key_index = DatabaseKeyIndex::new(self.ingredient, key_index); + let thread_id = thread::current().id(); + + match zalsa + .runtime() + .block_transferred(database_key_index, thread_id) + { + BlockTransferredResult::ImTheOwner if reentrant.is_allow() => { + let SyncState { + id, claimed_twice, .. + } = entry.into_mut(); + debug_assert!(!*claimed_twice); + + *id = SyncOwner::Thread(thread_id); + *claimed_twice = true; + + Ok(ClaimResult::Claimed(ClaimGuard { + key_index, + zalsa, + sync_table: self, + mode: ReleaseMode::SelfOnly, + })) + } + BlockTransferredResult::ImTheOwner => Ok(ClaimResult::Cycle { inner: true }), + BlockTransferredResult::OwnedBy(other_thread) => { + entry.get_mut().anyone_waiting = true; + Err(other_thread) + } + BlockTransferredResult::Released => { + entry.insert(SyncState { + id: SyncOwner::Thread(thread_id), + anyone_waiting: false, + is_transfer_target: false, + claimed_twice: false, + }); + Ok(ClaimResult::Claimed(ClaimGuard { + key_index, + zalsa, + sync_table: self, + mode: ReleaseMode::Default, + })) + } + } + } + + /// Marks `key_index` as a transfer target. + /// + /// Returns the `SyncOwnerId` of the thread that currently owns this query. + /// + /// Note: The result of this method will immediately become stale unless the thread owning `key_index` + /// is currently blocked on this thread (claiming `key_index` from this thread results in a cycle). + pub(super) fn mark_as_transfer_target(&self, key_index: Id) -> Option { + let mut syncs = self.syncs.lock(); + syncs.get_mut(&key_index).map(|state| { + // We set `anyone_waiting` to true because it is used in `ClaimGuard::release` + // to exit early if the query doesn't need to release any locks. + // However, there are now dependent queries that need to be released, that's why we set `anyone_waiting` to true, + // so that `ClaimGuard::release` no longer exits early. + state.anyone_waiting = true; + state.is_transfer_target = true; + + state.id + }) + } +} + +#[derive(Copy, Clone, Debug)] +pub enum SyncOwner { + /// Query is owned by this thread + Thread(thread::ThreadId), + + /// The query's lock ownership has been transferred to another query. + /// E.g. if `a` transfers its ownership to `b`, then only the thread in the critical path + /// to complete b` can claim `a` (in most instances, only the thread owning `b` can claim `a`). + /// + /// The thread owning `a` is stored in the `DependencyGraph`. + /// + /// A query can be marked as `Transferred` even if it has since then been released by the owning query. + /// In that case, the query is effectively unclaimed and the `Transferred` state is stale. The reason + /// for this is that it avoids the need for locking each sync table when releasing the transferred queries. + Transferred, } /// Marks an active 'claim' in the synchronization map. The claim is @@ -87,33 +224,147 @@ pub(crate) struct ClaimGuard<'me> { key_index: Id, zalsa: &'me Zalsa, sync_table: &'me SyncTable, + mode: ReleaseMode, } -impl ClaimGuard<'_> { - fn remove_from_map_and_unblock_queries(&self) { +impl<'me> ClaimGuard<'me> { + pub(crate) const fn zalsa(&self) -> &'me Zalsa { + self.zalsa + } + + pub(crate) const fn database_key_index(&self) -> DatabaseKeyIndex { + DatabaseKeyIndex::new(self.sync_table.ingredient, self.key_index) + } + + pub(crate) fn set_release_mode(&mut self, mode: ReleaseMode) { + self.mode = mode; + } + + #[cold] + #[inline(never)] + fn release_panicking(&self) { let mut syncs = self.sync_table.syncs.lock(); + let state = syncs.remove(&self.key_index).expect("key claimed twice?"); + tracing::debug!( + "Release claim on {:?} due to panic", + self.database_key_index() + ); + + self.release(state, WaitResult::Panicked); + } + + #[inline(always)] + fn release(&self, state: SyncState, wait_result: WaitResult) { + let SyncState { + anyone_waiting, + is_transfer_target, + claimed_twice, + .. + } = state; + + if !anyone_waiting { + return; + } + + let runtime = self.zalsa.runtime(); + let database_key_index = self.database_key_index(); - let SyncState { anyone_waiting, .. } = - syncs.remove(&self.key_index).expect("key claimed twice?"); - - if anyone_waiting { - let database_key = DatabaseKeyIndex::new(self.sync_table.ingredient, self.key_index); - self.zalsa.runtime().unblock_queries_blocked_on( - database_key, - if thread::panicking() { - tracing::info!("Unblocking queries blocked on {database_key:?} after a panick"); - WaitResult::Panicked - } else { - WaitResult::Completed - }, - ) + if claimed_twice { + runtime.undo_transfer_lock(database_key_index); } + + if is_transfer_target { + runtime.unblock_transferred_queries_owned_by(database_key_index, wait_result); + } + + runtime.unblock_queries_blocked_on(database_key_index, wait_result); + } + + #[cold] + #[inline(never)] + fn release_self(&self) { + let mut syncs = self.sync_table.syncs.lock(); + let std::collections::hash_map::Entry::Occupied(mut state) = syncs.entry(self.key_index) + else { + panic!("key should only be claimed/released once"); + }; + + if state.get().claimed_twice { + state.get_mut().claimed_twice = false; + state.get_mut().id = SyncOwner::Transferred; + } else { + self.release(state.remove(), WaitResult::Completed); + } + } + + #[cold] + #[inline(never)] + pub(crate) fn transfer(&self, new_owner: DatabaseKeyIndex) { + let owner_ingredient = self.zalsa.lookup_ingredient(new_owner.ingredient_index()); + + // Get the owning thread of `new_owner`. + // The thread id is guaranteed to not be stale because `new_owner` must be blocked on `self_key` + // or `transfer_lock` will panic (at least in debug builds). + let Some(new_owner_thread_id) = + owner_ingredient.mark_as_transfer_target(new_owner.key_index()) + else { + self.release( + self.sync_table + .syncs + .lock() + .remove(&self.key_index) + .expect("key should only be claimed/released once"), + WaitResult::Panicked, + ); + + panic!("new owner to be a locked query") + }; + + let mut syncs = self.sync_table.syncs.lock(); + + let self_key = self.database_key_index(); + tracing::debug!( + "Transferring lock ownership of {self_key:?} to {new_owner:?} ({new_owner_thread_id:?})" + ); + + let SyncState { + id, claimed_twice, .. + } = syncs + .get_mut(&self.key_index) + .expect("key should only be claimed/released once"); + + self.zalsa + .runtime() + .transfer_lock(self_key, new_owner, new_owner_thread_id); + + *id = SyncOwner::Transferred; + *claimed_twice = false; } } impl Drop for ClaimGuard<'_> { fn drop(&mut self) { - self.remove_from_map_and_unblock_queries() + if thread::panicking() { + self.release_panicking(); + return; + } + + match self.mode { + ReleaseMode::Default => { + let mut syncs = self.sync_table.syncs.lock(); + let state = syncs + .remove(&self.key_index) + .expect("key should only be claimed/released once"); + + self.release(state, WaitResult::Completed); + } + ReleaseMode::SelfOnly => { + self.release_self(); + } + ReleaseMode::TransferTo(new_owner) => { + self.transfer(new_owner); + } + } } } @@ -122,3 +373,60 @@ impl std::fmt::Debug for SyncTable { f.debug_struct("SyncTable").finish() } } + +/// Controls how the lock is released when the `ClaimGuard` is dropped. +#[derive(Copy, Clone, Debug, Default)] +pub(crate) enum ReleaseMode { + /// The default release mode. + /// + /// Releases the query for which this claim guard holds the lock and any queries that have + /// transferred ownership to this query. + #[default] + Default, + + /// Only releases the lock for this query. Any query that has transferred ownership to this query + /// will remain locked. + /// + /// If this thread panics, the query will be released as normal (default mode). + SelfOnly, + + /// Transfers the ownership of the lock to the specified query. + /// + /// The query will remain locked and only the thread owning the transfer target will be resumed. + /// + /// The transfer target must be a query that's blocked on this query to guarantee that the transfer target doesn't complete + /// before the transfer is finished (which would leave this query locked forever). + /// + /// If this thread panics, the query will be released as normal (default mode). + TransferTo(DatabaseKeyIndex), +} + +impl std::fmt::Debug for ClaimGuard<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClaimGuard") + .field("key_index", &self.key_index) + .field("mode", &self.mode) + .finish_non_exhaustive() + } +} + +/// Controls whether this thread can claim a query that transferred its ownership to a query +/// this thread currently holds the lock for. +/// +/// For example: if query `a` transferred its ownership to query `b`, and this thread holds +/// the lock for `b`, then this thread can also claim `a` — but only when using [`Self::Allow`]. +#[derive(Copy, Clone, PartialEq, Eq)] +pub(crate) enum Reentrancy { + /// Allow `try_claim` to reclaim a query's that transferred its ownership to a query + /// hold by this thread. + Allow, + + /// Only allow claiming queries that haven't been claimed by any thread. + Deny, +} + +impl Reentrancy { + const fn is_allow(self) -> bool { + matches!(self, Reentrancy::Allow) + } +} diff --git a/src/ingredient.rs b/src/ingredient.rs index 3cf36ae61..9b377e4d1 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -1,7 +1,7 @@ use std::any::{Any, TypeId}; use std::fmt; -use crate::cycle::{empty_cycle_heads, CycleHeads, CycleRecoveryStrategy, ProvisionalStatus}; +use crate::cycle::{empty_cycle_heads, CycleHeads, IterationCount, ProvisionalStatus}; use crate::database::RawDatabase; use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::hash::{FxHashSet, FxIndexSet}; @@ -93,9 +93,19 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// on an other thread, it's up to caller to block until the result becomes available if desired. /// A return value of [`WaitForResult::Cycle`] means that a cycle was encountered; the waited-on query is either already claimed /// by the current thread, or by a thread waiting on the current thread. - fn wait_for<'me>(&'me self, zalsa: &'me Zalsa, key_index: Id) -> WaitForResult<'me> { - _ = (zalsa, key_index); - WaitForResult::Available + fn wait_for<'me>(&'me self, _zalsa: &'me Zalsa, _key_index: Id) -> WaitForResult<'me> { + unreachable!( + "wait_for should only be called on cycle heads and only functions can be cycle heads" + ); + } + + /// Invoked when a query transfers its lock-ownership to `_key_index`. Returns the thread + /// owning the lock for `_key_index` or `None` if `_key_index` is not claimed. + /// + /// Note: The returned `SyncOwnerId` may be outdated as soon as this function returns **unless** + /// it's guaranteed that `_key_index` is blocked on the current thread. + fn mark_as_transfer_target(&self, _key_index: Id) -> Option { + unreachable!("mark_as_transfer_target should only be called on functions"); } /// Invoked when the value `output_key` should be marked as valid in the current revision. @@ -157,11 +167,27 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { } // Function ingredient methods - /// If this ingredient is a participant in a cycle, what is its cycle recovery strategy? - /// (Really only relevant to [`crate::function::FunctionIngredient`], - /// since only function ingredients push themselves onto the active query stack.) - fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { - unreachable!("only function ingredients can be part of a cycle") + /// Tests if the (nested) cycle head `_input` has converged in the most recent iteration. + /// + /// Returns `false` if the Memo doesn't exist or if called on a non-cycle head. + fn cycle_converged(&self, _zalsa: &Zalsa, _input: Id) -> bool { + unreachable!("cycle_converged should only be called on cycle heads and only functions can be cycle heads"); + } + + /// Updates the iteration count for the (nested) cycle head `_input` to `iteration_count`. + /// + /// This is a no-op if the memo doesn't exist or if called on a Memo without cycle heads. + fn set_cycle_iteration_count( + &self, + _zalsa: &Zalsa, + _input: Id, + _iteration_count: IterationCount, + ) { + unreachable!("increment_iteration_count should only be called on cycle heads and only functions can be cycle heads"); + } + + fn finalize_cycle_head(&self, _zalsa: &Zalsa, _input: Id) { + unreachable!("finalize_cycle_head should only be called on cycle heads and only functions can be cycle heads"); } /// What were the inputs (if any) that were used to create the value at `key_index`. @@ -302,14 +328,9 @@ pub(crate) fn fmt_index(debug_name: &str, id: Id, fmt: &mut fmt::Formatter<'_>) write!(fmt, "{debug_name}({id:?})") } +#[derive(Debug)] pub enum WaitForResult<'me> { Running(Running<'me>), Available, - Cycle, -} - -impl WaitForResult<'_> { - pub const fn is_cycle(&self) -> bool { - matches!(self, WaitForResult::Cycle) - } + Cycle { inner: bool }, } diff --git a/src/key.rs b/src/key.rs index 82d922565..364015756 100644 --- a/src/key.rs +++ b/src/key.rs @@ -18,7 +18,7 @@ pub struct DatabaseKeyIndex { impl DatabaseKeyIndex { #[inline] - pub(crate) fn new(ingredient_index: IngredientIndex, key_index: Id) -> Self { + pub(crate) const fn new(ingredient_index: IngredientIndex, key_index: Id) -> Self { Self { key_index, ingredient_index, diff --git a/src/runtime.rs b/src/runtime.rs index 8436c684d..670d6d62f 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,6 +1,6 @@ use self::dependency_graph::DependencyGraph; use crate::durability::Durability; -use crate::function::SyncGuard; +use crate::function::{SyncGuard, SyncOwner}; use crate::key::DatabaseKeyIndex; use crate::sync::atomic::{AtomicBool, Ordering}; use crate::sync::thread::{self, ThreadId}; @@ -58,6 +58,57 @@ pub(crate) enum BlockResult<'me> { Cycle, } +pub(crate) enum BlockTransferredResult<'me> { + /// The current thread is the owner of the transferred query + /// and it can claim it if it wants to. + ImTheOwner, + + /// The query is owned/running on another thread. + OwnedBy(Box>), + + /// The query has transferred its ownership to another query previously but that query has + /// since then completed and released the lock. + Released, +} + +pub(super) struct BlockOnTransferredOwner<'me> { + dg: crate::sync::MutexGuard<'me, DependencyGraph>, + /// The query that we're trying to claim. + database_key: DatabaseKeyIndex, + /// The thread that currently owns the lock for the transferred query. + other_id: ThreadId, + /// The current thread that is trying to claim the transferred query. + thread_id: ThreadId, +} + +impl<'me> BlockOnTransferredOwner<'me> { + /// Block on the other thread to complete the computation. + pub(super) fn block(self, query_mutex_guard: SyncGuard<'me>) -> BlockResult<'me> { + // Cycle in the same thread. + if self.thread_id == self.other_id { + return BlockResult::Cycle; + } + + if self.dg.depends_on(self.other_id, self.thread_id) { + crate::tracing::debug!( + "block_on: cycle detected for {:?} in thread {thread_id:?} on {:?}", + self.database_key, + self.other_id, + thread_id = self.thread_id + ); + return BlockResult::Cycle; + } + + BlockResult::Running(Running(Box::new(BlockedOnInner { + dg: self.dg, + query_mutex_guard, + database_key: self.database_key, + other_id: self.other_id, + thread_id: self.thread_id, + }))) + } +} + pub struct Running<'me>(Box>); struct BlockedOnInner<'me> { @@ -69,10 +120,6 @@ struct BlockedOnInner<'me> { } impl Running<'_> { - pub(crate) fn database_key(&self) -> DatabaseKeyIndex { - self.0.database_key - } - /// Blocks on the other thread to complete the computation. pub(crate) fn block_on(self, zalsa: &Zalsa) { let BlockedOnInner { @@ -210,7 +257,7 @@ impl Runtime { let r_old = self.current_revision(); let r_new = r_old.next(); self.revisions[0] = r_new; - crate::tracing::debug!("new_revision: {r_old:?} -> {r_new:?}"); + crate::tracing::info!("new_revision: {r_old:?} -> {r_new:?}"); r_new } @@ -253,9 +300,40 @@ impl Runtime { }))) } + /// Tries to claim ownership of a transferred query where `thread_id` is the current thread and `query` + /// is the query (that had its ownership transferred) to claim. + /// + /// For this operation to be reasonable, the caller must ensure that the sync table lock on `query` is not released + /// before this operation completes. + pub(super) fn block_transferred( + &self, + query: DatabaseKeyIndex, + current_id: ThreadId, + ) -> BlockTransferredResult<'_> { + let dg = self.dependency_graph.lock(); + + let owner_thread = dg.thread_id_of_transferred_query(query, None); + + let Some(owner_thread_id) = owner_thread else { + // The query transferred its ownership but the owner has since then released the lock. + return BlockTransferredResult::Released; + }; + + if owner_thread_id == current_id || dg.depends_on(owner_thread_id, current_id) { + BlockTransferredResult::ImTheOwner + } else { + // Lock is owned by another thread, wait for it to be released. + BlockTransferredResult::OwnedBy(Box::new(BlockOnTransferredOwner { + dg, + database_key: query, + other_id: owner_thread_id, + thread_id: current_id, + })) + } + } + /// Invoked when this runtime completed computing `database_key` with - /// the given result `wait_result` (`wait_result` should be `None` if - /// computing `database_key` panicked and could not complete). + /// the given result `wait_result`. /// This function unblocks any dependent queries and allows them /// to continue executing. pub(crate) fn unblock_queries_blocked_on( @@ -268,6 +346,52 @@ impl Runtime { .unblock_runtimes_blocked_on(database_key, wait_result); } + /// Unblocks all transferred queries that are owned by `database_key` recursively. + /// + /// Invoked when a query completes that has been marked as transfer target (it has + /// queries that transferred their lock ownership to it) with the given `wait_result`. + /// + /// This function unblocks any dependent queries and allows them to continue executing. The + /// query `database_key` is not unblocked by this function. + #[cold] + pub(crate) fn unblock_transferred_queries_owned_by( + &self, + database_key: DatabaseKeyIndex, + wait_result: WaitResult, + ) { + self.dependency_graph + .lock() + .unblock_runtimes_blocked_on_transferred_queries_owned_by(database_key, wait_result); + } + + /// Removes the ownership transfer of `query`'s lock if it exists. + /// + /// If `query` has transferred its lock ownership to another query, this function will remove that transfer, + /// so that `query` now owns its lock again. + #[cold] + pub(super) fn undo_transfer_lock(&self, query: DatabaseKeyIndex) { + self.dependency_graph.lock().undo_transfer_lock(query); + } + + /// Transfers ownership of the lock for `query` to `new_owner_key`. + /// + /// For this operation to be reasonable, the caller must ensure that the sync table lock on `query` is not released + /// and that `new_owner_key` is currently blocked on `query`. Otherwise, `new_owner_key` might + /// complete before the lock is transferred, leaving `query` locked forever. + pub(super) fn transfer_lock( + &self, + query: DatabaseKeyIndex, + new_owner_key: DatabaseKeyIndex, + new_owner_id: SyncOwner, + ) { + self.dependency_graph.lock().transfer_lock( + query, + thread::current().id(), + new_owner_key, + new_owner_id, + ); + } + #[cfg(feature = "persistence")] pub(crate) fn deserialize_from(&mut self, other: &mut Runtime) { // The only field that is serialized is `revisions`. diff --git a/src/runtime/dependency_graph.rs b/src/runtime/dependency_graph.rs index fd26c04fa..403f7c544 100644 --- a/src/runtime/dependency_graph.rs +++ b/src/runtime/dependency_graph.rs @@ -3,11 +3,16 @@ use std::pin::Pin; use rustc_hash::FxHashMap; use smallvec::SmallVec; +use crate::function::SyncOwner; use crate::key::DatabaseKeyIndex; use crate::runtime::dependency_graph::edge::EdgeCondvar; use crate::runtime::WaitResult; use crate::sync::thread::ThreadId; use crate::sync::MutexGuard; +use crate::tracing; + +type QueryDependents = FxHashMap>; +type TransferredDependents = FxHashMap>; #[derive(Debug, Default)] pub(super) struct DependencyGraph { @@ -15,16 +20,26 @@ pub(super) struct DependencyGraph { /// `K` is blocked on some query executing in the runtime `V`. /// This encodes a graph that must be acyclic (or else deadlock /// will result). - edges: FxHashMap, + edges: Edges, /// Encodes the `ThreadId` that are blocked waiting for the result /// of a given query. - query_dependents: FxHashMap>, + query_dependents: QueryDependents, /// When a key K completes which had dependent queries Qs blocked on it, /// it stores its `WaitResult` here. As they wake up, each query Q in Qs will /// come here to fetch their results. wait_results: FxHashMap, + + /// A `K -> Q` pair indicates that the query `K`'s lock is now owned by the query + /// `Q`. It's important that `transferred` always forms a tree (must be acyclic), + /// or else deadlock will result. + transferred: FxHashMap, + + /// A `K -> [Q]` pair indicates that the query `K` owns the locks of + /// `Q`. This is the reverse mapping of `transferred` to allow efficient unlocking + /// of all dependent queries when `K` completes. + transferred_dependents: TransferredDependents, } impl DependencyGraph { @@ -32,15 +47,7 @@ impl DependencyGraph { /// /// (i.e., there is a path from `from_id` to `to_id` in the graph.) pub(super) fn depends_on(&self, from_id: ThreadId, to_id: ThreadId) -> bool { - let mut p = from_id; - while let Some(q) = self.edges.get(&p).map(|edge| edge.blocked_on_id) { - if q == to_id { - return true; - } - - p = q; - } - p == to_id + self.edges.depends_on(from_id, to_id) } /// Modifies the graph so that `from_id` is blocked @@ -138,6 +145,381 @@ impl DependencyGraph { // notify the thread. edge.notify(); } + + /// Invoked when the query `database_key` completes and it owns the locks of other queries + /// (the queries transferred their locks to `database_key`). + pub(super) fn unblock_runtimes_blocked_on_transferred_queries_owned_by( + &mut self, + database_key: DatabaseKeyIndex, + wait_result: WaitResult, + ) { + fn unblock_recursive( + me: &mut DependencyGraph, + query: DatabaseKeyIndex, + wait_result: WaitResult, + ) { + me.transferred.remove(&query); + + for query in me.transferred_dependents.remove(&query).unwrap_or_default() { + me.unblock_runtimes_blocked_on(query, wait_result); + unblock_recursive(me, query, wait_result); + } + } + + // If `database_key` is `c` and it has been transferred to `b` earlier, remove its entry. + tracing::trace!( + "unblock_runtimes_blocked_on_transferred_queries_owned_by({database_key:?}" + ); + + if let Some((_, owner)) = self.transferred.remove(&database_key) { + // If this query previously transferred its lock ownership to another query, remove + // it from that queries dependents as it is now completing. + self.transferred_dependents + .get_mut(&owner) + .unwrap() + .remove(&database_key); + } + + unblock_recursive(self, database_key, wait_result); + } + + pub(super) fn undo_transfer_lock(&mut self, database_key: DatabaseKeyIndex) { + if let Some((_, owner)) = self.transferred.remove(&database_key) { + self.transferred_dependents + .get_mut(&owner) + .unwrap() + .remove(&database_key); + } + } + + /// Recursively resolves the thread id that currently owns the lock for `database_key`. + /// + /// Returns `None` if `database_key` hasn't (or has since then been released) transferred its lock + /// and the thread id must be looked up in the `SyncTable` instead. + pub(super) fn thread_id_of_transferred_query( + &self, + database_key: DatabaseKeyIndex, + ignore: Option, + ) -> Option { + let &(mut resolved_thread, owner) = self.transferred.get(&database_key)?; + + let mut current_owner = owner; + + while let Some(&(next_thread, next_key)) = self.transferred.get(¤t_owner) { + if Some(next_key) == ignore { + break; + } + resolved_thread = next_thread; + current_owner = next_key; + } + + Some(resolved_thread) + } + + /// Modifies the graph so that the lock on `query` (currently owned by `current_thread`) is + /// transferred to `new_owner` (which is owned by `new_owner_id`). + pub(super) fn transfer_lock( + &mut self, + query: DatabaseKeyIndex, + current_thread: ThreadId, + new_owner: DatabaseKeyIndex, + new_owner_id: SyncOwner, + ) { + let new_owner_thread = match new_owner_id { + SyncOwner::Thread(thread) => thread, + SyncOwner::Transferred => { + // Skip over `query` to skip over any existing mapping from `new_owner` to `query` that may + // exist from previous transfers. + self.thread_id_of_transferred_query(new_owner, Some(query)) + .expect("new owner should be blocked on `query`") + } + }; + + debug_assert!( + new_owner_thread == current_thread || self.depends_on(new_owner_thread, current_thread), + "new owner {new_owner:?} ({new_owner_thread:?}) must be blocked on {query:?} ({current_thread:?})" + ); + + let thread_changed = match self.transferred.entry(query) { + std::collections::hash_map::Entry::Vacant(entry) => { + // Transfer `c -> b` and there's no existing entry for `c`. + entry.insert((new_owner_thread, new_owner)); + current_thread != new_owner_thread + } + std::collections::hash_map::Entry::Occupied(mut entry) => { + // If we transfer to the same owner as before, return immediately as this is a no-op. + if entry.get() == &(new_owner_thread, new_owner) { + return; + } + + // `Transfer `c -> b` after a previous `c -> d` mapping. + // Update the owner and remove the query from the old owner's dependents. + let &(old_owner_thread, old_owner) = entry.get(); + + // For the example below, remove `d` from `b`'s dependents.` + self.transferred_dependents + .get_mut(&old_owner) + .unwrap() + .remove(&query); + + entry.insert((new_owner_thread, new_owner)); + + // If we have `c -> a -> d` and we now insert a mapping `d -> c`, rewrite the mapping to + // `d -> c -> a` to avoid cycles. + // + // Or, starting with `e -> c -> a -> d -> b` insert `d -> c`. We need to rewrite the tree to + // ``` + // e -> c -> a -> b + // d / + // ``` + // + // + // A cycle between transfers can occur when a later iteration has a different outer most query than + // a previous iteration. The second iteration then hits `cycle_initial` for a different head, (e.g. for `c` where it previously was `d`). + let mut last_segment = self.transferred.entry(new_owner); + + while let std::collections::hash_map::Entry::Occupied(mut entry) = last_segment { + let source = *entry.key(); + let next_target = entry.get().1; + + // If it's `a -> d`, remove `a -> d` and insert an edge from `a -> b` + if next_target == query { + tracing::trace!( + "Remap edge {source:?} -> {next_target:?} to {source:?} -> {old_owner:?} to prevent a cycle", + ); + + // Remove `a` from the dependents of `d` and remove the mapping from `a -> d`. + self.transferred_dependents + .get_mut(&query) + .unwrap() + .remove(&source); + + // if the old mapping was `c -> d` and we now insert `d -> c`, remove `d -> c` + if old_owner == new_owner { + entry.remove(); + } else { + // otherwise (when `d` pointed to some other query, e.g. `b` in the example), + // add an edge from `a` to `b` + entry.insert((old_owner_thread, old_owner)); + self.transferred_dependents + .get_mut(&old_owner) + .unwrap() + .push(source); + } + + break; + } + + last_segment = self.transferred.entry(next_target); + } + + // We simply assume here that the thread has changed because we'd have to walk the entire + // transferred chaine of `old_owner` to know if the thread has changed. This won't save us much + // compared to just updating all dependent threads. + true + } + }; + + // Register `c` as a dependent of `b`. + let all_dependents = self.transferred_dependents.entry(new_owner).or_default(); + debug_assert!(!all_dependents.contains(&new_owner)); + all_dependents.push(query); + + if thread_changed { + tracing::debug!("Unblocking new owner of transfer target {new_owner:?}"); + self.unblock_transfer_target(query, new_owner_thread); + self.update_transferred_edges(query, new_owner_thread); + } + } + + /// Finds the one query in the dependents of the `source_query` (the one that is transferred to a new owner) + /// on which the `new_owner_id` thread blocks on and unblocks it, to ensure progress. + fn unblock_transfer_target(&mut self, source_query: DatabaseKeyIndex, new_owner_id: ThreadId) { + /// Finds the thread that's currently blocking the `new_owner_id` thread. + /// + /// Returns `Some` if there's such a thread where the first element is the query + /// that the thread is blocked on (key into `query_dependents`) and the second element + /// is the index in the list of blocked threads (index into the `query_dependents` value) for that query. + fn find_blocked_thread( + me: &DependencyGraph, + query: DatabaseKeyIndex, + new_owner_id: ThreadId, + ) -> Option<(DatabaseKeyIndex, usize)> { + if let Some(blocked_threads) = me.query_dependents.get(&query) { + for (i, id) in blocked_threads.iter().copied().enumerate() { + if id == new_owner_id || me.edges.depends_on(new_owner_id, id) { + return Some((query, i)); + } + } + } + + me.transferred_dependents + .get(&query) + .iter() + .copied() + .flatten() + .find_map(|dependent| find_blocked_thread(me, *dependent, new_owner_id)) + } + + if let Some((query, query_dependents_index)) = + find_blocked_thread(self, source_query, new_owner_id) + { + let blocked_threads = self.query_dependents.get_mut(&query).unwrap(); + + let thread_id = blocked_threads.swap_remove(query_dependents_index); + if blocked_threads.is_empty() { + self.query_dependents.remove(&query); + } + + self.unblock_runtime(thread_id, WaitResult::Completed); + } + } + + fn update_transferred_edges(&mut self, query: DatabaseKeyIndex, new_owner_thread: ThreadId) { + fn update_transferred_edges( + edges: &mut Edges, + query_dependents: &QueryDependents, + transferred_dependents: &TransferredDependents, + query: DatabaseKeyIndex, + new_owner_thread: ThreadId, + ) { + tracing::trace!("update_transferred_edges({query:?}"); + if let Some(dependents) = query_dependents.get(&query) { + for dependent in dependents.iter() { + let edge = edges.get_mut(dependent).unwrap(); + + tracing::trace!( + "Rewrite edge from {:?} to {new_owner_thread:?}", + edge.blocked_on_id + ); + edge.blocked_on_id = new_owner_thread; + debug_assert!( + !edges.depends_on(new_owner_thread, *dependent), + "Circular reference between blocked edges: {:#?}", + edges + ); + } + }; + + if let Some(dependents) = transferred_dependents.get(&query) { + for dependent in dependents { + update_transferred_edges( + edges, + query_dependents, + transferred_dependents, + *dependent, + new_owner_thread, + ) + } + } + } + + update_transferred_edges( + &mut self.edges, + &self.query_dependents, + &self.transferred_dependents, + query, + new_owner_thread, + ) + } +} + +#[derive(Debug, Default)] +struct Edges(FxHashMap); + +impl Edges { + fn depends_on(&self, from_id: ThreadId, to_id: ThreadId) -> bool { + let mut p = from_id; + while let Some(q) = self.0.get(&p).map(|edge| edge.blocked_on_id) { + if q == to_id { + return true; + } + + p = q; + } + p == to_id + } + + fn get_mut(&mut self, id: &ThreadId) -> Option<&mut edge::Edge> { + self.0.get_mut(id) + } + + fn contains_key(&self, id: &ThreadId) -> bool { + self.0.contains_key(id) + } + + fn insert(&mut self, id: ThreadId, edge: edge::Edge) { + self.0.insert(id, edge); + } + + fn remove(&mut self, id: &ThreadId) -> Option { + self.0.remove(id) + } +} + +#[derive(Debug)] +struct SmallSet(SmallVec<[T; N]>); + +impl SmallSet +where + T: PartialEq, +{ + const fn new() -> Self { + Self(SmallVec::new_const()) + } + + fn push(&mut self, value: T) { + debug_assert!(!self.0.contains(&value)); + + self.0.push(value); + } + + fn contains(&self, value: &T) -> bool { + self.0.contains(value) + } + + fn remove(&mut self, value: &T) -> bool { + if let Some(index) = self.0.iter().position(|x| x == value) { + self.0.swap_remove(index); + true + } else { + false + } + } + + fn iter(&self) -> std::slice::Iter<'_, T> { + self.0.iter() + } +} + +impl IntoIterator for SmallSet { + type Item = T; + type IntoIter = smallvec::IntoIter<[T; N]>; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a, T, const N: usize> IntoIterator for &'a SmallSet +where + T: PartialEq, +{ + type Item = &'a T; + type IntoIter = std::slice::Iter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl Default for SmallSet +where + T: PartialEq, +{ + fn default() -> Self { + Self::new() + } } mod edge { @@ -165,7 +547,7 @@ mod edge { /// Signalled whenever a query with dependents completes. /// Allows those dependents to check if they are ready to unblock. - // condvar: unsafe<'stack_frame> Pin<&'stack_frame Condvar>, + /// `condvar: unsafe<'stack_frame> Pin<&'stack_frame Condvar>` condvar: Pin<&'static EdgeCondvar>, } diff --git a/src/tracing.rs b/src/tracing.rs index 47f95d00e..6d3ae8851 100644 --- a/src/tracing.rs +++ b/src/tracing.rs @@ -7,6 +7,12 @@ macro_rules! trace { }; } +macro_rules! warn_event { + ($($x:tt)*) => { + crate::tracing::event!(WARN, $($x)*) + }; +} + macro_rules! info { ($($x:tt)*) => { crate::tracing::event!(INFO, $($x)*) @@ -25,6 +31,13 @@ macro_rules! debug_span { }; } +#[expect(unused_macros)] +macro_rules! info_span { + ($($x:tt)*) => { + crate::tracing::span!(INFO, $($x)*) + }; +} + macro_rules! event { ($level:ident, $($x:tt)*) => {{ let event = { @@ -51,4 +64,5 @@ macro_rules! span { }}; } -pub(crate) use {debug, debug_span, event, info, span, trace}; +#[expect(unused_imports)] +pub(crate) use {debug, debug_span, event, info, info_span, span, trace, warn_event as warn}; diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index e332b516f..39d0c489c 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -1,4 +1,6 @@ use std::cell::{RefCell, UnsafeCell}; +use std::fmt; +use std::fmt::Formatter; use std::panic::UnwindSafe; use std::ptr::{self, NonNull}; @@ -11,7 +13,7 @@ use crate::accumulator::{ Accumulator, }; use crate::active_query::{CompletedQuery, QueryStack}; -use crate::cycle::{empty_cycle_heads, CycleHeads, IterationCount}; +use crate::cycle::{empty_cycle_heads, AtomicIterationCount, CycleHeads, IterationCount}; use crate::durability::Durability; use crate::key::DatabaseKeyIndex; use crate::runtime::Stamp; @@ -513,7 +515,8 @@ impl QueryRevisionsExtra { accumulated, cycle_heads, tracked_struct_ids, - iteration, + iteration: iteration.into(), + cycle_converged: false, })) }; @@ -521,7 +524,6 @@ impl QueryRevisionsExtra { } } -#[derive(Debug)] #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] struct QueryRevisionsExtraInner { #[cfg(feature = "accumulator")] @@ -561,7 +563,12 @@ struct QueryRevisionsExtraInner { /// iterate again. cycle_heads: CycleHeads, - iteration: IterationCount, + iteration: AtomicIterationCount, + + /// Stores for nested cycle heads whether they've converged in the last iteration. + /// This value is always `false` for other queries. + #[cfg_attr(feature = "persistence", serde(skip))] + cycle_converged: bool, } impl QueryRevisionsExtraInner { @@ -573,6 +580,7 @@ impl QueryRevisionsExtraInner { tracked_struct_ids, cycle_heads, iteration: _, + cycle_converged: _, } = self; #[cfg(feature = "accumulator")] @@ -583,6 +591,44 @@ impl QueryRevisionsExtraInner { } } +impl fmt::Debug for QueryRevisionsExtraInner { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + struct FmtTrackedStructIds<'a>(&'a ThinVec<(Identity, Id)>); + + impl fmt::Debug for FmtTrackedStructIds<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut f = f.debug_list(); + + if self.0.len() > 5 { + f.entries(&self.0[..5]); + f.finish_non_exhaustive() + } else { + f.entries(self.0); + f.finish() + } + } + } + + let mut f = f.debug_struct("QueryRevisionsExtraInner"); + + f.field("cycle_heads", &self.cycle_heads) + .field("iteration", &self.iteration) + .field("cycle_converged", &self.cycle_converged); + + #[cfg(feature = "accumulator")] + { + f.field("accumulated", &self.accumulated); + } + + f.field( + "tracked_struct_ids", + &FmtTrackedStructIds(&self.tracked_struct_ids), + ); + + f.finish() + } +} + #[cfg(not(feature = "shuttle"))] #[cfg(target_pointer_width = "64")] const _: [(); std::mem::size_of::()] = [(); std::mem::size_of::<[usize; 4]>()]; @@ -605,7 +651,7 @@ impl QueryRevisions { #[cfg(feature = "accumulator")] AccumulatedMap::default(), ThinVec::default(), - CycleHeads::initial(query), + CycleHeads::initial(query, IterationCount::initial()), IterationCount::initial(), ), } @@ -654,17 +700,55 @@ impl QueryRevisions { }; } - pub(crate) const fn iteration(&self) -> IterationCount { + pub(crate) fn cycle_converged(&self) -> bool { match &self.extra.0 { - Some(extra) => extra.iteration, + Some(extra) => extra.cycle_converged, + None => false, + } + } + + pub(crate) fn set_cycle_converged(&mut self, cycle_converged: bool) { + if let Some(extra) = &mut self.extra.0 { + extra.cycle_converged = cycle_converged + } + } + + pub(crate) fn iteration(&self) -> IterationCount { + match &self.extra.0 { + Some(extra) => extra.iteration.load(), None => IterationCount::initial(), } } + pub(crate) fn set_iteration_count( + &self, + database_key_index: DatabaseKeyIndex, + iteration_count: IterationCount, + ) { + let Some(extra) = &self.extra.0 else { + return; + }; + debug_assert!(extra.iteration.load() <= iteration_count); + + extra.iteration.store(iteration_count); + + extra + .cycle_heads + .update_iteration_count(database_key_index, iteration_count); + } + /// Updates the iteration count if this query has any cycle heads. Otherwise it's a no-op. - pub(crate) fn update_iteration_count(&mut self, iteration_count: IterationCount) { + pub(crate) fn update_iteration_count_mut( + &mut self, + cycle_head_index: DatabaseKeyIndex, + iteration_count: IterationCount, + ) { if let Some(extra) = &mut self.extra.0 { - extra.iteration = iteration_count + extra.iteration.store_mut(iteration_count); + + extra + .cycle_heads + .update_iteration_count_mut(cycle_head_index, iteration_count); } } diff --git a/tests/backtrace.rs b/tests/backtrace.rs index 74124c1ab..b611cac86 100644 --- a/tests/backtrace.rs +++ b/tests/backtrace.rs @@ -108,7 +108,7 @@ fn backtrace_works() { at tests/backtrace.rs:32 1: query_cycle(Id(2)) at tests/backtrace.rs:45 - cycle heads: query_cycle(Id(2)) -> IterationCount(0) + cycle heads: query_cycle(Id(2)) -> iteration = 0 2: query_f(Id(2)) at tests/backtrace.rs:40 "#]] @@ -119,9 +119,9 @@ fn backtrace_works() { query stacktrace: 0: query_e(Id(3)) -> (R1, Durability::LOW) at tests/backtrace.rs:32 - 1: query_cycle(Id(3)) -> (R1, Durability::HIGH, iteration = IterationCount(0)) + 1: query_cycle(Id(3)) -> (R1, Durability::HIGH, iteration = 0) at tests/backtrace.rs:45 - cycle heads: query_cycle(Id(3)) -> IterationCount(0) + cycle heads: query_cycle(Id(3)) -> iteration = 0 2: query_f(Id(3)) -> (R1, Durability::HIGH) at tests/backtrace.rs:40 "#]] diff --git a/tests/cycle.rs b/tests/cycle.rs index 7a7e26a07..5e46cc0be 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -95,18 +95,22 @@ impl Input { } } + #[track_caller] fn assert(&self, db: &dyn Db, expected: Value) { assert_eq!(self.eval(db), expected) } + #[track_caller] fn assert_value(&self, db: &dyn Db, expected: u8) { self.assert(db, Value::N(expected)) } + #[track_caller] fn assert_bounds(&self, db: &dyn Db) { self.assert(db, Value::OutOfBounds) } + #[track_caller] fn assert_count(&self, db: &dyn Db) { self.assert(db, Value::TooManyIterations) } @@ -893,7 +897,7 @@ fn cycle_unchanged() { /// /// If nothing in a nested cycle changed in the new revision, no part of the cycle should /// re-execute. -#[test] +#[test_log::test] fn cycle_unchanged_nested() { let mut db = ExecuteValidateLoggerDatabase::default(); let a_in = Inputs::new(&db, vec![]); @@ -978,7 +982,7 @@ fn cycle_unchanged_nested_intertwined() { e.assert_value(&db, 60); } - db.assert_logs_len(15 + i); + db.assert_logs_len(13 + i); // next revision, we change only A, which is not part of the cycle and the cycle does not // depend on. diff --git a/tests/cycle_tracked.rs b/tests/cycle_tracked.rs index 154ba3370..2e0c2cfd0 100644 --- a/tests/cycle_tracked.rs +++ b/tests/cycle_tracked.rs @@ -269,7 +269,7 @@ fn cycle_recover_with_structs<'db>( CycleRecoveryAction::Iterate } -#[test] +#[test_log::test] fn test_cycle_with_fixpoint_structs() { let mut db = EventLoggerDatabase::default(); diff --git a/tests/parallel/cycle_a_t1_b_t2.rs b/tests/parallel/cycle_a_t1_b_t2.rs index d9d5ca365..ad21b7963 100644 --- a/tests/parallel/cycle_a_t1_b_t2.rs +++ b/tests/parallel/cycle_a_t1_b_t2.rs @@ -62,7 +62,7 @@ fn initial(_db: &dyn KnobsDatabase) -> CycleValue { #[test_log::test] fn the_test() { crate::sync::check(|| { - tracing::debug!("New run"); + tracing::debug!("Starting new run"); let db_t1 = Knobs::default(); let db_t2 = db_t1.clone(); diff --git a/tests/parallel/cycle_a_t1_b_t2_fallback.rs b/tests/parallel/cycle_a_t1_b_t2_fallback.rs index 8005a9c23..b2d6631cc 100644 --- a/tests/parallel/cycle_a_t1_b_t2_fallback.rs +++ b/tests/parallel/cycle_a_t1_b_t2_fallback.rs @@ -55,11 +55,18 @@ fn the_test() { use crate::Knobs; crate::sync::check(|| { + tracing::debug!("Starting new run"); let db_t1 = Knobs::default(); let db_t2 = db_t1.clone(); - let t1 = thread::spawn(move || query_a(&db_t1)); - let t2 = thread::spawn(move || query_b(&db_t2)); + let t1 = thread::spawn(move || { + let _span = tracing::debug_span!("t1", thread_id = ?thread::current().id()).entered(); + query_a(&db_t1) + }); + let t2 = thread::spawn(move || { + let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered(); + query_b(&db_t2) + }); let (r_t1, r_t2) = (t1.join(), t2.join()); diff --git a/tests/parallel/cycle_nested_deep.rs b/tests/parallel/cycle_nested_deep.rs index 7b7c2f42a..f2b355616 100644 --- a/tests/parallel/cycle_nested_deep.rs +++ b/tests/parallel/cycle_nested_deep.rs @@ -63,6 +63,7 @@ fn initial(_db: &dyn KnobsDatabase) -> CycleValue { #[test_log::test] fn the_test() { crate::sync::check(|| { + tracing::debug!("Starting new run"); let db_t1 = Knobs::default(); let db_t2 = db_t1.clone(); let db_t3 = db_t1.clone(); diff --git a/tests/parallel/cycle_nested_deep_conditional.rs b/tests/parallel/cycle_nested_deep_conditional.rs index 316612845..4eff75189 100644 --- a/tests/parallel/cycle_nested_deep_conditional.rs +++ b/tests/parallel/cycle_nested_deep_conditional.rs @@ -72,7 +72,7 @@ fn initial(_db: &dyn KnobsDatabase) -> CycleValue { #[test_log::test] fn the_test() { crate::sync::check(|| { - tracing::debug!("New run"); + tracing::debug!("Starting new run"); let db_t1 = Knobs::default(); let db_t2 = db_t1.clone(); let db_t3 = db_t1.clone(); diff --git a/tests/parallel/cycle_nested_deep_conditional_changed.rs b/tests/parallel/cycle_nested_deep_conditional_changed.rs index 7c96d808d..51d506456 100644 --- a/tests/parallel/cycle_nested_deep_conditional_changed.rs +++ b/tests/parallel/cycle_nested_deep_conditional_changed.rs @@ -81,7 +81,7 @@ fn the_test() { use crate::sync; use salsa::Setter as _; sync::check(|| { - tracing::debug!("New run"); + tracing::debug!("Starting new run"); // This is a bit silly but it works around https://github.com/awslabs/shuttle/issues/192 static INITIALIZE: sync::Mutex> = @@ -108,36 +108,36 @@ fn the_test() { } let t1 = thread::spawn(move || { + let _span = tracing::info_span!("t1", thread_id = ?thread::current().id()).entered(); let (db, input) = get_db(|db, input| { query_a(db, input); }); - let _span = tracing::debug_span!("t1", thread_id = ?thread::current().id()).entered(); - query_a(&db, input) }); let t2 = thread::spawn(move || { + let _span = tracing::info_span!("t2", thread_id = ?thread::current().id()).entered(); let (db, input) = get_db(|db, input| { query_b(db, input); }); - let _span = tracing::debug_span!("t4", thread_id = ?thread::current().id()).entered(); query_b(&db, input) }); let t3 = thread::spawn(move || { + let _span = tracing::info_span!("t3", thread_id = ?thread::current().id()).entered(); let (db, input) = get_db(|db, input| { query_d(db, input); }); - let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered(); query_d(&db, input) }); let t4 = thread::spawn(move || { + let _span = tracing::info_span!("t4", thread_id = ?thread::current().id()).entered(); + let (db, input) = get_db(|db, input| { query_e(db, input); }); - let _span = tracing::debug_span!("t3", thread_id = ?thread::current().id()).entered(); query_e(&db, input) }); diff --git a/tests/parallel/cycle_nested_deep_panic.rs b/tests/parallel/cycle_nested_deep_panic.rs new file mode 100644 index 000000000..8b89f362a --- /dev/null +++ b/tests/parallel/cycle_nested_deep_panic.rs @@ -0,0 +1,142 @@ +// Shuttle doesn't like panics inside of its runtime. +#![cfg(not(feature = "shuttle"))] + +//! Tests that salsa doesn't get stuck after a panic in a nested cycle function. + +use crate::sync::thread; +use crate::{Knobs, KnobsDatabase}; +use std::fmt; +use std::panic::catch_unwind; + +use salsa::CycleRecoveryAction; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(3); + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + query_b(db) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + let c_value = query_c(db); + CycleValue(c_value.0 + 1).min(MAX) +} + +#[salsa::tracked] +fn query_c(db: &dyn KnobsDatabase) -> CycleValue { + let d_value = query_d(db); + + if d_value > CycleValue(0) { + let e_value = query_e(db); + let b_value = query_b(db); + CycleValue(d_value.0.max(e_value.0).max(b_value.0)) + } else { + let a_value = query_a(db); + CycleValue(d_value.0.max(a_value.0)) + } +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_d(db: &dyn KnobsDatabase) -> CycleValue { + query_b(db) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_e(db: &dyn KnobsDatabase) -> CycleValue { + query_c(db) +} + +fn cycle_fn( + _db: &dyn KnobsDatabase, + _value: &CycleValue, + _count: u32, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + +fn run() { + tracing::debug!("Starting new run"); + let db_t1 = Knobs::default(); + let db_t2 = db_t1.clone(); + let db_t3 = db_t1.clone(); + let db_t4 = db_t1.clone(); + + let t1 = thread::spawn(move || { + let _span = tracing::debug_span!("t1", thread_id = ?thread::current().id()).entered(); + catch_unwind(|| { + db_t1.wait_for(1); + query_a(&db_t1) + }) + }); + let t2 = thread::spawn(move || { + let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered(); + catch_unwind(|| { + db_t2.wait_for(1); + + query_b(&db_t2) + }) + }); + let t3 = thread::spawn(move || { + let _span = tracing::debug_span!("t3", thread_id = ?thread::current().id()).entered(); + catch_unwind(|| { + db_t3.signal(2); + query_d(&db_t3) + }) + }); + + let r_t1 = t1.join().unwrap(); + let r_t2 = t2.join().unwrap(); + let r_t3 = t3.join().unwrap(); + + assert_is_set_cycle_error(r_t1); + assert_is_set_cycle_error(r_t2); + assert_is_set_cycle_error(r_t3); + + // Pulling the cycle again at a later point should still result in a panic. + assert_is_set_cycle_error(catch_unwind(|| query_d(&db_t4))); +} + +#[test_log::test] +fn the_test() { + let count = if cfg!(miri) { 1 } else { 200 }; + + for _ in 0..count { + run() + } +} + +#[track_caller] +fn assert_is_set_cycle_error(result: Result>) +where + T: fmt::Debug, +{ + let err = result.expect_err("expected an error"); + + if let Some(message) = err.downcast_ref::<&str>() { + assert!( + message.contains("set cycle_fn/cycle_initial to fixpoint iterate"), + "Expected error message to contain 'set cycle_fn/cycle_initial to fixpoint iterate', but got: {}", + message + ); + } else if let Some(message) = err.downcast_ref::() { + assert!( + message.contains("set cycle_fn/cycle_initial to fixpoint iterate"), + "Expected error message to contain 'set cycle_fn/cycle_initial to fixpoint iterate', but got: {}", + message + ); + } else if err.downcast_ref::().is_some() { + // This is okay, because Salsa throws a Cancelled::PropagatedPanic when a panic occurs in a query + // that it blocks on. + } else { + std::panic::resume_unwind(err); + } +} diff --git a/tests/parallel/cycle_nested_three_threads.rs b/tests/parallel/cycle_nested_three_threads.rs index c761a80f4..22232bd85 100644 --- a/tests/parallel/cycle_nested_three_threads.rs +++ b/tests/parallel/cycle_nested_three_threads.rs @@ -76,9 +76,18 @@ fn the_test() { let db_t2 = db_t1.clone(); let db_t3 = db_t1.clone(); - let t1 = thread::spawn(move || query_a(&db_t1)); - let t2 = thread::spawn(move || query_b(&db_t2)); - let t3 = thread::spawn(move || query_c(&db_t3)); + let t1 = thread::spawn(move || { + let _span = tracing::info_span!("t1", thread_id = ?thread::current().id()).entered(); + query_a(&db_t1) + }); + let t2 = thread::spawn(move || { + let _span = tracing::info_span!("t2", thread_id = ?thread::current().id()).entered(); + query_b(&db_t2) + }); + let t3 = thread::spawn(move || { + let _span = tracing::info_span!("t3", thread_id = ?thread::current().id()).entered(); + query_c(&db_t3) + }); let r_t1 = t1.join().unwrap(); let r_t2 = t2.join().unwrap(); diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index a764a864c..6bc89d2a2 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -9,6 +9,7 @@ mod cycle_ab_peeping_c; mod cycle_nested_deep; mod cycle_nested_deep_conditional; mod cycle_nested_deep_conditional_changed; +mod cycle_nested_deep_panic; mod cycle_nested_three_threads; mod cycle_nested_three_threads_changed; mod cycle_panic; @@ -33,7 +34,7 @@ pub(crate) mod sync { pub use shuttle::thread; pub fn check(f: impl Fn() + Send + Sync + 'static) { - shuttle::check_pct(f, 1000, 50); + shuttle::check_pct(f, 2500, 50); } } From 9cfe41c343ff43f258520967081e32caad467bc0 Mon Sep 17 00:00:00 2001 From: Ben Beasley Date: Sun, 19 Oct 2025 11:17:05 +0100 Subject: [PATCH 51/65] Fix missing license files in published macros/macro-rules crates (#1009) --- components/salsa-macro-rules/LICENSE-APACHE | 1 + components/salsa-macro-rules/LICENSE-MIT | 1 + components/salsa-macros/LICENSE-APACHE | 1 + components/salsa-macros/LICENSE-MIT | 1 + 4 files changed, 4 insertions(+) create mode 120000 components/salsa-macro-rules/LICENSE-APACHE create mode 120000 components/salsa-macro-rules/LICENSE-MIT create mode 120000 components/salsa-macros/LICENSE-APACHE create mode 120000 components/salsa-macros/LICENSE-MIT diff --git a/components/salsa-macro-rules/LICENSE-APACHE b/components/salsa-macro-rules/LICENSE-APACHE new file mode 120000 index 000000000..1cd601d0a --- /dev/null +++ b/components/salsa-macro-rules/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/components/salsa-macro-rules/LICENSE-MIT b/components/salsa-macro-rules/LICENSE-MIT new file mode 120000 index 000000000..b2cfbdc7b --- /dev/null +++ b/components/salsa-macro-rules/LICENSE-MIT @@ -0,0 +1 @@ +../../LICENSE-MIT \ No newline at end of file diff --git a/components/salsa-macros/LICENSE-APACHE b/components/salsa-macros/LICENSE-APACHE new file mode 120000 index 000000000..1cd601d0a --- /dev/null +++ b/components/salsa-macros/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/components/salsa-macros/LICENSE-MIT b/components/salsa-macros/LICENSE-MIT new file mode 120000 index 000000000..b2cfbdc7b --- /dev/null +++ b/components/salsa-macros/LICENSE-MIT @@ -0,0 +1 @@ +../../LICENSE-MIT \ No newline at end of file From a4113cd472539fdbc44a4e6a139d0124211da921 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Mon, 20 Oct 2025 13:32:46 +0200 Subject: [PATCH 52/65] Simplify `WaitGroup` implementation (#958) * Simplify `WaitGroup` implementation * Slightly cheaper `get_mut` Co-authored-by: Ibraheem Ahmed --------- Co-authored-by: Ibraheem Ahmed --- src/storage.rs | 41 ++++++++++++++++++++--------------------- src/table.rs | 5 ++++- src/views.rs | 6 +++++- src/zalsa_local.rs | 4 ++-- 4 files changed, 31 insertions(+), 25 deletions(-) diff --git a/src/storage.rs b/src/storage.rs index f63981e4f..443b53221 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -25,8 +25,6 @@ pub struct StorageHandle { impl Clone for StorageHandle { fn clone(&self) -> Self { - *self.coordinate.clones.lock() += 1; - Self { zalsa_impl: self.zalsa_impl.clone(), coordinate: CoordinateDrop(Arc::clone(&self.coordinate)), @@ -53,7 +51,7 @@ impl StorageHandle { Self { zalsa_impl: Arc::new(Zalsa::new::(event_callback, jars)), coordinate: CoordinateDrop(Arc::new(Coordinate { - clones: Mutex::new(1), + coordinate_lock: Mutex::default(), cvar: Default::default(), })), phantom: PhantomData, @@ -95,17 +93,6 @@ impl Drop for Storage { } } -struct Coordinate { - /// Counter of the number of clones of actor. Begins at 1. - /// Incremented when cloned, decremented when dropped. - clones: Mutex, - cvar: Condvar, -} - -// We cannot panic while holding a lock to `clones: Mutex` and therefore we cannot enter an -// inconsistent state. -impl RefUnwindSafe for Coordinate {} - impl Default for Storage { fn default() -> Self { Self::new(None) @@ -168,12 +155,15 @@ impl Storage { .zalsa_impl .event(&|| Event::new(EventKind::DidSetCancellationFlag)); - let mut clones = self.handle.coordinate.clones.lock(); - while *clones != 1 { - clones = self.handle.coordinate.cvar.wait(clones); - } - // The ref count on the `Arc` should now be 1 - let zalsa = Arc::get_mut(&mut self.handle.zalsa_impl).unwrap(); + let mut coordinate_lock = self.handle.coordinate.coordinate_lock.lock(); + let zalsa = loop { + if Arc::strong_count(&self.handle.zalsa_impl) == 1 { + // SAFETY: The strong count is 1, and we never create any weak pointers, + // so we have a unique reference. + break unsafe { &mut *(Arc::as_ptr(&self.handle.zalsa_impl).cast_mut()) }; + } + coordinate_lock = self.handle.coordinate.cvar.wait(coordinate_lock); + }; // cancellation is done, so reset the flag zalsa.runtime_mut().reset_cancellation_flag(); zalsa @@ -260,6 +250,16 @@ impl Clone for Storage { } } +/// A simplified `WaitGroup`, this is used together with `Arc` as the actual counter +struct Coordinate { + coordinate_lock: Mutex<()>, + cvar: Condvar, +} + +// We cannot panic while holding a lock to `clones: Mutex` and therefore we cannot enter an +// inconsistent state. +impl RefUnwindSafe for Coordinate {} + struct CoordinateDrop(Arc); impl std::ops::Deref for CoordinateDrop { @@ -272,7 +272,6 @@ impl std::ops::Deref for CoordinateDrop { impl Drop for CoordinateDrop { fn drop(&mut self) { - *self.0.clones.lock() -= 1; self.0.cvar.notify_all(); } } diff --git a/src/table.rs b/src/table.rs index 53cf10cce..5505c1c05 100644 --- a/src/table.rs +++ b/src/table.rs @@ -252,7 +252,10 @@ impl Table { } let allocated_idx = self.push_page::(ingredient, memo_types.clone()); - assert_eq!(allocated_idx, page_idx); + assert_eq!( + allocated_idx, page_idx, + "allocated index does not match requested index" + ); } }; } diff --git a/src/views.rs b/src/views.rs index d449779c3..d58f349f0 100644 --- a/src/views.rs +++ b/src/views.rs @@ -108,7 +108,11 @@ impl Views { &self, func: fn(NonNull) -> NonNull, ) -> &DatabaseDownCaster { - assert_eq!(self.source_type_id, TypeId::of::()); + assert_eq!( + self.source_type_id, + TypeId::of::(), + "mismatched source type" + ); let target_type_id = TypeId::of::(); if let Some((_, caster)) = self .view_casters diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 39d0c489c..7b0399178 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -1173,7 +1173,7 @@ impl ActiveQueryGuard<'_> { unsafe { self.local_state.with_query_stack_unchecked_mut(|stack| { #[cfg(debug_assertions)] - assert_eq!(stack.len(), self.push_len); + assert_eq!(stack.len(), self.push_len, "mismatched push and pop"); let frame = stack.last_mut().unwrap(); frame.tracked_struct_ids_mut().seed(tracked_struct_ids); }) @@ -1195,7 +1195,7 @@ impl ActiveQueryGuard<'_> { unsafe { self.local_state.with_query_stack_unchecked_mut(|stack| { #[cfg(debug_assertions)] - assert_eq!(stack.len(), self.push_len); + assert_eq!(stack.len(), self.push_len, "mismatched push and pop"); let frame = stack.last_mut().unwrap(); frame.seed_iteration(durability, changed_at, edges, untracked_read, tracked_ids); }) From ffa811dca2352c6c54e10346df363fdc0d51dd46 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Wed, 22 Oct 2025 13:41:56 +0200 Subject: [PATCH 53/65] Remove experimental parallel feature (#1013) --- src/lib.rs | 5 - src/parallel.rs | 91 ------------ tests/parallel/main.rs | 3 - tests/parallel/parallel_cancellation.rs | 67 --------- tests/parallel/parallel_join.rs | 176 ------------------------ tests/parallel/parallel_map.rs | 100 -------------- 6 files changed, 442 deletions(-) delete mode 100644 src/parallel.rs delete mode 100644 tests/parallel/parallel_cancellation.rs delete mode 100644 tests/parallel/parallel_join.rs delete mode 100644 tests/parallel/parallel_map.rs diff --git a/src/lib.rs b/src/lib.rs index 8ab47379d..8c50c9052 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,11 +37,6 @@ mod zalsa_local; #[cfg(not(feature = "inventory"))] mod nonce; -#[cfg(feature = "rayon")] -mod parallel; - -#[cfg(feature = "rayon")] -pub use parallel::{join, par_map}; #[cfg(feature = "macros")] pub use salsa_macros::{accumulator, db, input, interned, tracked, Supertype, Update}; diff --git a/src/parallel.rs b/src/parallel.rs deleted file mode 100644 index 8a0bde655..000000000 --- a/src/parallel.rs +++ /dev/null @@ -1,91 +0,0 @@ -use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator}; - -use crate::{database::RawDatabase, views::DatabaseDownCaster, Database}; - -pub fn par_map(db: &Db, inputs: impl IntoParallelIterator, op: F) -> C -where - Db: Database + ?Sized + Send, - F: Fn(&Db, T) -> R + Sync + Send, - T: Send, - R: Send + Sync, - C: FromParallelIterator, -{ - let views = db.zalsa().views(); - let caster = &views.downcaster_for::(); - let db_caster = &views.downcaster_for::(); - inputs - .into_par_iter() - .map_with( - DbForkOnClone(db.fork_db(), caster, db_caster), - |db, element| op(db.as_view(), element), - ) - .collect() -} - -struct DbForkOnClone<'views, Db: Database + ?Sized>( - RawDatabase<'static>, - &'views DatabaseDownCaster, - &'views DatabaseDownCaster, -); - -// SAFETY: `T: Send` -> `&own T: Send`, `DbForkOnClone` is an owning pointer -unsafe impl Send for DbForkOnClone<'_, Db> {} - -impl DbForkOnClone<'_, Db> { - fn as_view(&self) -> &Db { - // SAFETY: The downcaster ensures that the pointer is valid for the lifetime of the view. - unsafe { self.1.downcast_unchecked(self.0) } - } -} - -impl Drop for DbForkOnClone<'_, Db> { - fn drop(&mut self) { - // SAFETY: `caster` is derived from a `db` fitting for our database clone - let db = unsafe { self.1.downcast_mut_unchecked(self.0) }; - // SAFETY: `db` has been box allocated and leaked by `fork_db` - _ = unsafe { Box::from_raw(db) }; - } -} - -impl Clone for DbForkOnClone<'_, Db> { - fn clone(&self) -> Self { - DbForkOnClone( - // SAFETY: `caster` is derived from a `db` fitting for our database clone - unsafe { self.2.downcast_unchecked(self.0) }.fork_db(), - self.1, - self.2, - ) - } -} - -pub fn join(db: &Db, a: A, b: B) -> (RA, RB) -where - A: FnOnce(&Db) -> RA + Send, - B: FnOnce(&Db) -> RB + Send, - RA: Send, - RB: Send, -{ - #[derive(Copy, Clone)] - struct AssertSend(T); - // SAFETY: We send owning pointers over, which are Send, given the `Db` type parameter above is Send - unsafe impl Send for AssertSend {} - - let caster = &db.zalsa().views().downcaster_for::(); - // we need to fork eagerly, as `rayon::join_context` gives us no option to tell whether we get - // moved to another thread before the closure is executed - let db_a = AssertSend(db.fork_db()); - let db_b = AssertSend(db.fork_db()); - let res = rayon::join( - // SAFETY: `caster` is derived from a `db` fitting for our database clone - move || a(unsafe { caster.downcast_unchecked({ db_a }.0) }), - // SAFETY: `caster` is derived from a `db` fitting for our database clone - move || b(unsafe { caster.downcast_unchecked({ db_b }.0) }), - ); - - // SAFETY: `db` has been box allocated and leaked by `fork_db` - // FIXME: Clean this mess up, RAII - _ = unsafe { Box::from_raw(caster.downcast_mut_unchecked(db_a.0)) }; - // SAFETY: `db` has been box allocated and leaked by `fork_db` - _ = unsafe { Box::from_raw(caster.downcast_mut_unchecked(db_b.0)) }; - res -} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 6bc89d2a2..859d14f47 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -14,9 +14,6 @@ mod cycle_nested_three_threads; mod cycle_nested_three_threads_changed; mod cycle_panic; mod cycle_provisional_depending_on_itself; -mod parallel_cancellation; -mod parallel_join; -mod parallel_map; #[cfg(not(feature = "shuttle"))] pub(crate) mod sync { diff --git a/tests/parallel/parallel_cancellation.rs b/tests/parallel/parallel_cancellation.rs deleted file mode 100644 index a82437d54..000000000 --- a/tests/parallel/parallel_cancellation.rs +++ /dev/null @@ -1,67 +0,0 @@ -// Shuttle doesn't like panics inside of its runtime. -#![cfg(not(feature = "shuttle"))] - -//! Test for thread cancellation. -use salsa::{Cancelled, Setter}; - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input(debug)] -struct MyInput { - field: i32, -} - -#[salsa::tracked] -fn a1(db: &dyn KnobsDatabase, input: MyInput) -> MyInput { - db.signal(1); - db.wait_for(2); - dummy(db, input) -} - -#[salsa::tracked] -fn dummy(_db: &dyn KnobsDatabase, _input: MyInput) -> MyInput { - panic!("should never get here!") -} - -// Cancellation signalling test -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 -// | wait for stage 1 -// signal stage 1 set input, triggers cancellation -// wait for stage 2 (blocks) triggering cancellation sends stage 2 -// | -// (unblocked) -// dummy -// panics - -#[test] -fn execute() { - let mut db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - move || a1(&db, input) - }); - - db.signal_on_did_cancel(2); - input.set_field(&mut db).to(2); - - // Assert thread A *should* was cancelled - let cancelled = thread_a - .join() - .unwrap_err() - .downcast::() - .unwrap(); - - // and inspect the output - expect_test::expect![[r#" - PendingWrite - "#]] - .assert_debug_eq(&cancelled); -} diff --git a/tests/parallel/parallel_join.rs b/tests/parallel/parallel_join.rs deleted file mode 100644 index f39e9a5fc..000000000 --- a/tests/parallel/parallel_join.rs +++ /dev/null @@ -1,176 +0,0 @@ -#![cfg(all(feature = "rayon", not(feature = "shuttle")))] - -// test for rayon-like join interactions. - -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, -}; - -use salsa::{Cancelled, Database, Setter, Storage}; - -use crate::signal::Signal; - -#[salsa::input] -struct ParallelInput { - a: u32, - b: u32, -} - -#[salsa::tracked] -fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> (u32, u32) { - salsa::join(db, |db| input.a(db) + 1, |db| input.b(db) - 1) -} - -#[salsa::tracked] -fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> (u32, u32) { - db.signal(1); - salsa::join( - db, - |db| { - db.wait_for(2); - input.a(db) + dummy(db) - }, - |db| { - db.wait_for(2); - input.b(db) + dummy(db) - }, - ) -} - -#[salsa::tracked] -fn dummy(_db: &dyn KnobsDatabase) -> u32 { - panic!("should never get here!") -} - -#[test] -#[cfg_attr(miri, ignore)] -fn execute() { - let db = salsa::DatabaseImpl::new(); - - let input = ParallelInput::new(&db, 10, 20); - - tracked_fn(&db, input); -} - -// we expect this to panic, as `salsa::par_map` needs to be called from a query. -#[test] -#[cfg_attr(miri, ignore)] -#[should_panic] -fn direct_calls_panic() { - let db = salsa::DatabaseImpl::new(); - - let input = ParallelInput::new(&db, 10, 20); - let (_, _) = salsa::join(&db, |db| input.a(db) + 1, |db| input.b(db) - 1); -} - -// Cancellation signalling test -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 -// | wait for stage 1 -// signal stage 1 set input, triggers cancellation -// wait for stage 2 (blocks) triggering cancellation sends stage 2 -// | -// (unblocked) -// dummy -// panics - -#[test] -#[cfg_attr(miri, ignore)] -fn execute_cancellation() { - let mut db = Knobs::default(); - - let input = ParallelInput::new(&db, 10, 20); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - move || a1(&db, input) - }); - - db.signal_on_did_cancel(2); - input.set_a(&mut db).to(30); - - // Assert thread A was cancelled - let cancelled = thread_a - .join() - .unwrap_err() - .downcast::() - .unwrap(); - - // and inspect the output - expect_test::expect![[r#" - PendingWrite - "#]] - .assert_debug_eq(&cancelled); -} - -#[salsa::db] -trait KnobsDatabase: Database { - fn signal(&self, stage: usize); - fn wait_for(&self, stage: usize); -} - -/// A copy of `tests\parallel\setup.rs` that does not assert, as the assert is incorrect for the -/// purposes of this test. -#[salsa::db] -struct Knobs { - storage: salsa::Storage, - signal: Arc, - signal_on_did_cancel: Arc, -} - -impl Knobs { - pub fn signal_on_did_cancel(&self, stage: usize) { - self.signal_on_did_cancel.store(stage, Ordering::Release); - } -} - -impl Clone for Knobs { - #[track_caller] - fn clone(&self) -> Self { - Self { - storage: self.storage.clone(), - signal: self.signal.clone(), - signal_on_did_cancel: self.signal_on_did_cancel.clone(), - } - } -} - -impl Default for Knobs { - fn default() -> Self { - let signal = >::default(); - let signal_on_did_cancel = Arc::new(AtomicUsize::new(0)); - - Self { - storage: Storage::new(Some(Box::new({ - let signal = signal.clone(); - let signal_on_did_cancel = signal_on_did_cancel.clone(); - move |event| { - if let salsa::EventKind::DidSetCancellationFlag = event.kind { - signal.signal(signal_on_did_cancel.load(Ordering::Acquire)); - } - } - }))), - signal, - signal_on_did_cancel, - } - } -} - -#[salsa::db] -impl salsa::Database for Knobs {} - -#[salsa::db] -impl KnobsDatabase for Knobs { - fn signal(&self, stage: usize) { - self.signal.signal(stage); - } - - fn wait_for(&self, stage: usize) { - self.signal.wait_for(stage); - } -} diff --git a/tests/parallel/parallel_map.rs b/tests/parallel/parallel_map.rs deleted file mode 100644 index f05b73363..000000000 --- a/tests/parallel/parallel_map.rs +++ /dev/null @@ -1,100 +0,0 @@ -#![cfg(all(feature = "rayon", not(feature = "shuttle")))] -// test for rayon-like parallel map interactions. - -use salsa::{Cancelled, Setter}; - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input] -struct ParallelInput { - field: Vec, -} - -#[salsa::tracked] -fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> Vec { - salsa::par_map(db, input.field(db), |_db, field| field + 1) -} - -#[salsa::tracked] -fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> Vec { - db.signal(1); - salsa::par_map(db, input.field(db), |db, field| { - db.wait_for(2); - field + dummy(db) - }) -} - -#[salsa::tracked] -fn dummy(_db: &dyn KnobsDatabase) -> u32 { - panic!("should never get here!") -} - -#[test] -#[cfg_attr(miri, ignore)] -fn execute() { - let db = salsa::DatabaseImpl::new(); - - let counts = (1..=10).collect::>(); - let input = ParallelInput::new(&db, counts); - - tracked_fn(&db, input); -} - -// we expect this to panic, as `salsa::par_map` needs to be called from a query. -#[test] -#[cfg_attr(miri, ignore)] -#[should_panic] -fn direct_calls_panic() { - let db = salsa::DatabaseImpl::new(); - - let counts = (1..=10).collect::>(); - let input = ParallelInput::new(&db, counts); - let _: Vec = salsa::par_map(&db, input.field(&db), |_db, field| field + 1); -} - -// Cancellation signalling test -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 -// | wait for stage 1 -// signal stage 1 set input, triggers cancellation -// wait for stage 2 (blocks) triggering cancellation sends stage 2 -// | -// (unblocked) -// dummy -// panics - -#[test] -#[cfg_attr(miri, ignore)] -fn execute_cancellation() { - let mut db = Knobs::default(); - - let counts = (1..=10).collect::>(); - let input = ParallelInput::new(&db, counts); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - move || a1(&db, input) - }); - - let counts = (2..=20).collect::>(); - - db.signal_on_did_cancel(2); - input.set_field(&mut db).to(counts); - - // Assert thread A *should* was cancelled - let cancelled = thread_a - .join() - .unwrap_err() - .downcast::() - .unwrap(); - - // and inspect the output - expect_test::expect![[r#" - PendingWrite - "#]] - .assert_debug_eq(&cancelled); -} From 16d51d63d515aca6b646529a73c037633b7c1ec4 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Thu, 23 Oct 2025 08:15:01 +0200 Subject: [PATCH 54/65] Fix hangs in multithreaded fixpoint iteration (#1010) * Fix race condition between releasing a transferred query's lock and the same query blocking on the outer head in `provisional_retry` * Fix infinite loop in `provisional_retry --- src/active_query.rs | 4 - src/function/execute.rs | 19 +++- src/function/fetch.rs | 46 +------- src/function/maybe_changed_after.rs | 4 +- src/function/memo.rs | 89 ++------------- src/function/sync.rs | 50 ++++++--- src/runtime.rs | 10 +- src/runtime/dependency_graph.rs | 68 +++++++---- tests/parallel/cycle_iteration_mismatch.rs | 124 +++++++++++++++++++++ tests/parallel/main.rs | 1 + 10 files changed, 243 insertions(+), 172 deletions(-) create mode 100644 tests/parallel/cycle_iteration_mismatch.rs diff --git a/src/active_query.rs b/src/active_query.rs index d830fece1..bb5987fcd 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -158,10 +158,6 @@ impl ActiveQuery { } } - pub(super) fn iteration_count(&self) -> IterationCount { - self.iteration_count - } - pub(crate) fn tracked_struct_ids(&self) -> &IdentityMap { &self.tracked_struct_ids } diff --git a/src/function/execute.rs b/src/function/execute.rs index 67f76e145..aa4339bef 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -28,6 +28,11 @@ where /// * `db`, the database. /// * `active_query`, the active stack frame for the query to execute. /// * `opt_old_memo`, the older memo, if any existed. Used for backdating. + /// + /// # Returns + /// The newly computed memo or `None` if this query is part of a larger cycle + /// and `execute` blocked on a cycle head running on another thread. In this case, + /// the memo is potentially outdated and needs to be refetched. #[inline(never)] pub(super) fn execute<'db>( &'db self, @@ -35,7 +40,7 @@ where mut claim_guard: ClaimGuard<'db>, zalsa_local: &'db ZalsaLocal, opt_old_memo: Option<&Memo<'db, C>>, - ) -> &'db Memo<'db, C> { + ) -> Option<&'db Memo<'db, C>> { let database_key_index = claim_guard.database_key_index(); let zalsa = claim_guard.zalsa(); @@ -80,7 +85,7 @@ where // We need to mark the memo as finalized so other cycle participants that have fallbacks // will be verified (participants that don't have fallbacks will not be verified). memo.revisions.verified_final.store(true, Ordering::Release); - return memo; + return Some(memo); } // If we're in the middle of a cycle and we have a fallback, use it instead. @@ -125,7 +130,7 @@ where self.diff_outputs(zalsa, database_key_index, old_memo, &completed_query); } - self.insert_memo( + let memo = self.insert_memo( zalsa, id, Memo::new( @@ -134,7 +139,13 @@ where completed_query.revisions, ), memo_ingredient_index, - ) + ); + + if claim_guard.drop() { + None + } else { + Some(memo) + } } fn execute_maybe_iterate<'db>( diff --git a/src/function/fetch.rs b/src/function/fetch.rs index ef42708a7..a3f3705f4 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -58,20 +58,11 @@ where id: Id, ) -> &'db Memo<'db, C> { let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); - let mut retry_count = 0; + loop { if let Some(memo) = self .fetch_hot(zalsa, id, memo_ingredient_index) - .or_else(|| { - self.fetch_cold_with_retry( - zalsa, - zalsa_local, - db, - id, - memo_ingredient_index, - &mut retry_count, - ) - }) + .or_else(|| self.fetch_cold(zalsa, zalsa_local, db, id, memo_ingredient_index)) { return memo; } @@ -104,33 +95,6 @@ where } } - fn fetch_cold_with_retry<'db>( - &'db self, - zalsa: &'db Zalsa, - zalsa_local: &'db ZalsaLocal, - db: &'db C::DbView, - id: Id, - memo_ingredient_index: MemoIngredientIndex, - retry_count: &mut u32, - ) -> Option<&'db Memo<'db, C>> { - let memo = self.fetch_cold(zalsa, zalsa_local, db, id, memo_ingredient_index)?; - - // If we get back a provisional cycle memo, and it's provisional on any cycle heads - // that are claimed by a different thread, we can't propagate the provisional memo - // any further (it could escape outside the cycle); we need to block on the other - // thread completing fixpoint iteration of the cycle, and then we can re-query for - // our no-longer-provisional memo. - // That is only correct for fixpoint cycles, though: `FallbackImmediate` cycles - // never have provisional entries. - if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate - || !memo.provisional_retry(zalsa, zalsa_local, self.database_key_index(id), retry_count) - { - Some(memo) - } else { - None - } - } - fn fetch_cold<'db>( &'db self, zalsa: &'db Zalsa, @@ -151,7 +115,7 @@ where if let Some(memo) = memo { if memo.value.is_some() { - memo.block_on_heads(zalsa, zalsa_local); + memo.block_on_heads(zalsa); } } } @@ -212,9 +176,7 @@ where } } - let memo = self.execute(db, claim_guard, zalsa_local, opt_old_memo); - - Some(memo) + self.execute(db, claim_guard, zalsa_local, opt_old_memo) } #[cold] diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 698285055..62839e865 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -228,7 +228,7 @@ where // `in_cycle` tracks if the enclosing query is in a cycle. `deep_verify.cycle_heads` tracks // if **this query** encountered a cycle (which means there's some provisional value somewhere floating around). if old_memo.value.is_some() && !cycle_heads.has_any() { - let memo = self.execute(db, claim_guard, zalsa_local, Some(old_memo)); + let memo = self.execute(db, claim_guard, zalsa_local, Some(old_memo))?; let changed_at = memo.revisions.changed_at; // Always assume that a provisional value has changed. @@ -500,7 +500,7 @@ where return on_stack; } - let cycle_heads_iter = TryClaimCycleHeadsIter::new(zalsa, zalsa_local, cycle_heads); + let cycle_heads_iter = TryClaimCycleHeadsIter::new(zalsa, cycle_heads); for cycle_head in cycle_heads_iter { match cycle_head { diff --git a/src/function/memo.rs b/src/function/memo.rs index 302ca73c3..2e84bc04f 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -14,7 +14,7 @@ use crate::runtime::Running; use crate::sync::atomic::Ordering; use crate::table::memo::MemoTableWithTypesMut; use crate::zalsa::{MemoIngredientIndex, Zalsa}; -use crate::zalsa_local::{QueryOriginRef, QueryRevisions, ZalsaLocal}; +use crate::zalsa_local::{QueryOriginRef, QueryRevisions}; use crate::{Event, EventKind, Id, Revision}; impl IngredientImpl { @@ -132,50 +132,12 @@ impl<'db, C: Configuration> Memo<'db, C> { !self.revisions.verified_final.load(Ordering::Relaxed) } - /// Invoked when `refresh_memo` is about to return a memo to the caller; if that memo is - /// provisional, and its cycle head is claimed by another thread, we need to wait for that - /// other thread to complete the fixpoint iteration, and then retry fetching our own memo. - /// - /// Return `true` if the caller should retry, `false` if the caller should go ahead and return - /// this memo to the caller. - #[inline(always)] - pub(super) fn provisional_retry( - &self, - zalsa: &Zalsa, - zalsa_local: &ZalsaLocal, - database_key_index: DatabaseKeyIndex, - retry_count: &mut u32, - ) -> bool { - if self.block_on_heads(zalsa, zalsa_local) { - // If we get here, we are a provisional value of - // the cycle head (either initial value, or from a later iteration) and should be - // returned to caller to allow fixpoint iteration to proceed. - false - } else { - assert!( - *retry_count <= 20000, - "Provisional memo retry limit exceeded for {database_key_index:?}; \ - this usually indicates a bug in salsa's cycle caching/locking. \ - (retried {retry_count} times)", - ); - - *retry_count += 1; - - // all our cycle heads are complete; re-fetch - // and we should get a non-provisional memo. - crate::tracing::debug!( - "Retrying provisional memo {database_key_index:?} after awaiting cycle heads." - ); - true - } - } - /// Blocks on all cycle heads (recursively) that this memo depends on. /// /// Returns `true` if awaiting all cycle heads results in a cycle. This means, they're all waiting /// for us to make progress. #[inline(always)] - pub(super) fn block_on_heads(&self, zalsa: &Zalsa, zalsa_local: &ZalsaLocal) -> bool { + pub(super) fn block_on_heads(&self, zalsa: &Zalsa) -> bool { // IMPORTANT: If you make changes to this function, make sure to run `cycle_nested_deep` with // shuttle with at least 10k iterations. @@ -184,16 +146,12 @@ impl<'db, C: Configuration> Memo<'db, C> { return true; } - return block_on_heads_cold(zalsa, zalsa_local, cycle_heads); + return block_on_heads_cold(zalsa, cycle_heads); #[inline(never)] - fn block_on_heads_cold( - zalsa: &Zalsa, - zalsa_local: &ZalsaLocal, - heads: &CycleHeads, - ) -> bool { + fn block_on_heads_cold(zalsa: &Zalsa, heads: &CycleHeads) -> bool { let _entered = crate::tracing::debug_span!("block_on_heads").entered(); - let cycle_heads = TryClaimCycleHeadsIter::new(zalsa, zalsa_local, heads); + let cycle_heads = TryClaimCycleHeadsIter::new(zalsa, heads); let mut all_cycles = true; for claim_result in cycle_heads { @@ -447,6 +405,7 @@ mod persistence { } } +#[derive(Debug)] pub(super) enum TryClaimHeadsResult<'me> { /// Claiming the cycle head results in a cycle. Cycle { @@ -465,19 +424,15 @@ pub(super) enum TryClaimHeadsResult<'me> { /// Iterator to try claiming the transitive cycle heads of a memo. pub(super) struct TryClaimCycleHeadsIter<'a> { zalsa: &'a Zalsa, - zalsa_local: &'a ZalsaLocal, + cycle_heads: CycleHeadsIterator<'a>, } impl<'a> TryClaimCycleHeadsIter<'a> { - pub(super) fn new( - zalsa: &'a Zalsa, - zalsa_local: &'a ZalsaLocal, - cycle_heads: &'a CycleHeads, - ) -> Self { + pub(super) fn new(zalsa: &'a Zalsa, cycle_heads: &'a CycleHeads) -> Self { Self { zalsa, - zalsa_local, + cycle_heads: cycle_heads.iter(), } } @@ -488,31 +443,7 @@ impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { fn next(&mut self) -> Option { let head = self.cycle_heads.next()?; - let head_database_key = head.database_key_index; - let head_iteration_count = head.iteration_count.load(); - - // The most common case is that the head is already in the query stack. So let's check that first. - // SAFETY: We do not access the query stack reentrantly. - if let Some(current_iteration_count) = unsafe { - self.zalsa_local.with_query_stack_unchecked(|stack| { - stack - .iter() - .rev() - .find(|query| query.database_key_index == head_database_key) - .map(|query| query.iteration_count()) - }) - } { - crate::tracing::trace!( - "Waiting for {head_database_key:?} results in a cycle (because it is already in the query stack)" - ); - return Some(TryClaimHeadsResult::Cycle { - head_iteration_count, - memo_iteration_count: current_iteration_count, - verified_at: self.zalsa.current_revision(), - }); - } - let head_key_index = head_database_key.key_index(); let ingredient = self .zalsa @@ -543,7 +474,7 @@ impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { Some(TryClaimHeadsResult::Cycle { memo_iteration_count: current_iteration_count, - head_iteration_count, + head_iteration_count: head.iteration_count.load(), verified_at, }) } diff --git a/src/function/sync.rs b/src/function/sync.rs index 97a36262c..02f1bffd0 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -273,11 +273,11 @@ impl<'me> ClaimGuard<'me> { runtime.undo_transfer_lock(database_key_index); } + runtime.unblock_queries_blocked_on(database_key_index, wait_result); + if is_transfer_target { runtime.unblock_transferred_queries_owned_by(database_key_index, wait_result); } - - runtime.unblock_queries_blocked_on(database_key_index, wait_result); } #[cold] @@ -299,7 +299,7 @@ impl<'me> ClaimGuard<'me> { #[cold] #[inline(never)] - pub(crate) fn transfer(&self, new_owner: DatabaseKeyIndex) { + pub(crate) fn transfer(&self, new_owner: DatabaseKeyIndex) -> bool { let owner_ingredient = self.zalsa.lookup_ingredient(new_owner.ingredient_index()); // Get the owning thread of `new_owner`. @@ -333,22 +333,27 @@ impl<'me> ClaimGuard<'me> { .get_mut(&self.key_index) .expect("key should only be claimed/released once"); - self.zalsa - .runtime() - .transfer_lock(self_key, new_owner, new_owner_thread_id); - *id = SyncOwner::Transferred; *claimed_twice = false; + + self.zalsa + .runtime() + .transfer_lock(self_key, new_owner, new_owner_thread_id, syncs) } -} -impl Drop for ClaimGuard<'_> { - fn drop(&mut self) { - if thread::panicking() { - self.release_panicking(); - return; - } + /// Drops the claim on the memo. + /// + /// Returns `true` if the lock was transferred to another query and + /// this thread blocked waiting for the new owner's lock to be released. + /// In that case, any computed memo need to be refetched because they may have + /// changed since `drop` was called. + pub(crate) fn drop(mut self) -> bool { + let refetch = self.drop_impl(); + std::mem::forget(self); + refetch + } + fn drop_impl(&mut self) -> bool { match self.mode { ReleaseMode::Default => { let mut syncs = self.sync_table.syncs.lock(); @@ -357,17 +362,28 @@ impl Drop for ClaimGuard<'_> { .expect("key should only be claimed/released once"); self.release(state, WaitResult::Completed); + false } ReleaseMode::SelfOnly => { self.release_self(); + false } - ReleaseMode::TransferTo(new_owner) => { - self.transfer(new_owner); - } + ReleaseMode::TransferTo(new_owner) => self.transfer(new_owner), } } } +impl Drop for ClaimGuard<'_> { + fn drop(&mut self) { + if thread::panicking() { + self.release_panicking(); + return; + } + + self.drop_impl(); + } +} + impl std::fmt::Debug for SyncTable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SyncTable").finish() diff --git a/src/runtime.rs b/src/runtime.rs index 670d6d62f..48caf53ec 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -383,13 +383,17 @@ impl Runtime { query: DatabaseKeyIndex, new_owner_key: DatabaseKeyIndex, new_owner_id: SyncOwner, - ) { - self.dependency_graph.lock().transfer_lock( + guard: SyncGuard, + ) -> bool { + let dg = self.dependency_graph.lock(); + DependencyGraph::transfer_lock( + dg, query, thread::current().id(), new_owner_key, new_owner_id, - ); + guard, + ) } #[cfg(feature = "persistence")] diff --git a/src/runtime/dependency_graph.rs b/src/runtime/dependency_graph.rs index 403f7c544..9b8cbe221 100644 --- a/src/runtime/dependency_graph.rs +++ b/src/runtime/dependency_graph.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use rustc_hash::FxHashMap; use smallvec::SmallVec; -use crate::function::SyncOwner; +use crate::function::{SyncGuard, SyncOwner}; use crate::key::DatabaseKeyIndex; use crate::runtime::dependency_graph::edge::EdgeCondvar; use crate::runtime::WaitResult; @@ -199,18 +199,23 @@ impl DependencyGraph { pub(super) fn thread_id_of_transferred_query( &self, database_key: DatabaseKeyIndex, - ignore: Option, + skip_over: Option, ) -> Option { let &(mut resolved_thread, owner) = self.transferred.get(&database_key)?; let mut current_owner = owner; while let Some(&(next_thread, next_key)) = self.transferred.get(¤t_owner) { - if Some(next_key) == ignore { - break; + current_owner = next_key; + + // Ignore the `skip_over` key. E.g. if we have `a -> b -> c` and we want to resolve `a` but are transferring `b` to `c`, then + // we don't want to resolve `a` to the owner of `c`. But for `a -> c -> b`, we want resolve `a` to the owner of `c` and not `b` + // (because `b` will be owned by `a`). + if Some(next_key) == skip_over { + continue; } + resolved_thread = next_thread; - current_owner = next_key; } Some(resolved_thread) @@ -218,29 +223,36 @@ impl DependencyGraph { /// Modifies the graph so that the lock on `query` (currently owned by `current_thread`) is /// transferred to `new_owner` (which is owned by `new_owner_id`). + /// + /// Note, this function will block if `new_owner` runs on a different thread, unless `new_owner` is blocked + /// on current thread after transferring the query ownership. + /// + /// Returns `true` if the transfer blocked on `new_owner` (in which case it might be necessary to refetch any previously computed memos). pub(super) fn transfer_lock( - &mut self, + mut me: MutexGuard, query: DatabaseKeyIndex, current_thread: ThreadId, new_owner: DatabaseKeyIndex, new_owner_id: SyncOwner, - ) { + guard: SyncGuard, + ) -> bool { + let dg = &mut *me; let new_owner_thread = match new_owner_id { SyncOwner::Thread(thread) => thread, SyncOwner::Transferred => { // Skip over `query` to skip over any existing mapping from `new_owner` to `query` that may // exist from previous transfers. - self.thread_id_of_transferred_query(new_owner, Some(query)) + dg.thread_id_of_transferred_query(new_owner, Some(query)) .expect("new owner should be blocked on `query`") } }; debug_assert!( - new_owner_thread == current_thread || self.depends_on(new_owner_thread, current_thread), + new_owner_thread == current_thread || dg.depends_on(new_owner_thread, current_thread), "new owner {new_owner:?} ({new_owner_thread:?}) must be blocked on {query:?} ({current_thread:?})" ); - let thread_changed = match self.transferred.entry(query) { + let thread_changed = match dg.transferred.entry(query) { std::collections::hash_map::Entry::Vacant(entry) => { // Transfer `c -> b` and there's no existing entry for `c`. entry.insert((new_owner_thread, new_owner)); @@ -249,7 +261,7 @@ impl DependencyGraph { std::collections::hash_map::Entry::Occupied(mut entry) => { // If we transfer to the same owner as before, return immediately as this is a no-op. if entry.get() == &(new_owner_thread, new_owner) { - return; + return false; } // `Transfer `c -> b` after a previous `c -> d` mapping. @@ -257,7 +269,7 @@ impl DependencyGraph { let &(old_owner_thread, old_owner) = entry.get(); // For the example below, remove `d` from `b`'s dependents.` - self.transferred_dependents + dg.transferred_dependents .get_mut(&old_owner) .unwrap() .remove(&query); @@ -273,10 +285,9 @@ impl DependencyGraph { // d / // ``` // - // // A cycle between transfers can occur when a later iteration has a different outer most query than // a previous iteration. The second iteration then hits `cycle_initial` for a different head, (e.g. for `c` where it previously was `d`). - let mut last_segment = self.transferred.entry(new_owner); + let mut last_segment = dg.transferred.entry(new_owner); while let std::collections::hash_map::Entry::Occupied(mut entry) = last_segment { let source = *entry.key(); @@ -289,19 +300,19 @@ impl DependencyGraph { ); // Remove `a` from the dependents of `d` and remove the mapping from `a -> d`. - self.transferred_dependents + dg.transferred_dependents .get_mut(&query) .unwrap() .remove(&source); - // if the old mapping was `c -> d` and we now insert `d -> c`, remove `d -> c` + // if the old mapping was `c -> d` and we now insert `d -> c`, remove `c -> d` if old_owner == new_owner { entry.remove(); } else { // otherwise (when `d` pointed to some other query, e.g. `b` in the example), // add an edge from `a` to `b` entry.insert((old_owner_thread, old_owner)); - self.transferred_dependents + dg.transferred_dependents .get_mut(&old_owner) .unwrap() .push(source); @@ -310,7 +321,7 @@ impl DependencyGraph { break; } - last_segment = self.transferred.entry(next_target); + last_segment = dg.transferred.entry(next_target); } // We simply assume here that the thread has changed because we'd have to walk the entire @@ -321,15 +332,30 @@ impl DependencyGraph { }; // Register `c` as a dependent of `b`. - let all_dependents = self.transferred_dependents.entry(new_owner).or_default(); + let all_dependents = dg.transferred_dependents.entry(new_owner).or_default(); debug_assert!(!all_dependents.contains(&new_owner)); all_dependents.push(query); if thread_changed { tracing::debug!("Unblocking new owner of transfer target {new_owner:?}"); - self.unblock_transfer_target(query, new_owner_thread); - self.update_transferred_edges(query, new_owner_thread); + dg.unblock_transfer_target(query, new_owner_thread); + dg.update_transferred_edges(query, new_owner_thread); + + // Block on the new owner, unless new owner is blocked on this query. + // This is necessary to avoid a race between `fetch` completing and `provisional_retry` blocking on the + // first cycle head. + if current_thread != new_owner_thread + && !dg.depends_on(new_owner_thread, current_thread) + { + crate::tracing::info!( + "block_on: thread {current_thread:?} is blocking on {new_owner:?} in thread {new_owner_thread:?}", + ); + Self::block_on(me, current_thread, new_owner, new_owner_thread, guard); + return true; + } } + + false } /// Finds the one query in the dependents of the `source_query` (the one that is transferred to a new owner) diff --git a/tests/parallel/cycle_iteration_mismatch.rs b/tests/parallel/cycle_iteration_mismatch.rs new file mode 100644 index 000000000..17cc60108 --- /dev/null +++ b/tests/parallel/cycle_iteration_mismatch.rs @@ -0,0 +1,124 @@ +//! Test for iteration count mismatch bug where cycle heads have different iteration counts +//! +//! This test aims to reproduce the scenario where: +//! 1. A memo has multiple cycle heads with different iteration counts +//! 2. When validating, iteration counts mismatch causes re-execution +//! 3. After re-execution, the memo still has the same mismatched iteration counts + +use crate::sync::thread; +use crate::{Knobs, KnobsDatabase}; +use salsa::CycleRecoveryAction; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(5); + +// Query A: First cycle head - will iterate multiple times +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + let b = query_b(db); + CycleValue(b.0 + 1).min(MAX) +} + +// Query B: Depends on C and D, creating complex dependencies +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + let c = query_c(db); + let d = query_d(db); + CycleValue(c.0.max(d.0) + 1).min(MAX) +} + +// Query C: Creates a cycle back to A +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_c(db: &dyn KnobsDatabase) -> CycleValue { + let a = query_a(db); + // Also depends on E to create more complex cycle structure + let e = query_e(db); + CycleValue(a.0.max(e.0)) +} + +// Query D: Part of a separate cycle with E +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_d(db: &dyn KnobsDatabase) -> CycleValue { + let e = query_e(db); + CycleValue(e.0 + 1).min(MAX) +} + +// Query E: Depends back on D and F +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_e(db: &dyn KnobsDatabase) -> CycleValue { + let d = query_d(db); + let f = query_f(db); + CycleValue(d.0.max(f.0) + 1).min(MAX) +} + +// Query F: Creates another cycle that might have different iteration count +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_f(db: &dyn KnobsDatabase) -> CycleValue { + // Create a cycle that depends on earlier queries + let b = query_b(db); + let e = query_e(db); + CycleValue(b.0.max(e.0)) +} + +fn cycle_fn( + _db: &dyn KnobsDatabase, + _value: &CycleValue, + _count: u32, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + +#[test_log::test] +fn test_iteration_count_mismatch() { + crate::sync::check(|| { + tracing::debug!("Starting new run"); + let db_t1 = Knobs::default(); + let db_t2 = db_t1.clone(); + let db_t3 = db_t1.clone(); + let db_t4 = db_t1.clone(); + + // Thread 1: Starts with query_a - main cycle head + let t1 = thread::spawn(move || { + let _span = tracing::debug_span!("t1", thread_id = ?thread::current().id()).entered(); + query_a(&db_t1) + }); + + // Thread 2: Starts with query_d - separate cycle that will have different iteration + let t2 = thread::spawn(move || { + let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered(); + query_d(&db_t2) + }); + + // Thread 3: Starts with query_f after others have started + let t3 = thread::spawn(move || { + let _span = tracing::debug_span!("t3", thread_id = ?thread::current().id()).entered(); + query_f(&db_t3) + }); + + // Thread 4: Queries b which depends on multiple cycles + let t4 = thread::spawn(move || { + let _span = tracing::debug_span!("t4", thread_id = ?thread::current().id()).entered(); + query_b(&db_t4) + }); + + let r_t1 = t1.join().unwrap(); + let r_t2 = t2.join().unwrap(); + let r_t3 = t3.join().unwrap(); + let r_t4 = t4.join().unwrap(); + + // All queries should converge to the same value + assert_eq!(r_t1, r_t2); + assert_eq!(r_t2, r_t3); + assert_eq!(r_t3, r_t4); + + // They should have computed a non-initial value + assert!(r_t1.0 > MIN.0); + }); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 859d14f47..1062d4899 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -6,6 +6,7 @@ mod signal; mod cycle_a_t1_b_t2; mod cycle_a_t1_b_t2_fallback; mod cycle_ab_peeping_c; +mod cycle_iteration_mismatch; mod cycle_nested_deep; mod cycle_nested_deep_conditional; mod cycle_nested_deep_conditional_changed; From d38145c29574758de7ffbe8a13cd4584c3b09161 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Thu, 23 Oct 2025 08:36:56 +0200 Subject: [PATCH 55/65] Expose the query ID and the last provisional value to the cycle recovery function (#1012) * Expose the query ID and the last provisional value to the cycle recovery function * Mark cycle as converged if fallback value is the same as the last provisional * Make `cycle_fn` optional --- benches/dataflow.rs | 4 +++ book/src/cycles.md | 2 ++ .../salsa-macro-rules/src/setup_tracked_fn.rs | 6 +++-- .../src/unexpected_cycle_recovery.rs | 6 ++--- components/salsa-macros/src/tracked_fn.rs | 7 ++--- src/cycle.rs | 6 ++++- src/function.rs | 24 ++++++++++++----- src/function/execute.rs | 6 ++++- src/function/memo.rs | 2 ++ tests/backtrace.rs | 11 +------- tests/cycle.rs | 27 ++++--------------- tests/cycle_accumulate.rs | 2 ++ tests/cycle_initial_call_back_into_cycle.rs | 10 +------ tests/cycle_initial_call_query.rs | 10 +------ tests/cycle_maybe_changed_after.rs | 26 +++--------------- tests/cycle_output.rs | 11 +------- tests/cycle_recovery_call_back_into_cycle.rs | 8 +++++- tests/cycle_recovery_call_query.rs | 2 ++ tests/cycle_regression_455.rs | 12 +-------- tests/cycle_tracked.rs | 25 +++-------------- tests/cycle_tracked_own_input.rs | 13 ++------- tests/dataflow.rs | 4 +++ tests/parallel/cycle_a_t1_b_t2.rs | 14 ++-------- tests/parallel/cycle_ab_peeping_c.rs | 14 ++-------- tests/parallel/cycle_iteration_mismatch.rs | 21 +++++---------- tests/parallel/cycle_nested_deep.rs | 20 ++++---------- .../parallel/cycle_nested_deep_conditional.rs | 20 ++++---------- .../cycle_nested_deep_conditional_changed.rs | 21 ++++----------- tests/parallel/cycle_nested_deep_panic.rs | 18 +++---------- tests/parallel/cycle_nested_three_threads.rs | 16 +++-------- .../cycle_nested_three_threads_changed.rs | 17 +++--------- tests/parallel/cycle_panic.rs | 8 +++++- .../cycle_provisional_depending_on_itself.rs | 15 +++-------- 33 files changed, 127 insertions(+), 281 deletions(-) diff --git a/benches/dataflow.rs b/benches/dataflow.rs index db099c6b2..d1acfd27b 100644 --- a/benches/dataflow.rs +++ b/benches/dataflow.rs @@ -76,6 +76,8 @@ fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { fn def_cycle_recover( _db: &dyn Db, + _id: salsa::Id, + _last_provisional_value: &Type, value: &Type, count: u32, _def: Definition, @@ -89,6 +91,8 @@ fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { fn use_cycle_recover( _db: &dyn Db, + _id: salsa::Id, + _last_provisional_value: &Type, value: &Type, count: u32, _use: Use, diff --git a/book/src/cycles.md b/book/src/cycles.md index 2215b8ff3..2e2c6e7b8 100644 --- a/book/src/cycles.md +++ b/book/src/cycles.md @@ -21,6 +21,8 @@ fn initial(_db: &dyn KnobsDatabase) -> u32 { } ``` +The `cycle_fn` is optional. The default implementation always returns `Iterate`. + If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `initial_fn` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the current value and the number of iterations that have occurred so far). The `cycle_fn` can return `salsa::CycleRecoveryAction::Iterate` to indicate that the cycle should iterate again, or `salsa::CycleRecoveryAction::Fallback(value)` to indicate that fixpoint iteration should resume starting with the given value (which should be a value that will converge quickly). The cycle will iterate until it converges: that is, until two successive iterations produce the same result. diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 945021f3a..961b5b4f8 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -308,11 +308,13 @@ macro_rules! setup_tracked_fn { fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, + id: salsa::Id, + last_provisional_value: &Self::Output<$db_lt>, value: &Self::Output<$db_lt>, - count: u32, + iteration_count: u32, ($($input_id),*): ($($interned_input_ty),*) ) -> $zalsa::CycleRecoveryAction> { - $($cycle_recovery_fn)*(db, value, count, $($input_id),*) + $($cycle_recovery_fn)*(db, id, last_provisional_value, value, iteration_count, $($input_id),*) } fn id_to_input<$db_lt>(zalsa: &$db_lt $zalsa::Zalsa, key: salsa::Id) -> Self::Input<$db_lt> { diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index 8d56d54f3..aa6161d28 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -3,10 +3,10 @@ // a macro because it can take a variadic number of arguments. #[macro_export] macro_rules! unexpected_cycle_recovery { - ($db:ident, $value:ident, $count:ident, $($other_inputs:ident),*) => {{ - std::mem::drop($db); + ($db:ident, $id:ident, $last_provisional_value:ident, $new_value:ident, $count:ident, $($other_inputs:ident),*) => {{ + let (_db, _id, _last_provisional_value, _new_value, _count) = ($db, $id, $last_provisional_value, $new_value, $count); std::mem::drop(($($other_inputs,)*)); - panic!("cannot recover from cycle") + salsa::CycleRecoveryAction::Iterate }}; } diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 5c6fab7d2..12f9170c7 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -286,9 +286,10 @@ impl Macro { self.args.cycle_fn.as_ref().unwrap(), "must provide `cycle_initial` along with `cycle_fn`", )), - (None, Some(_), None) => Err(syn::Error::new_spanned( - self.args.cycle_initial.as_ref().unwrap(), - "must provide `cycle_fn` along with `cycle_initial`", + (None, Some(cycle_initial), None) => Ok(( + quote!((salsa::plumbing::unexpected_cycle_recovery!)), + quote!((#cycle_initial)), + quote!(Fixpoint), )), (None, None, Some(cycle_result)) => Ok(( quote!((salsa::plumbing::unexpected_cycle_recovery!)), diff --git a/src/cycle.rs b/src/cycle.rs index c9a9b82c1..09ec51525 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -70,7 +70,11 @@ pub enum CycleRecoveryAction { /// Iterate the cycle again to look for a fixpoint. Iterate, - /// Cut off iteration and use the given result value for this query. + /// Use the given value as the result for the current iteration instead + /// of the value computed by the query function. + /// + /// Returning `Fallback` doesn't stop the fixpoint iteration. It only + /// allows the iterate function to return a different value. Fallback(T), } diff --git a/src/function.rs b/src/function.rs index 259dff14b..1cf3e9478 100644 --- a/src/function.rs +++ b/src/function.rs @@ -94,20 +94,32 @@ pub trait Configuration: Any { /// value from the latest iteration of this cycle. `count` is the number of cycle iterations /// completed so far. /// - /// # Iteration count semantics + /// # Id /// - /// The `count` parameter isn't guaranteed to start from zero or to be contiguous: + /// The id can be used to uniquely identify the query instance. This can be helpful + /// if the cycle function has to re-identify a value it returned previously. /// - /// * **Initial value**: `count` may be non-zero on the first call for a given query if that + /// # Values + /// + /// The `last_provisional_value` is the value from the previous iteration of this cycle + /// and `value` is the new value that was computed in the current iteration. + /// + /// # Iteration count + /// + /// The `iteration` parameter isn't guaranteed to start from zero or to be contiguous: + /// + /// * **Initial value**: `iteration` may be non-zero on the first call for a given query if that /// query becomes the outermost cycle head after a nested cycle complete a few iterations. In this case, - /// `count` continues from the nested cycle's iteration count rather than resetting to zero. + /// `iteration` continues from the nested cycle's iteration count rather than resetting to zero. /// * **Non-contiguous values**: This function isn't called if this cycle is part of an outer cycle /// and the value for this query remains unchanged for one iteration. But the outer cycle might /// keep iterating because other heads keep changing. fn recover_from_cycle<'db>( db: &'db Self::DbView, - value: &Self::Output<'db>, - count: u32, + id: Id, + last_provisional_value: &Self::Output<'db>, + new_value: &Self::Output<'db>, + iteration: u32, input: Self::Input<'db>, ) -> CycleRecoveryAction>; diff --git a/src/function/execute.rs b/src/function/execute.rs index aa4339bef..3acfaadc8 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -320,7 +320,7 @@ where I am a cycle head, comparing last provisional value with new value" ); - let this_converged = C::values_equal(&new_value, last_provisional_value); + let mut this_converged = C::values_equal(&new_value, last_provisional_value); // If this is the outermost cycle, use the maximum iteration count of all cycles. // This is important for when later iterations introduce new cycle heads (that then @@ -341,6 +341,8 @@ where // cycle-recovery function what to do: match C::recover_from_cycle( db, + id, + last_provisional_value, &new_value, iteration_count.as_u32(), C::id_to_input(zalsa, id), @@ -351,6 +353,8 @@ where "{database_key_index:?}: execute: user cycle_fn says to fall back" ); new_value = fallback_value; + + this_converged = C::values_equal(&new_value, last_provisional_value); } } } diff --git a/src/function/memo.rs b/src/function/memo.rs index 2e84bc04f..8fe0c1dd8 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -557,6 +557,8 @@ mod _memory_usage { fn recover_from_cycle<'db>( _: &'db Self::DbView, + _: Id, + _: &Self::Output<'db>, _: &Self::Output<'db>, _: u32, _: Self::Input<'db>, diff --git a/tests/backtrace.rs b/tests/backtrace.rs index b611cac86..0adf517cd 100644 --- a/tests/backtrace.rs +++ b/tests/backtrace.rs @@ -42,7 +42,7 @@ fn query_f(db: &dyn Database, thing: Thing) -> String { query_cycle(db, thing) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_cycle(db: &dyn Database, thing: Thing) -> String { let backtrace = query_cycle(db, thing); if backtrace.is_empty() { @@ -56,15 +56,6 @@ fn cycle_initial(_db: &dyn salsa::Database, _thing: Thing) -> String { String::new() } -fn cycle_fn( - _db: &dyn salsa::Database, - _value: &str, - _count: u32, - _thing: Thing, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate -} - #[test] fn backtrace_works() { let db = DatabaseImpl::default(); diff --git a/tests/cycle.rs b/tests/cycle.rs index 5e46cc0be..0c4d686af 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -125,6 +125,8 @@ const MAX_ITERATIONS: u32 = 3; /// iterating again. fn cycle_recover( _db: &dyn Db, + _id: salsa::Id, + _last_provisional_value: &Value, value: &Value, count: u32, _inputs: Inputs, @@ -440,7 +442,6 @@ fn two_fallback_count() { /// /// Two-query cycle, falls back but fallback does not converge. #[test] -#[should_panic(expected = "too many cycle iterations")] fn two_fallback_diverge() { let mut db = DbImpl::new(); let a_in = Inputs::new(&db, vec![]); @@ -1167,7 +1168,7 @@ fn repeat_query_participating_in_cycle() { value: u32, } - #[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=initial)] + #[salsa::tracked(cycle_initial=initial)] fn head(db: &dyn Db, input: Input) -> u32 { let a = query_a(db, input); @@ -1178,15 +1179,6 @@ fn repeat_query_participating_in_cycle() { 0 } - fn cycle_recover( - _db: &dyn Db, - _value: &u32, - _count: u32, - _input: Input, - ) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate - } - #[salsa::tracked] fn query_a(db: &dyn Db, input: Input) -> u32 { let _ = query_b(db, input); @@ -1281,7 +1273,7 @@ fn repeat_query_participating_in_cycle2() { value: u32, } - #[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=initial)] + #[salsa::tracked(cycle_initial=initial)] fn head(db: &dyn Db, input: Input) -> u32 { let a = query_a(db, input); @@ -1292,16 +1284,7 @@ fn repeat_query_participating_in_cycle2() { 0 } - fn cycle_recover( - _db: &dyn Db, - _value: &u32, - _count: u32, - _input: Input, - ) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate - } - - #[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=initial)] + #[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn Db, input: Input) -> u32 { let _ = query_hot(db, input); query_b(db, input) diff --git a/tests/cycle_accumulate.rs b/tests/cycle_accumulate.rs index e06fe033b..8148e952d 100644 --- a/tests/cycle_accumulate.rs +++ b/tests/cycle_accumulate.rs @@ -50,6 +50,8 @@ fn cycle_initial(_db: &dyn LogDatabase, _file: File) -> Vec { fn cycle_fn( _db: &dyn LogDatabase, + _id: salsa::Id, + _last_provisional_value: &[u32], _value: &[u32], _count: u32, _file: File, diff --git a/tests/cycle_initial_call_back_into_cycle.rs b/tests/cycle_initial_call_back_into_cycle.rs index 326fd46c7..e56c4c4d1 100644 --- a/tests/cycle_initial_call_back_into_cycle.rs +++ b/tests/cycle_initial_call_back_into_cycle.rs @@ -7,7 +7,7 @@ fn initial_value(db: &dyn salsa::Database) -> u32 { query(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query(db: &dyn salsa::Database) -> u32 { let val = query(db); if val < 5 { @@ -21,14 +21,6 @@ fn cycle_initial(db: &dyn salsa::Database) -> u32 { initial_value(db) } -fn cycle_fn( - _db: &dyn salsa::Database, - _value: &u32, - _count: u32, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate -} - #[test_log::test] #[should_panic(expected = "dependency graph cycle")] fn the_test() { diff --git a/tests/cycle_initial_call_query.rs b/tests/cycle_initial_call_query.rs index cb10e77e1..2212ef958 100644 --- a/tests/cycle_initial_call_query.rs +++ b/tests/cycle_initial_call_query.rs @@ -7,7 +7,7 @@ fn initial_value(_db: &dyn salsa::Database) -> u32 { 0 } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query(db: &dyn salsa::Database) -> u32 { let val = query(db); if val < 5 { @@ -21,14 +21,6 @@ fn cycle_initial(db: &dyn salsa::Database) -> u32 { initial_value(db) } -fn cycle_fn( - _db: &dyn salsa::Database, - _value: &u32, - _count: u32, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate -} - #[test_log::test] fn the_test() { let db = salsa::DatabaseImpl::default(); diff --git a/tests/cycle_maybe_changed_after.rs b/tests/cycle_maybe_changed_after.rs index 6ee42d3a5..8c00c484a 100644 --- a/tests/cycle_maybe_changed_after.rs +++ b/tests/cycle_maybe_changed_after.rs @@ -4,7 +4,7 @@ mod common; use crate::common::EventLoggerDatabase; -use salsa::{CycleRecoveryAction, Database, Durability, Setter}; +use salsa::{Database, Durability, Setter}; #[salsa::input(debug)] struct Input { @@ -17,7 +17,7 @@ struct Output<'db> { value: u32, } -#[salsa::tracked(cycle_fn=query_a_recover, cycle_initial=query_a_initial)] +#[salsa::tracked(cycle_initial=query_a_initial)] fn query_c<'db>(db: &'db dyn salsa::Database, input: Input) -> u32 { query_d(db, input) } @@ -40,21 +40,12 @@ fn query_a_initial(_db: &dyn Database, _input: Input) -> u32 { 0 } -fn query_a_recover( - _db: &dyn Database, - _output: &u32, - _count: u32, - _input: Input, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - /// Only the first iteration depends on `input.value`. It's important that the entire query /// reruns if `input.value` changes. That's why salsa has to carry-over the inputs and outputs /// from the previous iteration. #[test_log::test] fn first_iteration_input_only() { - #[salsa::tracked(cycle_fn=query_a_recover, cycle_initial=query_a_initial)] + #[salsa::tracked(cycle_initial=query_a_initial)] fn query_a<'db>(db: &'db dyn salsa::Database, input: Input) -> u32 { query_b(db, input) } @@ -126,7 +117,7 @@ fn nested_cycle_fewer_dependencies_in_first_iteration() { scope: Scope<'db>, } - #[salsa::tracked(cycle_fn=head_recover, cycle_initial=head_initial)] + #[salsa::tracked(cycle_initial=head_initial)] fn cycle_head<'db>(db: &'db dyn salsa::Database, input: Input) -> Option> { let b = cycle_outer(db, input); tracing::info!("query_b = {b:?}"); @@ -141,15 +132,6 @@ fn nested_cycle_fewer_dependencies_in_first_iteration() { None } - fn head_recover<'db>( - _db: &'db dyn Database, - _output: &Option>, - _count: u32, - _input: Input, - ) -> CycleRecoveryAction>> { - CycleRecoveryAction::Iterate - } - #[salsa::tracked] fn cycle_outer<'db>(db: &'db dyn salsa::Database, input: Input) -> Option> { cycle_participant(db, input) diff --git a/tests/cycle_output.rs b/tests/cycle_output.rs index 59b789aa4..02a3b569f 100644 --- a/tests/cycle_output.rs +++ b/tests/cycle_output.rs @@ -35,7 +35,7 @@ fn query_a(db: &dyn Db, input: InputValue) -> u32 { } } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_b(db: &dyn Db, input: InputValue) -> u32 { query_a(db, input) } @@ -44,15 +44,6 @@ fn cycle_initial(_db: &dyn Db, _input: InputValue) -> u32 { 0 } -fn cycle_fn( - _db: &dyn Db, - _value: &u32, - _count: u32, - _input: InputValue, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate -} - #[salsa::tracked] fn query_c(db: &dyn Db, input: InputValue) -> u32 { input.value(db) diff --git a/tests/cycle_recovery_call_back_into_cycle.rs b/tests/cycle_recovery_call_back_into_cycle.rs index 805a2be7b..358f988ad 100644 --- a/tests/cycle_recovery_call_back_into_cycle.rs +++ b/tests/cycle_recovery_call_back_into_cycle.rs @@ -25,7 +25,13 @@ fn cycle_initial(_db: &dyn ValueDatabase) -> u32 { 0 } -fn cycle_fn(db: &dyn ValueDatabase, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { +fn cycle_fn( + db: &dyn ValueDatabase, + _id: salsa::Id, + _last_provisional_value: &u32, + _value: &u32, + _count: u32, +) -> salsa::CycleRecoveryAction { salsa::CycleRecoveryAction::Fallback(fallback_value(db)) } diff --git a/tests/cycle_recovery_call_query.rs b/tests/cycle_recovery_call_query.rs index dcc31abeb..37341a202 100644 --- a/tests/cycle_recovery_call_query.rs +++ b/tests/cycle_recovery_call_query.rs @@ -23,6 +23,8 @@ fn cycle_initial(_db: &dyn salsa::Database) -> u32 { fn cycle_fn( db: &dyn salsa::Database, + _id: salsa::Id, + _last_provisional_value: &u32, _value: &u32, _count: u32, ) -> salsa::CycleRecoveryAction { diff --git a/tests/cycle_regression_455.rs b/tests/cycle_regression_455.rs index 99c193ab9..a083cb996 100644 --- a/tests/cycle_regression_455.rs +++ b/tests/cycle_regression_455.rs @@ -7,21 +7,11 @@ fn memoized(db: &dyn Database, input: MyInput) -> u32 { memoized_a(db, MyTracked::new(db, input.field(db))) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn memoized_a<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> u32 { MyTracked::new(db, 0); memoized_b(db, tracked) } - -fn cycle_fn<'db>( - _db: &'db dyn Database, - _value: &u32, - _count: u32, - _input: MyTracked<'db>, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate -} - fn cycle_initial(_db: &dyn Database, _input: MyTracked) -> u32 { 0 } diff --git a/tests/cycle_tracked.rs b/tests/cycle_tracked.rs index 2e0c2cfd0..5ee4e1620 100644 --- a/tests/cycle_tracked.rs +++ b/tests/cycle_tracked.rs @@ -4,7 +4,7 @@ mod common; use crate::common::{EventLoggerDatabase, LogDatabase}; use expect_test::expect; -use salsa::{CycleRecoveryAction, Database, Setter}; +use salsa::{Database, Setter}; #[derive(Clone, Debug, Eq, PartialEq, Hash, salsa::Update)] struct Graph<'db> { @@ -86,7 +86,7 @@ fn create_graph(db: &dyn salsa::Database, input: GraphInput) -> Graph<'_> { } /// Computes the minimum cost from the node with offset `0` to the given node. -#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=max_initial)] +#[salsa::tracked(cycle_initial=max_initial)] fn cost_to_start<'db>(db: &'db dyn Database, node: Node<'db>) -> usize { let mut min_cost = usize::MAX; let graph = create_graph(db, node.graph(db)); @@ -114,15 +114,6 @@ fn max_initial(_db: &dyn Database, _node: Node) -> usize { usize::MAX } -fn cycle_recover( - _db: &dyn Database, - _value: &usize, - _count: u32, - _inputs: Node, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - /// Tests for cycles where the cycle head is stored on a tracked struct /// and that tracked struct is freed in a later revision. #[test] @@ -215,7 +206,7 @@ struct IterationNode<'db> { /// 3. Second iteration: returns `[iter_0, iter_1]` /// 4. Third iteration (only for variant=1): returns `[iter_0, iter_1, iter_2]` /// 5. Further iterations: no change, fixpoint reached -#[salsa::tracked(cycle_fn=cycle_recover_with_structs, cycle_initial=initial_with_structs)] +#[salsa::tracked(cycle_initial=initial_with_structs)] fn create_tracked_in_cycle<'db>( db: &'db dyn Database, input: GraphInput, @@ -259,16 +250,6 @@ fn initial_with_structs(_db: &dyn Database, _input: GraphInput) -> Vec( - _db: &'db dyn Database, - _value: &Vec>, - _iteration: u32, - _input: GraphInput, -) -> CycleRecoveryAction>> { - CycleRecoveryAction::Iterate -} - #[test_log::test] fn test_cycle_with_fixpoint_structs() { let mut db = EventLoggerDatabase::default(); diff --git a/tests/cycle_tracked_own_input.rs b/tests/cycle_tracked_own_input.rs index 38218f1a7..79035bab5 100644 --- a/tests/cycle_tracked_own_input.rs +++ b/tests/cycle_tracked_own_input.rs @@ -11,7 +11,7 @@ mod common; use crate::common::{EventLoggerDatabase, LogDatabase}; use expect_test::expect; -use salsa::{CycleRecoveryAction, Database, Setter}; +use salsa::{Database, Setter}; #[salsa::input(debug)] struct ClassNode { @@ -52,7 +52,7 @@ impl Type<'_> { } } -#[salsa::tracked(cycle_fn=infer_class_recover, cycle_initial=infer_class_initial)] +#[salsa::tracked(cycle_initial=infer_class_initial)] fn infer_class<'db>(db: &'db dyn salsa::Database, node: ClassNode) -> Type<'db> { Type::Class(Class::new( db, @@ -85,15 +85,6 @@ fn infer_class_initial(_db: &'_ dyn Database, _node: ClassNode) -> Type<'_> { Type::Unknown } -fn infer_class_recover<'db>( - _db: &'db dyn Database, - _type: &Type<'db>, - _count: u32, - _inputs: ClassNode, -) -> CycleRecoveryAction> { - CycleRecoveryAction::Iterate -} - #[test] fn main() { let mut db = EventLoggerDatabase::default(); diff --git a/tests/dataflow.rs b/tests/dataflow.rs index 960cc33f5..793870322 100644 --- a/tests/dataflow.rs +++ b/tests/dataflow.rs @@ -77,6 +77,8 @@ fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { fn def_cycle_recover( _db: &dyn Db, + _id: salsa::Id, + _last_provisional_value: &Type, value: &Type, count: u32, _def: Definition, @@ -90,6 +92,8 @@ fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { fn use_cycle_recover( _db: &dyn Db, + _id: salsa::Id, + _last_provisional_value: &Type, value: &Type, count: u32, _use: Use, diff --git a/tests/parallel/cycle_a_t1_b_t2.rs b/tests/parallel/cycle_a_t1_b_t2.rs index ad21b7963..6a434099e 100644 --- a/tests/parallel/cycle_a_t1_b_t2.rs +++ b/tests/parallel/cycle_a_t1_b_t2.rs @@ -15,8 +15,6 @@ use crate::sync::thread; use crate::{Knobs, KnobsDatabase}; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); @@ -26,7 +24,7 @@ const MAX: CycleValue = CycleValue(3); // Signal 1: T1 has entered `query_a` // Signal 2: T2 has entered `query_b` -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { db.signal(1); @@ -36,7 +34,7 @@ fn query_a(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { // Wait for Thread T1 to enter `query_a` before we continue. db.wait_for(1); @@ -47,14 +45,6 @@ fn query_b(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(a_value.0 + 1).min(MAX) } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_ab_peeping_c.rs b/tests/parallel/cycle_ab_peeping_c.rs index 134fe7429..8ed2b4fb6 100644 --- a/tests/parallel/cycle_ab_peeping_c.rs +++ b/tests/parallel/cycle_ab_peeping_c.rs @@ -9,8 +9,6 @@ use crate::sync::thread; use crate::{Knobs, KnobsDatabase}; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); @@ -18,7 +16,7 @@ const MIN: CycleValue = CycleValue(0); const MID: CycleValue = CycleValue(5); const MAX: CycleValue = CycleValue(10); -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { let b_value = query_b(db); @@ -32,19 +30,11 @@ fn query_a(db: &dyn KnobsDatabase) -> CycleValue { b_value } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { let a_value = query_a(db); diff --git a/tests/parallel/cycle_iteration_mismatch.rs b/tests/parallel/cycle_iteration_mismatch.rs index 17cc60108..61d1da01d 100644 --- a/tests/parallel/cycle_iteration_mismatch.rs +++ b/tests/parallel/cycle_iteration_mismatch.rs @@ -7,7 +7,6 @@ use crate::sync::thread; use crate::{Knobs, KnobsDatabase}; -use salsa::CycleRecoveryAction; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); @@ -16,14 +15,14 @@ const MIN: CycleValue = CycleValue(0); const MAX: CycleValue = CycleValue(5); // Query A: First cycle head - will iterate multiple times -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { let b = query_b(db); CycleValue(b.0 + 1).min(MAX) } // Query B: Depends on C and D, creating complex dependencies -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { let c = query_c(db); let d = query_d(db); @@ -31,7 +30,7 @@ fn query_b(db: &dyn KnobsDatabase) -> CycleValue { } // Query C: Creates a cycle back to A -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_c(db: &dyn KnobsDatabase) -> CycleValue { let a = query_a(db); // Also depends on E to create more complex cycle structure @@ -40,14 +39,14 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { } // Query D: Part of a separate cycle with E -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_d(db: &dyn KnobsDatabase) -> CycleValue { let e = query_e(db); CycleValue(e.0 + 1).min(MAX) } // Query E: Depends back on D and F -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_e(db: &dyn KnobsDatabase) -> CycleValue { let d = query_d(db); let f = query_f(db); @@ -55,7 +54,7 @@ fn query_e(db: &dyn KnobsDatabase) -> CycleValue { } // Query F: Creates another cycle that might have different iteration count -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_f(db: &dyn KnobsDatabase) -> CycleValue { // Create a cycle that depends on earlier queries let b = query_b(db); @@ -63,14 +62,6 @@ fn query_f(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(b.0.max(e.0)) } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep.rs b/tests/parallel/cycle_nested_deep.rs index f2b355616..3d46bbbc5 100644 --- a/tests/parallel/cycle_nested_deep.rs +++ b/tests/parallel/cycle_nested_deep.rs @@ -9,26 +9,24 @@ use crate::sync::thread; use crate::{Knobs, KnobsDatabase}; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); const MIN: CycleValue = CycleValue(0); const MAX: CycleValue = CycleValue(3); -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { let c_value = query_c(db); CycleValue(c_value.0 + 1).min(MAX) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_c(db: &dyn KnobsDatabase) -> CycleValue { let d_value = query_d(db); let e_value = query_e(db); @@ -38,24 +36,16 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(d_value.0.max(e_value.0).max(b_value.0).max(a_value.0)) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_d(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_e(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep_conditional.rs b/tests/parallel/cycle_nested_deep_conditional.rs index 4eff75189..544342e07 100644 --- a/tests/parallel/cycle_nested_deep_conditional.rs +++ b/tests/parallel/cycle_nested_deep_conditional.rs @@ -14,26 +14,24 @@ use crate::sync::thread; use crate::{Knobs, KnobsDatabase}; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); const MIN: CycleValue = CycleValue(0); const MAX: CycleValue = CycleValue(3); -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { let c_value = query_c(db); CycleValue(c_value.0 + 1).min(MAX) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_c(db: &dyn KnobsDatabase) -> CycleValue { let d_value = query_d(db); @@ -47,24 +45,16 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { } } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_d(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_e(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep_conditional_changed.rs b/tests/parallel/cycle_nested_deep_conditional_changed.rs index 51d506456..03423b09a 100644 --- a/tests/parallel/cycle_nested_deep_conditional_changed.rs +++ b/tests/parallel/cycle_nested_deep_conditional_changed.rs @@ -15,8 +15,6 @@ //! Specifically, the maybe_changed_after flow. use crate::sync::thread; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); @@ -28,18 +26,18 @@ struct Input { value: u32, } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn salsa::Database, input: Input) -> CycleValue { query_b(db, input) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn salsa::Database, input: Input) -> CycleValue { let c_value = query_c(db, input); CycleValue(c_value.0 + input.value(db).max(1)).min(MAX) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_c(db: &dyn salsa::Database, input: Input) -> CycleValue { let d_value = query_d(db, input); @@ -53,25 +51,16 @@ fn query_c(db: &dyn salsa::Database, input: Input) -> CycleValue { } } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_d(db: &dyn salsa::Database, input: Input) -> CycleValue { query_c(db, input) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_e(db: &dyn salsa::Database, input: Input) -> CycleValue { query_c(db, input) } -fn cycle_fn( - _db: &dyn salsa::Database, - _value: &CycleValue, - _count: u32, - _input: Input, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn salsa::Database, _input: Input) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep_panic.rs b/tests/parallel/cycle_nested_deep_panic.rs index 8b89f362a..4356489c3 100644 --- a/tests/parallel/cycle_nested_deep_panic.rs +++ b/tests/parallel/cycle_nested_deep_panic.rs @@ -8,20 +8,18 @@ use crate::{Knobs, KnobsDatabase}; use std::fmt; use std::panic::catch_unwind; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); const MIN: CycleValue = CycleValue(0); const MAX: CycleValue = CycleValue(3); -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { let c_value = query_c(db); CycleValue(c_value.0 + 1).min(MAX) @@ -41,24 +39,16 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { } } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_d(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_e(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_three_threads.rs b/tests/parallel/cycle_nested_three_threads.rs index 22232bd85..728fc3e70 100644 --- a/tests/parallel/cycle_nested_three_threads.rs +++ b/tests/parallel/cycle_nested_three_threads.rs @@ -17,8 +17,6 @@ use crate::sync::thread; use crate::{Knobs, KnobsDatabase}; -use salsa::CycleRecoveryAction; - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); @@ -29,7 +27,7 @@ const MAX: CycleValue = CycleValue(3); // Signal 2: T2 has entered `query_b` // Signal 3: T3 has entered `query_c` -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { db.signal(1); db.wait_for(3); @@ -37,7 +35,7 @@ fn query_a(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { db.wait_for(1); db.signal(2); @@ -47,7 +45,7 @@ fn query_b(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(c_value.0 + 1).min(MAX) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_c(db: &dyn KnobsDatabase) -> CycleValue { db.wait_for(2); db.signal(3); @@ -57,14 +55,6 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(a_value.0.max(b_value.0)) } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_three_threads_changed.rs b/tests/parallel/cycle_nested_three_threads_changed.rs index ccd92a407..626b3ef90 100644 --- a/tests/parallel/cycle_nested_three_threads_changed.rs +++ b/tests/parallel/cycle_nested_three_threads_changed.rs @@ -19,7 +19,7 @@ use crate::sync; use crate::sync::thread; -use salsa::{CycleRecoveryAction, DatabaseImpl, Setter as _}; +use salsa::{DatabaseImpl, Setter as _}; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] struct CycleValue(u32); @@ -36,33 +36,24 @@ struct Input { // Signal 2: T2 has entered `query_b` // Signal 3: T3 has entered `query_c` -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_a(db: &dyn salsa::Database, input: Input) -> CycleValue { query_b(db, input) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_b(db: &dyn salsa::Database, input: Input) -> CycleValue { let c_value = query_c(db, input); CycleValue(c_value.0 + input.value(db)).min(MAX) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +#[salsa::tracked(cycle_initial=initial)] fn query_c(db: &dyn salsa::Database, input: Input) -> CycleValue { let a_value = query_a(db, input); let b_value = query_b(db, input); CycleValue(a_value.0.max(b_value.0)) } -fn cycle_fn( - _db: &dyn salsa::Database, - _value: &CycleValue, - _count: u32, - _input: Input, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn initial(_db: &dyn salsa::Database, _input: Input) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_panic.rs b/tests/parallel/cycle_panic.rs index a713809b7..13c988f8f 100644 --- a/tests/parallel/cycle_panic.rs +++ b/tests/parallel/cycle_panic.rs @@ -18,7 +18,13 @@ fn query_b(db: &dyn KnobsDatabase) -> u32 { query_a(db) + 1 } -fn cycle_fn(_db: &dyn KnobsDatabase, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { +fn cycle_fn( + _db: &dyn KnobsDatabase, + _id: salsa::Id, + _last_provisional_value: &u32, + _value: &u32, + _count: u32, +) -> salsa::CycleRecoveryAction { panic!("cancel!") } diff --git a/tests/parallel/cycle_provisional_depending_on_itself.rs b/tests/parallel/cycle_provisional_depending_on_itself.rs index ba3645fd5..bb615210e 100644 --- a/tests/parallel/cycle_provisional_depending_on_itself.rs +++ b/tests/parallel/cycle_provisional_depending_on_itself.rs @@ -19,7 +19,6 @@ //! 3. `t1`: Iterates on `a`, finalizes the memo use crate::sync::thread; -use salsa::CycleRecoveryAction; use crate::setup::{Knobs, KnobsDatabase}; @@ -29,12 +28,12 @@ struct CycleValue(u32); const MIN: CycleValue = CycleValue(0); const MAX: CycleValue = CycleValue(1); -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { // Wait for thread 2 to have entered `query_c`. tracing::debug!("Wait for signal 1 from thread 2"); @@ -55,7 +54,7 @@ fn query_b(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(a_value.0 + 1).min(MAX) } -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_initial=cycle_initial)] fn query_c(db: &dyn KnobsDatabase) -> CycleValue { tracing::debug!("query_c: signaling thread1 to call c"); db.signal(1); @@ -68,14 +67,6 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { b } -fn cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } From 25b3ef146cfa2615f4ec82760bd0c22b454d0a12 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Fri, 24 Oct 2025 17:58:45 +0200 Subject: [PATCH 56/65] Fix cache invalidation when cycle head becomes non-head (#1014) * Fix cache invalidation when cycle head becomes non-head * Discard changes to src/function/fetch.rs * Inline comment --- src/cycle.rs | 15 +++++++++++++-- src/function.rs | 18 +++++++----------- src/function/execute.rs | 10 +++++++--- src/function/maybe_changed_after.rs | 28 ++++++++++++++++++++++++++-- src/function/memo.rs | 21 ++++++++++++++------- src/ingredient.rs | 14 ++++++-------- 6 files changed, 73 insertions(+), 33 deletions(-) diff --git a/src/cycle.rs b/src/cycle.rs index 09ec51525..fcbadf891 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -346,6 +346,7 @@ impl CycleHeads { if *removed { *removed = false; + existing.iteration_count.store_mut(iteration_count); true } else { @@ -468,11 +469,12 @@ pub(crate) fn empty_cycle_heads() -> &'static CycleHeads { EMPTY_CYCLE_HEADS.get_or_init(|| CycleHeads(ThinVec::new())) } -#[derive(Debug, PartialEq, Eq)] -pub enum ProvisionalStatus { +#[derive(Debug)] +pub enum ProvisionalStatus<'db> { Provisional { iteration: IterationCount, verified_at: Revision, + cycle_heads: &'db CycleHeads, }, Final { iteration: IterationCount, @@ -480,3 +482,12 @@ pub enum ProvisionalStatus { }, FallbackImmediate, } + +impl<'db> ProvisionalStatus<'db> { + pub(crate) fn cycle_heads(&self) -> &'db CycleHeads { + match self { + ProvisionalStatus::Provisional { cycle_heads, .. } => cycle_heads, + _ => empty_cycle_heads(), + } + } +} diff --git a/src/function.rs b/src/function.rs index 1cf3e9478..512c8ba70 100644 --- a/src/function.rs +++ b/src/function.rs @@ -7,10 +7,7 @@ use std::ptr::NonNull; use std::sync::atomic::Ordering; use std::sync::OnceLock; -use crate::cycle::{ - empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, IterationCount, - ProvisionalStatus, -}; +use crate::cycle::{CycleRecoveryAction, CycleRecoveryStrategy, IterationCount, ProvisionalStatus}; use crate::database::RawDatabase; use crate::function::delete::DeletedEntries; use crate::hash::{FxHashSet, FxIndexSet}; @@ -357,7 +354,11 @@ where /// /// Otherwise, the value is still provisional. For both final and provisional, it also /// returns the iteration in which this memo was created (always 0 except for cycle heads). - fn provisional_status(&self, zalsa: &Zalsa, input: Id) -> Option { + fn provisional_status<'db>( + &self, + zalsa: &'db Zalsa, + input: Id, + ) -> Option> { let memo = self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input))?; @@ -377,6 +378,7 @@ where ProvisionalStatus::Provisional { iteration, verified_at: memo.verified_at.load(), + cycle_heads: memo.cycle_heads(), } }) } @@ -416,12 +418,6 @@ where self.sync_table.mark_as_transfer_target(key_index) } - fn cycle_heads<'db>(&self, zalsa: &'db Zalsa, input: Id) -> &'db CycleHeads { - self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input)) - .map(|memo| memo.cycle_heads()) - .unwrap_or(empty_cycle_heads()) - } - /// Attempts to claim `key_index` without blocking. /// /// * [`WaitForResult::Running`] if the `key_index` is running on another thread. It's up to the caller to block on the other thread diff --git a/src/function/execute.rs b/src/function/execute.rs index 3acfaadc8..5e3c226be 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -248,9 +248,11 @@ where let ingredient = zalsa.lookup_ingredient(head.database_key_index.ingredient_index()); - for nested_head in - ingredient.cycle_heads(zalsa, head.database_key_index.key_index()) - { + let provisional_status = ingredient + .provisional_status(zalsa, head.database_key_index.key_index()) + .expect("cycle head memo must have been created during the execution"); + + for nested_head in provisional_status.cycle_heads() { let nested_as_tuple = ( nested_head.database_key_index, nested_head.iteration_count.load(), @@ -442,6 +444,8 @@ where // Update the iteration count of this cycle head, but only after restoring // the cycle heads array (or this becomes a no-op). + // We don't call the same method on `cycle_heads` because that one doens't update + // the `memo.iteration_count` completed_query.revisions.set_cycle_heads(cycle_heads); completed_query .revisions diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 62839e865..4198631b9 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -484,8 +484,9 @@ where // Always return `false` for cycle initial values "unless" they are running in the same thread. if cycle_heads - .iter() - .all(|head| head.database_key_index == memo_database_key_index) + .iter_not_eq(memo_database_key_index) + .next() + .is_none() { // SAFETY: We do not access the query stack reentrantly. let on_stack = unsafe { @@ -508,6 +509,8 @@ where head_iteration_count, memo_iteration_count: current_iteration_count, verified_at: head_verified_at, + cycle_heads, + database_key_index: head_database_key, } => { if head_verified_at != memo_verified_at { return false; @@ -516,6 +519,27 @@ where if head_iteration_count != current_iteration_count { return false; } + + // Check if the memo is still a cycle head and hasn't changed + // to a normal cycle participant. This is to force re-execution in + // a scenario like this: + // + // * There's a nested cycle with the outermost query A + // * B participates in the cycle and is a cycle head in the first few iterations + // * B becomes a non-cycle head in a later iteration + // * There's a query `C` that has `B` as its cycle head + // + // The crucial point is that `B` switches from being a cycle head to being a regular cycle participant. + // The issue with that is that `A` doesn't update `B`'s `iteration_count `when the iteration completes + // because it only does that for cycle heads (and collecting all queries participating in a query would be sort of expensive?). + // + // When we now pull `C` in a later iteration, `validate_same_iteration` iterates over all its cycle heads (`B`), + // and check if the iteration count still matches. Which is the case because `A` didn't update `B`'s iteration count. + // + // That's why we also check if `B` is still a cycle head in the current iteration. + if !cycle_heads.contains(&head_database_key) { + return false; + } } _ => { return false; diff --git a/src/function/memo.rs b/src/function/memo.rs index 8fe0c1dd8..200f83a4d 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -409,9 +409,11 @@ mod persistence { pub(super) enum TryClaimHeadsResult<'me> { /// Claiming the cycle head results in a cycle. Cycle { + database_key_index: DatabaseKeyIndex, head_iteration_count: IterationCount, memo_iteration_count: IterationCount, verified_at: Revision, + cycle_heads: &'me CycleHeads, }, /// The cycle head is not finalized, but it can be claimed. @@ -458,23 +460,28 @@ impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { let provisional_status = ingredient .provisional_status(self.zalsa, head_key_index) .expect("cycle head memo to exist"); - let (current_iteration_count, verified_at) = match provisional_status { + let (current_iteration_count, verified_at, cycle_heads) = match provisional_status { ProvisionalStatus::Provisional { iteration, verified_at, - } - | ProvisionalStatus::Final { + cycle_heads, + } => (iteration, verified_at, cycle_heads), + ProvisionalStatus::Final { iteration, verified_at, - } => (iteration, verified_at), - ProvisionalStatus::FallbackImmediate => { - (IterationCount::initial(), self.zalsa.current_revision()) - } + } => (iteration, verified_at, empty_cycle_heads()), + ProvisionalStatus::FallbackImmediate => ( + IterationCount::initial(), + self.zalsa.current_revision(), + empty_cycle_heads(), + ), }; Some(TryClaimHeadsResult::Cycle { + database_key_index: head_database_key, memo_iteration_count: current_iteration_count, head_iteration_count: head.iteration_count.load(), + cycle_heads, verified_at, }) } diff --git a/src/ingredient.rs b/src/ingredient.rs index 9b377e4d1..6fe525c4f 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -1,7 +1,7 @@ use std::any::{Any, TypeId}; use std::fmt; -use crate::cycle::{empty_cycle_heads, CycleHeads, IterationCount, ProvisionalStatus}; +use crate::cycle::{IterationCount, ProvisionalStatus}; use crate::database::RawDatabase; use crate::function::{VerifyCycleHeads, VerifyResult}; use crate::hash::{FxHashSet, FxIndexSet}; @@ -74,16 +74,14 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// Is it a provisional value or has it been finalized and in which iteration. /// /// Returns `None` if `input` doesn't exist. - fn provisional_status(&self, _zalsa: &Zalsa, _input: Id) -> Option { + fn provisional_status<'db>( + &self, + _zalsa: &'db Zalsa, + _input: Id, + ) -> Option> { unreachable!("provisional_status should only be called on cycle heads and only functions can be cycle heads"); } - /// Returns the cycle heads for this ingredient. - fn cycle_heads<'db>(&self, zalsa: &'db Zalsa, input: Id) -> &'db CycleHeads { - _ = (zalsa, input); - empty_cycle_heads() - } - /// Invoked when the current thread needs to wait for a result for the given `key_index`. /// This call doesn't block the current thread. Instead, it's up to the caller to block /// in case `key_index` is [running](`WaitForResult::Running`) on another thread. From e8ddb4dbf7f0adbfa951a6f6e793a2ce3b165355 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sun, 26 Oct 2025 15:44:49 +0100 Subject: [PATCH 57/65] pref: Add `SyncTable::peek_claim` fast path for `function::Ingredient::wait_for` (#1011) --- src/function.rs | 4 +-- src/function/sync.rs | 82 ++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 78 insertions(+), 8 deletions(-) diff --git a/src/function.rs b/src/function.rs index 512c8ba70..434a895a5 100644 --- a/src/function.rs +++ b/src/function.rs @@ -428,11 +428,11 @@ where fn wait_for<'me>(&'me self, zalsa: &'me Zalsa, key_index: Id) -> WaitForResult<'me> { match self .sync_table - .try_claim(zalsa, key_index, Reentrancy::Deny) + .peek_claim(zalsa, key_index, Reentrancy::Deny) { ClaimResult::Running(blocked_on) => WaitForResult::Running(blocked_on), ClaimResult::Cycle { inner } => WaitForResult::Cycle { inner }, - ClaimResult::Claimed(_) => WaitForResult::Available, + ClaimResult::Claimed(()) => WaitForResult::Available, } } diff --git a/src/function/sync.rs b/src/function/sync.rs index 02f1bffd0..c9a74a307 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -20,7 +20,7 @@ pub(crate) struct SyncTable { ingredient: IngredientIndex, } -pub(crate) enum ClaimResult<'a> { +pub(crate) enum ClaimResult<'a, Guard = ClaimGuard<'a>> { /// Can't claim the query because it is running on an other thread. Running(Running<'a>), /// Claiming the query results in a cycle. @@ -31,7 +31,7 @@ pub(crate) enum ClaimResult<'a> { inner: bool, }, /// Successfully claimed the query. - Claimed(ClaimGuard<'a>), + Claimed(Guard), } pub(crate) struct SyncState { @@ -87,10 +87,7 @@ impl SyncTable { } }; - let &mut SyncState { - ref mut anyone_waiting, - .. - } = occupied_entry.into_mut(); + let SyncState { anyone_waiting, .. } = occupied_entry.into_mut(); // NB: `Ordering::Relaxed` is sufficient here, // as there are no loads that are "gated" on this @@ -125,6 +122,51 @@ impl SyncTable { } } + /// Claims the given key index, or blocks if it is running on another thread. + pub(crate) fn peek_claim<'me>( + &'me self, + zalsa: &'me Zalsa, + key_index: Id, + reentrant: Reentrancy, + ) -> ClaimResult<'me, ()> { + let mut write = self.syncs.lock(); + match write.entry(key_index) { + std::collections::hash_map::Entry::Occupied(occupied_entry) => { + let id = match occupied_entry.get().id { + SyncOwner::Thread(id) => id, + SyncOwner::Transferred => { + return match self.peek_claim_transferred(zalsa, occupied_entry, reentrant) { + Ok(claimed) => claimed, + Err(other_thread) => match other_thread.block(write) { + BlockResult::Cycle => ClaimResult::Cycle { inner: false }, + BlockResult::Running(running) => ClaimResult::Running(running), + }, + } + } + }; + + let SyncState { anyone_waiting, .. } = occupied_entry.into_mut(); + + // NB: `Ordering::Relaxed` is sufficient here, + // as there are no loads that are "gated" on this + // value. Everything that is written is also protected + // by a lock that must be acquired. The role of this + // boolean is to decide *whether* to acquire the lock, + // not to gate future atomic reads. + *anyone_waiting = true; + match zalsa.runtime().block( + DatabaseKeyIndex::new(self.ingredient, key_index), + id, + write, + ) { + BlockResult::Running(blocked_on) => ClaimResult::Running(blocked_on), + BlockResult::Cycle => ClaimResult::Cycle { inner: false }, + } + } + std::collections::hash_map::Entry::Vacant(_) => ClaimResult::Claimed(()), + } + } + #[cold] #[inline(never)] fn try_claim_transferred<'me>( @@ -179,6 +221,34 @@ impl SyncTable { } } + #[cold] + #[inline(never)] + fn peek_claim_transferred<'me>( + &'me self, + zalsa: &'me Zalsa, + mut entry: OccupiedEntry, + reentrant: Reentrancy, + ) -> Result, Box>> { + let key_index = *entry.key(); + let database_key_index = DatabaseKeyIndex::new(self.ingredient, key_index); + let thread_id = thread::current().id(); + + match zalsa + .runtime() + .block_transferred(database_key_index, thread_id) + { + BlockTransferredResult::ImTheOwner if reentrant.is_allow() => { + Ok(ClaimResult::Claimed(())) + } + BlockTransferredResult::ImTheOwner => Ok(ClaimResult::Cycle { inner: true }), + BlockTransferredResult::OwnedBy(other_thread) => { + entry.get_mut().anyone_waiting = true; + Err(other_thread) + } + BlockTransferredResult::Released => Ok(ClaimResult::Claimed(())), + } + } + /// Marks `key_index` as a transfer target. /// /// Returns the `SyncOwnerId` of the thread that currently owns this query. From cdd0b85516a52c18b8a6d17a2279a96ed6c3e198 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Mon, 27 Oct 2025 21:27:09 +0900 Subject: [PATCH 58/65] Expose the Input query Id with cycle_initial (#1015) --- benches/dataflow.rs | 4 ++-- components/salsa-macro-rules/src/setup_tracked_fn.rs | 4 ++-- .../salsa-macro-rules/src/unexpected_cycle_recovery.rs | 2 +- src/function.rs | 6 +++++- src/function/execute.rs | 2 +- src/function/fetch.rs | 4 ++-- src/function/memo.rs | 6 +++++- tests/backtrace.rs | 2 +- tests/cycle.rs | 8 ++++---- tests/cycle_accumulate.rs | 2 +- tests/cycle_fallback_immediate.rs | 4 ++-- tests/cycle_initial_call_back_into_cycle.rs | 2 +- tests/cycle_initial_call_query.rs | 2 +- tests/cycle_maybe_changed_after.rs | 4 ++-- tests/cycle_output.rs | 2 +- tests/cycle_recovery_call_back_into_cycle.rs | 2 +- tests/cycle_recovery_call_query.rs | 2 +- tests/cycle_regression_455.rs | 2 +- tests/cycle_result_dependencies.rs | 2 +- tests/cycle_tracked.rs | 8 ++++++-- tests/cycle_tracked_own_input.rs | 2 +- tests/dataflow.rs | 4 ++-- tests/parallel/cycle_a_t1_b_t2.rs | 2 +- tests/parallel/cycle_a_t1_b_t2_fallback.rs | 4 ++-- tests/parallel/cycle_ab_peeping_c.rs | 2 +- tests/parallel/cycle_iteration_mismatch.rs | 2 +- tests/parallel/cycle_nested_deep.rs | 2 +- tests/parallel/cycle_nested_deep_conditional.rs | 2 +- tests/parallel/cycle_nested_deep_conditional_changed.rs | 2 +- tests/parallel/cycle_nested_deep_panic.rs | 2 +- tests/parallel/cycle_nested_three_threads.rs | 2 +- tests/parallel/cycle_nested_three_threads_changed.rs | 2 +- tests/parallel/cycle_panic.rs | 2 +- tests/parallel/cycle_provisional_depending_on_itself.rs | 2 +- 34 files changed, 57 insertions(+), 45 deletions(-) diff --git a/benches/dataflow.rs b/benches/dataflow.rs index d1acfd27b..a548c806a 100644 --- a/benches/dataflow.rs +++ b/benches/dataflow.rs @@ -70,7 +70,7 @@ fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { } } -fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { +fn def_cycle_initial(_db: &dyn Db, _id: salsa::Id, _def: Definition) -> Type { Type::Bottom } @@ -85,7 +85,7 @@ fn def_cycle_recover( cycle_recover(value, count) } -fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { +fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { Type::Bottom } diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 961b5b4f8..8ea4e5e33 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -302,8 +302,8 @@ macro_rules! setup_tracked_fn { $inner($db, $($input_id),*) } - fn cycle_initial<$db_lt>(db: &$db_lt Self::DbView, ($($input_id),*): ($($interned_input_ty),*)) -> Self::Output<$db_lt> { - $($cycle_recovery_initial)*(db, $($input_id),*) + fn cycle_initial<$db_lt>(db: &$db_lt Self::DbView, id: salsa::Id, ($($input_id),*): ($($interned_input_ty),*)) -> Self::Output<$db_lt> { + $($cycle_recovery_initial)*(db, id, $($input_id),*) } fn recover_from_cycle<$db_lt>( diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index aa6161d28..ff03c02a2 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -12,7 +12,7 @@ macro_rules! unexpected_cycle_recovery { #[macro_export] macro_rules! unexpected_cycle_initial { - ($db:ident, $($other_inputs:ident),*) => {{ + ($db:ident, $id:ident, $($other_inputs:ident),*) => {{ std::mem::drop($db); std::mem::drop(($($other_inputs,)*)); panic!("no cycle initial value") diff --git a/src/function.rs b/src/function.rs index 434a895a5..045825e19 100644 --- a/src/function.rs +++ b/src/function.rs @@ -85,7 +85,11 @@ pub trait Configuration: Any { fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; /// Get the cycle recovery initial value. - fn cycle_initial<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; + fn cycle_initial<'db>( + db: &'db Self::DbView, + id: Id, + input: Self::Input<'db>, + ) -> Self::Output<'db>; /// Decide whether to iterate a cycle again or fallback. `value` is the provisional return /// value from the latest iteration of this cycle. `count` is the number of cycle iterations diff --git a/src/function/execute.rs b/src/function/execute.rs index 5e3c226be..d299b0966 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -94,7 +94,7 @@ where let cycle_heads = std::mem::take(cycle_heads); let active_query = zalsa_local.push_query(database_key_index, IterationCount::initial()); - new_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); + new_value = C::cycle_initial(db, id, C::id_to_input(zalsa, id)); completed_query = active_query.pop(); // We need to set `cycle_heads` and `verified_final` because it needs to propagate to the callers. // When verifying this, we will see we have fallback and mark ourselves verified. diff --git a/src/function/fetch.rs b/src/function/fetch.rs index a3f3705f4..14d7a93d7 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -236,7 +236,7 @@ where inserting and returning fixpoint initial value" ); let revisions = QueryRevisions::fixpoint_initial(database_key_index); - let initial_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); + let initial_value = C::cycle_initial(db, id, C::id_to_input(zalsa, id)); self.insert_memo( zalsa, id, @@ -250,7 +250,7 @@ where ); let active_query = zalsa_local.push_query(database_key_index, IterationCount::initial()); - let fallback_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); + let fallback_value = C::cycle_initial(db, id, C::id_to_input(zalsa, id)); let mut completed_query = active_query.pop(); completed_query .revisions diff --git a/src/function/memo.rs b/src/function/memo.rs index 200f83a4d..fd830ced3 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -558,7 +558,11 @@ mod _memory_usage { unimplemented!() } - fn cycle_initial<'db>(_: &'db Self::DbView, _: Self::Input<'db>) -> Self::Output<'db> { + fn cycle_initial<'db>( + _: &'db Self::DbView, + _: Id, + _: Self::Input<'db>, + ) -> Self::Output<'db> { unimplemented!() } diff --git a/tests/backtrace.rs b/tests/backtrace.rs index 0adf517cd..3cc5bbad0 100644 --- a/tests/backtrace.rs +++ b/tests/backtrace.rs @@ -52,7 +52,7 @@ fn query_cycle(db: &dyn Database, thing: Thing) -> String { } } -fn cycle_initial(_db: &dyn salsa::Database, _thing: Thing) -> String { +fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id, _thing: Thing) -> String { String::new() } diff --git a/tests/cycle.rs b/tests/cycle.rs index 0c4d686af..dbe0bdc19 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -173,7 +173,7 @@ fn min_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { fold_values(inputs.values(db), u8::min) } -fn min_initial(_db: &dyn Db, _inputs: Inputs) -> Value { +fn min_initial(_db: &dyn Db, _id: salsa::Id, _inputs: Inputs) -> Value { Value::N(255) } @@ -183,7 +183,7 @@ fn max_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { fold_values(inputs.values(db), u8::max) } -fn max_initial(_db: &dyn Db, _inputs: Inputs) -> Value { +fn max_initial(_db: &dyn Db, _id: salsa::Id, _inputs: Inputs) -> Value { Value::N(0) } @@ -1175,7 +1175,7 @@ fn repeat_query_participating_in_cycle() { a.min(2) } - fn initial(_db: &dyn Db, _input: Input) -> u32 { + fn initial(_db: &dyn Db, _id: salsa::Id, _input: Input) -> u32 { 0 } @@ -1280,7 +1280,7 @@ fn repeat_query_participating_in_cycle2() { a.min(2) } - fn initial(_db: &dyn Db, _input: Input) -> u32 { + fn initial(_db: &dyn Db, _id: salsa::Id, _input: Input) -> u32 { 0 } diff --git a/tests/cycle_accumulate.rs b/tests/cycle_accumulate.rs index 8148e952d..49f1d06d9 100644 --- a/tests/cycle_accumulate.rs +++ b/tests/cycle_accumulate.rs @@ -44,7 +44,7 @@ fn check_file(db: &dyn LogDatabase, file: File) -> Vec { sorted_issues } -fn cycle_initial(_db: &dyn LogDatabase, _file: File) -> Vec { +fn cycle_initial(_db: &dyn LogDatabase, _id: salsa::Id, _file: File) -> Vec { vec![] } diff --git a/tests/cycle_fallback_immediate.rs b/tests/cycle_fallback_immediate.rs index 374978d81..64f872ad1 100644 --- a/tests/cycle_fallback_immediate.rs +++ b/tests/cycle_fallback_immediate.rs @@ -11,7 +11,7 @@ fn one_o_one(db: &dyn salsa::Database) -> u32 { val + 1 } -fn cycle_result(_db: &dyn salsa::Database) -> u32 { +fn cycle_result(_db: &dyn salsa::Database, _id: salsa::Id) -> u32 { 100 } @@ -38,7 +38,7 @@ fn two_queries2(db: &dyn salsa::Database) -> i32 { CALLS_COUNT.fetch_add(1, Ordering::Relaxed) } -fn two_queries_cycle_result(_db: &dyn salsa::Database) -> i32 { +fn two_queries_cycle_result(_db: &dyn salsa::Database, _id: salsa::Id) -> i32 { 1 } diff --git a/tests/cycle_initial_call_back_into_cycle.rs b/tests/cycle_initial_call_back_into_cycle.rs index e56c4c4d1..ab7a473a2 100644 --- a/tests/cycle_initial_call_back_into_cycle.rs +++ b/tests/cycle_initial_call_back_into_cycle.rs @@ -17,7 +17,7 @@ fn query(db: &dyn salsa::Database) -> u32 { } } -fn cycle_initial(db: &dyn salsa::Database) -> u32 { +fn cycle_initial(db: &dyn salsa::Database, _id: salsa::Id) -> u32 { initial_value(db) } diff --git a/tests/cycle_initial_call_query.rs b/tests/cycle_initial_call_query.rs index 2212ef958..b16b72711 100644 --- a/tests/cycle_initial_call_query.rs +++ b/tests/cycle_initial_call_query.rs @@ -17,7 +17,7 @@ fn query(db: &dyn salsa::Database) -> u32 { } } -fn cycle_initial(db: &dyn salsa::Database) -> u32 { +fn cycle_initial(db: &dyn salsa::Database, _id: salsa::Id) -> u32 { initial_value(db) } diff --git a/tests/cycle_maybe_changed_after.rs b/tests/cycle_maybe_changed_after.rs index 8c00c484a..f411404d5 100644 --- a/tests/cycle_maybe_changed_after.rs +++ b/tests/cycle_maybe_changed_after.rs @@ -36,7 +36,7 @@ fn query_d<'db>(db: &'db dyn salsa::Database, input: Input) -> u32 { } } -fn query_a_initial(_db: &dyn Database, _input: Input) -> u32 { +fn query_a_initial(_db: &dyn Database, _id: salsa::Id, _input: Input) -> u32 { 0 } @@ -128,7 +128,7 @@ fn nested_cycle_fewer_dependencies_in_first_iteration() { }) } - fn head_initial(_db: &dyn Database, _input: Input) -> Option> { + fn head_initial(_db: &dyn Database, _id: salsa::Id, _input: Input) -> Option> { None } diff --git a/tests/cycle_output.rs b/tests/cycle_output.rs index 02a3b569f..c4a9384e0 100644 --- a/tests/cycle_output.rs +++ b/tests/cycle_output.rs @@ -40,7 +40,7 @@ fn query_b(db: &dyn Db, input: InputValue) -> u32 { query_a(db, input) } -fn cycle_initial(_db: &dyn Db, _input: InputValue) -> u32 { +fn cycle_initial(_db: &dyn Db, _id: salsa::Id, _input: InputValue) -> u32 { 0 } diff --git a/tests/cycle_recovery_call_back_into_cycle.rs b/tests/cycle_recovery_call_back_into_cycle.rs index 358f988ad..4ab236565 100644 --- a/tests/cycle_recovery_call_back_into_cycle.rs +++ b/tests/cycle_recovery_call_back_into_cycle.rs @@ -21,7 +21,7 @@ fn query(db: &dyn ValueDatabase) -> u32 { } } -fn cycle_initial(_db: &dyn ValueDatabase) -> u32 { +fn cycle_initial(_db: &dyn ValueDatabase, _id: salsa::Id) -> u32 { 0 } diff --git a/tests/cycle_recovery_call_query.rs b/tests/cycle_recovery_call_query.rs index 37341a202..a227d6122 100644 --- a/tests/cycle_recovery_call_query.rs +++ b/tests/cycle_recovery_call_query.rs @@ -17,7 +17,7 @@ fn query(db: &dyn salsa::Database) -> u32 { } } -fn cycle_initial(_db: &dyn salsa::Database) -> u32 { +fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id) -> u32 { 0 } diff --git a/tests/cycle_regression_455.rs b/tests/cycle_regression_455.rs index a083cb996..2957e5284 100644 --- a/tests/cycle_regression_455.rs +++ b/tests/cycle_regression_455.rs @@ -12,7 +12,7 @@ fn memoized_a<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> u32 { MyTracked::new(db, 0); memoized_b(db, tracked) } -fn cycle_initial(_db: &dyn Database, _input: MyTracked) -> u32 { +fn cycle_initial(_db: &dyn Database, _id: salsa::Id, _input: MyTracked) -> u32 { 0 } diff --git a/tests/cycle_result_dependencies.rs b/tests/cycle_result_dependencies.rs index 8e025f998..d614f956e 100644 --- a/tests/cycle_result_dependencies.rs +++ b/tests/cycle_result_dependencies.rs @@ -12,7 +12,7 @@ fn has_cycle(db: &dyn Database, input: Input) -> i32 { has_cycle(db, input) } -fn cycle_result(db: &dyn Database, input: Input) -> i32 { +fn cycle_result(db: &dyn Database, _id: salsa::Id, input: Input) -> i32 { input.value(db) } diff --git a/tests/cycle_tracked.rs b/tests/cycle_tracked.rs index 5ee4e1620..1a5b82ee6 100644 --- a/tests/cycle_tracked.rs +++ b/tests/cycle_tracked.rs @@ -110,7 +110,7 @@ fn cost_to_start<'db>(db: &'db dyn Database, node: Node<'db>) -> usize { min_cost } -fn max_initial(_db: &dyn Database, _node: Node) -> usize { +fn max_initial(_db: &dyn Database, _id: salsa::Id, _node: Node) -> usize { usize::MAX } @@ -246,7 +246,11 @@ fn create_tracked_in_cycle<'db>( } } -fn initial_with_structs(_db: &dyn Database, _input: GraphInput) -> Vec> { +fn initial_with_structs( + _db: &dyn Database, + _id: salsa::Id, + _input: GraphInput, +) -> Vec> { vec![] } diff --git a/tests/cycle_tracked_own_input.rs b/tests/cycle_tracked_own_input.rs index 79035bab5..0359c2df2 100644 --- a/tests/cycle_tracked_own_input.rs +++ b/tests/cycle_tracked_own_input.rs @@ -81,7 +81,7 @@ fn infer_type_param<'db>(db: &'db dyn salsa::Database, node: TypeParamNode) -> T } } -fn infer_class_initial(_db: &'_ dyn Database, _node: ClassNode) -> Type<'_> { +fn infer_class_initial(_db: &'_ dyn Database, _id: salsa::Id, _node: ClassNode) -> Type<'_> { Type::Unknown } diff --git a/tests/dataflow.rs b/tests/dataflow.rs index 793870322..69c91d513 100644 --- a/tests/dataflow.rs +++ b/tests/dataflow.rs @@ -71,7 +71,7 @@ fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { } } -fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { +fn def_cycle_initial(_db: &dyn Db, _id: salsa::Id, _def: Definition) -> Type { Type::Bottom } @@ -86,7 +86,7 @@ fn def_cycle_recover( cycle_recover(value, count) } -fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { +fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { Type::Bottom } diff --git a/tests/parallel/cycle_a_t1_b_t2.rs b/tests/parallel/cycle_a_t1_b_t2.rs index 6a434099e..95b2a3d28 100644 --- a/tests/parallel/cycle_a_t1_b_t2.rs +++ b/tests/parallel/cycle_a_t1_b_t2.rs @@ -45,7 +45,7 @@ fn query_b(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(a_value.0 + 1).min(MAX) } -fn initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_a_t1_b_t2_fallback.rs b/tests/parallel/cycle_a_t1_b_t2_fallback.rs index b2d6631cc..b49fa0448 100644 --- a/tests/parallel/cycle_a_t1_b_t2_fallback.rs +++ b/tests/parallel/cycle_a_t1_b_t2_fallback.rs @@ -41,11 +41,11 @@ fn query_b(db: &dyn KnobsDatabase) -> u32 { query_a(db) | OFFSET_B } -fn cycle_result_a(_db: &dyn KnobsDatabase) -> u32 { +fn cycle_result_a(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 { FALLBACK_A } -fn cycle_result_b(_db: &dyn KnobsDatabase) -> u32 { +fn cycle_result_b(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 { FALLBACK_B } diff --git a/tests/parallel/cycle_ab_peeping_c.rs b/tests/parallel/cycle_ab_peeping_c.rs index 8ed2b4fb6..c61f3c6ae 100644 --- a/tests/parallel/cycle_ab_peeping_c.rs +++ b/tests/parallel/cycle_ab_peeping_c.rs @@ -30,7 +30,7 @@ fn query_a(db: &dyn KnobsDatabase) -> CycleValue { b_value } -fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn cycle_initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_iteration_mismatch.rs b/tests/parallel/cycle_iteration_mismatch.rs index 61d1da01d..fa84bfb0d 100644 --- a/tests/parallel/cycle_iteration_mismatch.rs +++ b/tests/parallel/cycle_iteration_mismatch.rs @@ -62,7 +62,7 @@ fn query_f(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(b.0.max(e.0)) } -fn initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep.rs b/tests/parallel/cycle_nested_deep.rs index 3d46bbbc5..72d4ebf74 100644 --- a/tests/parallel/cycle_nested_deep.rs +++ b/tests/parallel/cycle_nested_deep.rs @@ -46,7 +46,7 @@ fn query_e(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -fn initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep_conditional.rs b/tests/parallel/cycle_nested_deep_conditional.rs index 544342e07..bf9a600b3 100644 --- a/tests/parallel/cycle_nested_deep_conditional.rs +++ b/tests/parallel/cycle_nested_deep_conditional.rs @@ -55,7 +55,7 @@ fn query_e(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -fn initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep_conditional_changed.rs b/tests/parallel/cycle_nested_deep_conditional_changed.rs index 03423b09a..95122bebd 100644 --- a/tests/parallel/cycle_nested_deep_conditional_changed.rs +++ b/tests/parallel/cycle_nested_deep_conditional_changed.rs @@ -61,7 +61,7 @@ fn query_e(db: &dyn salsa::Database, input: Input) -> CycleValue { query_c(db, input) } -fn initial(_db: &dyn salsa::Database, _input: Input) -> CycleValue { +fn initial(_db: &dyn salsa::Database, _id: salsa::Id, _input: Input) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_deep_panic.rs b/tests/parallel/cycle_nested_deep_panic.rs index 4356489c3..92d192be5 100644 --- a/tests/parallel/cycle_nested_deep_panic.rs +++ b/tests/parallel/cycle_nested_deep_panic.rs @@ -49,7 +49,7 @@ fn query_e(db: &dyn KnobsDatabase) -> CycleValue { query_c(db) } -fn initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_three_threads.rs b/tests/parallel/cycle_nested_three_threads.rs index 728fc3e70..d56dfd22a 100644 --- a/tests/parallel/cycle_nested_three_threads.rs +++ b/tests/parallel/cycle_nested_three_threads.rs @@ -55,7 +55,7 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(a_value.0.max(b_value.0)) } -fn initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_nested_three_threads_changed.rs b/tests/parallel/cycle_nested_three_threads_changed.rs index 626b3ef90..b9677ccc4 100644 --- a/tests/parallel/cycle_nested_three_threads_changed.rs +++ b/tests/parallel/cycle_nested_three_threads_changed.rs @@ -54,7 +54,7 @@ fn query_c(db: &dyn salsa::Database, input: Input) -> CycleValue { CycleValue(a_value.0.max(b_value.0)) } -fn initial(_db: &dyn salsa::Database, _input: Input) -> CycleValue { +fn initial(_db: &dyn salsa::Database, _id: salsa::Id, _input: Input) -> CycleValue { MIN } diff --git a/tests/parallel/cycle_panic.rs b/tests/parallel/cycle_panic.rs index 13c988f8f..ba05291a5 100644 --- a/tests/parallel/cycle_panic.rs +++ b/tests/parallel/cycle_panic.rs @@ -28,7 +28,7 @@ fn cycle_fn( panic!("cancel!") } -fn initial(_db: &dyn KnobsDatabase) -> u32 { +fn initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 { 0 } diff --git a/tests/parallel/cycle_provisional_depending_on_itself.rs b/tests/parallel/cycle_provisional_depending_on_itself.rs index bb615210e..2c27becb3 100644 --- a/tests/parallel/cycle_provisional_depending_on_itself.rs +++ b/tests/parallel/cycle_provisional_depending_on_itself.rs @@ -67,7 +67,7 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { b } -fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn cycle_initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue { MIN } From 76e65b1890c68b75f4d41db17c067b4489e843ac Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Wed, 29 Oct 2025 22:33:51 +0900 Subject: [PATCH 59/65] doc: Explain the motivation for breaking API changes made in #1012 and #1015 (#1016) * doc: Explain the motivation for breaking API changes made in #1012 and #1015 * Update book/src/cycles.md --------- Co-authored-by: Micha Reiser --- book/src/cycles.md | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/book/src/cycles.md b/book/src/cycles.md index 2e2c6e7b8..bd0675bdc 100644 --- a/book/src/cycles.md +++ b/book/src/cycles.md @@ -7,23 +7,23 @@ Salsa also supports recovering from query cycles via fixed-point iteration. Fixe In order to support fixed-point iteration for a query, provide the `cycle_fn` and `cycle_initial` arguments to `salsa::tracked`: ```rust -#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial_fn)] +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] fn query(db: &dyn salsa::Database) -> u32 { // ... } -fn cycle_fn(_db: &dyn KnobsDatabase, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { +fn cycle_fn(_db: &dyn KnobsDatabase, _id: salsa::Id, _last_provisional_value: &u32, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { salsa::CycleRecoveryAction::Iterate } -fn initial(_db: &dyn KnobsDatabase) -> u32 { +fn cycle_initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 { 0 } ``` The `cycle_fn` is optional. The default implementation always returns `Iterate`. -If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `initial_fn` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the current value and the number of iterations that have occurred so far). The `cycle_fn` can return `salsa::CycleRecoveryAction::Iterate` to indicate that the cycle should iterate again, or `salsa::CycleRecoveryAction::Fallback(value)` to indicate that fixpoint iteration should resume starting with the given value (which should be a value that will converge quickly). +If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `cycle_initial` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the current value and the number of iterations that have occurred so far). The `cycle_fn` can return `salsa::CycleRecoveryAction::Iterate` to indicate that the cycle should iterate again, or `salsa::CycleRecoveryAction::Fallback(value)` to indicate that fixpoint iteration should continue with the given value (which should be a value that will converge quickly). The cycle will iterate until it converges: that is, until two successive iterations produce the same result. @@ -39,6 +39,11 @@ Consider a two-query cycle where `query_a` calls `query_b`, and `query_b` calls Fixed-point iteration is a powerful tool, but is also easy to misuse, potentially resulting in infinite iteration. To avoid this, ensure that all queries participating in fixpoint iteration are deterministic and monotone. +To guarantee convergence, you can leverage the `last_provisional_value` (3rd parameter) received by `cycle_fn`. +When the `cycle_fn` recalculates a value, you can implement a strategy that references the last provisional value to "join" values ​​or "widen" it and return a fallback value. This ensures monotonicity of the calculation and suppresses infinite oscillation of values ​​between cycles. + +Also, in fixed-point iteration, it is advantageous to be able to identify which cycle head seeded a value. By embedding a `salsa::Id` (2nd parameter) in the initial value as a "cycle marker", the recovery function can detect self-originated recursion. + ## Calling Salsa queries from within `cycle_fn` or `cycle_initial` It is permitted to call other Salsa queries from within the `cycle_fn` and `cycle_initial` functions. However, if these functions re-enter the same cycle, this can lead to unpredictable results. Take care which queries are called from within cycle-recovery functions, and avoid triggering further cycles. From 671c3dcba6ee94794876fd904606cd45a7b71599 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Wed, 29 Oct 2025 20:29:46 +0100 Subject: [PATCH 60/65] Only use provisional values from the same revision (#1019) --- src/function/fetch.rs | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 14d7a93d7..f1c58eda1 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -195,26 +195,24 @@ where // existing provisional memo if it exists let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); if let Some(memo) = memo_guard { - if memo.value.is_some() && memo.revisions.cycle_heads().contains(&database_key_index) { - let can_shallow_update = self.shallow_verify_memo(zalsa, database_key_index, memo); - if can_shallow_update.yes() { - self.update_shallow(zalsa, database_key_index, memo, can_shallow_update); - - if C::CYCLE_STRATEGY == CycleRecoveryStrategy::Fixpoint { - memo.revisions - .cycle_heads() - .remove_all_except(database_key_index); - } + if memo.verified_at.load() == zalsa.current_revision() + && memo.value.is_some() + && memo.revisions.cycle_heads().contains(&database_key_index) + { + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::Fixpoint { + memo.revisions + .cycle_heads() + .remove_all_except(database_key_index); + } - crate::tracing::debug!( - "hit cycle at {database_key_index:#?}, \ + crate::tracing::debug!( + "hit cycle at {database_key_index:#?}, \ returning last provisional value: {:#?}", - memo.revisions - ); + memo.revisions + ); - // SAFETY: memo is present in memo_map. - return unsafe { self.extend_memo_lifetime(memo) }; - } + // SAFETY: memo is present in memo_map. + return unsafe { self.extend_memo_lifetime(memo) }; } } From 46aa2cfadc91c798b3a1d5fefc2fe19a5ba379bc Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Fri, 31 Oct 2025 22:55:41 +0100 Subject: [PATCH 61/65] Update compile fail snapshots to match new rust stable output (#1020) * Update expected test output to match new rust stable output * Remove 1.90 constraint from compile_fail tests * Discard changes to tests/persistence.rs * Update snapshot on unix * Update with the correct rust version * Discard changes to tests/compile_fail.rs --- .../incomplete_persistence.stderr | 32 +++++++++++++++---- tests/compile-fail/span-tracked-getter.stderr | 2 +- ...ot-work-if-the-key-is-a-salsa-input.stderr | 9 ++++-- ...work-if-the-key-is-a-salsa-interned.stderr | 9 ++++-- .../tracked_method_incompatibles.stderr | 2 +- 5 files changed, 42 insertions(+), 12 deletions(-) diff --git a/tests/compile-fail/incomplete_persistence.stderr b/tests/compile-fail/incomplete_persistence.stderr index f7082ecca..a7a65b94c 100644 --- a/tests/compile-fail/incomplete_persistence.stderr +++ b/tests/compile-fail/incomplete_persistence.stderr @@ -4,9 +4,14 @@ error[E0277]: the trait bound `NotPersistable<'_>: serde::Serialize` is not sati 1 | #[salsa::tracked(persist)] | ^^^^^^^^^^^^^^^^^^^^^^^^^^ | | - | the trait `Serialize` is not implemented for `NotPersistable<'_>` + | unsatisfied trait bound | required by a bound introduced by this call | +help: the trait `Serialize` is not implemented for `NotPersistable<'_>` + --> tests/compile-fail/incomplete_persistence.rs:6:1 + | +6 | #[salsa::tracked] + | ^^^^^^^^^^^^^^^^^ = note: for local types consider adding `#[derive(serde::Serialize)]` to your `NotPersistable<'_>` type = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `Serialize`: @@ -26,8 +31,13 @@ error[E0277]: the trait bound `NotPersistable<'_>: serde::Deserialize<'de>` is n --> tests/compile-fail/incomplete_persistence.rs:1:1 | 1 | #[salsa::tracked(persist)] - | ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Deserialize<'_>` is not implemented for `NotPersistable<'_>` + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ unsatisfied trait bound + | +help: the trait `Deserialize<'_>` is not implemented for `NotPersistable<'_>` + --> tests/compile-fail/incomplete_persistence.rs:6:1 | +6 | #[salsa::tracked] + | ^^^^^^^^^^^^^^^^^ = note: for local types consider adding `#[derive(serde::Deserialize)]` to your `NotPersistable<'_>` type = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `Deserialize<'de>`: @@ -47,8 +57,13 @@ error[E0277]: the trait bound `NotPersistable<'db>: serde::Serialize` is not sat --> tests/compile-fail/incomplete_persistence.rs:12:45 | 12 | fn query(_db: &dyn salsa::Database, _input: NotPersistable<'_>) {} - | ^^^^^^^^^^^^^^^^^^ the trait `Serialize` is not implemented for `NotPersistable<'db>` + | ^^^^^^^^^^^^^^^^^^ unsatisfied trait bound | +help: the trait `Serialize` is not implemented for `NotPersistable<'db>` + --> tests/compile-fail/incomplete_persistence.rs:6:1 + | + 6 | #[salsa::tracked] + | ^^^^^^^^^^^^^^^^^ = note: for local types consider adding `#[derive(serde::Serialize)]` to your `NotPersistable<'db>` type = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `Serialize`: @@ -69,14 +84,19 @@ note: required by a bound in `query_input_is_persistable` | | | required by a bound in this function | required by this bound in `query_input_is_persistable` - = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the macro `salsa::plumbing::setup_tracked_struct` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) error[E0277]: the trait bound `NotPersistable<'db>: serde::Deserialize<'de>` is not satisfied --> tests/compile-fail/incomplete_persistence.rs:12:45 | 12 | fn query(_db: &dyn salsa::Database, _input: NotPersistable<'_>) {} - | ^^^^^^^^^^^^^^^^^^ the trait `for<'de> Deserialize<'de>` is not implemented for `NotPersistable<'db>` + | ^^^^^^^^^^^^^^^^^^ unsatisfied trait bound + | +help: the trait `for<'de> Deserialize<'de>` is not implemented for `NotPersistable<'db>` + --> tests/compile-fail/incomplete_persistence.rs:6:1 | + 6 | #[salsa::tracked] + | ^^^^^^^^^^^^^^^^^ = note: for local types consider adding `#[derive(serde::Deserialize)]` to your `NotPersistable<'db>` type = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `Deserialize<'de>`: @@ -97,4 +117,4 @@ note: required by a bound in `query_input_is_persistable` | | | required by a bound in this function | required by this bound in `query_input_is_persistable` - = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the macro `salsa::plumbing::setup_tracked_struct` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/compile-fail/span-tracked-getter.stderr b/tests/compile-fail/span-tracked-getter.stderr index fcf546c72..bc304a5c6 100644 --- a/tests/compile-fail/span-tracked-getter.stderr +++ b/tests/compile-fail/span-tracked-getter.stderr @@ -29,4 +29,4 @@ warning: variable does not need to be mutable | | | help: remove this `mut` | - = note: `#[warn(unused_mut)]` on by default + = note: `#[warn(unused_mut)]` (part of `#[warn(unused)]`) on by default diff --git a/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.stderr b/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.stderr index 580ea67bf..5c6420632 100644 --- a/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.stderr +++ b/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.stderr @@ -2,8 +2,13 @@ error[E0277]: the trait bound `MyInput: TrackedStructInDb` is not satisfied --> tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.rs:15:1 | 15 | #[salsa::tracked(specify)] - | ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `TrackedStructInDb` is not implemented for `MyInput` + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ unsatisfied trait bound | +help: the trait `TrackedStructInDb` is not implemented for `MyInput` + --> tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-input.rs:5:1 + | + 5 | #[salsa::input] + | ^^^^^^^^^^^^^^^ = help: the trait `TrackedStructInDb` is implemented for `MyTracked<'_>` note: required by a bound in `salsa::function::specify::>::specify_and_record` --> src/function/specify.rs @@ -13,4 +18,4 @@ note: required by a bound in `salsa::function::specify::: TrackedStructInDb, | ^^^^^^^^^^^^^^^^^ required by this bound in `salsa::function::specify::>::specify_and_record` - = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::input` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.stderr b/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.stderr index 01a4b8f60..6c6ba51e0 100644 --- a/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.stderr +++ b/tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.stderr @@ -2,8 +2,13 @@ error[E0277]: the trait bound `MyInterned<'_>: TrackedStructInDb` is not satisfi --> tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.rs:15:1 | 15 | #[salsa::tracked(specify)] - | ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `TrackedStructInDb` is not implemented for `MyInterned<'_>` + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ unsatisfied trait bound | +help: the trait `TrackedStructInDb` is not implemented for `MyInterned<'_>` + --> tests/compile-fail/specify-does-not-work-if-the-key-is-a-salsa-interned.rs:5:1 + | + 5 | #[salsa::interned] + | ^^^^^^^^^^^^^^^^^^ = help: the trait `TrackedStructInDb` is implemented for `MyTracked<'_>` note: required by a bound in `salsa::function::specify::>::specify_and_record` --> src/function/specify.rs @@ -13,4 +18,4 @@ note: required by a bound in `salsa::function::specify::: TrackedStructInDb, | ^^^^^^^^^^^^^^^^^ required by this bound in `salsa::function::specify::>::specify_and_record` - = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the macro `salsa::plumbing::setup_tracked_fn` which comes from the expansion of the attribute macro `salsa::interned` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/compile-fail/tracked_method_incompatibles.stderr b/tests/compile-fail/tracked_method_incompatibles.stderr index 72a27a33b..5700eb556 100644 --- a/tests/compile-fail/tracked_method_incompatibles.stderr +++ b/tests/compile-fail/tracked_method_incompatibles.stderr @@ -52,7 +52,7 @@ warning: unused variable: `db` 9 | fn ref_self(&self, db: &dyn salsa::Database) {} | ^^ help: if this is intentional, prefix it with an underscore: `_db` | - = note: `#[warn(unused_variables)]` on by default + = note: `#[warn(unused_variables)]` (part of `#[warn(unused)]`) on by default warning: unused variable: `db` --> tests/compile-fail/tracked_method_incompatibles.rs:15:32 From c762869fd590855e444a957afbea355dec7f6028 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Sat, 1 Nov 2025 01:36:55 +0100 Subject: [PATCH 62/65] Always increment iteration count (#1017) --- src/function/execute.rs | 39 ++++++++++++++++++++++------- src/function/fetch.rs | 18 ++++++++++++- src/function/maybe_changed_after.rs | 23 ----------------- src/function/memo.rs | 20 ++++++--------- src/zalsa_local.rs | 27 ++++++++++++++------ 5 files changed, 73 insertions(+), 54 deletions(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index d299b0966..9d6758730 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -172,11 +172,7 @@ where let mut iteration_count = IterationCount::initial(); if let Some(old_memo) = opt_old_memo { - if old_memo.verified_at.load() == zalsa.current_revision() - && old_memo.cycle_heads().contains(&database_key_index) - { - let memo_iteration_count = old_memo.revisions.iteration(); - + if old_memo.verified_at.load() == zalsa.current_revision() { // The `DependencyGraph` locking propagates panics when another thread is blocked on a panicking query. // However, the locking doesn't handle the case where a thread fetches the result of a panicking // cycle head query **after** all locks were released. That's what we do here. @@ -189,8 +185,14 @@ where tracing::warn!("Propagating panic for cycle head that panicked in an earlier execution in that revision"); Cancelled::PropagatedPanic.throw(); } - last_provisional_memo = Some(old_memo); - iteration_count = memo_iteration_count; + + // Only use the last provisional memo if it was a cycle head in the last iteration. This is to + // force at least two executions. + if old_memo.cycle_heads().contains(&database_key_index) { + last_provisional_memo = Some(old_memo); + } + + iteration_count = old_memo.revisions.iteration(); } } @@ -216,6 +218,14 @@ where // If there are no cycle heads, break out of the loop (`cycle_heads_mut` returns `None` if the cycle head list is empty) let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() else { + iteration_count = iteration_count.increment().unwrap_or_else(|| { + tracing::warn!("{database_key_index:?}: execute: too many cycle iterations"); + panic!("{database_key_index:?}: execute: too many cycle iterations") + }); + completed_query + .revisions + .update_iteration_count_mut(database_key_index, iteration_count); + claim_guard.set_release_mode(ReleaseMode::SelfOnly); break (new_value, completed_query); }; @@ -289,6 +299,15 @@ where } completed_query.revisions.set_cycle_heads(cycle_heads); + + iteration_count = iteration_count.increment().unwrap_or_else(|| { + tracing::warn!("{database_key_index:?}: execute: too many cycle iterations"); + panic!("{database_key_index:?}: execute: too many cycle iterations") + }); + completed_query + .revisions + .update_iteration_count_mut(database_key_index, iteration_count); + break (new_value, completed_query); } @@ -555,8 +574,10 @@ impl<'a, C: Configuration> PoisonProvisionalIfPanicking<'a, C> { impl Drop for PoisonProvisionalIfPanicking<'_, C> { fn drop(&mut self) { if thread::panicking() { - let revisions = - QueryRevisions::fixpoint_initial(self.ingredient.database_key_index(self.id)); + let revisions = QueryRevisions::fixpoint_initial( + self.ingredient.database_key_index(self.id), + IterationCount::initial(), + ); let memo = Memo::new(None, self.zalsa.current_revision(), revisions); self.ingredient diff --git a/src/function/fetch.rs b/src/function/fetch.rs index f1c58eda1..588b08bb1 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -195,6 +195,9 @@ where // existing provisional memo if it exists let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); if let Some(memo) = memo_guard { + // Ideally, we'd use the last provisional memo even if it wasn't a cycle head in the last iteration + // but that would require inserting itself as a cycle head, which either requires clone + // on the value OR a concurrent `Vec` for cycle heads. if memo.verified_at.load() == zalsa.current_revision() && memo.value.is_some() && memo.revisions.cycle_heads().contains(&database_key_index) @@ -233,7 +236,20 @@ where "hit cycle at {database_key_index:#?}, \ inserting and returning fixpoint initial value" ); - let revisions = QueryRevisions::fixpoint_initial(database_key_index); + + let iteration = memo_guard + .and_then(|old_memo| { + if old_memo.verified_at.load() == zalsa.current_revision() + && old_memo.value.is_some() + { + Some(old_memo.revisions.iteration()) + } else { + None + } + }) + .unwrap_or(IterationCount::initial()); + let revisions = QueryRevisions::fixpoint_initial(database_key_index, iteration); + let initial_value = C::cycle_initial(db, id, C::id_to_input(zalsa, id)); self.insert_memo( zalsa, diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 4198631b9..20440883e 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -509,8 +509,6 @@ where head_iteration_count, memo_iteration_count: current_iteration_count, verified_at: head_verified_at, - cycle_heads, - database_key_index: head_database_key, } => { if head_verified_at != memo_verified_at { return false; @@ -519,27 +517,6 @@ where if head_iteration_count != current_iteration_count { return false; } - - // Check if the memo is still a cycle head and hasn't changed - // to a normal cycle participant. This is to force re-execution in - // a scenario like this: - // - // * There's a nested cycle with the outermost query A - // * B participates in the cycle and is a cycle head in the first few iterations - // * B becomes a non-cycle head in a later iteration - // * There's a query `C` that has `B` as its cycle head - // - // The crucial point is that `B` switches from being a cycle head to being a regular cycle participant. - // The issue with that is that `A` doesn't update `B`'s `iteration_count `when the iteration completes - // because it only does that for cycle heads (and collecting all queries participating in a query would be sort of expensive?). - // - // When we now pull `C` in a later iteration, `validate_same_iteration` iterates over all its cycle heads (`B`), - // and check if the iteration count still matches. Which is the case because `A` didn't update `B`'s iteration count. - // - // That's why we also check if `B` is still a cycle head in the current iteration. - if !cycle_heads.contains(&head_database_key) { - return false; - } } _ => { return false; diff --git a/src/function/memo.rs b/src/function/memo.rs index fd830ced3..d8faf3e0b 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -409,11 +409,9 @@ mod persistence { pub(super) enum TryClaimHeadsResult<'me> { /// Claiming the cycle head results in a cycle. Cycle { - database_key_index: DatabaseKeyIndex, head_iteration_count: IterationCount, memo_iteration_count: IterationCount, verified_at: Revision, - cycle_heads: &'me CycleHeads, }, /// The cycle head is not finalized, but it can be claimed. @@ -460,28 +458,24 @@ impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { let provisional_status = ingredient .provisional_status(self.zalsa, head_key_index) .expect("cycle head memo to exist"); - let (current_iteration_count, verified_at, cycle_heads) = match provisional_status { + let (current_iteration_count, verified_at) = match provisional_status { ProvisionalStatus::Provisional { iteration, verified_at, - cycle_heads, - } => (iteration, verified_at, cycle_heads), + cycle_heads: _, + } => (iteration, verified_at), ProvisionalStatus::Final { iteration, verified_at, - } => (iteration, verified_at, empty_cycle_heads()), - ProvisionalStatus::FallbackImmediate => ( - IterationCount::initial(), - self.zalsa.current_revision(), - empty_cycle_heads(), - ), + } => (iteration, verified_at), + ProvisionalStatus::FallbackImmediate => { + (IterationCount::initial(), self.zalsa.current_revision()) + } }; Some(TryClaimHeadsResult::Cycle { - database_key_index: head_database_key, memo_iteration_count: current_iteration_count, head_iteration_count: head.iteration_count.load(), - cycle_heads, verified_at, }) } diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 7b0399178..f43eb78eb 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -639,7 +639,7 @@ const _: [(); std::mem::size_of::()] = [(); std::mem::size_of::<[usize; if cfg!(feature = "accumulator") { 7 } else { 3 }]>()]; impl QueryRevisions { - pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex) -> Self { + pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex, iteration: IterationCount) -> Self { Self { changed_at: Revision::start(), durability: Durability::MAX, @@ -651,8 +651,8 @@ impl QueryRevisions { #[cfg(feature = "accumulator")] AccumulatedMap::default(), ThinVec::default(), - CycleHeads::initial(query, IterationCount::initial()), - IterationCount::initial(), + CycleHeads::initial(query, iteration), + iteration, ), } } @@ -743,12 +743,23 @@ impl QueryRevisions { cycle_head_index: DatabaseKeyIndex, iteration_count: IterationCount, ) { - if let Some(extra) = &mut self.extra.0 { - extra.iteration.store_mut(iteration_count); + match &mut self.extra.0 { + None => { + self.extra = QueryRevisionsExtra::new( + #[cfg(feature = "accumulator")] + AccumulatedMap::default(), + ThinVec::default(), + empty_cycle_heads().clone(), + iteration_count, + ); + } + Some(extra) => { + extra.iteration.store_mut(iteration_count); - extra - .cycle_heads - .update_iteration_count_mut(cycle_head_index, iteration_count); + extra + .cycle_heads + .update_iteration_count_mut(cycle_head_index, iteration_count); + } } } From 664750a6e588ed23a0d2d9105a02cb5993c8e178 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Mon, 3 Nov 2025 21:44:22 +0100 Subject: [PATCH 63/65] Track cycle function dependencies as part of the cyclic query (#1018) * Track cycle function dependenciees as part of the cyclic query * Add regression test * Discard changes to src/function/backdate.rs * Update comment * Fix merge error * Refine comment --- src/active_query.rs | 4 ++ src/cycle.rs | 4 ++ src/function/execute.rs | 61 +++++++++++++++------ src/function/maybe_changed_after.rs | 5 +- src/zalsa_local.rs | 12 ++++ tests/cycle_recovery_dependencies.rs | 82 ++++++++++++++++++++++++++++ 6 files changed, 149 insertions(+), 19 deletions(-) create mode 100644 tests/cycle_recovery_dependencies.rs diff --git a/src/active_query.rs b/src/active_query.rs index bb5987fcd..c80cded3b 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -91,6 +91,10 @@ impl ActiveQuery { .mark_all_active(active_tracked_ids.iter().copied()); } + pub(super) fn take_cycle_heads(&mut self) -> CycleHeads { + std::mem::take(&mut self.cycle_heads) + } + pub(super) fn add_read( &mut self, input: DatabaseKeyIndex, diff --git a/src/cycle.rs b/src/cycle.rs index fcbadf891..3f6f70aa0 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -490,4 +490,8 @@ impl<'db> ProvisionalStatus<'db> { _ => empty_cycle_heads(), } } + + pub(crate) const fn is_provisional(&self) -> bool { + matches!(self, ProvisionalStatus::Provisional { .. }) + } } diff --git a/src/function/execute.rs b/src/function/execute.rs index 9d6758730..53bc640a2 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -56,20 +56,25 @@ where }); let (new_value, mut completed_query) = match C::CYCLE_STRATEGY { - CycleRecoveryStrategy::Panic => Self::execute_query( - db, - zalsa, - zalsa_local.push_query(database_key_index, IterationCount::initial()), - opt_old_memo, - ), + CycleRecoveryStrategy::Panic => { + let (new_value, active_query) = Self::execute_query( + db, + zalsa, + zalsa_local.push_query(database_key_index, IterationCount::initial()), + opt_old_memo, + ); + (new_value, active_query.pop()) + } CycleRecoveryStrategy::FallbackImmediate => { - let (mut new_value, mut completed_query) = Self::execute_query( + let (mut new_value, active_query) = Self::execute_query( db, zalsa, zalsa_local.push_query(database_key_index, IterationCount::initial()), opt_old_memo, ); + let mut completed_query = active_query.pop(); + if let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() { // Did the new result we got depend on our own provisional value, in a cycle? if cycle_heads.contains(&database_key_index) { @@ -198,9 +203,10 @@ where let _poison_guard = PoisonProvisionalIfPanicking::new(self, zalsa, id, memo_ingredient_index); - let mut active_query = zalsa_local.push_query(database_key_index, iteration_count); let (new_value, completed_query) = loop { + let active_query = zalsa_local.push_query(database_key_index, iteration_count); + // Tracked struct ids that existed in the previous revision // but weren't recreated in the last iteration. It's important that we seed the next // query with these ids because the query might re-create them as part of the next iteration. @@ -209,29 +215,32 @@ where // if they aren't recreated when reaching the final iteration. active_query.seed_tracked_struct_ids(&last_stale_tracked_ids); - let (mut new_value, mut completed_query) = Self::execute_query( + let (mut new_value, mut active_query) = Self::execute_query( db, zalsa, active_query, last_provisional_memo.or(opt_old_memo), ); - // If there are no cycle heads, break out of the loop (`cycle_heads_mut` returns `None` if the cycle head list is empty) - let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() else { + // Take the cycle heads to not-fight-rust's-borrow-checker. + let mut cycle_heads = active_query.take_cycle_heads(); + + // If there are no cycle heads, break out of the loop. + if cycle_heads.is_empty() { iteration_count = iteration_count.increment().unwrap_or_else(|| { tracing::warn!("{database_key_index:?}: execute: too many cycle iterations"); panic!("{database_key_index:?}: execute: too many cycle iterations") }); + + let mut completed_query = active_query.pop(); completed_query .revisions .update_iteration_count_mut(database_key_index, iteration_count); claim_guard.set_release_mode(ReleaseMode::SelfOnly); break (new_value, completed_query); - }; + } - // Take the cycle heads to not-fight-rust's-borrow-checker. - let mut cycle_heads = std::mem::take(cycle_heads); let mut missing_heads: SmallVec<[(DatabaseKeyIndex, IterationCount); 1]> = SmallVec::new_const(); let mut max_iteration_count = iteration_count; @@ -262,6 +271,11 @@ where .provisional_status(zalsa, head.database_key_index.key_index()) .expect("cycle head memo must have been created during the execution"); + // A query should only ever depend on other heads that are provisional. + // If this invariant is violated, it means that this query participates in a cycle, + // but it wasn't executed in the last iteration of said cycle. + assert!(provisional_status.is_provisional()); + for nested_head in provisional_status.cycle_heads() { let nested_as_tuple = ( nested_head.database_key_index, @@ -298,6 +312,8 @@ where claim_guard.set_release_mode(ReleaseMode::SelfOnly); } + let mut completed_query = active_query.pop(); + *completed_query.revisions.verified_final.get_mut() = false; completed_query.revisions.set_cycle_heads(cycle_heads); iteration_count = iteration_count.increment().unwrap_or_else(|| { @@ -378,8 +394,17 @@ where this_converged = C::values_equal(&new_value, last_provisional_value); } } + + let new_cycle_heads = active_query.take_cycle_heads(); + for head in new_cycle_heads { + if !cycle_heads.contains(&head.database_key_index) { + panic!("Cycle recovery function for {database_key_index:?} introduced a cycle, depending on {:?}. This is not allowed.", head.database_key_index); + } + } } + let mut completed_query = active_query.pop(); + if let Some(outer_cycle) = outer_cycle { tracing::info!( "Detected nested cycle {database_key_index:?}, iterate it as part of the outer cycle {outer_cycle:?}" @@ -390,6 +415,7 @@ where completed_query .revisions .set_cycle_converged(this_converged); + *completed_query.revisions.verified_final.get_mut() = false; // Transfer ownership of this query to the outer cycle, so that it can claim it // and other threads don't compete for the same lock. @@ -428,9 +454,9 @@ where } *completed_query.revisions.verified_final.get_mut() = true; - break (new_value, completed_query); } + *completed_query.revisions.verified_final.get_mut() = false; // The fixpoint iteration hasn't converged. Iterate again... iteration_count = iteration_count.increment().unwrap_or_else(|| { @@ -484,7 +510,6 @@ where last_provisional_memo = Some(new_memo); last_stale_tracked_ids = completed_query.stale_tracked_structs; - active_query = zalsa_local.push_query(database_key_index, iteration_count); continue; }; @@ -503,7 +528,7 @@ where zalsa: &'db Zalsa, active_query: ActiveQueryGuard<'db>, opt_old_memo: Option<&Memo<'db, C>>, - ) -> (C::Output<'db>, CompletedQuery) { + ) -> (C::Output<'db>, ActiveQueryGuard<'db>) { if let Some(old_memo) = opt_old_memo { // If we already executed this query once, then use the tracked-struct ids from the // previous execution as the starting point for the new one. @@ -528,7 +553,7 @@ where C::id_to_input(zalsa, active_query.database_key_index.key_index()), ); - (new_value, active_query.pop()) + (new_value, active_query) } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 20440883e..165a3fb02 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -592,7 +592,10 @@ where cycle_heads.append_heads(&mut child_cycle_heads); match input_result { - VerifyResult::Changed => return VerifyResult::changed(), + VerifyResult::Changed => { + cycle_heads.remove_head(database_key_index); + return VerifyResult::changed(); + } #[cfg(feature = "accumulator")] VerifyResult::Unchanged { accumulated } => { inputs |= accumulated; diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index f43eb78eb..bde3b6b24 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -1213,6 +1213,18 @@ impl ActiveQueryGuard<'_> { } } + pub(crate) fn take_cycle_heads(&mut self) -> CycleHeads { + // SAFETY: We do not access the query stack reentrantly. + unsafe { + self.local_state.with_query_stack_unchecked_mut(|stack| { + #[cfg(debug_assertions)] + assert_eq!(stack.len(), self.push_len); + let frame = stack.last_mut().unwrap(); + frame.take_cycle_heads() + }) + } + } + /// Invoked when the query has successfully completed execution. fn complete(self) -> CompletedQuery { // SAFETY: We do not access the query stack reentrantly. diff --git a/tests/cycle_recovery_dependencies.rs b/tests/cycle_recovery_dependencies.rs new file mode 100644 index 000000000..b26ce973b --- /dev/null +++ b/tests/cycle_recovery_dependencies.rs @@ -0,0 +1,82 @@ +#![cfg(feature = "inventory")] + +//! Queries or inputs read within the cycle recovery function +//! are tracked on the cycle function and don't "leak" into the +//! function calling the query with cycle handling. + +use expect_test::expect; +use salsa::Setter as _; + +use crate::common::LogDatabase; + +mod common; + +#[salsa::input] +struct Input { + value: u32, +} + +#[salsa::tracked] +fn entry(db: &dyn salsa::Database, input: Input) -> u32 { + query(db, input) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query(db: &dyn salsa::Database, input: Input) -> u32 { + let val = query(db, input); + if val < 5 { + val + 1 + } else { + val + } +} + +fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id, _input: Input) -> u32 { + 0 +} + +fn cycle_fn( + db: &dyn salsa::Database, + _id: salsa::Id, + _last_provisional_value: &u32, + _value: &u32, + _count: u32, + input: Input, +) -> salsa::CycleRecoveryAction { + let _input = input.value(db); + salsa::CycleRecoveryAction::Iterate +} + +#[test_log::test] +fn the_test() { + let mut db = common::EventLoggerDatabase::default(); + + let input = Input::new(&db, 1); + assert_eq!(entry(&db, input), 5); + + db.assert_logs_len(15); + + input.set_value(&mut db).to(2); + + assert_eq!(entry(&db, input), 5); + db.assert_logs(expect![[r#" + [ + "DidSetCancellationFlag", + "WillCheckCancellation", + "WillCheckCancellation", + "WillCheckCancellation", + "WillExecute { database_key: query(Id(0)) }", + "WillCheckCancellation", + "WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(1) }", + "WillCheckCancellation", + "WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(2) }", + "WillCheckCancellation", + "WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(3) }", + "WillCheckCancellation", + "WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(4) }", + "WillCheckCancellation", + "WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(5) }", + "WillCheckCancellation", + "DidValidateMemoizedValue { database_key: entry(Id(0)) }", + ]"#]]); +} From 05a9af7f554b64b8aadc2eeb6f2caf73d0408d09 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Wed, 5 Nov 2025 13:37:23 +0100 Subject: [PATCH 64/65] Call `cycle_fn` for every iteration (#1021) * Call `cycle_fn` for every iteration * Update documentation * Clippy * Remove `CycleRecoveryAction` --- benches/dataflow.rs | 22 +++++----- book/src/cycles.md | 14 +++--- .../salsa-macro-rules/src/setup_tracked_fn.rs | 4 +- .../src/unexpected_cycle_recovery.rs | 4 +- src/cycle.rs | 31 +++---------- src/function.rs | 25 +++++++---- src/function/execute.rs | 43 +++++++------------ src/function/memo.rs | 8 ++-- src/lib.rs | 4 +- tests/cycle.rs | 20 +++++---- tests/cycle_accumulate.rs | 6 +-- tests/cycle_recovery_call_back_into_cycle.rs | 12 ++++-- tests/cycle_recovery_call_query.rs | 6 +-- tests/cycle_recovery_dependencies.rs | 6 +-- tests/dataflow.rs | 38 +++++++++------- tests/parallel/cycle_panic.rs | 4 +- 16 files changed, 116 insertions(+), 131 deletions(-) diff --git a/benches/dataflow.rs b/benches/dataflow.rs index a548c806a..cf20140f6 100644 --- a/benches/dataflow.rs +++ b/benches/dataflow.rs @@ -6,7 +6,7 @@ use std::collections::BTreeSet; use std::iter::IntoIterator; use codspeed_criterion_compat::{criterion_group, criterion_main, BatchSize, Criterion}; -use salsa::{CycleRecoveryAction, Database as Db, Setter}; +use salsa::{Database as Db, Setter}; /// A Use of a symbol. #[salsa::input] @@ -78,10 +78,10 @@ fn def_cycle_recover( _db: &dyn Db, _id: salsa::Id, _last_provisional_value: &Type, - value: &Type, + value: Type, count: u32, _def: Definition, -) -> CycleRecoveryAction { +) -> Type { cycle_recover(value, count) } @@ -93,24 +93,24 @@ fn use_cycle_recover( _db: &dyn Db, _id: salsa::Id, _last_provisional_value: &Type, - value: &Type, + value: Type, count: u32, _use: Use, -) -> CycleRecoveryAction { +) -> Type { cycle_recover(value, count) } -fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction { - match value { - Type::Bottom => CycleRecoveryAction::Iterate, +fn cycle_recover(value: Type, count: u32) -> Type { + match &value { + Type::Bottom => value, Type::Values(_) => { if count > 4 { - CycleRecoveryAction::Fallback(Type::Top) + Type::Top } else { - CycleRecoveryAction::Iterate + value } } - Type::Top => CycleRecoveryAction::Iterate, + Type::Top => value, } } diff --git a/book/src/cycles.md b/book/src/cycles.md index bd0675bdc..023f5bb79 100644 --- a/book/src/cycles.md +++ b/book/src/cycles.md @@ -12,8 +12,8 @@ fn query(db: &dyn salsa::Database) -> u32 { // ... } -fn cycle_fn(_db: &dyn KnobsDatabase, _id: salsa::Id, _last_provisional_value: &u32, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate +fn cycle_fn(_db: &dyn KnobsDatabase, _id: salsa::Id, _last_provisional_value: &u32, value: u32, _count: u32) -> u32 { + value } fn cycle_initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 { @@ -21,13 +21,11 @@ fn cycle_initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 { } ``` -The `cycle_fn` is optional. The default implementation always returns `Iterate`. +The `cycle_fn` is optional. The default implementation always returns the computed `value`. -If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `cycle_initial` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the current value and the number of iterations that have occurred so far). The `cycle_fn` can return `salsa::CycleRecoveryAction::Iterate` to indicate that the cycle should iterate again, or `salsa::CycleRecoveryAction::Fallback(value)` to indicate that fixpoint iteration should continue with the given value (which should be a value that will converge quickly). +If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `cycle_initial` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the last provisional value, the newly computed value, and the number of iterations that have occurred so far). The `cycle_fn` can return the `value` parameter to continue iterating with the computed value, or return a different value (a fallback value) to continue iteration with that value instead. -The cycle will iterate until it converges: that is, until two successive iterations produce the same result. - -If the `cycle_fn` returns `Fallback`, the cycle will still continue to iterate (using the given value as a new starting point), in order to verify that the fallback value results in a stable converged cycle. It is not permitted to use a fallback value that does not converge, because this would leave the cycle in an unpredictable state, depending on the order of query execution. +The cycle will iterate until it converges: that is, until the value returned by `cycle_fn` equals the value from the previous iteration. If a cycle iterates more than 200 times, Salsa will panic rather than iterate forever. @@ -40,7 +38,7 @@ Consider a two-query cycle where `query_a` calls `query_b`, and `query_b` calls Fixed-point iteration is a powerful tool, but is also easy to misuse, potentially resulting in infinite iteration. To avoid this, ensure that all queries participating in fixpoint iteration are deterministic and monotone. To guarantee convergence, you can leverage the `last_provisional_value` (3rd parameter) received by `cycle_fn`. -When the `cycle_fn` recalculates a value, you can implement a strategy that references the last provisional value to "join" values ​​or "widen" it and return a fallback value. This ensures monotonicity of the calculation and suppresses infinite oscillation of values ​​between cycles. +When the `cycle_fn` receives a newly computed value, you can implement a strategy that references the last provisional value to "join" values or "widen" it and return a fallback value. This ensures monotonicity of the calculation and suppresses infinite oscillation of values between cycles. For example: Also, in fixed-point iteration, it is advantageous to be able to identify which cycle head seeded a value. By embedding a `salsa::Id` (2nd parameter) in the initial value as a "cycle marker", the recovery function can detect self-originated recursion. diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 8ea4e5e33..1c3312372 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -310,10 +310,10 @@ macro_rules! setup_tracked_fn { db: &$db_lt dyn $Db, id: salsa::Id, last_provisional_value: &Self::Output<$db_lt>, - value: &Self::Output<$db_lt>, + value: Self::Output<$db_lt>, iteration_count: u32, ($($input_id),*): ($($interned_input_ty),*) - ) -> $zalsa::CycleRecoveryAction> { + ) -> Self::Output<$db_lt> { $($cycle_recovery_fn)*(db, id, last_provisional_value, value, iteration_count, $($input_id),*) } diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index ff03c02a2..fe002fa4e 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -4,9 +4,9 @@ #[macro_export] macro_rules! unexpected_cycle_recovery { ($db:ident, $id:ident, $last_provisional_value:ident, $new_value:ident, $count:ident, $($other_inputs:ident),*) => {{ - let (_db, _id, _last_provisional_value, _new_value, _count) = ($db, $id, $last_provisional_value, $new_value, $count); + let (_db, _id, _last_provisional_value, _count) = ($db, $id, $last_provisional_value, $count); std::mem::drop(($($other_inputs,)*)); - salsa::CycleRecoveryAction::Iterate + $new_value }}; } diff --git a/src/cycle.rs b/src/cycle.rs index 3f6f70aa0..0f12472b4 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -23,14 +23,12 @@ //! //! When a query observes that it has just computed a result which contains itself as a cycle head, //! it recognizes that it is responsible for resolving this cycle and calls its `cycle_fn` to -//! decide how to do so. The `cycle_fn` function is passed the provisional value just computed for -//! that query and the count of iterations so far, and must return either -//! `CycleRecoveryAction::Iterate` (which signals that the cycle head should re-iterate the cycle), -//! or `CycleRecoveryAction::Fallback` (which signals that the cycle head should replace its -//! computed value with the given fallback value). +//! decide what value to use. The `cycle_fn` function is passed the provisional value just computed +//! for that query and the count of iterations so far, and returns the value to use for this +//! iteration. This can be the computed value itself, or a different value (e.g., a fallback value). //! -//! If the cycle head ever observes that the provisional value it just recomputed is the same as -//! the provisional value from the previous iteration, the cycle has converged. The cycle head will +//! If the cycle head ever observes that the value returned by `cycle_fn` is the same as the +//! provisional value from the previous iteration, this cycle has converged. The cycle head will //! mark that value as final (by removing itself as cycle head) and return it. //! //! Other queries in the cycle will still have provisional values recorded, but those values should @@ -39,11 +37,6 @@ //! of its cycle heads have a final result, in which case it, too, can be marked final. (This is //! implemented in `shallow_verify_memo` and `validate_provisional`.) //! -//! If the `cycle_fn` returns a fallback value, the cycle head will replace its provisional value -//! with that fallback, and then iterate the cycle one more time. A fallback value is expected to -//! result in a stable, converged cycle. If it does not (that is, if the result of another -//! iteration of the cycle is not the same as the fallback value), we'll panic. -//! //! In nested cycle cases, the inner cycles are iterated as part of the outer cycle iteration. This helps //! to significantly reduce the number of iterations needed to reach a fixpoint. For nested cycles, //! the inner cycles head will transfer their lock ownership to the outer cycle. This ensures @@ -64,20 +57,6 @@ use crate::Revision; /// Should only be relevant in case of a badly configured cycle recovery. pub const MAX_ITERATIONS: IterationCount = IterationCount(200); -/// Return value from a cycle recovery function. -#[derive(Debug)] -pub enum CycleRecoveryAction { - /// Iterate the cycle again to look for a fixpoint. - Iterate, - - /// Use the given value as the result for the current iteration instead - /// of the value computed by the query function. - /// - /// Returning `Fallback` doesn't stop the fixpoint iteration. It only - /// allows the iterate function to return a different value. - Fallback(T), -} - /// Cycle recovery strategy: Is this query capable of recovering from /// a cycle that results from executing the function? If so, how? #[derive(Copy, Clone, Debug, PartialEq, Eq)] diff --git a/src/function.rs b/src/function.rs index 045825e19..b9878bc41 100644 --- a/src/function.rs +++ b/src/function.rs @@ -7,7 +7,7 @@ use std::ptr::NonNull; use std::sync::atomic::Ordering; use std::sync::OnceLock; -use crate::cycle::{CycleRecoveryAction, CycleRecoveryStrategy, IterationCount, ProvisionalStatus}; +use crate::cycle::{CycleRecoveryStrategy, IterationCount, ProvisionalStatus}; use crate::database::RawDatabase; use crate::function::delete::DeletedEntries; use crate::hash::{FxHashSet, FxIndexSet}; @@ -91,9 +91,11 @@ pub trait Configuration: Any { input: Self::Input<'db>, ) -> Self::Output<'db>; - /// Decide whether to iterate a cycle again or fallback. `value` is the provisional return - /// value from the latest iteration of this cycle. `count` is the number of cycle iterations - /// completed so far. + /// Decide what value to use for this cycle iteration. Takes ownership of the new value + /// and returns an owned value to use. + /// + /// The function is called for every iteration of the cycle head, regardless of whether the cycle + /// has converged (the values are equal). /// /// # Id /// @@ -112,17 +114,22 @@ pub trait Configuration: Any { /// * **Initial value**: `iteration` may be non-zero on the first call for a given query if that /// query becomes the outermost cycle head after a nested cycle complete a few iterations. In this case, /// `iteration` continues from the nested cycle's iteration count rather than resetting to zero. - /// * **Non-contiguous values**: This function isn't called if this cycle is part of an outer cycle - /// and the value for this query remains unchanged for one iteration. But the outer cycle might - /// keep iterating because other heads keep changing. + /// * **Non-contiguous values**: The iteration count can be non-contigious for cycle heads + /// that are only conditionally part of a cycle. + /// + /// # Return value + /// + /// The function should return the value to use for this iteration. This can be the `value` + /// that was computed, or a different value (e.g., a fallback value). This cycle will continue + /// iterating until the returned value equals the previous iteration's value. fn recover_from_cycle<'db>( db: &'db Self::DbView, id: Id, last_provisional_value: &Self::Output<'db>, - new_value: &Self::Output<'db>, + value: Self::Output<'db>, iteration: u32, input: Self::Input<'db>, - ) -> CycleRecoveryAction>; + ) -> Self::Output<'db>; /// Serialize the output type using `serde`. /// diff --git a/src/function/execute.rs b/src/function/execute.rs index 53bc640a2..d07bb45f6 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -357,8 +357,6 @@ where I am a cycle head, comparing last provisional value with new value" ); - let mut this_converged = C::values_equal(&new_value, last_provisional_value); - // If this is the outermost cycle, use the maximum iteration count of all cycles. // This is important for when later iterations introduce new cycle heads (that then // become the outermost cycle). We want to ensure that the iteration count keeps increasing @@ -373,36 +371,25 @@ where iteration_count }; - if !this_converged { - // We are in a cycle that hasn't converged; ask the user's - // cycle-recovery function what to do: - match C::recover_from_cycle( - db, - id, - last_provisional_value, - &new_value, - iteration_count.as_u32(), - C::id_to_input(zalsa, id), - ) { - crate::CycleRecoveryAction::Iterate => {} - crate::CycleRecoveryAction::Fallback(fallback_value) => { - tracing::debug!( - "{database_key_index:?}: execute: user cycle_fn says to fall back" - ); - new_value = fallback_value; - - this_converged = C::values_equal(&new_value, last_provisional_value); - } - } + // We are in a cycle that hasn't converged; ask the user's + // cycle-recovery function what to do (it may return the same value or a different one): + new_value = C::recover_from_cycle( + db, + id, + last_provisional_value, + new_value, + iteration_count.as_u32(), + C::id_to_input(zalsa, id), + ); - let new_cycle_heads = active_query.take_cycle_heads(); - for head in new_cycle_heads { - if !cycle_heads.contains(&head.database_key_index) { - panic!("Cycle recovery function for {database_key_index:?} introduced a cycle, depending on {:?}. This is not allowed.", head.database_key_index); - } + let new_cycle_heads = active_query.take_cycle_heads(); + for head in new_cycle_heads { + if !cycle_heads.contains(&head.database_key_index) { + panic!("Cycle recovery function for {database_key_index:?} introduced a cycle, depending on {:?}. This is not allowed.", head.database_key_index); } } + let this_converged = C::values_equal(&new_value, last_provisional_value); let mut completed_query = active_query.pop(); if let Some(outer_cycle) = outer_cycle { diff --git a/src/function/memo.rs b/src/function/memo.rs index d8faf3e0b..f22af65fe 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -496,7 +496,7 @@ mod _memory_usage { use crate::plumbing::{self, IngredientIndices, MemoIngredientSingletonIndex, SalsaStructInDb}; use crate::table::memo::MemoTableWithTypes; use crate::zalsa::Zalsa; - use crate::{CycleRecoveryAction, Database, Id, Revision}; + use crate::{Database, Id, Revision}; use std::any::TypeId; use std::num::NonZeroUsize; @@ -564,11 +564,11 @@ mod _memory_usage { _: &'db Self::DbView, _: Id, _: &Self::Output<'db>, - _: &Self::Output<'db>, + value: Self::Output<'db>, _: u32, _: Self::Input<'db>, - ) -> CycleRecoveryAction> { - unimplemented!() + ) -> Self::Output<'db> { + value } fn serialize(_: &Self::Output<'_>, _: S) -> Result diff --git a/src/lib.rs b/src/lib.rs index 8c50c9052..d4409c4a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,7 +47,7 @@ pub use self::database::IngredientInfo; pub use self::accumulator::Accumulator; pub use self::active_query::Backtrace; pub use self::cancelled::Cancelled; -pub use self::cycle::CycleRecoveryAction; + pub use self::database::Database; pub use self::database_impl::DatabaseImpl; pub use self::durability::Durability; @@ -92,7 +92,7 @@ pub mod plumbing { #[cfg(feature = "accumulator")] pub use crate::accumulator::Accumulator; pub use crate::attach::{attach, with_attached_database}; - pub use crate::cycle::{CycleRecoveryAction, CycleRecoveryStrategy}; + pub use crate::cycle::CycleRecoveryStrategy; pub use crate::database::{current_revision, Database}; pub use crate::durability::Durability; pub use crate::id::{AsId, FromId, FromIdWithDb, Id}; diff --git a/tests/cycle.rs b/tests/cycle.rs index dbe0bdc19..dd476ab76 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -7,7 +7,7 @@ mod common; use common::{ExecuteValidateLoggerDatabase, LogDatabase}; use expect_test::expect; -use salsa::{CycleRecoveryAction, Database as Db, DatabaseImpl as DbImpl, Durability, Setter}; +use salsa::{Database as Db, DatabaseImpl as DbImpl, Durability, Setter}; #[cfg(not(miri))] use test_log::test; @@ -122,24 +122,26 @@ const MAX_ITERATIONS: u32 = 3; /// Recover from a cycle by falling back to `Value::OutOfBounds` if the value is out of bounds, /// `Value::TooManyIterations` if we've iterated more than `MAX_ITERATIONS` times, or else -/// iterating again. +/// returning the computed value to continue iterating. fn cycle_recover( _db: &dyn Db, _id: salsa::Id, - _last_provisional_value: &Value, - value: &Value, + last_provisional_value: &Value, + value: Value, count: u32, _inputs: Inputs, -) -> CycleRecoveryAction { - if value +) -> Value { + if &value == last_provisional_value { + value + } else if value .to_value() .is_some_and(|val| val <= MIN_VALUE || val >= MAX_VALUE) { - CycleRecoveryAction::Fallback(Value::OutOfBounds) + Value::OutOfBounds } else if count > MAX_ITERATIONS { - CycleRecoveryAction::Fallback(Value::TooManyIterations) + Value::TooManyIterations } else { - CycleRecoveryAction::Iterate + value } } diff --git a/tests/cycle_accumulate.rs b/tests/cycle_accumulate.rs index 49f1d06d9..6377805b8 100644 --- a/tests/cycle_accumulate.rs +++ b/tests/cycle_accumulate.rs @@ -52,11 +52,11 @@ fn cycle_fn( _db: &dyn LogDatabase, _id: salsa::Id, _last_provisional_value: &[u32], - _value: &[u32], + value: Vec, _count: u32, _file: File, -) -> salsa::CycleRecoveryAction> { - salsa::CycleRecoveryAction::Iterate +) -> Vec { + value } #[test] diff --git a/tests/cycle_recovery_call_back_into_cycle.rs b/tests/cycle_recovery_call_back_into_cycle.rs index 4ab236565..77f7378e4 100644 --- a/tests/cycle_recovery_call_back_into_cycle.rs +++ b/tests/cycle_recovery_call_back_into_cycle.rs @@ -28,11 +28,15 @@ fn cycle_initial(_db: &dyn ValueDatabase, _id: salsa::Id) -> u32 { fn cycle_fn( db: &dyn ValueDatabase, _id: salsa::Id, - _last_provisional_value: &u32, - _value: &u32, + last_provisional_value: &u32, + value: u32, _count: u32, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Fallback(fallback_value(db)) +) -> u32 { + if &value == last_provisional_value { + value + } else { + fallback_value(db) + } } #[test] diff --git a/tests/cycle_recovery_call_query.rs b/tests/cycle_recovery_call_query.rs index a227d6122..dae4203d7 100644 --- a/tests/cycle_recovery_call_query.rs +++ b/tests/cycle_recovery_call_query.rs @@ -25,10 +25,10 @@ fn cycle_fn( db: &dyn salsa::Database, _id: salsa::Id, _last_provisional_value: &u32, - _value: &u32, + _value: u32, _count: u32, -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Fallback(fallback_value(db)) +) -> u32 { + fallback_value(db) } #[test_log::test] diff --git a/tests/cycle_recovery_dependencies.rs b/tests/cycle_recovery_dependencies.rs index b26ce973b..fe93428e5 100644 --- a/tests/cycle_recovery_dependencies.rs +++ b/tests/cycle_recovery_dependencies.rs @@ -39,12 +39,12 @@ fn cycle_fn( db: &dyn salsa::Database, _id: salsa::Id, _last_provisional_value: &u32, - _value: &u32, + value: u32, _count: u32, input: Input, -) -> salsa::CycleRecoveryAction { +) -> u32 { let _input = input.value(db); - salsa::CycleRecoveryAction::Iterate + value } #[test_log::test] diff --git a/tests/dataflow.rs b/tests/dataflow.rs index 69c91d513..f91123ef0 100644 --- a/tests/dataflow.rs +++ b/tests/dataflow.rs @@ -7,7 +7,7 @@ use std::collections::BTreeSet; use std::iter::IntoIterator; -use salsa::{CycleRecoveryAction, Database as Db, Setter}; +use salsa::{Database as Db, Setter}; /// A Use of a symbol. #[salsa::input] @@ -78,12 +78,16 @@ fn def_cycle_initial(_db: &dyn Db, _id: salsa::Id, _def: Definition) -> Type { fn def_cycle_recover( _db: &dyn Db, _id: salsa::Id, - _last_provisional_value: &Type, - value: &Type, + last_provisional_value: &Type, + value: Type, count: u32, _def: Definition, -) -> CycleRecoveryAction { - cycle_recover(value, count) +) -> Type { + if &value == last_provisional_value { + value + } else { + cycle_recover(value, count) + } } fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { @@ -93,25 +97,29 @@ fn use_cycle_initial(_db: &dyn Db, _id: salsa::Id, _use: Use) -> Type { fn use_cycle_recover( _db: &dyn Db, _id: salsa::Id, - _last_provisional_value: &Type, - value: &Type, + last_provisional_value: &Type, + value: Type, count: u32, _use: Use, -) -> CycleRecoveryAction { - cycle_recover(value, count) +) -> Type { + if &value == last_provisional_value { + value + } else { + cycle_recover(value, count) + } } -fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction { - match value { - Type::Bottom => CycleRecoveryAction::Iterate, +fn cycle_recover(value: Type, count: u32) -> Type { + match &value { + Type::Bottom => value, Type::Values(_) => { if count > 4 { - CycleRecoveryAction::Fallback(Type::Top) + Type::Top } else { - CycleRecoveryAction::Iterate + value } } - Type::Top => CycleRecoveryAction::Iterate, + Type::Top => value, } } diff --git a/tests/parallel/cycle_panic.rs b/tests/parallel/cycle_panic.rs index ba05291a5..34cbb7ed2 100644 --- a/tests/parallel/cycle_panic.rs +++ b/tests/parallel/cycle_panic.rs @@ -22,9 +22,9 @@ fn cycle_fn( _db: &dyn KnobsDatabase, _id: salsa::Id, _last_provisional_value: &u32, - _value: &u32, + _value: u32, _count: u32, -) -> salsa::CycleRecoveryAction { +) -> u32 { panic!("cancel!") } From a885bb4c4c192741b8a17418fef81a71e33d111e Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Thu, 13 Nov 2025 10:17:44 +0100 Subject: [PATCH 65/65] Fix cycle head durability (#1024) --- src/function/execute.rs | 36 ++++++++---- src/zalsa_local.rs | 4 ++ tests/cycle_input_different_cycle_head.rs | 72 +++++++++++++++++++++++ 3 files changed, 101 insertions(+), 11 deletions(-) create mode 100644 tests/cycle_input_different_cycle_head.rs diff --git a/src/function/execute.rs b/src/function/execute.rs index d07bb45f6..558ace738 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -170,7 +170,7 @@ where // Our provisional value from the previous iteration, when doing fixpoint iteration. // This is different from `opt_old_memo` which might be from a different revision. - let mut last_provisional_memo: Option<&Memo<'db, C>> = None; + let mut last_provisional_memo_opt: Option<&Memo<'db, C>> = None; // TODO: Can we seed those somehow? let mut last_stale_tracked_ids: Vec<(Identity, Id)> = Vec::new(); @@ -194,7 +194,7 @@ where // Only use the last provisional memo if it was a cycle head in the last iteration. This is to // force at least two executions. if old_memo.cycle_heads().contains(&database_key_index) { - last_provisional_memo = Some(old_memo); + last_provisional_memo_opt = Some(old_memo); } iteration_count = old_memo.revisions.iteration(); @@ -219,7 +219,7 @@ where db, zalsa, active_query, - last_provisional_memo.or(opt_old_memo), + last_provisional_memo_opt.or(opt_old_memo), ); // Take the cycle heads to not-fight-rust's-borrow-checker. @@ -329,10 +329,7 @@ where // Get the last provisional value for this query so that we can compare it with the new value // to test if the cycle converged. - let last_provisional_value = if let Some(last_provisional) = last_provisional_memo { - // We have a last provisional value from our previous time around the loop. - last_provisional.value.as_ref() - } else { + let last_provisional_memo = last_provisional_memo_opt.unwrap_or_else(|| { // This is our first time around the loop; a provisional value must have been // inserted into the memo table when the cycle was hit, so let's pull our // initial provisional value from there. @@ -346,8 +343,10 @@ where }); debug_assert!(memo.may_be_provisional()); - memo.value.as_ref() - }; + memo + }); + + let last_provisional_value = last_provisional_memo.value.as_ref(); let last_provisional_value = last_provisional_value.expect( "`fetch_cold_cycle` should have inserted a provisional memo with Cycle::initial", @@ -389,9 +388,24 @@ where } } - let this_converged = C::values_equal(&new_value, last_provisional_value); let mut completed_query = active_query.pop(); + let value_converged = C::values_equal(&new_value, last_provisional_value); + + // It's important to force a re-execution of the cycle if `changed_at` or `durability` has changed + // to ensure the reduced durability and changed propagates to all queries depending on this head. + let metadata_converged = last_provisional_memo.revisions.durability + == completed_query.revisions.durability + && last_provisional_memo.revisions.changed_at + == completed_query.revisions.changed_at + && last_provisional_memo + .revisions + .origin + .is_derived_untracked() + == completed_query.revisions.origin.is_derived_untracked(); + + let this_converged = value_converged && metadata_converged; + if let Some(outer_cycle) = outer_cycle { tracing::info!( "Detected nested cycle {database_key_index:?}, iterate it as part of the outer cycle {outer_cycle:?}" @@ -494,7 +508,7 @@ where memo_ingredient_index, ); - last_provisional_memo = Some(new_memo); + last_provisional_memo_opt = Some(new_memo); last_stale_tracked_ids = completed_query.stale_tracked_structs; diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index bde3b6b24..8f0239e56 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -934,6 +934,10 @@ impl QueryOrigin { } } + pub fn is_derived_untracked(&self) -> bool { + matches!(self.kind, QueryOriginKind::DerivedUntracked) + } + /// Create a query origin of type `QueryOriginKind::Derived`, with the given edges. pub fn derived(input_outputs: Box<[QueryEdge]>) -> QueryOrigin { // Exceeding `u32::MAX` query edges should never happen in real-world usage. diff --git a/tests/cycle_input_different_cycle_head.rs b/tests/cycle_input_different_cycle_head.rs new file mode 100644 index 000000000..d7f75143c --- /dev/null +++ b/tests/cycle_input_different_cycle_head.rs @@ -0,0 +1,72 @@ +#![cfg(feature = "inventory")] + +//! Tests that the durability correctly propagates +//! to all cycle heads. + +use salsa::Setter as _; + +#[test_log::test] +fn low_durability_cycle_enter_from_different_head() { + let mut db = MyDbImpl::default(); + // Start with 0, the same as returned by cycle initial + let input = Input::builder(0).new(&db); + db.input = Some(input); + + assert_eq!(query_a(&db), 0); // Prime the Db + + input.set_value(&mut db).to(10); + + assert_eq!(query_b(&db), 10); +} + +#[salsa::input] +struct Input { + value: u32, +} + +#[salsa::db] +trait MyDb: salsa::Database { + fn input(&self) -> Input; +} + +#[salsa::db] +#[derive(Clone, Default)] +struct MyDbImpl { + storage: salsa::Storage, + input: Option, +} + +#[salsa::db] +impl salsa::Database for MyDbImpl {} + +#[salsa::db] +impl MyDb for MyDbImpl { + fn input(&self) -> Input { + self.input.unwrap() + } +} + +#[salsa::tracked(cycle_initial=cycle_initial)] +fn query_a(db: &dyn MyDb) -> u32 { + query_b(db); + db.input().value(db) +} + +fn cycle_initial(_db: &dyn MyDb, _id: salsa::Id) -> u32 { + 0 +} + +#[salsa::interned] +struct Interned { + value: u32, +} + +#[salsa::tracked(cycle_initial=cycle_initial)] +fn query_b<'db>(db: &'db dyn MyDb) -> u32 { + query_c(db) +} + +#[salsa::tracked] +fn query_c(db: &dyn MyDb) -> u32 { + query_a(db) +}