feat!: change constructor syntax: C::__new__() -> C()

This commit is contained in:
Shunsuke Shibayama 2024-02-08 02:42:49 +09:00
parent 36fcc8cb79
commit fce88717b0
16 changed files with 90 additions and 56 deletions

View file

@ -116,9 +116,14 @@ fn escape_name(
fn escape_ident(ident: Identifier) -> Str { fn escape_ident(ident: Identifier) -> Str {
let vis = ident.vis(); let vis = ident.vis();
if &ident.inspect()[..] == "Self" { if &ident.inspect()[..] == "Self" {
let Ok(ty) = <&Type>::try_from(ident.vi.t.singleton_value().unwrap()) else { // reference the self type or the self type constructor
unreachable!() let ty = ident
}; .vi
.t
.singleton_value()
.and_then(|tp| <&Type>::try_from(tp).ok())
.or_else(|| ident.vi.t.return_t())
.unwrap();
escape_name( escape_name(
&ty.local_name(), &ty.local_name(),
&ident.vi.vis.modifier, &ident.vi.vis.modifier,
@ -925,9 +930,6 @@ impl PyCodeGenerator {
fn emit_load_method_instr(&mut self, ident: Identifier) { fn emit_load_method_instr(&mut self, ident: Identifier) {
log!(info "entered {} ({ident})", fn_name!()); log!(info "entered {} ({ident})", fn_name!());
if &ident.inspect()[..] == "__new__" {
log!("{:?}", ident.vi);
}
let escaped = escape_ident(ident); let escaped = escape_ident(ident);
let name = self let name = self
.local_search(&escaped, BoundAttr) .local_search(&escaped, BoundAttr)
@ -3450,11 +3452,12 @@ impl PyCodeGenerator {
self.emit_store_instr(Identifier::public("__module__"), Name); self.emit_store_instr(Identifier::public("__module__"), Name);
self.emit_load_const(name); self.emit_load_const(name);
self.emit_store_instr(Identifier::public("__qualname__"), Name); self.emit_store_instr(Identifier::public("__qualname__"), Name);
self.emit_init_method(&class.sig, class.__new__.clone()); let mut methods = ClassDef::take_all_methods(class.methods_list);
let __init__ = methods.remove_def("__init__");
self.emit_init_method(&class.sig, __init__, class.__new__.clone());
if class.need_to_gen_new { if class.need_to_gen_new {
self.emit_new_func(&class.sig, class.__new__); self.emit_new_func(&class.sig, class.__new__);
} }
let methods = ClassDef::take_all_methods(class.methods_list);
if !methods.is_empty() { if !methods.is_empty() {
self.emit_simple_block(methods); self.emit_simple_block(methods);
} }
@ -3494,7 +3497,7 @@ impl PyCodeGenerator {
unit.codeobj unit.codeobj
} }
fn emit_init_method(&mut self, sig: &Signature, __new__: Type) { fn emit_init_method(&mut self, sig: &Signature, __init__: Option<Def>, __new__: Type) {
log!(info "entered {}", fn_name!()); log!(info "entered {}", fn_name!());
let new_first_param = __new__.non_default_params().unwrap().first(); let new_first_param = __new__.non_default_params().unwrap().first();
let line = sig.ln_begin().unwrap_or(0); let line = sig.ln_begin().unwrap_or(0);
@ -3539,6 +3542,9 @@ impl PyCodeGenerator {
vec![], vec![],
); );
let mut attrs = vec![]; let mut attrs = vec![];
if let Some(__init__) = __init__ {
attrs.extend(__init__.body.block.clone());
}
match new_first_param.map(|pt| pt.typ()) { match new_first_param.map(|pt| pt.typ()) {
// namedtupleは仕様上::xなどの名前を使えない // namedtupleは仕様上::xなどの名前を使えない
// {x = Int; y = Int} // {x = Int; y = Int}
@ -3587,8 +3593,8 @@ impl PyCodeGenerator {
/// ```python /// ```python
/// class C: /// class C:
/// # __new__ => __call__ /// # __new__ => C
/// def new(x): return C.__call__(x) /// def new(x): return C(x)
/// ``` /// ```
fn emit_new_func(&mut self, sig: &Signature, __new__: Type) { fn emit_new_func(&mut self, sig: &Signature, __new__: Type) {
log!(info "entered {}", fn_name!()); log!(info "entered {}", fn_name!());
@ -3596,9 +3602,6 @@ impl PyCodeGenerator {
let line = sig.ln_begin().unwrap_or(0); let line = sig.ln_begin().unwrap_or(0);
let mut ident = Identifier::public_with_line(DOT, Str::ever("new"), line); let mut ident = Identifier::public_with_line(DOT, Str::ever("new"), line);
let class = Expr::Accessor(Accessor::Ident(class_ident.clone())); let class = Expr::Accessor(Accessor::Ident(class_ident.clone()));
let mut new_ident = Identifier::private_with_line(Str::ever("__new__"), line);
new_ident.vi.py_name = Some(Str::ever("__call__"));
let class_new = class.attr_expr(new_ident);
ident.vi.t = __new__; ident.vi.t = __new__;
if let Some(new_first_param) = ident.vi.t.non_default_params().unwrap().first() { if let Some(new_first_param) = ident.vi.t.non_default_params().unwrap().first() {
let param_name = new_first_param let param_name = new_first_param
@ -3627,7 +3630,7 @@ impl PyCodeGenerator {
let arg = PosArg::new(Expr::Accessor(Accessor::private_with_line( let arg = PosArg::new(Expr::Accessor(Accessor::private_with_line(
param_name, line, param_name, line,
))); )));
let call = class_new.call_expr(Args::single(arg)); let call = class.call_expr(Args::single(arg));
let block = Block::new(vec![call]); let block = Block::new(vec![call]);
let body = DefBody::new(EQUAL, block, DefId(0)); let body = DefBody::new(EQUAL, block, DefId(0));
self.emit_subr_def(Some(class_ident.inspect()), sig, body); self.emit_subr_def(Some(class_ident.inspect()), sig, body);
@ -3642,7 +3645,7 @@ impl PyCodeGenerator {
sig.t_spec_with_op().cloned(), sig.t_spec_with_op().cloned(),
vec![], vec![],
); );
let call = class_new.call_expr(Args::empty()); let call = class.call_expr(Args::empty());
let block = Block::new(vec![call]); let block = Block::new(vec![call]);
let body = DefBody::new(EQUAL, block, DefId(0)); let body = DefBody::new(EQUAL, block, DefId(0));
self.emit_subr_def(Some(class_ident.inspect()), sig, body); self.emit_subr_def(Some(class_ident.inspect()), sig, body);

View file

@ -1684,10 +1684,11 @@ impl Context {
instance: &Type, instance: &Type,
pos_args: &[hir::PosArg], pos_args: &[hir::PosArg],
kw_args: &[hir::KwArg], kw_args: &[hir::KwArg],
namespace: &Context,
) -> TyCheckResult<SubstituteResult> { ) -> TyCheckResult<SubstituteResult> {
match instance { match instance {
Type::FreeVar(fv) if fv.is_linked() => { Type::FreeVar(fv) if fv.is_linked() => {
self.substitute_call(obj, attr_name, &fv.crack(), pos_args, kw_args) self.substitute_call(obj, attr_name, &fv.crack(), pos_args, kw_args, namespace)
} }
Type::FreeVar(fv) => { Type::FreeVar(fv) => {
if let Some(sub) = fv.get_sub() { if let Some(sub) = fv.get_sub() {
@ -1699,11 +1700,14 @@ impl Context {
instance.destructive_coerce(); instance.destructive_coerce();
if instance.is_quantified_subr() { if instance.is_quantified_subr() {
let instance = self.instantiate(instance.clone(), obj)?; let instance = self.instantiate(instance.clone(), obj)?;
self.substitute_call(obj, attr_name, &instance, pos_args, kw_args)?; self.substitute_call(
obj, attr_name, &instance, pos_args, kw_args, namespace,
)?;
return Ok(SubstituteResult::Coerced(instance)); return Ok(SubstituteResult::Coerced(instance));
} else if get_hash(instance) != hash { } else if get_hash(instance) != hash {
return self return self.substitute_call(
.substitute_call(obj, attr_name, instance, pos_args, kw_args); obj, attr_name, instance, pos_args, kw_args, namespace,
);
} }
} }
} }
@ -1734,7 +1738,7 @@ impl Context {
} }
} }
Type::Refinement(refine) => { Type::Refinement(refine) => {
self.substitute_call(obj, attr_name, &refine.t, pos_args, kw_args) self.substitute_call(obj, attr_name, &refine.t, pos_args, kw_args, namespace)
} }
// instance must be instantiated // instance must be instantiated
Type::Quantified(_) => unreachable_error!(TyCheckErrors, TyCheckError, self), Type::Quantified(_) => unreachable_error!(TyCheckErrors, TyCheckError, self),
@ -1932,7 +1936,9 @@ impl Context {
} }
} }
Type::Failure => Ok(SubstituteResult::Ok), Type::Failure => Ok(SubstituteResult::Ok),
_ => self.substitute_dunder_call(obj, attr_name, instance, pos_args, kw_args), _ => {
self.substitute_dunder_call(obj, attr_name, instance, pos_args, kw_args, namespace)
}
} }
} }
@ -1943,6 +1949,7 @@ impl Context {
instance: &Type, instance: &Type,
pos_args: &[hir::PosArg], pos_args: &[hir::PosArg],
kw_args: &[hir::KwArg], kw_args: &[hir::KwArg],
namespace: &Context,
) -> TyCheckResult<SubstituteResult> { ) -> TyCheckResult<SubstituteResult> {
let ctxs = self let ctxs = self
.get_singular_ctxs_by_hir_expr(obj, self) .get_singular_ctxs_by_hir_expr(obj, self)
@ -1979,8 +1986,16 @@ impl Context {
if let Some(call_vi) = if let Some(call_vi) =
typ_ctx.get_current_scope_var(&VarName::from_static("__call__")) typ_ctx.get_current_scope_var(&VarName::from_static("__call__"))
{ {
if call_vi.vis.is_private() {
self.validate_visibility(
&Identifier::private_with_loc("__call__".into(), obj.loc()),
call_vi,
&self.cfg.input,
namespace,
)?;
}
let instance = self.instantiate_def_type(&call_vi.t)?; let instance = self.instantiate_def_type(&call_vi.t)?;
self.substitute_call(obj, attr_name, &instance, pos_args, kw_args)?; self.substitute_call(obj, attr_name, &instance, pos_args, kw_args, namespace)?;
return Ok(SubstituteResult::__Call__(instance)); return Ok(SubstituteResult::__Call__(instance));
} }
// instance method __call__ // instance method __call__
@ -1991,7 +2006,7 @@ impl Context {
}) })
{ {
let instance = self.instantiate_def_type(&call_vi.t)?; let instance = self.instantiate_def_type(&call_vi.t)?;
self.substitute_call(obj, attr_name, &instance, pos_args, kw_args)?; self.substitute_call(obj, attr_name, &instance, pos_args, kw_args, namespace)?;
return Ok(SubstituteResult::__Call__(instance)); return Ok(SubstituteResult::__Call__(instance));
} }
} }
@ -2356,7 +2371,7 @@ impl Context {
fmt_slice(kw_args) fmt_slice(kw_args)
); );
let instance = match self let instance = match self
.substitute_call(obj, attr_name, &instance, pos_args, kw_args) .substitute_call(obj, attr_name, &instance, pos_args, kw_args, namespace)
.map_err(|errs| { .map_err(|errs| {
( (
Some(VarInfo { Some(VarInfo {

View file

@ -1331,7 +1331,7 @@ impl Context {
} }
} }
/// e.g. ::__new__ /// e.g. `::__call__`
fn register_fixed_auto_impl( fn register_fixed_auto_impl(
&mut self, &mut self,
name: &'static str, name: &'static str,
@ -1678,18 +1678,18 @@ impl Context {
func0(gen.typ().clone()) func0(gen.typ().clone())
}; };
methods.register_fixed_auto_impl( methods.register_fixed_auto_impl(
"__new__", "__call__",
new_t.clone(), new_t.clone(),
Immutable, Immutable,
Visibility::BUILTIN_PRIVATE, Visibility::private(ctx.name.clone()),
Some("__call__".into()), None,
)?; )?;
// 必要なら、ユーザーが独自に上書きする // 必要なら、ユーザーが独自に上書きする
methods.register_auto_impl( methods.register_auto_impl(
"new", "new",
new_t, new_t,
Immutable, Immutable,
Visibility::BUILTIN_PUBLIC, Visibility::public(ctx.name.clone()),
None, None,
)?; )?;
ctx.methods_list.push(MethodContext::new( ctx.methods_list.push(MethodContext::new(
@ -1903,10 +1903,10 @@ impl Context {
}; };
if ERG_MODE { if ERG_MODE {
methods.register_fixed_auto_impl( methods.register_fixed_auto_impl(
"__new__", "__call__",
new_t.clone(), new_t.clone(),
Immutable, Immutable,
Visibility::BUILTIN_PRIVATE, Visibility::private(ctx.name.clone()),
Some("__call__".into()), Some("__call__".into()),
)?; )?;
// users can override this if necessary // users can override this if necessary
@ -1914,7 +1914,7 @@ impl Context {
"new", "new",
new_t, new_t,
Immutable, Immutable,
Visibility::BUILTIN_PUBLIC, Visibility::public(ctx.name.clone()),
None, None,
)?; )?;
} else { } else {
@ -1922,7 +1922,7 @@ impl Context {
"__call__", "__call__",
new_t, new_t,
Immutable, Immutable,
Visibility::BUILTIN_PUBLIC, Visibility::public(ctx.name.clone()),
Some("__call__".into()), Some("__call__".into()),
)?; )?;
} }

View file

@ -1596,6 +1596,21 @@ impl Locational for Block {
} }
} }
impl Block {
pub fn remove_def(&mut self, name: &str) -> Option<Def> {
let mut i = 0;
while i < self.0.len() {
if let Expr::Def(def) = &self.0[i] {
if def.sig.ident().inspect() == name {
return Def::try_from(self.0.remove(i)).ok();
}
}
i += 1;
}
None
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Dummy(Vec<Expr>); pub struct Dummy(Vec<Expr>);

View file

@ -9,6 +9,7 @@
input: Tensor!(T, _), input: Tensor!(T, _),
) -> Tensor!(T, _) ) -> Tensor!(T, _)
.Module. .Module.
__init__: (self: RefMut(.Module)) => NoneType
parameters: (self: Ref(.Module), recurse := Bool) -> Iterator Parameter parameters: (self: Ref(.Module), recurse := Bool) -> Iterator Parameter
named_parameters: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, Parameter)) named_parameters: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, Parameter))
# buffers: (self: Ref(.Module), recurse := Bool) -> Iterator .Tensor! # buffers: (self: Ref(.Module), recurse := Bool) -> Iterator .Tensor!

View file

@ -11,7 +11,7 @@
[name, num] -> [name, num] ->
num_ = nat(num) num_ = nat(num)
assert num_ in Nat assert num_ in Nat
.Identifier::__new__ { .name; .num = num_ } .Identifier { .name; .num = num_ }
_ -> panic "invalid identifier string: \{s}" _ -> panic "invalid identifier string: \{s}"
@Override @Override
__repr__ ref self = "Identifier(\{self.__str__()})" __repr__ ref self = "Identifier(\{self.__str__()})"
@ -24,7 +24,7 @@
do: "\{self.major}.\{self.minor}.\{self.patch}" do: "\{self.major}.\{self.minor}.\{self.patch}"
.SemVer. .SemVer.
new major, minor, patch, pre := None = new major, minor, patch, pre := None =
.SemVer::__new__ { .major; .minor; .patch; .pre } .SemVer { .major; .minor; .patch; .pre }
from_str s: Str = from_str s: Str =
match s.split("."): match s.split("."):
[major, minor, patch] -> [major, minor, patch] ->

View file

@ -20,7 +20,7 @@ Point3D.
# Overloading is prohibited by default. Remove this decorator and check for errors. # Overloading is prohibited by default. Remove this decorator and check for errors.
@Override @Override
new x, y, z = new x, y, z =
Point3D::__new__ {x; y; z} Point3D {x; y; z}
@Override @Override
norm self = self::x**2 + self::y**2 + self::z**2 norm self = self::x**2 + self::y**2 + self::z**2

View file

@ -1,6 +1,6 @@
Point = Class {x = Int; y = Int} Point = Class {x = Int; y = Int}
Point. Point.
new x, y = Point::__new__ {x; y} new x, y = Point {x; y}
norm self = self::x**2 + self::y**2 norm self = self::x**2 + self::y**2
Point|<: Add(Point)|. Point|<: Add(Point)|.
Output = Point Output = Point

View file

@ -1,7 +1,7 @@
IntList = Class NoneType or { .node = Int; .next = IntList } IntList = Class NoneType or { .node = Int; .next = IntList }
IntList. IntList.
null = IntList::__new__ None null = IntList None
insert self, node = IntList::__new__ { .node; .next = self } insert self, node = IntList { .node; .next = self }
fst self = fst self =
match self::base: match self::base:
{ node; next = _ } => node { node; next = _ } => node

View file

@ -2,10 +2,10 @@ name n: Structural { .name = Str } = n.name
C = Class { .name = Str } C = Class { .name = Str }
C. C.
new name = C::__new__ { .name = name } new name = C { .name = name }
D = Class { .name = Str; .id = Nat } D = Class { .name = Str; .id = Nat }
D. D.
new name, id = D::__new__ { .name = name; .id = id } new name, id = D { .name = name; .id = id }
c = C.new "foo" c = C.new "foo"
d = D.new "bar", 1 d = D.new "bar", 1
@ -17,8 +17,8 @@ inner|T| x: Structural { .inner = T } = x.inner
E = Class { .inner = Int } E = Class { .inner = Int }
E. E.
new inner = E::__new__ { .inner = inner } new inner = E { .inner = inner }
__add__ self, other: E = E::__new__ { .inner = self.inner + other.inner } __add__ self, other: E = E { .inner = self.inner + other.inner }
e = E.new 1 e = E.new 1

View file

@ -4,7 +4,7 @@ C = Class()
D = Class { .name = Int; .id = Nat } D = Class { .name = Int; .id = Nat }
D. D.
__add__ self, _ = 1 __add__ self, _ = 1
new name, id = D::__new__ { .name = name; .id = id } new name, id = D { .name = name; .id = id }
c = C.new() c = C.new()
d = D.new 1, 2 d = D.new 1, 2
@ -17,8 +17,8 @@ inner|T: Type| x: Structural { .inner = T } = x.inner
E = Class { .inner = Int } E = Class { .inner = Int }
E. E.
new inner = E::__new__ { .inner = inner } new inner = E { .inner = inner }
__add__ self, other: E = E::__new__ { .inner = self.inner + other.inner } __add__ self, other: E = E { .inner = self.inner + other.inner }
e = E.new 1 e = E.new 1

View file

@ -11,7 +11,7 @@ C.
foo self, x = self.x.foo(x) foo self, x = self.x.foo(x)
D = Class { .y = Int } D = Class { .y = Int }
D. D.
new y = Self::__new__ { .y; } new y = Self { .y; }
@staticmethod @staticmethod
foo x = x + 1 foo x = x + 1
one = Self.new 1 one = Self.new 1

View file

@ -25,7 +25,7 @@ for! [1], _ =>
Versions! = Class Dict! { Str: Array!(SemVer) } Versions! = Class Dict! { Str: Array!(SemVer) }
Versions!. Versions!.
new() = Versions!::__new__ !{:} new() = Versions! !{:}
insert!(ref! self, name: Str, version: SemVer) = insert!(ref! self, name: Str, version: SemVer) =
if! self::base.get(name) == None: if! self::base.get(name) == None:
do!: do!:

View file

@ -3,12 +3,12 @@ C = Class { .x = Int }
C. C.
const = 0 const = 0
C. C.
new x: Int = C::__new__ { .x = x } new x: Int = C { .x = x }
D = Inherit C D = Inherit C
D. D.
@Override @Override
new x: Int = D::__new__ { .x = x } new x: Int = D { .x = x }
d: D = D.new(1) d: D = D.new(1)
print! d print! d

View file

@ -28,7 +28,7 @@ Net = Inherit nn.Module, Additional := {
Net. Net.
@Override @Override
new() = Net::__new__ { new() = Net {
conv1 = nn.Conv2d(1, 16, kernel_size:=3, stride:=1, padding:=1); conv1 = nn.Conv2d(1, 16, kernel_size:=3, stride:=1, padding:=1);
conv2 = nn.Conv2d(16, 32, kernel_size:=3, stride:=1, padding:=1); conv2 = nn.Conv2d(16, 32, kernel_size:=3, stride:=1, padding:=1);
pool = nn.MaxPool2d(kernel_size:=2, stride:=2); pool = nn.MaxPool2d(kernel_size:=2, stride:=2);

View file

@ -4,16 +4,16 @@ C::
C. C.
method self = method self =
_ = self _ = self
x = C::__new__() x = C()
y: C::X = Self::__new__() y: C::X = Self()
log x, y log x, y
.C2 = Class { .x = Int } .C2 = Class { .x = Int }
.C2. .C2.
method self = method self =
_ = self _ = self
x = .C2::__new__ { .x = 1 } x = .C2 { .x = 1 }
y = Self::__new__ { .x = 1 } y = Self { .x = 1 }
log x, y log x, y
x = C.new() x = C.new()