diff --git a/pkg/completion/completion.go b/pkg/completion/completion.go index 43d407e4..e8265046 100644 --- a/pkg/completion/completion.go +++ b/pkg/completion/completion.go @@ -13,6 +13,25 @@ import ( "github.com/Jguer/yay/v12/pkg/text" ) +// NeedsUpdate checks if the completion cache needs to be regenerated. +// Returns true if the file doesn't exist, is older than interval days, or force is true. +func NeedsUpdate(completionPath string, interval int, force bool) bool { + if force { + return true + } + + info, err := os.Stat(completionPath) + if os.IsNotExist(err) { + return true + } + + if interval != -1 && time.Since(info.ModTime()).Hours() >= float64(interval*24) { + return true + } + + return false +} + type PkgSynchronizer interface { SyncPackages(...string) []db.IPackage } @@ -21,9 +40,10 @@ type PkgSynchronizer interface { func Show(ctx context.Context, httpClient download.HTTPRequestDoer, dbExecutor PkgSynchronizer, aurURL, completionPath string, interval int, force bool, logger *text.Logger, ) error { - err := Update(ctx, httpClient, dbExecutor, aurURL, completionPath, interval, force, logger) - if err != nil { - return err + if NeedsUpdate(completionPath, interval, force) { + if err := UpdateCache(ctx, httpClient, dbExecutor, aurURL, completionPath, logger); err != nil { + return err + } } in, err := os.OpenFile(completionPath, os.O_RDWR|os.O_CREATE, 0o644) @@ -37,35 +57,26 @@ func Show(ctx context.Context, httpClient download.HTTPRequestDoer, return err } -// Update updates completion cache to be used by Complete. -func Update(ctx context.Context, httpClient download.HTTPRequestDoer, - dbExecutor PkgSynchronizer, aurURL, completionPath string, interval int, force bool, logger *text.Logger, +// UpdateCache regenerates the completion cache file unconditionally. +func UpdateCache(ctx context.Context, httpClient download.HTTPRequestDoer, + dbExecutor PkgSynchronizer, aurURL, completionPath string, logger *text.Logger, ) error { - info, err := os.Stat(completionPath) - - if os.IsNotExist(err) || (interval != -1 && time.Since(info.ModTime()).Hours() >= float64(interval*24)) || force { - errd := os.MkdirAll(filepath.Dir(completionPath), 0o755) - if errd != nil { - return errd - } - - out, errf := os.Create(completionPath) - if errf != nil { - return errf - } - - if createAURList(ctx, httpClient, aurURL, out, logger) != nil { - defer os.Remove(completionPath) - } - - erra := createRepoList(dbExecutor, out) - - out.Close() - - return erra + if err := os.MkdirAll(filepath.Dir(completionPath), 0o755); err != nil { + return err } - return nil + out, err := os.Create(completionPath) + if err != nil { + return err + } + defer out.Close() + + if err := createAURList(ctx, httpClient, aurURL, out, logger); err != nil { + os.Remove(completionPath) + return err + } + + return createRepoList(dbExecutor, out) } // createAURList creates a new completion file. diff --git a/pkg/completion/completion_test.go b/pkg/completion/completion_test.go index 0b2507c7..acb85f07 100644 --- a/pkg/completion/completion_test.go +++ b/pkg/completion/completion_test.go @@ -10,9 +10,16 @@ import ( "errors" "io" "net/http" + "os" + "path/filepath" "testing" + "time" + + "github.com/Jguer/yay/v12/pkg/db" + "github.com/Jguer/yay/v12/pkg/db/mock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const samplePackageResp = ` @@ -122,3 +129,424 @@ func Test_createAURListStatusError(t *testing.T) { err := createAURList(context.Background(), doer, "https://aur.archlinux.org", out, nil) assert.EqualError(t, err, "invalid status code: 503") } + +func TestNeedsUpdate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupFile func(t *testing.T, path string) + interval int + force bool + expectedResult bool + }{ + { + name: "force returns true", + setupFile: nil, + interval: 7, + force: true, + expectedResult: true, + }, + { + name: "file does not exist returns true", + setupFile: nil, + interval: 7, + force: false, + expectedResult: true, + }, + { + name: "fresh file returns false", + setupFile: func(t *testing.T, path string) { + t.Helper() + err := os.WriteFile(path, []byte("test"), 0o600) + require.NoError(t, err) + }, + interval: 7, + force: false, + expectedResult: false, + }, + { + name: "interval -1 never updates", + setupFile: func(t *testing.T, path string) { + t.Helper() + err := os.WriteFile(path, []byte("test"), 0o600) + require.NoError(t, err) + // Set file time to 30 days ago + oldTime := time.Now().Add(-30 * 24 * time.Hour) + err = os.Chtimes(path, oldTime, oldTime) + require.NoError(t, err) + }, + interval: -1, + force: false, + expectedResult: false, + }, + { + name: "old file returns true", + setupFile: func(t *testing.T, path string) { + t.Helper() + err := os.WriteFile(path, []byte("test"), 0o600) + require.NoError(t, err) + // Set file time to 10 days ago + oldTime := time.Now().Add(-10 * 24 * time.Hour) + err = os.Chtimes(path, oldTime, oldTime) + require.NoError(t, err) + }, + interval: 7, + force: false, + expectedResult: true, + }, + { + name: "file within interval returns false", + setupFile: func(t *testing.T, path string) { + t.Helper() + err := os.WriteFile(path, []byte("test"), 0o600) + require.NoError(t, err) + // Set file time to 3 days ago + oldTime := time.Now().Add(-3 * 24 * time.Hour) + err = os.Chtimes(path, oldTime, oldTime) + require.NoError(t, err) + }, + interval: 7, + force: false, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + completionPath := filepath.Join(tmpDir, "completion") + + if tt.setupFile != nil { + tt.setupFile(t, completionPath) + } + + result := NeedsUpdate(completionPath, tt.interval, tt.force) + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +// mockPkgSynchronizer implements PkgSynchronizer for testing. +type mockPkgSynchronizer struct { + packages []db.IPackage +} + +func (m *mockPkgSynchronizer) SyncPackages(...string) []db.IPackage { + return m.packages +} + +func Test_createRepoList(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + packages []db.IPackage + expectedOutput string + expectedError error + }{ + { + name: "empty package list", + packages: []db.IPackage{}, + expectedOutput: "", + expectedError: nil, + }, + { + name: "single package", + packages: []db.IPackage{ + &mock.Package{PName: "vim", PDB: mock.NewDB("extra")}, + }, + expectedOutput: "vim\textra\n", + expectedError: nil, + }, + { + name: "multiple packages", + packages: []db.IPackage{ + &mock.Package{PName: "vim", PDB: mock.NewDB("extra")}, + &mock.Package{PName: "git", PDB: mock.NewDB("extra")}, + &mock.Package{PName: "linux", PDB: mock.NewDB("core")}, + }, + expectedOutput: "vim\textra\ngit\textra\nlinux\tcore\n", + expectedError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + dbExecutor := &mockPkgSynchronizer{packages: tt.packages} + out := &bytes.Buffer{} + + err := createRepoList(dbExecutor, out) + + if tt.expectedError != nil { + assert.EqualError(t, err, tt.expectedError.Error()) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tt.expectedOutput, out.String()) + }) + } +} + +// errorWriter is a writer that always returns an error. +type errorWriter struct{} + +func (e *errorWriter) Write(p []byte) (n int, err error) { + return 0, errors.New("write error") +} + +func Test_createRepoListWriteError(t *testing.T) { + t.Parallel() + + dbExecutor := &mockPkgSynchronizer{ + packages: []db.IPackage{ + &mock.Package{PName: "vim", PDB: mock.NewDB("extra")}, + }, + } + + err := createRepoList(dbExecutor, &errorWriter{}) + assert.EqualError(t, err, "write error") +} + +func TestUpdateCache(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + doer *mockDoer + packages []db.IPackage + expectedOutput string + expectError bool + }{ + { + name: "successful update", + doer: &mockDoer{ + returnStatusCode: 200, + returnBody: []byte("# Comment\npkg1\npkg2\n"), + returnErr: nil, + }, + packages: []db.IPackage{ + &mock.Package{PName: "vim", PDB: mock.NewDB("extra")}, + }, + expectedOutput: "pkg1\tAUR\npkg2\tAUR\nvim\textra\n", + expectError: false, + }, + { + name: "AUR fetch error removes file", + doer: &mockDoer{ + returnStatusCode: 500, + returnBody: []byte{}, + returnErr: nil, + }, + packages: []db.IPackage{}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + completionPath := filepath.Join(tmpDir, "subdir", "completion") + tt.doer.t = t + tt.doer.wantURL = "https://aur.archlinux.org/packages.gz" + + dbExecutor := &mockPkgSynchronizer{packages: tt.packages} + + err := UpdateCache(context.Background(), tt.doer, dbExecutor, "https://aur.archlinux.org", completionPath, nil) + + if tt.expectError { + assert.Error(t, err) + // File should be removed on error + _, statErr := os.Stat(completionPath) + assert.True(t, os.IsNotExist(statErr)) + } else { + require.NoError(t, err) + content, readErr := os.ReadFile(completionPath) + require.NoError(t, readErr) + assert.Equal(t, tt.expectedOutput, string(content)) + } + }) + } +} + +func TestShow(t *testing.T) { + // Note: Not running in parallel because we need to capture os.Stdout + tests := []struct { + name string + setupFile func(t *testing.T, path string) + doer *mockDoer + packages []db.IPackage + interval int + force bool + expectError bool + }{ + { + name: "existing fresh file", + setupFile: func(t *testing.T, path string) { + t.Helper() + err := os.WriteFile(path, []byte("cached\tdata\n"), 0o600) + require.NoError(t, err) + }, + doer: nil, // Should not be called + packages: nil, + interval: 7, + force: false, + expectError: false, + }, + { + name: "file needs update", + setupFile: nil, + doer: &mockDoer{ + returnStatusCode: 200, + returnBody: []byte("# Comment\naur-pkg\n"), + returnErr: nil, + }, + packages: []db.IPackage{ + &mock.Package{PName: "repo-pkg", PDB: mock.NewDB("core")}, + }, + interval: 7, + force: false, + expectError: false, + }, + { + name: "force update", + setupFile: nil, + doer: &mockDoer{ + returnStatusCode: 200, + returnBody: []byte("# Comment\nforced-pkg\n"), + returnErr: nil, + }, + packages: []db.IPackage{}, + interval: 7, + force: true, + expectError: false, + }, + { + name: "update cache error", + setupFile: nil, + doer: &mockDoer{ + returnStatusCode: 500, + returnBody: []byte{}, + returnErr: nil, + }, + packages: []db.IPackage{}, + interval: 7, + force: false, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Not running in parallel because we capture os.Stdout + tmpDir := t.TempDir() + completionPath := filepath.Join(tmpDir, "completion") + + if tt.setupFile != nil { + tt.setupFile(t, completionPath) + } + + if tt.doer != nil { + tt.doer.t = t + tt.doer.wantURL = "https://aur.archlinux.org/packages.gz" + } + + dbExecutor := &mockPkgSynchronizer{packages: tt.packages} + + // Capture stdout using a pipe + oldStdout := os.Stdout + r, w, pipeErr := os.Pipe() + require.NoError(t, pipeErr) + os.Stdout = w + + err := Show(context.Background(), tt.doer, dbExecutor, "https://aur.archlinux.org", completionPath, tt.interval, tt.force, nil) + + // Close writer first, then restore stdout, then read + w.Close() + os.Stdout = oldStdout + + var buf bytes.Buffer + _, copyErr := io.Copy(&buf, r) + r.Close() + + if tt.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + require.NoError(t, copyErr) + // Verify file exists and has content + content, readErr := os.ReadFile(completionPath) + require.NoError(t, readErr) + assert.NotEmpty(t, content) + } + }) + } +} + +func TestShowFileOpenError(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + // Use a path that can't be created (directory as file) + completionPath := filepath.Join(tmpDir, "completion") + + // Create a directory where we expect a file - this will cause OpenFile to fail + err := os.MkdirAll(completionPath, 0o755) + require.NoError(t, err) + + doer := &mockDoer{ + t: t, + wantURL: "https://aur.archlinux.org/packages.gz", + returnStatusCode: 200, + returnBody: []byte("# Comment\npkg\n"), + returnErr: nil, + } + + dbExecutor := &mockPkgSynchronizer{packages: []db.IPackage{}} + + err = Show(context.Background(), doer, dbExecutor, "https://aur.archlinux.org", completionPath, 7, true, nil) + assert.Error(t, err) +} + +func TestUpdateCacheMkdirError(t *testing.T) { + t.Parallel() + + // Create a file where we expect a directory - this will cause MkdirAll to fail + tmpDir := t.TempDir() + blockingFile := filepath.Join(tmpDir, "blocking") + err := os.WriteFile(blockingFile, []byte("block"), 0o600) + require.NoError(t, err) + + completionPath := filepath.Join(blockingFile, "subdir", "completion") + + doer := &mockDoer{ + t: t, + wantURL: "https://aur.archlinux.org/packages.gz", + returnStatusCode: 200, + returnBody: []byte("# Comment\npkg\n"), + returnErr: nil, + } + + dbExecutor := &mockPkgSynchronizer{packages: []db.IPackage{}} + + err = UpdateCache(context.Background(), doer, dbExecutor, "https://aur.archlinux.org", completionPath, nil) + assert.Error(t, err) +} + +func Test_createAURListWriteError(t *testing.T) { + t.Parallel() + + doer := &mockDoer{ + t: t, + wantURL: "https://aur.archlinux.org/packages.gz", + returnStatusCode: 200, + returnBody: []byte("# Comment\npkg1\npkg2\n"), + returnErr: nil, + } + + err := createAURList(context.Background(), doer, "https://aur.archlinux.org", &errorWriter{}, nil) + assert.EqualError(t, err, "write error") +} diff --git a/pkg/sync/sync.go b/pkg/sync/sync.go index 374df975..d3ca2f7e 100644 --- a/pkg/sync/sync.go +++ b/pkg/sync/sync.go @@ -63,13 +63,15 @@ func (o *OperationService) Run(ctx context.Context, run *runtime.Runtime, installer.AddPostInstallHook(cleanAURDirsFunc) } - go func() { - errComp := completion.Update(ctx, run.HTTPClient, o.dbExecutor, - o.cfg.AURURL, o.cfg.CompletionPath, o.cfg.CompletionInterval, false, o.logger) - if errComp != nil { - o.logger.Warnln(errComp) - } - }() + if completion.NeedsUpdate(o.cfg.CompletionPath, o.cfg.CompletionInterval, false) { + go func() { + errComp := completion.UpdateCache(ctx, run.HTTPClient, o.dbExecutor, + o.cfg.AURURL, o.cfg.CompletionPath, o.logger) + if errComp != nil { + o.logger.Warnln(errComp) + } + }() + } srcInfo, errInstall := srcinfo.NewService(o.dbExecutor, o.cfg, o.logger.Child("srcinfo"), run.CmdBuilder, run.VCSStore, pkgBuildDirs)