diff --git a/feathr_project/feathr/utils/job_utils.py b/feathr_project/feathr/utils/job_utils.py index d9c73c355..e03645f71 100644 --- a/feathr_project/feathr/utils/job_utils.py +++ b/feathr_project/feathr/utils/job_utils.py @@ -31,7 +31,7 @@ def get_result_pandas_df( Returns: pandas DataFrame """ - return get_result_df(client, data_format, res_url, local_cache_path) + return get_result_df(client=client, data_format=data_format, res_url=res_url, local_cache_path=local_cache_path) def get_result_spark_df( @@ -56,12 +56,19 @@ def get_result_spark_df( Returns: Spark DataFrame """ - return get_result_df(client, data_format, res_url, local_cache_path, spark=spark) + return get_result_df( + client=client, + data_format=data_format, + res_url=res_url, + local_cache_path=local_cache_path, + spark=spark, + ) def get_result_df( client: FeathrClient, data_format: str = None, + format: str = None, res_url: str = None, local_cache_path: str = None, spark: SparkSession = None, @@ -72,6 +79,7 @@ def get_result_df( client: Feathr client data_format: Format to read the downloaded files. Currently support `parquet`, `delta`, `avro`, and `csv`. Default to use client's job tags if exists. + format: An alias for `data_format` (for backward compatibility). res_url: Result URL to download files from. Note that this will not block the job so you need to make sure the job is finished and the result URL contains actual data. Default to use client's job tags if exists. local_cache_path (optional): Specify the absolute download directory. if the user does not provide this, @@ -82,6 +90,9 @@ def get_result_df( Returns: Either Spark or pandas DataFrame. """ + if format is not None: + data_format = format + if data_format is None: # May use data format from the job tags if client.get_job_tags() and client.get_job_tags().get(OUTPUT_FORMAT): diff --git a/feathr_project/test/unit/utils/test_job_utils.py b/feathr_project/test/unit/utils/test_job_utils.py index 0909fb56e..4a0d835e5 100644 --- a/feathr_project/test/unit/utils/test_job_utils.py +++ b/feathr_project/test/unit/utils/test_job_utils.py @@ -26,7 +26,12 @@ def test__get_result_pandas_df(mocker: MockerFixture): res_url = "some_res_url" local_cache_path = "some_local_cache_path" get_result_pandas_df(client, data_format, res_url, local_cache_path) - mocked_get_result_df.assert_called_once_with(client, data_format, res_url, local_cache_path) + mocked_get_result_df.assert_called_once_with( + client=client, + data_format=data_format, + res_url=res_url, + local_cache_path=local_cache_path, + ) def test__get_result_spark_df(mocker: MockerFixture): @@ -38,7 +43,13 @@ def test__get_result_spark_df(mocker: MockerFixture): res_url = "some_res_url" local_cache_path = "some_local_cache_path" get_result_spark_df(spark, client, data_format, res_url, local_cache_path) - mocked_get_result_df.assert_called_once_with(client, data_format, res_url, local_cache_path, spark=spark) + mocked_get_result_df.assert_called_once_with( + client=client, + data_format=data_format, + res_url=res_url, + local_cache_path=local_cache_path, + spark=spark, + ) @pytest.mark.parametrize( @@ -226,3 +237,41 @@ def test__get_result_df__with_spark_session( ) assert isinstance(df, DataFrame) assert df.count() == expected_count + + +@pytest.mark.parametrize( + "format,output_filename,expected_count", [ + ("csv", "output.csv", 5), + ] +) +def test__get_result_df__arg_alias( + workspace_dir: str, + format: str, + output_filename: str, + expected_count: int, +): + """Test get_result_df returns pandas DataFrame with the argument alias `format` instead of using `data_format`""" + for spark_runtime in ["local", "databricks", "azure_synapse"]: + # Note: make sure the output file exists in the test_user_workspace + res_url = str(Path(workspace_dir, "mock_results", output_filename)) + local_cache_path = res_url + + # Mock client + client = MagicMock() + client.spark_runtime = spark_runtime + + # Mock feathr_spark_launcher.download_result + if client.spark_runtime == "databricks": + res_url = f"dbfs:/{res_url}" + if client.spark_runtime == "azure_synapse" and format == "delta": + # TODO currently pass the delta table test on Synapse result due to the delta table package bug. + continue + + df = get_result_df( + client=client, + format=format, + res_url=res_url, + local_cache_path=local_cache_path, + ) + assert isinstance(df, pd.DataFrame) + assert len(df) == expected_count