From 46e687e8d111c93d49463659d22c00ec66899e70 Mon Sep 17 00:00:00 2001 From: Simon Date: Tue, 3 Sep 2024 09:23:28 +0200 Subject: [PATCH] [red-knot] Condense literals display by types (#13185) Co-authored-by: Micha Reiser --- crates/red_knot_python_semantic/src/lib.rs | 1 + crates/red_knot_python_semantic/src/types.rs | 13 + .../src/types/display.rs | 251 ++++++++++++++---- .../src/types/infer.rs | 12 +- 4 files changed, 219 insertions(+), 58 deletions(-) diff --git a/crates/red_knot_python_semantic/src/lib.rs b/crates/red_knot_python_semantic/src/lib.rs index 909c2d8de2..56827bcdd7 100644 --- a/crates/red_knot_python_semantic/src/lib.rs +++ b/crates/red_knot_python_semantic/src/lib.rs @@ -23,3 +23,4 @@ pub(crate) mod site_packages; pub mod types; type FxOrderSet = ordermap::set::OrderSet>; +type FxOrderMap = ordermap::map::OrderMap>; diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index ca28059d45..70cf080ee1 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -224,6 +224,19 @@ impl<'db> Type<'db> { matches!(self, Type::Never) } + /// Returns `true` if this type should be displayed as a literal value. + pub const fn is_literal(&self) -> bool { + matches!( + self, + Type::IntLiteral(_) + | Type::BooleanLiteral(_) + | Type::StringLiteral(_) + | Type::BytesLiteral(_) + | Type::Class(_) + | Type::Function(_) + ) + } + pub fn may_be_unbound(&self, db: &'db dyn Db) -> bool { match self { Type::Unbound => true, diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index 8c3cf67ff5..c47506596d 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -6,12 +6,16 @@ use ruff_python_ast::str::Quote; use ruff_python_literal::escape::AsciiEscape; use crate::types::{IntersectionType, Type, UnionType}; -use crate::Db; +use crate::{Db, FxOrderMap}; impl<'db> Type<'db> { pub fn display(&'db self, db: &'db dyn Db) -> DisplayType<'db> { DisplayType { ty: self, db } } + + fn representation(&'db self, db: &'db dyn Db) -> DisplayRepresentation<'db> { + DisplayRepresentation { db, ty: self } + } } #[derive(Copy, Clone)] @@ -21,6 +25,31 @@ pub struct DisplayType<'db> { } impl Display for DisplayType<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let representation = self.ty.representation(self.db); + if self.ty.is_literal() { + write!(f, "Literal[{representation}]",) + } else { + representation.fmt(f) + } + } +} + +impl std::fmt::Debug for DisplayType<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(self, f) + } +} + +/// Writes the string representation of a type, which is the value displayed either as +/// `Literal[]` or `Literal[, ]` for literal types or as `` for +/// non literals +struct DisplayRepresentation<'db> { + ty: &'db Type<'db>, + db: &'db dyn Db, +} + +impl std::fmt::Display for DisplayRepresentation<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self.ty { Type::Any => f.write_str("Any"), @@ -32,39 +61,27 @@ impl Display for DisplayType<'_> { write!(f, "", file.path(self.db)) } // TODO functions and classes should display using a fully qualified name - Type::Class(class) => write!(f, "Literal[{}]", class.name(self.db)), + Type::Class(class) => f.write_str(class.name(self.db)), Type::Instance(class) => f.write_str(class.name(self.db)), - Type::Function(function) => write!(f, "Literal[{}]", function.name(self.db)), + Type::Function(function) => f.write_str(function.name(self.db)), Type::Union(union) => union.display(self.db).fmt(f), Type::Intersection(intersection) => intersection.display(self.db).fmt(f), - Type::IntLiteral(n) => write!(f, "Literal[{n}]"), - Type::BooleanLiteral(boolean) => { - write!(f, "Literal[{}]", if *boolean { "True" } else { "False" }) + Type::IntLiteral(n) => write!(f, "{n}"), + Type::BooleanLiteral(boolean) => f.write_str(if *boolean { "True" } else { "False" }), + Type::StringLiteral(string) => { + write!(f, r#""{}""#, string.value(self.db).replace('"', r#"\""#)) } - Type::StringLiteral(string) => write!( - f, - r#"Literal["{}"]"#, - string.value(self.db).replace('"', r#"\""#) - ), - Type::LiteralString => write!(f, "LiteralString"), + Type::LiteralString => f.write_str("LiteralString"), Type::BytesLiteral(bytes) => { let escape = AsciiEscape::with_preferred_quote(bytes.value(self.db).as_ref(), Quote::Double); - f.write_str("Literal[")?; - escape.bytes_repr().write(f)?; - f.write_str("]") + escape.bytes_repr().write(f) } } } } -impl std::fmt::Debug for DisplayType<'_> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - std::fmt::Display::fmt(self, f) - } -} - impl<'db> UnionType<'db> { fn display(&'db self, db: &'db dyn Db) -> DisplayUnionType<'db> { DisplayUnionType { db, ty: self } @@ -78,45 +95,61 @@ struct DisplayUnionType<'db> { impl Display for DisplayUnionType<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let union = self.ty; + let elements = self.ty.elements(self.db); - let (int_literals, other_types): (Vec, Vec) = union - .elements(self.db) - .iter() - .copied() - .partition(|ty| matches!(ty, Type::IntLiteral(_))); + // Group literal types by kind. + let mut grouped_literals = FxOrderMap::default(); + + for element in elements { + if let Ok(literal_kind) = LiteralTypeKind::try_from(*element) { + grouped_literals + .entry(literal_kind) + .or_insert_with(Vec::new) + .push(*element); + } + } let mut first = true; - if !int_literals.is_empty() { - f.write_str("Literal[")?; - let mut nums: Vec<_> = int_literals - .into_iter() - .filter_map(|ty| { - if let Type::IntLiteral(n) = ty { - Some(n) - } else { - None - } - }) - .collect(); - nums.sort_unstable(); - for num in nums { + + // Print all types, but write all literals together (while preserving their position). + for ty in elements { + if let Ok(literal_kind) = LiteralTypeKind::try_from(*ty) { + let Some(mut literals) = grouped_literals.remove(&literal_kind) else { + continue; + }; + if !first { - f.write_str(", ")?; + f.write_str(" | ")?; + }; + + f.write_str("Literal[")?; + + if literal_kind == LiteralTypeKind::IntLiteral { + literals.sort_unstable_by_key(|ty| match ty { + Type::IntLiteral(n) => *n, + _ => panic!("Expected only int literals when kind is IntLiteral"), + }); } - write!(f, "{num}")?; - first = false; + + for (i, literal_ty) in literals.iter().enumerate() { + if i > 0 { + f.write_str(", ")?; + } + literal_ty.representation(self.db).fmt(f)?; + } + f.write_str("]")?; + } else { + if !first { + f.write_str(" | ")?; + }; + + ty.display(self.db).fmt(f)?; } - f.write_str("]")?; + + first = false; } - for ty in other_types { - if !first { - f.write_str(" | ")?; - }; - first = false; - write!(f, "{}", ty.display(self.db))?; - } + debug_assert!(grouped_literals.is_empty()); Ok(()) } @@ -128,6 +161,30 @@ impl std::fmt::Debug for DisplayUnionType<'_> { } } +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +enum LiteralTypeKind { + Class, + Function, + IntLiteral, + StringLiteral, + BytesLiteral, +} + +impl TryFrom> for LiteralTypeKind { + type Error = (); + + fn try_from(value: Type<'_>) -> Result { + match value { + Type::Class(_) => Ok(Self::Class), + Type::Function(_) => Ok(Self::Function), + Type::IntLiteral(_) => Ok(Self::IntLiteral), + Type::StringLiteral(_) => Ok(Self::StringLiteral), + Type::BytesLiteral(_) => Ok(Self::BytesLiteral), + _ => Err(()), + } + } +} + impl<'db> IntersectionType<'db> { fn display(&'db self, db: &'db dyn Db) -> DisplayIntersectionType<'db> { DisplayIntersectionType { db, ty: self } @@ -167,3 +224,93 @@ impl std::fmt::Debug for DisplayIntersectionType<'_> { std::fmt::Display::fmt(self, f) } } + +#[cfg(test)] +mod tests { + use ruff_db::files::system_path_to_file; + use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; + + use crate::db::tests::TestDb; + use crate::types::{ + global_symbol_ty_by_name, BytesLiteralType, StringLiteralType, Type, UnionBuilder, + }; + use crate::{Program, ProgramSettings, PythonVersion, SearchPathSettings}; + + fn setup_db() -> TestDb { + let db = TestDb::new(); + + let src_root = SystemPathBuf::from("/src"); + db.memory_file_system() + .create_directory_all(&src_root) + .unwrap(); + + Program::from_settings( + &db, + &ProgramSettings { + target_version: PythonVersion::default(), + search_paths: SearchPathSettings::new(src_root), + }, + ) + .expect("Valid search path settings"); + + db + } + + #[test] + fn test_condense_literal_display_by_type() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/main.py", + " + def foo(x: int) -> int: + return x + 1 + + def bar(s: str) -> str: + return s + + class A: ... + class B: ... + ", + )?; + let mod_file = system_path_to_file(&db, "src/main.py").expect("Expected file to exist."); + + let vec: Vec> = vec![ + Type::Unknown, + Type::IntLiteral(-1), + global_symbol_ty_by_name(&db, mod_file, "A"), + Type::StringLiteral(StringLiteralType::new(&db, Box::from("A"))), + Type::BytesLiteral(BytesLiteralType::new(&db, Box::from([0]))), + Type::BytesLiteral(BytesLiteralType::new(&db, Box::from([7]))), + Type::IntLiteral(0), + Type::IntLiteral(1), + Type::StringLiteral(StringLiteralType::new(&db, Box::from("B"))), + global_symbol_ty_by_name(&db, mod_file, "foo"), + global_symbol_ty_by_name(&db, mod_file, "bar"), + global_symbol_ty_by_name(&db, mod_file, "B"), + Type::BooleanLiteral(true), + Type::None, + ]; + let builder = vec.iter().fold(UnionBuilder::new(&db), |builder, literal| { + builder.add(*literal) + }); + let Type::Union(union) = builder.build() else { + panic!("expected a union"); + }; + let display = format!("{}", union.display(&db)); + assert_eq!( + display, + concat!( + "Unknown | ", + "Literal[-1, 0, 1] | ", + "Literal[A, B] | ", + "Literal[\"A\", \"B\"] | ", + "Literal[b\"\\x00\", b\"\\x07\"] | ", + "Literal[foo, bar] | ", + "Literal[True] | ", + "None" + ) + ); + Ok(()) + } +} diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 66162f9e5a..54cf3f9ce2 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -3093,7 +3093,7 @@ mod tests { ", )?; - assert_public_ty(&db, "src/a.py", "x", "Literal[3] | Unbound"); + assert_public_ty(&db, "src/a.py", "x", "Unbound | Literal[3]"); Ok(()) } @@ -3119,8 +3119,8 @@ mod tests { )?; assert_public_ty(&db, "src/a.py", "x", "Literal[3, 4, 5]"); - assert_public_ty(&db, "src/a.py", "r", "Literal[2] | Unbound"); - assert_public_ty(&db, "src/a.py", "s", "Literal[5] | Unbound"); + assert_public_ty(&db, "src/a.py", "r", "Unbound | Literal[2]"); + assert_public_ty(&db, "src/a.py", "s", "Unbound | Literal[5]"); Ok(()) } @@ -3356,7 +3356,7 @@ mod tests { assert_eq!( y_ty.display(&db).to_string(), - "Literal[1] | Literal[copyright]" + "Literal[copyright] | Literal[1]" ); Ok(()) @@ -3389,7 +3389,7 @@ mod tests { let y_ty = symbol_ty_by_name(&db, class_scope, "y"); let x_ty = symbol_ty_by_name(&db, class_scope, "x"); - assert_eq!(x_ty.display(&db).to_string(), "Literal[2] | Unbound"); + assert_eq!(x_ty.display(&db).to_string(), "Unbound | Literal[2]"); assert_eq!(y_ty.display(&db).to_string(), "Literal[1]"); Ok(()) @@ -3522,7 +3522,7 @@ mod tests { ", )?; - assert_public_ty(&db, "/src/a.py", "x", "Literal[1] | None"); + assert_public_ty(&db, "/src/a.py", "x", "None | Literal[1]"); assert_public_ty(&db, "/src/a.py", "y", "Literal[0, 1]"); Ok(())