diff --git a/certificate.go b/certificate.go index 205a2e2..88ea3f6 100644 --- a/certificate.go +++ b/certificate.go @@ -4,9 +4,10 @@ import ( "bytes" "context" "fmt" - "net/http" "os" "path" + + "github.com/valyala/fasthttp" ) var cetkData []byte @@ -28,12 +29,12 @@ func getCert(tmdData []byte, id int, numContents uint16) ([]byte, error) { } } -func getDefaultCert(cancelCtx context.Context, progressReporter ProgressReporter, client *http.Client, buffer []byte) ([]byte, error) { +func getDefaultCert(cancelCtx context.Context, progressReporter ProgressReporter, client *fasthttp.Client) ([]byte, error) { if len(cetkData) >= 0x350+0x300 { return cetkData[0x350 : 0x350+0x300], nil } cetkDir := path.Join(os.TempDir(), "cetk") - if err := downloadFile(cancelCtx, progressReporter, client, "http://ccs.cdn.c.shop.nintendowifi.net/ccs/download/000500101000400a/cetk", cetkDir, true, buffer); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, "http://ccs.cdn.c.shop.nintendowifi.net/ccs/download/000500101000400a/cetk", cetkDir, true); err != nil { return nil, err } cetkData, err := os.ReadFile(cetkDir) @@ -51,7 +52,7 @@ func getDefaultCert(cancelCtx context.Context, progressReporter ProgressReporter return nil, fmt.Errorf("failed to download OSv10 cetk, length: %d", len(cetkData)) } -func GenerateCert(tmdData []byte, contentCount uint16, progressReporter ProgressReporter, client *http.Client, cancelCtx context.Context, buffer []byte) (bytes.Buffer, error) { +func GenerateCert(tmdData []byte, contentCount uint16, progressReporter ProgressReporter, client *fasthttp.Client, cancelCtx context.Context) (bytes.Buffer, error) { cert := bytes.Buffer{} cert0, err := getCert(tmdData, 0, contentCount) @@ -66,7 +67,7 @@ func GenerateCert(tmdData []byte, contentCount uint16, progressReporter Progress } cert.Write(cert1) - defaultCert, err := getDefaultCert(cancelCtx, progressReporter, client, buffer) + defaultCert, err := getDefaultCert(cancelCtx, progressReporter, client) if err != nil { return bytes.Buffer{}, err } diff --git a/cmd/WiiUDownloader/main.go b/cmd/WiiUDownloader/main.go index 2583852..3e3250b 100644 --- a/cmd/WiiUDownloader/main.go +++ b/cmd/WiiUDownloader/main.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "net/http" "os" "path/filepath" "runtime" @@ -11,6 +10,8 @@ import ( wiiudownloader "github.com/Xpl0itU/WiiUDownloader" "github.com/gotk3/gotk3/glib" "github.com/gotk3/gotk3/gtk" + + "github.com/valyala/fasthttp" ) func main() { @@ -42,14 +43,11 @@ func main() { logger.Fatal(err.Error()) } - t := http.DefaultTransport.(*http.Transport).Clone() - t.MaxIdleConns = 100 - t.MaxConnsPerHost = 100 - t.MaxIdleConnsPerHost = 100 - - client := &http.Client{ - Timeout: time.Duration(30) * time.Second, - Transport: t, + client := &fasthttp.Client{ + MaxConnsPerHost: 100, + MaxIdleConnDuration: 30 * time.Second, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, } app.Connect("activate", func() { diff --git a/cmd/WiiUDownloader/mainwindow.go b/cmd/WiiUDownloader/mainwindow.go index 3fc598c..f3f4cf6 100644 --- a/cmd/WiiUDownloader/mainwindow.go +++ b/cmd/WiiUDownloader/mainwindow.go @@ -5,7 +5,6 @@ import ( "context" "encoding/binary" "fmt" - "net/http" "os" "path/filepath" "strconv" @@ -15,6 +14,7 @@ import ( "github.com/Xpl0itU/dialog" "github.com/gotk3/gotk3/glib" "github.com/gotk3/gotk3/gtk" + "github.com/valyala/fasthttp" "golang.org/x/sync/errgroup" ) @@ -40,10 +40,10 @@ type MainWindow struct { titles []wiiudownloader.TitleEntry decryptContents bool currentRegion uint8 - client *http.Client + client *fasthttp.Client } -func NewMainWindow(app *gtk.Application, entries []wiiudownloader.TitleEntry, logger *wiiudownloader.Logger, client *http.Client) *MainWindow { +func NewMainWindow(app *gtk.Application, entries []wiiudownloader.TitleEntry, logger *wiiudownloader.Logger, client *fasthttp.Client) *MainWindow { gSettings, err := gtk.SettingsGetDefault() if err != nil { logger.Error(err.Error()) @@ -265,7 +265,7 @@ func (mw *MainWindow) ShowAll() { wiiudownloader.GenerateTicket(filepath.Join(parentDir, "title.tik"), titleID, titleKey, titleVersion) - cert, err := wiiudownloader.GenerateCert(tmdData, contentCount, mw.progressWindow, http.DefaultClient, context.Background(), make([]byte, 0)) + cert, err := wiiudownloader.GenerateCert(tmdData, contentCount, mw.progressWindow, mw.client, context.Background()) if err != nil { return } diff --git a/downloader.go b/downloader.go index 637f0c1..d5421aa 100644 --- a/downloader.go +++ b/downloader.go @@ -6,11 +6,12 @@ import ( "encoding/binary" "fmt" "io" - "net/http" "os" "path/filepath" "strings" "time" + + "github.com/valyala/fasthttp" ) const ( @@ -38,7 +39,7 @@ func calculateDownloadSpeed(downloaded int64, startTime, endTime time.Time) int6 return 0 } -func downloadFile(ctx context.Context, progressReporter ProgressReporter, client *http.Client, downloadURL, dstPath string, doRetries bool, buffer []byte) error { +func downloadFile(ctx context.Context, progressReporter ProgressReporter, client *fasthttp.Client, downloadURL, dstPath string, doRetries bool) error { filePath := filepath.Base(dstPath) startTime := time.Now() @@ -57,32 +58,41 @@ func downloadFile(ctx context.Context, progressReporter ProgressReporter, client for attempt := 1; attempt <= maxRetries; attempt++ { isError = false - req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil) - if err != nil { - return err - } + req := fasthttp.AcquireRequest() + + req.SetRequestURI(downloadURL) + req.Header.SetMethod("GET") req.Header.Set("User-Agent", "WiiUDownloader") req.Header.Set("Connection", "Keep-Alive") req.Header.Set("Accept-Encoding", "*") - resp, err := client.Do(req) - if err != nil { + resp := fasthttp.AcquireResponse() + resp.StreamBody = true + resp.ImmediateHeaderFlush = true + + if err := client.Do(req, resp); err != nil { + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) return err } - if resp.StatusCode != http.StatusOK { + if resp.StatusCode() != fasthttp.StatusOK { if doRetries && attempt < maxRetries { time.Sleep(retryDelay) continue } - resp.Body.Close() - return fmt.Errorf("download error after %d attempts, status code: %d", attempt, resp.StatusCode) + resp.CloseBodyStream() + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) + return fmt.Errorf("download error after %d attempts, status code: %d", attempt, resp.StatusCode()) } file, err := os.Create(dstPath) if err != nil { - resp.Body.Close() + resp.CloseBodyStream() + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) return err } @@ -90,48 +100,42 @@ func downloadFile(ctx context.Context, progressReporter ProgressReporter, client go updateProgress(&downloaded) - Loop: - for { - select { - case <-ctx.Done(): - resp.Body.Close() + customBufferedWriter, err := NewFileWriterWithProgress(file, &downloaded) + if err != nil { + resp.CloseBodyStream() + file.Close() + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) + return err + } + + select { + case <-ctx.Done(): + resp.CloseBodyStream() + file.Close() + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) + return ctx.Err() + default: + err := resp.BodyWriteTo(customBufferedWriter) + if err != nil && err != io.EOF { + resp.CloseBodyStream() file.Close() - return ctx.Err() - default: - n, err := resp.Body.Read(buffer) - if err != nil && err != io.EOF { - resp.Body.Close() - file.Close() - if doRetries && attempt < maxRetries { - time.Sleep(retryDelay) - isError = true - break Loop - } - return fmt.Errorf("download error after %d attempts: %+v", attempt, err) + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) + if doRetries && attempt < maxRetries { + time.Sleep(retryDelay) + isError = true + break } - - if n == 0 { - resp.Body.Close() - file.Close() - break Loop - } - - _, err = file.Write(buffer[:n]) - if err != nil { - resp.Body.Close() - file.Close() - if doRetries && attempt < maxRetries { - time.Sleep(retryDelay) - isError = true - break Loop - } - return fmt.Errorf("write error after %d attempts: %+v", attempt, err) - } - - downloaded += int64(n) + return fmt.Errorf("download error after %d attempts: %+v", attempt, err) } } if !isError { + resp.CloseBodyStream() + file.Close() + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) break } } @@ -139,7 +143,7 @@ func downloadFile(ctx context.Context, progressReporter ProgressReporter, client return nil } -func DownloadTitle(cancelCtx context.Context, titleID, outputDirectory string, doDecryption bool, progressReporter ProgressReporter, deleteEncryptedContents bool, logger *Logger, client *http.Client) error { +func DownloadTitle(cancelCtx context.Context, titleID, outputDirectory string, doDecryption bool, progressReporter ProgressReporter, deleteEncryptedContents bool, logger *Logger, client *fasthttp.Client) error { tEntry := getTitleEntryFromTid(titleID) progressReporter.SetTotalDownloaded(0) @@ -152,10 +156,8 @@ func DownloadTitle(cancelCtx context.Context, titleID, outputDirectory string, d return err } - buffer := make([]byte, bufferSize) - tmdPath := filepath.Join(outputDir, "title.tmd") - if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%s", baseURL, "tmd"), tmdPath, true, buffer); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%s", baseURL, "tmd"), tmdPath, true); err != nil { if progressReporter.Cancelled() { return nil } @@ -173,7 +175,7 @@ func DownloadTitle(cancelCtx context.Context, titleID, outputDirectory string, d } tikPath := filepath.Join(outputDir, "title.tik") - if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%s", baseURL, "cetk"), tikPath, false, buffer); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%s", baseURL, "cetk"), tikPath, false); err != nil { if progressReporter.Cancelled() { return nil } @@ -207,7 +209,7 @@ func DownloadTitle(cancelCtx context.Context, titleID, outputDirectory string, d progressReporter.SetDownloadSize(int64(titleSize)) - cert, err := GenerateCert(tmdData, contentCount, progressReporter, client, cancelCtx, buffer) + cert, err := GenerateCert(tmdData, contentCount, progressReporter, client, cancelCtx) if err != nil { if progressReporter.Cancelled() { return nil @@ -236,7 +238,7 @@ func DownloadTitle(cancelCtx context.Context, titleID, outputDirectory string, d return err } filePath := filepath.Join(outputDir, fmt.Sprintf("%08X.app", content.ID)) - if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%08X", baseURL, content.ID), filePath, true, buffer); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%08X", baseURL, content.ID), filePath, true); err != nil { if progressReporter.Cancelled() { break } @@ -246,7 +248,7 @@ func DownloadTitle(cancelCtx context.Context, titleID, outputDirectory string, d if tmdData[offset+7]&0x2 == 2 { filePath = filepath.Join(outputDir, fmt.Sprintf("%08X.h3", content.ID)) - if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%08X.h3", baseURL, content.ID), filePath, true, buffer); err != nil { + if err := downloadFile(cancelCtx, progressReporter, client, fmt.Sprintf("%s/%08X.h3", baseURL, content.ID), filePath, true); err != nil { if progressReporter.Cancelled() { break } diff --git a/go.mod b/go.mod index 075f7d9..d7f4a31 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,14 @@ require ( golang.org/x/crypto v0.21.0 ) +require ( + github.com/andybalholm/brotli v1.1.0 // indirect + github.com/klauspost/compress v1.17.6 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect +) + require ( github.com/TheTitanrain/w32 v0.0.0-20200114052255-2654d97dbd3d // indirect + github.com/valyala/fasthttp v1.52.0 golang.org/x/sync v0.6.0 ) diff --git a/go.sum b/go.sum index 960ea52..11b1a11 100644 --- a/go.sum +++ b/go.sum @@ -2,10 +2,18 @@ github.com/TheTitanrain/w32 v0.0.0-20200114052255-2654d97dbd3d h1:2xp1BQbqcDDaik github.com/TheTitanrain/w32 v0.0.0-20200114052255-2654d97dbd3d/go.mod h1:peYoMncQljjNS6tZwI9WVyQB3qZS6u79/N3mBOcnd3I= github.com/Xpl0itU/dialog v0.0.0-20230805114139-ec888310aded h1:GkBw5aNvID1+SKAD3xC5fU4EwMgOmkrvICy5NX3Rqvw= github.com/Xpl0itU/dialog v0.0.0-20230805114139-ec888310aded/go.mod h1:Yl652wzqaetwEMJ8FnDRKBK1+CisE+PU5BGJXItbYFg= +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/gotk3/gotk3 v0.6.3 h1:+Ke4WkM1TQUNOlM2TZH6szqknqo+zNbX3BZWVXjSHYw= github.com/gotk3/gotk3 v0.6.3/go.mod h1:/hqFpkNa9T3JgNAE2fLvCdov7c5bw//FHNZrZ3Uv9/Q= +github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI= +github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.52.0 h1:wqBQpxH71XW0e2g+Og4dzQM8pk34aFYlA1Ga8db7gU0= +github.com/valyala/fasthttp v1.52.0/go.mod h1:hf5C4QnVMkNXMspnsUlfM3WitlgYflyhHYoKol/szxQ= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= diff --git a/utils.go b/utils.go index 2e6e3cc..12c80c9 100644 --- a/utils.go +++ b/utils.go @@ -30,3 +30,33 @@ func doDeleteEncryptedContents(path string) error { return nil }) } + +type BufferedWriter struct { + file *os.File + downloaded *int64 + buffer []byte +} + +func NewFileWriterWithProgress(file *os.File, downloaded *int64) (*BufferedWriter, error) { + return &BufferedWriter{ + file: file, + downloaded: downloaded, + buffer: make([]byte, bufferSize), + }, nil +} + +func (bw *BufferedWriter) Write(data []byte) (int, error) { + written := 0 + for written < len(data) { + remaining := len(data) - written + toWrite := min(bufferSize, uint64(remaining)) + copy(bw.buffer, data[written:written+int(toWrite)]) + n, err := bw.file.Write(bw.buffer[:toWrite]) + if err != nil { + return written, err + } + written += n + *bw.downloaded += int64(n) + } + return written, nil +}