Use errgroup

This commit is contained in:
Xpl0itU 2023-08-26 13:29:12 +02:00
parent 2f2b6d888d
commit 929c6bda56
4 changed files with 29 additions and 35 deletions

View file

@ -8,13 +8,13 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync"
wiiudownloader "github.com/Xpl0itU/WiiUDownloader" wiiudownloader "github.com/Xpl0itU/WiiUDownloader"
"github.com/Xpl0itU/dialog" "github.com/Xpl0itU/dialog"
"github.com/cavaliergopher/grab/v3" "github.com/cavaliergopher/grab/v3"
"github.com/gotk3/gotk3/glib" "github.com/gotk3/gotk3/glib"
"github.com/gotk3/gotk3/gtk" "github.com/gotk3/gotk3/gtk"
"golang.org/x/sync/errgroup"
) )
const ( const (
@ -736,46 +736,36 @@ func (mw *MainWindow) onDownloadQueueClicked(selectedPath string) error {
} }
queueStatusChan := make(chan bool, 1) queueStatusChan := make(chan bool, 1)
errorChan := make(chan error, 1) defer close(queueStatusChan)
errGroup := errgroup.Group{}
var wg sync.WaitGroup
queueProcessingLoop: queueProcessingLoop:
for _, title := range mw.titleQueue { for _, title := range mw.titleQueue {
wg.Add(1) errGroup.Go(func() error {
go func(title wiiudownloader.TitleEntry, selectedPath string, progressWindow *wiiudownloader.ProgressWindow) {
defer wg.Done()
tidStr := fmt.Sprintf("%016x", title.TitleID) tidStr := fmt.Sprintf("%016x", title.TitleID)
titlePath := filepath.Join(selectedPath, fmt.Sprintf("%s [%s] [%s]", normalizeFilename(title.Name), wiiudownloader.GetFormattedKind(title.TitleID), tidStr)) titlePath := filepath.Join(selectedPath, fmt.Sprintf("%s [%s] [%s]", normalizeFilename(title.Name), wiiudownloader.GetFormattedKind(title.TitleID), tidStr))
if err := wiiudownloader.DownloadTitle(tidStr, titlePath, mw.decryptContents, progressWindow, mw.getDeleteEncryptedContents(), mw.logger); err != nil { if err := wiiudownloader.DownloadTitle(tidStr, titlePath, mw.decryptContents, &mw.progressWindow, mw.getDeleteEncryptedContents(), mw.logger); err != nil {
errorChan <- err return err
return
} }
queueStatusChan <- true queueStatusChan <- true
}(title, selectedPath, &mw.progressWindow) return nil
})
select { if err := errGroup.Wait(); err != nil {
case err := <-errorChan:
queueStatusChan <- false queueStatusChan <- false
wg.Wait()
mw.titleQueue = []wiiudownloader.TitleEntry{} mw.titleQueue = []wiiudownloader.TitleEntry{}
if mw.progressWindow.Window.IsVisible() { if mw.progressWindow.Window.IsVisible() {
mw.progressWindow.Window.Close() mw.progressWindow.Window.Close()
} }
mw.updateTitlesInQueue() mw.updateTitlesInQueue()
mw.onSelectionChanged() mw.onSelectionChanged()
close(queueStatusChan)
close(errorChan)
return err return err
case queueStatus := <-queueStatusChan:
if !queueStatus {
wg.Wait()
break queueProcessingLoop
} }
queueStatus := <-queueStatusChan
if !queueStatus {
break queueProcessingLoop
} }
} }
@ -788,8 +778,5 @@ queueProcessingLoop:
mw.updateTitlesInQueue() mw.updateTitlesInQueue()
mw.onSelectionChanged() mw.onSelectionChanged()
close(queueStatusChan)
close(errorChan)
return nil return nil
} }

View file

@ -20,6 +20,7 @@ import (
"github.com/gotk3/gotk3/glib" "github.com/gotk3/gotk3/glib"
"github.com/gotk3/gotk3/gtk" "github.com/gotk3/gotk3/gtk"
"golang.org/x/sync/errgroup"
) )
//export callProgressCallback //export callProgressCallback
@ -30,11 +31,13 @@ func callProgressCallback(progress C.int) {
var progressChan chan int var progressChan chan int
func DecryptContents(path string, progress *ProgressWindow, deleteEncryptedContents bool) error { func DecryptContents(path string, progress *ProgressWindow, deleteEncryptedContents bool) error {
errorChan := make(chan error)
progressChan = make(chan int) progressChan = make(chan int)
defer close(errorChan)
go runDecryption(path, errorChan, deleteEncryptedContents) errGroup := errgroup.Group{}
errGroup.Go(func() error {
return runDecryption(path, deleteEncryptedContents)
})
glib.IdleAdd(func() { glib.IdleAdd(func() {
progress.bar.SetText("Decrypting...") progress.bar.SetText("Decrypting...")
@ -52,10 +55,10 @@ func DecryptContents(path string, progress *ProgressWindow, deleteEncryptedConte
time.Sleep(time.Millisecond * 10) time.Sleep(time.Millisecond * 10)
} }
return <-errorChan return errGroup.Wait()
} }
func runDecryption(path string, errorChan chan<- error, deleteEncryptedContents bool) { func runDecryption(path string, deleteEncryptedContents bool) error {
defer close(progressChan) defer close(progressChan)
argv := make([]*C.char, 2) argv := make([]*C.char, 2)
argv[0] = C.CString("WiiUDownloader") argv[0] = C.CString("WiiUDownloader")
@ -67,13 +70,12 @@ func runDecryption(path string, errorChan chan<- error, deleteEncryptedContents
C.set_progress_callback(C.ProgressCallback(C.callProgressCallback)) C.set_progress_callback(C.ProgressCallback(C.callProgressCallback))
if int(C.cdecrypt_main(2, (**C.char)(unsafe.Pointer(&argv[0])))) != 0 { if int(C.cdecrypt_main(2, (**C.char)(unsafe.Pointer(&argv[0])))) != 0 {
errorChan <- fmt.Errorf("decryption failed") return fmt.Errorf("decryption failed")
return
} }
if deleteEncryptedContents { if deleteEncryptedContents {
doDeleteEncryptedContents(path) doDeleteEncryptedContents(path)
} }
errorChan <- nil return nil
} }

5
go.mod
View file

@ -10,4 +10,7 @@ require (
golang.org/x/crypto v0.12.0 golang.org/x/crypto v0.12.0
) )
require github.com/TheTitanrain/w32 v0.0.0-20200114052255-2654d97dbd3d // indirect require (
github.com/TheTitanrain/w32 v0.0.0-20200114052255-2654d97dbd3d // indirect
golang.org/x/sync v0.3.0 // indirect
)

2
go.sum
View file

@ -10,3 +10,5 @@ github.com/gotk3/gotk3 v0.6.2 h1:sx/PjaKfKULJPTPq8p2kn2ZbcNFxpOJqi4VLzMbEOO8=
github.com/gotk3/gotk3 v0.6.2/go.mod h1:/hqFpkNa9T3JgNAE2fLvCdov7c5bw//FHNZrZ3Uv9/Q= github.com/gotk3/gotk3 v0.6.2/go.mod h1:/hqFpkNa9T3JgNAE2fLvCdov7c5bw//FHNZrZ3Uv9/Q=
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=