diff --git a/cmd/WiiUDownloader/progressWindow.go b/cmd/WiiUDownloader/progressWindow.go index 11bd563..98d8a2a 100644 --- a/cmd/WiiUDownloader/progressWindow.go +++ b/cmd/WiiUDownloader/progressWindow.go @@ -2,6 +2,8 @@ package main import ( "fmt" + "sync" + "time" "github.com/dustin/go-humanize" "github.com/gotk3/gotk3/glib" @@ -33,6 +35,14 @@ func (sa *SpeedAverager) AddSpeed(speed int64) { sa.speeds = append(sa.speeds, speed) } +func calculateDownloadSpeed(downloaded int64, startTime, endTime time.Time) int64 { + duration := endTime.Sub(startTime).Seconds() + if duration > 0 { + return int64(float64(downloaded) / duration) + } + return 0 +} + func (sa *SpeedAverager) calculateAverageOfSpeeds() { var total int64 for _, speed := range sa.speeds { @@ -55,7 +65,9 @@ type ProgressWindow struct { cancelled bool totalToDownload int64 totalDownloaded int64 + progressMutex sync.Mutex speedAverager *SpeedAverager + startTime time.Time } func (pw *ProgressWindow) SetGameTitle(title string) { @@ -67,13 +79,13 @@ func (pw *ProgressWindow) SetGameTitle(title string) { } } -func (pw *ProgressWindow) UpdateDownloadProgress(downloaded, speed int64, filePath string) { +func (pw *ProgressWindow) UpdateDownloadProgress(downloaded int64) { glib.IdleAdd(func() { pw.cancelButton.SetSensitive(true) - currentDownload := downloaded + pw.totalDownloaded - pw.bar.SetFraction(float64(currentDownload) / float64(pw.totalToDownload)) - pw.speedAverager.AddSpeed(speed) - pw.bar.SetText(fmt.Sprintf("Downloading %s (%s/%s) (%s/s)", filePath, humanize.Bytes(uint64(currentDownload)), humanize.Bytes(uint64(pw.totalToDownload)), humanize.Bytes(uint64(int64(pw.speedAverager.GetAverageSpeed()))))) + pw.AddToTotalDownloaded(downloaded) + pw.bar.SetFraction(float64(pw.totalDownloaded) / float64(pw.totalToDownload)) + pw.speedAverager.AddSpeed(calculateDownloadSpeed(pw.totalDownloaded, pw.startTime, time.Now())) + pw.bar.SetText(fmt.Sprintf("Downloading... (%s/%s) (%s/s)", humanize.Bytes(uint64(pw.totalDownloaded)), humanize.Bytes(uint64(pw.totalToDownload)), humanize.Bytes(uint64(int64(pw.speedAverager.GetAverageSpeed()))))) }) for gtk.EventsPending() { gtk.MainIteration() @@ -110,11 +122,19 @@ func (pw *ProgressWindow) SetDownloadSize(size int64) { } func (pw *ProgressWindow) SetTotalDownloaded(total int64) { + pw.progressMutex.Lock() pw.totalDownloaded = total + pw.progressMutex.Unlock() } func (pw *ProgressWindow) AddToTotalDownloaded(toAdd int64) { + pw.progressMutex.Lock() pw.totalDownloaded += toAdd + pw.progressMutex.Unlock() +} + +func (pw *ProgressWindow) SetStartTime(startTime time.Time) { + pw.startTime = startTime } func createProgressWindow(parent *gtk.ApplicationWindow) (*ProgressWindow, error) { diff --git a/downloader.go b/downloader.go index e6b4014..5de4f42 100644 --- a/downloader.go +++ b/downloader.go @@ -2,6 +2,7 @@ package wiiudownloader import ( "bytes" + "context" "encoding/binary" "fmt" "io" @@ -10,27 +11,92 @@ import ( "path/filepath" "strings" "time" + + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" ) const ( - maxRetries = 5 - retryDelay = 5 * time.Second + maxRetries = 5 + retryDelay = 5 * time.Second + maxConcurrentDownloads = 4 +) + +var ( + errCancel = fmt.Errorf("cancelled download") ) type ProgressReporter interface { SetGameTitle(title string) - UpdateDownloadProgress(downloaded, speed int64, filePath string) + UpdateDownloadProgress(downloaded int64) UpdateDecryptionProgress(progress float64) Cancelled() bool SetCancelled() SetDownloadSize(size int64) SetTotalDownloaded(total int64) AddToTotalDownloaded(toAdd int64) + SetStartTime(startTime time.Time) +} + +func downloadFileWithSemaphore(ctx context.Context, progressReporter ProgressReporter, client *http.Client, downloadURL, dstPath string, doRetries bool, sem *semaphore.Weighted) error { + if err := sem.Acquire(ctx, 1); err != nil { + return nil + } + defer sem.Release(1) + + for attempt := 1; attempt <= maxRetries; attempt++ { + req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil) + if err != nil { + return err + } + + req.Header.Set("User-Agent", "WiiUDownloader") + + resp, err := client.Do(req) + if err != nil { + if doRetries && attempt < maxRetries && !progressReporter.Cancelled() { + time.Sleep(retryDelay) + continue + } + return err + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + if doRetries && attempt < maxRetries && !progressReporter.Cancelled() { + time.Sleep(retryDelay) + continue + } + fmt.Printf("download error after %d attempts, status code: %d, url: %s\n", attempt, resp.StatusCode, downloadURL) + return fmt.Errorf("download error after %d attempts, status code: %d", attempt, resp.StatusCode) + } + + file, err := os.Create(dstPath) + if err != nil { + resp.Body.Close() + return err + } + + writerProgress := newWriterProgress(file, progressReporter) + _, err = io.Copy(writerProgress, resp.Body) + if err != nil { + file.Close() + resp.Body.Close() + if doRetries && attempt < maxRetries && !progressReporter.Cancelled() { + time.Sleep(retryDelay) + continue + } + return err + } + file.Close() + resp.Body.Close() + break + } + + return nil } func downloadFile(progressReporter ProgressReporter, client *http.Client, downloadURL, dstPath string, doRetries bool) error { - filePath := filepath.Base(dstPath) - for attempt := 1; attempt <= maxRetries; attempt++ { req, err := http.NewRequest("GET", downloadURL, nil) if err != nil { @@ -63,7 +129,7 @@ func downloadFile(progressReporter ProgressReporter, client *http.Client, downlo return err } - writerProgress := newWriterProgress(file, progressReporter, time.Now(), filePath) + writerProgress := newWriterProgress(file, progressReporter) _, err = io.Copy(writerProgress, resp.Body) if err != nil { file.Close() @@ -133,17 +199,28 @@ func DownloadTitle(titleID, outputDirectory string, doDecryption bool, progressR } var titleSize uint64 - var contentSizes []uint64 - for i := 0; i < int(contentCount); i++ { - contentDataLoc := 0xB04 + (0x30 * i) + contents := make([]Content, contentCount) + tmdDataReader := bytes.NewReader(tmdData) - var contentSizeInt uint64 - if err := binary.Read(bytes.NewReader(tmdData[contentDataLoc+8:contentDataLoc+8+8]), binary.BigEndian, &contentSizeInt); err != nil { + for i := 0; i < int(contentCount); i++ { + offset := 0xB04 + (0x30 * i) + + tmdDataReader.Seek(int64(offset), io.SeekStart) + if err := binary.Read(tmdDataReader, binary.BigEndian, &contents[i].ID); err != nil { return err } - titleSize += contentSizeInt - contentSizes = append(contentSizes, contentSizeInt) + tmdDataReader.Seek(8, io.SeekCurrent) + + if err := binary.Read(tmdDataReader, binary.BigEndian, &contents[i].Type); err != nil { + return err + } + + if err := binary.Read(tmdDataReader, binary.BigEndian, &contents[i].Size); err != nil { + return err + } + + titleSize += contents[i].Size } progressReporter.SetDownloadSize(int64(titleSize)) @@ -164,39 +241,47 @@ func DownloadTitle(titleID, outputDirectory string, doDecryption bool, progressR if err := binary.Write(certFile, binary.BigEndian, cert.Bytes()); err != nil { return err } - defer certFile.Close() + certFile.Close() logger.Info("Certificate saved to %v \n", certPath) - var content Content - tmdDataReader := bytes.NewReader(tmdData) + g, ctx := errgroup.WithContext(context.Background()) + g.SetLimit(maxConcurrentDownloads) + sem := semaphore.NewWeighted(maxConcurrentDownloads) + progressReporter.SetStartTime(time.Now()) for i := 0; i < int(contentCount); i++ { - offset := 2820 + (48 * i) - tmdDataReader.Seek(int64(offset), 0) - if err := binary.Read(tmdDataReader, binary.BigEndian, &content.ID); err != nil { - return err - } - filePath := filepath.Join(outputDir, fmt.Sprintf("%08X.app", content.ID)) - if err := downloadFile(progressReporter, client, fmt.Sprintf("%s/%08X", baseURL, content.ID), filePath, true); err != nil { - if progressReporter.Cancelled() { - break - } - return err - } - progressReporter.AddToTotalDownloaded(int64(contentSizes[i])) + i := i + g.Go(func() error { - if tmdData[offset+7]&0x2 == 2 { - filePath = filepath.Join(outputDir, fmt.Sprintf("%08X.h3", content.ID)) - if err := downloadFile(progressReporter, client, fmt.Sprintf("%s/%08X.h3", baseURL, content.ID), filePath, true); err != nil { + filePath := filepath.Join(outputDir, fmt.Sprintf("%08X.app", contents[i].ID)) + if err := downloadFileWithSemaphore(ctx, progressReporter, client, fmt.Sprintf("%s/%08X", baseURL, contents[i].ID), filePath, true, sem); err != nil { if progressReporter.Cancelled() { - break + return errCancel } return err } + + if contents[i].Type&0x2 == 2 { // has a hash + filePath = filepath.Join(outputDir, fmt.Sprintf("%08X.h3", contents[i].ID)) + if err := downloadFileWithSemaphore(ctx, progressReporter, client, fmt.Sprintf("%s/%08X.h3", baseURL, contents[i].ID), filePath, true, sem); err != nil { + if progressReporter.Cancelled() { + return errCancel + } + return err + } + } + if progressReporter.Cancelled() { + return errCancel + } + return nil + }) + } + + if err := g.Wait(); err != nil { + if err == errCancel { + return nil } - if progressReporter.Cancelled() { - break - } + return err } if doDecryption && !progressReporter.Cancelled() { diff --git a/writerProgress.go b/writerProgress.go index b88a6f0..051d4dd 100644 --- a/writerProgress.go +++ b/writerProgress.go @@ -5,35 +5,31 @@ import ( "time" ) -func calculateDownloadSpeed(downloaded int64, startTime, endTime time.Time) int64 { - duration := endTime.Sub(startTime).Seconds() - if duration > 0 { - return int64(float64(downloaded) / duration) - } - return 0 -} - type WriterProgress struct { - writer io.Writer - progressReporter ProgressReporter - startTime time.Time - filePath string - totalDownloaded int64 + writer io.Writer + progressReporter ProgressReporter + updateProgressTicker *time.Ticker + downloadToReport int64 // Number of bytes to report to the progressReporter since the last update } -func newWriterProgress(writer io.Writer, progressReporter ProgressReporter, startTime time.Time, filePath string) *WriterProgress { - return &WriterProgress{writer: writer, totalDownloaded: 0, progressReporter: progressReporter, startTime: startTime, filePath: filePath} +func newWriterProgress(writer io.Writer, progressReporter ProgressReporter) *WriterProgress { + return &WriterProgress{writer: writer, progressReporter: progressReporter, updateProgressTicker: time.NewTicker(25 * time.Millisecond), downloadToReport: 0} } func (r *WriterProgress) Write(p []byte) (n int, err error) { + select { + case <-r.updateProgressTicker.C: + r.progressReporter.UpdateDownloadProgress(r.downloadToReport) + r.downloadToReport = 0 + default: + } if r.progressReporter.Cancelled() { - return len(p), nil + return 0, nil } n, err = r.writer.Write(p) if err != nil && err != io.EOF { return n, err } - r.totalDownloaded += int64(n) - r.progressReporter.UpdateDownloadProgress(r.totalDownloaded, calculateDownloadSpeed(r.totalDownloaded, r.startTime, time.Now()), r.filePath) + r.downloadToReport += int64(n) return n, err }