Skip to content

Commit bdfe12d

Browse files
committed
Check for allow(..) attributes for case diagnostic
1 parent 245e1b5 commit bdfe12d

File tree

1 file changed

+91
-25
lines changed

1 file changed

+91
-25
lines changed

crates/hir_ty/src/diagnostics/decl_check.rs

Lines changed: 91 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use hir_def::{
1616
adt::VariantData,
1717
expr::{Pat, PatId},
1818
src::HasSource,
19-
AdtId, ConstId, EnumId, FunctionId, Lookup, ModuleDefId, StaticId, StructId,
19+
AdtId, AttrDefId, ConstId, EnumId, FunctionId, Lookup, ModuleDefId, StaticId, StructId,
2020
};
2121
use hir_expand::{
2222
diagnostics::DiagnosticSink,
@@ -32,6 +32,12 @@ use crate::{
3232
diagnostics::{decl_check::case_conv::*, CaseType, IncorrectCase},
3333
};
3434

35+
mod allow {
36+
pub const NON_SNAKE_CASE: &str = "non_snake_case";
37+
pub const NON_UPPER_CASE_GLOBAL: &str = "non_upper_case_globals";
38+
pub const NON_CAMEL_CASE_TYPES: &str = "non_camel_case_types";
39+
}
40+
3541
pub(super) struct DeclValidator<'a, 'b: 'a> {
3642
owner: ModuleDefId,
3743
sink: &'a mut DiagnosticSink<'b>,
@@ -72,11 +78,29 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
7278
}
7379
}
7480

81+
/// Checks whether not following the convention is allowed for this item.
82+
///
83+
/// Currently this method doesn't check parent attributes.
84+
fn allowed(&self, db: &dyn HirDatabase, id: AttrDefId, allow_name: &str) -> bool {
85+
db.attrs(id).by_key("allow").tt_values().any(|tt| tt.to_string().contains(allow_name))
86+
}
87+
7588
fn validate_func(&mut self, db: &dyn HirDatabase, func: FunctionId) {
7689
let data = db.function_data(func);
7790
let body = db.body(func.into());
7891

79-
// 1. Check the function name.
92+
// 1. Recursively validate inner scope items, such as static variables and constants.
93+
for (item_id, _) in body.item_scope.values() {
94+
let mut validator = DeclValidator::new(item_id, self.sink);
95+
validator.validate_item(db);
96+
}
97+
98+
// 2. Check whether non-snake case identifiers are allowed for this function.
99+
if self.allowed(db, func.into(), allow::NON_SNAKE_CASE) {
100+
return;
101+
}
102+
103+
// 2. Check the function name.
80104
let function_name = data.name.to_string();
81105
let fn_name_replacement = if let Some(new_name) = to_lower_snake_case(&function_name) {
82106
let replacement = Replacement {
@@ -89,7 +113,7 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
89113
None
90114
};
91115

92-
// 2. Check the param names.
116+
// 3. Check the param names.
93117
let mut fn_param_replacements = Vec::new();
94118

95119
for pat_id in body.params.iter().cloned() {
@@ -111,7 +135,7 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
111135
}
112136
}
113137

114-
// 3. Check the patterns inside the function body.
138+
// 4. Check the patterns inside the function body.
115139
let mut pats_replacements = Vec::new();
116140

117141
for (pat_idx, pat) in body.pats.iter() {
@@ -136,20 +160,14 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
136160
}
137161
}
138162

139-
// 4. If there is at least one element to spawn a warning on, go to the source map and generate a warning.
163+
// 5. If there is at least one element to spawn a warning on, go to the source map and generate a warning.
140164
self.create_incorrect_case_diagnostic_for_func(
141165
func,
142166
db,
143167
fn_name_replacement,
144168
fn_param_replacements,
145169
);
146170
self.create_incorrect_case_diagnostic_for_variables(func, db, pats_replacements);
147-
148-
// 5. Recursively validate inner scope items, such as static variables and constants.
149-
for (item_id, _) in body.item_scope.values() {
150-
let mut validator = DeclValidator::new(item_id, self.sink);
151-
validator.validate_item(db);
152-
}
153171
}
154172

155173
/// Given the information about incorrect names in the function declaration, looks up into the source code
@@ -312,6 +330,10 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
312330
fn validate_struct(&mut self, db: &dyn HirDatabase, struct_id: StructId) {
313331
let data = db.struct_data(struct_id);
314332

333+
let non_camel_case_allowed =
334+
self.allowed(db, struct_id.into(), allow::NON_CAMEL_CASE_TYPES);
335+
let non_snake_case_allowed = self.allowed(db, struct_id.into(), allow::NON_SNAKE_CASE);
336+
315337
// 1. Check the structure name.
316338
let struct_name = data.name.to_string();
317339
let struct_name_replacement = if let Some(new_name) = to_camel_case(&struct_name) {
@@ -320,24 +342,30 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
320342
suggested_text: new_name,
321343
expected_case: CaseType::UpperCamelCase,
322344
};
323-
Some(replacement)
345+
if !non_camel_case_allowed {
346+
Some(replacement)
347+
} else {
348+
None
349+
}
324350
} else {
325351
None
326352
};
327353

