Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions db/sqlc/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions db/sqlc/queries/sessions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ WHERE id = $2;
DELETE FROM sessions
WHERE state = $1;

-- name: DeleteSession :exec
DELETE FROM sessions
WHERE id = $1;

-- name: GetSessionByLocalPublicKey :one
SELECT * FROM sessions
WHERE local_public_key = $1;
Expand Down
10 changes: 10 additions & 0 deletions db/sqlc/sessions.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions session/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ type Store interface {
// StateReserved state.
DeleteReservedSessions(ctx context.Context) error

// DeleteReservedSession deletes the session with the given ID if it is
// in the StateReserved state.
DeleteReservedSession(ctx context.Context, id ID) error

// ShiftState updates the state of the session with the given ID to the
// "dest" state.
ShiftState(ctx context.Context, id ID, dest State) error
Expand Down
154 changes: 104 additions & 50 deletions session/kvdb_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,10 @@ func (db *BoltStore) DeleteReservedSessions(_ context.Context) error {
return err
}

return sessionBucket.ForEach(func(k, v []byte) error {
// We create a copy of the sessions to delete so that we are
// not iterating and modifying the bucket at the same time.
var sessionsToDelete []*Session
err = sessionBucket.ForEach(func(k, v []byte) error {
// We'll also get buckets here, skip those (identified
// by nil value).
if v == nil {
Expand All @@ -458,69 +461,120 @@ func (db *BoltStore) DeleteReservedSessions(_ context.Context) error {
return nil
}

err = sessionBucket.Delete(k)
if err != nil {
return err
}
sessionsToDelete = append(sessionsToDelete, session)

idIndexBkt := sessionBucket.Bucket(idIndexKey)
if idIndexBkt == nil {
return ErrDBInitErr
}
return nil
})
if err != nil {
return err
}

// Delete the entire session ID bucket.
err = idIndexBkt.DeleteBucket(session.ID[:])
if err != nil {
for _, session := range sessionsToDelete {
if err := deleteSession(sessionBucket,
session); err != nil {
return err
}
}

groupIdIndexBkt := sessionBucket.Bucket(groupIDIndexKey)
if groupIdIndexBkt == nil {
return ErrDBInitErr
}
return nil
})
}

groupBkt := groupIdIndexBkt.Bucket(session.GroupID[:])
if groupBkt == nil {
return ErrDBInitErr
}
// deleteSession deletes all the parts of a session from the database. This
// assumes that the session has already been fetched from the db.
func deleteSession(sessionBucket *bbolt.Bucket, session *Session) error {
sessionKey := getSessionKey(session)
err := sessionBucket.Delete(sessionKey)
if err != nil {
return err
}

sessionIDsBkt := groupBkt.Bucket(sessionIDKey)
if sessionIDsBkt == nil {
return ErrDBInitErr
}
idIndexBkt := sessionBucket.Bucket(idIndexKey)
if idIndexBkt == nil {
return ErrDBInitErr
}

var (
seqKey []byte
numSessions int
)
err = sessionIDsBkt.ForEach(func(k, v []byte) error {
numSessions++
// Delete the entire session ID bucket.
err = idIndexBkt.DeleteBucket(session.ID[:])
if err != nil {
return err
}

if !bytes.Equal(v, session.ID[:]) {
return nil
}
groupIdIndexBkt := sessionBucket.Bucket(groupIDIndexKey)
if groupIdIndexBkt == nil {
return ErrDBInitErr
}

seqKey = k
groupBkt := groupIdIndexBkt.Bucket(session.GroupID[:])
if groupBkt == nil {
return ErrDBInitErr
}

return nil
})
if err != nil {
return err
}
sessionIDsBkt := groupBkt.Bucket(sessionIDKey)
if sessionIDsBkt == nil {
return ErrDBInitErr
}

if numSessions == 0 {
return fmt.Errorf("no sessions found for "+
"group ID %x", session.GroupID)
}
var (
seqKey []byte
numSessions int
)
err = sessionIDsBkt.ForEach(func(k, v []byte) error {
numSessions++

if numSessions == 1 {
// Delete the whole group bucket.
return groupBkt.DeleteBucket(sessionIDKey)
}
if !bytes.Equal(v, session.ID[:]) {
return nil
}

// Else, delete just the session ID entry.
return sessionIDsBkt.Delete(seqKey)
})
seqKey = k

return nil
})
if err != nil {
return err
}

if numSessions == 0 {
return fmt.Errorf("no sessions found for "+
"group ID %x", session.GroupID)
}

if numSessions == 1 {
// If this is the last session in the group, we can delete the
// whole group bucket.
return groupIdIndexBkt.DeleteBucket(session.GroupID[:])
}

// Else, delete just the session ID entry from the group.
return sessionIDsBkt.Delete(seqKey)
}

