mirror of
				https://github.com/astral-sh/ruff.git
				synced 2025-10-24 17:16:53 +00:00 
			
		
		
		
	 4c3e1930f6
			
		
	
	
		4c3e1930f6
		
			
		
	
	
	
	
		
			
			<!-- Thank you for contributing to Ruff/ty! To help us out with reviewing, please consider the following: - Does this pull request include a summary of the change? (See below.) - Does this pull request include a descriptive title? (Please prefix with `[ty]` for ty pull requests.) - Does this pull request include references to any relevant issues? --> ## Summary This PR implements https://docs.astral.sh/ruff/rules/yield-from-in-async-function/ as a syntax semantic error ## Test Plan <!-- How was it tested? --> I have written a simple inline test as directed in [https://github.com/astral-sh/ruff/issues/17412](https://github.com/astral-sh/ruff/issues/17412) --------- Signed-off-by: 11happy <soni5happy@gmail.com> Co-authored-by: Alex Waygood <alex.waygood@gmail.com> Co-authored-by: Brent Westbrook <36778786+ntBre@users.noreply.github.com>
		
			
				
	
	
		
			660 lines
		
	
	
	
		
			20 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
			
		
		
	
	
			660 lines
		
	
	
	
		
			20 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
| use std::cell::RefCell;
 | |
| use std::cmp::Ordering;
 | |
| use std::fmt::{Formatter, Write};
 | |
| use std::fs;
 | |
| use std::path::Path;
 | |
| 
 | |
| use ruff_annotate_snippets::{Level, Renderer, Snippet};
 | |
| use ruff_python_ast::visitor::Visitor;
 | |
| use ruff_python_ast::visitor::source_order::{SourceOrderVisitor, TraversalSignal, walk_module};
 | |
| use ruff_python_ast::{self as ast, AnyNodeRef, Mod, PythonVersion};
 | |
| use ruff_python_parser::semantic_errors::{
 | |
|     SemanticSyntaxChecker, SemanticSyntaxContext, SemanticSyntaxError,
 | |
| };
 | |
| use ruff_python_parser::{Mode, ParseErrorType, ParseOptions, Token, parse_unchecked};
 | |
| use ruff_source_file::{LineIndex, OneIndexed, SourceCode};
 | |
| use ruff_text_size::{Ranged, TextLen, TextRange, TextSize};
 | |
| 
 | |
| #[test]
 | |
| fn valid_syntax() {
 | |
|     insta::glob!("../resources", "valid/**/*.py", test_valid_syntax);
 | |
| }
 | |
| 
 | |
| #[test]
 | |
| fn invalid_syntax() {
 | |
|     insta::glob!("../resources", "invalid/**/*.py", test_invalid_syntax);
 | |
| }
 | |
| 
 | |
| #[test]
 | |
| fn inline_ok() {
 | |
|     insta::glob!("../resources/inline", "ok/**/*.py", test_valid_syntax);
 | |
| }
 | |
| 
 | |
| #[test]
 | |
| fn inline_err() {
 | |
|     insta::glob!("../resources/inline", "err/**/*.py", test_invalid_syntax);
 | |
| }
 | |
| 
 | |
| /// Asserts that the parser generates no syntax errors for a valid program.
 | |
| /// Snapshots the AST.
 | |
