diff --git a/crates/ruff_notebook/src/cell.rs b/crates/ruff_notebook/src/cell.rs index b43087b52b..196bd9c3d6 100644 --- a/crates/ruff_notebook/src/cell.rs +++ b/crates/ruff_notebook/src/cell.rs @@ -31,6 +31,18 @@ impl Cell { } } + pub fn is_code_cell(&self) -> bool { + matches!(self, Cell::Code(_)) + } + + pub fn metadata(&self) -> &serde_json::Value { + match self { + Cell::Code(cell) => &cell.metadata, + Cell::Markdown(cell) => &cell.metadata, + Cell::Raw(cell) => &cell.metadata, + } + } + /// Update the [`SourceValue`] of the cell. pub(crate) fn set_source(&mut self, source: SourceValue) { match self { diff --git a/crates/ruff_notebook/src/notebook.rs b/crates/ruff_notebook/src/notebook.rs index ed9d986588..99408908a9 100644 --- a/crates/ruff_notebook/src/notebook.rs +++ b/crates/ruff_notebook/src/notebook.rs @@ -19,6 +19,7 @@ use ruff_text_size::TextSize; use crate::cell::CellOffsets; use crate::index::NotebookIndex; use crate::schema::{Cell, RawNotebook, SortAlphabetically, SourceValue}; +use crate::RawNotebookMetadata; /// Run round-trip source code generation on a given Jupyter notebook file path. pub fn round_trip(path: &Path) -> anyhow::Result { @@ -383,6 +384,10 @@ impl Notebook { &self.raw.cells } + pub fn metadata(&self) -> &RawNotebookMetadata { + &self.raw.metadata + } + /// Return `true` if the notebook is a Python notebook, `false` otherwise. pub fn is_python_notebook(&self) -> bool { self.raw diff --git a/crates/ruff_server/resources/test/fixtures/tensorflow_test_notebook.ipynb b/crates/ruff_server/resources/test/fixtures/tensorflow_test_notebook.ipynb new file mode 100644 index 0000000000..91f7122340 --- /dev/null +++ b/crates/ruff_server/resources/test/fixtures/tensorflow_test_notebook.ipynb @@ -0,0 +1,353 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "JfOIB1KdkbYW" + }, + "source": [ + "##### Copyright 2020 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Ojb0aXCmBgo7" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M9Y4JZ0ZGoE4" + }, + "source": [ + "# Super resolution with TensorFlow Lite" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "q3FoFSLBjIYK" + }, + "source": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + " \n", + " See TF Hub model\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-uF3N4BbaMvA" + }, + "source": [ + "## Overview" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "isbXET4vVHfu" + }, + "source": [ + "The task of recovering a high resolution (HR) image from its low resolution counterpart is commonly referred to as Single Image Super Resolution (SISR). \n", + "\n", + "The model used here is ESRGAN\n", + "([ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks](https://arxiv.org/abs/1809.00219)). And we are going to use TensorFlow Lite to run inference on the pretrained model.\n", + "\n", + "The TFLite model is converted from this\n", + "[implementation](https://tfhub.dev/captain-pool/esrgan-tf2/1) hosted on TF Hub. Note that the model we converted upsamples a 50x50 low resolution image to a 200x200 high resolution image (scale factor=4). If you want a different input size or scale factor, you need to re-convert or re-train the original model." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2dQlTqiffuoU" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qKyMtsGqu3zH" + }, + "source": [ + "Let's install required libraries first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7YTT1Rxsw3A9" + }, + "outputs": [], + "source": [ + "!pip install matplotlib tensorflow tensorflow-hub" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Clz5Kl97FswD" + }, + "source": [ + "Import dependencies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2xh1kvGEBjuP" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_hub as hub\n", + "import matplotlib.pyplot as plt\n", + "print(tf.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i5miVfL4kxTA" + }, + "source": [ + "Download and convert the ESRGAN model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X5PvXIXRwvHj" + }, + "outputs": [], + "source": [ + "model = hub.load(\"https://tfhub.dev/captain-pool/esrgan-tf2/1\")\n", + "concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n", + "\n", + "@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])\n", + "def f(input):\n", + " return concrete_func(input);\n", + "\n", + "converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)\n", + "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", + "tflite_model = converter.convert()\n", + "\n", + "# Save the TF Lite model.\n", + "with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:\n", + " f.write(tflite_model)\n", + "\n", + "esrgan_model_path = './ESRGAN.tflite'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jH5-xPkyUEqt" + }, + "source": [ + "Download a test image (insect head)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "suWiStTWgK6e" + }, + "outputs": [], + "source": [ + "test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rgQ4qRuFNpyW" + }, + "source": [ + "## Generate a super resolution image using TensorFlow Lite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "J9FV4btf02-2" + }, + "outputs": [], + "source": [ + "lr = tf.io.read_file(test_img_path)\n", + "lr = tf.image.decode_jpeg(lr)\n", + "lr = tf.expand_dims(lr, axis=0)\n", + "lr = tf.cast(lr, tf.float32)\n", + "\n", + "# Load TFLite model and allocate tensors.\n", + "interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)\n", + "interpreter.allocate_tensors()\n", + "\n", + "# Get input and output tensors.\n", + "input_details = interpreter.get_input_details()\n", + "output_details = interpreter.get_output_details()\n", + "\n", + "# Run the model\n", + "interpreter.set_tensor(input_details[0]['index'], lr)\n", + "interpreter.invoke()\n", + "\n", + "# Extract the output and postprocess it\n", + "output_data = interpreter.get_tensor(output_details[0]['index'])\n", + "sr = tf.squeeze(output_data, axis=0)\n", + "sr = tf.clip_by_value(sr, 0, 255)\n", + "sr = tf.round(sr)\n", + "sr = tf.cast(sr, tf.uint8)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EwddQrDUNQGO" + }, + "source": [ + "## Visualize the result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aasKuozt1gNd" + }, + "outputs": [], + "source": [ + "lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)\n", + "plt.figure(figsize = (1, 1))\n", + "plt.title('LR')\n", + "plt.imshow(lr.numpy());\n", + "\n", + "plt.figure(figsize=(10, 4))\n", + "plt.subplot(1, 2, 1) \n", + "plt.title(f'ESRGAN (x4)')\n", + "plt.imshow(sr.numpy());\n", + "\n", + "bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)\n", + "bicubic = tf.cast(bicubic, tf.uint8)\n", + "plt.subplot(1, 2, 2) \n", + "plt.title('Bicubic')\n", + "plt.imshow(bicubic.numpy());" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0kb-fkogObjq" + }, + "source": [ + "## Performance Benchmarks" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tNzdgpqTy5P3" + }, + "source": [ + "Performance benchmark numbers are generated with the tool\n", + "[described here](https://www.tensorflow.org/lite/performance/benchmarks).\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
Model NameModel Size Device CPUGPU
\n", + " super resolution (ESRGAN)\n", + " \n", + " 4.8 Mb\n", + " Pixel 3586.8ms*128.6ms
Pixel 4385.1ms*130.3ms
\n", + "\n", + "**4 threads used*" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "super_resolution.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/crates/ruff_server/src/edit.rs b/crates/ruff_server/src/edit.rs index e711b36953..b88290dfec 100644 --- a/crates/ruff_server/src/edit.rs +++ b/crates/ruff_server/src/edit.rs @@ -8,7 +8,7 @@ mod text_document; use std::collections::HashMap; use lsp_types::{PositionEncodingKind, Url}; -pub(crate) use notebook::NotebookDocument; +pub use notebook::NotebookDocument; pub(crate) use range::{NotebookRange, RangeExt, ToRangeExt}; pub(crate) use replacement::Replacement; pub(crate) use text_document::DocumentVersion; @@ -34,7 +34,7 @@ pub enum PositionEncoding { /// A unique document ID, derived from a URL passed as part of an LSP request. /// This document ID can point to either be a standalone Python file, a full notebook, or a cell within a notebook. #[derive(Clone, Debug)] -pub(crate) enum DocumentKey { +pub enum DocumentKey { Notebook(Url), NotebookCell(Url), Text(Url), diff --git a/crates/ruff_server/src/edit/notebook.rs b/crates/ruff_server/src/edit/notebook.rs index d489d51265..ea6b3fe338 100644 --- a/crates/ruff_server/src/edit/notebook.rs +++ b/crates/ruff_server/src/edit/notebook.rs @@ -13,7 +13,7 @@ pub(super) type CellId = usize; /// The state of a notebook document in the server. Contains an array of cells whose /// contents are internally represented by [`TextDocument`]s. #[derive(Clone, Debug)] -pub(crate) struct NotebookDocument { +pub struct NotebookDocument { cells: Vec, metadata: ruff_notebook::RawNotebookMetadata, version: DocumentVersion, @@ -30,7 +30,7 @@ struct NotebookCell { } impl NotebookDocument { - pub(crate) fn new( + pub fn new( version: DocumentVersion, cells: Vec, metadata: serde_json::Map, @@ -59,7 +59,7 @@ impl NotebookDocument { /// Generates a pseudo-representation of a notebook that lacks per-cell metadata and contextual information /// but should still work with Ruff's linter. - pub(crate) fn make_ruff_notebook(&self) -> ruff_notebook::Notebook { + pub fn make_ruff_notebook(&self) -> ruff_notebook::Notebook { let cells = self .cells .iter() diff --git a/crates/ruff_server/src/lib.rs b/crates/ruff_server/src/lib.rs index e94a8df72d..595fe7c270 100644 --- a/crates/ruff_server/src/lib.rs +++ b/crates/ruff_server/src/lib.rs @@ -1,8 +1,9 @@ //! ## The Ruff Language Server -pub use edit::{PositionEncoding, TextDocument}; +pub use edit::{DocumentKey, NotebookDocument, PositionEncoding, TextDocument}; use lsp_types::CodeActionKind; pub use server::Server; +pub use session::{ClientSettings, DocumentQuery, DocumentSnapshot, Session}; #[macro_use] mod message; diff --git a/crates/ruff_server/src/session.rs b/crates/ruff_server/src/session.rs index a6072fb6c1..fb01d4fac5 100644 --- a/crates/ruff_server/src/session.rs +++ b/crates/ruff_server/src/session.rs @@ -8,15 +8,16 @@ use crate::edit::{DocumentKey, DocumentVersion, NotebookDocument}; use crate::{PositionEncoding, TextDocument}; pub(crate) use self::capabilities::ResolvedClientCapabilities; -pub(crate) use self::index::DocumentQuery; -pub(crate) use self::settings::{AllSettings, ClientSettings}; +pub use self::index::DocumentQuery; +pub(crate) use self::settings::AllSettings; +pub use self::settings::ClientSettings; mod capabilities; mod index; mod settings; /// The global state for the LSP -pub(crate) struct Session { +pub struct Session { /// Used to retrieve information about open documents and settings. index: index::Index, /// The global position encoding, negotiated during LSP initialization. @@ -29,7 +30,7 @@ pub(crate) struct Session { /// An immutable snapshot of `Session` that references /// a specific document. -pub(crate) struct DocumentSnapshot { +pub struct DocumentSnapshot { resolved_client_capabilities: Arc, client_settings: settings::ResolvedClientSettings, document_ref: index::DocumentQuery, @@ -37,7 +38,7 @@ pub(crate) struct DocumentSnapshot { } impl Session { - pub(crate) fn new( + pub fn new( client_capabilities: &ClientCapabilities, position_encoding: PositionEncoding, global_settings: ClientSettings, @@ -53,12 +54,12 @@ impl Session { }) } - pub(crate) fn key_from_url(&self, url: Url) -> DocumentKey { + pub fn key_from_url(&self, url: Url) -> DocumentKey { self.index.key_from_url(url) } /// Creates a document snapshot with the URL referencing the document to snapshot. - pub(crate) fn take_snapshot(&self, url: Url) -> Option { + pub fn take_snapshot(&self, url: Url) -> Option { let key = self.key_from_url(url); Some(DocumentSnapshot { resolved_client_capabilities: self.resolved_client_capabilities.clone(), @@ -98,7 +99,7 @@ impl Session { /// /// The document key must point to a notebook document or cell, or this will /// throw an error. - pub(crate) fn update_notebook_document( + pub fn update_notebook_document( &mut self, key: &DocumentKey, cells: Option, @@ -112,7 +113,7 @@ impl Session { /// Registers a notebook document at the provided `url`. /// If a document is already open here, it will be overwritten. - pub(crate) fn open_notebook_document(&mut self, url: Url, document: NotebookDocument) { + pub fn open_notebook_document(&mut self, url: Url, document: NotebookDocument) { self.index.open_notebook_document(url, document); } @@ -175,7 +176,7 @@ impl DocumentSnapshot { &self.client_settings } - pub(crate) fn query(&self) -> &index::DocumentQuery { + pub fn query(&self) -> &index::DocumentQuery { &self.document_ref } diff --git a/crates/ruff_server/src/session/index.rs b/crates/ruff_server/src/session/index.rs index 341e92cc73..4b5fdadbea 100644 --- a/crates/ruff_server/src/session/index.rs +++ b/crates/ruff_server/src/session/index.rs @@ -49,7 +49,7 @@ enum DocumentController { /// This query can 'select' a text document, full notebook, or a specific notebook cell. /// It also includes document settings. #[derive(Clone)] -pub(crate) enum DocumentQuery { +pub enum DocumentQuery { Text { file_url: Url, document: Arc, @@ -519,7 +519,7 @@ impl DocumentQuery { } /// Attempts to access the underlying notebook document that this query is selecting. - pub(crate) fn as_notebook(&self) -> Option<&NotebookDocument> { + pub fn as_notebook(&self) -> Option<&NotebookDocument> { match self { Self::Notebook { notebook, .. } => Some(notebook), Self::Text { .. } => None, diff --git a/crates/ruff_server/src/session/index/ruff_settings.rs b/crates/ruff_server/src/session/index/ruff_settings.rs index abb02a463e..39b35fa97b 100644 --- a/crates/ruff_server/src/session/index/ruff_settings.rs +++ b/crates/ruff_server/src/session/index/ruff_settings.rs @@ -18,7 +18,7 @@ use walkdir::WalkDir; use crate::session::settings::{ConfigurationPreference, ResolvedEditorSettings}; -pub(crate) struct RuffSettings { +pub struct RuffSettings { /// The path to this configuration file, used for debugging. /// The default fallback configuration does not have a file path. path: Option, diff --git a/crates/ruff_server/src/session/settings.rs b/crates/ruff_server/src/session/settings.rs index 0d3740d369..80ac4995a1 100644 --- a/crates/ruff_server/src/session/settings.rs +++ b/crates/ruff_server/src/session/settings.rs @@ -60,7 +60,7 @@ pub(crate) enum ConfigurationPreference { #[derive(Debug, Deserialize, Default)] #[cfg_attr(test, derive(PartialEq, Eq))] #[serde(rename_all = "camelCase")] -pub(crate) struct ClientSettings { +pub struct ClientSettings { configuration: Option, fix_all: Option, organize_imports: Option, diff --git a/crates/ruff_server/tests/notebook.rs b/crates/ruff_server/tests/notebook.rs new file mode 100644 index 0000000000..d639655fd0 --- /dev/null +++ b/crates/ruff_server/tests/notebook.rs @@ -0,0 +1,373 @@ +use std::{ + path::{Path, PathBuf}, + str::FromStr, +}; + +use lsp_types::{ + ClientCapabilities, LSPObject, NotebookDocumentCellChange, NotebookDocumentChangeTextContent, + Position, Range, TextDocumentContentChangeEvent, VersionedTextDocumentIdentifier, +}; +use ruff_notebook::SourceValue; +use ruff_server::ClientSettings; + +const SUPER_RESOLUTION_OVERVIEW_PATH: &str = + "./resources/test/fixtures/tensorflow_test_notebook.ipynb"; + +struct NotebookChange { + version: i32, + metadata: Option, + updated_cells: lsp_types::NotebookDocumentCellChange, +} + +#[test] +fn super_resolution_overview() { + let file_path = + std::path::absolute(PathBuf::from_str(SUPER_RESOLUTION_OVERVIEW_PATH).unwrap()).unwrap(); + let file_url = lsp_types::Url::from_file_path(&file_path).unwrap(); + let notebook = create_notebook(&file_path).unwrap(); + + insta::assert_snapshot!("initial_notebook", notebook_source(¬ebook)); + + let mut session = ruff_server::Session::new( + &ClientCapabilities::default(), + ruff_server::PositionEncoding::UTF16, + ClientSettings::default(), + vec![( + lsp_types::Url::from_file_path(file_path.parent().unwrap()).unwrap(), + ClientSettings::default(), + )], + ) + .unwrap(); + + session.open_notebook_document(file_url.clone(), notebook); + + let changes = [NotebookChange { + version: 0, + metadata: None, + updated_cells: NotebookDocumentCellChange { + structure: None, + data: None, + text_content: Some(vec![NotebookDocumentChangeTextContent { + document: VersionedTextDocumentIdentifier { + uri: make_cell_uri(&file_path, 5), + version: 2, + }, + changes: vec![ + TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 18, + character: 61, + }, + end: Position { + line: 18, + character: 62, + }, + }), + range_length: Some(1), + text: "\"".to_string(), + }, + TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 18, + character: 55, + }, + end: Position { + line: 18, + character: 56, + }, + }), + range_length: Some(1), + text: "\"".to_string(), + }, + TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 14, + character: 46, + }, + end: Position { + line: 14, + character: 47, + }, + }), + range_length: Some(1), + text: "\"".to_string(), + }, + TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 14, + character: 40, + }, + end: Position { + line: 14, + character: 41, + }, + }), + range_length: Some(1), + text: "\"".to_string(), + }, + ], + }]), + }, + }, + NotebookChange { + version: 1, + metadata: None, + updated_cells: NotebookDocumentCellChange { + structure: None, + data: None, + text_content: Some(vec![NotebookDocumentChangeTextContent { + document: VersionedTextDocumentIdentifier { + uri: make_cell_uri(&file_path, 4), + version: 2 + }, + changes: vec![TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 0, + character: 0 + }, + end: Position { + line: 0, + character: 181 + } }), + range_length: Some(181), + text: "test_img_path = tf.keras.utils.get_file(\n \"lr.jpg\",\n \"https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg\",\n)".to_string() + } + ] + } + ] + ) + } + }, + NotebookChange { + version: 2, + metadata: None, + updated_cells: NotebookDocumentCellChange { + structure: None, + data: None, + text_content: Some(vec![NotebookDocumentChangeTextContent { + document: VersionedTextDocumentIdentifier { + uri: make_cell_uri(&file_path, 2), + version: 2, + }, + changes: vec![TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 3, + character: 0, + }, + end: Position { + line: 3, + character: 21, + }, + }), + range_length: Some(21), + text: "\nprint(tf.__version__)".to_string(), + }], + }]), + } + }, + NotebookChange { + version: 3, + metadata: None, + updated_cells: NotebookDocumentCellChange { + structure: None, + data: None, + text_content: Some(vec![NotebookDocumentChangeTextContent { + document: VersionedTextDocumentIdentifier { + uri: make_cell_uri(&file_path, 1), + version: 2, + }, + changes: vec![TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 0, + character: 0, + }, + end: Position { + line: 0, + character: 49, + }, + }), + range_length: Some(49), + text: "!pip install matplotlib tensorflow tensorflow-hub".to_string(), + }], + }]), + }, + }, + NotebookChange { + version: 4, + metadata: None, + updated_cells: NotebookDocumentCellChange { + structure: None, + data: None, + text_content: Some(vec![NotebookDocumentChangeTextContent { + document: VersionedTextDocumentIdentifier { + uri: make_cell_uri(&file_path, 3), + version: 2, + }, + changes: vec![TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 3, + character: 0, + }, + end: Position { + line: 15, + character: 37, + }, + }), + range_length: Some(457), + text: "\n@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])\ndef f(input):\n return concrete_func(input)\n\n\nconverter = tf.lite.TFLiteConverter.from_concrete_functions(\n [f.get_concrete_function()], model\n)\nconverter.optimizations = [tf.lite.Optimize.DEFAULT]\ntflite_model = converter.convert()\n\n# Save the TF Lite model.\nwith tf.io.gfile.GFile(\"ESRGAN.tflite\", \"wb\") as f:\n f.write(tflite_model)\n\nesrgan_model_path = \"./ESRGAN.tflite\"".to_string(), + }], + }]), + }, + }, + NotebookChange { + version: 5, + metadata: None, + updated_cells: NotebookDocumentCellChange { + structure: None, + data: None, + text_content: Some(vec![NotebookDocumentChangeTextContent { + document: VersionedTextDocumentIdentifier { + uri: make_cell_uri(&file_path, 0), + version: 2, + }, + changes: vec![TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 0, + character: 0, + }, + end: Position { + line: 2, + character: 0, + }, + }), + range_length: Some(139), + text: "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n".to_string(), + }], + }]), + }, + }, + NotebookChange { + version: 6, + metadata: None, + updated_cells: NotebookDocumentCellChange { + structure: None, + data: None, + text_content: Some(vec![NotebookDocumentChangeTextContent { + document: VersionedTextDocumentIdentifier { + uri: make_cell_uri(&file_path, 6), + version: 2, + }, + changes: vec![TextDocumentContentChangeEvent { + range: Some(Range { + start: Position { + line: 1, + character: 0, + }, + end: Position { + line: 14, + character: 28, + }, + }), + range_length: Some(361), + text: "plt.figure(figsize=(1, 1))\nplt.title(\"LR\")\nplt.imshow(lr.numpy())\nplt.figure(figsize=(10, 4))\nplt.subplot(1, 2, 1)\nplt.title(f\"ESRGAN (x4)\")\nplt.imshow(sr.numpy())\nbicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)\nbicubic = tf.cast(bicubic, tf.uint8)\nplt.subplot(1, 2, 2)\nplt.title(\"Bicubic\")\nplt.imshow(bicubic.numpy());".to_string(), + }], + }]), + }, + } + ]; + + let key = session.key_from_url(file_url.clone()); + + for NotebookChange { + version, + metadata, + updated_cells, + } in changes + { + session + .update_notebook_document(&key, Some(updated_cells), metadata, version) + .unwrap(); + } + + let snapshot = session.take_snapshot(file_url.clone()).unwrap(); + + insta::assert_snapshot!( + "changed_notebook", + notebook_source(snapshot.query().as_notebook().unwrap()) + ); +} + +fn notebook_source(notebook: &ruff_server::NotebookDocument) -> String { + notebook.make_ruff_notebook().source_code().to_string() +} + +// produces an opaque URL based on a document path and a cell index +fn make_cell_uri(path: &Path, index: usize) -> lsp_types::Url { + lsp_types::Url::parse(&format!( + "notebook-cell:///Users/test/notebooks/{}.ipynb?cell={index}", + path.file_name().unwrap().to_string_lossy() + )) + .unwrap() +} + +fn create_notebook(file_path: &Path) -> anyhow::Result { + let ruff_notebook = ruff_notebook::Notebook::from_path(file_path)?; + + let mut cells = vec![]; + let mut cell_documents = vec![]; + for (i, cell) in ruff_notebook + .cells() + .iter() + .filter(|cell| cell.is_code_cell()) + .enumerate() + { + let uri = make_cell_uri(file_path, i); + let (lsp_cell, cell_document) = cell_to_lsp_cell(cell, uri)?; + cells.push(lsp_cell); + cell_documents.push(cell_document); + } + + let serde_json::Value::Object(metadata) = serde_json::to_value(ruff_notebook.metadata())? + else { + anyhow::bail!("Notebook metadata was not an object"); + }; + + ruff_server::NotebookDocument::new(0, cells, metadata, cell_documents) +} + +fn cell_to_lsp_cell( + cell: &ruff_notebook::Cell, + cell_uri: lsp_types::Url, +) -> anyhow::Result<(lsp_types::NotebookCell, lsp_types::TextDocumentItem)> { + let contents = match cell.source() { + SourceValue::String(string) => string.clone(), + SourceValue::StringArray(array) => array.join(""), + }; + let metadata = match serde_json::to_value(cell.metadata())? { + serde_json::Value::Null => None, + serde_json::Value::Object(metadata) => Some(metadata), + _ => anyhow::bail!("Notebook cell metadata was not an object"), + }; + Ok(( + lsp_types::NotebookCell { + kind: match cell { + ruff_notebook::Cell::Code(_) => lsp_types::NotebookCellKind::Code, + ruff_notebook::Cell::Markdown(_) => lsp_types::NotebookCellKind::Markup, + ruff_notebook::Cell::Raw(_) => unreachable!(), + }, + document: cell_uri.clone(), + metadata, + execution_summary: None, + }, + lsp_types::TextDocumentItem::new(cell_uri, "python".to_string(), 1, contents), + )) +} diff --git a/crates/ruff_server/tests/snapshots/notebook__changed_notebook.snap b/crates/ruff_server/tests/snapshots/notebook__changed_notebook.snap new file mode 100644 index 0000000000..a90e216678 --- /dev/null +++ b/crates/ruff_server/tests/snapshots/notebook__changed_notebook.snap @@ -0,0 +1,81 @@ +--- +source: crates/ruff_server/tests/notebook.rs +expression: notebook_source(snapshot.query().as_notebook().unwrap()) +--- +# @title Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +!pip install matplotlib tensorflow tensorflow-hub +import tensorflow as tf +import tensorflow_hub as hub +import matplotlib.pyplot as plt + +print(tf.__version__) +model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1") +concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + + +@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)]) +def f(input): + return concrete_func(input) + + +converter = tf.lite.TFLiteConverter.from_concrete_functions( + [f.get_concrete_function()], model +) +converter.optimizations = [tf.lite.Optimize.DEFAULT] +tflite_model = converter.convert() + +# Save the TF Lite model. +with tf.io.gfile.GFile("ESRGAN.tflite", "wb") as f: + f.write(tflite_model) + +esrgan_model_path = "./ESRGAN.tflite" +test_img_path = tf.keras.utils.get_file( + "lr.jpg", + "https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg", +) +lr = tf.io.read_file(test_img_path) +lr = tf.image.decode_jpeg(lr) +lr = tf.expand_dims(lr, axis=0) +lr = tf.cast(lr, tf.float32) + +# Load TFLite model and allocate tensors. +interpreter = tf.lite.Interpreter(model_path=esrgan_model_path) +interpreter.allocate_tensors() + +# Get input and output tensors. +input_details = interpreter.get_input_details() +output_details = interpreter.get_output_details() + +# Run the model +interpreter.set_tensor(input_details[0]["index"], lr) +interpreter.invoke() + +# Extract the output and postprocess it +output_data = interpreter.get_tensor(output_details[0]["index"]) +sr = tf.squeeze(output_data, axis=0) +sr = tf.clip_by_value(sr, 0, 255) +sr = tf.round(sr) +sr = tf.cast(sr, tf.uint8) +lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8) +plt.figure(figsize=(1, 1)) +plt.title("LR") +plt.imshow(lr.numpy()) +plt.figure(figsize=(10, 4)) +plt.subplot(1, 2, 1) +plt.title(f"ESRGAN (x4)") +plt.imshow(sr.numpy()) +bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC) +bicubic = tf.cast(bicubic, tf.uint8) +plt.subplot(1, 2, 2) +plt.title("Bicubic") +plt.imshow(bicubic.numpy()); diff --git a/crates/ruff_server/tests/snapshots/notebook__initial_notebook.snap b/crates/ruff_server/tests/snapshots/notebook__initial_notebook.snap new file mode 100644 index 0000000000..29a2872058 --- /dev/null +++ b/crates/ruff_server/tests/snapshots/notebook__initial_notebook.snap @@ -0,0 +1,75 @@ +--- +source: crates/ruff_server/tests/notebook.rs +expression: notebook_source(¬ebook) +--- +#@title Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +!pip install matplotlib tensorflow tensorflow-hub +import tensorflow as tf +import tensorflow_hub as hub +import matplotlib.pyplot as plt +print(tf.__version__) +model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1") +concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + +@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)]) +def f(input): + return concrete_func(input); + +converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model) +converter.optimizations = [tf.lite.Optimize.DEFAULT] +tflite_model = converter.convert() + +# Save the TF Lite model. +with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f: + f.write(tflite_model) + +esrgan_model_path = './ESRGAN.tflite' +test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg') +lr = tf.io.read_file(test_img_path) +lr = tf.image.decode_jpeg(lr) +lr = tf.expand_dims(lr, axis=0) +lr = tf.cast(lr, tf.float32) + +# Load TFLite model and allocate tensors. +interpreter = tf.lite.Interpreter(model_path=esrgan_model_path) +interpreter.allocate_tensors() + +# Get input and output tensors. +input_details = interpreter.get_input_details() +output_details = interpreter.get_output_details() + +# Run the model +interpreter.set_tensor(input_details[0]['index'], lr) +interpreter.invoke() + +# Extract the output and postprocess it +output_data = interpreter.get_tensor(output_details[0]['index']) +sr = tf.squeeze(output_data, axis=0) +sr = tf.clip_by_value(sr, 0, 255) +sr = tf.round(sr) +sr = tf.cast(sr, tf.uint8) +lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8) +plt.figure(figsize = (1, 1)) +plt.title('LR') +plt.imshow(lr.numpy()); + +plt.figure(figsize=(10, 4)) +plt.subplot(1, 2, 1) +plt.title(f'ESRGAN (x4)') +plt.imshow(sr.numpy()); + +bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC) +bicubic = tf.cast(bicubic, tf.uint8) +plt.subplot(1, 2, 2) +plt.title('Bicubic') +plt.imshow(bicubic.numpy());