Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions internal/restapi/location_params.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package restapi

import (
"net/http"

"maglev.onebusaway.org/internal/utils"
)

type LocationParams struct {
Lat float64
Lon float64
Radius float64
LatSpan float64
LonSpan float64
}

func (api *RestAPI) parseLocationParams(r *http.Request, fieldErrors map[string][]string) (*LocationParams, map[string][]string) {
queryParams := r.URL.Query()

lat, fieldErrors := utils.ParseRequiredFloatParam(queryParams, "lat", fieldErrors)
lon, fieldErrors := utils.ParseRequiredFloatParam(queryParams, "lon", fieldErrors)
radius, fieldErrors := utils.ParseFloatParam(queryParams, "radius", fieldErrors)
latSpan, fieldErrors := utils.ParseFloatParam(queryParams, "latSpan", fieldErrors)
lonSpan, fieldErrors := utils.ParseFloatParam(queryParams, "lonSpan", fieldErrors)

if len(fieldErrors) > 0 {
return nil, fieldErrors
}

locationErrors := utils.ValidateLocationParams(lat, lon, radius, latSpan, lonSpan)
if len(locationErrors) > 0 {
if fieldErrors == nil {
fieldErrors = make(map[string][]string)
}
for k, v := range locationErrors {
fieldErrors[k] = append(fieldErrors[k], v...)
}
return nil, fieldErrors
}

return &LocationParams{
Lat: lat,
Lon: lon,
Radius: radius,
LatSpan: latSpan,
LonSpan: lonSpan,
}, fieldErrors
}
25 changes: 8 additions & 17 deletions internal/restapi/routes_for_location_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,15 @@ import (
func (api *RestAPI) routesForLocationHandler(w http.ResponseWriter, r *http.Request) {
queryParams := r.URL.Query()

lat, fieldErrors := utils.ParseRequiredFloatParam(queryParams, "lat", nil)
lon, _ := utils.ParseRequiredFloatParam(queryParams, "lon", fieldErrors)
radius, _ := utils.ParseFloatParam(queryParams, "radius", fieldErrors)
latSpan, _ := utils.ParseFloatParam(queryParams, "latSpan", fieldErrors)
lonSpan, _ := utils.ParseFloatParam(queryParams, "lonSpan", fieldErrors)
maxCount, _ := utils.ParseMaxCount(queryParams, models.DefaultMaxCountForRoutes, fieldErrors)
query := queryParams.Get("query")
var fieldErrors map[string][]string
loc, fieldErrors := api.parseLocationParams(r, fieldErrors)
maxCount, fieldErrors := utils.ParseMaxCount(queryParams, models.DefaultMaxCountForRoutes, fieldErrors)

if len(fieldErrors) > 0 {
api.validationErrorResponse(w, r, fieldErrors)
return
}

// Validate location parameters
locationErrors := utils.ValidateLocationParams(lat, lon, radius, latSpan, lonSpan)
if len(locationErrors) > 0 {
api.validationErrorResponse(w, r, locationErrors)
return
}
query := queryParams.Get("query")

// Validate and sanitize query
sanitizedQuery, err := utils.ValidateAndSanitizeQuery(query)
Expand All @@ -42,6 +32,7 @@ func (api *RestAPI) routesForLocationHandler(w http.ResponseWriter, r *http.Requ
return
}
query = strings.ToLower(sanitizedQuery)
radius := loc.Radius
if radius == 0 {
radius = models.DefaultSearchRadiusInMeters
if query != "" {
Expand All @@ -60,7 +51,7 @@ func (api *RestAPI) routesForLocationHandler(w http.ResponseWriter, r *http.Requ
api.GtfsManager.RLock()
defer api.GtfsManager.RUnlock()

stops := api.GtfsManager.GetStopsForLocation(ctx, lat, lon, radius, latSpan, lonSpan, query, maxCount, true, nil, time.Time{})
stops := api.GtfsManager.GetStopsForLocation(ctx, loc.Lat, loc.Lon, radius, loc.LatSpan, loc.LonSpan, query, maxCount, true, nil, time.Time{})

var results = []models.Route{}
routeIDs := map[string]bool{}
Expand All @@ -77,7 +68,7 @@ func (api *RestAPI) routesForLocationHandler(w http.ResponseWriter, r *http.Requ
agencies := utils.FilterAgencies(api.GtfsManager.GetAgencies(), agencyIDs)
references := models.NewEmptyReferences()
references.Agencies = agencies
response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, lat, lon, latSpan, lonSpan, radius), api.Clock, false)
response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, loc.Lat, loc.Lon, loc.LatSpan, loc.LonSpan, radius), api.Clock, false)
api.sendResponse(w, r, response)
return
}
Expand Down Expand Up @@ -141,7 +132,7 @@ func (api *RestAPI) routesForLocationHandler(w http.ResponseWriter, r *http.Requ
references.Agencies = agencies
references.Situations = situations

response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, lat, lon, latSpan, lonSpan, radius), api.Clock, isLimitExceeded)
response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, loc.Lat, loc.Lon, loc.LatSpan, loc.LonSpan, radius), api.Clock, isLimitExceeded)
api.sendResponse(w, r, response)
}

