Fix Windows with_tempfile

This commit is contained in:
Richard Feldman 2024-07-04 14:53:08 -04:00
parent 5bf2c11bea
commit 6932b3f8f2
No known key found for this signature in database
GPG key ID: F1F21AA5B1D9E43B
2 changed files with 50 additions and 47 deletions

View file

@ -473,6 +473,7 @@ pub fn test(_matches: &ArgMatches, _target: Target) -> io::Result<i32> {
todo!("running tests does not work on windows right now")
}
#[cfg(not(windows))]
struct ModuleTestResults {
module_id: ModuleId,
failed_count: usize,

View file

@ -63,7 +63,13 @@ extern "system" {
lpOverlapped: *mut core::ffi::c_void,
) -> i32;
fn DeleteFileW(lpFileName: *const u16) -> i32;
fn GetTempPathW(nBufferLength: u32, lpBuffer: *mut u16) -> u32;
fn GetTempPath2W(nBufferLength: u32, lpBuffer: *mut u16) -> u32;
fn GetTempFileNameW(
lpPathName: *const u16,
lpPrefixString: *const u16,
uUnique: u32,
lpTempFileName: *mut u16,
) -> u32;
fn CreateDirectoryW(
lpPathName: *const u16,
lpSecurityAttributes: *mut core::ffi::c_void,
@ -260,6 +266,9 @@ impl File {
#[cfg(windows)]
impl File {
/// Source: https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Foundation/constant.MAX_PATH.html
const MAX_PATH: u32 = 260;
/// Returns whether it succeeded.
pub fn remove(path: &NativePath) -> bool {
unsafe { DeleteFileW(path.inner.as_ptr()) != 0 }
@ -304,75 +313,69 @@ impl File {
run(Err(FileIoErr::most_recent()))
}
}
}
#[cfg(windows)]
#[cfg(windows)]
impl File {
/// Create a tempfile, open it as a File, pass that File and its generated path
/// to the given function, and then delete it after the function returns.
pub fn with_tempfile<T>(run: impl FnOnce(Result<(&NativePath, Self), FileIoErr>) -> T) -> T {
let mut temp_path_buf = [0u16; 261];
let temp_path_len =
unsafe { GetTempPathW(temp_path_buf.len() as u32, temp_path_buf.as_mut_ptr()) };
pub fn with_tempfile<T>(
run: impl FnOnce(Result<(&NativePath, &mut Self), FileIoErr>) -> T,
) -> T {
let tempdir_path: &mut [MaybeUninit<u16>] =
&mut [MaybeUninit::uninit(); Self::MAX_PATH as usize + 1];
if temp_path_len == 0 || temp_path_len > temp_path_buf.len() as u32 {
let native_path =
NativePath::new(U16CString::from_vec_with_nul(temp_path_buf.to_vec()).unwrap());
return run(&native_path, None);
let tempdir_path_len =
unsafe { GetTempPath2W(tempdir_path.len() as u32, tempdir_path.as_mut_ptr().cast()) };
if tempdir_path_len == 0 {
return run(Err(FileIoErr::most_recent()));
}
let mut template = Vec::from(&temp_path_buf[..temp_path_len as usize]);
template.extend_from_slice(&[
b't' as u16,
b'e' as u16,
b'm' as u16,
b'p' as u16,
b'f' as u16,
b'i' as u16,
b'l' as u16,
b'e' as u16,
b'.' as u16,
b't' as u16,
b'm' as u16,
b'p' as u16,
0,
]);
let suffix =
// Note: only up to the first 3 chars of this string will be used,
// so we only give it 3.
//
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-gettempfilenamew
widestring::u16cstr!("roc").as_slice_with_nul();
let tempfile_path: &mut [MaybeUninit<u16>] =
&mut [MaybeUninit::uninit(); Self::MAX_PATH as usize + 1];
let mut temp_file_name = [0u16; 261];
let result = unsafe {
GetTempFileNameW(
temp_path_buf.as_ptr(),
U16CString::from_str("tmp").unwrap().as_ptr(),
tempdir_path.as_ptr().cast(),
suffix.as_ptr(),
0,
temp_file_name.as_mut_ptr(),
tempfile_path.as_mut_ptr().cast(),
)
};
if result == 0 {
let native_path =
NativePath::new(U16CString::from_vec_with_nul(temp_file_name.to_vec()).unwrap());
return run(&native_path, None);
return run(Err(FileIoErr::most_recent()));
}
let handle = unsafe {
CreateFileW(
temp_file_name.as_ptr(),
0x40000000, // GENERIC_WRITE
0,
tempfile_path.as_mut_ptr().cast(),
Self::GENERIC_WRITE,
Self::FILE_SHARE_READ | Self::FILE_SHARE_WRITE,
core::ptr::null_mut(),
1, // CREATE_ALWAYS
Self::CREATE_ALWAYS,
0x80, // FILE_ATTRIBUTE_TEMPORARY
0,
)
};
if handle != -1 {
// Since we pass an owned File, it will get closed automatically once dropped.
// This in turn will result in the file getting deleted, since Windows sets
// tempfiles to delete once the handle is closed.
run(Some(File { handle }))
let mut file = File { handle };
let native_path = unsafe { U16CStr::from_ptr_str(tempfile_path.as_ptr().cast()) };
// Windows automatically deletes tempfiles when the last handle
// to them is closed, so we don't need to delete this explicitly.
run(Ok((native_path.into(), &mut file)))
} else {
// Here we assume that since CreateFileW errored out, the file was not created
// and we shouldn't attempt to delete it.
run(None)
run(Err(FileIoErr::most_recent()))
}
}
}
@ -481,9 +484,8 @@ mod tests {
#[cfg(windows)]
fn str_to_u16cstr(s: &str) -> &widestring::U16CStr {
widestring::U16CString::from_str(s)
.expect("U16CStr conversion failed")
.as_ref()
let wide_str: Vec<u16> = s.encode_utf16().collect::<Vec<_>>();
unsafe { U16CStr::from_ptr_str(wide_str.as_ptr()) }
}
#[cfg(unix)]