Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 31 additions & 45 deletions mtglib/internal/doppel/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,48 +16,25 @@ type Conn struct {
}

type connPayload struct {
ctx context.Context
ctxCancel context.CancelCauseFunc
clock Clock
wg sync.WaitGroup
syncWriteLock sync.RWMutex
writeStream bytes.Buffer
writeCond *sync.Cond
ctx context.Context
ctxCancel context.CancelCauseFunc
clock Clock
wg sync.WaitGroup
writeStream bytes.Buffer
writtenCond sync.Cond
done bool
}

func (c Conn) Write(p []byte) (int, error) {
c.p.syncWriteLock.RLock()
defer c.p.syncWriteLock.RUnlock()

c.p.writeCond.L.Lock()
c.p.writeStream.Write(p)
c.p.writeCond.L.Unlock()

return len(p), context.Cause(c.p.ctx)
}

func (c Conn) SyncWrite(p []byte) (int, error) {
c.p.syncWriteLock.Lock()
defer c.p.syncWriteLock.Unlock()

c.p.writeCond.L.Lock()
// wait until buffer is exhausted
for c.p.writeStream.Len() != 0 && context.Cause(c.p.ctx) == nil {
c.p.writeCond.Wait()
if len(p) == 0 {
return 0, context.Cause(c.p.ctx)
}
c.p.writeStream.Write(p)
c.p.writeCond.L.Unlock()

if err := context.Cause(c.p.ctx); err != nil {
return len(p), err
}
c.p.writtenCond.L.Lock()
c.p.writeStream.Write(p)
c.p.writtenCond.L.Unlock()

c.p.writeCond.L.Lock()
// wait until data will be sent
for c.p.writeStream.Len() != 0 && context.Cause(c.p.ctx) == nil {
c.p.writeCond.Wait()
}
c.p.writeCond.L.Unlock()
c.p.writtenCond.Signal()

return len(p), context.Cause(c.p.ctx)
}
Expand All @@ -69,8 +46,6 @@ func (c Conn) Start() {
}

func (c Conn) start() {
defer c.p.writeCond.Broadcast()

buf := [tls.MaxRecordSize]byte{}

for {
Expand All @@ -80,25 +55,34 @@ func (c Conn) start() {
case <-c.p.clock.tick:
}

c.p.writeCond.L.Lock()
n, err := c.p.writeStream.Read(buf[:c.p.clock.stats.Size()])
c.p.writeCond.L.Unlock()
size := c.p.clock.stats.Size()

c.p.writtenCond.L.Lock()
for c.p.writeStream.Len() == 0 && !c.p.done {
c.p.writtenCond.Wait()
}
n, _ := c.p.writeStream.Read(buf[:size])
c.p.writtenCond.L.Unlock()

if n == 0 || err != nil {
if n == 0 {
continue
}

if err := tls.WriteRecord(c.Conn, buf[:n]); err != nil {
c.p.ctxCancel(err)
return
}

c.p.writeCond.Signal()
}
}

func (c Conn) Stop() {
c.p.ctxCancel(nil)

c.p.writtenCond.L.Lock()
c.p.done = true
c.p.writtenCond.L.Unlock()
c.p.writtenCond.Broadcast()

c.p.wg.Wait()
}

Expand All @@ -109,7 +93,9 @@ func NewConn(ctx context.Context, conn essentials.Conn, stats *Stats) Conn {
p: &connPayload{
ctx: ctx,
ctxCancel: cancel,
writeCond: sync.NewCond(&sync.Mutex{}),
writtenCond: sync.Cond{
L: &sync.Mutex{},
},
clock: Clock{
stats: stats,
tick: make(chan struct{}),
Expand Down
159 changes: 30 additions & 129 deletions mtglib/internal/doppel/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,150 +141,51 @@ func (suite *ConnTestSuite) TestWriteReturnsErrorAfterStop() {
suite.Error(err)
}

func (suite *ConnTestSuite) TestStopOnUnderlyingWriteError() {
suite.connMock.
On("Write", mock.AnythingOfType("[]uint8")).
Return(0, errors.New("connection reset")).
Maybe()

c := suite.makeConn()

_, _ = c.Write([]byte("data"))

suite.Eventually(func() bool {
_, err := c.Write([]byte{1})
return err != nil
}, 2*time.Second, time.Millisecond)
}

func (suite *ConnTestSuite) TestSyncWriteDataSent() {
suite.connMock.
On("Write", mock.AnythingOfType("[]uint8")).
Return(0, nil).
Maybe()

c := suite.makeConn()
defer c.Stop()

payload := []byte("sync hello")
n, err := c.SyncWrite(payload)
suite.NoError(err)
suite.Equal(len(payload), n)

// SyncWrite returns only after data is flushed to the wire.
assembled := &bytes.Buffer{}
reader := bytes.NewReader(suite.connMock.Written())

for {
header := make([]byte, tls.SizeHeader)
if _, err := io.ReadFull(reader, header); err != nil {
break
}

suite.Equal(byte(tls.TypeApplicationData), header[0])

length := binary.BigEndian.Uint16(header[tls.SizeRecordType+tls.SizeVersion:])
rec := make([]byte, length)
_, err := io.ReadFull(reader, rec)
suite.NoError(err)

assembled.Write(rec)
}

suite.Equal(payload, assembled.Bytes())
}

func (suite *ConnTestSuite) TestSyncWriteDrainsBufferFirst() {
suite.connMock.
On("Write", mock.AnythingOfType("[]uint8")).
Return(0, nil).
Maybe()

c := suite.makeConn()
defer c.Stop()

// Buffer some data via async Write.
_, err := c.Write([]byte("first"))
suite.NoError(err)

// SyncWrite must drain "first" before sending "second".
n, err := c.SyncWrite([]byte("second"))
suite.NoError(err)
suite.Equal(6, n)

// All data should be on the wire now.
assembled := &bytes.Buffer{}
reader := bytes.NewReader(suite.connMock.Written())

for {
header := make([]byte, tls.SizeHeader)
if _, err := io.ReadFull(reader, header); err != nil {
break
}

length := binary.BigEndian.Uint16(header[tls.SizeRecordType+tls.SizeVersion:])
rec := make([]byte, length)
_, err := io.ReadFull(reader, rec)
suite.NoError(err)

assembled.Write(rec)
}

suite.Equal([]byte("firstsecond"), assembled.Bytes())
}

func (suite *ConnTestSuite) TestSyncWriteBlocksAsyncWrite() {
func (suite *ConnTestSuite) TestStopDoesNotDeadlockWhenStartIsWaiting() {
suite.connMock.
On("Write", mock.AnythingOfType("[]uint8")).
Return(0, nil).
Maybe()

c := suite.makeConn()
defer c.Stop()

// Start SyncWrite — it holds exclusive lock.
syncDone := make(chan struct{})

go func() {
defer close(syncDone)
c.SyncWrite([]byte("exclusive")) //nolint: errcheck
}()

// Give SyncWrite time to acquire the lock.
time.Sleep(10 * time.Millisecond)

// Async Write should block until SyncWrite completes.
writeDone := make(chan struct{})

go func() {
defer close(writeDone)
c.Write([]byte("blocked")) //nolint: errcheck
}()

// SyncWrite should finish first.
<-syncDone

select {
case <-writeDone:
// Write completed after SyncWrite — correct.
case <-time.After(2 * time.Second):
suite.Fail("async Write did not unblock after SyncWrite completed")
for range 100 {
func() {
ctx, cancel := context.WithCancel(suite.ctx)
defer cancel()

c := NewConn(ctx, suite.connMock, &Stats{
k: 2.0,
lambda: 0.01,
})

done := make(chan struct{})
go func() {
defer close(done)
c.Stop()
}()

select {
case <-done:
case <-time.After(2 * time.Second):
suite.Fail("Stop() deadlocked: start() likely stuck in writtenCond.Wait()")
}
}()
}
}

func (suite *ConnTestSuite) TestSyncWriteReturnsErrorAfterStop() {
func (suite *ConnTestSuite) TestStopOnUnderlyingWriteError() {
suite.connMock.
On("Write", mock.AnythingOfType("[]uint8")).
Return(0, nil).
Return(0, errors.New("connection reset")).
Maybe()

c := suite.makeConn()
c.Stop()

time.Sleep(10 * time.Millisecond)
_, _ = c.Write([]byte("data"))

_, err := c.SyncWrite([]byte("too late"))
suite.Error(err)
suite.Eventually(func() bool {
_, err := c.Write([]byte{1})
return err != nil
}, 2*time.Second, time.Millisecond)
}

func TestConn(t *testing.T) {
Expand Down
Loading