perf: Reduce best fitting allocations (#7411)

This commit is contained in:
Micha Reiser 2023-09-18 21:49:44 +02:00 committed by GitHub
parent 2421805033
commit 3336d23f48
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 173 additions and 83 deletions

View file

@ -9,9 +9,7 @@ use Tag::*;
use crate::format_element::tag::{Condition, Tag};
use crate::prelude::tag::{DedentMode, GroupMode, LabelId};
use crate::prelude::*;
use crate::{
format_element, write, Argument, Arguments, FormatContext, FormatOptions, GroupId, TextSize,
};
use crate::{write, Argument, Arguments, FormatContext, FormatOptions, GroupId, TextSize};
use crate::{Buffer, VecBuffer};
/// A line break that only gets printed if the enclosing `Group` doesn't fit on a single line.
@ -2543,15 +2541,12 @@ impl<Context> Format<Context> for BestFitting<'_, Context> {
fn fmt(&self, f: &mut Formatter<Context>) -> FormatResult<()> {
let variants = self.variants.items();
let mut formatted_variants = Vec::with_capacity(variants.len());
let mut buffer = VecBuffer::with_capacity(variants.len() * 8, f.state_mut());
for variant in variants {
let mut buffer = VecBuffer::with_capacity(8, f.state_mut());
buffer.write_element(FormatElement::Tag(StartEntry));
buffer.write_element(FormatElement::Tag(StartBestFittingEntry));
buffer.write_fmt(Arguments::from(variant))?;
buffer.write_element(FormatElement::Tag(EndEntry));
formatted_variants.push(buffer.into_vec().into_boxed_slice());
buffer.write_element(FormatElement::Tag(EndBestFittingEntry));
}
// SAFETY: The constructor guarantees that there are always at least two variants. It's, therefore,
@ -2559,9 +2554,7 @@ impl<Context> Format<Context> for BestFitting<'_, Context> {
#[allow(unsafe_code)]
let element = unsafe {
FormatElement::BestFitting {
variants: format_element::BestFittingVariants::from_vec_unchecked(
formatted_variants,
),
variants: BestFittingVariants::from_vec_unchecked(buffer.into_vec()),
mode: self.mode,
}
};

View file

@ -3,6 +3,7 @@ pub mod tag;
use std::borrow::Cow;
use std::hash::{Hash, Hasher};
use std::iter::FusedIterator;
use std::num::NonZeroU32;
use std::ops::Deref;
use std::rc::Rc;
@ -67,6 +68,16 @@ pub enum FormatElement {
Tag(Tag),
}
impl FormatElement {
pub fn tag_kind(&self) -> Option<TagKind> {
if let FormatElement::Tag(tag) = self {
Some(tag.kind())
} else {
None
}
}
}
impl std::fmt::Debug for FormatElement {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
@ -318,7 +329,7 @@ pub enum BestFittingMode {
/// The first element is the one that takes up the most space horizontally (the most flat),
/// The last element takes up the least space horizontally (but most horizontal space).
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct BestFittingVariants(Box<[Box<[FormatElement]>]>);
pub struct BestFittingVariants(Box<[FormatElement]>);
impl BestFittingVariants {
/// Creates a new best fitting IR with the given variants. The method itself isn't unsafe
@ -331,9 +342,13 @@ impl BestFittingVariants {
/// The slice must contain at least two variants.
#[doc(hidden)]
#[allow(unsafe_code)]
pub unsafe fn from_vec_unchecked(variants: Vec<Box<[FormatElement]>>) -> Self {
pub unsafe fn from_vec_unchecked(variants: Vec<FormatElement>) -> Self {
debug_assert!(
variants.len() >= 2,
variants
.iter()
.filter(|element| matches!(element, FormatElement::Tag(Tag::StartBestFittingEntry)))
.count()
>= 2,
"Requires at least the least expanded and most expanded variants"
);
@ -342,40 +357,85 @@ impl BestFittingVariants {
/// Returns the most expanded variant
pub fn most_expanded(&self) -> &[FormatElement] {
self.0.last().expect(
"Most contain at least two elements, as guaranteed by the best fitting builder.",
)
self.into_iter().last().unwrap()
}
pub fn as_slice(&self) -> &[Box<[FormatElement]>] {
pub fn as_slice(&self) -> &[FormatElement] {
&self.0
}
/// Returns the least expanded variant
pub fn most_flat(&self) -> &[FormatElement] {
self.0.first().expect(
"Most contain at least two elements, as guaranteed by the best fitting builder.",
)
self.into_iter().next().unwrap()
}
}
impl Deref for BestFittingVariants {
type Target = [Box<[FormatElement]>];
type Target = [FormatElement];
fn deref(&self) -> &Self::Target {
self.as_slice()
}
}
pub struct BestFittingVariantsIter<'a> {
elements: &'a [FormatElement],
}
impl<'a> IntoIterator for &'a BestFittingVariants {
type Item = &'a Box<[FormatElement]>;
type IntoIter = std::slice::Iter<'a, Box<[FormatElement]>>;
type Item = &'a [FormatElement];
type IntoIter = BestFittingVariantsIter<'a>;
fn into_iter(self) -> Self::IntoIter {
self.as_slice().iter()
BestFittingVariantsIter { elements: &self.0 }
}
}
impl<'a> Iterator for BestFittingVariantsIter<'a> {
type Item = &'a [FormatElement];
fn next(&mut self) -> Option<Self::Item> {
match self.elements.first()? {
FormatElement::Tag(Tag::StartBestFittingEntry) => {
let end = self
.elements
.iter()
.position(|element| {
matches!(element, FormatElement::Tag(Tag::EndBestFittingEntry))
})
.map_or(self.elements.len(), |position| position + 1);
let (variant, rest) = self.elements.split_at(end);
self.elements = rest;
Some(variant)
}
_ => None,
}
}
fn last(mut self) -> Option<Self::Item>
where
Self: Sized,
{
self.next_back()
}
}
impl<'a> DoubleEndedIterator for BestFittingVariantsIter<'a> {
fn next_back(&mut self) -> Option<Self::Item> {
let start_position = self.elements.iter().rposition(|element| {
matches!(element, FormatElement::Tag(Tag::StartBestFittingEntry))
})?;
let (rest, variant) = self.elements.split_at(start_position);
self.elements = rest;
Some(variant)
}
}
impl FusedIterator for BestFittingVariantsIter<'_> {}
pub trait FormatElements {
/// Returns true if this [`FormatElement`] is guaranteed to break across multiple lines by the printer.
/// This is the case if this format element recursively contains a:

View file

@ -35,7 +35,10 @@ impl Document {
enum Enclosing<'a> {
Group(&'a tag::Group),
ConditionalGroup(&'a tag::ConditionalGroup),
FitsExpanded(&'a tag::FitsExpanded),
FitsExpanded {
tag: &'a tag::FitsExpanded,
expands_before: bool,
},
BestFitting,
}
@ -43,7 +46,7 @@ impl Document {
match enclosing.last() {
Some(Enclosing::Group(group)) => group.propagate_expand(),
Some(Enclosing::ConditionalGroup(group)) => group.propagate_expand(),
Some(Enclosing::FitsExpanded(fits_expanded)) => fits_expanded.propagate_expand(),
Some(Enclosing::FitsExpanded { tag, .. }) => tag.propagate_expand(),
_ => {}
}
}
@ -85,23 +88,24 @@ impl Document {
FormatElement::BestFitting { variants, mode: _ } => {
enclosing.push(Enclosing::BestFitting);
for variant in variants {
propagate_expands(variant, enclosing, checked_interned);
}
// Best fitting acts as a boundary
expands = false;
propagate_expands(variants, enclosing, checked_interned);
enclosing.pop();
continue;
}
FormatElement::Tag(Tag::StartFitsExpanded(fits_expanded)) => {
enclosing.push(Enclosing::FitsExpanded(fits_expanded));
enclosing.push(Enclosing::FitsExpanded {
tag: fits_expanded,
expands_before: expands,
});
false
}
FormatElement::Tag(Tag::EndFitsExpanded) => {
enclosing.pop();
// Fits expanded acts as a boundary
expands = false;
if let Some(Enclosing::FitsExpanded { expands_before, .. }) =
enclosing.pop()
{
expands = expands_before;
}
continue;
}
FormatElement::Text {
@ -338,14 +342,20 @@ impl Format<IrFormatContext<'_>> for &[FormatElement] {
}
FormatElement::BestFitting { variants, mode } => {
write!(f, [token("best_fitting([")])?;
write!(f, [token("best_fitting(")])?;
if *mode != BestFittingMode::FirstLine {
write!(f, [text(&std::format!("mode: {mode:?}, "), None)])?;
}
write!(f, [token("[")])?;
f.write_elements([
FormatElement::Tag(StartIndent),
FormatElement::Line(LineMode::Hard),
]);
for variant in variants {
write!(f, [&**variant, hard_line_break()])?;
write!(f, [variant, hard_line_break()])?;
}
f.write_elements([
@ -353,13 +363,7 @@ impl Format<IrFormatContext<'_>> for &[FormatElement] {
FormatElement::Line(LineMode::Hard),
]);
write!(f, [token("]")])?;
if *mode != BestFittingMode::FirstLine {
write!(f, [text(&std::format!(", mode: {mode:?}"), None),])?;
}
write!(f, [token(")")])?;
write!(f, [token("])")])?;
}
FormatElement::Interned(interned) => {
@ -594,10 +598,10 @@ impl Format<IrFormatContext<'_>> for &[FormatElement] {
}
}
StartEntry => {
StartEntry | StartBestFittingEntry { .. } => {
// handled after the match for all start tags
}
EndEntry => write!(f, [ContentArrayEnd])?,
EndEntry | EndBestFittingEntry => write!(f, [ContentArrayEnd])?,
EndFill
| EndLabelled

View file

@ -83,6 +83,9 @@ pub enum Tag {
StartFitsExpanded(FitsExpanded),
EndFitsExpanded,
StartBestFittingEntry,
EndBestFittingEntry,
}
impl Tag {
@ -103,6 +106,7 @@ impl Tag {
| Tag::StartVerbatim(_)
| Tag::StartLabelled(_)
| Tag::StartFitsExpanded(_)
| Tag::StartBestFittingEntry,
)
}
@ -129,6 +133,7 @@ impl Tag {
StartVerbatim(_) | EndVerbatim => TagKind::Verbatim,
StartLabelled(_) | EndLabelled => TagKind::Labelled,
StartFitsExpanded { .. } | EndFitsExpanded => TagKind::FitsExpanded,
StartBestFittingEntry { .. } | EndBestFittingEntry => TagKind::BestFittingEntry,
}
}
}
@ -152,6 +157,7 @@ pub enum TagKind {
Verbatim,
Labelled,
FitsExpanded,
BestFittingEntry,
}
#[derive(Debug, Copy, Default, Clone, Eq, PartialEq)]

View file

@ -288,7 +288,9 @@ impl<'a> Printer<'a> {
stack.push(TagKind::FitsExpanded, args);
}
FormatElement::Tag(tag @ (StartLabelled(_) | StartEntry)) => {
FormatElement::Tag(
tag @ (StartLabelled(_) | StartEntry | StartBestFittingEntry { .. }),
) => {
stack.push(tag.kind(), args);
}
@ -305,6 +307,7 @@ impl<'a> Printer<'a> {
| EndFitsExpanded
| EndVerbatim
| EndLineSuffix
| EndBestFittingEntry
| EndFill),
) => {
stack.pop(tag.kind())?;
@ -495,47 +498,64 @@ impl<'a> Printer<'a> {
if args.mode().is_flat() && self.state.measured_group_fits {
queue.extend_back(variants.most_flat());
self.print_entry(queue, stack, args)
self.print_entry(queue, stack, args, TagKind::BestFittingEntry)
} else {
self.state.measured_group_fits = true;
let normal_variants = &variants[..variants.len() - 1];
let mut variants_iter = variants.into_iter();
let mut current = variants_iter.next().unwrap();
for variant in normal_variants {
for next in variants_iter {
// Test if this variant fits and if so, use it. Otherwise try the next
// variant.
// Try to fit only the first variant on a single line
if !matches!(variant.first(), Some(&FormatElement::Tag(Tag::StartEntry))) {
return invalid_start_tag(TagKind::Entry, variant.first());
if !matches!(
current.first(),
Some(&FormatElement::Tag(Tag::StartBestFittingEntry))
) {
return invalid_start_tag(TagKind::BestFittingEntry, current.first());
}
// Skip the first element because we want to override the args for the entry and the
// args must be popped from the stack as soon as it sees the matching end entry.
let content = &variant[1..];
let content = &current[1..];
let entry_args = args
.with_print_mode(PrintMode::Flat)
.with_measure_mode(MeasureMode::from(mode));
queue.extend_back(content);
stack.push(TagKind::Entry, entry_args);
stack.push(TagKind::BestFittingEntry, entry_args);
let variant_fits = self.fits(queue, stack)?;
stack.pop(TagKind::Entry)?;
stack.pop(TagKind::BestFittingEntry)?;
// Remove the content slice because printing needs the variant WITH the start entry
let popped_slice = queue.pop_slice();
debug_assert_eq!(popped_slice, Some(content));
if variant_fits {
queue.extend_back(variant);
return self.print_entry(queue, stack, args.with_print_mode(PrintMode::Flat));
queue.extend_back(current);
return self.print_entry(
queue,
stack,
args.with_print_mode(PrintMode::Flat),
TagKind::BestFittingEntry,
);
}
current = next;
}
// At this stage current is the most expanded.
// No variant fits, take the last (most expanded) as fallback
let most_expanded = variants.most_expanded();
queue.extend_back(most_expanded);
self.print_entry(queue, stack, args.with_print_mode(PrintMode::Expanded))
queue.extend_back(current);
self.print_entry(
queue,
stack,
args.with_print_mode(PrintMode::Expanded),
TagKind::BestFittingEntry,
)
}
}
@ -686,7 +706,7 @@ impl<'a> Printer<'a> {
stack: &mut PrintCallStack,
args: PrintElementArgs,
) -> PrintResult<()> {
self.print_entry(queue, stack, args)
self.print_entry(queue, stack, args, TagKind::Entry)
}
/// Semantic alias for [`Self::print_entry`] for fill separators.
@ -696,7 +716,7 @@ impl<'a> Printer<'a> {
stack: &mut PrintCallStack,
args: PrintElementArgs,
) -> PrintResult<()> {
self.print_entry(queue, stack, args)
self.print_entry(queue, stack, args, TagKind::Entry)
}
/// Fully print an element (print the element itself and all its descendants)
@ -708,32 +728,31 @@ impl<'a> Printer<'a> {
queue: &mut PrintQueue<'a>,
stack: &mut PrintCallStack,
args: PrintElementArgs,
kind: TagKind,
) -> PrintResult<()> {
let start_entry = queue.top();
if !matches!(start_entry, Some(&FormatElement::Tag(Tag::StartEntry))) {
return invalid_start_tag(TagKind::Entry, start_entry);
if queue
.pop()
.is_some_and(|start| start.tag_kind() == Some(kind))
{
stack.push(kind, args);
} else {
return invalid_start_tag(kind, start_entry);
}
let mut depth = 0;
let mut depth = 1u32;
while let Some(element) = queue.pop() {
match element {
FormatElement::Tag(Tag::StartEntry) => {
// Handle the start of the first element by pushing the args on the stack.
if depth == 0 {
depth = 1;
stack.push(TagKind::Entry, args);
continue;
}
FormatElement::Tag(Tag::StartEntry | Tag::StartBestFittingEntry) => {
depth += 1;
}
FormatElement::Tag(Tag::EndEntry) => {
FormatElement::Tag(end_tag @ (Tag::EndEntry | Tag::EndBestFittingEntry)) => {
depth -= 1;
// Reached the end entry, pop the entry from the stack and return.
if depth == 0 {
stack.pop(TagKind::Entry)?;
stack.pop(end_tag.kind())?;
return Ok(());
}
}
@ -745,7 +764,7 @@ impl<'a> Printer<'a> {
self.print_element(stack, queue, element)?;
}
invalid_end_tag(TagKind::Entry, stack.top_kind())
invalid_end_tag(kind, stack.top_kind())
}
fn print_char(&mut self, char: char) {
@ -1148,11 +1167,14 @@ impl<'a, 'print> FitsMeasurer<'a, 'print> {
PrintMode::Expanded => (variants.most_expanded(), args),
};
if !matches!(slice.first(), Some(FormatElement::Tag(Tag::StartEntry))) {
return invalid_start_tag(TagKind::Entry, slice.first());
if !matches!(
slice.first(),
Some(FormatElement::Tag(Tag::StartBestFittingEntry))
) {
return invalid_start_tag(TagKind::BestFittingEntry, slice.first());
}
self.stack.push(TagKind::Entry, args);
self.stack.push(TagKind::BestFittingEntry, args);
self.queue.extend_back(&slice[1..]);
}
@ -1277,7 +1299,11 @@ impl<'a, 'print> FitsMeasurer<'a, 'print> {
}
FormatElement::Tag(
tag @ (StartFill | StartVerbatim(_) | StartLabelled(_) | StartEntry),
tag @ (StartFill
| StartVerbatim(_)
| StartLabelled(_)
| StartEntry
| StartBestFittingEntry { .. }),
) => {
self.stack.push(tag.kind(), args);
}
@ -1294,6 +1320,7 @@ impl<'a, 'print> FitsMeasurer<'a, 'print> {
| EndAlign
| EndDedent
| EndIndent
| EndBestFittingEntry
| EndFitsExpanded),
) => {
self.stack.pop(tag.kind())?;