diff --git a/go/ai/_test_data/prompts/example.prompt b/go/ai/_test_data/prompts/example.prompt new file mode 100644 index 0000000000..c47c70b4fc --- /dev/null +++ b/go/ai/_test_data/prompts/example.prompt @@ -0,0 +1,20 @@ +--- +model: test-model +maxTurns: 5 +description: A test prompt +toolChoice: required +returnToolRequests: true +input: + schema: + type: object + properties: + name: + type: string + default: + name: world +output: + format: text + schema: + type: string +--- +Hello, {{name}}! \ No newline at end of file diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 95214dd7a6..b2932f2720 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "io/fs" "log/slog" "maps" "os" @@ -505,12 +506,17 @@ func LoadPromptDir(r api.Registry, dir string, namespace string) { return } - loadPromptDir(r, path, namespace) + loadPromptDir(r, os.DirFS(dir), ".", namespace) +} + +// LoadPromptFS loads prompts and partials from the given filesystem for the given namespace. +func LoadPromptFS(r api.Registry, fsys fs.FS, dir string, namespace string) { + loadPromptDir(r, fsys, dir, namespace) } // loadPromptDir recursively loads prompts and partials from the directory. -func loadPromptDir(r api.Registry, dir string, namespace string) { - entries, err := os.ReadDir(dir) +func loadPromptDir(r api.Registry, fsys fs.FS, dir, namespace string) { + entries, err := fs.ReadDir(fsys, dir) if err != nil { panic(fmt.Errorf("failed to read prompt directory structure: %w", err)) } @@ -519,7 +525,7 @@ func loadPromptDir(r api.Registry, dir string, namespace string) { filename := entry.Name() path := filepath.Join(dir, filename) if entry.IsDir() { - loadPromptDir(r, path, namespace) + loadPromptDir(r, fsys, path, namespace) } else if strings.HasSuffix(filename, ".prompt") { if strings.HasPrefix(filename, "_") { partialName := strings.TrimSuffix(filename[1:], ".prompt") @@ -531,7 +537,7 @@ func loadPromptDir(r api.Registry, dir string, namespace string) { r.RegisterPartial(partialName, string(source)) slog.Debug("Registered Dotprompt partial", "name", partialName, "file", path) } else { - LoadPrompt(r, dir, filename, namespace) + loadPrompt(r, fsys, dir, filename, namespace) } } } @@ -539,11 +545,17 @@ func loadPromptDir(r api.Registry, dir string, namespace string) { // LoadPrompt loads a single prompt into the registry. func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { + dir, rest := filepath.Split(dir) + return loadPrompt(r, os.DirFS(dir), rest, filename, namespace) +} + +// loadPrompt uses provided fsys to load a single prompt into the registry. +func loadPrompt(r api.Registry, fsys fs.FS, dir, filename, namespace string) Prompt { name := strings.TrimSuffix(filename, ".prompt") name, variant, _ := strings.Cut(name, ".") sourceFile := filepath.Join(dir, filename) - source, err := os.ReadFile(sourceFile) + source, err := fs.ReadFile(fsys, sourceFile) if err != nil { slog.Error("Failed to read prompt file", "file", sourceFile, "error", err) return nil diff --git a/go/ai/prompt_test.go b/go/ai/prompt_test.go index be1f6c1354..008b0d811c 100644 --- a/go/ai/prompt_test.go +++ b/go/ai/prompt_test.go @@ -16,6 +16,7 @@ package ai import ( "context" + "embed" "fmt" "os" "path/filepath" @@ -29,6 +30,9 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" ) +//go:embed _test_data/prompts +var embededPrompts embed.FS + type InputOutput struct { Text string `json:"text"` } @@ -877,6 +881,15 @@ func assertResponse(t *testing.T, resp *ModelResponse, want string) { } } +func TestLoadPrompt_FromFS(t *testing.T) { + reg := registry.New() + LoadPromptFS(reg, embededPrompts, "_test_data/prompts", "test-namespace") + prompt := LookupPrompt(reg, "test-namespace/example") + if prompt == nil { + t.Fatalf("Prompt was not registered") + } +} + func TestLoadPrompt(t *testing.T) { // Create a temporary directory for testing tempDir := t.TempDir() diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 83cd33afeb..2ec88fd036 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -21,6 +21,7 @@ import ( "context" "errors" "fmt" + "io/fs" "log/slog" "os" "os/signal" @@ -46,6 +47,7 @@ type Genkit struct { type genkitOptions struct { DefaultModel string // Default model to use if no other model is specified. PromptDir string // Directory where dotprompts are stored. Will be loaded automatically on initialization. + PromptFS fs.FS // Filesystem that will be used for PromptDir lookup. Plugins []api.Plugin // Plugin to initialize automatically. } @@ -69,6 +71,13 @@ func (o *genkitOptions) apply(gOpts *genkitOptions) error { gOpts.PromptDir = o.PromptDir } + if o.PromptFS != nil { + if gOpts.PromptFS != nil { + return errors.New("cannot set prompt filesystem more than once (WithPromptFS)") + } + gOpts.PromptFS = o.PromptFS + } + if len(o.Plugins) > 0 { if gOpts.Plugins != nil { return errors.New("cannot set plugins more than once (WithPlugins)") @@ -106,6 +115,12 @@ func WithPromptDir(dir string) GenkitOption { return &genkitOptions{PromptDir: dir} } +// WithPromptFS is a more generic version of `WithPromptDir` and accepts a filesytem +// instead of directory path +func WithPromptFS(fsys fs.FS) GenkitOption { + return &genkitOptions{PromptFS: fsys} +} + // Init creates and initializes a new [Genkit] instance with the provided options. // It sets up the registry, initializes plugins ([WithPlugins]), loads prompts // ([WithPromptDir]), and configures other settings like the default model @@ -184,7 +199,11 @@ func Init(ctx context.Context, opts ...GenkitOption) *Genkit { ai.ConfigureFormats(r) ai.DefineGenerateAction(ctx, r) - ai.LoadPromptDir(r, gOpts.PromptDir, "") + if gOpts.PromptFS == nil { + ai.LoadPromptDir(r, gOpts.PromptDir, "") + } else { + ai.LoadPromptFS(r, gOpts.PromptFS, gOpts.PromptDir, "") + } r.RegisterValue(api.DefaultModelKey, gOpts.DefaultModel) r.RegisterValue(api.PromptDirKey, gOpts.PromptDir) diff --git a/go/samples/prompts-dir/main.go b/go/samples/prompts-dir/main.go index 59e5e83843..a5e15113bf 100644 --- a/go/samples/prompts-dir/main.go +++ b/go/samples/prompts-dir/main.go @@ -6,6 +6,7 @@ package main import ( "context" + "embed" "errors" // Import Genkit and the Google AI plugin @@ -14,12 +15,17 @@ import ( "github.com/firebase/genkit/go/plugins/googlegenai" ) +//go:embed prompts +var prompts embed.FS + func main() { ctx := context.Background() g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{}), genkit.WithPromptDir("prompts"), + // Without it OS's filesystem will be used + genkit.WithPromptFS(prompts), ) type greetingStyle struct {