diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2fdbd52..8eb8f82 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,7 +29,7 @@ jobs: - name: Install tools run: | - go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.60.1 + go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.11.4 go install mvdan.cc/gofumpt@v0.6.0 - name: Format check (gofumpt) diff --git a/.golangci.yml b/.golangci.yml index 463eac3..78c9781 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,26 +1,14 @@ -version: 2 +version: "2" run: timeout: 3m tests: true + linters: - disable-all: true - disable: - - errcheck - - unused + default: none enable: - govet - # revive and errcheck disabled to keep CLI glue concise - # - revive - # - errcheck -linters-settings: - revive: - rules: - - name: unused-parameter - disabled: true - - name: package-comments - disabled: true + issues: - exclude-use-default: false max-issues-per-linter: 0 max-same-issues: 0 diff --git a/README.md b/README.md index 0b05956..08c3d7c 100644 --- a/README.md +++ b/README.md @@ -24,9 +24,13 @@ chmod 600 ~/.config/eightctl/config.yaml # check pod state EIGHTCTL_EMAIL=you@example.com EIGHTCTL_PASSWORD=your-password eightctl status -# set temperature level (-100..100) +# set temperature level (-100..100); without --side, applies to all discovered sides/users eightctl temp 20 +# target a specific side when the household is split +eightctl temp -40 --side right +eightctl on --side left + # run daemon with your YAML schedule (see docs/example-schedule.yaml) eightctl daemon --dry-run ``` @@ -47,6 +51,13 @@ eightctl daemon --dry-run Use `--output table|json|csv` and `--fields field1,field2` to shape output. `--verbose` enables debug logs; `--quiet` hides the config banner. +## Household Targeting +- `status` shows discovered household targets by default when available, including `left` / `right` or inferred `solo`. +- `on`, `off`, and `temp` apply to all discovered household targets by default. +- Use `--side left|right|solo` to target one household side. +- Use `--target-user-id ` when you want to address a specific discovered user directly. +- For split households, `eightctl status --output json` is the quickest way to inspect available sides and user IDs. + ## Configuration Priority: flags > env vars (`EIGHTCTL_*`) > config file. diff --git a/docs/spec.md b/docs/spec.md index 56660f9..287b213 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -63,7 +63,7 @@ Travel: - `travel flight-status --flight` Household: -- `household summary|schedule|current-set|invitations` +- `household summary|schedule|current-set|invitations|devices|users|guests` Audio/temperature data helpers: - `tracks`, `feats` remain for backward compatibility. @@ -71,6 +71,9 @@ Audio/temperature data helpers: ## Output & UX - Output formats: table (default), json, csv via `--output`; `--fields` to select columns. - Logs via charmbracelet/log; `--verbose` for debug; `--quiet` hides config notice. +- `status` should prefer discovered household targets when available and display `left` / `right` or inferred `solo`. +- `on`, `off`, and `temp` should default to all discovered household targets unless narrowed with `--side` or `--target-user-id`. +- `temp` accepts negative positional levels such as `temp -40` without requiring `--`. ## Daemon Behavior - Reads YAML schedule (time, action on|off|temp, temperature with unit), minute tick, executes once per day, PID guard, SIGINT/SIGTERM graceful stop. @@ -78,7 +81,7 @@ Audio/temperature data helpers: ## Testing & Quality Gates - `go test ./...` (fast compile checks) — run before handoff. -- Formatting via `gofmt`; prefer `gofumpt`/`staticcheck` later. +- Formatting via `gofumpt`; lint via `golangci-lint` with `govet` enabled in repo config. - Live checks: `eightctl status`, `metrics summary`, `tempmode nap status` with test creds to validate auth + userId resolution. ## Prior Work (references) diff --git a/internal/client/eightsleep.go b/internal/client/eightsleep.go index 55fad67..a98293d 100644 --- a/internal/client/eightsleep.go +++ b/internal/client/eightsleep.go @@ -2,6 +2,7 @@ package client import ( "bytes" + "compress/gzip" "context" "crypto/tls" "encoding/json" @@ -10,6 +11,7 @@ import ( "io" "net/http" "net/url" + "strings" "time" "github.com/charmbracelet/log" @@ -18,10 +20,14 @@ import ( const ( defaultBaseURL = "https://client-api.8slp.net/v1" + defaultAppURL = "https://app-api.8slp.net" authURL = "https://auth-api.8slp.net/v1/tokens" // Extracted from the official Eight Sleep Android app v7.39.17 (public client creds) - defaultClientID = "0894c7f33bb94800a03f1f4df13a4f38" - defaultClientSecret = "f0954a3ed5763ba3d06834c73731a32f15f168f47d4f164751275def86db0c76" + defaultClientID = "0894c7f33bb94800a03f1f4df13a4f38" + defaultClientSecret = "f0954a3ed5763ba3d06834c73731a32f15f168f47d4f164751275def86db0c76" + maxRateLimitRetries = 2 + maxUnauthorizedRetries = 1 + defaultRetryDelay = 2 * time.Second ) // Client represents Eight Sleep API client. @@ -35,10 +41,58 @@ type Client struct { HTTP *http.Client BaseURL string + AppURL string token string tokenExp time.Time } +// breakpointBeforeAuthRequest exists solely to provide a stable debugger stop +// point immediately before the first outbound auth request is sent. +func breakpointBeforeAuthRequest(kind string, req *http.Request) { + _ = kind + _ = req +} + +func responseBodyBytes(resp *http.Response) ([]byte, error) { + reader := io.Reader(resp.Body) + if strings.Contains(strings.ToLower(resp.Header.Get("Content-Encoding")), "gzip") { + gz, err := gzip.NewReader(resp.Body) + if err != nil { + return nil, err + } + defer gz.Close() + reader = gz + } + return io.ReadAll(reader) +} + +func decodeJSONResponse(resp *http.Response, out any) error { + body, err := responseBodyBytes(resp) + if err != nil { + return err + } + return json.Unmarshal(body, out) +} + +func retryAfterDelay(resp *http.Response) time.Duration { + if resp == nil { + return defaultRetryDelay + } + value := strings.TrimSpace(resp.Header.Get("Retry-After")) + if value == "" { + return defaultRetryDelay + } + if seconds, err := time.ParseDuration(value + "s"); err == nil && seconds > 0 { + return seconds + } + if when, err := http.ParseTime(value); err == nil { + if delay := time.Until(when); delay > 0 { + return delay + } + } + return defaultRetryDelay +} + // New creates a Client. func New(email, password, userID, clientID, clientSecret string) *Client { @@ -62,6 +116,7 @@ func New(email, password, userID, clientID, clientSecret string) *Client { ClientSecret: clientSecret, HTTP: &http.Client{Timeout: 20 * time.Second, Transport: tr}, BaseURL: defaultBaseURL, + AppURL: defaultAppURL, } } @@ -100,6 +155,7 @@ func (c *Client) EnsureDeviceID(ctx context.Context) (string, error) { } var res struct { User struct { + Devices []string `json:"devices"` CurrentDevice struct { ID string `json:"id"` } `json:"currentDevice"` @@ -108,10 +164,14 @@ func (c *Client) EnsureDeviceID(ctx context.Context) (string, error) { if err := c.do(ctx, http.MethodGet, "/users/me", nil, nil, &res); err != nil { return "", err } - if res.User.CurrentDevice.ID == "" { + if res.User.CurrentDevice.ID != "" { + c.DeviceID = res.User.CurrentDevice.ID + return c.DeviceID, nil + } + if len(res.User.Devices) == 0 { return "", errors.New("no current device id") } - c.DeviceID = res.User.CurrentDevice.ID + c.DeviceID = res.User.Devices[0] return c.DeviceID, nil } @@ -120,8 +180,8 @@ func (c *Client) authTokenEndpoint(ctx context.Context) error { "grant_type": "password", "username": c.Email, "password": c.Password, - "client_id": "sleep-client", - "client_secret": "", + "client_id": c.ClientID, + "client_secret": c.ClientSecret, } body, _ := json.Marshal(payload) req, err := http.NewRequestWithContext(ctx, http.MethodPost, authURL, bytes.NewReader(body)) @@ -130,13 +190,14 @@ func (c *Client) authTokenEndpoint(ctx context.Context) error { } req.Header.Set("Content-Type", "application/json") + breakpointBeforeAuthRequest("oauth", req) resp, err := c.HTTP.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode >= 300 { - b, _ := io.ReadAll(resp.Body) + b, _ := responseBodyBytes(resp) log.Debug("token auth failed", "status", resp.Status, "headers", resp.Header, "body", string(b)) return fmt.Errorf("token auth failed: %s", resp.Status) } @@ -146,7 +207,7 @@ func (c *Client) authTokenEndpoint(ctx context.Context) error { ExpiresIn int `json:"expires_in"` UserID string `json:"userId"` } - if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + if err := decodeJSONResponse(resp, &res); err != nil { return err } if res.AccessToken == "" { @@ -183,13 +244,14 @@ func (c *Client) authLegacyLogin(ctx context.Context) error { req.Header.Set("Connection", "keep-alive") req.Header.Set("User-Agent", "okhttp/4.9.3") req.Header.Set("Accept-Encoding", "gzip") + breakpointBeforeAuthRequest("legacy", req) resp, err := c.HTTP.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode >= 300 { - b, _ := io.ReadAll(resp.Body) + b, _ := responseBodyBytes(resp) log.Debug("legacy login failed", "status", resp.Status, "headers", resp.Header, "body", string(b)) return fmt.Errorf("login failed: %s", string(b)) } @@ -200,7 +262,7 @@ func (c *Client) authLegacyLogin(ctx context.Context) error { ExpirationDate string `json:"expirationDate"` } `json:"session"` } - if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + if err := decodeJSONResponse(resp, &res); err != nil { return err } if res.Session.Token == "" { @@ -257,6 +319,18 @@ func (c *Client) requireUser(ctx context.Context) error { } func (c *Client) do(ctx context.Context, method, path string, query url.Values, body any, out any) error { + return c.doWithBase(ctx, c.BaseURL, method, path, query, body, out) +} + +func (c *Client) doApp(ctx context.Context, method, path string, query url.Values, body any, out any) error { + return c.doWithBase(ctx, c.AppURL, method, path, query, body, out) +} + +func (c *Client) doWithBase(ctx context.Context, baseURL, method, path string, query url.Values, body any, out any) error { + return c.doWithRetry(ctx, baseURL, method, path, query, body, out, maxRateLimitRetries, maxUnauthorizedRetries) +} + +func (c *Client) doWithRetry(ctx context.Context, baseURL, method, path string, query url.Values, body any, out any, remaining429 int, remaining401 int) error { if err := c.ensureToken(ctx); err != nil { return err } @@ -268,7 +342,7 @@ func (c *Client) do(ctx context.Context, method, path string, query url.Values, } rdr = bytes.NewReader(b) } - u := c.BaseURL + path + u := baseURL + path if len(query) > 0 { u += "?" + query.Encode() } @@ -289,44 +363,70 @@ func (c *Client) do(ctx context.Context, method, path string, query url.Values, } defer resp.Body.Close() if resp.StatusCode == http.StatusTooManyRequests { - time.Sleep(2 * time.Second) - return c.do(ctx, method, path, query, body, out) + if remaining429 == 0 { + b, _ := responseBodyBytes(resp) + return fmt.Errorf("api %s %s: rate limited after retries: %s", method, path, string(b)) + } + delay := retryAfterDelay(resp) + log.Debug("rate limited; retrying request", "method", method, "path", path, "delay", delay, "remaining_retries", remaining429) + time.Sleep(delay) + return c.doWithRetry(ctx, baseURL, method, path, query, body, out, remaining429-1, remaining401) } if resp.StatusCode == http.StatusUnauthorized { + if remaining401 == 0 { + b, _ := responseBodyBytes(resp) + return fmt.Errorf("api %s %s: unauthorized after re-auth retry: %s", method, path, string(b)) + } c.token = "" _ = tokencache.Clear(c.Identity()) if err := c.ensureToken(ctx); err != nil { return err } - return c.do(ctx, method, path, query, body, out) + return c.doWithRetry(ctx, baseURL, method, path, query, body, out, remaining429, remaining401-1) } if resp.StatusCode >= 300 { - b, _ := io.ReadAll(resp.Body) + b, _ := responseBodyBytes(resp) return fmt.Errorf("api %s %s: %s", method, path, string(b)) } if out != nil { - return json.NewDecoder(resp.Body).Decode(out) + return decodeJSONResponse(resp, out) } return nil } // TurnOn powers device on. func (c *Client) TurnOn(ctx context.Context) error { - return c.setPower(ctx, true) + return c.TurnOnForUser(ctx, "") } // TurnOff powers device off. func (c *Client) TurnOff(ctx context.Context) error { - return c.setPower(ctx, false) + return c.TurnOffForUser(ctx, "") } -func (c *Client) setPower(ctx context.Context, on bool) error { - if err := c.requireUser(ctx); err != nil { - return err +func (c *Client) TurnOnForUser(ctx context.Context, userID string) error { + return c.setPowerForUser(ctx, userID, true) +} + +func (c *Client) TurnOffForUser(ctx context.Context, userID string) error { + return c.setPowerForUser(ctx, userID, false) +} + +func (c *Client) setPowerForUser(ctx context.Context, userID string, on bool) error { + targetUserID := userID + if targetUserID == "" { + if err := c.requireUser(ctx); err != nil { + return err + } + targetUserID = c.UserID + } + path := fmt.Sprintf("/v1/users/%s/temperature", targetUserID) + state := "off" + if on { + state = "smart" } - path := fmt.Sprintf("/users/%s/devices/power", c.UserID) - body := map[string]bool{"on": on} - return c.do(ctx, http.MethodPost, path, nil, body, nil) + body := map[string]any{"currentState": map[string]string{"type": state}} + return c.doApp(ctx, http.MethodPut, path, nil, body, nil) } func (c *Client) Identity() tokencache.Identity { @@ -337,17 +437,34 @@ func (c *Client) Identity() tokencache.Identity { } } -// SetTemperature sets target heating/cooling level (-100..100). +// SetTemperature sets target heating/cooling level (-100..100) for the +// authenticated user's current pod side. func (c *Client) SetTemperature(ctx context.Context, level int) error { - if err := c.requireUser(ctx); err != nil { - return err - } + return c.SetTemperatureForUser(ctx, "", level) +} + +// SetTemperatureForUser sets target heating/cooling level (-100..100) for a +// specific household user ID. If userID is empty, the authenticated user's ID +// is resolved and used. +func (c *Client) SetTemperatureForUser(ctx context.Context, userID string, level int) error { if level < -100 || level > 100 { return fmt.Errorf("level must be between -100 and 100") } - path := fmt.Sprintf("/users/%s/temperature", c.UserID) + targetUserID := userID + if targetUserID == "" { + if err := c.requireUser(ctx); err != nil { + return err + } + targetUserID = c.UserID + } + path := fmt.Sprintf("/v1/users/%s/temperature", targetUserID) + if err := c.doApp(ctx, http.MethodPut, path, nil, map[string]any{ + "currentState": map[string]string{"type": "smart"}, + }, nil); err != nil { + return err + } body := map[string]int{"currentLevel": level} - return c.do(ctx, http.MethodPut, path, nil, body, nil) + return c.doApp(ctx, http.MethodPut, path, nil, body, nil) } // TempStatus represents current temperature state payload. @@ -360,12 +477,20 @@ type TempStatus struct { // GetStatus fetches temperature-based status (current mode/level). func (c *Client) GetStatus(ctx context.Context) (*TempStatus, error) { - if err := c.requireUser(ctx); err != nil { - return nil, err + return c.GetStatusForUser(ctx, "") +} + +func (c *Client) GetStatusForUser(ctx context.Context, userID string) (*TempStatus, error) { + targetUserID := userID + if targetUserID == "" { + if err := c.requireUser(ctx); err != nil { + return nil, err + } + targetUserID = c.UserID } - path := fmt.Sprintf("/users/%s/temperature", c.UserID) + path := fmt.Sprintf("/v1/users/%s/temperature", targetUserID) var res TempStatus - if err := c.do(ctx, http.MethodGet, path, nil, nil, &res); err != nil { + if err := c.doApp(ctx, http.MethodGet, path, nil, nil, &res); err != nil { return nil, err } return &res, nil diff --git a/internal/client/eightsleep_test.go b/internal/client/eightsleep_test.go index a6204ac..c59cfed 100644 --- a/internal/client/eightsleep_test.go +++ b/internal/client/eightsleep_test.go @@ -1,13 +1,23 @@ package client import ( + "bytes" + "compress/gzip" "context" + "io" "net/http" "net/http/httptest" + "strings" "testing" "time" ) +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + // mockServer builds a test server that can serve a handful of endpoints the client expects. func mockServer(t *testing.T) (*httptest.Server, *Client) { t.Helper() @@ -15,10 +25,10 @@ func mockServer(t *testing.T) (*httptest.Server, *Client) { mux.HandleFunc("/users/me", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"user":{"userId":"uid-123","currentDevice":{"id":"dev-1"}}}`)) + w.Write([]byte(`{"user":{"userId":"uid-123","devices":["dev-1"],"currentDevice":{"id":"dev-1"}}}`)) }) - mux.HandleFunc("/users/uid-123/temperature", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/v1/users/uid-123/temperature", func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodGet { w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"currentLevel":5,"currentState":{"type":"on"}}`)) @@ -46,6 +56,7 @@ func mockServer(t *testing.T) (*httptest.Server, *Client) { // client with pre-set token to skip auth c := New("email", "pass", "", "", "") c.BaseURL = srv.URL + c.AppURL = srv.URL c.token = "t" c.tokenExp = time.Now().Add(time.Hour) c.HTTP = srv.Client() @@ -101,3 +112,228 @@ func Test429Retry(t *testing.T) { t.Fatalf("expected backoff, got %v", elapsed) } } + +func Test429StopsAfterRetryLimit(t *testing.T) { + count := 0 + mux := http.NewServeMux() + mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { + count++ + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte(`{"error":"limit exceeded"}`)) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New("email", "pass", "uid", "", "") + c.BaseURL = srv.URL + c.token = "t" + c.tokenExp = time.Now().Add(time.Hour) + c.HTTP = srv.Client() + + start := time.Now() + err := c.do(context.Background(), http.MethodGet, "/ping", nil, nil, nil) + if err == nil { + t.Fatalf("expected rate limit error") + } + if !strings.Contains(err.Error(), "rate limited after retries") { + t.Fatalf("unexpected error: %v", err) + } + if count != maxRateLimitRetries+1 { + t.Fatalf("expected %d attempts, got %d", maxRateLimitRetries+1, count) + } + if elapsed := time.Since(start); elapsed < time.Duration(maxRateLimitRetries)*defaultRetryDelay { + t.Fatalf("expected retry delays, got %v", elapsed) + } +} + +func TestUnauthorizedStopsAfterReauthRetryLimit(t *testing.T) { + count := 0 + mux := http.NewServeMux() + mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { + count++ + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error":"bad token"}`)) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New("email", "pass", "uid", "", "") + c.BaseURL = srv.URL + c.token = "t" + c.tokenExp = time.Now().Add(time.Hour) + c.HTTP = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.String() == srv.URL+"/ping" { + return srv.Client().Transport.RoundTrip(req) + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"access_token":"tok","expires_in":3600,"userId":"uid"}`)), + Header: make(http.Header), + }, nil + }), + } + + err := c.do(context.Background(), http.MethodGet, "/ping", nil, nil, nil) + if err == nil { + t.Fatalf("expected unauthorized error") + } + if !strings.Contains(err.Error(), "unauthorized after re-auth retry") { + t.Fatalf("unexpected error: %v", err) + } + if count != maxUnauthorizedRetries+1 { + t.Fatalf("expected %d attempts, got %d", maxUnauthorizedRetries+1, count) + } +} + +func TestAuthTokenEndpointUsesClientCredentials(t *testing.T) { + c := New("user@example.com", "pass-123", "", "", "") + c.HTTP = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if req.Method != http.MethodPost { + t.Fatalf("method = %s, want POST", req.Method) + } + if got := req.URL.String(); got != authURL { + t.Fatalf("url = %s, want %s", got, authURL) + } + body, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + payload := string(body) + if !strings.Contains(payload, `"client_id":"`+defaultClientID+`"`) { + t.Fatalf("payload missing default client_id: %s", payload) + } + if !strings.Contains(payload, `"client_secret":"`+defaultClientSecret+`"`) { + t.Fatalf("payload missing default client_secret: %s", payload) + } + if strings.Contains(payload, `"client_id":"sleep-client"`) { + t.Fatalf("payload still contains legacy client_id: %s", payload) + } + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader( + `{"access_token":"tok","expires_in":3600,"userId":"uid-123"}`, + )), + Header: make(http.Header), + } + return resp, nil + }), + } + + if err := c.authTokenEndpoint(context.Background()); err != nil { + t.Fatalf("authTokenEndpoint: %v", err) + } + if c.token != "tok" { + t.Fatalf("token = %q, want tok", c.token) + } + if c.UserID != "uid-123" { + t.Fatalf("user id = %q, want uid-123", c.UserID) + } +} + +func TestDoHandlesGzipJSONResponse(t *testing.T) { + var payload bytes.Buffer + gz := gzip.NewWriter(&payload) + if _, err := gz.Write([]byte(`{"currentLevel":7,"currentState":{"type":"cooling"}}`)); err != nil { + t.Fatalf("gzip write: %v", err) + } + if err := gz.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + + mux := http.NewServeMux() + mux.HandleFunc("/v1/users/uid-123/temperature", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Header().Set("Content-Type", "application/json") + w.Write(payload.Bytes()) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New("email", "pass", "uid-123", "", "") + c.BaseURL = srv.URL + c.AppURL = srv.URL + c.token = "t" + c.tokenExp = time.Now().Add(time.Hour) + c.HTTP = srv.Client() + + st, err := c.GetStatus(context.Background()) + if err != nil { + t.Fatalf("GetStatus: %v", err) + } + if st.CurrentLevel != 7 || st.CurrentState.Type != "cooling" { + t.Fatalf("unexpected status %+v", st) + } +} + +func TestSetTemperatureForUserUsesExplicitUserID(t *testing.T) { + var gotPaths []string + var gotBodies []string + + mux := http.NewServeMux() + mux.HandleFunc("/v1/users/other-user/temperature", func(w http.ResponseWriter, r *http.Request) { + gotPaths = append(gotPaths, r.URL.Path) + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + gotBodies = append(gotBodies, string(body)) + w.WriteHeader(http.StatusNoContent) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New("email", "pass", "auth-user", "", "") + c.BaseURL = srv.URL + c.AppURL = srv.URL + c.token = "t" + c.tokenExp = time.Now().Add(time.Hour) + c.HTTP = srv.Client() + + if err := c.SetTemperatureForUser(context.Background(), "other-user", 12); err != nil { + t.Fatalf("SetTemperatureForUser: %v", err) + } + if len(gotPaths) != 2 { + t.Fatalf("expected 2 app requests, got %d", len(gotPaths)) + } + if gotPaths[0] != "/v1/users/other-user/temperature" || gotPaths[1] != "/v1/users/other-user/temperature" { + t.Fatalf("paths = %#v, want both /v1/users/other-user/temperature", gotPaths) + } + if gotBodies[0] != `{"currentState":{"type":"smart"}}` { + t.Fatalf("first body = %q, want smart currentState payload", gotBodies[0]) + } + if gotBodies[1] != `{"currentLevel":12}` { + t.Fatalf("second body = %q, want {\"currentLevel\":12}", gotBodies[1]) + } +} + +func TestTurnOnForUserUsesSmartCurrentState(t *testing.T) { + var gotBody string + + mux := http.NewServeMux() + mux.HandleFunc("/v1/users/other-user/temperature", func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + gotBody = string(body) + w.WriteHeader(http.StatusNoContent) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New("email", "pass", "auth-user", "", "") + c.BaseURL = srv.URL + c.AppURL = srv.URL + c.token = "t" + c.tokenExp = time.Now().Add(time.Hour) + c.HTTP = srv.Client() + + if err := c.TurnOnForUser(context.Background(), "other-user"); err != nil { + t.Fatalf("TurnOnForUser: %v", err) + } + if gotBody != `{"currentState":{"type":"smart"}}` { + t.Fatalf("body = %q, want smart currentState payload", gotBody) + } +} diff --git a/internal/client/household.go b/internal/client/household.go index 5ddfff0..945bbd5 100644 --- a/internal/client/household.go +++ b/internal/client/household.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "net/url" ) type HouseholdActions struct{ c *Client } @@ -14,9 +15,9 @@ func (h *HouseholdActions) Summary(ctx context.Context) (any, error) { if err := h.c.requireUser(ctx); err != nil { return nil, err } - path := fmt.Sprintf("/household/users/%s/summary", h.c.UserID) + path := fmt.Sprintf("/v1/household/users/%s/summary", h.c.UserID) var res any - err := h.c.do(ctx, http.MethodGet, path, nil, nil, &res) + err := h.c.doApp(ctx, http.MethodGet, path, nil, nil, &res) return res, err } @@ -24,9 +25,9 @@ func (h *HouseholdActions) Schedule(ctx context.Context) (any, error) { if err := h.c.requireUser(ctx); err != nil { return nil, err } - path := fmt.Sprintf("/household/users/%s/schedule", h.c.UserID) + path := fmt.Sprintf("/v1/household/users/%s/schedule", h.c.UserID) var res any - err := h.c.do(ctx, http.MethodGet, path, nil, nil, &res) + err := h.c.doApp(ctx, http.MethodGet, path, nil, nil, &res) return res, err } @@ -34,9 +35,9 @@ func (h *HouseholdActions) CurrentSet(ctx context.Context) (any, error) { if err := h.c.requireUser(ctx); err != nil { return nil, err } - path := fmt.Sprintf("/household/users/%s/current-set", h.c.UserID) + path := fmt.Sprintf("/v1/household/users/%s/current-set", h.c.UserID) var res any - err := h.c.do(ctx, http.MethodGet, path, nil, nil, &res) + err := h.c.doApp(ctx, http.MethodGet, path, nil, nil, &res) return res, err } @@ -44,9 +45,9 @@ func (h *HouseholdActions) Invitations(ctx context.Context) (any, error) { if err := h.c.requireUser(ctx); err != nil { return nil, err } - path := fmt.Sprintf("/household/users/%s/invitations", h.c.UserID) + path := fmt.Sprintf("/v1/household/users/%s/invitations", h.c.UserID) var res any - err := h.c.do(ctx, http.MethodGet, path, nil, nil, &res) + err := h.c.doApp(ctx, http.MethodGet, path, nil, nil, &res) return res, err } @@ -54,28 +55,79 @@ func (h *HouseholdActions) Devices(ctx context.Context) (any, error) { if err := h.c.requireUser(ctx); err != nil { return nil, err } - path := fmt.Sprintf("/household/users/%s/devices", h.c.UserID) - var res any - err := h.c.do(ctx, http.MethodGet, path, nil, nil, &res) - return res, err + path := fmt.Sprintf("/v1/household/users/%s/summary", h.c.UserID) + var res struct { + Households []struct { + Sets []struct { + Devices []map[string]any `json:"devices"` + } `json:"sets"` + } `json:"households"` + } + err := h.c.doApp(ctx, http.MethodGet, path, nil, nil, &res) + if err != nil { + return nil, err + } + out := []map[string]any{} + for _, household := range res.Households { + for _, set := range household.Sets { + out = append(out, set.Devices...) + } + } + return out, nil } func (h *HouseholdActions) Users(ctx context.Context) (any, error) { - if err := h.c.requireUser(ctx); err != nil { + targets, err := h.c.HouseholdUserTargets(ctx) + if err != nil { return nil, err } - path := fmt.Sprintf("/household/users/%s/users", h.c.UserID) - var res any - err := h.c.do(ctx, http.MethodGet, path, nil, nil, &res) - return res, err + out := make([]map[string]any, 0, len(targets)) + for _, target := range targets { + out = append(out, map[string]any{ + "userId": target.UserID, + "firstName": target.FirstName, + "lastName": target.LastName, + "email": target.Email, + "side": target.Side, + }) + } + return out, nil } func (h *HouseholdActions) Guests(ctx context.Context) (any, error) { if err := h.c.requireUser(ctx); err != nil { return nil, err } - path := fmt.Sprintf("/household/users/%s/guests", h.c.UserID) + path := fmt.Sprintf("/v1/household/users/%s/guests", h.c.UserID) var res any - err := h.c.do(ctx, http.MethodGet, path, nil, nil, &res) + err := h.c.doApp(ctx, http.MethodGet, path, nil, nil, &res) return res, err } + +func mapToValues(values map[string]string) url.Values { + out := make(url.Values, len(values)) + for key, value := range values { + out.Set(key, value) + } + return out +} + +func orderedUniqueStrings(values ...string) []string { + out := []string{} + for _, value := range values { + out = appendUniqueString(out, value) + } + return out +} + +func appendUniqueString(existing []string, value string) []string { + if value == "" { + return existing + } + for _, current := range existing { + if current == value { + return existing + } + } + return append(existing, value) +} diff --git a/internal/client/metrics.go b/internal/client/metrics.go index 0b114ad..0e6e2dc 100644 --- a/internal/client/metrics.go +++ b/internal/client/metrics.go @@ -11,13 +11,14 @@ type MetricsActions struct{ c *Client } func (c *Client) Metrics() *MetricsActions { return &MetricsActions{c: c} } -func (m *MetricsActions) Trends(ctx context.Context, from, to string, out any) error { +func (m *MetricsActions) Trends(ctx context.Context, from, to, timezone string, out any) error { if err := m.c.requireUser(ctx); err != nil { return err } q := url.Values{} q.Set("from", from) q.Set("to", to) + q.Set("tz", timezone) q.Set("include-main", "false") q.Set("include-all-sessions", "true") q.Set("model-version", "v2") diff --git a/internal/client/presence.go b/internal/client/presence.go index 8353488..e97b18b 100644 --- a/internal/client/presence.go +++ b/internal/client/presence.go @@ -4,21 +4,83 @@ import ( "context" "fmt" "net/http" + "net/url" + "time" ) -// Presence indicates if user is in bed. -type Presence struct { - Present bool `json:"presence"` +type trendSample struct { + Days []trendDay `json:"days"` } -func (c *Client) GetPresence(ctx context.Context) (bool, error) { +type trendDay struct { + Day string `json:"day"` + PresenceStart string `json:"presenceStart"` + PresenceEnd string `json:"presenceEnd"` + Sessions []trendSession `json:"sessions"` +} + +type trendSession struct { + Timeseries map[string][][]any `json:"timeseries"` +} + +func (c *Client) GetPresence(ctx context.Context, timezone string) (bool, error) { if err := c.requireUser(ctx); err != nil { return false, err } - path := fmt.Sprintf("/users/%s/presence", c.UserID) - var res Presence - if err := c.do(ctx, http.MethodGet, path, nil, nil, &res); err != nil { + + now := time.Now() + q := url.Values{} + q.Set("tz", timezone) + q.Set("from", now.Add(-24*time.Hour).Format("2006-01-02")) + q.Set("to", now.Format("2006-01-02")) + q.Set("include-main", "false") + q.Set("include-all-sessions", "true") + q.Set("model-version", "v2") + + path := fmt.Sprintf("/users/%s/trends", c.UserID) + var res trendSample + if err := c.do(ctx, http.MethodGet, path, q, nil, &res); err != nil { return false, err } - return res.Present, nil + return presenceFromTrendDays(res.Days, now.UTC()), nil +} + +func presenceFromTrendDays(days []trendDay, now time.Time) bool { + for i := len(days) - 1; i >= 0; i-- { + day := days[i] + if ts, ok := latestHeartRateTimestamp(day); ok { + age := now.Sub(ts) + if age >= 0 && age <= 10*time.Minute { + return true + } + if day.PresenceEnd == "" && age >= 0 && age <= 30*time.Minute { + return true + } + } + if day.PresenceStart != "" { + return day.PresenceEnd == "" + } + } + return false +} + +func latestHeartRateTimestamp(day trendDay) (time.Time, bool) { + for i := len(day.Sessions) - 1; i >= 0; i-- { + samples := day.Sessions[i].Timeseries["heartRate"] + for j := len(samples) - 1; j >= 0; j-- { + if len(samples[j]) == 0 { + continue + } + rawTS, ok := samples[j][0].(string) + if !ok || rawTS == "" { + continue + } + ts, err := time.Parse(time.RFC3339Nano, rawTS) + if err != nil { + continue + } + return ts.UTC(), true + } + } + return time.Time{}, false } diff --git a/internal/client/presence_test.go b/internal/client/presence_test.go new file mode 100644 index 0000000..ccf054c --- /dev/null +++ b/internal/client/presence_test.go @@ -0,0 +1,81 @@ +package client + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestPresenceFromTrendDays(t *testing.T) { + now := time.Date(2026, 4, 13, 12, 0, 0, 0, time.UTC) + + t.Run("active session with recent heart rate", func(t *testing.T) { + days := []trendDay{{ + PresenceStart: "2026-04-13T11:00:00Z", + Sessions: []trendSession{{ + Timeseries: map[string][][]any{ + "heartRate": {{"2026-04-13T11:55:00Z", 60}}, + }, + }}, + }} + if !presenceFromTrendDays(days, now) { + t.Fatalf("expected presence to be true") + } + }) + + t.Run("ended session is not present", func(t *testing.T) { + days := []trendDay{{ + PresenceStart: "2026-04-13T02:00:00Z", + PresenceEnd: "2026-04-13T09:00:00Z", + Sessions: []trendSession{{ + Timeseries: map[string][][]any{ + "heartRate": {{"2026-04-13T08:55:00Z", 55}}, + }, + }}, + }} + if presenceFromTrendDays(days, now) { + t.Fatalf("expected presence to be false") + } + }) +} + +func TestGetPresenceUsesTrendsEndpoint(t *testing.T) { + var gotPath string + var gotTZ string + + mux := http.NewServeMux() + mux.HandleFunc("/users/me", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"user":{"userId":"uid-123","devices":["dev-1"],"currentDevice":{"id":"dev-1"}}}`)) + }) + mux.HandleFunc("/users/uid-123/trends", func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotTZ = r.URL.Query().Get("tz") + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"days":[]}`)) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New("email", "pass", "", "", "") + c.BaseURL = srv.URL + c.token = "t" + c.tokenExp = time.Now().Add(time.Hour) + c.HTTP = srv.Client() + + present, err := c.GetPresence(context.Background(), "America/New_York") + if err != nil { + t.Fatalf("GetPresence: %v", err) + } + if present { + t.Fatalf("expected no presence from empty trends response") + } + if gotPath != "/users/uid-123/trends" { + t.Fatalf("path = %q, want /users/uid-123/trends", gotPath) + } + if gotTZ != "America/New_York" { + t.Fatalf("tz = %q, want America/New_York", gotTZ) + } +} diff --git a/internal/client/targets.go b/internal/client/targets.go new file mode 100644 index 0000000..1ea46f4 --- /dev/null +++ b/internal/client/targets.go @@ -0,0 +1,126 @@ +package client + +import ( + "context" + "fmt" + "net/http" + "strings" +) + +// HouseholdUserTarget describes a user that can be targeted for side-aware actions. +type HouseholdUserTarget struct { + UserID string + Side string + FirstName string + LastName string + Email string +} + +func (t HouseholdUserTarget) DisplayName() string { + name := strings.TrimSpace(strings.TrimSpace(t.FirstName + " " + t.LastName)) + if name != "" { + return name + } + if t.Email != "" { + return t.Email + } + return t.UserID +} + +func (t HouseholdUserTarget) SideLabel() string { + side := strings.TrimSpace(strings.ToLower(t.Side)) + if side == "" { + return "unknown" + } + return side +} + +// HouseholdUserTargets returns the household users that can be targeted for side-aware commands. +func (c *Client) HouseholdUserTargets(ctx context.Context) ([]HouseholdUserTarget, error) { + deviceID, err := c.EnsureDeviceID(ctx) + if err != nil { + return nil, err + } + var deviceRes struct { + Result struct { + LeftUserID string `json:"leftUserId"` + RightUserID string `json:"rightUserId"` + AwaySides map[string]string `json:"awaySides"` + } `json:"result"` + } + path := fmt.Sprintf("/devices/%s", deviceID) + query := mapToValues(map[string]string{ + "filter": "leftUserId,rightUserId,awaySides", + }) + if err := c.do(ctx, http.MethodGet, path, query, nil, &deviceRes); err != nil { + return nil, err + } + userIDs := orderedUniqueStrings( + deviceRes.Result.LeftUserID, + deviceRes.Result.RightUserID, + ) + for _, awayUserID := range deviceRes.Result.AwaySides { + userIDs = appendUniqueString(userIDs, awayUserID) + } + targets := make([]HouseholdUserTarget, 0, len(userIDs)) + for _, userID := range userIDs { + var userRes struct { + User struct { + UserID string `json:"userId"` + FirstName string `json:"firstName"` + LastName string `json:"lastName"` + Email string `json:"email"` + CurrentDevice struct { + Side string `json:"side"` + } `json:"currentDevice"` + } `json:"user"` + } + if err := c.do(ctx, http.MethodGet, fmt.Sprintf("/users/%s", userID), nil, nil, &userRes); err != nil { + return nil, err + } + targets = append(targets, HouseholdUserTarget{ + UserID: userRes.User.UserID, + Side: strings.ToLower(strings.TrimSpace(userRes.User.CurrentDevice.Side)), + FirstName: userRes.User.FirstName, + LastName: userRes.User.LastName, + Email: userRes.User.Email, + }) + } + if len(targets) == 1 && strings.TrimSpace(targets[0].Side) == "" { + targets[0].Side = "solo" + } + return targets, nil +} + +// ResolveHouseholdSide resolves a single user target for left/right/solo side-aware commands. +func ResolveHouseholdSide(targets []HouseholdUserTarget, side string) (*HouseholdUserTarget, error) { + side = strings.ToLower(strings.TrimSpace(side)) + switch side { + case "left", "right", "solo": + default: + return nil, fmt.Errorf("invalid side %q; expected left, right, or solo", side) + } + + matches := []HouseholdUserTarget{} + available := []string{} + for _, target := range targets { + if target.Side != "" { + available = appendUniqueString(available, target.Side) + } + if target.Side == side { + matches = append(matches, target) + } + } + + if len(matches) == 1 { + match := matches[0] + return &match, nil + } + if len(matches) > 1 { + return nil, fmt.Errorf("side %q maps to multiple household users; use --target-user-id", side) + } + if len(available) == 0 { + return nil, fmt.Errorf("could not resolve household side mapping; use --target-user-id") + } + return nil, fmt.Errorf("side %q is not available for this household; available sides: %s", side, strings.Join(available, ", ")) +} diff --git a/internal/client/targets_test.go b/internal/client/targets_test.go new file mode 100644 index 0000000..541efcd --- /dev/null +++ b/internal/client/targets_test.go @@ -0,0 +1,143 @@ +package client + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestResolveHouseholdSideLeftRight(t *testing.T) { + targets := []HouseholdUserTarget{ + {UserID: "left-user", Side: "left", FirstName: "Lefty"}, + {UserID: "right-user", Side: "right", FirstName: "Righty"}, + } + + target, err := ResolveHouseholdSide(targets, "right") + if err != nil { + t.Fatalf("ResolveHouseholdSide: %v", err) + } + if target.UserID != "right-user" { + t.Fatalf("user id = %q, want right-user", target.UserID) + } +} + +func TestResolveHouseholdSideSolo(t *testing.T) { + target, err := ResolveHouseholdSide([]HouseholdUserTarget{ + {UserID: "solo-user", Side: "solo"}, + }, "solo") + if err != nil { + t.Fatalf("ResolveHouseholdSide: %v", err) + } + if target.UserID != "solo-user" { + t.Fatalf("user id = %q, want solo-user", target.UserID) + } +} + +func TestResolveHouseholdSideUnavailable(t *testing.T) { + _, err := ResolveHouseholdSide([]HouseholdUserTarget{ + {UserID: "solo-user", Side: "solo"}, + }, "right") + if err == nil { + t.Fatalf("expected side resolution error") + } + if !strings.Contains(err.Error(), `side "right" is not available`) { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestResolveHouseholdSideUnknownMapping(t *testing.T) { + _, err := ResolveHouseholdSide([]HouseholdUserTarget{ + {UserID: "mystery-user"}, + }, "left") + if err == nil { + t.Fatalf("expected unknown mapping error") + } + if !strings.Contains(err.Error(), "could not resolve household side mapping") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestHouseholdUserTargets(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/users/me", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"user":{"devices":["dev-1"]}}`)) + }) + mux.HandleFunc("/devices/dev-1", func(w http.ResponseWriter, r *http.Request) { + if got := r.URL.Query().Get("filter"); got != "leftUserId,rightUserId,awaySides" { + t.Fatalf("filter = %q", got) + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"result":{"leftUserId":"left-user","rightUserId":"right-user"}}`)) + }) + mux.HandleFunc("/users/left-user", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"user":{"userId":"left-user","firstName":"Igor","lastName":"Left","email":"left@example.com","currentDevice":{"side":"left"}}}`)) + }) + mux.HandleFunc("/users/right-user", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"user":{"userId":"right-user","firstName":"Renata","lastName":"Right","email":"right@example.com","currentDevice":{"side":"right"}}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New("email", "pass", "", "", "") + c.BaseURL = srv.URL + c.token = "t" + c.tokenExp = time.Now().Add(time.Hour) + c.HTTP = srv.Client() + + targets, err := c.HouseholdUserTargets(context.Background()) + if err != nil { + t.Fatalf("HouseholdUserTargets: %v", err) + } + if len(targets) != 2 { + t.Fatalf("len(targets) = %d, want 2", len(targets)) + } + if targets[0].UserID != "left-user" || targets[0].Side != "left" { + t.Fatalf("unexpected left target: %+v", targets[0]) + } + if targets[1].UserID != "right-user" || targets[1].Side != "right" { + t.Fatalf("unexpected right target: %+v", targets[1]) + } +} + +func TestHouseholdUserTargetsInfersSoloWhenOnlyOneUserExists(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/users/me", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"user":{"devices":["dev-1"]}}`)) + }) + mux.HandleFunc("/devices/dev-1", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"result":{"leftUserId":"solo-user"}}`)) + }) + mux.HandleFunc("/users/solo-user", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"user":{"userId":"solo-user","firstName":"Solo","lastName":"Sleeper","email":"solo@example.com","currentDevice":{}}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New("email", "pass", "", "", "") + c.BaseURL = srv.URL + c.token = "t" + c.tokenExp = time.Now().Add(time.Hour) + c.HTTP = srv.Client() + + targets, err := c.HouseholdUserTargets(context.Background()) + if err != nil { + t.Fatalf("HouseholdUserTargets: %v", err) + } + if len(targets) != 1 { + t.Fatalf("len(targets) = %d, want 1", len(targets)) + } + if targets[0].Side != "solo" { + t.Fatalf("side = %q, want solo", targets[0].Side) + } +} diff --git a/internal/cmd/metrics.go b/internal/cmd/metrics.go index fdb9255..ecaddee 100644 --- a/internal/cmd/metrics.go +++ b/internal/cmd/metrics.go @@ -16,11 +16,21 @@ var metricsTrendsCmd = &cobra.Command{Use: "trends", RunE: func(cmd *cobra.Comma if err := requireAuthFields(); err != nil { return err } - from := viper.GetString("from") - to := viper.GetString("to") + from, err := cmd.Flags().GetString("from") + if err != nil { + return err + } + to, err := cmd.Flags().GetString("to") + if err != nil { + return err + } + tz, err := resolveAPITimezone(viper.GetString("timezone")) + if err != nil { + return err + } cl := client.New(viper.GetString("email"), viper.GetString("password"), viper.GetString("user_id"), viper.GetString("client_id"), viper.GetString("client_secret")) var out any - if err := cl.Metrics().Trends(context.Background(), from, to, &out); err != nil { + if err := cl.Metrics().Trends(context.Background(), from, to, tz, &out); err != nil { return err } return output.Print(output.Format(viper.GetString("output")), []string{"trends"}, []map[string]any{{"trends": out}}) @@ -30,7 +40,10 @@ var metricsIntervalsCmd = &cobra.Command{Use: "intervals", RunE: func(cmd *cobra if err := requireAuthFields(); err != nil { return err } - id := viper.GetString("id") + id, err := cmd.Flags().GetString("id") + if err != nil { + return err + } cl := client.New(viper.GetString("email"), viper.GetString("password"), viper.GetString("user_id"), viper.GetString("client_id"), viper.GetString("client_secret")) var out any if err := cl.Metrics().Intervals(context.Background(), id, &out); err != nil { diff --git a/internal/cmd/off.go b/internal/cmd/off.go index bb62f15..5a944b5 100644 --- a/internal/cmd/off.go +++ b/internal/cmd/off.go @@ -18,10 +18,28 @@ var offCmd = &cobra.Command{ return err } cl := client.New(viper.GetString("email"), viper.GetString("password"), viper.GetString("user_id"), viper.GetString("client_id"), viper.GetString("client_secret")) - if err := cl.TurnOff(context.Background()); err != nil { + targets, targeted, err := resolveCommandTargets(context.Background(), cmd, cl) + if err != nil { return err } - fmt.Println("pod turned off") + if targeted { + for _, target := range targets { + if err := cl.TurnOffForUser(context.Background(), target.UserID); err != nil { + return err + } + } + fmt.Printf("pod turned off%s\n", targetListSuffix(targets)) + return nil + } + + if err := cl.TurnOffForUser(context.Background(), ""); err != nil { + return err + } + fmt.Printf("pod turned off\n") return nil }, } + +func init() { + addTargetingFlags(offCmd, true) +} diff --git a/internal/cmd/on.go b/internal/cmd/on.go index 807a09b..41e5c7f 100644 --- a/internal/cmd/on.go +++ b/internal/cmd/on.go @@ -18,10 +18,28 @@ var onCmd = &cobra.Command{ return err } cl := client.New(viper.GetString("email"), viper.GetString("password"), viper.GetString("user_id"), viper.GetString("client_id"), viper.GetString("client_secret")) - if err := cl.TurnOn(context.Background()); err != nil { + targets, targeted, err := resolveCommandTargets(context.Background(), cmd, cl) + if err != nil { return err } - fmt.Println("pod turned on") + if targeted { + for _, target := range targets { + if err := cl.TurnOnForUser(context.Background(), target.UserID); err != nil { + return err + } + } + fmt.Printf("pod turned on%s\n", targetListSuffix(targets)) + return nil + } + + if err := cl.TurnOnForUser(context.Background(), ""); err != nil { + return err + } + fmt.Printf("pod turned on\n") return nil }, } + +func init() { + addTargetingFlags(onCmd, true) +} diff --git a/internal/cmd/presence.go b/internal/cmd/presence.go index 0b22f55..5aae198 100644 --- a/internal/cmd/presence.go +++ b/internal/cmd/presence.go @@ -2,12 +2,12 @@ package cmd import ( "context" - "fmt" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/steipete/eightctl/internal/client" + "github.com/steipete/eightctl/internal/output" ) var presenceCmd = &cobra.Command{ @@ -17,12 +17,15 @@ var presenceCmd = &cobra.Command{ if err := requireAuthFields(); err != nil { return err } + tz, err := resolveAPITimezone(viper.GetString("timezone")) + if err != nil { + return err + } cl := client.New(viper.GetString("email"), viper.GetString("password"), viper.GetString("user_id"), viper.GetString("client_id"), viper.GetString("client_secret")) - present, err := cl.GetPresence(context.Background()) + present, err := cl.GetPresence(context.Background(), tz) if err != nil { return err } - fmt.Printf("present: %v\n", present) - return nil + return output.Print(output.Format(viper.GetString("output")), []string{"present"}, []map[string]any{{"present": present}}) }, } diff --git a/internal/cmd/sleep.go b/internal/cmd/sleep.go index 8a22ff1..83c2548 100644 --- a/internal/cmd/sleep.go +++ b/internal/cmd/sleep.go @@ -2,7 +2,6 @@ package cmd import ( "context" - "time" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -24,13 +23,16 @@ var sleepDayCmd = &cobra.Command{ if err := requireAuthFields(); err != nil { return err } - date := viper.GetString("date") + date, err := cmd.Flags().GetString("date") + if err != nil { + return err + } if date == "" { - date = time.Now().Format("2006-01-02") + date = currentDate() } - tz := viper.GetString("timezone") - if tz == "local" { - tz = time.Local.String() + tz, err := resolveAPITimezone(viper.GetString("timezone")) + if err != nil { + return err } cl := client.New(viper.GetString("email"), viper.GetString("password"), viper.GetString("user_id"), viper.GetString("client_id"), viper.GetString("client_secret")) day, err := cl.GetSleepDay(context.Background(), date, tz) @@ -56,7 +58,7 @@ var sleepDayCmd = &cobra.Command{ } func init() { - sleepCmd.PersistentFlags().String("date", "", "date YYYY-MM-DD (default today)") - viper.BindPFlag("date", sleepCmd.PersistentFlags().Lookup("date")) + sleepDayCmd.Flags().String("date", "", "date YYYY-MM-DD (default today)") + viper.BindPFlag("date", sleepDayCmd.Flags().Lookup("date")) sleepCmd.AddCommand(sleepDayCmd) } diff --git a/internal/cmd/sleep_range.go b/internal/cmd/sleep_range.go index d31f1ec..6ef1c9f 100644 --- a/internal/cmd/sleep_range.go +++ b/internal/cmd/sleep_range.go @@ -19,8 +19,14 @@ var sleepRangeCmd = &cobra.Command{ if err := requireAuthFields(); err != nil { return err } - from := viper.GetString("from") - to := viper.GetString("to") + from, err := cmd.Flags().GetString("from") + if err != nil { + return err + } + to, err := cmd.Flags().GetString("to") + if err != nil { + return err + } if from == "" || to == "" { return fmt.Errorf("--from and --to are required") } @@ -36,9 +42,9 @@ var sleepRangeCmd = &cobra.Command{ if end.Before(start) { return fmt.Errorf("to must be >= from") } - tz := viper.GetString("timezone") - if tz == "local" { - tz = time.Local.String() + tz, err := resolveAPITimezone(viper.GetString("timezone")) + if err != nil { + return err } cl := client.New(viper.GetString("email"), viper.GetString("password"), viper.GetString("user_id"), viper.GetString("client_id"), viper.GetString("client_secret")) rows := []map[string]any{} diff --git a/internal/cmd/status.go b/internal/cmd/status.go index 9fb2a68..91265cf 100644 --- a/internal/cmd/status.go +++ b/internal/cmd/status.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "fmt" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -18,17 +19,98 @@ var statusCmd = &cobra.Command{ return err } cl := client.New(viper.GetString("email"), viper.GetString("password"), viper.GetString("user_id"), viper.GetString("client_id"), viper.GetString("client_secret")) - st, err := cl.GetStatus(context.Background()) + allSides, err := cmd.Flags().GetBool("all-sides") if err != nil { return err } - row := map[string]any{"mode": st.CurrentState.Type, "level": st.CurrentLevel} + target, err := resolveSelectedTarget(context.Background(), cmd, cl) + if err != nil { + return err + } + if allSides && target != nil { + return fmt.Errorf("use --all-sides by itself, not with --side or --target-user-id") + } + rows := []map[string]any{} + headers := []string{"mode", "level"} + if allSides { + targets, err := cl.HouseholdUserTargets(context.Background()) + if err != nil { + return err + } + rows, err = householdStatusRows(context.Background(), cl, targets) + if err != nil { + return err + } + headers = householdStatusHeaders() + } else { + rows, headers, err = defaultStatusRows(context.Background(), cl, target) + if err != nil { + return err + } + } fields := viper.GetStringSlice("fields") - rows := output.FilterFields([]map[string]any{row}, fields) - headers := fields - if len(headers) == 0 { - headers = []string{"mode", "level"} + rows = output.FilterFields(rows, fields) + if len(fields) > 0 { + headers = fields } return output.Print(output.Format(viper.GetString("output")), headers, rows) }, } + +func init() { + addTargetingFlags(statusCmd, true) + statusCmd.Flags().Bool("all-sides", false, "show status for all discovered household sides") +} + +func defaultStatusRows(ctx context.Context, cl *client.Client, target *client.HouseholdUserTarget) ([]map[string]any, []string, error) { + if target != nil { + st, err := cl.GetStatusForUser(ctx, target.UserID) + if err != nil { + return nil, nil, err + } + return []map[string]any{{ + "side": target.SideLabel(), + "name": target.DisplayName(), + "user_id": target.UserID, + "mode": st.CurrentState.Type, + "level": st.CurrentLevel, + }}, householdStatusHeaders(), nil + } + + targets, err := cl.HouseholdUserTargets(ctx) + if err == nil && len(targets) > 0 { + rows, err := householdStatusRows(ctx, cl, targets) + if err != nil { + return nil, nil, err + } + return rows, householdStatusHeaders(), nil + } + + st, err := cl.GetStatusForUser(ctx, "") + if err != nil { + return nil, nil, err + } + return []map[string]any{{"mode": st.CurrentState.Type, "level": st.CurrentLevel}}, []string{"mode", "level"}, nil +} + +func householdStatusRows(ctx context.Context, cl *client.Client, targets []client.HouseholdUserTarget) ([]map[string]any, error) { + rows := make([]map[string]any, 0, len(targets)) + for _, current := range targets { + st, err := cl.GetStatusForUser(ctx, current.UserID) + if err != nil { + return nil, err + } + rows = append(rows, map[string]any{ + "side": current.SideLabel(), + "name": current.DisplayName(), + "user_id": current.UserID, + "mode": st.CurrentState.Type, + "level": st.CurrentLevel, + }) + } + return rows, nil +} + +func householdStatusHeaders() []string { + return []string{"side", "name", "user_id", "mode", "level"} +} diff --git a/internal/cmd/targeting.go b/internal/cmd/targeting.go new file mode 100644 index 0000000..7202c47 --- /dev/null +++ b/internal/cmd/targeting.go @@ -0,0 +1,121 @@ +package cmd + +import ( + "context" + "fmt" + "strings" + + "github.com/spf13/cobra" + + "github.com/steipete/eightctl/internal/client" +) + +func addTargetingFlags(cmd *cobra.Command, includeTargetUser bool) { + cmd.Flags().String("side", "", "target household side: left|right|solo") + if includeTargetUser { + cmd.Flags().String("target-user-id", "", "set or query a specific household user ID") + } +} + +func resolveSelectedTarget(ctx context.Context, cmd *cobra.Command, cl *client.Client) (*client.HouseholdUserTarget, error) { + targetUserID, err := cmd.Flags().GetString("target-user-id") + if err != nil { + return nil, err + } + side, err := cmd.Flags().GetString("side") + if err != nil { + return nil, err + } + return resolveSelectedTargetValues(ctx, cl, targetUserID, side) +} + +func resolveCommandTargets(ctx context.Context, cmd *cobra.Command, cl *client.Client) ([]client.HouseholdUserTarget, bool, error) { + targetUserID, err := cmd.Flags().GetString("target-user-id") + if err != nil { + return nil, false, err + } + side, err := cmd.Flags().GetString("side") + if err != nil { + return nil, false, err + } + return resolveCommandTargetValues(ctx, cl, targetUserID, side) +} + +func resolveSelectedTargetValues(ctx context.Context, cl *client.Client, targetUserID string, side string) (*client.HouseholdUserTarget, error) { + if targetUserID != "" && side != "" { + return nil, fmt.Errorf("use either --target-user-id or --side, not both") + } + if side != "" { + targets, err := cl.HouseholdUserTargets(ctx) + if err != nil { + return nil, err + } + return client.ResolveHouseholdSide(targets, side) + } + if targetUserID == "" { + return nil, nil + } + + targets, err := cl.HouseholdUserTargets(ctx) + if err != nil { + return &client.HouseholdUserTarget{UserID: targetUserID}, nil + } + for _, target := range targets { + if target.UserID == targetUserID { + return &target, nil + } + } + return &client.HouseholdUserTarget{UserID: targetUserID}, nil +} + +func resolveCommandTargetValues(ctx context.Context, cl *client.Client, targetUserID string, side string) ([]client.HouseholdUserTarget, bool, error) { + if targetUserID != "" || side != "" { + target, err := resolveSelectedTargetValues(ctx, cl, targetUserID, side) + if err != nil { + return nil, false, err + } + if target == nil { + return nil, false, nil + } + return []client.HouseholdUserTarget{*target}, true, nil + } + + targets, err := cl.HouseholdUserTargets(ctx) + if err != nil || len(targets) == 0 { + return nil, false, nil + } + return targets, true, nil +} + +func targetSuffix(target *client.HouseholdUserTarget) string { + if target == nil { + return "" + } + if side := strings.TrimSpace(target.Side); side != "" { + return " for side " + side + } + if target.UserID != "" { + return " for user " + target.UserID + } + return "" +} + +func targetListSuffix(targets []client.HouseholdUserTarget) string { + if len(targets) == 0 { + return "" + } + if len(targets) == 1 { + target := targets[0] + return targetSuffix(&target) + } + + sides := []string{} + for _, target := range targets { + side := strings.TrimSpace(target.Side) + if side == "" { + return " for all discovered users" + } + sides = append(sides, side) + } + return " for sides " + strings.Join(sides, ", ") +} diff --git a/internal/cmd/targeting_test.go b/internal/cmd/targeting_test.go new file mode 100644 index 0000000..25e623c --- /dev/null +++ b/internal/cmd/targeting_test.go @@ -0,0 +1,34 @@ +package cmd + +import ( + "testing" + + "github.com/steipete/eightctl/internal/client" +) + +func TestTargetListSuffixSingle(t *testing.T) { + got := targetListSuffix([]client.HouseholdUserTarget{{UserID: "u1", Side: "right"}}) + if want := " for side right"; got != want { + t.Fatalf("suffix = %q, want %q", got, want) + } +} + +func TestTargetListSuffixMultipleSides(t *testing.T) { + got := targetListSuffix([]client.HouseholdUserTarget{ + {UserID: "u1", Side: "left"}, + {UserID: "u2", Side: "right"}, + }) + if want := " for sides left, right"; got != want { + t.Fatalf("suffix = %q, want %q", got, want) + } +} + +func TestTargetListSuffixMultipleUsersWithoutSides(t *testing.T) { + got := targetListSuffix([]client.HouseholdUserTarget{ + {UserID: "u1"}, + {UserID: "u2"}, + }) + if want := " for all discovered users"; got != want { + t.Fatalf("suffix = %q, want %q", got, want) + } +} diff --git a/internal/cmd/temp.go b/internal/cmd/temp.go index 494c10a..add45ad 100644 --- a/internal/cmd/temp.go +++ b/internal/cmd/temp.go @@ -3,6 +3,7 @@ package cmd import ( "context" "fmt" + "strings" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -12,22 +13,102 @@ import ( ) var tempCmd = &cobra.Command{ - Use: "temp ", - Short: "Set pod temperature (e.g., 68F, 20C, or heating level -100..100)", - Args: cobra.ExactArgs(1), + Use: "temp ", + Short: "Set pod temperature (e.g., 68F, 20C, or heating level -100..100)", + DisableFlagParsing: true, RunE: func(cmd *cobra.Command, args []string) error { if err := requireAuthFields(); err != nil { return err } - lvl, err := daemon.ParseTemp(args[0]) + tempValue, targetUserID, side, help, err := parseTempCommandArgs(args) + if err != nil { + return err + } + if help { + return cmd.Help() + } + lvl, err := daemon.ParseTemp(tempValue) if err != nil { return err } cl := client.New(viper.GetString("email"), viper.GetString("password"), viper.GetString("user_id"), viper.GetString("client_id"), viper.GetString("client_secret")) - if err := cl.SetTemperature(context.Background(), lvl); err != nil { + targets, targeted, err := resolveCommandTargetValues(context.Background(), cl, targetUserID, side) + if err != nil { + return err + } + if targeted { + for _, target := range targets { + if err := cl.SetTemperatureForUser(context.Background(), target.UserID, lvl); err != nil { + return err + } + } + fmt.Printf("temperature set (level %d)%s\n", lvl, targetListSuffix(targets)) + return nil + } + + if err := cl.SetTemperatureForUser(context.Background(), "", lvl); err != nil { return err } fmt.Printf("temperature set (level %d)\n", lvl) return nil }, } + +func init() { + addTargetingFlags(tempCmd, true) +} + +func parseTempCommandArgs(args []string) (tempValue string, targetUserID string, side string, help bool, err error) { + for i := 0; i < len(args); i++ { + arg := args[i] + switch { + case arg == "-h" || arg == "--help": + return "", "", "", true, nil + case arg == "--side": + i++ + if i >= len(args) { + return "", "", "", false, fmt.Errorf("flag needs an argument: --side") + } + side = args[i] + case strings.HasPrefix(arg, "--side="): + side = strings.TrimPrefix(arg, "--side=") + case arg == "--target-user-id": + i++ + if i >= len(args) { + return "", "", "", false, fmt.Errorf("flag needs an argument: --target-user-id") + } + targetUserID = args[i] + case strings.HasPrefix(arg, "--target-user-id="): + targetUserID = strings.TrimPrefix(arg, "--target-user-id=") + case arg == "--": + if i+1 >= len(args) { + return "", "", "", false, fmt.Errorf("requires exactly 1 temperature value") + } + if tempValue != "" || len(args[i+1:]) != 1 { + return "", "", "", false, fmt.Errorf("requires exactly 1 temperature value") + } + tempValue = args[i+1] + i = len(args) + case strings.HasPrefix(arg, "-") && !isNegativeTempCandidate(arg): + return "", "", "", false, fmt.Errorf("unknown flag: %s", arg) + default: + if tempValue != "" { + return "", "", "", false, fmt.Errorf("requires exactly 1 temperature value") + } + tempValue = arg + } + } + + if tempValue == "" { + return "", "", "", false, fmt.Errorf("requires exactly 1 temperature value") + } + return tempValue, targetUserID, side, false, nil +} + +func isNegativeTempCandidate(arg string) bool { + if len(arg) < 2 || arg[0] != '-' { + return false + } + b := arg[1] + return (b >= '0' && b <= '9') || b == '.' +} diff --git a/internal/cmd/temp_test.go b/internal/cmd/temp_test.go new file mode 100644 index 0000000..0dc5996 --- /dev/null +++ b/internal/cmd/temp_test.go @@ -0,0 +1,71 @@ +package cmd + +import "testing" + +func TestParseTempCommandArgsAllowsNegativeLevelBeforeFlags(t *testing.T) { + tempValue, targetUserID, side, help, err := parseTempCommandArgs([]string{"-40", "--side", "right"}) + if err != nil { + t.Fatalf("parse args: %v", err) + } + if help { + t.Fatalf("did not expect help") + } + if tempValue != "-40" { + t.Fatalf("tempValue = %q, want %q", tempValue, "-40") + } + if side != "right" { + t.Fatalf("side = %q, want %q", side, "right") + } + if targetUserID != "" { + t.Fatalf("targetUserID = %q, want empty", targetUserID) + } +} + +func TestParseTempCommandArgsAllowsNegativeCelsiusAfterFlags(t *testing.T) { + tempValue, targetUserID, side, help, err := parseTempCommandArgs([]string{"--target-user-id", "user-123", "-40C"}) + if err != nil { + t.Fatalf("parse args: %v", err) + } + if help { + t.Fatalf("did not expect help") + } + if tempValue != "-40C" { + t.Fatalf("tempValue = %q, want %q", tempValue, "-40C") + } + if side != "" { + t.Fatalf("side = %q, want empty", side) + } + if targetUserID != "user-123" { + t.Fatalf("targetUserID = %q, want %q", targetUserID, "user-123") + } +} + +func TestParseTempCommandArgsRejectsUnknownFlag(t *testing.T) { + _, _, _, _, err := parseTempCommandArgs([]string{"--bogus", "-40"}) + if err == nil { + t.Fatalf("expected error") + } + if got, want := err.Error(), "unknown flag: --bogus"; got != want { + t.Fatalf("error = %q, want %q", got, want) + } +} + +func TestParseTempCommandArgsRejectsMissingTemperature(t *testing.T) { + _, _, _, _, err := parseTempCommandArgs([]string{"--side", "left"}) + if err == nil { + t.Fatalf("expected error") + } + if got, want := err.Error(), "requires exactly 1 temperature value"; got != want { + t.Fatalf("error = %q, want %q", got, want) + } +} + +func TestParseTempCommandArgsHelp(t *testing.T) { + _, _, _, help, err := parseTempCommandArgs([]string{"--help"}) + if err != nil { + t.Fatalf("parse args: %v", err) + } + if !help { + t.Fatalf("expected help") + } +} diff --git a/internal/cmd/timezone.go b/internal/cmd/timezone.go new file mode 100644 index 0000000..eb13b3e --- /dev/null +++ b/internal/cmd/timezone.go @@ -0,0 +1,22 @@ +package cmd + +import ( + "fmt" + "strings" + "time" +) + +func resolveAPITimezone(value string) (string, error) { + tz := strings.TrimSpace(value) + if tz == "" || strings.EqualFold(tz, "local") { + tz = strings.TrimSpace(time.Local.String()) + } + if tz == "" || strings.EqualFold(tz, "local") { + return "", fmt.Errorf("timezone must be an explicit IANA timezone for sleep/metrics queries on this system; set --timezone or EIGHTCTL_TIMEZONE, e.g. America/New_York") + } + return tz, nil +} + +func currentDate() string { + return time.Now().Format("2006-01-02") +} diff --git a/internal/cmd/timezone_test.go b/internal/cmd/timezone_test.go new file mode 100644 index 0000000..b7ba411 --- /dev/null +++ b/internal/cmd/timezone_test.go @@ -0,0 +1,13 @@ +package cmd + +import "testing" + +func TestResolveAPITimezoneExplicit(t *testing.T) { + got, err := resolveAPITimezone("America/New_York") + if err != nil { + t.Fatalf("resolveAPITimezone: %v", err) + } + if got != "America/New_York" { + t.Fatalf("timezone = %q, want America/New_York", got) + } +} diff --git a/internal/tokencache/tokencache.go b/internal/tokencache/tokencache.go index 9d1c6a6..801d039 100644 --- a/internal/tokencache/tokencache.go +++ b/internal/tokencache/tokencache.go @@ -1,7 +1,9 @@ package tokencache import ( + "encoding/base64" "encoding/json" + "errors" "os" "path/filepath" "strings" @@ -12,8 +14,9 @@ import ( ) const ( - serviceName = "eightctl" - tokenKey = "oauth-token" + serviceName = "eightctl" + tokenKey = "oauth-token" + storageKeyV2Prefix = tokenKey + "_v2_" ) type CachedToken struct { @@ -75,7 +78,7 @@ func Save(id Identity, token string, expiresAt time.Time, userID string) error { return err } if err := ring.Set(keyring.Item{ - Key: cacheKey(id), + Key: storageKey(id), Label: serviceName + " token", Data: data, }); err != nil { @@ -92,8 +95,17 @@ func Load(id Identity, expectedUserID string) (*CachedToken, error) { log.Debug("keyring open failed (load)", "error", err) return nil, err } - key := cacheKey(id) + key := storageKey(id) item, err := ring.Get(key) + if err == keyring.ErrKeyNotFound { + legacyKey := cacheKey(id) + item, err = ring.Get(legacyKey) + if err == nil { + key = legacyKey + } else if isIgnorableLegacyKeyError(err) { + err = keyring.ErrKeyNotFound + } + } if err == keyring.ErrKeyNotFound && id.Email == "" { // No email specified: attempt to find a single matching token for this base/client. if alt, findErr := findSingleForClient(ring, id); findErr == nil { @@ -126,12 +138,13 @@ func Clear(id Identity) error { if err != nil { return err } - key := cacheKey(id) - if err := ring.Remove(key); err != nil { - if err == keyring.ErrKeyNotFound || os.IsNotExist(err) { - return nil + for _, key := range []string{storageKey(id), cacheKey(id)} { + if err := ring.Remove(key); err != nil { + if err == keyring.ErrKeyNotFound || os.IsNotExist(err) || isIgnorableLegacyKeyError(err) { + continue + } + return err } - return err } return nil } @@ -142,6 +155,36 @@ func cacheKey(id Identity) string { return tokenKey + ":" + base + "|" + id.ClientID + "|" + email } +func storageKey(id Identity) string { + return storageKeyV2Prefix + base64.RawURLEncoding.EncodeToString([]byte(cacheKey(id))) +} + +func identityKeyFromStorageKey(key string) (string, bool) { + if strings.HasPrefix(key, storageKeyV2Prefix) { + raw := strings.TrimPrefix(key, storageKeyV2Prefix) + decoded, err := base64.RawURLEncoding.DecodeString(raw) + if err != nil { + return "", false + } + return string(decoded), true + } + if strings.HasPrefix(key, tokenKey+":") { + return key, true + } + return "", false +} + +func isIgnorableLegacyKeyError(err error) bool { + if err == nil { + return false + } + var pathErr *os.PathError + if errors.As(err, &pathErr) { + return true + } + return strings.Contains(strings.ToLower(err.Error()), "filename, directory name, or volume label syntax is incorrect") +} + // findSingleForClient finds a single cached key for the given base/client when email is unknown. // Returns ErrKeyNotFound if none or multiple exist. func findSingleForClient(ring keyring.Keyring, id Identity) (string, error) { @@ -152,7 +195,8 @@ func findSingleForClient(ring keyring.Keyring, id Identity) (string, error) { prefix := tokenKey + ":" + strings.TrimSuffix(strings.ToLower(strings.TrimSpace(id.BaseURL)), "/") + "|" + id.ClientID + "|" matches := []string{} for _, k := range keys { - if strings.HasPrefix(k, prefix) { + identityKey, ok := identityKeyFromStorageKey(k) + if ok && strings.HasPrefix(identityKey, prefix) { matches = append(matches, k) } }