Color system based on traits, and conversion to linear color in the graph (#1123)

* Migrate Nodes to use RasterMut + Samplable

* Add Pixel trait to include serialization

* Implement traits for Color and propagate new generics

* Always convert to linear color when loading images
This commit is contained in:
Dennis Kobert 2023-04-16 02:57:05 +02:00 committed by Keavon Chambers
parent e21c2fb67b
commit 37b892a516
16 changed files with 638 additions and 265 deletions

View file

@ -2,6 +2,9 @@ use core::{fmt::Debug, marker::PhantomData};
use crate::Node;
use bytemuck::{Pod, Zeroable};
use glam::DVec2;
use num::Num;
#[cfg(target_arch = "spirv")]
use spirv_std::num_traits::float::Float;
@ -12,6 +15,183 @@ pub mod brightness_contrast;
pub mod color;
pub use adjustments::*;
pub trait Channel: Copy + Debug + num::Num + num::NumCast {
fn to_linear<Out: Linear>(self) -> Out;
fn from_linear<In: Linear>(linear: In) -> Self;
fn to_f32(self) -> f32 {
num::cast(self).expect("Failed to convert channel to f32")
}
fn from_f32(value: f32) -> Self {
num::cast(value).expect("Failed to convert f32 to channel")
}
fn to_f64(self) -> f64 {
num::cast(self).expect("Failed to convert channel to f64")
}
fn from_f64(value: f64) -> Self {
num::cast(value).expect("Failed to convert f64 to channel")
}
fn to_channel<Out: Channel>(self) -> Out {
num::cast(self).expect("Failed to convert channel to channel")
}
}
pub trait Linear: num::NumCast + Num {}
impl Linear for f32 {}
impl Linear for f64 {}
impl<T: Linear + Debug + Copy> Channel for T {
#[inline(always)]
fn to_linear<Out: Linear>(self) -> Out {
num::cast(self).expect("Failed to convert channel to linear")
}
#[inline(always)]
fn from_linear<In: Linear>(linear: In) -> Self {
num::cast(linear).expect("Failed to convert linear to channel")
}
}
use num_derive::*;
#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Num, NumCast, NumOps, One, Zero, ToPrimitive, FromPrimitive)]
struct SRGBGammaFloat(f32);
impl Channel for SRGBGammaFloat {
#[inline(always)]
fn to_linear<Out: Linear>(self) -> Out {
let channel = num::cast::<_, f32>(self).expect("Failed to convert srgb to linear");
let out = if channel <= 0.04045 { channel / 12.92 } else { ((channel + 0.055) / 1.055).powf(2.4) };
num::cast(out).expect("Failed to convert srgb to linear")
}
#[inline(always)]
fn from_linear<In: Linear>(linear: In) -> Self {
let linear = num::cast::<_, f32>(linear).expect("Failed to convert linear to srgb");
let out = if linear <= 0.0031308 { linear * 12.92 } else { 1.055 * linear.powf(1. / 2.4) - 0.055 };
num::cast(out).expect("Failed to convert linear to srgb")
}
}
pub trait RGBPrimaries {
const RED: DVec2;
const GREEN: DVec2;
const BLUE: DVec2;
const WHITE: DVec2;
}
pub trait Rec709Primaries {}
impl<T: Rec709Primaries> RGBPrimaries for T {
const RED: DVec2 = DVec2::new(0.64, 0.33);
const GREEN: DVec2 = DVec2::new(0.3, 0.6);
const BLUE: DVec2 = DVec2::new(0.15, 0.06);
const WHITE: DVec2 = DVec2::new(0.3127, 0.329);
}
pub trait SRGB: Rec709Primaries {}
#[cfg(feature = "serde")]
pub trait Serde: serde::Serialize + for<'a> serde::Deserialize<'a> {}
#[cfg(not(feature = "serde"))]
pub trait Serde {}
#[cfg(feature = "serde")]
impl<T: serde::Serialize + for<'a> serde::Deserialize<'a>> Serde for T {}
#[cfg(not(feature = "serde"))]
impl<T> Serde for T {}
// TODO: Come up with a better name for this trait
pub trait Pixel: Clone + Pod + Zeroable {
fn to_bytes(&self) -> Vec<u8> {
bytemuck::bytes_of(self).to_vec()
}
// TODO: use u8 for Color
fn from_bytes(bytes: &[u8]) -> &Self {
bytemuck::try_from_bytes(bytes).expect("Failed to convert bytes to pixel")
}
}
impl<T: Serde + Clone + Pod + Zeroable> Pixel for T {}
pub trait RGB: Pixel {
type ColorChannel: Channel;
fn red(&self) -> Self::ColorChannel;
fn r(&self) -> Self::ColorChannel {
self.red()
}
fn green(&self) -> Self::ColorChannel;
fn g(&self) -> Self::ColorChannel {
self.green()
}
fn blue(&self) -> Self::ColorChannel;
fn b(&self) -> Self::ColorChannel {
self.blue()
}
}
pub trait AssociatedAlpha: RGB + Alpha {
fn to_unassociated<Out: UnassociatedAlpha>(&self) -> Out;
}
pub trait UnassociatedAlpha: RGB + Alpha {
fn to_associated<Out: AssociatedAlpha>(&self) -> Out;
}
pub trait Alpha {
type AlphaChannel: Channel;
fn alpha(&self) -> Self::AlphaChannel;
fn a(&self) -> Self::AlphaChannel {
self.alpha()
}
fn multiply_alpha(&self, alpha: Self::AlphaChannel) -> Self;
}
pub trait Depth {
type DepthChannel: Channel;
fn depth(&self) -> Self::DepthChannel;
fn d(&self) -> Self::DepthChannel {
self.depth()
}
}
pub trait ExtraChannels<const NUM: usize> {
type ChannelType: Channel;
fn extra_channels(&self) -> [Self::ChannelType; NUM];
}
pub trait Luminance {
type LuminanceChannel: Channel;
fn luminance(&self) -> Self::LuminanceChannel;
fn l(&self) -> Self::LuminanceChannel {
self.luminance()
}
}
// TODO: We might rename this to Raster at some point
pub trait Sample {
type Pixel: Pixel;
// TODO: Add an area parameter
fn sample(&self, pos: DVec2) -> Option<Self::Pixel>;
}
// TODO: We might rename this to Bitmap at some point
pub trait Raster {
type Pixel: Pixel;
fn width(&self) -> u32;
fn height(&self) -> u32;
fn get_pixel(&self, x: u32, y: u32) -> Option<Self::Pixel>;
}
pub trait RasterMut: Raster {
fn get_pixel_mut(&mut self, x: u32, y: u32) -> Option<&mut Self::Pixel>;
fn set_pixel(&mut self, x: u32, y: u32, pixel: Self::Pixel) {
*self.get_pixel_mut(x, y).unwrap() = pixel;
}
fn map_pixels<F: Fn(Self::Pixel) -> Self::Pixel>(&mut self, map_fn: F) {
for y in 0..self.height() {
for x in 0..self.width() {
let pixel = self.get_pixel(x, y).unwrap();
self.set_pixel(x, y, map_fn(pixel));
}
}
}
}
#[derive(Debug, Default)]
pub struct MapNode<MapFn> {
map_fn: MapFn,
@ -113,25 +293,28 @@ fn distance_node(input: (i32, i32)) -> f32 {
}
#[derive(Debug, Clone, Copy)]
pub struct ImageIndexIterNode;
pub struct ImageIndexIterNode<P> {
_p: core::marker::PhantomData<P>,
}
#[node_macro::node_fn(ImageIndexIterNode)]
fn image_index_iter_node(input: ImageSlice<'input>) -> core::ops::Range<u32> {
#[node_macro::node_fn(ImageIndexIterNode<_P>)]
fn image_index_iter_node<_P>(input: ImageSlice<'input, _P>) -> core::ops::Range<u32> {
0..(input.width * input.height)
}
#[derive(Debug)]
pub struct WindowNode<Radius: for<'i> Node<'i, (), Output = u32>, Image: for<'i> Node<'i, (), Output = ImageSlice<'i>>> {
pub struct WindowNode<P, Radius: for<'i> Node<'i, (), Output = u32>, Image: for<'i> Node<'i, (), Output = ImageSlice<'i, P>>> {
radius: Radius,
image: Image,
_pixel: core::marker::PhantomData<P>,
}
impl<'input, S0: 'input, S1: 'input> Node<'input, u32> for WindowNode<S0, S1>
impl<'input, P: 'input, S0: 'input, S1: 'input> Node<'input, u32> for WindowNode<P, S0, S1>
where
S0: for<'any_input> Node<'any_input, (), Output = u32>,
S1: for<'any_input> Node<'any_input, (), Output = ImageSlice<'any_input>>,
S1: for<'any_input> Node<'any_input, (), Output = ImageSlice<'any_input, P>>,
{
type Output = ImageWindowIterator<'input>;
type Output = ImageWindowIterator<'input, P>;
#[inline]
fn eval<'node: 'input>(&'node self, input: u32) -> Self::Output {
let radius = self.radius.eval(());
@ -142,13 +325,17 @@ where
}
}
}
impl<S0, S1> WindowNode<S0, S1>
impl<P, S0, S1> WindowNode<P, S0, S1>
where
S0: for<'any_input> Node<'any_input, (), Output = u32>,
S1: for<'any_input> Node<'any_input, (), Output = ImageSlice<'any_input>>,
S1: for<'any_input> Node<'any_input, (), Output = ImageSlice<'any_input, P>>,
{
pub const fn new(radius: S0, image: S1) -> Self {
Self { radius, image }
Self {
radius,
image,
_pixel: core::marker::PhantomData,
}
}
}
/*
@ -159,16 +346,16 @@ fn window_node(input: u32, radius: u32, image: ImageSlice<'input>) -> ImageWindo
}*/
#[derive(Debug, Clone, Copy)]
pub struct ImageWindowIterator<'a> {
image: ImageSlice<'a>,
pub struct ImageWindowIterator<'a, P> {
image: ImageSlice<'a, P>,
radius: u32,
index: u32,
x: u32,
y: u32,
}
impl<'a> ImageWindowIterator<'a> {
fn new(image: ImageSlice<'a>, radius: u32, index: u32) -> Self {
impl<'a, P> ImageWindowIterator<'a, P> {
fn new(image: ImageSlice<'a, P>, radius: u32, index: u32) -> Self {
let start_x = index as i32 % image.width as i32;
let start_y = index as i32 / image.width as i32;
let min_x = (start_x - radius as i32).max(0) as u32;
@ -185,8 +372,8 @@ impl<'a> ImageWindowIterator<'a> {
}
#[cfg(not(target_arch = "spirv"))]
impl<'a> Iterator for ImageWindowIterator<'a> {
type Item = (Color, (i32, i32));
impl<'a, P: Copy> Iterator for ImageWindowIterator<'a, P> {
type Item = (P, (i32, i32));
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let start_x = self.index as i32 % self.image.width as i32;
@ -255,20 +442,24 @@ where
#[cfg(target_arch = "spirv")]
const NOTHING: () = ();
use dyn_any::{DynAny, StaticType};
#[derive(Clone, Debug, PartialEq, DynAny, Copy)]
use dyn_any::{DynAny, StaticType, StaticTypeSized};
#[derive(Clone, Debug, PartialEq, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct ImageSlice<'a> {
pub struct ImageSlice<'a, Pixel> {
pub width: u32,
pub height: u32,
#[cfg(not(target_arch = "spirv"))]
pub data: &'a [Color],
pub data: &'a [Pixel],
#[cfg(target_arch = "spirv")]
pub data: &'a (),
}
impl<P: StaticTypeSized> StaticType for ImageSlice<'_, P> {
type Static = ImageSlice<'static, P::Static>;
}
#[allow(clippy::derivable_impls)]
impl<'a> Default for ImageSlice<'a> {
impl<'a, P> Default for ImageSlice<'a, P> {
#[cfg(not(target_arch = "spirv"))]
fn default() -> Self {
Self {
@ -287,7 +478,25 @@ impl<'a> Default for ImageSlice<'a> {
}
}
impl ImageSlice<'_> {
impl<P: Copy + Debug + Pixel> Raster for ImageSlice<'_, P> {
type Pixel = P;
#[cfg(not(target_arch = "spirv"))]
fn get_pixel(&self, x: u32, y: u32) -> Option<P> {
self.data.get((x + y * self.width) as usize).copied()
}
#[cfg(target_arch = "spirv")]
fn get_pixel(&self, _x: u32, _y: u32) -> P {
Color::default()
}
fn width(&self) -> u32 {
self.width
}
fn height(&self) -> u32 {
self.height
}
}
impl<P> ImageSlice<'_, P> {
#[cfg(not(target_arch = "spirv"))]
pub const fn empty() -> Self {
Self { width: 0, height: 0, data: &[] }
@ -295,28 +504,30 @@ impl ImageSlice<'_> {
}
#[cfg(not(target_arch = "spirv"))]
impl<'a> IntoIterator for ImageSlice<'a> {
type Item = &'a Color;
type IntoIter = core::slice::Iter<'a, Color>;
impl<'a, P: 'a> IntoIterator for ImageSlice<'a, P> {
type Item = &'a P;
type IntoIter = core::slice::Iter<'a, P>;
fn into_iter(self) -> Self::IntoIter {
self.data.iter()
}
}
#[cfg(not(target_arch = "spirv"))]
impl<'a> IntoIterator for &'a ImageSlice<'a> {
type Item = &'a Color;
type IntoIter = core::slice::Iter<'a, Color>;
impl<'a, P: 'a> IntoIterator for &'a ImageSlice<'a, P> {
type Item = &'a P;
type IntoIter = core::slice::Iter<'a, P>;
fn into_iter(self) -> Self::IntoIter {
self.data.iter()
}
}
#[derive(Debug)]
pub struct ImageDimensionsNode;
pub struct ImageDimensionsNode<P> {
_p: PhantomData<P>,
}
#[node_macro::node_fn(ImageDimensionsNode)]
fn dimensions_node(input: ImageSlice<'input>) -> (u32, u32) {
#[node_macro::node_fn(ImageDimensionsNode<_P>)]
fn dimensions_node<_P>(input: ImageSlice<'input, _P>) -> (u32, u32) {
(input.width, input.height)
}
@ -335,27 +546,24 @@ mod image {
mod base64_serde {
//! Basic wrapper for [`serde`] for [`base64`] encoding
use crate::Color;
use super::super::Pixel;
use serde::{Deserialize, Deserializer, Serializer};
pub fn as_base64<S>(key: &[Color], serializer: S) -> Result<S::Ok, S::Error>
pub fn as_base64<S, P: Pixel>(key: &Vec<P>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let u8_data = key
.iter()
.flat_map(|color| [color.r(), color.g(), color.b(), color.a()].into_iter().map(|channel| (channel * 255.).clamp(0., 255.) as u8))
.collect::<Vec<_>>();
let u8_data = key.iter().flat_map(|color| color.to_bytes()).collect::<Vec<_>>();
serializer.serialize_str(&base64::encode(u8_data))
}
pub fn from_base64<'a, D>(deserializer: D) -> Result<Vec<Color>, D::Error>
pub fn from_base64<'a, D, P: Pixel>(deserializer: D) -> Result<Vec<P>, D::Error>
where
D: Deserializer<'a>,
{
use serde::de::Error;
let color_from_chunk = |chunk: &[u8]| Color::from_rgba8(chunk[0], chunk[1], chunk[2], chunk[3]);
let color_from_chunk = |chunk: &[u8]| P::from_bytes(chunk.try_into().unwrap()).clone();
let colors_from_bytes = |bytes: Vec<u8>| bytes.chunks_exact(4).map(color_from_chunk).collect();
@ -366,16 +574,42 @@ mod image {
}
}
#[derive(Clone, Debug, PartialEq, DynAny, Default, specta::Type)]
#[derive(Clone, Debug, PartialEq, Default, specta::Type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Image {
pub struct Image<P: Pixel> {
pub width: u32,
pub height: u32,
#[cfg_attr(feature = "serde", serde(serialize_with = "base64_serde::as_base64", deserialize_with = "base64_serde::from_base64"))]
pub data: Vec<Color>,
pub data: Vec<P>,
}
impl Hash for Image {
impl<P: StaticTypeSized + Pixel> StaticType for Image<P>
where
P::Static: Pixel,
{
type Static = Image<P::Static>;
}
impl<P: Copy + Pixel> Raster for Image<P> {
type Pixel = P;
fn get_pixel(&self, x: u32, y: u32) -> Option<P> {
self.data.get((x + y * self.width) as usize).copied()
}
fn width(&self) -> u32 {
self.width
}
fn height(&self) -> u32 {
self.height
}
}
impl<P: Copy + Pixel> RasterMut for Image<P> {
fn get_pixel_mut(&mut self, x: u32, y: u32) -> Option<&mut P> {
self.data.get_mut((x + y * self.width) as usize)
}
}
impl<P: Hash + Pixel> Hash for Image<P> {
fn hash<H: Hasher>(&self, state: &mut H) {
const HASH_SAMPLES: u64 = 1000;
let data_length = self.data.len() as u64;
@ -387,7 +621,7 @@ mod image {
}
}
impl Image {
impl<P: Pixel> Image<P> {
pub const fn empty() -> Self {
Self {
width: 0,
@ -396,7 +630,7 @@ mod image {
}
}
pub fn new(width: u32, height: u32, color: Color) -> Self {
pub fn new(width: u32, height: u32, color: P) -> Self {
Self {
width,
height,
@ -404,51 +638,66 @@ mod image {
}
}
pub fn as_slice(&self) -> ImageSlice {
pub fn as_slice(&self) -> ImageSlice<P> {
ImageSlice {
width: self.width,
height: self.height,
data: self.data.as_slice(),
}
}
}
pub fn get_mut(&mut self, x: u32, y: u32) -> Option<&mut Color> {
self.data.get_mut((y * self.width + x) as usize)
}
pub fn get(&self, x: u32, y: u32) -> Option<&Color> {
self.data.get((y * self.width + x) as usize)
}
impl Image<Color> {
/// Generate Image from some frontend image data (the canvas pixels as u8s in a flat array)
pub fn from_image_data(image_data: &[u8], width: u32, height: u32) -> Self {
let data = image_data.chunks_exact(4).map(|v| Color::from_rgba8(v[0], v[1], v[2], v[3])).collect();
let data = image_data.chunks_exact(4).map(|v| Color::from_rgba8_srgb(v[0], v[1], v[2], v[3])).collect();
Image { width, height, data }
}
}
use super::*;
impl<P: Alpha + RGB> Image<P>
where
P::ColorChannel: Linear,
{
/// Flattens each channel cast to a u8
pub fn into_flat_u8(self) -> (Vec<u8>, u32, u32) {
let Image { width, height, data } = self;
let result_bytes = data.into_iter().flat_map(|color| color.to_rgba8()).collect();
let to_gamma = |x| SRGBGammaFloat::from_linear(x);
let to_u8 = |x| (num::cast::<_, f32>(x).unwrap() * 255.) as u8;
let result_bytes = data
.into_iter()
.flat_map(|color| {
[
to_u8(to_gamma(color.r())),
to_u8(to_gamma(color.g())),
to_u8(to_gamma(color.b())),
(num::cast::<_, f32>(color.a()).unwrap() * 255.) as u8,
]
})
.collect();
(result_bytes, width, height)
}
}
impl IntoIterator for Image {
type Item = Color;
type IntoIter = alloc::vec::IntoIter<Color>;
impl<P: Pixel> IntoIterator for Image<P> {
type Item = P;
type IntoIter = alloc::vec::IntoIter<P>;
fn into_iter(self) -> Self::IntoIter {
self.data.into_iter()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ImageRefNode;
pub struct ImageRefNode<P> {
_p: PhantomData<P>,
}
#[node_macro::node_fn(ImageRefNode)]
fn image_ref_node(image: &'input Image) -> ImageSlice<'input> {
#[node_macro::node_fn(ImageRefNode<_P>)]
fn image_ref_node<_P: Pixel>(image: &'input Image<_P>) -> ImageSlice<'input, _P> {
image.as_slice()
}
@ -469,7 +718,7 @@ mod image {
}
#[node_macro::node_fn(MapImageSliceNode)]
fn map_node(input: (u32, u32), data: Vec<Color>) -> Image {
fn map_node<P: Pixel>(input: (u32, u32), data: Vec<P>) -> Image<P> {
Image {
width: input.0,
height: input.1,
@ -477,14 +726,56 @@ mod image {
}
}
#[derive(Clone, Debug, PartialEq, DynAny, Default, specta::Type)]
#[derive(Clone, Debug, PartialEq, Default, specta::Type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ImageFrame {
pub image: Image,
pub struct ImageFrame<P: Pixel> {
pub image: Image<P>,
pub transform: DAffine2,
}
impl ImageFrame {
impl<P: Debug + Copy + Pixel> Sample for ImageFrame<P> {
type Pixel = P;
fn sample(&self, pos: DVec2) -> Option<Self::Pixel> {
let image_size = DVec2::new(self.image.width() as f64, self.image.height() as f64);
let pos = (DAffine2::from_scale(image_size) * self.transform.inverse()).transform_point2(pos);
if pos.x < 0. || pos.y < 0. || pos.x >= image_size.x || pos.y >= image_size.y {
return None;
}
self.image.get_pixel(pos.x as u32, pos.y as u32)
}
}
impl<P: Copy + Pixel> Raster for ImageFrame<P> {
type Pixel = P;
fn width(&self) -> u32 {
self.image.width()
}
fn height(&self) -> u32 {
self.image.height()
}
fn get_pixel(&self, x: u32, y: u32) -> Option<Self::Pixel> {
self.image.get_pixel(x, y)
}
}
impl<P: Copy + Pixel> RasterMut for ImageFrame<P> {
fn get_pixel_mut(&mut self, x: u32, y: u32) -> Option<&mut Self::Pixel> {
self.image.get_pixel_mut(x, y)
}
}
impl<P: StaticTypeSized + Pixel> StaticType for ImageFrame<P>
where
P::Static: Pixel,
{
type Static = ImageFrame<P::Static>;
}
impl<P: Copy + Pixel> ImageFrame<P> {
pub const fn empty() -> Self {
Self {
image: Image::empty(),
@ -492,12 +783,12 @@ mod image {
}
}
pub fn get_mut(&mut self, x: usize, y: usize) -> &mut Color {
pub fn get_mut(&mut self, x: usize, y: usize) -> &mut P {
&mut self.image.data[y * (self.image.width as usize) + x]
}
/// Clamps the provided point to ((0, 0), (ImageSize.x, ImageSize.y)) and returns the closest pixel
pub fn sample(&self, position: DVec2) -> Color {
pub fn sample(&self, position: DVec2) -> P {
let x = position.x.clamp(0., self.image.width as f64 - 1.) as usize;
let y = position.y.clamp(0., self.image.height as f64 - 1.) as usize;
@ -505,13 +796,13 @@ mod image {
}
}
impl AsRef<ImageFrame> for ImageFrame {
fn as_ref(&self) -> &ImageFrame {
impl<P: Pixel> AsRef<ImageFrame<P>> for ImageFrame<P> {
fn as_ref(&self) -> &ImageFrame<P> {
self
}
}
impl Hash for ImageFrame {
impl<P: Hash + Pixel> Hash for ImageFrame<P> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.image.hash(state);
self.transform.to_cols_array().iter().for_each(|x| x.to_bits().hash(state))