Formatter: Add SourceType to context to enable special formatting for stub files (#6331)

**Summary** This adds the information whether we're in a .py python
source file or in a .pyi stub file to enable people working on #5822 and
related issues.

I'm not completely happy with `Default` for something that depends on
the input.

**Test Plan** None, this is currently unused, i'm leaving this to first
implementation of stub file specific formatting.

---------

Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
konsti 2023-08-04 13:52:26 +02:00 committed by GitHub
parent fe97a2a302
commit 1031bb6550
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 137 additions and 97 deletions

View file

@ -52,15 +52,13 @@ pub(crate) fn type_comparison(checker: &mut Checker, compare: &ast::ExprCompare)
} }
// Left-hand side must be, e.g., `type(obj)`. // Left-hand side must be, e.g., `type(obj)`.
let Expr::Call(ast::ExprCall { let Expr::Call(ast::ExprCall { func, .. }) = left else {
func, ..
}) = left else {
continue; continue;
}; };
let Expr::Name(ast::ExprName { id, .. }) = func.as_ref() else { let Expr::Name(ast::ExprName { id, .. }) = func.as_ref() else {
continue; continue;
}; };
if !(id == "type" && checker.semantic().is_builtin("type")) { if !(id == "type" && checker.semantic().is_builtin("type")) {
continue; continue;

View file

@ -1,6 +1,7 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ruff_benchmark::{TestCase, TestCaseSpeed, TestFile, TestFileDownloadError}; use ruff_benchmark::{TestCase, TestCaseSpeed, TestFile, TestFileDownloadError};
use ruff_python_formatter::{format_module, PyFormatOptions}; use ruff_python_formatter::{format_module, PyFormatOptions};
use std::path::Path;
use std::time::Duration; use std::time::Duration;
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
@ -51,8 +52,8 @@ fn benchmark_formatter(criterion: &mut Criterion) {
&case, &case,
|b, case| { |b, case| {
b.iter(|| { b.iter(|| {
format_module(case.code(), PyFormatOptions::default()) let options = PyFormatOptions::from_extension(Path::new(case.name()));
.expect("Formatting to succeed") format_module(case.code(), options).expect("Formatting to succeed")
}); });
}, },
); );

View file

@ -161,25 +161,20 @@ fn format(files: &[PathBuf]) -> Result<ExitStatus> {
internal use only." internal use only."
); );
let format_code = |code: &str| {
// dummy, to check that the function was actually called
let contents = code.replace("# DEL", "");
// real formatting that is currently a passthrough
format_module(&contents, PyFormatOptions::default())
};
match &files { match &files {
// Check if we should read from stdin // Check if we should read from stdin
[path] if path == Path::new("-") => { [path] if path == Path::new("-") => {
let unformatted = read_from_stdin()?; let unformatted = read_from_stdin()?;
let formatted = format_code(&unformatted)?; let options = PyFormatOptions::from_extension(Path::new("stdin.py"));
let formatted = format_module(&unformatted, options)?;
stdout().lock().write_all(formatted.as_code().as_bytes())?; stdout().lock().write_all(formatted.as_code().as_bytes())?;
} }
_ => { _ => {
for file in files { for file in files {
let unformatted = std::fs::read_to_string(file) let unformatted = std::fs::read_to_string(file)
.with_context(|| format!("Could not read {}: ", file.display()))?; .with_context(|| format!("Could not read {}: ", file.display()))?;
let formatted = format_code(&unformatted)?; let options = PyFormatOptions::from_extension(file);
let formatted = format_module(&unformatted, options)?;
std::fs::write(file, formatted.as_code().as_bytes()) std::fs::write(file, formatted.as_code().as_bytes())
.with_context(|| format!("Could not write to {}, exiting", file.display()))?; .with_context(|| format!("Could not write to {}, exiting", file.display()))?;
} }

View file

@ -376,10 +376,10 @@ fn format_dev_project(
// TODO(konstin): The assumptions between this script (one repo) and ruff (pass in a bunch of // TODO(konstin): The assumptions between this script (one repo) and ruff (pass in a bunch of
// files) mismatch. // files) mismatch.
let options = BlackOptions::from_file(&files[0])?.to_py_format_options(); let black_options = BlackOptions::from_file(&files[0])?;
debug!( debug!(
parent: None, parent: None,
"Options for {}: {options:?}", "Options for {}: {black_options:?}",
files[0].display() files[0].display()
); );
@ -398,7 +398,7 @@ fn format_dev_project(
paths paths
.into_par_iter() .into_par_iter()
.map(|dir_entry| { .map(|dir_entry| {
let result = format_dir_entry(dir_entry, stability_check, write, &options); let result = format_dir_entry(dir_entry, stability_check, write, &black_options);
pb_span.pb_inc(1); pb_span.pb_inc(1);
result result
}) })
@ -447,7 +447,7 @@ fn format_dir_entry(
dir_entry: Result<DirEntry, ignore::Error>, dir_entry: Result<DirEntry, ignore::Error>,
stability_check: bool, stability_check: bool,
write: bool, write: bool,
options: &PyFormatOptions, options: &BlackOptions,
) -> anyhow::Result<(Result<Statistics, CheckFileError>, PathBuf), Error> { ) -> anyhow::Result<(Result<Statistics, CheckFileError>, PathBuf), Error> {
let dir_entry = match dir_entry.context("Iterating the files in the repository failed") { let dir_entry = match dir_entry.context("Iterating the files in the repository failed") {
Ok(dir_entry) => dir_entry, Ok(dir_entry) => dir_entry,
@ -460,27 +460,27 @@ fn format_dir_entry(
} }
let file = dir_entry.path().to_path_buf(); let file = dir_entry.path().to_path_buf();
let options = options.to_py_format_options(&file);
// Handle panics (mostly in `debug_assert!`) // Handle panics (mostly in `debug_assert!`)
let result = let result = match catch_unwind(|| format_dev_file(&file, stability_check, write, options)) {
match catch_unwind(|| format_dev_file(&file, stability_check, write, options.clone())) { Ok(result) => result,
Ok(result) => result, Err(panic) => {
Err(panic) => { if let Some(message) = panic.downcast_ref::<String>() {
if let Some(message) = panic.downcast_ref::<String>() { Err(CheckFileError::Panic {
Err(CheckFileError::Panic { message: message.clone(),
message: message.clone(), })
}) } else if let Some(&message) = panic.downcast_ref::<&str>() {
} else if let Some(&message) = panic.downcast_ref::<&str>() { Err(CheckFileError::Panic {
Err(CheckFileError::Panic { message: message.to_string(),
message: message.to_string(), })
}) } else {
} else { Err(CheckFileError::Panic {
Err(CheckFileError::Panic { // This should not happen, but it can
// This should not happen, but it can message: "(Panic didn't set a string message)".to_string(),
message: "(Panic didn't set a string message)".to_string(), })
})
}
} }
}; }
};
Ok((result, file)) Ok((result, file))
} }
@ -833,9 +833,8 @@ impl BlackOptions {
Self::from_toml(&fs::read_to_string(&path)?, repo) Self::from_toml(&fs::read_to_string(&path)?, repo)
} }
fn to_py_format_options(&self) -> PyFormatOptions { fn to_py_format_options(&self, file: &Path) -> PyFormatOptions {
let mut options = PyFormatOptions::default(); PyFormatOptions::from_extension(file)
options
.with_line_width( .with_line_width(
LineWidth::try_from(self.line_length).expect("Invalid line length limit"), LineWidth::try_from(self.line_length).expect("Invalid line length limit"),
) )
@ -843,8 +842,7 @@ impl BlackOptions {
MagicTrailingComma::Ignore MagicTrailingComma::Ignore
} else { } else {
MagicTrailingComma::Respect MagicTrailingComma::Respect
}); })
options
} }
} }
@ -868,7 +866,7 @@ mod tests {
"}; "};
let options = BlackOptions::from_toml(toml, Path::new("pyproject.toml")) let options = BlackOptions::from_toml(toml, Path::new("pyproject.toml"))
.unwrap() .unwrap()
.to_py_format_options(); .to_py_format_options(Path::new("code_inline.py"));
assert_eq!(options.line_width(), LineWidth::try_from(119).unwrap()); assert_eq!(options.line_width(), LineWidth::try_from(119).unwrap());
assert!(matches!( assert!(matches!(
options.magic_trailing_comma(), options.magic_trailing_comma(),
@ -887,7 +885,7 @@ mod tests {
"#}; "#};
let options = BlackOptions::from_toml(toml, Path::new("pyproject.toml")) let options = BlackOptions::from_toml(toml, Path::new("pyproject.toml"))
.unwrap() .unwrap()
.to_py_format_options(); .to_py_format_options(Path::new("code_inline.py"));
assert_eq!(options.line_width(), LineWidth::try_from(130).unwrap()); assert_eq!(options.line_width(), LineWidth::try_from(130).unwrap());
assert!(matches!( assert!(matches!(
options.magic_trailing_comma(), options.magic_trailing_comma(),

View file

@ -1,4 +1,5 @@
use ruff_text_size::{TextRange, TextSize}; use ruff_text_size::{TextRange, TextSize};
use std::path::Path;
pub mod all; pub mod all;
pub mod call_path; pub mod call_path;
@ -49,3 +50,36 @@ where
T::range(self) T::range(self)
} }
} }
#[derive(Clone, Copy, Debug, Default, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum PySourceType {
#[default]
Python,
Stub,
Jupyter,
}
impl PySourceType {
pub const fn is_python(&self) -> bool {
matches!(self, PySourceType::Python)
}
pub const fn is_stub(&self) -> bool {
matches!(self, PySourceType::Stub)
}
pub const fn is_jupyter(&self) -> bool {
matches!(self, PySourceType::Jupyter)
}
}
impl From<&Path> for PySourceType {
fn from(path: &Path) -> Self {
match path.extension() {
Some(ext) if ext == "pyi" => PySourceType::Stub,
Some(ext) if ext == "ipynb" => PySourceType::Jupyter,
_ => PySourceType::Python,
}
}
}

View file

@ -32,7 +32,7 @@ smallvec = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
[dev-dependencies] [dev-dependencies]
ruff_formatter = { path = "../ruff_formatter", features = ["serde"]} ruff_formatter = { path = "../ruff_formatter", features = ["serde"] }
insta = { workspace = true, features = ["glob"] } insta = { workspace = true, features = ["glob"] }
serde = { workspace = true } serde = { workspace = true }
@ -43,8 +43,8 @@ similar = { workspace = true }
name = "ruff_python_formatter_fixtures" name = "ruff_python_formatter_fixtures"
path = "tests/fixtures.rs" path = "tests/fixtures.rs"
test = true test = true
required-features = [ "serde" ] required-features = ["serde"]
[features] [features]
serde = ["dep:serde", "ruff_formatter/serde", "ruff_source_file/serde"] serde = ["dep:serde", "ruff_formatter/serde", "ruff_source_file/serde", "ruff_python_ast/serde"]
default = ["serde"] default = ["serde"]

View file

@ -1,14 +1,14 @@
#![allow(clippy::print_stdout)] #![allow(clippy::print_stdout)]
use std::path::PathBuf; use std::path::{Path, PathBuf};
use anyhow::{bail, Context, Result}; use anyhow::{bail, Context, Result};
use clap::{command, Parser, ValueEnum}; use clap::{command, Parser, ValueEnum};
use ruff_python_parser::lexer::lex;
use ruff_python_parser::{parse_tokens, Mode};
use ruff_formatter::SourceCode; use ruff_formatter::SourceCode;
use ruff_python_index::CommentRangesBuilder; use ruff_python_index::CommentRangesBuilder;
use ruff_python_parser::lexer::lex;
use ruff_python_parser::{parse_tokens, Mode};
use crate::{format_node, PyFormatOptions}; use crate::{format_node, PyFormatOptions};
@ -37,7 +37,7 @@ pub struct Cli {
pub print_comments: bool, pub print_comments: bool,
} }
pub fn format_and_debug_print(input: &str, cli: &Cli) -> Result<String> { pub fn format_and_debug_print(input: &str, cli: &Cli, source_type: &Path) -> Result<String> {
let mut tokens = Vec::new(); let mut tokens = Vec::new();
let mut comment_ranges = CommentRangesBuilder::default(); let mut comment_ranges = CommentRangesBuilder::default();
@ -57,13 +57,9 @@ pub fn format_and_debug_print(input: &str, cli: &Cli) -> Result<String> {
let python_ast = let python_ast =
parse_tokens(tokens, Mode::Module, "<filename>").context("Syntax error in input")?; parse_tokens(tokens, Mode::Module, "<filename>").context("Syntax error in input")?;
let formatted = format_node( let options = PyFormatOptions::from_extension(source_type);
&python_ast, let formatted = format_node(&python_ast, &comment_ranges, input, options)
&comment_ranges, .context("Failed to format node")?;
input,
PyFormatOptions::default(),
)
.context("Failed to format node")?;
if cli.print_ir { if cli.print_ir {
println!("{}", formatted.document().display(SourceCode::new(input))); println!("{}", formatted.document().display(SourceCode::new(input)));
} }

View file

@ -255,6 +255,7 @@ mod tests {
use ruff_python_index::CommentRangesBuilder; use ruff_python_index::CommentRangesBuilder;
use ruff_python_parser::lexer::lex; use ruff_python_parser::lexer::lex;
use ruff_python_parser::{parse_tokens, Mode}; use ruff_python_parser::{parse_tokens, Mode};
use std::path::Path;
/// Very basic test intentionally kept very similar to the CLI /// Very basic test intentionally kept very similar to the CLI
#[test] #[test]
@ -321,15 +322,10 @@ with [
let comment_ranges = comment_ranges.finish(); let comment_ranges = comment_ranges.finish();
// Parse the AST. // Parse the AST.
let python_ast = parse_tokens(tokens, Mode::Module, "<filename>").unwrap(); let source_path = "code_inline.py";
let python_ast = parse_tokens(tokens, Mode::Module, source_path).unwrap();
let formatted = format_node( let options = PyFormatOptions::from_extension(Path::new(source_path));
&python_ast, let formatted = format_node(&python_ast, &comment_ranges, src, options).unwrap();
&comment_ranges,
src,
PyFormatOptions::default(),
)
.unwrap();
// Uncomment the `dbg` to print the IR. // Uncomment the `dbg` to print the IR.
// Use `dbg_write!(f, []) instead of `write!(f, [])` in your formatting code to print some IR // Use `dbg_write!(f, []) instead of `write!(f, [])` in your formatting code to print some IR

View file

@ -1,4 +1,5 @@
use std::io::{stdout, Read, Write}; use std::io::{stdout, Read, Write};
use std::path::Path;
use std::{fs, io}; use std::{fs, io};
use anyhow::{bail, Context, Result}; use anyhow::{bail, Context, Result};
@ -25,7 +26,8 @@ fn main() -> Result<()> {
); );
} }
let input = read_from_stdin()?; let input = read_from_stdin()?;
let formatted = format_and_debug_print(&input, &cli)?; // It seems reasonable to give this a dummy name
let formatted = format_and_debug_print(&input, &cli, Path::new("stdin.py"))?;
if cli.check { if cli.check {
if formatted == input { if formatted == input {
return Ok(()); return Ok(());
@ -37,7 +39,7 @@ fn main() -> Result<()> {
for file in &cli.files { for file in &cli.files {
let input = fs::read_to_string(file) let input = fs::read_to_string(file)
.with_context(|| format!("Could not read {}: ", file.display()))?; .with_context(|| format!("Could not read {}: ", file.display()))?;
let formatted = format_and_debug_print(&input, &cli)?; let formatted = format_and_debug_print(&input, &cli, file)?;
match cli.emit { match cli.emit {
Some(Emit::Stdout) => stdout().lock().write_all(formatted.as_bytes())?, Some(Emit::Stdout) => stdout().lock().write_all(formatted.as_bytes())?,
None | Some(Emit::Files) => { None | Some(Emit::Files) => {

View file

@ -1,5 +1,7 @@
use ruff_formatter::printer::{LineEnding, PrinterOptions}; use ruff_formatter::printer::{LineEnding, PrinterOptions};
use ruff_formatter::{FormatOptions, IndentStyle, LineWidth}; use ruff_formatter::{FormatOptions, IndentStyle, LineWidth};
use ruff_python_ast::PySourceType;
use std::path::Path;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
#[cfg_attr( #[cfg_attr(
@ -8,6 +10,9 @@ use ruff_formatter::{FormatOptions, IndentStyle, LineWidth};
serde(default) serde(default)
)] )]
pub struct PyFormatOptions { pub struct PyFormatOptions {
/// Whether we're in a `.py` file or `.pyi` file, which have different rules
source_type: PySourceType,
/// Specifies the indent style: /// Specifies the indent style:
/// * Either a tab /// * Either a tab
/// * or a specific amount of spaces /// * or a specific amount of spaces
@ -28,7 +33,31 @@ fn default_line_width() -> LineWidth {
LineWidth::try_from(88).unwrap() LineWidth::try_from(88).unwrap()
} }
impl Default for PyFormatOptions {
fn default() -> Self {
Self {
source_type: PySourceType::default(),
indent_style: IndentStyle::Space(4),
line_width: LineWidth::try_from(88).unwrap(),
quote_style: QuoteStyle::default(),
magic_trailing_comma: MagicTrailingComma::default(),
}
}
}
impl PyFormatOptions { impl PyFormatOptions {
/// Otherwise sets the defaults. Returns none if the extension is unknown
pub fn from_extension(path: &Path) -> Self {
Self::from_source_type(PySourceType::from(path))
}
pub fn from_source_type(source_type: PySourceType) -> Self {
Self {
source_type,
..Self::default()
}
}
pub fn magic_trailing_comma(&self) -> MagicTrailingComma { pub fn magic_trailing_comma(&self) -> MagicTrailingComma {
self.magic_trailing_comma self.magic_trailing_comma
} }
@ -42,17 +71,20 @@ impl PyFormatOptions {
self self
} }
pub fn with_magic_trailing_comma(&mut self, trailing_comma: MagicTrailingComma) -> &mut Self { #[must_use]
pub fn with_magic_trailing_comma(mut self, trailing_comma: MagicTrailingComma) -> Self {
self.magic_trailing_comma = trailing_comma; self.magic_trailing_comma = trailing_comma;
self self
} }
pub fn with_indent_style(&mut self, indent_style: IndentStyle) -> &mut Self { #[must_use]
pub fn with_indent_style(mut self, indent_style: IndentStyle) -> Self {
self.indent_style = indent_style; self.indent_style = indent_style;
self self
} }
pub fn with_line_width(&mut self, line_width: LineWidth) -> &mut Self { #[must_use]
pub fn with_line_width(mut self, line_width: LineWidth) -> Self {
self.line_width = line_width; self.line_width = line_width;
self self
} }
@ -77,17 +109,6 @@ impl FormatOptions for PyFormatOptions {
} }
} }
impl Default for PyFormatOptions {
fn default() -> Self {
Self {
indent_style: IndentStyle::Space(4),
line_width: LineWidth::try_from(88).unwrap(),
quote_style: QuoteStyle::default(),
magic_trailing_comma: MagicTrailingComma::default(),
}
}
}
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
#[cfg_attr( #[cfg_attr(
feature = "serde", feature = "serde",

View file

@ -17,7 +17,7 @@ fn black_compatibility() {
let reader = BufReader::new(options_file); let reader = BufReader::new(options_file);
serde_json::from_reader(reader).expect("Options to be a valid Json file") serde_json::from_reader(reader).expect("Options to be a valid Json file")
} else { } else {
PyFormatOptions::default() PyFormatOptions::from_extension(input_path)
}; };
let printed = format_module(&content, options.clone()).unwrap_or_else(|err| { let printed = format_module(&content, options.clone()).unwrap_or_else(|err| {
@ -106,11 +106,11 @@ fn format() {
let test_file = |input_path: &Path| { let test_file = |input_path: &Path| {
let content = fs::read_to_string(input_path).unwrap(); let content = fs::read_to_string(input_path).unwrap();
let options = PyFormatOptions::default(); let options = PyFormatOptions::from_extension(input_path);
let printed = format_module(&content, options.clone()).expect("Formatting to succeed"); let printed = format_module(&content, options.clone()).expect("Formatting to succeed");
let formatted_code = printed.as_code(); let formatted_code = printed.as_code();
ensure_stability_when_formatting_twice(formatted_code, options, input_path); ensure_stability_when_formatting_twice(formatted_code, options.clone(), input_path);
let mut snapshot = format!("## Input\n{}", CodeFrame::new("py", &content)); let mut snapshot = format!("## Input\n{}", CodeFrame::new("py", &content));
@ -139,7 +139,6 @@ fn format() {
.unwrap(); .unwrap();
} }
} else { } else {
let options = PyFormatOptions::default();
let printed = format_module(&content, options.clone()).expect("Formatting to succeed"); let printed = format_module(&content, options.clone()).expect("Formatting to succeed");
let formatted_code = printed.as_code(); let formatted_code = printed.as_code();

View file

@ -21,6 +21,7 @@ use ruff::rules::{
use ruff::settings::configuration::Configuration; use ruff::settings::configuration::Configuration;
use ruff::settings::options::Options; use ruff::settings::options::Options;
use ruff::settings::{defaults, flags, Settings}; use ruff::settings::{defaults, flags, Settings};
use ruff_python_ast::PySourceType;
use ruff_python_codegen::Stylist; use ruff_python_codegen::Stylist;
use ruff_python_formatter::{format_module, format_node, PyFormatOptions}; use ruff_python_formatter::{format_module, format_node, PyFormatOptions};
use ruff_python_index::{CommentRangesBuilder, Indexer}; use ruff_python_index::{CommentRangesBuilder, Indexer};
@ -262,7 +263,9 @@ impl Workspace {
} }
pub fn format(&self, contents: &str) -> Result<String, Error> { pub fn format(&self, contents: &str) -> Result<String, Error> {
let printed = format_module(contents, PyFormatOptions::default()).map_err(into_error)?; // TODO(konstin): Add an options for py/pyi to the UI (1/2)
let options = PyFormatOptions::from_source_type(PySourceType::default());
let printed = format_module(contents, options).map_err(into_error)?;
Ok(printed.into_code()) Ok(printed.into_code())
} }
@ -278,13 +281,10 @@ impl Workspace {
let comment_ranges = comment_ranges.finish(); let comment_ranges = comment_ranges.finish();
let module = parse_tokens(tokens, Mode::Module, ".").map_err(into_error)?; let module = parse_tokens(tokens, Mode::Module, ".").map_err(into_error)?;
let formatted = format_node( // TODO(konstin): Add an options for py/pyi to the UI (2/2)
&module, let options = PyFormatOptions::from_source_type(PySourceType::default());
&comment_ranges, let formatted =
contents, format_node(&module, &comment_ranges, contents, options).map_err(into_error)?;
PyFormatOptions::default(),
)
.map_err(into_error)?;
Ok(format!("{formatted}")) Ok(format!("{formatted}"))
} }