@@ -207,15 +207,10 @@ func visitPayloads(
207207 parent proto.Message,
208208 objs ...interface{},
209209) error {
210- for i , obj := range objs {
210+ for _ , obj := range objs {
211211 ctx.SinglePayloadRequired = false
212212
213213 switch o := obj.(type) {
214- case *common.Payload:
215- if o == nil { continue }
216- no, err := visitPayload(ctx, options, parent, o)
217- if err != nil { return err }
218- objs[i] = no
219214 case map[string]*common.Payload:
220215 for ix, x := range o {
221216 if nx, err := visitPayload(ctx, options, parent, x); err != nil {
@@ -273,6 +268,14 @@ func visitPayloads(
273268 if options.SkipSearchAttributes { continue }
274269 {{end}}
275270 if o == nil { continue }
271+ {{range $record.Payloads -}}
272+ if o.{{.}} != nil {
273+ no, err := visitPayload(ctx, options, o, o.{{.}})
274+ if err != nil { return err }
275+ o.{{.}} = no
276+ }
277+ {{end}}
278+ {{if $record.Methods}}
276279 if err := visitPayloads(
277280 ctx,
278281 options,
@@ -281,6 +284,7 @@ func visitPayloads(
281284 o.{{.}}(),
282285 {{end}}
283286 ); err != nil { return err }
287+ {{end}}
284288{{end}}
285289 }
286290 }
@@ -326,10 +330,11 @@ var interceptorTemplate = template.Must(template.New("interceptor").Parse(Interc
326330
327331// TypeRecord holds the state for a type referred to by the workflow service
328332type TypeRecord struct {
329- Methods []string // List of methods on this type that can eventually lead to Payload(s)
330- Slice bool // The API refers to slices of this type
331- Map bool // The API refers to maps with this type as the value
332- Matches bool // We found methods on this type that can eventually lead to Payload(s)
333+ Methods []string // List of methods on this type that can eventually lead to Payload(s)
334+ Payloads []string // List of attributes on this type that are of type Payload
335+ Slice bool // The API refers to slices of this type
336+ Map bool // The API refers to maps with this type as the value
337+ Matches bool // We found methods on this type that can eventually lead to Payload(s)
333338}
334339
335340// isSlice returns true if a type is slice, false otherwise
@@ -411,12 +416,12 @@ func pruneRecords(input map[string]*TypeRecord) map[string]*TypeRecord {
411416
412417// walk iterates the methods on a type and returns whether any of them can eventually lead to Payload(s)
413418// The return type for each method on this type is walked recursively to decide which methods can lead to Payload(s)
414- func walk (desired []types.Type , typ types.Type , records * map [string ]* TypeRecord ) bool {
419+ func walk (desired []types.Type , typ types.Type , records * map [string ]* TypeRecord , checkDirectPayload bool ) bool {
415420 typeName := typeName (typ )
416421
417422 // If this type is a slice then walk the underlying type and then make a note we need to encode slices of this type
418423 if isSlice (typ ) {
419- result := walk (desired , elemType (typ ), records )
424+ result := walk (desired , elemType (typ ), records , checkDirectPayload )
420425 if result {
421426 record := (* records )[typeName ]
422427 record .Slice = true
@@ -426,7 +431,7 @@ func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord)
426431
427432 // If this type is a map then walk the underlying type and then make a note we need to encode maps with values of this type
428433 if isMap (typ ) {
429- result := walk (desired , elemType (typ ), records )
434+ result := walk (desired , elemType (typ ), records , checkDirectPayload )
430435 if result {
431436 record := (* records )[typeName ]
432437 record .Map = true
@@ -459,8 +464,18 @@ func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord)
459464 // All the Get... methods return the relevant protobuf as the first result
460465 resultType := sig .Results ().At (0 ).Type ()
461466
467+ if checkDirectPayload && resultType .String () == "*go.temporal.io/api/common/v1.Payload" {
468+ record .Matches = true
469+ prefix , ok := strings .CutPrefix (methodName , "Get" )
470+ if ! ok {
471+ panic (fmt .Errorf ("expected method to have a Get prefix: %s" , methodName ))
472+ }
473+ record .Payloads = append (record .Payloads , prefix )
474+ continue
475+ }
476+
462477 // Check if this method returns a Payload(s) or if it leads (eventually) to a Type which refers to a Payload(s)
463- if typeMatches (resultType , desired ... ) || walk (desired , resultType , records ) {
478+ if typeMatches (resultType , desired ... ) || walk (desired , resultType , records , checkDirectPayload ) {
464479 record .Matches = true
465480 record .Methods = append (record .Methods , methodName )
466481 }
@@ -536,10 +551,10 @@ func generateInterceptor(cfg config) error {
536551 }
537552
538553 sig := meth .Obj ().Type ().(* types.Signature )
539- walk (payloadTypes , sig .Params ().At (1 ).Type (), & payloadRecords )
540- walk (failureTypes , sig .Params ().At (1 ).Type (), & failureRecords )
541- walk (payloadTypes , sig .Results ().At (0 ).Type (), & payloadRecords )
542- walk (failureTypes , sig .Results ().At (0 ).Type (), & failureRecords )
554+ walk (payloadTypes , sig .Params ().At (1 ).Type (), & payloadRecords , true )
555+ walk (failureTypes , sig .Params ().At (1 ).Type (), & failureRecords , false )
556+ walk (payloadTypes , sig .Results ().At (0 ).Type (), & payloadRecords , true )
557+ walk (failureTypes , sig .Results ().At (0 ).Type (), & failureRecords , false )
543558 }
544559
545560 for _ , meth := range typeutil .IntuitiveMethodSet (operatorService , nil ) {
@@ -548,14 +563,14 @@ func generateInterceptor(cfg config) error {
548563 }
549564
550565 sig := meth .Obj ().Type ().(* types.Signature )
551- walk (payloadTypes , sig .Params ().At (1 ).Type (), & payloadRecords )
552- walk (failureTypes , sig .Params ().At (1 ).Type (), & failureRecords )
553- walk (payloadTypes , sig .Results ().At (0 ).Type (), & payloadRecords )
554- walk (failureTypes , sig .Results ().At (0 ).Type (), & failureRecords )
566+ walk (payloadTypes , sig .Params ().At (1 ).Type (), & payloadRecords , true )
567+ walk (failureTypes , sig .Params ().At (1 ).Type (), & failureRecords , false )
568+ walk (payloadTypes , sig .Results ().At (0 ).Type (), & payloadRecords , true )
569+ walk (failureTypes , sig .Results ().At (0 ).Type (), & failureRecords , false )
555570 }
556571
557- walk (payloadTypes , workflowExecutions , & payloadRecords )
558- walk (failureTypes , workflowExecutions , & failureRecords )
572+ walk (payloadTypes , workflowExecutions , & payloadRecords , true )
573+ walk (failureTypes , workflowExecutions , & failureRecords , false )
559574
560575 payloadRecords = pruneRecords (payloadRecords )
561576 failureRecords = pruneRecords (failureRecords )
0 commit comments