@@ -5,6 +5,7 @@ package tarantool
55import (
66 "bufio"
77 "bytes"
8+ "context"
89 "errors"
910 "fmt"
1011 "io"
@@ -125,8 +126,11 @@ type Connection struct {
125126 c net.Conn
126127 mutex sync.Mutex
127128 // Schema contains schema loaded on connection.
128- Schema *Schema
129+ Schema *Schema
130+ // requestId contains the last request ID for requests with nil context.
129131 requestId uint32
132+ // contextRequestId contains the last request ID for requests with context.
133+ contextRequestId uint32
130134 // Greeting contains first message sent by Tarantool.
131135 Greeting *Greeting
132136
@@ -143,16 +147,57 @@ type Connection struct {
143147
144148var _ = Connector(&Connection{}) // Check compatibility with connector interface.
145149
150+ type futureList struct {
151+ first *Future
152+ last **Future
153+ }
154+
155+ func (list *futureList) findFuture(reqid uint32, fetch bool) *Future {
156+ root := &list.first
157+ for {
158+ fut := *root
159+ if fut == nil {
160+ return nil
161+ }
162+ if fut.requestId == reqid {
163+ if fetch {
164+ *root = fut.next
165+ if fut.next == nil {
166+ list.last = root
167+ } else {
168+ fut.next = nil
169+ }
170+ }
171+ return fut
172+ }
173+ root = &fut.next
174+ }
175+ }
176+
177+ func (list *futureList) addFuture(fut *Future) {
178+ *list.last = fut
179+ list.last = &fut.next
180+ }
181+
182+ func (list *futureList) clear(err error, conn *Connection) {
183+ fut := list.first
184+ list.first = nil
185+ list.last = &list.first
186+ for fut != nil {
187+ fut.SetError(err)
188+ conn.markDone(fut)
189+ fut, fut.next = fut.next, nil
190+ }
191+ }
192+
146193type connShard struct {
147- rmut sync.Mutex
148- requests [requestsMap]struct {
149- first *Future
150- last **Future
151- }
152- bufmut sync.Mutex
153- buf smallWBuf
154- enc *msgpack.Encoder
155- _pad [16]uint64 //nolint: unused,structcheck
194+ rmut sync.Mutex
195+ requests [requestsMap]futureList
196+ requestsWithCtx [requestsMap]futureList
197+ bufmut sync.Mutex
198+ buf smallWBuf
199+ enc *msgpack.Encoder
200+ _pad [16]uint64 //nolint: unused,structcheck
156201}
157202
158203// Greeting is a message sent by Tarantool on connect.
@@ -167,6 +212,11 @@ type Opts struct {
167212 // push messages are received. If Timeout is zero, any request can be
168213 // blocked infinitely.
169214 // Also used to setup net.TCPConn.Set(Read|Write)Deadline.
215+ //
216+ // Pay attention, when using contexts with request objects,
217+ // the timeout option for Connection does not affect the lifetime
218+ // of the request. For those purposes use context.WithTimeout() as
219+ // the root context.
170220 Timeout time.Duration
171221 // Timeout between reconnect attempts. If Reconnect is zero, no
172222 // reconnect attempts will be made.
@@ -262,12 +312,13 @@ type SslOpts struct {
262312// and will not finish to make attempts on authorization failures.
263313func Connect(addr string, opts Opts) (conn *Connection, err error) {
264314 conn = &Connection{
265- addr: addr,
266- requestId: 0,
267- Greeting: &Greeting{},
268- control: make(chan struct{}),
269- opts: opts,
270- dec: msgpack.NewDecoder(&smallBuf{}),
315+ addr: addr,
316+ requestId: 0,
317+ contextRequestId: 1,
318+ Greeting: &Greeting{},
319+ control: make(chan struct{}),
320+ opts: opts,
321+ dec: msgpack.NewDecoder(&smallBuf{}),
271322 }
272323 maxprocs := uint32(runtime.GOMAXPROCS(-1))
273324 if conn.opts.Concurrency == 0 || conn.opts.Concurrency > maxprocs*128 {
@@ -283,8 +334,11 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) {
283334 conn.shard = make([]connShard, conn.opts.Concurrency)
284335 for i := range conn.shard {
285336 shard := &conn.shard[i]
286- for j := range shard.requests {
287- shard.requests[j].last = &shard.requests[j].first
337+ requestsLists := []*[requestsMap]futureList{&shard.requests, &shard.requestsWithCtx}
338+ for _, requests := range requestsLists {
339+ for j := range requests {
340+ requests[j].last = &requests[j].first
341+ }
288342 }
289343 }
290344
@@ -387,6 +441,13 @@ func (conn *Connection) Handle() interface{} {
387441 return conn.opts.Handle
388442}
389443
444+ func (conn *Connection) cancelFuture(fut *Future, err error) {
445+ if fut = conn.fetchFuture(fut.requestId); fut != nil {
446+ fut.SetError(err)
447+ conn.markDone(fut)
448+ }
449+ }
450+
390451func (conn *Connection) dial() (err error) {
391452 var connection net.Conn
392453 network := "tcp"
@@ -580,15 +641,10 @@ func (conn *Connection) closeConnection(neterr error, forever bool) (err error)
580641 }
581642 for i := range conn.shard {
582643 conn.shard[i].buf.Reset()
583- requests := &conn.shard[i].requests
584- for pos := range requests {
585- fut := requests[pos].first
586- requests[pos].first = nil
587- requests[pos].last = &requests[pos].first
588- for fut != nil {
589- fut.SetError(neterr)
590- conn.markDone(fut)
591- fut, fut.next = fut.next, nil
644+ requestsLists := []*[requestsMap]futureList{&conn.shard[i].requests, &conn.shard[i].requestsWithCtx}
645+ for _, requests := range requestsLists {
646+ for pos := range requests {
647+ requests[pos].clear(neterr, conn)
592648 }
593649 }
594650 }
@@ -721,7 +777,7 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) {
721777 }
722778}
723779
724- func (conn *Connection) newFuture() (fut *Future) {
780+ func (conn *Connection) newFuture(ctx context.Context ) (fut *Future) {
725781 fut = NewFuture()
726782 if conn.rlimit != nil && conn.opts.RLimitAction == RLimitDrop {
727783 select {
@@ -736,7 +792,7 @@ func (conn *Connection) newFuture() (fut *Future) {
736792 return
737793 }
738794 }
739- fut.requestId = conn.nextRequestId()
795+ fut.requestId = conn.nextRequestId(ctx != nil )
740796 shardn := fut.requestId & (conn.opts.Concurrency - 1)
741797 shard := &conn.shard[shardn]
742798 shard.rmut.Lock()
@@ -761,11 +817,20 @@ func (conn *Connection) newFuture() (fut *Future) {
761817 return
762818 }
763819 pos := (fut.requestId / conn.opts.Concurrency) & (requestsMap - 1)
764- pair := &shard.requests[pos]
765- *pair.last = fut
766- pair.last = &fut.next
767- if conn.opts.Timeout > 0 {
768- fut.timeout = time.Since(epoch) + conn.opts.Timeout
820+ if ctx != nil {
821+ select {
822+ case <-ctx.Done():
823+ fut.SetError(fmt.Errorf("context is done"))
824+ shard.rmut.Unlock()
825+ return
826+ default:
827+ }
828+ shard.requestsWithCtx[pos].addFuture(fut)
829+ } else {
830+ shard.requests[pos].addFuture(fut)
831+ if conn.opts.Timeout > 0 {
832+ fut.timeout = time.Since(epoch) + conn.opts.Timeout
833+ }
769834 }
770835 shard.rmut.Unlock()
771836 if conn.rlimit != nil && conn.opts.RLimitAction == RLimitWait {
@@ -785,12 +850,43 @@ func (conn *Connection) newFuture() (fut *Future) {
785850 return
786851}
787852
853+ // This method removes a future from the internal queue if the context
854+ // is "done" before the response is come. Such select logic is inspired
855+ // from this thread: https://groups.google.com/g/golang-dev/c/jX4oQEls3uk
856+ func (conn *Connection) contextWatchdog(fut *Future, ctx context.Context) {
857+ select {
858+ case <-fut.done:
859+ default:
860+ select {
861+ case <-ctx.Done():
862+ conn.cancelFuture(fut, fmt.Errorf("context is done"))
863+ default:
864+ select {
865+ case <-fut.done:
866+ case <-ctx.Done():
867+ conn.cancelFuture(fut, fmt.Errorf("context is done"))
868+ }
869+ }
870+ }
871+ }
872+
788873func (conn *Connection) send(req Request) *Future {
789- fut := conn.newFuture()
874+ fut := conn.newFuture(req.Ctx() )
790875 if fut.ready == nil {
791876 return fut
792877 }
878+ if req.Ctx() != nil {
879+ select {
880+ case <-req.Ctx().Done():
881+ conn.cancelFuture(fut, fmt.Errorf("context is done"))
882+ return fut
883+ default:
884+ }
885+ }
793886 conn.putFuture(fut, req)
887+ if req.Ctx() != nil {
888+ go conn.contextWatchdog(fut, req.Ctx())
889+ }
794890 return fut
795891}
796892
@@ -877,25 +973,11 @@ func (conn *Connection) fetchFuture(reqid uint32) (fut *Future) {
877973func (conn *Connection) getFutureImp(reqid uint32, fetch bool) *Future {
878974 shard := &conn.shard[reqid&(conn.opts.Concurrency-1)]
879975 pos := (reqid / conn.opts.Concurrency) & (requestsMap - 1)
880- pair := &shard.requests[pos]
881- root := &pair.first
882- for {
883- fut := *root
884- if fut == nil {
885- return nil
886- }
887- if fut.requestId == reqid {
888- if fetch {
889- *root = fut.next
890- if fut.next == nil {
891- pair.last = root
892- } else {
893- fut.next = nil
894- }
895- }
896- return fut
897- }
898- root = &fut.next
976+ // futures with even requests id belong to requests list with nil context
977+ if reqid%2 == 0 {
978+ return shard.requests[pos].findFuture(reqid, fetch)
979+ } else {
980+ return shard.requestsWithCtx[pos].findFuture(reqid, fetch)
899981 }
900982}
901983
@@ -984,8 +1066,12 @@ func (conn *Connection) read(r io.Reader) (response []byte, err error) {
9841066 return
9851067}
9861068
987- func (conn *Connection) nextRequestId() (requestId uint32) {
988- return atomic.AddUint32(&conn.requestId, 1)
1069+ func (conn *Connection) nextRequestId(context bool) (requestId uint32) {
1070+ if context {
1071+ return atomic.AddUint32(&conn.contextRequestId, 2)
1072+ } else {
1073+ return atomic.AddUint32(&conn.requestId, 2)
1074+ }
9891075}
9901076
9911077// Do performs a request asynchronously on the connection.
@@ -1000,6 +1086,15 @@ func (conn *Connection) Do(req Request) *Future {
10001086 return fut
10011087 }
10021088 }
1089+ if req.Ctx() != nil {
1090+ select {
1091+ case <-req.Ctx().Done():
1092+ fut := NewFuture()
1093+ fut.SetError(fmt.Errorf("context is done"))
1094+ return fut
1095+ default:
1096+ }
1097+ }
10031098 return conn.send(req)
10041099}
10051100
0 commit comments