From 849d6c7b79e329fd82ad8bb4be5bd00207ee6cb3 Mon Sep 17 00:00:00 2001 From: Zack Zeng Date: Tue, 11 Nov 2025 15:56:09 -0800 Subject: [PATCH] refactor: reduce code duplication through cleanup Signed-off-by: Zack Zeng --- client.go | 3 +-- multipart.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ request.go | 2 +- util.go | 44 -------------------------------------------- util_test.go | 12 ++++++++++-- 5 files changed, 58 insertions(+), 49 deletions(-) diff --git a/client.go b/client.go index 2b88d7f..a1fe26c 100644 --- a/client.go +++ b/client.go @@ -1295,8 +1295,7 @@ func (c *Client) execute(req *Request) (*Response, error) { if err != nil { uploadErrChan <- err } - _ = mpw.mw.Close() - _ = mpw.pw.Close() + _ = mpw.Close() close(uploadErrChan) }() } else { diff --git a/multipart.go b/multipart.go index d01f589..e2e53eb 100644 --- a/multipart.go +++ b/multipart.go @@ -1,8 +1,11 @@ package resty import ( + "fmt" "io" "mime/multipart" + "net/http" + "os" ) // ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ @@ -20,7 +23,50 @@ type MultipartField struct { io.Reader } +func (mf *MultipartField) writeToMultipartWriter(w *multipart.Writer) error { + if len(mf.FilePath) > 0 && mf.Reader == nil { + fr, err := os.Open(mf.FilePath) + if err != nil { + return err + } + mf.Reader = fr + } + r := mf.Reader + + buf := make([]byte, 32*1024) + size, err := r.Read(buf) + if err != nil && err != io.EOF { + return err + } + + if len(mf.ContentType) == 0 { + mf.ContentType = http.DetectContentType(buf[:size]) + } + + partWriter, err := w.CreatePart(createMultipartHeader(mf.Param, mf.FileName, mf.ContentType)) + if err != nil { + return err + } + + if _, err = partWriter.Write(buf[:size]); err != nil { + return err + } + + _, err = io.CopyBuffer(partWriter, r, buf) + return err +} + type multipartAndPipeWriter struct { mw *multipart.Writer pw *io.PipeWriter } + +func (m *multipartAndPipeWriter) Close() error { + if err := m.mw.Close(); err != nil { + return fmt.Errorf("close multipart writer: %w", err) + } + if err := m.pw.Close(); err != nil { + return fmt.Errorf("close pipe writer: %w", err) + } + return nil +} diff --git a/request.go b/request.go index 7147e8d..9f2da93 100644 --- a/request.go +++ b/request.go @@ -1178,7 +1178,7 @@ func (r *Request) writeMultipartFields(w *multipart.Writer) error { // GitHub #130 adding multipart field support with content type for _, mf := range r.multipartFields { - if err := addMultipartFormField(w, mf); err != nil { + if err := mf.writeToMultipartWriter(w); err != nil { return err } } diff --git a/util.go b/util.go index dd5c893..6059784 100644 --- a/util.go +++ b/util.go @@ -10,7 +10,6 @@ import ( "fmt" "io" "log" - "mime/multipart" "net/http" "net/textproto" "os" @@ -211,49 +210,6 @@ func closeFieldReaders(fields []*MultipartField) { } } -func addMultipartFormField(w *multipart.Writer, mf *MultipartField) error { - if len(mf.FilePath) > 0 && mf.Reader == nil { - fr, err := os.Open(mf.FilePath) - if err != nil { - return err - } - mf.Reader = fr - } - - if len(mf.ContentType) > 0 { - partWriter, err := w.CreatePart(createMultipartHeader(mf.Param, mf.FileName, mf.ContentType)) - if err != nil { - return err - } - - _, err = io.Copy(partWriter, mf.Reader) - return err - } - - return writeMultipartFormFile(w, mf.Param, mf.FileName, mf.Reader) -} - -func writeMultipartFormFile(w *multipart.Writer, fieldName, fileName string, r io.Reader) error { - // Auto detect actual multipart content type - cbuf := make([]byte, 512) - size, err := r.Read(cbuf) - if err != nil && err != io.EOF { - return err - } - - partWriter, err := w.CreatePart(createMultipartHeader(fieldName, fileName, http.DetectContentType(cbuf[:size]))) - if err != nil { - return err - } - - if _, err = partWriter.Write(cbuf[:size]); err != nil { - return err - } - - _, err = io.Copy(partWriter, r) - return err -} - func getPointer(v interface{}) interface{} { vv := valueOf(v) if vv.Kind() == reflect.Ptr { diff --git a/util_test.go b/util_test.go index 4d0a888..ce51160 100644 --- a/util_test.go +++ b/util_test.go @@ -80,13 +80,21 @@ func TestIsXMLType(t *testing.T) { func TestWriteMultipartFormFileReaderEmpty(t *testing.T) { w := multipart.NewWriter(bytes.NewBuffer(nil)) defer func() { _ = w.Close() }() - if err := writeMultipartFormFile(w, "foo", "bar", bytes.NewReader(nil)); err != nil { + mf := MultipartField{ + Param: "foo", + FileName: "bar", + Reader: bytes.NewReader(nil), + } + if err := mf.writeToMultipartWriter(w); err != nil { t.Errorf("Got unexpected error: %v", err) } } func TestWriteMultipartFormFileReaderError(t *testing.T) { - err := writeMultipartFormFile(nil, "", "", &brokenReadCloser{}) + mf := MultipartField{ + Reader: &brokenReadCloser{}, + } + err := mf.writeToMultipartWriter(nil) assertNotNil(t, err) assertEqual(t, "read error", err.Error()) }