Skip to content

Redis version #2

@omarkhatibgg

Description

@omarkhatibgg

Hey,

Thank you for this awesome library. it's really solve a huge problem for me. I created a Redis version using Claude tbh. to use it across multiple servers.

here is the code

import stringify from "fast-json-stable-stringify";
import Redis from "ioredis";

type MaybePromise<T> = Promise<T> | T;

interface ProcedureResolverOptionsLike {
  signal?: AbortSignal;
}

export interface LiveOptions<TOpts extends ProcedureResolverOptionsLike, T> {
  key: string | string[] | ((opts: TOpts) => string | string[]);
  resolver: (opts: TOpts) => MaybePromise<T>;
}

export class RedisLiveStore {
  private redis: Redis;
  private subscriber: Redis;
  private channelPrefix: string;
  private subscriptions: Map<string, Set<() => void>> = new Map();

  constructor(options: {
    redisOptions?: Redis.RedisOptions;
    channelPrefix?: string;
    redis?: Redis;
    subscriber?: Redis;
  } = {}) {
    this.channelPrefix = options.channelPrefix || "trpc-live:";
    
    if (options.redis) {
      // Use the provided Redis instance
      this.redis = options.redis;
      
      // Use provided subscriber or create a duplicate connection
      if (options.subscriber) {
        this.subscriber = options.subscriber;
      } else {
        this.subscriber = new Redis(this.redis.options);
      }
    } else {
      this.redis = new Redis(options.redisOptions);
      this.subscriber = new Redis(options.redisOptions);
    }

    // Setup message handler
    this.subscriber.on("message", (channel, message) => {
      if (channel.startsWith(this.channelPrefix)) {
        const key = channel.slice(this.channelPrefix.length);
        const handlers = this.subscriptions.get(key);
        if (handlers) {
          handlers.forEach(fn => fn());
        }
      }
    });
  }

  async count(key: string | string[]) {
    const keys = castArray(key);
    let count = 0;
    
    for (const k of keys) {
      const channel = this.channelPrefix + k;
      // Get number of subscribers for this channel
      const result = await this.redis.pubsub("NUMSUB", channel);
      // Result is [channelName, subscriberCount]
      if (Array.isArray(result) && result.length >= 2) {
        count += Number(result[1]);
      }
    }
    
    return count;
  }

  async invalidate(key: string | string[]) {
    const keys = castArray(key);
    
    for (const k of keys) {
      const channel = this.channelPrefix + k;
      // Publish an invalidation message
      await this.redis.publish(channel, "invalidate");
    }
  }

  private async subscribe(keys: string[], fn: () => void) {
    for (const key of keys) {
      const channel = this.channelPrefix + key;
      
      // Track local handlers
      let handlers = this.subscriptions.get(key);
      if (!handlers) {
        handlers = new Set();
        this.subscriptions.set(key, handlers);
        
        // Subscribe to Redis channel when first handler is added
        await this.subscriber.subscribe(channel);
      }
      
      handlers.add(fn);
    }
  }

  private async unsubscribe(keys: string[], fn: () => void) {
    for (const key of keys) {
      const channel = this.channelPrefix + key;
      const handlers = this.subscriptions.get(key);
      
      if (handlers) {
        handlers.delete(fn);
        
        if (handlers.size === 0) {
          // Unsubscribe from Redis channel when no handlers remain
          this.subscriptions.delete(key);
          await this.subscriber.unsubscribe(channel);
        }
      }
    }
  }

  live<TOpts extends ProcedureResolverOptionsLike, T>({
    key,
    resolver
  }: LiveOptions<TOpts, T>) {
    const store = this;

    return async function* (opts: TOpts) {
      const keys = castArray(typeof key === "function" ? key(opts) : key);
      let triggerNext = () => {};
      let triggerExit = () => {};
      let invalidationPromise = Promise.resolve();

      function resetPromise() {
        invalidationPromise = new Promise((resolve, reject) => {
          triggerNext = resolve;
          triggerExit = reject;
        });
      }

      function invalidate() {
        triggerNext();
      }

      function abort() {
        triggerExit();
      }

      opts.signal?.addEventListener("abort", abort);
      await store.subscribe(keys, invalidate);

      try {
        yield resolver(opts);
        while (!opts.signal?.aborted) {
          resetPromise();
          await invalidationPromise;
          yield resolver(opts);
        }
      } finally {
        await store.unsubscribe(keys, invalidate);
        opts.signal?.removeEventListener("abort", abort);
      }
    };
  }

  async disconnect(closeConnections = true) {
    // Unsubscribe from all channels without closing connections
    const channels = Array.from(this.subscriptions.keys()).map(key => this.channelPrefix + key);
    
    if (channels.length > 0) {
      await this.subscriber.unsubscribe(...channels);
    }
    
    // Clear subscriptions
    this.subscriptions.clear();
    
    // Only close Redis connections if they were created by this class
    if (closeConnections) {
      if (this.subscriber.status === "ready") {
        await this.subscriber.quit();
      }
      
      if (this.redis.status === "ready") {
        await this.redis.quit();
      }
    }
  }
}

function castArray<T>(value: T | T[]) {
  return Array.isArray(value) ? value : [value];
}

export function key(...args: any[]) {
  return stringify(args);
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions