Skip to content

Commit d936d91

Browse files
authored
Fixed issue with output stream being reused when using multi concurrency mode
1 parent a534ca7 commit d936d91

File tree

3 files changed

+180
-55
lines changed

3 files changed

+180
-55
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"Projects": [
3+
{
4+
"Name": "Amazon.Lambda.RuntimeSupport",
5+
"Type": "Patch",
6+
"ChangelogMessages": [
7+
"Fixed issue with output stream being reused when using multi concurrency mode"
8+
]
9+
}
10+
]
11+
}

Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/HandlerWrapper.cs

Lines changed: 126 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public class HandlerWrapper : IDisposable
2929
private static readonly InvocationResponse EmptyInvocationResponse =
3030
new InvocationResponse(new MemoryStream(0), false);
3131

32-
private readonly MemoryStream OutputStream = new MemoryStream();
32+
private readonly IOutputStreamFactory _outputStreamFactory;
3333

3434
/// <summary>
3535
/// The handler that will be called for each event.
@@ -39,9 +39,20 @@ public class HandlerWrapper : IDisposable
3939
private HandlerWrapper(LambdaBootstrapHandler handler)
4040
{
4141
Handler = handler;
42+
43+
if (Helpers.Utils.IsUsingMultiConcurrency(new SystemEnvironmentVariables()))
44+
_outputStreamFactory = new MultiConcurrencyOutputStreamFactory();
45+
else
46+
_outputStreamFactory = new OnDemandOutputStreamFactory();
4247
}
4348

44-
private HandlerWrapper() { }
49+
private HandlerWrapper()
50+
{
51+
if (Helpers.Utils.IsUsingMultiConcurrency(new SystemEnvironmentVariables()))
52+
_outputStreamFactory = new MultiConcurrencyOutputStreamFactory();
53+
else
54+
_outputStreamFactory = new OnDemandOutputStreamFactory();
55+
}
4556

