diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index b28848754..1d51c41c1 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -27,7 +27,6 @@ import ( internalConfig "github.com/canonical/microcluster/v3/internal/config" "github.com/canonical/microcluster/v3/internal/db" "github.com/canonical/microcluster/v3/internal/endpoints" - internalLog "github.com/canonical/microcluster/v3/internal/log" "github.com/canonical/microcluster/v3/internal/recover" internalREST "github.com/canonical/microcluster/v3/internal/rest" internalClient "github.com/canonical/microcluster/v3/internal/rest/client" @@ -153,7 +152,7 @@ func NewDaemon() *Daemon { // log is a convenience to retrieve the internal logger from the shutdown context. // We always expect the logger to be present. func (d *Daemon) log() *slog.Logger { - return d.shutdownCtx.Value(internalLog.CtxLogger).(*slog.Logger) //nolint:revive + return d.shutdownCtx.Value(types.CtxLogger).(*slog.Logger) //nolint:revive } // Run initializes the Daemon with the given configuration, starts the database, diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 054251f3c..da1fae401 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -2,7 +2,6 @@ package daemon import ( "context" - "log/slog" "path/filepath" "testing" @@ -13,7 +12,6 @@ import ( "github.com/canonical/microcluster/v3/internal/config" "github.com/canonical/microcluster/v3/internal/endpoints" - "github.com/canonical/microcluster/v3/internal/log" "github.com/canonical/microcluster/v3/internal/rest/client" "github.com/canonical/microcluster/v3/internal/sys" "github.com/canonical/microcluster/v3/internal/trust" @@ -189,7 +187,7 @@ func (t *daemonsSuite) Test_UpdateServers() { for i, test := range tests { t.T().Logf("%s (case %d)", test.name, i) - ctx := context.WithValue(context.TODO(), log.CtxLogger, slog.Default()) + ctx := types.ContextWithLogger(context.TODO()) commonDir := t.T().TempDir() // Create a temp watcher. diff --git a/internal/db/db_test.go b/internal/db/db_test.go index 5764226f1..4054071e8 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -5,7 +5,6 @@ import ( "database/sql" "encoding/json" "fmt" - "log/slog" "testing" "time" @@ -14,7 +13,6 @@ import ( "github.com/canonical/microcluster/v3/internal/cluster" "github.com/canonical/microcluster/v3/internal/db/update" - "github.com/canonical/microcluster/v3/internal/log" "github.com/canonical/microcluster/v3/internal/sys" clusterDB "github.com/canonical/microcluster/v3/microcluster/db" "github.com/canonical/microcluster/v3/microcluster/types" @@ -652,7 +650,7 @@ func (s *dbSuite) Test_waitUpgradeSchemaAndAPI() { func NewTestDB(extensionsExternal []clusterDB.Update) (*DqliteDB, error) { var err error - ctx := context.WithValue(context.Background(), log.CtxLogger, slog.Default()) + ctx := types.ContextWithLogger(context.Background()) db := &DqliteDB{ ctx: ctx, diff --git a/internal/db/dqlite.go b/internal/db/dqlite.go index a7b96be55..5f626e1c1 100644 --- a/internal/db/dqlite.go +++ b/internal/db/dqlite.go @@ -26,7 +26,6 @@ import ( "github.com/canonical/microcluster/v3/internal/cluster" "github.com/canonical/microcluster/v3/internal/db/update" - "github.com/canonical/microcluster/v3/internal/log" internalClient "github.com/canonical/microcluster/v3/internal/rest/client" "github.com/canonical/microcluster/v3/internal/sys" clusterDB "github.com/canonical/microcluster/v3/microcluster/db" @@ -109,7 +108,7 @@ func NewDB(ctx context.Context, serverCert func() *shared.CertInfo, clusterCert // log is a convenience to retrieve the internal logger from the database's context. // We always expect the logger to be present. func (db *DqliteDB) log() *slog.Logger { - return db.ctx.Value(log.CtxLogger).(*slog.Logger) //nolint:revive + return db.ctx.Value(types.CtxLogger).(*slog.Logger) //nolint:revive } // SetSchema sets schema and API extensions on the DB. diff --git a/internal/db/query/transaction_test.go b/internal/db/query/transaction_test.go index 5ad12f826..a5ebc7c13 100644 --- a/internal/db/query/transaction_test.go +++ b/internal/db/query/transaction_test.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "errors" - "log/slog" "testing" _ "github.com/mattn/go-sqlite3" @@ -12,8 +11,8 @@ import ( "github.com/stretchr/testify/require" "github.com/canonical/microcluster/v3/internal/db/query" - "github.com/canonical/microcluster/v3/internal/log" clusterDB "github.com/canonical/microcluster/v3/microcluster/db" + "github.com/canonical/microcluster/v3/microcluster/types" ) // Any error happening when beginning the transaction will be propagated. @@ -36,7 +35,7 @@ func TestTransaction_FunctionError(t *testing.T) { db := newDB(t) // Populate the context with the logger as this is required for a failing transaction. - ctx := context.WithValue(context.TODO(), log.CtxLogger, slog.Default()) + ctx := types.ContextWithLogger(context.TODO()) err := query.Transaction(ctx, db, func(ctx context.Context, tx *sql.Tx) error { _, err := tx.Exec("CREATE TABLE test (id INTEGER)") diff --git a/internal/endpoints/endpoints.go b/internal/endpoints/endpoints.go index c69d62ddb..4c92d8fba 100644 --- a/internal/endpoints/endpoints.go +++ b/internal/endpoints/endpoints.go @@ -9,7 +9,7 @@ import ( "github.com/canonical/lxd/shared" - "github.com/canonical/microcluster/v3/internal/log" + "github.com/canonical/microcluster/v3/microcluster/types" ) // Endpoints represents all listeners and servers for the microcluster daemon REST API. @@ -28,7 +28,7 @@ func NewEndpoints(shutdownCtx context.Context, endpoints map[string]Endpoint) *E // log is a convenience to retrieve the internal logger from the shutdown context. // We always expect the logger to be present. func (e *Endpoints) log() *slog.Logger { - return e.shutdownCtx.Value(log.CtxLogger).(*slog.Logger) //nolint:revive + return e.shutdownCtx.Value(types.CtxLogger).(*slog.Logger) //nolint:revive } // Up calls Serve on each of the configured listeners. diff --git a/internal/endpoints/network.go b/internal/endpoints/network.go index 9edaf02bd..9d61868ee 100644 --- a/internal/endpoints/network.go +++ b/internal/endpoints/network.go @@ -14,7 +14,7 @@ import ( "github.com/canonical/lxd/shared" - "github.com/canonical/microcluster/v3/internal/log" + "github.com/canonical/microcluster/v3/microcluster/types" ) // Network represents an HTTPS listener and its server. @@ -53,7 +53,7 @@ func NewNetwork(ctx context.Context, endpointType EndpointType, server *http.Ser // log is a convenience to retrieve the internal logger from the network's context. // We always expect the logger to be present. func (n *Network) log() *slog.Logger { - return n.ctx.Value(log.CtxLogger).(*slog.Logger) //nolint:revive + return n.ctx.Value(types.CtxLogger).(*slog.Logger) //nolint:revive } // Type returns the type of the Endpoint. diff --git a/internal/endpoints/socket.go b/internal/endpoints/socket.go index a854cd2db..4f4fee439 100644 --- a/internal/endpoints/socket.go +++ b/internal/endpoints/socket.go @@ -15,7 +15,7 @@ import ( "github.com/canonical/lxd/shared" - "github.com/canonical/microcluster/v3/internal/log" + "github.com/canonical/microcluster/v3/microcluster/types" ) // Socket represents a unix socket with a given path. @@ -50,7 +50,7 @@ func NewSocket(ctx context.Context, server *http.Server, path *url.URL, group st // log is a convenience to retrieve the internal logger from the socket's context. // We always expect the logger to be present. func (s *Socket) log() *slog.Logger { - return s.ctx.Value(log.CtxLogger).(*slog.Logger) //nolint:revive + return s.ctx.Value(types.CtxLogger).(*slog.Logger) //nolint:revive } // Type returns the type of the Endpoint. diff --git a/internal/log/logger.go b/internal/log/logger.go index 52b8c1031..cde3e4084 100644 --- a/internal/log/logger.go +++ b/internal/log/logger.go @@ -4,16 +4,13 @@ import ( "context" "errors" "log/slog" -) - -type ctxKey string -// CtxLogger is the name of the context value for the central logger. -const CtxLogger ctxKey = "logger" + "github.com/canonical/microcluster/v3/microcluster/types" +) // LoggerFromContext returns the logger from the given context. func LoggerFromContext(ctx context.Context) (*slog.Logger, error) { - logger, ok := ctx.Value(CtxLogger).(*slog.Logger) + logger, ok := ctx.Value(types.CtxLogger).(*slog.Logger) if !ok { return nil, errors.New("Logger does not exist on context") } diff --git a/microcluster/app.go b/microcluster/app.go index 2e084e7fd..c48440a51 100644 --- a/microcluster/app.go +++ b/microcluster/app.go @@ -100,7 +100,7 @@ func (m *MicroCluster) Start(ctx context.Context, daemonArgs DaemonArgs) error { } // Attach the logger to the parent context. - ctx = context.WithValue(ctx, log.CtxLogger, logger) + ctx = types.ContextWithLogger(ctx, logger) err := d.Run(ctx, m.FileSystem.StateDir(), daemonArgs) if err != nil { @@ -280,7 +280,7 @@ func (m *MicroCluster) RecoverFromQuorumLoss(members []types.DqliteMember) (stri // Derive a new context with the central logger attached. // As we don't have a running daemon at this stage, we cannot use its context. // Instead we use the logger populated for the app which uses the custom handler if supplied. - ctx := context.WithValue(context.Background(), log.CtxLogger, m.LoggerFromContext(context.Background())) + ctx := types.ContextWithLogger(context.Background(), m.LoggerFromContext(context.Background())) return recover.RecoverFromQuorumLoss(ctx, m.FileSystem, members) } diff --git a/microcluster/rest/response/response.go b/microcluster/rest/response/response.go index 0682367ab..edd631d75 100644 --- a/microcluster/rest/response/response.go +++ b/microcluster/rest/response/response.go @@ -3,6 +3,7 @@ package response import ( "bytes" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -118,6 +119,16 @@ func Unavailable(err error) Response { return &errorResponse{http.StatusServiceUnavailable, err} } +// Unauthorized returns an unauthorized response (401) with the given error. +func Unauthorized(err error) Response { + return &errorResponse{http.StatusUnauthorized, err} +} + +// ErrorResponse returns an error response with the given code and msg. +func ErrorResponse(code int, msg string) Response { + return &errorResponse{code, errors.New(msg)} +} + func (r *errorResponse) Render(w http.ResponseWriter, req *http.Request) error { buf := &bytes.Buffer{} resp := api.ResponseRaw{ diff --git a/microcluster/types/log.go b/microcluster/types/log.go new file mode 100644 index 000000000..0d4a19194 --- /dev/null +++ b/microcluster/types/log.go @@ -0,0 +1,24 @@ +package types + +import ( + "context" + "log/slog" +) + +type CtxKey string + +// CtxLogger is the name of the context value for the central logger. +const CtxLogger CtxKey = "logger" + +// ContextWithLogger returns a new context with the given logger. +// If no logger is provided, the default logger is used. +func ContextWithLogger(ctx context.Context, logger ...*slog.Logger) context.Context { + var l *slog.Logger + if len(logger) == 0 { + l = slog.Default() + } else { + l = logger[0] + } + + return context.WithValue(ctx, CtxLogger, l) +}