diff --git a/src/databricks/sql/parameters/__init__.py b/src/databricks/sql/parameters/__init__.py index 3c39cf2bd..c05beb9e6 100644 --- a/src/databricks/sql/parameters/__init__.py +++ b/src/databricks/sql/parameters/__init__.py @@ -12,4 +12,6 @@ TimestampNTZParameter, TinyIntParameter, DecimalParameter, + MapParameter, + ArrayParameter, ) diff --git a/src/databricks/sql/parameters/native.py b/src/databricks/sql/parameters/native.py index 8a436355f..b7c448254 100644 --- a/src/databricks/sql/parameters/native.py +++ b/src/databricks/sql/parameters/native.py @@ -1,12 +1,13 @@ import datetime import decimal from enum import Enum, auto -from typing import Optional, Sequence +from typing import Optional, Sequence, Any from databricks.sql.exc import NotSupportedError from databricks.sql.thrift_api.TCLIService.ttypes import ( TSparkParameter, TSparkParameterValue, + TSparkParameterValueArg, ) import datetime @@ -54,7 +55,17 @@ class DatabricksSupportedType(Enum): TAllowedParameterValue = Union[ - str, int, float, datetime.datetime, datetime.date, bool, decimal.Decimal, None + str, + int, + float, + datetime.datetime, + datetime.date, + bool, + decimal.Decimal, + None, + list, + dict, + tuple, ] @@ -82,6 +93,7 @@ class DbsqlParameterBase: CAST_EXPR: str name: Optional[str] + value: Any def as_tspark_param(self, named: bool) -> TSparkParameter: """Returns a TSparkParameter object that can be passed to the DBR thrift server.""" @@ -98,6 +110,10 @@ def as_tspark_param(self, named: bool) -> TSparkParameter: def _tspark_param_value(self): return TSparkParameterValue(stringValue=str(self.value)) + def _tspark_value_arg(self): + """Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server.""" + return TSparkParameterValueArg(value=str(self.value), type=self._cast_expr()) + def _cast_expr(self): return self.CAST_EXPR @@ -428,6 +444,99 @@ def __init__(self, value: int, name: Optional[str] = None): CAST_EXPR = DatabricksSupportedType.TINYINT.name +class ArrayParameter(DbsqlParameterBase): + """Wrap a Python `Sequence` that will be bound to a Databricks SQL ARRAY type.""" + + def __init__(self, value: Sequence[Any], name: Optional[str] = None): + """ + :value: + The value to bind for this parameter. This will be casted to a ARRAY. + :name: + If None, your query must contain a `?` marker. Like: + + ```sql + SELECT * FROM table WHERE field = ? + ``` + If not None, your query should contain a named parameter marker. Like: + ```sql + SELECT * FROM table WHERE field = :my_param + ``` + + The `name` argument to this function would be `my_param`. + """ + self.name = name + self.value = [dbsql_parameter_from_primitive(val) for val in value] + + def as_tspark_param(self, named: bool = False) -> TSparkParameter: + """Returns a TSparkParameter object that can be passed to the DBR thrift server.""" + + tsp = TSparkParameter(type=self._cast_expr()) + tsp.arguments = [val._tspark_value_arg() for val in self.value] + + if named: + tsp.name = self.name + tsp.ordinal = False + elif not named: + tsp.ordinal = True + return tsp + + def _tspark_value_arg(self): + """Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server.""" + tva = TSparkParameterValueArg(type=self._cast_expr()) + tva.arguments = [val._tspark_value_arg() for val in self.value] + return tva + + CAST_EXPR = DatabricksSupportedType.ARRAY.name + + +class MapParameter(DbsqlParameterBase): + """Wrap a Python `dict` that will be bound to a Databricks SQL MAP type.""" + + def __init__(self, value: dict, name: Optional[str] = None): + """ + :value: + The value to bind for this parameter. This will be casted to a MAP. + :name: + If None, your query must contain a `?` marker. Like: + + ```sql + SELECT * FROM table WHERE field = ? + ``` + If not None, your query should contain a named parameter marker. Like: + ```sql + SELECT * FROM table WHERE field = :my_param + ``` + + The `name` argument to this function would be `my_param`. + """ + self.name = name + self.value = [ + dbsql_parameter_from_primitive(item) + for key, val in value.items() + for item in (key, val) + ] + + def as_tspark_param(self, named: bool = False) -> TSparkParameter: + """Returns a TSparkParameter object that can be passed to the DBR thrift server.""" + + tsp = TSparkParameter(type=self._cast_expr()) + tsp.arguments = [val._tspark_value_arg() for val in self.value] + if named: + tsp.name = self.name + tsp.ordinal = False + elif not named: + tsp.ordinal = True + return tsp + + def _tspark_value_arg(self): + """Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server.""" + tva = TSparkParameterValueArg(type=self._cast_expr()) + tva.arguments = [val._tspark_value_arg() for val in self.value] + return tva + + CAST_EXPR = DatabricksSupportedType.MAP.name + + class DecimalParameter(DbsqlParameterBase): """Wrap a Python `Decimal` that will be bound to a Databricks SQL DECIMAL type.""" @@ -543,23 +652,26 @@ def dbsql_parameter_from_primitive( # havoc. We can't use TYPE_INFERRENCE_MAP because mypy doesn't trust # its logic - if type(value) is int: + if isinstance(value, bool): + return BooleanParameter(value=value, name=name) + elif isinstance(value, int): return dbsql_parameter_from_int(value, name=name) - elif type(value) is str: + elif isinstance(value, str): return StringParameter(value=value, name=name) - elif type(value) is float: + elif isinstance(value, float): return FloatParameter(value=value, name=name) - elif type(value) is datetime.datetime: + elif isinstance(value, datetime.datetime): return TimestampParameter(value=value, name=name) - elif type(value) is datetime.date: + elif isinstance(value, datetime.date): return DateParameter(value=value, name=name) - elif type(value) is bool: - return BooleanParameter(value=value, name=name) - elif type(value) is decimal.Decimal: + elif isinstance(value, decimal.Decimal): return DecimalParameter(value=value, name=name) + elif isinstance(value, dict): + return MapParameter(value=value, name=name) + elif isinstance(value, Sequence) and not isinstance(value, str): + return ArrayParameter(value=value, name=name) elif value is None: return VoidParameter(value=value, name=name) - else: raise NotSupportedError( f"Could not infer parameter type from value: {value} - {type(value)} \n" @@ -581,6 +693,8 @@ def dbsql_parameter_from_primitive( TimestampNTZParameter, TinyIntParameter, DecimalParameter, + ArrayParameter, + MapParameter, ] diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 186f13dd6..0ce2fa169 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -5,10 +5,10 @@ import decimal from abc import ABC, abstractmethod from collections import OrderedDict, namedtuple -from collections.abc import Iterable +from collections.abc import Mapping from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Sequence import re import lz4.frame @@ -429,7 +429,7 @@ def user_friendly_error_message(self, no_retry_reason, attempt, elapsed): # Taken from PyHive class ParamEscaper: _DATE_FORMAT = "%Y-%m-%d" - _TIME_FORMAT = "%H:%M:%S.%f" + _TIME_FORMAT = "%H:%M:%S.%f %z" _DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT) def escape_args(self, parameters): @@ -458,13 +458,22 @@ def escape_string(self, item): return "'{}'".format(item.replace("\\", "\\\\").replace("'", "\\'")) def escape_sequence(self, item): - l = map(str, map(self.escape_item, item)) - return "(" + ",".join(l) + ")" + l = map(self.escape_item, item) + l = list(map(str, l)) + return "ARRAY(" + ",".join(l) + ")" + + def escape_mapping(self, item): + l = map( + self.escape_item, + (element for key, value in item.items() for element in (key, value)), + ) + l = list(map(str, l)) + return "MAP(" + ",".join(l) + ")" def escape_datetime(self, item, format, cutoff=0): dt_str = item.strftime(format) formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str - return "'{}'".format(formatted) + return "'{}'".format(formatted.strip()) def escape_decimal(self, item): return str(item) @@ -476,14 +485,16 @@ def escape_item(self, item): return self.escape_number(item) elif isinstance(item, str): return self.escape_string(item) - elif isinstance(item, Iterable): - return self.escape_sequence(item) elif isinstance(item, datetime.datetime): return self.escape_datetime(item, self._DATETIME_FORMAT) elif isinstance(item, datetime.date): return self.escape_datetime(item, self._DATE_FORMAT) elif isinstance(item, decimal.Decimal): return self.escape_decimal(item) + elif isinstance(item, Sequence): + return self.escape_sequence(item) + elif isinstance(item, Mapping): + return self.escape_mapping(item) else: raise exc.ProgrammingError("Unsupported object {}".format(item)) diff --git a/tests/e2e/test_complex_types.py b/tests/e2e/test_complex_types.py index 4af4f7b67..c8a3a0781 100644 --- a/tests/e2e/test_complex_types.py +++ b/tests/e2e/test_complex_types.py @@ -1,5 +1,6 @@ import pytest from numpy import ndarray +from typing import Sequence from tests.e2e.test_driver import PySQLPytestTestCase @@ -14,50 +15,73 @@ def table_fixture(self, connection_details): # Create the table cursor.execute( """ - CREATE TABLE IF NOT EXISTS pysql_e2e_test_complex_types_table ( + CREATE TABLE IF NOT EXISTS pysql_test_complex_types_table ( array_col ARRAY, map_col MAP, - struct_col STRUCT - ) + struct_col STRUCT, + array_array_col ARRAY>, + array_map_col ARRAY>, + map_array_col MAP> + ) USING DELTA """ ) # Insert a record cursor.execute( """ - INSERT INTO pysql_e2e_test_complex_types_table + INSERT INTO pysql_test_complex_types_table VALUES ( ARRAY('a', 'b', 'c'), MAP('a', 1, 'b', 2, 'c', 3), - NAMED_STRUCT('field1', 'a', 'field2', 1) + NAMED_STRUCT('field1', 'a', 'field2', 1), + ARRAY(ARRAY('a','b','c')), + ARRAY(MAP('a', 1, 'b', 2, 'c', 3)), + MAP('a', ARRAY('a', 'b', 'c'), 'b', ARRAY('d', 'e')) ) """ ) yield # Clean up the table after the test - cursor.execute("DROP TABLE IF EXISTS pysql_e2e_test_complex_types_table") + cursor.execute("DELETE FROM pysql_test_complex_types_table") @pytest.mark.parametrize( "field,expected_type", - [("array_col", ndarray), ("map_col", list), ("struct_col", dict)], + [ + ("array_col", ndarray), + ("map_col", list), + ("struct_col", dict), + ("array_array_col", ndarray), + ("array_map_col", ndarray), + ("map_array_col", list), + ], ) def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): """Confirms the return types of a complex type field when reading as arrow""" with self.cursor() as cursor: result = cursor.execute( - "SELECT * FROM pysql_e2e_test_complex_types_table LIMIT 1" + "SELECT * FROM pysql_test_complex_types_table LIMIT 1" ).fetchone() assert isinstance(result[field], expected_type) - @pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col")]) + @pytest.mark.parametrize( + "field", + [ + ("array_col"), + ("map_col"), + ("struct_col"), + ("array_array_col"), + ("array_map_col"), + ("map_array_col"), + ], + ) def test_read_complex_types_as_string(self, field, table_fixture): """Confirms the return type of a complex type that is returned as a string""" with self.cursor( extra_params={"_use_arrow_native_complex_types": False} ) as cursor: result = cursor.execute( - "SELECT * FROM pysql_e2e_test_complex_types_table LIMIT 1" + "SELECT * FROM pysql_test_complex_types_table LIMIT 1" ).fetchone() assert isinstance(result[field], str) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 440d4efb3..d0c721109 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -856,7 +856,9 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog): raise KeyboardInterrupt("Simulated interrupt") finally: if conn is not None: - assert not conn.open, "Connection should be closed after KeyboardInterrupt" + assert ( + not conn.open + ), "Connection should be closed after KeyboardInterrupt" def test_cursor_close_properly_closes_operation(self): """Test that Cursor.close() properly closes the active operation handle on the server.""" @@ -883,7 +885,9 @@ def test_cursor_close_properly_closes_operation(self): raise KeyboardInterrupt("Simulated interrupt") finally: if cursor is not None: - assert not cursor.open, "Cursor should be closed after KeyboardInterrupt" + assert ( + not cursor.open + ), "Cursor should be closed after KeyboardInterrupt" def test_nested_cursor_context_managers(self): """Test that nested cursor context managers properly close operations on the server.""" diff --git a/tests/e2e/test_parameterized_queries.py b/tests/e2e/test_parameterized_queries.py index d346ad5c6..79def9b72 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/tests/e2e/test_parameterized_queries.py @@ -5,8 +5,11 @@ from typing import Dict, List, Type, Union from unittest.mock import patch +import time +import numpy as np import pytest import pytz +from numpy.random.mtrand import Sequence from databricks.sql.parameters.native import ( BigIntegerParameter, @@ -26,6 +29,8 @@ TimestampParameter, TinyIntParameter, VoidParameter, + ArrayParameter, + MapParameter, ) from tests.e2e.test_driver import PySQLPytestTestCase @@ -50,6 +55,8 @@ class Primitive(Enum): DOUBLE = 3.14 FLOAT = 3.15 SMALLINT = 51 + ARRAYS = ["a", "b", "c"] + MAPS = {"a": 1, "b": 2, "c": 3} class PrimitiveExtra(Enum): @@ -103,6 +110,8 @@ class TestParameterizedQueries(PySQLPytestTestCase): Primitive.BOOL: "boolean_col", Primitive.DATE: "date_col", Primitive.TIMESTAMP: "timestamp_col", + Primitive.ARRAYS: "array_col", + Primitive.MAPS: "map_col", Primitive.NONE: "null_col", } @@ -134,7 +143,11 @@ def inline_table(self, connection_details): string_col STRING, boolean_col BOOLEAN, date_col DATE, - timestamp_col TIMESTAMP + timestamp_col TIMESTAMP, + array_col ARRAY, + map_col MAP, + array_map_col ARRAY>, + map_array_col MAP> ) USING DELTA """ @@ -155,7 +168,7 @@ def patch_server_supports_native_params(self, supports_native_params: bool = Tru finally: pass - def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle): + def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle, target_column): """This INSERT, SELECT, DELETE dance is necessary because simply selecting ``` "SELECT %(param)s" @@ -166,7 +179,6 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle): :paramstyle: This is a no-op but is included to make the test-code easier to read. """ - target_column = self._get_inline_table_column(params.get("p")) INSERT_QUERY = f"INSERT INTO pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)" SELECT_QUERY = f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1" DELETE_QUERY = "DELETE FROM pysql_e2e_inline_param_test_table" @@ -212,7 +224,11 @@ def _get_one_result( if approach == ParameterApproach.INLINE: # inline mode always uses ParamStyle.PYFORMAT # inline mode doesn't support positional parameters - return self._inline_roundtrip(params, paramstyle=ParamStyle.PYFORMAT) + return self._inline_roundtrip( + params, + paramstyle=ParamStyle.PYFORMAT, + target_column=self._get_inline_table_column(params.get("p")), + ) elif approach == ParameterApproach.NATIVE: # native mode can use either ParamStyle.NAMED or ParamStyle.PYFORMAT # native mode can use either ParameterStructure.NAMED or ParameterStructure.POSITIONAL @@ -229,10 +245,73 @@ def _eq(self, actual, expected: Primitive): If primitive is Primitive.DOUBLE than an extra quantize step is performed before making the assertion. """ + actual_parsed = actual + expected_parsed = expected.value + if expected in (Primitive.DOUBLE, Primitive.FLOAT): - return self._quantize(actual) == self._quantize(expected.value) + actual_parsed = self._quantize(actual) + expected_parsed = self._quantize(expected.value) + elif expected == Primitive.ARRAYS: + actual_parsed = actual.tolist() + elif expected == Primitive.MAPS: + expected_parsed = list(expected.value.items()) + + return actual_parsed == expected_parsed + + def _parse_to_common_type(self, value): + """ + Function to convert the :value passed into a common python datatype for comparison + + Convertion fyi + MAP Datatype on server is returned as a list of tuples + Ex: + {"a":1,"b":2} -> [("a",1),("b",2)] - return actual == expected.value + ARRAY Datatype on server is returned as a numpy array + Ex: + ["a","b","c"] -> np.array(["a","b","c"],dtype=object) + + Primitive datatype on server is returned as a numpy primitive + Ex: + 1 -> np.int64(1) + 2 -> np.int32(2) + """ + if value is None: + return None + elif isinstance(value, (Sequence, np.ndarray)) and not isinstance( + value, (str, bytes) + ): + return tuple(value) + elif isinstance(value, dict): + return tuple(value.items()) + elif isinstance(value, np.generic): + return value.item() + else: + return value + + def _recursive_compare(self, actual, expected): + """ + Function to compare the :actual and :expected values, recursively checks and ensures that all the data matches till the leaf level + + Note: Complex datatype like MAP is not returned as a dictionary but as a list of tuples + """ + actual_parsed = self._parse_to_common_type(actual) + expected_parsed = self._parse_to_common_type(expected) + + # Check if types are the same + if type(actual_parsed) != type(expected_parsed): + return False + + # Handle lists or tuples + if isinstance(actual_parsed, (list, tuple)): + if len(actual_parsed) != len(expected_parsed): + return False + return all( + self._recursive_compare(o1, o2) + for o1, o2 in zip(actual_parsed, expected_parsed) + ) + + return actual_parsed == expected_parsed @pytest.mark.parametrize("primitive", Primitive) @pytest.mark.parametrize( @@ -278,6 +357,8 @@ def test_primitive_single( (Primitive.SMALLINT, SmallIntParameter), (PrimitiveExtra.TIMESTAMP_NTZ, TimestampNTZParameter), (PrimitiveExtra.TINYINT, TinyIntParameter), + (Primitive.ARRAYS, ArrayParameter), + (Primitive.MAPS, MapParameter), ], ) def test_dbsqlparameter_single( @@ -361,6 +442,58 @@ def test_readme_example(self): assert len(result) == 10 assert result[0].p == "foo" + @pytest.mark.parametrize( + "col_name,data", + [ + ("array_map_col", [{"a": 1, "b": 2}, {"c": 3, "d": 4}]), + ("map_array_col", {1: ["a", "b"], 2: ["c", "d"]}), + ], + ) + def test_inline_recursive_complex_type(self, col_name, data): + params = {"p": data} + result = self._inline_roundtrip( + params=params, paramstyle=ParamStyle.PYFORMAT, target_column=col_name + ) + assert self._recursive_compare(result.col, data) + + @pytest.mark.parametrize( + "description,data", + [ + ("ARRAY>", [{"a": 1, "b": 2}, {"c": 3, "d": 4}]), + ("MAP>", {1: ["a", "b"], 2: ["c", "d"]}), + ("ARRAY>", [[1, 2, 3], [1, 2, 3]]), + ( + "ARRAY>>", + [[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]], + ), + ( + "MAP>", + {"a": {"b": "c", "d": "e"}, "f": {"g": "h", "i": "j"}}, + ), + ], + ) + @pytest.mark.parametrize( + "paramstyle,parameter_structure", + [ + (ParamStyle.NONE, ParameterStructure.POSITIONAL), + (ParamStyle.PYFORMAT, ParameterStructure.NAMED), + (ParamStyle.NAMED, ParameterStructure.NAMED), + ], + ) + def test_native_recursive_complex_type( + self, description, data, paramstyle, parameter_structure + ): + if paramstyle == ParamStyle.NONE: + params = [data] + else: + params = {"p": data} + result = self._native_roundtrip( + parameters=params, + paramstyle=paramstyle, + parameter_structure=parameter_structure, + ) + assert self._recursive_compare(result.col, data) + class TestInlineParameterSyntax(PySQLPytestTestCase): """The inline parameter approach uses pyformat markers""" diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index cf8779a21..588b0d70e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -532,7 +532,7 @@ def test_execute_parameter_passthrough(self): ("SELECT %(x)s", "SELECT NULL", {"x": None}), ("SELECT %(int_value)d", "SELECT 48", {"int_value": 48}), ("SELECT %(float_value).2f", "SELECT 48.20", {"float_value": 48.2}), - ("SELECT %(iter)s", "SELECT (1,2,3,4,5)", {"iter": [1, 2, 3, 4, 5]}), + ("SELECT %(iter)s", "SELECT ARRAY(1,2,3,4,5)", {"iter": [1, 2, 3, 4, 5]}), ( "SELECT %(datetime)s", "SELECT '2022-02-01 10:23:00.000000'", @@ -758,7 +758,7 @@ def test_cursor_close_handles_exception(self): mock_backend = Mock() mock_connection = Mock() mock_op_handle = Mock() - + mock_backend.close_command.side_effect = Exception("Test error") cursor = client.Cursor(mock_connection, mock_backend) @@ -767,78 +767,80 @@ def test_cursor_close_handles_exception(self): cursor.close() mock_backend.close_command.assert_called_once_with(mock_op_handle) - + self.assertIsNone(cursor.active_op_handle) - + self.assertFalse(cursor.open) def test_cursor_context_manager_handles_exit_exception(self): """Test that cursor's context manager handles exceptions during __exit__.""" mock_backend = Mock() mock_connection = Mock() - + cursor = client.Cursor(mock_connection, mock_backend) original_close = cursor.close cursor.close = Mock(side_effect=Exception("Test error during close")) - + try: with cursor: raise ValueError("Test error inside context") except ValueError: pass - + cursor.close.assert_called_once() def test_connection_close_handles_cursor_close_exception(self): """Test that _close handles exceptions from cursor.close() properly.""" cursors_closed = [] - + def mock_close_with_exception(): cursors_closed.append(1) raise Exception("Test error during close") - + cursor1 = Mock() cursor1.close = mock_close_with_exception - + def mock_close_normal(): cursors_closed.append(2) - + cursor2 = Mock() cursor2.close = mock_close_normal - + mock_backend = Mock() mock_session_handle = Mock() - + try: for cursor in [cursor1, cursor2]: try: cursor.close() except Exception: pass - + mock_backend.close_session(mock_session_handle) except Exception as e: self.fail(f"Connection close should handle exceptions: {e}") - - self.assertEqual(cursors_closed, [1, 2], "Both cursors should have close called") + + self.assertEqual( + cursors_closed, [1, 2], "Both cursors should have close called" + ) def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" result_set = client.ResultSet.__new__(client.ResultSet) result_set.thrift_backend = Mock() - result_set.thrift_backend.CLOSED_OP_STATE = 'CLOSED' + result_set.thrift_backend.CLOSED_OP_STATE = "CLOSED" result_set.connection = Mock() result_set.connection.open = True - result_set.op_state = 'RUNNING' + result_set.op_state = "RUNNING" result_set.has_been_closed_server_side = False result_set.command_id = Mock() class MockRequestError(Exception): def __init__(self): self.args = ["Error message", CursorAlreadyClosedError()] - + result_set.thrift_backend.close_command.side_effect = MockRequestError() - + original_close = client.ResultSet.close try: try: @@ -854,11 +856,13 @@ def __init__(self): finally: result_set.has_been_closed_server_side = True result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE - - result_set.thrift_backend.close_command.assert_called_once_with(result_set.command_id) - + + result_set.thrift_backend.close_command.assert_called_once_with( + result_set.command_id + ) + assert result_set.has_been_closed_server_side is True - + assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE finally: pass diff --git a/tests/unit/test_param_escaper.py b/tests/unit/test_param_escaper.py index 925fcea58..9b6b9c246 100644 --- a/tests/unit/test_param_escaper.py +++ b/tests/unit/test_param_escaper.py @@ -120,24 +120,38 @@ def test_escape_date(self): assert pe.escape_datetime(INPUT, FORMAT) == OUTPUT def test_escape_sequence_integer(self): - assert pe.escape_sequence([1, 2, 3, 4]) == "(1,2,3,4)" + assert pe.escape_sequence([1, 2, 3, 4]) == "ARRAY(1,2,3,4)" def test_escape_sequence_float(self): - assert pe.escape_sequence([1.1, 2.2, 3.3, 4.4]) == "(1.1,2.2,3.3,4.4)" + assert pe.escape_sequence([1.1, 2.2, 3.3, 4.4]) == "ARRAY(1.1,2.2,3.3,4.4)" def test_escape_sequence_string(self): assert ( pe.escape_sequence(["his", "name", "was", "robert", "palmer"]) - == "('his','name','was','robert','palmer')" + == "ARRAY('his','name','was','robert','palmer')" ) def test_escape_sequence_sequence_of_strings(self): - # This is not valid SQL. INPUT = [["his", "name"], ["was", "robert"], ["palmer"]] - OUTPUT = "(('his','name'),('was','robert'),('palmer'))" + OUTPUT = "ARRAY(ARRAY('his','name'),ARRAY('was','robert'),ARRAY('palmer'))" assert pe.escape_sequence(INPUT) == OUTPUT + def test_escape_map_string_int(self): + INPUT = {"a": 1, "b": 2} + OUTPUT = "MAP('a',1,'b',2)" + assert pe.escape_mapping(INPUT) == OUTPUT + + def test_escape_map_string_sequence_of_floats(self): + INPUT = {"a": [1.1, 2.2, 3.3], "b": [4.4, 5.5, 6.6]} + OUTPUT = "MAP('a',ARRAY(1.1,2.2,3.3),'b',ARRAY(4.4,5.5,6.6))" + assert pe.escape_mapping(INPUT) == OUTPUT + + def test_escape_sequence_of_map_int_string(self): + INPUT = [{1: "a", 2: "foo"}, {3: "b", 4: "bar"}] + OUTPUT = "ARRAY(MAP(1,'a',2,'foo'),MAP(3,'b',4,'bar'))" + assert pe.escape_sequence(INPUT) == OUTPUT + class TestFullQueryEscaping(object): def test_simple(self): diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index eec921e4d..249730789 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -21,10 +21,14 @@ TimestampParameter, TinyIntParameter, VoidParameter, + MapParameter, + ArrayParameter, ) from databricks.sql.parameters.native import ( TDbsqlParameter, + TSparkParameter, TSparkParameterValue, + TSparkParameterValueArg, dbsql_parameter_from_primitive, ) from databricks.sql.thrift_api.TCLIService import ttypes @@ -112,6 +116,8 @@ class Primitive(Enum): DOUBLE = 3.14 FLOAT = 3.15 SMALLINT = 51 + ARRAY = [1, 2, 3] + MAP = {"a": 1, "b": 2} class TestDbsqlParameter: @@ -131,6 +137,8 @@ class TestDbsqlParameter: (TimestampParameter, Primitive.TIMESTAMP, "TIMESTAMP"), (TimestampNTZParameter, Primitive.TIMESTAMP, "TIMESTAMP_NTZ"), (TinyIntParameter, Primitive.INT, "TINYINT"), + (MapParameter, Primitive.MAP, "MAP"), + (ArrayParameter, Primitive.ARRAY, "ARRAY"), ), ) def test_cast_expression( @@ -166,6 +174,99 @@ def test_tspark_param_value(self, t: TDbsqlParameter, prim): else: assert output == TSparkParameterValue(stringValue=str(prim.value)) + @pytest.mark.parametrize( + "base_type,input,expected_output", + [ + ( + ArrayParameter, + [1, 2, 3], + TSparkParameter( + ordinal=True, + name=None, + type="ARRAY", + value=None, + arguments=[ + TSparkParameterValueArg(type="INT", value="1", arguments=None), + TSparkParameterValueArg(type="INT", value="2", arguments=None), + TSparkParameterValueArg(type="INT", value="3", arguments=None), + ], + ), + ), + ( + MapParameter, + {"a": 1, "b": 2}, + TSparkParameter( + ordinal=True, + name=None, + type="MAP", + value=None, + arguments=[ + TSparkParameterValueArg( + type="STRING", value="a", arguments=None + ), + TSparkParameterValueArg(type="INT", value="1", arguments=None), + TSparkParameterValueArg( + type="STRING", value="b", arguments=None + ), + TSparkParameterValueArg(type="INT", value="2", arguments=None), + ], + ), + ), + ( + ArrayParameter, + [{"a": 1, "b": 2}, {"c": 3, "d": 4}], + TSparkParameter( + ordinal=True, + name=None, + type="ARRAY", + value=None, + arguments=[ + TSparkParameterValueArg( + type="MAP", + value=None, + arguments=[ + TSparkParameterValueArg( + type="STRING", value="a", arguments=None + ), + TSparkParameterValueArg( + type="INT", value="1", arguments=None + ), + TSparkParameterValueArg( + type="STRING", value="b", arguments=None + ), + TSparkParameterValueArg( + type="INT", value="2", arguments=None + ), + ], + ), + TSparkParameterValueArg( + type="MAP", + value=None, + arguments=[ + TSparkParameterValueArg( + type="STRING", value="c", arguments=None + ), + TSparkParameterValueArg( + type="INT", value="3", arguments=None + ), + TSparkParameterValueArg( + type="STRING", value="d", arguments=None + ), + TSparkParameterValueArg( + type="INT", value="4", arguments=None + ), + ], + ), + ], + ), + ), + ], + ) + def test_complex_type_tspark_param(self, base_type, input, expected_output): + p = base_type(input) + tsp = p.as_tspark_param() + assert tsp == expected_output + def test_tspark_param_named(self): p = dbsql_parameter_from_primitive(Primitive.INT.value, name="p") tsp = p.as_tspark_param(named=True) @@ -192,6 +293,8 @@ def test_tspark_param_ordinal(self): (FloatParameter, Primitive.FLOAT), (VoidParameter, Primitive.NONE), (TimestampParameter, Primitive.TIMESTAMP), + (MapParameter, Primitive.MAP), + (ArrayParameter, Primitive.ARRAY), ), ) def test_inference(self, _type: TDbsqlParameter, prim: Primitive): diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 7fe318446..458ea9a82 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -86,7 +86,9 @@ def test_make_request_checks_thrift_status_code(self): def _make_type_desc(self, type): return ttypes.TTypeDesc( - types=[ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type))] + types=[ + ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type)) + ] ) def _make_fake_thrift_backend(self):