diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 2e5a9a1e5e..3cfc28e61f 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -24,12 +24,14 @@ from bigframes.ml import base, core, globals, utils import bigframes.pandas as bpd -_REMOTE_TEXT_GENERATOR_MODEL_CODE = "CLOUD_AI_LARGE_LANGUAGE_MODEL_V1" -_REMOTE_TEXT_GENERATOR_32K_MODEL_CODE = "text-bison-32k" +_REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT = "text-bison" +_REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT = "text-bison-32k" _TEXT_GENERATE_RESULT_COLUMN = "ml_generate_text_llm_result" -_REMOTE_EMBEDDING_GENERATOR_MODEL_CODE = "CLOUD_AI_TEXT_EMBEDDING_MODEL_V1" -_REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_CODE = "textembedding-gecko-multilingual" +_REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT = "textembedding-gecko" +_REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT = ( + "textembedding-gecko-multilingual" +) _EMBED_TEXT_RESULT_COLUMN = "text_embedding" @@ -88,14 +90,18 @@ def _create_bqml_model(self): connection_id=connection_name_parts[2], iam_role="aiplatform.user", ) - if self.model_name == "text-bison": + if self.model_name == _REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT: options = { - "remote_service_type": _REMOTE_TEXT_GENERATOR_MODEL_CODE, + "endpoint": _REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT, } - else: + elif self.model_name == _REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT: options = { - "endpoint": _REMOTE_TEXT_GENERATOR_32K_MODEL_CODE, + "endpoint": _REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT, } + else: + raise ValueError( + f"Model name {self.model_name} is not supported. We only support {_REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT} and {_REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT}." + ) return self._bqml_model_factory.create_remote_model( session=self.session, connection_name=self.connection_name, options=options ) @@ -240,12 +246,16 @@ def _create_bqml_model(self): ) if self.model_name == "textembedding-gecko": options = { - "remote_service_type": _REMOTE_EMBEDDING_GENERATOR_MODEL_CODE, + "endpoint": _REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT, } - else: + elif self.model_name == _REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT: options = { - "endpoint": _REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_CODE, + "endpoint": _REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT, } + else: + raise ValueError( + f"Model name {self.model_name} is not supported. We only support {_REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT} and {_REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT}." + ) return self._bqml_model_factory.create_remote_model( session=self.session, connection_name=self.connection_name, options=options