diff --git a/api_client.go b/api_client.go index 64571b1a..0d1bd32d 100644 --- a/api_client.go +++ b/api_client.go @@ -34,10 +34,12 @@ import ( "time" ) -const maxChunkSize = 8 * 1024 * 1024 // 8 MB chunk size -const maxRetryCount = 3 -const initialRetryDelay = time.Second -const delayMultiplier = 2 +const ( + maxChunkSize = 8 * 1024 * 1024 // 8 MB chunk size + maxRetryCount = 3 + initialRetryDelay = time.Second + delayMultiplier = 2 +) type apiClient struct { clientConfig *ClientConfig @@ -75,7 +77,6 @@ func sendStreamRequest[T responseStream[R], R any](ctx context.Context, ac *apiC // sendRequest issues an API request and returns a map of the response contents. func sendRequest(ctx context.Context, ac *apiClient, path string, method string, body map[string]any, httpOptions *HTTPOptions) (map[string]any, error) { - req, httpOptions, err := buildRequest(ctx, ac, path, body, method, httpOptions) if err != nil { return nil, err @@ -435,7 +436,7 @@ func iterateResponseStream[R any](rs *responseStream[R], responseConverter func( default: var err error if len(line) > 0 { - var respWithError = new(responseWithError) + respWithError := new(responseWithError) // Stream chunk that doesn't matches error format. if marshalErr := json.Unmarshal(line, respWithError); marshalErr != nil { err = fmt.Errorf("iterateResponseStream: invalid stream chunk: %s:%s", string(prefix), string(data)) @@ -479,7 +480,7 @@ type responseWithError struct { } func newAPIError(resp *http.Response) error { - var respWithError = new(responseWithError) + respWithError := new(responseWithError) body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("newAPIError: error reading response body: %w. Response: %v", err, string(body)) @@ -561,7 +562,7 @@ func (ac *apiClient) uploadFile(ctx context.Context, r io.Reader, uploadURL stri var offset int64 = 0 var resp *http.Response var respBody map[string]any - var uploadCommand = "upload" + uploadCommand := "upload" buffer := make([]byte, maxChunkSize) for { @@ -574,7 +575,7 @@ func (ac *apiClient) uploadFile(ctx context.Context, r io.Reader, uploadURL stri } else if err != nil { return nil, fmt.Errorf("Failed to read bytes from file at offset %d: %w. Bytes actually read: %d", offset, err, bytesRead) } - for attempt := 0; attempt < maxRetryCount; attempt++ { + for attempt := range maxRetryCount { patchedHTTPOptions, err := patchHTTPOptions(ac.clientConfig.HTTPOptions, *httpOptions) if err != nil { return nil, err @@ -641,7 +642,7 @@ func (ac *apiClient) uploadFile(ctx context.Context, r io.Reader, uploadURL stri return nil, fmt.Errorf("Failed to upload file: Upload status is not finalized") } - var response = new(File) + response := new(File) err := mapToStruct(respBody["file"].(map[string]any), &response) if err != nil { return nil, err diff --git a/api_client_test.go b/api_client_test.go index 43c71e0e..0e3fb7aa 100644 --- a/api_client_test.go +++ b/api_client_test.go @@ -658,11 +658,13 @@ func TestMapToStruct(t *testing.T) { inputMap: map[string]any{ "role": "test-role", "TokenIDs": []string{"123", "456"}, - "Tokens": [][]byte{[]byte("token1"), []byte("token2")}}, + "Tokens": [][]byte{[]byte("token1"), []byte("token2")}, + }, wantValue: TokensInfo{ Role: "test-role", TokenIDs: []int64{123, 456}, - Tokens: [][]byte{[]byte("token1"), []byte("token2")}}, + Tokens: [][]byte{[]byte("token1"), []byte("token2")}, + }, }, { name: "Citation", @@ -705,7 +707,6 @@ func TestMapToStruct(t *testing.T) { outputValue := reflect.New(reflect.TypeOf(tc.wantValue)).Interface() err := mapToStruct(tc.inputMap, &outputValue) - if err != nil { t.Fatalf("mapToStruct failed: %v", err) } @@ -1309,7 +1310,7 @@ func createTestFile(t *testing.T, size int64) (string, func()) { buf := make([]byte, 1024*1024) // 1MB buffer pattern := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()") - for i := 0; i < len(buf); i++ { + for i := range buf { buf[i] = pattern[i%len(pattern)] } @@ -1504,7 +1505,6 @@ func TestUploadFile(t *testing.T) { uploadURL := server.URL + "/upload" uploadedFile, err := ac.uploadFile(ctx, fileReader, uploadURL, httpOpts) - if err != nil { t.Fatalf("uploadFile failed: %v", err) } @@ -1530,7 +1530,6 @@ func TestUploadFile(t *testing.T) { if uploadedFile.MIMEType != "text/plain" { // Matches mock server response t.Errorf("uploadedFile.MIMEType mismatch: want 'text/plain', got '%s'", uploadedFile.MIMEType) } - }) } } diff --git a/base_url.go b/base_url.go index 378370c5..71a4562c 100644 --- a/base_url.go +++ b/base_url.go @@ -14,8 +14,10 @@ package genai -var defaultBaseGeminiURL string = "" -var defaultBaseVertexURL string = "" +var ( + defaultBaseGeminiURL string = "" + defaultBaseVertexURL string = "" +) // BaseURLParameters are parameters for setting the base URLs for the Gemini API and Vertex AI API. type BaseURLParameters struct { diff --git a/caches_test.go b/caches_test.go index 24536fa8..a6ffc809 100644 --- a/caches_test.go +++ b/caches_test.go @@ -91,7 +91,8 @@ func TestCachesAll(t *testing.T) { defer ts.Close() // Create a client with the test server - client, err := NewClient(context.Background(), &ClientConfig{HTTPOptions: HTTPOptions{BaseURL: ts.URL}, + client, err := NewClient(context.Background(), &ClientConfig{ + HTTPOptions: HTTPOptions{BaseURL: ts.URL}, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_API_KEY": "test-api-key", diff --git a/chats_test.go b/chats_test.go index 7205d859..b16d4553 100644 --- a/chats_test.go +++ b/chats_test.go @@ -35,9 +35,9 @@ func TestValidateContent(t *testing.T) { {"NilContent", nil, false}, {"EmptyParts", &Content{Parts: []*Part{}}, false}, {"NilPart", &Content{Parts: []*Part{nil}}, false}, - {"EmptyTextPart", &Content{Parts: []*Part{&Part{Text: ""}}}, false}, - {"ValidTextPart", &Content{Parts: []*Part{&Part{Text: "hello"}}}, true}, - {"ValidFunctionCall", &Content{Parts: []*Part{&Part{FunctionCall: &FunctionCall{Name: "test"}}}}, true}, + {"EmptyTextPart", &Content{Parts: []*Part{{Text: ""}}}, false}, + {"ValidTextPart", &Content{Parts: []*Part{{Text: "hello"}}}, true}, + {"ValidFunctionCall", &Content{Parts: []*Part{{FunctionCall: &FunctionCall{Name: "test"}}}}, true}, } for _, tt := range tests { @@ -176,7 +176,6 @@ func TestChatsUnitTest(t *testing.T) { break } }) - } func TestChatsText(t *testing.T) { @@ -316,16 +315,16 @@ func TestChatsHistory(t *testing.T) { // Create a new Chat with handwritten history. var config *GenerateContentConfig = &GenerateContentConfig{Temperature: Ptr[float32](0.5)} history := []*Content{ - &Content{ + { Role: "user", Parts: []*Part{ - &Part{Text: "What is 1 + 2?"}, + {Text: "What is 1 + 2?"}, }, }, - &Content{ + { Role: "model", Parts: []*Part{ - &Part{Text: "3"}, + {Text: "3"}, }, }, } @@ -721,10 +720,10 @@ data:{ } var expectedResponses []*Content - expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{&Part{Text: "text1_candidate1"}}}) - expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{&Part{Text: " "}}}) - expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{&Part{Text: "text3_candidate1"}, &Part{Text: " additional text3_candidate1 "}}}) - expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{&Part{Text: "text4_candidate1"}, &Part{Text: " additional text4_candidate1"}}}) + expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{{Text: "text1_candidate1"}}}) + expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{{Text: " "}}}) + expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{{Text: "text3_candidate1"}, {Text: " additional text3_candidate1 "}}}) + expectedResponses = append(expectedResponses, &Content{Role: "model", Parts: []*Part{{Text: "text4_candidate1"}, {Text: " additional text4_candidate1"}}}) history := chat.History(false) expectedUserMessage := "What is 1 + 2?" @@ -738,6 +737,5 @@ data:{ } } } - }) } diff --git a/client_test.go b/client_test.go index 5af56153..89739923 100644 --- a/client_test.go +++ b/client_test.go @@ -28,7 +28,6 @@ import ( // TestNewClient only runs in replay mode. func TestNewClient(t *testing.T) { - ctx := context.Background() t.Run("VertexAI with default credentials", func(t *testing.T) { // Needed for account default credential. @@ -85,7 +84,8 @@ func TestNewClient(t *testing.T) { }) t.Run("Explicit project and location takes precedence over project and location from environment when set VertexAI", func(t *testing.T) { - client, err := NewClient(ctx, &ClientConfig{Backend: BackendVertexAI, Project: "constructor-project", Location: "constructor-location", + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendVertexAI, Project: "constructor-project", Location: "constructor-location", envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_CLOUD_PROJECT": "env-project-id", @@ -113,7 +113,8 @@ func TestNewClient(t *testing.T) { t.Run("API key from config when set VertexAI", func(t *testing.T) { apiKey := "test-api-key-constructor" - client, err := NewClient(ctx, &ClientConfig{Backend: BackendVertexAI, APIKey: apiKey, + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendVertexAI, APIKey: apiKey, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_API_KEY": "test-api-key-env", @@ -139,7 +140,8 @@ func TestNewClient(t *testing.T) { t.Run("API key from environment when set VertexAI", func(t *testing.T) { apiKey := "test-api-key-env" - client, err := NewClient(ctx, &ClientConfig{Backend: BackendVertexAI, + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendVertexAI, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_API_KEY": apiKey, @@ -165,7 +167,8 @@ func TestNewClient(t *testing.T) { t.Run("Project from environment", func(t *testing.T) { projectID := "test-project-env" - client, err := NewClient(ctx, &ClientConfig{Backend: BackendVertexAI, Location: "test-location", + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendVertexAI, Location: "test-location", envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_CLOUD_PROJECT": projectID, @@ -182,7 +185,8 @@ func TestNewClient(t *testing.T) { t.Run("Location from GOOGLE_CLOUD_REGION environment", func(t *testing.T) { location := "test-region-env" - client, err := NewClient(ctx, &ClientConfig{Project: "test-project", Backend: BackendVertexAI, + client, err := NewClient(ctx, &ClientConfig{ + Project: "test-project", Backend: BackendVertexAI, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_CLOUD_REGION": location, @@ -199,12 +203,14 @@ func TestNewClient(t *testing.T) { t.Run("Location from GOOGLE_CLOUD_LOCATION environment", func(t *testing.T) { location := "test-location-env" - client, err := NewClient(ctx, &ClientConfig{Project: "test-project", Backend: BackendVertexAI, + client, err := NewClient(ctx, &ClientConfig{ + Project: "test-project", Backend: BackendVertexAI, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_CLOUD_LOCATION": location, } - }}) + }, + }) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -214,7 +220,8 @@ func TestNewClient(t *testing.T) { }) t.Run("VertexAI set from environment", func(t *testing.T) { - client, err := NewClient(ctx, &ClientConfig{Project: "test-project", Location: "test-location", + client, err := NewClient(ctx, &ClientConfig{ + Project: "test-project", Location: "test-location", envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "true", @@ -230,7 +237,8 @@ func TestNewClient(t *testing.T) { }) t.Run("VertexAI false from environment", func(t *testing.T) { - client, err := NewClient(ctx, &ClientConfig{APIKey: "test-api-key", + client, err := NewClient(ctx, &ClientConfig{ + APIKey: "test-api-key", envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "false", @@ -246,7 +254,8 @@ func TestNewClient(t *testing.T) { }) t.Run("VertexAI from config", func(t *testing.T) { - client, err := NewClient(ctx, &ClientConfig{Backend: BackendVertexAI, Project: "test-project", Location: "test-location", + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendVertexAI, Project: "test-project", Location: "test-location", envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "false", @@ -262,7 +271,8 @@ func TestNewClient(t *testing.T) { }) t.Run("VertexAI is unset from config and environment is false", func(t *testing.T) { - client, err := NewClient(ctx, &ClientConfig{APIKey: "test-api-key", + client, err := NewClient(ctx, &ClientConfig{ + APIKey: "test-api-key", envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "false", @@ -278,7 +288,8 @@ func TestNewClient(t *testing.T) { }) t.Run("VertexAI is unset from config but environment is true", func(t *testing.T) { - client, err := NewClient(ctx, &ClientConfig{Backend: BackendGeminiAPI, APIKey: "test-api-key", + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendGeminiAPI, APIKey: "test-api-key", envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "true", @@ -296,7 +307,8 @@ func TestNewClient(t *testing.T) { t.Run("API key from constructor takes precedence over proj/location from environment", func(t *testing.T) { // Vertex AI API key combo 1 apiKey := "vertexai-api-key" - client, err := NewClient(ctx, &ClientConfig{Backend: BackendVertexAI, APIKey: apiKey, + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendVertexAI, APIKey: apiKey, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_API_KEY": "", @@ -331,7 +343,8 @@ func TestNewClient(t *testing.T) { // Vertex AI API key combo 2 project := "test-project" location := "test-location" - client, err := NewClient(ctx, &ClientConfig{Backend: BackendVertexAI, Project: project, Location: location, + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendVertexAI, Project: project, Location: location, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_API_KEY": "vertexai-api-key-env", @@ -366,7 +379,8 @@ func TestNewClient(t *testing.T) { // Vertex AI API key combo 3 project := "test-project-env" location := "test-location-env" - client, err := NewClient(ctx, &ClientConfig{Backend: BackendVertexAI, + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendVertexAI, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_API_KEY": "vertexai-api-key-env", @@ -399,10 +413,12 @@ func TestNewClient(t *testing.T) { t.Run("Base URL from HTTPOptions", func(t *testing.T) { baseURL := "https://test-base-url.com/" - client, err := NewClient(ctx, &ClientConfig{Project: "test-project", Location: "test-location", Backend: BackendVertexAI, + client, err := NewClient(ctx, &ClientConfig{ + Project: "test-project", Location: "test-location", Backend: BackendVertexAI, HTTPOptions: HTTPOptions{ BaseURL: baseURL, - }}) + }, + }) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -431,12 +447,14 @@ func TestNewClient(t *testing.T) { t.Run("Base URL from environment", func(t *testing.T) { baseURL := "https://test-base-url.com/" - client, err := NewClient(ctx, &ClientConfig{Project: "test-project", Location: "test-location", Backend: BackendVertexAI, + client, err := NewClient(ctx, &ClientConfig{ + Project: "test-project", Location: "test-location", Backend: BackendVertexAI, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_VERTEX_BASE_URL": baseURL, } - }}) + }, + }) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -469,7 +487,8 @@ func TestNewClient(t *testing.T) { t.Run("API Key from config", func(t *testing.T) { apiKey := "test-constructor-api-key" - client, err := NewClient(ctx, &ClientConfig{APIKey: apiKey, + client, err := NewClient(ctx, &ClientConfig{ + APIKey: apiKey, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_API_KEY": "test-env-api-key", @@ -493,7 +512,8 @@ func TestNewClient(t *testing.T) { t.Run("API Key from GOOGLE_API_KEY only", func(t *testing.T) { apiKey := "test-api-key-env" - client, err := NewClient(ctx, &ClientConfig{Backend: BackendGeminiAPI, + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendGeminiAPI, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_API_KEY": apiKey, @@ -509,7 +529,8 @@ func TestNewClient(t *testing.T) { }) t.Run("API Key from GEMINI_API_KEY only", func(t *testing.T) { apiKey := "test-api-key-env" - client, err := NewClient(ctx, &ClientConfig{Backend: BackendGeminiAPI, + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendGeminiAPI, envVarProvider: func() map[string]string { return map[string]string{ "GEMINI_API_KEY": apiKey, @@ -526,7 +547,8 @@ func TestNewClient(t *testing.T) { t.Run("API Key from GEMINI_API_KEY and GOOGLE_API_KEY as empty string", func(t *testing.T) { apiKey := "test-api-key-env" - client, err := NewClient(ctx, &ClientConfig{Backend: BackendGeminiAPI, + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendGeminiAPI, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_API_KEY": "", @@ -545,7 +567,8 @@ func TestNewClient(t *testing.T) { t.Run("API Key both GEMINI_API_KEY and GOOGLE_API_KEY", func(t *testing.T) { geminiAPIKey := "gemini-api-key-env" googleAPIKey := "google-api-key-env" - client, err := NewClient(ctx, &ClientConfig{Backend: BackendGeminiAPI, + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendGeminiAPI, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_API_KEY": googleAPIKey, @@ -563,10 +586,12 @@ func TestNewClient(t *testing.T) { t.Run("Base URL from HTTPOptions", func(t *testing.T) { baseURL := "https://test-base-url.com/" - client, err := NewClient(ctx, &ClientConfig{APIKey: "test-api-key", Backend: BackendGeminiAPI, + client, err := NewClient(ctx, &ClientConfig{ + APIKey: "test-api-key", Backend: BackendGeminiAPI, HTTPOptions: HTTPOptions{ BaseURL: baseURL, - }}) + }, + }) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -595,12 +620,14 @@ func TestNewClient(t *testing.T) { t.Run("Base URL from environment", func(t *testing.T) { baseURL := "https://test-base-url.com/" - client, err := NewClient(ctx, &ClientConfig{APIKey: "test-api-key", Backend: BackendGeminiAPI, + client, err := NewClient(ctx, &ClientConfig{ + APIKey: "test-api-key", Backend: BackendGeminiAPI, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_GEMINI_BASE_URL": baseURL, } - }}) + }, + }) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -690,7 +717,6 @@ func TestNewClient(t *testing.T) { t.Errorf("Models.apiClient.clientConfig mismatch (-want +got):\n%s", diff) } }) - } func TestClientConfigHTTPOptions(t *testing.T) { diff --git a/common.go b/common.go index 413dc5d6..f5468947 100644 --- a/common.go +++ b/common.go @@ -21,6 +21,7 @@ import ( "fmt" "iter" "log" + "maps" "net/http" "net/url" "reflect" @@ -135,9 +136,7 @@ func setValueByPath(data map[string]any, keys []string, value any) { if newMap, ok2 := value.(map[string]any); ok2 { // Instead of overwriting dictionary with another dictionary, merge them. // This is important for handling training and validation datasets in tuning. - for k, v := range newMap { - existingMap[k] = v - } + maps.Copy(existingMap, newMap) data[finalKey] = existingMap // Assign the updated map back } } else { @@ -147,9 +146,7 @@ func setValueByPath(data map[string]any, keys []string, value any) { if finalKey == "_self" && reflect.TypeOf(value).Kind() == reflect.Map { // Iterate through the `value` map and copy its contents to `data`. if valMap, ok := value.(map[string]any); ok { - for k, v := range valMap { - data[k] = v - } + maps.Copy(data, valMap) } } else { // If existing_data is None (or key doesn't exist), set the value directly. diff --git a/common_test.go b/common_test.go index 16cea1d9..f1f1d44a 100644 --- a/common_test.go +++ b/common_test.go @@ -21,9 +21,7 @@ import ( "github.com/google/go-cmp/cmp" ) -var ( - dummyExtrasProvider = func(body map[string]any) map[string]any { return body } -) +var dummyExtrasProvider = func(body map[string]any) map[string]any { return body } func TestMergeHTTPOptions(t *testing.T) { tests := []struct { @@ -304,12 +302,10 @@ func TestSetValueByPath(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - setValueByPath(tt.data, tt.keys, tt.value) if diff := cmp.Diff(tt.data, tt.want); diff != "" { t.Errorf("setValueByPath() mismatch (-want +got):\n%s", diff) } - }) } } @@ -393,7 +389,6 @@ func TestGetValueByPath(t *testing.T) { t.Errorf("getValueByPath() mismatch (-want +got):\n%s", diff) } } - }) } } diff --git a/examples/batches/create_get_cancel.go b/examples/batches/create_get_cancel.go index 6e402832..06af5b57 100644 --- a/examples/batches/create_get_cancel.go +++ b/examples/batches/create_get_cancel.go @@ -44,7 +44,7 @@ func run(ctx context.Context) { } if client.ClientConfig().Backend == genai.BackendVertexAI { fmt.Println("Calling VertexAI Backend...") - var model = flag.String("model", "gemini-1.5-flash-002", "the model name, e.g. gemini-1.5-pro-002") + model := flag.String("model", "gemini-1.5-flash-002", "the model name, e.g. gemini-1.5-pro-002") // Create a batch job. result, err := client.Batches.Create( ctx, @@ -84,7 +84,7 @@ func run(ctx context.Context) { fmt.Println("Cancelled batch job:", result.Name) } else { fmt.Println("Calling GeminiAPI Backend...") - var model = flag.String("model", "gemini-2.0-flash", "the model name, e.g. gemini-1.5-pro-002") + model := flag.String("model", "gemini-2.0-flash", "the model name, e.g. gemini-1.5-pro-002") // Create a batch job. result, err := client.Batches.Create( ctx, diff --git a/examples/caches/create_get_delete.go b/examples/caches/create_get_delete.go index cca21652..7fb30670 100644 --- a/examples/caches/create_get_delete.go +++ b/examples/caches/create_get_delete.go @@ -70,7 +70,8 @@ func run(ctx context.Context) { }, }, }, - }}) + }, + }) if err != nil { log.Fatal(err) } diff --git a/examples/files/upload_file.go b/examples/files/upload_file.go index 9891609b..6d13de02 100644 --- a/examples/files/upload_file.go +++ b/examples/files/upload_file.go @@ -80,7 +80,7 @@ func run(ctx context.Context) { fmt.Println("Calling GeminiAPI Backend...") } // Upload a new file. - var testDataDir = filepath.Join(moduleRootDir(), "testdata") + testDataDir := filepath.Join(moduleRootDir(), "testdata") filePath := filepath.Join(testDataDir, "google.jpg") file, err := client.Files.UploadFromPath(ctx, filePath, nil) if err != nil { diff --git a/examples/mcptoolbox/mcp_toolbox.go b/examples/mcptoolbox/mcp_toolbox.go index bb994545..6212ca30 100644 --- a/examples/mcptoolbox/mcp_toolbox.go +++ b/examples/mcptoolbox/mcp_toolbox.go @@ -13,7 +13,6 @@ import ( // ConvertToGenaiTool translates a ToolboxTool into the genai.FunctionDeclaration format. func ConvertToGenaiTool(toolboxTool *core.ToolboxTool) *genai.Tool { - inputschema, err := toolboxTool.InputSchema() if err != nil { return &genai.Tool{} @@ -135,5 +134,4 @@ func main() { } log.Println("=== Final Response from Model (after processing function result) ===") printResponse(finalResponse) - } diff --git a/examples/models/generate_content/function_declaration_json_schema.go b/examples/models/generate_content/function_declaration_json_schema.go index 2c9f0bb8..6d7a24d5 100644 --- a/examples/models/generate_content/function_declaration_json_schema.go +++ b/examples/models/generate_content/function_declaration_json_schema.go @@ -29,7 +29,7 @@ import ( var model = flag.String("model", "gemini-2.0-flash", "the model name, e.g. gemini-2.0-flash") func run(ctx context.Context) { - var parameterSchema = map[string]any{ + parameterSchema := map[string]any{ "type": "object", "properties": map[string]any{ "brightness": map[string]any{ @@ -44,7 +44,7 @@ func run(ctx context.Context) { "required": []string{"brightness", "colorTemperature"}, } - var tools = []*genai.Tool{ + tools := []*genai.Tool{ { FunctionDeclarations: []*genai.FunctionDeclaration{ { diff --git a/examples/models/generate_content/function_declaration_schema.go b/examples/models/generate_content/function_declaration_schema.go index 795fc765..aaa47061 100644 --- a/examples/models/generate_content/function_declaration_schema.go +++ b/examples/models/generate_content/function_declaration_schema.go @@ -29,8 +29,7 @@ import ( var model = flag.String("model", "gemini-2.0-flash", "the model name, e.g. gemini-2.0-flash") func run(ctx context.Context) { - - var tools = []*genai.Tool{ + tools := []*genai.Tool{ { FunctionDeclarations: []*genai.FunctionDeclaration{ { diff --git a/examples/models/generate_content/text_stream.go b/examples/models/generate_content/text_stream.go index eb55837a..e0ae3552 100644 --- a/examples/models/generate_content/text_stream.go +++ b/examples/models/generate_content/text_stream.go @@ -37,7 +37,7 @@ func run(ctx context.Context) { } else { fmt.Println("Calling GeminiAI.GenerateContentStream API...") } - var config *genai.GenerateContentConfig = &genai.GenerateContentConfig{SystemInstruction: &genai.Content{Parts: []*genai.Part{&genai.Part{Text: "You are a story writer."}}}} + var config *genai.GenerateContentConfig = &genai.GenerateContentConfig{SystemInstruction: &genai.Content{Parts: []*genai.Part{{Text: "You are a story writer."}}}} // Call the GenerateContent method. for result, err := range client.Models.GenerateContentStream(ctx, *model, genai.Text("Tell me a story in 300 words."), config) { if err != nil { diff --git a/examples/models/recontext_image/image.go b/examples/models/recontext_image/image.go index 53aa187b..17fba516 100644 --- a/examples/models/recontext_image/image.go +++ b/examples/models/recontext_image/image.go @@ -58,7 +58,8 @@ func run(ctx context.Context) { &genai.RecontextImageSource{ Prompt: prompt, PersonImage: nil, - ProductImages: productImages}, + ProductImages: productImages, + }, &genai.RecontextImageConfig{ OutputMIMEType: "image/jpeg", }, @@ -79,7 +80,8 @@ func run(ctx context.Context) { &genai.RecontextImageSource{ Prompt: "", PersonImage: personImage, - ProductImages: productImages2}, + ProductImages: productImages2, + }, &genai.RecontextImageConfig{ OutputMIMEType: "image/jpeg", }, diff --git a/files_test.go b/files_test.go index c4dc915a..627c9e61 100644 --- a/files_test.go +++ b/files_test.go @@ -245,7 +245,8 @@ func TestFilesAll(t *testing.T) { })) defer ts.Close() - client, err := NewClient(ctx, &ClientConfig{HTTPOptions: HTTPOptions{BaseURL: ts.URL}, + client, err := NewClient(ctx, &ClientConfig{ + HTTPOptions: HTTPOptions{BaseURL: ts.URL}, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_API_KEY": "test-api-key", @@ -692,7 +693,7 @@ func TestFilesUploadFromPath(t *testing.T) { tempDir := t.TempDir() filePath := filepath.Join(tempDir, "testfile.txt") fileContent := "Content for UploadFromPath test." - err = os.WriteFile(filePath, []byte(fileContent), 0644) + err = os.WriteFile(filePath, []byte(fileContent), 0o644) if err != nil { t.Fatalf("Failed to create temp file: %v", err) } @@ -754,7 +755,7 @@ func TestFilesUploadFromPath(t *testing.T) { name: "Error - Unknown MIME Type", path: func() string { // Create a file with an unknown extension p := filepath.Join(tempDir, "file.unknownext") - _ = os.WriteFile(p, []byte("data"), 0644) + _ = os.WriteFile(p, []byte("data"), 0o644) return p }(), config: nil, // No MIME override diff --git a/internal/changefinder/main.go b/internal/changefinder/main.go index bb8ecf85..ee922ab2 100644 --- a/internal/changefinder/main.go +++ b/internal/changefinder/main.go @@ -156,7 +156,7 @@ func touchModule(root, mod string) error { c := exec.Command("echo") log.Printf(c.String()) - f, err := os.OpenFile(path.Join(root, mod, "CHANGES.md"), os.O_APPEND|os.O_WRONLY, 0644) + f, err := os.OpenFile(path.Join(root, mod, "CHANGES.md"), os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { return err } diff --git a/live.go b/live.go index e73e6b48..62d49bad 100644 --- a/live.go +++ b/live.go @@ -291,7 +291,7 @@ func (s *Session) Receive() (*LiveServerMessage, error) { return nil, err } - var message = new(LiveServerMessage) + message := new(LiveServerMessage) err = mapToStruct(responseMap, message) if err != nil { return nil, err diff --git a/live_test.go b/live_test.go index 31a7a394..1878b5f3 100644 --- a/live_test.go +++ b/live_test.go @@ -194,7 +194,7 @@ func TestLiveConnect(t *testing.T) { for _, tt := range connectTests { t.Run(tt.desc, func(t *testing.T) { - var upgrader = websocket.Upgrader{} + upgrader := websocket.Upgrader{} ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, _ := upgrader.Upgrade(w, r, nil) defer conn.Close() @@ -213,7 +213,6 @@ func TestLiveConnect(t *testing.T) { } mt, message, err := conn.ReadMessage() - if err != nil { if tt.wantErr { return @@ -500,7 +499,7 @@ func TestLiveConnect(t *testing.T) { func setupTestWebsocketServer(t *testing.T, wantRequestBodySlice []string, fakeResponseBodySlice []string) *httptest.Server { t.Helper() - var upgrader = websocket.Upgrader{} + upgrader := websocket.Upgrader{} ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, _ := upgrader.Upgrade(w, r, nil) diff --git a/main_test.go b/main_test.go index 55c51be1..4dde54d9 100644 --- a/main_test.go +++ b/main_test.go @@ -30,10 +30,10 @@ var ( "mldev/models/generate_images/test_all_vertexai_config_person_generation_enum_parameters_3", } disabledTestsByMode = map[string][]string{ - apiMode: []string{ + apiMode: { "TestModelsGenerateContentAudio/", }, - replayMode: []string{ + replayMode: { // TODO(b/372730941): httpOptions related tests are not covered in replay mode. "models/delete/test_delete_model_with_http_options_in_method", "models/generate_content/test_http_options_in_method", @@ -60,7 +60,7 @@ var ( // filter=display_name%3A%22genai_%2A%22&pageSize=5 are reordered and mismatched "batches/list/test_list_batch_jobs_with_config", }, - unitMode: []string{ + unitMode: { // We don't run table tests in unit mode. "TestTable/", }, diff --git a/models_test.go b/models_test.go index 7c0a2252..d9692c5b 100644 --- a/models_test.go +++ b/models_test.go @@ -466,7 +466,8 @@ func TestModelsAll(t *testing.T) { })) defer ts.Close() - client, err := NewClient(ctx, &ClientConfig{HTTPOptions: HTTPOptions{BaseURL: ts.URL}, + client, err := NewClient(ctx, &ClientConfig{ + HTTPOptions: HTTPOptions{BaseURL: ts.URL}, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_API_KEY": "test-api-key", @@ -544,7 +545,8 @@ func TestModelsAllEmptyResponse(t *testing.T) { })) defer ts.Close() - client, err := NewClient(ctx, &ClientConfig{HTTPOptions: HTTPOptions{BaseURL: ts.URL}, + client, err := NewClient(ctx, &ClientConfig{ + HTTPOptions: HTTPOptions{BaseURL: ts.URL}, envVarProvider: func() map[string]string { return map[string]string{ "GOOGLE_API_KEY": "test-api-key", diff --git a/pages.go b/pages.go index c8b8f98c..bb20f125 100644 --- a/pages.go +++ b/pages.go @@ -18,6 +18,7 @@ import ( "context" "errors" "iter" + "maps" ) // ErrPageDone is the error returned by an iterator's Next method when no more pages are available. @@ -88,9 +89,7 @@ func (p Page[T]) Next(ctx context.Context) (Page[T], error) { return p, ErrPageDone } c := make(map[string]any) - for k, v := range p.config { - c[k] = v - } + maps.Copy(c, p.config) c["PageToken"] = p.NextPageToken return newPage[T](ctx, p.Name, c, p.listFunc) diff --git a/pages_test.go b/pages_test.go index ecef819e..b676db60 100644 --- a/pages_test.go +++ b/pages_test.go @@ -137,7 +137,5 @@ func TestPageAll(t *testing.T) { t.Fatalf("Unexpected error during iteration: %v", err) } - } - } diff --git a/replay_sanitizer.go b/replay_sanitizer.go index d12ee342..e4938c7e 100644 --- a/replay_sanitizer.go +++ b/replay_sanitizer.go @@ -60,7 +60,7 @@ func sanitizeMapWithSourceType(t *testing.T, sourceType reflect.Type, m any) { for _, path := range paths { if sourceType.Kind() == reflect.Slice { data := m.([]any) - for i := 0; i < len(data); i++ { + for i := range data { sanitizeMapByPath(data[i], path, stdBase64Handler, false) } } else { diff --git a/replay_sanitizer_test.go b/replay_sanitizer_test.go index 619d5e70..a16ada90 100644 --- a/replay_sanitizer_test.go +++ b/replay_sanitizer_test.go @@ -30,7 +30,7 @@ type nestedStruct struct { type outerStruct struct { PointerField *nestedStruct `json:"pointerField,omitempty"` - StructField nestedStruct `json:"structField,omitempty"` + StructField nestedStruct `json:"structField"` SliceField []nestedStruct `json:"sliceField,omitempty"` SlicePointerField []*nestedStruct `json:"slicePointerField,omitempty"` // Recursive types. @@ -121,9 +121,9 @@ func TestSanitizeMapByPath(t *testing.T) { sanitized: map[string]any{"k1": []any{"sanitized", "sanitized"}}, }, { - input: map[string]any{"k1": []map[string]any{map[string]any{"k2": "v2"}, map[string]any{"k2": "v2"}}}, + input: map[string]any{"k1": []map[string]any{{"k2": "v2"}, {"k2": "v2"}}}, path: "[]k1.k2", - sanitized: map[string]any{"k1": []map[string]any{map[string]any{"k2": "sanitized"}, map[string]any{"k2": "sanitized"}}}, + sanitized: map[string]any{"k1": []map[string]any{{"k2": "sanitized"}, {"k2": "sanitized"}}}, }, { input: map[string]any{"k1": map[string]any{"k2": []any{"v2", "v2"}}}, @@ -147,9 +147,9 @@ func TestSanitizeMapByPath(t *testing.T) { sanitized: map[string]any{"k1": []any{"v1", "v1"}}, }, { - input: map[string]any{"k1": []map[string]any{map[string]any{"k2": "v2"}, map[string]any{"k2": "v2"}}}, + input: map[string]any{"k1": []map[string]any{{"k2": "v2"}, {"k2": "v2"}}}, path: "[]wrongPath.k2", - sanitized: map[string]any{"k1": []map[string]any{map[string]any{"k2": "v2"}, map[string]any{"k2": "v2"}}}, + sanitized: map[string]any{"k1": []map[string]any{{"k2": "v2"}, {"k2": "v2"}}}, }, { input: map[string]any{"k1": map[string]any{"k2": []string{"v2", "v2"}}}, diff --git a/table_test.go b/table_test.go index 5783c60a..c287d26d 100644 --- a/table_test.go +++ b/table_test.go @@ -51,7 +51,7 @@ func snakeToCamel(s string) string { // methodParamType is extra mapping of method param name to its param type because reflect module cannot process private struct. var methodParamType = map[string]map[string]reflect.Type{ - "editImage": map[string]reflect.Type{ + "editImage": { "referenceImages": reflect.TypeOf(([]ReferenceImage)(nil)), }, } diff --git a/tokenizer/tokenizer.go b/tokenizer/tokenizer.go index a26fa6fa..713fdb9c 100644 --- a/tokenizer/tokenizer.go +++ b/tokenizer/tokenizer.go @@ -404,11 +404,11 @@ func loadModelData(url string, wantHash string) ([]byte, error) { return nil, fmt.Errorf("downloaded model hash mismatch") } - err = os.MkdirAll(cacheDir, 0770) + err = os.MkdirAll(cacheDir, 0o770) if err != nil { return nil, fmt.Errorf("creating cache dir: %w", err) } - err = os.WriteFile(cachePath, cacheData, 0660) + err = os.WriteFile(cachePath, cacheData, 0o660) if err != nil { return nil, fmt.Errorf("writing cache file: %w", err) } diff --git a/tokenizer/tokenizer_test.go b/tokenizer/tokenizer_test.go index 6c808c91..f59aea64 100644 --- a/tokenizer/tokenizer_test.go +++ b/tokenizer/tokenizer_test.go @@ -59,8 +59,8 @@ func TestLoadModelData(t *testing.T) { // Overwrite cache file with wrong data, and try again. cacheDir := filepath.Join(os.TempDir(), "vertexai_tokenizer_model") cachePath := filepath.Join(cacheDir, hashString([]byte(config.modelURL))) - _ = os.MkdirAll(cacheDir, 0770) - _ = os.WriteFile(cachePath, []byte{0, 1, 2, 3}, 0660) + _ = os.MkdirAll(cacheDir, 0o770) + _ = os.WriteFile(cachePath, []byte{0, 1, 2, 3}, 0o660) data, err = loadModelData(config.modelURL, config.modelHash) checkDataAndErr(data, err) } @@ -92,7 +92,7 @@ func TestCreateLocalTokenizer(t *testing.T) { } func TestCountTokens(t *testing.T) { - var tests = []struct { + tests := []struct { contents []*genai.Content wantCount int32 }{ diff --git a/transformer.go b/transformer.go index 9571dc3a..e97f7478 100644 --- a/transformer.go +++ b/transformer.go @@ -139,6 +139,7 @@ func tLiveSpeechConfig(speechConfig any) (any, error) { return nil, fmt.Errorf("unsupported speechConfig type: %T", speechConfig) } } + func tBytes(fromImageBytes any) (any, error) { // TODO(b/389133914): Remove dummy bytes converter. return fromImageBytes, nil @@ -219,8 +220,8 @@ func tFileName(name any) (string, error) { return "", fmt.Errorf("could not extract file name from URI: %s", name) } name = match[0] - } else if strings.HasPrefix(name, "files/") { - name = strings.TrimPrefix(name, "files/") + } else if after, ok := strings.CutPrefix(name, "files/"); ok { + name = after } return name, nil } diff --git a/tunings_test.go b/tunings_test.go index ec499886..bc105819 100644 --- a/tunings_test.go +++ b/tunings_test.go @@ -117,6 +117,7 @@ func TestTuningsTuneUnit(t *testing.T) { }) } } + func TestTuningsTuneAPIMode(t *testing.T) { if *mode != apiMode { t.Skip("Skip. This test is only in the API mode") @@ -164,7 +165,6 @@ func TestTuningsTuneAPIMode(t *testing.T) { // Test tuning with a pre-tuned model. continuousJob, err := client.Tunings.Tune(ctx, preTunedModelName, trainingDataset, nil) - if err != nil { t.Fatalf("Tunings.Tune() with pre-tuned model failed: %v", err) } diff --git a/types_json_test.go b/types_json_test.go index bf22a63f..9386002e 100644 --- a/types_json_test.go +++ b/types_json_test.go @@ -618,7 +618,7 @@ func TestMarshalJSON(t *testing.T) { SizeBytes: Ptr[int64](1024), CreateTime: time.Date(2024, 12, 31, 23, 59, 59, 0, time.UTC), ExpirationTime: time.Date(2025, 12, 31, 23, 59, 59, 0, time.UTC), - UpdateTime: time.Date(2025, 01, 01, 0, 0, 0, 0, time.UTC), + UpdateTime: time.Date(2025, 0o1, 0o1, 0, 0, 0, 0, time.UTC), Sha256Hash: "test-hash", URI: "https://example.com/test-file", DownloadURI: "https://example.com/download/test-file", @@ -759,13 +759,11 @@ func TestMarshalJSON(t *testing.T) { if string(roundTripMarshal) != tt.want { t.Errorf("%s.MarshalJSON() = %v, want %v", tt.target, string(roundTripMarshal), tt.want) } - }) } } func TestJSONCustomTypes(t *testing.T) { - t.Run("int64SliceJSON", func(t *testing.T) { type valueStruct struct { Val int64SliceJSON `json:"val,omitempty"` @@ -919,7 +917,7 @@ func TestJSONCustomTypes(t *testing.T) { t.Run("dateJSON", func(t *testing.T) { type valueStruct struct { - Val dateJSON `json:"val,omitempty"` + Val dateJSON `json:"val"` } type pointerStruct struct { Val *dateJSON `json:"val,omitempty"` diff --git a/version.go b/version.go index 2a177888..9ca7d949 100644 --- a/version.go +++ b/version.go @@ -14,7 +14,5 @@ package genai -var ( - // Version is the version of the SDK. - version = "1.29.0" // x-release-please-version -) +// Version is the version of the SDK. +var version = "1.29.0" // x-release-please-version