Skip to content

Commit 9d4fa97

Browse files
authored
Add a Runtime API client (#298)
* Add a Runtime API client * make the new client function private * fix typo * Update entry.go * Update invoke_loop.go * Update invoke_loop_test.go * Update invoke_loop_test.go * Update runtime_api_client_test.go * fix StartWithContext, add lambda/entry_test.go * appease errcheck
1 parent 26aa364 commit 9d4fa97

File tree

12 files changed

+731
-48
lines changed

12 files changed

+731
-48
lines changed

go.sum

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0
99
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
1010
github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=
1111
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
12+
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
1213
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
1314
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
1415
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=

lambda/entry.go

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@ package lambda
44

55
import (
66
"context"
7+
"errors"
78
"log"
8-
"net"
9-
"net/rpc"
109
"os"
1110
)
1211

@@ -56,23 +55,47 @@ func StartHandler(handler Handler) {
5655
StartHandlerWithContext(context.Background(), handler)
5756
}
5857

58+
type startFunction struct {
59+
env string
60+
f func(ctx context.Context, envValue string, hander Handler) error
61+
}
62+
63+
var (
64+
// This allow users to save a little bit of coldstart time in the download, by the dependecies brought in for RPC support.
65+
// The tradeoff is dropping compatibility with the go1.x runtime, functions must be "Custom Runtime" instead.
66+
// To drop the rpc dependecies, compile with `-tags lambda.norpc`
67+
rpcStartFunction = &startFunction{
68+
env: "_LAMBDA_SERVER_PORT",
69+
f: func(c context.Context, p string, h Handler) error {
70+
return errors.New("_LAMBDA_SERVER_PORT was present but the function was compiled without RPC support")
71+
},
72+
}
73+
runtimeAPIStartFunction = &startFunction{
74+
env: "AWS_LAMBDA_RUNTIME_API",
75+
f: startRuntimeAPILoop,
76+
}
77+
startFunctions = []*startFunction{rpcStartFunction, runtimeAPIStartFunction}
78+
79+
// This allows end to end testing of the Start functions, by tests overwriting this function to keep the program alive
80+
logFatalf = log.Fatalf
81+
)
82+
5983
// StartHandlerWithContext is the same as StartHandler except sets the base context for the function.
6084
//
6185
// Handler implementation requires a single "Invoke()" function:
6286
//
6387
// func Invoke(context.Context, []byte) ([]byte, error)
6488
func StartHandlerWithContext(ctx context.Context, handler Handler) {
65-
port := os.Getenv("_LAMBDA_SERVER_PORT")
66-
lis, err := net.Listen("tcp", "localhost:"+port)
67-
if err != nil {
68-
log.Fatal(err)
89+
var keys []string
90+
for _, start := range startFunctions {
91+
config := os.Getenv(start.env)
92+
if config != "" {
93+
// in normal operation, the start function never returns
94+
// if it does, exit!, this triggers a restart of the lambda function
95+
err := start.f(ctx, config, handler)
96+
logFatalf("%v", err)
97+
}
98+
keys = append(keys, start.env)
6999
}
70-
71-
fn := NewFunction(handler).withContext(ctx)
72-
if err := rpc.Register(fn); err != nil {
73-
log.Fatal("failed to register handler function")
74-
}
75-
76-
rpc.Accept(lis)
77-
log.Fatal("accept should not have returned")
100+
logFatalf("expected AWS Lambda environment variables %s are not defined", keys)
78101
}

lambda/entry_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved
2+
3+
package lambda
4+
5+
import (
6+
"context"
7+
"fmt"
8+
"log"
9+
"net"
10+
"net/rpc"
11+
"os"
12+
"strings"
13+
"testing"
14+
15+
"github.com/aws/aws-lambda-go/lambda/messages"
16+
"github.com/stretchr/testify/assert"
17+
)
18+
19+
func TestStartRuntimeAPIWithContext(t *testing.T) {
20+
server, _ := runtimeAPIServer("null", 1) // serve a single invoke, and then cause an internal error
21+
expected := "expected"
22+
actual := "unexpected"
23+
24+
os.Setenv("AWS_LAMBDA_RUNTIME_API", strings.Split(server.URL, "://")[1])
25+
defer os.Unsetenv("AWS_LAMBDA_RUNTIME_API")
26+
logFatalf = func(format string, v ...interface{}) {}
27+
defer func() { logFatalf = log.Fatalf }()
28+
29+
StartWithContext(context.WithValue(context.Background(), "key", expected), func(ctx context.Context) error {
30+
actual, _ = ctx.Value("key").(string)
31+
return nil
32+
})
33+
34+
assert.Equal(t, expected, actual)
35+
}
36+
37+
func TestStartRPCWithContext(t *testing.T) {
38+
expected := "expected"
39+
actual := "unexpected"
40+
port := getFreeTCPPort()
41+
os.Setenv("_LAMBDA_SERVER_PORT", fmt.Sprintf("%d", port))
42+
defer os.Unsetenv("_LAMBDA_SERVER_PORT")
43+
go StartWithContext(context.WithValue(context.Background(), "key", expected), func(ctx context.Context) error {
44+
actual, _ = ctx.Value("key").(string)
45+
return nil
46+
})
47+
48+
var client *rpc.Client
49+
var pingResponse messages.PingResponse
50+
var invokeResponse messages.InvokeResponse
51+
var err error
52+
for {
53+
client, err = rpc.Dial("tcp", fmt.Sprintf("localhost:%d", port))
54+
if err != nil {
55+
continue
56+
}
57+
break
58+
}
59+
for {
60+
if err := client.Call("Function.Ping", &messages.PingRequest{}, &pingResponse); err != nil {
61+
continue
62+
}
63+
break
64+
}
65+
if err := client.Call("Function.Invoke", &messages.InvokeRequest{}, &invokeResponse); err != nil {
66+
t.Logf("error invoking function: %v", err)
67+
}
68+
69+
assert.Equal(t, expected, actual)
70+
}
71+
72+
func getFreeTCPPort() int {
73+
l, err := net.Listen("tcp", "localhost:0")
74+
if err != nil {
75+
log.Fatal("getFreeTCPPort failed: ", err)
76+
}
77+
defer l.Close()
78+
79+
return l.Addr().(*net.TCPAddr).Port
80+
}
81+
82+
func TestStartNotInLambda(t *testing.T) {
83+
actual := "unexpected"
84+
logFatalf = func(format string, v ...interface{}) {
85+
actual = fmt.Sprintf(format, v...)
86+
}
87+
88+
Start(func() error { return nil })
89+
assert.Equal(t, "expected AWS Lambda environment variables [_LAMBDA_SERVER_PORT AWS_LAMBDA_RUNTIME_API] are not defined", actual)
90+
}

lambda/errors.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved
2+
3+
package lambda
4+
5+
import (
6+
"reflect"
7+
8+
"github.com/aws/aws-lambda-go/lambda/messages"
9+
)
10+
11+
func getErrorType(err interface{}) string {
12+
errorType := reflect.TypeOf(err)
13+
if errorType.Kind() == reflect.Ptr {
14+
return errorType.Elem().Name()
15+
}
16+
return errorType.Name()
17+
}
18+
19+
func lambdaErrorResponse(invokeError error) *messages.InvokeResponse_Error {
20+
var errorName string
21+
if errorType := reflect.TypeOf(invokeError); errorType.Kind() == reflect.Ptr {
22+
errorName = errorType.Elem().Name()
23+
} else {
24+
errorName = errorType.Name()
25+
}
26+
return &messages.InvokeResponse_Error{
27+
Message: invokeError.Error(),
28+
Type: errorName,
29+
}
30+
}
31+
32+
func lambdaPanicResponse(err interface{}) *messages.InvokeResponse_Error {
33+
panicInfo := getPanicInfo(err)
34+
return &messages.InvokeResponse_Error{
35+
Message: panicInfo.Message,
36+
Type: getErrorType(err),
37+
StackTrace: panicInfo.StackTrace,
38+
ShouldExit: true,
39+
}
40+
}

lambda/function.go

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"context"
77
"encoding/json"
88
"os"
9-
"reflect"
109
"time"
1110

1211
"github.com/aws/aws-lambda-go/lambda/messages"
@@ -34,13 +33,7 @@ func (fn *Function) Ping(req *messages.PingRequest, response *messages.PingRespo
3433
func (fn *Function) Invoke(req *messages.InvokeRequest, response *messages.InvokeResponse) error {
3534
defer func() {
3635
if err := recover(); err != nil {
37-
panicInfo := getPanicInfo(err)
38-
response.Error = &messages.InvokeResponse_Error{
39-
Message: panicInfo.Message,
40-
Type: getErrorType(err),
41-
StackTrace: panicInfo.StackTrace,
42-
ShouldExit: true,
43-
}
36+
response.Error = lambdaPanicResponse(err)
4437
}
4538
}()
4639

@@ -99,24 +92,3 @@ func (fn *Function) withContext(ctx context.Context) *Function {
9992

10093
return fn2
10194
}
102-
103-
func getErrorType(err interface{}) string {
104-
errorType := reflect.TypeOf(err)
105-
if errorType.Kind() == reflect.Ptr {
106-
return errorType.Elem().Name()
107-
}
108-
return errorType.Name()
109-
}
110-
111-
func lambdaErrorResponse(invokeError error) *messages.InvokeResponse_Error {
112-
var errorName string
113-
if errorType := reflect.TypeOf(invokeError); errorType.Kind() == reflect.Ptr {
114-
errorName = errorType.Elem().Name()
115-
} else {
116-
errorName = errorType.Name()
117-
}
118-
return &messages.InvokeResponse_Error{
119-
Message: invokeError.Error(),
120-
Type: errorName,
121-
}
122-
}

lambda/handler_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ func TestInvalidJsonInput(t *testing.T) {
207207
lambdaHandler := NewHandler(func(s string) error { return nil })
208208
_, err := lambdaHandler.Invoke(context.TODO(), []byte(`{"invalid json`))
209209
assert.Equal(t, "unexpected end of JSON input", err.Error())
210-
211210
}
212211

213212
func TestHandlerTrace(t *testing.T) {

lambda/invoke_loop.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved
2+
3+
package lambda
4+
5+
import (
6+
"context"
7+
"encoding/json"
8+
"fmt"
9+
"strconv"
10+
"time"
11+
12+
"github.com/aws/aws-lambda-go/lambda/messages"
13+
)
14+
15+
const (
16+
serializationErrorFormat = `{"errorType": "Runtime.SerializationError", "errorMessage": "%s"}`
17+
msPerS = int64(time.Second / time.Millisecond)
18+
nsPerMS = int64(time.Millisecond / time.Nanosecond)
19+
)
20+
21+
// startRuntimeAPILoop will return an error if handling a particular invoke resulted in a non-recoverable error
22+
func startRuntimeAPILoop(ctx context.Context, api string, handler Handler) error {
23+
client := newRuntimeAPIClient(api)
24+
function := NewFunction(handler).withContext(ctx)
25+
for {
26+
invoke, err := client.next()
27+
if err != nil {
28+
return err
29+
}
30+
31+
err = handleInvoke(invoke, function)
32+
if err != nil {
33+
return err
34+
}
35+
}
36+
}
37+
38+
// handleInvoke returns an error if the function panics, or some other non-recoverable error occurred
39+
func handleInvoke(invoke *invoke, function *Function) error {
40+
functionRequest, err := convertInvokeRequest(invoke)
41+
if err != nil {
42+
return fmt.Errorf("unexpected error occured when parsing the invoke: %v", err)
43+
}
44+
45+
functionResponse := &messages.InvokeResponse{}
46+
if err := function.Invoke(functionRequest, functionResponse); err != nil {
47+
return fmt.Errorf("unexpected error occured when invoking the handler: %v", err)
48+
}
49+
50+
if functionResponse.Error != nil {
51+
payload := safeMarshal(functionResponse.Error)
52+
if err := invoke.failure(payload, contentTypeJSON); err != nil {
53+
return fmt.Errorf("unexpected error occured when sending the function error to the API: %v", err)
54+
}
55+
if functionResponse.Error.ShouldExit {
56+
return fmt.Errorf("calling the handler function resulted in a panic, the process should exit")
57+
}
58+
return nil
59+
}
60+
61+
if err := invoke.success(functionResponse.Payload, contentTypeJSON); err != nil {
62+
return fmt.Errorf("unexpected error occured when sending the function functionResponse to the API: %v", err)
63+
}
64+
65+
return nil
66+
}
67+
68+
// convertInvokeRequest converts an invoke from the Runtime API, and unpacks it to be compatible with the shape of a `lambda.Function` InvokeRequest.
69+
func convertInvokeRequest(invoke *invoke) (*messages.InvokeRequest, error) {
70+
deadlineEpochMS, err := strconv.ParseInt(invoke.headers.Get(headerDeadlineMS), 10, 64)
71+
if err != nil {
72+
return nil, fmt.Errorf("failed to parse contents of header: %s", headerDeadlineMS)
73+
}
74+
deadlineS := deadlineEpochMS / msPerS
75+
deadlineNS := (deadlineEpochMS % msPerS) * nsPerMS
76+
77+
res := &messages.InvokeRequest{
78+
InvokedFunctionArn: invoke.headers.Get(headerInvokedFunctionARN),
79+
XAmznTraceId: invoke.headers.Get(headerTraceID),
80+
Deadline: messages.InvokeRequest_Timestamp{
81+
Seconds: deadlineS,
82+
Nanos: deadlineNS,
83+
},
84+
Payload: invoke.payload,
85+
}
86+
87+
clientContextJSON := invoke.headers.Get(headerClientContext)
88+
if clientContextJSON != "" {
89+
res.ClientContext = []byte(clientContextJSON)
90+
}
91+
92+
cognitoIdentityJSON := invoke.headers.Get(headerCognitoIdentity)
93+
if cognitoIdentityJSON != "" {
94+
if err := json.Unmarshal([]byte(invoke.headers.Get(headerCognitoIdentity)), res); err != nil {
95+
return nil, fmt.Errorf("failed to unmarshal cognito identity json: %v", err)
96+
}
97+
}
98+
99+
return res, nil
100+
}
101+
102+
func safeMarshal(v interface{}) []byte {
103+
payload, err := json.Marshal(v)
104+
if err != nil {
105+
return []byte(fmt.Sprintf(serializationErrorFormat, err.Error()))
106+
}
107+
return payload
108+
}

0 commit comments

Comments
 (0)