feat: argumen type narrowing

This commit is contained in:
Shunsuke Shibayama 2024-03-17 20:57:13 +09:00
parent 95e675cccd
commit fd873a1916
6 changed files with 268 additions and 45 deletions

View file

@ -39,6 +39,13 @@ impl PartialEq for Str {
}
}
impl PartialEq<Str> for &mut Str {
#[inline]
fn eq(&self, other: &Str) -> bool {
self[..] == other[..]
}
}
impl PartialEq<str> for Str {
#[inline]
fn eq(&self, other: &str) -> bool {

View file

@ -19,7 +19,9 @@ use erg_parser::Parser;
use crate::ty::free::{CanbeFree, Constraint, HasLevel};
use crate::ty::typaram::{IntervalOp, OpKind, TyParam, TyParamLambda, TyParamOrdering};
use crate::ty::value::ValueObj;
use crate::ty::{constructors::*, Predicate, RefinementType, VisibilityModifier};
use crate::ty::{
constructors::*, CastTarget, GuardType, Predicate, RefinementType, VisibilityModifier,
};
use crate::ty::{Field, HasType, ParamTy, SubrKind, SubrType, Type};
use crate::type_feature_error;
use crate::varinfo::{AbsLocation, VarInfo};
@ -271,7 +273,7 @@ impl Context {
}
}
}
let var_args = if let Some(var_args) = sig.params.var_params.as_ref() {
let var_params = if let Some(var_args) = sig.params.var_params.as_ref() {
let opt_decl_t = opt_decl_sig_t
.as_ref()
.and_then(|subr| subr.var_params.as_ref().map(|v| v.as_ref()));
@ -313,7 +315,7 @@ impl Context {
}
}
}
let kw_var_args = if let Some(kw_var_args) = sig.params.kw_var_params.as_ref() {
let kw_var_params = if let Some(kw_var_args) = sig.params.kw_var_params.as_ref() {
let opt_decl_t = opt_decl_sig_t
.as_ref()
.and_then(|subr| subr.kw_var_params.as_ref().map(|v| v.as_ref()));
@ -346,7 +348,15 @@ impl Context {
mode,
false,
) {
Ok(ty) => ty,
Ok(ty) => {
let params = non_defaults
.iter()
.chain(&var_params)
.chain(&defaults)
.chain(&kw_var_params)
.filter_map(|pt| pt.name());
self.recover_guard(ty, params)
}
Err(es) => {
errs.extend(es);
Type::Failure
@ -363,9 +373,21 @@ impl Context {
};
// tmp_tv_cache.warn_isolated_vars(self);
let typ = if sig.ident.is_procedural() {
proc(non_defaults, var_args, defaults, kw_var_args, spec_return_t)
proc(
non_defaults,
var_params,
defaults,
kw_var_params,
spec_return_t,
)
} else {
func(non_defaults, var_args, defaults, kw_var_args, spec_return_t)
func(
non_defaults,
var_params,
defaults,
kw_var_params,
spec_return_t,
)
};
if errs.is_empty() {
Ok(typ)
@ -1861,6 +1883,29 @@ impl Context {
}
}
#[inline]
fn recover_guard<'a>(&self, return_t: Type, mut params: impl Iterator<Item = &'a Str>) -> Type {
match return_t {
Type::Guard(GuardType {
namespace,
target: CastTarget::Expr(expr),
to,
}) => {
let target = if let Some(nth) = params.position(|p| Some(p) == expr.get_name()) {
CastTarget::arg(nth, expr.get_name().unwrap().clone(), ().loc())
} else {
CastTarget::Expr(expr)
};
Type::Guard(GuardType {
namespace,
target,
to,
})
}
_ => return_t,
}
}
// FIXME: opt_decl_t must be disassembled for each polymorphic type
pub(crate) fn instantiate_typespec_full(
&self,
@ -2001,6 +2046,28 @@ impl Context {
// TODO: エラー処理(リテラルでない)はパーサーにやらせる
TypeSpec::Enum(set) => {
let mut new_set = set! {};
// guard type (e.g. {x in Int})
if set.pos_args.len() == 1 {
let expr = &set.pos_args().next().unwrap().expr;
match expr {
ConstExpr::BinOp(bin) if bin.op.is(TokenKind::InOp) => {
if let Ok(to) = self.instantiate_const_expr_as_type(
&bin.rhs,
None,
tmp_tv_cache,
not_found_is_qvar,
) {
let target = CastTarget::expr(bin.lhs.clone().downgrade());
return Ok(Type::Guard(GuardType::new(
self.name.clone(),
target,
to,
)));
}
}
_ => {}
}
}
for arg in set.pos_args() {
new_set.insert(self.instantiate_const_expr(
&arg.expr,
@ -2132,6 +2199,13 @@ impl Context {
Type::Failure
}
};
let params = non_defaults
.iter()
.chain(&var_params)
.chain(&defaults)
.chain(&kw_var_params)
.filter_map(|pt| pt.name());
let return_t = self.recover_guard(return_t, params);
// no quantification at this point (in `generalize_t`)
if errs.is_empty() {
Ok(subr_t(

View file

@ -2413,6 +2413,7 @@ impl Context {
pub(crate) fn cast(
&mut self,
guard: GuardType,
args: Option<&hir::Args>,
overwritten: &mut Vec<(VarName, VarInfo)>,
) -> TyCheckResult<()> {
match &guard.target {
@ -2443,15 +2444,57 @@ impl Context {
return Err(errs);
}
}
Ok(())
}
CastTarget::Param { .. } => {
// TODO:
// ```
// i: Obj
// is_int: (x: Obj) -> {x in Int} # change the 0th arg type to Int
// assert is_int i
// i: Int
// ```
CastTarget::Arg { nth, name, loc } => {
if let Some(name) = args
.and_then(|args| args.get(*nth))
.and_then(|ex| ex.local_name())
{
let vi = if let Some((name, vi)) = self.locals.remove_entry(name) {
overwritten.push((name, vi.clone()));
vi
} else if let Some((n, vi)) = self.get_var_kv(name) {
overwritten.push((n.clone(), vi.clone()));
vi.clone()
} else {
VarInfo::nd_parameter(
*guard.to.clone(),
self.absolutize(().loc()),
self.name.clone(),
)
};
match self.recover_typarams(&vi.t, &guard) {
Ok(t) => {
self.locals
.insert(VarName::from_str(Str::rc(name)), VarInfo { t, ..vi });
}
Err(errs) => {
self.locals.insert(VarName::from_str(Str::rc(name)), vi);
return Err(errs);
}
}
Ok(())
} else {
let target = CastTarget::Var {
name: name.clone(),
loc: *loc,
};
let guard = GuardType::new(guard.namespace, target, *guard.to);
self.cast(guard, args, overwritten)
}
}
CastTarget::Expr(_) => {
self.guards.push(guard);
Ok(())
}
}
Ok(())
}
pub(crate) fn inc_ref<L: Locational>(

View file

@ -60,16 +60,6 @@ use crate::{AccessKind, GenericHIRBuilder};
use VisibilityModifier::*;
pub fn expr_to_cast_target(expr: &ast::Expr) -> CastTarget {
match expr {
ast::Expr::Accessor(ast::Accessor::Ident(ident)) => CastTarget::Var {
name: ident.inspect().clone(),
loc: ident.loc(),
},
_ => CastTarget::expr(expr.clone()),
}
}
/// Checks & infers types of an AST, and convert (lower) it into a HIR
#[derive(Debug)]
pub struct GenericASTLowerer<ASTBuilder: ASTBuildable = DefaultASTBuilder> {
@ -1010,6 +1000,32 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
}
}
pub fn expr_to_cast_target(&self, expr: &ast::Expr) -> CastTarget {
match expr {
ast::Expr::Accessor(ast::Accessor::Ident(ident)) => {
if let Some(nth) = self
.module
.context
.params
.iter()
.position(|(name, _)| name.as_ref() == Some(&ident.name))
{
CastTarget::Arg {
nth,
name: ident.inspect().clone(),
loc: ident.loc(),
}
} else {
CastTarget::Var {
name: ident.inspect().clone(),
loc: ident.loc(),
}
}
}
_ => CastTarget::expr(expr.clone()),
}
}
fn get_bin_guard_type(&self, op: &Token, lhs: &ast::Expr, rhs: &ast::Expr) -> Option<Type> {
match op.kind {
TokenKind::AndOp => {
@ -1025,9 +1041,9 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
_ => {}
}
let target = if op.kind == TokenKind::ContainsOp {
expr_to_cast_target(rhs)
self.expr_to_cast_target(rhs)
} else {
expr_to_cast_target(lhs)
self.expr_to_cast_target(lhs)
};
let namespace = self.module.context.name.clone();
match op.kind {
@ -1463,7 +1479,15 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
if let Some(Type::Guard(guard)) =
call.args.get_left_or_key("test").map(|exp| exp.ref_t())
{
self.module.context.cast(guard.clone(), &mut vec![])?;
let test = call.args.get_left_or_key("test").unwrap();
let test_args = if let hir::Expr::Call(call) = test {
Some(&call.args)
} else {
None
};
self.module
.context
.cast(guard.clone(), test_args, &mut vec![])?;
}
Ok(())
}
@ -1692,7 +1716,7 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
mem::take(&mut self.module.context.get_mut_outer().unwrap().guards)
};
for guard in guards.into_iter() {
if let Err(errs) = self.module.context.cast(guard, &mut overwritten) {
if let Err(errs) = self.module.context.cast(guard, None, &mut overwritten) {
self.errs.extend(errs);
}
}

View file

@ -199,6 +199,13 @@ impl ParamTy {
}
}
pub fn name_mut(&mut self) -> Option<&mut Str> {
match self {
Self::Pos(_) => None,
Self::Kw { name, .. } | Self::KwWithDefault { name, .. } => Some(name),
}
}
pub const fn typ(&self) -> &Type {
match self {
Self::Pos(ty) | Self::Kw { ty, .. } | Self::KwWithDefault { ty, .. } => ty,
@ -713,6 +720,55 @@ impl SubrType {
.map(|t| (t.name().cloned(), t.typ().ownership()));
ArgsOwnership::new(nd_args, var_args, d_args, kw_var_args)
}
pub fn _replace(mut self, target: &Type, to: &Type) -> Self {
for nd in self.non_default_params.iter_mut() {
*nd.typ_mut() = std::mem::take(nd.typ_mut())._replace(target, to);
}
if let Some(var) = self.var_params.as_mut() {
*var.as_mut().typ_mut() = std::mem::take(var.as_mut().typ_mut())._replace(target, to);
}
for d in self.default_params.iter_mut() {
*d.typ_mut() = std::mem::take(d.typ_mut())._replace(target, to);
}
self.return_t = Box::new(self.return_t._replace(target, to));
self
}
pub fn replace_params(mut self, target_and_to: Vec<(Str, Str)>) -> Self {
for (target, to) in target_and_to {
for nd in self.non_default_params.iter_mut() {
if let Some(name) = nd.name_mut() {
if name == target {
*name = to.clone();
}
}
}
if let Some(var) = self.var_params.as_mut() {
if let Some(name) = var.name_mut() {
if name == target {
*name = to.clone();
}
}
}
for d in self.default_params.iter_mut() {
if let Some(name) = d.name_mut() {
if name == target {
*name = to.clone();
}
}
}
if let Some(kw_var) = self.kw_var_params.as_mut() {
if let Some(name) = kw_var.name_mut() {
if name == target {
*name = to.clone();
}
}
}
*self.return_t = self.return_t.replace_param(&target, &to);
}
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
@ -918,7 +974,7 @@ impl ArgsOwnership {
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CastTarget {
Param {
Arg {
nth: usize,
name: Str,
loc: Location,
@ -934,7 +990,7 @@ pub enum CastTarget {
impl fmt::Display for CastTarget {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Param { nth, name, .. } => write!(f, "{name}#{nth}"),
Self::Arg { name, .. } => write!(f, "{name}"),
Self::Var { name, .. } => write!(f, "{name}"),
Self::Expr(expr) => write!(f, "{expr}"),
}
@ -944,7 +1000,7 @@ impl fmt::Display for CastTarget {
impl Locational for CastTarget {
fn loc(&self) -> Location {
match self {
Self::Param { loc, .. } => *loc,
Self::Arg { loc, .. } => *loc,
Self::Var { loc, .. } => *loc,
Self::Expr(expr) => expr.loc(),
}
@ -952,8 +1008,8 @@ impl Locational for CastTarget {
}
impl CastTarget {
pub const fn param(nth: usize, name: Str, loc: Location) -> Self {
Self::Param { nth, name, loc }
pub const fn arg(nth: usize, name: Str, loc: Location) -> Self {
Self::Arg { nth, name, loc }
}
pub fn expr(expr: Expr) -> Self {
@ -968,6 +1024,14 @@ pub struct GuardType {
pub to: Box<Type>,
}
impl LimitedDisplay for GuardType {
fn limited_fmt<W: std::fmt::Write>(&self, f: &mut W, limit: isize) -> fmt::Result {
write!(f, "{{{} in ", self.target)?;
self.to.limited_fmt(f, limit - 1)?;
write!(f, "}}")
}
}
impl fmt::Display for GuardType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{{{} in {}}}", self.target, self.to)
@ -988,6 +1052,14 @@ impl GuardType {
to: Box::new(to),
}
}
pub fn replace_param(mut self, target: &Str, to: &Str) -> Self {
match &mut self.target {
CastTarget::Arg { name, .. } if name == target => *name = to.clone(),
_ => {}
}
self
}
}
#[derive(Debug, Clone, Hash, Default)]
@ -1383,9 +1455,7 @@ impl LimitedDisplay for Type {
ty.limited_fmt(f, limit - 1)?;
write!(f, ")")
}
Self::Guard(guard) if cfg!(feature = "debug") => {
write!(f, "Guard({guard})")
}
Self::Guard(guard) => guard.limited_fmt(f, limit),
Self::Bounded { sub, sup } => {
if sub.is_union_type() || sub.is_intersection_type() {
write!(f, "(")?;
@ -3614,20 +3684,7 @@ impl Type {
}
Self::NamedTuple(r)
}
Self::Subr(mut subr) => {
for nd in subr.non_default_params.iter_mut() {
*nd.typ_mut() = std::mem::take(nd.typ_mut())._replace(target, to);
}
if let Some(var) = subr.var_params.as_mut() {
*var.as_mut().typ_mut() =
std::mem::take(var.as_mut().typ_mut())._replace(target, to);
}
for d in subr.default_params.iter_mut() {
*d.typ_mut() = std::mem::take(d.typ_mut())._replace(target, to);
}
subr.return_t = Box::new(subr.return_t._replace(target, to));
Self::Subr(subr)
}
Self::Subr(subr) => Self::Subr(subr._replace(target, to)),
Self::Callable { param_ts, return_t } => {
let param_ts = param_ts
.into_iter()
@ -3678,6 +3735,19 @@ impl Type {
}
}
fn replace_param(self, target: &Str, to: &Str) -> Self {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().clone().replace_param(target, to),
Self::Refinement(mut refine) => {
*refine.t = refine.t.replace_param(target, to);
Self::Refinement(refine)
}
Self::And(l, r) => l.replace_param(target, to) & r.replace_param(target, to),
Self::Guard(guard) => Self::Guard(guard.replace_param(target, to)),
_ => self,
}
}
/// TyParam::Value(ValueObj::Type(_)) => TyParam::Type
pub fn normalize(self) -> Self {
match self {

View file

@ -20,3 +20,8 @@ s = "{ \"key\": \"value\" }"
jdata = json.loads(s)
assert jdata in {Str: Str}
assert jdata["key"] == "value"
is_int(x: Obj) = x in Int
y as Obj = 1
assert is_int y
y: Int