| fn test_valid_syntax(input_path: &Path) {
 | |
|     let source = fs::read_to_string(input_path).expect("Expected test file to exist");
 | |
|     let options = extract_options(&source).unwrap_or_else(|| {
 | |
|         ParseOptions::from(Mode::Module).with_target_version(PythonVersion::latest_preview())
 | |
|     });
 | |
|     let parsed = parse_unchecked(&source, options.clone());
 | |
| 
 | |
|     if parsed.has_syntax_errors() {
 | |
|         let line_index = LineIndex::from_source_text(&source);
 | |
|         let source_code = SourceCode::new(&source, &line_index);
 | |
| 
 | |
|         let mut message = "Expected no syntax errors for a valid program but the parser generated the following errors:\n".to_string();
 | |
| 
 | |
|         for error in parsed.errors() {
 | |
|             writeln!(
 | |
|                 &mut message,
 | |
|                 "{}\n",
 | |
|                 CodeFrame {
 | |
|                     range: error.location,
 | |
|                     error,
 | |
|                     source_code: &source_code,
 | |
|                 }
 | |
|             )
 | |
|             .unwrap();
 | |
|         }
 | |
| 
 | |
|         for error in parsed.unsupported_syntax_errors() {
 | |
|             writeln!(
 | |
|                 &mut message,
 | |
|                 "{}\n",
 | |
|                 CodeFrame {
 | |
|                     range: error.range,
 | |
|                     error: &ParseErrorType::OtherError(error.to_string()),
 | |
|                     source_code: &source_code,
 | |
|                 }
 | |
|             )
 | |
|             .unwrap();
 | |
|         }
 | |
| 
 | |
|         panic!("{input_path:?}: {message}");
 | |
|     }
 | |
| 
 | |
|     validate_tokens(parsed.tokens(), source.text_len(), input_path);
 | |
|     validate_ast(parsed.syntax(), source.text_len(), input_path);
 | |
| 
 | |
|     let mut output = String::new();
 | |
|     writeln!(&mut output, "## AST").unwrap();
 | |
|     writeln!(&mut output, "\n```\n{:#?}\n```", parsed.syntax()).unwrap();
 | |
| 
 | |
|     let parsed = parsed.try_into_module().expect("Parsed with Mode::Module");
 | |
| 
 | |
|     let mut visitor =
 | |
|         SemanticSyntaxCheckerVisitor::new(&source).with_python_version(options.target_version());
 | |
| 
 | |
|     for stmt in parsed.suite() {
 | |
|         visitor.visit_stmt(stmt);
 | |
|     }
 | |
| 
 | |
|     let semantic_syntax_errors = visitor.into_diagnostics();
 | |
| 
 | |
|     if !semantic_syntax_errors.is_empty() {
 | |
|         let mut message = "Expected no semantic syntax errors for a valid program:\n".to_string();
 | |
| 
 | |
|         let line_index = LineIndex::from_source_text(&source);
 | |
|         let source_code = SourceCode::new(&source, &line_index);
 | |
| 
 | |
|         for error in semantic_syntax_errors {
 | |
|             writeln!(
 | |
|                 &mut message,
 | |
|                 "{}\n",
 | |
|                 CodeFrame {
 | |
|                     range: error.range,
 | |
|                     error: &ParseErrorType::OtherError(error.to_string()),
 | |
|                     source_code: &source_code,
 | |
|                 }
 | |
|             )
 | |
|             .unwrap();
 | |
|         }
 | |
| 
 | |
|         panic!("{input_path:?}: {message}");
 | |
|     }
 | |
| 
 | |
|     insta::with_settings!({
 | |
|         omit_expression => true,
 | |
|         input_file => input_path,
 | |
|         prepend_module_to_snapshot => false,
 | |
|     }, {
 | |
|         insta::assert_snapshot!(output);
 | |
|     });
 | |
| }
 | |
| 
 | |
| /// Assert that the parser generates at least one syntax error for the given input file.
 | |
| /// Snapshots the AST and the error messages.
 | |
