[pki] add RSA-2048 signature validation for all server downloads

* Closes #1172
* Also fix a MinGW warning in badblocks.c
This commit is contained in:
Pete Batard 2018-06-29 18:19:05 +01:00
parent 2d262df8f3
commit fdfc9ff82d
8 changed files with 331 additions and 85 deletions

212
src/net.c
View file

@ -29,6 +29,7 @@
#include <malloc.h>
#include <string.h>
#include <inttypes.h>
#include <assert.h>
#include "rufus.h"
#include "missing.h"
@ -216,13 +217,14 @@ const char* WinInetErrorString(void)
}
/*
* Download a file from an URL
* Download a file or fill a buffer from an URL
* Mostly taken from http://support.microsoft.com/kb/234913
* If file is NULL, a buffer is allocated for the download (that needs to be freed by the caller)
* If hProgressDialog is not NULL, this function will send INIT and EXIT messages
* to the dialog in question, with WPARAM being set to nonzero for EXIT on success
* and also attempt to indicate progress using an IDC_PROGRESS control
*/
DWORD DownloadFile(const char* url, const char* file, HWND hProgressDialog)
static DWORD DownloadToFileOrBuffer(const char* url, const char* file, BYTE** buffer, HWND hProgressDialog)
{
HWND hProgressBar = NULL;
BOOL r = FALSE;
@ -234,8 +236,8 @@ DWORD DownloadFile(const char* url, const char* file, HWND hProgressDialog)
HINTERNET hSession = NULL, hConnection = NULL, hRequest = NULL;
URL_COMPONENTSA UrlParts = {sizeof(URL_COMPONENTSA), NULL, 1, (INTERNET_SCHEME)0,
hostname, sizeof(hostname), 0, NULL, 1, urlpath, sizeof(urlpath), NULL, 1};
size_t last_slash;
int i;
const char* short_name;
size_t i;
// Can't link with wininet.lib because of sideloading issues
PF_TYPE_DECL(WINAPI, BOOL, InternetCrackUrlA, (LPCSTR, DWORD, DWORD, LPURL_COMPONENTSA));
@ -257,7 +259,7 @@ DWORD DownloadFile(const char* url, const char* file, HWND hProgressDialog)
PF_INIT_OR_OUT(HttpSendRequestA, WinInet);
PF_INIT_OR_OUT(HttpQueryInfoA, WinInet);
DownloadStatus = 0;
DownloadStatus = 404;
if (hProgressDialog != NULL) {
// Use the progress control provided, if any
hProgressBar = GetDlgItem(hProgressDialog, IDC_PROGRESS);
@ -268,22 +270,18 @@ DWORD DownloadFile(const char* url, const char* file, HWND hProgressDialog)
SendMessage(hProgressDialog, UM_PROGRESS_INIT, 0, 0);
}
if (file == NULL)
if (url == NULL)
goto out;
for (last_slash = safe_strlen(file); last_slash != 0; last_slash--) {
if ((file[last_slash] == '/') || (file[last_slash] == '\\')) {
last_slash++;
break;
}
}
short_name = (file != NULL) ? PathFindFileNameU(file) : PathFindFileNameU(url);
PrintInfo(0, MSG_085, &file[last_slash]);
uprintf("Downloading '%s' from %s\n", &file[last_slash], url);
if (hProgressDialog != NULL)
PrintInfo(0, MSG_085, short_name);
uprintf("Downloading %s", url);
if ( (!pfInternetCrackUrlA(url, (DWORD)safe_strlen(url), 0, &UrlParts))
|| (UrlParts.lpszHostName == NULL) || (UrlParts.lpszUrlPath == NULL)) {
uprintf("Unable to decode URL: %s\n", WinInetErrorString());
uprintf("Unable to decode URL: %s", WinInetErrorString());
goto out;
}
hostname[sizeof(hostname)-1] = 0;
@ -295,7 +293,7 @@ DWORD DownloadFile(const char* url, const char* file, HWND hProgressDialog)
if (i <= 0) {
// http://msdn.microsoft.com/en-us/library/windows/desktop/aa384702.aspx is wrong...
SetLastError(ERROR_INTERNET_NOT_INITIALIZED);
uprintf("Network is unavailable: %s\n", WinInetErrorString());
uprintf("Network is unavailable: %s", WinInetErrorString());
goto out;
}
static_sprintf(agent, APPLICATION_NAME "/%d.%d.%d (Windows NT %d.%d%s)",
@ -303,13 +301,13 @@ DWORD DownloadFile(const char* url, const char* file, HWND hProgressDialog)
nWindowsVersion>>4, nWindowsVersion&0x0F, is_x64()?"; WOW64":"");
hSession = pfInternetOpenA(agent, INTERNET_OPEN_TYPE_PRECONFIG, NULL, NULL, 0);
if (hSession == NULL) {
uprintf("Could not open Internet session: %s\n", WinInetErrorString());
uprintf("Could not open Internet session: %s", WinInetErrorString());
goto out;
}
hConnection = pfInternetConnectA(hSession, UrlParts.lpszHostName, UrlParts.nPort, NULL, NULL, INTERNET_SERVICE_HTTP, 0, (DWORD_PTR)NULL);
if (hConnection == NULL) {
uprintf("Could not connect to server %s:%d: %s\n", UrlParts.lpszHostName, UrlParts.nPort, WinInetErrorString());
uprintf("Could not connect to server %s:%d: %s", UrlParts.lpszHostName, UrlParts.nPort, WinInetErrorString());
goto out;
}
@ -318,35 +316,46 @@ DWORD DownloadFile(const char* url, const char* file, HWND hProgressDialog)
INTERNET_FLAG_NO_COOKIES|INTERNET_FLAG_NO_UI|INTERNET_FLAG_NO_CACHE_WRITE|INTERNET_FLAG_HYPERLINK|
((UrlParts.nScheme==INTERNET_SCHEME_HTTPS)?INTERNET_FLAG_SECURE:0), (DWORD_PTR)NULL);
if (hRequest == NULL) {
uprintf("Could not open URL %s: %s\n", url, WinInetErrorString());
uprintf("Could not open URL %s: %s", url, WinInetErrorString());
goto out;
}
if (!pfHttpSendRequestA(hRequest, NULL, 0, NULL, 0)) {
uprintf("Unable to send request: %s\n", WinInetErrorString());
uprintf("Unable to send request: %s", WinInetErrorString());
goto out;
}
// Get the file size
dwSize = sizeof(DownloadStatus);
DownloadStatus = 404;
pfHttpQueryInfoA(hRequest, HTTP_QUERY_STATUS_CODE|HTTP_QUERY_FLAG_NUMBER, (LPVOID)&DownloadStatus, &dwSize, NULL);
if (DownloadStatus != 200) {
error_code = ERROR_INTERNET_ITEM_NOT_FOUND;
uprintf("Unable to access file: %d\n", DownloadStatus);
uprintf("Unable to access file: %d", DownloadStatus);
goto out;
}
dwSize = sizeof(dwTotalSize);
if (!pfHttpQueryInfoA(hRequest, HTTP_QUERY_CONTENT_LENGTH|HTTP_QUERY_FLAG_NUMBER, (LPVOID)&dwTotalSize, &dwSize, NULL)) {
uprintf("Unable to retrieve file length: %s\n", WinInetErrorString());
uprintf("Unable to retrieve file length: %s", WinInetErrorString());
goto out;
}
uprintf("File length: %d bytes\n", dwTotalSize);
uprintf("File length: %d bytes", dwTotalSize);
hFile = CreateFileU(file, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ, NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
if (hFile == INVALID_HANDLE_VALUE) {
uprintf("Unable to create file '%s': %s\n", &file[last_slash], WinInetErrorString());
goto out;
if (file != NULL) {
hFile = CreateFileU(file, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ, NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
if (hFile == INVALID_HANDLE_VALUE) {
uprintf("Unable to create file '%s': %s", short_name, WinInetErrorString());
goto out;
}
} else {
if (buffer == NULL) {
uprintf("No buffer pointer provided for download");
goto out;
}
*buffer = malloc(dwTotalSize);
if (*buffer == NULL) {
uprintf("Could not allocate buffer for download");
goto out;
}
}
// Keep checking for data until there is nothing left.
@ -354,28 +363,37 @@ DWORD DownloadFile(const char* url, const char* file, HWND hProgressDialog)
while (1) {
if (IS_ERROR(FormatStatus))
goto out;
if (!pfInternetReadFile(hRequest, buf, sizeof(buf), &dwDownloaded) || (dwDownloaded == 0))
break;
dwSize += dwDownloaded;
SendMessage(hProgressBar, PBM_SETPOS, (WPARAM)(MAX_PROGRESS*((1.0f*dwSize)/(1.0f*dwTotalSize))), 0);
PrintInfo(0, MSG_241, (100.0f*dwSize)/(1.0f*dwTotalSize));
if (!WriteFile(hFile, buf, dwDownloaded, &dwWritten, NULL)) {
uprintf("Error writing file '%s': %s\n", &file[last_slash], WinInetErrorString());
goto out;
} else if (dwDownloaded != dwWritten) {
uprintf("Error writing file '%s': Only %d/%d bytes written\n", dwWritten, dwDownloaded);
goto out;
if (hProgressDialog != NULL) {
SendMessage(hProgressBar, PBM_SETPOS, (WPARAM)(MAX_PROGRESS*((1.0f*dwSize) / (1.0f*dwTotalSize))), 0);
PrintInfo(0, MSG_241, (100.0f*dwSize) / (1.0f*dwTotalSize));
}
if (file != NULL) {
if (!WriteFile(hFile, buf, dwDownloaded, &dwWritten, NULL)) {
uprintf("Error writing file '%s': %s", short_name, WinInetErrorString());
goto out;
} else if (dwDownloaded != dwWritten) {
uprintf("Error writing file '%s': Only %d/%d bytes written", short_name, dwWritten, dwDownloaded);
goto out;
}
} else {
memcpy(&(*buffer)[dwSize], buf, dwDownloaded);
}
dwSize += dwDownloaded;
}
if (dwSize != dwTotalSize) {
uprintf("Could not download complete file - read: %d bytes, expected: %d bytes\n", dwSize, dwTotalSize);
uprintf("Could not download complete file - read: %d bytes, expected: %d bytes", dwSize, dwTotalSize);
FormatStatus = ERROR_SEVERITY_ERROR|FAC(FACILITY_STORAGE)|ERROR_WRITE_FAULT;
goto out;
} else {
r = TRUE;
uprintf("Successfully downloaded '%s'\n", &file[last_slash]);
uprintf("Successfully downloaded '%s'", short_name);
if (hProgressDialog != NULL) {
SendMessage(hProgressBar, PBM_SETPOS, (WPARAM)MAX_PROGRESS, 0);
PrintInfo(0, MSG_241, 100.0f);
}
}
out:
@ -406,20 +424,76 @@ out:
return r?dwSize:0;
}
// Download and validate a signed file. The file must have a corresponding '.sig' on the server.
DWORD DownloadSignedFile(const char* url, const char* file, HWND hProgressDialog)
{
char* url_sig = NULL;
BYTE *buf = NULL, *sig = NULL;
DWORD buf_len = 0, sig_len = 0;
DWORD ret = 0;
HANDLE hFile = INVALID_HANDLE_VALUE;
if (url == NULL)
goto out;
url_sig = malloc(strlen(url) + 5);
if (url_sig == NULL) {
uprintf("Could not allocate signature URL");
goto out;
}
strcpy(url_sig, url);
strcat(url_sig, ".sig");
buf_len = DownloadToFileOrBuffer(url, NULL, &buf, hProgressDialog);
if (buf_len == 0)
goto out;
sig_len = DownloadToFileOrBuffer(url_sig, NULL, &sig, NULL);
if ((sig_len != RSA_SIGNATURE_SIZE) || (!ValidateOpensslSignature(buf, buf_len, sig, sig_len))) {
uprintf("FATAL: Server signature is invalid!");
DownloadStatus = 403; // Forbidden
goto out;
}
uprintf("Download signature is valid");
DownloadStatus = 206; // Partial content
hFile = CreateFileU(file, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ, NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
if (hFile == INVALID_HANDLE_VALUE) {
uprintf("Unable to create file '%s': %s", PathFindFileNameU(file), WinInetErrorString());
goto out;
}
if (!WriteFile(hFile, buf, buf_len, &ret, NULL)) {
uprintf("Error writing file '%s': %s", PathFindFileNameU(file), WinInetErrorString());
ret = 0;
goto out;
} else if (ret != buf_len) {
uprintf("Error writing file '%s': Only %d/%d bytes written", PathFindFileNameU(file), ret, buf_len);
ret = 0;
goto out;
}
DownloadStatus = 200; // Full content
out:
safe_closehandle(hFile);
free(url_sig);
free(buf);
free(sig);
return ret;
}
/* Threaded download */
static const char *_url, *_file;
static HWND _hProgressDialog;
static DWORD WINAPI _DownloadFileThread(LPVOID param)
static DWORD WINAPI _DownloadSignedFileThread(LPVOID param)
{
ExitThread(DownloadFile(_url, _file, _hProgressDialog) != 0);
ExitThread(DownloadSignedFile(_url, _file, _hProgressDialog) != 0);
}
HANDLE DownloadFileThreaded(const char* url, const char* file, HWND hProgressDialog)
HANDLE DownloadSignedFileThreaded(const char* url, const char* file, HWND hProgressDialog)
{
_url = url;
_file = file;
_hProgressDialog = hProgressDialog;
return CreateThread(NULL, 0, _DownloadFileThread, NULL, 0, NULL);
return CreateThread(NULL, 0, _DownloadSignedFileThread, NULL, 0, NULL);
}
static __inline uint64_t to_uint64_t(uint16_t x[4]) {
@ -443,8 +517,9 @@ static DWORD WINAPI CheckForUpdatesThread(LPVOID param)
static const char* channel[] = {"release", "beta", "test"}; // release channel
const char* accept_types[] = {"*/*\0", NULL};
DWORD dwFlags, dwSize, dwDownloaded, dwTotalSize, dwStatus;
BYTE *sig = NULL;
char* buf = NULL;
char agent[64], hostname[64], urlpath[128], mime[32];
char agent[64], hostname[64], urlpath[128], sigpath[256], mime[32];
OSVERSIONINFOA os_version = {sizeof(OSVERSIONINFOA), 0, 0, 0, 0, ""};
HINTERNET hSession = NULL, hConnection = NULL, hRequest = NULL;
URL_COMPONENTSA UrlParts = {sizeof(URL_COMPONENTSA), NULL, 1, (INTERNET_SCHEME)0,
@ -487,7 +562,7 @@ static DWORD WINAPI CheckForUpdatesThread(LPVOID param)
} while ((!force_update_check) && ((iso_op_in_progress || format_op_in_progress || (dialog_showing>0))));
if (!force_update_check) {
if ((ReadSetting32(SETTING_UPDATE_INTERVAL) == -1)) {
vuprintf("Check for updates disabled, as per settings.\n");
vuprintf("Check for updates disabled, as per settings.");
goto out;
}
reg_time = ReadSetting64(SETTING_LAST_UPDATE);
@ -500,9 +575,9 @@ static DWORD WINAPI CheckForUpdatesThread(LPVOID param)
if (!SystemTimeToFileTime(&LocalTime, &FileTime))
goto out;
local_time = ((((int64_t)FileTime.dwHighDateTime)<<32) + FileTime.dwLowDateTime) / 10000000;
vvuprintf("Local time: %" PRId64 "\n", local_time);
vvuprintf("Local time: %" PRId64, local_time);
if (local_time < reg_time + update_interval) {
vuprintf("Next update check in %" PRId64 " seconds.\n", reg_time + update_interval - local_time);
vuprintf("Next update check in %" PRId64 " seconds.", reg_time + update_interval - local_time);
goto out;
}
}
@ -512,7 +587,7 @@ static DWORD WINAPI CheckForUpdatesThread(LPVOID param)
status++; // 1
if (!GetVersionExA(&os_version)) {
uprintf("Could not read Windows version - Check for updates cancelled.\n");
uprintf("Could not read Windows version - Check for updates cancelled.");
goto out;
}
@ -540,7 +615,7 @@ static DWORD WINAPI CheckForUpdatesThread(LPVOID param)
max_channel = releases_only ? 1 : (int)ARRAYSIZE(channel) - 1;
#endif
for (k=0; (k<max_channel) && (!found_new_version); k++) {
uprintf("Checking %s channel...\n", channel[k]);
uprintf("Checking %s channel...", channel[k]);
// At this stage we can query the server for various update version files.
// We first try to lookup for "<appname>_<os_arch>_<os_version_major>_<os_version_minor>.ver"
// and then remove each each of the <os_> components until we find our match. For instance, we may first
@ -548,21 +623,18 @@ static DWORD WINAPI CheckForUpdatesThread(LPVOID param)
// This allows sunsetting OS versions (eg XP) or providing different downloads for different archs/groups.
static_sprintf(urlpath, "%s%s%s_%s_%lu.%lu.ver", APPLICATION_NAME, (k==0)?"":"_",
(k==0)?"":channel[k], archname[is_x64()?1:0], os_version.dwMajorVersion, os_version.dwMinorVersion);
vuprintf("Base update check: %s\n", urlpath);
vuprintf("Base update check: %s", urlpath);
for (i=0, j=(int)safe_strlen(urlpath)-5; (j>0)&&(i<ARRAYSIZE(verpos)); j--) {
if ((urlpath[j] == '.') || (urlpath[j] == '_')) {
verpos[i++] = j;
}
}
if (i != ARRAYSIZE(verpos)) {
uprintf("Broken code in CheckForUpdatesThread()!\n");
goto out;
}
assert(i == ARRAYSIZE(verpos));
UrlParts.lpszUrlPath = urlpath;
UrlParts.dwUrlPathLength = sizeof(urlpath);
for (i=0; i<ARRAYSIZE(verpos); i++) {
vvuprintf("Trying %s\n", UrlParts.lpszUrlPath);
vvuprintf("Trying %s", UrlParts.lpszUrlPath);
hRequest = pfHttpOpenRequestA(hConnection, "GET", UrlParts.lpszUrlPath, NULL, NULL, accept_types,
INTERNET_FLAG_IGNORE_REDIRECT_TO_HTTP|INTERNET_FLAG_IGNORE_REDIRECT_TO_HTTPS|
INTERNET_FLAG_NO_COOKIES|INTERNET_FLAG_NO_UI|INTERNET_FLAG_NO_CACHE_WRITE|INTERNET_FLAG_HYPERLINK|
@ -607,7 +679,7 @@ static DWORD WINAPI CheckForUpdatesThread(LPVOID param)
|| (!SystemTimeToFileTime(&ServerTime, &FileTime)) )
goto out;
server_time = ((((int64_t)FileTime.dwHighDateTime)<<32) + FileTime.dwLowDateTime) / 10000000;
vvuprintf("Server time: %" PRId64 "\n", server_time);
vvuprintf("Server time: %" PRId64, server_time);
// Always store the server response time - the only clock we trust!
WriteSetting64(SETTING_LAST_UPDATE, server_time);
// Might as well let the user know
@ -625,29 +697,39 @@ static DWORD WINAPI CheckForUpdatesThread(LPVOID param)
safe_free(buf);
// Make sure the file is NUL terminated
buf = (char*)calloc(dwTotalSize+1, 1);
if (buf == NULL) goto out;
if (buf == NULL)
goto out;
// This is a version file - we should be able to gulp it down in one go
if (!pfInternetReadFile(hRequest, buf, dwTotalSize, &dwDownloaded) || (dwDownloaded != dwTotalSize))
goto out;
vuprintf("Successfully downloaded version file (%d bytes)", dwTotalSize);
// Now download the signature file
static_sprintf(sigpath, "%s/%s.sig", server_url, urlpath);
dwDownloaded = DownloadToFileOrBuffer(sigpath, NULL, &sig, NULL);
if ((dwDownloaded != RSA_SIGNATURE_SIZE) || (!ValidateOpensslSignature(buf, dwTotalSize, sig, dwDownloaded))) {
uprintf("FATAL: Version signature is invalid!");
goto out;
}
vuprintf("Version signature is valid");
status++;
vuprintf("Successfully downloaded version file (%d bytes)\n", dwTotalSize);
parse_update(buf, dwTotalSize+1);
vuprintf("UPDATE DATA:\n");
vuprintf(" version: %d.%d.%d (%s)\n", update.version[0], update.version[1], update.version[2], channel[k]);
vuprintf(" platform_min: %d.%d\n", update.platform_min[0], update.platform_min[1]);
vuprintf(" url: %s\n", update.download_url);
vuprintf("UPDATE DATA:");
vuprintf(" version: %d.%d.%d (%s)", update.version[0], update.version[1], update.version[2], channel[k]);
vuprintf(" platform_min: %d.%d", update.platform_min[0], update.platform_min[1]);
vuprintf(" url: %s", update.download_url);
found_new_version = ((to_uint64_t(update.version) > to_uint64_t(rufus_version)) || (force_update))
&& ( (os_version.dwMajorVersion > update.platform_min[0])
|| ( (os_version.dwMajorVersion == update.platform_min[0]) && (os_version.dwMinorVersion >= update.platform_min[1])) );
uprintf("N%sew %s version found%c\n", found_new_version?"":"o n", channel[k], found_new_version?'!':'.');
uprintf("N%sew %s version found%c", found_new_version?"":"o n", channel[k], found_new_version?'!':'.');
}
out:
safe_free(buf);
safe_free(sig);
if (hRequest)
pfInternetCloseHandle(hRequest);
if (hConnection)