Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions skyflow/vault/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
'''
import json
import types
import typing
import requests
import asyncio
from concurrent.futures import Future, ThreadPoolExecutor
from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse
from skyflow.vault._update import sendUpdateRequests, createUpdateResponseBody
from skyflow.vault._config import Configuration, ConnectionConfig, DeleteOptions, DetokenizeOptions, GetOptions, InsertOptions, UpdateOptions, QueryOptions
Expand Down Expand Up @@ -86,7 +88,7 @@ def detokenize(self, records: dict, options: DetokenizeOptions = DetokenizeOptio
self.storedToken = tokenProviderWrapper(
self.storedToken, self.tokenProvider, interface)
url = self._get_complete_vault_url() + '/detokenize'
responses = asyncio.run(sendDetokenizeRequests(
responses = run_coro(sendDetokenizeRequests(
records, url, self.storedToken, options))
result, partial = createDetokenizeResponseBody(records, responses, options)
if partial:
Expand All @@ -105,7 +107,7 @@ def get(self, records, options: GetOptions = GetOptions()):
self.storedToken = tokenProviderWrapper(
self.storedToken, self.tokenProvider, interface)
url = self._get_complete_vault_url()
responses = asyncio.run(sendGetRequests(
responses = run_coro(sendGetRequests(
records, options, url, self.storedToken))
result, partial = createGetResponseBody(responses)
if partial:
Expand All @@ -124,7 +126,7 @@ def get_by_id(self, records):
self.storedToken = tokenProviderWrapper(
self.storedToken, self.tokenProvider, interface)
url = self._get_complete_vault_url()
responses = asyncio.run(sendGetByIdRequests(
responses = run_coro(sendGetByIdRequests(
records, url, self.storedToken))
result, partial = createGetResponseBody(responses)
if partial:
Expand Down Expand Up @@ -201,7 +203,7 @@ def update(self, updateInput, options: UpdateOptions = UpdateOptions()):
self.storedToken = tokenProviderWrapper(
self.storedToken, self.tokenProvider, interface)
url = self._get_complete_vault_url()
responses = asyncio.run(sendUpdateRequests(
responses = run_coro(sendUpdateRequests(
updateInput, options, url, self.storedToken))
result, partial = createUpdateResponseBody(responses)
if partial:
Expand Down Expand Up @@ -290,4 +292,34 @@ def delete(self, records: dict, options: DeleteOptions = DeleteOptions()):

else:
log_info(InfoMessages.DELETE_DATA_SUCCESS.value, interface)
return result
return result


T = typing.TypeVar('T')

def run_coro(coro: typing.Coroutine[typing.Any, typing.Any, T]) -> T:
"""
Run a coroutine in a thread pool. This avoids the RuntimeError that occurs
when calling asyncio.run() from a thread that already has an event loop.

Note that this isn't performant, since it create a new thread with a new
event loop for each call.

Args:
coro: The coroutine to run.

Returns:
The result of the coroutine.
"""

try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)

with ThreadPoolExecutor() as executor:
# Must run asyncio.run in a thread. If we don't we'll get the following
# error:
# RuntimeError: asyncio.run() cannot be called from a running event loop
future: Future[T] = executor.submit(asyncio.run, coro)
return future.result()