| fn test_invalid_syntax(input_path: &Path) {
 | |
|     let source = fs::read_to_string(input_path).expect("Expected test file to exist");
 | |
|     let options = extract_options(&source).unwrap_or_else(|| {
 | |
|         ParseOptions::from(Mode::Module).with_target_version(PythonVersion::PY314)
 | |
|     });
 | |
|     let parsed = parse_unchecked(&source, options.clone());
 | |
| 
 | |
|     validate_tokens(parsed.tokens(), source.text_len(), input_path);
 | |
|     validate_ast(parsed.syntax(), source.text_len(), input_path);
 | |
| 
 | |
|     let mut output = String::new();
 | |
|     writeln!(&mut output, "## AST").unwrap();
 | |
|     writeln!(&mut output, "\n```\n{:#?}\n```", parsed.syntax()).unwrap();
 | |
| 
 | |
|     let line_index = LineIndex::from_source_text(&source);
 | |
|     let source_code = SourceCode::new(&source, &line_index);
 | |
| 
 | |
|     if !parsed.errors().is_empty() {
 | |
|         writeln!(&mut output, "## Errors\n").unwrap();
 | |
|     }
 | |
| 
 | |
|     for error in parsed.errors() {
 | |
|         writeln!(
 | |
|             &mut output,
 | |
|             "{}\n",
 | |
|             CodeFrame {
 | |
|                 range: error.location,
 | |
|                 error,
 | |
|                 source_code: &source_code,
 | |
|             }
 | |
|         )
 | |
|         .unwrap();
 | |
|     }
 | |
| 
 | |
|     if !parsed.unsupported_syntax_errors().is_empty() {
 | |
|         writeln!(&mut output, "## Unsupported Syntax Errors\n").unwrap();
 | |
|     }
 | |
| 
 | |
|     for error in parsed.unsupported_syntax_errors() {
 | |
|         writeln!(
 | |
|             &mut output,
 | |
|             "{}\n",
 | |
|             CodeFrame {
 | |
|                 range: error.range,
 | |
|                 error: &ParseErrorType::OtherError(error.to_string()),
 | |
|                 source_code: &source_code,
 | |
|             }
 | |
|         )
 | |
|         .unwrap();
 | |
|     }
 | |
| 
 | |
|     let parsed = parsed.try_into_module().expect("Parsed with Mode::Module");
 | |
| 
 | |
|     let mut visitor =
 | |
|         SemanticSyntaxCheckerVisitor::new(&source).with_python_version(options.target_version());
 | |
| 
 | |
|     for stmt in parsed.suite() {
 | |
|         visitor.visit_stmt(stmt);
 | |
|     }
 | |
| 
 | |
|     let semantic_syntax_errors = visitor.into_diagnostics();
 | |
| 
 | |
|     assert!(
 | |
|         parsed.has_syntax_errors() || !semantic_syntax_errors.is_empty(),
 | |
|         "{input_path:?}: Expected parser to generate at least one syntax error for a program containing syntax errors."
 | |
|     );
 | |
| 
 | |
|     if !semantic_syntax_errors.is_empty() {
 | |
|         writeln!(&mut output, "## Semantic Syntax Errors\n").unwrap();
 | |
|     }
 | |
| 
 | |
|     for error in semantic_syntax_errors {
 | |
|         writeln!(
 | |
|             &mut output,
 | |
|             "{}\n",
 | |
|             CodeFrame {
 | |
|                 range: error.range,
 | |
|                 error: &ParseErrorType::OtherError(error.to_string()),
 | |
|                 source_code: &source_code,
 | |
|             }
 | |
|         )
 | |
|         .unwrap();
 | |
|     }
 | |
| 
 | |
|     insta::with_settings!({
 | |
|         omit_expression => true,
 | |
|         input_file => input_path,
 | |
|         prepend_module_to_snapshot => false,
 | |
|     }, {
 | |
|         insta::assert_snapshot!(output);
 | |
|     });
 | |
| }
 | |
| 
 | |
| /// Copy of [`ParseOptions`] for deriving [`Deserialize`] with serde as a dev-dependency.
 | |
| #[derive(serde::Deserialize)]
 | |
| #[serde(rename_all = "kebab-case")]
 | |
| struct JsonParseOptions {
 | |
|     #[serde(default)]
 | |
|     mode: JsonMode,
 | |
|     #[serde(default)]
 | |
|     target_version: PythonVersion,
 | |
| }
 | |
| 
 | |
| /// Copy of [`Mode`] for deserialization.
 | |
| #[derive(Default, serde::Deserialize)]
 | |
| #[serde(rename_all = "kebab-case")]
 | |
| enum JsonMode {
 | |
|     #[default]
 | |
|     Module,
 | |
|     Expression,
 | |
|     ParenthesizedExpression,
 | |
|     Ipython,
 | |
| }
 | |
| 
 | |
| impl From<JsonParseOptions> for ParseOptions {
 | |
|     fn from(value: JsonParseOptions) -> Self {
 | |
|         let mode = match value.mode {
 | |
|             JsonMode::Module => Mode::Module,
 | |
|             JsonMode::Expression => Mode::Expression,
 | |
|             JsonMode::ParenthesizedExpression => Mode::ParenthesizedExpression,
 | |
|             JsonMode::Ipython => Mode::Ipython,
 | |
|         };
 | |
|         Self::from(mode).with_target_version(value.target_version)
 | |
|     }
 | |
| }
 | |
