Add Jupyter Notebook document change snapshot test (#11944)

## Summary

Closes #11914.

This PR introduces a snapshot test that replays the LSP requests made
during a document formatting request, and confirms that the notebook
document is updated in the expected way.
This commit is contained in:
Jane Lewis 2024-06-20 22:29:27 -07:00 committed by GitHub
parent 927069c12f
commit 3ab7a8da73
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 921 additions and 20 deletions

View file

@ -31,6 +31,18 @@ impl Cell {
}
}
pub fn is_code_cell(&self) -> bool {
matches!(self, Cell::Code(_))
}
pub fn metadata(&self) -> &serde_json::Value {
match self {
Cell::Code(cell) => &cell.metadata,
Cell::Markdown(cell) => &cell.metadata,
Cell::Raw(cell) => &cell.metadata,
}
}
/// Update the [`SourceValue`] of the cell.
pub(crate) fn set_source(&mut self, source: SourceValue) {
match self {

View file

@ -19,6 +19,7 @@ use ruff_text_size::TextSize;
use crate::cell::CellOffsets;
use crate::index::NotebookIndex;
use crate::schema::{Cell, RawNotebook, SortAlphabetically, SourceValue};
use crate::RawNotebookMetadata;
/// Run round-trip source code generation on a given Jupyter notebook file path.
pub fn round_trip(path: &Path) -> anyhow::Result<String> {
@ -383,6 +384,10 @@ impl Notebook {
&self.raw.cells
}
pub fn metadata(&self) -> &RawNotebookMetadata {
&self.raw.metadata
}
/// Return `true` if the notebook is a Python notebook, `false` otherwise.
pub fn is_python_notebook(&self) -> bool {
self.raw

View file

@ -0,0 +1,353 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "JfOIB1KdkbYW"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "Ojb0aXCmBgo7"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M9Y4JZ0ZGoE4"
},
"source": [
"# Super resolution with TensorFlow Lite"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q3FoFSLBjIYK"
},
"source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://www.tensorflow.org/lite/examples/super_resolution/overview\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/super_resolution/overview.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/super_resolution/overview.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" </td>\n",
" <td>\n",
" <a href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/super_resolution/overview.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
" </td>\n",
" <td>\n",
" <a href=\"https://tfhub.dev/captain-pool/esrgan-tf2/1\"><img src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" />See TF Hub model</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-uF3N4BbaMvA"
},
"source": [
"## Overview"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "isbXET4vVHfu"
},
"source": [
"The task of recovering a high resolution (HR) image from its low resolution counterpart is commonly referred to as Single Image Super Resolution (SISR). \n",
"\n",
"The model used here is ESRGAN\n",
"([ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks](https://arxiv.org/abs/1809.00219)). And we are going to use TensorFlow Lite to run inference on the pretrained model.\n",
"\n",
"The TFLite model is converted from this\n",
"[implementation](https://tfhub.dev/captain-pool/esrgan-tf2/1) hosted on TF Hub. Note that the model we converted upsamples a 50x50 low resolution image to a 200x200 high resolution image (scale factor=4). If you want a different input size or scale factor, you need to re-convert or re-train the original model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2dQlTqiffuoU"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qKyMtsGqu3zH"
},
"source": [
"Let's install required libraries first."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7YTT1Rxsw3A9"
},
"outputs": [],
"source": [
"!pip install matplotlib tensorflow tensorflow-hub"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Clz5Kl97FswD"
},
"source": [
"Import dependencies."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2xh1kvGEBjuP"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import tensorflow_hub as hub\n",
"import matplotlib.pyplot as plt\n",
"print(tf.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i5miVfL4kxTA"
},
"source": [
"Download and convert the ESRGAN model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X5PvXIXRwvHj"
},
"outputs": [],
"source": [
"model = hub.load(\"https://tfhub.dev/captain-pool/esrgan-tf2/1\")\n",
"concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n",
"\n",
"@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])\n",
"def f(input):\n",
" return concrete_func(input);\n",
"\n",
"converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)\n",
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
"tflite_model = converter.convert()\n",
"\n",
"# Save the TF Lite model.\n",
"with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:\n",
" f.write(tflite_model)\n",
"\n",
"esrgan_model_path = './ESRGAN.tflite'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jH5-xPkyUEqt"
},
"source": [
"Download a test image (insect head)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "suWiStTWgK6e"
},
"outputs": [],
"source": [
"test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rgQ4qRuFNpyW"
},
"source": [
"## Generate a super resolution image using TensorFlow Lite"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "J9FV4btf02-2"
},
"outputs": [],
"source": [
"lr = tf.io.read_file(test_img_path)\n",
"lr = tf.image.decode_jpeg(lr)\n",
"lr = tf.expand_dims(lr, axis=0)\n",
"lr = tf.cast(lr, tf.float32)\n",
"\n",
"# Load TFLite model and allocate tensors.\n",
"interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)\n",
"interpreter.allocate_tensors()\n",
"\n",
"# Get input and output tensors.\n",
"input_details = interpreter.get_input_details()\n",
"output_details = interpreter.get_output_details()\n",
"\n",
"# Run the model\n",
"interpreter.set_tensor(input_details[0]['index'], lr)\n",
"interpreter.invoke()\n",
"\n",
"# Extract the output and postprocess it\n",
"output_data = interpreter.get_tensor(output_details[0]['index'])\n",
"sr = tf.squeeze(output_data, axis=0)\n",
"sr = tf.clip_by_value(sr, 0, 255)\n",
"sr = tf.round(sr)\n",
"sr = tf.cast(sr, tf.uint8)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EwddQrDUNQGO"
},
"source": [
"## Visualize the result"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aasKuozt1gNd"
},
"outputs": [],
"source": [
"lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)\n",
"plt.figure(figsize = (1, 1))\n",
"plt.title('LR')\n",
"plt.imshow(lr.numpy());\n",
"\n",
"plt.figure(figsize=(10, 4))\n",
"plt.subplot(1, 2, 1) \n",
"plt.title(f'ESRGAN (x4)')\n",
"plt.imshow(sr.numpy());\n",
"\n",
"bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)\n",
"bicubic = tf.cast(bicubic, tf.uint8)\n",
"plt.subplot(1, 2, 2) \n",
"plt.title('Bicubic')\n",
"plt.imshow(bicubic.numpy());"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0kb-fkogObjq"
},
"source": [
"## Performance Benchmarks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tNzdgpqTy5P3"
},
"source": [
"Performance benchmark numbers are generated with the tool\n",
"[described here](https://www.tensorflow.org/lite/performance/benchmarks).\n",
"\n",
"<table>\n",
" <thead>\n",
" <tr>\n",
" <th>Model Name</th>\n",
" <th>Model Size </th>\n",
" <th>Device </th>\n",
" <th>CPU</th>\n",
" <th>GPU</th>\n",
" </tr>\n",
" </thead>\n",
" <tr>\n",
" <td rowspan = 3>\n",
" super resolution (ESRGAN)\n",
" </td>\n",
" <td rowspan = 3>\n",
" 4.8 Mb\n",
" </td>\n",
" <td>Pixel 3</td>\n",
" <td>586.8ms*</td>\n",
" <td>128.6ms</td>\n",
" </tr>\n",
" <tr>\n",
" <td>Pixel 4</td>\n",
" <td>385.1ms*</td>\n",
" <td>130.3ms</td>\n",
" </tr>\n",
"\n",
"</table>\n",
"\n",
"**4 threads used*"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "super_resolution.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View file

@ -8,7 +8,7 @@ mod text_document;
use std::collections::HashMap;
use lsp_types::{PositionEncodingKind, Url};
pub(crate) use notebook::NotebookDocument;
pub use notebook::NotebookDocument;
pub(crate) use range::{NotebookRange, RangeExt, ToRangeExt};
pub(crate) use replacement::Replacement;
pub(crate) use text_document::DocumentVersion;
@ -34,7 +34,7 @@ pub enum PositionEncoding {
/// A unique document ID, derived from a URL passed as part of an LSP request.
/// This document ID can point to either be a standalone Python file, a full notebook, or a cell within a notebook.
#[derive(Clone, Debug)]
pub(crate) enum DocumentKey {
pub enum DocumentKey {
Notebook(Url),
NotebookCell(Url),
Text(Url),

View file

@ -13,7 +13,7 @@ pub(super) type CellId = usize;
/// The state of a notebook document in the server. Contains an array of cells whose
/// contents are internally represented by [`TextDocument`]s.
#[derive(Clone, Debug)]
pub(crate) struct NotebookDocument {
pub struct NotebookDocument {
cells: Vec<NotebookCell>,
metadata: ruff_notebook::RawNotebookMetadata,
version: DocumentVersion,
@ -30,7 +30,7 @@ struct NotebookCell {
}
impl NotebookDocument {
pub(crate) fn new(
pub fn new(
version: DocumentVersion,
cells: Vec<lsp_types::NotebookCell>,
metadata: serde_json::Map<String, serde_json::Value>,
@ -59,7 +59,7 @@ impl NotebookDocument {
/// Generates a pseudo-representation of a notebook that lacks per-cell metadata and contextual information
/// but should still work with Ruff's linter.
pub(crate) fn make_ruff_notebook(&self) -> ruff_notebook::Notebook {
pub fn make_ruff_notebook(&self) -> ruff_notebook::Notebook {
let cells = self
.cells
.iter()

View file

@ -1,8 +1,9 @@
//! ## The Ruff Language Server
pub use edit::{PositionEncoding, TextDocument};
pub use edit::{DocumentKey, NotebookDocument, PositionEncoding, TextDocument};
use lsp_types::CodeActionKind;
pub use server::Server;
pub use session::{ClientSettings, DocumentQuery, DocumentSnapshot, Session};
#[macro_use]
mod message;

View file

@ -8,15 +8,16 @@ use crate::edit::{DocumentKey, DocumentVersion, NotebookDocument};
use crate::{PositionEncoding, TextDocument};
pub(crate) use self::capabilities::ResolvedClientCapabilities;
pub(crate) use self::index::DocumentQuery;
pub(crate) use self::settings::{AllSettings, ClientSettings};
pub use self::index::DocumentQuery;
pub(crate) use self::settings::AllSettings;
pub use self::settings::ClientSettings;
mod capabilities;
mod index;
mod settings;
/// The global state for the LSP
pub(crate) struct Session {
pub struct Session {
/// Used to retrieve information about open documents and settings.
index: index::Index,
/// The global position encoding, negotiated during LSP initialization.
@ -29,7 +30,7 @@ pub(crate) struct Session {
/// An immutable snapshot of `Session` that references
/// a specific document.
pub(crate) struct DocumentSnapshot {
pub struct DocumentSnapshot {
resolved_client_capabilities: Arc<ResolvedClientCapabilities>,
client_settings: settings::ResolvedClientSettings,
document_ref: index::DocumentQuery,
@ -37,7 +38,7 @@ pub(crate) struct DocumentSnapshot {
}
impl Session {
pub(crate) fn new(
pub fn new(
client_capabilities: &ClientCapabilities,
position_encoding: PositionEncoding,
global_settings: ClientSettings,
@ -53,12 +54,12 @@ impl Session {
})
}
pub(crate) fn key_from_url(&self, url: Url) -> DocumentKey {
pub fn key_from_url(&self, url: Url) -> DocumentKey {
self.index.key_from_url(url)
}
/// Creates a document snapshot with the URL referencing the document to snapshot.
pub(crate) fn take_snapshot(&self, url: Url) -> Option<DocumentSnapshot> {
pub fn take_snapshot(&self, url: Url) -> Option<DocumentSnapshot> {
let key = self.key_from_url(url);
Some(DocumentSnapshot {
resolved_client_capabilities: self.resolved_client_capabilities.clone(),
@ -98,7 +99,7 @@ impl Session {
///
/// The document key must point to a notebook document or cell, or this will
/// throw an error.
pub(crate) fn update_notebook_document(
pub fn update_notebook_document(
&mut self,
key: &DocumentKey,
cells: Option<NotebookDocumentCellChange>,
@ -112,7 +113,7 @@ impl Session {
/// Registers a notebook document at the provided `url`.
/// If a document is already open here, it will be overwritten.
pub(crate) fn open_notebook_document(&mut self, url: Url, document: NotebookDocument) {
pub fn open_notebook_document(&mut self, url: Url, document: NotebookDocument) {
self.index.open_notebook_document(url, document);
}
@ -175,7 +176,7 @@ impl DocumentSnapshot {
&self.client_settings
}
pub(crate) fn query(&self) -> &index::DocumentQuery {
pub fn query(&self) -> &index::DocumentQuery {
&self.document_ref
}

View file

@ -49,7 +49,7 @@ enum DocumentController {
/// This query can 'select' a text document, full notebook, or a specific notebook cell.
/// It also includes document settings.
#[derive(Clone)]
pub(crate) enum DocumentQuery {
pub enum DocumentQuery {
Text {
file_url: Url,
document: Arc<TextDocument>,
@ -519,7 +519,7 @@ impl DocumentQuery {
}
/// Attempts to access the underlying notebook document that this query is selecting.
pub(crate) fn as_notebook(&self) -> Option<&NotebookDocument> {
pub fn as_notebook(&self) -> Option<&NotebookDocument> {
match self {
Self::Notebook { notebook, .. } => Some(notebook),
Self::Text { .. } => None,

View file

@ -18,7 +18,7 @@ use walkdir::WalkDir;
use crate::session::settings::{ConfigurationPreference, ResolvedEditorSettings};
pub(crate) struct RuffSettings {
pub struct RuffSettings {
/// The path to this configuration file, used for debugging.
/// The default fallback configuration does not have a file path.
path: Option<PathBuf>,

View file

@ -60,7 +60,7 @@ pub(crate) enum ConfigurationPreference {
#[derive(Debug, Deserialize, Default)]
#[cfg_attr(test, derive(PartialEq, Eq))]
#[serde(rename_all = "camelCase")]
pub(crate) struct ClientSettings {
pub struct ClientSettings {
configuration: Option<String>,
fix_all: Option<bool>,
organize_imports: Option<bool>,

View file

@ -0,0 +1,373 @@
use std::{
path::{Path, PathBuf},
str::FromStr,
};
use lsp_types::{
ClientCapabilities, LSPObject, NotebookDocumentCellChange, NotebookDocumentChangeTextContent,
Position, Range, TextDocumentContentChangeEvent, VersionedTextDocumentIdentifier,
};
use ruff_notebook::SourceValue;
use ruff_server::ClientSettings;
const SUPER_RESOLUTION_OVERVIEW_PATH: &str =
"./resources/test/fixtures/tensorflow_test_notebook.ipynb";
struct NotebookChange {
version: i32,
metadata: Option<LSPObject>,
updated_cells: lsp_types::NotebookDocumentCellChange,
}
#[test]
fn super_resolution_overview() {
let file_path =
std::path::absolute(PathBuf::from_str(SUPER_RESOLUTION_OVERVIEW_PATH).unwrap()).unwrap();
let file_url = lsp_types::Url::from_file_path(&file_path).unwrap();
let notebook = create_notebook(&file_path).unwrap();
insta::assert_snapshot!("initial_notebook", notebook_source(&notebook));
let mut session = ruff_server::Session::new(
&ClientCapabilities::default(),
ruff_server::PositionEncoding::UTF16,
ClientSettings::default(),
vec![(
lsp_types::Url::from_file_path(file_path.parent().unwrap()).unwrap(),
ClientSettings::default(),
)],
)
.unwrap();
session.open_notebook_document(file_url.clone(), notebook);
let changes = [NotebookChange {
version: 0,
metadata: None,
updated_cells: NotebookDocumentCellChange {
structure: None,
data: None,
text_content: Some(vec![NotebookDocumentChangeTextContent {
document: VersionedTextDocumentIdentifier {
uri: make_cell_uri(&file_path, 5),
version: 2,
},
changes: vec![
TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 18,
character: 61,
},
end: Position {
line: 18,
character: 62,
},
}),
range_length: Some(1),
text: "\"".to_string(),
},
TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 18,
character: 55,
},
end: Position {
line: 18,
character: 56,
},
}),
range_length: Some(1),
text: "\"".to_string(),
},
TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 14,
character: 46,
},
end: Position {
line: 14,
character: 47,
},
}),
range_length: Some(1),
text: "\"".to_string(),
},
TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 14,
character: 40,
},
end: Position {
line: 14,
character: 41,
},
}),
range_length: Some(1),
text: "\"".to_string(),
},
],
}]),
},
},
NotebookChange {
version: 1,
metadata: None,
updated_cells: NotebookDocumentCellChange {
structure: None,
data: None,
text_content: Some(vec![NotebookDocumentChangeTextContent {
document: VersionedTextDocumentIdentifier {
uri: make_cell_uri(&file_path, 4),
version: 2
},
changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 0,
character: 0
},
end: Position {
line: 0,
character: 181
} }),
range_length: Some(181),
text: "test_img_path = tf.keras.utils.get_file(\n \"lr.jpg\",\n \"https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg\",\n)".to_string()
}
]
}
]
)
}
},
NotebookChange {
version: 2,
metadata: None,
updated_cells: NotebookDocumentCellChange {
structure: None,
data: None,
text_content: Some(vec![NotebookDocumentChangeTextContent {
document: VersionedTextDocumentIdentifier {
uri: make_cell_uri(&file_path, 2),
version: 2,
},
changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 3,
character: 0,
},
end: Position {
line: 3,
character: 21,
},
}),
range_length: Some(21),
text: "\nprint(tf.__version__)".to_string(),
}],
}]),
}
},
NotebookChange {
version: 3,
metadata: None,
updated_cells: NotebookDocumentCellChange {
structure: None,
data: None,
text_content: Some(vec![NotebookDocumentChangeTextContent {
document: VersionedTextDocumentIdentifier {
uri: make_cell_uri(&file_path, 1),
version: 2,
},
changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 0,
character: 0,
},
end: Position {
line: 0,
character: 49,
},
}),
range_length: Some(49),
text: "!pip install matplotlib tensorflow tensorflow-hub".to_string(),
}],
}]),
},
},
NotebookChange {
version: 4,
metadata: None,
updated_cells: NotebookDocumentCellChange {
structure: None,
data: None,
text_content: Some(vec![NotebookDocumentChangeTextContent {
document: VersionedTextDocumentIdentifier {
uri: make_cell_uri(&file_path, 3),
version: 2,
},
changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 3,
character: 0,
},
end: Position {
line: 15,
character: 37,
},
}),
range_length: Some(457),
text: "\n@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])\ndef f(input):\n return concrete_func(input)\n\n\nconverter = tf.lite.TFLiteConverter.from_concrete_functions(\n [f.get_concrete_function()], model\n)\nconverter.optimizations = [tf.lite.Optimize.DEFAULT]\ntflite_model = converter.convert()\n\n# Save the TF Lite model.\nwith tf.io.gfile.GFile(\"ESRGAN.tflite\", \"wb\") as f:\n f.write(tflite_model)\n\nesrgan_model_path = \"./ESRGAN.tflite\"".to_string(),
}],
}]),
},
},
NotebookChange {
version: 5,
metadata: None,
updated_cells: NotebookDocumentCellChange {
structure: None,
data: None,
text_content: Some(vec![NotebookDocumentChangeTextContent {
document: VersionedTextDocumentIdentifier {
uri: make_cell_uri(&file_path, 0),
version: 2,
},
changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 0,
character: 0,
},
end: Position {
line: 2,
character: 0,
},
}),
range_length: Some(139),
text: "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n".to_string(),
}],
}]),
},
},
NotebookChange {
version: 6,
metadata: None,
updated_cells: NotebookDocumentCellChange {
structure: None,
data: None,
text_content: Some(vec![NotebookDocumentChangeTextContent {
document: VersionedTextDocumentIdentifier {
uri: make_cell_uri(&file_path, 6),
version: 2,
},
changes: vec![TextDocumentContentChangeEvent {
range: Some(Range {
start: Position {
line: 1,
character: 0,
},
end: Position {
line: 14,
character: 28,
},
}),
range_length: Some(361),
text: "plt.figure(figsize=(1, 1))\nplt.title(\"LR\")\nplt.imshow(lr.numpy())\nplt.figure(figsize=(10, 4))\nplt.subplot(1, 2, 1)\nplt.title(f\"ESRGAN (x4)\")\nplt.imshow(sr.numpy())\nbicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)\nbicubic = tf.cast(bicubic, tf.uint8)\nplt.subplot(1, 2, 2)\nplt.title(\"Bicubic\")\nplt.imshow(bicubic.numpy());".to_string(),
}],
}]),
},
}
];
let key = session.key_from_url(file_url.clone());
for NotebookChange {
version,
metadata,
updated_cells,
} in changes
{
session
.update_notebook_document(&key, Some(updated_cells), metadata, version)
.unwrap();
}
let snapshot = session.take_snapshot(file_url.clone()).unwrap();
insta::assert_snapshot!(
"changed_notebook",
notebook_source(snapshot.query().as_notebook().unwrap())
);
}
fn notebook_source(notebook: &ruff_server::NotebookDocument) -> String {
notebook.make_ruff_notebook().source_code().to_string()
}
// produces an opaque URL based on a document path and a cell index
fn make_cell_uri(path: &Path, index: usize) -> lsp_types::Url {
lsp_types::Url::parse(&format!(
"notebook-cell:///Users/test/notebooks/{}.ipynb?cell={index}",
path.file_name().unwrap().to_string_lossy()
))
.unwrap()
}
fn create_notebook(file_path: &Path) -> anyhow::Result<ruff_server::NotebookDocument> {
let ruff_notebook = ruff_notebook::Notebook::from_path(file_path)?;
let mut cells = vec![];
let mut cell_documents = vec![];
for (i, cell) in ruff_notebook
.cells()
.iter()
.filter(|cell| cell.is_code_cell())
.enumerate()
{
let uri = make_cell_uri(file_path, i);
let (lsp_cell, cell_document) = cell_to_lsp_cell(cell, uri)?;
cells.push(lsp_cell);
cell_documents.push(cell_document);
}
let serde_json::Value::Object(metadata) = serde_json::to_value(ruff_notebook.metadata())?
else {
anyhow::bail!("Notebook metadata was not an object");
};
ruff_server::NotebookDocument::new(0, cells, metadata, cell_documents)
}
fn cell_to_lsp_cell(
cell: &ruff_notebook::Cell,
cell_uri: lsp_types::Url,
) -> anyhow::Result<(lsp_types::NotebookCell, lsp_types::TextDocumentItem)> {
let contents = match cell.source() {
SourceValue::String(string) => string.clone(),
SourceValue::StringArray(array) => array.join(""),
};
let metadata = match serde_json::to_value(cell.metadata())? {
serde_json::Value::Null => None,
serde_json::Value::Object(metadata) => Some(metadata),
_ => anyhow::bail!("Notebook cell metadata was not an object"),
};
Ok((
lsp_types::NotebookCell {
kind: match cell {
ruff_notebook::Cell::Code(_) => lsp_types::NotebookCellKind::Code,
ruff_notebook::Cell::Markdown(_) => lsp_types::NotebookCellKind::Markup,
ruff_notebook::Cell::Raw(_) => unreachable!(),
},
document: cell_uri.clone(),
metadata,
execution_summary: None,
},
lsp_types::TextDocumentItem::new(cell_uri, "python".to_string(), 1, contents),
))
}

View file

@ -0,0 +1,81 @@
---
source: crates/ruff_server/tests/notebook.rs
expression: notebook_source(snapshot.query().as_notebook().unwrap())
---
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
!pip install matplotlib tensorflow tensorflow-hub
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])
def f(input):
return concrete_func(input)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[f.get_concrete_function()], model
)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
# Save the TF Lite model.
with tf.io.gfile.GFile("ESRGAN.tflite", "wb") as f:
f.write(tflite_model)
esrgan_model_path = "./ESRGAN.tflite"
test_img_path = tf.keras.utils.get_file(
"lr.jpg",
"https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg",
)
lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Run the model
interpreter.set_tensor(input_details[0]["index"], lr)
interpreter.invoke()
# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]["index"])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)
lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize=(1, 1))
plt.title("LR")
plt.imshow(lr.numpy())
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.title(f"ESRGAN (x4)")
plt.imshow(sr.numpy())
bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)
plt.title("Bicubic")
plt.imshow(bicubic.numpy());

View file

@ -0,0 +1,75 @@
---
source: crates/ruff_server/tests/notebook.rs
expression: notebook_source(&notebook)
---
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
!pip install matplotlib tensorflow tensorflow-hub
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
@tf.function(input_signature=[tf.TensorSpec(shape=[1, 50, 50, 3], dtype=tf.float32)])
def f(input):
return concrete_func(input);
converter = tf.lite.TFLiteConverter.from_concrete_functions([f.get_concrete_function()], model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
f.write(tflite_model)
esrgan_model_path = './ESRGAN.tflite'
test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')
lr = tf.io.read_file(test_img_path)
lr = tf.image.decode_jpeg(lr)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()
# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)
lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
plt.figure(figsize = (1, 1))
plt.title('LR')
plt.imshow(lr.numpy());
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.title(f'ESRGAN (x4)')
plt.imshow(sr.numpy());
bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)
plt.title('Bicubic')
plt.imshow(bicubic.numpy());