refactor hir-ty::diagnostics::decl_check for function bodies

This commit is contained in:
davidsemakula 2024-02-03 16:37:58 +03:00
parent c0071ace5a
commit 23aa872f3c

View file

@ -16,12 +16,9 @@ mod case_conv;
use std::fmt; use std::fmt;
use hir_def::{ use hir_def::{
data::adt::VariantData, data::adt::VariantData, db::DefDatabase, hir::Pat, src::HasSource, AdtId, AttrDefId, ConstId,
db::DefDatabase, EnumId, FunctionId, ItemContainerId, Lookup, ModuleDefId, ModuleId, StaticId, StructId,
hir::{Pat, PatId}, TraitId, TypeAliasId,
src::HasSource,
AdtId, AttrDefId, ConstId, EnumId, FunctionId, ItemContainerId, Lookup, ModuleDefId, ModuleId,
StaticId, StructId, TraitId, TypeAliasId,
}; };
use hir_expand::{ use hir_expand::{
name::{AsName, Name}, name::{AsName, Name},
@ -298,11 +295,9 @@ impl<'a> DeclValidator<'a> {
return; return;
} }
// Check whether function is an associated item of a trait implementation
let is_trait_impl_assoc_fn = self.is_trait_impl_container(container);
// Check the function name. // Check the function name.
if !is_trait_impl_assoc_fn { // Skipped if function is an associated item of a trait implementation.
if !self.is_trait_impl_container(container) {
let data = self.db.function_data(func); let data = self.db.function_data(func);
self.create_incorrect_case_diagnostic_for_item_name( self.create_incorrect_case_diagnostic_for_item_name(
func, func,
@ -315,82 +310,73 @@ impl<'a> DeclValidator<'a> {
} }
// Check the patterns inside the function body. // Check the patterns inside the function body.
// This includes function parameters if it's not an associated function self.validate_func_body(func);
// of a trait implementation. }
/// Check incorrect names for patterns inside the function body.
/// This includes function parameters except for trait implementation associated functions.
fn validate_func_body(&mut self, func: FunctionId) {
// Check whether function is an associated item of a trait implementation
let container = func.lookup(self.db.upcast()).container;
let is_trait_impl_assoc_fn = self.is_trait_impl_container(container);
let body = self.db.body(func.into()); let body = self.db.body(func.into());
let pats_replacements = body let mut pats_replacements = body
.pats .pats
.iter() .iter()
.filter_map(|(pat_id, pat)| match pat { .filter_map(|(pat_id, pat)| match pat {
Pat::Bind { id, .. } => { Pat::Bind { id, .. } => {
// Filter out parameters if it's an associated function // Filter out parameters for trait implementation associated functions.
// of a trait implementation.
if is_trait_impl_assoc_fn if is_trait_impl_assoc_fn
&& body.params.iter().any(|param_id| *param_id == pat_id) && body.params.iter().any(|param_id| *param_id == pat_id)
{ {
cov_mark::hit!(trait_impl_assoc_func_param_incorrect_case_ignored); cov_mark::hit!(trait_impl_assoc_func_param_incorrect_case_ignored);
None None
} else { } else {
Some((pat_id, &body.bindings[*id].name)) let bind_name = &body.bindings[*id].name;
let replacement = Replacement {
current_name: bind_name.clone(),
suggested_text: to_lower_snake_case(&bind_name.to_smol_str())?,
expected_case: CaseType::LowerSnakeCase,
};
Some((pat_id, replacement))
} }
} }
_ => None, _ => None,
}) })
.filter_map(|(id, bind_name)| { .peekable();
Some((
id,
Replacement {
current_name: bind_name.clone(),
suggested_text: to_lower_snake_case(
&bind_name.display(self.db.upcast()).to_string(),
)?,
expected_case: CaseType::LowerSnakeCase,
},
))
})
.collect();
self.create_incorrect_case_diagnostic_for_func_variables(func, pats_replacements);
}
/// Given the information about incorrect variable names, looks up into the source code
/// for exact locations and adds diagnostics into the sink.
fn create_incorrect_case_diagnostic_for_func_variables(
&mut self,
func: FunctionId,
pats_replacements: Vec<(PatId, Replacement)>,
) {
// XXX: only look at source_map if we do have missing fields // XXX: only look at source_map if we do have missing fields
if pats_replacements.is_empty() { if pats_replacements.peek().is_none() {
return; return;
} }
let (_, source_map) = self.db.body_with_source_map(func.into()); let (_, source_map) = self.db.body_with_source_map(func.into());
for (id, replacement) in pats_replacements { for (id, replacement) in pats_replacements {
if let Ok(source_ptr) = source_map.pat_syntax(id) { let Ok(source_ptr) = source_map.pat_syntax(id) else {
if let Some(ptr) = source_ptr.value.cast::<ast::IdentPat>() { continue;
};
let Some(ptr) = source_ptr.value.cast::<ast::IdentPat>() else {
continue;
};
let root = source_ptr.file_syntax(self.db.upcast()); let root = source_ptr.file_syntax(self.db.upcast());
let ident_pat = ptr.to_node(&root); let ident_pat = ptr.to_node(&root);
let parent = match ident_pat.syntax().parent() { let Some(parent) = ident_pat.syntax().parent() else {
Some(parent) => parent, continue;
None => continue,
}; };
let is_param = ast::Param::can_cast(parent.kind()); let is_param = ast::Param::can_cast(parent.kind());
// We have to check that it's either `let var = ...` or `var @ Variant(_)` statement, // We have to check that it's either `let var = ...` or `var @ Variant(_)` statement,
// because e.g. match arms are patterns as well. // because e.g. match arms are patterns as well.
// In other words, we check that it's a named variable binding. // In other words, we check that it's a named variable binding.
let is_binding = ast::LetStmt::can_cast(parent.kind()) let is_binding = ast::LetStmt::can_cast(parent.kind())
|| (ast::MatchArm::can_cast(parent.kind()) || (ast::MatchArm::can_cast(parent.kind()) && ident_pat.at_token().is_some());
&& ident_pat.at_token().is_some());
if !(is_param || is_binding) { if !(is_param || is_binding) {
// This pattern is not an actual variable declaration, e.g. `Some(val) => {..}` match arm. // This pattern is not an actual variable declaration, e.g. `Some(val) => {..}` match arm.
continue; continue;
} }
let ident_type = let ident_type = if is_param { IdentType::Parameter } else { IdentType::Variable };
if is_param { IdentType::Parameter } else { IdentType::Variable };
self.create_incorrect_case_diagnostic_for_ast_node( self.create_incorrect_case_diagnostic_for_ast_node(
replacement, replacement,
@ -400,8 +386,6 @@ impl<'a> DeclValidator<'a> {
); );
} }
} }
}
}
fn validate_struct(&mut self, struct_id: StructId) { fn validate_struct(&mut self, struct_id: StructId) {
let data = self.db.struct_data(struct_id); let data = self.db.struct_data(struct_id);