diff --git a/.dockerignore b/.dockerignore index 15a1166..ac1bde8 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,4 +1,6 @@ **/__pycache__ +.github +.pytest_cache .venv credentials db diff --git a/Dockerfile b/Dockerfile index c1d723d..6d4ec37 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,59 +1,62 @@ -FROM python:3.12 -# FROM python:3.12-alpine +# FROM python:3.13-slim +FROM python:3.13-alpine # can't install aerospike==8.0.0 on alpine -# RUN apk update && \ -# apk upgrade && \ -# apk add --no-cache git && \ -# apk add --no-cache --virtual \ -# .build-deps \ -# gcc \ -# g++ \ -# musl-dev \ -# libffi-dev \ -# postgresql-dev \ -# build-base \ -# bash linux-headers \ -# libuv libuv-dev \ -# openssl openssl-dev \ -# lua5.1 lua5.1-dev \ -# zlib zlib-dev \ -# python3-dev \ -# exiftool - -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt update && \ - apt upgrade -y && \ - apt install -y \ - build-essential \ - libssl-dev \ +RUN apk update && \ + apk upgrade && \ + apk add --no-cache git && \ + apk add --no-cache --virtual \ + .build-deps \ + gcc \ + g++ \ + musl-dev \ libffi-dev \ - git \ - jq \ - libpq-dev \ + postgresql-dev \ + build-base \ + bash linux-headers \ + libuv libuv-dev \ + openssl openssl-dev \ + lua5.1 lua5.1-dev \ + zlib zlib-dev \ python3-dev \ - libpng-dev \ - libjpeg-dev \ - libtiff-dev \ - libwebp-dev \ - imagemagick \ - libimage-exiftool-perl \ - ffmpeg + exiftool \ + jq + +# ENV DEBIAN_FRONTEND=noninteractive + +# RUN apt update && \ +# apt upgrade -y && \ +# apt install -y \ +# build-essential \ +# libssl-dev \ +# libffi-dev \ +# git \ +# jq \ +# libpq-dev \ +# python3-dev \ +# libpng-dev \ +# libjpeg-dev \ +# libtiff-dev \ +# libwebp-dev \ +# imagemagick \ +# libimage-exiftool-perl \ +# ffmpeg RUN rm -rf /var/lib/apt/lists/* WORKDIR /app COPY . /app +RUN chmod +x docker-exec.sh + RUN mkdir "images" -RUN rm -rf .venv && \ - rm -rf credentials RUN wget https://go.dev/dl/go1.22.5.linux-amd64.tar.gz && \ tar -xvf go1.22.5.linux-amd64.tar.gz -C /usr/local && \ rm go1.22.5.linux-amd64.tar.gz ENV GOROOT=/usr/local/go +# redefine $HOME to the default so that it stops complaining +ENV HOME="/root" ENV GOPATH=$HOME/go ENV PATH=$GOPATH/bin:$GOROOT/bin:$PATH @@ -67,6 +70,4 @@ ENV PATH="/opt/.venv/bin:$PATH" ENV PORT=80 ENV ENVIRONMENT=DEV -CMD jq '.fullchain' -r /etc/certs/cert.json > fullchain.pem && \ - jq '.privkey' -r /etc/certs/cert.json > privkey.pem && \ - gunicorn -w 2 -k uvicorn.workers.UvicornWorker --certfile fullchain.pem --keyfile privkey.pem -b 0.0.0.0:443 --timeout 1200 server:app +CMD ["./docker-exec.sh"] diff --git a/account/account.py b/account/account.py index bc38a8a..5d3d679 100644 --- a/account/account.py +++ b/account/account.py @@ -13,7 +13,7 @@ from shared.exceptions.http_error import BadRequest, Conflict, HttpError, HttpErrorHandler, Unauthorized from shared.hashing import Hashable from shared.models.auth import AuthToken, Scope -from shared.server import Request +from shared.models.server import Request from shared.sql import SqlInterface diff --git a/account/router.py b/account/router.py index eed29f5..4c7e8ae 100644 --- a/account/router.py +++ b/account/router.py @@ -6,15 +6,15 @@ from shared.config.constants import environment from shared.datetime import datetime from shared.exceptions.http_error import BadRequest -from shared.server import Request +from shared.models.server import Request from .account import Account, auth from .models import ChangeHandle, CreateAccountRequest, FinalizeAccountRequest, OtpFinalizeRequest, OtpRemoveEmailRequest, OtpRemoveRequest, OtpRequest app = APIRouter( - prefix='/v1/account', - tags=['account'], + prefix = '/v1/account', + tags = ['account'], ) account = Account() @@ -108,19 +108,7 @@ async def v1BotLogin(body: BotLoginRequest) -> LoginResponse : return await auth.botLogin(body.token) -@app.post('/bot/renew', response_model=BotCreateResponse) -async def v1BotRenew(req: Request) -> BotCreateResponse : - await req.user.verify_scope(Scope.internal) - return await auth.createBot(req.user, BotType.internal) - - @app.get('/bot/create', response_model=BotCreateResponse) async def v1BotCreate(req: Request) -> BotCreateResponse : await req.user.verify_scope(Scope.user) return await auth.createBot(req.user, BotType.bot) - - -@app.get('/bot/internal', response_model=BotCreateResponse) -async def v1BotCreateInternal(req: Request) -> BotCreateResponse : - await req.user.verify_scope(Scope.admin) - return await auth.createBot(req.user, BotType.internal) diff --git a/authenticator/authenticator.py b/authenticator/authenticator.py index 57d52be..670d7b2 100644 --- a/authenticator/authenticator.py +++ b/authenticator/authenticator.py @@ -7,6 +7,7 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Self, Tuple from uuid import UUID, uuid4 +import aerospike import pyotp import ujson as json from argon2 import PasswordHasher as Argon2 @@ -26,7 +27,7 @@ from shared.exceptions.http_error import BadRequest, Conflict, FailedLogin, HttpError, InternalServerError, NotFound, UnprocessableEntity from shared.hashing import Hashable from shared.models import InternalUser -from shared.models.auth import AuthState, AuthToken, KhUser, Scope, TokenMetadata +from shared.models.auth import AuthState, AuthToken, Scope, TokenMetadata, _KhUser from shared.sql import SqlInterface from shared.timing import timed from shared.utilities.json import json_stream @@ -89,7 +90,18 @@ BotLoginSerializer: AvroSerializer = AvroSerializer(BotLogin) BotLoginDeserializer: AvroDeserializer = AvroDeserializer(BotLogin) +token_kvs: KeyValueStore = KeyValueStore('kheina', 'token') +try : + KeyValueStore._client.index_integer_create( # type: ignore + 'kheina', + 'token', + 'user_id', + 'kheina_token_user_id_idx', + ) + +except aerospike.exception.IndexFoundError : + pass class BotTypeMap(SqlInterface): @AsyncLRU(maxsize=0) @@ -130,7 +142,6 @@ async def get_id(self: Self, key: BotType) -> int : class Authenticator(SqlInterface, Hashable) : EmailRegex = re_compile(r'^(?P[A-Z0-9._%+-]+)@(?P[A-Z0-9.-]+\.[A-Z]{2,})$', flags=IGNORECASE) - KVS: KeyValueStore def __init__(self) : Hashable.__init__(self) @@ -151,9 +162,6 @@ def __init__(self) : 'id': 0, } - if not getattr(Authenticator, 'KVS', None) : - Authenticator.KVS = KeyValueStore('kheina', 'token') - def _validateEmail(self, email: str) -> Dict[str, str] : e = Authenticator.EmailRegex.search(email) @@ -201,18 +209,18 @@ async def generate_token(self, user_id: int, token_data: dict, ttl: Optional[int start = self._calc_timestamp(issued) end = start + self._key_refresh_interval self._active_private_key = { - 'key': None, + 'key': None, 'algorithm': self._token_algorithm, - 'issued': 0, - 'start': start, - 'end': end, - 'id': 0, + 'issued': 0, + 'start': start, + 'end': end, + 'id': 0, } private_key = self._active_private_key['key'] = Ed25519PrivateKey.generate() public_key = private_key.public_key().public_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PublicFormat.SubjectPublicKeyInfo, + encoding = serialization.Encoding.DER, + format = serialization.PublicFormat.SubjectPublicKeyInfo, ) signature = private_key.sign(public_key) @@ -237,10 +245,10 @@ async def generate_token(self, user_id: int, token_data: dict, ttl: Optional[int # put the new key into the public keyring self._public_keyring[(self._token_algorithm, key_id)] = { - 'key': b64encode(public_key).decode(), + 'key': b64encode(public_key).decode(), 'signature': b64encode(signature).decode(), - 'issued': pk_issued, - 'expires': pk_expires, + 'issued': pk_issued, + 'expires': pk_expires, } guid: UUID = uuid4() @@ -255,16 +263,22 @@ async def generate_token(self, user_id: int, token_data: dict, ttl: Optional[int ]) token_info: TokenMetadata = TokenMetadata( - version=self._token_version.encode(), - state=AuthState.active, - issued=datetime.fromtimestamp(issued), - expires=datetime.fromtimestamp(expires), - key_id=key_id, - user_id=user_id, - algorithm=self._token_algorithm, - fingerprint=token_data.get('fp', '').encode(), + version = self._token_version.encode(), + state = AuthState.active, + issued = datetime.fromtimestamp(issued), + expires = datetime.fromtimestamp(expires), + key_id = key_id, + user_id = user_id, + algorithm = self._token_algorithm, + fingerprint = token_data.get('fp', '').encode(), + ) + await token_kvs.put_async( + guid.bytes, + token_info, + ttl or self._token_expires_interval, + # additional bins for querying active logins + { 'user_id': user_id }, ) - await Authenticator.KVS.put_async(guid.bytes, token_info, ttl or self._token_expires_interval) version = self._token_version.encode() content = b64encode(version) + b'.' + b64encode(load) @@ -272,12 +286,12 @@ async def generate_token(self, user_id: int, token_data: dict, ttl: Optional[int token = content + b'.' + b64encode(signature) return TokenResponse( - version=self._token_version, - algorithm=self._token_algorithm, # type: ignore - key_id=key_id, - issued=issued, # type: ignore - expires=expires, # type: ignore - token=token.decode(), + version = self._token_version, + algorithm = self._token_algorithm, # type: ignore + key_id = key_id, + issued = issued, # type: ignore + expires = expires, # type: ignore + token = token.decode(), ) @@ -440,7 +454,7 @@ async def login(self, email: str, password: str, otp: Optional[str], token_data: ) - async def createBot(self, user: KhUser, bot_type: BotType) -> BotCreateResponse : + async def createBot(self, user: _KhUser, bot_type: BotType) -> BotCreateResponse : if type(bot_type) is not BotType : # this should never run, thanks to pydantic/fastapi. just being extra careful. raise BadRequest('bot_type must be a BotType value.') @@ -709,11 +723,11 @@ async def create(self, handle: str, name: str, email: str, password: str, token_ raise InternalServerError('an error occurred during user creation.', logdata={ 'refid': refid }) - async def create_otp(self: Self, user: KhUser) -> str : + async def create_otp(self: Self, user: _KhUser) -> str : return pyotp.random_base32() - async def add_otp(self: Self, user: KhUser, email: str, otp_secret: str, otp: str) -> OtpAddedResponse : + async def add_otp(self: Self, user: _KhUser, email: str, otp_secret: str, otp: str) -> OtpAddedResponse : if not pyotp.TOTP(otp_secret).verify(otp) : raise BadRequest('failed to add OTP', email=email, user=user) diff --git a/avro_schema_repository/schema_repository.py b/avro_schema_repository/schema_repository.py index 1c3750f..4fd3dfc 100644 --- a/avro_schema_repository/schema_repository.py +++ b/avro_schema_repository/schema_repository.py @@ -1,18 +1,17 @@ -from typing import List +from hashlib import sha1 import ujson from avrofastapi.schema import AvroSchema -from shared.base64 import b64encode from shared.caching import AerospikeCache from shared.caching.key_value_store import KeyValueStore -from shared.crc import CRC from shared.exceptions.http_error import HttpErrorHandler, NotFound from shared.sql import SqlInterface -KVS: KeyValueStore = KeyValueStore('kheina', 'avro_schemas', local_TTL=60) -crc: CRC = CRC(64) +AvroMarker: bytes = b'\xC3\x01' +kvs: KeyValueStore = KeyValueStore('kheina', 'avro_schemas', local_TTL=60) +key_format: str = '{fingerprint}' def int_to_bytes(integer: int) -> bytes : @@ -23,24 +22,29 @@ def int_from_bytes(bytestring: bytes) -> int : return int.from_bytes(bytestring, 'little') +def crc(value: bytes) -> int : + return int.from_bytes(sha1(value).digest()[:8]) + + class SchemaRepository(SqlInterface) : @HttpErrorHandler('retrieving schema') - @AerospikeCache('kheina', 'avro_schemas', '{fingerprint}', _kvs=KVS) + @AerospikeCache('kheina', 'avro_schemas', key_format, _kvs=kvs) async def getSchema(self, fingerprint: bytes) -> bytes : """ - returns the avro schema as a json encoded string + returns the avro schema as a json encoded byte string """ fp: int = int_from_bytes(fingerprint) - data: List[bytes] = await self.query_async(""" + data: list[bytes] = await self.query_async(""" SELECT schema FROM kheina.public.avro_schemas WHERE fingerprint = %s; - """, + """, ( # because crc returns unsigned, we "convert" to signed - (fp - 9223372036854775808,), - fetch_one=True, + fp - 9223372036854775808, + ), + fetch_one = True, ) if not data : @@ -51,8 +55,11 @@ async def getSchema(self, fingerprint: bytes) -> bytes : @HttpErrorHandler('saving schema') async def addSchema(self, schema: AvroSchema) -> bytes : + """ + returns the schema fingerprint as a bytestring + """ data: bytes = ujson.dumps(schema).encode() - fingerprint: int = crc(data) + fp: int = crc(data) await self.query_async(""" INSERT INTO kheina.public.avro_schemas @@ -61,14 +68,15 @@ async def addSchema(self, schema: AvroSchema) -> bytes : (%s, %s) ON CONFLICT ON CONSTRAINT avro_schemas_pkey DO UPDATE SET - schema = %s; - """, + schema = excluded.schema; + """, ( # because crc returns unsigned, we "convert" to signed - (fingerprint - 9223372036854775808, data, data), - commit=True, + fp - 9223372036854775808, + data, + ), + commit = True, ) - fp: bytes = int_to_bytes(fingerprint) - KVS.put(b64encode(fp).decode(), schema) - - return fp + fingerprint: bytes = int_to_bytes(fp) + await kvs.put_async(key_format.format(fingerprint=fingerprint), schema) + return fingerprint diff --git a/configs/configs.py b/configs/configs.py index 5386571..98d4681 100644 --- a/configs/configs.py +++ b/configs/configs.py @@ -1,37 +1,32 @@ -from asyncio import ensure_future +from asyncio import Task, create_task from collections.abc import Iterable from datetime import datetime +from enum import Enum from random import randrange from re import Match, Pattern from re import compile as re_compile -from typing import Any, Optional, Self, Type +from typing import Literal, Optional, Self -from avrofastapi.schema import convert_schema -from avrofastapi.serialization import AvroDeserializer, AvroSerializer, Schema, parse_avro_schema -from cache import AsyncLRU +import aerospike from patreon import API as PatreonApi -from pydantic import BaseModel -from avro_schema_repository.schema_repository import SchemaRepository from shared.auth import KhUser from shared.caching import AerospikeCache from shared.caching.key_value_store import KeyValueStore from shared.config.constants import environment from shared.config.credentials import fetch from shared.exceptions.http_error import BadRequest, HttpErrorHandler, NotFound -from shared.models import Undefined +from shared.models import PostId from shared.sql import SqlInterface from shared.timing import timed +from users.repository import Repository as Users -from .models import OTP, BannerStore, ConfigsResponse, ConfigType, CostsStore, CssProperty, Funding, OtpType, UserConfig, UserConfigKeyFormat, UserConfigRequest, UserConfigResponse +from .models import OTP, BannerStore, BlockBehavior, Blocking, BlockingBehavior, ConfigsResponse, ConfigType, CostsStore, CssProperty, CssValue, Funding, OtpType, Store, Theme, UserConfigKeyFormat, UserConfigResponse, UserConfigType -repo: SchemaRepository = SchemaRepository() - PatreonClient: PatreonApi = PatreonApi(fetch('creator_access_token', str)) KVS: KeyValueStore = KeyValueStore('kheina', 'configs', local_TTL=60) -UserConfigSerializer: AvroSerializer = AvroSerializer(UserConfig) -AvroMarker: bytes = b'\xC3\x01' +users: Users = Users() ColorRegex: Pattern = re_compile(r'^(?:#(?P[a-f0-9]{8}|[a-f0-9]{6})|(?P[a-z0-9-]+))$') PropValidators: dict[CssProperty, Pattern] = { CssProperty.background_attachment: re_compile(r'^(?:scroll|fixed|local)(?:,\s*(?:scroll|fixed|local))*$'), @@ -43,30 +38,14 @@ class Configs(SqlInterface) : - UserConfigFingerprint: bytes - Serializers: dict[ConfigType, tuple[AvroSerializer, bytes]] - SerializerTypeMap: dict[ConfigType, Type[BaseModel]] = { - ConfigType.banner: BannerStore, - ConfigType.costs: CostsStore, + SerializerTypeMap: dict[Enum, type[Store]] = { + ConfigType.banner: BannerStore, + ConfigType.costs: CostsStore, + UserConfigType.blocking: Blocking, + UserConfigType.block_behavior: BlockBehavior, + UserConfigType.theme: Theme, } - async def startup(self) -> bool : - Configs.Serializers = { - ConfigType.banner: (AvroSerializer(BannerStore), await repo.addSchema(convert_schema(BannerStore))), - ConfigType.costs: (AvroSerializer(CostsStore), await repo.addSchema(convert_schema(CostsStore))), - } - self.UserConfigFingerprint = await repo.addSchema(convert_schema(UserConfig)) - assert self.Serializers.keys() == set(ConfigType.__members__.values()), 'Did you forget to add serializers for a config?' - assert self.SerializerTypeMap.keys() == set(ConfigType.__members__.values()), 'Did you forget to add serializers for a config?' - return True - - - @AsyncLRU(maxsize=32) - @staticmethod - async def getSchema(fingerprint: bytes) -> Schema : - return parse_avro_schema((await repo.getSchema(fingerprint)).decode()) - - @HttpErrorHandler('retrieving patreon campaign info') @AerospikeCache('kheina', 'configs', 'patreon-campaign-funds', TTL_minutes=10, _kvs=KVS) async def getFunding(self) -> int : @@ -78,24 +57,25 @@ async def getFunding(self) -> int : @HttpErrorHandler('retrieving config') - async def getConfigs(self, configs: Iterable[ConfigType]) -> dict[ConfigType, Any] : + async def getConfigs(self: Self, configs: Iterable[ConfigType]) -> dict[ConfigType, Store] : keys = list(configs) if not keys : return { } - cached = await KVS.get_many_async(keys) + cached = await KVS.get_many_async(keys, Store) + found: dict[ConfigType, Store] = { } misses: list[ConfigType] = [] - for k, v in list(cached.items()) : - if v is not Undefined : + for k, v in cached.items() : + if isinstance(v, Store) : + found[k] = v continue misses.append(k) - del cached[k] if not misses : - return cached + return found data: Optional[list[tuple[str, bytes]]] = await self.query_async(""" SELECT key, bytes @@ -111,63 +91,61 @@ async def getConfigs(self, configs: Iterable[ConfigType]) -> dict[ConfigType, An raise NotFound('no data was found for the provided config.') for k, v in data : - v: bytes = bytes(v) - assert v[:2] == AvroMarker - config: ConfigType = ConfigType(k) - deserializer: AvroDeserializer = AvroDeserializer( - read_model = self.SerializerTypeMap[config], - write_model = await Configs.getSchema(v[2:10]), - ) - value = cached[config] = deserializer(v[10:]) - ensure_future(KVS.put_async(config, value)) + value = found[config] = await self.SerializerTypeMap[config].deserialize(bytes(v)) + create_task(KVS.put_async(config, value)) - return cached + return found async def allConfigs(self: Self) -> ConfigsResponse : - funds = ensure_future(self.getFunding()) - configs = await self.getConfigs(self.SerializerTypeMap.keys()) + funds = create_task(self.getFunding()) + configs = await self.getConfigs([ + ConfigType.banner, + ConfigType.costs, + ]) + banner = configs[ConfigType.banner] + assert isinstance(banner, BannerStore), f'banner is not the expected type of BannerStore, got: {type(banner)}' + costs = configs[ConfigType.costs] + assert isinstance(costs, CostsStore), f'costs is not the expected type of CostsStore, got: {type(costs)}' return ConfigsResponse( - banner = configs[ConfigType.banner].banner, + banner = banner.banner, funding = Funding( funds = await funds, - costs = configs[ConfigType.costs].costs, + costs = costs.costs, ), ) @HttpErrorHandler('updating config') - async def updateConfig(self, user: KhUser, config: ConfigType, value: BaseModel) -> None : - serializer: tuple[AvroSerializer, bytes] = self.Serializers[config] - data: bytes = AvroMarker + serializer[1] + serializer[0](value) + async def updateConfig(self: Self, user: KhUser, config: Store) -> None : await self.query_async(""" - INSERT INTO kheina.public.configs + insert into kheina.public.configs (key, bytes, updated_by) - VALUES - (%s, %s, %s) - ON CONFLICT ON CONSTRAINT configs_pkey DO - UPDATE SET - updated = now(), - bytes = %s, - updated_by = %s; - """, - ( - config, data, user.user_id, - data, user.user_id, + values + ( %s, %s, %s) + on conflict on constraint configs_pkey do + update set + updated = now(), + bytes = excluded.bytes, + updated_by = excluded.updated_by + where key = excluded.key; + """, ( + config.key(), + await config.serialize(), + user.user_id, ), - commit=True, + commit = True, ) - print(config, value) - await KVS.put_async(config, value) + await KVS.put_async(config.key(), config) @staticmethod - def _validateColors(css_properties: Optional[dict[CssProperty, str]]) -> Optional[dict[str, str | int]] : + def _validateColors(css_properties: Optional[dict[CssProperty, str]]) -> Optional[dict[str, CssValue | int | str]] : if not css_properties : return None - output: dict[str, str | int] = { } + output: dict[str, CssValue | int | str] = { } # color input is very strict for prop, value in css_properties.items() : @@ -197,8 +175,8 @@ def _validateColors(css_properties: Optional[dict[CssProperty, str]]) -> Optiona else : c: str = match.group('var').replace('-', '_') - if c in CssProperty._member_map_ : - output[color.value] = c + if c in CssValue._member_map_ : + output[color.value] = CssValue(c) else : raise BadRequest(f'{value} is not a valid color. value must be in the form "#xxxxxx", "#xxxxxxxx", or the name of another color variable (without the preceding deshes)') @@ -207,61 +185,128 @@ def _validateColors(css_properties: Optional[dict[CssProperty, str]]) -> Optiona @HttpErrorHandler('saving user config') - async def setUserConfig(self, user: KhUser, value: UserConfigRequest) -> None : - user_config: UserConfig = UserConfig( - blocking_behavior=value.blocking_behavior, - blocked_tags=list(map(list, value.blocked_tags)) if value.blocked_tags else None, - # TODO: internal tokens need to be added so that we can convert handles to user ids - blocked_users=None, - wallpaper=value.wallpaper, - css_properties=Configs._validateColors(value.css_properties), - ) - - data: bytes = AvroMarker + self.UserConfigFingerprint + UserConfigSerializer(user_config) - config_key: str = UserConfigKeyFormat.format(user_id=user.user_id) - await self.query_async(""" - INSERT INTO kheina.public.configs - (key, bytes, updated_by) - VALUES - (%s, %s, %s) - ON CONFLICT ON CONSTRAINT configs_pkey DO - UPDATE SET - updated = now(), - bytes = %s, - updated_by = %s; - """, ( - config_key, data, user.user_id, - data, user.user_id, - ), - commit=True, + async def setUserConfig( + self: Self, + user: KhUser, + blocking_behavior: BlockingBehavior | None = None, + blocked_tags: list[set[str]] | None = None, + blocked_users: list[str] | None = None, + wallpaper: PostId | None | Literal[False] = False, + css_properties: dict[CssProperty, str] | None | Literal[False] = False, + ) -> None : + stores: list[Store] = [] + + if blocking_behavior : + stores.append(BlockBehavior( + behavior = blocking_behavior, + )) + + if blocked_tags is not None or blocked_users is not None : + blocking = await self._getUserConfig(user.user_id, Blocking) + + if blocked_tags is not None : + blocking.tags = list(map(list, blocked_tags)) + + if blocked_users is not None : + blocking.users = list((await users._handles_to_user_ids(blocked_users)).values()) + + if len(blocking.users) != len(blocked_users) : + raise BadRequest('could not find users for some or all of the provided handles') + + stores.append(blocking) + + if wallpaper is not False or css_properties is not False : + theme = await self._getUserConfig(user.user_id, Theme) + + if wallpaper is not False : + theme.wallpaper = wallpaper + + if css_properties is not False : + theme.css_properties = self._validateColors(css_properties) + + stores.append(theme) + + if not stores : + raise BadRequest('must submit at least one config to update') + + query: list[str] = [] + params: list[int | str | bytes] = [] + + for store in stores : + query.append('(%s, %s, %s, %s)') + params += [ + user.user_id, + store.key(), + store.type_(), + await store.serialize(), + ] + + await self.query_async(f""" + insert into kheina.public.user_configs + (user_id, key, type, data) + values + {','.join(query)} + on conflict on constraint user_configs_pkey do + update set + type = excluded.type, + data = excluded.data + where user_configs.user_id = excluded.user_id + and user_configs.key = excluded.key; + """, + tuple(params), + commit = True, ) - await KVS.put_async(config_key, user_config) + for store in stores : + create_task(KVS.put_async( + UserConfigKeyFormat.format( + user_id = user.user_id, + key = store.key(), + ), + store, + )) + + + async def _getUserConfig[T: Store](self: Self, user_id: int, type_: type[T]) -> T : + try : + return await KVS.get_async( + UserConfigKeyFormat.format( + user_id = user_id, + key = type_.key(), + ), + type = type_, + ) + except aerospike.exception.RecordNotFound : + pass - @AerospikeCache('kheina', 'configs', UserConfigKeyFormat, _kvs=KVS) - async def _getUserConfig(self, user_id: int) -> UserConfig : data: list[bytes] = await self.query_async(""" - SELECT bytes - FROM kheina.public.configs - WHERE key = %s; - """, - (UserConfigKeyFormat.format(user_id=user_id),), - fetch_one=True, + select data + from kheina.public.user_configs + where user_id = %s + and key = %s; + """, ( + user_id, + type_.key(), + ), + fetch_one = True, ) if not data : - return UserConfig() + return type_() - value: bytes = bytes(data[0]) - assert value[:2] == AvroMarker - - deserializer: AvroDeserializer[UserConfig] = AvroDeserializer(read_model=UserConfig, write_model=await Configs.getSchema(value[2:10])) - return deserializer(value[10:]) + await KVS.put_async( + UserConfigKeyFormat.format( + user_id = user_id, + key = type_.key(), + ), + res := await type_.deserialize(data[0]), + ) + return res @timed - async def _getUserOTP(self: Self, user_id: int) -> Optional[list[OTP]] : + async def _getUserOTP(self: Self, user_id: int) -> list[OTP] : data: list[tuple[datetime, str]] = await self.query_async(""" select created, 'totp' from kheina.auth.otp @@ -273,7 +318,7 @@ async def _getUserOTP(self: Self, user_id: int) -> Optional[list[OTP]] : ) if not data : - return None + return [] return [ OTP( @@ -285,35 +330,55 @@ async def _getUserOTP(self: Self, user_id: int) -> Optional[list[OTP]] : @HttpErrorHandler('retrieving user config') - async def getUserConfig(self, user: KhUser) -> UserConfigResponse : - user_config: UserConfig = await self._getUserConfig(user.user_id) - - return UserConfigResponse( - blocking_behavior = user_config.blocking_behavior, - blocked_tags = list(map(set, user_config.blocked_tags)) if user_config.blocked_tags else [], - # TODO: convert user ids to handles - blocked_users = None, - wallpaper = user_config.wallpaper.decode() if user_config.wallpaper else None, - otp = await self._getUserOTP(user.user_id), + @timed + async def getUserConfig(self: Self, user: KhUser) -> UserConfigResponse : + data: list[tuple[str, int, bytes]] = await self.query_async(""" + select key, type, data + from kheina.public.user_configs + where user_configs.user_id = %s; + """, ( + user.user_id, + ), + fetch_all = True, ) + res = UserConfigResponse() + otp: Task[list[OTP]] = create_task(self._getUserOTP(user.user_id)) + if data : + for key, type_, value in data : + t: type[Store] = Configs.SerializerTypeMap[UserConfigType(type_)] + match v := await t.deserialize(value) : + case BlockBehavior() : + res.blocking_behavior = v.behavior + + case Blocking() : + res.blocked_tags = v.tags + res.blocked_users = [i.handle for i in (await users._get_users(v.users)).values()] + + case Theme() : + if v.wallpaper or v.css_properties : + res.theme = v + + res.otp = await otp + return res + @HttpErrorHandler('retrieving custom theme') - async def getUserTheme(self, user: KhUser) -> str : - user_config: UserConfig = await self._getUserConfig(user.user_id) + async def getUserTheme(self: Self, user: KhUser) -> str : + theme: Theme = await self._getUserConfig(user.user_id, Theme) - if not user_config.css_properties : + if not theme.css_properties : return '' css_properties: str = '' - for key, value in user_config.css_properties.items() : + for key, value in theme.css_properties.items() : name = key.replace("_", "-") if isinstance(value, int) : css_properties += f'--{name}:#{value:08x} !important;' - elif isinstance(value, CssProperty) : + elif isinstance(value, CssValue) : css_properties += f'--{name}:var(--{value.value.replace("_", "-")}) !important;' else : diff --git a/configs/models.py b/configs/models.py index 179b8d5..952f340 100644 --- a/configs/models.py +++ b/configs/models.py @@ -1,45 +1,51 @@ from datetime import datetime -from enum import Enum, unique -from typing import Dict, List, Literal, Optional, Set, Union +from enum import Enum, IntEnum, unique +from typing import Any, Literal, Optional, Self from avrofastapi.schema import AvroInt -from pydantic import BaseModel, conbytes +from pydantic import BaseModel, Field from shared.models import PostId +from shared.models.config import Store +from shared.sql.query import Table -UserConfigKeyFormat: str = 'user.{user_id}' +UserConfigKeyFormat: Literal['user.{user_id}.{key}'] = 'user.{user_id}.{key}' -class BannerStore(BaseModel) : +@unique +class ConfigType(str, Enum) : # str so literals work + banner = 'banner' + costs = 'costs' + + +class BannerStore(Store) : banner: Optional[str] + @staticmethod + def type_() -> ConfigType : + return ConfigType.banner + -class CostsStore(BaseModel) : +class CostsStore(Store) : costs: int - -@unique -class ConfigType(str, Enum) : # str so literals work - banner = 'banner' - costs = 'costs' + @staticmethod + def type_() -> ConfigType : + return ConfigType.costs class UpdateBannerRequest(BaseModel) : config: Literal[ConfigType.banner] - value: BannerStore + value: BannerStore class UpdateCostsRequest(BaseModel) : config: Literal[ConfigType.costs] - value: CostsStore - + value: CostsStore -UpdateConfigRequest = Union[UpdateBannerRequest, UpdateCostsRequest] - -class SaveSchemaResponse(BaseModel) : - fingerprint: str +UpdateConfigRequest = UpdateBannerRequest | UpdateCostsRequest class Funding(BaseModel) : @@ -52,6 +58,7 @@ class ConfigsResponse(BaseModel) : funding: Funding +@unique class BlockingBehavior(Enum) : hide = 'hide' omit = 'omit' @@ -63,6 +70,8 @@ class CssProperty(Enum) : background_repeat = 'background_repeat' background_size = 'background_size' + +class CssValue(Enum) : transition = 'transition' fadetime = 'fadetime' warning = 'warning' @@ -104,20 +113,55 @@ class CssProperty(Enum) : notification_bg = 'notification_bg' -class UserConfig(BaseModel) : - blocking_behavior: Optional[BlockingBehavior] = None - blocked_tags: Optional[List[List[str]]] = None - blocked_users: Optional[List[int]] = None - wallpaper: Optional[conbytes(min_length=8, max_length=8)] = None - css_properties: Optional[Dict[str, Union[CssProperty, AvroInt, str]]] = None +@unique +class UserConfigType(IntEnum) : + blocking = 0 + block_behavior = 1 + theme = 2 + + +class Blocking(Store) : + tags: list[list[str]] = [] + users: list[int] = [] + + @staticmethod + def type_() -> UserConfigType : + return UserConfigType.blocking + + +class BlockBehavior(Store) : + behavior: BlockingBehavior = BlockingBehavior.hide + + @staticmethod + def type_() -> UserConfigType : + return UserConfigType.block_behavior + + +class Theme(Store) : + wallpaper: Optional[PostId] = None + css_properties: Optional[dict[str, CssValue | AvroInt | str]] = None + + @staticmethod + def type_() -> UserConfigType : + return UserConfigType.theme class UserConfigRequest(BaseModel) : + field_mask: list[str] blocking_behavior: Optional[BlockingBehavior] - blocked_tags: Optional[List[Set[str]]] - blocked_users: Optional[List[str]] - wallpaper: Optional[PostId] - css_properties: Optional[Dict[CssProperty, str]] + blocked_tags: Optional[list[set[str]]] + blocked_users: Optional[list[str]] + wallpaper: Optional[PostId] + css_properties: Optional[dict[CssProperty, str]] + + def values(self: Self) -> dict[str, Any] : + values = { } + + for f in self.field_mask : + if f in self.__fields_set__ : + values[f] = getattr(self, f) + + return values @unique @@ -132,8 +176,18 @@ class OTP(BaseModel) : class UserConfigResponse(BaseModel) : - blocking_behavior: Optional[BlockingBehavior] - blocked_tags: Optional[List[Set[str]]] - blocked_users: Optional[List[str]] - wallpaper: Optional[str] - otp: Optional[list[OTP]] + blocking_behavior: BlockingBehavior = BlockingBehavior.hide + blocked_tags: list[list[str]] = [] + blocked_users: list[str] = [] + theme: Optional[Theme] = None + otp: list[OTP] = [] + + +class Config(BaseModel) : + __table_name__ = Table('kheina.public.configs') + + key: str = Field(description='orm:"pk"') + created: datetime = Field(description='orm:"gen"') + updated: datetime = Field(description='orm:"gen"') + updated_by: int + bytes_: Optional[bytes] = Field(None, description='orm:"col[bytes]"') diff --git a/configs/router.py b/configs/router.py index a0ab2ce..1e68288 100644 --- a/configs/router.py +++ b/configs/router.py @@ -4,7 +4,7 @@ from fastapi.responses import PlainTextResponse from shared.auth import Scope -from shared.server import Request +from shared.models.server import Request from .configs import Configs from .models import ConfigsResponse, UpdateConfigRequest, UserConfigRequest, UserConfigResponse @@ -17,24 +17,11 @@ configs: Configs = Configs() -@app.on_event('startup') -async def startup() : - await configs.startup() - - @app.on_event('shutdown') async def shutdown() : await configs.close() -################################################## INTERNAL ################################################## -# @app.get('/i1/user/{user_id}', response_model=UserConfig) -# async def i1UserConfig(req: Request, user_id: int) -> UserConfig : -# await req.user.verify_scope(Scope.internal) -# return await configs._getUserConfig(user_id) - - -################################################## PUBLIC ################################################## @app.get('s', response_model=ConfigsResponse) async def v1Configs() -> ConfigsResponse : return await configs.allConfigs() @@ -45,7 +32,7 @@ async def v1UpdateUserConfig(req: Request, body: UserConfigRequest) -> None : await req.user.verify_scope(Scope.user) await configs.setUserConfig( req.user, - body, + **body.values(), ) @@ -67,11 +54,11 @@ async def v1UserTheme(req: Request) -> PlainTextResponse : ) -@app.patch('', status_code=204) -async def v1UpdateConfig(req: Request, body: UpdateConfigRequest) -> None : - await req.user.verify_scope(Scope.mod) - await configs.updateConfig( - req.user, - body.config, - body.value, - ) +# @app.patch('', status_code=204) +# async def v1UpdateConfig(req: Request, body: UpdateConfigRequest) -> None : +# await req.user.verify_scope(Scope.mod) +# await configs.updateConfig( +# req.user, +# body.config, +# body.value, +# ) diff --git a/db/10/00-add-notification-subscriptions.sql b/db/10/00-add-notification-subscriptions.sql new file mode 100644 index 0000000..38e243f --- /dev/null +++ b/db/10/00-add-notification-subscriptions.sql @@ -0,0 +1,73 @@ +begin; + +create or replace function generated_created() returns trigger as +$$ +begin + + new.created = now(); + return new; + +end; +$$ +language plpgsql; + +drop table if exists public.subscriptions; +create table public.subscriptions ( + sub_id uuid unique not null, + user_id bigint not null + references public.users (user_id) + on update cascade + on delete cascade, + sub_info bytea unique not null, + primary key (user_id, sub_id) +); + +drop table if exists public.notifications; +create table public.notifications ( + id uuid not null unique, + user_id bigint not null + references public.users (user_id) + on update cascade + on delete cascade, + type smallint not null, + created timestamptz not null, + data bytea not null, + primary key (user_id, id) +); + +create index if not exists notifications_user_created_idx on public.notifications (user_id, created); + +create or replace trigger generated_created before insert on public.notifications + for each row execute procedure generated_created(); + +create or replace trigger immutable_columns before update on public.notifications + for each row execute procedure public.immutable_columns('id', 'user_id', 'type', 'created', 'data'); + +drop function if exists public.register_subscription; +create or replace function public.register_subscription(sid uuid, uid bigint, sinfo bytea) returns void as +$$ +begin + + update public.subscriptions + set sub_id = sid, + sub_info = sinfo + where subscriptions.user_id = uid + and ( + subscriptions.sub_id = sid + or subscriptions.sub_info = sinfo + ); + + if found then + return; + end if; + + insert into public.subscriptions + (sub_id, user_id, sub_info) + values + (sid, uid, sinfo); + +end; +$$ +language plpgsql; + +commit; diff --git a/db/11/00-create-user-configs-table.sql b/db/11/00-create-user-configs-table.sql new file mode 100644 index 0000000..395f904 --- /dev/null +++ b/db/11/00-create-user-configs-table.sql @@ -0,0 +1,17 @@ +begin; + +alter table public.configs drop column if exists value; + +drop table if exists public.user_configs; +create table public.user_configs ( + user_id bigint not null + references public.users (user_id) + on update cascade + on delete cascade, + key text not null, + type smallint not null, + data bytea not null, + primary key (user_id, key) +); + +commit; diff --git a/docker-exec.sh b/docker-exec.sh new file mode 100644 index 0000000..a9eda82 --- /dev/null +++ b/docker-exec.sh @@ -0,0 +1,5 @@ +#!/bin/sh + +jq '.fullchain' -r /etc/certs/cert.json > fullchain.pem && \ +jq '.privkey' -r /etc/certs/cert.json > privkey.pem && \ +gunicorn -w 2 -k uvicorn.workers.UvicornWorker --certfile fullchain.pem --keyfile privkey.pem -b 0.0.0.0:443 --timeout 1200 server:app diff --git a/emojis/repository.py b/emojis/repository.py index a34c774..ba54ccf 100644 --- a/emojis/repository.py +++ b/emojis/repository.py @@ -11,7 +11,7 @@ from shared.models import PostId from shared.models._shared import InternalUser, UserPortable from shared.sql import SqlInterface -from users.repository import Users +from users.repository import Repository as Users from .models import Emoji, InternalEmoji diff --git a/emojis/router.py b/emojis/router.py index d8d70f6..cc0ed48 100644 --- a/emojis/router.py +++ b/emojis/router.py @@ -3,7 +3,7 @@ from fastapi import APIRouter from shared.models.auth import Scope -from shared.server import Request +from shared.models.server import Request from shared.timing import timed from .emoji import Emojis diff --git a/init.py b/init.py index 2c5580e..76338b4 100644 --- a/init.py +++ b/init.py @@ -3,7 +3,6 @@ import re import shutil import time -from dataclasses import dataclass from os import environ, listdir, remove from os.path import isdir, isfile, join from secrets import token_bytes @@ -25,6 +24,7 @@ from shared.config.credentials import decryptCredentialFile, fetch from shared.datetime import datetime from shared.logging import TerminalAgent +from shared.models.encryption import Keys from shared.sql import SqlInterface @@ -91,12 +91,17 @@ def nukeCache() -> None : is_flag=True, default=False, ) +@click.option( + '-l', + '--lock', + default=None, +) @click.option( '-f', '--file', default='', ) -async def execSql(unlock: bool = False, file: str = '') -> None : +async def execSql(unlock: bool = False, file: str = '', lock: Optional[int] = None) -> None : """ connects to the database and runs all files stored under the db folder folders under db are sorted numberically and run in descending order @@ -112,9 +117,14 @@ async def execSql(unlock: bool = False, file: str = '') -> None : async with sql.pool.connection() as conn : async with conn.cursor() as cur : sqllock = None - if not unlock and isfile('sql.lock') : + + if lock is not None : + sqllock = int(lock) + + if not unlock and sqllock is None and isfile('sql.lock') : sqllock = int(open('sql.lock').read().strip()) - click.echo(f'==> sql.lock: {sqllock}') + + click.echo(f'==> sql.lock: {sqllock}') if file : if not isfile(file) : @@ -319,45 +329,24 @@ async def updatePassword() -> LoginRequest : return acct -@dataclass -class Keys : - aes: AESGCM - ed25519: Ed25519PrivateKey - associated_data: bytes - - def encrypt(self, data: bytes) -> bytes : - nonce = token_bytes(12) - return b'.'.join(map(b64encode, [nonce, self.aes.encrypt(nonce, data, self.associated_data), self.ed25519.sign(data)])) - - def _generate_keys() -> Keys : + keys = Keys.generate() + if isfile('credentials/aes.key') : remove('credentials/aes.key') if isfile('credentials/ed25519.pub') : remove('credentials/ed25519.pub') - aesbytes = AESGCM.generate_key(256) - aeskey = AESGCM(aesbytes) - ed25519priv = Ed25519PrivateKey.generate() + data = keys.dump() with open('credentials/aes.key', 'wb') as file : - file.write(b'.'.join(map(b64encode, [aesbytes, ed25519priv.sign(aesbytes)]))) + file.write(data['aes'].encode()) - pub = ed25519priv.public_key().public_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) with open('credentials/ed25519.pub', 'wb') as file : - nonce = token_bytes(12) - aeskey.encrypt - file.write(b'.'.join(map(b64encode, [nonce, aeskey.encrypt(nonce, pub, aesbytes), ed25519priv.sign(pub)]))) - - return Keys( - aes=aeskey, - ed25519=ed25519priv, - associated_data=pub, - ) + file.write(data['pub'].encode()) + + return keys def writeAesFile(file: BinaryIO, contents: bytes) : diff --git a/k8s.yml b/k8s.yml index 8ebd7ae..a4ae2ef 100644 --- a/k8s.yml +++ b/k8s.yml @@ -14,7 +14,7 @@ spec: spec: containers: - name: fuzzly-backend - image: us-central1-docker.pkg.dev/kheinacom/fuzzly-repo/fuzzly-backend@sha256:e757a3b0c5912818f7131e7441db814adcda5d0a0f7c899abd64a5ebfafe0e99 + image: us-central1-docker.pkg.dev/kheinacom/fuzzly-repo/fuzzly-backend@sha256:53b0898536ff726a2557ea55cb9bdc99cdadeb310185af21093eb8ad88d1a452 env: - name: pod_ip valueFrom: diff --git a/notifications/__init__.py b/notifications/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/notifications/models.py b/notifications/models.py new file mode 100644 index 0000000..ac521ab --- /dev/null +++ b/notifications/models.py @@ -0,0 +1,146 @@ +from datetime import datetime +from enum import Enum, IntEnum +from typing import Literal, Optional, Self +from uuid import UUID + +from pydantic import BaseModel, Field, validator + +from posts.models import Post +from shared.models import UserPortable +from shared.models.config import Store +from shared.sql.query import Table + + +class ServerKey(BaseModel) : + application_server_key: str + + +class SubscriptionInfo(Store) : + endpoint: str + expirationTime: Optional[int] + keys: dict[str, str] + + @classmethod + def type_(cls) -> Enum : + raise NotImplementedError + + +class Subscription(BaseModel) : + __table_name__ = Table('kheina.public.subscriptions') + + sub_id: UUID = Field(description='orm:"pk"') + """ + sub_id refers to the guid from an auth token. this way, on log out or expiration, a subscription can be removed from the database proactively + """ + + user_id: int = Field(description='orm:"pk"') + subscription_info: bytes = Field(description='orm:"col[sub_info]"') + + +class NotificationType(IntEnum) : + post = 0 + user = 1 + interact = 2 + + +class InternalNotification(BaseModel) : + __table_name__ = Table('kheina.public.notifications') + + id: UUID = Field(description='orm:"pk"') + user_id: int = Field(description='orm:"pk"') + type_: int = Field(description='orm:"col[type]"') + created: datetime = Field(description='orm:"gen"') + data: bytes + + @validator('type_') + def isValidType(cls, value) : + if value not in NotificationType.__members__.values() : + raise KeyError('notification type must exist in the notification enum') + + return value + + def type(self: Self) -> NotificationType : + return NotificationType(self.type_) + + +class Notification(BaseModel) : + type: str + event: Enum + + +class InteractNotificationEvent(Enum) : + favorite = 'favorite' + reply = 'reply' + repost = 'repost' + + +class InternalInteractNotification(Store) : + """ + an interact notification represents a user taking an action on a post + """ + event: InteractNotificationEvent + post_id: int + user_id: int + + @classmethod + def type_(cls) -> NotificationType : + return NotificationType.interact + + +class InteractNotification(Notification) : + """ + an interact notification represents a user taking an action on a post + """ + id: UUID + type: Literal['interact'] = 'interact' + event: InteractNotificationEvent + created: datetime + user: UserPortable + post: Post + + +class PostNotificationEvent(Enum) : + mention = 'mention' + tagged = 'tagged' + + +class InternalPostNotification(Store) : + event: PostNotificationEvent + post_id: int + + @classmethod + def type_(cls) -> NotificationType : + return NotificationType.post + + +class PostNotification(Notification) : + id: UUID + type: Literal['post'] = 'post' + event: PostNotificationEvent + created: datetime + post: Post + + +class UserNotificationEvent(Enum) : + follow = 'follow' + + +class InternalUserNotification(Store) : + event: UserNotificationEvent + user_id: int + """ + this is the user that performed the action, NOT the user being notified, + that is stored in the larger InternalNotification object + """ + + @classmethod + def type_(cls) -> NotificationType : + return NotificationType.user + + +class UserNotification(Notification) : + id: UUID + type: Literal['user'] = 'user' + event: UserNotificationEvent + created: datetime + user: UserPortable diff --git a/notifications/notifications.py b/notifications/notifications.py new file mode 100644 index 0000000..def535d --- /dev/null +++ b/notifications/notifications.py @@ -0,0 +1,98 @@ +from asyncio import Task, create_task +from datetime import datetime, timedelta +from typing import Self + +from posts.models import InternalPost, Post +from posts.repository import Repository as Posts +from shared.auth import KhUser +from shared.exceptions.http_error import HttpErrorHandler +from shared.models import InternalUser, PostId, UserPortable +from shared.sql.query import Field, Operator, Order, Value, Where +from users.repository import Repository as Users + +from .models import InteractNotification, InternalInteractNotification, InternalNotification, InternalPostNotification, InternalUserNotification, NotificationType, PostNotification, UserNotification +from .repository import Notifier + + +posts: Posts = Posts() +users: Users = Users() + + +class Notifications(Notifier) : + + @HttpErrorHandler('fetching notifications') + async def fetchNotifications(self: Self, user: KhUser) -> list[InteractNotification | PostNotification | UserNotification] : + data: list[InternalNotification] = await self.where( + InternalNotification, + Where( + Field('notifications', 'user_id'), + Operator.equal, + Value(user.user_id), + ), + order = [ + (Field('notifications', 'created'), Order.ascending), + ], + limit = 100, + ) + + if not data : + return [] + + inotifications: list[tuple[InternalNotification, InternalInteractNotification | InternalPostNotification | InternalUserNotification]] = [] + post_ids: list[PostId] = [] + user_ids: list[int] = [] + for n in data : + match n.type() : + case NotificationType.interact : + inotifications.append((n, notif := await InternalInteractNotification.deserialize(n.data))) + post_ids.append(PostId(notif.post_id)) + user_ids.append(notif.user_id) + + case NotificationType.post : + inotifications.append((n, notif := await InternalPostNotification.deserialize(n.data))) + post_ids.append(PostId(notif.post_id)) + + case NotificationType.user : + inotifications.append((n, notif := await InternalUserNotification.deserialize(n.data))) + user_ids.append(notif.user_id) + + iposts: Task[dict[PostId, InternalPost]] = create_task(posts._get_posts(post_ids)) + iusers: Task[dict[int, InternalUser]] = create_task(users._get_users(user_ids)) + + posts_task: Task[list[Post]] = create_task(posts.posts(user, list((await iposts).values()))) + all_users: dict[int, UserPortable] = await users.portables(user, list((await iusers).values())) + all_posts: dict[PostId, Post] = { + p.post_id: p + for p in await posts_task + } + + notifications: list[InteractNotification | PostNotification | UserNotification] = [] + + for n, i in inotifications : + match i : + case InternalInteractNotification() : + notifications.append(InteractNotification( + id = n.id, + event = i.event, + created = n.created, + user = all_users[i.user_id], + post = all_posts[PostId(i.post_id)], + )) + + case InternalPostNotification() : + notifications.append(PostNotification( + id = n.id, + event = i.event, + created = n.created, + post = all_posts[PostId(i.post_id)], + )) + + case InternalUserNotification() : + notifications.append(UserNotification( + id = n.id, + event = i.event, + created = n.created, + user = all_users[i.user_id], + )) + + return notifications diff --git a/notifications/repository.py b/notifications/repository.py new file mode 100644 index 0000000..72bfe3c --- /dev/null +++ b/notifications/repository.py @@ -0,0 +1,315 @@ +from typing import Optional, Self +from urllib.parse import urlparse +from uuid import UUID + +import aerospike +import ujson +from aiohttp import ClientResponse, ClientSession, ClientTimeout +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey +from py_vapid import Vapid02 +from pydantic import BaseModel +from pywebpush import WebPusher as _WebPusher + +from configs.models import Config +from posts.models import InternalPost, Post +from shared.auth import KhUser, tokenMetadata +from shared.base64 import b64encode +from shared.caching import AerospikeCache +from shared.caching.key_value_store import KeyValueStore +from shared.config.credentials import fetch +from shared.datetime import datetime +from shared.exceptions.http_error import HttpErrorHandler, InternalServerError +from shared.models import PostId, UserPortable +from shared.models.auth import AuthState, TokenMetadata +from shared.models.encryption import Keys +from shared.sql import SqlInterface +from shared.sql.query import Field, Operator, Order, Value, Where +from shared.timing import timed +from shared.utilities import uuid7 +from shared.utilities.json import json_stream + +from .models import InteractNotification, InternalInteractNotification, InternalNotification, InternalPostNotification, InternalUserNotification, NotificationType, PostNotification, ServerKey, Subscription, SubscriptionInfo, UserNotification + + +@timed +async def getTokenMetadata(guid: UUID) -> Optional[TokenMetadata] : + try : + return await tokenMetadata(guid) + + except aerospike.exception.RecordNotFound : + return None + + +class WebPusher(_WebPusher) : + @timed + async def send_async(self, *args, **kwargs) -> ClientResponse | str : + # this is pretty much copied as-is, but with a couple changes to fix issues + timeout = ClientTimeout(kwargs.pop("timeout", 10000)) + curl = kwargs.pop("curl", False) + + params = self._prepare_send_data(*args, **kwargs) + endpoint = params.pop("endpoint") + + if curl : + encoded_data = params["data"] + headers = params["headers"] + return self.as_curl(endpoint, encoded_data=encoded_data, headers=headers) + + if self.aiohttp_session : + resp = await self.aiohttp_session.post(endpoint, timeout=timeout, **params) + + else: + async with ClientSession() as session : + resp = await session.post(endpoint, timeout=timeout, **params) + + return resp + + +kvs: KeyValueStore = KeyValueStore('kheina', 'notifications') + + +class Notifier(SqlInterface) : + + keys: Keys + subInfoFingerprint: bytes + serializerTypeMap: dict[NotificationType, type[BaseModel]] = { + NotificationType.post: InternalPostNotification, + NotificationType.user: InternalUserNotification, + NotificationType.interact: InternalInteractNotification, + } + + async def startup(self) -> None : + if getattr(Notifier, 'keys', None) is None : + Notifier.keys = Keys.load(**fetch('notifications', dict[str, str])) + + + @timed + @AerospikeCache('kheina', 'notifications', 'vapid-config', _kvs=kvs) + async def getVapidPem(self: Self) -> bytes : + async with self.transaction() as t : + vapid = Vapid02() + vapid_config = Config( + key = 'vapid-config', + created = datetime.zero(), + updated = datetime.zero(), + updated_by = 0, + bytes_ = None, + ) + + try : + vapid_config = await t.select(vapid_config) + assert vapid_config.bytes_ + return Notifier.keys.decrypt(vapid_config.bytes_) + + except KeyError : + pass + + vapid.generate_keys() + vapid_config.bytes_ = Notifier.keys.encrypt(vapid.private_pem()) + await t.insert(vapid_config) + return vapid.private_pem() + + + async def getVapid(self: Self) -> Vapid02 : + pk_pem = await self.getVapidPem() + return Vapid02.from_pem(pk_pem) + + + async def getApplicationServerKey(self: Self) -> ServerKey : + vapid = await self.getVapid() + pub = vapid.public_key + assert isinstance(pub, EllipticCurvePublicKey) + return ServerKey( + application_server_key = b64encode(pub.public_bytes( + serialization.Encoding.X962, + serialization.PublicFormat.UncompressedPoint, + )).decode(), + ) + + + @HttpErrorHandler('registering subscription info', exclusions=['self', 'sub_info']) + async def registerSubInfo(self: Self, user: KhUser, sub_info: SubscriptionInfo) -> None : + assert user.token, 'this should always be populated when the user is authenticated' + data: bytes = await sub_info.serialize() + await self.query_async(""" + select kheina.public.register_subscription(%s::uuid, %s, %s); + """, ( + user.token.guid, + user.user_id, + Notifier.keys.encrypt(data), + ), + commit = True, + ) + await kvs.remove_async(f'sub_info={user.user_id}') + + + @timed + async def unregisterSubInfo(self: Self, user_id: int, sub_ids: list[UUID]) -> None : + await self.query_async(""" + delete from kheina.public.subscriptions + where subscriptions.sub_id = any(%s); + """, ( + sub_ids, + ), + commit = True, + ) + await kvs.remove_async(f'sub_info={user_id}') + + + @timed + @AerospikeCache('kheina', 'notifications', 'sub_info={user_id}', _kvs=kvs) + async def getSubInfo(self: Self, user_id: int) -> dict[UUID, SubscriptionInfo] : + sub_info: dict[UUID, SubscriptionInfo] = { } + subs: list[Subscription] = await self.where(Subscription, Where( + Field('subscriptions', 'user_id'), + Operator.equal, + Value(user_id), + )) + + for s in subs : + sub = Notifier.keys.decrypt(s.subscription_info) + sub_info[s.sub_id] = await SubscriptionInfo.deserialize(sub) + + return sub_info + + + async def vapidHeaders(self: Self, sub_info: SubscriptionInfo) -> dict[str, str] : + url = urlparse(sub_info.endpoint) + claim = { + 'sub': 'mailto:help@kheina.com', + 'aud': f'{url.scheme}://{url.netloc}', + 'exp': int(datetime.now().timestamp()) + 1440, # 1 hour, I guess? + } + vapid = await self.getVapid() + return vapid.sign(claim) + + + @timed + async def _send(self: Self, user_id: int, data: dict) -> int : + unregister: list[UUID] = [] + successes: int = 0 + subs = await self.getSubInfo(user_id) + for sub_id, sub in subs.items() : + # sub_id is the token guid of the token that created the subscription + # check that it's still active before sending the notification + token = await getTokenMetadata(sub_id) + if not token or token.state != AuthState.active : + unregister.append(sub_id) + continue + + res = await WebPusher( + sub.dict(), + verbose = True, + ).send_async( + data = ujson.dumps(json_stream(data)), + headers = await self.vapidHeaders(sub), + content_encoding = "aes128gcm", + ) + + if not isinstance(res, ClientResponse) : + raise TypeError(f'expected response to be ClientResponse, got {type(res)}') + + if res.status < 300 : + successes += 1 + + elif res.status == 410 : + unregister.append(sub_id) + + else : + raise InternalServerError('unexpected error occurred while sending notification', status=res.status, content=await res.text()) + + if unregister : + await self.unregisterSubInfo(user_id, unregister) + + self.logger.debug({ + 'message': 'sent notification', + 'successes': successes, + 'failures': len(unregister), + 'to': user_id, + 'notification': data, + }) + + return successes + + + @timed.root + async def sendNotification( + self: Self, + user_id: int, + data: InternalInteractNotification | InternalPostNotification | InternalUserNotification, + **kwargs: UserPortable | Post, + ) -> None : + """ + creates, persists and then sends the given notification to the provided user_id. + kwargs must include the user and/or post of the notification's user_id/post_id in the form of + ``` + await sendNotification(..., user=UserPortable(...), post=Post(...)) + ``` + """ + try : + inotification = await self.insert(InternalNotification( + id = uuid7(), + user_id = user_id, + type_ = data.type_(), + created = datetime.zero(), + data = await data.serialize(), + )) + + self.logger.debug({ + 'message': 'notification', + 'to': user_id, + 'notification': { + 'type': type(data), + 'type_enm': data.type_(), + **data.dict(), + }, + }) + + match data : + case InternalInteractNotification() : + user, post = kwargs.get('user'), kwargs.get('post') + assert isinstance(user, UserPortable) and isinstance(post, Post), 'interact notifications must include user and post kwargs' + notification = InteractNotification( + id = inotification.id, + event = data.event, + created = inotification.created, + user = user, + post = post, + ) + await self._send(user_id, notification.dict()) + + case InternalPostNotification() : + post = kwargs.get('post') + assert isinstance(post, Post), 'post notifications must include a post kwarg' + notification = PostNotification( + id = inotification.id, + event = data.event, + created = inotification.created, + post = post, + ) + await self._send(user_id, notification.dict()) + + case InternalUserNotification() : + user = kwargs.get('user') + assert isinstance(user, UserPortable), 'user notifications must include a user kwarg' + notification = UserNotification( + id = inotification.id, + event = data.event, + created = inotification.created, + user = user, + ) + await self._send(user_id, notification.dict()) + + except : + # since this function will almost always be run async using ensure_future, handle errors internally + self.logger.exception('failed to send notification') + + + @HttpErrorHandler('sending some random cunt a notif') + @timed + async def debugSendNotification(self: Self, user_id: int, data: dict) -> int : + return await self._send(user_id, data) + + +notifier: Notifier = Notifier() diff --git a/notifications/router.py b/notifications/router.py new file mode 100644 index 0000000..1483141 --- /dev/null +++ b/notifications/router.py @@ -0,0 +1,61 @@ +from fastapi import APIRouter + +from shared.datetime import datetime +from shared.models.auth import Scope +from shared.models.server import Request +from shared.timing import timed + +from .models import InteractNotification, PostNotification, ServerKey, SubscriptionInfo, UserNotification +from .notifications import Notifications + + +notifier = Notifications() + + +notificationsRouter = APIRouter( + prefix='/notifications', +) + + +@notificationsRouter.on_event('startup') +async def startup() -> None : + await notifier.startup() + + +@notificationsRouter.get('/register', response_model=ServerKey) +@timed.root +async def v1GetServerKey(req: Request) -> ServerKey : + """ + only auth required + """ + await req.user.authenticated() + return await notifier.getApplicationServerKey() + + +@notificationsRouter.put('/register', response_model=None) +@timed.root +async def v1RegisterNotificationTarget(req: Request, body: SubscriptionInfo) -> None : + await req.user.authenticated() + await notifier.registerSubInfo(req.user, body) + + +@notificationsRouter.get('') +@timed.root +async def v1GetNotifications(req: Request) -> list[InteractNotification | PostNotification | UserNotification] : + await req.user.authenticated() + return await notifier.fetchNotifications(req.user) + + +@notificationsRouter.post('', status_code=201) +@timed.root +async def v1SendThisBitchAVibe(req: Request, body: dict) -> int : + await req.user.verify_scope(Scope.admin) + return await notifier.debugSendNotification(req.user.user_id, body) + + +app = APIRouter( + prefix='/v1', + tags=['notifications'], +) + +app.include_router(notificationsRouter) diff --git a/posts/blocking.py b/posts/blocking.py index f35196f..05367a2 100644 --- a/posts/blocking.py +++ b/posts/blocking.py @@ -1,12 +1,13 @@ -from typing import Dict, Iterable, Self, Set, Tuple +from typing import Iterable, Optional, Self from configs.configs import Configs -from configs.models import UserConfig +from configs.models import Blocking from shared.auth import KhUser from shared.caching import ArgsCache -from shared.models import InternalUser from shared.timing import timed +from .models import Rating + configs = Configs() @@ -29,23 +30,29 @@ def dict(self: Self) : def __init__(self: 'BlockTree') : - self.tags: Set[str] = set() - self.match: Dict[str, BlockTree] = { } - self.nomatch: Dict[str, BlockTree] = { } + self.tags: set[str | int] = set() + self.match: dict[str | int, BlockTree] = { } + self.nomatch: dict[str | int, BlockTree] = { } - def populate(self: Self, tags: Iterable[Iterable[str]]) : + def populate(self: Self, tags: Iterable[Iterable[str | int]]) : for tag_set in tags : tree: BlockTree = self for tag in tag_set : match = True - if tag.startswith('-') : - match = False - tag = tag[1:] + if isinstance(tag, str) : + if tag.startswith('-') : + match = False + tag = tag[1:] + + elif isinstance(tag, int) : + if tag < 0 : + match = False + tag *= -1 - branch: Dict[str, BlockTree] + branch: dict[str | int, BlockTree] if match : if not tree.match : @@ -65,7 +72,7 @@ def populate(self: Self, tags: Iterable[Iterable[str]]) : tree = branch[tag] - def blocked(self: Self, tags: Iterable[str]) -> bool : + def blocked(self: Self, tags: Iterable[str | int]) -> bool : if not self.match and not self.nomatch : return False @@ -93,27 +100,27 @@ def _blocked(self: Self, tree: 'BlockTree') -> bool : @ArgsCache(30) -async def fetch_block_tree(user: KhUser) -> Tuple[BlockTree, UserConfig] : +async def fetch_block_tree(user: KhUser) -> tuple[BlockTree, Optional[set[int]]] : tree: BlockTree = BlockTree() if not user.token : # TODO: create and return a default config - return tree, UserConfig() + return tree, None - # TODO: return underlying UserConfig here, once internal tokens are implemented - user_config: UserConfig = await configs._getUserConfig(user.user_id) - tree.populate(user_config.blocked_tags or []) - return tree, user_config + config: Blocking = await configs._getUserConfig(user.user_id, Blocking) + tree.populate(config.tags or []) + return tree, set(config.users) if config.users else None @timed -async def is_post_blocked(user: KhUser, uploader: InternalUser, tags: Iterable[str]) -> bool : - block_tree, user_config = await fetch_block_tree(user) +async def is_post_blocked(user: KhUser, uploader: int, rating: Rating, tags: Iterable[str]) -> bool : + block_tree, blocked_users = await fetch_block_tree(user) - if user_config.blocked_users and uploader.user_id in user_config.blocked_users : + if blocked_users and uploader in blocked_users : return True - tags: Set[str] = set(tags) - tags.add('@' + uploader.handle) # TODO: user ids need to be added here instead of just handle, once changeable handles are added + tags: set[str | int] = set(tags) # TODO: convert handles to user_ids (int) + tags.add(uploader) + tags.add(f'rating:{rating.name}') return block_tree.blocked(tags) diff --git a/posts/models.py b/posts/models.py index 0abd2be..460b64f 100644 --- a/posts/models.py +++ b/posts/models.py @@ -158,7 +158,7 @@ class Post(OmitModel) : post_id: PostId title: Optional[str] description: Optional[str] - user: UserPortable + user: Optional[UserPortable] score: Optional[Score] rating: Rating parent_id: Optional[PostId] @@ -169,6 +169,7 @@ class Post(OmitModel) : media: Optional[Media] tags: Optional[TagGroups] blocked: bool + locked: bool = False replies: Optional[list['Post']] = None """ None implies "not retrieved" whereas [] means no replies found @@ -273,7 +274,7 @@ class InternalScore(BaseModel) : total: int -######################### uploader things +################################################## uploader things ################################################## class UpdateRequest(BaseModel) : diff --git a/posts/posts.py b/posts/posts.py index 0224680..0bd15b5 100644 --- a/posts/posts.py +++ b/posts/posts.py @@ -4,7 +4,7 @@ from typing import Iterable, Optional, Self from sets.models import InternalSet, SetId -from sets.repository import Sets +from sets.repository import Repository as Sets from shared.auth import KhUser from shared.caching import AerospikeCache, ArgsCache from shared.datetime import datetime @@ -13,13 +13,13 @@ from shared.timing import timed from .models import InternalPost, Post, PostId, PostSort, Privacy, Rating, Score, SearchResults -from .repository import PostKVS, Posts, privacy_map, rating_map, users # type: ignore +from .repository import PostKVS, Repository, privacy_map, rating_map, users # type: ignore sets = Sets() -class Posts(Posts) : +class Posts(Repository) : @staticmethod def _normalize_tag(tag: str) : @@ -544,6 +544,8 @@ async def _fetch_posts(self: Self, sort: PostSort, tags: Optional[tuple[str, ... order = [(Field('set_post', 'index'), order)], alias = 'order', ), + ).group( + Field('set_post', 'index'), ).order( Field('set_post', 'index'), order, @@ -656,11 +658,6 @@ async def _getComments(self: Self, post_id: PostId, sort: PostSort, count: int, Operator.equal, Value(await privacy_map.get_id(Privacy.public)), ), - Where( - Field('posts', 'locked'), - Operator.equal, - Value(False), - ), ).union( Query( Table('kheina.public.posts'), @@ -679,11 +676,6 @@ async def _getComments(self: Self, post_id: PostId, sort: PostSort, count: int, Operator.equal, Value(await privacy_map.get_id(Privacy.public)), ), - Where( - Field('posts', 'locked'), - Operator.equal, - Value(False), - ), ), ), recursive = True, diff --git a/posts/repository.py b/posts/repository.py index f411e80..c7e76a0 100644 --- a/posts/repository.py +++ b/posts/repository.py @@ -1,7 +1,6 @@ -from asyncio import Task, ensure_future +from asyncio import Task, create_task from collections import defaultdict -from dataclasses import dataclass -from typing import Callable, Mapping, Optional, Self, Tuple, Union +from typing import Callable, Iterable, Mapping, Optional, Self from cache import AsyncLRU @@ -11,13 +10,15 @@ from shared.datetime import datetime from shared.exceptions.http_error import BadRequest, NotFound from shared.maps import privacy_map -from shared.models import InternalUser, Undefined, UserPortable +from shared.models import InternalUser, UserPortable from shared.sql import SqlInterface from shared.sql.query import CTE, Field, Join, JoinType, Operator, Order, Query, Table, Value, Where from shared.timing import timed +from shared.utilities import ensure_future from tags.models import InternalTag, Tag, TagGroup -from tags.repository import TagKVS, Tags -from users.repository import Users +from tags.repository import Repository as Tags +from tags.repository import TagKVS +from users.repository import Repository as Users from .blocking import is_post_blocked from .models import InternalPost, InternalScore, Media, MediaFlag, MediaType, Post, PostId, PostSize, Privacy, Rating, Score, Thumbnail @@ -36,7 +37,7 @@ class RatingMap(SqlInterface) : @timed @AsyncLRU(maxsize=0) async def get(self, key: int) -> Rating : - data: Tuple[str] = await self.query_async(""" + data: tuple[str] = await self.query_async(""" SELECT rating FROM kheina.public.ratings WHERE ratings.rating_id = %s @@ -53,7 +54,7 @@ async def get(self, key: int) -> Rating : @timed @AsyncLRU(maxsize=0) async def get_id(self, key: str | Rating) -> int : - data: Tuple[int] = await self.query_async(""" + data: tuple[int] = await self.query_async(""" SELECT rating_id FROM kheina.public.ratings WHERE ratings.rating = %s @@ -76,7 +77,7 @@ class MediaTypeMap(SqlInterface) : @timed @AsyncLRU(maxsize=0) async def get(self, key: int) -> MediaType : - data: Tuple[str, str] = await self.query_async(""" + data: tuple[str, str] = await self.query_async(""" SELECT file_type, mime_type FROM kheina.public.media_type WHERE media_type.media_type_id = %s @@ -94,7 +95,7 @@ async def get(self, key: int) -> MediaType : @timed @AsyncLRU(maxsize=0) async def get_id(self, mime: str) -> int : - data: Tuple[int] = await self.query_async(""" + data: tuple[int] = await self.query_async(""" SELECT media_type_id FROM kheina.public.media_type WHERE media_type.mime_type = %s @@ -110,18 +111,12 @@ async def get_id(self, mime: str) -> int : media_type_map: MediaTypeMap = MediaTypeMap() -@dataclass -class UserCombined: - portable: UserPortable - internal: InternalUser - - -class Posts(SqlInterface) : +class Repository(SqlInterface) : def parse_response( self: Self, data: list[ - Tuple[ + tuple[ int, # 0 post_id str, # 1 title str, # 2 description @@ -148,20 +143,6 @@ def parse_response( posts: list[InternalPost] = [] for row in data : - # media: Optional[InternalMedia] = None - # if row[7] and row[8] and row[16] : - # media = InternalMedia( - # post_id = row[0], - # filename = row[7], - # type = row[8], - # crc = row[15], - # updated = row[16], - # size = PostSize( - # width = row[9], - # height = row[10], - # ) if row[9] and row[10] else None, - # ) - post = InternalPost( post_id = row[0], title = row[1], @@ -176,15 +157,14 @@ def parse_response( width = row[9], height = row[10], ) if row[9] and row[10] else None, - user_id = row[11], - privacy = row[12], - thumbhash = row[13], - locked = row[14], - crc = row[15], - media_updated = row[16], - content_length = row[17], - thumbnails = row[18], # type: ignore - + user_id = row[11], + privacy = row[12], + thumbhash = row[13], + locked = row[14], + crc = row[15], + media_updated = row[16], + content_length = row[17], + thumbnails = row[18], # type: ignore include_in_results = row[19], ) posts.append(post) @@ -195,7 +175,7 @@ def parse_response( def internal_select(self: Self, query: Query) -> Callable[[ list[ - Tuple[ + tuple[ int, # 0 post_id str, # 1 title str, # 2 description @@ -311,10 +291,45 @@ async def _get_post(self: Self, post_id: PostId) -> InternalPost : @timed - async def parents(self: Self, user: KhUser, ipost: InternalPost) -> Optional[Post] : - if not ipost.parent : - return None + async def _get_posts(self: Self, post_ids: Iterable[PostId]) -> dict[PostId, InternalPost] : + if not post_ids : + return { } + + cached = await PostKVS.get_many_async(post_ids) + found: dict[PostId, InternalPost] = { } + misses: list[PostId] = [] + + for k, v in cached.items() : + if v is None or isinstance(v, InternalPost) : + found[k] = v + continue + + misses.append(k) + if not misses : + return found + + posts: dict[PostId, InternalPost] = found + data: list[InternalPost] = await self.where( + InternalPost, + Where( + Field('internal_posts', 'post_id'), + Operator.equal, + Value(misses, functions = ['any']), + ), + ) + + for post in data : + post_id = PostId(post.post_id) + ensure_future(PostKVS.put_async(post_id, post)) + posts[post_id] = post + + return posts + + + @timed + @AerospikeCache('kheina', 'posts', 'parents={parent}', _kvs=PostKVS) + async def _parents(self: Self, parent: int) -> list[InternalPost] : cte = Query( Table('post_ids', cte=True), ).cte( @@ -330,7 +345,7 @@ async def parents(self: Self, user: KhUser, ipost: InternalPost) -> Optional[Pos Where( Field('posts', 'post_id'), Operator.equal, - Value(ipost.parent), + Value(parent), ), ).union( Query( @@ -387,9 +402,16 @@ async def parents(self: Self, user: KhUser, ipost: InternalPost) -> Optional[Pos ), ), ) - parser = self.internal_select(query := self.CteQuery(cte)) - iposts = parser(await self.query_async(query, fetch_all=True)) + return parser(await self.query_async(query, fetch_all=True)) + + + @timed + async def parents(self: Self, user: KhUser, ipost: InternalPost) -> Optional[Post] : + if not ipost.parent : + return None + + iposts = await self._parents(ipost.parent) posts = await self.posts(user, iposts) assert len(posts) == 1 return posts[0] @@ -398,18 +420,45 @@ async def parents(self: Self, user: KhUser, ipost: InternalPost) -> Optional[Pos @timed async def post(self: Self, user: KhUser, ipost: InternalPost) -> Post : post_id: PostId = PostId(ipost.post_id) - parent: Task[Optional[Post]] = ensure_future(self.parents(user, ipost)) - upl: Task[InternalUser] = ensure_future(users._get_user(ipost.user_id)) - tags_task: Task[list[InternalTag]] = ensure_future(tagger._fetch_tags_by_post(post_id)) - score: Task[Optional[Score]] = ensure_future(self.getScore(user, post_id)) + parent: Task[Optional[Post]] = create_task(self.parents(user, ipost)) + upl: Task[InternalUser] = create_task(users._get_user(ipost.user_id)) + tags_task: Task[list[InternalTag]] = create_task(tagger._fetch_tags_by_post(post_id)) + score: Task[Optional[Score]] = create_task(self.getScore(user, post_id)) uploader: InternalUser = await upl - upl_portable: Task[UserPortable] = ensure_future(users.portable(user, uploader)) + upl_portable: Task[UserPortable] = create_task(users.portable(user, uploader)) itags: list[InternalTag] = await tags_task - tags: Task[list[Tag]] = ensure_future(tagger.tags(user, itags)) - blocked: Task[bool] = ensure_future(is_post_blocked(user, uploader, [t.name for t in itags])) + tags: Task[list[Tag]] = create_task(tagger.tags(user, itags)) + blocked: Task[bool] = create_task(is_post_blocked(user, ipost.user_id, await rating_map.get(ipost.rating), (t.name for t in itags))) + + post = Post( + post_id = post_id, + title = None, + description = None, + user = None, + score = await score, + rating = await rating_map.get(ipost.rating), + parent = await parent, + parent_id = PostId(ipost.parent) if ipost.parent else None, + privacy = await privacy_map.get(ipost.privacy), + created = ipost.created, + updated = ipost.updated, + media = None, + tags = tagger.groups(await tags), + blocked = await blocked, + replies = None, + ) + + if ipost.locked : + post.locked = True + + if not await user.verify_scope(Scope.mod, False) and ipost.user_id != user.user_id : + return post # we don't want any other fields populated + + post.title = ipost.title + post.description = ipost.description + post.user = await upl_portable - media: Optional[Media] = None if ipost.filename and ipost.media_type and ipost.size and ipost.content_length and ipost.thumbnails : flags: list[MediaFlag] = [] @@ -417,7 +466,7 @@ async def post(self: Self, user: KhUser, ipost: InternalPost) -> Post : if itag.group == TagGroup.system : flags.append(MediaFlag[itag.name]) - media = Media( + post.media = Media( post_id = PostId(ipost.post_id), crc = ipost.crc, filename = ipost.filename, @@ -443,23 +492,7 @@ async def post(self: Self, user: KhUser, ipost: InternalPost) -> Post : ], ) - return Post( - post_id = post_id, - title = ipost.title, - description = ipost.description, - user = await upl_portable, - score = await score, - rating = await rating_map.get(ipost.rating), - parent = await parent, - parent_id = PostId(ipost.parent) if ipost.parent else None, - privacy = await privacy_map.get(ipost.privacy), - created = ipost.created, - updated = ipost.updated, - media = media, - tags = tagger.groups(await tags), - blocked = await blocked, - replies = None, - ) + return post @timed @@ -493,20 +526,22 @@ async def scores_many(self: Self, post_ids: list[PostId]) -> dict[PostId, Option return { } cached = await ScoreKVS.get_many_async(post_ids) + found: dict[PostId, Optional[InternalScore]] = { } misses: list[PostId] = [] - for k, v in list(cached.items()) : - if v is not Undefined : + for k, v in cached.items() : + if v is None or isinstance(v, InternalScore) : + found[k] = v continue misses.append(k) - cached[k] = None + found[k] = None if not misses : - return cached + return found - scores: dict[PostId, Optional[InternalScore]] = cached - data: list[Tuple[int, int, int]] = await self.query_async(""" + scores: dict[PostId, Optional[InternalScore]] = found + data: list[tuple[int, int, int]] = await self.query_async(""" SELECT post_scores.post_id, post_scores.upvotes, @@ -519,9 +554,6 @@ async def scores_many(self: Self, post_ids: list[PostId]) -> dict[PostId, Option fetch_all = True, ) - if not data : - return scores - for post_id, up, down in data : post_id = PostId(post_id) score: InternalScore = InternalScore( @@ -530,7 +562,9 @@ async def scores_many(self: Self, post_ids: list[PostId]) -> dict[PostId, Option total = up + down, ) scores[post_id] = score - ensure_future(ScoreKVS.put_async(post_id, score)) + + for k, v in scores.items() : + ensure_future(ScoreKVS.put_async(k, v)) return scores @@ -538,7 +572,7 @@ async def scores_many(self: Self, post_ids: list[PostId]) -> dict[PostId, Option @timed @AerospikeCache('kheina', 'votes', '{user_id}|{post_id}', _kvs=VoteKVS) async def _get_vote(self: Self, user_id: int, post_id: PostId) -> int : - data: Optional[Tuple[bool]] = await self.query_async(""" + data: Optional[tuple[bool]] = await self.query_async(""" SELECT upvote FROM kheina.public.post_votes @@ -566,20 +600,21 @@ async def votes_many(self: Self, user_id: int, post_ids: list[PostId]) -> dict[P PostId(k[k.rfind('|') + 1:]): v for k, v in (await VoteKVS.get_many_async([f'{user_id}|{post_id}' for post_id in post_ids])).items() } + found: dict[PostId, int] = { } misses: list[PostId] = [] - for k, v in list(cached.items()) : - if v is not Undefined : + for k, v in cached.items() : + if isinstance(v, int) : + found[k] = v continue misses.append(k) - cached[k] = None if not misses : - return cached + return found - votes: dict[PostId, int] = cached - data: list[Tuple[int, int]] = await self.query_async(""" + votes: dict[PostId, int] = found + data: list[tuple[int, int]] = await self.query_async(""" SELECT post_votes.post_id, post_votes.upvote @@ -593,13 +628,12 @@ async def votes_many(self: Self, user_id: int, post_ids: list[PostId]) -> dict[P fetch_all = True, ) - if not data : - return votes - for post_id, upvote in data : post_id = PostId(post_id) vote: int = 1 if upvote else -1 votes[post_id] = vote + + for post_id, vote in votes.items() : ensure_future(VoteKVS.put_async(f'{user_id}|{post_id}', vote)) return votes @@ -607,8 +641,8 @@ async def votes_many(self: Self, user_id: int, post_ids: list[PostId]) -> dict[P @timed async def getScore(self: Self, user: KhUser, post_id: PostId) -> Optional[Score] : - score_task: Task[Optional[InternalScore]] = ensure_future(self._get_score(post_id)) - vote: Task[int] = ensure_future(self._get_vote(user.user_id, post_id)) + score_task: Task[Optional[InternalScore]] = create_task(self._get_score(post_id)) + vote: Task[int] = create_task(self._get_vote(user.user_id, post_id)) score = await score_task @@ -668,6 +702,11 @@ def _validateVote(self: Self, vote: Optional[bool]) -> None : @timed async def _vote(self: Self, user: KhUser, post_id: PostId, upvote: Optional[bool]) -> Score : + ipost: InternalPost = await self._get_post(post_id) + + if ipost.locked : + raise BadRequest('cannot vote on a post that has been locked', post=ipost) + self._validateVote(upvote) async with self.transaction() as transaction : await transaction.query_async(""" @@ -686,7 +725,7 @@ async def _vote(self: Self, user: KhUser, post_id: PostId, upvote: Optional[bool ), ) - data: Tuple[int, int, datetime] = await transaction.query_async(""" + data: tuple[int, int, datetime] = await transaction.query_async(""" SELECT COUNT(post_votes.upvote), SUM(post_votes.upvote::int), posts.created FROM kheina.public.posts LEFT JOIN kheina.public.post_votes @@ -751,14 +790,14 @@ async def _vote(self: Self, user: KhUser, post_id: PostId, upvote: Optional[bool @timed - async def _uploaders(self: Self, user: KhUser, iposts: list[InternalPost]) -> dict[int, UserCombined] : + async def _uploaders(self: Self, user: KhUser, iposts: list[InternalPost]) -> dict[int, UserPortable] : """ returns populated user objects for every uploader id provided :return: dict in the form user id -> populated User object """ uploader_ids: list[int] = list(set(map(lambda x : x.user_id, iposts))) - users_task: Task[dict[int, InternalUser]] = ensure_future(users._get_users(uploader_ids)) + users_task: Task[dict[int, InternalUser]] = create_task(users._get_users(uploader_ids)) following: Mapping[int, Optional[bool]] if await user.authenticated(False) : @@ -770,16 +809,14 @@ async def _uploaders(self: Self, user: KhUser, iposts: list[InternalPost]) -> di iusers: dict[int, InternalUser] = await users_task return { - user_id: UserCombined( - internal = iuser, - portable = UserPortable( - name = iuser.name, - handle = iuser.handle, - privacy = users._validate_privacy(await privacy_map.get(iuser.privacy)), - icon = iuser.icon, - verified = iuser.verified, - following = following[user_id], - ), + user_id: + UserPortable( + name = iuser.name, + handle = iuser.handle, + privacy = users._validate_privacy(await privacy_map.get(iuser.privacy)), + icon = iuser.icon, + verified = iuser.verified, + following = following[user_id], ) for user_id, iuser in iusers.items() } @@ -805,7 +842,7 @@ async def _scores(self: Self, user: KhUser, iposts: list[InternalPost]) -> dict[ # but put all of them in the dict scores[post_id] = None - iscores_task: Task[dict[PostId, Optional[InternalScore]]] = ensure_future(self.scores_many(post_ids)) + iscores_task: Task[dict[PostId, Optional[InternalScore]]] = create_task(self.scores_many(post_ids)) user_votes: dict[PostId, int] if await user.authenticated(False) : @@ -836,21 +873,22 @@ async def _tags_many(self: Self, post_ids: list[PostId]) -> dict[PostId, list[In cached = { PostId(k[k.rfind('.') + 1:]): v - for k, v in (await VoteKVS.get_many_async([f'post.{post_id}' for post_id in post_ids])).items() + for k, v in (await TagKVS.get_many_async([f'post.{post_id}' for post_id in post_ids])).items() } + found: dict[PostId, list[InternalTag]] = { } misses: list[PostId] = [] - for k, v in list(cached.items()) : - if v is not Undefined : + for k, v in cached.items() : + if isinstance(v, list) and all(map(lambda x : isinstance(x, InternalTag), v)) : + found[k] = v continue misses.append(k) - del cached[k] if not misses : - return cached + return found - tags: dict[PostId, list[InternalTag]] = defaultdict(list, cached) + tags: dict[PostId, list[InternalTag]] = defaultdict(list, found) data: list[tuple[int, str, str, bool, Optional[int]]] = await self.query_async(""" SELECT tag_post.post_id, @@ -881,8 +919,8 @@ async def _tags_many(self: Self, post_ids: list[PostId]) -> dict[PostId, list[In description = None, # in this case, we don't care about this field )) - for post_id, t in tags.items() : - ensure_future(TagKVS.put_async(f'post.{post_id}', t)) + for post_id in post_ids : + ensure_future(TagKVS.put_async(f'post.{post_id}', tags[post_id])) return tags @@ -894,14 +932,17 @@ async def posts(self: Self, user: KhUser, iposts: list[InternalPost], assign_par assign_parents = True will assign any posts found with a matching parent id to the `parent` field of the resulting Post object assign_parents = False will instead assign these posts to the `replies` field of the resulting Post object """ - uploaders_task: Task[dict[int, UserCombined]] = ensure_future(self._uploaders(user, iposts)) - scores_task: Task[dict[PostId, Optional[Score]]] = ensure_future(self._scores(user, iposts)) + + # TODO: at some point, we should make this even faster by joining the uploaders and tag owners tasks + + uploaders_task: Task[dict[int, UserPortable]] = create_task(self._uploaders(user, iposts)) + scores_task: Task[dict[PostId, Optional[Score]]] = create_task(self._scores(user, iposts)) tags: dict[PostId, list[InternalTag]] = await self._tags_many(list(map(lambda x : PostId(x.post_id), iposts))) - at_task: Task[list[Tag]] = ensure_future(tagger.tags(user, [t for l in tags.values() for t in l])) - uploaders: dict[int, UserCombined] = await uploaders_task + at_task: Task[list[Tag]] = create_task(tagger.tags(user, [t for l in tags.values() for t in l])) + uploaders: dict[int, UserPortable] = await uploaders_task scores: dict[PostId, Optional[Score]] = await scores_task - all_tags: dict[str, Tag] = { + all_tags: dict[str, Tag] = { tag.tag: tag for tag in await at_task } @@ -928,9 +969,43 @@ async def posts(self: Self, user: KhUser, iposts: list[InternalPost], assign_par if itag.name in MediaFlag.__members__ : flags.append(MediaFlag[itag.name]) - media: Optional[Media] = None + post = all_posts[post_id] = Post( + post_id = post_id, + title = None, + description = None, + user = None, + score = scores[post_id], + rating = await rating_map.get(ipost.rating), + privacy = await privacy_map.get(ipost.privacy), + media = None, + created = ipost.created, + updated = ipost.updated, + parent_id = parent_id, + + # only the first call retrieves blocked info, all the rest should be cached and not actually await + blocked = await is_post_blocked(user, ipost.user_id, await rating_map.get(ipost.rating), tag_names), + tags = tagger.groups(post_tags) + ) + + if not assign_parents : + # this way, when assign_parents = true, post.replies can be omitted by being unassigned + post.replies = [] + + if ipost.include_in_results : + posts.append(post) + + if ipost.locked : + post.locked = True + + if not await user.verify_scope(Scope.mod, False) and ipost.user_id != user.user_id : + continue # we don't want any other fields populated + + post.title = ipost.title + post.description = ipost.description + post.user = uploaders[ipost.user_id] + if ipost.filename and ipost.media_type and ipost.size and ipost.content_length and ipost.thumbnails : - media = Media( + post.media = Media( post_id = post_id, crc = ipost.crc, filename = ipost.filename, @@ -956,31 +1031,6 @@ async def posts(self: Self, user: KhUser, iposts: list[InternalPost], assign_par ], ) - post = all_posts[post_id] = Post( - post_id = post_id, - title = ipost.title, - description = ipost.description, - user = uploaders[ipost.user_id].portable, - score = scores[post_id], - rating = await rating_map.get(ipost.rating), - privacy = await privacy_map.get(ipost.privacy), - media = media, - created = ipost.created, - updated = ipost.updated, - parent_id = parent_id, - - # only the first call retrieves blocked info, all the rest should be cached and not actually await - blocked = await is_post_blocked(user, uploaders[ipost.user_id].internal, tag_names), - tags = tagger.groups(post_tags) - ) - - if not assign_parents : - # this way, when assign_parents = true, post.replies can be omitted by being unassigned - post.replies = [] - - if ipost.include_in_results : - posts.append(post) - if assign_parents : for post_id, parent in parents.items() : if parent not in all_posts : diff --git a/posts/router.py b/posts/router.py index f735c2c..cefcabc 100644 --- a/posts/router.py +++ b/posts/router.py @@ -1,17 +1,17 @@ from asyncio import ensure_future from html import escape -from typing import Literal, Optional, Union +from typing import Literal, Optional from uuid import uuid4 import aiofiles -from fastapi import APIRouter, File, Form, UploadFile +from fastapi import APIRouter, File, Form, Response, UploadFile from shared.backblaze import B2Interface from shared.config.constants import Environment, environment from shared.exceptions.http_error import UnprocessableDetail, UnprocessableEntity from shared.models import Privacy, convert_path_post_id from shared.models.auth import Scope -from shared.server import Request, Response +from shared.models.server import Request from shared.timing import timed from shared.utilities.units import Byte from users.users import Users @@ -302,7 +302,7 @@ async def v1Rss(req: Request) -> Response : mime_type = post.media.type.mime_type, length = post.media.length, ) if post.media else '', - ) for post in timeline + ) for post in timeline if post.user ]), ), ) diff --git a/posts/uploader.py b/posts/uploader.py index 151bdc6..9d9f46d 100644 --- a/posts/uploader.py +++ b/posts/uploader.py @@ -1,6 +1,7 @@ import json from asyncio import Task, create_subprocess_exec, ensure_future from enum import Enum +from hashlib import sha1 from io import BytesIO from os import path, remove from secrets import token_bytes @@ -16,11 +17,11 @@ from wand import resource from wand.image import Image +from notifications.repository import Notifier from shared.auth import KhUser, Scope from shared.backblaze import B2Interface, MimeType -from shared.base64 import b64decode +from shared.base64 import b64decode, b64encode from shared.caching.key_value_store import KeyValueStore -from shared.crc import CRC from shared.datetime import datetime from shared.exceptions.http_error import BadGateway, BadRequest, Forbidden, HttpErrorHandler, InternalServerError, NotFound from shared.models import InternalUser @@ -29,11 +30,15 @@ from shared.utilities import flatten, int_from_bytes from shared.utilities.units import Byte from tags.models import InternalTag -from tags.repository import CountKVS, Tags -from users.repository import UserKVS, Users +from tags.repository import CountKVS +from tags.repository import Repository as Tags +from users.repository import Repository as Users +from users.repository import UserKVS from .models import Coordinates, InternalPost, Media, MediaFlag, Post, PostId, PostSize, Privacy, Rating, Thumbnail -from .repository import PostKVS, Posts, VoteKVS, media_type_map, privacy_map, rating_map +from .repository import PostKVS +from .repository import Repository as Posts +from .repository import VoteKVS, media_type_map, privacy_map, rating_map from .scoring import confidence from .scoring import controversial as calc_cont from .scoring import hot as calc_hot @@ -44,27 +49,59 @@ resource.limits.set_resource_limit('disk', Byte.gigabyte.value * 100) UnpublishedPrivacies: Set[Privacy] = { Privacy.unpublished, Privacy.draft } -posts = Posts() -users = Users() -tagger = Tags() -_crc = CRC(32) +posts: Posts = Posts() +users: Users = Users() +tagger: Tags = Tags() +notifier: Notifier = Notifier() @timed def crc(value: bytes) -> int : - return _crc(value) + # return int.from_bytes(sha1(value).digest()[:8], signed=True) + return int.from_bytes(sha1(value).digest()[:4]) + + +@timed +async def extract_frame(file_on_disk: str, filename: str) -> str : + await FFmpeg().input( + file_on_disk, + accurate_seek = None, + ss = '0', + ).output( + (screenshot := f'images/{uuid4().hex}_{filename}.webp'), + { 'frames:v': '1' }, + ).execute() + return screenshot + + +@timed +async def validate_image(file_on_disk: str) -> None : + with Image(file=open(file_on_disk, 'rb')) : + pass + + +@timed +async def validate_video(file_on_disk: str) -> None : + # ffmpeg -v error -i file.avi -f null - + await FFmpeg().input( + file_on_disk, + v = 'error', + ).output( + '-', + f = 'null', + ).execute() class Uploader(SqlInterface, B2Interface) : def __init__(self: Self) -> None : + B2Interface.__init__(self, max_retries=5) SqlInterface.__init__( self, conversions={ Enum: lambda x: x.name, }, ) - B2Interface.__init__(self, max_retries=5) self.thumbnail_sizes: list[int] = [ 1200, 800, @@ -195,7 +232,13 @@ async def createPost(self: Self, user: KhUser) -> Post : for _ in range(100) : post_id = PostId.generate() - data = await t.query_async("SELECT count(1) FROM kheina.public.posts WHERE post_id = %s;", (post_id.int(),), fetch_one=True) + data = await t.query_async(""" + SELECT count(1) FROM kheina.public.posts WHERE post_id = %s; + """, ( + post_id.int(), + ), + fetch_one = True, + ) if not data[0] : break @@ -223,7 +266,7 @@ async def createPost(self: Self, user: KhUser) -> Post : user.user_id, user.user_id, ), - fetch_one=True, + fetch_one = True, ) post_id = PostId(data[0]) @@ -274,6 +317,7 @@ async def createPostWithFields( internal_post_id: int post_id: PostId + notify: bool = False async with self.transaction() as transaction : for _ in range(100) : @@ -297,13 +341,18 @@ async def createPostWithFields( post = await transaction.insert(post) if privacy : - await self._update_privacy(user, post_id, privacy, transaction=transaction, commit=False) + notify = await self._update_privacy(user, post_id, privacy, transaction=transaction, commit=False) post.privacy = await privacy_map.get_id(privacy) await transaction.commit() await PostKVS.put_async(post_id, post) + if notify : + """ + TODO: check for mentions in the post, and notify users that they were mentioned + """ + return await posts.post(user, post) @@ -374,18 +423,28 @@ async def insert_thumbnail(self: Self, t: Transaction, post_id: PostId, crc: int @timed - async def upload_thumbnail(self: Self, run: str, t: Transaction, post_id: PostId, crc: int, image: Image, size: int, ext: str) -> Thumbnail : - mime: MimeType = MimeType[ext] - url: str = f'{post_id}/{crc}/thumbnails/{size}.{ext}' - data: bytes = self.get_image_data(image.convert(mime.type())) - await self.upload_async(data, url, mime) - th = await self.insert_thumbnail(t, post_id, crc, mime, size, f'{size}.{ext}', len(data), image.size[0], image.size[1]) - self.logger.debug({ - 'run': run, - 'post': post_id, - 'message': f'uploaded thumbnail {mime.name}({size}) image to cdn', - }) - return th + async def upload_thumbnails(self: Self, run: str, t: Transaction, post_id: PostId, crc: int, image: Image, formats: list[tuple[int, str]]) -> list[Thumbnail] : + """ + formats is a tuple of size and file extension, used to resize each thumbnail and upload it + """ + ths: list[Thumbnail] = [] + # query: list[str] = [] + # params: list[Any] = [] + + for size, ext in sorted(formats, key=lambda x : x[0], reverse=True) : + image = self.convert_image(image, size) + mime: MimeType = MimeType[ext] + url: str = f'{post_id}/{crc}/thumbnails/{size}.{ext}' + data: bytes = self.get_image_data(image.convert(mime.type())) + await self.upload_async(data, url, mime) + ths.append(await self.insert_thumbnail(t, post_id, crc, mime, size, f'{size}.{ext}', len(data), image.size[0], image.size[1])) + self.logger.debug({ + 'run': run, + 'post': post_id, + 'message': f'uploaded thumbnail {mime.name}({size}) image to cdn', + }) + + return ths @timed @@ -422,17 +481,25 @@ async def uploadImage( emoji_name: Optional[str] = None, web_resize: Optional[int] = None, ) -> Media : - run: str = uuid4().hex + start: datetime = datetime.now() + run: str = uuid4().hex # validate it's an actual photo try : - with Image(file=open(file_on_disk, 'rb')) : - pass + await validate_image(file_on_disk) except Exception as e : self.delete_file(file_on_disk) raise BadRequest('Uploaded file is not an image.', err=e) + self.logger.debug({ + 'run': run, + 'post': post_id, + 'elapsed': datetime.now() - start, + 'file_on_disk': file_on_disk, + 'message': 'validated input image file', + }) + rev: int mime_type: MimeType @@ -461,6 +528,7 @@ async def uploadImage( self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, 'file_on_disk': file_on_disk, 'content_type': mime_type, 'filename': filename, @@ -474,7 +542,9 @@ async def uploadImage( self.logger.debug({ 'run': run, - 'thumbhash': thumbhash, + 'post': post_id, + 'elapsed': datetime.now() - start, + 'thumbhash': b64encode(thumbhash).decode(), }) async with self.transaction() as transaction : @@ -523,6 +593,7 @@ async def uploadImage( self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, 'message': 'resized for web', }) @@ -556,14 +627,14 @@ async def uploadImage( ( %s, %s, %s, %s, %s, %s, %s, %s) on conflict (post_id) do update set updated = now(), - type = %s, - filename = %s, - length = %s, - thumbhash = %s, - width = %s, - height = %s, - crc = %s - WHERE media.post_id = %s + type = excluded.type, + filename = excluded.filename, + length = excluded.length, + thumbhash = excluded.thumbhash, + width = excluded.width, + height = excluded.height, + crc = excluded.crc + WHERE media.post_id = excluded.post_id RETURNING media.updated; """, ( post_id.int(), @@ -574,18 +645,8 @@ async def uploadImage( image_size.width, image_size.height, rev, - - media_type, - filename, - content_length, - thumbhash, - image_size.width, - image_size.height, - rev, - - post_id.int(), ), - fetch_one=True, + fetch_one = True, ) updated: datetime = upd[0] @@ -602,6 +663,7 @@ async def uploadImage( self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, 'message': 'deleted old file from cdn', }) @@ -612,37 +674,21 @@ async def uploadImage( self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, 'message': 'uploaded fullsize image to cdn', }) # upload thumbnails - thumbnails: list[Thumbnail] = [] - + thumbnails: list[Thumbnail] with Image(file=open(file_on_disk, 'rb')) as image : - for i, size in enumerate(self.thumbnail_sizes) : - image = self.convert_image(image, size) - - if not i : - # jpeg thumbnail - thumbnails.append(await self.upload_thumbnail( - run, - transaction, - post_id, - rev, - image, - size, - 'jpg', - )) - - thumbnails.append(await self.upload_thumbnail( - run, - transaction, - post_id, - rev, - image, - size, - 'webp', - )) + thumbnails: list[Thumbnail] = await self.upload_thumbnails( + run, + transaction, + post_id, + rev, + image, + [(s, 'webp') for s in self.thumbnail_sizes] + [(self.thumbnail_sizes[0], 'jpg')], + ) del image @@ -729,9 +775,10 @@ async def updatePostMetadata( if not update and not update_privacy : raise BadRequest('no params were provided.') + notify: bool = False async with self.transaction() as t : if update_privacy and privacy : - await self._update_privacy(user, post_id, privacy, t, commit = False) + notify = await self._update_privacy(user, post_id, privacy, t, commit = False) post.privacy = await privacy_map.get_id(privacy) if update : @@ -740,6 +787,11 @@ async def updatePostMetadata( await PostKVS.put_async(post_id, post) await t.commit() + if notify : + """ + TODO: check for mentions and tags in the post, and notify users that they were mentioned, tagged, or a post matched one of their tag sets + """ + @timed async def _update_privacy( @@ -750,12 +802,17 @@ async def _update_privacy( transaction: Optional[Transaction] = None, commit: bool = True, ) -> bool : + """ + returns True if the post was published, false otherwise + """ if privacy == Privacy.unpublished : raise BadRequest('post privacy cannot be updated to unpublished.') if not transaction : transaction = self.transaction() + published: bool = False + async with transaction as t : data = await t.query_async(""" SELECT privacy.type @@ -786,6 +843,7 @@ async def _update_privacy( vote_task: Optional[Task] = None if old_privacy in UnpublishedPrivacies and privacy not in UnpublishedPrivacies : + published = True await t.query_async(""" INSERT INTO kheina.public.post_votes (user_id, post_id, upvote) @@ -873,19 +931,7 @@ async def _update_privacy( if vote_task : await vote_task - return True - - - @HttpErrorHandler('updating post privacy') - @timed - async def updatePrivacy(self: Self, user: KhUser, post_id: PostId, privacy: Privacy) : - success = await self._update_privacy(user, post_id, privacy) - - if await PostKVS.exists_async(post_id) : - # we need the created and updated values set by db, so just remove - ensure_future(PostKVS.remove_async(post_id)) - - return success + return published async def getImage(self: Self, ipost: InternalPost, coordinates: Coordinates) -> Image : @@ -1081,21 +1127,25 @@ async def uploadVideo( filename: str, post_id: PostId, ) -> Media : - run: str = uuid4().hex + start: datetime = datetime.now() + run: str = uuid4().hex # validate it's an actual video try : - await FFmpeg().input( - file_on_disk, - ).output( - '-', - f = 'null', - ).execute() + await validate_video(file_on_disk) except Exception as e : self.delete_file(file_on_disk) raise BadRequest('Uploaded file is not a video.', err=e) + self.logger.debug({ + 'run': run, + 'post': post_id, + 'elapsed': datetime.now() - start, + 'file_on_disk': file_on_disk, + 'message': 'validated input video file', + }) + rev: int mime_type: MimeType @@ -1115,19 +1165,13 @@ async def uploadVideo( try : # extract the first frame of the video to use for thumbnails/hash - await FFmpeg().input( - file_on_disk, - accurate_seek = None, - ss = '0', - ).output( - (screenshot := f'images/{uuid4().hex}_{filename}.webp'), - { 'frames:v': '1' }, - ).execute() + screenshot = await extract_frame(file_on_disk, filename) post: InternalPost = await posts._get_post(post_id) self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, 'file_on_disk': file_on_disk, 'content_type': mime_type, 'filename': filename, @@ -1140,7 +1184,9 @@ async def uploadVideo( self.logger.debug({ 'run': run, - 'thumbhash': thumbhash, + 'post': post_id, + 'elapsed': datetime.now() - start, + 'thumbhash': b64encode(thumbhash).decode(), }) async with self.transaction() as transaction : @@ -1164,7 +1210,7 @@ async def uploadVideo( old_filename: Optional[str] = data[0] old_crc: Optional[int] = data[1] - image_size: PostSize + image_size: PostSize del data await self.purgeSystemTags(run, transaction, post_id) @@ -1195,14 +1241,15 @@ async def uploadVideo( audio = True continue - # since there can be empty audio streams, we need to do a further check of the audio stream itself - media = await self.parse_audio_stream(file_on_disk) - if media.get('RMS level dB') != '-inf' : - query.append("(tag_to_id('audio'), %s, 0)") - params.append(post_id.int()) - flags.append(MediaFlag.audio) + if audio : + # since there can be empty audio streams, we need to do a further check of the audio stream itself + media = await self.parse_audio_stream(file_on_disk) + if (rms := media.get('RMS level dB')) and rms != '-inf' : + query.append("(tag_to_id('audio'), %s, 0)") + params.append(post_id.int()) + flags.append(MediaFlag.audio) - del media + del media if not query or not params : raise BadRequest('no media streams found!') @@ -1236,14 +1283,14 @@ async def uploadVideo( ( %s, %s, %s, %s, %s, %s, %s, %s) on conflict (post_id) do update set updated = now(), - type = %s, - filename = %s, - length = %s, - thumbhash = %s, - width = %s, - height = %s, - crc = %s - WHERE media.post_id = %s + type = excluded.type, + filename = excluded.filename, + length = excluded.length, + thumbhash = excluded.thumbhash, + width = excluded.width, + height = excluded.height, + crc = excluded.crc + WHERE media.post_id = excluded.post_id RETURNING media.updated; """, ( post_id.int(), @@ -1254,75 +1301,44 @@ async def uploadVideo( image_size.width, image_size.height, rev, - - media_type, - filename, - content_length, - thumbhash, - image_size.width, - image_size.height, - rev, - - post_id.int(), ), fetch_one = True, ) updated: datetime = upd[0] if old_filename : - old_url: str - - if old_crc : - old_url = f'{post_id}/{old_crc}/{old_filename}' - - else : - old_url = f'{post_id}/{old_filename}' - + old_url: str = f'{post_id}/{old_crc}/{old_filename}' if old_crc else f'{post_id}/{old_filename}' await self.delete_file_async(old_url) self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, + 'url': old_url, 'message': 'deleted old file from cdn', }) - url: str = f'{post_id}/{rev}/{filename}' - # upload fullsize + url: str = f'{post_id}/{rev}/{filename}' await self.upload_async(open(file_on_disk, 'rb').read(), url, content_type = mime_type) self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, + 'url': url, 'message': 'uploaded fullsize image to cdn', }) # upload thumbnails - thumbnails: list[Thumbnail] = [] - + thumbnails: list[Thumbnail] with Image(file=open(screenshot, 'rb')) as image : - for i, size in enumerate(self.thumbnail_sizes) : - image = self.convert_image(image, size) - - if not i : - # jpeg thumbnail - thumbnails.append(await self.upload_thumbnail( - run, - transaction, - post_id, - rev, - image, - size, - 'jpg', - )) - - thumbnails.append(await self.upload_thumbnail( - run, - transaction, - post_id, - rev, - image, - size, - 'webp', - )) + thumbnails: list[Thumbnail] = await self.upload_thumbnails( + run, + transaction, + post_id, + rev, + image, + [(s, 'webp') for s in self.thumbnail_sizes] + [(self.thumbnail_sizes[0], 'jpg')], + ) del image diff --git a/reporting/mod_actions.py b/reporting/mod_actions.py index e4f17df..3170546 100644 --- a/reporting/mod_actions.py +++ b/reporting/mod_actions.py @@ -10,8 +10,7 @@ from pydantic import BaseModel from avro_schema_repository.schema_repository import SchemaRepository -from posts.models import InternalPost -from posts.repository import Posts +from posts.repository import Repository as Posts from shared.auth import KhUser, Scope from shared.caching import AerospikeCache from shared.caching.key_value_store import KeyValueStore @@ -20,12 +19,12 @@ from shared.exceptions.http_error import BadRequest, Conflict, NotFound from shared.models import PostId, UserPortable from shared.sql import SqlInterface -from shared.sql.query import Field, Operator, Order, Query, Update, Value, Where -from users.repository import Users +from shared.sql.query import Field, Operator, Order, Query, Table, Update, Value, Where +from users.repository import Repository as Users from .models.actions import ActionType, BanAction, ForceUpdateAction, InternalActionType, InternalBanAction, InternalModAction, ModAction, RemovePostAction from .models.bans import Ban, InternalBan, InternalBanType, InternalIpBan -from .repository import Reporting +from .repository import Repository from .repository import kvs as reporting_kvs @@ -34,7 +33,7 @@ posts: Posts = Posts() AvroMarker: bytes = b'\xC3\x01' kvs: KeyValueStore = KeyValueStore('kheina', 'actions') -reporting: Reporting = Reporting() +reporting: Repository = Repository() class ModActions(SqlInterface) : @@ -113,10 +112,11 @@ async def action(self: Self, user: KhUser, iaction: InternalModAction) -> ModAct async def ban(self: Self, user: KhUser, iban: InternalBan) -> Ban : + iuser = await users._get_user(iban.user_id) return Ban( ban_id = iban.ban_id, ban_type = iban.ban_type.to_type(), - user = await self.user_portable(user, iban.user_id), + user = await users.portable(user, iuser), created = iban.created, completed = iban.completed, reason = iban.reason, @@ -238,7 +238,7 @@ async def create(self: Self, user: KhUser, response: str, action: ModAction) -> match action.action : case ForceUpdateAction() : await t.query_async( - Query(InternalPost.__table_name__).update( + Query(Table('kheina.public.posts')).update( Update('locked', Value(True)), ).where( Where( @@ -252,7 +252,7 @@ async def create(self: Self, user: KhUser, response: str, action: ModAction) -> case RemovePostAction() : await t.query_async( - Query(InternalPost.__table_name__).update( + Query(Table('kheina.public.posts')).update( Update('locked', Value(True)), ).where( Where( @@ -276,8 +276,8 @@ async def create(self: Self, user: KhUser, response: str, action: ModAction) -> completed = completed, reason = action.reason, )) - await kvs.put_async(f'ban={iban.user_id}', iban) - await kvs.put_async(f'user_id={user_id}', iaction) + await kvs.put_async(f'active_ban={user_id}', iban, action.action.duration) + await kvs.remove_async(f'user_bans={user_id}') await t.commit() @@ -453,26 +453,29 @@ async def _active_action(self: Self, post_id: PostId) -> Optional[InternalModAct ) - @AerospikeCache('kheina', 'actions', 'active_action={post_id}', _kvs=kvs) + @AerospikeCache('kheina', 'actions', 'post_actions={post_id}', _kvs=kvs) async def _actions(self: Self, post_id: PostId) -> list[InternalModAction] : - data: list[tuple[int, int, Optional[int], Optional[int], Optional[int], datetime, Optional[datetime], str, int, bytes]] = await self.query_async(Query(InternalModAction.__table_name__).select( - Field('mod_actions', 'action_id'), - Field('mod_actions', 'report_id'), - Field('mod_actions', 'post_id'), - Field('mod_actions', 'user_id'), - Field('mod_actions', 'assignee'), - Field('mod_actions', 'created'), - Field('mod_actions', 'completed'), - Field('mod_actions', 'reason'), - Field('mod_actions', 'action_type'), - Field('mod_actions', 'action'), - ).where( - Where( - Field('mod_actions', 'post_id'), - Operator.equal, - Value(post_id.int()), + data: list[tuple[int, int, Optional[int], Optional[int], Optional[int], datetime, Optional[datetime], str, int, bytes]] = await self.query_async( + Query(InternalModAction.__table_name__).select( + Field('mod_actions', 'action_id'), + Field('mod_actions', 'report_id'), + Field('mod_actions', 'post_id'), + Field('mod_actions', 'user_id'), + Field('mod_actions', 'assignee'), + Field('mod_actions', 'created'), + Field('mod_actions', 'completed'), + Field('mod_actions', 'reason'), + Field('mod_actions', 'action_type'), + Field('mod_actions', 'action'), + ).where( + Where( + Field('mod_actions', 'post_id'), + Operator.equal, + Value(post_id.int()), + ), ), - ), fetch_all = True) + fetch_all = True, + ) if not data : return [] @@ -498,3 +501,54 @@ async def actions(self: Self, user: KhUser, post_id: PostId) -> list[ModAction] await self.action(user, iaction) for iaction in await self._actions(post_id) ] + + + @AerospikeCache('kheina', 'actions', 'user_actions={user_id}', _kvs=kvs) + async def _user_actions(self: Self, user_id: int) -> list[InternalModAction] : + data: list[tuple[int, int, Optional[int], Optional[int], Optional[int], datetime, Optional[datetime], str, int, bytes]] = await self.query_async( + Query(InternalModAction.__table_name__).select( + Field('mod_actions', 'action_id'), + Field('mod_actions', 'report_id'), + Field('mod_actions', 'post_id'), + Field('mod_actions', 'user_id'), + Field('mod_actions', 'assignee'), + Field('mod_actions', 'created'), + Field('mod_actions', 'completed'), + Field('mod_actions', 'reason'), + Field('mod_actions', 'action_type'), + Field('mod_actions', 'action'), + ).where( + Where( + Field('mod_actions', 'user_id'), + Operator.equal, + Value(user_id), + ), + ), + fetch_all = True, + ) + + if not data : + return [] + + return [ + InternalModAction( + action_id = row[0], + report_id = row[1], + post_id = row[2], + user_id = row[3], + assignee = row[4], + created = row[5], + completed = row[6], + reason = row[7], + action_type = InternalActionType(row[8]), + action = bytes(row[9]), + ) + for row in data + ] + + async def user_actions(self: Self, user: KhUser, handle: str) -> list[ModAction] : + user_id: int = await users._handle_to_user_id(handle) + return [ + await self.action(user, iaction) + for iaction in await self._user_actions(user_id) + ] diff --git a/reporting/models/__init__.py b/reporting/models/__init__.py index c4c3ec2..1c18d8d 100644 --- a/reporting/models/__init__.py +++ b/reporting/models/__init__.py @@ -30,5 +30,5 @@ class CreateActionRequest(BaseModel) : action: RemovePostAction | ForceUpdateAction | BanActionInput | None -class ReportReponseRequest(BaseModel) : +class CloseReponseRequest(BaseModel) : response: str diff --git a/reporting/models/actions.py b/reporting/models/actions.py index 951ce81..d8c5b1b 100644 --- a/reporting/models/actions.py +++ b/reporting/models/actions.py @@ -32,8 +32,7 @@ def internal(self: Self) -> InternalActionType : # these two enums must contain the same values -assert set(InternalActionType.__members__.keys()) == set(ActionType.__members__.keys()) -assert set(InternalActionType.__members__.keys()) == set(map(lambda x : x.value, ActionType.__members__.values())) +assert set(InternalActionType.__members__.keys()) == set(ActionType.__members__.keys()) == set(map(lambda x : x.value, ActionType.__members__.values())) class InternalModAction(BaseModel) : diff --git a/reporting/models/bans.py b/reporting/models/bans.py index 990eaad..a96a563 100644 --- a/reporting/models/bans.py +++ b/reporting/models/bans.py @@ -57,7 +57,7 @@ class InternalIpBan(BaseModel) : class Ban(BaseModel) : ban_id: int ban_type: BanType - user: Optional[UserPortable] + user: UserPortable created: datetime completed: datetime reason: str diff --git a/reporting/models/mod_queue.py b/reporting/models/mod_queue.py index 0329032..5c35cde 100644 --- a/reporting/models/mod_queue.py +++ b/reporting/models/mod_queue.py @@ -17,6 +17,6 @@ class InternalModQueueEntry(BaseModel) : class ModQueueEntry(BaseModel) : - queue_id: int = Field(description='orm:"pk;gen"') + queue_id: int assignee: Optional[UserPortable] report: Report diff --git a/reporting/reporting.py b/reporting/reporting.py index 65d5508..c655304 100644 --- a/reporting/reporting.py +++ b/reporting/reporting.py @@ -10,13 +10,13 @@ from .models.actions import BanAction, ForceUpdateAction, ModAction, RemovePostAction from .models.bans import Ban from .models.reports import BaseReport, Report -from .repository import Reporting # type: ignore +from .repository import Repository -mod_acitons = ModActions() +mod_actions = ModActions() -class Reporting(Reporting) : +class Reporting(Repository) : async def create(self: Self, user: KhUser, body: CreateRequest) -> Report : return await super().create( @@ -79,7 +79,7 @@ async def create_action(self: Self, user: KhUser, body: CreateActionRequest) -> case _ : raise BadRequest('unknown action object', body=body) - return await mod_acitons.create( + return await mod_actions.create( user, body.response, ModAction( @@ -95,8 +95,12 @@ async def create_action(self: Self, user: KhUser, body: CreateActionRequest) -> async def actions(self: Self, user: KhUser, post_id: PostId) -> list[ModAction] : - return await mod_acitons.actions(user, PostId(post_id)) + return await mod_actions.actions(user, post_id) + + + async def user_actions(self: Self, user: KhUser, handle: str) -> list[ModAction] : + return await mod_actions.user_actions(user, handle) async def bans(self: Self, user: KhUser, handle: str) -> list[Ban] : - return await mod_acitons.bans(user, handle) + return await mod_actions.bans(user, handle) diff --git a/reporting/repository.py b/reporting/repository.py index bc0bf37..6877286 100644 --- a/reporting/repository.py +++ b/reporting/repository.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from avro_schema_repository.schema_repository import SchemaRepository -from posts.repository import Posts +from posts.repository import Repository as Posts from shared.auth import KhUser, Scope from shared.caching import AerospikeCache from shared.caching.key_value_store import KeyValueStore @@ -17,7 +17,7 @@ from shared.models import UserPortable from shared.sql import SqlInterface from shared.sql.query import Field, Operator, Order, Query, Value, Where -from users.repository import Users +from users.repository import Repository as Users from .models.mod_queue import InternalModQueueEntry, ModQueueEntry from .models.reports import BaseReport, BaseReportHistory, CopyrightReport, HistoryMask, InternalReport, InternalReportType, Report @@ -30,7 +30,7 @@ kvs: KeyValueStore = KeyValueStore('kheina', 'reports') -class Reporting(SqlInterface) : +class Repository(SqlInterface) : _report_type_map: dict[InternalReportType, type[BaseModel]] = { InternalReportType.other: BaseReport, @@ -43,7 +43,7 @@ class Reporting(SqlInterface) : } def __init__(self, *args: Any, **kwargs: Any) : - assert set(Reporting._report_type_map.keys()) == set(InternalReportType.__members__.values()) + assert set(Repository._report_type_map.keys()) == set(InternalReportType.__members__.values()) super().__init__(*args, conversions={ IntEnum: lambda x: x.value }, **kwargs) @@ -55,14 +55,14 @@ async def _get_schema(fingerprint: bytes) -> Schema: @AsyncLRU(maxsize=0) async def _get_serializer(self: Self, report_type: InternalReportType) -> tuple[bytes, AvroSerializer] : - model = Reporting._report_type_map[report_type] + model = Repository._report_type_map[report_type] return AvroMarker + await repo.addSchema(convert_schema(model)), AvroSerializer(model) async def _get_deserializer(self: Self, report_type: InternalReportType, fp: bytes) -> AvroDeserializer : assert fp[:2] == AvroMarker - model = Reporting._report_type_map[report_type] - return AvroDeserializer(read_model=model, write_model=await Reporting._get_schema(fp[2:10])) + model = Repository._report_type_map[report_type] + return AvroDeserializer(read_model=model, write_model=await Repository._get_schema(fp[2:10])) async def user_portable(self: Self, user: KhUser, user_id: Optional[int]) -> Optional[UserPortable] : @@ -165,7 +165,7 @@ async def update_report(self: Self, user: KhUser, report: Report) -> None : prev = report_data.dict() self.logger.debug({ 'incoming data': data, - 'prev': prev, + 'prev': prev, }) for k, v in data.items() : if prev.get(k) == v : @@ -224,22 +224,24 @@ async def assign_self(self: Self, user: KhUser, queue_id: int) -> None : raise Conflict('another moderator has assigned this report to themselves') - async def close_response(self: Self, user: KhUser, queue_id: int, response: str) -> Report : + async def close_response(self: Self, user: KhUser, report_id: int, response: str) -> Report : ireport: InternalReport async with self.transaction() as t : - data: Optional[tuple[int, Optional[int]]] = await t.query_async(""" + data: Optional[tuple[int]] = await t.query_async(""" delete from kheina.public.mod_queue - where mod_queue.queue_id = %s - returning mod_queue.report_id, mod_queue.assignee; - """, ( - queue_id, - ), fetch_one = True) + where mod_queue.report_id = %s + returning mod_queue.assignee; + """, ( + report_id, + ), + fetch_one = True, + ) if not data : - raise NotFound('provided queue entry does not exist') + raise NotFound('provided report does not exist') - if data[1] != user.user_id : + if data[0] != user.user_id : raise BadRequest('cannot close a report that is assigned to someone else') ireport = await self._read(data[0]) diff --git a/reporting/router.py b/reporting/router.py index 9219899..52a1a07 100644 --- a/reporting/router.py +++ b/reporting/router.py @@ -1,10 +1,11 @@ -from fastapi import APIRouter, Request +from fastapi import APIRouter from shared.auth import Scope -from shared.models import PostId +from shared.models import PostId, convert_path_post_id +from shared.models.server import Request from shared.timing import timed -from .models import CreateActionRequest, CreateRequest, ReportReponseRequest +from .models import CloseReponseRequest, CreateActionRequest, CreateRequest from .models.actions import ModAction from .models.bans import Ban from .models.mod_queue import ModQueueEntry @@ -13,25 +14,25 @@ reportRouter = APIRouter( - prefix='/report', + prefix = '/report', ) reportsRouter = APIRouter( - prefix='/reports', + prefix = '/reports', ) actionRouter = APIRouter( - prefix='/action', + prefix = '/action', ) actionsRouter = APIRouter( - prefix='/actions', + prefix = '/actions', ) queueRouter = APIRouter( - prefix='/mod', + prefix = '/mod', ) bansRouter = APIRouter( - prefix='/bans', + prefix = '/bans', ) @@ -66,6 +67,13 @@ async def v1List(req: Request) -> list[Report] : return await reporting.list_(req.user) +@reportRouter.delete('/{report_id}') +@timed.root +async def v1CloseWithoutAction(req: Request, report_id: int, body: CloseReponseRequest) -> Report : + await req.user.verify_scope(Scope.mod) + return await reporting.close_response(req.user, report_id, body.response) + + ######################### queue ######################### @@ -83,13 +91,6 @@ async def v1AssignSelf(req: Request, queue_id: int) -> None : return await reporting.assign_self(req.user, queue_id) -@queueRouter.patch('/{queue_id}') -@timed.root -async def v1CloseWithoutAction(req: Request, queue_id: int, body: ReportReponseRequest) -> Report : - await req.user.verify_scope(Scope.mod) - return await reporting.close_response(req.user, queue_id, body.response) - - ######################### actions ######################### @actionRouter.put('') @@ -103,7 +104,14 @@ async def v1CloseWithAction(req: Request, body: CreateActionRequest) -> ModActio @timed.root async def v1Actions(req: Request, post_id: PostId) -> list[ModAction] : await req.user.verify_scope(Scope.mod) - return await reporting.actions(req.user, post_id) + return await reporting.actions(req.user, convert_path_post_id(post_id)) + + +@actionsRouter.get('/user/{handle}') +@timed.root +async def v1UserActions(req: Request, handle: str) -> list[ModAction] : + await req.user.verify_scope(Scope.mod) + return await reporting.user_actions(req.user, handle) ######################### bans ######################### diff --git a/requirements.lock b/requirements.lock index af9adde..c92c62f 100644 --- a/requirements.lock +++ b/requirements.lock @@ -25,6 +25,7 @@ fastapi==0.115.6 frozenlist==1.5.0 gunicorn==23.0.0 h11==0.14.0 +http_ece==1.2.1 httpcore==1.0.7 httptools==0.6.4 httpx==0.27.2 @@ -42,6 +43,7 @@ propcache==0.2.0 psycopg-binary==3.2.3 psycopg-pool==3.2.4 psycopg==3.2.3 +py-vapid==1.9.2 pycparser==2.22 pycryptodome==3.21.0 pydantic==1.10.19 @@ -53,18 +55,21 @@ pyotp==2.9.0 python-dotenv==1.0.1 python-ffmpeg==2.0.12 python-multipart==0.0.20 +pywebpush==2.0.3 PyYAML==6.0.2 requests==2.32.3 rich==13.9.4 scipy==1.14.1 setuptools==75.6.0 shellingham==1.5.4 +six==1.17.0 sniffio==1.3.1 starlette==0.41.3 typer==0.15.1 typing_extensions==4.12.2 ujson==5.5.0 urllib3==2.2.3 +uuid7==0.1.0 uvicorn==0.32.0 uvloop==0.21.0 Wand==0.6.13 diff --git a/sample-creds.json b/sample-creds.json index e96a18c..3ee5ff7 100644 --- a/sample-creds.json +++ b/sample-creds.json @@ -27,6 +27,11 @@ "key": "kw7vBGROerYZjIb7i05SrAIrYV18Xn43ijIahWrY", "key_name": "localhost-minio-key" }, + "notifications": { + "aes": "XHn0bd38L5mQgISKBtP2FbIJQnYZd0yWfwXFn1FBtfQ.pD795CKO9aaq2T5MuIqxOwgTKli-d0gMOaSlhl5GR1Mv9GhOGLct-3jgsqGhHc6uQwJXki7Jj1pjbj2fliOyAQ", + "pub": "wtzY_IHxeR7ZMbar.TsLHfu8KxrkW99Ou8ySJNB_nxF6vK33SyldQ1HIWhTWcgT5LuYd4Ekra7an0i-jsQDxFNSnnnxCVWKv0.vmTisEaOEAQajFyWGr65unDlD9ZJXOSIQmmfh0lpxUSS-x4-GFICLcdITGSNtE7f6JbMPSGbp7GygGZ5NYdPCw", + "priv": "8ax1-3veA74sd5nt.2Bv8pW2CEobi6HWO9800C3aBZ5iAubhq1SxPZeN-aYkrJ3LMU4CRjOFL4mVEzQTL.cmGWjKew0VFJVw7gEu2n2PhQAkWulNUJQkwx2-0NOGP4E-NIi_tQjYAWE-F_cMtZd3nRHmjJYy37IPZc3x27Bg" + }, "db": { "user": "kheina", "password": "password", @@ -50,4 +55,4 @@ ], "ip_salt": "570e755552e4ef95f7ee5ce9ad0ff4e38a64b858" } -} \ No newline at end of file +} diff --git a/server.py b/server.py index bce3124..2541c03 100644 --- a/server.py +++ b/server.py @@ -11,6 +11,7 @@ from account.router import app as account from configs.router import app as configs from emojis.router import app as emoji +from notifications.router import app as notifications from posts.router import app as posts from probe.router import probes from reporting.router import app as reporting @@ -165,3 +166,4 @@ def root() -> ServiceInfo : app.include_router(users) app.include_router(emoji) app.include_router(reporting) +app.include_router(notifications) diff --git a/sets/repository.py b/sets/repository.py index 7ef3d4d..e1521aa 100644 --- a/sets/repository.py +++ b/sets/repository.py @@ -3,16 +3,17 @@ from typing import Optional, Self, Tuple, Union from posts.models import InternalPost, Post, PostId, Privacy -from posts.repository import Posts, privacy_map +from posts.repository import Repository as Posts +from posts.repository import privacy_map from shared.auth import KhUser, Scope -from shared.caching import AerospikeCache, ArgsCache +from shared.caching import AerospikeCache from shared.caching.key_value_store import KeyValueStore from shared.datetime import datetime from shared.exceptions.http_error import NotFound from shared.hashing import Hashable from shared.models import InternalUser, UserPrivacy from shared.sql import SqlInterface -from users.repository import Users +from users.repository import Repository as Users from .models import InternalSet, Set, SetId @@ -23,9 +24,9 @@ posts = Posts() -class Sets(SqlInterface, Hashable) : +class Repository(SqlInterface, Hashable) : - def __init__(self: 'Sets') -> None : + def __init__(self) -> None : SqlInterface.__init__( self, conversions={ @@ -132,7 +133,7 @@ async def set(self: Self, iset: InternalSet, user: KhUser) -> Set : count=iset.count, title=iset.title, description=iset.description, - privacy=Sets._validate_privacy(await privacy_map.get(iset.privacy)), + privacy=Repository._validate_privacy(await privacy_map.get(iset.privacy)), created=iset.created, updated=iset.updated, first=first_post, diff --git a/sets/router.py b/sets/router.py index 0e468ad..e428e2d 100644 --- a/sets/router.py +++ b/sets/router.py @@ -1,8 +1,7 @@ from fastapi import APIRouter -from pydantic import conint from shared.models._shared import PostId -from shared.server import Request +from shared.models.server import Request from shared.timing import timed from .models import AddPostToSetRequest, CreateSetRequest, PostSet, Set, SetId, UpdateSetRequest diff --git a/sets/sets.py b/sets/sets.py index dd00b66..bf589b8 100644 --- a/sets/sets.py +++ b/sets/sets.py @@ -5,17 +5,18 @@ from psycopg.errors import UniqueViolation from posts.models import InternalPost, MediaType, Post, PostId, PostSize, Privacy, Rating -from posts.repository import Posts, privacy_map +from posts.repository import Repository as Posts +from posts.repository import privacy_map from shared.auth import KhUser, Scope -from shared.caching import AerospikeCache, ArgsCache +from shared.caching import ArgsCache from shared.datetime import datetime from shared.exceptions.http_error import BadRequest, Conflict, HttpErrorHandler, NotFound from shared.models.user import UserPrivacy from shared.timing import timed -from users.repository import Users +from users.repository import Repository as Users from .models import InternalSet, PostSet, Set, SetId, SetNeighbors, UpdateSetRequest -from .repository import SetKVS, SetNotFound, Sets # type: ignore +from .repository import Repository, SetKVS, SetNotFound """ @@ -44,7 +45,7 @@ users = Users() -class Sets(Sets) : +class Sets(Repository) : @staticmethod async def _verify_authorized(user: KhUser, iset: InternalSet) -> bool : diff --git a/shared/auth/__init__.py b/shared/auth/__init__.py index 82f5bd7..df19735 100644 --- a/shared/auth/__init__.py +++ b/shared/auth/__init__.py @@ -12,12 +12,11 @@ from cryptography.hazmat.primitives.serialization import load_der_public_key from fastapi import Request -from authenticator.authenticator import AuthAlgorithm, Authenticator, AuthState, PublicKeyResponse, Scope, TokenMetadata -from shared.models.auth import AuthToken, KhUser # type: ignore +from authenticator.authenticator import AuthAlgorithm, Authenticator, AuthState, PublicKeyResponse, Scope, TokenMetadata, token_kvs +from shared.models.auth import AuthToken, _KhUser from ..base64 import b64decode, b64encode from ..caching import ArgsCache -from ..caching.key_value_store import KeyValueStore from ..datetime import datetime from ..exceptions.http_error import Forbidden, Unauthorized from ..utilities import int_from_bytes @@ -26,14 +25,13 @@ authenticator = Authenticator() ua_strip = re_compile(r'\/\d+(?:\.\d+)*') -KVS: KeyValueStore = KeyValueStore('kheina', 'token') class InvalidToken(ValueError) : pass -class KhUser(KhUser) : +class KhUser(_KhUser) : async def authenticated(self, raise_error: bool = True) -> bool : if self.banned : if raise_error : @@ -53,7 +51,10 @@ async def verify_scope(self, scope: Scope, raise_error: bool = True) -> bool : await self.authenticated(raise_error) if scope not in self.scope : - raise Forbidden('User is not authorized to access this resource.', user=self) + if raise_error : + raise Forbidden('User is not authorized to access this resource.', user=self) + + return False return True @@ -104,7 +105,8 @@ async def v1token(token: str) -> AuthToken : if datetime.now() > expires : raise Unauthorized('Key has expired.') - token_info_task = ensure_future(KVS.get_async(guid.bytes, TokenMetadata)) + token_info_task = ensure_future(tokenMetadata(guid.bytes)) + token_info: TokenMetadata try : public_key = await _fetchPublicKey(key_id, algorithm) @@ -133,6 +135,7 @@ async def v1token(token: str) -> AuthToken : expires = expires, data = json.loads(data), token_string = token, + metadata = token_info, ) @@ -151,15 +154,28 @@ async def verifyToken(token: str) -> AuthToken : raise InvalidToken('The given token uses a version that is unable to be decoded.') +async def tokenMetadata(guid: bytes | UUID) -> TokenMetadata : + if isinstance(guid, UUID) : + guid = guid.bytes + + token = await token_kvs.get_async(guid, TokenMetadata) + + # though the kvs should only retain the token for as long as it's active, check the expiration anyway + if token.expires <= datetime.now() : + token.state = AuthState.inactive + + return token + + async def deactivateAuthToken(token: str, guid: Optional[bytes] = None) -> None : atoken = await verifyToken(token) if not guid : - return await KVS.remove_async(atoken.guid.bytes) + return await token_kvs.remove_async(atoken.guid.bytes) - tm = await KVS.get_async(guid, TokenMetadata) + tm = await tokenMetadata(guid) if tm.user_id == atoken.user_id : - return await KVS.remove_async(guid) + return await token_kvs.remove_async(guid) async def retrieveAuthToken(request: Request) -> AuthToken : diff --git a/shared/caching/__init__.py b/shared/caching/__init__.py index 4fa33fd..0a52983 100644 --- a/shared/caching/__init__.py +++ b/shared/caching/__init__.py @@ -181,14 +181,25 @@ def wrapper(*args: Tuple[Hashable], **kwargs:Dict[str, Hashable]) -> Any : return decorator -def deepTypecheck(type_: type | tuple, instance: Any) -> bool : +def deepTypecheck(type_: type | tuple[type, ...], instance: Any) -> bool : + """ + returns true if instance is an instance of type_ + """ + match instance : + case list() | tuple() : + return all(map(partial(deepTypecheck, type_), instance)) + + if type(instance) is type_ : + return True + + type_ = getattr(type_, '__args__', type_) + if isinstance(type_, tuple) : if type(instance) not in type_ : - return False + return False else : - t = getattr(type_, '__origin__', type_) - if type(instance) is not t : + if type(instance) is not getattr(type_, '__origin__', type_) : return False if di := getattr(instance, '__dict__', None) : @@ -196,10 +207,6 @@ def deepTypecheck(type_: type | tuple, instance: Any) -> bool : if di.keys() != dt.keys() : return False - match instance : - case list() | tuple() : - return all(map(partial(deepTypecheck, type_.__args__), instance)) # type: ignore - return True @@ -267,7 +274,6 @@ async def wrapper(*args: Hashable, **kwargs: Hashable) -> Any : await decorator.kvs.put_async(key, data, TTL) else : - if not deepTypecheck(return_type, data) : data = await func(*args, **kwargs) diff --git a/shared/caching/key_value_store.py b/shared/caching/key_value_store.py index 4753c56..61c2aa4 100644 --- a/shared/caching/key_value_store.py +++ b/shared/caching/key_value_store.py @@ -4,28 +4,27 @@ from copy import copy from functools import partial from time import time -from typing import Any, Iterable, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Iterable, Optional, Self, Union import aerospike from ..config.constants import environment from ..config.credentials import fetch -from ..models import Undefined +from ..models._shared import Undefined from ..timing import timed from ..utilities import __clear_cache__, coerse -T = TypeVar('T') KeyType = Union[str, bytes, int] class KeyValueStore : _client = None - def __init__(self: 'KeyValueStore', namespace: str, set: str, local_TTL: float = 1) : + def __init__(self: Self, namespace: str, set: str, local_TTL: float = 1) : if not KeyValueStore._client and not environment.is_test() : config = { - 'hosts': fetch('aerospike.hosts', list[Tuple[str, int]]), + 'hosts': fetch('aerospike.hosts', list[tuple[str, int]]), 'policies': fetch('aerospike.policies', dict[str, Any]), } KeyValueStore._client = aerospike.client(config).connect() @@ -39,14 +38,17 @@ def __init__(self: 'KeyValueStore', namespace: str, set: str, local_TTL: float = @timed - def put(self: 'KeyValueStore', key: KeyType, data: Any, TTL: int = 0) : - KeyValueStore._client.put( # type: ignore + def put(self: Self, key: KeyType, data: Any, TTL: int = 0, bins: dict[str, Any] = { }) -> None : + KeyValueStore._client.put( # type: ignore (self._namespace, self._set, key), - { 'data': data }, - meta={ + { + 'data': data, + **bins, + }, + meta = { 'ttl': TTL, }, - policy={ + policy = { 'max_retries': 3, }, ) @@ -54,16 +56,21 @@ def put(self: 'KeyValueStore', key: KeyType, data: Any, TTL: int = 0) : @timed - async def put_async(self: 'KeyValueStore', key: KeyType, data: Any, TTL: int = 0) : + async def put_async(self: Self, key: KeyType, data: Any, TTL: int = 0, bins: dict[str, Any] = { }) -> None : with ThreadPoolExecutor() as threadpool : - return await get_event_loop().run_in_executor(threadpool, partial(self.put, key, data, TTL)) + return await get_event_loop().run_in_executor(threadpool, partial(self.put, key, data, TTL, bins)) - def _get(self: 'KeyValueStore', key: KeyType, type: Optional[Type[T]] = None) -> T : + def _get[T](self: Self, key: KeyType, type: Optional[type[T]] = None) -> T : if key in self._cache : return copy(self._cache[key][1]) - _, _, data = KeyValueStore._client.get((self._namespace, self._set, key)) # type: ignore + try : + _, _, data = KeyValueStore._client.get((self._namespace, self._set, key)) # type: ignore + + except aerospike.exception.RecordNotFound : + raise aerospike.exception.RecordNotFound(f'Record not found: {(self._namespace, self._set, key)}') + self._cache[key] = (time() + self._local_TTL, data['data']) if type : @@ -73,35 +80,33 @@ def _get(self: 'KeyValueStore', key: KeyType, type: Optional[Type[T]] = None) -> @timed - def get(self: 'KeyValueStore', key: KeyType, type: Optional[Type[T]] = None) -> T : + def get[T](self: Self, key: KeyType, type: Optional[type[T]] = None) -> T : __clear_cache__(self._cache, time) return self._get(key, type) @timed - async def get_async(self: 'KeyValueStore', key: KeyType, type: Optional[Type[T]] = None) -> T : + async def get_async[T](self: Self, key: KeyType, type: Optional[type[T]] = None) -> T : async with self._get_lock : __clear_cache__(self._cache, time) with ThreadPoolExecutor() as threadpool : - try : - return await get_event_loop().run_in_executor(threadpool, partial(self._get, key, type)) - - except aerospike.exception.RecordNotFound : - raise aerospike.exception.RecordNotFound(f'Record not found: {(self._namespace, self._set, key)}') + return await get_event_loop().run_in_executor(threadpool, partial(self._get, key, type)) - def _get_many[T: KeyType](self: 'KeyValueStore', k: Iterable[T]) -> dict[T, Any] : - keys: dict[T, T] = { v: v for v in k } - remote_keys: Set[T] = keys.keys() - self._cache.keys() + def _get_many[T, K: KeyType](self: Self, k: Iterable[K], type: Optional[type[T]] = None) -> dict[K, T | type[Undefined]] : + # this weird ass dict is so that we can "convert" the returned aerospike keytype back to K + keys: dict[K, K] = { v: v for v in k } + remote_keys: set[K] = keys.keys() - self._cache.keys() + values: dict[K, Any] if remote_keys : data: list[Tuple[Any, Any, Any]] = KeyValueStore._client.get_many(list(map(lambda k : (self._namespace, self._set, k), remote_keys))) # type: ignore - data_map: dict[T, Any] = { } + data_map: dict[K, Any] = { } exp: float = time() + self._local_TTL for datum in data : - key: T = keys[datum[0][2]] + key: K = keys[datum[0][2]] # filter on the metadata, since it will always be populated if datum[1] : @@ -112,7 +117,7 @@ def _get_many[T: KeyType](self: 'KeyValueStore', k: Iterable[T]) -> dict[T, Any] else : data_map[key] = Undefined - return { + values = { **data_map, **{ key: copy(self._cache[key][1]) @@ -120,51 +125,64 @@ def _get_many[T: KeyType](self: 'KeyValueStore', k: Iterable[T]) -> dict[T, Any] }, } - # only local cache is required - return { - key: self._cache[key][1] - for key in keys.keys() - } + else : + # only local cache is required + values = { + key: copy(self._cache[key][1]) + for key in keys.keys() + } + + if type : + return { + k: coerse(v, type) if v is not Undefined else v + for k, v in values.items() + } + + return values @timed - def get_many[T: KeyType](self: 'KeyValueStore', keys: Iterable[T]) -> dict[T, Any] : + def get_many[T, K: KeyType](self: Self, keys: Iterable[K], type: Optional[type[T]] = None) -> dict[K, T | type[Undefined]] : __clear_cache__(self._cache, time) - return self._get_many(keys) + return self._get_many(keys, type) @timed - async def get_many_async[T: KeyType](self: 'KeyValueStore', keys: Iterable[T]) -> dict[T, Any] : + async def get_many_async[T, K: KeyType](self: Self, keys: Iterable[K], type: Optional[type[T]] = None) -> dict[K, T | type[Undefined]] : async with self._get_many_lock : with ThreadPoolExecutor() as threadpool : - return await get_event_loop().run_in_executor(threadpool, partial(self.get_many, keys)) + return await get_event_loop().run_in_executor(threadpool, partial(self.get_many, keys, type)) @timed - def remove(self: 'KeyValueStore', key: KeyType) -> None : + def remove(self: Self, key: KeyType) -> None : + try : + self._client.remove( # type: ignore + (self._namespace, self._set, key), + policy = { + 'max_retries': 3, + }, + ) + + except aerospike.exception.RecordNotFound : + pass + if key in self._cache : del self._cache[key] - self._client.remove( # type: ignore - (self._namespace, self._set, key), - policy={ - 'max_retries': 3, - }, - ) - @timed - async def remove_async(self: 'KeyValueStore', key: KeyType) -> None : + async def remove_async(self: Self, key: KeyType) -> None : with ThreadPoolExecutor() as threadpool : return await get_event_loop().run_in_executor(threadpool, partial(self.remove, key)) @timed - def exists(self: 'KeyValueStore', key: KeyType) -> bool : + def exists(self: Self, key: KeyType) -> bool : try : - _, meta = self._client.exists( # type: ignore + _, meta = self._client.exists( # type: ignore (self._namespace, self._set, key), - policy={ + policy = { 'max_retries': 3, }, ) @@ -176,10 +194,42 @@ def exists(self: 'KeyValueStore', key: KeyType) -> bool : @timed - async def exists_async(self: 'KeyValueStore', key: KeyType) -> bool : + async def exists_async(self: Self, key: KeyType) -> bool : with ThreadPoolExecutor() as threadpool : return await get_event_loop().run_in_executor(threadpool, partial(self.exists, key)) - def truncate(self: 'KeyValueStore') -> None : + @timed + def where[T](self: Self, *predicates: aerospike.predicates, type: Optional[type[T]] = None) -> list[T] : + results: list[T] = [] + func: Callable[[Any], None] + + if type : + def func(data: Any) -> None : + results.append(coerse(data[2]['data'], type)) + + else : + def func(data: Any) -> None : + results.append(copy(data[2]['data'])) + + KeyValueStore._client.query( # type: ignore + self._namespace, + self._set, + ).select( + 'data', + ).where( + *predicates, + ).foreach( + func, + ) + return results + + + @timed + async def where_async[T](self: Self, *predicates: aerospike.predicates, type: Optional[type[T]] = None) -> list[T] : + with ThreadPoolExecutor() as threadpool : + return await get_event_loop().run_in_executor(threadpool, partial(self.where, *predicates, type=type)) + + + def truncate(self: Self) -> None : self._client.truncate(self._namespace, self._set, 0) # type: ignore diff --git a/shared/exceptions/base_error.py b/shared/exceptions/base_error.py index d964e8a..b6ae3d9 100644 --- a/shared/exceptions/base_error.py +++ b/shared/exceptions/base_error.py @@ -23,7 +23,8 @@ def __init__(self, *args: Any, refid: Optional[UUID | str] = None, logdata: dict if 'refid' in logdata : del logdata['refid'] - self.logdata: dict[str, Any] = { + self.__dict__: dict[str, Any] = { **logdata, **kwargs, + **self.__dict__, } diff --git a/shared/exceptions/http_error.py b/shared/exceptions/http_error.py index b8061e4..95c35f3 100644 --- a/shared/exceptions/http_error.py +++ b/shared/exceptions/http_error.py @@ -127,26 +127,26 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any : match e : case NotImplementedError() : - raise NotImplemented( # noqa: F901 + raise NotImplemented( f'{message} has not been implemented.', - refid = refid, + refid = refid, logdata = logdata, - err = e, + err = e, ) case ClientError() : raise ServiceUnavailable( f'{ServiceUnavailable.__name__}: received an invalid response from an upstream server while {message}.', - refid = refid, + refid = refid, logdata = logdata, - err = e, + err = e, ) raise InternalServerError( f'an unexpected error occurred while {message}.', - refid = refid, + refid = refid, logdata = logdata, - err = e, + err = e, ) markcoroutinefunction(wrapper) diff --git a/shared/logging.py b/shared/logging.py index bdd6556..bf674db 100644 --- a/shared/logging.py +++ b/shared/logging.py @@ -1,12 +1,15 @@ import json import logging +from dataclasses import is_dataclass from enum import Enum, unique -from logging import ERROR, INFO, Logger, getLevelName +from logging import DEBUG, ERROR, INFO, Logger, getLevelName from sys import stderr, stdout from traceback import format_tb from types import ModuleType from typing import Any, Callable, Optional, Self, TextIO +from pydantic import BaseModel + from .config.constants import environment from .config.repo import name as repo_name from .config.repo import short_hash @@ -227,23 +230,22 @@ def emit(self, record: logging.LogRecord) -> None : if record.args and isinstance(record.msg, str) : record.msg = record.msg % tuple(map(str, map(json_stream, record.args))) - if record.exc_info : - e: BaseException = record.exc_info[1] # type: ignore - refid = getattr(e, 'refid', None) - errorinfo: dict[str, Any] = { - 'error': f'{getFullyQualifiedClassName(e)}: {e}', - 'stacktrace': list(map(str.strip, format_tb(record.exc_info[2]))), - 'refid': refid.hex if refid else None, - **json_stream(getattr(e, 'logdata', { })), - } + if is_dataclass(record.msg) or isinstance(record.msg, BaseModel) : + record.msg = record.msg.__dict__ + + if record.exc_info and record.exc_info[1] : + e: BaseException = record.exc_info[1] + msg: dict[str, Any] if isinstance(record.msg, dict) : - errorinfo.update(json_stream(record.msg)) + msg = json_stream(record.msg) else : - errorinfo['message'] = record.msg + msg = { 'message': record.msg } + + msg.update(json_stream(e)) try : - self.agent.log_struct(errorinfo, severity=record.levelno) + self.agent.log_struct(msg, severity=record.levelno) except : # noqa: E722 # we really, really do not want to fail-crash here. diff --git a/shared/models/auth.py b/shared/models/auth.py index e89333b..40f23fd 100644 --- a/shared/models/auth.py +++ b/shared/models/auth.py @@ -6,12 +6,30 @@ from pydantic import BaseModel +@unique +class AuthState(IntEnum) : + active = 0 + inactive = 1 + + +class TokenMetadata(BaseModel) : + state: AuthState + key_id: int + user_id: int + version: bytes + algorithm: str + expires: datetime + issued: datetime + fingerprint: bytes + + class AuthToken(NamedTuple) : user_id: int expires: datetime guid: UUID data: dict[str, Any] token_string: str + metadata: TokenMetadata @unique @@ -27,7 +45,7 @@ def all_included_scopes(self: 'Scope') -> list['Scope'] : return [v for v in Scope.__members__.values() if Scope.user.value <= v.value <= self.value] or [self] -class KhUser(NamedTuple) : +class _KhUser(NamedTuple) : user_id: int = -1 token: Optional[AuthToken] = None scope: set[Scope] = set() @@ -50,23 +68,6 @@ class PublicKeyResponse(BaseModel) : expires: datetime -@unique -class AuthState(IntEnum) : - active = 0 - inactive = 1 - - -class TokenMetadata(BaseModel) : - state: AuthState - key_id: int - user_id: int - version: bytes - algorithm: str - expires: datetime - issued: datetime - fingerprint: bytes - - @unique class AuthAlgorithm(Enum) : ed25519 = 'ed25519' diff --git a/shared/models/config.py b/shared/models/config.py new file mode 100644 index 0000000..9ca0d68 --- /dev/null +++ b/shared/models/config.py @@ -0,0 +1,62 @@ +from abc import ABCMeta, abstractmethod +from enum import Enum +from typing import Callable, Self + +from avrofastapi import schema, serialization +from avrofastapi.serialization import AvroDeserializer, AvroSerializer, Schema, parse_avro_schema +from cache import AsyncLRU +from pydantic import BaseModel + +from avro_schema_repository.schema_repository import AvroMarker, SchemaRepository +from shared.caching import ArgsCache +from shared.utilities.json import json_stream + + +repo: SchemaRepository = SchemaRepository() + + +@AsyncLRU(maxsize=32) +async def getSchema(fingerprint: bytes) -> Schema : + return parse_avro_schema((await repo.getSchema(fingerprint)).decode()) + + +def _convert_schema(model: type[BaseModel], error: bool = False, conversions: dict[type, Callable[[schema.AvroSchemaGenerator, type], schema.AvroSchema] | schema.AvroSchema] = { }) -> schema.AvroSchema : + generator: schema.AvroSchemaGenerator = schema.AvroSchemaGenerator(model, error, conversions) + return json_stream(generator.schema()) + + +serialization.convert_schema = schema.convert_schema = _convert_schema + + +class Store(BaseModel, metaclass=ABCMeta) : + + @classmethod + @ArgsCache(float('inf')) + async def fingerprint(cls: type[Self]) -> bytes : + return await repo.addSchema(schema.convert_schema(cls)) + + @classmethod + @ArgsCache(float('inf')) + def serializer(cls: type[Self]) -> AvroSerializer : + return AvroSerializer(cls) + + async def serialize(self: Self) -> bytes : + return AvroMarker + await self.fingerprint() + self.serializer()(self) + + @classmethod + async def deserialize(cls: type[Self], data: bytes) -> Self : + assert data[:2] == AvroMarker + deserializer: AvroDeserializer = AvroDeserializer( + read_model = cls, + write_model = await getSchema(data[2:10]), + ) + return deserializer(data[10:]) + + @staticmethod + @abstractmethod + def type_() -> Enum : + pass + + @classmethod + def key(cls: type[Self]) -> str : + return cls.type_().name diff --git a/shared/models/encryption.py b/shared/models/encryption.py new file mode 100644 index 0000000..a287a28 --- /dev/null +++ b/shared/models/encryption.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +from secrets import token_bytes +from typing import Optional, Self +from xmlrpc.client import boolean + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey +from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.serialization import load_der_public_key + +from ..base64 import b64decode, b64encode + + +@dataclass +class Keys : + aes: AESGCM + _aes_bytes: bytes + ed25519: Optional[Ed25519PrivateKey] + pub: Ed25519PublicKey + associated_data: bytes + + def encrypt(self: Self, data: bytes) -> bytes : + if not self.ed25519 : + raise ValueError('can only encrypt data with private keys') + + nonce = token_bytes(12) + return b'.'.join(map(b64encode, [nonce, self.aes.encrypt(nonce, data, self.associated_data), self.ed25519.sign(data)])) + + def decrypt(self: Self, data: bytes) -> bytes : + nonce: bytes; encrypted: bytes; sig: bytes + nonce, encrypted, sig = map(b64decode, b''.join(data.split()).split(b'.', 3)) + + decrypted: bytes = self.aes.decrypt(nonce, encrypted, self.associated_data) + self.pub.verify(sig, decrypted) + return decrypted + + @staticmethod + def _encode_pub(pub: Ed25519PublicKey) -> bytes : + return pub.public_bytes( + encoding = serialization.Encoding.DER, + format = serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + @staticmethod + def generate() -> 'Keys' : + aesbytes = AESGCM.generate_key(256) + aeskey = AESGCM(aesbytes) + ed25519priv = Ed25519PrivateKey.generate() + + return Keys( + aes = aeskey, + _aes_bytes = aesbytes, + ed25519 = ed25519priv, + pub = ed25519priv.public_key(), + associated_data = Keys._encode_pub(ed25519priv.public_key()), + ) + + @staticmethod + def load(aes: str, pub: str, priv: Optional[str] = None) -> 'Keys' : + aesbytes: bytes; aes_sig: bytes + aesbytes, aes_sig = map(b64decode, aes.split('.', 2)) + + aeskey = AESGCM(aesbytes) + + nonce: bytes; pub_encrypted: bytes; pub_sig: bytes + nonce, pub_encrypted, pub_sig = map(b64decode, pub.split('.', 3)) + + pub_decrypted: bytes = aeskey.decrypt(nonce, pub_encrypted, aesbytes) + pub_key: PublicKeyTypes = load_der_public_key(pub_decrypted, backend=default_backend()) + assert isinstance(pub_key, Ed25519PublicKey) + pub_key.verify(pub_sig, pub_decrypted) + pub_key.verify(aes_sig, aesbytes) + + associated_data: bytes = Keys._encode_pub(pub_key) + + pk: Optional[Ed25519PrivateKey] = None + if priv : + priv_encrypted: bytes; priv_sig: bytes + nonce, priv_encrypted, priv_sig = map(b64decode, priv.split('.', 3)) + priv_decrypted: bytes = aeskey.decrypt(nonce, priv_encrypted, associated_data) + pub_key.verify(priv_sig, priv_decrypted) + pk = Ed25519PrivateKey.from_private_bytes(priv_decrypted) + + return Keys( + aes = aeskey, + _aes_bytes = aesbytes, + ed25519 = pk, + pub = pub_key, + associated_data = associated_data, + ) + + def dump(self: Self, priv: boolean = False) -> dict[str, str] : + if not self.ed25519 : + raise ValueError('can only dump keys that contain private keys') + + data = { + 'aes': b'.'.join(map(b64encode, [self._aes_bytes, self.ed25519.sign(self._aes_bytes)])).decode(), + 'pub': b'.'.join(map(b64encode, [(nonce := token_bytes(12)), self.aes.encrypt(nonce, self.associated_data, self._aes_bytes), self.ed25519.sign(self.associated_data)])).decode(), + } + + if priv : + data['priv'] = b'.'.join(map(b64encode, [(nonce := token_bytes(12)), self.aes.encrypt(nonce, (pb := self.ed25519.private_bytes_raw()), self.associated_data), self.ed25519.sign(pb)])).decode() + + return data diff --git a/shared/models/server.py b/shared/models/server.py new file mode 100644 index 0000000..e402c93 --- /dev/null +++ b/shared/models/server.py @@ -0,0 +1,15 @@ +from fastapi import Request as _req + +from ..auth import KhUser + + +class Request(_req) : + @property + def auth(self) -> KhUser : + assert 'auth' in self.scope, 'AuthenticationMiddleware must be installed to access request.auth' + return self.scope['auth'] + + @property + def user(self) -> KhUser : + assert 'user' in self.scope, 'KhAuthMiddleware must be installed to access request.user' + return self.scope['user'] diff --git a/shared/server/__init__.py b/shared/server/__init__.py index 3384718..e69de29 100644 --- a/shared/server/__init__.py +++ b/shared/server/__init__.py @@ -1,105 +0,0 @@ -from typing import Iterable - -from fastapi import FastAPI, Request -from fastapi.responses import Response -from starlette.middleware.exceptions import ExceptionMiddleware - -from ..auth import KhUser -from ..config.constants import environment -from ..exceptions.base_error import BaseError -from ..exceptions.handler import jsonErrorHandler - - -NoContentResponse = Response(None, status_code=204) - - -class Request(Request) : - @property - def user(self) -> KhUser : - return super().user - - -def ServerApp( - auth: bool = True, - auth_required: bool = True, - cors: bool = True, - max_age: int = 86400, - custom_headers: bool = True, - allowed_hosts: Iterable[str] = [ - 'localhost', - '127.0.0.1', - '*.fuzz.ly', - 'fuzz.ly', - ], - allowed_origins: Iterable[str] = [ - 'localhost', - '127.0.0.1', - 'dev.fuzz.ly', - 'fuzz.ly', - ], - allowed_methods: Iterable[str] = [ - 'GET', - 'POST', - ], - allowed_headers: Iterable[str] = [ - 'accept', - 'accept-language', - 'authorization', - 'cache-control', - 'content-encoding', - 'content-language', - 'content-length', - 'content-security-policy', - 'content-type', - 'cookie', - 'host', - 'location', - 'referer', - 'referrer-policy', - 'set-cookie', - 'user-agent', - 'www-authenticate', - 'x-frame-options', - 'x-xss-protection', - ], - exposed_headers: Iterable[str] = [ - 'authorization', - 'cache-control', - 'content-type', - 'cookie', - 'set-cookie', - 'www-authenticate', - ], -) -> FastAPI : - app = FastAPI() - app.add_middleware(ExceptionMiddleware, handlers={ Exception: jsonErrorHandler }, debug=False) - app.add_exception_handler(BaseError, jsonErrorHandler) - - allowed_protocols = ['http', 'https'] if environment.is_local() else ['https'] - - if custom_headers : - from ..server.middleware import CustomHeaderMiddleware, HeadersToSet - exposed_headers = list(exposed_headers) + list(HeadersToSet.keys()) - app.middleware('http')(CustomHeaderMiddleware) - - if cors : - from ..server.middleware.cors import KhCorsMiddleware - app.add_middleware( - KhCorsMiddleware, - allowed_origins = set(allowed_origins), - allowed_protocols = set(allowed_protocols), - allowed_headers = list(allowed_headers), - allowed_methods = list(allowed_methods), - exposed_headers = list(exposed_headers), - max_age = max_age, - ) - - if allowed_hosts : - from starlette.middleware.trustedhost import TrustedHostMiddleware - app.add_middleware(TrustedHostMiddleware, allowed_hosts=list(allowed_hosts)) - - if auth : - from ..server.middleware.auth import KhAuthMiddleware - app.add_middleware(KhAuthMiddleware, required=auth_required) - - return app diff --git a/shared/sql/__init__.py b/shared/sql/__init__.py index a67d0d3..2f026c8 100644 --- a/shared/sql/__init__.py +++ b/shared/sql/__init__.py @@ -1,21 +1,23 @@ from dataclasses import dataclass from dataclasses import field as dataclass_field -from enum import Enum +from enum import Enum, IntEnum from functools import lru_cache, partial from re import compile from types import TracebackType -from typing import Any, Awaitable, Callable, Dict, List, Optional, Protocol, Self, Tuple, Type, Union +from typing import Any, Awaitable, Callable, Optional, Protocol, Self, Union +from uuid import UUID from psycopg import AsyncClientCursor, AsyncConnection, AsyncCursor, Binary, OperationalError from psycopg_pool import AsyncConnectionPool from pydantic import BaseModel from pydantic.fields import ModelField +from ..config.constants import environment from ..config.credentials import fetch -from ..logging import Logger, getLogger +from ..logging import DEBUG, INFO, Logger, getLogger from ..models import PostId from ..timing import timed -from .query import Field, Insert, Operator, Query, Table, Update, Value, Where +from .query import Field, Insert, Operator, Order, Query, Table, Update, Value, Where _orm_regex = compile(r'orm:"([^\n]*?)(? None : - self.logger: Logger = getLogger() + def __init__(self: Self, long_query_metric: float = 1, conversions: dict[type, Callable] = { }) -> None : + self.logger: Logger = getLogger(level=DEBUG if environment.is_local() else INFO) self._long_query = long_query_metric - self._conversions: Dict[type, Callable] = { + self._conversions: dict[type, Callable] = { tuple: list, bytes: Binary, + IntEnum: lambda x : x.value, Enum: lambda x : x.name, + UUID: str, PostId: PostId.int, **conversions, } + if getattr(SqlInterface, 'db', None) is None : + SqlInterface.db = fetch('db', dict[str, str]) + async def open(self: Self) : if getattr(SqlInterface, 'pool', None) is None : @@ -116,6 +136,7 @@ async def query_async( sql, params = sql.build() params = tuple(map(self._convert_item, params)) + self.logger.debug({ 'sql': sql, 'params': params }) for _ in range(attempts) : async with SqlInterface.pool.connection() as conn : @@ -161,7 +182,7 @@ async def close(self: Self) -> int : @staticmethod - def _table_name(model: Union[BaseModel, Type[BaseModel]]) -> Table : + def _table_name(model: Union[BaseModel, type[BaseModel]]) -> Table : if not hasattr(model, '__table_name__') : raise AttributeError('model must be defined with the __table_name__ attribute') @@ -225,11 +246,11 @@ async def insert[T: BaseModel](self: Self, model: T, query: Optional[AwaitableQu map[subtype.field:column,field:column2] - maps a subtype's field to columns. separate nested fields by periods. """ table: Table = self._table_name(model) - d: Dict[str, Any] = model.dict() - paths: List[Tuple[str, ...]] = [] - vals: List[Value] = [] - cols: List[str] = [] - ret: List[str] = [] + d: dict[str, Any] = model.dict() + paths: list[tuple[str, ...]] = [] + vals: list[Value] = [] + cols: list[str] = [] + ret: list[str] = [] for key, field in model.__fields__.items() : attrs = SqlInterface._orm_attr_parser(field) @@ -279,7 +300,7 @@ async def insert[T: BaseModel](self: Self, model: T, query: Optional[AwaitableQu query = partial(self.query_async, commit=True) assert query - data: Tuple[Any, ...] = await query(sql, fetch_one=bool(ret)) + data: tuple[Any, ...] = await query(sql, fetch_one=bool(ret)) for i, path in enumerate(paths) : v2 = d @@ -306,7 +327,7 @@ def _read_convert(value: Any) -> Any : @staticmethod - def _assign_field_values[T: BaseModel](model: Type[T], data: Tuple[Any, ...]) -> T : + def _assign_field_values[T: BaseModel](model: type[T], data: tuple[Any, ...]) -> T : i = 0 d: dict = { } for key, field in model.__fields__.items() : @@ -386,7 +407,7 @@ async def select[T: BaseModel](self: Self, model: T, query: Optional[AwaitableQu query = partial(self.query_async, commit=False) assert query - data: Tuple[Any, ...] = await query(sql, fetch_one=True) + data: tuple[Any, ...] = await query(sql, fetch_one=True) if not data : raise KeyError('value does not exist in database') @@ -406,11 +427,11 @@ async def update[T: BaseModel](self: Self, model: T, query: Optional[AwaitableQu """ table: Table = self._table_name(model) sql: Query = Query(table) - d: Dict[str, Any] = model.dict() - paths: List[Tuple[str, ...]] = [] - vals: List[Value] = [] - cols: List[str] = [] - ret: List[str] = [] + d: dict[str, Any] = model.dict() + paths: list[tuple[str, ...]] = [] + vals: list[Value] = [] + cols: list[str] = [] + ret: list[str] = [] _, t = str(table).rsplit('.', 1) pk = 0 @@ -468,7 +489,7 @@ async def update[T: BaseModel](self: Self, model: T, query: Optional[AwaitableQu query = partial(self.query_async, commit=True) assert query - data: Tuple[Any, ...] = await query(sql, fetch_one=bool(ret)) + data: tuple[Any, ...] = await query(sql, fetch_one=bool(ret)) for i, path in enumerate(paths) : v2 = d @@ -520,11 +541,14 @@ async def delete(self: Self, model: BaseModel, query: Optional[AwaitableQuery] = await query(sql) - async def where[T: BaseModel](self: Self, model: Type[T], *where: Where, query: Optional[AwaitableQuery] = None) -> List[T] : + async def where[T: BaseModel](self: Self, model: type[T], *where: Where, order: list[tuple[Field, Order]] = [], limit: Optional[int] = None, query: Optional[AwaitableQuery] = None) -> list[T] : table = self._table_name(model) sql = Query(table).where(*where) _, t = str(table).rsplit('.', 1) + for o in order : + sql.order(*o) + for _, field in model.__fields__.items() : attrs = SqlInterface._orm_attr_parser(field) if attrs.ignore : @@ -537,11 +561,14 @@ async def where[T: BaseModel](self: Self, model: Type[T], *where: Where, query: else : sql.select(Field(t, attrs.column or field.name)) + if limit : + sql.limit(limit) + if not query : query = partial(self.query_async, commit=False) assert query - data: List[Tuple[Any, ...]] = await query(sql, fetch_all=True) + data: list[tuple[Any, ...]] = await query(sql, fetch_all=True) return [SqlInterface._assign_field_values(model, row) for row in data] @@ -572,7 +599,11 @@ async def __aenter__(self: Self) : return self - async def __aexit__(self: Self, exc_type: Optional[Type[BaseException]], exc_obj: Optional[BaseException], exc_tb: Optional[TracebackType]) : + async def __await__(self: Self) : + pass + + + async def __aexit__(self: Self, exc_type: Optional[type[BaseException]], exc_obj: Optional[BaseException], exc_tb: Optional[TracebackType]) : if not self.nested : if self.conn : await self.conn.__aexit__(exc_type, exc_obj, exc_tb) @@ -609,6 +640,7 @@ async def query_async( assert self.conn params = tuple(map(self._sql._convert_item, params)) + self._sql.logger.debug({ 'sql': sql, 'params': params }) try : # TODO: convert fuzzly's Query implementation into a psycopg composable diff --git a/shared/timing/__init__.py b/shared/timing/__init__.py index 3345b32..a690c0d 100644 --- a/shared/timing/__init__.py +++ b/shared/timing/__init__.py @@ -54,6 +54,34 @@ def __repr__(self: Self) -> str : ')' ) + @staticmethod + def parse(json: dict[str, dict[str, int | float | dict]]) -> 'Execution' : + assert len(json) == 1 + k, v = json.items().__iter__().__next__() + return Execution( + name = k, + )._parse( + v, + ) + + def _parse(self: Self, json: dict[str, int | float | dict]) -> 'Execution' : + for k, v in json.items() : + match v : + case float() : + self.total = Time(v) + + case int() : + self.count = v + + case _ : + self.nested[k] = Execution( + name = k, + )._parse( + v, + ) + + return self + def record(self: Self, time: float) : self.total = Time(self.total + time) self.count += 1 diff --git a/shared/utilities/__init__.py b/shared/utilities/__init__.py index 4e5b95c..d5b6984 100644 --- a/shared/utilities/__init__.py +++ b/shared/utilities/__init__.py @@ -1,9 +1,14 @@ +from asyncio import Task, create_task from collections import OrderedDict +from contextvars import Context from math import ceil from time import time -from typing import Any, Callable, Generator, Hashable, Iterable, Optional, Tuple, Type, TypeVar +from types import CoroutineType +from typing import Any, Callable, Generator, Hashable, Iterable, Optional, Tuple, Type +from uuid import UUID from pydantic import parse_obj_as +from uuid_extensions import uuid7 as _uuid7 def __clear_cache__(cache: OrderedDict[Hashable, Tuple[float, Any]], t: Callable[[], float] = time) -> None : @@ -73,3 +78,25 @@ def coerse[T](obj: Any, type: Type[T]) -> T : :raises: pydantic.ValidationError on failure """ return parse_obj_as(type, obj) + + +def uuid7() -> UUID : + guid = _uuid7() + assert isinstance(guid, UUID) + return guid + + +background_tasks: set[Task] = set() +def ensure_future[T](fut: CoroutineType[Any, Any, T], name: str | None = None, context: Context | None = None) -> Task[T] : + """ + `utilities.ensure_future` differs from `asyncio.ensure_future` in that this utility function stores a strong + reference to the created task so that it will not get garbage collected before completion. + + `utilities.ensure_future` should be used whenever a task needs to be completed, but not within the context of + a request. Otherwise, `asyncio.create_task` should be used. + """ + # from https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + + background_tasks.add(task := create_task(fut, name=name, context=context)) + task.add_done_callback(background_tasks.discard) + return task diff --git a/shared/utilities/json.py b/shared/utilities/json.py index 1ac1af7..7d52d63 100644 --- a/shared/utilities/json.py +++ b/shared/utilities/json.py @@ -1,23 +1,21 @@ -from datetime import datetime +from datetime import datetime, timedelta from decimal import Decimal from enum import Enum, IntEnum +from hashlib import sha1 from traceback import format_tb -from typing import Any, Callable +from typing import Any, Callable, TypeVar from uuid import UUID from pydantic import BaseModel from ..base64 import b64decode -from ..crc import CRC -from ..models.auth import AuthToken, KhUser +from ..models.auth import AuthToken, _KhUser from . import getFullyQualifiedClassName -crc = CRC(32) - - -_conversions: dict[type, Callable] = { +_conversions: dict[type[T := TypeVar('T')], Callable[[T], Any]] = { datetime: str, + timedelta: timedelta.total_seconds, Decimal: float, float: float, int: int, @@ -31,7 +29,7 @@ IntEnum: lambda x : x.name, Enum: lambda x : x.name, UUID: lambda x : x.hex, - KhUser: lambda x : { + _KhUser: lambda x : { 'user_id': x.user_id, 'scope': json_stream(x.scope), 'token': json_stream(x.token) if x.token else None, @@ -44,7 +42,7 @@ 'token': { 'len': len(x.token_string), 'version': int(b64decode(x.token_string[:x.token_string.find('.')]).decode()), - 'hash': f'{crc(x.token_string.encode()):x}', + 'hash': sha1(x.token_string.encode()).hexdigest(), }, }, BaseModel: lambda x : json_stream(x.dict()), diff --git a/tags/repository.py b/tags/repository.py index ca9c6e5..68a7af2 100644 --- a/tags/repository.py +++ b/tags/repository.py @@ -12,7 +12,7 @@ from shared.sql import SqlInterface from shared.timing import timed from shared.utilities import flatten -from users.repository import Users +from users.repository import Repository as Users from .models import InternalTag, Tag, TagGroup, TagGroups, TagPortable @@ -23,7 +23,7 @@ users = Users() -class Tags(SqlInterface) : +class Repository(SqlInterface) : # TODO: figure out a way that we can increase this TTL (updating inheritance won't be reflected in cache) @timed @@ -138,11 +138,15 @@ async def _get_tag_counts(self, tags: Iterable[str]) -> dict[str, int] : """ counts = await CountKVS.get_many_async(tags) + found: dict[str, int] = { } for k, v in counts.items() : - if v is Undefined : - counts[k] = await self._populate_tag_cache(k) + if isinstance(v, int) : + found[k] = v - return counts + else : + found[k] = await self._populate_tag_cache(k) + + return found async def _get_tag_count(self, tag: str) -> int : diff --git a/tags/router.py b/tags/router.py index d1ecb94..dd3ec28 100644 --- a/tags/router.py +++ b/tags/router.py @@ -3,7 +3,7 @@ from posts.models import PostId from shared.auth import Scope from shared.exceptions.http_error import Forbidden -from shared.server import Request +from shared.models.server import Request from shared.timing import timed from .models import InheritRequest, LookupRequest, RemoveInheritance, Tag, TagGroups, TagsRequest, UpdateRequest diff --git a/tags/tagger.py b/tags/tagger.py index bb3be62..ca85c90 100644 --- a/tags/tagger.py +++ b/tags/tagger.py @@ -6,7 +6,7 @@ from psycopg.errors import NotNullViolation, UniqueViolation from posts.models import InternalPost, PostId, Privacy -from posts.repository import Posts +from posts.repository import Repository as Posts from shared.auth import KhUser, Scope from shared.caching import AerospikeCache, SimpleCache from shared.exceptions.http_error import BadRequest, Conflict, Forbidden, HttpErrorHandler, NotFound @@ -16,14 +16,14 @@ from shared.utilities import flatten from .models import InternalTag, Tag, TagGroup, TagGroups -from .repository import TagKVS, Tags, users +from .repository import Repository, TagKVS, users posts = Posts() Misc: TagGroup = TagGroup('misc') -class Tagger(Tags) : +class Tagger(Repository) : def _validateDescription(self, description: str) : if len(description) > 1000 : diff --git a/users/repository.py b/users/repository.py index 5e7c2ea..d0aa06e 100644 --- a/users/repository.py +++ b/users/repository.py @@ -1,6 +1,8 @@ from asyncio import ensure_future +from collections import defaultdict +from curses.panel import bottom_panel from datetime import datetime -from typing import Iterable, Optional, Self, Union +from typing import Iterable, Mapping, Optional, Self, Union from cache import AsyncLRU @@ -9,12 +11,13 @@ from shared.caching.key_value_store import KeyValueStore from shared.exceptions.http_error import BadRequest, NotFound from shared.maps import privacy_map -from shared.models import Badge, InternalUser, Privacy, Undefined, User, UserPortable, UserPrivacy, Verified +from shared.models import Badge, InternalUser, PostId, Privacy, User, UserPortable, UserPrivacy, Verified from shared.sql import SqlInterface from shared.timing import timed UserKVS: KeyValueStore = KeyValueStore('kheina', 'users', local_TTL=60) +handleKVS: KeyValueStore = KeyValueStore('kheina', 'user_handle_map', local_TTL=60) FollowKVS: KeyValueStore = KeyValueStore('kheina', 'following') @@ -95,7 +98,7 @@ async def get_id(self: Self, key: Badge) -> int : badge_map: BadgeMap = BadgeMap() -class Users(SqlInterface) : +class Repository(SqlInterface) : def _clean_text(self: Self, text: str) -> Optional[str] : text = text.strip() @@ -199,20 +202,35 @@ async def _get_users(self, user_ids: Iterable[int]) -> dict[int, InternalUser] : if not user_ids : return { } - cached = await UserKVS.get_many_async(user_ids) + cached = await UserKVS.get_many_async(map(str, user_ids)) + found: dict[int, InternalUser] = { } misses: list[int] = [] - for k, v in list(cached.items()) : - if v is not Undefined : + for k, v in cached.items() : + if isinstance(v, InternalUser) : + found[int(k)] = v continue - misses.append(k) - del cached[k] + misses.append(int(k)) if not misses : - return cached - - data: list[tuple] = await self.query_async(""" + return found + + data: list[tuple[ + int, + str, + str, + int, + PostId | None, + str | None, + datetime, + str | None, + PostId | None, + bool, + bool, + bool, + list[int], + ]] = await self.query_async(""" SELECT users.user_id, users.display_name, @@ -236,13 +254,13 @@ async def _get_users(self, user_ids: Iterable[int]) -> dict[int, InternalUser] : """, ( misses, ), - fetch_all=True, + fetch_all = True, ) if not data : - raise NotFound('not all users could be found.', user_ids=user_ids, misses=misses, cached=cached, data=data) + raise NotFound('not all users could be found.', user_ids=user_ids, misses=misses, found=found, data=data) - users: dict[int, InternalUser] = cached + users: dict[int, InternalUser] = found for datum in data : verified: Optional[Verified] = None @@ -276,7 +294,8 @@ async def _get_users(self, user_ids: Iterable[int]) -> dict[int, InternalUser] : return users - @AerospikeCache('kheina', 'user_handle_map', '{handle}', local_TTL=60) + @timed + @AerospikeCache('kheina', 'user_handle_map', '{handle}', local_TTL=60, _kvs=handleKVS) async def _handle_to_user_id(self: Self, handle: str) -> int : data = await self.query_async(""" SELECT @@ -286,7 +305,7 @@ async def _handle_to_user_id(self: Self, handle: str) -> int : """, ( handle.lower(), ), - fetch_one=True, + fetch_one = True, ) if not data : @@ -295,6 +314,46 @@ async def _handle_to_user_id(self: Self, handle: str) -> int : return data[0] + @timed + async def _handles_to_user_ids(self: Self, handles: Iterable[str]) -> dict[str, int] : + handles = list(handles) + + if not handles : + return { } + + cached = await handleKVS.get_many_async(handles) + found: dict[str, int] = { } + misses: list[str] = [] + + for k, v in cached.items() : + if isinstance(v, int) : + found[k] = v + continue + + misses.append(k) + + if not misses : + return found + + data: list[tuple[str, int]] = await self.query_async(""" + SELECT + users.handle, + users.user_id + FROM kheina.public.users + WHERE users.handle = any(%s); + """, ( + misses, + ), + fetch_all = True, + ) + + for datum in data : + found[datum[0]] = datum[1] + ensure_future(handleKVS.put_async(datum[0], datum[1])) + + return found + + async def _get_user_by_handle(self: Self, handle: str) -> InternalUser : user_id: int = await self._handle_to_user_id(handle.lower()) return await self._get_user(user_id) @@ -337,17 +396,18 @@ async def following_many(self: Self, user_id: int, targets: list[int]) -> dict[i int(k[k.rfind('|') + 1:]): v for k, v in (await FollowKVS.get_many_async([f'{user_id}|{t}' for t in targets])).items() } + found: dict[int, bool] = { } misses: list[int] = [] - for k, v in list(cached.items()) : - if v is not Undefined : + for k, v in cached.items() : + if isinstance(v, bool) : + found[k] = v continue misses.append(k) - cached[k] = None if not misses : - return cached + return found data: list[tuple[int, int]] = await self.query_async(""" SELECT following.follows, count(1) @@ -359,10 +419,10 @@ async def following_many(self: Self, user_id: int, targets: list[int]) -> dict[i user_id, misses, ), - fetch_all=True, + fetch_all = True, ) - return_value: dict[int, bool] = cached + return_value: dict[int, bool] = found for target, following in data : following = bool(following) @@ -411,7 +471,18 @@ async def portables(self: Self, user: KhUser, iusers: Iterable[InternalUser]) -> returns a map of user id -> UserPortable """ - following = await self.following_many(user.user_id, [iuser.user_id for iuser in iusers]) + iusers = list(iusers) + if not iusers : + return { } + + following: Mapping[int, Optional[bool]] + + if await user.authenticated(False) : + following = await self.following_many(user.user_id, [iuser.user_id for iuser in iusers]) + + else : + following = defaultdict(lambda : None) + return { iuser.user_id: UserPortable( name = iuser.name, diff --git a/users/router.py b/users/router.py index a9d187f..a5c7473 100644 --- a/users/router.py +++ b/users/router.py @@ -5,8 +5,8 @@ from shared.exceptions.http_error import HttpErrorHandler from shared.models import Badge, User from shared.models.auth import Scope +from shared.models.server import Request from shared.models.user import SetMod, SetVerified, UpdateSelf -from shared.server import Request from shared.timing import timed from .users import Users @@ -22,14 +22,6 @@ users: Users = Users() -################################################## INTERNAL ################################################## -# @app.get('/i1/{user_id}', response_model=InternalUser) -# async def i1User(req: Request, user_id: int) : -# await req.user.verify_scope(Scope.internal) -# return await users._get_user(user_id) - - -################################################## PUBLIC ################################################## @userRouter.get('/self', response_model=User) @timed.root @HttpErrorHandler("retrieving user's own profile") @@ -105,18 +97,18 @@ async def v1User(req: Request, handle: str) : return await users.getUser(req.user, handle) -@userRouter.put('/{handle}/follow', status_code=204) +@userRouter.put('/{handle}/follow', response_model=bool) @timed.root -async def v1FollowUser(req: Request, handle: str) : +async def v1FollowUser(req: Request, handle: str) -> bool : await req.user.authenticated() - await users.followUser(req.user, handle.lower()) + return await users.followUser(req.user, handle.lower()) -@userRouter.delete('/{handle}/follow', status_code=204) +@userRouter.delete('/{handle}/follow', response_model=bool) @timed.root -async def v1UnfollowUser(req: Request, handle: str) : +async def v1UnfollowUser(req: Request, handle: str) -> bool : await req.user.authenticated() - await users.unfollowUser(req.user, handle.lower()) + return await users.unfollowUser(req.user, handle.lower()) app = APIRouter( diff --git a/users/users.py b/users/users.py index c0c1a41..c7a34d9 100644 --- a/users/users.py +++ b/users/users.py @@ -1,68 +1,93 @@ -from asyncio import Task, ensure_future -from typing import List, Optional +from asyncio import Task, create_task +from typing import Optional, Self +from notifications.models import InternalUserNotification, UserNotificationEvent +from notifications.repository import notifier from shared.auth import KhUser from shared.caching import SimpleCache from shared.exceptions.http_error import BadRequest, HttpErrorHandler, NotFound from shared.models import Badge, InternalUser, User, UserPrivacy, Verified +from shared.models._shared import UserPortable +from shared.timing import timed +from shared.utilities import ensure_future -from .repository import FollowKVS, UserKVS, Users, badge_map, privacy_map # type: ignore +from .repository import FollowKVS, Repository, UserKVS, badge_map, privacy_map -class Users(Users) : +class Users(Repository) : @HttpErrorHandler('retrieving user') - async def getUser(self: 'Users', user: KhUser, handle: str) -> User : + async def getUser(self: Self, user: KhUser, handle: str) -> User : iuser: InternalUser = await self._get_user_by_handle(handle) return await self.user(user, iuser) - async def followUser(self: 'Users', user: KhUser, handle: str) -> None : - user_id: int = await self._handle_to_user_id(handle.lower()) + @timed + async def followUser(self: Self, user: KhUser, handle: str) -> bool : + user_id: int = await self._handle_to_user_id(handle.lower()) following: bool = await self.following(user.user_id, user_id) if following : raise BadRequest('you are already following this user.') + portable: Task[UserPortable] = create_task(self.portable( + KhUser(user_id=user_id), + await self._get_user(user.user_id), + )) + await self.query_async(""" INSERT INTO kheina.public.following (user_id, follows) VALUES (%s, %s); - """, - (user.user_id, user_id), - commit=True, + """, ( + user.user_id, + user_id, + ), + commit = True, ) - FollowKVS.put(f'{user.user_id}|{user_id}', True) + ensure_future(FollowKVS.put_async(f'{user.user_id}|{user_id}', following := True)) + ensure_future(notifier.sendNotification( + user_id, + InternalUserNotification( + event = UserNotificationEvent.follow, + user_id = user.user_id, + ), + user = await portable, + )) + return following - async def unfollowUser(self: 'Users', user: KhUser, handle: str) -> None : - user_id: int = await self._handle_to_user_id(handle.lower()) + async def unfollowUser(self: Self, user: KhUser, handle: str) -> bool : + user_id: int = await self._handle_to_user_id(handle.lower()) following: bool = await self.following(user.user_id, user_id) if following is False : - raise BadRequest('you are already not following this user.') + raise BadRequest('you are not currently following this user.') await self.query_async(""" DELETE FROM kheina.public.following WHERE following.user_id = %s AND following.follows = %s - """, - (user.user_id, user_id), - commit=True, + """, ( + user.user_id, + user_id, + ), + commit = True, ) - FollowKVS.put(f'{user.user_id}|{user_id}', False) + ensure_future(FollowKVS.put_async(f'{user.user_id}|{user_id}', following := False)) + return following - async def getSelf(self: 'Users', user: KhUser) -> User : + async def getSelf(self: Self, user: KhUser) -> User : iuser: InternalUser = await self._get_user(user.user_id) return await self.user(user, iuser) @HttpErrorHandler('updating user profile') - async def updateSelf(self: 'Users', user: KhUser, name: Optional[str], privacy: Optional[UserPrivacy], website: Optional[str], description: Optional[str]) -> None : + async def updateSelf(self: Self, user: KhUser, name: Optional[str], privacy: Optional[UserPrivacy], website: Optional[str], description: Optional[str]) -> None : iuser: InternalUser = await self._get_user(user.user_id) if not any([name, privacy, website, description]) : @@ -87,7 +112,7 @@ async def updateSelf(self: 'Users', user: KhUser, name: Optional[str], privacy: @HttpErrorHandler('fetching all users') - async def getUsers(self: 'Users', user: KhUser) : + async def getUsers(self: Self, user: KhUser) : # TODO: this function desperately needs to be reworked data = await self.query_async(""" SELECT @@ -119,7 +144,7 @@ async def getUsers(self: 'Users', user: KhUser) : users.admin, users.verified; """, - fetch_all=True, + fetch_all = True, ) return [ @@ -138,24 +163,26 @@ async def getUsers(self: 'Users', user: KhUser) : Verified.artist if row[10] else None ) ), - following=None, + following = None, ) for row in data ] @HttpErrorHandler('setting mod') - async def setMod(self: 'Users', handle: str, mod: bool) -> None : + async def setMod(self: Self, handle: str, mod: bool) -> None : user_id: int = await self._handle_to_user_id(handle.lower()) - user_task: Task[InternalUser] = ensure_future(self._get_user(user_id)) + user_task: Task[InternalUser] = create_task(self._get_user(user_id)) await self.query_async(""" UPDATE kheina.public.users SET mod = %s WHERE users.user_id = %s - """, - (mod, user_id), - commit=True, + """, ( + mod, + user_id, + ), + commit = True, ) user: Optional[InternalUser] = await user_task @@ -165,13 +192,13 @@ async def setMod(self: 'Users', handle: str, mod: bool) -> None : @SimpleCache(60) - async def fetchBadges(self: 'Users') -> List[Badge] : + async def fetchBadges(self: Self) -> list[Badge] : return await badge_map.all() @HttpErrorHandler('adding badge to self') - async def addBadge(self: 'Users', user: KhUser, badge: Badge) -> None : - iuser_task: Task[InternalUser] = ensure_future(self._get_user(user.user_id)) + async def addBadge(self: Self, user: KhUser, badge: Badge) -> None : + iuser_task: Task[InternalUser] = create_task(self._get_user(user.user_id)) try : badge_id: int = await badge_map.get_id(badge) @@ -188,9 +215,11 @@ async def addBadge(self: 'Users', user: KhUser, badge: Badge) -> None : (user_id, badge_id) VALUES (%s, %s); - """, - (user.user_id, badge_id), - commit=True, + """, ( + user.user_id, + badge_id, + ), + commit = True, ) iuser.badges.append(badge) @@ -198,8 +227,8 @@ async def addBadge(self: 'Users', user: KhUser, badge: Badge) -> None : @HttpErrorHandler('removing badge from self') - async def removeBadge(self: 'Users', user: KhUser, badge: Badge) -> None : - iuser_task: Task[InternalUser] = ensure_future(self._get_user(user.user_id)) + async def removeBadge(self: Self, user: KhUser, badge: Badge) -> None : + iuser_task: Task[InternalUser] = create_task(self._get_user(user.user_id)) try : badge_id: int = await badge_map.get_id(badge) @@ -218,39 +247,44 @@ async def removeBadge(self: 'Users', user: KhUser, badge: Badge) -> None : DELETE FROM kheina.public.user_badge WHERE user_id = %s AND badge_id = %s; - """, - (user.user_id, badge_id), - commit=True, + """, ( + user.user_id, + badge_id, + ), + commit = True, ) UserKVS.put(str(user.user_id), iuser) @HttpErrorHandler('creating badge') - async def createBadge(self: 'Users', badge: Badge) -> None : + async def createBadge(self: Self, badge: Badge) -> None : await self.query_async(""" INSERT INTO kheina.public.badges (emoji, label) VALUES (%s, %s); - """, - (badge.emoji, badge.label), - commit=True, + """, ( + badge.emoji, + badge.label, + ), + commit = True, ) @HttpErrorHandler('verifying user') - async def verifyUser(self: 'Users', handle: str, verified: Verified) -> None : + async def verifyUser(self: Self, handle: str, verified: Verified) -> None : user_id: int = await self._handle_to_user_id(handle.lower()) - user_task: Task[InternalUser] = ensure_future(self._get_user(user_id)) + user_task: Task[InternalUser] = create_task(self._get_user(user_id)) await self.query_async(f""" UPDATE kheina.public.users set {'verified' if verified == Verified.artist else verified.name} = true WHERE users.user_id = %s; - """, - (user_id,), - commit=True, + """, ( + user_id, + ), + commit = True, ) user: Optional[InternalUser] = await user_task