Only parse the object file once

This commit is contained in:
Lukas Wirth 2024-12-12 13:20:08 +01:00
parent 16c0f25579
commit c9314d4aff
2 changed files with 19 additions and 21 deletions

View file

@ -6,7 +6,6 @@ use proc_macro::bridge;
use std::{fmt, fs, io, time::SystemTime}; use std::{fmt, fs, io, time::SystemTime};
use libloading::Library; use libloading::Library;
use memmap2::Mmap;
use object::Object; use object::Object;
use paths::{Utf8Path, Utf8PathBuf}; use paths::{Utf8Path, Utf8PathBuf};
use proc_macro_api::ProcMacroKind; use proc_macro_api::ProcMacroKind;
@ -23,8 +22,8 @@ fn is_derive_registrar_symbol(symbol: &str) -> bool {
symbol.contains(NEW_REGISTRAR_SYMBOL) symbol.contains(NEW_REGISTRAR_SYMBOL)
} }
fn find_registrar_symbol(buffer: &[u8]) -> object::Result<Option<String>> { fn find_registrar_symbol(obj: &object::File<'_>) -> object::Result<Option<String>> {
Ok(object::File::parse(buffer)? Ok(obj
.exports()? .exports()?
.into_iter() .into_iter()
.map(|export| export.name()) .map(|export| export.name())
@ -109,15 +108,16 @@ struct ProcMacroLibraryLibloading {
impl ProcMacroLibraryLibloading { impl ProcMacroLibraryLibloading {
fn open(path: &Utf8Path) -> Result<Self, LoadProcMacroDylibError> { fn open(path: &Utf8Path) -> Result<Self, LoadProcMacroDylibError> {
let buffer = unsafe { Mmap::map(&fs::File::open(path)?)? }; let file = fs::File::open(path)?;
let file = unsafe { memmap2::Mmap::map(&file) }?;
let obj = object::File::parse(&*file)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let version_info = version::read_dylib_info(&obj)?;
let symbol_name = let symbol_name =
find_registrar_symbol(&buffer).map_err(invalid_data_err)?.ok_or_else(|| { find_registrar_symbol(&obj).map_err(invalid_data_err)?.ok_or_else(|| {
invalid_data_err(format!("Cannot find registrar symbol in file {path}")) invalid_data_err(format!("Cannot find registrar symbol in file {path}"))
})?; })?;
let version_info = version::read_dylib_info(&buffer)?;
drop(buffer);
let lib = load_library(path).map_err(invalid_data_err)?; let lib = load_library(path).map_err(invalid_data_err)?;
let proc_macros = crate::proc_macros::ProcMacros::from_lib( let proc_macros = crate::proc_macros::ProcMacros::from_lib(
&lib, &lib,

View file

@ -2,7 +2,7 @@
use std::io::{self, Read}; use std::io::{self, Read};
use object::read::{File as BinaryFile, Object, ObjectSection}; use object::read::{Object, ObjectSection};
#[derive(Debug)] #[derive(Debug)]
#[allow(dead_code)] #[allow(dead_code)]
@ -16,14 +16,14 @@ pub struct RustCInfo {
} }
/// Read rustc dylib information /// Read rustc dylib information
pub fn read_dylib_info(buffer: &[u8]) -> io::Result<RustCInfo> { pub fn read_dylib_info(obj: &object::File<'_>) -> io::Result<RustCInfo> {
macro_rules! err { macro_rules! err {
($e:literal) => { ($e:literal) => {
io::Error::new(io::ErrorKind::InvalidData, $e) io::Error::new(io::ErrorKind::InvalidData, $e)
}; };
} }
let ver_str = read_version(buffer)?; let ver_str = read_version(obj)?;
let mut items = ver_str.split_whitespace(); let mut items = ver_str.split_whitespace();
let tag = items.next().ok_or_else(|| err!("version format error"))?; let tag = items.next().ok_or_else(|| err!("version format error"))?;
if tag != "rustc" { if tag != "rustc" {
@ -70,10 +70,8 @@ pub fn read_dylib_info(buffer: &[u8]) -> io::Result<RustCInfo> {
/// This is used inside read_version() to locate the ".rustc" section /// This is used inside read_version() to locate the ".rustc" section
/// from a proc macro crate's binary file. /// from a proc macro crate's binary file.
fn read_section<'a>(dylib_binary: &'a [u8], section_name: &str) -> io::Result<&'a [u8]> { fn read_section<'a>(obj: &object::File<'a>, section_name: &str) -> io::Result<&'a [u8]> {
BinaryFile::parse(dylib_binary) obj.section_by_name(section_name)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
.section_by_name(section_name)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "section read error"))? .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "section read error"))?
.data() .data()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
@ -101,8 +99,8 @@ fn read_section<'a>(dylib_binary: &'a [u8], section_name: &str) -> io::Result<&'
/// ///
/// Check this issue for more about the bytes layout: /// Check this issue for more about the bytes layout:
/// <https://github.com/rust-lang/rust-analyzer/issues/6174> /// <https://github.com/rust-lang/rust-analyzer/issues/6174>
pub fn read_version(buffer: &[u8]) -> io::Result<String> { pub fn read_version(obj: &object::File<'_>) -> io::Result<String> {
let dot_rustc = read_section(buffer, ".rustc")?; let dot_rustc = read_section(obj, ".rustc")?;
// check if magic is valid // check if magic is valid
if &dot_rustc[0..4] != b"rust" { if &dot_rustc[0..4] != b"rust" {
@ -151,10 +149,10 @@ pub fn read_version(buffer: &[u8]) -> io::Result<String> {
#[test] #[test]
fn test_version_check() { fn test_version_check() {
let info = read_dylib_info(&unsafe { let info = read_dylib_info(
memmap2::Mmap::map(&std::fs::File::open(crate::proc_macro_test_dylib_path()).unwrap()) &object::File::parse(&*std::fs::read(crate::proc_macro_test_dylib_path()).unwrap())
.unwrap() .unwrap(),
}) )
.unwrap(); .unwrap();
assert_eq!( assert_eq!(