Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion shortcuts/common/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,13 @@ func (ctx *RuntimeContext) DoAPIStream(callCtx context.Context, req *larkcore.Ap
option.Header = make(http.Header)
}
if shortcutHeaders := cmdutil.ShortcutHeaderOpts(ctx.ctx); shortcutHeaders != nil {
shortcutHeaders(&option)
var shortcutOption larkcore.RequestOption
shortcutHeaders(&shortcutOption)
for key, values := range shortcutOption.Header {
for _, value := range values {
option.Header.Add(key, value)
}
}
}

accessToken, err := ctx.AccessToken()
Expand Down
251 changes: 251 additions & 0 deletions shortcuts/im/helpers_network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ package im
import (
"bytes"
"context"
"crypto/md5"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"testing"
"unsafe"
Expand Down Expand Up @@ -289,6 +291,9 @@ func TestDownloadIMResourceToPathSuccess(t *testing.T) {
if gotHeaders.Get(cmdutil.HeaderExecutionId) != "exec-123" {
t.Fatalf("%s = %q, want %q", cmdutil.HeaderExecutionId, gotHeaders.Get(cmdutil.HeaderExecutionId), "exec-123")
}
if gotHeaders.Get("Range") != fmt.Sprintf("bytes=0-%d", probeChunkSize-1) {
t.Fatalf("Range header = %q, want %q", gotHeaders.Get("Range"), fmt.Sprintf("bytes=0-%d", probeChunkSize-1))
}
}

func TestDownloadIMResourceToPathHTTPErrorBody(t *testing.T) {
Expand All @@ -313,6 +318,252 @@ func TestDownloadIMResourceToPathHTTPErrorBody(t *testing.T) {
}
}

func TestDownloadIMResourceToPathRetriesNetworkError(t *testing.T) {
attempts := 0
payload := []byte("retry success")
runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
switch {
case strings.Contains(req.URL.Path, "tenant_access_token"):
return shortcutJSONResponse(200, map[string]interface{}{
"code": 0,
"tenant_access_token": "tenant-token",
"expire": 7200,
}), nil
case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_retry/resources/file_retry"):
attempts++
if attempts < 3 {
return nil, fmt.Errorf("temporary network failure")
}
return shortcutRawResponse(200, payload, http.Header{"Content-Type": []string{"application/octet-stream"}}), nil
default:
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
}))

target := filepath.Join(t.TempDir(), "out.bin")
_, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_retry", "file_retry", "file", target)
if err != nil {
t.Fatalf("downloadIMResourceToPath() error = %v", err)
}
if attempts != 3 {
t.Fatalf("download attempts = %d, want 3", attempts)
}
if size != int64(len(payload)) {
t.Fatalf("downloadIMResourceToPath() size = %d, want %d", size, len(payload))
}
}

func TestDownloadIMResourceToPathRetrySecondAttemptSuccess(t *testing.T) {
attempts := 0
payload := []byte("second retry success")
runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
switch {
case strings.Contains(req.URL.Path, "tenant_access_token"):
return shortcutJSONResponse(200, map[string]interface{}{
"code": 0,
"tenant_access_token": "tenant-token",
"expire": 7200,
}), nil
case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_retry2/resources/file_retry2"):
attempts++
if attempts < 2 {
return nil, fmt.Errorf("temporary network failure")
}
return shortcutRawResponse(200, payload, http.Header{"Content-Type": []string{"application/octet-stream"}}), nil
default:
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
}))

target := filepath.Join(t.TempDir(), "out.bin")
_, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_retry2", "file_retry2", "file", target)
if err != nil {
t.Fatalf("downloadIMResourceToPath() error = %v", err)
}
if attempts != 2 {
t.Fatalf("download attempts = %d, want 2", attempts)
}
if size != int64(len(payload)) {
t.Fatalf("downloadIMResourceToPath() size = %d, want %d", size, len(payload))
}
}

func TestDownloadIMResourceToPathRetryContextCanceled(t *testing.T) {
attempts := 0
runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
switch {
case strings.Contains(req.URL.Path, "tenant_access_token"):
return shortcutJSONResponse(200, map[string]interface{}{
"code": 0,
"tenant_access_token": "tenant-token",
"expire": 7200,
}), nil
case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_cancel/resources/file_cancel"):
attempts++
return nil, fmt.Errorf("temporary network failure")
default:
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
}))

ctx, cancel := context.WithCancel(context.Background())
// Cancel context immediately to trigger context error on first retry
cancel()

target := filepath.Join(t.TempDir(), "out.bin")
_, _, err := downloadIMResourceToPath(ctx, runtime, "om_cancel", "file_cancel", "file", target)
if err != context.Canceled {
t.Fatalf("downloadIMResourceToPath() error = %v, want context.Canceled", err)
}
// First attempt is made, then retry checks ctx.Err() and returns
if attempts != 1 {
t.Fatalf("download attempts = %d, want 1", attempts)
}
}

