Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1edc80a

Browse files
remove changes to filters
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent d59b351 commit 1edc80a

File tree

2 files changed

+105
-69
lines changed

2 files changed

+105
-69
lines changed

src/databricks/sql/backend/filters.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,27 @@
99
List,
1010
Optional,
1111
Any,
12-
Dict,
1312
Callable,
14-
TypeVar,
15-
Generic,
1613
cast,
17-
TYPE_CHECKING,
1814
)
1915

20-
from databricks.sql.backend.types import ExecuteResponse, CommandId
21-
from databricks.sql.backend.sea.models.base import ResultData
2216
from databricks.sql.backend.sea.backend import SeaDatabricksClient
17+
from databricks.sql.backend.types import ExecuteResponse
2318

24-
if TYPE_CHECKING:
25-
from databricks.sql.result_set import ResultSet, SeaResultSet
19+
from databricks.sql.result_set import ResultSet, SeaResultSet
2620

2721
logger = logging.getLogger(__name__)
2822

2923

3024
class ResultSetFilter:
3125
"""
32-
A general-purpose filter for result sets that can be applied to any backend.
33-
34-
This class provides methods to filter result sets based on various criteria,
35-
similar to the client-side filtering in the JDBC connector.
26+
A general-purpose filter for result sets.
3627
"""
3728

3829
@staticmethod
3930
def _filter_sea_result_set(
40-
result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool]
41-
) -> "SeaResultSet":
31+
result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool]
32+
) -> SeaResultSet:
4233
"""
4334
Filter a SEA result set using the provided filter function.
4435
@@ -49,15 +40,13 @@ def _filter_sea_result_set(
4940
Returns:
5041
A filtered SEA result set
5142
"""
43+
5244
# Get all remaining rows
5345
all_rows = result_set.results.remaining_rows()
5446

5547
# Filter rows
5648
filtered_rows = [row for row in all_rows if filter_func(row)]
5749

58-
# Import SeaResultSet here to avoid circular imports
59-
from databricks.sql.result_set import SeaResultSet
60-
6150
# Reuse the command_id from the original result set
6251
command_id = result_set.command_id
6352

@@ -73,10 +62,13 @@ def _filter_sea_result_set(
7362
)
7463

7564
# Create a new ResultData object with filtered data
65+
7666
from databricks.sql.backend.sea.models.base import ResultData
7767

7868
result_data = ResultData(data=filtered_rows, external_links=None)
7969

