Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 50 additions & 16 deletions pkg/migration/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,22 @@ func NewMigrationFromReader(sql io.Reader) (*MigrationFile, error) {
return &MigrationFile{Statements: lines}, nil
}

func isPipelineIncompatible(sql string) bool {
upper := strings.ToUpper(strings.TrimSpace(sql))
return strings.Contains(upper, "INDEX CONCURRENTLY") ||
strings.Contains(upper, "REINDEX CONCURRENTLY") ||
strings.HasPrefix(upper, "VACUUM ") ||
strings.HasPrefix(upper, "ALTER SYSTEM") ||
strings.HasPrefix(upper, "CLUSTER ")
}

func (m *MigrationFile) ExecBatch(ctx context.Context, conn *pgx.Conn) error {
// Batch migration commands, without using statement cache
batch := &pgconn.Batch{}
for _, line := range m.Statements {
batch.ExecParams(line, nil, nil, nil, nil)
}
// Insert into migration history
if len(m.Version) > 0 {
if err := m.insertVersionSQL(conn, batch); err != nil {
return err
}
}
// ExecBatch is implicitly transactional
if result, err := conn.PgConn().ExecBatch(ctx, batch).ReadAll(); err != nil {
// Defaults to printing the last statement on error
batchSize := 0
executed := 0

formatError := func(err error, i int) error {
stat := INSERT_MIGRATION_VERSION
i := len(result)
if i < len(m.Statements) {
stat = m.Statements[i]
}
Expand All @@ -99,7 +98,6 @@ func (m *MigrationFile) ExecBatch(ctx context.Context, conn *pgx.Conn) error {
if len(pgErr.Detail) > 0 {
msg = append(msg, pgErr.Detail)
}
// Provide helpful hint for extension type errors (SQLSTATE 42704: undefined_object)
if typeName := extractTypeName(pgErr.Message); len(typeName) > 0 && pgErr.Code == "42704" && !IsSchemaQualified(typeName) {
msg = append(msg, "")
msg = append(msg, "Hint: This type may be defined in a schema that's not in your search_path.")
Expand All @@ -111,7 +109,43 @@ func (m *MigrationFile) ExecBatch(ctx context.Context, conn *pgx.Conn) error {
msg = append(msg, fmt.Sprintf("At statement: %d", i), stat)
return errors.Errorf("%w\n%s", err, strings.Join(msg, "\n"))
}
return nil

flushBatch := func() error {
if batchSize == 0 {
return nil
}
if result, err := conn.PgConn().ExecBatch(ctx, batch).ReadAll(); err != nil {
return formatError(err, executed+len(result))
}
executed += batchSize
batch = &pgconn.Batch{}
batchSize = 0
return nil
}

for _, line := range m.Statements {
if isPipelineIncompatible(line) {
if err := flushBatch(); err != nil {
return err
}
if _, err := conn.Exec(ctx, line); err != nil {
return formatError(err, executed)
}
executed++
} else {
batch.ExecParams(line, nil, nil, nil, nil)
batchSize++
}
}

// Insert into migration history
if len(m.Version) > 0 {
if err := m.insertVersionSQL(conn, batch); err != nil {
return err
}
}

return flushBatch()
}

func markError(stat string, pos int) string {
Expand Down