diff --git a/cmd.go b/cmd.go index a4842f01..3149e1e0 100644 --- a/cmd.go +++ b/cmd.go @@ -1,7 +1,6 @@ package main import ( - "bufio" "context" "errors" "fmt" @@ -269,7 +268,7 @@ func handlePrint(ctx context.Context, run *runtime.Runtime, cmdArgs *parser.Argu dbExecutor.LastBuildTime(), run.Cfg.BottomUp, double, quiet) case cmdArgs.ExistsArg("c", "complete"): return completion.Show(ctx, run.HTTPClient, dbExecutor, - run.Cfg.AURURL, run.Cfg.CompletionPath, run.Cfg.CompletionInterval, cmdArgs.ExistsDouble("c", "complete")) + run.Cfg.AURURL, run.Cfg.CompletionPath, run.Cfg.CompletionInterval, cmdArgs.ExistsDouble("c", "complete"), run.Logger) case cmdArgs.ExistsArg("s", "stats"): return localStatistics(ctx, run, dbExecutor) } @@ -427,18 +426,11 @@ func syncList(ctx context.Context, run *runtime.Runtime, } if run.Cfg.Mode.AtLeastAUR() && (len(cmdArgs.Targets) == 0 || aur) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, run.Cfg.AURURL+"/packages.gz", http.NoBody) + scanner, err := download.GetPackageScanner(ctx, httpClient, run.Cfg.AURURL, run.Logger) if err != nil { return err } - - resp, err := httpClient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - scanner := bufio.NewScanner(resp.Body) + defer scanner.Close() scanner.Scan() diff --git a/pkg/completion/completion.go b/pkg/completion/completion.go index 640fdf0e..43d407e4 100644 --- a/pkg/completion/completion.go +++ b/pkg/completion/completion.go @@ -1,34 +1,27 @@ package completion import ( - "bufio" "context" - "fmt" "io" - "net/http" - "net/url" "os" - "path" "path/filepath" "strings" "time" "github.com/Jguer/yay/v12/pkg/db" + "github.com/Jguer/yay/v12/pkg/download" + "github.com/Jguer/yay/v12/pkg/text" ) type PkgSynchronizer interface { SyncPackages(...string) []db.IPackage } -type httpRequestDoer interface { - Do(req *http.Request) (*http.Response, error) -} - // Show provides completion info for shells. -func Show(ctx context.Context, httpClient httpRequestDoer, - dbExecutor PkgSynchronizer, aurURL, completionPath string, interval int, force bool, +func Show(ctx context.Context, httpClient download.HTTPRequestDoer, + dbExecutor PkgSynchronizer, aurURL, completionPath string, interval int, force bool, logger *text.Logger, ) error { - err := Update(ctx, httpClient, dbExecutor, aurURL, completionPath, interval, force) + err := Update(ctx, httpClient, dbExecutor, aurURL, completionPath, interval, force, logger) if err != nil { return err } @@ -45,8 +38,8 @@ func Show(ctx context.Context, httpClient httpRequestDoer, } // Update updates completion cache to be used by Complete. -func Update(ctx context.Context, httpClient httpRequestDoer, - dbExecutor PkgSynchronizer, aurURL, completionPath string, interval int, force bool, +func Update(ctx context.Context, httpClient download.HTTPRequestDoer, + dbExecutor PkgSynchronizer, aurURL, completionPath string, interval int, force bool, logger *text.Logger, ) error { info, err := os.Stat(completionPath) @@ -61,7 +54,7 @@ func Update(ctx context.Context, httpClient httpRequestDoer, return errf } - if createAURList(ctx, httpClient, aurURL, out) != nil { + if createAURList(ctx, httpClient, aurURL, out, logger) != nil { defer os.Remove(completionPath) } @@ -75,41 +68,23 @@ func Update(ctx context.Context, httpClient httpRequestDoer, return nil } -// CreateAURList creates a new completion file. -func createAURList(ctx context.Context, client httpRequestDoer, aurURL string, out io.Writer) error { - u, err := url.Parse(aurURL) +// createAURList creates a new completion file. +func createAURList(ctx context.Context, client download.HTTPRequestDoer, aurURL string, out io.Writer, logger *text.Logger) error { + scanner, err := download.GetPackageScanner(ctx, client, aurURL, logger) if err != nil { return err } - - u.Path = path.Join(u.Path, "packages.gz") - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), http.NoBody) - if err != nil { - return err - } - - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("invalid status code: %d", resp.StatusCode) - } - - scanner := bufio.NewScanner(resp.Body) + defer scanner.Close() scanner.Scan() for scanner.Scan() { - text := scanner.Text() - if strings.HasPrefix(text, "#") { + pkgName := scanner.Text() + if strings.HasPrefix(pkgName, "#") { continue } - if _, err := io.WriteString(out, text+"\tAUR\n"); err != nil { + if _, err := io.WriteString(out, pkgName+"\tAUR\n"); err != nil { return err } } diff --git a/pkg/completion/completion_test.go b/pkg/completion/completion_test.go index 5af0beac..0b2507c7 100644 --- a/pkg/completion/completion_test.go +++ b/pkg/completion/completion_test.go @@ -5,6 +5,7 @@ package completion import ( "bytes" + "compress/gzip" "context" "errors" "io" @@ -38,31 +39,55 @@ eternallands-sound AUR type mockDoer struct { t *testing.T - returnBody string + returnBody []byte returnStatusCode int returnErr error - wantUrl string + wantURL string } -func (m *mockDoer) Do(req *http.Request) (*http.Response, error) { - assert.Equal(m.t, m.wantUrl, req.URL.String()) +func (m *mockDoer) Get(url string) (*http.Response, error) { + assert.Equal(m.t, m.wantURL, url) return &http.Response{ StatusCode: m.returnStatusCode, - Body: io.NopCloser(bytes.NewBufferString(m.returnBody)), + Body: io.NopCloser(bytes.NewReader(m.returnBody)), }, m.returnErr } +func gzipString(s string) []byte { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + gz.Write([]byte(s)) + gz.Close() + return buf.Bytes() +} + func Test_createAURList(t *testing.T) { t.Parallel() doer := &mockDoer{ t: t, - wantUrl: "https://aur.archlinux.org/packages.gz", + wantURL: "https://aur.archlinux.org/packages.gz", returnStatusCode: 200, - returnBody: samplePackageResp, + returnBody: []byte(samplePackageResp), returnErr: nil, } out := &bytes.Buffer{} - err := createAURList(context.Background(), doer, "https://aur.archlinux.org", out) + err := createAURList(context.Background(), doer, "https://aur.archlinux.org", out, nil) + assert.NoError(t, err) + gotOut := out.String() + assert.Equal(t, expectPackageCompletion, gotOut) +} + +func Test_createAURListGzip(t *testing.T) { + t.Parallel() + doer := &mockDoer{ + t: t, + wantURL: "https://aur.archlinux.org/packages.gz", + returnStatusCode: 200, + returnBody: gzipString(samplePackageResp), + returnErr: nil, + } + out := &bytes.Buffer{} + err := createAURList(context.Background(), doer, "https://aur.archlinux.org", out, nil) assert.NoError(t, err) gotOut := out.String() assert.Equal(t, expectPackageCompletion, gotOut) @@ -72,14 +97,14 @@ func Test_createAURListHTTPError(t *testing.T) { t.Parallel() doer := &mockDoer{ t: t, - wantUrl: "https://aur.archlinux.org/packages.gz", + wantURL: "https://aur.archlinux.org/packages.gz", returnStatusCode: 200, - returnBody: samplePackageResp, + returnBody: []byte(samplePackageResp), returnErr: errors.New("Not available"), } out := &bytes.Buffer{} - err := createAURList(context.Background(), doer, "https://aur.archlinux.org", out) + err := createAURList(context.Background(), doer, "https://aur.archlinux.org", out, nil) assert.EqualError(t, err, "Not available") } @@ -87,13 +112,13 @@ func Test_createAURListStatusError(t *testing.T) { t.Parallel() doer := &mockDoer{ t: t, - wantUrl: "https://aur.archlinux.org/packages.gz", + wantURL: "https://aur.archlinux.org/packages.gz", returnStatusCode: 503, - returnBody: samplePackageResp, + returnBody: []byte(samplePackageResp), returnErr: nil, } out := &bytes.Buffer{} - err := createAURList(context.Background(), doer, "https://aur.archlinux.org", out) + err := createAURList(context.Background(), doer, "https://aur.archlinux.org", out, nil) assert.EqualError(t, err, "invalid status code: 503") } diff --git a/pkg/download/abs.go b/pkg/download/abs.go index 7ca16338..db773d89 100644 --- a/pkg/download/abs.go +++ b/pkg/download/abs.go @@ -59,7 +59,7 @@ func convertPkgNameForURL(pkgName string) string { } // ABSPKGBUILD retrieves the PKGBUILD file to a dest directory. -func ABSPKGBUILD(httpClient httpRequestDoer, dbName, pkgName string) ([]byte, error) { +func ABSPKGBUILD(httpClient HTTPRequestDoer, dbName, pkgName string) ([]byte, error) { packageURL := getPackagePKGBUILDURL(pkgName) resp, err := httpClient.Get(packageURL) diff --git a/pkg/download/aur.go b/pkg/download/aur.go index 874cb6c5..648a5e41 100644 --- a/pkg/download/aur.go +++ b/pkg/download/aur.go @@ -1,11 +1,15 @@ package download import ( + "bufio" + "bytes" + "compress/gzip" "context" "fmt" "io" "net/http" "net/url" + "path" "sync" "github.com/leonelquinteros/gotext" @@ -15,7 +19,7 @@ import ( "github.com/Jguer/yay/v12/pkg/text" ) -func AURPKGBUILD(httpClient httpRequestDoer, pkgName, aurURL string) ([]byte, error) { +func AURPKGBUILD(httpClient HTTPRequestDoer, pkgName, aurURL string) ([]byte, error) { values := url.Values{} values.Set("h", pkgName) pkgURL := aurURL + "/cgit/aur.git/plain/PKGBUILD?" + values.Encode() @@ -98,3 +102,69 @@ func AURPKGBUILDRepos( return cloned, errs.Return() } + +// ScannerCloser combines a bufio.Scanner with a Close method. +type ScannerCloser struct { + *bufio.Scanner + closer io.Closer +} + +// Close closes the underlying gzip reader if present. +func (s *ScannerCloser) Close() error { + if s.closer != nil { + return s.closer.Close() + } + return nil +} + +// GetPackageScanner fetches the AUR packages.gz file and returns a scanner for reading its contents. +// The caller must call Close() on the returned ScannerCloser when done to properly release resources. +func GetPackageScanner(ctx context.Context, client HTTPRequestDoer, aurURL string, logger *text.Logger) (*ScannerCloser, error) { + u, err := url.Parse(aurURL) + if err != nil { + return nil, err + } + + u.Path = path.Join(u.Path, "packages.gz") + packagesURL := u.String() + + resp, err := client.Get(packagesURL) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, fmt.Errorf("invalid status code: %d", resp.StatusCode) + } + + // Read the entire body to allow trying gzip decompression + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + + if err != nil { + return nil, err + } + + // Try to decompress as gzip; if that fails, use raw body + var reader io.Reader + var closer io.Closer + + gzReader, gzErr := gzip.NewReader(bytes.NewReader(body)) + if gzErr == nil { + reader = gzReader + closer = gzReader + } else { + if logger != nil { + logger.Debugln("gzip decompression not needed, using raw response body") + } + reader = bytes.NewReader(body) + } + + scanner := bufio.NewScanner(reader) + + return &ScannerCloser{ + Scanner: scanner, + closer: closer, + }, nil +} diff --git a/pkg/download/unified.go b/pkg/download/unified.go index 3433b855..86fce756 100644 --- a/pkg/download/unified.go +++ b/pkg/download/unified.go @@ -18,7 +18,8 @@ import ( "github.com/Jguer/yay/v12/pkg/text" ) -type httpRequestDoer interface { +// HTTPRequestDoer is an interface for HTTP clients that can perform GET requests. +type HTTPRequestDoer interface { Get(string) (*http.Response, error) } diff --git a/pkg/sync/sync.go b/pkg/sync/sync.go index d5edcde4..374df975 100644 --- a/pkg/sync/sync.go +++ b/pkg/sync/sync.go @@ -65,7 +65,7 @@ func (o *OperationService) Run(ctx context.Context, run *runtime.Runtime, go func() { errComp := completion.Update(ctx, run.HTTPClient, o.dbExecutor, - o.cfg.AURURL, o.cfg.CompletionPath, o.cfg.CompletionInterval, false) + o.cfg.AURURL, o.cfg.CompletionPath, o.cfg.CompletionInterval, false, o.logger) if errComp != nil { o.logger.Warnln(errComp) }