Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions internal/gtfs/advanced_direction_calculator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"sync"
"sync/atomic"

"golang.org/x/sync/singleflight"
"maglev.onebusaway.org/gtfsdb"
"maglev.onebusaway.org/internal/utils"
)
Expand All @@ -29,7 +30,13 @@ type AdvancedDirectionCalculator struct {
shapeCache map[string][]gtfsdb.GetShapePointsWithDistanceRow // Cache of all shape data for bulk operations
initialized atomic.Bool // Tracks whether concurrent operations have started
cacheMutex sync.RWMutex // Protects map access
directionResults sync.Map // Cached direction results (stopID -> string), includes negative cache

// directionResults caches computed stop directions.
// Lifecycle note: This map grows indefinitely for the lifetime of the application.
// Unbounded growth is acceptable here because it is strictly bounded by the finite
// number of valid real-world stops, and computed directions remain stable across GTFS reloads.
directionResults sync.Map // Cached direction results (stopID -> string), includes negative cache
requestGroup singleflight.Group // Prevents duplicate concurrent computations for the same stop
}

// NewAdvancedDirectionCalculator creates a new advanced direction calculator
Expand All @@ -54,8 +61,9 @@ func (adc *AdvancedDirectionCalculator) SetStandardDeviationThreshold(threshold
return nil
}