70+
from databricks.sql.result_set import SeaResultSet
71+
8072
# Create a new SeaResultSet with the filtered data
8173
filtered_result_set = SeaResultSet(
8274
connection=result_set.connection,
@@ -91,11 +83,11 @@ def _filter_sea_result_set(
9183

9284
@staticmethod
9385
def filter_by_column_values(
94-
result_set: "ResultSet",
86+
result_set: ResultSet,
9587
column_index: int,
9688
allowed_values: List[str],
9789
case_sensitive: bool = False,
98-
) -> "ResultSet":
90+
) -> ResultSet:
9991
"""
10092
Filter a result set by values in a specific column.
10193
@@ -108,6 +100,7 @@ def filter_by_column_values(
108100
Returns:
109101
A filtered result set
110102
"""
103+
111104
# Convert to uppercase for case-insensitive comparison if needed
112105
if not case_sensitive:
113106
allowed_values = [v.upper() for v in allowed_values]
@@ -138,8 +131,8 @@ def filter_by_column_values(
138131

139132
@staticmethod
140133
def filter_tables_by_type(
141-
result_set: "ResultSet", table_types: Optional[List[str]] = None
142-
) -> "ResultSet":
134+
result_set: ResultSet, table_types: Optional[List[str]] = None
135+
) -> ResultSet:
143136
"""
144137
Filter a result set of tables by the specified table types.
145138
@@ -154,6 +147,7 @@ def filter_tables_by_type(
154147
Returns:
155148
A filtered result set containing only tables of the specified types
156149
"""
150+
157151
# Default table types if none specified
158152
DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"]
159153
valid_types = (

tests/unit/test_filters.py

Lines changed: 90 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44

55
import unittest
66
from unittest.mock import MagicMock, patch
7-
import sys
8-
from typing import List, Dict, Any
9-
10-
# Add the necessary path to import the filter module
11-
sys.path.append("/home/varun.edachali/conn/databricks-sql-python/src")
127

138
from databricks.sql.backend.filters import ResultSetFilter
149

@@ -20,24 +15,39 @@ def setUp(self):
2015
"""Set up test fixtures."""
2116
# Create a mock SeaResultSet
2217
self.mock_sea_result_set = MagicMock()
23-
self.mock_sea_result_set._response = {
24-
"result": {
25-
"data_array": [
26-
["catalog1", "schema1", "table1", "TABLE", ""],
27-
["catalog1", "schema1", "table2", "VIEW", ""],
28-
["catalog1", "schema1", "table3", "SYSTEM TABLE", ""],
29-
["catalog1", "schema1", "table4", "EXTERNAL TABLE", ""],
30-
],
31-
"row_count": 4,
32-
}
33-
}
18+
19+
# Set up the remaining_rows method on the results attribute
20+
self.mock_sea_result_set.results = MagicMock()
21+
self.mock_sea_result_set.results.remaining_rows.return_value = [
22+
["catalog1", "schema1", "table1", "owner1", "2023-01-01", "TABLE", ""],
23+
["catalog1", "schema1", "table2", "owner1", "2023-01-01", "VIEW", ""],
24+
[
25+
"catalog1",
26+
"schema1",
27+
"table3",
28+
"owner1",
29+
"2023-01-01",
30+
"SYSTEM TABLE",
31+
"",
32+
],
33+
[
34+
"catalog1",
35+
"schema1",
36+
"table4",
37+
"owner1",
38+
"2023-01-01",
39+
"EXTERNAL TABLE",
40+
"",
41+
],
42+
]
3443

3544
# Set up the connection and other required attributes
3645
self.mock_sea_result_set.connection = MagicMock()
3746
self.mock_sea_result_set.backend = MagicMock()
3847
self.mock_sea_result_set.buffer_size_bytes = 1000
3948
self.mock_sea_result_set.arraysize = 100
4049
self.mock_sea_result_set.statement_id = "test-statement-id"
50+
self.mock_sea_result_set.lz4_compressed = False
4151

4252
# Create a mock CommandId
4353
from databricks.sql.backend.types import CommandId, BackendType
@@ -50,70 +60,102 @@ def setUp(self):
5060
("catalog_name", "string", None, None, None, None, True),
5161
("schema_name", "string", None, None, None, None, True),
5262
("table_name", "string", None, None, None, None, True),
63+
("owner", "string", None, None, None, None, True),
64+
("creation_time", "string", None, None, None, None, True),
5365
("table_type", "string", None, None, None, None, True),
5466
("remarks", "string", None, None, None, None, True),
5567
]
5668
self.mock_sea_result_set.has_been_closed_server_side = False
69+
self.mock_sea_result_set._arrow_schema_bytes = None
5770

58-
def test_filter_tables_by_type(self):
59-
"""Test filtering tables by type."""
60-
# Test with specific table types
61-
table_types = ["TABLE", "VIEW"]
71+
def test_filter_by_column_values(self):
72+
"""Test filtering by column values with various options."""
73+
# Case 1: Case-sensitive filtering
74+
allowed_values = ["table1", "table3"]
6275

63-
# Make the mock_sea_result_set appear to be a SeaResultSet
6476
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
6577
with patch(
6678
"databricks.sql.result_set.SeaResultSet"
6779
) as mock_sea_result_set_class:
68-
# Set up the mock to return a new mock when instantiated
6980
mock_instance = MagicMock()
7081
mock_sea_result_set_class.return_value = mock_instance
7182

72-
result = ResultSetFilter.filter_tables_by_type(
73-
self.mock_sea_result_set, table_types
83+
# Call filter_by_column_values on the table_name column (index 2)
84+
result = ResultSetFilter.filter_by_column_values(
85+
self.mock_sea_result_set, 2, allowed_values, case_sensitive=True
7486
)
7587

7688
# Verify the filter was applied correctly
7789
mock_sea_result_set_class.assert_called_once()
7890

79-
def test_filter_tables_by_type_case_insensitive(self):
80-
"""Test filtering tables by type with case insensitivity."""
81-
# Test with lowercase table types
82-
table_types = ["table", "view"]
91+
# Check the filtered data passed to the constructor
92+
args, kwargs = mock_sea_result_set_class.call_args
93+
result_data = kwargs.get("result_data")
94+
self.assertIsNotNone(result_data)
95+
self.assertEqual(len(result_data.data), 2)
96+
self.assertIn(result_data.data[0][2], allowed_values)
97+
self.assertIn(result_data.data[1][2], allowed_values)
8398

84-
# Make the mock_sea_result_set appear to be a SeaResultSet
99+
# Case 2: Case-insensitive filtering
100+
mock_sea_result_set_class.reset_mock()
85101
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
86102
with patch(
87103
"databricks.sql.result_set.SeaResultSet"
88104
) as mock_sea_result_set_class:
89-
# Set up the mock to return a new mock when instantiated
90105
mock_instance = MagicMock()
91106
mock_sea_result_set_class.return_value = mock_instance
92107

93-
result = ResultSetFilter.filter_tables_by_type(
94-
self.mock_sea_result_set, table_types
108+
# Call filter_by_column_values with case-insensitive matching
109+
result = ResultSetFilter.filter_by_column_values(
110+
self.mock_sea_result_set,
111+
2,
112+
["TABLE1", "TABLE3"],
113+
case_sensitive=False,
95114
)
96-
97-
# Verify the filter was applied correctly
98115
mock_sea_result_set_class.assert_called_once()
99116

100-
def test_filter_tables_by_type_default(self):
101-
"""Test filtering tables by type with default types."""
102-
# Make the mock_sea_result_set appear to be a SeaResultSet
103-
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
104-
with patch(
105-
"databricks.sql.result_set.SeaResultSet"
106-
) as mock_sea_result_set_class:
107-
# Set up the mock to return a new mock when instantiated
108-
mock_instance = MagicMock()
109-
mock_sea_result_set_class.return_value = mock_instance
117+
# Case 3: Unsupported result set type
118+
mock_unsupported_result_set = MagicMock()
119+
with patch("databricks.sql.backend.filters.isinstance", return_value=False):
120+
with patch("databricks.sql.backend.filters.logger") as mock_logger:
121+
result = ResultSetFilter.filter_by_column_values(
122+
mock_unsupported_result_set, 0, ["value"], True
123+
)
124+
mock_logger.warning.assert_called_once()
125+
self.assertEqual(result, mock_unsupported_result_set)
126+
127+
def test_filter_tables_by_type(self):
128+
"""Test filtering tables by type with various options."""
129+
# Case 1: Specific table types
130+
table_types = ["TABLE", "VIEW"]
110131

111-
result = ResultSetFilter.filter_tables_by_type(
112-
self.mock_sea_result_set, None
132+
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
133+
with patch.object(
134+
ResultSetFilter, "filter_by_column_values"
135+
) as mock_filter:
136+
ResultSetFilter.filter_tables_by_type(
137+
self.mock_sea_result_set, table_types
113138
)
139+
args, kwargs = mock_filter.call_args
140+
self.assertEqual(args[0], self.mock_sea_result_set)
141+
self.assertEqual(args[1], 5) # Table type column index
142+
self.assertEqual(args[2], table_types)
143+
self.assertEqual(kwargs.get("case_sensitive"), True)
114144

115-
# Verify the filter was applied correctly
116-
mock_sea_result_set_class.assert_called_once()
145+
# Case 2: Default table types (None or empty list)
146+
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
147+
with patch.object(
148+
ResultSetFilter, "filter_by_column_values"
149+
) as mock_filter:
150+
# Test with None
151+
ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None)
152+
args, kwargs = mock_filter.call_args
153+
self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"])
154+
155+
# Test with empty list
156+
ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, [])
157+
args, kwargs = mock_filter.call_args
158+
self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"])
117159

118160

119161
if __name__ == "__main__":

0 commit comments

Comments
 (0)