Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 19d9668

Browse files
committed
SK-1308 Fix bug in Get method in Python SDK
- Fix bug when calling Get method with options tokens as true. - Remove redundant import statements from client file. - Add missing validation case for Get method.
1 parent 295f0ce commit 19d9668

File tree

5 files changed

+85
-118
lines changed

5 files changed

+85
-118
lines changed

skyflow/errors/_skyflow_errors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class SkyflowErrorMessages(Enum):
7272
INVALID_TOKEN_TYPE = "Token key has value of type %s, expected string"
7373
REDACTION_WITH_TOKENS_NOT_SUPPORTED = "Redaction cannot be used when tokens are true in options"
7474
TOKENS_GET_COLUMN_NOT_SUPPORTED = "Column_name or column_values cannot be used with tokens in options"
75+
BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED = "Both skyflow ids and column details (name and/or values) are specified in payload"
76+
7577
PARTIAL_SUCCESS = "Server returned errors, check SkyflowError.data for more"
7678

7779
VAULT_ID_INVALID_TYPE = "Expected Vault ID to be str, got %s"

skyflow/vault/_client.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,19 @@
44
import json
55
import types
66
import requests
7+
import asyncio
78
from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse
89
from skyflow.vault._update import sendUpdateRequests, createUpdateResponseBody
9-
from skyflow.vault._config import Configuration, GetOptions
10-
from skyflow.vault._config import InsertOptions, ConnectionConfig, UpdateOptions
10+
from skyflow.vault._config import Configuration, ConnectionConfig, DeleteOptions, DetokenizeOptions, GetOptions, InsertOptions, UpdateOptions, QueryOptions
1111
from skyflow.vault._connection import createRequest
1212
from skyflow.vault._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody
1313
from skyflow.vault._get_by_id import sendGetByIdRequests, createGetResponseBody
1414
from skyflow.vault._get import sendGetRequests
15-
import asyncio
16-
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
17-
from skyflow._utils import log_info, InfoMessages, InterfaceName, getMetrics
18-
from skyflow.vault._token import tokenProviderWrapper
19-
20-
from ._delete import deleteProcessResponse
21-
from ._insert import getInsertRequestBody, processResponse, convertResponse
22-
from ._update import sendUpdateRequests, createUpdateResponseBody
23-
from ._config import Configuration, DeleteOptions, DetokenizeOptions, InsertOptions, ConnectionConfig, UpdateOptions, QueryOptions
24-
from ._connection import createRequest
25-
from ._detokenize import sendDetokenizeRequests, createDetokenizeResponseBody
26-
from ._get_by_id import sendGetByIdRequests, createGetResponseBody
27-
from ._get import sendGetRequests
28-
import asyncio
15+
from skyflow.vault._delete import deleteProcessResponse
16+
from skyflow.vault._query import getQueryRequestBody, getQueryResponse
2917
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
3018
from skyflow._utils import log_info, log_error, InfoMessages, InterfaceName, getMetrics
31-
from ._token import tokenProviderWrapper
32-
from ._query import getQueryRequestBody, getQueryResponse
19+
from skyflow.vault._token import tokenProviderWrapper
3320

3421
class Client:
3522
def __init__(self, config: Configuration):
@@ -109,7 +96,7 @@ def get(self, records, options: GetOptions = GetOptions()):
10996
self.storedToken, self.tokenProvider, interface)
11097
url = self._get_complete_vault_url()
11198
responses = asyncio.run(sendGetRequests(
112-
records, options,url, self.storedToken))
99+
records, options, url, self.storedToken))
113100
result, partial = createGetResponseBody(responses)
114101
if partial:
115102
raise SkyflowError(SkyflowErrorCodes.PARTIAL_SUCCESS,

skyflow/vault/_get.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
interface = InterfaceName.GET.value
1313

14-
1514
def getGetRequestBody(data, options: GetOptions):
15+
requestBody = {}
1616
ids = None
1717
if "ids" in data:
1818
ids = data["ids"]
@@ -25,25 +25,28 @@ def getGetRequestBody(data, options: GetOptions):
2525
idType = str(type(id))
2626
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_ID_TYPE.value % (
2727
idType), interface=interface)
28+
requestBody["skyflow_ids"] = ids
2829
try:
2930
table = data["table"]
3031
except KeyError:
3132
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
3233
SkyflowErrorMessages.TABLE_KEY_ERROR, interface=interface)
3334
if not isinstance(table, str):
3435
tableType = str(type(table))
35-
3636
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_TABLE_TYPE.value % (
3737
tableType), interface=interface)
38+
else:
39+
requestBody["tableName"] = table
3840

