diff --git a/decoder/api_fort.go b/decoder/api_fort.go index c6cd33f0..8a06b654 100644 --- a/decoder/api_fort.go +++ b/decoder/api_fort.go @@ -331,9 +331,9 @@ func StationScanEndpoint(retrieveParameters ApiFortScan, dbDetails db.DbDetails) start := time.Now() for _, key := range returnKeys { - station, unlock, err := getStationRecordReadOnly(context.Background(), dbDetails, key, "API.GetScanStation") + station, unlock, err := GetStationRecordReadOnly(context.Background(), dbDetails, key, "API.GetScanStation") if err == nil && station != nil { - stationCopy := buildStationResult(station) + stationCopy := BuildStationResult(station) results = append(results, &stationCopy) } if unlock != nil { @@ -380,9 +380,9 @@ func FortCombinedScanEndpoint(retrieveParameters ApiFortScan, dbDetails db.DbDet stations := make([]*ApiStationResult, 0, len(stationKeys)) for _, key := range stationKeys { - station, unlock, err := getStationRecordReadOnly(context.Background(), dbDetails, key, "API.GetScanStationPokemon") + station, unlock, err := GetStationRecordReadOnly(context.Background(), dbDetails, key, "API.GetScanStationPokemon") if err == nil && station != nil { - stationCopy := buildStationResult(station) + stationCopy := BuildStationResult(station) stations = append(stations, &stationCopy) } if unlock != nil { diff --git a/decoder/api_station.go b/decoder/api_station.go index c581d9e6..88d9c1ee 100644 --- a/decoder/api_station.go +++ b/decoder/api_station.go @@ -27,7 +27,7 @@ type ApiStationResult struct { StationedPokemon null.String `json:"stationed_pokemon"` } -func buildStationResult(station *Station) ApiStationResult { +func BuildStationResult(station *Station) ApiStationResult { return ApiStationResult{ Id: station.Id, Lat: station.Lat, diff --git a/decoder/station_state.go b/decoder/station_state.go index 194d0b93..2199bee0 100644 --- a/decoder/station_state.go +++ b/decoder/station_state.go @@ -66,10 +66,10 @@ func peekStationRecord(stationId string, caller string) (*Station, func(), error return nil, nil, nil } -// getStationRecordReadOnly acquires lock but does NOT take snapshot. +// GetStationRecordReadOnly acquires lock but does NOT take snapshot. // Use for read-only checks. Will cause a backing database lookup. // Caller MUST call returned unlock function if non-nil. -func getStationRecordReadOnly(ctx context.Context, db db.DbDetails, stationId string, caller string) (*Station, func(), error) { +func GetStationRecordReadOnly(ctx context.Context, db db.DbDetails, stationId string, caller string) (*Station, func(), error) { // Check cache first if item := stationCache.Get(stationId); item != nil { station := item.Value() @@ -104,7 +104,7 @@ func getStationRecordReadOnly(ctx context.Context, db db.DbDetails, stationId st // getStationRecordForUpdate acquires lock AND takes snapshot for webhook comparison. // Caller MUST call returned unlock function if non-nil. func getStationRecordForUpdate(ctx context.Context, db db.DbDetails, stationId string, caller string) (*Station, func(), error) { - station, unlock, err := getStationRecordReadOnly(ctx, db, stationId, caller) + station, unlock, err := GetStationRecordReadOnly(ctx, db, stationId, caller) if err != nil || station == nil { return nil, nil, err } diff --git a/main.go b/main.go index 8e1f3e42..407da4ca 100644 --- a/main.go +++ b/main.go @@ -335,6 +335,7 @@ func main() { apiGroup.POST("/gym/search", SearchGyms) apiGroup.POST("/gym/scan", GymScan) apiGroup.POST("/pokestop/scan", PokestopScan) + apiGroup.POST("/station/query", GetStations) apiGroup.POST("/station/scan", StationScan) apiGroup.POST("/fort/scan", FortScan) apiGroup.POST("/reload-geojson", ReloadGeojson) diff --git a/routes.go b/routes.go index fd3c9a10..95c6ceea 100644 --- a/routes.go +++ b/routes.go @@ -607,6 +607,81 @@ func GetGyms(c *gin.Context) { c.JSON(http.StatusOK, out) } +// POST /api/station/query +// +// { "ids": ["stationid1", "stationid2", ...] } +func GetStations(c *gin.Context) { + type idsPayload struct { + IDs []string `json:"ids"` + } + + var payload idsPayload + if err := c.ShouldBindJSON(&payload); err != nil { + var arr []string + if err2 := c.ShouldBindJSON(&arr); err2 != nil { + log.Warnf("invalid JSON: %v / %v", err, err2) + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON body; expected {\"ids\":[...] }"}) + return + } + payload.IDs = arr + } + + seen := make(map[string]struct{}, len(payload.IDs)) + ids := make([]string, 0, len(payload.IDs)) + for _, id := range payload.IDs { + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + ids = append(ids, id) + } + + const maxIDs = 500 + if len(ids) > maxIDs { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{ + "error": "too many ids", + "max_supported": maxIDs, + }) + return + } + + if len(ids) == 0 { + c.JSON(http.StatusOK, []decoder.ApiStationResult{}) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + out := make([]decoder.ApiStationResult, 0, len(ids)) + for _, id := range ids { + s, unlock, err := decoder.GetStationRecordReadOnly(ctx, dbDetails, id, "API.GetStations") + if err != nil { + if unlock != nil { + unlock() + } + log.Warnf("error retrieving station %s: %v", id, err) + c.Status(http.StatusInternalServerError) + return + } + if s != nil { + out = append(out, decoder.BuildStationResult(s)) + } + if unlock != nil { + unlock() + } + if ctx.Err() != nil { + c.Status(http.StatusInternalServerError) + return + } + } + + c.JSON(http.StatusOK, out) +} + // POST /api/gym/search // Multiple filter combinations with AND logic //