Merge pull request #18995 from alibektas/12210

fix: Lower range pattern bounds to expressions
This commit is contained in:
Lukas Wirth 2025-02-12 11:58:33 +00:00 committed by GitHub
commit 622ef64f93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 187 additions and 117 deletions

View file

@ -112,9 +112,9 @@ pub struct ExpressionStoreSourceMap {
// AST expressions can create patterns in destructuring assignments. Therefore, `ExprSource` can also map // AST expressions can create patterns in destructuring assignments. Therefore, `ExprSource` can also map
// to `PatId`, and `PatId` can also map to `ExprSource` (the other way around is unaffected). // to `PatId`, and `PatId` can also map to `ExprSource` (the other way around is unaffected).
expr_map: FxHashMap<ExprSource, ExprOrPatId>, expr_map: FxHashMap<ExprSource, ExprOrPatId>,
expr_map_back: ArenaMap<ExprId, ExprSource>, expr_map_back: ArenaMap<ExprId, ExprOrPatSource>,
pat_map: FxHashMap<PatSource, PatId>, pat_map: FxHashMap<PatSource, ExprOrPatId>,
pat_map_back: ArenaMap<PatId, ExprOrPatSource>, pat_map_back: ArenaMap<PatId, ExprOrPatSource>,
label_map: FxHashMap<LabelSource, LabelId>, label_map: FxHashMap<LabelSource, LabelId>,
@ -606,12 +606,12 @@ impl Index<TypeRefId> for ExpressionStore {
impl ExpressionStoreSourceMap { impl ExpressionStoreSourceMap {
pub fn expr_or_pat_syntax(&self, id: ExprOrPatId) -> Result<ExprOrPatSource, SyntheticSyntax> { pub fn expr_or_pat_syntax(&self, id: ExprOrPatId) -> Result<ExprOrPatSource, SyntheticSyntax> {
match id { match id {
ExprOrPatId::ExprId(id) => self.expr_syntax(id).map(|it| it.map(AstPtr::wrap_left)), ExprOrPatId::ExprId(id) => self.expr_syntax(id),
ExprOrPatId::PatId(id) => self.pat_syntax(id), ExprOrPatId::PatId(id) => self.pat_syntax(id),
} }
} }
pub fn expr_syntax(&self, expr: ExprId) -> Result<ExprSource, SyntheticSyntax> { pub fn expr_syntax(&self, expr: ExprId) -> Result<ExprOrPatSource, SyntheticSyntax> {
self.expr_map_back.get(expr).cloned().ok_or(SyntheticSyntax) self.expr_map_back.get(expr).cloned().ok_or(SyntheticSyntax)
} }
@ -633,7 +633,7 @@ impl ExpressionStoreSourceMap {
self.pat_map_back.get(pat).cloned().ok_or(SyntheticSyntax) self.pat_map_back.get(pat).cloned().ok_or(SyntheticSyntax)
} }
pub fn node_pat(&self, node: InFile<&ast::Pat>) -> Option<PatId> { pub fn node_pat(&self, node: InFile<&ast::Pat>) -> Option<ExprOrPatId> {
self.pat_map.get(&node.map(AstPtr::new)).cloned() self.pat_map.get(&node.map(AstPtr::new)).cloned()
} }

View file

@ -44,8 +44,8 @@ use crate::{
FormatPlaceholder, FormatSign, FormatTrait, FormatPlaceholder, FormatSign, FormatTrait,
}, },
Array, Binding, BindingAnnotation, BindingId, BindingProblems, CaptureBy, ClosureKind, Array, Binding, BindingAnnotation, BindingId, BindingProblems, CaptureBy, ClosureKind,
Expr, ExprId, Item, Label, LabelId, Literal, LiteralOrConst, MatchArm, Movability, Expr, ExprId, Item, Label, LabelId, Literal, MatchArm, Movability, OffsetOf, Pat, PatId,
OffsetOf, Pat, PatId, RecordFieldPat, RecordLitField, Statement, RecordFieldPat, RecordLitField, Statement,
}, },
item_scope::BuiltinShadowMode, item_scope::BuiltinShadowMode,
lang_item::LangItem, lang_item::LangItem,
@ -1784,23 +1784,33 @@ impl ExprCollector<'_> {
self.collect_macro_call(call, macro_ptr, true, |this, expanded_pat| { self.collect_macro_call(call, macro_ptr, true, |this, expanded_pat| {
this.collect_pat_opt(expanded_pat, binding_list) this.collect_pat_opt(expanded_pat, binding_list)
}); });
self.source_map.pat_map.insert(src, pat); self.source_map.pat_map.insert(src, pat.into());
return pat; return pat;
} }
None => Pat::Missing, None => Pat::Missing,
}, },
// FIXME: implement in a way that also builds source map and calculates assoc resolutions in type inference.
ast::Pat::RangePat(p) => { ast::Pat::RangePat(p) => {
let mut range_part_lower = |p: Option<ast::Pat>| { let mut range_part_lower = |p: Option<ast::Pat>| -> Option<ExprId> {
p.and_then(|it| match &it { p.and_then(|it| {
ast::Pat::LiteralPat(it) => { let ptr = PatPtr::new(&it);
Some(Box::new(LiteralOrConst::Literal(pat_literal_to_hir(it)?.0))) match &it {
} ast::Pat::LiteralPat(it) => Some(self.alloc_expr_from_pat(
pat @ (ast::Pat::IdentPat(_) | ast::Pat::PathPat(_)) => { Expr::Literal(pat_literal_to_hir(it)?.0),
let subpat = self.collect_pat(pat.clone(), binding_list); ptr,
Some(Box::new(LiteralOrConst::Const(subpat))) )),
} ast::Pat::IdentPat(ident) if ident.is_simple_ident() => ident
.name()
.map(|name| name.as_name())
.map(Path::from)
.map(|path| self.alloc_expr_from_pat(Expr::Path(path), ptr)),
ast::Pat::PathPat(p) => p
.path()
.and_then(|path| self.parse_path(path))
.map(|parsed| self.alloc_expr_from_pat(Expr::Path(parsed), ptr)),
// We only need to handle literal, ident (if bare) and path patterns here,
// as any other pattern as a range pattern operand is semantically invalid.
_ => None, _ => None,
}
}) })
}; };
let start = range_part_lower(p.start()); let start = range_part_lower(p.start());
@ -1863,7 +1873,7 @@ impl ExprCollector<'_> {
} }
}); });
if let Some(pat) = pat.left() { if let Some(pat) = pat.left() {
self.source_map.pat_map.insert(src, pat); self.source_map.pat_map.insert(src, pat.into());
} }
pat pat
} }
@ -2490,7 +2500,7 @@ impl ExprCollector<'_> {
fn alloc_expr(&mut self, expr: Expr, ptr: ExprPtr) -> ExprId { fn alloc_expr(&mut self, expr: Expr, ptr: ExprPtr) -> ExprId {
let src = self.expander.in_file(ptr); let src = self.expander.in_file(ptr);
let id = self.store.exprs.alloc(expr); let id = self.store.exprs.alloc(expr);
self.source_map.expr_map_back.insert(id, src); self.source_map.expr_map_back.insert(id, src.map(AstPtr::wrap_left));
self.source_map.expr_map.insert(src, id.into()); self.source_map.expr_map.insert(src, id.into());
id id
} }
@ -2502,7 +2512,7 @@ impl ExprCollector<'_> {
fn alloc_expr_desugared_with_ptr(&mut self, expr: Expr, ptr: ExprPtr) -> ExprId { fn alloc_expr_desugared_with_ptr(&mut self, expr: Expr, ptr: ExprPtr) -> ExprId {
let src = self.expander.in_file(ptr); let src = self.expander.in_file(ptr);
let id = self.store.exprs.alloc(expr); let id = self.store.exprs.alloc(expr);
self.source_map.expr_map_back.insert(id, src); self.source_map.expr_map_back.insert(id, src.map(AstPtr::wrap_left));
// We intentionally don't fill this as it could overwrite a non-desugared entry // We intentionally don't fill this as it could overwrite a non-desugared entry
// self.source_map.expr_map.insert(src, id); // self.source_map.expr_map.insert(src, id);
id id
@ -2526,11 +2536,20 @@ impl ExprCollector<'_> {
self.source_map.pat_map_back.insert(id, src.map(AstPtr::wrap_left)); self.source_map.pat_map_back.insert(id, src.map(AstPtr::wrap_left));
id id
} }
fn alloc_expr_from_pat(&mut self, expr: Expr, ptr: PatPtr) -> ExprId {
let src = self.expander.in_file(ptr);
let id = self.store.exprs.alloc(expr);
self.source_map.pat_map.insert(src, id.into());
self.source_map.expr_map_back.insert(id, src.map(AstPtr::wrap_right));
id
}
fn alloc_pat(&mut self, pat: Pat, ptr: PatPtr) -> PatId { fn alloc_pat(&mut self, pat: Pat, ptr: PatPtr) -> PatId {
let src = self.expander.in_file(ptr); let src = self.expander.in_file(ptr);
let id = self.store.pats.alloc(pat); let id = self.store.pats.alloc(pat);
self.source_map.pat_map_back.insert(id, src.map(AstPtr::wrap_right)); self.source_map.pat_map_back.insert(id, src.map(AstPtr::wrap_right));
self.source_map.pat_map.insert(src, id); self.source_map.pat_map.insert(src, id.into());
id id
} }
// FIXME: desugared pats don't have ptr, that's wrong and should be fixed somehow. // FIXME: desugared pats don't have ptr, that's wrong and should be fixed somehow.

