Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import type {
} from "rivetkit";
import { lookupInRegistry } from "rivetkit";
import type { Client } from "rivetkit/client";
import {
type ActorDriver,
type AnyActorInstance,
type ManagerDriver,
import type {
ActorDriver,
AnyActorInstance,
ManagerDriver,
} from "rivetkit/driver-helpers";
import { promiseWithResolvers } from "rivetkit/utils";
import { KEYS } from "./actor-handler-do";
Expand Down Expand Up @@ -239,7 +239,6 @@ export class CloudflareActorsActorDriver implements ActorDriver {
// Persist data key
return Uint8Array.from([1]);
}

}

export function createCloudflareActorsActorDriverBuilder(
Expand Down
157 changes: 49 additions & 108 deletions rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import * as cbor from "cbor-x";
import onChange from "on-change";
import { isCborSerializable } from "@/common/utils";
import type * as protocol from "@/schemas/client-protocol/mod";
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils";
import type { AnyDatabaseProvider } from "../database";
import * as errors from "../errors";
import {
ACTOR_INSTANCE_PERSIST_SYMBOL,
type ActorInstance,
} from "../instance/mod";
import type { PersistedConn } from "../instance/persisted";
import { CachedSerializer } from "../protocol/serde";
import type { ConnDriver } from "./driver";
import { StateManager } from "./state-manager";

export function generateConnRequestId(): string {
return crypto.randomUUID();
Expand All @@ -24,6 +22,12 @@ export type AnyConn = Conn<any, any, any, any, any, any>;

export const CONN_PERSIST_SYMBOL = Symbol("persist");
export const CONN_DRIVER_SYMBOL = Symbol("driver");
export const CONN_ACTOR_SYMBOL = Symbol("actor");
export const CONN_STATE_ENABLED_SYMBOL = Symbol("stateEnabled");
export const CONN_PERSIST_RAW_SYMBOL = Symbol("persistRaw");
export const CONN_HAS_CHANGES_SYMBOL = Symbol("hasChanges");
export const CONN_MARK_SAVED_SYMBOL = Symbol("markSaved");
export const CONN_SEND_MESSAGE_SYMBOL = Symbol("sendMessage");

/**
* Represents a client connection to a actor.
Expand All @@ -38,72 +42,66 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
// TODO: Remove this cyclical reference
#actor: ActorInstance<S, CP, CS, V, I, DB>;

/**
* The proxied state that notifies of changes automatically.
*
* Any data that should be stored indefinitely should be held within this
* object.
*
* This will only be persisted if using hibernatable WebSockets. If not,
* this is just used to hole state.
*/
[CONN_PERSIST_SYMBOL]!: PersistedConn<CP, CS>;

/** Raw persist object without the proxy wrapper */
#persistRaw: PersistedConn<CP, CS>;

/** Track if this connection's state has changed */
#changed = false;
// MARK: - Managers
#stateManager!: StateManager<CP, CS>;

/**
* If undefined, then nothing is connected to this.
*/
[CONN_DRIVER_SYMBOL]?: ConnDriver;

public get params(): CP {
return this[CONN_PERSIST_SYMBOL].params;
// MARK: - Public Getters

get [CONN_ACTOR_SYMBOL](): ActorInstance<S, CP, CS, V, I, DB> {
return this.#actor;
}

public get stateEnabled() {
return this.#actor.connStateEnabled;
get [CONN_PERSIST_SYMBOL](): PersistedConn<CP, CS> {
return this.#stateManager.persist;
}

get params(): CP {
return this.#stateManager.params;
}

get [CONN_STATE_ENABLED_SYMBOL](): boolean {
return this.#stateManager.stateEnabled;
}

/**
* Gets the current state of the connection.
*
* Throws an error if the state is not enabled.
*/
public get state(): CS {
this.#validateStateEnabled();
if (!this[CONN_PERSIST_SYMBOL].state)
throw new Error("state should exists");
return this[CONN_PERSIST_SYMBOL].state;
get state(): CS {
return this.#stateManager.state;
}

/**
* Sets the state of the connection.
*
* Throws an error if the state is not enabled.
*/
public set state(value: CS) {
this.#validateStateEnabled();
this[CONN_PERSIST_SYMBOL].state = value;
set state(value: CS) {
this.#stateManager.state = value;
}

/**
* Unique identifier for the connection.
*/
public get id(): ConnId {
return this[CONN_PERSIST_SYMBOL].connId;
get id(): ConnId {
return this.#stateManager.persist.connId;
}

/**
* @experimental
*
* If the underlying connection can hibernate.
*/
public get isHibernatable(): boolean {
if (!this[CONN_PERSIST_SYMBOL].hibernatableRequestId) {
get isHibernatable(): boolean {
const hibernatableRequestId =
this.#stateManager.persist.hibernatableRequestId;
if (!hibernatableRequestId) {
return false;
}
return (
Expand All @@ -112,7 +110,7 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
].hibernatableConns.findIndex((conn: any) =>
arrayBuffersEqual(
conn.hibernatableRequestId,
this[CONN_PERSIST_SYMBOL].hibernatableRequestId!,
hibernatableRequestId,
),
) > -1
);
Expand All @@ -121,8 +119,8 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
/**
* Timestamp of the last time the connection was seen, i.e. the last time the connection was active and checked for liveness.
*/
public get lastSeen(): number {
return this[CONN_PERSIST_SYMBOL].lastSeen;
get lastSeen(): number {
return this.#stateManager.persist.lastSeen;
}

/**
Expand All @@ -132,94 +130,37 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
*
* @protected
*/
public constructor(
constructor(
actor: ActorInstance<S, CP, CS, V, I, DB>,
persist: PersistedConn<CP, CS>,
) {
this.#actor = actor;
this.#persistRaw = persist;
this.#setupPersistProxy(persist);
}

/**
* Sets up the proxy for connection persistence with change tracking
*/
#setupPersistProxy(persist: PersistedConn<CP, CS>) {
// If this can't be proxied, return raw value
if (persist === null || typeof persist !== "object") {
this[CONN_PERSIST_SYMBOL] = persist;
return;
}

// Listen for changes to the object
this[CONN_PERSIST_SYMBOL] = onChange(
persist,
(
path: string,
value: any,
_previousValue: any,
_applyData: any,
) => {
// Validate CBOR serializability for state changes
if (path.startsWith("state")) {
let invalidPath = "";
if (
!isCborSerializable(
value,
(invalidPathPart: string) => {
invalidPath = invalidPathPart;
},
"",
)
) {
throw new errors.InvalidStateType({
path: path + (invalidPath ? `.${invalidPath}` : ""),
});
}
}

this.#changed = true;
this.#actor.rLog.debug({
msg: "conn onChange triggered",
connId: this.id,
path,
});

// Notify actor that this connection has changed
this.#actor.markConnChanged(this);
},
{ ignoreDetached: true },
);
this.#stateManager = new StateManager(this);
this.#stateManager.initPersistProxy(persist);
}

/**
* Returns whether this connection has unsaved changes
*/
get hasChanges(): boolean {
return this.#changed;
[CONN_HAS_CHANGES_SYMBOL](): boolean {
return this.#stateManager.hasChanges();
}

/**
* Marks changes as saved
*/
markSaved() {
this.#changed = false;
[CONN_MARK_SAVED_SYMBOL]() {
this.#stateManager.markSaved();
}

/**
* Gets the raw persist data for serialization
*/
get persistRaw(): PersistedConn<CP, CS> {
return this.#persistRaw;
}

#validateStateEnabled() {
if (!this.stateEnabled) {
throw new errors.ConnStateNotEnabled();
}
get [CONN_PERSIST_RAW_SYMBOL](): PersistedConn<CP, CS> {
return this.#stateManager.persistRaw;
}

public sendMessage(message: CachedSerializer<protocol.ToClient>) {
[CONN_SEND_MESSAGE_SYMBOL](message: CachedSerializer<protocol.ToClient>) {
if (this[CONN_DRIVER_SYMBOL]) {
const driver = this[CONN_DRIVER_SYMBOL];
if (driver.sendMessage) {
Expand All @@ -245,14 +186,14 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
* @param args - The arguments for the event.
* @see {@link https://rivet.dev/docs/events|Events Documentation}
*/
public send(eventName: string, ...args: unknown[]) {
send(eventName: string, ...args: unknown[]) {
this.#actor.inspector.emitter.emit("eventFired", {
type: "event",
eventName,
args,
connId: this.id,
});
this.sendMessage(
this[CONN_SEND_MESSAGE_SYMBOL](
new CachedSerializer<protocol.ToClient>(
{
body: {
Expand All @@ -273,7 +214,7 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
*
* @param reason - The reason for disconnection.
*/
public async disconnect(reason?: string) {
async disconnect(reason?: string) {
if (this[CONN_DRIVER_SYMBOL]) {
const driver = this[CONN_DRIVER_SYMBOL];
if (driver.disconnect) {
Expand Down
Loading
Loading