Implement multiple concurrent downloads

This commit is contained in:
Xpl0itU 2024-04-13 19:01:31 +02:00
parent 894ac30409
commit b74b43fff0
3 changed files with 160 additions and 59 deletions

View file

@ -2,6 +2,8 @@ package main
import ( import (
"fmt" "fmt"
"sync"
"time"
"github.com/dustin/go-humanize" "github.com/dustin/go-humanize"
"github.com/gotk3/gotk3/glib" "github.com/gotk3/gotk3/glib"
@ -33,6 +35,14 @@ func (sa *SpeedAverager) AddSpeed(speed int64) {
sa.speeds = append(sa.speeds, speed) sa.speeds = append(sa.speeds, speed)
} }
func calculateDownloadSpeed(downloaded int64, startTime, endTime time.Time) int64 {
duration := endTime.Sub(startTime).Seconds()
if duration > 0 {
return int64(float64(downloaded) / duration)
}
return 0
}
func (sa *SpeedAverager) calculateAverageOfSpeeds() { func (sa *SpeedAverager) calculateAverageOfSpeeds() {
var total int64 var total int64
for _, speed := range sa.speeds { for _, speed := range sa.speeds {
@ -55,7 +65,9 @@ type ProgressWindow struct {
cancelled bool cancelled bool
totalToDownload int64 totalToDownload int64
totalDownloaded int64 totalDownloaded int64
progressMutex sync.Mutex
speedAverager *SpeedAverager speedAverager *SpeedAverager
startTime time.Time
} }
func (pw *ProgressWindow) SetGameTitle(title string) { func (pw *ProgressWindow) SetGameTitle(title string) {
@ -67,13 +79,13 @@ func (pw *ProgressWindow) SetGameTitle(title string) {
} }
} }
func (pw *ProgressWindow) UpdateDownloadProgress(downloaded, speed int64, filePath string) { func (pw *ProgressWindow) UpdateDownloadProgress(downloaded int64) {
glib.IdleAdd(func() { glib.IdleAdd(func() {
pw.cancelButton.SetSensitive(true) pw.cancelButton.SetSensitive(true)
currentDownload := downloaded + pw.totalDownloaded pw.AddToTotalDownloaded(downloaded)
pw.bar.SetFraction(float64(currentDownload) / float64(pw.totalToDownload)) pw.bar.SetFraction(float64(pw.totalDownloaded) / float64(pw.totalToDownload))
pw.speedAverager.AddSpeed(speed) pw.speedAverager.AddSpeed(calculateDownloadSpeed(pw.totalDownloaded, pw.startTime, time.Now()))
pw.bar.SetText(fmt.Sprintf("Downloading %s (%s/%s) (%s/s)", filePath, humanize.Bytes(uint64(currentDownload)), humanize.Bytes(uint64(pw.totalToDownload)), humanize.Bytes(uint64(int64(pw.speedAverager.GetAverageSpeed()))))) pw.bar.SetText(fmt.Sprintf("Downloading... (%s/%s) (%s/s)", humanize.Bytes(uint64(pw.totalDownloaded)), humanize.Bytes(uint64(pw.totalToDownload)), humanize.Bytes(uint64(int64(pw.speedAverager.GetAverageSpeed())))))
}) })
for gtk.EventsPending() { for gtk.EventsPending() {
gtk.MainIteration() gtk.MainIteration()
@ -110,11 +122,19 @@ func (pw *ProgressWindow) SetDownloadSize(size int64) {
} }
func (pw *ProgressWindow) SetTotalDownloaded(total int64) { func (pw *ProgressWindow) SetTotalDownloaded(total int64) {
pw.progressMutex.Lock()
pw.totalDownloaded = total pw.totalDownloaded = total
pw.progressMutex.Unlock()
} }
func (pw *ProgressWindow) AddToTotalDownloaded(toAdd int64) { func (pw *ProgressWindow) AddToTotalDownloaded(toAdd int64) {
pw.progressMutex.Lock()
pw.totalDownloaded += toAdd pw.totalDownloaded += toAdd
pw.progressMutex.Unlock()
}
func (pw *ProgressWindow) SetStartTime(startTime time.Time) {
pw.startTime = startTime
} }
func createProgressWindow(parent *gtk.ApplicationWindow) (*ProgressWindow, error) { func createProgressWindow(parent *gtk.ApplicationWindow) (*ProgressWindow, error) {

View file

@ -2,6 +2,7 @@ package wiiudownloader
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
@ -10,27 +11,92 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
) )
const ( const (
maxRetries = 5 maxRetries = 5
retryDelay = 5 * time.Second retryDelay = 5 * time.Second
maxConcurrentDownloads = 4
)
var (
errCancel = fmt.Errorf("cancelled download")
) )
type ProgressReporter interface { type ProgressReporter interface {
SetGameTitle(title string) SetGameTitle(title string)
UpdateDownloadProgress(downloaded, speed int64, filePath string) UpdateDownloadProgress(downloaded int64)
UpdateDecryptionProgress(progress float64) UpdateDecryptionProgress(progress float64)
Cancelled() bool Cancelled() bool
SetCancelled() SetCancelled()
SetDownloadSize(size int64) SetDownloadSize(size int64)
SetTotalDownloaded(total int64) SetTotalDownloaded(total int64)
AddToTotalDownloaded(toAdd int64) AddToTotalDownloaded(toAdd int64)
SetStartTime(startTime time.Time)
}
func downloadFileWithSemaphore(ctx context.Context, progressReporter ProgressReporter, client *http.Client, downloadURL, dstPath string, doRetries bool, sem *semaphore.Weighted) error {
if err := sem.Acquire(ctx, 1); err != nil {
return nil
}
defer sem.Release(1)
for attempt := 1; attempt <= maxRetries; attempt++ {
req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
if err != nil {
return err
}
req.Header.Set("User-Agent", "WiiUDownloader")
resp, err := client.Do(req)
if err != nil {
if doRetries && attempt < maxRetries && !progressReporter.Cancelled() {
time.Sleep(retryDelay)
continue
}
return err
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
if doRetries && attempt < maxRetries && !progressReporter.Cancelled() {
time.Sleep(retryDelay)
continue
}
fmt.Printf("download error after %d attempts, status code: %d, url: %s\n", attempt, resp.StatusCode, downloadURL)
return fmt.Errorf("download error after %d attempts, status code: %d", attempt, resp.StatusCode)
}
file, err := os.Create(dstPath)
if err != nil {
resp.Body.Close()
return err
}
writerProgress := newWriterProgress(file, progressReporter)
_, err = io.Copy(writerProgress, resp.Body)
if err != nil {
file.Close()
resp.Body.Close()
if doRetries && attempt < maxRetries && !progressReporter.Cancelled() {
time.Sleep(retryDelay)
continue
}
return err
}
file.Close()
resp.Body.Close()
break
}
return nil
} }
func downloadFile(progressReporter ProgressReporter, client *http.Client, downloadURL, dstPath string, doRetries bool) error { func downloadFile(progressReporter ProgressReporter, client *http.Client, downloadURL, dstPath string, doRetries bool) error {
filePath := filepath.Base(dstPath)
for attempt := 1; attempt <= maxRetries; attempt++ { for attempt := 1; attempt <= maxRetries; attempt++ {
req, err := http.NewRequest("GET", downloadURL, nil) req, err := http.NewRequest("GET", downloadURL, nil)
if err != nil { if err != nil {
@ -63,7 +129,7 @@ func downloadFile(progressReporter ProgressReporter, client *http.Client, downlo
return err return err
} }
writerProgress := newWriterProgress(file, progressReporter, time.Now(), filePath) writerProgress := newWriterProgress(file, progressReporter)
_, err = io.Copy(writerProgress, resp.Body) _, err = io.Copy(writerProgress, resp.Body)
if err != nil { if err != nil {
file.Close() file.Close()
@ -133,17 +199,28 @@ func DownloadTitle(titleID, outputDirectory string, doDecryption bool, progressR
} }
var titleSize uint64 var titleSize uint64
var contentSizes []uint64 contents := make([]Content, contentCount)
for i := 0; i < int(contentCount); i++ { tmdDataReader := bytes.NewReader(tmdData)
contentDataLoc := 0xB04 + (0x30 * i)
var contentSizeInt uint64 for i := 0; i < int(contentCount); i++ {
if err := binary.Read(bytes.NewReader(tmdData[contentDataLoc+8:contentDataLoc+8+8]), binary.BigEndian, &contentSizeInt); err != nil { offset := 0xB04 + (0x30 * i)
tmdDataReader.Seek(int64(offset), io.SeekStart)
if err := binary.Read(tmdDataReader, binary.BigEndian, &contents[i].ID); err != nil {
return err return err
} }
titleSize += contentSizeInt tmdDataReader.Seek(8, io.SeekCurrent)
contentSizes = append(contentSizes, contentSizeInt)
if err := binary.Read(tmdDataReader, binary.BigEndian, &contents[i].Type); err != nil {
return err
}
if err := binary.Read(tmdDataReader, binary.BigEndian, &contents[i].Size); err != nil {
return err
}
titleSize += contents[i].Size
} }
progressReporter.SetDownloadSize(int64(titleSize)) progressReporter.SetDownloadSize(int64(titleSize))
@ -164,39 +241,47 @@ func DownloadTitle(titleID, outputDirectory string, doDecryption bool, progressR
if err := binary.Write(certFile, binary.BigEndian, cert.Bytes()); err != nil { if err := binary.Write(certFile, binary.BigEndian, cert.Bytes()); err != nil {
return err return err
} }
defer certFile.Close() certFile.Close()
logger.Info("Certificate saved to %v \n", certPath) logger.Info("Certificate saved to %v \n", certPath)
var content Content g, ctx := errgroup.WithContext(context.Background())
tmdDataReader := bytes.NewReader(tmdData) g.SetLimit(maxConcurrentDownloads)
sem := semaphore.NewWeighted(maxConcurrentDownloads)
progressReporter.SetStartTime(time.Now())
for i := 0; i < int(contentCount); i++ { for i := 0; i < int(contentCount); i++ {
offset := 2820 + (48 * i) i := i
tmdDataReader.Seek(int64(offset), 0) g.Go(func() error {
if err := binary.Read(tmdDataReader, binary.BigEndian, &content.ID); err != nil {
return err
}
filePath := filepath.Join(outputDir, fmt.Sprintf("%08X.app", content.ID))
if err := downloadFile(progressReporter, client, fmt.Sprintf("%s/%08X", baseURL, content.ID), filePath, true); err != nil {
if progressReporter.Cancelled() {
break
}
return err
}
progressReporter.AddToTotalDownloaded(int64(contentSizes[i]))
if tmdData[offset+7]&0x2 == 2 { filePath := filepath.Join(outputDir, fmt.Sprintf("%08X.app", contents[i].ID))
filePath = filepath.Join(outputDir, fmt.Sprintf("%08X.h3", content.ID)) if err := downloadFileWithSemaphore(ctx, progressReporter, client, fmt.Sprintf("%s/%08X", baseURL, contents[i].ID), filePath, true, sem); err != nil {
if err := downloadFile(progressReporter, client, fmt.Sprintf("%s/%08X.h3", baseURL, content.ID), filePath, true); err != nil {
if progressReporter.Cancelled() { if progressReporter.Cancelled() {
break return errCancel
}
return err
}
if contents[i].Type&0x2 == 2 { // has a hash
filePath = filepath.Join(outputDir, fmt.Sprintf("%08X.h3", contents[i].ID))
if err := downloadFileWithSemaphore(ctx, progressReporter, client, fmt.Sprintf("%s/%08X.h3", baseURL, contents[i].ID), filePath, true, sem); err != nil {
if progressReporter.Cancelled() {
return errCancel
} }
return err return err
} }
} }
if progressReporter.Cancelled() { if progressReporter.Cancelled() {
break return errCancel
} }
return nil
})
}
if err := g.Wait(); err != nil {
if err == errCancel {
return nil
}
return err
} }
if doDecryption && !progressReporter.Cancelled() { if doDecryption && !progressReporter.Cancelled() {

View file

@ -5,35 +5,31 @@ import (
"time" "time"
) )
func calculateDownloadSpeed(downloaded int64, startTime, endTime time.Time) int64 {
duration := endTime.Sub(startTime).Seconds()
if duration > 0 {
return int64(float64(downloaded) / duration)
}
return 0
}
type WriterProgress struct { type WriterProgress struct {
writer io.Writer writer io.Writer
progressReporter ProgressReporter progressReporter ProgressReporter
startTime time.Time updateProgressTicker *time.Ticker
filePath string downloadToReport int64 // Number of bytes to report to the progressReporter since the last update
totalDownloaded int64
} }
func newWriterProgress(writer io.Writer, progressReporter ProgressReporter, startTime time.Time, filePath string) *WriterProgress { func newWriterProgress(writer io.Writer, progressReporter ProgressReporter) *WriterProgress {
return &WriterProgress{writer: writer, totalDownloaded: 0, progressReporter: progressReporter, startTime: startTime, filePath: filePath} return &WriterProgress{writer: writer, progressReporter: progressReporter, updateProgressTicker: time.NewTicker(25 * time.Millisecond), downloadToReport: 0}
} }
func (r *WriterProgress) Write(p []byte) (n int, err error) { func (r *WriterProgress) Write(p []byte) (n int, err error) {
select {
case <-r.updateProgressTicker.C:
r.progressReporter.UpdateDownloadProgress(r.downloadToReport)
r.downloadToReport = 0
default:
}
if r.progressReporter.Cancelled() { if r.progressReporter.Cancelled() {
return len(p), nil return 0, nil
} }
n, err = r.writer.Write(p) n, err = r.writer.Write(p)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return n, err return n, err
} }
r.totalDownloaded += int64(n) r.downloadToReport += int64(n)
r.progressReporter.UpdateDownloadProgress(r.totalDownloaded, calculateDownloadSpeed(r.totalDownloaded, r.startTime, time.Now()), r.filePath)
return n, err return n, err
} }