use std::cmp::Ordering::*; use std::fmt; use std::ops::{Deref, DerefMut}; use erg_common::traits::Stream; use erg_compiler::erg_parser::ast; use erg_compiler::erg_parser::ast::Module; use erg_compiler::hir; use erg_compiler::hir::HIR; use erg_compiler::lower::ASTLowerer; use erg_compiler::ty::HasType; #[derive(Debug, Clone, PartialEq, Eq)] pub enum ASTDiff { Deletion(usize), Addition(usize, ast::Expr), Modification(usize, ast::Expr), Nop, } impl fmt::Display for ASTDiff { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Deletion(idx) => write!(f, "Deletion({idx})"), Self::Addition(idx, expr) => write!(f, "Addition({idx}, {expr})"), Self::Modification(idx, expr) => write!(f, "Modification({idx}, {expr})"), Self::Nop => write!(f, "Nop"), } } } /// diff(old: {x, y, z}, new: {x, a, y, z}) => ASTDiff::Addition(1) /// diff(old: {x, y}, new: {x, y, a}) => ASTDiff::Addition(2) /// diff(old: {x, y}, new: {x}) => ASTDiff::Deletion(1) /// diff(old: {x, y, z}, new: {x, z}) => ASTDiff::Deletion(1) /// diff(old: {x, y, z}, new: {x, ya, z}) => ASTDiff::Modification(1) /// diff(old: {x, y, z}, new: {x, a, z}) => ASTDiff::Modification(1) /// diff(old: {x, y, z}, new: {x, y, z}) => ASTDiff::Nop impl ASTDiff { pub fn diff, M2: Deref>( old: M1, new: M2, ) -> ASTDiff { match old.len().cmp(&new.len()) { Less => { let idx = new .iter() .zip(old.iter()) .position(|(new, old)| new != old) .unwrap_or(new.len() - 1); Self::Addition(idx, new.get(idx).unwrap().clone()) } Greater => Self::Deletion( old.iter() .zip(new.iter()) .position(|(old, new)| old != new) .unwrap_or(old.len() - 1), ), Equal => old .iter() .zip(new.iter()) .position(|(old, new)| old != new) .map(|idx| Self::Modification(idx, new.get(idx).unwrap().clone())) .unwrap_or(Self::Nop), } } pub const fn is_nop(&self) -> bool { matches!(self, Self::Nop) } pub fn update(self, mut old: impl DerefMut) { match self { Self::Addition(idx, expr) => { if idx > old.len() { old.push(expr); } else { old.insert(idx, expr); } } Self::Deletion(usize) => { if old.get(usize).is_some() { old.remove(usize); } } Self::Modification(idx, expr) => { if let Some(old_expr) = old.get_mut(idx) { *old_expr = expr; } } Self::Nop => {} } } } #[derive(Debug, Clone, PartialEq, Eq)] pub enum HIRDiff { Deletion(usize), Addition(usize, hir::Expr), Modification(usize, hir::Expr), Nop, } impl fmt::Display for HIRDiff { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Deletion(idx) => write!(f, "Deletion({idx})"), Self::Addition(idx, expr) => write!(f, "Addition({idx}, {expr})"), Self::Modification(idx, expr) => write!(f, "Modification({idx}, {expr})"), Self::Nop => write!(f, "Nop"), } } } impl HIRDiff { pub fn new(diff: ASTDiff, lowerer: &mut ASTLowerer) -> Option { match diff { ASTDiff::Deletion(idx) => Some(Self::Deletion(idx)), ASTDiff::Addition(idx, expr) => { let expr = match lowerer.lower_and_resolve_chunk(expr, None) { Ok(expr) => expr, Err((opt_expr, _err)) => opt_expr?, }; Some(Self::Addition(idx, expr)) } ASTDiff::Modification(idx, expr) => { if let ast::Expr::Def(def) | ast::Expr::ClassDef(ast::ClassDef { def, .. }) | ast::Expr::PatchDef(ast::PatchDef { def, .. }) = &expr { if let Some(name) = def.sig.name_as_str() { lowerer.unregister(name); } } let expr = match lowerer.lower_and_resolve_chunk(expr, None) { Ok(expr) => expr, Err((opt_expr, _err)) => opt_expr?, }; Some(Self::Modification(idx, expr)) } ASTDiff::Nop => Some(Self::Nop), } } pub fn update>(self, mut old: H) { match self { Self::Addition(idx, expr) => { if idx > old.module.len() { old.module.push(expr); } else { old.module.insert(idx, expr); } } Self::Deletion(usize) => { if old.module.get(usize).is_some() { old.module.remove(usize); } } Self::Modification(idx, expr) => { if let Some(old_expr) = old.module.get_mut(idx) { *old_expr = expr; } } Self::Nop => {} } } pub fn fix(ast: &ast::Module, hir: &mut hir::Module, lowerer: &mut ASTLowerer) -> usize { let mut fixed = 0; for (ast_chunk, chunk) in ast.iter().zip(hir.iter_mut()) { if ast_chunk.name() != chunk.name() { continue; } if chunk.ref_t().contains_failure() { match lowerer.lower_and_resolve_chunk(ast_chunk.clone(), None) { Ok(expr) | Err((Some(expr), _)) => { *chunk = expr; fixed += 1; } _ => {} } } } fixed } }