diff --git a/certificate.go b/certificate.go index d1cf529..f682acd 100644 --- a/certificate.go +++ b/certificate.go @@ -2,11 +2,11 @@ package wiiudownloader import ( "bytes" + "context" "fmt" + "net/http" "os" "path" - - "github.com/cavaliergopher/grab/v3" ) var cetkData []byte @@ -28,12 +28,12 @@ func getCert(tmdData []byte, id int, numContents uint16) ([]byte, error) { } } -func getDefaultCert(progressReporter ProgressReporter, client *grab.Client) ([]byte, error) { +func getDefaultCert(cancelCtx context.Context, progressReporter ProgressReporter, client *http.Client) ([]byte, error) { if len(cetkData) >= 0x350+0x300 { return cetkData[0x350 : 0x350+0x300], nil } cetkDir := path.Join(os.TempDir(), "cetk") - if err := downloadFile(progressReporter, client, "http://ccs.cdn.c.shop.nintendowifi.net/ccs/download/000500101000400a/cetk", cetkDir, true); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, "http://ccs.cdn.c.shop.nintendowifi.net/ccs/download/000500101000400a/cetk", cetkDir, true); err != nil { return nil, err } cetkData, err := os.ReadFile(cetkDir) @@ -51,7 +51,7 @@ func getDefaultCert(progressReporter ProgressReporter, client *grab.Client) ([]b return nil, fmt.Errorf("failed to download OSv10 cetk, length: %d", len(cetkData)) } -func GenerateCert(tmdData []byte, contentCount uint16, progressReporter ProgressReporter, client *grab.Client) (bytes.Buffer, error) { +func GenerateCert(tmdData []byte, contentCount uint16, progressReporter ProgressReporter, client *http.Client, cancelCtx context.Context) (bytes.Buffer, error) { cert := bytes.Buffer{} cert0, err := getCert(tmdData, 0, contentCount) @@ -66,7 +66,7 @@ func GenerateCert(tmdData []byte, contentCount uint16, progressReporter Progress } cert.Write(cert1) - defaultCert, err := getDefaultCert(progressReporter, client) + defaultCert, err := getDefaultCert(cancelCtx, progressReporter, client) if err != nil { return bytes.Buffer{}, err } diff --git a/cmd/WiiUDownloader/mainwindow.go b/cmd/WiiUDownloader/mainwindow.go index a962fbf..c3cfc03 100644 --- a/cmd/WiiUDownloader/mainwindow.go +++ b/cmd/WiiUDownloader/mainwindow.go @@ -5,6 +5,7 @@ import ( "context" "encoding/binary" "fmt" + "net/http" "os" "path/filepath" "strconv" @@ -12,7 +13,6 @@ import ( wiiudownloader "github.com/Xpl0itU/WiiUDownloader" "github.com/Xpl0itU/dialog" - "github.com/cavaliergopher/grab/v3" "github.com/gotk3/gotk3/glib" "github.com/gotk3/gotk3/gtk" "golang.org/x/sync/errgroup" @@ -263,7 +263,7 @@ func (mw *MainWindow) ShowAll() { wiiudownloader.GenerateTicket(filepath.Join(parentDir, "title.tik"), titleID, titleKey, titleVersion) - cert, err := wiiudownloader.GenerateCert(tmdData, contentCount, mw.progressWindow, grab.NewClient()) + cert, err := wiiudownloader.GenerateCert(tmdData, contentCount, mw.progressWindow, http.DefaultClient, context.Background()) if err != nil { return } @@ -745,7 +745,8 @@ func (mw *MainWindow) onDownloadQueueClicked(selectedPath string) error { errGroup := errgroup.Group{} queueCtx, cancel := context.WithCancel(context.Background()) - defer cancel() + mw.progressWindow.cancelFunc = cancel + defer mw.progressWindow.cancelFunc() for _, title := range mw.titleQueue { errGroup.Go(func() error { @@ -757,7 +758,7 @@ func (mw *MainWindow) onDownloadQueueClicked(selectedPath string) error { } tidStr := fmt.Sprintf("%016x", title.TitleID) titlePath := filepath.Join(selectedPath, fmt.Sprintf("%s [%s] [%s]", normalizeFilename(title.Name), wiiudownloader.GetFormattedKind(title.TitleID), tidStr)) - if err := wiiudownloader.DownloadTitle(cancel, tidStr, titlePath, mw.decryptContents, mw.progressWindow, mw.getDeleteEncryptedContents(), mw.logger); err != nil { + if err := wiiudownloader.DownloadTitle(queueCtx, tidStr, titlePath, mw.decryptContents, mw.progressWindow, mw.getDeleteEncryptedContents(), mw.logger); err != nil && err != context.Canceled { return err } diff --git a/cmd/WiiUDownloader/progressWindow.go b/cmd/WiiUDownloader/progressWindow.go index a90cbfe..eee0935 100644 --- a/cmd/WiiUDownloader/progressWindow.go +++ b/cmd/WiiUDownloader/progressWindow.go @@ -1,9 +1,9 @@ package main import ( + "context" "fmt" - "github.com/cavaliergopher/grab/v3" "github.com/dustin/go-humanize" "github.com/gotk3/gotk3/glib" "github.com/gotk3/gotk3/gtk" @@ -16,6 +16,7 @@ type ProgressWindow struct { bar *gtk.ProgressBar cancelButton *gtk.Button cancelled bool + cancelFunc context.CancelFunc } func (pw *ProgressWindow) SetGameTitle(title string) { @@ -27,10 +28,11 @@ func (pw *ProgressWindow) SetGameTitle(title string) { } } -func (pw *ProgressWindow) UpdateDownloadProgress(resp *grab.Response, filePath string) { +func (pw *ProgressWindow) UpdateDownloadProgress(downloaded, total, speed int64, filePath string) { glib.IdleAdd(func() { - pw.bar.SetFraction(resp.Progress()) - pw.bar.SetText(fmt.Sprintf("Downloading %s (%s/%s) (%s/s)", filePath, humanize.Bytes(uint64(resp.BytesComplete())), humanize.Bytes(uint64(resp.Size())), humanize.Bytes(uint64(resp.BytesPerSecond())))) + pw.cancelButton.SetSensitive(true) + pw.bar.SetFraction(float64(downloaded) / float64(total)) + pw.bar.SetText(fmt.Sprintf("Downloading %s (%s/%s) (%s/s)", filePath, humanize.Bytes(uint64(downloaded)), humanize.Bytes(uint64(total)), humanize.Bytes(uint64(speed)))) }) for gtk.EventsPending() { gtk.MainIteration() @@ -39,6 +41,7 @@ func (pw *ProgressWindow) UpdateDownloadProgress(resp *grab.Response, filePath s func (pw *ProgressWindow) UpdateDecryptionProgress(progress float64) { glib.IdleAdd(func() { + pw.cancelButton.SetSensitive(false) pw.bar.SetFraction(progress) pw.bar.SetText(fmt.Sprintf("Decrypting (%.2f%%)", progress*100)) }) @@ -51,6 +54,10 @@ func (pw *ProgressWindow) Cancelled() bool { return pw.cancelled } +func (pw *ProgressWindow) SetCancelled() { + pw.cancelFunc() +} + func createProgressWindow(parent *gtk.ApplicationWindow) (*ProgressWindow, error) { win, err := gtk.WindowNew(gtk.WINDOW_TOPLEVEL) if err != nil { @@ -106,6 +113,7 @@ func createProgressWindow(parent *gtk.ApplicationWindow) (*ProgressWindow, error progressWindow.cancelButton.Connect("clicked", func() { progressWindow.cancelled = true + progressWindow.SetCancelled() }) return &progressWindow, nil diff --git a/downloader.go b/downloader.go index e239c70..8da9cb3 100644 --- a/downloader.go +++ b/downloader.go @@ -8,12 +8,12 @@ import ( "encoding/binary" "encoding/hex" "fmt" + "io" + "net/http" "os" "path/filepath" "strings" "time" - - "github.com/cavaliergopher/grab/v3" ) const ( @@ -24,53 +24,92 @@ const ( type ProgressReporter interface { SetGameTitle(title string) - UpdateDownloadProgress(resp *grab.Response, filePath string) + UpdateDownloadProgress(downloaded, total int64, speed int64, filePath string) UpdateDecryptionProgress(progress float64) Cancelled() bool + SetCancelled() } -func downloadFile(progressReporter ProgressReporter, client *grab.Client, downloadURL, dstPath string, doRetries bool) error { +func calculateDownloadSpeed(downloaded int64, startTime, endTime time.Time) int64 { + duration := endTime.Sub(startTime).Seconds() + if duration > 0 { + return int64(float64(downloaded) / duration) + } + return 0 +} + +func downloadFile(ctx context.Context, progressReporter ProgressReporter, client *http.Client, downloadURL, dstPath string, doRetries bool) error { filePath := filepath.Base(dstPath) - t := time.NewTicker(500 * time.Millisecond) - defer t.Stop() + var speed int64 + var startTime time.Time for attempt := 1; attempt <= maxRetries; attempt++ { - req, err := grab.NewRequest(dstPath, downloadURL) + req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil) if err != nil { return err } - req.BufferSize = bufferSize - resp := client.Do(req) - progressReporter.UpdateDownloadProgress(resp, filePath) - - Loop: - for { - select { - case <-t.C: - progressReporter.UpdateDownloadProgress(resp, filePath) - if progressReporter.Cancelled() { - resp.Cancel() - break Loop - } - case <-resp.Done: - if err := resp.Err(); err != nil { - if doRetries && attempt < maxRetries { - time.Sleep(retryDelay) - break Loop - } - return fmt.Errorf("download error after %d attempts: %+v", attempt, err) - } - break Loop - } + resp, err := client.Do(req) + if err != nil { + return err } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if doRetries && attempt < maxRetries { + time.Sleep(retryDelay) + continue + } + return fmt.Errorf("download error after %d attempts, status code: %d", attempt, resp.StatusCode) + } + + file, err := os.Create(dstPath) + if err != nil { + return err + } + defer file.Close() + + total := resp.ContentLength + buffer := make([]byte, bufferSize) + var downloaded int64 + + startTime = time.Now() + for { + n, err := resp.Body.Read(buffer) + if err != nil && err != io.EOF { + if doRetries && attempt < maxRetries { + time.Sleep(retryDelay) + break + } + return fmt.Errorf("download error after %d attempts: %+v", attempt, err) + } + + if n == 0 { + break + } + + _, err = file.Write(buffer[:n]) + if err != nil { + if doRetries && attempt < maxRetries { + time.Sleep(retryDelay) + break + } + return fmt.Errorf("write error after %d attempts: %+v", attempt, err) + } + + downloaded += int64(n) + endTime := time.Now() + speed = calculateDownloadSpeed(downloaded, startTime, endTime) + progressReporter.UpdateDownloadProgress(downloaded, total, speed, filePath) + } + break } return nil } -func DownloadTitle(cancel context.CancelFunc, titleID, outputDirectory string, doDecryption bool, progressReporter ProgressReporter, deleteEncryptedContents bool, logger *Logger) error { +func DownloadTitle(cancelCtx context.Context, titleID, outputDirectory string, doDecryption bool, progressReporter ProgressReporter, deleteEncryptedContents bool, logger *Logger) error { titleEntry := getTitleEntryFromTid(titleID) progressReporter.SetGameTitle(titleEntry.Name) @@ -86,10 +125,12 @@ func DownloadTitle(cancel context.CancelFunc, titleID, outputDirectory string, d return err } - client := grab.NewClient() - client.BufferSize = bufferSize + client := &http.Client{} tmdPath := filepath.Join(outputDir, "title.tmd") - if err := downloadFile(progressReporter, client, fmt.Sprintf("%s/%s", baseURL, "tmd"), tmdPath, true); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%s", baseURL, "tmd"), tmdPath, true); err != nil { + if progressReporter.Cancelled() { + return nil + } return err } @@ -104,7 +145,10 @@ func DownloadTitle(cancel context.CancelFunc, titleID, outputDirectory string, d } tikPath := filepath.Join(outputDir, "title.tik") - if err := downloadFile(progressReporter, client, fmt.Sprintf("%s/%s", baseURL, "cetk"), tikPath, false); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%s", baseURL, "cetk"), tikPath, false); err != nil { + if progressReporter.Cancelled() { + return nil + } titleKey, err := GenerateKey(titleID) if err != nil { return err @@ -124,8 +168,11 @@ func DownloadTitle(cancel context.CancelFunc, titleID, outputDirectory string, d return err } - cert, err := GenerateCert(tmdData, contentCount, progressReporter, client) + cert, err := GenerateCert(tmdData, contentCount, progressReporter, client, cancelCtx) if err != nil { + if progressReporter.Cancelled() { + return nil + } return err } @@ -165,31 +212,38 @@ func DownloadTitle(cancel context.CancelFunc, titleID, outputDirectory string, d return err } filePath := filepath.Join(outputDir, fmt.Sprintf("%08X.app", id)) - if err := downloadFile(progressReporter, client, fmt.Sprintf("%s/%08X", baseURL, id), filePath, true); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%08X", baseURL, id), filePath, true); err != nil { + if progressReporter.Cancelled() { + break + } return err } if tmdData[offset+7]&0x2 == 2 { filePath = filepath.Join(outputDir, fmt.Sprintf("%08X.h3", id)) - if err := downloadFile(progressReporter, client, fmt.Sprintf("%s/%08X.h3", baseURL, id), filePath, true); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%08X.h3", baseURL, id), filePath, true); err != nil { + if progressReporter.Cancelled() { + break + } return err } content.Hash = tmdData[offset+16 : offset+0x14] content.ID = fmt.Sprintf("%08X", id) tmdDataReader.Seek(int64(offset+8), 0) if err := binary.Read(tmdDataReader, binary.BigEndian, &content.Size); err != nil { + if progressReporter.Cancelled() { + break + } return err } if err := checkContentHashes(outputDirectory, content, cipherHashTree); err != nil { if progressReporter.Cancelled() { - cancel() break } return err } } if progressReporter.Cancelled() { - cancel() break } } diff --git a/go.mod b/go.mod index 2b0e792..a376617 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,12 @@ go 1.20 require ( github.com/Xpl0itU/dialog v0.0.0-20230805114139-ec888310aded - github.com/cavaliergopher/grab/v3 v3.0.1 github.com/dustin/go-humanize v1.0.1 github.com/gotk3/gotk3 v0.6.2 - golang.org/x/crypto v0.13.0 + golang.org/x/crypto v0.17.0 ) require ( github.com/TheTitanrain/w32 v0.0.0-20200114052255-2654d97dbd3d // indirect - golang.org/x/sync v0.3.0 + golang.org/x/sync v0.5.0 ) diff --git a/go.sum b/go.sum index faadc78..be69bf4 100644 --- a/go.sum +++ b/go.sum @@ -2,13 +2,11 @@ github.com/TheTitanrain/w32 v0.0.0-20200114052255-2654d97dbd3d h1:2xp1BQbqcDDaik github.com/TheTitanrain/w32 v0.0.0-20200114052255-2654d97dbd3d/go.mod h1:peYoMncQljjNS6tZwI9WVyQB3qZS6u79/N3mBOcnd3I= github.com/Xpl0itU/dialog v0.0.0-20230805114139-ec888310aded h1:GkBw5aNvID1+SKAD3xC5fU4EwMgOmkrvICy5NX3Rqvw= github.com/Xpl0itU/dialog v0.0.0-20230805114139-ec888310aded/go.mod h1:Yl652wzqaetwEMJ8FnDRKBK1+CisE+PU5BGJXItbYFg= -github.com/cavaliergopher/grab/v3 v3.0.1 h1:4z7TkBfmPjmLAAmkkAZNX/6QJ1nNFdv3SdIHXju0Fr4= -github.com/cavaliergopher/grab/v3 v3.0.1/go.mod h1:1U/KNnD+Ft6JJiYoYBAimKH2XrYptb8Kl3DFGmsjpq4= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/gotk3/gotk3 v0.6.2 h1:sx/PjaKfKULJPTPq8p2kn2ZbcNFxpOJqi4VLzMbEOO8= github.com/gotk3/gotk3 v0.6.2/go.mod h1:/hqFpkNa9T3JgNAE2fLvCdov7c5bw//FHNZrZ3Uv9/Q= -golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= -golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=