From 6625ce223e69d4e4dc844cafa5ea9f1253b231c9 Mon Sep 17 00:00:00 2001 From: Yuqing Wei Date: Mon, 27 Jun 2022 20:38:48 +0800 Subject: [PATCH 1/3] access control plugin --- registry/Dockerfile | 9 + registry/access_control/README.md | 74 +++++ registry/access_control/__init__.py | 9 + registry/access_control/access.py | 49 +++ registry/access_control/auth.py | 126 ++++++++ registry/access_control/config.py | 23 ++ registry/access_control/db_rbac.py | 111 +++++++ registry/access_control/interface.py | 46 +++ registry/access_control/models.py | 94 ++++++ registry/access_control/scripts/schema.sql | 13 + registry/access_control/scripts/test_data.sql | 4 + registry/api.py | 95 ++++++ registry/common/database.py | 153 +++++++++ registry/main.py | 26 ++ registry/requirements.txt | 5 + ui/src/api/api.tsx | 227 ++++++++------ ui/src/api/mock/userrole.json | 63 ---- ui/src/components/roleManagementForm.tsx | 66 ++-- ui/src/components/userRoles.tsx | 293 ++++++++++-------- ui/src/models/model.ts | 3 + 20 files changed, 1165 insertions(+), 324 deletions(-) create mode 100644 registry/Dockerfile create mode 100644 registry/access_control/README.md create mode 100644 registry/access_control/__init__.py create mode 100644 registry/access_control/access.py create mode 100644 registry/access_control/auth.py create mode 100644 registry/access_control/config.py create mode 100644 registry/access_control/db_rbac.py create mode 100644 registry/access_control/interface.py create mode 100644 registry/access_control/models.py create mode 100644 registry/access_control/scripts/schema.sql create mode 100644 registry/access_control/scripts/test_data.sql create mode 100644 registry/api.py create mode 100644 registry/common/database.py create mode 100644 registry/main.py create mode 100644 registry/requirements.txt delete mode 100644 ui/src/api/mock/userrole.json diff --git a/registry/Dockerfile b/registry/Dockerfile new file mode 100644 index 000000000..d2647021d --- /dev/null +++ b/registry/Dockerfile @@ -0,0 +1,9 @@ +FROM python:3.9 + +COPY ./ /usr/src + +WORKDIR /usr/src +RUN pip install -r requirements.txt + +# Start web server +CMD [ "uvicorn","main:app","--host", "0.0.0.0", "--port", "80" ] diff --git a/registry/access_control/README.md b/registry/access_control/README.md new file mode 100644 index 000000000..3d6f4d9dc --- /dev/null +++ b/registry/access_control/README.md @@ -0,0 +1,74 @@ +# Feathr Registry Access Control Gateway Specifications + +## Registry API with Access Control Gateway + +**Access Control Gateway** is an access control **Plugin** component of feature registry API. It can work with different type of backend registry. When user enables this component, registry requests will be validated in a gateway as below flow chart: + +```mermaid +flowchart TD + A[Get Registry API Request] --> B{Is Id Token Valid?}; + B -- No --> D[Return 401]; + B -- Yes --> C{Have Permission?}; + C -- No --> F[Return 403]; + C -- Yes --> E[Call Downstream API*]; + E --> G{API Service Available?} + G -- No --> I[Return 503] + G -- Yes --> H[Return API Results] +``` + +If Access control plugin is NOT enabled, the flow will start from **Call Downstream API*** + +## Access Control Registry API + +- For all **get** requests, check **read** permission for certain project. +- For all **post** request, check **write** permission for certain project. +- For all **access control management** request, check **manage** permission for certain project. +- In case of feature level query, will verify the parent project access of the feature. +- Registry API calls and returns will be transparently transferred. + +## Management Rules + +### Initialize `userroles` table + +Users needs to create a `userroles` table with [schema.sql](scripts/schema.sql) at the very first place. The process will be similar with SQL Registry `bacpac` initialization. + +### Initialize `userroles` records + +In current version, user needs to manually initialize `userroles` table admins in SQL table. +When `create_registry` and `create_project` API is enabled, default admin role will be assigned to the creator. +Admin roles can add or delete roles in management UI page or through management API. + +### Environment Settings + +| Variable| Description| +|---|---| +| CONNECTION_STR| Connection String of the SQL database that host access control tables| +| API_BASE| Aligned API base| +|REGISTRY_URL| The downstream Registry API Endpoint| +| AAD_INSTANCE | Set to "https://login.microsoftonline.com" by default | +| AAD_TENANT_ID| Used get auth url together with AAD_INSTANCE| +|API_AUDIENCE| Used as audience to decode jwt tokens| + +## Notes + +Supported scenarios status are tracked below: + +- General Foundations: + - [x] Access Control Abstract Class + - [x] API Spec Contents for Access Control Management APIs + - [x] API Spec Contents for Registry API Access Control + - [x] Separate Registry API and Access Control into different implementation + - [ ] A docker file to contain all required component for deployments +- SQL Implementation: + - [x] `userroles` table CRUD through FastAPI + - [x] `userroles` table schema & test data, could be used to make `.bacpac` file for SQL table initialize. + - [x] Initialize default Project Admin role for project creator + - [ ] Initialize default Global Admin Role for workspace creator +- UI Experience + - [x] Hidden page `../management` for global admin to make CUD requests to `userroles` table + - [x] Use id token in Management API Request headers to identify requestor +- Future Enhancements: + - [ ] Functional in Feathr Client + - [ ] Support AAD Groups + - [ ] Support Other OAuth Providers + \ No newline at end of file diff --git a/registry/access_control/__init__.py b/registry/access_control/__init__.py new file mode 100644 index 000000000..695071b7d --- /dev/null +++ b/registry/access_control/__init__.py @@ -0,0 +1,9 @@ +__all__ = ["auth", "access", "models", "interface", "db_rbac"] + + +from access_control.auth import * +from access_control.access import * +from access_control.interface import RBAC +from access_control.models import * +from access_control.db_rbac import DbRBAC +from common.database import DbConnection, connect diff --git a/registry/access_control/access.py b/registry/access_control/access.py new file mode 100644 index 000000000..7f9a04638 --- /dev/null +++ b/registry/access_control/access.py @@ -0,0 +1,49 @@ +from typing import Any +from fastapi import Depends, HTTPException, status +from access_control.db_rbac import DbRBAC + +from access_control.models import AccessType, User +from access_control.auth import authorize + +""" +All Access Validation Functions. Used as FastAPI Dependencies. +""" + +rbac = DbRBAC() + + +class ForbiddenAccess(HTTPException): + def __init__(self, detail: Any = None) -> None: + super().__init__(status_code=status.HTTP_403_FORBIDDEN, + detail=detail, headers={"WWW-Authenticate": "Bearer"}) + + +def get_user(user: User = Depends(authorize)) -> User: + return user + + +def project_read_access(project: str, user: User = Depends(authorize)) -> User: + return _project_access(project, user, AccessType.READ) + + +def project_write_access(project: str, user: User = Depends(authorize)) -> User: + return _project_access(project, user, AccessType.WRITE) + + +def project_manage_access(project: str, user: User = Depends(authorize)) -> User: + return _project_access(project, user, AccessType.MANAGE) + + +def _project_access(project: str, user: User, access: str): + if rbac.validate_project_access_users(project, user.preferred_username, access): + return user + else: + raise ForbiddenAccess( + f"{access} privileges for project {project} required for user {user.preferred_username}") + + +def global_admin_access(user: User = Depends(authorize)): + if user.preferred_username in rbac.global_admin: + return user + else: + raise ForbiddenAccess('Admin privileges required') diff --git a/registry/access_control/auth.py b/registry/access_control/auth.py new file mode 100644 index 000000000..0139508c0 --- /dev/null +++ b/registry/access_control/auth.py @@ -0,0 +1,126 @@ +import base64 +import logging +import requests +import rsa +from typing import Any, Mapping, Optional +from fastapi import HTTPException, Request, status +from fastapi.security import OAuth2AuthorizationCodeBearer +import jwt +from jwt.exceptions import ExpiredSignatureError, PyJWKError + +import access_control.config as config +from access_control.models import User + + +log = logging.getLogger() +BEARER_TOKEN = "BEARER " + + +class InvalidAuthorization(HTTPException): + def __init__(self, detail: Any = None) -> None: + super().__init__(status_code=status.HTTP_401_UNAUTHORIZED, + detail=detail, headers={"WWW-Authenticate": "Bearer"}) + + +class AzureADAuth(OAuth2AuthorizationCodeBearer): + # cached AAD jwt keys + aad_jwt_keys_cache: dict = {} + + def __init__(self, aad_instance: str = config.AAD_INSTANCE, aad_tenant: str = config.AAD_TENANT_ID): + self.base_auth_url: str = f"{aad_instance}/{aad_tenant}" + super(AzureADAuth, self).__init__( + authorizationUrl=f"{self.base_auth_url}/oauth2/v2.0/authorize", + tokenUrl=f"{self.base_auth_url}/oauth2/v2.0/token", + refreshUrl=f"{self.base_auth_url}/oauth2/v2.0/token", + scheme_name="oauth2", + scopes={ + f'api://{config.API_CLIENT_ID}/access_as_user': 'Access API as user', + } + ) + + async def __call__(self, request: Request) -> User: + bearer_token: str = request.headers.get("authorization") + token = bearer_token[len(BEARER_TOKEN):] + decoded_token = self._decode_token(token) + return self._get_user_from_token(decoded_token) + + @staticmethod + def _get_user_from_token(decoded_token: Mapping) -> User: + try: + user_id = decoded_token['oid'] + except Exception as e: + logging.debug(e) + raise InvalidAuthorization( + detail='Unable to extract user details from token') + + return User( + id=user_id, + name=decoded_token.get('name', ''), + preferred_username=decoded_token.get('preferred_username', ''), + roles=decoded_token.get('roles', []) + ) + + @staticmethod + def _get_key_id(token: str) -> Optional[str]: + headers = jwt.get_unverified_header(token) + return headers['kid'] if headers and 'kid' in headers else None + + @staticmethod + def _ensure_b64padding(key: str) -> str: + """ + The base64 encoded keys are not always correctly padded, so pad with the right number of = + """ + key = key.encode('utf-8') + missing_padding = len(key) % 4 + for _ in range(missing_padding): + key = key + b'=' + return key + + def _cache_aad_keys(self) -> None: + """ + Cache all AAD JWT keys - so we don't have to make a web call each auth request + """ + response = requests.get( + f"{self.base_auth_url}/v2.0/.well-known/openid-configuration") + aad_metadata = response.json() if response.ok else None + jwks_uri = aad_metadata['jwks_uri'] if aad_metadata and 'jwks_uri' in aad_metadata else None + if jwks_uri: + response = requests.get(jwks_uri) + keys = response.json() if response.ok else None + if keys and 'keys' in keys: + for key in keys['keys']: + n = int.from_bytes(base64.urlsafe_b64decode( + self._ensure_b64padding(key['n'])), "big") + e = int.from_bytes(base64.urlsafe_b64decode( + self._ensure_b64padding(key['e'])), "big") + pub_key = rsa.PublicKey(n, e) + # Cache the PEM formatted public key. + AzureADAuth.aad_jwt_keys_cache[key['kid']] = pub_key.save_pkcs1( + ) + + def _get_token_key(self, key_id: str) -> str: + if key_id not in AzureADAuth.aad_jwt_keys_cache: + self._cache_aad_keys() + return AzureADAuth.aad_jwt_keys_cache[key_id] + + def _decode_token(self, token: str) -> Mapping: + key_id = self._get_key_id(token) + if not key_id: + raise InvalidAuthorization('The token does not contain kid') + key = self._get_token_key(key_id) + try: + decode = jwt.decode(token, key=key, algorithms=[ + 'RS256'], audience=config.API_AUDIENCE) + return decode + except ExpiredSignatureError as e: + logging.debug(f'The token signature has expired: {e}') + raise InvalidAuthorization('The token signature has expired') + except PyJWKError as e: + logging.debug(f'Invalid token: {e}') + raise InvalidAuthorization('The token is invalid') + except Exception as e: + logging.debug(f'Unexpected error: {e}') + raise InvalidAuthorization('Unable to decode token') + + +authorize = AzureADAuth() diff --git a/registry/access_control/config.py b/registry/access_control/config.py new file mode 100644 index 000000000..d9175ee1e --- /dev/null +++ b/registry/access_control/config.py @@ -0,0 +1,23 @@ +from starlette.config import Config + + +config = Config(".env") + +# API Settings +API_BASE: str = config("API_BASE", default = "/api/v1") + +# Authentication +API_CLIENT_ID: str = config( + "API_CLIENT_ID", default="db8dc4b0-202e-450c-b38d-7396ad9631a5") +AAD_TENANT_ID: str = config( + "AAD_TENANT_ID", default="72f988bf-86f1-41af-91ab-2d7cd011db47") +AAD_INSTANCE: str = config( + "AAD_INSTANCE", default="https://login.microsoftonline.com") +API_AUDIENCE: str = config( + "API_AUDIENCE", default="db8dc4b0-202e-450c-b38d-7396ad9631a5") + +# SQL Database +CONNECTION_STR: str = config("CONNECTION_STR", default= "") + +# Downstream API Endpoint +REGISTRY_URL: str = config("REGISTRY_URL", default= "https://feathr-sql-registry.azurewebsites.net/api/v1") diff --git a/registry/access_control/db_rbac.py b/registry/access_control/db_rbac.py new file mode 100644 index 000000000..53dccd51f --- /dev/null +++ b/registry/access_control/db_rbac.py @@ -0,0 +1,111 @@ +from access_control import config +from common.database import connect +from access_control.models import AccessType, UserRole, RoleType, SUPER_ADMIN_SCOPE +from access_control.interface import RBAC +import os +import logging + + +class DbRBAC(RBAC): + def __init__(self): + if not os.environ["CONNECTION_STR"]: + os.environ["CONNECTION_STR"] = config.CONNECTION_STR + self.conn = connect() + # cached user role records + self._refresh_cache() + + def _refresh_cache(self): + #TODO: add user role cache refresh interval, e.g. daily + self.userroles = self._get_userroles() + self.global_admin = self._get_global_admin_users() + + def _get_userroles(self) -> list[UserRole]: + """query all the active user role records in SQL table + """ + rows = self.conn.query( + fr"""select record_id, project_name, user_name, role_name, create_by, create_reason, create_time, delete_by, delete_reason, delete_time + from userroles + where delete_reason is null""") + ret = [] + for row in rows: + r = UserRole(**row) + ret.append(UserRole(**row)) + logging.info(f"{ret.__len__} user roles are get.") + return ret + + def _get_global_admin_users(self) -> list[str]: + return [u.user_name for u in self.userroles if (u.project_name == SUPER_ADMIN_SCOPE and u.role_name == RoleType.ADMIN)] + + def validate_project_access_users(self, project: str, user: str, access: str = AccessType.READ) -> bool: + for u in self.userroles: + if (u.user_name == user and u.project_name in [project, SUPER_ADMIN_SCOPE] and (access in u.access)): + return True + return False + + def get_userroles_by_user(self, user_name: str, role_name: str = None) -> list[UserRole]: + """query the active user role of certain user + """ + query = fr"""select record_id, project_name, user_name, role_name, create_by, create_reason, create_time, delete_by, delete_reason, delete_time + from userroles + where delete_reason is null and user_name ='{user_name}'""" + if role_name: + query += fr"and role_name = '{role_name}'" + rows = self.conn.query(query) + ret = [] + for row in rows: + ret.append(UserRole(**row)) + return ret + + def get_userroles_by_project(self, project_name: str, role_name: str = None) -> list[UserRole]: + """query the active user role of certain project. + """ + query = fr"""select record_id, project_name, user_name, role_name, create_by, create_reason, create_time, delete_reason, delete_time + from userroles + where delete_reason is null and project_name ='{project_name}'""" + if role_name: + query += fr"and role_name = '{role_name}'" + rows = self.conn.query(query) + ret = [] + for row in rows: + ret.append(UserRole(**row)) + return ret + + def add_userrole(self, project_name: str, user_name: str, role_name: str, create_reason: str, by: str): + """insert new user role relationship into sql table + """ + # check if record already exist + for u in self.userroles: + if u.project_name == project_name and u.user_name == user_name and u.role_name == role_name: + logging.warning( + f"User {user_name} already have {role_name} role of {project_name}.") + return True + + # insert new record + query = f"""insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time) + values ('{project_name}','{user_name}','{role_name}','{by}' ,'{create_reason}', getutcdate())""" + self.conn.update(query) + self._refresh_cache() + return + + def delete_userrole(self, project_name: str, user_name: str, role_name: str, delete_reason: str, by: str): + """mark existing user role relationship as deleted with reason + """ + query = fr"""UPDATE userroles SET + [delete_by] = '{by}', + [delete_reason] = '{delete_reason}', + [delete_time] = getutcdate() + WHERE [user_name] = '{user_name}' and [project_name] = '{project_name}' and [role_name] = '{role_name}' + and [delete_time] is null""" + self.conn.update(query) + self._refresh_cache() + return + + def init_userrole(self, creator_name: str, project_name: str): + """initialize user role relationship when a new project is created + TODO: Add init user role to every new project call + """ + create_by = "system" + create_reason = "creator of project, get admin by default." + query = fr"""insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time) + values ('{project_name}','{creator_name}','{RoleType.ADMIN}','{create_by}','{create_reason}', getutcdate())""" + return self.conn.update(query) diff --git a/registry/access_control/interface.py b/registry/access_control/interface.py new file mode 100644 index 000000000..e3dbbcbbf --- /dev/null +++ b/registry/access_control/interface.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod +from access_control.models import UserRole + + +class RBAC(ABC): + @abstractmethod + def _get_userroles(self) -> list[UserRole]: + """Get List of All User Role Records + """ + pass + + @abstractmethod + def add_userrole(self, userrole: UserRole): + """Add a Role to a User + """ + pass + + @abstractmethod + def delete_userrole(self, userrole: UserRole): + """Delete a Role of a User + """ + pass + + @abstractmethod + def init_userrole(self, creator_name: str, project_name: str): + """Default User Role Relationship when a new project is created + """ + pass + + @abstractmethod + def get_userroles_by_user(self, user_name: str) -> list[UserRole]: + """Get List of All User Role Records for a User + """ + pass + + @abstractmethod + def get_userroles_by_project(self, project_name: str) -> list[UserRole]: + """Get List of All User Role Records for a Project + """ + pass + + @abstractmethod + def validate_project_access_users(self, project: str, user: str, access: str) -> bool: + """A Common function Validate if a user has certain project access + """ + pass diff --git a/registry/access_control/models.py b/registry/access_control/models.py new file mode 100644 index 000000000..b526d4bc0 --- /dev/null +++ b/registry/access_control/models.py @@ -0,0 +1,94 @@ +from typing import List, Optional +from pydantic import BaseModel +from datetime import datetime +from enum import Enum +from numpy import number + + +class User(BaseModel): + id: str + name: str + preferred_username: str + roles: List[str] + + +SUPER_ADMIN_SCOPE = "global" + + +class AccessType(str, Enum): + READ = "read", + WRITE = "write", + MANAGE = "manage", + + +class RoleType(str, Enum): + ADMIN = "admin", + CONSUMER = "consumer", + PRODUCER = "producer", + DEFAULT = "default", + + +RoleAccessMapping = { + RoleType.ADMIN: ["read", "write", "manage"], + RoleType.CONSUMER: ["read"], + RoleType.PRODUCER: ["read", "write"], + RoleType.DEFAULT: [] +} + + +class UserRole(): + def __init__(self, + record_id: number, + project_name: str, + user_name: str, + role_name: str, + create_by: str, + create_reason: str, + create_time: datetime, + delete_by: Optional[str] = None, + delete_reason: Optional[str] = None, + delete_time: Optional[datetime] = None, + **kwargs): + self.record_id = record_id + self.project_name = project_name.lower() + self.user_name = user_name.lower() + self.role_name = role_name.lower() + self.create_by = create_by.lower() + self.create_reason = create_reason + self.create_time = create_time + self.delete_by = delete_by + self.delete_reason = delete_reason + self.delete_time = delete_time + self.access = RoleAccessMapping[RoleType(self.role_name)] + + def to_dict(self) -> dict: + return { + "id": str(self.record_id), + "scope": self.project_name, + "userName": self.user_name, + "roleName": str(self.role_name), + "createBy": self.create_by, + "createReason": self.create_reason, + "createTime": str(self.create_time), + "deleteBy": str(self.delete_by), + "deleteReason": self.delete_reason, + "deleteTime": str(self.delete_time), + "access": self.access + } + + +class Access(): + def __init__(self, + record_id: number, + project_name: str, + access_name: str) -> None: + self.record_id = record_id + self.project_name = project_name + self.access_name = access_name + + def to_dict(self) -> dict: + return { + "record_id": str(self.record_id), + "project_name": self.project_name, + "access": self.access_name, + } diff --git a/registry/access_control/scripts/schema.sql b/registry/access_control/scripts/schema.sql new file mode 100644 index 000000000..ce9a7659a --- /dev/null +++ b/registry/access_control/scripts/schema.sql @@ -0,0 +1,13 @@ +create table userroles +( + record_id int IDENTITY(1,1), + project_name varchar(50) not null, + user_name varchar(50) not null, + role_name varchar(50) not null, + create_by varchar(50) not null, + create_reason varchar(50) not null, + create_time datetime not null, + delete_by varchar(50), + delete_reason varchar(50), + delete_time datetime, +) \ No newline at end of file diff --git a/registry/access_control/scripts/test_data.sql b/registry/access_control/scripts/test_data.sql new file mode 100644 index 000000000..a048c232a --- /dev/null +++ b/registry/access_control/scripts/test_data.sql @@ -0,0 +1,4 @@ +insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time) values ('global', 'abc@microsoft.com','admin', 'test_data@microsoft.com', 'test data', getutcdate()) +insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time) values ('global', 'efg@microsoft.com','admin', 'test_data@microsoft.com', 'test data', getutcdate()) +insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time) values ('feathr_ci_registry_12_33_182947', 'efg@microsoft.com','admin','test_data@microsoft.com', 'test data', getutcdate()) +insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time) values ('feathr_ci_registry_12_33_182947', 'hij@microsoft.com','consumer', 'test_data@microsoft.com', 'test data', getutcdate()) \ No newline at end of file diff --git a/registry/api.py b/registry/api.py new file mode 100644 index 000000000..7520a064d --- /dev/null +++ b/registry/api.py @@ -0,0 +1,95 @@ +import json +from typing import Optional +from fastapi import APIRouter, Depends +import requests +from access_control.access import global_admin_access, get_user, project_read_access, project_write_access +from access_control.models import User +from access_control.db_rbac import DbRBAC +from access_control import config + +router = APIRouter() +rbac = DbRBAC() +registry_url = config.REGISTRY_URL + +@router.get('/projects', name="Get a list of Project Names [No Auth Required]") +async def get_projects() -> list[str]: + response = requests.get(registry_url + "/projects").content.decode('utf-8') + return json.loads(response) + + +@router.get('/projects/{project}', name="Get My Project [Read Access Required]") +async def get_project(project: str, requestor: User = Depends(project_read_access)): + response = requests.get(registry_url + "/projects/" + project).content.decode('utf-8') + return json.loads(response) + + + +@router.get("/projects/{project}/datasources", name="Get data sources of my project [Read Access Required]") +def get_project_datasources(project: str, requestor: User = Depends(project_read_access)) -> list: + response = requests.get(registry_url + "/projects/" + project + "/datasources").content.decode('utf-8') + return json.loads(response) + + +@router.get("/projects/{project}/features", name="Get features under my project [Read Access Required]") +def get_project_features(project: str, keyword: Optional[str] = None, requestor: User = Depends(project_read_access)) -> list: + response = requests.get(registry_url + "/projects/" + project + "/features").content.decode('utf-8') + return json.loads(response) + + +@router.get("/features/{feature}/{project}", name="Get a single feature by feature Id [Read Access Required]") +def get_feature(feature: str, requestor: User = Depends(project_read_access)) -> dict: + response = requests.get(registry_url + "/features/" + feature).content.decode('utf-8') + return json.loads(response) + + +@router.get("/features/{feature}/lineage/{project}", name="Get Feature Lineage [Read Access Required]") +def get_feature_lineage(feature: str, requestor: User = Depends(project_read_access)) -> dict: + response = requests.get(registry_url + "/features/" + feature + "/lineage").content.decode('utf-8') + return json.loads(response) + + +@router.post("/projects", name="Create new project with definition [Auth Required]") +def new_project(definition: dict, requestor: User = Depends(get_user)) -> dict: + rbac.init_userrole(requestor, definition["name"]) + response = requests.post(url = registry_url + "/projects", params=definition).content.decode('utf-8') + return json.loads(response) + +@router.post("/projects/{project}/datasources", name="Create new data source of my project [Write Access Required]") +def new_project_datasource(project: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict: + response = requests.post(url = registry_url + "/projects/" + project + '/datasources', params=definition).content.decode('utf-8') + return json.loads(response) + +@router.post("/projects/{project}/anchors", name="Create new anchors of my project [Write Access Required]") +def new_project_anchor(project: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict: + response = requests.post(url = registry_url + "/projects/" + project + '/datasources', params=definition).content.decode('utf-8') + return json.loads(response) + + +@router.post("/projects/{project}/anchors/{anchor}/features", name="Create new anchor features of my project [Write Access Required]") +def new_project_anchor_feature(project: str, anchor: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict: + response = requests.post(url = registry_url + "/projects/" + project + '/anchors/' + anchor + '/features', params=definition).content.decode('utf-8') + return json.loads(response) + + +@router.post("/projects/{project}/derivedfeatures", name="Create new derived features of my project [Write Access Required]") +def new_project_derived_feature(project: str,definition: dict, requestor: User = Depends(project_write_access)) -> dict: + response = requests.post(url = registry_url + "/projects/" + project + '/derivedfeatures', params=definition).content.decode('utf-8') + return json.loads(response) + +# Below are access control management APIs + + +@router.get("/userroles", name="List all active user role records [Global Admin Required]") +def get_userroles(requestor: User = Depends(global_admin_access)) -> list: + return list([r.to_dict() for r in rbac.userroles]) + + +@router.post("/users/{user}/userroles/add", name="Add a new user role [Global Admin Required]") +def add_userrole(project: str, user: str, role: str, reason: str, requestor: User = Depends(global_admin_access)): + return rbac.add_userrole(project, user, role, reason, requestor.preferred_username) + + +@router.delete("/users/{user}/userroles/delete", name="Delete a user role [Global Admin Required]") +def delete_userrole(project: str, user: str, role: str, reason: str, requestor: User = Depends(global_admin_access)): + return rbac.delete_userrole(project, user, role, reason, requestor.preferred_username) + diff --git a/registry/common/database.py b/registry/common/database.py new file mode 100644 index 000000000..8dcbd30b0 --- /dev/null +++ b/registry/common/database.py @@ -0,0 +1,153 @@ +from abc import ABC, abstractmethod +from contextlib import contextmanager +import logging +import threading +import os +import pymssql + +# TODO: refactor common utils in registry folder +# This is a copy of database.py from sql-registry folder +# And add a `update` function to execute insert and set commands + +providers = [] + + +class DbConnection(ABC): + @abstractmethod + def query(self, sql: str, *args, **kwargs) -> list[dict]: + pass + + +def quote(id): + if isinstance(id, str): + return f"'{id}'" + else: + return ",".join([f"'{i}'" for i in id]) + + +def parse_conn_str(s: str) -> dict: + """ + TODO: Not a sound and safe implementation, but useful enough in this case + as the connection string is provided by users themselves. + """ + parts = dict([p.strip().split("=", 1) + for p in s.split(";") if len(p.strip()) > 0]) + server = parts["Server"].split(":")[1].split(",")[0] + return { + "host": server, + "database": parts["Initial Catalog"], + "user": parts["User ID"], + "password": parts["Password"], + # "charset": "utf-8", ## For unknown reason this causes connection failure + } + + +class MssqlConnection(DbConnection): + @staticmethod + def connect(autocommit=True): + conn_str = os.environ["CONNECTION_STR"] + if "Server=" not in conn_str: + logging.debug( + "`CONNECTION_STR` is not in ADO connection string format") + return None + params = parse_conn_str(conn_str) + if not autocommit: + params["autocommit"] = False + return MssqlConnection(params) + + def __init__(self, params): + self.params = params + self.make_connection() + self.mutex = threading.Lock() + + def make_connection(self): + self.conn = pymssql.connect(**self.params) + + def query(self, sql: str, *args, **kwargs) -> list[dict]: + """ + Make SQL query and return result + """ + logging.debug(f"SQL: `{sql}`") + # NOTE: Only one cursor is allowed at the same time + retry = 0 + while True: + try: + with self.mutex: + c = self.conn.cursor(as_dict=True) + c.execute(sql, *args, **kwargs) + return c.fetchall() + except pymssql.OperationalError: + logging.warning("Database error, retrying...") + # Reconnect + self.make_connection() + retry += 1 + if retry >= 3: + # Stop retrying + raise + pass + + def update(self, sql: str, *args, **kwargs): + retry = 0 + while True: + try: + with self.mutex: + c = self.conn.cursor(as_dict=True) + c.execute(sql, *args, **kwargs) + self.conn.commit() + return True + except pymssql.OperationalError: + logging.warning("Database error, retrying...") + # Reconnect + self.make_connection() + retry += 1 + if retry >= 3: + # Stop retrying + raise + pass + + @contextmanager + def transaction(self): + """ + Start a transaction so we can run multiple SQL in one batch. + User should use `with` with the returned value, look into db_registry.py for more real usage. + + NOTE: `self.query` and `self.execute` will use a different MSSQL connection so any change made + in this transaction will *not* be visible in these calls. + + The minimal implementation could look like this if the underlying engine doesn't support transaction. + ``` + @contextmanager + def transaction(self): + try: + c = self.create_or_get_connection(...) + yield c + finally: + c.close(...) + ``` + """ + conn = None + cursor = None + try: + # As one MssqlConnection has only one connection, we need to create a new one to disable `autocommit` + conn = MssqlConnection.connect(autocommit=False).conn + cursor = conn.cursor(as_dict=True) + yield cursor + except Exception as e: + logging.warning(f"Exception: {e}") + if conn: + conn.rollback() + raise e + finally: + if conn: + conn.commit() + + +providers.append(MssqlConnection) + + +def connect(*args, **kargs): + for p in providers: + ret = p.connect(*args, **kargs) + if ret is not None: + return ret + raise RuntimeError("Cannot connect to database") diff --git a/registry/main.py b/registry/main.py new file mode 100644 index 000000000..8a3ee2093 --- /dev/null +++ b/registry/main.py @@ -0,0 +1,26 @@ +import uvicorn +from fastapi import FastAPI +from starlette.middleware.cors import CORSMiddleware +from access_control import config +from api import router as api_router + +rp = config.API_BASE + +def get_application() -> FastAPI: + application = FastAPI() + # Enables CORS + application.add_middleware(CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + application.include_router(prefix=rp, router=api_router) + return application + + +app = get_application() + +if __name__ == "__main__": + uvicorn.run("main:app", host="localhost", port=8000, reload=True) diff --git a/registry/requirements.txt b/registry/requirements.txt new file mode 100644 index 000000000..cbdd7f02b --- /dev/null +++ b/registry/requirements.txt @@ -0,0 +1,5 @@ +pymssql +fastapi +uvicorn +pyjwt +pydantic \ No newline at end of file diff --git a/ui/src/api/api.tsx b/ui/src/api/api.tsx index 92a091e3d..6201eb17f 100644 --- a/ui/src/api/api.tsx +++ b/ui/src/api/api.tsx @@ -1,167 +1,200 @@ import Axios from "axios"; -import { DataSource, EnvConfig, Feature, FeatureLineage, Role, UserRole } from "../models/model"; -import { InteractionRequiredAuthError, PublicClientApplication } from "@azure/msal-browser"; -import mockUserRole from "./mock/userrole.json"; +import { + DataSource, + Feature, + FeatureLineage, + Role, + UserRole, +} from "../models/model"; +import {PublicClientApplication,} from "@azure/msal-browser"; import { getMsalConfig } from "../utils/utils"; const msalInstance = getMsalConfig(); const getApiBaseUrl = () => { - let endpoint = process.env.REACT_APP_API_ENDPOINT - if (!endpoint || endpoint === '') { + let endpoint = process.env.REACT_APP_API_ENDPOINT; + if (!endpoint || endpoint === "") { endpoint = window.location.protocol + "//" + window.location.host; } return endpoint + "/api/v1"; -} +}; export const fetchDataSources = async (project: string) => { - const token = await getIdToken(msalInstance); - return Axios - .get(`${ getApiBaseUrl() }/projects/${ project }/datasources?code=${ token }`, - { headers: {} }) + const axios = await authAxios(msalInstance); + return axios + .get(`${getApiBaseUrl()}/projects/${project}/datasources`, { + headers: {}, + }) .then((response) => { return response.data; - }) + }); }; export const fetchProjects = async () => { - const token = await getIdToken(msalInstance); - return Axios - .get<[]>(`${ getApiBaseUrl() }/projects?code=${ token }`, - { - headers: {} - }) - .then((response) => { - return response.data; - }) + const axios = await authAxios(msalInstance); + return axios.get<[]>(`${getApiBaseUrl()}/projects`, { + headers: {}, + }).then((response) => { + return response.data; + }); }; -export const fetchFeatures = async (project: string, page: number, limit: number, keyword: string) => { - const token = await getIdToken(msalInstance); - return Axios - .get(`${ getApiBaseUrl() }/projects/${ project }/features?code=${ token }`, - { - params: { 'keyword': keyword, 'page': page, 'limit': limit }, - headers: {} - }) +export const fetchFeatures = async ( + project: string, + page: number, + limit: number, + keyword: string +) => { + const axios = await authAxios(msalInstance); + return axios + .get(`${getApiBaseUrl()}/projects/${project}/features`, { + params: { keyword: keyword, page: page, limit: limit }, + headers: {}, + }) .then((response) => { return response.data; - }) + }); }; export const fetchFeature = async (project: string, featureId: string) => { - const token = await getIdToken(msalInstance); - return Axios - .get(`${ getApiBaseUrl() }/features/${ featureId }?code=${ token }`, {}) + const axios = await authAxios(msalInstance); + return axios + .get(`${getApiBaseUrl()}/features/${featureId}/${project}`, {}) .then((response) => { return response.data; - }) + }); }; export const fetchProjectLineages = async (project: string) => { - const token = await getIdToken(msalInstance); - return Axios - .get(`${ getApiBaseUrl() }/projects/${ project }?code=${ token }`, {}) + const axios = await authAxios(msalInstance); + return axios + .get(`${getApiBaseUrl()}/projects/${project}`, {}) .then((response) => { return response.data; - }) + }); }; export const fetchFeatureLineages = async (project: string) => { - const token = await getIdToken(msalInstance); - return Axios - .get(`${ getApiBaseUrl() }/features/lineage/${ project }?code=${ token }`, {}) + const axios = await authAxios(msalInstance); + return axios + .get(`${getApiBaseUrl()}/features/lineage/${project}`, {}) .then((response) => { return response.data; - }) + }); }; // Following are place-holder code export const createFeature = async (feature: Feature) => { - const token = await getIdToken(msalInstance); - return Axios - .post(`${ getApiBaseUrl() }/features?code=${ token }`, feature, - { - headers: { "Content-Type": "application/json;" }, - params: {}, - }).then((response) => { + const axios = await authAxios(msalInstance); + return axios + .post(`${getApiBaseUrl()}/features`, feature, { + headers: { "Content-Type": "application/json;" }, + params: {}, + }) + .then((response) => { return response; - }).catch((error) => { + }) + .catch((error) => { return error.response; }); -} +}; export const updateFeature = async (feature: Feature, id: string) => { - const token = await getIdToken(msalInstance); + const axios = await authAxios(msalInstance); feature.guid = id; - return await Axios.put(`${ getApiBaseUrl() }/features/${ id }?code=${ token }`, feature, - { + return await axios + .put(`${getApiBaseUrl()}/features/${id}`, feature, { headers: { "Content-Type": "application/json;" }, params: {}, - }).then((response) => { - return response - }).catch((error) => { - return error.response - }); + }) + .then((response) => { + return response; + }) + .catch((error) => { + return error.response; + }); }; export const listUserRole = async () => { - let data: UserRole[] = mockUserRole - return data + const token = await getIdToken(msalInstance); + console.log(token) + const axios = await authAxios(msalInstance); + return await axios + .get(`${getApiBaseUrl()}/userroles`, {}) + .then((response) => { + return response.data; + }); }; export const getUserRole = async (userName: string) => { - const token = await getIdToken(msalInstance); - return await Axios - .get(`${ getApiBaseUrl() }/user/${ userName }/userroles?code=${ token }`, {}) + const axios = await authAxios(msalInstance); + return await axios + .get(`${getApiBaseUrl()}/user/${userName}/userroles`, {}) .then((response) => { return response.data; - }) -} + }); +}; export const addUserRole = async (role: Role) => { - const token = await getIdToken(msalInstance); - return await Axios - .post(`${ getApiBaseUrl() }/user/${ role.userName }/userroles/new?code=${ token }`, role, - { - headers: { "Content-Type": "application/json;" }, - params: {}, - }).then((response) => { + const axios = await authAxios(msalInstance); + return await axios + .post(`${getApiBaseUrl()}/users/${role.userName}/userroles/add`, role, { + headers: { "Content-Type": "application/json;" }, + params: { + project: role.scope, + role: role.roleName, + reason: role.reason, + }, + }) + .then((response) => { return response; - }).catch((error) => { + }) + .catch((error) => { return error.response; }); -} +}; -export const deleteUserRole = async (role: Role) => { - const token = await getIdToken(msalInstance); - return await Axios - .post(`${ getApiBaseUrl() }/user/${ role.userName }/userroles/delete?code=${ token }`, role, - { - headers: { "Content-Type": "application/json;" }, - params: {}, - }).then((response) => { +export const deleteUserRole = async (userrole: UserRole) => { + const axios = await authAxios(msalInstance); + const reason = "Delete from management UI."; + return await axios + .delete(`${getApiBaseUrl()}/users/${userrole.userName}/userroles/delete`, { + headers: { "Content-Type": "application/json;" }, + params: { + project: userrole.scope, + role: userrole.roleName, + reason: reason, + }, + }) + .then((response) => { return response; - }).catch((error) => { + }) + .catch((error) => { return error.response; }); -} +}; -export const getIdToken = async (msalInstance: PublicClientApplication): Promise => { +export const getIdToken = async ( + msalInstance: PublicClientApplication +): Promise => { const activeAccount = msalInstance.getActiveAccount(); // This will only return a non-null value if you have logic somewhere else that calls the setActiveAccount API const accounts = msalInstance.getAllAccounts(); const request = { scopes: ["User.Read"], - account: activeAccount || accounts[0] + account: activeAccount || accounts[0], }; // Silently acquire an token for a given set of scopes. Will use cached token if available, otherwise will attempt to acquire a new token from the network via refresh token. - await msalInstance.acquireTokenSilent(request).then(response => { - return response.idToken - }).catch(error => { - if (error instanceof InteractionRequiredAuthError) { - msalInstance.acquireTokenPopup(request).then(response => { - return response.idToken - }); - } - }) - return "" -} + // A known issue may cause token expire: https://github.com/AzureAD/microsoft-authentication-library-for-js/issues/4206 + const authResult = await msalInstance.acquireTokenSilent(request); + return authResult.idToken; +}; + +export const authAxios = async (msalInstance: PublicClientApplication) => { + const token = await getIdToken(msalInstance); + return Axios.create({ + headers: { + Authorization: "Bearer " + token, + "Content-Type": "application/json", + }, + baseURL: getApiBaseUrl(), + }); +}; diff --git a/ui/src/api/mock/userrole.json b/ui/src/api/mock/userrole.json deleted file mode 100644 index 20535b8ec..000000000 --- a/ui/src/api/mock/userrole.json +++ /dev/null @@ -1,63 +0,0 @@ -[ - { - "id": 1, - "scope": "Global", - "userName": "edwinc@microsoft.com", - "roleName": "Admin", - "permissions": [ - "Read", - "Write", - "Management" - ], - "createReason": "Resource Owner", - "createTime": "2022/5/15" - }, - { - "id": 1, - "scope": "Global", - "userName": "yuqwe@microsoft.com", - "roleName": "Admin", - "permissions": [ - "Read", - "Write", - "Management" - ], - "createReason": "Test Purpose", - "createTime": "2022/5/16" - }, - { - "id": 2, - "scope": "Project A: Frontend Datasets", - "userName": "blairch@microsoft.com", - "roleName": "Producer", - "permissions": [ - "Read", - "Write" - ], - "createReason": "Project Owner", - "createTime": "2022/5/16" - }, - { - "id": 3, - "scope": "Project B: Backend Datasets", - "userName": "xuchen@microsoft.com", - "roleName": "Producer", - "permissions": [ - "Read", - "Write" - ], - "createReason": "Project Owner", - "createTime": "2022/5/16" - }, - { - "id": 4, - "scope": "Project B: Backend Datasets", - "userName": "yihgu@microsoft.com", - "roleName": "Consumer", - "permissions": [ - "Read" - ], - "createReason": "Data Engineering", - "createTime": "2022/5/17" - } -] \ No newline at end of file diff --git a/ui/src/components/roleManagementForm.tsx b/ui/src/components/roleManagementForm.tsx index fa8a6a4ac..ae336fd8a 100644 --- a/ui/src/components/roleManagementForm.tsx +++ b/ui/src/components/roleManagementForm.tsx @@ -1,8 +1,8 @@ -import React, { CSSProperties, useEffect, useState } from 'react'; -import { BackTop, Button, Form, Input, Select, Space } from 'antd'; +import React, { CSSProperties, useEffect, useState } from "react"; +import { BackTop, Button, Form, Input, Select, Space } from "antd"; import { Navigate } from "react-router-dom"; -import { addUserRole} from '../api'; -import { UpCircleOutlined } from '@ant-design/icons' +import { addUserRole } from "../api"; +import { UpCircleOutlined } from "@ant-design/icons"; import { Role, UserRole } from "../models/model"; type RoleManagementFormProps = { @@ -11,11 +11,14 @@ type RoleManagementFormProps = { userRole?: UserRole; }; -const Admin = "Admin" -const Producer = "Producer" -const Consumer = "Consumer" +const Admin = "admin"; +const Producer = "producer"; +const Consumer = "consumer"; -const RoleManagementForm: React.FC = ({ editMode, userRole }) => { +const RoleManagementForm: React.FC = ({ + editMode, + userRole, +}) => { const [fireRedirect] = useState(false); const [createLoading, setCreateLoading] = useState(false); @@ -33,9 +36,9 @@ const RoleManagementForm: React.FC = ({ editMode, userR const roleForm: Role = form.getFieldsValue(); await addUserRole(roleForm); setCreateLoading(false); - } + }; - const styling: CSSProperties = { width: "92%" } + const styling: CSSProperties = { width: "92%" }; return ( <>
= ({ editMode, userR - + - - - - + + + - - + + + - { fireRedirect && () } + {fireRedirect && } ); }; -export default RoleManagementForm +export default RoleManagementForm; diff --git a/ui/src/components/userRoles.tsx b/ui/src/components/userRoles.tsx index 2fe18f005..537653a8e 100644 --- a/ui/src/components/userRoles.tsx +++ b/ui/src/components/userRoles.tsx @@ -1,148 +1,167 @@ -import React, { useCallback, useEffect, useState } from 'react'; -import { useNavigate } from 'react-router-dom'; -import { Button, Modal, PageHeader, Row, Space, Table, Tag } from "antd"; +import React, { useCallback, useEffect, useState } from "react"; +import { useNavigate } from "react-router-dom"; +import { + Button, + Menu, + message, + PageHeader, + Popconfirm, + Row, + Space, + Table, + Tag, +} from "antd"; import { UserRole } from "../models/model"; -import { listUserRole } from "../api"; +import { deleteUserRole, listUserRole } from "../api"; const UserRoles: React.FC = () => { - const navigate = useNavigate(); - const [visible, setVisible] = React.useState(false); - const [confirmLoading, setConfirmLoading] = React.useState(false); - const [modalText, setModalText] = React.useState('Content of the modal'); + const navigate = useNavigate(); - const showModal = () => { - setVisible(true); - setModalText("This Role Assignment will be deleted."); - }; - const handleOk = () => { - setModalText('The modal will be closed after two seconds'); - setConfirmLoading(true); - setTimeout(() => { - setVisible(false); - setConfirmLoading(false); - }, 2000); - }; + const onDelete = async (row: UserRole) => { + console.log( + `The [${row.roleName}] Role of [${row.userName}] user role delete request is sent.` + ); + const res = await deleteUserRole(row); + if (res.status === 200) { + message.success(`Role ${row.roleName} of user ${row.userName} deleted`); + } else { + message.error("Failed to delete userrole."); + } + setLoading(false); + fetchData(); + }; - const handleCancel = () => { - console.log('Clicked cancel button'); - setVisible(false); - }; - const columns = [ - { - title:
Scope
, - dataIndex: 'scope', - key: 'scope', - align: 'center' as 'center', - }, - { - title:
User
, - dataIndex: 'userName', - key: 'userName', - align: 'center' as 'center', - }, - { - title:
Role
, - dataIndex: 'roleName', - key: 'roleName', - align: 'center' as 'center', - }, - { - title:
Permissions
, - key: 'permissions', - dataIndex: 'permissions', - render: (tags: any[]) => ( - <> - {tags.map(tag => { - let color = tag.length > 5 ? 'red' : 'green'; - if (tag === 'Write') color = 'blue' - return ( - - {tag.toUpperCase()} - - ); - })} - - ), - }, - { - title:
Create Reason
, - dataIndex: 'createReason', - key: 'createReason', - align: 'center' as 'center', - }, - { - title:
Create Time
, - dataIndex: 'createTime', - key: 'createTime', - align: 'center' as 'center', - }, - { - title: 'Action', - key: 'action', - render: () => ( - - - -

{modalText}

-
-
- ), - }, - ]; - const [page, setPage] = useState(1); - const [, setLoading] = useState(false); - const [tableData, setTableData] = useState(); + const columns = [ + { + title:
Scope (Project/Global)
, + dataIndex: "scope", + key: "scope", + align: "center" as "center", + }, + { + title:
Role
, + dataIndex: "roleName", + key: "roleName", + align: "center" as "center", + }, + { + title:
User
, + dataIndex: "userName", + key: "userName", + align: "center" as "center", + }, + { + title:
Permissions
, + key: "access", + dataIndex: "access", + render: (tags: any[]) => ( + <> + {tags.map((tag) => { + let color = tag.length > 5 ? "red" : "green"; + if (tag === "write") color = "blue"; + return ( + + {tag.toUpperCase()} + + ); + })} + + ), + }, + { + title:
Create By
, + dataIndex: "createBy", + key: "createBy", + align: "center" as "center", + }, + { + title:
Create Reason
, + dataIndex: "createReason", + key: "createReason", + align: "center" as "center", + }, + { + title:
Create Time
, + dataIndex: "createTime", + key: "createTime", + align: "center" as "center", + }, + { + title: "Action", + key: "action", + render: (userName: string, row: UserRole) => ( + + + + { + onDelete(row); + }} + > + Delete + + + + + ), + }, + ]; + const [page, setPage] = useState(1); + const [, setLoading] = useState(false); + const [tableData, setTableData] = useState(); - const fetchData = useCallback(async () => { - setLoading(true); - const result = await listUserRole(); - console.log(result); - setPage(page); - setTableData(result); - setLoading(false); - }, [page]) + const fetchData = useCallback(async () => { + setLoading(true); + const result = await listUserRole(); + console.log(result); + setPage(page); + setTableData(result); + setLoading(false); + }, [page]); - const onClickRoleAssign = () => { - navigate('/role-management'); - return; - } + const onClickRoleAssign = () => { + navigate("/role-management"); + return; + }; - useEffect(() => { - fetchData() - }, [fetchData]) + useEffect(() => { + fetchData(); + }, [fetchData]); - return ( -
- - -
- <> -

- Below is the mock data for now. Will connect with Management APIs. -

- -
-
-
- - - - ; - - ); -} + return ( +
+ + +
+ <> +

+ This page is protected by Feathr Access Control. Only Global Admin can retrieve management details and grant or delete user roles. +

+ +
+
+
+ + + +
; + + ); +}; export default UserRoles; diff --git a/ui/src/models/model.ts b/ui/src/models/model.ts index 71937ad73..8746c98ee 100644 --- a/ui/src/models/model.ts +++ b/ui/src/models/model.ts @@ -63,10 +63,13 @@ export interface UserRole { scope: string; userName: string; roleName: string; + createBy: string; createTime: string; createReason: string; + deleteBy: string; deleteTime?: any; deleteReason?: any; + access?: string; } export interface Role { From c4be2552cb5ff355811eb638dfa538e0c50feffd Mon Sep 17 00:00:00 2001 From: Yuqing Wei Date: Fri, 1 Jul 2022 18:02:52 +0800 Subject: [PATCH 2/3] add new router, enhance config and 401 exception --- registry/access_control/.env | 4 ++++ registry/access_control/access.py | 9 +++++++++ registry/access_control/auth.py | 12 +++++++----- registry/access_control/config.py | 7 ++++--- registry/access_control/db_rbac.py | 3 +-- registry/api.py | 15 +++++++++++---- ui/src/api/api.tsx | 5 +++-- 7 files changed, 39 insertions(+), 16 deletions(-) create mode 100644 registry/access_control/.env diff --git a/registry/access_control/.env b/registry/access_control/.env new file mode 100644 index 000000000..11635dc52 --- /dev/null +++ b/registry/access_control/.env @@ -0,0 +1,4 @@ +API_BASE=/api/v1 +AAD_TENANT_ID = common +REGISTRY_URL=https://feathr-sql-registry.azurewebsites.net/api/v1 +CONNECTION_STR= \ No newline at end of file diff --git a/registry/access_control/access.py b/registry/access_control/access.py index 7f9a04638..90aae57a8 100644 --- a/registry/access_control/access.py +++ b/registry/access_control/access.py @@ -47,3 +47,12 @@ def global_admin_access(user: User = Depends(authorize)): return user else: raise ForbiddenAccess('Admin privileges required') + +def validate_project_access_for_feature(feature:str, user:str, access:str): + project = _get_project_from_feature(feature) + _project_access(project, user, access) + + +def _get_project_from_feature(feature: str): + feature_delimiter = "__request_features__" + return feature.split(feature_delimiter)[0] \ No newline at end of file diff --git a/registry/access_control/auth.py b/registry/access_control/auth.py index 0139508c0..2528964cb 100644 --- a/registry/access_control/auth.py +++ b/registry/access_control/auth.py @@ -40,9 +40,12 @@ def __init__(self, aad_instance: str = config.AAD_INSTANCE, aad_tenant: str = co async def __call__(self, request: Request) -> User: bearer_token: str = request.headers.get("authorization") - token = bearer_token[len(BEARER_TOKEN):] - decoded_token = self._decode_token(token) - return self._get_user_from_token(decoded_token) + if bearer_token: + token = bearer_token[len(BEARER_TOKEN):] + decoded_token = self._decode_token(token) + return self._get_user_from_token(decoded_token) + else: + raise InvalidAuthorization(detail='No authorization token was found') @staticmethod def _get_user_from_token(decoded_token: Mapping) -> User: @@ -50,8 +53,7 @@ def _get_user_from_token(decoded_token: Mapping) -> User: user_id = decoded_token['oid'] except Exception as e: logging.debug(e) - raise InvalidAuthorization( - detail='Unable to extract user details from token') + raise InvalidAuthorization(detail='Unable to extract user details from token') return User( id=user_id, diff --git a/registry/access_control/config.py b/registry/access_control/config.py index d9175ee1e..c2b0688db 100644 --- a/registry/access_control/config.py +++ b/registry/access_control/config.py @@ -1,7 +1,8 @@ +import os from starlette.config import Config - -config = Config(".env") +env_file = os.path.join("registry", "access_control", ".env") +config = Config(os.path.abspath(env_file)) # API Settings API_BASE: str = config("API_BASE", default = "/api/v1") @@ -10,7 +11,7 @@ API_CLIENT_ID: str = config( "API_CLIENT_ID", default="db8dc4b0-202e-450c-b38d-7396ad9631a5") AAD_TENANT_ID: str = config( - "AAD_TENANT_ID", default="72f988bf-86f1-41af-91ab-2d7cd011db47") + "AAD_TENANT_ID", default="common") AAD_INSTANCE: str = config( "AAD_INSTANCE", default="https://login.microsoftonline.com") API_AUDIENCE: str = config( diff --git a/registry/access_control/db_rbac.py b/registry/access_control/db_rbac.py index 53dccd51f..285fd7cd6 100644 --- a/registry/access_control/db_rbac.py +++ b/registry/access_control/db_rbac.py @@ -5,10 +5,9 @@ import os import logging - class DbRBAC(RBAC): def __init__(self): - if not os.environ["CONNECTION_STR"]: + if not os.environ.get("CONNECTION_STR"): os.environ["CONNECTION_STR"] = config.CONNECTION_STR self.conn = connect() # cached user role records diff --git a/registry/api.py b/registry/api.py index 7520a064d..3e1f6d513 100644 --- a/registry/api.py +++ b/registry/api.py @@ -2,7 +2,7 @@ from typing import Optional from fastapi import APIRouter, Depends import requests -from access_control.access import global_admin_access, get_user, project_read_access, project_write_access +from access_control.access import * from access_control.models import User from access_control.db_rbac import DbRBAC from access_control import config @@ -23,7 +23,6 @@ async def get_project(project: str, requestor: User = Depends(project_read_acces return json.loads(response) - @router.get("/projects/{project}/datasources", name="Get data sources of my project [Read Access Required]") def get_project_datasources(project: str, requestor: User = Depends(project_read_access)) -> list: response = requests.get(registry_url + "/projects/" + project + "/datasources").content.decode('utf-8') @@ -36,11 +35,19 @@ def get_project_features(project: str, keyword: Optional[str] = None, requestor: return json.loads(response) -@router.get("/features/{feature}/{project}", name="Get a single feature by feature Id [Read Access Required]") -def get_feature(feature: str, requestor: User = Depends(project_read_access)) -> dict: +@router.get("/features/{feature}", name="Get a single feature by feature Id [Read Access Required]") +def get_feature(feature: str, project: str, requestor: User = Depends(project_read_access)) -> dict: response = requests.get(registry_url + "/features/" + feature).content.decode('utf-8') return json.loads(response) +# To make sure the consistent experience of Registry API and Access Control Plugin. +# Even if user doesn't provide the project name, the API can still work. +@router.get("/features/{feature}", name="Get a single feature by feature Id [Read Access Required]") +def get_feature(feature: str, requestor: User = Depends(get_user)) -> dict: + response = requests.get(registry_url + "/features/" + feature).content.decode('utf-8') + ret = json.loads(response) + validate_project_access_for_feature(ret["qualified_name"], requestor, AccessType.READ) + return ret @router.get("/features/{feature}/lineage/{project}", name="Get Feature Lineage [Read Access Required]") def get_feature_lineage(feature: str, requestor: User = Depends(project_read_access)) -> dict: diff --git a/ui/src/api/api.tsx b/ui/src/api/api.tsx index 6201eb17f..bcb25bd21 100644 --- a/ui/src/api/api.tsx +++ b/ui/src/api/api.tsx @@ -58,7 +58,9 @@ export const fetchFeatures = async ( export const fetchFeature = async (project: string, featureId: string) => { const axios = await authAxios(msalInstance); return axios - .get(`${getApiBaseUrl()}/features/${featureId}/${project}`, {}) + .get(`${getApiBaseUrl()}/features/${featureId}`, { + params: { project: project} + }) .then((response) => { return response.data; }); @@ -116,7 +118,6 @@ export const updateFeature = async (feature: Feature, id: string) => { export const listUserRole = async () => { const token = await getIdToken(msalInstance); - console.log(token) const axios = await authAxios(msalInstance); return await axios .get(`${getApiBaseUrl()}/userroles`, {}) From 488c185981caf8a2195ed10178538e0034f5ef6a Mon Sep 17 00:00:00 2001 From: Yuqing Wei Date: Mon, 4 Jul 2022 14:24:35 +0800 Subject: [PATCH 3/3] fix comments update sql query & enhance config --- registry/access_control/.env | 9 +++-- registry/access_control/access.py | 2 +- registry/access_control/auth.py | 2 +- registry/access_control/config.py | 21 ++++++----- registry/access_control/db_rbac.py | 57 +++++++++++++++++------------- 5 files changed, 51 insertions(+), 40 deletions(-) diff --git a/registry/access_control/.env b/registry/access_control/.env index 11635dc52..12f0c90bd 100644 --- a/registry/access_control/.env +++ b/registry/access_control/.env @@ -1,4 +1,7 @@ -API_BASE=/api/v1 +API_BASE = /api/v1 +API_CLIENT_ID = db8dc4b0-202e-450c-b38d-7396ad9631a5 AAD_TENANT_ID = common -REGISTRY_URL=https://feathr-sql-registry.azurewebsites.net/api/v1 -CONNECTION_STR= \ No newline at end of file +AAD_INSTANCE = https://login.microsoftonline.com +API_AUDIENCE = db8dc4b0-202e-450c-b38d-7396ad9631a5 +REGISTRY_URL = https://feathr-sql-registry.azurewebsites.net/api/v1 +CONNECTION_STR = \ No newline at end of file diff --git a/registry/access_control/access.py b/registry/access_control/access.py index 90aae57a8..b067670db 100644 --- a/registry/access_control/access.py +++ b/registry/access_control/access.py @@ -43,7 +43,7 @@ def _project_access(project: str, user: User, access: str): def global_admin_access(user: User = Depends(authorize)): - if user.preferred_username in rbac.global_admin: + if user.preferred_username in rbac.get_global_admin_users(): return user else: raise ForbiddenAccess('Admin privileges required') diff --git a/registry/access_control/auth.py b/registry/access_control/auth.py index 2528964cb..2cd11012f 100644 --- a/registry/access_control/auth.py +++ b/registry/access_control/auth.py @@ -122,7 +122,7 @@ def _decode_token(self, token: str) -> Mapping: raise InvalidAuthorization('The token is invalid') except Exception as e: logging.debug(f'Unexpected error: {e}') - raise InvalidAuthorization('Unable to decode token') + raise InvalidAuthorization('Unable to decode token, error: {e}') authorize = AzureADAuth() diff --git a/registry/access_control/config.py b/registry/access_control/config.py index c2b0688db..502a9ec66 100644 --- a/registry/access_control/config.py +++ b/registry/access_control/config.py @@ -4,21 +4,20 @@ env_file = os.path.join("registry", "access_control", ".env") config = Config(os.path.abspath(env_file)) +def _get_config(key:str, default:str = "", config:Config = config): + return os.environ.get(key) or config.get(key, default=default) + # API Settings -API_BASE: str = config("API_BASE", default = "/api/v1") +API_BASE: str = _get_config("API_BASE", default="/api/v1") # Authentication -API_CLIENT_ID: str = config( - "API_CLIENT_ID", default="db8dc4b0-202e-450c-b38d-7396ad9631a5") -AAD_TENANT_ID: str = config( - "AAD_TENANT_ID", default="common") -AAD_INSTANCE: str = config( - "AAD_INSTANCE", default="https://login.microsoftonline.com") -API_AUDIENCE: str = config( - "API_AUDIENCE", default="db8dc4b0-202e-450c-b38d-7396ad9631a5") +API_CLIENT_ID: str = _get_config("API_CLIENT_ID", default="db8dc4b0-202e-450c-b38d-7396ad9631a5") +AAD_TENANT_ID: str = _get_config("AAD_TENANT_ID", default="common") +AAD_INSTANCE: str = _get_config("AAD_INSTANCE", default="https://login.microsoftonline.com") +API_AUDIENCE: str = _get_config("API_AUDIENCE", default="db8dc4b0-202e-450c-b38d-7396ad9631a5") # SQL Database -CONNECTION_STR: str = config("CONNECTION_STR", default= "") +CONNECTION_STR: str = _get_config("CONNECTION_STR", default= "") # Downstream API Endpoint -REGISTRY_URL: str = config("REGISTRY_URL", default= "https://feathr-sql-registry.azurewebsites.net/api/v1") +REGISTRY_URL: str = _get_config("REGISTRY_URL", default= "https://feathr-sql-registry.azurewebsites.net/api/v1") diff --git a/registry/access_control/db_rbac.py b/registry/access_control/db_rbac.py index 285fd7cd6..eed5269d3 100644 --- a/registry/access_control/db_rbac.py +++ b/registry/access_control/db_rbac.py @@ -1,3 +1,4 @@ +from linecache import cache from access_control import config from common.database import connect from access_control.models import AccessType, UserRole, RoleType, SUPER_ADMIN_SCOPE @@ -10,13 +11,11 @@ def __init__(self): if not os.environ.get("CONNECTION_STR"): os.environ["CONNECTION_STR"] = config.CONNECTION_STR self.conn = connect() - # cached user role records - self._refresh_cache() + self.get_userroles() - def _refresh_cache(self): - #TODO: add user role cache refresh interval, e.g. daily + def get_userroles(self): + # Cache is not supported in cluster, make sure every operation read from database. self.userroles = self._get_userroles() - self.global_admin = self._get_global_admin_users() def _get_userroles(self) -> list[UserRole]: """query all the active user role records in SQL table @@ -27,15 +26,16 @@ def _get_userroles(self) -> list[UserRole]: where delete_reason is null""") ret = [] for row in rows: - r = UserRole(**row) ret.append(UserRole(**row)) logging.info(f"{ret.__len__} user roles are get.") return ret - def _get_global_admin_users(self) -> list[str]: + def get_global_admin_users(self) -> list[str]: + self.get_userroles() return [u.user_name for u in self.userroles if (u.project_name == SUPER_ADMIN_SCOPE and u.role_name == RoleType.ADMIN)] def validate_project_access_users(self, project: str, user: str, access: str = AccessType.READ) -> bool: + self.get_userroles() for u in self.userroles: if (u.user_name == user and u.project_name in [project, SUPER_ADMIN_SCOPE] and (access in u.access)): return True @@ -46,10 +46,12 @@ def get_userroles_by_user(self, user_name: str, role_name: str = None) -> list[U """ query = fr"""select record_id, project_name, user_name, role_name, create_by, create_reason, create_time, delete_by, delete_reason, delete_time from userroles - where delete_reason is null and user_name ='{user_name}'""" + where delete_reason is null and user_name ='%s'""" if role_name: - query += fr"and role_name = '{role_name}'" - rows = self.conn.query(query) + query += fr"and role_name = '%s'" + rows = self.conn.query(query%(user_name, role_name)) + else: + rows = self.conn.query(query%(user_name)) ret = [] for row in rows: ret.append(UserRole(**row)) @@ -60,10 +62,12 @@ def get_userroles_by_project(self, project_name: str, role_name: str = None) -> """ query = fr"""select record_id, project_name, user_name, role_name, create_by, create_reason, create_time, delete_reason, delete_time from userroles - where delete_reason is null and project_name ='{project_name}'""" + where delete_reason is null and project_name ='%s'""" if role_name: - query += fr"and role_name = '{role_name}'" - rows = self.conn.query(query) + query += fr"and role_name = '%s'" + rows = self.conn.query(query%(project_name, role_name)) + else: + rows = self.conn.query(query%(project_name)) ret = [] for row in rows: ret.append(UserRole(**row)) @@ -73,6 +77,7 @@ def add_userrole(self, project_name: str, user_name: str, role_name: str, create """insert new user role relationship into sql table """ # check if record already exist + self.get_userroles() for u in self.userroles: if u.project_name == project_name and u.user_name == user_name and u.role_name == role_name: logging.warning( @@ -80,23 +85,25 @@ def add_userrole(self, project_name: str, user_name: str, role_name: str, create return True # insert new record - query = f"""insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time) - values ('{project_name}','{user_name}','{role_name}','{by}' ,'{create_reason}', getutcdate())""" - self.conn.update(query) - self._refresh_cache() + query = fr"""insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time) + values ('%s','%s','%s','%s' ,'%s', getutcdate())""" + self.conn.update(query%(project_name, user_name, role_name, by, create_reason)) + logging.info(f"Userrole added with query: {query%(project_name, user_name, role_name, by, create_reason)}") + self.get_userroles() return def delete_userrole(self, project_name: str, user_name: str, role_name: str, delete_reason: str, by: str): """mark existing user role relationship as deleted with reason """ query = fr"""UPDATE userroles SET - [delete_by] = '{by}', - [delete_reason] = '{delete_reason}', + [delete_by] = '%s', + [delete_reason] = '%s', [delete_time] = getutcdate() - WHERE [user_name] = '{user_name}' and [project_name] = '{project_name}' and [role_name] = '{role_name}' + WHERE [user_name] = '%s' and [project_name] = '%s' and [role_name] = '%s' and [delete_time] is null""" - self.conn.update(query) - self._refresh_cache() + self.conn.update(query%(by, delete_reason, user_name, project_name, role_name)) + logging.info(f"Userrole removed with query: {query%(by, delete_reason, user_name, project_name, role_name)}") + self.get_userroles() return def init_userrole(self, creator_name: str, project_name: str): @@ -106,5 +113,7 @@ def init_userrole(self, creator_name: str, project_name: str): create_by = "system" create_reason = "creator of project, get admin by default." query = fr"""insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time) - values ('{project_name}','{creator_name}','{RoleType.ADMIN}','{create_by}','{create_reason}', getutcdate())""" - return self.conn.update(query) + values ('%s','%s','%s','%s','%s', getutcdate())""" + self.conn.update(query%(project_name, creator_name, RoleType.ADMIN, create_by, create_reason)) + logging.info(f"Userrole initialized with query: {query%(project_name, creator_name, RoleType.ADMIN, create_by, create_reason)}") + return self.get_userroles()