Modify refcount of contents *before* structure to prevent use-after-free

This commit is contained in:
Brian Carroll 2022-01-05 12:22:26 +00:00
parent 456bda0895
commit d9cc3c5692
3 changed files with 233 additions and 112 deletions

View file

@ -68,8 +68,9 @@ pub fn refcount_stmt<'a>(
refcount_stmt(root, ident_ids, ctx, layout, modify, following)
}
// Struct is stack-only, so DecRef is a no-op
// Struct and non-recursive Unions are stack-only, so DecRef is a no-op
Layout::Struct(_) => following,
Layout::Union(UnionLayout::NonRecursive(_)) => following,
// Inline the refcounting code instead of making a function. Don't iterate fields,
// and replace any return statements with jumps to the `following` statement.
@ -114,7 +115,7 @@ pub fn refcount_generic<'a>(
refcount_struct(root, ident_ids, ctx, field_layouts, structure)
}
Layout::Union(union_layout) => {
refcount_tag_union(root, ident_ids, ctx, union_layout, structure)
refcount_union(root, ident_ids, ctx, union_layout, structure)
}
Layout::LambdaSet(lambda_set) => {
let runtime_layout = lambda_set.runtime_representation();
@ -482,12 +483,32 @@ fn refcount_list<'a>(
//
// modify refcount of the list and its elements
// (elements first, to avoid use-after-free for Dec)
//
let rc_ptr = root.create_symbol(ident_ids, "rc_ptr");
let alignment = layout.alignment_bytes(root.ptr_size);
let modify_elems = if elem_layout.is_refcounted() && !ctx.op.is_decref() {
let ret_stmt = rc_return_stmt(root, ident_ids, ctx);
let modify_list = modify_refcount(
root,
ident_ids,
ctx,
rc_ptr,
alignment,
arena.alloc(ret_stmt),
);
let get_rc_and_modify_list = rc_ptr_from_data_ptr(
root,
ident_ids,
elements,
rc_ptr,
false,
arena.alloc(modify_list),
);
let modify_elems_and_list = if elem_layout.is_refcounted() && !ctx.op.is_decref() {
refcount_list_elems(
root,
ident_ids,
@ -497,43 +518,31 @@ fn refcount_list<'a>(
box_union_layout,
len,
elements,
get_rc_and_modify_list,
)
} else {
rc_return_stmt(root, ident_ids, ctx)
get_rc_and_modify_list
};
let modify_list = modify_refcount(
root,
ident_ids,
ctx,
rc_ptr,
alignment,
arena.alloc(modify_elems),
);
let modify_list_and_elems = elements_stmt(arena.alloc(
//
rc_ptr_from_data_ptr(
root,
ident_ids,
elements,
rc_ptr,
false,
arena.alloc(modify_list),
),
));
//
// Do nothing if the list is empty
//
let non_empty_branch = root.arena.alloc(
//
elements_stmt(root.arena.alloc(
//
modify_elems_and_list,
)),
);
let if_stmt = Stmt::Switch {
cond_symbol: is_empty,
cond_layout: LAYOUT_BOOL,
branches: root
.arena
.alloc([(1, BranchInfo::None, rc_return_stmt(root, ident_ids, ctx))]),
default_branch: (BranchInfo::None, root.arena.alloc(modify_list_and_elems)),
default_branch: (BranchInfo::None, non_empty_branch),
ret_layout: LAYOUT_UNIT,
};
@ -559,6 +568,7 @@ fn refcount_list_elems<'a>(
box_union_layout: UnionLayout<'a>,
length: Symbol,
elements: Symbol,
following: Stmt<'a>,
) -> Stmt<'a> {
use LowLevel::*;
let layout_isize = root.layout_isize;
@ -573,9 +583,9 @@ fn refcount_list_elems<'a>(
//
// let size = literal int
let size = root.create_symbol(ident_ids, "size");
let size_expr = Expr::Literal(Literal::Int(elem_layout.stack_size(root.ptr_size) as i128));
let size_stmt = |next| Stmt::Let(size, size_expr, layout_isize, next);
let elem_size = root.create_symbol(ident_ids, "elem_size");
let elem_size_expr = Expr::Literal(Literal::Int(elem_layout.stack_size(root.ptr_size) as i128));
let elem_size_stmt = |next| Stmt::Let(elem_size, elem_size_expr, layout_isize, next);
// let list_size = len * size
let list_size = root.create_symbol(ident_ids, "list_size");
@ -585,7 +595,7 @@ fn refcount_list_elems<'a>(
layout_isize,
list_size,
NumMul,
&[length, size],
&[length, elem_size],
next,
)
};
@ -641,8 +651,16 @@ fn refcount_list_elems<'a>(
// Next loop iteration
//
let next_addr = root.create_symbol(ident_ids, "next_addr");
let next_addr_stmt =
|next| let_lowlevel(arena, layout_isize, next_addr, NumAdd, &[addr, size], next);
let next_addr_stmt = |next| {
let_lowlevel(
arena,
layout_isize,
next_addr,
NumAdd,
&[addr, elem_size],
next,
)
};
//
// Control flow
@ -655,9 +673,7 @@ fn refcount_list_elems<'a>(
cond_symbol: is_end,
cond_layout: LAYOUT_BOOL,
ret_layout,
branches: root
.arena
.alloc([(1, BranchInfo::None, rc_return_stmt(root, ident_ids, ctx))]),
branches: root.arena.alloc([(1, BranchInfo::None, following)]),
default_branch: (
BranchInfo::None,
arena.alloc(box_stmt(arena.alloc(
@ -693,7 +709,7 @@ fn refcount_list_elems<'a>(
start_stmt(arena.alloc(
//
size_stmt(arena.alloc(
elem_size_stmt(arena.alloc(
//
list_size_stmt(arena.alloc(
//
@ -745,32 +761,30 @@ fn refcount_struct<'a>(
stmt
}
fn refcount_tag_union<'a>(
fn refcount_union<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
union: UnionLayout<'a>,
structure: Symbol,
) -> Stmt<'a> {
use UnionLayout::*;
let parent_rec_ptr_layout = ctx.recursive_union;
if !matches!(union_layout, NonRecursive(_)) {
ctx.recursive_union = Some(union_layout);
if !matches!(union, NonRecursive(_)) {
ctx.recursive_union = Some(union);
}
let body = match union_layout {
NonRecursive(tags) => {
refcount_tag_union_help(root, ident_ids, ctx, union_layout, tags, None, structure)
}
let body = match union {
NonRecursive(tags) => refcount_union_nonrec(root, ident_ids, ctx, union, tags, structure),
Recursive(tags) => {
refcount_tag_union_help(root, ident_ids, ctx, union_layout, tags, None, structure)
refcount_tag_union_rec(root, ident_ids, ctx, union, tags, None, structure)
}
NonNullableUnwrapped(field_layouts) => {
let tags = root.arena.alloc([field_layouts]);
refcount_tag_union_help(root, ident_ids, ctx, union_layout, tags, None, structure)
refcount_tag_union_rec(root, ident_ids, ctx, union, tags, None, structure)
}
NullableWrapped {
@ -778,7 +792,7 @@ fn refcount_tag_union<'a>(
nullable_id,
} => {
let null_id = Some(nullable_id);
refcount_tag_union_help(root, ident_ids, ctx, union_layout, tags, null_id, structure)
refcount_tag_union_rec(root, ident_ids, ctx, union, tags, null_id, structure)
}
NullableUnwrapped {
@ -787,7 +801,7 @@ fn refcount_tag_union<'a>(
} => {
let null_id = Some(nullable_id as TagIdIntType);
let tags = root.arena.alloc([other_fields]);
refcount_tag_union_help(root, ident_ids, ctx, union_layout, tags, null_id, structure)
refcount_tag_union_rec(root, ident_ids, ctx, union, tags, null_id, structure)
}
};
@ -796,16 +810,14 @@ fn refcount_tag_union<'a>(
body
}
fn refcount_tag_union_help<'a>(
fn refcount_union_nonrec<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
tag_layouts: &'a [&'a [Layout<'a>]],
null_id: Option<TagIdIntType>,
structure: Symbol,
) -> Stmt<'a> {
let is_non_recursive = matches!(union_layout, UnionLayout::NonRecursive(_));
let tag_id_layout = union_layout.tag_id_layout();
let tag_id_sym = root.create_symbol(ident_ids, "tag_id");
@ -821,93 +833,159 @@ fn refcount_tag_union_help<'a>(
)
};
let modify_fields_stmt = if ctx.op.is_decref() {
rc_return_stmt(root, ident_ids, ctx)
} else {
let mut tag_branches = Vec::with_capacity_in(tag_layouts.len(), root.arena);
let continuation = rc_return_stmt(root, ident_ids, ctx);
let mut tag_id: TagIdIntType = 0;
for field_layouts in tag_layouts.iter() {
if let Some(id) = null_id {
if tag_id == id {
tag_id += 1;
}
let switch_stmt = refcount_union_contents(
root,
ident_ids,
ctx,
union_layout,
tag_layouts,
None,
structure,
tag_id_sym,
tag_id_layout,
continuation,
);
tag_id_stmt(root.arena.alloc(
//
switch_stmt,
))
}
#[allow(clippy::too_many_arguments)]
fn refcount_union_contents<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
tag_layouts: &'a [&'a [Layout<'a>]],
null_id: Option<TagIdIntType>,
structure: Symbol,
tag_id_sym: Symbol,
tag_id_layout: Layout<'a>,
modify_union_stmt: Stmt<'a>,
) -> Stmt<'a> {
let mut tag_branches = Vec::with_capacity_in(tag_layouts.len() + 1, root.arena);
if let Some(id) = null_id {
let ret = rc_return_stmt(root, ident_ids, ctx);
tag_branches.push((id as u64, BranchInfo::None, ret));
}
let mut tag_id: TagIdIntType = 0;
for field_layouts in tag_layouts.iter() {
match null_id {
Some(id) if id == tag_id => {
tag_id += 1;
}
_ => {}
}
let fields_stmt = refcount_tag_fields(
root,
ident_ids,
ctx,
union_layout,
field_layouts,
let fields_stmt = refcount_tag_fields(
root,
ident_ids,
ctx,
union_layout,
field_layouts,
structure,
tag_id,
modify_union_stmt.clone(), // TODO: Use a jump, this is a bit bloated
);
tag_branches.push((tag_id as u64, BranchInfo::None, fields_stmt));
tag_id += 1;
}
let default_stmt: Stmt<'a> = tag_branches.pop().unwrap().2;
Stmt::Switch {
cond_symbol: tag_id_sym,
cond_layout: tag_id_layout,
branches: tag_branches.into_bump_slice(),
default_branch: (BranchInfo::None, root.arena.alloc(default_stmt)),
ret_layout: LAYOUT_UNIT,
}
}
fn refcount_tag_union_rec<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
tag_layouts: &'a [&'a [Layout<'a>]],
null_id: Option<TagIdIntType>,
structure: Symbol,
) -> Stmt<'a> {
let tag_id_layout = union_layout.tag_id_layout();
let tag_id_sym = root.create_symbol(ident_ids, "tag_id");
let tag_id_stmt = |next| {
Stmt::Let(
tag_id_sym,
Expr::GetTagId {
structure,
tag_id as TagIdIntType,
);
tag_branches.push((tag_id as u64, BranchInfo::None, fields_stmt));
tag_id += 1;
}
let default_stmt: Stmt<'a> = tag_branches.pop().unwrap().2;
Stmt::Switch {
cond_symbol: tag_id_sym,
cond_layout: tag_id_layout,
branches: tag_branches.into_bump_slice(),
default_branch: (BranchInfo::None, root.arena.alloc(default_stmt)),
ret_layout: LAYOUT_UNIT,
}
union_layout,
},
tag_id_layout,
next,
)
};
let rc_structure_stmt = if is_non_recursive {
modify_fields_stmt
} else {
let rc_structure_stmt = {
let rc_ptr = root.create_symbol(ident_ids, "rc_ptr");
let alignment = Layout::Union(union_layout).alignment_bytes(root.ptr_size);
let ret_stmt = rc_return_stmt(root, ident_ids, ctx);
let modify_structure_stmt = modify_refcount(
root,
ident_ids,
ctx,
rc_ptr,
alignment,
root.arena.alloc(modify_fields_stmt),
root.arena.alloc(ret_stmt),
);
let rc_ptr_stmt = rc_ptr_from_data_ptr(
rc_ptr_from_data_ptr(
root,
ident_ids,
structure,
rc_ptr,
union_layout.stores_tag_id_in_pointer(root.ptr_size),
root.arena.alloc(modify_structure_stmt),
);
if let Some(id) = null_id {
let null_branch = (
id as u64,
BranchInfo::None,
rc_return_stmt(root, ident_ids, ctx),
);
Stmt::Switch {
cond_symbol: tag_id_sym,
cond_layout: tag_id_layout,
branches: root.arena.alloc([null_branch]),
default_branch: (BranchInfo::None, root.arena.alloc(rc_ptr_stmt)),
ret_layout: LAYOUT_UNIT,
}
} else {
rc_ptr_stmt
}
)
};
tag_id_stmt(root.arena.alloc(
//
rc_structure_stmt,
))
let rc_contents_then_structure = if ctx.op.is_decref() {
rc_structure_stmt
} else {
refcount_union_contents(
root,
ident_ids,
ctx,
union_layout,
tag_layouts,
null_id,
structure,
tag_id_sym,
tag_id_layout,
rc_structure_stmt,
)
};
if ctx.op.is_decref() && null_id.is_none() {
rc_contents_then_structure
} else {
tag_id_stmt(root.arena.alloc(
//
rc_contents_then_structure,
))
}
}
#[allow(clippy::too_many_arguments)]
fn refcount_tag_fields<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
@ -916,8 +994,9 @@ fn refcount_tag_fields<'a>(
field_layouts: &'a [Layout<'a>],
structure: Symbol,
tag_id: TagIdIntType,
following: Stmt<'a>,
) -> Stmt<'a> {
let mut stmt = rc_return_stmt(root, ident_ids, ctx);
let mut stmt = following;
for (i, field_layout) in field_layouts.iter().enumerate().rev() {
if field_layout.contains_refcounted() {