Fix the Into nodes, which were broken but unused except in GPU nodes (#2480)

* Prototype document network level into node insertion

* Fix generic type resolution

* Cleanup

* Remove network nesting
This commit is contained in:
Dennis Kobert 2025-03-27 10:11:11 +01:00 committed by GitHub
parent 92132919d1
commit 41288d7642
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 100 additions and 47 deletions

View file

@ -281,21 +281,16 @@ impl DocumentNode {
self.inputs[index] = NodeInput::Node { node_id, output_index, lambda };
let input_source = &mut self.original_location.inputs_source;
for source in source {
input_source.insert(source, index + self.original_location.skip_inputs - skip);
input_source.insert(source, (index + self.original_location.skip_inputs).saturating_sub(skip));
}
}
fn resolve_proto_node(mut self) -> ProtoNode {
assert!(!self.inputs.is_empty() || self.manual_composition.is_some(), "Resolving document node {self:#?} with no inputs");
let DocumentNodeImplementation::ProtoNode(fqn) = self.implementation else {
let DocumentNodeImplementation::ProtoNode(identifier) = self.implementation else {
unreachable!("tried to resolve not flattened node on resolved node {self:?}");
};
// TODO replace with proper generics removal
let identifier = match fqn.name.clone().split_once('<') {
Some((path, _generics)) => ProtoNodeIdentifier { name: Cow::Owned(path.to_string()) },
_ => ProtoNodeIdentifier { name: fqn.name },
};
let (input, mut args) = if let Some(ty) = self.manual_composition {
(ProtoNodeInput::ManualComposition(ty), ConstructionArgs::Nodes(vec![]))
} else {

View file

@ -696,7 +696,7 @@ impl TypingContext {
// Direct comparison of two concrete types.
(Type::Concrete(type1), Type::Concrete(type2)) => type1 == type2,
// Check inner type for futures
(Type::Future(type1), Type::Future(type2)) => type1 == type2,
(Type::Future(type1), Type::Future(type2)) => valid_type(type1, type2),
// Direct comparison of two function types.
// Note: in the presence of subtyping, functions are considered on a "greater than or equal to" basis of its function type's generality.
// That means we compare their types with a contravariant relationship, which means that a more general type signature may be substituted for a more specific type signature.
@ -728,16 +728,17 @@ impl TypingContext {
let substitution_results = valid_output_types
.iter()
.map(|node_io| {
collect_generics(node_io)
let generics_lookup: Result<HashMap<_, _>, _> = collect_generics(node_io)
.iter()
.try_for_each(|generic| check_generic(node_io, &primary_input_or_call_argument, &inputs, generic).map(|_| ()))
.map(|_| {
if let Type::Generic(out) = &node_io.return_value {
((*node_io).clone(), check_generic(node_io, &primary_input_or_call_argument, &inputs, out).unwrap())
} else {
((*node_io).clone(), node_io.return_value.clone())
}
})
.map(|generic| check_generic(node_io, &primary_input_or_call_argument, &inputs, generic).map(|x| (generic.to_string(), x)))
.collect();
generics_lookup.map(|generics_lookup| {
let orig_node_io = (*node_io).clone();
let mut new_node_io = orig_node_io.clone();
replace_generics(&mut new_node_io, &generics_lookup);
(new_node_io, orig_node_io)
})
})
.collect::<Vec<_>>();
@ -783,8 +784,8 @@ impl TypingContext {
.join("\n");
Err(vec![GraphError::new(node, GraphErrorType::InvalidImplementations { inputs, error_inputs })])
}
[(org_nio, _)] => {
let node_io = org_nio.clone();
[(node_io, org_nio)] => {
let node_io = node_io.clone();
// Save the inferred type
self.inferred.insert(node_id, node_io.clone());
@ -794,15 +795,15 @@ impl TypingContext {
// If two types are available and one of them accepts () an input, always choose that one
[first, second] => {
if first.0.call_argument != second.0.call_argument {
for (org_nio, _) in [first, second] {
if org_nio.call_argument != concrete!(()) {
for (node_io, orig_nio) in [first, second] {
if node_io.call_argument != concrete!(()) {
continue;
}
// Save the inferred type
self.inferred.insert(node_id, org_nio.clone());
self.constructor.insert(node_id, impls[org_nio]);
return Ok(org_nio.clone());
self.inferred.insert(node_id, node_io.clone());
self.constructor.insert(node_id, impls[orig_nio]);
return Ok(node_io.clone());
}
}
let inputs = [&primary_input_or_call_argument].into_iter().chain(&inputs).map(|t| t.to_string()).collect::<Vec<_>>().join(", ");
@ -821,7 +822,7 @@ impl TypingContext {
/// Returns a list of all generic types used in the node
fn collect_generics(types: &NodeIOTypes) -> Vec<Cow<'static, str>> {
let inputs = [&types.call_argument].into_iter().chain(types.inputs.iter().flat_map(|x| x.fn_output()));
let inputs = [&types.call_argument].into_iter().chain(types.inputs.iter().map(|x| x.nested_type()));
let mut generics = inputs
.filter_map(|t| match t {
Type::Generic(out) => Some(out.clone()),
@ -839,6 +840,7 @@ fn collect_generics(types: &NodeIOTypes) -> Vec<Cow<'static, str>> {
fn check_generic(types: &NodeIOTypes, input: &Type, parameters: &[Type], generic: &str) -> Result<Type, String> {
let inputs = [(Some(&types.call_argument), Some(input))]
.into_iter()
.chain(types.inputs.iter().map(|x| x.fn_input()).zip(parameters.iter().map(|x| x.fn_input())))
.chain(types.inputs.iter().map(|x| x.fn_output()).zip(parameters.iter().map(|x| x.fn_output())));
let concrete_inputs = inputs.filter(|(ni, _)| matches!(ni, Some(Type::Generic(input)) if generic == input));
let mut outputs = concrete_inputs.flat_map(|(_, out)| out);
@ -851,6 +853,21 @@ fn check_generic(types: &NodeIOTypes, input: &Type, parameters: &[Type], generic
Ok(out_ty.clone())
}
/// Returns a list of all generic types used in the node
fn replace_generics(types: &mut NodeIOTypes, lookup: &HashMap<String, Type>) {
let replace = |ty: &Type| {
let Type::Generic(ident) = ty else {
return None;
};
lookup.get(ident.as_ref()).cloned()
};
types.call_argument.replace_nested(replace);
types.return_value.replace_nested(replace);
for input in &mut types.inputs {
input.replace_nested(replace);
}
}
#[cfg(test)]
mod test {
use super::*;