From 4ffc50730b431468ebfdd40aa26ddc3eb568541d Mon Sep 17 00:00:00 2001 From: dani <29378233+DanielleMiu@users.noreply.github.com> Date: Wed, 9 Apr 2025 02:14:47 -0400 Subject: [PATCH 1/5] meow --- account/account.py | 2 +- account/router.py | 18 +- authenticator/authenticator.py | 71 ++-- avro_schema_repository/schema_repository.py | 48 +-- configs/configs.py | 325 ++++++++++-------- configs/models.py | 119 +++++-- configs/router.py | 17 +- db/10/00-add-notification-subscriptions.sql | 54 +++ db/11/00-create-user-configs-table.sql | 17 + emojis/repository.py | 2 +- emojis/router.py | 2 +- init.py | 53 ++- notifications/__init__.py | 0 notifications/models.py | 143 ++++++++ notifications/repository.py | 286 ++++++++++++++++ notifications/router.py | 50 +++ posts/blocking.py | 47 +-- posts/models.py | 5 +- posts/posts.py | 18 +- posts/repository.py | 195 +++++------ posts/router.py | 8 +- posts/uploader.py | 350 ++++++++++---------- reporting/mod_actions.py | 112 +++++-- reporting/models/__init__.py | 3 +- reporting/models/actions.py | 3 +- reporting/models/bans.py | 2 +- reporting/models/mod_queue.py | 2 +- reporting/reporting.py | 16 +- reporting/repository.py | 32 +- reporting/router.py | 39 ++- requirements.lock | 5 + sample-creds.json | 7 +- server.py | 2 + sets/repository.py | 13 +- sets/router.py | 3 +- sets/sets.py | 11 +- shared/auth/__init__.py | 4 +- shared/caching/__init__.py | 10 +- shared/caching/key_value_store.py | 149 ++++++--- shared/exceptions/base_error.py | 3 +- shared/exceptions/http_error.py | 14 +- shared/logging.py | 28 +- shared/models/auth.py | 2 +- shared/models/config.py | 62 ++++ shared/models/encryption.py | 106 ++++++ shared/models/server.py | 15 + shared/server/__init__.py | 105 ------ shared/sql/__init__.py | 70 ++-- shared/timing/__init__.py | 28 ++ shared/utilities/__init__.py | 29 +- shared/utilities/json.py | 18 +- tags/repository.py | 14 +- tags/router.py | 2 +- tags/tagger.py | 6 +- users/repository.py | 113 +++++-- users/router.py | 10 +- users/users.py | 70 ++-- 57 files changed, 1992 insertions(+), 946 deletions(-) create mode 100644 db/10/00-add-notification-subscriptions.sql create mode 100644 db/11/00-create-user-configs-table.sql create mode 100644 notifications/__init__.py create mode 100644 notifications/models.py create mode 100644 notifications/repository.py create mode 100644 notifications/router.py create mode 100644 shared/models/config.py create mode 100644 shared/models/encryption.py create mode 100644 shared/models/server.py diff --git a/account/account.py b/account/account.py index bc38a8a..5d3d679 100644 --- a/account/account.py +++ b/account/account.py @@ -13,7 +13,7 @@ from shared.exceptions.http_error import BadRequest, Conflict, HttpError, HttpErrorHandler, Unauthorized from shared.hashing import Hashable from shared.models.auth import AuthToken, Scope -from shared.server import Request +from shared.models.server import Request from shared.sql import SqlInterface diff --git a/account/router.py b/account/router.py index eed29f5..4c7e8ae 100644 --- a/account/router.py +++ b/account/router.py @@ -6,15 +6,15 @@ from shared.config.constants import environment from shared.datetime import datetime from shared.exceptions.http_error import BadRequest -from shared.server import Request +from shared.models.server import Request from .account import Account, auth from .models import ChangeHandle, CreateAccountRequest, FinalizeAccountRequest, OtpFinalizeRequest, OtpRemoveEmailRequest, OtpRemoveRequest, OtpRequest app = APIRouter( - prefix='/v1/account', - tags=['account'], + prefix = '/v1/account', + tags = ['account'], ) account = Account() @@ -108,19 +108,7 @@ async def v1BotLogin(body: BotLoginRequest) -> LoginResponse : return await auth.botLogin(body.token) -@app.post('/bot/renew', response_model=BotCreateResponse) -async def v1BotRenew(req: Request) -> BotCreateResponse : - await req.user.verify_scope(Scope.internal) - return await auth.createBot(req.user, BotType.internal) - - @app.get('/bot/create', response_model=BotCreateResponse) async def v1BotCreate(req: Request) -> BotCreateResponse : await req.user.verify_scope(Scope.user) return await auth.createBot(req.user, BotType.bot) - - -@app.get('/bot/internal', response_model=BotCreateResponse) -async def v1BotCreateInternal(req: Request) -> BotCreateResponse : - await req.user.verify_scope(Scope.admin) - return await auth.createBot(req.user, BotType.internal) diff --git a/authenticator/authenticator.py b/authenticator/authenticator.py index 57d52be..b596ee0 100644 --- a/authenticator/authenticator.py +++ b/authenticator/authenticator.py @@ -26,7 +26,7 @@ from shared.exceptions.http_error import BadRequest, Conflict, FailedLogin, HttpError, InternalServerError, NotFound, UnprocessableEntity from shared.hashing import Hashable from shared.models import InternalUser -from shared.models.auth import AuthState, AuthToken, KhUser, Scope, TokenMetadata +from shared.models.auth import AuthState, AuthToken, Scope, TokenMetadata, _KhUser from shared.sql import SqlInterface from shared.timing import timed from shared.utilities.json import json_stream @@ -153,6 +153,13 @@ def __init__(self) : if not getattr(Authenticator, 'KVS', None) : Authenticator.KVS = KeyValueStore('kheina', 'token') + # create the index used to query active logins + KeyValueStore._client.index_integer_create( # type: ignore + 'kheina', + 'token', + 'user_id', + 'kheina_token_user_id_idx', + ) def _validateEmail(self, email: str) -> Dict[str, str] : @@ -201,18 +208,18 @@ async def generate_token(self, user_id: int, token_data: dict, ttl: Optional[int start = self._calc_timestamp(issued) end = start + self._key_refresh_interval self._active_private_key = { - 'key': None, + 'key': None, 'algorithm': self._token_algorithm, - 'issued': 0, - 'start': start, - 'end': end, - 'id': 0, + 'issued': 0, + 'start': start, + 'end': end, + 'id': 0, } private_key = self._active_private_key['key'] = Ed25519PrivateKey.generate() public_key = private_key.public_key().public_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PublicFormat.SubjectPublicKeyInfo, + encoding = serialization.Encoding.DER, + format = serialization.PublicFormat.SubjectPublicKeyInfo, ) signature = private_key.sign(public_key) @@ -237,10 +244,10 @@ async def generate_token(self, user_id: int, token_data: dict, ttl: Optional[int # put the new key into the public keyring self._public_keyring[(self._token_algorithm, key_id)] = { - 'key': b64encode(public_key).decode(), + 'key': b64encode(public_key).decode(), 'signature': b64encode(signature).decode(), - 'issued': pk_issued, - 'expires': pk_expires, + 'issued': pk_issued, + 'expires': pk_expires, } guid: UUID = uuid4() @@ -255,16 +262,22 @@ async def generate_token(self, user_id: int, token_data: dict, ttl: Optional[int ]) token_info: TokenMetadata = TokenMetadata( - version=self._token_version.encode(), - state=AuthState.active, - issued=datetime.fromtimestamp(issued), - expires=datetime.fromtimestamp(expires), - key_id=key_id, - user_id=user_id, - algorithm=self._token_algorithm, - fingerprint=token_data.get('fp', '').encode(), + version = self._token_version.encode(), + state = AuthState.active, + issued = datetime.fromtimestamp(issued), + expires = datetime.fromtimestamp(expires), + key_id = key_id, + user_id = user_id, + algorithm = self._token_algorithm, + fingerprint = token_data.get('fp', '').encode(), + ) + await Authenticator.KVS.put_async( + guid.bytes, + token_info, + ttl or self._token_expires_interval, + # additional bins for querying active logins + { 'user_id': user_id }, ) - await Authenticator.KVS.put_async(guid.bytes, token_info, ttl or self._token_expires_interval) version = self._token_version.encode() content = b64encode(version) + b'.' + b64encode(load) @@ -272,12 +285,12 @@ async def generate_token(self, user_id: int, token_data: dict, ttl: Optional[int token = content + b'.' + b64encode(signature) return TokenResponse( - version=self._token_version, - algorithm=self._token_algorithm, # type: ignore - key_id=key_id, - issued=issued, # type: ignore - expires=expires, # type: ignore - token=token.decode(), + version = self._token_version, + algorithm = self._token_algorithm, # type: ignore + key_id = key_id, + issued = issued, # type: ignore + expires = expires, # type: ignore + token = token.decode(), ) @@ -440,7 +453,7 @@ async def login(self, email: str, password: str, otp: Optional[str], token_data: ) - async def createBot(self, user: KhUser, bot_type: BotType) -> BotCreateResponse : + async def createBot(self, user: _KhUser, bot_type: BotType) -> BotCreateResponse : if type(bot_type) is not BotType : # this should never run, thanks to pydantic/fastapi. just being extra careful. raise BadRequest('bot_type must be a BotType value.') @@ -709,11 +722,11 @@ async def create(self, handle: str, name: str, email: str, password: str, token_ raise InternalServerError('an error occurred during user creation.', logdata={ 'refid': refid }) - async def create_otp(self: Self, user: KhUser) -> str : + async def create_otp(self: Self, user: _KhUser) -> str : return pyotp.random_base32() - async def add_otp(self: Self, user: KhUser, email: str, otp_secret: str, otp: str) -> OtpAddedResponse : + async def add_otp(self: Self, user: _KhUser, email: str, otp_secret: str, otp: str) -> OtpAddedResponse : if not pyotp.TOTP(otp_secret).verify(otp) : raise BadRequest('failed to add OTP', email=email, user=user) diff --git a/avro_schema_repository/schema_repository.py b/avro_schema_repository/schema_repository.py index 1c3750f..4fd3dfc 100644 --- a/avro_schema_repository/schema_repository.py +++ b/avro_schema_repository/schema_repository.py @@ -1,18 +1,17 @@ -from typing import List +from hashlib import sha1 import ujson from avrofastapi.schema import AvroSchema -from shared.base64 import b64encode from shared.caching import AerospikeCache from shared.caching.key_value_store import KeyValueStore -from shared.crc import CRC from shared.exceptions.http_error import HttpErrorHandler, NotFound from shared.sql import SqlInterface -KVS: KeyValueStore = KeyValueStore('kheina', 'avro_schemas', local_TTL=60) -crc: CRC = CRC(64) +AvroMarker: bytes = b'\xC3\x01' +kvs: KeyValueStore = KeyValueStore('kheina', 'avro_schemas', local_TTL=60) +key_format: str = '{fingerprint}' def int_to_bytes(integer: int) -> bytes : @@ -23,24 +22,29 @@ def int_from_bytes(bytestring: bytes) -> int : return int.from_bytes(bytestring, 'little') +def crc(value: bytes) -> int : + return int.from_bytes(sha1(value).digest()[:8]) + + class SchemaRepository(SqlInterface) : @HttpErrorHandler('retrieving schema') - @AerospikeCache('kheina', 'avro_schemas', '{fingerprint}', _kvs=KVS) + @AerospikeCache('kheina', 'avro_schemas', key_format, _kvs=kvs) async def getSchema(self, fingerprint: bytes) -> bytes : """ - returns the avro schema as a json encoded string + returns the avro schema as a json encoded byte string """ fp: int = int_from_bytes(fingerprint) - data: List[bytes] = await self.query_async(""" + data: list[bytes] = await self.query_async(""" SELECT schema FROM kheina.public.avro_schemas WHERE fingerprint = %s; - """, + """, ( # because crc returns unsigned, we "convert" to signed - (fp - 9223372036854775808,), - fetch_one=True, + fp - 9223372036854775808, + ), + fetch_one = True, ) if not data : @@ -51,8 +55,11 @@ async def getSchema(self, fingerprint: bytes) -> bytes : @HttpErrorHandler('saving schema') async def addSchema(self, schema: AvroSchema) -> bytes : + """ + returns the schema fingerprint as a bytestring + """ data: bytes = ujson.dumps(schema).encode() - fingerprint: int = crc(data) + fp: int = crc(data) await self.query_async(""" INSERT INTO kheina.public.avro_schemas @@ -61,14 +68,15 @@ async def addSchema(self, schema: AvroSchema) -> bytes : (%s, %s) ON CONFLICT ON CONSTRAINT avro_schemas_pkey DO UPDATE SET - schema = %s; - """, + schema = excluded.schema; + """, ( # because crc returns unsigned, we "convert" to signed - (fingerprint - 9223372036854775808, data, data), - commit=True, + fp - 9223372036854775808, + data, + ), + commit = True, ) - fp: bytes = int_to_bytes(fingerprint) - KVS.put(b64encode(fp).decode(), schema) - - return fp + fingerprint: bytes = int_to_bytes(fp) + await kvs.put_async(key_format.format(fingerprint=fingerprint), schema) + return fingerprint diff --git a/configs/configs.py b/configs/configs.py index 5386571..618e4e4 100644 --- a/configs/configs.py +++ b/configs/configs.py @@ -1,37 +1,32 @@ -from asyncio import ensure_future +from asyncio import Task, ensure_future from collections.abc import Iterable from datetime import datetime +from enum import Enum from random import randrange from re import Match, Pattern from re import compile as re_compile -from typing import Any, Optional, Self, Type +from typing import Literal, Optional, Self -from avrofastapi.schema import convert_schema -from avrofastapi.serialization import AvroDeserializer, AvroSerializer, Schema, parse_avro_schema -from cache import AsyncLRU +import aerospike from patreon import API as PatreonApi -from pydantic import BaseModel -from avro_schema_repository.schema_repository import SchemaRepository from shared.auth import KhUser from shared.caching import AerospikeCache from shared.caching.key_value_store import KeyValueStore from shared.config.constants import environment from shared.config.credentials import fetch from shared.exceptions.http_error import BadRequest, HttpErrorHandler, NotFound -from shared.models import Undefined +from shared.models import PostId from shared.sql import SqlInterface from shared.timing import timed +from users.repository import Repository as Users -from .models import OTP, BannerStore, ConfigsResponse, ConfigType, CostsStore, CssProperty, Funding, OtpType, UserConfig, UserConfigKeyFormat, UserConfigRequest, UserConfigResponse +from .models import OTP, BannerStore, BlockBehavior, Blocking, BlockingBehavior, ConfigsResponse, ConfigType, CostsStore, CssProperty, Funding, OtpType, Store, Theme, UserConfigKeyFormat, UserConfigResponse, UserConfigType -repo: SchemaRepository = SchemaRepository() - PatreonClient: PatreonApi = PatreonApi(fetch('creator_access_token', str)) KVS: KeyValueStore = KeyValueStore('kheina', 'configs', local_TTL=60) -UserConfigSerializer: AvroSerializer = AvroSerializer(UserConfig) -AvroMarker: bytes = b'\xC3\x01' +users: Users = Users() ColorRegex: Pattern = re_compile(r'^(?:#(?P[a-f0-9]{8}|[a-f0-9]{6})|(?P[a-z0-9-]+))$') PropValidators: dict[CssProperty, Pattern] = { CssProperty.background_attachment: re_compile(r'^(?:scroll|fixed|local)(?:,\s*(?:scroll|fixed|local))*$'), @@ -43,30 +38,14 @@ class Configs(SqlInterface) : - UserConfigFingerprint: bytes - Serializers: dict[ConfigType, tuple[AvroSerializer, bytes]] - SerializerTypeMap: dict[ConfigType, Type[BaseModel]] = { - ConfigType.banner: BannerStore, - ConfigType.costs: CostsStore, + SerializerTypeMap: dict[Enum, type[Store]] = { + ConfigType.banner: BannerStore, + ConfigType.costs: CostsStore, + UserConfigType.blocking: Blocking, + UserConfigType.block_behavior: BlockBehavior, + UserConfigType.theme: Theme, } - async def startup(self) -> bool : - Configs.Serializers = { - ConfigType.banner: (AvroSerializer(BannerStore), await repo.addSchema(convert_schema(BannerStore))), - ConfigType.costs: (AvroSerializer(CostsStore), await repo.addSchema(convert_schema(CostsStore))), - } - self.UserConfigFingerprint = await repo.addSchema(convert_schema(UserConfig)) - assert self.Serializers.keys() == set(ConfigType.__members__.values()), 'Did you forget to add serializers for a config?' - assert self.SerializerTypeMap.keys() == set(ConfigType.__members__.values()), 'Did you forget to add serializers for a config?' - return True - - - @AsyncLRU(maxsize=32) - @staticmethod - async def getSchema(fingerprint: bytes) -> Schema : - return parse_avro_schema((await repo.getSchema(fingerprint)).decode()) - - @HttpErrorHandler('retrieving patreon campaign info') @AerospikeCache('kheina', 'configs', 'patreon-campaign-funds', TTL_minutes=10, _kvs=KVS) async def getFunding(self) -> int : @@ -78,24 +57,25 @@ async def getFunding(self) -> int : @HttpErrorHandler('retrieving config') - async def getConfigs(self, configs: Iterable[ConfigType]) -> dict[ConfigType, Any] : + async def getConfigs(self: Self, configs: Iterable[ConfigType]) -> dict[ConfigType, Store] : keys = list(configs) if not keys : return { } - cached = await KVS.get_many_async(keys) + cached = await KVS.get_many_async(keys, Store) + found: dict[ConfigType, Store] = { } misses: list[ConfigType] = [] - for k, v in list(cached.items()) : - if v is not Undefined : + for k, v in cached.items() : + if isinstance(v, Store) : + found[k] = v continue misses.append(k) - del cached[k] if not misses : - return cached + return found data: Optional[list[tuple[str, bytes]]] = await self.query_async(""" SELECT key, bytes @@ -111,55 +91,53 @@ async def getConfigs(self, configs: Iterable[ConfigType]) -> dict[ConfigType, An raise NotFound('no data was found for the provided config.') for k, v in data : - v: bytes = bytes(v) - assert v[:2] == AvroMarker - config: ConfigType = ConfigType(k) - deserializer: AvroDeserializer = AvroDeserializer( - read_model = self.SerializerTypeMap[config], - write_model = await Configs.getSchema(v[2:10]), - ) - value = cached[config] = deserializer(v[10:]) + value = found[config] = await self.SerializerTypeMap[config].deserialize(bytes(v)) ensure_future(KVS.put_async(config, value)) - return cached + return found async def allConfigs(self: Self) -> ConfigsResponse : funds = ensure_future(self.getFunding()) - configs = await self.getConfigs(self.SerializerTypeMap.keys()) + configs = await self.getConfigs([ + ConfigType.banner, + ConfigType.costs, + ]) + banner = configs[ConfigType.banner] + assert isinstance(banner, BannerStore), f'banner is not the expected type of BannerStore, got: {type(banner)}' + costs = configs[ConfigType.costs] + assert isinstance(costs, CostsStore), f'costs is not the expected type of CostsStore, got: {type(costs)}' return ConfigsResponse( - banner = configs[ConfigType.banner].banner, + banner = banner.banner, funding = Funding( funds = await funds, - costs = configs[ConfigType.costs].costs, + costs = costs.costs, ), ) @HttpErrorHandler('updating config') - async def updateConfig(self, user: KhUser, config: ConfigType, value: BaseModel) -> None : - serializer: tuple[AvroSerializer, bytes] = self.Serializers[config] - data: bytes = AvroMarker + serializer[1] + serializer[0](value) + async def updateConfig(self: Self, user: KhUser, config: Store) -> None : await self.query_async(""" - INSERT INTO kheina.public.configs + insert into kheina.public.configs (key, bytes, updated_by) - VALUES - (%s, %s, %s) - ON CONFLICT ON CONSTRAINT configs_pkey DO - UPDATE SET - updated = now(), - bytes = %s, - updated_by = %s; - """, - ( - config, data, user.user_id, - data, user.user_id, + values + ( %s, %s, %s) + on conflict on constraint configs_pkey do + update set + updated = now(), + bytes = excluded.bytes, + updated_by = excluded.updated_by + where key = excluded.key; + """, ( + config.key(), + await config.serialize(), + user.user_id, ), - commit=True, + commit = True, ) - print(config, value) - await KVS.put_async(config, value) + await KVS.put_async(config.key(), config) @staticmethod @@ -207,61 +185,128 @@ def _validateColors(css_properties: Optional[dict[CssProperty, str]]) -> Optiona @HttpErrorHandler('saving user config') - async def setUserConfig(self, user: KhUser, value: UserConfigRequest) -> None : - user_config: UserConfig = UserConfig( - blocking_behavior=value.blocking_behavior, - blocked_tags=list(map(list, value.blocked_tags)) if value.blocked_tags else None, - # TODO: internal tokens need to be added so that we can convert handles to user ids - blocked_users=None, - wallpaper=value.wallpaper, - css_properties=Configs._validateColors(value.css_properties), - ) - - data: bytes = AvroMarker + self.UserConfigFingerprint + UserConfigSerializer(user_config) - config_key: str = UserConfigKeyFormat.format(user_id=user.user_id) - await self.query_async(""" - INSERT INTO kheina.public.configs - (key, bytes, updated_by) - VALUES - (%s, %s, %s) - ON CONFLICT ON CONSTRAINT configs_pkey DO - UPDATE SET - updated = now(), - bytes = %s, - updated_by = %s; - """, ( - config_key, data, user.user_id, - data, user.user_id, - ), - commit=True, + async def setUserConfig( + self: Self, + user: KhUser, + blocking_behavior: BlockingBehavior | None = None, + blocked_tags: list[set[str]] | None = None, + blocked_users: list[str] | None = None, + wallpaper: PostId | None | Literal[False] = False, + css_properties: dict[CssProperty, str] | None | Literal[False] = False, + ) -> None : + stores: list[Store] = [] + + if blocking_behavior : + stores.append(BlockBehavior( + behavior = blocking_behavior, + )) + + if blocked_tags is not None or blocked_users is not None : + blocking = await self._getUserConfig(user.user_id, Blocking) + + if blocked_tags is not None : + blocking.tags = list(map(list, blocked_tags)) + + if blocked_users is not None : + blocking.users = list((await users._handles_to_user_ids(blocked_users)).values()) + + if len(blocking.users) != len(blocked_users) : + raise BadRequest('could not find users for some or all of the provided handles') + + stores.append(blocking) + + if wallpaper is not False or css_properties is not False : + theme = await self._getUserConfig(user.user_id, Theme) + + if wallpaper is not False : + theme.wallpaper = wallpaper + + if css_properties is not False : + theme.css_properties = self._validateColors(css_properties) + + stores.append(theme) + + if not stores : + raise BadRequest('must submit at least one config to update') + + query: list[str] = [] + params: list[int | str | bytes] = [] + + for store in stores : + query.append('(%s, %s, %s, %s)') + params += [ + user.user_id, + store.key(), + store.type_(), + await store.serialize(), + ] + + await self.query_async(f""" + insert into kheina.public.user_configs + (user_id, key, type, data) + values + {','.join(query)} + on conflict on constraint user_configs_pkey do + update set + type = excluded.type, + data = excluded.data + where user_configs.user_id = excluded.user_id + and user_configs.key = excluded.key; + """, + tuple(params), + commit = True, ) - await KVS.put_async(config_key, user_config) + for store in stores : + ensure_future(KVS.put_async( + UserConfigKeyFormat.format( + user_id = user.user_id, + key = store.key(), + ), + store, + )) + + + async def _getUserConfig[T: Store](self: Self, user_id: int, type_: type[T]) -> T : + try : + return await KVS.get_async( + UserConfigKeyFormat.format( + user_id = user_id, + key = type_.key(), + ), + type = type_, + ) + except aerospike.exception.RecordNotFound : + pass - @AerospikeCache('kheina', 'configs', UserConfigKeyFormat, _kvs=KVS) - async def _getUserConfig(self, user_id: int) -> UserConfig : data: list[bytes] = await self.query_async(""" - SELECT bytes - FROM kheina.public.configs - WHERE key = %s; - """, - (UserConfigKeyFormat.format(user_id=user_id),), - fetch_one=True, + select data + from kheina.public.user_configs + where user_id = %s + and key = %s; + """, ( + user_id, + type_.key(), + ), + fetch_one = True, ) if not data : - return UserConfig() + return type_() - value: bytes = bytes(data[0]) - assert value[:2] == AvroMarker - - deserializer: AvroDeserializer[UserConfig] = AvroDeserializer(read_model=UserConfig, write_model=await Configs.getSchema(value[2:10])) - return deserializer(value[10:]) + await KVS.put_async( + UserConfigKeyFormat.format( + user_id = user_id, + key = type_.key(), + ), + res := await type_.deserialize(data[0]), + ) + return res @timed - async def _getUserOTP(self: Self, user_id: int) -> Optional[list[OTP]] : + async def _getUserOTP(self: Self, user_id: int) -> list[OTP] : data: list[tuple[datetime, str]] = await self.query_async(""" select created, 'totp' from kheina.auth.otp @@ -273,7 +318,7 @@ async def _getUserOTP(self: Self, user_id: int) -> Optional[list[OTP]] : ) if not data : - return None + return [] return [ OTP( @@ -285,29 +330,51 @@ async def _getUserOTP(self: Self, user_id: int) -> Optional[list[OTP]] : @HttpErrorHandler('retrieving user config') - async def getUserConfig(self, user: KhUser) -> UserConfigResponse : - user_config: UserConfig = await self._getUserConfig(user.user_id) - - return UserConfigResponse( - blocking_behavior = user_config.blocking_behavior, - blocked_tags = list(map(set, user_config.blocked_tags)) if user_config.blocked_tags else [], - # TODO: convert user ids to handles - blocked_users = None, - wallpaper = user_config.wallpaper.decode() if user_config.wallpaper else None, - otp = await self._getUserOTP(user.user_id), + @timed + async def getUserConfig(self: Self, user: KhUser) -> UserConfigResponse : + data: list[tuple[str, int, bytes]] = await self.query_async(""" + select key, type, data + from kheina.public.user_configs + where user_configs.user_id = %s; + """, ( + user.user_id, + ), + fetch_all = True, ) + res = UserConfigResponse() + otp: Task[list[OTP]] = ensure_future(self._getUserOTP(user.user_id)) + print('==> data:', data) + if data : + for key, type_, value in data : + t = UserConfigType(type_) + + match v := await Configs.SerializerTypeMap[t].deserialize(value) : + case BlockBehavior() : + res.blocking_behavior = v.behavior + + case Blocking() : + res.blocked_tags = v.tags + res.blocked_users = [i.handle for i in (await users._get_users(v.users)).values()] + + case Theme() : + if v.wallpaper or v.css_properties : + res.theme = v + + res.otp = await otp + return res + @HttpErrorHandler('retrieving custom theme') - async def getUserTheme(self, user: KhUser) -> str : - user_config: UserConfig = await self._getUserConfig(user.user_id) + async def getUserTheme(self: Self, user: KhUser) -> str : + theme: Theme = await self._getUserConfig(user.user_id, Theme) - if not user_config.css_properties : + if not theme.css_properties : return '' css_properties: str = '' - for key, value in user_config.css_properties.items() : + for key, value in theme.css_properties.items() : name = key.replace("_", "-") if isinstance(value, int) : diff --git a/configs/models.py b/configs/models.py index 179b8d5..d8289bd 100644 --- a/configs/models.py +++ b/configs/models.py @@ -1,45 +1,51 @@ from datetime import datetime -from enum import Enum, unique -from typing import Dict, List, Literal, Optional, Set, Union +from enum import Enum, IntEnum, unique +from typing import Any, Literal, Optional, Self from avrofastapi.schema import AvroInt -from pydantic import BaseModel, conbytes +from pydantic import BaseModel, Field from shared.models import PostId +from shared.sql.query import Table +from shared.models.config import Store -UserConfigKeyFormat: str = 'user.{user_id}' +UserConfigKeyFormat: Literal['user.{user_id}.{key}'] = 'user.{user_id}.{key}' -class BannerStore(BaseModel) : +@unique +class ConfigType(str, Enum) : # str so literals work + banner = 'banner' + costs = 'costs' + + +class BannerStore(Store) : banner: Optional[str] + @staticmethod + def type_() -> ConfigType : + return ConfigType.banner -class CostsStore(BaseModel) : + +class CostsStore(Store) : costs: int - -@unique -class ConfigType(str, Enum) : # str so literals work - banner = 'banner' - costs = 'costs' + @staticmethod + def type_() -> ConfigType : + return ConfigType.costs class UpdateBannerRequest(BaseModel) : config: Literal[ConfigType.banner] - value: BannerStore + value: BannerStore class UpdateCostsRequest(BaseModel) : config: Literal[ConfigType.costs] - value: CostsStore + value: CostsStore -UpdateConfigRequest = Union[UpdateBannerRequest, UpdateCostsRequest] - - -class SaveSchemaResponse(BaseModel) : - fingerprint: str +UpdateConfigRequest = UpdateBannerRequest | UpdateCostsRequest class Funding(BaseModel) : @@ -52,6 +58,7 @@ class ConfigsResponse(BaseModel) : funding: Funding +@unique class BlockingBehavior(Enum) : hide = 'hide' omit = 'omit' @@ -104,20 +111,55 @@ class CssProperty(Enum) : notification_bg = 'notification_bg' -class UserConfig(BaseModel) : - blocking_behavior: Optional[BlockingBehavior] = None - blocked_tags: Optional[List[List[str]]] = None - blocked_users: Optional[List[int]] = None - wallpaper: Optional[conbytes(min_length=8, max_length=8)] = None - css_properties: Optional[Dict[str, Union[CssProperty, AvroInt, str]]] = None +@unique +class UserConfigType(IntEnum) : + blocking = 0 + block_behavior = 1 + theme = 2 + + +class Blocking(Store) : + tags: list[list[str]] = [] + users: list[int] = [] + + @staticmethod + def type_() -> UserConfigType : + return UserConfigType.blocking + + +class BlockBehavior(Store) : + behavior: BlockingBehavior = BlockingBehavior.hide + + @staticmethod + def type_() -> UserConfigType : + return UserConfigType.block_behavior + + +class Theme(Store) : + wallpaper: Optional[PostId] = None + css_properties: Optional[dict[str, CssProperty | AvroInt | str]] = None + + @staticmethod + def type_() -> UserConfigType : + return UserConfigType.theme class UserConfigRequest(BaseModel) : + field_mask: list[str] blocking_behavior: Optional[BlockingBehavior] - blocked_tags: Optional[List[Set[str]]] - blocked_users: Optional[List[str]] - wallpaper: Optional[PostId] - css_properties: Optional[Dict[CssProperty, str]] + blocked_tags: Optional[list[set[str]]] + blocked_users: Optional[list[str]] + wallpaper: Optional[PostId] + css_properties: Optional[dict[CssProperty, str]] + + def values(self: Self) -> dict[str, Any] : + values = { } + + for f in self.field_mask : + if f in self.__fields_set__ : + values[f] = getattr(self, f) + + return values @unique @@ -132,8 +174,19 @@ class OTP(BaseModel) : class UserConfigResponse(BaseModel) : - blocking_behavior: Optional[BlockingBehavior] - blocked_tags: Optional[List[Set[str]]] - blocked_users: Optional[List[str]] - wallpaper: Optional[str] - otp: Optional[list[OTP]] + blocking_behavior: BlockingBehavior = BlockingBehavior.hide + blocked_tags: list[list[str]] = [] + blocked_users: list[str] = [] + theme: Optional[Theme] = None + otp: list[OTP] = [] + + +class Config(BaseModel) : + __table_name__ = Table('kheina.public.configs') + + key: str = Field(description='orm:"pk"') + created: datetime = Field(description='orm:"gen"') + updated: datetime = Field(description='orm:"gen"') + updated_by: int + value: Optional[str] = None + bytes_: Optional[bytes] = Field(None, description='orm:"col[bytes]"') diff --git a/configs/router.py b/configs/router.py index a0ab2ce..8502629 100644 --- a/configs/router.py +++ b/configs/router.py @@ -4,7 +4,7 @@ from fastapi.responses import PlainTextResponse from shared.auth import Scope -from shared.server import Request +from shared.models.server import Request from .configs import Configs from .models import ConfigsResponse, UpdateConfigRequest, UserConfigRequest, UserConfigResponse @@ -17,24 +17,11 @@ configs: Configs = Configs() -@app.on_event('startup') -async def startup() : - await configs.startup() - - @app.on_event('shutdown') async def shutdown() : await configs.close() -################################################## INTERNAL ################################################## -# @app.get('/i1/user/{user_id}', response_model=UserConfig) -# async def i1UserConfig(req: Request, user_id: int) -> UserConfig : -# await req.user.verify_scope(Scope.internal) -# return await configs._getUserConfig(user_id) - - -################################################## PUBLIC ################################################## @app.get('s', response_model=ConfigsResponse) async def v1Configs() -> ConfigsResponse : return await configs.allConfigs() @@ -45,7 +32,7 @@ async def v1UpdateUserConfig(req: Request, body: UserConfigRequest) -> None : await req.user.verify_scope(Scope.user) await configs.setUserConfig( req.user, - body, + **body.values(), ) diff --git a/db/10/00-add-notification-subscriptions.sql b/db/10/00-add-notification-subscriptions.sql new file mode 100644 index 0000000..e5817f9 --- /dev/null +++ b/db/10/00-add-notification-subscriptions.sql @@ -0,0 +1,54 @@ +begin; + +drop table if exists public.subscriptions; +create table public.subscriptions ( + sub_id uuid unique not null, + user_id bigint not null + references public.users (user_id) + on update cascade + on delete cascade, + sub_info bytea unique not null, + primary key (user_id, sub_id) +); + +drop table if exists public.notifications; +create table public.notifications ( + id uuid not null unique, + user_id bigint not null + references public.users (user_id) + on update cascade + on delete cascade, + type smallint not null, + created timestamptz not null generated always as ('now'::timestamptz) stored, + data bytea not null, + primary key (user_id, id) +); + +drop function public.register_subscription; +create or replace function public.register_subscription(sid uuid, uid bigint, sinfo bytea) returns void as +$$ +begin + + update public.subscriptions + set sub_id = sid, + sub_info = sinfo + where subscriptions.user_id = uid + and ( + subscriptions.sub_id = sid + or subscriptions.sub_info = sinfo + ); + + if found then + return; + end if; + + insert into public.subscriptions + (sub_id, user_id, sub_info) + values + (sid, uid, sinfo); + +end; +$$ +language plpgsql; + +commit; diff --git a/db/11/00-create-user-configs-table.sql b/db/11/00-create-user-configs-table.sql new file mode 100644 index 0000000..2436edc --- /dev/null +++ b/db/11/00-create-user-configs-table.sql @@ -0,0 +1,17 @@ +begin; + +alter table public.configs drop column value; + +drop table if exists public.user_configs; +create table public.user_configs ( + user_id bigint not null + references public.users (user_id) + on update cascade + on delete cascade, + key text not null, + type smallint not null, + data bytea not null, + primary key (user_id, key) +); + +commit; diff --git a/emojis/repository.py b/emojis/repository.py index a34c774..ba54ccf 100644 --- a/emojis/repository.py +++ b/emojis/repository.py @@ -11,7 +11,7 @@ from shared.models import PostId from shared.models._shared import InternalUser, UserPortable from shared.sql import SqlInterface -from users.repository import Users +from users.repository import Repository as Users from .models import Emoji, InternalEmoji diff --git a/emojis/router.py b/emojis/router.py index d8d70f6..cc0ed48 100644 --- a/emojis/router.py +++ b/emojis/router.py @@ -3,7 +3,7 @@ from fastapi import APIRouter from shared.models.auth import Scope -from shared.server import Request +from shared.models.server import Request from shared.timing import timed from .emoji import Emojis diff --git a/init.py b/init.py index 2c5580e..76338b4 100644 --- a/init.py +++ b/init.py @@ -3,7 +3,6 @@ import re import shutil import time -from dataclasses import dataclass from os import environ, listdir, remove from os.path import isdir, isfile, join from secrets import token_bytes @@ -25,6 +24,7 @@ from shared.config.credentials import decryptCredentialFile, fetch from shared.datetime import datetime from shared.logging import TerminalAgent +from shared.models.encryption import Keys from shared.sql import SqlInterface @@ -91,12 +91,17 @@ def nukeCache() -> None : is_flag=True, default=False, ) +@click.option( + '-l', + '--lock', + default=None, +) @click.option( '-f', '--file', default='', ) -async def execSql(unlock: bool = False, file: str = '') -> None : +async def execSql(unlock: bool = False, file: str = '', lock: Optional[int] = None) -> None : """ connects to the database and runs all files stored under the db folder folders under db are sorted numberically and run in descending order @@ -112,9 +117,14 @@ async def execSql(unlock: bool = False, file: str = '') -> None : async with sql.pool.connection() as conn : async with conn.cursor() as cur : sqllock = None - if not unlock and isfile('sql.lock') : + + if lock is not None : + sqllock = int(lock) + + if not unlock and sqllock is None and isfile('sql.lock') : sqllock = int(open('sql.lock').read().strip()) - click.echo(f'==> sql.lock: {sqllock}') + + click.echo(f'==> sql.lock: {sqllock}') if file : if not isfile(file) : @@ -319,45 +329,24 @@ async def updatePassword() -> LoginRequest : return acct -@dataclass -class Keys : - aes: AESGCM - ed25519: Ed25519PrivateKey - associated_data: bytes - - def encrypt(self, data: bytes) -> bytes : - nonce = token_bytes(12) - return b'.'.join(map(b64encode, [nonce, self.aes.encrypt(nonce, data, self.associated_data), self.ed25519.sign(data)])) - - def _generate_keys() -> Keys : + keys = Keys.generate() + if isfile('credentials/aes.key') : remove('credentials/aes.key') if isfile('credentials/ed25519.pub') : remove('credentials/ed25519.pub') - aesbytes = AESGCM.generate_key(256) - aeskey = AESGCM(aesbytes) - ed25519priv = Ed25519PrivateKey.generate() + data = keys.dump() with open('credentials/aes.key', 'wb') as file : - file.write(b'.'.join(map(b64encode, [aesbytes, ed25519priv.sign(aesbytes)]))) + file.write(data['aes'].encode()) - pub = ed25519priv.public_key().public_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) with open('credentials/ed25519.pub', 'wb') as file : - nonce = token_bytes(12) - aeskey.encrypt - file.write(b'.'.join(map(b64encode, [nonce, aeskey.encrypt(nonce, pub, aesbytes), ed25519priv.sign(pub)]))) - - return Keys( - aes=aeskey, - ed25519=ed25519priv, - associated_data=pub, - ) + file.write(data['pub'].encode()) + + return keys def writeAesFile(file: BinaryIO, contents: bytes) : diff --git a/notifications/__init__.py b/notifications/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/notifications/models.py b/notifications/models.py new file mode 100644 index 0000000..17a6c7f --- /dev/null +++ b/notifications/models.py @@ -0,0 +1,143 @@ +from datetime import datetime +from enum import Enum, IntEnum +from typing import Literal, Optional, Self +from uuid import UUID + +from pydantic import BaseModel, Field, validator + +from posts.models import Post +from shared.models import UserPortable +from shared.models.config import Store +from shared.sql.query import Table + + +class ServerKey(BaseModel) : + application_server_key: str + + +class SubscriptionInfo(Store) : + endpoint: str + expirationTime: Optional[int] + keys: dict[str, str] + + @classmethod + def type_(cls) -> Enum : + raise NotImplementedError + + +class Subscription(BaseModel) : + __table_name__ = Table('kheina.public.subscriptions') + + sub_id: UUID = Field(description='orm:"pk"') + """ + sub_id refers to the guid from an auth token. this way, on log out or expiration, a subscription can be removed from the database proactively + """ + + user_id: int = Field(description='orm:"pk"') + subscription_info: bytes = Field(description='orm:"col[sub_info]"') + + +class NotificationType(IntEnum) : + post = 0 + user = 1 + interact = 2 + + +class InternalNotification(BaseModel) : + __table_name__ = Table('kheina.public.notifications') + + id: UUID = Field(description='orm:"pk"') + user_id: int = Field(description='orm:"pk"') + type_: int = Field(description='orm:"col[type]"') + created: datetime = Field(description='orm:"gen"') + data: bytes + + @validator('type_') + def isValidType(cls, value) : + if value not in NotificationType.__members__.values() : + raise KeyError('notification type must exist in the notification enum') + + return value + + def type(self: Self) -> NotificationType : + return NotificationType(self.type_) + + +class Notification(BaseModel) : + type: str + event: Enum + + +class InteractNotificationEvent(Enum) : + favorite = 'favorite' + reply = 'reply' + repost = 'repost' + + +class InternalInteractNotification(Store) : + """ + an interact notification represents a user taking an action on a post + """ + event: InteractNotificationEvent + post_id: int + user_id: int + + @classmethod + def type_(cls) -> NotificationType : + return NotificationType.interact + + +class InteractNotification(Notification) : + """ + an interact notification represents a user taking an action on a post + """ + type: Literal['interact'] = 'interact' + event: InteractNotificationEvent + created: datetime + user: UserPortable + post: Post + + +class PostNotificationEvent(Enum) : + mention = 'mention' + tagged = 'tagged' + + +class InternalPostNotification(Store) : + event: PostNotificationEvent + post_id: int + + @classmethod + def type_(cls) -> NotificationType : + return NotificationType.post + + +class PostNotification(Notification) : + type: Literal['post'] = 'post' + event: PostNotificationEvent + created: datetime + post: Post + + +class UserNotificationEvent(Enum) : + follow = 'follow' + + +class InternalUserNotification(Store) : + event: UserNotificationEvent + user_id: int + """ + this is the user that performed the action, NOT the user being notified, + that is stored in the larger InternalNotification object + """ + + @classmethod + def type_(cls) -> NotificationType : + return NotificationType.user + + +class UserNotification(Notification) : + type: Literal['user'] = 'user' + event: UserNotificationEvent + created: datetime + user: UserPortable diff --git a/notifications/repository.py b/notifications/repository.py new file mode 100644 index 0000000..ec80283 --- /dev/null +++ b/notifications/repository.py @@ -0,0 +1,286 @@ +from typing import Self +from urllib.parse import urlparse +from uuid import UUID + +import ujson +from aiohttp import ClientResponse, ClientSession, ClientTimeout +from cache import AsyncLRU +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey +from py_vapid import Vapid02 +from pydantic import BaseModel +from pywebpush import WebPusher as _WebPusher + +from configs.models import Config +from posts.models import Post +from shared.auth import KhUser +from shared.base64 import b64encode +from shared.caching import AerospikeCache +from shared.caching.key_value_store import KeyValueStore +from shared.config.credentials import fetch +from shared.datetime import datetime +from shared.exceptions.http_error import HttpErrorHandler, InternalServerError +from shared.models import UserPortable +from shared.models.encryption import Keys +from shared.sql import SqlInterface +from shared.sql.query import Field, Operator, Value, Where +from shared.timing import timed +from shared.utilities import uuid7 +from shared.utilities.json import json_stream + +from .models import InteractNotification, InternalInteractNotification, InternalNotification, InternalPostNotification, InternalUserNotification, NotificationType, PostNotification, ServerKey, Subscription, SubscriptionInfo, UserNotification + + +class WebPusher(_WebPusher) : + @timed + async def send_async(self, *args, **kwargs) -> ClientResponse | str : + # this is pretty much copied as-is, but with a couple changes to fix issues + timeout = ClientTimeout(kwargs.pop("timeout", 10000)) + curl = kwargs.pop("curl", False) + + params = self._prepare_send_data(*args, **kwargs) + endpoint = params.pop("endpoint") + + if curl : + encoded_data = params["data"] + headers = params["headers"] + return self.as_curl(endpoint, encoded_data=encoded_data, headers=headers) + + if self.aiohttp_session : + resp = await self.aiohttp_session.post(endpoint, timeout=timeout, **params) + + else: + async with ClientSession() as session : + resp = await session.post(endpoint, timeout=timeout, **params) + + return resp + + +kvs: KeyValueStore = KeyValueStore('kheina', 'notifications') + + +class Notifier(SqlInterface) : + + keys: Keys + subInfoFingerprint: bytes + serializerTypeMap: dict[NotificationType, type[BaseModel]] = { + NotificationType.post: InternalPostNotification, + NotificationType.user: InternalUserNotification, + NotificationType.interact: InternalInteractNotification, + } + + async def startup(self) -> None : + if getattr(Notifier, 'keys', None) is None : + Notifier.keys = Keys.load(**fetch('notifications', dict[str, str])) + + + @AerospikeCache('kheina', 'notifications', 'vapid-config', _kvs=kvs) + async def getVapidPem(self: Self) -> bytes : + async with self.transaction() as t : + vapid = Vapid02() + vapid_config = Config( + key = 'vapid-config', + created = datetime.zero(), + updated = datetime.zero(), + updated_by = 0, + bytes_ = None, + ) + + try : + vapid_config = await t.select(vapid_config) + assert vapid_config.bytes_ + return Notifier.keys.decrypt(vapid_config.bytes_) + + except KeyError : + pass + + vapid.generate_keys() + vapid_config.bytes_ = Notifier.keys.encrypt(vapid.private_pem()) + await t.insert(vapid_config) + return vapid.private_pem() + + + async def getVapid(self: Self) -> Vapid02 : + pk_pem = await self.getVapidPem() + return Vapid02.from_pem(pk_pem) + + + async def getApplicationServerKey(self: Self) -> ServerKey : + vapid = await self.getVapid() + pub = vapid.public_key + assert isinstance(pub, EllipticCurvePublicKey) + return ServerKey( + application_server_key = b64encode(pub.public_bytes( + serialization.Encoding.X962, + serialization.PublicFormat.UncompressedPoint, + )).decode(), + ) + + + @HttpErrorHandler('registering subscription info', exclusions=['self', 'sub_info']) + async def registerSubInfo(self: Self, user: KhUser, sub_info: SubscriptionInfo) -> None : + assert user.token, 'this should always be populated when the user is authenticated' + data: bytes = await sub_info.serialize() + await self.query_async(""" + select kheina.public.register_subscription(%s::uuid, %s, %s); + """, ( + user.token.guid, + user.user_id, + Notifier.keys.encrypt(data), + ), + commit = True, + ) + await kvs.remove_async(f'sub_info={user.user_id}') + + + @timed + async def unregisterSubInfo(self: Self, user_id: int, sub_ids: list[UUID]) -> None : + await self.query_async(""" + delete from kheina.public.subscriptions + where subscriptions.sub_id = any(%s); + """, ( + sub_ids, + ), + commit = True, + ) + await kvs.remove_async(f'sub_info={user_id}') + + + @timed + @AerospikeCache('kheina', 'notifications', 'sub_info={user_id}', _kvs=kvs) + async def getSubInfo(self: Self, user_id: int) -> dict[UUID, SubscriptionInfo] : + sub_info: dict[UUID, SubscriptionInfo] = { } + subs: list[Subscription] = await self.where(Subscription, Where( + Field('subscriptions', 'user_id'), + Operator.equal, + Value(user_id), + )) + + for s in subs : + sub = Notifier.keys.decrypt(s.subscription_info) + sub_info[s.sub_id] = await SubscriptionInfo.deserialize(sub) + + return sub_info + + + async def vapidHeaders(self: Self, sub_info: SubscriptionInfo) -> dict[str, str] : + url = urlparse(sub_info.endpoint) + claim = { + 'sub': 'mailto:help@kheina.com', + 'aud': f'{url.scheme}://{url.netloc}', + 'exp': int(datetime.now().timestamp()) + 1440, # 1 hour, I guess? + } + vapid = await self.getVapid() + return vapid.sign(claim) + + + @timed + async def _send(self: Self, user_id: int, data: dict) -> int : + subs = await self.getSubInfo(user_id) + unregister: list[UUID] = [] + successes: int = 0 + for sub_id, sub in subs.items() : + res = await WebPusher( + sub.dict(), + verbose = True, + ).send_async( + data = ujson.dumps(json_stream(data)), + headers = await self.vapidHeaders(sub), + content_encoding = "aes128gcm", + ) + + if not isinstance(res, ClientResponse) : + raise TypeError(f'expected response to be ClientResponse, got {type(res)}') + + if res.status < 300 : + successes += 1 + + elif res.status == 410 : + unregister.append(sub_id) + + else : + raise InternalServerError('unexpected error occurred while sending notification', status=res.status, content=await res.text()) + + if unregister : + await self.unregisterSubInfo(user_id, unregister) + + self.logger.debug({ + 'message': 'sent notification', + 'successes': successes, + 'failures': len(unregister), + 'to': user_id, + 'notification': data, + }) + + return successes + + + @timed + async def sendNotification( + self: Self, + user_id: int, + data: InternalInteractNotification | InternalPostNotification | InternalUserNotification, + **kwargs: UserPortable | Post, + ) -> int : + """ + creates, persists and then sends the given notification to the provided user_id. + kwargs must include the user and/or post of the notification's user_id/post_id in the form of + ``` + await sendNotification(..., user=UserPortable(...), post=Post(...)) + ``` + """ + type_: NotificationType = data.type_() + inotification = await self.insert(InternalNotification( + id = uuid7(), + user_id = user_id, + type_ = type_, + created = datetime.zero(), + data = await data.serialize(), + )) + + self.logger.debug({ + 'message': 'notification', + 'to': user_id, + 'notification': { + 'type': type(data), + 'type_enm': data.type_(), + **data.dict(), + }, + }) + + match data : + case InternalInteractNotification() : + user, post = kwargs.get('user'), kwargs.get('post') + assert isinstance(user, UserPortable) and isinstance(post, Post) + notification = InteractNotification( + event = data.event, + created = inotification.created, + user = user, + post = post, + ) + return await self._send(user_id, notification.dict()) + + case InternalPostNotification() : + post = kwargs.get('post') + assert isinstance(post, Post) + notification = PostNotification( + event = data.event, + created = inotification.created, + post = post, + ) + return await self._send(user_id, notification.dict()) + + case InternalUserNotification() : + user = kwargs.get('user') + assert isinstance(user, UserPortable) + notification = UserNotification( + event = data.event, + created = inotification.created, + user = user, + ) + return await self._send(user_id, notification.dict()) + + + @HttpErrorHandler('sending some random cunt a notif') + async def debugSendNotification(self: Self, user_id: int, data: dict) -> int : + return await self._send(user_id, data) diff --git a/notifications/router.py b/notifications/router.py new file mode 100644 index 0000000..4bc77f1 --- /dev/null +++ b/notifications/router.py @@ -0,0 +1,50 @@ +from fastapi import APIRouter + +from shared.models.auth import Scope +from shared.models.server import Request +from shared.timing import timed + +from .models import ServerKey, SubscriptionInfo +from .repository import Notifier + + +repo = Notifier() +notificationsRouter = APIRouter( + prefix='/notifications', +) + + +@notificationsRouter.on_event('startup') +async def startup() -> None : + await repo.startup() + + +@notificationsRouter.get('/register', response_model=ServerKey) +@timed.root +async def v1GetServerKey(req: Request) -> ServerKey : + """ + only auth required + """ + await req.user.authenticated() + return await repo.getApplicationServerKey() + + +@notificationsRouter.put('/register', response_model=None) +@timed.root +async def v1RegisterNotificationTarget(req: Request, body: SubscriptionInfo) -> None : + await req.user.authenticated() + await repo.registerSubInfo(req.user, body) + + +@notificationsRouter.post('', status_code=201) +async def v1SendThisBitchAVibe(req: Request, body: dict) -> int : + await req.user.verify_scope(Scope.admin) + return await repo.debugSendNotification(req.user.user_id, body) + + +app = APIRouter( + prefix='/v1', + tags=['notifications'], +) + +app.include_router(notificationsRouter) diff --git a/posts/blocking.py b/posts/blocking.py index f35196f..184e8df 100644 --- a/posts/blocking.py +++ b/posts/blocking.py @@ -1,7 +1,7 @@ -from typing import Dict, Iterable, Self, Set, Tuple +from typing import Iterable, Optional, Self from configs.configs import Configs -from configs.models import UserConfig +from configs.models import Blocking from shared.auth import KhUser from shared.caching import ArgsCache from shared.models import InternalUser @@ -29,23 +29,29 @@ def dict(self: Self) : def __init__(self: 'BlockTree') : - self.tags: Set[str] = set() - self.match: Dict[str, BlockTree] = { } - self.nomatch: Dict[str, BlockTree] = { } + self.tags: set[str | int] = set() + self.match: dict[str | int, BlockTree] = { } + self.nomatch: dict[str | int, BlockTree] = { } - def populate(self: Self, tags: Iterable[Iterable[str]]) : + def populate(self: Self, tags: Iterable[Iterable[str | int]]) : for tag_set in tags : tree: BlockTree = self for tag in tag_set : match = True - if tag.startswith('-') : - match = False - tag = tag[1:] + if isinstance(tag, str) : + if tag.startswith('-') : + match = False + tag = tag[1:] - branch: Dict[str, BlockTree] + elif isinstance(tag, int) : + if tag < 0 : + match = False + tag *= -1 + + branch: dict[str | int, BlockTree] if match : if not tree.match : @@ -65,7 +71,7 @@ def populate(self: Self, tags: Iterable[Iterable[str]]) : tree = branch[tag] - def blocked(self: Self, tags: Iterable[str]) -> bool : + def blocked(self: Self, tags: Iterable[str | int]) -> bool : if not self.match and not self.nomatch : return False @@ -93,27 +99,26 @@ def _blocked(self: Self, tree: 'BlockTree') -> bool : @ArgsCache(30) -async def fetch_block_tree(user: KhUser) -> Tuple[BlockTree, UserConfig] : +async def fetch_block_tree(user: KhUser) -> tuple[BlockTree, Optional[set[int]]] : tree: BlockTree = BlockTree() if not user.token : # TODO: create and return a default config - return tree, UserConfig() + return tree, None - # TODO: return underlying UserConfig here, once internal tokens are implemented - user_config: UserConfig = await configs._getUserConfig(user.user_id) - tree.populate(user_config.blocked_tags or []) - return tree, user_config + config: Blocking = await configs._getUserConfig(user.user_id, Blocking) + tree.populate(config.tags or []) + return tree, set(config.users) if config.users else None @timed async def is_post_blocked(user: KhUser, uploader: InternalUser, tags: Iterable[str]) -> bool : - block_tree, user_config = await fetch_block_tree(user) + block_tree, blocked_users = await fetch_block_tree(user) - if user_config.blocked_users and uploader.user_id in user_config.blocked_users : + if blocked_users and uploader.user_id in blocked_users : return True - tags: Set[str] = set(tags) - tags.add('@' + uploader.handle) # TODO: user ids need to be added here instead of just handle, once changeable handles are added + tags: set[str | int] = set(tags) # TODO: convert handles to user_ids (int) + tags.add(uploader.user_id) return block_tree.blocked(tags) diff --git a/posts/models.py b/posts/models.py index 0abd2be..460b64f 100644 --- a/posts/models.py +++ b/posts/models.py @@ -158,7 +158,7 @@ class Post(OmitModel) : post_id: PostId title: Optional[str] description: Optional[str] - user: UserPortable + user: Optional[UserPortable] score: Optional[Score] rating: Rating parent_id: Optional[PostId] @@ -169,6 +169,7 @@ class Post(OmitModel) : media: Optional[Media] tags: Optional[TagGroups] blocked: bool + locked: bool = False replies: Optional[list['Post']] = None """ None implies "not retrieved" whereas [] means no replies found @@ -273,7 +274,7 @@ class InternalScore(BaseModel) : total: int -######################### uploader things +################################################## uploader things ################################################## class UpdateRequest(BaseModel) : diff --git a/posts/posts.py b/posts/posts.py index 0224680..0bd15b5 100644 --- a/posts/posts.py +++ b/posts/posts.py @@ -4,7 +4,7 @@ from typing import Iterable, Optional, Self from sets.models import InternalSet, SetId -from sets.repository import Sets +from sets.repository import Repository as Sets from shared.auth import KhUser from shared.caching import AerospikeCache, ArgsCache from shared.datetime import datetime @@ -13,13 +13,13 @@ from shared.timing import timed from .models import InternalPost, Post, PostId, PostSort, Privacy, Rating, Score, SearchResults -from .repository import PostKVS, Posts, privacy_map, rating_map, users # type: ignore +from .repository import PostKVS, Repository, privacy_map, rating_map, users # type: ignore sets = Sets() -class Posts(Posts) : +class Posts(Repository) : @staticmethod def _normalize_tag(tag: str) : @@ -544,6 +544,8 @@ async def _fetch_posts(self: Self, sort: PostSort, tags: Optional[tuple[str, ... order = [(Field('set_post', 'index'), order)], alias = 'order', ), + ).group( + Field('set_post', 'index'), ).order( Field('set_post', 'index'), order, @@ -656,11 +658,6 @@ async def _getComments(self: Self, post_id: PostId, sort: PostSort, count: int, Operator.equal, Value(await privacy_map.get_id(Privacy.public)), ), - Where( - Field('posts', 'locked'), - Operator.equal, - Value(False), - ), ).union( Query( Table('kheina.public.posts'), @@ -679,11 +676,6 @@ async def _getComments(self: Self, post_id: PostId, sort: PostSort, count: int, Operator.equal, Value(await privacy_map.get_id(Privacy.public)), ), - Where( - Field('posts', 'locked'), - Operator.equal, - Value(False), - ), ), ), recursive = True, diff --git a/posts/repository.py b/posts/repository.py index f411e80..095e3c5 100644 --- a/posts/repository.py +++ b/posts/repository.py @@ -1,6 +1,7 @@ from asyncio import Task, ensure_future from collections import defaultdict from dataclasses import dataclass +from functools import partial from typing import Callable, Mapping, Optional, Self, Tuple, Union from cache import AsyncLRU @@ -16,8 +17,9 @@ from shared.sql.query import CTE, Field, Join, JoinType, Operator, Order, Query, Table, Value, Where from shared.timing import timed from tags.models import InternalTag, Tag, TagGroup -from tags.repository import TagKVS, Tags -from users.repository import Users +from tags.repository import Repository as Tags +from tags.repository import TagKVS +from users.repository import Repository as Users from .blocking import is_post_blocked from .models import InternalPost, InternalScore, Media, MediaFlag, MediaType, Post, PostId, PostSize, Privacy, Rating, Score, Thumbnail @@ -116,7 +118,7 @@ class UserCombined: internal: InternalUser -class Posts(SqlInterface) : +class Repository(SqlInterface) : def parse_response( self: Self, @@ -148,20 +150,6 @@ def parse_response( posts: list[InternalPost] = [] for row in data : - # media: Optional[InternalMedia] = None - # if row[7] and row[8] and row[16] : - # media = InternalMedia( - # post_id = row[0], - # filename = row[7], - # type = row[8], - # crc = row[15], - # updated = row[16], - # size = PostSize( - # width = row[9], - # height = row[10], - # ) if row[9] and row[10] else None, - # ) - post = InternalPost( post_id = row[0], title = row[1], @@ -176,15 +164,14 @@ def parse_response( width = row[9], height = row[10], ) if row[9] and row[10] else None, - user_id = row[11], - privacy = row[12], - thumbhash = row[13], - locked = row[14], - crc = row[15], - media_updated = row[16], - content_length = row[17], - thumbnails = row[18], # type: ignore - + user_id = row[11], + privacy = row[12], + thumbhash = row[13], + locked = row[14], + crc = row[15], + media_updated = row[16], + content_length = row[17], + thumbnails = row[18], # type: ignore include_in_results = row[19], ) posts.append(post) @@ -407,9 +394,29 @@ async def post(self: Self, user: KhUser, ipost: InternalPost) -> Post : upl_portable: Task[UserPortable] = ensure_future(users.portable(user, uploader)) itags: list[InternalTag] = await tags_task tags: Task[list[Tag]] = ensure_future(tagger.tags(user, itags)) - blocked: Task[bool] = ensure_future(is_post_blocked(user, uploader, [t.name for t in itags])) + blocked: Task[bool] = ensure_future(is_post_blocked(user, uploader, (t.name for t in itags))) + + post = Post( + post_id = post_id, + title = ipost.title, + description = ipost.description, + user = await upl_portable, + score = await score, + rating = await rating_map.get(ipost.rating), + parent = await parent, + parent_id = PostId(ipost.parent) if ipost.parent else None, + privacy = await privacy_map.get(ipost.privacy), + created = ipost.created, + updated = ipost.updated, + media = None, + tags = tagger.groups(await tags), + blocked = await blocked, + replies = None, + ) + + if ipost.locked : + post.locked = True - media: Optional[Media] = None if ipost.filename and ipost.media_type and ipost.size and ipost.content_length and ipost.thumbnails : flags: list[MediaFlag] = [] @@ -417,7 +424,7 @@ async def post(self: Self, user: KhUser, ipost: InternalPost) -> Post : if itag.group == TagGroup.system : flags.append(MediaFlag[itag.name]) - media = Media( + post.media = Media( post_id = PostId(ipost.post_id), crc = ipost.crc, filename = ipost.filename, @@ -443,23 +450,7 @@ async def post(self: Self, user: KhUser, ipost: InternalPost) -> Post : ], ) - return Post( - post_id = post_id, - title = ipost.title, - description = ipost.description, - user = await upl_portable, - score = await score, - rating = await rating_map.get(ipost.rating), - parent = await parent, - parent_id = PostId(ipost.parent) if ipost.parent else None, - privacy = await privacy_map.get(ipost.privacy), - created = ipost.created, - updated = ipost.updated, - media = media, - tags = tagger.groups(await tags), - blocked = await blocked, - replies = None, - ) + return post @timed @@ -493,19 +484,21 @@ async def scores_many(self: Self, post_ids: list[PostId]) -> dict[PostId, Option return { } cached = await ScoreKVS.get_many_async(post_ids) + found: dict[PostId, Optional[InternalScore]] = { } misses: list[PostId] = [] - for k, v in list(cached.items()) : - if v is not Undefined : + for k, v in cached.items() : + if v is None or isinstance(v, InternalScore) : + found[k] = v continue misses.append(k) - cached[k] = None + found[k] = None if not misses : - return cached + return found - scores: dict[PostId, Optional[InternalScore]] = cached + scores: dict[PostId, Optional[InternalScore]] = found data: list[Tuple[int, int, int]] = await self.query_async(""" SELECT post_scores.post_id, @@ -519,9 +512,6 @@ async def scores_many(self: Self, post_ids: list[PostId]) -> dict[PostId, Option fetch_all = True, ) - if not data : - return scores - for post_id, up, down in data : post_id = PostId(post_id) score: InternalScore = InternalScore( @@ -530,7 +520,9 @@ async def scores_many(self: Self, post_ids: list[PostId]) -> dict[PostId, Option total = up + down, ) scores[post_id] = score - ensure_future(ScoreKVS.put_async(post_id, score)) + + for k, v in scores.items() : + ensure_future(ScoreKVS.put_async(k, v)) return scores @@ -566,19 +558,20 @@ async def votes_many(self: Self, user_id: int, post_ids: list[PostId]) -> dict[P PostId(k[k.rfind('|') + 1:]): v for k, v in (await VoteKVS.get_many_async([f'{user_id}|{post_id}' for post_id in post_ids])).items() } + found: dict[PostId, int] = { } misses: list[PostId] = [] - for k, v in list(cached.items()) : - if v is not Undefined : + for k, v in cached.items() : + if isinstance(v, int) : + found[k] = v continue misses.append(k) - cached[k] = None if not misses : - return cached + return found - votes: dict[PostId, int] = cached + votes: dict[PostId, int] = found data: list[Tuple[int, int]] = await self.query_async(""" SELECT post_votes.post_id, @@ -593,13 +586,12 @@ async def votes_many(self: Self, user_id: int, post_ids: list[PostId]) -> dict[P fetch_all = True, ) - if not data : - return votes - for post_id, upvote in data : post_id = PostId(post_id) vote: int = 1 if upvote else -1 votes[post_id] = vote + + for post_id, vote in votes.items() : ensure_future(VoteKVS.put_async(f'{user_id}|{post_id}', vote)) return votes @@ -668,6 +660,11 @@ def _validateVote(self: Self, vote: Optional[bool]) -> None : @timed async def _vote(self: Self, user: KhUser, post_id: PostId, upvote: Optional[bool]) -> Score : + ipost: InternalPost = await self._get_post(post_id) + + if ipost.locked : + raise BadRequest('cannot vote on a post that has been locked', post=ipost) + self._validateVote(upvote) async with self.transaction() as transaction : await transaction.query_async(""" @@ -838,19 +835,20 @@ async def _tags_many(self: Self, post_ids: list[PostId]) -> dict[PostId, list[In PostId(k[k.rfind('.') + 1:]): v for k, v in (await VoteKVS.get_many_async([f'post.{post_id}' for post_id in post_ids])).items() } + found: dict[PostId, list[InternalTag]] = { } misses: list[PostId] = [] - for k, v in list(cached.items()) : - if v is not Undefined : + for k, v in cached.items() : + if isinstance(v, list) and all(map(lambda x : isinstance(x, InternalTag), v)) : + found[k] = v continue misses.append(k) - del cached[k] if not misses : - return cached + return found - tags: dict[PostId, list[InternalTag]] = defaultdict(list, cached) + tags: dict[PostId, list[InternalTag]] = defaultdict(list, found) data: list[tuple[int, str, str, bool, Optional[int]]] = await self.query_async(""" SELECT tag_post.post_id, @@ -928,9 +926,41 @@ async def posts(self: Self, user: KhUser, iposts: list[InternalPost], assign_par if itag.name in MediaFlag.__members__ : flags.append(MediaFlag[itag.name]) - media: Optional[Media] = None + post = all_posts[post_id] = Post( + post_id = post_id, + title = None, + description = None, + user = None, + score = scores[post_id], + rating = await rating_map.get(ipost.rating), + privacy = await privacy_map.get(ipost.privacy), + media = None, + created = ipost.created, + updated = ipost.updated, + parent_id = parent_id, + + # only the first call retrieves blocked info, all the rest should be cached and not actually await + blocked = await is_post_blocked(user, uploaders[ipost.user_id].internal, tag_names), + tags = tagger.groups(post_tags) + ) + + if not assign_parents : + # this way, when assign_parents = true, post.replies can be omitted by being unassigned + post.replies = [] + + if ipost.include_in_results : + posts.append(post) + + if ipost.locked : + post.locked = True + continue # we don't want any other fields populated + + post.title = ipost.title + post.description = ipost.description + post.user = uploaders[ipost.user_id].portable + if ipost.filename and ipost.media_type and ipost.size and ipost.content_length and ipost.thumbnails : - media = Media( + post.media = Media( post_id = post_id, crc = ipost.crc, filename = ipost.filename, @@ -956,31 +986,6 @@ async def posts(self: Self, user: KhUser, iposts: list[InternalPost], assign_par ], ) - post = all_posts[post_id] = Post( - post_id = post_id, - title = ipost.title, - description = ipost.description, - user = uploaders[ipost.user_id].portable, - score = scores[post_id], - rating = await rating_map.get(ipost.rating), - privacy = await privacy_map.get(ipost.privacy), - media = media, - created = ipost.created, - updated = ipost.updated, - parent_id = parent_id, - - # only the first call retrieves blocked info, all the rest should be cached and not actually await - blocked = await is_post_blocked(user, uploaders[ipost.user_id].internal, tag_names), - tags = tagger.groups(post_tags) - ) - - if not assign_parents : - # this way, when assign_parents = true, post.replies can be omitted by being unassigned - post.replies = [] - - if ipost.include_in_results : - posts.append(post) - if assign_parents : for post_id, parent in parents.items() : if parent not in all_posts : diff --git a/posts/router.py b/posts/router.py index f735c2c..cefcabc 100644 --- a/posts/router.py +++ b/posts/router.py @@ -1,17 +1,17 @@ from asyncio import ensure_future from html import escape -from typing import Literal, Optional, Union +from typing import Literal, Optional from uuid import uuid4 import aiofiles -from fastapi import APIRouter, File, Form, UploadFile +from fastapi import APIRouter, File, Form, Response, UploadFile from shared.backblaze import B2Interface from shared.config.constants import Environment, environment from shared.exceptions.http_error import UnprocessableDetail, UnprocessableEntity from shared.models import Privacy, convert_path_post_id from shared.models.auth import Scope -from shared.server import Request, Response +from shared.models.server import Request from shared.timing import timed from shared.utilities.units import Byte from users.users import Users @@ -302,7 +302,7 @@ async def v1Rss(req: Request) -> Response : mime_type = post.media.type.mime_type, length = post.media.length, ) if post.media else '', - ) for post in timeline + ) for post in timeline if post.user ]), ), ) diff --git a/posts/uploader.py b/posts/uploader.py index 151bdc6..9d9f46d 100644 --- a/posts/uploader.py +++ b/posts/uploader.py @@ -1,6 +1,7 @@ import json from asyncio import Task, create_subprocess_exec, ensure_future from enum import Enum +from hashlib import sha1 from io import BytesIO from os import path, remove from secrets import token_bytes @@ -16,11 +17,11 @@ from wand import resource from wand.image import Image +from notifications.repository import Notifier from shared.auth import KhUser, Scope from shared.backblaze import B2Interface, MimeType -from shared.base64 import b64decode +from shared.base64 import b64decode, b64encode from shared.caching.key_value_store import KeyValueStore -from shared.crc import CRC from shared.datetime import datetime from shared.exceptions.http_error import BadGateway, BadRequest, Forbidden, HttpErrorHandler, InternalServerError, NotFound from shared.models import InternalUser @@ -29,11 +30,15 @@ from shared.utilities import flatten, int_from_bytes from shared.utilities.units import Byte from tags.models import InternalTag -from tags.repository import CountKVS, Tags -from users.repository import UserKVS, Users +from tags.repository import CountKVS +from tags.repository import Repository as Tags +from users.repository import Repository as Users +from users.repository import UserKVS from .models import Coordinates, InternalPost, Media, MediaFlag, Post, PostId, PostSize, Privacy, Rating, Thumbnail -from .repository import PostKVS, Posts, VoteKVS, media_type_map, privacy_map, rating_map +from .repository import PostKVS +from .repository import Repository as Posts +from .repository import VoteKVS, media_type_map, privacy_map, rating_map from .scoring import confidence from .scoring import controversial as calc_cont from .scoring import hot as calc_hot @@ -44,27 +49,59 @@ resource.limits.set_resource_limit('disk', Byte.gigabyte.value * 100) UnpublishedPrivacies: Set[Privacy] = { Privacy.unpublished, Privacy.draft } -posts = Posts() -users = Users() -tagger = Tags() -_crc = CRC(32) +posts: Posts = Posts() +users: Users = Users() +tagger: Tags = Tags() +notifier: Notifier = Notifier() @timed def crc(value: bytes) -> int : - return _crc(value) + # return int.from_bytes(sha1(value).digest()[:8], signed=True) + return int.from_bytes(sha1(value).digest()[:4]) + + +@timed +async def extract_frame(file_on_disk: str, filename: str) -> str : + await FFmpeg().input( + file_on_disk, + accurate_seek = None, + ss = '0', + ).output( + (screenshot := f'images/{uuid4().hex}_{filename}.webp'), + { 'frames:v': '1' }, + ).execute() + return screenshot + + +@timed +async def validate_image(file_on_disk: str) -> None : + with Image(file=open(file_on_disk, 'rb')) : + pass + + +@timed +async def validate_video(file_on_disk: str) -> None : + # ffmpeg -v error -i file.avi -f null - + await FFmpeg().input( + file_on_disk, + v = 'error', + ).output( + '-', + f = 'null', + ).execute() class Uploader(SqlInterface, B2Interface) : def __init__(self: Self) -> None : + B2Interface.__init__(self, max_retries=5) SqlInterface.__init__( self, conversions={ Enum: lambda x: x.name, }, ) - B2Interface.__init__(self, max_retries=5) self.thumbnail_sizes: list[int] = [ 1200, 800, @@ -195,7 +232,13 @@ async def createPost(self: Self, user: KhUser) -> Post : for _ in range(100) : post_id = PostId.generate() - data = await t.query_async("SELECT count(1) FROM kheina.public.posts WHERE post_id = %s;", (post_id.int(),), fetch_one=True) + data = await t.query_async(""" + SELECT count(1) FROM kheina.public.posts WHERE post_id = %s; + """, ( + post_id.int(), + ), + fetch_one = True, + ) if not data[0] : break @@ -223,7 +266,7 @@ async def createPost(self: Self, user: KhUser) -> Post : user.user_id, user.user_id, ), - fetch_one=True, + fetch_one = True, ) post_id = PostId(data[0]) @@ -274,6 +317,7 @@ async def createPostWithFields( internal_post_id: int post_id: PostId + notify: bool = False async with self.transaction() as transaction : for _ in range(100) : @@ -297,13 +341,18 @@ async def createPostWithFields( post = await transaction.insert(post) if privacy : - await self._update_privacy(user, post_id, privacy, transaction=transaction, commit=False) + notify = await self._update_privacy(user, post_id, privacy, transaction=transaction, commit=False) post.privacy = await privacy_map.get_id(privacy) await transaction.commit() await PostKVS.put_async(post_id, post) + if notify : + """ + TODO: check for mentions in the post, and notify users that they were mentioned + """ + return await posts.post(user, post) @@ -374,18 +423,28 @@ async def insert_thumbnail(self: Self, t: Transaction, post_id: PostId, crc: int @timed - async def upload_thumbnail(self: Self, run: str, t: Transaction, post_id: PostId, crc: int, image: Image, size: int, ext: str) -> Thumbnail : - mime: MimeType = MimeType[ext] - url: str = f'{post_id}/{crc}/thumbnails/{size}.{ext}' - data: bytes = self.get_image_data(image.convert(mime.type())) - await self.upload_async(data, url, mime) - th = await self.insert_thumbnail(t, post_id, crc, mime, size, f'{size}.{ext}', len(data), image.size[0], image.size[1]) - self.logger.debug({ - 'run': run, - 'post': post_id, - 'message': f'uploaded thumbnail {mime.name}({size}) image to cdn', - }) - return th + async def upload_thumbnails(self: Self, run: str, t: Transaction, post_id: PostId, crc: int, image: Image, formats: list[tuple[int, str]]) -> list[Thumbnail] : + """ + formats is a tuple of size and file extension, used to resize each thumbnail and upload it + """ + ths: list[Thumbnail] = [] + # query: list[str] = [] + # params: list[Any] = [] + + for size, ext in sorted(formats, key=lambda x : x[0], reverse=True) : + image = self.convert_image(image, size) + mime: MimeType = MimeType[ext] + url: str = f'{post_id}/{crc}/thumbnails/{size}.{ext}' + data: bytes = self.get_image_data(image.convert(mime.type())) + await self.upload_async(data, url, mime) + ths.append(await self.insert_thumbnail(t, post_id, crc, mime, size, f'{size}.{ext}', len(data), image.size[0], image.size[1])) + self.logger.debug({ + 'run': run, + 'post': post_id, + 'message': f'uploaded thumbnail {mime.name}({size}) image to cdn', + }) + + return ths @timed @@ -422,17 +481,25 @@ async def uploadImage( emoji_name: Optional[str] = None, web_resize: Optional[int] = None, ) -> Media : - run: str = uuid4().hex + start: datetime = datetime.now() + run: str = uuid4().hex # validate it's an actual photo try : - with Image(file=open(file_on_disk, 'rb')) : - pass + await validate_image(file_on_disk) except Exception as e : self.delete_file(file_on_disk) raise BadRequest('Uploaded file is not an image.', err=e) + self.logger.debug({ + 'run': run, + 'post': post_id, + 'elapsed': datetime.now() - start, + 'file_on_disk': file_on_disk, + 'message': 'validated input image file', + }) + rev: int mime_type: MimeType @@ -461,6 +528,7 @@ async def uploadImage( self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, 'file_on_disk': file_on_disk, 'content_type': mime_type, 'filename': filename, @@ -474,7 +542,9 @@ async def uploadImage( self.logger.debug({ 'run': run, - 'thumbhash': thumbhash, + 'post': post_id, + 'elapsed': datetime.now() - start, + 'thumbhash': b64encode(thumbhash).decode(), }) async with self.transaction() as transaction : @@ -523,6 +593,7 @@ async def uploadImage( self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, 'message': 'resized for web', }) @@ -556,14 +627,14 @@ async def uploadImage( ( %s, %s, %s, %s, %s, %s, %s, %s) on conflict (post_id) do update set updated = now(), - type = %s, - filename = %s, - length = %s, - thumbhash = %s, - width = %s, - height = %s, - crc = %s - WHERE media.post_id = %s + type = excluded.type, + filename = excluded.filename, + length = excluded.length, + thumbhash = excluded.thumbhash, + width = excluded.width, + height = excluded.height, + crc = excluded.crc + WHERE media.post_id = excluded.post_id RETURNING media.updated; """, ( post_id.int(), @@ -574,18 +645,8 @@ async def uploadImage( image_size.width, image_size.height, rev, - - media_type, - filename, - content_length, - thumbhash, - image_size.width, - image_size.height, - rev, - - post_id.int(), ), - fetch_one=True, + fetch_one = True, ) updated: datetime = upd[0] @@ -602,6 +663,7 @@ async def uploadImage( self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, 'message': 'deleted old file from cdn', }) @@ -612,37 +674,21 @@ async def uploadImage( self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, 'message': 'uploaded fullsize image to cdn', }) # upload thumbnails - thumbnails: list[Thumbnail] = [] - + thumbnails: list[Thumbnail] with Image(file=open(file_on_disk, 'rb')) as image : - for i, size in enumerate(self.thumbnail_sizes) : - image = self.convert_image(image, size) - - if not i : - # jpeg thumbnail - thumbnails.append(await self.upload_thumbnail( - run, - transaction, - post_id, - rev, - image, - size, - 'jpg', - )) - - thumbnails.append(await self.upload_thumbnail( - run, - transaction, - post_id, - rev, - image, - size, - 'webp', - )) + thumbnails: list[Thumbnail] = await self.upload_thumbnails( + run, + transaction, + post_id, + rev, + image, + [(s, 'webp') for s in self.thumbnail_sizes] + [(self.thumbnail_sizes[0], 'jpg')], + ) del image @@ -729,9 +775,10 @@ async def updatePostMetadata( if not update and not update_privacy : raise BadRequest('no params were provided.') + notify: bool = False async with self.transaction() as t : if update_privacy and privacy : - await self._update_privacy(user, post_id, privacy, t, commit = False) + notify = await self._update_privacy(user, post_id, privacy, t, commit = False) post.privacy = await privacy_map.get_id(privacy) if update : @@ -740,6 +787,11 @@ async def updatePostMetadata( await PostKVS.put_async(post_id, post) await t.commit() + if notify : + """ + TODO: check for mentions and tags in the post, and notify users that they were mentioned, tagged, or a post matched one of their tag sets + """ + @timed async def _update_privacy( @@ -750,12 +802,17 @@ async def _update_privacy( transaction: Optional[Transaction] = None, commit: bool = True, ) -> bool : + """ + returns True if the post was published, false otherwise + """ if privacy == Privacy.unpublished : raise BadRequest('post privacy cannot be updated to unpublished.') if not transaction : transaction = self.transaction() + published: bool = False + async with transaction as t : data = await t.query_async(""" SELECT privacy.type @@ -786,6 +843,7 @@ async def _update_privacy( vote_task: Optional[Task] = None if old_privacy in UnpublishedPrivacies and privacy not in UnpublishedPrivacies : + published = True await t.query_async(""" INSERT INTO kheina.public.post_votes (user_id, post_id, upvote) @@ -873,19 +931,7 @@ async def _update_privacy( if vote_task : await vote_task - return True - - - @HttpErrorHandler('updating post privacy') - @timed - async def updatePrivacy(self: Self, user: KhUser, post_id: PostId, privacy: Privacy) : - success = await self._update_privacy(user, post_id, privacy) - - if await PostKVS.exists_async(post_id) : - # we need the created and updated values set by db, so just remove - ensure_future(PostKVS.remove_async(post_id)) - - return success + return published async def getImage(self: Self, ipost: InternalPost, coordinates: Coordinates) -> Image : @@ -1081,21 +1127,25 @@ async def uploadVideo( filename: str, post_id: PostId, ) -> Media : - run: str = uuid4().hex + start: datetime = datetime.now() + run: str = uuid4().hex # validate it's an actual video try : - await FFmpeg().input( - file_on_disk, - ).output( - '-', - f = 'null', - ).execute() + await validate_video(file_on_disk) except Exception as e : self.delete_file(file_on_disk) raise BadRequest('Uploaded file is not a video.', err=e) + self.logger.debug({ + 'run': run, + 'post': post_id, + 'elapsed': datetime.now() - start, + 'file_on_disk': file_on_disk, + 'message': 'validated input video file', + }) + rev: int mime_type: MimeType @@ -1115,19 +1165,13 @@ async def uploadVideo( try : # extract the first frame of the video to use for thumbnails/hash - await FFmpeg().input( - file_on_disk, - accurate_seek = None, - ss = '0', - ).output( - (screenshot := f'images/{uuid4().hex}_{filename}.webp'), - { 'frames:v': '1' }, - ).execute() + screenshot = await extract_frame(file_on_disk, filename) post: InternalPost = await posts._get_post(post_id) self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, 'file_on_disk': file_on_disk, 'content_type': mime_type, 'filename': filename, @@ -1140,7 +1184,9 @@ async def uploadVideo( self.logger.debug({ 'run': run, - 'thumbhash': thumbhash, + 'post': post_id, + 'elapsed': datetime.now() - start, + 'thumbhash': b64encode(thumbhash).decode(), }) async with self.transaction() as transaction : @@ -1164,7 +1210,7 @@ async def uploadVideo( old_filename: Optional[str] = data[0] old_crc: Optional[int] = data[1] - image_size: PostSize + image_size: PostSize del data await self.purgeSystemTags(run, transaction, post_id) @@ -1195,14 +1241,15 @@ async def uploadVideo( audio = True continue - # since there can be empty audio streams, we need to do a further check of the audio stream itself - media = await self.parse_audio_stream(file_on_disk) - if media.get('RMS level dB') != '-inf' : - query.append("(tag_to_id('audio'), %s, 0)") - params.append(post_id.int()) - flags.append(MediaFlag.audio) + if audio : + # since there can be empty audio streams, we need to do a further check of the audio stream itself + media = await self.parse_audio_stream(file_on_disk) + if (rms := media.get('RMS level dB')) and rms != '-inf' : + query.append("(tag_to_id('audio'), %s, 0)") + params.append(post_id.int()) + flags.append(MediaFlag.audio) - del media + del media if not query or not params : raise BadRequest('no media streams found!') @@ -1236,14 +1283,14 @@ async def uploadVideo( ( %s, %s, %s, %s, %s, %s, %s, %s) on conflict (post_id) do update set updated = now(), - type = %s, - filename = %s, - length = %s, - thumbhash = %s, - width = %s, - height = %s, - crc = %s - WHERE media.post_id = %s + type = excluded.type, + filename = excluded.filename, + length = excluded.length, + thumbhash = excluded.thumbhash, + width = excluded.width, + height = excluded.height, + crc = excluded.crc + WHERE media.post_id = excluded.post_id RETURNING media.updated; """, ( post_id.int(), @@ -1254,75 +1301,44 @@ async def uploadVideo( image_size.width, image_size.height, rev, - - media_type, - filename, - content_length, - thumbhash, - image_size.width, - image_size.height, - rev, - - post_id.int(), ), fetch_one = True, ) updated: datetime = upd[0] if old_filename : - old_url: str - - if old_crc : - old_url = f'{post_id}/{old_crc}/{old_filename}' - - else : - old_url = f'{post_id}/{old_filename}' - + old_url: str = f'{post_id}/{old_crc}/{old_filename}' if old_crc else f'{post_id}/{old_filename}' await self.delete_file_async(old_url) self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, + 'url': old_url, 'message': 'deleted old file from cdn', }) - url: str = f'{post_id}/{rev}/{filename}' - # upload fullsize + url: str = f'{post_id}/{rev}/{filename}' await self.upload_async(open(file_on_disk, 'rb').read(), url, content_type = mime_type) self.logger.debug({ 'run': run, 'post': post_id, + 'elapsed': datetime.now() - start, + 'url': url, 'message': 'uploaded fullsize image to cdn', }) # upload thumbnails - thumbnails: list[Thumbnail] = [] - + thumbnails: list[Thumbnail] with Image(file=open(screenshot, 'rb')) as image : - for i, size in enumerate(self.thumbnail_sizes) : - image = self.convert_image(image, size) - - if not i : - # jpeg thumbnail - thumbnails.append(await self.upload_thumbnail( - run, - transaction, - post_id, - rev, - image, - size, - 'jpg', - )) - - thumbnails.append(await self.upload_thumbnail( - run, - transaction, - post_id, - rev, - image, - size, - 'webp', - )) + thumbnails: list[Thumbnail] = await self.upload_thumbnails( + run, + transaction, + post_id, + rev, + image, + [(s, 'webp') for s in self.thumbnail_sizes] + [(self.thumbnail_sizes[0], 'jpg')], + ) del image diff --git a/reporting/mod_actions.py b/reporting/mod_actions.py index e4f17df..3170546 100644 --- a/reporting/mod_actions.py +++ b/reporting/mod_actions.py @@ -10,8 +10,7 @@ from pydantic import BaseModel from avro_schema_repository.schema_repository import SchemaRepository -from posts.models import InternalPost -from posts.repository import Posts +from posts.repository import Repository as Posts from shared.auth import KhUser, Scope from shared.caching import AerospikeCache from shared.caching.key_value_store import KeyValueStore @@ -20,12 +19,12 @@ from shared.exceptions.http_error import BadRequest, Conflict, NotFound from shared.models import PostId, UserPortable from shared.sql import SqlInterface -from shared.sql.query import Field, Operator, Order, Query, Update, Value, Where -from users.repository import Users +from shared.sql.query import Field, Operator, Order, Query, Table, Update, Value, Where +from users.repository import Repository as Users from .models.actions import ActionType, BanAction, ForceUpdateAction, InternalActionType, InternalBanAction, InternalModAction, ModAction, RemovePostAction from .models.bans import Ban, InternalBan, InternalBanType, InternalIpBan -from .repository import Reporting +from .repository import Repository from .repository import kvs as reporting_kvs @@ -34,7 +33,7 @@ posts: Posts = Posts() AvroMarker: bytes = b'\xC3\x01' kvs: KeyValueStore = KeyValueStore('kheina', 'actions') -reporting: Reporting = Reporting() +reporting: Repository = Repository() class ModActions(SqlInterface) : @@ -113,10 +112,11 @@ async def action(self: Self, user: KhUser, iaction: InternalModAction) -> ModAct async def ban(self: Self, user: KhUser, iban: InternalBan) -> Ban : + iuser = await users._get_user(iban.user_id) return Ban( ban_id = iban.ban_id, ban_type = iban.ban_type.to_type(), - user = await self.user_portable(user, iban.user_id), + user = await users.portable(user, iuser), created = iban.created, completed = iban.completed, reason = iban.reason, @@ -238,7 +238,7 @@ async def create(self: Self, user: KhUser, response: str, action: ModAction) -> match action.action : case ForceUpdateAction() : await t.query_async( - Query(InternalPost.__table_name__).update( + Query(Table('kheina.public.posts')).update( Update('locked', Value(True)), ).where( Where( @@ -252,7 +252,7 @@ async def create(self: Self, user: KhUser, response: str, action: ModAction) -> case RemovePostAction() : await t.query_async( - Query(InternalPost.__table_name__).update( + Query(Table('kheina.public.posts')).update( Update('locked', Value(True)), ).where( Where( @@ -276,8 +276,8 @@ async def create(self: Self, user: KhUser, response: str, action: ModAction) -> completed = completed, reason = action.reason, )) - await kvs.put_async(f'ban={iban.user_id}', iban) - await kvs.put_async(f'user_id={user_id}', iaction) + await kvs.put_async(f'active_ban={user_id}', iban, action.action.duration) + await kvs.remove_async(f'user_bans={user_id}') await t.commit() @@ -453,26 +453,29 @@ async def _active_action(self: Self, post_id: PostId) -> Optional[InternalModAct ) - @AerospikeCache('kheina', 'actions', 'active_action={post_id}', _kvs=kvs) + @AerospikeCache('kheina', 'actions', 'post_actions={post_id}', _kvs=kvs) async def _actions(self: Self, post_id: PostId) -> list[InternalModAction] : - data: list[tuple[int, int, Optional[int], Optional[int], Optional[int], datetime, Optional[datetime], str, int, bytes]] = await self.query_async(Query(InternalModAction.__table_name__).select( - Field('mod_actions', 'action_id'), - Field('mod_actions', 'report_id'), - Field('mod_actions', 'post_id'), - Field('mod_actions', 'user_id'), - Field('mod_actions', 'assignee'), - Field('mod_actions', 'created'), - Field('mod_actions', 'completed'), - Field('mod_actions', 'reason'), - Field('mod_actions', 'action_type'), - Field('mod_actions', 'action'), - ).where( - Where( - Field('mod_actions', 'post_id'), - Operator.equal, - Value(post_id.int()), + data: list[tuple[int, int, Optional[int], Optional[int], Optional[int], datetime, Optional[datetime], str, int, bytes]] = await self.query_async( + Query(InternalModAction.__table_name__).select( + Field('mod_actions', 'action_id'), + Field('mod_actions', 'report_id'), + Field('mod_actions', 'post_id'), + Field('mod_actions', 'user_id'), + Field('mod_actions', 'assignee'), + Field('mod_actions', 'created'), + Field('mod_actions', 'completed'), + Field('mod_actions', 'reason'), + Field('mod_actions', 'action_type'), + Field('mod_actions', 'action'), + ).where( + Where( + Field('mod_actions', 'post_id'), + Operator.equal, + Value(post_id.int()), + ), ), - ), fetch_all = True) + fetch_all = True, + ) if not data : return [] @@ -498,3 +501,54 @@ async def actions(self: Self, user: KhUser, post_id: PostId) -> list[ModAction] await self.action(user, iaction) for iaction in await self._actions(post_id) ] + + + @AerospikeCache('kheina', 'actions', 'user_actions={user_id}', _kvs=kvs) + async def _user_actions(self: Self, user_id: int) -> list[InternalModAction] : + data: list[tuple[int, int, Optional[int], Optional[int], Optional[int], datetime, Optional[datetime], str, int, bytes]] = await self.query_async( + Query(InternalModAction.__table_name__).select( + Field('mod_actions', 'action_id'), + Field('mod_actions', 'report_id'), + Field('mod_actions', 'post_id'), + Field('mod_actions', 'user_id'), + Field('mod_actions', 'assignee'), + Field('mod_actions', 'created'), + Field('mod_actions', 'completed'), + Field('mod_actions', 'reason'), + Field('mod_actions', 'action_type'), + Field('mod_actions', 'action'), + ).where( + Where( + Field('mod_actions', 'user_id'), + Operator.equal, + Value(user_id), + ), + ), + fetch_all = True, + ) + + if not data : + return [] + + return [ + InternalModAction( + action_id = row[0], + report_id = row[1], + post_id = row[2], + user_id = row[3], + assignee = row[4], + created = row[5], + completed = row[6], + reason = row[7], + action_type = InternalActionType(row[8]), + action = bytes(row[9]), + ) + for row in data + ] + + async def user_actions(self: Self, user: KhUser, handle: str) -> list[ModAction] : + user_id: int = await users._handle_to_user_id(handle) + return [ + await self.action(user, iaction) + for iaction in await self._user_actions(user_id) + ] diff --git a/reporting/models/__init__.py b/reporting/models/__init__.py index c4c3ec2..5eb970c 100644 --- a/reporting/models/__init__.py +++ b/reporting/models/__init__.py @@ -30,5 +30,6 @@ class CreateActionRequest(BaseModel) : action: RemovePostAction | ForceUpdateAction | BanActionInput | None -class ReportReponseRequest(BaseModel) : +class CloseReponseRequest(BaseModel) : + report_id: int response: str diff --git a/reporting/models/actions.py b/reporting/models/actions.py index 951ce81..d8c5b1b 100644 --- a/reporting/models/actions.py +++ b/reporting/models/actions.py @@ -32,8 +32,7 @@ def internal(self: Self) -> InternalActionType : # these two enums must contain the same values -assert set(InternalActionType.__members__.keys()) == set(ActionType.__members__.keys()) -assert set(InternalActionType.__members__.keys()) == set(map(lambda x : x.value, ActionType.__members__.values())) +assert set(InternalActionType.__members__.keys()) == set(ActionType.__members__.keys()) == set(map(lambda x : x.value, ActionType.__members__.values())) class InternalModAction(BaseModel) : diff --git a/reporting/models/bans.py b/reporting/models/bans.py index 990eaad..a96a563 100644 --- a/reporting/models/bans.py +++ b/reporting/models/bans.py @@ -57,7 +57,7 @@ class InternalIpBan(BaseModel) : class Ban(BaseModel) : ban_id: int ban_type: BanType - user: Optional[UserPortable] + user: UserPortable created: datetime completed: datetime reason: str diff --git a/reporting/models/mod_queue.py b/reporting/models/mod_queue.py index 0329032..5c35cde 100644 --- a/reporting/models/mod_queue.py +++ b/reporting/models/mod_queue.py @@ -17,6 +17,6 @@ class InternalModQueueEntry(BaseModel) : class ModQueueEntry(BaseModel) : - queue_id: int = Field(description='orm:"pk;gen"') + queue_id: int assignee: Optional[UserPortable] report: Report diff --git a/reporting/reporting.py b/reporting/reporting.py index 65d5508..c655304 100644 --- a/reporting/reporting.py +++ b/reporting/reporting.py @@ -10,13 +10,13 @@ from .models.actions import BanAction, ForceUpdateAction, ModAction, RemovePostAction from .models.bans import Ban from .models.reports import BaseReport, Report -from .repository import Reporting # type: ignore +from .repository import Repository -mod_acitons = ModActions() +mod_actions = ModActions() -class Reporting(Reporting) : +class Reporting(Repository) : async def create(self: Self, user: KhUser, body: CreateRequest) -> Report : return await super().create( @@ -79,7 +79,7 @@ async def create_action(self: Self, user: KhUser, body: CreateActionRequest) -> case _ : raise BadRequest('unknown action object', body=body) - return await mod_acitons.create( + return await mod_actions.create( user, body.response, ModAction( @@ -95,8 +95,12 @@ async def create_action(self: Self, user: KhUser, body: CreateActionRequest) -> async def actions(self: Self, user: KhUser, post_id: PostId) -> list[ModAction] : - return await mod_acitons.actions(user, PostId(post_id)) + return await mod_actions.actions(user, post_id) + + + async def user_actions(self: Self, user: KhUser, handle: str) -> list[ModAction] : + return await mod_actions.user_actions(user, handle) async def bans(self: Self, user: KhUser, handle: str) -> list[Ban] : - return await mod_acitons.bans(user, handle) + return await mod_actions.bans(user, handle) diff --git a/reporting/repository.py b/reporting/repository.py index bc0bf37..c6154ed 100644 --- a/reporting/repository.py +++ b/reporting/repository.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from avro_schema_repository.schema_repository import SchemaRepository -from posts.repository import Posts +from posts.repository import Repository as Posts from shared.auth import KhUser, Scope from shared.caching import AerospikeCache from shared.caching.key_value_store import KeyValueStore @@ -17,7 +17,7 @@ from shared.models import UserPortable from shared.sql import SqlInterface from shared.sql.query import Field, Operator, Order, Query, Value, Where -from users.repository import Users +from users.repository import Repository as Users from .models.mod_queue import InternalModQueueEntry, ModQueueEntry from .models.reports import BaseReport, BaseReportHistory, CopyrightReport, HistoryMask, InternalReport, InternalReportType, Report @@ -30,7 +30,7 @@ kvs: KeyValueStore = KeyValueStore('kheina', 'reports') -class Reporting(SqlInterface) : +class Repository(SqlInterface) : _report_type_map: dict[InternalReportType, type[BaseModel]] = { InternalReportType.other: BaseReport, @@ -43,7 +43,7 @@ class Reporting(SqlInterface) : } def __init__(self, *args: Any, **kwargs: Any) : - assert set(Reporting._report_type_map.keys()) == set(InternalReportType.__members__.values()) + assert set(Repository._report_type_map.keys()) == set(InternalReportType.__members__.values()) super().__init__(*args, conversions={ IntEnum: lambda x: x.value }, **kwargs) @@ -55,14 +55,14 @@ async def _get_schema(fingerprint: bytes) -> Schema: @AsyncLRU(maxsize=0) async def _get_serializer(self: Self, report_type: InternalReportType) -> tuple[bytes, AvroSerializer] : - model = Reporting._report_type_map[report_type] + model = Repository._report_type_map[report_type] return AvroMarker + await repo.addSchema(convert_schema(model)), AvroSerializer(model) async def _get_deserializer(self: Self, report_type: InternalReportType, fp: bytes) -> AvroDeserializer : assert fp[:2] == AvroMarker - model = Reporting._report_type_map[report_type] - return AvroDeserializer(read_model=model, write_model=await Reporting._get_schema(fp[2:10])) + model = Repository._report_type_map[report_type] + return AvroDeserializer(read_model=model, write_model=await Repository._get_schema(fp[2:10])) async def user_portable(self: Self, user: KhUser, user_id: Optional[int]) -> Optional[UserPortable] : @@ -165,7 +165,7 @@ async def update_report(self: Self, user: KhUser, report: Report) -> None : prev = report_data.dict() self.logger.debug({ 'incoming data': data, - 'prev': prev, + 'prev': prev, }) for k, v in data.items() : if prev.get(k) == v : @@ -224,20 +224,22 @@ async def assign_self(self: Self, user: KhUser, queue_id: int) -> None : raise Conflict('another moderator has assigned this report to themselves') - async def close_response(self: Self, user: KhUser, queue_id: int, response: str) -> Report : + async def close_response(self: Self, user: KhUser, report_id: int, response: str) -> Report : ireport: InternalReport async with self.transaction() as t : data: Optional[tuple[int, Optional[int]]] = await t.query_async(""" delete from kheina.public.mod_queue - where mod_queue.queue_id = %s - returning mod_queue.report_id, mod_queue.assignee; - """, ( - queue_id, - ), fetch_one = True) + where mod_queue.report_id = %s + returning mod_queue.assignee; + """, ( + report_id, + ), + fetch_one = True, + ) if not data : - raise NotFound('provided queue entry does not exist') + raise NotFound('provided report does not exist') if data[1] != user.user_id : raise BadRequest('cannot close a report that is assigned to someone else') diff --git a/reporting/router.py b/reporting/router.py index 9219899..d9bf27f 100644 --- a/reporting/router.py +++ b/reporting/router.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, Request from shared.auth import Scope -from shared.models import PostId +from shared.models import PostId, convert_path_post_id from shared.timing import timed -from .models import CreateActionRequest, CreateRequest, ReportReponseRequest +from .models import CloseReponseRequest, CreateActionRequest, CreateRequest from .models.actions import ModAction from .models.bans import Ban from .models.mod_queue import ModQueueEntry @@ -13,25 +13,25 @@ reportRouter = APIRouter( - prefix='/report', + prefix = '/report', ) reportsRouter = APIRouter( - prefix='/reports', + prefix = '/reports', ) actionRouter = APIRouter( - prefix='/action', + prefix = '/action', ) actionsRouter = APIRouter( - prefix='/actions', + prefix = '/actions', ) queueRouter = APIRouter( - prefix='/mod', + prefix = '/mod', ) bansRouter = APIRouter( - prefix='/bans', + prefix = '/bans', ) @@ -66,6 +66,13 @@ async def v1List(req: Request) -> list[Report] : return await reporting.list_(req.user) +@reportRouter.delete('/{report_id}') +@timed.root +async def v1CloseWithoutAction(req: Request, report_id: int, body: CloseReponseRequest) -> Report : + await req.user.verify_scope(Scope.mod) + return await reporting.close_response(req.user, report_id, body.response) + + ######################### queue ######################### @@ -83,13 +90,6 @@ async def v1AssignSelf(req: Request, queue_id: int) -> None : return await reporting.assign_self(req.user, queue_id) -@queueRouter.patch('/{queue_id}') -@timed.root -async def v1CloseWithoutAction(req: Request, queue_id: int, body: ReportReponseRequest) -> Report : - await req.user.verify_scope(Scope.mod) - return await reporting.close_response(req.user, queue_id, body.response) - - ######################### actions ######################### @actionRouter.put('') @@ -103,7 +103,14 @@ async def v1CloseWithAction(req: Request, body: CreateActionRequest) -> ModActio @timed.root async def v1Actions(req: Request, post_id: PostId) -> list[ModAction] : await req.user.verify_scope(Scope.mod) - return await reporting.actions(req.user, post_id) + return await reporting.actions(req.user, convert_path_post_id(post_id)) + + +@actionsRouter.get('/user/{handle}') +@timed.root +async def v1UserActions(req: Request, handle: str) -> list[ModAction] : + await req.user.verify_scope(Scope.mod) + return await reporting.user_actions(req.user, handle) ######################### bans ######################### diff --git a/requirements.lock b/requirements.lock index af9adde..c92c62f 100644 --- a/requirements.lock +++ b/requirements.lock @@ -25,6 +25,7 @@ fastapi==0.115.6 frozenlist==1.5.0 gunicorn==23.0.0 h11==0.14.0 +http_ece==1.2.1 httpcore==1.0.7 httptools==0.6.4 httpx==0.27.2 @@ -42,6 +43,7 @@ propcache==0.2.0 psycopg-binary==3.2.3 psycopg-pool==3.2.4 psycopg==3.2.3 +py-vapid==1.9.2 pycparser==2.22 pycryptodome==3.21.0 pydantic==1.10.19 @@ -53,18 +55,21 @@ pyotp==2.9.0 python-dotenv==1.0.1 python-ffmpeg==2.0.12 python-multipart==0.0.20 +pywebpush==2.0.3 PyYAML==6.0.2 requests==2.32.3 rich==13.9.4 scipy==1.14.1 setuptools==75.6.0 shellingham==1.5.4 +six==1.17.0 sniffio==1.3.1 starlette==0.41.3 typer==0.15.1 typing_extensions==4.12.2 ujson==5.5.0 urllib3==2.2.3 +uuid7==0.1.0 uvicorn==0.32.0 uvloop==0.21.0 Wand==0.6.13 diff --git a/sample-creds.json b/sample-creds.json index e96a18c..3ee5ff7 100644 --- a/sample-creds.json +++ b/sample-creds.json @@ -27,6 +27,11 @@ "key": "kw7vBGROerYZjIb7i05SrAIrYV18Xn43ijIahWrY", "key_name": "localhost-minio-key" }, + "notifications": { + "aes": "XHn0bd38L5mQgISKBtP2FbIJQnYZd0yWfwXFn1FBtfQ.pD795CKO9aaq2T5MuIqxOwgTKli-d0gMOaSlhl5GR1Mv9GhOGLct-3jgsqGhHc6uQwJXki7Jj1pjbj2fliOyAQ", + "pub": "wtzY_IHxeR7ZMbar.TsLHfu8KxrkW99Ou8ySJNB_nxF6vK33SyldQ1HIWhTWcgT5LuYd4Ekra7an0i-jsQDxFNSnnnxCVWKv0.vmTisEaOEAQajFyWGr65unDlD9ZJXOSIQmmfh0lpxUSS-x4-GFICLcdITGSNtE7f6JbMPSGbp7GygGZ5NYdPCw", + "priv": "8ax1-3veA74sd5nt.2Bv8pW2CEobi6HWO9800C3aBZ5iAubhq1SxPZeN-aYkrJ3LMU4CRjOFL4mVEzQTL.cmGWjKew0VFJVw7gEu2n2PhQAkWulNUJQkwx2-0NOGP4E-NIi_tQjYAWE-F_cMtZd3nRHmjJYy37IPZc3x27Bg" + }, "db": { "user": "kheina", "password": "password", @@ -50,4 +55,4 @@ ], "ip_salt": "570e755552e4ef95f7ee5ce9ad0ff4e38a64b858" } -} \ No newline at end of file +} diff --git a/server.py b/server.py index bce3124..2541c03 100644 --- a/server.py +++ b/server.py @@ -11,6 +11,7 @@ from account.router import app as account from configs.router import app as configs from emojis.router import app as emoji +from notifications.router import app as notifications from posts.router import app as posts from probe.router import probes from reporting.router import app as reporting @@ -165,3 +166,4 @@ def root() -> ServiceInfo : app.include_router(users) app.include_router(emoji) app.include_router(reporting) +app.include_router(notifications) diff --git a/sets/repository.py b/sets/repository.py index 7ef3d4d..e1521aa 100644 --- a/sets/repository.py +++ b/sets/repository.py @@ -3,16 +3,17 @@ from typing import Optional, Self, Tuple, Union from posts.models import InternalPost, Post, PostId, Privacy -from posts.repository import Posts, privacy_map +from posts.repository import Repository as Posts +from posts.repository import privacy_map from shared.auth import KhUser, Scope -from shared.caching import AerospikeCache, ArgsCache +from shared.caching import AerospikeCache from shared.caching.key_value_store import KeyValueStore from shared.datetime import datetime from shared.exceptions.http_error import NotFound from shared.hashing import Hashable from shared.models import InternalUser, UserPrivacy from shared.sql import SqlInterface -from users.repository import Users +from users.repository import Repository as Users from .models import InternalSet, Set, SetId @@ -23,9 +24,9 @@ posts = Posts() -class Sets(SqlInterface, Hashable) : +class Repository(SqlInterface, Hashable) : - def __init__(self: 'Sets') -> None : + def __init__(self) -> None : SqlInterface.__init__( self, conversions={ @@ -132,7 +133,7 @@ async def set(self: Self, iset: InternalSet, user: KhUser) -> Set : count=iset.count, title=iset.title, description=iset.description, - privacy=Sets._validate_privacy(await privacy_map.get(iset.privacy)), + privacy=Repository._validate_privacy(await privacy_map.get(iset.privacy)), created=iset.created, updated=iset.updated, first=first_post, diff --git a/sets/router.py b/sets/router.py index 0e468ad..e428e2d 100644 --- a/sets/router.py +++ b/sets/router.py @@ -1,8 +1,7 @@ from fastapi import APIRouter -from pydantic import conint from shared.models._shared import PostId -from shared.server import Request +from shared.models.server import Request from shared.timing import timed from .models import AddPostToSetRequest, CreateSetRequest, PostSet, Set, SetId, UpdateSetRequest diff --git a/sets/sets.py b/sets/sets.py index dd00b66..bf589b8 100644 --- a/sets/sets.py +++ b/sets/sets.py @@ -5,17 +5,18 @@ from psycopg.errors import UniqueViolation from posts.models import InternalPost, MediaType, Post, PostId, PostSize, Privacy, Rating -from posts.repository import Posts, privacy_map +from posts.repository import Repository as Posts +from posts.repository import privacy_map from shared.auth import KhUser, Scope -from shared.caching import AerospikeCache, ArgsCache +from shared.caching import ArgsCache from shared.datetime import datetime from shared.exceptions.http_error import BadRequest, Conflict, HttpErrorHandler, NotFound from shared.models.user import UserPrivacy from shared.timing import timed -from users.repository import Users +from users.repository import Repository as Users from .models import InternalSet, PostSet, Set, SetId, SetNeighbors, UpdateSetRequest -from .repository import SetKVS, SetNotFound, Sets # type: ignore +from .repository import Repository, SetKVS, SetNotFound """ @@ -44,7 +45,7 @@ users = Users() -class Sets(Sets) : +class Sets(Repository) : @staticmethod async def _verify_authorized(user: KhUser, iset: InternalSet) -> bool : diff --git a/shared/auth/__init__.py b/shared/auth/__init__.py index 82f5bd7..81b520c 100644 --- a/shared/auth/__init__.py +++ b/shared/auth/__init__.py @@ -13,7 +13,7 @@ from fastapi import Request from authenticator.authenticator import AuthAlgorithm, Authenticator, AuthState, PublicKeyResponse, Scope, TokenMetadata -from shared.models.auth import AuthToken, KhUser # type: ignore +from shared.models.auth import AuthToken, _KhUser from ..base64 import b64decode, b64encode from ..caching import ArgsCache @@ -33,7 +33,7 @@ class InvalidToken(ValueError) : pass -class KhUser(KhUser) : +class KhUser(_KhUser) : async def authenticated(self, raise_error: bool = True) -> bool : if self.banned : if raise_error : diff --git a/shared/caching/__init__.py b/shared/caching/__init__.py index 4fa33fd..cc08759 100644 --- a/shared/caching/__init__.py +++ b/shared/caching/__init__.py @@ -182,9 +182,16 @@ def wrapper(*args: Tuple[Hashable], **kwargs:Dict[str, Hashable]) -> Any : def deepTypecheck(type_: type | tuple, instance: Any) -> bool : + """ + returns true if instance is an instance of type_ + """ + + if instance is None and type(None) in getattr(type_, '__args__', tuple()) : + return True + if isinstance(type_, tuple) : if type(instance) not in type_ : - return False + return False else : t = getattr(type_, '__origin__', type_) @@ -267,7 +274,6 @@ async def wrapper(*args: Hashable, **kwargs: Hashable) -> Any : await decorator.kvs.put_async(key, data, TTL) else : - if not deepTypecheck(return_type, data) : data = await func(*args, **kwargs) diff --git a/shared/caching/key_value_store.py b/shared/caching/key_value_store.py index 4753c56..0591a80 100644 --- a/shared/caching/key_value_store.py +++ b/shared/caching/key_value_store.py @@ -4,28 +4,27 @@ from copy import copy from functools import partial from time import time -from typing import Any, Iterable, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Iterable, Optional, Self, Union import aerospike from ..config.constants import environment from ..config.credentials import fetch -from ..models import Undefined +from ..models._shared import Undefined from ..timing import timed from ..utilities import __clear_cache__, coerse -T = TypeVar('T') KeyType = Union[str, bytes, int] class KeyValueStore : _client = None - def __init__(self: 'KeyValueStore', namespace: str, set: str, local_TTL: float = 1) : + def __init__(self: Self, namespace: str, set: str, local_TTL: float = 1) : if not KeyValueStore._client and not environment.is_test() : config = { - 'hosts': fetch('aerospike.hosts', list[Tuple[str, int]]), + 'hosts': fetch('aerospike.hosts', list[tuple[str, int]]), 'policies': fetch('aerospike.policies', dict[str, Any]), } KeyValueStore._client = aerospike.client(config).connect() @@ -39,14 +38,17 @@ def __init__(self: 'KeyValueStore', namespace: str, set: str, local_TTL: float = @timed - def put(self: 'KeyValueStore', key: KeyType, data: Any, TTL: int = 0) : - KeyValueStore._client.put( # type: ignore + def put(self: Self, key: KeyType, data: Any, TTL: int = 0, bins: dict[str, Any] = { }) -> None : + KeyValueStore._client.put( # type: ignore (self._namespace, self._set, key), - { 'data': data }, - meta={ + { + 'data': data, + **bins, + }, + meta = { 'ttl': TTL, }, - policy={ + policy = { 'max_retries': 3, }, ) @@ -54,16 +56,21 @@ def put(self: 'KeyValueStore', key: KeyType, data: Any, TTL: int = 0) : @timed - async def put_async(self: 'KeyValueStore', key: KeyType, data: Any, TTL: int = 0) : + async def put_async(self: Self, key: KeyType, data: Any, TTL: int = 0, bins: dict[str, Any] = { }) -> None : with ThreadPoolExecutor() as threadpool : - return await get_event_loop().run_in_executor(threadpool, partial(self.put, key, data, TTL)) + return await get_event_loop().run_in_executor(threadpool, partial(self.put, key, data, TTL, bins)) - def _get(self: 'KeyValueStore', key: KeyType, type: Optional[Type[T]] = None) -> T : + def _get[T](self: Self, key: KeyType, type: Optional[type[T]] = None) -> T : if key in self._cache : return copy(self._cache[key][1]) - _, _, data = KeyValueStore._client.get((self._namespace, self._set, key)) # type: ignore + try : + _, _, data = KeyValueStore._client.get((self._namespace, self._set, key)) # type: ignore + + except aerospike.exception.RecordNotFound : + raise aerospike.exception.RecordNotFound(f'Record not found: {(self._namespace, self._set, key)}') + self._cache[key] = (time() + self._local_TTL, data['data']) if type : @@ -73,35 +80,32 @@ def _get(self: 'KeyValueStore', key: KeyType, type: Optional[Type[T]] = None) -> @timed - def get(self: 'KeyValueStore', key: KeyType, type: Optional[Type[T]] = None) -> T : + def get[T](self: Self, key: KeyType, type: Optional[type[T]] = None) -> T : __clear_cache__(self._cache, time) return self._get(key, type) @timed - async def get_async(self: 'KeyValueStore', key: KeyType, type: Optional[Type[T]] = None) -> T : + async def get_async[T](self: Self, key: KeyType, type: Optional[type[T]] = None) -> T : async with self._get_lock : __clear_cache__(self._cache, time) with ThreadPoolExecutor() as threadpool : - try : - return await get_event_loop().run_in_executor(threadpool, partial(self._get, key, type)) - - except aerospike.exception.RecordNotFound : - raise aerospike.exception.RecordNotFound(f'Record not found: {(self._namespace, self._set, key)}') + return await get_event_loop().run_in_executor(threadpool, partial(self._get, key, type)) - def _get_many[T: KeyType](self: 'KeyValueStore', k: Iterable[T]) -> dict[T, Any] : - keys: dict[T, T] = { v: v for v in k } - remote_keys: Set[T] = keys.keys() - self._cache.keys() + def _get_many[T, K: KeyType](self: Self, k: Iterable[K], type: Optional[type[T]] = None) -> dict[K, T | type[Undefined]] : + keys: dict[K, K] = { v: v for v in k } + remote_keys: set[K] = keys.keys() - self._cache.keys() + values: dict[K, Any] if remote_keys : data: list[Tuple[Any, Any, Any]] = KeyValueStore._client.get_many(list(map(lambda k : (self._namespace, self._set, k), remote_keys))) # type: ignore - data_map: dict[T, Any] = { } + data_map: dict[K, Any] = { } exp: float = time() + self._local_TTL for datum in data : - key: T = keys[datum[0][2]] + key: K = keys[datum[0][2]] # filter on the metadata, since it will always be populated if datum[1] : @@ -112,7 +116,7 @@ def _get_many[T: KeyType](self: 'KeyValueStore', k: Iterable[T]) -> dict[T, Any] else : data_map[key] = Undefined - return { + values = { **data_map, **{ key: copy(self._cache[key][1]) @@ -120,51 +124,64 @@ def _get_many[T: KeyType](self: 'KeyValueStore', k: Iterable[T]) -> dict[T, Any] }, } - # only local cache is required - return { - key: self._cache[key][1] - for key in keys.keys() - } + else : + # only local cache is required + values = { + key: copy(self._cache[key][1]) + for key in keys.keys() + } + + if type : + return { + k: coerse(v, type) + for k, v in values.items() + } + + return values @timed - def get_many[T: KeyType](self: 'KeyValueStore', keys: Iterable[T]) -> dict[T, Any] : + def get_many[T, K: KeyType](self: Self, keys: Iterable[K], type: Optional[type[T]] = None) -> dict[K, T | type[Undefined]] : __clear_cache__(self._cache, time) - return self._get_many(keys) + return self._get_many(keys, type) @timed - async def get_many_async[T: KeyType](self: 'KeyValueStore', keys: Iterable[T]) -> dict[T, Any] : + async def get_many_async[T, K: KeyType](self: Self, keys: Iterable[K], type: Optional[type[T]] = None) -> dict[K, T | type[Undefined]] : async with self._get_many_lock : with ThreadPoolExecutor() as threadpool : - return await get_event_loop().run_in_executor(threadpool, partial(self.get_many, keys)) + return await get_event_loop().run_in_executor(threadpool, partial(self.get_many, keys, type)) @timed - def remove(self: 'KeyValueStore', key: KeyType) -> None : + def remove(self: Self, key: KeyType) -> None : + try : + self._client.remove( # type: ignore + (self._namespace, self._set, key), + policy = { + 'max_retries': 3, + }, + ) + + except aerospike.exception.RecordNotFound : + pass + if key in self._cache : del self._cache[key] - self._client.remove( # type: ignore - (self._namespace, self._set, key), - policy={ - 'max_retries': 3, - }, - ) - @timed - async def remove_async(self: 'KeyValueStore', key: KeyType) -> None : + async def remove_async(self: Self, key: KeyType) -> None : with ThreadPoolExecutor() as threadpool : return await get_event_loop().run_in_executor(threadpool, partial(self.remove, key)) @timed - def exists(self: 'KeyValueStore', key: KeyType) -> bool : + def exists(self: Self, key: KeyType) -> bool : try : - _, meta = self._client.exists( # type: ignore + _, meta = self._client.exists( # type: ignore (self._namespace, self._set, key), - policy={ + policy = { 'max_retries': 3, }, ) @@ -176,10 +193,42 @@ def exists(self: 'KeyValueStore', key: KeyType) -> bool : @timed - async def exists_async(self: 'KeyValueStore', key: KeyType) -> bool : + async def exists_async(self: Self, key: KeyType) -> bool : with ThreadPoolExecutor() as threadpool : return await get_event_loop().run_in_executor(threadpool, partial(self.exists, key)) - def truncate(self: 'KeyValueStore') -> None : + @timed + def where[T](self: Self, *predicates: aerospike.predicates, type: Optional[type[T]] = None) -> list[T] : + results: list[T] = [] + func: Callable[[Any], None] + + if type : + def func(data: Any) -> None : + results.append(coerse(data[2]['data'], type)) + + else : + def func(data: Any) -> None : + results.append(copy(data[2]['data'])) + + KeyValueStore._client.query( # type: ignore + self._namespace, + self._set, + ).select( + 'data', + ).where( + *predicates, + ).foreach( + func, + ) + return results + + + @timed + async def where_async[T](self: Self, *predicates: aerospike.predicates, type: Optional[type[T]] = None) -> list[T] : + with ThreadPoolExecutor() as threadpool : + return await get_event_loop().run_in_executor(threadpool, partial(self.where, *predicates, type=type)) + + + def truncate(self: Self) -> None : self._client.truncate(self._namespace, self._set, 0) # type: ignore diff --git a/shared/exceptions/base_error.py b/shared/exceptions/base_error.py index d964e8a..b6ae3d9 100644 --- a/shared/exceptions/base_error.py +++ b/shared/exceptions/base_error.py @@ -23,7 +23,8 @@ def __init__(self, *args: Any, refid: Optional[UUID | str] = None, logdata: dict if 'refid' in logdata : del logdata['refid'] - self.logdata: dict[str, Any] = { + self.__dict__: dict[str, Any] = { **logdata, **kwargs, + **self.__dict__, } diff --git a/shared/exceptions/http_error.py b/shared/exceptions/http_error.py index b8061e4..95c35f3 100644 --- a/shared/exceptions/http_error.py +++ b/shared/exceptions/http_error.py @@ -127,26 +127,26 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any : match e : case NotImplementedError() : - raise NotImplemented( # noqa: F901 + raise NotImplemented( f'{message} has not been implemented.', - refid = refid, + refid = refid, logdata = logdata, - err = e, + err = e, ) case ClientError() : raise ServiceUnavailable( f'{ServiceUnavailable.__name__}: received an invalid response from an upstream server while {message}.', - refid = refid, + refid = refid, logdata = logdata, - err = e, + err = e, ) raise InternalServerError( f'an unexpected error occurred while {message}.', - refid = refid, + refid = refid, logdata = logdata, - err = e, + err = e, ) markcoroutinefunction(wrapper) diff --git a/shared/logging.py b/shared/logging.py index bdd6556..bf674db 100644 --- a/shared/logging.py +++ b/shared/logging.py @@ -1,12 +1,15 @@ import json import logging +from dataclasses import is_dataclass from enum import Enum, unique -from logging import ERROR, INFO, Logger, getLevelName +from logging import DEBUG, ERROR, INFO, Logger, getLevelName from sys import stderr, stdout from traceback import format_tb from types import ModuleType from typing import Any, Callable, Optional, Self, TextIO +from pydantic import BaseModel + from .config.constants import environment from .config.repo import name as repo_name from .config.repo import short_hash @@ -227,23 +230,22 @@ def emit(self, record: logging.LogRecord) -> None : if record.args and isinstance(record.msg, str) : record.msg = record.msg % tuple(map(str, map(json_stream, record.args))) - if record.exc_info : - e: BaseException = record.exc_info[1] # type: ignore - refid = getattr(e, 'refid', None) - errorinfo: dict[str, Any] = { - 'error': f'{getFullyQualifiedClassName(e)}: {e}', - 'stacktrace': list(map(str.strip, format_tb(record.exc_info[2]))), - 'refid': refid.hex if refid else None, - **json_stream(getattr(e, 'logdata', { })), - } + if is_dataclass(record.msg) or isinstance(record.msg, BaseModel) : + record.msg = record.msg.__dict__ + + if record.exc_info and record.exc_info[1] : + e: BaseException = record.exc_info[1] + msg: dict[str, Any] if isinstance(record.msg, dict) : - errorinfo.update(json_stream(record.msg)) + msg = json_stream(record.msg) else : - errorinfo['message'] = record.msg + msg = { 'message': record.msg } + + msg.update(json_stream(e)) try : - self.agent.log_struct(errorinfo, severity=record.levelno) + self.agent.log_struct(msg, severity=record.levelno) except : # noqa: E722 # we really, really do not want to fail-crash here. diff --git a/shared/models/auth.py b/shared/models/auth.py index e89333b..86641d0 100644 --- a/shared/models/auth.py +++ b/shared/models/auth.py @@ -27,7 +27,7 @@ def all_included_scopes(self: 'Scope') -> list['Scope'] : return [v for v in Scope.__members__.values() if Scope.user.value <= v.value <= self.value] or [self] -class KhUser(NamedTuple) : +class _KhUser(NamedTuple) : user_id: int = -1 token: Optional[AuthToken] = None scope: set[Scope] = set() diff --git a/shared/models/config.py b/shared/models/config.py new file mode 100644 index 0000000..e45dd34 --- /dev/null +++ b/shared/models/config.py @@ -0,0 +1,62 @@ +from abc import ABCMeta, abstractmethod +from enum import Enum +from typing import Callable, Self + +from avrofastapi import schema, serialization +from avrofastapi.serialization import AvroDeserializer, AvroSerializer, Schema, parse_avro_schema +from cache import AsyncLRU +from pydantic import BaseModel + +from avro_schema_repository.schema_repository import AvroMarker, SchemaRepository +from shared.caching import ArgsCache +from shared.utilities.json import json_stream + + +repo: SchemaRepository = SchemaRepository() + + +@AsyncLRU(maxsize=32) +async def getSchema(fingerprint: bytes) -> Schema : + return parse_avro_schema((await repo.getSchema(fingerprint)).decode()) + + +def _convert_schema(model: type[BaseModel], error: bool = False, conversions: dict[type, Callable[[schema.AvroSchemaGenerator, type], schema.AvroSchema] | schema.AvroSchema] = { }) -> schema.AvroSchema : + generator: schema.AvroSchemaGenerator = schema.AvroSchemaGenerator(model, error, conversions) + return json_stream(generator.schema()) + + +serialization.convert_schema = schema.convert_schema = _convert_schema + + +class Store(BaseModel, metaclass=ABCMeta) : + + @classmethod + @ArgsCache(float('inf')) + async def fingerprint(cls: type[Self]) -> bytes : + return await repo.addSchema(schema.convert_schema(cls)) + + @classmethod + @ArgsCache(float('inf')) + def serializer(cls: type[Self]) -> AvroSerializer : + return AvroSerializer(cls) + + async def serialize(self: Self) -> bytes : + return AvroMarker + await self.fingerprint() + self.serializer()(self) + + @classmethod + async def deserialize(cls: type[Self], data: bytes) -> Self : + assert data[:2] == AvroMarker + deserializer: AvroDeserializer = AvroDeserializer( + read_model = cls, + # write_model = await getSchema(data[2:10]), + ) + return deserializer(data[10:]) + + @staticmethod + @abstractmethod + def type_() -> Enum : + pass + + @classmethod + def key(cls: type[Self]) -> str : + return cls.type_().name diff --git a/shared/models/encryption.py b/shared/models/encryption.py new file mode 100644 index 0000000..a287a28 --- /dev/null +++ b/shared/models/encryption.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +from secrets import token_bytes +from typing import Optional, Self +from xmlrpc.client import boolean + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey +from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.serialization import load_der_public_key + +from ..base64 import b64decode, b64encode + + +@dataclass +class Keys : + aes: AESGCM + _aes_bytes: bytes + ed25519: Optional[Ed25519PrivateKey] + pub: Ed25519PublicKey + associated_data: bytes + + def encrypt(self: Self, data: bytes) -> bytes : + if not self.ed25519 : + raise ValueError('can only encrypt data with private keys') + + nonce = token_bytes(12) + return b'.'.join(map(b64encode, [nonce, self.aes.encrypt(nonce, data, self.associated_data), self.ed25519.sign(data)])) + + def decrypt(self: Self, data: bytes) -> bytes : + nonce: bytes; encrypted: bytes; sig: bytes + nonce, encrypted, sig = map(b64decode, b''.join(data.split()).split(b'.', 3)) + + decrypted: bytes = self.aes.decrypt(nonce, encrypted, self.associated_data) + self.pub.verify(sig, decrypted) + return decrypted + + @staticmethod + def _encode_pub(pub: Ed25519PublicKey) -> bytes : + return pub.public_bytes( + encoding = serialization.Encoding.DER, + format = serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + @staticmethod + def generate() -> 'Keys' : + aesbytes = AESGCM.generate_key(256) + aeskey = AESGCM(aesbytes) + ed25519priv = Ed25519PrivateKey.generate() + + return Keys( + aes = aeskey, + _aes_bytes = aesbytes, + ed25519 = ed25519priv, + pub = ed25519priv.public_key(), + associated_data = Keys._encode_pub(ed25519priv.public_key()), + ) + + @staticmethod + def load(aes: str, pub: str, priv: Optional[str] = None) -> 'Keys' : + aesbytes: bytes; aes_sig: bytes + aesbytes, aes_sig = map(b64decode, aes.split('.', 2)) + + aeskey = AESGCM(aesbytes) + + nonce: bytes; pub_encrypted: bytes; pub_sig: bytes + nonce, pub_encrypted, pub_sig = map(b64decode, pub.split('.', 3)) + + pub_decrypted: bytes = aeskey.decrypt(nonce, pub_encrypted, aesbytes) + pub_key: PublicKeyTypes = load_der_public_key(pub_decrypted, backend=default_backend()) + assert isinstance(pub_key, Ed25519PublicKey) + pub_key.verify(pub_sig, pub_decrypted) + pub_key.verify(aes_sig, aesbytes) + + associated_data: bytes = Keys._encode_pub(pub_key) + + pk: Optional[Ed25519PrivateKey] = None + if priv : + priv_encrypted: bytes; priv_sig: bytes + nonce, priv_encrypted, priv_sig = map(b64decode, priv.split('.', 3)) + priv_decrypted: bytes = aeskey.decrypt(nonce, priv_encrypted, associated_data) + pub_key.verify(priv_sig, priv_decrypted) + pk = Ed25519PrivateKey.from_private_bytes(priv_decrypted) + + return Keys( + aes = aeskey, + _aes_bytes = aesbytes, + ed25519 = pk, + pub = pub_key, + associated_data = associated_data, + ) + + def dump(self: Self, priv: boolean = False) -> dict[str, str] : + if not self.ed25519 : + raise ValueError('can only dump keys that contain private keys') + + data = { + 'aes': b'.'.join(map(b64encode, [self._aes_bytes, self.ed25519.sign(self._aes_bytes)])).decode(), + 'pub': b'.'.join(map(b64encode, [(nonce := token_bytes(12)), self.aes.encrypt(nonce, self.associated_data, self._aes_bytes), self.ed25519.sign(self.associated_data)])).decode(), + } + + if priv : + data['priv'] = b'.'.join(map(b64encode, [(nonce := token_bytes(12)), self.aes.encrypt(nonce, (pb := self.ed25519.private_bytes_raw()), self.associated_data), self.ed25519.sign(pb)])).decode() + + return data diff --git a/shared/models/server.py b/shared/models/server.py new file mode 100644 index 0000000..e402c93 --- /dev/null +++ b/shared/models/server.py @@ -0,0 +1,15 @@ +from fastapi import Request as _req + +from ..auth import KhUser + + +class Request(_req) : + @property + def auth(self) -> KhUser : + assert 'auth' in self.scope, 'AuthenticationMiddleware must be installed to access request.auth' + return self.scope['auth'] + + @property + def user(self) -> KhUser : + assert 'user' in self.scope, 'KhAuthMiddleware must be installed to access request.user' + return self.scope['user'] diff --git a/shared/server/__init__.py b/shared/server/__init__.py index 3384718..e69de29 100644 --- a/shared/server/__init__.py +++ b/shared/server/__init__.py @@ -1,105 +0,0 @@ -from typing import Iterable - -from fastapi import FastAPI, Request -from fastapi.responses import Response -from starlette.middleware.exceptions import ExceptionMiddleware - -from ..auth import KhUser -from ..config.constants import environment -from ..exceptions.base_error import BaseError -from ..exceptions.handler import jsonErrorHandler - - -NoContentResponse = Response(None, status_code=204) - - -class Request(Request) : - @property - def user(self) -> KhUser : - return super().user - - -def ServerApp( - auth: bool = True, - auth_required: bool = True, - cors: bool = True, - max_age: int = 86400, - custom_headers: bool = True, - allowed_hosts: Iterable[str] = [ - 'localhost', - '127.0.0.1', - '*.fuzz.ly', - 'fuzz.ly', - ], - allowed_origins: Iterable[str] = [ - 'localhost', - '127.0.0.1', - 'dev.fuzz.ly', - 'fuzz.ly', - ], - allowed_methods: Iterable[str] = [ - 'GET', - 'POST', - ], - allowed_headers: Iterable[str] = [ - 'accept', - 'accept-language', - 'authorization', - 'cache-control', - 'content-encoding', - 'content-language', - 'content-length', - 'content-security-policy', - 'content-type', - 'cookie', - 'host', - 'location', - 'referer', - 'referrer-policy', - 'set-cookie', - 'user-agent', - 'www-authenticate', - 'x-frame-options', - 'x-xss-protection', - ], - exposed_headers: Iterable[str] = [ - 'authorization', - 'cache-control', - 'content-type', - 'cookie', - 'set-cookie', - 'www-authenticate', - ], -) -> FastAPI : - app = FastAPI() - app.add_middleware(ExceptionMiddleware, handlers={ Exception: jsonErrorHandler }, debug=False) - app.add_exception_handler(BaseError, jsonErrorHandler) - - allowed_protocols = ['http', 'https'] if environment.is_local() else ['https'] - - if custom_headers : - from ..server.middleware import CustomHeaderMiddleware, HeadersToSet - exposed_headers = list(exposed_headers) + list(HeadersToSet.keys()) - app.middleware('http')(CustomHeaderMiddleware) - - if cors : - from ..server.middleware.cors import KhCorsMiddleware - app.add_middleware( - KhCorsMiddleware, - allowed_origins = set(allowed_origins), - allowed_protocols = set(allowed_protocols), - allowed_headers = list(allowed_headers), - allowed_methods = list(allowed_methods), - exposed_headers = list(exposed_headers), - max_age = max_age, - ) - - if allowed_hosts : - from starlette.middleware.trustedhost import TrustedHostMiddleware - app.add_middleware(TrustedHostMiddleware, allowed_hosts=list(allowed_hosts)) - - if auth : - from ..server.middleware.auth import KhAuthMiddleware - app.add_middleware(KhAuthMiddleware, required=auth_required) - - return app diff --git a/shared/sql/__init__.py b/shared/sql/__init__.py index a67d0d3..ead297f 100644 --- a/shared/sql/__init__.py +++ b/shared/sql/__init__.py @@ -1,21 +1,23 @@ from dataclasses import dataclass from dataclasses import field as dataclass_field -from enum import Enum +from enum import Enum, IntEnum from functools import lru_cache, partial from re import compile from types import TracebackType -from typing import Any, Awaitable, Callable, Dict, List, Optional, Protocol, Self, Tuple, Type, Union +from typing import Any, Awaitable, Callable, Optional, Protocol, Self, Union +from uuid import UUID from psycopg import AsyncClientCursor, AsyncConnection, AsyncCursor, Binary, OperationalError from psycopg_pool import AsyncConnectionPool from pydantic import BaseModel from pydantic.fields import ModelField +from ..config.constants import environment from ..config.credentials import fetch -from ..logging import Logger, getLogger +from ..logging import DEBUG, INFO, Logger, getLogger from ..models import PostId from ..timing import timed -from .query import Field, Insert, Operator, Query, Table, Update, Value, Where +from .query import Field, Insert, Operator, Order, Query, Table, Update, Value, Where _orm_regex = compile(r'orm:"([^\n]*?)(? None : - self.logger: Logger = getLogger() + def __init__(self: Self, long_query_metric: float = 1, conversions: dict[type, Callable] = { }) -> None : + self.logger: Logger = getLogger(level=DEBUG if environment.is_local() else INFO) self._long_query = long_query_metric - self._conversions: Dict[type, Callable] = { + self._conversions: dict[type, Callable] = { tuple: list, bytes: Binary, + IntEnum: lambda x : x.value, Enum: lambda x : x.name, + UUID: str, PostId: PostId.int, **conversions, } + if getattr(SqlInterface, 'db', None) is None : + SqlInterface.db = fetch('db', dict[str, str]) + async def open(self: Self) : if getattr(SqlInterface, 'pool', None) is None : @@ -116,6 +123,7 @@ async def query_async( sql, params = sql.build() params = tuple(map(self._convert_item, params)) + self.logger.debug({ 'sql': sql, 'params': params }) for _ in range(attempts) : async with SqlInterface.pool.connection() as conn : @@ -161,7 +169,7 @@ async def close(self: Self) -> int : @staticmethod - def _table_name(model: Union[BaseModel, Type[BaseModel]]) -> Table : + def _table_name(model: Union[BaseModel, type[BaseModel]]) -> Table : if not hasattr(model, '__table_name__') : raise AttributeError('model must be defined with the __table_name__ attribute') @@ -225,11 +233,11 @@ async def insert[T: BaseModel](self: Self, model: T, query: Optional[AwaitableQu map[subtype.field:column,field:column2] - maps a subtype's field to columns. separate nested fields by periods. """ table: Table = self._table_name(model) - d: Dict[str, Any] = model.dict() - paths: List[Tuple[str, ...]] = [] - vals: List[Value] = [] - cols: List[str] = [] - ret: List[str] = [] + d: dict[str, Any] = model.dict() + paths: list[tuple[str, ...]] = [] + vals: list[Value] = [] + cols: list[str] = [] + ret: list[str] = [] for key, field in model.__fields__.items() : attrs = SqlInterface._orm_attr_parser(field) @@ -279,7 +287,7 @@ async def insert[T: BaseModel](self: Self, model: T, query: Optional[AwaitableQu query = partial(self.query_async, commit=True) assert query - data: Tuple[Any, ...] = await query(sql, fetch_one=bool(ret)) + data: tuple[Any, ...] = await query(sql, fetch_one=bool(ret)) for i, path in enumerate(paths) : v2 = d @@ -306,7 +314,7 @@ def _read_convert(value: Any) -> Any : @staticmethod - def _assign_field_values[T: BaseModel](model: Type[T], data: Tuple[Any, ...]) -> T : + def _assign_field_values[T: BaseModel](model: type[T], data: tuple[Any, ...]) -> T : i = 0 d: dict = { } for key, field in model.__fields__.items() : @@ -386,7 +394,7 @@ async def select[T: BaseModel](self: Self, model: T, query: Optional[AwaitableQu query = partial(self.query_async, commit=False) assert query - data: Tuple[Any, ...] = await query(sql, fetch_one=True) + data: tuple[Any, ...] = await query(sql, fetch_one=True) if not data : raise KeyError('value does not exist in database') @@ -406,11 +414,11 @@ async def update[T: BaseModel](self: Self, model: T, query: Optional[AwaitableQu """ table: Table = self._table_name(model) sql: Query = Query(table) - d: Dict[str, Any] = model.dict() - paths: List[Tuple[str, ...]] = [] - vals: List[Value] = [] - cols: List[str] = [] - ret: List[str] = [] + d: dict[str, Any] = model.dict() + paths: list[tuple[str, ...]] = [] + vals: list[Value] = [] + cols: list[str] = [] + ret: list[str] = [] _, t = str(table).rsplit('.', 1) pk = 0 @@ -468,7 +476,7 @@ async def update[T: BaseModel](self: Self, model: T, query: Optional[AwaitableQu query = partial(self.query_async, commit=True) assert query - data: Tuple[Any, ...] = await query(sql, fetch_one=bool(ret)) + data: tuple[Any, ...] = await query(sql, fetch_one=bool(ret)) for i, path in enumerate(paths) : v2 = d @@ -520,11 +528,14 @@ async def delete(self: Self, model: BaseModel, query: Optional[AwaitableQuery] = await query(sql) - async def where[T: BaseModel](self: Self, model: Type[T], *where: Where, query: Optional[AwaitableQuery] = None) -> List[T] : + async def where[T: BaseModel](self: Self, model: type[T], *where: Where, order: list[tuple[Field, Order]] = [], query: Optional[AwaitableQuery] = None) -> list[T] : table = self._table_name(model) sql = Query(table).where(*where) _, t = str(table).rsplit('.', 1) + for o in order : + sql.order(*o) + for _, field in model.__fields__.items() : attrs = SqlInterface._orm_attr_parser(field) if attrs.ignore : @@ -541,7 +552,7 @@ async def where[T: BaseModel](self: Self, model: Type[T], *where: Where, query: query = partial(self.query_async, commit=False) assert query - data: List[Tuple[Any, ...]] = await query(sql, fetch_all=True) + data: list[tuple[Any, ...]] = await query(sql, fetch_all=True) return [SqlInterface._assign_field_values(model, row) for row in data] @@ -572,7 +583,11 @@ async def __aenter__(self: Self) : return self - async def __aexit__(self: Self, exc_type: Optional[Type[BaseException]], exc_obj: Optional[BaseException], exc_tb: Optional[TracebackType]) : + async def __await__(self: Self) : + pass + + + async def __aexit__(self: Self, exc_type: Optional[type[BaseException]], exc_obj: Optional[BaseException], exc_tb: Optional[TracebackType]) : if not self.nested : if self.conn : await self.conn.__aexit__(exc_type, exc_obj, exc_tb) @@ -609,6 +624,7 @@ async def query_async( assert self.conn params = tuple(map(self._sql._convert_item, params)) + self._sql.logger.debug({ 'sql': sql, 'params': params }) try : # TODO: convert fuzzly's Query implementation into a psycopg composable diff --git a/shared/timing/__init__.py b/shared/timing/__init__.py index 3345b32..a690c0d 100644 --- a/shared/timing/__init__.py +++ b/shared/timing/__init__.py @@ -54,6 +54,34 @@ def __repr__(self: Self) -> str : ')' ) + @staticmethod + def parse(json: dict[str, dict[str, int | float | dict]]) -> 'Execution' : + assert len(json) == 1 + k, v = json.items().__iter__().__next__() + return Execution( + name = k, + )._parse( + v, + ) + + def _parse(self: Self, json: dict[str, int | float | dict]) -> 'Execution' : + for k, v in json.items() : + match v : + case float() : + self.total = Time(v) + + case int() : + self.count = v + + case _ : + self.nested[k] = Execution( + name = k, + )._parse( + v, + ) + + return self + def record(self: Self, time: float) : self.total = Time(self.total + time) self.count += 1 diff --git a/shared/utilities/__init__.py b/shared/utilities/__init__.py index 4e5b95c..d5b6984 100644 --- a/shared/utilities/__init__.py +++ b/shared/utilities/__init__.py @@ -1,9 +1,14 @@ +from asyncio import Task, create_task from collections import OrderedDict +from contextvars import Context from math import ceil from time import time -from typing import Any, Callable, Generator, Hashable, Iterable, Optional, Tuple, Type, TypeVar +from types import CoroutineType +from typing import Any, Callable, Generator, Hashable, Iterable, Optional, Tuple, Type +from uuid import UUID from pydantic import parse_obj_as +from uuid_extensions import uuid7 as _uuid7 def __clear_cache__(cache: OrderedDict[Hashable, Tuple[float, Any]], t: Callable[[], float] = time) -> None : @@ -73,3 +78,25 @@ def coerse[T](obj: Any, type: Type[T]) -> T : :raises: pydantic.ValidationError on failure """ return parse_obj_as(type, obj) + + +def uuid7() -> UUID : + guid = _uuid7() + assert isinstance(guid, UUID) + return guid + + +background_tasks: set[Task] = set() +def ensure_future[T](fut: CoroutineType[Any, Any, T], name: str | None = None, context: Context | None = None) -> Task[T] : + """ + `utilities.ensure_future` differs from `asyncio.ensure_future` in that this utility function stores a strong + reference to the created task so that it will not get garbage collected before completion. + + `utilities.ensure_future` should be used whenever a task needs to be completed, but not within the context of + a request. Otherwise, `asyncio.create_task` should be used. + """ + # from https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + + background_tasks.add(task := create_task(fut, name=name, context=context)) + task.add_done_callback(background_tasks.discard) + return task diff --git a/shared/utilities/json.py b/shared/utilities/json.py index 1ac1af7..7d52d63 100644 --- a/shared/utilities/json.py +++ b/shared/utilities/json.py @@ -1,23 +1,21 @@ -from datetime import datetime +from datetime import datetime, timedelta from decimal import Decimal from enum import Enum, IntEnum +from hashlib import sha1 from traceback import format_tb -from typing import Any, Callable +from typing import Any, Callable, TypeVar from uuid import UUID from pydantic import BaseModel from ..base64 import b64decode -from ..crc import CRC -from ..models.auth import AuthToken, KhUser +from ..models.auth import AuthToken, _KhUser from . import getFullyQualifiedClassName -crc = CRC(32) - - -_conversions: dict[type, Callable] = { +_conversions: dict[type[T := TypeVar('T')], Callable[[T], Any]] = { datetime: str, + timedelta: timedelta.total_seconds, Decimal: float, float: float, int: int, @@ -31,7 +29,7 @@ IntEnum: lambda x : x.name, Enum: lambda x : x.name, UUID: lambda x : x.hex, - KhUser: lambda x : { + _KhUser: lambda x : { 'user_id': x.user_id, 'scope': json_stream(x.scope), 'token': json_stream(x.token) if x.token else None, @@ -44,7 +42,7 @@ 'token': { 'len': len(x.token_string), 'version': int(b64decode(x.token_string[:x.token_string.find('.')]).decode()), - 'hash': f'{crc(x.token_string.encode()):x}', + 'hash': sha1(x.token_string.encode()).hexdigest(), }, }, BaseModel: lambda x : json_stream(x.dict()), diff --git a/tags/repository.py b/tags/repository.py index ca9c6e5..68a7af2 100644 --- a/tags/repository.py +++ b/tags/repository.py @@ -12,7 +12,7 @@ from shared.sql import SqlInterface from shared.timing import timed from shared.utilities import flatten -from users.repository import Users +from users.repository import Repository as Users from .models import InternalTag, Tag, TagGroup, TagGroups, TagPortable @@ -23,7 +23,7 @@ users = Users() -class Tags(SqlInterface) : +class Repository(SqlInterface) : # TODO: figure out a way that we can increase this TTL (updating inheritance won't be reflected in cache) @timed @@ -138,11 +138,15 @@ async def _get_tag_counts(self, tags: Iterable[str]) -> dict[str, int] : """ counts = await CountKVS.get_many_async(tags) + found: dict[str, int] = { } for k, v in counts.items() : - if v is Undefined : - counts[k] = await self._populate_tag_cache(k) + if isinstance(v, int) : + found[k] = v - return counts + else : + found[k] = await self._populate_tag_cache(k) + + return found async def _get_tag_count(self, tag: str) -> int : diff --git a/tags/router.py b/tags/router.py index d1ecb94..dd3ec28 100644 --- a/tags/router.py +++ b/tags/router.py @@ -3,7 +3,7 @@ from posts.models import PostId from shared.auth import Scope from shared.exceptions.http_error import Forbidden -from shared.server import Request +from shared.models.server import Request from shared.timing import timed from .models import InheritRequest, LookupRequest, RemoveInheritance, Tag, TagGroups, TagsRequest, UpdateRequest diff --git a/tags/tagger.py b/tags/tagger.py index bb3be62..ca85c90 100644 --- a/tags/tagger.py +++ b/tags/tagger.py @@ -6,7 +6,7 @@ from psycopg.errors import NotNullViolation, UniqueViolation from posts.models import InternalPost, PostId, Privacy -from posts.repository import Posts +from posts.repository import Repository as Posts from shared.auth import KhUser, Scope from shared.caching import AerospikeCache, SimpleCache from shared.exceptions.http_error import BadRequest, Conflict, Forbidden, HttpErrorHandler, NotFound @@ -16,14 +16,14 @@ from shared.utilities import flatten from .models import InternalTag, Tag, TagGroup, TagGroups -from .repository import TagKVS, Tags, users +from .repository import Repository, TagKVS, users posts = Posts() Misc: TagGroup = TagGroup('misc') -class Tagger(Tags) : +class Tagger(Repository) : def _validateDescription(self, description: str) : if len(description) > 1000 : diff --git a/users/repository.py b/users/repository.py index 5e7c2ea..1ce3f1e 100644 --- a/users/repository.py +++ b/users/repository.py @@ -1,6 +1,8 @@ from asyncio import ensure_future +from collections import defaultdict +from curses.panel import bottom_panel from datetime import datetime -from typing import Iterable, Optional, Self, Union +from typing import Iterable, Mapping, Optional, Self, Union from cache import AsyncLRU @@ -9,12 +11,13 @@ from shared.caching.key_value_store import KeyValueStore from shared.exceptions.http_error import BadRequest, NotFound from shared.maps import privacy_map -from shared.models import Badge, InternalUser, Privacy, Undefined, User, UserPortable, UserPrivacy, Verified +from shared.models import Badge, InternalUser, PostId, Privacy, User, UserPortable, UserPrivacy, Verified from shared.sql import SqlInterface from shared.timing import timed UserKVS: KeyValueStore = KeyValueStore('kheina', 'users', local_TTL=60) +handleKVS: KeyValueStore = KeyValueStore('kheina', 'user_handle_map', local_TTL=60) FollowKVS: KeyValueStore = KeyValueStore('kheina', 'following') @@ -95,7 +98,7 @@ async def get_id(self: Self, key: Badge) -> int : badge_map: BadgeMap = BadgeMap() -class Users(SqlInterface) : +class Repository(SqlInterface) : def _clean_text(self: Self, text: str) -> Optional[str] : text = text.strip() @@ -200,19 +203,34 @@ async def _get_users(self, user_ids: Iterable[int]) -> dict[int, InternalUser] : return { } cached = await UserKVS.get_many_async(user_ids) + found: dict[int, InternalUser] = { } misses: list[int] = [] - for k, v in list(cached.items()) : - if v is not Undefined : + for k, v in cached.items() : + if isinstance(v, InternalUser) : + found[k] = v continue misses.append(k) - del cached[k] if not misses : - return cached - - data: list[tuple] = await self.query_async(""" + return found + + data: list[tuple[ + int, + str, + str, + int, + PostId | None, + str | None, + datetime, + str | None, + PostId | None, + bool, + bool, + bool, + list[int], + ]] = await self.query_async(""" SELECT users.user_id, users.display_name, @@ -236,13 +254,13 @@ async def _get_users(self, user_ids: Iterable[int]) -> dict[int, InternalUser] : """, ( misses, ), - fetch_all=True, + fetch_all = True, ) if not data : - raise NotFound('not all users could be found.', user_ids=user_ids, misses=misses, cached=cached, data=data) + raise NotFound('not all users could be found.', user_ids=user_ids, misses=misses, found=found, data=data) - users: dict[int, InternalUser] = cached + users: dict[int, InternalUser] = found for datum in data : verified: Optional[Verified] = None @@ -276,7 +294,8 @@ async def _get_users(self, user_ids: Iterable[int]) -> dict[int, InternalUser] : return users - @AerospikeCache('kheina', 'user_handle_map', '{handle}', local_TTL=60) + @timed + @AerospikeCache('kheina', 'user_handle_map', '{handle}', local_TTL=60, _kvs=handleKVS) async def _handle_to_user_id(self: Self, handle: str) -> int : data = await self.query_async(""" SELECT @@ -286,7 +305,7 @@ async def _handle_to_user_id(self: Self, handle: str) -> int : """, ( handle.lower(), ), - fetch_one=True, + fetch_one = True, ) if not data : @@ -295,6 +314,46 @@ async def _handle_to_user_id(self: Self, handle: str) -> int : return data[0] + @timed + async def _handles_to_user_ids(self: Self, handles: Iterable[str]) -> dict[str, int] : + handles = list(handles) + + if not handles : + return { } + + cached = await handleKVS.get_many_async(handles) + found: dict[str, int] = { } + misses: list[str] = [] + + for k, v in cached.items() : + if isinstance(v, int) : + found[k] = v + continue + + misses.append(k) + + if not misses : + return found + + data: list[tuple[str, int]] = await self.query_async(""" + SELECT + users.handle, + users.user_id + FROM kheina.public.users + WHERE users.handle = any(%s); + """, ( + misses, + ), + fetch_all = True, + ) + + for datum in data : + found[datum[0]] = datum[1] + ensure_future(handleKVS.put_async(datum[0], datum[1])) + + return found + + async def _get_user_by_handle(self: Self, handle: str) -> InternalUser : user_id: int = await self._handle_to_user_id(handle.lower()) return await self._get_user(user_id) @@ -337,17 +396,18 @@ async def following_many(self: Self, user_id: int, targets: list[int]) -> dict[i int(k[k.rfind('|') + 1:]): v for k, v in (await FollowKVS.get_many_async([f'{user_id}|{t}' for t in targets])).items() } + found: dict[int, bool] = { } misses: list[int] = [] - for k, v in list(cached.items()) : - if v is not Undefined : + for k, v in cached.items() : + if isinstance(v, bool) : + found[k] = v continue misses.append(k) - cached[k] = None if not misses : - return cached + return found data: list[tuple[int, int]] = await self.query_async(""" SELECT following.follows, count(1) @@ -359,10 +419,10 @@ async def following_many(self: Self, user_id: int, targets: list[int]) -> dict[i user_id, misses, ), - fetch_all=True, + fetch_all = True, ) - return_value: dict[int, bool] = cached + return_value: dict[int, bool] = found for target, following in data : following = bool(following) @@ -411,7 +471,18 @@ async def portables(self: Self, user: KhUser, iusers: Iterable[InternalUser]) -> returns a map of user id -> UserPortable """ - following = await self.following_many(user.user_id, [iuser.user_id for iuser in iusers]) + iusers = list(iusers) + if not iusers : + return { } + + following: Mapping[int, Optional[bool]] + + if await user.authenticated(False) : + following = await self.following_many(user.user_id, [iuser.user_id for iuser in iusers]) + + else : + following = defaultdict(lambda : None) + return { iuser.user_id: UserPortable( name = iuser.name, diff --git a/users/router.py b/users/router.py index a9d187f..1e06599 100644 --- a/users/router.py +++ b/users/router.py @@ -5,8 +5,8 @@ from shared.exceptions.http_error import HttpErrorHandler from shared.models import Badge, User from shared.models.auth import Scope +from shared.models.server import Request from shared.models.user import SetMod, SetVerified, UpdateSelf -from shared.server import Request from shared.timing import timed from .users import Users @@ -22,14 +22,6 @@ users: Users = Users() -################################################## INTERNAL ################################################## -# @app.get('/i1/{user_id}', response_model=InternalUser) -# async def i1User(req: Request, user_id: int) : -# await req.user.verify_scope(Scope.internal) -# return await users._get_user(user_id) - - -################################################## PUBLIC ################################################## @userRouter.get('/self', response_model=User) @timed.root @HttpErrorHandler("retrieving user's own profile") diff --git a/users/users.py b/users/users.py index c0c1a41..46bd743 100644 --- a/users/users.py +++ b/users/users.py @@ -1,25 +1,38 @@ -from asyncio import Task, ensure_future -from typing import List, Optional +from asyncio import Task, create_task +from typing import List, Optional, Self +from notifications.models import InternalUserNotification, UserNotificationEvent +from notifications.repository import Notifier from shared.auth import KhUser from shared.caching import SimpleCache from shared.exceptions.http_error import BadRequest, HttpErrorHandler, NotFound from shared.models import Badge, InternalUser, User, UserPrivacy, Verified +from shared.models._shared import UserPortable +from shared.timing import timed +from shared.utilities import ensure_future -from .repository import FollowKVS, UserKVS, Users, badge_map, privacy_map # type: ignore +from .repository import FollowKVS, Repository, UserKVS, badge_map, privacy_map -class Users(Users) : +notifier: Notifier = Notifier() + + +class Users(Repository) : @HttpErrorHandler('retrieving user') - async def getUser(self: 'Users', user: KhUser, handle: str) -> User : + async def getUser(self: Self, user: KhUser, handle: str) -> User : iuser: InternalUser = await self._get_user_by_handle(handle) return await self.user(user, iuser) - async def followUser(self: 'Users', user: KhUser, handle: str) -> None : + @timed + async def followUser(self: Self, user: KhUser, handle: str) -> None : user_id: int = await self._handle_to_user_id(handle.lower()) following: bool = await self.following(user.user_id, user_id) + portable: Task[UserPortable] = create_task(self.portable( + KhUser(user_id=user_id), + await self._get_user(user.user_id), + )) if following : raise BadRequest('you are already following this user.') @@ -34,15 +47,22 @@ async def followUser(self: 'Users', user: KhUser, handle: str) -> None : commit=True, ) - FollowKVS.put(f'{user.user_id}|{user_id}', True) - - - async def unfollowUser(self: 'Users', user: KhUser, handle: str) -> None : + ensure_future(FollowKVS.put_async(f'{user.user_id}|{user_id}', True)) + ensure_future(notifier.sendNotification( + user_id, + InternalUserNotification( + event = UserNotificationEvent.follow, + user_id = user.user_id, + ), + user = await portable, + ), name = 'sending notifications') + + async def unfollowUser(self: Self, user: KhUser, handle: str) -> None : user_id: int = await self._handle_to_user_id(handle.lower()) following: bool = await self.following(user.user_id, user_id) if following is False : - raise BadRequest('you are already not following this user.') + raise BadRequest('you are not currently following this user.') await self.query_async(""" DELETE FROM kheina.public.following @@ -53,16 +73,16 @@ async def unfollowUser(self: 'Users', user: KhUser, handle: str) -> None : commit=True, ) - FollowKVS.put(f'{user.user_id}|{user_id}', False) + ensure_future(FollowKVS.put_async(f'{user.user_id}|{user_id}', False)) - async def getSelf(self: 'Users', user: KhUser) -> User : + async def getSelf(self: Self, user: KhUser) -> User : iuser: InternalUser = await self._get_user(user.user_id) return await self.user(user, iuser) @HttpErrorHandler('updating user profile') - async def updateSelf(self: 'Users', user: KhUser, name: Optional[str], privacy: Optional[UserPrivacy], website: Optional[str], description: Optional[str]) -> None : + async def updateSelf(self: Self, user: KhUser, name: Optional[str], privacy: Optional[UserPrivacy], website: Optional[str], description: Optional[str]) -> None : iuser: InternalUser = await self._get_user(user.user_id) if not any([name, privacy, website, description]) : @@ -87,7 +107,7 @@ async def updateSelf(self: 'Users', user: KhUser, name: Optional[str], privacy: @HttpErrorHandler('fetching all users') - async def getUsers(self: 'Users', user: KhUser) : + async def getUsers(self: Self, user: KhUser) : # TODO: this function desperately needs to be reworked data = await self.query_async(""" SELECT @@ -145,9 +165,9 @@ async def getUsers(self: 'Users', user: KhUser) : @HttpErrorHandler('setting mod') - async def setMod(self: 'Users', handle: str, mod: bool) -> None : + async def setMod(self: Self, handle: str, mod: bool) -> None : user_id: int = await self._handle_to_user_id(handle.lower()) - user_task: Task[InternalUser] = ensure_future(self._get_user(user_id)) + user_task: Task[InternalUser] = create_task(self._get_user(user_id)) await self.query_async(""" UPDATE kheina.public.users @@ -165,13 +185,13 @@ async def setMod(self: 'Users', handle: str, mod: bool) -> None : @SimpleCache(60) - async def fetchBadges(self: 'Users') -> List[Badge] : + async def fetchBadges(self: Self) -> List[Badge] : return await badge_map.all() @HttpErrorHandler('adding badge to self') - async def addBadge(self: 'Users', user: KhUser, badge: Badge) -> None : - iuser_task: Task[InternalUser] = ensure_future(self._get_user(user.user_id)) + async def addBadge(self: Self, user: KhUser, badge: Badge) -> None : + iuser_task: Task[InternalUser] = create_task(self._get_user(user.user_id)) try : badge_id: int = await badge_map.get_id(badge) @@ -198,8 +218,8 @@ async def addBadge(self: 'Users', user: KhUser, badge: Badge) -> None : @HttpErrorHandler('removing badge from self') - async def removeBadge(self: 'Users', user: KhUser, badge: Badge) -> None : - iuser_task: Task[InternalUser] = ensure_future(self._get_user(user.user_id)) + async def removeBadge(self: Self, user: KhUser, badge: Badge) -> None : + iuser_task: Task[InternalUser] = create_task(self._get_user(user.user_id)) try : badge_id: int = await badge_map.get_id(badge) @@ -227,7 +247,7 @@ async def removeBadge(self: 'Users', user: KhUser, badge: Badge) -> None : @HttpErrorHandler('creating badge') - async def createBadge(self: 'Users', badge: Badge) -> None : + async def createBadge(self: Self, badge: Badge) -> None : await self.query_async(""" INSERT INTO kheina.public.badges (emoji, label) @@ -240,9 +260,9 @@ async def createBadge(self: 'Users', badge: Badge) -> None : @HttpErrorHandler('verifying user') - async def verifyUser(self: 'Users', handle: str, verified: Verified) -> None : + async def verifyUser(self: Self, handle: str, verified: Verified) -> None : user_id: int = await self._handle_to_user_id(handle.lower()) - user_task: Task[InternalUser] = ensure_future(self._get_user(user_id)) + user_task: Task[InternalUser] = create_task(self._get_user(user_id)) await self.query_async(f""" UPDATE kheina.public.users From 3d5713c07d2cc3fff8466a7230184c258ad27e63 Mon Sep 17 00:00:00 2001 From: dani <29378233+DanielleMiu@users.noreply.github.com> Date: Mon, 14 Apr 2025 19:38:30 -0400 Subject: [PATCH 2/5] things be workin --- .dockerignore | 2 + Dockerfile | 82 +++++----- authenticator/authenticator.py | 20 +-- configs/configs.py | 28 ++-- configs/models.py | 9 +- configs/router.py | 16 +- db/10/00-add-notification-subscriptions.sql | 23 ++- db/11/00-create-user-configs-table.sql | 2 +- docker-exec.sh | 3 + notifications/models.py | 3 + notifications/notifications.py | 98 +++++++++++ notifications/repository.py | 147 ++++++++++------- notifications/router.py | 25 ++- posts/blocking.py | 9 +- posts/repository.py | 171 ++++++++++++-------- reporting/models/__init__.py | 1 - reporting/repository.py | 4 +- reporting/router.py | 3 +- shared/auth/__init__.py | 32 +++- shared/caching/__init__.py | 16 +- shared/caching/key_value_store.py | 3 +- shared/models/auth.py | 35 ++-- shared/models/config.py | 2 +- shared/sql/__init__.py | 18 ++- users/repository.py | 6 +- users/router.py | 12 +- users/users.py | 92 ++++++----- 27 files changed, 558 insertions(+), 304 deletions(-) create mode 100644 docker-exec.sh create mode 100644 notifications/notifications.py diff --git a/.dockerignore b/.dockerignore index 15a1166..ac1bde8 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,4 +1,6 @@ **/__pycache__ +.github +.pytest_cache .venv credentials db diff --git a/Dockerfile b/Dockerfile index c1d723d..9b22f53 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,45 +1,45 @@ -FROM python:3.12 -# FROM python:3.12-alpine +# FROM python:3.13-slim +FROM python:3.13-alpine # can't install aerospike==8.0.0 on alpine -# RUN apk update && \ -# apk upgrade && \ -# apk add --no-cache git && \ -# apk add --no-cache --virtual \ -# .build-deps \ -# gcc \ -# g++ \ -# musl-dev \ -# libffi-dev \ -# postgresql-dev \ -# build-base \ -# bash linux-headers \ -# libuv libuv-dev \ -# openssl openssl-dev \ -# lua5.1 lua5.1-dev \ -# zlib zlib-dev \ -# python3-dev \ -# exiftool - -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt update && \ - apt upgrade -y && \ - apt install -y \ - build-essential \ - libssl-dev \ +RUN apk update && \ + apk upgrade && \ + apk add --no-cache git && \ + apk add --no-cache --virtual \ + .build-deps \ + gcc \ + g++ \ + musl-dev \ libffi-dev \ - git \ - jq \ - libpq-dev \ + postgresql-dev \ + build-base \ + bash linux-headers \ + libuv libuv-dev \ + openssl openssl-dev \ + lua5.1 lua5.1-dev \ + zlib zlib-dev \ python3-dev \ - libpng-dev \ - libjpeg-dev \ - libtiff-dev \ - libwebp-dev \ - imagemagick \ - libimage-exiftool-perl \ - ffmpeg + exiftool + +# ENV DEBIAN_FRONTEND=noninteractive + +# RUN apt update && \ +# apt upgrade -y && \ +# apt install -y \ +# build-essential \ +# libssl-dev \ +# libffi-dev \ +# git \ +# jq \ +# libpq-dev \ +# python3-dev \ +# libpng-dev \ +# libjpeg-dev \ +# libtiff-dev \ +# libwebp-dev \ +# imagemagick \ +# libimage-exiftool-perl \ +# ffmpeg RUN rm -rf /var/lib/apt/lists/* @@ -54,6 +54,8 @@ RUN wget https://go.dev/dl/go1.22.5.linux-amd64.tar.gz && \ rm go1.22.5.linux-amd64.tar.gz ENV GOROOT=/usr/local/go +# redefine $HOME to the default so that it stops complaining +ENV HOME="/root" ENV GOPATH=$HOME/go ENV PATH=$GOPATH/bin:$GOROOT/bin:$PATH @@ -67,6 +69,4 @@ ENV PATH="/opt/.venv/bin:$PATH" ENV PORT=80 ENV ENVIRONMENT=DEV -CMD jq '.fullchain' -r /etc/certs/cert.json > fullchain.pem && \ - jq '.privkey' -r /etc/certs/cert.json > privkey.pem && \ - gunicorn -w 2 -k uvicorn.workers.UvicornWorker --certfile fullchain.pem --keyfile privkey.pem -b 0.0.0.0:443 --timeout 1200 server:app +CMD ["docker-exec.sh"] diff --git a/authenticator/authenticator.py b/authenticator/authenticator.py index b596ee0..428e2c3 100644 --- a/authenticator/authenticator.py +++ b/authenticator/authenticator.py @@ -89,6 +89,13 @@ BotLoginSerializer: AvroSerializer = AvroSerializer(BotLogin) BotLoginDeserializer: AvroDeserializer = AvroDeserializer(BotLogin) +token_kvs: KeyValueStore = KeyValueStore('kheina', 'token') +KeyValueStore._client.index_integer_create( # type: ignore + 'kheina', + 'token', + 'user_id', + 'kheina_token_user_id_idx', +) class BotTypeMap(SqlInterface): @@ -130,7 +137,6 @@ async def get_id(self: Self, key: BotType) -> int : class Authenticator(SqlInterface, Hashable) : EmailRegex = re_compile(r'^(?P[A-Z0-9._%+-]+)@(?P[A-Z0-9.-]+\.[A-Z]{2,})$', flags=IGNORECASE) - KVS: KeyValueStore def __init__(self) : Hashable.__init__(self) @@ -151,16 +157,6 @@ def __init__(self) : 'id': 0, } - if not getattr(Authenticator, 'KVS', None) : - Authenticator.KVS = KeyValueStore('kheina', 'token') - # create the index used to query active logins - KeyValueStore._client.index_integer_create( # type: ignore - 'kheina', - 'token', - 'user_id', - 'kheina_token_user_id_idx', - ) - def _validateEmail(self, email: str) -> Dict[str, str] : e = Authenticator.EmailRegex.search(email) @@ -271,7 +267,7 @@ async def generate_token(self, user_id: int, token_data: dict, ttl: Optional[int algorithm = self._token_algorithm, fingerprint = token_data.get('fp', '').encode(), ) - await Authenticator.KVS.put_async( + await token_kvs.put_async( guid.bytes, token_info, ttl or self._token_expires_interval, diff --git a/configs/configs.py b/configs/configs.py index 618e4e4..98d4681 100644 --- a/configs/configs.py +++ b/configs/configs.py @@ -1,4 +1,4 @@ -from asyncio import Task, ensure_future +from asyncio import Task, create_task from collections.abc import Iterable from datetime import datetime from enum import Enum @@ -21,7 +21,7 @@ from shared.timing import timed from users.repository import Repository as Users -from .models import OTP, BannerStore, BlockBehavior, Blocking, BlockingBehavior, ConfigsResponse, ConfigType, CostsStore, CssProperty, Funding, OtpType, Store, Theme, UserConfigKeyFormat, UserConfigResponse, UserConfigType +from .models import OTP, BannerStore, BlockBehavior, Blocking, BlockingBehavior, ConfigsResponse, ConfigType, CostsStore, CssProperty, CssValue, Funding, OtpType, Store, Theme, UserConfigKeyFormat, UserConfigResponse, UserConfigType PatreonClient: PatreonApi = PatreonApi(fetch('creator_access_token', str)) @@ -93,13 +93,13 @@ async def getConfigs(self: Self, configs: Iterable[ConfigType]) -> dict[ConfigTy for k, v in data : config: ConfigType = ConfigType(k) value = found[config] = await self.SerializerTypeMap[config].deserialize(bytes(v)) - ensure_future(KVS.put_async(config, value)) + create_task(KVS.put_async(config, value)) return found async def allConfigs(self: Self) -> ConfigsResponse : - funds = ensure_future(self.getFunding()) + funds = create_task(self.getFunding()) configs = await self.getConfigs([ ConfigType.banner, ConfigType.costs, @@ -141,11 +141,11 @@ async def updateConfig(self: Self, user: KhUser, config: Store) -> None : @staticmethod - def _validateColors(css_properties: Optional[dict[CssProperty, str]]) -> Optional[dict[str, str | int]] : + def _validateColors(css_properties: Optional[dict[CssProperty, str]]) -> Optional[dict[str, CssValue | int | str]] : if not css_properties : return None - output: dict[str, str | int] = { } + output: dict[str, CssValue | int | str] = { } # color input is very strict for prop, value in css_properties.items() : @@ -175,8 +175,8 @@ def _validateColors(css_properties: Optional[dict[CssProperty, str]]) -> Optiona else : c: str = match.group('var').replace('-', '_') - if c in CssProperty._member_map_ : - output[color.value] = c + if c in CssValue._member_map_ : + output[color.value] = CssValue(c) else : raise BadRequest(f'{value} is not a valid color. value must be in the form "#xxxxxx", "#xxxxxxxx", or the name of another color variable (without the preceding deshes)') @@ -258,7 +258,7 @@ async def setUserConfig( ) for store in stores : - ensure_future(KVS.put_async( + create_task(KVS.put_async( UserConfigKeyFormat.format( user_id = user.user_id, key = store.key(), @@ -343,13 +343,11 @@ async def getUserConfig(self: Self, user: KhUser) -> UserConfigResponse : ) res = UserConfigResponse() - otp: Task[list[OTP]] = ensure_future(self._getUserOTP(user.user_id)) - print('==> data:', data) + otp: Task[list[OTP]] = create_task(self._getUserOTP(user.user_id)) if data : for key, type_, value in data : - t = UserConfigType(type_) - - match v := await Configs.SerializerTypeMap[t].deserialize(value) : + t: type[Store] = Configs.SerializerTypeMap[UserConfigType(type_)] + match v := await t.deserialize(value) : case BlockBehavior() : res.blocking_behavior = v.behavior @@ -380,7 +378,7 @@ async def getUserTheme(self: Self, user: KhUser) -> str : if isinstance(value, int) : css_properties += f'--{name}:#{value:08x} !important;' - elif isinstance(value, CssProperty) : + elif isinstance(value, CssValue) : css_properties += f'--{name}:var(--{value.value.replace("_", "-")}) !important;' else : diff --git a/configs/models.py b/configs/models.py index d8289bd..952f340 100644 --- a/configs/models.py +++ b/configs/models.py @@ -6,8 +6,8 @@ from pydantic import BaseModel, Field from shared.models import PostId -from shared.sql.query import Table from shared.models.config import Store +from shared.sql.query import Table UserConfigKeyFormat: Literal['user.{user_id}.{key}'] = 'user.{user_id}.{key}' @@ -70,6 +70,8 @@ class CssProperty(Enum) : background_repeat = 'background_repeat' background_size = 'background_size' + +class CssValue(Enum) : transition = 'transition' fadetime = 'fadetime' warning = 'warning' @@ -136,8 +138,8 @@ def type_() -> UserConfigType : class Theme(Store) : - wallpaper: Optional[PostId] = None - css_properties: Optional[dict[str, CssProperty | AvroInt | str]] = None + wallpaper: Optional[PostId] = None + css_properties: Optional[dict[str, CssValue | AvroInt | str]] = None @staticmethod def type_() -> UserConfigType : @@ -188,5 +190,4 @@ class Config(BaseModel) : created: datetime = Field(description='orm:"gen"') updated: datetime = Field(description='orm:"gen"') updated_by: int - value: Optional[str] = None bytes_: Optional[bytes] = Field(None, description='orm:"col[bytes]"') diff --git a/configs/router.py b/configs/router.py index 8502629..1e68288 100644 --- a/configs/router.py +++ b/configs/router.py @@ -54,11 +54,11 @@ async def v1UserTheme(req: Request) -> PlainTextResponse : ) -@app.patch('', status_code=204) -async def v1UpdateConfig(req: Request, body: UpdateConfigRequest) -> None : - await req.user.verify_scope(Scope.mod) - await configs.updateConfig( - req.user, - body.config, - body.value, - ) +# @app.patch('', status_code=204) +# async def v1UpdateConfig(req: Request, body: UpdateConfigRequest) -> None : +# await req.user.verify_scope(Scope.mod) +# await configs.updateConfig( +# req.user, +# body.config, +# body.value, +# ) diff --git a/db/10/00-add-notification-subscriptions.sql b/db/10/00-add-notification-subscriptions.sql index e5817f9..38e243f 100644 --- a/db/10/00-add-notification-subscriptions.sql +++ b/db/10/00-add-notification-subscriptions.sql @@ -1,5 +1,16 @@ begin; +create or replace function generated_created() returns trigger as +$$ +begin + + new.created = now(); + return new; + +end; +$$ +language plpgsql; + drop table if exists public.subscriptions; create table public.subscriptions ( sub_id uuid unique not null, @@ -19,12 +30,20 @@ create table public.notifications ( on update cascade on delete cascade, type smallint not null, - created timestamptz not null generated always as ('now'::timestamptz) stored, + created timestamptz not null, data bytea not null, primary key (user_id, id) ); -drop function public.register_subscription; +create index if not exists notifications_user_created_idx on public.notifications (user_id, created); + +create or replace trigger generated_created before insert on public.notifications + for each row execute procedure generated_created(); + +create or replace trigger immutable_columns before update on public.notifications + for each row execute procedure public.immutable_columns('id', 'user_id', 'type', 'created', 'data'); + +drop function if exists public.register_subscription; create or replace function public.register_subscription(sid uuid, uid bigint, sinfo bytea) returns void as $$ begin diff --git a/db/11/00-create-user-configs-table.sql b/db/11/00-create-user-configs-table.sql index 2436edc..395f904 100644 --- a/db/11/00-create-user-configs-table.sql +++ b/db/11/00-create-user-configs-table.sql @@ -1,6 +1,6 @@ begin; -alter table public.configs drop column value; +alter table public.configs drop column if exists value; drop table if exists public.user_configs; create table public.user_configs ( diff --git a/docker-exec.sh b/docker-exec.sh new file mode 100644 index 0000000..4a559a8 --- /dev/null +++ b/docker-exec.sh @@ -0,0 +1,3 @@ +jq '.fullchain' -r /etc/certs/cert.json > fullchain.pem && \ +jq '.privkey' -r /etc/certs/cert.json > privkey.pem && \ +gunicorn -w 2 -k uvicorn.workers.UvicornWorker --certfile fullchain.pem --keyfile privkey.pem -b 0.0.0.0:443 --timeout 1200 server:app diff --git a/notifications/models.py b/notifications/models.py index 17a6c7f..ac521ab 100644 --- a/notifications/models.py +++ b/notifications/models.py @@ -91,6 +91,7 @@ class InteractNotification(Notification) : """ an interact notification represents a user taking an action on a post """ + id: UUID type: Literal['interact'] = 'interact' event: InteractNotificationEvent created: datetime @@ -113,6 +114,7 @@ def type_(cls) -> NotificationType : class PostNotification(Notification) : + id: UUID type: Literal['post'] = 'post' event: PostNotificationEvent created: datetime @@ -137,6 +139,7 @@ def type_(cls) -> NotificationType : class UserNotification(Notification) : + id: UUID type: Literal['user'] = 'user' event: UserNotificationEvent created: datetime diff --git a/notifications/notifications.py b/notifications/notifications.py new file mode 100644 index 0000000..def535d --- /dev/null +++ b/notifications/notifications.py @@ -0,0 +1,98 @@ +from asyncio import Task, create_task +from datetime import datetime, timedelta +from typing import Self + +from posts.models import InternalPost, Post +from posts.repository import Repository as Posts +from shared.auth import KhUser +from shared.exceptions.http_error import HttpErrorHandler +from shared.models import InternalUser, PostId, UserPortable +from shared.sql.query import Field, Operator, Order, Value, Where +from users.repository import Repository as Users + +from .models import InteractNotification, InternalInteractNotification, InternalNotification, InternalPostNotification, InternalUserNotification, NotificationType, PostNotification, UserNotification +from .repository import Notifier + + +posts: Posts = Posts() +users: Users = Users() + + +class Notifications(Notifier) : + + @HttpErrorHandler('fetching notifications') + async def fetchNotifications(self: Self, user: KhUser) -> list[InteractNotification | PostNotification | UserNotification] : + data: list[InternalNotification] = await self.where( + InternalNotification, + Where( + Field('notifications', 'user_id'), + Operator.equal, + Value(user.user_id), + ), + order = [ + (Field('notifications', 'created'), Order.ascending), + ], + limit = 100, + ) + + if not data : + return [] + + inotifications: list[tuple[InternalNotification, InternalInteractNotification | InternalPostNotification | InternalUserNotification]] = [] + post_ids: list[PostId] = [] + user_ids: list[int] = [] + for n in data : + match n.type() : + case NotificationType.interact : + inotifications.append((n, notif := await InternalInteractNotification.deserialize(n.data))) + post_ids.append(PostId(notif.post_id)) + user_ids.append(notif.user_id) + + case NotificationType.post : + inotifications.append((n, notif := await InternalPostNotification.deserialize(n.data))) + post_ids.append(PostId(notif.post_id)) + + case NotificationType.user : + inotifications.append((n, notif := await InternalUserNotification.deserialize(n.data))) + user_ids.append(notif.user_id) + + iposts: Task[dict[PostId, InternalPost]] = create_task(posts._get_posts(post_ids)) + iusers: Task[dict[int, InternalUser]] = create_task(users._get_users(user_ids)) + + posts_task: Task[list[Post]] = create_task(posts.posts(user, list((await iposts).values()))) + all_users: dict[int, UserPortable] = await users.portables(user, list((await iusers).values())) + all_posts: dict[PostId, Post] = { + p.post_id: p + for p in await posts_task + } + + notifications: list[InteractNotification | PostNotification | UserNotification] = [] + + for n, i in inotifications : + match i : + case InternalInteractNotification() : + notifications.append(InteractNotification( + id = n.id, + event = i.event, + created = n.created, + user = all_users[i.user_id], + post = all_posts[PostId(i.post_id)], + )) + + case InternalPostNotification() : + notifications.append(PostNotification( + id = n.id, + event = i.event, + created = n.created, + post = all_posts[PostId(i.post_id)], + )) + + case InternalUserNotification() : + notifications.append(UserNotification( + id = n.id, + event = i.event, + created = n.created, + user = all_users[i.user_id], + )) + + return notifications diff --git a/notifications/repository.py b/notifications/repository.py index ec80283..72bfe3c 100644 --- a/notifications/repository.py +++ b/notifications/repository.py @@ -1,10 +1,10 @@ -from typing import Self +from typing import Optional, Self from urllib.parse import urlparse from uuid import UUID +import aerospike import ujson from aiohttp import ClientResponse, ClientSession, ClientTimeout -from cache import AsyncLRU from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey from py_vapid import Vapid02 @@ -12,18 +12,19 @@ from pywebpush import WebPusher as _WebPusher from configs.models import Config -from posts.models import Post -from shared.auth import KhUser +from posts.models import InternalPost, Post +from shared.auth import KhUser, tokenMetadata from shared.base64 import b64encode from shared.caching import AerospikeCache from shared.caching.key_value_store import KeyValueStore from shared.config.credentials import fetch from shared.datetime import datetime from shared.exceptions.http_error import HttpErrorHandler, InternalServerError -from shared.models import UserPortable +from shared.models import PostId, UserPortable +from shared.models.auth import AuthState, TokenMetadata from shared.models.encryption import Keys from shared.sql import SqlInterface -from shared.sql.query import Field, Operator, Value, Where +from shared.sql.query import Field, Operator, Order, Value, Where from shared.timing import timed from shared.utilities import uuid7 from shared.utilities.json import json_stream @@ -31,6 +32,15 @@ from .models import InteractNotification, InternalInteractNotification, InternalNotification, InternalPostNotification, InternalUserNotification, NotificationType, PostNotification, ServerKey, Subscription, SubscriptionInfo, UserNotification +@timed +async def getTokenMetadata(guid: UUID) -> Optional[TokenMetadata] : + try : + return await tokenMetadata(guid) + + except aerospike.exception.RecordNotFound : + return None + + class WebPusher(_WebPusher) : @timed async def send_async(self, *args, **kwargs) -> ClientResponse | str : @@ -74,6 +84,7 @@ async def startup(self) -> None : Notifier.keys = Keys.load(**fetch('notifications', dict[str, str])) + @timed @AerospikeCache('kheina', 'notifications', 'vapid-config', _kvs=kvs) async def getVapidPem(self: Self) -> bytes : async with self.transaction() as t : @@ -176,10 +187,17 @@ async def vapidHeaders(self: Self, sub_info: SubscriptionInfo) -> dict[str, str] @timed async def _send(self: Self, user_id: int, data: dict) -> int : - subs = await self.getSubInfo(user_id) unregister: list[UUID] = [] successes: int = 0 + subs = await self.getSubInfo(user_id) for sub_id, sub in subs.items() : + # sub_id is the token guid of the token that created the subscription + # check that it's still active before sending the notification + token = await getTokenMetadata(sub_id) + if not token or token.state != AuthState.active : + unregister.append(sub_id) + continue + res = await WebPusher( sub.dict(), verbose = True, @@ -215,13 +233,13 @@ async def _send(self: Self, user_id: int, data: dict) -> int : return successes - @timed + @timed.root async def sendNotification( self: Self, user_id: int, data: InternalInteractNotification | InternalPostNotification | InternalUserNotification, **kwargs: UserPortable | Post, - ) -> int : + ) -> None : """ creates, persists and then sends the given notification to the provided user_id. kwargs must include the user and/or post of the notification's user_id/post_id in the form of @@ -229,58 +247,69 @@ async def sendNotification( await sendNotification(..., user=UserPortable(...), post=Post(...)) ``` """ - type_: NotificationType = data.type_() - inotification = await self.insert(InternalNotification( - id = uuid7(), - user_id = user_id, - type_ = type_, - created = datetime.zero(), - data = await data.serialize(), - )) - - self.logger.debug({ - 'message': 'notification', - 'to': user_id, - 'notification': { - 'type': type(data), - 'type_enm': data.type_(), - **data.dict(), - }, - }) - - match data : - case InternalInteractNotification() : - user, post = kwargs.get('user'), kwargs.get('post') - assert isinstance(user, UserPortable) and isinstance(post, Post) - notification = InteractNotification( - event = data.event, - created = inotification.created, - user = user, - post = post, - ) - return await self._send(user_id, notification.dict()) - - case InternalPostNotification() : - post = kwargs.get('post') - assert isinstance(post, Post) - notification = PostNotification( - event = data.event, - created = inotification.created, - post = post, - ) - return await self._send(user_id, notification.dict()) - - case InternalUserNotification() : - user = kwargs.get('user') - assert isinstance(user, UserPortable) - notification = UserNotification( - event = data.event, - created = inotification.created, - user = user, - ) - return await self._send(user_id, notification.dict()) + try : + inotification = await self.insert(InternalNotification( + id = uuid7(), + user_id = user_id, + type_ = data.type_(), + created = datetime.zero(), + data = await data.serialize(), + )) + + self.logger.debug({ + 'message': 'notification', + 'to': user_id, + 'notification': { + 'type': type(data), + 'type_enm': data.type_(), + **data.dict(), + }, + }) + + match data : + case InternalInteractNotification() : + user, post = kwargs.get('user'), kwargs.get('post') + assert isinstance(user, UserPortable) and isinstance(post, Post), 'interact notifications must include user and post kwargs' + notification = InteractNotification( + id = inotification.id, + event = data.event, + created = inotification.created, + user = user, + post = post, + ) + await self._send(user_id, notification.dict()) + + case InternalPostNotification() : + post = kwargs.get('post') + assert isinstance(post, Post), 'post notifications must include a post kwarg' + notification = PostNotification( + id = inotification.id, + event = data.event, + created = inotification.created, + post = post, + ) + await self._send(user_id, notification.dict()) + + case InternalUserNotification() : + user = kwargs.get('user') + assert isinstance(user, UserPortable), 'user notifications must include a user kwarg' + notification = UserNotification( + id = inotification.id, + event = data.event, + created = inotification.created, + user = user, + ) + await self._send(user_id, notification.dict()) + + except : + # since this function will almost always be run async using ensure_future, handle errors internally + self.logger.exception('failed to send notification') @HttpErrorHandler('sending some random cunt a notif') + @timed async def debugSendNotification(self: Self, user_id: int, data: dict) -> int : return await self._send(user_id, data) + + +notifier: Notifier = Notifier() diff --git a/notifications/router.py b/notifications/router.py index 4bc77f1..1483141 100644 --- a/notifications/router.py +++ b/notifications/router.py @@ -1,14 +1,17 @@ from fastapi import APIRouter +from shared.datetime import datetime from shared.models.auth import Scope from shared.models.server import Request from shared.timing import timed -from .models import ServerKey, SubscriptionInfo -from .repository import Notifier +from .models import InteractNotification, PostNotification, ServerKey, SubscriptionInfo, UserNotification +from .notifications import Notifications + + +notifier = Notifications() -repo = Notifier() notificationsRouter = APIRouter( prefix='/notifications', ) @@ -16,7 +19,7 @@ @notificationsRouter.on_event('startup') async def startup() -> None : - await repo.startup() + await notifier.startup() @notificationsRouter.get('/register', response_model=ServerKey) @@ -26,20 +29,28 @@ async def v1GetServerKey(req: Request) -> ServerKey : only auth required """ await req.user.authenticated() - return await repo.getApplicationServerKey() + return await notifier.getApplicationServerKey() @notificationsRouter.put('/register', response_model=None) @timed.root async def v1RegisterNotificationTarget(req: Request, body: SubscriptionInfo) -> None : await req.user.authenticated() - await repo.registerSubInfo(req.user, body) + await notifier.registerSubInfo(req.user, body) + + +@notificationsRouter.get('') +@timed.root +async def v1GetNotifications(req: Request) -> list[InteractNotification | PostNotification | UserNotification] : + await req.user.authenticated() + return await notifier.fetchNotifications(req.user) @notificationsRouter.post('', status_code=201) +@timed.root async def v1SendThisBitchAVibe(req: Request, body: dict) -> int : await req.user.verify_scope(Scope.admin) - return await repo.debugSendNotification(req.user.user_id, body) + return await notifier.debugSendNotification(req.user.user_id, body) app = APIRouter( diff --git a/posts/blocking.py b/posts/blocking.py index 184e8df..b184284 100644 --- a/posts/blocking.py +++ b/posts/blocking.py @@ -2,9 +2,9 @@ from configs.configs import Configs from configs.models import Blocking +from .models import Rating from shared.auth import KhUser from shared.caching import ArgsCache -from shared.models import InternalUser from shared.timing import timed @@ -112,13 +112,14 @@ async def fetch_block_tree(user: KhUser) -> tuple[BlockTree, Optional[set[int]]] @timed -async def is_post_blocked(user: KhUser, uploader: InternalUser, tags: Iterable[str]) -> bool : +async def is_post_blocked(user: KhUser, uploader: int, rating: Rating, tags: Iterable[str]) -> bool : block_tree, blocked_users = await fetch_block_tree(user) - if blocked_users and uploader.user_id in blocked_users : + if blocked_users and uploader in blocked_users : return True tags: set[str | int] = set(tags) # TODO: convert handles to user_ids (int) - tags.add(uploader.user_id) + tags.add(uploader) + tags.add(f'rating:{rating.name}') return block_tree.blocked(tags) diff --git a/posts/repository.py b/posts/repository.py index 095e3c5..c7e76a0 100644 --- a/posts/repository.py +++ b/posts/repository.py @@ -1,8 +1,6 @@ -from asyncio import Task, ensure_future +from asyncio import Task, create_task from collections import defaultdict -from dataclasses import dataclass -from functools import partial -from typing import Callable, Mapping, Optional, Self, Tuple, Union +from typing import Callable, Iterable, Mapping, Optional, Self from cache import AsyncLRU @@ -12,10 +10,11 @@ from shared.datetime import datetime from shared.exceptions.http_error import BadRequest, NotFound from shared.maps import privacy_map -from shared.models import InternalUser, Undefined, UserPortable +from shared.models import InternalUser, UserPortable from shared.sql import SqlInterface from shared.sql.query import CTE, Field, Join, JoinType, Operator, Order, Query, Table, Value, Where from shared.timing import timed +from shared.utilities import ensure_future from tags.models import InternalTag, Tag, TagGroup from tags.repository import Repository as Tags from tags.repository import TagKVS @@ -38,7 +37,7 @@ class RatingMap(SqlInterface) : @timed @AsyncLRU(maxsize=0) async def get(self, key: int) -> Rating : - data: Tuple[str] = await self.query_async(""" + data: tuple[str] = await self.query_async(""" SELECT rating FROM kheina.public.ratings WHERE ratings.rating_id = %s @@ -55,7 +54,7 @@ async def get(self, key: int) -> Rating : @timed @AsyncLRU(maxsize=0) async def get_id(self, key: str | Rating) -> int : - data: Tuple[int] = await self.query_async(""" + data: tuple[int] = await self.query_async(""" SELECT rating_id FROM kheina.public.ratings WHERE ratings.rating = %s @@ -78,7 +77,7 @@ class MediaTypeMap(SqlInterface) : @timed @AsyncLRU(maxsize=0) async def get(self, key: int) -> MediaType : - data: Tuple[str, str] = await self.query_async(""" + data: tuple[str, str] = await self.query_async(""" SELECT file_type, mime_type FROM kheina.public.media_type WHERE media_type.media_type_id = %s @@ -96,7 +95,7 @@ async def get(self, key: int) -> MediaType : @timed @AsyncLRU(maxsize=0) async def get_id(self, mime: str) -> int : - data: Tuple[int] = await self.query_async(""" + data: tuple[int] = await self.query_async(""" SELECT media_type_id FROM kheina.public.media_type WHERE media_type.mime_type = %s @@ -112,18 +111,12 @@ async def get_id(self, mime: str) -> int : media_type_map: MediaTypeMap = MediaTypeMap() -@dataclass -class UserCombined: - portable: UserPortable - internal: InternalUser - - class Repository(SqlInterface) : def parse_response( self: Self, data: list[ - Tuple[ + tuple[ int, # 0 post_id str, # 1 title str, # 2 description @@ -182,7 +175,7 @@ def parse_response( def internal_select(self: Self, query: Query) -> Callable[[ list[ - Tuple[ + tuple[ int, # 0 post_id str, # 1 title str, # 2 description @@ -298,10 +291,45 @@ async def _get_post(self: Self, post_id: PostId) -> InternalPost : @timed - async def parents(self: Self, user: KhUser, ipost: InternalPost) -> Optional[Post] : - if not ipost.parent : - return None + async def _get_posts(self: Self, post_ids: Iterable[PostId]) -> dict[PostId, InternalPost] : + if not post_ids : + return { } + + cached = await PostKVS.get_many_async(post_ids) + found: dict[PostId, InternalPost] = { } + misses: list[PostId] = [] + + for k, v in cached.items() : + if v is None or isinstance(v, InternalPost) : + found[k] = v + continue + + misses.append(k) + + if not misses : + return found + + posts: dict[PostId, InternalPost] = found + data: list[InternalPost] = await self.where( + InternalPost, + Where( + Field('internal_posts', 'post_id'), + Operator.equal, + Value(misses, functions = ['any']), + ), + ) + + for post in data : + post_id = PostId(post.post_id) + ensure_future(PostKVS.put_async(post_id, post)) + posts[post_id] = post + + return posts + + @timed + @AerospikeCache('kheina', 'posts', 'parents={parent}', _kvs=PostKVS) + async def _parents(self: Self, parent: int) -> list[InternalPost] : cte = Query( Table('post_ids', cte=True), ).cte( @@ -317,7 +345,7 @@ async def parents(self: Self, user: KhUser, ipost: InternalPost) -> Optional[Pos Where( Field('posts', 'post_id'), Operator.equal, - Value(ipost.parent), + Value(parent), ), ).union( Query( @@ -374,9 +402,16 @@ async def parents(self: Self, user: KhUser, ipost: InternalPost) -> Optional[Pos ), ), ) - parser = self.internal_select(query := self.CteQuery(cte)) - iposts = parser(await self.query_async(query, fetch_all=True)) + return parser(await self.query_async(query, fetch_all=True)) + + + @timed + async def parents(self: Self, user: KhUser, ipost: InternalPost) -> Optional[Post] : + if not ipost.parent : + return None + + iposts = await self._parents(ipost.parent) posts = await self.posts(user, iposts) assert len(posts) == 1 return posts[0] @@ -385,22 +420,22 @@ async def parents(self: Self, user: KhUser, ipost: InternalPost) -> Optional[Pos @timed async def post(self: Self, user: KhUser, ipost: InternalPost) -> Post : post_id: PostId = PostId(ipost.post_id) - parent: Task[Optional[Post]] = ensure_future(self.parents(user, ipost)) - upl: Task[InternalUser] = ensure_future(users._get_user(ipost.user_id)) - tags_task: Task[list[InternalTag]] = ensure_future(tagger._fetch_tags_by_post(post_id)) - score: Task[Optional[Score]] = ensure_future(self.getScore(user, post_id)) + parent: Task[Optional[Post]] = create_task(self.parents(user, ipost)) + upl: Task[InternalUser] = create_task(users._get_user(ipost.user_id)) + tags_task: Task[list[InternalTag]] = create_task(tagger._fetch_tags_by_post(post_id)) + score: Task[Optional[Score]] = create_task(self.getScore(user, post_id)) uploader: InternalUser = await upl - upl_portable: Task[UserPortable] = ensure_future(users.portable(user, uploader)) + upl_portable: Task[UserPortable] = create_task(users.portable(user, uploader)) itags: list[InternalTag] = await tags_task - tags: Task[list[Tag]] = ensure_future(tagger.tags(user, itags)) - blocked: Task[bool] = ensure_future(is_post_blocked(user, uploader, (t.name for t in itags))) + tags: Task[list[Tag]] = create_task(tagger.tags(user, itags)) + blocked: Task[bool] = create_task(is_post_blocked(user, ipost.user_id, await rating_map.get(ipost.rating), (t.name for t in itags))) post = Post( post_id = post_id, - title = ipost.title, - description = ipost.description, - user = await upl_portable, + title = None, + description = None, + user = None, score = await score, rating = await rating_map.get(ipost.rating), parent = await parent, @@ -417,6 +452,13 @@ async def post(self: Self, user: KhUser, ipost: InternalPost) -> Post : if ipost.locked : post.locked = True + if not await user.verify_scope(Scope.mod, False) and ipost.user_id != user.user_id : + return post # we don't want any other fields populated + + post.title = ipost.title + post.description = ipost.description + post.user = await upl_portable + if ipost.filename and ipost.media_type and ipost.size and ipost.content_length and ipost.thumbnails : flags: list[MediaFlag] = [] @@ -499,7 +541,7 @@ async def scores_many(self: Self, post_ids: list[PostId]) -> dict[PostId, Option return found scores: dict[PostId, Optional[InternalScore]] = found - data: list[Tuple[int, int, int]] = await self.query_async(""" + data: list[tuple[int, int, int]] = await self.query_async(""" SELECT post_scores.post_id, post_scores.upvotes, @@ -530,7 +572,7 @@ async def scores_many(self: Self, post_ids: list[PostId]) -> dict[PostId, Option @timed @AerospikeCache('kheina', 'votes', '{user_id}|{post_id}', _kvs=VoteKVS) async def _get_vote(self: Self, user_id: int, post_id: PostId) -> int : - data: Optional[Tuple[bool]] = await self.query_async(""" + data: Optional[tuple[bool]] = await self.query_async(""" SELECT upvote FROM kheina.public.post_votes @@ -572,7 +614,7 @@ async def votes_many(self: Self, user_id: int, post_ids: list[PostId]) -> dict[P return found votes: dict[PostId, int] = found - data: list[Tuple[int, int]] = await self.query_async(""" + data: list[tuple[int, int]] = await self.query_async(""" SELECT post_votes.post_id, post_votes.upvote @@ -599,8 +641,8 @@ async def votes_many(self: Self, user_id: int, post_ids: list[PostId]) -> dict[P @timed async def getScore(self: Self, user: KhUser, post_id: PostId) -> Optional[Score] : - score_task: Task[Optional[InternalScore]] = ensure_future(self._get_score(post_id)) - vote: Task[int] = ensure_future(self._get_vote(user.user_id, post_id)) + score_task: Task[Optional[InternalScore]] = create_task(self._get_score(post_id)) + vote: Task[int] = create_task(self._get_vote(user.user_id, post_id)) score = await score_task @@ -683,7 +725,7 @@ async def _vote(self: Self, user: KhUser, post_id: PostId, upvote: Optional[bool ), ) - data: Tuple[int, int, datetime] = await transaction.query_async(""" + data: tuple[int, int, datetime] = await transaction.query_async(""" SELECT COUNT(post_votes.upvote), SUM(post_votes.upvote::int), posts.created FROM kheina.public.posts LEFT JOIN kheina.public.post_votes @@ -748,14 +790,14 @@ async def _vote(self: Self, user: KhUser, post_id: PostId, upvote: Optional[bool @timed - async def _uploaders(self: Self, user: KhUser, iposts: list[InternalPost]) -> dict[int, UserCombined] : + async def _uploaders(self: Self, user: KhUser, iposts: list[InternalPost]) -> dict[int, UserPortable] : """ returns populated user objects for every uploader id provided :return: dict in the form user id -> populated User object """ uploader_ids: list[int] = list(set(map(lambda x : x.user_id, iposts))) - users_task: Task[dict[int, InternalUser]] = ensure_future(users._get_users(uploader_ids)) + users_task: Task[dict[int, InternalUser]] = create_task(users._get_users(uploader_ids)) following: Mapping[int, Optional[bool]] if await user.authenticated(False) : @@ -767,16 +809,14 @@ async def _uploaders(self: Self, user: KhUser, iposts: list[InternalPost]) -> di iusers: dict[int, InternalUser] = await users_task return { - user_id: UserCombined( - internal = iuser, - portable = UserPortable( - name = iuser.name, - handle = iuser.handle, - privacy = users._validate_privacy(await privacy_map.get(iuser.privacy)), - icon = iuser.icon, - verified = iuser.verified, - following = following[user_id], - ), + user_id: + UserPortable( + name = iuser.name, + handle = iuser.handle, + privacy = users._validate_privacy(await privacy_map.get(iuser.privacy)), + icon = iuser.icon, + verified = iuser.verified, + following = following[user_id], ) for user_id, iuser in iusers.items() } @@ -802,7 +842,7 @@ async def _scores(self: Self, user: KhUser, iposts: list[InternalPost]) -> dict[ # but put all of them in the dict scores[post_id] = None - iscores_task: Task[dict[PostId, Optional[InternalScore]]] = ensure_future(self.scores_many(post_ids)) + iscores_task: Task[dict[PostId, Optional[InternalScore]]] = create_task(self.scores_many(post_ids)) user_votes: dict[PostId, int] if await user.authenticated(False) : @@ -833,7 +873,7 @@ async def _tags_many(self: Self, post_ids: list[PostId]) -> dict[PostId, list[In cached = { PostId(k[k.rfind('.') + 1:]): v - for k, v in (await VoteKVS.get_many_async([f'post.{post_id}' for post_id in post_ids])).items() + for k, v in (await TagKVS.get_many_async([f'post.{post_id}' for post_id in post_ids])).items() } found: dict[PostId, list[InternalTag]] = { } misses: list[PostId] = [] @@ -879,8 +919,8 @@ async def _tags_many(self: Self, post_ids: list[PostId]) -> dict[PostId, list[In description = None, # in this case, we don't care about this field )) - for post_id, t in tags.items() : - ensure_future(TagKVS.put_async(f'post.{post_id}', t)) + for post_id in post_ids : + ensure_future(TagKVS.put_async(f'post.{post_id}', tags[post_id])) return tags @@ -892,14 +932,17 @@ async def posts(self: Self, user: KhUser, iposts: list[InternalPost], assign_par assign_parents = True will assign any posts found with a matching parent id to the `parent` field of the resulting Post object assign_parents = False will instead assign these posts to the `replies` field of the resulting Post object """ - uploaders_task: Task[dict[int, UserCombined]] = ensure_future(self._uploaders(user, iposts)) - scores_task: Task[dict[PostId, Optional[Score]]] = ensure_future(self._scores(user, iposts)) + + # TODO: at some point, we should make this even faster by joining the uploaders and tag owners tasks + + uploaders_task: Task[dict[int, UserPortable]] = create_task(self._uploaders(user, iposts)) + scores_task: Task[dict[PostId, Optional[Score]]] = create_task(self._scores(user, iposts)) tags: dict[PostId, list[InternalTag]] = await self._tags_many(list(map(lambda x : PostId(x.post_id), iposts))) - at_task: Task[list[Tag]] = ensure_future(tagger.tags(user, [t for l in tags.values() for t in l])) - uploaders: dict[int, UserCombined] = await uploaders_task + at_task: Task[list[Tag]] = create_task(tagger.tags(user, [t for l in tags.values() for t in l])) + uploaders: dict[int, UserPortable] = await uploaders_task scores: dict[PostId, Optional[Score]] = await scores_task - all_tags: dict[str, Tag] = { + all_tags: dict[str, Tag] = { tag.tag: tag for tag in await at_task } @@ -940,7 +983,7 @@ async def posts(self: Self, user: KhUser, iposts: list[InternalPost], assign_par parent_id = parent_id, # only the first call retrieves blocked info, all the rest should be cached and not actually await - blocked = await is_post_blocked(user, uploaders[ipost.user_id].internal, tag_names), + blocked = await is_post_blocked(user, ipost.user_id, await rating_map.get(ipost.rating), tag_names), tags = tagger.groups(post_tags) ) @@ -953,11 +996,13 @@ async def posts(self: Self, user: KhUser, iposts: list[InternalPost], assign_par if ipost.locked : post.locked = True - continue # we don't want any other fields populated + + if not await user.verify_scope(Scope.mod, False) and ipost.user_id != user.user_id : + continue # we don't want any other fields populated post.title = ipost.title post.description = ipost.description - post.user = uploaders[ipost.user_id].portable + post.user = uploaders[ipost.user_id] if ipost.filename and ipost.media_type and ipost.size and ipost.content_length and ipost.thumbnails : post.media = Media( diff --git a/reporting/models/__init__.py b/reporting/models/__init__.py index 5eb970c..1c18d8d 100644 --- a/reporting/models/__init__.py +++ b/reporting/models/__init__.py @@ -31,5 +31,4 @@ class CreateActionRequest(BaseModel) : class CloseReponseRequest(BaseModel) : - report_id: int response: str diff --git a/reporting/repository.py b/reporting/repository.py index c6154ed..6877286 100644 --- a/reporting/repository.py +++ b/reporting/repository.py @@ -228,7 +228,7 @@ async def close_response(self: Self, user: KhUser, report_id: int, response: str ireport: InternalReport async with self.transaction() as t : - data: Optional[tuple[int, Optional[int]]] = await t.query_async(""" + data: Optional[tuple[int]] = await t.query_async(""" delete from kheina.public.mod_queue where mod_queue.report_id = %s returning mod_queue.assignee; @@ -241,7 +241,7 @@ async def close_response(self: Self, user: KhUser, report_id: int, response: str if not data : raise NotFound('provided report does not exist') - if data[1] != user.user_id : + if data[0] != user.user_id : raise BadRequest('cannot close a report that is assigned to someone else') ireport = await self._read(data[0]) diff --git a/reporting/router.py b/reporting/router.py index d9bf27f..52a1a07 100644 --- a/reporting/router.py +++ b/reporting/router.py @@ -1,7 +1,8 @@ -from fastapi import APIRouter, Request +from fastapi import APIRouter from shared.auth import Scope from shared.models import PostId, convert_path_post_id +from shared.models.server import Request from shared.timing import timed from .models import CloseReponseRequest, CreateActionRequest, CreateRequest diff --git a/shared/auth/__init__.py b/shared/auth/__init__.py index 81b520c..df19735 100644 --- a/shared/auth/__init__.py +++ b/shared/auth/__init__.py @@ -12,12 +12,11 @@ from cryptography.hazmat.primitives.serialization import load_der_public_key from fastapi import Request -from authenticator.authenticator import AuthAlgorithm, Authenticator, AuthState, PublicKeyResponse, Scope, TokenMetadata +from authenticator.authenticator import AuthAlgorithm, Authenticator, AuthState, PublicKeyResponse, Scope, TokenMetadata, token_kvs from shared.models.auth import AuthToken, _KhUser from ..base64 import b64decode, b64encode from ..caching import ArgsCache -from ..caching.key_value_store import KeyValueStore from ..datetime import datetime from ..exceptions.http_error import Forbidden, Unauthorized from ..utilities import int_from_bytes @@ -26,7 +25,6 @@ authenticator = Authenticator() ua_strip = re_compile(r'\/\d+(?:\.\d+)*') -KVS: KeyValueStore = KeyValueStore('kheina', 'token') class InvalidToken(ValueError) : @@ -53,7 +51,10 @@ async def verify_scope(self, scope: Scope, raise_error: bool = True) -> bool : await self.authenticated(raise_error) if scope not in self.scope : - raise Forbidden('User is not authorized to access this resource.', user=self) + if raise_error : + raise Forbidden('User is not authorized to access this resource.', user=self) + + return False return True @@ -104,7 +105,8 @@ async def v1token(token: str) -> AuthToken : if datetime.now() > expires : raise Unauthorized('Key has expired.') - token_info_task = ensure_future(KVS.get_async(guid.bytes, TokenMetadata)) + token_info_task = ensure_future(tokenMetadata(guid.bytes)) + token_info: TokenMetadata try : public_key = await _fetchPublicKey(key_id, algorithm) @@ -133,6 +135,7 @@ async def v1token(token: str) -> AuthToken : expires = expires, data = json.loads(data), token_string = token, + metadata = token_info, ) @@ -151,15 +154,28 @@ async def verifyToken(token: str) -> AuthToken : raise InvalidToken('The given token uses a version that is unable to be decoded.') +async def tokenMetadata(guid: bytes | UUID) -> TokenMetadata : + if isinstance(guid, UUID) : + guid = guid.bytes + + token = await token_kvs.get_async(guid, TokenMetadata) + + # though the kvs should only retain the token for as long as it's active, check the expiration anyway + if token.expires <= datetime.now() : + token.state = AuthState.inactive + + return token + + async def deactivateAuthToken(token: str, guid: Optional[bytes] = None) -> None : atoken = await verifyToken(token) if not guid : - return await KVS.remove_async(atoken.guid.bytes) + return await token_kvs.remove_async(atoken.guid.bytes) - tm = await KVS.get_async(guid, TokenMetadata) + tm = await tokenMetadata(guid) if tm.user_id == atoken.user_id : - return await KVS.remove_async(guid) + return await token_kvs.remove_async(guid) async def retrieveAuthToken(request: Request) -> AuthToken : diff --git a/shared/caching/__init__.py b/shared/caching/__init__.py index cc08759..0a52983 100644 --- a/shared/caching/__init__.py +++ b/shared/caching/__init__.py @@ -181,21 +181,25 @@ def wrapper(*args: Tuple[Hashable], **kwargs:Dict[str, Hashable]) -> Any : return decorator -def deepTypecheck(type_: type | tuple, instance: Any) -> bool : +def deepTypecheck(type_: type | tuple[type, ...], instance: Any) -> bool : """ returns true if instance is an instance of type_ """ + match instance : + case list() | tuple() : + return all(map(partial(deepTypecheck, type_), instance)) - if instance is None and type(None) in getattr(type_, '__args__', tuple()) : + if type(instance) is type_ : return True + type_ = getattr(type_, '__args__', type_) + if isinstance(type_, tuple) : if type(instance) not in type_ : return False else : - t = getattr(type_, '__origin__', type_) - if type(instance) is not t : + if type(instance) is not getattr(type_, '__origin__', type_) : return False if di := getattr(instance, '__dict__', None) : @@ -203,10 +207,6 @@ def deepTypecheck(type_: type | tuple, instance: Any) -> bool : if di.keys() != dt.keys() : return False - match instance : - case list() | tuple() : - return all(map(partial(deepTypecheck, type_.__args__), instance)) # type: ignore - return True diff --git a/shared/caching/key_value_store.py b/shared/caching/key_value_store.py index 0591a80..61c2aa4 100644 --- a/shared/caching/key_value_store.py +++ b/shared/caching/key_value_store.py @@ -95,6 +95,7 @@ async def get_async[T](self: Self, key: KeyType, type: Optional[type[T]] = None) def _get_many[T, K: KeyType](self: Self, k: Iterable[K], type: Optional[type[T]] = None) -> dict[K, T | type[Undefined]] : + # this weird ass dict is so that we can "convert" the returned aerospike keytype back to K keys: dict[K, K] = { v: v for v in k } remote_keys: set[K] = keys.keys() - self._cache.keys() values: dict[K, Any] @@ -133,7 +134,7 @@ def _get_many[T, K: KeyType](self: Self, k: Iterable[K], type: Optional[type[T]] if type : return { - k: coerse(v, type) + k: coerse(v, type) if v is not Undefined else v for k, v in values.items() } diff --git a/shared/models/auth.py b/shared/models/auth.py index 86641d0..40f23fd 100644 --- a/shared/models/auth.py +++ b/shared/models/auth.py @@ -6,12 +6,30 @@ from pydantic import BaseModel +@unique +class AuthState(IntEnum) : + active = 0 + inactive = 1 + + +class TokenMetadata(BaseModel) : + state: AuthState + key_id: int + user_id: int + version: bytes + algorithm: str + expires: datetime + issued: datetime + fingerprint: bytes + + class AuthToken(NamedTuple) : user_id: int expires: datetime guid: UUID data: dict[str, Any] token_string: str + metadata: TokenMetadata @unique @@ -50,23 +68,6 @@ class PublicKeyResponse(BaseModel) : expires: datetime -@unique -class AuthState(IntEnum) : - active = 0 - inactive = 1 - - -class TokenMetadata(BaseModel) : - state: AuthState - key_id: int - user_id: int - version: bytes - algorithm: str - expires: datetime - issued: datetime - fingerprint: bytes - - @unique class AuthAlgorithm(Enum) : ed25519 = 'ed25519' diff --git a/shared/models/config.py b/shared/models/config.py index e45dd34..9ca0d68 100644 --- a/shared/models/config.py +++ b/shared/models/config.py @@ -48,7 +48,7 @@ async def deserialize(cls: type[Self], data: bytes) -> Self : assert data[:2] == AvroMarker deserializer: AvroDeserializer = AvroDeserializer( read_model = cls, - # write_model = await getSchema(data[2:10]), + write_model = await getSchema(data[2:10]), ) return deserializer(data[10:]) diff --git a/shared/sql/__init__.py b/shared/sql/__init__.py index ead297f..2f026c8 100644 --- a/shared/sql/__init__.py +++ b/shared/sql/__init__.py @@ -31,31 +31,44 @@ class FieldAttributes : list of paths to columns. first entry of each list member is the route to the field within the model. second entry is the column that field belongs to within the database. + + ex: orm:map[subtype.field:column,field:column2] + """ column: Optional[str] = None """ database column this field belongs to + + ex: orm:"col[column]" """ primary_key: Optional[bool] = None """ field is a primary key + + ex: orm:"pk" """ generated: Optional[bool] = None """ field may be generated by the database + + ex: orm:"gen" """ default: Optional[bool] = None """ field has a database-defined default value + + ex: orm:"default" """ ignore: Optional[bool] = None """ field does not exist in the database + + ex: orm:"-" """ @@ -528,7 +541,7 @@ async def delete(self: Self, model: BaseModel, query: Optional[AwaitableQuery] = await query(sql) - async def where[T: BaseModel](self: Self, model: type[T], *where: Where, order: list[tuple[Field, Order]] = [], query: Optional[AwaitableQuery] = None) -> list[T] : + async def where[T: BaseModel](self: Self, model: type[T], *where: Where, order: list[tuple[Field, Order]] = [], limit: Optional[int] = None, query: Optional[AwaitableQuery] = None) -> list[T] : table = self._table_name(model) sql = Query(table).where(*where) _, t = str(table).rsplit('.', 1) @@ -548,6 +561,9 @@ async def where[T: BaseModel](self: Self, model: type[T], *where: Where, order: else : sql.select(Field(t, attrs.column or field.name)) + if limit : + sql.limit(limit) + if not query : query = partial(self.query_async, commit=False) diff --git a/users/repository.py b/users/repository.py index 1ce3f1e..d0aa06e 100644 --- a/users/repository.py +++ b/users/repository.py @@ -202,16 +202,16 @@ async def _get_users(self, user_ids: Iterable[int]) -> dict[int, InternalUser] : if not user_ids : return { } - cached = await UserKVS.get_many_async(user_ids) + cached = await UserKVS.get_many_async(map(str, user_ids)) found: dict[int, InternalUser] = { } misses: list[int] = [] for k, v in cached.items() : if isinstance(v, InternalUser) : - found[k] = v + found[int(k)] = v continue - misses.append(k) + misses.append(int(k)) if not misses : return found diff --git a/users/router.py b/users/router.py index 1e06599..a5c7473 100644 --- a/users/router.py +++ b/users/router.py @@ -97,18 +97,18 @@ async def v1User(req: Request, handle: str) : return await users.getUser(req.user, handle) -@userRouter.put('/{handle}/follow', status_code=204) +@userRouter.put('/{handle}/follow', response_model=bool) @timed.root -async def v1FollowUser(req: Request, handle: str) : +async def v1FollowUser(req: Request, handle: str) -> bool : await req.user.authenticated() - await users.followUser(req.user, handle.lower()) + return await users.followUser(req.user, handle.lower()) -@userRouter.delete('/{handle}/follow', status_code=204) +@userRouter.delete('/{handle}/follow', response_model=bool) @timed.root -async def v1UnfollowUser(req: Request, handle: str) : +async def v1UnfollowUser(req: Request, handle: str) -> bool : await req.user.authenticated() - await users.unfollowUser(req.user, handle.lower()) + return await users.unfollowUser(req.user, handle.lower()) app = APIRouter( diff --git a/users/users.py b/users/users.py index 46bd743..c7a34d9 100644 --- a/users/users.py +++ b/users/users.py @@ -1,8 +1,8 @@ from asyncio import Task, create_task -from typing import List, Optional, Self +from typing import Optional, Self from notifications.models import InternalUserNotification, UserNotificationEvent -from notifications.repository import Notifier +from notifications.repository import notifier from shared.auth import KhUser from shared.caching import SimpleCache from shared.exceptions.http_error import BadRequest, HttpErrorHandler, NotFound @@ -14,9 +14,6 @@ from .repository import FollowKVS, Repository, UserKVS, badge_map, privacy_map -notifier: Notifier = Notifier() - - class Users(Repository) : @HttpErrorHandler('retrieving user') @@ -26,28 +23,31 @@ async def getUser(self: Self, user: KhUser, handle: str) -> User : @timed - async def followUser(self: Self, user: KhUser, handle: str) -> None : - user_id: int = await self._handle_to_user_id(handle.lower()) + async def followUser(self: Self, user: KhUser, handle: str) -> bool : + user_id: int = await self._handle_to_user_id(handle.lower()) following: bool = await self.following(user.user_id, user_id) + + if following : + raise BadRequest('you are already following this user.') + portable: Task[UserPortable] = create_task(self.portable( KhUser(user_id=user_id), await self._get_user(user.user_id), )) - if following : - raise BadRequest('you are already following this user.') - await self.query_async(""" INSERT INTO kheina.public.following (user_id, follows) VALUES (%s, %s); - """, - (user.user_id, user_id), - commit=True, + """, ( + user.user_id, + user_id, + ), + commit = True, ) - ensure_future(FollowKVS.put_async(f'{user.user_id}|{user_id}', True)) + ensure_future(FollowKVS.put_async(f'{user.user_id}|{user_id}', following := True)) ensure_future(notifier.sendNotification( user_id, InternalUserNotification( @@ -55,10 +55,12 @@ async def followUser(self: Self, user: KhUser, handle: str) -> None : user_id = user.user_id, ), user = await portable, - ), name = 'sending notifications') + )) + return following - async def unfollowUser(self: Self, user: KhUser, handle: str) -> None : - user_id: int = await self._handle_to_user_id(handle.lower()) + + async def unfollowUser(self: Self, user: KhUser, handle: str) -> bool : + user_id: int = await self._handle_to_user_id(handle.lower()) following: bool = await self.following(user.user_id, user_id) if following is False : @@ -68,12 +70,15 @@ async def unfollowUser(self: Self, user: KhUser, handle: str) -> None : DELETE FROM kheina.public.following WHERE following.user_id = %s AND following.follows = %s - """, - (user.user_id, user_id), - commit=True, + """, ( + user.user_id, + user_id, + ), + commit = True, ) - ensure_future(FollowKVS.put_async(f'{user.user_id}|{user_id}', False)) + ensure_future(FollowKVS.put_async(f'{user.user_id}|{user_id}', following := False)) + return following async def getSelf(self: Self, user: KhUser) -> User : @@ -139,7 +144,7 @@ async def getUsers(self: Self, user: KhUser) : users.admin, users.verified; """, - fetch_all=True, + fetch_all = True, ) return [ @@ -158,7 +163,7 @@ async def getUsers(self: Self, user: KhUser) : Verified.artist if row[10] else None ) ), - following=None, + following = None, ) for row in data ] @@ -173,9 +178,11 @@ async def setMod(self: Self, handle: str, mod: bool) -> None : UPDATE kheina.public.users SET mod = %s WHERE users.user_id = %s - """, - (mod, user_id), - commit=True, + """, ( + mod, + user_id, + ), + commit = True, ) user: Optional[InternalUser] = await user_task @@ -185,7 +192,7 @@ async def setMod(self: Self, handle: str, mod: bool) -> None : @SimpleCache(60) - async def fetchBadges(self: Self) -> List[Badge] : + async def fetchBadges(self: Self) -> list[Badge] : return await badge_map.all() @@ -208,9 +215,11 @@ async def addBadge(self: Self, user: KhUser, badge: Badge) -> None : (user_id, badge_id) VALUES (%s, %s); - """, - (user.user_id, badge_id), - commit=True, + """, ( + user.user_id, + badge_id, + ), + commit = True, ) iuser.badges.append(badge) @@ -238,9 +247,11 @@ async def removeBadge(self: Self, user: KhUser, badge: Badge) -> None : DELETE FROM kheina.public.user_badge WHERE user_id = %s AND badge_id = %s; - """, - (user.user_id, badge_id), - commit=True, + """, ( + user.user_id, + badge_id, + ), + commit = True, ) UserKVS.put(str(user.user_id), iuser) @@ -253,9 +264,11 @@ async def createBadge(self: Self, badge: Badge) -> None : (emoji, label) VALUES (%s, %s); - """, - (badge.emoji, badge.label), - commit=True, + """, ( + badge.emoji, + badge.label, + ), + commit = True, ) @@ -268,9 +281,10 @@ async def verifyUser(self: Self, handle: str, verified: Verified) -> None : UPDATE kheina.public.users set {'verified' if verified == Verified.artist else verified.name} = true WHERE users.user_id = %s; - """, - (user_id,), - commit=True, + """, ( + user_id, + ), + commit = True, ) user: Optional[InternalUser] = await user_task From cf399a208b52574fde41583594e17ab46417e940 Mon Sep 17 00:00:00 2001 From: dani <29378233+DanielleMiu@users.noreply.github.com> Date: Mon, 14 Apr 2025 20:40:31 -0400 Subject: [PATCH 3/5] fix docker --- Dockerfile | 7 +++++-- docker-exec.sh | 2 ++ k8s.yml | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 9b22f53..d205e80 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,7 +19,8 @@ RUN apk update && \ lua5.1 lua5.1-dev \ zlib zlib-dev \ python3-dev \ - exiftool + exiftool \ + jq # ENV DEBIAN_FRONTEND=noninteractive @@ -45,6 +46,8 @@ RUN rm -rf /var/lib/apt/lists/* WORKDIR /app COPY . /app +RUN chmod +x docker-exec.sh + RUN mkdir "images" RUN rm -rf .venv && \ rm -rf credentials @@ -69,4 +72,4 @@ ENV PATH="/opt/.venv/bin:$PATH" ENV PORT=80 ENV ENVIRONMENT=DEV -CMD ["docker-exec.sh"] +CMD ["./docker-exec.sh"] diff --git a/docker-exec.sh b/docker-exec.sh index 4a559a8..a9eda82 100644 --- a/docker-exec.sh +++ b/docker-exec.sh @@ -1,3 +1,5 @@ +#!/bin/sh + jq '.fullchain' -r /etc/certs/cert.json > fullchain.pem && \ jq '.privkey' -r /etc/certs/cert.json > privkey.pem && \ gunicorn -w 2 -k uvicorn.workers.UvicornWorker --certfile fullchain.pem --keyfile privkey.pem -b 0.0.0.0:443 --timeout 1200 server:app diff --git a/k8s.yml b/k8s.yml index 8ebd7ae..ffeb8cc 100644 --- a/k8s.yml +++ b/k8s.yml @@ -14,7 +14,7 @@ spec: spec: containers: - name: fuzzly-backend - image: us-central1-docker.pkg.dev/kheinacom/fuzzly-repo/fuzzly-backend@sha256:e757a3b0c5912818f7131e7441db814adcda5d0a0f7c899abd64a5ebfafe0e99 + image: us-central1-docker.pkg.dev/kheinacom/fuzzly-repo/fuzzly-backend@sha256:13ff06a0bbdd5d9391af23f932030a7e6a7eb5a7cd5f33ad6a9d9f7c069333ce env: - name: pod_ip valueFrom: From 7cdb3ec64a0971cdf779c17442895e039e148902 Mon Sep 17 00:00:00 2001 From: dani <29378233+DanielleMiu@users.noreply.github.com> Date: Mon, 14 Apr 2025 20:43:06 -0400 Subject: [PATCH 4/5] ignore index found error --- Dockerfile | 2 -- authenticator/authenticator.py | 17 +++++++++++------ posts/blocking.py | 3 ++- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index d205e80..6d4ec37 100644 --- a/Dockerfile +++ b/Dockerfile @@ -49,8 +49,6 @@ COPY . /app RUN chmod +x docker-exec.sh RUN mkdir "images" -RUN rm -rf .venv && \ - rm -rf credentials RUN wget https://go.dev/dl/go1.22.5.linux-amd64.tar.gz && \ tar -xvf go1.22.5.linux-amd64.tar.gz -C /usr/local && \ diff --git a/authenticator/authenticator.py b/authenticator/authenticator.py index 428e2c3..670d7b2 100644 --- a/authenticator/authenticator.py +++ b/authenticator/authenticator.py @@ -7,6 +7,7 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Self, Tuple from uuid import UUID, uuid4 +import aerospike import pyotp import ujson as json from argon2 import PasswordHasher as Argon2 @@ -90,13 +91,17 @@ BotLoginSerializer: AvroSerializer = AvroSerializer(BotLogin) BotLoginDeserializer: AvroDeserializer = AvroDeserializer(BotLogin) token_kvs: KeyValueStore = KeyValueStore('kheina', 'token') -KeyValueStore._client.index_integer_create( # type: ignore - 'kheina', - 'token', - 'user_id', - 'kheina_token_user_id_idx', -) +try : + KeyValueStore._client.index_integer_create( # type: ignore + 'kheina', + 'token', + 'user_id', + 'kheina_token_user_id_idx', + ) + +except aerospike.exception.IndexFoundError : + pass class BotTypeMap(SqlInterface): @AsyncLRU(maxsize=0) diff --git a/posts/blocking.py b/posts/blocking.py index b184284..05367a2 100644 --- a/posts/blocking.py +++ b/posts/blocking.py @@ -2,11 +2,12 @@ from configs.configs import Configs from configs.models import Blocking -from .models import Rating from shared.auth import KhUser from shared.caching import ArgsCache from shared.timing import timed +from .models import Rating + configs = Configs() From da2a4a98a89b350e1bd068066438adaed0f29ee9 Mon Sep 17 00:00:00 2001 From: dani <29378233+DanielleMiu@users.noreply.github.com> Date: Mon, 14 Apr 2025 20:48:26 -0400 Subject: [PATCH 5/5] apply --- k8s.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k8s.yml b/k8s.yml index ffeb8cc..a4ae2ef 100644 --- a/k8s.yml +++ b/k8s.yml @@ -14,7 +14,7 @@ spec: spec: containers: - name: fuzzly-backend - image: us-central1-docker.pkg.dev/kheinacom/fuzzly-repo/fuzzly-backend@sha256:13ff06a0bbdd5d9391af23f932030a7e6a7eb5a7cd5f33ad6a9d9f7c069333ce + image: us-central1-docker.pkg.dev/kheinacom/fuzzly-repo/fuzzly-backend@sha256:53b0898536ff726a2557ea55cb9bdc99cdadeb310185af21093eb8ad88d1a452 env: - name: pod_ip valueFrom: