mirror of
https://github.com/microsoft/pyright.git
synced 2025-12-23 09:19:29 +00:00
Implemented prototype of refinement types.
This commit is contained in:
parent
243983220e
commit
e08c7a1e0f
40 changed files with 6946 additions and 112 deletions
|
|
@ -670,7 +670,10 @@ export class Binder extends ParseTreeWalker {
|
|||
// Skip if we're in an 'Annotated' annotation because this creates
|
||||
// problems for "No Return" return type analysis when annotation
|
||||
// evaluation is deferred.
|
||||
if (!this._isInAnnotatedAnnotation) {
|
||||
if (
|
||||
!this._isInAnnotatedAnnotation &&
|
||||
!ParseTreeUtils.isNodeContainedWithinNodeType(node, ParseNodeType.StringList)
|
||||
) {
|
||||
this._createCallFlowNode(node);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,19 +9,25 @@
|
|||
*/
|
||||
|
||||
import { assert } from '../common/debug';
|
||||
import { RefinementExpr, RefinementVarId } from './refinementTypes';
|
||||
import { FunctionType, ParamSpecType, Type, TypeVarType } from './types';
|
||||
|
||||
export type RefinementVarMap = Map<RefinementVarId, RefinementExpr | undefined>;
|
||||
|
||||
// Records the types associated with a set of type variables.
|
||||
export class ConstraintSolutionSet {
|
||||
// Indexed by TypeVar ID.
|
||||
private _typeVarMap: Map<string, Type | undefined>;
|
||||
|
||||
// Indexed by refinement var ID.
|
||||
private _refinementVarMap: RefinementVarMap | undefined;
|
||||
|
||||
constructor() {
|
||||
this._typeVarMap = new Map();
|
||||
}
|
||||
|
||||
isEmpty() {
|
||||
return this._typeVarMap.size === 0;
|
||||
return this._typeVarMap.size === 0 && !this._refinementVarMap;
|
||||
}
|
||||
|
||||
getType(typeVar: ParamSpecType): FunctionType | undefined;
|
||||
|
|
@ -48,6 +54,34 @@ export class ConstraintSolutionSet {
|
|||
}
|
||||
});
|
||||
}
|
||||
|
||||
getRefinementVarType(refinementVarId: string): RefinementExpr | undefined {
|
||||
return this._refinementVarMap?.get(refinementVarId);
|
||||
}
|
||||
|
||||
setRefinementVarType(refinementVarId: string, value: RefinementExpr | undefined) {
|
||||
if (!this._refinementVarMap) {
|
||||
this._refinementVarMap = new Map();
|
||||
}
|
||||
|
||||
this._refinementVarMap.set(refinementVarId, value);
|
||||
}
|
||||
|
||||
hasRefinementVarType(refinementVarId: string): boolean {
|
||||
return this._refinementVarMap?.has(refinementVarId) ?? false;
|
||||
}
|
||||
|
||||
doForEachRefinementVar(callback: (value: RefinementExpr, refinementVarId: string) => void) {
|
||||
this._refinementVarMap?.forEach((type, key) => {
|
||||
if (type) {
|
||||
callback(type, key);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
getRefinementVarMap(): RefinementVarMap {
|
||||
return this._refinementVarMap ?? new Map();
|
||||
}
|
||||
}
|
||||
|
||||
export class ConstraintSolution {
|
||||
|
|
@ -68,6 +102,12 @@ export class ConstraintSolution {
|
|||
});
|
||||
}
|
||||
|
||||
setRefinementVarType(refinementVarId: string, value: RefinementExpr) {
|
||||
return this._solutionSets.forEach((set) => {
|
||||
set.setRefinementVarType(refinementVarId, value);
|
||||
});
|
||||
}
|
||||
|
||||
getMainSolutionSet() {
|
||||
return this.getSolutionSet(0);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import { DiagnosticAddendum } from '../common/diagnostic';
|
|||
import { LocAddendum } from '../localization/localize';
|
||||
import { ConstraintSolution, ConstraintSolutionSet } from './constraintSolution';
|
||||
import { ConstraintSet, ConstraintTracker, TypeVarConstraints } from './constraintTracker';
|
||||
import { solveRefinementVarRecursive } from './refinementSolver';
|
||||
import {
|
||||
AssignTypeFlags,
|
||||
maxSubtypesForInferredType,
|
||||
|
|
@ -248,6 +249,11 @@ export function solveConstraintSet(
|
|||
solveTypeVarRecursive(evaluator, constraintSet, options, solutionSet, entry);
|
||||
});
|
||||
|
||||
// Solve the refinement variables.
|
||||
constraintSet.doForEachRefinementVar((name) => {
|
||||
solveRefinementVarRecursive(constraintSet, solutionSet, name);
|
||||
});
|
||||
|
||||
return solutionSet;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@
|
|||
*/
|
||||
|
||||
import { assert } from '../common/debug';
|
||||
import { RefinementExpr } from './refinementTypes';
|
||||
import { isRefinementExprEquivalent } from './refinementTypeUtils';
|
||||
import { getComplexityScoreForType } from './typeComplexity';
|
||||
import { Type, TypeVarScopeId, TypeVarType, isTypeSame } from './types';
|
||||
|
||||
|
|
@ -41,6 +43,9 @@ export class ConstraintSet {
|
|||
// Maps type variable IDs to their current constraints.
|
||||
private _typeVarMap: Map<string, TypeVarConstraints>;
|
||||
|
||||
// Maps refinement variable IDs to their current values.
|
||||
private _refinementVarMap: Map<string, RefinementExpr> | undefined;
|
||||
|
||||
// A set of one or more TypeVar scope IDs that identify this constraint set.
|
||||
// This corresponds to the scope ID of the overload signature. Normally
|
||||
// there will be only one scope ID associated with each signature, but
|
||||
|
|
@ -65,6 +70,13 @@ export class ConstraintSet {
|
|||
this._scopeIds.forEach((scopeId) => constraintSet.addScopeId(scopeId));
|
||||
}
|
||||
|
||||
if (this._refinementVarMap) {
|
||||
constraintSet._refinementVarMap = new Map<string, RefinementExpr>();
|
||||
this._refinementVarMap.forEach((value, key) => {
|
||||
constraintSet._refinementVarMap!.set(key, value);
|
||||
});
|
||||
}
|
||||
|
||||
return constraintSet;
|
||||
}
|
||||
|
||||
|
|
@ -93,6 +105,21 @@ export class ConstraintSet {
|
|||
}
|
||||
});
|
||||
|
||||
if (this._refinementVarMap) {
|
||||
if (!other._refinementVarMap || this._refinementVarMap.size !== other._refinementVarMap.size) {
|
||||
isSame = false;
|
||||
} else {
|
||||
this._refinementVarMap.forEach((value, key) => {
|
||||
const otherValue = other._refinementVarMap!.get(key);
|
||||
if (!otherValue || !isRefinementExprEquivalent(value, otherValue)) {
|
||||
isSame = false;
|
||||
}
|
||||
});
|
||||
}
|
||||
} else if (other._refinementVarMap) {
|
||||
isSame = false;
|
||||
}
|
||||
|
||||
return isSame;
|
||||
}
|
||||
|
||||
|
|
@ -180,6 +207,24 @@ export class ConstraintSet {
|
|||
|
||||
return false;
|
||||
}
|
||||
|
||||
getRefinementVarType(refinementVarId: string): RefinementExpr | undefined {
|
||||
return this._refinementVarMap?.get(refinementVarId);
|
||||
}
|
||||
|
||||
setRefinementVarType(refinementVarId: string, value: RefinementExpr) {
|
||||
if (!this._refinementVarMap) {
|
||||
this._refinementVarMap = new Map<string, RefinementExpr>();
|
||||
}
|
||||
|
||||
this._refinementVarMap.set(refinementVarId, value);
|
||||
}
|
||||
|
||||
doForEachRefinementVar(cb: (id: string, value: RefinementExpr) => void) {
|
||||
this._refinementVarMap?.forEach((value, key) => {
|
||||
cb(key, value);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export class ConstraintTracker {
|
||||
|
|
@ -276,4 +321,10 @@ export class ConstraintTracker {
|
|||
assert(index >= 0 && index < this._constraintSets.length);
|
||||
return this._constraintSets[index];
|
||||
}
|
||||
|
||||
setRefinementVarType(refinementVarId: string, value: RefinementExpr) {
|
||||
return this._constraintSets.forEach((set) => {
|
||||
set.setRefinementVarType(refinementVarId, value);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import { getEmptyRange } from '../common/textRange';
|
|||
import { Uri } from '../common/uri/uri';
|
||||
import { NameNode, ParseNodeType } from '../parser/parseNodes';
|
||||
import { ImportLookup, ImportLookupResult } from './analyzerFileInfo';
|
||||
import { AliasDeclaration, Declaration, DeclarationType, ModuleLoaderActions, isAliasDeclaration } from './declaration';
|
||||
import { AliasDeclaration, Declaration, DeclarationType, isAliasDeclaration, ModuleLoaderActions } from './declaration';
|
||||
import { getFileInfoFromNode } from './parseTreeUtils';
|
||||
import { Symbol } from './symbol';
|
||||
|
||||
|
|
|
|||
|
|
@ -23,12 +23,15 @@ import {
|
|||
import { OperatorType } from '../parser/tokenizerTypes';
|
||||
import { getFileInfo } from './analyzerNodeInfo';
|
||||
import { getEnclosingLambda, isWithinLoop, operatorSupportsChaining, printOperator } from './parseTreeUtils';
|
||||
import { isRefinementWildcard } from './refinementTypeUtils';
|
||||
import { TypeRefinement } from './refinementTypes';
|
||||
import { getScopeForNode } from './scopeUtils';
|
||||
import { evaluateStaticBoolExpression } from './staticExpressions';
|
||||
import { EvalFlags, MagicMethodDeprecationInfo, TypeEvaluator, TypeResult } from './typeEvaluatorTypes';
|
||||
import {
|
||||
InferenceContext,
|
||||
convertToInstantiable,
|
||||
getIntValueRefinement,
|
||||
getLiteralTypeClassName,
|
||||
getTypeCondition,
|
||||
getUnionSubtypeCount,
|
||||
|
|
@ -335,6 +338,16 @@ export function getTypeOfBinaryOperation(
|
|||
}
|
||||
}
|
||||
|
||||
// Is this a "@" operator used in a context where it is supposed to be
|
||||
// interpreted as a refinement type operator?
|
||||
if (node.d.operator === OperatorType.MatrixMultiply && (flags & EvalFlags.TypeExpression) !== 0) {
|
||||
if (getFileInfo(node).diagnosticRuleSet.enableExperimentalFeatures) {
|
||||
if (!customMetaclassSupportsMethod(leftType, '__matmul__')) {
|
||||
return evaluator.getTypeMetadata(leftExpression, leftTypeResult, rightExpression);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const rightTypeResult = evaluator.getTypeOfExpression(
|
||||
rightExpression,
|
||||
flags,
|
||||
|
|
@ -447,26 +460,24 @@ export function getTypeOfBinaryOperation(
|
|||
node.d.leftExpr
|
||||
);
|
||||
} else {
|
||||
// If neither the LHS or RHS are unions, don't include a diagnostic addendum
|
||||
let diagMessage = diag.getString();
|
||||
|
||||
// If neither the LHS or RHS are unions, don't add more diagnostic information
|
||||
// because it will be redundant with the main diagnostic message. The addenda
|
||||
// are useful only if union expansion was used for one or both operands.
|
||||
let diagString = '';
|
||||
if (
|
||||
isUnion(evaluator.makeTopLevelTypeVarsConcrete(leftType)) ||
|
||||
isUnion(evaluator.makeTopLevelTypeVarsConcrete(rightType))
|
||||
) {
|
||||
diagString = diag.getString();
|
||||
diagMessage =
|
||||
LocMessage.typeNotSupportBinaryOperator().format({
|
||||
operator: printOperator(node.d.operator),
|
||||
leftType: evaluator.printType(leftType),
|
||||
rightType: evaluator.printType(rightType),
|
||||
}) + diagMessage;
|
||||
}
|
||||
|
||||
evaluator.addDiagnostic(
|
||||
DiagnosticRule.reportOperatorIssue,
|
||||
LocMessage.typeNotSupportBinaryOperator().format({
|
||||
operator: printOperator(node.d.operator),
|
||||
leftType: evaluator.printType(leftType),
|
||||
rightType: evaluator.printType(rightType),
|
||||
}) + diagString,
|
||||
node
|
||||
);
|
||||
evaluator.addDiagnostic(DiagnosticRule.reportOperatorIssue, diagMessage, node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1266,12 +1277,14 @@ function validateArithmeticOperation(
|
|||
}
|
||||
|
||||
const magicMethodName = binaryOperatorMap[operator][0];
|
||||
const subDiag = new DiagnosticAddendum();
|
||||
let resultTypeResult = evaluator.getTypeOfMagicMethodCall(
|
||||
convertFunctionToObject(evaluator, leftSubtypeUnexpanded),
|
||||
magicMethodName,
|
||||
[{ type: rightSubtypeUnexpanded, isIncomplete: rightTypeResult.isIncomplete }],
|
||||
errorNode,
|
||||
inferenceContext
|
||||
inferenceContext,
|
||||
subDiag
|
||||
);
|
||||
|
||||
if (!resultTypeResult && leftSubtypeUnexpanded !== leftSubtypeExpanded) {
|
||||
|
|
@ -1336,6 +1349,8 @@ function validateArithmeticOperation(
|
|||
}
|
||||
|
||||
if (!resultTypeResult) {
|
||||
diag.addAddendum(subDiag);
|
||||
|
||||
if (inferenceContext) {
|
||||
diag.addMessage(
|
||||
LocMessage.typeNotSupportBinaryOperatorBidirectional().format({
|
||||
|
|
@ -1360,6 +1375,15 @@ function validateArithmeticOperation(
|
|||
deprecatedInfo = resultTypeResult.magicMethodDeprecationInfo;
|
||||
}
|
||||
|
||||
if (resultTypeResult && options.isLiteralMathAllowed) {
|
||||
resultTypeResult.type = applyRefinementsForBinaryOp(
|
||||
operator,
|
||||
leftSubtypeExpanded,
|
||||
rightSubtypeExpanded,
|
||||
resultTypeResult.type
|
||||
);
|
||||
}
|
||||
|
||||
return resultTypeResult?.type ?? UnknownType.create(isIncomplete);
|
||||
}
|
||||
);
|
||||
|
|
@ -1368,3 +1392,44 @@ function validateArithmeticOperation(
|
|||
|
||||
return { type, magicMethodDeprecationInfo: deprecatedInfo };
|
||||
}
|
||||
|
||||
function applyRefinementsForBinaryOp(operator: OperatorType, leftType: Type, rightType: Type, resultType: Type): Type {
|
||||
if (
|
||||
!isClassInstance(leftType) ||
|
||||
!ClassType.isBuiltIn(leftType, 'int') ||
|
||||
!isClassInstance(rightType) ||
|
||||
!ClassType.isBuiltIn(rightType, 'int') ||
|
||||
!isClassInstance(resultType) ||
|
||||
!ClassType.isBuiltIn(resultType, 'int') ||
|
||||
resultType.priv.literalValue !== undefined
|
||||
) {
|
||||
return resultType;
|
||||
}
|
||||
|
||||
const supportedOps: OperatorType[] = [
|
||||
OperatorType.Add,
|
||||
OperatorType.Subtract,
|
||||
OperatorType.Multiply,
|
||||
OperatorType.FloorDivide,
|
||||
OperatorType.Mod,
|
||||
];
|
||||
|
||||
if (!supportedOps.includes(operator)) {
|
||||
return resultType;
|
||||
}
|
||||
|
||||
const leftIntValue = getIntValueRefinement(leftType);
|
||||
const rightIntValue = getIntValueRefinement(rightType);
|
||||
|
||||
if (!leftIntValue || !rightIntValue) {
|
||||
return resultType;
|
||||
}
|
||||
|
||||
const refinement = TypeRefinement.fromBinaryOp(operator, leftIntValue, rightIntValue);
|
||||
|
||||
if (isRefinementWildcard(refinement.value)) {
|
||||
return resultType;
|
||||
}
|
||||
|
||||
return ClassType.cloneWithRefinements(resultType, [refinement]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ export enum ParamKind {
|
|||
|
||||
export interface VirtualParamDetails {
|
||||
param: FunctionParam;
|
||||
realParamType: Type;
|
||||
type: Type;
|
||||
declaredType: Type;
|
||||
defaultType?: Type | undefined;
|
||||
|
|
@ -124,6 +125,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
|
|||
|
||||
const addVirtualParam = (
|
||||
param: FunctionParam,
|
||||
realParamType: Type,
|
||||
index: number,
|
||||
typeOverride?: Type,
|
||||
defaultTypeOverride?: Type,
|
||||
|
|
@ -145,6 +147,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
|
|||
|
||||
result.params.push({
|
||||
param,
|
||||
realParamType,
|
||||
index,
|
||||
type: typeOverride ?? FunctionType.getParamType(type, index),
|
||||
declaredType: FunctionType.getDeclaredParamType(type, index),
|
||||
|
|
@ -182,6 +185,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
|
|||
FunctionParamFlags.NameSynthesized | FunctionParamFlags.TypeDeclared,
|
||||
`${param.name}[${tupleIndex.toString()}]`
|
||||
),
|
||||
paramType,
|
||||
index,
|
||||
tupleArg.type,
|
||||
/* defaultArgTypeOverride */ undefined,
|
||||
|
|
@ -226,7 +230,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
|
|||
sawKeywordOnlySeparator = true;
|
||||
}
|
||||
|
||||
addVirtualParam(param, index);
|
||||
addVirtualParam(param, paramType, index);
|
||||
}
|
||||
} else if (param.category === ParamCategory.KwargsDict) {
|
||||
sawKeywordOnlySeparator = true;
|
||||
|
|
@ -256,6 +260,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
|
|||
name,
|
||||
defaultParamType
|
||||
),
|
||||
paramType,
|
||||
index,
|
||||
specializedParamType,
|
||||
defaultParamType
|
||||
|
|
@ -270,6 +275,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
|
|||
FunctionParamFlags.TypeDeclared,
|
||||
'kwargs'
|
||||
),
|
||||
paramType,
|
||||
index,
|
||||
paramType.shared.typedDictEntries.extraItems.valueType
|
||||
);
|
||||
|
|
@ -288,7 +294,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
|
|||
result.firstKeywordOnlyIndex = result.params.length;
|
||||
}
|
||||
|
||||
addVirtualParam(param, index);
|
||||
addVirtualParam(param, paramType, index);
|
||||
}
|
||||
} else if (param.category === ParamCategory.Simple) {
|
||||
if (param.name && !sawKeywordOnlySeparator) {
|
||||
|
|
@ -297,6 +303,7 @@ export function getParamListDetails(type: FunctionType): ParamListDetails {
|
|||
|
||||
addVirtualParam(
|
||||
param,
|
||||
FunctionType.getParamType(type, index),
|
||||
index,
|
||||
/* typeOverride */ undefined,
|
||||
type.priv.specializedTypes?.parameterDefaultTypes
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ import {
|
|||
ParameterNode,
|
||||
ParseNode,
|
||||
ParseNodeType,
|
||||
RefinementNode,
|
||||
StatementListNode,
|
||||
StatementNode,
|
||||
StringListNode,
|
||||
|
|
@ -1081,6 +1082,39 @@ export function getTypeVarScopeNode(node: ParseNode): TypeParameterScopeNode | u
|
|||
return undefined;
|
||||
}
|
||||
|
||||
// Similar to getTypeVarScopeNode except for refinement variable scopes.
|
||||
export function getRefinementScopeNode(node: ParseNode): ParseNode | undefined {
|
||||
let prevNode: ParseNode | undefined;
|
||||
let curNode: ParseNode | undefined = node;
|
||||
let exprNode: ExpressionNode | RefinementNode | undefined;
|
||||
let sawNonExprNode = false;
|
||||
|
||||
while (curNode) {
|
||||
if (curNode.nodeType === ParseNodeType.Function) {
|
||||
if (prevNode === curNode.d.suite) {
|
||||
return exprNode ?? curNode;
|
||||
}
|
||||
|
||||
if (!curNode.d.decorators.some((decorator) => decorator === prevNode)) {
|
||||
return curNode;
|
||||
}
|
||||
}
|
||||
|
||||
if (isExpressionNode(curNode) && curNode.nodeType !== ParseNodeType.TypeAnnotation) {
|
||||
if (!sawNonExprNode) {
|
||||
exprNode = curNode;
|
||||
}
|
||||
} else if (curNode.nodeType !== ParseNodeType.Argument && curNode.nodeType !== ParseNodeType.Refinement) {
|
||||
sawNonExprNode = true;
|
||||
}
|
||||
|
||||
prevNode = curNode;
|
||||
curNode = curNode.parent;
|
||||
}
|
||||
|
||||
return exprNode;
|
||||
}
|
||||
|
||||
// Returns the parse node corresponding to the scope that is used
|
||||
// for executing the code referenced in the specified node.
|
||||
export function getExecutionScopeNode(node: ParseNode): ExecutionScopeNode {
|
||||
|
|
@ -2217,6 +2251,9 @@ export function printParseNodeType(type: ParseNodeType) {
|
|||
|
||||
case ParseNodeType.TypeAlias:
|
||||
return 'TypeAlias';
|
||||
|
||||
case ParseNodeType.Refinement:
|
||||
return 'Refinement';
|
||||
}
|
||||
|
||||
assertNever(type);
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ import {
|
|||
PatternSequenceNode,
|
||||
PatternValueNode,
|
||||
RaiseNode,
|
||||
RefinementNode,
|
||||
ReturnNode,
|
||||
SetNode,
|
||||
SliceNode,
|
||||
|
|
@ -271,6 +272,9 @@ export function getChildNodes(node: ParseNode): (ParseNode | undefined)[] {
|
|||
case ParseNodeType.PatternValue:
|
||||
return [node.d.expr];
|
||||
|
||||
case ParseNodeType.Refinement:
|
||||
return [node.d.valueExpr, node.d.conditionExpr];
|
||||
|
||||
case ParseNodeType.Raise:
|
||||
return [node.d.expr, node.d.fromExpr];
|
||||
|
||||
|
|
@ -287,7 +291,7 @@ export function getChildNodes(node: ParseNode): (ParseNode | undefined)[] {
|
|||
return node.d.statements;
|
||||
|
||||
case ParseNodeType.StringList:
|
||||
return [node.d.annotation, ...node.d.strings];
|
||||
return [node.d.annotation, node.d.refinement, ...node.d.strings];
|
||||
|
||||
case ParseNodeType.String:
|
||||
return [];
|
||||
|
|
@ -519,6 +523,9 @@ export class ParseTreeVisitor<T> {
|
|||
case ParseNodeType.PatternValue:
|
||||
return this.visitPatternValue(node);
|
||||
|
||||
case ParseNodeType.Refinement:
|
||||
return this.visitRefinement(node);
|
||||
|
||||
case ParseNodeType.Raise:
|
||||
return this.visitRaise(node);
|
||||
|
||||
|
|
@ -815,6 +822,10 @@ export class ParseTreeVisitor<T> {
|
|||
return this._default;
|
||||
}
|
||||
|
||||
visitRefinement(node: RefinementNode) {
|
||||
return this._default;
|
||||
}
|
||||
|
||||
visitRaise(node: RaiseNode) {
|
||||
return this._default;
|
||||
}
|
||||
|
|
|
|||
173
packages/pyright-internal/src/analyzer/refinementPrinter.ts
Normal file
173
packages/pyright-internal/src/analyzer/refinementPrinter.ts
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
/*
|
||||
* refinementPrinter.ts
|
||||
* Copyright (c) Microsoft Corporation.
|
||||
* Licensed under the MIT license.
|
||||
* Author: Eric Traut
|
||||
*
|
||||
* Logic that converts a refinement type to a user-visible string.
|
||||
*/
|
||||
|
||||
import { assertNever } from '../common/debug';
|
||||
import { OperatorType } from '../parser/tokenizerTypes';
|
||||
import { RefinementExpr, RefinementNodeType, TypeRefinement } from './refinementTypes';
|
||||
import { printBytesLiteral, printStringLiteral } from './typePrinterUtils';
|
||||
|
||||
export interface PrintRefinementExprOptions {
|
||||
// Include scopes for refinement variables?
|
||||
printVarScopes?: boolean;
|
||||
|
||||
// Surround tuples with parens?
|
||||
encloseTupleInParens?: boolean;
|
||||
}
|
||||
|
||||
// Converts a refinement definition to its text form.
|
||||
export function printRefinement(refinement: TypeRefinement, options?: PrintRefinementExprOptions): string {
|
||||
const value = printRefinementExpr(refinement.value, options);
|
||||
const condition = refinement.condition ? ` if ${printRefinementExpr(refinement.condition, options)}` : '';
|
||||
|
||||
if (refinement.classDetails.baseSupportsLiteral && !condition) {
|
||||
if (
|
||||
refinement.value.nodeType === RefinementNodeType.Number ||
|
||||
refinement.value.nodeType === RefinementNodeType.String ||
|
||||
refinement.value.nodeType === RefinementNodeType.Bytes ||
|
||||
refinement.value.nodeType === RefinementNodeType.Boolean
|
||||
) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
if (refinement.classDetails.baseSupportsStringShortcut) {
|
||||
return `"${value}${condition}"`;
|
||||
}
|
||||
|
||||
return `${refinement.classDetails.className}("${value}${condition}")`;
|
||||
}
|
||||
|
||||
export function printRefinementExpr(expr: RefinementExpr, options: PrintRefinementExprOptions = {}): string {
|
||||
switch (expr.nodeType) {
|
||||
case RefinementNodeType.Number: {
|
||||
return expr.value.toString();
|
||||
}
|
||||
|
||||
case RefinementNodeType.String: {
|
||||
return printStringLiteral(expr.value, "'");
|
||||
}
|
||||
|
||||
case RefinementNodeType.Bytes: {
|
||||
return printBytesLiteral(expr.value);
|
||||
}
|
||||
|
||||
case RefinementNodeType.Boolean: {
|
||||
return expr.value ? 'True' : 'False';
|
||||
}
|
||||
|
||||
case RefinementNodeType.Wildcard: {
|
||||
return '_';
|
||||
}
|
||||
|
||||
case RefinementNodeType.Var: {
|
||||
if (options?.printVarScopes) {
|
||||
return `${expr.var.shared.name}@${expr.var.shared.scopeName}`;
|
||||
}
|
||||
|
||||
return expr.var.shared.name;
|
||||
}
|
||||
|
||||
case RefinementNodeType.BinaryOp: {
|
||||
// Map the operator to a string and numerical evaluation precedence.
|
||||
const operatorMap: { [key: number]: [string, number, boolean] } = {
|
||||
[OperatorType.Multiply]: ['*', 1, true],
|
||||
[OperatorType.FloorDivide]: ['//', 1, false],
|
||||
[OperatorType.Mod]: ['%', 1, false],
|
||||
[OperatorType.Add]: ['+', 2, true],
|
||||
[OperatorType.Subtract]: ['-', 2, false],
|
||||
[OperatorType.Equals]: ['==', 3, false],
|
||||
[OperatorType.NotEquals]: ['!=', 3, false],
|
||||
[OperatorType.LessThan]: ['<', 3, false],
|
||||
[OperatorType.LessThanOrEqual]: ['<=', 3, false],
|
||||
[OperatorType.GreaterThan]: ['>', 3, false],
|
||||
[OperatorType.GreaterThanOrEqual]: ['>=', 3, false],
|
||||
[OperatorType.And]: ['and', 4, true],
|
||||
[OperatorType.Or]: ['or', 5, true],
|
||||
};
|
||||
|
||||
const operatorStr = operatorMap[expr.operator][0] ?? '<unknown>';
|
||||
const isCommutative = operatorMap[expr.operator][2] ?? false;
|
||||
|
||||
let leftStr = printRefinementExpr(expr.leftExpr, options);
|
||||
let rightStr = printRefinementExpr(expr.rightExpr, options);
|
||||
|
||||
const operatorPrecedence = operatorMap[expr.operator][1] ?? 0;
|
||||
|
||||
if (expr.leftExpr.nodeType === RefinementNodeType.BinaryOp) {
|
||||
const leftPrecedence = operatorMap[expr.leftExpr.operator][1] ?? 0;
|
||||
if (leftPrecedence > operatorPrecedence) {
|
||||
leftStr = `(${leftStr})`;
|
||||
}
|
||||
}
|
||||
|
||||
if (expr.rightExpr.nodeType === RefinementNodeType.BinaryOp) {
|
||||
const rightPrecedence = operatorMap[expr.rightExpr.operator][1] ?? 0;
|
||||
const isRightCommutative = operatorMap[expr.rightExpr.operator][1] ?? 0;
|
||||
|
||||
let includeParens = rightPrecedence >= operatorPrecedence;
|
||||
if (rightPrecedence === operatorPrecedence && isCommutative && isRightCommutative) {
|
||||
includeParens = false;
|
||||
}
|
||||
|
||||
if (includeParens) {
|
||||
rightStr = `(${rightStr})`;
|
||||
}
|
||||
}
|
||||
|
||||
return `${leftStr} ${operatorStr} ${rightStr}`;
|
||||
}
|
||||
|
||||
case RefinementNodeType.UnaryOp: {
|
||||
const operatorMap: { [key: number]: string } = {
|
||||
[OperatorType.Add]: '+',
|
||||
[OperatorType.Subtract]: '-',
|
||||
[OperatorType.Not]: 'not ',
|
||||
};
|
||||
|
||||
const operatorStr = operatorMap[expr.operator] ?? '<unknown>';
|
||||
return `${operatorStr}${printRefinementExpr(expr.expr, options)}`;
|
||||
}
|
||||
|
||||
case RefinementNodeType.Tuple: {
|
||||
const entries = expr.entries.map((elem) => {
|
||||
let baseElemStr = printRefinementExpr(elem.value, options);
|
||||
if (elem.value.nodeType === RefinementNodeType.Tuple) {
|
||||
baseElemStr = `(${baseElemStr})`;
|
||||
}
|
||||
return `${elem.isUnpacked ? '*' : ''}${baseElemStr}`;
|
||||
});
|
||||
|
||||
if (expr.entries.length === 0) {
|
||||
return '()';
|
||||
}
|
||||
|
||||
let tupleStr: string;
|
||||
if (expr.entries.length === 1) {
|
||||
tupleStr = `${entries[0]},`;
|
||||
} else {
|
||||
tupleStr = entries.join(', ');
|
||||
}
|
||||
|
||||
if (options.encloseTupleInParens) {
|
||||
return `(${tupleStr})`;
|
||||
}
|
||||
|
||||
return tupleStr;
|
||||
}
|
||||
|
||||
case RefinementNodeType.Call: {
|
||||
const args = expr.args.map((arg) => printRefinementExpr(arg, { ...options, encloseTupleInParens: true }));
|
||||
return `${expr.name}(${args.join(', ')})`;
|
||||
}
|
||||
|
||||
default: {
|
||||
assertNever(expr);
|
||||
}
|
||||
}
|
||||
}
|
||||
566
packages/pyright-internal/src/analyzer/refinementSolver.ts
Normal file
566
packages/pyright-internal/src/analyzer/refinementSolver.ts
Normal file
|
|
@ -0,0 +1,566 @@
|
|||
/*
|
||||
* refinementSolver.ts
|
||||
* Copyright (c) Microsoft Corporation.
|
||||
* Licensed under the MIT license.
|
||||
* Author: Eric Traut
|
||||
*
|
||||
* Constraint solver for refinement types.
|
||||
*/
|
||||
|
||||
import { assert } from '../common/debug';
|
||||
import { DiagnosticAddendum } from '../common/diagnostic';
|
||||
import { LocAddendum } from '../localization/localize';
|
||||
import { ConstraintSolutionSet } from './constraintSolution';
|
||||
import { ConstraintSet, ConstraintTracker } from './constraintTracker';
|
||||
import { printRefinementExpr } from './refinementPrinter';
|
||||
import {
|
||||
RefinementExpr,
|
||||
RefinementNodeType,
|
||||
RefinementNumberNode,
|
||||
RefinementTupleEntry,
|
||||
RefinementVarId,
|
||||
RefinementVarNode,
|
||||
TypeRefinement,
|
||||
} from './refinementTypes';
|
||||
import {
|
||||
applySolvedRefinementVars,
|
||||
createWildcardRefinementValue,
|
||||
evaluateRefinementCondition,
|
||||
evaluateRefinementExpression,
|
||||
getFreeRefinementVars,
|
||||
isRefinementExprEquivalent,
|
||||
isRefinementLiteral,
|
||||
isRefinementTuple,
|
||||
isRefinementVar,
|
||||
isRefinementWildcard,
|
||||
RefinementTypeDiag,
|
||||
} from './refinementTypeUtils';
|
||||
import { ClassType, isAnyOrUnknown, isClass, isClassInstance } from './types';
|
||||
import { getBuiltInRefinementClassId } from './typeUtils';
|
||||
|
||||
export interface AssignRefinementsOptions {
|
||||
checkOverloadOverlap?: boolean;
|
||||
}
|
||||
|
||||
// Attempts to assign a srcType with a refinement type to a destType
|
||||
// with a refinement type.
|
||||
export function assignRefinements(
|
||||
destType: ClassType,
|
||||
srcType: ClassType,
|
||||
diag: DiagnosticAddendum | undefined,
|
||||
constraints: ConstraintTracker | undefined,
|
||||
options?: AssignRefinementsOptions
|
||||
): boolean {
|
||||
let assignmentOk = true;
|
||||
|
||||
// Apply refinements by class.
|
||||
const destRefs = destType.priv.refinements ?? [];
|
||||
const srcRefs = srcType.priv.refinements ?? [];
|
||||
|
||||
const synthesizedRefinement = synthesizeRefinementTypeFromLiteral(srcType);
|
||||
if (synthesizedRefinement) {
|
||||
srcRefs.push(synthesizedRefinement);
|
||||
}
|
||||
|
||||
for (const destRef of destRefs) {
|
||||
let srcClassMatch = false;
|
||||
|
||||
for (const srcRef of srcRefs) {
|
||||
if (destRef.classDetails.classId !== srcRef.classDetails.classId) {
|
||||
continue;
|
||||
}
|
||||
|
||||
srcClassMatch = true;
|
||||
|
||||
if (!assignRefinement(destRef, srcRef, diag, constraints, options)) {
|
||||
assignmentOk = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (srcRefs.length === 0 && options?.checkOverloadOverlap) {
|
||||
assignmentOk = false;
|
||||
}
|
||||
|
||||
// If no source refinements matched the dest refinement class,
|
||||
// the assignment validity is based on whether it's enforced.
|
||||
if (!srcClassMatch && destRef.isEnforced) {
|
||||
assignmentOk = false;
|
||||
}
|
||||
}
|
||||
|
||||
return assignmentOk;
|
||||
}
|
||||
|
||||
export function solveRefinementVarRecursive(
|
||||
constraintSet: ConstraintSet,
|
||||
solutionSet: ConstraintSolutionSet,
|
||||
varId: RefinementVarId
|
||||
): RefinementExpr | undefined {
|
||||
// If this refinement variable already has a solution, don't attempt to re-solve it.
|
||||
if (solutionSet.hasRefinementVarType(varId)) {
|
||||
return solutionSet.getRefinementVarType(varId);
|
||||
}
|
||||
|
||||
const value = constraintSet.getRefinementVarType(varId);
|
||||
if (!value) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Protect against infinite recursion by setting the initial value to
|
||||
// undefined. We'll replace this later with a real value.
|
||||
solutionSet.setRefinementVarType(varId, /* value */ undefined);
|
||||
|
||||
// Determine which free variables are referenced by this expression. We need
|
||||
// to ensure that they are solved first.
|
||||
const freeVars = getFreeRefinementVars(value);
|
||||
for (const freeVar of freeVars) {
|
||||
solveRefinementVarRecursive(constraintSet, solutionSet, freeVar.id);
|
||||
}
|
||||
|
||||
// Now evaluate the expression.
|
||||
const solvedValue = applySolvedRefinementVars(value, solutionSet.getRefinementVarMap());
|
||||
const simplifiedValue = evaluateRefinementExpression(solvedValue);
|
||||
|
||||
solutionSet.setRefinementVarType(varId, simplifiedValue);
|
||||
|
||||
return simplifiedValue;
|
||||
}
|
||||
|
||||
export function assignRefinement(
|
||||
destRefinement: TypeRefinement,
|
||||
srcRefinement: TypeRefinement,
|
||||
diag: DiagnosticAddendum | undefined,
|
||||
constraints: ConstraintTracker | undefined,
|
||||
options?: AssignRefinementsOptions
|
||||
): boolean {
|
||||
assert(destRefinement.classDetails.classId === srcRefinement.classDetails.classId);
|
||||
|
||||
const destValue = destRefinement.value;
|
||||
const srcValue = srcRefinement.value;
|
||||
|
||||
// Determine if there are any conditions provided by the caller (for
|
||||
// function call evaluation) or local conditions (for assignments).
|
||||
let conditions: RefinementExpr[] | undefined;
|
||||
if (!conditions && destRefinement.condition) {
|
||||
conditions = [destRefinement.condition];
|
||||
}
|
||||
|
||||
// If we have conditions to verify but have not been provided a
|
||||
// constraint tracker, create a temporary one.
|
||||
if (conditions && !constraints) {
|
||||
constraints = new ConstraintTracker();
|
||||
}
|
||||
|
||||
// Handle tuples specially.
|
||||
if (
|
||||
destRefinement.classDetails.domain === 'IntTupleRefinement' &&
|
||||
destValue.nodeType === RefinementNodeType.Tuple &&
|
||||
srcValue.nodeType === RefinementNodeType.Tuple
|
||||
) {
|
||||
const srcEntries = [...srcValue.entries];
|
||||
|
||||
// If the dest and source tuple shapes match, we can skip any reshaping efforts.
|
||||
if (
|
||||
destValue.entries.length !== srcValue.entries.length ||
|
||||
destValue.entries.some((entry, i) => entry.isUnpacked !== srcValue.entries[i].isUnpacked)
|
||||
) {
|
||||
if (!adjustSourceTupleShape(destValue.entries, srcEntries)) {
|
||||
const msg =
|
||||
destRefinement.classDetails.classId === getBuiltInRefinementClassId('Shape')
|
||||
? LocAddendum.refinementShapeMismatch()
|
||||
: LocAddendum.refinementTupleMismatch();
|
||||
|
||||
diag?.addMessage(
|
||||
msg.format({
|
||||
expected: printRefinementExpr(destValue),
|
||||
received: printRefinementExpr(srcValue),
|
||||
})
|
||||
);
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// At this point, the dest and src tuples should have the same shape
|
||||
// (i.e. same length and with the same entries unpacked or not).
|
||||
assert(destValue.entries.length === srcEntries.length);
|
||||
|
||||
for (let i = 0; i < destValue.entries.length; i++) {
|
||||
const destEntry = destValue.entries[i];
|
||||
const srcEntry = srcEntries[i];
|
||||
|
||||
assert(destEntry.isUnpacked === srcEntry.isUnpacked);
|
||||
|
||||
if (!assignRefinementValue(destEntry.value, srcEntry.value, diag, constraints, options)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!assignRefinementValue(destValue, srcValue, diag, constraints, options)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (conditions && !validateRefinementConditions(conditions, diag, constraints)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
export function validateRefinementConditions(
|
||||
conditions: RefinementExpr[],
|
||||
diag: DiagnosticAddendum | undefined,
|
||||
constraints: ConstraintTracker | undefined
|
||||
): boolean {
|
||||
let solvedConditions = [...conditions];
|
||||
|
||||
if (constraints) {
|
||||
const solutionSet = new ConstraintSolutionSet();
|
||||
const constraintSet = constraints.getMainConstraintSet();
|
||||
|
||||
// Solve the refinement variables.
|
||||
constraintSet.doForEachRefinementVar((name) => {
|
||||
solveRefinementVarRecursive(constraintSet, solutionSet, name);
|
||||
});
|
||||
|
||||
solvedConditions = conditions.map((condition) =>
|
||||
applySolvedRefinementVars(condition, solutionSet.getRefinementVarMap())
|
||||
);
|
||||
}
|
||||
|
||||
for (let i = 0; i < solvedConditions.length; i++) {
|
||||
const errors: RefinementTypeDiag[] = [];
|
||||
|
||||
if (!evaluateRefinementCondition(solvedConditions[i], { refinements: { errors } })) {
|
||||
diag?.addMessage(
|
||||
LocAddendum.refinementConditionNotSatisfied().format({
|
||||
condition: printRefinementExpr(solvedConditions[i]),
|
||||
})
|
||||
);
|
||||
|
||||
errors.forEach((error) => {
|
||||
diag?.addAddendum(error.diag);
|
||||
});
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
export function synthesizeRefinementTypeFromLiteral(classType: ClassType): TypeRefinement | undefined {
|
||||
if (!isClass(classType)) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (ClassType.isBuiltIn(classType, 'tuple')) {
|
||||
const typeArgs = classType.priv.tupleTypeArgs;
|
||||
if (!typeArgs) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
let foundInt = false;
|
||||
if (
|
||||
!typeArgs.every((typeArg) => {
|
||||
if (isClassInstance(typeArg.type) && ClassType.isBuiltIn(typeArg.type, 'int')) {
|
||||
foundInt = true;
|
||||
return true;
|
||||
}
|
||||
return isAnyOrUnknown(typeArg.type);
|
||||
}) ||
|
||||
!foundInt
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const entries: RefinementTupleEntry[] = typeArgs.map((typeArg) => {
|
||||
if (isClassInstance(typeArg.type) && typeArg.type.priv.literalValue !== undefined && !typeArg.isUnbounded) {
|
||||
const value = typeArg.type.priv.literalValue;
|
||||
assert(typeof value === 'number' || typeof value === 'bigint');
|
||||
const entry: RefinementNumberNode = { nodeType: RefinementNodeType.Number, value: value };
|
||||
return { value: entry, isUnpacked: false };
|
||||
}
|
||||
|
||||
return { value: createWildcardRefinementValue(), isUnpacked: typeArg.isUnbounded };
|
||||
});
|
||||
|
||||
return {
|
||||
classDetails: {
|
||||
domain: 'IntTupleRefinement',
|
||||
className: 'IntTupleValue',
|
||||
classId: getBuiltInRefinementClassId('IntTupleValue'),
|
||||
},
|
||||
value: {
|
||||
nodeType: RefinementNodeType.Tuple,
|
||||
entries,
|
||||
},
|
||||
isEnforced: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (classType.priv.literalValue === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (ClassType.isBuiltIn(classType, 'int')) {
|
||||
assert(typeof classType.priv.literalValue === 'number' || typeof classType.priv.literalValue === 'bigint');
|
||||
|
||||
return {
|
||||
classDetails: {
|
||||
domain: 'IntRefinement',
|
||||
className: 'IntValue',
|
||||
classId: getBuiltInRefinementClassId('IntValue'),
|
||||
},
|
||||
value: {
|
||||
nodeType: RefinementNodeType.Number,
|
||||
value: classType.priv.literalValue,
|
||||
},
|
||||
isEnforced: true,
|
||||
};
|
||||
}
|
||||
|
||||
if (ClassType.isBuiltIn(classType, 'str')) {
|
||||
assert(typeof classType.priv.literalValue === 'string');
|
||||
|
||||
return {
|
||||
classDetails: {
|
||||
domain: 'StrRefinement',
|
||||
className: 'StrValue',
|
||||
classId: getBuiltInRefinementClassId('StrValue'),
|
||||
},
|
||||
value: {
|
||||
nodeType: RefinementNodeType.String,
|
||||
value: classType.priv.literalValue,
|
||||
},
|
||||
isEnforced: true,
|
||||
};
|
||||
}
|
||||
|
||||
if (ClassType.isBuiltIn(classType, 'bytes')) {
|
||||
assert(typeof classType.priv.literalValue === 'string');
|
||||
|
||||
return {
|
||||
classDetails: {
|
||||
domain: 'BytesRefinement',
|
||||
className: 'BytesValue',
|
||||
classId: getBuiltInRefinementClassId('BytesValue'),
|
||||
},
|
||||
value: {
|
||||
nodeType: RefinementNodeType.Bytes,
|
||||
value: classType.priv.literalValue,
|
||||
},
|
||||
isEnforced: true,
|
||||
};
|
||||
}
|
||||
|
||||
if (ClassType.isBuiltIn(classType, 'bool')) {
|
||||
assert(typeof classType.priv.literalValue === 'boolean');
|
||||
|
||||
return {
|
||||
classDetails: {
|
||||
domain: 'BoolRefinement',
|
||||
className: 'BoolValue',
|
||||
classId: getBuiltInRefinementClassId('BoolValue'),
|
||||
},
|
||||
value: {
|
||||
nodeType: RefinementNodeType.Boolean,
|
||||
value: classType.priv.literalValue,
|
||||
},
|
||||
isEnforced: true,
|
||||
};
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Adjusts the srcEntries to match the shape of the destEntries if possible.
|
||||
// It assumes the caller has already confirmed that the dest and src shapes
|
||||
// don't already match. The shape includes both the number of entries and
|
||||
// whether each of those entries is unpacked.
|
||||
function adjustSourceTupleShape(destEntries: RefinementTupleEntry[], srcEntries: RefinementTupleEntry[]): boolean {
|
||||
const destUnpackCount = destEntries.filter((entry) => entry.isUnpacked).length;
|
||||
const srcUnpackCount = srcEntries.filter((entry) => entry.isUnpacked).length;
|
||||
const srcUnpackIndex = srcEntries.findIndex((entry) => entry.isUnpacked);
|
||||
|
||||
if (destUnpackCount > 1) {
|
||||
// If there's more than one unpacked entry in the dest, there
|
||||
// is no unambiguous way to adjust the source, so don't attempt.
|
||||
return false;
|
||||
}
|
||||
|
||||
if (destUnpackCount === 1) {
|
||||
// If the dest has a single unpacked entry, we may be able to adjust
|
||||
// the source shape to match it.
|
||||
|
||||
const srcEntriesToPack = srcEntries.length - destEntries.length + 1;
|
||||
if (srcEntriesToPack < 0) {
|
||||
return false;
|
||||
}
|
||||
const destUnpackIndex = destEntries.findIndex((entry) => entry.isUnpacked);
|
||||
const removedEntries = srcEntries.splice(destUnpackIndex, srcEntriesToPack);
|
||||
|
||||
// If any of the remaining source entries are unpacked, we can't
|
||||
// make the shapes match.
|
||||
if (srcEntries.some((entry) => entry.isUnpacked)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Add a new unpacked tuple entry.
|
||||
srcEntries.splice(destUnpackIndex, 0, {
|
||||
value: { nodeType: RefinementNodeType.Tuple, entries: removedEntries },
|
||||
isUnpacked: true,
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// If the dest has no unpacked entries, the source cannot have any
|
||||
// unpacked entries unless it has an unpacked wildcard.
|
||||
if (srcUnpackCount > 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (srcUnpackIndex < 0 || srcEntries[srcUnpackIndex].value.nodeType !== RefinementNodeType.Wildcard) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Remove the unpacked wildcard entry to make the shapes match.
|
||||
srcEntries.splice(srcUnpackIndex, 1);
|
||||
|
||||
if (srcEntries.length > destEntries.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Insert wildcard entries to match the dest shape.
|
||||
while (srcEntries.length < destEntries.length) {
|
||||
srcEntries.splice(srcUnpackIndex, 0, {
|
||||
value: createWildcardRefinementValue(),
|
||||
isUnpacked: false,
|
||||
});
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
function assignRefinementValue(
|
||||
destExpr: RefinementExpr,
|
||||
srcExpr: RefinementExpr,
|
||||
diag: DiagnosticAddendum | undefined,
|
||||
constraints: ConstraintTracker | undefined,
|
||||
options?: AssignRefinementsOptions
|
||||
): boolean {
|
||||
// Handle assignment to or from wildcard.
|
||||
if (isRefinementWildcard(destExpr) || isRefinementWildcard(srcExpr)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (isRefinementExprEquivalent(srcExpr, destExpr)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Handle assignments to literals.
|
||||
if (isRefinementLiteral(destExpr) && isRefinementLiteral(srcExpr)) {
|
||||
if (destExpr.value !== srcExpr.value) {
|
||||
diag?.addMessage(
|
||||
LocAddendum.refinementLiteralAssignment().format({
|
||||
expected: printRefinementExpr(destExpr),
|
||||
received: printRefinementExpr(srcExpr),
|
||||
})
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (isRefinementVar(destExpr)) {
|
||||
if (assignToRefinementVar(destExpr, srcExpr, diag, constraints)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isRefinementExprEquivalent(destExpr, srcExpr)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (isRefinementTuple(destExpr) && isRefinementTuple(srcExpr)) {
|
||||
if (destExpr.entries.length === srcExpr.entries.length) {
|
||||
if (
|
||||
destExpr.entries.every((destEntry, i) => {
|
||||
const srcEntry = srcExpr.entries[i];
|
||||
return (
|
||||
destEntry.isUnpacked === srcEntry.isUnpacked &&
|
||||
assignRefinementValue(destEntry.value, srcEntry.value, diag, constraints)
|
||||
);
|
||||
})
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// See if we can simplify the source or the dest expression and try again.
|
||||
const simplifiedDest = evaluateRefinementExpression(destExpr);
|
||||
const simplifiedSrc = evaluateRefinementExpression(srcExpr);
|
||||
if (
|
||||
!TypeRefinement.isRefinementExprSame(simplifiedDest, destExpr) ||
|
||||
!TypeRefinement.isRefinementExprSame(simplifiedSrc, srcExpr)
|
||||
) {
|
||||
return assignRefinementValue(simplifiedDest, simplifiedSrc, diag, constraints);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
function assignToRefinementVar(
|
||||
destExpr: RefinementVarNode,
|
||||
srcExpr: RefinementExpr,
|
||||
diag: DiagnosticAddendum | undefined,
|
||||
constraints: ConstraintTracker | undefined
|
||||
): boolean {
|
||||
// If the dest is a bound variable, it cannot receive any value other
|
||||
// than a wildcard and itself.
|
||||
if (destExpr.var.isBound) {
|
||||
if (isRefinementWildcard(srcExpr)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (isRefinementExprEquivalent(srcExpr, destExpr)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
diag?.addMessage(
|
||||
LocAddendum.refinementValMismatch().format({
|
||||
expected: printRefinementExpr({ nodeType: RefinementNodeType.Var, var: destExpr.var }),
|
||||
received: printRefinementExpr(srcExpr),
|
||||
})
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
// If there is no constraint tracker, we have nothing more to do.
|
||||
if (!constraints) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const constraintSet = constraints.getMainConstraintSet();
|
||||
const curValue = constraintSet.getRefinementVarType(destExpr.var.id);
|
||||
|
||||
// If there is a current value, the new value must be the same or more specific.
|
||||
if (curValue) {
|
||||
if (!assignRefinementValue(curValue, srcExpr, /* diag */ undefined, /* constraints */ undefined)) {
|
||||
diag?.addMessage(
|
||||
LocAddendum.refinementValMismatch().format({
|
||||
expected: printRefinementExpr(curValue),
|
||||
received: printRefinementExpr(srcExpr),
|
||||
})
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Assign the new value.
|
||||
constraintSet.setRefinementVarType(destExpr.var.id, srcExpr);
|
||||
return true;
|
||||
}
|
||||
1715
packages/pyright-internal/src/analyzer/refinementTypeUtils.ts
Normal file
1715
packages/pyright-internal/src/analyzer/refinementTypeUtils.ts
Normal file
File diff suppressed because it is too large
Load diff
1676
packages/pyright-internal/src/analyzer/refinementTypes.ts
Normal file
1676
packages/pyright-internal/src/analyzer/refinementTypes.ts
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -14,6 +14,7 @@ import { ExpressionNode, ParseNodeType, SliceNode, TupleNode } from '../parser/p
|
|||
import { addConstraintsForExpectedType } from './constraintSolver';
|
||||
import { ConstraintTracker } from './constraintTracker';
|
||||
import { getTypeVarScopesForNode } from './parseTreeUtils';
|
||||
import { RefinementNodeType } from './refinementTypes';
|
||||
import { AssignTypeFlags, EvalFlags, maxInferredContainerDepth, TypeEvaluator, TypeResult } from './typeEvaluatorTypes';
|
||||
import {
|
||||
AnyType,
|
||||
|
|
@ -35,6 +36,7 @@ import {
|
|||
import {
|
||||
convertToInstance,
|
||||
doForEachSubtype,
|
||||
getBuiltInRefinementClassId,
|
||||
getContainerDepth,
|
||||
InferenceContext,
|
||||
isLiteralType,
|
||||
|
|
@ -582,3 +584,60 @@ function getTupleSliceParam(
|
|||
|
||||
return value;
|
||||
}
|
||||
|
||||
// If this is a tuple[int, ...] with one or more refinements in the
|
||||
// IntTupleRefinement refinement domain, returns a refined version
|
||||
// of the TupleTypeArg entries.
|
||||
export function getRefinementShapeForTuple(type: Type, entries: TupleTypeArg[]): TupleTypeArg[] {
|
||||
if (!isClassInstance(type)) {
|
||||
return entries;
|
||||
}
|
||||
|
||||
if (entries.length !== 1 || !entries[0].isUnbounded) {
|
||||
return entries;
|
||||
}
|
||||
|
||||
// Shapes apply only to types of type[int, ...].
|
||||
const entryType = entries[0].type;
|
||||
if (
|
||||
!isClassInstance(entryType) ||
|
||||
!ClassType.isBuiltIn(entryType, 'int') ||
|
||||
entryType.priv.literalValue !== undefined
|
||||
) {
|
||||
return entries;
|
||||
}
|
||||
|
||||
const tupleRefinement = type.priv.refinements?.find((r) => r.classDetails.domain === 'IntTupleRefinement');
|
||||
|
||||
if (!tupleRefinement) {
|
||||
return entries;
|
||||
}
|
||||
|
||||
const refinementShape = tupleRefinement.value;
|
||||
if (refinementShape.nodeType !== RefinementNodeType.Tuple) {
|
||||
return entries;
|
||||
}
|
||||
|
||||
return refinementShape.entries.map((entry) => {
|
||||
const refinement = {
|
||||
...tupleRefinement,
|
||||
value: entry.value,
|
||||
condition: undefined,
|
||||
};
|
||||
|
||||
if (!entry.isUnpacked) {
|
||||
refinement.classDetails = {
|
||||
domain: 'IntRefinement',
|
||||
className: 'IntValue',
|
||||
classId: getBuiltInRefinementClassId('IntValue'),
|
||||
baseSupportsLiteral: true,
|
||||
baseSupportsStringShortcut: true,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
type: ClassType.cloneAddRefinement(entryType, refinement),
|
||||
isUnbounded: entry.isUnpacked,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -34,6 +34,7 @@ import { CodeFlowReferenceExpressionNode, FlowNode } from './codeFlowTypes';
|
|||
import { ConstraintTracker } from './constraintTracker';
|
||||
import { Declaration } from './declaration';
|
||||
import * as DeclarationUtils from './declarationUtils';
|
||||
import { RefinementExpr, RefinementVarType, TypeRefinement } from './refinementTypes';
|
||||
import { SymbolWithScope } from './scope';
|
||||
import { Symbol } from './symbol';
|
||||
import { PrintTypeFlags } from './typePrinter';
|
||||
|
|
@ -178,6 +179,9 @@ export const enum EvalFlags {
|
|||
// with the enclosing class or an outer scope.
|
||||
EnforceClassTypeVarScope = 1 << 31,
|
||||
|
||||
// Expecting a possible refinement type expression.
|
||||
Refinement = 1 << 32,
|
||||
|
||||
// Defaults used for evaluating the LHS of a call expression.
|
||||
CallBaseDefaults = NoSpecialize,
|
||||
|
||||
|
|
@ -254,6 +258,9 @@ export interface TypeResult<T extends Type = Type> {
|
|||
|
||||
// Deprecation messages related to magic methods.
|
||||
magicMethodDeprecationInfo?: MagicMethodDeprecationInfo;
|
||||
|
||||
// Refinement definition.
|
||||
refinement?: TypeRefinement;
|
||||
}
|
||||
|
||||
export interface TypeResultWithNode extends TypeResult {
|
||||
|
|
@ -345,6 +352,7 @@ export interface EffectiveTypeResult {
|
|||
export interface ValidateArgTypeParams {
|
||||
paramCategory: ParamCategory;
|
||||
paramType: Type;
|
||||
refinementConditions: RefinementExpr[] | undefined;
|
||||
requiresTypeVarMatching: boolean;
|
||||
argument: Arg;
|
||||
isDefaultArg?: boolean;
|
||||
|
|
@ -352,7 +360,7 @@ export interface ValidateArgTypeParams {
|
|||
errorNode: ExpressionNode;
|
||||
paramName?: string | undefined;
|
||||
isParamNameSynthesized?: boolean;
|
||||
mapsToVarArgList?: boolean | undefined;
|
||||
mapsToVarArgListType?: Type | undefined;
|
||||
isinstanceParam?: boolean;
|
||||
}
|
||||
|
||||
|
|
@ -406,6 +414,9 @@ export interface CallResult {
|
|||
// Were any errors discovered when evaluating argument types?
|
||||
argumentErrors?: boolean;
|
||||
|
||||
// Were any refinement errors discovered when evaluating arg types?
|
||||
refinementErrors?: boolean;
|
||||
|
||||
// Did one or more arguments evaluated to Any or Unknown?
|
||||
anyOrUnknownArg?: UnknownType | AnyType;
|
||||
|
||||
|
|
@ -517,6 +528,19 @@ export interface SynthesizedTypeInfo {
|
|||
export interface SymbolDeclInfo {
|
||||
decls: Declaration[];
|
||||
synthesizedTypes: SynthesizedTypeInfo[];
|
||||
refinementInfo?: {
|
||||
// Type of refinement variable, if applicable.
|
||||
varType?: RefinementVarType;
|
||||
|
||||
// Indicates that the symbol is a refinement wildcard symbol ("_").
|
||||
isWildcard?: boolean;
|
||||
|
||||
// Signature of refinement function call.
|
||||
callSignature?: string;
|
||||
|
||||
// Docstring for refinement call.
|
||||
callDocstring?: string;
|
||||
};
|
||||
}
|
||||
|
||||
export const enum AssignTypeFlags {
|
||||
|
|
@ -608,6 +632,7 @@ export interface TypeEvaluator {
|
|||
createSubclass: (errorNode: ExpressionNode, type1: ClassType, type2: ClassType) => ClassType;
|
||||
getTypeOfFunction: (node: FunctionNode) => FunctionTypeResult | undefined;
|
||||
getTypeOfExpressionExpectingType: (node: ExpressionNode, options?: ExpectedTypeOptions) => TypeResult;
|
||||
getTypeMetadata: (errorNode: ExpressionNode, typeResult: TypeResult, node: ExpressionNode) => TypeResult;
|
||||
evaluateTypeForSubnode: (subnode: ParseNode, callback: () => void) => TypeResult | undefined;
|
||||
evaluateTypesForStatement: (node: ParseNode) => void;
|
||||
evaluateTypesForMatchStatement: (node: MatchNode) => void;
|
||||
|
|
@ -652,6 +677,7 @@ export interface TypeEvaluator {
|
|||
|
||||
getDeclInfoForStringNode: (node: StringNode) => SymbolDeclInfo | undefined;
|
||||
getDeclInfoForNameNode: (node: NameNode, skipUnreachableCode?: boolean) => SymbolDeclInfo | undefined;
|
||||
getRefinementVarType: (node: NameNode) => RefinementVarType | undefined;
|
||||
getTypeForDeclaration: (declaration: Declaration) => DeclaredSymbolTypeInfo;
|
||||
resolveAliasDeclaration: (
|
||||
declaration: Declaration,
|
||||
|
|
@ -728,7 +754,8 @@ export interface TypeEvaluator {
|
|||
methodName: string,
|
||||
argList: TypeResult[],
|
||||
errorNode: ExpressionNode,
|
||||
inferenceContext: InferenceContext | undefined
|
||||
inferenceContext: InferenceContext | undefined,
|
||||
diag?: DiagnosticAddendum
|
||||
) => TypeResult | undefined;
|
||||
bindFunctionToClassOrObject: (
|
||||
baseType: ClassType | undefined,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import { assert } from '../common/debug';
|
|||
import { ParamCategory } from '../parser/parseNodes';
|
||||
import { isTypedKwargs } from './parameterUtils';
|
||||
import * as ParseTreeUtils from './parseTreeUtils';
|
||||
import { printRefinement } from './refinementPrinter';
|
||||
import { printBytesLiteral, printStringLiteral } from './typePrinterUtils';
|
||||
import {
|
||||
ClassType,
|
||||
|
|
@ -197,7 +198,6 @@ function printTypeInternal(
|
|||
recursionCount++;
|
||||
|
||||
const originalPrintTypeFlags = printTypeFlags;
|
||||
const parenthesizeUnion = (printTypeFlags & PrintTypeFlags.ParenthesizeUnion) !== 0;
|
||||
printTypeFlags &= ~(PrintTypeFlags.ParenthesizeUnion | PrintTypeFlags.ParenthesizeCallable);
|
||||
|
||||
// If this is a type alias, see if we should use its name rather than
|
||||
|
|
@ -213,6 +213,11 @@ function printTypeInternal(
|
|||
}
|
||||
}
|
||||
|
||||
// If there are refinements on the type, don't use the type alias.
|
||||
if (type.category === TypeCategory.Class && type.priv.refinements) {
|
||||
expandTypeAlias = true;
|
||||
}
|
||||
|
||||
if (!expandTypeAlias) {
|
||||
try {
|
||||
recursionTypes.push(type);
|
||||
|
|
@ -363,6 +368,36 @@ function printTypeInternal(
|
|||
return '...';
|
||||
}
|
||||
|
||||
const baseTypeStr = printTypeWithoutRefinement(
|
||||
type,
|
||||
originalPrintTypeFlags,
|
||||
returnTypeCallback,
|
||||
uniqueNameMap,
|
||||
recursionTypes,
|
||||
recursionCount
|
||||
);
|
||||
|
||||
if (!isClass(type) || !type.priv?.refinements || type.priv.refinements.length === 0) {
|
||||
return baseTypeStr;
|
||||
}
|
||||
|
||||
const refinements = type.priv.refinements.map((refinement) => printRefinement(refinement));
|
||||
return `${baseTypeStr} @ ${refinements.join(' @ ')}`;
|
||||
}
|
||||
|
||||
function printTypeWithoutRefinement(
|
||||
type: Type,
|
||||
printTypeFlags: PrintTypeFlags,
|
||||
returnTypeCallback: FunctionReturnTypeCallback,
|
||||
uniqueNameMap: UniqueNameMap,
|
||||
recursionTypes: Type[],
|
||||
recursionCount: number
|
||||
): string {
|
||||
const originalPrintTypeFlags = printTypeFlags;
|
||||
const parenthesizeUnion = (printTypeFlags & PrintTypeFlags.ParenthesizeUnion) !== 0;
|
||||
|
||||
printTypeFlags &= ~(PrintTypeFlags.ParenthesizeUnion | PrintTypeFlags.ParenthesizeCallable);
|
||||
|
||||
try {
|
||||
recursionTypes.push(type);
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,16 @@ import { assert } from '../common/debug';
|
|||
import { ParamCategory } from '../parser/parseNodes';
|
||||
import { ConstraintSolution, ConstraintSolutionSet } from './constraintSolution';
|
||||
import { DeclarationType } from './declaration';
|
||||
import { RefinementClassDetails, RefinementDomain, TypeRefinement } from './refinementTypes';
|
||||
import {
|
||||
applySolvedRefinementVars,
|
||||
evaluateRefinementExpression,
|
||||
isRefinementLiteral,
|
||||
isRefinementWildcard,
|
||||
makeRefinementVarsBound,
|
||||
makeRefinementVarsFree,
|
||||
RefinementTypeDiag,
|
||||
} from './refinementTypeUtils';
|
||||
import { Symbol, SymbolFlags, SymbolTable } from './symbol';
|
||||
import { isEffectivelyClassVar, isTypedDictMemberAccessedThroughIndex } from './symbolUtils';
|
||||
import {
|
||||
|
|
@ -214,6 +224,10 @@ export interface ApplyTypeVarOptions {
|
|||
useUnknown?: boolean;
|
||||
eliminateUnsolvedInUnions?: boolean;
|
||||
};
|
||||
refinements?: {
|
||||
errors?: RefinementTypeDiag[];
|
||||
warnings?: RefinementTypeDiag[];
|
||||
};
|
||||
}
|
||||
|
||||
// Tracks whether a function signature has been seen before within
|
||||
|
|
@ -294,6 +308,41 @@ export function isIncompleteUnknown(type: Type): boolean {
|
|||
return isUnknown(type) && type.priv.isIncomplete;
|
||||
}
|
||||
|
||||
export function getRefinementDomain(type: ClassType): RefinementDomain | undefined {
|
||||
for (const baseClass of type.shared.mro) {
|
||||
if (isClass(baseClass) && ClassType.isBuiltIn(baseClass)) {
|
||||
const className = baseClass.shared.name;
|
||||
const domainNames = ['IntRefinement', 'StrRefinement', 'IntTupleRefinement', 'Refinement'];
|
||||
if (domainNames.includes(className)) {
|
||||
return className as RefinementDomain;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export function getRefinementClassId(type: ClassType): string {
|
||||
if (ClassType.isBuiltIn(type)) {
|
||||
return getBuiltInRefinementClassId(type.shared.name);
|
||||
}
|
||||
|
||||
return `${type.shared.fileUri}:${type.shared.name}`;
|
||||
}
|
||||
|
||||
export function getBuiltInRefinementClassId(className: string): string {
|
||||
// Choose a unique ID that will never conflict with a user-defined class.
|
||||
return `stdlib:${className}`;
|
||||
}
|
||||
|
||||
export function hasTupleRefinement(type: ClassType): boolean {
|
||||
if (!type.priv.refinements) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return type.priv.refinements.some((refinement) => refinement.classDetails.domain === 'IntTupleRefinement');
|
||||
}
|
||||
|
||||
// Similar to isTypeSame except that type1 is a TypeVar and type2
|
||||
// can be either a TypeVar of the same type or a union that includes
|
||||
// conditional types associated with that bound TypeVar.
|
||||
|
|
@ -1200,9 +1249,78 @@ export function isLiteralLikeType(type: ClassType): boolean {
|
|||
return true;
|
||||
}
|
||||
|
||||
// If the class has a value refinement type, it's considered literal-like.
|
||||
if (type.priv.refinements) {
|
||||
const valueClasses = ['IntValue', 'StrValue', 'BytesValue'].map((name) => getBuiltInRefinementClassId(name));
|
||||
return type.priv.refinements.some(
|
||||
(refinement) => valueClasses.includes(refinement.classDetails.classId) && refinement.isEnforced
|
||||
);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// If possible, converts an int, str, or bytes value with a refinement type
|
||||
// into its corresponding Literal type.
|
||||
export function simplifyRefinementTypeToLiteral(type: ClassType): ClassType {
|
||||
const refinementClassMap: { [key: string]: string } = {
|
||||
int: 'IntValue',
|
||||
str: 'StrValue',
|
||||
bytes: 'BytesValue',
|
||||
bool: 'BoolValue',
|
||||
};
|
||||
|
||||
if (!type.priv.refinements || type.priv.literalValue !== undefined || !ClassType.isBuiltIn(type)) {
|
||||
return type;
|
||||
}
|
||||
|
||||
const refClassName = refinementClassMap[type.shared.name];
|
||||
if (!refClassName) {
|
||||
return type;
|
||||
}
|
||||
const refClassId = getBuiltInRefinementClassId(refClassName);
|
||||
|
||||
const refinements = type.priv.refinements.filter((refinement) => refinement.classDetails.classId === refClassId);
|
||||
if (refinements.length !== 1) {
|
||||
return type;
|
||||
}
|
||||
const refinementExpr = refinements[0].value;
|
||||
if (!isRefinementLiteral(refinementExpr)) {
|
||||
return type;
|
||||
}
|
||||
|
||||
const remainingRefinements = type.priv.refinements.filter(
|
||||
(refinement) => refinement.classDetails.classId !== refClassId
|
||||
);
|
||||
|
||||
let resultType = ClassType.cloneWithLiteral(type, refinementExpr.value);
|
||||
resultType = ClassType.cloneWithRefinements(resultType, remainingRefinements);
|
||||
return resultType;
|
||||
}
|
||||
|
||||
// If the value has an IntValue refinement type associated with it or is an
|
||||
// integer literal, this function returns the effective value.
|
||||
export function getIntValueRefinement(type: ClassType): TypeRefinement | undefined {
|
||||
if (type.priv.literalValue !== undefined) {
|
||||
if (!ClassType.isBuiltIn(type, 'int')) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
assert(typeof type.priv.literalValue === 'number' || typeof type.priv.literalValue === 'bigint');
|
||||
const classDetails: RefinementClassDetails = {
|
||||
domain: 'IntRefinement',
|
||||
className: 'IntValue',
|
||||
classId: getBuiltInRefinementClassId('IntValue'),
|
||||
baseSupportsLiteral: true,
|
||||
baseSupportsStringShortcut: true,
|
||||
};
|
||||
return TypeRefinement.fromLiteral(classDetails, type.priv.literalValue, /* isEnforced */ true);
|
||||
}
|
||||
|
||||
const intValueClassId = getBuiltInRefinementClassId('IntValue');
|
||||
return type.priv.refinements?.find((refinement) => refinement.classDetails.classId === intValueClassId);
|
||||
}
|
||||
|
||||
export function containsLiteralType(type: Type, includeTypeArgs = false): boolean {
|
||||
class ContainsLiteralTypeWalker extends TypeWalker {
|
||||
foundLiteral = false;
|
||||
|
|
@ -2025,6 +2143,65 @@ export function getTypeVarArgsRecursive(type: Type, recursionCount = 0): TypeVar
|
|||
return [];
|
||||
}
|
||||
|
||||
export function getRefinementsRecursive(type: Type, recursionCount = 0): TypeRefinement[] {
|
||||
if (recursionCount > maxTypeRecursionCount) {
|
||||
return [];
|
||||
}
|
||||
recursionCount++;
|
||||
|
||||
const aliasInfo = type.props?.typeAliasInfo;
|
||||
if (aliasInfo?.typeArgs) {
|
||||
const combinedList: TypeRefinement[] = [];
|
||||
|
||||
aliasInfo?.typeArgs.forEach((typeArg) => {
|
||||
appendArray(combinedList, getRefinementsRecursive(typeArg, recursionCount));
|
||||
});
|
||||
|
||||
return combinedList;
|
||||
}
|
||||
|
||||
if (isClass(type)) {
|
||||
const combinedList: TypeRefinement[] = [];
|
||||
const typeArgs = type.priv.tupleTypeArgs ? type.priv.tupleTypeArgs.map((e) => e.type) : type.priv.typeArgs;
|
||||
if (typeArgs) {
|
||||
typeArgs.forEach((typeArg) => {
|
||||
appendArray(combinedList, getRefinementsRecursive(typeArg, recursionCount));
|
||||
});
|
||||
}
|
||||
|
||||
if (type.priv.refinements) {
|
||||
appendArray(combinedList, type.priv.refinements);
|
||||
}
|
||||
|
||||
return combinedList;
|
||||
}
|
||||
|
||||
if (isUnion(type)) {
|
||||
const combinedList: TypeRefinement[] = [];
|
||||
doForEachSubtype(type, (subtype) => {
|
||||
appendArray(combinedList, getRefinementsRecursive(subtype, recursionCount));
|
||||
});
|
||||
return combinedList;
|
||||
}
|
||||
|
||||
if (isFunction(type)) {
|
||||
const combinedList: TypeRefinement[] = [];
|
||||
|
||||
for (let i = 0; i < type.shared.parameters.length; i++) {
|
||||
appendArray(combinedList, getRefinementsRecursive(FunctionType.getParamType(type, i), recursionCount));
|
||||
}
|
||||
|
||||
const returnType = FunctionType.getEffectiveReturnType(type);
|
||||
if (returnType) {
|
||||
appendArray(combinedList, getRefinementsRecursive(returnType, recursionCount));
|
||||
}
|
||||
|
||||
return combinedList;
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
// Creates a specialized version of the class, filling in any unspecified
|
||||
// type arguments with Unknown or default value.
|
||||
export function specializeWithDefaultTypeArgs(type: ClassType): ClassType {
|
||||
|
|
@ -2805,6 +2982,11 @@ function _requiresSpecialization(type: Type, options?: RequiresSpecializationOpt
|
|||
|
||||
switch (type.category) {
|
||||
case TypeCategory.Class: {
|
||||
// If the type has a refinement type, it may need to be specialized.
|
||||
if (type.priv.refinements) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ClassType.isPseudoGenericClass(type) && options?.ignorePseudoGeneric) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -3514,7 +3696,8 @@ export class TypeVarTransformer {
|
|||
if (
|
||||
typeParams.length === 0 &&
|
||||
!ClassType.isSpecialBuiltIn(classType) &&
|
||||
!ClassType.isBuiltIn(classType, 'type')
|
||||
!ClassType.isBuiltIn(classType, 'type') &&
|
||||
!classType.priv.refinements
|
||||
) {
|
||||
return classType;
|
||||
}
|
||||
|
|
@ -3596,18 +3779,51 @@ export class TypeVarTransformer {
|
|||
});
|
||||
}
|
||||
|
||||
// If specialization wasn't needed, don't allocate a new class.
|
||||
if (!specializationNeeded) {
|
||||
return classType;
|
||||
let newClassType = classType;
|
||||
|
||||
if (specializationNeeded) {
|
||||
newClassType = ClassType.specialize(
|
||||
classType,
|
||||
newTypeArgs,
|
||||
/* isTypeArgExplicit */ true,
|
||||
/* includeSubclasses */ undefined,
|
||||
newTupleTypeArgs
|
||||
);
|
||||
}
|
||||
|
||||
return ClassType.specialize(
|
||||
classType,
|
||||
newTypeArgs,
|
||||
/* isTypeArgExplicit */ true,
|
||||
/* includeSubclasses */ undefined,
|
||||
newTupleTypeArgs
|
||||
);
|
||||
// Apply transforms to the refinements.
|
||||
if (newClassType.priv.refinements) {
|
||||
const newRefinements: TypeRefinement[] = [];
|
||||
let refinementsChanged = false;
|
||||
|
||||
newClassType.priv.refinements.forEach((refinement) => {
|
||||
const newRefinement = this.transformRefinement(refinement);
|
||||
if (newRefinement) {
|
||||
// Strip wildcard refinements that are not enforced.
|
||||
if (isRefinementWildcard(newRefinement.value) && !newRefinement.isEnforced) {
|
||||
refinementsChanged = true;
|
||||
} else {
|
||||
newRefinements.push(newRefinement);
|
||||
if (newRefinement !== refinement) {
|
||||
refinementsChanged = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
refinementsChanged = true;
|
||||
}
|
||||
});
|
||||
|
||||
if (refinementsChanged) {
|
||||
newClassType = ClassType.cloneWithRefinements(
|
||||
newClassType,
|
||||
newRefinements.length > 0 ? newRefinements : undefined
|
||||
);
|
||||
|
||||
newClassType = simplifyRefinementTypeToLiteral(newClassType);
|
||||
}
|
||||
}
|
||||
|
||||
return newClassType;
|
||||
}
|
||||
|
||||
transformTypeVarsInFunctionType(sourceType: FunctionType, recursionCount: number): FunctionType | OverloadedType {
|
||||
|
|
@ -3798,6 +4014,10 @@ export class TypeVarTransformer {
|
|||
});
|
||||
}
|
||||
|
||||
transformRefinement(refinement: TypeRefinement): TypeRefinement | undefined {
|
||||
return refinement;
|
||||
}
|
||||
|
||||
private _isTypeVarScopePending(typeVarScopeId: TypeVarScopeId | undefined) {
|
||||
return !!typeVarScopeId && this._pendingTypeVarTransformations.has(typeVarScopeId);
|
||||
}
|
||||
|
|
@ -3897,6 +4117,23 @@ class BoundTypeVarTransform extends TypeVarTransformer {
|
|||
return undefined;
|
||||
}
|
||||
|
||||
override transformRefinement(refinement: TypeRefinement): TypeRefinement | undefined {
|
||||
if (!this._scopeIds) {
|
||||
return refinement;
|
||||
}
|
||||
|
||||
const newValue = makeRefinementVarsBound(refinement.value, this._scopeIds);
|
||||
const newCondition = refinement.condition
|
||||
? makeRefinementVarsBound(refinement.condition, this._scopeIds)
|
||||
: undefined;
|
||||
|
||||
if (newValue === refinement.value && newCondition === refinement.condition) {
|
||||
return refinement;
|
||||
}
|
||||
|
||||
return { ...refinement, value: newValue, condition: newCondition };
|
||||
}
|
||||
|
||||
private _isTypeVarInScope(typeVar: TypeVarType) {
|
||||
if (!typeVar.priv.scopeId) {
|
||||
return false;
|
||||
|
|
@ -3930,6 +4167,23 @@ class FreeTypeVarTransform extends TypeVarTransformer {
|
|||
return undefined;
|
||||
}
|
||||
|
||||
override transformRefinement(refinement: TypeRefinement): TypeRefinement | undefined {
|
||||
if (!this._scopeIds) {
|
||||
return refinement;
|
||||
}
|
||||
|
||||
const newValue = makeRefinementVarsFree(refinement.value, this._scopeIds);
|
||||
const newCondition = refinement.condition
|
||||
? makeRefinementVarsFree(refinement.condition, this._scopeIds)
|
||||
: undefined;
|
||||
|
||||
if (newValue === refinement.value && newCondition === refinement.condition) {
|
||||
return refinement;
|
||||
}
|
||||
|
||||
return { ...refinement, value: newValue, condition: newCondition };
|
||||
}
|
||||
|
||||
private _isTypeVarInScope(typeVar: TypeVarType) {
|
||||
if (!typeVar.priv.scopeId) {
|
||||
return false;
|
||||
|
|
@ -4141,6 +4395,25 @@ class ApplySolvedTypeVarsTransformer extends TypeVarTransformer {
|
|||
return type;
|
||||
}
|
||||
|
||||
override transformRefinement(refinement: TypeRefinement): TypeRefinement | undefined {
|
||||
const solutionSet = this._solution.getSolutionSet(this._activeConstraintSetIndex ?? 0);
|
||||
const newRefinementValue = applySolvedRefinementVars(refinement.value, solutionSet.getRefinementVarMap(), {
|
||||
replaceUnsolved: !!this._options.replaceUnsolved,
|
||||
refinements: this._options.refinements,
|
||||
});
|
||||
|
||||
if (newRefinementValue === refinement.value) {
|
||||
return refinement;
|
||||
}
|
||||
|
||||
return {
|
||||
...refinement,
|
||||
value: evaluateRefinementExpression(newRefinementValue, {
|
||||
refinements: this._options.refinements,
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
override doForEachConstraintSet(callback: () => FunctionType): FunctionType | OverloadedType {
|
||||
const solutionSets = this._solution.getSolutionSets();
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import { assert } from '../common/debug';
|
|||
import { Uri } from '../common/uri/uri';
|
||||
import { ArgumentNode, ExpressionNode, NameNode, ParamCategory } from '../parser/parseNodes';
|
||||
import { ClassDeclaration, FunctionDeclaration, SpecialBuiltInClassDeclaration } from './declaration';
|
||||
import { RefinementExpr, RefinementVar, TypeRefinement } from './refinementTypes';
|
||||
import { Symbol, SymbolTable } from './symbol';
|
||||
|
||||
export const enum TypeCategory {
|
||||
|
|
@ -791,6 +792,9 @@ export interface ClassDetailsPriv {
|
|||
// literal types (e.g. true or 'string' or 3).
|
||||
literalValue?: LiteralValue | undefined;
|
||||
|
||||
// Refinement definitions.
|
||||
refinements?: TypeRefinement[] | undefined;
|
||||
|
||||
// The typing module defines aliases for builtin types
|
||||
// (e.g. Tuple, List, Dict). This field holds the alias
|
||||
// name.
|
||||
|
|
@ -960,6 +964,26 @@ export namespace ClassType {
|
|||
return newClassType;
|
||||
}
|
||||
|
||||
export function cloneAddRefinement(classType: ClassType, refinement: TypeRefinement): ClassType {
|
||||
const newClassType = TypeBase.cloneType(classType);
|
||||
|
||||
if (!newClassType.priv.refinements) {
|
||||
newClassType.priv.refinements = [refinement];
|
||||
} else {
|
||||
newClassType.priv.refinements = [...newClassType.priv.refinements, refinement];
|
||||
}
|
||||
return newClassType;
|
||||
}
|
||||
|
||||
export function cloneWithRefinements(classType: ClassType, refinements: TypeRefinement[] | undefined): ClassType {
|
||||
const newClassType = TypeBase.cloneType(classType);
|
||||
if (refinements && refinements.length === 0) {
|
||||
refinements = undefined;
|
||||
}
|
||||
newClassType.priv.refinements = refinements;
|
||||
return newClassType;
|
||||
}
|
||||
|
||||
export function cloneForDeprecatedInstance(type: ClassType, deprecatedMessage?: string): ClassType {
|
||||
const newClassType = TypeBase.cloneType(type);
|
||||
newClassType.priv.deprecatedInstanceMessage = deprecatedMessage;
|
||||
|
|
@ -1038,6 +1062,26 @@ export namespace ClassType {
|
|||
return type1.priv.literalValue === type2.priv.literalValue;
|
||||
}
|
||||
|
||||
export function isRefinementSame(type1: ClassType, type2: ClassType): boolean {
|
||||
if (type1.priv.refinements === undefined) {
|
||||
return type2.priv.refinements === undefined;
|
||||
} else if (type2.priv.refinements === undefined) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (type1.priv.refinements.length !== type2.priv.refinements.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (let i = 0; i < type1.priv.refinements.length; i++) {
|
||||
if (!TypeRefinement.isSame(type1.priv.refinements[i], type2.priv.refinements[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Determines whether two typed dict classes are equivalent given
|
||||
// that one or both have narrowed entries (i.e. entries that are
|
||||
// guaranteed to be present).
|
||||
|
|
@ -1132,6 +1176,14 @@ export namespace ClassType {
|
|||
return true;
|
||||
}
|
||||
|
||||
export function hasRefinement(classType: ClassType, classId: string): boolean {
|
||||
if (!classType.priv.refinements) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return classType.priv.refinements.some((refinement) => refinement.classDetails.classId === classId);
|
||||
}
|
||||
|
||||
export function supportsAbstractMethods(classType: ClassType) {
|
||||
return !!(classType.shared.flags & ClassTypeFlags.SupportsAbstractMethods);
|
||||
}
|
||||
|
|
@ -1492,6 +1544,8 @@ export interface FunctionParam {
|
|||
_defaultType: Type | undefined;
|
||||
|
||||
defaultExpr: ExpressionNode | undefined;
|
||||
|
||||
refinementConditions?: RefinementExpr[];
|
||||
}
|
||||
|
||||
export namespace FunctionParam {
|
||||
|
|
@ -1501,9 +1555,10 @@ export namespace FunctionParam {
|
|||
flags = FunctionParamFlags.None,
|
||||
name?: string,
|
||||
defaultType?: Type,
|
||||
defaultExpr?: ExpressionNode
|
||||
defaultExpr?: ExpressionNode,
|
||||
refinementConditions?: RefinementExpr[]
|
||||
): FunctionParam {
|
||||
return { category, flags, name, _type: type, _defaultType: defaultType, defaultExpr };
|
||||
return { category, flags, name, _type: type, _defaultType: defaultType, defaultExpr, refinementConditions };
|
||||
}
|
||||
|
||||
export function isNameSynthesized(param: FunctionParam) {
|
||||
|
|
@ -1611,6 +1666,8 @@ interface FunctionDetailsShared {
|
|||
moduleName: string;
|
||||
flags: FunctionTypeFlags;
|
||||
typeParams: TypeVarType[];
|
||||
refinements: TypeRefinement[] | undefined;
|
||||
refinementVars: RefinementVar[] | undefined;
|
||||
parameters: FunctionParam[];
|
||||
declaredReturnType: Type | undefined;
|
||||
declaration: FunctionDeclaration | undefined;
|
||||
|
|
@ -1732,6 +1789,8 @@ export namespace FunctionType {
|
|||
moduleName,
|
||||
flags: functionFlags,
|
||||
typeParams: [],
|
||||
refinements: undefined,
|
||||
refinementVars: undefined,
|
||||
parameters: [],
|
||||
declaredReturnType: undefined,
|
||||
declaration: undefined,
|
||||
|
|
@ -3338,6 +3397,10 @@ export function isTypeSame(type1: Type, type2: Type, options: TypeSameOptions =
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!ClassType.isRefinementSame(type1, classType2)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!type1.priv.isUnpacked !== !classType2.priv.isUnpacked) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import {
|
|||
isUnresolvedAliasDeclaration,
|
||||
} from '../analyzer/declaration';
|
||||
import * as ParseTreeUtils from '../analyzer/parseTreeUtils';
|
||||
import { RefinementExprType } from '../analyzer/refinementTypes';
|
||||
import { SourceMapper } from '../analyzer/sourceMapper';
|
||||
import { isBuiltInModule } from '../analyzer/typeDocStringUtils';
|
||||
import { PrintTypeOptions, SynthesizedTypeInfo, TypeEvaluator } from '../analyzer/typeEvaluatorTypes';
|
||||
|
|
@ -269,6 +270,44 @@ export class HoverProvider {
|
|||
this._addResultsForSynthesizedType(results.parts, type, name);
|
||||
});
|
||||
this._addDocumentationPart(results.parts, node, /* resolvedDecl */ undefined);
|
||||
} else if (declInfo?.refinementInfo) {
|
||||
if (declInfo.refinementInfo.callSignature) {
|
||||
this._addResultsPart(
|
||||
results.parts,
|
||||
`(refinement operator) ${node.d.value}: ${declInfo.refinementInfo.callSignature}`,
|
||||
/* python */ true
|
||||
);
|
||||
if (declInfo.refinementInfo.callDocstring) {
|
||||
addDocumentationResultsPart(
|
||||
this._program.serviceProvider,
|
||||
declInfo.refinementInfo.callDocstring,
|
||||
this._format,
|
||||
results.parts,
|
||||
/* resolvedDecl */ undefined
|
||||
);
|
||||
}
|
||||
} else {
|
||||
let varTypeText: string | undefined;
|
||||
|
||||
if (declInfo.refinementInfo.isWildcard) {
|
||||
varTypeText = 'any';
|
||||
} else if (declInfo.refinementInfo.varType) {
|
||||
const varTypeMapping: { [key in RefinementExprType]: string } = {
|
||||
[RefinementExprType.Int]: 'Int',
|
||||
[RefinementExprType.Str]: 'str',
|
||||
[RefinementExprType.Bytes]: 'bytes',
|
||||
[RefinementExprType.Bool]: 'bool',
|
||||
[RefinementExprType.IntTuple]: 'tuple',
|
||||
};
|
||||
|
||||
varTypeText = varTypeMapping[declInfo.refinementInfo.varType];
|
||||
}
|
||||
this._addResultsPart(
|
||||
results.parts,
|
||||
`(refinement var) ${node.d.value}: ${varTypeText ?? 'unknown'}`,
|
||||
/* python */ true
|
||||
);
|
||||
}
|
||||
} else if (!node.parent || node.parent.nodeType !== ParseNodeType.ModuleName) {
|
||||
// If we had no declaration, see if we can provide a minimal tooltip. We'll skip
|
||||
// this if it's part of a module name, since a module name part with no declaration
|
||||
|
|
|
|||
|
|
@ -494,6 +494,8 @@ export namespace Localizer {
|
|||
export const expectedPatternExpr = () => getRawString('Diagnostic.expectedPatternExpr');
|
||||
export const expectedPatternSubjectExpr = () => getRawString('Diagnostic.expectedPatternSubjectExpr');
|
||||
export const expectedPatternValue = () => getRawString('Diagnostic.expectedPatternValue');
|
||||
export const expectedPredicateIf = () => getRawString('Diagnostic.expectedPredicateIf');
|
||||
export const expectedRefinement = () => getRawString('Diagnostic.expectedRefinement');
|
||||
export const expectedReturnExpr = () => getRawString('Diagnostic.expectedReturnExpr');
|
||||
export const expectedSliceIndex = () => getRawString('Diagnostic.expectedSliceIndex');
|
||||
export const expectedTypeNotString = () => getRawString('Diagnostic.expectedTypeNotString');
|
||||
|
|
@ -848,6 +850,30 @@ export namespace Localizer {
|
|||
export const readOnlyNotInTypedDict = () => getRawString('Diagnostic.readOnlyNotInTypedDict');
|
||||
export const recursiveDefinition = () =>
|
||||
new ParameterizedString<{ name: string }>(getRawString('Diagnostic.recursiveDefinition'));
|
||||
export const refinementBaseTypeInvalid = () => getRawString('Diagnostic.refinementBaseTypeInvalid');
|
||||
export const refinementCallArgCount = () =>
|
||||
new ParameterizedString<{ name: string; expected: number; received: number }>(
|
||||
getRawString('Diagnostic.refinementCallArgCount')
|
||||
);
|
||||
export const refinementCallArgUnpacked = () => getRawString('Diagnostic.refinementCallArgUnpacked');
|
||||
export const refinementCallArgKeyword = () => getRawString('Diagnostic.refinementCallArgKeyword');
|
||||
export const refinementConditionFailure = () => getRawString('Diagnostic.refinementConditionFailure');
|
||||
export const refinementFloatImaginary = () => getRawString('Diagnostic.refinementFloatImaginary');
|
||||
export const refinementIntTupleNotAllowed = () => getRawString('Diagnostic.refinementIntTupleNotAllowed');
|
||||
export const refinementPostCondition = () => getRawString('Diagnostic.refinementPostCondition');
|
||||
export const refinementPrecondition = () => getRawString('Diagnostic.refinementPrecondition');
|
||||
export const refinementTypeNotSupported = () =>
|
||||
new ParameterizedString<{ name: string }>(getRawString('Diagnostic.refinementTypeNotSupported'));
|
||||
export const refinementUnexpectedValueType = () =>
|
||||
new ParameterizedString<{ expected: string; received: string }>(
|
||||
getRawString('Diagnostic.refinementUnexpectedValueType')
|
||||
);
|
||||
export const refinementUnsupportedCall = () =>
|
||||
new ParameterizedString<{ name: string }>(getRawString('Diagnostic.refinementUnsupportedCall'));
|
||||
export const refinementUnsupportedExpression = () => getRawString('Diagnostic.refinementUnsupportedExpression');
|
||||
export const refinementUnsupportedOperation = () => getRawString('Diagnostic.refinementUnsupportedOperation');
|
||||
export const refinementVarNotInValue = () =>
|
||||
new ParameterizedString<{ name: string }>(getRawString('Diagnostic.refinementVarNotInValue'));
|
||||
export const relativeImportNotAllowed = () => getRawString('Diagnostic.relativeImportNotAllowed');
|
||||
export const requiredArgCount = () => getRawString('Diagnostic.requiredArgCount');
|
||||
export const requiredNotInTypedDict = () => getRawString('Diagnostic.requiredNotInTypedDict');
|
||||
|
|
@ -1007,6 +1033,7 @@ export namespace Localizer {
|
|||
export const typeGuardParamCount = () => getRawString('Diagnostic.typeGuardParamCount');
|
||||
export const typeIsReturnType = () =>
|
||||
new ParameterizedString<{ type: string; returnType: string }>(getRawString('Diagnostic.typeIsReturnType'));
|
||||
export const typeMetadataInvalid = () => getRawString('Diagnostic.typeMetadataInvalid');
|
||||
export const typeNotAwaitable = () =>
|
||||
new ParameterizedString<{ type: string }>(getRawString('Diagnostic.typeNotAwaitable'));
|
||||
export const typeNotIntantiable = () =>
|
||||
|
|
@ -1457,6 +1484,37 @@ export namespace Localizer {
|
|||
export const pyrightCommentIgnoreTip = () => getRawString('DiagnosticAddendum.pyrightCommentIgnoreTip');
|
||||
export const readOnlyAttribute = () =>
|
||||
new ParameterizedString<{ name: string }>(getRawString('DiagnosticAddendum.readOnlyAttribute'));
|
||||
export const refinementBroadcast = () => getRawString('DiagnosticAddendum.refinementBroadcast');
|
||||
export const refinementConcatMismatch = () => getRawString('DiagnosticAddendum.refinementConcatMismatch');
|
||||
export const refinementConditionNotSatisfied = () =>
|
||||
new ParameterizedString<{ condition: string }>(
|
||||
getRawString('DiagnosticAddendum.refinementConditionNotSatisfied')
|
||||
);
|
||||
export const refinementIndexOutOfRange = () =>
|
||||
new ParameterizedString<{ value: number }>(getRawString('DiagnosticAddendum.refinementIndexOutOfRange'));
|
||||
export const refinementLiteralAssignment = () =>
|
||||
new ParameterizedString<{ expected: string; received: string }>(
|
||||
getRawString('DiagnosticAddendum.refinementLiteralAssignment')
|
||||
);
|
||||
export const refinementPermuteDuplicate = () => getRawString('DiagnosticAddendum.refinementPermuteDuplicate');
|
||||
export const refinementPermuteMismatch = () =>
|
||||
new ParameterizedString<{ expected: string; received: string }>(
|
||||
getRawString('DiagnosticAddendum.refinementPermuteMismatch')
|
||||
);
|
||||
export const refinementReshapeInferred = () => getRawString('DiagnosticAddendum.refinementReshapeInferred');
|
||||
export const refinementReshapeMismatch = () => getRawString('DiagnosticAddendum.refinementReshapeMismatch');
|
||||
export const refinementShapeMismatch = () =>
|
||||
new ParameterizedString<{ expected: string; received: string }>(
|
||||
getRawString('DiagnosticAddendum.refinementShapeMismatch')
|
||||
);
|
||||
export const refinementTupleMismatch = () =>
|
||||
new ParameterizedString<{ expected: string; received: string }>(
|
||||
getRawString('DiagnosticAddendum.refinementTupleMismatch')
|
||||
);
|
||||
export const refinementValMismatch = () =>
|
||||
new ParameterizedString<{ expected: string; received: string }>(
|
||||
getRawString('DiagnosticAddendum.refinementValMismatch')
|
||||
);
|
||||
export const seeDeclaration = () => getRawString('DiagnosticAddendum.seeDeclaration');
|
||||
export const seeClassDeclaration = () => getRawString('DiagnosticAddendum.seeClassDeclaration');
|
||||
export const seeFunctionDeclaration = () => getRawString('DiagnosticAddendum.seeFunctionDeclaration');
|
||||
|
|
|
|||
|
|
@ -441,6 +441,11 @@
|
|||
"message": "Expected pattern value expression of the form \"a.b\"",
|
||||
"comment": "{Locked='a.b'}"
|
||||
},
|
||||
"expectedPredicateIf": "Expected \"if\" after refinement value expression",
|
||||
"expectedRefinement": {
|
||||
"message": "Expected string literal for refinement type definition",
|
||||
"comment": "{Locked='literal'"
|
||||
},
|
||||
"expectedReturnExpr": {
|
||||
"message": "Expected expression after \"return\"",
|
||||
"comment": "{Locked='return'}"
|
||||
|
|
@ -1063,6 +1068,23 @@
|
|||
"comment": "{Locked='ReadOnly'}"
|
||||
},
|
||||
"recursiveDefinition": "Type of \"{name}\" could not be determined because it refers to itself",
|
||||
"refinementBaseTypeInvalid": "Base type for a refinement type must be a nominal class",
|
||||
"refinementBoolValueExpected": "Expected boolean expression",
|
||||
"refinementBytes": "Bytes values are not supported in refinement types",
|
||||
"refinementCallArgCount": "Incorrect argument count for \"name\"; expected {expected} but received {received}",
|
||||
"refinementCallArgKeyword": "Keyword arguments are not supported in refinement types",
|
||||
"refinementConditionFailure": "Argument value violates refinement condition",
|
||||
"refinementCallArgUnpacked": "Unpacked argument expressions are not supported in refinement types",
|
||||
"refinementFloatImaginary": "Float and imaginary values are not supported in refinement types",
|
||||
"refinementIntTupleNotAllowed": "Literal value not supported by IntTupleRefinement",
|
||||
"refinementPostCondition": "Refinement condition not allowed in return type",
|
||||
"refinementPrecondition": "Invalid refinement value expression; use variable, literal, or \"_\"",
|
||||
"refinementTypeNotSupported": "Refinement type \"{name}\" is not supported",
|
||||
"refinementUnexpectedValueType": "Unexpected expression type; expected \"{expected}\" but received \"{received}\"",
|
||||
"refinementUnsupportedCall": "Operation \"{name}\" not supported in refinement type",
|
||||
"refinementUnsupportedExpression": "Expression not supported in refinement type",
|
||||
"refinementUnsupportedOperation": "Operation not supported in refinement type",
|
||||
"refinementVarNotInValue": "Refinement variable \"{name}\" is undefined",
|
||||
"relativeImportNotAllowed": {
|
||||
"message": "Relative imports cannot be used with \"import .a\" form; use \"from . import a\" instead",
|
||||
"comment": "{Locked='import .a','from . import a'}"
|
||||
|
|
@ -1296,6 +1318,9 @@
|
|||
"message": "Return type of TypeIs (\"{returnType}\") is not consistent with value parameter type (\"{type}\")",
|
||||
"comment": "{Locked='TypeIs'}"
|
||||
},
|
||||
"typeMetadataInvalid": {
|
||||
"message": "Type metadata must be instance of TypeMetadata or a literal value",
|
||||
"comment": "{Locked='TypeMetadata'}"},
|
||||
"typeNotAwaitable": {
|
||||
"message": "\"{type}\" is not awaitable",
|
||||
"comment": "{Locked='awaitable'}"
|
||||
|
|
@ -1933,6 +1958,18 @@
|
|||
"comment": "{Locked='# pyright: ignore[<diagnostic rules>]'}"
|
||||
},
|
||||
"readOnlyAttribute": "Attribute \"{name}\" is read-only",
|
||||
"refinementBroadcast": "Shapes differ and do not allow for broadcasting",
|
||||
"refinementConcatMismatch": "Shape must match for all dimensions other than the concatenation dimension",
|
||||
"refinementConditionNotSatisfied": "Refinement condition not satisfied: \"{condition}\"",
|
||||
"refinementIndexOutOfRange": "Index {value} is out of range",
|
||||
"refinementLiteralAssignment": "Refinement literal value is incompatible; expected {expected} but received {received}",
|
||||
"refinementPermuteDuplicate": "Permute does not allow duplicate dimensions",
|
||||
"refinementPermuteMismatch": "Permute dimension count mismatch (expected {expected} but received {received})",
|
||||
"refinementReshapeInferred": "Only one inferred dimension is allowed",
|
||||
"refinementReshapeMismatch": "New shape does not match existing shape",
|
||||
"refinementShapeMismatch": "Could not assign shape \"{received}\" to \"{expected}\"",
|
||||
"refinementTupleMismatch": "Could not assign refinement tuple \"{received}\" to \"{expected}\"",
|
||||
"refinementValMismatch": "Refinement value is incompatible; expected {expected} but received {received}",
|
||||
"seeClassDeclaration": "See class declaration",
|
||||
"seeDeclaration": "See declaration",
|
||||
"seeFunctionDeclaration": "See function declaration",
|
||||
|
|
|
|||
|
|
@ -110,6 +110,8 @@ export const enum ParseNodeType {
|
|||
TypeParameter,
|
||||
TypeParameterList,
|
||||
TypeAlias,
|
||||
|
||||
Refinement,
|
||||
}
|
||||
|
||||
export const enum ErrorExpressionCategory {
|
||||
|
|
@ -1807,6 +1809,11 @@ export interface StringListNode extends ParseNodeBase<ParseNodeType.StringList>
|
|||
// into an expression.
|
||||
annotation: ExpressionNode | undefined;
|
||||
|
||||
// If the string represents a refinement type definition
|
||||
// within a type annotation, it is further parsed into
|
||||
// an expression (lazily by the type evaluator).
|
||||
refinement: RefinementNode | undefined;
|
||||
|
||||
// Indicates that the string list is enclosed in parens.
|
||||
hasParens: boolean;
|
||||
};
|
||||
|
|
@ -1824,6 +1831,7 @@ export namespace StringListNode {
|
|||
d: {
|
||||
strings,
|
||||
annotation: undefined,
|
||||
refinement: undefined,
|
||||
hasParens: false,
|
||||
},
|
||||
};
|
||||
|
|
@ -2742,6 +2750,39 @@ export type PatternAtomNode =
|
|||
| PatternValueNode
|
||||
| ErrorNode;
|
||||
|
||||
export interface RefinementNode extends ParseNodeBase<ParseNodeType.Refinement> {
|
||||
d: {
|
||||
valueExpr: ExpressionNode;
|
||||
conditionExpr: ExpressionNode | undefined;
|
||||
};
|
||||
}
|
||||
|
||||
export namespace RefinementNode {
|
||||
export function create(valueExpr: ExpressionNode, conditionExpr: ExpressionNode | undefined) {
|
||||
const node: RefinementNode = {
|
||||
start: valueExpr.start,
|
||||
length: valueExpr.length,
|
||||
nodeType: ParseNodeType.Refinement,
|
||||
id: _nextNodeId++,
|
||||
parent: undefined,
|
||||
a: undefined,
|
||||
d: {
|
||||
valueExpr,
|
||||
conditionExpr,
|
||||
},
|
||||
};
|
||||
|
||||
valueExpr.parent = node;
|
||||
|
||||
if (conditionExpr) {
|
||||
conditionExpr.parent = node;
|
||||
extendRange(node, conditionExpr);
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
}
|
||||
|
||||
export type ParseNode =
|
||||
| ErrorNode
|
||||
| ArgumentNode
|
||||
|
|
@ -2799,6 +2840,7 @@ export type ParseNode =
|
|||
| PatternMappingNode
|
||||
| PatternSequenceNode
|
||||
| PatternValueNode
|
||||
| RefinementNode
|
||||
| RaiseNode
|
||||
| ReturnNode
|
||||
| SetNode
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ import {
|
|||
PatternSequenceNode,
|
||||
PatternValueNode,
|
||||
RaiseNode,
|
||||
RefinementNode,
|
||||
ReturnNode,
|
||||
SetNode,
|
||||
SliceNode,
|
||||
|
|
@ -217,6 +218,7 @@ export const enum ParseTextMode {
|
|||
Expression,
|
||||
VariableAnnotation,
|
||||
FunctionAnnotation,
|
||||
Refinement,
|
||||
}
|
||||
|
||||
// Limit the max child node depth to prevent stack overflows.
|
||||
|
|
@ -318,6 +320,15 @@ export class Parser {
|
|||
initialParenDepth?: number,
|
||||
typingSymbolAliases?: Map<string, string>
|
||||
): ParseExpressionTextResults<FunctionAnnotationNode>;
|
||||
parseTextExpression(
|
||||
fileContents: string,
|
||||
textOffset: number,
|
||||
textLength: number,
|
||||
parseOptions: ParseOptions,
|
||||
parseTextMode: ParseTextMode.Refinement,
|
||||
initialParenDepth?: number,
|
||||
typingSymbolAliases?: Map<string, string>
|
||||
): ParseExpressionTextResults<RefinementNode>;
|
||||
parseTextExpression(
|
||||
fileContents: string,
|
||||
textOffset: number,
|
||||
|
|
@ -326,7 +337,7 @@ export class Parser {
|
|||
parseTextMode = ParseTextMode.Expression,
|
||||
initialParenDepth = 0,
|
||||
typingSymbolAliases?: Map<string, string>
|
||||
): ParseExpressionTextResults<ExpressionNode | FunctionAnnotationNode> {
|
||||
): ParseExpressionTextResults<ExpressionNode | FunctionAnnotationNode | RefinementNode> {
|
||||
const diagSink = new DiagnosticSink();
|
||||
this._startNewParse(fileContents, textOffset, textLength, parseOptions, diagSink, initialParenDepth);
|
||||
|
||||
|
|
@ -334,13 +345,16 @@ export class Parser {
|
|||
this._typingSymbolAliases = new Map<string, string>(typingSymbolAliases);
|
||||
}
|
||||
|
||||
let parseTree: ExpressionNode | FunctionAnnotationNode | undefined;
|
||||
let parseTree: ExpressionNode | FunctionAnnotationNode | RefinementNode | undefined;
|
||||
if (parseTextMode === ParseTextMode.VariableAnnotation) {
|
||||
this._isParsingQuotedText = true;
|
||||
parseTree = this._parseTypeAnnotation();
|
||||
} else if (parseTextMode === ParseTextMode.FunctionAnnotation) {
|
||||
this._isParsingQuotedText = true;
|
||||
parseTree = this._parseFunctionTypeAnnotation();
|
||||
} else if (parseTextMode === ParseTextMode.Refinement) {
|
||||
this._isParsingQuotedText = true;
|
||||
parseTree = this._parseRefinement();
|
||||
} else {
|
||||
const exprListResult = this._parseTestOrStarExpressionList(
|
||||
/* allowAssignmentExpression */ false,
|
||||
|
|
@ -3434,6 +3448,8 @@ export class Parser {
|
|||
return leftExpr;
|
||||
}
|
||||
|
||||
const wasParsingTypeAnnotation = this._isParsingTypeAnnotation;
|
||||
|
||||
let peekToken = this._peekToken();
|
||||
let nextOperator = this._peekOperatorType();
|
||||
while (
|
||||
|
|
@ -3444,12 +3460,19 @@ export class Parser {
|
|||
nextOperator === OperatorType.FloorDivide
|
||||
) {
|
||||
this._getNextToken();
|
||||
|
||||
if (nextOperator === OperatorType.MatrixMultiply) {
|
||||
this._isParsingTypeAnnotation = false;
|
||||
}
|
||||
|
||||
const rightExpr = this._parseArithmeticFactor();
|
||||
leftExpr = this._createBinaryOperationNode(leftExpr, rightExpr, peekToken, nextOperator);
|
||||
peekToken = this._peekToken();
|
||||
nextOperator = this._peekOperatorType();
|
||||
}
|
||||
|
||||
this._isParsingTypeAnnotation = wasParsingTypeAnnotation;
|
||||
|
||||
return leftExpr;
|
||||
}
|
||||
|
||||
|
|
@ -4923,6 +4946,45 @@ export class Parser {
|
|||
return FormatStringNode.create(startToken, endToken, middleTokens, fieldExpressions, formatExpressions);
|
||||
}
|
||||
|
||||
private _parseRefinement(): RefinementNode {
|
||||
const valueExpr = this._parseRefinementValue();
|
||||
let conditionExpr: ExpressionNode | undefined;
|
||||
|
||||
if (this._consumeTokenIfKeyword(KeywordType.If)) {
|
||||
conditionExpr = this._parseOrTest();
|
||||
} else {
|
||||
const nextToken = this._peekToken();
|
||||
if (nextToken.type !== TokenType.EndOfStream) {
|
||||
this._addSyntaxError(LocMessage.expectedPredicateIf(), nextToken);
|
||||
this._consumeTokensUntilType([TokenType.EndOfStream]);
|
||||
}
|
||||
}
|
||||
|
||||
return RefinementNode.create(valueExpr, conditionExpr);
|
||||
}
|
||||
|
||||
private _parseRefinementValue(): ExpressionNode {
|
||||
if (this._isNextTokenNeverExpression()) {
|
||||
return this._handleExpressionParseError(
|
||||
ErrorExpressionCategory.MissingExpression,
|
||||
LocMessage.expectedExpr()
|
||||
);
|
||||
}
|
||||
|
||||
const exprListResult = this._parseExpressionListGeneric(() => {
|
||||
if (this._peekOperatorType() === OperatorType.Multiply) {
|
||||
return this._parseExpression(/* allowUnpack */ true);
|
||||
}
|
||||
|
||||
return this._parseOrTest();
|
||||
});
|
||||
|
||||
if (exprListResult.parseError) {
|
||||
return exprListResult.parseError;
|
||||
}
|
||||
return this._makeExpressionOrTuple(exprListResult, /* enclosedInParens */ false);
|
||||
}
|
||||
|
||||
private _createBinaryOperationNode(
|
||||
leftExpression: ExpressionNode,
|
||||
rightExpression: ExpressionNode,
|
||||
|
|
|
|||
56
packages/pyright-internal/src/tests/samples/refinement1.py
Normal file
56
packages/pyright-internal/src/tests/samples/refinement1.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
# This sample tests basic handling of refinement types.
|
||||
|
||||
# pyright: reportMissingModuleSource=false
|
||||
|
||||
from typing import Annotated, Any, Iterable, TypedDict
|
||||
from typing_extensions import IntValue
|
||||
|
||||
|
||||
class TD1(TypedDict):
|
||||
x: int
|
||||
|
||||
|
||||
def v_ok1(v: Annotated[int, IntValue("x")]):
|
||||
reveal_type(v, expected_text='int @ "x"')
|
||||
|
||||
def v_ok2(v: Annotated[int, IntValue(value=1)]):
|
||||
reveal_type(v, expected_text='int @ 1')
|
||||
|
||||
# This should generate an error because refinement types
|
||||
# apply only to nominal class types.
|
||||
v_bad1: Annotated[int | str, IntValue("x")]
|
||||
|
||||
# This should generate an error because refinement types
|
||||
# apply only to nominal class types.
|
||||
v_bad2: Annotated[Iterable, IntValue("x")]
|
||||
|
||||
# This should generate an error because refinement types
|
||||
# apply only to nominal class types.
|
||||
v_bad3: Annotated[TD1, IntValue("x")]
|
||||
|
||||
# This should generate an error because refinement types
|
||||
# apply only to nominal class types.
|
||||
v_bad4: Annotated[Any, IntValue("x")]
|
||||
|
||||
|
||||
def x_ok1(v: int @ IntValue("x")):
|
||||
reveal_type(v, expected_text='int @ "x"')
|
||||
|
||||
def x_ok2(v: int @ IntValue(value=1)):
|
||||
reveal_type(v, expected_text='int @ 1')
|
||||
|
||||
# This should generate an error because refinement types
|
||||
# apply only to nominal class types.
|
||||
x_bad1: (int | str) @ IntValue("x")
|
||||
|
||||
# This should generate an error because refinement types
|
||||
# apply only to nominal class types.
|
||||
x_bad2: Iterable @ IntValue("x")
|
||||
|
||||
# This should generate an error because refinement types
|
||||
# apply only to nominal class types.
|
||||
x_bad3: TD1 @ IntValue("x")
|
||||
|
||||
# This should generate an error because refinement types
|
||||
# apply only to nominal class types.
|
||||
x_bad4: Any @ IntValue("x")
|
||||
59
packages/pyright-internal/src/tests/samples/refinement2.py
Normal file
59
packages/pyright-internal/src/tests/samples/refinement2.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
# This sample tests basic parsing of refinement predicates and
|
||||
# refinement variable consistency.
|
||||
|
||||
# pyright: reportMissingModuleSource=false
|
||||
|
||||
from typing import Annotated
|
||||
from typing_extensions import IntValue, Refinement, Shape, StrValue
|
||||
|
||||
|
||||
class Tensor: ...
|
||||
|
||||
|
||||
v_ok1: Annotated[int, IntValue("x if x > 1 and (x < 10 or x % 2 == 0)")]
|
||||
v_ok2: Annotated[int, Shape("x, y if x > 1 and (x < 10 or x % 2 == 0)")]
|
||||
v_ok3: Annotated[int, Shape("x, *y if x > 1 and (x < 10 or x % 2 == 0)")]
|
||||
|
||||
|
||||
# This should generate a syntax error because ":" should be "if".
|
||||
v_bad1: Annotated[int, IntValue("x: x > 1")]
|
||||
|
||||
# This should generate a syntax error because "x" isn't a bool value.
|
||||
v_bad2: Annotated[int, IntValue("x if x")]
|
||||
|
||||
# This should generate a syntax error because "y" isn't an int value.
|
||||
v_bad3: Annotated[int, Shape("x, *y if y < 1")]
|
||||
|
||||
# This should generate a syntax error because "x, y" isn't an int value.
|
||||
v_bad4: Annotated[int, IntValue("x, y")]
|
||||
|
||||
# This should generate a syntax error because "x" isn't an int value.
|
||||
v_bad5: Annotated[int, Shape("x if x < 1")]
|
||||
|
||||
# This should generate a syntax error because it's an invalid expression.
|
||||
v_bad6: Annotated[int, IntValue("x if x < 1 x")]
|
||||
|
||||
# This should generate a syntax error because it's an invalid expression.
|
||||
v_bad7: Annotated[int, IntValue("x if x.foo > 1")]
|
||||
|
||||
# This should generate a syntax error because it's an invalid expression.
|
||||
v_bad8: Annotated[int, IntValue("x if x[1] > 1")]
|
||||
|
||||
# This should generate a syntax error because it's an unsupported call.
|
||||
v_bad9: Annotated[int, IntValue("x if call(x)")]
|
||||
|
||||
|
||||
class CustomRefinementDomain(Refinement):
|
||||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
# This should generate a syntax error because it's an unknown refinement domain.
|
||||
v_bad10: Annotated[int, CustomRefinementDomain("x")]
|
||||
|
||||
# This should generate two errors because "x" is inconsistent.
|
||||
v_bad11: int @ IntValue("x") | str @ StrValue("x") | Tensor @ Shape("x")
|
||||
|
||||
|
||||
# This should generate two errors because "x" is inconsistent.
|
||||
type TA_Bad1 = int @ IntValue("x") | str @ StrValue("x") | Tensor @ Shape("x")
|
||||
27
packages/pyright-internal/src/tests/samples/refinement3.py
Normal file
27
packages/pyright-internal/src/tests/samples/refinement3.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# This sample tests basic scoping rules for refinement variables.
|
||||
|
||||
# pyright: reportMissingModuleSource=false
|
||||
|
||||
from typing_extensions import IntValue, StrValue
|
||||
|
||||
def func_good1(a: int @ IntValue("x"), b: str @ IntValue("x")) -> None:
|
||||
pass
|
||||
|
||||
# This should generate an error because "x" is used twice
|
||||
# and has inconsistent types.
|
||||
def func_bad1(a: int @ IntValue("x"), b: str @ StrValue("x")) -> None:
|
||||
pass
|
||||
|
||||
def outer1(a: int @ IntValue("x")):
|
||||
# This should generate an error because "x"
|
||||
# has inconsistent types.
|
||||
def inner(a: str @ StrValue("x")):
|
||||
pass
|
||||
|
||||
v1: int @ IntValue("x")
|
||||
|
||||
# This should generate an error because "x"
|
||||
# has inconsistent types.
|
||||
v2: str @ StrValue("x")
|
||||
|
||||
|
||||
55
packages/pyright-internal/src/tests/samples/refinement4.py
Normal file
55
packages/pyright-internal/src/tests/samples/refinement4.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
# This sample tests refinement condition validation.
|
||||
|
||||
# pyright: reportMissingModuleSource=false
|
||||
|
||||
from typing import Any, cast
|
||||
from typing_extensions import StrValue
|
||||
|
||||
|
||||
type SingleDigitIt = int @ "x if x >= 0 and x < 10"
|
||||
|
||||
x1: SingleDigitIt = 0
|
||||
x2: SingleDigitIt = 9
|
||||
|
||||
# This should generate an error.
|
||||
x3: SingleDigitIt = 10
|
||||
|
||||
# This should generate an error.
|
||||
x4: SingleDigitIt = -1
|
||||
|
||||
|
||||
def func1(a: int @ "x if x >= 0", b: int @ "y if y < 0 and x + y < 2") -> int:
|
||||
return a + b
|
||||
|
||||
|
||||
func1(1, -1)
|
||||
func1(10, -9)
|
||||
|
||||
# This should generate an error.
|
||||
func1(1, 0)
|
||||
|
||||
# This should generate an error.
|
||||
func1(10, -2)
|
||||
|
||||
|
||||
def func2(a: int @ "x", b: int @ "y if x + y < 2") -> int @ "x":
|
||||
result = a + b
|
||||
reveal_type(result, expected_text='int @ "x + y"')
|
||||
return cast(Any, result)
|
||||
|
||||
|
||||
func2(1, -1)
|
||||
func2(10, -9)
|
||||
|
||||
# This should generate an error.
|
||||
func2(1, 4)
|
||||
|
||||
# This should generate an error.
|
||||
func2(10, -2)
|
||||
|
||||
|
||||
y1: str @ StrValue("x if x == 'hi' or x == 'bye'") = "bye"
|
||||
y1 = "hi"
|
||||
|
||||
# This should generate an error.
|
||||
y1 = "neither"
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
# This sample tests the handling of custom refinement types.
|
||||
|
||||
# pyright: reportMissingModuleSource=false
|
||||
|
||||
from typing_extensions import StrRefinement
|
||||
|
||||
|
||||
class Units(StrRefinement):
|
||||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
type FloatYards = float @ Units(value="yards")
|
||||
type FloatMeters = float @ Units(value="meters")
|
||||
|
||||
|
||||
def add_units(a: float @ Units("x"), b: float @ Units("x")) -> float @ Units("x"):
|
||||
return a + b
|
||||
|
||||
|
||||
def convert_yards_to_meters(a: float @ Units("'yards'")) -> float @ Units("'meters'"):
|
||||
return a * 0.9144
|
||||
|
||||
|
||||
def test2(a: float @ Units("'meters'"), b: float @ Units("'yards'")):
|
||||
m = convert_yards_to_meters(b)
|
||||
reveal_type(m, expected_text="float @ Units(\"'meters'\")")
|
||||
add_units(a, m)
|
||||
|
||||
# This should generate an error.
|
||||
add_units(a, b)
|
||||
|
||||
|
||||
def test3(a: FloatMeters, b: FloatYards):
|
||||
m = convert_yards_to_meters(b)
|
||||
reveal_type(m, expected_text="float @ Units(\"'meters'\")")
|
||||
add_units(a, m)
|
||||
|
||||
# This should generate an error.
|
||||
add_units(a, b)
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
# This sample tests enforced and unenforced refinement types.
|
||||
|
||||
# pyright: reportMissingModuleSource=false
|
||||
|
||||
from typing_extensions import IntValue, Shape
|
||||
|
||||
|
||||
def get_int() -> int: ...
|
||||
|
||||
|
||||
# This should generate an error.
|
||||
i1: int @ 1 = get_int()
|
||||
|
||||
# This should generate an error.
|
||||
i2: int @ IntValue(value=1) = get_int()
|
||||
|
||||
i3: int @ IntValue(value=1, enforce=False) = get_int()
|
||||
|
||||
i4: int @ "x" = get_int()
|
||||
|
||||
i5: int @ IntValue("x", enforce=False) = get_int()
|
||||
|
||||
# This should generate an error.
|
||||
i6: int @ IntValue("x", enforce=True) = get_int()
|
||||
|
||||
|
||||
class Tensor: ...
|
||||
|
||||
|
||||
t1: Tensor @ Shape("x") = Tensor()
|
||||
|
||||
# This should generate an error.
|
||||
t2: Tensor @ Shape("x", enforce=True) = Tensor()
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
# This sample tests the case where a class supports implicit refinement
|
||||
# types through the __type_metadata__ magic method.
|
||||
|
||||
# pyright: reportMissingModuleSource=false
|
||||
|
||||
from typing import cast
|
||||
from typing_extensions import Shape
|
||||
|
||||
|
||||
class ClassA:
|
||||
@classmethod
|
||||
def __type_metadata__(cls, pred: str) -> Shape:
|
||||
return Shape(pred)
|
||||
|
||||
|
||||
def test1(a: ClassA @ Shape("x, "), b: ClassA @ "y, ") -> ClassA @ "x, y":
|
||||
reveal_type(a, expected_text='ClassA @ "x,"')
|
||||
reveal_type(b, expected_text='ClassA @ "y,"')
|
||||
|
||||
return cast(ClassA @ "x, y", ClassA())
|
||||
|
||||
|
||||
def test2(m: ClassA @ "1, ", n: ClassA @ "2, "):
|
||||
v = test1(m, n)
|
||||
reveal_type(v, expected_text='ClassA @ "1, 2"')
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
# This sample tests basic interactions between Literal and refinement types.
|
||||
|
||||
# pyright: reportMissingModuleSource=false
|
||||
|
||||
from typing import Literal, cast
|
||||
from typing_extensions import StrValue
|
||||
|
||||
|
||||
def func1(a: int @ "x") -> int @ "x":
|
||||
return a
|
||||
|
||||
|
||||
v1 = func1(-1)
|
||||
reveal_type(v1, expected_text="Literal[-1]")
|
||||
|
||||
|
||||
def func2(a: int @ "x", b: int @ "y") -> int @ "x" | int @ "y":
|
||||
return a if a > b else b
|
||||
|
||||
|
||||
v2 = func2(-1, 2)
|
||||
reveal_type(v2, expected_text="Literal[-1, 2]")
|
||||
|
||||
|
||||
def func3(a: str @ StrValue("x"), b: str @ StrValue("y")) -> str @ StrValue("x + y"):
|
||||
return cast(str @ StrValue("x + y"), a + b)
|
||||
|
||||
|
||||
v3 = func3("hi ", "there")
|
||||
reveal_type(v3, expected_text="Literal['hi there']")
|
||||
|
||||
|
||||
def func4(x: int @ 2):
|
||||
y1: Literal[2] = x
|
||||
y2: Literal[1, 2] = x
|
||||
|
||||
# This should result in an error.
|
||||
y3: Literal[3] = x
|
||||
|
||||
|
||||
def func5(x: bool @ False):
|
||||
y1: Literal[False] = x
|
||||
|
||||
y2: bool @ False = x
|
||||
|
||||
# This should result in an error.
|
||||
y3: Literal[True] = x
|
||||
|
||||
# This should result in an error.
|
||||
y4: bool @ True = x
|
||||
|
||||
|
||||
def is_greater(a: int @ "a", b: int @ "b") -> bool @ "a > b":
|
||||
return a > b
|
||||
|
||||
|
||||
def func6():
|
||||
reveal_type(is_greater(1, 2), expected_text="Literal[False]")
|
||||
reveal_type(is_greater(2, 1), expected_text="Literal[True]")
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
# This sample tests the "Shape" refinement type.
|
||||
|
||||
# pyright: reportMissingModuleSource=false
|
||||
|
||||
from typing_extensions import Shape
|
||||
|
||||
|
||||
class Tensor: ...
|
||||
|
||||
|
||||
def matmul(a: Tensor @ Shape("x, y"), b: Tensor @ Shape("y, z")) -> Tensor @ Shape(
|
||||
"x, z"
|
||||
): ...
|
||||
|
||||
|
||||
def func1(
|
||||
a: Tensor @ Shape("a, b"), b: Tensor @ Shape("b, c"), c: Tensor @ Shape("b, b")
|
||||
):
|
||||
v1 = matmul(a, b)
|
||||
reveal_type(v1, expected_text='Tensor @ Shape("a, c")')
|
||||
|
||||
v2 = matmul(a, c)
|
||||
reveal_type(v2, expected_text='Tensor @ Shape("a, b")')
|
||||
|
||||
v3 = matmul(c, c)
|
||||
reveal_type(v3, expected_text='Tensor @ Shape("b, b")')
|
||||
|
||||
v4 = matmul(c, b)
|
||||
reveal_type(v4, expected_text='Tensor @ Shape("b, c")')
|
||||
|
||||
# This should generate an error.
|
||||
matmul(c, a)
|
||||
|
||||
# This should generate an error.
|
||||
matmul(b, a)
|
||||
|
||||
|
||||
class Size(tuple[int, ...]): ...
|
||||
|
||||
|
||||
def func2(v: Size @ Shape("x, y, z")):
|
||||
a, b, c = v
|
||||
reveal_type(a, expected_text='int @ "x"')
|
||||
reveal_type(b, expected_text='int @ "y"')
|
||||
reveal_type(c, expected_text='int @ "z"')
|
||||
|
||||
d, *x = v
|
||||
reveal_type(d, expected_text='int @ "x"')
|
||||
reveal_type(x, expected_text='list[int @ "y" | int @ "z"]')
|
||||
|
||||
k, j, *m, n = v
|
||||
reveal_type(k, expected_text='int @ "x"')
|
||||
reveal_type(j, expected_text='int @ "y"')
|
||||
reveal_type(m, expected_text="list[Never]")
|
||||
reveal_type(n, expected_text='int @ "z"')
|
||||
137
packages/pyright-internal/src/tests/samples/refinementShape2.py
Normal file
137
packages/pyright-internal/src/tests/samples/refinementShape2.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
# This sample tests various aspects of the Shape refinement type.
|
||||
|
||||
# pyright: reportMissingModuleSource=false
|
||||
|
||||
from tensorlib import Size, Tensor, cat, conv2d, randn, sum
|
||||
|
||||
|
||||
def func1(x: Size @ "x, y") -> Size @ "x, y":
|
||||
return x
|
||||
|
||||
|
||||
def func2(s1: Size @ "1, 2", s2: Size @ "1, 2, x"):
|
||||
# This should generate an error.
|
||||
func1(s2)
|
||||
|
||||
v1 = func1(s1)
|
||||
reveal_type(v1, expected_text='Size @ "1, 2"')
|
||||
|
||||
x1, y1 = s1
|
||||
reveal_type(x1, expected_text="int @ 1")
|
||||
reveal_type(y1, expected_text="int @ 2")
|
||||
|
||||
# This should generate an error.
|
||||
x2, y2, z2 = s1
|
||||
|
||||
x3, *other3 = s2
|
||||
reveal_type(x3, expected_text="int @ 1")
|
||||
reveal_type(other3, expected_text='list[int @ 2 | int @ "x"]')
|
||||
|
||||
|
||||
def index1(t1: Tensor @ "a, b, c"):
|
||||
s1 = t1.shape
|
||||
reveal_type(s1, expected_text='Size @ "a, b, c"')
|
||||
|
||||
s2 = s1[2]
|
||||
reveal_type(s2, expected_text='int @ "c"')
|
||||
|
||||
s3 = s1[-3]
|
||||
reveal_type(s3, expected_text='int @ "a"')
|
||||
|
||||
# This should generate an error.
|
||||
s4 = s1[-4]
|
||||
|
||||
# This should generate an error.
|
||||
s5 = s1[4]
|
||||
|
||||
|
||||
def index2(t1: Tensor @ "a, b, *other"):
|
||||
s1 = t1.shape
|
||||
reveal_type(s1, expected_text='Size @ "a, b, *other"')
|
||||
|
||||
s2 = s1[2]
|
||||
reveal_type(s2, expected_text='int @ "index((a, b, *other), 2)"')
|
||||
|
||||
s3 = s1[-3]
|
||||
reveal_type(s3, expected_text='int @ "index((a, b, *other), -3)"')
|
||||
|
||||
s4 = s1[-4]
|
||||
reveal_type(s4, expected_text='int @ "index((a, b, *other), -4)"')
|
||||
|
||||
s5 = s1[4]
|
||||
reveal_type(s5, expected_text='int @ "index((a, b, *other), 4)"')
|
||||
|
||||
|
||||
def concat1(t1: Tensor @ "a, b, c", t2: Tensor @ "a, 1, c"):
|
||||
s1 = cat((t1, t2), dim=1)
|
||||
reveal_type(s1, expected_text='Tensor @ "a, b + 1, c"')
|
||||
|
||||
s2 = cat((t1, t2, t2), dim=1)
|
||||
reveal_type(s2, expected_text='Tensor @ "a, b + 2, c"')
|
||||
|
||||
# This should generate an error.
|
||||
s3 = cat((t1, t2, t2))
|
||||
|
||||
# This should generate an error.
|
||||
s4 = cat((t1, t2, t2), dim=2)
|
||||
|
||||
# This should generate an error.
|
||||
s5 = cat((t1, t2, t2), dim=-1)
|
||||
|
||||
# This should generate an error.
|
||||
s6 = cat((t1, t2, t2), dim=5)
|
||||
|
||||
|
||||
def conv1(input: Tensor @ "n, c_in, y, x", weight: Tensor @ "c_out, c_in, ky, kx"):
|
||||
c1 = conv2d(input, weight)
|
||||
reveal_type(c1, expected_text='Tensor @ "n, c_out, y - ky + 1, x - kx + 1"')
|
||||
|
||||
|
||||
def conv2(x: Tensor @ "B, C, H, W", filters: Tensor @ "C, C, F1, F2"):
|
||||
return conv2d(x, filters, stride=2)
|
||||
|
||||
|
||||
def conv3():
|
||||
filters = randn(4, 4, 5, 5)
|
||||
reveal_type(filters, expected_text='Tensor @ "4, 4, 5, 5"')
|
||||
|
||||
c0 = conv2(randn(1, 4, 5, 5), filters)
|
||||
reveal_type(c0, expected_text='Tensor @ "1, 4, 1, 1"')
|
||||
|
||||
c1 = conv2(randn(1, 4, 32, 32), filters)
|
||||
reveal_type(c1, expected_text='Tensor @ "1, 4, 14, 14"')
|
||||
|
||||
c2 = conv2(randn(1, 4, 53, 32), filters)
|
||||
reveal_type(c2, expected_text='Tensor @ "1, 4, 25, 14"')
|
||||
|
||||
c3 = conv2(randn(1, 4, 28, 28), filters)
|
||||
reveal_type(c3, expected_text='Tensor @ "1, 4, 12, 12"')
|
||||
|
||||
|
||||
def sum1(t1: Tensor @ "a, b"):
|
||||
s1 = sum(t1)
|
||||
reveal_type(s1, expected_text='Tensor @ "1,"')
|
||||
|
||||
s2 = sum(t1, dim=0)
|
||||
reveal_type(s2, expected_text='Tensor @ "b,"')
|
||||
|
||||
s3 = sum(t1, dim=0, keepdim=True)
|
||||
reveal_type(s3, expected_text='Tensor @ "1, b"')
|
||||
|
||||
s4 = sum(t1, dim=1)
|
||||
reveal_type(s4, expected_text='Tensor @ "a,"')
|
||||
|
||||
s5 = sum(t1, dim=1, keepdim=True)
|
||||
reveal_type(s5, expected_text='Tensor @ "a, 1"')
|
||||
|
||||
s6 = sum(t1, dim=-1)
|
||||
reveal_type(s6, expected_text='Tensor @ "a,"')
|
||||
|
||||
s7 = sum(t1, dim=-2)
|
||||
reveal_type(s7, expected_text='Tensor @ "b,"')
|
||||
|
||||
# This should generate an error.
|
||||
s8 = sum(t1, dim=2)
|
||||
|
||||
# This should generate an error.
|
||||
s9 = sum(t1, dim=-3)
|
||||
119
packages/pyright-internal/src/tests/samples/refinementShape3.py
Normal file
119
packages/pyright-internal/src/tests/samples/refinementShape3.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
# This sample tests various aspects of the "Shape" refinement type.
|
||||
|
||||
# pyright: reportMissingModuleSource=false
|
||||
|
||||
from tensorlib import Tensor, linspace, randn, index_select, permute, squeeze, unsqueeze
|
||||
|
||||
|
||||
def broadcast1(
|
||||
t1: Tensor @ "a, b",
|
||||
t2: Tensor @ "x, a, b",
|
||||
t3: Tensor @ "x, 1, 1",
|
||||
t4: Tensor @ "1, a, c if c == b",
|
||||
t5: Tensor @ "1, a, d",
|
||||
t6: Tensor @ "3, 1, 4",
|
||||
t7: Tensor @ "5, 1, 5, 1",
|
||||
):
|
||||
d1 = t1.sub(t2)
|
||||
reveal_type(d1, expected_text='Tensor @ "x, a, b"')
|
||||
|
||||
d2 = t1 - t2
|
||||
reveal_type(d2, expected_text='Tensor @ "x, a, b"')
|
||||
|
||||
d3 = t2 + t3
|
||||
reveal_type(d3, expected_text='Tensor @ "x, a, b"')
|
||||
|
||||
d3 = t2 - t3
|
||||
reveal_type(d3, expected_text='Tensor @ "x, a, b"')
|
||||
|
||||
d4 = t2 + t4
|
||||
reveal_type(d4, expected_text='Tensor @ "x, a, b"')
|
||||
|
||||
# This should generate an error.
|
||||
d5 = t2 - t5
|
||||
|
||||
d6 = t6 + t7
|
||||
reveal_type(d6, expected_text='Tensor @ "5, 3, 5, 4"')
|
||||
|
||||
|
||||
def linspace1(i1: int @ "a if a > 0"):
|
||||
t1 = linspace(0, 10, 4)
|
||||
reveal_type(t1, expected_text='Tensor @ "4,"')
|
||||
|
||||
t2 = linspace(0, 4, i1)
|
||||
reveal_type(t2, expected_text='Tensor @ "a,"')
|
||||
|
||||
t3_out = randn(2)
|
||||
reveal_type(t3_out, expected_text='Tensor @ "2,"')
|
||||
t3 = linspace(0, 4, 2, out=t3_out)
|
||||
reveal_type(t3, expected_text='Tensor @ "2,"')
|
||||
|
||||
# This should generate an error.
|
||||
t4 = linspace(0, 4, 3, out=t3_out)
|
||||
|
||||
# This should generate an error.
|
||||
t5 = linspace(0, 4, 0)
|
||||
|
||||
|
||||
def index_select1(i: Tensor @ "a, b", x: Tensor @ "3, "):
|
||||
t1 = index_select(i, 0, x)
|
||||
reveal_type(t1, expected_text='Tensor @ "3, b"')
|
||||
|
||||
t2 = index_select(i, 1, x)
|
||||
reveal_type(t2, expected_text='Tensor @ "a, 3"')
|
||||
|
||||
# This should generate an error.
|
||||
t3 = index_select(i, 2, x)
|
||||
|
||||
|
||||
def permute1(t1: Tensor @ "a, b, c"):
|
||||
p1 = permute(t1, (1, 2, 0))
|
||||
reveal_type(p1, expected_text='Tensor @ "b, c, a"')
|
||||
|
||||
p2 = permute(t1, (0, 2, 1))
|
||||
reveal_type(p2, expected_text='Tensor @ "a, c, b"')
|
||||
|
||||
# This should generate an error.
|
||||
p3 = permute(t1, (0, 2, 0))
|
||||
|
||||
# This should generate an error.
|
||||
p4 = permute(t1, (0, 2, 4))
|
||||
|
||||
# This should generate an error.
|
||||
p5 = permute(t1, (0, 1, 2, 3))
|
||||
|
||||
|
||||
def squeeze1(t1: Tensor @ "a, b, 1, 2"):
|
||||
s1 = squeeze(t1, 2)
|
||||
reveal_type(s1, expected_text='Tensor @ "a, b, 2"')
|
||||
|
||||
s2 = squeeze(t1, -2)
|
||||
reveal_type(s1, expected_text='Tensor @ "a, b, 2"')
|
||||
|
||||
s3 = squeeze(t1, -1)
|
||||
reveal_type(s3, expected_text='Tensor @ "a, b, 1, 2"')
|
||||
|
||||
# This should generate two errors.
|
||||
s4 = squeeze(t1, 5)
|
||||
|
||||
s5 = squeeze(t1, 1)
|
||||
reveal_type(s5, expected_text='Tensor @ "a, b, 1, 2"')
|
||||
|
||||
|
||||
def squeeze2(t1: Tensor @ "a, b, 1, 2"):
|
||||
s1 = squeeze(t1, (1, 2))
|
||||
reveal_type(s1, expected_text="Tensor")
|
||||
|
||||
|
||||
def unsqueeze1(t1: Tensor @ "a, b, c"):
|
||||
u1 = unsqueeze(t1, 1)
|
||||
reveal_type(u1, expected_text='Tensor @ "a, 1, b, c"')
|
||||
|
||||
u2 = unsqueeze(t1, 3)
|
||||
reveal_type(u2, expected_text='Tensor @ "a, b, c, 1"')
|
||||
|
||||
u3 = unsqueeze(t1, -1)
|
||||
reveal_type(u3, expected_text='Tensor @ "a, b, c, 1"')
|
||||
|
||||
u4 = unsqueeze(t1, -2)
|
||||
reveal_type(u4, expected_text='Tensor @ "a, b, 1, c"')
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
# This sample tests the handling of refinement types that use the
|
||||
# IntTuple refinement domain.
|
||||
|
||||
# pyright: reportMissingModuleSource=false
|
||||
|
||||
from typing import cast
|
||||
from typing_extensions import Shape
|
||||
|
||||
|
||||
class Tensor:
|
||||
...
|
||||
|
||||
|
||||
def func1(a: int @ "x", b: int @ "y") -> Tensor @ Shape("x, y"):
|
||||
...
|
||||
|
||||
|
||||
v1 = func1(1, 2)
|
||||
reveal_type(v1, expected_text='Tensor @ Shape("1, 2")')
|
||||
|
||||
|
||||
def func2(a: Tensor @ Shape("x, y")) -> Tensor @ Shape("y, x"):
|
||||
...
|
||||
|
||||
|
||||
t2 = cast(Tensor @ Shape("1, 2"), Tensor())
|
||||
v2 = func2(t2)
|
||||
reveal_type(v2, expected_text='Tensor @ Shape("2, 1")')
|
||||
|
||||
|
||||
def func3(a: Tensor @ Shape("a, b")):
|
||||
x = func2(a)
|
||||
reveal_type(x, expected_text='Tensor @ Shape("b, a")')
|
||||
return x
|
||||
|
||||
|
||||
t3_1 = cast(Tensor @ Shape("1, 2"), Tensor())
|
||||
v3_1 = func3(t3_1)
|
||||
reveal_type(v3_1, expected_text='Tensor @ Shape("2, 1")')
|
||||
|
||||
t3_2 = cast(Tensor @ Shape("_, 1"), Tensor())
|
||||
v3_2 = func3(t3_2)
|
||||
reveal_type(v3_2, expected_text='Tensor @ Shape("1, _")')
|
||||
|
||||
|
||||
def func4(a: Tensor @ Shape("a, b, *other")) -> Tensor @ Shape("a, *other, b"):
|
||||
...
|
||||
|
||||
|
||||
t4_1 = cast(Tensor @ Shape("1, 2, 3, 4"), Tensor())
|
||||
v4_1 = func4(t4_1)
|
||||
reveal_type(v4_1, expected_text='Tensor @ Shape("1, 3, 4, 2")')
|
||||
105
packages/pyright-internal/src/tests/samples/tensorlib.pyi
Normal file
105
packages/pyright-internal/src/tests/samples/tensorlib.pyi
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
# This sample represents a typical tensor library that might use the
|
||||
# "Shape" refinement. It is used by other test cases.
|
||||
|
||||
# This sample defines the basic classes found in a typical tensor
|
||||
# library. It's meant to test refinement types for tensor shape
|
||||
# validation.
|
||||
|
||||
from typing import Literal, overload
|
||||
from typing_extensions import IntTupleValue, Shape
|
||||
|
||||
class Size(tuple[int, ...]):
|
||||
@classmethod
|
||||
def __type_metadata__(cls, predicate: str) -> Shape: ...
|
||||
def __getitem__(self: Size @ "o", key: int @ "i") -> int @ "index(o, i)": ... # pyright: ignore[reportIncompatibleMethodOverride]
|
||||
def elem_count(self: Size @ "o") -> int @ "len(o)": ...
|
||||
|
||||
class Tensor:
|
||||
def __init__(self, value) -> None: ...
|
||||
@classmethod
|
||||
def __type_metadata__(cls, predicate: str) -> Shape: ...
|
||||
@property
|
||||
def shape(self: Tensor @ "o") -> Size @ "o": ...
|
||||
def transpose(self: Tensor @ "o", dim0: int @ "a", dim1: int @ "b") -> Tensor @ "swap(o, a, b)": ...
|
||||
def view(self: Tensor @ "o", *args: *(tuple[int, ...] @ IntTupleValue("t"))) -> Tensor @ "reshape(o, t)": ...
|
||||
def cos(self: Tensor @ "o") -> Tensor @ "o": ...
|
||||
def sin(self: Tensor @ "o") -> Tensor @ "o": ...
|
||||
def add(self: Tensor @ "o", input: Tensor @ "i") -> Tensor @ "broadcast(o, i)": ...
|
||||
def __add__(self: Tensor @ "o", input: Tensor @ "i") -> Tensor @ "broadcast(o, i)": ...
|
||||
def sub(self: Tensor @ "o", input: Tensor @ "i") -> Tensor @ "broadcast(o, i)": ...
|
||||
def __sub__(self: Tensor @ "o", input: Tensor @ "i") -> Tensor @ "broadcast(o, i)": ...
|
||||
def pow(self: Tensor @ "o", exp: float) -> Tensor @ "o": ...
|
||||
def __pow__(self: Tensor @ "o", exp: float) -> Tensor @ "o": ...
|
||||
|
||||
@overload
|
||||
def cat(
|
||||
tensors: tuple[Tensor @ "t1", Tensor @ "t2"],
|
||||
dim: int @ "d" = 0,
|
||||
*,
|
||||
out: Tensor @ "o if o == concat(t1, t2, d)" | None = ...,
|
||||
) -> Tensor @ "concat(t1, t2, d)": ...
|
||||
@overload
|
||||
def cat(
|
||||
tensors: tuple[Tensor @ "t1", Tensor @ "t2", Tensor @ "t3"],
|
||||
dim: int @ "d" = 0,
|
||||
*,
|
||||
out: Tensor @ "o if o == concat(concat(t1, t2, d), t3, d)" | None = ...,
|
||||
) -> Tensor @ "concat(concat(t1, t2, d), t3, d)": ...
|
||||
@overload
|
||||
def cat(
|
||||
tensors: tuple[Tensor @ "t1", Tensor @ "t2", Tensor @ "t3", Tensor @ "t4"],
|
||||
dim: int @ "d" = 0,
|
||||
*,
|
||||
out: Tensor @ "o if o == concat(concat(concat(t1, t2, d), t3, d), t4, d)" | None = ...,
|
||||
) -> Tensor @ "concat(concat(concat(t1, t2, d), t3, d), t4, d)": ...
|
||||
@overload
|
||||
def cat(tensors: tuple[Tensor, ...], dim: int = 0, *, out: Tensor | None = ...) -> Tensor: ...
|
||||
def conv2d(
|
||||
input: Tensor @ "n, c_in, h_in, w_in",
|
||||
weight: Tensor @ "c_out, gr, k0, k1 if gr == c_in // g",
|
||||
bias: Tensor @ "c_out," | None = None,
|
||||
stride: int @ "s0" @ "s1" | tuple[int @ "s0", int @ "s1"] = 1,
|
||||
padding: int @ "p0" @ "p1" | tuple[int @ "p0", int @ "p1"] = 0,
|
||||
dilation: int @ "d0" @ "d1" | tuple[int @ "d0", int @ "d1"] = 1,
|
||||
groups: int @ "g" = 1,
|
||||
) -> (
|
||||
Tensor @ "n, c_out, (h_in + 2 * p0 - d0 * (k0 - 1) - 1) // s0 + 1, (w_in + 2 * p1 - d1 * (k1 - 1) - 1) // s1 + 1"
|
||||
): ...
|
||||
def index_select(
|
||||
input: Tensor @ "o", dim: int @ "i if index(o, i) == _", index: Tensor @ "x,"
|
||||
) -> Tensor @ "splice(o, i, 1, (x,))": ...
|
||||
def linspace(
|
||||
start: float, end: float, steps: int @ "x if x > 0", *, out: Tensor @ "x, " | None = None
|
||||
) -> Tensor @ "x, ": ...
|
||||
def logspace(
|
||||
start: float, end: float, steps: int @ "x if x > 0", *, out: Tensor @ "x, " | None = None
|
||||
) -> Tensor @ "x, ": ...
|
||||
def permute(
|
||||
input: Tensor @ "o", dims: tuple[int, ...] @ IntTupleValue("d if permute(o, d) == _")
|
||||
) -> Tensor @ "permute(o, d)": ...
|
||||
def randn(*args: *(tuple[int, ...] @ Shape("o"))) -> Tensor @ "o": ...
|
||||
def sqrt(input: Tensor @ "o") -> Tensor @ "o": ...
|
||||
@overload
|
||||
def squeeze(input: Tensor @ "o", dim: int @ "i if index(o, i) == 1") -> Tensor @ "splice(o, i, 1, ())": ...
|
||||
@overload
|
||||
def squeeze(input: Tensor @ "o", dim: int @ "i if index(o, i) != 1") -> Tensor @ "o": ...
|
||||
@overload
|
||||
def squeeze(input: Tensor, dim: tuple[int, ...] | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def sum(input: Tensor @ "o", *, dim: None = None, keepdim: bool = False) -> Tensor @ "(1, )": ...
|
||||
@overload
|
||||
def sum(
|
||||
input: Tensor @ "o",
|
||||
*,
|
||||
dim: int @ "d",
|
||||
keepdim: Literal[False] = False,
|
||||
) -> Tensor @ "splice(o, d, 1, ())": ...
|
||||
@overload
|
||||
def sum(input: Tensor @ "o", *, dim: int @ "d", keepdim: Literal[True]) -> Tensor @ "splice(o, d, 1, (1, ))": ...
|
||||
def take(input: Tensor @ "o", index: Tensor @ "x,") -> Tensor @ "x,": ...
|
||||
@overload
|
||||
def unsqueeze(input: Tensor @ "o", dim: int @ "d if d >= 0") -> Tensor @ "splice(o, d, 0, (1,))": ...
|
||||
@overload
|
||||
def unsqueeze(input: Tensor @ "o", dim: int @ "d if d < 0") -> Tensor @ "splice(o, len(o) + d + 1, 0, (1,))": ...
|
||||
@overload
|
||||
def unsqueeze(input: Tensor, dim: int) -> Tensor: ...
|
||||
26
packages/pyright-internal/src/tests/samples/typeMeta1.py
Normal file
26
packages/pyright-internal/src/tests/samples/typeMeta1.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
# This sample tests the handling of the TypeMetadata class.
|
||||
|
||||
# pyright: reportMissingModuleSource=false
|
||||
|
||||
from typing_extensions import TypeMetadata
|
||||
|
||||
|
||||
class NotTypeMeta: ...
|
||||
|
||||
|
||||
class TM1(TypeMetadata): ...
|
||||
|
||||
|
||||
# This should generate an error.
|
||||
v1: int @ []
|
||||
|
||||
# This should generate an error.
|
||||
v2: list[int] @ [1]
|
||||
|
||||
# This should generate an error.
|
||||
v3: int @ NotTypeMeta()
|
||||
|
||||
# This should generate an error.
|
||||
v4: int @ TM1
|
||||
|
||||
v5: int @ TM1()
|
||||
|
|
@ -11,7 +11,7 @@
|
|||
import * as assert from 'assert';
|
||||
|
||||
import { ConfigOptions } from '../common/configOptions';
|
||||
import { pythonVersion3_10, pythonVersion3_11, pythonVersion3_8 } from '../common/pythonVersion';
|
||||
import { pythonVersion3_10, pythonVersion3_11, pythonVersion3_12, pythonVersion3_8 } from '../common/pythonVersion';
|
||||
import { Uri } from '../common/uri/uri';
|
||||
import * as TestUtils from './testUtils';
|
||||
|
||||
|
|
@ -972,3 +972,110 @@ test('TypeForm7', () => {
|
|||
|
||||
TestUtils.validateResults(analysisResults, 1);
|
||||
});
|
||||
|
||||
test('TypeMeta1', () => {
|
||||
const configOptions = new ConfigOptions(Uri.empty());
|
||||
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeMeta1.py'], configOptions);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 4);
|
||||
});
|
||||
|
||||
test('Refinement1', () => {
|
||||
const configOptions = new ConfigOptions(Uri.empty());
|
||||
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinement1.py'], configOptions);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 8);
|
||||
});
|
||||
|
||||
test('Refinement2', () => {
|
||||
const configOptions = new ConfigOptions(Uri.empty());
|
||||
configOptions.defaultPythonVersion = pythonVersion3_12;
|
||||
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinement2.py'], configOptions);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 14);
|
||||
});
|
||||
|
||||
test('Refinement3', () => {
|
||||
const configOptions = new ConfigOptions(Uri.empty());
|
||||
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinement3.py'], configOptions);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 3);
|
||||
});
|
||||
|
||||
test('Refinement4', () => {
|
||||
const configOptions = new ConfigOptions(Uri.empty());
|
||||
configOptions.defaultPythonVersion = pythonVersion3_12;
|
||||
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinement4.py'], configOptions);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 7);
|
||||
});
|
||||
|
||||
test('RefinementEnforce1', () => {
|
||||
const configOptions = new ConfigOptions(Uri.empty());
|
||||
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementEnforce1.py'], configOptions);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 4);
|
||||
});
|
||||
|
||||
test('RefinementImplicit1', () => {
|
||||
const configOptions = new ConfigOptions(Uri.empty());
|
||||
configOptions.defaultPythonVersion = pythonVersion3_12;
|
||||
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementImplicit1.py'], configOptions);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 0);
|
||||
});
|
||||
|
||||
test('RefinementLiteral1', () => {
|
||||
const configOptions = new ConfigOptions(Uri.empty());
|
||||
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementLiteral1.py'], configOptions);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 3);
|
||||
});
|
||||
|
||||
test('RefinementShape1', () => {
|
||||
const configOptions = new ConfigOptions(Uri.empty());
|
||||
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementShape1.py'], configOptions);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 2);
|
||||
});
|
||||
|
||||
test('RefinementShape2', () => {
|
||||
const configOptions = new ConfigOptions(Uri.empty());
|
||||
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementShape2.py'], configOptions);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 10);
|
||||
});
|
||||
|
||||
test('RefinementShape3', () => {
|
||||
const configOptions = new ConfigOptions(Uri.empty());
|
||||
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementShape3.py'], configOptions);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 9);
|
||||
});
|
||||
|
||||
test('RefinementCustom1', () => {
|
||||
const configOptions = new ConfigOptions(Uri.empty());
|
||||
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementCustom1.py'], configOptions);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 2);
|
||||
});
|
||||
|
||||
test('RefinementTuple1', () => {
|
||||
const configOptions = new ConfigOptions(Uri.empty());
|
||||
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;
|
||||
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['refinementTuple1.py'], configOptions);
|
||||
|
||||
TestUtils.validateResults(analysisResults, 0);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -539,3 +539,73 @@ if sys.version_info >= (3, 14):
|
|||
from typing import TypeForm
|
||||
else:
|
||||
TypeForm: _SpecialForm
|
||||
|
||||
|
||||
# Refinement Types
|
||||
|
||||
class TypeMetadata:
|
||||
def __rmatmul__[T](self, typ: T) -> T:
|
||||
return typ
|
||||
|
||||
class Refinement(TypeMetadata, metaclass=abc.ABCMeta):
|
||||
def __init__(self, predicate: str | None, /) -> None: ...
|
||||
@abc.abstractmethod
|
||||
def __str__(self) -> str: ...
|
||||
|
||||
class IntRefinement(Refinement, metaclass=abc.ABCMeta):
|
||||
predicate: str | None
|
||||
@overload
|
||||
def __init__(
|
||||
self, predicate: None = None, /, *, value: int, enforce: bool = True
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, predicate: str, /, *, enforce: bool = True) -> None: ...
|
||||
|
||||
class StrRefinement(Refinement, metaclass=abc.ABCMeta):
|
||||
predicate: str | None
|
||||
@overload
|
||||
def __init__(
|
||||
self, predicate: None = None, /, *, value: str, enforce: bool = True
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, predicate: str, /, *, enforce: bool = False) -> None: ...
|
||||
|
||||
class BytesRefinement(Refinement, metaclass=abc.ABCMeta):
|
||||
predicate: str | None
|
||||
@overload
|
||||
def __init__(
|
||||
self, predicate: None = None, /, *, value: bytes, enforce: bool = True
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, predicate: str, /, *, enforce: bool = False) -> None: ...
|
||||
|
||||
class BoolRefinement(Refinement, metaclass=abc.ABCMeta):
|
||||
predicate: str | None
|
||||
@overload
|
||||
def __init__(
|
||||
self, predicate: None = None, /, *, value: bool, enforce: bool = True
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, predicate: str, /, *, enforce: bool = False) -> None: ...
|
||||
|
||||
class IntTupleRefinement(Refinement, metaclass=abc.ABCMeta):
|
||||
predicate: str | None
|
||||
def __init__(self, predicate: str, /, *, enforce: bool = False) -> None: ...
|
||||
|
||||
class IntValue(IntRefinement):
|
||||
def __str__(self) -> str: ...
|
||||
|
||||
class StrValue(StrRefinement):
|
||||
def __str__(self) -> str: ...
|
||||
|
||||
class BytesValue(BytesRefinement):
|
||||
def __str__(self) -> str: ...
|
||||
|
||||
class BoolValue(BoolRefinement):
|
||||
def __str__(self) -> str: ...
|
||||
|
||||
class IntTupleValue(IntTupleRefinement):
|
||||
def __str__(self) -> str: ...
|
||||
|
||||
class Shape(IntTupleRefinement):
|
||||
def __str__(self) -> str: ...
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue