fix: cyclic reference modules bugs

This commit is contained in:
Shunsuke Shibayama 2023-12-28 14:59:14 +09:00
parent 26c758e67f
commit a2d1809cee
21 changed files with 266 additions and 162 deletions

View file

@ -46,18 +46,14 @@ impl<Checker: BuildRunnable, Parser: Parsable> Server<Checker, Parser> {
{
if let Some(def) = self.get_min::<Def>(uri, pos) {
if def.def_kind().is_import() {
if vi.t.is_module() {
if let Some(path) = self
.get_local_ctx(uri, pos)
.first()
.and_then(|ctx| ctx.get_path_with_mod_t(&vi.t))
{
let mod_uri = Url::from_file_path(path).unwrap();
return Ok(Some(lsp_types::Location::new(
mod_uri,
lsp_types::Range::default(),
)));
}
if let Some(mod_uri) =
vi.t.module_path()
.and_then(|path| Url::from_file_path(path).ok())
{
return Ok(Some(lsp_types::Location::new(
mod_uri,
lsp_types::Range::default(),
)));
} else {
// line of module member definitions may no longer match after the desugaring process
let mod_t = def.body.ref_t();

View file

@ -4,7 +4,7 @@ use std::fmt;
use std::ops::Deref;
use std::path::{Component, Path, PathBuf};
use crate::normalize_path;
use crate::{normalize_path, Str};
/// Guaranteed equivalence path.
///
@ -233,3 +233,24 @@ pub fn squash(path: PathBuf) -> PathBuf {
pub fn remove_verbatim(path: &Path) -> String {
path.to_string_lossy().replace("\\\\?\\", "")
}
/// e.g. http.d/client.d.er -> http.client
/// math.d.er -> math
pub fn mod_name(path: &Path) -> Str {
let mut name = path
.file_name()
.unwrap()
.to_str()
.unwrap()
.trim_end_matches(".d.er")
.to_string();
for parent in path.components().rev().skip(1) {
let parent = parent.as_os_str().to_str().unwrap();
if parent.ends_with(".d") {
name = parent.trim_end_matches(".d").to_string() + "." + &name;
} else {
break;
}
}
Str::from(name)
}

View file

@ -169,6 +169,16 @@ impl<ASTBuilder: ASTBuildable> GenericHIRBuilder<ASTBuilder> {
}
}
pub fn new_submodule(mut mod_ctx: ModuleContext, name: &str) -> Self {
mod_ctx
.context
.grow(name, ContextKind::Module, VisibilityModifier::Private, None);
Self {
ownership_checker: OwnershipChecker::new(mod_ctx.get_top_cfg()),
lowerer: GenericASTLowerer::new_with_ctx(mod_ctx),
}
}
pub fn check(&mut self, ast: AST, mode: &str) -> Result<CompleteArtifact, IncompleteArtifact> {
let mut artifact = self.lowerer.lower(ast, mode)?;
let effect_checker = SideEffectChecker::new(self.cfg().clone());

View file

@ -21,7 +21,7 @@ use erg_common::error::MultiErrorDisplay;
use erg_common::io::Input;
#[allow(unused)]
use erg_common::log;
use erg_common::pathutil::NormalizedPathBuf;
use erg_common::pathutil::{mod_name, NormalizedPathBuf};
use erg_common::spawn::spawn_new_thread;
use erg_common::str::Str;
use erg_common::traits::{ExitStatus, New, Runnable, Stream};
@ -572,10 +572,9 @@ impl<ASTBuilder: ASTBuildable, HIRBuilder: Buildable>
self.resolve(&mut ast, &import_cfg)
{
*expr = Expr::InlineModule(InlineModule::new(
submod_input.clone(),
Input::file(import_path.to_path_buf()),
ast,
call.clone(),
import_path,
));
if path != from_path {
return Err(ResolveError::CycleDetected { path, submod_input });
@ -596,7 +595,7 @@ impl<ASTBuilder: ASTBuildable, HIRBuilder: Buildable>
let mut graph = self.shared.graph.clone_inner();
let mut ancestors = graph.ancestors(&path).into_vec();
while let Some(ancestor) = ancestors.pop() {
if self.cyclic.contains(&ancestor) || graph.ancestors(&ancestor).is_empty() {
if graph.ancestors(&ancestor).is_empty() {
graph.remove(&ancestor);
if let Some((__name__, ancestor_ast)) = self.asts.remove(&ancestor) {
self.start_analysis_process(ancestor_ast, __name__, ancestor);
@ -676,27 +675,6 @@ impl<ASTBuilder: ASTBuildable, HIRBuilder: Buildable>
self.shared.promises.insert(path, handle);
}
/// e.g. http.d/client.d.er -> http.client
/// math.d.er -> math
fn mod_name(&self, path: &Path) -> Str {
let mut name = path
.file_name()
.unwrap()
.to_str()
.unwrap()
.trim_end_matches(".d.er")
.to_string();
for parent in path.components().rev().skip(1) {
let parent = parent.as_os_str().to_str().unwrap();
if parent.ends_with(".d") {
name = parent.trim_end_matches(".d").to_string() + "." + &name;
} else {
break;
}
}
Str::from(name)
}
/// FIXME: bug with inter-process sharing of type variables (pyimport "math")
fn build_decl_mod(&self, ast: AST, path: NormalizedPathBuf) {
let py_mod_cache = &self.shared.py_mod_cache;
@ -709,8 +687,7 @@ impl<ASTBuilder: ASTBuildable, HIRBuilder: Buildable>
} else {
None
};
let mut builder =
HIRBuilder::inherit_with_name(cfg, self.mod_name(&path), self.shared.clone());
let mut builder = HIRBuilder::inherit_with_name(cfg, mod_name(&path), self.shared.clone());
match builder.build_from_ast(ast, "declare") {
Ok(artifact) => {
let ctx = builder.pop_context().unwrap();

View file

@ -809,15 +809,9 @@ impl Context {
let Some(ctx) = self.get_nominal_type_ctx(other) else {
return Dict::new();
};
let mod_fields = if other.is_module() {
if let Ok(ValueObj::Str(mod_name)) =
ValueObj::try_from(other.typarams()[0].clone())
{
self.get_mod(&mod_name)
.map_or(Dict::new(), |ctx| ctx.local_dir())
} else {
Dict::new()
}
let mod_fields = if let Some(path) = other.module_path() {
self.get_mod_with_path(&path)
.map_or(Dict::new(), |ctx| ctx.local_dir())
} else {
Dict::new()
};
@ -847,9 +841,13 @@ impl Context {
erg_common::fmt_vec(lparams),
erg_common::fmt_vec(rparams)
);
let ctx = self
.get_nominal_type_ctx(typ)
.unwrap_or_else(|| panic!("{typ} is not found"));
let Some(ctx) = self.get_nominal_type_ctx(typ) else {
if DEBUG_MODE {
panic!("{typ} is not found");
} else {
return false;
}
};
let variances = ctx.type_params_variance();
debug_assert_eq!(
lparams.len(),

View file

@ -1415,55 +1415,18 @@ impl Context {
}
}
fn eval_succ_func(&self, val: ValueObj) -> EvalResult<ValueObj> {
match val {
ValueObj::Bool(b) => Ok(ValueObj::Nat(b as u64 + 1)),
ValueObj::Nat(n) => Ok(ValueObj::Nat(n + 1)),
ValueObj::Int(n) => Ok(ValueObj::Int(n + 1)),
// TODO:
ValueObj::Float(n) => Ok(ValueObj::Float(n + f64::EPSILON)),
ValueObj::Inf | ValueObj::NegInf => Ok(val),
_ => Err(EvalErrors::from(EvalError::unreachable(
self.cfg.input.clone(),
fn_name!(),
line!(),
))),
}
}
fn eval_pred_func(&self, val: ValueObj) -> EvalResult<ValueObj> {
match val {
ValueObj::Bool(_) => Ok(ValueObj::Nat(0)),
ValueObj::Nat(n) => Ok(ValueObj::Nat(n.saturating_sub(1))),
ValueObj::Int(n) => Ok(ValueObj::Int(n - 1)),
// TODO:
ValueObj::Float(n) => Ok(ValueObj::Float(n - f64::EPSILON)),
ValueObj::Inf | ValueObj::NegInf => Ok(val),
_ => Err(EvalErrors::from(EvalError::unreachable(
self.cfg.input.clone(),
fn_name!(),
line!(),
))),
}
}
pub(crate) fn eval_app(&self, name: Str, args: Vec<TyParam>) -> EvalResult<TyParam> {
if let Ok(mut value_args) = args
if let Ok(value_args) = args
.iter()
.map(|tp| self.convert_tp_into_value(tp.clone()))
.collect::<Result<Vec<_>, _>>()
{
match &name[..] {
"succ" => self
.eval_succ_func(value_args.remove(0))
.map(TyParam::Value),
"pred" => self
.eval_pred_func(value_args.remove(0))
.map(TyParam::Value),
_ => {
log!(err "eval_app({name}({}))", fmt_vec(&args));
Ok(TyParam::app(name, args))
}
if let Some(ValueObj::Subr(subr)) = self.rec_get_const_obj(&name) {
let args = ValueArgs::pos_only(value_args);
self.call(subr.clone(), args, ().loc())
} else {
log!(err "eval_app({name}({}))", fmt_vec(&args));
Ok(TyParam::app(name, args))
}
} else {
log!(err "eval_app({name}({}))", fmt_vec(&args));

View file

@ -1,5 +1,6 @@
use std::fmt::Display;
use std::mem;
use std::path::Path;
use erg_common::dict::Dict;
#[allow(unused_imports)]
@ -677,3 +678,86 @@ pub(crate) fn as_record(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<T
}
Ok(ValueObj::builtin_type(Type::Record(dict)).into())
}
pub(crate) fn resolve_path_func(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<TyParam> {
let path = args
.remove_left_or_key("Path")
.ok_or_else(|| not_passed("Path"))?;
let path = match &path {
ValueObj::Str(s) => Path::new(&s[..]),
other => {
return Err(type_mismatch("Str", other, "Path"));
}
};
let Some(path) = ctx.cfg.input.resolve_path(path) else {
return Err(ErrorCore::new(
vec![SubMessage::only_loc(Location::Unknown)],
format!("Path {} is not found", path.display()),
line!() as usize,
ErrorKind::IoError,
Location::Unknown,
)
.into());
};
Ok(ValueObj::Str(path.to_string_lossy().into()).into())
}
pub(crate) fn resolve_decl_path_func(
mut args: ValueArgs,
ctx: &Context,
) -> EvalValueResult<TyParam> {
let path = args
.remove_left_or_key("Path")
.ok_or_else(|| not_passed("Path"))?;
let path = match &path {
ValueObj::Str(s) => Path::new(&s[..]),
other => {
return Err(type_mismatch("Str", other, "Path"));
}
};
let Some(path) = ctx.cfg.input.resolve_decl_path(path) else {
return Err(ErrorCore::new(
vec![SubMessage::only_loc(Location::Unknown)],
format!("Path {} is not found", path.display()),
line!() as usize,
ErrorKind::IoError,
Location::Unknown,
)
.into());
};
Ok(ValueObj::Str(path.to_string_lossy().into()).into())
}
pub(crate) fn succ_func(mut args: ValueArgs, _ctx: &Context) -> EvalValueResult<TyParam> {
let val = args
.remove_left_or_key("Value")
.ok_or_else(|| not_passed("Value"))?;
let val = match &val {
ValueObj::Bool(b) => ValueObj::Nat(*b as u64 + 1),
ValueObj::Nat(n) => ValueObj::Nat(n + 1),
ValueObj::Int(n) => ValueObj::Int(n + 1),
ValueObj::Float(n) => ValueObj::Float(n + f64::EPSILON),
v @ (ValueObj::Inf | ValueObj::NegInf) => v.clone(),
_ => {
return Err(type_mismatch("Number", val, "Value"));
}
};
Ok(val.into())
}
pub(crate) fn pred_func(mut args: ValueArgs, _ctx: &Context) -> EvalValueResult<TyParam> {
let val = args
.remove_left_or_key("Value")
.ok_or_else(|| not_passed("Value"))?;
let val = match &val {
ValueObj::Bool(b) => ValueObj::Nat((*b as u64).saturating_sub(1)),
ValueObj::Nat(n) => ValueObj::Nat(n.saturating_sub(1)),
ValueObj::Int(n) => ValueObj::Int(n - 1),
ValueObj::Float(n) => ValueObj::Float(n - f64::EPSILON),
v @ (ValueObj::Inf | ValueObj::NegInf) => v.clone(),
_ => {
return Err(type_mismatch("Number", val, "Value"));
}
};
Ok(val.into())
}

View file

@ -136,7 +136,7 @@ impl Context {
let t_import = nd_func(
vec![anon(tp_enum(Str, set! {Path.clone()}))],
None,
module(Path.clone()),
module(TyParam::app(FUNC_RESOLVE_PATH.into(), vec![Path.clone()])),
)
.quantify();
let t_isinstance = nd_func(
@ -231,7 +231,7 @@ impl Context {
let t_pyimport = nd_func(
vec![anon(tp_enum(Str, set! {Path.clone()}))],
None,
py_module(Path),
py_module(TyParam::app(FUNC_RESOLVE_DECL_PATH.into(), vec![Path])),
)
.quantify();
let t_pycompile = nd_func(
@ -664,7 +664,33 @@ impl Context {
TraitType,
);
let patch = ConstSubr::Builtin(BuiltinConstSubr::new(PATCH, patch_func, patch_t, None));
self.register_builtin_const(PATCH, vis, ValueObj::Subr(patch));
self.register_builtin_const(PATCH, vis.clone(), ValueObj::Subr(patch));
let t_resolve_path = nd_func(vec![kw(KW_PATH, Str)], None, mono(GENERIC_MODULE));
let resolve_path = ConstSubr::Builtin(BuiltinConstSubr::new(
FUNC_RESOLVE_PATH,
resolve_path_func,
t_resolve_path,
None,
));
self.register_builtin_const(FUNC_RESOLVE_PATH, vis.clone(), ValueObj::Subr(resolve_path));
let t_resolve_decl_path = nd_func(vec![kw(KW_PATH, Str)], None, mono(GENERIC_MODULE));
let resolve_decl_path = ConstSubr::Builtin(BuiltinConstSubr::new(
FUNC_RESOLVE_DECL_PATH,
resolve_decl_path_func,
t_resolve_decl_path,
None,
));
self.register_builtin_const(
FUNC_RESOLVE_DECL_PATH,
vis.clone(),
ValueObj::Subr(resolve_decl_path),
);
let t_succ = nd_func(vec![kw(KW_N, Nat)], None, Nat);
let succ = ConstSubr::Builtin(BuiltinConstSubr::new(FUNC_SUCC, succ_func, t_succ, None));
self.register_builtin_const(FUNC_SUCC, vis.clone(), ValueObj::Subr(succ));
let t_pred = nd_func(vec![kw(KW_N, Nat)], None, Nat);
let pred = ConstSubr::Builtin(BuiltinConstSubr::new(FUNC_PRED, pred_func, t_pred, None));
self.register_builtin_const(FUNC_PRED, vis.clone(), ValueObj::Subr(pred));
}
pub(super) fn init_builtin_py_specific_funcs(&mut self) {

View file

@ -411,6 +411,8 @@ const FUNC_GETATTR: &str = "getattr";
const FUNC_SETATTR: &str = "setattr";
const FUNC_DELATTR: &str = "delattr";
const FUNC_NEARLY_EQ: &str = "nearly_eq";
const FUNC_RESOLVE_PATH: &str = "ResolvePath";
const FUNC_RESOLVE_DECL_PATH: &str = "ResolveDeclPath";
const OP_EQ: &str = "__eq__";
const OP_HASH: &str = "__hash__";
@ -580,6 +582,7 @@ const KW_CHARS: &str = "chars";
const KW_OTHER: &str = "other";
const KW_CONFLICT_RESOLVER: &str = "conflict_resolver";
const KW_EPSILON: &str = "epsilon";
const KW_PATH: &str = "Path";
pub fn builtins_path() -> PathBuf {
erg_pystd_path().join("builtins.d.er")

View file

@ -2960,25 +2960,16 @@ impl Context {
}
pub fn get_mod_with_t(&self, mod_t: &Type) -> Option<&Context> {
self.get_mod_with_path(&self.get_path_with_mod_t(mod_t)?)
}
pub fn get_path_with_mod_t(&self, mod_t: &Type) -> Option<PathBuf> {
let tps = mod_t.typarams();
let Some(TyParam::Value(ValueObj::Str(path))) = tps.get(0) else {
return None;
};
if mod_t.is_erg_module() {
self.cfg.input.resolve_path(Path::new(&path[..]))
} else if mod_t.is_py_module() {
self.cfg.input.resolve_decl_path(Path::new(&path[..]))
} else {
None
}
self.get_mod_with_path(&mod_t.module_path()?)
}
// rec_get_const_localとは違い、位置情報を持たないしエラーとならない
pub(crate) fn rec_get_const_obj(&self, name: &str) -> Option<&ValueObj> {
if name.split('.').count() > 1 {
let typ = Type::Mono(Str::rc(name));
let namespace = self.get_namespace(&typ.namespace())?;
return namespace.rec_get_const_obj(&typ.local_name());
}
#[cfg(feature = "py_compat")]
let name = self.erg_to_py_names.get(name).map_or(name, |s| &s[..]);
if name == "Self" {
@ -3032,9 +3023,8 @@ impl Context {
pub(crate) fn get_namespace_path(&self, namespace: &Str) -> Option<PathBuf> {
// get the true name
let namespace = if let Some((_, vi)) = self.get_var_info(namespace) {
// m: PyModule("math") -> math
if vi.t.is_module() {
vi.t.typarams()[0].to_string().replace('"', "").into()
if let Some(path) = vi.t.module_path() {
return Some(path);
} else {
namespace.clone()
}

View file

@ -1135,8 +1135,10 @@ impl Context {
}
pub(crate) fn path(&self) -> Str {
// NOTE: this need to be changed if we want to support nested classes/traits
if let Some(outer) = self.get_outer() {
// NOTE: maybe this need to be changed if we want to support nested classes/traits
if self.kind == ContextKind::Module {
self.name.replace(".__init__", "").into()
} else if let Some(outer) = self.get_outer() {
outer.path()
} else {
self.name.replace(".__init__", "").into()
@ -1217,7 +1219,9 @@ impl Context {
vis: VisibilityModifier,
tv_cache: Option<TyVarCache>,
) {
let name = if vis.is_public() {
let name = if kind.is_module() {
name.into()
} else if vis.is_public() {
format!("{parent}.{name}", parent = self.name)
} else {
format!("{parent}::{name}", parent = self.name)

View file

@ -20,8 +20,8 @@ use ast::{
use erg_parser::ast::{self, ClassAttr, TypeSpecWithOp};
use crate::ty::constructors::{
free_var, func, func0, func1, proc, ref_, ref_mut, str_dict_t, tp_enum, unknown_len_array_t,
v_enum,
free_var, func, func0, func1, module, proc, py_module, ref_, ref_mut, str_dict_t, tp_enum,
unknown_len_array_t, v_enum,
};
use crate::ty::free::{Constraint, HasLevel};
use crate::ty::typaram::TyParam;
@ -1137,22 +1137,19 @@ impl Context {
let Ok(mod_name) = hir::Literal::try_from(mod_name.token.clone()) else {
return Ok(());
};
let res = self
.import_mod(call.additional_operation().unwrap(), &mod_name)
.map(|_path| ());
let arg = TyParam::Value(ValueObj::Str(
mod_name.token.content.replace('\"', "").into(),
));
let typ = if def.def_kind().is_erg_import() {
Type::Poly {
name: Str::ever("Module"),
params: vec![arg],
}
let path = self.import_mod(call.additional_operation().unwrap(), &mod_name);
let arg = if let Ok(path) = &path {
TyParam::Value(ValueObj::Str(path.to_string_lossy().into()))
} else {
Type::Poly {
name: Str::ever("PyModule"),
params: vec![arg],
}
TyParam::Value(ValueObj::Str(
mod_name.token.content.replace('\"', "").into(),
))
};
let res = path.map(|_path| ());
let typ = if def.def_kind().is_erg_import() {
module(arg)
} else {
py_module(arg)
};
let Some(ident) = def.sig.ident() else {
return res;

View file

@ -14,7 +14,6 @@ use erg_common::Str;
use erg_parser::ast::{DefId, OperationKind};
use erg_parser::token::{Token, TokenKind, DOT, EQUAL};
use crate::ty::typaram::TyParam;
use crate::ty::value::ValueObj;
use crate::ty::HasType;
@ -377,11 +376,9 @@ impl<'a> HIRLinker<'a> {
/// ```
fn replace_erg_import(&self, expr: &mut Expr) {
let line = expr.ln_begin().unwrap_or(0);
let TyParam::Value(ValueObj::Str(path)) = expr.ref_t().typarams().remove(0) else {
let Some(path) = expr.ref_t().module_path() else {
unreachable!()
};
let path = Path::new(&path[..]);
let path = self.cfg.input.resolve_real_path(path).unwrap();
// # module.er
// self = import "module"
// ↓

View file

@ -10,6 +10,7 @@ use erg_common::dict;
use erg_common::dict::Dict;
use erg_common::error::{Location, MultiErrorDisplay};
use erg_common::fresh::FreshNameGenerator;
use erg_common::pathutil::mod_name;
use erg_common::set;
use erg_common::set::Set;
use erg_common::traits::{ExitStatus, Locational, NoTypeDisplay, Runnable, Stream};
@ -2961,10 +2962,11 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
expect: Option<&Type>,
) -> hir::Call {
log!(info "entered {}", fn_name!());
let path = inline.module_path;
let path = inline.input.path().to_path_buf();
let parent = self.get_mod_ctx().context.get_module().unwrap().clone();
let mod_ctx = ModuleContext::new(parent, dict! {});
let mut builder = GenericHIRBuilder::<A>::new_with_ctx(mod_ctx);
let mod_name = mod_name(&path);
let mut builder = GenericHIRBuilder::<A>::new_submodule(mod_ctx, &mod_name);
builder.lowerer.module.context.cfg.input = inline.input.clone();
builder.cfg_mut().input = inline.input.clone();
let mode = if path.to_string_lossy().ends_with("d.er") {

View file

@ -97,7 +97,7 @@ pub struct ModuleCache {
impl fmt::Display for ModuleCache {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ModuleCache {{")?;
writeln!(f, "ModuleCache {{")?;
for (path, entry) in self.cache.iter() {
writeln!(f, "{}: {}, ", path.display(), entry)?;
}

View file

@ -67,6 +67,20 @@ impl ValueArgs {
ValueArgs { pos_args, kw_args }
}
pub fn empty() -> Self {
ValueArgs {
pos_args: Vec::new(),
kw_args: Dict::new(),
}
}
pub fn pos_only(pos_args: Vec<ValueObj>) -> Self {
ValueArgs {
pos_args,
kw_args: Dict::new(),
}
}
pub fn remove_left_or_key(&mut self, key: &str) -> Option<ValueObj> {
if !self.pos_args.is_empty() {
Some(self.pos_args.remove(0))

View file

@ -49,9 +49,9 @@ use crate::context::eval::UndoableLinkedList;
use self::constructors::{bounded, free_var, named_free_var, proj_call, subr_t};
pub const STR_OMIT_THRESHOLD: usize = 16;
pub const CONTAINER_OMIT_THRESHOLD: usize = 8;
pub const DEFAULT_PARAMS_THRESHOLD: usize = 5;
pub const STR_OMIT_THRESHOLD: usize = if DEBUG_MODE { 100 } else { 16 };
pub const CONTAINER_OMIT_THRESHOLD: usize = if DEBUG_MODE { 100 } else { 8 };
pub const DEFAULT_PARAMS_THRESHOLD: usize = if DEBUG_MODE { 100 } else { 5 };
/// cloneのコストがあるためなるべく.ref_tを使うようにすること
/// いくつかの構造体は直接Typeを保持していないので、その場合は.tを使う
@ -3949,6 +3949,21 @@ impl Type {
_ => {}
}
}
pub fn module_path(&self) -> Option<PathBuf> {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().module_path(),
Self::Refinement(refine) => refine.t.module_path(),
_ if self.is_module() => {
let tps = self.typarams();
let Some(TyParam::Value(ValueObj::Str(path))) = tps.get(0) else {
return None;
};
Some(PathBuf::from(&path[..]))
}
_ => None,
}
}
}
pub struct ReplaceTable<'t> {

View file

@ -6,7 +6,6 @@ use std::fmt::Write as _;
use erg_common::consts::ERG_MODE;
use erg_common::error::Location;
use erg_common::io::Input;
use erg_common::pathutil::NormalizedPathBuf;
use erg_common::set::Set as HashSet;
// use erg_common::dict::Dict as HashMap;
use erg_common::traits::{Locational, NestedDisplay, Stream};
@ -6168,7 +6167,6 @@ pub struct InlineModule {
pub input: Input,
pub ast: AST,
pub import: Call,
pub module_path: NormalizedPathBuf,
}
impl NestedDisplay for InlineModule {
@ -6190,12 +6188,7 @@ impl InlineModule {
}
impl InlineModule {
pub const fn new(input: Input, ast: AST, import: Call, module_path: NormalizedPathBuf) -> Self {
Self {
input,
ast,
import,
module_path,
}
pub const fn new(input: Input, ast: AST, import: Call) -> Self {
Self { input, ast, import }
}
}

View file

@ -408,12 +408,7 @@ impl Desugarer {
chunks.push(desugar(chunk));
}
let ast = AST::new(inline.ast.name, Module::new(chunks));
Expr::InlineModule(InlineModule::new(
inline.input,
ast,
inline.import,
inline.module_path,
))
Expr::InlineModule(InlineModule::new(inline.input, ast, inline.import))
}
Expr::Dummy(exprs) => {
let loc = exprs.loc;

View file

@ -0,0 +1,14 @@
datasets = pyimport "torchvision/datasets"
transforms = pyimport "torchvision/transforms"
data = pyimport "torch/utils/data"
training_data = datasets.FashionMNIST(
root:="target/data",
train:=True,
download:=True,
transform:=transforms.ToTensor(),
)
train_dataloader = data.DataLoader(training_data, batch_size:=64)
for! train_dataloader, ((x, y),) =>
print! x.shape, y.shape

View file

@ -245,6 +245,11 @@ fn exec_many_import() -> Result<(), ()> {
expect_success("tests/should_ok/many_import/many_import.er", 0)
}
#[test]
fn exec_many_import_pytorch() -> Result<(), ()> {
expect_compile_success("tests/should_ok/many_import/pytorch.er", 0)
}
#[test]
fn exec_map() -> Result<(), ()> {
expect_success("tests/should_ok/map.er", 0)