diff --git a/mtglib/internal/doppel/conn.go b/mtglib/internal/doppel/conn.go index 7e8ed30e3..33ea88de9 100644 --- a/mtglib/internal/doppel/conn.go +++ b/mtglib/internal/doppel/conn.go @@ -16,48 +16,25 @@ type Conn struct { } type connPayload struct { - ctx context.Context - ctxCancel context.CancelCauseFunc - clock Clock - wg sync.WaitGroup - syncWriteLock sync.RWMutex - writeStream bytes.Buffer - writeCond *sync.Cond + ctx context.Context + ctxCancel context.CancelCauseFunc + clock Clock + wg sync.WaitGroup + writeStream bytes.Buffer + writtenCond sync.Cond + done bool } func (c Conn) Write(p []byte) (int, error) { - c.p.syncWriteLock.RLock() - defer c.p.syncWriteLock.RUnlock() - - c.p.writeCond.L.Lock() - c.p.writeStream.Write(p) - c.p.writeCond.L.Unlock() - - return len(p), context.Cause(c.p.ctx) -} - -func (c Conn) SyncWrite(p []byte) (int, error) { - c.p.syncWriteLock.Lock() - defer c.p.syncWriteLock.Unlock() - - c.p.writeCond.L.Lock() - // wait until buffer is exhausted - for c.p.writeStream.Len() != 0 && context.Cause(c.p.ctx) == nil { - c.p.writeCond.Wait() + if len(p) == 0 { + return 0, context.Cause(c.p.ctx) } - c.p.writeStream.Write(p) - c.p.writeCond.L.Unlock() - if err := context.Cause(c.p.ctx); err != nil { - return len(p), err - } + c.p.writtenCond.L.Lock() + c.p.writeStream.Write(p) + c.p.writtenCond.L.Unlock() - c.p.writeCond.L.Lock() - // wait until data will be sent - for c.p.writeStream.Len() != 0 && context.Cause(c.p.ctx) == nil { - c.p.writeCond.Wait() - } - c.p.writeCond.L.Unlock() + c.p.writtenCond.Signal() return len(p), context.Cause(c.p.ctx) } @@ -69,8 +46,6 @@ func (c Conn) Start() { } func (c Conn) start() { - defer c.p.writeCond.Broadcast() - buf := [tls.MaxRecordSize]byte{} for { @@ -80,11 +55,16 @@ func (c Conn) start() { case <-c.p.clock.tick: } - c.p.writeCond.L.Lock() - n, err := c.p.writeStream.Read(buf[:c.p.clock.stats.Size()]) - c.p.writeCond.L.Unlock() + size := c.p.clock.stats.Size() + + c.p.writtenCond.L.Lock() + for c.p.writeStream.Len() == 0 && !c.p.done { + c.p.writtenCond.Wait() + } + n, _ := c.p.writeStream.Read(buf[:size]) + c.p.writtenCond.L.Unlock() - if n == 0 || err != nil { + if n == 0 { continue } @@ -92,13 +72,17 @@ func (c Conn) start() { c.p.ctxCancel(err) return } - - c.p.writeCond.Signal() } } func (c Conn) Stop() { c.p.ctxCancel(nil) + + c.p.writtenCond.L.Lock() + c.p.done = true + c.p.writtenCond.L.Unlock() + c.p.writtenCond.Broadcast() + c.p.wg.Wait() } @@ -109,7 +93,9 @@ func NewConn(ctx context.Context, conn essentials.Conn, stats *Stats) Conn { p: &connPayload{ ctx: ctx, ctxCancel: cancel, - writeCond: sync.NewCond(&sync.Mutex{}), + writtenCond: sync.Cond{ + L: &sync.Mutex{}, + }, clock: Clock{ stats: stats, tick: make(chan struct{}), diff --git a/mtglib/internal/doppel/conn_test.go b/mtglib/internal/doppel/conn_test.go index eec1f6fb9..850146944 100644 --- a/mtglib/internal/doppel/conn_test.go +++ b/mtglib/internal/doppel/conn_test.go @@ -141,150 +141,51 @@ func (suite *ConnTestSuite) TestWriteReturnsErrorAfterStop() { suite.Error(err) } -func (suite *ConnTestSuite) TestStopOnUnderlyingWriteError() { - suite.connMock. - On("Write", mock.AnythingOfType("[]uint8")). - Return(0, errors.New("connection reset")). - Maybe() - - c := suite.makeConn() - - _, _ = c.Write([]byte("data")) - - suite.Eventually(func() bool { - _, err := c.Write([]byte{1}) - return err != nil - }, 2*time.Second, time.Millisecond) -} - -func (suite *ConnTestSuite) TestSyncWriteDataSent() { - suite.connMock. - On("Write", mock.AnythingOfType("[]uint8")). - Return(0, nil). - Maybe() - - c := suite.makeConn() - defer c.Stop() - - payload := []byte("sync hello") - n, err := c.SyncWrite(payload) - suite.NoError(err) - suite.Equal(len(payload), n) - - // SyncWrite returns only after data is flushed to the wire. - assembled := &bytes.Buffer{} - reader := bytes.NewReader(suite.connMock.Written()) - - for { - header := make([]byte, tls.SizeHeader) - if _, err := io.ReadFull(reader, header); err != nil { - break - } - - suite.Equal(byte(tls.TypeApplicationData), header[0]) - - length := binary.BigEndian.Uint16(header[tls.SizeRecordType+tls.SizeVersion:]) - rec := make([]byte, length) - _, err := io.ReadFull(reader, rec) - suite.NoError(err) - - assembled.Write(rec) - } - - suite.Equal(payload, assembled.Bytes()) -} - -func (suite *ConnTestSuite) TestSyncWriteDrainsBufferFirst() { - suite.connMock. - On("Write", mock.AnythingOfType("[]uint8")). - Return(0, nil). - Maybe() - - c := suite.makeConn() - defer c.Stop() - - // Buffer some data via async Write. - _, err := c.Write([]byte("first")) - suite.NoError(err) - - // SyncWrite must drain "first" before sending "second". - n, err := c.SyncWrite([]byte("second")) - suite.NoError(err) - suite.Equal(6, n) - - // All data should be on the wire now. - assembled := &bytes.Buffer{} - reader := bytes.NewReader(suite.connMock.Written()) - - for { - header := make([]byte, tls.SizeHeader) - if _, err := io.ReadFull(reader, header); err != nil { - break - } - - length := binary.BigEndian.Uint16(header[tls.SizeRecordType+tls.SizeVersion:]) - rec := make([]byte, length) - _, err := io.ReadFull(reader, rec) - suite.NoError(err) - - assembled.Write(rec) - } - - suite.Equal([]byte("firstsecond"), assembled.Bytes()) -} - -func (suite *ConnTestSuite) TestSyncWriteBlocksAsyncWrite() { +func (suite *ConnTestSuite) TestStopDoesNotDeadlockWhenStartIsWaiting() { suite.connMock. On("Write", mock.AnythingOfType("[]uint8")). Return(0, nil). Maybe() - c := suite.makeConn() - defer c.Stop() - - // Start SyncWrite — it holds exclusive lock. - syncDone := make(chan struct{}) - - go func() { - defer close(syncDone) - c.SyncWrite([]byte("exclusive")) //nolint: errcheck - }() - - // Give SyncWrite time to acquire the lock. - time.Sleep(10 * time.Millisecond) - - // Async Write should block until SyncWrite completes. - writeDone := make(chan struct{}) - - go func() { - defer close(writeDone) - c.Write([]byte("blocked")) //nolint: errcheck - }() - - // SyncWrite should finish first. - <-syncDone - - select { - case <-writeDone: - // Write completed after SyncWrite — correct. - case <-time.After(2 * time.Second): - suite.Fail("async Write did not unblock after SyncWrite completed") + for range 100 { + func() { + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + + c := NewConn(ctx, suite.connMock, &Stats{ + k: 2.0, + lambda: 0.01, + }) + + done := make(chan struct{}) + go func() { + defer close(done) + c.Stop() + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + suite.Fail("Stop() deadlocked: start() likely stuck in writtenCond.Wait()") + } + }() } } -func (suite *ConnTestSuite) TestSyncWriteReturnsErrorAfterStop() { +func (suite *ConnTestSuite) TestStopOnUnderlyingWriteError() { suite.connMock. On("Write", mock.AnythingOfType("[]uint8")). - Return(0, nil). + Return(0, errors.New("connection reset")). Maybe() c := suite.makeConn() - c.Stop() - time.Sleep(10 * time.Millisecond) + _, _ = c.Write([]byte("data")) - _, err := c.SyncWrite([]byte("too late")) - suite.Error(err) + suite.Eventually(func() bool { + _, err := c.Write([]byte{1}) + return err != nil + }, 2*time.Second, time.Millisecond) } func TestConn(t *testing.T) {