View file

@ -6,10 +6,7 @@ use itertools::Itertools;
use span::Edition; use span::Edition;
use crate::{ use crate::{
hir::{ hir::{Array, BindingAnnotation, CaptureBy, ClosureKind, Literal, Movability, Statement},
Array, BindingAnnotation, CaptureBy, ClosureKind, Literal, LiteralOrConst, Movability,
Statement,
},
pretty::{print_generic_args, print_path, print_type_ref}, pretty::{print_generic_args, print_path, print_type_ref},
}; };
@ -656,11 +653,11 @@ impl Printer<'_> {
} }
Pat::Range { start, end } => { Pat::Range { start, end } => {
if let Some(start) = start { if let Some(start) = start {
self.print_literal_or_const(start); self.print_expr(*start);
} }
w!(self, "..="); w!(self, "..=");
if let Some(end) = end { if let Some(end) = end {
self.print_literal_or_const(end); self.print_expr(*end);
} }
} }
Pat::Slice { prefix, slice, suffix } => { Pat::Slice { prefix, slice, suffix } => {
@ -757,13 +754,6 @@ impl Printer<'_> {
} }
} }
fn print_literal_or_const(&mut self, literal_or_const: &LiteralOrConst) {
match literal_or_const {
LiteralOrConst::Literal(l) => self.print_literal(l),
LiteralOrConst::Const(c) => self.print_pat(*c),
}
}
fn print_literal(&mut self, literal: &Literal) { fn print_literal(&mut self, literal: &Literal) {
match literal { match literal {
Literal::String(it) => w!(self, "{:?}", it), Literal::String(it) => w!(self, "{:?}", it),

View file

@ -1,11 +1,10 @@
mod block; mod block;
use crate::{hir::MatchArm, test_db::TestDB, ModuleDefId};
use expect_test::{expect, Expect}; use expect_test::{expect, Expect};
use la_arena::RawIdx; use la_arena::RawIdx;
use test_fixture::WithFixture; use test_fixture::WithFixture;
use crate::{test_db::TestDB, ModuleDefId};
use super::*; use super::*;
fn lower(#[rust_analyzer::rust_fixture] ra_fixture: &str) -> (TestDB, Arc<Body>, DefWithBodyId) { fn lower(#[rust_analyzer::rust_fixture] ra_fixture: &str) -> (TestDB, Arc<Body>, DefWithBodyId) {
@ -460,3 +459,45 @@ async fn foo(a: (), b: i32) -> u32 {
expect!["fn foo(<28>: (), <20>: i32) -> impl ::core::future::Future::<Output = u32> <20>"] expect!["fn foo(<28>: (), <20>: i32) -> impl ::core::future::Future::<Output = u32> <20>"]
.assert_eq(&printed); .assert_eq(&printed);
} }
#[test]
fn range_bounds_are_hir_exprs() {
let (_, body, _) = lower(
r#"
pub const L: i32 = 6;
mod x {
pub const R: i32 = 100;
}
const fn f(x: i32) -> i32 {
match x {
-1..=5 => x * 10,
L..=x::R => x * 100,
_ => x,
}
}"#,
);
let mtch_arms = body
.exprs
.iter()
.find_map(|(_, expr)| {
if let Expr::Match { arms, .. } = expr {
return Some(arms);
}
None
})
.unwrap();
let MatchArm { pat, .. } = mtch_arms[1];
match body.pats[pat] {
Pat::Range { start, end } => {
let hir_start = &body.exprs[start.unwrap()];
let hir_end = &body.exprs[end.unwrap()];
assert!(matches!(hir_start, Expr::Path { .. }));
assert!(matches!(hir_end, Expr::Path { .. }));
}
_ => {}
}
}

View file

@ -55,12 +55,20 @@ impl ExprOrPatId {
} }
} }
pub fn is_expr(&self) -> bool {
matches!(self, Self::ExprId(_))
}
pub fn as_pat(self) -> Option<PatId> { pub fn as_pat(self) -> Option<PatId> {
match self { match self {
Self::PatId(v) => Some(v), Self::PatId(v) => Some(v),
_ => None, _ => None,
} }
} }
pub fn is_pat(&self) -> bool {
matches!(self, Self::PatId(_))
}
} }
stdx::impl_from!(ExprId, PatId for ExprOrPatId); stdx::impl_from!(ExprId, PatId for ExprOrPatId);
@ -571,8 +579,8 @@ pub enum Pat {
ellipsis: bool, ellipsis: bool,
}, },
Range { Range {
start: Option<Box<LiteralOrConst>>, start: Option<ExprId>,
end: Option<Box<LiteralOrConst>>, end: Option<ExprId>,
}, },
Slice { Slice {
prefix: Box<[PatId]>, prefix: Box<[PatId]>,

View file

@ -440,7 +440,9 @@ impl ExprValidator {
return; return;
}; };
let root = source_ptr.file_syntax(db.upcast()); let root = source_ptr.file_syntax(db.upcast());
let ast::Expr::IfExpr(if_expr) = source_ptr.value.to_node(&root) else { let either::Left(ast::Expr::IfExpr(if_expr)) =
source_ptr.value.to_node(&root)
else {
return; return;
}; };
let mut top_if_expr = if_expr; let mut top_if_expr = if_expr;

View file

@ -8,8 +8,8 @@ use hir_def::{
data::adt::{StructKind, VariantData}, data::adt::{StructKind, VariantData},
expr_store::{Body, HygieneId}, expr_store::{Body, HygieneId},
hir::{ hir::{
ArithOp, Array, BinaryOp, BindingAnnotation, BindingId, ExprId, LabelId, Literal, ArithOp, Array, BinaryOp, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm,
LiteralOrConst, MatchArm, Pat, PatId, RecordFieldPat, RecordLitField, Pat, PatId, RecordFieldPat, RecordLitField,
}, },
lang_item::{LangItem, LangItemTarget}, lang_item::{LangItem, LangItemTarget},
path::Path, path::Path,
@ -1358,20 +1358,10 @@ impl<'ctx> MirLowerCtx<'ctx> {
Ok(()) Ok(())
} }
fn lower_literal_or_const_to_operand( fn lower_literal_or_const_to_operand(&mut self, ty: Ty, loc: &ExprId) -> Result<Operand> {
&mut self, match &self.body.exprs[*loc] {
ty: Ty, Expr::Literal(l) => self.lower_literal_to_operand(ty, l),
loc: &LiteralOrConst, Expr::Path(c) => {
) -> Result<Operand> {
match loc {
LiteralOrConst::Literal(l) => self.lower_literal_to_operand(ty, l),
LiteralOrConst::Const(c) => {
let c = match &self.body.pats[*c] {
Pat::Path(p) => p,
_ => not_supported!(
"only `char` and numeric types are allowed in range patterns"
),
};
let edition = self.edition(); let edition = self.edition();
let unresolved_name = let unresolved_name =
|| MirLowerError::unresolved_path(self.db, c, edition, &self.body.types); || MirLowerError::unresolved_path(self.db, c, edition, &self.body.types);
@ -1392,6 +1382,9 @@ impl<'ctx> MirLowerCtx<'ctx> {
} }
} }
} }
_ => {
not_supported!("only `char` and numeric types are allowed in range patterns");
}
} }
} }

