Implemented prototype of refinement types.

This commit is contained in:
Eric Traut 2024-08-19 20:45:48 -07:00
parent 243983220e
commit e08c7a1e0f
40 changed files with 6946 additions and 112 deletions

View file

@ -670,7 +670,10 @@ export class Binder extends ParseTreeWalker {
// Skip if we're in an 'Annotated' annotation because this creates
// problems for "No Return" return type analysis when annotation
// evaluation is deferred.
if (!this._isInAnnotatedAnnotation) {
if (
!this._isInAnnotatedAnnotation &&
!ParseTreeUtils.isNodeContainedWithinNodeType(node, ParseNodeType.StringList)
) {
this._createCallFlowNode(node);
}
}

View file

@ -9,19 +9,25 @@
*/
import { assert } from '../common/debug';
import { RefinementExpr, RefinementVarId } from './refinementTypes';
import { FunctionType, ParamSpecType, Type, TypeVarType } from './types';
export type RefinementVarMap = Map<RefinementVarId, RefinementExpr | undefined>;
// Records the types associated with a set of type variables.
export class ConstraintSolutionSet {
// Indexed by TypeVar ID.
private _typeVarMap: Map<string, Type | undefined>;
// Indexed by refinement var ID.
private _refinementVarMap: RefinementVarMap | undefined;
constructor() {
this._typeVarMap = new Map();
}
isEmpty() {
return this._typeVarMap.size === 0;
return this._typeVarMap.size === 0 && !this._refinementVarMap;
}
getType(typeVar: ParamSpecType): FunctionType | undefined;
@ -48,6 +54,34 @@ export class ConstraintSolutionSet {
}
});
}
getRefinementVarType(refinementVarId: string): RefinementExpr | undefined {
return this._refinementVarMap?.get(refinementVarId);
}
setRefinementVarType(refinementVarId: string, value: RefinementExpr | undefined) {
if (!this._refinementVarMap) {
this._refinementVarMap = new Map();
}
this._refinementVarMap.set(refinementVarId, value);
}
hasRefinementVarType(refinementVarId: string): boolean {
return this._refinementVarMap?.has(refinementVarId) ?? false;
}
doForEachRefinementVar(callback: (value: RefinementExpr, refinementVarId: string) => void) {
this._refinementVarMap?.forEach((type, key) => {
if (type) {
callback(type, key);
}
});
}
getRefinementVarMap(): RefinementVarMap {
return this._refinementVarMap ?? new Map();
}
}
export class ConstraintSolution {
@ -68,6 +102,12 @@ export class ConstraintSolution {
});
}
setRefinementVarType(refinementVarId: string, value: RefinementExpr) {
return this._solutionSets.forEach((set) => {
set.setRefinementVarType(refinementVarId, value);
});
}
getMainSolutionSet() {
return this.getSolutionSet(0);
}

View file

@ -13,6 +13,7 @@ import { DiagnosticAddendum } from '../common/diagnostic';
import { LocAddendum } from '../localization/localize';
import { ConstraintSolution, ConstraintSolutionSet } from './constraintSolution';
import { ConstraintSet, ConstraintTracker, TypeVarConstraints } from './constraintTracker';
import { solveRefinementVarRecursive } from './refinementSolver';
import {
AssignTypeFlags,
maxSubtypesForInferredType,
@ -248,6 +249,11 @@ export function solveConstraintSet(
solveTypeVarRecursive(evaluator, constraintSet, options, solutionSet, entry);
});
// Solve the refinement variables.
constraintSet.doForEachRefinementVar((name) => {
solveRefinementVarRecursive(constraintSet, solutionSet, name);
});
return solutionSet;
}

View file

@ -10,6 +10,8 @@
*/
import { assert } from '../common/debug';
import { RefinementExpr } from './refinementTypes';
import { isRefinementExprEquivalent } from './refinementTypeUtils';
import { getComplexityScoreForType } from './typeComplexity';
import { Type, TypeVarScopeId, TypeVarType, isTypeSame } from './types';
@ -41,6 +43,9 @@ export class ConstraintSet {
// Maps type variable IDs to their current constraints.
private _typeVarMap: Map<string, TypeVarConstraints>;
// Maps refinement variable IDs to their current values.
private _refinementVarMap: Map<string, RefinementExpr> | undefined;
// A set of one or more TypeVar scope IDs that identify this constraint set.
// This corresponds to the scope ID of the overload signature. Normally
// there will be only one scope ID associated with each signature, but
@ -65,6 +70,13 @@ export class ConstraintSet {
this._scopeIds.forEach((scopeId) => constraintSet.addScopeId(scopeId));
}
if (this._refinementVarMap) {
constraintSet._refinementVarMap = new Map<string, RefinementExpr>();
this._refinementVarMap.forEach((value, key) => {
constraintSet._refinementVarMap!.set(key, value);
});
}
return constraintSet;
}
@ -93,6 +105,21 @@ export class ConstraintSet {
}
});
if (this._refinementVarMap) {
if (!other._refinementVarMap || this._refinementVarMap.size !== other._refinementVarMap.size) {
isSame = false;
} else {
this._refinementVarMap.forEach((value, key) => {
const otherValue = other._refinementVarMap!.get(key);
if (!otherValue || !isRefinementExprEquivalent(value, otherValue)) {
isSame = false;
}
});
}
} else if (other._refinementVarMap) {
isSame = false;
}
return isSame;
}
@ -180,6 +207,24 @@ export class ConstraintSet {
return false;
}
getRefinementVarType(refinementVarId: string): RefinementExpr | undefined {
return this._refinementVarMap?.get(refinementVarId);
}
setRefinementVarType(refinementVarId: string, value: RefinementExpr) {
if (!this._refinementVarMap) {
this._refinementVarMap = new Map<string, RefinementExpr>();
}
this._refinementVarMap.set(refinementVarId, value);
}
doForEachRefinementVar(cb: (id: string, value: RefinementExpr) => void) {
this._refinementVarMap?.forEach((value, key) => {
cb(key, value);
});
}
}
export class ConstraintTracker {
@ -276,4 +321,10 @@ export class ConstraintTracker {
assert(index >= 0 && index < this._constraintSets.length);
return this._constraintSets[index];
}
setRefinementVarType(refinementVarId: string, value: RefinementExpr) {
return this._constraintSets.forEach((set) => {
set.setRefinementVarType(refinementVarId, value);
});
}
}

View file

@ -12,7 +12,7 @@ import { getEmptyRange } from '../common/textRange';
import { Uri } from '../common/uri/uri';
import { NameNode, ParseNodeType } from '../parser/parseNodes';
import { ImportLookup, ImportLookupResult } from './analyzerFileInfo';
import { AliasDeclaration, Declaration, DeclarationType, ModuleLoaderActions, isAliasDeclaration } from './declaration';
import { AliasDeclaration, Declaration, DeclarationType, isAliasDeclaration, ModuleLoaderActions } from './declaration';
import { getFileInfoFromNode } from './parseTreeUtils';
import { Symbol } from './symbol';

View file

