diff --git a/mtglib/internal/doppel/conn.go b/mtglib/internal/doppel/conn.go index 33ea88de9..cc81d912b 100644 --- a/mtglib/internal/doppel/conn.go +++ b/mtglib/internal/doppel/conn.go @@ -61,14 +61,14 @@ func (c Conn) start() { for c.p.writeStream.Len() == 0 && !c.p.done { c.p.writtenCond.Wait() } - n, _ := c.p.writeStream.Read(buf[:size]) + n, _ := c.p.writeStream.Read(buf[tls.SizeHeader : tls.SizeHeader+size]) c.p.writtenCond.L.Unlock() if n == 0 { continue } - if err := tls.WriteRecord(c.Conn, buf[:n]); err != nil { + if err := tls.WriteRecordInPlace(c.Conn, buf[:], n); err != nil { c.p.ctxCancel(err) return } diff --git a/mtglib/internal/doppel/ganger.go b/mtglib/internal/doppel/ganger.go index 6621761f1..c8fbfbcc6 100644 --- a/mtglib/internal/doppel/ganger.go +++ b/mtglib/internal/doppel/ganger.go @@ -98,7 +98,8 @@ func (g *Ganger) run() { g.durations = append(g.durations, durations...) if len(g.durations) > DoppelGangerMaxDurations { - g.durations = g.durations[len(g.durations)-DoppelGangerMaxDurations:] + copy(g.durations, g.durations[len(g.durations)-DoppelGangerMaxDurations:]) + g.durations = g.durations[:DoppelGangerMaxDurations] } if len(g.durations) < MinDurationsToCalculate { diff --git a/mtglib/internal/tls/conn.go b/mtglib/internal/tls/conn.go index e8893109d..ce33fa1f6 100644 --- a/mtglib/internal/tls/conn.go +++ b/mtglib/internal/tls/conn.go @@ -34,7 +34,6 @@ type Conn struct { type connPayload struct { readBuf bytes.Buffer - writeBuf bytes.Buffer connBuffered *bufio.Reader read bool write bool @@ -80,7 +79,6 @@ func New(conn essentials.Conn, read, write bool) Conn { } newConn.p.readBuf.Grow(DefaultBufferSize) - newConn.p.writeBuf.Grow(DefaultBufferSize) return newConn } diff --git a/mtglib/internal/tls/utils.go b/mtglib/internal/tls/utils.go index 978e048c9..c0bdc39fd 100644 --- a/mtglib/internal/tls/utils.go +++ b/mtglib/internal/tls/utils.go @@ -29,20 +29,24 @@ func ReadRecord(r io.Reader, w io.Writer) (byte, int64, error) { func WriteRecord(w io.Writer, payload []byte) error { buf := [MaxRecordSize]byte{} - buf[0] = TypeApplicationData - - bufV := buf[SizeRecordType:] - copy(bufV[:SizeVersion], TLSVersion[:]) + copy(buf[SizeHeader:], payload) - bufS := bufV[SizeVersion:] - binary.BigEndian.PutUint16(bufS[:SizeSize], uint16(len(payload))) + return WriteRecordInPlace(w, buf[:], len(payload)) +} - bufP := buf[SizeHeader:] - if n := copy(bufP, payload); n != len(payload) { - return fmt.Errorf("copied %d bytes of payload instead of %d", n, len(payload)) +func WriteRecordInPlace(w io.Writer, buf []byte, payloadLen int) error { + if payloadLen > MaxRecordPayloadSize { + return fmt.Errorf("payload %d exceeds max %d", payloadLen, MaxRecordPayloadSize) } - _, err := w.Write(buf[:SizeHeader+len(payload)]) + buf[0] = TypeApplicationData + copy(buf[SizeRecordType:SizeRecordType+SizeVersion], TLSVersion[:]) + binary.BigEndian.PutUint16( + buf[SizeRecordType+SizeVersion:SizeRecordType+SizeVersion+SizeSize], + uint16(payloadLen), + ) + + _, err := w.Write(buf[:SizeHeader+payloadLen]) return err } diff --git a/mtglib/internal/tls/utils_test.go b/mtglib/internal/tls/utils_test.go index 9ddbfa8c6..ad5a47f9c 100644 --- a/mtglib/internal/tls/utils_test.go +++ b/mtglib/internal/tls/utils_test.go @@ -119,6 +119,84 @@ func (suite *UtilsTestSuite) TestWriteRecordPayloadTooLarge() { suite.Error(err) } +func (suite *UtilsTestSuite) TestWriteRecordInPlace() { + payload := []byte("hello in-place") + + var buf [MaxRecordSize]byte + copy(buf[SizeHeader:], payload) + + err := WriteRecordInPlace(suite.dst, buf[:], len(payload)) + suite.NoError(err) + + written := suite.dst.Bytes() + suite.Equal(byte(TypeApplicationData), written[0]) + suite.Equal(TLSVersion[:], written[SizeRecordType:SizeRecordType+SizeVersion]) + + length := binary.BigEndian.Uint16(written[SizeRecordType+SizeVersion:]) + suite.Equal(uint16(len(payload)), length) + suite.Equal(payload, written[SizeHeader:]) +} + +func (suite *UtilsTestSuite) TestWriteRecordInPlaceRoundTrip() { + payload := []byte("round trip in-place") + + var buf [MaxRecordSize]byte + copy(buf[SizeHeader:], payload) + + var wire bytes.Buffer + + err := WriteRecordInPlace(&wire, buf[:], len(payload)) + suite.NoError(err) + + var recovered bytes.Buffer + + recordType, length, err := ReadRecord(&wire, &recovered) + suite.NoError(err) + suite.Equal(byte(TypeApplicationData), recordType) + suite.Equal(int64(len(payload)), length) + suite.Equal(payload, recovered.Bytes()) +} + +func (suite *UtilsTestSuite) TestWriteRecordInPlacePayloadTooLarge() { + var buf [MaxRecordSize]byte + + err := WriteRecordInPlace(suite.dst, buf[:], MaxRecordPayloadSize+1) + suite.Error(err) +} + +func (suite *UtilsTestSuite) TestWriteRecordInPlacePropagatesError() { + m := &WriterMock{} + m. + On("Write", mock.AnythingOfType("[]uint8")). + Once(). + Return(0, errors.New("disk full")) + + var buf [MaxRecordSize]byte + copy(buf[SizeHeader:], []byte("data")) + + err := WriteRecordInPlace(m, buf[:], 4) + suite.Error(err) + + m.AssertExpectations(suite.T()) +} + +func (suite *UtilsTestSuite) TestWriteRecordInPlaceMatchesWriteRecord() { + payload := []byte("equivalence check") + + var legacy bytes.Buffer + err := WriteRecord(&legacy, payload) + suite.NoError(err) + + var buf [MaxRecordSize]byte + copy(buf[SizeHeader:], payload) + + var inPlace bytes.Buffer + err = WriteRecordInPlace(&inPlace, buf[:], len(payload)) + suite.NoError(err) + + suite.Equal(legacy.Bytes(), inPlace.Bytes()) +} + func TestUtils(t *testing.T) { t.Parallel() suite.Run(t, &UtilsTestSuite{})