Loosen the Graphene type system to allow contravariant function arguments (#1740)

* Accept any input for nodes that expect () as input

* Add comments

* More comments

---------

Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
Dennis Kobert 2024-05-08 01:36:25 +02:00 committed by GitHub
parent 07fd2c2782
commit ce96ae66f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 54 additions and 293 deletions

View file

@ -969,7 +969,7 @@ impl NodeNetwork {
return;
}
// replace value inputs with value nodes
// Replace value inputs with value nodes
for input in node.inputs.iter_mut() {
// Skip inputs that are already value nodes
if node.implementation == DocumentNodeImplementation::ProtoNode("graphene_core::value::ClonedNode".into()) {

View file

@ -735,14 +735,28 @@ impl TypingContext {
}) {
return Err(vec![GraphError::new(node, GraphErrorType::UnexpectedGenerics { index, parameters })]);
}
fn covariant(from: &Type, to: &Type) -> bool {
/// Checks if a proposed input to a particular (primary or secondary) input is valid for its type signature.
/// `from` indicates the value given to a input, `to` indicates the input's allowed type as specified by its type signature.
fn valid_subtype(from: &Type, to: &Type) -> bool {
match (from, to) {
(Type::Concrete(t1), Type::Concrete(t2)) => t1 == t2,
(Type::Fn(a1, b1), Type::Fn(a2, b2)) => covariant(a1, a2) && covariant(b1, b2),
// Direct comparison of two concrete types.
(Type::Concrete(type1), Type::Concrete(type2)) => type1 == type2,
// Loose comparison of function types, where loose means that functions are considered on a "greater than or equal to" basis of its function type's generality.
// That means we compare their types with a contravariant relationship, which means that a more general type signature may be substituted for a more specific type signature.
// For example, we allow `T -> V` to be substituted with `T' -> V` or `() -> V` where T' and () are more specific than T.
// This allows us to supply anything to a function that is satisfied with `()`.
// In other words, we are implementing these two relations, where the >= operator means that the left side is more general than the right side:
// - `T >= T' ⇒ (T' -> V) >= (T -> V)` (functions are contravariant in their input types)
// - `V >= V' ⇒ (T -> V) >= (T -> V')` (functions are covariant in their output types)
// While these two relations aren't a truth about the universe, they are a design decision that we are employing in our language design that is also common in other languages.
// For example, Rust implements these same relations as it describes here: <https://doc.rust-lang.org/nomicon/subtyping.html>
// More details explained here: <https://github.com/GraphiteEditor/Graphite/issues/1741>
(Type::Fn(in1, out1), Type::Fn(in2, out2)) => valid_subtype(out1, out2) && (valid_subtype(in1, in2) || **in1 == concrete!(())),
// If either the proposed input or the allowed input are generic, we allow the substitution (meaning this is a valid subtype).
// TODO: Add proper generic counting which is not based on the name
(Type::Generic(_), Type::Generic(_)) => true,
(Type::Generic(_), _) => true,
(_, Type::Generic(_)) => true,
(Type::Generic(_), _) | (_, Type::Generic(_)) => true,
// Reject unknown type relationships.
_ => false,
}
}
@ -750,7 +764,7 @@ impl TypingContext {
// List of all implementations that match the input and parameter types
let valid_output_types = impls
.keys()
.filter(|node_io| covariant(&input, &node_io.input) && parameters.iter().zip(node_io.parameters.iter()).all(|(p1, p2)| covariant(p1, p2)))
.filter(|node_io| valid_subtype(&input, &node_io.input) && parameters.iter().zip(node_io.parameters.iter()).all(|(p1, p2)| valid_subtype(p1, p2)))
.collect::<Vec<_>>();
// Attempt to substitute generic types with concrete types and save the list of results
@ -785,7 +799,7 @@ impl TypingContext {
.cloned()
.zip([&node_io.input].into_iter().chain(&node_io.parameters).cloned())
.enumerate()
.filter(|(_, (p1, p2))| !covariant(p1, p2))
.filter(|(_, (p1, p2))| !valid_subtype(p1, p2))
.map(|(index, ty)| (node.original_location.inputs(index).min_by_key(|s| s.node.len()).map(|s| s.index).unwrap_or(index), ty))
.collect::<Vec<_>>();
if current_errors.len() < best_errors {