Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions registry/sql-registry/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, FastAPI, HTTPException
from starlette.middleware.cors import CORSMiddleware
from registry import *
from registry.db_registry import DbRegistry
from registry.models import EntityType
from registry.models import AnchorDef, AnchorFeatureDef, DerivedFeatureDef, EntityType, ProjectDef, SourceDef, to_snake

rp = "/"
try:
Expand All @@ -21,11 +22,12 @@

# Enables CORS
app.add_middleware(CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


@router.get("/projects")
def get_projects() -> list[str]:
Expand Down Expand Up @@ -54,7 +56,8 @@ def get_project_features(project: str, keyword: Optional[str] = None) -> list:
features = registry.get_entities(feature_ids)
return list([e.to_dict() for e in features])
else:
efs = registry.search_entity(keyword, [EntityType.AnchorFeature, EntityType.DerivedFeature])
efs = registry.search_entity(
keyword, [EntityType.AnchorFeature, EntityType.DerivedFeature])
feature_ids = [ef.id for ef in efs]
features = registry.get_entities(feature_ids)
return list([e.to_dict() for e in features])
Expand All @@ -64,7 +67,8 @@ def get_project_features(project: str, keyword: Optional[str] = None) -> list:
def get_feature(feature: str) -> dict:
e = registry.get_entity(feature)
if e.entity_type not in [EntityType.DerivedFeature, EntityType.AnchorFeature]:
raise HTTPException(status_code=404, detail=f"Feature {feature} not found")
raise HTTPException(
status_code=404, detail=f"Feature {feature} not found")
return e


Expand All @@ -74,4 +78,39 @@ def get_feature_lineage(feature: str) -> dict:
return lineage.to_dict()


app.include_router(prefix = rp, router=router)
@router.post("/projects")
def new_project(definition: dict) -> UUID:
id = registry.create_project(ProjectDef(**to_snake(definition)))
return {"guid": str(id)}


@router.post("/projects/{project}/datasources")
def new_project_datasource(project: str, definition: dict) -> UUID:
project_id = registry.get_entity_id(project)
id = registry.create_project_datasource(project_id, SourceDef(**to_snake(definition)))
return {"guid": str(id)}


@router.post("/projects/{project}/anchors")
def new_project_anchor(project: str, definition: dict) -> UUID:
project_id = registry.get_entity_id(project)
id = registry.create_project_anchor(project_id, AnchorDef(**to_snake(definition)))
return {"guid": str(id)}


@router.post("/projects/{project}/anchors/{anchor}/features")
def new_project_anchor_feature(project: str, anchor: str, definition: dict) -> UUID:
project_id = registry.get_entity_id(project)
anchor_id = registry.get_entity_id(anchor)
id = registry.create_project_anchor_feature(project_id, anchor_id, AnchorFeatureDef(**to_snake(definition)))
return {"guid": str(id)}


@router.post("/projects/{project}/derivedfeatures")
def new_project_derived_feature(project: str, definition: dict) -> UUID:
project_id = registry.get_entity_id(project)
id = registry.create_project_derived_feature(project_id, DerivedFeatureDef(**to_snake(definition)))
return {"guid": str(id)}


app.include_router(prefix=rp, router=router)
52 changes: 45 additions & 7 deletions registry/sql-registry/registry/database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
import logging
import threading
from distutils.log import debug, warn
import os
Expand All @@ -9,7 +11,7 @@

class DbConnection(ABC):
@abstractmethod
def execute(self, sql: str, *args, **kwargs) -> list[dict]:
def query(self, sql: str, *args, **kwargs) -> list[dict]:
pass

def quote(id):
Expand Down Expand Up @@ -38,12 +40,15 @@ def parse_conn_str(s: str) -> dict:

class MssqlConnection(DbConnection):
@staticmethod
def connect(*args, **kwargs):
def connect(autocommit = True):
conn_str = os.environ["CONNECTION_STR"]
if "Server=" not in conn_str:
debug("`CONNECTION_STR` is not in ADO connection string format")
return None
return MssqlConnection(parse_conn_str(conn_str))
params = parse_conn_str(conn_str)
if not autocommit:
params["autocommit"] = False
return MssqlConnection(params)

def __init__(self, params):
self.params = params
Expand All @@ -53,8 +58,11 @@ def __init__(self, params):
def make_connection(self):
self.conn = pymssql.connect(**self.params)

def execute(self, sql: str, *args, **kwargs) -> list[dict]:
debug(f"SQL: `{sql}`")
def query(self, sql: str, *args, **kwargs) -> list[dict]:
"""
Make SQL query and return result
"""
warn(f"SQL: `{sql}`")
# NOTE: Only one cursor is allowed at the same time
retry = 0
while True:
Expand All @@ -73,13 +81,43 @@ def execute(self, sql: str, *args, **kwargs) -> list[dict]:
raise
pass

@contextmanager
def transaction(self):
"""
Do NOT use self.query inside this block as they may reconnect
The minimal implementation could look like this if the provider doesn't support transaction
```
@contextmanager
def transaction(self):
try:
c = self.create_or_get_connection(...)
yield c
finally:
c.close(...)
```
"""
conn = None
cursor = None
try:
conn = MssqlConnection.connect(autocommit=False).conn
cursor = conn.cursor(as_dict=True)
yield cursor
except Exception as e:
logging.warning(f"Exception: {e}")
if conn:
conn.rollback()
raise e
finally:
if conn:
conn.commit()


providers.append(MssqlConnection)


def connect():
def connect(*args, **kargs):
for p in providers:
ret = p.connect()
ret = p.connect(*args, **kargs)
if ret is not None:
return ret
raise RuntimeError("Cannot connect to database")
Loading