Add type checking to the node graph (#1025)

* Implement type inference

Add type hints to node trait

Add type annotation infrastructure

Refactor type ascription infrastructure

Run cargo fix

Insert infer types stub

Remove types from node identifier

* Implement covariance

* Disable rejection of generic inputs + parameters

* Fix lints

* Extend type checking to cover Network inputs

* Implement generic specialization

* Relax covariance rules

* Fix type annotations for TypErasedComposeNode

* Fix type checking errors

* Keep connection information during node resolution
* Fix TypeDescriptor PartialEq implementation

* Apply review suggestions

* Add documentation to type inference

* Add Imaginate node to document node types

* Fix whitespace in macros

* Add types to imaginate node

* Fix type declaration for imaginate node + add console logging

* Use fully qualified type names as fallback during comparison

---------

Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
Dennis Kobert 2023-02-15 23:31:30 +01:00 committed by Keavon Chambers
parent a64c856ec4
commit 5dab7de68d
25 changed files with 1365 additions and 1008 deletions

View file

@ -9,7 +9,7 @@ license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
std = ["dyn-any", "dyn-any/std"]
std = ["dyn-any", "dyn-any/std", "alloc"]
default = ["async", "serde", "kurbo", "log", "std"]
log = ["dep:log"]
serde = ["dep:serde", "glam/serde"]
@ -17,9 +17,10 @@ gpu = ["spirv-std", "bytemuck", "glam/bytemuck", "dyn-any"]
async = ["async-trait", "alloc"]
nightly = []
alloc = ["dyn-any", "bezier-rs"]
type_id_logging = []
[dependencies]
dyn-any = {path = "../../libraries/dyn-any", features = ["derive"], optional = true, default-features = false }
dyn-any = {path = "../../libraries/dyn-any", features = ["derive", "glam"], optional = true, default-features = false }
spirv-std = { git = "https://github.com/EmbarkStudios/rust-gpu", features = ["glam"] , optional = true}
bytemuck = {version = "1.8", features = ["derive"], optional = true}
@ -34,4 +35,5 @@ kurbo = { git = "https://github.com/linebender/kurbo.git", features = [
glam = { version = "^0.22", default-features = false, features = ["scalar-math", "libm"]}
node-macro = {path = "../node-macro"}
specta.workspace = true
once_cell = { version = "1.17.0", default-features = false }
# forma = { version = "0.1.0", package = "forma-render" }

View file

@ -21,12 +21,52 @@ pub mod raster;
#[cfg(feature = "alloc")]
pub mod vector;
use core::any::TypeId;
// pub trait Node: for<'n> NodeIO<'n> {
pub trait Node<'i, Input: 'i>: 'i {
type Output: 'i;
fn eval<'s: 'i>(&'s self, input: Input) -> Self::Output;
}
#[cfg(feature = "alloc")]
mod types;
pub use types::*;
pub trait NodeIO<'i, Input: 'i>: 'i + Node<'i, Input>
where
Self::Output: 'i + StaticType,
Input: 'i + StaticType,
{
fn input_type(&self) -> TypeId {
TypeId::of::<Input::Static>()
}
fn input_type_name(&self) -> &'static str {
core::any::type_name::<Input>()
}
fn output_type(&self) -> core::any::TypeId {
TypeId::of::<<Self::Output as StaticType>::Static>()
}
fn output_type_name(&self) -> &'static str {
core::any::type_name::<Self::Output>()
}
#[cfg(feature = "alloc")]
fn to_node_io(&self, parameters: Vec<Type>) -> NodeIOTypes {
NodeIOTypes {
input: concrete!(<Input as StaticType>::Static),
output: concrete!(<Self::Output as StaticType>::Static),
parameters,
}
}
}
impl<'i, N: Node<'i, I>, I> NodeIO<'i, I> for N
where
N::Output: 'i + StaticType,
I: 'i + StaticType,
{
}
/*impl<'i, I: 'i, O: 'i> Node<'i, I> for &'i dyn for<'n> Node<'n, I, Output = O> {
type Output = O;
@ -42,6 +82,8 @@ impl<'i, 'n: 'i, I: 'i, O: 'i> Node<'i, I> for &'n dyn for<'a> Node<'a, I, Outpu
}
}
use core::pin::Pin;
use dyn_any::StaticType;
#[cfg(feature = "alloc")]
impl<'i, I: 'i, O: 'i> Node<'i, I> for Pin<Box<dyn for<'a> Node<'a, I, Output = O> + 'i>> {
type Output = O;

View file

@ -0,0 +1,108 @@
use core::any::TypeId;
#[cfg(not(feature = "std"))]
pub use alloc::borrow::Cow;
#[cfg(feature = "std")]
pub use std::borrow::Cow;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct NodeIOTypes {
pub input: Type,
pub output: Type,
pub parameters: Vec<Type>,
}
impl NodeIOTypes {
pub fn new(input: Type, output: Type, parameters: Vec<Type>) -> Self {
Self { input, output, parameters }
}
}
#[macro_export]
macro_rules! concrete {
($type:ty) => {
Type::Concrete(TypeDescriptor {
id: Some(core::any::TypeId::of::<$type>()),
name: Cow::Borrowed(core::any::type_name::<$type>()),
})
};
}
#[macro_export]
macro_rules! generic {
($type:ty) => {{
Type::Generic(Cow::Borrowed(stringify!($type)))
}};
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, specta::Type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct NodeIdentifier {
pub name: Cow<'static, str>,
}
#[derive(Clone, Debug, Eq, specta::Type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TypeDescriptor {
#[cfg_attr(feature = "serde", serde(skip))]
#[specta(skip)]
pub id: Option<TypeId>,
pub name: Cow<'static, str>,
}
impl core::hash::Hash for TypeDescriptor {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.id.hash(state);
}
}
impl PartialEq for TypeDescriptor {
fn eq(&self, other: &Self) -> bool {
match (self.id, other.id) {
(Some(id), Some(other_id)) => id == other_id,
_ => {
warn!("TypeDescriptor::eq: comparing types without ids based on name");
self.name == other.name
}
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, specta::Type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum Type {
Generic(Cow<'static, str>),
Concrete(TypeDescriptor),
}
impl core::fmt::Debug for Type {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Generic(arg0) => f.write_fmt(format_args!("Generic({})", arg0)),
#[cfg(feature = "type_id_logging")]
Self::Concrete(arg0) => f.write_fmt(format_args!("Concrete({}, {:?}))", arg0.name, arg0.id)),
#[cfg(not(feature = "type_id_logging"))]
Self::Concrete(arg0) => f.write_fmt(format_args!("Concrete({})", arg0.name)),
}
}
}
impl std::fmt::Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Type::Generic(name) => write!(f, "{}", name),
Type::Concrete(ty) => write!(f, "{}", ty.name),
}
}
}
impl From<&'static str> for NodeIdentifier {
fn from(s: &'static str) -> Self {
NodeIdentifier { name: Cow::Borrowed(s) }
}
}
impl NodeIdentifier {
pub const fn new(name: &'static str) -> Self {
NodeIdentifier { name: Cow::Borrowed(name) }
}
}