diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 68f665637b..dfdf263b32 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -659,232 +659,3 @@ pub struct BytesLiteralType<'db> { #[return_ref] value: Box<[u8]>, } - -#[cfg(test)] -mod tests { - use anyhow::Context; - - use ruff_db::files::system_path_to_file; - use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; - - use crate::db::tests::TestDb; - use crate::{Program, ProgramSettings, PythonVersion, SearchPathSettings}; - - use super::TypeCheckDiagnostics; - - fn setup_db() -> TestDb { - let db = TestDb::new(); - db.memory_file_system() - .create_directory_all("/src") - .unwrap(); - - Program::from_settings( - &db, - &ProgramSettings { - target_version: PythonVersion::default(), - search_paths: SearchPathSettings::new(SystemPathBuf::from("/src")), - }, - ) - .expect("Valid search path settings"); - - db - } - - fn assert_diagnostic_messages(diagnostics: &TypeCheckDiagnostics, expected: &[&str]) { - let messages: Vec<&str> = diagnostics - .iter() - .map(|diagnostic| diagnostic.message()) - .collect(); - assert_eq!(&messages, expected); - } - - #[test] - fn unresolved_import_statement() -> anyhow::Result<()> { - let mut db = setup_db(); - - db.write_file("src/foo.py", "import bar\n") - .context("Failed to write foo.py")?; - - let foo = system_path_to_file(&db, "src/foo.py").context("Failed to resolve foo.py")?; - - let diagnostics = super::check_types(&db, foo); - assert_diagnostic_messages(&diagnostics, &["Cannot resolve import 'bar'."]); - - Ok(()) - } - - #[test] - fn unresolved_import_from_statement() { - let mut db = setup_db(); - - db.write_file("src/foo.py", "from bar import baz\n") - .unwrap(); - let foo = system_path_to_file(&db, "src/foo.py").unwrap(); - let diagnostics = super::check_types(&db, foo); - assert_diagnostic_messages(&diagnostics, &["Cannot resolve import 'bar'."]); - } - - #[test] - fn unresolved_import_from_resolved_module() { - let mut db = setup_db(); - - db.write_files([("/src/a.py", ""), ("/src/b.py", "from a import thing")]) - .unwrap(); - - let b_file = system_path_to_file(&db, "/src/b.py").unwrap(); - let b_file_diagnostics = super::check_types(&db, b_file); - assert_diagnostic_messages(&b_file_diagnostics, &["Module 'a' has no member 'thing'"]); - } - - #[test] - fn resolved_import_of_symbol_from_unresolved_import() { - let mut db = setup_db(); - - db.write_files([ - ("/src/a.py", "import foo as foo"), - ("/src/b.py", "from a import foo"), - ]) - .unwrap(); - - let a_file = system_path_to_file(&db, "/src/a.py").unwrap(); - let a_file_diagnostics = super::check_types(&db, a_file); - assert_diagnostic_messages(&a_file_diagnostics, &["Cannot resolve import 'foo'."]); - - // Importing the unresolved import into a second first-party file should not trigger - // an additional "unresolved import" violation - let b_file = system_path_to_file(&db, "/src/b.py").unwrap(); - let b_file_diagnostics = super::check_types(&db, b_file); - assert_eq!(&*b_file_diagnostics, &[]); - } - - #[test] - fn invalid_callable() { - let mut db = setup_db(); - - db.write_dedented( - "src/a.py", - " - nonsense = 123 - x = nonsense() - ", - ) - .unwrap(); - - let a_file = system_path_to_file(&db, "/src/a.py").unwrap(); - let a_file_diagnostics = super::check_types(&db, a_file); - assert_diagnostic_messages( - &a_file_diagnostics, - &["Object of type 'Literal[123]' is not callable"], - ); - } - - #[test] - fn invalid_iterable() { - let mut db = setup_db(); - - db.write_dedented( - "src/a.py", - " - nonsense = 123 - for x in nonsense: - pass - ", - ) - .unwrap(); - - let a_file = system_path_to_file(&db, "/src/a.py").unwrap(); - let a_file_diagnostics = super::check_types(&db, a_file); - assert_diagnostic_messages( - &a_file_diagnostics, - &["Object of type 'Literal[123]' is not iterable"], - ); - } - - #[test] - fn new_iteration_protocol_takes_precedence_over_old_style() { - let mut db = setup_db(); - - db.write_dedented( - "src/a.py", - " - class NotIterable: - def __getitem__(self, key: int) -> int: - return 42 - - __iter__ = None - - for x in NotIterable(): - pass - ", - ) - .unwrap(); - - let a_file = system_path_to_file(&db, "/src/a.py").unwrap(); - let a_file_diagnostics = super::check_types(&db, a_file); - assert_diagnostic_messages( - &a_file_diagnostics, - &["Object of type 'NotIterable' is not iterable"], - ); - } - - #[test] - fn starred_expressions_must_be_iterable() { - let mut db = setup_db(); - - db.write_dedented( - "src/a.py", - " - class NotIterable: pass - - class Iterator: - def __next__(self) -> int: - return 42 - - class Iterable: - def __iter__(self) -> Iterator: - - x = [*NotIterable()] - y = [*Iterable()] - ", - ) - .unwrap(); - - let a_file = system_path_to_file(&db, "/src/a.py").unwrap(); - let a_file_diagnostics = super::check_types(&db, a_file); - assert_diagnostic_messages( - &a_file_diagnostics, - &["Object of type 'NotIterable' is not iterable"], - ); - } - - #[test] - fn yield_from_expression_must_be_iterable() { - let mut db = setup_db(); - - db.write_dedented( - "src/a.py", - " - class NotIterable: pass - - class Iterator: - def __next__(self) -> int: - return 42 - - class Iterable: - def __iter__(self) -> Iterator: - - def generator_function(): - yield from Iterable() - yield from NotIterable() - ", - ) - .unwrap(); - - let a_file = system_path_to_file(&db, "/src/a.py").unwrap(); - let a_file_diagnostics = super::check_types(&db, a_file); - assert_diagnostic_messages( - &a_file_diagnostics, - &["Object of type 'NotIterable' is not iterable"], - ); - } -} diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 0dcc108531..4b3bf4af42 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2409,7 +2409,9 @@ mod tests { use crate::semantic_index::symbol::FileScopeId; use crate::semantic_index::{global_scope, semantic_index, symbol_table, use_def_map}; use crate::stdlib::builtins_module_scope; - use crate::types::{global_symbol_ty, infer_definition_types, symbol_ty}; + use crate::types::{ + check_types, global_symbol_ty, infer_definition_types, symbol_ty, TypeCheckDiagnostics, + }; use crate::{HasTy, ProgramSettings, SemanticModel}; use super::TypeInferenceBuilder; @@ -2491,6 +2493,21 @@ mod tests { assert_eq!(ty.display(db).to_string(), expected); } + fn assert_diagnostic_messages(diagnostics: &TypeCheckDiagnostics, expected: &[&str]) { + let messages: Vec<&str> = diagnostics + .iter() + .map(|diagnostic| diagnostic.message()) + .collect(); + assert_eq!(&messages, expected); + } + + fn assert_file_diagnostics(db: &TestDb, filename: &str, expected: &[&str]) { + let file = system_path_to_file(db, filename).unwrap(); + let diagnostics = check_types(db, file); + + assert_diagnostic_messages(&diagnostics, expected); + } + #[test] fn follow_import_to_class() -> anyhow::Result<()> { let mut db = setup_db(); @@ -2997,108 +3014,6 @@ mod tests { Ok(()) } - #[test] - fn basic_for_loop() -> anyhow::Result<()> { - let mut db = setup_db(); - - db.write_dedented( - "src/a.py", - " - class IntIterator: - def __next__(self) -> int: - return 42 - - class IntIterable: - def __iter__(self) -> IntIterator: - return IntIterator() - - for x in IntIterable(): - pass - ", - )?; - - assert_public_ty(&db, "src/a.py", "x", "int"); - - Ok(()) - } - - #[test] - fn for_loop_with_old_style_iteration_protocol() -> anyhow::Result<()> { - let mut db = setup_db(); - - db.write_dedented( - "src/a.py", - " - class OldStyleIterable: - def __getitem__(self, key: int) -> int: - return 42 - - for x in OldStyleIterable(): - pass - ", - )?; - - assert_public_ty(&db, "src/a.py", "x", "int"); - - Ok(()) - } - - /// This tests that we understand that `async` for loops - /// do not work according to the synchronous iteration protocol - #[test] - fn invalid_async_for_loop() -> anyhow::Result<()> { - let mut db = setup_db(); - - db.write_dedented( - "src/a.py", - " - async def foo(): - class Iterator: - def __next__(self) -> int: - return 42 - - class Iterable: - def __iter__(self) -> Iterator: - return Iterator() - - async for x in Iterator(): - pass - ", - )?; - - // TODO(Alex) async iterables/iterators! - assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown"); - - Ok(()) - } - - #[test] - fn basic_async_for_loop() -> anyhow::Result<()> { - let mut db = setup_db(); - - db.write_dedented( - "src/a.py", - " - async def foo(): - class IntAsyncIterator: - async def __anext__(self) -> int: - return 42 - - class IntAsyncIterable: - def __aiter__(self) -> IntAsyncIterator: - return IntAsyncIterator() - - async for x in IntAsyncIterable(): - pass - ", - )?; - - // TODO(Alex) async iterables/iterators! - assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown"); - - Ok(()) - } - #[test] fn class_constructor_call_expression() -> anyhow::Result<()> { let mut db = setup_db(); @@ -3117,6 +3032,26 @@ mod tests { Ok(()) } + #[test] + fn invalid_callable() { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + nonsense = 123 + x = nonsense() + ", + ) + .unwrap(); + + assert_file_diagnostics( + &db, + "/src/a.py", + &["Object of type 'Literal[123]' is not callable"], + ); + } + #[test] fn resolve_union() -> anyhow::Result<()> { let mut db = setup_db(); @@ -4014,6 +3949,259 @@ mod tests { Ok(()) } + #[test] + fn unresolved_import_statement() { + let mut db = setup_db(); + + db.write_file("src/foo.py", "import bar\n").unwrap(); + + assert_file_diagnostics(&db, "src/foo.py", &["Cannot resolve import 'bar'."]); + } + + #[test] + fn unresolved_import_from_statement() { + let mut db = setup_db(); + + db.write_file("src/foo.py", "from bar import baz\n") + .unwrap(); + assert_file_diagnostics(&db, "/src/foo.py", &["Cannot resolve import 'bar'."]); + } + + #[test] + fn unresolved_import_from_resolved_module() { + let mut db = setup_db(); + + db.write_files([("/src/a.py", ""), ("/src/b.py", "from a import thing")]) + .unwrap(); + + assert_file_diagnostics(&db, "/src/b.py", &["Module 'a' has no member 'thing'"]); + } + + #[test] + fn resolved_import_of_symbol_from_unresolved_import() { + let mut db = setup_db(); + + db.write_files([ + ("/src/a.py", "import foo as foo"), + ("/src/b.py", "from a import foo"), + ]) + .unwrap(); + + assert_file_diagnostics(&db, "/src/a.py", &["Cannot resolve import 'foo'."]); + + // Importing the unresolved import into a second first-party file should not trigger + // an additional "unresolved import" violation + assert_file_diagnostics(&db, "/src/b.py", &[]); + } + + #[test] + fn basic_for_loop() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + class IntIterator: + def __next__(self) -> int: + return 42 + + class IntIterable: + def __iter__(self) -> IntIterator: + return IntIterator() + + for x in IntIterable(): + pass + ", + )?; + + assert_public_ty(&db, "src/a.py", "x", "int"); + + Ok(()) + } + + #[test] + fn for_loop_with_old_style_iteration_protocol() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + class OldStyleIterable: + def __getitem__(self, key: int) -> int: + return 42 + + for x in OldStyleIterable(): + pass + ", + )?; + + assert_public_ty(&db, "src/a.py", "x", "int"); + + Ok(()) + } + + /// This tests that we understand that `async` for loops + /// do not work according to the synchronous iteration protocol + #[test] + fn invalid_async_for_loop() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + async def foo(): + class Iterator: + def __next__(self) -> int: + return 42 + + class Iterable: + def __iter__(self) -> Iterator: + return Iterator() + + async for x in Iterator(): + pass + ", + )?; + + // TODO(Alex) async iterables/iterators! + assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown"); + + Ok(()) + } + + #[test] + fn basic_async_for_loop() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + async def foo(): + class IntAsyncIterator: + async def __anext__(self) -> int: + return 42 + + class IntAsyncIterable: + def __aiter__(self) -> IntAsyncIterator: + return IntAsyncIterator() + + async for x in IntAsyncIterable(): + pass + ", + )?; + + // TODO(Alex) async iterables/iterators! + assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown"); + + Ok(()) + } + + #[test] + fn invalid_iterable() { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + nonsense = 123 + for x in nonsense: + pass + ", + ) + .unwrap(); + + assert_file_diagnostics( + &db, + "/src/a.py", + &["Object of type 'Literal[123]' is not iterable"], + ); + } + + #[test] + fn new_iteration_protocol_takes_precedence_over_old_style() { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + class NotIterable: + def __getitem__(self, key: int) -> int: + return 42 + + __iter__ = None + + for x in NotIterable(): + pass + ", + ) + .unwrap(); + + assert_file_diagnostics( + &db, + "/src/a.py", + &["Object of type 'NotIterable' is not iterable"], + ); + } + + #[test] + fn starred_expressions_must_be_iterable() { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + class NotIterable: pass + + class Iterator: + def __next__(self) -> int: + return 42 + + class Iterable: + def __iter__(self) -> Iterator: + + x = [*NotIterable()] + y = [*Iterable()] + ", + ) + .unwrap(); + + assert_file_diagnostics( + &db, + "/src/a.py", + &["Object of type 'NotIterable' is not iterable"], + ); + } + + #[test] + fn yield_from_expression_must_be_iterable() { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + class NotIterable: pass + + class Iterator: + def __next__(self) -> int: + return 42 + + class Iterable: + def __iter__(self) -> Iterator: + + def generator_function(): + yield from Iterable() + yield from NotIterable() + ", + ) + .unwrap(); + + assert_file_diagnostics( + &db, + "/src/a.py", + &["Object of type 'NotIterable' is not iterable"], + ); + } + // Incremental inference tests fn first_public_def<'db>(db: &'db TestDb, file: File, name: &str) -> Definition<'db> {