4
4
5
5
import unittest
6
6
from unittest .mock import MagicMock , patch
7
- import sys
8
- from typing import List , Dict , Any
9
-
10
- # Add the necessary path to import the filter module
11
- sys .path .append ("/home/varun.edachali/conn/databricks-sql-python/src" )
12
7
13
8
from databricks .sql .backend .filters import ResultSetFilter
14
9
@@ -20,24 +15,39 @@ def setUp(self):
20
15
"""Set up test fixtures."""
21
16
# Create a mock SeaResultSet
22
17
self .mock_sea_result_set = MagicMock ()
23
- self .mock_sea_result_set ._response = {
24
- "result" : {
25
- "data_array" : [
26
- ["catalog1" , "schema1" , "table1" , "TABLE" , "" ],
27
- ["catalog1" , "schema1" , "table2" , "VIEW" , "" ],
28
- ["catalog1" , "schema1" , "table3" , "SYSTEM TABLE" , "" ],
29
- ["catalog1" , "schema1" , "table4" , "EXTERNAL TABLE" , "" ],
30
- ],
31
- "row_count" : 4 ,
32
- }
33
- }
18
+
19
+ # Set up the remaining_rows method on the results attribute
20
+ self .mock_sea_result_set .results = MagicMock ()
21
+ self .mock_sea_result_set .results .remaining_rows .return_value = [
22
+ ["catalog1" , "schema1" , "table1" , "owner1" , "2023-01-01" , "TABLE" , "" ],
23
+ ["catalog1" , "schema1" , "table2" , "owner1" , "2023-01-01" , "VIEW" , "" ],
24
+ [
25
+ "catalog1" ,
26
+ "schema1" ,
27
+ "table3" ,
28
+ "owner1" ,
29
+ "2023-01-01" ,
30
+ "SYSTEM TABLE" ,
31
+ "" ,
32
+ ],
33
+ [
34
+ "catalog1" ,
35
+ "schema1" ,
36
+ "table4" ,
37
+ "owner1" ,
38
+ "2023-01-01" ,
39
+ "EXTERNAL TABLE" ,
40
+ "" ,
41
+ ],
42
+ ]
34
43
35
44
# Set up the connection and other required attributes
36
45
self .mock_sea_result_set .connection = MagicMock ()
37
46
self .mock_sea_result_set .backend = MagicMock ()
38
47
self .mock_sea_result_set .buffer_size_bytes = 1000
39
48
self .mock_sea_result_set .arraysize = 100
40
49
self .mock_sea_result_set .statement_id = "test-statement-id"
50
+ self .mock_sea_result_set .lz4_compressed = False
41
51
42
52
# Create a mock CommandId
43
53
from databricks .sql .backend .types import CommandId , BackendType
@@ -50,70 +60,102 @@ def setUp(self):
50
60
("catalog_name" , "string" , None , None , None , None , True ),
51
61
("schema_name" , "string" , None , None , None , None , True ),
52
62
("table_name" , "string" , None , None , None , None , True ),
63
+ ("owner" , "string" , None , None , None , None , True ),
64
+ ("creation_time" , "string" , None , None , None , None , True ),
53
65
("table_type" , "string" , None , None , None , None , True ),
54
66
("remarks" , "string" , None , None , None , None , True ),
55
67
]
56
68
self .mock_sea_result_set .has_been_closed_server_side = False
69
+ self .mock_sea_result_set ._arrow_schema_bytes = None
57
70
58
- def test_filter_tables_by_type (self ):
59
- """Test filtering tables by type ."""
60
- # Test with specific table types
61
- table_types = ["TABLE " , "VIEW " ]
71
+ def test_filter_by_column_values (self ):
72
+ """Test filtering by column values with various options ."""
73
+ # Case 1: Case-sensitive filtering
74
+ allowed_values = ["table1 " , "table3 " ]
62
75
63
- # Make the mock_sea_result_set appear to be a SeaResultSet
64
76
with patch ("databricks.sql.backend.filters.isinstance" , return_value = True ):
65
77
with patch (
66
78
"databricks.sql.result_set.SeaResultSet"
67
79
) as mock_sea_result_set_class :
68
- # Set up the mock to return a new mock when instantiated
69
80
mock_instance = MagicMock ()
70
81
mock_sea_result_set_class .return_value = mock_instance
71
82
72
- result = ResultSetFilter .filter_tables_by_type (
73
- self .mock_sea_result_set , table_types
83
+ # Call filter_by_column_values on the table_name column (index 2)
84
+ result = ResultSetFilter .filter_by_column_values (
85
+ self .mock_sea_result_set , 2 , allowed_values , case_sensitive = True
74
86
)
75
87
76
88
# Verify the filter was applied correctly
77
89
mock_sea_result_set_class .assert_called_once ()
78
90
79
- def test_filter_tables_by_type_case_insensitive (self ):
80
- """Test filtering tables by type with case insensitivity."""
81
- # Test with lowercase table types
82
- table_types = ["table" , "view" ]
91
+ # Check the filtered data passed to the constructor
92
+ args , kwargs = mock_sea_result_set_class .call_args
93
+ result_data = kwargs .get ("result_data" )
94
+ self .assertIsNotNone (result_data )
95
+ self .assertEqual (len (result_data .data ), 2 )
96
+ self .assertIn (result_data .data [0 ][2 ], allowed_values )
97
+ self .assertIn (result_data .data [1 ][2 ], allowed_values )
83
98
84
- # Make the mock_sea_result_set appear to be a SeaResultSet
99
+ # Case 2: Case-insensitive filtering
100
+ mock_sea_result_set_class .reset_mock ()
85
101
with patch ("databricks.sql.backend.filters.isinstance" , return_value = True ):
86
102
with patch (
87
103
"databricks.sql.result_set.SeaResultSet"
88
104
) as mock_sea_result_set_class :
89
- # Set up the mock to return a new mock when instantiated
90
105
mock_instance = MagicMock ()
91
106
mock_sea_result_set_class .return_value = mock_instance
92
107
93
- result = ResultSetFilter .filter_tables_by_type (
94
- self .mock_sea_result_set , table_types
108
+ # Call filter_by_column_values with case-insensitive matching
109
+ result = ResultSetFilter .filter_by_column_values (
110
+ self .mock_sea_result_set ,
111
+ 2 ,
112
+ ["TABLE1" , "TABLE3" ],
113
+ case_sensitive = False ,
95
114
)
96
-
97
- # Verify the filter was applied correctly
98
115
mock_sea_result_set_class .assert_called_once ()
99
116
100
- def test_filter_tables_by_type_default (self ):
101
- """Test filtering tables by type with default types."""
102
- # Make the mock_sea_result_set appear to be a SeaResultSet
103
- with patch ("databricks.sql.backend.filters.isinstance" , return_value = True ):
104
- with patch (
105
- "databricks.sql.result_set.SeaResultSet"
106
- ) as mock_sea_result_set_class :
107
- # Set up the mock to return a new mock when instantiated
108
- mock_instance = MagicMock ()
109
- mock_sea_result_set_class .return_value = mock_instance
117
+ # Case 3: Unsupported result set type
118
+ mock_unsupported_result_set = MagicMock ()
119
+ with patch ("databricks.sql.backend.filters.isinstance" , return_value = False ):
120
+ with patch ("databricks.sql.backend.filters.logger" ) as mock_logger :
121
+ result = ResultSetFilter .filter_by_column_values (
122
+ mock_unsupported_result_set , 0 , ["value" ], True
123
+ )
124
+ mock_logger .warning .assert_called_once ()
125
+ self .assertEqual (result , mock_unsupported_result_set )
126
+
127
+ def test_filter_tables_by_type (self ):
128
+ """Test filtering tables by type with various options."""
129
+ # Case 1: Specific table types
130
+ table_types = ["TABLE" , "VIEW" ]
110
131
111
- result = ResultSetFilter .filter_tables_by_type (
112
- self .mock_sea_result_set , None
132
+ with patch ("databricks.sql.backend.filters.isinstance" , return_value = True ):
133
+ with patch .object (
134
+ ResultSetFilter , "filter_by_column_values"
135
+ ) as mock_filter :
136
+ ResultSetFilter .filter_tables_by_type (
137
+ self .mock_sea_result_set , table_types
113
138
)
139
+ args , kwargs = mock_filter .call_args
140
+ self .assertEqual (args [0 ], self .mock_sea_result_set )
141
+ self .assertEqual (args [1 ], 5 ) # Table type column index
142
+ self .assertEqual (args [2 ], table_types )
143
+ self .assertEqual (kwargs .get ("case_sensitive" ), True )
114
144
115
- # Verify the filter was applied correctly
116
- mock_sea_result_set_class .assert_called_once ()
145
+ # Case 2: Default table types (None or empty list)
146
+ with patch ("databricks.sql.backend.filters.isinstance" , return_value = True ):
147
+ with patch .object (
148
+ ResultSetFilter , "filter_by_column_values"
149
+ ) as mock_filter :
150
+ # Test with None
151
+ ResultSetFilter .filter_tables_by_type (self .mock_sea_result_set , None )
152
+ args , kwargs = mock_filter .call_args
153
+ self .assertEqual (args [2 ], ["TABLE" , "VIEW" , "SYSTEM TABLE" ])
154
+
155
+ # Test with empty list
156
+ ResultSetFilter .filter_tables_by_type (self .mock_sea_result_set , [])
157
+ args , kwargs = mock_filter .call_args
158
+ self .assertEqual (args [2 ], ["TABLE" , "VIEW" , "SYSTEM TABLE" ])
117
159
118
160
119
161
if __name__ == "__main__" :
0 commit comments