diff --git a/README.md b/README.md index 4461532..b191f2d 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ # embedded-postgres -Run a real Postgres database locally on Linux, OSX or Windows as part of another Go application or test. +Run a real Postgres database locally on Linux, OSX, Windows or FreeBSD (amd64) as part of another Go application or test. When testing this provides a higher level of confidence than using any in memory alternative. It also requires no other external dependencies outside of the Go build ecosystem. @@ -69,6 +69,10 @@ is done). If your test need to run multiple different versions of Postgres for different tests, make sure *BinaryPath* is a subdirectory of *RuntimePath*. +On FreeBSD amd64 the default artifact naming currently uses `freebsd13-amd64`. +If you publish a different FreeBSD line such as `freebsd14-amd64`, select it explicitly with `Platform("freebsd14")`. +If you already have a prebuilt binary tree on disk, prefer `BinariesPath(...)` to skip remote downloads entirely. + A single Postgres instance can be created, started and stopped as follows ```go @@ -89,7 +93,9 @@ Username("beer"). Password("wine"). Database("gin"). Version(V12). +Platform("freebsd14"). RuntimePath("/tmp"). +BinariesPath("/opt/embedded-postgres"). BinaryRepositoryURL("https://repo.local/central.proxy"). Port(9876). StartTimeout(45 * time.Second). @@ -102,6 +108,35 @@ err := postgres.Start() err := postgres.Stop() ``` +If you want a reusable starting point for local macOS development and FreeBSD 13 +deployments, use the `preset` helper package: + +```go +import ( + "time" + + embeddedpostgres "github.com/fergusstrange/embedded-postgres" + "github.com/fergusstrange/embedded-postgres/preset" +) + +config := preset.LocalDevelopment("myapp", preset.Options{ + Version: embeddedpostgres.V18, + Port: 5433, + StartTimeout: 30 * time.Second, + // Optional on FreeBSD when you pre-install the binary tree yourself. + // BinariesPath: "/opt/embedded-postgres", +}) + +postgres := embeddedpostgres.NewDatabase(config) +err := postgres.Start() +defer postgres.Stop() +``` + +On FreeBSD this preset defaults to the `freebsd13` artifact line and uses +`/var/tmp//embedded-postgres/runtime` plus +`/var/db//embedded-postgres/data`. Override `Platform("freebsd14")` or the +paths if your host layout differs. + It should be noted that if `postgres.Stop()` is not called then the child Postgres process will not be released and the caller will block. @@ -119,4 +154,3 @@ in [examples](https://github.com/fergusstrange/embedded-postgres/tree/master/exa ## Contributing View the [contributing guide](CONTRIBUTING.md). - diff --git a/bundles/libcrypto.so.111 b/bundles/libcrypto.so.111 new file mode 100644 index 0000000..5c40536 Binary files /dev/null and b/bundles/libcrypto.so.111 differ diff --git a/bundles/libssl.so.111 b/bundles/libssl.so.111 new file mode 100644 index 0000000..7cae13f Binary files /dev/null and b/bundles/libssl.so.111 differ diff --git a/bundles/postgres-freebsd13-x86_64.txz b/bundles/postgres-freebsd13-x86_64.txz new file mode 100644 index 0000000..28fa9d1 Binary files /dev/null and b/bundles/postgres-freebsd13-x86_64.txz differ diff --git a/bundles/postgres-freebsd14-x86_64.txz b/bundles/postgres-freebsd14-x86_64.txz new file mode 100644 index 0000000..09e0622 Binary files /dev/null and b/bundles/postgres-freebsd14-x86_64.txz differ diff --git a/config.go b/config.go index fb18a49..0299d98 100644 --- a/config.go +++ b/config.go @@ -3,6 +3,7 @@ package embeddedpostgres import ( "fmt" "io" + "net/url" "os" "time" ) @@ -10,7 +11,11 @@ import ( // Config maintains the runtime configuration for the Postgres process to be created. type Config struct { version PostgresVersion + platform string port uint32 + useUnixSocket bool + unixSocketDirectory string + logDirectory string database string username string password string @@ -38,6 +43,8 @@ func DefaultConfig() Config { return Config{ version: V18, port: 5432, + useUnixSocket: false, + unixSocketDirectory: "/tmp/", database: "postgres", username: "postgres", password: "postgres", @@ -53,12 +60,38 @@ func (c Config) Version(version PostgresVersion) Config { return c } +// Platform sets the artifact platform line used to resolve postgres binaries. +// For example: "freebsd13", "freebsd14" or "alpine". +func (c Config) Platform(platform string) Config { + c.platform = platform + return c +} + // Port sets the runtime port that Postgres can be accessed on. func (c Config) Port(port uint32) Config { c.port = port return c } +// WithoutTcp makes Postgres listen on a UNIX socket instead of opening a TCP port. +func (c Config) WithoutTcp() Config { + c.useUnixSocket = true + return c +} + +// WithUnixSocketDirectory sets the directory where Postgres creates its UNIX socket. +func (c Config) WithUnixSocketDirectory(dir string) Config { + c.unixSocketDirectory = dir + return c +} + +// LogDirectory sets the directory where the temporary embedded Postgres capture +// file is created before its content is forwarded to the configured logger. +func (c Config) LogDirectory(dir string) Config { + c.logDirectory = dir + return c +} + // Database sets the database name that will be created. func (c Config) Database(database string) Config { c.database = database @@ -145,7 +178,23 @@ func (c Config) BinaryRepositoryURL(binaryRepositoryURL string) Config { } func (c Config) GetConnectionURL() string { - return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", c.username, c.password, "localhost", c.port, c.database) + u := &url.URL{ + Scheme: "postgresql", + User: url.UserPassword(c.username, c.password), + Path: "/" + c.database, + } + + if c.useUnixSocket { + u.Host = fmt.Sprintf(":%d", c.port) + + q := url.Values{} + q.Set("host", c.unixSocketDirectory) + u.RawQuery = q.Encode() + } else { + u.Host = fmt.Sprintf("localhost:%d", c.port) + } + + return u.String() } // PostgresVersion represents the semantic version used to fetch and run the Postgres process. diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..266f61b --- /dev/null +++ b/config_test.go @@ -0,0 +1,35 @@ +package embeddedpostgres + +import "testing" + +func TestGetConnectionURL(t *testing.T) { + config := DefaultConfig().Database("mydb").Username("myuser").Password("mypass") + expect := "postgresql://myuser:mypass@localhost:5432/mydb" + + if got := config.GetConnectionURL(); got != expect { + t.Errorf("expected %q got %q", expect, got) + } +} + +func TestGetConnectionURLWithUnixSocket(t *testing.T) { + config := DefaultConfig().Database("mydb").Username("myuser").Password("mypass").WithoutTcp() + expect := "postgresql://myuser:mypass@:5432/mydb?host=%2Ftmp%2F" + + if got := config.GetConnectionURL(); got != expect { + t.Errorf("expected %q got %q", expect, got) + } +} + +func TestGetConnectionURLWithUnixSocketInCustomDir(t *testing.T) { + config := DefaultConfig(). + Database("mydb"). + Username("myuser"). + Password("mypass"). + WithoutTcp(). + WithUnixSocketDirectory("/path/to/socks") + expect := "postgresql://myuser:mypass@:5432/mydb?host=%2Fpath%2Fto%2Fsocks" + + if got := config.GetConnectionURL(); got != expect { + t.Errorf("expected %q got %q", expect, got) + } +} diff --git a/decompression.go b/decompression.go index 79cecd9..6daaaa3 100644 --- a/decompression.go +++ b/decompression.go @@ -10,6 +10,8 @@ import ( "github.com/xi2/xz" ) +type progressLogger func(format string, args ...any) + func defaultTarReader(xzReader *xz.Reader) (func() (*tar.Header, error), func() io.Reader) { tarReader := tar.NewReader(xzReader) @@ -20,7 +22,7 @@ func defaultTarReader(xzReader *xz.Reader) (func() (*tar.Header, error), func() } } -func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), func() io.Reader), path, extractPath string) error { +func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), func() io.Reader), path, extractPath string, logf progressLogger) error { extractDirectory := filepath.Dir(extractPath) if err := os.MkdirAll(extractDirectory, os.ModePerm); err != nil { @@ -54,6 +56,7 @@ func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), fu } readNext, reader := tarReader(xzReader) + entryCount := 0 for { header, err := readNext() @@ -79,6 +82,7 @@ func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), fu switch header.Typeflag { case tar.TypeReg: + logProgress(logf, "extracting embedded postgres entry archive=%s entry=%s type=file target=%s size=%d", path, header.Name, finalPath, header.Size) outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_RDWR, os.FileMode(header.Mode)) if err != nil { return errorExtractingPostgres(err) @@ -92,6 +96,7 @@ func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), fu return errorExtractingPostgres(err) } case tar.TypeSymlink: + logProgress(logf, "extracting embedded postgres entry archive=%s entry=%s type=symlink target=%s link=%s", path, header.Name, finalPath, header.Linkname) if err := os.RemoveAll(targetPath); err != nil { return errorExtractingPostgres(err) } @@ -101,20 +106,32 @@ func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), fu } case tar.TypeDir: + logProgress(logf, "extracting embedded postgres entry archive=%s entry=%s type=dir target=%s", path, header.Name, finalPath) if err := os.MkdirAll(finalPath, os.FileMode(header.Mode)); err != nil { return errorExtractingPostgres(err) } + entryCount++ continue } if err := renameOrIgnore(targetPath, finalPath); err != nil { return errorExtractingPostgres(err) } + entryCount++ } + logProgress(logf, "finished extracting embedded postgres archive archive=%s destination=%s entries=%d", path, extractPath, entryCount) + return nil } +func logProgress(logf progressLogger, format string, args ...any) { + if logf == nil { + return + } + logf(format, args...) +} + func errorUnableToExtract(cacheLocation, binariesPath string, err error) error { return fmt.Errorf("unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories, %w", cacheLocation, diff --git a/decompression_test.go b/decompression_test.go index 92c6a73..79f073e 100644 --- a/decompression_test.go +++ b/decompression_test.go @@ -2,6 +2,7 @@ package embeddedpostgres import ( "archive/tar" + "bytes" "errors" "fmt" "io" @@ -28,7 +29,7 @@ func Test_decompressTarXz(t *testing.T) { archive, cleanUp := createTempXzArchive() defer cleanUp() - err = decompressTarXz(defaultTarReader, archive, tempDir) + err = decompressTarXz(defaultTarReader, archive, tempDir, nil) assert.NoError(t, err) @@ -42,7 +43,7 @@ func Test_decompressTarXz(t *testing.T) { } func Test_decompressTarXz_ErrorWhenFileNotExists(t *testing.T) { - err := decompressTarXz(defaultTarReader, "/does-not-exist", "/also-fake") + err := decompressTarXz(defaultTarReader, "/does-not-exist", "/also-fake", nil) assert.Error(t, err) assert.Contains( @@ -68,7 +69,7 @@ func Test_decompressTarXz_ErrorWhenErrorDuringRead(t *testing.T) { return func() (*tar.Header, error) { return nil, errors.New("oh noes") }, nil - }, archive, tempDir) + }, archive, tempDir, nil) assert.EqualError(t, err, "unable to extract postgres archive: oh noes") } @@ -108,7 +109,7 @@ func Test_decompressTarXz_ErrorWhenFailedToReadFileToCopy(t *testing.T) { } } - err = decompressTarXz(fileBlockingExtractTarReader, archive, tempDir) + err = decompressTarXz(fileBlockingExtractTarReader, archive, tempDir, nil) assert.Regexp(t, "^unable to extract postgres archive:.+$", err) } @@ -145,7 +146,7 @@ func Test_decompressTarXz_ErrorWhenFileToCopyToNotExists(t *testing.T) { } } - err = decompressTarXz(fileBlockingExtractTarReader, archive, tempDir) + err = decompressTarXz(fileBlockingExtractTarReader, archive, tempDir, nil) assert.Regexp(t, "^unable to extract postgres archive:.+$", err) } @@ -180,7 +181,7 @@ func Test_decompressTarXz_ErrorWhenArchiveCorrupted(t *testing.T) { panic(err) } - err = decompressTarXz(defaultTarReader, archive, tempDir) + err = decompressTarXz(defaultTarReader, archive, tempDir, nil) assert.EqualError(t, err, "unable to extract postgres archive: xz: data is corrupt") } @@ -197,10 +198,26 @@ func Test_decompressTarXz_ErrorWithInvalidDestination(t *testing.T) { op := fmt.Sprintf(path.Join(tempDir, "%c"), rune(0)) - err = decompressTarXz(defaultTarReader, archive, op) + err = decompressTarXz(defaultTarReader, archive, op, nil) assert.EqualError( t, err, fmt.Sprintf("unable to extract postgres archive: mkdir %s: invalid argument", op), ) } + +func Test_decompressTarXz_LogsExtractedEntries(t *testing.T) { + tempDir := t.TempDir() + archive, cleanUp := createTempXzArchive() + defer cleanUp() + + var logs bytes.Buffer + err := decompressTarXz(defaultTarReader, archive, tempDir, func(format string, args ...any) { + _, _ = fmt.Fprintf(&logs, format+"\n", args...) + }) + + require.NoError(t, err) + assert.Contains(t, logs.String(), "extracting embedded postgres entry") + assert.Contains(t, logs.String(), "entry=dir1/dir2/some_content") + assert.Contains(t, logs.String(), "finished extracting embedded postgres archive") +} diff --git a/embedded_postgres.go b/embedded_postgres.go index afe8497..155cf6c 100644 --- a/embedded_postgres.go +++ b/embedded_postgres.go @@ -71,11 +71,13 @@ func (ep *EmbeddedPostgres) Start() error { return ErrServerAlreadyStarted } - if err := ensurePortAvailable(ep.config.port); err != nil { - return err + if !ep.config.useUnixSocket { + if err := ensurePortAvailable(ep.config.port); err != nil { + return err + } } - logger, err := newSyncedLogger("", ep.config.logger) + logger, err := newSyncedLogger(ep.config.logDirectory, ep.config.logger) if err != nil { return errors.New("unable to create logger") } @@ -100,6 +102,15 @@ func (ep *EmbeddedPostgres) Start() error { ep.config.binariesPath = ep.config.runtimePath } + ep.logf( + "embedded postgres binary setup cache=%s cache_exists=%t runtime_path=%s binaries_path=%s data_path=%s", + cacheLocation, + cacheExists, + ep.config.runtimePath, + ep.config.binariesPath, + ep.config.dataPath, + ) + if err := ep.downloadAndExtractBinary(cacheExists, cacheLocation); err != nil { return err } @@ -127,7 +138,12 @@ func (ep *EmbeddedPostgres) Start() error { ep.started = true if !reuseData { - if err := ep.createDatabase(ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil { + host := "localhost" + if ep.config.useUnixSocket { + host = ep.config.unixSocketDirectory + } + + if err := ep.createDatabase(host, ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil { if stopErr := stopPostgres(ep); stopErr != nil { return fmt.Errorf("unable to stop database caused by error %s", err) } @@ -152,21 +168,40 @@ func (ep *EmbeddedPostgres) downloadAndExtractBinary(cacheExists bool, cacheLoca mu.Lock() defer mu.Unlock() - _, binDirErr := os.Stat(filepath.Join(ep.config.binariesPath, "bin", "pg_ctl")) + pgCtlPath := filepath.Join(ep.config.binariesPath, "bin", "pg_ctl") + _, binDirErr := os.Stat(pgCtlPath) if os.IsNotExist(binDirErr) { if !cacheExists { - if err := ep.remoteFetchStrategy(); err != nil { + ep.logf("downloading embedded postgres archive cache=%s", cacheLocation) + if err := ep.remoteFetchStrategy(ep.logf); err != nil { return err } + } else { + ep.logf("using cached embedded postgres archive cache=%s", cacheLocation) } - if err := decompressTarXz(defaultTarReader, cacheLocation, ep.config.binariesPath); err != nil { + ep.logf("extracting embedded postgres archive archive=%s destination=%s", cacheLocation, ep.config.binariesPath) + if err := decompressTarXz(defaultTarReader, cacheLocation, ep.config.binariesPath, ep.logf); err != nil { return err } + ep.logf("embedded postgres archive extracted archive=%s destination=%s", cacheLocation, ep.config.binariesPath) + } else { + ep.logf("embedded postgres binaries already available pg_ctl=%s", pgCtlPath) } return nil } +func (ep *EmbeddedPostgres) logf(format string, args ...any) { + if ep == nil || ep.syncedLogger == nil { + return + } + ep.syncedLogger.logf(format, args...) +} + +func (ep *EmbeddedPostgres) GetConnectionURL() string { + return ep.config.GetConnectionURL() +} + func (ep *EmbeddedPostgres) cleanDataDirectoryAndInit() error { if err := os.RemoveAll(ep.config.dataPath); err != nil { return fmt.Errorf("unable to clean up data directory %s with error: %s", ep.config.dataPath, err) @@ -210,27 +245,65 @@ func encodeOptions(port uint32, parameters map[string]string) string { } func startPostgres(ep *EmbeddedPostgres) error { + if err := ensureFreeBSDRuntimeUser(ep.config.binariesPath, ep.config.runtimePath, ep.config.dataPath); err != nil { + return err + } + + if ep.config.startParameters == nil { + ep.config.startParameters = make(map[string]string) + } + + if ep.config.useUnixSocket { + ep.config.startParameters["listen_addresses"] = "" + ep.config.startParameters["unix_socket_directories"] = ep.config.unixSocketDirectory + } + postgresBinary := filepath.Join(ep.config.binariesPath, "bin/pg_ctl") - postgresProcess := exec.Command(postgresBinary, "start", "-w", + postgresProcess := wrapCommandForRuntimeUser(exec.Command(postgresBinary, "start", "-w", "-D", ep.config.dataPath, - "-o", encodeOptions(ep.config.port, ep.config.startParameters)) + "-o", encodeOptions(ep.config.port, ep.config.startParameters))) postgresProcess.Stdout = ep.syncedLogger.file postgresProcess.Stderr = ep.syncedLogger.file if err := postgresProcess.Run(); err != nil { _ = ep.syncedLogger.flush() logContent, _ := readLogsOrTimeout(ep.syncedLogger.file) + logText := string(logContent) - return fmt.Errorf("could not start postgres using %s:\n%s", postgresProcess.String(), string(logContent)) + if runtime.GOOS == "freebsd" && needsFreeBSDICUCopyHint(logText) { + logText += freeBSDICUCopyHint(ep.config.binariesPath) + } + + return fmt.Errorf("could not start postgres using %s:\n%s", postgresProcess.String(), logText) } return nil } +func needsFreeBSDICUCopyHint(logText string) bool { + return strings.Contains(logText, `U_FILE_ACCESS_ERROR`) || + strings.Contains(logText, `could not open collator for locale "und"`) || + strings.Contains(logText, `pg_collation_actual_version(oid)`) || + strings.Contains(logText, `icu`) +} + +func freeBSDICUCopyHint(binariesPath string) string { + icuRoot := filepath.Join(binariesPath, "share", "icu") + return fmt.Sprintf( + "\nFreeBSD ICU hint:\n"+ + " The embedded bundle ships ICU data under %s\n"+ + " If PostgreSQL still fails to load collations, copy the bundled ICU version directory into /usr/local/share/icu/.\n"+ + " Example:\n"+ + " cp -R %s/* /usr/local/share/icu/\n", + icuRoot, + icuRoot, + ) +} + func stopPostgres(ep *EmbeddedPostgres) error { postgresBinary := filepath.Join(ep.config.binariesPath, "bin/pg_ctl") - postgresProcess := exec.Command(postgresBinary, "stop", "-w", - "-D", ep.config.dataPath) + postgresProcess := wrapCommandForRuntimeUser(exec.Command(postgresBinary, "stop", "-w", + "-D", ep.config.dataPath)) postgresProcess.Stderr = ep.syncedLogger.file postgresProcess.Stdout = ep.syncedLogger.file diff --git a/embedded_postgres_test.go b/embedded_postgres_test.go index 6429a89..ce4544a 100644 --- a/embedded_postgres_test.go +++ b/embedded_postgres_test.go @@ -65,11 +65,11 @@ func Test_ErrorWhenPortAlreadyTaken(t *testing.T) { } func Test_ErrorWhenRemoteFetchError(t *testing.T) { - database := NewDatabase() + database := NewDatabase(DefaultConfig().Port(9875)) database.cacheLocator = func() (string, bool) { return "", false } - database.remoteFetchStrategy = func() error { + database.remoteFetchStrategy = func(progressLogger) error { return errors.New("did not work") } @@ -156,7 +156,7 @@ func Test_ErrorWhenUnableToCreateDatabase(t *testing.T) { RuntimePath(extractPath). StartTimeout(10 * time.Second)) - database.createDatabase = func(port uint32, username, password, database string) error { + database.createDatabase = func(host string, port uint32, username, password, database string) error { return errors.New("ah noes") } @@ -176,7 +176,7 @@ func Test_TimesOutWhenCannotStart(t *testing.T) { Database("something-fancy"). StartTimeout(500 * time.Millisecond)) - database.createDatabase = func(port uint32, username, password, database string) error { + database.createDatabase = func(host string, port uint32, username, password, database string) error { return nil } @@ -328,7 +328,7 @@ func Test_CustomLog(t *testing.T) { assert.Contains(t, lines, fmt.Sprintf("The files belonging to this database system will be owned by user \"%s\".", current.Username)) assert.Contains(t, lines, "syncing data to disk ... ok") assert.Contains(t, lines, "server stopped") - assert.Less(t, len(lines), 55) + assert.Less(t, len(lines), 70) assert.Greater(t, len(lines), 40) } @@ -767,12 +767,12 @@ func Test_PrefetchedBinaries(t *testing.T) { RuntimePath(runtimeTempDir)) // Download and unarchive postgres into the bindir. - if err := database.remoteFetchStrategy(); err != nil { + if err := database.remoteFetchStrategy(nil); err != nil { panic(err) } cacheLocation, _ := database.cacheLocator() - if err := decompressTarXz(defaultTarReader, cacheLocation, binTempDir); err != nil { + if err := decompressTarXz(defaultTarReader, cacheLocation, binTempDir, nil); err != nil { panic(err) } @@ -780,7 +780,7 @@ func Test_PrefetchedBinaries(t *testing.T) { database.cacheLocator = func() (string, bool) { return "", false } - database.remoteFetchStrategy = func() error { + database.remoteFetchStrategy = func(progressLogger) error { return errors.New("did not work") } @@ -833,3 +833,51 @@ func Test_RunningInParallel(t *testing.T) { waitGroup.Wait() } + +func Test_RunOnUnixSocket(t *testing.T) { + database := NewDatabase(DefaultConfig().Port(9876).WithoutTcp()) + if err := database.Start(); err != nil { + shutdownDBAndFail(t, err, database) + } + + defer database.Stop() + + if _, err := os.Stat("/tmp/.s.PGSQL.9876"); err != nil { + shutdownDBAndFail(t, err, database) + } +} + +func Test_RunOnUnixSocketOnCustomPath(t *testing.T) { + tempPath, err := os.MkdirTemp("", "custom_dir_socks") + if err != nil { + panic(err) + } + + defer os.RemoveAll(tempPath) + + database := NewDatabase(DefaultConfig().Port(9876).WithoutTcp().WithUnixSocketDirectory(tempPath)) + if err := database.Start(); err != nil { + shutdownDBAndFail(t, err, database) + } + + defer database.Stop() + + if _, err := os.Stat(fmt.Sprintf("%s/.s.PGSQL.9876", tempPath)); err != nil { + shutdownDBAndFail(t, err, database) + } +} + +func Test_RunOnUnixSocket_IgnoresBusyTCPPort(t *testing.T) { + listener, err := net.Listen("tcp", "localhost:5432") + if err != nil { + t.Skip("tcp/5432 already in use on this machine") + } + defer listener.Close() + + database := NewDatabase(DefaultConfig().WithoutTcp()) + if err := database.Start(); err != nil { + shutdownDBAndFail(t, err, database) + } + + defer database.Stop() +} diff --git a/examples/embedded_postgres_config_test.go b/examples/embedded_postgres_config_test.go new file mode 100644 index 0000000..ad13dc4 --- /dev/null +++ b/examples/embedded_postgres_config_test.go @@ -0,0 +1,59 @@ +package examples + +import ( + "fmt" + "os" + "strconv" + + embeddedpostgres "github.com/fergusstrange/embedded-postgres" +) + +func exampleDatabaseConfig() embeddedpostgres.Config { + config := embeddedpostgres.DefaultConfig() + + if version := os.Getenv("EMBEDDED_POSTGRES_VERSION"); version != "" { + config = config.Version(embeddedpostgres.PostgresVersion(version)) + } + if platform := os.Getenv("EMBEDDED_POSTGRES_PLATFORM"); platform != "" { + config = config.Platform(platform) + } + if binariesPath := os.Getenv("EMBEDDED_POSTGRES_BINARIES_PATH"); binariesPath != "" { + config = config.BinariesPath(binariesPath) + } + if runtimePath := os.Getenv("EMBEDDED_POSTGRES_RUNTIME_PATH"); runtimePath != "" { + config = config.RuntimePath(runtimePath) + } + if dataPath := os.Getenv("EMBEDDED_POSTGRES_DATA_PATH"); dataPath != "" { + config = config.DataPath(dataPath) + } + if repositoryURL := os.Getenv("EMBEDDED_POSTGRES_BINARY_REPOSITORY_URL"); repositoryURL != "" { + config = config.BinaryRepositoryURL(repositoryURL) + } + if port := os.Getenv("EMBEDDED_POSTGRES_PORT"); port != "" { + if parsedPort, err := strconv.ParseUint(port, 10, 32); err == nil { + config = config.Port(uint32(parsedPort)) + } + } + + return config +} + +func newExampleDatabase() *embeddedpostgres.EmbeddedPostgres { + return embeddedpostgres.NewDatabase(exampleDatabaseConfig()) +} + +func exampleDatabasePort() uint32 { + if port := os.Getenv("EMBEDDED_POSTGRES_PORT"); port != "" { + if parsedPort, err := strconv.ParseUint(port, 10, 32); err == nil { + return uint32(parsedPort) + } + } + return 5432 +} + +func exampleConnectionString() string { + return fmt.Sprintf( + "host=localhost port=%d user=postgres password=postgres dbname=postgres sslmode=disable", + exampleDatabasePort(), + ) +} diff --git a/examples/examples_test.go b/examples/examples_test.go index 9b100ac..cb2aae9 100644 --- a/examples/examples_test.go +++ b/examples/examples_test.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "reflect" "testing" @@ -16,7 +17,7 @@ import ( ) func Test_GooseMigrations(t *testing.T) { - database := embeddedpostgres.NewDatabase() + database := newExampleDatabase() if err := database.Start(); err != nil { t.Fatal(err) } @@ -27,7 +28,7 @@ func Test_GooseMigrations(t *testing.T) { } }() - db, err := connect() + db, err := connect(database.GetConnectionURL()) if err != nil { t.Fatal(err) } @@ -45,8 +46,8 @@ func Test_ZapioLogger(t *testing.T) { w := &zapio.Writer{Log: logger} - database := embeddedpostgres.NewDatabase(embeddedpostgres.DefaultConfig(). - Logger(w)) + config := exampleDatabaseConfig().Logger(w) + database := embeddedpostgres.NewDatabase(config) if err := database.Start(); err != nil { t.Fatal(err) } @@ -57,7 +58,7 @@ func Test_ZapioLogger(t *testing.T) { } }() - db, err := connect() + db, err := connect(database.GetConnectionURL()) if err != nil { t.Fatal(err) } @@ -68,7 +69,7 @@ func Test_ZapioLogger(t *testing.T) { } func Test_Sqlx_SelectOne(t *testing.T) { - database := embeddedpostgres.NewDatabase() + database := newExampleDatabase() if err := database.Start(); err != nil { t.Fatal(err) } @@ -79,7 +80,36 @@ func Test_Sqlx_SelectOne(t *testing.T) { } }() - db, err := connect() + db, err := connect(database.GetConnectionURL()) + if err != nil { + t.Fatal(err) + } + + rows := make([]int32, 0) + + err = db.Select(&rows, "SELECT 1") + if err != nil { + t.Fatal(err) + } + + if len(rows) != 1 { + t.Fatal("Expected one row returned") + } +} + +func Test_UnixSocket_Sqlx_SelectOne(t *testing.T) { + database := embeddedpostgres.NewDatabase(embeddedpostgres.DefaultConfig().Port(9878).WithoutTcp()) + if err := database.Start(); err != nil { + t.Fatal(err) + } + + defer func() { + if err := database.Stop(); err != nil { + t.Fatal(err) + } + }() + + db, err := connect(database.GetConnectionURL()) if err != nil { t.Fatal(err) } @@ -97,7 +127,7 @@ func Test_Sqlx_SelectOne(t *testing.T) { } func Test_ManyTestsAgainstOneDatabase(t *testing.T) { - database := embeddedpostgres.NewDatabase() + database := newExampleDatabase() if err := database.Start(); err != nil { t.Fatal(err) } @@ -108,7 +138,7 @@ func Test_ManyTestsAgainstOneDatabase(t *testing.T) { } }() - db, err := connect() + db, err := connect(database.GetConnectionURL()) if err != nil { t.Fatal(err) } @@ -160,7 +190,7 @@ func Test_ManyTestsAgainstOneDatabase(t *testing.T) { } func Test_SimpleHttpWebApp(t *testing.T) { - database := embeddedpostgres.NewDatabase() + database := newExampleDatabase() if err := database.Start(); err != nil { t.Fatal(err) } @@ -188,7 +218,18 @@ func Test_SimpleHttpWebApp(t *testing.T) { } } -func connect() (*sqlx.DB, error) { - db, err := sqlx.Connect("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") +func connect(connectionURL string) (*sqlx.DB, error) { + parsed, err := url.Parse(connectionURL) + if err != nil { + return nil, err + } + + q := parsed.Query() + if q.Get("sslmode") == "" { + q.Set("sslmode", "disable") + } + parsed.RawQuery = q.Encode() + + db, err := sqlx.Connect("postgres", parsed.String()) return db, err } diff --git a/examples/server_test.go b/examples/server_test.go index 24d3090..0ac5da4 100644 --- a/examples/server_test.go +++ b/examples/server_test.go @@ -25,7 +25,7 @@ func (a *App) Start() error { } func NewApp() *App { - db, err := sqlx.Connect("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") + db, err := sqlx.Connect("postgres", exampleConnectionString()) if err != nil { log.Fatal(err) } diff --git a/icu_hint_test.go b/icu_hint_test.go new file mode 100644 index 0000000..1f07a5e --- /dev/null +++ b/icu_hint_test.go @@ -0,0 +1,31 @@ +package embeddedpostgres + +import ( + "path/filepath" + "strings" + "testing" +) + +func TestNeedsFreeBSDICUCopyHint(t *testing.T) { + t.Parallel() + + if !needsFreeBSDICUCopyHint(`could not open collator for locale "und": U_FILE_ACCESS_ERROR`) { + t.Fatal("expected ICU copy hint to be needed") + } + if needsFreeBSDICUCopyHint(`database system is ready to accept connections`) { + t.Fatal("did not expect ICU copy hint to be needed") + } +} + +func TestFreeBSDICUCopyHint(t *testing.T) { + t.Parallel() + + got := freeBSDICUCopyHint("/opt/embedded-postgres") + wantPath := filepath.Join("/opt/embedded-postgres", "share", "icu") + if !strings.Contains(got, wantPath) { + t.Fatalf("hint %q does not contain %q", got, wantPath) + } + if !strings.Contains(got, "cp -R") { + t.Fatalf("hint %q does not contain copy command", got) + } +} diff --git a/logging.go b/logging.go index 364e932..cd8d3ec 100644 --- a/logging.go +++ b/logging.go @@ -15,6 +15,11 @@ type syncedLogger struct { } func newSyncedLogger(dir string, logger io.Writer) (*syncedLogger, error) { + if dir != "" { + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, err + } + } file, err := os.CreateTemp(dir, "embedded_postgres_log") if err != nil { return nil, err @@ -56,6 +61,21 @@ func (s *syncedLogger) flush() error { return nil } +func (s *syncedLogger) logf(format string, args ...any) { + if s == nil || s.file == nil { + return + } + if _, err := fmt.Fprintf(s.file, format+"\n", args...); err != nil { + panic(err) + } + if err := s.file.Sync(); err != nil { + panic(err) + } + if err := s.flush(); err != nil { + panic(err) + } +} + func readLogsOrTimeout(logger *os.File) (logContent []byte, err error) { logContent = []byte("logs could not be read") diff --git a/prepare_database.go b/prepare_database.go index 751aaea..83ae30f 100644 --- a/prepare_database.go +++ b/prepare_database.go @@ -9,6 +9,7 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "github.com/lib/pq" ) @@ -19,7 +20,7 @@ const ( ) type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error -type createDatabase func(port uint32, username, password, database string) error +type createDatabase func(host string, port uint32, username, password, database string) error func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error { passwordFile, err := createPasswordFile(runtimePath, password) @@ -42,8 +43,12 @@ func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username args = append(args, fmt.Sprintf("--encoding=%s", encoding)) } + if err := ensureFreeBSDRuntimeUser(binaryExtractLocation, runtimePath, pgDataDir); err != nil { + return err + } + postgresInitDBBinary := filepath.Join(binaryExtractLocation, "bin/initdb") - postgresInitDBProcess := exec.Command(postgresInitDBBinary, args...) + postgresInitDBProcess := wrapCommandForRuntimeUser(exec.Command(postgresInitDBBinary, args...)) postgresInitDBProcess.Stderr = logger postgresInitDBProcess.Stdout = logger @@ -52,7 +57,11 @@ func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username if readLogsErr != nil { logContent = []byte(string(logContent) + " - " + readLogsErr.Error()) } - return fmt.Errorf("unable to init database using '%s': %w\n%s", postgresInitDBProcess.String(), err, string(logContent)) + logText := string(logContent) + if runtime.GOOS == "freebsd" && needsFreeBSDICUCopyHint(logText) { + logText += freeBSDICUCopyHint(binaryExtractLocation) + } + return fmt.Errorf("unable to init database using '%s': %w\n%s", postgresInitDBProcess.String(), err, logText) } if err = os.Remove(passwordFile); err != nil { @@ -71,12 +80,12 @@ func createPasswordFile(runtimePath, password string) (string, error) { return passwordFileLocation, nil } -func defaultCreateDatabase(port uint32, username, password, database string) (err error) { +func defaultCreateDatabase(host string, port uint32, username, password, database string) (err error) { if database == "postgres" { return nil } - conn, err := openDatabaseConnection(port, username, password, "postgres") + conn, err := openDatabaseConnection(host, port, username, password, "postgres") if err != nil { return errorCustomDatabase(database, err) } @@ -120,7 +129,12 @@ func healthCheckDatabaseOrTimeout(config Config) error { go func() { for timeout.Err() == nil { - if err := healthCheckDatabase(config.port, config.database, config.username, config.password); err != nil { + host := "localhost" + if config.useUnixSocket { + host = config.unixSocketDirectory + } + + if err := healthCheckDatabase(host, config.port, config.database, config.username, config.password); err != nil { continue } healthCheckSignal <- true @@ -137,8 +151,8 @@ func healthCheckDatabaseOrTimeout(config Config) error { } } -func healthCheckDatabase(port uint32, database, username, password string) (err error) { - conn, err := openDatabaseConnection(port, username, password, database) +func healthCheckDatabase(host string, port uint32, database, username, password string) (err error) { + conn, err := openDatabaseConnection(host, port, username, password, database) if err != nil { return err } @@ -155,8 +169,9 @@ func healthCheckDatabase(port uint32, database, username, password string) (err return nil } -func openDatabaseConnection(port uint32, username string, password string, database string) (*pq.Connector, error) { - conn, err := pq.NewConnector(fmt.Sprintf("host=localhost port=%d user=%s password=%s dbname=%s sslmode=disable", +func openDatabaseConnection(host string, port uint32, username string, password string, database string) (*pq.Connector, error) { + conn, err := pq.NewConnector(fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + host, port, username, password, diff --git a/prepare_database_test.go b/prepare_database_test.go index 2700d27..477bf2c 100644 --- a/prepare_database_test.go +++ b/prepare_database_test.go @@ -132,7 +132,13 @@ func Test_defaultInitDatabase_PwFileRemoved(t *testing.T) { } func Test_defaultCreateDatabase_ErrorWhenSQLOpenError(t *testing.T) { - err := defaultCreateDatabase(1234, "user client_encoding=lol", "password", "database") + err := defaultCreateDatabase("localhost", 1234, "user client_encoding=lol", "password", "database") + + assert.EqualError(t, err, "unable to connect to create database with custom name database with the following error: client_encoding must be absent or 'UTF8'") +} + +func Test_defaultCreateDatabase_ErrorWhenSQLOpenError_UnixSocket(t *testing.T) { + err := defaultCreateDatabase("/tmp", 1234, "user client_encoding=lol", "password", "database") assert.EqualError(t, err, "unable to connect to create database with custom name database with the following error: client_encoding must be absent or 'UTF8'") } @@ -165,13 +171,13 @@ func Test_defaultCreateDatabase_ErrorWhenQueryError(t *testing.T) { } }() - err := defaultCreateDatabase(9831, "postgres", "postgres", "b33r") + err := defaultCreateDatabase("localhost", 9831, "postgres", "postgres", "b33r") assert.EqualError(t, err, `unable to connect to create database with custom name b33r with the following error: pq: database "b33r" already exists`) } func Test_healthCheckDatabase_ErrorWhenSQLConnectingError(t *testing.T) { - err := healthCheckDatabase(1234, "tom client_encoding=lol", "more", "b33r") + err := healthCheckDatabase("localhost", 1234, "tom client_encoding=lol", "more", "b33r") assert.EqualError(t, err, "client_encoding must be absent or 'UTF8'") } diff --git a/preset/local_development.go b/preset/local_development.go new file mode 100644 index 0000000..a317419 --- /dev/null +++ b/preset/local_development.go @@ -0,0 +1,170 @@ +package preset + +import ( + "io" + "os" + "path/filepath" + "runtime" + "strings" + "time" + + embeddedpostgres "github.com/fergusstrange/embedded-postgres" +) + +const ( + defaultAppName = "embedded-postgres" + defaultFreeBSDPlatform = "freebsd13" +) + +// Options customizes the LocalDevelopment preset. +type Options struct { + Version embeddedpostgres.PostgresVersion + Platform string + Port uint32 + Database string + Username string + Password string + CachePath string + RuntimePath string + DataPath string + BinariesPath string + BinaryRepositoryURL string + StartTimeout time.Duration + StartParameters map[string]string + Logger io.Writer +} + +type resolvedConfig struct { + version embeddedpostgres.PostgresVersion + platform string + port uint32 + database string + username string + password string + cachePath string + runtimePath string + dataPath string + binariesPath string + binaryRepositoryURL string + startTimeout time.Duration + startParameters map[string]string + logger io.Writer +} + +// LocalDevelopment returns a config with filesystem defaults that work well for +// local macOS development and FreeBSD 13 hosts. +// +// On FreeBSD amd64 it pins the default artifact line to freebsd13 unless +// Options.Platform overrides it. Runtime files are kept under /var/tmp and the +// data directory is placed under /var/db so it survives process restarts. +func LocalDevelopment(appName string, options Options) embeddedpostgres.Config { + resolved := resolveLocalDevelopment(runtime.GOOS, os.TempDir(), appName, options) + + config := embeddedpostgres.DefaultConfig(). + RuntimePath(resolved.runtimePath). + DataPath(resolved.dataPath) + + if resolved.version != "" { + config = config.Version(resolved.version) + } + if resolved.platform != "" { + config = config.Platform(resolved.platform) + } + if resolved.port != 0 { + config = config.Port(resolved.port) + } + if resolved.database != "" { + config = config.Database(resolved.database) + } + if resolved.username != "" { + config = config.Username(resolved.username) + } + if resolved.password != "" { + config = config.Password(resolved.password) + } + if resolved.cachePath != "" { + config = config.CachePath(resolved.cachePath) + } + if resolved.binariesPath != "" { + config = config.BinariesPath(resolved.binariesPath) + } + if resolved.binaryRepositoryURL != "" { + config = config.BinaryRepositoryURL(resolved.binaryRepositoryURL) + } + if resolved.startTimeout != 0 { + config = config.StartTimeout(resolved.startTimeout) + } + if resolved.startParameters != nil { + config = config.StartParameters(resolved.startParameters) + } + if resolved.logger != nil { + config = config.Logger(resolved.logger) + } + + return config +} + +func resolveLocalDevelopment(goos, tempDir, appName string, options Options) resolvedConfig { + runtimePath, dataPath := defaultPaths(goos, tempDir, appName) + + resolved := resolvedConfig{ + version: options.Version, + port: options.Port, + database: options.Database, + username: options.Username, + password: options.Password, + cachePath: options.CachePath, + runtimePath: runtimePath, + dataPath: dataPath, + binariesPath: options.BinariesPath, + binaryRepositoryURL: options.BinaryRepositoryURL, + startTimeout: options.StartTimeout, + startParameters: options.StartParameters, + logger: options.Logger, + } + + if goos == "freebsd" { + resolved.platform = defaultFreeBSDPlatform + } + + if options.Platform != "" { + resolved.platform = options.Platform + } + if options.RuntimePath != "" { + resolved.runtimePath = options.RuntimePath + } + if options.DataPath != "" { + resolved.dataPath = options.DataPath + } + + return resolved +} + +func defaultPaths(goos, tempDir, appName string) (string, string) { + name := sanitizeAppName(appName) + + if goos == "freebsd" { + runtimeBase := filepath.Join("/var/tmp", name, "embedded-postgres") + dataBase := filepath.Join("/var/db", name, "embedded-postgres") + return filepath.Join(runtimeBase, "runtime"), filepath.Join(dataBase, "data") + } + + base := filepath.Join(tempDir, name, "embedded-postgres") + return filepath.Join(base, "runtime"), filepath.Join(base, "data") +} + +func sanitizeAppName(appName string) string { + appName = strings.TrimSpace(appName) + if appName == "" { + return defaultAppName + } + + replacer := strings.NewReplacer( + string(os.PathSeparator), "-", + "/", "-", + "\\", "-", + " ", "-", + ) + + return replacer.Replace(appName) +} diff --git a/preset/local_development_test.go b/preset/local_development_test.go new file mode 100644 index 0000000..38d7969 --- /dev/null +++ b/preset/local_development_test.go @@ -0,0 +1,64 @@ +package preset + +import ( + "path/filepath" + "testing" + "time" + + embeddedpostgres "github.com/fergusstrange/embedded-postgres" + "github.com/stretchr/testify/assert" +) + +func TestResolveLocalDevelopment_DarwinDefaults(t *testing.T) { + resolved := resolveLocalDevelopment("darwin", "/tmp", "my app", Options{}) + + assert.Equal(t, "", resolved.platform) + assert.Equal(t, filepath.Join("/tmp", "my-app", "embedded-postgres", "runtime"), resolved.runtimePath) + assert.Equal(t, filepath.Join("/tmp", "my-app", "embedded-postgres", "data"), resolved.dataPath) +} + +func TestResolveLocalDevelopment_FreeBSDDefaults(t *testing.T) { + resolved := resolveLocalDevelopment("freebsd", "/tmp", "my app", Options{}) + + assert.Equal(t, "freebsd13", resolved.platform) + assert.Equal(t, filepath.Join("/var/tmp", "my-app", "embedded-postgres", "runtime"), resolved.runtimePath) + assert.Equal(t, filepath.Join("/var/db", "my-app", "embedded-postgres", "data"), resolved.dataPath) +} + +func TestResolveLocalDevelopment_OptionsOverrideDefaults(t *testing.T) { + resolved := resolveLocalDevelopment("freebsd", "/tmp", "my app", Options{ + Version: embeddedpostgres.V17, + Platform: "freebsd14", + Port: 5544, + Database: "appdb", + Username: "appuser", + Password: "secret", + CachePath: "/cache", + RuntimePath: "/runtime", + DataPath: "/data", + BinariesPath: "/binaries", + BinaryRepositoryURL: "https://repo.local/maven2", + StartTimeout: 45 * time.Second, + StartParameters: map[string]string{"max_connections": "200"}, + }) + + assert.Equal(t, embeddedpostgres.V17, resolved.version) + assert.Equal(t, "freebsd14", resolved.platform) + assert.Equal(t, uint32(5544), resolved.port) + assert.Equal(t, "appdb", resolved.database) + assert.Equal(t, "appuser", resolved.username) + assert.Equal(t, "secret", resolved.password) + assert.Equal(t, "/cache", resolved.cachePath) + assert.Equal(t, "/runtime", resolved.runtimePath) + assert.Equal(t, "/data", resolved.dataPath) + assert.Equal(t, "/binaries", resolved.binariesPath) + assert.Equal(t, "https://repo.local/maven2", resolved.binaryRepositoryURL) + assert.Equal(t, 45*time.Second, resolved.startTimeout) + assert.Equal(t, map[string]string{"max_connections": "200"}, resolved.startParameters) +} + +func TestSanitizeAppName(t *testing.T) { + assert.Equal(t, defaultAppName, sanitizeAppName(" ")) + assert.Equal(t, "my-app-name", sanitizeAppName("my/app name")) + assert.Equal(t, "windows-path", sanitizeAppName("windows\\path")) +} diff --git a/remote_fetch.go b/remote_fetch.go index bc95756..b2b2b4a 100644 --- a/remote_fetch.go +++ b/remote_fetch.go @@ -16,12 +16,23 @@ import ( ) // RemoteFetchStrategy provides a strategy to fetch a Postgres binary so that it is available for use. -type RemoteFetchStrategy func() error +type RemoteFetchStrategy func(progressLogger) error + +var freeBSDBinaryRepositoryURL = "https://web.sintel.com.tr/downloads/siper/pg" //nolint:funlen func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionStrategy, cacheLocator CacheLocator) RemoteFetchStrategy { - return func() error { + return func(logf progressLogger) error { operatingSystem, architecture, version := versionStrategy() + cacheLocation, _ := cacheLocator() + + if freeBSDDirectDownloadURL, ok := freeBSDBundleDownloadURL(operatingSystem, architecture); ok { + logProgress(logf, "downloading embedded postgres archive url=%s cache=%s", freeBSDDirectDownloadURL, cacheLocation) + if err := downloadArchiveToCache(freeBSDDirectDownloadURL, cacheLocation, logf); err != nil { + return err + } + return nil + } jarDownloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar", remoteFetchHost, @@ -31,6 +42,7 @@ func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionS operatingSystem, architecture, version) + logProgress(logf, "downloading embedded postgres bundle url=%s cache=%s", jarDownloadURL, cacheLocation) jarDownloadResponse, err := http.Get(jarDownloadURL) if err != nil { @@ -64,8 +76,69 @@ func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionS } } - return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, jarDownloadURL) + return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, jarDownloadURL, logf) + } +} + +func freeBSDBundleDownloadURL(operatingSystem, architecture string) (string, bool) { + switch { + case operatingSystem == "freebsd13" && architecture == "amd64": + return freeBSDBinaryRepositoryURL + "/postgres-freebsd13-x86_64.txz", true + case operatingSystem == "freebsd14" && architecture == "amd64": + return freeBSDBinaryRepositoryURL + "/postgres-freebsd14-x86_64.txz", true + default: + return "", false + } +} + +func downloadArchiveToCache(downloadURL, cacheLocation string, logf progressLogger) error { + downloadResponse, err := http.Get(downloadURL) + if err != nil { + return fmt.Errorf("unable to connect to %s", downloadURL) + } + defer closeBody(downloadResponse)() + + if downloadResponse.StatusCode != http.StatusOK { + return fmt.Errorf("no version found matching archive at %s", downloadURL) + } + + archiveBytes, err := io.ReadAll(downloadResponse.Body) + if err != nil { + return errorFetchingPostgres(err) + } + if err := os.MkdirAll(filepath.Dir(cacheLocation), 0755); err != nil { + return errorExtractingPostgres(err) + } + logProgress(logf, "downloaded embedded postgres archive url=%s cache=%s bytes=%d", downloadURL, cacheLocation, len(archiveBytes)) + return writeArchiveAtomically(cacheLocation, archiveBytes) +} + +func writeArchiveAtomically(cacheLocation string, archiveBytes []byte) error { + renamed := false + + tmp, err := os.CreateTemp(filepath.Dir(cacheLocation), "temp_") + if err != nil { + return errorExtractingPostgres(err) + } + defer func() { + if !renamed { + if err := os.Remove(tmp.Name()); err != nil { + panic(err) + } + } + }() + + if _, err := tmp.Write(archiveBytes); err != nil { + return errorExtractingPostgres(err) + } + if err := tmp.Close(); err != nil { + return errorExtractingPostgres(err) } + if err := renameOrIgnore(tmp.Name(), cacheLocation); err != nil { + return errorExtractingPostgres(err) + } + renamed = true + return nil } func closeBody(resp *http.Response) func() { @@ -79,7 +152,7 @@ func closeBody(resp *http.Response) func() { } } -func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, downloadURL string) error { +func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, downloadURL string, logf progressLogger) error { size := contentLength // if the content length is not set (i.e. chunked encoding), // we need to use the length of the bodyBytes otherwise @@ -100,6 +173,7 @@ func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator Cach for _, file := range zipReader.File { if !file.FileHeader.FileInfo().IsDir() && strings.HasSuffix(file.FileHeader.Name, ".txz") { + logProgress(logf, "writing embedded postgres archive from bundle url=%s entry=%s cache=%s", downloadURL, file.FileHeader.Name, cacheLocation) if err := decompressSingleFile(file, cacheLocation); err != nil { return err } @@ -113,8 +187,6 @@ func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator Cach } func decompressSingleFile(file *zip.File, cacheLocation string) error { - renamed := false - archiveReader, err := file.Open() if err != nil { return errorExtractingPostgres(err) @@ -124,40 +196,7 @@ func decompressSingleFile(file *zip.File, cacheLocation string) error { if err != nil { return errorExtractingPostgres(err) } - - // if multiple processes attempt to extract - // to prevent file corruption when multiple processes attempt to extract at the same time - // first to a cache location, and then move the file into place. - tmp, err := os.CreateTemp(filepath.Dir(cacheLocation), "temp_") - if err != nil { - return errorExtractingPostgres(err) - } - defer func() { - // if anything failed before the rename then the temporary file should be cleaned up. - // if the rename was successful then there is no temporary file to remove. - if !renamed { - if err := os.Remove(tmp.Name()); err != nil { - panic(err) - } - } - }() - - if _, err := tmp.Write(archiveBytes); err != nil { - return errorExtractingPostgres(err) - } - - // Windows cannot rename a file if is it still open. - // The file needs to be manually closed to allow the rename to happen - if err := tmp.Close(); err != nil { - return errorExtractingPostgres(err) - } - - if err := renameOrIgnore(tmp.Name(), cacheLocation); err != nil { - return errorExtractingPostgres(err) - } - renamed = true - - return nil + return writeArchiveAtomically(cacheLocation, archiveBytes) } func errorExtractingPostgres(err error) error { diff --git a/remote_fetch_test.go b/remote_fetch_test.go index 2b4a76b..0626923 100644 --- a/remote_fetch_test.go +++ b/remote_fetch_test.go @@ -2,8 +2,10 @@ package embeddedpostgres import ( "archive/zip" + "bytes" "crypto/sha256" "encoding/hex" + "fmt" "github.com/stretchr/testify/require" "io" "net/http" @@ -22,7 +24,7 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenHttpGet(t *testing.T) { testVersionStrategy(), testCacheLocator()) - err := remoteFetchStrategy() + err := remoteFetchStrategy(nil) assert.EqualError(t, err, "unable to connect to http://localhost:1234/maven2") } @@ -37,11 +39,126 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenHttpStatusNot200(t *testing.T) { testVersionStrategy(), testCacheLocator()) - err := remoteFetchStrategy() + err := remoteFetchStrategy(nil) assert.EqualError(t, err, "no version found matching 1.2.3") } +func Test_defaultRemoteFetchStrategy_UsesConfiguredPlatformOverride(t *testing.T) { + requests := make([]string, 0, 2) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests = append(requests, r.RequestURI) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + versionStrategy := defaultVersionStrategy( + DefaultConfig(). + Version(PostgresVersion("18.3.0")). + Platform("freebsd14"), + "freebsd", + "amd64", + linuxMachineName, + func() bool { + return false + }, + ) + + originalURL := freeBSDBinaryRepositoryURL + freeBSDBinaryRepositoryURL = server.URL + defer func() { + freeBSDBinaryRepositoryURL = originalURL + }() + + cacheDir := t.TempDir() + remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", versionStrategy, func() (string, bool) { + return filepath.Join(cacheDir, "freebsd14.txz"), false + }) + + err := remoteFetchStrategy(nil) + + require.NoError(t, err) + require.NotEmpty(t, requests) + assert.Equal(t, "/postgres-freebsd14-x86_64.txz", requests[0]) +} + +func Test_defaultRemoteFetchStrategy_UsesDefaultFreeBSD13Bundle(t *testing.T) { + requests := make([]string, 0, 2) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests = append(requests, r.RequestURI) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + versionStrategy := defaultVersionStrategy( + DefaultConfig().Version(PostgresVersion("18.3.0")), + "freebsd", + "amd64", + linuxMachineName, + func() bool { + return false + }, + ) + + originalURL := freeBSDBinaryRepositoryURL + freeBSDBinaryRepositoryURL = server.URL + defer func() { + freeBSDBinaryRepositoryURL = originalURL + }() + + cacheDir := t.TempDir() + remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", versionStrategy, func() (string, bool) { + return filepath.Join(cacheDir, "freebsd13.txz"), false + }) + + err := remoteFetchStrategy(nil) + + require.NoError(t, err) + require.NotEmpty(t, requests) + assert.Equal(t, "/postgres-freebsd13-x86_64.txz", requests[0]) +} + +func Test_defaultRemoteFetchStrategy_LogsFreeBSDBundleDownload(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("txz")) + })) + defer server.Close() + + versionStrategy := defaultVersionStrategy( + DefaultConfig().Version(PostgresVersion("18.3.0")), + "freebsd", + "amd64", + linuxMachineName, + func() bool { + return false + }, + ) + + originalURL := freeBSDBinaryRepositoryURL + freeBSDBinaryRepositoryURL = server.URL + defer func() { + freeBSDBinaryRepositoryURL = originalURL + }() + + cacheDir := t.TempDir() + cacheLocation := filepath.Join(cacheDir, "freebsd13.txz") + remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", versionStrategy, func() (string, bool) { + return cacheLocation, false + }) + + var logs bytes.Buffer + err := remoteFetchStrategy(func(format string, args ...any) { + _, _ = fmt.Fprintf(&logs, format+"\n", args...) + }) + + require.NoError(t, err) + assert.Contains(t, logs.String(), "downloading embedded postgres archive") + assert.Contains(t, logs.String(), "/postgres-freebsd13-x86_64.txz") + assert.Contains(t, logs.String(), cacheLocation) + assert.Contains(t, logs.String(), "downloaded embedded postgres archive") +} + func Test_defaultRemoteFetchStrategy_ErrorWhenBodyReadIssue(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", "1") @@ -52,7 +169,7 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenBodyReadIssue(t *testing.T) { testVersionStrategy(), testCacheLocator()) - err := remoteFetchStrategy() + err := remoteFetchStrategy(nil) assert.EqualError(t, err, "error fetching postgres: unexpected EOF") } @@ -70,7 +187,7 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzipSubFile(t *testing.T) { testVersionStrategy(), testCacheLocator()) - err := remoteFetchStrategy() + err := remoteFetchStrategy(nil) assert.EqualError(t, err, "error fetching postgres: zip: not a valid zip file") } @@ -92,7 +209,7 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzip(t *testing.T) { testVersionStrategy(), testCacheLocator()) - err := remoteFetchStrategy() + err := remoteFetchStrategy(nil) assert.EqualError(t, err, "error fetching postgres: zip: not a valid zip file") } @@ -116,7 +233,7 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenNoSubTarArchive(t *testing.T) { testVersionStrategy(), testCacheLocator()) - err := remoteFetchStrategy() + err := remoteFetchStrategy(nil) assert.EqualError(t, err, "error fetching postgres: cannot find binary in archive retrieved from "+server.URL+"/maven2/io/zonky/test/postgres/embedded-postgres-binaries-darwin-amd64/1.2.3/embedded-postgres-binaries-darwin-amd64-1.2.3.jar") } @@ -147,7 +264,7 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotExtractSubArchive(t *testing return filepath.FromSlash("/invalid"), false }) - err := remoteFetchStrategy() + err := remoteFetchStrategy(nil) assert.Regexp(t, "^unable to extract postgres archive:.+$", err) } @@ -187,7 +304,7 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateCacheDirectory(t *test return cacheLocation, false }) - err := remoteFetchStrategy() + err := remoteFetchStrategy(nil) assert.Regexp(t, "^unable to extract postgres archive:.+$", err) } @@ -224,7 +341,7 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateSubArchiveFile(t *test return "/\\000", false }) - err := remoteFetchStrategy() + err := remoteFetchStrategy(nil) assert.Regexp(t, "^unable to extract postgres archive:.+$", err) } @@ -262,7 +379,7 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenSHA256NotMatch(t *testing.T) { return cacheLocation, false }) - err := remoteFetchStrategy() + err := remoteFetchStrategy(nil) assert.EqualError(t, err, "downloaded checksums do not match") } @@ -301,7 +418,7 @@ func Test_defaultRemoteFetchStrategy(t *testing.T) { return cacheLocation, false }) - err := remoteFetchStrategy() + err := remoteFetchStrategy(nil) assert.NoError(t, err) assert.FileExists(t, cacheLocation) @@ -354,7 +471,7 @@ func Test_defaultRemoteFetchStrategyWithExistingDownload(t *testing.T) { }) // call it the remoteFetchStrategy(). The output location should be empty and a new file created - err = remoteFetchStrategy() + err = remoteFetchStrategy(nil) assert.NoError(t, err) assert.FileExists(t, cacheLocation) out1, err := os.ReadFile(cacheLocation) @@ -364,7 +481,7 @@ func Test_defaultRemoteFetchStrategyWithExistingDownload(t *testing.T) { assert.NoError(t, err) // call the remoteFetchStrategy() again, this time the file should be overwritten - err = remoteFetchStrategy() + err = remoteFetchStrategy(nil) assert.NoError(t, err) assert.FileExists(t, cacheLocation) @@ -412,7 +529,7 @@ func Test_defaultRemoteFetchStrategy_whenContentLengthNotSet(t *testing.T) { return cacheLocation, false }) - err = remoteFetchStrategy() + err = remoteFetchStrategy(nil) assert.NoError(t, err) assert.FileExists(t, cacheLocation) diff --git a/runtime_user.go b/runtime_user.go new file mode 100644 index 0000000..5741c8f --- /dev/null +++ b/runtime_user.go @@ -0,0 +1,91 @@ +package embeddedpostgres + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" +) + +const freebsdRuntimeUser = "embeddedpg" + +func shouldUseFreeBSDRuntimeUser() bool { + return runtime.GOOS == "freebsd" && os.Geteuid() == 0 +} + +func ensureFreeBSDRuntimeUser(binaryExtractLocation, runtimePath, dataPath string) error { + if !shouldUseFreeBSDRuntimeUser() { + return nil + } + + if err := exec.Command("pw", "groupshow", freebsdRuntimeUser).Run(); err != nil { + cmd := exec.Command("pw", "useradd", freebsdRuntimeUser, "-m", "-s", "/bin/sh") + if output, addErr := cmd.CombinedOutput(); addErr != nil { + return errorWithOutput(addErr, output) + } + } + + for _, path := range []string{runtimePath, filepath.Dir(dataPath), dataPath} { + if path == "" { + continue + } + if err := os.MkdirAll(path, 0755); err != nil { + return err + } + } + + for _, path := range []string{binaryExtractLocation, runtimePath, dataPath, filepath.Dir(dataPath)} { + if path == "" { + continue + } + cmd := exec.Command("chown", "-R", freebsdRuntimeUser+":"+freebsdRuntimeUser, path) + if output, chownErr := cmd.CombinedOutput(); chownErr != nil { + return errorWithOutput(chownErr, output) + } + } + + return nil +} + +func wrapCommandForRuntimeUser(cmd *exec.Cmd) *exec.Cmd { + if !shouldUseFreeBSDRuntimeUser() { + return cmd + } + + args := append([]string{"-m", freebsdRuntimeUser, "-c", shellQuoteCommand(cmd.Path, cmd.Args[1:])}, []string{}...) + wrapped := exec.Command("su", args...) + wrapped.Stdout = cmd.Stdout + wrapped.Stderr = cmd.Stderr + wrapped.Env = cmd.Env + wrapped.Dir = cmd.Dir + return wrapped +} + +func shellQuoteCommand(binary string, args []string) string { + quoted := shellQuote(binary) + for _, arg := range args { + quoted += " " + shellQuote(arg) + } + return quoted +} + +func shellQuote(value string) string { + escaped := "'" + for _, r := range value { + if r == '\'' { + escaped += "'\"'\"'" + } else { + escaped += string(r) + } + } + escaped += "'" + return escaped +} + +func errorWithOutput(err error, output []byte) error { + if len(output) == 0 { + return err + } + return fmt.Errorf("%w: %s", err, string(output)) +} diff --git a/test_config.go b/test_config.go deleted file mode 100644 index 5646c9b..0000000 --- a/test_config.go +++ /dev/null @@ -1,12 +0,0 @@ -package embeddedpostgres - -import "testing" - -func TestGetConnectionURL(t *testing.T) { - config := DefaultConfig().Database("mydb").Username("myuser").Password("mypass") - expect := "postgresql://myuser:mypass@localhost:5432/mydb" - - if got := config.GetConnectionURL(); got != expect { - t.Errorf("expected \"%s\" got \"%s\"", expect, got) - } -} diff --git a/version_strategy.go b/version_strategy.go index 136826f..0f8db7e 100644 --- a/version_strategy.go +++ b/version_strategy.go @@ -16,7 +16,16 @@ func defaultVersionStrategy(config Config, goos, arch string, linuxMachineName f goos := goos arch := arch - if goos == "linux" { + if config.platform != "" { + goos = config.platform + } + + if config.platform == "" && goos == "freebsd" && arch == "amd64" { + // The current embedded FreeBSD artifact line is published as freebsd13-amd64. + goos = "freebsd13" + } + + if goos == "linux" || goos == "alpine" { // the zonkyio/embedded-postgres-binaries project produces // arm binaries with the following name schema: // 32bit: arm32v6 / arm32v7 @@ -32,7 +41,7 @@ func defaultVersionStrategy(config Config, goos, arch string, linuxMachineName f } } - if shouldUseAlpineLinuxBuild() { + if goos == "linux" && config.platform == "" && shouldUseAlpineLinuxBuild() { arch += "-alpine" } } diff --git a/version_strategy_test.go b/version_strategy_test.go index d4c89bf..7faa587 100644 --- a/version_strategy_test.go +++ b/version_strategy_test.go @@ -19,7 +19,7 @@ func Test_DefaultVersionStrategy_AllGolangDistributions(t *testing.T) { "darwin/arm64": {"darwin", "amd64"}, "dragonfly/amd64": {"dragonfly", "amd64"}, "freebsd/386": {"freebsd", "386"}, - "freebsd/amd64": {"freebsd", "amd64"}, + "freebsd/amd64": {"freebsd13", "amd64"}, "freebsd/arm": {"freebsd", "arm"}, "freebsd/arm64": {"freebsd", "arm64"}, "illumos/amd64": {"illumos", "amd64"}, @@ -142,6 +142,40 @@ func Test_DefaultVersionStrategy_Linux_Alpine(t *testing.T) { assert.Equal(t, V18, postgresVersion) } +func Test_DefaultVersionStrategy_FreeBSD_PlatformOverride(t *testing.T) { + operatingSystem, architecture, postgresVersion := defaultVersionStrategy( + DefaultConfig().Platform("freebsd14"), + "freebsd", + "amd64", + linuxMachineName, + func() bool { + return false + }, + )() + + assert.Equal(t, "freebsd14", operatingSystem) + assert.Equal(t, "amd64", architecture) + assert.Equal(t, V18, postgresVersion) +} + +func Test_DefaultVersionStrategy_Alpine_PlatformOverride(t *testing.T) { + operatingSystem, architecture, postgresVersion := defaultVersionStrategy( + DefaultConfig().Platform("alpine"), + "linux", + "arm64", + func() string { + return "" + }, + func() bool { + return false + }, + )() + + assert.Equal(t, "alpine", operatingSystem) + assert.Equal(t, "arm64v8", architecture) + assert.Equal(t, V18, postgresVersion) +} + func Test_DefaultVersionStrategy_shouldUseAlpineLinuxBuild(t *testing.T) { assert.NotPanics(t, func() { shouldUseAlpineLinuxBuild()