From c83ee062e4e2291f6eacc824050fa1c320031d40 Mon Sep 17 00:00:00 2001 From: Xpl0itU <24777100+Xpl0itU@users.noreply.github.com> Date: Wed, 10 Apr 2024 16:30:15 +0200 Subject: [PATCH] Rewrite download logic --- certificate.go | 8 +-- cmd/WiiUDownloader/mainwindow.go | 2 +- downloader.go | 108 ++++++++----------------------- writerProgress.go | 39 +++++++++++ 4 files changed, 72 insertions(+), 85 deletions(-) create mode 100644 writerProgress.go diff --git a/certificate.go b/certificate.go index 205a2e2..f682acd 100644 --- a/certificate.go +++ b/certificate.go @@ -28,12 +28,12 @@ func getCert(tmdData []byte, id int, numContents uint16) ([]byte, error) { } } -func getDefaultCert(cancelCtx context.Context, progressReporter ProgressReporter, client *http.Client, buffer []byte) ([]byte, error) { +func getDefaultCert(cancelCtx context.Context, progressReporter ProgressReporter, client *http.Client) ([]byte, error) { if len(cetkData) >= 0x350+0x300 { return cetkData[0x350 : 0x350+0x300], nil } cetkDir := path.Join(os.TempDir(), "cetk") - if err := downloadFile(cancelCtx, progressReporter, client, "http://ccs.cdn.c.shop.nintendowifi.net/ccs/download/000500101000400a/cetk", cetkDir, true, buffer); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, "http://ccs.cdn.c.shop.nintendowifi.net/ccs/download/000500101000400a/cetk", cetkDir, true); err != nil { return nil, err } cetkData, err := os.ReadFile(cetkDir) @@ -51,7 +51,7 @@ func getDefaultCert(cancelCtx context.Context, progressReporter ProgressReporter return nil, fmt.Errorf("failed to download OSv10 cetk, length: %d", len(cetkData)) } -func GenerateCert(tmdData []byte, contentCount uint16, progressReporter ProgressReporter, client *http.Client, cancelCtx context.Context, buffer []byte) (bytes.Buffer, error) { +func GenerateCert(tmdData []byte, contentCount uint16, progressReporter ProgressReporter, client *http.Client, cancelCtx context.Context) (bytes.Buffer, error) { cert := bytes.Buffer{} cert0, err := getCert(tmdData, 0, contentCount) @@ -66,7 +66,7 @@ func GenerateCert(tmdData []byte, contentCount uint16, progressReporter Progress } cert.Write(cert1) - defaultCert, err := getDefaultCert(cancelCtx, progressReporter, client, buffer) + defaultCert, err := getDefaultCert(cancelCtx, progressReporter, client) if err != nil { return bytes.Buffer{}, err } diff --git a/cmd/WiiUDownloader/mainwindow.go b/cmd/WiiUDownloader/mainwindow.go index 3fc598c..effa2e3 100644 --- a/cmd/WiiUDownloader/mainwindow.go +++ b/cmd/WiiUDownloader/mainwindow.go @@ -265,7 +265,7 @@ func (mw *MainWindow) ShowAll() { wiiudownloader.GenerateTicket(filepath.Join(parentDir, "title.tik"), titleID, titleKey, titleVersion) - cert, err := wiiudownloader.GenerateCert(tmdData, contentCount, mw.progressWindow, http.DefaultClient, context.Background(), make([]byte, 0)) + cert, err := wiiudownloader.GenerateCert(tmdData, contentCount, mw.progressWindow, http.DefaultClient, context.Background()) if err != nil { return } diff --git a/downloader.go b/downloader.go index 4a4c125..8d8864d 100644 --- a/downloader.go +++ b/downloader.go @@ -14,9 +14,8 @@ import ( ) const ( - maxRetries = 5 - retryDelay = 5 * time.Second - BUFFER_SIZE = 1048576 + maxRetries = 5 + retryDelay = 5 * time.Second ) type ProgressReporter interface { @@ -30,33 +29,14 @@ type ProgressReporter interface { AddToTotalDownloaded(toAdd int64) } -func calculateDownloadSpeed(downloaded int64, startTime, endTime time.Time) int64 { - duration := endTime.Sub(startTime).Seconds() - if duration > 0 { - return int64(float64(downloaded) / duration) - } - return 0 -} - -func downloadFile(ctx context.Context, progressReporter ProgressReporter, client *http.Client, downloadURL, dstPath string, doRetries bool, buffer []byte) error { +func downloadFile(ctx context.Context, progressReporter ProgressReporter, client *http.Client, downloadURL, dstPath string, doRetries bool) error { filePath := filepath.Base(dstPath) startTime := time.Now() ticker := time.NewTicker(50 * time.Millisecond) defer ticker.Stop() - isError := false - - updateProgress := func(downloaded *int64) { - for range ticker.C { - if progressReporter.Cancelled() { - return - } - progressReporter.UpdateDownloadProgress(*downloaded, calculateDownloadSpeed(*downloaded, startTime, time.Now()), filePath) - } - } for attempt := 1; attempt <= maxRetries; attempt++ { - isError = false req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil) if err != nil { return err @@ -68,15 +48,19 @@ func downloadFile(ctx context.Context, progressReporter ProgressReporter, client resp, err := client.Do(req) if err != nil { - return err - } - - if resp.StatusCode != http.StatusOK { if doRetries && attempt < maxRetries { time.Sleep(retryDelay) continue } + return err + } + + if resp.StatusCode != http.StatusOK { resp.Body.Close() + if doRetries && attempt < maxRetries { + time.Sleep(retryDelay) + continue + } return fmt.Errorf("download error after %d attempts, status code: %d", attempt, resp.StatusCode) } @@ -86,54 +70,20 @@ func downloadFile(ctx context.Context, progressReporter ProgressReporter, client return err } - var downloaded int64 - - go updateProgress(&downloaded) - - Loop: - for { - select { - case <-ctx.Done(): - resp.Body.Close() - file.Close() - return ctx.Err() - default: - n, err := resp.Body.Read(buffer) - if err != nil && err != io.EOF { - resp.Body.Close() - file.Close() - if doRetries && attempt < maxRetries { - time.Sleep(retryDelay) - isError = true - break Loop - } - return fmt.Errorf("download error after %d attempts: %+v", attempt, err) - } - - if n == 0 { - resp.Body.Close() - file.Close() - break Loop - } - - _, err = file.Write(buffer[:n]) - if err != nil { - resp.Body.Close() - file.Close() - if doRetries && attempt < maxRetries { - time.Sleep(retryDelay) - isError = true - break Loop - } - return fmt.Errorf("write error after %d attempts: %+v", attempt, err) - } - - downloaded += int64(n) + writerProgress := newWriterProgress(file, progressReporter, startTime, filePath) + _, err = io.Copy(writerProgress, resp.Body) + if err != nil { + file.Close() + resp.Body.Close() + if doRetries && attempt < maxRetries { + time.Sleep(retryDelay) + continue } + return err } - if !isError { - break - } + file.Close() + resp.Body.Close() + break } return nil @@ -152,10 +102,8 @@ func DownloadTitle(cancelCtx context.Context, titleID, outputDirectory string, d return err } - buffer := make([]byte, BUFFER_SIZE) - tmdPath := filepath.Join(outputDir, "title.tmd") - if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%s", baseURL, "tmd"), tmdPath, true, buffer); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%s", baseURL, "tmd"), tmdPath, true); err != nil { if progressReporter.Cancelled() { return nil } @@ -173,7 +121,7 @@ func DownloadTitle(cancelCtx context.Context, titleID, outputDirectory string, d } tikPath := filepath.Join(outputDir, "title.tik") - if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%s", baseURL, "cetk"), tikPath, false, buffer); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%s", baseURL, "cetk"), tikPath, false); err != nil { if progressReporter.Cancelled() { return nil } @@ -207,7 +155,7 @@ func DownloadTitle(cancelCtx context.Context, titleID, outputDirectory string, d progressReporter.SetDownloadSize(int64(titleSize)) - cert, err := GenerateCert(tmdData, contentCount, progressReporter, client, cancelCtx, buffer) + cert, err := GenerateCert(tmdData, contentCount, progressReporter, client, cancelCtx) if err != nil { if progressReporter.Cancelled() { return nil @@ -236,7 +184,7 @@ func DownloadTitle(cancelCtx context.Context, titleID, outputDirectory string, d return err } filePath := filepath.Join(outputDir, fmt.Sprintf("%08X.app", content.ID)) - if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%08X", baseURL, content.ID), filePath, true, buffer); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%08X", baseURL, content.ID), filePath, true); err != nil { if progressReporter.Cancelled() { break } @@ -246,7 +194,7 @@ func DownloadTitle(cancelCtx context.Context, titleID, outputDirectory string, d if tmdData[offset+7]&0x2 == 2 { filePath = filepath.Join(outputDir, fmt.Sprintf("%08X.h3", content.ID)) - if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%08X.h3", baseURL, content.ID), filePath, true, buffer); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%08X.h3", baseURL, content.ID), filePath, true); err != nil { if progressReporter.Cancelled() { break } diff --git a/writerProgress.go b/writerProgress.go new file mode 100644 index 0000000..b88a6f0 --- /dev/null +++ b/writerProgress.go @@ -0,0 +1,39 @@ +package wiiudownloader + +import ( + "io" + "time" +) + +func calculateDownloadSpeed(downloaded int64, startTime, endTime time.Time) int64 { + duration := endTime.Sub(startTime).Seconds() + if duration > 0 { + return int64(float64(downloaded) / duration) + } + return 0 +} + +type WriterProgress struct { + writer io.Writer + progressReporter ProgressReporter + startTime time.Time + filePath string + totalDownloaded int64 +} + +func newWriterProgress(writer io.Writer, progressReporter ProgressReporter, startTime time.Time, filePath string) *WriterProgress { + return &WriterProgress{writer: writer, totalDownloaded: 0, progressReporter: progressReporter, startTime: startTime, filePath: filePath} +} + +func (r *WriterProgress) Write(p []byte) (n int, err error) { + if r.progressReporter.Cancelled() { + return len(p), nil + } + n, err = r.writer.Write(p) + if err != nil && err != io.EOF { + return n, err + } + r.totalDownloaded += int64(n) + r.progressReporter.UpdateDownloadProgress(r.totalDownloaded, calculateDownloadSpeed(r.totalDownloaded, r.startTime, time.Now()), r.filePath) + return n, err +}