// DeleteReservedSession removes a given session that is in the reserved state
// from the database.
//
// NOTE: This is part of the Store interface.
func (db *BoltStore) DeleteReservedSession(_ context.Context, id ID) error {
return db.Update(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
if err != nil {
return err
}

// We'll first get the session to make sure it's actually in the
// reserved state before deleting. This gives us a slightly
// better error message than just trying to delete and getting a
// "not found" if the session was in another state.
session, err := getSessionByID(sessionBucket, id)
if err != nil {
return err
}

if session.State != StateReserved {
return fmt.Errorf("session not in reserved state, is "+
"%v", session.State)
}

return deleteSession(sessionBucket, session)
})
}

Expand Down
25 changes: 25 additions & 0 deletions session/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type SQLQueries interface {
SetSessionGroupID(ctx context.Context, arg sqlc.SetSessionGroupIDParams) error
UpdateSessionState(ctx context.Context, arg sqlc.UpdateSessionStateParams) error
DeleteSessionsWithState(ctx context.Context, state int16) error
DeleteSession(ctx context.Context, id int64) error
GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error)
GetAccount(ctx context.Context, id int64) (sqlc.Account, error)
}
Expand Down Expand Up @@ -431,6 +432,30 @@ func (s *SQLStore) DeleteReservedSessions(ctx context.Context) error {
})
}

// DeleteReservedSession removes a given session that is in the reserved state
// from the database.
//
// NOTE: This is part of the Store interface.
func (s *SQLStore) DeleteReservedSession(ctx context.Context, id ID) error {
var writeTxOpts db.QueriesTxOptions
return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error {
session, err := db.GetSessionByAlias(ctx, id[:])
if errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("%w: unable to get session: %w",
ErrSessionNotFound, err)
} else if err != nil {
return fmt.Errorf("unable to get session: %w", err)
}

if State(session.State) != StateReserved {
return fmt.Errorf("session not in reserved state, is "+
"%v", State(session.State))
}

return db.DeleteSession(ctx, session.ID)
})
}

// GetSessionByLocalPub fetches the session with the given local pub key.
//
// NOTE: This is part of the Store interface.
Expand Down
26 changes: 26 additions & 0 deletions session/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ func TestBasicSessionStore(t *testing.T) {
// of the sessions are reserved.
require.NoError(t, db.DeleteReservedSessions(ctx))

// Explicitly trying to delete session 1 should fail as it's not
// reserved.
require.Error(t, db.DeleteReservedSession(ctx, s1.ID))

sessions, err = db.ListSessionsByState(ctx, StateReserved)
require.NoError(t, err)
require.Empty(t, sessions)
Expand Down Expand Up @@ -192,6 +196,28 @@ func TestBasicSessionStore(t *testing.T) {
_, err = db.GetGroupID(ctx, s4.ID)
require.ErrorIs(t, err, ErrSessionNotFound)

// Reserve a new session and link it to session 1.
s5, err := reserveSession(
db, "session 5", withLinkedGroupID(&session1.GroupID),
)
require.NoError(t, err)
sessions, err = db.ListSessionsByState(ctx, StateReserved)
require.NoError(t, err)
require.Equal(t, 1, len(sessions))
assertEqualSessions(t, s5, sessions[0])

// Now delete the reserved session by its ID and show that it is no
// longer in the database and no longer in the group ID/session ID
// index.
require.NoError(t, db.DeleteReservedSession(ctx, s5.ID))

sessions, err = db.ListSessionsByState(ctx, StateReserved)
require.NoError(t, err)
require.Empty(t, sessions)

_, err = db.GetGroupID(ctx, s5.ID)
require.ErrorIs(t, err, ErrSessionNotFound)

// Only session 1 should remain in this group.
sessIDs, err = db.GetSessionIDs(ctx, s4.GroupID)
require.NoError(t, err)
Expand Down
16 changes: 14 additions & 2 deletions session_rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,18 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
return nil, fmt.Errorf("error creating new session: %v", err)
}

// If we tried to link to a previous session, we delete the newly
// created session in the case of errors to avoid having non-revoked
// sessions lying around.
fail := func(err error) error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Not sure we need a separate lambda function for this when it's only being called once?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, I wasn't sure if we also want to delete the session on an error later on, so I left it for now

if len(req.LinkedGroupId) != 0 {
err := s.cfg.db.DeleteReservedSession(ctx, sess.ID)
log.Errorf("error deleting session after failed "+
"linking attempt: %v", err)
}
return err
}

// If this session is being linked to a previous one, then we need to
// use the previous session's local private key to sign the new
// session's public key in order to prove to the Autopilot server that
Expand Down Expand Up @@ -1286,8 +1298,8 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
privacyFlags.Serialize(),
)
if err != nil {
return nil, fmt.Errorf("error registering session with "+
"autopilot server: %v", err)
return nil, fail(fmt.Errorf("error registering session with "+
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we're not 100% sure here that this errors due to the session not having been stored on the autopilot server (as this could simplify fail for loss of the connection mid execution for example), I can see the motivation still persisting this on the litd node, to aim to not have sessions stored on the autopilot server, which do not exist on the corresponding litd node.
But since that can already occur for various other reasons as well, I think it's fine to just delete it.

"autopilot server: %v", err))
}

err = s.cfg.db.UpdateSessionRemotePubKey(ctx, sess.ID, remoteKey)
Expand Down
Loading