Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ run-name: ${{ github.actor }} is running tests
on:
push:
branches:
- '*'
- main
pull_request:
branches:
- main
Expand Down
39 changes: 39 additions & 0 deletions migration_scripts/v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# TODO: some smarter way of handling migrations

import apsw
import apsw.bestpractice

apsw.bestpractice.apply(apsw.bestpractice.recommended)

from millipds import static_config

with apsw.Connection(static_config.MAIN_DB_PATH) as con:
version_now, *_ = con.execute("SELECT db_version FROM config").fetchone()

assert version_now == 1

con.execute(
"""
CREATE TABLE did_cache(
did TEXT PRIMARY KEY NOT NULL,
doc TEXT,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL
)
"""
)

con.execute(
"""
CREATE TABLE handle_cache(
handle TEXT PRIMARY KEY NOT NULL,
did TEXT,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL
)
"""
)

con.execute("UPDATE config SET db_version=2")

print("v1 -> v2 Migration successful")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"pyjwt[crypto]",
"cryptography",
"aiohttp",
"aiodns", # goes faster, apparently
"aiohttp-middlewares", # cors
"docopt",
"apsw",
Expand Down
6 changes: 4 additions & 2 deletions src/millipds/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@
from getpass import getpass

from docopt import docopt
import aiohttp
from .ssrf import get_ssrf_safe_client


import cbrrr

Expand Down Expand Up @@ -234,7 +235,8 @@ def main():
elif args["run"]:

async def run_service_with_client():
async with aiohttp.ClientSession() as client:
# TODO: option to use regular unsafe client for local dev testing
async with get_ssrf_safe_client() as client:
await service.run(
db=db,
client=client,
Expand Down
8 changes: 8 additions & 0 deletions src/millipds/app_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aiohttp import web

from . import database
from .did import DIDResolver

MILLIPDS_DB = web.AppKey("MILLIPDS_DB", database.Database)
MILLIPDS_AIOHTTP_CLIENT = web.AppKey(
Expand All @@ -16,6 +17,7 @@
MILLIPDS_FIREHOSE_QUEUES_LOCK = web.AppKey(
"MILLIPDS_FIREHOSE_QUEUES_LOCK", asyncio.Lock
)
MILLIPDS_DID_RESOLVER = web.AppKey("MILLIPDS_DID_RESOLVER", DIDResolver)


# these helpers are useful for conciseness and type hinting
Expand All @@ -35,13 +37,19 @@ def get_firehose_queues_lock(req: web.Request):
return req.app[MILLIPDS_FIREHOSE_QUEUES_LOCK]


def get_did_resolver(req: web.Request):
return req.app[MILLIPDS_DID_RESOLVER]


__all__ = [
"MILLIPDS_DB",
"MILLIPDS_AIOHTTP_CLIENT",
"MILLIPDS_FIREHOSE_QUEUES",
"MILLIPDS_FIREHOSE_QUEUES_LOCK",
"MILLIPDS_DID_RESOLVER",
"get_db",
"get_client",
"get_firehose_queues",
"get_firehose_queues_lock",
"get_did_resolver",
]
27 changes: 15 additions & 12 deletions src/millipds/appview_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,6 @@
logger = logging.getLogger(__name__)


# TODO: this should be done via actual DID resolution, not hardcoded!
SERVICE_ROUTES = {
"did:web:api.bsky.chat#bsky_chat": "https://api.bsky.chat",
"did:web:discover.bsky.app#bsky_fg": "https://discover.bsky.app",
"did:plc:ar7c4by46qjdydhdevvrndac#atproto_labeler": "https://mod.bsky.app",
}


@authenticated
async def service_proxy(request: web.Request, service: Optional[str] = None):
"""
Expand All @@ -30,11 +22,22 @@ async def service_proxy(request: web.Request, service: Optional[str] = None):
logger.info(f"proxying lxm {lxm}")
db = get_db(request)
if service:
service_did = service.partition("#")[0]
service_route = SERVICE_ROUTES.get(service)
if service_route is None:
service_did, _, fragment = service.partition("#")
fragment = "#" + fragment
did_doc = await get_did_resolver(request).resolve_with_db_cache(
db, service_did
)
if did_doc is None:
return web.HTTPInternalServerError(
f"unable to resolve service {service!r}"
)
for service in did_doc.get("service", []):
if service.get("id") == fragment:
service_route = service["serviceEndpoint"]
break
else:
return web.HTTPBadRequest(f"unable to resolve service {service!r}")
else:
else: # fall thru to assuming bsky appview
service_did = db.config["bsky_appview_did"]
service_route = db.config["bsky_appview_pfx"]

Expand Down
39 changes: 33 additions & 6 deletions src/millipds/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,21 @@ class Database:
def __init__(self, path: str = static_config.MAIN_DB_PATH) -> None:
logger.info(f"opening database at {path}")
self.path = path
util.mkdirs_for_file(path)
if "/" in path:
util.mkdirs_for_file(path)
self.con = self.new_con()
self.pw_hasher = argon2.PasswordHasher()

try:
config_exists = self.con.execute(
"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='config'"
).fetchone()[0]

if config_exists:
if self.config["db_version"] != static_config.MILLIPDS_DB_VERSION:
raise Exception(
"unrecognised db version (TODO: db migrations?!)"
)

except apsw.SQLError as e: # no such table, so lets create it
if "no such table" not in str(e):
raise
else:
with self.con:
self._init_tables()

Expand Down Expand Up @@ -216,6 +218,31 @@ def _init_tables(self):
"""
)

# we cache failures too, represented as a null doc (with shorter TTL)
# timestamps are unix timestamp ints, in seconds
self.con.execute(
"""
CREATE TABLE did_cache(
did TEXT PRIMARY KEY NOT NULL,
doc TEXT,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL
)
"""
)

# likewise, a null did represents a failed resolution
self.con.execute(
"""
CREATE TABLE handle_cache(
handle TEXT PRIMARY KEY NOT NULL,
did TEXT,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL
)
"""
)

def update_config(
self,
pds_pfx: Optional[str] = None,
Expand Down
180 changes: 180 additions & 0 deletions src/millipds/did.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import aiohttp
import asyncio
from typing import Dict, Callable, Any, Awaitable, Optional
import re
import json
import time
import logging

from .database import Database
from . import util
from . import static_config

logger = logging.getLogger(__name__)

DIDDoc = Dict[str, Any]

"""
Security considerations for DID resolution:

- SSRF - not handled here!!! - caller must pass in an "SSRF safe" ClientSession
- Overly long DID strings (handled here via a hard limit (2KiB))
- Overly long DID document responses (handled here via a hard limit (64KiB))
- Servers that are slow to respond (handled via timeouts configured in the ClientSession)
- Non-canonically-encoded DIDs (handled here via strict regex - for now we don't support percent-encoding at all)

"""


class DIDResolver:
DID_LENGTH_LIMIT = 2048
DIDDOC_LENGTH_LIMIT = 0x10000

def __init__(
self,
session: aiohttp.ClientSession,
plc_directory_host: str = static_config.PLC_DIRECTORY_HOST,
) -> None:
self.session: aiohttp.ClientSession = session
self.plc_directory_host: str = plc_directory_host
self.did_methods: Dict[str, Callable[[str], Awaitable[DIDDoc]]] = {
"web": self.resolve_did_web,
"plc": self.resolve_did_plc,
}

self._concurrent_query_locks = util.PartitionedLock()

# keep stats for logging
self.hits = 0
self.misses = 0

# note: the uncached methods raise exceptions on failure, but this one returns None
async def resolve_with_db_cache(
self, db: Database, did: str
) -> Optional[DIDDoc]:
"""
If we fired off two concurrent queries for the same DID, the second would
be a waste of resources. By using a per-DID locking scheme, we ensure that
any subsequent queries wait for the first one to complete - by which time
the cache will be primed and the second query can return the cached result.

TODO: maybe consider an in-memory cache, too? Probably not worth it.
"""
async with self._concurrent_query_locks.get_lock(did):
# try the db first
now = int(time.time())
row = db.con.execute(
"SELECT doc FROM did_cache WHERE did=? AND ?<expires_at",
(did, now),
).fetchone()

# cache hit
if row is not None:
self.hits += 1
doc = row[0]
return None if doc is None else json.loads(doc)

# cache miss
self.misses += 1
logger.info(
f"DID cache miss for {did}. Total hits: {self.hits}, Total misses: {self.misses}"
)
try:
doc = await self.resolve_uncached(did)
logger.info(f"Successfully resolved {did}")
except Exception as e:
logger.exception(f"Error resolving {did}: {e}")
doc = None

# update "now" because resolution might've taken a while
now = int(time.time())
expires_at = now + (
static_config.DID_CACHE_ERROR_TTL
if doc is None
else static_config.DID_CACHE_TTL
)

# update the cache (note: we cache failures too, but with a shorter TTL)
# TODO: if current doc is None, only replace if the existing entry is also None
db.con.execute(
"INSERT OR REPLACE INTO did_cache (did, doc, created_at, expires_at) VALUES (?, ?, ?, ?)",
(
did,
None if doc is None else util.compact_json(doc),
now,
expires_at,
),
)

return doc

async def resolve_uncached(self, did: str) -> DIDDoc:
if len(did) > self.DID_LENGTH_LIMIT:
raise ValueError("DID too long for atproto")
scheme, method, *_ = did.split(":")
if scheme != "did":
raise ValueError("not a valid DID")
resolver = self.did_methods.get(method)
if resolver is None:
raise ValueError(f"Unsupported DID method: {method}")
return await resolver(did)

# 64k ought to be enough for anyone!
async def _get_json_with_limit(self, url: str, limit: int) -> DIDDoc:
async with self.session.get(url) as r:
r.raise_for_status()
try:
await r.content.readexactly(limit)
raise ValueError("DID document too large")
except asyncio.IncompleteReadError as e:
# this is actually the happy path
return json.loads(e.partial)

async def resolve_did_web(self, did: str) -> DIDDoc:
# TODO: support port numbers on localhost?
if not re.match(r"^did:web:[a-z0-9\.\-]+$", did):
raise ValueError("Invalid did:web")
host = did.rpartition(":")[2]

return await self._get_json_with_limit(
f"https://{host}/.well-known/did.json", self.DIDDOC_LENGTH_LIMIT
)

async def resolve_did_plc(self, did: str) -> DIDDoc:
if not re.match(r"^did:plc:[a-z2-7]+$", did): # base32-sortable
raise ValueError("Invalid did:plc")

return await self._get_json_with_limit(
f"{self.plc_directory_host}/{did}", self.DIDDOC_LENGTH_LIMIT
)


async def main() -> None:
# TODO: move these tests into a proper pytest file

async with aiohttp.ClientSession() as session:
TEST_DIDWEB = "did:web:retr0.id" # TODO: don't rely on external infra
resolver = DIDResolver(session)
print(await resolver.resolve_uncached(TEST_DIDWEB))
print(
await resolver.resolve_uncached("did:plc:vwzwgnygau7ed7b7wt5ux7y2")
)

db = Database(":memory:")
a = resolver.resolve_with_db_cache(db, TEST_DIDWEB)
b = resolver.resolve_with_db_cache(db, TEST_DIDWEB)
res_a, res_b = await asyncio.gather(a, b)
assert res_a == res_b

# if not for _concurrent_query_locks, we'd have 2 misses and 0 hits
# (because the second query would start before the first one finishes
# and primes the cache)
assert resolver.hits == 1
assert resolver.misses == 1

# check that the WeakValueDictionary is doing its thing (i.e. no memory leaks)
assert list(resolver._concurrent_query_locks._locks.keys()) == []


if __name__ == "__main__":
asyncio.run(main())
Loading
Loading