Skip to content

Commit d46ec1c

Browse files
committed
add Add Elements tool to dataset sdk
Signed-off-by: Grant Linville <grant@acorn.io>
1 parent e7e9231 commit d46ec1c

File tree

2 files changed

+78
-8
lines changed

2 files changed

+78
-8
lines changed

pkg/sdkserver/datasets.go

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,21 @@ import (
1212

1313
type datasetRequest struct {
1414
Input string `json:"input"`
15-
Workspace string `json:"workspace"`
1615
DatasetToolRepo string `json:"datasetToolRepo"`
1716
}
1817

1918
func (r datasetRequest) validate(requireInput bool) error {
20-
if r.Workspace == "" {
21-
return fmt.Errorf("workspace is required")
22-
} else if requireInput && r.Input == "" {
19+
if requireInput && r.Input == "" {
2320
return fmt.Errorf("input is required")
2421
}
2522
return nil
2623
}
2724

2825
func (r datasetRequest) opts(o gptscript.Options) gptscript.Options {
2926
opts := gptscript.Options{
30-
Cache: o.Cache,
31-
Monitor: o.Monitor,
32-
Runner: o.Runner,
33-
Workspace: r.Workspace,
27+
Cache: o.Cache,
28+
Monitor: o.Monitor,
29+
Runner: o.Runner,
3430
}
3531
return opts
3632
}
@@ -209,6 +205,79 @@ func (s *server) addDatasetElement(w http.ResponseWriter, r *http.Request) {
209205
writeResponse(logger, w, map[string]any{"stdout": result})
210206
}
211207

208+
type addDatasetElementsArgs struct {
209+
DatasetID string `json:"datasetID"`
210+
Elements []struct {
211+
Name string `json:"name"`
212+
Description string `json:"description"`
213+
Contents string `json:"contents"`
214+
}
215+
}
216+
217+
func (a addDatasetElementsArgs) validate() error {
218+
if a.DatasetID == "" {
219+
return fmt.Errorf("datasetID is required")
220+
}
221+
if len(a.Elements) == 0 {
222+
return fmt.Errorf("elements is required")
223+
}
224+
return nil
225+
}
226+
227+
func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
228+
logger := gcontext.GetLogger(r.Context())
229+
230+
var req datasetRequest
231+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
232+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
233+
return
234+
}
235+
236+
if err := req.validate(true); err != nil {
237+
writeError(logger, w, http.StatusBadRequest, err)
238+
return
239+
}
240+
241+
g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
242+
if err != nil {
243+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
244+
return
245+
}
246+
247+
var args addDatasetElementsArgs
248+
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
249+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
250+
return
251+
}
252+
253+
if err := args.validate(); err != nil {
254+
writeError(logger, w, http.StatusBadRequest, err)
255+
return
256+
}
257+
258+
prg, err := loader.Program(r.Context(), req.getToolRepo(), "Add Elements", loader.Options{
259+
Cache: g.Cache,
260+
})
261+
if err != nil {
262+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
263+
return
264+
}
265+
266+
elementsJSON, err := json.Marshal(args.Elements)
267+
if err != nil {
268+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to marshal elements: %w", err))
269+
return
270+
}
271+
272+
result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, fmt.Sprintf(`{"datasetID":%q, elements:%s}`, args.DatasetID, elementsJSON))
273+
if err != nil {
274+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
275+
return
276+
}
277+
278+
writeResponse(logger, w, map[string]any{"stdout": result})
279+
}
280+
212281
type listDatasetElementsArgs struct {
213282
DatasetID string `json:"datasetID"`
214283
}

pkg/sdkserver/routes.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ func (s *server) addRoutes(mux *http.ServeMux) {
7373
mux.HandleFunc("POST /datasets/list-elements", s.listDatasetElements)
7474
mux.HandleFunc("POST /datasets/get-element", s.getDatasetElement)
7575
mux.HandleFunc("POST /datasets/add-element", s.addDatasetElement)
76+
mux.HandleFunc("POST /datasets/add-elements", s.addDatasetElements)
7677

7778
mux.HandleFunc("POST /workspaces/create", s.createWorkspace)
7879
mux.HandleFunc("POST /workspaces/delete", s.deleteWorkspace)

0 commit comments

Comments
 (0)