fix: overload resolution bug

This commit is contained in:
Shunsuke Shibayama 2024-02-09 14:33:58 +09:00
parent 66352ddd3a
commit 21d5f22ca8
5 changed files with 30 additions and 1 deletions

View file

@ -1126,9 +1126,17 @@ impl Context {
_ => {} _ => {}
} }
if self.subtype_of(ty, &input_t) { if self.subtype_of(ty, &input_t) {
if let Ok(instance) = self.instantiate(ty.clone(), obj) {
let subst = self
.substitute_call(obj, &None, &instance, pos_args, kw_args, self)
.is_ok();
let eval = self.eval_t_params(instance, self.level, obj).is_ok();
if subst && eval {
return Ok(ty.clone()); return Ok(ty.clone());
} }
} }
}
}
let Type::Subr(subr_t) = input_t else { let Type::Subr(subr_t) = input_t else {
unreachable!() unreachable!()
}; };

View file

@ -0,0 +1,4 @@
_ImageBase: ClassType
.AxesImage!: ClassType
.AxesImage! <: _ImageBase

View file

@ -6,9 +6,12 @@ legend = pyimport "../legend"
.style = pyimport "../style" .style = pyimport "../style"
figure = pyimport "../figure" figure = pyimport "../figure"
axes = pyimport "../axes" axes = pyimport "../axes"
image = pyimport "../image"
.plot!: (*args: Obj, scaleX := Bool, scaleY := Bool) => [Obj; _] .plot!: (*args: Obj, scaleX := Bool, scaleY := Bool) => [Obj; _]
.imshow!: (X: Obj, cmap := Str, interpolation := Str) => image.AxesImage!
.show!: () => NoneType .show!: () => NoneType
.text!: (x: Float, y: Float, s: Str, fontdict := {Str: Obj}, fontsize := Nat) => text.Text
.title!: (title: Str) => text.Text .title!: (title: Str) => text.Text
.xlabel!: (label: Str) => text.Text .xlabel!: (label: Str) => text.Text
.ylabel!: (label: Str) => text.Text .ylabel!: (label: Str) => text.Text

View file

@ -1,3 +1,5 @@
np = pyimport "numpy"
.backends = pyimport "./backends" .backends = pyimport "./backends"
.cuda = pyimport "./cuda" .cuda = pyimport "./cuda"
.nn = pyimport "./nn" .nn = pyimport "./nn"
@ -33,9 +35,17 @@
.Tensor!(T, S)|<: IrregularEq|. .Tensor!(T, S)|<: IrregularEq|.
Output: {Tensor!(Bool, S)} Output: {Tensor!(Bool, S)}
__eq__: (self: .Tensor!(T, S), other: .Tensor!(T, S)) -> .Tensor!(Bool, S) __eq__: (self: .Tensor!(T, S), other: .Tensor!(T, S)) -> .Tensor!(Bool, S)
.Tensor!(T, S)|<: Indexable(Nat, .Tensor!(T, _))|.
__getitem__: (self: .Tensor!(T, S), index: Nat or [Nat; _]) -> .Tensor!(T, _)
.Tensor!(T, S).
data: .Tensor!(T, S)
.Tensor!(_, _). .Tensor!(_, _).
dtype: .DType dtype: .DType
shape: .Size shape: .Size
clone: |T, S: [Nat; _]|(self: .Tensor!(T, S)) -> .Tensor!(T, S)
cpu: |T, S: [Nat; _]|(self: .Tensor!(T, S)) -> .Tensor!(T, S)
detach: |T, S: [Nat; _]|(self: .Tensor!(T, S)) -> .Tensor!(T, S)
numpy: |T, S: [Nat; _]|(self: .Tensor!(T, S)) -> np.NDArray(T, S)
view: (|T, Old: [Nat; _], S: {A: [Nat; _] | A.prod() == Old.prod()}|( view: (|T, Old: [Nat; _], S: {A: [Nat; _] | A.prod() == Old.prod()}|(
self: .Tensor!(T, Old), self: .Tensor!(T, Old),
shape: {S}, shape: {S},
@ -64,6 +74,7 @@
and (|T|(self: .Tensor!(T, _), dim: Nat) -> .Tensor!(T, _)) and (|T|(self: .Tensor!(T, _), dim: Nat) -> .Tensor!(T, _))
.relu: |T, S: [Nat; _]|(x: .Tensor!(T, S)) -> .Tensor!(T, S) .relu: |T, S: [Nat; _]|(x: .Tensor!(T, S)) -> .Tensor!(T, S)
.softmax: |T, S: [Nat; _]|(x: .Tensor!(T, S), dim: Nat) -> .Tensor!(T, S)
.max: (|T|(input: .Tensor!(T, _), dim: Nat, keepdim := Bool) -> (.Tensor!(T, _)), .Tensor!(T, _)) \ .max: (|T|(input: .Tensor!(T, _), dim: Nat, keepdim := Bool) -> (.Tensor!(T, _)), .Tensor!(T, _)) \
and (|T|(input: .Tensor!(T, _)) -> .Tensor!(T, _)) and (|T|(input: .Tensor!(T, _)) -> .Tensor!(T, _))
.min: (|T|(input: .Tensor!(T, _), dim: Nat, keepdim := Bool) -> (.Tensor!(T, _)), .Tensor!(T, _)) \ .min: (|T|(input: .Tensor!(T, _), dim: Nat, keepdim := Bool) -> (.Tensor!(T, _)), .Tensor!(T, _)) \

View file

@ -1,7 +1,10 @@
dataset = pyimport "torch/utils/data/dataset" dataset = pyimport "torch/utils/data/dataset"
{Tensor!;} = pyimport "torch"
.VisionDataset: ClassType .VisionDataset: ClassType
.VisionDataset <: dataset.Dataset .VisionDataset <: dataset.Dataset
.VisionDataset|<: Indexable(Nat, (Tensor!(Float, _), Tensor!(Float, _)))|.
__getitem__: (index: Nat) -> (Tensor!(Float, _), Tensor!(Float, _))
.VisionDataset. .VisionDataset.
__call__: ( __call__: (
root: Str, root: Str,