39-
if options.tokens and data.get("redaction"):
40-
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
41-
SkyflowErrorMessages.REDACTION_WITH_TOKENS_NOT_SUPPORTED, interface=interface)
42-
if options.tokens and (data.get('columnName') or data.get('columnValues')):
43-
raise SkyflowError(SkyflowErrorCodes.TOKENS_GET_COLUMN_NOT_SUPPORTED,
41+
if options.tokens:
42+
if data.get("redaction"):
43+
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
44+
SkyflowErrorMessages.REDACTION_WITH_TOKENS_NOT_SUPPORTED, interface=interface)
45+
if (data.get('columnName') or data.get('columnValues')):
46+
raise SkyflowError(SkyflowErrorCodes.TOKENS_GET_COLUMN_NOT_SUPPORTED,
4447
SkyflowErrorMessages.TOKENS_GET_COLUMN_NOT_SUPPORTED, interface=interface)
45-
46-
if not options.tokens:
48+
requestBody["tokenization"] = options.tokens
49+
else:
4750
try:
4851
redaction = data["redaction"]
4952
except KeyError:
@@ -53,6 +56,8 @@ def getGetRequestBody(data, options: GetOptions):
5356
redactionType = str(type(redaction))
5457
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % (
5558
redactionType), interface=interface)
59+
else:
60+
requestBody["redaction"] = redaction.value
5661

5762
columnName = None
5863
if "columnName" in data:
@@ -69,13 +74,17 @@ def getGetRequestBody(data, options: GetOptions):
6974
columnValuesType = str(type(columnValues))
7075
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (
7176
columnValuesType), interface=interface)
77+
else:
78+
requestBody["column_name"] = columnName
79+
requestBody["column_values"] = columnValues
7280

7381
if (ids is None and (columnName is None or columnValues is None)):
7482
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
75-
SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR.value, interface=interface)
76-
return ids, table, redaction.value, columnName, columnValues
77-
return ids, table, "DEFAULT", None, None
78-
83+
SkyflowErrorMessages.UNIQUE_COLUMN_OR_IDS_KEY_ERROR, interface=interface)
84+
elif (ids != None and (columnName != None or columnValues != None)):
85+
raise SkyflowError(SkyflowErrorCodes.INVALID_INPUT,
86+
SkyflowErrorMessages.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED, interface=interface)
87+
return requestBody
7988

8089
async def sendGetRequests(data, options: GetOptions, url, token):
8190
tasks = []
@@ -97,27 +106,22 @@ async def sendGetRequests(data, options: GetOptions, url, token):
97106

98107
validatedRecords = []
99108
for record in records:
100-
ids, table, redaction, columnName, columnValues = getGetRequestBody(record, options)
101-
validatedRecords.append((ids, table, redaction, columnName, columnValues))
109+
requestBody = getGetRequestBody(record, options)
110+
validatedRecords.append(requestBody)
102111
async with ClientSession() as session:
103112
for record in validatedRecords:
104-
ids, table, redaction, columnName, columnValues = record
105113
headers = {
106114
"Authorization": "Bearer " + token,
107115
"sky-metadata": json.dumps(getMetrics())
108116
}
109-
params = {"redaction": redaction}
110-
111-
if ids is not None:
112-
params["skyflow_ids"] = ids
113-
if columnName is not None:
114-
params["column_name"] = columnName
115-
params["column_values"] = columnValues
117+
table = record.pop("tableName")
118+
params = record
119+
if options.tokens:
120+
params["tokenization"] = json.dumps(record["tokenization"])
116121
task = asyncio.ensure_future(
117-
get(url, headers, params, session, record[1], options.tokens)
122+
get(url, headers, params, session, table)
118123
)
119124
tasks.append(task)
120125
await asyncio.gather(*tasks)
121126
await session.close()
122-
123127
return tasks