@ -23,12 +23,15 @@ import {
import { OperatorType } from '../parser/tokenizerTypes';
import { getFileInfo } from './analyzerNodeInfo';
import { getEnclosingLambda, isWithinLoop, operatorSupportsChaining, printOperator } from './parseTreeUtils';
import { isRefinementWildcard } from './refinementTypeUtils';
import { TypeRefinement } from './refinementTypes';
import { getScopeForNode } from './scopeUtils';
import { evaluateStaticBoolExpression } from './staticExpressions';
import { EvalFlags, MagicMethodDeprecationInfo, TypeEvaluator, TypeResult } from './typeEvaluatorTypes';
import {
InferenceContext,
convertToInstantiable,
getIntValueRefinement,
getLiteralTypeClassName,
getTypeCondition,
getUnionSubtypeCount,
@ -335,6 +338,16 @@ export function getTypeOfBinaryOperation(
}
}
// Is this a "@" operator used in a context where it is supposed to be
// interpreted as a refinement type operator?
if (node.d.operator === OperatorType.MatrixMultiply && (flags & EvalFlags.TypeExpression) !== 0) {
if (getFileInfo(node).diagnosticRuleSet.enableExperimentalFeatures) {
if (!customMetaclassSupportsMethod(leftType, '__matmul__')) {
return evaluator.getTypeMetadata(leftExpression, leftTypeResult, rightExpression);
}
}
}
const rightTypeResult = evaluator.getTypeOfExpression(
rightExpression,
flags,
@ -447,26 +460,24 @@ export function getTypeOfBinaryOperation(
node.d.leftExpr
);
} else {
// If neither the LHS or RHS are unions, don't include a diagnostic addendum
let diagMessage = diag.getString();
// If neither the LHS or RHS are unions, don't add more diagnostic information
// because it will be redundant with the main diagnostic message. The addenda
// are useful only if union expansion was used for one or both operands.
let diagString = '';
if (
isUnion(evaluator.makeTopLevelTypeVarsConcrete(leftType)) ||
isUnion(evaluator.makeTopLevelTypeVarsConcrete(rightType))
) {
diagString = diag.getString();
diagMessage =
LocMessage.typeNotSupportBinaryOperator().format({
operator: printOperator(node.d.operator),
leftType: evaluator.printType(leftType),
rightType: evaluator.printType(rightType),
}) + diagMessage;
}
evaluator.addDiagnostic(
DiagnosticRule.reportOperatorIssue,
LocMessage.typeNotSupportBinaryOperator().format({
operator: printOperator(node.d.operator),
leftType: evaluator.printType(leftType),
rightType: evaluator.printType(rightType),
}) + diagString,
node
);
evaluator.addDiagnostic(DiagnosticRule.reportOperatorIssue, diagMessage, node);
}
}
}
@ -1266,12 +1277,14 @@ function validateArithmeticOperation(
}
const magicMethodName = binaryOperatorMap[operator][0];
const subDiag = new DiagnosticAddendum();
let resultTypeResult = evaluator.getTypeOfMagicMethodCall(
convertFunctionToObject(evaluator, leftSubtypeUnexpanded),
magicMethodName,
[{ type: rightSubtypeUnexpanded, isIncomplete: rightTypeResult.isIncomplete }],
errorNode,
inferenceContext
inferenceContext,
subDiag
);
if (!resultTypeResult && leftSubtypeUnexpanded !== leftSubtypeExpanded) {
@ -1336,6 +1349,8 @@ function validateArithmeticOperation(
}
if (!resultTypeResult) {
diag.addAddendum(subDiag);
if (inferenceContext) {
diag.addMessage(
LocMessage.typeNotSupportBinaryOperatorBidirectional().format({
@ -1360,6 +1375,15 @@ function validateArithmeticOperation(
deprecatedInfo = resultTypeResult.magicMethodDeprecationInfo;
}
if (resultTypeResult && options.isLiteralMathAllowed) {
resultTypeResult.type = applyRefinementsForBinaryOp(
operator,
leftSubtypeExpanded,
rightSubtypeExpanded,
resultTypeResult.type
);
}
return resultTypeResult?.type ?? UnknownType.create(isIncomplete);
}
);
@ -1368,3 +1392,44 @@ function validateArithmeticOperation(
return { type, magicMethodDeprecationInfo: deprecatedInfo };
}
function applyRefinementsForBinaryOp(operator: OperatorType, leftType: Type, rightType: Type, resultType: Type): Type {
if (
!isClassInstance(leftType) ||
!ClassType.isBuiltIn(leftType, 'int') ||
!isClassInstance(rightType) ||
!ClassType.isBuiltIn(rightType, 'int') ||
!isClassInstance(resultType) ||
!ClassType.isBuiltIn(resultType, 'int') ||
resultType.priv.literalValue !== undefined
) {
return resultType;
}
const supportedOps: OperatorType[] = [
OperatorType.Add,
OperatorType.Subtract,
OperatorType.Multiply,
OperatorType.FloorDivide,
OperatorType.Mod,
];
if (!supportedOps.includes(operator)) {
return resultType;
}
const leftIntValue = getIntValueRefinement(leftType);
const rightIntValue = getIntValueRefinement(rightType);
if (!leftIntValue || !rightIntValue) {
return resultType;
}
const refinement = TypeRefinement.fromBinaryOp(operator, leftIntValue, rightIntValue);
if (isRefinementWildcard(refinement.value)) {
return resultType;
}
return ClassType.cloneWithRefinements(resultType, [refinement]);
}

View file

@ -43,6 +43,7 @@ export enum ParamKind {
export interface VirtualParamDetails {
param: FunctionParam;
realParamType: Type;
type: Type;
declaredType: Type;
defaultType?: Type | undefined;
@ -124,6 +125,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
const addVirtualParam = (
param: FunctionParam,
realParamType: Type,
index: number,
typeOverride?: Type,
defaultTypeOverride?: Type,
@ -145,6 +147,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
result.params.push({
param,
realParamType,
index,
type: typeOverride ?? FunctionType.getParamType(type, index),
declaredType: FunctionType.getDeclaredParamType(type, index),
@ -182,6 +185,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
FunctionParamFlags.NameSynthesized | FunctionParamFlags.TypeDeclared,
`${param.name}[${tupleIndex.toString()}]`
),
paramType,
index,
tupleArg.type,
/* defaultArgTypeOverride */ undefined,
@ -226,7 +230,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
sawKeywordOnlySeparator = true;
}
addVirtualParam(param, index);
addVirtualParam(param, paramType, index);
}
} else if (param.category === ParamCategory.KwargsDict) {
sawKeywordOnlySeparator = true;
@ -256,6 +260,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
name,
defaultParamType
),
paramType,
index,
specializedParamType,
defaultParamType
@ -270,6 +275,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
FunctionParamFlags.TypeDeclared,
'kwargs'
),
paramType,
index,
paramType.shared.typedDictEntries.extraItems.valueType
);
@ -288,7 +294,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
result.firstKeywordOnlyIndex = result.params.length;
}
addVirtualParam(param, index);
addVirtualParam(param, paramType, index);
}
} else if (param.category === ParamCategory.Simple) {
if (param.name && !sawKeywordOnlySeparator) {
@ -297,6 +303,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
addVirtualParam(
param,
FunctionType.getParamType(type, index),
index,
/* typeOverride */ undefined,
type.priv.specializedTypes?.parameterDefaultTypes

View file

@ -35,6 +35,7 @@ import {
ParameterNode,
ParseNode,
ParseNodeType,
RefinementNode,
StatementListNode,
StatementNode,
StringListNode,
@ -1081,6 +1082,39 @@ export function getTypeVarScopeNode(node: ParseNode): TypeParameterScopeNode | u
return undefined;
}
// Similar to getTypeVarScopeNode except for refinement variable scopes.
export function getRefinementScopeNode(node: ParseNode): ParseNode | undefined {
let prevNode: ParseNode | undefined;
let curNode: ParseNode | undefined = node;
let exprNode: ExpressionNode | RefinementNode | undefined;
let sawNonExprNode = false;
while (curNode) {
if (curNode.nodeType === ParseNodeType.Function) {
if (prevNode === curNode.d.suite) {
return exprNode ?? curNode;
}
if (!curNode.d.decorators.some((decorator) => decorator === prevNode)) {
return curNode;
}
}
if (isExpressionNode(curNode) && curNode.nodeType !== ParseNodeType.TypeAnnotation) {
if (!sawNonExprNode) {
exprNode = curNode;
}
} else if (curNode.nodeType !== ParseNodeType.Argument && curNode.nodeType !== ParseNodeType.Refinement) {
sawNonExprNode = true;
}
prevNode = curNode;
curNode = curNode.parent;
}
return exprNode;
}
// Returns the parse node corresponding to the scope that is used
// for executing the code referenced in the specified node.
export function getExecutionScopeNode(node: ParseNode): ExecutionScopeNode {
@ -2217,6 +2251,9 @@ export function printParseNodeType(type: ParseNodeType) {
case ParseNodeType.TypeAlias:
return 'TypeAlias';
case ParseNodeType.Refinement:
return 'Refinement';
}
assertNever(type);

View file

@ -69,6 +69,7 @@ import {
PatternSequenceNode,
PatternValueNode,
RaiseNode,
RefinementNode,
ReturnNode,
SetNode,
SliceNode,
@ -271,6 +272,9 @@ export function getChildNodes(node: ParseNode): (ParseNode | undefined)[] {
case ParseNodeType.PatternValue:
return [node.d.expr];
case ParseNodeType.Refinement:
return [node.d.valueExpr, node.d.conditionExpr];
case ParseNodeType.Raise:
return [node.d.expr, node.d.fromExpr];
@ -287,7 +291,7 @@ export function getChildNodes(node: ParseNode): (ParseNode | undefined)[] {
return node.d.statements;
case ParseNodeType.StringList:
return [node.d.annotation, ...node.d.strings];
return [node.d.annotation, node.d.refinement, ...node.d.strings];
case ParseNodeType.String:
return [];
@ -519,6 +523,9 @@ export class ParseTreeVisitor<T> {
case ParseNodeType.PatternValue:
return this.visitPatternValue(node);
case ParseNodeType.Refinement:
return this.visitRefinement(node);
case ParseNodeType.Raise:
return this.visitRaise(node);
@ -815,6 +822,10 @@ export class ParseTreeVisitor<T> {
return this._default;
}
visitRefinement(node: RefinementNode) {
return this._default;
}
visitRaise(node: RaiseNode) {
return this._default;
}

View file

@ -0,0 +1,173 @@
/*
* refinementPrinter.ts
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
* Author: Eric Traut
*
* Logic that converts a refinement type to a user-visible string.
*/
import { assertNever } from '../common/debug';
import { OperatorType } from '../parser/tokenizerTypes';
import { RefinementExpr, RefinementNodeType, TypeRefinement } from './refinementTypes';
import { printBytesLiteral, printStringLiteral } from './typePrinterUtils';
export interface PrintRefinementExprOptions {
// Include scopes for refinement variables?
printVarScopes?: boolean;
// Surround tuples with parens?
encloseTupleInParens?: boolean;
}
// Converts a refinement definition to its text form.
export function printRefinement(refinement: TypeRefinement, options?: PrintRefinementExprOptions): string {
const value = printRefinementExpr(refinement.value, options);
const condition = refinement.condition ? ` if ${printRefinementExpr(refinement.condition, options)}` : '';
if (refinement.classDetails.baseSupportsLiteral && !condition) {
if (
refinement.value.nodeType === RefinementNodeType.Number ||
refinement.value.nodeType === RefinementNodeType.String ||
refinement.value.nodeType === RefinementNodeType.Bytes ||
refinement.value.nodeType === RefinementNodeType.Boolean
) {
return value;
}
}
if (refinement.classDetails.baseSupportsStringShortcut) {
return `"${value}${condition}"`;
}
return `${refinement.classDetails.className}("${value}${condition}")`;
}
export function printRefinementExpr(expr: RefinementExpr, options: PrintRefinementExprOptions = {}): string {
switch (expr.nodeType) {
case RefinementNodeType.Number: {
return expr.value.toString();
}
case RefinementNodeType.String: {
return printStringLiteral(expr.value, "'");
}
case RefinementNodeType.Bytes: {
return printBytesLiteral(expr.value);
}
case RefinementNodeType.Boolean: {
return expr.value ? 'True' : 'False';
}
case RefinementNodeType.Wildcard: {
return '_';
}
case RefinementNodeType.Var: {
if (options?.printVarScopes) {
return `${expr.var.shared.name}@${expr.var.shared.scopeName}`;
}
return expr.var.shared.name;
}
case RefinementNodeType.BinaryOp: {
// Map the operator to a string and numerical evaluation precedence.
const operatorMap: { [key: number]: [string, number, boolean] } = {
[OperatorType.Multiply]: ['*', 1, true],
[OperatorType.FloorDivide]: ['//', 1, false],
[OperatorType.Mod]: ['%', 1, false],
[OperatorType.Add]: ['+', 2, true],
[OperatorType.Subtract]: ['-', 2, false],
[OperatorType.Equals]: ['==', 3, false],
[OperatorType.NotEquals]: ['!=', 3, false],
[OperatorType.LessThan]: ['<', 3, false],
[OperatorType.LessThanOrEqual]: ['<=', 3, false],
[OperatorType.GreaterThan]: ['>', 3, false],
[OperatorType.GreaterThanOrEqual]: ['>=', 3, false],
[OperatorType.And]: ['and', 4, true],
[OperatorType.Or]: ['or', 5, true],
};
const operatorStr = operatorMap[expr.operator][0] ?? '<unknown>';
const isCommutative = operatorMap[expr.operator][2] ?? false;
let leftStr = printRefinementExpr(expr.leftExpr, options);
let rightStr = printRefinementExpr(expr.rightExpr, options);
const operatorPrecedence = operatorMap[expr.operator][1] ?? 0;
if (expr.leftExpr.nodeType === RefinementNodeType.BinaryOp) {
const leftPrecedence = operatorMap[expr.leftExpr.operator][1] ?? 0;
if (leftPrecedence > operatorPrecedence) {
leftStr = `(${leftStr})`;
}
}
if (expr.rightExpr.nodeType === RefinementNodeType.BinaryOp) {
const rightPrecedence = operatorMap[expr.rightExpr.operator][1] ?? 0;
const isRightCommutative = operatorMap[expr.rightExpr.operator][1] ?? 0;
let includeParens = rightPrecedence >= operatorPrecedence;
if (rightPrecedence === operatorPrecedence && isCommutative && isRightCommutative) {
includeParens = false;
}
if (includeParens) {
rightStr = `(${rightStr})`;
}
}
return `${leftStr} ${operatorStr} ${rightStr}`;
}
case RefinementNodeType.UnaryOp: {
const operatorMap: { [key: number]: string } = {
[OperatorType.Add]: '+',
[OperatorType.Subtract]: '-',
[OperatorType.Not]: 'not ',
};
const operatorStr = operatorMap[expr.operator] ?? '<unknown>';
return `${operatorStr}${printRefinementExpr(expr.expr, options)}`;
}
case RefinementNodeType.Tuple: {
const entries = expr.entries.map((elem) => {
let baseElemStr = printRefinementExpr(elem.value, options);
if (elem.value.nodeType === RefinementNodeType.Tuple) {
baseElemStr = `(${baseElemStr})`;
}
return `${elem.isUnpacked ? '*' : ''}${baseElemStr}`;
});
if (expr.entries.length === 0) {
return '()';
}
let tupleStr: string;
if (expr.entries.length === 1) {
tupleStr = `${entries[0]},`;
} else {
tupleStr = entries.join(', ');
}
if (options.encloseTupleInParens) {
return `(${tupleStr})`;
}
return tupleStr;
}
case RefinementNodeType.Call: {
const args = expr.args.map((arg) => printRefinementExpr(arg, { ...options, encloseTupleInParens: true }));
return `${expr.name}(${args.join(', ')})`;
}
default: {
assertNever(expr);
}
}
}

View file

@ -0,0 +1,566 @@
/*
* refinementSolver.ts
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
* Author: Eric Traut
*
* Constraint solver for refinement types.
*/
import { assert } from '../common/debug';
import { DiagnosticAddendum } from '../common/diagnostic';
import { LocAddendum } from '../localization/localize';
import { ConstraintSolutionSet } from './constraintSolution';
import { ConstraintSet, ConstraintTracker } from './constraintTracker';
import { printRefinementExpr } from './refinementPrinter';
import {
RefinementExpr,
RefinementNodeType,
RefinementNumberNode,
RefinementTupleEntry,
RefinementVarId,
RefinementVarNode,
TypeRefinement,
} from './refinementTypes';
import {
applySolvedRefinementVars,
createWildcardRefinementValue,
evaluateRefinementCondition,
evaluateRefinementExpression,
getFreeRefinementVars,
isRefinementExprEquivalent,
isRefinementLiteral,
isRefinementTuple,
isRefinementVar,
isRefinementWildcard,
RefinementTypeDiag,
} from './refinementTypeUtils';
import { ClassType, isAnyOrUnknown, isClass, isClassInstance } from './types';
import { getBuiltInRefinementClassId } from './typeUtils';
export interface AssignRefinementsOptions {
checkOverloadOverlap?: boolean;
}
// Attempts to assign a srcType with a refinement type to a destType
// with a refinement type.
export function assignRefinements(
destType: ClassType,
srcType: ClassType,
diag: DiagnosticAddendum | undefined,
constraints: ConstraintTracker | undefined,
options?: AssignRefinementsOptions
): boolean {
let assignmentOk = true;
// Apply refinements by class.
const destRefs = destType.priv.refinements ?? [];
const srcRefs = srcType.priv.refinements ?? [];
const synthesizedRefinement = synthesizeRefinementTypeFromLiteral(srcType);
if (synthesizedRefinement) {
srcRefs.push(synthesizedRefinement);
}
for (const destRef of destRefs) {
let srcClassMatch = false;
for (const srcRef of srcRefs) {
if (destRef.classDetails.classId !== srcRef.classDetails.classId) {
continue;
}
srcClassMatch = true;
if (!assignRefinement(destRef, srcRef, diag, constraints, options)) {
assignmentOk = false;
}
}
if (srcRefs.length === 0 && options?.checkOverloadOverlap) {
assignmentOk = false;
}
// If no source refinements matched the dest refinement class,
// the assignment validity is based on whether it's enforced.
if (!srcClassMatch && destRef.isEnforced) {
assignmentOk = false;
}
}
return assignmentOk;
}
export function solveRefinementVarRecursive(
constraintSet: ConstraintSet,
solutionSet: ConstraintSolutionSet,
varId: RefinementVarId
): RefinementExpr | undefined {
// If this refinement variable already has a solution, don't attempt to re-solve it.
if (solutionSet.hasRefinementVarType(varId)) {
return solutionSet.getRefinementVarType(varId);
}
const value = constraintSet.getRefinementVarType(varId);
if (!value) {
return undefined;
}
// Protect against infinite recursion by setting the initial value to
// undefined. We'll replace this later with a real value.
solutionSet.setRefinementVarType(varId, /* value */ undefined);
// Determine which free variables are referenced by this expression. We need
// to ensure that they are solved first.
const freeVars = getFreeRefinementVars(value);
for (const freeVar of freeVars) {
solveRefinementVarRecursive(constraintSet, solutionSet, freeVar.id);
}
// Now evaluate the expression.
const solvedValue = applySolvedRefinementVars(value, solutionSet.getRefinementVarMap());
const simplifiedValue = evaluateRefinementExpression(solvedValue);
solutionSet.setRefinementVarType(varId, simplifiedValue);
return simplifiedValue;
}
export function assignRefinement(
destRefinement: TypeRefinement,
srcRefinement: TypeRefinement,
diag: DiagnosticAddendum | undefined,
constraints: ConstraintTracker | undefined,
options?: AssignRefinementsOptions
): boolean {
assert(destRefinement.classDetails.classId === srcRefinement.classDetails.classId);
const destValue = destRefinement.value;
const srcValue = srcRefinement.value;
// Determine if there are any conditions provided by the caller (for
// function call evaluation) or local conditions (for assignments).
let conditions: RefinementExpr[] | undefined;
if (!conditions && destRefinement.condition) {
conditions = [destRefinement.condition];
}
// If we have conditions to verify but have not been provided a
// constraint tracker, create a temporary one.
if (conditions && !constraints) {
constraints = new ConstraintTracker();
}
// Handle tuples specially.
if (
destRefinement.classDetails.domain === 'IntTupleRefinement' &&
destValue.nodeType === RefinementNodeType.Tuple &&
srcValue.nodeType === RefinementNodeType.Tuple
) {
const srcEntries = [...srcValue.entries];
// If the dest and source tuple shapes match, we can skip any reshaping efforts.
if (
destValue.entries.length !== srcValue.entries.length ||
destValue.entries.some((entry, i) => entry.isUnpacked !== srcValue.entries[i].isUnpacked)
) {
if (!adjustSourceTupleShape(destValue.entries, srcEntries)) {
const msg =
destRefinement.classDetails.classId === getBuiltInRefinementClassId('Shape')
? LocAddendum.refinementShapeMismatch()
: LocAddendum.refinementTupleMismatch();
diag?.addMessage(
msg.format({
expected: printRefinementExpr(destValue),
received: printRefinementExpr(srcValue),
})
);
return false;
}
}
// At this point, the dest and src tuples should have the same shape
// (i.e. same length and with the same entries unpacked or not).
assert(destValue.entries.length === srcEntries.length);
for (let i = 0; i < destValue.entries.length; i++) {
const destEntry = destValue.entries[i];
const srcEntry = srcEntries[i];
assert(destEntry.isUnpacked === srcEntry.isUnpacked);
if (!assignRefinementValue(destEntry.value, srcEntry.value, diag, constraints, options)) {
return false;
}
}
} else {
if (!assignRefinementValue(destValue, srcValue, diag, constraints, options)) {
return false;
}
}
if (conditions && !validateRefinementConditions(conditions, diag, constraints)) {
return false;
}
return true;
}
export function validateRefinementConditions(
conditions: RefinementExpr[],
diag: DiagnosticAddendum | undefined,
constraints: ConstraintTracker | undefined
): boolean {
let solvedConditions = [...conditions];
if (constraints) {
const solutionSet = new ConstraintSolutionSet();
const constraintSet = constraints.getMainConstraintSet();
// Solve the refinement variables.
constraintSet.doForEachRefinementVar((name) => {
solveRefinementVarRecursive(constraintSet, solutionSet, name);
});
solvedConditions = conditions.map((condition) =>
applySolvedRefinementVars(condition, solutionSet.getRefinementVarMap())
);
}
for (let i = 0; i < solvedConditions.length; i++) {
const errors: RefinementTypeDiag[] = [];
if (!evaluateRefinementCondition(solvedConditions[i], { refinements: { errors } })) {
diag?.addMessage(
LocAddendum.refinementConditionNotSatisfied().format({
condition: printRefinementExpr(solvedConditions[i]),
})
);
errors.forEach((error) => {
diag?.addAddendum(error.diag);
});
return false;
}
}
return true;
}
export function synthesizeRefinementTypeFromLiteral(classType: ClassType): TypeRefinement | undefined {
if (!isClass(classType)) {
return undefined;
}
if (ClassType.isBuiltIn(classType, 'tuple')) {
const typeArgs = classType.priv.tupleTypeArgs;
if (!typeArgs) {
return undefined;
}
let foundInt = false;
if (
!typeArgs.every((typeArg) => {
if (isClassInstance(typeArg.type) && ClassType.isBuiltIn(typeArg.type, 'int')) {
foundInt = true;
return true;
}
return isAnyOrUnknown(typeArg.type);
}) ||
!foundInt
) {
return undefined;
}
const entries: RefinementTupleEntry[] = typeArgs.map((typeArg) => {
if (isClassInstance(typeArg.type) && typeArg.type.priv.literalValue !== undefined && !typeArg.isUnbounded) {
const value = typeArg.type.priv.literalValue;
assert(typeof value === 'number' || typeof value === 'bigint');
const entry: RefinementNumberNode = { nodeType: RefinementNodeType.Number, value: value };
return { value: entry, isUnpacked: false };
}
return { value: createWildcardRefinementValue(), isUnpacked: typeArg.isUnbounded };
});
return {
classDetails: {
domain: 'IntTupleRefinement',
className: 'IntTupleValue',
classId: getBuiltInRefinementClassId('IntTupleValue'),
},
value: {
nodeType: RefinementNodeType.Tuple,
entries,
},
isEnforced: false,
};
}
if (classType.priv.literalValue === undefined) {
return undefined;
}
if (ClassType.isBuiltIn(classType, 'int')) {
assert(typeof classType.priv.literalValue === 'number' || typeof classType.priv.literalValue === 'bigint');
return {
classDetails: {
domain: 'IntRefinement',
className: 'IntValue',
classId: getBuiltInRefinementClassId('IntValue'),
},
value: {
nodeType: RefinementNodeType.Number,
value: classType.priv.literalValue,
},
isEnforced: true,
};
}
if (ClassType.isBuiltIn(classType, 'str')) {
assert(typeof classType.priv.literalValue === 'string');
return {
classDetails: {
domain: 'StrRefinement',
className: 'StrValue',
classId: getBuiltInRefinementClassId('StrValue'),
},
value: {
nodeType: RefinementNodeType.String,
value: classType.priv.literalValue,
},
isEnforced: true,
};
}
if (ClassType.isBuiltIn(classType, 'bytes')) {
assert(typeof classType.priv.literalValue === 'string');
return {
classDetails: {
domain: 'BytesRefinement',
className: 'BytesValue',
classId: getBuiltInRefinementClassId('BytesValue'),
},
value: {
nodeType: RefinementNodeType.Bytes,
value: classType.priv.literalValue,
},
isEnforced: true,
};
}
if (ClassType.isBuiltIn(classType, 'bool')) {
assert(typeof classType.priv.literalValue === 'boolean');
return {
classDetails: {
domain: 'BoolRefinement',
className: 'BoolValue',
classId: getBuiltInRefinementClassId('BoolValue'),
},
value: {
nodeType: RefinementNodeType.Boolean,
value: classType.priv.literalValue,
},
isEnforced: true,
};
}
return undefined;
}
// Adjusts the srcEntries to match the shape of the destEntries if possible.
// It assumes the caller has already confirmed that the dest and src shapes
// don't already match. The shape includes both the number of entries and
// whether each of those entries is unpacked.
function adjustSourceTupleShape(destEntries: RefinementTupleEntry[], srcEntries: RefinementTupleEntry[]): boolean {
const destUnpackCount = destEntries.filter((entry) => entry.isUnpacked).length;
const srcUnpackCount = srcEntries.filter((entry) => entry.isUnpacked).length;
const srcUnpackIndex = srcEntries.findIndex((entry) => entry.isUnpacked);
if (destUnpackCount > 1) {
// If there's more than one unpacked entry in the dest, there
// is no unambiguous way to adjust the source, so don't attempt.
return false;
}
if (destUnpackCount === 1) {
// If the dest has a single unpacked entry, we may be able to adjust
// the source shape to match it.
const srcEntriesToPack = srcEntries.length - destEntries.length + 1;
if (srcEntriesToPack < 0) {
return false;
}
const destUnpackIndex = destEntries.findIndex((entry) => entry.isUnpacked);
const removedEntries = srcEntries.splice(destUnpackIndex, srcEntriesToPack);
// If any of the remaining source entries are unpacked, we can't
// make the shapes match.
if (srcEntries.some((entry) => entry.isUnpacked)) {
return false;
}
// Add a new unpacked tuple entry.
srcEntries.splice(destUnpackIndex, 0, {
value: { nodeType: RefinementNodeType.Tuple, entries: removedEntries },
isUnpacked: true,
});
return true;
}
// If the dest has no unpacked entries, the source cannot have any
// unpacked entries unless it has an unpacked wildcard.
if (srcUnpackCount > 1) {
return false;
}
if (srcUnpackIndex < 0 || srcEntries[srcUnpackIndex].value.nodeType !== RefinementNodeType.Wildcard) {
return false;
}
// Remove the unpacked wildcard entry to make the shapes match.
srcEntries.splice(srcUnpackIndex, 1);
if (srcEntries.length > destEntries.length) {
return false;
}
// Insert wildcard entries to match the dest shape.
while (srcEntries.length < destEntries.length) {
srcEntries.splice(srcUnpackIndex, 0, {
value: createWildcardRefinementValue(),
isUnpacked: false,
});
}
return true;
}
function assignRefinementValue(
destExpr: RefinementExpr,
srcExpr: RefinementExpr,
diag: DiagnosticAddendum | undefined,
constraints: ConstraintTracker | undefined,
options?: AssignRefinementsOptions
): boolean {
// Handle assignment to or from wildcard.
if (isRefinementWildcard(destExpr) || isRefinementWildcard(srcExpr)) {
return true;
}
if (isRefinementExprEquivalent(srcExpr, destExpr)) {
return true;
}
// Handle assignments to literals.
if (isRefinementLiteral(destExpr) && isRefinementLiteral(srcExpr)) {
if (destExpr.value !== srcExpr.value) {
diag?.addMessage(
LocAddendum.refinementLiteralAssignment().format({
expected: printRefinementExpr(destExpr),
received: printRefinementExpr(srcExpr),
})
);
return false;
}
return true;
}
if (isRefinementVar(destExpr)) {
if (assignToRefinementVar(destExpr, srcExpr, diag, constraints)) {
return true;
}
}
if (isRefinementExprEquivalent(destExpr, srcExpr)) {
return true;
}
if (isRefinementTuple(destExpr) && isRefinementTuple(srcExpr)) {
if (destExpr.entries.length === srcExpr.entries.length) {
if (
destExpr.entries.every((destEntry, i) => {
const srcEntry = srcExpr.entries[i];
return (
destEntry.isUnpacked === srcEntry.isUnpacked &&
assignRefinementValue(destEntry.value, srcEntry.value, diag, constraints)
);
})
) {
return true;
}
}
}
// See if we can simplify the source or the dest expression and try again.
const simplifiedDest = evaluateRefinementExpression(destExpr);
const simplifiedSrc = evaluateRefinementExpression(srcExpr);
if (
!TypeRefinement.isRefinementExprSame(simplifiedDest, destExpr) ||
!TypeRefinement.isRefinementExprSame(simplifiedSrc, srcExpr)
) {
return assignRefinementValue(simplifiedDest, simplifiedSrc, diag, constraints);
}
return false;
}
function assignToRefinementVar(
destExpr: RefinementVarNode,
srcExpr: RefinementExpr,
diag: DiagnosticAddendum | undefined,
constraints: ConstraintTracker | undefined
): boolean {
// If the dest is a bound variable, it cannot receive any value other
// than a wildcard and itself.
if (destExpr.var.isBound) {
if (isRefinementWildcard(srcExpr)) {
return true;
}
if (isRefinementExprEquivalent(srcExpr, destExpr)) {
return true;
}
diag?.addMessage(
LocAddendum.refinementValMismatch().format({
expected: printRefinementExpr({ nodeType: RefinementNodeType.Var, var: destExpr.var }),
received: printRefinementExpr(srcExpr),
})
);
return false;
}
// If there is no constraint tracker, we have nothing more to do.
if (!constraints) {
return true;
}
const constraintSet = constraints.getMainConstraintSet();
const curValue = constraintSet.getRefinementVarType(destExpr.var.id);
// If there is a current value, the new value must be the same or more specific.
if (curValue) {
if (!assignRefinementValue(curValue, srcExpr, /* diag */ undefined, /* constraints */ undefined)) {
diag?.addMessage(
LocAddendum.refinementValMismatch().format({
expected: printRefinementExpr(curValue),
received: printRefinementExpr(srcExpr),
})
);
return false;
}
}
// Assign the new value.
constraintSet.setRefinementVarType(destExpr.var.id, srcExpr);
return true;
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -14,6 +14,7 @@ import { ExpressionNode, ParseNodeType, SliceNode, TupleNode } from '../parser/p
import { addConstraintsForExpectedType } from './constraintSolver';
import { ConstraintTracker } from './constraintTracker';
import { getTypeVarScopesForNode } from './parseTreeUtils';
import { RefinementNodeType } from './refinementTypes';
import { AssignTypeFlags, EvalFlags, maxInferredContainerDepth, TypeEvaluator, TypeResult } from './typeEvaluatorTypes';
import {
AnyType,
@ -35,6 +36,7 @@ import {
import {
convertToInstance,
doForEachSubtype,
getBuiltInRefinementClassId,
getContainerDepth,
InferenceContext,
isLiteralType,
@ -582,3 +584,60 @@ function getTupleSliceParam(
return value;
}
// If this is a tuple[int, ...] with one or more refinements in the
// IntTupleRefinement refinement domain, returns a refined version
// of the TupleTypeArg entries.
export function getRefinementShapeForTuple(type: Type, entries: TupleTypeArg[]): TupleTypeArg[] {
if (!isClassInstance(type)) {
return entries;
}
if (entries.length !== 1 || !entries[0].isUnbounded) {
return entries;
}
// Shapes apply only to types of type[int, ...].
const entryType = entries[0].type;
if (
!isClassInstance(entryType) ||
!ClassType.isBuiltIn(entryType, 'int') ||
entryType.priv.literalValue !== undefined
) {
return entries;
}
const tupleRefinement = type.priv.refinements?.find((r) => r.classDetails.domain === 'IntTupleRefinement');
if (!tupleRefinement) {
return entries;
}
const refinementShape = tupleRefinement.value;
if (refinementShape.nodeType !== RefinementNodeType.Tuple) {
return entries;
}
return refinementShape.entries.map((entry) => {
const refinement = {
...tupleRefinement,
value: entry.value,
condition: undefined,
};
if (!entry.isUnpacked) {
refinement.classDetails = {
domain: 'IntRefinement',
className: 'IntValue',
classId: getBuiltInRefinementClassId('IntValue'),
baseSupportsLiteral: true,
baseSupportsStringShortcut: true,
};
}
return {
type: ClassType.cloneAddRefinement(entryType, refinement),
isUnbounded: entry.isUnpacked,
};
});
}

File diff suppressed because it is too large Load diff

View file

@ -34,6 +34,7 @@ import { CodeFlowReferenceExpressionNode, FlowNode } from './codeFlowTypes';
import { ConstraintTracker } from './constraintTracker';
import { Declaration } from './declaration';
import * as DeclarationUtils from './declarationUtils';
import { RefinementExpr, RefinementVarType, TypeRefinement } from './refinementTypes';
import { SymbolWithScope } from './scope';
import { Symbol } from './symbol';
import { PrintTypeFlags } from './typePrinter';
@ -178,6 +179,9 @@ export const enum EvalFlags {
// with the enclosing class or an outer scope.
EnforceClassTypeVarScope = 1 << 31,
// Expecting a possible refinement type expression.
Refinement = 1 << 32,
// Defaults used for evaluating the LHS of a call expression.
CallBaseDefaults = NoSpecialize,
@ -254,6 +258,9 @@ export interface TypeResult<T extends Type = Type> {
// Deprecation messages related to magic methods.
magicMethodDeprecationInfo?: MagicMethodDeprecationInfo;
// Refinement definition.
refinement?: TypeRefinement;
}
export interface TypeResultWithNode extends TypeResult {
@ -345,6 +352,7 @@ export interface EffectiveTypeResult {
export interface ValidateArgTypeParams {
paramCategory: ParamCategory;
paramType: Type;
refinementConditions: RefinementExpr[] | undefined;
requiresTypeVarMatching: boolean;
argument: Arg;
isDefaultArg?: boolean;
@ -352,7 +360,7 @@ export interface ValidateArgTypeParams {
errorNode: ExpressionNode;
paramName?: string | undefined;
isParamNameSynthesized?: boolean;
mapsToVarArgList?: boolean | undefined;
mapsToVarArgListType?: Type | undefined;
isinstanceParam?: boolean;
}
@ -406,6 +414,9 @@ export interface CallResult {
// Were any errors discovered when evaluating argument types?
argumentErrors?: boolean;
// Were any refinement errors discovered when evaluating arg types?
refinementErrors?: boolean;
// Did one or more arguments evaluated to Any or Unknown?
anyOrUnknownArg?: UnknownType | AnyType;
@ -517,6 +528,19 @@ export interface SynthesizedTypeInfo {
export interface SymbolDeclInfo {
decls: Declaration[];
synthesizedTypes: SynthesizedTypeInfo[];
refinementInfo?: {
// Type of refinement variable, if applicable.
varType?: RefinementVarType;
// Indicates that the symbol is a refinement wildcard symbol ("_").
isWildcard?: boolean;
// Signature of refinement function call.
callSignature?: string;
// Docstring for refinement call.
callDocstring?: string;
};
}
export const enum AssignTypeFlags {
@ -608,6 +632,7 @@ export interface TypeEvaluator {
createSubclass: (errorNode: ExpressionNode, type1: ClassType, type2: ClassType) => ClassType;
getTypeOfFunction: (node: FunctionNode) => FunctionTypeResult | undefined;
getTypeOfExpressionExpectingType: (node: ExpressionNode, options?: ExpectedTypeOptions) => TypeResult;
getTypeMetadata: (errorNode: ExpressionNode, typeResult: TypeResult, node: ExpressionNode) => TypeResult;
evaluateTypeForSubnode: (subnode: ParseNode, callback: () => void) => TypeResult | undefined;
evaluateTypesForStatement: (node: ParseNode) => void;
evaluateTypesForMatchStatement: (node: MatchNode) => void;
@ -652,6 +677,7 @@ export interface TypeEvaluator {
getDeclInfoForStringNode: (node: StringNode) => SymbolDeclInfo | undefined;
getDeclInfoForNameNode: (node: NameNode, skipUnreachableCode?: boolean) => SymbolDeclInfo | undefined;
getRefinementVarType: (node: NameNode) => RefinementVarType | undefined;
getTypeForDeclaration: (declaration: Declaration) => DeclaredSymbolTypeInfo;
resolveAliasDeclaration: (
declaration: Declaration,
@ -728,7 +754,8 @@ export interface TypeEvaluator {
methodName: string,
argList: TypeResult[],
errorNode: ExpressionNode,
inferenceContext: InferenceContext | undefined
inferenceContext: InferenceContext | undefined,
diag?: DiagnosticAddendum
) => TypeResult | undefined;
bindFunctionToClassOrObject: (
baseType: ClassType | undefined,

View file

@ -12,6 +12,7 @@ import { assert } from '../common/debug';
import { ParamCategory } from '../parser/parseNodes';
import { isTypedKwargs } from './parameterUtils';
import * as ParseTreeUtils from './parseTreeUtils';
import { printRefinement } from './refinementPrinter';
import { printBytesLiteral, printStringLiteral } from './typePrinterUtils';
import {
ClassType,
@ -197,7 +198,6 @@ function printTypeInternal(
recursionCount++;
const originalPrintTypeFlags = printTypeFlags;
const parenthesizeUnion = (printTypeFlags & PrintTypeFlags.ParenthesizeUnion) !== 0;
printTypeFlags &= ~(PrintTypeFlags.ParenthesizeUnion | PrintTypeFlags.ParenthesizeCallable);
// If this is a type alias, see if we should use its name rather than
@ -213,6 +213,11 @@ function printTypeInternal(
}
}
// If there are refinements on the type, don't use the type alias.
if (type.category === TypeCategory.Class && type.priv.refinements) {
expandTypeAlias = true;
}
if (!expandTypeAlias) {
try {
recursionTypes.push(type);
@ -363,6 +368,36 @@ function printTypeInternal(
return '...';
}
const baseTypeStr = printTypeWithoutRefinement(
type,
originalPrintTypeFlags,
returnTypeCallback,
uniqueNameMap,
recursionTypes,
recursionCount
);
if (!isClass(type) || !type.priv?.refinements || type.priv.refinements.length === 0) {
return baseTypeStr;
}
const refinements = type.priv.refinements.map((refinement) => printRefinement(refinement));
return `${baseTypeStr} @ ${refinements.join(' @ ')}`;
}
function printTypeWithoutRefinement(
type: Type,
printTypeFlags: PrintTypeFlags,
returnTypeCallback: FunctionReturnTypeCallback,
uniqueNameMap: UniqueNameMap,
recursionTypes: Type[],
recursionCount: number
): string {
const originalPrintTypeFlags = printTypeFlags;
const parenthesizeUnion = (printTypeFlags & PrintTypeFlags.ParenthesizeUnion) !== 0;
printTypeFlags &= ~(PrintTypeFlags.ParenthesizeUnion | PrintTypeFlags.ParenthesizeCallable);
try {
recursionTypes.push(type);

View file

@ -12,6 +12,16 @@ import { assert } from '../common/debug';
import { ParamCategory } from '../parser/parseNodes';
import { ConstraintSolution, ConstraintSolutionSet } from './constraintSolution';
import { DeclarationType } from './declaration';
import { RefinementClassDetails, RefinementDomain, TypeRefinement } from './refinementTypes';
import {
applySolvedRefinementVars,
evaluateRefinementExpression,
isRefinementLiteral,
isRefinementWildcard,
makeRefinementVarsBound,
makeRefinementVarsFree,
RefinementTypeDiag,
} from './refinementTypeUtils';
import { Symbol, SymbolFlags, SymbolTable } from './symbol';
import { isEffectivelyClassVar, isTypedDictMemberAccessedThroughIndex } from './symbolUtils';
import {
@ -214,6 +224,10 @@ export interface ApplyTypeVarOptions {
useUnknown?: boolean;
eliminateUnsolvedInUnions?: boolean;
};
refinements?: {
errors?: RefinementTypeDiag[];
warnings?: RefinementTypeDiag[];
};
}
// Tracks whether a function signature has been seen before within
@ -294,6 +308,41 @@ export function isIncompleteUnknown(type: Type): boolean {
return isUnknown(type) && type.priv.isIncomplete;
}
export function getRefinementDomain(type: ClassType): RefinementDomain | undefined {
for (const baseClass of type.shared.mro) {
if (isClass(baseClass) && ClassType.isBuiltIn(baseClass)) {
const className = baseClass.shared.name;
const domainNames = ['IntRefinement', 'StrRefinement', 'IntTupleRefinement', 'Refinement'];
if (domainNames.includes(className)) {
return className as RefinementDomain;
}
}
}
return undefined;
}
export function getRefinementClassId(type: ClassType): string {
if (ClassType.isBuiltIn(type)) {
return getBuiltInRefinementClassId(type.shared.name);
}
return `${type.shared.fileUri}:${type.shared.name}`;
}
export function getBuiltInRefinementClassId(className: string): string {
// Choose a unique ID that will never conflict with a user-defined class.
return `stdlib:${className}`;
}
export function hasTupleRefinement(type: ClassType): boolean {
if (!type.priv.refinements) {
return false;
}
return type.priv.refinements.some((refinement) => refinement.classDetails.domain === 'IntTupleRefinement');
}
// Similar to isTypeSame except that type1 is a TypeVar and type2
// can be either a TypeVar of the same type or a union that includes
// conditional types associated with that bound TypeVar.
@ -1200,9 +1249,78 @@ export function isLiteralLikeType(type: ClassType): boolean {
return true;
}
// If the class has a value refinement type, it's considered literal-like.
if (type.priv.refinements) {
const valueClasses = ['IntValue', 'StrValue', 'BytesValue'].map((name) => getBuiltInRefinementClassId(name));
return type.priv.refinements.some(
(refinement) => valueClasses.includes(refinement.classDetails.classId) && refinement.isEnforced
);
}
return false;
}
// If possible, converts an int, str, or bytes value with a refinement type
// into its corresponding Literal type.
export function simplifyRefinementTypeToLiteral(type: ClassType): ClassType {
const refinementClassMap: { [key: string]: string } = {
int: 'IntValue',
str: 'StrValue',
bytes: 'BytesValue',
bool: 'BoolValue',
};
if (!type.priv.refinements || type.priv.literalValue !== undefined || !ClassType.isBuiltIn(type)) {
return type;
}
const refClassName = refinementClassMap[type.shared.name];
if (!refClassName) {
return type;
}
const refClassId = getBuiltInRefinementClassId(refClassName);
const refinements = type.priv.refinements.filter((refinement) => refinement.classDetails.classId === refClassId);
if (refinements.length !== 1) {
return type;
}
const refinementExpr = refinements[0].value;
if (!isRefinementLiteral(refinementExpr)) {
return type;
}
const remainingRefinements = type.priv.refinements.filter(
(refinement) => refinement.classDetails.classId !== refClassId
);
let resultType = ClassType.cloneWithLiteral(type, refinementExpr.value);
resultType = ClassType.cloneWithRefinements(resultType, remainingRefinements);
return resultType;
}
// If the value has an IntValue refinement type associated with it or is an
// integer literal, this function returns the effective value.
export function getIntValueRefinement(type: ClassType): TypeRefinement | undefined {
if (type.priv.literalValue !== undefined) {
if (!ClassType.isBuiltIn(type, 'int')) {
return undefined;
}
assert(typeof type.priv.literalValue === 'number' || typeof type.priv.literalValue === 'bigint');
const classDetails: RefinementClassDetails = {
domain: 'IntRefinement',
className: 'IntValue',
classId: getBuiltInRefinementClassId('IntValue'),
baseSupportsLiteral: true,
baseSupportsStringShortcut: true,
};
return TypeRefinement.fromLiteral(classDetails, type.priv.literalValue, /* isEnforced */ true);
}
const intValueClassId = getBuiltInRefinementClassId('IntValue');
return type.priv.refinements?.find((refinement) => refinement.classDetails.classId === intValueClassId);
}
export function containsLiteralType(type: Type, includeTypeArgs = false): boolean {
class ContainsLiteralTypeWalker extends TypeWalker {
foundLiteral = false;
@ -2025,6 +2143,65 @@ export function getTypeVarArgsRecursive(type: Type, recursionCount = 0): TypeVar
return [];
}
export function getRefinementsRecursive(type: Type, recursionCount = 0): TypeRefinement[] {
if (recursionCount > maxTypeRecursionCount) {
return [];
}
recursionCount++;
const aliasInfo = type.props?.typeAliasInfo;
if (aliasInfo?.typeArgs) {
const combinedList: TypeRefinement[] = [];
aliasInfo?.typeArgs.forEach((typeArg) => {
appendArray(combinedList, getRefinementsRecursive(typeArg, recursionCount));
});
return combinedList;
}
if (isClass(type)) {
const combinedList: TypeRefinement[] = [];
const typeArgs = type.priv.tupleTypeArgs ? type.priv.tupleTypeArgs.map((e) => e.type) : type.priv.typeArgs;
if (typeArgs) {
typeArgs.forEach((typeArg) => {
appendArray(combinedList, getRefinementsRecursive(typeArg, recursionCount));
});
}
if (type.priv.refinements) {
appendArray(combinedList, type.priv.refinements);
}
return combinedList;
}
if (isUnion(type)) {
const combinedList: TypeRefinement[] = [];
doForEachSubtype(type, (subtype) => {
appendArray(combinedList, getRefinementsRecursive(subtype, recursionCount));
});
return combinedList;
}
if (isFunction(type)) {
const combinedList: TypeRefinement[] = [];
for (let i = 0; i < type.shared.parameters.length; i++) {
appendArray(combinedList, getRefinementsRecursive(FunctionType.getParamType(type, i), recursionCount));
}
const returnType = FunctionType.getEffectiveReturnType(type);
if (returnType) {
appendArray(combinedList, getRefinementsRecursive(returnType, recursionCount));
}
return combinedList;
}
return [];
}
// Creates a specialized version of the class, filling in any unspecified
// type arguments with Unknown or default value.
export function specializeWithDefaultTypeArgs(type: ClassType): ClassType {
@ -2805,6 +2982,11 @@ function _requiresSpecialization(type: Type, options?: RequiresSpecializationOpt
switch (type.category) {
case TypeCategory.Class: {
// If the type has a refinement type, it may need to be specialized.
if (type.priv.refinements) {
return true;
}
if (ClassType.isPseudoGenericClass(type) && options?.ignorePseudoGeneric) {
return false;
}
@ -3514,7 +3696,8 @@ export class TypeVarTransformer {
if (
typeParams.length === 0 &&
!ClassType.isSpecialBuiltIn(classType) &&
!ClassType.isBuiltIn(classType, 'type')
!ClassType.isBuiltIn(classType, 'type') &&
!classType.priv.refinements
) {
return classType;
}
@ -3596,18 +3779,51 @@ export class TypeVarTransformer {
});
}
// If specialization wasn't needed, don't allocate a new class.
if (!specializationNeeded) {
return classType;
let newClassType = classType;
if (specializationNeeded) {
newClassType = ClassType.specialize(
classType,
newTypeArgs,
/* isTypeArgExplicit */ true,
/* includeSubclasses */ undefined,
newTupleTypeArgs
);
}
return ClassType.specialize(
classType,
newTypeArgs,
/* isTypeArgExplicit */ true,
/* includeSubclasses */ undefined,
newTupleTypeArgs
);
// Apply transforms to the refinements.
if (newClassType.priv.refinements) {
const newRefinements: TypeRefinement[] = [];
let refinementsChanged = false;
newClassType.priv.refinements.forEach((refinement) => {
const newRefinement = this.transformRefinement(refinement);
if (newRefinement) {
// Strip wildcard refinements that are not enforced.
if (isRefinementWildcard(newRefinement.value) && !newRefinement.isEnforced) {
refinementsChanged = true;
} else {
newRefinements.push(newRefinement);
if (newRefinement !== refinement) {
refinementsChanged = true;
}
}
} else {
refinementsChanged = true;
}
});
if (refinementsChanged) {
newClassType = ClassType.cloneWithRefinements(
newClassType,
newRefinements.length > 0 ? newRefinements : undefined
);
newClassType = simplifyRefinementTypeToLiteral(newClassType);
}
}
return newClassType;
}
transformTypeVarsInFunctionType(sourceType: FunctionType, recursionCount: number): FunctionType | OverloadedType {
@ -3798,6 +4014,10 @@ export class TypeVarTransformer {
});
}
transformRefinement(refinement: TypeRefinement): TypeRefinement | undefined {
return refinement;
}
private _isTypeVarScopePending(typeVarScopeId: TypeVarScopeId | undefined) {
return !!typeVarScopeId && this._pendingTypeVarTransformations.has(typeVarScopeId);
}
@ -3897,6 +4117,23 @@ class BoundTypeVarTransform extends TypeVarTransformer {
return undefined;
}
override transformRefinement(refinement: TypeRefinement): TypeRefinement | undefined {
if (!this._scopeIds) {
return refinement;
}
const newValue = makeRefinementVarsBound(refinement.value, this._scopeIds);
const newCondition = refinement.condition
? makeRefinementVarsBound(refinement.condition, this._scopeIds)
: undefined;
if (newValue === refinement.value && newCondition === refinement.condition) {
return refinement;
}
return { ...refinement, value: newValue, condition: newCondition };
}
private _isTypeVarInScope(typeVar: TypeVarType) {
if (!typeVar.priv.scopeId) {
return false;
@ -3930,6 +4167,23 @@ class FreeTypeVarTransform extends TypeVarTransformer {
return undefined;
}
override transformRefinement(refinement: TypeRefinement): TypeRefinement | undefined {
if (!this._scopeIds) {
return refinement;
}
const newValue = makeRefinementVarsFree(refinement.value, this._scopeIds);
const newCondition = refinement.condition
? makeRefinementVarsFree(refinement.condition, this._scopeIds)
: undefined;
if (newValue === refinement.value && newCondition === refinement.condition) {
return refinement;
}
return { ...refinement, value: newValue, condition: newCondition };
}
private _isTypeVarInScope(typeVar: TypeVarType) {
if (!typeVar.priv.scopeId) {
return false;
@ -4141,6 +4395,25 @@ class ApplySolvedTypeVarsTransformer extends TypeVarTransformer {
return type;
}
override transformRefinement(refinement: TypeRefinement): TypeRefinement | undefined {
const solutionSet = this._solution.getSolutionSet(this._activeConstraintSetIndex ?? 0);
const newRefinementValue = applySolvedRefinementVars(refinement.value, solutionSet.getRefinementVarMap(), {
replaceUnsolved: !!this._options.replaceUnsolved,
refinements: this._options.refinements,
});
if (newRefinementValue === refinement.value) {
return refinement;
}
return {
...refinement,
value: evaluateRefinementExpression(newRefinementValue, {
refinements: this._options.refinements,
}),
};
}
override doForEachConstraintSet(callback: () => FunctionType): FunctionType | OverloadedType {
const solutionSets = this._solution.getSolutionSets();

View file

@ -12,6 +12,7 @@ import { assert } from '../common/debug';
import { Uri } from '../common/uri/uri';
import { ArgumentNode, ExpressionNode, NameNode, ParamCategory } from '../parser/parseNodes';
import { ClassDeclaration, FunctionDeclaration, SpecialBuiltInClassDeclaration } from './declaration';
import { RefinementExpr, RefinementVar, TypeRefinement } from './refinementTypes';
import { Symbol, SymbolTable } from './symbol';
export const enum TypeCategory {
@ -791,6 +792,9 @@ export interface ClassDetailsPriv {
// literal types (e.g. true or 'string' or 3).
literalValue?: LiteralValue | undefined;
// Refinement definitions.
refinements?: TypeRefinement[] | undefined;
// The typing module defines aliases for builtin types
// (e.g. Tuple, List, Dict). This field holds the alias
// name.
@ -960,6 +964,26 @@ export namespace ClassType {
return newClassType;
}
export function cloneAddRefinement(classType: ClassType, refinement: TypeRefinement): ClassType {
const newClassType = TypeBase.cloneType(classType);
if (!newClassType.priv.refinements) {
newClassType.priv.refinements = [refinement];
} else {
newClassType.priv.refinements = [...newClassType.priv.refinements, refinement];
}
return newClassType;
}
export function cloneWithRefinements(classType: ClassType, refinements: TypeRefinement[] | undefined): ClassType {
const newClassType = TypeBase.cloneType(classType);
if (refinements && refinements.length === 0) {
refinements = undefined;
}
newClassType.priv.refinements = refinements;
return newClassType;
}
export function cloneForDeprecatedInstance(type: ClassType, deprecatedMessage?: string): ClassType {
const newClassType = TypeBase.cloneType(type);
newClassType.priv.deprecatedInstanceMessage = deprecatedMessage;
@ -1038,6 +1062,26 @@ export namespace ClassType {
return type1.priv.literalValue === type2.priv.literalValue;
}
export function isRefinementSame(type1: ClassType, type2: ClassType): boolean {
if (type1.priv.refinements === undefined) {
return type2.priv.refinements === undefined;
} else if (type2.priv.refinements === undefined) {
return false;
}
if (type1.priv.refinements.length !== type2.priv.refinements.length) {
return false;
}
for (let i = 0; i < type1.priv.refinements.length; i++) {
if (!TypeRefinement.isSame(type1.priv.refinements[i], type2.priv.refinements[i])) {
return false;
}
}
return true;
}
// Determines whether two typed dict classes are equivalent given
// that one or both have narrowed entries (i.e. entries that are
// guaranteed to be present).
@ -1132,6 +1176,14 @@ export namespace ClassType {
return true;
}
export function hasRefinement(classType: ClassType, classId: string): boolean {
if (!classType.priv.refinements) {
return false;
}
return classType.priv.refinements.some((refinement) => refinement.classDetails.classId === classId);
}
export function supportsAbstractMethods(classType: ClassType) {
return !!(classType.shared.flags & ClassTypeFlags.SupportsAbstractMethods);
}
@ -1492,6 +1544,8 @@ export interface FunctionParam {
_defaultType: Type | undefined;
defaultExpr: ExpressionNode | undefined;
refinementConditions?: RefinementExpr[];
}
export namespace FunctionParam {
@ -1501,9 +1555,10 @@ export namespace FunctionParam {
flags = FunctionParamFlags.None,
name?: string,
defaultType?: Type,
defaultExpr?: ExpressionNode
defaultExpr?: ExpressionNode,
refinementConditions?: RefinementExpr[]
): FunctionParam {
return { category, flags, name, _type: type, _defaultType: defaultType, defaultExpr };
return { category, flags, name, _type: type, _defaultType: defaultType, defaultExpr, refinementConditions };
}
export function isNameSynthesized(param: FunctionParam) {
@ -1611,6 +1666,8 @@ interface FunctionDetailsShared {
moduleName: string;
flags: FunctionTypeFlags;
typeParams: TypeVarType[];
refinements: TypeRefinement[] | undefined;
refinementVars: RefinementVar[] | undefined;
parameters: FunctionParam[];
declaredReturnType: Type | undefined;
declaration: FunctionDeclaration | undefined;
@ -1732,6 +1789,8 @@ export namespace FunctionType {
moduleName,
flags: functionFlags,
typeParams: [],
refinements: undefined,
refinementVars: undefined,
parameters: [],
declaredReturnType: undefined,
declaration: undefined,
@ -3338,6 +3397,10 @@ export function isTypeSame(type1: Type, type2: Type, options: TypeSameOptions =
return false;
}
if (!ClassType.isRefinementSame(type1, classType2)) {
return false;
}
if (!type1.priv.isUnpacked !== !classType2.priv.isUnpacked) {
return false;
}

View file

@ -18,6 +18,7 @@ import {
isUnresolvedAliasDeclaration,
} from '../analyzer/declaration';
import * as ParseTreeUtils from '../analyzer/parseTreeUtils';
import { RefinementExprType } from '../analyzer/refinementTypes';
import { SourceMapper } from '../analyzer/sourceMapper';
import { isBuiltInModule } from '../analyzer/typeDocStringUtils';
import { PrintTypeOptions, SynthesizedTypeInfo, TypeEvaluator } from '../analyzer/typeEvaluatorTypes';
@ -269,6 +270,44 @@ export class HoverProvider {
this._addResultsForSynthesizedType(results.parts, type, name);
});
this._addDocumentationPart(results.parts, node, /* resolvedDecl */ undefined);
} else if (declInfo?.refinementInfo) {
if (declInfo.refinementInfo.callSignature) {
this._addResultsPart(
results.parts,
`(refinement operator) ${node.d.value}: ${declInfo.refinementInfo.callSignature}`,
/* python */ true
);
if (declInfo.refinementInfo.callDocstring) {
addDocumentationResultsPart(
this._program.serviceProvider,
declInfo.refinementInfo.callDocstring,
this._format,
results.parts,
/* resolvedDecl */ undefined
);
}
} else {
let varTypeText: string | undefined;
if (declInfo.refinementInfo.isWildcard) {
varTypeText = 'any';
} else if (declInfo.refinementInfo.varType) {
const varTypeMapping: { [key in RefinementExprType]: string } = {
[RefinementExprType.Int]: 'Int',
[RefinementExprType.Str]: 'str',
[RefinementExprType.Bytes]: 'bytes',
[RefinementExprType.Bool]: 'bool',
[RefinementExprType.IntTuple]: 'tuple',
};
varTypeText = varTypeMapping[declInfo.refinementInfo.varType];
}
this._addResultsPart(
results.parts,
`(refinement var) ${node.d.value}: ${varTypeText ?? 'unknown'}`,
/* python */ true
);
}
} else if (!node.parent || node.parent.nodeType !== ParseNodeType.ModuleName) {
// If we had no declaration, see if we can provide a minimal tooltip. We'll skip
// this if it's part of a module name, since a module name part with no declaration

View file

@ -494,6 +494,8 @@ export namespace Localizer {
export const expectedPatternExpr = () => getRawString('Diagnostic.expectedPatternExpr');
export const expectedPatternSubjectExpr = () => getRawString('Diagnostic.expectedPatternSubjectExpr');
export const expectedPatternValue = () => getRawString('Diagnostic.expectedPatternValue');
export const expectedPredicateIf = () => getRawString('Diagnostic.expectedPredicateIf');
export const expectedRefinement = () => getRawString('Diagnostic.expectedRefinement');
export const expectedReturnExpr = () => getRawString('Diagnostic.expectedReturnExpr');
export const expectedSliceIndex = () => getRawString('Diagnostic.expectedSliceIndex');
export const expectedTypeNotString = () => getRawString('Diagnostic.expectedTypeNotString');
@ -848,6 +850,30 @@ export namespace Localizer {
export const readOnlyNotInTypedDict = () => getRawString('Diagnostic.readOnlyNotInTypedDict');
export const recursiveDefinition = () =>
new ParameterizedString<{ name: string }>(getRawString('Diagnostic.recursiveDefinition'));
export const refinementBaseTypeInvalid = () => getRawString('Diagnostic.refinementBaseTypeInvalid');
export const refinementCallArgCount = () =>
new ParameterizedString<{ name: string; expected: number; received: number }>(
getRawString('Diagnostic.refinementCallArgCount')
);
export const refinementCallArgUnpacked = () => getRawString('Diagnostic.refinementCallArgUnpacked');
export const refinementCallArgKeyword = () => getRawString('Diagnostic.refinementCallArgKeyword');
export const refinementConditionFailure = () => getRawString('Diagnostic.refinementConditionFailure');
export const refinementFloatImaginary = () => getRawString('Diagnostic.refinementFloatImaginary');
export const refinementIntTupleNotAllowed = () => getRawString('Diagnostic.refinementIntTupleNotAllowed');
export const refinementPostCondition = () => getRawString('Diagnostic.refinementPostCondition');
export const refinementPrecondition = () => getRawString('Diagnostic.refinementPrecondition');
export const refinementTypeNotSupported = () =>
new ParameterizedString<{ name: string }>(getRawString('Diagnostic.refinementTypeNotSupported'));
export const refinementUnexpectedValueType = () =>
new ParameterizedString<{ expected: string; received: string }>(
getRawString('Diagnostic.refinementUnexpectedValueType')
);
export const refinementUnsupportedCall = () =>
new ParameterizedString<{ name: string }>(getRawString('Diagnostic.refinementUnsupportedCall'));
export const refinementUnsupportedExpression = () => getRawString('Diagnostic.refinementUnsupportedExpression');
export const refinementUnsupportedOperation = () => getRawString('Diagnostic.refinementUnsupportedOperation');
export const refinementVarNotInValue = () =>
new ParameterizedString<{ name: string }>(getRawString('Diagnostic.refinementVarNotInValue'));
export const relativeImportNotAllowed = () => getRawString('Diagnostic.relativeImportNotAllowed');
export const requiredArgCount = () => getRawString('Diagnostic.requiredArgCount');
export const requiredNotInTypedDict = () => getRawString('Diagnostic.requiredNotInTypedDict');
@ -1007,6 +1033,7 @@ export namespace Localizer {
export const typeGuardParamCount = () => getRawString('Diagnostic.typeGuardParamCount');
export const typeIsReturnType = () =>
new ParameterizedString<{ type: string; returnType: string }>(getRawString('Diagnostic.typeIsReturnType'));
export const typeMetadataInvalid = () => getRawString('Diagnostic.typeMetadataInvalid');
export const typeNotAwaitable = () =>
new ParameterizedString<{ type: string }>(getRawString('Diagnostic.typeNotAwaitable'));
export const typeNotIntantiable = () =>
@ -1457,6 +1484,37 @@ export namespace Localizer {
export const pyrightCommentIgnoreTip = () => getRawString('DiagnosticAddendum.pyrightCommentIgnoreTip');
export const readOnlyAttribute = () =>
new ParameterizedString<{ name: string }>(getRawString('DiagnosticAddendum.readOnlyAttribute'));
export const refinementBroadcast = () => getRawString('DiagnosticAddendum.refinementBroadcast');
export const refinementConcatMismatch = () => getRawString('DiagnosticAddendum.refinementConcatMismatch');
export const refinementConditionNotSatisfied = () =>
new ParameterizedString<{ condition: string }>(
getRawString('DiagnosticAddendum.refinementConditionNotSatisfied')
);
export const refinementIndexOutOfRange = () =>
new ParameterizedString<{ value: number }>(getRawString('DiagnosticAddendum.refinementIndexOutOfRange'));
export const refinementLiteralAssignment = () =>
new ParameterizedString<{ expected: string; received: string }>(
getRawString('DiagnosticAddendum.refinementLiteralAssignment')
);
export const refinementPermuteDuplicate = () => getRawString('DiagnosticAddendum.refinementPermuteDuplicate');
export const refinementPermuteMismatch = () =>
new ParameterizedString<{ expected: string; received: string }>(
getRawString('DiagnosticAddendum.refinementPermuteMismatch')
);
export const refinementReshapeInferred = () => getRawString('DiagnosticAddendum.refinementReshapeInferred');
export const refinementReshapeMismatch = () => getRawString('DiagnosticAddendum.refinementReshapeMismatch');
export const refinementShapeMismatch = () =>
new ParameterizedString<{ expected: string; received: string }>(
getRawString('DiagnosticAddendum.refinementShapeMismatch')
);
export const refinementTupleMismatch = () =>
new ParameterizedString<{ expected: string; received: string }>(
getRawString('DiagnosticAddendum.refinementTupleMismatch')
);
export const refinementValMismatch = () =>
new ParameterizedString<{ expected: string; received: string }>(
getRawString('DiagnosticAddendum.refinementValMismatch')
);
export const seeDeclaration = () => getRawString('DiagnosticAddendum.seeDeclaration');
export const seeClassDeclaration = () => getRawString('DiagnosticAddendum.seeClassDeclaration');
export const seeFunctionDeclaration = () => getRawString('DiagnosticAddendum.seeFunctionDeclaration');

View file

@ -441,6 +441,11 @@
"message": "Expected pattern value expression of the form \"a.b\"",
"comment": "{Locked='a.b'}"
},
"expectedPredicateIf": "Expected \"if\" after refinement value expression",
"expectedRefinement": {
"message": "Expected string literal for refinement type definition",
"comment": "{Locked='literal'"
},
"expectedReturnExpr": {
"message": "Expected expression after \"return\"",
"comment": "{Locked='return'}"
@ -1063,6 +1068,23 @@
"comment": "{Locked='ReadOnly'}"
},
"recursiveDefinition": "Type of \"{name}\" could not be determined because it refers to itself",
"refinementBaseTypeInvalid": "Base type for a refinement type must be a nominal class",
"refinementBoolValueExpected": "Expected boolean expression",
"refinementBytes": "Bytes values are not supported in refinement types",
"refinementCallArgCount": "Incorrect argument count for \"name\"; expected {expected} but received {received}",
"refinementCallArgKeyword": "Keyword arguments are not supported in refinement types",
"refinementConditionFailure": "Argument value violates refinement condition",
"refinementCallArgUnpacked": "Unpacked argument expressions are not supported in refinement types",
"refinementFloatImaginary": "Float and imaginary values are not supported in refinement types",
"refinementIntTupleNotAllowed": "Literal value not supported by IntTupleRefinement",
"refinementPostCondition": "Refinement condition not allowed in return type",
"refinementPrecondition": "Invalid refinement value expression; use variable, literal, or \"_\"",
"refinementTypeNotSupported": "Refinement type \"{name}\" is not supported",
"refinementUnexpectedValueType": "Unexpected expression type; expected \"{expected}\" but received \"{received}\"",
"refinementUnsupportedCall": "Operation \"{name}\" not supported in refinement type",
"refinementUnsupportedExpression": "Expression not supported in refinement type",
"refinementUnsupportedOperation": "Operation not supported in refinement type",
"refinementVarNotInValue": "Refinement variable \"{name}\" is undefined",
"relativeImportNotAllowed": {
"message": "Relative imports cannot be used with \"import .a\" form; use \"from . import a\" instead",
"comment": "{Locked='import .a','from . import a'}"
@ -1296,6 +1318,9 @@
"message": "Return type of TypeIs (\"{returnType}\") is not consistent with value parameter type (\"{type}\")",
"comment": "{Locked='TypeIs'}"
},
"typeMetadataInvalid": {
"message": "Type metadata must be instance of TypeMetadata or a literal value",
"comment": "{Locked='TypeMetadata'}"},
"typeNotAwaitable": {
"message": "\"{type}\" is not awaitable",
"comment": "{Locked='awaitable'}"
@ -1933,6 +1958,18 @@
"comment": "{Locked='# pyright: ignore[<diagnostic rules>]'}"
},
"readOnlyAttribute": "Attribute \"{name}\" is read-only",
"refinementBroadcast": "Shapes differ and do not allow for broadcasting",
"refinementConcatMismatch": "Shape must match for all dimensions other than the concatenation dimension",
"refinementConditionNotSatisfied": "Refinement condition not satisfied: \"{condition}\"",
"refinementIndexOutOfRange": "Index {value} is out of range",
"refinementLiteralAssignment": "Refinement literal value is incompatible; expected {expected} but received {received}",
"refinementPermuteDuplicate": "Permute does not allow duplicate dimensions",
"refinementPermuteMismatch": "Permute dimension count mismatch (expected {expected} but received {received})",
"refinementReshapeInferred": "Only one inferred dimension is allowed",
"refinementReshapeMismatch": "New shape does not match existing shape",
"refinementShapeMismatch": "Could not assign shape \"{received}\" to \"{expected}\"",
"refinementTupleMismatch": "Could not assign refinement tuple \"{received}\" to \"{expected}\"",
"refinementValMismatch": "Refinement value is incompatible; expected {expected} but received {received}",
"seeClassDeclaration": "See class declaration",
"seeDeclaration": "See declaration",
"seeFunctionDeclaration": "See function declaration",

View file

@ -110,6 +110,8 @@ export const enum ParseNodeType {
TypeParameter,
TypeParameterList,
TypeAlias,
Refinement,
}
export const enum ErrorExpressionCategory {
@ -1807,6 +1809,11 @@ export interface StringListNode extends ParseNodeBase<ParseNodeType.StringList>
// into an expression.
annotation: ExpressionNode | undefined;
// If the string represents a refinement type definition
// within a type annotation, it is further parsed into
// an expression (lazily by the type evaluator).
refinement: RefinementNode | undefined;
// Indicates that the string list is enclosed in parens.
hasParens: boolean;
};
@ -1824,6 +1831,7 @@ export namespace StringListNode {
d: {
strings,
annotation: undefined,
refinement: undefined,
hasParens: false,
},
};
@ -2742,6 +2750,39 @@ export type PatternAtomNode =
| PatternValueNode
| ErrorNode;
export interface RefinementNode extends ParseNodeBase<ParseNodeType.Refinement> {
d: {
valueExpr: ExpressionNode;
conditionExpr: ExpressionNode | undefined;
};
}
export namespace RefinementNode {
export function create(valueExpr: ExpressionNode, conditionExpr: ExpressionNode | undefined) {
const node: RefinementNode = {
start: valueExpr.start,
length: valueExpr.length,
nodeType: ParseNodeType.Refinement,
id: _nextNodeId++,
parent: undefined,
a: undefined,
d: {
valueExpr,
conditionExpr,
},
};
valueExpr.parent = node;
if (conditionExpr) {
conditionExpr.parent = node;
extendRange(node, conditionExpr);
}
return node;
}
}
export type ParseNode =
| ErrorNode
| ArgumentNode
@ -2799,6 +2840,7 @@ export type ParseNode =
| PatternMappingNode
| PatternSequenceNode
| PatternValueNode
| RefinementNode
| RaiseNode
| ReturnNode
| SetNode

View file

@ -103,6 +103,7 @@ import {
PatternSequenceNode,
PatternValueNode,
RaiseNode,
RefinementNode,
ReturnNode,
SetNode,
SliceNode,
@ -217,6 +218,7 @@ export const enum ParseTextMode {
Expression,
VariableAnnotation,
FunctionAnnotation,
Refinement,
}
// Limit the max child node depth to prevent stack overflows.
@ -318,6 +320,15 @@ export class Parser {
initialParenDepth?: number,
typingSymbolAliases?: Map<string, string>
): ParseExpressionTextResults<FunctionAnnotationNode>;
parseTextExpression(
fileContents: string,
textOffset: number,
textLength: number,
parseOptions: ParseOptions,
parseTextMode: ParseTextMode.Refinement,
initialParenDepth?: number,
typingSymbolAliases?: Map<string, string>
): ParseExpressionTextResults<RefinementNode>;
parseTextExpression(
fileContents: string,
textOffset: number,
@ -326,7 +337,7 @@ export class Parser {
parseTextMode = ParseTextMode.Expression,
initialParenDepth = 0,
typingSymbolAliases?: Map<string, string>
): ParseExpressionTextResults<ExpressionNode | FunctionAnnotationNode> {
): ParseExpressionTextResults<ExpressionNode | FunctionAnnotationNode | RefinementNode> {
const diagSink = new DiagnosticSink();
this._startNewParse(fileContents, textOffset, textLength, parseOptions, diagSink, initialParenDepth);
@ -334,13 +345,16 @@ export class Parser {
this._typingSymbolAliases = new Map<string, string>(typingSymbolAliases);
}
let parseTree: ExpressionNode | FunctionAnnotationNode | undefined;
let parseTree: ExpressionNode | FunctionAnnotationNode | RefinementNode | undefined;
if (parseTextMode === ParseTextMode.VariableAnnotation) {
this._isParsingQuotedText = true;
parseTree = this._parseTypeAnnotation();
} else if (parseTextMode === ParseTextMode.FunctionAnnotation) {
this._isParsingQuotedText = true;
parseTree = this._parseFunctionTypeAnnotation();
} else if (parseTextMode === ParseTextMode.Refinement) {
this._isParsingQuotedText = true;
parseTree = this._parseRefinement();
} else {
const exprListResult = this._parseTestOrStarExpressionList(
/* allowAssignmentExpression */ false,
@ -3434,6 +3448,8 @@ export class Parser {
return leftExpr;
}
const wasParsingTypeAnnotation = this._isParsingTypeAnnotation;
let peekToken = this._peekToken();
let nextOperator = this._peekOperatorType();
while (
@ -3444,12 +3460,19 @@ export class Parser {
nextOperator === OperatorType.FloorDivide
) {
this._getNextToken();
if (nextOperator === OperatorType.MatrixMultiply) {
this._isParsingTypeAnnotation = false;
}
const rightExpr = this._parseArithmeticFactor();
leftExpr = this._createBinaryOperationNode(leftExpr, rightExpr, peekToken, nextOperator);
peekToken = this._peekToken();
nextOperator = this._peekOperatorType();
}
this._isParsingTypeAnnotation = wasParsingTypeAnnotation;
return leftExpr;
}
@ -4923,6 +4946,45 @@ export class Parser {
return FormatStringNode.create(startToken, endToken, middleTokens, fieldExpressions, formatExpressions);
}
private _parseRefinement(): RefinementNode {
const valueExpr = this._parseRefinementValue();
let conditionExpr: ExpressionNode | undefined;
if (this._consumeTokenIfKeyword(KeywordType.If)) {
conditionExpr = this._parseOrTest();
} else {
const nextToken = this._peekToken();
if (nextToken.type !== TokenType.EndOfStream) {
this._addSyntaxError(LocMessage.expectedPredicateIf(), nextToken);
this._consumeTokensUntilType([TokenType.EndOfStream]);
}
}
return RefinementNode.create(valueExpr, conditionExpr);
}
private _parseRefinementValue(): ExpressionNode {
if (this._isNextTokenNeverExpression()) {
return this._handleExpressionParseError(
ErrorExpressionCategory.MissingExpression,
LocMessage.expectedExpr()
);
}
const exprListResult = this._parseExpressionListGeneric(() => {
if (this._peekOperatorType() === OperatorType.Multiply) {
return this._parseExpression(/* allowUnpack */ true);
}
return this._parseOrTest();
});
if (exprListResult.parseError) {
return exprListResult.parseError;
}
return this._makeExpressionOrTuple(exprListResult, /* enclosedInParens */ false);
}
private _createBinaryOperationNode(
leftExpression: ExpressionNode,
rightExpression: ExpressionNode,

View file

@ -0,0 +1,56 @@
# This sample tests basic handling of refinement types.
# pyright: reportMissingModuleSource=false
from typing import Annotated, Any, Iterable, TypedDict
from typing_extensions import IntValue
class TD1(TypedDict):
x: int
def v_ok1(v: Annotated[int, IntValue("x")]):
reveal_type(v, expected_text='int @ "x"')
def v_ok2(v: Annotated[int, IntValue(value=1)]):
reveal_type(v, expected_text='int @ 1')
# This should generate an error because refinement types
# apply only to nominal class types.
v_bad1: Annotated[int | str, IntValue("x")]
# This should generate an error because refinement types
# apply only to nominal class types.
v_bad2: Annotated[Iterable, IntValue("x")]
# This should generate an error because refinement types
# apply only to nominal class types.
v_bad3: Annotated[TD1, IntValue("x")]
# This should generate an error because refinement types
# apply only to nominal class types.
v_bad4: Annotated[Any, IntValue("x")]
def x_ok1(v: int @ IntValue("x")):
reveal_type(v, expected_text='int @ "x"')
def x_ok2(v: int @ IntValue(value=1)):
reveal_type(v, expected_text='int @ 1')
# This should generate an error because refinement types
# apply only to nominal class types.
x_bad1: (int | str) @ IntValue("x")
# This should generate an error because refinement types
# apply only to nominal class types.
x_bad2: Iterable @ IntValue("x")
# This should generate an error because refinement types
# apply only to nominal class types.
x_bad3: TD1 @ IntValue("x")
# This should generate an error because refinement types
# apply only to nominal class types.
x_bad4: Any @ IntValue("x")

View file

@ -0,0 +1,59 @@
# This sample tests basic parsing of refinement predicates and
# refinement variable consistency.
# pyright: reportMissingModuleSource=false
from typing import Annotated
from typing_extensions import IntValue, Refinement, Shape, StrValue
class Tensor: ...
v_ok1: Annotated[int, IntValue("x if x > 1 and (x < 10 or x % 2 == 0)")]
v_ok2: Annotated[int, Shape("x, y if x > 1 and (x < 10 or x % 2 == 0)")]
v_ok3: Annotated[int, Shape("x, *y if x > 1 and (x < 10 or x % 2 == 0)")]
# This should generate a syntax error because ":" should be "if".
v_bad1: Annotated[int, IntValue("x: x > 1")]
# This should generate a syntax error because "x" isn't a bool value.
v_bad2: Annotated[int, IntValue("x if x")]
# This should generate a syntax error because "y" isn't an int value.
v_bad3: Annotated[int, Shape("x, *y if y < 1")]
# This should generate a syntax error because "x, y" isn't an int value.
v_bad4: Annotated[int, IntValue("x, y")]
# This should generate a syntax error because "x" isn't an int value.
v_bad5: Annotated[int, Shape("x if x < 1")]
# This should generate a syntax error because it's an invalid expression.
v_bad6: Annotated[int, IntValue("x if x < 1 x")]
# This should generate a syntax error because it's an invalid expression.
v_bad7: Annotated[int, IntValue("x if x.foo > 1")]
# This should generate a syntax error because it's an invalid expression.
v_bad8: Annotated[int, IntValue("x if x[1] > 1")]
# This should generate a syntax error because it's an unsupported call.
v_bad9: Annotated[int, IntValue("x if call(x)")]
class CustomRefinementDomain(Refinement):
def __str__(self) -> str:
return ""
# This should generate a syntax error because it's an unknown refinement domain.
v_bad10: Annotated[int, CustomRefinementDomain("x")]
# This should generate two errors because "x" is inconsistent.
v_bad11: int @ IntValue("x") | str @ StrValue("x") | Tensor @ Shape("x")
# This should generate two errors because "x" is inconsistent.
type TA_Bad1 = int @ IntValue("x") | str @ StrValue("x") | Tensor @ Shape("x")

View file

@ -0,0 +1,27 @@
# This sample tests basic scoping rules for refinement variables.
# pyright: reportMissingModuleSource=false
from typing_extensions import IntValue, StrValue
def func_good1(a: int @ IntValue("x"), b: str @ IntValue("x")) -> None:
pass
# This should generate an error because "x" is used twice
# and has inconsistent types.
def func_bad1(a: int @ IntValue("x"), b: str @ StrValue("x")) -> None:
pass
def outer1(a: int @ IntValue("x")):
# This should generate an error because "x"
# has inconsistent types.
def inner(a: str @ StrValue("x")):
pass
v1: int @ IntValue("x")
# This should generate an error because "x"
# has inconsistent types.
v2: str @ StrValue("x")

View file

@ -0,0 +1,55 @@
# This sample tests refinement condition validation.
# pyright: reportMissingModuleSource=false
from typing import Any, cast
from typing_extensions import StrValue
type SingleDigitIt = int @ "x if x >= 0 and x < 10"
x1: SingleDigitIt = 0
x2: SingleDigitIt = 9
# This should generate an error.
x3: SingleDigitIt = 10
# This should generate an error.
x4: SingleDigitIt = -1
def func1(a: int @ "x if x >= 0", b: int @ "y if y < 0 and x + y < 2") -> int:
return a + b
func1(1, -1)
func1(10, -9)
# This should generate an error.
func1(1, 0)
# This should generate an error.
func1(10, -2)
def func2(a: int @ "x", b: int @ "y if x + y < 2") -> int @ "x":
result = a + b
reveal_type(result, expected_text='int @ "x + y"')
return cast(Any, result)
func2(1, -1)
func2(10, -9)
# This should generate an error.
func2(1, 4)
# This should generate an error.
func2(10, -2)
y1: str @ StrValue("x if x == 'hi' or x == 'bye'") = "bye"
y1 = "hi"
# This should generate an error.
y1 = "neither"

View file

@ -0,0 +1,40 @@
# This sample tests the handling of custom refinement types.
# pyright: reportMissingModuleSource=false
from typing_extensions import StrRefinement
class Units(StrRefinement):
def __str__(self) -> str:
return ""
type FloatYards = float @ Units(value="yards")
type FloatMeters = float @ Units(value="meters")
def add_units(a: float @ Units("x"), b: float @ Units("x")) -> float @ Units("x"):
return a + b
def convert_yards_to_meters(a: float @ Units("'yards'")) -> float @ Units("'meters'"):
return a * 0.9144
def test2(a: float @ Units("'meters'"), b: float @ Units("'yards'")):
m = convert_yards_to_meters(b)
reveal_type(m, expected_text="float @ Units(\"'meters'\")")
add_units(a, m)
# This should generate an error.
add_units(a, b)
def test3(a: FloatMeters, b: FloatYards):
m = convert_yards_to_meters(b)
reveal_type(m, expected_text="float @ Units(\"'meters'\")")
add_units(a, m)
# This should generate an error.
add_units(a, b)

View file

@ -0,0 +1,33 @@
# This sample tests enforced and unenforced refinement types.
# pyright: reportMissingModuleSource=false
from typing_extensions import IntValue, Shape
def get_int() -> int: ...
# This should generate an error.
i1: int @ 1 = get_int()
# This should generate an error.
i2: int @ IntValue(value=1) = get_int()
i3: int @ IntValue(value=1, enforce=False) = get_int()
i4: int @ "x" = get_int()
i5: int @ IntValue("x", enforce=False) = get_int()
# This should generate an error.
i6: int @ IntValue("x", enforce=True) = get_int()
class Tensor: ...
t1: Tensor @ Shape("x") = Tensor()
# This should generate an error.
t2: Tensor @ Shape("x", enforce=True) = Tensor()

View file

@ -0,0 +1,25 @@
# This sample tests the case where a class supports implicit refinement
# types through the __type_metadata__ magic method.
# pyright: reportMissingModuleSource=false
from typing import cast
from typing_extensions import Shape
class ClassA:
@classmethod
def __type_metadata__(cls, pred: str) -> Shape:
return Shape(pred)
def test1(a: ClassA @ Shape("x, "), b: ClassA @ "y, ") -> ClassA @ "x, y":
reveal_type(a, expected_text='ClassA @ "x,"')
reveal_type(b, expected_text='ClassA @ "y,"')
return cast(ClassA @ "x, y", ClassA())
def test2(m: ClassA @ "1, ", n: ClassA @ "2, "):
v = test1(m, n)
reveal_type(v, expected_text='ClassA @ "1, 2"')

View file

@ -0,0 +1,59 @@
# This sample tests basic interactions between Literal and refinement types.
# pyright: reportMissingModuleSource=false
from typing import Literal, cast
from typing_extensions import StrValue
def func1(a: int @ "x") -> int @ "x":
return a
v1 = func1(-1)
reveal_type(v1, expected_text="Literal[-1]")
def func2(a: int @ "x", b: int @ "y") -> int @ "x" | int @ "y":
return a if a > b else b
v2 = func2(-1, 2)
reveal_type(v2, expected_text="Literal[-1, 2]")
def func3(a: str @ StrValue("x"), b: str @ StrValue("y")) -> str @ StrValue("x + y"):
return cast(str @ StrValue("x + y"), a + b)
v3 = func3("hi ", "there")
reveal_type(v3, expected_text="Literal['hi there']")
def func4(x: int @ 2):
y1: Literal[2] = x
y2: Literal[1, 2] = x
# This should result in an error.
y3: Literal[3] = x
def func5(x: bool @ False):
y1: Literal[False] = x
y2: bool @ False = x
# This should result in an error.
y3: Literal[True] = x
# This should result in an error.
y4: bool @ True = x
def is_greater(a: int @ "a", b: int @ "b") -> bool @ "a > b":
return a > b
def func6():
reveal_type(is_greater(1, 2), expected_text="Literal[False]")
reveal_type(is_greater(2, 1), expected_text="Literal[True]")

View file

@ -0,0 +1,55 @@
# This sample tests the "Shape" refinement type.
# pyright: reportMissingModuleSource=false
from typing_extensions import Shape
class Tensor: ...
def matmul(a: Tensor @ Shape("x, y"), b: Tensor @ Shape("y, z")) -> Tensor @ Shape(
"x, z"
): ...
def func1(
a: Tensor @ Shape("a, b"), b: Tensor @ Shape("b, c"), c: Tensor @ Shape("b, b")
):
v1 = matmul(a, b)
reveal_type(v1, expected_text='Tensor @ Shape("a, c")')
v2 = matmul(a, c)
reveal_type(v2, expected_text='Tensor @ Shape("a, b")')
v3 = matmul(c, c)
reveal_type(v3, expected_text='Tensor @ Shape("b, b")')
v4 = matmul(c, b)
reveal_type(v4, expected_text='Tensor @ Shape("b, c")')
# This should generate an error.
matmul(c, a)
# This should generate an error.
matmul(b, a)
class Size(tuple[int, ...]): ...
def func2(v: Size @ Shape("x, y, z")):
a, b, c = v
reveal_type(a, expected_text='int @ "x"')
reveal_type(b, expected_text='int @ "y"')
reveal_type(c, expected_text='int @ "z"')
d, *x = v
reveal_type(d, expected_text='int @ "x"')
reveal_type(x, expected_text='list[int @ "y" | int @ "z"]')
k, j, *m, n = v
reveal_type(k, expected_text='int @ "x"')
reveal_type(j, expected_text='int @ "y"')
reveal_type(m, expected_text="list[Never]")
reveal_type(n, expected_text='int @ "z"')

View file

@ -0,0 +1,137 @@
# This sample tests various aspects of the Shape refinement type.
# pyright: reportMissingModuleSource=false
from tensorlib import Size, Tensor, cat, conv2d, randn, sum
def func1(x: Size @ "x, y") -> Size @ "x, y":
return x
def func2(s1: Size @ "1, 2", s2: Size @ "1, 2, x"):
# This should generate an error.
func1(s2)
v1 = func1(s1)
reveal_type(v1, expected_text='Size @ "1, 2"')
x1, y1 = s1
reveal_type(x1, expected_text="int @ 1")
reveal_type(y1, expected_text="int @ 2")
# This should generate an error.
x2, y2, z2 = s1
x3, *other3 = s2
reveal_type(x3, expected_text="int @ 1")
reveal_type(other3, expected_text='list[int @ 2 | int @ "x"]')
def index1(t1: Tensor @ "a, b, c"):
s1 = t1.shape
reveal_type(s1, expected_text='Size @ "a, b, c"')
s2 = s1[2]
reveal_type(s2, expected_text='int @ "c"')
s3 = s1[-3]
reveal_type(s3, expected_text='int @ "a"')
# This should generate an error.
s4 = s1[-4]
# This should generate an error.
s5 = s1[4]
def index2(t1: Tensor @ "a, b, *other"):
s1 = t1.shape
reveal_type(s1, expected_text='Size @ "a, b, *other"')
s2 = s1[2]
reveal_type(s2, expected_text='int @ "index((a, b, *other), 2)"')
s3 = s1[-3]
reveal_type(s3, expected_text='int @ "index((a, b, *other), -3)"')
s4 = s1[-4]
reveal_type(s4, expected_text='int @ "index((a, b, *other), -4)"')
s5 = s1[4]
reveal_type(s5, expected_text='int @ "index((a, b, *other), 4)"')
def concat1(t1: Tensor @ "a, b, c", t2: Tensor @ "a, 1, c"):
s1 = cat((t1, t2), dim=1)
reveal_type(s1, expected_text='Tensor @ "a, b + 1, c"')
s2 = cat((t1, t2, t2), dim=1)
reveal_type(s2, expected_text='Tensor @ "a, b + 2, c"')
# This should generate an error.
s3 = cat((t1, t2, t2))
# This should generate an error.
s4 = cat((t1, t2, t2), dim=2)
# This should generate an error.
s5 = cat((t1, t2, t2), dim=-1)
# This should generate an error.
s6 = cat((t1, t2, t2), dim=5)
def conv1(input: Tensor @ "n, c_in, y, x", weight: Tensor @ "c_out, c_in, ky, kx"):
c1 = conv2d(input, weight)
reveal_type(c1, expected_text='Tensor @ "n, c_out, y - ky + 1, x - kx + 1"')
def conv2(x: Tensor @ "B, C, H, W", filters: Tensor @ "C, C, F1, F2"):
return conv2d(x, filters, stride=2)
def conv3():
filters = randn(4, 4, 5, 5)
reveal_type(filters, expected_text='Tensor @ "4, 4, 5, 5"')
c0 = conv2(randn(1, 4, 5, 5), filters)
reveal_type(c0, expected_text='Tensor @ "1, 4, 1, 1"')
c1 = conv2(randn(1, 4, 32, 32), filters)
reveal_type(c1, expected_text='Tensor @ "1, 4, 14, 14"')
c2 = conv2(randn(1, 4, 53, 32), filters)
reveal_type(c2, expected_text='Tensor @ "1, 4, 25, 14"')
c3 = conv2(randn(1, 4, 28, 28), filters)
reveal_type(c3, expected_text='Tensor @ "1, 4, 12, 12"')
def sum1(t1: Tensor @ "a, b"):
s1 = sum(t1)
reveal_type(s1, expected_text='Tensor @ "1,"')
s2 = sum(t1, dim=0)
reveal_type(s2, expected_text='Tensor @ "b,"')
s3 = sum(t1, dim=0, keepdim=True)
reveal_type(s3, expected_text='Tensor @ "1, b"')
s4 = sum(t1, dim=1)
reveal_type(s4, expected_text='Tensor @ "a,"')
s5 = sum(t1, dim=1, keepdim=True)
reveal_type(s5, expected_text='Tensor @ "a, 1"')
s6 = sum(t1, dim=-1)
reveal_type(s6, expected_text='Tensor @ "a,"')
s7 = sum(t1, dim=-2)
reveal_type(s7, expected_text='Tensor @ "b,"')
# This should generate an error.
s8 = sum(t1, dim=2)
# This should generate an error.
s9 = sum(t1, dim=-3)

View file

@ -0,0 +1,119 @@
# This sample tests various aspects of the "Shape" refinement type.
# pyright: reportMissingModuleSource=false
from tensorlib import Tensor, linspace, randn, index_select, permute, squeeze, unsqueeze
def broadcast1(
t1: Tensor @ "a, b",
t2: Tensor @ "x, a, b",
t3: Tensor @ "x, 1, 1",
t4: Tensor @ "1, a, c if c == b",
t5: Tensor @ "1, a, d",
t6: Tensor @ "3, 1, 4",
t7: Tensor @ "5, 1, 5, 1",
):
d1 = t1.sub(t2)
reveal_type(d1, expected_text='Tensor @ "x, a, b"')
d2 = t1 - t2
reveal_type(d2, expected_text='Tensor @ "x, a, b"')
d3 = t2 + t3
reveal_type(d3, expected_text='Tensor @ "x, a, b"')
d3 = t2 - t3
reveal_type(d3, expected_text='Tensor @ "x, a, b"')
d4 = t2 + t4
reveal_type(d4, expected_text='Tensor @ "x, a, b"')
# This should generate an error.
d5 = t2 - t5
d6 = t6 + t7
reveal_type(d6, expected_text='Tensor @ "5, 3, 5, 4"')
def linspace1(i1: int @ "a if a > 0"):
t1 = linspace(0, 10, 4)
reveal_type(t1, expected_text='Tensor @ "4,"')
t2 = linspace(0, 4, i1)
reveal_type(t2, expected_text='Tensor @ "a,"')
t3_out = randn(2)
reveal_type(t3_out, expected_text='Tensor @ "2,"')
t3 = linspace(0, 4, 2, out=t3_out)
reveal_type(t3, expected_text='Tensor @ "2,"')
# This should generate an error.
t4 = linspace(0, 4, 3, out=t3_out)
# This should generate an error.
t5 = linspace(0, 4, 0)
def index_select1(i: Tensor @ "a, b", x: Tensor @ "3, "):
t1 = index_select(i, 0, x)
reveal_type(t1, expected_text='Tensor @ "3, b"')
t2 = index_select(i, 1, x)
reveal_type(t2, expected_text='Tensor @ "a, 3"')
# This should generate an error.
t3 = index_select(i, 2, x)
def permute1(t1: Tensor @ "a, b, c"):
p1 = permute(t1, (1, 2, 0))
reveal_type(p1, expected_text='Tensor @ "b, c, a"')
p2 = permute(t1, (0, 2, 1))
reveal_type(p2, expected_text='Tensor @ "a, c, b"')
# This should generate an error.
p3 = permute(t1, (0, 2, 0))
# This should generate an error.
p4 = permute(t1, (0, 2, 4))
# This should generate an error.
p5 = permute(t1, (0, 1, 2, 3))
def squeeze1(t1: Tensor @ "a, b, 1, 2"):
s1 = squeeze(t1, 2)
reveal_type(s1, expected_text='Tensor @ "a, b, 2"')
s2 = squeeze(t1, -2)
reveal_type(s1, expected_text='Tensor @ "a, b, 2"')
s3 = squeeze(t1, -1)
reveal_type(s3, expected_text='Tensor @ "a, b, 1, 2"')
# This should generate two errors.
s4 = squeeze(t1, 5)
s5 = squeeze(t1, 1)
reveal_type(s5, expected_text='Tensor @ "a, b, 1, 2"')
def squeeze2(t1: Tensor @ "a, b, 1, 2"):
s1 = squeeze(t1, (1, 2))
reveal_type(s1, expected_text="Tensor")
def unsqueeze1(t1: Tensor @ "a, b, c"):
u1 = unsqueeze(t1, 1)
reveal_type(u1, expected_text='Tensor @ "a, 1, b, c"')
u2 = unsqueeze(t1, 3)
reveal_type(u2, expected_text='Tensor @ "a, b, c, 1"')
u3 = unsqueeze(t1, -1)
reveal_type(u3, expected_text='Tensor @ "a, b, c, 1"')
u4 = unsqueeze(t1, -2)
reveal_type(u4, expected_text='Tensor @ "a, b, 1, c"')

View file

@ -0,0 +1,52 @@
# This sample tests the handling of refinement types that use the
# IntTuple refinement domain.
# pyright: reportMissingModuleSource=false
from typing import cast
from typing_extensions import Shape
class Tensor:
...
def func1(a: int @ "x", b: int @ "y") -> Tensor @ Shape("x, y"):
...
v1 = func1(1, 2)
reveal_type(v1, expected_text='Tensor @ Shape("1, 2")')
def func2(a: Tensor @ Shape("x, y")) -> Tensor @ Shape("y, x"):
...
t2 = cast(Tensor @ Shape("1, 2"), Tensor())
v2 = func2(t2)
reveal_type(v2, expected_text='Tensor @ Shape("2, 1")')
def func3(a: Tensor @ Shape("a, b")):
x = func2(a)
reveal_type(x, expected_text='Tensor @ Shape("b, a")')
return x
t3_1 = cast(Tensor @ Shape("1, 2"), Tensor())
v3_1 = func3(t3_1)
reveal_type(v3_1, expected_text='Tensor @ Shape("2, 1")')
t3_2 = cast(Tensor @ Shape("_, 1"), Tensor())
v3_2 = func3(t3_2)
reveal_type(v3_2, expected_text='Tensor @ Shape("1, _")')
def func4(a: Tensor @ Shape("a, b, *other")) -> Tensor @ Shape("a, *other, b"):
...
t4_1 = cast(Tensor @ Shape("1, 2, 3, 4"), Tensor())
v4_1 = func4(t4_1)
reveal_type(v4_1, expected_text='Tensor @ Shape("1, 3, 4, 2")')

View file

@ -0,0 +1,105 @@
# This sample represents a typical tensor library that might use the
# "Shape" refinement. It is used by other test cases.
# This sample defines the basic classes found in a typical tensor
# library. It's meant to test refinement types for tensor shape
# validation.
from typing import Literal, overload
from typing_extensions import IntTupleValue, Shape
class Size(tuple[int, ...]):
@classmethod
def __type_metadata__(cls, predicate: str) -> Shape: ...
def __getitem__(self: Size @ "o", key: int @ "i") -> int @ "index(o, i)": ... # pyright: ignore[reportIncompatibleMethodOverride]
def elem_count(self: Size @ "o") -> int @ "len(o)": ...
class Tensor:
def __init__(self, value) -> None: ...
@classmethod
def __type_metadata__(cls, predicate: str) -> Shape: ...
@property
def shape(self: Tensor @ "o") -> Size @ "o": ...
def transpose(self: Tensor @ "o", dim0: int @ "a", dim1: int @ "b") -> Tensor @ "swap(o, a, b)": ...
def view(self: Tensor @ "o", *args: *(tuple[int, ...] @ IntTupleValue("t"))) -> Tensor @ "reshape(o, t)": ...
def cos(self: Tensor @ "o") -> Tensor @ "o": ...
def sin(self: Tensor @ "o") -> Tensor @ "o": ...
def add(self: Tensor @ "o", input: Tensor @ "i") -> Tensor @ "broadcast(o, i)": ...
def __add__(self: Tensor @ "o", input: Tensor @ "i") -> Tensor @ "broadcast(o, i)": ...
def sub(self: Tensor @ "o", input: Tensor @ "i") -> Tensor @ "broadcast(o, i)": ...
def __sub__(self: Tensor @ "o", input: Tensor @ "i") -> Tensor @ "broadcast(o, i)": ...
def pow(self: Tensor @ "o", exp: float) -> Tensor @ "o": ...
def __pow__(self: Tensor @ "o", exp: float) -> Tensor @ "o": ...
@overload
def cat(
tensors: tuple[Tensor @ "t1", Tensor @ "t2"],
dim: int @ "d" = 0,
*,
out: Tensor @ "o if o == concat(t1, t2, d)" | None = ...,
) -> Tensor @ "concat(t1, t2, d)": ...
@overload
def cat(
tensors: tuple[Tensor @ "t1", Tensor @ "t2", Tensor @ "t3"],
dim: int @ "d" = 0,
*,
out: Tensor @ "o if o == concat(concat(t1, t2, d), t3, d)" | None = ...,
) -> Tensor @ "concat(concat(t1, t2, d), t3, d)": ...
@overload
def cat(
tensors: tuple[Tensor @ "t1", Tensor @ "t2", Tensor @ "t3", Tensor @ "t4"],
dim: int @ "d" = 0,
*,
out: Tensor @ "o if o == concat(concat(concat(t1, t2, d), t3, d), t4, d)" | None = ...,
) -> Tensor @ "concat(concat(concat(t1, t2, d), t3, d), t4, d)": ...
@overload
def cat(tensors: tuple[Tensor, ...], dim: int = 0, *, out: Tensor | None = ...) -> Tensor: ...
def conv2d(
input: Tensor @ "n, c_in, h_in, w_in",
weight: Tensor @ "c_out, gr, k0, k1 if gr == c_in // g",
bias: Tensor @ "c_out," | None = None,
stride: int @ "s0" @ "s1" | tuple[int @ "s0", int @ "s1"] = 1,
padding: int @ "p0" @ "p1" | tuple[int @ "p0", int @ "p1"] = 0,
dilation: int @ "d0" @ "d1" | tuple[int @ "d0", int @ "d1"] = 1,
groups: int @ "g" = 1,
) -> (
Tensor @ "n, c_out, (h_in + 2 * p0 - d0 * (k0 - 1) - 1) // s0 + 1, (w_in + 2 * p1 - d1 * (k1 - 1) - 1) // s1 + 1"
): ...
def index_select(
input: Tensor @ "o", dim: int @ "i if index(o, i) == _", index: Tensor @ "x,"
) -> Tensor @ "splice(o, i, 1, (x,))": ...
def linspace(
start: float, end: float, steps: int @ "x if x > 0", *, out: Tensor @ "x, " | None = None
) -> Tensor @ "x, ": ...
def logspace(
start: float, end: float, steps: int @ "x if x > 0", *, out: Tensor @ "x, " | None = None
) -> Tensor @ "x, ": ...
def permute(
input: Tensor @ "o", dims: tuple[int, ...] @ IntTupleValue("d if permute(o, d) == _")
) -> Tensor @ "permute(o, d)": ...
def randn(*args: *(tuple[int, ...] @ Shape("o"))) -> Tensor @ "o": ...
def sqrt(input: Tensor @ "o") -> Tensor @ "o": ...
@overload
def squeeze(input: Tensor @ "o", dim: int @ "i if index(o, i) == 1") -> Tensor @ "splice(o, i, 1, ())": ...
@overload
def squeeze(input: Tensor @ "o", dim: int @ "i if index(o, i) != 1") -> Tensor @ "o": ...
@overload
def squeeze(input: Tensor, dim: tuple[int, ...] | None = None) -> Tensor: ...
@overload
def sum(input: Tensor @ "o", *, dim: None = None, keepdim: bool = False) -> Tensor @ "(1, )": ...
@overload
def sum(
input: Tensor @ "o",
*,
dim: int @ "d",
keepdim: Literal[False] = False,
) -> Tensor @ "splice(o, d, 1, ())": ...
@overload
def sum(input: Tensor @ "o", *, dim: int @ "d", keepdim: Literal[True]) -> Tensor @ "splice(o, d, 1, (1, ))": ...
def take(input: Tensor @ "o", index: Tensor @ "x,") -> Tensor @ "x,": ...
@overload
def unsqueeze(input: Tensor @ "o", dim: int @ "d if d >= 0") -> Tensor @ "splice(o, d, 0, (1,))": ...
@overload
def unsqueeze(input: Tensor @ "o", dim: int @ "d if d < 0") -> Tensor @ "splice(o, len(o) + d + 1, 0, (1,))": ...
@overload
def unsqueeze(input: Tensor, dim: int) -> Tensor: ...

View file

@ -0,0 +1,26 @@
# This sample tests the handling of the TypeMetadata class.
# pyright: reportMissingModuleSource=false
from typing_extensions import TypeMetadata
class NotTypeMeta: ...
class TM1(TypeMetadata): ...
# This should generate an error.
v1: int @ []
# This should generate an error.
v2: list[int] @ [1]
# This should generate an error.
v3: int @ NotTypeMeta()
# This should generate an error.
v4: int @ TM1
v5: int @ TM1()

View file

@ -11,7 +11,7 @@
import * as assert from 'assert';
import { ConfigOptions } from '../common/configOptions';
import { pythonVersion3_10, pythonVersion3_11, pythonVersion3_8 } from '../common/pythonVersion';
import { pythonVersion3_10, pythonVersion3_11, pythonVersion3_12, pythonVersion3_8 } from '../common/pythonVersion';
import { Uri } from '../common/uri/uri';
import * as TestUtils from './testUtils';
@ -972,3 +972,110 @@ test('TypeForm7', () => {
TestUtils.validateResults(analysisResults, 1);
});
test('TypeMeta1', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeMeta1.py'], configOptions);
TestUtils.validateResults(analysisResults, 4);
});
test('Refinement1', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinement1.py'], configOptions);
TestUtils.validateResults(analysisResults, 8);
});
test('Refinement2', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.defaultPythonVersion = pythonVersion3_12;
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinement2.py'], configOptions);
TestUtils.validateResults(analysisResults, 14);
});
test('Refinement3', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinement3.py'], configOptions);
TestUtils.validateResults(analysisResults, 3);
});
test('Refinement4', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.defaultPythonVersion = pythonVersion3_12;
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinement4.py'], configOptions);
TestUtils.validateResults(analysisResults, 7);
});
test('RefinementEnforce1', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementEnforce1.py'], configOptions);
TestUtils.validateResults(analysisResults, 4);
});
test('RefinementImplicit1', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.defaultPythonVersion = pythonVersion3_12;
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementImplicit1.py'], configOptions);
TestUtils.validateResults(analysisResults, 0);
});
test('RefinementLiteral1', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementLiteral1.py'], configOptions);
TestUtils.validateResults(analysisResults, 3);
});
test('RefinementShape1', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementShape1.py'], configOptions);
TestUtils.validateResults(analysisResults, 2);
});
test('RefinementShape2', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementShape2.py'], configOptions);
TestUtils.validateResults(analysisResults, 10);
});
test('RefinementShape3', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementShape3.py'], configOptions);
TestUtils.validateResults(analysisResults, 9);
});
test('RefinementCustom1', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementCustom1.py'], configOptions);
TestUtils.validateResults(analysisResults, 2);
});
test('RefinementTuple1', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementTuple1.py'], configOptions);
TestUtils.validateResults(analysisResults, 0);
});

View file

@ -539,3 +539,73 @@ if sys.version_info >= (3, 14):
from typing import TypeForm
else:
TypeForm: _SpecialForm
# Refinement Types
class TypeMetadata:
def __rmatmul__[T](self, typ: T) -> T:
return typ
class Refinement(TypeMetadata, metaclass=abc.ABCMeta):
def __init__(self, predicate: str | None, /) -> None: ...
@abc.abstractmethod
def __str__(self) -> str: ...
class IntRefinement(Refinement, metaclass=abc.ABCMeta):
predicate: str | None
@overload
def __init__(
self, predicate: None = None, /, *, value: int, enforce: bool = True
) -> None: ...
@overload
def __init__(self, predicate: str, /, *, enforce: bool = True) -> None: ...
class StrRefinement(Refinement, metaclass=abc.ABCMeta):
predicate: str | None
@overload
def __init__(
self, predicate: None = None, /, *, value: str, enforce: bool = True
) -> None: ...
@overload
def __init__(self, predicate: str, /, *, enforce: bool = False) -> None: ...
class BytesRefinement(Refinement, metaclass=abc.ABCMeta):
predicate: str | None
@overload
def __init__(
self, predicate: None = None, /, *, value: bytes, enforce: bool = True
) -> None: ...
@overload
def __init__(self, predicate: str, /, *, enforce: bool = False) -> None: ...
class BoolRefinement(Refinement, metaclass=abc.ABCMeta):
predicate: str | None
@overload
def __init__(
self, predicate: None = None, /, *, value: bool, enforce: bool = True
) -> None: ...
@overload
def __init__(self, predicate: str, /, *, enforce: bool = False) -> None: ...
class IntTupleRefinement(Refinement, metaclass=abc.ABCMeta):
predicate: str | None
def __init__(self, predicate: str, /, *, enforce: bool = False) -> None: ...
class IntValue(IntRefinement):
def __str__(self) -> str: ...
class StrValue(StrRefinement):
def __str__(self) -> str: ...
class BytesValue(BytesRefinement):
def __str__(self) -> str: ...
class BoolValue(BoolRefinement):
def __str__(self) -> str: ...
class IntTupleValue(IntTupleRefinement):
def __str__(self) -> str: ...
class Shape(IntTupleRefinement):
def __str__(self) -> str: ...