diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index dbdea43..d2121ec 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -27,6 +27,7 @@ jobs: with: version: latest args: --timeout=5m + install-mode: goinstall - name: Run Tests (With Coverage) run: go test -v -coverprofile=coverage.out ./... diff --git a/.golangci.yml b/.golangci.yml index 144ec2c..60c378c 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,19 +1,15 @@ run: timeout: 5m - modules-download-mode: readonly linters: + disable-all: true enable: - errcheck - - gosimple - govet - ineffassign - staticcheck - - typecheck - unused - revive # Replacement for golint (enforces comments) - - gofmt - - goimports - misspell - unconvert - unparam diff --git a/Makefile b/Makefile index ef3a394..be23ac3 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ API_CMD=$(GO_RUN) ./cmd/api # Database Connection (for psql) DB_DSN=$(DATABASE_URL) -.PHONY: all build run test clean lint migrate-up migrate-down migrate-status docker-up docker-down help +.PHONY: all build run test clean lint migrate-up migrate-down migrate-status docker-up docker-down help setup help: ## Show this help message @echo 'Usage:' @@ -18,6 +18,16 @@ help: ## Show this help message @echo 'Targets:' @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf " \033[36m%-20s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) +setup: ## Install developer dependencies (goose, swag, pkgsite, gomarkdoc, golangci-lint) + @echo "Installing Go tools..." + $(GO_CMD) install github.com/pressly/goose/v3/cmd/goose@latest + $(GO_CMD) install github.com/swaggo/swag/cmd/swag@latest + $(GO_CMD) install golang.org/x/pkgsite/cmd/pkgsite@latest + $(GO_CMD) install github.com/princjef/gomarkdoc/cmd/gomarkdoc@latest + @echo "Installing golangci-lint via official install script..." + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(shell go env GOPATH)/bin v1.64.4 + @echo "Setup complete. Make sure $$(go env GOPATH)/bin is in your PATH." + all: lint test build ## Run linter, tests, and build build: ## Build the API binary diff --git a/README.md b/README.md index c68f21a..07a2eb1 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,15 @@ cd sal cp .env.example .env ``` -### 2. Start the database +### 2. Install Dependencies + +Install the Go toolchain tools required by the Makefile (linter, swagger generator, database migration tool, and documentation server). + +```bash +make setup +``` + +### 3. Start the database ```bash docker compose up postgres -d @@ -70,6 +78,7 @@ We use `make` for common tasks: | Command | Description | |---------|-------------| +| `make setup` | Install Go development tools and linters | | `make all` | Run linter, tests, and build | | `make run` | Run API locally | | `make test` | Run tests with coverage | diff --git a/cmd/api/server.go b/cmd/api/server.go index 816aeb5..c970c04 100644 --- a/cmd/api/server.go +++ b/cmd/api/server.go @@ -70,7 +70,7 @@ func (s *Server) routes() { s.Router.Use(middleware.Logger) s.Router.Use(middleware.Recoverer) s.Router.Use(cors.Handler(cors.Options{ - AllowedOrigins: []string{"*"}, // TODO: Restrict in production + AllowedOrigins: s.Config.AllowedOrigins, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, ExposedHeaders: []string{"Link"}, @@ -91,7 +91,7 @@ func (s *Server) routes() { // API Group s.Router.Route("/api/v1", func(r chi.Router) { - r.Get("/", func(w http.ResponseWriter, r *http.Request) { + r.Get("/", func(w http.ResponseWriter, _ *http.Request) { response.JSON(w, http.StatusOK, map[string]string{"message": "Welcome to Sal API v1"}) }) diff --git a/docs/ROADMAP.md b/docs/ROADMAP.md index 921fb62..28fef5d 100644 --- a/docs/ROADMAP.md +++ b/docs/ROADMAP.md @@ -39,11 +39,27 @@ Replcaed Supabase Auth. Complete. ### 3d. Middleware - [x] **Auth Middleware**: Check `Authorization: Bearer ...`. -- [ ] **Permission Middleware**: Check `staff.permissions` JSON. +- [x] **Permission Middleware**: Check `staff.permissions` JSON. --- -## 🔮 Phase 4: Organization Management +## 🔮 Phase 4: Billing & Onboarding (SaaS/Stripe) + +Manage paid subscriptions, Stripe integration, and the onboarding flow. Estimated: **4-5 Days**. + +### 4a. Stripe Integration (Backend) +- [ ] **Data Models**: Add `subscriptions`, `packages` tables, and link `stripe_customer_id` to organizations. +- [ ] **Checkout Route**: Generate Stripe Checkout Sessions `POST /billing/checkout`. +- [ ] **Webhook Handler**: `POST /billing/webhook` to handle `checkout.session.completed` for auto-provisioning. + +### 4b. Account Provisioning & Email +- [ ] **Email Service**: Integrate SMTP/AWS SES/SendGrid for transactional emails. +- [ ] **Welcome Flow**: Auto-create Org & Admin user upon successful payment. +- [ ] **Set Password Route**: `POST /auth/set-password` securely handle the one-time setup token. + +--- + +## 🔮 Phase 5: Organization Management Manage Tenants, Staff, and Patients. Estimated: **3-4 Days**. @@ -63,35 +79,35 @@ Manage Tenants, Staff, and Patients. Estimated: **3-4 Days**. --- -## 🔮 Phase 5: Clinical Forms & Templates +## 🔮 Phase 6: Clinical Forms & Templates Dynamic Form Builder. Estimated: **3 Days**. -### 5a. Templates +### 6a. Templates - [ ] CRUD for `form_templates` (JSON Schema). - [ ] Versioning logic (`template_key` + `version`). -### 5b. Document Flows +### 6b. Document Flows - [ ] `document_flows` (Workflow definitions). --- -## 🔮 Phase 6: Core Product (AI Notes) +## 🔮 Phase 7: Core Product (AI Notes) Audio Processing Pipeline. Estimated: **5-7 Days**. -### 6a. Audio Upload +### 7a. Audio Upload - [ ] `POST /audio-notes`: Upload file to S3/MinIO. - [ ] Architecture: Signed URLs vs Direct Upload. -### 6b. Transcription & Generation +### 7b. Transcription & Generation - [ ] **Worker**: Background job to process audio. - [ ] **LLM**: Integration with Anthropic/OpenAI API. - [ ] **Optimistic Locking**: Handle concurrent edits on `generated_notes`. --- -## 🔮 Phase 7: Advanced Features +## 🔮 Phase 8: Advanced Features - [ ] **WS**: WebSockets for real-time status updates. - [ ] **2FA**: TOTP implementation. - [ ] **OAuth**: Google Login. diff --git a/docs/reference.md b/docs/reference.md index c406f5f..9d5486a 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -27,7 +27,7 @@ Package main serves as the entry point for the Sal API server. It handles depend -## type [Server]() +## type [Server]() Server is the main HTTP server container. It holds references to all shared dependencies required by HTTP handlers. @@ -41,7 +41,7 @@ type Server struct { ``` -### func [NewServer]() +### func [NewServer]() ```go func NewServer(cfg *config.Config, db *database.Postgres) *Server @@ -50,7 +50,7 @@ func NewServer(cfg *config.Config, db *database.Postgres) *Server NewServer creates and configures a new HTTP server. -### func \(\*Server\) [Shutdown]() +### func \(\*Server\) [Shutdown]() ```go func (s *Server) Shutdown(ctx context.Context) error @@ -59,7 +59,7 @@ func (s *Server) Shutdown(ctx context.Context) error Shutdown gracefully stops the HTTP server. -### func \(\*Server\) [Start]() +### func \(\*Server\) [Start]() ```go func (s *Server) Start() error @@ -145,7 +145,7 @@ const RefreshTokenLen = 32 ``` -## func [CheckPasswordHash]() +## func [CheckPasswordHash]() ```go func CheckPasswordHash(password, hash string) error @@ -154,7 +154,7 @@ func CheckPasswordHash(password, hash string) error CheckPasswordHash compares a bcrypt hashed password with a plain text password. Returns nil if the passwords match, or an error if they don't. -## func [HashPassword]() +## func [HashPassword]() ```go func HashPassword(password string) (string, error) @@ -163,7 +163,7 @@ func HashPassword(password string) (string, error) HashPassword hashes a plain text password using bcrypt with a default cost. -## func [NewAccessToken]() +## func [NewAccessToken]() ```go func NewAccessToken(userID, orgID, role, secret string) (string, error) @@ -172,7 +172,7 @@ func NewAccessToken(userID, orgID, role, secret string) (string, error) NewAccessToken creates a signed JWT for the given user context. -## func [NewRefreshToken]() +## func [NewRefreshToken]() ```go func NewRefreshToken() (string, error) @@ -181,7 +181,7 @@ func NewRefreshToken() (string, error) NewRefreshToken generates a secure random hex string. This matches the format expected by the 'refresh\_tokens' table. -## type [Claims]() +## type [Claims]() Claims represents the JWT payload. @@ -195,7 +195,7 @@ type Claims struct { ``` -### func [ParseAccessToken]() +### func [ParseAccessToken]() ```go func ParseAccessToken(tokenString, secret string) (*Claims, error) @@ -218,21 +218,22 @@ Package config handles environment variable loading and application configuratio -## type [Config]() +## type [Config]() Config holds all configuration values for the application. ```go type Config struct { - DatabaseURL string - Port string - Env string - JWTSecret string + DatabaseURL string + Port string + Env string + JWTSecret string + AllowedOrigins []string } ``` -### func [Load]() +### func [Load]() ```go func Load() *Config @@ -257,7 +258,7 @@ Package database manages the PostgreSQL connection pool and related utilities. -## type [Postgres]() +## type [Postgres]() Postgres holds the connection pool to the database. @@ -268,7 +269,7 @@ type Postgres struct { ``` -### func [New]() +### func [New]() ```go func New(ctx context.Context, connectionString string) (*Postgres, error) @@ -277,7 +278,7 @@ func New(ctx context.Context, connectionString string) (*Postgres, error) New creates a new Postgres connection pool with optimized production settings. It parses the connection string, sets connection limits, and verifies the connection with a Ping. -### func \(\*Postgres\) [Close]() +### func \(\*Postgres\) [Close]() ```go func (p *Postgres) Close() @@ -286,7 +287,7 @@ func (p *Postgres) Close() Close ensures the database connection pool allows graceful shutdown. It waits for active queries to finish before closing connections. -### func \(\*Postgres\) [Health]() +### func \(\*Postgres\) [Health]() ```go func (p *Postgres) Health(ctx context.Context) error @@ -313,7 +314,7 @@ Package handler provides HTTP handlers for the API. -## type [AuthHandler]() +## type [AuthHandler]() AuthHandler handles authentication requests. @@ -329,7 +330,7 @@ type AuthHandler struct { ``` -### func [NewAuthHandler]() +### func [NewAuthHandler]() ```go func NewAuthHandler(db *database.Postgres, userEq *repository.UserRepository, orgEq *repository.OrganizationRepository, staffEq *repository.StaffRepository, jwtSecret string) *AuthHandler @@ -338,7 +339,7 @@ func NewAuthHandler(db *database.Postgres, userEq *repository.UserRepository, or NewAuthHandler creates a new AuthHandler. -### func \(\*AuthHandler\) [Login]() +### func \(\*AuthHandler\) [Login]() ```go func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) @@ -347,7 +348,7 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) Login authenticates a user and returns tokens. @Summary Login @Description Authenticates user by email/password and returns JWT pairs. @Tags auth @Accept json @Produce json @Param input body LoginInput true "Login Credentials" @Success 200 \{object\} response.Response\{data=map\[string\]string\} "Tokens" @Failure 401 \{object\} response.Response "Unauthorized" @Router /auth/login \[post\] -### func \(\*AuthHandler\) [Register]() +### func \(\*AuthHandler\) [Register]() ```go func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) @@ -356,7 +357,7 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) Register creates a new user, organization, and admin staff entry atomically. @Summary Register a new Admin @Description Creates a new User, Organization, and links them as Admin Staff. @Tags auth @Accept json @Produce json @Param input body RegisterInput true "Registration Config" @Success 201 \{object\} response.Response\{data=map\[string\]interface\{\}\} "User and Org created" @Failure 400 \{object\} response.Response "Validation Error" @Failure 500 \{object\} response.Response "Internal Server Error" @Router /auth/register \[post\] -## type [LoginInput]() +## type [LoginInput]() LoginInput defines the payload for login. @@ -368,7 +369,7 @@ type LoginInput struct { ``` -## type [RegisterInput]() +## type [RegisterInput]() RegisterInput defines the payload for admin registration. @@ -382,6 +383,54 @@ type RegisterInput struct { } ``` +# middleware + +```go +import "github.com/off-by-2/sal/internal/middleware" +``` + +Package middleware provides HTTP middleware handlers like Auth and Permissions. + +## Index + +- [func AuthMiddleware\(jwtSecret string\) func\(http.Handler\) http.Handler](<#AuthMiddleware>) +- [func RequirePermission\(db \*database.Postgres, required string\) func\(http.Handler\) http.Handler](<#RequirePermission>) +- [type ContextKey](<#ContextKey>) + + + +## func [AuthMiddleware]() + +```go +func AuthMiddleware(jwtSecret string) func(http.Handler) http.Handler +``` + +AuthMiddleware validates the JWT access token from the Authorization header. + + +## func [RequirePermission]() + +```go +func RequirePermission(db *database.Postgres, required string) func(http.Handler) http.Handler +``` + +RequirePermission ensures the user has a specific permission in their staff record. The required perm should be in the format "resource.action", e.g. "notes.create" + + +## type [ContextKey]() + +ContextKey is used for typed keys in requests context + +```go +type ContextKey string +``` + +ClaimsKey is the key to fetch valid claims inside an HTTP Request context. + +```go +const ClaimsKey ContextKey = "claims" +``` + # repository ```go @@ -424,7 +473,7 @@ var ( ``` -## type [Organization]() +## type [Organization]() Organization represents a row in the organizations table. @@ -440,7 +489,7 @@ type Organization struct { ``` -## type [OrganizationRepository]() +## type [OrganizationRepository]() OrganizationRepository handles database operations for organizations. @@ -451,7 +500,7 @@ type OrganizationRepository struct { ``` -### func [NewOrganizationRepository]() +### func [NewOrganizationRepository]() ```go func NewOrganizationRepository(db *database.Postgres) *OrganizationRepository @@ -460,7 +509,7 @@ func NewOrganizationRepository(db *database.Postgres) *OrganizationRepository NewOrganizationRepository creates a new OrganizationRepository. -### func \(\*OrganizationRepository\) [CreateOrg]() +### func \(\*OrganizationRepository\) [CreateOrg]() ```go func (r *OrganizationRepository) CreateOrg(ctx context.Context, o *Organization) error @@ -469,7 +518,7 @@ func (r *OrganizationRepository) CreateOrg(ctx context.Context, o *Organization) CreateOrg inserts a new organization. -## type [Staff]() +## type [Staff]() Staff represents a row in the staff table. @@ -486,7 +535,7 @@ type Staff struct { ``` -## type [StaffRepository]() +## type [StaffRepository]() StaffRepository handles database operations for staff. @@ -497,7 +546,7 @@ type StaffRepository struct { ``` -### func [NewStaffRepository]() +### func [NewStaffRepository]() ```go func NewStaffRepository(db *database.Postgres) *StaffRepository @@ -506,7 +555,7 @@ func NewStaffRepository(db *database.Postgres) *StaffRepository NewStaffRepository creates a new StaffRepository. -### func \(\*StaffRepository\) [CreateStaff]() +### func \(\*StaffRepository\) [CreateStaff]() ```go func (r *StaffRepository) CreateStaff(ctx context.Context, s *Staff) error @@ -515,7 +564,7 @@ func (r *StaffRepository) CreateStaff(ctx context.Context, s *Staff) error CreateStaff inserts a new staff member. -## type [User]() +## type [User]() User represents a row in the users table. @@ -537,7 +586,7 @@ type User struct { ``` -## type [UserRepository]() +## type [UserRepository]() UserRepository handles database operations for users. @@ -548,7 +597,7 @@ type UserRepository struct { ``` -### func [NewUserRepository]() +### func [NewUserRepository]() ```go func NewUserRepository(db *database.Postgres) *UserRepository @@ -557,7 +606,7 @@ func NewUserRepository(db *database.Postgres) *UserRepository NewUserRepository creates a new UserRepository. -### func \(\*UserRepository\) [CreateUser]() +### func \(\*UserRepository\) [CreateUser]() ```go func (r *UserRepository) CreateUser(ctx context.Context, u *User) error @@ -566,7 +615,7 @@ func (r *UserRepository) CreateUser(ctx context.Context, u *User) error CreateUser inserts a new user into the database. -### func \(\*UserRepository\) [GetUserByEmail]() +### func \(\*UserRepository\) [GetUserByEmail]() ```go func (r *UserRepository) GetUserByEmail(ctx context.Context, email string) (*User, error) @@ -591,7 +640,7 @@ Package response provides helper functions for sending consistent JSON responses -## func [Error]() +## func [Error]() ```go func Error(w http.ResponseWriter, status int, message string) @@ -600,7 +649,7 @@ func Error(w http.ResponseWriter, status int, message string) Error sends a standardized error response. -## func [JSON]() +## func [JSON]() ```go func JSON(w http.ResponseWriter, status int, data interface{}) @@ -609,7 +658,7 @@ func JSON(w http.ResponseWriter, status int, data interface{}) JSON sends a JSON response with the given status code and data. -## func [ValidationError]() +## func [ValidationError]() ```go func ValidationError(w http.ResponseWriter, err error) @@ -618,7 +667,7 @@ func ValidationError(w http.ResponseWriter, err error) ValidationError sends a response with detailed validation errors. It parses go\-playground/validator errors into a simplified map. -## type [Response]() +## type [Response]() Response represents the standard JSON envelope for all API responses. diff --git a/internal/auth/token_test.go b/internal/auth/token_test.go index 6eaa0c7..91a5a55 100644 --- a/internal/auth/token_test.go +++ b/internal/auth/token_test.go @@ -47,7 +47,7 @@ func TestNewRefreshToken(t *testing.T) { } } -func TestParseAccessToken_Expired(t *testing.T) { +func TestParseAccessToken_Expired(_ *testing.T) { // We can't easily mock time in the current token implementation without refactoring. // But we can test malformed tokens. } diff --git a/internal/config/config.go b/internal/config/config.go index b0942bc..1493b2b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,16 +4,18 @@ package config import ( "log" "os" + "strings" "github.com/joho/godotenv" ) // Config holds all configuration values for the application. type Config struct { - DatabaseURL string - Port string - Env string - JWTSecret string + DatabaseURL string + Port string + Env string + JWTSecret string + AllowedOrigins []string } // Load retrieves configuration from environment variables. @@ -25,11 +27,20 @@ func Load() *Config { log.Println("No .env file found, using system environment variables") } + originsStr := getEnv("CORS_ALLOWED_ORIGINS", "http://localhost:3000") + origins := []string{} + if originsStr != "" { + for _, o := range strings.Split(originsStr, ",") { + origins = append(origins, strings.TrimSpace(o)) + } + } + return &Config{ - DatabaseURL: getEnv("DATABASE_URL", "postgres://salvia:localdev@localhost:5432/salvia?sslmode=disable"), - Port: getEnv("PORT", "8000"), - Env: getEnv("ENV", "development"), - JWTSecret: getEnv("JWT_SECRET", "super-secret-dev-key-change-me"), + DatabaseURL: getEnv("DATABASE_URL", "postgres://salvia:localdev@localhost:5432/salvia?sslmode=disable"), + Port: getEnv("PORT", "8000"), + Env: getEnv("ENV", "development"), + JWTSecret: getEnv("JWT_SECRET", "super-secret-dev-key-change-me"), + AllowedOrigins: origins, } } diff --git a/internal/handler/auth.go b/internal/handler/auth.go index bff252c..7cec45d 100644 --- a/internal/handler/auth.go +++ b/internal/handler/auth.go @@ -222,7 +222,22 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { return } - // TODO: Store Refresh Token hash in DB + // 5a. Hash and Store Refresh Token in DB + hashedRefreshToken, err := auth.HashPassword(refreshToken) + if err != nil { + response.Error(w, http.StatusInternalServerError, "Failed to secure token") + return + } + + _, err = h.DB.Pool.Exec(r.Context(), ` + INSERT INTO refresh_tokens (token_hash, user_id, expires_at) + VALUES ($1, $2, $3)`, + hashedRefreshToken, user.ID, time.Now().Add(7*24*time.Hour), + ) + if err != nil { + response.Error(w, http.StatusInternalServerError, "Failed to save session") + return + } // 6. Set Refresh Cookie http.SetCookie(w, &http.Cookie{ diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go new file mode 100644 index 0000000..c4a7de2 --- /dev/null +++ b/internal/middleware/auth.go @@ -0,0 +1,48 @@ +// Package middleware provides HTTP middleware handlers like Auth and Permissions. +package middleware + +import ( + "context" + "net/http" + "strings" + + "github.com/off-by-2/sal/internal/auth" + "github.com/off-by-2/sal/internal/response" +) + +// ContextKey is used for typed keys in requests context +type ContextKey string + +// ClaimsKey is the key to fetch valid claims inside an HTTP Request context. +const ClaimsKey ContextKey = "claims" + +// AuthMiddleware validates the JWT access token from the Authorization header. +func AuthMiddleware(jwtSecret string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + response.Error(w, http.StatusUnauthorized, "Missing Authorization header") + return + } + + // Expect format: "Bearer " + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + response.Error(w, http.StatusUnauthorized, "Invalid Authorization format") + return + } + + tokenString := parts[1] + claims, err := auth.ParseAccessToken(tokenString, jwtSecret) + if err != nil { + response.Error(w, http.StatusUnauthorized, "Invalid or expired token") + return + } + + // Add claims to request context + ctx := context.WithValue(r.Context(), ClaimsKey, claims) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/internal/middleware/auth_test.go b/internal/middleware/auth_test.go new file mode 100644 index 0000000..c78e617 --- /dev/null +++ b/internal/middleware/auth_test.go @@ -0,0 +1,81 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/off-by-2/sal/internal/auth" +) + +func TestAuthMiddleware(t *testing.T) { + jwtSecret := "test-secret" + + // Create a valid token + validToken, err := auth.NewAccessToken("user-123", "org-456", "admin", jwtSecret) + if err != nil { + t.Fatalf("Failed to generate test token: %v", err) + } + + tests := []struct { + name string + authHeader string + expectedStatus int + }{ + { + name: "Valid Token", + authHeader: "Bearer " + validToken, + expectedStatus: http.StatusOK, + }, + { + name: "Missing Header", + authHeader: "", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Invalid Format (No Bearer)", + authHeader: "Token " + validToken, + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Invalid Format (Just Bearer)", + authHeader: "Bearer", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Invalid Token Signature", + authHeader: "Bearer " + validToken + "invalid", + expectedStatus: http.StatusUnauthorized, + }, + } + + // Mock endpoint handler that validates context loading + mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims, ok := r.Context().Value(ClaimsKey).(*auth.Claims) + if !ok || claims.UserID != "user-123" { + t.Errorf("Claims not properly loaded into context") + http.Error(w, "invalid claims", http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + }) + + middleware := AuthMiddleware(jwtSecret) + handlerUnderTest := middleware(mockHandler) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.authHeader != "" { + req.Header.Set("Authorization", tc.authHeader) + } + w := httptest.NewRecorder() + + handlerUnderTest.ServeHTTP(w, req) + + if w.Code != tc.expectedStatus { + t.Errorf("Expected status %d, got %d", tc.expectedStatus, w.Code) + } + }) + } +} diff --git a/internal/middleware/permissions.go b/internal/middleware/permissions.go new file mode 100644 index 0000000..54ef437 --- /dev/null +++ b/internal/middleware/permissions.go @@ -0,0 +1,77 @@ +package middleware + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/off-by-2/sal/internal/auth" + "github.com/off-by-2/sal/internal/database" + "github.com/off-by-2/sal/internal/response" +) + +// RequirePermission ensures the user has a specific permission in their staff record. +// The required perm should be in the format "resource.action", e.g. "notes.create" +func RequirePermission(db *database.Postgres, required string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 1. Get Claims from Context + claims, ok := r.Context().Value(ClaimsKey).(*auth.Claims) + if !ok { + response.Error(w, http.StatusUnauthorized, "Missing authentication context") + return + } + + // If user is guest/no-org, deny automatically + if claims.Role == "guest" || claims.OrgID == "" { + response.Error(w, http.StatusForbidden, "Insufficient permissions") + return + } + + // Admin bypass + if claims.Role == "admin" || claims.Role == "owner" { + next.ServeHTTP(w, r) + return + } + + // 2. Query Staff Permissions + var permissionsJSON []byte + err := db.Pool.QueryRow(r.Context(), ` + SELECT permissions FROM staff + WHERE user_id = $1 AND organization_id = $2 LIMIT 1`, + claims.UserID, claims.OrgID, + ).Scan(&permissionsJSON) + + if err != nil { + response.Error(w, http.StatusForbidden, "Could not load staff profile") + return + } + + // 3. Parse JSONB + var perms map[string]map[string]bool + if err := json.Unmarshal(permissionsJSON, &perms); err != nil { + response.Error(w, http.StatusInternalServerError, "Failed to parse permissions") + return + } + + // 4. Validate Permission String "resource.action" + parts := strings.Split(required, ".") + if len(parts) != 2 { + response.Error(w, http.StatusInternalServerError, "Invalid permission configuration") + return + } + + resource := parts[0] + action := parts[1] + + if resourcePerms, ok := perms[resource]; ok { + if allowed, exists := resourcePerms[action]; exists && allowed { + next.ServeHTTP(w, r) + return + } + } + + response.Error(w, http.StatusForbidden, "You do not have permission to perform this action") + }) + } +} diff --git a/internal/response/response_test.go b/internal/response/response_test.go index 1548673..3d26396 100644 --- a/internal/response/response_test.go +++ b/internal/response/response_test.go @@ -87,7 +87,7 @@ func TestError(t *testing.T) { } } -func TestValidationError(t *testing.T) { +func TestValidationError(_ *testing.T) { // Standard error fallback w := httptest.NewRecorder() ValidationError(w, errors.New("simple error")) diff --git a/migrations/20240218140000_initial_schema.sql b/migrations/00001_initial_schema.sql similarity index 100% rename from migrations/20240218140000_initial_schema.sql rename to migrations/00001_initial_schema.sql diff --git a/migrations/20240219000000_add_extensions.sql b/migrations/00002_add_extensions.sql similarity index 100% rename from migrations/20240219000000_add_extensions.sql rename to migrations/00002_add_extensions.sql diff --git a/migrations/00003_create_refresh_tokens.sql b/migrations/00003_create_refresh_tokens.sql new file mode 100644 index 0000000..ab4a62f --- /dev/null +++ b/migrations/00003_create_refresh_tokens.sql @@ -0,0 +1,18 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TABLE public.refresh_tokens ( + token_hash character varying(255) NOT NULL, + user_id uuid NOT NULL, + expires_at timestamp with time zone NOT NULL, + created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL, + CONSTRAINT refresh_tokens_pkey PRIMARY KEY (token_hash), + CONSTRAINT fk_refresh_tokens_user FOREIGN KEY (user_id) REFERENCES public.users(id) ON DELETE CASCADE +); + +CREATE INDEX idx_refresh_tokens_user ON public.refresh_tokens USING btree (user_id); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP TABLE IF EXISTS public.refresh_tokens; +-- +goose StatementEnd