From 4d1722928fe6fb86baae3cae7b91850b7e5ff08e Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Wed, 4 Oct 2023 19:28:46 +0000 Subject: [PATCH 1/6] feat: use default session connection --- bigframes/session.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bigframes/session.py b/bigframes/session.py index ac48c977cb..0c67c779a8 100644 --- a/bigframes/session.py +++ b/bigframes/session.py @@ -97,6 +97,8 @@ _BIGQUERYCONNECTION_REGIONAL_ENDPOINT = "{location}-bigqueryconnection.googleapis.com" _BIGQUERYSTORAGE_REGIONAL_ENDPOINT = "{location}-bigquerystorage.googleapis.com" +_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection" + _MAX_CLUSTER_COLUMNS = 4 # TODO(swast): Need to connect to regional endpoints when performing remote @@ -321,7 +323,10 @@ def __init__( ), ) - self._bq_connection = context.bq_connection + self._bq_connection = ( + context.bq_connection + or f"{self.bqclient.project}.{self._location}.{_BIGFRAMES_DEFAULT_CONNECTION_ID}" + ) # Now that we're starting the session, don't allow the options to be # changed. From b0e92154158fab3b2e6634a75ea7f3781b981985 Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Wed, 4 Oct 2023 22:41:04 +0000 Subject: [PATCH 2/6] update docs --- bigframes/_config/bigquery_options.py | 4 +++- bigframes/ml/llm.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/bigframes/_config/bigquery_options.py b/bigframes/_config/bigquery_options.py index ea1864ed5f..da71ac5efd 100644 --- a/bigframes/_config/bigquery_options.py +++ b/bigframes/_config/bigquery_options.py @@ -83,12 +83,14 @@ def project(self, value: Optional[str]): @property def bq_connection(self) -> Optional[str]: - """Name of the BigQuery connection to use. + """Name of the BigQuery connection to use. Should be of the form ... You should either have the connection already created in the location you have chosen, or you should have the Project IAM Admin role to enable the service to create the connection for you if you need it. + + If this option isn't provided, session will use its default project/location/connection_id as default connection. """ return self._bq_connection diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index c86e5fb3b6..f698fea4d1 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -38,8 +38,9 @@ class PaLM2TextGenerator(base.Predictor): session (bigframes.Session or None): BQ session to create the model. If None, use the global default session. connection_name (str or None): - connection to connect with remote service. str of the format ... - if None, use default connection in session context. + connection to connect with remote service. str of the format ... + if None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach + permission if the connection isn't fully setup. """ def __init__( From 4ee895edb57b30ed1043478e2776b20347f53093 Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Tue, 10 Oct 2023 00:22:11 +0000 Subject: [PATCH 3/6] supplement partial connection_name --- bigframes/_config/bigquery_options.py | 2 +- bigframes/clients.py | 24 ++++++++ bigframes/ml/llm.py | 18 +++++- bigframes/remote_function.py | 71 ++++++++++------------ bigframes/session.py | 13 ++-- tests/system/small/ml/test_llm.py | 33 +++++++++- tests/system/small/test_remote_function.py | 30 +++++++++ tests/unit/test_clients.py | 45 ++++++++++++++ 8 files changed, 185 insertions(+), 51 deletions(-) create mode 100644 tests/unit/test_clients.py diff --git a/bigframes/_config/bigquery_options.py b/bigframes/_config/bigquery_options.py index da71ac5efd..eb56de826a 100644 --- a/bigframes/_config/bigquery_options.py +++ b/bigframes/_config/bigquery_options.py @@ -90,7 +90,7 @@ def bq_connection(self) -> Optional[str]: Admin role to enable the service to create the connection for you if you need it. - If this option isn't provided, session will use its default project/location/connection_id as default connection. + If this option isn't provided, or project or location aren't provided, session will use its default project/location/connection_id as default connection. """ return self._bq_connection diff --git a/bigframes/clients.py b/bigframes/clients.py index b60fcba04a..dcac611e8c 100644 --- a/bigframes/clients.py +++ b/bigframes/clients.py @@ -29,6 +29,8 @@ ) logger = logging.getLogger(__name__) +_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection" + class BqConnectionManager: """Manager to handle operations with BQ connections.""" @@ -162,3 +164,25 @@ def _get_service_account_if_connection_exists( pass return service_account + + +def get_connection_name_full( + connection_name: Optional[str], default_project: str, default_location: str +) -> str: + """Retrieve the full connection name of the form ... + Use default project, location or connection_id when any of them are missing.""" + if connection_name is None: + return ( + f"{default_project}.{default_location}.{_BIGFRAMES_DEFAULT_CONNECTION_ID}" + ) + + if connection_name.count(".") == 2: + return connection_name + + if connection_name.count(".") == 1: + return f"{default_project}.{connection_name}" + + if connection_name.count(".") == 0: + return f"{default_project}.{default_location}.{connection_name}" + + raise ValueError(f"Invalid connection name format: {connection_name}.") diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index f698fea4d1..a61dd34e6d 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -49,7 +49,14 @@ def __init__( connection_name: Optional[str] = None, ): self.session = session or bpd.get_global_session() - self.connection_name = connection_name or self.session._bq_connection + + connection_name = connection_name or self.session._bq_connection + self.connection_name = clients.get_connection_name_full( + connection_name, + default_project=self.session._project, + default_location=self.session._location, + ) + self._bq_connection_manager = clients.BqConnectionManager( self.session.bqconnectionclient, self.session.resourcemanagerclient ) @@ -181,7 +188,14 @@ def __init__( connection_name: Optional[str] = None, ): self.session = session or bpd.get_global_session() - self.connection_name = connection_name or self.session._bq_connection + + connection_name = connection_name or self.session._bq_connection + self.connection_name = clients.get_connection_name_full( + connection_name, + default_project=self.session._project, + default_location=self.session._location, + ) + self._bq_connection_manager = clients.BqConnectionManager( self.session.bqconnectionclient, self.session.resourcemanagerclient ) diff --git a/bigframes/remote_function.py b/bigframes/remote_function.py index 6fc2f8e59f..7b04ebd44e 100644 --- a/bigframes/remote_function.py +++ b/bigframes/remote_function.py @@ -695,9 +695,12 @@ def remote_function( persistent name. """ + import bigframes.pandas as bpd + + session = session or bpd.get_global_session() # A BigQuery client is required to perform BQ operations - if not bigquery_client and session: + if not bigquery_client: bigquery_client = session.bqclient if not bigquery_client: raise ValueError( @@ -706,7 +709,7 @@ def remote_function( ) # A BigQuery connection client is required to perform BQ connection operations - if not bigquery_connection_client and session: + if not bigquery_connection_client: bigquery_connection_client = session.bqconnectionclient if not bigquery_connection_client: raise ValueError( @@ -716,8 +719,7 @@ def remote_function( # A cloud functions client is required to perform cloud functions operations if not cloud_functions_client: - if session: - cloud_functions_client = session.cloudfunctionsclient + cloud_functions_client = session.cloudfunctionsclient if not cloud_functions_client: raise ValueError( "A cloud functions client must be provided, either directly or via session. " @@ -726,8 +728,7 @@ def remote_function( # A resource manager client is required to get/set IAM operations if not resource_manager_client: - if session: - resource_manager_client = session.resourcemanagerclient + resource_manager_client = session.resourcemanagerclient if not resource_manager_client: raise ValueError( "A resource manager client must be provided, either directly or via session. " @@ -740,11 +741,11 @@ def remote_function( dataset_ref = bigquery.DatasetReference.from_string( dataset, default_project=bigquery_client.project ) - elif session: + else: dataset_ref = bigquery.DatasetReference.from_string( session._session_dataset_id, default_project=bigquery_client.project ) - else: + if not dataset_ref: raise ValueError( "Project and dataset must be provided, either directly or via session. " f"{constants.FEEDBACK_LINK}" @@ -756,40 +757,30 @@ def remote_function( # A connection is required for BQ remote function # https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#create_a_remote_function - if not bigquery_connection and session: - bigquery_connection = session._bq_connection # type: ignore if not bigquery_connection: + bigquery_connection = session._bq_connection # type: ignore + + bigquery_connection = clients.get_connection_name_full( + bigquery_connection, + default_project=dataset_ref.project, + default_location=bq_location, + ) + # Guaranteed to be the form of .. + ( + gcp_project_id, + bq_connection_location, + bq_connection_id, + ) = bigquery_connection.split(".") + if gcp_project_id.casefold() != dataset_ref.project.casefold(): raise ValueError( - "BigQuery connection must be provided, either directly or via session. " - f"{constants.FEEDBACK_LINK}" + "The project_id does not match BigQuery connection gcp_project_id: " + f"{dataset_ref.project}." + ) + if bq_connection_location.casefold() != bq_location.casefold(): + raise ValueError( + "The location does not match BigQuery connection location: " + f"{bq_location}." ) - - # Check connection_id with `LOCATION.CONNECTION_ID` or `PROJECT_ID.LOCATION.CONNECTION_ID` format. - if bigquery_connection.count(".") == 1: - bq_connection_location, bq_connection_id = bigquery_connection.split(".") - if bq_connection_location.casefold() != bq_location.casefold(): - raise ValueError( - "The location does not match BigQuery connection location: " - f"{bq_location}." - ) - bigquery_connection = bq_connection_id - elif bigquery_connection.count(".") == 2: - ( - gcp_project_id, - bq_connection_location, - bq_connection_id, - ) = bigquery_connection.split(".") - if gcp_project_id.casefold() != dataset_ref.project.casefold(): - raise ValueError( - "The project_id does not match BigQuery connection gcp_project_id: " - f"{dataset_ref.project}." - ) - if bq_connection_location.casefold() != bq_location.casefold(): - raise ValueError( - "The location does not match BigQuery connection location: " - f"{bq_location}." - ) - bigquery_connection = bq_connection_id def wrapper(f): if not callable(f): @@ -808,7 +799,7 @@ def wrapper(f): dataset_ref.dataset_id, bigquery_client, bigquery_connection_client, - bigquery_connection, + bq_connection_id, resource_manager_client, ) diff --git a/bigframes/session.py b/bigframes/session.py index 0c67c779a8..a7cb78e3ff 100644 --- a/bigframes/session.py +++ b/bigframes/session.py @@ -97,8 +97,6 @@ _BIGQUERYCONNECTION_REGIONAL_ENDPOINT = "{location}-bigqueryconnection.googleapis.com" _BIGQUERYSTORAGE_REGIONAL_ENDPOINT = "{location}-bigquerystorage.googleapis.com" -_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection" - _MAX_CLUSTER_COLUMNS = 4 # TODO(swast): Need to connect to regional endpoints when performing remote @@ -323,10 +321,7 @@ def __init__( ), ) - self._bq_connection = ( - context.bq_connection - or f"{self.bqclient.project}.{self._location}.{_BIGFRAMES_DEFAULT_CONNECTION_ID}" - ) + self._bq_connection = context.bq_connection # Now that we're starting the session, don't allow the options to be # changed. @@ -355,10 +350,14 @@ def resourcemanagerclient(self): @property def _session_dataset_id(self): """A dataset for storing temporary objects local to the session - This is a workaround for BQML models and remote functions that do not + This is a workaround for remote functions that do not yet support session-temporary instances.""" return self._session_dataset.dataset_id + @property + def _project(self): + return self.bqclient.project + def _create_and_bind_bq_session(self): """Create a BQ session and bind the session id with clients to capture BQ activities: go/bigframes-transient-data""" diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 7486277487..8e4477d4d9 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -26,7 +26,8 @@ def test_create_text_generator_model(palm2_text_generator_model): assert palm2_text_generator_model._bqml_model is not None -def test_create_text_generator_model_defaults(bq_connection): +@pytest.mark.flaky(retries=2, delay=120) +def test_create_text_generator_model_default_session(bq_connection, llm_text_pandas_df): import bigframes.pandas as bpd bpd.reset_session() @@ -36,6 +37,36 @@ def test_create_text_generator_model_defaults(bq_connection): model = llm.PaLM2TextGenerator() assert model is not None assert model._bqml_model is not None + assert model.connection_name.casefold() == "bigframes-dev.us.bigframes-rf-conn" + + llm_text_df = bpd.read_pandas(llm_text_pandas_df) + + df = model.predict(llm_text_df).to_pandas() + TestCase().assertSequenceEqual(df.shape, (3, 1)) + assert "ml_generate_text_llm_result" in df.columns + series = df["ml_generate_text_llm_result"] + assert all(series.str.len() > 20) + + +@pytest.mark.flaky(retries=2, delay=120) +def test_create_text_generator_model_default_connection(llm_text_pandas_df): + import bigframes.pandas as bpd + + llm_text_df = bpd.read_pandas(llm_text_pandas_df) + + model = llm.PaLM2TextGenerator() + assert model is not None + assert model._bqml_model is not None + assert ( + model.connection_name.casefold() + == "bigframes-dev.us.bigframes-default-connection" + ) + + df = model.predict(llm_text_df).to_pandas() + TestCase().assertSequenceEqual(df.shape, (3, 1)) + assert "ml_generate_text_llm_result" in df.columns + series = df["ml_generate_text_llm_result"] + assert all(series.str.len() > 20) # Marked as flaky only because BQML LLM is in preview, the service only has limited capacity, not stable enough. diff --git a/tests/system/small/test_remote_function.py b/tests/system/small/test_remote_function.py index 77fb81d2c9..3e52730358 100644 --- a/tests/system/small/test_remote_function.py +++ b/tests/system/small/test_remote_function.py @@ -465,6 +465,36 @@ def square(x): assert_pandas_df_equal_ignore_ordering(bf_result, pd_result) +@pytest.mark.flaky(retries=2, delay=120) +def test_remote_function_default_connection(scalars_dfs, dataset_id): + @rf.remote_function([int], int, dataset=dataset_id) + def square(x): + return x * x + + scalars_df, scalars_pandas_df = scalars_dfs + + bf_int64_col = scalars_df["int64_col"] + bf_int64_col_filter = bf_int64_col.notnull() + bf_int64_col_filtered = bf_int64_col[bf_int64_col_filter] + bf_result_col = bf_int64_col_filtered.apply(square) + bf_result = ( + bf_int64_col_filtered.to_frame().assign(result=bf_result_col).to_pandas() + ) + + pd_int64_col = scalars_pandas_df["int64_col"] + pd_int64_col_filter = pd_int64_col.notnull() + pd_int64_col_filtered = pd_int64_col[pd_int64_col_filter] + pd_result_col = pd_int64_col_filtered.apply(lambda x: x * x) + # TODO(shobs): Figure why pandas .apply() changes the dtype, i.e. + # pd_int64_col_filtered.dtype is Int64Dtype() + # pd_int64_col_filtered.apply(lambda x: x * x).dtype is int64. + # For this test let's force the pandas dtype to be same as bigframes' dtype. + pd_result_col = pd_result_col.astype(pd.Int64Dtype()) + pd_result = pd_int64_col_filtered.to_frame().assign(result=pd_result_col) + + assert_pandas_df_equal_ignore_ordering(bf_result, pd_result) + + @pytest.mark.flaky(retries=2, delay=120) def test_dataframe_applymap(session_with_bq_connection, scalars_dfs): def add_one(x): diff --git a/tests/unit/test_clients.py b/tests/unit/test_clients.py new file mode 100644 index 0000000000..acc624921c --- /dev/null +++ b/tests/unit/test_clients.py @@ -0,0 +1,45 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from bigframes import clients + + +def test_get_connection_name_full_none(): + connection_name = clients.get_connection_name_full( + None, default_project="default-project", default_location="us" + ) + assert connection_name == "default-project.us.bigframes-default-connection" + + +def test_get_connection_name_full_connection_id(): + connection_name = clients.get_connection_name_full( + "connection-id", default_project="default-project", default_location="us" + ) + assert connection_name == "default-project.us.connection-id" + + +def test_get_connection_name_full_location_connection_id(): + connection_name = clients.get_connection_name_full( + "eu.connection-id", default_project="default-project", default_location="us" + ) + assert connection_name == "default-project.eu.connection-id" + + +def test_get_connection_name_full_all(): + connection_name = clients.get_connection_name_full( + "my-project.eu.connection-id", + default_project="default-project", + default_location="us", + ) + assert connection_name == "my-project.eu.connection-id" From c79c86ebec154f42edd3a58f0d9d5c44bb254895 Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Tue, 10 Oct 2023 00:43:10 +0000 Subject: [PATCH 4/6] fix test --- tests/system/small/ml/test_llm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 8e4477d4d9..e6e051817a 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -52,6 +52,8 @@ def test_create_text_generator_model_default_session(bq_connection, llm_text_pan def test_create_text_generator_model_default_connection(llm_text_pandas_df): import bigframes.pandas as bpd + bpd.reset_session() + llm_text_df = bpd.read_pandas(llm_text_pandas_df) model = llm.PaLM2TextGenerator() From 62ac080020e232e6578b5688ae93a9615bc5fe72 Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Tue, 10 Oct 2023 01:42:05 +0000 Subject: [PATCH 5/6] fix test --- tests/system/small/ml/test_llm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index e6e051817a..e546c09f97 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -50,9 +50,11 @@ def test_create_text_generator_model_default_session(bq_connection, llm_text_pan @pytest.mark.flaky(retries=2, delay=120) def test_create_text_generator_model_default_connection(llm_text_pandas_df): + from bigframes import _config import bigframes.pandas as bpd bpd.reset_session() + _config.options = _config.Options() # reset configs llm_text_df = bpd.read_pandas(llm_text_pandas_df) From aa42ff7c8088104f6b427074702456d8c0ca6c83 Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Wed, 11 Oct 2023 04:00:14 +0000 Subject: [PATCH 6/6] resolve comments --- bigframes/remote_function.py | 5 ----- tests/system/small/test_remote_function.py | 3 ++- tests/unit/test_clients.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/bigframes/remote_function.py b/bigframes/remote_function.py index 7b04ebd44e..37c7a2fc64 100644 --- a/bigframes/remote_function.py +++ b/bigframes/remote_function.py @@ -745,11 +745,6 @@ def remote_function( dataset_ref = bigquery.DatasetReference.from_string( session._session_dataset_id, default_project=bigquery_client.project ) - if not dataset_ref: - raise ValueError( - "Project and dataset must be provided, either directly or via session. " - f"{constants.FEEDBACK_LINK}" - ) bq_location, cloud_function_region = get_remote_function_locations( bigquery_client.location diff --git a/tests/system/small/test_remote_function.py b/tests/system/small/test_remote_function.py index 3e52730358..d024a57ded 100644 --- a/tests/system/small/test_remote_function.py +++ b/tests/system/small/test_remote_function.py @@ -20,6 +20,7 @@ import bigframes from bigframes import remote_function as rf +import bigframes.pandas as bpd from tests.system.utils import assert_pandas_df_equal_ignore_ordering @@ -467,7 +468,7 @@ def square(x): @pytest.mark.flaky(retries=2, delay=120) def test_remote_function_default_connection(scalars_dfs, dataset_id): - @rf.remote_function([int], int, dataset=dataset_id) + @bpd.remote_function([int], int, dataset=dataset_id) def square(x): return x * x diff --git a/tests/unit/test_clients.py b/tests/unit/test_clients.py index acc624921c..a90e5b0320 100644 --- a/tests/unit/test_clients.py +++ b/tests/unit/test_clients.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from bigframes import clients @@ -43,3 +45,13 @@ def test_get_connection_name_full_all(): default_location="us", ) assert connection_name == "my-project.eu.connection-id" + + +def test_get_connection_name_full_raise_value_error(): + + with pytest.raises(ValueError): + clients.get_connection_name_full( + "my-project.eu.connection-id.extra_field", + default_project="default-project", + default_location="us", + )