328354
// 2. Check the field names.
329355
let mut struct_fields_replacements = Vec::new();
330356

331-
if let VariantData::Record(fields) = data.variant_data.as_ref() {
332-
for (_, field) in fields.iter() {
333-
let field_name = field.name.to_string();
334-
if let Some(new_name) = to_lower_snake_case(&field_name) {
335-
let replacement = Replacement {
336-
current_name: field.name.clone(),
337-
suggested_text: new_name,
338-
expected_case: CaseType::LowerSnakeCase,
339-
};
340-
struct_fields_replacements.push(replacement);
357+
if !non_snake_case_allowed {
358+
if let VariantData::Record(fields) = data.variant_data.as_ref() {
359+
for (_, field) in fields.iter() {
360+
let field_name = field.name.to_string();
361+
if let Some(new_name) = to_lower_snake_case(&field_name) {
362+
let replacement = Replacement {
363+
current_name: field.name.clone(),
364+
suggested_text: new_name,
365+
expected_case: CaseType::LowerSnakeCase,
366+
};
367+
struct_fields_replacements.push(replacement);
368+
}
341369
}
342370
}
343371
}
@@ -442,7 +470,12 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
442470
fn validate_enum(&mut self, db: &dyn HirDatabase, enum_id: EnumId) {
443471
let data = db.enum_data(enum_id);
444472

445-
// 1. Check the enum name.
473+
// 1. Check whether non-camel case names are allowed for this enum.
474+
if self.allowed(db, enum_id.into(), allow::NON_CAMEL_CASE_TYPES) {
475+
return;
476+
}
477+
478+
// 2. Check the enum name.
446479
let enum_name = data.name.to_string();
447480
let enum_name_replacement = if let Some(new_name) = to_camel_case(&enum_name) {
448481
let replacement = Replacement {
@@ -455,7 +488,7 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
455488
None
456489
};
457490

458-
// 2. Check the field names.
491+
// 3. Check the field names.
459492
let mut enum_fields_replacements = Vec::new();
460493

461494
for (_, variant) in data.variants.iter() {
@@ -470,7 +503,7 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
470503
}
471504
}
472505

473-
// 3. If there is at least one element to spawn a warning on, go to the source map and generate a warning.
506+
// 4. If there is at least one element to spawn a warning on, go to the source map and generate a warning.
474507
self.create_incorrect_case_diagnostic_for_enum(
475508
enum_id,
476509
db,
@@ -572,6 +605,10 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
572605
fn validate_const(&mut self, db: &dyn HirDatabase, const_id: ConstId) {
573606
let data = db.const_data(const_id);
574607

608+
if self.allowed(db, const_id.into(), allow::NON_UPPER_CASE_GLOBAL) {
609+
return;
610+
}
611+
575612
let name = match &data.name {
576613
Some(name) => name,
577614
None => return,
@@ -612,6 +649,10 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
612649
fn validate_static(&mut self, db: &dyn HirDatabase, static_id: StaticId) {
613650
let data = db.static_data(static_id);
614651

652+
if self.allowed(db, static_id.into(), allow::NON_UPPER_CASE_GLOBAL) {
653+
return;
654+
}
655+
615656
let name = match &data.name {
616657
Some(name) => name,
617658
None => return,
@@ -854,4 +895,29 @@ fn main() {
854895
"#,
855896
);
856897
}
898+
899+
#[test]
900+
fn allow_attributes() {
901+
check_diagnostics(
902+
r#"
903+
#[allow(non_snake_case)]
904+
fn NonSnakeCaseName(SOME_VAR: u8) -> u8{
905+
let OtherVar = SOME_VAR + 1;
906+
OtherVar
907+
}
908+
909+
#[allow(non_snake_case, non_camel_case_types)]
910+
pub struct some_type {
911+
SOME_FIELD: u8,
912+
SomeField: u16,
913+
}
914+
915+
#[allow(non_upper_case_globals)]
916+
pub const some_const: u8 = 10;
917+
918+
#[allow(non_upper_case_globals)]
919+
pub static SomeStatic: u8 = 10;
920+
"#,
921+
);
922+
}
857923
}

0 commit comments

Comments
 (0)