crypto/spl: implement rsa-oaep

This commit is contained in:
Michael Scire 2020-02-24 19:09:13 -08:00
parent ad1158b30a
commit a429c61f33
6 changed files with 373 additions and 75 deletions

View file

@ -151,80 +151,6 @@ namespace ams::spl::impl {
R_ABORT_UNLESS(svcMapDeviceAddressSpaceAligned(g_se_das_hnd, dd::GetCurrentProcessHandle(), work_buffer_addr, sizeof(g_work_buffer), g_se_mapped_work_buffer_addr, 3));
}
/* RSA OAEP implementation helpers. */
void CalcMgf1AndXor(void *dst, size_t dst_size, const void *src, size_t src_size) {
uint8_t *dst_u8 = reinterpret_cast<u8 *>(dst);
u32 ctr = 0;
while (dst_size > 0) {
const size_t cur_size = std::min(size_t(SHA256_HASH_SIZE), dst_size);
dst_size -= cur_size;
u32 ctr_be = __builtin_bswap32(ctr++);
u8 hash[SHA256_HASH_SIZE];
{
Sha256Context ctx;
sha256ContextCreate(&ctx);
sha256ContextUpdate(&ctx, src, src_size);
sha256ContextUpdate(&ctx, &ctr_be, sizeof(ctr_be));
sha256ContextGetHash(&ctx, hash);
}
for (size_t i = 0; i < cur_size; i++) {
*(dst_u8++) ^= hash[i];
}
}
}
size_t DecodeRsaOaep(void *dst, size_t dst_size, const void *label_digest, size_t label_digest_size, const void *src, size_t src_size) {
/* Very basic validation. */
if (dst_size == 0 || src_size != 0x100 || label_digest_size != SHA256_HASH_SIZE) {
return 0;
}
u8 block[0x100];
std::memcpy(block, src, sizeof(block));
/* First, validate byte 0 == 0, and unmask DB. */
int invalid = block[0];
u8 *salt = block + 1;
u8 *db = salt + SHA256_HASH_SIZE;
CalcMgf1AndXor(salt, SHA256_HASH_SIZE, db, src_size - (1 + SHA256_HASH_SIZE));
CalcMgf1AndXor(db, src_size - (1 + SHA256_HASH_SIZE), salt, SHA256_HASH_SIZE);
/* Validate label digest. */
for (size_t i = 0; i < SHA256_HASH_SIZE; i++) {
invalid |= db[i] ^ reinterpret_cast<const u8 *>(label_digest)[i];
}
/* Locate message after 00...0001 padding. */
const u8 *padded_msg = db + SHA256_HASH_SIZE;
size_t padded_msg_size = src_size - (1 + 2 * SHA256_HASH_SIZE);
size_t msg_ind = 0;
int not_found = 1;
int wrong_padding = 0;
size_t i = 0;
while (i < padded_msg_size) {
int zero = (padded_msg[i] == 0);
int one = (padded_msg[i] == 1);
msg_ind += static_cast<size_t>(not_found & one) * (++i);
not_found &= ~one;
wrong_padding |= (not_found & ~zero);
}
if (invalid | not_found | wrong_padding) {
return 0;
}
/* Copy message out. */
size_t msg_size = padded_msg_size - msg_ind;
if (msg_size > dst_size) {
return 0;
}
std::memcpy(dst, padded_msg + msg_ind, msg_size);
return msg_size;
}
/* Internal RNG functionality. */
Result GenerateRandomBytesInternal(void *out, size_t size) {
if (!g_drbg.GenerateRandomBytes(out, size)) {
@ -793,7 +719,7 @@ namespace ams::spl::impl {
/* Nintendo doesn't check this result code, but we will. */
R_TRY(SecureExpMod(g_work_buffer, 0x100, base, base_size, mod, mod_size, smc::SecureExpModMode::Lotus));
size_t data_size = DecodeRsaOaep(dst, dst_size, label_digest, label_digest_size, g_work_buffer, 0x100);
size_t data_size = crypto::DecodeRsa2048OaepSha256(dst, dst_size, label_digest, label_digest_size, g_work_buffer, 0x100);
R_UNLESS(data_size > 0, spl::ResultDecryptionFailed());
*out_size = static_cast<u32>(data_size);