4657
/// <summary>
4758
/// Get a HandlerWrapper that will call the given delegate on function invocation.
@@ -54,10 +65,10 @@ public static HandlerWrapper GetHandlerWrapper(Action<Stream, ILambdaContext, Me
5465
var handlerWrapper = new HandlerWrapper();
5566
handlerWrapper.Handler = invocation =>
5667
{
57-
handlerWrapper.OutputStream.SetLength(0);
58-
invokeDelegate(invocation.InputStream, invocation.LambdaContext, handlerWrapper.OutputStream);
59-
handlerWrapper.OutputStream.Position = 0;
60-
var response = new InvocationResponse(handlerWrapper.OutputStream, false);
68+
var outputStream = handlerWrapper._outputStreamFactory.CreateOutputStream();
69+
invokeDelegate(invocation.InputStream, invocation.LambdaContext, outputStream);
70+
outputStream.Position = 0;
71+
var response = new InvocationResponse(outputStream, false);
6172
return Task.FromResult(response);
6273
};
6374
return handlerWrapper;
@@ -271,10 +282,10 @@ public static HandlerWrapper GetHandlerWrapper<TOutput>(Func<Task<TOutput>> hand
271282
handlerWrapper.Handler = async (invocation) =>
272283
{
273284
TOutput output = await handler();
274-
handlerWrapper.OutputStream.SetLength(0);
275-
serializer.Serialize(output, handlerWrapper.OutputStream);
276-
handlerWrapper.OutputStream.Position = 0;
277-
return new InvocationResponse(handlerWrapper.OutputStream, false);
285+
var outputStream = handlerWrapper._outputStreamFactory.CreateOutputStream();
286+
serializer.Serialize(output, outputStream);
287+
outputStream.Position = 0;
288+
return new InvocationResponse(outputStream, false);
278289
};
279290
return handlerWrapper;
280291
}
@@ -293,10 +304,10 @@ public static HandlerWrapper GetHandlerWrapper<TOutput>(Func<Stream, Task<TOutpu
293304
handlerWrapper.Handler = async (invocation) =>
294305
{
295306
TOutput output = await handler(invocation.InputStream);
296-
handlerWrapper.OutputStream.SetLength(0);
297-
serializer.Serialize(output, handlerWrapper.OutputStream);
298-
handlerWrapper.OutputStream.Position = 0;
299-
return new InvocationResponse(handlerWrapper.OutputStream, false);
307+
var outputStream = handlerWrapper._outputStreamFactory.CreateOutputStream();
308+
serializer.Serialize(output, outputStream);
309+
outputStream.Position = 0;
310+
return new InvocationResponse(outputStream, false);
300311
};
301312
return handlerWrapper;
302313
}
@@ -316,10 +327,10 @@ public static HandlerWrapper GetHandlerWrapper<TInput, TOutput>(Func<TInput, Tas
316327
{
317328
TInput input = serializer.Deserialize<TInput>(invocation.InputStream);
318329
TOutput output = await handler(input);
319-
handlerWrapper.OutputStream.SetLength(0);
320-
serializer.Serialize(output, handlerWrapper.OutputStream);
321-
handlerWrapper.OutputStream.Position = 0;
322-
return new InvocationResponse(handlerWrapper.OutputStream, false);
330+
var outputStream = handlerWrapper._outputStreamFactory.CreateOutputStream();
331+
serializer.Serialize(output, outputStream);
332+
outputStream.Position = 0;
333+
return new InvocationResponse(outputStream, false);
323334
};
324335
return handlerWrapper;
325336
}
@@ -338,10 +349,10 @@ public static HandlerWrapper GetHandlerWrapper<TOutput>(Func<ILambdaContext, Tas
338349
handlerWrapper.Handler = async (invocation) =>
339350
{
340351
TOutput output = await handler(invocation.LambdaContext);
341-
handlerWrapper.OutputStream.SetLength(0);
342-
serializer.Serialize(output, handlerWrapper.OutputStream);
343-
handlerWrapper.OutputStream.Position = 0; ;
344-
return new InvocationResponse(handlerWrapper.OutputStream, false);
352+
var outputStream = handlerWrapper._outputStreamFactory.CreateOutputStream();
353+
serializer.Serialize(output, outputStream);
354+
outputStream.Position = 0; ;
355+
return new InvocationResponse(outputStream, false);
345356
};
346357
return handlerWrapper;
347358
}
@@ -360,10 +371,10 @@ public static HandlerWrapper GetHandlerWrapper<TOutput>(Func<Stream, ILambdaCont
360371
handlerWrapper.Handler = async (invocation) =>
361372
{
362373
TOutput output = await handler(invocation.InputStream, invocation.LambdaContext);
363-
handlerWrapper.OutputStream.SetLength(0);
364-
serializer.Serialize(output, handlerWrapper.OutputStream);
365-
handlerWrapper.OutputStream.Position = 0;
366-
return new InvocationResponse(handlerWrapper.OutputStream, false);
374+
var outputStream = handlerWrapper._outputStreamFactory.CreateOutputStream();
375+
serializer.Serialize(output, outputStream);
376+
outputStream.Position = 0;
377+
return new InvocationResponse(outputStream, false);
367378
};
368379
return handlerWrapper;
369380
}
@@ -383,10 +394,10 @@ public static HandlerWrapper GetHandlerWrapper<TInput, TOutput>(Func<TInput, ILa
383394
{
384395
TInput input = serializer.Deserialize<TInput>(invocation.InputStream);
385396
TOutput output = await handler(input, invocation.LambdaContext);
386-
handlerWrapper.OutputStream.SetLength(0);
387-
serializer.Serialize(output, handlerWrapper.OutputStream);
388-
handlerWrapper.OutputStream.Position = 0;
389-
return new InvocationResponse(handlerWrapper.OutputStream, false);
397+
var outputStream = handlerWrapper._outputStreamFactory.CreateOutputStream();
398+
serializer.Serialize(output, outputStream);
399+
outputStream.Position = 0;
400+
return new InvocationResponse(outputStream, false);
390401
};
391402
return handlerWrapper;
392403
}
@@ -599,10 +610,10 @@ public static HandlerWrapper GetHandlerWrapper<TOutput>(Func<TOutput> handler, I
599610
handlerWrapper.Handler = (invocation) =>
600611
{
601612
TOutput output = handler();
602-
handlerWrapper.OutputStream.SetLength(0);
603-
serializer.Serialize(output, handlerWrapper.OutputStream);
604-
handlerWrapper.OutputStream.Position = 0;
605-
return Task.FromResult(new InvocationResponse(handlerWrapper.OutputStream, false));
613+
var outputStream = handlerWrapper._outputStreamFactory.CreateOutputStream();
614+
serializer.Serialize(output, outputStream);
615+
outputStream.Position = 0;
616+
return Task.FromResult(new InvocationResponse(outputStream, false));
606617
};
607618
return handlerWrapper;
608619
}
@@ -621,10 +632,10 @@ public static HandlerWrapper GetHandlerWrapper<TOutput>(Func<Stream, TOutput> ha
621632
handlerWrapper.Handler = (invocation) =>
622633
{
623634
TOutput output = handler(invocation.InputStream);
624-
handlerWrapper.OutputStream.SetLength(0);
625-
serializer.Serialize(output, handlerWrapper.OutputStream);
626-
handlerWrapper.OutputStream.Position = 0;
627-
return Task.FromResult(new InvocationResponse(handlerWrapper.OutputStream, false));
635+
var outputStream = handlerWrapper._outputStreamFactory.CreateOutputStream();
636+
serializer.Serialize(output, outputStream);
637+
outputStream.Position = 0;
638+
return Task.FromResult(new InvocationResponse(outputStream, false));
628639
};
629640
return handlerWrapper;
630641
}
@@ -644,10 +655,10 @@ public static HandlerWrapper GetHandlerWrapper<TInput, TOutput>(Func<TInput, TOu
644655
{
645656
TInput input = serializer.Deserialize<TInput>(invocation.InputStream);
646657
TOutput output = handler(input);
647-
handlerWrapper.OutputStream.SetLength(0);
648-
serializer.Serialize(output, handlerWrapper.OutputStream);
649-
handlerWrapper.OutputStream.Position = 0;
650-
return Task.FromResult(new InvocationResponse(handlerWrapper.OutputStream, false));
658+
var outputStream = handlerWrapper._outputStreamFactory.CreateOutputStream();
659+
serializer.Serialize(output, outputStream);
660+
outputStream.Position = 0;
661+
return Task.FromResult(new InvocationResponse(outputStream, false));
651662
};
652663
return handlerWrapper;
653664
}
@@ -666,10 +677,10 @@ public static HandlerWrapper GetHandlerWrapper<TOutput>(Func<ILambdaContext, TOu
666677
handlerWrapper.Handler = (invocation) =>
667678
{
668679
TOutput output = handler(invocation.LambdaContext);
669-
handlerWrapper.OutputStream.SetLength(0);
670-
serializer.Serialize(output, handlerWrapper.OutputStream);
671-
handlerWrapper.OutputStream.Position = 0; ;
672-
return Task.FromResult(new InvocationResponse(handlerWrapper.OutputStream, false));
680+
var outputStream = handlerWrapper._outputStreamFactory.CreateOutputStream();
681+
serializer.Serialize(output, outputStream);
682+
outputStream.Position = 0; ;
683+
return Task.FromResult(new InvocationResponse(outputStream, false));
673684
};
674685
return handlerWrapper;
675686
}
@@ -688,10 +699,10 @@ public static HandlerWrapper GetHandlerWrapper<TOutput>(Func<Stream, ILambdaCont
688699
handlerWrapper.Handler = (invocation) =>
689700
{
690701
TOutput output = handler(invocation.InputStream, invocation.LambdaContext);
691-
handlerWrapper.OutputStream.SetLength(0);
692-
serializer.Serialize(output, handlerWrapper.OutputStream);
693-
handlerWrapper.OutputStream.Position = 0;
694-
return Task.FromResult(new InvocationResponse(handlerWrapper.OutputStream, false));
702+
var outputStream = handlerWrapper._outputStreamFactory.CreateOutputStream();
703+
serializer.Serialize(output, outputStream);
704+
outputStream.Position = 0;
705+
return Task.FromResult(new InvocationResponse(outputStream, false));
695706
};
696707
return handlerWrapper;
697708
}
@@ -711,10 +722,10 @@ public static HandlerWrapper GetHandlerWrapper<TInput, TOutput>(Func<TInput, ILa
711722
{
712723
TInput input = serializer.Deserialize<TInput>(invocation.InputStream);
713724
TOutput output = handler(input, invocation.LambdaContext);
714-
handlerWrapper.OutputStream.SetLength(0);
715-
serializer.Serialize(output, handlerWrapper.OutputStream);
716-
handlerWrapper.OutputStream.Position = 0;
717-
return Task.FromResult(new InvocationResponse(handlerWrapper.OutputStream, false));
725+
var outputStream = handlerWrapper._outputStreamFactory.CreateOutputStream();
726+
serializer.Serialize(output, outputStream);
727+
outputStream.Position = 0;
728+
return Task.FromResult(new InvocationResponse(outputStream, false));
718729
};
719730
return handlerWrapper;
720731
}
@@ -731,7 +742,7 @@ protected virtual void Dispose(bool disposing)
731742
{
732743
if (disposing)
733744
{
734-
OutputStream.Dispose();
745+
_outputStreamFactory.Dispose();
735746
}
736747

737748
disposedValue = true;
@@ -746,5 +757,65 @@ public void Dispose()
746757
Dispose(true);
747758
}
748759
#endregion
760+
761+
interface IOutputStreamFactory : IDisposable
762+
{
763+
MemoryStream CreateOutputStream();
764+
}
765+
766+
/// <summary>
767+
/// In on demand mode there is never a more then one invocation happening at a time within the process
768+
/// so the same memory stream can be reused.
769+
/// </summary>
770+
class OnDemandOutputStreamFactory : IOutputStreamFactory
771+
{
772+
private readonly MemoryStream OutputStream = new MemoryStream();
773+
private bool _disposedValue;
774+
775+
public MemoryStream CreateOutputStream()
776+
{
777+
OutputStream.SetLength(0);
778+
return OutputStream;
779+
}
780+
781+
protected virtual void Dispose(bool disposing)
782+
{
783+
if (!_disposedValue)
784+
{
785+
if (disposing)
786+
{
787+
OutputStream.Dispose();
788+
}
789+
790+
_disposedValue = true;
791+
}
792+
}
793+
794+
public void Dispose()
795+
{
796+
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
797+
Dispose(disposing: true);
798+
GC.SuppressFinalize(this);
799+
}
800+
}
801+
802+
/// <summary>
803+
/// In multi concurrency mode multiple invocations can happen at the same time within the process
804+
/// so we need to make sure each invocation gets its own output stream.
805+
/// </summary>
806+
class MultiConcurrencyOutputStreamFactory : IOutputStreamFactory
807+
{
808+
public MemoryStream CreateOutputStream()
809+
{
810+
return new MemoryStream();
811+
}
812+
813+
public void Dispose()
814+
{
815+
// Technically we are creating MemoryStreams that have a Dispose method but that is inherited from the base
816+
// class. A MemoryStream is fully managed and doesn't have anything to dispose so it is okay to not worry
817+
// about disposing any of the MemoryStreams created from the CreateOutputStream call.
818+
}
819+
}
749820
}
750821
}

Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/HandlerWrapperTests.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,49 @@ public async Task TestSerializtionOfString()
633633
}
634634
}
635635

