From 55a1eae19b4eb5307847ad221c46c6043159b877 Mon Sep 17 00:00:00 2001 From: Yuqing Wei Date: Wed, 3 Aug 2022 18:09:53 +0800 Subject: [PATCH 1/2] extend rbac registry apis to post ones --- registry/access_control/api.py | 14 ++++----- registry/access_control/rbac/db_rbac.py | 39 ++++++++++++++++++++----- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/registry/access_control/api.py b/registry/access_control/api.py index a47c9233c..3e7343e47 100644 --- a/registry/access_control/api.py +++ b/registry/access_control/api.py @@ -66,29 +66,29 @@ def get_feature_lineage(feature: str, requestor: User = Depends(get_user)) -> di @router.post("/projects", name="Create new project with definition [Auth Required]") def new_project(definition: dict, requestor: User = Depends(get_user)) -> dict: - rbac.init_userrole(requestor, definition["name"]) - response = requests.post(url=f"{registry_url}/projects", params=definition, + rbac.init_userrole(requestor.username, definition["name"]) + response = requests.post(url=f"{registry_url}/projects", json=definition, headers=get_api_header(requestor)).content.decode('utf-8') return json.loads(response) @router.post("/projects/{project}/datasources", name="Create new data source of my project [Write Access Required]") def new_project_datasource(project: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict: - response = requests.post(url=f"{registry_url}/projects/{project}/datasources", params=definition, headers=get_api_header( + response = requests.post(url=f"{registry_url}/projects/{project}/datasources", json=definition, headers=get_api_header( requestor)).content.decode('utf-8') return json.loads(response) @router.post("/projects/{project}/anchors", name="Create new anchors of my project [Write Access Required]") def new_project_anchor(project: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict: - response = requests.post(url=f"{registry_url}/projects/{project}/anchors", params=definition, headers=get_api_header( + response = requests.post(url=f"{registry_url}/projects/{project}/anchors", json=definition, headers=get_api_header( requestor)).content.decode('utf-8') return json.loads(response) @router.post("/projects/{project}/anchors/{anchor}/features", name="Create new anchor features of my project [Write Access Required]") def new_project_anchor_feature(project: str, anchor: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict: - response = requests.post(url=f"{registry_url}/projects/{project}/anchors/{anchor}/features", params=definition, headers=get_api_header( + response = requests.post(url=f"{registry_url}/projects/{project}/anchors/{anchor}/features", json=definition, headers=get_api_header( requestor)).content.decode('utf-8') return json.loads(response) @@ -96,12 +96,10 @@ def new_project_anchor_feature(project: str, anchor: str, definition: dict, requ @router.post("/projects/{project}/derivedfeatures", name="Create new derived features of my project [Write Access Required]") def new_project_derived_feature(project: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict: response = requests.post(url=f"{registry_url}/projects/{project}/derivedfeatures", - params=definition, headers=get_api_header(requestor)).content.decode('utf-8') + json=definition, headers=get_api_header(requestor)).content.decode('utf-8') return json.loads(response) # Below are access control management APIs - - @router.get("/userroles", name="List all active user role records [Project Manage Access Required]") def get_userroles(requestor: User = Depends(get_user)) -> list: return rbac.list_userroles(requestor.username) diff --git a/registry/access_control/rbac/db_rbac.py b/registry/access_control/rbac/db_rbac.py index f822a4b32..e0f1f37ed 100644 --- a/registry/access_control/rbac/db_rbac.py +++ b/registry/access_control/rbac/db_rbac.py @@ -1,3 +1,5 @@ +from fastapi import HTTPException, status +from typing import Any from rbac import config from rbac.database import connect from rbac.models import AccessType, UserRole, RoleType, SUPER_ADMIN_SCOPE @@ -5,6 +7,11 @@ import os import logging +class BadRequest(HTTPException): + def __init__(self, detail: Any = None) -> None: + super().__init__(status_code=status.HTTP_400_BAD_REQUEST, + detail=detail, headers={"WWW-Authenticate": "Bearer"}) + class DbRBAC(RBAC): def __init__(self): @@ -122,16 +129,34 @@ def delete_userrole(self, project_name: str, user_name: str, role_name: str, del self.get_userroles() return - def init_userrole(self, creator_name: str, project_name: str): - """initialize user role relationship when a new project is created - TODO: project name cannot be `global`. + def init_userrole(self, creator_name: str, project_name:str): + """Project name validation and project admin initialization + """ + # project name cannot be `global` + if project_name.casefold() == SUPER_ADMIN_SCOPE.casefold(): + raise BadRequest(f"{SUPER_ADMIN_SCOPE} is keyword for Global Admin (admin of all projects), please try other project name.") + else: + # check if project already exist (have valid rbac records) + # no 400 exception to align the registry api behaviors + query = fr"""select project_name, user_name, role_name, create_by, create_reason, create_time, delete_reason, delete_time + from userroles + where delete_reason is null and project_name ='%s'""" + rows = self.conn.query(query%(project_name)) + if len(rows) > 0: + logging.warning(f"{project_name} already exist, please pick another name.") + return + else: + # initialize project admin if project not exist: + self.init_project_admin(creator_name, project_name) + + + def init_project_admin(self, creator_name: str, project_name: str): + """initialize the creator as project admin when a new project is created """ create_by = "system" create_reason = "creator of project, get admin by default." query = fr"""insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time) values ('%s','%s','%s','%s','%s', getutcdate())""" - self.conn.update(query % (project_name, creator_name, - RoleType.ADMIN.value, create_by, create_reason)) - logging.info( - f"Userrole initialized with query: {query%(project_name, creator_name, RoleType.ADMIN.value, create_by, create_reason)}") + self.conn.update(query % (project_name, creator_name, RoleType.ADMIN.value, create_by, create_reason)) + logging.info(f"Userrole initialized with query: {query%(project_name, creator_name, RoleType.ADMIN.value, create_by, create_reason)}") return self.get_userroles() From e190d377dfe2bba43f8787a2c95827cabb28b3ee Mon Sep 17 00:00:00 2001 From: Yuqing Wei Date: Thu, 4 Aug 2022 14:33:41 +0800 Subject: [PATCH 2/2] update auth.py to support live.com CLI tokens --- registry/access_control/rbac/auth.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/registry/access_control/rbac/auth.py b/registry/access_control/rbac/auth.py index 8b7af4803..3a4403bb8 100644 --- a/registry/access_control/rbac/auth.py +++ b/registry/access_control/rbac/auth.py @@ -66,8 +66,12 @@ def _get_user_from_token(decoded_token: Mapping) -> User: elif aad_app_key in decoded_token: appid = decoded_token.get(aad_app_key) # Azure CLI User Impersonation token - if decoded_token.get("scp") == str(UserType.USER_IMPERSONATION): - username = decoded_token.get("upn") + if decoded_token.get("scp") == str(UserType.USER_IMPERSONATION.value): + if "upn" in decoded_token: + username = decoded_token.get("upn") + # live.com account token doesn't have upn + else: + username = decoded_token.get("email") type = UserType.USER_IMPERSONATION # Other AAD App token else: