diff --git a/datasets.go b/datasets.go new file mode 100644 index 0000000..45b665b --- /dev/null +++ b/datasets.go @@ -0,0 +1,50 @@ +package gptscript + +type DatasetElementMeta struct { + Name string `json:"name"` + Description string `json:"description"` +} + +type DatasetElement struct { + DatasetElementMeta `json:",inline"` + Contents string `json:"contents"` +} + +type DatasetMeta struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` +} + +type Dataset struct { + DatasetMeta `json:",inline"` + BaseDir string `json:"baseDir,omitempty"` + Elements map[string]DatasetElementMeta `json:"elements"` +} + +type datasetRequest struct { + Input string `json:"input"` + Workspace string `json:"workspace"` + DatasetToolRepo string `json:"datasetToolRepo"` +} + +type createDatasetArgs struct { + Name string `json:"datasetName"` + Description string `json:"datasetDescription"` +} + +type addDatasetElementArgs struct { + DatasetID string `json:"datasetID"` + ElementName string `json:"elementName"` + ElementDescription string `json:"elementDescription"` + ElementContent string `json:"elementContent"` +} + +type listDatasetElementArgs struct { + DatasetID string `json:"datasetID"` +} + +type getDatasetElementArgs struct { + DatasetID string `json:"datasetID"` + Element string `json:"element"` +} diff --git a/gptscript.go b/gptscript.go index 1e30d95..6178c67 100644 --- a/gptscript.go +++ b/gptscript.go @@ -7,6 +7,7 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -388,6 +389,176 @@ func (g *GPTScript) DeleteCredential(ctx context.Context, credCtx, name string) return err } +// Dataset methods + +func (g *GPTScript) ListDatasets(ctx context.Context, workspace string) ([]DatasetMeta, error) { + if workspace == "" { + workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR") + } + + out, err := g.runBasicCommand(ctx, "datasets", datasetRequest{ + Input: "{}", + Workspace: workspace, + DatasetToolRepo: g.globalOpts.DatasetToolRepo, + }) + + if err != nil { + return nil, err + } + + if strings.HasPrefix(out, "ERROR:") { + return nil, errors.New(out) + } + + var datasets []DatasetMeta + if err = json.Unmarshal([]byte(out), &datasets); err != nil { + return nil, err + } + return datasets, nil +} + +func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, description string) (Dataset, error) { + if workspace == "" { + workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR") + } + + args := createDatasetArgs{ + Name: name, + Description: description, + } + argsJSON, err := json.Marshal(args) + if err != nil { + return Dataset{}, fmt.Errorf("failed to marshal dataset args: %w", err) + } + + out, err := g.runBasicCommand(ctx, "datasets/create", datasetRequest{ + Input: string(argsJSON), + Workspace: workspace, + DatasetToolRepo: g.globalOpts.DatasetToolRepo, + }) + + if err != nil { + return Dataset{}, err + } + + if strings.HasPrefix(out, "ERROR:") { + return Dataset{}, errors.New(out) + } + + var dataset Dataset + if err = json.Unmarshal([]byte(out), &dataset); err != nil { + return Dataset{}, err + } + return dataset, nil +} + +func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID, elementName, elementDescription, elementContent string) (DatasetElementMeta, error) { + if workspace == "" { + workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR") + } + + args := addDatasetElementArgs{ + DatasetID: datasetID, + ElementName: elementName, + ElementDescription: elementDescription, + ElementContent: elementContent, + } + argsJSON, err := json.Marshal(args) + if err != nil { + return DatasetElementMeta{}, fmt.Errorf("failed to marshal element args: %w", err) + } + + out, err := g.runBasicCommand(ctx, "datasets/add-element", datasetRequest{ + Input: string(argsJSON), + Workspace: workspace, + DatasetToolRepo: g.globalOpts.DatasetToolRepo, + }) + + if err != nil { + return DatasetElementMeta{}, err + } + + if strings.HasPrefix(out, "ERROR:") { + return DatasetElementMeta{}, errors.New(out) + } + + var element DatasetElementMeta + if err = json.Unmarshal([]byte(out), &element); err != nil { + return DatasetElementMeta{}, err + } + return element, nil +} + +func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetID string) ([]DatasetElementMeta, error) { + if workspace == "" { + workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR") + } + + args := listDatasetElementArgs{ + DatasetID: datasetID, + } + argsJSON, err := json.Marshal(args) + if err != nil { + return nil, fmt.Errorf("failed to marshal element args: %w", err) + } + + out, err := g.runBasicCommand(ctx, "datasets/list-elements", datasetRequest{ + Input: string(argsJSON), + Workspace: workspace, + DatasetToolRepo: g.globalOpts.DatasetToolRepo, + }) + + if err != nil { + return nil, err + } + + if strings.HasPrefix(out, "ERROR:") { + return nil, errors.New(out) + } + + var elements []DatasetElementMeta + if err = json.Unmarshal([]byte(out), &elements); err != nil { + return nil, err + } + return elements, nil +} + +func (g *GPTScript) GetDatasetElement(ctx context.Context, workspace, datasetID, elementName string) (DatasetElement, error) { + if workspace == "" { + workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR") + } + + args := getDatasetElementArgs{ + DatasetID: datasetID, + Element: elementName, + } + argsJSON, err := json.Marshal(args) + if err != nil { + return DatasetElement{}, fmt.Errorf("failed to marshal element args: %w", err) + } + + out, err := g.runBasicCommand(ctx, "datasets/get-element", datasetRequest{ + Input: string(argsJSON), + Workspace: workspace, + DatasetToolRepo: g.globalOpts.DatasetToolRepo, + }) + + if err != nil { + return DatasetElement{}, err + } + + if strings.HasPrefix(out, "ERROR:") { + return DatasetElement{}, errors.New(out) + } + + var element DatasetElement + if err = json.Unmarshal([]byte(out), &element); err != nil { + return DatasetElement{}, err + } + + return element, nil +} + func (g *GPTScript) runBasicCommand(ctx context.Context, requestPath string, body any) (string, error) { run := &Run{ url: g.globalOpts.URL, diff --git a/gptscript_test.go b/gptscript_test.go index ec4419c..eb471c8 100644 --- a/gptscript_test.go +++ b/gptscript_test.go @@ -670,7 +670,7 @@ func TestParseToolWithTextNode(t *testing.T) { t.Fatalf("No text node found") } - if tools[1].TextNode.Text != "hello\n" { + if strings.TrimSpace(tools[1].TextNode.Text) != "hello" { t.Errorf("Unexpected text: %s", tools[1].TextNode.Text) } if tools[1].TextNode.Fmt != "markdown" { @@ -1047,7 +1047,7 @@ func TestConfirmDeny(t *testing.T) { return } - if !strings.Contains(confirmCallEvent.Input, "\"ls\"") { + if !strings.Contains(confirmCallEvent.Input, "ls") { t.Errorf("unexpected confirm input: %s", confirmCallEvent.Input) } @@ -1560,3 +1560,46 @@ func TestCredentials(t *testing.T) { require.Error(t, err) require.True(t, errors.As(err, &ErrNotFound{})) } + +func TestDatasets(t *testing.T) { + workspace, err := os.MkdirTemp("", "go-gptscript-test") + require.NoError(t, err) + defer func() { + _ = os.RemoveAll(workspace) + }() + + // Create a dataset + dataset, err := g.CreateDataset(context.Background(), workspace, "test-dataset", "This is a test dataset") + require.NoError(t, err) + require.Equal(t, "test-dataset", dataset.Name) + require.Equal(t, "This is a test dataset", dataset.Description) + require.Equal(t, 0, len(dataset.Elements)) + + // Add an element + elementMeta, err := g.AddDatasetElement(context.Background(), workspace, dataset.ID, "test-element", "This is a test element", "This is the content") + require.NoError(t, err) + require.Equal(t, "test-element", elementMeta.Name) + require.Equal(t, "This is a test element", elementMeta.Description) + + // Get the element + element, err := g.GetDatasetElement(context.Background(), workspace, dataset.ID, "test-element") + require.NoError(t, err) + require.Equal(t, "test-element", element.Name) + require.Equal(t, "This is a test element", element.Description) + require.Equal(t, "This is the content", element.Contents) + + // List elements in the dataset + elements, err := g.ListDatasetElements(context.Background(), workspace, dataset.ID) + require.NoError(t, err) + require.Equal(t, 1, len(elements)) + require.Equal(t, "test-element", elements[0].Name) + require.Equal(t, "This is a test element", elements[0].Description) + + // List datasets + datasets, err := g.ListDatasets(context.Background(), workspace) + require.NoError(t, err) + require.Equal(t, 1, len(datasets)) + require.Equal(t, "test-dataset", datasets[0].Name) + require.Equal(t, "This is a test dataset", datasets[0].Description) + require.Equal(t, dataset.ID, datasets[0].ID) +} diff --git a/opts.go b/opts.go index 779dcb9..283b4ec 100644 --- a/opts.go +++ b/opts.go @@ -11,6 +11,7 @@ type GlobalOptions struct { DefaultModelProvider string `json:"DefaultModelProvider"` CacheDir string `json:"CacheDir"` Env []string `json:"env"` + DatasetToolRepo string `json:"DatasetToolRepo"` } func (g GlobalOptions) toEnv() []string { @@ -41,6 +42,7 @@ func completeGlobalOptions(opts ...GlobalOptions) GlobalOptions { result.OpenAIBaseURL = firstSet(opt.OpenAIBaseURL, result.OpenAIBaseURL) result.DefaultModel = firstSet(opt.DefaultModel, result.DefaultModel) result.DefaultModelProvider = firstSet(opt.DefaultModelProvider, result.DefaultModelProvider) + result.DatasetToolRepo = firstSet(opt.DatasetToolRepo, result.DatasetToolRepo) result.Env = append(result.Env, opt.Env...) } return result