diff --git a/docs/supabase/db/reset.md b/docs/supabase/db/reset.md index acb9b9832..2cf8af89d 100644 --- a/docs/supabase/db/reset.md +++ b/docs/supabase/db/reset.md @@ -4,6 +4,6 @@ Resets the local database to a clean state. Requires the local development stack to be started by running `supabase start`. -Recreates the local Postgres container and applies all local migrations found in `supabase/migrations` directory. If test data is defined in `supabase/seed.sql`, it will be seeded after the migrations are run. Any other data or schema changes made during local development will be discarded. +Recreates the local Postgres container and applies all local migrations found in `supabase/migrations` directory. If test data is defined in `supabase/seed.sql` (or a configured `*.sql.gz` seed file), it will be seeded after the migrations are run. Any other data or schema changes made during local development will be discarded. When running db reset with `--linked` or `--db-url` flag, a SQL script is executed to identify and drop all user created entities in the remote database. Since Postgres roles are cluster level entities, any custom roles created through the dashboard or `supabase/roles.sql` will not be deleted by remote reset. diff --git a/pkg/config/templates/config.toml b/pkg/config/templates/config.toml index 97ed4e566..ba3e22427 100644 --- a/pkg/config/templates/config.toml +++ b/pkg/config/templates/config.toml @@ -62,6 +62,7 @@ schema_paths = [] enabled = true # Specifies an ordered list of seed files to load during db reset. # Supports glob patterns relative to supabase directory: "./seeds/*.sql" +# Supports gzipped SQL files with .sql.gz extension. sql_paths = ["./seed.sql"] [db.network_restrictions] diff --git a/pkg/config/testdata/config.toml b/pkg/config/testdata/config.toml index b228a9c07..50cef4c73 100644 --- a/pkg/config/testdata/config.toml +++ b/pkg/config/testdata/config.toml @@ -62,6 +62,7 @@ test_key = "test_value" enabled = true # Specifies an ordered list of seed files to load during db reset. # Supports glob patterns relative to supabase directory: "./seeds/*.sql" +# Supports gzipped SQL files with .sql.gz extension. sql_paths = ["./seed.sql"] [db.network_restrictions] diff --git a/pkg/migration/file.go b/pkg/migration/file.go index 540c129e3..36ca07714 100644 --- a/pkg/migration/file.go +++ b/pkg/migration/file.go @@ -2,6 +2,7 @@ package migration import ( "bytes" + "compress/gzip" "context" "crypto/sha256" "encoding/hex" @@ -31,6 +32,8 @@ var ( typeNamePattern = regexp.MustCompile(`type "([^"]+)" does not exist`) ) +const compressedSQLSizeMultiplier = 8 + func NewMigrationFromFile(path string, fsys fs.FS) (*MigrationFile, error) { lines, err := parseFile(path, fsys) if err != nil { @@ -48,17 +51,15 @@ func NewMigrationFromFile(path string, fsys fs.FS) (*MigrationFile, error) { } func parseFile(path string, fsys fs.FS) ([]string, error) { - sql, err := fsys.Open(path) + sql, scannerBuffer, err := openSQL(path, fsys, "migration file") if err != nil { - return nil, errors.Errorf("failed to open migration file: %w", err) + return nil, err } defer sql.Close() - // Unless explicitly specified, Use file length as max buffer size + // Unless explicitly specified, use file length (or an estimate for .sql.gz) as max buffer size. if !viper.IsSet("SCANNER_BUFFER_SIZE") { - if fi, err := sql.Stat(); err == nil { - if size := int(fi.Size()); size > parser.MaxScannerCapacity { - parser.MaxScannerCapacity = size - } + if scannerBuffer > parser.MaxScannerCapacity { + parser.MaxScannerCapacity = scannerBuffer } } return parser.SplitAndTrim(sql) @@ -182,9 +183,9 @@ type SeedFile struct { } func NewSeedFile(path string, fsys fs.FS) (*SeedFile, error) { - sql, err := fsys.Open(path) + sql, _, err := openSQL(path, fsys, "seed file") if err != nil { - return nil, errors.Errorf("failed to open seed file: %w", err) + return nil, err } defer sql.Close() hash := sha256.New() @@ -195,6 +196,66 @@ func NewSeedFile(path string, fsys fs.FS) (*SeedFile, error) { return &SeedFile{Path: path, Hash: digest}, nil } +func openSQL(path string, fsys fs.FS, kind string) (io.ReadCloser, int, error) { + sql, err := fsys.Open(path) + if err != nil { + return nil, 0, errors.Errorf("failed to open %s: %w", kind, err) + } + bufferSize := scannerBufferSize(path, sql) + if !isCompressedSQL(path) { + return sql, bufferSize, nil + } + gz, err := gzip.NewReader(sql) + if err != nil { + _ = sql.Close() + return nil, 0, errors.Errorf("failed to decompress %s: %w", kind, err) + } + return &compressedSQLReader{Reader: gz, gz: gz, file: sql}, bufferSize, nil +} + +func scannerBufferSize(path string, sql fs.File) int { + info, err := sql.Stat() + if err != nil { + return 0 + } + maxInt := int64(^uint(0) >> 1) + size := info.Size() + if size <= 0 { + return 0 + } + if isCompressedSQL(path) { + if size > maxInt/compressedSQLSizeMultiplier { + return int(maxInt) + } + size *= compressedSQLSizeMultiplier + } + if size > maxInt { + return int(maxInt) + } + return int(size) +} + +func isCompressedSQL(path string) bool { + return strings.HasSuffix(strings.ToLower(path), ".sql.gz") +} + +type compressedSQLReader struct { + io.Reader + gz *gzip.Reader + file fs.File +} + +func (r *compressedSQLReader) Close() error { + var firstErr error + if err := r.gz.Close(); err != nil { + firstErr = err + } + if err := r.file.Close(); err != nil && firstErr == nil { + firstErr = err + } + return firstErr +} + func (m *SeedFile) ExecBatchWithCache(ctx context.Context, conn *pgx.Conn, fsys fs.FS) error { // Parse each file individually to reduce memory usage lines, err := parseFile(m.Path, fsys) diff --git a/pkg/migration/seed_test.go b/pkg/migration/seed_test.go index db4337b54..7a96d2e3e 100644 --- a/pkg/migration/seed_test.go +++ b/pkg/migration/seed_test.go @@ -1,8 +1,12 @@ package migration import ( + "bytes" + "compress/gzip" "context" + "crypto/sha256" _ "embed" + "encoding/hex" "os" "testing" fs "testing/fstest" @@ -85,6 +89,26 @@ func TestPendingSeeds(t *testing.T) { // Check error assert.NoError(t, err) }) + + t.Run("finds gzipped seeds", func(t *testing.T) { + pending := []string{"testdata/seed.sql.gz"} + fsys := fs.MapFS{ + pending[0]: &fs.MapFile{Data: gzipData(t, testSeed)}, + } + // Setup mock postgres + conn := pgtest.NewConn() + defer conn.Close(t) + conn.Query(SELECT_SEED_TABLE). + Reply("SELECT 0") + // Run test + seeds, err := GetPendingSeeds(context.Background(), pending, conn.MockClient(t), fsys) + // Check error + assert.NoError(t, err) + require.Len(t, seeds, 1) + assert.Equal(t, pending[0], seeds[0].Path) + assert.Equal(t, hashString(testSeed), seeds[0].Hash) + assert.False(t, seeds[0].Dirty) + }) } func TestSeedData(t *testing.T) { @@ -124,6 +148,28 @@ func TestSeedData(t *testing.T) { // Check error assert.ErrorContains(t, err, `ERROR: null value in column "age" of relation "employees" (SQLSTATE 23502)`) }) + + t.Run("seeds from gzipped file", func(t *testing.T) { + seed := SeedFile{ + Path: "testdata/seed.sql.gz", + Hash: hashString(testSeed), + } + fsys := fs.MapFS{ + seed.Path: &fs.MapFile{Data: gzipData(t, testSeed)}, + } + // Setup mock postgres + conn := pgtest.NewConn() + defer conn.Close(t) + mockSeedHistory(conn). + Query(testSeed). + Reply("INSERT 0 1"). + Query(UPSERT_SEED_FILE, seed.Path, seed.Hash). + Reply("INSERT 0 1") + // Run test + err := SeedData(context.Background(), []SeedFile{seed}, conn.MockClient(t), fsys) + // Check error + assert.NoError(t, err) + }) } func mockSeedHistory(conn *pgtest.MockConn) *pgtest.MockConn { @@ -173,4 +219,35 @@ func TestSeedGlobals(t *testing.T) { // Check error assert.ErrorContains(t, err, `ERROR: database "postgres" does not exist (SQLSTATE 3D000)`) }) + + t.Run("seeds from gzipped file", func(t *testing.T) { + pending := []string{"testdata/1_globals.sql.gz"} + fsys := fs.MapFS{ + pending[0]: &fs.MapFile{Data: gzipData(t, testGlobals)}, + } + // Setup mock postgres + conn := pgtest.NewConn() + defer conn.Close(t) + conn.Query(testGlobals). + Reply("CREATE ROLE") + // Run test + err := SeedGlobals(context.Background(), pending, conn.MockClient(t), fsys) + // Check error + assert.NoError(t, err) + }) +} + +func gzipData(t *testing.T, input string) []byte { + t.Helper() + var compressed bytes.Buffer + writer := gzip.NewWriter(&compressed) + _, err := writer.Write([]byte(input)) + require.NoError(t, err) + require.NoError(t, writer.Close()) + return compressed.Bytes() +} + +func hashString(input string) string { + digest := sha256.Sum256([]byte(input)) + return hex.EncodeToString(digest[:]) }