From fb1d1e3241c6182b2e1e7e799cf11a4cf3cd3fd5 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Wed, 16 Oct 2024 07:58:24 +0100 Subject: [PATCH] [red-knot] Simplify some branches in `infer_subscript_expression` (#13762) ## Summary Just a small simplification to remove some unnecessary complexity here. Rather than using separate branches for subscript expressions involving boolean literals, we can simply convert them to integer literals and reuse the logic in the `IntLiteral` branches. ## Test Plan `cargo test -p red_knot_python_semantic` --- .../src/types/infer.rs | 74 ++++++++----------- 1 file changed, 29 insertions(+), 45 deletions(-) diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index d83b5a8ba6..06b24d7ae0 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2855,7 +2855,15 @@ impl<'db> TypeInferenceBuilder<'db> { let value_ty = self.infer_expression(value); let slice_ty = self.infer_expression(slice); + self.infer_subscript_expression_types(value, value_ty, slice_ty) + } + fn infer_subscript_expression_types( + &mut self, + value_node: &ast::Expr, + value_ty: Type<'db>, + slice_ty: Type<'db>, + ) -> Type<'db> { match (value_ty, slice_ty) { // Ex) Given `("a", "b", "c", "d")[1]`, return `"b"` (Type::Tuple(tuple_ty), Type::IntLiteral(int)) if int >= 0 => { @@ -2865,7 +2873,7 @@ impl<'db> TypeInferenceBuilder<'db> { .and_then(|index| elements.get(index).copied()) .unwrap_or_else(|| { self.tuple_index_out_of_bounds_diagnostic( - (&**value).into(), + value_node.into(), value_ty, elements.len(), int, @@ -2882,7 +2890,7 @@ impl<'db> TypeInferenceBuilder<'db> { .and_then(|index| elements.get(index).copied()) .unwrap_or_else(|| { self.tuple_index_out_of_bounds_diagnostic( - (&**value).into(), + value_node.into(), value_ty, elements.len(), int, @@ -2891,19 +2899,11 @@ impl<'db> TypeInferenceBuilder<'db> { }) } // Ex) Given `("a", "b", "c", "d")[True]`, return `"b"` - (Type::Tuple(tuple_ty), Type::BooleanLiteral(bool)) => { - let elements = tuple_ty.elements(self.db); - let int = i64::from(bool); - elements.get(usize::from(bool)).copied().unwrap_or_else(|| { - self.tuple_index_out_of_bounds_diagnostic( - (&**value).into(), - value_ty, - elements.len(), - int, - ); - Type::Unknown - }) - } + (Type::Tuple(_), Type::BooleanLiteral(bool)) => self.infer_subscript_expression_types( + value_node, + value_ty, + Type::IntLiteral(i64::from(bool)), + ), // Ex) Given `"value"[1]`, return `"a"` (Type::StringLiteral(literal_ty), Type::IntLiteral(int)) if int >= 0 => { let literal_value = literal_ty.value(self.db); @@ -2918,7 +2918,7 @@ impl<'db> TypeInferenceBuilder<'db> { }) .unwrap_or_else(|| { self.string_index_out_of_bounds_diagnostic( - (&**value).into(), + value_node.into(), value_ty, literal_value.chars().count(), int, @@ -2941,7 +2941,7 @@ impl<'db> TypeInferenceBuilder<'db> { }) .unwrap_or_else(|| { self.string_index_out_of_bounds_diagnostic( - (&**value).into(), + value_node.into(), value_ty, literal_value.chars().count(), int, @@ -2950,28 +2950,12 @@ impl<'db> TypeInferenceBuilder<'db> { }) } // Ex) Given `"value"[True]`, return `"a"` - (Type::StringLiteral(literal_ty), Type::BooleanLiteral(bool)) => { - let literal_value = literal_ty.value(self.db); - let int = i64::from(bool); - literal_value - .chars() - .nth(usize::from(bool)) - .map(|ch| { - Type::StringLiteral(StringLiteralType::new( - self.db, - ch.to_string().into_boxed_str(), - )) - }) - .unwrap_or_else(|| { - self.string_index_out_of_bounds_diagnostic( - (&**value).into(), - value_ty, - literal_value.chars().count(), - int, - ); - Type::Unknown - }) - } + (Type::StringLiteral(_), Type::BooleanLiteral(bool)) => self + .infer_subscript_expression_types( + value_node, + value_ty, + Type::IntLiteral(i64::from(bool)), + ), (value_ty, slice_ty) => { // Resolve the value to its class. let value_meta_ty = value_ty.to_meta_type(self.db); @@ -2983,10 +2967,10 @@ impl<'db> TypeInferenceBuilder<'db> { if !dunder_getitem_method.is_unbound() { return dunder_getitem_method .call(self.db, &[slice_ty]) - .return_ty_result(self.db, value.as_ref().into(), self) + .return_ty_result(self.db, value_node.into(), self) .unwrap_or_else(|err| { self.add_diagnostic( - (&**value).into(), + value_node.into(), "call-non-callable", format_args!( "Method `__getitem__` of type `{}` is not callable on object of type `{}`", @@ -3012,10 +2996,10 @@ impl<'db> TypeInferenceBuilder<'db> { if !dunder_class_getitem_method.is_unbound() { return dunder_class_getitem_method .call(self.db, &[slice_ty]) - .return_ty_result(self.db, value.as_ref().into(), self) + .return_ty_result(self.db, value_node.into(), self) .unwrap_or_else(|err| { self.add_diagnostic( - (&**value).into(), + value_node.into(), "call-non-callable", format_args!( "Method `__class_getitem__` of type `{}` is not callable on object of type `{}`", @@ -3033,12 +3017,12 @@ impl<'db> TypeInferenceBuilder<'db> { } self.non_subscriptable_diagnostic( - (&**value).into(), + value_node.into(), value_ty, "__class_getitem__", ); } else { - self.non_subscriptable_diagnostic((&**value).into(), value_ty, "__getitem__"); + self.non_subscriptable_diagnostic(value_node.into(), value_ty, "__getitem__"); } Type::Unknown