diff --git a/src/update/BUILD b/src/update/BUILD index 7fffc5204..7729885a4 100644 --- a/src/update/BUILD +++ b/src/update/BUILD @@ -48,6 +48,7 @@ go_test( "///third_party/go/github.com_sigstore_sigstore//pkg/cryptoutils", "///third_party/go/github.com_sigstore_sigstore//pkg/signature", "///third_party/go/github.com_stretchr_testify//assert", + "///third_party/go/github.com_stretchr_testify//require", "///third_party/go/gopkg.in_op_go-logging.v1//:go-logging.v1", "//src/cli", "//src/core", diff --git a/src/update/update.go b/src/update/update.go index 81220ec50..fdb653d89 100644 --- a/src/update/update.go +++ b/src/update/update.go @@ -15,6 +15,7 @@ import ( "io" "net/http" "os" + "os/exec" "path/filepath" "runtime" "strconv" @@ -62,8 +63,9 @@ func CheckAndUpdate(config *core.Configuration, updatesEnabled, updateCommand, f clean(config, updateCommand) return } + newPlease := filepath.Join(config.Please.Location, config.Please.Version.VersionString(), "please") word := describe(config.Please.Version.Semver(), pleaseVersion(), true) - if !updateCommand { + if !updateCommand && !core.PathExists(newPlease) { fmt.Fprintf(os.Stderr, "%s Please from version %s to %s\n", word, pleaseVersion(), config.Please.Version.VersionString()) } @@ -83,7 +85,7 @@ func CheckAndUpdate(config *core.Configuration, updatesEnabled, updateCommand, f } // Download it. - newPlease := downloadAndLinkPlease(config, verify, progress) + newPlease = downloadAndLinkPlease(config, verify, progress) // Print update milestone message if we hit a milestone printMilestoneMessage(config.Please.Version.VersionString()) @@ -91,6 +93,9 @@ func CheckAndUpdate(config *core.Configuration, updatesEnabled, updateCommand, f // Clean out any old ones clean(config, updateCommand) + // Warn if the binary in PATH isn't the one we just updated. + warnIfNotInPath(config) + // Now run the new one. core.ReturnToInitialWorkingDir() args := filterArgs(forceUpdate, append([]string{newPlease}, os.Args[1:]...)) @@ -406,6 +411,36 @@ func writeTarFile(hdr *tar.Header, r io.Reader, destination string) error { return err } +// warnIfNotInPath emits a warning if the binary found via PATH is not the one we just updated. +// This catches the common case where e.g. a Homebrew-installed plz shadows the one in ~/.please. +func warnIfNotInPath(config *core.Configuration) { + name := filepath.Base(os.Args[0]) + pathBinary, err := exec.LookPath(name) + if err != nil { + return + } + pathBinary, _ = filepath.Abs(pathBinary) + if resolved, err := filepath.EvalSymlinks(pathBinary); err == nil { + pathBinary = resolved + } + location, _ := filepath.Abs(config.Please.Location) + if !strings.HasPrefix(pathBinary, location+string(filepath.Separator)) && pathBinary != location { + // Ensure a symlink exists so the binary name the user invoked (e.g. "plz") + // is available in the Please location directory. + nameLink := filepath.Join(location, name) + if _, err := os.Lstat(nameLink); os.IsNotExist(err) { + if err := os.Symlink("please", nameLink); err != nil { + log.Warning("Failed to create %s symlink: %s", nameLink, err) + } + } + log.Warning("Updated Please to %s in %s but %q in your PATH resolves to %s", + config.Please.Version.VersionString(), location, name, pathBinary) + log.Warning("To use the updated version, add %s to your PATH ahead of %s:", + location, filepath.Dir(pathBinary)) + log.Warning(" export PATH=%s:$PATH", location) + } +} + // filterArgs filters out the --force update if forced updates were specified. // This is important so that we don't end up in a loop of repeatedly forcing re-downloads. func filterArgs(forceUpdate bool, args []string) []string { diff --git a/src/update/update_test.go b/src/update/update_test.go index 2f79eeca9..d3b776bd5 100644 --- a/src/update/update_test.go +++ b/src/update/update_test.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/go-retryablehttp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/op/go-logging.v1" "github.com/thought-machine/please/src/cli" @@ -151,6 +152,43 @@ func TestFilterArgs(t *testing.T) { assert.Equal(t, []string{"plz", "update"}, filterArgs(true, []string{"plz", "update", "--force"})) } +func TestWarnIfNotInPath(t *testing.T) { + // Create a fake binary in a temp directory and put it on PATH. + tmpDir := t.TempDir() + fakeBin := filepath.Join(tmpDir, "fake_plz") + assert.NoError(t, os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)) + + oldPath := os.Getenv("PATH") + t.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath) + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + os.Args = []string{"fake_plz"} + + // When location matches the directory containing the binary, no warning expected. + c := makeConfig("warnpath") + c.Please.Location = tmpDir + warnIfNotInPath(c) // should not warn + + // When location doesn't match, it should warn and create a symlink for the binary name. + symlinkDir := t.TempDir() + // Write a "please" binary so the symlink target exists. + assert.NoError(t, os.WriteFile(filepath.Join(symlinkDir, "please"), []byte("#!/bin/sh\n"), 0o755)) + c.Please.Location = symlinkDir + warnIfNotInPath(c) // should warn but not panic + // Should have created a relative symlink: fake_plz -> please + link := filepath.Join(symlinkDir, "fake_plz") + target, err := os.Readlink(link) + require.NoError(t, err) + require.Equal(t, "please", target) + + // Calling again should not fail (symlink already exists). + warnIfNotInPath(c) + + // When the binary isn't in PATH at all, should silently return. + os.Args = []string{"not_in_path_at_all"} + warnIfNotInPath(c) // should not panic +} + func TestFullDistVersion(t *testing.T) { var v cli.Version v.UnmarshalFlag("13.1.9")