636+
[Theory]
637+
[InlineData(true)]
638+
[InlineData(false)]
639+
public async Task TestOutputStreamReuse(bool onDemand)
640+
{
641+
if (!onDemand)
642+
{
643+
Environment.SetEnvironmentVariable(Constants.ENV_VAR_AWS_LAMBDA_MAX_CONCURRENCY, "10");
644+
}
645+
try
646+
{
647+
using (var handlerWrapper = HandlerWrapper.GetHandlerWrapper<string, string>((input) =>
648+
{
649+
return input.ToUpper();
650+
}, Serializer))
651+
{
652+
var invocation1 = new InvocationRequest
653+
{
654+
InputStream = new MemoryStream(UTF8Encoding.UTF8.GetBytes("\"Hello\"")),
655+
LambdaContext = new LambdaContext(_runtimeApiHeaders, _lambdaEnvironment, new Helpers.SimpleLoggerWriter(new SystemEnvironmentVariables()))
656+
};
657+
658+
var invocationResponse1 = await handlerWrapper.Handler(invocation1);
659+
660+
var invocation2 = new InvocationRequest
661+
{
662+
InputStream = new MemoryStream(UTF8Encoding.UTF8.GetBytes("\"World\"")),
663+
LambdaContext = new LambdaContext(_runtimeApiHeaders, _lambdaEnvironment, new Helpers.SimpleLoggerWriter(new SystemEnvironmentVariables()))
664+
};
665+
666+
var invocationResponse2 = await handlerWrapper.Handler(invocation2);
667+
if (onDemand)
668+
Assert.True(object.ReferenceEquals(invocationResponse1.OutputStream, invocationResponse2.OutputStream));
669+
else
670+
Assert.False(object.ReferenceEquals(invocationResponse1.OutputStream, invocationResponse2.OutputStream));
671+
}
672+
}
673+
finally
674+
{
675+
Environment.SetEnvironmentVariable(Constants.ENV_VAR_AWS_LAMBDA_MAX_CONCURRENCY, null);
676+
}
677+
}
678+
636679
private async Task TestHandlerWrapper(HandlerWrapper handlerWrapper, byte[] input, byte[] expectedOutput, bool expectedDisposeOutputStream)
637680
{
638681
// run twice to make sure wrappers that reuse the output stream work correctly

0 commit comments

Comments
 (0)