diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index 3e2615030..6e8e7ed7d 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -17,6 +17,7 @@ type Querier interface { DeleteFeatureKVStoreRecord(ctx context.Context, arg DeleteFeatureKVStoreRecordParams) error DeleteGlobalKVStoreRecord(ctx context.Context, arg DeleteGlobalKVStoreRecordParams) error DeleteGroupKVStoreRecord(ctx context.Context, arg DeleteGroupKVStoreRecordParams) error + DeleteSession(ctx context.Context, id int64) error DeleteSessionsWithState(ctx context.Context, state int16) error GetAccount(ctx context.Context, id int64) (Account, error) GetAccountByLabel(ctx context.Context, label sql.NullString) (Account, error) diff --git a/db/sqlc/queries/sessions.sql b/db/sqlc/queries/sessions.sql index 2d0442cce..57dbdef31 100644 --- a/db/sqlc/queries/sessions.sql +++ b/db/sqlc/queries/sessions.sql @@ -18,6 +18,10 @@ WHERE id = $2; DELETE FROM sessions WHERE state = $1; +-- name: DeleteSession :exec +DELETE FROM sessions +WHERE id = $1; + -- name: GetSessionByLocalPublicKey :one SELECT * FROM sessions WHERE local_public_key = $1; diff --git a/db/sqlc/sessions.sql.go b/db/sqlc/sessions.sql.go index a7ff55185..72a7ab3bd 100644 --- a/db/sqlc/sessions.sql.go +++ b/db/sqlc/sessions.sql.go @@ -11,6 +11,16 @@ import ( "time" ) +const deleteSession = `-- name: DeleteSession :exec +DELETE FROM sessions +WHERE id = $1 +` + +func (q *Queries) DeleteSession(ctx context.Context, id int64) error { + _, err := q.db.ExecContext(ctx, deleteSession, id) + return err +} + const deleteSessionsWithState = `-- name: DeleteSessionsWithState :exec DELETE FROM sessions WHERE state = $1 diff --git a/session/interface.go b/session/interface.go index 36b5075fe..9c64d8956 100644 --- a/session/interface.go +++ b/session/interface.go @@ -325,6 +325,10 @@ type Store interface { // StateReserved state. DeleteReservedSessions(ctx context.Context) error + // DeleteReservedSession deletes the session with the given ID if it is + // in the StateReserved state. + DeleteReservedSession(ctx context.Context, id ID) error + // ShiftState updates the state of the session with the given ID to the // "dest" state. ShiftState(ctx context.Context, id ID, dest State) error diff --git a/session/kvdb_store.go b/session/kvdb_store.go index d52897966..da58ebf75 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -442,7 +442,10 @@ func (db *BoltStore) DeleteReservedSessions(_ context.Context) error { return err } - return sessionBucket.ForEach(func(k, v []byte) error { + // We create a copy of the sessions to delete so that we are + // not iterating and modifying the bucket at the same time. + var sessionsToDelete []*Session + err = sessionBucket.ForEach(func(k, v []byte) error { // We'll also get buckets here, skip those (identified // by nil value). if v == nil { @@ -458,69 +461,120 @@ func (db *BoltStore) DeleteReservedSessions(_ context.Context) error { return nil } - err = sessionBucket.Delete(k) - if err != nil { - return err - } + sessionsToDelete = append(sessionsToDelete, session) - idIndexBkt := sessionBucket.Bucket(idIndexKey) - if idIndexBkt == nil { - return ErrDBInitErr - } + return nil + }) + if err != nil { + return err + } - // Delete the entire session ID bucket. - err = idIndexBkt.DeleteBucket(session.ID[:]) - if err != nil { + for _, session := range sessionsToDelete { + if err := deleteSession(sessionBucket, + session); err != nil { return err } + } - groupIdIndexBkt := sessionBucket.Bucket(groupIDIndexKey) - if groupIdIndexBkt == nil { - return ErrDBInitErr - } + return nil + }) +} - groupBkt := groupIdIndexBkt.Bucket(session.GroupID[:]) - if groupBkt == nil { - return ErrDBInitErr - } +// deleteSession deletes all the parts of a session from the database. This +// assumes that the session has already been fetched from the db. +func deleteSession(sessionBucket *bbolt.Bucket, session *Session) error { + sessionKey := getSessionKey(session) + err := sessionBucket.Delete(sessionKey) + if err != nil { + return err + } - sessionIDsBkt := groupBkt.Bucket(sessionIDKey) - if sessionIDsBkt == nil { - return ErrDBInitErr - } + idIndexBkt := sessionBucket.Bucket(idIndexKey) + if idIndexBkt == nil { + return ErrDBInitErr + } - var ( - seqKey []byte - numSessions int - ) - err = sessionIDsBkt.ForEach(func(k, v []byte) error { - numSessions++ + // Delete the entire session ID bucket. + err = idIndexBkt.DeleteBucket(session.ID[:]) + if err != nil { + return err + } - if !bytes.Equal(v, session.ID[:]) { - return nil - } + groupIdIndexBkt := sessionBucket.Bucket(groupIDIndexKey) + if groupIdIndexBkt == nil { + return ErrDBInitErr + } - seqKey = k + groupBkt := groupIdIndexBkt.Bucket(session.GroupID[:]) + if groupBkt == nil { + return ErrDBInitErr + } - return nil - }) - if err != nil { - return err - } + sessionIDsBkt := groupBkt.Bucket(sessionIDKey) + if sessionIDsBkt == nil { + return ErrDBInitErr + } - if numSessions == 0 { - return fmt.Errorf("no sessions found for "+ - "group ID %x", session.GroupID) - } + var ( + seqKey []byte + numSessions int + ) + err = sessionIDsBkt.ForEach(func(k, v []byte) error { + numSessions++ - if numSessions == 1 { - // Delete the whole group bucket. - return groupBkt.DeleteBucket(sessionIDKey) - } + if !bytes.Equal(v, session.ID[:]) { + return nil + } - // Else, delete just the session ID entry. - return sessionIDsBkt.Delete(seqKey) - }) + seqKey = k + + return nil + }) + if err != nil { + return err + } + + if numSessions == 0 { + return fmt.Errorf("no sessions found for "+ + "group ID %x", session.GroupID) + } + + if numSessions == 1 { + // If this is the last session in the group, we can delete the + // whole group bucket. + return groupIdIndexBkt.DeleteBucket(session.GroupID[:]) + } + + // Else, delete just the session ID entry from the group. + return sessionIDsBkt.Delete(seqKey) +} + +// DeleteReservedSession removes a given session that is in the reserved state +// from the database. +// +// NOTE: This is part of the Store interface. +func (db *BoltStore) DeleteReservedSession(_ context.Context, id ID) error { + return db.Update(func(tx *bbolt.Tx) error { + sessionBucket, err := getBucket(tx, sessionBucketKey) + if err != nil { + return err + } + + // We'll first get the session to make sure it's actually in the + // reserved state before deleting. This gives us a slightly + // better error message than just trying to delete and getting a + // "not found" if the session was in another state. + session, err := getSessionByID(sessionBucket, id) + if err != nil { + return err + } + + if session.State != StateReserved { + return fmt.Errorf("session not in reserved state, is "+ + "%v", session.State) + } + + return deleteSession(sessionBucket, session) }) } diff --git a/session/sql_store.go b/session/sql_store.go index b1d366fe7..a82fd0e63 100644 --- a/session/sql_store.go +++ b/session/sql_store.go @@ -45,6 +45,7 @@ type SQLQueries interface { SetSessionGroupID(ctx context.Context, arg sqlc.SetSessionGroupIDParams) error UpdateSessionState(ctx context.Context, arg sqlc.UpdateSessionStateParams) error DeleteSessionsWithState(ctx context.Context, state int16) error + DeleteSession(ctx context.Context, id int64) error GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) GetAccount(ctx context.Context, id int64) (sqlc.Account, error) } @@ -431,6 +432,30 @@ func (s *SQLStore) DeleteReservedSessions(ctx context.Context) error { }) } +// DeleteReservedSession removes a given session that is in the reserved state +// from the database. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) DeleteReservedSession(ctx context.Context, id ID) error { + var writeTxOpts db.QueriesTxOptions + return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + session, err := db.GetSessionByAlias(ctx, id[:]) + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("%w: unable to get session: %w", + ErrSessionNotFound, err) + } else if err != nil { + return fmt.Errorf("unable to get session: %w", err) + } + + if State(session.State) != StateReserved { + return fmt.Errorf("session not in reserved state, is "+ + "%v", State(session.State)) + } + + return db.DeleteSession(ctx, session.ID) + }) +} + // GetSessionByLocalPub fetches the session with the given local pub key. // // NOTE: This is part of the Store interface. diff --git a/session/store_test.go b/session/store_test.go index 4b1b7d3bb..98b632919 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -156,6 +156,10 @@ func TestBasicSessionStore(t *testing.T) { // of the sessions are reserved. require.NoError(t, db.DeleteReservedSessions(ctx)) + // Explicitly trying to delete session 1 should fail as it's not + // reserved. + require.Error(t, db.DeleteReservedSession(ctx, s1.ID)) + sessions, err = db.ListSessionsByState(ctx, StateReserved) require.NoError(t, err) require.Empty(t, sessions) @@ -192,6 +196,28 @@ func TestBasicSessionStore(t *testing.T) { _, err = db.GetGroupID(ctx, s4.ID) require.ErrorIs(t, err, ErrSessionNotFound) + // Reserve a new session and link it to session 1. + s5, err := reserveSession( + db, "session 5", withLinkedGroupID(&session1.GroupID), + ) + require.NoError(t, err) + sessions, err = db.ListSessionsByState(ctx, StateReserved) + require.NoError(t, err) + require.Equal(t, 1, len(sessions)) + assertEqualSessions(t, s5, sessions[0]) + + // Now delete the reserved session by its ID and show that it is no + // longer in the database and no longer in the group ID/session ID + // index. + require.NoError(t, db.DeleteReservedSession(ctx, s5.ID)) + + sessions, err = db.ListSessionsByState(ctx, StateReserved) + require.NoError(t, err) + require.Empty(t, sessions) + + _, err = db.GetGroupID(ctx, s5.ID) + require.ErrorIs(t, err, ErrSessionNotFound) + // Only session 1 should remain in this group. sessIDs, err = db.GetSessionIDs(ctx, s4.GroupID) require.NoError(t, err) diff --git a/session_rpcserver.go b/session_rpcserver.go index 0d56bd79b..131ca1887 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -1213,6 +1213,18 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, return nil, fmt.Errorf("error creating new session: %v", err) } + // If we tried to link to a previous session, we delete the newly + // created session in the case of errors to avoid having non-revoked + // sessions lying around. + fail := func(err error) error { + if len(req.LinkedGroupId) != 0 { + err := s.cfg.db.DeleteReservedSession(ctx, sess.ID) + log.Errorf("error deleting session after failed "+ + "linking attempt: %v", err) + } + return err + } + // If this session is being linked to a previous one, then we need to // use the previous session's local private key to sign the new // session's public key in order to prove to the Autopilot server that @@ -1286,8 +1298,8 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, privacyFlags.Serialize(), ) if err != nil { - return nil, fmt.Errorf("error registering session with "+ - "autopilot server: %v", err) + return nil, fail(fmt.Errorf("error registering session with "+ + "autopilot server: %v", err)) } err = s.cfg.db.UpdateSessionRemotePubKey(ctx, sess.ID, remoteKey)