Expand Down
21 changes: 6 additions & 15 deletions internal/restapi/stops_for_location_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@ import (
func (api *RestAPI) stopsForLocationHandler(w http.ResponseWriter, r *http.Request) {
queryParams := r.URL.Query()

lat, fieldErrors := utils.ParseRequiredFloatParam(queryParams, "lat", nil)
lon, _ := utils.ParseRequiredFloatParam(queryParams, "lon", fieldErrors)
radius, _ := utils.ParseFloatParam(queryParams, "radius", fieldErrors)
latSpan, _ := utils.ParseFloatParam(queryParams, "latSpan", fieldErrors)
lonSpan, _ := utils.ParseFloatParam(queryParams, "lonSpan", fieldErrors)
maxCount, _ := utils.ParseMaxCount(queryParams, models.DefaultMaxCountForStops, fieldErrors)
var fieldErrors map[string][]string
loc, fieldErrors := api.parseLocationParams(r, fieldErrors)
maxCount, fieldErrors := utils.ParseMaxCount(queryParams, models.DefaultMaxCountForStops, fieldErrors)
query := queryParams.Get("query")

var routeTypes []int
Expand Down Expand Up @@ -74,12 +71,6 @@ func (api *RestAPI) stopsForLocationHandler(w http.ResponseWriter, r *http.Reque
return
}

locationErrors := utils.ValidateLocationParams(lat, lon, radius, latSpan, lonSpan)
if len(locationErrors) > 0 {
api.validationErrorResponse(w, r, locationErrors)
return
}

// Validate and sanitize query
sanitizedQuery, err := utils.ValidateAndSanitizeQuery(query)
if err != nil {
Expand All @@ -102,7 +93,7 @@ func (api *RestAPI) stopsForLocationHandler(w http.ResponseWriter, r *http.Reque
api.GtfsManager.RLock()
defer api.GtfsManager.RUnlock()

stops := api.GtfsManager.GetStopsForLocation(ctx, lat, lon, radius, latSpan, lonSpan, query, maxCount, false, routeTypes, queryTime)
stops := api.GtfsManager.GetStopsForLocation(ctx, loc.Lat, loc.Lon, loc.Radius, loc.LatSpan, loc.LonSpan, query, maxCount, false, routeTypes, queryTime)

// Referenced Java code: "here we sort by distance for possible truncation, but later it will be re-sorted by stopId"
sort.SliceStable(stops, func(i, j int) bool {
Expand All @@ -127,7 +118,7 @@ func (api *RestAPI) stopsForLocationHandler(w http.ResponseWriter, r *http.Reque
references := models.NewEmptyReferences()
references.Agencies = agencies
references.Routes = routes
response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, lat, lon, latSpan, lonSpan, radius), api.Clock, false)
response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, loc.Lat, loc.Lon, loc.LatSpan, loc.LonSpan, loc.Radius), api.Clock, false)
api.sendResponse(w, r, response)
return
}
Expand Down Expand Up @@ -250,6 +241,6 @@ func (api *RestAPI) stopsForLocationHandler(w http.ResponseWriter, r *http.Reque
references.Routes = routes
references.Situations = situations

response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, lat, lon, latSpan, lonSpan, radius), api.Clock, isLimitExceeded)
response := models.NewListResponseWithRange(results, references, checkIfOutOfBounds(api, loc.Lat, loc.Lon, loc.LatSpan, loc.LonSpan, loc.Radius), api.Clock, isLimitExceeded)
api.sendResponse(w, r, response)
}
47 changes: 23 additions & 24 deletions internal/restapi/trips_for_location_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (api *RestAPI) tripsForLocationHandler(w http.ResponseWriter, r *http.Reque
defer api.GtfsManager.RUnlock()

lat, lon, latSpan, lonSpan, includeTrip, includeSchedule, currentLocation, todayMidnight, serviceDate, fieldErrors, err := api.parseAndValidateRequest(r)
if fieldErrors != nil {
if len(fieldErrors) > 0 {
api.validationErrorResponse(w, r, fieldErrors)
return
}
Expand Down Expand Up @@ -129,14 +129,20 @@ func (api *RestAPI) parseAndValidateRequest(r *http.Request) (
todayMidnight time.Time,
serviceDate time.Time,
fieldErrors map[string][]string,
err error,
serverErr error,
) {
var loc *LocationParams
loc, fieldErrors = api.parseLocationParams(r, nil)

if loc != nil {
lat = loc.Lat
lon = loc.Lon
latSpan = loc.LatSpan
lonSpan = loc.LonSpan
}

queryParams := r.URL.Query()

lat, fieldErrors = utils.ParseRequiredFloatParam(queryParams, "lat", nil)
lon, _ = utils.ParseRequiredFloatParam(queryParams, "lon", fieldErrors)
latSpan, _ = utils.ParseFloatParam(queryParams, "latSpan", fieldErrors)
lonSpan, _ = utils.ParseFloatParam(queryParams, "lonSpan", fieldErrors)
includeTrip = queryParams.Get("includeTrip") == "true"
includeSchedule = queryParams.Get("includeSchedule") == "true"

Expand All @@ -146,36 +152,29 @@ func (api *RestAPI) parseAndValidateRequest(r *http.Request) (
}

currentAgency := agencies[0]
currentLocation, err = time.LoadLocation(currentAgency.Timezone)
if err != nil {
return 0, 0, 0, 0, false, false, nil, time.Time{}, time.Time{}, nil, fmt.Errorf("invalid timezone for agency %q: %w", currentAgency.Id, err)
currentLocation, serverErr = time.LoadLocation(currentAgency.Timezone)
if serverErr != nil {
return 0, 0, 0, 0, false, false, nil, time.Time{}, time.Time{}, nil, fmt.Errorf("invalid timezone for agency %q: %w", currentAgency.Id, serverErr)
}

timeParam := queryParams.Get("time")
currentTime := api.Clock.Now().In(currentLocation)
todayMidnight = time.Date(currentTime.Year(), currentTime.Month(), currentTime.Day(), 0, 0, 0, 0, currentLocation)

var timeFieldErrors map[string][]string
var success bool
_, serviceDate, timeFieldErrors, success = utils.ParseTimeParameter(timeParam, currentLocation)
for k, v := range timeFieldErrors {
fieldErrors[k] = append(fieldErrors[k], v...)
}

ctx := r.Context()
if ctx.Err() != nil {
return 0, 0, 0, 0, false, false, nil, time.Time{}, time.Time{}, nil, ctx.Err()
}

if !success {
_, serviceDate, timeFieldErrors, _ = utils.ParseTimeParameter(timeParam, currentLocation)
if len(timeFieldErrors) > 0 {
if fieldErrors == nil {
fieldErrors = make(map[string][]string)
}
for k, v := range timeFieldErrors {
fieldErrors[k] = append(fieldErrors[k], v...)
}
}

locationErrors := utils.ValidateLocationParams(lat, lon, 0, latSpan, lonSpan)
for k, v := range locationErrors {
fieldErrors[k] = append(fieldErrors[k], v...)
ctx := r.Context()
if ctx.Err() != nil {
return 0, 0, 0, 0, false, false, nil, time.Time{}, time.Time{}, nil, ctx.Err()
}

if len(fieldErrors) > 0 {
Expand Down
Loading