skyflow/vault/_get_by_id.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,6 @@
1111

1212
interface = InterfaceName.GET_BY_ID.value
1313

14-
def encrypt_data(data, token):
15-
if token:
16-
key = Fernet.generate_key()
17-
fernet = Fernet(key)
18-
encrypted_data = data.copy()
19-
fields = encrypted_data["records"][0]["fields"]
20-
for record in encrypted_data["records"]:
21-
fields = record["fields"]
22-
for key, value in fields.items():
23-
if isinstance(value, str):
24-
encrypted_value = fernet.encrypt(value.encode()).decode()
25-
fields[key] = encrypted_value
26-
27-
serialized_data = json.dumps(encrypted_data)
28-
encrypted_bytes = serialized_data.encode()
29-
30-
return encrypted_bytes
31-
else:
32-
return data, None
33-
34-
3514
def getGetByIdRequestBody(data):
3615
try:
3716
ids = data["ids"]
@@ -98,21 +77,13 @@ async def sendGetByIdRequests(data, url, token):
9877
await session.close()
9978
return tasks
10079

101-
102-
async def get(url, headers, params, session, table,token=False):
80+
async def get(url, headers, params, session, table):
10381
async with session.get(url + "/" + table, headers=headers, params=params, ssl=False) as response:
10482
try:
105-
response_data = await response.text()
106-
107-
if token:
108-
data = json.loads(response_data)
109-
return (encrypt_data(data,token), response.status, table, response.headers['x-request-id'])
110-
11183
return (await response.read(), response.status, table, response.headers['x-request-id'])
11284
except KeyError:
11385
return (await response.read(), response.status, table)
11486

