Compute closure captures

This commit is contained in:
hkalbasi 2023-04-06 16:14:38 +03:30
parent 51d5862caf
commit 59b6f2d9f2
42 changed files with 2537 additions and 433 deletions

View file

@ -21,9 +21,16 @@ use la_arena::ArenaMap;
use rustc_hash::FxHashMap;
use crate::{
consteval::ConstEvalError, db::HirDatabase, display::HirDisplay, infer::TypeMismatch,
inhabitedness::is_ty_uninhabited_from, layout::layout_of_ty, mapping::ToChalk, static_lifetime,
utils::generics, Adjust, Adjustment, AutoBorrow, CallableDefId, TyBuilder, TyExt,
consteval::ConstEvalError,
db::HirDatabase,
display::HirDisplay,
infer::{CaptureKind, CapturedItem, TypeMismatch},
inhabitedness::is_ty_uninhabited_from,
layout::layout_of_ty,
mapping::ToChalk,
static_lifetime,
utils::generics,
Adjust, Adjustment, AutoBorrow, CallableDefId, TyBuilder, TyExt,
};
use super::*;
@ -74,10 +81,12 @@ pub enum MirLowerError {
BreakWithoutLoop,
Loop,
/// Something that should never happen and is definitely a bug, but we don't want to panic if it happened
ImplementationError(&'static str),
ImplementationError(String),
LangItemNotFound(LangItem),
MutatingRvalue,
UnresolvedLabel,
UnresolvedUpvar(Place),
UnaccessableLocal,
}
macro_rules! not_supported {
@ -88,8 +97,8 @@ macro_rules! not_supported {
macro_rules! implementation_error {
($x: expr) => {{
::stdx::never!("MIR lower implementation bug: {}", $x);
return Err(MirLowerError::ImplementationError($x));
::stdx::never!("MIR lower implementation bug: {}", format!($x));
return Err(MirLowerError::ImplementationError(format!($x)));
}};
}
@ -116,7 +125,44 @@ impl MirLowerError {
type Result<T> = std::result::Result<T, MirLowerError>;
impl MirLowerCtx<'_> {
impl<'ctx> MirLowerCtx<'ctx> {
fn new(
db: &'ctx dyn HirDatabase,
owner: DefWithBodyId,
body: &'ctx Body,
infer: &'ctx InferenceResult,
) -> Self {
let mut basic_blocks = Arena::new();
let start_block = basic_blocks.alloc(BasicBlock {
statements: vec![],
terminator: None,
is_cleanup: false,
});
let locals = Arena::new();
let binding_locals: ArenaMap<BindingId, LocalId> = ArenaMap::new();
let mir = MirBody {
basic_blocks,
locals,
start_block,
binding_locals,
param_locals: vec![],
owner,
arg_count: body.params.len(),
closures: vec![],
};
let ctx = MirLowerCtx {
result: mir,
db,
infer,
body,
owner,
current_loop_blocks: None,
labeled_loop_blocks: Default::default(),
discr_temp: None,
};
ctx
}
fn temp(&mut self, ty: Ty) -> Result<LocalId> {
if matches!(ty.kind(Interner), TyKind::Slice(_) | TyKind::Dyn(_)) {
implementation_error!("unsized temporaries");
@ -268,7 +314,7 @@ impl MirLowerCtx<'_> {
self.push_assignment(
current,
place,
Operand::Copy(self.result.binding_locals[pat_id].into()).into(),
Operand::Copy(self.binding_local(pat_id)?.into()).into(),
expr_id.into(),
);
Ok(Some(current))
@ -823,7 +869,51 @@ impl MirLowerCtx<'_> {
);
Ok(Some(current))
},
Expr::Closure { .. } => not_supported!("closure"),
Expr::Closure { .. } => {
let ty = self.expr_ty(expr_id);
let TyKind::Closure(id, _) = ty.kind(Interner) else {
not_supported!("closure with non closure type");
};
self.result.closures.push(*id);
let (captures, _) = self.infer.closure_info(id);
let mut operands = vec![];
for capture in captures.iter() {
let p = Place {
local: self.binding_local(capture.place.local)?,
projection: capture.place.projections.clone().into_iter().map(|x| {
match x {
ProjectionElem::Deref => ProjectionElem::Deref,
ProjectionElem::Field(x) => ProjectionElem::Field(x),
ProjectionElem::TupleOrClosureField(x) => ProjectionElem::TupleOrClosureField(x),
ProjectionElem::ConstantIndex { offset, min_length, from_end } => ProjectionElem::ConstantIndex { offset, min_length, from_end },
ProjectionElem::Subslice { from, to, from_end } => ProjectionElem::Subslice { from, to, from_end },
ProjectionElem::OpaqueCast(x) => ProjectionElem::OpaqueCast(x),
ProjectionElem::Index(x) => match x { },
}
}).collect(),
};
match &capture.kind {
CaptureKind::ByRef(bk) => {
let tmp: Place = self.temp(capture.ty.clone())?.into();
self.push_assignment(
current,
tmp.clone(),
Rvalue::Ref(bk.clone(), p),
expr_id.into(),
);
operands.push(Operand::Move(tmp));
},
CaptureKind::ByValue => operands.push(Operand::Move(p)),
}
}
self.push_assignment(
current,
place,
Rvalue::Aggregate(AggregateKind::Closure(ty), operands),
expr_id.into(),
);
Ok(Some(current))
},
Expr::Tuple { exprs, is_assignee_expr: _ } => {
let Some(values) = exprs
.iter()
@ -893,7 +983,7 @@ impl MirLowerCtx<'_> {
let index = name
.as_tuple_index()
.ok_or(MirLowerError::TypeError("named field on tuple"))?;
place.projection.push(ProjectionElem::TupleField(index))
place.projection.push(ProjectionElem::TupleOrClosureField(index))
} else {
let field =
self.infer.field_resolution(expr_id).ok_or(MirLowerError::UnresolvedField)?;
@ -1126,8 +1216,9 @@ impl MirLowerCtx<'_> {
};
self.set_goto(prev_block, begin);
f(self, begin)?;
let my = mem::replace(&mut self.current_loop_blocks, prev)
.ok_or(MirLowerError::ImplementationError("current_loop_blocks is corrupt"))?;
let my = mem::replace(&mut self.current_loop_blocks, prev).ok_or(
MirLowerError::ImplementationError("current_loop_blocks is corrupt".to_string()),
)?;
if let Some(prev) = prev_label {
self.labeled_loop_blocks.insert(label.unwrap(), prev);
}
@ -1159,7 +1250,9 @@ impl MirLowerCtx<'_> {
let r = match self
.current_loop_blocks
.as_mut()
.ok_or(MirLowerError::ImplementationError("Current loop access out of loop"))?
.ok_or(MirLowerError::ImplementationError(
"Current loop access out of loop".to_string(),
))?
.end
{
Some(x) => x,
@ -1167,7 +1260,9 @@ impl MirLowerCtx<'_> {
let s = self.new_basic_block();
self.current_loop_blocks
.as_mut()
.ok_or(MirLowerError::ImplementationError("Current loop access out of loop"))?
.ok_or(MirLowerError::ImplementationError(
"Current loop access out of loop".to_string(),
))?
.end = Some(s);
s
}
@ -1181,7 +1276,7 @@ impl MirLowerCtx<'_> {
/// This function push `StorageLive` statement for the binding, and applies changes to add `StorageDead` in
/// the appropriated places.
fn push_storage_live(&mut self, b: BindingId, current: BasicBlockId) {
fn push_storage_live(&mut self, b: BindingId, current: BasicBlockId) -> Result<()> {
// Current implementation is wrong. It adds no `StorageDead` at the end of scope, and before each break
// and continue. It just add a `StorageDead` before the `StorageLive`, which is not wrong, but unneeeded in
// the proper implementation. Due this limitation, implementing a borrow checker on top of this mir will falsely
@ -1206,9 +1301,10 @@ impl MirLowerCtx<'_> {
.copied()
.map(MirSpan::PatId)
.unwrap_or(MirSpan::Unknown);
let l = self.result.binding_locals[b];
let l = self.binding_local(b)?;
self.push_statement(current, StatementKind::StorageDead(l).with_span(span));
self.push_statement(current, StatementKind::StorageLive(l).with_span(span));
Ok(())
}
fn resolve_lang_item(&self, item: LangItem) -> Result<LangItemTarget> {
@ -1256,9 +1352,15 @@ impl MirLowerCtx<'_> {
}
}
} else {
let mut err = None;
self.body.walk_bindings_in_pat(*pat, |b| {
self.push_storage_live(b, current);
if let Err(e) = self.push_storage_live(b, current) {
err = Some(e);
}
});
if let Some(e) = err {
return Err(e);
}
}
}
hir_def::hir::Statement::Expr { expr, has_semi: _ } => {
@ -1274,6 +1376,67 @@ impl MirLowerCtx<'_> {
None => Ok(Some(current)),
}
}
fn lower_params_and_bindings(
&mut self,
params: impl Iterator<Item = (PatId, Ty)> + Clone,
pick_binding: impl Fn(BindingId) -> bool,
) -> Result<BasicBlockId> {
let base_param_count = self.result.param_locals.len();
self.result.param_locals.extend(params.clone().map(|(x, ty)| {
let local_id = self.result.locals.alloc(Local { ty });
if let Pat::Bind { id, subpat: None } = self.body[x] {
if matches!(
self.body.bindings[id].mode,
BindingAnnotation::Unannotated | BindingAnnotation::Mutable
) {
self.result.binding_locals.insert(id, local_id);
}
}
local_id
}));
// and then rest of bindings
for (id, _) in self.body.bindings.iter() {
if !pick_binding(id) {
continue;
}
if !self.result.binding_locals.contains_idx(id) {
self.result
.binding_locals
.insert(id, self.result.locals.alloc(Local { ty: self.infer[id].clone() }));
}
}
let mut current = self.result.start_block;
for ((param, _), local) in
params.zip(self.result.param_locals.clone().into_iter().skip(base_param_count))
{
if let Pat::Bind { id, .. } = self.body[param] {
if local == self.binding_local(id)? {
continue;
}
}
let r = self.pattern_match(
current,
None,
local.into(),
self.result.locals[local].ty.clone(),
param,
BindingAnnotation::Unannotated,
)?;
if let Some(b) = r.1 {
self.set_terminator(b, Terminator::Unreachable);
}
current = r.0;
}
Ok(current)
}
fn binding_local(&self, b: BindingId) -> Result<LocalId> {
match self.result.binding_locals.get(b) {
Some(x) => Ok(*x),
None => Err(MirLowerError::UnaccessableLocal),
}
}
}
fn cast_kind(source_ty: &Ty, target_ty: &Ty) -> Result<CastKind> {
@ -1297,6 +1460,87 @@ fn cast_kind(source_ty: &Ty, target_ty: &Ty) -> Result<CastKind> {
})
}
pub fn mir_body_for_closure_query(
db: &dyn HirDatabase,
closure: ClosureId,
) -> Result<Arc<MirBody>> {
let (owner, expr) = db.lookup_intern_closure(closure.into());
let body = db.body(owner);
let infer = db.infer(owner);
let Expr::Closure { args, body: root, .. } = &body[expr] else {
implementation_error!("closure expression is not closure");
};
let TyKind::Closure(_, substs) = &infer[expr].kind(Interner) else {
implementation_error!("closure expression is not closure");
};
let (captures, _) = infer.closure_info(&closure);
let mut ctx = MirLowerCtx::new(db, owner, &body, &infer);
ctx.result.arg_count = args.len() + 1;
// 0 is return local
ctx.result.locals.alloc(Local { ty: infer[*root].clone() });
ctx.result.locals.alloc(Local { ty: infer[expr].clone() });
let Some(sig) = substs.at(Interner, 0).assert_ty_ref(Interner).callable_sig(db) else {
implementation_error!("closure has not callable sig");
};
let current = ctx.lower_params_and_bindings(
args.iter().zip(sig.params().iter()).map(|(x, y)| (*x, y.clone())),
|_| true,
)?;
if let Some(b) = ctx.lower_expr_to_place(*root, return_slot().into(), current)? {
ctx.set_terminator(b, Terminator::Return);
}
let mut upvar_map: FxHashMap<LocalId, Vec<(&CapturedItem, usize)>> = FxHashMap::default();
for (i, capture) in captures.iter().enumerate() {
let local = ctx.binding_local(capture.place.local)?;
upvar_map.entry(local).or_default().push((capture, i));
}
let mut err = None;
let closure_local = ctx.result.locals.iter().nth(1).unwrap().0;
ctx.result.walk_places(|p| {
if let Some(x) = upvar_map.get(&p.local) {
let r = x.iter().find(|x| {
if p.projection.len() < x.0.place.projections.len() {
return false;
}
for (x, y) in p.projection.iter().zip(x.0.place.projections.iter()) {
match (x, y) {
(ProjectionElem::Deref, ProjectionElem::Deref) => (),
(ProjectionElem::Field(x), ProjectionElem::Field(y)) if x == y => (),
(
ProjectionElem::TupleOrClosureField(x),
ProjectionElem::TupleOrClosureField(y),
) if x == y => (),
_ => return false,
}
}
true
});
match r {
Some(x) => {
p.local = closure_local;
let prev_projs =
mem::replace(&mut p.projection, vec![PlaceElem::TupleOrClosureField(x.1)]);
if x.0.kind != CaptureKind::ByValue {
p.projection.push(ProjectionElem::Deref);
}
p.projection.extend(prev_projs.into_iter().skip(x.0.place.projections.len()));
}
None => err = Some(p.clone()),
}
}
});
ctx.result.binding_locals = ctx
.result
.binding_locals
.into_iter()
.filter(|x| ctx.body[x.0].owner == Some(expr))
.collect();
if let Some(err) = err {
return Err(MirLowerError::UnresolvedUpvar(err));
}
Ok(Arc::new(ctx.result))
}
pub fn mir_body_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Result<Arc<MirBody>> {
let _p = profile::span("mir_body_query").detail(|| match def {
DefWithBodyId::FunctionId(it) => db.function_data(it).name.to_string(),
@ -1334,86 +1578,29 @@ pub fn lower_to_mir(
if let Some((_, x)) = infer.type_mismatches().next() {
return Err(MirLowerError::TypeMismatch(x.clone()));
}
let mut basic_blocks = Arena::new();
let start_block =
basic_blocks.alloc(BasicBlock { statements: vec![], terminator: None, is_cleanup: false });
let mut locals = Arena::new();
let mut ctx = MirLowerCtx::new(db, owner, body, infer);
// 0 is return local
locals.alloc(Local { ty: infer[root_expr].clone() });
let mut binding_locals: ArenaMap<BindingId, LocalId> = ArenaMap::new();
ctx.result.locals.alloc(Local { ty: infer[root_expr].clone() });
let binding_picker = |b: BindingId| {
if root_expr == body.body_expr {
body[b].owner.is_none()
} else {
body[b].owner == Some(root_expr)
}
};
// 1 to param_len is for params
let param_locals: Vec<LocalId> = if let DefWithBodyId::FunctionId(fid) = owner {
let current = if let DefWithBodyId::FunctionId(fid) = owner {
let substs = TyBuilder::placeholder_subst(db, fid);
let callable_sig = db.callable_item_signature(fid.into()).substitute(Interner, &substs);
body.params
.iter()
.zip(callable_sig.params().iter())
.map(|(&x, ty)| {
let local_id = locals.alloc(Local { ty: ty.clone() });
if let Pat::Bind { id, subpat: None } = body[x] {
if matches!(
body.bindings[id].mode,
BindingAnnotation::Unannotated | BindingAnnotation::Mutable
) {
binding_locals.insert(id, local_id);
}
}
local_id
})
.collect()
ctx.lower_params_and_bindings(
body.params.iter().zip(callable_sig.params().iter()).map(|(x, y)| (*x, y.clone())),
binding_picker,
)?
} else {
if !body.params.is_empty() {
return Err(MirLowerError::TypeError("Unexpected parameter for non function body"));
}
vec![]
ctx.lower_params_and_bindings([].into_iter(), binding_picker)?
};
// and then rest of bindings
for (id, _) in body.bindings.iter() {
if !binding_locals.contains_idx(id) {
binding_locals.insert(id, locals.alloc(Local { ty: infer[id].clone() }));
}
}
let mir = MirBody {
basic_blocks,
locals,
start_block,
binding_locals,
param_locals,
owner,
arg_count: body.params.len(),
};
let mut ctx = MirLowerCtx {
result: mir,
db,
infer,
body,
owner,
current_loop_blocks: None,
labeled_loop_blocks: Default::default(),
discr_temp: None,
};
let mut current = start_block;
for (&param, local) in body.params.iter().zip(ctx.result.param_locals.clone().into_iter()) {
if let Pat::Bind { id, .. } = body[param] {
if local == ctx.result.binding_locals[id] {
continue;
}
}
let r = ctx.pattern_match(
current,
None,
local.into(),
ctx.result.locals[local].ty.clone(),
param,
BindingAnnotation::Unannotated,
)?;
if let Some(b) = r.1 {
ctx.set_terminator(b, Terminator::Unreachable);
}
current = r.0;
}
if let Some(b) = ctx.lower_expr_to_place(root_expr, return_slot().into(), current)? {
ctx.result.basic_blocks[b].terminator = Some(Terminator::Return);
ctx.set_terminator(b, Terminator::Return);
}
Ok(ctx.result)
}