Skip to content

Commit ce74515

Browse files
committed
feat: add dataset functions
Signed-off-by: Grant Linville <grant@acorn.io>
1 parent 2af5143 commit ce74515

File tree

4 files changed

+264
-0
lines changed

4 files changed

+264
-0
lines changed

datasets.go

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package gptscript
2+
3+
type DatasetElementMeta struct {
4+
Name string `json:"name"`
5+
Description string `json:"description"`
6+
}
7+
8+
type DatasetElement struct {
9+
DatasetElementMeta `json:",inline"`
10+
Contents string `json:"contents"`
11+
}
12+
13+
type DatasetMeta struct {
14+
ID string `json:"id"`
15+
Name string `json:"name"`
16+
Description string `json:"description"`
17+
}
18+
19+
type Dataset struct {
20+
DatasetMeta `json:",inline"`
21+
BaseDir string `json:"baseDir,omitempty"`
22+
Elements map[string]DatasetElement `json:"elements"`
23+
}
24+
25+
type datasetRequest struct {
26+
Input string `json:"input"`
27+
Workspace string `json:"workspace"`
28+
DatasetToolRepo string `json:"datasetToolRepo"`
29+
}
30+
31+
type createDatasetArgs struct {
32+
Name string `json:"dataset_name"`
33+
Description string `json:"dataset_description"`
34+
}
35+
36+
type addDatasetElementArgs struct {
37+
DatasetID string `json:"dataset_id"`
38+
ElementName string `json:"element_name"`
39+
ElementDescription string `json:"element_description"`
40+
ElementContent string `json:"element_content"`
41+
}
42+
43+
type listDatasetElementArgs struct {
44+
DatasetID string `json:"dataset_id"`
45+
}
46+
47+
type getDatasetElementArgs struct {
48+
DatasetID string `json:"dataset_id"`
49+
Element string `json:"element"`
50+
}

gptscript.go

+170
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,176 @@ func (g *GPTScript) DeleteCredential(ctx context.Context, credCtx, name string)
388388
return err
389389
}
390390

391+
// Dataset methods
392+
393+
func (g *GPTScript) ListDatasets(ctx context.Context, workspace string) ([]DatasetMeta, error) {
394+
if workspace == "" {
395+
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
396+
}
397+
398+
out, err := g.runBasicCommand(ctx, "datasets", datasetRequest{
399+
Input: "{}",
400+
Workspace: workspace,
401+
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
402+
})
403+
404+
if err != nil {
405+
return nil, err
406+
}
407+
408+
if strings.HasPrefix(out, "ERROR:") {
409+
return nil, fmt.Errorf(out)
410+
}
411+
412+
var datasets []DatasetMeta
413+
if err = json.Unmarshal([]byte(out), &datasets); err != nil {
414+
return nil, err
415+
}
416+
return datasets, nil
417+
}
418+
419+
func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, description string) (Dataset, error) {
420+
if workspace == "" {
421+
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
422+
}
423+
424+
args := createDatasetArgs{
425+
Name: name,
426+
Description: description,
427+
}
428+
argsJSON, err := json.Marshal(args)
429+
if err != nil {
430+
return Dataset{}, fmt.Errorf("failed to marshal dataset args: %w", err)
431+
}
432+
433+
out, err := g.runBasicCommand(ctx, "datasets/create", datasetRequest{
434+
Input: string(argsJSON),
435+
Workspace: workspace,
436+
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
437+
})
438+
439+
if err != nil {
440+
return Dataset{}, err
441+
}
442+
443+
if strings.HasPrefix(out, "ERROR:") {
444+
return Dataset{}, fmt.Errorf(out)
445+
}
446+
447+
var dataset Dataset
448+
if err = json.Unmarshal([]byte(out), &dataset); err != nil {
449+
return Dataset{}, err
450+
}
451+
return dataset, nil
452+
}
453+
454+
func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID, elementName, elementDescription, elementContent string) (DatasetElementMeta, error) {
455+
if workspace == "" {
456+
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
457+
}
458+
459+
args := addDatasetElementArgs{
460+
DatasetID: datasetID,
461+
ElementName: elementName,
462+
ElementDescription: elementDescription,
463+
ElementContent: elementContent,
464+
}
465+
argsJSON, err := json.Marshal(args)
466+
if err != nil {
467+
return DatasetElementMeta{}, fmt.Errorf("failed to marshal element args: %w", err)
468+
}
469+
470+
out, err := g.runBasicCommand(ctx, "datasets/add-element", datasetRequest{
471+
Input: string(argsJSON),
472+
Workspace: workspace,
473+
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
474+
})
475+
476+
if err != nil {
477+
return DatasetElementMeta{}, err
478+
}
479+
480+
if strings.HasPrefix(out, "ERROR:") {
481+
return DatasetElementMeta{}, fmt.Errorf(out)
482+
}
483+
484+
var element DatasetElementMeta
485+
if err = json.Unmarshal([]byte(out), &element); err != nil {
486+
return DatasetElementMeta{}, err
487+
}
488+
return element, nil
489+
}
490+
491+
func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetID string) ([]DatasetElementMeta, error) {
492+
if workspace == "" {
493+
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
494+
}
495+
496+
args := listDatasetElementArgs{
497+
DatasetID: datasetID,
498+
}
499+
argsJSON, err := json.Marshal(args)
500+
if err != nil {
501+
return nil, fmt.Errorf("failed to marshal element args: %w", err)
502+
}
503+
504+
out, err := g.runBasicCommand(ctx, "datasets/list-elements", datasetRequest{
505+
Input: string(argsJSON),
506+
Workspace: workspace,
507+
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
508+
})
509+
510+
if err != nil {
511+
return nil, err
512+
}
513+
514+
if strings.HasPrefix(out, "ERROR:") {
515+
return nil, fmt.Errorf(out)
516+
}
517+
518+
var elements []DatasetElementMeta
519+
if err = json.Unmarshal([]byte(out), &elements); err != nil {
520+
return nil, err
521+
}
522+
return elements, nil
523+
}
524+
525+
func (g *GPTScript) GetDatasetElement(ctx context.Context, workspace, datasetID, elementName string) (DatasetElement, error) {
526+
if workspace == "" {
527+
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
528+
}
529+
530+
args := getDatasetElementArgs{
531+
DatasetID: datasetID,
532+
Element: elementName,
533+
}
534+
argsJSON, err := json.Marshal(args)
535+
if err != nil {
536+
return DatasetElement{}, fmt.Errorf("failed to marshal element args: %w", err)
537+
}
538+
539+
out, err := g.runBasicCommand(ctx, "datasets/get-element", datasetRequest{
540+
Input: string(argsJSON),
541+
Workspace: workspace,
542+
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
543+
})
544+
545+
if err != nil {
546+
return DatasetElement{}, err
547+
}
548+
549+
if strings.HasPrefix(out, "ERROR:") {
550+
return DatasetElement{}, fmt.Errorf(out)
551+
}
552+
553+
var element DatasetElement
554+
if err = json.Unmarshal([]byte(out), &element); err != nil {
555+
return DatasetElement{}, err
556+
}
557+
558+
return element, nil
559+
}
560+
391561
func (g *GPTScript) runBasicCommand(ctx context.Context, requestPath string, body any) (string, error) {
392562
run := &Run{
393563
url: g.globalOpts.URL,

gptscript_test.go

+43
Original file line numberDiff line numberDiff line change
@@ -1560,3 +1560,46 @@ func TestCredentials(t *testing.T) {
15601560
require.Error(t, err)
15611561
require.True(t, errors.As(err, &ErrNotFound{}))
15621562
}
1563+
1564+
func TestDatasets(t *testing.T) {
1565+
workspace, err := os.MkdirTemp("/tmp", "go-gptscript-test")
1566+
require.NoError(t, err)
1567+
defer func() {
1568+
_ = os.RemoveAll(workspace)
1569+
}()
1570+
1571+
// Create a dataset
1572+
dataset, err := g.CreateDataset(context.Background(), workspace, "test-dataset", "This is a test dataset")
1573+
require.NoError(t, err)
1574+
require.Equal(t, "test-dataset", dataset.Name)
1575+
require.Equal(t, "This is a test dataset", dataset.Description)
1576+
require.Equal(t, 0, len(dataset.Elements))
1577+
1578+
// Add an element
1579+
elementMeta, err := g.AddDatasetElement(context.Background(), workspace, dataset.ID, "test-element", "This is a test element", "This is the content")
1580+
require.NoError(t, err)
1581+
require.Equal(t, "test-element", elementMeta.Name)
1582+
require.Equal(t, "This is a test element", elementMeta.Description)
1583+
1584+
// Get the element
1585+
element, err := g.GetDatasetElement(context.Background(), workspace, dataset.ID, "test-element")
1586+
require.NoError(t, err)
1587+
require.Equal(t, "test-element", element.Name)
1588+
require.Equal(t, "This is a test element", element.Description)
1589+
require.Equal(t, "This is the content", element.Contents)
1590+
1591+
// List elements in the dataset
1592+
elements, err := g.ListDatasetElements(context.Background(), workspace, dataset.ID)
1593+
require.NoError(t, err)
1594+
require.Equal(t, 1, len(elements))
1595+
require.Equal(t, "test-element", elements[0].Name)
1596+
require.Equal(t, "This is a test element", elements[0].Description)
1597+
1598+
// List datasets
1599+
datasets, err := g.ListDatasets(context.Background(), workspace)
1600+
require.NoError(t, err)
1601+
require.Equal(t, 1, len(datasets))
1602+
require.Equal(t, "test-dataset", datasets[0].Name)
1603+
require.Equal(t, "This is a test dataset", datasets[0].Description)
1604+
require.Equal(t, dataset.ID, datasets[0].ID)
1605+
}

opts.go

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ type GlobalOptions struct {
1111
DefaultModelProvider string `json:"DefaultModelProvider"`
1212
CacheDir string `json:"CacheDir"`
1313
Env []string `json:"env"`
14+
DatasetToolRepo string `json:"DatasetToolRepo"`
1415
}
1516

1617
func (g GlobalOptions) toEnv() []string {

0 commit comments

Comments
 (0)