diff --git a/examples/files/gcs_register.go b/examples/files/gcs_register.go new file mode 100644 index 00000000..1b1c44e2 --- /dev/null +++ b/examples/files/gcs_register.go @@ -0,0 +1,87 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build ignore_vet + +package main + +import ( + "context" + "flag" + "fmt" + "log" + + "cloud.google.com/go/auth/credentials" + "google.golang.org/genai" +) + +var model = flag.String("model", "gemini-2.5-flash", "the model name, e.g. gemini-2.5-flash") +var gcsURI = flag.String("gcs-uri", "gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf", "the gcs uri of the pdf file") + +// This example shows how to register a file from GCS and use it in the Gemini API. +// Setup instructions: https://ai.google.dev/gemini-api/docs/file-input-methods#registration +// GOOGLE_APPLICATION_CREDENTIALS=/path/to/service-account.json GEMINI_API_KEY= go run gcs_register.go +func run(ctx context.Context) { + client, err := genai.NewClient(ctx, nil) + if err != nil { + log.Fatalln(err) + } + if client.ClientConfig().Backend == genai.BackendVertexAI { + log.Fatalln("Not supported for VertexAI backend") + } else { + fmt.Println("Calling GeminiAPI Backend...") + } + + creds, err := credentials.DetectDefault(&credentials.DetectOptions{ + Scopes: []string{"https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/devstorage.read_only"}, + }) + if err != nil { + log.Fatal(err) + } + + registeredFiles, err := client.Files.RegisterFiles(ctx, creds, []string{*gcsURI}, nil) + if err != nil { + log.Fatal(err) + } + fmt.Println("Registered files:", registeredFiles.Files) + if len(registeredFiles.Files) == 0 { + log.Fatal("No files were registered") + } + + result, err := client.Models.GenerateContent(ctx, *model, []*genai.Content{ + { + Role: "user", + Parts: []*genai.Part{ + {Text: "What's this pdf about?"}, + { + FileData: &genai.FileData{ + FileURI: registeredFiles.Files[0].URI, + MIMEType: "application/pdf", + }, + }, + }, + }, + }, nil) + if err != nil { + log.Fatal(err) + } + + fmt.Println("Generated content:", result.Text()) +} + +func main() { + ctx := context.Background() + flag.Parse() + run(ctx) +} diff --git a/files.go b/files.go index 3833ba34..e00501cf 100644 --- a/files.go +++ b/files.go @@ -27,6 +27,8 @@ import ( "path/filepath" "strconv" "strings" + + "cloud.google.com/go/auth" ) func createFileParametersToMldev(fromObject map[string]any, parentObject map[string]any, rootObject map[string]any) (toObject map[string]any, err error) { @@ -145,6 +147,33 @@ func listFilesResponseFromMldev(fromObject map[string]any, parentObject map[stri return toObject, nil } +func registerFilesParametersToMldev(fromObject map[string]any, parentObject map[string]any, rootObject map[string]any) (toObject map[string]any, err error) { + toObject = make(map[string]any) + + fromURIs := getValueByPath(fromObject, []string{"uris"}) + if fromURIs != nil { + setValueByPath(toObject, []string{"uris"}, fromURIs) + } + + return toObject, nil +} + +func registerFilesResponseFromMldev(fromObject map[string]any, parentObject map[string]any, rootObject map[string]any) (toObject map[string]any, err error) { + toObject = make(map[string]any) + + fromSdkHttpResponse := getValueByPath(fromObject, []string{"sdkHttpResponse"}) + if fromSdkHttpResponse != nil { + setValueByPath(toObject, []string{"sdkHttpResponse"}, fromSdkHttpResponse) + } + + fromFiles := getValueByPath(fromObject, []string{"files"}) + if fromFiles != nil { + setValueByPath(toObject, []string{"files"}, fromFiles) + } + + return toObject, nil +} + type Files struct { apiClient *apiClient } @@ -430,6 +459,117 @@ func (m Files) Delete(ctx context.Context, name string, config *DeleteFileConfig return response, nil } +func (m Files) registerFiles(ctx context.Context, uris []string, config *RegisterFilesConfig) (*RegisterFilesResponse, error) { + parameterMap := make(map[string]any) + + kwargs := map[string]any{"uris": uris} + deepMarshal(kwargs, ¶meterMap) + + var httpOptions *HTTPOptions + if config == nil || config.HTTPOptions == nil { + httpOptions = &HTTPOptions{} + } else { + httpOptions = config.HTTPOptions + } + if httpOptions.Headers == nil { + httpOptions.Headers = http.Header{} + } + var response = new(RegisterFilesResponse) + var responseMap map[string]any + var fromConverter func(map[string]any, map[string]any, map[string]any) (map[string]any, error) + var toConverter func(map[string]any, map[string]any, map[string]any) (map[string]any, error) + if m.apiClient.clientConfig.Backend == BackendVertexAI { + + return nil, fmt.Errorf("method RegisterFiles is only supported in the Gemini Developer client. You can choose to use Gemini Developer client by setting ClientConfig.Backend to BackendGeminiAPI.") + + } else { + toConverter = registerFilesParametersToMldev + fromConverter = registerFilesResponseFromMldev + } + + body, err := toConverter(parameterMap, nil, parameterMap) + if err != nil { + return nil, err + } + var path string + var urlParams map[string]any + if _, ok := body["_url"]; ok { + urlParams = body["_url"].(map[string]any) + delete(body, "_url") + } + if m.apiClient.clientConfig.Backend == BackendVertexAI { + path, err = formatMap("None", urlParams) + } else { + path, err = formatMap("files:register", urlParams) + } + if err != nil { + return nil, fmt.Errorf("invalid url params: %#v.\n%w", urlParams, err) + } + if _, ok := body["_query"]; ok { + query, err := createURLQuery(body["_query"].(map[string]any)) + if err != nil { + return nil, err + } + path += "?" + query + delete(body, "_query") + } + responseMap, err = sendRequest(ctx, m.apiClient, path, http.MethodPost, body, httpOptions) + if err != nil { + return nil, err + } + if fromConverter != nil { + responseMap, err = fromConverter(responseMap, nil, parameterMap) + } + if err != nil { + return nil, err + } + err = mapToStruct(responseMap, response) + if err != nil { + return nil, err + } + + return response, nil +} + +// RegisterFiles registers GCS files with the Gemini file service. +// This method is only supported in the Gemini Developer client (not Vertex AI). +// It requires explicit OAuth credentials for authentication. +func (m Files) RegisterFiles(ctx context.Context, credentials *auth.Credentials, uris []string, config *RegisterFilesConfig) (*RegisterFilesResponse, error) { + if m.apiClient.clientConfig.Backend == BackendVertexAI { + return nil, fmt.Errorf("method RegisterFiles is only supported in the Gemini Developer client. You can choose to use Gemini Developer client by setting ClientConfig.Backend to BackendGeminiAPI.") + } + if credentials == nil { + return nil, fmt.Errorf("credentials are required for RegisterFiles") + } + if len(uris) == 0 { + return nil, fmt.Errorf("at least one URI is required for RegisterFiles") + } + + token, err := credentials.Token(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get token: %w", err) + } + + var localConfig RegisterFilesConfig + if config != nil { + deepCopy(*config, &localConfig) + } + if localConfig.HTTPOptions == nil { + localConfig.HTTPOptions = &HTTPOptions{} + } + if localConfig.HTTPOptions.Headers == nil { + localConfig.HTTPOptions.Headers = http.Header{} + } + localConfig.HTTPOptions.Headers.Set("Authorization", fmt.Sprintf("Bearer %s", token.Value)) + + quotaProjectID, err := credentials.QuotaProjectID(ctx) + if err == nil && quotaProjectID != "" { + localConfig.HTTPOptions.Headers.Set("X-Goog-User-Project", quotaProjectID) + } + + return m.registerFiles(ctx, uris, &localConfig) +} + // List retrieves a paginated list of files resources. func (m Files) List(ctx context.Context, config *ListFilesConfig) (Page[File], error) { listFunc := func(ctx context.Context, config map[string]any) ([]*File, string, *HTTPResponse, error) { diff --git a/files_test.go b/files_test.go index c4dc915a..ea0326bd 100644 --- a/files_test.go +++ b/files_test.go @@ -831,3 +831,233 @@ type errorReader struct{} func (r *errorReader) Read(p []byte) (n int, err error) { return 0, fmt.Errorf("intentional read error") } + +// mockTokenProvider implements auth.TokenProvider for testing. +type mockTokenProvider struct { + token *auth.Token +} + +func (m *mockTokenProvider) Token(ctx context.Context) (*auth.Token, error) { + return m.token, nil +} + +func newMockCredentials(tokenValue string) *auth.Credentials { + return auth.NewCredentials(&auth.CredentialsOptions{ + TokenProvider: &mockTokenProvider{ + token: &auth.Token{Value: tokenValue}, + }, + }) +} + +func TestRegisterFiles(t *testing.T) { + ctx := context.Background() + + t.Run("Success", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if !strings.HasSuffix(r.URL.Path, "files:register") { + t.Errorf("expected path ending with files:register, got %s", r.URL.Path) + } + authHeader := r.Header.Get("Authorization") + if authHeader != "Bearer test-token" { + t.Errorf("expected Authorization header 'Bearer test-token', got %q", authHeader) + } + + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("failed to decode request body: %v", err) + } + uris, ok := body["uris"].([]any) + if !ok || len(uris) != 1 || uris[0] != "gs://bucket/object" { + t.Errorf("expected uris [gs://bucket/object], got %v", body["uris"]) + } + + resp := map[string]any{ + "files": []map[string]any{ + { + "name": "files/abc123", + "uri": "gs://bucket/object", + "sizeBytes": "1024", + "mimeType": "application/octet-stream", + "state": "ACTIVE", + "source": "REGISTERED", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer ts.Close() + + client, err := NewClient(ctx, &ClientConfig{ + HTTPOptions: HTTPOptions{BaseURL: ts.URL}, + HTTPClient: ts.Client(), + envVarProvider: func() map[string]string { + return map[string]string{"GOOGLE_API_KEY": "test-api-key"} + }, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + creds := newMockCredentials("test-token") + resp, err := client.Files.RegisterFiles(ctx, creds, []string{"gs://bucket/object"}, nil) + if err != nil { + t.Fatalf("RegisterFiles() error: %v", err) + } + if len(resp.Files) != 1 { + t.Fatalf("expected 1 file, got %d", len(resp.Files)) + } + if resp.Files[0].Name != "files/abc123" { + t.Errorf("expected file name 'files/abc123', got %q", resp.Files[0].Name) + } + }) + + t.Run("VertexAINotSupported", func(t *testing.T) { + client, err := NewClient(ctx, &ClientConfig{ + Backend: BackendVertexAI, + Credentials: &auth.Credentials{}, + envVarProvider: func() map[string]string { + return map[string]string{ + "GOOGLE_CLOUD_PROJECT": "test-project", + "GOOGLE_CLOUD_LOCATION": "test-location", + } + }, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + creds := newMockCredentials("test-token") + _, err = client.Files.RegisterFiles(ctx, creds, []string{"gs://bucket/object"}, nil) + if err == nil { + t.Fatal("expected error for Vertex AI, got nil") + } + if !strings.Contains(err.Error(), "only supported in the Gemini Developer client") { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("NilCredentials", func(t *testing.T) { + client, err := NewClient(ctx, &ClientConfig{ + envVarProvider: func() map[string]string { + return map[string]string{"GOOGLE_API_KEY": "test-api-key"} + }, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + _, err = client.Files.RegisterFiles(ctx, nil, []string{"gs://bucket/object"}, nil) + if err == nil { + t.Fatal("expected error for nil credentials, got nil") + } + if !strings.Contains(err.Error(), "credentials are required") { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("EmptyURIs", func(t *testing.T) { + client, err := NewClient(ctx, &ClientConfig{ + envVarProvider: func() map[string]string { + return map[string]string{"GOOGLE_API_KEY": "test-api-key"} + }, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + creds := newMockCredentials("test-token") + _, err = client.Files.RegisterFiles(ctx, creds, []string{}, nil) + if err == nil { + t.Fatal("expected error for empty URIs, got nil") + } + if !strings.Contains(err.Error(), "at least one URI is required") { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("MultipleURIs", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("failed to decode request body: %v", err) + } + uris, ok := body["uris"].([]any) + if !ok || len(uris) != 3 { + t.Errorf("expected 3 uris, got %v", body["uris"]) + } + + resp := map[string]any{ + "files": []map[string]any{ + {"name": "files/file1", "uri": "gs://bucket/obj1"}, + {"name": "files/file2", "uri": "gs://bucket/obj2"}, + {"name": "files/file3", "uri": "gs://bucket/obj3"}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer ts.Close() + + client, err := NewClient(ctx, &ClientConfig{ + HTTPOptions: HTTPOptions{BaseURL: ts.URL}, + HTTPClient: ts.Client(), + envVarProvider: func() map[string]string { + return map[string]string{"GOOGLE_API_KEY": "test-api-key"} + }, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + creds := newMockCredentials("test-token") + resp, err := client.Files.RegisterFiles(ctx, creds, []string{"gs://bucket/obj1", "gs://bucket/obj2", "gs://bucket/obj3"}, nil) + if err != nil { + t.Fatalf("RegisterFiles() error: %v", err) + } + if len(resp.Files) != 3 { + t.Fatalf("expected 3 files, got %d", len(resp.Files)) + } + }) + + t.Run("AuthHeaderVerification", func(t *testing.T) { + var receivedAuthHeader string + var receivedQuotaHeader string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuthHeader = r.Header.Get("Authorization") + receivedQuotaHeader = r.Header.Get("X-Goog-User-Project") + + resp := map[string]any{"files": []map[string]any{}} + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer ts.Close() + + client, err := NewClient(ctx, &ClientConfig{ + HTTPOptions: HTTPOptions{BaseURL: ts.URL}, + HTTPClient: ts.Client(), + envVarProvider: func() map[string]string { + return map[string]string{"GOOGLE_API_KEY": "test-api-key"} + }, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + creds := newMockCredentials("my-secret-token") + _, err = client.Files.RegisterFiles(ctx, creds, []string{"gs://bucket/object"}, nil) + if err != nil { + t.Fatalf("RegisterFiles() error: %v", err) + } + if receivedAuthHeader != "Bearer my-secret-token" { + t.Errorf("expected Authorization 'Bearer my-secret-token', got %q", receivedAuthHeader) + } + // QuotaProjectID from empty credentials returns "" + if receivedQuotaHeader != "" { + t.Errorf("expected empty X-Goog-User-Project header, got %q", receivedQuotaHeader) + } + }) +} diff --git a/types.go b/types.go index f9313c9c..0328d74a 100644 --- a/types.go +++ b/types.go @@ -5332,6 +5332,20 @@ type DeleteFileResponse struct { SDKHTTPResponse *HTTPResponse `json:"sdkHttpResponse,omitempty"` } +// Used to override the default configuration for RegisterFiles. +type RegisterFilesConfig struct { + // Optional. Used to override HTTP request options. + HTTPOptions *HTTPOptions `json:"httpOptions,omitempty"` +} + +// Response for the register files method. +type RegisterFilesResponse struct { + // Optional. Used to retain the full HTTP response. + SDKHTTPResponse *HTTPResponse `json:"sdkHttpResponse,omitempty"` + // The list of registered files. + Files []*File `json:"files,omitempty"` +} + // Config for inlined request. type InlinedRequest struct { // ID of the model to use. For a list of models, see `Google models