From f06f9abd6250eb109e31671121f6e0eb8ce29842 Mon Sep 17 00:00:00 2001 From: Ashley Xu Date: Wed, 3 Jan 2024 01:10:40 +0000 Subject: [PATCH 1/4] feat: add to_gbq() method for LLM models --- bigframes/ml/llm.py | 72 +++++++++++++++++++++++++++++++ bigframes/ml/loader.py | 22 ++++++++++ tests/system/small/ml/test_llm.py | 35 ++++++++++++++- 3 files changed, 127 insertions(+), 2 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 5beb54a32d..a5ee5e6e4a 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -19,6 +19,8 @@ from typing import cast, Literal, Optional, Union import warnings +from google.cloud import bigquery + import bigframes from bigframes import clients, constants from bigframes.core import blocks, log_adapter @@ -113,6 +115,23 @@ def _create_bqml_model(self): session=self.session, connection_name=self.connection_name, options=options ) + @classmethod + def _from_bq( + cls, session: bigframes.Session, model: bigquery.Model + ) -> PaLM2TextGenerator: + if ( + model.model_type == "MODEL_TYPE_UNSPECIFIED" + and model._properties["remoteModelInfo"]["endpoint"] is not None + ): + # Parse the remote model endpoint + bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] + endpoint_model = bqml_endpoint.split("/")[-1] + assert endpoint_model in _TEXT_GENERATOR_ENDPOINTS + + text_generator_model = cls(session=session) + text_generator_model._bqml_model = core.BqmlModel(session, model) + return text_generator_model + def predict( self, X: Union[bpd.DataFrame, bpd.Series], @@ -200,6 +219,23 @@ def predict( return df + def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator: + """Save the model to BigQuery. + + Args: + model_name (str): + the name of the model. + replace (bool, default False): + whether to replace if the model already exists. Default to False. + + Returns: + PaLM2TextGenerator: saved model.""" + if not self._bqml_model: + raise RuntimeError("A model must be fitted before it can be saved") + + new_model = self._bqml_model.copy(model_name, replace) + return new_model.session.read_gbq_model(model_name) + @log_adapter.class_logger class PaLM2TextEmbeddingGenerator(base.Predictor): @@ -271,6 +307,23 @@ def _create_bqml_model(self): session=self.session, connection_name=self.connection_name, options=options ) + @classmethod + def _from_bq( + cls, session: bigframes.Session, model: bigquery.Model + ) -> PaLM2TextEmbeddingGenerator: + if ( + model.model_type == "MODEL_TYPE_UNSPECIFIED" + and model._properties["remoteModelInfo"]["endpoint"] is not None + ): + # Parse the remote model endpoint + bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] + endpoint_model = bqml_endpoint.split("/")[-1] + assert endpoint_model in _EMBEDDING_GENERATOR_ENDPOINTS + + embedding_generator_model = cls(session=session) + embedding_generator_model._bqml_model = core.BqmlModel(session, model) + return embedding_generator_model + def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: """Predict the result from input DataFrame. @@ -307,3 +360,22 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: ) return df + + def to_gbq( + self, model_name: str, replace: bool = False + ) -> PaLM2TextEmbeddingGenerator: + """Save the model to BigQuery. + + Args: + model_name (str): + the name of the model. + replace (bool, default False): + whether to replace if the model already exists. Default to False. + + Returns: + PaLM2TextEmbeddingGenerator: saved model.""" + if not self._bqml_model: + raise RuntimeError("A model must be fitted before it can be saved") + + new_model = self._bqml_model.copy(model_name, replace) + return new_model.session.read_gbq_model(model_name) diff --git a/bigframes/ml/loader.py b/bigframes/ml/loader.py index 805747c49b..f368e93e8d 100644 --- a/bigframes/ml/loader.py +++ b/bigframes/ml/loader.py @@ -28,6 +28,7 @@ forecasting, imported, linear_model, + llm, pipeline, ) @@ -47,6 +48,15 @@ } ) +_BQML_ENDPOINT_TYPE_MAPPING = MappingProxyType( + { + llm._TEXT_GENERATOR_BISON_ENDPOINT: llm.PaLM2TextGenerator, + llm._TEXT_GENERATOR_BISON_32K_ENDPOINT: llm.PaLM2TextGenerator, + llm._EMBEDDING_GENERATOR_GECKO_ENDPOINT: llm.PaLM2TextEmbeddingGenerator, + llm._EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT: llm.PaLM2TextEmbeddingGenerator, + } +) + def from_bq( session: bigframes.Session, bq_model: bigquery.Model @@ -62,6 +72,8 @@ def from_bq( ensemble.RandomForestClassifier, imported.TensorFlowModel, imported.ONNXModel, + llm.PaLM2TextGenerator, + llm.PaLM2TextEmbeddingGenerator, pipeline.Pipeline, ]: """Load a BQML model to BigQuery DataFrames ML. @@ -84,6 +96,16 @@ def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model): return _BQML_MODEL_TYPE_MAPPING[bq_model.model_type]._from_bq( # type: ignore session=session, model=bq_model ) + if ( + bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" + and bq_model._properties["remoteModelInfo"]["endpoint"] is not None + ): + # Parse the remote model endpoint + bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"] + endpoint_model = bqml_endpoint.split("/")[-1] + return _BQML_ENDPOINT_TYPE_MAPPING[endpoint_model]._from_bq( # type: ignore + session=session, model=bq_model + ) raise NotImplementedError( f"Model type {bq_model.model_type} is not yet supported by BigQuery DataFrames. {constants.FEEDBACK_LINK}" diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index fd1b803eea..42d1dc3cad 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -17,11 +17,29 @@ from bigframes.ml import llm -def test_create_text_generator_model(palm2_text_generator_model): +def test_create_text_generator_model(palm2_text_generator_model, dataset_id): # Model creation doesn't return error assert palm2_text_generator_model is not None assert palm2_text_generator_model._bqml_model is not None + # save, load to ensure configuration was kept + reloaded_model = palm2_text_generator_model.to_gbq( + f"{dataset_id}.temp_text_model", replace=True + ) + assert f"{dataset_id}.temp_text_model" in reloaded_model._bqml_model.model_name + + +def test_create_text_generator_32k_model(palm2_text_generator_32k_model, dataset_id): + # Model creation doesn't return error + assert palm2_text_generator_32k_model is not None + assert palm2_text_generator_32k_model._bqml_model is not None + + # save, load to ensure configuration was kept + reloaded_model = palm2_text_generator_32k_model.to_gbq( + f"{dataset_id}.temp_text_model", replace=True + ) + assert f"{dataset_id}.temp_text_model" in reloaded_model._bqml_model.model_name + @pytest.mark.flaky(retries=2, delay=120) def test_create_text_generator_model_default_session( @@ -152,19 +170,32 @@ def test_text_generator_predict_with_params_success( assert all(series.str.len() > 20) -def test_create_embedding_generator_model(palm2_embedding_generator_model): +def test_create_embedding_generator_model(palm2_embedding_generator_model, dataset_id): # Model creation doesn't return error assert palm2_embedding_generator_model is not None assert palm2_embedding_generator_model._bqml_model is not None + # save, load to ensure configuration was kept + reloaded_model = palm2_embedding_generator_model.to_gbq( + f"{dataset_id}.temp_embedding_model", replace=True + ) + assert f"{dataset_id}.temp_embedding_model" in reloaded_model._bqml_model.model_name + def test_create_embedding_generator_multilingual_model( palm2_embedding_generator_multilingual_model, + dataset_id, ): # Model creation doesn't return error assert palm2_embedding_generator_multilingual_model is not None assert palm2_embedding_generator_multilingual_model._bqml_model is not None + # save, load to ensure configuration was kept + reloaded_model = palm2_embedding_generator_multilingual_model.to_gbq( + f"{dataset_id}.temp_embedding_model", replace=True + ) + assert f"{dataset_id}.temp_embedding_model" in reloaded_model._bqml_model.model_name + def test_create_text_embedding_generator_model_defaults(bq_connection): import bigframes.pandas as bpd From fd4d804f96bb22f2cfd52b9817ac0f635b0431ae Mon Sep 17 00:00:00 2001 From: Ashley Xu Date: Thu, 4 Jan 2024 20:07:31 +0000 Subject: [PATCH 2/4] address comments --- bigframes/ml/llm.py | 46 +++++++++++++++---------------- bigframes/ml/loader.py | 9 +++--- tests/system/small/ml/test_llm.py | 4 +-- 3 files changed, 28 insertions(+), 31 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index a5ee5e6e4a..f4322272a6 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -119,16 +119,17 @@ def _create_bqml_model(self): def _from_bq( cls, session: bigframes.Session, model: bigquery.Model ) -> PaLM2TextGenerator: - if ( - model.model_type == "MODEL_TYPE_UNSPECIFIED" - and model._properties["remoteModelInfo"]["endpoint"] is not None - ): - # Parse the remote model endpoint - bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] - endpoint_model = bqml_endpoint.split("/")[-1] - assert endpoint_model in _TEXT_GENERATOR_ENDPOINTS - - text_generator_model = cls(session=session) + assert model.model_type == "MODEL_TYPE_UNSPECIFIED" + assert model._properties.get("remoteModelInfo").get("endpoint") is not None + + # Parse the remote model endpoint + bqml_endpoint = model._properties.get("remoteModelInfo").get("endpoint") + model_connection = model._properties.get("remoteModelInfo").get("connection") + model_endpoint = bqml_endpoint.split("/")[-1] + + text_generator_model = cls( + session=session, model_name=model_endpoint, connection_name=model_connection + ) text_generator_model._bqml_model = core.BqmlModel(session, model) return text_generator_model @@ -230,8 +231,6 @@ def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator: Returns: PaLM2TextGenerator: saved model.""" - if not self._bqml_model: - raise RuntimeError("A model must be fitted before it can be saved") new_model = self._bqml_model.copy(model_name, replace) return new_model.session.read_gbq_model(model_name) @@ -311,16 +310,17 @@ def _create_bqml_model(self): def _from_bq( cls, session: bigframes.Session, model: bigquery.Model ) -> PaLM2TextEmbeddingGenerator: - if ( - model.model_type == "MODEL_TYPE_UNSPECIFIED" - and model._properties["remoteModelInfo"]["endpoint"] is not None - ): - # Parse the remote model endpoint - bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] - endpoint_model = bqml_endpoint.split("/")[-1] - assert endpoint_model in _EMBEDDING_GENERATOR_ENDPOINTS - - embedding_generator_model = cls(session=session) + assert model.model_type == "MODEL_TYPE_UNSPECIFIED" + assert model._properties.get("remoteModelInfo").get("endpoint") is not None + + # Parse the remote model endpoint + bqml_endpoint = model._properties.get("remoteModelInfo").get("endpoint") + model_connection = model._properties.get("remoteModelInfo").get("connection") + model_endpoint = bqml_endpoint.split("/")[-1] + + embedding_generator_model = cls( + session=session, model_name=model_endpoint, connection_name=model_connection + ) embedding_generator_model._bqml_model = core.BqmlModel(session, model) return embedding_generator_model @@ -374,8 +374,6 @@ def to_gbq( Returns: PaLM2TextEmbeddingGenerator: saved model.""" - if not self._bqml_model: - raise RuntimeError("A model must be fitted before it can be saved") new_model = self._bqml_model.copy(model_name, replace) return new_model.session.read_gbq_model(model_name) diff --git a/bigframes/ml/loader.py b/bigframes/ml/loader.py index f368e93e8d..c73107ae0e 100644 --- a/bigframes/ml/loader.py +++ b/bigframes/ml/loader.py @@ -96,12 +96,11 @@ def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model): return _BQML_MODEL_TYPE_MAPPING[bq_model.model_type]._from_bq( # type: ignore session=session, model=bq_model ) - if ( - bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" - and bq_model._properties["remoteModelInfo"]["endpoint"] is not None - ): + if bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" and bq_model._properties.get( + "remoteModelInfo" + ).get("endpoint"): # Parse the remote model endpoint - bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"] + bqml_endpoint = bq_model._properties.get("remoteModelInfo").get("endpoint") endpoint_model = bqml_endpoint.split("/")[-1] return _BQML_ENDPOINT_TYPE_MAPPING[endpoint_model]._from_bq( # type: ignore session=session, model=bq_model diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 42d1dc3cad..1ab1167a1d 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -26,7 +26,7 @@ def test_create_text_generator_model(palm2_text_generator_model, dataset_id): reloaded_model = palm2_text_generator_model.to_gbq( f"{dataset_id}.temp_text_model", replace=True ) - assert f"{dataset_id}.temp_text_model" in reloaded_model._bqml_model.model_name + assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name def test_create_text_generator_32k_model(palm2_text_generator_32k_model, dataset_id): @@ -38,7 +38,7 @@ def test_create_text_generator_32k_model(palm2_text_generator_32k_model, dataset reloaded_model = palm2_text_generator_32k_model.to_gbq( f"{dataset_id}.temp_text_model", replace=True ) - assert f"{dataset_id}.temp_text_model" in reloaded_model._bqml_model.model_name + assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name @pytest.mark.flaky(retries=2, delay=120) From 7057bdda35acdf0a1c91867277e879e829db4d6c Mon Sep 17 00:00:00 2001 From: Ashley Xu Date: Fri, 5 Jan 2024 00:04:13 +0000 Subject: [PATCH 3/4] address additional comments --- bigframes/ml/llm.py | 27 +++++++++++++++++++++------ bigframes/ml/loader.py | 10 ++++++---- tests/system/small/ml/test_llm.py | 25 ++++++++++++++++++++----- 3 files changed, 47 insertions(+), 15 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index f4322272a6..f444c5964c 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -120,11 +120,19 @@ def _from_bq( cls, session: bigframes.Session, model: bigquery.Model ) -> PaLM2TextGenerator: assert model.model_type == "MODEL_TYPE_UNSPECIFIED" - assert model._properties.get("remoteModelInfo").get("endpoint") is not None + assert model._properties["remoteModelInfo"]["endpoint"] is not None # Parse the remote model endpoint - bqml_endpoint = model._properties.get("remoteModelInfo").get("endpoint") - model_connection = model._properties.get("remoteModelInfo").get("connection") + if ( + "remoteModelInfo" in model._properties + and "endpoint" in model._properties["remoteModelInfo"] + ): + bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] + if ( + "remoteModelInfo" in model._properties + and "connection" in model._properties["remoteModelInfo"] + ): + model_connection = model._properties["remoteModelInfo"]["connection"] model_endpoint = bqml_endpoint.split("/")[-1] text_generator_model = cls( @@ -311,11 +319,18 @@ def _from_bq( cls, session: bigframes.Session, model: bigquery.Model ) -> PaLM2TextEmbeddingGenerator: assert model.model_type == "MODEL_TYPE_UNSPECIFIED" - assert model._properties.get("remoteModelInfo").get("endpoint") is not None # Parse the remote model endpoint - bqml_endpoint = model._properties.get("remoteModelInfo").get("endpoint") - model_connection = model._properties.get("remoteModelInfo").get("connection") + if ( + "remoteModelInfo" in model._properties + and "endpoint" in model._properties["remoteModelInfo"] + ): + bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] + if ( + "remoteModelInfo" in model._properties + and "connection" in model._properties["remoteModelInfo"] + ): + model_connection = model._properties["remoteModelInfo"]["connection"] model_endpoint = bqml_endpoint.split("/")[-1] embedding_generator_model = cls( diff --git a/bigframes/ml/loader.py b/bigframes/ml/loader.py index c73107ae0e..b64a22ee32 100644 --- a/bigframes/ml/loader.py +++ b/bigframes/ml/loader.py @@ -96,11 +96,13 @@ def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model): return _BQML_MODEL_TYPE_MAPPING[bq_model.model_type]._from_bq( # type: ignore session=session, model=bq_model ) - if bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" and bq_model._properties.get( - "remoteModelInfo" - ).get("endpoint"): + if ( + bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" + and bq_model._properties["remoteModelInfo"] + and bq_model._properties["remoteModelInfo"]["endpoint"] + ): # Parse the remote model endpoint - bqml_endpoint = bq_model._properties.get("remoteModelInfo").get("endpoint") + bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"] endpoint_model = bqml_endpoint.split("/")[-1] return _BQML_ENDPOINT_TYPE_MAPPING[endpoint_model]._from_bq( # type: ignore session=session, model=bq_model diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 1ab1167a1d..805cee4fec 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -17,7 +17,9 @@ from bigframes.ml import llm -def test_create_text_generator_model(palm2_text_generator_model, dataset_id): +def test_create_text_generator_model( + palm2_text_generator_model, dataset_id, bq_connection +): # Model creation doesn't return error assert palm2_text_generator_model is not None assert palm2_text_generator_model._bqml_model is not None @@ -27,9 +29,13 @@ def test_create_text_generator_model(palm2_text_generator_model, dataset_id): f"{dataset_id}.temp_text_model", replace=True ) assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name + assert reloaded_model.model_name == "text-bison" + assert reloaded_model.connection_name == bq_connection -def test_create_text_generator_32k_model(palm2_text_generator_32k_model, dataset_id): +def test_create_text_generator_32k_model( + palm2_text_generator_32k_model, dataset_id, bq_connection +): # Model creation doesn't return error assert palm2_text_generator_32k_model is not None assert palm2_text_generator_32k_model._bqml_model is not None @@ -39,6 +45,8 @@ def test_create_text_generator_32k_model(palm2_text_generator_32k_model, dataset f"{dataset_id}.temp_text_model", replace=True ) assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name + assert reloaded_model.model_name == "text-bison-32k" + assert reloaded_model.connection_name == bq_connection @pytest.mark.flaky(retries=2, delay=120) @@ -170,7 +178,9 @@ def test_text_generator_predict_with_params_success( assert all(series.str.len() > 20) -def test_create_embedding_generator_model(palm2_embedding_generator_model, dataset_id): +def test_create_embedding_generator_model( + palm2_embedding_generator_model, dataset_id, bq_connection +): # Model creation doesn't return error assert palm2_embedding_generator_model is not None assert palm2_embedding_generator_model._bqml_model is not None @@ -179,12 +189,15 @@ def test_create_embedding_generator_model(palm2_embedding_generator_model, datas reloaded_model = palm2_embedding_generator_model.to_gbq( f"{dataset_id}.temp_embedding_model", replace=True ) - assert f"{dataset_id}.temp_embedding_model" in reloaded_model._bqml_model.model_name + assert f"{dataset_id}.temp_embedding_model" == reloaded_model._bqml_model.model_name + assert reloaded_model.model_name == "textembedding-gecko" + assert reloaded_model.connection_name == bq_connection def test_create_embedding_generator_multilingual_model( palm2_embedding_generator_multilingual_model, dataset_id, + bq_connection, ): # Model creation doesn't return error assert palm2_embedding_generator_multilingual_model is not None @@ -194,7 +207,9 @@ def test_create_embedding_generator_multilingual_model( reloaded_model = palm2_embedding_generator_multilingual_model.to_gbq( f"{dataset_id}.temp_embedding_model", replace=True ) - assert f"{dataset_id}.temp_embedding_model" in reloaded_model._bqml_model.model_name + assert f"{dataset_id}.temp_embedding_model" == reloaded_model._bqml_model.model_name + assert reloaded_model.model_name == "textembedding-gecko-multilingual" + assert reloaded_model.connection_name == bq_connection def test_create_text_embedding_generator_model_defaults(bq_connection): From 2bbb786f836f9567a705129a3a584e2f5ce5f0a5 Mon Sep 17 00:00:00 2001 From: Ashley Xu Date: Mon, 8 Jan 2024 21:53:48 +0000 Subject: [PATCH 4/4] address comments --- bigframes/ml/llm.py | 31 ++++++++++--------------------- bigframes/ml/loader.py | 4 ++-- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index f444c5964c..8c01159113 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -120,19 +120,13 @@ def _from_bq( cls, session: bigframes.Session, model: bigquery.Model ) -> PaLM2TextGenerator: assert model.model_type == "MODEL_TYPE_UNSPECIFIED" - assert model._properties["remoteModelInfo"]["endpoint"] is not None + assert "remoteModelInfo" in model._properties + assert "endpoint" in model._properties["remoteModelInfo"] + assert "connection" in model._properties["remoteModelInfo"] # Parse the remote model endpoint - if ( - "remoteModelInfo" in model._properties - and "endpoint" in model._properties["remoteModelInfo"] - ): - bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] - if ( - "remoteModelInfo" in model._properties - and "connection" in model._properties["remoteModelInfo"] - ): - model_connection = model._properties["remoteModelInfo"]["connection"] + bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] + model_connection = model._properties["remoteModelInfo"]["connection"] model_endpoint = bqml_endpoint.split("/")[-1] text_generator_model = cls( @@ -319,18 +313,13 @@ def _from_bq( cls, session: bigframes.Session, model: bigquery.Model ) -> PaLM2TextEmbeddingGenerator: assert model.model_type == "MODEL_TYPE_UNSPECIFIED" + assert "remoteModelInfo" in model._properties + assert "endpoint" in model._properties["remoteModelInfo"] + assert "connection" in model._properties["remoteModelInfo"] # Parse the remote model endpoint - if ( - "remoteModelInfo" in model._properties - and "endpoint" in model._properties["remoteModelInfo"] - ): - bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] - if ( - "remoteModelInfo" in model._properties - and "connection" in model._properties["remoteModelInfo"] - ): - model_connection = model._properties["remoteModelInfo"]["connection"] + bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] + model_connection = model._properties["remoteModelInfo"]["connection"] model_endpoint = bqml_endpoint.split("/")[-1] embedding_generator_model = cls( diff --git a/bigframes/ml/loader.py b/bigframes/ml/loader.py index b64a22ee32..4ffde43543 100644 --- a/bigframes/ml/loader.py +++ b/bigframes/ml/loader.py @@ -98,8 +98,8 @@ def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model): ) if ( bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" - and bq_model._properties["remoteModelInfo"] - and bq_model._properties["remoteModelInfo"]["endpoint"] + and "remoteModelInfo" in bq_model._properties + and "endpoint" in bq_model._properties["remoteModelInfo"] ): # Parse the remote model endpoint bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]