Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
7 changes: 7 additions & 0 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ def evaluate(self, input_data: Optional[bpd.DataFrame] = None):

return self._session.read_gbq(sql)

def arima_evaluate(self, show_all_candidate_models: bool = False):
sql = self._model_manipulation_sql_generator.ml_arima_evaluate(
show_all_candidate_models
)

return self._session.read_gbq(sql)

def centroids(self) -> bpd.DataFrame:
assert self._model.model_type == "KMEANS"

Expand Down
25 changes: 25 additions & 0 deletions bigframes/ml/forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,31 @@ def score(
input_data = X.join(y, how="outer")
return self._bqml_model.evaluate(input_data)

def summary(
self,
show_all_candidate_models: bool = False,
) -> bpd.DataFrame:
"""Summary of the evaluation metrics of the time series model.

.. note::

Output matches that of the BigQuery ML.ARIMA_EVALUATE function.
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-arima-evaluate
for the outputs relevant to this model type.

Args:
show_all_candidate_models (bool, default to False):
Whether to show evaluation metrics or an error message for either
all candidate models or for only the best model with the lowest
AIC. Default to False.

Returns:
bigframes.dataframe.DataFrame: A DataFrame as evaluation result.
"""
if not self._bqml_model:
raise RuntimeError("A model must be fitted before score")
return self._bqml_model.arima_evaluate(show_all_candidate_models)

def to_gbq(self, model_name: str, replace: bool = False) -> ARIMAPlus:
"""Save the model to BigQuery.

Expand Down
6 changes: 6 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ def ml_evaluate(self, source_df: Optional[bpd.DataFrame] = None) -> str:
return f"""SELECT * FROM ML.EVALUATE(MODEL `{self._model_name}`,
({source_sql}))"""

# ML evaluation TVFs
def ml_arima_evaluate(self, show_all_candidate_models: bool = False) -> str:
"""Encode ML.ARMIA_EVALUATE for BQML"""
return f"""SELECT * FROM ML.ARIMA_EVALUATE(MODEL `{self._model_name}`,
STRUCT({show_all_candidate_models} AS show_all_candidate_models))"""

def ml_centroids(self) -> str:
"""Encode ML.CENTROIDS for BQML"""
return f"""SELECT * FROM ML.CENTROIDS(MODEL `{self._model_name}`)"""
Expand Down
35 changes: 33 additions & 2 deletions tests/system/large/ml/test_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@

from bigframes.ml import forecasting

ARIMA_EVALUATE_OUTPUT_COL = [
"non_seasonal_p",
"non_seasonal_d",
"non_seasonal_q",
"log_likelihood",
"AIC",
"variance",
"seasonal_periods",
"has_holiday_effect",
"has_spikes_and_dips",
"has_step_changes",
"error_message",
]


def test_arima_plus_model_fit_score(
time_series_df_default_index, dataset_id, new_time_series_df
Expand All @@ -42,7 +56,24 @@ def test_arima_plus_model_fit_score(
pd.testing.assert_frame_equal(result, expected, check_exact=False, rtol=0.1)

# save, load to ensure configuration was kept
reloaded_model = model.to_gbq(f"{dataset_id}.temp_configured_model", replace=True)
reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True)
assert (
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
)


def test_arima_plus_model_fit_summary(time_series_df_default_index, dataset_id):
model = forecasting.ARIMAPlus()
X_train = time_series_df_default_index[["parsed_date"]]
y_train = time_series_df_default_index[["total_visits"]]
model.fit(X_train, y_train)

result = model.summary()
assert result.shape == (1, 12)
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)

# save, load to ensure configuration was kept
reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True)
assert (
f"{dataset_id}.temp_configured_model" in reloaded_model._bqml_model.model_name
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
)
40 changes: 40 additions & 0 deletions tests/system/small/ml/test_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@

from bigframes.ml import forecasting

ARIMA_EVALUATE_OUTPUT_COL = [
"non_seasonal_p",
"non_seasonal_d",
"non_seasonal_q",
"log_likelihood",
"AIC",
"variance",
"seasonal_periods",
"has_holiday_effect",
"has_spikes_and_dips",
"has_step_changes",
"error_message",
]


def test_model_predict_default(time_series_arima_plus_model: forecasting.ARIMAPlus):
utc = pytz.utc
Expand Down Expand Up @@ -104,6 +118,24 @@ def test_model_score(
)


def test_model_summary(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
result = time_series_arima_plus_model.summary()
assert result.shape == (1, 12)
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)


def test_model_summary_show_all_candidates(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
result = time_series_arima_plus_model.summary(
show_all_candidate_models=True,
)
assert result.shape[0] > 1
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)


def test_model_score_series(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
Expand All @@ -126,3 +158,11 @@ def test_model_score_series(
rtol=0.1,
check_index_type=False,
)


def test_model_summary_series(
time_series_arima_plus_model: forecasting.ARIMAPlus, new_time_series_df
):
result = time_series_arima_plus_model.summary()
assert result.shape == (1, 12)
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)
13 changes: 13 additions & 0 deletions tests/unit/ml/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,19 @@ def test_ml_evaluate_produces_correct_sql(
)


def test_ml_arima_evaluate_produces_correct_sql(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
):
sql = model_manipulation_sql_generator.ml_arima_evaluate(
show_all_candidate_models=True
)
assert (
sql
== """SELECT * FROM ML.ARIMA_EVALUATE(MODEL `my_project_id.my_dataset_id.my_model_id`,
STRUCT(True AS show_all_candidate_models))"""
)


def test_ml_evaluate_no_source_produces_correct_sql(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
):
Expand Down