Refactor the solve inferrer

This commit is contained in:
Ayaz Hafiz 2023-04-01 12:45:08 -05:00
parent 9d8d36b532
commit bfcafb0be3
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
2 changed files with 164 additions and 72 deletions

View file

@ -582,14 +582,23 @@ pub fn find_type_at(region: Region, decls: &Declarations) -> Option<Variable> {
visitor.typ visitor.typ
} }
pub enum FoundSymbol {
/// Specialization(T, foo1) is the specialization of foo for T.
Specialization(Symbol, Symbol),
/// AbilityMember(Foo, foo) is the ability member foo of Foo.
AbilityMember(Symbol, Symbol),
/// Raw symbol, not specialized to anything.
Symbol(Symbol),
}
/// Given an ability Foo has foo : ..., returns (T, foo1) if the symbol at the given region is a /// Given an ability Foo has foo : ..., returns (T, foo1) if the symbol at the given region is a
/// symbol foo1 that specializes foo for T. Otherwise if the symbol is foo but the specialization /// symbol foo1 that specializes foo for T. Otherwise if the symbol is foo but the specialization
/// is unknown, (Foo, foo) is returned. Otherwise [None] is returned. /// is unknown, (Foo, foo) is returned. Otherwise [None] is returned.
pub fn find_ability_member_and_owning_type_at( pub fn find_symbol_at(
region: Region, region: Region,
decls: &Declarations, decls: &Declarations,
abilities_store: &AbilitiesStore, abilities_store: &AbilitiesStore,
) -> Option<(Symbol, Symbol)> { ) -> Option<FoundSymbol> {
let mut visitor = Finder { let mut visitor = Finder {
region, region,
found: None, found: None,
@ -601,7 +610,7 @@ pub fn find_ability_member_and_owning_type_at(
struct Finder<'a> { struct Finder<'a> {
region: Region, region: Region,
abilities_store: &'a AbilitiesStore, abilities_store: &'a AbilitiesStore,
found: Option<(Symbol, Symbol)>, found: Option<FoundSymbol>,
} }
impl Visitor for Finder<'_> { impl Visitor for Finder<'_> {
@ -611,16 +620,19 @@ pub fn find_ability_member_and_owning_type_at(
fn visit_pattern(&mut self, pattern: &Pattern, region: Region, _opt_var: Option<Variable>) { fn visit_pattern(&mut self, pattern: &Pattern, region: Region, _opt_var: Option<Variable>) {
if region == self.region { if region == self.region {
if let Pattern::AbilityMemberSpecialization { match pattern {
ident: spec_symbol, Pattern::AbilityMemberSpecialization {
specializes: _, ident: spec_symbol,
} = pattern specializes: _,
{ } => {
debug_assert!(self.found.is_none()); debug_assert!(self.found.is_none());
let spec_type = let spec_type =
find_specialization_type_of_symbol(*spec_symbol, self.abilities_store) find_specialization_type_of_symbol(*spec_symbol, self.abilities_store)
.unwrap(); .unwrap();
self.found = Some((spec_type, *spec_symbol)) self.found = Some(FoundSymbol::Specialization(spec_type, *spec_symbol))
}
Pattern::Identifier(symbol) => self.found = Some(FoundSymbol::Symbol(*symbol)),
_ => {}
} }
} }
@ -640,7 +652,7 @@ pub fn find_ability_member_and_owning_type_at(
self.abilities_store, self.abilities_store,
) )
.unwrap(); .unwrap();
Some((spec_type, spec_symbol)) Some(FoundSymbol::Specialization(spec_type, spec_symbol))
} }
None => { None => {
let parent_ability = self let parent_ability = self
@ -648,7 +660,7 @@ pub fn find_ability_member_and_owning_type_at(
.member_def(member_symbol) .member_def(member_symbol)
.unwrap() .unwrap()
.parent_ability; .parent_ability;
Some((parent_ability, member_symbol)) Some(FoundSymbol::AbilityMember(parent_ability, member_symbol))
} }
}; };
return; return;

View file

@ -3,8 +3,9 @@ use std::{error::Error, io, path::PathBuf};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
use roc_can::{ use roc_can::{
abilities::AbilitiesStore,
expr::Declarations, expr::Declarations,
traverse::{find_ability_member_and_owning_type_at, find_type_at}, traverse::{find_symbol_at, find_type_at, FoundSymbol},
}; };
use roc_load::LoadedModule; use roc_load::LoadedModule;
use roc_module::symbol::{Interns, ModuleId}; use roc_module::symbol::{Interns, ModuleId};
@ -13,7 +14,10 @@ use roc_problem::can::Problem;
use roc_region::all::{LineColumn, LineColumnRegion, LineInfo, Region}; use roc_region::all::{LineColumn, LineColumnRegion, LineInfo, Region};
use roc_reporting::report::{can_problem, type_problem, RocDocAllocator}; use roc_reporting::report::{can_problem, type_problem, RocDocAllocator};
use roc_solve_problem::TypeError; use roc_solve_problem::TypeError;
use roc_types::pretty_print::{name_and_print_var, DebugPrint}; use roc_types::{
pretty_print::{name_and_print_var, DebugPrint},
subs::{Subs, Variable},
};
fn promote_expr_to_module(src: &str) -> String { fn promote_expr_to_module(src: &str) -> String {
let mut buffer = String::from(indoc::indoc!( let mut buffer = String::from(indoc::indoc!(
@ -134,12 +138,18 @@ lazy_static! {
/// inst # instantiate the given generic instance /// inst # instantiate the given generic instance
/// ``` /// ```
static ref RE_TYPE_QUERY: Regex = static ref RE_TYPE_QUERY: Regex =
Regex::new(r#"(?P<where>\^+)(?:\{-(?P<sub>\d+)\})?"#).unwrap(); Regex::new(r#"(?P<where>\^+)(?:\{(?P<directives>.*?)\})?"#).unwrap();
static ref RE_DIRECTIVE : Regex =
Regex::new(r#"(?:-(?P<sub>\d+))|(?P<inst>inst)"#).unwrap();
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct TypeQuery { pub struct TypeQuery {
query_region: Region, query_region: Region,
/// If true, the query is under a function call, which should be instantiated with the present
/// value and have its nested queries printed.
instantiate: bool,
source: String, source: String,
comment_column: u32, comment_column: u32,
source_line_column: LineColumn, source_line_column: LineColumn,
@ -179,10 +189,23 @@ fn parse_queries(src: &str) -> Vec<TypeQuery> {
.to_string(); .to_string();
let wher = capture.name("where").unwrap(); let wher = capture.name("where").unwrap();
let subtract_col = capture
.name("sub") let mut subtract_col = 0u32;
.and_then(|m| str::parse(m.as_str()).ok()) let mut instantiate = false;
.unwrap_or(0);
if let Some(directives) = capture.name("directives") {
for directive in directives.as_str().split(',') {
let directive = RE_DIRECTIVE
.captures(directive)
.expect(&format!("directive {directive} must match RE_DIRECTIVE"));
if let Some(sub) = directive.name("sub") {
subtract_col += sub.as_str().parse::<u32>().expect("must be a number");
}
if directive.name("inst").is_some() {
instantiate = true;
}
}
}
let (source_start, source_end) = (wher.start() as u32, wher.end() as u32); let (source_start, source_end) = (wher.start() as u32, wher.end() as u32);
let (query_start, query_end) = (source_start - subtract_col, source_end - subtract_col); let (query_start, query_end) = (source_start - subtract_col, source_end - subtract_col);
@ -211,6 +234,7 @@ fn parse_queries(src: &str) -> Vec<TypeQuery> {
source, source,
comment_column, comment_column,
source_line_column, source_line_column,
instantiate,
}); });
} }
} }
@ -335,58 +359,21 @@ pub fn infer_queries(src: &str, options: InferOptions) -> Result<InferredProgram
} }
let mut inferred_queries = Vec::with_capacity(queries.len()); let mut inferred_queries = Vec::with_capacity(queries.len());
for TypeQuery {
query_region,
source,
comment_column,
source_line_column,
} in queries.into_iter()
{
let start = query_region.start().offset;
let end = query_region.end().offset;
let text = &src[start as usize..end as usize];
let var = find_type_at(query_region, &declarations)
.ok_or_else(|| format!("No type for {:?} ({:?})!", &text, query_region))?;
let snapshot = subs.snapshot(); let mut ctx = QueryCtx {
let actual_str = name_and_print_var( all_queries: &queries,
var, source: &src,
subs, declarations: &declarations,
home, subs,
&interns, abilities_store: &abilities_store,
DebugPrint { home,
print_lambda_sets: true, interns: &interns,
print_only_under_alias: options.print_only_under_alias, options,
ignore_polarity: true, };
print_weakened_vars: true,
},
);
subs.rollback_to(snapshot);
let (header, elaboration) = match find_ability_member_and_owning_type_at( for query in queries.iter() {
query_region, let answer = ctx.answer(query)?;
&declarations, inferred_queries.push(answer);
&abilities_store,
) {
Some((spec_type, spec_symbol)) => (
InferredHeader::Specialization(format!(
"{}#{}({})",
spec_type.as_str(&interns),
text,
spec_symbol.ident_id().index(),
)),
actual_str,
),
None => (InferredHeader::Source(text.to_owned()), actual_str),
};
inferred_queries.push(InferredQuery {
header,
elaboration,
comment_column,
source_line_column,
source,
});
} }
Ok(InferredProgram { Ok(InferredProgram {
@ -399,6 +386,99 @@ pub fn infer_queries(src: &str, options: InferOptions) -> Result<InferredProgram
}) })
} }
struct QueryCtx<'a> {
all_queries: &'a [TypeQuery],
source: &'a str,
declarations: &'a Declarations,
subs: &'a mut Subs,
abilities_store: &'a AbilitiesStore,
home: ModuleId,
interns: &'a Interns,
options: InferOptions,
}
impl<'a> QueryCtx<'a> {
fn answer(&mut self, query: &TypeQuery) -> Result<InferredQuery, Box<dyn Error>> {
let TypeQuery {
query_region,
source,
comment_column,
source_line_column,
instantiate,
} = query;
let start = query_region.start().offset;
let end = query_region.end().offset;
let text = &self.source[start as usize..end as usize];
let var = find_type_at(*query_region, self.declarations)
.ok_or_else(|| format!("No type for {:?} ({:?})!", &text, query_region))?;
let snapshot = self.subs.snapshot();
let (header, elaboration) = if *instantiate {
self.infer_instantiated(var)
} else {
self.infer_direct(var, *query_region, text)
};
self.subs.rollback_to(snapshot);
Ok(InferredQuery {
header,
elaboration,
comment_column: *comment_column,
source_line_column: *source_line_column,
source: source.to_string(),
})
}
fn infer_direct(
&mut self,
var: Variable,
query_region: Region,
text: &str,
) -> (InferredHeader, String) {
let actual_str = name_and_print_var(
var,
self.subs,
self.home,
self.interns,
DebugPrint {
print_lambda_sets: true,
print_only_under_alias: self.options.print_only_under_alias,
ignore_polarity: true,
print_weakened_vars: true,
},
);
let (header, elaboration) =
match find_symbol_at(query_region, self.declarations, self.abilities_store) {
Some(found_symbol) => {
let header = match found_symbol {
FoundSymbol::Specialization(spec_type, spec_symbol)
| FoundSymbol::AbilityMember(spec_type, spec_symbol) => {
InferredHeader::Specialization(format!(
"{}#{}({})",
spec_type.as_str(self.interns),
text,
spec_symbol.ident_id().index(),
))
}
FoundSymbol::Symbol(symbol) => {
InferredHeader::Source(symbol.as_str(self.interns).to_owned())
}
};
(header, actual_str)
}
None => (InferredHeader::Source(text.to_owned()), actual_str),
};
(header, elaboration)
}
fn infer_instantiated(&self, var: Variable) -> (InferredHeader, String) {
todo!()
}
}
pub fn infer_queries_help(src: &str, expected: impl FnOnce(&str), options: InferOptions) { pub fn infer_queries_help(src: &str, expected: impl FnOnce(&str), options: InferOptions) {
let InferredProgram { let InferredProgram {
program, program,