diff --git a/cmd/reviewGOOSE/cache.go b/cmd/reviewGOOSE/cache.go index c627a7e..55f2fd2 100644 --- a/cmd/reviewGOOSE/cache.go +++ b/cmd/reviewGOOSE/cache.go @@ -2,101 +2,79 @@ package main import ( "context" - "crypto/sha256" - "encoding/hex" "encoding/json" "fmt" "log/slog" - "os" - "path/filepath" - "strings" "time" + "github.com/codeGROOVE-dev/goose/pkg/prcache" "github.com/codeGROOVE-dev/goose/pkg/safebrowse" "github.com/codeGROOVE-dev/retry" "github.com/codeGROOVE-dev/turnclient/pkg/turn" ) -type cacheEntry struct { - Data *turn.CheckResponse `json:"data"` - CachedAt time.Time `json:"cached_at"` - UpdatedAt time.Time `json:"updated_at"` -} - // checkCache checks the cache for a PR and returns the cached data if valid. // Returns (data, hit, running) where running indicates incomplete tests. -func (app *App) checkCache(path, url string, updatedAt time.Time) (data *turn.CheckResponse, hit, running bool) { - b, err := os.ReadFile(path) - if err != nil { - if !os.IsNotExist(err) { - slog.Debug("[CACHE] Cache file read error", "url", url, "error", err) +func (app *App) checkCache(cacheManager *prcache.Manager, path, url string, updatedAt time.Time) (data *turn.CheckResponse, hit, running bool) { + // State check function for incomplete tests + stateCheck := func(d any) bool { + if m, ok := d.(map[string]any); ok { + if pr, ok := m["pull_request"].(map[string]any); ok { + if state, ok := pr["test_state"].(string); ok { + incomplete := state == "running" || state == "queued" || state == "pending" + // Only bypass for recently updated PRs + if incomplete && time.Since(updatedAt) < time.Hour { + return true + } + } + } } - return nil, false, false + return false } - var e cacheEntry - if err := json.Unmarshal(b, &e); err != nil { - slog.Warn("Failed to unmarshal cache data", "url", url, "error", err) - if err := os.Remove(path); err != nil { - slog.Debug("Failed to remove corrupted cache file", "error", err) - } + // Determine TTL based on whether we expect tests to be running + ttl := cacheTTL + bypassTTL := runningTestsCacheBypass + + result, err := cacheManager.Get(path, updatedAt, ttl, bypassTTL, stateCheck) + if err != nil { + slog.Debug("[CACHE] Cache error", "url", url, "error", err) return nil, false, false } - if e.Data == nil { - slog.Warn("Cache entry missing data", "url", url) - if err := os.Remove(path); err != nil { - slog.Debug("Failed to remove corrupted cache file", "error", err) - } - return nil, false, false + + if !result.Hit { + return nil, false, result.ShouldBypass } - // Determine TTL based on test state - use shorter TTL for incomplete tests - state := e.Data.PullRequest.TestState - incomplete := state == "running" || state == "queued" || state == "pending" - ttl := cacheTTL - if incomplete { - ttl = runningTestsCacheTTL + // Extract turn.CheckResponse from cached data + if result.Entry == nil || result.Entry.Data == nil { + return nil, false, false } - // Check if cache is expired or PR updated - if time.Since(e.CachedAt) >= ttl || !e.UpdatedAt.Equal(updatedAt) { - if !e.UpdatedAt.Equal(updatedAt) { - slog.Debug("[CACHE] Cache miss - PR updated", - "url", url, - "cached_pr_time", e.UpdatedAt.Format(time.RFC3339), - "current_pr_time", updatedAt.Format(time.RFC3339)) - } else { - slog.Debug("[CACHE] Cache miss - TTL expired", - "url", url, - "cached_at", e.CachedAt.Format(time.RFC3339), - "cache_age", time.Since(e.CachedAt).Round(time.Second), - "ttl", ttl, - "test_state", state) - } - return nil, false, incomplete + // Convert map back to CheckResponse + dataBytes, err := json.Marshal(result.Entry.Data) + if err != nil { + slog.Warn("Failed to marshal cached data", "url", url, "error", err) + return nil, false, false } - // Invalidate cache for incomplete tests on recently-updated PRs to catch completion - // Skip this for PRs not updated in over an hour - their pending tests are likely stale - age := time.Since(e.CachedAt) - if incomplete && age < runningTestsCacheBypass && time.Since(updatedAt) < time.Hour { - slog.Debug("[CACHE] Cache invalidated - tests incomplete and cache entry is fresh", - "url", url, - "test_state", state, - "cache_age", age.Round(time.Minute), - "cached_at", e.CachedAt.Format(time.RFC3339)) - return nil, false, true + var response turn.CheckResponse + if err := json.Unmarshal(dataBytes, &response); err != nil { + slog.Warn("Failed to unmarshal cached data", "url", url, "error", err) + return nil, false, false } slog.Debug("[CACHE] Cache hit", "url", url, - "cached_at", e.CachedAt.Format(time.RFC3339), - "cache_age", time.Since(e.CachedAt).Round(time.Second), - "pr_updated_at", e.UpdatedAt.Format(time.RFC3339)) + "cached_at", result.Entry.CachedAt.Format(time.RFC3339), + "cache_age", time.Since(result.Entry.CachedAt).Round(time.Second), + "pr_updated_at", result.Entry.UpdatedAt.Format(time.RFC3339)) + if app.healthMonitor != nil { app.healthMonitor.recordCacheAccess(true) } - return e.Data, true, false + + return &response, true, false } // turnData fetches Turn API data with caching. @@ -110,21 +88,20 @@ func (app *App) turnData(ctx context.Context, url string, updatedAt time.Time) ( return nil, false, fmt.Errorf("invalid URL: %w", err) } - // Create cache key from URL and updated timestamp - key := fmt.Sprintf("%s-%s", url, updatedAt.Format(time.RFC3339)) - h := sha256.Sum256([]byte(key)) - path := filepath.Join(app.cacheDir, hex.EncodeToString(h[:])[:16]+".json") + // Create cache manager and path + cacheManager := prcache.NewManager(app.cacheDir) + cacheKey := prcache.CacheKey(url, updatedAt) + path := cacheManager.CachePath(cacheKey) slog.Debug("[CACHE] Checking cache", "url", url, "updated_at", updatedAt.Format(time.RFC3339), - "cache_key", key, - "cache_file", filepath.Base(path)) + "cache_key", cacheKey) // Check cache unless --no-cache flag is set var running bool if !app.noCache { - data, hit, r := app.checkCache(path, url, updatedAt) + data, hit, r := app.checkCache(cacheManager, path, url, updatedAt) if hit { return data, true, nil } @@ -203,18 +180,11 @@ func (app *App) turnData(ctx context.Context, url string, updatedAt time.Time) ( // Save to cache (don't fail if caching fails) if !app.noCache && data != nil { - e := cacheEntry{Data: data, CachedAt: time.Now(), UpdatedAt: updatedAt} - b, err := json.Marshal(e) - if err != nil { - slog.Error("Failed to marshal cache data", "url", url, "error", err) - } else if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { - slog.Error("Failed to create cache directory", "error", err) - } else if err := os.WriteFile(path, b, 0o600); err != nil { - slog.Error("Failed to write cache file", "error", err) + if err := cacheManager.Put(path, data, updatedAt); err != nil { + slog.Error("Failed to save cache", "url", url, "error", err) } else { slog.Debug("[CACHE] Saved to cache", "url", url, - "cache_file", filepath.Base(path), "test_state", data.PullRequest.TestState) } } @@ -224,33 +194,8 @@ func (app *App) turnData(ctx context.Context, url string, updatedAt time.Time) ( // cleanupOldCache removes cache files older than the cleanup interval (15 days). func (app *App) cleanupOldCache() { - entries, err := os.ReadDir(app.cacheDir) - if err != nil { - slog.Error("Failed to read cache directory for cleanup", "error", err) - return - } - - var cleaned, errs int - for _, e := range entries { - if !strings.HasSuffix(e.Name(), ".json") { - continue - } - info, err := e.Info() - if err != nil { - slog.Error("Failed to get file info for cache entry", "entry", e.Name(), "error", err) - errs++ - continue - } - if time.Since(info.ModTime()) > cacheCleanupInterval { - p := filepath.Join(app.cacheDir, e.Name()) - if err := os.Remove(p); err != nil { - slog.Error("Failed to remove old cache file", "file", p, "error", err) - errs++ - } else { - cleaned++ - } - } - } + cacheManager := prcache.NewManager(app.cacheDir) + cleaned, errs := cacheManager.CleanupOldFiles(cacheCleanupInterval) if cleaned > 0 || errs > 0 { slog.Info("Cache cleanup completed", "removed", cleaned, "errors", errs) diff --git a/cmd/reviewGOOSE/deadlock_test.go b/cmd/reviewGOOSE/deadlock_test.go index 7bd1fd8..c4accdd 100644 --- a/cmd/reviewGOOSE/deadlock_test.go +++ b/cmd/reviewGOOSE/deadlock_test.go @@ -4,6 +4,8 @@ import ( "sync" "testing" "time" + + "github.com/codeGROOVE-dev/goose/pkg/ratelimit" ) // TestConcurrentMenuOperations tests that concurrent menu operations don't cause deadlocks @@ -14,7 +16,7 @@ func TestConcurrentMenuOperations(t *testing.T) { hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), blockedPRTimes: make(map[string]time.Time), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), systrayInterface: &MockSystray{}, incoming: []PR{ {Repository: "org1/repo1", Number: 1, Title: "Fix bug", URL: "https://github.com/org1/repo1/pull/1"}, @@ -101,7 +103,7 @@ func TestMenuClickDeadlockScenario(t *testing.T) { hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), blockedPRTimes: make(map[string]time.Time), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), systrayInterface: &MockSystray{}, incoming: []PR{ {Repository: "org1/repo1", Number: 1, Title: "Test PR", URL: "https://github.com/org1/repo1/pull/1"}, @@ -142,7 +144,7 @@ func TestRapidMenuClicks(t *testing.T) { hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), blockedPRTimes: make(map[string]time.Time), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), systrayInterface: &MockSystray{}, lastSearchAttempt: time.Now().Add(-15 * time.Second), // Allow first click incoming: []PR{ diff --git a/cmd/reviewGOOSE/main.go b/cmd/reviewGOOSE/main.go index 6b1f93d..a65ceaf 100644 --- a/cmd/reviewGOOSE/main.go +++ b/cmd/reviewGOOSE/main.go @@ -7,6 +7,7 @@ package main import ( "context" _ "embed" + "errors" "flag" "fmt" "log/slog" @@ -19,6 +20,8 @@ import ( "time" "github.com/codeGROOVE-dev/goose/cmd/reviewGOOSE/x11tray" + "github.com/codeGROOVE-dev/goose/pkg/logging" + "github.com/codeGROOVE-dev/goose/pkg/ratelimit" "github.com/codeGROOVE-dev/retry" "github.com/codeGROOVE-dev/turnclient/pkg/turn" "github.com/energye/systray" @@ -38,8 +41,8 @@ var ( date = "unknown" ) -// getVersion returns the version string, preferring ldflags but falling back to VERSION file. -func getVersion() string { +// appVersion returns the version string, preferring ldflags but falling back to VERSION file. +func appVersion() string { // If version was set via ldflags and isn't the default, use it if version != "" && version != "dev" { return version @@ -51,6 +54,46 @@ func getVersion() string { return "dev" } +// logDir returns the platform-appropriate directory for application logs. +// - macOS: ~/Library/Logs/reviewGOOSE. +// - Linux: ~/.local/state/reviewGOOSE (or $XDG_STATE_HOME/reviewGOOSE if set). +// - Windows: %LOCALAPPDATA%\reviewGOOSE\Logs. +func logDir() (string, error) { + var dir string + + switch runtime.GOOS { + case "darwin": + // macOS: use ~/Library/Logs + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + dir = filepath.Join(home, "Library", "Logs", "reviewGOOSE") + + case "windows": + // Windows: use %LOCALAPPDATA%\reviewGOOSE\Logs + localAppData := os.Getenv("LOCALAPPDATA") + if localAppData == "" { + return "", errors.New("LOCALAPPDATA environment variable not set") + } + dir = filepath.Join(localAppData, "reviewGOOSE", "Logs") + + default: + // Linux and other Unix: use XDG_STATE_HOME or ~/.local/state + stateHome := os.Getenv("XDG_STATE_HOME") + if stateHome == "" { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + stateHome = filepath.Join(home, ".local", "state") + } + dir = filepath.Join(stateHome, "reviewGOOSE") + } + + return dir, nil +} + const ( cacheTTL = 10 * 24 * time.Hour // 10 days - rely mostly on PR UpdatedAt runningTestsCacheTTL = 2 * time.Minute // Short TTL for PRs with incomplete tests to catch completions quickly @@ -112,7 +155,7 @@ type App struct { lastSuccessfulFetch time.Time startTime time.Time systrayInterface SystrayInterface - browserRateLimiter *BrowserRateLimiter + browserRateLimiter *ratelimit.BrowserRateLimiter blockedPRTimes map[string]time.Time currentUser *github.User stateManager *PRStateManager @@ -168,7 +211,7 @@ func main() { // Handle version flag if showVersion { - fmt.Printf("goose version %s\ncommit: %s\nbuilt: %s\n", getVersion(), commit, date) + fmt.Printf("goose version %s\ncommit: %s\nbuilt: %s\n", appVersion(), commit, date) os.Exit(0) } @@ -207,7 +250,7 @@ func main() { } opts := &slog.HandlerOptions{AddSource: true, Level: logLevel, ReplaceAttr: simplifySource} slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, opts))) - slog.Info("Starting Goose", "version", getVersion(), "commit", commit, "date", date) + slog.Info("Starting Goose", "version", appVersion(), "commit", commit, "date", date) slog.Info("Configuration", "update_interval", updateInterval, "max_retries", maxRetries, "max_delay", maxRetryDelay) slog.Info("Browser auto-open configuration", "startup_delay", browserOpenDelay, @@ -228,25 +271,26 @@ func main() { os.Exit(1) } - // Set up file-based logging alongside cache - logDir := filepath.Join(cacheDir, "logs") - if err := os.MkdirAll(logDir, dirPerm); err != nil { + // Set up file-based logging in platform-appropriate location + logDirectory, err := logDir() + if err != nil { + slog.Error("Failed to determine log directory", "error", err) + // Continue without file logging + } else if err := os.MkdirAll(logDirectory, dirPerm); err != nil { slog.Error("Failed to create log directory", "error", err) // Continue without file logging } else { // Create log file with daily rotation - logPath := filepath.Join(logDir, fmt.Sprintf("goose-%s.log", time.Now().Format("2006-01-02"))) + logPath := filepath.Join(logDirectory, fmt.Sprintf("goose-%s.log", time.Now().Format("2006-01-02"))) logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600) if err != nil { slog.Error("Failed to open log file", "error", err) } else { // Update logger to write to both stderr and file - multiHandler := &MultiHandler{ - handlers: []slog.Handler{ - slog.NewTextHandler(os.Stderr, opts), - slog.NewTextHandler(logFile, opts), - }, - } + multiHandler := logging.NewMultiHandler( + slog.NewTextHandler(os.Stderr, opts), + slog.NewTextHandler(logFile, opts), + ) slog.SetDefault(slog.New(multiHandler)) slog.Info("Logs are being written to", "path", logPath) } @@ -262,7 +306,7 @@ func main() { updateInterval: updateInterval, enableAudioCues: true, enableAutoBrowser: false, // Default to false for safety - browserRateLimiter: NewBrowserRateLimiter(browserOpenDelay, maxBrowserOpensMinute, maxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(browserOpenDelay, maxBrowserOpensMinute, maxBrowserOpensDay), startTime: startTime, systrayInterface: &RealSystray{}, // Use real systray implementation seenOrgs: make(map[string]bool), @@ -959,11 +1003,13 @@ func (app *App) tryAutoOpenPR(ctx context.Context, pr *PR, autoBrowserEnabled bo return } - // Skip draft PRs authored by the user we're querying for + // Determine queried user for draft check queriedUser := app.targetUser if queriedUser == "" && app.currentUser != nil { queriedUser = app.currentUser.GetLogin() } + + // Skip draft PRs authored by the user we're querying for if pr.IsDraft && pr.Author == queriedUser { slog.Debug("[BROWSER] Skipping auto-open for draft PR by queried user", "repo", pr.Repository, "number", pr.Number, "author", pr.Author) @@ -971,7 +1017,6 @@ func (app *App) tryAutoOpenPR(ctx context.Context, pr *PR, autoBrowserEnabled bo } // Only auto-open if the PR is actually blocked or needs review - // This ensures we have a valid NextAction before opening if !pr.IsBlocked && !pr.NeedsReview { slog.Debug("[BROWSER] Skipping auto-open for non-blocked PR", "repo", pr.Repository, "number", pr.Number, diff --git a/cmd/reviewGOOSE/main_test.go b/cmd/reviewGOOSE/main_test.go index 745a5e8..bbd10ac 100644 --- a/cmd/reviewGOOSE/main_test.go +++ b/cmd/reviewGOOSE/main_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/codeGROOVE-dev/goose/pkg/ratelimit" "github.com/codeGROOVE-dev/turnclient/pkg/turn" "github.com/google/go-github/v57/github" ) @@ -79,7 +80,7 @@ func TestMenuItemTitleTransition(t *testing.T) { hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), blockedPRTimes: make(map[string]time.Time), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), systrayInterface: &MockSystray{}, // Use mock systray to avoid panics } @@ -328,7 +329,7 @@ func TestTrayIconRestoredAfterNetworkRecovery(t *testing.T) { blockedPRTimes: make(map[string]time.Time), hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), systrayInterface: mock, menuInitialized: true, } @@ -376,7 +377,7 @@ func TestTrayTitleUpdates(t *testing.T) { blockedPRTimes: make(map[string]time.Time), hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), systrayInterface: &MockSystray{}, // Use mock systray to avoid panics } @@ -500,7 +501,7 @@ func TestSoundPlaybackDuringTransitions(t *testing.T) { hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), previousBlockedPRs: make(map[string]bool), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), enableAudioCues: true, initialLoadComplete: true, // Set to true to allow sound playback menuInitialized: true, @@ -651,7 +652,7 @@ func TestSoundDisabledNoPlayback(t *testing.T) { hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), previousBlockedPRs: make(map[string]bool), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), enableAudioCues: false, // Audio disabled initialLoadComplete: true, menuInitialized: true, @@ -691,7 +692,7 @@ func TestGracePeriodPreventsNotifications(t *testing.T) { hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), previousBlockedPRs: make(map[string]bool), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), enableAudioCues: true, initialLoadComplete: true, menuInitialized: true, @@ -816,7 +817,7 @@ func TestNotificationScenarios(t *testing.T) { hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), previousBlockedPRs: make(map[string]bool), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), enableAudioCues: true, initialLoadComplete: tt.initialLoadComplete, menuInitialized: true, @@ -876,7 +877,7 @@ func TestNewlyBlockedPRAfterGracePeriod(t *testing.T) { hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), previousBlockedPRs: make(map[string]bool), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), enableAudioCues: true, initialLoadComplete: true, // Already past initial load menuInitialized: true, diff --git a/cmd/reviewGOOSE/menu_change_detection_test.go b/cmd/reviewGOOSE/menu_change_detection_test.go index 1ec52f5..6ee7438 100644 --- a/cmd/reviewGOOSE/menu_change_detection_test.go +++ b/cmd/reviewGOOSE/menu_change_detection_test.go @@ -5,6 +5,8 @@ import ( "sync" "testing" "time" + + "github.com/codeGROOVE-dev/goose/pkg/ratelimit" ) // TestMenuChangeDetection tests that the menu change detection logic works correctly @@ -17,7 +19,7 @@ func TestMenuChangeDetection(t *testing.T) { hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), blockedPRTimes: make(map[string]time.Time), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), systrayInterface: &MockSystray{}, incoming: []PR{ {Repository: "org1/repo1", Number: 1, Title: "Fix bug", URL: "https://github.com/org1/repo1/pull/1", NeedsReview: true, UpdatedAt: time.Now()}, @@ -110,7 +112,7 @@ func TestFirstRunMenuRebuildBug(t *testing.T) { hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), blockedPRTimes: make(map[string]time.Time), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), menuInitialized: false, systrayInterface: &MockSystray{}, lastMenuTitles: nil, // This is nil on first run - the bug condition @@ -169,7 +171,7 @@ func TestHiddenOrgChangesMenu(t *testing.T) { hiddenOrgs: make(map[string]bool), seenOrgs: make(map[string]bool), blockedPRTimes: make(map[string]time.Time), - browserRateLimiter: NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), + browserRateLimiter: ratelimit.NewBrowserRateLimiter(startupGracePeriod, 5, defaultMaxBrowserOpensDay), systrayInterface: &MockSystray{}, incoming: []PR{ {Repository: "org1/repo1", Number: 1, Title: "PR 1", URL: "https://github.com/org1/repo1/pull/1"}, diff --git a/cmd/reviewGOOSE/notifications.go b/cmd/reviewGOOSE/notifications.go index 490f561..b1760cb 100644 --- a/cmd/reviewGOOSE/notifications.go +++ b/cmd/reviewGOOSE/notifications.go @@ -38,8 +38,8 @@ func (app *App) processNotifications(ctx context.Context) { // Update deprecated fields for test compatibility app.mu.Lock() - app.previousBlockedPRs = make(map[string]bool) - app.blockedPRTimes = make(map[string]time.Time) + clear(app.previousBlockedPRs) + clear(app.blockedPRTimes) states := app.stateManager.BlockedPRs() for url, state := range states { app.previousBlockedPRs[url] = true diff --git a/cmd/reviewGOOSE/reliability.go b/cmd/reviewGOOSE/reliability.go index b7082e1..4cffbd1 100644 --- a/cmd/reviewGOOSE/reliability.go +++ b/cmd/reviewGOOSE/reliability.go @@ -178,7 +178,11 @@ func (hm *healthMonitor) logMetrics() { sprinklerConnected := false sprinklerLastConnected := "" if hm.app.sprinklerMonitor != nil { - connected, lastConnectedAt := hm.app.sprinklerMonitor.connectionStatus() + hm.app.sprinklerMonitor.mu.RLock() + connected := hm.app.sprinklerMonitor.isConnected + lastConnectedAt := hm.app.sprinklerMonitor.lastConnectedAt + hm.app.sprinklerMonitor.mu.RUnlock() + sprinklerConnected = connected if !lastConnectedAt.IsZero() { sprinklerLastConnected = time.Since(lastConnectedAt).Round(time.Second).String() + " ago" diff --git a/cmd/reviewGOOSE/settings.go b/cmd/reviewGOOSE/settings.go index 7077457..f940c18 100644 --- a/cmd/reviewGOOSE/settings.go +++ b/cmd/reviewGOOSE/settings.go @@ -1,10 +1,9 @@ package main import ( - "encoding/json" "log/slog" - "os" - "path/filepath" + + "github.com/codeGROOVE-dev/goose/pkg/appsettings" ) // Settings represents persistent user settings. @@ -23,24 +22,17 @@ func (app *App) loadSettings() { app.enableAutoBrowser = true app.hiddenOrgs = make(map[string]bool) - configDir, err := os.UserConfigDir() - if err != nil { - slog.Error("Failed to get settings directory", "error", err) - return - } + manager := appsettings.NewManager("reviewGOOSE") - settingsPath := filepath.Join(configDir, "reviewGOOSE", "settings.json") - data, err := os.ReadFile(settingsPath) + var settings Settings + found, err := manager.Load(&settings) if err != nil { - if !os.IsNotExist(err) { - slog.Debug("Failed to read settings", "error", err) - } + slog.Error("Failed to load settings", "error", err) return } - var settings Settings - if err := json.Unmarshal(data, &settings); err != nil { - slog.Error("Failed to parse settings", "error", err) + if !found { + slog.Debug("No settings file found, using defaults") return } @@ -61,13 +53,6 @@ func (app *App) loadSettings() { // saveSettings saves current settings to disk. func (app *App) saveSettings() { - configDir, err := os.UserConfigDir() - if err != nil { - slog.Error("Failed to get settings directory", "error", err) - return - } - settingsDir := filepath.Join(configDir, "reviewGOOSE") - app.mu.RLock() settings := Settings{ EnableAudioCues: app.enableAudioCues, @@ -77,21 +62,8 @@ func (app *App) saveSettings() { } app.mu.RUnlock() - // Ensure directory exists - if err := os.MkdirAll(settingsDir, 0o700); err != nil { - slog.Error("Failed to create settings directory", "error", err) - return - } - - settingsPath := filepath.Join(settingsDir, "settings.json") - - data, err := json.MarshalIndent(settings, "", " ") - if err != nil { - slog.Error("Failed to marshal settings", "error", err) - return - } - - if err := os.WriteFile(settingsPath, data, 0o600); err != nil { + manager := appsettings.NewManager("reviewGOOSE") + if err := manager.Save(&settings); err != nil { slog.Error("Failed to save settings", "error", err) return } diff --git a/cmd/reviewGOOSE/sprinkler.go b/cmd/reviewGOOSE/sprinkler.go index 717b8d7..c2d0867 100644 --- a/cmd/reviewGOOSE/sprinkler.go +++ b/cmd/reviewGOOSE/sprinkler.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/codeGROOVE-dev/goose/pkg/dedup" "github.com/codeGROOVE-dev/retry" "github.com/codeGROOVE-dev/sprinkler/pkg/client" "github.com/codeGROOVE-dev/turnclient/pkg/turn" @@ -38,7 +39,7 @@ type sprinklerMonitor struct { client *client.Client cancel context.CancelFunc eventChan chan prEvent - lastEventMap map[string]time.Time + dedup *dedup.Manager token string serverAddress string // Custom server hostname (empty = use default) orgs []string @@ -56,7 +57,7 @@ func newSprinklerMonitor(app *App, token, sprinklerServer string) *sprinklerMoni serverAddress: sprinklerServer, orgs: make([]string, 0), eventChan: make(chan prEvent, eventChannelSize), - lastEventMap: make(map[string]time.Time), + dedup: dedup.New(eventDedupWindow, eventMapCleanupAge, eventMapMaxSize), } } @@ -225,31 +226,10 @@ func (sm *sprinklerMonitor) handleEvent(event client.Event) { } // Dedupe events - only process if we haven't seen this URL recently - sm.mu.Lock() - lastSeen, exists := sm.lastEventMap[event.URL] - now := time.Now() - if exists && now.Sub(lastSeen) < eventDedupWindow { - sm.mu.Unlock() - slog.Debug("[SPRINKLER] Skipping duplicate event", - "url", event.URL, - "last_seen", now.Sub(lastSeen).Round(time.Millisecond)) + if !sm.dedup.ShouldProcess(event.URL, time.Now()) { + slog.Debug("[SPRINKLER] Skipping duplicate event", "url", event.URL) return } - sm.lastEventMap[event.URL] = now - - // Clean up old entries to prevent memory leak - if len(sm.lastEventMap) > eventMapMaxSize { - // Remove entries older than the cleanup age threshold - cutoff := now.Add(-eventMapCleanupAge) - for url, timestamp := range sm.lastEventMap { - if timestamp.Before(cutoff) { - delete(sm.lastEventMap, url) - } - } - slog.Debug("[SPRINKLER] Cleaned up event map", - "entries_remaining", len(sm.lastEventMap)) - } - sm.mu.Unlock() slog.Info("[SPRINKLER] PR event received", "url", event.URL, @@ -304,9 +284,17 @@ func (sm *sprinklerMonitor) checkAndNotify(ctx context.Context, evt prEvent) { return } - repo, n := parseRepoAndNumberFromURL(evt.url) - if repo == "" || n == 0 { - slog.Warn("[SPRINKLER] Failed to parse PR URL", "url", evt.url) + // Parse repo and PR number from URL (https://github.com/org/repo/pull/123) + parts := strings.Split(evt.url, "/") + const minParts = 7 + if len(parts) < minParts || parts[2] != "github.com" { + slog.Warn("[SPRINKLER] Invalid PR URL format", "url", evt.url) + return + } + repo := fmt.Sprintf("%s/%s", parts[3], parts[4]) + var n int + if _, err := fmt.Sscanf(parts[6], "%d", &n); err != nil { + slog.Warn("[SPRINKLER] Failed to parse PR number from URL", "url", evt.url, "error", err) return } @@ -319,12 +307,33 @@ func (sm *sprinklerMonitor) checkAndNotify(ctx context.Context, evt prEvent) { return } - act := validateUserAction(data, user, repo, n, cached) - if act == nil { + // Check if user needs to take critical action + if data.Analysis.NextAction == nil { + slog.Debug("[SPRINKLER] No turn data available", + "repo", repo, + "number", n, + "cached", cached) + return + } + act, exists := data.Analysis.NextAction[user] + if !exists { + slog.Debug("[SPRINKLER] No action required for user", + "repo", repo, + "number", n, + "user", user, + "state", data.PullRequest.State) + return + } + if !act.Critical { + slog.Debug("[SPRINKLER] Non-critical action, skipping notification", + "repo", repo, + "number", n, + "action", act.Kind, + "critical", act.Critical) return } - if sm.handleNewPR(ctx, evt.url, repo, n, act) { + if sm.handleNewPR(ctx, evt.url, repo, n, &act) { return } @@ -340,7 +349,7 @@ func (sm *sprinklerMonitor) checkAndNotify(ctx context.Context, evt prEvent) { "event_timestamp", evt.timestamp.Format(time.RFC3339), "elapsed", time.Since(start).Round(time.Millisecond)) - sm.sendNotifications(ctx, evt.url, repo, n, act) + sm.sendNotifications(ctx, evt.url, repo, n, &act) } // fetchTurnData retrieves PR data from Turn API with retry logic. @@ -415,38 +424,6 @@ func (sm *sprinklerMonitor) handleClosedPR( return false } -// validateUserAction checks if the user needs to take action and returns the action if critical. -func validateUserAction(data *turn.CheckResponse, user, repo string, n int, cached bool) *turn.Action { - if data == nil || data.Analysis.NextAction == nil { - slog.Debug("[SPRINKLER] No turn data available", - "repo", repo, - "number", n, - "cached", cached) - return nil - } - - act, exists := data.Analysis.NextAction[user] - if !exists { - slog.Debug("[SPRINKLER] No action required for user", - "repo", repo, - "number", n, - "user", user, - "state", data.PullRequest.State) - return nil - } - - if !act.Critical { - slog.Debug("[SPRINKLER] Non-critical action, skipping notification", - "repo", repo, - "number", n, - "action", act.Kind, - "critical", act.Critical) - return nil - } - - return &act -} - // handleNewPR triggers a refresh for PRs not in our lists and returns true if handled. func (sm *sprinklerMonitor) handleNewPR(ctx context.Context, url, repo string, n int, act *turn.Action) bool { sm.app.mu.RLock() @@ -598,30 +575,3 @@ func (sm *sprinklerMonitor) stop() { sm.cancel() sm.isRunning = false } - -// connectionStatus returns the current WebSocket connection status. -func (sm *sprinklerMonitor) connectionStatus() (connected bool, lastConnectedAt time.Time) { - sm.mu.RLock() - defer sm.mu.RUnlock() - return sm.isConnected, sm.lastConnectedAt -} - -// parseRepoAndNumberFromURL extracts repo and PR number from URL. -func parseRepoAndNumberFromURL(url string) (repo string, number int) { - // URL format: https://github.com/org/repo/pull/123 - const minParts = 7 - parts := strings.Split(url, "/") - if len(parts) < minParts || parts[2] != "github.com" { - return "", 0 - } - - repo = fmt.Sprintf("%s/%s", parts[3], parts[4]) - - var n int - _, err := fmt.Sscanf(parts[6], "%d", &n) - if err != nil { - return "", 0 - } - - return repo, n -} diff --git a/cmd/reviewGOOSE/ui.go b/cmd/reviewGOOSE/ui.go index f7b906d..196e45e 100644 --- a/cmd/reviewGOOSE/ui.go +++ b/cmd/reviewGOOSE/ui.go @@ -740,7 +740,7 @@ func (app *App) rebuildMenu(ctx context.Context) { // Add Web Dashboard link dashboardItem := app.systrayInterface.AddMenuItem("Web Dashboard", "") dashboardItem.Click(func() { - if err := openURL(ctx, "https://reviewGOOSE.dev/", ""); err != nil { + if err := openURL(ctx, "https://my.reviewGOOSE.dev/", ""); err != nil { slog.Error("failed to open dashboard", "error", err) } }) diff --git a/pkg/appsettings/appsettings.go b/pkg/appsettings/appsettings.go new file mode 100644 index 0000000..8f8cde0 --- /dev/null +++ b/pkg/appsettings/appsettings.go @@ -0,0 +1,76 @@ +// Package appsettings provides functionality for loading and saving application settings. +package appsettings + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// Manager handles loading and saving settings to disk. +type Manager struct { + appName string +} + +// NewManager creates a new settings manager for the given application name. +func NewManager(appName string) *Manager { + return &Manager{appName: appName} +} + +// Path returns the path to the settings file. +func (m *Manager) Path() (string, error) { + configDir, err := os.UserConfigDir() + if err != nil { + return "", fmt.Errorf("get user config dir: %w", err) + } + return filepath.Join(configDir, m.appName, "settings.json"), nil +} + +// Load loads settings from disk into the provided struct. +// Returns false if the file doesn't exist (not an error). +func (m *Manager) Load(settings any) (bool, error) { + path, err := m.Path() + if err != nil { + return false, err + } + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return false, nil // File doesn't exist, not an error + } + return false, fmt.Errorf("read settings file: %w", err) + } + + if err := json.Unmarshal(data, settings); err != nil { + return false, fmt.Errorf("parse settings: %w", err) + } + + return true, nil +} + +// Save saves settings to disk. +func (m *Manager) Save(settings any) error { + path, err := m.Path() + if err != nil { + return err + } + + // Ensure directory exists + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("create settings directory: %w", err) + } + + data, err := json.MarshalIndent(settings, "", " ") + if err != nil { + return fmt.Errorf("marshal settings: %w", err) + } + + if err := os.WriteFile(path, data, 0o600); err != nil { + return fmt.Errorf("write settings file: %w", err) + } + + return nil +} diff --git a/pkg/appsettings/appsettings_test.go b/pkg/appsettings/appsettings_test.go new file mode 100644 index 0000000..4dd0751 --- /dev/null +++ b/pkg/appsettings/appsettings_test.go @@ -0,0 +1,291 @@ +package appsettings + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +type testSettings struct { + Name string `json:"name"` + Enabled bool `json:"enabled"` + Count int `json:"count"` + Tags map[string]bool `json:"tags"` +} + +func TestNewManager(t *testing.T) { + m := NewManager("testapp") + if m == nil { + t.Fatal("NewManager returned nil") + } + if m.appName != "testapp" { + t.Errorf("appName = %q, want %q", m.appName, "testapp") + } +} + +func TestPath(t *testing.T) { + m := NewManager("testapp") + path, err := m.Path() + if err != nil { + t.Fatalf("Path() error = %v", err) + } + + if !filepath.IsAbs(path) { + t.Errorf("Path is not absolute: %q", path) + } + + if !strings.HasPrefix(path, os.Getenv("HOME")) && !strings.HasPrefix(path, os.Getenv("APPDATA")) { + t.Logf("Warning: Path may not be in user directory: %q", path) + } + + expectedSuffix := filepath.Join("testapp", "settings.json") + if !strings.HasSuffix(path, expectedSuffix) { + t.Errorf("Path should end with %q, got %q", expectedSuffix, path) + } +} + +func TestSaveAndLoad(t *testing.T) { + // Use temporary directory to avoid interfering with real config + tmpDir := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", tmpDir) // Linux + t.Setenv("HOME", tmpDir) // macOS fallback + t.Setenv("APPDATA", tmpDir) // Windows + + m := NewManager("testapp") + + // Create test settings + original := testSettings{ + Name: "test", + Enabled: true, + Count: 42, + Tags: map[string]bool{"go": true, "rust": false}, + } + + // Save settings + err := m.Save(&original) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify file exists + path, err := m.Path() + if err != nil { + t.Fatalf("Path() error = %v", err) + } + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Fatal("Settings file was not created") + } + + // Load settings + var loaded testSettings + found, err := m.Load(&loaded) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if !found { + t.Fatal("Load() returned found=false, expected true") + } + + // Verify loaded matches original + if loaded.Name != original.Name { + t.Errorf("Name = %q, want %q", loaded.Name, original.Name) + } + if loaded.Enabled != original.Enabled { + t.Errorf("Enabled = %v, want %v", loaded.Enabled, original.Enabled) + } + if loaded.Count != original.Count { + t.Errorf("Count = %d, want %d", loaded.Count, original.Count) + } + if len(loaded.Tags) != len(original.Tags) { + t.Errorf("Tags length = %d, want %d", len(loaded.Tags), len(original.Tags)) + } + for k, v := range original.Tags { + if loaded.Tags[k] != v { + t.Errorf("Tags[%q] = %v, want %v", k, loaded.Tags[k], v) + } + } +} + +func TestLoad_FileNotExists(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", tmpDir) + t.Setenv("HOME", tmpDir) + t.Setenv("APPDATA", tmpDir) + + m := NewManager("nonexistent") + + var settings testSettings + found, err := m.Load(&settings) + if err != nil { + t.Errorf("Load() error = %v, want nil for nonexistent file", err) + } + if found { + t.Error("Load() returned found=true for nonexistent file, want false") + } +} + +func TestLoad_CorruptedFile(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", tmpDir) + t.Setenv("HOME", tmpDir) + t.Setenv("APPDATA", tmpDir) + + m := NewManager("corrupted") + + // Create corrupted settings file + path, err := m.Path() + if err != nil { + t.Fatalf("Path() error = %v", err) + } + err = os.MkdirAll(filepath.Dir(path), 0o700) + if err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + err = os.WriteFile(path, []byte("not valid json {{{"), 0o600) + if err != nil { + t.Fatalf("Failed to write corrupted file: %v", err) + } + + var settings testSettings + found, err := m.Load(&settings) + if err == nil { + t.Error("Load() should return error for corrupted file") + } + if found { + t.Error("Load() returned found=true for corrupted file") + } +} + +func TestSave_CreateDirectory(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", tmpDir) + t.Setenv("HOME", tmpDir) + t.Setenv("APPDATA", tmpDir) + + m := NewManager("testapp") + + settings := testSettings{Name: "test"} + + err := m.Save(&settings) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify directory was created + path, err := m.Path() + if err != nil { + t.Fatalf("Path() error = %v", err) + } + dir := filepath.Dir(path) + if _, err := os.Stat(dir); os.IsNotExist(err) { + t.Error("Settings directory should have been created") + } + + // Verify file exists + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Error("Settings file should have been created") + } +} + +func TestSave_Overwrite(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", tmpDir) + t.Setenv("HOME", tmpDir) + t.Setenv("APPDATA", tmpDir) + + m := NewManager("testapp") + + // Save first version + first := testSettings{Name: "first", Count: 1} + err := m.Save(&first) + if err != nil { + t.Fatalf("First Save() error = %v", err) + } + + // Save second version (overwrite) + second := testSettings{Name: "second", Count: 2} + err = m.Save(&second) + if err != nil { + t.Fatalf("Second Save() error = %v", err) + } + + // Load and verify it has the second version + var loaded testSettings + found, err := m.Load(&loaded) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if !found { + t.Fatal("Load() returned found=false") + } + + if loaded.Name != "second" { + t.Errorf("Name = %q, want %q (should be overwritten)", loaded.Name, "second") + } + if loaded.Count != 2 { + t.Errorf("Count = %d, want %d (should be overwritten)", loaded.Count, 2) + } +} + +func TestSave_FilePermissions(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", tmpDir) + t.Setenv("HOME", tmpDir) + t.Setenv("APPDATA", tmpDir) + + m := NewManager("testapp") + + settings := testSettings{Name: "test"} + err := m.Save(&settings) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + path, err := m.Path() + if err != nil { + t.Fatalf("Path() error = %v", err) + } + info, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat() error = %v", err) + } + + // Check file permissions are restrictive (0o600 on Unix) + mode := info.Mode() + if mode.Perm() != 0o600 { + t.Logf("Warning: File permissions are %o, expected 0o600", mode.Perm()) + } +} + +func TestLoad_EmptyFile(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", tmpDir) + t.Setenv("HOME", tmpDir) + t.Setenv("APPDATA", tmpDir) + + m := NewManager("emptyfile") + + // Create empty settings file + path, err := m.Path() + if err != nil { + t.Fatalf("Path() error = %v", err) + } + err = os.MkdirAll(filepath.Dir(path), 0o700) + if err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + err = os.WriteFile(path, []byte(""), 0o600) + if err != nil { + t.Fatalf("Failed to write empty file: %v", err) + } + + var settings testSettings + found, err := m.Load(&settings) + if err == nil { + t.Error("Load() should return error for empty file") + } + if found { + t.Error("Load() returned found=true for empty file") + } +} diff --git a/pkg/dedup/dedup.go b/pkg/dedup/dedup.go new file mode 100644 index 0000000..ec1e5f6 --- /dev/null +++ b/pkg/dedup/dedup.go @@ -0,0 +1,60 @@ +// Package dedup provides time-based event deduplication. +package dedup + +import ( + "sync" + "time" +) + +// Manager deduplicates events within a time window. +type Manager struct { + last map[string]time.Time + mu sync.Mutex + window time.Duration + cleanupAge time.Duration + maxSize int +} + +// New creates a deduplication manager. +func New(window, cleanupAge time.Duration, maxSize int) *Manager { + return &Manager{ + last: make(map[string]time.Time), + window: window, + maxSize: maxSize, + cleanupAge: cleanupAge, + } +} + +// ShouldProcess returns true if the event should be processed. +// Returns false if it's a duplicate within the dedup window. +// Safe for concurrent use. +func (d *Manager) ShouldProcess(key string, t time.Time) bool { + d.mu.Lock() + defer d.mu.Unlock() + + if last, ok := d.last[key]; ok && t.Sub(last) < d.window { + return false + } + + d.last[key] = t + + // Cleanup old entries if map is too large + if len(d.last) > d.maxSize { + cutoff := t.Add(-d.cleanupAge) + for k, ts := range d.last { + if ts.Before(cutoff) { + delete(d.last, k) + } + } + } + + return true +} + +// Size returns the current number of tracked events. +// Safe for concurrent use. +func (d *Manager) Size() int { + d.mu.Lock() + defer d.mu.Unlock() + return len(d.last) +} diff --git a/pkg/dedup/dedup_test.go b/pkg/dedup/dedup_test.go new file mode 100644 index 0000000..38e64ff --- /dev/null +++ b/pkg/dedup/dedup_test.go @@ -0,0 +1,124 @@ +package dedup + +import ( + "testing" + "time" +) + +func TestNew(t *testing.T) { + m := New(5*time.Second, 1*time.Hour, 100) + if m == nil { + t.Fatal("New returned nil") + } + if m.window != 5*time.Second { + t.Errorf("window = %v, want %v", m.window, 5*time.Second) + } + if m.cleanupAge != 1*time.Hour { + t.Errorf("cleanupAge = %v, want %v", m.cleanupAge, 1*time.Hour) + } + if m.maxSize != 100 { + t.Errorf("maxSize = %d, want 100", m.maxSize) + } + if m.Size() != 0 { + t.Errorf("initial Size() = %d, want 0", m.Size()) + } +} + +func TestManager_ShouldProcess(t *testing.T) { + m := New(100*time.Millisecond, 1*time.Hour, 100) + now := time.Now() + + // First event should be processed + if !m.ShouldProcess("url1", now) { + t.Error("First event should be processed") + } + if m.Size() != 1 { + t.Errorf("Size after first event = %d, want 1", m.Size()) + } + + // Duplicate within window should not be processed + if m.ShouldProcess("url1", now.Add(50*time.Millisecond)) { + t.Error("Duplicate within dedup window should not be processed") + } + + // After window, should be processed again + if !m.ShouldProcess("url1", now.Add(150*time.Millisecond)) { + t.Error("Event after dedup window should be processed") + } + + // Different URL should be processed + if !m.ShouldProcess("url2", now) { + t.Error("Different URL should be processed") + } + if m.Size() != 2 { + t.Errorf("Size after second URL = %d, want 2", m.Size()) + } +} + +func TestManager_Cleanup(t *testing.T) { + m := New(5*time.Second, 1*time.Minute, 10) + now := time.Now() + + // Add events to fill beyond max + for i := range 15 { + url := "url" + string(rune('0'+i)) + m.ShouldProcess(url, now.Add(-2*time.Minute)) + } + + // Add a new event to trigger cleanup + m.ShouldProcess("trigger", now) + + // Old entries should be cleaned up + if sz := m.Size(); sz > 11 { + t.Errorf("Size after cleanup = %d, expected cleanup to reduce it", sz) + } +} + +func TestManager_MultipleEvents(t *testing.T) { + m := New(200*time.Millisecond, 1*time.Hour, 100) + now := time.Now() + + urls := []string{"url1", "url2", "url3"} + + // All should be processed initially + for _, url := range urls { + if !m.ShouldProcess(url, now) { + t.Errorf("Initial event for %s should be processed", url) + } + } + + if m.Size() != 3 { + t.Errorf("Size = %d, want 3", m.Size()) + } + + // All duplicates should be rejected + for _, url := range urls { + if m.ShouldProcess(url, now.Add(100*time.Millisecond)) { + t.Errorf("Duplicate event for %s should not be processed", url) + } + } + + // After window, all should be processed again + for _, url := range urls { + if !m.ShouldProcess(url, now.Add(250*time.Millisecond)) { + t.Errorf("Event after window for %s should be processed", url) + } + } +} + +func TestManager_ExactWindowBoundary(t *testing.T) { + m := New(100*time.Millisecond, 1*time.Hour, 100) + now := time.Now() + + m.ShouldProcess("url1", now) + + // Exactly at window boundary should not be processed (< not <=) + if m.ShouldProcess("url1", now.Add(99*time.Millisecond)) { + t.Error("Event just before window end should not be processed") + } + + // Just after window should be processed + if !m.ShouldProcess("url1", now.Add(100*time.Millisecond)) { + t.Error("Event at window boundary should be processed") + } +} diff --git a/pkg/icon/icon.go b/pkg/icon/icon.go index 837a016..1a20a1a 100644 --- a/pkg/icon/icon.go +++ b/pkg/icon/icon.go @@ -221,7 +221,7 @@ func (c *Cache) Put(incoming, outgoing int, data []byte) { // Simple size limit if len(c.icons) > 100 { - c.icons = make(map[string][]byte) + clear(c.icons) } c.icons[key(incoming, outgoing)] = data diff --git a/pkg/icon/icon_test.go b/pkg/icon/icon_test.go index 9f6cadc..066cf1e 100644 --- a/pkg/icon/icon_test.go +++ b/pkg/icon/icon_test.go @@ -94,3 +94,68 @@ func TestFormat(t *testing.T) { } } } + +func TestScale(t *testing.T) { + // Create a test icon (red circle with "5") + originalData, err := Badge(5, 0) + if err != nil { + t.Fatalf("Badge() failed: %v", err) + } + if originalData == nil { + t.Fatal("Badge() returned nil") + } + + // Scale it + scaled, err := Scale(originalData) + if err != nil { + t.Fatalf("Scale() error = %v", err) + } + + // Verify it's valid PNG + img, err := png.Decode(bytes.NewReader(scaled)) + if err != nil { + t.Fatalf("invalid PNG after scaling: %v", err) + } + + // Verify dimensions match Size constant + bounds := img.Bounds() + if bounds.Dx() != Size || bounds.Dy() != Size { + t.Errorf("wrong dimensions: got %dx%d, want %dx%d", + bounds.Dx(), bounds.Dy(), Size, Size) + } + + // Test error case: invalid PNG data + _, err = Scale([]byte("not a png")) + if err == nil { + t.Error("Scale() should fail with invalid PNG data") + } +} + +func TestCacheOverflow(t *testing.T) { + c := NewCache() + + // Fill cache to exactly 101 entries (exceeds limit of 100) + for i := range 101 { + c.Put(i, 0, []byte("test")) + } + + // At this point we have 101 entries (exceeds limit but not cleared yet) + // Add one more entry to trigger cache clear + c.Put(999, 0, []byte("test")) + + // After clearing and adding entry 999, only entry 999 should be present + if _, ok := c.Lookup(999, 0); !ok { + t.Error("expected entry 999 after cache overflow") + } + + // Old entries should be gone after cache was cleared + found := 0 + for i := range 101 { + if _, ok := c.Lookup(i, 0); ok { + found++ + } + } + if found > 0 { + t.Errorf("expected old entries to be cleared after overflow, but found %d", found) + } +} diff --git a/cmd/reviewGOOSE/multihandler.go b/pkg/logging/multihandler.go similarity index 71% rename from cmd/reviewGOOSE/multihandler.go rename to pkg/logging/multihandler.go index a4f0be9..3c22ddb 100644 --- a/cmd/reviewGOOSE/multihandler.go +++ b/pkg/logging/multihandler.go @@ -1,4 +1,5 @@ -package main +// Package logging provides logging utilities for the Goose application. +package logging import ( "context" @@ -10,6 +11,11 @@ type MultiHandler struct { handlers []slog.Handler } +// NewMultiHandler creates a new MultiHandler that writes to multiple destinations. +func NewMultiHandler(handlers ...slog.Handler) *MultiHandler { + return &MultiHandler{handlers: handlers} +} + // Enabled returns true if at least one handler is enabled. func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool { for _, handler := range h.handlers { @@ -21,15 +27,14 @@ func (h *MultiHandler) Enabled(ctx context.Context, level slog.Level) bool { } // Handle writes the record to all handlers. +// Errors from individual handlers are silently ignored to ensure all handlers execute. // //nolint:gocritic // record is an interface parameter, cannot change to pointer func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error { for _, handler := range h.handlers { if handler.Enabled(ctx, record.Level) { - if err := handler.Handle(ctx, record); err != nil { - // Continue logging to other destinations even if one fails - _ = err - } + // Intentionally ignore handler errors to ensure all handlers run + _ = handler.Handle(ctx, record) //nolint:errcheck // Error intentionally ignored } } return nil diff --git a/pkg/logging/multihandler_test.go b/pkg/logging/multihandler_test.go new file mode 100644 index 0000000..f741f5f --- /dev/null +++ b/pkg/logging/multihandler_test.go @@ -0,0 +1,179 @@ +package logging + +import ( + "bytes" + "context" + "log/slog" + "strings" + "testing" +) + +func TestMultiHandler_Enabled(t *testing.T) { + tests := []struct { + name string + handlers []slog.Handler + level slog.Level + want bool + }{ + { + name: "all handlers disabled", + handlers: []slog.Handler{ + slog.NewTextHandler(&bytes.Buffer{}, &slog.HandlerOptions{Level: slog.LevelError}), + slog.NewTextHandler(&bytes.Buffer{}, &slog.HandlerOptions{Level: slog.LevelError}), + }, + level: slog.LevelInfo, + want: false, + }, + { + name: "one handler enabled", + handlers: []slog.Handler{ + slog.NewTextHandler(&bytes.Buffer{}, &slog.HandlerOptions{Level: slog.LevelInfo}), + slog.NewTextHandler(&bytes.Buffer{}, &slog.HandlerOptions{Level: slog.LevelError}), + }, + level: slog.LevelInfo, + want: true, + }, + { + name: "all handlers enabled", + handlers: []slog.Handler{ + slog.NewTextHandler(&bytes.Buffer{}, &slog.HandlerOptions{Level: slog.LevelDebug}), + slog.NewTextHandler(&bytes.Buffer{}, &slog.HandlerOptions{Level: slog.LevelDebug}), + }, + level: slog.LevelInfo, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := NewMultiHandler(tt.handlers...) + ctx := context.Background() + got := h.Enabled(ctx, tt.level) + if got != tt.want { + t.Errorf("Enabled() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMultiHandler_Handle(t *testing.T) { + var buf1, buf2 bytes.Buffer + + handler1 := slog.NewTextHandler(&buf1, &slog.HandlerOptions{Level: slog.LevelInfo}) + handler2 := slog.NewTextHandler(&buf2, &slog.HandlerOptions{Level: slog.LevelInfo}) + + multi := NewMultiHandler(handler1, handler2) + + logger := slog.New(multi) + logger.Info("test message", "key", "value") + + // Both buffers should contain the log message + output1 := buf1.String() + output2 := buf2.String() + + if !strings.Contains(output1, "test message") { + t.Errorf("handler1 output missing 'test message': %s", output1) + } + if !strings.Contains(output2, "test message") { + t.Errorf("handler2 output missing 'test message': %s", output2) + } + + if !strings.Contains(output1, "key=value") { + t.Errorf("handler1 output missing 'key=value': %s", output1) + } + if !strings.Contains(output2, "key=value") { + t.Errorf("handler2 output missing 'key=value': %s", output2) + } +} + +func TestMultiHandler_WithAttrs(t *testing.T) { + var buf1, buf2 bytes.Buffer + + handler1 := slog.NewTextHandler(&buf1, &slog.HandlerOptions{Level: slog.LevelInfo}) + handler2 := slog.NewTextHandler(&buf2, &slog.HandlerOptions{Level: slog.LevelInfo}) + + multi := NewMultiHandler(handler1, handler2) + + // Add attributes + multiWithAttrs := multi.WithAttrs([]slog.Attr{ + slog.String("source", "test"), + }) + + logger := slog.New(multiWithAttrs) + logger.Info("test message") + + // Both buffers should contain the attribute + output1 := buf1.String() + output2 := buf2.String() + + if !strings.Contains(output1, "source=test") { + t.Errorf("handler1 output missing attribute: %s", output1) + } + if !strings.Contains(output2, "source=test") { + t.Errorf("handler2 output missing attribute: %s", output2) + } +} + +func TestMultiHandler_WithGroup(t *testing.T) { + var buf1, buf2 bytes.Buffer + + handler1 := slog.NewTextHandler(&buf1, &slog.HandlerOptions{Level: slog.LevelInfo}) + handler2 := slog.NewTextHandler(&buf2, &slog.HandlerOptions{Level: slog.LevelInfo}) + + multi := NewMultiHandler(handler1, handler2) + + // Add group + multiWithGroup := multi.WithGroup("metrics") + + logger := slog.New(multiWithGroup) + logger.Info("test message", "count", 42) + + // Both buffers should contain the group + output1 := buf1.String() + output2 := buf2.String() + + if !strings.Contains(output1, "metrics.count=42") { + t.Errorf("handler1 output missing grouped attribute: %s", output1) + } + if !strings.Contains(output2, "metrics.count=42") { + t.Errorf("handler2 output missing grouped attribute: %s", output2) + } +} + +func TestMultiHandler_OneHandlerDisabled(t *testing.T) { + var buf1, buf2 bytes.Buffer + + // handler1 accepts Info, handler2 only accepts Error + handler1 := slog.NewTextHandler(&buf1, &slog.HandlerOptions{Level: slog.LevelInfo}) + handler2 := slog.NewTextHandler(&buf2, &slog.HandlerOptions{Level: slog.LevelError}) + + multi := NewMultiHandler(handler1, handler2) + + logger := slog.New(multi) + logger.Info("test message") + + // Only buf1 should have output + output1 := buf1.String() + output2 := buf2.String() + + if !strings.Contains(output1, "test message") { + t.Errorf("handler1 should have logged: %s", output1) + } + if output2 != "" { + t.Errorf("handler2 should not have logged: %s", output2) + } +} + +func TestMultiHandler_Empty(t *testing.T) { + // Test with no handlers + multi := NewMultiHandler() + + ctx := context.Background() + if multi.Enabled(ctx, slog.LevelInfo) { + t.Error("Enabled() should return false with no handlers") + } + + // Should not panic + logger := slog.New(multi) + logger.Info("test message") +} diff --git a/pkg/prcache/prcache.go b/pkg/prcache/prcache.go new file mode 100644 index 0000000..a931a17 --- /dev/null +++ b/pkg/prcache/prcache.go @@ -0,0 +1,146 @@ +// Package prcache provides caching functionality for PR metadata with TTL support. +package prcache + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + "time" +) + +// Entry represents a cached item with metadata. +type Entry[T any] struct { + Data T `json:"data"` + CachedAt time.Time `json:"cached_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// Manager handles caching of PR metadata with TTL and invalidation logic. +type Manager struct { + cacheDir string +} + +// NewManager creates a new cache manager. +func NewManager(cacheDir string) *Manager { + return &Manager{cacheDir: cacheDir} +} + +// CacheKey generates a cache key from a URL and timestamp. +func CacheKey(url string, updatedAt time.Time) string { + key := fmt.Sprintf("%s-%s", url, updatedAt.Format(time.RFC3339)) + h := sha256.Sum256([]byte(key)) + return hex.EncodeToString(h[:])[:16] +} + +// CachePath returns the file path for a cache key. +func (m *Manager) CachePath(cacheKey string) string { + return filepath.Join(m.cacheDir, cacheKey+".json") +} + +// CacheResult represents the result of a cache lookup. +type CacheResult struct { + Entry *Entry[any] + Hit bool // True if cache entry was found and valid + ShouldBypass bool // True if cache should be bypassed (e.g., for running tests) +} + +// Get retrieves cached data if valid according to TTL rules. +func (*Manager) Get(path string, updatedAt time.Time, ttl time.Duration, bypassTTL time.Duration, stateCheck func(any) bool) (*CacheResult, error) { + b, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return &CacheResult{}, nil + } + return nil, fmt.Errorf("read cache file: %w", err) + } + + var e Entry[any] + if err := json.Unmarshal(b, &e); err != nil { + // Corrupted cache file - try to remove it + if removeErr := os.Remove(path); removeErr != nil { + slog.Debug("Failed to remove corrupted cache file", "path", path, "error", removeErr) + } + return nil, fmt.Errorf("unmarshal cache: %w", err) + } + + // Check if PR was updated since cache + if !e.UpdatedAt.Equal(updatedAt) { + return &CacheResult{}, nil + } + + age := time.Since(e.CachedAt) + + // Check if should bypass cache for incomplete state (regardless of TTL) + // This ensures we fetch fresh data when tests are still running + if stateCheck != nil && stateCheck(e.Data) && age < bypassTTL { + return &CacheResult{ShouldBypass: true}, nil + } + + // Check TTL + if age >= ttl { + return &CacheResult{}, nil + } + + return &CacheResult{Entry: &e, Hit: true}, nil +} + +// Put stores data in the cache. +func (*Manager) Put(path string, data any, updatedAt time.Time) error { + e := Entry[any]{ + Data: data, + CachedAt: time.Now(), + UpdatedAt: updatedAt, + } + + b, err := json.Marshal(e) + if err != nil { + return fmt.Errorf("marshal cache data: %w", err) + } + + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return fmt.Errorf("create cache directory: %w", err) + } + + if err := os.WriteFile(path, b, 0o600); err != nil { + return fmt.Errorf("write cache file: %w", err) + } + + return nil +} + +// CleanupOldFiles removes cache files older than the specified interval. +func (m *Manager) CleanupOldFiles(maxAge time.Duration) (cleaned int, errs int) { + entries, err := os.ReadDir(m.cacheDir) + if err != nil { + slog.Error("Failed to read cache directory for cleanup", "error", err) + return 0, 1 + } + + for _, e := range entries { + if !strings.HasSuffix(e.Name(), ".json") { + continue + } + + info, err := e.Info() + if err != nil { + errs++ + continue + } + + if time.Since(info.ModTime()) > maxAge { + p := filepath.Join(m.cacheDir, e.Name()) + if err := os.Remove(p); err != nil { + errs++ + } else { + cleaned++ + } + } + } + + return cleaned, errs +} diff --git a/pkg/prcache/prcache_test.go b/pkg/prcache/prcache_test.go new file mode 100644 index 0000000..db46991 --- /dev/null +++ b/pkg/prcache/prcache_test.go @@ -0,0 +1,361 @@ +package prcache + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestNewManager(t *testing.T) { + m := NewManager("/tmp/test-cache") + if m == nil { + t.Fatal("NewManager returned nil") + } + if m.cacheDir != "/tmp/test-cache" { + t.Errorf("cacheDir = %q, want %q", m.cacheDir, "/tmp/test-cache") + } +} + +func TestCacheKey(t *testing.T) { + url := "https://github.com/owner/repo/pull/123" + ts := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + + key1 := CacheKey(url, ts) + key2 := CacheKey(url, ts) + + // Same inputs should produce same key + if key1 != key2 { + t.Errorf("CacheKey not deterministic: %q != %q", key1, key2) + } + + // Different timestamp should produce different key + ts2 := ts.Add(1 * time.Second) + key3 := CacheKey(url, ts2) + if key1 == key3 { + t.Error("CacheKey should differ for different timestamps") + } + + // Key should be 16 characters (hex encoded) + if len(key1) != 16 { + t.Errorf("CacheKey length = %d, want 16", len(key1)) + } +} + +func TestCachePath(t *testing.T) { + cacheDir := t.TempDir() + m := NewManager(cacheDir) + path := m.CachePath("abcd1234") + + expected := filepath.Join(cacheDir, "abcd1234.json") + if path != expected { + t.Errorf("CachePath = %q, want %q", path, expected) + } +} + +func TestPutAndGet(t *testing.T) { + // Create temporary cache directory + tmpDir := t.TempDir() + m := NewManager(tmpDir) + + url := "https://github.com/owner/repo/pull/123" + updatedAt := time.Now().Add(-1 * time.Hour) + cacheKey := CacheKey(url, updatedAt) + path := m.CachePath(cacheKey) + + // Test data + data := map[string]string{ + "test": "value", + "foo": "bar", + } + + // Put data in cache + err := m.Put(path, data, updatedAt) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Verify file exists + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Fatal("Cache file was not created") + } + + // Get data from cache with long TTL (should hit) + ttl := 24 * time.Hour + bypassTTL := 1 * time.Hour + result, err := m.Get(path, updatedAt, ttl, bypassTTL, nil) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if !result.Hit { + t.Error("Expected cache hit") + } + if result.ShouldBypass { + t.Error("Should not bypass cache") + } + if result.Entry == nil { + t.Fatal("Entry is nil") + } + + // Verify cached timestamp is recent + if time.Since(result.Entry.CachedAt) > 5*time.Second { + t.Errorf("CachedAt is too old: %v", result.Entry.CachedAt) + } +} + +func TestGet_CacheMiss_FileNotExists(t *testing.T) { + tmpDir := t.TempDir() + m := NewManager(tmpDir) + + path := filepath.Join(tmpDir, "nonexistent.json") + updatedAt := time.Now() + + result, err := m.Get(path, updatedAt, 1*time.Hour, 1*time.Minute, nil) + if err != nil { + t.Errorf("Get returned error for nonexistent file: %v", err) + } + if result.Hit { + t.Error("Should not have cache hit for nonexistent file") + } + if result.ShouldBypass { + t.Error("Should not bypass for nonexistent file") + } +} + +func TestGet_CacheMiss_CorruptedFile(t *testing.T) { + tmpDir := t.TempDir() + m := NewManager(tmpDir) + + path := filepath.Join(tmpDir, "corrupted.json") + + // Write corrupted JSON + err := os.WriteFile(path, []byte("not valid json {{{"), 0o600) + if err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + updatedAt := time.Now() + result, err := m.Get(path, updatedAt, 1*time.Hour, 1*time.Minute, nil) + if err == nil { + t.Error("Expected error for corrupted cache file") + } + if result != nil && result.Hit { + t.Error("Should not have cache hit for corrupted file") + } + if result != nil && result.ShouldBypass { + t.Error("Should not bypass for corrupted file") + } + + // File should be removed + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Error("Corrupted cache file should have been removed") + } +} + +func TestGet_CacheMiss_PRUpdated(t *testing.T) { + tmpDir := t.TempDir() + m := NewManager(tmpDir) + + url := "https://github.com/owner/repo/pull/123" + oldUpdatedAt := time.Now().Add(-2 * time.Hour) + newUpdatedAt := time.Now().Add(-1 * time.Hour) + + cacheKey := CacheKey(url, oldUpdatedAt) + path := m.CachePath(cacheKey) + + // Put data with old timestamp + data := map[string]string{"test": "value"} + err := m.Put(path, data, oldUpdatedAt) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Try to get with new timestamp (PR was updated) + result, err := m.Get(path, newUpdatedAt, 24*time.Hour, 1*time.Hour, nil) + if err != nil { + t.Errorf("Get failed: %v", err) + } + if result.Hit { + t.Error("Should not have cache hit when PR was updated") + } +} + +func TestGet_CacheMiss_TTLExpired(t *testing.T) { + tmpDir := t.TempDir() + m := NewManager(tmpDir) + + url := "https://github.com/owner/repo/pull/123" + updatedAt := time.Now().Add(-1 * time.Hour) + cacheKey := CacheKey(url, updatedAt) + path := m.CachePath(cacheKey) + + // Put data + data := map[string]string{"test": "value"} + err := m.Put(path, data, updatedAt) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Wait a bit to ensure cache age + time.Sleep(100 * time.Millisecond) + + // Get with very short TTL (should miss) + result, err := m.Get(path, updatedAt, 50*time.Millisecond, 1*time.Hour, nil) + if err != nil { + t.Errorf("Get failed: %v", err) + } + if result.Hit { + t.Error("Should not have cache hit when TTL expired") + } + if result.ShouldBypass { + t.Error("Should not bypass without state check") + } +} + +func TestGet_Bypass_WithStateCheck(t *testing.T) { + tmpDir := t.TempDir() + m := NewManager(tmpDir) + + url := "https://github.com/owner/repo/pull/123" + updatedAt := time.Now().Add(-1 * time.Hour) + cacheKey := CacheKey(url, updatedAt) + path := m.CachePath(cacheKey) + + // Put data + data := map[string]any{ + "state": "running", + } + err := m.Put(path, data, updatedAt) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Wait to ensure TTL expired + time.Sleep(100 * time.Millisecond) + + // State check that returns true (incomplete state) + stateCheck := func(d any) bool { + if m, ok := d.(map[string]any); ok { + if state, ok := m["state"].(string); ok { + return state == "running" + } + } + return false + } + + // Get with expired TTL but bypass window still valid + result, err := m.Get(path, updatedAt, 50*time.Millisecond, 1*time.Hour, stateCheck) + if err != nil { + t.Errorf("Get failed: %v", err) + } + if result.Hit { + t.Error("Should not have cache hit when TTL expired") + } + if !result.ShouldBypass { + t.Error("Should bypass cache for incomplete state within bypass window") + } +} + +func TestCleanupOldFiles(t *testing.T) { + tmpDir := t.TempDir() + m := NewManager(tmpDir) + + // Create some old files + oldFile1 := filepath.Join(tmpDir, "old1.json") + oldFile2 := filepath.Join(tmpDir, "old2.json") + recentFile := filepath.Join(tmpDir, "recent.json") + nonJSONFile := filepath.Join(tmpDir, "other.txt") + + // Write files + for _, f := range []string{oldFile1, oldFile2, recentFile, nonJSONFile} { + if err := os.WriteFile(f, []byte("{}"), 0o600); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + } + + // Make old files actually old by changing their modification time + oldTime := time.Now().Add(-20 * 24 * time.Hour) // 20 days ago + if err := os.Chtimes(oldFile1, oldTime, oldTime); err != nil { + t.Fatalf("Failed to change file time: %v", err) + } + if err := os.Chtimes(oldFile2, oldTime, oldTime); err != nil { + t.Fatalf("Failed to change file time: %v", err) + } + + // Cleanup files older than 15 days + cleaned, errs := m.CleanupOldFiles(15 * 24 * time.Hour) + + if errs != 0 { + t.Errorf("Cleanup had errors: %d", errs) + } + + if cleaned != 2 { + t.Errorf("Cleaned %d files, want 2", cleaned) + } + + // Verify old files are gone + for _, f := range []string{oldFile1, oldFile2} { + if _, err := os.Stat(f); !os.IsNotExist(err) { + t.Errorf("Old file %q should have been removed", f) + } + } + + // Verify recent file and non-JSON file still exist + for _, f := range []string{recentFile, nonJSONFile} { + if _, err := os.Stat(f); err != nil { + t.Errorf("File %q should still exist: %v", f, err) + } + } +} + +func TestCleanupOldFiles_NoFiles(t *testing.T) { + tmpDir := t.TempDir() + m := NewManager(tmpDir) + + cleaned, errs := m.CleanupOldFiles(15 * 24 * time.Hour) + + if cleaned != 0 { + t.Errorf("Cleaned %d files, want 0", cleaned) + } + if errs != 0 { + t.Errorf("Had %d errors, want 0", errs) + } +} + +func TestCleanupOldFiles_NonexistentDir(t *testing.T) { + m := NewManager("/nonexistent/directory") + + cleaned, errs := m.CleanupOldFiles(15 * 24 * time.Hour) + + if cleaned != 0 { + t.Errorf("Cleaned %d files, want 0", cleaned) + } + if errs != 1 { + t.Errorf("Had %d errors, want 1", errs) + } +} + +func TestPut_CreateDirectory(t *testing.T) { + tmpDir := t.TempDir() + nestedDir := filepath.Join(tmpDir, "nested", "deep") + m := NewManager(nestedDir) + + path := filepath.Join(nestedDir, "test.json") + data := map[string]string{"test": "value"} + updatedAt := time.Now() + + err := m.Put(path, data, updatedAt) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Verify directory was created + if _, err := os.Stat(nestedDir); os.IsNotExist(err) { + t.Error("Nested directory should have been created") + } + + // Verify file exists + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Error("Cache file should have been created") + } +} diff --git a/cmd/reviewGOOSE/browser_rate_limiter.go b/pkg/ratelimit/browser.go similarity index 96% rename from cmd/reviewGOOSE/browser_rate_limiter.go rename to pkg/ratelimit/browser.go index 7aef4c6..4371acc 100644 --- a/cmd/reviewGOOSE/browser_rate_limiter.go +++ b/pkg/ratelimit/browser.go @@ -1,4 +1,5 @@ -package main +// Package ratelimit provides rate limiting functionality for browser operations. +package ratelimit import ( "log/slog" @@ -120,6 +121,6 @@ func (b *BrowserRateLimiter) Reset() { b.mu.Lock() defer b.mu.Unlock() previousCount := len(b.openedPRs) - b.openedPRs = make(map[string]bool) + clear(b.openedPRs) slog.Info("[BROWSER] Rate limiter reset", "clearedPRs", previousCount) } diff --git a/pkg/ratelimit/browser_test.go b/pkg/ratelimit/browser_test.go new file mode 100644 index 0000000..5630ac6 --- /dev/null +++ b/pkg/ratelimit/browser_test.go @@ -0,0 +1,241 @@ +package ratelimit + +import ( + "testing" + "time" +) + +func TestNewBrowserRateLimiter(t *testing.T) { + limiter := NewBrowserRateLimiter(1*time.Minute, 2, 10) + + if limiter == nil { + t.Fatal("NewBrowserRateLimiter returned nil") + } + + if limiter.startupDelay != 1*time.Minute { + t.Errorf("startupDelay = %v, want %v", limiter.startupDelay, 1*time.Minute) + } + + if limiter.maxPerMinute != 2 { + t.Errorf("maxPerMinute = %d, want 2", limiter.maxPerMinute) + } + + if limiter.maxPerDay != 10 { + t.Errorf("maxPerDay = %d, want 10", limiter.maxPerDay) + } +} + +func TestBrowserRateLimiter_CanOpen_StartupDelay(t *testing.T) { + startTime := time.Now() + limiter := NewBrowserRateLimiter(1*time.Minute, 10, 100) + + // Should not allow opening during startup delay + if limiter.CanOpen(startTime, "https://github.com/owner/repo/pull/1") { + t.Error("CanOpen should return false during startup delay") + } + + // Should allow opening after startup delay + pastStartTime := time.Now().Add(-2 * time.Minute) + if !limiter.CanOpen(pastStartTime, "https://github.com/owner/repo/pull/1") { + t.Error("CanOpen should return true after startup delay") + } +} + +func TestBrowserRateLimiter_CanOpen_DuplicatePR(t *testing.T) { + startTime := time.Now().Add(-2 * time.Minute) // Past startup delay + limiter := NewBrowserRateLimiter(1*time.Minute, 10, 100) + + prURL := "https://github.com/owner/repo/pull/1" + + // First call should succeed + if !limiter.CanOpen(startTime, prURL) { + t.Error("CanOpen should return true for first call") + } + + // Record the open + limiter.RecordOpen(prURL) + + // Second call for same PR should fail + if limiter.CanOpen(startTime, prURL) { + t.Error("CanOpen should return false for duplicate PR") + } +} + +func TestBrowserRateLimiter_CanOpen_PerMinuteLimit(t *testing.T) { + startTime := time.Now().Add(-2 * time.Minute) // Past startup delay + limiter := NewBrowserRateLimiter(1*time.Minute, 2, 100) // Max 2 per minute + + // Open first PR + if !limiter.CanOpen(startTime, "https://github.com/owner/repo/pull/1") { + t.Error("First CanOpen should succeed") + } + limiter.RecordOpen("https://github.com/owner/repo/pull/1") + + // Open second PR + if !limiter.CanOpen(startTime, "https://github.com/owner/repo/pull/2") { + t.Error("Second CanOpen should succeed") + } + limiter.RecordOpen("https://github.com/owner/repo/pull/2") + + // Third PR should fail per-minute limit + if limiter.CanOpen(startTime, "https://github.com/owner/repo/pull/3") { + t.Error("Third CanOpen should fail per-minute limit") + } +} + +func TestBrowserRateLimiter_CanOpen_PerDayLimit(t *testing.T) { + startTime := time.Now().Add(-2 * time.Minute) // Past startup delay + limiter := NewBrowserRateLimiter(1*time.Minute, 100, 2) // Max 2 per day + + // Open first PR + if !limiter.CanOpen(startTime, "https://github.com/owner/repo/pull/1") { + t.Error("First CanOpen should succeed") + } + limiter.RecordOpen("https://github.com/owner/repo/pull/1") + + // Manually clear per-minute limit to test daily limit in isolation + limiter.mu.Lock() + limiter.openedLastMinute = []time.Time{} + limiter.mu.Unlock() + + // Open second PR + if !limiter.CanOpen(startTime, "https://github.com/owner/repo/pull/2") { + t.Error("Second CanOpen should succeed") + } + limiter.RecordOpen("https://github.com/owner/repo/pull/2") + + // Manually clear per-minute limit again + limiter.mu.Lock() + limiter.openedLastMinute = []time.Time{} + limiter.mu.Unlock() + + // Third PR should fail per-day limit + if limiter.CanOpen(startTime, "https://github.com/owner/repo/pull/3") { + t.Error("Third CanOpen should fail per-day limit") + } +} + +func TestBrowserRateLimiter_RecordOpen(t *testing.T) { + limiter := NewBrowserRateLimiter(1*time.Minute, 10, 100) + + prURL := "https://github.com/owner/repo/pull/1" + + // Record an open + limiter.RecordOpen(prURL) + + // Verify it's tracked in openedPRs + limiter.mu.Lock() + if !limiter.openedPRs[prURL] { + t.Error("PR should be tracked in openedPRs") + } + + // Verify it's tracked in time windows + if len(limiter.openedLastMinute) != 1 { + t.Errorf("openedLastMinute = %d, want 1", len(limiter.openedLastMinute)) + } + + if len(limiter.openedToday) != 1 { + t.Errorf("openedToday = %d, want 1", len(limiter.openedToday)) + } + limiter.mu.Unlock() +} + +func TestBrowserRateLimiter_CleanOldEntries(t *testing.T) { + limiter := NewBrowserRateLimiter(1*time.Minute, 10, 100) + + now := time.Now() + + // Add entries at different times + limiter.mu.Lock() + limiter.openedLastMinute = []time.Time{ + now.Add(-2 * time.Minute), // Should be cleaned (>1 minute ago) + now.Add(-30 * time.Second), // Should remain (<1 minute ago) + } + limiter.openedToday = []time.Time{ + now.Add(-25 * time.Hour), // Should be cleaned (>24 hours ago) + now.Add(-1 * time.Hour), // Should remain (<24 hours ago) + } + limiter.mu.Unlock() + + // Clean entries + limiter.mu.Lock() + limiter.cleanOldEntries(now) + + // Check per-minute entries + if len(limiter.openedLastMinute) != 1 { + t.Errorf("openedLastMinute after clean = %d, want 1", len(limiter.openedLastMinute)) + } + + // Check per-day entries + if len(limiter.openedToday) != 1 { + t.Errorf("openedToday after clean = %d, want 1", len(limiter.openedToday)) + } + limiter.mu.Unlock() +} + +func TestBrowserRateLimiter_Reset(t *testing.T) { + limiter := NewBrowserRateLimiter(1*time.Minute, 10, 100) + + // Add some opened PRs + limiter.RecordOpen("https://github.com/owner/repo/pull/1") + limiter.RecordOpen("https://github.com/owner/repo/pull/2") + limiter.RecordOpen("https://github.com/owner/repo/pull/3") + + // Verify they're tracked + limiter.mu.Lock() + if len(limiter.openedPRs) != 3 { + t.Errorf("openedPRs before reset = %d, want 3", len(limiter.openedPRs)) + } + limiter.mu.Unlock() + + // Reset + limiter.Reset() + + // Verify they're cleared + limiter.mu.Lock() + if len(limiter.openedPRs) != 0 { + t.Errorf("openedPRs after reset = %d, want 0", len(limiter.openedPRs)) + } + limiter.mu.Unlock() + + // Time window entries should still exist (reset only clears PR tracking) + limiter.mu.Lock() + if len(limiter.openedLastMinute) == 0 { + t.Error("openedLastMinute should not be cleared by Reset") + } + if len(limiter.openedToday) == 0 { + t.Error("openedToday should not be cleared by Reset") + } + limiter.mu.Unlock() +} + +func TestBrowserRateLimiter_Concurrent(t *testing.T) { + startTime := time.Now().Add(-2 * time.Minute) // Past startup delay + limiter := NewBrowserRateLimiter(1*time.Minute, 100, 1000) + + // Test concurrent access + done := make(chan bool) + for i := range 10 { + go func(id int) { + prURL := "https://github.com/owner/repo/pull/" + string(rune('1'+id)) + if limiter.CanOpen(startTime, prURL) { + limiter.RecordOpen(prURL) + } + done <- true + }(i) + } + + // Wait for all goroutines + for range 10 { + <-done + } + + // Verify no race conditions (test runs with -race flag will catch issues) + limiter.mu.Lock() + totalOpened := len(limiter.openedPRs) + limiter.mu.Unlock() + + if totalOpened == 0 { + t.Error("No PRs were opened") + } +} diff --git a/pkg/safebrowse/safebrowse_test.go b/pkg/safebrowse/safebrowse_test.go index 848a517..713c502 100644 --- a/pkg/safebrowse/safebrowse_test.go +++ b/pkg/safebrowse/safebrowse_test.go @@ -524,3 +524,205 @@ func TestValidateParamString(t *testing.T) { }) } } + +func TestOpen(t *testing.T) { + tests := []struct { + name string + url string + wantErr bool + }{ + { + name: "valid URL", + url: "https://github.com/owner/repo", + wantErr: false, // validation passes, browser open may fail in test + }, + { + name: "invalid URL - HTTP", + url: "http://github.com/owner/repo", + wantErr: true, // validation fails + }, + { + name: "invalid URL - control char", + url: "https://github.com/owner\n/repo", + wantErr: true, // validation fails + }, + { + name: "invalid URL - query params", + url: "https://github.com/owner/repo?foo=bar", + wantErr: true, // validation fails (params not allowed in Open) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + err := Open(ctx, tt.url) + + // If wantErr is true, we expect validation to fail + // If wantErr is false, validation passes (browser open may fail, which is OK) + if tt.wantErr && err == nil { + t.Errorf("Open() expected error but got none") + } + }) + } +} + +func TestOpenBrowser_InvalidCommand(t *testing.T) { + // Test that openBrowser handles context cancellation + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + err := openBrowser(ctx, "https://github.com/owner/repo") + // We expect an error because context is cancelled + // The exact error depends on timing and platform + if err == nil { + // It's OK if err is nil in some cases due to Start() not blocking + t.Log("openBrowser with cancelled context returned nil (Start() doesn't block)") + } +} + +func TestOpenWithParams_PercentEncoding(t *testing.T) { + // Test that OpenWithParams rejects URLs that produce percent encoding + ctx := context.Background() + + // Valid base URL but param value that would need encoding + // The current implementation actually encodes and then rejects if % is present + // Let's verify this behavior + err := OpenWithParams(ctx, "https://github.com/owner/repo", map[string]string{ + "key": "value with space", // spaces would require encoding + }) + + // This should fail during validation of the parameter value + if err == nil { + t.Error("OpenWithParams() should reject parameter value with space") + } +} + +func TestValidate_EdgeCases(t *testing.T) { + tests := []struct { + name string + url string + allowParams bool + wantErr bool + }{ + { + name: "valid URL with params when allowed", + url: "https://github.com/owner/repo?key=value", + allowParams: true, + wantErr: false, + }, + { + name: "valid URL with params when not allowed", + url: "https://github.com/owner/repo?key=value", + allowParams: false, + wantErr: true, + }, + { + name: "URL with non-ASCII character", + url: "https://github.com/owner/repö", + allowParams: false, + wantErr: true, + }, + { + name: "URL with DEL character (0x7F)", + url: "https://github.com/owner/repo\x7F", + allowParams: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validate(tt.url, tt.allowParams) + if (err != nil) != tt.wantErr { + t.Errorf("validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateGitHubPRURL_EdgeCases(t *testing.T) { + tests := []struct { + name string + url string + wantErr bool + }{ + { + name: "valid minimal PR", + url: "https://github.com/a/b/pull/1", + wantErr: false, + }, + { + name: "valid with underscore in repo", + url: "https://github.com/owner/repo_name/pull/123", + wantErr: false, + }, + { + name: "missing pull segment", + url: "https://github.com/owner/repo/123", + wantErr: true, + }, + { + name: "too few path segments", + url: "https://github.com/owner/pull/123", + wantErr: true, + }, + { + name: "PR number empty", + url: "https://github.com/owner/repo/pull/", + wantErr: true, + }, + { + name: "PR number has letters", + url: "https://github.com/owner/repo/pull/12a", + wantErr: true, + }, + { + name: "goose param with multiple values", + url: "https://github.com/owner/repo/pull/123?goose=1&goose=2", + wantErr: true, + }, + { + name: "query param without goose prefix", + url: "https://github.com/owner/repo/pull/123?other=value", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateGitHubPRURL(tt.url) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateGitHubPRURL(%q) error = %v, wantErr %v", tt.url, err, tt.wantErr) + } + }) + } +} + +func TestOpenWithParams_EmptyParams(t *testing.T) { + ctx := context.Background() + + // Test with empty params map + err := OpenWithParams(ctx, "https://github.com/owner/repo", map[string]string{}) + // Validation should pass, browser open may fail (which is OK for test) + // We're just checking that empty params don't cause a panic + if err != nil && strings.Contains(err.Error(), "panic") { + t.Error("OpenWithParams with empty params should not panic") + } +} + +func TestOpenWithParams_MultipleValidParams(t *testing.T) { + ctx := context.Background() + + // Test with multiple valid params + err := OpenWithParams(ctx, "https://github.com/owner/repo", map[string]string{ + "goose": "review", + "source": "tray", + }) + // The function will encode params and then check for % + // Since the values don't need encoding, it should pass validation + // Browser open may fail (which is OK for test) + if err != nil && strings.Contains(err.Error(), "invalid parameter") { + t.Errorf("OpenWithParams with valid params should not fail validation: %v", err) + } +}