115-
11687
def createGetResponseBody(responses):
11788
result = {
11889
"records": [],

tests/vault/test_get.py

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@
44
import unittest
55
import os
66

7-
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
8-
from skyflow.vault import Client, Configuration, RedactionType, GetOptions
9-
from skyflow.vault._get_by_id import encrypt_data
10-
from skyflow.service_account import generate_bearer_token
11-
from dotenv import dotenv_values
127
import warnings
138
import asyncio
149
import json
15-
from cryptography.fernet import Fernet
16-
10+
from dotenv import dotenv_values
11+
from skyflow.service_account import generate_bearer_token
12+
from skyflow.vault import Client, Configuration, RedactionType, GetOptions
13+
from skyflow.vault._get import getGetRequestBody
14+
from skyflow.errors._skyflow_errors import SkyflowError, SkyflowErrorCodes, SkyflowErrorMessages
1715

1816
class TestGet(unittest.TestCase):
1917

@@ -169,7 +167,6 @@ def testGetByIdInvalidColumnValues(self):
169167
self.assertEqual(
170168
e.message, SkyflowErrorMessages.INVALID_COLUMN_VALUE.value % (str) )
171169

172-
173170
def testGetByTokenAndRedaction(self):
174171
invalidData = {"records": [
175172
{"ids": ["123","456"],
@@ -184,8 +181,7 @@ def testGetByTokenAndRedaction(self):
184181
e.message, SkyflowErrorMessages.REDACTION_WITH_TOKENS_NOT_SUPPORTED.value)
185182

186183
def testGetByNoOptionAndRedaction(self):
187-
invalidData = {"records":[
188-
{"ids":["123","456"],"table":"newstripe"}]}
184+
invalidData = {"records":[{"ids":["123", "456"], "table":"newstripe"}]}
189185
options = GetOptions(False)
190186
try:
191187
self.client.get(invalidData,options=options)
@@ -196,10 +192,13 @@ def testGetByNoOptionAndRedaction(self):
196192
e.message,SkyflowErrorMessages.REDACTION_KEY_ERROR.value)
197193

198194
def testGetByOptionAndUniqueColumnRedaction(self):
199-
invalidData ={"records":[
200-
{"table":"newstripe","columnName":"card_number","columnValues":["456","980"],}
201-
]}
202-
195+
invalidData ={
196+
"records":[{
197+
"table":"newstripe",
198+
"columnName":"card_number",
199+
"columnValues":["456","980"],
200+
}]
201+
}
203202
options = GetOptions(True)
204203
try:
205204
self.client.get(invalidData, options=options)
@@ -210,9 +209,13 @@ def testGetByOptionAndUniqueColumnRedaction(self):
210209
e.message, SkyflowErrorMessages.TOKENS_GET_COLUMN_NOT_SUPPORTED.value)
211210

212211
def testInvalidRedactionTypeWithNoOption(self):
213-
invalidData = {"records": [
214-
{"ids": ["123","456"],
215-
"table": "stripe", "redaction": "invalid_redaction"}]}
212+
invalidData = {
213+
"records": [{
214+
"ids": ["123","456"],
215+
"table": "stripe",
216+
"redaction": "invalid_redaction"
217+
}]
218+
}
216219
options = GetOptions(False)
217220
try:
218221
self.client.get(invalidData, options=options)
@@ -221,36 +224,36 @@ def testInvalidRedactionTypeWithNoOption(self):
221224
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
222225
self.assertEqual(e.message, SkyflowErrorMessages.INVALID_REDACTION_TYPE.value % (str))
223226

224-
def test_encrypt_data_with_token(self):
225-
data = {
227+
def testBothSkyflowIdsAndColumnDetailsPassed(self):
228+
invalidData = {
226229
"records": [
227230
{
228-
"fields": {
229-
"ids": ["123","456"],
230-
"table": "stripe",
231-
}
231+
"ids": ["123", "456"],
232+
"table": "stripe",
233+
"redaction": RedactionType.PLAIN_TEXT,
234+
"columnName": "email",
235+
"columnValues": ["[email protected]", "[email protected]"]
232236
}
233237
]
234238
}
235-
token = "secret_token"
236-
encrypted_bytes = encrypt_data(data, token)
237-
self.assertIsNotNone(encrypted_bytes)
238-
239-
def test_encrypt_data_without_token(self):
240-
data = {
241-
"records": [
242-
{
243-
"fields": {
244-
"ids": ["123", "456"],
245-
"table": "stripe",
246-
}
247-
}
248-
]
239+
options = GetOptions(False)
240+
try:
241+
self.client.get(invalidData, options=options)
242+
self.fail('Should have thrown an error')
243+
except SkyflowError as e:
244+
self.assertEqual(e.code, SkyflowErrorCodes.INVALID_INPUT.value)
245+
self.assertEqual(e.message, SkyflowErrorMessages.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED.value)
246+
247+
def testGetRequestBodyReturnsRequestBodyWithIds(self):
248+
validData = {
249+
"records": [{
250+
"ids": ["123", "456"],
251+
"table": "stripe",
252+
}]
249253
}
250-
token = None
251-
encrypted_data, key = encrypt_data(data, token)
252-
self.assertEqual(encrypted_data, data)
253-
self.assertIsNone(key)
254-
255-
256-
254+
options = GetOptions(True)
255+
try:
256+
requestBody = getGetRequestBody(validData["records"][0], options)
257+
self.assertTrue(requestBody["tokenization"])
258+
except SkyflowError as e:
259+
self.fail('Should not have thrown an error')

0 commit comments

Comments
 (0)