Refactor the solve inferrer

This commit is contained in:
Ayaz Hafiz 2023-04-01 12:45:08 -05:00
parent 9d8d36b532
commit bfcafb0be3
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
2 changed files with 164 additions and 72 deletions

View file

@ -3,8 +3,9 @@ use std::{error::Error, io, path::PathBuf};
use lazy_static::lazy_static;
use regex::Regex;
use roc_can::{
abilities::AbilitiesStore,
expr::Declarations,
traverse::{find_ability_member_and_owning_type_at, find_type_at},
traverse::{find_symbol_at, find_type_at, FoundSymbol},
};
use roc_load::LoadedModule;
use roc_module::symbol::{Interns, ModuleId};
@ -13,7 +14,10 @@ use roc_problem::can::Problem;
use roc_region::all::{LineColumn, LineColumnRegion, LineInfo, Region};
use roc_reporting::report::{can_problem, type_problem, RocDocAllocator};
use roc_solve_problem::TypeError;
use roc_types::pretty_print::{name_and_print_var, DebugPrint};
use roc_types::{
pretty_print::{name_and_print_var, DebugPrint},
subs::{Subs, Variable},
};
fn promote_expr_to_module(src: &str) -> String {
let mut buffer = String::from(indoc::indoc!(
@ -134,12 +138,18 @@ lazy_static! {
/// inst # instantiate the given generic instance
/// ```
static ref RE_TYPE_QUERY: Regex =
Regex::new(r#"(?P<where>\^+)(?:\{-(?P<sub>\d+)\})?"#).unwrap();
Regex::new(r#"(?P<where>\^+)(?:\{(?P<directives>.*?)\})?"#).unwrap();
static ref RE_DIRECTIVE : Regex =
Regex::new(r#"(?:-(?P<sub>\d+))|(?P<inst>inst)"#).unwrap();
}
#[derive(Debug, Clone)]
pub struct TypeQuery {
query_region: Region,
/// If true, the query is under a function call, which should be instantiated with the present
/// value and have its nested queries printed.
instantiate: bool,
source: String,
comment_column: u32,
source_line_column: LineColumn,
@ -179,10 +189,23 @@ fn parse_queries(src: &str) -> Vec<TypeQuery> {
.to_string();
let wher = capture.name("where").unwrap();
let subtract_col = capture
.name("sub")
.and_then(|m| str::parse(m.as_str()).ok())
.unwrap_or(0);
let mut subtract_col = 0u32;
let mut instantiate = false;
if let Some(directives) = capture.name("directives") {
for directive in directives.as_str().split(',') {
let directive = RE_DIRECTIVE
.captures(directive)
.expect(&format!("directive {directive} must match RE_DIRECTIVE"));
if let Some(sub) = directive.name("sub") {
subtract_col += sub.as_str().parse::<u32>().expect("must be a number");
}
if directive.name("inst").is_some() {
instantiate = true;
}
}
}
let (source_start, source_end) = (wher.start() as u32, wher.end() as u32);
let (query_start, query_end) = (source_start - subtract_col, source_end - subtract_col);
@ -211,6 +234,7 @@ fn parse_queries(src: &str) -> Vec<TypeQuery> {
source,
comment_column,
source_line_column,
instantiate,
});
}
}
@ -335,58 +359,21 @@ pub fn infer_queries(src: &str, options: InferOptions) -> Result<InferredProgram
}
let mut inferred_queries = Vec::with_capacity(queries.len());
for TypeQuery {
query_region,
source,
comment_column,
source_line_column,
} in queries.into_iter()
{
let start = query_region.start().offset;
let end = query_region.end().offset;
let text = &src[start as usize..end as usize];
let var = find_type_at(query_region, &declarations)
.ok_or_else(|| format!("No type for {:?} ({:?})!", &text, query_region))?;
let snapshot = subs.snapshot();
let actual_str = name_and_print_var(
var,
subs,
home,
&interns,
DebugPrint {
print_lambda_sets: true,
print_only_under_alias: options.print_only_under_alias,
ignore_polarity: true,
print_weakened_vars: true,
},
);
subs.rollback_to(snapshot);
let mut ctx = QueryCtx {
all_queries: &queries,
source: &src,
declarations: &declarations,
subs,
abilities_store: &abilities_store,
home,
interns: &interns,
options,
};
let (header, elaboration) = match find_ability_member_and_owning_type_at(
query_region,
&declarations,
&abilities_store,
) {
Some((spec_type, spec_symbol)) => (
InferredHeader::Specialization(format!(
"{}#{}({})",
spec_type.as_str(&interns),
text,
spec_symbol.ident_id().index(),
)),
actual_str,
),
None => (InferredHeader::Source(text.to_owned()), actual_str),
};
inferred_queries.push(InferredQuery {
header,
elaboration,
comment_column,
source_line_column,
source,
});
for query in queries.iter() {
let answer = ctx.answer(query)?;
inferred_queries.push(answer);
}
Ok(InferredProgram {
@ -399,6 +386,99 @@ pub fn infer_queries(src: &str, options: InferOptions) -> Result<InferredProgram
})
}
struct QueryCtx<'a> {
all_queries: &'a [TypeQuery],
source: &'a str,
declarations: &'a Declarations,
subs: &'a mut Subs,
abilities_store: &'a AbilitiesStore,
home: ModuleId,
interns: &'a Interns,
options: InferOptions,
}
impl<'a> QueryCtx<'a> {
fn answer(&mut self, query: &TypeQuery) -> Result<InferredQuery, Box<dyn Error>> {
let TypeQuery {
query_region,
source,
comment_column,
source_line_column,
instantiate,
} = query;
let start = query_region.start().offset;
let end = query_region.end().offset;
let text = &self.source[start as usize..end as usize];
let var = find_type_at(*query_region, self.declarations)
.ok_or_else(|| format!("No type for {:?} ({:?})!", &text, query_region))?;
let snapshot = self.subs.snapshot();
let (header, elaboration) = if *instantiate {
self.infer_instantiated(var)
} else {
self.infer_direct(var, *query_region, text)
};
self.subs.rollback_to(snapshot);
Ok(InferredQuery {
header,
elaboration,
comment_column: *comment_column,
source_line_column: *source_line_column,
source: source.to_string(),
})
}
fn infer_direct(
&mut self,
var: Variable,
query_region: Region,
text: &str,
) -> (InferredHeader, String) {
let actual_str = name_and_print_var(
var,
self.subs,
self.home,
self.interns,
DebugPrint {
print_lambda_sets: true,
print_only_under_alias: self.options.print_only_under_alias,
ignore_polarity: true,
print_weakened_vars: true,
},
);
let (header, elaboration) =
match find_symbol_at(query_region, self.declarations, self.abilities_store) {
Some(found_symbol) => {
let header = match found_symbol {
FoundSymbol::Specialization(spec_type, spec_symbol)
| FoundSymbol::AbilityMember(spec_type, spec_symbol) => {
InferredHeader::Specialization(format!(
"{}#{}({})",
spec_type.as_str(self.interns),
text,
spec_symbol.ident_id().index(),
))
}
FoundSymbol::Symbol(symbol) => {
InferredHeader::Source(symbol.as_str(self.interns).to_owned())
}
};
(header, actual_str)
}
None => (InferredHeader::Source(text.to_owned()), actual_str),
};
(header, elaboration)
}
fn infer_instantiated(&self, var: Variable) -> (InferredHeader, String) {
todo!()
}
}
pub fn infer_queries_help(src: &str, expected: impl FnOnce(&str), options: InferOptions) {
let InferredProgram {
program,