diff --git a/crates/hir-expand/src/builtin_derive_macro.rs b/crates/hir-expand/src/builtin_derive_macro.rs index 78218a8345..c8c3cebb88 100644 --- a/crates/hir-expand/src/builtin_derive_macro.rs +++ b/crates/hir-expand/src/builtin_derive_macro.rs @@ -12,9 +12,7 @@ use crate::{ name::{AsName, Name}, tt::{self, TokenId}, }; -use syntax::ast::{ - self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds, -}; +use syntax::ast::{self, AstNode, FieldList, HasAttrs, HasGenericParams, HasName, HasTypeBounds}; use crate::{db::ExpandDatabase, name, quote, ExpandError, ExpandResult, MacroCallId}; @@ -30,12 +28,13 @@ macro_rules! register_builtin { &self, db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + token_map: &TokenMap, ) -> ExpandResult { let expander = match *self { $( BuiltinDeriveExpander::$trait => $expand, )* }; - expander(db, id, tt) + expander(db, id, tt, token_map) } fn find_by_name(name: &name::Name) -> Option { @@ -118,13 +117,13 @@ impl VariantShape { } } - fn from(value: Option, token_map: &TokenMap) -> Result { + fn from(tm: &TokenMap, value: Option) -> Result { let r = match value { None => VariantShape::Unit, Some(FieldList::RecordFieldList(it)) => VariantShape::Struct( it.fields() .map(|it| it.name()) - .map(|it| name_to_token(token_map, it)) + .map(|it| name_to_token(tm, it)) .collect::>()?, ), Some(FieldList::TupleFieldList(it)) => VariantShape::Tuple(it.fields().count()), @@ -190,25 +189,12 @@ struct BasicAdtInfo { associated_types: Vec, } -fn parse_adt(tt: &tt::Subtree) -> Result { - let (parsed, token_map) = mbe::token_tree_to_syntax_node(tt, mbe::TopEntryPoint::MacroItems); - let macro_items = ast::MacroItems::cast(parsed.syntax_node()).ok_or_else(|| { - debug!("derive node didn't parse"); - ExpandError::other("invalid item definition") - })?; - let item = macro_items.items().next().ok_or_else(|| { - debug!("no module item parsed"); - ExpandError::other("no item found") - })?; - let adt = ast::Adt::cast(item.syntax().clone()).ok_or_else(|| { - debug!("expected adt, found: {:?}", item); - ExpandError::other("expected struct, enum or union") - })?; +fn parse_adt(tm: &TokenMap, adt: &ast::Adt) -> Result { let (name, generic_param_list, shape) = match &adt { ast::Adt::Struct(it) => ( it.name(), it.generic_param_list(), - AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?), + AdtShape::Struct(VariantShape::from(tm, it.field_list())?), ), ast::Adt::Enum(it) => { let default_variant = it @@ -227,8 +213,8 @@ fn parse_adt(tt: &tt::Subtree) -> Result { .flat_map(|it| it.variants()) .map(|it| { Ok(( - name_to_token(&token_map, it.name())?, - VariantShape::from(it.field_list(), &token_map)?, + name_to_token(tm, it.name())?, + VariantShape::from(tm, it.field_list())?, )) }) .collect::>()?, @@ -298,7 +284,7 @@ fn parse_adt(tt: &tt::Subtree) -> Result { }) .map(|it| mbe::syntax_node_to_token_tree(it.syntax()).0) .collect(); - let name_token = name_to_token(&token_map, name)?; + let name_token = name_to_token(&tm, name)?; Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types }) } @@ -345,11 +331,12 @@ fn name_to_token(token_map: &TokenMap, name: Option) -> Result tt::Subtree, ) -> ExpandResult { - let info = match parse_adt(tt) { + let info = match parse_adt(tm, tt) { Ok(info) => info, Err(e) => return ExpandResult::new(tt::Subtree::empty(), e), }; @@ -405,19 +392,21 @@ fn find_builtin_crate(db: &dyn ExpandDatabase, id: MacroCallId) -> tt::TokenTree fn copy_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::marker::Copy }, |_| quote! {}) + expand_simple_derive(tt, tm, quote! { #krate::marker::Copy }, |_| quote! {}) } fn clone_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::clone::Clone }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::clone::Clone }, |adt| { if matches!(adt.shape, AdtShape::Union) { let star = tt::Punct { char: '*', @@ -479,10 +468,11 @@ fn and_and() -> ::tt::Subtree { fn default_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = &find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::default::Default }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::default::Default }, |adt| { let body = match &adt.shape { AdtShape::Struct(fields) => { let name = &adt.name; @@ -518,10 +508,11 @@ fn default_expand( fn debug_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = &find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::fmt::Debug }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::fmt::Debug }, |adt| { let for_variant = |name: String, v: &VariantShape| match v { VariantShape::Struct(fields) => { let for_fields = fields.iter().map(|it| { @@ -598,10 +589,11 @@ fn debug_expand( fn hash_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = &find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::hash::Hash }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::hash::Hash }, |adt| { if matches!(adt.shape, AdtShape::Union) { // FIXME: Return expand error here return quote! {}; @@ -646,19 +638,21 @@ fn hash_expand( fn eq_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::cmp::Eq }, |_| quote! {}) + expand_simple_derive(tt, tm, quote! { #krate::cmp::Eq }, |_| quote! {}) } fn partial_eq_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::cmp::PartialEq }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::cmp::PartialEq }, |adt| { if matches!(adt.shape, AdtShape::Union) { // FIXME: Return expand error here return quote! {}; @@ -722,10 +716,11 @@ fn self_and_other_patterns( fn ord_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = &find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::cmp::Ord }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::cmp::Ord }, |adt| { fn compare( krate: &tt::TokenTree, left: tt::Subtree, @@ -786,10 +781,11 @@ fn ord_expand( fn partial_ord_expand( db: &dyn ExpandDatabase, id: MacroCallId, - tt: &tt::Subtree, + tt: &ast::Adt, + tm: &TokenMap, ) -> ExpandResult { let krate = &find_builtin_crate(db, id); - expand_simple_derive(tt, quote! { #krate::cmp::PartialOrd }, |adt| { + expand_simple_derive(tt, tm, quote! { #krate::cmp::PartialOrd }, |adt| { fn compare( krate: &tt::TokenTree, left: tt::Subtree, diff --git a/crates/hir-expand/src/db.rs b/crates/hir-expand/src/db.rs index 3f08a09869..f528dfed31 100644 --- a/crates/hir-expand/src/db.rs +++ b/crates/hir-expand/src/db.rs @@ -55,7 +55,9 @@ impl TokenExpander { TokenExpander::Builtin(it) => it.expand(db, id, tt).map_err(Into::into), TokenExpander::BuiltinEager(it) => it.expand(db, id, tt).map_err(Into::into), TokenExpander::BuiltinAttr(it) => it.expand(db, id, tt), - TokenExpander::BuiltinDerive(it) => it.expand(db, id, tt), + TokenExpander::BuiltinDerive(_) => { + unreachable!("builtin derives should be expanded manually") + } TokenExpander::ProcMacro(_) => { unreachable!("ExpandDatabase::expand_proc_macro should be used for proc macros") } @@ -232,6 +234,11 @@ pub fn expand_speculative( MacroDefKind::BuiltInAttr(BuiltinAttrExpander::Derive, _) => { pseudo_derive_attr_expansion(&tt, attr_arg.as_ref()?) } + MacroDefKind::BuiltInDerive(expander, ..) => { + // this cast is a bit sus, can we avoid losing the typedness here? + let adt = ast::Adt::cast(speculative_args.clone()).unwrap(); + expander.expand(db, actual_macro_call, &adt, &spec_args_tmap) + } _ => macro_def.expand(db, actual_macro_call, &tt), }; @@ -333,6 +340,9 @@ fn macro_arg( Some(Arc::new((tt, tmap, fixups.undo_info))) } +/// Certain macro calls expect some nodes in the input to be preprocessed away, namely: +/// - derives expect all `#[derive(..)]` invocations up to the currently invoked one to be stripped +/// - attributes expect the invoking attribute to be stripped fn censor_for_macro_input(loc: &MacroCallLoc, node: &SyntaxNode) -> FxHashSet { // FIXME: handle `cfg_attr` (|| { @@ -451,37 +461,59 @@ fn macro_expand(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult return db.expand_proc_macro(id), + MacroDefKind::BuiltInDerive(expander, ..) => { + let arg = db.macro_arg_text(id).unwrap(); - let expander = match db.macro_def(loc.def) { - Ok(it) => it, - // FIXME: We should make sure to enforce a variant that invalid macro - // definitions do not get expanders that could reach this call path! - Err(err) => { - return ExpandResult { - value: Arc::new(tt::Subtree { - delimiter: tt::Delimiter::UNSPECIFIED, - token_trees: vec![], - }), - err: Some(ExpandError::other(format!("invalid macro definition: {err}"))), - } + let node = SyntaxNode::new_root(arg); + let censor = censor_for_macro_input(&loc, &node); + let mut fixups = fixup::fixup_syntax(&node); + fixups.replace.extend(censor.into_iter().map(|node| (node.into(), Vec::new()))); + let (tmap, _) = mbe::syntax_node_to_token_map_with_modifications( + &node, + fixups.token_map, + fixups.next_id, + fixups.replace, + fixups.append, + ); + + // this cast is a bit sus, can we avoid losing the typedness here? + let adt = ast::Adt::cast(node).unwrap(); + (expander.expand(db, id, &adt, &tmap), Some((tmap, fixups.undo_info))) + } + _ => { + let expander = match db.macro_def(loc.def) { + Ok(it) => it, + // FIXME: We should make sure to enforce a variant that invalid macro + // definitions do not get expanders that could reach this call path! + Err(err) => { + return ExpandResult { + value: Arc::new(tt::Subtree { + delimiter: tt::Delimiter::UNSPECIFIED, + token_trees: vec![], + }), + err: Some(ExpandError::other(format!("invalid macro definition: {err}"))), + } + } + }; + let Some(macro_arg) = db.macro_arg(id) else { + return ExpandResult { + value: Arc::new(tt::Subtree { + delimiter: tt::Delimiter::UNSPECIFIED, + token_trees: Vec::new(), + }), + // FIXME: We should make sure to enforce an invariant that invalid macro + // calls do not reach this call path! + err: Some(ExpandError::other("invalid token tree")), + }; + }; + let (arg, arg_tm, undo_info) = &*macro_arg; + let mut res = expander.expand(db, id, arg); + fixup::reverse_fixups(&mut res.value, arg_tm, undo_info); + (res, None) } }; - let Some(macro_arg) = db.macro_arg(id) else { - return ExpandResult { - value: Arc::new(tt::Subtree { - delimiter: tt::Delimiter::UNSPECIFIED, - token_trees: Vec::new(), - }), - // FIXME: We should make sure to enforce an invariant that invalid macro - // calls do not reach this call path! - err: Some(ExpandError::other("invalid token tree")), - }; - }; - let (arg_tt, arg_tm, undo_info) = &*macro_arg; - let ExpandResult { value: mut tt, mut err } = expander.expand(db, id, arg_tt); if let Some(EagerCallInfo { error, .. }) = loc.eager.as_deref() { // FIXME: We should report both errors! @@ -493,7 +525,9 @@ fn macro_expand(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult TokenMap { + syntax_node_to_token_map_with_modifications( + node, + Default::default(), + 0, + Default::default(), + Default::default(), + ) + .0 +} + +/// Convert the syntax node to a `TokenTree` (what macro will consume) +/// with the censored range excluded. +pub fn syntax_node_to_token_map_with_modifications( + node: &SyntaxNode, + existing_token_map: TokenMap, + next_id: u32, + replace: FxHashMap>, + append: FxHashMap>, +) -> (TokenMap, u32) { + let global_offset = node.text_range().start(); + let mut c = Converter::new(node, global_offset, existing_token_map, next_id, replace, append); + collect_tokens(&mut c); + c.id_alloc.map.shrink_to_fit(); + always!(c.replace.is_empty(), "replace: {:?}", c.replace); + always!(c.append.is_empty(), "append: {:?}", c.append); + (c.id_alloc.map, c.id_alloc.next_id) +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct SyntheticTokenId(pub u32); @@ -327,6 +358,111 @@ fn convert_tokens(conv: &mut C) -> tt::Subtree { } } +fn collect_tokens(conv: &mut C) { + struct StackEntry { + idx: usize, + open_range: TextRange, + delimiter: tt::DelimiterKind, + } + + let entry = StackEntry { + delimiter: tt::DelimiterKind::Invisible, + // never used (delimiter is `None`) + idx: !0, + open_range: TextRange::empty(TextSize::of('.')), + }; + let mut stack = NonEmptyVec::new(entry); + + loop { + let StackEntry { delimiter, .. } = stack.last_mut(); + let (token, range) = match conv.bump() { + Some(it) => it, + None => break, + }; + let synth_id = token.synthetic_id(conv); + + let kind = token.kind(conv); + if kind == COMMENT { + // Since `convert_doc_comment` can fail, we need to peek the next id, so that we can + // figure out which token id to use for the doc comment, if it is converted successfully. + let next_id = conv.id_alloc().peek_next_id(); + if let Some(_tokens) = conv.convert_doc_comment(&token, next_id) { + let id = conv.id_alloc().alloc(range, synth_id); + debug_assert_eq!(id, next_id); + } + continue; + } + if kind.is_punct() && kind != UNDERSCORE { + if synth_id.is_none() { + assert_eq!(range.len(), TextSize::of('.')); + } + + let expected = match delimiter { + tt::DelimiterKind::Parenthesis => Some(T![')']), + tt::DelimiterKind::Brace => Some(T!['}']), + tt::DelimiterKind::Bracket => Some(T![']']), + tt::DelimiterKind::Invisible => None, + }; + + if let Some(expected) = expected { + if kind == expected { + if let Some(entry) = stack.pop() { + conv.id_alloc().close_delim(entry.idx, Some(range)); + } + continue; + } + } + + let delim = match kind { + T!['('] => Some(tt::DelimiterKind::Parenthesis), + T!['{'] => Some(tt::DelimiterKind::Brace), + T!['['] => Some(tt::DelimiterKind::Bracket), + _ => None, + }; + + if let Some(kind) = delim { + let (_id, idx) = conv.id_alloc().open_delim(range, synth_id); + + stack.push(StackEntry { idx, open_range: range, delimiter: kind }); + continue; + } + + conv.id_alloc().alloc(range, synth_id); + } else { + macro_rules! make_leaf { + ($i:ident) => {{ + conv.id_alloc().alloc(range, synth_id); + }}; + } + match kind { + T![true] | T![false] => make_leaf!(Ident), + IDENT => make_leaf!(Ident), + UNDERSCORE => make_leaf!(Ident), + k if k.is_keyword() => make_leaf!(Ident), + k if k.is_literal() => make_leaf!(Literal), + LIFETIME_IDENT => { + let char_unit = TextSize::of('\''); + let r = TextRange::at(range.start(), char_unit); + conv.id_alloc().alloc(r, synth_id); + + let r = TextRange::at(range.start() + char_unit, range.len() - char_unit); + conv.id_alloc().alloc(r, synth_id); + continue; + } + _ => continue, + }; + }; + + // If we get here, we've consumed all input tokens. + // We might have more than one subtree in the stack, if the delimiters are improperly balanced. + // Merge them so we're left with one. + while let Some(entry) = stack.pop() { + conv.id_alloc().close_delim(entry.idx, None); + conv.id_alloc().alloc(entry.open_range, None); + } + } +} + fn is_single_token_op(kind: SyntaxKind) -> bool { matches!( kind,