// SetShapeCache sets a pre-loaded cache of shape data to avoid database queries during bulk operations.
// This significantly improves performance when calculating directions for many stops.
// SetShapeCache is retained exclusively for use by the DirectionPrecomputer during startup.
// It sets a pre-loaded cache of shape data to avoid thousands of database queries during
// the precomputation phase, significantly improving startup performance.
// IMPORTANT: This must be called before any concurrent operations begin.
// Returns an error if called after CalculateStopDirection has been invoked.
func (adc *AdvancedDirectionCalculator) SetShapeCache(cache map[string][]gtfsdb.GetShapePointsWithDistanceRow) error {
Expand All @@ -69,20 +77,6 @@ func (adc *AdvancedDirectionCalculator) SetShapeCache(cache map[string][]gtfsdb.
return nil
}

// SetContextCache injects the bulk-loaded context data.
// IMPORTANT: This must be called before any concurrent calculation operations begin.
// Returns an error if called after CalculateStopDirection has been invoked.
func (adc *AdvancedDirectionCalculator) SetContextCache(cache map[string][]gtfsdb.GetStopsWithShapeContextRow) error {
adc.cacheMutex.Lock()
defer adc.cacheMutex.Unlock()

if adc.initialized.Load() {
return errors.New("SetContextCache called after concurrent operations have started")
}
adc.contextCache = cache
return nil
}

// CalculateStopDirection computes the direction for a stop using the Java algorithm
func (adc *AdvancedDirectionCalculator) CalculateStopDirection(ctx context.Context, stopID string, gtfsDirection ...sql.NullString) string {
if len(gtfsDirection) > 0 && gtfsDirection[0].Valid && gtfsDirection[0].String != "" {
Expand All @@ -99,12 +93,24 @@ func (adc *AdvancedDirectionCalculator) CalculateStopDirection(ctx context.Conte
// Mark as initialized for concurrency safety
adc.initialized.Store(true)

result := adc.computeFromShapes(ctx, stopID)
// Fall back to computing from shapes, protected by singleflight
// This ensures concurrent requests for the SAME stopID don't hit the DB multiple times.
v, _, _ := adc.requestGroup.Do(stopID, func() (interface{}, error) {
// Double-check cache inside the singleflight in case another goroutine just finished it
if cached, ok := adc.directionResults.Load(stopID); ok {
return cached.(string), nil
}

// Actually compute it (Hits the DB)
computedDir := adc.computeFromShapes(ctx, stopID)

// Store in sync.Map for all future requests
adc.directionResults.Store(stopID, computedDir)

// Cache the result (even empty strings) to avoid recomputation
adc.directionResults.Store(stopID, result)
return computedDir, nil
})

return result
return v.(string)
}

// translateGtfsDirection converts GTFS direction field to compass direction
Expand Down
95 changes: 5 additions & 90 deletions internal/gtfs/advanced_direction_calculator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"maglev.onebusaway.org/gtfsdb"
"maglev.onebusaway.org/internal/models"
)

Expand Down Expand Up @@ -105,16 +104,14 @@ func TestTranslateGtfsDirection(t *testing.T) {
}

func TestCalculateStopDirectionResultCache(t *testing.T) {
calc := &AdvancedDirectionCalculator{}
// Set an empty context cache so computeFromShapes doesn't try to hit the DB
_ = calc.SetContextCache(make(map[string][]gtfsdb.GetStopsWithShapeContextRow))
_, calc := getSharedTestComponents(t)

// First call: precomputed direction "NE" from DB should be recognized
result := calc.CalculateStopDirection(context.Background(), "stop-1", sql.NullString{String: "NE", Valid: true})
assert.Equal(t, "NE", result, "should recognize compass abbreviation NE from precomputed direction")

// Verify that a stop with no GTFS direction falls through to computeFromShapes,
// gets an empty result (no data in cache), and caches the empty result.
// gets an empty result (no data in cache for nonexistent stop), and caches the empty result.
result = calc.CalculateStopDirection(context.Background(), "nonexistent-stop", sql.NullString{Valid: false})
assert.Equal(t, "", result, "should return empty for stop with no direction data")

Expand Down Expand Up @@ -274,7 +271,7 @@ func TestStandardDeviationThreshold(t *testing.T) {

func TestCalculateStopDirection_WithShapeData(t *testing.T) {
ctx := context.Background()
// Optimization: Reuse shared DB and Cache
// Optimization: Reuse shared DB
_, calc := getSharedTestComponents(t)

// Test with a real stop from RABA data
Expand All @@ -285,7 +282,7 @@ func TestCalculateStopDirection_WithShapeData(t *testing.T) {

func TestComputeFromShapes_NoShapeData(t *testing.T) {
ctx := context.Background()
// Optimization: Reuse shared DB and Cache
// Optimization: Reuse shared DB
_, calc := getSharedTestComponents(t)

// Test with a non-existent stop
Expand All @@ -295,7 +292,7 @@ func TestComputeFromShapes_NoShapeData(t *testing.T) {

func TestComputeFromShapes_SingleOrientation(t *testing.T) {
ctx := context.Background()
// Optimization: Reuse shared DB and Cache
// Optimization: Reuse shared DB
_, calc := getSharedTestComponents(t)

// Test with actual stop data - single orientation path will be taken if only one trip
Expand Down Expand Up @@ -442,43 +439,6 @@ func TestTranslateGtfsDirection_NumericEdgeCases(t *testing.T) {
}
}

func TestSetContextCache_HappyPath(t *testing.T) {
// Create a bare instance (no queries needed for this test)
adc := &AdvancedDirectionCalculator{}

// Create dummy cache data
cache := make(map[string][]gtfsdb.GetStopsWithShapeContextRow)
cache["stop1"] = []gtfsdb.GetStopsWithShapeContextRow{
{
ID: "stop1",
Lat: 40.7128,
Lon: -74.0060,
},
}

// Set the cache
err := adc.SetContextCache(cache)
assert.NoError(t, err)

// Verify it was set correctly (accessing private field)
assert.Equal(t, 1, len(adc.contextCache))
assert.Equal(t, "stop1", adc.contextCache["stop1"][0].ID)
}

func TestSetContextCache_ReturnsErrorAfterInit(t *testing.T) {
// Create the instance
adc := &AdvancedDirectionCalculator{}

// Simulate that concurrent operations have already started
// We manually toggle the atomic boolean to "true"
adc.initialized.Store(true)

// This call MUST return an error now
err := adc.SetContextCache(make(map[string][]gtfsdb.GetStopsWithShapeContextRow))
assert.Error(t, err)
assert.Equal(t, "SetContextCache called after concurrent operations have started", err.Error())
}

func TestCalculateStopDirection_VariadicSignature(t *testing.T) {
ctx := context.Background()
_, calc := getSharedTestComponents(t)
Expand All @@ -495,51 +455,6 @@ func TestCalculateStopDirection_VariadicSignature(t *testing.T) {
assert.Equal(t, "", dirOmitted, "Should fall back gracefully when argument is omitted")
}

func TestSetContextCache_ConcurrentAccess(t *testing.T) {
ctx := context.Background()
manager, _ := getSharedTestComponents(t)
// We use shared DB, but MUST use a fresh Calculator to test the race condition specifically on that instance.
calc := NewAdvancedDirectionCalculator(manager.GtfsDB.Queries)

// Create dummy cache
cache := make(map[string][]gtfsdb.GetStopsWithShapeContextRow)

// Channel to coordinate start
start := make(chan struct{})
done := make(chan struct{})
setErrCh := make(chan error, 1)

// Launch a "Reader" Goroutine (Simulating a request coming in)
go func() {
<-start // Wait for signal
// This triggers 'initialized.Store(true)' internally
calc.CalculateStopDirection(ctx, "7000")
close(done)
}()

// Launch a "Writer" (Simulating the bulk loader trying to set cache late)
// We want to verify this doesn't crash the program with a race condition,
// but correctly returns an error if it happens too late.
go func() {
<-start // Wait for signal
setErrCh <- calc.SetContextCache(cache)
}()

// Start the race
close(start)

// Wait for reader to finish
<-done

// Wait for writer to finish
err := <-setErrCh
if err != nil {
assert.Equal(t, "SetContextCache called after concurrent operations have started", err.Error())
}

// If got here without the test binary crashing/deadlocking, the atomic guards did their job.
}

// TestBulkQuery_GetStopsWithShapeContextByIDs verifies the bulk optimization
func TestBulkQuery_GetStopsWithShapeContextByIDs(t *testing.T) {
ctx := context.Background()
Expand Down
74 changes: 0 additions & 74 deletions internal/gtfs/global_cache.go

This file was deleted.

Loading
Loading