| 
 | |
| /// Extract [`ParseOptions`] from an initial pragma line, if present.
 | |
| ///
 | |
| /// For example,
 | |
| ///
 | |
| /// ```python
 | |
| /// # parse_options: { "target-version": "3.10" }
 | |
| /// def f(): ...
 | |
| fn extract_options(source: &str) -> Option<ParseOptions> {
 | |
|     let header = source.lines().next()?;
 | |
|     let (_label, options) = header.split_once("# parse_options: ")?;
 | |
|     let options: Option<JsonParseOptions> = serde_json::from_str(options.trim()).ok();
 | |
|     options.map(ParseOptions::from)
 | |
| }
 | |
| 
 | |
| // Test that is intentionally ignored by default.
 | |
| // Use it for quickly debugging a parser issue.
 | |
| #[test]
 | |
| #[ignore]
 | |
| #[expect(clippy::print_stdout)]
 | |
| fn parser_quick_test() {
 | |
|     let source = "\
 | |
| f'{'
 | |
| f'{foo!r'
 | |
| ";
 | |
| 
 | |
|     let parsed = parse_unchecked(source, ParseOptions::from(Mode::Module));
 | |
| 
 | |
|     println!("AST:\n----\n{:#?}", parsed.syntax());
 | |
|     println!("Tokens:\n-------\n{:#?}", parsed.tokens());
 | |
| 
 | |
|     if parsed.has_invalid_syntax() {
 | |
|         println!("Errors:\n-------");
 | |
| 
 | |
|         let line_index = LineIndex::from_source_text(source);
 | |
|         let source_code = SourceCode::new(source, &line_index);
 | |
| 
 | |
|         for error in parsed.errors() {
 | |
|             // Sometimes the code frame doesn't show the error message, so we print
 | |
|             // the message as well.
 | |
|             println!("Syntax Error: {error}");
 | |
|             println!(
 | |
|                 "{}\n",
 | |
|                 CodeFrame {
 | |
|                     range: error.location,
 | |
|                     error,
 | |
|                     source_code: &source_code,
 | |
|                 }
 | |
|             );
 | |
|         }
 | |
| 
 | |
|         println!();
 | |
|     }
 | |
| }
 | |
| 
 | |
| struct CodeFrame<'a> {
 | |
|     range: TextRange,
 | |
|     error: &'a ParseErrorType,
 | |
|     source_code: &'a SourceCode<'a, 'a>,
 | |
| }
 | |
| 
 | |
| impl std::fmt::Display for CodeFrame<'_> {
 | |
|     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 | |
|         // Copied and modified from ruff_linter/src/message/text.rs
 | |
|         let content_start_index = self.source_code.line_index(self.range.start());
 | |
|         let mut start_index = content_start_index.saturating_sub(2);
 | |
| 
 | |
|         // Trim leading empty lines.
 | |
|         while start_index < content_start_index {
 | |
|             if !self.source_code.line_text(start_index).trim().is_empty() {
 | |
|                 break;
 | |
|             }
 | |
|             start_index = start_index.saturating_add(1);
 | |
|         }
 | |
| 
 | |
|         let content_end_index = self.source_code.line_index(self.range.end());
 | |
|         let mut end_index = content_end_index
 | |
|             .saturating_add(2)
 | |
|             .min(OneIndexed::from_zero_indexed(self.source_code.line_count()));
 | |
| 
 | |
|         // Trim trailing empty lines.
 | |
|         while end_index > content_end_index {
 | |
|             if !self.source_code.line_text(end_index).trim().is_empty() {
 | |
|                 break;
 | |
|             }
 | |
| 
 | |
|             end_index = end_index.saturating_sub(1);
 | |
|         }
 | |
| 
 | |
|         let start_offset = self.source_code.line_start(start_index);
 | |
|         let end_offset = self.source_code.line_end(end_index);
 | |
| 
 | |
|         let annotation_range = self.range - start_offset;
 | |
|         let source = self
 | |
|             .source_code
 | |
|             .slice(TextRange::new(start_offset, end_offset));
 | |
| 
 | |
|         let label = format!("Syntax Error: {error}", error = self.error);
 | |