func TestDownloadIMResourceToPathRangeDownload(t *testing.T) {
cases := []struct {
name string
payloadLen int64
wantRanges []string
}{
{
name: "single small chunk",
payloadLen: 16,
wantRanges: []string{"bytes=0-131071"},
},
{
name: "exact probe chunk",
payloadLen: probeChunkSize,
wantRanges: []string{"bytes=0-131071"},
},
{
name: "multiple chunks with tail",
payloadLen: probeChunkSize + normalChunkSize + 1234,
wantRanges: []string{
"bytes=0-131071",
fmt.Sprintf("bytes=%d-%d", probeChunkSize, probeChunkSize+normalChunkSize-1),
fmt.Sprintf("bytes=%d-%d", probeChunkSize+normalChunkSize, probeChunkSize+normalChunkSize+1233),
},
},
{
name: "multiple chunks exact 8mb tail",
payloadLen: probeChunkSize + 2*normalChunkSize,
wantRanges: []string{
"bytes=0-131071",
fmt.Sprintf("bytes=%d-%d", probeChunkSize, probeChunkSize+normalChunkSize-1),
fmt.Sprintf("bytes=%d-%d", probeChunkSize+normalChunkSize, probeChunkSize+2*normalChunkSize-1),
},
},
}

for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
payload := bytes.Repeat([]byte("range-download-"), int(tt.payloadLen/15)+1)
payload = payload[:tt.payloadLen]

var gotRanges []string
runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
switch {
case strings.Contains(req.URL.Path, "tenant_access_token"):
return shortcutJSONResponse(200, map[string]interface{}{
"code": 0,
"tenant_access_token": "tenant-token",
"expire": 7200,
}), nil
case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_range/resources/file_range"):
rangeHeader := req.Header.Get("Range")
gotRanges = append(gotRanges, rangeHeader)
if req.Header.Get("Authorization") != "Bearer tenant-token" {
return nil, fmt.Errorf("missing authorization header")
}
start, end, err := parseRangeHeader(rangeHeader, int64(len(payload)))
if err != nil {
return nil, err
}
return shortcutRawResponse(http.StatusPartialContent, payload[start:end+1], http.Header{
"Content-Type": []string{"application/octet-stream"},
"Content-Range": []string{fmt.Sprintf("bytes %d-%d/%d", start, end, len(payload))},
}), nil
default:
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
}))

target := filepath.Join(t.TempDir(), "nested", "resource.bin")
_, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_range", "file_range", "file", target)
if err != nil {
t.Fatalf("downloadIMResourceToPath() error = %v", err)
}
if size != int64(len(payload)) {
t.Fatalf("downloadIMResourceToPath() size = %d, want %d", size, len(payload))
}
if !reflect.DeepEqual(gotRanges, tt.wantRanges) {
t.Fatalf("Range requests = %#v, want %#v", gotRanges, tt.wantRanges)
}

got, err := os.ReadFile(target)
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
if md5.Sum(got) != md5.Sum(payload) {
t.Fatalf("downloaded payload MD5 = %x, want %x", md5.Sum(got), md5.Sum(payload))
}
})
}
}

func TestDownloadIMResourceToPathInvalidContentRange(t *testing.T) {
runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
switch {
case strings.Contains(req.URL.Path, "tenant_access_token"):
return shortcutJSONResponse(200, map[string]interface{}{
"code": 0,
"tenant_access_token": "tenant-token",
"expire": 7200,
}), nil
case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_bad/resources/file_bad"):
return shortcutRawResponse(http.StatusPartialContent, []byte("bad"), http.Header{
"Content-Type": []string{"application/octet-stream"},
"Content-Range": []string{"bytes 0-2/not-a-number"},
}), nil
default:
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
}))

_, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_bad", "file_bad", "file", filepath.Join(t.TempDir(), "out.bin"))
if err == nil || !strings.Contains(err.Error(), "invalid Content-Range header") {
t.Fatalf("downloadIMResourceToPath() error = %v", err)
}
}

func parseRangeHeader(header string, totalSize int64) (int64, int64, error) {
if !strings.HasPrefix(header, "bytes=") {
return 0, 0, fmt.Errorf("unexpected range header: %q", header)
}
parts := strings.SplitN(strings.TrimPrefix(header, "bytes="), "-", 2)
if len(parts) != 2 {
return 0, 0, fmt.Errorf("unexpected range header: %q", header)
}

start, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return 0, 0, fmt.Errorf("parse start: %w", err)
}
end, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return 0, 0, fmt.Errorf("parse end: %w", err)
}
if start < 0 || end < start || start >= totalSize {
return 0, 0, fmt.Errorf("invalid range bounds: %d-%d for size %d", start, end, totalSize)
}
if end >= totalSize {
end = totalSize - 1
}
return start, end, nil
}

func TestUploadImageToIMSuccess(t *testing.T) {
var gotBody string
runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) {
Expand Down
37 changes: 37 additions & 0 deletions shortcuts/im/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,43 @@ func TestDownloadIMResourceToPathHTTPClientError(t *testing.T) {
}
}

func TestParseTotalSize(t *testing.T) {
tests := []struct {
name string
contentRange string
want int64
wantErr string
}{
{name: "normal", contentRange: "bytes 0-131071/104857600", want: 104857600},
{name: "single probe chunk", contentRange: "bytes 0-131071/131072", want: 131072},
{name: "single small chunk", contentRange: "bytes 0-15/16", want: 16},
{name: "empty", contentRange: "", wantErr: "content-range is empty"},
{name: "invalid prefix", contentRange: "items 0-15/16", wantErr: `unsupported content-range: "items 0-15/16"`},
{name: "missing total", contentRange: "bytes 0-15/", wantErr: `unsupported content-range: "bytes 0-15/"`},
{name: "wildcard", contentRange: "bytes */16", wantErr: `unsupported content-range: "bytes */16"`},
{name: "unknown total size", contentRange: "bytes 0-99/*", wantErr: `unknown total size in content-range: "bytes 0-99/*"`},
{name: "invalid total", contentRange: "bytes 0-15/not-a-number", wantErr: "parse total size:"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseTotalSize(tt.contentRange)
if tt.wantErr != "" {
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("parseTotalSize() error = %v, want substring %q", err, tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("parseTotalSize() unexpected error = %v", err)
}
if got != tt.want {
t.Fatalf("parseTotalSize() = %d, want %d", got, tt.want)
}
})
}
}

func TestShortcuts(t *testing.T) {
var commands []string
for _, shortcut := range Shortcuts() {
Expand Down
Loading
Loading