View file

@ -1,6 +1,6 @@
//! MIR lowering for patterns //! MIR lowering for patterns
use hir_def::{hir::LiteralOrConst, AssocItemId}; use hir_def::{hir::ExprId, AssocItemId};
use crate::{ use crate::{
mir::{ mir::{
@ -207,7 +207,7 @@ impl MirLowerCtx<'_> {
)? )?
} }
Pat::Range { start, end } => { Pat::Range { start, end } => {
let mut add_check = |l: &LiteralOrConst, binop| -> Result<()> { let mut add_check = |l: &ExprId, binop| -> Result<()> {
let lv = let lv =
self.lower_literal_or_const_to_operand(self.infer[pattern].clone(), l)?; self.lower_literal_or_const_to_operand(self.infer[pattern].clone(), l)?;
let else_target = *current_else.get_or_insert_with(|| self.new_basic_block()); let else_target = *current_else.get_or_insert_with(|| self.new_basic_block());

View file

@ -6,6 +6,7 @@
use cfg::{CfgExpr, CfgOptions}; use cfg::{CfgExpr, CfgOptions};
use either::Either; use either::Either;
use hir_def::{ use hir_def::{
expr_store::ExprOrPatPtr,
hir::ExprOrPatId, hir::ExprOrPatId,
path::{hir_segment_to_ast_segment, ModPath}, path::{hir_segment_to_ast_segment, ModPath},
type_ref::TypesSourceMap, type_ref::TypesSourceMap,
@ -115,14 +116,14 @@ diagnostics![
#[derive(Debug)] #[derive(Debug)]
pub struct BreakOutsideOfLoop { pub struct BreakOutsideOfLoop {
pub expr: InFile<AstPtr<ast::Expr>>, pub expr: InFile<ExprOrPatPtr>,
pub is_break: bool, pub is_break: bool,
pub bad_value_break: bool, pub bad_value_break: bool,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct TypedHole { pub struct TypedHole {
pub expr: InFile<AstPtr<ast::Expr>>, pub expr: InFile<ExprOrPatPtr>,
pub expected: Type, pub expected: Type,
} }
@ -221,26 +222,26 @@ pub struct NoSuchField {
#[derive(Debug)] #[derive(Debug)]
pub struct PrivateAssocItem { pub struct PrivateAssocItem {
pub expr_or_pat: InFile<AstPtr<Either<ast::Expr, ast::Pat>>>, pub expr_or_pat: InFile<ExprOrPatPtr>,
pub item: AssocItem, pub item: AssocItem,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct MismatchedTupleStructPatArgCount { pub struct MismatchedTupleStructPatArgCount {
pub expr_or_pat: InFile<AstPtr<Either<ast::Expr, ast::Pat>>>, pub expr_or_pat: InFile<ExprOrPatPtr>,
pub expected: usize, pub expected: usize,
pub found: usize, pub found: usize,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct ExpectedFunction { pub struct ExpectedFunction {
pub call: InFile<AstPtr<ast::Expr>>, pub call: InFile<ExprOrPatPtr>,
pub found: Type, pub found: Type,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct UnresolvedField { pub struct UnresolvedField {
pub expr: InFile<AstPtr<ast::Expr>>, pub expr: InFile<ExprOrPatPtr>,
pub receiver: Type, pub receiver: Type,
pub name: Name, pub name: Name,
pub method_with_same_name_exists: bool, pub method_with_same_name_exists: bool,
@ -248,7 +249,7 @@ pub struct UnresolvedField {
#[derive(Debug)] #[derive(Debug)]
pub struct UnresolvedMethodCall { pub struct UnresolvedMethodCall {
pub expr: InFile<AstPtr<ast::Expr>>, pub expr: InFile<ExprOrPatPtr>,
pub receiver: Type, pub receiver: Type,
pub name: Name, pub name: Name,
pub field_with_same_name: Option<Type>, pub field_with_same_name: Option<Type>,
@ -257,17 +258,17 @@ pub struct UnresolvedMethodCall {
#[derive(Debug)] #[derive(Debug)]
pub struct UnresolvedAssocItem { pub struct UnresolvedAssocItem {
pub expr_or_pat: InFile<AstPtr<Either<ast::Expr, ast::Pat>>>, pub expr_or_pat: InFile<ExprOrPatPtr>,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct UnresolvedIdent { pub struct UnresolvedIdent {
pub node: InFile<(AstPtr<Either<ast::Expr, ast::Pat>>, Option<TextRange>)>, pub node: InFile<(ExprOrPatPtr, Option<TextRange>)>,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct PrivateField { pub struct PrivateField {
pub expr: InFile<AstPtr<ast::Expr>>, pub expr: InFile<ExprOrPatPtr>,
pub field: Field, pub field: Field,
} }
@ -280,7 +281,7 @@ pub enum UnsafeLint {
#[derive(Debug)] #[derive(Debug)]
pub struct MissingUnsafe { pub struct MissingUnsafe {
pub node: InFile<AstPtr<Either<ast::Expr, ast::Pat>>>, pub node: InFile<ExprOrPatPtr>,
pub lint: UnsafeLint, pub lint: UnsafeLint,
pub reason: UnsafetyReason, pub reason: UnsafetyReason,
} }
@ -302,7 +303,7 @@ pub struct ReplaceFilterMapNextWithFindMap {
#[derive(Debug)] #[derive(Debug)]
pub struct MismatchedArgCount { pub struct MismatchedArgCount {
pub call_expr: InFile<AstPtr<ast::Expr>>, pub call_expr: InFile<ExprOrPatPtr>,
pub expected: usize, pub expected: usize,
pub found: usize, pub found: usize,
} }
@ -321,7 +322,7 @@ pub struct NonExhaustiveLet {
#[derive(Debug)] #[derive(Debug)]
pub struct TypeMismatch { pub struct TypeMismatch {
pub expr_or_pat: InFile<AstPtr<Either<ast::Expr, ast::Pat>>>, pub expr_or_pat: InFile<ExprOrPatPtr>,
pub expected: Type, pub expected: Type,
pub actual: Type, pub actual: Type,
} }
@ -395,13 +396,13 @@ pub struct RemoveUnnecessaryElse {
#[derive(Debug)] #[derive(Debug)]
pub struct CastToUnsized { pub struct CastToUnsized {
pub expr: InFile<AstPtr<ast::Expr>>, pub expr: InFile<ExprOrPatPtr>,
pub cast_ty: Type, pub cast_ty: Type,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct InvalidCast { pub struct InvalidCast {
pub expr: InFile<AstPtr<ast::Expr>>, pub expr: InFile<ExprOrPatPtr>,
pub error: CastError, pub error: CastError,
pub expr_ty: Type, pub expr_ty: Type,
pub cast_ty: Type, pub cast_ty: Type,
@ -428,9 +429,7 @@ impl AnyDiagnostic {
.collect(); .collect();
let record = match record { let record = match record {
Either::Left(record_expr) => { Either::Left(record_expr) => source_map.expr_syntax(record_expr).ok()?,
source_map.expr_syntax(record_expr).ok()?.map(AstPtr::wrap_left)
}
Either::Right(record_pat) => source_map.pat_syntax(record_pat).ok()?, Either::Right(record_pat) => source_map.pat_syntax(record_pat).ok()?,
}; };
let file = record.file_id; let file = record.file_id;
@ -474,7 +473,7 @@ impl AnyDiagnostic {
return Some( return Some(
ReplaceFilterMapNextWithFindMap { ReplaceFilterMapNextWithFindMap {
file: next_source_ptr.file_id, file: next_source_ptr.file_id,
next_expr: next_source_ptr.value, next_expr: next_source_ptr.value.cast()?,
} }
.into(), .into(),
); );
@ -484,7 +483,9 @@ impl AnyDiagnostic {
match source_map.expr_syntax(match_expr) { match source_map.expr_syntax(match_expr) {
Ok(source_ptr) => { Ok(source_ptr) => {
let root = source_ptr.file_syntax(db.upcast()); let root = source_ptr.file_syntax(db.upcast());
if let ast::Expr::MatchExpr(match_expr) = &source_ptr.value.to_node(&root) { if let Either::Left(ast::Expr::MatchExpr(match_expr)) =
&source_ptr.value.to_node(&root)
{
match match_expr.expr() { match match_expr.expr() {
Some(scrut_expr) if match_expr.match_arm_list().is_some() => { Some(scrut_expr) if match_expr.match_arm_list().is_some() => {
return Some( return Some(
@ -561,7 +562,7 @@ impl AnyDiagnostic {
let pat_syntax = let pat_syntax =
|pat| source_map.pat_syntax(pat).inspect_err(|_| stdx::never!("synthetic syntax")).ok(); |pat| source_map.pat_syntax(pat).inspect_err(|_| stdx::never!("synthetic syntax")).ok();
let expr_or_pat_syntax = |id| match id { let expr_or_pat_syntax = |id| match id {
ExprOrPatId::ExprId(expr) => expr_syntax(expr).map(|it| it.map(AstPtr::wrap_left)), ExprOrPatId::ExprId(expr) => expr_syntax(expr),
ExprOrPatId::PatId(pat) => pat_syntax(pat), ExprOrPatId::PatId(pat) => pat_syntax(pat),
}; };
Some(match d { Some(match d {
@ -633,7 +634,7 @@ impl AnyDiagnostic {
&InferenceDiagnostic::UnresolvedIdent { id } => { &InferenceDiagnostic::UnresolvedIdent { id } => {
let node = match id { let node = match id {
ExprOrPatId::ExprId(id) => match source_map.expr_syntax(id) { ExprOrPatId::ExprId(id) => match source_map.expr_syntax(id) {
Ok(syntax) => syntax.map(|it| (it.wrap_left(), None)), Ok(syntax) => syntax.map(|it| (it, None)),
Err(SyntheticSyntax) => source_map Err(SyntheticSyntax) => source_map
.format_args_implicit_capture(id)? .format_args_implicit_capture(id)?
.map(|(node, range)| (node.wrap_left(), Some(range))), .map(|(node, range)| (node.wrap_left(), Some(range))),
@ -652,7 +653,7 @@ impl AnyDiagnostic {
} }
&InferenceDiagnostic::MismatchedTupleStructPatArgCount { pat, expected, found } => { &InferenceDiagnostic::MismatchedTupleStructPatArgCount { pat, expected, found } => {
let expr_or_pat = match pat { let expr_or_pat = match pat {
ExprOrPatId::ExprId(expr) => expr_syntax(expr)?.map(AstPtr::wrap_left), ExprOrPatId::ExprId(expr) => expr_syntax(expr)?,
ExprOrPatId::PatId(pat) => { ExprOrPatId::PatId(pat) => {
let InFile { file_id, value } = pat_syntax(pat)?; let InFile { file_id, value } = pat_syntax(pat)?;

View file

@ -248,7 +248,7 @@ impl HasSource for Param {
let ast @ InFile { file_id, value } = source_map.expr_syntax(expr_id).ok()?; let ast @ InFile { file_id, value } = source_map.expr_syntax(expr_id).ok()?;
let root = db.parse_or_expand(file_id); let root = db.parse_or_expand(file_id);
match value.to_node(&root) { match value.to_node(&root) {
ast::Expr::ClosureExpr(it) => it Either::Left(ast::Expr::ClosureExpr(it)) => it
.param_list()? .param_list()?
.params() .params()
.nth(self.idx) .nth(self.idx)
@ -301,7 +301,7 @@ impl HasSource for InlineAsmOperand {
let root = src.file_syntax(db.upcast()); let root = src.file_syntax(db.upcast());
return src return src
.map(|ast| match ast.to_node(&root) { .map(|ast| match ast.to_node(&root) {
ast::Expr::AsmExpr(asm) => asm Either::Left(ast::Expr::AsmExpr(asm)) => asm
.asm_pieces() .asm_pieces()
.filter_map(|it| match it { .filter_map(|it| match it {
ast::AsmPiece::AsmOperandNamed(it) => Some(it), ast::AsmPiece::AsmOperandNamed(it) => Some(it),

View file

@ -1957,7 +1957,7 @@ impl DefWithBody {
ExprOrPatId::PatId(pat) => source_map.pat_syntax(pat).map(Either::Right), ExprOrPatId::PatId(pat) => source_map.pat_syntax(pat).map(Either::Right),
}; };
let expr_or_pat = match expr_or_pat { let expr_or_pat = match expr_or_pat {
Ok(Either::Left(expr)) => expr.map(AstPtr::wrap_left), Ok(Either::Left(expr)) => expr,
Ok(Either::Right(InFile { file_id, value: pat })) => { Ok(Either::Right(InFile { file_id, value: pat })) => {
// cast from Either<Pat, SelfParam> -> Either<_, Pat> // cast from Either<Pat, SelfParam> -> Either<_, Pat>
let Some(ptr) = AstPtr::try_from_raw(pat.syntax_node_ptr()) else { let Some(ptr) = AstPtr::try_from_raw(pat.syntax_node_ptr()) else {
@ -2003,7 +2003,7 @@ impl DefWithBody {
match source_map.expr_syntax(node) { match source_map.expr_syntax(node) {
Ok(node) => acc.push( Ok(node) => acc.push(
MissingUnsafe { MissingUnsafe {
node: node.map(|it| it.wrap_left()), node,
lint: UnsafeLint::DeprecatedSafe2024, lint: UnsafeLint::DeprecatedSafe2024,
reason: UnsafetyReason::UnsafeFnCall, reason: UnsafetyReason::UnsafeFnCall,
} }
@ -4592,10 +4592,7 @@ impl CaptureUsages {
match span { match span {
mir::MirSpan::ExprId(expr) => { mir::MirSpan::ExprId(expr) => {
if let Ok(expr) = source_map.expr_syntax(expr) { if let Ok(expr) = source_map.expr_syntax(expr) {
result.push(CaptureUsageSource { result.push(CaptureUsageSource { is_ref, source: expr })
is_ref,
source: expr.map(AstPtr::wrap_left),
})
} }
} }
mir::MirSpan::PatId(pat) => { mir::MirSpan::PatId(pat) => {

View file

@ -352,7 +352,7 @@ impl SourceToDefCtx<'_, '_> {
let src = src.cloned().map(ast::Pat::from); let src = src.cloned().map(ast::Pat::from);
let pat_id = source_map.node_pat(src.as_ref())?; let pat_id = source_map.node_pat(src.as_ref())?;
// the pattern could resolve to a constant, verify that this is not the case // the pattern could resolve to a constant, verify that this is not the case
if let crate::Pat::Bind { id, .. } = body[pat_id] { if let crate::Pat::Bind { id, .. } = body[pat_id.as_pat()?] {
Some((container, id)) Some((container, id))
} else { } else {
None None

View file

@ -18,7 +18,7 @@ use hir_def::{
scope::{ExprScopes, ScopeId}, scope::{ExprScopes, ScopeId},
Body, BodySourceMap, HygieneId, Body, BodySourceMap, HygieneId,
}, },
hir::{BindingId, Expr, ExprId, ExprOrPatId, Pat, PatId}, hir::{BindingId, Expr, ExprId, ExprOrPatId, Pat},
lang_item::LangItem, lang_item::LangItem,
lower::LowerCtx, lower::LowerCtx,
nameres::MacroSubNs, nameres::MacroSubNs,
@ -139,7 +139,7 @@ impl SourceAnalyzer {
sm.node_expr(src.as_ref()) sm.node_expr(src.as_ref())
} }
fn pat_id(&self, pat: &ast::Pat) -> Option<PatId> { fn pat_id(&self, pat: &ast::Pat) -> Option<ExprOrPatId> {
// FIXME: macros, see `expr_id` // FIXME: macros, see `expr_id`
let src = InFile { file_id: self.file_id, value: pat }; let src = InFile { file_id: self.file_id, value: pat };
self.body_source_map()?.node_pat(src) self.body_source_map()?.node_pat(src)
@ -147,7 +147,7 @@ impl SourceAnalyzer {
fn binding_id_of_pat(&self, pat: &ast::IdentPat) -> Option<BindingId> { fn binding_id_of_pat(&self, pat: &ast::IdentPat) -> Option<BindingId> {
let pat_id = self.pat_id(&pat.clone().into())?; let pat_id = self.pat_id(&pat.clone().into())?;
if let Pat::Bind { id, .. } = self.body()?.pats[pat_id] { if let Pat::Bind { id, .. } = self.body()?.pats[pat_id.as_pat()?] {
Some(id) Some(id)
} else { } else {
None None
@ -210,11 +210,20 @@ impl SourceAnalyzer {
db: &dyn HirDatabase, db: &dyn HirDatabase,
pat: &ast::Pat, pat: &ast::Pat,
) -> Option<(Type, Option<Type>)> { ) -> Option<(Type, Option<Type>)> {
let pat_id = self.pat_id(pat)?; let expr_or_pat_id = self.pat_id(pat)?;
let infer = self.infer.as_ref()?; let infer = self.infer.as_ref()?;
let coerced = let coerced = match expr_or_pat_id {
infer.pat_adjustments.get(&pat_id).and_then(|adjusts| adjusts.last().cloned()); ExprOrPatId::ExprId(idx) => infer
let ty = infer[pat_id].clone(); .expr_adjustments
.get(&idx)
.and_then(|adjusts| adjusts.last().cloned())
.map(|adjust| adjust.target),
ExprOrPatId::PatId(idx) => {
infer.pat_adjustments.get(&idx).and_then(|adjusts| adjusts.last().cloned())
}
};
let ty = infer[expr_or_pat_id].clone();
let mk_ty = |ty| Type::new_with_resolver(db, &self.resolver, ty); let mk_ty = |ty| Type::new_with_resolver(db, &self.resolver, ty);
Some((mk_ty(ty), coerced.map(mk_ty))) Some((mk_ty(ty), coerced.map(mk_ty)))
} }
@ -248,7 +257,7 @@ impl SourceAnalyzer {
) -> Option<BindingMode> { ) -> Option<BindingMode> {
let id = self.pat_id(&pat.clone().into())?; let id = self.pat_id(&pat.clone().into())?;
let infer = self.infer.as_ref()?; let infer = self.infer.as_ref()?;
infer.binding_modes.get(id).map(|bm| match bm { infer.binding_modes.get(id.as_pat()?).map(|bm| match bm {
hir_ty::BindingMode::Move => BindingMode::Move, hir_ty::BindingMode::Move => BindingMode::Move,
hir_ty::BindingMode::Ref(hir_ty::Mutability::Mut) => BindingMode::Ref(Mutability::Mut), hir_ty::BindingMode::Ref(hir_ty::Mutability::Mut) => BindingMode::Ref(Mutability::Mut),
hir_ty::BindingMode::Ref(hir_ty::Mutability::Not) => { hir_ty::BindingMode::Ref(hir_ty::Mutability::Not) => {
@ -266,7 +275,7 @@ impl SourceAnalyzer {
Some( Some(
infer infer
.pat_adjustments .pat_adjustments
.get(&pat_id)? .get(&pat_id.as_pat()?)?
.iter() .iter()
.map(|ty| Type::new_with_resolver(db, &self.resolver, ty.clone())) .map(|ty| Type::new_with_resolver(db, &self.resolver, ty.clone()))
.collect(), .collect(),
@ -649,10 +658,10 @@ impl SourceAnalyzer {
let field_name = field.field_name()?.as_name(); let field_name = field.field_name()?.as_name();
let record_pat = ast::RecordPat::cast(field.syntax().parent().and_then(|p| p.parent())?)?; let record_pat = ast::RecordPat::cast(field.syntax().parent().and_then(|p| p.parent())?)?;
let pat_id = self.pat_id(&record_pat.into())?; let pat_id = self.pat_id(&record_pat.into())?;
let variant = self.infer.as_ref()?.variant_resolution_for_pat(pat_id)?; let variant = self.infer.as_ref()?.variant_resolution_for_pat(pat_id.as_pat()?)?;
let variant_data = variant.variant_data(db.upcast()); let variant_data = variant.variant_data(db.upcast());
let field = FieldId { parent: variant, local_id: variant_data.field(&field_name)? }; let field = FieldId { parent: variant, local_id: variant_data.field(&field_name)? };
let (adt, subst) = self.infer.as_ref()?.type_of_pat.get(pat_id)?.as_adt()?; let (adt, subst) = self.infer.as_ref()?.type_of_pat.get(pat_id.as_pat()?)?.as_adt()?;
let field_ty = let field_ty =
db.field_types(variant).get(field.local_id)?.clone().substitute(Interner, subst); db.field_types(variant).get(field.local_id)?.clone().substitute(Interner, subst);
Some(( Some((
@ -682,12 +691,20 @@ impl SourceAnalyzer {
db: &dyn HirDatabase, db: &dyn HirDatabase,
pat: &ast::IdentPat, pat: &ast::IdentPat,
) -> Option<ModuleDef> { ) -> Option<ModuleDef> {
let pat_id = self.pat_id(&pat.clone().into())?; let expr_or_pat_id = self.pat_id(&pat.clone().into())?;
let body = self.body()?; let body = self.body()?;
let path = match &body[pat_id] {
let path = match expr_or_pat_id {
ExprOrPatId::ExprId(idx) => match &body[idx] {
Expr::Path(path) => path,
_ => return None,
},
ExprOrPatId::PatId(idx) => match &body[idx] {
Pat::Path(path) => path, Pat::Path(path) => path,
_ => return None, _ => return None,
},
}; };
let res = resolve_hir_path(db, &self.resolver, path, HygieneId::ROOT, TypesMap::EMPTY)?; let res = resolve_hir_path(db, &self.resolver, path, HygieneId::ROOT, TypesMap::EMPTY)?;
match res { match res {
PathResolution::Def(def) => Some(def), PathResolution::Def(def) => Some(def),
@ -782,8 +799,9 @@ impl SourceAnalyzer {
} }
prefer_value_ns = true; prefer_value_ns = true;
} else if let Some(path_pat) = parent().and_then(ast::PathPat::cast) { } else if let Some(path_pat) = parent().and_then(ast::PathPat::cast) {
let pat_id = self.pat_id(&path_pat.into())?; let expr_or_pat_id = self.pat_id(&path_pat.into())?;
if let Some((assoc, subs)) = infer.assoc_resolutions_for_pat(pat_id) { if let Some((assoc, subs)) = infer.assoc_resolutions_for_expr_or_pat(expr_or_pat_id)
{
let (assoc, subst) = match assoc { let (assoc, subst) = match assoc {
AssocItemId::ConstId(const_id) => { AssocItemId::ConstId(const_id) => {
let (konst, subst) = let (konst, subst) =
@ -807,7 +825,7 @@ impl SourceAnalyzer {
return Some((PathResolution::Def(AssocItem::from(assoc).into()), Some(subst))); return Some((PathResolution::Def(AssocItem::from(assoc).into()), Some(subst)));
} }
if let Some(VariantId::EnumVariantId(variant)) = if let Some(VariantId::EnumVariantId(variant)) =
infer.variant_resolution_for_pat(pat_id) infer.variant_resolution_for_expr_or_pat(expr_or_pat_id)
{ {
return Some((PathResolution::Def(ModuleDef::Variant(variant.into())), None)); return Some((PathResolution::Def(ModuleDef::Variant(variant.into())), None));
} }
@ -824,7 +842,7 @@ impl SourceAnalyzer {
|| parent().and_then(ast::TupleStructPat::cast).map(ast::Pat::from); || parent().and_then(ast::TupleStructPat::cast).map(ast::Pat::from);
if let Some(pat) = record_pat.or_else(tuple_struct_pat) { if let Some(pat) = record_pat.or_else(tuple_struct_pat) {
let pat_id = self.pat_id(&pat)?; let pat_id = self.pat_id(&pat)?;
let variant_res_for_pat = infer.variant_resolution_for_pat(pat_id); let variant_res_for_pat = infer.variant_resolution_for_pat(pat_id.as_pat()?);
if let Some(VariantId::EnumVariantId(variant)) = variant_res_for_pat { if let Some(VariantId::EnumVariantId(variant)) = variant_res_for_pat {
return Some(( return Some((
PathResolution::Def(ModuleDef::Variant(variant.into())), PathResolution::Def(ModuleDef::Variant(variant.into())),
@ -1080,7 +1098,7 @@ impl SourceAnalyzer {
let body = self.body()?; let body = self.body()?;
let infer = self.infer.as_ref()?; let infer = self.infer.as_ref()?;
let pat_id = self.pat_id(&pattern.clone().into())?; let pat_id = self.pat_id(&pattern.clone().into())?.as_pat()?;
let substs = infer.type_of_pat[pat_id].as_adt()?.1; let substs = infer.type_of_pat[pat_id].as_adt()?.1;
let (variant, missing_fields, _exhaustive) = let (variant, missing_fields, _exhaustive) =

View file

@ -40,7 +40,7 @@ pub(crate) fn mismatched_arg_count(
Diagnostic::new( Diagnostic::new(
DiagnosticCode::RustcHardError("E0107"), DiagnosticCode::RustcHardError("E0107"),
message, message,
invalid_args_range(ctx, d.call_expr.map(AstPtr::wrap_left), d.expected, d.found), invalid_args_range(ctx, d.call_expr, d.expected, d.found),
) )
} }

View file

@ -1,5 +1,6 @@
use std::iter; use std::iter;
use either::Either;
use hir::{db::ExpandDatabase, Adt, FileRange, HasSource, HirDisplay, InFile, Struct, Union}; use hir::{db::ExpandDatabase, Adt, FileRange, HasSource, HirDisplay, InFile, Struct, Union};
use ide_db::text_edit::TextEdit; use ide_db::text_edit::TextEdit;
use ide_db::{ use ide_db::{
@ -41,7 +42,7 @@ pub(crate) fn unresolved_field(
), ),
adjusted_display_range(ctx, d.expr, &|expr| { adjusted_display_range(ctx, d.expr, &|expr| {
Some( Some(
match expr { match expr.left()? {
ast::Expr::MethodCallExpr(it) => it.name_ref(), ast::Expr::MethodCallExpr(it) => it.name_ref(),
ast::Expr::FieldExpr(it) => it.name_ref(), ast::Expr::FieldExpr(it) => it.name_ref(),
_ => None, _ => None,
@ -72,7 +73,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::UnresolvedField) -> Option<Vec<A
fn field_fix(ctx: &DiagnosticsContext<'_>, d: &hir::UnresolvedField) -> Option<Assist> { fn field_fix(ctx: &DiagnosticsContext<'_>, d: &hir::UnresolvedField) -> Option<Assist> {
// Get the FileRange of the invalid field access // Get the FileRange of the invalid field access
let root = ctx.sema.db.parse_or_expand(d.expr.file_id); let root = ctx.sema.db.parse_or_expand(d.expr.file_id);
let expr = d.expr.value.to_node(&root); let expr = d.expr.value.to_node(&root).left()?;
let error_range = ctx.sema.original_range_opt(expr.syntax())?; let error_range = ctx.sema.original_range_opt(expr.syntax())?;
let field_name = d.name.as_str(); let field_name = d.name.as_str();
@ -263,7 +264,7 @@ fn record_field_layout(
// FIXME: We should fill out the call here, move the cursor and trigger signature help // FIXME: We should fill out the call here, move the cursor and trigger signature help
fn method_fix( fn method_fix(
ctx: &DiagnosticsContext<'_>, ctx: &DiagnosticsContext<'_>,
expr_ptr: &InFile<AstPtr<ast::Expr>>, expr_ptr: &InFile<AstPtr<Either<ast::Expr, ast::Pat>>>,
) -> Option<Assist> { ) -> Option<Assist> {
let root = ctx.sema.db.parse_or_expand(expr_ptr.file_id); let root = ctx.sema.db.parse_or_expand(expr_ptr.file_id);
let expr = expr_ptr.value.to_node(&root); let expr = expr_ptr.value.to_node(&root);

View file

@ -35,7 +35,7 @@ pub(crate) fn unresolved_method(
), ),
adjusted_display_range(ctx, d.expr, &|expr| { adjusted_display_range(ctx, d.expr, &|expr| {
Some( Some(
match expr { match expr.left()? {
ast::Expr::MethodCallExpr(it) => it.name_ref(), ast::Expr::MethodCallExpr(it) => it.name_ref(),
ast::Expr::FieldExpr(it) => it.name_ref(), ast::Expr::FieldExpr(it) => it.name_ref(),
_ => None, _ => None,
@ -85,7 +85,7 @@ fn field_fix(
let expr_ptr = &d.expr; let expr_ptr = &d.expr;
let root = ctx.sema.db.parse_or_expand(expr_ptr.file_id); let root = ctx.sema.db.parse_or_expand(expr_ptr.file_id);
let expr = expr_ptr.value.to_node(&root); let expr = expr_ptr.value.to_node(&root);
let (file_id, range) = match expr { let (file_id, range) = match expr.left()? {
ast::Expr::MethodCallExpr(mcall) => { ast::Expr::MethodCallExpr(mcall) => {
let FileRange { range, file_id } = let FileRange { range, file_id } =
ctx.sema.original_range_opt(mcall.receiver()?.syntax())?; ctx.sema.original_range_opt(mcall.receiver()?.syntax())?;
@ -117,7 +117,7 @@ fn assoc_func_fix(ctx: &DiagnosticsContext<'_>, d: &hir::UnresolvedMethodCall) -
let expr_ptr = &d.expr; let expr_ptr = &d.expr;
let root = db.parse_or_expand(expr_ptr.file_id); let root = db.parse_or_expand(expr_ptr.file_id);
let expr: ast::Expr = expr_ptr.value.to_node(&root); let expr: ast::Expr = expr_ptr.value.to_node(&root).left()?;
let call = ast::MethodCallExpr::cast(expr.syntax().clone())?; let call = ast::MethodCallExpr::cast(expr.syntax().clone())?;
let range = InFile::new(expr_ptr.file_id, call.syntax().text_range()) let range = InFile::new(expr_ptr.file_id, call.syntax().text_range())