WIP: submodule resolution bug

This commit is contained in:
Shunsuke Shibayama 2023-12-25 02:14:33 +09:00
parent 1b824d78e1
commit 26c758e67f
11 changed files with 169 additions and 37 deletions

View file

@ -498,11 +498,12 @@ macro_rules! log {
}
}};
(backtrace) => {{
(backtrace $($arg: tt)*) => {{
if cfg!(feature = "debug") {
use $crate::style::*;
$crate::debug_info!();
println!("\n{}", std::backtrace::Backtrace::capture());
println!($($arg)*);
println!("{}", std::backtrace::Backtrace::capture());
}
}};

View file

@ -571,7 +571,12 @@ impl<ASTBuilder: ASTBuildable, HIRBuilder: Buildable>
if let Err(ResolveError::CycleDetected { path, submod_input }) =
self.resolve(&mut ast, &import_cfg)
{
*expr = Expr::InlineModule(InlineModule::new(submod_input.clone(), ast, call.clone()));
*expr = Expr::InlineModule(InlineModule::new(
submod_input.clone(),
ast,
call.clone(),
import_path,
));
if path != from_path {
return Err(ResolveError::CycleDetected { path, submod_input });
} else {

View file

@ -2654,7 +2654,7 @@ impl Context {
if let Some(ctx) = self.get_nominal_type_ctx(sup) {
sup_ctxs.push(ctx);
} else if DEBUG_MODE {
todo!("no ctx for {sup}");
todo!("no ctx ({} / {}) for {sup}", self.name, self.kind);
}
}
Some(vec![ctx].into_iter().chain(sup_ctxs))

View file

@ -61,8 +61,6 @@ pub trait ContextProvider {
}
}
const BUILTINS: &Str = &Str::ever("<builtins>");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ControlKind {
If,
@ -1140,10 +1138,8 @@ impl Context {
// NOTE: this need to be changed if we want to support nested classes/traits
if let Some(outer) = self.get_outer() {
outer.path()
} else if self.kind == ContextKind::Module {
self.name.replace(".__init__", "").into()
} else {
BUILTINS.clone()
self.name.replace(".__init__", "").into()
}
}
@ -1151,7 +1147,7 @@ impl Context {
/// This avoids infinite loops.
pub(crate) fn get_builtins(&self) -> Option<&Context> {
// builtins中で定義した型等はmod_cacheがNoneになっている
if self.kind != ContextKind::Module || &self.path()[..] != "<builtins>" {
if &self.path()[..] != "<builtins>" {
self.shared
.as_ref()
.map(|shared| {
@ -1175,7 +1171,13 @@ impl Context {
outer.get_module()
}
})
.or(Some(self))
.or_else(|| {
if self.kind == ContextKind::Module {
Some(self)
} else {
None
}
})
}
pub(crate) fn get_module_from_stack(&self, path: &NormalizedPathBuf) -> Option<&Context> {
@ -1249,10 +1251,10 @@ impl Context {
self.params.retain(|(_, v)| v.t != Failure);
}
/// Note that the popped context is detached and `outer == None`.
pub fn pop(&mut self) -> Context {
self.check_types();
if let Some(parent) = self.outer.as_mut() {
let parent = mem::take(parent);
if let Some(parent) = self.outer.take() {
let ctx = mem::take(self);
*self = *parent;
log!(info "{}: current namespace: {}", fn_name!(), self.name);

View file

@ -2513,6 +2513,30 @@ impl Context {
#[allow(clippy::single_match)]
match expr {
ast::Expr::Accessor(acc) => self.inc_ref_acc(acc, namespace, tmp_tv_cache),
ast::Expr::BinOp(bin) => {
self.inc_ref_expr(&bin.args[0], namespace, tmp_tv_cache)
|| self.inc_ref_expr(&bin.args[1], namespace, tmp_tv_cache)
}
ast::Expr::UnaryOp(unary) => self.inc_ref_expr(&unary.value(), namespace, tmp_tv_cache),
ast::Expr::Call(call) => {
let mut res = self.inc_ref_expr(&call.obj, namespace, tmp_tv_cache);
for arg in call.args.pos_args() {
if self.inc_ref_expr(&arg.expr, namespace, tmp_tv_cache) {
res = true;
}
}
if let Some(arg) = call.args.var_args() {
if self.inc_ref_expr(&arg.expr, namespace, tmp_tv_cache) {
res = true;
}
}
for arg in call.args.kw_args() {
if self.inc_ref_expr(&arg.expr, namespace, tmp_tv_cache) {
res = true;
}
}
res
}
ast::Expr::Record(ast::Record::Normal(rec)) => {
let mut res = false;
for val in rec.attrs.iter() {
@ -2522,10 +2546,6 @@ impl Context {
}
res
}
ast::Expr::BinOp(bin) => {
self.inc_ref_expr(&bin.args[0], namespace, tmp_tv_cache)
|| self.inc_ref_expr(&bin.args[1], namespace, tmp_tv_cache)
}
ast::Expr::Array(ast::Array::Normal(arr)) => {
let mut res = false;
for val in arr.elems.pos_args().iter() {
@ -2535,6 +2555,24 @@ impl Context {
}
res
}
ast::Expr::Tuple(ast::Tuple::Normal(tup)) => {
let mut res = false;
for val in tup.elems.pos_args().iter() {
if self.inc_ref_expr(&val.expr, namespace, tmp_tv_cache) {
res = true;
}
}
res
}
ast::Expr::Set(ast::Set::Normal(set)) => {
let mut res = false;
for val in set.elems.pos_args().iter() {
if self.inc_ref_expr(&val.expr, namespace, tmp_tv_cache) {
res = true;
}
}
res
}
ast::Expr::Set(ast::Set::Comprehension(comp)) => {
let mut res = false;
for (_, gen) in comp.generators.iter() {
@ -2549,6 +2587,32 @@ impl Context {
}
res
}
ast::Expr::Dict(ast::Dict::Normal(dict)) => {
let mut res = false;
for ast::KeyValue { key, value } in dict.kvs.iter() {
if self.inc_ref_expr(key, namespace, tmp_tv_cache) {
res = true;
}
if self.inc_ref_expr(value, namespace, tmp_tv_cache) {
res = true;
}
}
res
}
ast::Expr::Dict(ast::Dict::Comprehension(comp)) => {
let mut res = false;
for (_, gen) in comp.generators.iter() {
if self.inc_ref_expr(gen, namespace, tmp_tv_cache) {
res = true;
}
}
if let Some(guard) = &comp.guard {
if self.inc_ref_expr(guard, namespace, tmp_tv_cache) {
res = true;
}
}
res
}
other => {
log!(err "inc_ref_expr: {other}");
false

View file

@ -996,6 +996,26 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
{
Ok(hir::Expr::Call(self.lower_call(call, None)))
}
ast::Expr::Compound(compound) => {
let mut chunks = vec![];
for chunk in compound.into_iter() {
let chunk = self.declare_chunk(chunk, true)?;
chunks.push(chunk);
}
Ok(hir::Expr::Compound(hir::Block::new(chunks)))
}
ast::Expr::Dummy(dummy) => {
let mut dummy_ = vec![];
for elem in dummy.into_iter() {
let elem = self.declare_chunk(elem, true)?;
dummy_.push(elem);
}
Ok(hir::Expr::Dummy(hir::Dummy::new(dummy_)))
}
ast::Expr::InlineModule(inline) => {
let import = self.lower_inline_module(inline, None);
Ok(hir::Expr::Call(import))
}
other => Err(LowerErrors::from(LowerError::declare_error(
self.cfg().input.clone(),
line!() as usize,

View file

@ -1,6 +1,11 @@
.modules = pyimport "./modules"
.parameter = pyimport "./parameter"
{.CrossEntropyLoss; .Flatten; .Linear; .Module; .ReLU;} = .modules
{
.CrossEntropyLoss;
.Flatten;
.Linear;
.Module;
.ReLU;
} = .modules
{.Parameter;} = .parameter
{.Module;} = pyimport "./modules/module"

View file

@ -1,5 +1,36 @@
{.Parameter;} = pyimport "torch/nn/parameter"
# {.Tensor;} = pyimport "torch"
.Module: ClassType
.Module.
parameters: (self: Ref(.Module)) -> Iterator .Parameter
parameters: (self: Ref(.Module), recurse := Bool) -> Iterator .Parameter
named_parameters: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, .Parameter))
# buffers: (self: Ref(.Module), recurse := Bool) -> Iterator .Tensor
# named_buffers: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, .Tensor))
children: (self: Ref(.Module)) -> Iterator .Module
named_children: (self: Ref(.Module), prefix := Str) -> Iterator((Str, .Module))
modules: (self: Ref(.Module)) -> Iterator .Module
named_modules: (self: Ref(.Module), memo := {.Module; _}, prefix := Str, remove_duplicate := Bool) -> Iterator((Str, .Module))
train: |T <: .Module|(self: Ref(T), mode := Bool) -> T
eval: |T <: .Module|(self: Ref(T)) -> T
zero_grad!: (self: RefMut(.Module), set_to_none := Bool) => NoneType
compile: (self: Ref(.Module), *args: Obj, **kwargs: Obj) -> .Module
# register_buffer!: (self: RefMut(.Module), name: Str, tensor := Tensor, persistent := Bool) => NoneType
register_parameter!: (self: RefMut(.Module), name: Str, param := .Parameter) => NoneType
add_module!: (self: RefMut(.Module), name: Str, module := .Module) => NoneType
register_module!: (self: RefMut(.Module), name: Str, module := .Module) => NoneType
get_submodule: (self: Ref(.Module), name: Str) -> .Module
get_parameter: (self: Ref(.Module), name: Str) -> .Parameter
# get_buffer: (self: Ref(.Module), name: Str) -> .Tensor
get_extra_state: (self: Ref(.Module)) -> Obj
set_extra_state!: (self: RefMut(.Module), state: Obj) => NoneType
apply!: |T <: .Module|(self: T, fn: (module: RefMut(T)) => NoneType) => T
cuda!: |T <: .Module|(self: T, device := Int) => T
ipu!: |T <: .Module|(self: T, device := Int) => T
xpu!: |T <: .Module|(self: T, device := Int) => T
cpu!: |T <: .Module|(self: T) => T
float: |T <: .Module|(self: T) -> T
double: |T <: .Module|(self: T) -> T
half: |T <: .Module|(self: T) -> T
bfloat16: |T <: .Module|(self: T) -> T
to: |T <: .Module|(self: T, *args: Obj, **kwargs: Obj) -> T

View file

@ -3,7 +3,6 @@
//! ASTLowerer(ASTからHIRへの変換器)を実装
use std::marker::PhantomData;
use std::mem;
use std::path::Path;
use erg_common::config::{ErgConfig, ErgMode};
use erg_common::consts::{ELS, ERG_MODE, PYTHON_MODE};
@ -2956,20 +2955,13 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
Ok(hir::Dummy::new(hir_dummy))
}
fn lower_inline_module(&mut self, inline: InlineModule, expect: Option<&Type>) -> hir::Call {
pub(crate) fn lower_inline_module(
&mut self,
inline: InlineModule,
expect: Option<&Type>,
) -> hir::Call {
log!(info "entered {}", fn_name!());
let Some(ast::Expr::Literal(mod_name)) = inline.import.args.get_left_or_key("Path") else {
unreachable!();
};
let Ok(mod_name) = hir::Literal::try_from(mod_name.token.clone()) else {
unreachable!();
};
let ValueObj::Str(__name__) = &mod_name.value else {
unreachable!();
};
let Some(path) = inline.input.resolve_path(Path::new(&__name__[..])) else {
unreachable!();
};
let path = inline.module_path;
let parent = self.get_mod_ctx().context.get_module().unwrap().clone();
let mod_ctx = ModuleContext::new(parent, dict! {});
let mut builder = GenericHIRBuilder::<A>::new_with_ctx(mod_ctx);

View file

@ -6,6 +6,7 @@ use std::fmt::Write as _;
use erg_common::consts::ERG_MODE;
use erg_common::error::Location;
use erg_common::io::Input;
use erg_common::pathutil::NormalizedPathBuf;
use erg_common::set::Set as HashSet;
// use erg_common::dict::Dict as HashMap;
use erg_common::traits::{Locational, NestedDisplay, Stream};
@ -6167,6 +6168,7 @@ pub struct InlineModule {
pub input: Input,
pub ast: AST,
pub import: Call,
pub module_path: NormalizedPathBuf,
}
impl NestedDisplay for InlineModule {
@ -6188,7 +6190,12 @@ impl InlineModule {
}
impl InlineModule {
pub const fn new(input: Input, ast: AST, import: Call) -> Self {
Self { input, ast, import }
pub const fn new(input: Input, ast: AST, import: Call, module_path: NormalizedPathBuf) -> Self {
Self {
input,
ast,
import,
module_path,
}
}
}

View file

@ -408,7 +408,12 @@ impl Desugarer {
chunks.push(desugar(chunk));
}
let ast = AST::new(inline.ast.name, Module::new(chunks));
Expr::InlineModule(InlineModule::new(inline.input, ast, inline.import))
Expr::InlineModule(InlineModule::new(
inline.input,
ast,
inline.import,
inline.module_path,
))
}
Expr::Dummy(exprs) => {
let loc = exprs.loc;