Skip to content

Commit

Permalink
Support dumping to AWS RDS Postgres servers (#125)
Browse files Browse the repository at this point in the history
* remove and re-add foreign keys when dumping to AWS RDS postgres databases
* add --no-comments to pg_dump schema call
  • Loading branch information
sjhewitt authored Sep 25, 2020
1 parent d726dda commit b4a5096
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 12 deletions.
3 changes: 3 additions & 0 deletions cmd/steal.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type (

from string
to string
toRDS bool
concurrency int
readOpts connOpts
writeOpts connOpts
Expand Down Expand Up @@ -65,6 +66,7 @@ func NewStealCmd() *cobra.Command {
persistentFlags.StringVarP(&opts.configPath, "config", "c", config.DefaultConfigFileName, "Path to config file")
persistentFlags.StringVarP(&opts.from, "from", "f", "mysql://root:root@tcp(localhost:3306)/klepto", "Database dsn to steal from")
persistentFlags.StringVarP(&opts.to, "to", "t", "os://stdout/", "Database to output to (default writes to stdOut)")
persistentFlags.BoolVar(&opts.toRDS, "to-rds", false, "If the output server is an AWS RDS server")
persistentFlags.IntVar(&opts.concurrency, "concurrency", runtime.NumCPU(), "Sets the amount of dumps to be performed concurrently")
persistentFlags.DurationVar(&opts.readOpts.timeout, "read-timeout", 5*time.Minute, "Sets the timeout for read operations")
persistentFlags.DurationVar(&opts.readOpts.maxConnLifetime, "read-conn-lifetime", 0, "Sets the maximum amount of time a connection may be reused on the read database")
Expand Down Expand Up @@ -99,6 +101,7 @@ func RunSteal(opts *StealOptions) (err error) {
source = anonymiser.NewAnonymiser(source, opts.cfgTables)
target, err := dumper.NewDumper(dumper.ConnOpts{
DSN: opts.to,
IsRDS: opts.toRDS,
Timeout: opts.writeOpts.timeout,
MaxConnLifetime: opts.writeOpts.maxConnLifetime,
MaxConns: opts.writeOpts.maxConns,
Expand Down
2 changes: 2 additions & 0 deletions pkg/dumper/dumper.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ type (
ConnOpts struct {
// DSN is the connection address.
DSN string
// IsRDS identifies if the server is an AWS RDS server
IsRDS bool
// Timeout is the timeout for dump operations.
Timeout time.Duration
// MaxConnLifetime is the maximum amount of time a connection may be reused on the read database.
Expand Down
72 changes: 61 additions & 11 deletions pkg/dumper/postgres/dumper.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,26 @@ import (
)

type (
foreignKeyInfo struct {
tableName string
constraintName string
constraintDefinition string
}

pgDumper struct {
conn *sql.DB
reader reader.Reader
conn *sql.DB
reader reader.Reader
isRDS bool
foreignKeys []foreignKeyInfo
}
)

// NewDumper returns a new postgres dumper.
func NewDumper(conn *sql.DB, rdr reader.Reader) dumper.Dumper {
func NewDumper(opts dumper.ConnOpts, conn *sql.DB, rdr reader.Reader) dumper.Dumper {
return engine.New(rdr, &pgDumper{
conn: conn,
reader: rdr,
isRDS: opts.IsRDS,
})
}

Expand Down Expand Up @@ -72,26 +81,67 @@ func (d *pgDumper) DumpTable(tableName string, rowChan <-chan database.Row) erro
// PreDumpTables Disable triggers on all tables to avoid foreign key constraints
func (d *pgDumper) PreDumpTables(tables []string) error {
// We can't use `SET session_replication_role = replica` because multiple connections and stuff
for _, tbl := range tables {
query := fmt.Sprintf("ALTER TABLE %s DISABLE TRIGGER ALL", strconv.Quote(tbl))
if _, err := d.conn.Exec(query); err != nil {
return errors.Wrapf(err, "Failed to disable triggers for %s", tbl)
// For RDS databases, the superuser does not have the required permission to call
// DISABLE TRIGGER ALL, so manually remove and re-add all Foreign Keys
if !d.isRDS {
log.Debug("Disabling triggers")
for _, tbl := range tables {
query := fmt.Sprintf("ALTER TABLE %s DISABLE TRIGGER ALL", strconv.Quote(tbl))
if _, err := d.conn.Exec(query); err != nil {
return errors.Wrapf(err, "Failed to disable triggers for %s", tbl)
}
}
return nil
}

log.Debug("Removing foreign keys")
query := `SELECT conrelid::regclass::varchar tableName,
conname constraintName,
pg_catalog.pg_get_constraintdef(r.oid, true) constraintDefinition
FROM pg_catalog.pg_constraint r
WHERE r.contype = 'f'
AND r.connamespace = (SELECT n.oid FROM pg_namespace n WHERE n.nspname = current_schema())
`
rows, err := d.conn.Query(query)
if err != nil {
return errors.Wrapf(err, "Failed to query ForeignKeys")
}
defer rows.Close()
for rows.Next() {
var fk foreignKeyInfo
if err := rows.Scan(&fk.tableName, &fk.constraintName, &fk.constraintDefinition); err != nil {
return errors.Wrapf(err, "Failed to load ForeignKeyInfo")
}
query := fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", strconv.Quote(fk.tableName), strconv.Quote(fk.constraintName))
if _, err := d.conn.Exec(query); err != nil {
return errors.Wrapf(err, "Failed to frop contraint %s.%s", fk.tableName, fk.constraintName)
}
d.foreignKeys = append(d.foreignKeys, fk)
}
return nil
}

// PostDumpTables enable triggers on all tables to enforce foreign key constraints
func (d *pgDumper) PostDumpTables(tables []string) error {
// We can't use `SET session_replication_role = DEFAULT` because multiple connections and stuff
for _, tbl := range tables {
query := fmt.Sprintf("ALTER TABLE %s ENABLE TRIGGER ALL", strconv.Quote(tbl))
if _, err := d.conn.Exec(query); err != nil {
return errors.Wrapf(err, "Failed to anble triggers for %s", tbl)
if !d.isRDS {
log.Debug("Reenabling triggers")
for _, tbl := range tables {
query := fmt.Sprintf("ALTER TABLE %s ENABLE TRIGGER ALL", strconv.Quote(tbl))
if _, err := d.conn.Exec(query); err != nil {
return errors.Wrapf(err, "Failed to enable triggers for %s", tbl)
}
}
return nil
}

log.Debug("Recreating foreign keys")
for _, fk := range d.foreignKeys {
query := fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s", strconv.Quote(fk.tableName), strconv.Quote(fk.constraintName), fk.constraintDefinition)
if _, err := d.conn.Exec(query); err != nil {
return errors.Wrapf(err, "Failed to re-create ForeignKey %s.%s", fk.tableName, fk.constraintName)
}
}
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/dumper/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (m *driver) NewConnection(opts dumper.ConnOpts, rdr reader.Reader) (dumper.
conn.SetMaxIdleConns(opts.MaxIdleConns)
conn.SetConnMaxLifetime(opts.MaxConnLifetime)

return NewDumper(conn, rdr), nil
return NewDumper(opts, conn, rdr), nil
}

func init() {
Expand Down
1 change: 1 addition & 0 deletions pkg/reader/postgres/pg_dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func (p *PgDump) GetStructure() (string, error) {
"--schema-only",
"--no-privileges",
"--no-owner",
"--no-comments",
)

logger.Debug("loading schema for table")
Expand Down

0 comments on commit b4a5096

Please sign in to comment.