diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index d83b5a8ba6..06b24d7ae0 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2855,7 +2855,15 @@ impl<'db> TypeInferenceBuilder<'db> { let value_ty = self.infer_expression(value); let slice_ty = self.infer_expression(slice); + self.infer_subscript_expression_types(value, value_ty, slice_ty) + } + fn infer_subscript_expression_types( + &mut self, + value_node: &ast::Expr, + value_ty: Type<'db>, + slice_ty: Type<'db>, + ) -> Type<'db> { match (value_ty, slice_ty) { // Ex) Given `("a", "b", "c", "d")[1]`, return `"b"` (Type::Tuple(tuple_ty), Type::IntLiteral(int)) if int >= 0 => { @@ -2865,7 +2873,7 @@ impl<'db> TypeInferenceBuilder<'db> { .and_then(|index| elements.get(index).copied()) .unwrap_or_else(|| { self.tuple_index_out_of_bounds_diagnostic( - (&**value).into(), + value_node.into(), value_ty, elements.len(), int, @@ -2882,7 +2890,7 @@ impl<'db> TypeInferenceBuilder<'db> { .and_then(|index| elements.get(index).copied()) .unwrap_or_else(|| { self.tuple_index_out_of_bounds_diagnostic( - (&**value).into(), + value_node.into(), value_ty, elements.len(), int, @@ -2891,19 +2899,11 @@ impl<'db> TypeInferenceBuilder<'db> { }) } // Ex) Given `("a", "b", "c", "d")[True]`, return `"b"` - (Type::Tuple(tuple_ty), Type::BooleanLiteral(bool)) => { - let elements = tuple_ty.elements(self.db); - let int = i64::from(bool); - elements.get(usize::from(bool)).copied().unwrap_or_else(|| { - self.tuple_index_out_of_bounds_diagnostic( - (&**value).into(), - value_ty, - elements.len(), - int, - ); - Type::Unknown - }) - } + (Type::Tuple(_), Type::BooleanLiteral(bool)) => self.infer_subscript_expression_types( + value_node, + value_ty, + Type::IntLiteral(i64::from(bool)), + ), // Ex) Given `"value"[1]`, return `"a"` (Type::StringLiteral(literal_ty), Type::IntLiteral(int)) if int >= 0 => { let literal_value = literal_ty.value(self.db); @@ -2918,7 +2918,7 @@ impl<'db> TypeInferenceBuilder<'db> { }) .unwrap_or_else(|| { self.string_index_out_of_bounds_diagnostic( - (&**value).into(), + value_node.into(), value_ty, literal_value.chars().count(), int, @@ -2941,7 +2941,7 @@ impl<'db> TypeInferenceBuilder<'db> { }) .unwrap_or_else(|| { self.string_index_out_of_bounds_diagnostic( - (&**value).into(), + value_node.into(), value_ty, literal_value.chars().count(), int, @@ -2950,28 +2950,12 @@ impl<'db> TypeInferenceBuilder<'db> { }) } // Ex) Given `"value"[True]`, return `"a"` - (Type::StringLiteral(literal_ty), Type::BooleanLiteral(bool)) => { - let literal_value = literal_ty.value(self.db); - let int = i64::from(bool); - literal_value - .chars() - .nth(usize::from(bool)) - .map(|ch| { - Type::StringLiteral(StringLiteralType::new( - self.db, - ch.to_string().into_boxed_str(), - )) - }) - .unwrap_or_else(|| { - self.string_index_out_of_bounds_diagnostic( - (&**value).into(), - value_ty, - literal_value.chars().count(), - int, - ); - Type::Unknown - }) - } + (Type::StringLiteral(_), Type::BooleanLiteral(bool)) => self + .infer_subscript_expression_types( + value_node, + value_ty, + Type::IntLiteral(i64::from(bool)), + ), (value_ty, slice_ty) => { // Resolve the value to its class. let value_meta_ty = value_ty.to_meta_type(self.db); @@ -2983,10 +2967,10 @@ impl<'db> TypeInferenceBuilder<'db> { if !dunder_getitem_method.is_unbound() { return dunder_getitem_method .call(self.db, &[slice_ty]) - .return_ty_result(self.db, value.as_ref().into(), self) + .return_ty_result(self.db, value_node.into(), self) .unwrap_or_else(|err| { self.add_diagnostic( - (&**value).into(), + value_node.into(), "call-non-callable", format_args!( "Method `__getitem__` of type `{}` is not callable on object of type `{}`", @@ -3012,10 +2996,10 @@ impl<'db> TypeInferenceBuilder<'db> { if !dunder_class_getitem_method.is_unbound() { return dunder_class_getitem_method .call(self.db, &[slice_ty]) - .return_ty_result(self.db, value.as_ref().into(), self) + .return_ty_result(self.db, value_node.into(), self) .unwrap_or_else(|err| { self.add_diagnostic( - (&**value).into(), + value_node.into(), "call-non-callable", format_args!( "Method `__class_getitem__` of type `{}` is not callable on object of type `{}`", @@ -3033,12 +3017,12 @@ impl<'db> TypeInferenceBuilder<'db> { } self.non_subscriptable_diagnostic( - (&**value).into(), + value_node.into(), value_ty, "__class_getitem__", ); } else { - self.non_subscriptable_diagnostic((&**value).into(), value_ty, "__getitem__"); + self.non_subscriptable_diagnostic(value_node.into(), value_ty, "__getitem__"); } Type::Unknown