Skip to content

Commit a6b8ad8

Browse files
authored
Fix proxy manipulation of single payload fields (#202)
* Fix proxy manipulation of single payload fields * Remove case for *common.Payload
1 parent 28f4dfc commit a6b8ad8

File tree

3 files changed

+253
-85
lines changed

3 files changed

+253
-85
lines changed

cmd/proxygenerator/interceptor.go

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -207,15 +207,10 @@ func visitPayloads(
207207
parent proto.Message,
208208
objs ...interface{},
209209
) error {
210-
for i, obj := range objs {
210+
for _, obj := range objs {
211211
ctx.SinglePayloadRequired = false
212212
213213
switch o := obj.(type) {
214-
case *common.Payload:
215-
if o == nil { continue }
216-
no, err := visitPayload(ctx, options, parent, o)
217-
if err != nil { return err }
218-
objs[i] = no
219214
case map[string]*common.Payload:
220215
for ix, x := range o {
221216
if nx, err := visitPayload(ctx, options, parent, x); err != nil {
@@ -273,6 +268,14 @@ func visitPayloads(
273268
if options.SkipSearchAttributes { continue }
274269
{{end}}
275270
if o == nil { continue }
271+
{{range $record.Payloads -}}
272+
if o.{{.}} != nil {
273+
no, err := visitPayload(ctx, options, o, o.{{.}})
274+
if err != nil { return err }
275+
o.{{.}} = no
276+
}
277+
{{end}}
278+
{{if $record.Methods}}
276279
if err := visitPayloads(
277280
ctx,
278281
options,
@@ -281,6 +284,7 @@ func visitPayloads(
281284
o.{{.}}(),
282285
{{end}}
283286
); err != nil { return err }
287+
{{end}}
284288
{{end}}
285289
}
286290
}
@@ -326,10 +330,11 @@ var interceptorTemplate = template.Must(template.New("interceptor").Parse(Interc
326330

327331
// TypeRecord holds the state for a type referred to by the workflow service
328332
type TypeRecord struct {
329-
Methods []string // List of methods on this type that can eventually lead to Payload(s)
330-
Slice bool // The API refers to slices of this type
331-
Map bool // The API refers to maps with this type as the value
332-
Matches bool // We found methods on this type that can eventually lead to Payload(s)
333+
Methods []string // List of methods on this type that can eventually lead to Payload(s)
334+
Payloads []string // List of attributes on this type that are of type Payload
335+
Slice bool // The API refers to slices of this type
336+
Map bool // The API refers to maps with this type as the value
337+
Matches bool // We found methods on this type that can eventually lead to Payload(s)
333338
}
334339

335340
// isSlice returns true if a type is slice, false otherwise
@@ -411,12 +416,12 @@ func pruneRecords(input map[string]*TypeRecord) map[string]*TypeRecord {
411416

412417
// walk iterates the methods on a type and returns whether any of them can eventually lead to Payload(s)
413418
// The return type for each method on this type is walked recursively to decide which methods can lead to Payload(s)
414-
func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord) bool {
419+
func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord, checkDirectPayload bool) bool {
415420
typeName := typeName(typ)
416421

417422
// If this type is a slice then walk the underlying type and then make a note we need to encode slices of this type
418423
if isSlice(typ) {
419-
result := walk(desired, elemType(typ), records)
424+
result := walk(desired, elemType(typ), records, checkDirectPayload)
420425
if result {
421426
record := (*records)[typeName]
422427
record.Slice = true
@@ -426,7 +431,7 @@ func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord)
426431

427432
// If this type is a map then walk the underlying type and then make a note we need to encode maps with values of this type
428433
if isMap(typ) {
429-
result := walk(desired, elemType(typ), records)
434+
result := walk(desired, elemType(typ), records, checkDirectPayload)
430435
if result {
431436
record := (*records)[typeName]
432437
record.Map = true
@@ -459,8 +464,18 @@ func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord)
459464
// All the Get... methods return the relevant protobuf as the first result
460465
resultType := sig.Results().At(0).Type()
461466

467+
if checkDirectPayload && resultType.String() == "*go.temporal.io/api/common/v1.Payload" {
468+
record.Matches = true
469+
prefix, ok := strings.CutPrefix(methodName, "Get")
470+
if !ok {
471+
panic(fmt.Errorf("expected method to have a Get prefix: %s", methodName))
472+
}
473+
record.Payloads = append(record.Payloads, prefix)
474+
continue
475+
}
476+
462477
// Check if this method returns a Payload(s) or if it leads (eventually) to a Type which refers to a Payload(s)
463-
if typeMatches(resultType, desired...) || walk(desired, resultType, records) {
478+
if typeMatches(resultType, desired...) || walk(desired, resultType, records, checkDirectPayload) {
464479
record.Matches = true
465480
record.Methods = append(record.Methods, methodName)
466481
}
@@ -536,10 +551,10 @@ func generateInterceptor(cfg config) error {
536551
}
537552

538553
sig := meth.Obj().Type().(*types.Signature)
539-
walk(payloadTypes, sig.Params().At(1).Type(), &payloadRecords)
540-
walk(failureTypes, sig.Params().At(1).Type(), &failureRecords)
541-
walk(payloadTypes, sig.Results().At(0).Type(), &payloadRecords)
542-
walk(failureTypes, sig.Results().At(0).Type(), &failureRecords)
554+
walk(payloadTypes, sig.Params().At(1).Type(), &payloadRecords, true)
555+
walk(failureTypes, sig.Params().At(1).Type(), &failureRecords, false)
556+
walk(payloadTypes, sig.Results().At(0).Type(), &payloadRecords, true)
557+
walk(failureTypes, sig.Results().At(0).Type(), &failureRecords, false)
543558
}
544559

545560
for _, meth := range typeutil.IntuitiveMethodSet(operatorService, nil) {
@@ -548,14 +563,14 @@ func generateInterceptor(cfg config) error {
548563
}
549564

550565
sig := meth.Obj().Type().(*types.Signature)
551-
walk(payloadTypes, sig.Params().At(1).Type(), &payloadRecords)
552-
walk(failureTypes, sig.Params().At(1).Type(), &failureRecords)
553-
walk(payloadTypes, sig.Results().At(0).Type(), &payloadRecords)
554-
walk(failureTypes, sig.Results().At(0).Type(), &failureRecords)
566+
walk(payloadTypes, sig.Params().At(1).Type(), &payloadRecords, true)
567+
walk(failureTypes, sig.Params().At(1).Type(), &failureRecords, false)
568+
walk(payloadTypes, sig.Results().At(0).Type(), &payloadRecords, true)
569+
walk(failureTypes, sig.Results().At(0).Type(), &failureRecords, false)
555570
}
556571

557-
walk(payloadTypes, workflowExecutions, &payloadRecords)
558-
walk(failureTypes, workflowExecutions, &failureRecords)
572+
walk(payloadTypes, workflowExecutions, &payloadRecords, true)
573+
walk(failureTypes, workflowExecutions, &failureRecords, false)
559574

560575
payloadRecords = pruneRecords(payloadRecords)
561576
failureRecords = pruneRecords(failureRecords)

0 commit comments

Comments
 (0)