| 
 | |
|         let span = usize::from(annotation_range.start())..usize::from(annotation_range.end());
 | |
|         let annotation = Level::Error.span(span).label(&label);
 | |
|         let snippet = Snippet::source(source)
 | |
|             .line_start(start_index.get())
 | |
|             .annotation(annotation)
 | |
|             .fold(false);
 | |
|         let message = Level::None.title("").snippet(snippet);
 | |
|         let renderer = Renderer::plain().cut_indicator("…");
 | |
|         let rendered = renderer.render(message);
 | |
|         writeln!(f, "{rendered}")
 | |
|     }
 | |
| }
 | |
| 
 | |
| /// Verifies that:
 | |
| /// * the ranges are strictly increasing when loop the tokens in insertion order
 | |
| /// * all ranges are within the length of the source code
 | |
| fn validate_tokens(tokens: &[Token], source_length: TextSize, test_path: &Path) {
 | |
|     let mut previous: Option<&Token> = None;
 | |
| 
 | |
|     for token in tokens {
 | |
|         assert!(
 | |
|             token.end() <= source_length,
 | |
|             "{path}: Token range exceeds the source code length. Token: {token:#?}",
 | |
|             path = test_path.display()
 | |
|         );
 | |
| 
 | |
|         if let Some(previous) = previous {
 | |
|             assert_eq!(
 | |
|                 previous.range().ordering(token.range()),
 | |
|                 Ordering::Less,
 | |
|                 "{path}: Token ranges are not in increasing order
 | |
| Previous token: {previous:#?}
 | |
| Current token: {token:#?}
 | |
| Tokens: {tokens:#?}
 | |
| ",
 | |
|                 path = test_path.display(),
 | |
|             );
 | |
|         }
 | |
| 
 | |
|         previous = Some(token);
 | |
|     }
 | |
| }
 | |
| 
 | |
| /// Verifies that:
 | |
| /// * the range of the parent node fully encloses all its child nodes
 | |
| /// * the ranges are strictly increasing when traversing the nodes in pre-order.
 | |
| /// * all ranges are within the length of the source code.
 | |
| fn validate_ast(root: &Mod, source_len: TextSize, test_path: &Path) {
 | |
|     walk_module(&mut ValidateAstVisitor::new(source_len, test_path), root);
 | |
| }
 | |
| 
 | |
| #[derive(Debug)]
 | |
| struct ValidateAstVisitor<'a> {
 | |
|     parents: Vec<AnyNodeRef<'a>>,
 | |
|     previous: Option<AnyNodeRef<'a>>,
 | |
|     source_length: TextSize,
 | |
|     test_path: &'a Path,
 | |
| }
 | |
| 
 | |
| impl<'a> ValidateAstVisitor<'a> {
 | |
|     fn new(source_length: TextSize, test_path: &'a Path) -> Self {
 | |
|         Self {
 | |
|             parents: Vec::new(),
 | |
|             previous: None,
 | |
|             source_length,
 | |
|             test_path,
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| impl<'ast> SourceOrderVisitor<'ast> for ValidateAstVisitor<'ast> {
 | |
|     fn enter_node(&mut self, node: AnyNodeRef<'ast>) -> TraversalSignal {
 | |
|         assert!(
 | |
|             node.end() <= self.source_length,
 | |
|             "{path}: The range of the node exceeds the length of the source code. Node: {node:#?}",
 | |
|             path = self.test_path.display()
 | |
|         );
 | |
| 
 | |
|         if let Some(previous) = self.previous {
 | |
|             assert_ne!(
 | |
|                 previous.range().ordering(node.range()),
 | |
|                 Ordering::Greater,
 | |
|                 "{path}: The ranges of the nodes are not strictly increasing when traversing the AST in pre-order.\nPrevious node: {previous:#?}\n\nCurrent node: {node:#?}\n\nRoot: {root:#?}",
 | |
|                 path = self.test_path.display(),
 | |
|                 root = self.parents.first()
 | |
|             );
 | |
|         }
 | |
| 
 | |
|         if let Some(parent) = self.parents.last() {
 | |
|             assert!(
 | |
|                 parent.range().contains_range(node.range()),
 | |
|                 "{path}: The range of the parent node does not fully enclose the range of the child node.\nParent node: {parent:#?}\n\nChild node: {node:#?}\n\nRoot: {root:#?}",
 | |
|                 path = self.test_path.display(),
 | |
|                 root = self.parents.first()
 | |
|             );
 | |
|         }
 | |
| 
 | |
|         self.parents.push(node);
 | |
| 
 | |
|         TraversalSignal::Traverse
 | |
|     }
 | |
| 
 | |
|     fn leave_node(&mut self, node: AnyNodeRef<'ast>) {
 | |
|         self.parents.pop().expect("Expected tree to be balanced");
 | |
| 
 | |
|         self.previous = Some(node);
 | |
|     }
 | |
| }
 | |
| 
 | |
| enum Scope {
 | |
|     Module,
 | |
|     Function { is_async: bool },
 | |
|     Comprehension { is_async: bool },
 | |
|     Class,
 | |
| }
 | |
| 
 | |
| struct SemanticSyntaxCheckerVisitor<'a> {
 | |
|     checker: SemanticSyntaxChecker,
 | |
|     diagnostics: RefCell<Vec<SemanticSyntaxError>>,
 | |
|     python_version: PythonVersion,
 | |
|     source: &'a str,
 | |
|     scopes: Vec<Scope>,
 | |
| }
 | |
| 
 | |
| impl<'a> SemanticSyntaxCheckerVisitor<'a> {
 | |
|     fn new(source: &'a str) -> Self {
 | |
|         Self {
 | |
|             checker: SemanticSyntaxChecker::new(),
 | |
|             diagnostics: RefCell::default(),
 | |
|             python_version: PythonVersion::default(),
 | |
|             source,
 | |
|             scopes: vec![Scope::Module],
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     #[must_use]
 | |
|     fn with_python_version(mut self, python_version: PythonVersion) -> Self {
 | |
|         self.python_version = python_version;
 | |
|         self
 | |
|     }
 | |
| 
 | |
|     fn into_diagnostics(self) -> Vec<SemanticSyntaxError> {
 | |
|         self.diagnostics.into_inner()
 | |
|     }
 | |
| 
 | |
|     fn with_semantic_checker(&mut self, f: impl FnOnce(&mut SemanticSyntaxChecker, &Self)) {
 | |
|         let mut checker = std::mem::take(&mut self.checker);
 | |
|         f(&mut checker, self);
 | |
|         self.checker = checker;
 | |
|     }
 | |
| }
 | |
| 
 | |
| impl SemanticSyntaxContext for SemanticSyntaxCheckerVisitor<'_> {
 | |
|     fn future_annotations_or_stub(&self) -> bool {
 | |
|         false
 | |
|     }
 | |
| 
 | |
|     fn python_version(&self) -> PythonVersion {
 | |
|         self.python_version
 | |
|     }
 | |
| 
 | |
|     fn report_semantic_error(&self, error: SemanticSyntaxError) {
 | |
|         self.diagnostics.borrow_mut().push(error);
 | |
|     }
 | |
| 
 | |
|     fn source(&self) -> &str {
 | |
|         self.source
 | |
|     }
 | |
| 
 | |
|     fn global(&self, _name: &str) -> Option<TextRange> {
 | |
|         None
 | |
|     }
 | |
| 
 | |
|     fn in_async_context(&self) -> bool {
 | |
|         if let Some(scope) = self.scopes.iter().next_back() {
 | |
|             match scope {
 | |
|                 Scope::Class | Scope::Module => false,
 | |
|                 Scope::Comprehension { is_async } => *is_async,
 | |
|                 Scope::Function { is_async } => *is_async,
 | |
|             }
 | |
|         } else {
 | |
|             false
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     fn in_sync_comprehension(&self) -> bool {
 | |
|         for scope in &self.scopes {
 | |
|             if let Scope::Comprehension { is_async: false } = scope {
 | |
|                 return true;
 | |
|             }
 | |
|         }
 | |
|         false
 | |
|     }
 | |
| 
 | |
|     fn in_module_scope(&self) -> bool {
 | |
|         self.scopes.len() == 1
 | |
|     }
 | |
| 
 | |
|     fn in_function_scope(&self) -> bool {
 | |
|         true
 | |
|     }
 | |
| 
 | |
|     fn in_notebook(&self) -> bool {
 | |
|         false
 | |
|     }
 | |
| 
 | |
|     fn in_await_allowed_context(&self) -> bool {
 | |
|         true
 | |
|     }
 | |
| 
 | |
|     fn in_yield_allowed_context(&self) -> bool {
 | |
|         true
 | |
|     }
 | |
| 
 | |
|     fn in_generator_scope(&self) -> bool {
 | |
|         true
 | |
|     }
 | |
| }
 | |
| 
 | |
| impl Visitor<'_> for SemanticSyntaxCheckerVisitor<'_> {
 | |
|     fn visit_stmt(&mut self, stmt: &ast::Stmt) {
 | |
|         self.with_semantic_checker(|semantic, context| semantic.visit_stmt(stmt, context));
 | |
|         match stmt {
 | |
|             ast::Stmt::ClassDef(ast::StmtClassDef {
 | |
|                 arguments,
 | |
|                 body,
 | |
|                 decorator_list,
 | |
|                 type_params,
 | |
|                 ..
 | |
|             }) => {
 | |
|                 for decorator in decorator_list {
 | |
|                     self.visit_decorator(decorator);
 | |
|                 }
 | |
|                 if let Some(type_params) = type_params {
 | |
|                     self.visit_type_params(type_params);
 | |
|                 }
 | |
|                 if let Some(arguments) = arguments {
 | |
|                     self.visit_arguments(arguments);
 | |
|                 }
 | |
|                 self.scopes.push(Scope::Class);
 | |
|                 self.visit_body(body);
 | |
|                 self.scopes.pop().unwrap();
 | |
|             }
 | |
|             ast::Stmt::FunctionDef(ast::StmtFunctionDef { is_async, .. }) => {
 | |
|                 self.scopes.push(Scope::Function {
 | |
|                     is_async: *is_async,
 | |
|                 });
 | |
|                 ast::visitor::walk_stmt(self, stmt);
 | |
|                 self.scopes.pop().unwrap();
 | |
|             }
 | |
|             _ => {
 | |
|                 ast::visitor::walk_stmt(self, stmt);
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     fn visit_expr(&mut self, expr: &ast::Expr) {
 | |
|         self.with_semantic_checker(|semantic, context| semantic.visit_expr(expr, context));
 | |
|         match expr {
 | |
|             ast::Expr::Lambda(_) => {
 | |
|                 self.scopes.push(Scope::Function { is_async: false });
 | |
|                 ast::visitor::walk_expr(self, expr);
 | |
|                 self.scopes.pop().unwrap();
 | |
|             }
 | |
|             ast::Expr::ListComp(ast::ExprListComp {
 | |
|                 elt, generators, ..
 | |
|             })
 | |
|             | ast::Expr::SetComp(ast::ExprSetComp {
 | |
|                 elt, generators, ..
 | |
|             })
 | |
|             | ast::Expr::Generator(ast::ExprGenerator {
 | |
|                 elt, generators, ..
 | |
|             }) => {
 | |
|                 for comprehension in generators {
 | |
|                     self.visit_comprehension(comprehension);
 | |
|                 }
 | |
|                 self.scopes.push(Scope::Comprehension {
 | |
|                     is_async: generators.iter().any(|generator| generator.is_async),
 | |
|                 });
 | |
|                 self.visit_expr(elt);
 | |
|                 self.scopes.pop().unwrap();
 | |
|             }
 | |
|             ast::Expr::DictComp(ast::ExprDictComp {
 | |
|                 key,
 | |
|                 value,
 | |
|                 generators,
 | |
|                 ..
 | |
|             }) => {
 | |
|                 for comprehension in generators {
 | |
|                     self.visit_comprehension(comprehension);
 | |
|                 }
 | |
|                 self.scopes.push(Scope::Comprehension {
 | |
|                     is_async: generators.iter().any(|generator| generator.is_async),
 | |
|                 });
 | |
|                 self.visit_expr(key);
 | |
|                 self.visit_expr(value);
 | |
|                 self.scopes.pop().unwrap();
 | |
|             }
 | |
|             _ => {
 | |
|                 ast::visitor::walk_expr(self, expr);
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| }
 |