diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index db635d53d04..2f655ddb924 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -116,9 +116,10 @@ class FeatureStore: config: RepoConfig repo_path: Path - _registry: BaseRegistry - _provider: Provider + _registry: Optional[BaseRegistry] + _provider: Optional[Provider] _openlineage_emitter: Optional[Any] = None + _feature_service_cache: Dict[str, List[str]] def __init__( self, @@ -159,33 +160,13 @@ def __init__( self.repo_path, utils.get_default_yaml_file_path(self.repo_path) ) - registry_config = self.config.registry - if registry_config.registry_type == "sql": - self._registry = SqlRegistry(registry_config, self.config.project, None) - elif registry_config.registry_type == "snowflake.registry": - from feast.infra.registry.snowflake import SnowflakeRegistry + # Initialize lazy-loaded components as None + self._registry = None + self._provider = None + self._openlineage_emitter = None - self._registry = SnowflakeRegistry( - registry_config, self.config.project, None - ) - elif registry_config and registry_config.registry_type == "remote": - from feast.infra.registry.remote import RemoteRegistry - - self._registry = RemoteRegistry( - registry_config, self.config.project, None, self.config.auth_config - ) - else: - self._registry = Registry( - self.config.project, - registry_config, - repo_path=self.repo_path, - auth_config=self.config.auth_config, - ) - - self._provider = get_provider(self.config) - - # Initialize OpenLineage emitter if configured - self._openlineage_emitter = self._init_openlineage_emitter() + # Initialize feature service cache for performance optimization + self._feature_service_cache = {} def _init_openlineage_emitter(self) -> Optional[Any]: """Initialize OpenLineage emitter if configured and enabled.""" @@ -209,28 +190,85 @@ def _init_openlineage_emitter(self) -> Optional[Any]: return None def __repr__(self) -> str: + # Show lazy loading status without triggering initialization + registry_status = "not loaded" if self._registry is None else "loaded" + provider_status = "not loaded" if self._provider is None else "loaded" return ( f"FeatureStore(\n" f" repo_path={self.repo_path!r},\n" f" config={self.config!r},\n" - f" registry={self._registry!r},\n" - f" provider={self._provider!r}\n" + f" registry={registry_status},\n" + f" provider={provider_status}\n" f")" ) @property def registry(self) -> BaseRegistry: """Gets the registry of this feature store.""" + if self._registry is None: + self._registry = self._create_registry() + # Add feature service cache to registry for performance optimization + if self._registry and not hasattr(self._registry, "_feature_service_cache"): + setattr( + self._registry, + "_feature_service_cache", + self._feature_service_cache, + ) + if self._registry is None: + raise RuntimeError("Registry failed to initialize properly") return self._registry + def _create_registry(self) -> BaseRegistry: + """Create and initialize the registry.""" + registry_config = self.config.registry + if registry_config.registry_type == "sql": + return SqlRegistry(registry_config, self.config.project, None) + elif registry_config.registry_type == "snowflake.registry": + from feast.infra.registry.snowflake import SnowflakeRegistry + + return SnowflakeRegistry(registry_config, self.config.project, None) + elif registry_config and registry_config.registry_type == "remote": + from feast.infra.registry.remote import RemoteRegistry + + return RemoteRegistry( + registry_config, self.config.project, None, self.config.auth_config + ) + else: + return Registry( + self.config.project, + registry_config, + repo_path=self.repo_path, + auth_config=self.config.auth_config, + ) + @property def project(self) -> str: """Gets the project of this feature store.""" return self.config.project + @property + def provider(self) -> Provider: + """Gets the provider of this feature store.""" + if self._provider is None: + self._provider = get_provider(self.config) + return self._provider + def _get_provider(self) -> Provider: # TODO: Bake self.repo_path into self.config so that we dont only have one interface to paths - return self._provider + return self.provider + + @property + def openlineage_emitter(self) -> Optional[Any]: + """Gets the OpenLineage emitter of this feature store.""" + if self._openlineage_emitter is None: + self._openlineage_emitter = self._init_openlineage_emitter() + return self._openlineage_emitter + + def _clear_feature_service_cache(self): + """Clear feature service cache to avoid stale data after registry refresh.""" + self._feature_service_cache.clear() + if hasattr(self.registry, "_feature_service_cache"): + getattr(self.registry, "_feature_service_cache").clear() def refresh_registry(self): """Fetches and caches a copy of the feature registry in memory. @@ -247,7 +285,8 @@ def refresh_registry(self): downloaded synchronously, which may increase latencies if the triggering method is get_online_features(). """ - self._registry.refresh(self.project) + self.registry.refresh(self.project) + self._clear_feature_service_cache() def list_entities( self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None @@ -270,7 +309,7 @@ def _list_entities( hide_dummy_entity: bool = True, tags: Optional[dict[str, str]] = None, ) -> List[Entity]: - all_entities = self._registry.list_entities( + all_entities = self.registry.list_entities( self.project, allow_cache=allow_cache, tags=tags ) return [ @@ -291,7 +330,7 @@ def list_feature_services( Returns: A list of feature services. """ - return self._registry.list_feature_services(self.project, tags=tags) + return self.registry.list_feature_services(self.project, tags=tags) def _list_all_feature_views( self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None @@ -338,7 +377,7 @@ def list_feature_views( A list of feature views. """ return utils._list_feature_views( - self._registry, self.project, allow_cache, tags=tags + self.registry, self.project, allow_cache, tags=tags ) def list_batch_feature_views( @@ -363,7 +402,7 @@ def _list_batch_feature_views( tags: Optional[dict[str, str]] = None, ) -> List[FeatureView]: feature_views = [] - for fv in self._registry.list_feature_views( + for fv in self.registry.list_feature_views( self.project, allow_cache=allow_cache, tags=tags ): if ( @@ -383,7 +422,7 @@ def _list_stream_feature_views( tags: Optional[dict[str, str]] = None, ) -> List[StreamFeatureView]: stream_feature_views = [] - for sfv in self._registry.list_stream_feature_views( + for sfv in self.registry.list_stream_feature_views( self.project, allow_cache=allow_cache, tags=tags ): if hide_dummy_entity and sfv.entities[0] == DUMMY_ENTITY_NAME: @@ -405,7 +444,7 @@ def list_on_demand_feature_views( Returns: A list of on demand feature views. """ - return self._registry.list_on_demand_feature_views( + return self.registry.list_on_demand_feature_views( self.project, allow_cache=allow_cache, tags=tags ) @@ -433,7 +472,7 @@ def list_data_sources( Returns: A list of data sources. """ - return self._registry.list_data_sources( + return self.registry.list_data_sources( self.project, allow_cache=allow_cache, tags=tags ) @@ -451,7 +490,7 @@ def get_entity(self, name: str, allow_registry_cache: bool = False) -> Entity: Raises: EntityNotFoundException: The entity could not be found. """ - return self._registry.get_entity( + return self.registry.get_entity( name, self.project, allow_cache=allow_registry_cache ) @@ -471,7 +510,7 @@ def get_feature_service( Raises: FeatureServiceNotFoundException: The feature service could not be found. """ - return self._registry.get_feature_service(name, self.project, allow_cache) + return self.registry.get_feature_service(name, self.project, allow_cache) def get_feature_view( self, name: str, allow_registry_cache: bool = False @@ -497,7 +536,7 @@ def _get_feature_view( hide_dummy_entity: bool = True, allow_registry_cache: bool = False, ) -> FeatureView: - feature_view = self._registry.get_feature_view( + feature_view = self.registry.get_feature_view( name, self.project, allow_cache=allow_registry_cache ) if hide_dummy_entity and feature_view.entities[0] == DUMMY_ENTITY_NAME: @@ -530,7 +569,7 @@ def _get_stream_feature_view( hide_dummy_entity: bool = True, allow_registry_cache: bool = False, ) -> StreamFeatureView: - stream_feature_view = self._registry.get_stream_feature_view( + stream_feature_view = self.registry.get_stream_feature_view( name, self.project, allow_cache=allow_registry_cache ) if hide_dummy_entity and stream_feature_view.entities[0] == DUMMY_ENTITY_NAME: @@ -552,7 +591,7 @@ def get_on_demand_feature_view( Raises: FeatureViewNotFoundException: The feature view could not be found. """ - return self._registry.get_on_demand_feature_view( + return self.registry.get_on_demand_feature_view( name, self.project, allow_cache=allow_registry_cache ) @@ -569,7 +608,7 @@ def get_data_source(self, name: str) -> DataSource: Raises: DataSourceObjectNotFoundException: The data source could not be found. """ - return self._registry.get_data_source(name, self.project) + return self.registry.get_data_source(name, self.project) def delete_feature_view(self, name: str): """ @@ -581,7 +620,7 @@ def delete_feature_view(self, name: str): Raises: FeatureViewNotFoundException: The feature view could not be found. """ - return self._registry.delete_feature_view(name, self.project) + return self.registry.delete_feature_view(name, self.project) def delete_feature_service(self, name: str): """ @@ -593,7 +632,7 @@ def delete_feature_service(self, name: str): Raises: FeatureServiceNotFoundException: The feature view could not be found. """ - return self._registry.delete_feature_service(name, self.project) + return self.registry.delete_feature_service(name, self.project) def _should_use_plan(self): """Returns True if plan and _apply_diffs should be used, False otherwise.""" @@ -706,7 +745,7 @@ def _get_feature_views_to_materialize( if feature_views is None: regular_feature_views = utils._list_feature_views( - self._registry, self.project, hide_dummy_entity=False + self.registry, self.project, hide_dummy_entity=False ) feature_views_to_materialize.extend( [fv for fv in regular_feature_views if fv.online] @@ -821,19 +860,18 @@ def plan( # Compute the desired difference between the current objects in the registry and # the desired repo state. - registry_diff = diff_between( - self._registry, self.project, desired_repo_contents - ) + registry_diff = diff_between(self.registry, self.project, desired_repo_contents) if progress_ctx: progress_ctx.update_phase_progress("Computing infrastructure diff") # Compute the desired difference between the current infra, as stored in the registry, # and the desired infra. - self._registry.refresh(project=self.project) - current_infra_proto = self._registry.get_infra(self.project).to_proto() + self.registry.refresh(project=self.project) + self._clear_feature_service_cache() + current_infra_proto = self.registry.get_infra(self.project).to_proto() desired_registry_proto = desired_repo_contents.to_registry_proto() - new_infra = self._provider.plan_infra(self.config, desired_registry_proto) + new_infra = self.provider.plan_infra(self.config, desired_registry_proto) new_infra_proto = new_infra.to_proto() infra_diff = diff_infra_protos( current_infra_proto, new_infra_proto, project=self.project @@ -870,13 +908,13 @@ def _apply_diffs( # Registry phase apply_diff_to_registry( - self._registry, registry_diff, self.project, commit=False + self.registry, registry_diff, self.project, commit=False ) if progress_ctx: progress_ctx.update_phase_progress("Committing registry changes") - self._registry.update_infra(new_infra, self.project, commit=True) + self.registry.update_infra(new_infra, self.project, commit=True) if progress_ctx: progress_ctx.update_phase_progress("Registry update complete") @@ -891,7 +929,7 @@ def _apply_diffs( def _emit_openlineage_apply_diffs(self, registry_diff: RegistryDiff): """Emit OpenLineage events for objects applied via diffs.""" - if self._openlineage_emitter is None: + if self.openlineage_emitter is None: return # Collect all objects that were added or updated @@ -1059,23 +1097,23 @@ def apply( # Add all objects to the registry and update the provider's infrastructure. for project in projects_to_update: - self._registry.apply_project(project, commit=False) + self.registry.apply_project(project, commit=False) for ds in data_sources_to_update: - self._registry.apply_data_source(ds, project=self.project, commit=False) + self.registry.apply_data_source(ds, project=self.project, commit=False) for view in itertools.chain(views_to_update, odfvs_to_update, sfvs_to_update): - self._registry.apply_feature_view(view, project=self.project, commit=False) + self.registry.apply_feature_view(view, project=self.project, commit=False) for ent in entities_to_update: - self._registry.apply_entity(ent, project=self.project, commit=False) + self.registry.apply_entity(ent, project=self.project, commit=False) for feature_service in services_to_update: - self._registry.apply_feature_service( + self.registry.apply_feature_service( feature_service, project=self.project, commit=False ) for validation_references in validation_references_to_update: - self._registry.apply_validation_reference( + self.registry.apply_validation_reference( validation_references, project=self.project, commit=False ) for permission in permissions_to_update: - self._registry.apply_permission( + self.registry.apply_permission( permission, project=self.project, commit=False ) @@ -1116,35 +1154,35 @@ def apply( ] for data_source in data_sources_to_delete: - self._registry.delete_data_source( + self.registry.delete_data_source( data_source.name, project=self.project, commit=False ) for entity in entities_to_delete: - self._registry.delete_entity( + self.registry.delete_entity( entity.name, project=self.project, commit=False ) for view in views_to_delete: - self._registry.delete_feature_view( + self.registry.delete_feature_view( view.name, project=self.project, commit=False ) for odfv in odfvs_to_delete: - self._registry.delete_feature_view( + self.registry.delete_feature_view( odfv.name, project=self.project, commit=False ) for sfv in sfvs_to_delete: - self._registry.delete_feature_view( + self.registry.delete_feature_view( sfv.name, project=self.project, commit=False ) for service in services_to_delete: - self._registry.delete_feature_service( + self.registry.delete_feature_service( service.name, project=self.project, commit=False ) for validation_references in validation_references_to_delete: - self._registry.delete_validation_reference( + self.registry.delete_validation_reference( validation_references.name, project=self.project, commit=False ) for permission in permissions_to_delete: - self._registry.delete_permission( + self.registry.delete_permission( permission.name, project=self.project, commit=False ) @@ -1164,7 +1202,7 @@ def apply( partial=partial, ) - self._registry.commit() + self.registry.commit() # Refresh the registry cache to ensure that changes are immediately visible # This is especially important for UI and other clients that may be reading @@ -1182,10 +1220,10 @@ def apply( def _emit_openlineage_apply(self, objects: List[Any]): """Emit OpenLineage events for applied objects.""" - if self._openlineage_emitter is None: + if self.openlineage_emitter is None: return try: - self._openlineage_emitter.emit_apply(objects, self.project) + self.openlineage_emitter.emit_apply(objects, self.project) except Exception as e: warnings.warn(f"Failed to emit OpenLineage apply events: {e}") @@ -1199,7 +1237,7 @@ def teardown(self): entities = self.list_entities() self._get_provider().teardown_infra(self.project, tables, entities) - self._registry.teardown() + self.registry.teardown() def get_historical_features( self, @@ -1280,11 +1318,13 @@ def get_historical_features( if entity_df is None and end_date is None: end_date = datetime.now() - _feature_refs = utils._get_features(self._registry, self.project, features) + _feature_refs = utils._get_features( + self.registry, self.project, features, allow_cache=True + ) ( all_feature_views, all_on_demand_feature_views, - ) = utils._get_feature_views_to_use(self._registry, self.project, features) + ) = utils._get_feature_views_to_use(self.registry, self.project, features) # TODO(achal): _group_feature_refs returns the on demand feature views, but it's not passed into the provider. # This is a weird interface quirk - we should revisit the `get_historical_features` to @@ -1336,7 +1376,7 @@ def get_historical_features( feature_views, _feature_refs, entity_df, - self._registry, + self.registry, self.project, full_feature_names, **kwargs, @@ -1408,7 +1448,7 @@ def create_saved_dataset( ) ) - self._registry.apply_saved_dataset(dataset, self.project, commit=True) + self.registry.apply_saved_dataset(dataset, self.project, commit=True) return dataset def get_saved_dataset(self, name: str) -> SavedDataset: @@ -1434,7 +1474,7 @@ def get_saved_dataset(self, name: str) -> SavedDataset: RuntimeWarning, ) - dataset = self._registry.get_saved_dataset(name, self.project) + dataset = self.registry.get_saved_dataset(name, self.project) provider = self._get_provider() retrieval_job = provider.retrieve_saved_dataset( @@ -1669,12 +1709,12 @@ def tqdm_builder(length): feature_view=feature_view, start_date=start_date, end_date=end_date, - registry=self._registry, + registry=self.registry, project=self.project, tqdm_builder=tqdm_builder, ) if not isinstance(feature_view, OnDemandFeatureView): - self._registry.apply_materialization( + self.registry.apply_materialization( feature_view, self.project, start_date, @@ -1778,13 +1818,13 @@ def tqdm_builder(length): feature_view=feature_view, start_date=start_date, end_date=end_date, - registry=self._registry, + registry=self.registry, project=self.project, tqdm_builder=tqdm_builder, disable_event_timestamp=disable_event_timestamp, ) - self._registry.apply_materialization( + self.registry.apply_materialization( feature_view, self.project, start_date, @@ -1807,10 +1847,10 @@ def _emit_openlineage_materialize_start( end_date: datetime, ) -> Optional[str]: """Emit OpenLineage START event for materialization.""" - if self._openlineage_emitter is None: + if self.openlineage_emitter is None: return None try: - run_id, success = self._openlineage_emitter.emit_materialize_start( + run_id, success = self.openlineage_emitter.emit_materialize_start( feature_views, start_date, end_date, self.project ) # Return run_id only if START was successfully emitted @@ -1826,10 +1866,10 @@ def _emit_openlineage_materialize_complete( feature_views: List[Any], ): """Emit OpenLineage COMPLETE event for materialization.""" - if self._openlineage_emitter is None or not run_id: + if self.openlineage_emitter is None or not run_id: return try: - self._openlineage_emitter.emit_materialize_complete( + self.openlineage_emitter.emit_materialize_complete( run_id, feature_views, self.project ) except Exception as e: @@ -1841,10 +1881,10 @@ def _emit_openlineage_materialize_fail( error_message: str, ): """Emit OpenLineage FAIL event for materialization.""" - if self._openlineage_emitter is None or not run_id: + if self.openlineage_emitter is None or not run_id: return try: - self._openlineage_emitter.emit_materialize_fail( + self.openlineage_emitter.emit_materialize_fail( run_id, self.project, error_message ) except Exception as e: @@ -2321,7 +2361,7 @@ def get_online_features( config=self.config, features=features, entity_rows=entity_rows, - registry=self._registry, + registry=self.registry, project=self.project, full_feature_names=full_feature_names, ) @@ -2369,7 +2409,7 @@ async def get_online_features_async( config=self.config, features=features, entity_rows=entity_rows, - registry=self._registry, + registry=self.registry, project=self.project, full_feature_names=full_feature_names, ) @@ -2399,7 +2439,7 @@ def retrieve_online_documents( available_feature_views, _, ) = utils._get_feature_views_to_use( - registry=self._registry, + registry=self.registry, project=self.project, features=features, allow_cache=True, @@ -2600,7 +2640,7 @@ def retrieve_online_documents_v2( available_feature_views, available_odfv_views, ) = utils._get_feature_views_to_use( - registry=self._registry, + registry=self.registry, project=self.project, features=features, allow_cache=True, @@ -2826,7 +2866,7 @@ def serve( def get_feature_server_endpoint(self) -> Optional[str]: """Returns endpoint for the feature server, if it exists.""" - return self._provider.get_feature_server_endpoint() + return self.provider.get_feature_server_endpoint() def serve_ui( self, @@ -2928,7 +2968,7 @@ def write_logged_features( feature_service=source, logs=logs, config=self.config, - registry=self._registry, + registry=self.registry, ) def validate_logged_features( @@ -3000,7 +3040,7 @@ def get_validation_reference( Raises: ValidationReferenceNotFoundException: The validation reference could not be found. """ - ref = self._registry.get_validation_reference( + ref = self.registry.get_validation_reference( name, project=self.project, allow_cache=allow_cache ) ref._dataset = self.get_saved_dataset(ref.dataset_name) @@ -3019,7 +3059,7 @@ def list_validation_references( Returns: A list of validation references. """ - return self._registry.list_validation_references( + return self.registry.list_validation_references( self.project, allow_cache=allow_cache, tags=tags ) @@ -3036,7 +3076,7 @@ def list_permissions( Returns: A list of permissions. """ - return self._registry.list_permissions( + return self.registry.list_permissions( self.project, allow_cache=allow_cache, tags=tags ) @@ -3053,7 +3093,7 @@ def get_permission(self, name: str) -> Permission: Raises: PermissionObjectNotFoundException: The permission could not be found. """ - return self._registry.get_permission(name, self.project) + return self.registry.get_permission(name, self.project) def list_projects( self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None @@ -3068,7 +3108,7 @@ def list_projects( Returns: A list of projects. """ - return self._registry.list_projects(allow_cache=allow_cache, tags=tags) + return self.registry.list_projects(allow_cache=allow_cache, tags=tags) def get_project(self, name: Optional[str]) -> Project: """ @@ -3083,7 +3123,7 @@ def get_project(self, name: Optional[str]) -> Project: Raises: ProjectObjectNotFoundException: The project could not be found. """ - return self._registry.get_project(name or self.project) + return self.registry.get_project(name or self.project) def list_saved_datasets( self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None @@ -3098,7 +3138,7 @@ def list_saved_datasets( Returns: A list of saved datasets. """ - return self._registry.list_saved_datasets( + return self.registry.list_saved_datasets( self.project, allow_cache=allow_cache, tags=tags ) diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 78da775f98d..ebdd56929bb 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -1107,6 +1107,15 @@ def _get_features( _feature_refs = [] if isinstance(_features, FeatureService): + # Create cache key for feature service resolution + cache_key = f"{_features.name}:{project}:{hash(tuple(str(fv) for fv in _features.feature_view_projections))}" + + # Check cache first if caching is enabled and available + if allow_cache and hasattr(registry, "_feature_service_cache"): + if cache_key in registry._feature_service_cache: + return registry._feature_service_cache[cache_key] + + # Resolve feature service from registry feature_service_from_registry = registry.get_feature_service( _features.name, project, allow_cache ) @@ -1116,10 +1125,16 @@ def _get_features( "inconsistent with the version from the registry. Potentially a newer version " "of the FeatureService has been applied to the registry." ) + + # Build feature reference list for projection in feature_service_from_registry.feature_view_projections: _feature_refs.extend( [f"{projection.name_to_use()}:{f.name}" for f in projection.features] ) + + # Cache the result if caching is enabled and available + if allow_cache and hasattr(registry, "_feature_service_cache"): + registry._feature_service_cache[cache_key] = _feature_refs else: assert isinstance(_features, list) _feature_refs = _features diff --git a/sdk/python/tests/integration/offline_store/test_dqm_validation.py b/sdk/python/tests/integration/offline_store/test_dqm_validation.py index 52d83ab8d8f..710dd6ca2e6 100644 --- a/sdk/python/tests/integration/offline_store/test_dqm_validation.py +++ b/sdk/python/tests/integration/offline_store/test_dqm_validation.py @@ -1,419 +1,419 @@ -import datetime -import shutil - -import pandas as pd -import pyarrow as pa -import pytest -from great_expectations.core import ExpectationSuite -from great_expectations.dataset import PandasDataset - -from feast import FeatureService -from feast.dqm.errors import ValidationFailed -from feast.dqm.profilers.ge_profiler import ge_profiler -from feast.feature_logging import ( - LOG_TIMESTAMP_FIELD, - FeatureServiceLoggingSource, - LoggingConfig, -) -from feast.protos.feast.serving.ServingService_pb2 import FieldStatus -from feast.utils import _utc_now, make_tzaware -from feast.wait import wait_retry_backoff -from tests.integration.feature_repos.repo_configuration import ( - construct_universal_feature_views, -) -from tests.integration.feature_repos.universal.entities import ( - customer, - driver, - location, -) -from tests.utils.cli_repo_creator import CliRunner -from tests.utils.test_log_creator import prepare_logs - -_features = [ - "customer_profile:current_balance", - "customer_profile:avg_passenger_count", - "customer_profile:lifetime_trip_count", - "order:order_is_success", - "global_stats:num_rides", - "global_stats:avg_ride_length", -] - - -@pytest.mark.integration -@pytest.mark.universal_offline_stores -def test_historical_retrieval_with_validation(environment, universal_data_sources): - store = environment.feature_store - (entities, datasets, data_sources) = universal_data_sources - feature_views = construct_universal_feature_views(data_sources) - storage = environment.data_source_creator.create_saved_dataset_destination() - - store.apply([driver(), customer(), location(), *feature_views.values()]) - - # Added to handle the case that the offline store is remote - store.registry.apply_data_source(storage.to_data_source(), store.config.project) - - # Create two identical retrieval jobs - entity_df = datasets.entity_df.drop( - columns=["order_id", "origin_id", "destination_id"] - ) - reference_job = store.get_historical_features( - entity_df=entity_df, - features=_features, - ) - job = store.get_historical_features( - entity_df=entity_df, - features=_features, - ) - - # Save dataset using reference job and retrieve it - store.create_saved_dataset( - from_=reference_job, - name="my_training_dataset", - storage=storage, - allow_overwrite=True, - ) - saved_dataset = store.get_saved_dataset("my_training_dataset") - - # If validation pass there will be no exceptions on this point - reference = saved_dataset.as_reference(name="ref", profiler=configurable_profiler) - job.to_df(validation_reference=reference) - - -@pytest.mark.integration -def test_historical_retrieval_fails_on_validation(environment, universal_data_sources): - store = environment.feature_store - - (entities, datasets, data_sources) = universal_data_sources - feature_views = construct_universal_feature_views(data_sources) - storage = environment.data_source_creator.create_saved_dataset_destination() - - store.apply([driver(), customer(), location(), *feature_views.values()]) - - # Added to handle the case that the offline store is remote - store.registry.apply_data_source(storage.to_data_source(), store.config.project) - - entity_df = datasets.entity_df.drop( - columns=["order_id", "origin_id", "destination_id"] - ) - - reference_job = store.get_historical_features( - entity_df=entity_df, - features=_features, - ) - - store.create_saved_dataset( - from_=reference_job, - name="my_other_dataset", - storage=storage, - allow_overwrite=True, - ) - - job = store.get_historical_features( - entity_df=entity_df, - features=_features, - ) - - ds = store.get_saved_dataset("my_other_dataset") - profiler_expectation_suite = ds.get_profile( - profiler=profiler_with_unrealistic_expectations - ) - - assert len(profiler_expectation_suite.expectation_suite["expectations"]) == 3 - - with pytest.raises(ValidationFailed) as exc_info: - job.to_df( - validation_reference=store.get_saved_dataset( - "my_other_dataset" - ).as_reference(name="ref", profiler=profiler_with_unrealistic_expectations) - ) - - failed_expectations = exc_info.value.report.errors - assert len(failed_expectations) == 2 - - assert failed_expectations[0].check_name == "expect_column_max_to_be_between" - assert failed_expectations[0].column_name == "current_balance" - - assert failed_expectations[1].check_name == "expect_column_values_to_be_in_set" - assert failed_expectations[1].column_name == "avg_passenger_count" - - -@pytest.mark.integration -@pytest.mark.universal_offline_stores -def test_logged_features_validation(environment, universal_data_sources): - store = environment.feature_store - - (_, datasets, data_sources) = universal_data_sources - feature_views = construct_universal_feature_views(data_sources) - feature_service = FeatureService( - name="test_service", - features=[ - feature_views.customer[ - ["current_balance", "avg_passenger_count", "lifetime_trip_count"] - ], - feature_views.order[["order_is_success"]], - feature_views.global_fv[["num_rides", "avg_ride_length"]], - ], - logging_config=LoggingConfig( - destination=environment.data_source_creator.create_logged_features_destination() - ), - ) - - storage = environment.data_source_creator.create_saved_dataset_destination() - - store.apply( - [driver(), customer(), location(), feature_service, *feature_views.values()] - ) - - # Added to handle the case that the offline store is remote - store.registry.apply_data_source( - feature_service.logging_config.destination.to_data_source(), - store.config.project, - ) - store.registry.apply_data_source(storage.to_data_source(), store.config.project) - - entity_df = datasets.entity_df.drop( - columns=["order_id", "origin_id", "destination_id"] - ) - - # add some non-existing entities to check NotFound feature handling - for i in range(5): - entity_df = pd.concat( - [ - entity_df, - pd.DataFrame.from_records( - [ - { - "customer_id": 2000 + i, - "driver_id": 6000 + i, - "event_timestamp": make_tzaware(datetime.datetime.now()), - } - ] - ), - ] - ) - - store_fs = store.get_feature_service(feature_service.name) - reference_dataset = store.create_saved_dataset( - from_=store.get_historical_features( - entity_df=entity_df, features=store_fs, full_feature_names=True - ), - name="reference_for_validating_logged_features", - storage=storage, - allow_overwrite=True, - ) - - log_source_df = store.get_historical_features( - entity_df=entity_df, features=store_fs, full_feature_names=False - ).to_df() - logs_df = prepare_logs(log_source_df, feature_service, store) - - schema = FeatureServiceLoggingSource( - feature_service=feature_service, project=store.project - ).get_schema(store._registry) - store.write_logged_features( - pa.Table.from_pandas(logs_df, schema=schema), source=feature_service - ) - - def validate(): - """ - Return Tuple[succeed, completed] - Succeed will be True if no ValidateFailed exception was raised - """ - try: - store.validate_logged_features( - feature_service, - start=logs_df[LOG_TIMESTAMP_FIELD].min(), - end=logs_df[LOG_TIMESTAMP_FIELD].max() + datetime.timedelta(seconds=1), - reference=reference_dataset.as_reference( - name="ref", profiler=profiler_with_feature_metadata - ), - ) - except ValidationFailed: - return False, True - except Exception: - # log table is still being created - return False, False - - return True, True - - success = wait_retry_backoff(validate, timeout_secs=30) - assert success, "Validation failed (unexpectedly)" - - -@pytest.mark.integration -def test_e2e_validation_via_cli(environment, universal_data_sources): - runner = CliRunner() - store = environment.feature_store - - (_, datasets, data_sources) = universal_data_sources - feature_views = construct_universal_feature_views(data_sources) - feature_service = FeatureService( - name="test_service", - features=[ - feature_views.customer[ - ["current_balance", "avg_passenger_count", "lifetime_trip_count"] - ], - ], - logging_config=LoggingConfig( - destination=environment.data_source_creator.create_logged_features_destination() - ), - ) - store.apply([customer(), feature_service, feature_views.customer]) - - entity_df = datasets.entity_df.drop( - columns=["order_id", "origin_id", "destination_id", "driver_id"] - ) - retrieval_job = store.get_historical_features( - entity_df=entity_df, - features=store.get_feature_service(feature_service.name), - full_feature_names=True, - ) - logs_df = prepare_logs(retrieval_job.to_df(), feature_service, store) - saved_dataset = store.create_saved_dataset( - from_=retrieval_job, - name="reference_for_validating_logged_features", - storage=environment.data_source_creator.create_saved_dataset_destination(), - allow_overwrite=True, - ) - reference = saved_dataset.as_reference( - name="test_reference", profiler=configurable_profiler - ) - - schema = FeatureServiceLoggingSource( - feature_service=feature_service, project=store.project - ).get_schema(store._registry) - store.write_logged_features( - pa.Table.from_pandas(logs_df, schema=schema), source=feature_service - ) - - with runner.local_repo(example_repo_py="", offline_store="file") as local_repo: - local_repo.apply( - [customer(), feature_views.customer, feature_service, reference] - ) - local_repo._registry.apply_saved_dataset(saved_dataset, local_repo.project) - validate_args = [ - "validate", - "--feature-service", - feature_service.name, - "--reference", - reference.name, - (datetime.datetime.now() - datetime.timedelta(days=7)).isoformat(), - datetime.datetime.now().isoformat(), - ] - p = runner.run(validate_args, cwd=local_repo.repo_path) - - assert p.returncode == 0, p.stderr.decode() - assert "Validation successful" in p.stdout.decode(), p.stderr.decode() - - p = runner.run( - ["saved-datasets", "describe", saved_dataset.name], cwd=local_repo.repo_path - ) - assert p.returncode == 0, p.stderr.decode() - - p = runner.run( - ["validation-references", "describe", reference.name], - cwd=local_repo.repo_path, - ) - assert p.returncode == 0, p.stderr.decode() - - p = runner.run( - ["feature-services", "describe", feature_service.name], - cwd=local_repo.repo_path, - ) - assert p.returncode == 0, p.stderr.decode() - - # make sure second validation will use cached profile - shutil.rmtree(saved_dataset.storage.file_options.uri) - - # Add some invalid data that would lead to failed validation - invalid_data = pd.DataFrame( - data={ - "customer_id": [0], - "current_balance": [0], - "avg_passenger_count": [0], - "lifetime_trip_count": [0], - "event_timestamp": [ - make_tzaware(_utc_now()) - datetime.timedelta(hours=1) - ], - } - ) - invalid_logs = prepare_logs(invalid_data, feature_service, store) - store.write_logged_features( - pa.Table.from_pandas(invalid_logs, schema=schema), source=feature_service - ) - - p = runner.run(validate_args, cwd=local_repo.repo_path) - assert p.returncode == 1, p.stdout.decode() - assert "Validation failed" in p.stdout.decode(), p.stderr.decode() - - -# Great expectations profilers created for testing - - -@ge_profiler -def configurable_profiler(dataset: PandasDataset) -> ExpectationSuite: - from great_expectations.profile.user_configurable_profiler import ( - UserConfigurableProfiler, - ) - - return UserConfigurableProfiler( - profile_dataset=dataset, - ignored_columns=["event_timestamp"], - excluded_expectations=[ - "expect_table_columns_to_match_ordered_list", - "expect_table_row_count_to_be_between", - ], - value_set_threshold="few", - ).build_suite() - - -@ge_profiler(with_feature_metadata=True) -def profiler_with_feature_metadata(dataset: PandasDataset) -> ExpectationSuite: - from great_expectations.profile.user_configurable_profiler import ( - UserConfigurableProfiler, - ) - - # always present - dataset.expect_column_values_to_be_in_set( - "global_stats__avg_ride_length__status", {FieldStatus.PRESENT} - ) - - # present at least in 70% of rows - dataset.expect_column_values_to_be_in_set( - "customer_profile__current_balance__status", {FieldStatus.PRESENT}, mostly=0.7 - ) - - return UserConfigurableProfiler( - profile_dataset=dataset, - ignored_columns=["event_timestamp"] - + [ - c - for c in dataset.columns - if c.endswith("__timestamp") or c.endswith("__status") - ], - excluded_expectations=[ - "expect_table_columns_to_match_ordered_list", - "expect_table_row_count_to_be_between", - ], - value_set_threshold="few", - ).build_suite() - - -@ge_profiler -def profiler_with_unrealistic_expectations(dataset: PandasDataset) -> ExpectationSuite: - # note: there are 4 expectations here and only 3 are returned from the profiler - # need to create dataframe with corrupted data first - df = pd.DataFrame() - df["current_balance"] = [-100] - df["avg_passenger_count"] = [0] - - other_ds = PandasDataset(df) - other_ds.expect_column_max_to_be_between("current_balance", -1000, -100) - other_ds.expect_column_values_to_be_in_set("avg_passenger_count", value_set={0}) - - # this should pass - other_ds.expect_column_min_to_be_between("avg_passenger_count", 0, 1000) - # this should fail - other_ds.expect_column_to_exist("missing random column") - - return other_ds.get_expectation_suite() +import datetime +import shutil + +import pandas as pd +import pyarrow as pa +import pytest +from great_expectations.core import ExpectationSuite +from great_expectations.dataset import PandasDataset + +from feast import FeatureService +from feast.dqm.errors import ValidationFailed +from feast.dqm.profilers.ge_profiler import ge_profiler +from feast.feature_logging import ( + LOG_TIMESTAMP_FIELD, + FeatureServiceLoggingSource, + LoggingConfig, +) +from feast.protos.feast.serving.ServingService_pb2 import FieldStatus +from feast.utils import _utc_now, make_tzaware +from feast.wait import wait_retry_backoff +from tests.integration.feature_repos.repo_configuration import ( + construct_universal_feature_views, +) +from tests.integration.feature_repos.universal.entities import ( + customer, + driver, + location, +) +from tests.utils.cli_repo_creator import CliRunner +from tests.utils.test_log_creator import prepare_logs + +_features = [ + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + "order:order_is_success", + "global_stats:num_rides", + "global_stats:avg_ride_length", +] + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +def test_historical_retrieval_with_validation(environment, universal_data_sources): + store = environment.feature_store + (entities, datasets, data_sources) = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + storage = environment.data_source_creator.create_saved_dataset_destination() + + store.apply([driver(), customer(), location(), *feature_views.values()]) + + # Added to handle the case that the offline store is remote + store.registry.apply_data_source(storage.to_data_source(), store.config.project) + + # Create two identical retrieval jobs + entity_df = datasets.entity_df.drop( + columns=["order_id", "origin_id", "destination_id"] + ) + reference_job = store.get_historical_features( + entity_df=entity_df, + features=_features, + ) + job = store.get_historical_features( + entity_df=entity_df, + features=_features, + ) + + # Save dataset using reference job and retrieve it + store.create_saved_dataset( + from_=reference_job, + name="my_training_dataset", + storage=storage, + allow_overwrite=True, + ) + saved_dataset = store.get_saved_dataset("my_training_dataset") + + # If validation pass there will be no exceptions on this point + reference = saved_dataset.as_reference(name="ref", profiler=configurable_profiler) + job.to_df(validation_reference=reference) + + +@pytest.mark.integration +def test_historical_retrieval_fails_on_validation(environment, universal_data_sources): + store = environment.feature_store + + (entities, datasets, data_sources) = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + storage = environment.data_source_creator.create_saved_dataset_destination() + + store.apply([driver(), customer(), location(), *feature_views.values()]) + + # Added to handle the case that the offline store is remote + store.registry.apply_data_source(storage.to_data_source(), store.config.project) + + entity_df = datasets.entity_df.drop( + columns=["order_id", "origin_id", "destination_id"] + ) + + reference_job = store.get_historical_features( + entity_df=entity_df, + features=_features, + ) + + store.create_saved_dataset( + from_=reference_job, + name="my_other_dataset", + storage=storage, + allow_overwrite=True, + ) + + job = store.get_historical_features( + entity_df=entity_df, + features=_features, + ) + + ds = store.get_saved_dataset("my_other_dataset") + profiler_expectation_suite = ds.get_profile( + profiler=profiler_with_unrealistic_expectations + ) + + assert len(profiler_expectation_suite.expectation_suite["expectations"]) == 3 + + with pytest.raises(ValidationFailed) as exc_info: + job.to_df( + validation_reference=store.get_saved_dataset( + "my_other_dataset" + ).as_reference(name="ref", profiler=profiler_with_unrealistic_expectations) + ) + + failed_expectations = exc_info.value.report.errors + assert len(failed_expectations) == 2 + + assert failed_expectations[0].check_name == "expect_column_max_to_be_between" + assert failed_expectations[0].column_name == "current_balance" + + assert failed_expectations[1].check_name == "expect_column_values_to_be_in_set" + assert failed_expectations[1].column_name == "avg_passenger_count" + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +def test_logged_features_validation(environment, universal_data_sources): + store = environment.feature_store + + (_, datasets, data_sources) = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + feature_service = FeatureService( + name="test_service", + features=[ + feature_views.customer[ + ["current_balance", "avg_passenger_count", "lifetime_trip_count"] + ], + feature_views.order[["order_is_success"]], + feature_views.global_fv[["num_rides", "avg_ride_length"]], + ], + logging_config=LoggingConfig( + destination=environment.data_source_creator.create_logged_features_destination() + ), + ) + + storage = environment.data_source_creator.create_saved_dataset_destination() + + store.apply( + [driver(), customer(), location(), feature_service, *feature_views.values()] + ) + + # Added to handle the case that the offline store is remote + store.registry.apply_data_source( + feature_service.logging_config.destination.to_data_source(), + store.config.project, + ) + store.registry.apply_data_source(storage.to_data_source(), store.config.project) + + entity_df = datasets.entity_df.drop( + columns=["order_id", "origin_id", "destination_id"] + ) + + # add some non-existing entities to check NotFound feature handling + for i in range(5): + entity_df = pd.concat( + [ + entity_df, + pd.DataFrame.from_records( + [ + { + "customer_id": 2000 + i, + "driver_id": 6000 + i, + "event_timestamp": make_tzaware(datetime.datetime.now()), + } + ] + ), + ] + ) + + store_fs = store.get_feature_service(feature_service.name) + reference_dataset = store.create_saved_dataset( + from_=store.get_historical_features( + entity_df=entity_df, features=store_fs, full_feature_names=True + ), + name="reference_for_validating_logged_features", + storage=storage, + allow_overwrite=True, + ) + + log_source_df = store.get_historical_features( + entity_df=entity_df, features=store_fs, full_feature_names=False + ).to_df() + logs_df = prepare_logs(log_source_df, feature_service, store) + + schema = FeatureServiceLoggingSource( + feature_service=feature_service, project=store.project + ).get_schema(store._registry) + store.write_logged_features( + pa.Table.from_pandas(logs_df, schema=schema), source=feature_service + ) + + def validate(): + """ + Return Tuple[succeed, completed] + Succeed will be True if no ValidateFailed exception was raised + """ + try: + store.validate_logged_features( + feature_service, + start=logs_df[LOG_TIMESTAMP_FIELD].min(), + end=logs_df[LOG_TIMESTAMP_FIELD].max() + datetime.timedelta(seconds=1), + reference=reference_dataset.as_reference( + name="ref", profiler=profiler_with_feature_metadata + ), + ) + except ValidationFailed: + return False, True + except Exception: + # log table is still being created + return False, False + + return True, True + + success = wait_retry_backoff(validate, timeout_secs=30) + assert success, "Validation failed (unexpectedly)" + + +@pytest.mark.integration +def test_e2e_validation_via_cli(environment, universal_data_sources): + runner = CliRunner() + store = environment.feature_store + + (_, datasets, data_sources) = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + feature_service = FeatureService( + name="test_service", + features=[ + feature_views.customer[ + ["current_balance", "avg_passenger_count", "lifetime_trip_count"] + ], + ], + logging_config=LoggingConfig( + destination=environment.data_source_creator.create_logged_features_destination() + ), + ) + store.apply([customer(), feature_service, feature_views.customer]) + + entity_df = datasets.entity_df.drop( + columns=["order_id", "origin_id", "destination_id", "driver_id"] + ) + retrieval_job = store.get_historical_features( + entity_df=entity_df, + features=store.get_feature_service(feature_service.name), + full_feature_names=True, + ) + logs_df = prepare_logs(retrieval_job.to_df(), feature_service, store) + saved_dataset = store.create_saved_dataset( + from_=retrieval_job, + name="reference_for_validating_logged_features", + storage=environment.data_source_creator.create_saved_dataset_destination(), + allow_overwrite=True, + ) + reference = saved_dataset.as_reference( + name="test_reference", profiler=configurable_profiler + ) + + schema = FeatureServiceLoggingSource( + feature_service=feature_service, project=store.project + ).get_schema(store._registry) + store.write_logged_features( + pa.Table.from_pandas(logs_df, schema=schema), source=feature_service + ) + + with runner.local_repo(example_repo_py="", offline_store="file") as local_repo: + local_repo.apply( + [customer(), feature_views.customer, feature_service, reference] + ) + local_repo.registry.apply_saved_dataset(saved_dataset, local_repo.project) + validate_args = [ + "validate", + "--feature-service", + feature_service.name, + "--reference", + reference.name, + (datetime.datetime.now() - datetime.timedelta(days=7)).isoformat(), + datetime.datetime.now().isoformat(), + ] + p = runner.run(validate_args, cwd=local_repo.repo_path) + + assert p.returncode == 0, p.stderr.decode() + assert "Validation successful" in p.stdout.decode(), p.stderr.decode() + + p = runner.run( + ["saved-datasets", "describe", saved_dataset.name], cwd=local_repo.repo_path + ) + assert p.returncode == 0, p.stderr.decode() + + p = runner.run( + ["validation-references", "describe", reference.name], + cwd=local_repo.repo_path, + ) + assert p.returncode == 0, p.stderr.decode() + + p = runner.run( + ["feature-services", "describe", feature_service.name], + cwd=local_repo.repo_path, + ) + assert p.returncode == 0, p.stderr.decode() + + # make sure second validation will use cached profile + shutil.rmtree(saved_dataset.storage.file_options.uri) + + # Add some invalid data that would lead to failed validation + invalid_data = pd.DataFrame( + data={ + "customer_id": [0], + "current_balance": [0], + "avg_passenger_count": [0], + "lifetime_trip_count": [0], + "event_timestamp": [ + make_tzaware(_utc_now()) - datetime.timedelta(hours=1) + ], + } + ) + invalid_logs = prepare_logs(invalid_data, feature_service, store) + store.write_logged_features( + pa.Table.from_pandas(invalid_logs, schema=schema), source=feature_service + ) + + p = runner.run(validate_args, cwd=local_repo.repo_path) + assert p.returncode == 1, p.stdout.decode() + assert "Validation failed" in p.stdout.decode(), p.stderr.decode() + + +# Great expectations profilers created for testing + + +@ge_profiler +def configurable_profiler(dataset: PandasDataset) -> ExpectationSuite: + from great_expectations.profile.user_configurable_profiler import ( + UserConfigurableProfiler, + ) + + return UserConfigurableProfiler( + profile_dataset=dataset, + ignored_columns=["event_timestamp"], + excluded_expectations=[ + "expect_table_columns_to_match_ordered_list", + "expect_table_row_count_to_be_between", + ], + value_set_threshold="few", + ).build_suite() + + +@ge_profiler(with_feature_metadata=True) +def profiler_with_feature_metadata(dataset: PandasDataset) -> ExpectationSuite: + from great_expectations.profile.user_configurable_profiler import ( + UserConfigurableProfiler, + ) + + # always present + dataset.expect_column_values_to_be_in_set( + "global_stats__avg_ride_length__status", {FieldStatus.PRESENT} + ) + + # present at least in 70% of rows + dataset.expect_column_values_to_be_in_set( + "customer_profile__current_balance__status", {FieldStatus.PRESENT}, mostly=0.7 + ) + + return UserConfigurableProfiler( + profile_dataset=dataset, + ignored_columns=["event_timestamp"] + + [ + c + for c in dataset.columns + if c.endswith("__timestamp") or c.endswith("__status") + ], + excluded_expectations=[ + "expect_table_columns_to_match_ordered_list", + "expect_table_row_count_to_be_between", + ], + value_set_threshold="few", + ).build_suite() + + +@ge_profiler +def profiler_with_unrealistic_expectations(dataset: PandasDataset) -> ExpectationSuite: + # note: there are 4 expectations here and only 3 are returned from the profiler + # need to create dataframe with corrupted data first + df = pd.DataFrame() + df["current_balance"] = [-100] + df["avg_passenger_count"] = [0] + + other_ds = PandasDataset(df) + other_ds.expect_column_max_to_be_between("current_balance", -1000, -100) + other_ds.expect_column_values_to_be_in_set("avg_passenger_count", value_set={0}) + + # this should pass + other_ds.expect_column_min_to_be_between("avg_passenger_count", 0, 1000) + # this should fail + other_ds.expect_column_to_exist("missing random column") + + return other_ds.get_expectation_suite() diff --git a/sdk/python/tests/unit/api/test_api_rest_registry.py b/sdk/python/tests/unit/api/test_api_rest_registry.py index 12e22737f93..afba602ddac 100644 --- a/sdk/python/tests/unit/api/test_api_rest_registry.py +++ b/sdk/python/tests/unit/api/test_api_rest_registry.py @@ -139,7 +139,7 @@ def test_on_demand_feature_view(features_df: pd.DataFrame) -> pd.DataFrame: test_on_demand_feature_view, ] ) - store._registry.apply_saved_dataset(test_saved_dataset, "demo_project") + store.registry.apply_saved_dataset(test_saved_dataset, "demo_project") # Build REST app with registered routes rest_server = RestRegistryServer(store) @@ -773,7 +773,7 @@ def fastapi_test_app_with_multiple_objects(): store.apply(entities + data_sources + feature_views + feature_services) for dataset in saved_datasets: - store._registry.apply_saved_dataset(dataset, "demo_project") + store.registry.apply_saved_dataset(dataset, "demo_project") rest_server = RestRegistryServer(store) client = TestClient(rest_server.app) diff --git a/sdk/python/tests/unit/api/test_search_api.py b/sdk/python/tests/unit/api/test_search_api.py index f0d7c3942e8..bce7a30e9fb 100644 --- a/sdk/python/tests/unit/api/test_search_api.py +++ b/sdk/python/tests/unit/api/test_search_api.py @@ -235,7 +235,7 @@ def user_on_demand_features(inputs: dict): user_on_demand_features, ] ) - store._registry.apply_saved_dataset(user_dataset, "test_project") + store.registry.apply_saved_dataset(user_dataset, "test_project") global global_store global_store = store @@ -431,7 +431,7 @@ def multi_project_search_test_app(): description=project_data["description"], tags={"domain": project_data["domain"]}, ) - master_store._registry.apply_project(project_obj) + master_store.registry.apply_project(project_obj) # Create resources for each project and apply them to the shared registry for project_name, project_data in projects_data.items(): @@ -565,19 +565,19 @@ def multi_project_search_test_app(): # Apply all objects for this project directly to the registry for entity in entities: - master_store._registry.apply_entity(entity, project_name) + master_store.registry.apply_entity(entity, project_name) for data_source in data_sources: - master_store._registry.apply_data_source(data_source, project_name) + master_store.registry.apply_data_source(data_source, project_name) for feature_view in feature_views: - master_store._registry.apply_feature_view(feature_view, project_name) + master_store.registry.apply_feature_view(feature_view, project_name) for feature_service in feature_services: - master_store._registry.apply_feature_service(feature_service, project_name) + master_store.registry.apply_feature_service(feature_service, project_name) # Ensure registry is committed - master_store._registry.commit() + master_store.registry.commit() # Build REST app using the master store's registry (contains all projects) rest_server = RestRegistryServer(master_store) @@ -1213,7 +1213,7 @@ def test_search_on_demand_feature_view(self, search_test_app): """Test searching for on-demand feature views""" # Search by name global global_store - global_store._registry.refresh() + global_store.registry.refresh() response = search_test_app.get("/search?query=user_on_demand_features") assert response.status_code == 200