Skip to content

Commit 15b2d10

Browse files
committed
snapshot: trim extension statements from pgdump when uploading
1 parent 7462d79 commit 15b2d10

3 files changed

Lines changed: 168 additions & 19 deletions

File tree

cmd/src/snapshot_upload.go

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ BUCKET
3636
flagSet := flag.NewFlagSet("upload", flag.ExitOnError)
3737
bucketName := flagSet.String("bucket", "", "destination Cloud Storage bucket name")
3838
credentialsPath := flagSet.String("credentials", "", "JSON credentials file for Google Cloud service account")
39+
trimExtensions := flagSet.Bool("trim-extensions", true, "trim EXTENSION statements from database dumps for import to Google Cloud SQL")
3940

4041
snapshotCommands = append(snapshotCommands, &command{
4142
flagSet: flagSet,
@@ -59,8 +60,9 @@ BUCKET
5960
}
6061

6162
type upload struct {
62-
file *os.File
63-
stat os.FileInfo
63+
file *os.File
64+
stat os.FileInfo
65+
trimExtensions bool
6466
}
6567
var (
6668
uploads []upload // index aligned with progressBars
@@ -76,8 +78,9 @@ BUCKET
7678
return errors.Wrap(err, "get file size")
7779
}
7880
uploads = append(uploads, upload{
79-
file: f,
80-
stat: stat,
81+
file: f,
82+
stat: stat,
83+
trimExtensions: false, // not a database dump
8184
})
8285
progressBars = append(progressBars, output.ProgressBar{
8386
Label: stat.Name(),
@@ -95,8 +98,9 @@ BUCKET
9598
return errors.Wrap(err, "get file size")
9699
}
97100
uploads = append(uploads, upload{
98-
file: f,
99-
stat: stat,
101+
file: f,
102+
stat: stat,
103+
trimExtensions: *trimExtensions,
100104
})
101105
progressBars = append(progressBars, output.ProgressBar{
102106
Label: stat.Name(),
@@ -116,7 +120,7 @@ BUCKET
116120
g.Go(func(ctx context.Context) error {
117121
progressFn := func(p int64) { progress.SetValue(i, float64(p)) }
118122

119-
if err := copyToBucket(ctx, u.file, u.stat, bucket, progressFn); err != nil {
123+
if err := copyDumpToBucket(ctx, u.file, u.stat, bucket, progressFn, u.trimExtensions); err != nil {
120124
return errors.Wrap(err, u.stat.Name())
121125
}
122126

@@ -139,26 +143,43 @@ BUCKET
139143
})
140144
}
141145

142-
func copyToBucket(ctx context.Context, src io.Reader, stat fs.FileInfo, dst *storage.BucketHandle, progressFn func(int64)) error {
143-
writer := dst.Object(stat.Name()).NewWriter(ctx)
144-
writer.ProgressFunc = progressFn
145-
defer writer.Close()
146+
func copyDumpToBucket(ctx context.Context, src io.ReadSeeker, stat fs.FileInfo, dst *storage.BucketHandle, progressFn func(int64), trimExtensions bool) error {
147+
// Set up object to write to
148+
object := dst.Object(stat.Name()).NewWriter(ctx)
149+
object.ProgressFunc = progressFn
150+
defer object.Close()
151+
152+
// To assert against actual file size
153+
var totalWritten int64
154+
155+
// Do a partial copy that trims out unwanted statements
156+
if trimExtensions {
157+
written, err := pgdump.PartialCopyWithoutExtensions(object, src, progressFn)
158+
if err != nil {
159+
return errors.Wrap(err, "trim extensions and upload")
160+
}
161+
totalWritten += written
162+
}
146163

147164
// io.Copy is the best way to copy from a reader to writer in Go, and storage.Writer
148-
// has its own chunking mechanisms internally.
149-
written, err := io.Copy(writer, src)
165+
// has its own chunking mechanisms internally. io.Reader is stateful, so this copy
166+
// will just continue from where we left off if we use copyAndTrimExtensions.
167+
written, err := io.Copy(object, src)
150168
if err != nil {
151-
return err
169+
return errors.Wrap(err, "upload")
152170
}
171+
totalWritten += written
153172

154-
// Progress is not called on completion, so we call it manually after io.Copy is done
173+
// Progress is not called on completion of io.Copy, so we call it manually after to
174+
// update our pretty progress bars.
155175
progressFn(written)
156176

157-
// Validate we have sent all data
177+
// Validate we have sent all data. copyAndTrimExtensions may add some bytes, so the
178+
// check is not a strict equality.
158179
size := stat.Size()
159-
if written != size {
160-
return errors.Newf("expected to write %d bytes, but actually wrote %d bytes",
161-
size, written)
180+
if totalWritten < size {
181+
return errors.Newf("expected to write %d bytes, but actually wrote %d bytes (diff: %d bytes)",
182+
size, totalWritten, totalWritten-size)
162183
}
163184

164185
return nil

internal/pgdump/extensions.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package pgdump
2+
3+
import (
4+
"bufio"
5+
"bytes"
6+
"io"
7+
)
8+
9+
// PartialCopyWithoutExtensions will perform a partial copy of a SQL database dump from
10+
// src to dst while commenting out EXTENSIONs-related statements. When it determines there
11+
// are no more EXTENSIONs-related statements, it will return, resetting src to the position
12+
// of the last contents written to dst.
13+
//
14+
// This is needed for import to Google Cloud Storage, which does not like many EXTENSION
15+
// statements. For more details, see https://cloud.google.com/sql/docs/postgres/import-export/import-export-dmp
16+
//
17+
// Filtering requires reading entire lines into memory - this can be a very expensive
18+
// operation, so when filtering is complete the more efficient io.Copy should be used
19+
// to perform the remainder of the copy from src to dst.
20+
func PartialCopyWithoutExtensions(dst io.Writer, src io.ReadSeeker, progressFn func(int64)) (written int64, err error) {
21+
var (
22+
reader = bufio.NewReader(src)
23+
// position we have consumed up to, track separately because bufio.Reader may have
24+
// read ahead on src. This allows us to reset src later.
25+
consumed int64
26+
// set to true when we have done all our filtering
27+
noMoreExtensions bool
28+
)
29+
30+
for !noMoreExtensions {
31+
// Read up to a line, keeping track of our position in src
32+
var line []byte
33+
line, err = reader.ReadBytes('\n')
34+
consumed += int64(len(line))
35+
if err != nil {
36+
return
37+
}
38+
39+
// Once we start seeing table creations, we are definitely done with extensions,
40+
// so we can hand off the rest to the superior io.Copy implementation.
41+
if bytes.HasPrefix(line, []byte("CREATE TABLE")) {
42+
// we are done with extensions
43+
noMoreExtensions = true
44+
} else if bytes.HasPrefix(line, []byte("COMMENT ON EXTENSION")) {
45+
// comment out this line
46+
line = append([]byte("-- "), line...)
47+
}
48+
49+
// Write this line and update our progress before returning on error
50+
var lineWritten int
51+
lineWritten, err = dst.Write(line)
52+
written += int64(lineWritten)
53+
progressFn(written)
54+
if err != nil {
55+
return
56+
}
57+
}
58+
59+
// No more extensions - reset src to the last actual consumed position
60+
_, err = src.Seek(consumed, io.SeekStart)
61+
62+
return
63+
}

internal/pgdump/extensions_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package pgdump
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"os"
7+
"path/filepath"
8+
"testing"
9+
10+
"github.com/hexops/autogold"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
func TestPartialCopyWithoutExtensions(t *testing.T) {
16+
// Create test data - there is no stdlib in-memory io.ReadSeeker implementation
17+
src, err := os.Create(filepath.Join(t.TempDir(), t.Name()))
18+
require.NoError(t, err)
19+
_, err = src.WriteString(`-- Some comment
20+
21+
CREATE EXTENSION foobar
22+
23+
COMMENT ON EXTENSION barbaz
24+
25+
CREATE TYPE asdf
26+
27+
CREATE TABLE robert (
28+
...
29+
)
30+
31+
CREATE TABLE bobhead (
32+
...
33+
)`)
34+
require.NoError(t, err)
35+
_, err = src.Seek(0, io.SeekStart)
36+
require.NoError(t, err)
37+
38+
// Set up target to assert against
39+
var dst bytes.Buffer
40+
41+
// Perform partial copy
42+
_, err = PartialCopyWithoutExtensions(&dst, src, func(i int64) {})
43+
assert.NoError(t, err)
44+
45+
// Copy rest of contents
46+
_, err = io.Copy(&dst, src)
47+
assert.NoError(t, err)
48+
49+
// Assert contents (update with -update)
50+
autogold.Want("partial-copy-without-extensions", `-- Some comment
51+
52+
CREATE EXTENSION foobar
53+
54+
-- COMMENT ON EXTENSION barbaz
55+
56+
CREATE TYPE asdf
57+
58+
CREATE TABLE robert (
59+
...
60+
)
61+
62+
CREATE TABLE bobhead (
63+
...
64+
)`).Equal(t, dst.String())
65+
}

0 commit comments

Comments
 (0)