diff --git a/gtfsdb/helpers.go b/gtfsdb/helpers.go index 9f79d509..3d08d9d1 100644 --- a/gtfsdb/helpers.go +++ b/gtfsdb/helpers.go @@ -66,6 +66,29 @@ func performDatabaseMigration(ctx context.Context, db *sql.DB) error { return nil } +// withTransaction executes the given function within a transaction. +// If tx is non-nil, it uses the provided transaction and does not commit. +// If tx is nil, it starts a new transaction, ensures rollback on error, and commits on success. +func (c *Client) withTransaction(ctx context.Context, tx *sql.Tx, label string, fn func(*sql.Tx) error) error { + if tx != nil { + return fn(tx) + } + + newTx, err := c.DB.BeginTx(ctx, nil) + if err != nil { + return err + } + + logger := slog.Default().With(slog.String("component", "bulk_insert")) + defer logging.SafeRollbackWithLogging(newTx, logger, label) + + if err := fn(newTx); err != nil { + return err + } + + return newTx.Commit() +} + func (c *Client) processAndStoreGTFSDataWithSource(b []byte, source string) error { logger := slog.Default().With(slog.String("component", "gtfs_importer")) @@ -553,35 +576,23 @@ func pickFirstAvailable(a, b string) string { // bulkInsertStops inserts stops. If tx is non-nil it uses that transaction and does not commit; if nil it starts its own and commits. func (c *Client) bulkInsertStops(ctx context.Context, stops []CreateStopParams, tx *sql.Tx) error { - db := c.DB queries := c.Queries logger := slog.Default().With(slog.String("component", "bulk_insert")) logging.LogOperation(logger, "inserting_stops", slog.Int("count", len(stops))) - useTx := tx - if useTx == nil { - var err error - useTx, err = db.Begin() - if err != nil { - return err - } - defer logging.SafeRollbackWithLogging(useTx, logger, "bulk_insert_stops") - } - - qtx := queries.WithTx(useTx) - for _, params := range stops { - _, err := qtx.CreateStop(ctx, params) - if err != nil { - return err - } - } - - if tx == nil { - if err := useTx.Commit(); err != nil { - return err + if err := c.withTransaction(ctx, tx, "bulk_insert_stops", func(tx *sql.Tx) error { + qtx := queries.WithTx(tx) + for _, params := range stops { + _, err := qtx.CreateStop(ctx, params) + if err != nil { + return err + } } + return nil + }); err != nil { + return err } logging.LogOperation(logger, "stops_inserted", @@ -592,35 +603,23 @@ func (c *Client) bulkInsertStops(ctx context.Context, stops []CreateStopParams, // bulkInsertTrips inserts trips. If tx is non-nil it uses that transaction and does not commit; if nil it starts its own and commits. func (c *Client) bulkInsertTrips(ctx context.Context, trips []CreateTripParams, tx *sql.Tx) error { - db := c.DB queries := c.Queries logger := slog.Default().With(slog.String("component", "bulk_insert")) logging.LogOperation(logger, "inserting_trips", slog.Int("count", len(trips))) - useTx := tx - if useTx == nil { - var err error - useTx, err = db.Begin() - if err != nil { - return err - } - defer logging.SafeRollbackWithLogging(useTx, logger, "bulk_insert_trips") - } - - qtx := queries.WithTx(useTx) - for _, params := range trips { - _, err := qtx.CreateTrip(ctx, params) - if err != nil { - return err - } - } - - if tx == nil { - if err := useTx.Commit(); err != nil { - return err + if err := c.withTransaction(ctx, tx, "bulk_insert_trips", func(tx *sql.Tx) error { + qtx := queries.WithTx(tx) + for _, params := range trips { + _, err := qtx.CreateTrip(ctx, params) + if err != nil { + return err + } } + return nil + }); err != nil { + return err } logging.LogOperation(logger, "trips_inserted", @@ -639,7 +638,6 @@ type preparedStopTimeBatch struct { // bulkInsertStopTimes inserts stop times. If tx is non-nil it uses that transaction and does not commit; if nil it starts its own and commits. func (c *Client) bulkInsertStopTimes(ctx context.Context, stopTimes []CreateStopTimeParams, tx *sql.Tx) error { - db := c.DB logger := slog.Default().With(slog.String("component", "bulk_insert")) logging.LogOperation(logger, "inserting_stop_times", @@ -656,17 +654,6 @@ func (c *Client) bulkInsertStopTimes(ctx context.Context, stopTimes []CreateStop // Calculate number of batches numBatches := (len(stopTimes) + batchSize - 1) / batchSize - // Use provided transaction or start our own - useTx := tx - if useTx == nil { - var err error - useTx, err = db.Begin() - if err != nil { - return err - } - defer logging.SafeRollbackWithLogging(useTx, logger, "bulk_insert_stop_times") - } - // Create channels for pipeline numWorkers := runtime.NumCPU() batchChan := make(chan int, numWorkers) @@ -774,31 +761,30 @@ func (c *Client) bulkInsertStopTimes(ctx context.Context, stopTimes []CreateStop slog.Int("total", len(stopTimes)), ) - // Execute sorted batches - for _, batch := range preparedBatches { - // Check context before executing - if ctx.Err() != nil { - return ctx.Err() - } - - // Execute the batch insert - _, err := useTx.ExecContext(ctx, batch.query, batch.args...) - if err != nil { - return fmt.Errorf("failed to insert stop_times batch: %w", err) - } + if err := c.withTransaction(ctx, tx, "bulk_insert_stop_times", func(tx *sql.Tx) error { + // Execute sorted batches + for _, batch := range preparedBatches { + // Check context before executing + if ctx.Err() != nil { + return ctx.Err() + } - // Log progress every 100k records - if (batch.end)%100000 == 0 || batch.end == len(stopTimes) { - logging.LogOperation(logger, "stop_times_progress", - slog.Int("inserted", batch.end), - slog.Int("total", len(stopTimes))) - } - } + // Execute the batch insert + _, err := tx.ExecContext(ctx, batch.query, batch.args...) + if err != nil { + return fmt.Errorf("failed to insert stop_times batch: %w", err) + } - if tx == nil { - if err := useTx.Commit(); err != nil { - return err + // Log progress every 100k records + if (batch.end)%100000 == 0 || batch.end == len(stopTimes) { + logging.LogOperation(logger, "stop_times_progress", + slog.Int("inserted", batch.end), + slog.Int("total", len(stopTimes))) + } } + return nil + }); err != nil { + return err } logging.LogOperation(logger, "stop_times_inserted", @@ -817,7 +803,6 @@ type preparedShapeBatch struct { // bulkInsertShapes inserts shapes. If tx is non-nil it uses that transaction and does not commit; if nil it starts its own and commits. func (c *Client) bulkInsertShapes(ctx context.Context, shapes []CreateShapeParams, tx *sql.Tx) error { - db := c.DB logger := slog.Default().With(slog.String("component", "bulk_insert")) logging.LogOperation(logger, "inserting_shapes", @@ -942,40 +927,29 @@ func (c *Client) bulkInsertShapes(ctx context.Context, shapes []CreateShapeParam }) // ===== PHASE 3: SEQUENTIAL DATABASE EXECUTION ===== - useTx := tx - if useTx == nil { - var err error - useTx, err = db.Begin() - if err != nil { - return err - } - defer logging.SafeRollbackWithLogging(useTx, logger, "bulk_insert_shapes") - } - - for _, batch := range preparedBatches { - // Check context before executing - if ctx.Err() != nil { - return ctx.Err() - } - - // Execute the batch insert - _, err := useTx.ExecContext(ctx, batch.query, batch.args...) - if err != nil { - return fmt.Errorf("failed to insert shapes batch: %w", err) - } + if err := c.withTransaction(ctx, tx, "bulk_insert_shapes", func(tx *sql.Tx) error { + for _, batch := range preparedBatches { + // Check context before executing + if ctx.Err() != nil { + return ctx.Err() + } - // Log progress every 50k records - if (batch.end)%50000 == 0 || batch.end == len(shapes) { - logging.LogOperation(logger, "shapes_progress", - slog.Int("inserted", batch.end), - slog.Int("total", len(shapes))) - } - } + // Execute the batch insert + _, err := tx.ExecContext(ctx, batch.query, batch.args...) + if err != nil { + return fmt.Errorf("failed to insert shapes batch: %w", err) + } - if tx == nil { - if err := useTx.Commit(); err != nil { - return err + // Log progress every 50k records + if (batch.end)%50000 == 0 || batch.end == len(shapes) { + logging.LogOperation(logger, "shapes_progress", + slog.Int("inserted", batch.end), + slog.Int("total", len(shapes))) + } } + return nil + }); err != nil { + return err } logging.LogOperation(logger, "shapes_inserted", @@ -986,35 +960,23 @@ func (c *Client) bulkInsertShapes(ctx context.Context, shapes []CreateShapeParam // bulkInsertFrequencies inserts frequencies. If tx is non-nil it uses that transaction and does not commit; if nil it starts its own and commits. func (c *Client) bulkInsertFrequencies(ctx context.Context, frequencies []CreateFrequencyParams, tx *sql.Tx) error { - db := c.DB queries := c.Queries logger := slog.Default().With(slog.String("component", "bulk_insert")) logging.LogOperation(logger, "inserting_frequencies", slog.Int("count", len(frequencies))) - useTx := tx - if useTx == nil { - var err error - useTx, err = db.Begin() - if err != nil { - return err - } - defer logging.SafeRollbackWithLogging(useTx, logger, "bulk_insert_frequencies") - } - - qtx := queries.WithTx(useTx) - for _, params := range frequencies { - err := qtx.CreateFrequency(ctx, params) - if err != nil { - return err - } - } - - if tx == nil { - if err := useTx.Commit(); err != nil { - return err + if err := c.withTransaction(ctx, tx, "bulk_insert_frequencies", func(tx *sql.Tx) error { + qtx := queries.WithTx(tx) + for _, params := range frequencies { + err := qtx.CreateFrequency(ctx, params) + if err != nil { + return err + } } + return nil + }); err != nil { + return err } logging.LogOperation(logger, "frequencies_inserted", @@ -1025,30 +987,20 @@ func (c *Client) bulkInsertFrequencies(ctx context.Context, frequencies []Create // bulkInsertCalendarDates inserts calendar dates. If tx is non-nil it uses that transaction and does not commit; if nil it starts its own and commits. func (c *Client) bulkInsertCalendarDates(ctx context.Context, calendarDates []CreateCalendarDateParams, tx *sql.Tx) error { - db := c.DB queries := c.Queries - logger := slog.Default().With(slog.String("component", "bulk_insert")) - - useTx := tx - if useTx == nil { - var err error - useTx, err = db.Begin() - if err != nil { - return err + if err := c.withTransaction(ctx, tx, "bulk_insert_calendar_dates", func(tx *sql.Tx) error { + qtx := queries.WithTx(tx) + for _, params := range calendarDates { + _, err := qtx.CreateCalendarDate(ctx, params) + if err != nil { + return err + } } - defer logging.SafeRollbackWithLogging(useTx, logger, "bulk_insert_calendar_dates") + return nil + }); err != nil { + return err } - qtx := queries.WithTx(useTx) - for _, params := range calendarDates { - _, err := qtx.CreateCalendarDate(ctx, params) - if err != nil { - return err - } - } - if tx == nil { - return useTx.Commit() - } return nil } @@ -1172,61 +1124,52 @@ func (c *Client) buildBlockTripIndex(ctx context.Context, staticData *gtfs.Stati slog.Int("total_trips", len(tripMap)), slog.Int("unique_indices", len(indexGroups))) - // buildBlockTripIndex uses tx if provided; otherwise starts its own transaction and commits. - useTx := tx - if useTx == nil { - var err error - useTx, err = c.DB.BeginTx(ctx, nil) - if err != nil { - return err - } - defer logging.SafeRollbackWithLogging(useTx, logger, "build_block_trip_index") - } - - qtx := c.Queries.WithTx(useTx) + q := c.Queries createdAt := time.Now().Unix() - for key, trips := range indexGroups { - // Create unique index key (service ID + layover stop) - indexKey := fmt.Sprintf("%s|%s", key.serviceIDs, key.stopSequenceKey) + if err := c.withTransaction(ctx, tx, "build_block_trip_index", func(tx *sql.Tx) error { + qtx := q.WithTx(tx) - indexID, err := qtx.CreateBlockTripIndex(ctx, CreateBlockTripIndexParams{ - IndexKey: indexKey, - ServiceIds: key.serviceIDs, - StopSequenceKey: key.stopSequenceKey, - CreatedAt: createdAt, - }) - if err != nil { - return fmt.Errorf("failed to create block trip index: %w", err) - } + for key, trips := range indexGroups { + // Create unique index key (service ID + layover stop) + indexKey := fmt.Sprintf("%s|%s", key.serviceIDs, key.stopSequenceKey) - // Sort trips within the group by block_id and then trip_id for deterministic ordering - sort.Slice(trips, func(i, j int) bool { - if trips[i].blockID != trips[j].blockID { - return trips[i].blockID < trips[j].blockID - } - return trips[i].tripID < trips[j].tripID - }) - - // Insert block_trip_entry records for each trip in this index - for sequence, trip := range trips { - err = qtx.CreateBlockTripEntry(ctx, CreateBlockTripEntryParams{ - BlockTripIndexID: indexID, - TripID: trip.tripID, - BlockID: toNullString(trip.blockID), - ServiceID: trip.serviceID, - BlockTripSequence: int64(sequence), + indexID, err := qtx.CreateBlockTripIndex(ctx, CreateBlockTripIndexParams{ + IndexKey: indexKey, + ServiceIds: key.serviceIDs, + StopSequenceKey: key.stopSequenceKey, + CreatedAt: createdAt, }) if err != nil { - return fmt.Errorf("failed to create block trip entry: %w", err) + return fmt.Errorf("failed to create block trip index: %w", err) } - } - } - if tx == nil { - if err := useTx.Commit(); err != nil { - return err + // Sort trips within the group by block_id and then trip_id for deterministic ordering + sort.Slice(trips, func(i, j int) bool { + if trips[i].blockID != trips[j].blockID { + return trips[i].blockID < trips[j].blockID + } + return trips[i].tripID < trips[j].tripID + }) + + // Insert block_trip_entry records for each trip in this index + for sequence, trip := range trips { + err = qtx.CreateBlockTripEntry(ctx, CreateBlockTripEntryParams{ + BlockTripIndexID: indexID, + TripID: trip.tripID, + BlockID: toNullString(trip.blockID), + ServiceID: trip.serviceID, + BlockTripSequence: int64(sequence), + }) + if err != nil { + return fmt.Errorf("failed to create block trip entry: %w", err) + } + } } + + return nil + }); err != nil { + return err } totalEntries := 0 diff --git a/internal/models/schedule_for_route.go b/internal/models/schedule_for_route.go index 244015b6..6078dadc 100644 --- a/internal/models/schedule_for_route.go +++ b/internal/models/schedule_for_route.go @@ -28,5 +28,6 @@ type ScheduleForRouteEntry struct { RouteID string `json:"routeId"` ScheduleDate int64 `json:"scheduleDate"` ServiceIDs []string `json:"serviceIds"` + Stops []string `json:"stops"` StopTripGroupings []StopTripGrouping `json:"stopTripGroupings"` } diff --git a/internal/restapi/schedule_for_route_handler.go b/internal/restapi/schedule_for_route_handler.go index 25565907..6314b6a8 100644 --- a/internal/restapi/schedule_for_route_handler.go +++ b/internal/restapi/schedule_for_route_handler.go @@ -81,6 +81,7 @@ func (api *RestAPI) scheduleForRouteHandler(w http.ResponseWriter, r *http.Reque RouteID: utils.FormCombinedID(agencyID, routeID), ScheduleDate: scheduleDate, ServiceIDs: []string{}, + Stops: []string{}, StopTripGroupings: []models.StopTripGrouping{}, } api.sendResponse(w, r, models.NewEntryResponse(entry, *models.NewEmptyReferences(), api.Clock)) @@ -108,6 +109,7 @@ func (api *RestAPI) scheduleForRouteHandler(w http.ResponseWriter, r *http.Reque RouteID: utils.FormCombinedID(agencyID, routeID), ScheduleDate: scheduleDate, ServiceIDs: combinedServiceIDs, + Stops: []string{}, StopTripGroupings: []models.StopTripGrouping{}, } api.sendResponse(w, r, models.NewEntryResponse(entry, *models.NewEmptyReferences(), api.Clock)) @@ -284,6 +286,8 @@ func (api *RestAPI) scheduleForRouteHandler(w http.ResponseWriter, r *http.Reque uniqueStopIDs = append(uniqueStopIDs, sid) } + combinedStopIDs := make([]string, 0, len(uniqueStopIDs)) + if len(uniqueStopIDs) > 0 { modelStops, _, err := BuildStopReferencesAndRouteIDsForStops(api, ctx, agencyID, uniqueStopIDs) if err != nil { @@ -291,6 +295,10 @@ func (api *RestAPI) scheduleForRouteHandler(w http.ResponseWriter, r *http.Reque return } references.Stops = append(references.Stops, modelStops...) + + for _, sid := range uniqueStopIDs { + combinedStopIDs = append(combinedStopIDs, utils.FormCombinedID(agencyID, sid)) + } } for _, sref := range stopTimesRefs { @@ -301,6 +309,7 @@ func (api *RestAPI) scheduleForRouteHandler(w http.ResponseWriter, r *http.Reque RouteID: utils.FormCombinedID(agencyID, routeID), ScheduleDate: scheduleDate, ServiceIDs: combinedServiceIDs, + Stops: combinedStopIDs, StopTripGroupings: stopTripGroupings, } api.sendResponse(w, r, models.NewEntryResponse(entry, *references, api.Clock))