diff --git a/pkg/sdkserver/datasets.go b/pkg/sdkserver/datasets.go index a65566a4..bfef595a 100644 --- a/pkg/sdkserver/datasets.go +++ b/pkg/sdkserver/datasets.go @@ -11,16 +11,19 @@ import ( ) type datasetRequest struct { - Input string `json:"input"` - Workspace string `json:"workspace"` - DatasetToolRepo string `json:"datasetToolRepo"` + Input string `json:"input"` + WorkspaceID string `json:"workspaceID"` + DatasetToolRepo string `json:"datasetToolRepo"` + Env []string `json:"env"` } func (r datasetRequest) validate(requireInput bool) error { - if r.Workspace == "" { - return fmt.Errorf("workspace is required") + if r.WorkspaceID == "" { + return fmt.Errorf("workspaceID is required") } else if requireInput && r.Input == "" { return fmt.Errorf("input is required") + } else if len(r.Env) == 0 { + return fmt.Errorf("env is required") } return nil } @@ -30,7 +33,7 @@ func (r datasetRequest) opts(o gptscript.Options) gptscript.Options { Cache: o.Cache, Monitor: o.Monitor, Runner: o.Runner, - Workspace: r.Workspace, + Workspace: r.WorkspaceID, } return opts } @@ -39,7 +42,7 @@ func (r datasetRequest) getToolRepo() string { if r.DatasetToolRepo != "" { return r.DatasetToolRepo } - return "github.com/gptscript-ai/datasets" + return "github.com/otto8-ai/datasets" } func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) { @@ -71,7 +74,7 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) { return } - result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input) + result, err := g.Run(r.Context(), prg, req.Env, req.Input) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return @@ -132,7 +135,7 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) { return } - result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input) + result, err := g.Run(r.Context(), prg, req.Env, req.Input) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return @@ -200,7 +203,80 @@ func (s *server) addDatasetElement(w http.ResponseWriter, r *http.Request) { return } - result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input) + result, err := g.Run(r.Context(), prg, req.Env, req.Input) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) + return + } + + writeResponse(logger, w, map[string]any{"stdout": result}) +} + +type addDatasetElementsArgs struct { + DatasetID string `json:"datasetID"` + Elements []struct { + Name string `json:"name"` + Description string `json:"description"` + Contents string `json:"contents"` + } +} + +func (a addDatasetElementsArgs) validate() error { + if a.DatasetID == "" { + return fmt.Errorf("datasetID is required") + } + if len(a.Elements) == 0 { + return fmt.Errorf("elements is required") + } + return nil +} + +func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + + var req datasetRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err)) + return + } + + if err := req.validate(true); err != nil { + writeError(logger, w, http.StatusBadRequest, err) + return + } + + g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts)) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err)) + return + } + + var args addDatasetElementsArgs + if err := json.Unmarshal([]byte(req.Input), &args); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err)) + return + } + + if err := args.validate(); err != nil { + writeError(logger, w, http.StatusBadRequest, err) + return + } + + prg, err := loader.Program(r.Context(), req.getToolRepo(), "Add Elements", loader.Options{ + Cache: g.Cache, + }) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) + return + } + + elementsJSON, err := json.Marshal(args.Elements) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to marshal elements: %w", err)) + return + } + + result, err := g.Run(r.Context(), prg, req.Env, fmt.Sprintf(`{"datasetID":%q, "elements":%q}`, args.DatasetID, string(elementsJSON))) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return @@ -259,7 +335,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) { return } - result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input) + result, err := g.Run(r.Context(), prg, req.Env, req.Input) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return @@ -322,7 +398,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) { return } - result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input) + result, err := g.Run(r.Context(), prg, req.Env, req.Input) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index 8afdb8a4..713f74fe 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -73,6 +73,7 @@ func (s *server) addRoutes(mux *http.ServeMux) { mux.HandleFunc("POST /datasets/list-elements", s.listDatasetElements) mux.HandleFunc("POST /datasets/get-element", s.getDatasetElement) mux.HandleFunc("POST /datasets/add-element", s.addDatasetElement) + mux.HandleFunc("POST /datasets/add-elements", s.addDatasetElements) mux.HandleFunc("POST /workspaces/create", s.createWorkspace) mux.HandleFunc("POST /workspaces/delete", s.deleteWorkspace)