[ty] Pull types on synthesized Python files created by mdtest (#18539)

This commit is contained in:
Alex Waygood 2025-06-12 10:32:17 +01:00 committed by GitHub
parent e6fe2af292
commit 324e5cbc19
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 361 additions and 191 deletions

2
Cargo.lock generated
View file

@ -4004,6 +4004,7 @@ dependencies = [
"test-case", "test-case",
"thiserror 2.0.12", "thiserror 2.0.12",
"tracing", "tracing",
"ty_python_semantic",
"ty_test", "ty_test",
"ty_vendored", "ty_vendored",
] ]
@ -4039,6 +4040,7 @@ name = "ty_test"
version = "0.0.0" version = "0.0.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bitflags 2.9.1",
"camino", "camino",
"colored 3.0.0", "colored 3.0.0",
"insta", "insta",

View file

@ -50,6 +50,7 @@ strum_macros = { workspace = true }
[dev-dependencies] [dev-dependencies]
ruff_db = { workspace = true, features = ["testing", "os"] } ruff_db = { workspace = true, features = ["testing", "os"] }
ruff_python_parser = { workspace = true } ruff_python_parser = { workspace = true }
ty_python_semantic = { workspace = true, features = ["testing"] }
ty_test = { workspace = true } ty_test = { workspace = true }
ty_vendored = { workspace = true } ty_vendored = { workspace = true }
@ -63,6 +64,7 @@ quickcheck_macros = { version = "1.0.0" }
[features] [features]
serde = ["ruff_db/serde", "dep:serde", "ruff_python_ast/serde"] serde = ["ruff_db/serde", "dep:serde", "ruff_python_ast/serde"]
testing = []
[lints] [lints]
workspace = true workspace = true

View file

@ -139,6 +139,8 @@ x: int = MagicMock()
## Invalid ## Invalid
<!-- pull-types:skip -->
`Any` cannot be parameterized: `Any` cannot be parameterized:
```py ```py

View file

@ -58,6 +58,8 @@ def _(c: Callable[[int, 42, str, False], None]):
### Missing return type ### Missing return type
<!-- pull-types:skip -->
Using a parameter list: Using a parameter list:
```py ```py

View file

@ -14,6 +14,8 @@ directly.
### Negation ### Negation
<!-- pull-types:skip -->
```py ```py
from typing import Literal from typing import Literal
from ty_extensions import Not, static_assert from ty_extensions import Not, static_assert
@ -371,6 +373,8 @@ static_assert(not is_single_valued(Literal["a"] | Literal["b"]))
## `TypeOf` ## `TypeOf`
<!-- pull-types:skip -->
We use `TypeOf` to get the inferred type of an expression. This is useful when we want to refer to We use `TypeOf` to get the inferred type of an expression. This is useful when we want to refer to
it in a type expression. For example, if we want to make sure that the class literal type `str` is a it in a type expression. For example, if we want to make sure that the class literal type `str` is a
subtype of `type[str]`, we can not use `is_subtype_of(str, type[str])`, as that would test if the subtype of `type[str]`, we can not use `is_subtype_of(str, type[str])`, as that would test if the
@ -412,6 +416,8 @@ def f(x: TypeOf) -> None:
## `CallableTypeOf` ## `CallableTypeOf`
<!-- pull-types:skip -->
The `CallableTypeOf` special form can be used to extract the `Callable` structural type inhabited by The `CallableTypeOf` special form can be used to extract the `Callable` structural type inhabited by
a given callable object. This can be used to get the externally visibly signature of the object, a given callable object. This can be used to get the externally visibly signature of the object,
which can then be used to test various type properties. which can then be used to test various type properties.

View file

@ -84,6 +84,8 @@ d.a = 2
## Too many arguments ## Too many arguments
<!-- pull-types:skip -->
```py ```py
from typing import ClassVar from typing import ClassVar

View file

@ -45,6 +45,8 @@ reveal_type(FINAL_E) # revealed: int
## Too many arguments ## Too many arguments
<!-- pull-types:skip -->
```py ```py
from typing import Final from typing import Final

View file

@ -35,6 +35,9 @@ pub mod types;
mod unpack; mod unpack;
mod util; mod util;
#[cfg(feature = "testing")]
pub mod pull_types;
type FxOrderSet<V> = ordermap::set::OrderSet<V, BuildHasherDefault<FxHasher>>; type FxOrderSet<V> = ordermap::set::OrderSet<V, BuildHasherDefault<FxHasher>>;
/// Returns the default registry with all known semantic lints. /// Returns the default registry with all known semantic lints.

View file

@ -0,0 +1,134 @@
//! A utility visitor for testing, which attempts to "pull a type" for ever sub-node in a given AST.
//!
//! This is used in the "corpus" and (indirectly) the "mdtest" integration tests for this crate.
//! (Mdtest uses the `pull_types` function via the `ty_test` crate.)
use crate::{Db, HasType, SemanticModel};
use ruff_db::{files::File, parsed::parsed_module};
use ruff_python_ast::{
self as ast, visitor::source_order, visitor::source_order::SourceOrderVisitor,
};
pub fn pull_types(db: &dyn Db, file: File) {
let mut visitor = PullTypesVisitor::new(db, file);
let ast = parsed_module(db.upcast(), file).load(db.upcast());
visitor.visit_body(ast.suite());
}
struct PullTypesVisitor<'db> {
model: SemanticModel<'db>,
}
impl<'db> PullTypesVisitor<'db> {
fn new(db: &'db dyn Db, file: File) -> Self {
Self {
model: SemanticModel::new(db, file),
}
}
fn visit_target(&mut self, target: &ast::Expr) {
match target {
ast::Expr::List(ast::ExprList { elts, .. })
| ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => {
for element in elts {
self.visit_target(element);
}
}
_ => self.visit_expr(target),
}
}
}
impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> {
fn visit_stmt(&mut self, stmt: &ast::Stmt) {
match stmt {
ast::Stmt::FunctionDef(function) => {
let _ty = function.inferred_type(&self.model);
}
ast::Stmt::ClassDef(class) => {
let _ty = class.inferred_type(&self.model);
}
ast::Stmt::Assign(assign) => {
for target in &assign.targets {
self.visit_target(target);
}
self.visit_expr(&assign.value);
return;
}
ast::Stmt::For(for_stmt) => {
self.visit_target(&for_stmt.target);
self.visit_expr(&for_stmt.iter);
self.visit_body(&for_stmt.body);
self.visit_body(&for_stmt.orelse);
return;
}
ast::Stmt::With(with_stmt) => {
for item in &with_stmt.items {
if let Some(target) = &item.optional_vars {
self.visit_target(target);
}
self.visit_expr(&item.context_expr);
}
self.visit_body(&with_stmt.body);
return;
}
ast::Stmt::AnnAssign(_)
| ast::Stmt::Return(_)
| ast::Stmt::Delete(_)
| ast::Stmt::AugAssign(_)
| ast::Stmt::TypeAlias(_)
| ast::Stmt::While(_)
| ast::Stmt::If(_)
| ast::Stmt::Match(_)
| ast::Stmt::Raise(_)
| ast::Stmt::Try(_)
| ast::Stmt::Assert(_)
| ast::Stmt::Import(_)
| ast::Stmt::ImportFrom(_)
| ast::Stmt::Global(_)
| ast::Stmt::Nonlocal(_)
| ast::Stmt::Expr(_)
| ast::Stmt::Pass(_)
| ast::Stmt::Break(_)
| ast::Stmt::Continue(_)
| ast::Stmt::IpyEscapeCommand(_) => {}
}
source_order::walk_stmt(self, stmt);
}
fn visit_expr(&mut self, expr: &ast::Expr) {
let _ty = expr.inferred_type(&self.model);
source_order::walk_expr(self, expr);
}
fn visit_comprehension(&mut self, comprehension: &ast::Comprehension) {
self.visit_expr(&comprehension.iter);
self.visit_target(&comprehension.target);
for if_expr in &comprehension.ifs {
self.visit_expr(if_expr);
}
}
fn visit_parameter(&mut self, parameter: &ast::Parameter) {
let _ty = parameter.inferred_type(&self.model);
source_order::walk_parameter(self, parameter);
}
fn visit_parameter_with_default(&mut self, parameter_with_default: &ast::ParameterWithDefault) {
let _ty = parameter_with_default.inferred_type(&self.model);
source_order::walk_parameter_with_default(self, parameter_with_default);
}
fn visit_alias(&mut self, alias: &ast::Alias) {
let _ty = alias.inferred_type(&self.model);
source_order::walk_alias(self, alias);
}
}

