Skip to content

Commit 6515d93

Browse files
committed
chore: sdkserver: update dataset methods for the rewrite
Signed-off-by: Grant Linville <grant@acorn.io>
1 parent 4ce687f commit 6515d93

File tree

2 files changed

+49
-162
lines changed

2 files changed

+49
-162
lines changed

pkg/sdkserver/datasets.go

Lines changed: 49 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,13 @@ import (
1111
)
1212

1313
type datasetRequest struct {
14-
Input string `json:"input"`
15-
WorkspaceID string `json:"workspaceID"`
16-
DatasetToolRepo string `json:"datasetToolRepo"`
17-
Env []string `json:"env"`
14+
Input string `json:"input"`
15+
DatasetTool string `json:"datasetTool"`
16+
Env []string `json:"env"`
1817
}
1918

20-
func (r datasetRequest) validate(requireInput bool) error {
21-
if r.WorkspaceID == "" {
22-
return fmt.Errorf("workspaceID is required")
23-
} else if requireInput && r.Input == "" {
19+
func (r datasetRequest) validate() error {
20+
if r.Input == "" {
2421
return fmt.Errorf("input is required")
2522
} else if len(r.Env) == 0 {
2623
return fmt.Errorf("env is required")
@@ -30,72 +27,32 @@ func (r datasetRequest) validate(requireInput bool) error {
3027

3128
func (r datasetRequest) opts(o gptscript.Options) gptscript.Options {
3229
opts := gptscript.Options{
33-
Cache: o.Cache,
34-
Monitor: o.Monitor,
35-
Runner: o.Runner,
36-
Workspace: r.WorkspaceID,
30+
Cache: o.Cache,
31+
Monitor: o.Monitor,
32+
Runner: o.Runner,
3733
}
3834
return opts
3935
}
4036

4137
func (r datasetRequest) getToolRepo() string {
42-
if r.DatasetToolRepo != "" {
43-
return r.DatasetToolRepo
38+
if r.DatasetTool != "" {
39+
return r.DatasetTool
4440
}
45-
return "github.com/otto8-ai/datasets"
41+
return "github.com/g-linville/datasets@rewrite-as-daemon"
4642
}
4743

48-
func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
49-
logger := gcontext.GetLogger(r.Context())
50-
51-
var req datasetRequest
52-
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
53-
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
54-
return
55-
}
56-
57-
if err := req.validate(false); err != nil {
58-
writeError(logger, w, http.StatusBadRequest, err)
59-
return
60-
}
61-
62-
g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
63-
if err != nil {
64-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
65-
return
66-
}
67-
68-
prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Datasets", loader.Options{
69-
Cache: g.Cache,
70-
})
71-
72-
if err != nil {
73-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
74-
return
75-
}
76-
77-
result, err := g.Run(r.Context(), prg, req.Env, req.Input)
78-
if err != nil {
79-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
80-
return
81-
}
82-
83-
writeResponse(logger, w, map[string]any{"stdout": result})
84-
}
85-
86-
type createDatasetArgs struct {
87-
Name string `json:"datasetName"`
88-
Description string `json:"datasetDescription"`
44+
type listDatasetsArgs struct {
45+
WorkspaceID string `json:"workspaceID"`
8946
}
9047

91-
func (a createDatasetArgs) validate() error {
92-
if a.Name == "" {
93-
return fmt.Errorf("datasetName is required")
48+
func (a listDatasetsArgs) validate() error {
49+
if a.WorkspaceID == "" {
50+
return fmt.Errorf("workspaceID is required")
9451
}
9552
return nil
9653
}
9754

98-
func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
55+
func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
9956
logger := gcontext.GetLogger(r.Context())
10057

10158
var req datasetRequest
@@ -104,7 +61,7 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
10461
return
10562
}
10663

107-
if err := req.validate(true); err != nil {
64+
if err := req.validate(); err != nil {
10865
writeError(logger, w, http.StatusBadRequest, err)
10966
return
11067
}
@@ -115,7 +72,7 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
11572
return
11673
}
11774

118-
var args createDatasetArgs
75+
var args listDatasetsArgs
11976
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
12077
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
12178
return
@@ -126,7 +83,7 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
12683
return
12784
}
12885

129-
prg, err := loader.Program(r.Context(), req.getToolRepo(), "Create Dataset", loader.Options{
86+
prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Datasets", loader.Options{
13087
Cache: g.Cache,
13188
})
13289

@@ -144,88 +101,21 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
144101
writeResponse(logger, w, map[string]any{"stdout": result})
145102
}
146103

147-
type addDatasetElementArgs struct {
148-
DatasetID string `json:"datasetID"`
149-
ElementName string `json:"elementName"`
150-
ElementDescription string `json:"elementDescription"`
151-
ElementContent string `json:"elementContent"`
152-
}
153-
154-
func (a addDatasetElementArgs) validate() error {
155-
if a.DatasetID == "" {
156-
return fmt.Errorf("datasetID is required")
157-
}
158-
if a.ElementName == "" {
159-
return fmt.Errorf("elementName is required")
160-
}
161-
if a.ElementContent == "" {
162-
return fmt.Errorf("elementContent is required")
163-
}
164-
return nil
165-
}
166-
167-
func (s *server) addDatasetElement(w http.ResponseWriter, r *http.Request) {
168-
logger := gcontext.GetLogger(r.Context())
169-
170-
var req datasetRequest
171-
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
172-
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
173-
return
174-
}
175-
176-
if err := req.validate(true); err != nil {
177-
writeError(logger, w, http.StatusBadRequest, err)
178-
return
179-
}
180-
181-
g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
182-
if err != nil {
183-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
184-
return
185-
}
186-
187-
var args addDatasetElementArgs
188-
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
189-
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
190-
return
191-
}
192-
193-
if err := args.validate(); err != nil {
194-
writeError(logger, w, http.StatusBadRequest, err)
195-
return
196-
}
197-
198-
prg, err := loader.Program(r.Context(), req.getToolRepo(), "Add Element", loader.Options{
199-
Cache: g.Cache,
200-
})
201-
if err != nil {
202-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
203-
return
204-
}
205-
206-
result, err := g.Run(r.Context(), prg, req.Env, req.Input)
207-
if err != nil {
208-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
209-
return
210-
}
211-
212-
writeResponse(logger, w, map[string]any{"stdout": result})
213-
}
214-
215104
type addDatasetElementsArgs struct {
216-
DatasetID string `json:"datasetID"`
217-
Elements []struct {
218-
Name string `json:"name"`
219-
Description string `json:"description"`
220-
Contents string `json:"contents"`
221-
}
105+
WorkspaceID string `json:"workspaceID"`
106+
DatasetID string `json:"datasetID"`
107+
Elements []struct {
108+
Name string `json:"name"`
109+
Description string `json:"description"`
110+
Contents string `json:"contents"`
111+
BinaryContents []byte `json:"binaryContents"`
112+
} `json:"elements"`
222113
}
223114

