diff --git a/Cargo.lock b/Cargo.lock index f48ba867f1..9c8c36681e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4004,6 +4004,7 @@ dependencies = [ "test-case", "thiserror 2.0.12", "tracing", + "ty_python_semantic", "ty_test", "ty_vendored", ] @@ -4039,6 +4040,7 @@ name = "ty_test" version = "0.0.0" dependencies = [ "anyhow", + "bitflags 2.9.1", "camino", "colored 3.0.0", "insta", diff --git a/crates/ty_python_semantic/Cargo.toml b/crates/ty_python_semantic/Cargo.toml index 1bbdf10a49..7c5268299d 100644 --- a/crates/ty_python_semantic/Cargo.toml +++ b/crates/ty_python_semantic/Cargo.toml @@ -50,6 +50,7 @@ strum_macros = { workspace = true } [dev-dependencies] ruff_db = { workspace = true, features = ["testing", "os"] } ruff_python_parser = { workspace = true } +ty_python_semantic = { workspace = true, features = ["testing"] } ty_test = { workspace = true } ty_vendored = { workspace = true } @@ -63,6 +64,7 @@ quickcheck_macros = { version = "1.0.0" } [features] serde = ["ruff_db/serde", "dep:serde", "ruff_python_ast/serde"] +testing = [] [lints] workspace = true diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/any.md b/crates/ty_python_semantic/resources/mdtest/annotations/any.md index d4b1e6f502..a35b18168e 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/any.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/any.md @@ -139,6 +139,8 @@ x: int = MagicMock() ## Invalid + + `Any` cannot be parameterized: ```py diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/callable.md b/crates/ty_python_semantic/resources/mdtest/annotations/callable.md index 4b398acef0..b5420c2873 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/callable.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/callable.md @@ -58,6 +58,8 @@ def _(c: Callable[[int, 42, str, False], None]): ### Missing return type + + Using a parameter list: ```py diff --git a/crates/ty_python_semantic/resources/mdtest/type_api.md b/crates/ty_python_semantic/resources/mdtest/type_api.md index 639bd4f83a..8074e1fa43 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_api.md +++ b/crates/ty_python_semantic/resources/mdtest/type_api.md @@ -14,6 +14,8 @@ directly. ### Negation + + ```py from typing import Literal from ty_extensions import Not, static_assert @@ -371,6 +373,8 @@ static_assert(not is_single_valued(Literal["a"] | Literal["b"])) ## `TypeOf` + + We use `TypeOf` to get the inferred type of an expression. This is useful when we want to refer to it in a type expression. For example, if we want to make sure that the class literal type `str` is a subtype of `type[str]`, we can not use `is_subtype_of(str, type[str])`, as that would test if the @@ -412,6 +416,8 @@ def f(x: TypeOf) -> None: ## `CallableTypeOf` + + The `CallableTypeOf` special form can be used to extract the `Callable` structural type inhabited by a given callable object. This can be used to get the externally visibly signature of the object, which can then be used to test various type properties. diff --git a/crates/ty_python_semantic/resources/mdtest/type_qualifiers/classvar.md b/crates/ty_python_semantic/resources/mdtest/type_qualifiers/classvar.md index fdf443db9c..c92901341e 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_qualifiers/classvar.md +++ b/crates/ty_python_semantic/resources/mdtest/type_qualifiers/classvar.md @@ -84,6 +84,8 @@ d.a = 2 ## Too many arguments + + ```py from typing import ClassVar diff --git a/crates/ty_python_semantic/resources/mdtest/type_qualifiers/final.md b/crates/ty_python_semantic/resources/mdtest/type_qualifiers/final.md index 32e84412eb..c5f5a86375 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_qualifiers/final.md +++ b/crates/ty_python_semantic/resources/mdtest/type_qualifiers/final.md @@ -45,6 +45,8 @@ reveal_type(FINAL_E) # revealed: int ## Too many arguments + + ```py from typing import Final diff --git a/crates/ty_python_semantic/src/lib.rs b/crates/ty_python_semantic/src/lib.rs index 0123d28c17..9728cf3731 100644 --- a/crates/ty_python_semantic/src/lib.rs +++ b/crates/ty_python_semantic/src/lib.rs @@ -35,6 +35,9 @@ pub mod types; mod unpack; mod util; +#[cfg(feature = "testing")] +pub mod pull_types; + type FxOrderSet = ordermap::set::OrderSet>; /// Returns the default registry with all known semantic lints. diff --git a/crates/ty_python_semantic/src/pull_types.rs b/crates/ty_python_semantic/src/pull_types.rs new file mode 100644 index 0000000000..68feb73edc --- /dev/null +++ b/crates/ty_python_semantic/src/pull_types.rs @@ -0,0 +1,134 @@ +//! A utility visitor for testing, which attempts to "pull a type" for ever sub-node in a given AST. +//! +//! This is used in the "corpus" and (indirectly) the "mdtest" integration tests for this crate. +//! (Mdtest uses the `pull_types` function via the `ty_test` crate.) + +use crate::{Db, HasType, SemanticModel}; +use ruff_db::{files::File, parsed::parsed_module}; +use ruff_python_ast::{ + self as ast, visitor::source_order, visitor::source_order::SourceOrderVisitor, +}; + +pub fn pull_types(db: &dyn Db, file: File) { + let mut visitor = PullTypesVisitor::new(db, file); + + let ast = parsed_module(db.upcast(), file).load(db.upcast()); + + visitor.visit_body(ast.suite()); +} + +struct PullTypesVisitor<'db> { + model: SemanticModel<'db>, +} + +impl<'db> PullTypesVisitor<'db> { + fn new(db: &'db dyn Db, file: File) -> Self { + Self { + model: SemanticModel::new(db, file), + } + } + + fn visit_target(&mut self, target: &ast::Expr) { + match target { + ast::Expr::List(ast::ExprList { elts, .. }) + | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => { + for element in elts { + self.visit_target(element); + } + } + _ => self.visit_expr(target), + } + } +} + +impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> { + fn visit_stmt(&mut self, stmt: &ast::Stmt) { + match stmt { + ast::Stmt::FunctionDef(function) => { + let _ty = function.inferred_type(&self.model); + } + ast::Stmt::ClassDef(class) => { + let _ty = class.inferred_type(&self.model); + } + ast::Stmt::Assign(assign) => { + for target in &assign.targets { + self.visit_target(target); + } + self.visit_expr(&assign.value); + return; + } + ast::Stmt::For(for_stmt) => { + self.visit_target(&for_stmt.target); + self.visit_expr(&for_stmt.iter); + self.visit_body(&for_stmt.body); + self.visit_body(&for_stmt.orelse); + return; + } + ast::Stmt::With(with_stmt) => { + for item in &with_stmt.items { + if let Some(target) = &item.optional_vars { + self.visit_target(target); + } + self.visit_expr(&item.context_expr); + } + + self.visit_body(&with_stmt.body); + return; + } + ast::Stmt::AnnAssign(_) + | ast::Stmt::Return(_) + | ast::Stmt::Delete(_) + | ast::Stmt::AugAssign(_) + | ast::Stmt::TypeAlias(_) + | ast::Stmt::While(_) + | ast::Stmt::If(_) + | ast::Stmt::Match(_) + | ast::Stmt::Raise(_) + | ast::Stmt::Try(_) + | ast::Stmt::Assert(_) + | ast::Stmt::Import(_) + | ast::Stmt::ImportFrom(_) + | ast::Stmt::Global(_) + | ast::Stmt::Nonlocal(_) + | ast::Stmt::Expr(_) + | ast::Stmt::Pass(_) + | ast::Stmt::Break(_) + | ast::Stmt::Continue(_) + | ast::Stmt::IpyEscapeCommand(_) => {} + } + + source_order::walk_stmt(self, stmt); + } + + fn visit_expr(&mut self, expr: &ast::Expr) { + let _ty = expr.inferred_type(&self.model); + + source_order::walk_expr(self, expr); + } + + fn visit_comprehension(&mut self, comprehension: &ast::Comprehension) { + self.visit_expr(&comprehension.iter); + self.visit_target(&comprehension.target); + for if_expr in &comprehension.ifs { + self.visit_expr(if_expr); + } + } + + fn visit_parameter(&mut self, parameter: &ast::Parameter) { + let _ty = parameter.inferred_type(&self.model); + + source_order::walk_parameter(self, parameter); + } + + fn visit_parameter_with_default(&mut self, parameter_with_default: &ast::ParameterWithDefault) { + let _ty = parameter_with_default.inferred_type(&self.model); + + source_order::walk_parameter_with_default(self, parameter_with_default); + } + + fn visit_alias(&mut self, alias: &ast::Alias) { + let _ty = alias.inferred_type(&self.model); + + source_order::walk_alias(self, alias); + } +} diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 5c99354bb6..d1ea86b1c6 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -471,8 +471,10 @@ impl<'db> TypeInference<'db> { #[track_caller] pub(crate) fn expression_type(&self, expression: ScopedExpressionId) -> Type<'db> { self.try_expression_type(expression).expect( - "expression should belong to this TypeInference region and \ - TypeInferenceBuilder should have inferred a type for it", + "Failed to retrieve the inferred type for an `ast::Expr` node \ + passed to `TypeInference::expression_type()`. The `TypeInferenceBuilder` \ + should infer and store types for all `ast::Expr` nodes in any `TypeInference` \ + region it analyzes.", ) } diff --git a/crates/ty_python_semantic/tests/corpus.rs b/crates/ty_python_semantic/tests/corpus.rs index 6880dac918..0a9222a3d6 100644 --- a/crates/ty_python_semantic/tests/corpus.rs +++ b/crates/ty_python_semantic/tests/corpus.rs @@ -1,19 +1,15 @@ use anyhow::{Context, anyhow}; use ruff_db::Upcast; use ruff_db::files::{File, Files, system_path_to_file}; -use ruff_db::parsed::parsed_module; use ruff_db::system::{DbWithTestSystem, System, SystemPath, SystemPathBuf, TestSystem}; use ruff_db::vendored::VendoredFileSystem; -use ruff_python_ast::visitor::source_order; -use ruff_python_ast::visitor::source_order::SourceOrderVisitor; -use ruff_python_ast::{ - self as ast, Alias, Comprehension, Expr, Parameter, ParameterWithDefault, PythonVersion, Stmt, -}; +use ruff_python_ast::PythonVersion; use ty_python_semantic::lint::{LintRegistry, RuleSelection}; +use ty_python_semantic::pull_types::pull_types; use ty_python_semantic::{ - Db, HasType, Program, ProgramSettings, PythonPlatform, PythonVersionSource, - PythonVersionWithSource, SearchPathSettings, SemanticModel, default_lint_registry, + Program, ProgramSettings, PythonPlatform, PythonVersionSource, PythonVersionWithSource, + SearchPathSettings, default_lint_registry, }; fn get_cargo_workspace_root() -> anyhow::Result { @@ -174,129 +170,6 @@ fn run_corpus_tests(pattern: &str) -> anyhow::Result<()> { Ok(()) } -fn pull_types(db: &dyn Db, file: File) { - let mut visitor = PullTypesVisitor::new(db, file); - - let ast = parsed_module(db.upcast(), file).load(db.upcast()); - - visitor.visit_body(ast.suite()); -} - -struct PullTypesVisitor<'db> { - model: SemanticModel<'db>, -} - -impl<'db> PullTypesVisitor<'db> { - fn new(db: &'db dyn Db, file: File) -> Self { - Self { - model: SemanticModel::new(db, file), - } - } - - fn visit_target(&mut self, target: &Expr) { - match target { - Expr::List(ast::ExprList { elts, .. }) | Expr::Tuple(ast::ExprTuple { elts, .. }) => { - for element in elts { - self.visit_target(element); - } - } - _ => self.visit_expr(target), - } - } -} - -impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> { - fn visit_stmt(&mut self, stmt: &Stmt) { - match stmt { - Stmt::FunctionDef(function) => { - let _ty = function.inferred_type(&self.model); - } - Stmt::ClassDef(class) => { - let _ty = class.inferred_type(&self.model); - } - Stmt::Assign(assign) => { - for target in &assign.targets { - self.visit_target(target); - } - self.visit_expr(&assign.value); - return; - } - Stmt::For(for_stmt) => { - self.visit_target(&for_stmt.target); - self.visit_expr(&for_stmt.iter); - self.visit_body(&for_stmt.body); - self.visit_body(&for_stmt.orelse); - return; - } - Stmt::With(with_stmt) => { - for item in &with_stmt.items { - if let Some(target) = &item.optional_vars { - self.visit_target(target); - } - self.visit_expr(&item.context_expr); - } - - self.visit_body(&with_stmt.body); - return; - } - Stmt::AnnAssign(_) - | Stmt::Return(_) - | Stmt::Delete(_) - | Stmt::AugAssign(_) - | Stmt::TypeAlias(_) - | Stmt::While(_) - | Stmt::If(_) - | Stmt::Match(_) - | Stmt::Raise(_) - | Stmt::Try(_) - | Stmt::Assert(_) - | Stmt::Import(_) - | Stmt::ImportFrom(_) - | Stmt::Global(_) - | Stmt::Nonlocal(_) - | Stmt::Expr(_) - | Stmt::Pass(_) - | Stmt::Break(_) - | Stmt::Continue(_) - | Stmt::IpyEscapeCommand(_) => {} - } - - source_order::walk_stmt(self, stmt); - } - - fn visit_expr(&mut self, expr: &Expr) { - let _ty = expr.inferred_type(&self.model); - - source_order::walk_expr(self, expr); - } - - fn visit_comprehension(&mut self, comprehension: &Comprehension) { - self.visit_expr(&comprehension.iter); - self.visit_target(&comprehension.target); - for if_expr in &comprehension.ifs { - self.visit_expr(if_expr); - } - } - - fn visit_parameter(&mut self, parameter: &Parameter) { - let _ty = parameter.inferred_type(&self.model); - - source_order::walk_parameter(self, parameter); - } - - fn visit_parameter_with_default(&mut self, parameter_with_default: &ParameterWithDefault) { - let _ty = parameter_with_default.inferred_type(&self.model); - - source_order::walk_parameter_with_default(self, parameter_with_default); - } - - fn visit_alias(&mut self, alias: &Alias) { - let _ty = alias.inferred_type(&self.model); - - source_order::walk_alias(self, alias); - } -} - /// Whether or not the .py/.pyi version of this file is expected to fail #[rustfmt::skip] const KNOWN_FAILURES: &[(&str, bool, bool)] = &[ diff --git a/crates/ty_test/Cargo.toml b/crates/ty_test/Cargo.toml index c80bf76011..f3d698f21f 100644 --- a/crates/ty_test/Cargo.toml +++ b/crates/ty_test/Cargo.toml @@ -18,10 +18,11 @@ ruff_python_trivia = { workspace = true } ruff_source_file = { workspace = true } ruff_text_size = { workspace = true } ruff_python_ast = { workspace = true } -ty_python_semantic = { workspace = true, features = ["serde"] } +ty_python_semantic = { workspace = true, features = ["serde", "testing"] } ty_vendored = { workspace = true } anyhow = { workspace = true } +bitflags = { workspace = true } camino = { workspace = true } colored = { workspace = true } insta = { workspace = true, features = ["filters"] } diff --git a/crates/ty_test/src/lib.rs b/crates/ty_test/src/lib.rs index 4749f477bd..f9e779fd0e 100644 --- a/crates/ty_test/src/lib.rs +++ b/crates/ty_test/src/lib.rs @@ -1,4 +1,5 @@ use crate::config::Log; +use crate::db::Db; use crate::parser::{BacktickOffsets, EmbeddedFileSourceMap}; use camino::Utf8Path; use colored::Colorize; @@ -17,6 +18,7 @@ use ruff_db::testing::{setup_logging, setup_logging_with_filter}; use ruff_source_file::{LineIndex, OneIndexed}; use std::backtrace::BacktraceStatus; use std::fmt::Write; +use ty_python_semantic::pull_types::pull_types; use ty_python_semantic::types::check_types; use ty_python_semantic::{ Program, ProgramSettings, PythonPath, PythonPlatform, PythonVersionSource, @@ -291,9 +293,31 @@ fn run_test( // all diagnostics. Otherwise it remains empty. let mut snapshot_diagnostics = vec![]; - let failures: Failures = test_files - .into_iter() + let mut any_pull_types_failures = false; + + let mut failures: Failures = test_files + .iter() .filter_map(|test_file| { + let pull_types_result = attempt_test( + db, + pull_types, + test_file, + "\"pull types\"", + Some( + "Note: either fix the panic or add the `` \ + directive to this test", + ), + ); + match pull_types_result { + Ok(()) => {} + Err(failures) => { + any_pull_types_failures = true; + if !test.should_skip_pulling_types() { + return Some(failures); + } + } + } + let parsed = parsed_module(db, test_file.file).load(db); let mut diagnostics: Vec = parsed @@ -309,64 +333,50 @@ fn run_test( .map(|error| create_unsupported_syntax_diagnostic(test_file.file, error)), ); - let type_diagnostics = match catch_unwind(|| check_types(db, test_file.file)) { - Ok(type_diagnostics) => type_diagnostics, - Err(info) => { - let mut by_line = matcher::FailuresByLine::default(); - let mut messages = vec![]; - match info.location { - Some(location) => messages.push(format!("panicked at {location}")), - None => messages.push("panicked at unknown location".to_string()), - } - match info.payload.as_str() { - Some(message) => messages.push(message.to_string()), - // Mimic the default panic hook's rendering of the panic payload if it's - // not a string. - None => messages.push("Box".to_string()), - } - if let Some(backtrace) = info.backtrace { - match backtrace.status() { - BacktraceStatus::Disabled => { - let msg = "run with `RUST_BACKTRACE=1` environment variable to display a backtrace"; - messages.push(msg.to_string()); - } - BacktraceStatus::Captured => { - messages.extend(backtrace.to_string().split('\n').map(String::from)); - } - _ => {} - } - } - - if let Some(backtrace) = info.salsa_backtrace { - salsa::attach(db, || { - messages.extend(format!("{backtrace:#}").split('\n').map(String::from)); - }); - } - - by_line.push(OneIndexed::from_zero_indexed(0), messages); - return Some(FileFailures { - backtick_offsets: test_file.backtick_offsets, - by_line, - }); - } + let mdtest_result = attempt_test(db, check_types, test_file, "run mdtest", None); + let type_diagnostics = match mdtest_result { + Ok(diagnostics) => diagnostics, + Err(failures) => return Some(failures), }; + diagnostics.extend(type_diagnostics.into_iter().cloned()); - diagnostics.sort_by(|left, right|left.rendering_sort_key(db).cmp(&right.rendering_sort_key(db))); + diagnostics.sort_by(|left, right| { + left.rendering_sort_key(db) + .cmp(&right.rendering_sort_key(db)) + }); let failure = match matcher::match_file(db, test_file.file, &diagnostics) { Ok(()) => None, Err(line_failures) => Some(FileFailures { - backtick_offsets: test_file.backtick_offsets, + backtick_offsets: test_file.backtick_offsets.clone(), by_line: line_failures, }), }; if test.should_snapshot_diagnostics() { snapshot_diagnostics.extend(diagnostics); } + failure }) .collect(); + if test.should_skip_pulling_types() && !any_pull_types_failures { + let mut by_line = matcher::FailuresByLine::default(); + by_line.push( + OneIndexed::from_zero_indexed(0), + vec![ + "Remove the `` directive from this test: pulling types \ + succeeded for all files in the test." + .to_string(), + ], + ); + let failure = FileFailures { + backtick_offsets: test_files[0].backtick_offsets.clone(), + by_line, + }; + failures.push(failure); + } + if snapshot_diagnostics.is_empty() && test.should_snapshot_diagnostics() { panic!( "Test `{}` requested snapshotting diagnostics but it didn't produce any.", @@ -462,3 +472,71 @@ fn create_diagnostic_snapshot( } snapshot } + +/// Run a function over an embedded test file, catching any panics that occur in the process. +/// +/// If no panic occurs, the result of the function is returned as an `Ok()` variant. +/// +/// If a panic occurs, a nicely formatted [`FileFailures`] is returned as an `Err()` variant. +/// This will be formatted into a diagnostic message by `ty_test`. +fn attempt_test<'db, T, F>( + db: &'db Db, + test_fn: F, + test_file: &TestFile, + action: &str, + clarification: Option<&str>, +) -> Result +where + F: FnOnce(&'db dyn ty_python_semantic::Db, File) -> T + std::panic::UnwindSafe, +{ + catch_unwind(|| test_fn(db, test_file.file)).map_err(|info| { + let mut by_line = matcher::FailuresByLine::default(); + let mut messages = vec![]; + match info.location { + Some(location) => messages.push(format!( + "Attempting to {action} caused a panic at {location}" + )), + None => messages.push(format!( + "Attempting to {action} caused a panic at an unknown location", + )), + } + if let Some(clarification) = clarification { + messages.push(clarification.to_string()); + } + messages.push(String::new()); + match info.payload.as_str() { + Some(message) => messages.push(message.to_string()), + // Mimic the default panic hook's rendering of the panic payload if it's + // not a string. + None => messages.push("Box".to_string()), + } + messages.push(String::new()); + + if let Some(backtrace) = info.backtrace { + match backtrace.status() { + BacktraceStatus::Disabled => { + let msg = + "run with `RUST_BACKTRACE=1` environment variable to display a backtrace"; + messages.push(msg.to_string()); + } + BacktraceStatus::Captured => { + messages.extend(backtrace.to_string().split('\n').map(String::from)); + } + _ => {} + } + } + + if let Some(backtrace) = info.salsa_backtrace { + salsa::attach(db, || { + messages.extend(format!("{backtrace:#}").split('\n').map(String::from)); + }); + } + + by_line.push(OneIndexed::from_zero_indexed(0), messages); + + FileFailures { + backtick_offsets: test_file.backtick_offsets.clone(), + by_line, + } + }) +} diff --git a/crates/ty_test/src/parser.rs b/crates/ty_test/src/parser.rs index b1cc448beb..68e9ab5198 100644 --- a/crates/ty_test/src/parser.rs +++ b/crates/ty_test/src/parser.rs @@ -143,7 +143,15 @@ impl<'m, 's> MarkdownTest<'m, 's> { } pub(super) fn should_snapshot_diagnostics(&self) -> bool { - self.section.snapshot_diagnostics + self.section + .directives + .contains(MdtestDirectives::SNAPSHOT_DIAGNOSTICS) + } + + pub(super) fn should_skip_pulling_types(&self) -> bool { + self.section + .directives + .contains(MdtestDirectives::PULL_TYPES_SKIP) } } @@ -194,7 +202,7 @@ struct Section<'s> { level: u8, parent_id: Option, config: MarkdownTestConfig, - snapshot_diagnostics: bool, + directives: MdtestDirectives, } #[newtype_index] @@ -428,7 +436,7 @@ impl<'s> Parser<'s> { level: 0, parent_id: None, config: MarkdownTestConfig::default(), - snapshot_diagnostics: false, + directives: MdtestDirectives::default(), }); Self { sections, @@ -486,6 +494,7 @@ impl<'s> Parser<'s> { fn parse_impl(&mut self) -> anyhow::Result<()> { const SECTION_CONFIG_SNAPSHOT: &str = "snapshot-diagnostics"; + const SECTION_CONFIG_PULLTYPES: &str = "pull-types:skip"; const HTML_COMMENT_ALLOWLIST: &[&str] = &["blacken-docs:on", "blacken-docs:off"]; const CODE_BLOCK_END: &[u8] = b"```"; const HTML_COMMENT_END: &[u8] = b"-->"; @@ -498,10 +507,12 @@ impl<'s> Parser<'s> { { let html_comment = self.cursor.as_str()[..position].trim(); if html_comment == SECTION_CONFIG_SNAPSHOT { - self.process_snapshot_diagnostics()?; + self.process_mdtest_directive(MdtestDirective::SnapshotDiagnostics)?; + } else if html_comment == SECTION_CONFIG_PULLTYPES { + self.process_mdtest_directive(MdtestDirective::PullTypesSkip)?; } else if !HTML_COMMENT_ALLOWLIST.contains(&html_comment) { bail!( - "Unknown HTML comment `{}` -- possibly a `snapshot-diagnostics` typo? \ + "Unknown HTML comment `{}` -- possibly a typo? \ (Add to `HTML_COMMENT_ALLOWLIST` if this is a false positive)", html_comment ); @@ -636,7 +647,7 @@ impl<'s> Parser<'s> { level: header_level.try_into()?, parent_id: Some(parent), config: self.sections[parent].config.clone(), - snapshot_diagnostics: self.sections[parent].snapshot_diagnostics, + directives: self.sections[parent].directives, }; if !self.current_section_files.is_empty() { @@ -784,28 +795,28 @@ impl<'s> Parser<'s> { Ok(()) } - fn process_snapshot_diagnostics(&mut self) -> anyhow::Result<()> { + fn process_mdtest_directive(&mut self, directive: MdtestDirective) -> anyhow::Result<()> { if self.current_section_has_config { bail!( - "Section config to enable snapshotting diagnostics must come before \ + "Section config to enable {directive} must come before \ everything else (including TOML configuration blocks).", ); } if !self.current_section_files.is_empty() { bail!( - "Section config to enable snapshotting diagnostics must come before \ + "Section config to enable {directive} must come before \ everything else (including embedded files).", ); } let current_section = &mut self.sections[self.stack.top()]; - if current_section.snapshot_diagnostics { + if current_section.directives.has_directive_set(directive) { bail!( - "Section config to enable snapshotting diagnostics should appear \ + "Section config to enable {directive} should appear \ at most once.", ); } - current_section.snapshot_diagnostics = true; + current_section.directives.add_directive(directive); Ok(()) } @@ -824,6 +835,56 @@ impl<'s> Parser<'s> { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum MdtestDirective { + /// A directive to enable snapshotting diagnostics. + SnapshotDiagnostics, + /// A directive to skip pull types. + PullTypesSkip, +} + +impl std::fmt::Display for MdtestDirective { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + MdtestDirective::SnapshotDiagnostics => f.write_str("snapshotting diagnostics"), + MdtestDirective::PullTypesSkip => f.write_str("skipping the pull-types visitor"), + } + } +} + +bitflags::bitflags! { + /// Directives that can be applied to a Markdown test section. + #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] + pub(crate) struct MdtestDirectives: u8 { + /// We should snapshot diagnostics for this section. + const SNAPSHOT_DIAGNOSTICS = 1 << 0; + /// We should skip pulling types for this section. + const PULL_TYPES_SKIP = 1 << 1; + } +} + +impl MdtestDirectives { + const fn has_directive_set(self, directive: MdtestDirective) -> bool { + match directive { + MdtestDirective::SnapshotDiagnostics => { + self.contains(MdtestDirectives::SNAPSHOT_DIAGNOSTICS) + } + MdtestDirective::PullTypesSkip => self.contains(MdtestDirectives::PULL_TYPES_SKIP), + } + } + + fn add_directive(&mut self, directive: MdtestDirective) { + match directive { + MdtestDirective::SnapshotDiagnostics => { + self.insert(MdtestDirectives::SNAPSHOT_DIAGNOSTICS); + } + MdtestDirective::PullTypesSkip => { + self.insert(MdtestDirectives::PULL_TYPES_SKIP); + } + } + } +} + #[cfg(test)] mod tests { use ruff_python_ast::PySourceType; @@ -1906,7 +1967,7 @@ mod tests { let err = super::parse("file.md", &source).expect_err("Should fail to parse"); assert_eq!( err.to_string(), - "Unknown HTML comment `snpshotttt-digggggnosstic` -- possibly a `snapshot-diagnostics` typo? \ + "Unknown HTML comment `snpshotttt-digggggnosstic` -- possibly a typo? \ (Add to `HTML_COMMENT_ALLOWLIST` if this is a false positive)", ); }