View file

@ -471,8 +471,10 @@ impl<'db> TypeInference<'db> {
#[track_caller] #[track_caller]
pub(crate) fn expression_type(&self, expression: ScopedExpressionId) -> Type<'db> { pub(crate) fn expression_type(&self, expression: ScopedExpressionId) -> Type<'db> {
self.try_expression_type(expression).expect( self.try_expression_type(expression).expect(
"expression should belong to this TypeInference region and \ "Failed to retrieve the inferred type for an `ast::Expr` node \
TypeInferenceBuilder should have inferred a type for it", passed to `TypeInference::expression_type()`. The `TypeInferenceBuilder` \
should infer and store types for all `ast::Expr` nodes in any `TypeInference` \
region it analyzes.",
) )
} }

View file

@ -1,19 +1,15 @@
use anyhow::{Context, anyhow}; use anyhow::{Context, anyhow};
use ruff_db::Upcast; use ruff_db::Upcast;
use ruff_db::files::{File, Files, system_path_to_file}; use ruff_db::files::{File, Files, system_path_to_file};
use ruff_db::parsed::parsed_module;
use ruff_db::system::{DbWithTestSystem, System, SystemPath, SystemPathBuf, TestSystem}; use ruff_db::system::{DbWithTestSystem, System, SystemPath, SystemPathBuf, TestSystem};
use ruff_db::vendored::VendoredFileSystem; use ruff_db::vendored::VendoredFileSystem;
use ruff_python_ast::visitor::source_order; use ruff_python_ast::PythonVersion;
use ruff_python_ast::visitor::source_order::SourceOrderVisitor;
use ruff_python_ast::{
self as ast, Alias, Comprehension, Expr, Parameter, ParameterWithDefault, PythonVersion, Stmt,
};
use ty_python_semantic::lint::{LintRegistry, RuleSelection}; use ty_python_semantic::lint::{LintRegistry, RuleSelection};
use ty_python_semantic::pull_types::pull_types;
use ty_python_semantic::{ use ty_python_semantic::{
Db, HasType, Program, ProgramSettings, PythonPlatform, PythonVersionSource, Program, ProgramSettings, PythonPlatform, PythonVersionSource, PythonVersionWithSource,
PythonVersionWithSource, SearchPathSettings, SemanticModel, default_lint_registry, SearchPathSettings, default_lint_registry,
}; };
fn get_cargo_workspace_root() -> anyhow::Result<SystemPathBuf> { fn get_cargo_workspace_root() -> anyhow::Result<SystemPathBuf> {
@ -174,129 +170,6 @@ fn run_corpus_tests(pattern: &str) -> anyhow::Result<()> {
Ok(()) Ok(())
} }
fn pull_types(db: &dyn Db, file: File) {
let mut visitor = PullTypesVisitor::new(db, file);
let ast = parsed_module(db.upcast(), file).load(db.upcast());
visitor.visit_body(ast.suite());
}
struct PullTypesVisitor<'db> {
model: SemanticModel<'db>,
}
impl<'db> PullTypesVisitor<'db> {
fn new(db: &'db dyn Db, file: File) -> Self {
Self {
model: SemanticModel::new(db, file),
}
}
fn visit_target(&mut self, target: &Expr) {
match target {
Expr::List(ast::ExprList { elts, .. }) | Expr::Tuple(ast::ExprTuple { elts, .. }) => {
for element in elts {
self.visit_target(element);
}
}
_ => self.visit_expr(target),
}
}
}
impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> {
fn visit_stmt(&mut self, stmt: &Stmt) {
match stmt {
Stmt::FunctionDef(function) => {
let _ty = function.inferred_type(&self.model);
}
Stmt::ClassDef(class) => {
let _ty = class.inferred_type(&self.model);
}
Stmt::Assign(assign) => {
for target in &assign.targets {
self.visit_target(target);
}
self.visit_expr(&assign.value);
return;
}
Stmt::For(for_stmt) => {
self.visit_target(&for_stmt.target);
self.visit_expr(&for_stmt.iter);
self.visit_body(&for_stmt.body);
self.visit_body(&for_stmt.orelse);
return;
}
Stmt::With(with_stmt) => {
for item in &with_stmt.items {
if let Some(target) = &item.optional_vars {
self.visit_target(target);
}
self.visit_expr(&item.context_expr);
}
self.visit_body(&with_stmt.body);
return;
}
Stmt::AnnAssign(_)
| Stmt::Return(_)
| Stmt::Delete(_)
| Stmt::AugAssign(_)
| Stmt::TypeAlias(_)
| Stmt::While(_)
| Stmt::If(_)
| Stmt::Match(_)
| Stmt::Raise(_)
| Stmt::Try(_)
| Stmt::Assert(_)
| Stmt::Import(_)
| Stmt::ImportFrom(_)
| Stmt::Global(_)
| Stmt::Nonlocal(_)
| Stmt::Expr(_)
| Stmt::Pass(_)
| Stmt::Break(_)
| Stmt::Continue(_)
| Stmt::IpyEscapeCommand(_) => {}
}
source_order::walk_stmt(self, stmt);
}
fn visit_expr(&mut self, expr: &Expr) {
let _ty = expr.inferred_type(&self.model);
source_order::walk_expr(self, expr);
}
fn visit_comprehension(&mut self, comprehension: &Comprehension) {
self.visit_expr(&comprehension.iter);
self.visit_target(&comprehension.target);
for if_expr in &comprehension.ifs {
self.visit_expr(if_expr);
}
}
fn visit_parameter(&mut self, parameter: &Parameter) {
let _ty = parameter.inferred_type(&self.model);
source_order::walk_parameter(self, parameter);
}
fn visit_parameter_with_default(&mut self, parameter_with_default: &ParameterWithDefault) {
let _ty = parameter_with_default.inferred_type(&self.model);
source_order::walk_parameter_with_default(self, parameter_with_default);
}
fn visit_alias(&mut self, alias: &Alias) {
let _ty = alias.inferred_type(&self.model);
source_order::walk_alias(self, alias);
}
}
/// Whether or not the .py/.pyi version of this file is expected to fail /// Whether or not the .py/.pyi version of this file is expected to fail
#[rustfmt::skip] #[rustfmt::skip]
const KNOWN_FAILURES: &[(&str, bool, bool)] = &[ const KNOWN_FAILURES: &[(&str, bool, bool)] = &[

View file

@ -18,10 +18,11 @@ ruff_python_trivia = { workspace = true }
ruff_source_file = { workspace = true } ruff_source_file = { workspace = true }
ruff_text_size = { workspace = true } ruff_text_size = { workspace = true }
ruff_python_ast = { workspace = true } ruff_python_ast = { workspace = true }
ty_python_semantic = { workspace = true, features = ["serde"] } ty_python_semantic = { workspace = true, features = ["serde", "testing"] }
ty_vendored = { workspace = true } ty_vendored = { workspace = true }
anyhow = { workspace = true } anyhow = { workspace = true }
bitflags = { workspace = true }
camino = { workspace = true } camino = { workspace = true }
colored = { workspace = true } colored = { workspace = true }
insta = { workspace = true, features = ["filters"] } insta = { workspace = true, features = ["filters"] }

View file

@ -1,4 +1,5 @@
use crate::config::Log; use crate::config::Log;
use crate::db::Db;
use crate::parser::{BacktickOffsets, EmbeddedFileSourceMap}; use crate::parser::{BacktickOffsets, EmbeddedFileSourceMap};
use camino::Utf8Path; use camino::Utf8Path;
use colored::Colorize; use colored::Colorize;
@ -17,6 +18,7 @@ use ruff_db::testing::{setup_logging, setup_logging_with_filter};
use ruff_source_file::{LineIndex, OneIndexed}; use ruff_source_file::{LineIndex, OneIndexed};
use std::backtrace::BacktraceStatus; use std::backtrace::BacktraceStatus;
use std::fmt::Write; use std::fmt::Write;
use ty_python_semantic::pull_types::pull_types;
use ty_python_semantic::types::check_types; use ty_python_semantic::types::check_types;
use ty_python_semantic::{ use ty_python_semantic::{
Program, ProgramSettings, PythonPath, PythonPlatform, PythonVersionSource, Program, ProgramSettings, PythonPath, PythonPlatform, PythonVersionSource,
@ -291,9 +293,31 @@ fn run_test(
// all diagnostics. Otherwise it remains empty. // all diagnostics. Otherwise it remains empty.
let mut snapshot_diagnostics = vec![]; let mut snapshot_diagnostics = vec![];
let failures: Failures = test_files let mut any_pull_types_failures = false;
.into_iter()
let mut failures: Failures = test_files
.iter()
.filter_map(|test_file| { .filter_map(|test_file| {
let pull_types_result = attempt_test(
db,
pull_types,
test_file,
"\"pull types\"",
Some(
"Note: either fix the panic or add the `<!-- pull-types:skip -->` \
directive to this test",
),
);
match pull_types_result {
Ok(()) => {}
Err(failures) => {
any_pull_types_failures = true;
if !test.should_skip_pulling_types() {
return Some(failures);
}
}
}
let parsed = parsed_module(db, test_file.file).load(db); let parsed = parsed_module(db, test_file.file).load(db);
let mut diagnostics: Vec<Diagnostic> = parsed let mut diagnostics: Vec<Diagnostic> = parsed
@ -309,64 +333,50 @@ fn run_test(
.map(|error| create_unsupported_syntax_diagnostic(test_file.file, error)), .map(|error| create_unsupported_syntax_diagnostic(test_file.file, error)),
); );
let type_diagnostics = match catch_unwind(|| check_types(db, test_file.file)) { let mdtest_result = attempt_test(db, check_types, test_file, "run mdtest", None);
Ok(type_diagnostics) => type_diagnostics, let type_diagnostics = match mdtest_result {
Err(info) => { Ok(diagnostics) => diagnostics,
let mut by_line = matcher::FailuresByLine::default(); Err(failures) => return Some(failures),
let mut messages = vec![];
match info.location {
Some(location) => messages.push(format!("panicked at {location}")),
None => messages.push("panicked at unknown location".to_string()),
}
match info.payload.as_str() {
Some(message) => messages.push(message.to_string()),
// Mimic the default panic hook's rendering of the panic payload if it's
// not a string.
None => messages.push("Box<dyn Any>".to_string()),
}
if let Some(backtrace) = info.backtrace {
match backtrace.status() {
BacktraceStatus::Disabled => {
let msg = "run with `RUST_BACKTRACE=1` environment variable to display a backtrace";
messages.push(msg.to_string());
}
BacktraceStatus::Captured => {
messages.extend(backtrace.to_string().split('\n').map(String::from));
}
_ => {}
}
}
if let Some(backtrace) = info.salsa_backtrace {
salsa::attach(db, || {
messages.extend(format!("{backtrace:#}").split('\n').map(String::from));
});
}
by_line.push(OneIndexed::from_zero_indexed(0), messages);
return Some(FileFailures {
backtick_offsets: test_file.backtick_offsets,
by_line,
});
}
}; };
diagnostics.extend(type_diagnostics.into_iter().cloned()); diagnostics.extend(type_diagnostics.into_iter().cloned());
diagnostics.sort_by(|left, right|left.rendering_sort_key(db).cmp(&right.rendering_sort_key(db))); diagnostics.sort_by(|left, right| {
left.rendering_sort_key(db)
.cmp(&right.rendering_sort_key(db))
});
let failure = match matcher::match_file(db, test_file.file, &diagnostics) { let failure = match matcher::match_file(db, test_file.file, &diagnostics) {
Ok(()) => None, Ok(()) => None,
Err(line_failures) => Some(FileFailures { Err(line_failures) => Some(FileFailures {
backtick_offsets: test_file.backtick_offsets, backtick_offsets: test_file.backtick_offsets.clone(),
by_line: line_failures, by_line: line_failures,
}), }),
}; };
if test.should_snapshot_diagnostics() { if test.should_snapshot_diagnostics() {
snapshot_diagnostics.extend(diagnostics); snapshot_diagnostics.extend(diagnostics);
} }
failure failure
}) })
.collect(); .collect();
if test.should_skip_pulling_types() && !any_pull_types_failures {
let mut by_line = matcher::FailuresByLine::default();
by_line.push(
OneIndexed::from_zero_indexed(0),
vec![
"Remove the `<!-- pull-types:skip -->` directive from this test: pulling types \
succeeded for all files in the test."
.to_string(),
],
);
let failure = FileFailures {
backtick_offsets: test_files[0].backtick_offsets.clone(),
by_line,
};
failures.push(failure);
}
if snapshot_diagnostics.is_empty() && test.should_snapshot_diagnostics() { if snapshot_diagnostics.is_empty() && test.should_snapshot_diagnostics() {
panic!( panic!(
"Test `{}` requested snapshotting diagnostics but it didn't produce any.", "Test `{}` requested snapshotting diagnostics but it didn't produce any.",
@ -462,3 +472,71 @@ fn create_diagnostic_snapshot(
} }
snapshot snapshot
} }
/// Run a function over an embedded test file, catching any panics that occur in the process.
///
/// If no panic occurs, the result of the function is returned as an `Ok()` variant.
///
/// If a panic occurs, a nicely formatted [`FileFailures`] is returned as an `Err()` variant.
/// This will be formatted into a diagnostic message by `ty_test`.
fn attempt_test<'db, T, F>(
db: &'db Db,
test_fn: F,
test_file: &TestFile,
action: &str,
clarification: Option<&str>,
) -> Result<T, FileFailures>
where
F: FnOnce(&'db dyn ty_python_semantic::Db, File) -> T + std::panic::UnwindSafe,
{
catch_unwind(|| test_fn(db, test_file.file)).map_err(|info| {
let mut by_line = matcher::FailuresByLine::default();
let mut messages = vec![];
match info.location {
Some(location) => messages.push(format!(
"Attempting to {action} caused a panic at {location}"
)),
None => messages.push(format!(
"Attempting to {action} caused a panic at an unknown location",
)),
}
if let Some(clarification) = clarification {
messages.push(clarification.to_string());
}
messages.push(String::new());
match info.payload.as_str() {
Some(message) => messages.push(message.to_string()),
// Mimic the default panic hook's rendering of the panic payload if it's
// not a string.
None => messages.push("Box<dyn Any>".to_string()),
}
messages.push(String::new());
if let Some(backtrace) = info.backtrace {
match backtrace.status() {
BacktraceStatus::Disabled => {
let msg =
"run with `RUST_BACKTRACE=1` environment variable to display a backtrace";
messages.push(msg.to_string());
}
BacktraceStatus::Captured => {
messages.extend(backtrace.to_string().split('\n').map(String::from));
}
_ => {}
}
}
if let Some(backtrace) = info.salsa_backtrace {
salsa::attach(db, || {
messages.extend(format!("{backtrace:#}").split('\n').map(String::from));
});
}
by_line.push(OneIndexed::from_zero_indexed(0), messages);
FileFailures {
backtick_offsets: test_file.backtick_offsets.clone(),
by_line,
}
})
}

View file

@ -143,7 +143,15 @@ impl<'m, 's> MarkdownTest<'m, 's> {
} }
pub(super) fn should_snapshot_diagnostics(&self) -> bool { pub(super) fn should_snapshot_diagnostics(&self) -> bool {
self.section.snapshot_diagnostics self.section
.directives
.contains(MdtestDirectives::SNAPSHOT_DIAGNOSTICS)
}
pub(super) fn should_skip_pulling_types(&self) -> bool {
self.section
.directives
.contains(MdtestDirectives::PULL_TYPES_SKIP)
} }
} }
@ -194,7 +202,7 @@ struct Section<'s> {
level: u8, level: u8,
parent_id: Option<SectionId>, parent_id: Option<SectionId>,
config: MarkdownTestConfig, config: MarkdownTestConfig,
snapshot_diagnostics: bool, directives: MdtestDirectives,
} }
#[newtype_index] #[newtype_index]
@ -428,7 +436,7 @@ impl<'s> Parser<'s> {
level: 0, level: 0,
parent_id: None, parent_id: None,
config: MarkdownTestConfig::default(), config: MarkdownTestConfig::default(),
snapshot_diagnostics: false, directives: MdtestDirectives::default(),
}); });
Self { Self {
sections, sections,
@ -486,6 +494,7 @@ impl<'s> Parser<'s> {
fn parse_impl(&mut self) -> anyhow::Result<()> { fn parse_impl(&mut self) -> anyhow::Result<()> {
const SECTION_CONFIG_SNAPSHOT: &str = "snapshot-diagnostics"; const SECTION_CONFIG_SNAPSHOT: &str = "snapshot-diagnostics";
const SECTION_CONFIG_PULLTYPES: &str = "pull-types:skip";
const HTML_COMMENT_ALLOWLIST: &[&str] = &["blacken-docs:on", "blacken-docs:off"]; const HTML_COMMENT_ALLOWLIST: &[&str] = &["blacken-docs:on", "blacken-docs:off"];
const CODE_BLOCK_END: &[u8] = b"```"; const CODE_BLOCK_END: &[u8] = b"```";
const HTML_COMMENT_END: &[u8] = b"-->"; const HTML_COMMENT_END: &[u8] = b"-->";
@ -498,10 +507,12 @@ impl<'s> Parser<'s> {
{ {
let html_comment = self.cursor.as_str()[..position].trim(); let html_comment = self.cursor.as_str()[..position].trim();
if html_comment == SECTION_CONFIG_SNAPSHOT { if html_comment == SECTION_CONFIG_SNAPSHOT {
self.process_snapshot_diagnostics()?; self.process_mdtest_directive(MdtestDirective::SnapshotDiagnostics)?;
} else if html_comment == SECTION_CONFIG_PULLTYPES {
self.process_mdtest_directive(MdtestDirective::PullTypesSkip)?;
} else if !HTML_COMMENT_ALLOWLIST.contains(&html_comment) { } else if !HTML_COMMENT_ALLOWLIST.contains(&html_comment) {
bail!( bail!(
"Unknown HTML comment `{}` -- possibly a `snapshot-diagnostics` typo? \ "Unknown HTML comment `{}` -- possibly a typo? \
(Add to `HTML_COMMENT_ALLOWLIST` if this is a false positive)", (Add to `HTML_COMMENT_ALLOWLIST` if this is a false positive)",
html_comment html_comment
); );
@ -636,7 +647,7 @@ impl<'s> Parser<'s> {
level: header_level.try_into()?, level: header_level.try_into()?,
parent_id: Some(parent), parent_id: Some(parent),
config: self.sections[parent].config.clone(), config: self.sections[parent].config.clone(),
snapshot_diagnostics: self.sections[parent].snapshot_diagnostics, directives: self.sections[parent].directives,
}; };
if !self.current_section_files.is_empty() { if !self.current_section_files.is_empty() {
@ -784,28 +795,28 @@ impl<'s> Parser<'s> {
Ok(()) Ok(())
} }
fn process_snapshot_diagnostics(&mut self) -> anyhow::Result<()> { fn process_mdtest_directive(&mut self, directive: MdtestDirective) -> anyhow::Result<()> {
if self.current_section_has_config { if self.current_section_has_config {
bail!( bail!(
"Section config to enable snapshotting diagnostics must come before \ "Section config to enable {directive} must come before \
everything else (including TOML configuration blocks).", everything else (including TOML configuration blocks).",
); );
} }
if !self.current_section_files.is_empty() { if !self.current_section_files.is_empty() {
bail!( bail!(
"Section config to enable snapshotting diagnostics must come before \ "Section config to enable {directive} must come before \
everything else (including embedded files).", everything else (including embedded files).",
); );
} }
let current_section = &mut self.sections[self.stack.top()]; let current_section = &mut self.sections[self.stack.top()];
if current_section.snapshot_diagnostics { if current_section.directives.has_directive_set(directive) {
bail!( bail!(
"Section config to enable snapshotting diagnostics should appear \ "Section config to enable {directive} should appear \
at most once.", at most once.",
); );
} }
current_section.snapshot_diagnostics = true; current_section.directives.add_directive(directive);
Ok(()) Ok(())
} }
@ -824,6 +835,56 @@ impl<'s> Parser<'s> {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MdtestDirective {
/// A directive to enable snapshotting diagnostics.
SnapshotDiagnostics,
/// A directive to skip pull types.
PullTypesSkip,
}
impl std::fmt::Display for MdtestDirective {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
MdtestDirective::SnapshotDiagnostics => f.write_str("snapshotting diagnostics"),
MdtestDirective::PullTypesSkip => f.write_str("skipping the pull-types visitor"),
}
}
}
bitflags::bitflags! {
/// Directives that can be applied to a Markdown test section.
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct MdtestDirectives: u8 {
/// We should snapshot diagnostics for this section.
const SNAPSHOT_DIAGNOSTICS = 1 << 0;
/// We should skip pulling types for this section.
const PULL_TYPES_SKIP = 1 << 1;
}
}
impl MdtestDirectives {
const fn has_directive_set(self, directive: MdtestDirective) -> bool {
match directive {
MdtestDirective::SnapshotDiagnostics => {
self.contains(MdtestDirectives::SNAPSHOT_DIAGNOSTICS)
}
MdtestDirective::PullTypesSkip => self.contains(MdtestDirectives::PULL_TYPES_SKIP),
}
}
fn add_directive(&mut self, directive: MdtestDirective) {
match directive {
MdtestDirective::SnapshotDiagnostics => {
self.insert(MdtestDirectives::SNAPSHOT_DIAGNOSTICS);
}
MdtestDirective::PullTypesSkip => {
self.insert(MdtestDirectives::PULL_TYPES_SKIP);
}
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use ruff_python_ast::PySourceType; use ruff_python_ast::PySourceType;
@ -1906,7 +1967,7 @@ mod tests {
let err = super::parse("file.md", &source).expect_err("Should fail to parse"); let err = super::parse("file.md", &source).expect_err("Should fail to parse");
assert_eq!( assert_eq!(
err.to_string(), err.to_string(),
"Unknown HTML comment `snpshotttt-digggggnosstic` -- possibly a `snapshot-diagnostics` typo? \ "Unknown HTML comment `snpshotttt-digggggnosstic` -- possibly a typo? \
(Add to `HTML_COMMENT_ALLOWLIST` if this is a false positive)", (Add to `HTML_COMMENT_ALLOWLIST` if this is a false positive)",
); );
} }