diff --git a/internal/restapi/location_params.go b/internal/restapi/location_params.go new file mode 100644 index 00000000..a781a53b --- /dev/null +++ b/internal/restapi/location_params.go @@ -0,0 +1,48 @@ +package restapi + +import ( + "net/http" + + "maglev.onebusaway.org/internal/utils" +) + +type LocationParams struct { + Lat float64 + Lon float64 + Radius float64 + LatSpan float64 + LonSpan float64 +} + +func (api *RestAPI) parseLocationParams(r *http.Request, fieldErrors map[string][]string) (*LocationParams, map[string][]string) { + queryParams := r.URL.Query() + + lat, fieldErrors := utils.ParseRequiredFloatParam(queryParams, "lat", fieldErrors) + lon, fieldErrors := utils.ParseRequiredFloatParam(queryParams, "lon", fieldErrors) + radius, fieldErrors := utils.ParseFloatParam(queryParams, "radius", fieldErrors) + latSpan, fieldErrors := utils.ParseFloatParam(queryParams, "latSpan", fieldErrors) + lonSpan, fieldErrors := utils.ParseFloatParam(queryParams, "lonSpan", fieldErrors) + + if len(fieldErrors) > 0 { + return nil, fieldErrors + } + + locationErrors := utils.ValidateLocationParams(lat, lon, radius, latSpan, lonSpan) + if len(locationErrors) > 0 { + if fieldErrors == nil { + fieldErrors = make(map[string][]string) + } + for k, v := range locationErrors { + fieldErrors[k] = append(fieldErrors[k], v...) + } + return nil, fieldErrors + } + + return &LocationParams{ + Lat: lat, + Lon: lon, + Radius: radius, + LatSpan: latSpan, + LonSpan: lonSpan, + }, fieldErrors +} diff --git a/internal/restapi/routes_for_location_handler.go b/internal/restapi/routes_for_location_handler.go index 492b587d..3edb704e 100644 --- a/internal/restapi/routes_for_location_handler.go +++ b/internal/restapi/routes_for_location_handler.go @@ -12,25 +12,15 @@ import ( func (api *RestAPI) routesForLocationHandler(w http.ResponseWriter, r *http.Request) { queryParams := r.URL.Query() - lat, fieldErrors := utils.ParseRequiredFloatParam(queryParams, "lat", nil) - lon, _ := utils.ParseRequiredFloatParam(queryParams, "lon", fieldErrors) - radius, _ := utils.ParseFloatParam(queryParams, "radius", fieldErrors) - latSpan, _ := utils.ParseFloatParam(queryParams, "latSpan", fieldErrors) - lonSpan, _ := utils.ParseFloatParam(queryParams, "lonSpan", fieldErrors) - maxCount, _ := utils.ParseMaxCount(queryParams, models.DefaultMaxCountForRoutes, fieldErrors) - query := queryParams.Get("query") + var fieldErrors map[string][]string + loc, fieldErrors := api.parseLocationParams(r, fieldErrors) + maxCount, fieldErrors := utils.ParseMaxCount(queryParams, models.DefaultMaxCountForRoutes, fieldErrors) if len(fieldErrors) > 0 { api.validationErrorResponse(w, r, fieldErrors) return } - - // Validate location parameters - locationErrors := utils.ValidateLocationParams(lat, lon, radius, latSpan, lonSpan) - if len(locationErrors) > 0 { - api.validationErrorResponse(w, r, locationErrors) - return - } + query := queryParams.Get("query") // Validate and sanitize query sanitizedQuery, err := utils.ValidateAndSanitizeQuery(query) @@ -42,6 +32,7 @@ func (api *RestAPI) routesForLocationHandler(w http.ResponseWriter, r *http.Requ return } query = strings.ToLower(sanitizedQuery) + radius := loc.Radius if radius == 0 { radius = models.DefaultSearchRadiusInMeters if query != "" { @@ -60,7 +51,7 @@ func (api *RestAPI) routesForLocationHandler(w http.ResponseWriter, r *http.Requ api.GtfsManager.RLock() defer api.GtfsManager.RUnlock() - stops := api.GtfsManager.GetStopsForLocation(ctx, lat, lon, radius, latSpan, lonSpan, query, maxCount, true, nil, time.Time{}) + stops := api.GtfsManager.GetStopsForLocation(ctx, loc.Lat, loc.Lon, radius, loc.LatSpan, loc.LonSpan, query, maxCount, true, nil, time.Time{}) var results = []models.Route{} routeIDs := map[string]bool{} @@ -77,7 +68,7 @@ func (api *RestAPI) routesForLocationHandler(w http.ResponseWriter, r *http.Requ agencies := utils.FilterAgencies(api.GtfsManager.GetAgencies(), agencyIDs) references := models.NewEmptyReferences() references.Agencies = agencies - response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, lat, lon, latSpan, lonSpan, radius), api.Clock, false) + response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, loc.Lat, loc.Lon, loc.LatSpan, loc.LonSpan, radius), api.Clock, false) api.sendResponse(w, r, response) return } @@ -141,7 +132,7 @@ func (api *RestAPI) routesForLocationHandler(w http.ResponseWriter, r *http.Requ references.Agencies = agencies references.Situations = situations - response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, lat, lon, latSpan, lonSpan, radius), api.Clock, isLimitExceeded) + response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, loc.Lat, loc.Lon, loc.LatSpan, loc.LonSpan, radius), api.Clock, isLimitExceeded) api.sendResponse(w, r, response) } diff --git a/internal/restapi/stops_for_location_handler.go b/internal/restapi/stops_for_location_handler.go index d3b087c5..6ce9541b 100644 --- a/internal/restapi/stops_for_location_handler.go +++ b/internal/restapi/stops_for_location_handler.go @@ -15,12 +15,9 @@ import ( func (api *RestAPI) stopsForLocationHandler(w http.ResponseWriter, r *http.Request) { queryParams := r.URL.Query() - lat, fieldErrors := utils.ParseRequiredFloatParam(queryParams, "lat", nil) - lon, _ := utils.ParseRequiredFloatParam(queryParams, "lon", fieldErrors) - radius, _ := utils.ParseFloatParam(queryParams, "radius", fieldErrors) - latSpan, _ := utils.ParseFloatParam(queryParams, "latSpan", fieldErrors) - lonSpan, _ := utils.ParseFloatParam(queryParams, "lonSpan", fieldErrors) - maxCount, _ := utils.ParseMaxCount(queryParams, models.DefaultMaxCountForStops, fieldErrors) + var fieldErrors map[string][]string + loc, fieldErrors := api.parseLocationParams(r, fieldErrors) + maxCount, fieldErrors := utils.ParseMaxCount(queryParams, models.DefaultMaxCountForStops, fieldErrors) query := queryParams.Get("query") var routeTypes []int @@ -74,12 +71,6 @@ func (api *RestAPI) stopsForLocationHandler(w http.ResponseWriter, r *http.Reque return } - locationErrors := utils.ValidateLocationParams(lat, lon, radius, latSpan, lonSpan) - if len(locationErrors) > 0 { - api.validationErrorResponse(w, r, locationErrors) - return - } - // Validate and sanitize query sanitizedQuery, err := utils.ValidateAndSanitizeQuery(query) if err != nil { @@ -102,7 +93,7 @@ func (api *RestAPI) stopsForLocationHandler(w http.ResponseWriter, r *http.Reque api.GtfsManager.RLock() defer api.GtfsManager.RUnlock() - stops := api.GtfsManager.GetStopsForLocation(ctx, lat, lon, radius, latSpan, lonSpan, query, maxCount, false, routeTypes, queryTime) + stops := api.GtfsManager.GetStopsForLocation(ctx, loc.Lat, loc.Lon, loc.Radius, loc.LatSpan, loc.LonSpan, query, maxCount, false, routeTypes, queryTime) // Referenced Java code: "here we sort by distance for possible truncation, but later it will be re-sorted by stopId" sort.SliceStable(stops, func(i, j int) bool { @@ -127,7 +118,7 @@ func (api *RestAPI) stopsForLocationHandler(w http.ResponseWriter, r *http.Reque references := models.NewEmptyReferences() references.Agencies = agencies references.Routes = routes - response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, lat, lon, latSpan, lonSpan, radius), api.Clock, false) + response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, loc.Lat, loc.Lon, loc.LatSpan, loc.LonSpan, loc.Radius), api.Clock, false) api.sendResponse(w, r, response) return } @@ -250,6 +241,6 @@ func (api *RestAPI) stopsForLocationHandler(w http.ResponseWriter, r *http.Reque references.Routes = routes references.Situations = situations - response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, lat, lon, latSpan, lonSpan, radius), api.Clock, isLimitExceeded) + response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, loc.Lat, loc.Lon, loc.LatSpan, loc.LonSpan, loc.Radius), api.Clock, isLimitExceeded) api.sendResponse(w, r, response) } diff --git a/internal/restapi/trips_for_location_handler.go b/internal/restapi/trips_for_location_handler.go index feaa16c3..69b63048 100644 --- a/internal/restapi/trips_for_location_handler.go +++ b/internal/restapi/trips_for_location_handler.go @@ -23,7 +23,7 @@ func (api *RestAPI) tripsForLocationHandler(w http.ResponseWriter, r *http.Reque defer api.GtfsManager.RUnlock() lat, lon, latSpan, lonSpan, includeTrip, includeSchedule, currentLocation, todayMidnight, serviceDate, fieldErrors, err := api.parseAndValidateRequest(r) - if fieldErrors != nil { + if len(fieldErrors) > 0 { api.validationErrorResponse(w, r, fieldErrors) return } @@ -129,14 +129,20 @@ func (api *RestAPI) parseAndValidateRequest(r *http.Request) ( todayMidnight time.Time, serviceDate time.Time, fieldErrors map[string][]string, - err error, + serverErr error, ) { + var loc *LocationParams + loc, fieldErrors = api.parseLocationParams(r, nil) + + if loc != nil { + lat = loc.Lat + lon = loc.Lon + latSpan = loc.LatSpan + lonSpan = loc.LonSpan + } + queryParams := r.URL.Query() - lat, fieldErrors = utils.ParseRequiredFloatParam(queryParams, "lat", nil) - lon, _ = utils.ParseRequiredFloatParam(queryParams, "lon", fieldErrors) - latSpan, _ = utils.ParseFloatParam(queryParams, "latSpan", fieldErrors) - lonSpan, _ = utils.ParseFloatParam(queryParams, "lonSpan", fieldErrors) includeTrip = queryParams.Get("includeTrip") == "true" includeSchedule = queryParams.Get("includeSchedule") == "true" @@ -146,9 +152,9 @@ func (api *RestAPI) parseAndValidateRequest(r *http.Request) ( } currentAgency := agencies[0] - currentLocation, err = time.LoadLocation(currentAgency.Timezone) - if err != nil { - return 0, 0, 0, 0, false, false, nil, time.Time{}, time.Time{}, nil, fmt.Errorf("invalid timezone for agency %q: %w", currentAgency.Id, err) + currentLocation, serverErr = time.LoadLocation(currentAgency.Timezone) + if serverErr != nil { + return 0, 0, 0, 0, false, false, nil, time.Time{}, time.Time{}, nil, fmt.Errorf("invalid timezone for agency %q: %w", currentAgency.Id, serverErr) } timeParam := queryParams.Get("time") @@ -156,26 +162,19 @@ func (api *RestAPI) parseAndValidateRequest(r *http.Request) ( todayMidnight = time.Date(currentTime.Year(), currentTime.Month(), currentTime.Day(), 0, 0, 0, 0, currentLocation) var timeFieldErrors map[string][]string - var success bool - _, serviceDate, timeFieldErrors, success = utils.ParseTimeParameter(timeParam, currentLocation) - for k, v := range timeFieldErrors { - fieldErrors[k] = append(fieldErrors[k], v...) - } - - ctx := r.Context() - if ctx.Err() != nil { - return 0, 0, 0, 0, false, false, nil, time.Time{}, time.Time{}, nil, ctx.Err() - } - - if !success { + _, serviceDate, timeFieldErrors, _ = utils.ParseTimeParameter(timeParam, currentLocation) + if len(timeFieldErrors) > 0 { if fieldErrors == nil { fieldErrors = make(map[string][]string) } + for k, v := range timeFieldErrors { + fieldErrors[k] = append(fieldErrors[k], v...) + } } - locationErrors := utils.ValidateLocationParams(lat, lon, 0, latSpan, lonSpan) - for k, v := range locationErrors { - fieldErrors[k] = append(fieldErrors[k], v...) + ctx := r.Context() + if ctx.Err() != nil { + return 0, 0, 0, 0, false, false, nil, time.Time{}, time.Time{}, nil, ctx.Err() } if len(fieldErrors) > 0 {