diff --git a/decryption.go b/decryption.go index f73d359..2cae75c 100644 --- a/decryption.go +++ b/decryption.go @@ -128,13 +128,11 @@ func extractFileHash(src *os.File, partDataOffset uint64, fileOffset uint64, siz return nil } -func extractFile(src *os.File, part_data_offset uint64, file_offset uint64, size uint64, path string, content_id uint16, cipherHashTree cipher.Block) error { - enc := make([]byte, BLOCK_SIZE) - dec := make([]byte, BLOCK_SIZE) - iv := make([]byte, 16) +func extractFile(src *os.File, partDataOffset uint64, fileOffset uint64, size uint64, path string, contentId uint16, cipherHashTree cipher.Block) error { + encryptedContent := make([]byte, BLOCK_SIZE) + decryptedContent := make([]byte, BLOCK_SIZE) - roffset := file_offset / BLOCK_SIZE * BLOCK_SIZE - soffset := file_offset - (file_offset / BLOCK_SIZE * BLOCK_SIZE) + writeSize := BLOCK_SIZE dst, err := os.Create(path) if err != nil { @@ -142,39 +140,43 @@ func extractFile(src *os.File, part_data_offset uint64, file_offset uint64, size } defer dst.Close() - iv[1] = byte(content_id) + roffset := fileOffset / BLOCK_SIZE * BLOCK_SIZE + soffset := fileOffset - (fileOffset / BLOCK_SIZE * BLOCK_SIZE) - write_size := BLOCK_SIZE - if soffset+size > uint64(write_size) { - write_size = write_size - int(soffset) + if soffset+size > uint64(writeSize) { + writeSize = writeSize - int(soffset) } - _, err = src.Seek(int64(part_data_offset+roffset), io.SeekStart) + _, err = src.Seek(int64(partDataOffset+roffset), io.SeekStart) if err != nil { return err } + iv := make([]byte, aes.BlockSize) + iv[1] = byte(contentId) + + aesCipher := cipher.NewCBCDecrypter(cipherHashTree, iv) + for size > 0 { - if uint64(write_size) > size { - write_size = int(size) + if uint64(writeSize) > size { + writeSize = int(size) } - if _, err := io.ReadFull(src, enc); err != nil { + if n, err := io.ReadFull(src, encryptedContent); err != nil && n != BLOCK_SIZE { return fmt.Errorf("could not read %d bytes from '%s': %w", BLOCK_SIZE, path, err) } - mode := cipher.NewCBCDecrypter(cipherHashTree, iv) - mode.CryptBlocks(dec, enc) + aesCipher.CryptBlocks(decryptedContent, encryptedContent) - size -= uint64(write_size) - - _, err = dst.Write(dec[soffset : soffset+uint64(write_size)]) + n, err := dst.Write(decryptedContent[soffset : soffset+uint64(writeSize)]) if err != nil { return err } + size -= uint64(n) + if soffset != 0 { - write_size = BLOCK_SIZE + writeSize = BLOCK_SIZE soffset = 0 } } @@ -566,9 +568,9 @@ func DecryptContents(path string, progressReporter ProgressReporter, deleteEncry } defer srcFile.Close() if tmdFlags&0x02 != 0 { - err = extractFileHash(srcFile, 0, uint64(fst.FSTEntries[i].Offset), uint64(fst.FSTEntries[i].Length), outputPath, fst.FSTEntries[i].ContentID, cipherHashTree) + err = extractFileHash(srcFile, 0, contentOffset, uint64(fst.FSTEntries[i].Length), outputPath, fst.FSTEntries[i].ContentID, cipherHashTree) } else { - err = extractFile(srcFile, 0, uint64(fst.FSTEntries[i].Offset), uint64(fst.FSTEntries[i].Length), outputPath, fst.FSTEntries[i].ContentID, cipherHashTree) + err = extractFile(srcFile, 0, contentOffset, uint64(fst.FSTEntries[i].Length), outputPath, fst.FSTEntries[i].ContentID, cipherHashTree) } if err != nil { return err