diff --git a/src/datachannel/streaming.go b/src/datachannel/streaming.go index 73ad3374..fb8b9b75 100644 --- a/src/datachannel/streaming.go +++ b/src/datachannel/streaming.go @@ -115,6 +115,8 @@ type DataChannel struct { // AgentVersion received during handshake agentVersion string + + mutex sync.Mutex } type ListMessageBuffer struct { @@ -274,6 +276,9 @@ func (dataChannel *DataChannel) SendInputDataMessage( payloadType message.PayloadType, inputData []byte) (err error) { + dataChannel.mutex.Lock() + defer dataChannel.mutex.Unlock() + var ( flag uint64 = 0 msg []byte @@ -338,12 +343,16 @@ func (dataChannel *DataChannel) ResendStreamDataMessageScheduler(log log.T) (err streamMessageElement := dataChannel.OutgoingMessageBuffer.Messages.Front() dataChannel.OutgoingMessageBuffer.Mutex.Unlock() + dataChannel.mutex.Lock() + localTimeout := dataChannel.RetransmissionTimeout + dataChannel.mutex.Unlock() + if streamMessageElement == nil { continue } streamMessage := streamMessageElement.Value.(StreamingMessage) - if time.Since(streamMessage.LastSentTime) > dataChannel.RetransmissionTimeout { + if time.Since(streamMessage.LastSentTime) > localTimeout { log.Debugf("Resend stream data message %d for the %d attempt.", streamMessage.SequenceNumber, *streamMessage.ResendAttempt) if *streamMessage.ResendAttempt >= config.ResendMaxAttempt { log.Warnf("Message %d was resent over %d times.", streamMessage.SequenceNumber, config.ResendMaxAttempt) @@ -363,6 +372,9 @@ func (dataChannel *DataChannel) ResendStreamDataMessageScheduler(log log.T) (err // ProcessAcknowledgedMessage processes acknowledge messages by deleting them from OutgoingMessageBuffer func (dataChannel *DataChannel) ProcessAcknowledgedMessage(log log.T, acknowledgeMessageContent message.AcknowledgeContent) error { + dataChannel.mutex.Lock() + defer dataChannel.mutex.Unlock() + acknowledgeSequenceNumber := acknowledgeMessageContent.SequenceNumber for streamMessageElement := dataChannel.OutgoingMessageBuffer.Messages.Front(); streamMessageElement != nil; streamMessageElement = streamMessageElement.Next() { streamMessage := streamMessageElement.Value.(StreamingMessage) @@ -610,6 +622,8 @@ func (dataChannel *DataChannel) HandleOutputMessage( outputMessage message.ClientMessage, rawMessage []byte) (err error) { + dataChannel.mutex.Lock() + // On receiving expected stream data message, send acknowledgement, process it and increment expected sequence number by 1. // Further process messages from IncomingMessageBuffer if outputMessage.SequenceNumber == dataChannel.ExpectedSequenceNumber { @@ -618,40 +632,51 @@ func (dataChannel *DataChannel) HandleOutputMessage( case message.HandshakeRequestPayloadType: { if err = SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil { + dataChannel.mutex.Unlock() return err } // PayloadType is HandshakeRequest so we call our own handler instead of the provided handler log.Debugf("Processing HandshakeRequest message %s", outputMessage) + + // The handler will eventually request the lock in `SendInputDataMessage`, so we'll unlock here to avoid deadlock + dataChannel.mutex.Unlock() if err = dataChannel.handleHandshakeRequest(log, outputMessage); err != nil { log.Errorf("Unable to process incoming data payload, MessageType %s, "+ "PayloadType HandshakeRequestPayloadType, err: %s.", outputMessage.MessageType, err) return err } + dataChannel.mutex.Lock() } case message.HandshakeCompletePayloadType: { if err = SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil { + dataChannel.mutex.Unlock() return err } + dataChannel.mutex.Unlock() if err = dataChannel.handleHandshakeComplete(log, outputMessage); err != nil { log.Errorf("Unable to process incoming data payload, MessageType %s, "+ "PayloadType HandshakeCompletePayloadType, err: %s.", outputMessage.MessageType, err) return err } + dataChannel.mutex.Lock() } case message.EncChallengeRequest: { if err = SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil { + dataChannel.mutex.Unlock() return err } + dataChannel.mutex.Unlock() if err = dataChannel.handleEncryptionChallengeRequest(log, outputMessage); err != nil { log.Errorf("Unable to process incoming data payload, MessageType %s, "+ "PayloadType EncChallengeRequest, err: %s.", outputMessage.MessageType, err) return err } + dataChannel.mutex.Lock() } default: @@ -681,11 +706,13 @@ func (dataChannel *DataChannel) HandleOutputMessage( } else { // Acknowledge outputMessage only if session specific handler is ready if err := SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil { + dataChannel.mutex.Unlock() return err } } } dataChannel.ExpectedSequenceNumber = dataChannel.ExpectedSequenceNumber + 1 + dataChannel.mutex.Unlock() return dataChannel.ProcessIncomingMessageBufferItems(log, outputMessage) } else { log.Debugf("Unexpected sequence message received. Received Sequence Number: %d. Expected Sequence Number: %d", @@ -698,6 +725,7 @@ func (dataChannel *DataChannel) HandleOutputMessage( outputMessage.SequenceNumber, dataChannel.ExpectedSequenceNumber) if len(dataChannel.IncomingMessageBuffer.Messages) < dataChannel.IncomingMessageBuffer.Capacity { if err = SendAcknowledgeMessageCall(log, dataChannel, outputMessage); err != nil { + dataChannel.mutex.Unlock() return err } @@ -713,6 +741,7 @@ func (dataChannel *DataChannel) HandleOutputMessage( } } } + dataChannel.mutex.Unlock() return nil } @@ -722,6 +751,9 @@ func (dataChannel *DataChannel) HandleOutputMessage( func (dataChannel *DataChannel) ProcessIncomingMessageBufferItems(log log.T, outputMessage message.ClientMessage) (err error) { + dataChannel.mutex.Lock() + defer dataChannel.mutex.Unlock() + for { bufferedStreamMessage := dataChannel.IncomingMessageBuffer.Messages[dataChannel.ExpectedSequenceNumber] if bufferedStreamMessage.Content != nil { diff --git a/src/sessionmanagerplugin/session/portsession/portsession.go b/src/sessionmanagerplugin/session/portsession/portsession.go index 793b6a74..cf9c9d58 100644 --- a/src/sessionmanagerplugin/session/portsession/portsession.go +++ b/src/sessionmanagerplugin/session/portsession/portsession.go @@ -50,7 +50,7 @@ type PortParameters struct { } func init() { - session.Register(&PortSession{}) + session.Register(&PortSession{}, func() session.ISessionPlugin { return &PortSession{} }) } // Name is the session name used inputStream the plugin diff --git a/src/sessionmanagerplugin/session/session.go b/src/sessionmanagerplugin/session/session.go index 6f4d154b..bcd42200 100644 --- a/src/sessionmanagerplugin/session/session.go +++ b/src/sessionmanagerplugin/session/session.go @@ -43,7 +43,7 @@ const ( VersionFile = "VERSION" ) -var SessionRegistry = map[string]ISessionPlugin{} +var SessionRegistry = map[string]func() ISessionPlugin{} type ISessionPlugin interface { SetSessionHandlers(log.T) error @@ -64,11 +64,11 @@ type ISession interface { } func init() { - SessionRegistry = make(map[string]ISessionPlugin) + SessionRegistry = make(map[string]func() ISessionPlugin) } -func Register(session ISessionPlugin) { - SessionRegistry[session.Name()] = session +func Register(session ISessionPlugin, constructor func() ISessionPlugin) { + SessionRegistry[session.Name()] = constructor } type Session struct { @@ -94,8 +94,12 @@ var startSession = func(session *Session, log log.T) error { // setSessionHandlersWithSessionType set session handlers based on session subtype var setSessionHandlersWithSessionType = func(session *Session, log log.T) error { - // SessionType is set inside DataChannel - sessionSubType := SessionRegistry[session.SessionType] + constructor := SessionRegistry[session.SessionType] + if constructor == nil { + return fmt.Errorf("no constructor found for session type %s", session.SessionType) + } + + sessionSubType := constructor() sessionSubType.Initialize(log, session) return sessionSubType.SetSessionHandlers(log) } diff --git a/src/sessionmanagerplugin/session/shellsession/shellsession.go b/src/sessionmanagerplugin/session/shellsession/shellsession.go index 991683c9..d4b200dd 100644 --- a/src/sessionmanagerplugin/session/shellsession/shellsession.go +++ b/src/sessionmanagerplugin/session/shellsession/shellsession.go @@ -47,7 +47,7 @@ var GetTerminalSizeCall = func(fd int) (width int, height int, err error) { } func init() { - session.Register(&ShellSession{}) + session.Register(&ShellSession{}, func() session.ISessionPlugin { return &ShellSession{} }) } // Name is the session name used in the plugin