@@ -12,13 +12,16 @@ import (
1212
1313type datasetRequest struct {
1414 Input string `json:"input"`
15+ WorkspaceID string `json:"workspaceID"`
1516 DatasetTool string `json:"datasetTool"`
1617 Env []string `json:"env"`
1718}
1819
19- func (r datasetRequest ) validate () error {
20- if r .Input == "" {
20+ func (r datasetRequest ) validate (requireInput bool ) error {
21+ if requireInput && r .Input == "" {
2122 return fmt .Errorf ("input is required" )
23+ } else if r .WorkspaceID == "" {
24+ return fmt .Errorf ("workspaceID is required" )
2225 } else if len (r .Env ) == 0 {
2326 return fmt .Errorf ("env is required" )
2427 }
@@ -27,9 +30,10 @@ func (r datasetRequest) validate() error {
2730
2831func (r datasetRequest ) opts (o gptscript.Options ) gptscript.Options {
2932 opts := gptscript.Options {
30- Cache : o .Cache ,
31- Monitor : o .Monitor ,
32- Runner : o .Runner ,
33+ Cache : o .Cache ,
34+ Monitor : o .Monitor ,
35+ Runner : o .Runner ,
36+ Workspace : r .WorkspaceID ,
3337 }
3438 return opts
3539}
@@ -41,17 +45,6 @@ func (r datasetRequest) getToolRepo() string {
4145 return "github.com/otto8-ai/datasets"
4246}
4347
44- type listDatasetsArgs struct {
45- WorkspaceID string `json:"workspaceID"`
46- }
47-
48- func (a listDatasetsArgs ) validate () error {
49- if a .WorkspaceID == "" {
50- return fmt .Errorf ("workspaceID is required" )
51- }
52- return nil
53- }
54-
5548func (s * server ) listDatasets (w http.ResponseWriter , r * http.Request ) {
5649 logger := gcontext .GetLogger (r .Context ())
5750
@@ -61,7 +54,7 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
6154 return
6255 }
6356
64- if err := req .validate (); err != nil {
57+ if err := req .validate (false ); err != nil {
6558 writeError (logger , w , http .StatusBadRequest , err )
6659 return
6760 }
@@ -72,17 +65,6 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
7265 return
7366 }
7467
75- var args listDatasetsArgs
76- if err := json .Unmarshal ([]byte (req .Input ), & args ); err != nil {
77- writeError (logger , w , http .StatusBadRequest , fmt .Errorf ("failed to unmarshal input: %w" , err ))
78- return
79- }
80-
81- if err := args .validate (); err != nil {
82- writeError (logger , w , http .StatusBadRequest , err )
83- return
84- }
85-
8668 prg , err := loader .Program (r .Context (), req .getToolRepo (), "List Datasets" , loader.Options {
8769 Cache : g .Cache ,
8870 })
@@ -102,9 +84,8 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
10284}
10385
10486type addDatasetElementsArgs struct {
105- WorkspaceID string `json:"workspaceID"`
106- DatasetID string `json:"datasetID"`
107- Elements []struct {
87+ DatasetID string `json:"datasetID"`
88+ Elements []struct {
10889 Name string `json:"name"`
10990 Description string `json:"description"`
11091 Contents string `json:"contents"`
@@ -113,9 +94,7 @@ type addDatasetElementsArgs struct {
11394}
11495
11596func (a addDatasetElementsArgs ) validate () error {
116- if a .WorkspaceID == "" {
117- return fmt .Errorf ("workspaceID is required" )
118- } else if len (a .Elements ) == 0 {
97+ if len (a .Elements ) == 0 {
11998 return fmt .Errorf ("elements is required" )
12099 }
121100 return nil
@@ -130,7 +109,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
130109 return
131110 }
132111
133- if err := req .validate (); err != nil {
112+ if err := req .validate (true ); err != nil {
134113 writeError (logger , w , http .StatusBadRequest , err )
135114 return
136115 }
@@ -170,14 +149,11 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
170149}
171150
172151type listDatasetElementsArgs struct {
173- WorkspaceID string `json:"workspaceID"`
174- DatasetID string `json:"datasetID"`
152+ DatasetID string `json:"datasetID"`
175153}
176154
177155func (a listDatasetElementsArgs ) validate () error {
178- if a .WorkspaceID == "" {
179- return fmt .Errorf ("workspaceID is required" )
180- } else if a .DatasetID == "" {
156+ if a .DatasetID == "" {
181157 return fmt .Errorf ("datasetID is required" )
182158 }
183159 return nil
@@ -192,7 +168,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
192168 return
193169 }
194170
195- if err := req .validate (); err != nil {
171+ if err := req .validate (true ); err != nil {
196172 writeError (logger , w , http .StatusBadRequest , err )
197173 return
198174 }
@@ -232,15 +208,12 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
232208}
233209
234210type getDatasetElementArgs struct {
235- WorkspaceID string `json:"workspaceID"`
236- DatasetID string `json:"datasetID"`
237- Name string `json:"name"`
211+ DatasetID string `json:"datasetID"`
212+ Name string `json:"name"`
238213}
239214
240215func (a getDatasetElementArgs ) validate () error {
241- if a .WorkspaceID == "" {
242- return fmt .Errorf ("workspaceID is required" )
243- } else if a .DatasetID == "" {
216+ if a .DatasetID == "" {
244217 return fmt .Errorf ("datasetID is required" )
245218 } else if a .Name == "" {
246219 return fmt .Errorf ("name is required" )
@@ -257,7 +230,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
257230 return
258231 }
259232
260- if err := req .validate (); err != nil {
233+ if err := req .validate (true ); err != nil {
261234 writeError (logger , w , http .StatusBadRequest , err )
262235 return
263236 }
0 commit comments