Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2ccd3c8

Browse files
SNOW-592647 consolidate definitions and resolve circular dependency issues (snowflakedb#1158)
1 parent 2d78cb0 commit 2ccd3c8

File tree

10 files changed

+115
-99
lines changed

10 files changed

+115
-99
lines changed

DESCRIPTION.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
1111

1212

1313
- v2.7.9(Unreleased)
14-
`
1514

1615
- Fixed a bug where errors raised during get_results_from_sfqid() were missing errno
16+
- Fixed a bug where empty results containing GEOGRAPHY type raised IndexError
1717

1818

1919
- v2.7.8(May 28,2022)

src/snowflake/connector/compat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import urllib.request
1717
from typing import Any
1818

19-
from snowflake.connector.constants import UTF8
19+
from . import constants
2020

2121
IS_LINUX = platform.system() == "Linux"
2222
IS_WINDOWS = platform.system() == "Windows"
@@ -111,7 +111,7 @@ def PKCS5_PAD(value: bytes, block_size: int) -> bytes:
111111
[
112112
value,
113113
(block_size - len(value) % block_size)
114-
* chr(block_size - len(value) % block_size).encode(UTF8),
114+
* chr(block_size - len(value) % block_size).encode(constants.UTF8),
115115
]
116116
)
117117

src/snowflake/connector/constants.py

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,16 @@
66

77
from collections import defaultdict
88
from enum import Enum, auto, unique
9-
from typing import Any, DefaultDict, NamedTuple
9+
from typing import Any, Callable, DefaultDict, NamedTuple
10+
11+
from .options import installed_pandas
12+
from .options import pyarrow as pa
13+
14+
if installed_pandas:
15+
DataType = pa.DataType
16+
else:
17+
DataType = None
18+
1019

1120
DBAPI_TYPE_STRING = 0
1221
DBAPI_TYPE_BINARY = 1
@@ -17,25 +26,61 @@
1726
class FieldType(NamedTuple):
1827
name: str
1928
dbapi_type: list[int]
20-
21-
22-
FIELD_TYPES: list[FieldType] = [
23-
FieldType(name="FIXED", dbapi_type=[DBAPI_TYPE_NUMBER]),
24-
FieldType(name="REAL", dbapi_type=[DBAPI_TYPE_NUMBER]),
25-
FieldType(name="TEXT", dbapi_type=[DBAPI_TYPE_STRING]),
26-
FieldType(name="DATE", dbapi_type=[DBAPI_TYPE_TIMESTAMP]),
27-
FieldType(name="TIMESTAMP", dbapi_type=[DBAPI_TYPE_TIMESTAMP]),
28-
FieldType(name="VARIANT", dbapi_type=[DBAPI_TYPE_BINARY]),
29-
FieldType(name="TIMESTAMP_LTZ", dbapi_type=[DBAPI_TYPE_TIMESTAMP]),
30-
FieldType(name="TIMESTAMP_TZ", dbapi_type=[DBAPI_TYPE_TIMESTAMP]),
31-
FieldType(name="TIMESTAMP_NTZ", dbapi_type=[DBAPI_TYPE_TIMESTAMP]),
32-
FieldType(name="OBJECT", dbapi_type=[DBAPI_TYPE_BINARY]),
33-
FieldType(name="ARRAY", dbapi_type=[DBAPI_TYPE_BINARY]),
34-
FieldType(name="BINARY", dbapi_type=[DBAPI_TYPE_BINARY]),
35-
FieldType(name="TIME", dbapi_type=[DBAPI_TYPE_TIMESTAMP]),
36-
FieldType(name="BOOLEAN", dbapi_type=[]),
37-
FieldType(name="GEOGRAPHY", dbapi_type=[DBAPI_TYPE_STRING]),
38-
]
29+
pa_type: Callable[[], DataType]
30+
31+
32+
# This type mapping holds column type definitions.
33+
# Be careful to not change the ordering as the index is what Snowflake
34+
# gives to as schema
35+
FIELD_TYPES: tuple[FieldType] = (
36+
FieldType(name="FIXED", dbapi_type=[DBAPI_TYPE_NUMBER], pa_type=lambda: pa.int64()),
37+
FieldType(
38+
name="REAL", dbapi_type=[DBAPI_TYPE_NUMBER], pa_type=lambda: pa.float64()
39+
),
40+
FieldType(name="TEXT", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda: pa.string()),
41+
FieldType(
42+
name="DATE", dbapi_type=[DBAPI_TYPE_TIMESTAMP], pa_type=lambda: pa.date64()
43+
),
44+
FieldType(
45+
name="TIMESTAMP",
46+
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
47+
pa_type=lambda: pa.time64("ns"),
48+
),
49+
FieldType(
50+
name="VARIANT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.string()
51+
),
52+
FieldType(
53+
name="TIMESTAMP_LTZ",
54+
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
55+
pa_type=lambda: pa.timestamp("ns"),
56+
),
57+
FieldType(
58+
name="TIMESTAMP_TZ",
59+
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
60+
pa_type=lambda: pa.timestamp("ns"),
61+
),
62+
FieldType(
63+
name="TIMESTAMP_NTZ",
64+
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
65+
pa_type=lambda: pa.timestamp("ns"),
66+
),
67+
FieldType(
68+
name="OBJECT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.string()
69+
),
70+
FieldType(
71+
name="ARRAY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.string()
72+
),
73+
FieldType(
74+
name="BINARY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.binary()
75+
),
76+
FieldType(
77+
name="TIME", dbapi_type=[DBAPI_TYPE_TIMESTAMP], pa_type=lambda: pa.time64("ns")
78+
),
79+
FieldType(name="BOOLEAN", dbapi_type=[], pa_type=lambda: pa.bool_()),
80+
FieldType(
81+
name="GEOGRAPHY", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda: pa.string()
82+
),
83+
)
3984

4085
FIELD_NAME_TO_ID: DefaultDict[Any, int] = defaultdict(int)
4186
FIELD_ID_TO_NAME: DefaultDict[int, str] = defaultdict(str)

src/snowflake/connector/cursor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
from snowflake.connector.result_batch import create_batches_from_response
3232
from snowflake.connector.result_set import ResultSet
3333

34+
from . import compat
3435
from .bind_upload_agent import BindUploadAgent, BindUploadError
35-
from .compat import BASE_EXCEPTION_CLASS
3636
from .constants import (
3737
FIELD_NAME_TO_ID,
3838
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT,
@@ -271,7 +271,7 @@ def __init__(
271271
def __del__(self) -> None: # pragma: no cover
272272
try:
273273
self.close()
274-
except BASE_EXCEPTION_CLASS as e:
274+
except compat.BASE_EXCEPTION_CLASS as e:
275275
if logger.getEffectiveLevel() <= logging.INFO:
276276
logger.info(e)
277277

src/snowflake/connector/errors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def telemetry_msg(self) -> str | None:
120120

121121
def generate_telemetry_exception_data(self) -> dict[str, str]:
122122
"""Generate the data to send through telemetry."""
123+
123124
telemetry_data = {
124125
TelemetryField.KEY_DRIVER_TYPE.value: CLIENT_NAME,
125126
TelemetryField.KEY_DRIVER_VERSION.value: SNOWFLAKE_CONNECTOR_VERSION,
@@ -146,6 +147,7 @@ def send_exception_telemetry(
146147
telemetry_data: dict[str, str],
147148
) -> None:
148149
"""Send telemetry data by in-band telemetry if it is enabled, otherwise send through out-of-band telemetry."""
150+
149151
if (
150152
connection is not None
151153
and connection.telemetry_enabled
@@ -164,6 +166,7 @@ def send_exception_telemetry(
164166
logger.debug("Cursor failed to log to telemetry.", exc_info=True)
165167
elif connection is None:
166168
# Send with out-of-band telemetry
169+
167170
telemetry_oob = TelemetryService.get_instance()
168171
telemetry_oob.log_general_exception(self.__class__.__name__, telemetry_data)
169172

src/snowflake/connector/options.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import pkg_resources
1414

15-
from .errors import MissingDependencyError
15+
from . import errors
1616

1717
logger = getLogger(__name__)
1818

@@ -35,7 +35,7 @@ class MissingOptionalDependency:
3535
_dep_name = "not set"
3636

3737
def __getattr__(self, item):
38-
raise MissingDependencyError(self._dep_name)
38+
raise errors.MissingDependencyError(self._dep_name)
3939

4040

4141
class MissingPandas(MissingOptionalDependency):

src/snowflake/connector/result_batch.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from .arrow_context import ArrowConverterContext
1616
from .compat import OK, UNAUTHORIZED, urlparse
17-
from .constants import IterUnit
17+
from .constants import FIELD_TYPES, IterUnit
1818
from .errorcode import ER_FAILED_TO_CONVERT_ROW_TO_PYTHON_TYPE, ER_NO_PYARROW
1919
from .errors import Error, InterfaceError, NotSupportedError, ProgrammingError
2020
from .network import (
@@ -25,6 +25,7 @@
2525
raise_okta_unauthorized_error,
2626
)
2727
from .options import installed_pandas, pandas
28+
from .options import pyarrow as pa
2829
from .secret_detector import SecretDetector
2930
from .time_util import DecorrelateJitterBackoff, TimerContextManager
3031
from .vendored import requests
@@ -40,20 +41,13 @@
4041
from .cursor import ResultMetadata, SnowflakeCursor
4142
from .vendored.requests import Response
4243

43-
if installed_pandas:
44-
from pyarrow import DataType, Table
45-
from pyarrow import binary as pa_bin
46-
from pyarrow import bool_ as pa_bool
47-
from pyarrow import date64 as pa_date64
48-
from pyarrow import field
49-
from pyarrow import float64 as pa_flt64
50-
from pyarrow import int64 as pa_int64
51-
from pyarrow import schema
52-
from pyarrow import string as pa_str
53-
from pyarrow import time64 as pa_time64
54-
from pyarrow import timestamp as pa_ts
55-
else:
56-
DataType, Table = None, None
44+
if installed_pandas:
45+
DataType = pa.DataType
46+
Table = pa.Table
47+
else:
48+
DataType = None
49+
Table = None
50+
5751

5852
# emtpy pyarrow type array corresponding to FIELD_TYPES
5953
FIELD_TYPE_TO_PA_TYPE: list[DataType] = []
@@ -655,26 +649,11 @@ def _create_empty_table(self) -> Table:
655649
"""Returns emtpy Arrow table based on schema"""
656650
if installed_pandas:
657651
# initialize pyarrow type array corresponding to FIELD_TYPES
658-
FIELD_TYPE_TO_PA_TYPE = [
659-
pa_int64(),
660-
pa_flt64(),
661-
pa_str(),
662-
pa_date64(),
663-
pa_time64("ns"),
664-
pa_str(),
665-
pa_ts("ns"),
666-
pa_ts("ns"),
667-
pa_ts("ns"),
668-
pa_str(),
669-
pa_str(),
670-
pa_bin(),
671-
pa_time64("ns"),
672-
pa_bool(),
673-
]
652+
FIELD_TYPE_TO_PA_TYPE = [e.pa_type() for e in FIELD_TYPES]
674653
fields = [
675-
field(s.name, FIELD_TYPE_TO_PA_TYPE[s.type_code]) for s in self.schema
654+
pa.field(s.name, FIELD_TYPE_TO_PA_TYPE[s.type_code]) for s in self.schema
676655
]
677-
return schema(fields).empty_table()
656+
return pa.schema(fields).empty_table()
678657

679658
def to_arrow(self, connection: SnowflakeConnection | None = None) -> Table:
680659
"""Returns this batch as a pyarrow Table"""

test/integ/test_cursor.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -572,41 +572,33 @@ def test_variant(conn, db_parameters):
572572

573573

574574
@pytest.mark.skipolddriver
575-
def test_geography(conn, db_parameters):
575+
def test_geography(conn_cnx):
576576
"""Variant including JSON object."""
577577
name_geo = random_string(5, "test_geography_")
578-
with conn() as cnx:
579-
cnx.cursor().execute(
580-
f"""\
581-
create table {name_geo} (geo geography)
582-
"""
583-
)
584-
cnx.cursor().execute(
585-
f"""\
586-
insert into {name_geo} values ('POINT(0 0)'), ('LINESTRING(1 1, 2 2)')
587-
"""
588-
)
589-
expected_data = [
590-
{"coordinates": [0, 0], "type": "Point"},
591-
{"coordinates": [[1, 1], [2, 2]], "type": "LineString"},
592-
]
593-
594-
try:
595-
with conn() as cnx:
596-
c = cnx.cursor()
597-
c.execute("alter session set GEOGRAPHY_OUTPUT_FORMAT='geoJson'")
578+
with conn_cnx(
579+
session_parameters={
580+
"GEOGRAPHY_OUTPUT_FORMAT": "geoJson",
581+
},
582+
) as cnx:
583+
with cnx.cursor() as cur:
584+
cur.execute(f"create temporary table {name_geo} (geo geography)")
585+
cur.execute(
586+
f"insert into {name_geo} values ('POINT(0 0)'), ('LINESTRING(1 1, 2 2)')"
587+
)
588+
expected_data = [
589+
{"coordinates": [0, 0], "type": "Point"},
590+
{"coordinates": [[1, 1], [2, 2]], "type": "LineString"},
591+
]
598592

593+
with cnx.cursor() as cur:
599594
# Test with GEOGRAPHY return type
600-
result = c.execute(f"select * from {name_geo}")
595+
result = cur.execute(f"select * from {name_geo}")
601596
metadata = result.description
602597
assert FIELD_ID_TO_NAME[metadata[0].type_code] == "GEOGRAPHY"
603598
data = result.fetchall()
604599
for raw_data in data:
605600
row = json.loads(raw_data[0])
606601
assert row in expected_data
607-
finally:
608-
with conn() as cnx:
609-
cnx.cursor().execute(f"drop table {name_geo}")
610602

611603

612604
def test_invalid_bind_data_type(conn_cnx):

test/integ/test_dbapi.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import snowflake.connector
1818
import snowflake.connector.dbapi
1919
from snowflake.connector import dbapi, errorcode, errors
20-
from snowflake.connector.compat import BASE_EXCEPTION_CLASS
2120

2221
from ..randomize import random_string
2322

@@ -273,20 +272,18 @@ def test_close(db_parameters):
273272
# errorcode.ER_CURSOR_IS_CLOSED),'cursor.close() called twice in a row')
274273

275274
# calling cursor.execute after connection is closed should raise an error
276-
try:
275+
with pytest.raises(errors.Error) as e:
277276
cur.execute(f"create or replace table {TABLE1} (name string)")
278-
except BASE_EXCEPTION_CLASS as error:
279-
assert (
280-
error.errno == errorcode.ER_CURSOR_IS_CLOSED
281-
), "cursor.execute() called twice in a row"
277+
assert (
278+
e.value.errno == errorcode.ER_CURSOR_IS_CLOSED
279+
), "cursor.execute() called twice in a row"
282280

283-
# try to create a cursor on a closed connection
284-
try:
285-
con.cursor()
286-
except BASE_EXCEPTION_CLASS as error:
287-
assert (
288-
error.errno == errorcode.ER_CONNECTION_IS_CLOSED
289-
), "tried to create a cursor on a closed cursor"
281+
# try to create a cursor on a closed connection
282+
with pytest.raises(errors.Error) as e:
283+
con.cursor()
284+
assert (
285+
e.value.errno == errorcode.ER_CONNECTION_IS_CLOSED
286+
), "tried to create a cursor on a closed cursor"
290287

291288

292289
def test_execute(conn_local):

test/integ/test_put_get_user_stage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def _put_list_rm_files_in_stage(tmpdir, conn_cnx, elem):
436436

437437
from io import open
438438

439-
from snowflake.connector.compat import UTF8
439+
from snowflake.connector.constants import UTF8
440440

441441
tmp_dir = str(tmpdir.mkdir("data"))
442442
data_file = os.path.join(tmp_dir, data_file_name)

0 commit comments

Comments
 (0)