|
6 | 6 |
|
7 | 7 | import json
|
8 | 8 | import os
|
| 9 | +import sys |
9 | 10 | from unittest.mock import patch
|
10 | 11 |
|
11 | 12 | import pytest
|
|
19 | 20 | QueryStatus = None
|
20 | 21 |
|
21 | 22 |
|
22 |
| -@patch("snowflake.connector.network.SnowflakeRestful._post_request") |
23 |
| -def test_connect_with_service_name(mockSnowflakeRestfulPostRequest): |
24 |
| - def mock_post_request(url, headers, json_body, **kwargs): |
25 |
| - global mock_cnt |
26 |
| - ret = None |
27 |
| - if mock_cnt == 0: |
28 |
| - # return from /v1/login-request |
29 |
| - ret = { |
30 |
| - "success": True, |
31 |
| - "message": None, |
32 |
| - "data": { |
33 |
| - "token": "TOKEN", |
34 |
| - "masterToken": "MASTER_TOKEN", |
35 |
| - "idToken": None, |
36 |
| - "parameters": [ |
37 |
| - {"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"} |
38 |
| - ], |
39 |
| - }, |
40 |
| - } |
41 |
| - return ret |
| 23 | +def fake_connector() -> snowflake.connector.SnowflakeConnection: |
| 24 | + return snowflake.connector.connect( |
| 25 | + user="user", |
| 26 | + account="account", |
| 27 | + password="testpassword", |
| 28 | + database="TESTDB", |
| 29 | + warehouse="TESTWH", |
| 30 | + ) |
42 | 31 |
|
43 |
| - # POST requests mock |
44 |
| - mockSnowflakeRestfulPostRequest.side_effect = mock_post_request |
45 | 32 |
|
46 |
| - global mock_cnt |
47 |
| - mock_cnt = 0 |
| 33 | +@pytest.fixture |
| 34 | +def mock_post_requests(monkeypatch): |
| 35 | + request_body = {} |
48 | 36 |
|
49 |
| - account = "testaccount" |
50 |
| - user = "testuser" |
| 37 | + def mock_post_request(request, url, headers, json_body, **kwargs): |
| 38 | + nonlocal request_body |
| 39 | + request_body.update(json.loads(json_body)) |
| 40 | + return { |
| 41 | + "success": True, |
| 42 | + "message": None, |
| 43 | + "data": { |
| 44 | + "token": "TOKEN", |
| 45 | + "masterToken": "MASTER_TOKEN", |
| 46 | + "idToken": None, |
| 47 | + "parameters": [{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}], |
| 48 | + }, |
| 49 | + } |
51 | 50 |
|
52 |
| - # connection |
53 |
| - con = snowflake.connector.connect( |
54 |
| - account=account, |
55 |
| - user=user, |
56 |
| - password="testpassword", |
57 |
| - database="TESTDB", |
58 |
| - warehouse="TESTWH", |
| 51 | + monkeypatch.setattr( |
| 52 | + snowflake.connector.network.SnowflakeRestful, "_post_request", mock_post_request |
59 | 53 | )
|
60 |
| - assert con.service_name == "FAKE_SERVICE_NAME" |
| 54 | + |
| 55 | + return request_body |
| 56 | + |
| 57 | + |
| 58 | +def test_connect_with_service_name(mock_post_requests): |
| 59 | + assert fake_connector().service_name == "FAKE_SERVICE_NAME" |
61 | 60 |
|
62 | 61 |
|
63 | 62 | @pytest.mark.skip(reason="Mock doesn't work as expected.")
|
@@ -137,38 +136,22 @@ def test_is_still_running():
|
137 | 136 |
|
138 | 137 |
|
139 | 138 | @pytest.mark.skipolddriver
|
140 |
| -@patch("snowflake.connector.network.SnowflakeRestful._post_request") |
141 |
| -def test_partner_env_var(mockSnowflakeRestfulPostRequest): |
| 139 | +def test_partner_env_var(mock_post_requests): |
142 | 140 | PARTNER_NAME = "Amanda"
|
143 | 141 |
|
144 |
| - request_body = {} |
| 142 | + with patch.dict(os.environ, {ENV_VAR_PARTNER: PARTNER_NAME}): |
| 143 | + assert fake_connector().application == PARTNER_NAME |
145 | 144 |
|
146 |
| - def mock_post_request(url, headers, json_body, **kwargs): |
147 |
| - nonlocal request_body |
148 |
| - request_body = json.loads(json_body) |
149 |
| - return { |
150 |
| - "success": True, |
151 |
| - "message": None, |
152 |
| - "data": { |
153 |
| - "token": "TOKEN", |
154 |
| - "masterToken": "MASTER_TOKEN", |
155 |
| - "idToken": None, |
156 |
| - "parameters": [{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}], |
157 |
| - }, |
158 |
| - } |
| 145 | + assert ( |
| 146 | + mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == PARTNER_NAME |
| 147 | + ) |
159 | 148 |
|
160 |
| - # POST requests mock |
161 |
| - mockSnowflakeRestfulPostRequest.side_effect = mock_post_request |
162 | 149 |
|
163 |
| - with patch.dict(os.environ, {ENV_VAR_PARTNER: PARTNER_NAME}): |
164 |
| - # connection |
165 |
| - con = snowflake.connector.connect( |
166 |
| - user="user", |
167 |
| - account="account", |
168 |
| - password="testpassword", |
169 |
| - database="TESTDB", |
170 |
| - warehouse="TESTWH", |
171 |
| - ) |
172 |
| - assert con.application == PARTNER_NAME |
| 150 | +@pytest.mark.skipolddriver |
| 151 | +def test_imported_module(mock_post_requests): |
| 152 | + with patch.dict(sys.modules, {"streamlit": "foo"}): |
| 153 | + assert fake_connector().application == "streamlit" |
173 | 154 |
|
174 |
| - assert request_body["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == PARTNER_NAME |
| 155 | + assert ( |
| 156 | + mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == "streamlit" |
| 157 | + ) |
0 commit comments