diff --git a/src/commands/index.ts b/src/commands/index.ts index 21e757e..0aee1e0 100644 --- a/src/commands/index.ts +++ b/src/commands/index.ts @@ -7,6 +7,7 @@ import { logger } from "@/logger" import { modules } from "@/modules" import type { TelemetryContextFlavor } from "@/modules/telemetry" import { redis } from "@/redis" +import { RedisSet } from "@/redis/set" import { fmt } from "@/utils/format" import { ephemeral } from "@/utils/messages" import type { Context, Role } from "@/utils/types" @@ -19,6 +20,19 @@ import { pin } from "./pin" import { report } from "./report" import { search } from "./search" +const userSet = new RedisSet({ + redis, + prefix: "managed-commands:cached-users", + ttl: 60 * 60 * 24, // 24h, we can afford some staleness here and it helps reduce the number of Redis calls significantly +}) + +const userRolesCache = new RedisFallbackAdapter({ + redis, + prefix: "managed-commands:user-roles", + ttl: 60 * 5, + logger, +}) + const adapter = new RedisFallbackAdapter>({ redis, prefix: "conv", @@ -34,6 +48,14 @@ export const commands = new ManagedCommands { + const key = `${userId}:${chatId}` + if (await userSet.has(key)) { + return true + } + await userSet.add(key) + return false + }, wrongScope: async ({ context, command }) => { await context.deleteMessage().catch(() => {}) logger.info( @@ -104,19 +126,15 @@ export const commands = new ManagedCommands { - // TODO: cache this to avoid hitting the db on every command - const { roles } = await api.tg.permissions.getRoles.query({ userId }) - return roles || [] + const cached = await userRolesCache.read(String(userId)) + if (cached) return cached + + const res = await api.tg.permissions.getRoles.query({ userId }) + const roles = res.roles ?? [] + await userRolesCache.write(String(userId), roles) + return roles }, }) - .createCommand({ - trigger: "ping", - scope: "private", - description: "Replies with pong", - handler: async ({ context }) => { - await context.reply("pong") - }, - }) .createCommand({ trigger: "start", scope: "private", @@ -138,4 +156,12 @@ export const commands = new ManagedCommands { + await context.reply("pong") + }, + }) .withCollection(linkAdminDashboard, report, search, management, moderation, pin, invite) diff --git a/src/lib/managed-commands/command.ts b/src/lib/managed-commands/command.ts index a25c396..52073db 100644 --- a/src/lib/managed-commands/command.ts +++ b/src/lib/managed-commands/command.ts @@ -1,6 +1,6 @@ import type { Conversation } from "@grammyjs/conversations" import type { Context } from "grammy" -import type { Message } from "grammy/types" +import type { BotCommand, Message } from "grammy/types" import type { z } from "zod" import type { MaybeArray } from "@/utils/types" import type { ConversationContext } from "./context" @@ -168,6 +168,14 @@ export type AnyCommand +export type AnyGroupCommand = Command< + CommandArgs, + CommandReplyTo, + "group" | "both", + TRole, + C +> + /** * Type guard to check if a command is allowed in groups. * @param cmd The command to check @@ -221,3 +229,45 @@ export function isAllowedInPrivateOnly< >(cmd: Command): cmd is Command { return cmd.scope === "private" } + +export function isAllowedInPrivate< + A extends CommandArgs, + R extends CommandReplyTo, + TRole extends string = string, + C extends Context = Context, +>(cmd: Command): cmd is Command { + return cmd.scope !== "group" +} + +export function isAllowedEverywhere< + A extends CommandArgs, + R extends CommandReplyTo, + TRole extends string = string, + C extends Context = Context, +>(cmd: Command): cmd is Command { + return cmd.scope === "both" || cmd.scope === undefined +} + +export function toBotCommands(command: AnyCommand): BotCommand[] { + const triggers = Array.isArray(command.trigger) ? command.trigger : [command.trigger] + return triggers.map((trigger) => ({ + command: trigger, + description: command.description ?? "No description", + })) +} + +export function isForThisScope(cmd: AnyCommand, chatType: "private" | "group" | "supergroup" | "channel"): boolean { + if (chatType === "channel") return false + if (cmd.scope === "private") return chatType === "private" + if (cmd.scope === "group") return chatType === "group" || chatType === "supergroup" + return true +} + +export function switchOnScope( + cmd: Command, + handlers: { private: T; group: T; both: T } +) { + if (cmd.scope === "private") return handlers.private + if (cmd.scope === "group") return handlers.group + return handlers.both +} diff --git a/src/lib/managed-commands/index.ts b/src/lib/managed-commands/index.ts index 951c8c3..21f1f24 100644 --- a/src/lib/managed-commands/index.ts +++ b/src/lib/managed-commands/index.ts @@ -9,16 +9,20 @@ import { hydrate } from "@grammyjs/hydrate" import { hydrateReply, parseMode } from "@grammyjs/parse-mode" import type { CommandContext, Context, Middleware, MiddlewareObj } from "grammy" import { Composer, MemorySessionStorage } from "grammy" -import type { Message } from "grammy/types" +import type { BotCommand, Message } from "grammy/types" import type { Result } from "neverthrow" import { err, ok } from "neverthrow" import z from "zod" +import { asyncFilter, asyncMap } from "@/utils/arrays" import { isFromGroupChat, isFromPrivateChat } from "@/utils/chat" import { fmt } from "@/utils/format" import { ephemeral } from "@/utils/messages" +import { once } from "@/utils/once" +import type { ContextWith } from "@/utils/types" import type { CommandsCollection } from "./collection" import type { AnyCommand, + AnyGroupCommand, ArgumentMap, ArgumentOptions, Command, @@ -29,7 +33,7 @@ import type { CommandScopedContext, RepliedTo, } from "./command" -import { isAllowedInGroups, isTypedArgumentOptions } from "./command" +import { isAllowedInGroups, isAllowedInPrivate, isTypedArgumentOptions, switchOnScope, toBotCommands } from "./command" import type { ManagedCommandsFlavor } from "./context" export type Hook = ( @@ -71,7 +75,7 @@ export type ManagedCommandsHooks) => Promise + overrideGroupAdminCheck?: (userId: number, chatId: number, context: OC) => Promise /** * Called when a command is invoked, before any processing is done, can be used to implement custom logic that should * run before checking permissions or requirements, for example logging or analytics @@ -82,6 +86,11 @@ export type ManagedCommandsHooks + /** + * A function to externally cache wether a user has had the commands menu generated for them or not + * @returns true if the user has had the commands menu generated, false otherwise + */ + cachedUserSetCommands?: (userId: number, chatId: number) => Promise } export interface IManagedCommandsOptions { @@ -99,7 +108,7 @@ export interface IManagedCommandsOptions { + * getUserRoles: async (userId) => { * const roles = await db.getUserRoles(userId) // Array<"admin" | "user">[] * return roles * }, @@ -114,7 +123,7 @@ export interface IManagedCommandsOptions) => Promise + getUserRoles: (userId: number) => Promise /** * Additional plugins to apply to the conversation inner composer. @@ -162,7 +171,7 @@ export class ManagedCommands< { private composer = new Composer() private commands: Record[]> = {} - private getUserRoles: (userId: number, context: CommandContext) => Promise + private getUserRoles: (userId: number) => Promise private hooks: ManagedCommandsHooks private adapter: ConversationStorage private registeredTriggers = new Set() @@ -263,9 +272,11 @@ export class ManagedCommands< */ private static formatCommandUsage(cmd: AnyCommand): string { const args = cmd.args ?? [] - const scope = - cmd.scope === "private" ? "Private Chat" : cmd.scope === "group" ? "Groups" : "Groups and Private Chat" - + const scope = switchOnScope(cmd, { + private: "šŸ‘¤ Private chats only", + group: "šŸ‘„ Groups only", + both: "šŸŒ Both private and group chats", + }) return fmt(({ n, b, i }) => [ typeof cmd.trigger === "string" ? `/${cmd.trigger}` : cmd.trigger.map((t) => `/${t}`).join(" | "), ...args.map(({ key, optional }) => (optional ? n`[${i`${key}`}]` : n`<${i`${key}`}>`)), @@ -281,10 +292,14 @@ export class ManagedCommands< private static formatCommandShort(cmd: AnyCommand): string { const args = cmd.args ?? [] + const trigger: string = + typeof cmd.trigger === "string" ? `/${cmd.trigger}` : cmd.trigger.map((t) => `/${t}`).join(" | ") + const scope = switchOnScope(cmd, { private: "šŸ‘¤", group: "šŸ‘„", both: "šŸŒ" }) + const admin = isAllowedInGroups(cmd) && cmd.permissions?.allowGroupAdmins ? "šŸ›”ļø" : "" return fmt(({ i, n }) => [ - typeof cmd.trigger === "string" ? `/${cmd.trigger}` : cmd.trigger.map((t) => `/${t}`).join(" | "), + trigger, ...args.map(({ key, optional }) => (optional ? i` [${key}]` : i` <${key}>`)), - n`\n\t${cmd.description ?? "No description"}`, + n`\n\t${scope}${admin}${cmd.description ?? "No description"}`, ]) } @@ -364,13 +379,41 @@ export class ManagedCommands< }) ) - this.composer.command("help", async (ctx) => { - if (ctx.chat.type !== "private") - return void ephemeral( - ctx.reply(fmt(({ n, code }) => n`You can only send ${code`/help`} in private chat with the bot.`)), - 10_000 - ) + const setFreeCommands = once(async (ctx: OC) => { + const freeCommands = this.getCommands().filter((cmd) => this.isCommandAllowedForRoles(cmd, [])) + const privateCommands: BotCommand[] = freeCommands + .filter((cmd) => isAllowedInPrivate(cmd)) + .flatMap((cmd) => toBotCommands(cmd)) + .concat([{ command: "help", description: "Show available commands" }]) + await ctx.api.setMyCommands(privateCommands, { scope: { type: "all_private_chats" } }).catch(() => {}) + const groupCommands: BotCommand[] = freeCommands + .filter((cmd) => isAllowedInGroups(cmd) && this.isCommandAllowedInGroup(cmd, -100)) // only include commands that are allowed in all groups + .flatMap((cmd) => toBotCommands(cmd)) + .concat([{ command: "help", description: "Show available commands" }]) + await ctx.api.setMyCommands(groupCommands, { scope: { type: "all_group_chats" } }).catch(() => {}) + }) + + this.composer.use(async (ctx, next) => { + await setFreeCommands(ctx) + return next() + }) + + this.composer.on("message").use(async (ctx, next) => { + if (!ctx.from) return next() + const shouldSkip = (await this.hooks.cachedUserSetCommands?.(ctx.from.id, ctx.chat.id)) ?? false + if (shouldSkip) return next() + const allowedCommands = await this.getAllowedCommandsFor(ctx) + await ctx.api + .setMyCommands(allowedCommands.flatMap(toBotCommands), { + scope: { type: "chat_member", chat_id: ctx.chat.id, user_id: ctx.from.id }, + }) + .catch(() => {}) + return next() + }) + this.composer.command("help", async (ctx) => { + if (!ctx.from) return + const userId = ctx.from.id const text = ctx.message?.text ?? "" const [_, cmdArg] = text.replaceAll("/", "").split(" ") @@ -383,14 +426,30 @@ export class ManagedCommands< return ctx.reply(ManagedCommands.formatCommandUsage(cmd)) } + const getUserRoles = once(async () => await this.getUserRoles(userId)) + const isFromGroupAdmin = once(async () => { + if (ctx.chat.type === "private") return true + return await this.isFromGroupAdmin(ctx) + }) + + const rawCollections = await asyncMap(Object.entries(this.commands), async ([collection, cmds]) => ({ + collection, + commands: await asyncFilter(cmds, async (cmd) => + this.checkPermissionsCached(cmd, ctx, getUserRoles, isFromGroupAdmin) + ), + })) + const collections = rawCollections.filter((c) => c.commands.length > 0) + const reply = fmt( - ({ u, b, skip, n, code }) => [ + ({ u, b, skip, n, code, i }) => [ b`Available commands:`, - ...Object.entries(this.commands).flatMap(([collection, cmds]) => [ + ...collections.flatMap(({ collection, commands }) => [ collection === "default" ? "" : u`${b`\n${collection}:`}`, - ...cmds.flatMap((cmd) => [skip`${ManagedCommands.formatCommandShort(cmd)}`]), + ...commands.map((cmd) => skip`${ManagedCommands.formatCommandShort(cmd)}`), ]), - n`\n\nType ${code`\/help `} for more details on a specific command.`, + i`\nšŸ‘¤: Private only, šŸ‘„: Group only, šŸŒ: Everywhere`, + i`Commands marked with šŸ›”ļø are restricted to administrators.`, + n`Type ${code`\/help `} for more details on a specific command.`, ], { sep: "\n" } ) @@ -407,36 +466,84 @@ export class ManagedCommands< return cmds } - private async checkPermissions(command: AnyCommand, ctx: CommandContext): Promise { - if (!command.permissions) return true - if (!ctx.from) return false + /** + * Checks whether a command is allowed in a specific group based on its permissions + */ + private isCommandAllowedInGroup(command: AnyGroupCommand, chatId: number): boolean { + const { allowedGroupsId, excludedGroupsId } = command.permissions ?? {} + if (allowedGroupsId && !allowedGroupsId.includes(chatId)) return false + if (excludedGroupsId?.includes(chatId)) return false + return true + } - const { allowedRoles, excludedRoles } = command.permissions + /** + * Checks whether a command is allowed for a specific set of roles based on its permissions + */ + private isCommandAllowedForRoles(command: AnyCommand, roles: TRole[]): boolean { + const { allowedRoles, excludedRoles } = command.permissions ?? {} + if (allowedRoles?.every((r) => !roles.includes(r))) return false + if (excludedRoles?.some((r) => roles.includes(r))) return false + return true + } - if (isAllowedInGroups(command) && (ctx.chat.type === "group" || ctx.chat.type === "supergroup")) { - const { allowGroupAdmins, allowedGroupsId, excludedGroupsId } = command.permissions + private async isFromGroupAdmin(ctx: OC): Promise { + if (!ctx.from || !ctx.chatId) return false + if (this.hooks.overrideGroupAdminCheck) { + const isAdmin = await this.hooks.overrideGroupAdminCheck(ctx.from.id, ctx.chatId, ctx) + if (isAdmin) return true + } else { + const { status: groupRole } = await ctx.getChatMember(ctx.from.id) + if (groupRole === "administrator" || groupRole === "creator") return true + } + return false + } - if (allowedGroupsId && !allowedGroupsId.includes(ctx.chatId)) return false - if (excludedGroupsId?.includes(ctx.chatId)) return false + private async getAllowedCommandsFor(ctx: ContextWith): Promise[]> { + const getUserRoles = once(() => this.getUserRoles(ctx.from.id)) + const isFromGroupAdmin = once(() => this.isFromGroupAdmin(ctx)) + + return await Promise.all( + this.getCommands() + .filter(isFromPrivateChat(ctx) ? (cmd) => isAllowedInPrivate(cmd) : (cmd) => isAllowedInGroups(cmd)) + .map((cmd) => + this.checkPermissionsCached(cmd, ctx, getUserRoles, isFromGroupAdmin).then((allowed) => + allowed ? cmd : null + ) + ) + ).then((cmds) => cmds.filter((c) => c !== null)) + } - if (allowGroupAdmins) { - if (this.hooks.overrideGroupAdminCheck) { - const isAdmin = await this.hooks.overrideGroupAdminCheck(ctx.from.id, ctx.chatId, ctx) - if (isAdmin) return true - } else { - const { status: groupRole } = await ctx.getChatMember(ctx.from.id) - if (groupRole === "administrator" || groupRole === "creator") return true - } + private async checkPermissionsCached( + command: AnyCommand, + ctx: ContextWith, + getUserRoles: () => Promise, + isFromGroupAdmin: () => Promise + ): Promise { + if (!command.permissions) return true + + if (isAllowedInGroups(command)) { + const allowed = this.isCommandAllowedInGroup(command, ctx.chat.id) + if (!allowed) return false + + if (command.permissions.allowGroupAdmins) { + const isAdmin = await isFromGroupAdmin() + if (isAdmin) return true } } - const roles = await this.getUserRoles(ctx.from.id, ctx) - - // blacklist is stronger than whitelist - if (allowedRoles?.every((r) => !roles.includes(r))) return false - if (excludedRoles?.some((r) => roles.includes(r))) return false + const roles = await getUserRoles() + return this.isCommandAllowedForRoles(command, roles) + } - return true + private async checkPermissions(command: AnyCommand, ctx: CommandContext): Promise { + if (!ctx.from) return false + const userId = ctx.from.id + return this.checkPermissionsCached( + command, + ctx, + () => this.getUserRoles(userId), + () => this.isFromGroupAdmin(ctx) + ) } /** diff --git a/src/redis/set.ts b/src/redis/set.ts new file mode 100644 index 0000000..ada3c75 --- /dev/null +++ b/src/redis/set.ts @@ -0,0 +1,117 @@ +import { EventEmitter } from "node:events" +import { + createClient, + type RedisClientOptions, + type RedisClientType, + type RedisFunctions, + type RedisModules, + type RedisScripts, +} from "redis" + +export interface RedisSetOptions { + /** Redis client instance, or options to create one */ + redis: RedisClientType | RedisClientOptions + /** Time to live for each entry in seconds, uses redis' EXPIRE command */ + ttl?: number + /** + * Prefix for each key stored in redis, to avoid collisions, if not provided a + * default one will be used to ensure uniqueness across multiple instances + */ + prefix?: string +} + +export class RedisSet< + M extends RedisModules = RedisModules, + F extends RedisFunctions = RedisFunctions, + S extends RedisScripts = RedisScripts, +> { + private static instanceCount = 0 + private prefix: string + // In-memory cache used when Redis is not available + private memoryCache: Set = new Set() + // temporary store for keys that need to be deleted once redis is back (used when delete does not find the key in memoryCache) + private deletions: Set = new Set() + private redisClient: RedisClientType + + constructor(private options: RedisSetOptions) { + const prefix = options.prefix ?? `redis-set-${RedisSet.instanceCount++}` + if (prefix.endsWith(":")) { + prefix.slice(0, -1) + } + this.prefix = prefix + if (options.redis instanceof EventEmitter) { + // RedisClient extends event emitter :) + this.redisClient = options.redis + } else { + this.redisClient = createClient(options.redis) + void this.redisClient.connect() + } + + this.redisClient.on("ready", () => { + void this.flushMemoryCache() + }) + } + + /** + * Flush the in-memory cache to Redis. Called automatically when the Redis + * connection is re-established. + */ + private async flushMemoryCache() { + // write all memoryCache entries to redis + await Promise.all(this.memoryCache.values().map((value) => this._add(value))) + this.memoryCache.clear() + // delete all keys that were marked for deletion while redis was down + await Promise.all(this.deletions.values().map((k) => this._delete(k))) + this.deletions.clear() + } + + private ready(): boolean { + return this.redisClient.isOpen && this.redisClient.isReady + } + + /** + * Writes a value to Redis. + * + * Sets an expiry if ttl is set in options. + * @param value The value to insert in the set. + */ + private async _add(value: string) { + await this.redisClient.sAdd(this.prefix, value) + if (this.options.ttl) { + await this.redisClient.expire(this.prefix, this.options.ttl) + } + } + + /** + * Deletes a key from Redis. + * @param key The key to delete. + */ + private async _delete(value: string) { + await this.redisClient.sRem(this.prefix, value) + } + + async add(value: string): Promise { + if (this.ready()) { + await this._add(value) + } else { + this.memoryCache.add(value) + } + } + + async delete(value: string): Promise { + if (this.ready()) { + await this._delete(value) + } else { + // Try to delete from memory cache, if not found add to deletions set + if (!this.memoryCache.delete(value)) this.deletions.add(value) + } + } + + async has(value: string): Promise { + if (this.ready()) { + return await this.redisClient.sIsMember(this.prefix, value) + } else { + return this.memoryCache.has(value) + } + } +} diff --git a/src/utils/arrays.ts b/src/utils/arrays.ts new file mode 100644 index 0000000..4cb15c8 --- /dev/null +++ b/src/utils/arrays.ts @@ -0,0 +1,9 @@ +export function asyncFilter(arr: T[], predicate: (item: T) => Promise): Promise { + return Promise.all(arr.map(async (item) => ({ item, keep: await predicate(item) }))).then((results) => + results.filter((result) => result.keep).map((result) => result.item) + ) +} + +export function asyncMap(arr: T[], mapper: (item: T) => Promise): Promise { + return Promise.all(arr.map(mapper)) +} diff --git a/src/utils/once.ts b/src/utils/once.ts index a6ae914..65574a3 100644 --- a/src/utils/once.ts +++ b/src/utils/once.ts @@ -1,4 +1,5 @@ import type { MaybePromise } from "./types" +import { Awaiter } from "./wait" /** * Wraps a function so that it can only be invoked once. @@ -15,14 +16,14 @@ import type { MaybePromise } from "./types" * @returns A wrapped version of `fn` that only runs on the first call */ export function once(fn: (...args: A) => MaybePromise) { + const result: Awaiter = new Awaiter() let called = false - let result: R return async (...args: A) => { if (!called) { called = true - result = await fn(...args) + result.resolve(await fn(...args)) } - return result + return await result } } diff --git a/src/utils/types.ts b/src/utils/types.ts index b6ccd85..0c42622 100644 --- a/src/utils/types.ts +++ b/src/utils/types.ts @@ -5,14 +5,8 @@ import type { ApiInput, ApiOutput } from "@/backend" import type { ManagedCommandsFlavor } from "@/lib/managed-commands" import type { TelemetryContextFlavor } from "@/modules/telemetry" -export type OptionalPropertyOf = Exclude< - { - [K in keyof T]: T[K] extends undefined ? never : K - }[keyof T], - undefined -> -export type ContextWith

> = Exclude & { - [K in P]: NonNullable +export type ContextWith = C & { + [K in P]: NonNullable } export type MaybePromise = T | Promise