From f4dc4b7cc171c45c252d50a0463d88a451edaa40 Mon Sep 17 00:00:00 2001 From: Nick Hale <4175918+njhale@users.noreply.github.com> Date: Mon, 24 Mar 2025 17:03:49 -0400 Subject: [PATCH] feat: switch to postgres for database tools Addresses https://github.com/obot-platform/obot/issues/2265 Signed-off-by: Nick Hale <4175918+njhale@users.noreply.github.com> --- database/go.mod | 21 ----- database/go.sum | 51 ------------ database/main.go | 150 +++++++----------------------------- database/pkg/cmd/command.go | 143 +++++++++++++++++++++++++++++----- database/pkg/cmd/context.go | 115 +++++++++++++++++++-------- database/pkg/cmd/rows.go | 74 +++++------------- database/pkg/cmd/table.go | 47 ++++------- database/tool.gpt | 4 +- 8 files changed, 274 insertions(+), 331 deletions(-) diff --git a/database/go.mod b/database/go.mod index 1bb12ae12..a5aeafffd 100644 --- a/database/go.mod +++ b/database/go.mod @@ -1,24 +1,3 @@ module obot-platform/database go 1.23.3 - -require ( - github.com/gptscript-ai/go-gptscript v0.9.6-0.20250222170845-eee4337500a6 - github.com/ncruces/go-sqlite3 v0.20.3 -) - -require ( - github.com/getkin/kin-openapi v0.129.0 // indirect - github.com/go-openapi/jsonpointer v0.21.0 // indirect - github.com/go-openapi/swag v0.23.0 // indirect - github.com/josharian/intern v1.0.0 // indirect - github.com/mailru/easyjson v0.9.0 // indirect - github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect - github.com/ncruces/julianday v1.0.0 // indirect - github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80 // indirect - github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349 // indirect - github.com/perimeterx/marshmallow v1.1.5 // indirect - github.com/tetratelabs/wazero v1.8.2 // indirect - golang.org/x/sys v0.27.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) diff --git a/database/go.sum b/database/go.sum index 728f4cceb..e69de29bb 100644 --- a/database/go.sum +++ b/database/go.sum @@ -1,51 +0,0 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/getkin/kin-openapi v0.129.0 h1:QGYTNcmyP5X0AtFQ2Dkou9DGBJsUETeLH9rFrJXZh30= -github.com/getkin/kin-openapi v0.129.0/go.mod h1:gmWI+b/J45xqpyK5wJmRRZse5wefA5H0RDMK46kLUtI= -github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= -github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= -github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= -github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= -github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= -github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= -github.com/gptscript-ai/go-gptscript v0.9.6-0.20250222170845-eee4337500a6 h1:vsZ09cWfNWUXT6AOVQc1GpfEdIxcLusUs6Hgo9IgAKs= -github.com/gptscript-ai/go-gptscript v0.9.6-0.20250222170845-eee4337500a6/go.mod h1:QvGPZoRuAiA8P5EzPI05kTrs+LZ0ipHywUGsKruSknw= -github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= -github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= -github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= -github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= -github.com/ncruces/go-sqlite3 v0.20.3 h1:+4G4uEqOeusF0yRuQVUl9fuoEebUolwQSnBUjYBLYIw= -github.com/ncruces/go-sqlite3 v0.20.3/go.mod h1:ojLIAB243gtz68Eo283Ps+k9PyR3dvzS+9/RgId4+AA= -github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= -github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= -github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80 h1:nZspmSkneBbtxU9TopEAE0CY+SBJLxO8LPUlw2vG4pU= -github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80/go.mod h1:7tFDb+Y51LcDpn26GccuUgQXUk6t0CXZsivKjyimYX8= -github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349 h1:t05Ww3DxZutOqbMN+7OIuqDwXbhl32HiZGpLy26BAPc= -github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o= -github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= -github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4= -github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= -github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= -github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= -golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= -golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= -golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/database/main.go b/database/main.go index c36521a92..a3f6efdee 100644 --- a/database/main.go +++ b/database/main.go @@ -2,89 +2,60 @@ package main import ( "context" - "crypto/sha256" - "encoding/hex" - "errors" "fmt" - "os" - "slices" - "obot-platform/database/pkg/cmd" - - "github.com/gptscript-ai/go-gptscript" - _ "github.com/ncruces/go-sqlite3/driver" - _ "github.com/ncruces/go-sqlite3/embed" + "os" ) -var workspaceID = os.Getenv("DATABASE_WORKSPACE_ID") - func main() { if len(os.Args) != 2 { fmt.Println("Usage: gptscript-go-tool ") os.Exit(1) } command := os.Args[1] + ctx := context.Background() - g, err := gptscript.NewGPTScript() - if err != nil { - fmt.Printf("Error creating GPTScript: %v\n", err) - os.Exit(1) + workspaceID := os.Getenv("DATABASE_WORKSPACE_ID") + if workspaceID == "" { + // TODO(njhale): Figure out why DATABASE_WORKSPACE_ID is not set here for the UI tools. + workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") } - defer g.Close() - - var ( - ctx = context.Background() - dbFileName = "obot.db" - dbWorkspacePath = "/databases/" + dbFileName - revisionID string = "-1" - initialDBData []byte - ) - - workspaceDB, err := g.ReadFileWithRevisionInWorkspace(ctx, dbWorkspacePath, gptscript.ReadFileInWorkspaceOptions{ - WorkspaceID: workspaceID, - }) - var notFoundErr *gptscript.NotFoundInWorkspaceError - if err != nil && !errors.As(err, ¬FoundErr) { - fmt.Printf("Error reading DB file: %v\n", err) - os.Exit(1) - } + // Get admin DSN from environment variable + adminDSN := os.Getenv("POSTGRES_DSN") - // Create a temporary file for the SQLite database - dbFile, err := os.CreateTemp("", dbFileName) + // Setup database and user with admin credentials + dsn, err := cmd.EnsureTenantSchema(ctx, adminDSN, workspaceID) if err != nil { - fmt.Printf("Error creating temp file: %v\n", err) + fmt.Printf("Error setting up database: %v\n", err) os.Exit(1) } - defer dbFile.Close() - defer os.Remove(dbFile.Name()) - // Write the data to the temporary file - if workspaceDB != nil && workspaceDB.Content != nil { - initialDBData = workspaceDB.Content - if err := os.WriteFile(dbFile.Name(), initialDBData, 0644); err != nil { - fmt.Printf("Error writing to temp file: %v\n", err) - os.Exit(1) - } - if workspaceDB.RevisionID != "" { - revisionID = workspaceDB.RevisionID - } - } - - // Run the requested command + // Run the requested command using the user credentials var result string switch command { case "listDatabaseTables": - result, err = cmd.ListDatabaseTables(ctx, dbFile) + result, err = cmd.ListDatabaseTables(ctx, dsn) + case "listDatabaseTableRows": - result, err = cmd.ListDatabaseTableRows(ctx, dbFile, os.Getenv("TABLE")) + table := os.Getenv("TABLE") + if table == "" { + err = fmt.Errorf("TABLE environment variable is required") + break + } + result, err = cmd.ListDatabaseTableRows(ctx, dsn, table) + case "runDatabaseSQL": - result, err = cmd.RunDatabaseCommand(ctx, dbFile, os.Getenv("SQL"), "-header") - if err == nil { - err = saveWorkspaceDB(ctx, g, dbWorkspacePath, revisionID, dbFile, initialDBData) + sql := os.Getenv("SQL") + if sql == "" { + err = fmt.Errorf("SQL environment variable is required") + break } + result, err = cmd.RunDatabaseCommand(ctx, dsn, sql) + case "databaseContext": - result, err = cmd.DatabaseContext(ctx, dbFile) + result, err = cmd.DatabaseContext(ctx, dsn) + default: err = fmt.Errorf("unknown command: %s", command) } @@ -96,66 +67,3 @@ func main() { fmt.Print(result) } - -// saveWorkspaceDB saves the updated database file to the workspace if the content of the database has changed. -func saveWorkspaceDB( - ctx context.Context, - g *gptscript.GPTScript, - dbWorkspacePath string, - revisionID string, - dbFile *os.File, - initialDBData []byte, -) error { - updatedDBData, err := os.ReadFile(dbFile.Name()) - if err != nil { - return fmt.Errorf("Error reading updated DB file: %v", err) - } - - if hash(initialDBData) == hash(updatedDBData) { - return nil - } - - if err := g.WriteFileInWorkspace(ctx, dbWorkspacePath, updatedDBData, gptscript.WriteFileInWorkspaceOptions{ - WorkspaceID: workspaceID, - CreateRevision: &([]bool{true}[0]), - LatestRevisionID: revisionID, - }); err != nil { - return fmt.Errorf("Error writing updated DB file to workspace: %v", err) - } - - // Delete old revisions after successfully writing the new revision - revisions, err := g.ListRevisionsForFileInWorkspace(ctx, dbWorkspacePath, gptscript.ListRevisionsForFileInWorkspaceOptions{ - WorkspaceID: workspaceID, - }) - if err != nil { - fmt.Fprintf(os.Stderr, "Error listing revisions: %v\n", err) - return nil - } - - lastRevisionIndex := slices.IndexFunc(revisions, func(rev gptscript.FileInfo) bool { - return rev.RevisionID == revisionID - }) - - if lastRevisionIndex < 0 { - return nil - } - - for _, rev := range revisions[:lastRevisionIndex+1] { - if err := g.DeleteRevisionForFileInWorkspace(ctx, dbWorkspacePath, rev.RevisionID, gptscript.DeleteRevisionForFileInWorkspaceOptions{ - WorkspaceID: workspaceID, - }); err != nil { - fmt.Fprintf(os.Stderr, "Error deleting revision %s: %v\n", rev.RevisionID, err) - } - } - - return nil -} - -// hash computes the SHA-256 hash of the given data and returns it as a hexadecimal string -func hash(data []byte) string { - if data == nil { - return "" - } - hash := sha256.Sum256(data) - return hex.EncodeToString(hash[:]) -} diff --git a/database/pkg/cmd/command.go b/database/pkg/cmd/command.go index cb9864e9d..b09b90c10 100644 --- a/database/pkg/cmd/command.go +++ b/database/pkg/cmd/command.go @@ -3,39 +3,146 @@ package cmd import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "fmt" - "os" "os/exec" + "regexp" "strconv" "strings" ) -// RunDatabaseCommand runs a sqlite3 command against the database and returns the output from the sqlite3 CLI. -func RunDatabaseCommand(ctx context.Context, dbFile *os.File, sql string, opts ...string) (string, error) { - // Remove the "sqlite3" prefix and trim whitespace - args := append(opts, dbFile.Name()) - if arg := strings.TrimSpace(sql); arg != "" { - // Use strconv.Unquote to safely handle quotes and escape sequences - unquoted, err := strconv.Unquote(arg) - if err != nil { - // If unquoting fails (e.g. string wasn't quoted), use original - unquoted = arg - } - args = append(args, unquoted) +// RunDatabaseCommand executes a command against the Postgres database +func RunDatabaseCommand(ctx context.Context, dsn string, sql string, opts ...string) (string, error) { + if sql == "" { + return "", fmt.Errorf("SQL cannot be empty") + } + + args := append([]string{dsn}, opts...) + + unquoted, err := strconv.Unquote(sql) + if err != nil { + unquoted = sql } + args = append(args, "-c", unquoted) - // Build the sqlite3 command - cmd := exec.CommandContext(ctx, "sqlite3", args...) + cmd := exec.CommandContext(ctx, "psql", args...) - // Redirect command output var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr - // Run the command and capture errors if err := cmd.Run(); err != nil { - return "", fmt.Errorf("error executing sqlite3: %w, stderr: %s", err, stderr.String()) + return "", fmt.Errorf("psql error: %w\nstderr: %s", err, stderr.String()) + } + + if stderr.Len() > 0 { + return stdout.String(), fmt.Errorf("psql stderr: %s", stderr.String()) } return stdout.String(), nil } + +// EnsureTenantSchema creates a schema and role for a tenant with proper isolation +func EnsureTenantSchema(ctx context.Context, adminDSN, workspaceID string) (string, error) { + schemaName := workspaceSchemaName(workspaceID) + userName := schemaName + password := generatePassword(workspaceID) + dbName := "obot_db" + + // Create shared database if it doesn't exist + checkDBSQL := fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", dbName) + dbExistsCheck, err := RunDatabaseCommand(ctx, adminDSN, checkDBSQL, "-At") + if err != nil { + return "", fmt.Errorf("error checking for shared database: %w", err) + } + if strings.TrimSpace(dbExistsCheck) != "1" { + createDBSQL := fmt.Sprintf("CREATE DATABASE %s", dbName) + if _, err := RunDatabaseCommand(ctx, adminDSN, createDBSQL); err != nil { + return "", fmt.Errorf("error creating shared database: %w", err) + } + + // Create tenant role + createRoleSQL := fmt.Sprintf(`CREATE ROLE %s WITH LOGIN PASSWORD '%s'`, userName, password) + if _, err := RunDatabaseCommand(ctx, adminDSN, createRoleSQL); err != nil && !strings.Contains(err.Error(), "already exists") { + return "", fmt.Errorf("error creating role: %w", err) + } + + // Connect to shared database + dbDSN, err := dsnWithDatabase(adminDSN, dbName) + if err != nil { + return "", fmt.Errorf("error constructing DSN for shared database: %w", err) + } + + // Create schema for tenant (owned by admin user) + createSchemaSQL := fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s`, schemaName) + if _, err := RunDatabaseCommand(ctx, dbDSN, createSchemaSQL); err != nil { + return "", fmt.Errorf("error creating schema: %w", err) + } + + // Revoke PUBLIC access on public schema + revokePublicSchemaSQL := `REVOKE ALL ON SCHEMA public FROM PUBLIC` + if _, err := RunDatabaseCommand(ctx, dbDSN, revokePublicSchemaSQL); err != nil { + return "", fmt.Errorf("error revoking public schema privileges: %w", err) + } + + // Set up tenant schema permissions + statements := []string{ + fmt.Sprintf(`REVOKE ALL ON SCHEMA %s FROM PUBLIC`, schemaName), + fmt.Sprintf(`GRANT USAGE, CREATE ON SCHEMA %s TO %s`, schemaName, userName), + fmt.Sprintf(`GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s TO %s`, schemaName, userName), + fmt.Sprintf(`GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA %s TO %s`, schemaName, userName), + fmt.Sprintf(`GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA %s TO %s`, schemaName, userName), + fmt.Sprintf(`ALTER DEFAULT PRIVILEGES IN SCHEMA %s GRANT ALL ON TABLES TO %s`, schemaName, userName), + fmt.Sprintf(`ALTER DEFAULT PRIVILEGES IN SCHEMA %s GRANT ALL ON SEQUENCES TO %s`, schemaName, userName), + fmt.Sprintf(`ALTER DEFAULT PRIVILEGES IN SCHEMA %s GRANT ALL ON FUNCTIONS TO %s`, schemaName, userName), + fmt.Sprintf(`ALTER ROLE %s SET search_path = %s`, userName, schemaName), + } + + for _, stmt := range statements { + if _, err := RunDatabaseCommand(ctx, dbDSN, stmt); err != nil { + return "", fmt.Errorf("error executing statement '%s': %w", stmt, err) + } + } + } + + userDSN := fmt.Sprintf("postgresql://%s:%s@%s/%s?sslmode=require", userName, password, extractHost(adminDSN), dbName) + return userDSN, nil +} + +// generatePassword creates a hashed password using the workspaceID +func generatePassword(workspaceID string) string { + hash := sha256.Sum256([]byte(workspaceID)) + return hex.EncodeToString(hash[:]) +} + +// extractHost extracts the host part from the DSN +func extractHost(dsn string) string { + re := regexp.MustCompile(`^(postgresql://[^:]+:[^@]+@)([^/]+)(/[^?]*)(\?.+)?$`) + matches := re.FindStringSubmatch(dsn) + if len(matches) >= 3 { + return matches[2] + } + return "" +} + +// workspaceSchemaName converts a workspace ID into a valid PostgreSQL schema/role identifier +func workspaceSchemaName(workspaceID string) string { + hash := sha256.Sum256([]byte(workspaceID)) + return "schema_" + hex.EncodeToString(hash[:16]) +} + +// dsnWithDatabase switches the database in the DSN string +func dsnWithDatabase(adminDSN, dbName string) (string, error) { + if strings.HasPrefix(adminDSN, "postgresql://") { + re := regexp.MustCompile(`^(postgresql://[^/]+/)([^?]*)(\?.+)?$`) + matches := re.FindStringSubmatch(adminDSN) + if len(matches) >= 3 { + if matches[3] != "" { + return matches[1] + dbName + matches[3], nil + } + return matches[1] + dbName, nil + } + } + return "", fmt.Errorf("invalid DSN format") +} diff --git a/database/pkg/cmd/context.go b/database/pkg/cmd/context.go index 57c3a665c..34e4b9b17 100644 --- a/database/pkg/cmd/context.go +++ b/database/pkg/cmd/context.go @@ -3,52 +3,103 @@ package cmd import ( "context" "fmt" - "os" "strings" ) -// DatabaseContext generates a markdown-formatted string with instructions -// and the database's current schemas. -func DatabaseContext(ctx context.Context, dbFile *os.File) (string, error) { +const getSchemasSQL = ` +WITH table_columns AS ( + SELECT + table_name, + ordinal_position, + column_name, + data_type, + is_nullable, + column_default + FROM information_schema.columns + WHERE table_schema = 'public' +), +constraints AS ( + SELECT + conname, + contype, + conrelid::regclass::text AS table_name, + pg_get_constraintdef(oid, true) AS definition + FROM pg_constraint + WHERE connamespace = 'public'::regnamespace +), +indexes AS ( + SELECT + tablename, + indexdef + FROM pg_indexes + WHERE schemaname = 'public' +) +SELECT format( + E'\nCREATE TABLE %I (\n%s%s\n);\n\n%s\n', + tc.table_name, + tc.table_name, + string_agg( + format(' %I %s%s%s', + tc.column_name, + tc.data_type, + CASE WHEN tc.column_default IS NOT NULL THEN ' DEFAULT ' || tc.column_default ELSE '' END, + CASE WHEN tc.is_nullable = 'NO' THEN ' NOT NULL' ELSE '' END + ), + E',\n' + ORDER BY tc.ordinal_position + ), + CASE + WHEN ct.constraint_defs IS NOT NULL THEN E',\n' || ct.constraint_defs + ELSE '' + END, + COALESCE(idx.index_defs, '') +) +FROM table_columns tc +LEFT JOIN ( + SELECT + table_name, + string_agg( + format(' CONSTRAINT %I %s', conname, definition), + E',\n' + ) AS constraint_defs + FROM constraints + GROUP BY table_name +) ct ON tc.table_name = ct.table_name +LEFT JOIN ( + SELECT + tablename, + string_agg(indexdef, E'\n') AS index_defs + FROM indexes + GROUP BY tablename +) idx ON tc.table_name = idx.tablename +GROUP BY tc.table_name, ct.constraint_defs, idx.index_defs +ORDER BY tc.table_name; +` + +// DatabaseContext returns markdown with database schema information +func DatabaseContext(ctx context.Context, dsn string) (string, error) { var builder strings.Builder - // Add usage instructions - builder.WriteString(`# START INSTRUCTIONS: Run Database SQL tool + builder.WriteString(`# PostgreSQL Database Tool -You have access to tools for interacting with a SQLite database. -The "Run Database SQL" tool lets you run SQL against the SQLite3 database. +You have access to tools for interacting with a PostgreSQL database. +The "Run Database SQL" tool lets you run SQL against the PostgreSQL database. Display all results from these tools and their schemas in markdown format. -If the user refers to creating or modifying tables, assume they mean a SQLite3 table and not writing a table in a markdown file. +If the user refers to creating or modifying tables, assume they mean a PostgreSQL table and not writing a table in a markdown file. -# END INSTRUCTIONS: Run Database SQL tool `) - // Add the schemas section - schemas, err := getSchemas(ctx, dbFile) + schemas, err := RunDatabaseCommand(ctx, dsn, getSchemasSQL, "-At") if err != nil { - return "", fmt.Errorf("failed to retrieve schemas: %w", err) + return "", fmt.Errorf("error getting schemas: %w", err) } - if schemas != "" { - builder.WriteString("# START CURRENT DATABASE SCHEMAS\n") - builder.WriteString(schemas) - builder.WriteString("\n# END CURRENT DATABASE SCHEMAS\n") + + if schemas == "" { + builder.WriteString("\n# No tables found in database\n") } else { - builder.WriteString("# DATABASE HAS NO TABLES\n") + builder.WriteString("\n# Database Schema\n\n") + builder.WriteString(schemas) } return builder.String(), nil } - -// getSchemas retrieves all schemas from the database using the sqlite3 CLI. -func getSchemas(ctx context.Context, dbFile *os.File) (string, error) { - query := `SELECT sql FROM sqlite_master WHERE type IN ('table', 'index', 'view', 'trigger') AND name NOT LIKE 'sqlite_%' ORDER BY name;` - - // Execute the query using the RunDatabaseCommand function - output, err := RunDatabaseCommand(ctx, dbFile, query) - if err != nil { - return "", fmt.Errorf("error querying schemas: %w", err) - } - - // Return raw output as-is - return strings.TrimSpace(output), nil -} diff --git a/database/pkg/cmd/rows.go b/database/pkg/cmd/rows.go index e3e61031a..f9bf5d03c 100644 --- a/database/pkg/cmd/rows.go +++ b/database/pkg/cmd/rows.go @@ -2,67 +2,29 @@ package cmd import ( "context" - "encoding/json" "fmt" - "os" ) -type Output struct { - Columns []string `json:"columns"` - Rows []map[string]any `json:"rows"` -} - -// ListDatabaseTableRows lists all rows from the specified table using RunDatabaseCommand and returns the JSON output directly. -func ListDatabaseTableRows(ctx context.Context, dbFile *os.File, table string) (string, error) { +const tableRowsSQL = ` +SELECT json_build_object( + 'columns', ( + SELECT array_agg(column_name ORDER BY ordinal_position) + FROM information_schema.columns + WHERE table_schema = 'public' AND table_name = '%s' + ), + 'rows', COALESCE(( + SELECT json_agg(row_to_json(t)) + FROM %s t + ), '[]') +)::text; +` + +// ListDatabaseTableRows returns table contents with columns +func ListDatabaseTableRows(ctx context.Context, dsn string, table string) (string, error) { if table == "" { return "", fmt.Errorf("table name cannot be empty") } - // Get column names using PRAGMA - columnsQuery := fmt.Sprintf("PRAGMA table_info(%q);", table) - columnsOutput, err := RunDatabaseCommand(ctx, dbFile, columnsQuery, "-json") - if err != nil { - return "", fmt.Errorf("error getting columns for table %q: %w", table, err) - } - - // Parse column information - var columnInfo []struct { - Name string `json:"name"` - } - if err := json.Unmarshal([]byte(columnsOutput), &columnInfo); err != nil { - return "", fmt.Errorf("error parsing column information: %w", err) - } - - columns := make([]string, len(columnInfo)) - for i, col := range columnInfo { - columns[i] = col.Name - } - - // Get all rows - rowsQuery := fmt.Sprintf("SELECT * FROM %q;", table) - rowsOutput, err := RunDatabaseCommand(ctx, dbFile, rowsQuery, "-json") - if err != nil { - return "", fmt.Errorf("error executing query for table %q: %w", table, err) - } - - // Parse rows - var rows []map[string]any - if rowsOutput != "" { - if err := json.Unmarshal([]byte(rowsOutput), &rows); err != nil { - return "", fmt.Errorf("error parsing JSON output: %w", err) - } - } - - // Create and marshal output - output := Output{ - Columns: columns, - Rows: rows, - } - - result, err := json.Marshal(output) - if err != nil { - return "", fmt.Errorf("error marshaling output: %w", err) - } - - return string(result), nil + query := fmt.Sprintf(tableRowsSQL, table, table) + return RunDatabaseCommand(ctx, dsn, query, "-At") } diff --git a/database/pkg/cmd/table.go b/database/pkg/cmd/table.go index 38aca3fe0..aab6eb1f2 100644 --- a/database/pkg/cmd/table.go +++ b/database/pkg/cmd/table.go @@ -2,42 +2,29 @@ package cmd import ( "context" - "encoding/json" "fmt" - "os" ) -type tables struct { - Tables []Table `json:"tables"` -} - -type Table struct { - Name string `json:"name,omitempty"` -} - -// ListDatabaseTables returns a JSON string containing the list of tables in the database. -func ListDatabaseTables(ctx context.Context, dbFile *os.File) (string, error) { - // Query to fetch table names - query := "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';" - - // Execute the query using RunDatabaseCommand with JSON output - output, err := RunDatabaseCommand(ctx, dbFile, query, "-json") +const listTablesSQL = `SELECT COALESCE( + json_agg(json_build_object('name', table_name)), '[]' +)::text +FROM ( + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'public' AND table_type = 'BASE TABLE' + ORDER BY table_name +) AS ordered_tables;` + +// ListDatabaseTables returns a JSON string containing the list of tables +func ListDatabaseTables(ctx context.Context, dsn string) (string, error) { + output, err := RunDatabaseCommand(ctx, dsn, listTablesSQL, "-At") if err != nil { - return "", fmt.Errorf("error executing query to list tables: %w", err) - } - - var dbTables tables - if output != "" { - if err := json.Unmarshal([]byte(output), &(dbTables.Tables)); err != nil { - return "", fmt.Errorf("error parsing table names: %w", err) - } + return "", fmt.Errorf("error listing tables: %w", err) } - // Marshal final result - data, err := json.Marshal(dbTables) - if err != nil { - return "", fmt.Errorf("error marshaling tables to JSON: %w", err) + if output == "" { + return `{"tables":[]}`, nil } - return string(data), nil + return fmt.Sprintf(`{"tables":%s}`, output), nil } diff --git a/database/tool.gpt b/database/tool.gpt index 2eb636a52..abb003160 100644 --- a/database/tool.gpt +++ b/database/tool.gpt @@ -8,8 +8,8 @@ Share Tools: Run Database SQL --- Name: Run Database SQL Share Context: Database Context -Description: Run SQL against the SQLite3 database and return the results -Param: sql: SQL to run against the SQLite3 database (e.g. "SELECT * FROM users") +Description: Run SQL against the PostgreSQL database and return the results +Param: sql: SQL to run against the PostgreSQL database (e.g. "SELECT * FROM users") #!${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool runDatabaseSQL