From 5bcef5d77a0e3ee471afd4bdc5b6f940dc0a7410 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alfonso=20S=C3=A1nchez-Beato?= Date: Thu, 16 Apr 2026 07:44:21 -0400 Subject: [PATCH 01/21] core-initrd: update changelog with latest PPA upload --- core-initrd/24.04/debian/changelog | 6 ++++++ core-initrd/25.10/debian/changelog | 6 ++++++ core-initrd/26.04/debian/changelog | 6 ++++++ 3 files changed, 18 insertions(+) diff --git a/core-initrd/24.04/debian/changelog b/core-initrd/24.04/debian/changelog index 637d6b69667..1ddb1df6430 100644 --- a/core-initrd/24.04/debian/changelog +++ b/core-initrd/24.04/debian/changelog @@ -1,3 +1,9 @@ +ubuntu-core-initramfs (69+2.75.2+g199.9a8c2f3+24.04) noble; urgency=medium + + * Update to snapd version 2.75.2+g199.9a8c2f3 + + -- Alfonso Sanchez-Beato Wed, 15 Apr 2026 16:49:42 -0400 + ubuntu-core-initramfs (69+2.75+g75.4b39daa+24.04) noble; urgency=medium * Update to snapd version 2.75+g75.4b39daa diff --git a/core-initrd/25.10/debian/changelog b/core-initrd/25.10/debian/changelog index 62279b587ec..96fe311a6a7 100644 --- a/core-initrd/25.10/debian/changelog +++ b/core-initrd/25.10/debian/changelog @@ -1,3 +1,9 @@ +ubuntu-core-initramfs (72+2.75.2+g199.9a8c2f3+25.10) questing; urgency=medium + + * Update to snapd version 2.75.2+g199.9a8c2f3 + + -- Alfonso Sanchez-Beato Wed, 15 Apr 2026 16:50:12 -0400 + ubuntu-core-initramfs (72+2.75+g75.4b39daa+25.10) questing; urgency=medium * Update to snapd version 2.75+g75.4b39daa diff --git a/core-initrd/26.04/debian/changelog b/core-initrd/26.04/debian/changelog index ed61e99486d..0b08697e2b6 100644 --- a/core-initrd/26.04/debian/changelog +++ b/core-initrd/26.04/debian/changelog @@ -1,3 +1,9 @@ +ubuntu-core-initramfs (73+2.75.2+g199.9a8c2f3+26.04) resolute; urgency=medium + + * Update to snapd version 2.75.2+g199.9a8c2f3 + + -- Alfonso Sanchez-Beato Wed, 15 Apr 2026 16:50:37 -0400 + ubuntu-core-initramfs (73+2.75+g75.4b39daa+26.04) resolute; urgency=medium * Update to snapd version 2.75+g75.4b39daa From 2da6fd0063a7617a0d06368848f665775cf9807f Mon Sep 17 00:00:00 2001 From: Miguel Pires Date: Mon, 20 Apr 2026 14:23:46 +0100 Subject: [PATCH 02/21] o/confdbstate: check for ephemeral change when missing save-view hook on commit (#16889) Although we error early if we can tell that a write affects ephemeral data but no save-view hook is present, a change-view hook may have written to an ephemeral path after that initial check so we need to check again before committing. Signed-off-by: Miguel Pires --- overlord/confdbstate/confdbmgr.go | 40 ++++++++++++++++ overlord/confdbstate/confdbmgr_test.go | 64 +++++++++++++++++++++++++- overlord/confdbstate/confdbstate.go | 2 + 3 files changed, 105 insertions(+), 1 deletion(-) diff --git a/overlord/confdbstate/confdbmgr.go b/overlord/confdbstate/confdbmgr.go index 7146beca59e..4ff4874fd1c 100644 --- a/overlord/confdbstate/confdbmgr.go +++ b/overlord/confdbstate/confdbmgr.go @@ -105,6 +105,46 @@ func (m *ConfdbManager) doCommitTransaction(t *state.Task, _ *tomb.Tomb) (err er } schema := confdbAssert.Schema().DatabagSchema + hasSaveViewHook := false + for _, task := range t.Change().Tasks() { + if task.Kind() != "run-hook" { + continue + } + + var hooksup hookstate.HookSetup + err := task.Get("hook-setup", &hooksup) + if err != nil { + return fmt.Errorf(`internal error: cannot get "hook-setup" from run-hook task: %w`, err) + } + + if strings.HasPrefix(hooksup.Hook, "save-view-") { + hasSaveViewHook = true + break + } + } + + // we error early if a write may affect ephemeral data but no save-view hook + // is present. However, a change-view hook may have written to an ephemeral + // path after that so we have to check again + if !hasSaveViewHook { + var viewName string + err = t.Get("view", &viewName) + if err != nil { + return fmt.Errorf(`internal error: cannot get "view" from task: %w`, err) + } + + view := confdbAssert.Schema().View(viewName) + paths := tx.AlteredPaths() + mightAffectEph, err := view.WriteAffectsEphemeral(paths) + if err != nil { + return fmt.Errorf("cannot commit transaction: cannot check for ephemeral paths: %v", err) + } + + if mightAffectEph { + return fmt.Errorf("cannot commit transaction: write may affect ephemeral data but no save-view hook is present") + } + } + return tx.Commit(st, schema) } diff --git a/overlord/confdbstate/confdbmgr_test.go b/overlord/confdbstate/confdbmgr_test.go index c1c60861aad..f9e05a44e65 100644 --- a/overlord/confdbstate/confdbmgr_test.go +++ b/overlord/confdbstate/confdbmgr_test.go @@ -19,6 +19,7 @@ package confdbstate_test import ( + "context" "errors" "strings" "time" @@ -33,6 +34,7 @@ import ( "github.com/snapcore/snapd/overlord/ifacestate/ifacerepo" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/testutil" + "gopkg.in/tomb.v2" . "gopkg.in/check.v1" ) @@ -442,6 +444,7 @@ func (s *confdbTestSuite) TestCommitTransaction(c *C) { c.Assert(err, IsNil) setTransaction(t, tx) + t.Set("view", "setup-wifi") s.state.Unlock() err = s.o.Settle(testutil.HostScaledTimeout(5 * time.Second)) @@ -518,6 +521,7 @@ func (s *confdbTestSuite) TestClearTransactionOnError(c *C) { err = tx.Set(parsePath(c, "foo"), "bar") c.Assert(err, IsNil) setTransaction(commitTask, tx) + commitTask.Set("view", "setup-wifi") // add this transaction to the state err = confdbstate.SetWriteTransaction(s.state, s.devAccID, "network", commitTask.ID()) @@ -531,10 +535,68 @@ func (s *confdbTestSuite) TestClearTransactionOnError(c *C) { c.Assert(chg.Status(), Equals, state.ErrorStatus) c.Assert(commitTask.Status(), Equals, state.ErrorStatus) c.Assert(clearTask.Status(), Equals, state.UndoneStatus) - c.Assert(strings.Join(commitTask.Log(), "\n"), Matches, ".*ERROR cannot accept top level element: map contains unexpected key \"foo\"") + c.Assert(strings.Join(commitTask.Log(), "\n"), Matches, ".*ERROR cannot commit transaction: cannot check for ephemeral paths: cannot check if write affects ephemeral data: cannot use \"foo\" as key in map") // no ongoing confdb transaction var ongoingTxs map[string]*confdbstate.ConfdbTransactions err = s.state.Get("confdb-ongoing-txs", &ongoingTxs) c.Assert(err, testutil.ErrorIs, &state.NoStateError{}) } + +func (s *confdbTestSuite) TestCommitTransactionEphemeralCheckWithoutSaveViewHooks(c *C) { + s.state.Lock() + defer s.state.Unlock() + + // the custodian has a change-view hook but no save-view + custodians := map[string]confdbHooks{"custodian-snap": changeView} + s.setupConfdbScenario(c, custodians, nil) + + // mock a change-view hook that writes to ephemeral data + restore := hookstate.MockRunHook(func(ctx *hookstate.Context, _ *tomb.Tomb) ([]byte, error) { + t, _ := ctx.Task() + ctx.State().Lock() + defer ctx.State().Unlock() + + var hooksup *hookstate.HookSetup + err := t.Get("hook-setup", &hooksup) + if err != nil { + return nil, err + } + c.Assert(strings.HasPrefix(hooksup.Hook, "change-view-"), Equals, true) + + tx, _, saveChanges, err := confdbstate.GetStoredTransaction(t) + if err != nil { + return nil, err + } + + err = tx.Set(parsePath(c, "wifi.eph"), "ephemeral-from-hook") + if err != nil { + return nil, err + } + saveChanges() + + return nil, nil + }) + defer restore() + + view, err := confdbstate.GetView(s.state, s.devAccID, "network", "setup-wifi") + c.Assert(err, IsNil) + + chgID, err := confdbstate.WriteConfdb(context.Background(), s.state, view, map[string]any{"ssid": "my-wifi"}) + c.Assert(err, IsNil) + + chg := s.state.Change(chgID) + c.Assert(chg, NotNil) + + s.state.Unlock() + err = s.o.Settle(testutil.HostScaledTimeout(5 * time.Second)) + s.state.Lock() + c.Assert(err, IsNil) + + // commit fails because change-view hook wrote ephemeral data but no save-view hooks exist + c.Assert(chg.Status(), Equals, state.ErrorStatus) + + commitTask := findTask(chg, "commit-confdb-tx") + c.Assert(commitTask, NotNil) + c.Assert(strings.Join(commitTask.Log(), "\n"), Matches, `.*ERROR cannot commit transaction: write may affect ephemeral data but no save-view hook is present.*`) +} diff --git a/overlord/confdbstate/confdbstate.go b/overlord/confdbstate/confdbstate.go index cff60e1cd3a..42449deda2a 100644 --- a/overlord/confdbstate/confdbstate.go +++ b/overlord/confdbstate/confdbstate.go @@ -525,6 +525,8 @@ func createChangeConfdbTasks(st *state.State, tx *Transaction, view *confdb.View // commit after custodians save ephemeral data commitTask := st.NewTask("commit-confdb-tx", fmt.Sprintf("Commit changes to confdb (%s)", view.ID())) commitTask.Set("confdb-transaction", tx) + commitTask.Set("view", view.Name) + // link all previous tasks to the commit task that carries the transaction for _, t := range ts.Tasks() { t.Set("tx-task", commitTask.ID()) From 017bd1b6ea28dd18f957e1e2ceb550ac9f46daf9 Mon Sep 17 00:00:00 2001 From: Katie May Date: Mon, 20 Apr 2026 17:04:46 +0200 Subject: [PATCH 03/21] tests: source nested.sh before calling function in prepare.sh (#16937) --- tests/lib/prepare.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/lib/prepare.sh b/tests/lib/prepare.sh index eaa03c754f0..3441b825ce2 100755 --- a/tests/lib/prepare.sh +++ b/tests/lib/prepare.sh @@ -1405,6 +1405,8 @@ setup_reflash_magic() { snap tasks --last=seed || true journalctl -u snapd snap model --verbose + #shellcheck source=tests/lib/nested.sh + . "$TESTSLIB/nested.sh" # remove the above debug lines once the mentioned bug is fixed snap install "--channel=$(nested_get_base_channel)" "$core_name" # TODO set up a trap to clean this up properly? From 538fb99f92d84dab00177fca9a11716169cbee65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alfonso=20S=C3=A1nchez-Beato?= Date: Mon, 20 Apr 2026 09:19:08 -0400 Subject: [PATCH 04/21] tests/lib/nested.sh: ensure test tools that need python can run on UC26 images, by setting appropriately the path. --- tests/lib/nested.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lib/nested.sh b/tests/lib/nested.sh index 55cad68a63f..c99e70ddd30 100755 --- a/tests/lib/nested.sh +++ b/tests/lib/nested.sh @@ -1863,7 +1863,7 @@ nested_prepare_tools() { if ! remote.exec "grep -qE PATH=.*$TOOLS_PATH /etc/environment"; then # shellcheck disable=SC2016 REMOTE_PATH="$(remote.exec 'echo $PATH')" - remote.exec "echo PATH=$TOOLS_PATH:$REMOTE_PATH | sudo tee -a /etc/environment" + remote.exec "echo PATH=$TOOLS_PATH:$REMOTE_PATH:/usr/lib/python | sudo tee -a /etc/environment" fi if [ -n "$TAG_FEATURES" ]; then From 970e0b681ee1afbe84dd6db1285f4ba92e3748d4 Mon Sep 17 00:00:00 2001 From: Katie May Date: Tue, 21 Apr 2026 12:02:26 +0200 Subject: [PATCH 05/21] tests: use noble for lp-1871652 (#16938) --- tests/regression/lp-1871652/task.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/regression/lp-1871652/task.yaml b/tests/regression/lp-1871652/task.yaml index 32b516b4e4c..f696c293218 100644 --- a/tests/regression/lp-1871652/task.yaml +++ b/tests/regression/lp-1871652/task.yaml @@ -15,11 +15,11 @@ details: | aware of the shutdown. # Run on a system matching the guest container. -systems: [ubuntu-18.04-64] +systems: [ubuntu-24.04-64] prepare: | "$TESTSTOOLS"/lxd-state prepare-snap - "$TESTSTOOLS"/lxd-state launch --remote ubuntu --image 18.04 --name bionic + "$TESTSTOOLS"/lxd-state launch --remote ubuntu --image 24.04 --name bionic # Install snapd inside the container and then install the core snap so that # we get re-execution logic to applies as snapd in the store is more recent From 604f5e7435869611f12c1fb68214ae49ab048f28 Mon Sep 17 00:00:00 2001 From: Katie May Date: Tue, 21 Apr 2026 12:45:44 +0200 Subject: [PATCH 06/21] github: use symlink for debian folder (#16948) --- .github/workflows/deb-builds.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/deb-builds.yaml b/.github/workflows/deb-builds.yaml index 22e6951c963..e0e3005b2cb 100644 --- a/.github/workflows/deb-builds.yaml +++ b/.github/workflows/deb-builds.yaml @@ -47,10 +47,10 @@ jobs: target_system="${{ inputs.os }}-${{ inputs.os-version }}" case "$target_system" in debian-sid) - cp -av packaging/debian-sid debian + ln -sfn packaging/debian-sid debian ;; ubuntu-*) - cp -av packaging/ubuntu-16.04 debian + ln -sfn packaging/ubuntu-16.04 debian ;; *) echo "unsupported deb packaging for $target_system" From 3fbb81c2d7ba9b1caa4be873128e23948ac26856 Mon Sep 17 00:00:00 2001 From: Miguel Pires Date: Tue, 21 Apr 2026 10:55:00 +0100 Subject: [PATCH 07/21] tests: fix upgrade-from-release Signed-off-by: Miguel Pires --- tests/main/upgrade-from-release/task.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/main/upgrade-from-release/task.yaml b/tests/main/upgrade-from-release/task.yaml index 6df7040c58f..b20c72b3ed3 100644 --- a/tests/main/upgrade-from-release/task.yaml +++ b/tests/main/upgrade-from-release/task.yaml @@ -44,7 +44,7 @@ execute: | # TODO: add automatic package lookup - manual list maintenance is impractical declare -A EXPECTED_SNAPD_VERSIONS=( ["26.04"]='2.74.1\+ubuntu26.04' - ["25.10"]='2.73\+ubuntu25.10' + ["25.10"]='2.74.1\+ubuntu25.10' ["24.04"]='2.62\+24.04' ["22.04"]='2.55.3\+22.04' ["20.04"]='2.44.3\+20.04' From 582e12161f1e6e99ab10ebb1c76e094a3a10b1ae Mon Sep 17 00:00:00 2001 From: Maciej Borzecki Date: Tue, 21 Apr 2026 15:07:57 +0200 Subject: [PATCH 08/21] interfaces/builtin/content: add unit tests for non-obvious targets (#16928) Add unit tests for the non-obvious values that are accepted for target. Signed-off-by: Maciej Borzecki --- interfaces/builtin/content_test.go | 41 ++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/interfaces/builtin/content_test.go b/interfaces/builtin/content_test.go index 1e9c70387c6..793f9a2942e 100644 --- a/interfaces/builtin/content_test.go +++ b/interfaces/builtin/content_test.go @@ -20,6 +20,7 @@ package builtin_test import ( + "fmt" "path/filepath" "strings" @@ -322,6 +323,46 @@ apps: c.Assert(interfaces.BeforePreparePlug(s.iface, plug), ErrorMatches, "content plug must contain target path") } +func (s *ContentSuite) TestSanitizePlugTargetEdgeCases(c *C) { + const snapYamlTemplate = `name: content-snap +version: 1.0 +plugs: + content-plug: + interface: content + content: mycont + target: %s +` + for _, tc := range []struct { + target string + errMsg string + }{ + // explicit, well understood, covered by other unit tests + {target: "$SNAP/import"}, + {target: "$SNAP_DATA/import"}, + {target: "$SNAP_COMMON/import"}, + // bare $SNAP, no subpath + {target: "$SNAP"}, + // bare path, implicit $SNAP prefix + {target: "import"}, + // absolute path, implicit $SNAP prefix + {target: "/import"}, + // bare root, ends up as $SNAP + {target: "/"}, + + // trailing slash is not a clean path, inconsistent with the rest + {target: "$SNAP/", errMsg: `content interface path is not clean: .*`}, + } { + info := snaptest.MockInfo(c, fmt.Sprintf(snapYamlTemplate, tc.target), nil) + plug := info.Plugs["content-plug"] + err := interfaces.BeforePreparePlug(s.iface, plug) + if tc.errMsg == "" { + c.Assert(err, IsNil, Commentf("target: %s", tc.target)) + } else { + c.Assert(err, ErrorMatches, tc.errMsg, Commentf("target: %s", tc.target)) + } + } +} + func (s *ContentSuite) TestSanitizeSlotNilAttrMap(c *C) { const mockSnapYaml = `name: content-slot-snap version: 1.0 From 19a413d93f34e2a89267da93adf910866f9d7e2e Mon Sep 17 00:00:00 2001 From: Miguel Pires Date: Tue, 21 Apr 2026 20:44:36 +0100 Subject: [PATCH 09/21] o/confdbstate: block concurrent snapctl accesses (#16834) * o/confdbstate: block concurrent snapctl accesses Extend the blocking UX to include accesses coming from snapctl. o/confdbstate: unblock as many reads as possible o/confdbstate: fix ignored err return + outdated TODO o/confdbstate: remove channel when closing o/confdbstate: read w/o tasks also unblocks next access o/confdbstate: unblock next pending access on error paths o/confdbstate: test hook helper o/confdbstate: unblock depending on tasks o/confdbstate: unblocked access must remove its own access o/confdbstate: prevent race conditions in multi read scenarios Signed-off-by: Miguel Pires * o/confdbstate: make channel buffered; tweaks Signed-off-by: Miguel Pires * o/confdbstate: more tweaks Signed-off-by: Miguel Pires * o/confdbstate: fix racy test Signed-off-by: Miguel Pires * o/confdbstate: remove waitID when setting ongoing tx Signed-off-by: Miguel Pires * o/confdbstate: merge WriteConfdbFromSnap and getTransactionToSet Signed-off-by: Miguel Pires * o/confdbstate: improve docs Signed-off-by: Miguel Pires * o/confdbstate: more docs + rename Signed-off-by: Miguel Pires * o/confdbstate: tweak cleanup and test Signed-off-by: Miguel Pires * o/confdbstate: rm unnecessary code; minor improvements Signed-off-by: Miguel Pires * o/confdbstate: remove edge usage; other minor tweaks Signed-off-by: Miguel Pires --------- Signed-off-by: Miguel Pires --- overlord/confdbstate/confdbmgr.go | 142 ++- overlord/confdbstate/confdbmgr_test.go | 33 +- overlord/confdbstate/confdbstate.go | 477 +++++----- overlord/confdbstate/confdbstate_test.go | 1095 +++++++++++++++++----- overlord/confdbstate/export_test.go | 26 +- overlord/hookstate/ctlcmd/export_test.go | 16 +- overlord/hookstate/ctlcmd/get.go | 7 +- overlord/hookstate/ctlcmd/get_test.go | 22 +- overlord/hookstate/ctlcmd/set.go | 30 +- overlord/hookstate/ctlcmd/set_test.go | 101 +- overlord/hookstate/ctlcmd/unset_test.go | 27 +- 11 files changed, 1298 insertions(+), 678 deletions(-) diff --git a/overlord/confdbstate/confdbmgr.go b/overlord/confdbstate/confdbmgr.go index 4ff4874fd1c..9529fbd7ec6 100644 --- a/overlord/confdbstate/confdbmgr.go +++ b/overlord/confdbstate/confdbmgr.go @@ -35,9 +35,14 @@ import ( ) const ( - // cacheKeyPrefix is the prefix to be concatenated with confdb IDs to form a - // cache key used to store pending access data. - cacheKeyPrefix = "confdb-accesses-" + // pendingCachePrefix is the prefix to be concatenated with confdb IDs to + // form a cache key used to store pending access data. + pendingCachePrefix = "pending-confdb-" + + // schedulingCachePrefix is the prefix to be concatenated with confdb IDs to + // form a cache key used to store access data that was unblocked and is being + // scheduled. + schedulingCachePrefix = "scheduling-confdb-" ) func setupConfdbHook(st *state.State, snapName, hookName string, ignoreError bool) *state.Task { @@ -163,7 +168,6 @@ func (m *ConfdbManager) clearOngoingTransaction(t *state.Task, _ *tomb.Tomb) err return err } - // TODO: unblock next waiting confdb writer once we add the blocking logic return nil } @@ -242,9 +246,13 @@ type confdbTransactions struct { ReadTxIDs []string `json:"read-tx-ids,omitempty"` WriteTxID string `json:"write-tx-id,omitempty"` - // pending holds accesses that are waiting to be scheduled. It's read from + // Pending holds accesses that are waiting to be scheduled. It's read from // the state cache so it's only kept in-memory, never persisted into state. - pending []pendingAccess + Pending []access `json:"-"` + + // Scheduling holds accesses that have been unblocked (moved from pending) + // but have not yet finished scheduling tasks/exiting. + Scheduling []access `json:"-"` } // CanStartReadTx returns true if there isn't a write transaction running or @@ -254,7 +262,10 @@ func (txs *confdbTransactions) CanStartReadTx() bool { return false } - for _, access := range txs.pending { + accesses := append([]access{}, txs.Pending...) + accesses = append(accesses, txs.Scheduling...) + + for _, access := range accesses { if access.AccessType == writeAccess { return false } @@ -263,21 +274,32 @@ func (txs *confdbTransactions) CanStartReadTx() bool { return true } -// CanStartWriteTx returns true if there is no running or pending transaction. +// CanStartWriteTx returns true if there is no access currently running or +// waiting to run. func (txs *confdbTransactions) CanStartWriteTx() bool { - return txs.WriteTxID == "" && len(txs.ReadTxIDs) == 0 && len(txs.pending) == 0 + return txs.WriteTxID == "" && len(txs.ReadTxIDs) == 0 && + len(txs.Pending) == 0 && len(txs.Scheduling) == 0 } // addReadTransaction adds a read transaction for the specified confdb, if no -// write transactions is ongoing. The state must be locked by the caller. -func addReadTransaction(st *state.State, account, confdbName, id string) error { +// write transactions is ongoing. If a accessID is passed in, it'll be removed +// from the Scheduling list. The state must be locked by the caller. +func addReadTransaction(st *state.State, account, confdbName, id, accessID string) error { txs, updateTxStateFunc, err := getOngoingTxs(st, account, confdbName) if err != nil { return err } + for i, acc := range txs.Scheduling { + if acc.ID == accessID { + txs.Scheduling = append(txs.Scheduling[:i], txs.Scheduling[i+1:]...) + break + } + } + if txs.WriteTxID != "" { - return fmt.Errorf("cannot read confdb (%s/%s): a write transaction is ongoing", account, confdbName) + // shouldn't happen save for programmer error + return fmt.Errorf("internal error: cannot read confdb (%s/%s): a write transaction is ongoing", account, confdbName) } txs.ReadTxIDs = append(txs.ReadTxIDs, id) @@ -286,21 +308,30 @@ func addReadTransaction(st *state.State, account, confdbName, id string) error { } // setWriteTransaction sets a write transaction for the specified confdb schema, -// if no other transactions (read or write) are ongoing. The state must be locked -// by the caller. -func setWriteTransaction(st *state.State, account, schemaName, id string) error { +// if no other transactions (read or write) are ongoing. If a accessID is passed +// in, it'll be removed from the Scheduling list. The state must be locked by +// the caller. +func setWriteTransaction(st *state.State, account, schemaName, id, accessID string) error { txs, updateTxStateFunc, err := getOngoingTxs(st, account, schemaName) if err != nil { return err } + for i, acc := range txs.Scheduling { + if acc.ID == accessID { + txs.Scheduling = append(txs.Scheduling[:i], txs.Scheduling[i+1:]...) + break + } + } + if txs.WriteTxID != "" || len(txs.ReadTxIDs) != 0 { op := "read" if txs.WriteTxID != "" { op = "write" } - return fmt.Errorf("cannot write confdb (%s/%s): a %s transaction is ongoing", account, schemaName, op) + // shouldn't happen save for programmer error + return fmt.Errorf("internal error: cannot write confdb (%s/%s): a %s transaction is ongoing", account, schemaName, op) } txs.WriteTxID = id @@ -341,16 +372,35 @@ func getOngoingTxs(st *state.State, account, schemaName string) (ongoingTxs *con st.Set("confdb-ongoing-txs", confdbTxs) } - st.Cache(cacheKeyPrefix+ref, ongoingTxs.pending) + if len(ongoingTxs.Pending) == 0 { + st.Cache(pendingCachePrefix+ref, nil) + } else { + st.Cache(pendingCachePrefix+ref, ongoingTxs.Pending) + } + + if len(ongoingTxs.Scheduling) == 0 { + st.Cache(schedulingCachePrefix+ref, nil) + } else { + st.Cache(schedulingCachePrefix+ref, ongoingTxs.Scheduling) + } } - cached := st.Cached(cacheKeyPrefix + ref) + cached := st.Cached(pendingCachePrefix + ref) if cached != nil { - queue, ok := cached.([]pendingAccess) + queue, ok := cached.([]access) if !ok { return nil, nil, fmt.Errorf("internal error: cannot access confdb pending transaction queue") } - confdbTxs[ref].pending = queue + confdbTxs[ref].Pending = queue + } + + cached = st.Cached(schedulingCachePrefix + ref) + if cached != nil { + queue, ok := cached.([]access) + if !ok { + return nil, nil, fmt.Errorf("internal error: cannot access confdb scheduling list") + } + confdbTxs[ref].Scheduling = queue } return confdbTxs[ref], updateTxStateFunc, nil @@ -378,34 +428,50 @@ func unsetOngoingTransaction(st *state.State, account, schemaName, id string) er if len(txs.ReadTxIDs) > 0 { // there are other transactions running (can only be reads) so skip this. - // The last one will unblock the next access + // The last one will unblock the next accesses return nil } - // unblock any waiting routine - if len(txs.pending) > 0 { - logger.Debugf("remove pending access %s", txs.pending[0].ID) - close(txs.pending[0].WaitChan) - } - - return nil + return maybeUnblockAccesses(txs) } -func unblockNextAccess(st *state.State, account, schemaName string) error { - txs, updateTxStateFunc, err := getOngoingTxs(st, account, schemaName) - if err != nil { - return err +// maybeUnblockAccesses unblocks as many consecutive pending accesses as +// possible, either one write or one or more sequential reads. +// This may be a no-op, if there are: +// - no pending changes (i.e., there's nothing to unblock) +// - changes running for other transactions - pending accesses would've been +// scheduled w/o waiting if they could (see waitForAccess) so any pending +// accesses are guaranteed to be incompatible. +// - accesses that have been unblocked but are still scheduling changes. If we +// unblocked accesses here, they would race with the ones already scheduling +// +// If accesses are unblocked, they're removed from the Pending list and put into +// the Scheduling list so we can track unblocked but still unscheduled accesses. +func maybeUnblockAccesses(txs *confdbTransactions) error { + if len(txs.Pending) == 0 || txs.WriteTxID != "" || len(txs.ReadTxIDs) > 0 || len(txs.Scheduling) != 0 { + return nil } - if len(txs.pending) == 0 { - return nil + var upTo int + for i, acc := range txs.Pending { + if acc.AccessType == writeAccess { + if i == 0 { + acc.WaitChan <- struct{}{} + logger.Debugf("unblocking pending %s access %s", acc.AccessType, acc.ID) + upTo = i + } + + break + } + + acc.WaitChan <- struct{}{} + logger.Debugf("unblocking pending %s access %s", acc.AccessType, acc.ID) + upTo = i } - // unblock any waiting routine - logger.Debugf("remove pending access %s", txs.pending[0].ID) - close(txs.pending[0].WaitChan) + txs.Scheduling = append([]access{}, txs.Pending[:upTo+1]...) + txs.Pending = txs.Pending[upTo+1:] - updateTxStateFunc(txs) return nil } diff --git a/overlord/confdbstate/confdbmgr_test.go b/overlord/confdbstate/confdbmgr_test.go index f9e05a44e65..a81bcba62ab 100644 --- a/overlord/confdbstate/confdbmgr_test.go +++ b/overlord/confdbstate/confdbmgr_test.go @@ -366,10 +366,14 @@ func (s *confdbTestSuite) TestSetAndUnsetOngoingTransactionHelpers(c *C) { err := s.state.Get("confdb-ongoing-txs", &ongoingTxs) c.Assert(err, testutil.ErrorIs, &state.NoStateError{}) - err = confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "1") + s.state.Cache("scheduling-confdb-my-acc/my-confdb", []confdbstate.Access{{ID: "foo"}}) + + err = confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "1", "foo") c.Assert(err, IsNil) + accs := s.state.Cached("scheduling-confdb-my-acc/my-confdb") + c.Assert(accs, IsNil) - err = confdbstate.SetWriteTransaction(s.state, "other-acc", "other-confdb", "2") + err = confdbstate.SetWriteTransaction(s.state, "other-acc", "other-confdb", "2", "") c.Assert(err, IsNil) err = s.state.Get("confdb-ongoing-txs", &ongoingTxs) @@ -402,29 +406,32 @@ func (s *confdbTestSuite) TestConflictingOngoingTransactions(c *C) { s.state.Lock() defer s.state.Unlock() - err := confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "1") + err := confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "1", "") c.Assert(err, IsNil) // can't set write due to ongoing write - err = confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "2") - c.Assert(err, ErrorMatches, `cannot write confdb \(my-acc/my-confdb\): a write transaction is ongoing`) + err = confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "2", "") + c.Assert(err, ErrorMatches, `internal error: cannot write confdb \(my-acc/my-confdb\): a write transaction is ongoing`) // can't add read due to ongoing write - err = confdbstate.AddReadTransaction(s.state, "my-acc", "my-confdb", "2") - c.Assert(err, ErrorMatches, `cannot read confdb \(my-acc/my-confdb\): a write transaction is ongoing`) + err = confdbstate.AddReadTransaction(s.state, "my-acc", "my-confdb", "2", "") + c.Assert(err, ErrorMatches, `internal error: cannot read confdb \(my-acc/my-confdb\): a write transaction is ongoing`) err = confdbstate.UnsetOngoingTransaction(s.state, "my-acc", "my-confdb", "1") c.Assert(err, IsNil) - err = confdbstate.AddReadTransaction(s.state, "my-acc", "my-confdb", "1") + s.state.Cache("scheduling-confdb-my-acc/my-confdb", []confdbstate.Access{{ID: "foo"}}) + err = confdbstate.AddReadTransaction(s.state, "my-acc", "my-confdb", "1", "foo") c.Assert(err, IsNil) + accs := s.state.Cached("scheduling-confdb-my-acc/my-confdb") + c.Assert(accs, IsNil) // can't set write due to ongoing read - err = confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "2") - c.Assert(err, ErrorMatches, `cannot write confdb \(my-acc/my-confdb\): a read transaction is ongoing`) + err = confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "2", "") + c.Assert(err, ErrorMatches, `internal error: cannot write confdb \(my-acc/my-confdb\): a read transaction is ongoing`) // many reads are fine - err = confdbstate.AddReadTransaction(s.state, "my-acc", "my-confdb", "2") + err = confdbstate.AddReadTransaction(s.state, "my-acc", "my-confdb", "2", "") c.Assert(err, IsNil) } @@ -483,7 +490,7 @@ func (s *confdbTestSuite) TestClearOngoingTransaction(c *C) { chg.AddTask(t) t.Set("tx-task", commitTask.ID()) - confdbstate.SetWriteTransaction(s.state, s.devAccID, "network", commitTask.ID()) + confdbstate.SetWriteTransaction(s.state, s.devAccID, "network", commitTask.ID(), "") c.Assert(err, IsNil) var confdbTxs map[string]*confdbstate.ConfdbTransactions @@ -524,7 +531,7 @@ func (s *confdbTestSuite) TestClearTransactionOnError(c *C) { commitTask.Set("view", "setup-wifi") // add this transaction to the state - err = confdbstate.SetWriteTransaction(s.state, s.devAccID, "network", commitTask.ID()) + err = confdbstate.SetWriteTransaction(s.state, s.devAccID, "network", commitTask.ID(), "") c.Assert(err, IsNil) s.state.Unlock() diff --git a/overlord/confdbstate/confdbstate.go b/overlord/confdbstate/confdbstate.go index 42449deda2a..505e895ba0c 100644 --- a/overlord/confdbstate/confdbstate.go +++ b/overlord/confdbstate/confdbstate.go @@ -45,21 +45,26 @@ import ( var ( assertstateConfdbSchema = assertstate.ConfdbSchema assertstateFetchConfdbSchemaAssertion = assertstate.FetchConfdbSchemaAssertion -) -var ( setConfdbChangeKind = swfeats.RegisterChangeKind("set-confdb") getConfdbChangeKind = swfeats.RegisterChangeKind("get-confdb") - // testBlockingChan is closed right before blocking to wait for access. - blockingSignalChan chan struct{} + // blockingSignals holds channels that, if present, will be closed to signal + // that an operation is about to block. Its only use is to test some blocking + // behaviour. + blockingSignals map[string]chan struct{} defaultWaitTimeout = 10 * time.Minute + + ensureNow = func(st *state.State) { + st.EnsureBefore(0) + } + + transactionTimeout = 2 * time.Minute ) -// SetViaView uses the view to set the requests in the transaction's databag. -// TODO: unexport this once the next PR refactors the writing from snapctl -func SetViaView(bag confdb.Databag, view *confdb.View, requests map[string]any) error { +// setViaView uses the view to set the requests in the transaction's databag. +func setViaView(bag confdb.Databag, view *confdb.View, requests map[string]any) error { for request, value := range requests { var err error if value == nil { @@ -201,142 +206,144 @@ var writeDatabag = func(st *state.State, databag confdb.JSONDatabag, account, db return nil } -// waitForAccess blocks until the access can be processed or until the context -// was cancelled/timed out, in which case an error is returned. Caller must hold -// the state lock. -func waitForAccess(ctx context.Context, st *state.State, view *confdb.View, access accessType) (err error) { +// waitForAccess checks if ongoing transactions prevent this access from running +// and if necessary blocks until it can. The following scenarios can occur: +// - the access can immediately run (no ongoing tx or all are reads) - returns +// without waiting, with no accessID or error +// - the access must wait - returns after being unblocked, with a non-empty +// accessID matching an access in Processing (to be removed after scheduling) +// - any error occurs or the context times out or is cancelled - returns an +// error but no accessID, since relevant state in Processing/Pending is cleared +// +// Caller must hold the state lock. +func waitForAccess(ctx context.Context, st *state.State, view *confdb.View, accKind accessType) (accessID string, err error) { account, schema := view.Schema().Account, view.Schema().Name txs, updateTxs, err := getOngoingTxs(st, account, schema) if err != nil { - return fmt.Errorf("cannot access confdb view %s: cannot check ongoing transactions: %v", view.ID(), err) + return "", fmt.Errorf("cannot access confdb view %s: cannot check ongoing transactions: %v", view.ID(), err) } - if (access == readAccess && txs.CanStartReadTx()) || (access == writeAccess && txs.CanStartWriteTx()) { - return nil + if (accKind == readAccess && txs.CanStartReadTx()) || (accKind == writeAccess && txs.CanStartWriteTx()) { + return "", nil } - id := randutil.RandomString(20) + accessID = randutil.RandomString(20) - wait := make(chan struct{}) - txs.pending = append(txs.pending, pendingAccess{ - AccessType: access, + // AFAICT a buffer isn't strictly necessary here because if a writer sends to + // the channel, this goroutine will already have unlocked state and will eventually + // read from the channel, unblocking the lock holding goroutine. But let's be extra safe + wait := make(chan struct{}, 2) + txs.Pending = append(txs.Pending, access{ + AccessType: accKind, WaitChan: wait, - ID: id, + ID: accessID, }) updateTxs(txs) st.Unlock() - defer func() { - st.Lock() - txs, updateTxs, defErr := getOngoingTxs(st, account, schema) - if defErr != nil { - if err == nil { - err = fmt.Errorf("cannot access %s: cannot check ongoing transactions: %v", view.ID(), defErr) - } - return - } - - accIndex := -1 - for i, acc := range txs.pending { - if acc.ID == id { - accIndex = i - } - } - - if accIndex == -1 { - logger.Noticef("cannot find access id %s when updating pending accesses", id) - } else { - txs.pending = append(txs.pending[:accIndex], txs.pending[accIndex+1:]...) - } - - updateTxs(txs) - }() - - _, set := ctx.Deadline() - if !set { + if _, set := ctx.Deadline(); !set { // set a maximum waiting time to safeguard against this hanging forever var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, defaultWaitTimeout) defer cancel() } - if blockingSignalChan != nil { - // signal we're about to block for testing - close(blockingSignalChan) + if blockingSignals["wait-for-access"] != nil { + // for testing purposes only + close(blockingSignals["wait-for-access"]) } select { case <-wait: + st.Lock() case <-ctx.Done(): - return fmt.Errorf("cannot %s %s: timed out waiting for access", access, view.ID()) + // if the waiting was cancelled or timed out, clean up the pending state + st.Lock() + txs, updateTxs, err := getOngoingTxs(st, account, schema) + if err != nil { + return "", fmt.Errorf("cannot cleanup state after timeout/cancel: %v", err) + } + + for i, acc := range txs.Pending { + if acc.ID == accessID { + txs.Pending = append(txs.Pending[:i], txs.Pending[i+1:]...) + break + } + } + + // if the timeout/cancel raced with an unblock, the access might be in + // Scheduling so remove that + for i, acc := range txs.Scheduling { + if acc.ID == accessID { + txs.Scheduling = append(txs.Scheduling[:i], txs.Scheduling[i+1:]...) + break + } + } + + err = maybeUnblockAccesses(txs) + if err != nil { + return "", fmt.Errorf("cannot cleanup state after timeout/cancel: %v", err) + } + + updateTxs(txs) + + return "", fmt.Errorf("cannot %s %s: timed out waiting for access", accKind, view.ID()) } - return nil + return accessID, nil } // WriteConfdb takes a map of request paths to values, schedules a change to // set the values in specified confdb view and run the appropriate hooks. // Returns a change ID. func WriteConfdb(ctx context.Context, st *state.State, view *confdb.View, values map[string]any) (changeID string, err error) { - defer func() { - if err != nil { - uerr := unblockNextAccess(st, view.Schema().Account, view.Schema().Name) - if uerr != nil { - logger.Noticef("cannot unblock next access after failed write: %v", uerr) - } - } - }() - - err = waitForAccess(ctx, st, view, writeAccess) + accessID, err := waitForAccess(ctx, st, view, writeAccess) if err != nil { return "", err } - account, schemaName := view.Schema().Account, view.Schema().Name + + account, schema := view.Schema().Account, view.Schema().Name + // accessID is empty if we didn't release the lock and wait, so no state was + // modified and there aren't other accesses to unblock + if accessID != "" { + defer cleanupAccess(st, accessID, account, schema) + } // not running in an existing confdb hook context, so create a transaction // and a change to verify its changes and commit - tx, err := NewTransaction(st, account, schemaName) + tx, err := NewTransaction(st, account, schema) if err != nil { return "", fmt.Errorf("cannot modify confdb through view %s: cannot create transaction: %v", view.ID(), err) } - err = SetViaView(tx, view, values) + err = setViaView(tx, view, values) if err != nil { return "", err } // the hooks we schedule depend on the paths written so this must happen after writing - ts, err := createChangeConfdbTasks(st, tx, view, "") + ts, commitTask, _, err := createChangeConfdbTasks(st, tx, view, "") if err != nil { return "", err } - chg := st.NewChange(setConfdbChangeKind, fmt.Sprintf("Set confdb through %q", view.ID())) - chg.AddAll(ts) - - commitTask, err := ts.Edge(commitEdge) + err = setWriteTransaction(st, account, schema, commitTask.ID(), accessID) if err != nil { return "", err } - err = setWriteTransaction(st, account, schemaName, commitTask.ID()) - if err != nil { - return "", err - } + // schedule tasks after saving the tx ID so the deferred cleanup skips waking + // up waiters if a task will do it (txs.WriteTxID != "") + chg := st.NewChange(setConfdbChangeKind, fmt.Sprintf("Set confdb through %q", view.ID())) + chg.AddAll(ts) - return chg.ID(), err + return chg.ID(), nil } -type CommitTxFunc func() (changeID string, waitChan <-chan struct{}, err error) - -// GetTransactionToSet gets a transaction to change the confdb through the view. -// The state must be locked by the caller. Returns a transaction through which -// the confdb can be modified and a CommitTxFunc. The latter is called once the -// modifications are made to commit them. It will return a changeID and a channel, -// allowing the caller to block until commit. If a transaction was already ongoing, -// CommitTxFunc simply returns that without blocking (changes to it will be -// saved on ctx.Done()). -func GetTransactionToSet(hookCtx *hookstate.Context, st *state.State, view *confdb.View) (*Transaction, CommitTxFunc, error) { - account, schemaName := view.Schema().Account, view.Schema().Name +// WriteConfdbFromSnap takes a hook context and a map of requests to values that +// are written through the provided view. It will block until the writing change +// completes. +func WriteConfdbFromSnap(hookCtx *hookstate.Context, view *confdb.View, values map[string]any) (err error) { + account, schema := view.Schema().Account, view.Schema().Name // check if we're already running in the context of a committing transaction if IsConfdbHookCtx(hookCtx) { @@ -345,11 +352,11 @@ func GetTransactionToSet(hookCtx *hookstate.Context, st *state.State, view *conf t, _ := hookCtx.Task() tx, _, saveTxChanges, err := GetStoredTransaction(t) if err != nil { - return nil, nil, fmt.Errorf("cannot access confdb through view %s: cannot get transaction: %v", view.ID(), err) + return fmt.Errorf("cannot access confdb through view %s: cannot get transaction: %v", view.ID(), err) } - if tx.ConfdbAccount != account || tx.ConfdbName != schemaName { - return nil, nil, fmt.Errorf("cannot access confdb through view %s: ongoing transaction for %s/%s", view.ID(), tx.ConfdbAccount, tx.ConfdbName) + if tx.ConfdbAccount != account || tx.ConfdbName != schema { + return fmt.Errorf("cannot access confdb through view %s: ongoing transaction for %s/%s", view.ID(), tx.ConfdbAccount, tx.ConfdbName) } // update the commit task to save transaction changes made by the hook @@ -358,109 +365,109 @@ func GetTransactionToSet(hookCtx *hookstate.Context, st *state.State, view *conf return nil }) - return tx, nil, nil + return setViaView(tx, view, values) + } + + // get --wait-for timeout from context state, if any is set + ctx := context.Background() + if hookCtx.Timeout() != time.Duration(0) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, hookCtx.Timeout()) + defer cancel() } - txs, _, err := getOngoingTxs(st, account, schemaName) + st := hookCtx.State() + accessID, err := waitForAccess(ctx, st, view, writeAccess) if err != nil { - return nil, nil, fmt.Errorf("cannot access confdb view %s: cannot check ongoing transactions: %v", view.ID(), err) + return err } - if txs != nil && !txs.CanStartWriteTx() { - // TODO: eventually we want to queue this write and block until we serve it. - // It might also be necessary to have some form of timeout. - return nil, nil, fmt.Errorf("cannot write confdb through view %s: ongoing transaction", view.ID()) + // accessID is empty if we didn't release the lock and wait, so no state was + // modified and there aren't other accesses to unblock + if accessID != "" { + defer cleanupAccess(st, accessID, account, schema) } // not running in an existing confdb hook context, so create a transaction // and a change to verify its changes and commit - tx, err := NewTransaction(st, account, schemaName) + tx, err := NewTransaction(st, account, schema) if err != nil { - return nil, nil, fmt.Errorf("cannot modify confdb through view %s: cannot create transaction: %v", view.ID(), err) + return fmt.Errorf("cannot modify confdb through view %s: cannot create transaction: %v", view.ID(), err) } - commitTx := func() (string, <-chan struct{}, error) { - var chg *state.Change - if hookCtx == nil || hookCtx.IsEphemeral() { - chg = st.NewChange(setConfdbChangeKind, fmt.Sprintf("Set confdb through %q", view.ID())) - } else { - // we're running in the context of a non-confdb hook, add the tasks to that change - task, _ := hookCtx.Task() - chg = task.Change() - } + err = setViaView(tx, view, values) + if err != nil { + return err + } - var callingSnap string - if hookCtx != nil { - callingSnap = hookCtx.InstanceName() - } + var chg *state.Change + if hookCtx.IsEphemeral() { + chg = st.NewChange(setConfdbChangeKind, fmt.Sprintf("Set confdb through %q", view.ID())) + } else { + // we're running in the context of a non-confdb hook, add the tasks to that change + task, _ := hookCtx.Task() + chg = task.Change() + } - ts, err := createChangeConfdbTasks(st, tx, view, callingSnap) - if err != nil { - return "", nil, err - } - chg.AddAll(ts) + ts, commitTask, clearTxTask, err := createChangeConfdbTasks(st, tx, view, hookCtx.InstanceName()) + if err != nil { + return err + } - commitTask, err := ts.Edge(commitEdge) - if err != nil { - return "", nil, err - } + // schedule tasks after saving the tx ID so the deferred cleanup skips waking + // up waiters if a task will do it (txs.WriteTxID != "") + err = setWriteTransaction(st, account, schema, commitTask.ID(), accessID) + if err != nil { + return err + } + chg.AddAll(ts) - clearTxTask, err := ts.Edge(clearTxEdge) - if err != nil { - return "", nil, err + waitChan := make(chan struct{}) + st.AddTaskStatusChangedHandler(func(t *state.Task, _, new state.Status) (remove bool) { + if t.ID() == clearTxTask.ID() && new.Ready() { + close(waitChan) + return true } + return false + }) - err = setWriteTransaction(st, account, schemaName, commitTask.ID()) - if err != nil { - return "", nil, err - } + ensureNow(st) - waitChan := make(chan struct{}) - st.AddTaskStatusChangedHandler(func(t *state.Task, old, new state.Status) (remove bool) { - if t.ID() == clearTxTask.ID() && new.Ready() { - close(waitChan) - return true - } - return false - }) + // wait for the transaction to be committed + hookCtx.Unlock() + defer hookCtx.Lock() - ensureNow(st) - return chg.ID(), waitChan, nil + if blockingSignals["wait-for-change-done"] != nil { + // for testing purposes only + close(blockingSignals["wait-for-change-done"]) } - return tx, commitTx, nil -} - -var ( - ensureNow = func(st *state.State) { - st.EnsureBefore(0) + select { + case <-waitChan: + case <-time.After(transactionTimeout): + return fmt.Errorf("cannot set confdb %s: timed out after %s", view.ID(), transactionTimeout) } - transactionTimeout = 2 * time.Minute -) - -const ( - commitEdge = state.TaskSetEdge("commit-edge") - clearTxEdge = state.TaskSetEdge("clear-tx-edge") -) + return nil +} -func createChangeConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, callingSnap string) (*state.TaskSet, error) { +func createChangeConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, callingSnap string) (ts *state.TaskSet, commitTask, clearTxTask *state.Task, err error) { custodians, custodianPlugs, err := getCustodianPlugsForView(st, view) if err != nil { - return nil, err + return nil, nil, nil, err } if len(custodianPlugs) == 0 { - return nil, fmt.Errorf("cannot commit changes to confdb made through view %s: no custodian snap connected", view.ID()) + return nil, nil, nil, fmt.Errorf("cannot commit changes to confdb made through view %s: no custodian snap connected", view.ID()) } paths := tx.AlteredPaths() mightAffectEph, err := view.WriteAffectsEphemeral(paths) if err != nil { - return nil, err + return nil, nil, nil, err } - ts := state.NewTaskSet() + ts = state.NewTaskSet() linkTask := func(t *state.Task) { tasks := ts.Tasks() if len(tasks) > 0 { @@ -492,7 +499,7 @@ func createChangeConfdbTasks(st *state.State, tx *Transaction, view *confdb.View } if hookPrefix == "save-view-" && mightAffectEph && !saveViewHookPresent { - return nil, fmt.Errorf("cannot access %s: write might change ephemeral data but no custodians has a save-view hook", view.ID()) + return nil, nil, nil, fmt.Errorf("cannot access %s: write might change ephemeral data but no custodians has a save-view hook", view.ID()) } } @@ -500,7 +507,7 @@ func createChangeConfdbTasks(st *state.State, tx *Transaction, view *confdb.View // changed with this data modification affectedPlugs, err := getPlugsAffectedByPaths(st, view.Schema(), paths) if err != nil { - return nil, err + return nil, nil, nil, err } viewChangedSnaps := make([]string, 0, len(affectedPlugs)) @@ -523,7 +530,7 @@ func createChangeConfdbTasks(st *state.State, tx *Transaction, view *confdb.View } // commit after custodians save ephemeral data - commitTask := st.NewTask("commit-confdb-tx", fmt.Sprintf("Commit changes to confdb (%s)", view.ID())) + commitTask = st.NewTask("commit-confdb-tx", fmt.Sprintf("Commit changes to confdb (%s)", view.ID())) commitTask.Set("confdb-transaction", tx) commitTask.Set("view", view.Name) @@ -532,15 +539,13 @@ func createChangeConfdbTasks(st *state.State, tx *Transaction, view *confdb.View t.Set("tx-task", commitTask.ID()) } linkTask(commitTask) - ts.MarkEdge(commitTask, commitEdge) // clear the ongoing tx from the state and unblock other writers waiting for it - clearTxTask := st.NewTask("clear-confdb-tx", "Clears the ongoing confdb transaction from state") + clearTxTask = st.NewTask("clear-confdb-tx", "Clears the ongoing confdb transaction from state") linkTask(clearTxTask) clearTxTask.Set("tx-task", commitTask.ID()) - ts.MarkEdge(clearTxTask, clearTxEdge) - return ts, nil + return ts, commitTask, clearTxTask, nil } // getCustodianPlugsForView returns a list of snaps that have connected plugs @@ -665,7 +670,7 @@ func GetStoredTransaction(t *state.Task) (tx *Transaction, txTask *state.Task, s // IsConfdbHookCtx returns whether the hook context belongs to a confdb hook. func IsConfdbHookCtx(ctx *hookstate.Context) bool { - return ctx != nil && !ctx.IsEphemeral() && IsConfdbHookname(ctx.HookName()) + return !ctx.IsEphemeral() && IsConfdbHookname(ctx.HookName()) } // IsConfdbHookname returns whether the hookname denotes a confdb hook. @@ -678,21 +683,22 @@ func IsConfdbHookname(name string) bool { } // CanHookSetConfdb returns whether the hook context belongs to a confdb hook -// that supports snapctl set (either a write hook or load-view). +// that supports snapctl set (either a write hook or load-view). Returns false +// if the context is ephemeral. func CanHookSetConfdb(ctx *hookstate.Context) bool { - return ctx != nil && !ctx.IsEphemeral() && + return !ctx.IsEphemeral() && (strings.HasPrefix(ctx.HookName(), "change-view-") || strings.HasPrefix(ctx.HookName(), "query-view-") || strings.HasPrefix(ctx.HookName(), "load-view-")) } -// GetTransactionForSnapctlGet gets a transaction to read the view's confdb. It -// schedules tasks to load the confdb as needed, unless no custodian defined -// relevant hooks. Blocks until the confdb has been loaded into the Transaction. -// If no tasks need to run to load the confdb, returns without blocking. -func GetTransactionForSnapctlGet(hookCtx *hookstate.Context, view *confdb.View, paths []string, constraints map[string]any) (*Transaction, error) { +// ReadConfdbFromSnap gets a transaction to read the view's confdb. It schedules +// tasks to load the confdb as needed, unless no custodian defined relevant +// hooks. Blocks until the confdb has been loaded into the Transaction. If no +// tasks need to run to load the confdb, returns without blocking. +func ReadConfdbFromSnap(hookCtx *hookstate.Context, view *confdb.View, paths []string, constraints map[string]any) (tx *Transaction, err error) { st := hookCtx.State() - account, schemaName := view.Schema().Account, view.Schema().Name + account, schema := view.Schema().Account, view.Schema().Name if IsConfdbHookCtx(hookCtx) { // running in the context of a transaction, so if the referenced confdb @@ -703,37 +709,41 @@ func GetTransactionForSnapctlGet(hookCtx *hookstate.Context, view *confdb.View, return nil, fmt.Errorf("cannot load confdb view %s: cannot get transaction: %v", view.ID(), err) } - if tx.ConfdbAccount != account || tx.ConfdbName != schemaName { + if tx.ConfdbAccount != account || tx.ConfdbName != schema { // TODO: this should be enabled at some point - return nil, fmt.Errorf("cannot load confdb %s/%s: ongoing transaction for %s/%s", account, schemaName, tx.ConfdbAccount, tx.ConfdbName) + return nil, fmt.Errorf("cannot load confdb %s/%s: ongoing transaction for %s/%s", account, schema, tx.ConfdbAccount, tx.ConfdbName) } // we're reading the tx that this hook is modifying, just return that return tx, nil } - // TODO: replace this with the concurrent access logic. Derive timeout from hookstate.Context - // if not otherwise set? - txs, _, err := getOngoingTxs(st, account, schemaName) + ctx := context.Background() + if hookCtx.Timeout() != time.Duration(0) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, hookCtx.Timeout()) + defer cancel() + } + + accessID, err := waitForAccess(ctx, st, view, readAccess) if err != nil { - return nil, fmt.Errorf("cannot access confdb view %s: cannot check ongoing transactions: %v", view.ID(), err) + return nil, err } - // TODO: use txs.CanStartReadTx() once we support blocking access here - if txs.WriteTxID != "" || len(txs.pending) > 0 { - // TODO: eventually we want to queue this load and block until we serve it. - // It might also be necessary to have some form of timeout. - return nil, fmt.Errorf("cannot access confdb view %s: ongoing write transaction", view.ID()) + // accessID is empty if we didn't release the lock and wait, so no state was + // modified and there aren't other accesses to unblock + if accessID != "" { + defer cleanupAccess(st, accessID, account, schema) } // not running in an existing confdb hook context, so create a transaction // and a change to load/modify data - tx, err := NewTransaction(st, account, schemaName) + tx, err = NewTransaction(st, account, schema) if err != nil { return nil, fmt.Errorf("cannot load confdb view %s: cannot create transaction: %v", view.ID(), err) } - ts, err := createLoadConfdbTasks(st, tx, view, paths, constraints) + ts, clearTxTask, err := createLoadConfdbTasks(st, tx, view, paths, constraints) if err != nil { return nil, err } @@ -752,13 +762,6 @@ func GetTransactionForSnapctlGet(hookCtx *hookstate.Context, view *confdb.View, chg = task.Change() } - chg.AddAll(ts) - - clearTxTask, err := ts.Edge(clearTxEdge) - if err != nil { - return nil, err - } - waitChan := make(chan struct{}) st.AddTaskStatusChangedHandler(func(t *state.Task, old, new state.Status) (remove bool) { if t.ID() == clearTxTask.ID() && new.Ready() { @@ -768,19 +771,27 @@ func GetTransactionForSnapctlGet(hookCtx *hookstate.Context, view *confdb.View, return false }) - err = addReadTransaction(st, account, schemaName, clearTxTask.ID()) + // schedule tasks after saving the tx ID so the deferred cleanup skips waking + // up waiters if a task will do it (len(txs.ReadTxIDs) > 0) + err = addReadTransaction(st, account, schema, clearTxTask.ID(), accessID) if err != nil { return nil, err } + chg.AddAll(ts) ensureNow(st) hookCtx.Unlock() + if blockingSignals["wait-for-change-done"] != nil { + // for testing purposes only + close(blockingSignals["wait-for-change-done"]) + } + select { case <-waitChan: case <-time.After(transactionTimeout): hookCtx.Lock() - return nil, fmt.Errorf("cannot load confdb %s/%s in change %s: timed out after %s", account, schemaName, chg.ID(), transactionTimeout) + return nil, fmt.Errorf("cannot load confdb %s/%s in change %s: timed out after %s", account, schema, chg.ID(), transactionTimeout) } hookCtx.Lock() @@ -797,53 +808,73 @@ const ( writeAccess accessType = "write" ) -type pendingAccess struct { +// access holds data for a pending access, namely a unique identifier, +// access type (read or write) and a channel use to signal that the access can +// proceed. +type access struct { // ID is a random string identifying this access. ID string - // AccessType denotes whether the access is read or write. Exported for - // testing purposes. + // AccessType denotes whether the access is read or write. AccessType accessType // WaitChan is closed to unblock the pending access. WaitChan chan<- struct{} } +// cleanupAccess removes state related to processing an access, if any exists +// (i.e., if the access had to wait and was eventually unblocked). If no tasks +// were scheduled and there aren't other accesses waiting to schedule, it unblocks +// the next pending accesses. +func cleanupAccess(st *state.State, accessID, account, schema string) { + txs, updateTxStateFunc, uerr := getOngoingTxs(st, account, schema) + if uerr != nil { + logger.Noticef("cannot unblock next access after failed access: %v", uerr) + return + } + defer updateTxStateFunc(txs) + + // remove this access from the scheduling list, if we haven't yet + for i, acc := range txs.Scheduling { + if acc.ID == accessID { + txs.Scheduling = append(txs.Scheduling[:i], txs.Scheduling[i+1:]...) + break + } + } + + // this may actually not unblock anything, if other accesses are being processed + uerr = maybeUnblockAccesses(txs) + if uerr != nil { + logger.Noticef("cannot unblock next access after failed access: %v", uerr) + } +} + // ReadConfdb schedules a change to load a confdb, running any appropriate // hooks and fulfilling the requests by reading the view and placing the // resulting data in the change's data (so it can be read by the client). func ReadConfdb(ctx context.Context, st *state.State, view *confdb.View, requests []string, constraints map[string]any, userAccess confdb.Access) (changeID string, err error) { - defer func() { - if err != nil { - uerr := unblockNextAccess(st, view.Schema().Account, view.Schema().Name) - if uerr != nil { - logger.Noticef("cannot unblock next access after failed read: %v", uerr) - } - } - }() - - err = waitForAccess(ctx, st, view, readAccess) + accessID, err := waitForAccess(ctx, st, view, readAccess) if err != nil { return "", err } account, schema := view.Schema().Account, view.Schema().Name + // accessID is empty if we didn't release the lock and wait, so no state was + // modified and there aren't other accesses to unblock + if accessID != "" { + defer cleanupAccess(st, accessID, account, schema) + } + tx, err := NewTransaction(st, account, schema) if err != nil { return "", fmt.Errorf("cannot access confdb view %s: cannot create transaction: %v", view.ID(), err) } - ts, err := createLoadConfdbTasks(st, tx, view, requests, constraints) + ts, clearTxTask, err := createLoadConfdbTasks(st, tx, view, requests, constraints) if err != nil { return "", err } chg := st.NewChange(getConfdbChangeKind, fmt.Sprintf(`Get confdb through %q`, view.ID())) if ts != nil { - // if there are hooks to run, link the read-confdb task to those tasks - clearTxTask, err := ts.Edge(clearTxEdge) - if err != nil { - return "", err - } - // schedule a task to read the tx after the hook and add the data to the // change so it can be read by the client loadConfdbTask := st.NewTask("load-confdb-change", "Load confdb data into the change") @@ -856,7 +887,9 @@ func ReadConfdb(ctx context.Context, st *state.State, view *confdb.View, request loadConfdbTask.WaitFor(clearTxTask) chg.AddAll(ts) - err = addReadTransaction(st, account, schema, clearTxTask.ID()) + // schedule tasks after saving the tx ID so the deferred cleanup skips waking + // up waiters if a task will do it (len(txs.ReadTxIDs) > 0) + err = addReadTransaction(st, account, schema, clearTxTask.ID(), accessID) if err != nil { return "", err } @@ -878,14 +911,14 @@ func ReadConfdb(ctx context.Context, st *state.State, view *confdb.View, request // read a transaction through the given view. In case no custodian snap has any // load-view or query-view hooks, nil is returned. If there are hooks to run, // a clear-confdb-tx task is also scheduled to remove the ongoing transaction at the end. -func createLoadConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, requests []string, constraints map[string]any) (*state.TaskSet, error) { +func createLoadConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, requests []string, constraints map[string]any) (*state.TaskSet, *state.Task, error) { custodians, custodianPlugs, err := getCustodianPlugsForView(st, view) if err != nil { - return nil, err + return nil, nil, err } if len(custodians) == 0 { - return nil, fmt.Errorf("cannot load confdb through view %s: no custodian snap connected", view.ID()) + return nil, nil, fmt.Errorf("cannot load confdb through view %s: no custodian snap connected", view.ID()) } ts := state.NewTaskSet() @@ -899,7 +932,7 @@ func createLoadConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, mightAffectEph, err := view.ReadAffectsEphemeral(requests, constraints) if err != nil { - return nil, err + return nil, nil, err } hookPrefixes := []string{"load-view-", "query-view-"} @@ -923,14 +956,14 @@ func createLoadConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, // there must be least one load-view hook if we're accessing ephemeral data if hookPrefix == "load-view-" && mightAffectEph && !loadViewHookPresent { - return nil, fmt.Errorf("cannot schedule tasks to access %s: read might cover ephemeral data but no custodian has a load-view hook", view.ID()) + return nil, nil, fmt.Errorf("cannot schedule tasks to access %s: read might cover ephemeral data but no custodian has a load-view hook", view.ID()) } } if len(hooks) == 0 { // no hooks to run and not running from API (don't need task to populate) // data in change so we can just read the databag synchronously - return nil, nil + return nil, nil, nil } // clear the tx from the state if the change fails @@ -947,11 +980,9 @@ func createLoadConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, for _, t := range ts.Tasks() { t.Set("tx-task", clearTxTask.ID()) } - linkTask(clearTxTask) - ts.MarkEdge(clearTxTask, clearTxEdge) - return ts, nil + return ts, clearTxTask, nil } func MockFetchConfdbSchemaAssertion(f func(*state.State, int, string, string) error) func() { diff --git a/overlord/confdbstate/confdbstate_test.go b/overlord/confdbstate/confdbstate_test.go index 5ae52a227f1..5690bfc2533 100644 --- a/overlord/confdbstate/confdbstate_test.go +++ b/overlord/confdbstate/confdbstate_test.go @@ -43,7 +43,6 @@ import ( "github.com/snapcore/snapd/overlord/confdbstate" "github.com/snapcore/snapd/overlord/configstate/config" "github.com/snapcore/snapd/overlord/hookstate" - "github.com/snapcore/snapd/overlord/hookstate/ctlcmd" "github.com/snapcore/snapd/overlord/hookstate/hooktest" "github.com/snapcore/snapd/overlord/ifacestate/ifacerepo" "github.com/snapcore/snapd/overlord/snapstate" @@ -193,7 +192,7 @@ func (s *confdbTestSuite) SetUpTest(c *C) { c.Assert(err, IsNil) tr.Commit() - confdbstate.SetBlockingSignalChan(nil) + confdbstate.ResetBlockingSignals() } func parsePath(c *C, path string) []confdb.Accessor { @@ -365,7 +364,7 @@ func (s *confdbTestSuite) TestUnsetView(c *C) { c.Assert(err, testutil.ErrorIs, &confdb.NoDataError{}) } -func (s *confdbTestSuite) TestConfdbstateGetEntireView(c *C) { +func (s *confdbTestSuite) TestGetEntireView(c *C) { s.state.Lock() defer s.state.Unlock() @@ -613,18 +612,12 @@ func (s *confdbTestSuite) TestConfdbTasksUserSetWithCustodianInstalled(c *C) { chg := s.state.NewChange("modify-confdb", "") // a user (not a snap) changes a confdb - ts, err := confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "") - c.Assert(err, IsNil) - chg.AddAll(ts) - - // there are two edges in the taskset - commitTask, err := ts.Edge(confdbstate.CommitEdge) + ts, commitTask, clearTask, err := confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "") c.Assert(err, IsNil) c.Assert(commitTask.Kind(), Equals, "commit-confdb-tx") + c.Assert(clearTask.Kind(), Equals, "clear-confdb-tx") - cleanupTask, err := ts.Edge(confdbstate.ClearTxEdge) - c.Assert(err, IsNil) - c.Assert(cleanupTask.Kind(), Equals, "clear-confdb-tx") + chg.AddAll(ts) // the custodian snap's hooks are run tasks := []string{"clear-confdb-tx-on-error", "run-hook", "run-hook", "run-hook", "commit-confdb-tx", "clear-confdb-tx"} @@ -670,7 +663,7 @@ func (s *confdbTestSuite) TestConfdbTasksCustodianSnapSet(c *C) { chg := s.state.NewChange("set-confdb", "") // a user (not a snap) changes a confdb - ts, err := confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "custodian-snap") + ts, _, _, err := confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "custodian-snap") c.Assert(err, IsNil) chg.AddAll(ts) @@ -713,7 +706,7 @@ func (s *confdbTestSuite) TestConfdbTasksObserverSnapSetWithCustodianInstalled(c chg := s.state.NewChange("modify-confdb", "") // a non-custodian snap modifies a confdb - ts, err := confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "test-snap-1") + ts, _, _, err := confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "test-snap-1") c.Assert(err, IsNil) chg.AddAll(ts) @@ -784,7 +777,7 @@ func (s *confdbTestSuite) testConfdbTasksNoCustodian(c *C) { view := s.dbSchema.View("setup-wifi") // a non-custodian snap modifies a confdb - _, err = confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "test-snap-1") + _, _, _, err = confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "test-snap-1") c.Assert(err, ErrorMatches, fmt.Sprintf("cannot commit changes to confdb made through view %s/network/%s: no custodian snap connected", s.devAccID, view.Name)) } @@ -1019,7 +1012,7 @@ func (s *confdbTestSuite) checkOngoingWriteConfdbTx(c *C, account, confdbName st c.Assert(commitTask.Status(), Equals, state.DoStatus) } -func (s *confdbTestSuite) TestGetTransactionFromUserCreatesNewChange(c *C) { +func (s *confdbTestSuite) TestWriteConfdbCreatesNewChange(c *C) { hooks, restore := s.mockConfdbHooks() defer restore() @@ -1038,37 +1031,24 @@ func (s *confdbTestSuite) TestGetTransactionFromUserCreatesNewChange(c *C) { s.setupConfdbScenario(c, custodians, nil) view := s.dbSchema.View("setup-wifi") - - tx, commitTxFunc, err := confdbstate.GetTransactionToSet(nil, s.state, view) - c.Assert(err, IsNil) - c.Assert(tx, NotNil) - c.Assert(commitTxFunc, NotNil) - - err = tx.Set(parsePath(c, "wifi.ssid"), "foo") - c.Assert(err, IsNil) - - // mock the daemon triggering the commit - changeID, waitChan, err := commitTxFunc() + chgID, err := confdbstate.WriteConfdb(context.Background(), s.state, view, map[string]any{ + "ssid": "foo", + }) c.Assert(err, IsNil) - s.state.Unlock() - select { - case <-waitChan: - case <-time.After(testutil.HostScaledTimeout(5 * time.Second)): - s.state.Lock() - c.Fatal("test timed out after 5s") - } - s.state.Lock() - c.Assert(s.state.Changes(), HasLen, 1) chg := s.state.Changes()[0] c.Assert(chg.Kind(), Equals, "set-confdb") - c.Assert(changeID, Equals, chg.ID()) + c.Assert(chg.ID(), Equals, chgID) + + s.state.Unlock() + s.o.Settle(testutil.HostScaledTimeout(5 * time.Second)) + s.state.Lock() s.checkSetConfdbChange(c, chg, hooks) } -func (s *confdbTestSuite) TestGetTransactionFromSnapCreatesNewChange(c *C) { +func (s *confdbTestSuite) TestWriteConfdbFromSnapCreatesNewChange(c *C) { hooks, restore := s.mockConfdbHooks() defer restore() @@ -1081,6 +1061,7 @@ func (s *confdbTestSuite) TestGetTransactionFromSnapCreatesNewChange(c *C) { s.state.Lock() defer s.state.Unlock() + view := s.dbSchema.View("setup-wifi") // only one custodian snap is installed custodians := map[string]confdbHooks{"custodian-snap": allHooks} @@ -1089,15 +1070,13 @@ func (s *confdbTestSuite) TestGetTransactionFromSnapCreatesNewChange(c *C) { ctx, err := hookstate.NewContext(nil, s.state, &hookstate.HookSetup{Snap: "test-snap"}, nil, "") c.Assert(err, IsNil) - s.state.Unlock() - stdout, stderr, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=foo"}, 0, nil) + + ctx.Lock() + err = confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{"ssid": "foo"}) c.Assert(err, IsNil) - c.Check(stdout, IsNil) - c.Check(stderr, IsNil) // this is called automatically by hooks or manually for daemon/ - ctx.Lock() ctx.Done() ctx.Unlock() @@ -1110,20 +1089,24 @@ func (s *confdbTestSuite) TestGetTransactionFromSnapCreatesNewChange(c *C) { } func (s *confdbTestSuite) TestGetTransactionFromNonConfdbHookAddsConfdbTx(c *C) { + view := s.dbSchema.View("setup-wifi") + var hooks []string restore := hookstate.MockRunHook(func(ctx *hookstate.Context, _ *tomb.Tomb) ([]byte, error) { t, _ := ctx.Task() - ctx.State().Lock() + s.state.Lock() var hooksup *hookstate.HookSetup err := t.Get("hook-setup", &hooksup) - ctx.State().Unlock() + s.state.Unlock() if err != nil { return nil, err } if hooksup.Hook == "install" { - _, _, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=foo"}, 0, nil) + ctx.Lock() + err := confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{"ssid": "foo"}) + ctx.Unlock() c.Assert(err, IsNil) return nil, nil } @@ -1220,56 +1203,8 @@ func (s *confdbTestSuite) checkSetConfdbChange(c *C, chg *state.Change, hooks *[ c.Assert(val, Equals, "foo") } -func (s *confdbTestSuite) TestGetTransactionFromChangeViewHook(c *C) { - ctx := s.testGetReadableOngoingTransaction(c, "change-view-setup") - - // change-view hooks can also write to the transaction - stdout, stderr, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=bar"}, 0, nil) - c.Assert(err, IsNil) - // accessed an ongoing transaction - c.Assert(stdout, IsNil) - c.Assert(stderr, IsNil) - - // this save the changes that the hook performs - ctx.Lock() - ctx.Done() - ctx.Unlock() - - s.state.Lock() - defer s.state.Unlock() - t, _ := ctx.Task() - tx, _, _, err := confdbstate.GetStoredTransaction(t) - c.Assert(err, IsNil) - - val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) - c.Assert(err, IsNil) - c.Assert(val, Equals, "bar") -} - -func (s *confdbTestSuite) TestGetTransactionFromSaveViewHook(c *C) { - ctx := s.testGetReadableOngoingTransaction(c, "save-view-setup") - - // non change-view hooks cannot modify the transaction - stdout, stderr, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=bar"}, 0, nil) - c.Assert(err, ErrorMatches, `cannot modify confdb in "save-view-setup" hook`) - c.Assert(stdout, IsNil) - c.Assert(stderr, IsNil) -} - -func (s *confdbTestSuite) TestGetTransactionFromViewChangedHook(c *C) { - ctx := s.testGetReadableOngoingTransaction(c, "observe-view-setup") - - // non change-view hooks cannot modify the transaction - stdout, stderr, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=bar"}, 0, nil) - c.Assert(err, ErrorMatches, `cannot modify confdb in "observe-view-setup" hook`) - c.Assert(stdout, IsNil) - c.Assert(stderr, IsNil) -} - -func (s *confdbTestSuite) testGetReadableOngoingTransaction(c *C, hook string) *hookstate.Context { +func (s *confdbTestSuite) TestWriteConfdbFromChangeViewHook(c *C) { s.state.Lock() - defer s.state.Unlock() - custodians := map[string]confdbHooks{"custodian-snap": allHooks} s.setupConfdbScenario(c, custodians, []string{"test-snap"}) @@ -1286,22 +1221,42 @@ func (s *confdbTestSuite) testGetReadableOngoingTransaction(c *C, hook string) * hookTask := s.state.NewTask("run-hook", "") chg.AddTask(hookTask) - setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: hook} + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: "change-view-setup"} mockHandler := hooktest.NewMockHandler() hookTask.Set("tx-task", commitTask.ID()) + s.state.Unlock() ctx, err := hookstate.NewContext(hookTask, s.state, setup, mockHandler, "") c.Assert(err, IsNil) - s.state.Unlock() - stdout, stderr, err := ctlcmd.Run(ctx, []string{"get", "--view", ":setup", "ssid"}, 0, nil) - s.state.Lock() + ctx.Lock() + view := s.dbSchema.View("setup-wifi") + tx, err := confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) c.Assert(err, IsNil) // accessed an ongoing transaction - c.Assert(string(stdout), Equals, "foo\n") - c.Assert(stderr, IsNil) + data, err := tx.Get(parsePath(c, "wifi.ssid"), nil) + c.Assert(err, IsNil) + c.Assert(data, Equals, "foo") + + // change-view hooks can also write to the transaction + err = confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{ + "ssid": "bar", + }) + c.Assert(err, IsNil) + + // accessed an ongoing transaction so save the changes made by the hook + ctx.Done() + ctx.Unlock() + + s.state.Lock() + defer s.state.Unlock() + t, _ := ctx.Task() + tx, _, _, err = confdbstate.GetStoredTransaction(t) + c.Assert(err, IsNil) - return ctx + val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) + c.Assert(err, IsNil) + c.Assert(val, Equals, "bar") } func (s *confdbTestSuite) TestGetDifferentTransactionThanOngoing(c *C) { @@ -1334,11 +1289,10 @@ func (s *confdbTestSuite) TestGetDifferentTransactionThanOngoing(c *C) { c.Assert(err, IsNil) ctx.Lock() - tx, commitTxFunc, err := confdbstate.GetTransactionToSet(ctx, s.state, confdb.View("foo")) + view := confdb.View("foo") + err = confdbstate.WriteConfdbFromSnap(ctx, view, nil) ctx.Unlock() c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot access confdb through view foo/bar/foo: ongoing transaction for %s/network`, s.devAccID)) - c.Assert(tx, IsNil) - c.Assert(commitTxFunc, IsNil) } func (s *confdbTestSuite) TestConfdbLoadDisconnectedCustodianSnap(c *C) { @@ -1373,7 +1327,7 @@ func (s *confdbTestSuite) testConfdbLoadNoCustodian(c *C) { view := s.dbSchema.View("setup-wifi") // a non-custodian snap modifies a confdb - _, err = confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) + _, _, err = confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) c.Assert(err, ErrorMatches, fmt.Sprintf("cannot load confdb through view %s/network/setup-wifi: no custodian snap connected", s.devAccID)) } @@ -1434,12 +1388,9 @@ func (s *confdbTestSuite) TestConfdbLoadCustodianInstalled(c *C) { view := s.dbSchema.View("setup-wifi") chg := s.state.NewChange("load-confdb", "") - ts, err := confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) + ts, cleanupTask, err := confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) c.Assert(err, IsNil) chg.AddAll(ts) - - cleanupTask, err := ts.Edge(confdbstate.ClearTxEdge) - c.Assert(err, IsNil) c.Assert(cleanupTask.Kind(), Equals, "clear-confdb-tx") // the custodian snap's hooks are run @@ -1477,7 +1428,7 @@ func (s *confdbTestSuite) TestConfdbLoadCustodianWithNoHooks(c *C) { c.Assert(err, IsNil) view := s.dbSchema.View("setup-wifi") - ts, err := confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) + ts, _, err := confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) c.Assert(err, IsNil) // no hooks, nothing to run c.Assert(ts, IsNil) @@ -1498,7 +1449,7 @@ func (s *confdbTestSuite) TestConfdbLoadTasks(c *C) { c.Assert(err, IsNil) view := s.dbSchema.View("setup-wifi") - ts, err := confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) + ts, _, err := confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) c.Assert(err, IsNil) chg := s.state.NewChange("get-confdb", "") chg.AddAll(ts) @@ -1521,18 +1472,19 @@ func (s *confdbTestSuite) TestConfdbLoadTasks(c *C) { checkLoadConfdbTasks(c, chg, tasks, hooks) } -func (s *confdbTestSuite) TestGetTransactionForSnapctlNoHook(c *C) { +func (s *confdbTestSuite) TestReadConfdbFromSnapEphemeral(c *C) { s.state.Lock() // only one custodian snap is installed custodians := map[string]confdbHooks{"custodian-snap": allHooks} s.setupConfdbScenario(c, custodians, nil) mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(nil, s.state, nil, mockHandler, "") + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup"} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") c.Assert(err, IsNil) s.state.Unlock() - chg := s.testGetTransactionForSnapctl(c, ctx) + chg := s.testReadConfdbFromSnap(c, ctx) s.state.Lock() defer s.state.Unlock() @@ -1563,18 +1515,16 @@ func (s *confdbTestSuite) TestGetTransactionForSnapctlNonConfdbHook(c *C) { c.Assert(err, IsNil) s.state.Unlock() - s.testGetTransactionForSnapctl(c, ctx) + s.testReadConfdbFromSnap(c, ctx) } -func (s *confdbTestSuite) testGetTransactionForSnapctl(c *C, ctx *hookstate.Context) *state.Change { +func (s *confdbTestSuite) testReadConfdbFromSnap(c *C, ctx *hookstate.Context) *state.Change { hooks, restore := s.mockConfdbHooks() defer restore() restore = confdbstate.MockEnsureNow(func(*state.State) { s.checkOngoingReadConfdbTx(c, s.devAccID, "network") - go func() { - s.o.Settle(5 * time.Second) - }() + go s.o.Settle(5 * time.Second) }) defer restore() @@ -1589,7 +1539,7 @@ func (s *confdbTestSuite) testGetTransactionForSnapctl(c *C, ctx *hookstate.Cont s.state.Set("confdb-databags", map[string]map[string]confdb.JSONDatabag{s.devAccID: {"network": bag}}) view := s.dbSchema.View("setup-wifi") - tx, err := confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"ssid"}, nil) + tx, err := confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) c.Assert(err, IsNil) c.Assert(s.state.Changes(), HasLen, 1) @@ -1630,7 +1580,7 @@ func (s *confdbTestSuite) TestGetTransactionInConfdbHook(c *C) { c.Assert(err, IsNil) view := s.dbSchema.View("setup-wifi") - tx, err := confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"ssid"}, nil) + tx, err := confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) c.Assert(err, IsNil) // reads synchronously without creating new change or tasks c.Assert(s.state.Changes(), HasLen, 1) @@ -1656,19 +1606,24 @@ func (s *confdbTestSuite) TestGetTransactionNoConfdbHooks(c *C) { Hook: "install", } hookTask.Set("hook-setup", hooksup) - mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(hookTask, s.state, hooksup, mockHandler, "") - c.Assert(err, IsNil) // write some value for the get to read bag := confdb.NewJSONDatabag() - err = bag.Set(parsePath(c, "wifi.ssid"), "foo") + err := bag.Set(parsePath(c, "wifi.ssid"), "foo") c.Assert(err, IsNil) s.state.Set("confdb-databags", map[string]map[string]confdb.JSONDatabag{s.devAccID: {"network": bag}}) + mockHandler := hooktest.NewMockHandler() + ctx, err := hookstate.NewContext(hookTask, s.state, hooksup, mockHandler, "") + c.Assert(err, IsNil) + view := s.dbSchema.View("setup-wifi") - tx, err := confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"ssid"}, nil) + s.state.Unlock() + ctx.Lock() + tx, err := confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) + ctx.Unlock() + s.state.Lock() c.Assert(err, IsNil) c.Assert(tx, NotNil) @@ -1694,7 +1649,8 @@ func (s *confdbTestSuite) TestGetTransactionTimesOut(c *C) { s.setupConfdbScenario(c, custodians, nil) mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(nil, s.state, nil, mockHandler, "") + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup"} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") c.Assert(err, IsNil) // write some value for the get to read @@ -1709,7 +1665,7 @@ func (s *confdbTestSuite) TestGetTransactionTimesOut(c *C) { ctx.Lock() defer ctx.Unlock() - tx, err := confdbstate.GetTransactionForSnapctlGet(ctx, view, nil, nil) + tx, err := confdbstate.ReadConfdbFromSnap(ctx, view, nil, nil) c.Assert(err, ErrorMatches, fmt.Sprintf("cannot load confdb %s/network in change 1: timed out after 0s", s.devAccID)) c.Assert(tx, IsNil) } @@ -1781,7 +1737,7 @@ func (s *confdbTestSuite) checkOngoingReadConfdbTx(c *C, account, confdbName str c.Assert(clearTask.Status(), Equals, state.DoStatus) } -func (s *confdbTestSuite) TestGetTransactionForAPI(c *C) { +func (s *confdbTestSuite) TestAPIReadConfdb(c *C) { s.state.Lock() custodians := map[string]confdbHooks{"custodian-snap": allHooks} nonCustodians := []string{"test-snap"} @@ -1837,7 +1793,7 @@ func (s *confdbTestSuite) TestGetTransactionForAPI(c *C) { }) } -func (s *confdbTestSuite) TestGetTransactionForAPINoHooks(c *C) { +func (s *confdbTestSuite) TestReadConfdbNoHooks(c *C) { s.state.Lock() defer s.state.Unlock() @@ -1881,7 +1837,76 @@ func (s *confdbTestSuite) TestGetTransactionForAPINoHooks(c *C) { }) } -func (s *confdbTestSuite) TestGetTransactionForAPINoHooksError(c *C) { +func (s *confdbTestSuite) TestReadConfdbNoHooksUnblocksNextPendingAccess(c *C) { + s.state.Lock() + + custodians := map[string]confdbHooks{"custodian-snap": noHooks} + nonCustodians := []string{"test-snap"} + s.setupConfdbScenario(c, custodians, nonCustodians) + + view := s.dbSchema.View("setup-wifi") + ref := view.Schema().Account + "/" + view.Schema().Name + s.state.Set("confdb-ongoing-txs", map[string]*confdbstate.ConfdbTransactions{ + ref: {WriteTxID: "10"}, + }) + + // testing helper closed when the access is about to block + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + + var chgID string + doneChan := make(chan struct{}) + go func() { + var err error + chgID, err = confdbstate.ReadConfdb(context.Background(), s.state, view, []string{"ssid"}, nil, 0) + c.Assert(err, IsNil) + s.state.Unlock() + close(doneChan) + }() + + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // the blocked read released the lock before waiting + s.state.Lock() + accs, ok := s.state.Cached("pending-confdb-" + ref).([]confdbstate.Access) + c.Assert(ok, Equals, true) + c.Assert(accs, HasLen, 1) + c.Assert(accs[0].AccessType, Equals, confdbstate.AccessType("read")) + + nextWaitChan := make(chan struct{}, 1) + s.endOngoingAccess(c, &confdbstate.Access{ + ID: "next-write", + AccessType: confdbstate.AccessType("write"), + WaitChan: nextWaitChan, + }) + s.state.Unlock() + + select { + case <-nextWaitChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected next access to be unblocked but timed out") + } + + select { + case <-doneChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected read to complete but timed out") + } + + s.state.Lock() + defer s.state.Unlock() + + chg := s.state.Change(chgID) + c.Assert(chg, NotNil) + c.Assert(chg.Tasks(), HasLen, 0) + c.Assert(chg.Status(), Equals, state.DoneStatus) +} + +func (s *confdbTestSuite) TestAPIReadConfdbNoHooksError(c *C) { s.state.Lock() defer s.state.Unlock() @@ -1915,7 +1940,7 @@ func (s *confdbTestSuite) TestGetTransactionForAPINoHooksError(c *C) { c.Assert(errKind, Equals, "option-not-found") } -func (s *confdbTestSuite) TestGetTransactionForAPIError(c *C) { +func (s *confdbTestSuite) TestAPIReadConfdbError(c *C) { s.state.Lock() custodians := map[string]confdbHooks{"custodian-snap": allHooks} nonCustodians := []string{"test-snap"} @@ -1949,87 +1974,38 @@ func (s *confdbTestSuite) TestGetTransactionForAPIError(c *C) { c.Assert(errKind, Equals, "option-not-found") } -// TODO: replace these tests once the snapctl flow is also blocking -func (s *confdbTestSuite) TestConcurrentAccessWithOngoingWrite(c *C) { +func (s *confdbTestSuite) TestWriteAffectingEphemeralMustDefineSaveViewHook(c *C) { s.state.Lock() - defer s.state.Unlock() - - s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) - _, restore := s.mockConfdbHooks() - defer restore() - - err := confdbstate.SetWriteTransaction(s.state, s.devAccID, "network", "1") - c.Assert(err, IsNil) - - view := s.dbSchema.View("setup-wifi") - - // reading from the snap - mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(nil, s.state, nil, mockHandler, "") - c.Assert(err, IsNil) - - _, err = confdbstate.GetTransactionForSnapctlGet(ctx, view, nil, nil) - c.Assert(err, ErrorMatches, fmt.Sprintf("cannot access confdb view %s/network/setup-wifi: ongoing write transaction", s.devAccID)) - - // writing (used both from snap or API) - _, _, err = confdbstate.GetTransactionToSet(nil, s.state, view) - c.Assert(err, ErrorMatches, fmt.Sprintf("cannot write confdb through view %s/network/setup-wifi: ongoing transaction", s.devAccID)) -} + hooks := observeView | queryView | loadView | changeView + s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": hooks}, nil) + s.state.Unlock() -func (s *confdbTestSuite) TestConcurrentAccessWithOngoingRead(c *C) { - s.state.Lock() - // it's better not to have hooks here because if we do the GetTransactionForSnapctlGet - // needs to schedule tasks and will block on them, making this test more timing based/annoying - s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": noHooks}, nil) + restore := confdbstate.MockEnsureNow(func(*state.State) { + s.checkOngoingWriteConfdbTx(c, s.devAccID, "network") - err := confdbstate.AddReadTransaction(s.state, s.devAccID, "network", "1") - c.Assert(err, IsNil) - s.state.Unlock() + go s.o.Settle(testutil.HostScaledTimeout(5 * time.Second)) + }) + defer restore() mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(nil, s.state, nil, mockHandler, "") + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1)} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") c.Assert(err, IsNil) ctx.Lock() defer ctx.Unlock() - - view := s.dbSchema.View("setup-wifi") - // writing (used both from snap or API) conflicts - _, _, err = confdbstate.GetTransactionToSet(ctx, s.state, view) - c.Assert(err, ErrorMatches, fmt.Sprintf("cannot write confdb through view %s/network/setup-wifi: ongoing transaction", s.devAccID)) - - // we can read from the API and the snap concurrently with other reads - _, err = confdbstate.ReadConfdb(context.Background(), s.state, view, []string{"ssid"}, nil, confdb.AdminAccess) - c.Assert(err, IsNil) - - _, err = confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"ssid"}, nil) - c.Assert(err, IsNil) -} - -func (s *confdbTestSuite) TestWriteAffectingEphemeralMustDefineSaveViewHook(c *C) { - s.state.Lock() - defer s.state.Unlock() - - hooks := observeView | queryView | loadView | changeView - s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": hooks}, nil) - view := s.dbSchema.View("setup-wifi") - tx, commitTx, err := confdbstate.GetTransactionToSet(nil, s.state, view) - c.Assert(err, IsNil) - err = tx.Set(parsePath(c, "wifi.eph"), "foo") - c.Assert(err, IsNil) // can't write an ephemeral path w/o a save-view hook - _, _, err = commitTx() + err = confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{ + "eph": "foo", + }) c.Assert(err, ErrorMatches, fmt.Sprintf("cannot access %s/network/setup-wifi: write might change ephemeral data but no custodians has a save-view hook", s.devAccID)) - err = tx.Clear(s.state) - c.Assert(err, IsNil) - err = tx.Set(parsePath(c, "wifi.ssid"), "foo") - c.Assert(err, IsNil) - // but we can if the path can't touch any ephemeral data - _, _, err = commitTx() + err = confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{ + "ssid": "foo", + }) c.Assert(err, IsNil) } @@ -2039,14 +2015,15 @@ func (s *confdbTestSuite) TestReadCoveringEphemeralMustDefineLoadViewHook(c *C) s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": hooks}, nil) mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(nil, s.state, nil, mockHandler, "") + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1)} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") c.Assert(err, IsNil) s.state.Unlock() ctx.Lock() view := s.dbSchema.View("setup-wifi") // can't read an ephemeral path w/o a load-view hook - _, err = confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"eph"}, nil) + _, err = confdbstate.ReadConfdbFromSnap(ctx, view, []string{"eph"}, nil) c.Assert(err, ErrorMatches, fmt.Sprintf("cannot schedule tasks to access %s/network/setup-wifi: read might cover ephemeral data but no custodian has a load-view hook", s.devAccID)) // so we don't block on the read @@ -2054,7 +2031,7 @@ func (s *confdbTestSuite) TestReadCoveringEphemeralMustDefineLoadViewHook(c *C) defer restore() // but if the path isn't ephemeral it's fine - _, err = confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"ssid"}, nil) + _, err = confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) c.Assert(err, ErrorMatches, fmt.Sprintf("cannot load confdb %s/network in change 1: timed out after 0s", s.devAccID)) ctx.Unlock() @@ -2074,7 +2051,8 @@ func (s *confdbTestSuite) TestBadPathHookChecks(c *C) { s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(nil, s.state, nil, mockHandler, "") + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup"} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") c.Assert(err, IsNil) s.state.Unlock() @@ -2082,18 +2060,51 @@ func (s *confdbTestSuite) TestBadPathHookChecks(c *C) { defer ctx.Unlock() view := s.dbSchema.View("setup-wifi") - _, err = confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"foo"}, nil) + _, err = confdbstate.ReadConfdbFromSnap(ctx, view, []string{"foo"}, nil) c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot get "foo" through %s/network/setup-wifi: no matching rule`, s.devAccID)) _, err = confdbstate.ReadConfdb(context.Background(), s.state, view, []string{"foo"}, nil, confdb.AdminAccess) c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot get "foo" through %s/network/setup-wifi: no matching rule`, s.devAccID)) - tx, commitTxFunc, err := confdbstate.GetTransactionToSet(nil, s.state, view) - c.Assert(err, IsNil) - // this shouldn't happen unless there's a mismatch between views and schemas but check we're robust - c.Assert(tx.Set(parsePath(c, "foo"), "bar"), IsNil) - _, _, err = commitTxFunc() - c.Assert(err, ErrorMatches, `cannot check if write affects ephemeral data: cannot use "foo" as key in map`) + err = confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{"foo": "bar"}) + c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot set "foo" through %s/network/setup-wifi: no matching rule`, s.devAccID)) +} + +func (s *confdbTestSuite) TestCanHookSetConfdb(c *C) { + s.state.Lock() + defer s.state.Unlock() + + mockHandler := hooktest.NewMockHandler() + chg := s.state.NewChange("test", "test change") + task := s.state.NewTask("test-task", "test task") + chg.AddTask(task) + + for _, tc := range []struct { + hook string + task *state.Task + expected bool + }{ + // we can set to modify transactions in read or write + {hook: "change-view-setup", task: task, expected: true}, + {hook: "query-view-setup", task: task, expected: true}, + // also to load data into a transaction + {hook: "load-view-setup", task: task, expected: true}, + // the other hooks cannot set + {hook: "save-view-setup", task: task, expected: false}, + {hook: "observe-view-setup", task: task, expected: false}, + // same for non-confdb hooks + {hook: "install", task: task, expected: false}, + {hook: "configure", task: task, expected: false}, + // helper expects the context to not be ephemeral + {hook: "change-view-setup", task: nil, expected: false}, + {hook: "query-view-setup", task: nil, expected: false}, + {hook: "load-view-setup", task: nil, expected: false}, + } { + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: tc.hook} + ctx, err := hookstate.NewContext(tc.task, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + c.Check(confdbstate.CanHookSetConfdb(ctx), Equals, tc.expected) + } } func (s *confdbTestSuite) TestEnsureLoopLogging(c *C) { @@ -2161,7 +2172,7 @@ func (s *confdbTestSuite) TestGetTransactionWithSecretVisibility(c *C) { c.Assert(log[0], Matches, fmt.Sprintf(`.*cannot get "private" through %s/network/setup-wifi: unauthorized access`, s.devAccID)) } -func (s *confdbTestSuite) TestReadWithOngoingWrite(c *C) { +func (s *confdbTestSuite) TestAPIReadWithOngoingWrite(c *C) { view := s.dbSchema.View("setup-wifi") firstAccess := func(ctx context.Context) string { chgID, err := confdbstate.WriteConfdb(ctx, s.state, view, map[string]any{"ssid": "foo"}) @@ -2176,7 +2187,7 @@ func (s *confdbTestSuite) TestReadWithOngoingWrite(c *C) { s.testConcurrentAccess(c, firstAccess, secondAccess) } -func (s *confdbTestSuite) TestWriteWithOngoingWrite(c *C) { +func (s *confdbTestSuite) TestAPIWriteWithOngoingWrite(c *C) { view := s.dbSchema.View("setup-wifi") firstAccess := func(ctx context.Context) string { chgID, err := confdbstate.WriteConfdb(ctx, s.state, view, map[string]any{"ssid": "foo"}) @@ -2191,7 +2202,7 @@ func (s *confdbTestSuite) TestWriteWithOngoingWrite(c *C) { s.testConcurrentAccess(c, firstAccess, secondAccess) } -func (s *confdbTestSuite) TestWriteWithOngoingRead(c *C) { +func (s *confdbTestSuite) TestAPIWriteWithOngoingRead(c *C) { view := s.dbSchema.View("setup-wifi") firstAccess := func(ctx context.Context) string { chgID, err := confdbstate.ReadConfdb(ctx, s.state, view, []string{"ssid"}, nil, 0) @@ -2223,7 +2234,7 @@ func (s *confdbTestSuite) testConcurrentAccess(c *C, firstAccess, secondAccess a // testing helper closed when the access is about to block blockingChan := make(chan struct{}) - confdbstate.SetBlockingSignalChan(blockingChan) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) doneChan := make(chan struct{}) var secondChgID string @@ -2236,7 +2247,6 @@ func (s *confdbTestSuite) testConcurrentAccess(c *C, firstAccess, secondAccess a select { case <-blockingChan: // signals that the second access is going to block - break case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected access to block but timed out") } @@ -2250,7 +2260,6 @@ func (s *confdbTestSuite) testConcurrentAccess(c *C, firstAccess, secondAccess a select { case <-doneChan: // signals that the second access was unblocked and scheduled the operation - break case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected access to block but timed out") } @@ -2261,7 +2270,7 @@ func (s *confdbTestSuite) testConcurrentAccess(c *C, firstAccess, secondAccess a c.Assert(secondChgID, Not(Equals), "") } -func (s *confdbTestSuite) TestMultipleConcurrentReads(c *C) { +func (s *confdbTestSuite) TestAPIMultipleConcurrentReads(c *C) { s.state.Lock() defer s.state.Unlock() @@ -2281,9 +2290,8 @@ func (s *confdbTestSuite) TestMultipleConcurrentReads(c *C) { c.Assert(err, IsNil) c.Assert(secondChgID, Not(Equals), "") - // mock a pending write - waitChan := make(chan struct{}) - s.state.Cache("confdb-accesses-"+view.Schema().Account+"/network", []confdbstate.PendingAccess{{ + waitChan := make(chan struct{}, 1) + s.state.Cache("pending-confdb-"+view.Schema().Account+"/network", []confdbstate.Access{{ ID: "foo", AccessType: confdbstate.AccessType("write"), WaitChan: waitChan, @@ -2297,7 +2305,7 @@ func (s *confdbTestSuite) TestMultipleConcurrentReads(c *C) { select { case <-waitChan: // only one read tx close this otherwise the other would panic - case <-time.After(2 * time.Second): + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected write to be unblocked but timed out") } @@ -2323,7 +2331,7 @@ func (s *confdbTestSuite) TestBlockingAccessIsCancelled(c *C) { // testing helper closed when the access is about to block blockingChan := make(chan struct{}) - confdbstate.SetBlockingSignalChan(blockingChan) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) doneChan := make(chan struct{}) var readErr error @@ -2335,7 +2343,6 @@ func (s *confdbTestSuite) TestBlockingAccessIsCancelled(c *C) { select { case <-blockingChan: // signals that the timed out read is done - break case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected access to block but timed out") } @@ -2343,14 +2350,13 @@ func (s *confdbTestSuite) TestBlockingAccessIsCancelled(c *C) { cancel() select { case <-doneChan: - break case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected access to block but timed out") } c.Assert(readErr, ErrorMatches, ".*timed out waiting for access") } -func (s *confdbTestSuite) TestBlockingAccessTimedOut(c *C) { +func (s *confdbTestSuite) TestAPIBlockingAccessTimedOut(c *C) { s.state.Lock() defer s.state.Unlock() @@ -2365,7 +2371,7 @@ func (s *confdbTestSuite) TestBlockingAccessTimedOut(c *C) { // testing helper closed when the access is about to block blockingChan := make(chan struct{}) - confdbstate.SetBlockingSignalChan(blockingChan) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) restore = confdbstate.MockDefaultWaitTimeout(time.Millisecond) defer restore() @@ -2379,14 +2385,13 @@ func (s *confdbTestSuite) TestBlockingAccessTimedOut(c *C) { select { case <-doneChan: - break case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected access to block but timed out") } c.Assert(readErr, ErrorMatches, ".*timed out waiting for access") } -func (s *confdbTestSuite) TestAccessDifferentConfdbIndependently(c *C) { +func (s *confdbTestSuite) TestAPIAccessDifferentConfdbIndependently(c *C) { s.state.Lock() defer s.state.Unlock() @@ -2401,7 +2406,7 @@ func (s *confdbTestSuite) TestAccessDifferentConfdbIndependently(c *C) { // testing helper closed when the access is about to block blockingChan := make(chan struct{}) - confdbstate.SetBlockingSignalChan(blockingChan) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) restore = confdbstate.MockDefaultWaitTimeout(time.Millisecond) defer restore() @@ -2413,7 +2418,6 @@ func (s *confdbTestSuite) TestAccessDifferentConfdbIndependently(c *C) { func (s *confdbTestSuite) TestFailedAccessUnblocksNextAccess(c *C) { s.state.Lock() - defer s.state.Unlock() // force the read/writes to fail due to missing custodian repo := interfaces.NewRepository() @@ -2421,6 +2425,7 @@ func (s *confdbTestSuite) TestFailedAccessUnblocksNextAccess(c *C) { view := s.dbSchema.View("setup-wifi") ctx := context.Background() + s.state.Unlock() var accErr error // mock ongoing read transaction and pending access @@ -2435,58 +2440,660 @@ func (s *confdbTestSuite) TestFailedAccessUnblocksNextAccess(c *C) { ongoingTxs[ref] = &confdbstate.ConfdbTransactions{ WriteTxID: "10", } + s.state.Lock() s.state.Set("confdb-ongoing-txs", ongoingTxs) - s.state.Cache("confdb-accesses-"+ref, nil) + s.state.Cache("pending-confdb-"+ref, nil) + s.state.Cache("scheduling-confdb-"+ref, nil) // testing helper closed when the access is about to block blockingChan := make(chan struct{}) - confdbstate.SetBlockingSignalChan(blockingChan) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) accDone := make(chan struct{}) go func() { accessFunc() + s.state.Unlock() close(accDone) }() select { case <-blockingChan: - case <-time.After(2 * time.Second): + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected access to block but timed out") } // while the access is blocked mock another one coming in s.state.Lock() - accs := s.state.Cached("confdb-accesses-" + ref) + accs := s.state.Cached("pending-confdb-" + ref) c.Assert(accs, NotNil) - pending := accs.([]confdbstate.PendingAccess) + pending := accs.([]confdbstate.Access) c.Assert(pending, HasLen, 1) - // mock another pending access - waitChan := make(chan struct{}) - pending = append(pending, confdbstate.PendingAccess{ + waitChan := make(chan struct{}, 1) + s.endOngoingAccess(c, &confdbstate.Access{ ID: "foo", AccessType: confdbstate.AccessType("write"), WaitChan: waitChan, }) - s.state.Cache("confdb-accesses-"+ref, pending) s.state.Unlock() - // unblock the access we started - close(pending[0].WaitChan) - // the access we mocked should be unblocked select { case <-waitChan: - case <-time.After(2 * time.Second): + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected next access to be unblocked but timed out") } // the access failed with the expected error select { case <-accDone: - case <-time.After(2 * time.Second): + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected failed access to return but timed out") } c.Assert(accErr, ErrorMatches, ".*: no custodian snap connected") } } + +func (s *confdbTestSuite) testSnapctlConcurrentAccess(c *C, firstAccess accessFunc, secondAccess func()) { + s.state.Lock() + + s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) + _, restore := s.mockConfdbHooks() + defer restore() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + firstAccess(ctx) + s.state.Unlock() + + // testing helper closed when the access is about to block + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + + doneChan := make(chan struct{}) + go func() { + secondAccess() + close(doneChan) + }() + + select { + case <-blockingChan: + // second access blocked waiting for its turn + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // closed when the second access waits for the change to complete + blockingChan = make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-change-done", blockingChan) + + err := s.o.Settle(5 * time.Second) + c.Assert(err, IsNil) + + // once the first access completes the second access should be unblocked, scheduled + // and again while the change runs + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(5 * time.Second)): + c.Fatal("expected second access to block while change runs but timed out") + } + + // when the second access is ongoing and waiting for the change to end, the + // queues are empty + s.state.Lock() + txs, _, err := confdbstate.GetOngoingTxs(s.state, s.devAccID, "network") + s.state.Unlock() + c.Assert(err, IsNil) + c.Assert(txs.Pending, IsNil) + c.Assert(txs.Scheduling, IsNil) + + err = s.o.Settle(5 * time.Second) + c.Assert(err, IsNil) + + select { + case <-doneChan: + case <-time.After(testutil.HostScaledTimeout(5 * time.Second)): + c.Fatal("expected access to block but timed out") + } +} + +func (s *confdbTestSuite) TestSnapctlWriteOngoingRead(c *C) { + view := s.dbSchema.View("setup-wifi") + + firstAccess := func(ctx context.Context) string { + chgID, err := confdbstate.ReadConfdb(ctx, s.state, view, []string{"ssid"}, nil, 0) + c.Assert(err, IsNil) + return chgID + } + + mockHandler := hooktest.NewMockHandler() + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup"} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + + secondAccess := func() { + ctx.Lock() + err := confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{"ssid": "foo"}) + ctx.Unlock() + c.Assert(err, IsNil) + } + s.testSnapctlConcurrentAccess(c, firstAccess, secondAccess) +} + +func (s *confdbTestSuite) TestSnapctlReadOngoingWrite(c *C) { + view := s.dbSchema.View("setup-wifi") + + mockHandler := hooktest.NewMockHandler() + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup"} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + + firstAccess := func(ctx context.Context) string { + chgID, err := confdbstate.WriteConfdb(ctx, s.state, view, map[string]any{"ssid": "foo"}) + c.Assert(err, IsNil) + return chgID + } + + secondAccess := func() { + ctx.Lock() + _, err := confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) + ctx.Unlock() + c.Assert(err, IsNil) + } + s.testSnapctlConcurrentAccess(c, firstAccess, secondAccess) +} + +func (s *confdbTestSuite) TestReadWithOngoingReadBlocksIfWriteIsPending(c *C) { + s.state.Lock() + defer s.state.Unlock() + + view := s.dbSchema.View("setup-wifi") + + // mock ongoing read transaction and pending access + ref := s.devAccID + "/network" + ongoingTxs := make(map[string]*confdbstate.ConfdbTransactions) + ongoingTxs[ref] = &confdbstate.ConfdbTransactions{ + ReadTxIDs: []string{"10"}, + } + s.state.Set("confdb-ongoing-txs", ongoingTxs) + s.state.Cache("pending-confdb-"+ref, []confdbstate.Access{{ + ID: "foo", + AccessType: confdbstate.AccessType("write"), + WaitChan: make(chan struct{}), + }}) + + // testing helper closed when the access is about to block + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + + ctx, cancel := context.WithCancel(context.Background()) + readDone := make(chan struct{}) + go func() { + _, err := confdbstate.ReadConfdb(ctx, s.state, view, []string{"ssid"}, nil, 0) + c.Assert(err, ErrorMatches, fmt.Sprintf("cannot read %s: timed out waiting for access", view.ID())) + close(readDone) + }() + + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // the read access released the lock and blocked so we have to re-lock + s.state.Lock() + pending, ok := s.state.Cached("pending-confdb-" + ref).([]confdbstate.Access) + s.state.Unlock() + c.Assert(ok, Equals, true) + c.Assert(pending, HasLen, 2) + c.Assert(pending[1].AccessType, Equals, confdbstate.AccessType("read")) + + // cancel the pending read access which should return an error and clean up + // its waiting channel from the pending queue + cancel() + + select { + case <-readDone: + // at this point the read returned and the state was re-locked + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // check that cancelling an access cleans up the pending state + pending, ok = s.state.Cached("pending-confdb-" + ref).([]confdbstate.Access) + c.Assert(ok, Equals, true) + c.Assert(pending, HasLen, 1) + c.Assert(pending[0].AccessType, Equals, confdbstate.AccessType("write")) +} + +func (s *confdbTestSuite) TestSnapctlReadAndWriteUseHookTimeout(c *C) { + s.state.Lock() + s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) + _, restore := s.mockConfdbHooks() + defer restore() + + view := s.dbSchema.View("setup-wifi") + + mockHandler := hooktest.NewMockHandler() + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup", Timeout: time.Microsecond} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + + ref := s.devAccID + "/network" + ongoingTxs := make(map[string]*confdbstate.ConfdbTransactions) + ongoingTxs[ref] = &confdbstate.ConfdbTransactions{ + WriteTxID: "10", + } + s.state.Set("confdb-ongoing-txs", ongoingTxs) + s.state.Unlock() + + ctx.Lock() + defer ctx.Unlock() + + _, err = confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) + c.Assert(err, ErrorMatches, fmt.Sprintf("cannot read %s: timed out waiting for access", view.ID())) + + err = confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{"ssid": "foo"}) + c.Assert(err, ErrorMatches, fmt.Sprintf("cannot write %s: timed out waiting for access", view.ID())) +} + +func (s *confdbTestSuite) TestOngoingTxUnblocksMultiplePendingReads(c *C) { + s.state.Lock() + defer s.state.Unlock() + + s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) + _, restore := s.mockConfdbHooks() + defer restore() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + view := s.dbSchema.View("setup-wifi") + chgID, err := confdbstate.WriteConfdb(ctx, s.state, view, map[string]any{"ssid": "foo"}) + c.Assert(err, IsNil) + + readOneChan, readTwoChan, writeChan := make(chan struct{}, 1), make(chan struct{}, 1), make(chan struct{}, 1) + s.state.Cache("pending-confdb-"+view.Schema().Account+"/network", []confdbstate.Access{ + { + ID: "foo", + AccessType: confdbstate.AccessType("read"), + WaitChan: readOneChan, + }, + { + ID: "bar", + AccessType: confdbstate.AccessType("read"), + WaitChan: readTwoChan, + }, + { + ID: "baz", + AccessType: confdbstate.AccessType("write"), + WaitChan: writeChan, + }, + }) + + s.state.Unlock() + err = s.o.Settle(5 * time.Second) + s.state.Lock() + c.Assert(err, IsNil) + + chg := s.state.Change(chgID) + c.Assert(chg.Status(), Equals, state.DoneStatus) + + // the running transaction unblocked the reads + select { + case <-readOneChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected 1st read to be unblocked but timed out") + } + + select { + case <-readTwoChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected 2nd read to be unblocked but timed out") + } + + // but not the write + select { + case <-writeChan: + c.Fatal("expected write not to have been unblocked") + case <-time.After(testutil.HostScaledTimeout(time.Millisecond)): + } +} + +func (s *confdbTestSuite) TestAPIConfdbErrorUnblocksNextAccess(c *C) { + s.state.Lock() + s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) + s.state.Unlock() + + view := s.dbSchema.View("setup-wifi") + ref := view.Schema().Account + "/" + view.Schema().Name + ctx := context.Background() + + var accErr error + for _, accFunc := range []func(){ + func() { + _, accErr = confdbstate.WriteConfdb(ctx, s.state, view, map[string]any{"nonexistent": "value"}) + }, + func() { _, accErr = confdbstate.ReadConfdb(ctx, s.state, view, []string{"nonexistent"}, nil, 0) }, + } { + s.state.Lock() + // mock an ongoing write transaction so the next access blocks + s.state.Set("confdb-ongoing-txs", map[string]*confdbstate.ConfdbTransactions{ + ref: {WriteTxID: "10"}, + }) + s.state.Cache("pending-confdb-"+ref, nil) + s.state.Cache("scheduling-confdb-"+ref, nil) + + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + + doneChan := make(chan struct{}) + go func() { + accFunc() + s.state.Unlock() + close(doneChan) + }() + + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // the blocked access released the lock; set up the next pending access + s.state.Lock() + accs := s.state.Cached("pending-confdb-" + ref) + c.Assert(accs, NotNil) + pending := accs.([]confdbstate.Access) + c.Assert(pending, HasLen, 1) + + // clear the ongoing tx, queue another pending access, then unblock + nextWaitChan := make(chan struct{}, 1) + s.endOngoingAccess(c, &confdbstate.Access{ + ID: "next-access", + AccessType: confdbstate.AccessType("write"), + WaitChan: nextWaitChan, + }) + s.state.Unlock() + + // the access should fail and unblock the next pending access + select { + case <-nextWaitChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected next access to be unblocked but timed out") + } + + select { + case <-doneChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected failed write to return but timed out") + } + c.Assert(accErr, ErrorMatches, `.*no matching rule`) + } +} + +// endOngoingAccess can be used to simulate the termination of a mocked ongoing +// transaction. It unsets the ongoing tx in the state, unblocks the next pending +// accesses and moves them to processing. If a new pending access is provided, +// it's set in the state. +func (s *confdbTestSuite) endOngoingAccess(c *C, newPending *confdbstate.Access) { + txs, updateFunc, err := confdbstate.GetOngoingTxs(s.state, s.devAccID, "network") + c.Assert(err, IsNil) + defer updateFunc(txs) + + txs.ReadTxIDs = nil + txs.WriteTxID = "" + + err = confdbstate.MaybeUnblockAccesses(txs) + c.Assert(err, IsNil) + + if newPending != nil { + txs.Pending = append(txs.Pending, *newPending) + } +} + +func (s *confdbTestSuite) TestSnapctlConfdbErrorUnblocksNextAccess(c *C) { + // force the read/writes to fail due to missing custodian + s.state.Lock() + repo := interfaces.NewRepository() + ifacerepo.Replace(s.state, repo) + + view := s.dbSchema.View("setup-wifi") + ref := view.Schema().Account + "/" + view.Schema().Name + + mockHandler := hooktest.NewMockHandler() + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "install"} + t := s.state.NewTask("run-hook", "") + chg := s.state.NewChange("some-change", "") + chg.AddTask(t) + + hookCtx, err := hookstate.NewContext(t, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + s.state.Unlock() + + var accErr error + for _, accFunc := range []func(){ + func() { + _, accErr = confdbstate.ReadConfdbFromSnap(hookCtx, view, []string{"ssid"}, nil) + }, + func() { + accErr = confdbstate.WriteConfdbFromSnap(hookCtx, view, map[string]any{"ssid": "foo"}) + }, + } { + accErr = nil + + s.state.Lock() + s.state.Set("confdb-ongoing-txs", map[string]*confdbstate.ConfdbTransactions{ + ref: {WriteTxID: "10"}, + }) + s.state.Cache("pending-confdb-"+ref, nil) + s.state.Cache("scheduling-confdb-"+ref, nil) + + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + s.state.Unlock() + + accDone := make(chan struct{}) + go func() { + hookCtx.Lock() + accFunc() + hookCtx.Unlock() + close(accDone) + }() + + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // the blocked access released the lock; set up the next pending access + s.state.Lock() + accs := s.state.Cached("pending-confdb-" + ref) + c.Assert(accs, NotNil) + pending := accs.([]confdbstate.Access) + c.Assert(pending, HasLen, 1) + + // clear the ongoing tx, queue another pending access, then unblock + nextWaitChan := make(chan struct{}, 1) + s.endOngoingAccess(c, &confdbstate.Access{ + ID: "next-access", + AccessType: confdbstate.AccessType("write"), + WaitChan: nextWaitChan, + }) + s.state.Unlock() + + // the failed access should unblock the next pending access + select { + case <-nextWaitChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected next access to be unblocked but timed out") + } + + select { + case <-accDone: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected failed access to return but timed out") + } + c.Assert(accErr, ErrorMatches, ".*: no custodian snap connected") + } +} + +func (s *confdbTestSuite) TestReadConfdbFromSnapNoHooksToRun(c *C) { + s.state.Lock() + + // the custodian snap has no hooks, so no tasks should be scheduled + custodians := map[string]confdbHooks{"custodian-snap": noHooks} + s.setupConfdbScenario(c, custodians, nil) + + // write some value for the get to read + bag := confdb.NewJSONDatabag() + err := bag.Set(parsePath(c, "wifi.ssid"), "foo") + c.Assert(err, IsNil) + + view := s.dbSchema.View("setup-wifi") + ref := view.Schema().Account + "/" + view.Schema().Name + s.state.Set("confdb-databags", map[string]map[string]confdb.JSONDatabag{s.devAccID: {"network": bag}}) + + mockHandler := hooktest.NewMockHandler() + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup"} + hookCtx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + + // simulate an ongoing write transaction so the read blocks + s.state.Set("confdb-ongoing-txs", map[string]*confdbstate.ConfdbTransactions{ + ref: {WriteTxID: "10"}, + }) + + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + s.state.Unlock() + + var tx *confdbstate.Transaction + var readErr error + doneChan := make(chan struct{}) + go func() { + hookCtx.Lock() + tx, readErr = confdbstate.ReadConfdbFromSnap(hookCtx, view, []string{"ssid"}, nil) + hookCtx.Unlock() + close(doneChan) + }() + + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // clear the ongoing tx, queue another pending access, then unblock + nextWaitChan := make(chan struct{}, 1) + s.state.Lock() + s.endOngoingAccess(c, &confdbstate.Access{ + ID: "next-write", + AccessType: confdbstate.AccessType("write"), + WaitChan: nextWaitChan, + }) + s.state.Unlock() + + // the no-hooks read path should unblock the next pending access + select { + case <-nextWaitChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected next access to be unblocked but timed out") + } + + select { + case <-doneChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected read to complete but timed out") + } + + c.Assert(readErr, IsNil) + c.Assert(tx, NotNil) + + s.state.Lock() + defer s.state.Unlock() + + // no tasks were scheduled because there are no hooks to run + c.Assert(s.state.Changes(), HasLen, 0) + + val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) + c.Assert(err, IsNil) + c.Assert(val, Equals, "foo") +} + +func (s *confdbTestSuite) TestAPIBlockingAccessTimedOutRacesWithUnblock(c *C) { + s.state.Lock() + defer s.state.Unlock() + + s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) + _, restore := s.mockConfdbHooks() + defer restore() + + view := s.dbSchema.View("setup-wifi") + ref := view.Schema().Account + "/" + view.Schema().Name + // simulate an ongoing write transaction so the read blocks + s.state.Set("confdb-ongoing-txs", map[string]*confdbstate.ConfdbTransactions{ + ref: {WriteTxID: "10"}, + }) + + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + + ctx, cancel := context.WithCancel(context.Background()) + doneChan := make(chan struct{}) + var cancelErr error + go func() { + _, cancelErr = confdbstate.WriteConfdb(ctx, s.state, view, map[string]any{"ssid": "foo"}) + close(doneChan) + }() + + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // mock a time out/cancel racing with an unblock + s.state.Lock() + cancel() + waitChan := make(chan struct{}, 1) + // in order to mock a race, we need to cancel the context and mock that another + // goroutine unblocked the channel and removed it. We won't actually unblock + // the channel otherwise we couldn't be sure which case the select would pick + txs, updateFunc, err := confdbstate.GetOngoingTxs(s.state, s.devAccID, "network") + c.Assert(err, IsNil) + c.Assert(txs.Pending, HasLen, 1) + c.Assert(txs.Pending[0].AccessType, Equals, confdbstate.AccessType("write")) + + // mock another goroutine unblocking the pending write + txs.WriteTxID = "" + txs.Scheduling = txs.Pending + txs.Pending = []confdbstate.Access{{ + ID: "next-read", + AccessType: confdbstate.AccessType("read"), + WaitChan: waitChan, + }} + updateFunc(txs) + s.state.Unlock() + + select { + case <-doneChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + c.Assert(cancelErr, ErrorMatches, ".*timed out waiting for access") + + // even though the pending access was already unblocked, the time out/cancel + // still cleaned up its state and unblocked the next access + cached := s.state.Cached("scheduling-confdb-" + ref).([]confdbstate.Access) + c.Assert(cached, HasLen, 1) + c.Assert(cached[0].AccessType, Equals, confdbstate.AccessType("read")) + + select { + case <-waitChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } +} diff --git a/overlord/confdbstate/export_test.go b/overlord/confdbstate/export_test.go index 3ed2d493d85..a040bc8ebfc 100644 --- a/overlord/confdbstate/export_test.go +++ b/overlord/confdbstate/export_test.go @@ -39,15 +39,10 @@ var ( type ( ConfdbTransactions = confdbTransactions - PendingAccess = pendingAccess + Access = access AccessType = accessType ) -const ( - CommitEdge = commitEdge - ClearTxEdge = clearTxEdge -) - func ChangeViewHandlerGenerator(ctx *hookstate.Context) hookstate.Handler { return &changeViewHandler{ctx: ctx} } @@ -96,6 +91,21 @@ func MockDefaultWaitTimeout(dur time.Duration) func() { } } -func SetBlockingSignalChan(signalChan chan struct{}) { - blockingSignalChan = signalChan +func SetBlockingSignal(key string, signalChan chan struct{}) { + if blockingSignals == nil { + blockingSignals = make(map[string]chan struct{}) + } + blockingSignals[key] = signalChan +} + +func ResetBlockingSignals() { + blockingSignals = nil +} + +func MaybeUnblockAccesses(txs *confdbTransactions) error { + return maybeUnblockAccesses(txs) +} + +func GetOngoingTxs(st *state.State, account, schemaName string) (ongoingTxs *confdbTransactions, updateTxStateFunc func(*confdbTransactions), err error) { + return getOngoingTxs(st, account, schemaName) } diff --git a/overlord/hookstate/ctlcmd/export_test.go b/overlord/hookstate/ctlcmd/export_test.go index b91f0fea492..78a8b63fd79 100644 --- a/overlord/hookstate/ctlcmd/export_test.go +++ b/overlord/hookstate/ctlcmd/export_test.go @@ -187,11 +187,11 @@ func MockNewStatusDecorator(f func(ctx context.Context, isGlobal bool, uid strin return restore } -func MockConfdbstateTransactionForSet(f func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error)) (restore func()) { - old := confdbstateTransactionForSet - confdbstateTransactionForSet = f +func MockConfdbstateWriteConfdb(f func(*hookstate.Context, *confdb.View, map[string]any) error) (restore func()) { + old := confdbstateWriteConfdb + confdbstateWriteConfdb = f return func() { - confdbstateTransactionForSet = old + confdbstateWriteConfdb = old } } @@ -203,10 +203,10 @@ func MockConfdbstateGetView(f func(st *state.State, account, confdbName, viewNam } } -func MockConfdbstateTransactionForGet(f func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error)) (restore func()) { - old := confdbstateTransactionForGet - confdbstateTransactionForGet = f +func MockConfdbstateReadConfdb(f func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error)) (restore func()) { + old := confdbstateReadConfdb + confdbstateReadConfdb = f return func() { - confdbstateTransactionForGet = old + confdbstateReadConfdb = old } } diff --git a/overlord/hookstate/ctlcmd/get.go b/overlord/hookstate/ctlcmd/get.go index 0f84ee39f4a..aaf486871d8 100644 --- a/overlord/hookstate/ctlcmd/get.go +++ b/overlord/hookstate/ctlcmd/get.go @@ -42,8 +42,8 @@ import ( ) var ( - confdbstateGetView = confdbstate.GetView - confdbstateTransactionForGet = confdbstate.GetTransactionForSnapctlGet + confdbstateGetView = confdbstate.GetView + confdbstateReadConfdb = confdbstate.ReadConfdbFromSnap ) type getCommand struct { @@ -448,7 +448,8 @@ func (c *getCommand) getConfdbValues(ctx *hookstate.Context, plugName string, re return err } - tx, err := confdbstateTransactionForGet(ctx, view, requests, constraints) + // TODO: add --wait-for timeout to options and cache in hookstate context + tx, err := confdbstateReadConfdb(ctx, view, requests, constraints) if err != nil { return err } diff --git a/overlord/hookstate/ctlcmd/get_test.go b/overlord/hookstate/ctlcmd/get_test.go index faf0ffa6b9c..05f9e328896 100644 --- a/overlord/hookstate/ctlcmd/get_test.go +++ b/overlord/hookstate/ctlcmd/get_test.go @@ -634,7 +634,7 @@ func (s *confdbSuite) TestConfdbGetSingleView(c *C) { c.Assert(err, IsNil) s.state.Unlock() - restore := ctlcmd.MockConfdbstateTransactionForGet(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { c.Assert(requests, DeepEquals, []string{"ssid"}) c.Assert(view.Schema().Account, Equals, s.devAccID) c.Assert(view.Schema().Name, Equals, "network") @@ -658,7 +658,7 @@ func (s *confdbSuite) TestConfdbGetManyViews(c *C) { c.Assert(err, IsNil) s.state.Unlock() - restore := ctlcmd.MockConfdbstateTransactionForGet(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { c.Assert(requests, DeepEquals, []string{"ssid", "password"}) c.Assert(view.Schema().Account, Equals, s.devAccID) c.Assert(view.Schema().Name, Equals, "network") @@ -687,7 +687,7 @@ func (s *confdbSuite) TestConfdbGetNoRequest(c *C) { c.Assert(err, IsNil) s.state.Unlock() - restore := ctlcmd.MockConfdbstateTransactionForGet(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { c.Assert(requests, IsNil) c.Assert(view.Schema().Account, Equals, s.devAccID) c.Assert(view.Schema().Name, Equals, "network") @@ -850,7 +850,7 @@ func (s *confdbSuite) TestConfdbGetPrevious(c *C) { err = tx.Set(parsePath(c, "wifi.ssid"), "bar") c.Assert(err, IsNil) - restore := ctlcmd.MockConfdbstateTransactionForGet(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { return tx, nil }) defer restore() @@ -1017,7 +1017,7 @@ func (s *confdbSuite) TestConfdbAccessUnconnectedPlug(c *C) { err = tx.Set(parsePath(c, "wifi.ssid"), "foo") c.Assert(err, IsNil) - restore := ctlcmd.MockConfdbstateTransactionForGet(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { c.Fatal("should not allow access to confdb") return tx, nil }) @@ -1077,7 +1077,7 @@ func (s *confdbSuite) TestConfdbDefaultIfNoData(c *C) { err = tx.Set(parsePath(c, "wifi.ssid"), "foo") c.Assert(err, IsNil) - restore := ctlcmd.MockConfdbstateTransactionForGet(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { return tx, nil }) defer restore() @@ -1098,7 +1098,7 @@ func (s *confdbSuite) TestConfdbDefaultNoFallbackIfTyped(c *C) { err = tx.Set(parsePath(c, "wifi.ssid"), "foo") c.Assert(err, IsNil) - restore := ctlcmd.MockConfdbstateTransactionForGet(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { return tx, nil }) defer restore() @@ -1117,7 +1117,7 @@ func (s *confdbSuite) TestConfdbDefaultWithOtherFlags(c *C) { tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") c.Assert(err, IsNil) - restore := ctlcmd.MockConfdbstateTransactionForGet(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { return tx, nil }) defer restore() @@ -1191,7 +1191,7 @@ func (s *confdbSuite) TestConfdbGetWithConstraints(c *C) { s.state.Unlock() var gotConstraints map[string]any - restore := ctlcmd.MockConfdbstateTransactionForGet(func(_ *hookstate.Context, _ *confdb.View, _ []string, constraints map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(_ *hookstate.Context, _ *confdb.View, _ []string, constraints map[string]any) (*confdbstate.Transaction, error) { gotConstraints = constraints return tx, nil }) @@ -1315,7 +1315,7 @@ func (s *confdbSuite) TestConfdbGetTypedConstraints(c *C) { s.state.Unlock() var gotConstraints map[string]any - restore := ctlcmd.MockConfdbstateTransactionForGet(func(_ *hookstate.Context, _ *confdb.View, _ []string, constraints map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(_ *hookstate.Context, _ *confdb.View, _ []string, constraints map[string]any) (*confdbstate.Transaction, error) { gotConstraints = constraints return tx, nil }) @@ -1354,7 +1354,7 @@ func (s *confdbSuite) TestConfdbGetSecretVisibility(c *C) { c.Assert(err, IsNil) s.state.Unlock() - restore := ctlcmd.MockConfdbstateTransactionForGet(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { c.Assert(requests, DeepEquals, []string{"password"}) c.Assert(view.Schema().Account, Equals, s.devAccID) c.Assert(view.Schema().Name, Equals, "network") diff --git a/overlord/hookstate/ctlcmd/set.go b/overlord/hookstate/ctlcmd/set.go index e359e48f8b6..d29ab8dcd1f 100644 --- a/overlord/hookstate/ctlcmd/set.go +++ b/overlord/hookstate/ctlcmd/set.go @@ -36,7 +36,7 @@ import ( "github.com/snapcore/snapd/snap" ) -var confdbstateTransactionForSet = confdbstate.GetTransactionToSet +var confdbstateWriteConfdb = confdbstate.WriteConfdbFromSnap type setCommand struct { baseCommand @@ -244,7 +244,7 @@ func (s *setCommand) setInterfaceSetting(context *hookstate.Context, plugOrSlot return nil } -func setConfdbValues(ctx *hookstate.Context, plugName string, requests map[string]any) error { +func setConfdbValues(ctx *hookstate.Context, plugName string, values map[string]any) error { ctx.Lock() defer ctx.Unlock() @@ -267,28 +267,6 @@ func setConfdbValues(ctx *hookstate.Context, plugName string, requests map[strin return fmt.Errorf("cannot modify confdb in %q hook", ctx.HookName()) } - tx, commitTxFunc, err := confdbstateTransactionForSet(ctx, ctx.State(), view) - if err != nil { - return err - } - - err = confdbstate.SetViaView(tx, view, requests) - if err != nil { - return err - } - - // if a new transaction was created, commit it - if commitTxFunc != nil { - _, waitChan, err := commitTxFunc() - if err != nil { - return err - } - - // wait for the transaction to be committed - ctx.Unlock() - <-waitChan - ctx.Lock() - } - - return nil + // TODO: add --wait-for timeout to options and cache in hookstate context + return confdbstateWriteConfdb(ctx, view, values) } diff --git a/overlord/hookstate/ctlcmd/set_test.go b/overlord/hookstate/ctlcmd/set_test.go index 2ac846b60ab..d73689771d9 100644 --- a/overlord/hookstate/ctlcmd/set_test.go +++ b/overlord/hookstate/ctlcmd/set_test.go @@ -28,14 +28,12 @@ import ( "github.com/snapcore/snapd/confdb" "github.com/snapcore/snapd/interfaces" - "github.com/snapcore/snapd/overlord/confdbstate" "github.com/snapcore/snapd/overlord/configstate/config" "github.com/snapcore/snapd/overlord/hookstate" "github.com/snapcore/snapd/overlord/hookstate/ctlcmd" "github.com/snapcore/snapd/overlord/hookstate/hooktest" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/snap" - "github.com/snapcore/snapd/testutil" ) type setSuite struct { @@ -412,44 +410,14 @@ func parsePath(c *C, path string) []confdb.Accessor { return accs } -func (s *confdbSuite) TestConfdbSetSingleView(c *C) { - s.state.Lock() - tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") - s.state.Unlock() - c.Assert(err, IsNil) - - restore := ctlcmd.MockConfdbstateTransactionForSet(func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error) { - return tx, nil, nil - }) - defer restore() - - stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"set", "--view", ":write-wifi", "ssid=other-ssid"}, 0, nil) - c.Assert(err, IsNil) - c.Check(stdout, IsNil) - c.Check(stderr, IsNil) - s.mockContext.Lock() - c.Assert(s.mockContext.Done(), IsNil) - s.mockContext.Unlock() - - val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, "other-ssid") -} - func (s *confdbSuite) TestConfdbSetSingleViewNewTransaction(c *C) { - s.state.Lock() - tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") - s.state.Unlock() - c.Assert(err, IsNil) - var called bool - restore := ctlcmd.MockConfdbstateTransactionForSet(func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error) { - return tx, func() (string, <-chan struct{}, error) { - called = true - waitChan := make(chan struct{}) - close(waitChan) - return "123", waitChan, nil - }, nil + restore := ctlcmd.MockConfdbstateWriteConfdb(func(_ *hookstate.Context, _ *confdb.View, values map[string]any) error { + called = true + c.Assert(values, DeepEquals, map[string]any{ + "ssid": "other-ssid", + }) + return nil }) defer restore() @@ -457,22 +425,16 @@ func (s *confdbSuite) TestConfdbSetSingleViewNewTransaction(c *C) { c.Assert(err, IsNil) c.Check(stdout, IsNil) c.Check(stderr, IsNil) - c.Assert(called, Equals, true) - - val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, "other-ssid") } func (s *confdbSuite) TestConfdbSetManyViews(c *C) { - s.state.Lock() - tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") - s.state.Unlock() - c.Assert(err, IsNil) - - restore := ctlcmd.MockConfdbstateTransactionForSet(func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error) { - return tx, nil, nil + restore := ctlcmd.MockConfdbstateWriteConfdb(func(_ *hookstate.Context, _ *confdb.View, values map[string]any) error { + c.Assert(values, DeepEquals, map[string]any{ + "ssid": "other-ssid", + "password": "other-secret", + }) + return nil }) defer restore() @@ -480,14 +442,6 @@ func (s *confdbSuite) TestConfdbSetManyViews(c *C) { c.Assert(err, IsNil) c.Check(stdout, IsNil) c.Check(stderr, IsNil) - - val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) - c.Assert(err, IsNil) - c.Assert(val, Equals, "other-ssid") - - val, err = tx.Get(parsePath(c, "wifi.psk"), nil) - c.Assert(err, IsNil) - c.Assert(val, Equals, "other-secret") } func (s *confdbSuite) TestConfdbSetInvalid(c *C) { @@ -516,19 +470,9 @@ func (s *confdbSuite) TestConfdbSetInvalid(c *C) { } func (s *confdbSuite) TestConfdbSetExclamationMark(c *C) { - s.state.Lock() - tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") - s.state.Unlock() - c.Assert(err, IsNil) - - err = tx.Set(parsePath(c, "wifi.ssid"), "foo") - c.Assert(err, IsNil) - - err = tx.Set(parsePath(c, "wifi.psk"), "bar") - c.Assert(err, IsNil) - - restore := ctlcmd.MockConfdbstateTransactionForSet(func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error) { - return tx, nil, nil + restore := ctlcmd.MockConfdbstateWriteConfdb(func(_ *hookstate.Context, _ *confdb.View, values map[string]any) error { + c.Assert(values, DeepEquals, map[string]any{"password": nil}) + return nil }) defer restore() @@ -536,24 +480,15 @@ func (s *confdbSuite) TestConfdbSetExclamationMark(c *C) { c.Assert(err, IsNil) c.Check(stdout, IsNil) c.Check(stderr, IsNil) - - _, err = tx.Get(parsePath(c, "wifi.psk"), nil) - c.Assert(err, testutil.ErrorIs, &confdb.NoDataError{}) - - val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) - c.Assert(err, IsNil) - c.Assert(val, Equals, "foo") } func (s *confdbSuite) TestConfdbModifyHooks(c *C) { s.state.Lock() defer s.state.Unlock() - tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") - c.Assert(err, IsNil) - - restore := ctlcmd.MockConfdbstateTransactionForSet(func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error) { - return tx, nil, nil + restore := ctlcmd.MockConfdbstateWriteConfdb(func(_ *hookstate.Context, _ *confdb.View, values map[string]any) error { + c.Assert(values, DeepEquals, map[string]any{"password": "thing"}) + return nil }) defer restore() diff --git a/overlord/hookstate/ctlcmd/unset_test.go b/overlord/hookstate/ctlcmd/unset_test.go index 884a0007128..53f18bfaf50 100644 --- a/overlord/hookstate/ctlcmd/unset_test.go +++ b/overlord/hookstate/ctlcmd/unset_test.go @@ -25,14 +25,12 @@ import ( . "gopkg.in/check.v1" "github.com/snapcore/snapd/confdb" - "github.com/snapcore/snapd/overlord/confdbstate" "github.com/snapcore/snapd/overlord/configstate/config" "github.com/snapcore/snapd/overlord/hookstate" "github.com/snapcore/snapd/overlord/hookstate/ctlcmd" "github.com/snapcore/snapd/overlord/hookstate/hooktest" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/snap" - "github.com/snapcore/snapd/testutil" ) type unsetSuite struct { @@ -165,31 +163,18 @@ func (s *unsetSuite) TestCommandWithoutContext(c *C) { } func (s *confdbSuite) TestConfdbUnsetManyViews(c *C) { - s.state.Lock() - tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") - s.state.Unlock() - c.Assert(err, IsNil) - - err = tx.Set(parsePath(c, "wifi.ssid"), "foo") - c.Assert(err, IsNil) - - err = tx.Set(parsePath(c, "wifi.psk"), "bar") - c.Assert(err, IsNil) - - ctlcmd.MockConfdbstateTransactionForSet(func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error) { - return tx, nil, nil + ctlcmd.MockConfdbstateWriteConfdb(func(_ *hookstate.Context, _ *confdb.View, values map[string]any) error { + c.Assert(values, DeepEquals, map[string]any{ + "ssid": nil, + "password": nil, + }) + return nil }) stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"unset", "--view", ":write-wifi", "ssid", "password"}, 0, nil) c.Assert(err, IsNil) c.Check(stdout, IsNil) c.Check(stderr, IsNil) - - _, err = tx.Get(parsePath(c, "wifi.ssid"), nil) - c.Assert(err, testutil.ErrorIs, &confdb.NoDataError{}) - - _, err = tx.Get(parsePath(c, "wifi.psk"), nil) - c.Assert(err, testutil.ErrorIs, &confdb.NoDataError{}) } func (s *confdbSuite) TestConfdbUnsetInvalid(c *C) { From 7573c0192d134225d31148fed78b4626c3110bba Mon Sep 17 00:00:00 2001 From: Robert Fudge Date: Tue, 21 Apr 2026 17:59:14 -0230 Subject: [PATCH 10/21] o/h/ctlcmd, t/main: add per-change rate limit, unit tests, spread test (#16777) * o/h/ctlcmd, t/main: add per-change rate limit, unit tests, spread test * o/h/ctlcmd: update to table-driven tests * o/h/ctlcmd: keep rate-limiting in memory, have ready check outside of lock * o/h/ctlcmd: ensure is-ready aligns with spec implentation * fixup! o/h/ctlcmd: ensure is-ready aligns with spec implentation * o/h/ctlcmd, t/m/snapctl-is-ready: fix lock hold while sleeping, refactor test tables to seperate logic and rate-lim tests, refactor spread tests to adhere to command format, add rate-lim spread test * o/h/ctlcmd: fix unit test complexity, simplify table-driven tests, fix race condition in helpers.go * t/main/snapctl-is-ready*: fix issue with --classic on UC * fixup! t/main/snapctl-is-ready*: fix issue with --classic on UC * o/h/ctlcmd: change behavior of last-accessed so that if it doesnt exist in cache, it is created * fixup! o/h/ctlcmd: change behavior of last-accessed so that if it doesnt exist in cache, it is created * t/main/snapctl-is-ready: fix spread test formatting * o/h/ctlcmd, t/main/snapctl-is-ready: fix spread test issues, clean-up isReady helper * t/main/snapctl-is-ready: remove rate limit spread test * o/h/ctlcmd, t/main/snapctl-is-ready: fix comments, error strings, use snap install --wait in spread test * fixup! o/h/ctlcmd, t/main/snapctl-is-ready: fix comments, error strings, use snap install --wait in spread test * o/h/ctlcmd, t/m/snapctl-is-ready: Move last access caching later in isReady, simplify spread test * fixup! o/h/ctlcmd, t/m/snapctl-is-ready: Move last access caching later in isReady, simplify spread test * fixup! fixup! o/h/ctlcmd, t/m/snapctl-is-ready: Move last access caching later in isReady, simplify spread test * fixup! fixup! fixup! o/h/ctlcmd, t/m/snapctl-is-ready: Move last access caching later in isReady, simplify spread test * fixup! fixup! fixup! fixup! o/h/ctlcmd, t/m/snapctl-is-ready: Move last access caching later in isReady, simplify spread test * o/h/ctlcmd: add toWait to the cached time to start waiting from when the last wait is finished * o/h/ctlcmd: add private type for key, ensure last access accounts for wait time for current request * o/ctlcmd: refactor rate limiting into some named functions * ov/ho/ctlcmd: add to non-root allowlist, write change status to stderr on ret code 2, fix spread test description * fixup! ov/ho/ctlcmd: add to non-root allowlist, write change status to stderr on ret code 2, fix spread test description * ov/ho/ctlcmd, tests: improve code documentation, revert spread test to pack and install instead of INSTALL_LOCAL tool * ov/ho/ctlcmd: fix doc comments * fixup! ov/ho/ctlcmd: fix doc comments * fixup! fixup! ov/ho/ctlcmd: fix doc comments * fixup! fixup! fixup! ov/ho/ctlcmd: fix doc comments * fixup! fixup! fixup! fixup! ov/ho/ctlcmd: fix doc comments --------- Co-authored-by: Andrew Phelps --- overlord/hookstate/ctlcmd/ctlcmd.go | 2 +- overlord/hookstate/ctlcmd/export_test.go | 10 + overlord/hookstate/ctlcmd/helpers.go | 139 ++++++++++ overlord/hookstate/ctlcmd/is_ready.go | 89 +++++++ overlord/hookstate/ctlcmd/is_ready_test.go | 249 ++++++++++++++++++ tests/main/snapctl-is-ready/task.yaml | 48 ++++ .../test-comp/meta/component.yaml | 5 + tests/main/snapctl-is-ready/test-snap/bin/app | 2 + .../snapctl-is-ready/test-snap/meta/snap.yaml | 12 + 9 files changed, 555 insertions(+), 1 deletion(-) create mode 100644 overlord/hookstate/ctlcmd/is_ready.go create mode 100644 overlord/hookstate/ctlcmd/is_ready_test.go create mode 100644 tests/main/snapctl-is-ready/task.yaml create mode 100644 tests/main/snapctl-is-ready/test-comp/meta/component.yaml create mode 100755 tests/main/snapctl-is-ready/test-snap/bin/app create mode 100644 tests/main/snapctl-is-ready/test-snap/meta/snap.yaml diff --git a/overlord/hookstate/ctlcmd/ctlcmd.go b/overlord/hookstate/ctlcmd/ctlcmd.go index ce7354f1205..2b7774490d2 100644 --- a/overlord/hookstate/ctlcmd/ctlcmd.go +++ b/overlord/hookstate/ctlcmd/ctlcmd.go @@ -148,7 +148,7 @@ func (f ForbiddenCommandError) Error() string { // nonRootAllowed lists the commands that can be performed even when snapctl // is invoked not by root. -var nonRootAllowed = []string{"get", "services", "set-health", "is-connected", "system-mode", "refresh", "model", "version"} +var nonRootAllowed = []string{"get", "services", "set-health", "is-connected", "system-mode", "refresh", "model", "version", "is-ready"} // Run runs the requested command. func Run(context *hookstate.Context, args []string, uid uint32, features []string) (stdout, stderr []byte, err error) { diff --git a/overlord/hookstate/ctlcmd/export_test.go b/overlord/hookstate/ctlcmd/export_test.go index 78a8b63fd79..bba0e79095f 100644 --- a/overlord/hookstate/ctlcmd/export_test.go +++ b/overlord/hookstate/ctlcmd/export_test.go @@ -22,6 +22,7 @@ package ctlcmd import ( "context" "errors" + "time" "github.com/snapcore/snapd/asserts" "github.com/snapcore/snapd/asserts/snapasserts" @@ -51,6 +52,8 @@ var ( ) type KmodCommand = kmodCommand +type IsReadyCommand = isReadyCommand +type ChangeRateLimitKey = changeRateLimitKey func MockKmodCheckConnection(f func(*hookstate.Context, string, []string) error) (restore func()) { r := testutil.Backup(&kmodCheckConnection) @@ -210,3 +213,10 @@ func MockConfdbstateReadConfdb(f func(*hookstate.Context, *confdb.View, []string confdbstateReadConfdb = old } } + +// TODO:GOVERSION: use time bubbles once project is updated to Go 1.26 +func MockTimeAfter(f func(time.Duration) <-chan time.Time) (restore func()) { + old := timeAfter + timeAfter = f + return func() { timeAfter = old } +} diff --git a/overlord/hookstate/ctlcmd/helpers.go b/overlord/hookstate/ctlcmd/helpers.go index fe2c39572a9..f5178178e8d 100644 --- a/overlord/hookstate/ctlcmd/helpers.go +++ b/overlord/hookstate/ctlcmd/helpers.go @@ -48,6 +48,8 @@ var ( snapstateRemoveComponents = snapstate.RemoveComponents ) +var timeAfter = time.After + var ( serviceControlChangeKind = swfeats.RegisterChangeKind("service-control") snapctlInstallChangeKind = swfeats.RegisterChangeKind("snapctl-install") @@ -61,6 +63,8 @@ func init() { } } +const snapctlDebounceWindow = 200 * time.Millisecond + // finalSeedTask is the last task that should run during seeding. This is used // in the special handling of the "seed" change, which requires that we // introspect the change for this specific task. Finding this task allows us to @@ -451,6 +455,7 @@ func runSnapManagementCommand(hctx *hookstate.Context, cmd managementCommand) er chg := st.NewChange(changeKind, fmt.Sprintf("%s components %v for snap %s", cmdVerb, cmd.components, hctx.InstanceName())) + chg.Set("initiated-by-snap", hctx.InstanceName()) for _, ts := range tss { chg.AddAll(ts) } @@ -491,6 +496,140 @@ func jsonRaw(v any) *json.RawMessage { return &raw } +type changeRateLimitKey struct { + ChangeID string +} + +// isReady checks if the change is ready, if it is, it returns the status, otherwise state.DoingStatus. +func isReady(hctx *hookstate.Context, changeID string) (state.Status, error) { + callerSnapName := hctx.InstanceName() + + st := hctx.State() + st.Lock() + defer st.Unlock() + + chg := st.Change(changeID) + + if chg == nil { + return state.DefaultStatus, fmt.Errorf("change %q not found", changeID) + } + + var initiatorSnapName string + err := chg.Get("initiated-by-snap", &initiatorSnapName) + if err != nil { + return state.DefaultStatus, fmt.Errorf("change %q not found", changeID) + } + + if initiatorSnapName != callerSnapName { + return state.DefaultStatus, fmt.Errorf("change %q not found", changeID) + } + + wait, err := rateLimit(st, changeID, snapctlDebounceWindow) + if err != nil { + return state.DefaultStatus, err + } + + return unlockAndWaitForStatus(st, chg, wait), nil +} + +// unlockAndWaitForStatus unlocks the state and waits for the change to be ready. +// The lock must be held prior to calling, and will be re-acquired before returning. +// Returns doingStatus if the change is still in progress, otherwise returns the final +// status of the change. +func unlockAndWaitForStatus(st *state.State, chg *state.Change, wait time.Duration) state.Status { + st.Unlock() + // note: we cannot defer the re-lock, since we must re-lock prior to + // calculating the return value in some branches. + + ready := chg.Ready() + + // The check ensures that both select cases aren't true immediately. + if wait <= 0 { + select { + // use default so the channel is prioritized. + case <-ready: + st.Lock() + return chg.Status() + default: + st.Lock() + return state.DoingStatus + } + } + + // Because the wait could've been > 0, the last select between a closed ready channel + // and a timer.After channel would've be racy. + select { + case <-ready: + case <-timeAfter(wait): + st.Lock() + return state.DoingStatus + } + + st.Lock() + return chg.Status() +} + +// rateLimit returns the amount of time that should be waited before accessing +// this change via snapctl. Internally, data associated with the change is +// cached so that all access to the change shares the same rate limit. +// The lock must be acquired before calling, as it modifies the state object. +func rateLimit(st *state.State, changeID string, rate time.Duration) (wait time.Duration, err error) { + now := time.Now() + + accessed, err := changeAccessedAt(st, changeID) + if err != nil { + return 0, err + } + + // first time through, we just set the change access to now. next request + // must wait at least "rate" duration before access. + if accessed.IsZero() { + setChangeAccessedAt(st, now, changeID) + return 0, nil + } + + durationSinceLastAccess := now.Sub(accessed) + + // user waited on their own, no waiting needed. next access will require + // waiting at least "rate" duration. + if durationSinceLastAccess >= rate { + setChangeAccessedAt(st, now, changeID) + return 0, nil + } + + // user needs to wait a bit still. note that durationSinceLastAccess might + // be negative, since "accessed" could be in the future. this can happen + // when there are multiple requests in parallel, within a duration less than + // "rate". + wait = rate - durationSinceLastAccess + + // current request must wait. next request must wait this amount of time, + // plus at least "rate" duration. + setChangeAccessedAt(st, now.Add(wait), changeID) + + return wait, nil +} + +func changeAccessedAt(st *state.State, changeID string) (time.Time, error) { + key := changeRateLimitKey{ChangeID: changeID} + accessedAt := st.Cached(key) + if accessedAt == nil { + return time.Time{}, nil + } + + accessedNano, ok := accessedAt.(int64) + if !ok { + return time.Time{}, fmt.Errorf("error: invalid type (%T) for access time", accessedAt) + } + + return time.Unix(0, accessedNano), nil +} + +func setChangeAccessedAt(st *state.State, accessed time.Time, changeID string) { + key := changeRateLimitKey{ChangeID: changeID} + st.Cache(key, accessed.UnixNano()) +} + // getAttribute unmarshals into result the value of the provided key from attributes map. // If the key does not exist, an error of type *NoAttributeError is returned. // The provided key may be formed as a dotted key path through nested maps. diff --git a/overlord/hookstate/ctlcmd/is_ready.go b/overlord/hookstate/ctlcmd/is_ready.go new file mode 100644 index 00000000000..9289a1cc5d9 --- /dev/null +++ b/overlord/hookstate/ctlcmd/is_ready.go @@ -0,0 +1,89 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package ctlcmd + +import ( + "fmt" + + "github.com/snapcore/snapd/i18n" + "github.com/snapcore/snapd/overlord/state" +) + +type isReadyCommand struct { + baseCommand +} + +const ( + changeReadyExitCode = iota + changeNotReadyExitCode + changeUnsuccessfulExitCode + otherErrorExitCode +) + +var shortIsReadyHelp = i18n.G(`Return the status of the associated change id.`) +var longIsReadyHelp = i18n.G(` +The is-ready command is used to query the status of change ids that are returned +by asynchronous snapctl commands. + +$ snapctl is-ready + 0: change completed successfully (Done) + 1: change is not ready + 2: change is ready but did not complete successfully (Undone, Error, Hold) + 3: other errors (invalid change id, permissions error) +stdout: empty, exit code conveys change readiness +stderr: empty for exit codes 0 and 1. Contains relevant errors for exit codes 2 and 3. +`) + +func init() { + addCommand("is-ready", shortIsReadyHelp, longIsReadyHelp, func() command { + return &isReadyCommand{} + }) +} + +func (c *isReadyCommand) Execute(args []string) error { + ctx, err := c.ensureContext() + if err != nil { + return err + } + + if len(args) != 1 { + return fmt.Errorf("invalid number of arguments: expected 1, got %d", len(args)) + } + + changeID := args[0] + + ready, err := isReady(ctx, changeID) + + if err != nil { + fmt.Fprint(c.stderr, err.Error()) + return &UnsuccessfulError{ExitCode: otherErrorExitCode} + } + + if !ready.Ready() { + return &UnsuccessfulError{ExitCode: changeNotReadyExitCode} + } + + if ready != state.DoneStatus { + fmt.Fprintf(c.stderr, "change finished with status %s", ready) + return &UnsuccessfulError{ExitCode: changeUnsuccessfulExitCode} + } + + return nil +} diff --git a/overlord/hookstate/ctlcmd/is_ready_test.go b/overlord/hookstate/ctlcmd/is_ready_test.go new file mode 100644 index 00000000000..489548bce73 --- /dev/null +++ b/overlord/hookstate/ctlcmd/is_ready_test.go @@ -0,0 +1,249 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package ctlcmd_test + +import ( + "time" + + . "gopkg.in/check.v1" + + "github.com/snapcore/snapd/dirs" + "github.com/snapcore/snapd/overlord/hookstate" + "github.com/snapcore/snapd/overlord/hookstate/ctlcmd" + "github.com/snapcore/snapd/overlord/hookstate/hooktest" + "github.com/snapcore/snapd/overlord/state" + "github.com/snapcore/snapd/snap" + "github.com/snapcore/snapd/testutil" +) + +type isReadySuite struct { + testutil.BaseTest + mockHandler *hooktest.MockHandler +} + +var _ = Suite(&isReadySuite{}) + +func (s *isReadySuite) SetUpTest(c *C) { + s.BaseTest.SetUpTest(c) + dirs.SetRootDir(c.MkDir()) + s.AddCleanup(func() { dirs.SetRootDir("/") }) + s.mockHandler = hooktest.NewMockHandler() +} + +// setupChangeAndContext creates a state, a change (with an optional initiator), +// and a non-ephemeral hook context for "test-snap". +func (s *isReadySuite) setupChangeAndContext(c *C, taskStatus state.Status, initiatorSnap string) (*state.State, *hookstate.Context, string) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("snapctl-install", "install via snapctl") + task := st.NewTask("test-task", "test task") + chg.AddTask(task) + + if initiatorSnap != "" { + chg.Set("initiated-by-snap", initiatorSnap) + } + + task.SetStatus(taskStatus) + + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: "install"} + ctx, err := hookstate.NewContext(task, st, setup, s.mockHandler, "") + c.Assert(err, IsNil) + + return st, ctx, chg.ID() +} + +func (s *isReadySuite) TestIsReadyNoContext(c *C) { + _, _, err := ctlcmd.Run(nil, []string{"is-ready", "1"}, 0, nil) + c.Assert(err, ErrorMatches, `cannot invoke snapctl operation commands.*from outside of a snap`) +} + +func (s *isReadySuite) TestIsReadyArgCount(c *C) { + _, ctx, _ := s.setupChangeAndContext(c, state.DoneStatus, "test-snap") + _, _, err := ctlcmd.Run(ctx, []string{"is-ready"}, 0, nil) + c.Assert(err, ErrorMatches, `invalid number of arguments: expected 1, got 0`) + + _, _, err = ctlcmd.Run(ctx, []string{"is-ready", "1", "extra-arg"}, 0, nil) + c.Assert(err, ErrorMatches, `invalid number of arguments: expected 1, got 2`) +} + +func (s *isReadySuite) TestIsReadyChangeNotFound(c *C) { + _, ctx, _ := s.setupChangeAndContext(c, state.DoneStatus, "") + _, stderr, err := ctlcmd.Run(ctx, []string{"is-ready", "nonexistent-id"}, 0, nil) + c.Assert(err, DeepEquals, &ctlcmd.UnsuccessfulError{ExitCode: 3}) + c.Check(string(stderr), Matches, `change "nonexistent-id" not found`) +} + +func (s *isReadySuite) TestIsReadyLogic(c *C) { + var logicTests = []struct { + taskStatus state.Status + initiatorSnap string // empty = don't set initiated-by-snap on the change + errValue error // if set, expect err to deep equal this value + expectedOut string + expectedStderr string // if set, checked as regexp match against stderr + }{ + { + taskStatus: state.DoneStatus, + errValue: &ctlcmd.UnsuccessfulError{ExitCode: 3}, + expectedStderr: `change .* not found`, + }, + { + taskStatus: state.DoneStatus, + initiatorSnap: "other-snap", // different from context snap "test-snap" + errValue: &ctlcmd.UnsuccessfulError{ExitCode: 3}, + expectedStderr: `change .* not found`, + }, + { + taskStatus: state.DoneStatus, + initiatorSnap: "test-snap", + }, + { + taskStatus: state.DoingStatus, + initiatorSnap: "test-snap", + errValue: &ctlcmd.UnsuccessfulError{ExitCode: 1}, + }, + { + taskStatus: state.ErrorStatus, + initiatorSnap: "test-snap", + errValue: &ctlcmd.UnsuccessfulError{ExitCode: 2}, + expectedStderr: `change finished with status Error`, + }, + { + taskStatus: state.HoldStatus, + initiatorSnap: "test-snap", + errValue: &ctlcmd.UnsuccessfulError{ExitCode: 2}, + expectedStderr: `change finished with status Hold`, + }, + } + + for _, tt := range logicTests { + _, ctx, changeID := s.setupChangeAndContext(c, tt.taskStatus, tt.initiatorSnap) + stdout, stderr, err := ctlcmd.Run(ctx, []string{"is-ready", changeID}, 0, nil) + if tt.errValue != nil { + c.Assert(err, DeepEquals, tt.errValue) + } else { + c.Assert(err, IsNil) + } + c.Check(string(stdout), Equals, tt.expectedOut) + if tt.expectedStderr != "" { + c.Check(string(stderr), Matches, tt.expectedStderr) + } else { + c.Check(string(stderr), Equals, "") + } + } +} + +// Rate-limiting tests +func (s *isReadySuite) rateLimitSetup(c *C, taskStatus state.Status, lastAccessedTime any) (*hookstate.Context, string) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("snapctl-install", "install via snapctl") + task := st.NewTask("test-task", "test task") + chg.AddTask(task) + chg.Set("initiated-by-snap", "test-snap") + + if lastAccessedTime != nil { + st.Cache(ctlcmd.ChangeRateLimitKey{ChangeID: chg.ID()}, lastAccessedTime) + } + + task.SetStatus(taskStatus) + + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: "install"} + ctx, err := hookstate.NewContext(task, st, setup, s.mockHandler, "") + c.Assert(err, IsNil) + + return ctx, chg.ID() +} + +// TestIsReadyMissingLastAccessed verifies that is-ready treats a missing +// last-accessed cache entry (e.g. after a snapd restart) as a first access and +// proceeds to report the real change status rather than returning an error. +func (s *isReadySuite) TestIsReadyMissingLastAccessed(c *C) { + ctx, changeID := s.rateLimitSetup(c, state.DoneStatus, nil) + + _, _, err := ctlcmd.Run(ctx, []string{"is-ready", changeID}, 0, nil) + + c.Assert(err, IsNil) +} + +// TestIsReadyRateLimitDelaysPolling verifies that when a snap polls within the +// 200 ms debounce window, is-ready sleeps for the remaining window duration +// before checking the change status. +func (s *isReadySuite) TestIsReadyRateLimitDelaysPolling(c *C) { + // A last-accessed time in the future guarantees we are within the debounce + // window, ensuring timeAfter is called with a positive duration. + ctx, changeID := s.rateLimitSetup(c, state.DoneStatus, time.Now().Add(time.Second).UnixNano()) + + var waitedFor time.Duration + restore := ctlcmd.MockTimeAfter(func(d time.Duration) <-chan time.Time { + waitedFor = d + return make(chan time.Time) // never fires; chg.Ready() wins + }) + defer restore() + + _, _, err := ctlcmd.Run(ctx, []string{"is-ready", changeID}, 0, nil) + + c.Assert(err, IsNil) + c.Check(waitedFor > 0, Equals, true) +} + +// TestIsReadyRateLimitTimerFires verifies that when timeAfter fires before the +// change is ready, is-ready reports DoingStatus (exit code 1) and the timer +// channel is drained. +func (s *isReadySuite) TestIsReadyRateLimitTimerFires(c *C) { + // A last-accessed time in the future puts us inside the debounce window. + // The task is left in DoingStatus so chg.Ready() never fires, ensuring + // the timer case is the only one that can win the select. + ctx, changeID := s.rateLimitSetup(c, state.DoingStatus, time.Now().Add(time.Second).UnixNano()) + + timerCh := make(chan time.Time, 1) + timerCh <- time.Now() // pre-fill so the timer fires immediately + restore := ctlcmd.MockTimeAfter(func(d time.Duration) <-chan time.Time { + return timerCh + }) + defer restore() + + _, _, err := ctlcmd.Run(ctx, []string{"is-ready", changeID}, 0, nil) + + c.Assert(err, DeepEquals, &ctlcmd.UnsuccessfulError{ExitCode: 1}) + c.Check(len(timerCh), Equals, 0) // element was consumed by the select +} + +// TestIsReadyExpiredWindowSkipsTimeAfter verifies that when the debounce window +// has already elapsed, is-ready returns the change status directly +func (s *isReadySuite) TestIsReadyExpiredWindowSkipsTimeAfter(c *C) { + // A last-accessed time sufficiently in the past guarantees toWait <= 0. + ctx, changeID := s.rateLimitSetup(c, state.DoneStatus, time.Now().Add(-time.Second).UnixNano()) + + called := false + restore := ctlcmd.MockTimeAfter(func(d time.Duration) <-chan time.Time { + called = true + return make(chan time.Time) + }) + defer restore() + + _, _, err := ctlcmd.Run(ctx, []string{"is-ready", changeID}, 0, nil) + + c.Assert(err, IsNil) + c.Check(called, Equals, false) +} diff --git a/tests/main/snapctl-is-ready/task.yaml b/tests/main/snapctl-is-ready/task.yaml new file mode 100644 index 00000000000..ef8f81f326a --- /dev/null +++ b/tests/main/snapctl-is-ready/task.yaml @@ -0,0 +1,48 @@ +summary: Ensure that snapctl is-ready command works. + +details: | + Verifies that the snapctl is-ready command correctly reports the status of a + change initiated by a snap via snapctl. A test snap with a component is + installed locally. The component is then removed via snapctl (from within the + snap's app context using snap run --shell), which creates a snapctl-remove + change marked with the initiated-by-snap change key for the calling snap. + The implementation also tracks last-access information in the in-memory state + cache rather than as a change attribute. is-ready is then called against that + change ID to verify it reports Done and exits successfully. + + Also verifies that is-ready fails appropriately for invalid change IDs, wrong + argument counts, and changes not initiated by the calling snap. + +systems: [ubuntu-16.04-64, ubuntu-18.04-64, ubuntu-2*, ubuntu-core-*, fedora-*] + +prepare: | + snap pack test-snap/ + snap pack test-comp/ + snap install --dangerous test-snapctl-is-ready_1.0_all.snap + snap install --dangerous test-snapctl-is-ready+comp_1.0.comp + +execute: | + echo "Remove component via snapctl to create a snapctl-remove change" + snap run test-snapctl-is-ready.app snapctl remove +comp + + CHANGE_ID=$(snap debug api /v2/changes?select=all | \ + gojq --raw-output '[.result[] | select(.kind == "snapctl-remove")] | last | .id') + + test -n "$CHANGE_ID" + test "$CHANGE_ID" != "null" + + echo "snapctl is-ready exits 0 for a completed change (stdout is empty; exit code conveys status)" + snap run test-snapctl-is-ready.app snapctl is-ready "$CHANGE_ID" + + echo "snapctl is-ready fails with exit 3 for an invalid change ID" + snap run test-snapctl-is-ready.app snapctl is-ready nonexistent-id || test $? -eq 3 + + echo "snapctl is-ready fails with too few arguments" + not snap run test-snapctl-is-ready.app snapctl is-ready + + echo "snapctl is-ready fails with too many arguments" + not snap run test-snapctl-is-ready.app snapctl is-ready "$CHANGE_ID" extra-arg + + echo "snapctl is-ready fails for a change not initiated by the snap" + INSTALL_CHANGE_ID=$(snap install --no-wait test-snapd-tools) + snap run test-snapctl-is-ready.app snapctl is-ready "$INSTALL_CHANGE_ID" || test $? -eq 3 diff --git a/tests/main/snapctl-is-ready/test-comp/meta/component.yaml b/tests/main/snapctl-is-ready/test-comp/meta/component.yaml new file mode 100644 index 00000000000..64b8573c6ab --- /dev/null +++ b/tests/main/snapctl-is-ready/test-comp/meta/component.yaml @@ -0,0 +1,5 @@ +component: test-snapctl-is-ready+comp +type: standard +version: 1.0 +summary: Test component for snapctl is-ready +description: Test component for snapctl is-ready diff --git a/tests/main/snapctl-is-ready/test-snap/bin/app b/tests/main/snapctl-is-ready/test-snap/bin/app new file mode 100755 index 00000000000..311cb8cb40c --- /dev/null +++ b/tests/main/snapctl-is-ready/test-snap/bin/app @@ -0,0 +1,2 @@ +#!/bin/sh +exec "$@" diff --git a/tests/main/snapctl-is-ready/test-snap/meta/snap.yaml b/tests/main/snapctl-is-ready/test-snap/meta/snap.yaml new file mode 100644 index 00000000000..121ab5f0d5d --- /dev/null +++ b/tests/main/snapctl-is-ready/test-snap/meta/snap.yaml @@ -0,0 +1,12 @@ +name: test-snapctl-is-ready +version: 1.0 +summary: Test snap for snapctl is-ready +apps: + app: + command: bin/app +base: core24 +components: + comp: + summary: test component for snapctl is-ready + description: test component for snapctl is-ready + type: standard From e71b9163301aa68e18aec6926c9bfeee851374fe Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 13 Apr 2026 08:58:12 +0000 Subject: [PATCH 11/21] fix OOB read in C mountinfo parser and add regression tests Agent-Logs-Url: https://github.com/canonical/snapd/sessions/cb4d4ae7-c044-49ec-baad-4c67d1c638be Co-authored-by: zyga <784262+zyga@users.noreply.github.com> --- cmd/libsnap-confine-private/mountinfo-test.c | 43 ++++++++++++++++++++ cmd/libsnap-confine-private/mountinfo.c | 17 ++++---- osutil/mountinfo_linux_test.go | 39 ++++++++++++++++++ 3 files changed, 91 insertions(+), 8 deletions(-) diff --git a/cmd/libsnap-confine-private/mountinfo-test.c b/cmd/libsnap-confine-private/mountinfo-test.c index 2078d4634e0..2ee7a596c2e 100644 --- a/cmd/libsnap-confine-private/mountinfo-test.c +++ b/cmd/libsnap-confine-private/mountinfo-test.c @@ -215,6 +215,47 @@ static void test_parse_mountinfo_entry__broken_octal_escaping(void) { g_assert_null(entry->next); } +static void test_parse_mountinfo_entry__partial_escape_oob(void) { + // Regression tests: partial octal escape sequences (fewer than 3 octal + // digits after the backslash, or just a trailing backslash) must not cause + // out-of-bounds reads. Each partial escape is copied verbatim. + const char *line; + struct sc_mountinfo_entry *entry; + + // Backslash followed by one octal digit at end of string. + line = "2074 27 0:54 / /tmp/dir rw - tmpfs source rw\\0"; + entry = sc_parse_mountinfo_entry(line); + g_assert_nonnull(entry); + g_test_queue_destroy((GDestroyNotify)sc_free_mountinfo_entry, entry); + g_assert_cmpstr(entry->mount_source, ==, "source"); + g_assert_cmpstr(entry->super_opts, ==, "rw\\0"); + + // Backslash followed by two octal digits at end of string. + line = "2074 27 0:54 / /tmp/dir rw - tmpfs source rw\\05"; + entry = sc_parse_mountinfo_entry(line); + g_assert_nonnull(entry); + g_test_queue_destroy((GDestroyNotify)sc_free_mountinfo_entry, entry); + g_assert_cmpstr(entry->mount_source, ==, "source"); + g_assert_cmpstr(entry->super_opts, ==, "rw\\05"); + + // Backslash followed by one octal digit then space (partial escape at + // end of a space-delimited field, not at end of string). + line = "2074 27 0:54 / /tmp/dir rw - tmpfs source\\5 rw"; + entry = sc_parse_mountinfo_entry(line); + g_assert_nonnull(entry); + g_test_queue_destroy((GDestroyNotify)sc_free_mountinfo_entry, entry); + g_assert_cmpstr(entry->mount_source, ==, "source\\5"); + g_assert_cmpstr(entry->super_opts, ==, "rw"); + + // Backslash followed by two octal digits then space. + line = "2074 27 0:54 / /tmp/dir rw - tmpfs source\\57 rw"; + entry = sc_parse_mountinfo_entry(line); + g_assert_nonnull(entry); + g_test_queue_destroy((GDestroyNotify)sc_free_mountinfo_entry, entry); + g_assert_cmpstr(entry->mount_source, ==, "source\\57"); + g_assert_cmpstr(entry->super_opts, ==, "rw"); +} + static void test_parse_mountinfo_entry__unescaped_whitespace(void) { // The kernel does not escape '\r' const char *line = "2074 27 0:54 / /tmp/strange\rdir rw,relatime shared:1039 - tmpfs tmpfs rw"; @@ -271,6 +312,8 @@ static void __attribute__((constructor)) init(void) { g_test_add_func("/mountinfo/parse_mountinfo_entry/octal_escaping", test_parse_mountinfo_entry__octal_escaping); g_test_add_func("/mountinfo/parse_mountinfo_entry/broken_octal_escaping", test_parse_mountinfo_entry__broken_octal_escaping); + g_test_add_func("/mountinfo/parse_mountinfo_entry/partial_escape_oob", + test_parse_mountinfo_entry__partial_escape_oob); g_test_add_func("/mountinfo/parse_mountinfo_entry/unescaped_whitespace", test_parse_mountinfo_entry__unescaped_whitespace); g_test_add_func("/mountinfo/parse_mountinfo_entry/broken_9p_superblock", diff --git a/cmd/libsnap-confine-private/mountinfo.c b/cmd/libsnap-confine-private/mountinfo.c index 3b84a585e81..1104c3c7f20 100644 --- a/cmd/libsnap-confine-private/mountinfo.c +++ b/cmd/libsnap-confine-private/mountinfo.c @@ -134,8 +134,9 @@ static char *parse_next_string_field_ex(sc_mountinfo_entry *entry, const char *l bool allow_spaces_in_field) { const char *input = &line[*offset]; char *output = &entry->line_buf[*offset]; - size_t input_idx = 0; // reading index - size_t output_idx = 0; // writing index + size_t input_idx = 0; // reading index + size_t output_idx = 0; // writing index + size_t input_len = strlen(input); // length of remaining input (used for bounds checks below) // Scan characters until we run out of memory to scan or we find a // space. The kernel uses simple octal escape sequences for the @@ -170,13 +171,13 @@ static char *parse_next_string_field_ex(sc_mountinfo_entry *entry, const char *l } else if (c == '\\') { // Three *more* octal digits required for the escape // sequence. For reference see mangle_path() in - // fs/seq_file.c. Note that is_octal_digit returns - // false on the string terminator character NUL and the - // short-circuiting behavior of && makes this check - // correct even if '\\' is the last character of the - // string. + // fs/seq_file.c. We explicitly verify that at least 3 + // more bytes remain before the end of the string to + // prevent out-of-bounds reads when the input contains + // fewer than 3 bytes after the backslash. const char *s = &input[input_idx]; - if (is_octal_digit(s[1]) && is_octal_digit(s[2]) && is_octal_digit(s[3])) { + if (input_idx + 4 <= input_len && + is_octal_digit(s[1]) && is_octal_digit(s[2]) && is_octal_digit(s[3])) { // Unescape the octal value encoded in s[1], // s[2] and s[3]. Because we are working with // byte values there are no issues related to diff --git a/osutil/mountinfo_linux_test.go b/osutil/mountinfo_linux_test.go index 4746f721b85..4e454a88b4a 100644 --- a/osutil/mountinfo_linux_test.go +++ b/osutil/mountinfo_linux_test.go @@ -157,6 +157,45 @@ func (s *mountinfoSuite) TestParseMountInfoEntry5(c *C) { c.Assert(entry.MountDir, Equals, "/tmp/strange\rdir") } +// TestParseMountInfoEntryBrokenOctalEscaping checks that partial octal escape +// sequences (fewer than 3 octal digits after a backslash, including a trailing +// backslash) do not cause a panic and are preserved verbatim, consistent with +// the behaviour of the C mountinfo parser. +func (s *mountinfoSuite) TestParseMountInfoEntryBrokenOctalEscaping(c *C) { + // Non-octal chars after backslash and trailing backslash in last field. + entry, err := osutil.ParseMountInfoEntry( + `2074 27 0:54 / /tmp/strange-dir rw,relatime shared:1039 - tmpfs no\888thing rw\`) + c.Assert(err, IsNil) + c.Assert(entry.MountSource, Equals, `no\888thing`) + c.Assert(entry.SuperOptions, DeepEquals, map[string]string{`rw\`: ""}) + + // Backslash followed by one octal digit at end of string. + entry, err = osutil.ParseMountInfoEntry( + `2074 27 0:54 / /tmp/dir rw - tmpfs source rw\0`) + c.Assert(err, IsNil) + c.Assert(entry.SuperOptions, DeepEquals, map[string]string{`rw\0`: ""}) + + // Backslash followed by two octal digits at end of string. + entry, err = osutil.ParseMountInfoEntry( + `2074 27 0:54 / /tmp/dir rw - tmpfs source rw\05`) + c.Assert(err, IsNil) + c.Assert(entry.SuperOptions, DeepEquals, map[string]string{`rw\05`: ""}) + + // Backslash followed by one octal digit in mount source (field ended by space). + entry, err = osutil.ParseMountInfoEntry( + `2074 27 0:54 / /tmp/dir rw - tmpfs source\5 rw`) + c.Assert(err, IsNil) + c.Assert(entry.MountSource, Equals, `source\5`) + c.Assert(entry.SuperOptions, DeepEquals, map[string]string{"rw": ""}) + + // Backslash followed by two octal digits in mount source. + entry, err = osutil.ParseMountInfoEntry( + `2074 27 0:54 / /tmp/dir rw - tmpfs source\57 rw`) + c.Assert(err, IsNil) + c.Assert(entry.MountSource, Equals, `source\57`) + c.Assert(entry.SuperOptions, DeepEquals, map[string]string{"rw": ""}) +} + // Test that empty mountinfo is parsed without errors. func (s *mountinfoSuite) TestReadMountInfo1(c *C) { entries, err := osutil.ReadMountInfo(strings.NewReader("")) From 57533c061535363503a121cd35c8d69ccc83c077 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 13 Apr 2026 10:17:41 +0000 Subject: [PATCH 12/21] cmd/libsnap-confine-private: add lone trailing backslash case to partial_escape_oob test Agent-Logs-Url: https://github.com/canonical/snapd/sessions/c41263a1-0b0a-4730-a291-91a8ecf80263 Co-authored-by: zyga <784262+zyga@users.noreply.github.com> --- cmd/libsnap-confine-private/mountinfo-test.c | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cmd/libsnap-confine-private/mountinfo-test.c b/cmd/libsnap-confine-private/mountinfo-test.c index 2ee7a596c2e..d54b699c282 100644 --- a/cmd/libsnap-confine-private/mountinfo-test.c +++ b/cmd/libsnap-confine-private/mountinfo-test.c @@ -222,6 +222,14 @@ static void test_parse_mountinfo_entry__partial_escape_oob(void) { const char *line; struct sc_mountinfo_entry *entry; + // Lone trailing backslash at end of string. + line = "2074 27 0:54 / /tmp/dir rw - tmpfs source rw\\"; + entry = sc_parse_mountinfo_entry(line); + g_assert_nonnull(entry); + g_test_queue_destroy((GDestroyNotify)sc_free_mountinfo_entry, entry); + g_assert_cmpstr(entry->mount_source, ==, "source"); + g_assert_cmpstr(entry->super_opts, ==, "rw\\"); + // Backslash followed by one octal digit at end of string. line = "2074 27 0:54 / /tmp/dir rw - tmpfs source rw\\0"; entry = sc_parse_mountinfo_entry(line); From 4a5044e24b38d812770073215f896cd23df294db Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 13 Apr 2026 10:42:44 +0000 Subject: [PATCH 13/21] cmd/libsnap-confine-private: clang-format mountinfo.c and mountinfo-test.c Agent-Logs-Url: https://github.com/canonical/snapd/sessions/b588b58a-c79b-4639-9612-5a71bb0c7285 Co-authored-by: zyga <784262+zyga@users.noreply.github.com> --- cmd/libsnap-confine-private/mountinfo.c | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cmd/libsnap-confine-private/mountinfo.c b/cmd/libsnap-confine-private/mountinfo.c index 1104c3c7f20..7298772ef73 100644 --- a/cmd/libsnap-confine-private/mountinfo.c +++ b/cmd/libsnap-confine-private/mountinfo.c @@ -134,8 +134,8 @@ static char *parse_next_string_field_ex(sc_mountinfo_entry *entry, const char *l bool allow_spaces_in_field) { const char *input = &line[*offset]; char *output = &entry->line_buf[*offset]; - size_t input_idx = 0; // reading index - size_t output_idx = 0; // writing index + size_t input_idx = 0; // reading index + size_t output_idx = 0; // writing index size_t input_len = strlen(input); // length of remaining input (used for bounds checks below) // Scan characters until we run out of memory to scan or we find a @@ -176,8 +176,7 @@ static char *parse_next_string_field_ex(sc_mountinfo_entry *entry, const char *l // prevent out-of-bounds reads when the input contains // fewer than 3 bytes after the backslash. const char *s = &input[input_idx]; - if (input_idx + 4 <= input_len && - is_octal_digit(s[1]) && is_octal_digit(s[2]) && is_octal_digit(s[3])) { + if (input_idx + 4 <= input_len && is_octal_digit(s[1]) && is_octal_digit(s[2]) && is_octal_digit(s[3])) { // Unescape the octal value encoded in s[1], // s[2] and s[3]. Because we are working with // byte values there are no issues related to From 9401825fd645f63c883a1a141e2785c9a2e96274 Mon Sep 17 00:00:00 2001 From: Zygmunt Krynicki Date: Tue, 14 Apr 2026 13:10:32 +0200 Subject: [PATCH 14/21] cmd: tweak spaces Co-authored-by: Maciej Borzecki --- cmd/libsnap-confine-private/mountinfo.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/libsnap-confine-private/mountinfo.c b/cmd/libsnap-confine-private/mountinfo.c index 7298772ef73..24d1f991584 100644 --- a/cmd/libsnap-confine-private/mountinfo.c +++ b/cmd/libsnap-confine-private/mountinfo.c @@ -170,8 +170,8 @@ static char *parse_next_string_field_ex(sc_mountinfo_entry *entry, const char *l break; } else if (c == '\\') { // Three *more* octal digits required for the escape - // sequence. For reference see mangle_path() in - // fs/seq_file.c. We explicitly verify that at least 3 + // sequence. For reference see mangle_path() in + // fs/seq_file.c. We explicitly verify that at least 3 // more bytes remain before the end of the string to // prevent out-of-bounds reads when the input contains // fewer than 3 bytes after the backslash. From 1c95eb2dc5bf6d0be91acce74a07c03d7f46b6eb Mon Sep 17 00:00:00 2001 From: Katie May Date: Wed, 22 Apr 2026 09:15:24 +0200 Subject: [PATCH 15/21] tests: set path in enviornment for service in enable-disable-units-gpio --- tests/core/enable-disable-units-gpio/task.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/enable-disable-units-gpio/task.yaml b/tests/core/enable-disable-units-gpio/task.yaml index 22e4a0ef6c9..a474b4e3497 100644 --- a/tests/core/enable-disable-units-gpio/task.yaml +++ b/tests/core/enable-disable-units-gpio/task.yaml @@ -21,7 +21,7 @@ skip: prepare: | echo "Create/enable fake gpio" - tests.systemd create-and-start-unit fake-gpio "$TESTSLIB/fakegpio/fake-gpio.py" "[Unit]\\nBefore=snap.snapd.interface.gpio-100.service\\n[Service]\\nType=notify" + tests.systemd create-and-start-unit fake-gpio "$TESTSLIB/fakegpio/fake-gpio.py" "[Unit]\\nBefore=snap.snapd.interface.gpio-100.service\\n[Service]\\nType=notify\\nEnvironment=PATH=$PATH" echo "Given a snap declaring a plug on gpio is installed" "$TESTSTOOLS"/snaps-state install-local gpio-consumer From d8ede655ac4abc028a41e548a76074cf7ee6c646 Mon Sep 17 00:00:00 2001 From: Maciej Borzecki Date: Wed, 22 Apr 2026 10:57:23 +0200 Subject: [PATCH 16/21] tests/main/disk-space-awareness: ensure consistent state before and after the test, bump size (#16954) Ensure that the state before and after the state is consistent. Specifically, mount units created during suite prepare are carried over to the test, thus mounting a tmpfs on top of /var/lib/snapd creates a discrepancy between e.g. unit files under /etc/systemd/system and actual snapd state. Adding purge ensures that the system state and snapd state match again. Signed-off-by: Maciej Borzecki --- tests/main/disk-space-awareness/task.yaml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/main/disk-space-awareness/task.yaml b/tests/main/disk-space-awareness/task.yaml index a11cd739af8..304a7c63f1e 100644 --- a/tests/main/disk-space-awareness/task.yaml +++ b/tests/main/disk-space-awareness/task.yaml @@ -11,18 +11,26 @@ environment: TMPFSMOUNT: /var/lib/snapd # filling tmpfs mounted under /var/lib/snapd triggers OOM SNAPD_NO_MEMORY_LIMIT: 1 - SUFFICIENT_SIZE: 200M + SUFFICIENT_SIZE: 300M prepare: | systemctl stop snapd.{socket,service} + SNAP_MOUNT_DIR="$(os.paths snap-mount-dir)" + # purge removes the snap mount directory, which needs to be restored + snapd.tool exec snap-mgmt --purge + mkdir -p "$SNAP_MOUNT_DIR" # mount /var/lib/snapd on a tmpfs mount -t tmpfs tmpfs -o size="$SUFFICIENT_SIZE",mode=0755 "$TMPFSMOUNT" systemctl start snapd.{socket,service} + snap wait system seed.loaded restore: | systemctl stop snapd.{socket,service} + SNAP_MOUNT_DIR="$(os.paths snap-mount-dir)" + snapd.tool exec snap-mgmt --purge + mkdir -p "$SNAP_MOUNT_DIR" umount -l "$TMPFSMOUNT" systemctl start snapd.{socket,service} From 74ab8c6abf8f105f7c4281872d25d334bcb64405 Mon Sep 17 00:00:00 2001 From: alfonsosanchezbeato Date: Wed, 22 Apr 2026 05:36:58 -0400 Subject: [PATCH 17/21] overlord, snap: omit snapd refresh suggestion for ISA-related assume errors (#16945) * overlord, snap: omit snapd refresh suggestion for ISA-related assume errors Remove "(try to refresh snapd)" suggestion from error messages when snap assumes fail due to ISA-related issues. ISA errors indicate architectural incompatibilities that cannot be resolved by refreshing snapd, so the suggestion is misleading in these cases. Introduce IsaError type to distinguish ISA validation failures from other assume validation errors, allowing appropriate error message formatting. * tests/regression/lp-1813365: adapt to python locaction in UC26 --- overlord/snapstate/check_snap.go | 7 ++++++- overlord/snapstate/check_snap_test.go | 6 ++++++ snap/naming/validate.go | 11 ++++++++++- tests/regression/lp-1813365/task.yaml | 2 +- 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/overlord/snapstate/check_snap.go b/overlord/snapstate/check_snap.go index 1833d2b974e..48e2db07bd4 100644 --- a/overlord/snapstate/check_snap.go +++ b/overlord/snapstate/check_snap.go @@ -144,7 +144,12 @@ func validateInfoAndFlags(info *snap.Info, snapst *SnapState, flags Flags) error // check assumes err := naming.ValidateAssumes(info.Assumes, snapdtool.Version, featureSet, arch.DpkgArchitecture()) if err != nil { - return fmt.Errorf("snap %q assumes %w (try to refresh snapd)", info.InstanceName(), err) + askToRefreshSnapd := " (try to refresh snapd)" + isaErr := &naming.IsaError{} + if errors.As(err, &isaErr) { + askToRefreshSnapd = "" + } + return fmt.Errorf("snap %q assumes %w%s", info.InstanceName(), err, askToRefreshSnapd) } // check and create system-usernames diff --git a/overlord/snapstate/check_snap_test.go b/overlord/snapstate/check_snap_test.go index b0406619509..c445322f8d6 100644 --- a/overlord/snapstate/check_snap_test.go +++ b/overlord/snapstate/check_snap_test.go @@ -27,6 +27,7 @@ import ( . "gopkg.in/check.v1" "github.com/snapcore/snapd/arch" + "github.com/snapcore/snapd/arch/archtest" "github.com/snapcore/snapd/asserts" "github.com/snapcore/snapd/dirs" "github.com/snapcore/snapd/osutil" @@ -95,6 +96,8 @@ architectures: } func (s *checkSnapSuite) TestCheckSnapAssumes(c *C) { + s.AddCleanup(archtest.MockArchitecture("arm64")) + var assumesTests = []struct { version string assumes string @@ -109,6 +112,9 @@ func (s *checkSnapSuite) TestCheckSnapAssumes(c *C) { assumes: "[f1, f2]", classic: true, error: `snap "foo" assumes unsupported features: f1, f2 \(try to refresh snapd\)`, + }, { + assumes: "[isa-arm64-someisa]", + error: `snap "foo" assumes isa-arm64-someisa: ISA specification is not supported for arch: arm64`, }, } diff --git a/snap/naming/validate.go b/snap/naming/validate.go index 7ce54e27f99..ab3ebed0842 100644 --- a/snap/naming/validate.go +++ b/snap/naming/validate.go @@ -315,6 +315,15 @@ func validateAssumedSnapdVersion(assumedVersion, currentVersion string) (bool, e var archIsISASupportedByCPU = arch.IsISASupportedByCPU +type IsaError struct { + Flag string + Err error +} + +func (e *IsaError) Error() string { + return fmt.Sprintf("%s: %s", e.Flag, e.Err) +} + // validateAssumedISAArch checks that, when a snap requires an ISA to be supported: // 1. compares the specified with the device's one. If they differ, it exits // without error signaling that the flag is valid @@ -338,7 +347,7 @@ func validateAssumedISAArch(flag string, currentArchitecture string) error { } if err := archIsISASupportedByCPU(tokens[2]); err != nil { - return fmt.Errorf("%s: %s", flag, err) + return &IsaError{Flag: flag, Err: err} } return nil diff --git a/tests/regression/lp-1813365/task.yaml b/tests/regression/lp-1813365/task.yaml index 9c814f7ce0c..c94e78b123f 100644 --- a/tests/regression/lp-1813365/task.yaml +++ b/tests/regression/lp-1813365/task.yaml @@ -30,5 +30,5 @@ restore: | rm -f /tmp/logger.log execute: | - su -l -c "$(pwd)/helper" test + su -l -c "PATH=\$PATH:/usr/lib/python $(pwd)/helper" test not test -e /tmp/logger.log From 6a01a45793d585c4b754b9d712ebd96346edd441 Mon Sep 17 00:00:00 2001 From: Valentin David Date: Wed, 22 Apr 2026 12:27:39 +0200 Subject: [PATCH 18/21] tests: enable more nested tests on 26 (#16939) --- tests/nested/classic/azure-cvm/task.yaml | 2 -- tests/nested/core/interfaces-custom-devices/task.yaml | 2 +- tests/nested/manual/component-recovery-system-offline/task.yaml | 2 +- tests/nested/manual/component-recovery-system/task.yaml | 2 +- .../manual/core20-fault-inject-on-install-component/task.yaml | 2 -- tests/nested/manual/seeding-failure/task.yaml | 2 -- .../manual/snapd-removes-vulnerable-snap-confine-revs/task.yaml | 2 +- 7 files changed, 4 insertions(+), 10 deletions(-) diff --git a/tests/nested/classic/azure-cvm/task.yaml b/tests/nested/classic/azure-cvm/task.yaml index 69c247262bb..405059ccd29 100644 --- a/tests/nested/classic/azure-cvm/task.yaml +++ b/tests/nested/classic/azure-cvm/task.yaml @@ -9,8 +9,6 @@ systems: - -ubuntu-16.04-* - -ubuntu-18.04-* - -ubuntu-20.04-* - # FIXME - - -ubuntu-26.04-* environment: SNAPD_DEB_FROM_REPO: false diff --git a/tests/nested/core/interfaces-custom-devices/task.yaml b/tests/nested/core/interfaces-custom-devices/task.yaml index 9ec3a22507a..ed0a918a5f1 100644 --- a/tests/nested/core/interfaces-custom-devices/task.yaml +++ b/tests/nested/core/interfaces-custom-devices/task.yaml @@ -5,10 +5,10 @@ details: | granting access to the devices it defines. systems: - # FIXME: make it work on 26 - ubuntu-20* - ubuntu-22* - ubuntu-24* + - ubuntu-26* prepare: | # Add our interface to the gadget snap diff --git a/tests/nested/manual/component-recovery-system-offline/task.yaml b/tests/nested/manual/component-recovery-system-offline/task.yaml index 03dc50501b4..0a0daeb635e 100644 --- a/tests/nested/manual/component-recovery-system-offline/task.yaml +++ b/tests/nested/manual/component-recovery-system-offline/task.yaml @@ -9,8 +9,8 @@ details: | HTTP form. systems: - # FIXME: make it work on 26 - ubuntu-24* + - ubuntu-26* environment: MODEL_JSON: $TESTSLIB/assertions/test-snapd-component-recovery-system-pc-VERSION.json diff --git a/tests/nested/manual/component-recovery-system/task.yaml b/tests/nested/manual/component-recovery-system/task.yaml index 4f4dd0ebcd9..762673728e5 100644 --- a/tests/nested/manual/component-recovery-system/task.yaml +++ b/tests/nested/manual/component-recovery-system/task.yaml @@ -5,8 +5,8 @@ details: | validates that the newly created system can be rebooted into. systems: - # FIXME: make it work on 26 - ubuntu-24* + - ubuntu-26* environment: MODEL_JSON: $TESTSLIB/assertions/test-snapd-component-recovery-system-pc-VERSION.json diff --git a/tests/nested/manual/core20-fault-inject-on-install-component/task.yaml b/tests/nested/manual/core20-fault-inject-on-install-component/task.yaml index 60c2ef205cd..31a0306ffd8 100644 --- a/tests/nested/manual/core20-fault-inject-on-install-component/task.yaml +++ b/tests/nested/manual/core20-fault-inject-on-install-component/task.yaml @@ -8,8 +8,6 @@ systems: - -ubuntu-1* - -ubuntu-20* - -ubuntu-22* - # FIXME - - -ubuntu-26* environment: TAG/kernel_panic_prepare_kernel_components: prepare-kernel-components diff --git a/tests/nested/manual/seeding-failure/task.yaml b/tests/nested/manual/seeding-failure/task.yaml index e9a78e63313..f3f8d28625c 100644 --- a/tests/nested/manual/seeding-failure/task.yaml +++ b/tests/nested/manual/seeding-failure/task.yaml @@ -14,8 +14,6 @@ systems: - -ubuntu-1* - -ubuntu-20* - -ubuntu-22* - # FIXME - - -ubuntu-26* environment: MODEL_JSON: $TESTSLIB/assertions/test-snapd-failed-seeding-pc-VERSION.json diff --git a/tests/nested/manual/snapd-removes-vulnerable-snap-confine-revs/task.yaml b/tests/nested/manual/snapd-removes-vulnerable-snap-confine-revs/task.yaml index 3829166a64d..04031878b3c 100644 --- a/tests/nested/manual/snapd-removes-vulnerable-snap-confine-revs/task.yaml +++ b/tests/nested/manual/snapd-removes-vulnerable-snap-confine-revs/task.yaml @@ -7,10 +7,10 @@ details: | # just focal is fine for this test - we only need to check that things happen on # classic systems: - # FIXME: make it work on 26 - ubuntu-20* - ubuntu-22* - ubuntu-24* + - ubuntu-26* environment: # which snap snapd comes from in this test From 24531bfc0cbb6d3c53d96b9f0d7f1a44f629edf5 Mon Sep 17 00:00:00 2001 From: ernestl Date: Tue, 7 Oct 2025 22:52:04 +0200 Subject: [PATCH 19/21] seclog: add structured security logger - at this point both journal and audit sinks are implemented. --- cmd/snapd/main.go | 8 + daemon/api_users.go | 10 + daemon/api_users_test.go | 20 ++ daemon/export_api_users_test.go | 5 + data/systemd/Makefile | 10 + data/systemd/journald@snapd-security.conf | 12 + .../snapd.service.d/security-journal.conf | 6 + .../00-snapd.conf | 5 + .../configstate/configcore/export_test.go | 8 + overlord/configstate/configcore/handlers.go | 3 + .../configcore/security_logging.go | 218 ++++++++++++ .../configcore/security_logging_test.go | 300 ++++++++++++++++ seclog/audit.go | 115 ++++++ seclog/export_slog_test.go | 26 ++ seclog/export_test.go | 92 +++++ seclog/journal.go | 134 +++++++ seclog/journal_test.go | 138 +++++++ seclog/nop.go | 46 +++ seclog/nop_test.go | 77 ++++ seclog/seclog.go | 301 ++++++++++++++++ seclog/seclog_test.go | 336 ++++++++++++++++++ seclog/slog.go | 234 ++++++++++++ seclog/slog_test.go | 306 ++++++++++++++++ wrappers/core18.go | 66 ++++ wrappers/core18_test.go | 34 ++ 25 files changed, 2510 insertions(+) create mode 100644 data/systemd/journald@snapd-security.conf create mode 100644 data/systemd/snapd.service.d/security-journal.conf create mode 100644 data/systemd/systemd-journald@snapd-security.service.d/00-snapd.conf create mode 100644 overlord/configstate/configcore/security_logging.go create mode 100644 overlord/configstate/configcore/security_logging_test.go create mode 100644 seclog/audit.go create mode 100644 seclog/export_slog_test.go create mode 100644 seclog/export_test.go create mode 100644 seclog/journal.go create mode 100644 seclog/journal_test.go create mode 100644 seclog/nop.go create mode 100644 seclog/nop_test.go create mode 100644 seclog/seclog.go create mode 100644 seclog/seclog_test.go create mode 100644 seclog/slog.go create mode 100644 seclog/slog_test.go diff --git a/cmd/snapd/main.go b/cmd/snapd/main.go index 89c3ffd9580..9f7c6dea01e 100644 --- a/cmd/snapd/main.go +++ b/cmd/snapd/main.go @@ -33,6 +33,7 @@ import ( "github.com/snapcore/snapd/osutil" "github.com/snapcore/snapd/sandbox" "github.com/snapcore/snapd/secboot" + "github.com/snapcore/snapd/seclog" "github.com/snapcore/snapd/snapdenv" "github.com/snapcore/snapd/snapdtool" "github.com/snapcore/snapd/syscheck" @@ -43,8 +44,15 @@ var ( syscheckCheckSystem = syscheck.CheckSystem ) +const secLogAppID = "canonical.snapd.snapd" +const secLogMinLevel seclog.Level = seclog.LevelInfo + func init() { logger.SimpleSetup(nil) + + if err := seclog.Setup(seclog.ImplSlog, seclog.SinkJournal, secLogAppID, secLogMinLevel); err != nil { + logger.Noticef("%v", err) + } } func main() { diff --git a/daemon/api_users.go b/daemon/api_users.go index 4550dbee2e7..23648f04faa 100644 --- a/daemon/api_users.go +++ b/daemon/api_users.go @@ -31,6 +31,7 @@ import ( "github.com/snapcore/snapd/overlord/devicestate" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/release" + "github.com/snapcore/snapd/seclog" "github.com/snapcore/snapd/store" ) @@ -68,6 +69,8 @@ var ( deviceStateCreateUser = devicestate.CreateUser deviceStateCreateKnownUsers = devicestate.CreateKnownUsers deviceStateRemoveUser = devicestate.RemoveUser + + seclogLogLoginSuccess = seclog.LogLoginSuccess ) // userResponseData contains the data releated to user creation/login/query @@ -175,6 +178,13 @@ func loginUser(c *Command, r *http.Request, user *auth.UserState) Response { return InternalError("cannot persist authentication details: %v", err) } + seclogLogLoginSuccess(seclog.SnapdUser{ + ID: int64(user.ID), + SystemUserName: user.Username, + StoreUserEmail: user.Email, + Expiration: user.Expiration, + }) + result := userResponseData{ ID: user.ID, Username: user.Username, diff --git a/daemon/api_users_test.go b/daemon/api_users_test.go index 28a8b729002..b0e645e06ad 100644 --- a/daemon/api_users_test.go +++ b/daemon/api_users_test.go @@ -39,6 +39,7 @@ import ( "github.com/snapcore/snapd/overlord/devicestate/devicestatetest" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/release" + "github.com/snapcore/snapd/seclog" "github.com/snapcore/snapd/store" "github.com/snapcore/snapd/testutil" ) @@ -113,6 +114,11 @@ func (s *userSuite) TestLoginUser(c *check.C) { s.expectLoginAccess() + var loggedUser seclog.SnapdUser + s.AddCleanup(daemon.MockSeclogLogLoginSuccess(func(user seclog.SnapdUser) { + loggedUser = user + })) + s.loginUserStoreMacaroon = "user-macaroon" s.loginUserDischarge = "the-discharge-macaroon-serialized-data" buf := bytes.NewBufferString(`{"username": "email@.com", "password": "password"}`) @@ -149,6 +155,10 @@ func (s *userSuite) TestLoginUser(c *check.C) { c.Check(err, check.IsNil) c.Check(snapdMacaroon.Id(), check.Equals, "1") c.Check(snapdMacaroon.Location(), check.Equals, "snapd") + + // security log was called with the right user details + c.Check(loggedUser.ID, check.Equals, int64(1)) + c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") } func (s *userSuite) TestLoginUserWithUsername(c *check.C) { @@ -156,6 +166,11 @@ func (s *userSuite) TestLoginUserWithUsername(c *check.C) { s.expectLoginAccess() + var loggedUser seclog.SnapdUser + s.AddCleanup(daemon.MockSeclogLogLoginSuccess(func(user seclog.SnapdUser) { + loggedUser = user + })) + s.loginUserStoreMacaroon = "user-macaroon" s.loginUserDischarge = "the-discharge-macaroon-serialized-data" buf := bytes.NewBufferString(`{"username": "username", "email": "email@.com", "password": "password"}`) @@ -191,6 +206,11 @@ func (s *userSuite) TestLoginUserWithUsername(c *check.C) { c.Check(err, check.IsNil) c.Check(snapdMacaroon.Id(), check.Equals, "1") c.Check(snapdMacaroon.Location(), check.Equals, "snapd") + + // security log was called with the right user details + c.Check(loggedUser.ID, check.Equals, int64(1)) + c.Check(loggedUser.SystemUserName, check.Equals, "username") + c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") } func (s *userSuite) TestLoginUserNoEmailWithExistentLocalUser(c *check.C) { diff --git a/daemon/export_api_users_test.go b/daemon/export_api_users_test.go index 6a4aff33dec..8a20ca73b03 100644 --- a/daemon/export_api_users_test.go +++ b/daemon/export_api_users_test.go @@ -25,6 +25,7 @@ import ( "github.com/snapcore/snapd/overlord/auth" "github.com/snapcore/snapd/overlord/devicestate" "github.com/snapcore/snapd/overlord/state" + "github.com/snapcore/snapd/seclog" "github.com/snapcore/snapd/testutil" ) @@ -52,6 +53,10 @@ func MockDeviceStateRemoveUser(removeUser func(st *state.State, username string, return restore } +func MockSeclogLogLoginSuccess(f func(user seclog.SnapdUser)) (restore func()) { + return testutil.Mock(&seclogLogLoginSuccess, f) +} + type ( UserResponseData = userResponseData ) diff --git a/data/systemd/Makefile b/data/systemd/Makefile index d294d1b5d27..f782a977248 100644 --- a/data/systemd/Makefile +++ b/data/systemd/Makefile @@ -40,6 +40,16 @@ install: $(SYSTEMD_UNITS) install -d -m 0755 $(DESTDIR)/$(LIBEXECDIR)/snapd install -m 0755 -t $(DESTDIR)/$(LIBEXECDIR)/snapd snapd.core-fixup.sh install -m 0755 -t $(DESTDIR)/$(LIBEXECDIR)/snapd snapd.run-from-snap + # security log journal namespace + install -d -m 0755 $(DESTDIR)/etc/systemd + install -m 0644 journald@snapd-security.conf $(DESTDIR)/etc/systemd/journald@snapd-security.conf + install -d -m 0755 $(DESTDIR)/$(SYSTEMDSYSTEMUNITDIR)/systemd-journald@snapd-security.service.d + install -m 0644 systemd-journald@snapd-security.service.d/00-snapd.conf \ + $(DESTDIR)/$(SYSTEMDSYSTEMUNITDIR)/systemd-journald@snapd-security.service.d/00-snapd.conf + # snapd.service drop-in to pull in the security journal socket (Wants + After) + install -d -m 0755 $(DESTDIR)/$(SYSTEMDSYSTEMUNITDIR)/snapd.service.d + install -m 0644 snapd.service.d/security-journal.conf \ + $(DESTDIR)/$(SYSTEMDSYSTEMUNITDIR)/snapd.service.d/security-journal.conf .PHONY: clean clean: diff --git a/data/systemd/journald@snapd-security.conf b/data/systemd/journald@snapd-security.conf new file mode 100644 index 00000000000..337346a708d --- /dev/null +++ b/data/systemd/journald@snapd-security.conf @@ -0,0 +1,12 @@ +# Journald configuration for the snapd security log namespace. +# This namespace isolates security audit events (login, access control) +# from the main system journal. +[Journal] +Storage=persistent +Compress=yes +SystemMaxFileSize=10M +SystemMaxUse=10M +SyncIntervalSec=30s +SyncOnShutdown=yes +RateLimitIntervalSec=30s +RateLimitBurst=10000 diff --git a/data/systemd/snapd.service.d/security-journal.conf b/data/systemd/snapd.service.d/security-journal.conf new file mode 100644 index 00000000000..f5fd3123a47 --- /dev/null +++ b/data/systemd/snapd.service.d/security-journal.conf @@ -0,0 +1,6 @@ +# Pull in the snapd-security journal namespace socket alongside snapd. +# This is a best-effort dependency: if the socket is masked or fails, +# snapd starts normally without security logging. +[Unit] +Wants=systemd-journald@snapd-security.socket +After=systemd-journald@snapd-security.socket diff --git a/data/systemd/systemd-journald@snapd-security.service.d/00-snapd.conf b/data/systemd/systemd-journald@snapd-security.service.d/00-snapd.conf new file mode 100644 index 00000000000..9e36515dc43 --- /dev/null +++ b/data/systemd/systemd-journald@snapd-security.service.d/00-snapd.conf @@ -0,0 +1,5 @@ +# Drop-in for the snapd-security journal namespace instance. +# Clears the default LogsDirectory to prevent failures with +# namespaced journald instances. +[Service] +LogsDirectory= diff --git a/overlord/configstate/configcore/export_test.go b/overlord/configstate/configcore/export_test.go index bd112770045..2afb299aaac 100644 --- a/overlord/configstate/configcore/export_test.go +++ b/overlord/configstate/configcore/export_test.go @@ -107,3 +107,11 @@ func MockEnvPath(newEnvPath string) func() { envFilePath = newEnvPath return func() { envFilePath = oldEnvPath } } + +func MockSeclogEnable(f func() error) func() { + return testutil.Mock(&seclogEnable, f) +} + +func MockSeclogDisable(f func() error) func() { + return testutil.Mock(&seclogDisable, f) +} diff --git a/overlord/configstate/configcore/handlers.go b/overlord/configstate/configcore/handlers.go index da39522a240..a5c4f63729a 100644 --- a/overlord/configstate/configcore/handlers.go +++ b/overlord/configstate/configcore/handlers.go @@ -124,6 +124,9 @@ func init() { // system.motd addFSOnlyHandler(validateMotdConfiguration, handleMotdConfiguration, coreOnly) + // security-logging.* + addFSOnlyHandler(validateSecurityLoggingSettings, handleSecurityLoggingConfiguration, nil) + sysconfig.ApplyFilesystemOnlyDefaultsImpl = filesystemOnlyApply } diff --git a/overlord/configstate/configcore/security_logging.go b/overlord/configstate/configcore/security_logging.go new file mode 100644 index 00000000000..1bb50bf3d4d --- /dev/null +++ b/overlord/configstate/configcore/security_logging.go @@ -0,0 +1,218 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package configcore + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/snapcore/snapd/dirs" + "github.com/snapcore/snapd/logger" + "github.com/snapcore/snapd/osutil" + "github.com/snapcore/snapd/seclog" + "github.com/snapcore/snapd/sysconfig" + "github.com/snapcore/snapd/systemd" +) + +const ( + securityJournalSocketUnit = "systemd-journald@snapd-security.socket" + securityJournalServiceUnit = "systemd-journald@snapd-security.service" + securityJournalConfPath = "/etc/systemd/journald@snapd-security.conf" + securityJournalDefaultMaxSize = "10M" +) + +var ( + seclogEnable = seclog.Enable + seclogDisable = seclog.Disable +) + +func init() { + supportedConfigurations["core.security-logging.enabled"] = true + supportedConfigurations["core.security-logging.max-size"] = true +} + +func validateSecurityLoggingSettings(tr ConfGetter) error { + if err := validateBoolFlag(tr, "security-logging.enabled"); err != nil { + return err + } + + maxSize, err := coreCfg(tr, "security-logging.max-size") + if err != nil { + return err + } + if maxSize != "" { + if err := validateJournalSizeValue(maxSize); err != nil { + return fmt.Errorf("security-logging.max-size %v", err) + } + } + + return nil +} + +func handleSecurityLoggingConfiguration(_ sysconfig.Device, tr ConfGetter, opts *fsOnlyContext) error { + enabled, err := coreCfg(tr, "security-logging.enabled") + if err != nil { + return err + } + maxSize, err := coreCfg(tr, "security-logging.max-size") + if err != nil { + return err + } + + // If nothing is set, do nothing. + if enabled == "" && maxSize == "" { + return nil + } + + rootDir := dirs.GlobalRootDir + if opts != nil { + rootDir = opts.RootDir + } + + switch enabled { + case "false": + return disableSecurityLogging(rootDir, opts) + default: + return enableSecurityLogging(rootDir, opts, maxSize) + } +} + +func disableSecurityLogging(rootDir string, opts *fsOnlyContext) error { + // Disconnect from the journal namespace first so that the sink + // is closed before we tear down the service. + if err := seclogDisable(); err != nil { + logger.Noticef("cannot disable security logger: %v", err) + } + + if opts != nil { + // During filesystem-only apply, mask the socket by creating + // a symlink to /dev/null, mirroring what systemctl mask does. + maskDir := filepath.Join(rootDir, "/etc/systemd/system") + if err := os.MkdirAll(maskDir, 0755); err != nil { + return err + } + maskPath := filepath.Join(maskDir, securityJournalSocketUnit) + os.Remove(maskPath) + return os.Symlink("/dev/null", maskPath) + } + + sysd := systemd.NewUnderRoot(rootDir, systemd.SystemMode, nil) + // Mask the socket to prevent future activation. The running + // journald instance (if any) will exit on its own once idle. + if err := sysd.Mask(securityJournalSocketUnit); err != nil { + return err + } + return nil +} + +func enableSecurityLogging(rootDir string, opts *fsOnlyContext, maxSize string) error { + confPath := filepath.Join(rootDir, securityJournalConfPath) + + // Write the namespace config file with current settings. + conf := generateSecurityJournalConf(maxSize) + if err := osutil.AtomicWriteFile(confPath, conf, 0644, 0); err != nil { + return err + } + + if opts != nil { + // Filesystem-only apply; remove any mask symlink that may + // have been created by a previous disable. + maskPath := filepath.Join(rootDir, "/etc/systemd/system", securityJournalSocketUnit) + os.Remove(maskPath) + return nil + } + + sysd := systemd.NewUnderRoot(rootDir, systemd.SystemMode, nil) + + // Unmask in case it was previously disabled. + if err := sysd.Unmask(securityJournalSocketUnit); err != nil { + return err + } + + // Start the socket unit so socket activation is available immediately. + if err := sysd.Start([]string{securityJournalSocketUnit}); err != nil { + logger.Noticef("cannot start security journal socket: %v", err) + } + + // Signal the namespaced journald to reload configuration if it was + // already running; if not, socket activation picks up the new config. + if err := sysd.Kill(securityJournalServiceUnit, "USR1", ""); err != nil { + // Non-fatal: new config takes effect on next activation. + } + + // Re-open the security logger against the fresh namespace connection. + // This is done last so that the journal service is in a stable state + // before we connect to the sink. + if err := seclogEnable(); err != nil { + logger.Noticef("cannot enable security logger: %v", err) + } + + return nil +} + +func generateSecurityJournalConf(maxSize string) []byte { + conf := "[Journal]\nStorage=persistent\nCompress=yes\n" + if maxSize == "" { + maxSize = securityJournalDefaultMaxSize + } + conf += fmt.Sprintf("SystemMaxUse=%s\n", maxSize) + conf += "SyncIntervalSec=30s\nSyncOnShutdown=yes\n" + // Sanity rate limit: generous enough to never trigger under + // normal operation, but prevents runaway log storms from + // filling the journal if something goes very wrong. + conf += "RateLimitIntervalSec=30s\nRateLimitBurst=10000\n" + return []byte(conf) +} + +// validateJournalSizeValue validates a systemd journal size value (e.g. "10M", "1G"). +// The minimum allowed value is 10M. +func validateJournalSizeValue(value string) error { + if len(value) < 2 { + return fmt.Errorf("cannot parse size %q: must be a number followed by a suffix like K, M, G or T", value) + } + suffix := value[len(value)-1] + var multiplier uint64 + switch suffix { + case 'K': + multiplier = 1024 + case 'M': + multiplier = 1024 * 1024 + case 'G': + multiplier = 1024 * 1024 * 1024 + case 'T': + multiplier = 1024 * 1024 * 1024 * 1024 + default: + return fmt.Errorf("cannot parse size %q: must be a number followed by a suffix like K, M, G or T", value) + } + numStr := value[:len(value)-1] + var num uint64 + for _, ch := range numStr { + if ch < '0' || ch > '9' { + return fmt.Errorf("cannot parse size %q: must be a number followed by a suffix like K, M, G or T", value) + } + num = num*10 + uint64(ch-'0') + } + const minSize = 10 * 1024 * 1024 // 10M + if num*multiplier < minSize { + return fmt.Errorf("cannot set size %q: must be at least 10M", value) + } + return nil +} diff --git a/overlord/configstate/configcore/security_logging_test.go b/overlord/configstate/configcore/security_logging_test.go new file mode 100644 index 00000000000..0828600a004 --- /dev/null +++ b/overlord/configstate/configcore/security_logging_test.go @@ -0,0 +1,300 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package configcore_test + +import ( + "os" + "path/filepath" + + . "gopkg.in/check.v1" + + "github.com/snapcore/snapd/dirs" + "github.com/snapcore/snapd/osutil" + "github.com/snapcore/snapd/overlord/configstate/configcore" + "github.com/snapcore/snapd/testutil" +) + +type securityLoggingSuite struct { + configcoreSuite +} + +var _ = Suite(&securityLoggingSuite{}) + +func (s *securityLoggingSuite) SetUpTest(c *C) { + s.configcoreSuite.SetUpTest(c) + + err := os.MkdirAll(filepath.Join(dirs.GlobalRootDir, "/etc/systemd"), 0755) + c.Assert(err, IsNil) + + s.AddCleanup(configcore.MockSeclogEnable(func() error { return nil })) + s.AddCleanup(configcore.MockSeclogDisable(func() error { return nil })) +} + +// Validation tests + +func (s *securityLoggingSuite) TestValidateEnabledValid(c *C) { + for _, val := range []string{"true", "false"} { + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{"security-logging.enabled": val}, + }) + c.Check(err, IsNil, Commentf("value %q", val)) + } +} + +func (s *securityLoggingSuite) TestValidateEnabledInvalid(c *C) { + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{"security-logging.enabled": "maybe"}, + }) + c.Assert(err, ErrorMatches, `security-logging.enabled can only be set to 'true' or 'false'`) +} + +func (s *securityLoggingSuite) TestValidateMaxSizeValid(c *C) { + for _, val := range []string{"10M", "100M", "1G", "2T"} { + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{ + "security-logging.enabled": "true", + "security-logging.max-size": val, + }, + }) + c.Check(err, IsNil, Commentf("value %q", val)) + } +} + +func (s *securityLoggingSuite) TestValidateMaxSizeInvalidSuffix(c *C) { + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{ + "security-logging.enabled": "true", + "security-logging.max-size": "100X", + }, + }) + c.Assert(err, ErrorMatches, `security-logging.max-size cannot parse size "100X": must be a number followed by a suffix like K, M, G or T`) +} + +func (s *securityLoggingSuite) TestValidateMaxSizeTooSmall(c *C) { + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{ + "security-logging.enabled": "true", + "security-logging.max-size": "5M", + }, + }) + c.Assert(err, ErrorMatches, `security-logging.max-size cannot set size "5M": must be at least 10M`) +} + +func (s *securityLoggingSuite) TestValidateMaxSizeNonNumeric(c *C) { + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{ + "security-logging.enabled": "true", + "security-logging.max-size": "abcM", + }, + }) + c.Assert(err, ErrorMatches, `security-logging.max-size cannot parse size "abcM": must be a number followed by a suffix like K, M, G or T`) +} + +func (s *securityLoggingSuite) TestValidateMaxSizeTooShort(c *C) { + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{ + "security-logging.enabled": "true", + "security-logging.max-size": "M", + }, + }) + c.Assert(err, ErrorMatches, `security-logging.max-size cannot parse size "M": must be a number followed by a suffix like K, M, G or T`) +} + +// Nothing set -> no-op + +func (s *securityLoggingSuite) TestHandleNothingSet(c *C) { + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{}, + }) + c.Assert(err, IsNil) + c.Check(s.systemctlArgs, HasLen, 0) +} + +// Enable path (runtime) + +func (s *securityLoggingSuite) TestEnableSecurityLogging(c *C) { + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{ + "security-logging.enabled": "true", + }, + }) + c.Assert(err, IsNil) + + // Check the journal conf was written with default max-size. + confPath := filepath.Join(dirs.GlobalRootDir, "/etc/systemd/journald@snapd-security.conf") + c.Check(confPath, testutil.FileContains, "Storage=persistent") + c.Check(confPath, testutil.FileContains, "SystemMaxUse=10M") + c.Check(confPath, testutil.FileContains, "SyncIntervalSec=30s") + c.Check(confPath, testutil.FileContains, "RateLimitBurst=10000") + + // Check systemctl calls: unmask, start, kill (reload config). + // The mock systemd operates under a root dir, so --root is prepended. + c.Assert(s.systemctlArgs, HasLen, 3) + c.Check(s.systemctlArgs[0], testutil.DeepContains, "unmask") + c.Check(s.systemctlArgs[0], testutil.DeepContains, "systemd-journald@snapd-security.socket") + c.Check(s.systemctlArgs[1], testutil.DeepContains, "start") + c.Check(s.systemctlArgs[1], testutil.DeepContains, "systemd-journald@snapd-security.socket") + c.Check(s.systemctlArgs[2], testutil.DeepContains, "kill") + c.Check(s.systemctlArgs[2], testutil.DeepContains, "systemd-journald@snapd-security.service") +} + +func (s *securityLoggingSuite) TestEnableSecurityLoggingWithMaxSize(c *C) { + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{ + "security-logging.enabled": "true", + "security-logging.max-size": "50M", + }, + }) + c.Assert(err, IsNil) + + confPath := filepath.Join(dirs.GlobalRootDir, "/etc/systemd/journald@snapd-security.conf") + c.Check(confPath, testutil.FileContains, "SystemMaxUse=50M") +} + +func (s *securityLoggingSuite) TestEnableCallsSeclogEnable(c *C) { + enableCalled := false + s.AddCleanup(configcore.MockSeclogEnable(func() error { + enableCalled = true + return nil + })) + + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{ + "security-logging.enabled": "true", + }, + }) + c.Assert(err, IsNil) + c.Check(enableCalled, Equals, true) +} + +// Disable path (runtime) + +func (s *securityLoggingSuite) TestDisableSecurityLogging(c *C) { + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{ + "security-logging.enabled": "false", + }, + }) + c.Assert(err, IsNil) + + // Check systemctl calls: mask the socket. + c.Assert(s.systemctlArgs, HasLen, 1) + c.Check(s.systemctlArgs[0], testutil.DeepContains, "mask") + c.Check(s.systemctlArgs[0], testutil.DeepContains, "systemd-journald@snapd-security.socket") +} + +func (s *securityLoggingSuite) TestDisableCallsSeclogDisable(c *C) { + disableCalled := false + s.AddCleanup(configcore.MockSeclogDisable(func() error { + disableCalled = true + return nil + })) + + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{ + "security-logging.enabled": "false", + }, + }) + c.Assert(err, IsNil) + c.Check(disableCalled, Equals, true) +} + +// Filesystem-only enable (preseeding/install) + +func (s *securityLoggingSuite) TestEnableFSOnly(c *C) { + rootDir := c.MkDir() + err := os.MkdirAll(filepath.Join(rootDir, "/etc/systemd"), 0755) + c.Assert(err, IsNil) + + // Place a mask symlink as if previously disabled. + maskDir := filepath.Join(rootDir, "/etc/systemd/system") + err = os.MkdirAll(maskDir, 0755) + c.Assert(err, IsNil) + maskPath := filepath.Join(maskDir, "systemd-journald@snapd-security.socket") + err = os.Symlink("/dev/null", maskPath) + c.Assert(err, IsNil) + + err = configcore.FilesystemOnlyApply(coreDev, rootDir, map[string]any{ + "security-logging.enabled": "true", + }) + c.Assert(err, IsNil) + + // Conf written. + confPath := filepath.Join(rootDir, "/etc/systemd/journald@snapd-security.conf") + c.Check(confPath, testutil.FileContains, "Storage=persistent") + + // Mask symlink removed. + c.Check(osutil.IsSymlink(maskPath), Equals, false) + + // No systemctl calls in fs-only mode. + c.Check(s.systemctlArgs, HasLen, 0) +} + +// Filesystem-only disable (preseeding/install) + +func (s *securityLoggingSuite) TestDisableFSOnly(c *C) { + rootDir := c.MkDir() + + err := configcore.FilesystemOnlyApply(coreDev, rootDir, map[string]any{ + "security-logging.enabled": "false", + }) + c.Assert(err, IsNil) + + // Mask symlink created. + maskPath := filepath.Join(rootDir, "/etc/systemd/system", "systemd-journald@snapd-security.socket") + c.Check(osutil.IsSymlink(maskPath), Equals, true) + target, err := os.Readlink(maskPath) + c.Assert(err, IsNil) + c.Check(target, Equals, "/dev/null") + + // No systemctl calls. + c.Check(s.systemctlArgs, HasLen, 0) +} + +// max-size alone implies enable + +func (s *securityLoggingSuite) TestMaxSizeAloneImpliesEnable(c *C) { + err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ + state: s.state, + conf: map[string]any{ + "security-logging.max-size": "100M", + }, + }) + c.Assert(err, IsNil) + + confPath := filepath.Join(dirs.GlobalRootDir, "/etc/systemd/journald@snapd-security.conf") + c.Check(confPath, testutil.FileContains, "SystemMaxUse=100M") + + // Systemctl was called (unmask, start, kill) — enable path was taken. + c.Check(len(s.systemctlArgs) > 0, Equals, true) +} diff --git a/seclog/audit.go b/seclog/audit.go new file mode 100644 index 00000000000..43805b820b5 --- /dev/null +++ b/seclog/audit.go @@ -0,0 +1,115 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +import ( + "encoding/binary" + "fmt" + "io" + "sync/atomic" + "syscall" +) + +const ( + // AUDIT_USER_MSG is the audit message type for user-space messages. + auditUserMsg = 1112 + + // NETLINK_AUDIT is the netlink protocol for audit. + netlinkAudit = 15 +) + +func init() { + registerSink(SinkAudit, newAuditSink) +} + +// newAuditSink opens a netlink audit socket and returns an [auditWriter] +// that sends each written payload as an AUDIT_USER_MSG. The appID is +// currently unused but accepted for sink signature compatibility. +func newAuditSink(_ string) (io.Writer, error) { + fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, netlinkAudit) + if err != nil { + return nil, fmt.Errorf("cannot open audit socket: %w", err) + } + addr := &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 0, // kernel + Groups: 0, + } + if err := syscall.Bind(fd, addr); err != nil { + syscall.Close(fd) + return nil, fmt.Errorf("cannot bind audit socket: %w", err) + } + return &auditWriter{fd: fd}, nil +} + +// auditWriter sends messages to the kernel audit subsystem via a netlink +// socket. Each Write call sends the payload as an AUDIT_USER_MSG. +// +// The writer is safe for sequential use; concurrent use requires external +// synchronization. +type auditWriter struct { + fd int + seq atomic.Uint32 +} + +// Write sends p as the payload of an AUDIT_USER_MSG netlink message. +// The returned byte count reflects only the original payload length. +func (aw *auditWriter) Write(p []byte) (int, error) { + msg := aw.buildMessage(p) + addr := &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 0, // kernel + } + if err := syscall.Sendto(aw.fd, msg, 0, addr); err != nil { + return 0, fmt.Errorf("cannot send audit message: %w", err) + } + return len(p), nil +} + +// Close closes the underlying netlink socket. +func (aw *auditWriter) Close() error { + return syscall.Close(aw.fd) +} + +// nlmsghdrSize is the size of a netlink message header in bytes +// (uint32 + uint16 + uint16 + uint32 + uint32 = 16). +const nlmsghdrSize = 16 + +// buildMessage constructs a raw netlink AUDIT_USER_MSG containing payload. +func (aw *auditWriter) buildMessage(payload []byte) []byte { + totalLen := nlmsghdrSize + uint32(len(payload)) + buf := make([]byte, nlmsgAlign(totalLen)) + + // Write header. + binary.LittleEndian.PutUint32(buf[0:4], totalLen) + binary.LittleEndian.PutUint16(buf[4:6], auditUserMsg) + binary.LittleEndian.PutUint16(buf[6:8], 0x01|0x04) // NLM_F_REQUEST | NLM_F_ACK + binary.LittleEndian.PutUint32(buf[8:12], aw.seq.Add(1)) + binary.LittleEndian.PutUint32(buf[12:16], 0) // pid 0 = kernel + + // Write payload. + copy(buf[nlmsghdrSize:], payload) + return buf +} + +// nlmsgAlign rounds up to the nearest 4-byte boundary per NLMSG_ALIGN. +func nlmsgAlign(n uint32) uint32 { + return (n + 3) &^ 3 +} diff --git a/seclog/export_slog_test.go b/seclog/export_slog_test.go new file mode 100644 index 00000000000..6d9c0366637 --- /dev/null +++ b/seclog/export_slog_test.go @@ -0,0 +1,26 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- +//go:build go1.21 && !noslog + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +type ( + SlogProvider = slogProvider + SlogLogger = slogLogger +) diff --git a/seclog/export_test.go b/seclog/export_test.go new file mode 100644 index 00000000000..edf95624fc9 --- /dev/null +++ b/seclog/export_test.go @@ -0,0 +1,92 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +import ( + "io" + + "github.com/snapcore/snapd/testutil" +) + +var NewNopLogger = newNopLogger + +var Register = register +var RegisterSink = registerSink + +type ( + Provider = provider + SecurityLogger = securityLogger +) + +func MockSinks(m map[Sink]func(string) (io.Writer, error)) (restore func()) { + restore = testutil.Backup(&sinks) + sinks = m + return restore +} + +// MockNewSink is a convenience wrapper that replaces the journal sink factory +// in the sinks map. The rest of the sinks map is preserved. +func MockNewSink(f func(string) (io.Writer, error)) (restore func()) { + restore = testutil.Backup(&sinks) + sinks = map[Sink]func(string) (io.Writer, error){ + SinkJournal: f, + SinkAudit: newAuditSink, + } + return restore +} + +func MockProviders(m map[Impl]provider) (restore func()) { + restore = testutil.Backup(&providers) + providers = m + return restore +} + +func MockGlobalLogger(l securityLogger) (restore func()) { + restore = testutil.Backup(&globalLogger) + globalLogger = l + return restore +} + +func MockGlobalCloser(c io.Closer) (restore func()) { + restore = testutil.Backup(&globalCloser) + globalCloser = c + return restore +} + +// LoggerSetup is the exported alias for the unexported loggerSetup type, +// allowing tests to create and mock setup state. +type LoggerSetup = loggerSetup + +// NewLoggerSetup constructs a LoggerSetup for use in tests. +func NewLoggerSetup(impl Impl, sink Sink, appID string, minLevel Level) *LoggerSetup { + return &LoggerSetup{impl: impl, sink: sink, appID: appID, minLevel: minLevel} +} + +func MockGlobalSetup(s *LoggerSetup) (restore func()) { + restore = testutil.Backup(&globalSetup) + globalSetup = s + return restore +} + +var SyslogPriority = syslogPriority + +var NewJournalWriter = newJournalWriter + +type JournalWriter = journalWriter diff --git a/seclog/journal.go b/seclog/journal.go new file mode 100644 index 00000000000..e23446e9385 --- /dev/null +++ b/seclog/journal.go @@ -0,0 +1,134 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +import ( + "fmt" + "io" + "log/syslog" + "os" + + "github.com/snapcore/snapd/systemd" +) + +const securityNamespace = "snapd-security" + +func init() { + registerSink(SinkJournal, newJournalSink) +} + +// newJournalSink opens a journald stream for the "snapd-security" namespace +// and returns a [journalWriter] that prepends syslog priority prefixes to every +// written line. The resulting writer is suitable as the output sink for a +// structured security logger. +func newJournalSink(appID string) (io.Writer, error) { + f, err := newJournalStream(appID) + if err != nil { + return nil, err + } + return newJournalWriter(f), nil +} + +// journalWriter implements [levelWriter] by wrapping an [io.Writer] and +// prepending a syslog-style "" priority prefix to each Write call. When +// used with a journald stream opened in level-prefix mode, the prefix +// overrides the per-message PRIORITY field. journald strips the prefix from +// the stored MESSAGE content. +// +// SetLevel must be called before each Write to select the priority for the +// upcoming message. Concurrent use of SetLevel and Write requires external +// synchronization. +type journalWriter struct { + w io.Writer + level Level +} + +// Ensure [journalWriter] implements [levelWriter]. +var _ levelWriter = (*journalWriter)(nil) + +// newJournalWriter returns a [journalWriter] that writes to the given +// writer with per-message syslog priority prefixes. +func newJournalWriter(w io.Writer) *journalWriter { + return &journalWriter{w: w, level: LevelInfo} +} + +// SetLevel sets the syslog priority for the next Write call. +func (jw *journalWriter) SetLevel(level Level) { + jw.level = level +} + +// Close closes the underlying writer if it implements [io.Closer]. +func (jw *journalWriter) Close() error { + if closer, ok := jw.w.(io.Closer); ok { + return closer.Close() + } + return nil +} + +// Write prepends "" to p and writes the result to the underlying writer. +// The returned byte count reflects only the original payload, excluding the +// prefix, to satisfy [io.Writer] callers that compare n against len(p). +func (jw *journalWriter) Write(p []byte) (int, error) { + prefix := fmt.Sprintf("<%d>", syslogPriority(jw.level)) + buf := make([]byte, len(prefix)+len(p)) + copy(buf, prefix) + copy(buf[len(prefix):], p) + n, err := jw.w.Write(buf) + // Report bytes written minus the prefix length so that + // callers see n == len(p) on success. + if n >= len(prefix) { + return n - len(prefix), err + } + return 0, err +} + +// newJournalStream opens a journald stream connection to the +// "snapd-security" namespace. The stream uses level-prefix mode so that +// each written line can override PRIORITY per message by prepending a +// "" syslog priority prefix. journald strips the prefix from the +// stored MESSAGE. +// +// The returned *os.File is suitable as the underlying writer for a +// [journalWriter]. +func newJournalStream(appID string) (*os.File, error) { + return systemd.NewJournalStreamFile(systemd.JournalStreamFileParams{ + Namespace: securityNamespace, + Identifier: appID, + Priority: syslog.LOG_DEBUG, + LevelPrefix: true, + }) +} + +// syslogPriority maps a security log [Level] to the equivalent syslog +// priority used by journald for the PRIORITY field. +func syslogPriority(level Level) syslog.Priority { + switch { + case level >= LevelCritical: + return syslog.LOG_CRIT + case level >= LevelError: + return syslog.LOG_ERR + case level >= LevelWarn: + return syslog.LOG_WARNING + case level >= LevelInfo: + return syslog.LOG_INFO + default: + return syslog.LOG_DEBUG + } +} diff --git a/seclog/journal_test.go b/seclog/journal_test.go new file mode 100644 index 00000000000..7659c2cd23b --- /dev/null +++ b/seclog/journal_test.go @@ -0,0 +1,138 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog_test + +import ( + "bytes" + "fmt" + + . "gopkg.in/check.v1" + + "github.com/snapcore/snapd/seclog" + "github.com/snapcore/snapd/testutil" +) + +type JournalSuite struct { + testutil.BaseTest + buf *bytes.Buffer +} + +var _ = Suite(&JournalSuite{}) + +func (s *JournalSuite) SetUpTest(c *C) { + s.BaseTest.SetUpTest(c) + s.buf = &bytes.Buffer{} +} + +func (s *JournalSuite) TearDownTest(c *C) { + s.BaseTest.TearDownTest(c) +} + +func (s *JournalSuite) TestNewJournalWriterDefaultLevel(c *C) { + jw := seclog.NewJournalWriter(s.buf) + c.Assert(jw, NotNil) + + // default level is LevelInfo, syslog.LOG_INFO == 6 + _, err := jw.Write([]byte("hello")) + c.Assert(err, IsNil) + c.Check(s.buf.String(), Equals, "<6>hello") +} + +func (s *JournalSuite) TestSetLevelAndWrite(c *C) { + jw := seclog.NewJournalWriter(s.buf) + + tests := []struct { + level seclog.Level + expectedPrefix string + }{ + {seclog.LevelDebug, "<7>"}, // LOG_DEBUG + {seclog.LevelInfo, "<6>"}, // LOG_INFO + {seclog.LevelWarn, "<4>"}, // LOG_WARNING + {seclog.LevelError, "<3>"}, // LOG_ERR + {seclog.LevelCritical, "<2>"}, // LOG_CRIT + } + + for _, t := range tests { + s.buf.Reset() + jw.SetLevel(t.level) + msg := []byte("test message") + n, err := jw.Write(msg) + c.Assert(err, IsNil) + c.Check(n, Equals, len(msg), + Commentf("level %v", t.level)) + c.Check(s.buf.String(), Equals, t.expectedPrefix+"test message", + Commentf("level %v", t.level)) + } +} + +func (s *JournalSuite) TestWriteByteCountExcludesPrefix(c *C) { + jw := seclog.NewJournalWriter(s.buf) + + msg := []byte("payload") + n, err := jw.Write(msg) + c.Assert(err, IsNil) + // n must equal len(msg), not len("<6>payload") + c.Check(n, Equals, len(msg)) +} + +type errWriter struct { + err error +} + +func (w *errWriter) Write(p []byte) (int, error) { + return 0, w.err +} + +func (s *JournalSuite) TestWritePropagatesError(c *C) { + expected := fmt.Errorf("disk full") + jw := seclog.NewJournalWriter(&errWriter{err: expected}) + + n, err := jw.Write([]byte("data")) + c.Check(err, Equals, expected) + c.Check(n, Equals, 0) +} + +// closeRecorder implements io.WriteCloser and records whether Close was called. +type closeRecorder struct { + bytes.Buffer + closed bool +} + +func (cr *closeRecorder) Close() error { + cr.closed = true + return nil +} + +func (s *JournalSuite) TestCloseForwardsToUnderlyingWriter(c *C) { + cr := &closeRecorder{} + jw := seclog.NewJournalWriter(cr) + + err := jw.Close() + c.Assert(err, IsNil) + c.Check(cr.closed, Equals, true) +} + +func (s *JournalSuite) TestCloseWithNonCloserReturnsNil(c *C) { + // bytes.Buffer does not implement io.Closer + jw := seclog.NewJournalWriter(s.buf) + + err := jw.Close() + c.Assert(err, IsNil) +} diff --git a/seclog/nop.go b/seclog/nop.go new file mode 100644 index 00000000000..b4513c5b0df --- /dev/null +++ b/seclog/nop.go @@ -0,0 +1,46 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +// nopLogger provides a no-operation [securityLogger] implementation. +type nopLogger struct{} + +// Ensure [nopLogger] implements [securityLogger]. +var _ securityLogger = (*nopLogger)(nil) + +func newNopLogger() securityLogger { + return nopLogger{} +} + +// LogLoggingEnabled implements [securityLogger.LogLoggingEnabled]. +func (nopLogger) LogLoggingEnabled() { +} + +// LogLoggingDisabled implements [securityLogger.LogLoggingDisabled]. +func (nopLogger) LogLoggingDisabled() { +} + +// LogLoginSuccess implements [securityLogger.LogLoginSuccess]. +func (nopLogger) LogLoginSuccess(user SnapdUser) { +} + +// LogLoginFailure implements [securityLogger.LogLoginFailure]. +func (nopLogger) LogLoginFailure(user SnapdUser) { +} diff --git a/seclog/nop_test.go b/seclog/nop_test.go new file mode 100644 index 00000000000..d183e96def7 --- /dev/null +++ b/seclog/nop_test.go @@ -0,0 +1,77 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog_test + +import ( + "testing" + + . "gopkg.in/check.v1" + + "github.com/snapcore/snapd/seclog" + "github.com/snapcore/snapd/testutil" +) + +type NopSuite struct { + testutil.BaseTest +} + +var _ = Suite(&NopSuite{}) + +func TestNop(t *testing.T) { TestingT(t) } + +func (s *NopSuite) SetUpTest(c *C) { + s.BaseTest.SetUpTest(c) +} + +func (s *NopSuite) TearDownTest(c *C) { + s.BaseTest.TearDownTest(c) +} + +func (s *NopSuite) TestLogLoggingEnabled(c *C) { + logger := seclog.NewNopLogger() + c.Assert(logger, NotNil) + + // nop logger discards all messages without error + logger.LogLoggingEnabled() +} + +func (s *NopSuite) TestLogLoggingDisabled(c *C) { + logger := seclog.NewNopLogger() + c.Assert(logger, NotNil) + + // nop logger discards all messages without error + logger.LogLoggingDisabled() +} + +func (s *NopSuite) TestLogLoginSuccess(c *C) { + logger := seclog.NewNopLogger() + c.Assert(logger, NotNil) + + // nop logger discards all messages without error + logger.LogLoginSuccess(seclog.SnapdUser{StoreUserEmail: "user@gmail.com"}) +} + +func (s *NopSuite) TestLogLoginFailure(c *C) { + logger := seclog.NewNopLogger() + c.Assert(logger, NotNil) + + // nop logger discards all messages without error + logger.LogLoginFailure(seclog.SnapdUser{StoreUserEmail: "user@gmail.com"}) +} diff --git a/seclog/seclog.go b/seclog/seclog.go new file mode 100644 index 00000000000..ac1aef6e954 --- /dev/null +++ b/seclog/seclog.go @@ -0,0 +1,301 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +import ( + "fmt" + "io" + "sync" + "time" +) + +var ( + providers = map[Impl]provider{} + sinks = map[Sink]func(string) (io.Writer, error){} + globalLogger securityLogger = newNopLogger() + globalCloser io.Closer + globalSetup *loggerSetup + lock sync.Mutex +) + +// Level is the importance or severity of a log event. +// The higher the level, the more severe the event. +type Level int + +// Log levels. +const ( + LevelDebug Level = 1 + LevelInfo Level = 2 + LevelWarn Level = 3 + LevelError Level = 4 + LevelCritical Level = 5 +) + +// String returns a name for the level. +// If the level has a name, then that name +// in uppercase is returned. +// If the level is between named values, then +// an integer is appended to the uppercased name. +// Examples: +// +// LevelWarn.String() => "WARN" +// (LevelCritical+2).String() => "CRITICAL+2" +func (l Level) String() string { + str := func(base string, val Level) string { + if val == 0 { + return base + } + return fmt.Sprintf("%s%+d", base, val) + } + + switch { + case l < LevelInfo: + return str("DEBUG", l-LevelDebug) + case l < LevelWarn: + return str("INFO", l-LevelInfo) + case l < LevelError: + return str("WARN", l-LevelWarn) + case l < LevelCritical: + return str("ERROR", l-LevelError) + default: + return str("CRITICAL", l-LevelCritical) + } +} + +// Impl represents a known logger implementation identifier used for +// registration and selection of security loggers. +type Impl string + +// Logger implementations. +const ( + ImplSlog Impl = "slog" // slog based structured logger +) + +// Sink identifies a log output destination. +type Sink string + +// Sink types. +const ( + SinkJournal Sink = "journal" // journald namespace stream + SinkAudit Sink = "audit" // kernel audit via netlink +) + +// SnapdUser represents the identity of a user for security log events. +type SnapdUser struct { + ID int64 `json:"snapd-user-id"` + SystemUserName string `json:"system-user-name,omitempty"` + StoreUserEmail string `json:"store-user-email,omitempty"` + Expiration time.Time `json:"expiration,omitzero"` +} + +// String returns a colon-separated description of the user in the form +// "::". Fields that are unset use +// "unknown" as a placeholder. A zero ID is treated as unset. +func (u SnapdUser) String() string { + const unknown = "unknown" + id := unknown + if u.ID != 0 { + id = fmt.Sprintf("%d", u.ID) + } + email := unknown + if u.StoreUserEmail != "" { + email = u.StoreUserEmail + } + name := unknown + if u.SystemUserName != "" { + name = u.SystemUserName + } + return id + ":" + email + ":" + name +} + +// securityLogger defines the interface for emitting structured security +// audit events. Implementations are created by a [provider] and write +// to a configured sink. +type securityLogger interface { + LogLoggingEnabled() + LogLoggingDisabled() + LogLoginSuccess(user SnapdUser) + LogLoginFailure(user SnapdUser) +} + +// loggerSetup holds the configuration provided to Setup. +type loggerSetup struct { + impl Impl + sink Sink + appID string + minLevel Level +} + +// provider provides functions required for constructing a [securityLogger]. +// It is intended for registration of available loggers. +type provider interface { + // New creates a securityLogger that writes to writer. Messages with a + // severity below minLevel are silently dropped. + New(writer io.Writer, appID string, minLevel Level) securityLogger + // Impl returns the identifier for this provider. + Impl() Impl +} + +// Setup stores the logger configuration and attempts to enable the +// security logger immediately. If the log sink cannot be opened (e.g. +// because the journal namespace is not active yet), the configuration +// is still stored and a non-fatal "security logger disabled" error is +// returned. A subsequent call to Enable will re-attempt activation. +func Setup(impl Impl, sink Sink, appID string, minLevel Level) error { + lock.Lock() + defer lock.Unlock() + + if _, exists := providers[impl]; !exists { + return fmt.Errorf("cannot set up security logger: unknown implementation %q", string(impl)) + } + if _, exists := sinks[sink]; !exists { + return fmt.Errorf("cannot set up security logger: unknown sink %q", string(sink)) + } + globalSetup = &loggerSetup{impl: impl, sink: sink, appID: appID, minLevel: minLevel} + if err := enableLocked(); err != nil { + return fmt.Errorf("security logger disabled") + } + return nil +} + +// Enable opens the security log sink using the configuration stored by Setup, +// activating the security logger. If the sink is already open, it is closed +// and re-opened, refreshing the connection to the journal namespace. +// Returns an error if Setup has not been called or if the sink cannot be opened. +func Enable() error { + lock.Lock() + defer lock.Unlock() + + if globalSetup == nil { + return fmt.Errorf("cannot enable security logger: setup has not been called") + } + return enableLocked() +} + +// Disable closes the security log sink and resets the global logger to nop. +// The stored configuration is retained so that Enable can re-open the sink +// later. It is safe to call even if the logger is already a nop. +func Disable() error { + lock.Lock() + defer lock.Unlock() + if globalSetup == nil { + return nil + } + return closeSinkLocked() +} + +// LogLoggingEnabled logs that security auditing has been enabled. +func LogLoggingEnabled() { + lock.Lock() + defer lock.Unlock() + globalLogger.LogLoggingEnabled() +} + +// LogLoggingDisabled logs that security auditing has been disabled. +func LogLoggingDisabled() { + lock.Lock() + defer lock.Unlock() + globalLogger.LogLoggingDisabled() +} + +// LogLoginSuccess logs a successful login using the global security logger. +func LogLoginSuccess(user SnapdUser) { + lock.Lock() + defer lock.Unlock() + globalLogger.LogLoginSuccess(user) +} + +// LogLoginFailure logs a failed login attempt using the global security logger. +func LogLoginFailure(user SnapdUser) { + lock.Lock() + defer lock.Unlock() + globalLogger.LogLoginFailure(user) +} + +// register makes a provider available by name. +// Should be called from init(). +func register(p provider) { + lock.Lock() + defer lock.Unlock() + impl := p.Impl() + if _, exists := providers[impl]; exists { + panic(fmt.Sprintf("attempting registration for existing logger %q", impl)) + } + providers[impl] = p +} + +// registerSink makes a sink factory available by name. +// Should be called from init(). +func registerSink(name Sink, factory func(string) (io.Writer, error)) { + lock.Lock() + defer lock.Unlock() + if _, exists := sinks[name]; exists { + panic(fmt.Sprintf("attempting registration for existing sink %q", name)) + } + sinks[name] = factory +} + +// enableLocked resolves the provider, opens the sink, and activates the +// logger. Must be called with lock held and globalSetup non-nil. +func enableLocked() error { + provider, exists := providers[globalSetup.impl] + if !exists { + return fmt.Errorf("internal error: provider %q missing", string(globalSetup.impl)) + } + newSink, exists := sinks[globalSetup.sink] + if !exists { + return fmt.Errorf("internal error: sink %q missing", string(globalSetup.sink)) + } + writer, err := openSinkLocked(newSink, globalSetup.appID) + if err != nil { + return fmt.Errorf("cannot enable security logger: %w", err) + } + globalLogger = provider.New(writer, globalSetup.appID, globalSetup.minLevel) + return nil +} + +// openSinkLocked opens the log sink and manages the closer. Any previously +// open sink is closed first. Must be called with lock held. +func openSinkLocked(newSink func(string) (io.Writer, error), appID string) (io.Writer, error) { + writer, err := newSink(appID) + if err != nil { + return nil, err + } + if globalCloser != nil { + globalCloser.Close() + globalCloser = nil + } + if closer, ok := writer.(io.Closer); ok { + globalCloser = closer + } + return writer, nil +} + +// closeSinkLocked closes the security log sink and resets the global logger to +// nop. Must be called with lock held. +func closeSinkLocked() error { + globalLogger = newNopLogger() + if globalCloser != nil { + err := globalCloser.Close() + globalCloser = nil + return err + } + return nil +} diff --git a/seclog/seclog_test.go b/seclog/seclog_test.go new file mode 100644 index 00000000000..cd30a8ba870 --- /dev/null +++ b/seclog/seclog_test.go @@ -0,0 +1,336 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog_test + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log/syslog" + "testing" + + . "gopkg.in/check.v1" + + "github.com/snapcore/snapd/seclog" + "github.com/snapcore/snapd/testutil" +) + +type SecLogSuite struct { + testutil.BaseTest + buf *bytes.Buffer + appID string +} + +var _ = Suite(&SecLogSuite{}) + +func TestSecLog(t *testing.T) { TestingT(t) } + +func (s *SecLogSuite) SetUpSuite(c *C) { + s.buf = &bytes.Buffer{} + s.appID = "canonical.snapd" +} + +func (s *SecLogSuite) SetUpTest(c *C) { + s.BaseTest.SetUpTest(c) + s.buf.Reset() +} + +func (s *SecLogSuite) TearDownTest(c *C) { + s.BaseTest.TearDownTest(c) +} + +func (s *SecLogSuite) TestString(c *C) { + levels := []seclog.Level{ + seclog.LevelDebug - 1, + seclog.LevelDebug, + seclog.LevelInfo, + seclog.LevelWarn, + seclog.LevelError, + seclog.LevelError + 1, + seclog.LevelCritical, + seclog.LevelCritical + 2, + } + + expected := []string{ + "DEBUG-1", + "DEBUG", + "INFO", + "WARN", + "ERROR", + "CRITICAL", + "CRITICAL", + "CRITICAL+2", + } + + c.Assert(len(levels), Equals, len(expected)) + + obtained := make([]string, 0, len(levels)) + + for _, level := range levels { + obtained = append(obtained, level.String()) + } + + c.Assert(expected, DeepEquals, obtained) +} + +func (s *SecLogSuite) TestSyslogPriority(c *C) { + tests := []struct { + level seclog.Level + expected syslog.Priority + }{ + {seclog.LevelDebug - 1, syslog.LOG_DEBUG}, + {seclog.LevelDebug, syslog.LOG_DEBUG}, + {seclog.LevelInfo, syslog.LOG_INFO}, + {seclog.LevelWarn, syslog.LOG_WARNING}, + {seclog.LevelError, syslog.LOG_ERR}, + {seclog.LevelCritical, syslog.LOG_CRIT}, + {seclog.LevelCritical + 1, syslog.LOG_CRIT}, + } + for _, t := range tests { + c.Check(seclog.SyslogPriority(t.level), Equals, t.expected, + Commentf("level %v", t.level)) + } +} + +func (s *SecLogSuite) TestRegister(c *C) { + restore := seclog.MockProviders(map[seclog.Impl]seclog.Provider{}) + defer restore() + + seclog.Register(seclog.SlogProvider{}) + + // registering the same implementation again panics + c.Assert(func() { seclog.Register(seclog.SlogProvider{}) }, PanicMatches, + `attempting registration for existing logger "slog"`) +} + +func (s *SecLogSuite) TestSetupUnknownImpl(c *C) { + restore := seclog.MockProviders(map[seclog.Impl]seclog.Provider{}) + defer restore() + + err := seclog.Setup("unknown", seclog.SinkJournal, s.appID, seclog.LevelInfo) + c.Assert(err, ErrorMatches, + `cannot set up security logger: unknown implementation "unknown"`) +} + +func (s *SecLogSuite) TestSetupSinkError(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return nil, fmt.Errorf("journal unavailable") + }) + defer restore() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo) + c.Assert(err, ErrorMatches, "security logger disabled") +} + +func (s *SecLogSuite) TestSetupSuccess(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + c.Check(appID, Equals, s.appID) + return s.buf, nil + }) + defer restore() + + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + // verify the logger is functional by logging through it + seclog.LogLoginSuccess(seclog.SnapdUser{ID: 1, SystemUserName: "testuser"}) + c.Check(s.buf.Len() > 0, Equals, true) +} + +func (s *SecLogSuite) setupSlogLogger(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return s.buf, nil + }) + s.AddCleanup(restore) + + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + s.AddCleanup(restoreLogger) + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) +} + +func (s *SecLogSuite) TestLogLoginSuccess(c *C) { + s.setupSlogLogger(c) + + user := seclog.SnapdUser{ + ID: 42, + StoreUserEmail: "user@example.com", + SystemUserName: "jdoe", + } + seclog.LogLoginSuccess(user) + + var obtained map[string]any + err := json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(obtained["level"], Equals, "INFO") + c.Check(obtained["description"], Equals, + "User 42:user@example.com:jdoe login success") + c.Check(obtained["app_id"], Equals, s.appID) + c.Check(obtained["category"], Equals, "AUTHN") + c.Check(obtained["event"], Equals, "authn_login_success") + userMap, ok := obtained["user"].(map[string]any) + c.Assert(ok, Equals, true) + c.Check(userMap["snapd-user-id"], Equals, float64(42)) + c.Check(userMap["store-user-email"], Equals, "user@example.com") + c.Check(userMap["system-user-name"], Equals, "jdoe") + c.Check(obtained["type"], Equals, "security") +} + +func (s *SecLogSuite) TestLogLoginFailure(c *C) { + s.setupSlogLogger(c) + + user := seclog.SnapdUser{ + ID: 42, + StoreUserEmail: "user@example.com", + SystemUserName: "jdoe", + } + seclog.LogLoginFailure(user) + + var obtained map[string]any + err := json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(obtained["level"], Equals, "WARN") + c.Check(obtained["description"], Equals, + "User 42:user@example.com:jdoe login failure") + c.Check(obtained["app_id"], Equals, s.appID) + c.Check(obtained["category"], Equals, "AUTHN") + c.Check(obtained["event"], Equals, "authn_login_failure") + userMap, ok := obtained["user"].(map[string]any) + c.Assert(ok, Equals, true) + c.Check(userMap["snapd-user-id"], Equals, float64(42)) + c.Check(userMap["store-user-email"], Equals, "user@example.com") + c.Check(userMap["system-user-name"], Equals, "jdoe") + c.Check(obtained["type"], Equals, "security") +} + +// closeTracker is a test helper that records whether Close was called. +type closeTracker struct { + closed bool + err error +} + +func (ct *closeTracker) Close() error { + ct.closed = true + return ct.err +} + +func (s *SecLogSuite) TestDisableClosesTheSink(c *C) { + tracker := &closeTracker{} + restoreCloser := seclog.MockGlobalCloser(tracker) + defer restoreCloser() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + restoreSetup := seclog.MockGlobalSetup( + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo)) + defer restoreSetup() + + err := seclog.Disable() + c.Assert(err, IsNil) + c.Check(tracker.closed, Equals, true) +} + +func (s *SecLogSuite) TestDisableWithNoSinkReturnsNil(c *C) { + restoreCloser := seclog.MockGlobalCloser(nil) + defer restoreCloser() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Disable() + c.Assert(err, IsNil) +} + +func (s *SecLogSuite) TestDisableIsIdempotent(c *C) { + tracker := &closeTracker{} + restoreCloser := seclog.MockGlobalCloser(tracker) + defer restoreCloser() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + restoreSetup := seclog.MockGlobalSetup( + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo)) + defer restoreSetup() + + err := seclog.Disable() + c.Assert(err, IsNil) + c.Check(tracker.closed, Equals, true) + + // second call does not error even though closer is now nil + err = seclog.Disable() + c.Assert(err, IsNil) +} + +func (s *SecLogSuite) TestDisablePropagatesError(c *C) { + tracker := &closeTracker{err: fmt.Errorf("disk full")} + restoreCloser := seclog.MockGlobalCloser(tracker) + defer restoreCloser() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + restoreSetup := seclog.MockGlobalSetup( + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo)) + defer restoreSetup() + + err := seclog.Disable() + c.Assert(err, ErrorMatches, "disk full") +} + +// writeCloseTracker is a test helper that implements io.WriteCloser and +// records whether Close was called. +type writeCloseTracker struct { + bytes.Buffer + closed bool +} + +func (wc *writeCloseTracker) Close() error { + wc.closed = true + return nil +} + +func (s *SecLogSuite) TestSetupClosesPreviousSink(c *C) { + first := &writeCloseTracker{} + second := &writeCloseTracker{} + call := 0 + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + call++ + if call == 1 { + return first, nil + } + return second, nil + }) + defer restore() + restoreCloser := seclog.MockGlobalCloser(nil) + defer restoreCloser() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + // first setup + err := seclog.Setup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + c.Check(first.closed, Equals, false) + + // second setup should close the first sink + err = seclog.Setup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + c.Check(first.closed, Equals, true) + c.Check(second.closed, Equals, false) +} diff --git a/seclog/slog.go b/seclog/slog.go new file mode 100644 index 00000000000..0eda4763e88 --- /dev/null +++ b/seclog/slog.go @@ -0,0 +1,234 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- +//go:build go1.21 && !noslog + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +import ( + "context" + "fmt" + "io" + "log/slog" + "sync" + "time" + + "github.com/snapcore/snapd/osutil" +) + +// slogProvider implements [provider]. +type slogProvider struct{} + +// Ensure [slogProvider] implements [provider]. +var _ provider = slogProvider{} + +// New constructs an slog based [securityLogger] that emits structured JSON to the +// provided [io.Writer]. The returned logger enables dynamic level control via +// an internal [slog.LevelVar]. +func (slogProvider) New(writer io.Writer, appID string, minLevel Level) securityLogger { + return newSlogLogger(writer, appID, minLevel) +} + +// Impl returns the implementation. +func (slogProvider) Impl() Impl { + return ImplSlog +} + +func init() { + register(slogProvider{}) +} + +func newSlogLogger(writer io.Writer, appID string, minLevel Level) securityLogger { + levelVar := new(slog.LevelVar) + levelVar.Set(slog.Level(minLevel)) + var handler slog.Handler = newJsonHandler(writer, levelVar) + if lw, ok := writer.(levelWriter); ok { + handler = newLevelHandler(handler, lw) + } + logger := &slogLogger{ + // enable dynamic level adjustment + levelVar: levelVar, + // always include app_id and type + logger: slog.New(handler).With( + slog.String("app_id", appID), + slog.String("type", "security"), + ), + } + return logger +} + +// slogLogger implements [securityLogger] and is constructed by the +// [slogProvider]. It wraps a [slog.Logger] and provides the required +// methods. The logger emits structured JSON with a predefined schema for +// built-in attributes and supports dynamic log level control via an internal +// [slog.LevelVar]. When used with a [levelWriter] sink, it ensures that +// each message is written with the correct severity level. +type slogLogger struct { + logger *slog.Logger + levelVar *slog.LevelVar +} + +// Ensure [slogLogger] implements [securityLogger]. +var _ securityLogger = (*slogLogger)(nil) + +// SlogLogger is a test only helper to retrieve a pointer to the underlying +// [slog.Logger]. +func (l *slogLogger) SlogLogger() *slog.Logger { + osutil.MustBeTestBinary("SlogLogger() is for testing only") + return l.logger +} + +// LogLoggingEnabled implements [securityLogger.LogLoggingEnabled]. +func (l *slogLogger) LogLoggingEnabled() { + l.logger.LogAttrs( + context.Background(), + slog.Level(LevelInfo), + "Security auditing enabled", + slog.Attr{Key: "category", Value: slog.StringValue("SYS")}, + slog.Attr{Key: "event", Value: slog.StringValue("sys_logging_enabled")}, + ) +} + +// LogLoggingDisabled implements [securityLogger.LogLoggingDisabled]. +func (l *slogLogger) LogLoggingDisabled() { + l.logger.LogAttrs( + context.Background(), + slog.Level(LevelCritical), + "Security auditing disabled", + slog.Attr{Key: "category", Value: slog.StringValue("SYS")}, + slog.Attr{Key: "event", Value: slog.StringValue("sys_logging_disabled")}, + ) +} + +// LogLoginSuccess implements [securityLogger.LogLoginSuccess]. +func (l *slogLogger) LogLoginSuccess(user SnapdUser) { + l.logger.LogAttrs( + context.Background(), + slog.Level(LevelInfo), + fmt.Sprintf("User %s login success", user.String()), + slog.Attr{Key: "category", Value: slog.StringValue("AUTHN")}, + slog.Attr{Key: "event", Value: slog.StringValue("authn_login_success")}, + slog.Any("user", user), + ) +} + +// LogLoginFailure implements [securityLogger.LogLoginFailure]. +func (l *slogLogger) LogLoginFailure(user SnapdUser) { + l.logger.LogAttrs( + context.Background(), + slog.Level(LevelWarn), + fmt.Sprintf("User %s login failure", user.String()), + slog.Attr{Key: "category", Value: slog.StringValue("AUTHN")}, + slog.Attr{Key: "event", Value: slog.StringValue("authn_login_failure")}, + slog.Any("user", user), + ) +} + +// LogValue implements [slog.LogValuer], allowing SnapdUser to be +// used directly as a structured log attribute value. +func (u SnapdUser) LogValue() slog.Value { + return slog.GroupValue( + slog.Int64("snapd-user-id", u.ID), + slog.String("system-user-name", u.SystemUserName), + slog.String("store-user-email", u.StoreUserEmail), + slog.String("expiration", u.Expiration.UTC().Format(time.RFC3339Nano)), + ) +} + +// newJsonHandler returns a slog JSON handler configured for security logs. +// +// It writes newline-delimited JSON to writer and enforces a schema for the +// built-in attributes: +// - time: key "datetime", formatted in UTC using [time.RFC3339Nano] +// - level: rendered as a string via [Level.String] +// - message: key "description" +// - app_id: always included with the value provided to newSlogLogger +// - type: always included with the value "security" +// +// Additional attributes are preserved verbatim, including nested groups. The +// handler logs at or above the minLevel threshold. It does not +// close or sync writer. +func newJsonHandler(writer io.Writer, minLevel slog.Leveler) slog.Handler { + options := &slog.HandlerOptions{ + Level: minLevel, + ReplaceAttr: func(groups []string, attr slog.Attr) slog.Attr { + switch attr.Key { + case slog.TimeKey: + // use "datetime" instead of default "time" + attr.Key = "datetime" + if t, ok := attr.Value.Any().(time.Time); ok { + // convert to formatted string + attr.Value = slog.StringValue(t.UTC().Format(time.RFC3339Nano)) + } + case slog.LevelKey: + if l, ok := attr.Value.Any().(slog.Level); ok { + attr.Value = slog.StringValue(Level(l).String()) + } + case slog.MessageKey: + // use "description" instead of default "msg" + attr.Key = "description" + } + return attr + }, + } + + return slog.NewJSONHandler(writer, options) +} + +// levelWriter extends [io.Writer] with per-message level control. Writers +// that implement this interface allow log handlers to set the severity for +// each message before writing. +type levelWriter interface { + io.Writer + SetLevel(Level) +} + +// levelHandler is a [slog.Handler] wrapper that sets the level on a +// [levelWriter] before each message is handled. This ensures that the +// written output carries the correct per-message priority. +// +// All derived handlers returned by WithAttrs and WithGroup share the same +// [levelWriter] and mutex, since they write to the same sink. +type levelHandler struct { + inner slog.Handler + lw levelWriter + mu *sync.Mutex +} + +func newLevelHandler(inner slog.Handler, lw levelWriter) slog.Handler { + return &levelHandler{inner: inner, lw: lw, mu: &sync.Mutex{}} +} + +func (h *levelHandler) Enabled(ctx context.Context, level slog.Level) bool { + return h.inner.Enabled(ctx, level) +} + +func (h *levelHandler) Handle(ctx context.Context, r slog.Record) error { + h.mu.Lock() + defer h.mu.Unlock() + h.lw.SetLevel(Level(r.Level)) + return h.inner.Handle(ctx, r) +} + +func (h *levelHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &levelHandler{inner: h.inner.WithAttrs(attrs), lw: h.lw, mu: h.mu} +} + +func (h *levelHandler) WithGroup(name string) slog.Handler { + return &levelHandler{inner: h.inner.WithGroup(name), lw: h.lw, mu: h.mu} +} diff --git a/seclog/slog_test.go b/seclog/slog_test.go new file mode 100644 index 00000000000..83241f29482 --- /dev/null +++ b/seclog/slog_test.go @@ -0,0 +1,306 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- +//go:build go1.21 && !noslog + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog_test + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "testing" + "time" + + "log/slog" + + . "gopkg.in/check.v1" + + "github.com/snapcore/snapd/seclog" + "github.com/snapcore/snapd/testutil" +) + +type SlogSuite struct { + testutil.BaseTest + buf *bytes.Buffer + appID string + provider seclog.Provider +} + +var _ = Suite(&SlogSuite{}) + +func TestSlog(t *testing.T) { TestingT(t) } + +func (s *SlogSuite) SetUpSuite(c *C) { + s.buf = &bytes.Buffer{} + s.appID = "canonical.snapd" + s.provider = seclog.SlogProvider{} +} + +func (s *SlogSuite) SetUpTest(c *C) { + s.BaseTest.SetUpTest(c) + s.buf.Reset() +} + +func (s *SlogSuite) TearDownTest(c *C) { + s.BaseTest.TearDownTest(c) +} + +// extractSlogLogger is a test helper to extract the internal [slog.Logger] from +// SecurityLogger. +func extractSlogLogger(logger seclog.SecurityLogger) (*slog.Logger, error) { + if l, ok := logger.(*seclog.SlogLogger); !ok { + return nil, errors.New("cannot extract slog logger") + } else { + // return the internal slog logger + return l.SlogLogger(), nil + } +} + +func (s *SlogSuite) TestSlogProvider(c *C) { + logger := s.provider.New(s.buf, s.appID, seclog.LevelInfo) + c.Check(logger, NotNil) + + impl := s.provider.Impl() + c.Check(impl, Equals, seclog.ImplSlog) +} + +// baseAttrs represents the non-optional attributes that is present in +// every record +type baseAttrs struct { + Datetime time.Time `json:"datetime"` + Level string `json:"level"` + Description string `json:"description"` + AppID string `json:"app_id"` + Type string `json:"type"` + Category string `json:"category"` +} + +// orderedKeys extracts the top-level JSON object keys in order. +func orderedKeys(data []byte) ([]string, error) { + decoder := json.NewDecoder(bytes.NewReader(data)) + // consume opening '{' + token, err := decoder.Token() + if err != nil { + return nil, err + } + if delim, ok := token.(json.Delim); !ok || delim != '{' { + return nil, errors.New("expected '{' delimiter") + } + var keys []string + for decoder.More() { + token, err = decoder.Token() + if err != nil { + return nil, err + } + key, ok := token.(string) + if !ok { + return nil, errors.New("expected string key") + } + keys = append(keys, key) + // skip value + var raw json.RawMessage + if err := decoder.Decode(&raw); err != nil { + return nil, err + } + } + return keys, nil +} + +type attrsAllTypes struct { + baseAttrs + String string `json:"string"` + Duration time.Duration `json:"duration"` + Timestamp time.Time `json:"timestamp"` + Float64 float64 `json:"float64"` + Int64 int64 `json:"int64"` + Int int `json:"int"` + Uint64 uint64 `json:"uint64"` + Any any `json:"any"` +} + +func (s *SlogSuite) TestHandlerAttrsAllTypes(c *C) { + logger := s.provider.New(s.buf, s.appID, seclog.LevelInfo) + c.Assert(logger, NotNil) + + sl, err := extractSlogLogger(logger) + c.Assert(err, IsNil) + sl.LogAttrs( + context.Background(), + slog.Level(seclog.LevelInfo), + "test description", + slog.Attr{Key: "category", Value: slog.StringValue("AUTHN")}, + slog.Attr{Key: "string", Value: slog.StringValue("test string")}, + slog.Attr{Key: "duration", Value: slog.DurationValue(time.Duration(90 * time.Second))}, + slog.Attr{ + Key: "timestamp", + Value: slog.TimeValue(time.Date(2025, 10, 8, 8, 0, 0, 0, time.UTC)), + }, + slog.Attr{Key: "float64", Value: slog.Float64Value(3.141592653589793)}, + slog.Attr{Key: "int64", Value: slog.Int64Value(-4611686018427387904)}, + slog.Attr{Key: "int", Value: slog.IntValue(-4294967296)}, + slog.Attr{Key: "uint64", Value: slog.Uint64Value(4294967295)}, + // AnyValue returns value of KindInt64, the original + // numeric type is not preserved + slog.Attr{Key: "any", Value: slog.AnyValue(map[string]any{"k": "v", "n": int(1)})}, + ) + + var obtained attrsAllTypes + err = json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + + c.Check(time.Since(obtained.Datetime) < time.Second, Equals, true) + c.Check(obtained.Level, Equals, "INFO") + c.Check(obtained.Description, Equals, "test description") + c.Check(obtained.AppID, Equals, s.appID) + c.Check(obtained.Type, Equals, "security") + c.Check(obtained.Category, Equals, "AUTHN") + + c.Check(obtained.String, Equals, "test string") + c.Check(obtained.Duration, Equals, time.Duration(90*time.Second)) + c.Check(obtained.Timestamp, Equals, time.Date(2025, 10, 8, 8, 0, 0, 0, time.UTC)) + c.Check(obtained.Float64, Equals, float64(3.141592653589793)) + c.Check(obtained.Int64, Equals, int64(-4611686018427387904)) + c.Check(obtained.Int, Equals, int(-4294967296)) + c.Check(obtained.Uint64, Equals, uint64(4294967295)) + c.Check(obtained.Any, DeepEquals, map[string]any{"k": "v", "n": float64(1)}) +} + +func (s *SlogSuite) TestLogLoginSuccess(c *C) { + logger := s.provider.New(s.buf, s.appID, seclog.LevelInfo) + c.Assert(logger, NotNil) + + type LoginSuccess struct { + baseAttrs + Event string `json:"event"` + User struct { + ID int64 `json:"snapd-user-id"` + SystemUserName string `json:"system-user-name"` + StoreUserEmail string `json:"store-user-email"` + Expiration string `json:"expiration"` + } `json:"user"` + } + + user := seclog.SnapdUser{ + ID: 42, + StoreUserEmail: "user@gmail.com", + SystemUserName: "jdoe", + } + logger.LogLoginSuccess(user) + + var obtained LoginSuccess + err := json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(time.Since(obtained.Datetime) < time.Second, Equals, true) + c.Check(obtained.Level, Equals, "INFO") + c.Check(obtained.Description, Equals, "User 42:user@gmail.com:jdoe login success") + c.Check(obtained.AppID, Equals, s.appID) + c.Check(obtained.Event, Equals, "authn_login_success") + c.Check(obtained.User.ID, Equals, int64(42)) + c.Check(obtained.User.StoreUserEmail, Equals, "user@gmail.com") + c.Check(obtained.User.SystemUserName, Equals, "jdoe") + + // verify key order for human readability + keys, err := orderedKeys(s.buf.Bytes()) + c.Assert(err, IsNil) + c.Check(keys, DeepEquals, []string{ + "datetime", "level", "description", + "app_id", "type", "category", "event", "user", + }) +} + +func (s *SlogSuite) TestLogLoginFailure(c *C) { + logger := s.provider.New(s.buf, s.appID, seclog.LevelInfo) + c.Assert(logger, NotNil) + + type loginFailure struct { + baseAttrs + Event string `json:"event"` + User struct { + ID int64 `json:"snapd-user-id"` + SystemUserName string `json:"system-user-name"` + StoreUserEmail string `json:"store-user-email"` + Expiration string `json:"expiration"` + } `json:"user"` + } + + user := seclog.SnapdUser{ + ID: 42, + StoreUserEmail: "user@gmail.com", + SystemUserName: "jdoe", + } + logger.LogLoginFailure(user) + + var obtained loginFailure + err := json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(time.Since(obtained.Datetime) < time.Second, Equals, true) + c.Check(obtained.Level, Equals, "WARN") + c.Check(obtained.Description, Equals, "User 42:user@gmail.com:jdoe login failure") + c.Check(obtained.AppID, Equals, s.appID) + c.Check(obtained.Event, Equals, "authn_login_failure") + c.Check(obtained.User.ID, Equals, int64(42)) + c.Check(obtained.User.StoreUserEmail, Equals, "user@gmail.com") + c.Check(obtained.User.SystemUserName, Equals, "jdoe") + + // verify key order for human readability + keys, err := orderedKeys(s.buf.Bytes()) + c.Assert(err, IsNil) + c.Check(keys, DeepEquals, []string{ + "datetime", "level", "description", + "app_id", "type", "category", "event", "user", + }) +} + +func (s *SlogSuite) TestLevelWriterSink(c *C) { + // wrap buffer in a journalWriter to exercise the levelWriter + // branch in newSlogLogger and the levelHandler wrapper + jw := seclog.NewJournalWriter(s.buf) + logger := s.provider.New(jw, s.appID, seclog.LevelInfo) + c.Assert(logger, NotNil) + + admin := seclog.SnapdUser{ + ID: 1, + SystemUserName: "admin", + } + logger.LogLoginSuccess(admin) + + // the journalWriter prepends a syslog priority prefix + raw := s.buf.String() + // INFO maps to syslog.LOG_INFO (6) + c.Check(raw[:3], Equals, "<6>") + + // the JSON payload follows the prefix + var obtained map[string]any + err := json.Unmarshal([]byte(raw[3:]), &obtained) + c.Assert(err, IsNil) + c.Check(obtained["level"], Equals, "INFO") + c.Check(obtained["event"], Equals, "authn_login_success") + userMap, ok := obtained["user"].(map[string]any) + c.Assert(ok, Equals, true) + c.Check(userMap["system-user-name"], Equals, "admin") + + // log a WARN-level message and verify the prefix changes + s.buf.Reset() + logger.LogLoginFailure(admin) + + raw = s.buf.String() + // WARN maps to syslog.LOG_WARNING (4) + c.Check(raw[:3], Equals, "<4>") +} diff --git a/wrappers/core18.go b/wrappers/core18.go index d0ec1f8ac2c..9062546cbe4 100644 --- a/wrappers/core18.go +++ b/wrappers/core18.go @@ -379,9 +379,75 @@ func AddSnapdSnapServices(s *snap.Info, opts *AddSnapdSnapServicesOptions, inter return err } + // Handle the security log journal namespace + if err := writeSnapdSecurityJournalOnCore(s, sysd); err != nil { + return err + } + return nil } +const securityJournalConfFile = "journald@snapd-security.conf" +const securityJournalDropInDir = "systemd-journald@snapd-security.service.d" +const securityJournalDropInFile = "00-snapd.conf" +const securityJournalSnapdDropInDir = "snapd.service.d" +const securityJournalSnapdDropInFile = "security-journal.conf" + +// writeSnapdSecurityJournalOnCore installs the journald namespace +// configuration and drop-in files for the snapd security log from +// the snapd snap onto the host filesystem. The namespace socket is +// pulled in via a snapd.service.d drop-in (Wants + After) so that +// it starts alongside snapd but cannot cause snapd to fail. +func writeSnapdSecurityJournalOnCore(s *snap.Info, sysd systemd.Systemd) error { + // Install the journald namespace config to /etc/systemd/ + srcConf := filepath.Join(s.MountDir(), "etc/systemd", securityJournalConfFile) + dstConf := filepath.Join(dirs.SnapSystemdDir, securityJournalConfFile) + if err := copyFileIfChanged(srcConf, dstConf); err != nil { + return fmt.Errorf("cannot install security journal config: %v", err) + } + + // Install the journald service drop-in to /etc/systemd/system/ + srcDropIn := filepath.Join(s.MountDir(), "lib/systemd/system", securityJournalDropInDir, securityJournalDropInFile) + dstDropInDir := filepath.Join(dirs.SnapServicesDir, securityJournalDropInDir) + if err := os.MkdirAll(dstDropInDir, 0755); err != nil { + return fmt.Errorf("cannot create security journal drop-in dir: %v", err) + } + dstDropIn := filepath.Join(dstDropInDir, securityJournalDropInFile) + if err := copyFileIfChanged(srcDropIn, dstDropIn); err != nil { + return fmt.Errorf("cannot install security journal service drop-in: %v", err) + } + + // Install the snapd.service drop-in to pull in the journal socket + srcSnapdDropIn := filepath.Join(s.MountDir(), "lib/systemd/system", securityJournalSnapdDropInDir, securityJournalSnapdDropInFile) + dstSnapdDropInDir := filepath.Join(dirs.SnapServicesDir, securityJournalSnapdDropInDir) + if err := os.MkdirAll(dstSnapdDropInDir, 0755); err != nil { + return fmt.Errorf("cannot create snapd.service.d dir: %v", err) + } + dstSnapdDropIn := filepath.Join(dstSnapdDropInDir, securityJournalSnapdDropInFile) + if err := copyFileIfChanged(srcSnapdDropIn, dstSnapdDropIn); err != nil { + return fmt.Errorf("cannot install snapd service drop-in for security journal: %v", err) + } + + return sysd.DaemonReload() +} + +// copyFileIfChanged copies src to dst only if the contents differ or +// dst does not exist. Returns nil if dst is already up to date. +func copyFileIfChanged(src, dst string) error { + srcContent, err := os.ReadFile(src) + if err != nil { + return err + } + err = osutil.EnsureFileState(dst, &osutil.MemoryFileState{ + Content: srcContent, + Mode: 0644, + }) + if err == osutil.ErrSameState { + return nil + } + return err +} + // undoSnapdUserServicesOnCore attempts to remove services that were deployed in // the filesystem as part of snapd snap installation. This should only be // executed as part of a controlled undo path. diff --git a/wrappers/core18_test.go b/wrappers/core18_test.go index 595dc0860c4..6c1eb81ffe4 100644 --- a/wrappers/core18_test.go +++ b/wrappers/core18_test.go @@ -124,6 +124,19 @@ func makeMockSnapdSnapWithOverrides(c *C, metaSnapYaml string, extra [][]string) "[Desktop Entry]\n" + "Name=Handler for snap:// URIs", }, + // security journal namespace files + {"etc/systemd/journald@snapd-security.conf", "" + + "[Journal]\nStorage=persistent\nCompress=yes\n" + + "SystemMaxFileSize=10M\nSystemMaxUse=10M\n" + + "SyncIntervalSec=30s\nSyncOnShutdown=yes\n", + }, + {"lib/systemd/system/systemd-journald@snapd-security.service.d/00-snapd.conf", "" + + "[Service]\nLogsDirectory=\n", + }, + {"lib/systemd/system/snapd.service.d/security-journal.conf", "" + + "[Unit]\nWants=systemd-journald@snapd-security.socket\n" + + "After=systemd-journald@snapd-security.socket\n", + }, } content := append(defaultContent, extra...) @@ -245,6 +258,21 @@ WantedBy=snapd.service }, { filepath.Join(dirs.SnapDesktopFilesDir, "snap-handle-link.desktop"), "[Desktop Entry]\nName=Handler for snap:// URIs", + }, { + // check that security journal config is installed + filepath.Join(dirs.SnapSystemdDir, "journald@snapd-security.conf"), + "[Journal]\nStorage=persistent\nCompress=yes\n" + + "SystemMaxFileSize=10M\nSystemMaxUse=10M\n" + + "SyncIntervalSec=30s\nSyncOnShutdown=yes\n", + }, { + // check that security journal service drop-in is installed + filepath.Join(dirs.SnapServicesDir, "systemd-journald@snapd-security.service.d/00-snapd.conf"), + "[Service]\nLogsDirectory=\n", + }, { + // check that snapd.service.d drop-in for security journal is installed + filepath.Join(dirs.SnapServicesDir, "snapd.service.d/security-journal.conf"), + "[Unit]\nWants=systemd-journald@snapd-security.socket\n" + + "After=systemd-journald@snapd-security.socket\n", }} { c.Check(entry[0], testutil.FileEquals, entry[1]) } @@ -278,6 +306,8 @@ WantedBy=snapd.service {"--user", "--global", "--no-reload", "disable", "snapd.session-agent.socket"}, {"--user", "--global", "--no-reload", "enable", "snapd.session-agent.socket"}, {"--user", "daemon-reload"}, + // security journal files installed + {"daemon-reload"}, }) } @@ -321,6 +351,8 @@ type: snapd {"--user", "--global", "--no-reload", "disable", "snapd.session-agent.socket"}, {"--user", "--global", "--no-reload", "enable", "snapd.session-agent.socket"}, {"--user", "daemon-reload"}, + // security journal files installed + {"daemon-reload"}, } s.testAddSnapServicesOperationsWithQuirks(c, quirkySnapdYaml, extras, expectedOps) @@ -372,6 +404,8 @@ type: snapd {"--user", "--global", "--no-reload", "disable", "snapd.session-agent.socket"}, {"--user", "--global", "--no-reload", "enable", "snapd.session-agent.socket"}, {"--user", "daemon-reload"}, + // security journal files installed + {"daemon-reload"}, } s.testAddSnapServicesOperationsWithQuirks(c, quirkySnapdYaml, extras, expectedOps) From 62793aa565bf49ffeb30ac1a2ab9965d89ce8dd5 Mon Sep 17 00:00:00 2001 From: ernestl Date: Wed, 22 Apr 2026 22:38:05 +0200 Subject: [PATCH 20/21] many: remove journal sink --- cmd/snapd/export_test.go | 22 +- cmd/snapd/main.go | 21 +- cmd/snapd/main_test.go | 43 ++ daemon/api_users.go | 51 ++- daemon/api_users_test.go | 59 +++ daemon/export_api_users_test.go | 4 + data/systemd/Makefile | 10 - data/systemd/journald@snapd-security.conf | 12 - .../snapd.service.d/security-journal.conf | 6 - .../00-snapd.conf | 5 - .../configstate/configcore/export_test.go | 8 - overlord/configstate/configcore/handlers.go | 3 - .../configcore/security_logging.go | 218 ----------- .../configcore/security_logging_test.go | 300 -------------- seclog/audit.go | 115 ------ seclog/audit_linux.go | 197 ++++++++++ seclog/audit_linux_test.go | 320 +++++++++++++++ seclog/export_audit_linux_test.go | 49 +++ seclog/export_slog_test.go | 5 +- seclog/export_test.go | 52 ++- seclog/journal.go | 134 ------- seclog/journal_test.go | 138 ------- seclog/nop.go | 2 +- seclog/nop_test.go | 6 +- seclog/seclog.go | 245 +++++++++--- seclog/seclog_test.go | 366 ++++++++++++++++-- seclog/slog.go | 50 +-- seclog/slog_test.go | 120 +++--- tests/main/security-logging/task.yaml | 39 ++ wrappers/core18.go | 66 ---- wrappers/core18_test.go | 34 -- 31 files changed, 1445 insertions(+), 1255 deletions(-) delete mode 100644 data/systemd/journald@snapd-security.conf delete mode 100644 data/systemd/snapd.service.d/security-journal.conf delete mode 100644 data/systemd/systemd-journald@snapd-security.service.d/00-snapd.conf delete mode 100644 overlord/configstate/configcore/security_logging.go delete mode 100644 overlord/configstate/configcore/security_logging_test.go delete mode 100644 seclog/audit.go create mode 100644 seclog/audit_linux.go create mode 100644 seclog/audit_linux_test.go create mode 100644 seclog/export_audit_linux_test.go delete mode 100644 seclog/journal.go delete mode 100644 seclog/journal_test.go create mode 100644 tests/main/security-logging/task.yaml diff --git a/cmd/snapd/export_test.go b/cmd/snapd/export_test.go index dc67d86fbae..2bc816fd88d 100644 --- a/cmd/snapd/export_test.go +++ b/cmd/snapd/export_test.go @@ -21,10 +21,14 @@ package main import ( "time" + + "github.com/snapcore/snapd/seclog" ) var ( - Run = run + Run = run + SetupSecurityLogger = setupSecurityLogger + DisableSecurityLogger = disableSecurityLogger ) func MockSyscheckCheckSystem(f func() error) (restore func()) { @@ -35,6 +39,22 @@ func MockSyscheckCheckSystem(f func() error) (restore func()) { } } +func MockSeclogSetup(f func(seclog.Impl, seclog.Sink, string, seclog.Level) error) (restore func()) { + old := seclogSetup + seclogSetup = f + return func() { + seclogSetup = old + } +} + +func MockSeclogDisable(f func() error) (restore func()) { + old := seclogDisable + seclogDisable = f + return func() { + seclogDisable = old + } +} + func MockCheckRunningConditionsRetryDelay(d time.Duration) (restore func()) { oldCheckRunningConditionsRetryDelay := checkRunningConditionsRetryDelay checkRunningConditionsRetryDelay = d diff --git a/cmd/snapd/main.go b/cmd/snapd/main.go index 9f7c6dea01e..7214ecfc43b 100644 --- a/cmd/snapd/main.go +++ b/cmd/snapd/main.go @@ -42,20 +42,33 @@ import ( var ( syscheckCheckSystem = syscheck.CheckSystem + seclogSetup = seclog.Setup + seclogDisable = seclog.Disable ) const secLogAppID = "canonical.snapd.snapd" const secLogMinLevel seclog.Level = seclog.LevelInfo -func init() { - logger.SimpleSetup(nil) +func setupSecurityLogger() { + if err := seclogSetup(seclog.ImplSlog, seclog.SinkAudit, secLogAppID, secLogMinLevel); err != nil { + logger.Noticef("WARNING: %v", err) + } +} - if err := seclog.Setup(seclog.ImplSlog, seclog.SinkJournal, secLogAppID, secLogMinLevel); err != nil { - logger.Noticef("%v", err) +func disableSecurityLogger() { + if err := seclogDisable(); err != nil { + logger.Noticef("WARNING: cannot disable security logger: %v", err) } } +func init() { + logger.SimpleSetup(nil) + setupSecurityLogger() +} + func main() { + defer disableSecurityLogger() + // When preseeding re-exec is not used if snapdenv.Preseeding() { logger.Noticef("running for preseeding") diff --git a/cmd/snapd/main_test.go b/cmd/snapd/main_test.go index bb69327bc83..041ee3fa0fb 100644 --- a/cmd/snapd/main_test.go +++ b/cmd/snapd/main_test.go @@ -35,6 +35,7 @@ import ( "github.com/snapcore/snapd/interfaces/seccomp" "github.com/snapcore/snapd/logger" "github.com/snapcore/snapd/osutil" + "github.com/snapcore/snapd/seclog" "github.com/snapcore/snapd/testutil" ) @@ -60,6 +61,48 @@ func (s *snapdSuite) SetUpTest(c *C) { s.AddCleanup(restore) } +func (s *snapdSuite) TestSetupSecurityLoggerWarnsOnError(c *C) { + logbuf, restore := logger.MockLogger() + defer restore() + + restore = snapd.MockSeclogSetup(func(impl seclog.Impl, sink seclog.Sink, appID string, level seclog.Level) error { + return fmt.Errorf("security logger disabled: cannot open audit socket: permission denied") + }) + defer restore() + + snapd.SetupSecurityLogger() + + c.Check(logbuf.String(), testutil.Contains, "WARNING: security logger disabled: cannot open audit socket: permission denied") +} + +func (s *snapdSuite) TestDisableSecurityLoggerCallsDisable(c *C) { + _, restore := logger.MockLogger() + defer restore() + + disabled := false + restore = snapd.MockSeclogDisable(func() error { + disabled = true + return nil + }) + defer restore() + + snapd.DisableSecurityLogger() + c.Check(disabled, Equals, true) +} + +func (s *snapdSuite) TestDisableSecurityLoggerWarnsOnError(c *C) { + logbuf, restore := logger.MockLogger() + defer restore() + + restore = snapd.MockSeclogDisable(func() error { + return fmt.Errorf("audit socket busy") + }) + defer restore() + + snapd.DisableSecurityLogger() + c.Check(logbuf.String(), testutil.Contains, "WARNING: cannot disable security logger: audit socket busy") +} + func (s *snapdSuite) TestSyscheckFailGoesIntoDegradedMode(c *C) { logbuf, restore := logger.MockLogger() defer restore() diff --git a/daemon/api_users.go b/daemon/api_users.go index 23648f04faa..ba27164d476 100644 --- a/daemon/api_users.go +++ b/daemon/api_users.go @@ -71,6 +71,7 @@ var ( deviceStateRemoveUser = devicestate.RemoveUser seclogLogLoginSuccess = seclog.LogLoginSuccess + seclogLogLoginFailure = seclog.LogLoginFailure ) // userResponseData contains the data releated to user creation/login/query @@ -86,6 +87,17 @@ type userResponseData struct { var isEmailish = regexp.MustCompile(`.@.*\..`).MatchString +// loginError logs a login failure to the security audit log and returns resp +// unchanged. It is a convenience wrapper so that each error return path in +// loginUser can log with a single call. +func loginError(resp *apiError, snapdUser seclog.SnapdUser, code string) *apiError { + seclogLogLoginFailure(snapdUser, seclog.Reason{ + Code: code, + Message: resp.Message, + }) + return resp +} + func loginUser(c *Command, r *http.Request, user *auth.UserState) Response { var loginData struct { Username string `json:"username"` @@ -119,41 +131,49 @@ func loginUser(c *Command, r *http.Request, user *auth.UserState) Response { } } + // Build the user identity for security audit logging. At this + // point we know the email and optional username; the numeric ID + // is only available after successful authentication. + snapdUser := seclog.SnapdUser{ + SystemUserName: loginData.Username, + StoreUserEmail: loginData.Email, + } + overlord := c.d.overlord st := overlord.State() theStore := storeFrom(c.d) macaroon, discharge, err := theStore.LoginUser(loginData.Email, loginData.Password, loginData.Otp) switch err { case store.ErrAuthenticationNeeds2fa: - return &apiError{ + return loginError(&apiError{ Status: 401, Message: err.Error(), Kind: client.ErrorKindTwoFactorRequired, - } + }, snapdUser, seclog.ReasonTwoFactorRequired) case store.Err2faFailed: - return &apiError{ + return loginError(&apiError{ Status: 401, Message: err.Error(), Kind: client.ErrorKindTwoFactorFailed, - } + }, snapdUser, seclog.ReasonTwoFactorFailed) default: switch err := err.(type) { case store.InvalidAuthDataError: - return &apiError{ + return loginError(&apiError{ Status: 400, Message: err.Error(), Kind: client.ErrorKindInvalidAuthData, Value: err, - } + }, snapdUser, seclog.ReasonInvalidAuthData) case store.PasswordPolicyError: - return &apiError{ + return loginError(&apiError{ Status: 401, Message: err.Error(), Kind: client.ErrorKindPasswordPolicy, Value: err, - } + }, snapdUser, seclog.ReasonPasswordPolicy) } - return Unauthorized(err.Error()) + return loginError(Unauthorized(err.Error()), snapdUser, seclog.ReasonInvalidCredentials) case nil: // continue } @@ -175,15 +195,14 @@ func loginUser(c *Command, r *http.Request, user *auth.UserState) Response { } st.Unlock() if err != nil { - return InternalError("cannot persist authentication details: %v", err) + return loginError(InternalError("cannot persist authentication details: %v", err), snapdUser, seclog.ReasonInternal) } - seclogLogLoginSuccess(seclog.SnapdUser{ - ID: int64(user.ID), - SystemUserName: user.Username, - StoreUserEmail: user.Email, - Expiration: user.Expiration, - }) + snapdUser.ID = int64(user.ID) + snapdUser.SystemUserName = user.Username + snapdUser.StoreUserEmail = user.Email + snapdUser.Expiration = user.Expiration + seclogLogLoginSuccess(snapdUser) result := userResponseData{ ID: user.ID, diff --git a/daemon/api_users_test.go b/daemon/api_users_test.go index b0e645e06ad..2e0ef43b241 100644 --- a/daemon/api_users_test.go +++ b/daemon/api_users_test.go @@ -425,6 +425,13 @@ func (s *userSuite) TestLoginUserDeveloperAPIError(c *check.C) { func (s *userSuite) TestLoginUserTwoFactorRequiredError(c *check.C) { s.expectLoginAccess() + var loggedUser seclog.SnapdUser + var loggedReason seclog.Reason + s.AddCleanup(daemon.MockSeclogLogLoginFailure(func(user seclog.SnapdUser, reason seclog.Reason) { + loggedUser = user + loggedReason = reason + })) + s.err = store.ErrAuthenticationNeeds2fa buf := bytes.NewBufferString(`{"username": "email@.com", "password": "password"}`) req, err := http.NewRequest("POST", "/v2/login", buf) @@ -433,11 +440,21 @@ func (s *userSuite) TestLoginUserTwoFactorRequiredError(c *check.C) { rspe := s.errorReq(c, req, nil, actionIsExpected) c.Check(rspe.Status, check.Equals, 401) c.Check(rspe.Kind, check.Equals, client.ErrorKindTwoFactorRequired) + + c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") + c.Check(loggedReason.Code, check.Equals, seclog.ReasonTwoFactorRequired) } func (s *userSuite) TestLoginUserTwoFactorFailedError(c *check.C) { s.expectLoginAccess() + var loggedUser seclog.SnapdUser + var loggedReason seclog.Reason + s.AddCleanup(daemon.MockSeclogLogLoginFailure(func(user seclog.SnapdUser, reason seclog.Reason) { + loggedUser = user + loggedReason = reason + })) + s.err = store.Err2faFailed buf := bytes.NewBufferString(`{"username": "email@.com", "password": "password"}`) req, err := http.NewRequest("POST", "/v2/login", buf) @@ -446,11 +463,21 @@ func (s *userSuite) TestLoginUserTwoFactorFailedError(c *check.C) { rspe := s.errorReq(c, req, nil, actionIsExpected) c.Check(rspe.Status, check.Equals, 401) c.Check(rspe.Kind, check.Equals, client.ErrorKindTwoFactorFailed) + + c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") + c.Check(loggedReason.Code, check.Equals, seclog.ReasonTwoFactorFailed) } func (s *userSuite) TestLoginUserInvalidCredentialsError(c *check.C) { s.expectLoginAccess() + var loggedUser seclog.SnapdUser + var loggedReason seclog.Reason + s.AddCleanup(daemon.MockSeclogLogLoginFailure(func(user seclog.SnapdUser, reason seclog.Reason) { + loggedUser = user + loggedReason = reason + })) + s.err = store.ErrInvalidCredentials buf := bytes.NewBufferString(`{"username": "email@.com", "password": "password"}`) req, err := http.NewRequest("POST", "/v2/login", buf) @@ -459,6 +486,10 @@ func (s *userSuite) TestLoginUserInvalidCredentialsError(c *check.C) { rspe := s.errorReq(c, req, nil, actionIsExpected) c.Check(rspe.Status, check.Equals, 401) c.Check(rspe.Message, check.Equals, "invalid credentials") + + c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") + c.Check(loggedReason.Code, check.Equals, seclog.ReasonInvalidCredentials) + c.Check(loggedReason.Message, check.Equals, "invalid credentials") } func (s *userSuite) TestLoginUserInvalidAuthDataError(c *check.C) { @@ -489,6 +520,34 @@ func (s *userSuite) TestLoginUserPasswordPolicyError(c *check.C) { c.Check(rspe.Value, check.DeepEquals, s.err) } +func (s *userSuite) TestLoginUserPersistError(c *check.C) { + s.expectLoginAccess() + + var loggedUser seclog.SnapdUser + var loggedReason seclog.Reason + s.AddCleanup(daemon.MockSeclogLogLoginFailure(func(user seclog.SnapdUser, reason seclog.Reason) { + loggedUser = user + loggedReason = reason + })) + + s.loginUserStoreMacaroon = "user-macaroon" + s.loginUserDischarge = "the-discharge-macaroon-serialized-data" + buf := bytes.NewBufferString(`{"username": "username", "email": "email@.com", "password": "password"}`) + req, err := http.NewRequest("POST", "/v2/login", buf) + c.Assert(err, check.IsNil) + + // Pass a user whose ID does not exist in the auth state, so + // auth.UpdateUser returns ErrInvalidUser. + fakeUser := &auth.UserState{ID: 99999, Username: "username", Email: "email@.com"} + rspe := s.errorReq(c, req, fakeUser, actionIsExpected) + c.Check(rspe.Status, check.Equals, 500) + c.Check(rspe.Message, check.Matches, "cannot persist authentication details: .*") + + c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") + c.Check(loggedUser.SystemUserName, check.Equals, "username") + c.Check(loggedReason.Message, check.Matches, "cannot persist authentication details: .*") +} + func (s *userSuite) TestPostCreateUser(c *check.C) { s.testCreateUser(c, true) } diff --git a/daemon/export_api_users_test.go b/daemon/export_api_users_test.go index 8a20ca73b03..1c032364140 100644 --- a/daemon/export_api_users_test.go +++ b/daemon/export_api_users_test.go @@ -57,6 +57,10 @@ func MockSeclogLogLoginSuccess(f func(user seclog.SnapdUser)) (restore func()) { return testutil.Mock(&seclogLogLoginSuccess, f) } +func MockSeclogLogLoginFailure(f func(user seclog.SnapdUser, reason seclog.Reason)) (restore func()) { + return testutil.Mock(&seclogLogLoginFailure, f) +} + type ( UserResponseData = userResponseData ) diff --git a/data/systemd/Makefile b/data/systemd/Makefile index f782a977248..d294d1b5d27 100644 --- a/data/systemd/Makefile +++ b/data/systemd/Makefile @@ -40,16 +40,6 @@ install: $(SYSTEMD_UNITS) install -d -m 0755 $(DESTDIR)/$(LIBEXECDIR)/snapd install -m 0755 -t $(DESTDIR)/$(LIBEXECDIR)/snapd snapd.core-fixup.sh install -m 0755 -t $(DESTDIR)/$(LIBEXECDIR)/snapd snapd.run-from-snap - # security log journal namespace - install -d -m 0755 $(DESTDIR)/etc/systemd - install -m 0644 journald@snapd-security.conf $(DESTDIR)/etc/systemd/journald@snapd-security.conf - install -d -m 0755 $(DESTDIR)/$(SYSTEMDSYSTEMUNITDIR)/systemd-journald@snapd-security.service.d - install -m 0644 systemd-journald@snapd-security.service.d/00-snapd.conf \ - $(DESTDIR)/$(SYSTEMDSYSTEMUNITDIR)/systemd-journald@snapd-security.service.d/00-snapd.conf - # snapd.service drop-in to pull in the security journal socket (Wants + After) - install -d -m 0755 $(DESTDIR)/$(SYSTEMDSYSTEMUNITDIR)/snapd.service.d - install -m 0644 snapd.service.d/security-journal.conf \ - $(DESTDIR)/$(SYSTEMDSYSTEMUNITDIR)/snapd.service.d/security-journal.conf .PHONY: clean clean: diff --git a/data/systemd/journald@snapd-security.conf b/data/systemd/journald@snapd-security.conf deleted file mode 100644 index 337346a708d..00000000000 --- a/data/systemd/journald@snapd-security.conf +++ /dev/null @@ -1,12 +0,0 @@ -# Journald configuration for the snapd security log namespace. -# This namespace isolates security audit events (login, access control) -# from the main system journal. -[Journal] -Storage=persistent -Compress=yes -SystemMaxFileSize=10M -SystemMaxUse=10M -SyncIntervalSec=30s -SyncOnShutdown=yes -RateLimitIntervalSec=30s -RateLimitBurst=10000 diff --git a/data/systemd/snapd.service.d/security-journal.conf b/data/systemd/snapd.service.d/security-journal.conf deleted file mode 100644 index f5fd3123a47..00000000000 --- a/data/systemd/snapd.service.d/security-journal.conf +++ /dev/null @@ -1,6 +0,0 @@ -# Pull in the snapd-security journal namespace socket alongside snapd. -# This is a best-effort dependency: if the socket is masked or fails, -# snapd starts normally without security logging. -[Unit] -Wants=systemd-journald@snapd-security.socket -After=systemd-journald@snapd-security.socket diff --git a/data/systemd/systemd-journald@snapd-security.service.d/00-snapd.conf b/data/systemd/systemd-journald@snapd-security.service.d/00-snapd.conf deleted file mode 100644 index 9e36515dc43..00000000000 --- a/data/systemd/systemd-journald@snapd-security.service.d/00-snapd.conf +++ /dev/null @@ -1,5 +0,0 @@ -# Drop-in for the snapd-security journal namespace instance. -# Clears the default LogsDirectory to prevent failures with -# namespaced journald instances. -[Service] -LogsDirectory= diff --git a/overlord/configstate/configcore/export_test.go b/overlord/configstate/configcore/export_test.go index 2afb299aaac..bd112770045 100644 --- a/overlord/configstate/configcore/export_test.go +++ b/overlord/configstate/configcore/export_test.go @@ -107,11 +107,3 @@ func MockEnvPath(newEnvPath string) func() { envFilePath = newEnvPath return func() { envFilePath = oldEnvPath } } - -func MockSeclogEnable(f func() error) func() { - return testutil.Mock(&seclogEnable, f) -} - -func MockSeclogDisable(f func() error) func() { - return testutil.Mock(&seclogDisable, f) -} diff --git a/overlord/configstate/configcore/handlers.go b/overlord/configstate/configcore/handlers.go index a5c4f63729a..da39522a240 100644 --- a/overlord/configstate/configcore/handlers.go +++ b/overlord/configstate/configcore/handlers.go @@ -124,9 +124,6 @@ func init() { // system.motd addFSOnlyHandler(validateMotdConfiguration, handleMotdConfiguration, coreOnly) - // security-logging.* - addFSOnlyHandler(validateSecurityLoggingSettings, handleSecurityLoggingConfiguration, nil) - sysconfig.ApplyFilesystemOnlyDefaultsImpl = filesystemOnlyApply } diff --git a/overlord/configstate/configcore/security_logging.go b/overlord/configstate/configcore/security_logging.go deleted file mode 100644 index 1bb50bf3d4d..00000000000 --- a/overlord/configstate/configcore/security_logging.go +++ /dev/null @@ -1,218 +0,0 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (C) 2026 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ - -package configcore - -import ( - "fmt" - "os" - "path/filepath" - - "github.com/snapcore/snapd/dirs" - "github.com/snapcore/snapd/logger" - "github.com/snapcore/snapd/osutil" - "github.com/snapcore/snapd/seclog" - "github.com/snapcore/snapd/sysconfig" - "github.com/snapcore/snapd/systemd" -) - -const ( - securityJournalSocketUnit = "systemd-journald@snapd-security.socket" - securityJournalServiceUnit = "systemd-journald@snapd-security.service" - securityJournalConfPath = "/etc/systemd/journald@snapd-security.conf" - securityJournalDefaultMaxSize = "10M" -) - -var ( - seclogEnable = seclog.Enable - seclogDisable = seclog.Disable -) - -func init() { - supportedConfigurations["core.security-logging.enabled"] = true - supportedConfigurations["core.security-logging.max-size"] = true -} - -func validateSecurityLoggingSettings(tr ConfGetter) error { - if err := validateBoolFlag(tr, "security-logging.enabled"); err != nil { - return err - } - - maxSize, err := coreCfg(tr, "security-logging.max-size") - if err != nil { - return err - } - if maxSize != "" { - if err := validateJournalSizeValue(maxSize); err != nil { - return fmt.Errorf("security-logging.max-size %v", err) - } - } - - return nil -} - -func handleSecurityLoggingConfiguration(_ sysconfig.Device, tr ConfGetter, opts *fsOnlyContext) error { - enabled, err := coreCfg(tr, "security-logging.enabled") - if err != nil { - return err - } - maxSize, err := coreCfg(tr, "security-logging.max-size") - if err != nil { - return err - } - - // If nothing is set, do nothing. - if enabled == "" && maxSize == "" { - return nil - } - - rootDir := dirs.GlobalRootDir - if opts != nil { - rootDir = opts.RootDir - } - - switch enabled { - case "false": - return disableSecurityLogging(rootDir, opts) - default: - return enableSecurityLogging(rootDir, opts, maxSize) - } -} - -func disableSecurityLogging(rootDir string, opts *fsOnlyContext) error { - // Disconnect from the journal namespace first so that the sink - // is closed before we tear down the service. - if err := seclogDisable(); err != nil { - logger.Noticef("cannot disable security logger: %v", err) - } - - if opts != nil { - // During filesystem-only apply, mask the socket by creating - // a symlink to /dev/null, mirroring what systemctl mask does. - maskDir := filepath.Join(rootDir, "/etc/systemd/system") - if err := os.MkdirAll(maskDir, 0755); err != nil { - return err - } - maskPath := filepath.Join(maskDir, securityJournalSocketUnit) - os.Remove(maskPath) - return os.Symlink("/dev/null", maskPath) - } - - sysd := systemd.NewUnderRoot(rootDir, systemd.SystemMode, nil) - // Mask the socket to prevent future activation. The running - // journald instance (if any) will exit on its own once idle. - if err := sysd.Mask(securityJournalSocketUnit); err != nil { - return err - } - return nil -} - -func enableSecurityLogging(rootDir string, opts *fsOnlyContext, maxSize string) error { - confPath := filepath.Join(rootDir, securityJournalConfPath) - - // Write the namespace config file with current settings. - conf := generateSecurityJournalConf(maxSize) - if err := osutil.AtomicWriteFile(confPath, conf, 0644, 0); err != nil { - return err - } - - if opts != nil { - // Filesystem-only apply; remove any mask symlink that may - // have been created by a previous disable. - maskPath := filepath.Join(rootDir, "/etc/systemd/system", securityJournalSocketUnit) - os.Remove(maskPath) - return nil - } - - sysd := systemd.NewUnderRoot(rootDir, systemd.SystemMode, nil) - - // Unmask in case it was previously disabled. - if err := sysd.Unmask(securityJournalSocketUnit); err != nil { - return err - } - - // Start the socket unit so socket activation is available immediately. - if err := sysd.Start([]string{securityJournalSocketUnit}); err != nil { - logger.Noticef("cannot start security journal socket: %v", err) - } - - // Signal the namespaced journald to reload configuration if it was - // already running; if not, socket activation picks up the new config. - if err := sysd.Kill(securityJournalServiceUnit, "USR1", ""); err != nil { - // Non-fatal: new config takes effect on next activation. - } - - // Re-open the security logger against the fresh namespace connection. - // This is done last so that the journal service is in a stable state - // before we connect to the sink. - if err := seclogEnable(); err != nil { - logger.Noticef("cannot enable security logger: %v", err) - } - - return nil -} - -func generateSecurityJournalConf(maxSize string) []byte { - conf := "[Journal]\nStorage=persistent\nCompress=yes\n" - if maxSize == "" { - maxSize = securityJournalDefaultMaxSize - } - conf += fmt.Sprintf("SystemMaxUse=%s\n", maxSize) - conf += "SyncIntervalSec=30s\nSyncOnShutdown=yes\n" - // Sanity rate limit: generous enough to never trigger under - // normal operation, but prevents runaway log storms from - // filling the journal if something goes very wrong. - conf += "RateLimitIntervalSec=30s\nRateLimitBurst=10000\n" - return []byte(conf) -} - -// validateJournalSizeValue validates a systemd journal size value (e.g. "10M", "1G"). -// The minimum allowed value is 10M. -func validateJournalSizeValue(value string) error { - if len(value) < 2 { - return fmt.Errorf("cannot parse size %q: must be a number followed by a suffix like K, M, G or T", value) - } - suffix := value[len(value)-1] - var multiplier uint64 - switch suffix { - case 'K': - multiplier = 1024 - case 'M': - multiplier = 1024 * 1024 - case 'G': - multiplier = 1024 * 1024 * 1024 - case 'T': - multiplier = 1024 * 1024 * 1024 * 1024 - default: - return fmt.Errorf("cannot parse size %q: must be a number followed by a suffix like K, M, G or T", value) - } - numStr := value[:len(value)-1] - var num uint64 - for _, ch := range numStr { - if ch < '0' || ch > '9' { - return fmt.Errorf("cannot parse size %q: must be a number followed by a suffix like K, M, G or T", value) - } - num = num*10 + uint64(ch-'0') - } - const minSize = 10 * 1024 * 1024 // 10M - if num*multiplier < minSize { - return fmt.Errorf("cannot set size %q: must be at least 10M", value) - } - return nil -} diff --git a/overlord/configstate/configcore/security_logging_test.go b/overlord/configstate/configcore/security_logging_test.go deleted file mode 100644 index 0828600a004..00000000000 --- a/overlord/configstate/configcore/security_logging_test.go +++ /dev/null @@ -1,300 +0,0 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (C) 2026 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ - -package configcore_test - -import ( - "os" - "path/filepath" - - . "gopkg.in/check.v1" - - "github.com/snapcore/snapd/dirs" - "github.com/snapcore/snapd/osutil" - "github.com/snapcore/snapd/overlord/configstate/configcore" - "github.com/snapcore/snapd/testutil" -) - -type securityLoggingSuite struct { - configcoreSuite -} - -var _ = Suite(&securityLoggingSuite{}) - -func (s *securityLoggingSuite) SetUpTest(c *C) { - s.configcoreSuite.SetUpTest(c) - - err := os.MkdirAll(filepath.Join(dirs.GlobalRootDir, "/etc/systemd"), 0755) - c.Assert(err, IsNil) - - s.AddCleanup(configcore.MockSeclogEnable(func() error { return nil })) - s.AddCleanup(configcore.MockSeclogDisable(func() error { return nil })) -} - -// Validation tests - -func (s *securityLoggingSuite) TestValidateEnabledValid(c *C) { - for _, val := range []string{"true", "false"} { - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{"security-logging.enabled": val}, - }) - c.Check(err, IsNil, Commentf("value %q", val)) - } -} - -func (s *securityLoggingSuite) TestValidateEnabledInvalid(c *C) { - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{"security-logging.enabled": "maybe"}, - }) - c.Assert(err, ErrorMatches, `security-logging.enabled can only be set to 'true' or 'false'`) -} - -func (s *securityLoggingSuite) TestValidateMaxSizeValid(c *C) { - for _, val := range []string{"10M", "100M", "1G", "2T"} { - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{ - "security-logging.enabled": "true", - "security-logging.max-size": val, - }, - }) - c.Check(err, IsNil, Commentf("value %q", val)) - } -} - -func (s *securityLoggingSuite) TestValidateMaxSizeInvalidSuffix(c *C) { - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{ - "security-logging.enabled": "true", - "security-logging.max-size": "100X", - }, - }) - c.Assert(err, ErrorMatches, `security-logging.max-size cannot parse size "100X": must be a number followed by a suffix like K, M, G or T`) -} - -func (s *securityLoggingSuite) TestValidateMaxSizeTooSmall(c *C) { - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{ - "security-logging.enabled": "true", - "security-logging.max-size": "5M", - }, - }) - c.Assert(err, ErrorMatches, `security-logging.max-size cannot set size "5M": must be at least 10M`) -} - -func (s *securityLoggingSuite) TestValidateMaxSizeNonNumeric(c *C) { - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{ - "security-logging.enabled": "true", - "security-logging.max-size": "abcM", - }, - }) - c.Assert(err, ErrorMatches, `security-logging.max-size cannot parse size "abcM": must be a number followed by a suffix like K, M, G or T`) -} - -func (s *securityLoggingSuite) TestValidateMaxSizeTooShort(c *C) { - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{ - "security-logging.enabled": "true", - "security-logging.max-size": "M", - }, - }) - c.Assert(err, ErrorMatches, `security-logging.max-size cannot parse size "M": must be a number followed by a suffix like K, M, G or T`) -} - -// Nothing set -> no-op - -func (s *securityLoggingSuite) TestHandleNothingSet(c *C) { - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{}, - }) - c.Assert(err, IsNil) - c.Check(s.systemctlArgs, HasLen, 0) -} - -// Enable path (runtime) - -func (s *securityLoggingSuite) TestEnableSecurityLogging(c *C) { - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{ - "security-logging.enabled": "true", - }, - }) - c.Assert(err, IsNil) - - // Check the journal conf was written with default max-size. - confPath := filepath.Join(dirs.GlobalRootDir, "/etc/systemd/journald@snapd-security.conf") - c.Check(confPath, testutil.FileContains, "Storage=persistent") - c.Check(confPath, testutil.FileContains, "SystemMaxUse=10M") - c.Check(confPath, testutil.FileContains, "SyncIntervalSec=30s") - c.Check(confPath, testutil.FileContains, "RateLimitBurst=10000") - - // Check systemctl calls: unmask, start, kill (reload config). - // The mock systemd operates under a root dir, so --root is prepended. - c.Assert(s.systemctlArgs, HasLen, 3) - c.Check(s.systemctlArgs[0], testutil.DeepContains, "unmask") - c.Check(s.systemctlArgs[0], testutil.DeepContains, "systemd-journald@snapd-security.socket") - c.Check(s.systemctlArgs[1], testutil.DeepContains, "start") - c.Check(s.systemctlArgs[1], testutil.DeepContains, "systemd-journald@snapd-security.socket") - c.Check(s.systemctlArgs[2], testutil.DeepContains, "kill") - c.Check(s.systemctlArgs[2], testutil.DeepContains, "systemd-journald@snapd-security.service") -} - -func (s *securityLoggingSuite) TestEnableSecurityLoggingWithMaxSize(c *C) { - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{ - "security-logging.enabled": "true", - "security-logging.max-size": "50M", - }, - }) - c.Assert(err, IsNil) - - confPath := filepath.Join(dirs.GlobalRootDir, "/etc/systemd/journald@snapd-security.conf") - c.Check(confPath, testutil.FileContains, "SystemMaxUse=50M") -} - -func (s *securityLoggingSuite) TestEnableCallsSeclogEnable(c *C) { - enableCalled := false - s.AddCleanup(configcore.MockSeclogEnable(func() error { - enableCalled = true - return nil - })) - - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{ - "security-logging.enabled": "true", - }, - }) - c.Assert(err, IsNil) - c.Check(enableCalled, Equals, true) -} - -// Disable path (runtime) - -func (s *securityLoggingSuite) TestDisableSecurityLogging(c *C) { - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{ - "security-logging.enabled": "false", - }, - }) - c.Assert(err, IsNil) - - // Check systemctl calls: mask the socket. - c.Assert(s.systemctlArgs, HasLen, 1) - c.Check(s.systemctlArgs[0], testutil.DeepContains, "mask") - c.Check(s.systemctlArgs[0], testutil.DeepContains, "systemd-journald@snapd-security.socket") -} - -func (s *securityLoggingSuite) TestDisableCallsSeclogDisable(c *C) { - disableCalled := false - s.AddCleanup(configcore.MockSeclogDisable(func() error { - disableCalled = true - return nil - })) - - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{ - "security-logging.enabled": "false", - }, - }) - c.Assert(err, IsNil) - c.Check(disableCalled, Equals, true) -} - -// Filesystem-only enable (preseeding/install) - -func (s *securityLoggingSuite) TestEnableFSOnly(c *C) { - rootDir := c.MkDir() - err := os.MkdirAll(filepath.Join(rootDir, "/etc/systemd"), 0755) - c.Assert(err, IsNil) - - // Place a mask symlink as if previously disabled. - maskDir := filepath.Join(rootDir, "/etc/systemd/system") - err = os.MkdirAll(maskDir, 0755) - c.Assert(err, IsNil) - maskPath := filepath.Join(maskDir, "systemd-journald@snapd-security.socket") - err = os.Symlink("/dev/null", maskPath) - c.Assert(err, IsNil) - - err = configcore.FilesystemOnlyApply(coreDev, rootDir, map[string]any{ - "security-logging.enabled": "true", - }) - c.Assert(err, IsNil) - - // Conf written. - confPath := filepath.Join(rootDir, "/etc/systemd/journald@snapd-security.conf") - c.Check(confPath, testutil.FileContains, "Storage=persistent") - - // Mask symlink removed. - c.Check(osutil.IsSymlink(maskPath), Equals, false) - - // No systemctl calls in fs-only mode. - c.Check(s.systemctlArgs, HasLen, 0) -} - -// Filesystem-only disable (preseeding/install) - -func (s *securityLoggingSuite) TestDisableFSOnly(c *C) { - rootDir := c.MkDir() - - err := configcore.FilesystemOnlyApply(coreDev, rootDir, map[string]any{ - "security-logging.enabled": "false", - }) - c.Assert(err, IsNil) - - // Mask symlink created. - maskPath := filepath.Join(rootDir, "/etc/systemd/system", "systemd-journald@snapd-security.socket") - c.Check(osutil.IsSymlink(maskPath), Equals, true) - target, err := os.Readlink(maskPath) - c.Assert(err, IsNil) - c.Check(target, Equals, "/dev/null") - - // No systemctl calls. - c.Check(s.systemctlArgs, HasLen, 0) -} - -// max-size alone implies enable - -func (s *securityLoggingSuite) TestMaxSizeAloneImpliesEnable(c *C) { - err := configcore.FilesystemOnlyRun(coreDev, &mockConf{ - state: s.state, - conf: map[string]any{ - "security-logging.max-size": "100M", - }, - }) - c.Assert(err, IsNil) - - confPath := filepath.Join(dirs.GlobalRootDir, "/etc/systemd/journald@snapd-security.conf") - c.Check(confPath, testutil.FileContains, "SystemMaxUse=100M") - - // Systemctl was called (unmask, start, kill) — enable path was taken. - c.Check(len(s.systemctlArgs) > 0, Equals, true) -} diff --git a/seclog/audit.go b/seclog/audit.go deleted file mode 100644 index 43805b820b5..00000000000 --- a/seclog/audit.go +++ /dev/null @@ -1,115 +0,0 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (C) 2026 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ - -package seclog - -import ( - "encoding/binary" - "fmt" - "io" - "sync/atomic" - "syscall" -) - -const ( - // AUDIT_USER_MSG is the audit message type for user-space messages. - auditUserMsg = 1112 - - // NETLINK_AUDIT is the netlink protocol for audit. - netlinkAudit = 15 -) - -func init() { - registerSink(SinkAudit, newAuditSink) -} - -// newAuditSink opens a netlink audit socket and returns an [auditWriter] -// that sends each written payload as an AUDIT_USER_MSG. The appID is -// currently unused but accepted for sink signature compatibility. -func newAuditSink(_ string) (io.Writer, error) { - fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, netlinkAudit) - if err != nil { - return nil, fmt.Errorf("cannot open audit socket: %w", err) - } - addr := &syscall.SockaddrNetlink{ - Family: syscall.AF_NETLINK, - Pid: 0, // kernel - Groups: 0, - } - if err := syscall.Bind(fd, addr); err != nil { - syscall.Close(fd) - return nil, fmt.Errorf("cannot bind audit socket: %w", err) - } - return &auditWriter{fd: fd}, nil -} - -// auditWriter sends messages to the kernel audit subsystem via a netlink -// socket. Each Write call sends the payload as an AUDIT_USER_MSG. -// -// The writer is safe for sequential use; concurrent use requires external -// synchronization. -type auditWriter struct { - fd int - seq atomic.Uint32 -} - -// Write sends p as the payload of an AUDIT_USER_MSG netlink message. -// The returned byte count reflects only the original payload length. -func (aw *auditWriter) Write(p []byte) (int, error) { - msg := aw.buildMessage(p) - addr := &syscall.SockaddrNetlink{ - Family: syscall.AF_NETLINK, - Pid: 0, // kernel - } - if err := syscall.Sendto(aw.fd, msg, 0, addr); err != nil { - return 0, fmt.Errorf("cannot send audit message: %w", err) - } - return len(p), nil -} - -// Close closes the underlying netlink socket. -func (aw *auditWriter) Close() error { - return syscall.Close(aw.fd) -} - -// nlmsghdrSize is the size of a netlink message header in bytes -// (uint32 + uint16 + uint16 + uint32 + uint32 = 16). -const nlmsghdrSize = 16 - -// buildMessage constructs a raw netlink AUDIT_USER_MSG containing payload. -func (aw *auditWriter) buildMessage(payload []byte) []byte { - totalLen := nlmsghdrSize + uint32(len(payload)) - buf := make([]byte, nlmsgAlign(totalLen)) - - // Write header. - binary.LittleEndian.PutUint32(buf[0:4], totalLen) - binary.LittleEndian.PutUint16(buf[4:6], auditUserMsg) - binary.LittleEndian.PutUint16(buf[6:8], 0x01|0x04) // NLM_F_REQUEST | NLM_F_ACK - binary.LittleEndian.PutUint32(buf[8:12], aw.seq.Add(1)) - binary.LittleEndian.PutUint32(buf[12:16], 0) // pid 0 = kernel - - // Write payload. - copy(buf[nlmsghdrSize:], payload) - return buf -} - -// nlmsgAlign rounds up to the nearest 4-byte boundary per NLMSG_ALIGN. -func nlmsgAlign(n uint32) uint32 { - return (n + 3) &^ 3 -} diff --git a/seclog/audit_linux.go b/seclog/audit_linux.go new file mode 100644 index 00000000000..ac66ac66ca8 --- /dev/null +++ b/seclog/audit_linux.go @@ -0,0 +1,197 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +// go1.21 is required for binary.NativeEndian which is used to serialize +// netlink headers in host byte order. NativeEndian is supported on all +// architectures snapd targets: amd64, arm, arm64, ppc64le, riscv64. +// See https://cs.opensource.google/go/go/+/refs/tags/go1.26.2:src/encoding/binary/native_endian_little.go +// The nonativeendian tag allows excluding this file on toolchains that +// lack NativeEndian support. +// See https://go.dev/doc/go1.21#encoding/binary +//go:build go1.21 && !nonativeendian + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "sync/atomic" + "syscall" +) + +const ( + // AUDIT_TRUSTED_APP is the audit message type for trusted application messages. + // See https://github.com/linux-audit/audit-userspace/blob/master/lib/audit-records.h + auditTrustedApp = 1121 + + // NETLINK_AUDIT is the netlink protocol for audit. + // See https://github.com/torvalds/linux/blob/master/include/uapi/linux/netlink.h + netlinkAudit = 9 +) + +// netlinkOps abstracts the syscall operations needed to open, bind, query, +// send to, and close a netlink socket. Production code uses [realNetlinkOps]; +// tests can substitute a recording or stubbing implementation. +type netlinkOps interface { + Socket(domain, typ, proto int) (int, error) + Bind(fd int, sa syscall.Sockaddr) error + Getsockname(fd int) (syscall.Sockaddr, error) + Sendto(fd int, p []byte, flags int, to syscall.Sockaddr) error + Close(fd int) error +} + +// realNetlinkOps delegates every operation to the corresponding syscall. +type realNetlinkOps struct{} + +func (realNetlinkOps) Socket(domain, typ, proto int) (int, error) { + return syscall.Socket(domain, typ, proto) +} + +func (realNetlinkOps) Bind(fd int, sa syscall.Sockaddr) error { + return syscall.Bind(fd, sa) +} + +func (realNetlinkOps) Getsockname(fd int) (syscall.Sockaddr, error) { + return syscall.Getsockname(fd) +} + +func (realNetlinkOps) Sendto(fd int, p []byte, flags int, to syscall.Sockaddr) error { + return syscall.Sendto(fd, p, flags, to) +} + +func (realNetlinkOps) Close(fd int) error { + return syscall.Close(fd) +} + +var netlink netlinkOps = realNetlinkOps{} + +func init() { + registerSink(SinkAudit, auditSinkFactory{}) +} + +// auditSinkFactory implements [sinkFactory] for the kernel audit sink. +type auditSinkFactory struct{} + +// Ensure [auditSinkFactory] implements [sinkFactory]. +var _ sinkFactory = auditSinkFactory{} + +// Open opens a netlink audit socket and returns an [auditWriter] +// that sends each written payload as an AUDIT_TRUSTED_APP. The appID is +// currently unused but accepted for sink factory compatibility. +func (auditSinkFactory) Open(_ string) (io.Writer, error) { + // SOCK_CLOEXEC prevents the fd from leaking to child processes. + fd, err := netlink.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW|syscall.SOCK_CLOEXEC, netlinkAudit) + if err != nil { + return nil, fmt.Errorf("cannot open audit socket: %w", err) + } + addr := &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 0, // let kernel assign port ID + Groups: 0, + } + if err := netlink.Bind(fd, addr); err != nil { + netlink.Close(fd) + return nil, fmt.Errorf("cannot bind audit socket: %w", err) + } + portID, err := getPortID(fd) + if err != nil { + netlink.Close(fd) + return nil, fmt.Errorf("cannot get audit socket port ID: %w", err) + } + return &auditWriter{fd: fd, portID: portID}, nil +} + +// getPortID returns the kernel-assigned port ID of the netlink socket. +// When binding with Pid 0, the kernel assigns a unique port ID that may +// or may not equal the process PID. This value must be used in outgoing +// netlink message headers. +func getPortID(fd int) (uint32, error) { + sa, err := netlink.Getsockname(fd) + if err != nil { + return 0, err + } + addr, ok := sa.(*syscall.SockaddrNetlink) + if !ok { + return 0, errors.New("unexpected socket address type") + } + return addr.Pid, nil +} + +// auditWriter sends messages to the kernel audit subsystem via a netlink +// socket. Each Write call sends the payload as an AUDIT_TRUSTED_APP. +// +// The writer is safe for sequential use; concurrent use requires external +// synchronization. +type auditWriter struct { + fd int + portID uint32 + seq atomic.Uint32 +} + +// Write sends p as the payload of an AUDIT_TRUSTED_APP netlink message. +// The returned byte count reflects only the original payload length. +func (aw *auditWriter) Write(payload []byte) (int, error) { + msg := aw.buildMessage(payload) + addr := &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 0, // kernel + } + if err := netlink.Sendto(aw.fd, msg, 0, addr); err != nil { + return 0, fmt.Errorf("cannot send audit message: %w", err) + } + return len(payload), nil +} + +// Close closes the underlying netlink socket. +func (aw *auditWriter) Close() error { + return netlink.Close(aw.fd) +} + +// nlmsghdrSize is the size of a netlink message header in bytes +// (uint32 + uint16 + uint16 + uint32 + uint32 = 16). +const nlmsghdrSize = 16 + +// buildMessage constructs a raw netlink AUDIT_TRUSTED_APP containing payload. +// The header layout follows struct nlmsghdr from +// https://github.com/torvalds/linux/blob/master/include/uapi/linux/netlink.h#L45 +func (aw *auditWriter) buildMessage(payload []byte) []byte { + totalLen := nlmsghdrSize + uint32(len(payload)) + buf := make([]byte, nlmsgAlign(totalLen)) + + // Write header in native byte order (netlink uses host endianness). + // NativeEndian is supported on all architectures snapd targets: + // amd64, arm, arm64, ppc64le, riscv64. + // See https://cs.opensource.google/go/go/+/refs/tags/go1.26.2:src/encoding/binary/native_endian_little.go + binary.NativeEndian.PutUint32(buf[0:4], totalLen) + binary.NativeEndian.PutUint16(buf[4:6], auditTrustedApp) + binary.NativeEndian.PutUint16(buf[6:8], syscall.NLM_F_REQUEST) // fire-and-forget, no ACK + binary.NativeEndian.PutUint32(buf[8:12], aw.seq.Add(1)) + binary.NativeEndian.PutUint32(buf[12:16], aw.portID) + + // Write payload. + copy(buf[nlmsghdrSize:], payload) + return buf +} + +// nlmsgAlign rounds up to the nearest 4-byte boundary per NLMSG_ALIGN. +func nlmsgAlign(size uint32) uint32 { + return (size + 3) &^ 3 +} diff --git a/seclog/audit_linux_test.go b/seclog/audit_linux_test.go new file mode 100644 index 00000000000..35b43c8a93b --- /dev/null +++ b/seclog/audit_linux_test.go @@ -0,0 +1,320 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- +//go:build go1.21 && !nonativeendian + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog_test + +import ( + "encoding/binary" + "fmt" + "slices" + "syscall" + + . "gopkg.in/check.v1" + + "github.com/snapcore/snapd/seclog" +) + +type AuditSuite struct{} + +var _ = Suite(&AuditSuite{}) + +func (s *AuditSuite) TestNlmsgAlignAlreadyAligned(c *C) { + c.Check(seclog.NlmsgAlign(0), Equals, uint32(0)) + c.Check(seclog.NlmsgAlign(4), Equals, uint32(4)) + c.Check(seclog.NlmsgAlign(8), Equals, uint32(8)) + c.Check(seclog.NlmsgAlign(16), Equals, uint32(16)) +} + +func (s *AuditSuite) TestNlmsgAlignRoundsUp(c *C) { + c.Check(seclog.NlmsgAlign(1), Equals, uint32(4)) + c.Check(seclog.NlmsgAlign(2), Equals, uint32(4)) + c.Check(seclog.NlmsgAlign(3), Equals, uint32(4)) + c.Check(seclog.NlmsgAlign(5), Equals, uint32(8)) + c.Check(seclog.NlmsgAlign(17), Equals, uint32(20)) +} + +func (s *AuditSuite) TestBuildMessageHeaderLayout(c *C) { + aw := &seclog.AuditWriter{} + + payload := []byte("hello") + msg := seclog.AuditWriterBuildMessage(aw, payload) + + // Total length: 16 (header) + 5 (payload) = 21, aligned to 24. + c.Assert(len(msg), Equals, 24) + + // nlmsghdr fields in native byte order. + totalLen := binary.NativeEndian.Uint32(msg[0:4]) + c.Check(totalLen, Equals, uint32(21)) + + msgType := binary.NativeEndian.Uint16(msg[4:6]) + c.Check(msgType, Equals, uint16(seclog.AuditTrustedApp)) + + flags := binary.NativeEndian.Uint16(msg[6:8]) + c.Check(flags, Equals, uint16(syscall.NLM_F_REQUEST)) + + seq := binary.NativeEndian.Uint32(msg[8:12]) + c.Check(seq, Equals, uint32(1)) + + portID := binary.NativeEndian.Uint32(msg[12:16]) + c.Check(portID, Equals, uint32(0)) + + // Payload follows header. + c.Check(string(msg[seclog.NlmsghdrSize:seclog.NlmsghdrSize+5]), Equals, "hello") + + // Padding bytes after payload should be zero. + c.Check(msg[21], Equals, byte(0)) + c.Check(msg[22], Equals, byte(0)) + c.Check(msg[23], Equals, byte(0)) +} + +func (s *AuditSuite) TestBuildMessagePortID(c *C) { + aw := &seclog.AuditWriter{} + seclog.AuditWriterSetPortID(aw, 42) + + msg := seclog.AuditWriterBuildMessage(aw, []byte("x")) + + portID := binary.NativeEndian.Uint32(msg[12:16]) + c.Check(portID, Equals, uint32(42)) +} + +func (s *AuditSuite) TestBuildMessageSequenceIncrements(c *C) { + aw := &seclog.AuditWriter{} + + msg1 := seclog.AuditWriterBuildMessage(aw, []byte("a")) + msg2 := seclog.AuditWriterBuildMessage(aw, []byte("b")) + msg3 := seclog.AuditWriterBuildMessage(aw, []byte("c")) + + seq1 := binary.NativeEndian.Uint32(msg1[8:12]) + seq2 := binary.NativeEndian.Uint32(msg2[8:12]) + seq3 := binary.NativeEndian.Uint32(msg3[8:12]) + + c.Check(seq1, Equals, uint32(1)) + c.Check(seq2, Equals, uint32(2)) + c.Check(seq3, Equals, uint32(3)) +} + +func (s *AuditSuite) TestBuildMessageAlignedPayload(c *C) { + aw := &seclog.AuditWriter{} + + // Payload of exactly 4 bytes: total = 20 which is already aligned. + msg := seclog.AuditWriterBuildMessage(aw, []byte("abcd")) + c.Check(len(msg), Equals, 20) + + totalLen := binary.NativeEndian.Uint32(msg[0:4]) + c.Check(totalLen, Equals, uint32(20)) +} + +func (s *AuditSuite) TestBuildMessageEmptyPayload(c *C) { + aw := &seclog.AuditWriter{} + + msg := seclog.AuditWriterBuildMessage(aw, []byte{}) + + // 16-byte header, already aligned. + c.Check(len(msg), Equals, 16) + + totalLen := binary.NativeEndian.Uint32(msg[0:4]) + c.Check(totalLen, Equals, uint32(16)) +} + +func (s *AuditSuite) TestNlmsghdrSizeConstant(c *C) { + // nlmsghdr is: uint32 + uint16 + uint16 + uint32 + uint32 = 16 + c.Check(seclog.NlmsghdrSize, Equals, 16) +} + +func (s *AuditSuite) TestAuditSinkRegistered(c *C) { + // The init() in audit_linux.go registers SinkAudit. + // Setup should not fail with "unknown sink" for SinkAudit. + // We verify indirectly: if the sink were missing, Setup would + // return "unknown sink". + restore := seclog.MockImplementations(map[seclog.Impl]seclog.ImplFactory{}) + defer restore() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, "test", seclog.LevelInfo) + // This should fail with "unknown implementation" (not "unknown sink"), + // proving the audit sink is registered. + c.Check(err, ErrorMatches, `cannot set up security logger: unknown implementation "slog"`) +} + +// mockNetlinkOps records calls and returns configurable results. +type mockNetlinkOps struct { + socketFD int + socketErr error + bindErr error + getsockname syscall.Sockaddr + getsocknErr error + sendtoData []byte + sendtoErr error + closedFDs []int + closeErr error +} + +func (m *mockNetlinkOps) Socket(domain, typ, proto int) (int, error) { + return m.socketFD, m.socketErr +} + +func (m *mockNetlinkOps) Bind(fd int, sa syscall.Sockaddr) error { + return m.bindErr +} + +func (m *mockNetlinkOps) Getsockname(fd int) (syscall.Sockaddr, error) { + return m.getsockname, m.getsocknErr +} + +func (m *mockNetlinkOps) Sendto(fd int, p []byte, flags int, to syscall.Sockaddr) error { + m.sendtoData = slices.Clone(p) + return m.sendtoErr +} + +func (m *mockNetlinkOps) Close(fd int) error { + m.closedFDs = append(m.closedFDs, fd) + return m.closeErr +} + +// Ensure mockNetlinkOps satisfies the interface. +var _ seclog.NetlinkOps = (*mockNetlinkOps)(nil) + +func (s *AuditSuite) TestOpenSuccess(c *C) { + mock := &mockNetlinkOps{ + socketFD: 42, + getsockname: &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 99, + }, + } + restore := seclog.MockNetlink(mock) + defer restore() + + writer, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, IsNil) + c.Assert(writer, NotNil) +} + +func (s *AuditSuite) TestOpenSocketError(c *C) { + mock := &mockNetlinkOps{ + socketErr: fmt.Errorf("permission denied"), + } + restore := seclog.MockNetlink(mock) + defer restore() + + _, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, ErrorMatches, "cannot open audit socket: permission denied") +} + +func (s *AuditSuite) TestOpenBindError(c *C) { + mock := &mockNetlinkOps{ + socketFD: 10, + bindErr: fmt.Errorf("address in use"), + } + restore := seclog.MockNetlink(mock) + defer restore() + + _, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, ErrorMatches, "cannot bind audit socket: address in use") + // Socket should have been closed on bind failure. + c.Check(mock.closedFDs, DeepEquals, []int{10}) +} + +func (s *AuditSuite) TestOpenGetsocknameError(c *C) { + mock := &mockNetlinkOps{ + socketFD: 10, + getsocknErr: fmt.Errorf("bad fd"), + } + restore := seclog.MockNetlink(mock) + defer restore() + + _, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, ErrorMatches, "cannot get audit socket port ID: bad fd") + c.Check(mock.closedFDs, DeepEquals, []int{10}) +} + +func (s *AuditSuite) TestOpenGetsocknameWrongAddressType(c *C) { + mock := &mockNetlinkOps{ + socketFD: 10, + // Return a non-netlink address type. + getsockname: &syscall.SockaddrUnix{Name: "/tmp/sock"}, + } + restore := seclog.MockNetlink(mock) + defer restore() + + _, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, ErrorMatches, "cannot get audit socket port ID: unexpected socket address type") + c.Check(mock.closedFDs, DeepEquals, []int{10}) +} + +func (s *AuditSuite) TestWriteSendtoError(c *C) { + mock := &mockNetlinkOps{ + socketFD: 7, + getsockname: &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 1, + }, + sendtoErr: fmt.Errorf("no buffer space"), + } + restore := seclog.MockNetlink(mock) + defer restore() + + writer, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, IsNil) + + _, err = writer.Write([]byte("test")) + c.Assert(err, ErrorMatches, "cannot send audit message: no buffer space") +} + +func (s *AuditSuite) TestWriteSuccess(c *C) { + mock := &mockNetlinkOps{ + socketFD: 7, + getsockname: &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 1, + }, + } + restore := seclog.MockNetlink(mock) + defer restore() + + writer, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, IsNil) + + n, err := writer.Write([]byte("hello")) + c.Assert(err, IsNil) + c.Check(n, Equals, 5) + // The mock captured the raw netlink message. + c.Check(len(mock.sendtoData) > seclog.NlmsghdrSize, Equals, true) +} + +func (s *AuditSuite) TestClose(c *C) { + mock := &mockNetlinkOps{ + socketFD: 7, + getsockname: &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 1, + }, + } + restore := seclog.MockNetlink(mock) + defer restore() + + writer, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, IsNil) + + closer, ok := writer.(interface{ Close() error }) + c.Assert(ok, Equals, true) + err = closer.Close() + c.Assert(err, IsNil) + c.Check(mock.closedFDs, DeepEquals, []int{7}) +} diff --git a/seclog/export_audit_linux_test.go b/seclog/export_audit_linux_test.go new file mode 100644 index 00000000000..891ec26061e --- /dev/null +++ b/seclog/export_audit_linux_test.go @@ -0,0 +1,49 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- +//go:build go1.21 && !nonativeendian + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +import ( + "github.com/snapcore/snapd/testutil" +) + +type AuditWriter = auditWriter + +type AuditSinkFactory = auditSinkFactory + +type NetlinkOps = netlinkOps + +var NlmsgAlign = nlmsgAlign + +const NlmsghdrSize = nlmsghdrSize + +const AuditTrustedApp = auditTrustedApp + +func AuditWriterBuildMessage(aw *auditWriter, payload []byte) []byte { + return aw.buildMessage(payload) +} + +func AuditWriterSetPortID(aw *auditWriter, id uint32) { + aw.portID = id +} + +func MockNetlink(ops netlinkOps) (restore func()) { + return testutil.Mock(&netlink, ops) +} diff --git a/seclog/export_slog_test.go b/seclog/export_slog_test.go index 6d9c0366637..f893af128e0 100644 --- a/seclog/export_slog_test.go +++ b/seclog/export_slog_test.go @@ -21,6 +21,7 @@ package seclog type ( - SlogProvider = slogProvider - SlogLogger = slogLogger + SlogImplementation = slogImplementation + SlogLogger = slogLogger + LevelWriter = levelWriter ) diff --git a/seclog/export_test.go b/seclog/export_test.go index edf95624fc9..b6cf469437f 100644 --- a/seclog/export_test.go +++ b/seclog/export_test.go @@ -27,34 +27,42 @@ import ( var NewNopLogger = newNopLogger -var Register = register +var RegisterImpl = registerImpl var RegisterSink = registerSink type ( - Provider = provider + ImplFactory = implFactory + SinkFactory = sinkFactory SecurityLogger = securityLogger ) -func MockSinks(m map[Sink]func(string) (io.Writer, error)) (restore func()) { +func MockSinks(m map[Sink]sinkFactory) (restore func()) { restore = testutil.Backup(&sinks) sinks = m return restore } -// MockNewSink is a convenience wrapper that replaces the journal sink factory -// in the sinks map. The rest of the sinks map is preserved. +// sinkFunc adapts a plain function to the [sinkFactory] interface. +type sinkFunc func(string) (io.Writer, error) + +// SinkFunc exports sinkFunc for use in external test packages. +type SinkFunc = sinkFunc + +func (f sinkFunc) Open(appID string) (io.Writer, error) { return f(appID) } + +// MockNewSink is a convenience wrapper that replaces the audit sink factory +// in the sinks map. func MockNewSink(f func(string) (io.Writer, error)) (restore func()) { restore = testutil.Backup(&sinks) - sinks = map[Sink]func(string) (io.Writer, error){ - SinkJournal: f, - SinkAudit: newAuditSink, + sinks = map[Sink]sinkFactory{ + SinkAudit: sinkFunc(f), } return restore } -func MockProviders(m map[Impl]provider) (restore func()) { - restore = testutil.Backup(&providers) - providers = m +func MockImplementations(m map[Impl]implFactory) (restore func()) { + restore = testutil.Backup(&implementations) + implementations = m return restore } @@ -85,8 +93,24 @@ func MockGlobalSetup(s *LoggerSetup) (restore func()) { return restore } -var SyslogPriority = syslogPriority +const MaxWriteFailures = maxWriteFailures -var NewJournalWriter = newJournalWriter +func MockWriteFailures(n int) (restore func()) { + restore = testutil.Backup(&writeFailures) + writeFailures = n + return restore +} -type JournalWriter = journalWriter +func MockFailed(f bool) (restore func()) { + restore = testutil.Backup(&failed) + failed = f + return restore +} + +func GetFailed() bool { + return failed +} + +func GetWriteFailures() int { + return writeFailures +} diff --git a/seclog/journal.go b/seclog/journal.go deleted file mode 100644 index e23446e9385..00000000000 --- a/seclog/journal.go +++ /dev/null @@ -1,134 +0,0 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (C) 2026 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ - -package seclog - -import ( - "fmt" - "io" - "log/syslog" - "os" - - "github.com/snapcore/snapd/systemd" -) - -const securityNamespace = "snapd-security" - -func init() { - registerSink(SinkJournal, newJournalSink) -} - -// newJournalSink opens a journald stream for the "snapd-security" namespace -// and returns a [journalWriter] that prepends syslog priority prefixes to every -// written line. The resulting writer is suitable as the output sink for a -// structured security logger. -func newJournalSink(appID string) (io.Writer, error) { - f, err := newJournalStream(appID) - if err != nil { - return nil, err - } - return newJournalWriter(f), nil -} - -// journalWriter implements [levelWriter] by wrapping an [io.Writer] and -// prepending a syslog-style "" priority prefix to each Write call. When -// used with a journald stream opened in level-prefix mode, the prefix -// overrides the per-message PRIORITY field. journald strips the prefix from -// the stored MESSAGE content. -// -// SetLevel must be called before each Write to select the priority for the -// upcoming message. Concurrent use of SetLevel and Write requires external -// synchronization. -type journalWriter struct { - w io.Writer - level Level -} - -// Ensure [journalWriter] implements [levelWriter]. -var _ levelWriter = (*journalWriter)(nil) - -// newJournalWriter returns a [journalWriter] that writes to the given -// writer with per-message syslog priority prefixes. -func newJournalWriter(w io.Writer) *journalWriter { - return &journalWriter{w: w, level: LevelInfo} -} - -// SetLevel sets the syslog priority for the next Write call. -func (jw *journalWriter) SetLevel(level Level) { - jw.level = level -} - -// Close closes the underlying writer if it implements [io.Closer]. -func (jw *journalWriter) Close() error { - if closer, ok := jw.w.(io.Closer); ok { - return closer.Close() - } - return nil -} - -// Write prepends "" to p and writes the result to the underlying writer. -// The returned byte count reflects only the original payload, excluding the -// prefix, to satisfy [io.Writer] callers that compare n against len(p). -func (jw *journalWriter) Write(p []byte) (int, error) { - prefix := fmt.Sprintf("<%d>", syslogPriority(jw.level)) - buf := make([]byte, len(prefix)+len(p)) - copy(buf, prefix) - copy(buf[len(prefix):], p) - n, err := jw.w.Write(buf) - // Report bytes written minus the prefix length so that - // callers see n == len(p) on success. - if n >= len(prefix) { - return n - len(prefix), err - } - return 0, err -} - -// newJournalStream opens a journald stream connection to the -// "snapd-security" namespace. The stream uses level-prefix mode so that -// each written line can override PRIORITY per message by prepending a -// "" syslog priority prefix. journald strips the prefix from the -// stored MESSAGE. -// -// The returned *os.File is suitable as the underlying writer for a -// [journalWriter]. -func newJournalStream(appID string) (*os.File, error) { - return systemd.NewJournalStreamFile(systemd.JournalStreamFileParams{ - Namespace: securityNamespace, - Identifier: appID, - Priority: syslog.LOG_DEBUG, - LevelPrefix: true, - }) -} - -// syslogPriority maps a security log [Level] to the equivalent syslog -// priority used by journald for the PRIORITY field. -func syslogPriority(level Level) syslog.Priority { - switch { - case level >= LevelCritical: - return syslog.LOG_CRIT - case level >= LevelError: - return syslog.LOG_ERR - case level >= LevelWarn: - return syslog.LOG_WARNING - case level >= LevelInfo: - return syslog.LOG_INFO - default: - return syslog.LOG_DEBUG - } -} diff --git a/seclog/journal_test.go b/seclog/journal_test.go deleted file mode 100644 index 7659c2cd23b..00000000000 --- a/seclog/journal_test.go +++ /dev/null @@ -1,138 +0,0 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- - -/* - * Copyright (C) 2026 Canonical Ltd - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 3 as - * published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - * - */ - -package seclog_test - -import ( - "bytes" - "fmt" - - . "gopkg.in/check.v1" - - "github.com/snapcore/snapd/seclog" - "github.com/snapcore/snapd/testutil" -) - -type JournalSuite struct { - testutil.BaseTest - buf *bytes.Buffer -} - -var _ = Suite(&JournalSuite{}) - -func (s *JournalSuite) SetUpTest(c *C) { - s.BaseTest.SetUpTest(c) - s.buf = &bytes.Buffer{} -} - -func (s *JournalSuite) TearDownTest(c *C) { - s.BaseTest.TearDownTest(c) -} - -func (s *JournalSuite) TestNewJournalWriterDefaultLevel(c *C) { - jw := seclog.NewJournalWriter(s.buf) - c.Assert(jw, NotNil) - - // default level is LevelInfo, syslog.LOG_INFO == 6 - _, err := jw.Write([]byte("hello")) - c.Assert(err, IsNil) - c.Check(s.buf.String(), Equals, "<6>hello") -} - -func (s *JournalSuite) TestSetLevelAndWrite(c *C) { - jw := seclog.NewJournalWriter(s.buf) - - tests := []struct { - level seclog.Level - expectedPrefix string - }{ - {seclog.LevelDebug, "<7>"}, // LOG_DEBUG - {seclog.LevelInfo, "<6>"}, // LOG_INFO - {seclog.LevelWarn, "<4>"}, // LOG_WARNING - {seclog.LevelError, "<3>"}, // LOG_ERR - {seclog.LevelCritical, "<2>"}, // LOG_CRIT - } - - for _, t := range tests { - s.buf.Reset() - jw.SetLevel(t.level) - msg := []byte("test message") - n, err := jw.Write(msg) - c.Assert(err, IsNil) - c.Check(n, Equals, len(msg), - Commentf("level %v", t.level)) - c.Check(s.buf.String(), Equals, t.expectedPrefix+"test message", - Commentf("level %v", t.level)) - } -} - -func (s *JournalSuite) TestWriteByteCountExcludesPrefix(c *C) { - jw := seclog.NewJournalWriter(s.buf) - - msg := []byte("payload") - n, err := jw.Write(msg) - c.Assert(err, IsNil) - // n must equal len(msg), not len("<6>payload") - c.Check(n, Equals, len(msg)) -} - -type errWriter struct { - err error -} - -func (w *errWriter) Write(p []byte) (int, error) { - return 0, w.err -} - -func (s *JournalSuite) TestWritePropagatesError(c *C) { - expected := fmt.Errorf("disk full") - jw := seclog.NewJournalWriter(&errWriter{err: expected}) - - n, err := jw.Write([]byte("data")) - c.Check(err, Equals, expected) - c.Check(n, Equals, 0) -} - -// closeRecorder implements io.WriteCloser and records whether Close was called. -type closeRecorder struct { - bytes.Buffer - closed bool -} - -func (cr *closeRecorder) Close() error { - cr.closed = true - return nil -} - -func (s *JournalSuite) TestCloseForwardsToUnderlyingWriter(c *C) { - cr := &closeRecorder{} - jw := seclog.NewJournalWriter(cr) - - err := jw.Close() - c.Assert(err, IsNil) - c.Check(cr.closed, Equals, true) -} - -func (s *JournalSuite) TestCloseWithNonCloserReturnsNil(c *C) { - // bytes.Buffer does not implement io.Closer - jw := seclog.NewJournalWriter(s.buf) - - err := jw.Close() - c.Assert(err, IsNil) -} diff --git a/seclog/nop.go b/seclog/nop.go index b4513c5b0df..63c60986890 100644 --- a/seclog/nop.go +++ b/seclog/nop.go @@ -42,5 +42,5 @@ func (nopLogger) LogLoginSuccess(user SnapdUser) { } // LogLoginFailure implements [securityLogger.LogLoginFailure]. -func (nopLogger) LogLoginFailure(user SnapdUser) { +func (nopLogger) LogLoginFailure(user SnapdUser, reason Reason) { } diff --git a/seclog/nop_test.go b/seclog/nop_test.go index d183e96def7..c479459238b 100644 --- a/seclog/nop_test.go +++ b/seclog/nop_test.go @@ -20,8 +20,6 @@ package seclog_test import ( - "testing" - . "gopkg.in/check.v1" "github.com/snapcore/snapd/seclog" @@ -34,8 +32,6 @@ type NopSuite struct { var _ = Suite(&NopSuite{}) -func TestNop(t *testing.T) { TestingT(t) } - func (s *NopSuite) SetUpTest(c *C) { s.BaseTest.SetUpTest(c) } @@ -73,5 +69,5 @@ func (s *NopSuite) TestLogLoginFailure(c *C) { c.Assert(logger, NotNil) // nop logger discards all messages without error - logger.LogLoginFailure(seclog.SnapdUser{StoreUserEmail: "user@gmail.com"}) + logger.LogLoginFailure(seclog.SnapdUser{StoreUserEmail: "user@gmail.com"}, seclog.Reason{}) } diff --git a/seclog/seclog.go b/seclog/seclog.go index ac1aef6e954..fa6525873db 100644 --- a/seclog/seclog.go +++ b/seclog/seclog.go @@ -24,15 +24,8 @@ import ( "io" "sync" "time" -) -var ( - providers = map[Impl]provider{} - sinks = map[Sink]func(string) (io.Writer, error){} - globalLogger securityLogger = newNopLogger() - globalCloser io.Closer - globalSetup *loggerSetup - lock sync.Mutex + "github.com/snapcore/snapd/logger" ) // Level is the importance or severity of a log event. @@ -93,16 +86,18 @@ type Sink string // Sink types. const ( - SinkJournal Sink = "journal" // journald namespace stream - SinkAudit Sink = "audit" // kernel audit via netlink + SinkAudit Sink = "audit" // kernel audit via netlink ) // SnapdUser represents the identity of a user for security log events. +// The slog output schema is defined by [SnapdUser.LogValue], which +// renders Expiration as "never" for zero values instead of emitting a +// zero-value datetime. type SnapdUser struct { ID int64 `json:"snapd-user-id"` - SystemUserName string `json:"system-user-name,omitempty"` - StoreUserEmail string `json:"store-user-email,omitempty"` - Expiration time.Time `json:"expiration,omitzero"` + SystemUserName string `json:"system-user-name"` + StoreUserEmail string `json:"store-user-email"` + Expiration time.Time `json:"expiration"` } // String returns a colon-separated description of the user in the form @@ -110,29 +105,68 @@ type SnapdUser struct { // "unknown" as a placeholder. A zero ID is treated as unset. func (u SnapdUser) String() string { const unknown = "unknown" + id := unknown if u.ID != 0 { id = fmt.Sprintf("%d", u.ID) } + email := unknown if u.StoreUserEmail != "" { email = u.StoreUserEmail } + name := unknown if u.SystemUserName != "" { name = u.SystemUserName } + return id + ":" + email + ":" + name } +// Reason codes are stable identifiers for security audit events. +const ( + ReasonInvalidCredentials = "invalid-credentials" + ReasonTwoFactorRequired = "two-factor-required" + ReasonTwoFactorFailed = "two-factor-failed" + ReasonInvalidAuthData = "invalid-auth-data" + ReasonPasswordPolicy = "password-policy" + ReasonInternal = "internal" +) + +// Reason describes why a security event happened. +type Reason struct { + Code string `json:"code"` + Message string `json:"message"` +} + +// String returns a colon-separated representation in the form +// ":". Fields that are unset use "unknown" as a +// placeholder. +func (r Reason) String() string { + const unknown = "unknown" + + code := unknown + if r.Code != "" { + code = r.Code + } + + message := unknown + if r.Message != "" { + message = r.Message + } + + return code + ":" + message +} + // securityLogger defines the interface for emitting structured security -// audit events. Implementations are created by a [provider] and write +// audit events. Implementations are created by an [implFactory] and write // to a configured sink. type securityLogger interface { LogLoggingEnabled() LogLoggingDisabled() LogLoginSuccess(user SnapdUser) - LogLoginFailure(user SnapdUser) + LogLoginFailure(user SnapdUser, reason Reason) } // loggerSetup holds the configuration provided to Setup. @@ -143,16 +177,40 @@ type loggerSetup struct { minLevel Level } -// provider provides functions required for constructing a [securityLogger]. +// implFactory provides functions required for constructing a [securityLogger]. // It is intended for registration of available loggers. -type provider interface { +type implFactory interface { // New creates a securityLogger that writes to writer. Messages with a // severity below minLevel are silently dropped. New(writer io.Writer, appID string, minLevel Level) securityLogger - // Impl returns the identifier for this provider. - Impl() Impl } +// sinkFactory creates an [io.Writer] for a log output destination. +// The appID identifies the application opening the sink and may be +// used by implementations for tagging or routing. +// +// If the returned writer also implements [io.Closer], it will be closed +// automatically when the sink is replaced or disabled. +type sinkFactory interface { + Open(appID string) (io.Writer, error) +} + +var ( + implementations = map[Impl]implFactory{} + sinks = map[Sink]sinkFactory{} + globalLogger securityLogger = newNopLogger() + globalCloser io.Closer + globalSetup *loggerSetup + writeFailures int + failed bool + lock sync.Mutex +) + +// maxWriteFailures is the number of consecutive write failures +// tolerated before the security logger enters the failed state and +// is automatically disabled. +const maxWriteFailures = 3 + // Setup stores the logger configuration and attempts to enable the // security logger immediately. If the log sink cannot be opened (e.g. // because the journal namespace is not active yet), the configuration @@ -162,16 +220,19 @@ func Setup(impl Impl, sink Sink, appID string, minLevel Level) error { lock.Lock() defer lock.Unlock() - if _, exists := providers[impl]; !exists { + if _, exists := implementations[impl]; !exists { return fmt.Errorf("cannot set up security logger: unknown implementation %q", string(impl)) } + if _, exists := sinks[sink]; !exists { return fmt.Errorf("cannot set up security logger: unknown sink %q", string(sink)) } + globalSetup = &loggerSetup{impl: impl, sink: sink, appID: appID, minLevel: minLevel} if err := enableLocked(); err != nil { - return fmt.Errorf("security logger disabled") + return fmt.Errorf("security logger disabled: %v", err) } + return nil } @@ -191,100 +252,123 @@ func Enable() error { // Disable closes the security log sink and resets the global logger to nop. // The stored configuration is retained so that Enable can re-open the sink -// later. It is safe to call even if the logger is already a nop. +// later. Returns an error if Setup has not been called or if the sink +// cannot be closed. func Disable() error { lock.Lock() defer lock.Unlock() + if globalSetup == nil { - return nil + return fmt.Errorf("cannot disable security logger: setup has not been called") } - return closeSinkLocked() -} - -// LogLoggingEnabled logs that security auditing has been enabled. -func LogLoggingEnabled() { - lock.Lock() - defer lock.Unlock() - globalLogger.LogLoggingEnabled() -} - -// LogLoggingDisabled logs that security auditing has been disabled. -func LogLoggingDisabled() { - lock.Lock() - defer lock.Unlock() globalLogger.LogLoggingDisabled() + logger.Noticef("security logger disabled") + return closeSinkLocked() } // LogLoginSuccess logs a successful login using the global security logger. func LogLoginSuccess(user SnapdUser) { lock.Lock() defer lock.Unlock() + globalLogger.LogLoginSuccess(user) } // LogLoginFailure logs a failed login attempt using the global security logger. -func LogLoginFailure(user SnapdUser) { +func LogLoginFailure(user SnapdUser, reason Reason) { lock.Lock() defer lock.Unlock() - globalLogger.LogLoginFailure(user) + + globalLogger.LogLoginFailure(user, reason) } -// register makes a provider available by name. -// Should be called from init(). -func register(p provider) { +// registerImpl makes a logger factory available by name. +// The registration pattern allows implementations to be conditionally +// compiled via build tags without requiring the core package to +// import them directly. +// Should be called from the init() of the implementation file. +func registerImpl(name Impl, factory implFactory) { lock.Lock() defer lock.Unlock() - impl := p.Impl() - if _, exists := providers[impl]; exists { - panic(fmt.Sprintf("attempting registration for existing logger %q", impl)) + + if _, exists := implementations[name]; exists { + panic(fmt.Sprintf("attempting re-registration for existing logger %q", name)) } - providers[impl] = p + implementations[name] = factory } // registerSink makes a sink factory available by name. -// Should be called from init(). -func registerSink(name Sink, factory func(string) (io.Writer, error)) { +// The registration pattern allows sinks to be conditionally compiled +// via build tags without requiring the core package to import them +// directly. +// Should be called from the init() of the sink file. +func registerSink(name Sink, factory sinkFactory) { lock.Lock() defer lock.Unlock() + if _, exists := sinks[name]; exists { - panic(fmt.Sprintf("attempting registration for existing sink %q", name)) + panic(fmt.Sprintf("attempting re-registration for existing sink %q", name)) } sinks[name] = factory } -// enableLocked resolves the provider, opens the sink, and activates the +// enableLocked resolves the logger factory, opens the sink, and activates the // logger. Must be called with lock held and globalSetup non-nil. func enableLocked() error { - provider, exists := providers[globalSetup.impl] + factory, exists := implementations[globalSetup.impl] if !exists { - return fmt.Errorf("internal error: provider %q missing", string(globalSetup.impl)) + return fmt.Errorf("internal error: implementation %q missing", string(globalSetup.impl)) } + newSink, exists := sinks[globalSetup.sink] if !exists { return fmt.Errorf("internal error: sink %q missing", string(globalSetup.sink)) } + writer, err := openSinkLocked(newSink, globalSetup.appID) if err != nil { return fmt.Errorf("cannot enable security logger: %w", err) } - globalLogger = provider.New(writer, globalSetup.appID, globalSetup.minLevel) + + // Wrap the writer with failure tracking so that repeated write + // errors automatically disable the logger. + tracked := &failureTrackingWriter{ + writer: writer, + writeFailures: &writeFailures, + failed: &failed, + maxFailures: maxWriteFailures, + onThresholdReached: func(failures int, lastErr error) { + logger.Noticef("security logger failed after %d consecutive write errors, disabling (last error: %v)", failures, lastErr) + closeSinkLocked() + }, + } + globalLogger = factory.New(tracked, globalSetup.appID, globalSetup.minLevel) + writeFailures = 0 + failed = false + globalLogger.LogLoggingEnabled() + logger.Noticef("security logger enabled") return nil } // openSinkLocked opens the log sink and manages the closer. Any previously // open sink is closed first. Must be called with lock held. -func openSinkLocked(newSink func(string) (io.Writer, error), appID string) (io.Writer, error) { - writer, err := newSink(appID) +func openSinkLocked(factory sinkFactory, appID string) (io.Writer, error) { + writer, err := factory.Open(appID) if err != nil { return nil, err } + if globalCloser != nil { globalCloser.Close() globalCloser = nil } + + // If the writer also implements io.Closer, track it so + // the sink is closed when replaced or disabled. if closer, ok := writer.(io.Closer); ok { globalCloser = closer } + return writer, nil } @@ -299,3 +383,54 @@ func closeSinkLocked() error { } return nil } + +// levelWriter extends [io.Writer] with per-message level control. Writers +// that implement this interface allow log handlers to set the severity for +// each message before writing. +// +// This interface is defined here rather than in slog.go so that +// [failureTrackingWriter] can implement it without a build-tag +// dependency on log/slog. The slog layer's [levelHandler] and the +// audit sink's [auditWriter] are the primary consumers. +type levelWriter interface { + io.Writer + SetLevel(Level) +} + +// failureTrackingWriter wraps an [io.Writer] and counts consecutive +// write failures. When maxFailures consecutive errors are reached it +// invokes onThresholdReached and marks the logger as failed. +// +// All mutable state (writeFailures, failed) is injected via pointers +// so that the writer does not implicitly depend on package globals. +// The caller must hold [lock] when calling Write; since Write is +// invoked from within a locked Log* call, the lock is already held. +type failureTrackingWriter struct { + writer io.Writer + writeFailures *int + failed *bool + maxFailures int + onThresholdReached func(failures int, lastErr error) +} + +func (w *failureTrackingWriter) Write(p []byte) (int, error) { + n, err := w.writer.Write(p) + if err != nil { + *w.writeFailures++ + if *w.writeFailures >= w.maxFailures && !*w.failed { + *w.failed = true + w.onThresholdReached(*w.writeFailures, err) + } + return n, err + } + *w.writeFailures = 0 + return n, nil +} + +// SetLevel implements [levelWriter] so the tracking wrapper is +// transparent to the [levelHandler]. +func (w *failureTrackingWriter) SetLevel(l Level) { + if lw, ok := w.writer.(levelWriter); ok { + lw.SetLevel(l) + } +} diff --git a/seclog/seclog_test.go b/seclog/seclog_test.go index cd30a8ba870..c741cd90b91 100644 --- a/seclog/seclog_test.go +++ b/seclog/seclog_test.go @@ -1,4 +1,5 @@ // -*- Mode: Go; indent-tabs-mode: t -*- +//go:build go1.21 && !noslog /* * Copyright (C) 2026 Canonical Ltd @@ -24,11 +25,11 @@ import ( "encoding/json" "fmt" "io" - "log/syslog" "testing" . "gopkg.in/check.v1" + "github.com/snapcore/snapd/logger" "github.com/snapcore/snapd/seclog" "github.com/snapcore/snapd/testutil" ) @@ -91,53 +92,90 @@ func (s *SecLogSuite) TestString(c *C) { c.Assert(expected, DeepEquals, obtained) } -func (s *SecLogSuite) TestSyslogPriority(c *C) { - tests := []struct { - level seclog.Level - expected syslog.Priority - }{ - {seclog.LevelDebug - 1, syslog.LOG_DEBUG}, - {seclog.LevelDebug, syslog.LOG_DEBUG}, - {seclog.LevelInfo, syslog.LOG_INFO}, - {seclog.LevelWarn, syslog.LOG_WARNING}, - {seclog.LevelError, syslog.LOG_ERR}, - {seclog.LevelCritical, syslog.LOG_CRIT}, - {seclog.LevelCritical + 1, syslog.LOG_CRIT}, - } - for _, t := range tests { - c.Check(seclog.SyslogPriority(t.level), Equals, t.expected, - Commentf("level %v", t.level)) - } +func (s *SecLogSuite) TestSnapdUserString(c *C) { + // All fields set. + c.Check(seclog.SnapdUser{ + ID: 42, StoreUserEmail: "a@b.com", SystemUserName: "jdoe", + }.String(), Equals, "42:a@b.com:jdoe") + + // All fields zero/empty — all "unknown". + c.Check(seclog.SnapdUser{}.String(), Equals, "unknown:unknown:unknown") + + // Only ID set. + c.Check(seclog.SnapdUser{ID: 7}.String(), Equals, "7:unknown:unknown") + + // Only email set. + c.Check(seclog.SnapdUser{StoreUserEmail: "x@y.z"}.String(), Equals, "unknown:x@y.z:unknown") + + // Only username set. + c.Check(seclog.SnapdUser{SystemUserName: "root"}.String(), Equals, "unknown:unknown:root") +} + +func (s *SecLogSuite) TestReasonString(c *C) { + // Both fields set. + c.Check(seclog.Reason{ + Code: seclog.ReasonInvalidCredentials, Message: "bad password", + }.String(), Equals, "invalid-credentials:bad password") + + // Both fields empty — all "unknown". + c.Check(seclog.Reason{}.String(), Equals, "unknown:unknown") + + // Only code set. + c.Check(seclog.Reason{Code: seclog.ReasonInternal}.String(), Equals, "internal:unknown") + + // Only message set. + c.Check(seclog.Reason{Message: "something broke"}.String(), Equals, "unknown:something broke") } func (s *SecLogSuite) TestRegister(c *C) { - restore := seclog.MockProviders(map[seclog.Impl]seclog.Provider{}) + restore := seclog.MockImplementations(map[seclog.Impl]seclog.ImplFactory{}) defer restore() - seclog.Register(seclog.SlogProvider{}) + seclog.RegisterImpl(seclog.ImplSlog, seclog.SlogImplementation{}) // registering the same implementation again panics - c.Assert(func() { seclog.Register(seclog.SlogProvider{}) }, PanicMatches, - `attempting registration for existing logger "slog"`) + c.Assert(func() { seclog.RegisterImpl(seclog.ImplSlog, seclog.SlogImplementation{}) }, PanicMatches, + `attempting re-registration for existing logger "slog"`) +} + +func (s *SecLogSuite) TestRegisterSinkDuplicatePanics(c *C) { + restore := seclog.MockSinks(map[seclog.Sink]seclog.SinkFactory{}) + defer restore() + + dummy := seclog.SinkFunc(func(string) (io.Writer, error) { return nil, nil }) + seclog.RegisterSink(seclog.SinkAudit, dummy) + + // registering the same sink again panics + c.Assert(func() { seclog.RegisterSink(seclog.SinkAudit, dummy) }, PanicMatches, + `attempting re-registration for existing sink "audit"`) } func (s *SecLogSuite) TestSetupUnknownImpl(c *C) { - restore := seclog.MockProviders(map[seclog.Impl]seclog.Provider{}) + restore := seclog.MockImplementations(map[seclog.Impl]seclog.ImplFactory{}) defer restore() - err := seclog.Setup("unknown", seclog.SinkJournal, s.appID, seclog.LevelInfo) + err := seclog.Setup("unknown", seclog.SinkAudit, s.appID, seclog.LevelInfo) c.Assert(err, ErrorMatches, `cannot set up security logger: unknown implementation "unknown"`) } +func (s *SecLogSuite) TestSetupUnknownSink(c *C) { + restore := seclog.MockSinks(map[seclog.Sink]seclog.SinkFactory{}) + defer restore() + + err := seclog.Setup(seclog.ImplSlog, "unknown", s.appID, seclog.LevelInfo) + c.Assert(err, ErrorMatches, + `cannot set up security logger: unknown sink "unknown"`) +} + func (s *SecLogSuite) TestSetupSinkError(c *C) { restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { return nil, fmt.Errorf("journal unavailable") }) defer restore() - err := seclog.Setup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo) - c.Assert(err, ErrorMatches, "security logger disabled") + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, ErrorMatches, "security logger disabled: cannot enable security logger: journal unavailable") } func (s *SecLogSuite) TestSetupSuccess(c *C) { @@ -150,7 +188,7 @@ func (s *SecLogSuite) TestSetupSuccess(c *C) { restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) defer restoreLogger() - err := seclog.Setup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo) + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) c.Assert(err, IsNil) // verify the logger is functional by logging through it @@ -167,8 +205,11 @@ func (s *SecLogSuite) setupSlogLogger(c *C) { restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) s.AddCleanup(restoreLogger) - err := seclog.Setup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo) + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) c.Assert(err, IsNil) + + // Reset buffer after Setup, which logs the "logging enabled" event. + s.buf.Reset() } func (s *SecLogSuite) TestLogLoginSuccess(c *C) { @@ -206,14 +247,14 @@ func (s *SecLogSuite) TestLogLoginFailure(c *C) { StoreUserEmail: "user@example.com", SystemUserName: "jdoe", } - seclog.LogLoginFailure(user) + seclog.LogLoginFailure(user, seclog.Reason{Code: seclog.ReasonInvalidCredentials, Message: "invalid credentials"}) var obtained map[string]any err := json.Unmarshal(s.buf.Bytes(), &obtained) c.Assert(err, IsNil) c.Check(obtained["level"], Equals, "WARN") c.Check(obtained["description"], Equals, - "User 42:user@example.com:jdoe login failure") + "User 42:user@example.com:jdoe login failure: invalid-credentials:invalid credentials") c.Check(obtained["app_id"], Equals, s.appID) c.Check(obtained["category"], Equals, "AUTHN") c.Check(obtained["event"], Equals, "authn_login_failure") @@ -222,6 +263,10 @@ func (s *SecLogSuite) TestLogLoginFailure(c *C) { c.Check(userMap["snapd-user-id"], Equals, float64(42)) c.Check(userMap["store-user-email"], Equals, "user@example.com") c.Check(userMap["system-user-name"], Equals, "jdoe") + errMap, ok := obtained["error"].(map[string]any) + c.Assert(ok, Equals, true) + c.Check(errMap["code"], Equals, seclog.ReasonInvalidCredentials) + c.Check(errMap["message"], Equals, "invalid credentials") c.Check(obtained["type"], Equals, "security") } @@ -243,7 +288,7 @@ func (s *SecLogSuite) TestDisableClosesTheSink(c *C) { restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) defer restoreLogger() restoreSetup := seclog.MockGlobalSetup( - seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo)) + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo)) defer restoreSetup() err := seclog.Disable() @@ -251,14 +296,84 @@ func (s *SecLogSuite) TestDisableClosesTheSink(c *C) { c.Check(tracker.closed, Equals, true) } -func (s *SecLogSuite) TestDisableWithNoSinkReturnsNil(c *C) { +func (s *SecLogSuite) TestDisableLogsDisabledEvent(c *C) { + s.setupSlogLogger(c) + + err := seclog.Disable() + c.Assert(err, IsNil) + + var obtained map[string]any + err = json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(obtained["level"], Equals, "CRITICAL") + c.Check(obtained["description"], Equals, "Security logging disabled") + c.Check(obtained["category"], Equals, "SYS") + c.Check(obtained["event"], Equals, "sys_logging_disabled") +} + +func (s *SecLogSuite) TestDisableWithNoSetupReturnsError(c *C) { restoreCloser := seclog.MockGlobalCloser(nil) defer restoreCloser() restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) defer restoreLogger() + restoreSetup := seclog.MockGlobalSetup(nil) + defer restoreSetup() + + err := seclog.Disable() + c.Assert(err, ErrorMatches, "cannot disable security logger: setup has not been called") +} + +func (s *SecLogSuite) TestEnableWithNoSetupReturnsError(c *C) { + restoreSetup := seclog.MockGlobalSetup(nil) + defer restoreSetup() + + err := seclog.Enable() + c.Assert(err, ErrorMatches, "cannot enable security logger: setup has not been called") +} + +func (s *SecLogSuite) TestEnableWithMissingImpl(c *C) { + restoreSetup := seclog.MockGlobalSetup( + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo)) + defer restoreSetup() + restoreImpls := seclog.MockImplementations(map[seclog.Impl]seclog.ImplFactory{}) + defer restoreImpls() + + err := seclog.Enable() + c.Assert(err, ErrorMatches, `internal error: implementation "slog" missing`) +} + +func (s *SecLogSuite) TestEnableWithMissingSink(c *C) { + restoreSetup := seclog.MockGlobalSetup( + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo)) + defer restoreSetup() + restoreSinks := seclog.MockSinks(map[seclog.Sink]seclog.SinkFactory{}) + defer restoreSinks() + + err := seclog.Enable() + c.Assert(err, ErrorMatches, `internal error: sink "audit" missing`) +} + +func (s *SecLogSuite) TestEnableAfterDisable(c *C) { + s.setupSlogLogger(c) err := seclog.Disable() c.Assert(err, IsNil) + s.buf.Reset() + + err = seclog.Enable() + c.Assert(err, IsNil) + s.buf.Reset() + user := seclog.SnapdUser{ + ID: 1, + StoreUserEmail: "a@b.com", + SystemUserName: "u", + } + seclog.LogLoginSuccess(user) + + var obtained map[string]any + err = json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(obtained["event"], Equals, "authn_login_success") } func (s *SecLogSuite) TestDisableIsIdempotent(c *C) { @@ -268,7 +383,7 @@ func (s *SecLogSuite) TestDisableIsIdempotent(c *C) { restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) defer restoreLogger() restoreSetup := seclog.MockGlobalSetup( - seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo)) + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo)) defer restoreSetup() err := seclog.Disable() @@ -287,7 +402,7 @@ func (s *SecLogSuite) TestDisablePropagatesError(c *C) { restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) defer restoreLogger() restoreSetup := seclog.MockGlobalSetup( - seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo)) + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo)) defer restoreSetup() err := seclog.Disable() @@ -324,13 +439,192 @@ func (s *SecLogSuite) TestSetupClosesPreviousSink(c *C) { defer restoreLogger() // first setup - err := seclog.Setup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo) + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) c.Assert(err, IsNil) c.Check(first.closed, Equals, false) // second setup should close the first sink - err = seclog.Setup(seclog.ImplSlog, seclog.SinkJournal, s.appID, seclog.LevelInfo) + err = seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) c.Assert(err, IsNil) c.Check(first.closed, Equals, true) c.Check(second.closed, Equals, false) } + +// countingWriter counts successful writes before switching to errors. +type countingWriter struct { + buf bytes.Buffer + successes int // number of remaining successful writes +} + +func (w *countingWriter) Write(p []byte) (int, error) { + if w.successes > 0 { + w.successes-- + return w.buf.Write(p) + } + return 0, fmt.Errorf("write failed") +} + +func (s *SecLogSuite) TestWriteFailuresDisableAfterThreshold(c *C) { + // Allow LogLoggingEnabled to succeed so writeFailures starts at 0; + // only the test loop writes trigger failures. + cw := &countingWriter{successes: 1} + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return cw, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + logBuf, restoreStdLogger := logger.MockLogger() + defer restoreStdLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + logBuf.Reset() + + user := seclog.SnapdUser{ID: 1, SystemUserName: "test"} + + // Exactly maxWriteFailures consecutive failures trigger auto-disable. + for i := 0; i < seclog.MaxWriteFailures; i++ { + seclog.LogLoginSuccess(user) + } + + c.Check(seclog.GetFailed(), Equals, true) + c.Check(seclog.GetWriteFailures(), Equals, seclog.MaxWriteFailures) + c.Check(logBuf.String(), testutil.Contains, + "security logger failed after 3 consecutive write errors, disabling") +} + +func (s *SecLogSuite) TestWriteFailuresDoNotDisableBelowThreshold(c *C) { + // Allow LogLoggingEnabled to succeed so writeFailures starts at 0. + cw := &countingWriter{successes: 1} + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return cw, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + user := seclog.SnapdUser{ID: 1, SystemUserName: "test"} + + // Fewer than maxWriteFailures failures should not trigger auto-disable. + for i := 0; i < seclog.MaxWriteFailures-1; i++ { + seclog.LogLoginSuccess(user) + } + + c.Check(seclog.GetFailed(), Equals, false) + c.Check(seclog.GetWriteFailures(), Equals, seclog.MaxWriteFailures-1) +} + +func (s *SecLogSuite) TestWriteSuccessResetsFailureCount(c *C) { + cw := &countingWriter{successes: 100} + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return cw, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + // Simulate some failures below the threshold. + restoreFailures := seclog.MockWriteFailures(seclog.MaxWriteFailures - 1) + defer restoreFailures() + + user := seclog.SnapdUser{ID: 1, SystemUserName: "test"} + // A successful write resets the counter. + seclog.LogLoginSuccess(user) + + c.Check(seclog.GetWriteFailures(), Equals, 0) + c.Check(seclog.GetFailed(), Equals, false) +} + +func (s *SecLogSuite) TestEnableResetsFailureState(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return s.buf, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + // Simulate a failed state. + restoreFailures := seclog.MockWriteFailures(seclog.MaxWriteFailures) + defer restoreFailures() + restoreFailed := seclog.MockFailed(true) + defer restoreFailed() + + // Re-enable should reset the failure state. + err = seclog.Enable() + c.Assert(err, IsNil) + c.Check(seclog.GetFailed(), Equals, false) + c.Check(seclog.GetWriteFailures(), Equals, 0) +} + +func (s *SecLogSuite) TestEnableLogsToStandardLogger(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return s.buf, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + logBuf, restoreStdLogger := logger.MockLogger() + defer restoreStdLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + c.Check(logBuf.String(), testutil.Contains, "security logger enabled") +} + +func (s *SecLogSuite) TestDisableLogsToStandardLogger(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return s.buf, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + logBuf, restoreStdLogger := logger.MockLogger() + defer restoreStdLogger() + + err = seclog.Disable() + c.Assert(err, IsNil) + + c.Check(logBuf.String(), testutil.Contains, "security logger disabled") +} + +func (s *SecLogSuite) TestFailureTrackingWriterPassesSetLevel(c *C) { + // Use a levelBuf (defined in slog_test.go) which implements + // levelWriter so we can verify SetLevel is called through + // the failureTrackingWriter wrapper. + lb := &levelBuf{} + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return lb, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + lb.Reset() + lb.levels = nil + + seclog.LogLoginSuccess(seclog.SnapdUser{ID: 1, SystemUserName: "test"}) + + // The levelHandler should have called SetLevel on the underlying + // levelBuf through the failureTrackingWriter wrapper. + c.Assert(len(lb.levels), Equals, 1) + c.Check(lb.levels[0], Equals, seclog.LevelInfo) +} diff --git a/seclog/slog.go b/seclog/slog.go index 0eda4763e88..16458bd29c6 100644 --- a/seclog/slog.go +++ b/seclog/slog.go @@ -1,4 +1,8 @@ // -*- Mode: Go; indent-tabs-mode: t -*- + +// go1.21 is required for log/slog which was added in Go 1.21. +// See https://go.dev/doc/go1.21#slog +// The noslog tag allows excluding the slog-based logger entirely. //go:build go1.21 && !noslog /* @@ -31,26 +35,21 @@ import ( "github.com/snapcore/snapd/osutil" ) -// slogProvider implements [provider]. -type slogProvider struct{} +// slogImplementation implements [implFactory]. +type slogImplementation struct{} -// Ensure [slogProvider] implements [provider]. -var _ provider = slogProvider{} +// Ensure [slogImplementation] implements [implFactory]. +var _ implFactory = slogImplementation{} // New constructs an slog based [securityLogger] that emits structured JSON to the // provided [io.Writer]. The returned logger enables dynamic level control via // an internal [slog.LevelVar]. -func (slogProvider) New(writer io.Writer, appID string, minLevel Level) securityLogger { +func (slogImplementation) New(writer io.Writer, appID string, minLevel Level) securityLogger { return newSlogLogger(writer, appID, minLevel) } -// Impl returns the implementation. -func (slogProvider) Impl() Impl { - return ImplSlog -} - func init() { - register(slogProvider{}) + registerImpl(ImplSlog, slogImplementation{}) } func newSlogLogger(writer io.Writer, appID string, minLevel Level) securityLogger { @@ -60,6 +59,7 @@ func newSlogLogger(writer io.Writer, appID string, minLevel Level) securityLogge if lw, ok := writer.(levelWriter); ok { handler = newLevelHandler(handler, lw) } + logger := &slogLogger{ // enable dynamic level adjustment levelVar: levelVar, @@ -73,7 +73,7 @@ func newSlogLogger(writer io.Writer, appID string, minLevel Level) securityLogge } // slogLogger implements [securityLogger] and is constructed by the -// [slogProvider]. It wraps a [slog.Logger] and provides the required +// [slogImplementation]. It wraps a [slog.Logger] and provides the required // methods. The logger emits structured JSON with a predefined schema for // built-in attributes and supports dynamic log level control via an internal // [slog.LevelVar]. When used with a [levelWriter] sink, it ensures that @@ -98,7 +98,7 @@ func (l *slogLogger) LogLoggingEnabled() { l.logger.LogAttrs( context.Background(), slog.Level(LevelInfo), - "Security auditing enabled", + "Security logging enabled", slog.Attr{Key: "category", Value: slog.StringValue("SYS")}, slog.Attr{Key: "event", Value: slog.StringValue("sys_logging_enabled")}, ) @@ -109,7 +109,7 @@ func (l *slogLogger) LogLoggingDisabled() { l.logger.LogAttrs( context.Background(), slog.Level(LevelCritical), - "Security auditing disabled", + "Security logging disabled", slog.Attr{Key: "category", Value: slog.StringValue("SYS")}, slog.Attr{Key: "event", Value: slog.StringValue("sys_logging_disabled")}, ) @@ -128,25 +128,30 @@ func (l *slogLogger) LogLoginSuccess(user SnapdUser) { } // LogLoginFailure implements [securityLogger.LogLoginFailure]. -func (l *slogLogger) LogLoginFailure(user SnapdUser) { +func (l *slogLogger) LogLoginFailure(user SnapdUser, reason Reason) { l.logger.LogAttrs( context.Background(), slog.Level(LevelWarn), - fmt.Sprintf("User %s login failure", user.String()), + fmt.Sprintf("User %s login failure: %s", user.String(), reason.String()), slog.Attr{Key: "category", Value: slog.StringValue("AUTHN")}, slog.Attr{Key: "event", Value: slog.StringValue("authn_login_failure")}, slog.Any("user", user), + slog.Any("error", reason), ) } // LogValue implements [slog.LogValuer], allowing SnapdUser to be // used directly as a structured log attribute value. func (u SnapdUser) LogValue() slog.Value { + expiration := "never" + if !u.Expiration.IsZero() { + expiration = u.Expiration.UTC().Format(time.RFC3339Nano) + } return slog.GroupValue( slog.Int64("snapd-user-id", u.ID), slog.String("system-user-name", u.SystemUserName), slog.String("store-user-email", u.StoreUserEmail), - slog.String("expiration", u.Expiration.UTC().Format(time.RFC3339Nano)), + slog.String("expiration", expiration), ) } @@ -190,14 +195,6 @@ func newJsonHandler(writer io.Writer, minLevel slog.Leveler) slog.Handler { return slog.NewJSONHandler(writer, options) } -// levelWriter extends [io.Writer] with per-message level control. Writers -// that implement this interface allow log handlers to set the severity for -// each message before writing. -type levelWriter interface { - io.Writer - SetLevel(Level) -} - // levelHandler is a [slog.Handler] wrapper that sets the level on a // [levelWriter] before each message is handled. This ensures that the // written output carries the correct per-message priority. @@ -221,6 +218,7 @@ func (h *levelHandler) Enabled(ctx context.Context, level slog.Level) bool { func (h *levelHandler) Handle(ctx context.Context, r slog.Record) error { h.mu.Lock() defer h.mu.Unlock() + h.lw.SetLevel(Level(r.Level)) return h.inner.Handle(ctx, r) } @@ -229,6 +227,8 @@ func (h *levelHandler) WithAttrs(attrs []slog.Attr) slog.Handler { return &levelHandler{inner: h.inner.WithAttrs(attrs), lw: h.lw, mu: h.mu} } +// WithGroup is required by the [slog.Handler] interface but is not +// currently used by seclog. func (h *levelHandler) WithGroup(name string) slog.Handler { return &levelHandler{inner: h.inner.WithGroup(name), lw: h.lw, mu: h.mu} } diff --git a/seclog/slog_test.go b/seclog/slog_test.go index 83241f29482..d74a79cf750 100644 --- a/seclog/slog_test.go +++ b/seclog/slog_test.go @@ -25,7 +25,6 @@ import ( "context" "encoding/json" "errors" - "testing" "time" "log/slog" @@ -38,19 +37,17 @@ import ( type SlogSuite struct { testutil.BaseTest - buf *bytes.Buffer - appID string - provider seclog.Provider + buf *bytes.Buffer + appID string + factory seclog.ImplFactory } var _ = Suite(&SlogSuite{}) -func TestSlog(t *testing.T) { TestingT(t) } - func (s *SlogSuite) SetUpSuite(c *C) { s.buf = &bytes.Buffer{} s.appID = "canonical.snapd" - s.provider = seclog.SlogProvider{} + s.factory = seclog.SlogImplementation{} } func (s *SlogSuite) SetUpTest(c *C) { @@ -73,12 +70,9 @@ func extractSlogLogger(logger seclog.SecurityLogger) (*slog.Logger, error) { } } -func (s *SlogSuite) TestSlogProvider(c *C) { - logger := s.provider.New(s.buf, s.appID, seclog.LevelInfo) +func (s *SlogSuite) TestSlogImplementation(c *C) { + logger := s.factory.New(s.buf, s.appID, seclog.LevelInfo) c.Check(logger, NotNil) - - impl := s.provider.Impl() - c.Check(impl, Equals, seclog.ImplSlog) } // baseAttrs represents the non-optional attributes that is present in @@ -136,7 +130,7 @@ type attrsAllTypes struct { } func (s *SlogSuite) TestHandlerAttrsAllTypes(c *C) { - logger := s.provider.New(s.buf, s.appID, seclog.LevelInfo) + logger := s.factory.New(s.buf, s.appID, seclog.LevelInfo) c.Assert(logger, NotNil) sl, err := extractSlogLogger(logger) @@ -183,7 +177,7 @@ func (s *SlogSuite) TestHandlerAttrsAllTypes(c *C) { } func (s *SlogSuite) TestLogLoginSuccess(c *C) { - logger := s.provider.New(s.buf, s.appID, seclog.LevelInfo) + logger := s.factory.New(s.buf, s.appID, seclog.LevelInfo) c.Assert(logger, NotNil) type LoginSuccess struct { @@ -215,6 +209,7 @@ func (s *SlogSuite) TestLogLoginSuccess(c *C) { c.Check(obtained.User.ID, Equals, int64(42)) c.Check(obtained.User.StoreUserEmail, Equals, "user@gmail.com") c.Check(obtained.User.SystemUserName, Equals, "jdoe") + c.Check(obtained.User.Expiration, Equals, "never") // verify key order for human readability keys, err := orderedKeys(s.buf.Bytes()) @@ -225,8 +220,38 @@ func (s *SlogSuite) TestLogLoginSuccess(c *C) { }) } +func (s *SlogSuite) TestLogLoginSuccessWithExpiration(c *C) { + logger := s.factory.New(s.buf, s.appID, seclog.LevelInfo) + c.Assert(logger, NotNil) + + type LoginSuccess struct { + baseAttrs + Event string `json:"event"` + User struct { + ID int64 `json:"snapd-user-id"` + SystemUserName string `json:"system-user-name"` + StoreUserEmail string `json:"store-user-email"` + Expiration string `json:"expiration"` + } `json:"user"` + } + + expiry := time.Date(2026, 6, 15, 12, 0, 0, 0, time.UTC) + user := seclog.SnapdUser{ + ID: 42, + StoreUserEmail: "user@gmail.com", + SystemUserName: "jdoe", + Expiration: expiry, + } + logger.LogLoginSuccess(user) + + var obtained LoginSuccess + err := json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(obtained.User.Expiration, Equals, "2026-06-15T12:00:00Z") +} + func (s *SlogSuite) TestLogLoginFailure(c *C) { - logger := s.provider.New(s.buf, s.appID, seclog.LevelInfo) + logger := s.factory.New(s.buf, s.appID, seclog.LevelInfo) c.Assert(logger, NotNil) type loginFailure struct { @@ -238,6 +263,10 @@ func (s *SlogSuite) TestLogLoginFailure(c *C) { StoreUserEmail string `json:"store-user-email"` Expiration string `json:"expiration"` } `json:"user"` + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` } user := seclog.SnapdUser{ @@ -245,62 +274,59 @@ func (s *SlogSuite) TestLogLoginFailure(c *C) { StoreUserEmail: "user@gmail.com", SystemUserName: "jdoe", } - logger.LogLoginFailure(user) + logger.LogLoginFailure(user, seclog.Reason{Code: seclog.ReasonInvalidCredentials, Message: "invalid credentials"}) var obtained loginFailure err := json.Unmarshal(s.buf.Bytes(), &obtained) c.Assert(err, IsNil) c.Check(time.Since(obtained.Datetime) < time.Second, Equals, true) c.Check(obtained.Level, Equals, "WARN") - c.Check(obtained.Description, Equals, "User 42:user@gmail.com:jdoe login failure") + c.Check(obtained.Description, Equals, "User 42:user@gmail.com:jdoe login failure: invalid-credentials:invalid credentials") c.Check(obtained.AppID, Equals, s.appID) c.Check(obtained.Event, Equals, "authn_login_failure") c.Check(obtained.User.ID, Equals, int64(42)) c.Check(obtained.User.StoreUserEmail, Equals, "user@gmail.com") c.Check(obtained.User.SystemUserName, Equals, "jdoe") + c.Check(obtained.User.Expiration, Equals, "never") + c.Check(obtained.Error.Code, Equals, seclog.ReasonInvalidCredentials) + c.Check(obtained.Error.Message, Equals, "invalid credentials") // verify key order for human readability keys, err := orderedKeys(s.buf.Bytes()) c.Assert(err, IsNil) c.Check(keys, DeepEquals, []string{ "datetime", "level", "description", - "app_id", "type", "category", "event", "user", + "app_id", "type", "category", "event", "user", "error", }) } -func (s *SlogSuite) TestLevelWriterSink(c *C) { - // wrap buffer in a journalWriter to exercise the levelWriter - // branch in newSlogLogger and the levelHandler wrapper - jw := seclog.NewJournalWriter(s.buf) - logger := s.provider.New(jw, s.appID, seclog.LevelInfo) - c.Assert(logger, NotNil) +// levelBuf is a bytes.Buffer that also implements [seclog.LevelWriter], +// recording the level set before each log message is written. +type levelBuf struct { + bytes.Buffer + levels []seclog.Level +} - admin := seclog.SnapdUser{ - ID: 1, - SystemUserName: "admin", - } - logger.LogLoginSuccess(admin) +func (lb *levelBuf) SetLevel(l seclog.Level) { + lb.levels = append(lb.levels, l) +} - // the journalWriter prepends a syslog priority prefix - raw := s.buf.String() - // INFO maps to syslog.LOG_INFO (6) - c.Check(raw[:3], Equals, "<6>") +// Ensure levelBuf satisfies the interface. +var _ seclog.LevelWriter = (*levelBuf)(nil) - // the JSON payload follows the prefix - var obtained map[string]any - err := json.Unmarshal([]byte(raw[3:]), &obtained) +func (s *SlogSuite) TestLevelHandlerSetsLevelBeforeWrite(c *C) { + lb := &levelBuf{} + logger := seclog.SlogImplementation{}.New(lb, s.appID, seclog.LevelInfo) + + slogLogger, err := extractSlogLogger(logger) c.Assert(err, IsNil) - c.Check(obtained["level"], Equals, "INFO") - c.Check(obtained["event"], Equals, "authn_login_success") - userMap, ok := obtained["user"].(map[string]any) - c.Assert(ok, Equals, true) - c.Check(userMap["system-user-name"], Equals, "admin") - // log a WARN-level message and verify the prefix changes - s.buf.Reset() - logger.LogLoginFailure(admin) + // Use seclog level values cast to slog.Level so they pass the + // level threshold set by newSlogLogger (slog.Level(seclog.LevelInfo)). + slogLogger.Log(context.Background(), slog.Level(seclog.LevelInfo), "info message") + slogLogger.Log(context.Background(), slog.Level(seclog.LevelWarn), "warn message") - raw = s.buf.String() - // WARN maps to syslog.LOG_WARNING (4) - c.Check(raw[:3], Equals, "<4>") + c.Assert(len(lb.levels), Equals, 2) + c.Check(lb.levels[0], Equals, seclog.LevelInfo) + c.Check(lb.levels[1], Equals, seclog.LevelWarn) } diff --git a/tests/main/security-logging/task.yaml b/tests/main/security-logging/task.yaml new file mode 100644 index 00000000000..e1e23e13860 --- /dev/null +++ b/tests/main/security-logging/task.yaml @@ -0,0 +1,39 @@ +summary: Checks that security audit events are written to the kernel audit log + +details: | + The snapd daemon writes structured security audit events via the kernel + audit subsystem (AUDIT_TRUSTED_APP, type 1121). This test verifies that + a failed login attempt produces an "authn_login_failure" event and, + when store credentials are available, that a successful login produces + an "authn_login_success" event in the audit log. + +prepare: | + # Create an audit checkpoint so we only see events from this test. + ausearch --checkpoint stamp -m 1121 || true + +restore: | + snap logout || true + rm -f stamp + +execute: | + echo "Checking that a failed login attempt produces an audit event" + echo '{"email":"someemail@testing.com","password":"wrong-password"}' | \ + snap debug api -X POST -H 'Content-Type: application/json' /v2/login || true + + # The audit log entry is the raw JSON payload sent by snapd. + ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'authn_login_failure' + ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'invalid-credentials' + ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'someemail@testing.com' + + if [ -n "$SPREAD_STORE_USER" ] && [ -n "$SPREAD_STORE_PASSWORD" ]; then + echo "Checking that a successful login produces an audit event" + # Reset the checkpoint so we only see the success event. + ausearch --checkpoint stamp -m 1121 || true + + expect -d -f "$TESTSLIB"/successful_login.exp + + ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'authn_login_success' + ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH "$SPREAD_STORE_USER" + + snap logout + fi diff --git a/wrappers/core18.go b/wrappers/core18.go index 9062546cbe4..d0ec1f8ac2c 100644 --- a/wrappers/core18.go +++ b/wrappers/core18.go @@ -379,75 +379,9 @@ func AddSnapdSnapServices(s *snap.Info, opts *AddSnapdSnapServicesOptions, inter return err } - // Handle the security log journal namespace - if err := writeSnapdSecurityJournalOnCore(s, sysd); err != nil { - return err - } - return nil } -const securityJournalConfFile = "journald@snapd-security.conf" -const securityJournalDropInDir = "systemd-journald@snapd-security.service.d" -const securityJournalDropInFile = "00-snapd.conf" -const securityJournalSnapdDropInDir = "snapd.service.d" -const securityJournalSnapdDropInFile = "security-journal.conf" - -// writeSnapdSecurityJournalOnCore installs the journald namespace -// configuration and drop-in files for the snapd security log from -// the snapd snap onto the host filesystem. The namespace socket is -// pulled in via a snapd.service.d drop-in (Wants + After) so that -// it starts alongside snapd but cannot cause snapd to fail. -func writeSnapdSecurityJournalOnCore(s *snap.Info, sysd systemd.Systemd) error { - // Install the journald namespace config to /etc/systemd/ - srcConf := filepath.Join(s.MountDir(), "etc/systemd", securityJournalConfFile) - dstConf := filepath.Join(dirs.SnapSystemdDir, securityJournalConfFile) - if err := copyFileIfChanged(srcConf, dstConf); err != nil { - return fmt.Errorf("cannot install security journal config: %v", err) - } - - // Install the journald service drop-in to /etc/systemd/system/ - srcDropIn := filepath.Join(s.MountDir(), "lib/systemd/system", securityJournalDropInDir, securityJournalDropInFile) - dstDropInDir := filepath.Join(dirs.SnapServicesDir, securityJournalDropInDir) - if err := os.MkdirAll(dstDropInDir, 0755); err != nil { - return fmt.Errorf("cannot create security journal drop-in dir: %v", err) - } - dstDropIn := filepath.Join(dstDropInDir, securityJournalDropInFile) - if err := copyFileIfChanged(srcDropIn, dstDropIn); err != nil { - return fmt.Errorf("cannot install security journal service drop-in: %v", err) - } - - // Install the snapd.service drop-in to pull in the journal socket - srcSnapdDropIn := filepath.Join(s.MountDir(), "lib/systemd/system", securityJournalSnapdDropInDir, securityJournalSnapdDropInFile) - dstSnapdDropInDir := filepath.Join(dirs.SnapServicesDir, securityJournalSnapdDropInDir) - if err := os.MkdirAll(dstSnapdDropInDir, 0755); err != nil { - return fmt.Errorf("cannot create snapd.service.d dir: %v", err) - } - dstSnapdDropIn := filepath.Join(dstSnapdDropInDir, securityJournalSnapdDropInFile) - if err := copyFileIfChanged(srcSnapdDropIn, dstSnapdDropIn); err != nil { - return fmt.Errorf("cannot install snapd service drop-in for security journal: %v", err) - } - - return sysd.DaemonReload() -} - -// copyFileIfChanged copies src to dst only if the contents differ or -// dst does not exist. Returns nil if dst is already up to date. -func copyFileIfChanged(src, dst string) error { - srcContent, err := os.ReadFile(src) - if err != nil { - return err - } - err = osutil.EnsureFileState(dst, &osutil.MemoryFileState{ - Content: srcContent, - Mode: 0644, - }) - if err == osutil.ErrSameState { - return nil - } - return err -} - // undoSnapdUserServicesOnCore attempts to remove services that were deployed in // the filesystem as part of snapd snap installation. This should only be // executed as part of a controlled undo path. diff --git a/wrappers/core18_test.go b/wrappers/core18_test.go index 6c1eb81ffe4..595dc0860c4 100644 --- a/wrappers/core18_test.go +++ b/wrappers/core18_test.go @@ -124,19 +124,6 @@ func makeMockSnapdSnapWithOverrides(c *C, metaSnapYaml string, extra [][]string) "[Desktop Entry]\n" + "Name=Handler for snap:// URIs", }, - // security journal namespace files - {"etc/systemd/journald@snapd-security.conf", "" + - "[Journal]\nStorage=persistent\nCompress=yes\n" + - "SystemMaxFileSize=10M\nSystemMaxUse=10M\n" + - "SyncIntervalSec=30s\nSyncOnShutdown=yes\n", - }, - {"lib/systemd/system/systemd-journald@snapd-security.service.d/00-snapd.conf", "" + - "[Service]\nLogsDirectory=\n", - }, - {"lib/systemd/system/snapd.service.d/security-journal.conf", "" + - "[Unit]\nWants=systemd-journald@snapd-security.socket\n" + - "After=systemd-journald@snapd-security.socket\n", - }, } content := append(defaultContent, extra...) @@ -258,21 +245,6 @@ WantedBy=snapd.service }, { filepath.Join(dirs.SnapDesktopFilesDir, "snap-handle-link.desktop"), "[Desktop Entry]\nName=Handler for snap:// URIs", - }, { - // check that security journal config is installed - filepath.Join(dirs.SnapSystemdDir, "journald@snapd-security.conf"), - "[Journal]\nStorage=persistent\nCompress=yes\n" + - "SystemMaxFileSize=10M\nSystemMaxUse=10M\n" + - "SyncIntervalSec=30s\nSyncOnShutdown=yes\n", - }, { - // check that security journal service drop-in is installed - filepath.Join(dirs.SnapServicesDir, "systemd-journald@snapd-security.service.d/00-snapd.conf"), - "[Service]\nLogsDirectory=\n", - }, { - // check that snapd.service.d drop-in for security journal is installed - filepath.Join(dirs.SnapServicesDir, "snapd.service.d/security-journal.conf"), - "[Unit]\nWants=systemd-journald@snapd-security.socket\n" + - "After=systemd-journald@snapd-security.socket\n", }} { c.Check(entry[0], testutil.FileEquals, entry[1]) } @@ -306,8 +278,6 @@ WantedBy=snapd.service {"--user", "--global", "--no-reload", "disable", "snapd.session-agent.socket"}, {"--user", "--global", "--no-reload", "enable", "snapd.session-agent.socket"}, {"--user", "daemon-reload"}, - // security journal files installed - {"daemon-reload"}, }) } @@ -351,8 +321,6 @@ type: snapd {"--user", "--global", "--no-reload", "disable", "snapd.session-agent.socket"}, {"--user", "--global", "--no-reload", "enable", "snapd.session-agent.socket"}, {"--user", "daemon-reload"}, - // security journal files installed - {"daemon-reload"}, } s.testAddSnapServicesOperationsWithQuirks(c, quirkySnapdYaml, extras, expectedOps) @@ -404,8 +372,6 @@ type: snapd {"--user", "--global", "--no-reload", "disable", "snapd.session-agent.socket"}, {"--user", "--global", "--no-reload", "enable", "snapd.session-agent.socket"}, {"--user", "daemon-reload"}, - // security journal files installed - {"daemon-reload"}, } s.testAddSnapServicesOperationsWithQuirks(c, quirkySnapdYaml, extras, expectedOps) From 15c1f6c227f44f8d9619c4111a5eae55d167a240 Mon Sep 17 00:00:00 2001 From: ernestl Date: Thu, 23 Apr 2026 22:53:20 +0200 Subject: [PATCH 21/21] many: review improvements --- cmd/snapd/main.go | 2 +- daemon/api_users.go | 10 ++- daemon/api_users_test.go | 4 +- seclog/audit_linux.go | 8 +-- seclog/seclog.go | 64 +++++++++-------- seclog/seclog_test.go | 66 ++++++++++++----- seclog/slog.go | 12 ++-- seclog/slog_test.go | 25 +++---- tests/main/security-logging/task.yaml | 100 ++++++++++++++++++++++---- 9 files changed, 200 insertions(+), 91 deletions(-) diff --git a/cmd/snapd/main.go b/cmd/snapd/main.go index 7214ecfc43b..70b9c666ac0 100644 --- a/cmd/snapd/main.go +++ b/cmd/snapd/main.go @@ -63,10 +63,10 @@ func disableSecurityLogger() { func init() { logger.SimpleSetup(nil) - setupSecurityLogger() } func main() { + setupSecurityLogger() defer disableSecurityLogger() // When preseeding re-exec is not used diff --git a/daemon/api_users.go b/daemon/api_users.go index ba27164d476..df0da136130 100644 --- a/daemon/api_users.go +++ b/daemon/api_users.go @@ -135,7 +135,7 @@ func loginUser(c *Command, r *http.Request, user *auth.UserState) Response { // point we know the email and optional username; the numeric ID // is only available after successful authentication. snapdUser := seclog.SnapdUser{ - SystemUserName: loginData.Username, + StoreUserName: loginData.Username, StoreUserEmail: loginData.Email, } @@ -173,7 +173,11 @@ func loginUser(c *Command, r *http.Request, user *auth.UserState) Response { Value: err, }, snapdUser, seclog.ReasonPasswordPolicy) } - return loginError(Unauthorized(err.Error()), snapdUser, seclog.ReasonInvalidCredentials) + reason := seclog.ReasonInternal + if err == store.ErrInvalidCredentials { + reason = seclog.ReasonInvalidCredentials + } + return loginError(Unauthorized(err.Error()), snapdUser, reason) case nil: // continue } @@ -199,7 +203,7 @@ func loginUser(c *Command, r *http.Request, user *auth.UserState) Response { } snapdUser.ID = int64(user.ID) - snapdUser.SystemUserName = user.Username + snapdUser.StoreUserName = user.Username snapdUser.StoreUserEmail = user.Email snapdUser.Expiration = user.Expiration seclogLogLoginSuccess(snapdUser) diff --git a/daemon/api_users_test.go b/daemon/api_users_test.go index 2e0ef43b241..d2b382ae29f 100644 --- a/daemon/api_users_test.go +++ b/daemon/api_users_test.go @@ -209,7 +209,7 @@ func (s *userSuite) TestLoginUserWithUsername(c *check.C) { // security log was called with the right user details c.Check(loggedUser.ID, check.Equals, int64(1)) - c.Check(loggedUser.SystemUserName, check.Equals, "username") + c.Check(loggedUser.StoreUserName, check.Equals, "username") c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") } @@ -544,7 +544,7 @@ func (s *userSuite) TestLoginUserPersistError(c *check.C) { c.Check(rspe.Message, check.Matches, "cannot persist authentication details: .*") c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") - c.Check(loggedUser.SystemUserName, check.Equals, "username") + c.Check(loggedUser.StoreUserName, check.Equals, "username") c.Check(loggedReason.Message, check.Matches, "cannot persist authentication details: .*") } diff --git a/seclog/audit_linux.go b/seclog/audit_linux.go index ac66ac66ca8..022546a9b28 100644 --- a/seclog/audit_linux.go +++ b/seclog/audit_linux.go @@ -100,7 +100,7 @@ func (auditSinkFactory) Open(_ string) (io.Writer, error) { // SOCK_CLOEXEC prevents the fd from leaking to child processes. fd, err := netlink.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW|syscall.SOCK_CLOEXEC, netlinkAudit) if err != nil { - return nil, fmt.Errorf("cannot open audit socket: %w", err) + return nil, fmt.Errorf("cannot open audit socket: %v", err) } addr := &syscall.SockaddrNetlink{ Family: syscall.AF_NETLINK, @@ -109,12 +109,12 @@ func (auditSinkFactory) Open(_ string) (io.Writer, error) { } if err := netlink.Bind(fd, addr); err != nil { netlink.Close(fd) - return nil, fmt.Errorf("cannot bind audit socket: %w", err) + return nil, fmt.Errorf("cannot bind audit socket: %v", err) } portID, err := getPortID(fd) if err != nil { netlink.Close(fd) - return nil, fmt.Errorf("cannot get audit socket port ID: %w", err) + return nil, fmt.Errorf("cannot get audit socket port ID: %v", err) } return &auditWriter{fd: fd, portID: portID}, nil } @@ -155,7 +155,7 @@ func (aw *auditWriter) Write(payload []byte) (int, error) { Pid: 0, // kernel } if err := netlink.Sendto(aw.fd, msg, 0, addr); err != nil { - return 0, fmt.Errorf("cannot send audit message: %w", err) + return 0, fmt.Errorf("cannot send audit message: %v", err) } return len(payload), nil } diff --git a/seclog/seclog.go b/seclog/seclog.go index fa6525873db..efbebbdf824 100644 --- a/seclog/seclog.go +++ b/seclog/seclog.go @@ -41,15 +41,9 @@ const ( LevelCritical Level = 5 ) -// String returns a name for the level. -// If the level has a name, then that name -// in uppercase is returned. -// If the level is between named values, then -// an integer is appended to the uppercased name. -// Examples: -// -// LevelWarn.String() => "WARN" -// (LevelCritical+2).String() => "CRITICAL+2" +// String returns a name for the level. If the level has a name, then that name +// in uppercase is returned. If the level is between named values, then an +// integer is appended to the uppercased name. func (l Level) String() string { str := func(base string, val Level) string { if val == 0 { @@ -81,7 +75,8 @@ const ( ImplSlog Impl = "slog" // slog based structured logger ) -// Sink identifies a log output destination. +// Sink identifies a log output destination used for +// registration and selection of security log sinks. type Sink string // Sink types. @@ -95,13 +90,13 @@ const ( // zero-value datetime. type SnapdUser struct { ID int64 `json:"snapd-user-id"` - SystemUserName string `json:"system-user-name"` + StoreUserName string `json:"store-user-name"` StoreUserEmail string `json:"store-user-email"` Expiration time.Time `json:"expiration"` } // String returns a colon-separated description of the user in the form -// "::". Fields that are unset use +// "::". Fields that are unset use // "unknown" as a placeholder. A zero ID is treated as unset. func (u SnapdUser) String() string { const unknown = "unknown" @@ -117,8 +112,8 @@ func (u SnapdUser) String() string { } name := unknown - if u.SystemUserName != "" { - name = u.SystemUserName + if u.StoreUserName != "" { + name = u.StoreUserName } return id + ":" + email + ":" + name @@ -163,6 +158,8 @@ func (r Reason) String() string { // audit events. Implementations are created by an [implFactory] and write // to a configured sink. type securityLogger interface { + // LogLoggingEnabled and LogLoggingDisabled are internal events + // emitted automatically when the logger is enabled or disabled. LogLoggingEnabled() LogLoggingDisabled() LogLoginSuccess(user SnapdUser) @@ -213,9 +210,12 @@ const maxWriteFailures = 3 // Setup stores the logger configuration and attempts to enable the // security logger immediately. If the log sink cannot be opened (e.g. -// because the journal namespace is not active yet), the configuration -// is still stored and a non-fatal "security logger disabled" error is -// returned. A subsequent call to Enable will re-attempt activation. +// because the sink is not yet available), the configuration is still +// stored and a non-fatal "security logger disabled" error is returned. +// A subsequent call to Enable will re-attempt activation. +// +// Although Setup is reentrant, it is intended to be called exactly +// once per application, typically during early initialization. func Setup(impl Impl, sink Sink, appID string, minLevel Level) error { lock.Lock() defer lock.Unlock() @@ -238,8 +238,8 @@ func Setup(impl Impl, sink Sink, appID string, minLevel Level) error { // Enable opens the security log sink using the configuration stored by Setup, // activating the security logger. If the sink is already open, it is closed -// and re-opened, refreshing the connection to the journal namespace. -// Returns an error if Setup has not been called or if the sink cannot be opened. +// and re-opened, refreshing the connection to the sink. Returns an error if +// Setup has not been called or if the sink cannot be opened. func Enable() error { lock.Lock() defer lock.Unlock() @@ -252,14 +252,14 @@ func Enable() error { // Disable closes the security log sink and resets the global logger to nop. // The stored configuration is retained so that Enable can re-open the sink -// later. Returns an error if Setup has not been called or if the sink -// cannot be closed. +// later. If Setup has not been called, Disable is a no-op. Returns an error +// if the sink cannot be closed. func Disable() error { lock.Lock() defer lock.Unlock() if globalSetup == nil { - return fmt.Errorf("cannot disable security logger: setup has not been called") + return nil } globalLogger.LogLoggingDisabled() logger.Noticef("security logger disabled") @@ -282,10 +282,10 @@ func LogLoginFailure(user SnapdUser, reason Reason) { globalLogger.LogLoginFailure(user, reason) } -// registerImpl makes a logger factory available by name. -// The registration pattern allows implementations to be conditionally -// compiled via build tags without requiring the core package to -// import them directly. +// registerImpl makes a logger factory available by name. The registration +// pattern allows implementations to be conditionally compiled via build tags +// without requiring the core package to import them directly. +// // Should be called from the init() of the implementation file. func registerImpl(name Impl, factory implFactory) { lock.Lock() @@ -297,10 +297,10 @@ func registerImpl(name Impl, factory implFactory) { implementations[name] = factory } -// registerSink makes a sink factory available by name. -// The registration pattern allows sinks to be conditionally compiled -// via build tags without requiring the core package to import them -// directly. +// registerSink makes a sink factory available by name. The registration +// pattern allows sinks to be conditionally compiled via build tags without +// requiring the core package to import them directly. +// // Should be called from the init() of the sink file. func registerSink(name Sink, factory sinkFactory) { lock.Lock() @@ -327,7 +327,7 @@ func enableLocked() error { writer, err := openSinkLocked(newSink, globalSetup.appID) if err != nil { - return fmt.Errorf("cannot enable security logger: %w", err) + return fmt.Errorf("cannot enable security logger: %v", err) } // Wrap the writer with failure tracking so that repeated write @@ -413,6 +413,8 @@ type failureTrackingWriter struct { onThresholdReached func(failures int, lastErr error) } +var _ levelWriter = (*failureTrackingWriter)(nil) + func (w *failureTrackingWriter) Write(p []byte) (int, error) { n, err := w.writer.Write(p) if err != nil { diff --git a/seclog/seclog_test.go b/seclog/seclog_test.go index c741cd90b91..d7fae15bf33 100644 --- a/seclog/seclog_test.go +++ b/seclog/seclog_test.go @@ -95,7 +95,7 @@ func (s *SecLogSuite) TestString(c *C) { func (s *SecLogSuite) TestSnapdUserString(c *C) { // All fields set. c.Check(seclog.SnapdUser{ - ID: 42, StoreUserEmail: "a@b.com", SystemUserName: "jdoe", + ID: 42, StoreUserEmail: "a@b.com", StoreUserName: "jdoe", }.String(), Equals, "42:a@b.com:jdoe") // All fields zero/empty — all "unknown". @@ -108,7 +108,7 @@ func (s *SecLogSuite) TestSnapdUserString(c *C) { c.Check(seclog.SnapdUser{StoreUserEmail: "x@y.z"}.String(), Equals, "unknown:x@y.z:unknown") // Only username set. - c.Check(seclog.SnapdUser{SystemUserName: "root"}.String(), Equals, "unknown:unknown:root") + c.Check(seclog.SnapdUser{StoreUserName: "root"}.String(), Equals, "unknown:unknown:root") } func (s *SecLogSuite) TestReasonString(c *C) { @@ -127,7 +127,7 @@ func (s *SecLogSuite) TestReasonString(c *C) { c.Check(seclog.Reason{Message: "something broke"}.String(), Equals, "unknown:something broke") } -func (s *SecLogSuite) TestRegister(c *C) { +func (s *SecLogSuite) TestRegisterImpl(c *C) { restore := seclog.MockImplementations(map[seclog.Impl]seclog.ImplFactory{}) defer restore() @@ -138,7 +138,7 @@ func (s *SecLogSuite) TestRegister(c *C) { `attempting re-registration for existing logger "slog"`) } -func (s *SecLogSuite) TestRegisterSinkDuplicatePanics(c *C) { +func (s *SecLogSuite) TestRegisterSinkDuplicate(c *C) { restore := seclog.MockSinks(map[seclog.Sink]seclog.SinkFactory{}) defer restore() @@ -192,7 +192,7 @@ func (s *SecLogSuite) TestSetupSuccess(c *C) { c.Assert(err, IsNil) // verify the logger is functional by logging through it - seclog.LogLoginSuccess(seclog.SnapdUser{ID: 1, SystemUserName: "testuser"}) + seclog.LogLoginSuccess(seclog.SnapdUser{ID: 1, StoreUserName: "testuser"}) c.Check(s.buf.Len() > 0, Equals, true) } @@ -218,7 +218,7 @@ func (s *SecLogSuite) TestLogLoginSuccess(c *C) { user := seclog.SnapdUser{ ID: 42, StoreUserEmail: "user@example.com", - SystemUserName: "jdoe", + StoreUserName: "jdoe", } seclog.LogLoginSuccess(user) @@ -235,7 +235,7 @@ func (s *SecLogSuite) TestLogLoginSuccess(c *C) { c.Assert(ok, Equals, true) c.Check(userMap["snapd-user-id"], Equals, float64(42)) c.Check(userMap["store-user-email"], Equals, "user@example.com") - c.Check(userMap["system-user-name"], Equals, "jdoe") + c.Check(userMap["store-user-name"], Equals, "jdoe") c.Check(obtained["type"], Equals, "security") } @@ -245,7 +245,7 @@ func (s *SecLogSuite) TestLogLoginFailure(c *C) { user := seclog.SnapdUser{ ID: 42, StoreUserEmail: "user@example.com", - SystemUserName: "jdoe", + StoreUserName: "jdoe", } seclog.LogLoginFailure(user, seclog.Reason{Code: seclog.ReasonInvalidCredentials, Message: "invalid credentials"}) @@ -262,7 +262,7 @@ func (s *SecLogSuite) TestLogLoginFailure(c *C) { c.Assert(ok, Equals, true) c.Check(userMap["snapd-user-id"], Equals, float64(42)) c.Check(userMap["store-user-email"], Equals, "user@example.com") - c.Check(userMap["system-user-name"], Equals, "jdoe") + c.Check(userMap["store-user-name"], Equals, "jdoe") errMap, ok := obtained["error"].(map[string]any) c.Assert(ok, Equals, true) c.Check(errMap["code"], Equals, seclog.ReasonInvalidCredentials) @@ -311,7 +311,7 @@ func (s *SecLogSuite) TestDisableLogsDisabledEvent(c *C) { c.Check(obtained["event"], Equals, "sys_logging_disabled") } -func (s *SecLogSuite) TestDisableWithNoSetupReturnsError(c *C) { +func (s *SecLogSuite) TestDisableWithNoSetupIsNoop(c *C) { restoreCloser := seclog.MockGlobalCloser(nil) defer restoreCloser() restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) @@ -320,7 +320,7 @@ func (s *SecLogSuite) TestDisableWithNoSetupReturnsError(c *C) { defer restoreSetup() err := seclog.Disable() - c.Assert(err, ErrorMatches, "cannot disable security logger: setup has not been called") + c.Assert(err, IsNil) } func (s *SecLogSuite) TestEnableWithNoSetupReturnsError(c *C) { @@ -366,7 +366,7 @@ func (s *SecLogSuite) TestEnableAfterDisable(c *C) { user := seclog.SnapdUser{ ID: 1, StoreUserEmail: "a@b.com", - SystemUserName: "u", + StoreUserName: "u", } seclog.LogLoginSuccess(user) @@ -395,6 +395,27 @@ func (s *SecLogSuite) TestDisableIsIdempotent(c *C) { c.Assert(err, IsNil) } +func (s *SecLogSuite) TestEnableIsIdempotent(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return s.buf, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + // second call does not error + err = seclog.Enable() + c.Assert(err, IsNil) + + // logger is still functional + s.buf.Reset() + seclog.LogLoginSuccess(seclog.SnapdUser{ID: 1, StoreUserName: "test"}) + c.Check(s.buf.Len() > 0, Equals, true) +} + func (s *SecLogSuite) TestDisablePropagatesError(c *C) { tracker := &closeTracker{err: fmt.Errorf("disk full")} restoreCloser := seclog.MockGlobalCloser(tracker) @@ -409,6 +430,19 @@ func (s *SecLogSuite) TestDisablePropagatesError(c *C) { c.Assert(err, ErrorMatches, "disk full") } +func (s *SecLogSuite) TestEnablePropagatesError(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return nil, fmt.Errorf("sink unavailable") + }) + defer restore() + restoreSetup := seclog.MockGlobalSetup( + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo)) + defer restoreSetup() + + err := seclog.Enable() + c.Assert(err, ErrorMatches, "cannot enable security logger: sink unavailable") +} + // writeCloseTracker is a test helper that implements io.WriteCloser and // records whether Close was called. type writeCloseTracker struct { @@ -482,7 +516,7 @@ func (s *SecLogSuite) TestWriteFailuresDisableAfterThreshold(c *C) { c.Assert(err, IsNil) logBuf.Reset() - user := seclog.SnapdUser{ID: 1, SystemUserName: "test"} + user := seclog.SnapdUser{ID: 1, StoreUserName: "test"} // Exactly maxWriteFailures consecutive failures trigger auto-disable. for i := 0; i < seclog.MaxWriteFailures; i++ { @@ -508,7 +542,7 @@ func (s *SecLogSuite) TestWriteFailuresDoNotDisableBelowThreshold(c *C) { err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) c.Assert(err, IsNil) - user := seclog.SnapdUser{ID: 1, SystemUserName: "test"} + user := seclog.SnapdUser{ID: 1, StoreUserName: "test"} // Fewer than maxWriteFailures failures should not trigger auto-disable. for i := 0; i < seclog.MaxWriteFailures-1; i++ { @@ -535,7 +569,7 @@ func (s *SecLogSuite) TestWriteSuccessResetsFailureCount(c *C) { restoreFailures := seclog.MockWriteFailures(seclog.MaxWriteFailures - 1) defer restoreFailures() - user := seclog.SnapdUser{ID: 1, SystemUserName: "test"} + user := seclog.SnapdUser{ID: 1, StoreUserName: "test"} // A successful write resets the counter. seclog.LogLoginSuccess(user) @@ -621,7 +655,7 @@ func (s *SecLogSuite) TestFailureTrackingWriterPassesSetLevel(c *C) { lb.Reset() lb.levels = nil - seclog.LogLoginSuccess(seclog.SnapdUser{ID: 1, SystemUserName: "test"}) + seclog.LogLoginSuccess(seclog.SnapdUser{ID: 1, StoreUserName: "test"}) // The levelHandler should have called SetLevel on the underlying // levelBuf through the failureTrackingWriter wrapper. diff --git a/seclog/slog.go b/seclog/slog.go index 16458bd29c6..cd2bf9c98c7 100644 --- a/seclog/slog.go +++ b/seclog/slog.go @@ -149,7 +149,7 @@ func (u SnapdUser) LogValue() slog.Value { } return slog.GroupValue( slog.Int64("snapd-user-id", u.ID), - slog.String("system-user-name", u.SystemUserName), + slog.String("store-user-name", u.StoreUserName), slog.String("store-user-email", u.StoreUserEmail), slog.String("expiration", expiration), ) @@ -162,12 +162,14 @@ func (u SnapdUser) LogValue() slog.Value { // - time: key "datetime", formatted in UTC using [time.RFC3339Nano] // - level: rendered as a string via [Level.String] // - message: key "description" +// +// [newSlogLogger] adds additional built-in attributes to the logger context: // - app_id: always included with the value provided to newSlogLogger // - type: always included with the value "security" // // Additional attributes are preserved verbatim, including nested groups. The -// handler logs at or above the minLevel threshold. It does not -// close or sync writer. +// handler logs at or above the minLevel threshold. It does not close or sync +// writer. func newJsonHandler(writer io.Writer, minLevel slog.Leveler) slog.Handler { options := &slog.HandlerOptions{ Level: minLevel, @@ -227,8 +229,8 @@ func (h *levelHandler) WithAttrs(attrs []slog.Attr) slog.Handler { return &levelHandler{inner: h.inner.WithAttrs(attrs), lw: h.lw, mu: h.mu} } -// WithGroup is required by the [slog.Handler] interface but is not -// currently used by seclog. +// WithGroup is required by the [slog.Handler] interface but is not currently +// used by seclog. func (h *levelHandler) WithGroup(name string) slog.Handler { return &levelHandler{inner: h.inner.WithGroup(name), lw: h.lw, mu: h.mu} } diff --git a/seclog/slog_test.go b/seclog/slog_test.go index d74a79cf750..6c3b1315823 100644 --- a/seclog/slog_test.go +++ b/seclog/slog_test.go @@ -124,7 +124,7 @@ type attrsAllTypes struct { Timestamp time.Time `json:"timestamp"` Float64 float64 `json:"float64"` Int64 int64 `json:"int64"` - Int int `json:"int"` + Int int64 `json:"int"` Uint64 uint64 `json:"uint64"` Any any `json:"any"` } @@ -148,7 +148,7 @@ func (s *SlogSuite) TestHandlerAttrsAllTypes(c *C) { }, slog.Attr{Key: "float64", Value: slog.Float64Value(3.141592653589793)}, slog.Attr{Key: "int64", Value: slog.Int64Value(-4611686018427387904)}, - slog.Attr{Key: "int", Value: slog.IntValue(-4294967296)}, + slog.Attr{Key: "int", Value: slog.IntValue(-2147483648)}, slog.Attr{Key: "uint64", Value: slog.Uint64Value(4294967295)}, // AnyValue returns value of KindInt64, the original // numeric type is not preserved @@ -171,7 +171,7 @@ func (s *SlogSuite) TestHandlerAttrsAllTypes(c *C) { c.Check(obtained.Timestamp, Equals, time.Date(2025, 10, 8, 8, 0, 0, 0, time.UTC)) c.Check(obtained.Float64, Equals, float64(3.141592653589793)) c.Check(obtained.Int64, Equals, int64(-4611686018427387904)) - c.Check(obtained.Int, Equals, int(-4294967296)) + c.Check(obtained.Int, Equals, int64(-2147483648)) // 32 bit compatible c.Check(obtained.Uint64, Equals, uint64(4294967295)) c.Check(obtained.Any, DeepEquals, map[string]any{"k": "v", "n": float64(1)}) } @@ -185,7 +185,7 @@ func (s *SlogSuite) TestLogLoginSuccess(c *C) { Event string `json:"event"` User struct { ID int64 `json:"snapd-user-id"` - SystemUserName string `json:"system-user-name"` + StoreUserName string `json:"store-user-name"` StoreUserEmail string `json:"store-user-email"` Expiration string `json:"expiration"` } `json:"user"` @@ -194,7 +194,7 @@ func (s *SlogSuite) TestLogLoginSuccess(c *C) { user := seclog.SnapdUser{ ID: 42, StoreUserEmail: "user@gmail.com", - SystemUserName: "jdoe", + StoreUserName: "jdoe", } logger.LogLoginSuccess(user) @@ -208,7 +208,7 @@ func (s *SlogSuite) TestLogLoginSuccess(c *C) { c.Check(obtained.Event, Equals, "authn_login_success") c.Check(obtained.User.ID, Equals, int64(42)) c.Check(obtained.User.StoreUserEmail, Equals, "user@gmail.com") - c.Check(obtained.User.SystemUserName, Equals, "jdoe") + c.Check(obtained.User.StoreUserName, Equals, "jdoe") c.Check(obtained.User.Expiration, Equals, "never") // verify key order for human readability @@ -229,7 +229,7 @@ func (s *SlogSuite) TestLogLoginSuccessWithExpiration(c *C) { Event string `json:"event"` User struct { ID int64 `json:"snapd-user-id"` - SystemUserName string `json:"system-user-name"` + StoreUserName string `json:"store-user-name"` StoreUserEmail string `json:"store-user-email"` Expiration string `json:"expiration"` } `json:"user"` @@ -239,7 +239,7 @@ func (s *SlogSuite) TestLogLoginSuccessWithExpiration(c *C) { user := seclog.SnapdUser{ ID: 42, StoreUserEmail: "user@gmail.com", - SystemUserName: "jdoe", + StoreUserName: "jdoe", Expiration: expiry, } logger.LogLoginSuccess(user) @@ -259,7 +259,7 @@ func (s *SlogSuite) TestLogLoginFailure(c *C) { Event string `json:"event"` User struct { ID int64 `json:"snapd-user-id"` - SystemUserName string `json:"system-user-name"` + StoreUserName string `json:"store-user-name"` StoreUserEmail string `json:"store-user-email"` Expiration string `json:"expiration"` } `json:"user"` @@ -272,7 +272,7 @@ func (s *SlogSuite) TestLogLoginFailure(c *C) { user := seclog.SnapdUser{ ID: 42, StoreUserEmail: "user@gmail.com", - SystemUserName: "jdoe", + StoreUserName: "jdoe", } logger.LogLoginFailure(user, seclog.Reason{Code: seclog.ReasonInvalidCredentials, Message: "invalid credentials"}) @@ -286,7 +286,7 @@ func (s *SlogSuite) TestLogLoginFailure(c *C) { c.Check(obtained.Event, Equals, "authn_login_failure") c.Check(obtained.User.ID, Equals, int64(42)) c.Check(obtained.User.StoreUserEmail, Equals, "user@gmail.com") - c.Check(obtained.User.SystemUserName, Equals, "jdoe") + c.Check(obtained.User.StoreUserName, Equals, "jdoe") c.Check(obtained.User.Expiration, Equals, "never") c.Check(obtained.Error.Code, Equals, seclog.ReasonInvalidCredentials) c.Check(obtained.Error.Message, Equals, "invalid credentials") @@ -311,9 +311,6 @@ func (lb *levelBuf) SetLevel(l seclog.Level) { lb.levels = append(lb.levels, l) } -// Ensure levelBuf satisfies the interface. -var _ seclog.LevelWriter = (*levelBuf)(nil) - func (s *SlogSuite) TestLevelHandlerSetsLevelBeforeWrite(c *C) { lb := &levelBuf{} logger := seclog.SlogImplementation{}.New(lb, s.appID, seclog.LevelInfo) diff --git a/tests/main/security-logging/task.yaml b/tests/main/security-logging/task.yaml index e1e23e13860..6519f5744df 100644 --- a/tests/main/security-logging/task.yaml +++ b/tests/main/security-logging/task.yaml @@ -7,33 +7,103 @@ details: | when store credentials are available, that a successful login produces an "authn_login_success" event in the audit log. +# ubuntu-core: auditd is not available as a distro package +systems: [-ubuntu-core-*] + prepare: | - # Create an audit checkpoint so we only see events from this test. - ausearch --checkpoint stamp -m 1121 || true + # Ensure auditd (which provides ausearch) is installed and running. + if ! command -v ausearch; then + #shellcheck source=tests/lib/pkgdb.sh + . "$TESTSLIB/pkgdb.sh" + distro_install_package auditd + systemctl enable --now auditd.service + fi restore: | snap logout || true - rm -f stamp + rm -f stamp content execute: | + # Pin snapd journal search position to the current journal position + journal_pin() { + "$TESTSTOOLS"/journal-state start_new_log + } + + # Assert that snapd.service logs since the pinned position do not match expression + journal_nomatch() { + expression="$1" + ! "$TESTSTOOLS"/journal-state match-log "$expression" -n 1 -u snapd.service >/dev/null 2>&1 + } + + # Assert that snapd.service logs since the pinned position match expression + journal_match() { + expression="$1" + "$TESTSTOOLS"/journal-state match-log "$expression" -n 3 -u snapd.service >/dev/null 2>&1 + } + + # Pin audit search position to after the latest existing 1121 events + audit_pin() { + rm -f stamp content + ausearch --checkpoint stamp -m 1121 --raw >/dev/null 2>&1 || true + } + + # Dump all content since pinned position into file with provided name + audit_dump() { + filename=$1 + ausearch --start checkpoint --checkpoint stamp -m 1121 --raw > "$filename" 2>&1 || true + } + + # Wait for snapd to become active and responsive to API calls, retrying if necessary. + wait_snapd_ready() { + local retries="${1:-3}" + local delay="${2:-1}" + local attempt=1 + local output + + while (( attempt <= retries )); do + if output="$(snap debug api /v2/snaps 2>&1)"; then + echo "snapd ready" >&2 + return 0 + fi + + echo "snapd API not ready yet, retrying ($attempt/$retries): $output" >&2 + sleep "$delay" + ((attempt++)) + done + + echo "snapd API did not become ready after $retries attempts" >&2 + echo "last error: $output" >&2 + return 1 + } + + echo "Checking that the security logger is disabled in snapd stop and enabled on snapd start" + audit_pin + journal_pin + systemctl restart snapd.service + wait_snapd_ready + journal_match "security logger disabled" + journal_match "security logger enabled" + + echo "Checking that restart produces disable/enable audit events" + audit_dump content + MATCH 'sys_logging_disabled' < content + MATCH 'sys_logging_enabled' < content + echo "Checking that a failed login attempt produces an audit event" + audit_pin echo '{"email":"someemail@testing.com","password":"wrong-password"}' | \ snap debug api -X POST -H 'Content-Type: application/json' /v2/login || true - - # The audit log entry is the raw JSON payload sent by snapd. - ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'authn_login_failure' - ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'invalid-credentials' - ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'someemail@testing.com' + audit_dump content + MATCH 'authn_login_failure' < content + MATCH 'invalid-credentials' < content + MATCH 'someemail@testing.com' < content if [ -n "$SPREAD_STORE_USER" ] && [ -n "$SPREAD_STORE_PASSWORD" ]; then echo "Checking that a successful login produces an audit event" - # Reset the checkpoint so we only see the success event. - ausearch --checkpoint stamp -m 1121 || true - + audit_pin expect -d -f "$TESTSLIB"/successful_login.exp - - ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'authn_login_success' - ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH "$SPREAD_STORE_USER" - + audit_dump content + MATCH 'authn_login_success' < content + MATCH "$SPREAD_STORE_USER" < content snap logout fi