224115
func (a addDatasetElementsArgs) validate() error {
225-
if a.DatasetID == "" {
226-
return fmt.Errorf("datasetID is required")
227-
}
228-
if len(a.Elements) == 0 {
116+
if a.WorkspaceID == "" {
117+
return fmt.Errorf("workspaceID is required")
118+
} else if len(a.Elements) == 0 {
229119
return fmt.Errorf("elements is required")
230120
}
231121
return nil
@@ -240,7 +130,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
240130
return
241131
}
242132

243-
if err := req.validate(true); err != nil {
133+
if err := req.validate(); err != nil {
244134
writeError(logger, w, http.StatusBadRequest, err)
245135
return
246136
}
@@ -270,13 +160,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
270160
return
271161
}
272162

273-
elementsJSON, err := json.Marshal(args.Elements)
274-
if err != nil {
275-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to marshal elements: %w", err))
276-
return
277-
}
278-
279-
result, err := g.Run(r.Context(), prg, req.Env, fmt.Sprintf(`{"datasetID":%q, "elements":%q}`, args.DatasetID, string(elementsJSON)))
163+
result, err := g.Run(r.Context(), prg, req.Env, req.Input)
280164
if err != nil {
281165
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
282166
return
@@ -286,11 +170,14 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
286170
}
287171

288172
type listDatasetElementsArgs struct {
289-
DatasetID string `json:"datasetID"`
173+
WorkspaceID string `json:"workspaceID"`
174+
DatasetID string `json:"datasetID"`
290175
}
291176

292177
func (a listDatasetElementsArgs) validate() error {
293-
if a.DatasetID == "" {
178+
if a.WorkspaceID == "" {
179+
return fmt.Errorf("workspaceID is required")
180+
} else if a.DatasetID == "" {
294181
return fmt.Errorf("datasetID is required")
295182
}
296183
return nil
@@ -305,7 +192,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
305192
return
306193
}
307194

308-
if err := req.validate(true); err != nil {
195+
if err := req.validate(); err != nil {
309196
writeError(logger, w, http.StatusBadRequest, err)
310197
return
311198
}
@@ -345,16 +232,18 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
345232
}
346233

347234
type getDatasetElementArgs struct {
348-
DatasetID string `json:"datasetID"`
349-
Element string `json:"element"`
235+
WorkspaceID string `json:"workspaceID"`
236+
DatasetID string `json:"datasetID"`
237+
Name string `json:"name"`
350238
}
351239

352240
func (a getDatasetElementArgs) validate() error {
353-
if a.DatasetID == "" {
241+
if a.WorkspaceID == "" {
242+
return fmt.Errorf("workspaceID is required")
243+
} else if a.DatasetID == "" {
354244
return fmt.Errorf("datasetID is required")
355-
}
356-
if a.Element == "" {
357-
return fmt.Errorf("element is required")
245+
} else if a.Name == "" {
246+
return fmt.Errorf("name is required")
358247
}
359248
return nil
360249
}
@@ -368,7 +257,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
368257
return
369258
}
370259

371-
if err := req.validate(true); err != nil {
260+
if err := req.validate(); err != nil {
372261
writeError(logger, w, http.StatusBadRequest, err)
373262
return
374263
}
@@ -390,7 +279,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
390279
return
391280
}
392281

393-
prg, err := loader.Program(r.Context(), req.getToolRepo(), "Get Element SDK", loader.Options{
282+
prg, err := loader.Program(r.Context(), req.getToolRepo(), "Get Element", loader.Options{
394283
Cache: g.Cache,
395284
})
396285
if err != nil {

pkg/sdkserver/routes.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,8 @@ func (s *server) addRoutes(mux *http.ServeMux) {
6969
mux.HandleFunc("POST /credentials/delete", s.deleteCredential)
7070

7171
mux.HandleFunc("POST /datasets", s.listDatasets)
72-
mux.HandleFunc("POST /datasets/create", s.createDataset)
7372
mux.HandleFunc("POST /datasets/list-elements", s.listDatasetElements)
7473
mux.HandleFunc("POST /datasets/get-element", s.getDatasetElement)
75-
mux.HandleFunc("POST /datasets/add-element", s.addDatasetElement)
7674
mux.HandleFunc("POST /datasets/add-elements", s.addDatasetElements)
7775

7876
mux.HandleFunc("POST /workspaces/create", s.createWorkspace)

0 commit comments

Comments
 (0)