Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7826d3e

Browse files
DATA-1339 Set application name based on streamlit being in sys.modules (snowflakedb#1117)
Co-authored-by: Mark Keller <[email protected]>
1 parent 6d55366 commit 7826d3e

File tree

2 files changed

+51
-65
lines changed

2 files changed

+51
-65
lines changed

src/snowflake/connector/connection.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,11 @@ def __init__(self, **kwargs):
268268

269269
self.heartbeat_thread = None
270270

271-
if "application" not in kwargs and ENV_VAR_PARTNER in os.environ.keys():
272-
kwargs["application"] = os.environ[ENV_VAR_PARTNER]
271+
if "application" not in kwargs:
272+
if ENV_VAR_PARTNER in os.environ.keys():
273+
kwargs["application"] = os.environ[ENV_VAR_PARTNER]
274+
elif "streamlit" in sys.modules:
275+
kwargs["application"] = "streamlit"
273276

274277
self.converter = None
275278
self.__set_error_attributes()

test/unit/test_connection.py

Lines changed: 46 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import json
88
import os
9+
import sys
910
from unittest.mock import patch
1011

1112
import pytest
@@ -19,45 +20,43 @@
1920
QueryStatus = None
2021

2122

22-
@patch("snowflake.connector.network.SnowflakeRestful._post_request")
23-
def test_connect_with_service_name(mockSnowflakeRestfulPostRequest):
24-
def mock_post_request(url, headers, json_body, **kwargs):
25-
global mock_cnt
26-
ret = None
27-
if mock_cnt == 0:
28-
# return from /v1/login-request
29-
ret = {
30-
"success": True,
31-
"message": None,
32-
"data": {
33-
"token": "TOKEN",
34-
"masterToken": "MASTER_TOKEN",
35-
"idToken": None,
36-
"parameters": [
37-
{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}
38-
],
39-
},
40-
}
41-
return ret
23+
def fake_connector() -> snowflake.connector.SnowflakeConnection:
24+
return snowflake.connector.connect(
25+
user="user",
26+
account="account",
27+
password="testpassword",
28+
database="TESTDB",
29+
warehouse="TESTWH",
30+
)
4231

43-
# POST requests mock
44-
mockSnowflakeRestfulPostRequest.side_effect = mock_post_request
4532

46-
global mock_cnt
47-
mock_cnt = 0
33+
@pytest.fixture
34+
def mock_post_requests(monkeypatch):
35+
request_body = {}
4836

49-
account = "testaccount"
50-
user = "testuser"
37+
def mock_post_request(request, url, headers, json_body, **kwargs):
38+
nonlocal request_body
39+
request_body.update(json.loads(json_body))
40+
return {
41+
"success": True,
42+
"message": None,
43+
"data": {
44+
"token": "TOKEN",
45+
"masterToken": "MASTER_TOKEN",
46+
"idToken": None,
47+
"parameters": [{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}],
48+
},
49+
}
5150

52-
# connection
53-
con = snowflake.connector.connect(
54-
account=account,
55-
user=user,
56-
password="testpassword",
57-
database="TESTDB",
58-
warehouse="TESTWH",
51+
monkeypatch.setattr(
52+
snowflake.connector.network.SnowflakeRestful, "_post_request", mock_post_request
5953
)
60-
assert con.service_name == "FAKE_SERVICE_NAME"
54+
55+
return request_body
56+
57+
58+
def test_connect_with_service_name(mock_post_requests):
59+
assert fake_connector().service_name == "FAKE_SERVICE_NAME"
6160

6261

6362
@pytest.mark.skip(reason="Mock doesn't work as expected.")
@@ -137,38 +136,22 @@ def test_is_still_running():
137136

138137

139138
@pytest.mark.skipolddriver
140-
@patch("snowflake.connector.network.SnowflakeRestful._post_request")
141-
def test_partner_env_var(mockSnowflakeRestfulPostRequest):
139+
def test_partner_env_var(mock_post_requests):
142140
PARTNER_NAME = "Amanda"
143141

144-
request_body = {}
142+
with patch.dict(os.environ, {ENV_VAR_PARTNER: PARTNER_NAME}):
143+
assert fake_connector().application == PARTNER_NAME
145144

146-
def mock_post_request(url, headers, json_body, **kwargs):
147-
nonlocal request_body
148-
request_body = json.loads(json_body)
149-
return {
150-
"success": True,
151-
"message": None,
152-
"data": {
153-
"token": "TOKEN",
154-
"masterToken": "MASTER_TOKEN",
155-
"idToken": None,
156-
"parameters": [{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}],
157-
},
158-
}
145+
assert (
146+
mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == PARTNER_NAME
147+
)
159148

160-
# POST requests mock
161-
mockSnowflakeRestfulPostRequest.side_effect = mock_post_request
162149

163-
with patch.dict(os.environ, {ENV_VAR_PARTNER: PARTNER_NAME}):
164-
# connection
165-
con = snowflake.connector.connect(
166-
user="user",
167-
account="account",
168-
password="testpassword",
169-
database="TESTDB",
170-
warehouse="TESTWH",
171-
)
172-
assert con.application == PARTNER_NAME
150+
@pytest.mark.skipolddriver
151+
def test_imported_module(mock_post_requests):
152+
with patch.dict(sys.modules, {"streamlit": "foo"}):
153+
assert fake_connector().application == "streamlit"
173154

174-
assert request_body["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == PARTNER_NAME
155+
assert (
156+
mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == "streamlit"
157+
)

0 commit comments

Comments
 (0)