diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 5beb54a32d..8c01159113 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -19,6 +19,8 @@ from typing import cast, Literal, Optional, Union import warnings +from google.cloud import bigquery + import bigframes from bigframes import clients, constants from bigframes.core import blocks, log_adapter @@ -113,6 +115,26 @@ def _create_bqml_model(self): session=self.session, connection_name=self.connection_name, options=options ) + @classmethod + def _from_bq( + cls, session: bigframes.Session, model: bigquery.Model + ) -> PaLM2TextGenerator: + assert model.model_type == "MODEL_TYPE_UNSPECIFIED" + assert "remoteModelInfo" in model._properties + assert "endpoint" in model._properties["remoteModelInfo"] + assert "connection" in model._properties["remoteModelInfo"] + + # Parse the remote model endpoint + bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] + model_connection = model._properties["remoteModelInfo"]["connection"] + model_endpoint = bqml_endpoint.split("/")[-1] + + text_generator_model = cls( + session=session, model_name=model_endpoint, connection_name=model_connection + ) + text_generator_model._bqml_model = core.BqmlModel(session, model) + return text_generator_model + def predict( self, X: Union[bpd.DataFrame, bpd.Series], @@ -200,6 +222,21 @@ def predict( return df + def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator: + """Save the model to BigQuery. + + Args: + model_name (str): + the name of the model. + replace (bool, default False): + whether to replace if the model already exists. Default to False. + + Returns: + PaLM2TextGenerator: saved model.""" + + new_model = self._bqml_model.copy(model_name, replace) + return new_model.session.read_gbq_model(model_name) + @log_adapter.class_logger class PaLM2TextEmbeddingGenerator(base.Predictor): @@ -271,6 +308,26 @@ def _create_bqml_model(self): session=self.session, connection_name=self.connection_name, options=options ) + @classmethod + def _from_bq( + cls, session: bigframes.Session, model: bigquery.Model + ) -> PaLM2TextEmbeddingGenerator: + assert model.model_type == "MODEL_TYPE_UNSPECIFIED" + assert "remoteModelInfo" in model._properties + assert "endpoint" in model._properties["remoteModelInfo"] + assert "connection" in model._properties["remoteModelInfo"] + + # Parse the remote model endpoint + bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"] + model_connection = model._properties["remoteModelInfo"]["connection"] + model_endpoint = bqml_endpoint.split("/")[-1] + + embedding_generator_model = cls( + session=session, model_name=model_endpoint, connection_name=model_connection + ) + embedding_generator_model._bqml_model = core.BqmlModel(session, model) + return embedding_generator_model + def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: """Predict the result from input DataFrame. @@ -307,3 +364,20 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: ) return df + + def to_gbq( + self, model_name: str, replace: bool = False + ) -> PaLM2TextEmbeddingGenerator: + """Save the model to BigQuery. + + Args: + model_name (str): + the name of the model. + replace (bool, default False): + whether to replace if the model already exists. Default to False. + + Returns: + PaLM2TextEmbeddingGenerator: saved model.""" + + new_model = self._bqml_model.copy(model_name, replace) + return new_model.session.read_gbq_model(model_name) diff --git a/bigframes/ml/loader.py b/bigframes/ml/loader.py index 805747c49b..4ffde43543 100644 --- a/bigframes/ml/loader.py +++ b/bigframes/ml/loader.py @@ -28,6 +28,7 @@ forecasting, imported, linear_model, + llm, pipeline, ) @@ -47,6 +48,15 @@ } ) +_BQML_ENDPOINT_TYPE_MAPPING = MappingProxyType( + { + llm._TEXT_GENERATOR_BISON_ENDPOINT: llm.PaLM2TextGenerator, + llm._TEXT_GENERATOR_BISON_32K_ENDPOINT: llm.PaLM2TextGenerator, + llm._EMBEDDING_GENERATOR_GECKO_ENDPOINT: llm.PaLM2TextEmbeddingGenerator, + llm._EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT: llm.PaLM2TextEmbeddingGenerator, + } +) + def from_bq( session: bigframes.Session, bq_model: bigquery.Model @@ -62,6 +72,8 @@ def from_bq( ensemble.RandomForestClassifier, imported.TensorFlowModel, imported.ONNXModel, + llm.PaLM2TextGenerator, + llm.PaLM2TextEmbeddingGenerator, pipeline.Pipeline, ]: """Load a BQML model to BigQuery DataFrames ML. @@ -84,6 +96,17 @@ def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model): return _BQML_MODEL_TYPE_MAPPING[bq_model.model_type]._from_bq( # type: ignore session=session, model=bq_model ) + if ( + bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" + and "remoteModelInfo" in bq_model._properties + and "endpoint" in bq_model._properties["remoteModelInfo"] + ): + # Parse the remote model endpoint + bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"] + endpoint_model = bqml_endpoint.split("/")[-1] + return _BQML_ENDPOINT_TYPE_MAPPING[endpoint_model]._from_bq( # type: ignore + session=session, model=bq_model + ) raise NotImplementedError( f"Model type {bq_model.model_type} is not yet supported by BigQuery DataFrames. {constants.FEEDBACK_LINK}" diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index fd1b803eea..805cee4fec 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -17,11 +17,37 @@ from bigframes.ml import llm -def test_create_text_generator_model(palm2_text_generator_model): +def test_create_text_generator_model( + palm2_text_generator_model, dataset_id, bq_connection +): # Model creation doesn't return error assert palm2_text_generator_model is not None assert palm2_text_generator_model._bqml_model is not None + # save, load to ensure configuration was kept + reloaded_model = palm2_text_generator_model.to_gbq( + f"{dataset_id}.temp_text_model", replace=True + ) + assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name + assert reloaded_model.model_name == "text-bison" + assert reloaded_model.connection_name == bq_connection + + +def test_create_text_generator_32k_model( + palm2_text_generator_32k_model, dataset_id, bq_connection +): + # Model creation doesn't return error + assert palm2_text_generator_32k_model is not None + assert palm2_text_generator_32k_model._bqml_model is not None + + # save, load to ensure configuration was kept + reloaded_model = palm2_text_generator_32k_model.to_gbq( + f"{dataset_id}.temp_text_model", replace=True + ) + assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name + assert reloaded_model.model_name == "text-bison-32k" + assert reloaded_model.connection_name == bq_connection + @pytest.mark.flaky(retries=2, delay=120) def test_create_text_generator_model_default_session( @@ -152,19 +178,39 @@ def test_text_generator_predict_with_params_success( assert all(series.str.len() > 20) -def test_create_embedding_generator_model(palm2_embedding_generator_model): +def test_create_embedding_generator_model( + palm2_embedding_generator_model, dataset_id, bq_connection +): # Model creation doesn't return error assert palm2_embedding_generator_model is not None assert palm2_embedding_generator_model._bqml_model is not None + # save, load to ensure configuration was kept + reloaded_model = palm2_embedding_generator_model.to_gbq( + f"{dataset_id}.temp_embedding_model", replace=True + ) + assert f"{dataset_id}.temp_embedding_model" == reloaded_model._bqml_model.model_name + assert reloaded_model.model_name == "textembedding-gecko" + assert reloaded_model.connection_name == bq_connection + def test_create_embedding_generator_multilingual_model( palm2_embedding_generator_multilingual_model, + dataset_id, + bq_connection, ): # Model creation doesn't return error assert palm2_embedding_generator_multilingual_model is not None assert palm2_embedding_generator_multilingual_model._bqml_model is not None + # save, load to ensure configuration was kept + reloaded_model = palm2_embedding_generator_multilingual_model.to_gbq( + f"{dataset_id}.temp_embedding_model", replace=True + ) + assert f"{dataset_id}.temp_embedding_model" == reloaded_model._bqml_model.model_name + assert reloaded_model.model_name == "textembedding-gecko-multilingual" + assert reloaded_model.connection_name == bq_connection + def test_create_text_embedding_generator_model_defaults(bq_connection): import bigframes.pandas as bpd