[red-knot] Add "cheap" program.snapshot (#11172)

This commit is contained in:
Micha Reiser 2024-04-30 09:13:26 +02:00 committed by GitHub
parent eb6f562419
commit bc03d376e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 833 additions and 508 deletions

26
Cargo.lock generated
View file

@ -501,6 +501,19 @@ dependencies = [
"itertools 0.10.5",
]
[[package]]
name = "crossbeam"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.12"
@ -529,6 +542,15 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.19"
@ -1804,7 +1826,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"bitflags 2.5.0",
"crossbeam-channel",
"crossbeam",
"ctrlc",
"dashmap",
"hashbrown 0.14.5",
@ -2341,7 +2363,7 @@ name = "ruff_server"
version = "0.2.2"
dependencies = [
"anyhow",
"crossbeam-channel",
"crossbeam",
"insta",
"jod-thread",
"libc",

View file

@ -30,7 +30,7 @@ console_error_panic_hook = { version = "0.1.7" }
console_log = { version = "1.0.0" }
countme = { version = "3.0.1" }
criterion = { version = "0.5.1", default-features = false }
crossbeam-channel = { version = "0.5.12" }
crossbeam = { version = "0.8.4" }
dashmap = { version = "5.5.3" }
dirs = { version = "5.0.0" }
drop_bomb = { version = "0.1.5" }

View file

@ -22,7 +22,7 @@ ruff_notebook = { path = "../ruff_notebook" }
anyhow = { workspace = true }
bitflags = { workspace = true }
ctrlc = "3.4.4"
crossbeam-channel = { workspace = true }
crossbeam = { workspace = true }
dashmap = { workspace = true }
hashbrown = { workspace = true }
indexmap = { workspace = true }

View file

@ -2,6 +2,7 @@ use std::fmt::Formatter;
use std::hash::Hash;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::db::QueryResult;
use dashmap::mapref::entry::Entry;
use crate::FxDashMap;
@ -27,11 +28,11 @@ where
}
}
pub fn get<F>(&self, key: &K, compute: F) -> V
pub fn get<F>(&self, key: &K, compute: F) -> QueryResult<V>
where
F: FnOnce(&K) -> V,
F: FnOnce(&K) -> QueryResult<V>,
{
match self.map.entry(key.clone()) {
Ok(match self.map.entry(key.clone()) {
Entry::Occupied(cached) => {
self.statistics.hit();
@ -40,11 +41,11 @@ where
Entry::Vacant(vacant) => {
self.statistics.miss();
let value = compute(key);
let value = compute(key)?;
vacant.insert(value.clone());
value
}
}
})
}
pub fn set(&mut self, key: K, value: V) {

View file

@ -1,35 +1,25 @@
use std::sync::{Arc, Condvar, Mutex};
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
#[derive(Debug, Default)]
#[derive(Debug, Clone, Default)]
pub struct CancellationTokenSource {
signal: Arc<(Mutex<bool>, Condvar)>,
signal: Arc<AtomicBool>,
}
impl CancellationTokenSource {
pub fn new() -> Self {
Self {
signal: Arc::new((Mutex::new(false), Condvar::default())),
signal: Arc::new(AtomicBool::new(false)),
}
}
#[tracing::instrument(level = "trace", skip_all)]
pub fn cancel(&self) {
let (cancelled, condvar) = &*self.signal;
let mut cancelled = cancelled.lock().unwrap();
if *cancelled {
return;
}
*cancelled = true;
condvar.notify_all();
self.signal.store(true, std::sync::atomic::Ordering::SeqCst);
}
pub fn is_cancelled(&self) -> bool {
let (cancelled, _) = &*self.signal;
*cancelled.lock().unwrap()
self.signal.load(std::sync::atomic::Ordering::SeqCst)
}
pub fn token(&self) -> CancellationToken {
@ -41,26 +31,12 @@ impl CancellationTokenSource {
#[derive(Clone, Debug)]
pub struct CancellationToken {
signal: Arc<(Mutex<bool>, Condvar)>,
signal: Arc<AtomicBool>,
}
impl CancellationToken {
/// Returns `true` if cancellation has been requested.
pub fn is_cancelled(&self) -> bool {
let (cancelled, _) = &*self.signal;
*cancelled.lock().unwrap()
}
pub fn wait(&self) {
let (bool, condvar) = &*self.signal;
let lock = condvar
.wait_while(bool.lock().unwrap(), |bool| !*bool)
.unwrap();
debug_assert!(*lock);
drop(lock);
self.signal.load(std::sync::atomic::Ordering::SeqCst)
}
}

View file

@ -1,3 +1,8 @@
mod jars;
mod query;
mod runtime;
mod storage;
use std::path::Path;
use std::sync::Arc;
@ -9,32 +14,115 @@ use crate::source::{Source, SourceStorage};
use crate::symbols::{SymbolId, SymbolTable, SymbolTablesStorage};
use crate::types::{Type, TypeStore};
pub trait SourceDb {
pub use jars::{HasJar, HasJars};
pub use query::{QueryError, QueryResult};
pub use runtime::DbRuntime;
pub use storage::JarsStorage;
pub trait Database {
/// Returns a reference to the runtime of the current worker.
fn runtime(&self) -> &DbRuntime;
/// Returns a mutable reference to the runtime. Only one worker can hold a mutable reference to the runtime.
fn runtime_mut(&mut self) -> &mut DbRuntime;
/// Returns `Ok` if the queries have not been cancelled and `Err(QueryError::Cancelled)` otherwise.
fn cancelled(&self) -> QueryResult<()> {
self.runtime().cancelled()
}
/// Returns `true` if the queries have been cancelled.
fn is_cancelled(&self) -> bool {
self.runtime().is_cancelled()
}
}
/// Database that supports running queries from multiple threads.
pub trait ParallelDatabase: Database + Send {
/// Creates a snapshot of the database state that can be used to query the database in another thread.
///
/// The snapshot is a read-only view of the database but query results are shared between threads.
/// All queries will be automatically cancelled when applying any mutations (calling [`HasJars::jars_mut`])
/// to the database (not the snapshot, because they're readonly).
///
/// ## Creating a snapshot
///
/// Creating a snapshot of the database's jars is cheap but creating a snapshot of
/// other state stored on the database might require deep-cloning data. That's why you should
/// avoid creating snapshots in a hot function (e.g. don't create a snapshot for each file, instead
/// create a snapshot when scheduling the check of an entire program).
///
/// ## Salsa compatibility
/// Salsa prohibits creating a snapshot while running a local query (it's fine if other workers run a query) [[source](https://github.com/salsa-rs/salsa/issues/80)].
/// We should avoid creating snapshots while running a query because we might want to adopt Salsa in the future (if we can figure out persistent caching).
/// Unfortunately, the infrastructure doesn't provide an automated way of knowing when a query is run, that's
/// why we have to "enforce" this constraint manually.
fn snapshot(&self) -> Snapshot<Self>;
}
/// Readonly snapshot of a database.
///
/// ## Dead locks
/// A snapshot should always be dropped as soon as it is no longer necessary to run queries.
/// Storing the snapshot without running a query or periodically checking if cancellation was requested
/// can lead to deadlocks because mutating the [`Database`] requires cancels all pending queries
/// and waiting for all [`Snapshot`]s to be dropped.
#[derive(Debug)]
pub struct Snapshot<DB: ?Sized>
where
DB: ParallelDatabase,
{
db: DB,
}
impl<DB> Snapshot<DB>
where
DB: ParallelDatabase,
{
pub fn new(db: DB) -> Self {
Snapshot { db }
}
}
impl<DB> std::ops::Deref for Snapshot<DB>
where
DB: ParallelDatabase,
{
type Target = DB;
fn deref(&self) -> &DB {
&self.db
}
}
// Red knot specific databases code.
pub trait SourceDb: Database {
// queries
fn file_id(&self, path: &std::path::Path) -> FileId;
fn file_path(&self, file_id: FileId) -> Arc<std::path::Path>;
fn source(&self, file_id: FileId) -> Source;
fn source(&self, file_id: FileId) -> QueryResult<Source>;
fn parse(&self, file_id: FileId) -> Parsed;
fn parse(&self, file_id: FileId) -> QueryResult<Parsed>;
fn lint_syntax(&self, file_id: FileId) -> Diagnostics;
fn lint_syntax(&self, file_id: FileId) -> QueryResult<Diagnostics>;
}
pub trait SemanticDb: SourceDb {
// queries
fn resolve_module(&self, name: ModuleName) -> Option<Module>;
fn resolve_module(&self, name: ModuleName) -> QueryResult<Option<Module>>;
fn file_to_module(&self, file_id: FileId) -> Option<Module>;
fn file_to_module(&self, file_id: FileId) -> QueryResult<Option<Module>>;
fn path_to_module(&self, path: &Path) -> Option<Module>;
fn path_to_module(&self, path: &Path) -> QueryResult<Option<Module>>;
fn symbol_table(&self, file_id: FileId) -> Arc<SymbolTable>;
fn symbol_table(&self, file_id: FileId) -> QueryResult<Arc<SymbolTable>>;
fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type;
fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult<Type>;
fn lint_semantic(&self, file_id: FileId) -> Diagnostics;
fn lint_semantic(&self, file_id: FileId) -> QueryResult<Diagnostics>;
// mutations
@ -60,32 +148,15 @@ pub struct SemanticJar {
pub lint_semantic: LintSemanticStorage,
}
/// Gives access to a specific jar in the database.
///
/// Nope, the terminology isn't borrowed from Java but from Salsa <https://salsa-rs.github.io/salsa/>,
/// which is an analogy to storing the salsa in different jars.
///
/// The basic idea is that each crate can define its own jar and the jars can be combined to a single
/// database in the top level crate. Each crate also defines its own `Database` trait. The combination of
/// `Database` trait and the jar allows to write queries in isolation without having to know how they get composed at the upper levels.
///
/// Salsa further defines a `HasIngredient` trait which slices the jar to a specific storage (e.g. a specific cache).
/// We don't need this just jet because we write our queries by hand. We may want a similar trait if we decide
/// to use a macro to generate the queries.
pub trait HasJar<T> {
/// Gives a read-only reference to the jar.
fn jar(&self) -> &T;
/// Gives a mutable reference to the jar.
fn jar_mut(&mut self) -> &mut T;
}
#[cfg(test)]
pub(crate) mod tests {
use std::path::Path;
use std::sync::Arc;
use crate::db::{HasJar, SourceDb, SourceJar};
use crate::db::{
Database, DbRuntime, HasJar, HasJars, JarsStorage, ParallelDatabase, QueryResult, Snapshot,
SourceDb, SourceJar,
};
use crate::files::{FileId, Files};
use crate::lint::{lint_semantic, lint_syntax, Diagnostics};
use crate::module::{
@ -104,27 +175,26 @@ pub(crate) mod tests {
#[derive(Debug, Default)]
pub(crate) struct TestDb {
files: Files,
source: SourceJar,
semantic: SemanticJar,
jars: JarsStorage<Self>,
}
impl HasJar<SourceJar> for TestDb {
fn jar(&self) -> &SourceJar {
&self.source
fn jar(&self) -> QueryResult<&SourceJar> {
Ok(&self.jars()?.0)
}
fn jar_mut(&mut self) -> &mut SourceJar {
&mut self.source
&mut self.jars_mut().0
}
}
impl HasJar<SemanticJar> for TestDb {
fn jar(&self) -> &SemanticJar {
&self.semantic
fn jar(&self) -> QueryResult<&SemanticJar> {
Ok(&self.jars()?.1)
}
fn jar_mut(&mut self) -> &mut SemanticJar {
&mut self.semantic
&mut self.jars_mut().1
}
}
@ -137,41 +207,41 @@ pub(crate) mod tests {
self.files.path(file_id)
}
fn source(&self, file_id: FileId) -> Source {
fn source(&self, file_id: FileId) -> QueryResult<Source> {
source_text(self, file_id)
}
fn parse(&self, file_id: FileId) -> Parsed {
fn parse(&self, file_id: FileId) -> QueryResult<Parsed> {
parse(self, file_id)
}
fn lint_syntax(&self, file_id: FileId) -> Diagnostics {
fn lint_syntax(&self, file_id: FileId) -> QueryResult<Diagnostics> {
lint_syntax(self, file_id)
}
}
impl SemanticDb for TestDb {
fn resolve_module(&self, name: ModuleName) -> Option<Module> {
fn resolve_module(&self, name: ModuleName) -> QueryResult<Option<Module>> {
resolve_module(self, name)
}
fn file_to_module(&self, file_id: FileId) -> Option<Module> {
fn file_to_module(&self, file_id: FileId) -> QueryResult<Option<Module>> {
file_to_module(self, file_id)
}
fn path_to_module(&self, path: &Path) -> Option<Module> {
fn path_to_module(&self, path: &Path) -> QueryResult<Option<Module>> {
path_to_module(self, path)
}
fn symbol_table(&self, file_id: FileId) -> Arc<SymbolTable> {
fn symbol_table(&self, file_id: FileId) -> QueryResult<Arc<SymbolTable>> {
symbol_table(self, file_id)
}
fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type {
fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult<Type> {
infer_symbol_type(self, file_id, symbol_id)
}
fn lint_semantic(&self, file_id: FileId) -> Diagnostics {
fn lint_semantic(&self, file_id: FileId) -> QueryResult<Diagnostics> {
lint_semantic(self, file_id)
}
@ -183,4 +253,35 @@ pub(crate) mod tests {
set_module_search_paths(self, paths);
}
}
impl HasJars for TestDb {
type Jars = (SourceJar, SemanticJar);
fn jars(&self) -> QueryResult<&Self::Jars> {
self.jars.jars()
}
fn jars_mut(&mut self) -> &mut Self::Jars {
self.jars.jars_mut()
}
}
impl Database for TestDb {
fn runtime(&self) -> &DbRuntime {
self.jars.runtime()
}
fn runtime_mut(&mut self) -> &mut DbRuntime {
self.jars.runtime_mut()
}
}
impl ParallelDatabase for TestDb {
fn snapshot(&self) -> Snapshot<Self> {
Snapshot::new(Self {
files: self.files.clone(),
jars: self.jars.snapshot(),
})
}
}
}

View file

@ -0,0 +1,37 @@
use crate::db::query::QueryResult;
/// Gives access to a specific jar in the database.
///
/// Nope, the terminology isn't borrowed from Java but from Salsa <https://salsa-rs.github.io/salsa/>,
/// which is an analogy to storing the salsa in different jars.
///
/// The basic idea is that each crate can define its own jar and the jars can be combined to a single
/// database in the top level crate. Each crate also defines its own `Database` trait. The combination of
/// `Database` trait and the jar allows to write queries in isolation without having to know how they get composed at the upper levels.
///
/// Salsa further defines a `HasIngredient` trait which slices the jar to a specific storage (e.g. a specific cache).
/// We don't need this just jet because we write our queries by hand. We may want a similar trait if we decide
/// to use a macro to generate the queries.
pub trait HasJar<T> {
/// Gives a read-only reference to the jar.
fn jar(&self) -> QueryResult<&T>;
/// Gives a mutable reference to the jar.
fn jar_mut(&mut self) -> &mut T;
}
/// Gives access to the jars in a database.
pub trait HasJars {
/// A type storing the jars.
///
/// Most commonly, this is a tuple where each jar is a tuple element.
type Jars: Default;
/// Gives access to the underlying jars but tests if the queries have been cancelled.
///
/// Returns `Err(QueryError::Cancelled)` if the queries have been cancelled.
fn jars(&self) -> QueryResult<&Self::Jars>;
/// Gives mutable access to the underlying jars.
fn jars_mut(&mut self) -> &mut Self::Jars;
}

View file

@ -0,0 +1,20 @@
use std::fmt::{Display, Formatter};
/// Reason why a db query operation failed.
#[derive(Debug, Clone, Copy)]
pub enum QueryError {
/// The query was cancelled because the DB was mutated or the query was cancelled by the host (e.g. on a file change or when pressing CTRL+C).
Cancelled,
}
impl Display for QueryError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
QueryError::Cancelled => f.write_str("query was cancelled"),
}
}
}
impl std::error::Error for QueryError {}
pub type QueryResult<T> = Result<T, QueryError>;

View file

@ -0,0 +1,41 @@
use crate::cancellation::CancellationTokenSource;
use crate::db::{QueryError, QueryResult};
/// Holds the jar agnostic state of the database.
#[derive(Debug, Default)]
pub struct DbRuntime {
/// The cancellation token source used to signal other works that the queries should be aborted and
/// exit at the next possible point.
cancellation_token: CancellationTokenSource,
}
impl DbRuntime {
pub(super) fn snapshot(&self) -> Self {
Self {
cancellation_token: self.cancellation_token.clone(),
}
}
/// Cancels the pending queries of other workers. The current worker cannot have any pending
/// queries because we're holding a mutable reference to the runtime.
pub(super) fn cancel_other_workers(&mut self) {
self.cancellation_token.cancel();
// Set a new cancellation token so that we're in a non-cancelled state again when running the next
// query.
self.cancellation_token = CancellationTokenSource::default();
}
/// Returns `Ok` if the queries have not been cancelled and `Err(QueryError::Cancelled)` otherwise.
pub(super) fn cancelled(&self) -> QueryResult<()> {
if self.cancellation_token.is_cancelled() {
Err(QueryError::Cancelled)
} else {
Ok(())
}
}
/// Returns `true` if the queries have been cancelled.
pub(super) fn is_cancelled(&self) -> bool {
self.cancellation_token.is_cancelled()
}
}

View file

@ -0,0 +1,117 @@
use std::fmt::Formatter;
use std::sync::Arc;
use crossbeam::sync::WaitGroup;
use crate::db::query::QueryResult;
use crate::db::runtime::DbRuntime;
use crate::db::{HasJars, ParallelDatabase};
/// Stores the jars of a database and the state for each worker.
///
/// Today, all state is shared across all workers, but it may be desired to store data per worker in the future.
pub struct JarsStorage<T>
where
T: HasJars + Sized,
{
// It's important that `jars_wait_group` is declared after `jars` to ensure that `jars` is dropped first.
// See https://doc.rust-lang.org/reference/destructors.html
/// Stores the jars of the database.
jars: Arc<T::Jars>,
/// Used to count the references to `jars`. Allows implementing `jars_mut` without requiring to clone `jars`.
jars_wait_group: WaitGroup,
/// The data agnostic state.
runtime: DbRuntime,
}
impl<Db> JarsStorage<Db>
where
Db: HasJars,
{
pub(super) fn new() -> Self {
Self {
jars: Arc::new(Db::Jars::default()),
jars_wait_group: WaitGroup::default(),
runtime: DbRuntime::default(),
}
}
/// Creates a snapshot of the jars.
///
/// Creating the snapshot is cheap because it doesn't clone the jars, it only increments a ref counter.
#[must_use]
pub fn snapshot(&self) -> JarsStorage<Db>
where
Db: ParallelDatabase,
{
Self {
jars: self.jars.clone(),
jars_wait_group: self.jars_wait_group.clone(),
runtime: self.runtime.snapshot(),
}
}
pub(crate) fn jars(&self) -> QueryResult<&Db::Jars> {
self.runtime.cancelled()?;
Ok(&self.jars)
}
/// Returns a mutable reference to the jars without cloning their content.
///
/// The method cancels any pending queries of other works and waits for them to complete so that
/// this instance is the only instance holding a reference to the jars.
pub(crate) fn jars_mut(&mut self) -> &mut Db::Jars {
// We have a mutable ref here, so no more workers can be spawned between calling this function and taking the mut ref below.
self.cancel_other_workers();
// Now all other references to `self.jars` should have been released. We can now safely return a mutable reference
// to the Arc's content.
let jars =
Arc::get_mut(&mut self.jars).expect("All references to jars should have been released");
jars
}
pub(crate) fn runtime(&self) -> &DbRuntime {
&self.runtime
}
pub(crate) fn runtime_mut(&mut self) -> &mut DbRuntime {
// Note: This method may need to use a similar trick to `jars_mut` if `DbRuntime` is ever to store data that is shared between workers.
&mut self.runtime
}
#[tracing::instrument(level = "trace", skip(self))]
fn cancel_other_workers(&mut self) {
self.runtime.cancel_other_workers();
// Wait for all other works to complete.
let existing_wait = std::mem::take(&mut self.jars_wait_group);
existing_wait.wait();
}
}
impl<Db> Default for JarsStorage<Db>
where
Db: HasJars,
{
fn default() -> Self {
Self::new()
}
}
impl<T> std::fmt::Debug for JarsStorage<T>
where
T: HasJars,
<T as HasJars>::Jars: std::fmt::Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharedStorage")
.field("jars", &self.jars)
.field("jars_wait_group", &self.jars_wait_group)
.field("runtime", &self.runtime)
.finish()
}
}

View file

@ -27,7 +27,7 @@ pub(crate) type FxDashMap<K, V> = dashmap::DashMap<K, V, BuildHasherDefault<FxHa
pub(crate) type FxDashSet<V> = dashmap::DashSet<V, BuildHasherDefault<FxHasher>>;
pub(crate) type FxIndexSet<V> = indexmap::set::IndexSet<V, BuildHasherDefault<FxHasher>>;
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Workspace {
/// TODO this should be a resolved path. We should probably use a newtype wrapper that guarantees that
/// PATH is a UTF-8 path and is normalized.

View file

@ -1,12 +1,15 @@
use std::cell::RefCell;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::Duration;
use ruff_python_ast::visitor::Visitor;
use ruff_python_ast::{ModModule, StringLiteral};
use crate::cache::KeyValueCache;
use crate::db::{HasJar, SemanticDb, SemanticJar, SourceDb, SourceJar};
use crate::db::{
HasJar, ParallelDatabase, QueryResult, SemanticDb, SemanticJar, SourceDb, SourceJar,
};
use crate::files::FileId;
use crate::parse::Parsed;
use crate::source::Source;
@ -14,19 +17,28 @@ use crate::symbols::{Definition, SymbolId, SymbolTable};
use crate::types::Type;
#[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn lint_syntax<Db>(db: &Db, file_id: FileId) -> Diagnostics
pub(crate) fn lint_syntax<Db>(db: &Db, file_id: FileId) -> QueryResult<Diagnostics>
where
Db: SourceDb + HasJar<SourceJar>,
Db: SourceDb + HasJar<SourceJar> + ParallelDatabase,
{
let storage = &db.jar().lint_syntax;
let storage = &db.jar()?.lint_syntax;
#[allow(clippy::print_stdout)]
if std::env::var("RED_KNOT_SLOW_LINT").is_ok() {
for i in 0..10 {
db.cancelled()?;
println!("RED_KNOT_SLOW_LINT is set, sleeping for {i}/10 seconds");
std::thread::sleep(Duration::from_secs(1));
}
}
storage.get(&file_id, |file_id| {
let mut diagnostics = Vec::new();
let source = db.source(*file_id);
let source = db.source(*file_id)?;
lint_lines(source.text(), &mut diagnostics);
let parsed = db.parse(*file_id);
let parsed = db.parse(*file_id)?;
if parsed.errors().is_empty() {
let ast = parsed.ast();
@ -41,7 +53,7 @@ where
diagnostics.extend(parsed.errors().iter().map(std::string::ToString::to_string));
}
Diagnostics::from(diagnostics)
Ok(Diagnostics::from(diagnostics))
})
}
@ -63,16 +75,16 @@ fn lint_lines(source: &str, diagnostics: &mut Vec<String>) {
}
#[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn lint_semantic<Db>(db: &Db, file_id: FileId) -> Diagnostics
pub(crate) fn lint_semantic<Db>(db: &Db, file_id: FileId) -> QueryResult<Diagnostics>
where
Db: SemanticDb + HasJar<SemanticJar>,
{
let storage = &db.jar().lint_semantic;
let storage = &db.jar()?.lint_semantic;
storage.get(&file_id, |file_id| {
let source = db.source(*file_id);
let parsed = db.parse(*file_id);
let symbols = db.symbol_table(*file_id);
let source = db.source(*file_id)?;
let parsed = db.parse(*file_id)?;
let symbols = db.symbol_table(*file_id)?;
let context = SemanticLintContext {
file_id: *file_id,
@ -83,25 +95,25 @@ where
diagnostics: RefCell::new(Vec::new()),
};
lint_unresolved_imports(&context);
lint_unresolved_imports(&context)?;
Diagnostics::from(context.diagnostics.take())
Ok(Diagnostics::from(context.diagnostics.take()))
})
}
fn lint_unresolved_imports(context: &SemanticLintContext) {
fn lint_unresolved_imports(context: &SemanticLintContext) -> QueryResult<()> {
// TODO: Consider iterating over the dependencies (imports) only instead of all definitions.
for (symbol, definition) in context.symbols().all_definitions() {
match definition {
Definition::Import(import) => {
let ty = context.eval_symbol(symbol);
let ty = context.infer_symbol_type(symbol)?;
if ty.is_unknown() {
context.push_diagnostic(format!("Unresolved module {}", import.module));
}
}
Definition::ImportFrom(import) => {
let ty = context.eval_symbol(symbol);
let ty = context.infer_symbol_type(symbol)?;
if ty.is_unknown() {
let module_name = import.module().map(Deref::deref).unwrap_or_default();
@ -126,6 +138,8 @@ fn lint_unresolved_imports(context: &SemanticLintContext) {
_ => {}
}
}
Ok(())
}
pub struct SemanticLintContext<'a> {
@ -154,7 +168,7 @@ impl<'a> SemanticLintContext<'a> {
&self.symbols
}
pub fn eval_symbol(&self, symbol_id: SymbolId) -> Type {
pub fn infer_symbol_type(&self, symbol_id: SymbolId) -> QueryResult<Type> {
self.db.infer_symbol_type(self.file_id, symbol_id)
}

View file

@ -4,6 +4,7 @@ use std::collections::hash_map::Entry;
use std::path::Path;
use std::sync::Mutex;
use crossbeam::channel as crossbeam_channel;
use rustc_hash::FxHashMap;
use tracing::subscriber::Interest;
use tracing::{Level, Metadata};
@ -12,11 +13,10 @@ use tracing_subscriber::layer::{Context, Filter, SubscriberExt};
use tracing_subscriber::{Layer, Registry};
use tracing_tree::time::Uptime;
use red_knot::cancellation::CancellationTokenSource;
use red_knot::db::{HasJar, SourceDb, SourceJar};
use red_knot::db::{HasJar, ParallelDatabase, QueryError, SemanticDb, SourceDb, SourceJar};
use red_knot::files::FileId;
use red_knot::module::{ModuleSearchPath, ModuleSearchPathKind};
use red_knot::program::check::{CheckError, RayonCheckScheduler};
use red_knot::program::check::RayonCheckScheduler;
use red_knot::program::{FileChange, FileChangeKind, Program};
use red_knot::watch::FileWatcher;
use red_knot::Workspace;
@ -51,7 +51,8 @@ fn main() -> anyhow::Result<()> {
workspace.root().to_path_buf(),
ModuleSearchPathKind::FirstParty,
);
let mut program = Program::new(workspace, vec![workspace_search_path]);
let mut program = Program::new(workspace);
program.set_module_search_paths(vec![workspace_search_path]);
let entry_id = program.file_id(entry_point);
program.workspace_mut().open_file(entry_id);
@ -82,7 +83,7 @@ fn main() -> anyhow::Result<()> {
main_loop.run(&mut program);
let source_jar: &SourceJar = program.jar();
let source_jar: &SourceJar = program.jar().unwrap();
dbg!(source_jar.parsed.statistics());
dbg!(source_jar.sources.statistics());
@ -101,10 +102,9 @@ impl MainLoop {
let (main_loop_sender, main_loop_receiver) = crossbeam_channel::bounded(1);
let mut orchestrator = Orchestrator {
pending_analysis: None,
receiver: orchestrator_receiver,
sender: main_loop_sender.clone(),
aggregated_changes: AggregatedChanges::default(),
revision: 0,
};
std::thread::spawn(move || {
@ -137,34 +137,32 @@ impl MainLoop {
tracing::trace!("Main Loop: Tick");
match message {
MainLoopMessage::CheckProgram => {
// Remove mutability from program.
let program = &*program;
let run_cancellation_token_source = CancellationTokenSource::new();
let run_cancellation_token = run_cancellation_token_source.token();
let sender = &self.orchestrator_sender;
MainLoopMessage::CheckProgram { revision } => {
let program = program.snapshot();
let sender = self.orchestrator_sender.clone();
sender
.send(OrchestratorMessage::CheckProgramStarted {
cancellation_token: run_cancellation_token_source,
})
.unwrap();
// Spawn a new task that checks the program. This needs to be done in a separate thread
// to prevent blocking the main loop here.
rayon::spawn(move || {
rayon::in_place_scope(|scope| {
let scheduler = RayonCheckScheduler::new(&program, scope);
rayon::in_place_scope(|scope| {
let scheduler = RayonCheckScheduler::new(program, scope);
let result = program.check(&scheduler, run_cancellation_token);
match result {
Ok(result) => sender
.send(OrchestratorMessage::CheckProgramCompleted(result))
.unwrap(),
Err(CheckError::Cancelled) => sender
.send(OrchestratorMessage::CheckProgramCancelled)
.unwrap(),
}
match program.check(&scheduler) {
Ok(result) => {
sender
.send(OrchestratorMessage::CheckProgramCompleted {
diagnostics: result,
revision,
})
.unwrap();
}
Err(QueryError::Cancelled) => {}
}
});
});
}
MainLoopMessage::ApplyChanges(changes) => {
// Automatically cancels any pending queries and waits for them to complete.
program.apply_changes(changes.iter());
}
MainLoopMessage::CheckCompleted(diagnostics) => {
@ -211,13 +209,11 @@ impl MainLoopCancellationToken {
}
struct Orchestrator {
aggregated_changes: AggregatedChanges,
pending_analysis: Option<PendingAnalysisState>,
/// Sends messages to the main loop.
sender: crossbeam_channel::Sender<MainLoopMessage>,
/// Receives messages from the main loop.
receiver: crossbeam_channel::Receiver<OrchestratorMessage>,
revision: usize,
}
impl Orchestrator {
@ -225,51 +221,33 @@ impl Orchestrator {
while let Ok(message) = self.receiver.recv() {
match message {
OrchestratorMessage::Run => {
self.pending_analysis = None;
self.sender.send(MainLoopMessage::CheckProgram).unwrap();
}
OrchestratorMessage::CheckProgramStarted { cancellation_token } => {
debug_assert!(self.pending_analysis.is_none());
self.pending_analysis = Some(PendingAnalysisState { cancellation_token });
}
OrchestratorMessage::CheckProgramCompleted(diagnostics) => {
self.pending_analysis
.take()
.expect("Expected a pending analysis.");
self.sender
.send(MainLoopMessage::CheckCompleted(diagnostics))
.send(MainLoopMessage::CheckProgram {
revision: self.revision,
})
.unwrap();
}
OrchestratorMessage::CheckProgramCancelled => {
self.pending_analysis
.take()
.expect("Expected a pending analysis.");
self.debounce_changes();
OrchestratorMessage::CheckProgramCompleted {
diagnostics,
revision,
} => {
// Only take the diagnostics if they are for the latest revision.
if self.revision == revision {
self.sender
.send(MainLoopMessage::CheckCompleted(diagnostics))
.unwrap();
} else {
tracing::debug!("Discarding diagnostics for outdated revision {revision} (current: {}).", self.revision);
}
}
OrchestratorMessage::FileChanges(changes) => {
// Request cancellation, but wait until all analysis tasks have completed to
// avoid stale messages in the next main loop.
let pending = if let Some(pending_state) = self.pending_analysis.as_ref() {
pending_state.cancellation_token.cancel();
true
} else {
false
};
self.aggregated_changes.extend(changes);
// If there are no pending analysis tasks, apply the file changes. Otherwise
// keep running until all file checks have completed.
if !pending {
self.debounce_changes();
}
self.revision += 1;
self.debounce_changes(changes);
}
OrchestratorMessage::Shutdown => {
return self.shutdown();
@ -278,8 +256,9 @@ impl Orchestrator {
}
}
fn debounce_changes(&mut self) {
debug_assert!(self.pending_analysis.is_none());
fn debounce_changes(&self, changes: Vec<FileChange>) {
let mut aggregated_changes = AggregatedChanges::default();
aggregated_changes.extend(changes);
loop {
// Consume possibly incoming file change messages before running a new analysis, but don't wait for more than 100ms.
@ -290,10 +269,12 @@ impl Orchestrator {
return self.shutdown();
}
Ok(OrchestratorMessage::FileChanges(file_changes)) => {
self.aggregated_changes.extend(file_changes);
aggregated_changes.extend(file_changes);
}
Ok(OrchestratorMessage::CheckProgramStarted {..}| OrchestratorMessage::CheckProgramCompleted(_) | OrchestratorMessage::CheckProgramCancelled) => unreachable!("No program check should be running while debouncing changes."),
Ok(OrchestratorMessage::CheckProgramCompleted { .. })=> {
// disregard any outdated completion message.
}
Ok(OrchestratorMessage::Run) => unreachable!("The orchestrator is already running."),
Err(_) => {
@ -302,10 +283,10 @@ impl Orchestrator {
}
}
},
default(std::time::Duration::from_millis(100)) => {
// No more file changes after 100 ms, send the changes and schedule a new analysis
self.sender.send(MainLoopMessage::ApplyChanges(std::mem::take(&mut self.aggregated_changes))).unwrap();
self.sender.send(MainLoopMessage::CheckProgram).unwrap();
default(std::time::Duration::from_millis(10)) => {
// No more file changes after 10 ms, send the changes and schedule a new analysis
self.sender.send(MainLoopMessage::ApplyChanges(aggregated_changes)).unwrap();
self.sender.send(MainLoopMessage::CheckProgram { revision: self.revision}).unwrap();
return;
}
}
@ -318,15 +299,10 @@ impl Orchestrator {
}
}
#[derive(Debug)]
struct PendingAnalysisState {
cancellation_token: CancellationTokenSource,
}
/// Message sent from the orchestrator to the main loop.
#[derive(Debug)]
enum MainLoopMessage {
CheckProgram,
CheckProgram { revision: usize },
CheckCompleted(Vec<String>),
ApplyChanges(AggregatedChanges),
Exit,
@ -337,11 +313,10 @@ enum OrchestratorMessage {
Run,
Shutdown,
CheckProgramStarted {
cancellation_token: CancellationTokenSource,
CheckProgramCompleted {
diagnostics: Vec<String>,
revision: usize,
},
CheckProgramCompleted(Vec<String>),
CheckProgramCancelled,
FileChanges(Vec<FileChange>),
}

View file

@ -7,7 +7,7 @@ use std::sync::Arc;
use dashmap::mapref::entry::Entry;
use smol_str::SmolStr;
use crate::db::{HasJar, SemanticDb, SemanticJar};
use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar};
use crate::files::FileId;
use crate::symbols::Dependency;
use crate::FxDashMap;
@ -17,44 +17,48 @@ use crate::FxDashMap;
pub struct Module(u32);
impl Module {
pub fn name<Db>(&self, db: &Db) -> ModuleName
pub fn name<Db>(&self, db: &Db) -> QueryResult<ModuleName>
where
Db: HasJar<SemanticJar>,
{
let modules = &db.jar().module_resolver;
let modules = &db.jar()?.module_resolver;
modules.modules.get(self).unwrap().name.clone()
Ok(modules.modules.get(self).unwrap().name.clone())
}
pub fn path<Db>(&self, db: &Db) -> ModulePath
pub fn path<Db>(&self, db: &Db) -> QueryResult<ModulePath>
where
Db: HasJar<SemanticJar>,
{
let modules = &db.jar().module_resolver;
let modules = &db.jar()?.module_resolver;
modules.modules.get(self).unwrap().path.clone()
Ok(modules.modules.get(self).unwrap().path.clone())
}
pub fn kind<Db>(&self, db: &Db) -> ModuleKind
pub fn kind<Db>(&self, db: &Db) -> QueryResult<ModuleKind>
where
Db: HasJar<SemanticJar>,
{
let modules = &db.jar().module_resolver;
let modules = &db.jar()?.module_resolver;
modules.modules.get(self).unwrap().kind
Ok(modules.modules.get(self).unwrap().kind)
}
pub fn resolve_dependency<Db>(&self, db: &Db, dependency: &Dependency) -> Option<ModuleName>
pub fn resolve_dependency<Db>(
&self,
db: &Db,
dependency: &Dependency,
) -> QueryResult<Option<ModuleName>>
where
Db: HasJar<SemanticJar>,
{
let (level, module) = match dependency {
Dependency::Module(module) => return Some(module.clone()),
Dependency::Module(module) => return Ok(Some(module.clone())),
Dependency::Relative { level, module } => (*level, module.as_deref()),
};
let name = self.name(db);
let kind = self.kind(db);
let name = self.name(db)?;
let kind = self.kind(db)?;
let mut components = name.components().peekable();
@ -67,7 +71,9 @@ impl Module {
// Skip over the relative parts.
for _ in start..level.get() {
components.next_back()?;
if components.next_back().is_none() {
return Ok(None);
}
}
let mut name = String::new();
@ -80,11 +86,11 @@ impl Module {
name.push_str(part);
}
if name.is_empty() {
Ok(if name.is_empty() {
None
} else {
Some(ModuleName(SmolStr::new(name)))
}
})
}
}
@ -238,20 +244,25 @@ pub struct ModuleData {
/// TODO: This would not work with Salsa because `ModuleName` isn't an ingredient and, therefore, cannot be used as part of a query.
/// For this to work with salsa, it would be necessary to intern all `ModuleName`s.
#[tracing::instrument(level = "debug", skip(db))]
pub fn resolve_module<Db>(db: &Db, name: ModuleName) -> Option<Module>
pub fn resolve_module<Db>(db: &Db, name: ModuleName) -> QueryResult<Option<Module>>
where
Db: SemanticDb + HasJar<SemanticJar>,
{
let jar = db.jar();
let modules = &jar.module_resolver;
let modules = &jar?.module_resolver;
let entry = modules.by_name.entry(name.clone());
match entry {
Entry::Occupied(entry) => Some(*entry.get()),
Entry::Occupied(entry) => Ok(Some(*entry.get())),
Entry::Vacant(entry) => {
let (root_path, absolute_path, kind) = resolve_name(&name, &modules.search_paths)?;
let normalized = absolute_path.canonicalize().ok()?;
let Some((root_path, absolute_path, kind)) = resolve_name(&name, &modules.search_paths)
else {
return Ok(None);
};
let Ok(normalized) = absolute_path.canonicalize() else {
return Ok(None);
};
let file_id = db.file_id(&normalized);
let path = ModulePath::new(root_path.clone(), file_id);
@ -277,7 +288,7 @@ where
entry.insert_entry(id);
Some(id)
Ok(Some(id))
}
}
}
@ -286,7 +297,7 @@ where
///
/// Returns `None` if the file is not a module in `sys.path`.
#[tracing::instrument(level = "debug", skip(db))]
pub fn file_to_module<Db>(db: &Db, file: FileId) -> Option<Module>
pub fn file_to_module<Db>(db: &Db, file: FileId) -> QueryResult<Option<Module>>
where
Db: SemanticDb + HasJar<SemanticJar>,
{
@ -298,34 +309,42 @@ where
///
/// Returns `None` if the path is not a module in `sys.path`.
#[tracing::instrument(level = "debug", skip(db))]
pub fn path_to_module<Db>(db: &Db, path: &Path) -> Option<Module>
pub fn path_to_module<Db>(db: &Db, path: &Path) -> QueryResult<Option<Module>>
where
Db: SemanticDb + HasJar<SemanticJar>,
{
let jar = db.jar();
let jar = db.jar()?;
let modules = &jar.module_resolver;
debug_assert!(path.is_absolute());
if let Some(existing) = modules.by_path.get(path) {
return Some(*existing);
return Ok(Some(*existing));
}
let (root_path, relative_path) = modules.search_paths.iter().find_map(|root| {
let Some((root_path, relative_path)) = modules.search_paths.iter().find_map(|root| {
let relative_path = path.strip_prefix(root.path()).ok()?;
Some((root.clone(), relative_path))
})?;
}) else {
return Ok(None);
};
let module_name = ModuleName::from_relative_path(relative_path)?;
let Some(module_name) = ModuleName::from_relative_path(relative_path) else {
return Ok(None);
};
// Resolve the module name to see if Python would resolve the name to the same path.
// If it doesn't, then that means that multiple modules have the same in different
// root paths, but that the module corresponding to the past path is in a lower priority path,
// in which case we ignore it.
let module_id = resolve_module(db, module_name)?;
let module_path = module_id.path(db);
let Some(module_id) = resolve_module(db, module_name)? else {
return Ok(None);
};
let module_path = module_id.path(db)?;
if module_path.root() == &root_path {
let normalized = path.canonicalize().ok()?;
let Ok(normalized) = path.canonicalize() else {
return Ok(None);
};
let interned_normalized = db.file_id(&normalized);
if interned_normalized != module_path.file() {
@ -336,15 +355,15 @@ where
// ```
// The module name of `src/foo.py` is `foo`, but the module loaded by Python is `src/foo/__init__.py`.
// That means we need to ignore `src/foo.py` even though it resolves to the same module name.
return None;
return Ok(None);
}
// Path has been inserted by `resolved`
Some(module_id)
Ok(Some(module_id))
} else {
// This path is for a module with the same name but in a module search path with a lower priority.
// Ignore it.
None
Ok(None)
}
}
@ -378,7 +397,7 @@ where
// TODO This needs tests
// Note: Intentionally by-pass caching here. Module should not be in the cache yet.
let module = path_to_module(db, path)?;
let module = path_to_module(db, path).ok()??;
// The code below is to handle the addition of `__init__.py` files.
// When an `__init__.py` file is added, we need to remove all modules that are part of the same package.
@ -392,7 +411,7 @@ where
return Some((module, Vec::new()));
}
let Some(parent_name) = module.name(db).parent() else {
let Some(parent_name) = module.name(db).ok()?.parent() else {
return Some((module, Vec::new()));
};
@ -691,7 +710,7 @@ mod tests {
}
#[test]
fn first_party_module() -> std::io::Result<()> {
fn first_party_module() -> anyhow::Result<()> {
let TestCase {
db,
src,
@ -702,22 +721,22 @@ mod tests {
let foo_path = src.path().join("foo.py");
std::fs::write(&foo_path, "print('Hello, world!')")?;
let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap();
let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap();
assert_eq!(Some(foo_module), db.resolve_module(ModuleName::new("foo")));
assert_eq!(Some(foo_module), db.resolve_module(ModuleName::new("foo"))?);
assert_eq!(ModuleName::new("foo"), foo_module.name(&db));
assert_eq!(&src, foo_module.path(&db).root());
assert_eq!(ModuleKind::Module, foo_module.kind(&db));
assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db).file()));
assert_eq!(ModuleName::new("foo"), foo_module.name(&db)?);
assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(ModuleKind::Module, foo_module.kind(&db)?);
assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file()));
assert_eq!(Some(foo_module), db.path_to_module(&foo_path));
assert_eq!(Some(foo_module), db.path_to_module(&foo_path)?);
Ok(())
}
#[test]
fn resolve_package() -> std::io::Result<()> {
fn resolve_package() -> anyhow::Result<()> {
let TestCase {
src,
db,
@ -730,22 +749,22 @@ mod tests {
std::fs::create_dir(&foo_dir)?;
std::fs::write(&foo_path, "print('Hello, world!')")?;
let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap();
let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap();
assert_eq!(ModuleName::new("foo"), foo_module.name(&db));
assert_eq!(&src, foo_module.path(&db).root());
assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db).file()));
assert_eq!(ModuleName::new("foo"), foo_module.name(&db)?);
assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file()));
assert_eq!(Some(foo_module), db.path_to_module(&foo_path));
assert_eq!(Some(foo_module), db.path_to_module(&foo_path)?);
// Resolving by directory doesn't resolve to the init file.
assert_eq!(None, db.path_to_module(&foo_dir));
assert_eq!(None, db.path_to_module(&foo_dir)?);
Ok(())
}
#[test]
fn package_priority_over_module() -> std::io::Result<()> {
fn package_priority_over_module() -> anyhow::Result<()> {
let TestCase {
db,
temp_dir: _temp_dir,
@ -761,20 +780,20 @@ mod tests {
let foo_py = src.path().join("foo.py");
std::fs::write(&foo_py, "print('Hello, world!')")?;
let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap();
let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap();
assert_eq!(&src, foo_module.path(&db).root());
assert_eq!(&foo_init, &*db.file_path(foo_module.path(&db).file()));
assert_eq!(ModuleKind::Package, foo_module.kind(&db));
assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(&foo_init, &*db.file_path(foo_module.path(&db)?.file()));
assert_eq!(ModuleKind::Package, foo_module.kind(&db)?);
assert_eq!(Some(foo_module), db.path_to_module(&foo_init));
assert_eq!(None, db.path_to_module(&foo_py));
assert_eq!(Some(foo_module), db.path_to_module(&foo_init)?);
assert_eq!(None, db.path_to_module(&foo_py)?);
Ok(())
}
#[test]
fn typing_stub_over_module() -> std::io::Result<()> {
fn typing_stub_over_module() -> anyhow::Result<()> {
let TestCase {
db,
src,
@ -787,19 +806,19 @@ mod tests {
std::fs::write(&foo_stub, "x: int")?;
std::fs::write(&foo_py, "print('Hello, world!')")?;
let foo = db.resolve_module(ModuleName::new("foo")).unwrap();
let foo = db.resolve_module(ModuleName::new("foo"))?.unwrap();
assert_eq!(&src, foo.path(&db).root());
assert_eq!(&foo_stub, &*db.file_path(foo.path(&db).file()));
assert_eq!(&src, foo.path(&db)?.root());
assert_eq!(&foo_stub, &*db.file_path(foo.path(&db)?.file()));
assert_eq!(Some(foo), db.path_to_module(&foo_stub));
assert_eq!(None, db.path_to_module(&foo_py));
assert_eq!(Some(foo), db.path_to_module(&foo_stub)?);
assert_eq!(None, db.path_to_module(&foo_py)?);
Ok(())
}
#[test]
fn sub_packages() -> std::io::Result<()> {
fn sub_packages() -> anyhow::Result<()> {
let TestCase {
db,
src,
@ -816,18 +835,18 @@ mod tests {
std::fs::write(bar.join("__init__.py"), "")?;
std::fs::write(&baz, "print('Hello, world!')")?;
let baz_module = db.resolve_module(ModuleName::new("foo.bar.baz")).unwrap();
let baz_module = db.resolve_module(ModuleName::new("foo.bar.baz"))?.unwrap();
assert_eq!(&src, baz_module.path(&db).root());
assert_eq!(&baz, &*db.file_path(baz_module.path(&db).file()));
assert_eq!(&src, baz_module.path(&db)?.root());
assert_eq!(&baz, &*db.file_path(baz_module.path(&db)?.file()));
assert_eq!(Some(baz_module), db.path_to_module(&baz));
assert_eq!(Some(baz_module), db.path_to_module(&baz)?);
Ok(())
}
#[test]
fn namespace_package() -> std::io::Result<()> {
fn namespace_package() -> anyhow::Result<()> {
let TestCase {
db,
temp_dir: _,
@ -863,21 +882,21 @@ mod tests {
std::fs::write(&two, "print('Hello, world!')")?;
let one_module = db
.resolve_module(ModuleName::new("parent.child.one"))
.resolve_module(ModuleName::new("parent.child.one"))?
.unwrap();
assert_eq!(Some(one_module), db.path_to_module(&one));
assert_eq!(Some(one_module), db.path_to_module(&one)?);
let two_module = db
.resolve_module(ModuleName::new("parent.child.two"))
.resolve_module(ModuleName::new("parent.child.two"))?
.unwrap();
assert_eq!(Some(two_module), db.path_to_module(&two));
assert_eq!(Some(two_module), db.path_to_module(&two)?);
Ok(())
}
#[test]
fn regular_package_in_namespace_package() -> std::io::Result<()> {
fn regular_package_in_namespace_package() -> anyhow::Result<()> {
let TestCase {
db,
temp_dir: _,
@ -914,17 +933,20 @@ mod tests {
std::fs::write(two, "print('Hello, world!')")?;
let one_module = db
.resolve_module(ModuleName::new("parent.child.one"))
.resolve_module(ModuleName::new("parent.child.one"))?
.unwrap();
assert_eq!(Some(one_module), db.path_to_module(&one));
assert_eq!(Some(one_module), db.path_to_module(&one)?);
assert_eq!(None, db.resolve_module(ModuleName::new("parent.child.two")));
assert_eq!(
None,
db.resolve_module(ModuleName::new("parent.child.two"))?
);
Ok(())
}
#[test]
fn module_search_path_priority() -> std::io::Result<()> {
fn module_search_path_priority() -> anyhow::Result<()> {
let TestCase {
db,
src,
@ -938,20 +960,20 @@ mod tests {
std::fs::write(&foo_src, "")?;
std::fs::write(&foo_site_packages, "")?;
let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap();
let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap();
assert_eq!(&src, foo_module.path(&db).root());
assert_eq!(&foo_src, &*db.file_path(foo_module.path(&db).file()));
assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(&foo_src, &*db.file_path(foo_module.path(&db)?.file()));
assert_eq!(Some(foo_module), db.path_to_module(&foo_src));
assert_eq!(None, db.path_to_module(&foo_site_packages));
assert_eq!(Some(foo_module), db.path_to_module(&foo_src)?);
assert_eq!(None, db.path_to_module(&foo_site_packages)?);
Ok(())
}
#[test]
#[cfg(target_family = "unix")]
fn symlink() -> std::io::Result<()> {
fn symlink() -> anyhow::Result<()> {
let TestCase {
db,
src,
@ -965,28 +987,28 @@ mod tests {
std::fs::write(&foo, "")?;
std::os::unix::fs::symlink(&foo, &bar)?;
let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap();
let bar_module = db.resolve_module(ModuleName::new("bar")).unwrap();
let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap();
let bar_module = db.resolve_module(ModuleName::new("bar"))?.unwrap();
assert_ne!(foo_module, bar_module);
assert_eq!(&src, foo_module.path(&db).root());
assert_eq!(&foo, &*db.file_path(foo_module.path(&db).file()));
assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(&foo, &*db.file_path(foo_module.path(&db)?.file()));
// Bar has a different name but it should point to the same file.
assert_eq!(&src, bar_module.path(&db).root());
assert_eq!(foo_module.path(&db).file(), bar_module.path(&db).file());
assert_eq!(&foo, &*db.file_path(bar_module.path(&db).file()));
assert_eq!(&src, bar_module.path(&db)?.root());
assert_eq!(foo_module.path(&db)?.file(), bar_module.path(&db)?.file());
assert_eq!(&foo, &*db.file_path(bar_module.path(&db)?.file()));
assert_eq!(Some(foo_module), db.path_to_module(&foo));
assert_eq!(Some(bar_module), db.path_to_module(&bar));
assert_eq!(Some(foo_module), db.path_to_module(&foo)?);
assert_eq!(Some(bar_module), db.path_to_module(&bar)?);
Ok(())
}
#[test]
fn resolve_dependency() -> std::io::Result<()> {
fn resolve_dependency() -> anyhow::Result<()> {
let TestCase {
src,
db,
@ -1002,8 +1024,8 @@ mod tests {
std::fs::write(foo_path, "from .bar import test")?;
std::fs::write(bar_path, "test = 'Hello world'")?;
let foo_module = db.resolve_module(ModuleName::new("foo")).unwrap();
let bar_module = db.resolve_module(ModuleName::new("foo.bar")).unwrap();
let foo_module = db.resolve_module(ModuleName::new("foo"))?.unwrap();
let bar_module = db.resolve_module(ModuleName::new("foo.bar"))?.unwrap();
// `from . import bar` in `foo/__init__.py` resolves to `foo`
assert_eq!(
@ -1014,13 +1036,13 @@ mod tests {
level: NonZeroU32::new(1).unwrap(),
module: None,
}
)
)?
);
// `from baz import bar` in `foo/__init__.py` should resolve to `baz.py`
assert_eq!(
Some(ModuleName::new("baz")),
foo_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz")))
foo_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz")))?
);
// from .bar import test in `foo/__init__.py` should resolve to `foo/bar.py`
@ -1032,7 +1054,7 @@ mod tests {
level: NonZeroU32::new(1).unwrap(),
module: Some(ModuleName::new("bar"))
}
)
)?
);
// from .. import test in `foo/__init__.py` resolves to `` which is not a module
@ -1044,7 +1066,7 @@ mod tests {
level: NonZeroU32::new(2).unwrap(),
module: None
}
)
)?
);
// `from . import test` in `foo/bar.py` resolves to `foo`
@ -1056,13 +1078,13 @@ mod tests {
level: NonZeroU32::new(1).unwrap(),
module: None
}
)
)?
);
// `from baz import test` in `foo/bar.py` resolves to `baz`
assert_eq!(
Some(ModuleName::new("baz")),
bar_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz")))
bar_module.resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz")))?
);
// `from .baz import test` in `foo/bar.py` resolves to `foo.baz`.
@ -1074,7 +1096,7 @@ mod tests {
level: NonZeroU32::new(1).unwrap(),
module: Some(ModuleName::new("baz"))
}
)
)?
);
Ok(())

View file

@ -6,7 +6,7 @@ use ruff_python_parser::{Mode, ParseError};
use ruff_text_size::{Ranged, TextRange};
use crate::cache::KeyValueCache;
use crate::db::{HasJar, SourceDb, SourceJar};
use crate::db::{HasJar, QueryResult, SourceDb, SourceJar};
use crate::files::FileId;
#[derive(Debug, Clone, PartialEq)]
@ -64,16 +64,16 @@ impl Parsed {
}
#[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn parse<Db>(db: &Db, file_id: FileId) -> Parsed
pub(crate) fn parse<Db>(db: &Db, file_id: FileId) -> QueryResult<Parsed>
where
Db: SourceDb + HasJar<SourceJar>,
{
let parsed = db.jar();
let parsed = db.jar()?;
parsed.parsed.get(&file_id, |file_id| {
let source = db.source(*file_id);
let source = db.source(*file_id)?;
Parsed::from_text(source.text())
Ok(Parsed::from_text(source.text()))
})
}

View file

@ -1,10 +1,9 @@
use std::num::NonZeroUsize;
use rayon::max_num_threads;
use rayon::{current_num_threads, yield_local};
use rustc_hash::FxHashSet;
use crate::cancellation::CancellationToken;
use crate::db::{SemanticDb, SourceDb};
use crate::db::{Database, QueryError, QueryResult, SemanticDb, SourceDb};
use crate::files::FileId;
use crate::lint::Diagnostics;
use crate::program::Program;
@ -13,42 +12,37 @@ use crate::symbols::Dependency;
impl Program {
/// Checks all open files in the workspace and its dependencies.
#[tracing::instrument(level = "debug", skip_all)]
pub fn check(
&self,
scheduler: &dyn CheckScheduler,
cancellation_token: CancellationToken,
) -> Result<Vec<String>, CheckError> {
let check_loop = CheckFilesLoop::new(scheduler, cancellation_token);
pub fn check(&self, scheduler: &dyn CheckScheduler) -> QueryResult<Vec<String>> {
self.cancelled()?;
let check_loop = CheckFilesLoop::new(scheduler);
check_loop.run(self.workspace().open_files.iter().copied())
}
/// Checks a single file and its dependencies.
#[tracing::instrument(level = "debug", skip(self, scheduler, cancellation_token))]
#[tracing::instrument(level = "debug", skip(self, scheduler))]
pub fn check_file(
&self,
file: FileId,
scheduler: &dyn CheckScheduler,
cancellation_token: CancellationToken,
) -> Result<Vec<String>, CheckError> {
let check_loop = CheckFilesLoop::new(scheduler, cancellation_token);
) -> QueryResult<Vec<String>> {
self.cancelled()?;
let check_loop = CheckFilesLoop::new(scheduler);
check_loop.run([file].into_iter())
}
#[tracing::instrument(level = "debug", skip(self, context))]
fn do_check_file(
&self,
file: FileId,
context: &CheckContext,
) -> Result<Diagnostics, CheckError> {
context.cancelled_ok()?;
fn do_check_file(&self, file: FileId, context: &CheckContext) -> QueryResult<Diagnostics> {
self.cancelled()?;
let symbol_table = self.symbol_table(file);
let symbol_table = self.symbol_table(file)?;
let dependencies = symbol_table.dependencies();
if !dependencies.is_empty() {
let module = self.file_to_module(file);
let module = self.file_to_module(file)?;
// TODO scheduling all dependencies here is wasteful if we don't infer any types on them
// but I think that's unlikely, so it is okay?
@ -57,18 +51,19 @@ impl Program {
for dependency in dependencies {
let dependency_name = match dependency {
Dependency::Module(name) => Some(name.clone()),
Dependency::Relative { .. } => module
.as_ref()
.and_then(|module| module.resolve_dependency(self, dependency)),
Dependency::Relative { .. } => match &module {
Some(module) => module.resolve_dependency(self, dependency)?,
None => None,
},
};
if let Some(dependency_name) = dependency_name {
// TODO We may want to have a different check functions for non-first-party
// files because we only need to index them and not check them.
// Supporting non-first-party code also requires supporting typing stubs.
if let Some(dependency) = self.resolve_module(dependency_name) {
if dependency.path(self).root().kind().is_first_party() {
context.schedule_check_file(dependency.path(self).file());
if let Some(dependency) = self.resolve_module(dependency_name)? {
if dependency.path(self)?.root().kind().is_first_party() {
context.schedule_check_file(dependency.path(self)?.file());
}
}
}
@ -78,8 +73,8 @@ impl Program {
let mut diagnostics = Vec::new();
if self.workspace().is_file_open(file) {
diagnostics.extend_from_slice(&self.lint_syntax(file));
diagnostics.extend_from_slice(&self.lint_semantic(file));
diagnostics.extend_from_slice(&self.lint_syntax(file)?);
diagnostics.extend_from_slice(&self.lint_semantic(file)?);
}
Ok(Diagnostics::from(diagnostics))
@ -128,10 +123,18 @@ where
self.scope
.spawn(move |_| child_span.in_scope(|| check_file_task.run(program)));
if current_num_threads() == 1 {
yield_local();
}
}
fn max_concurrency(&self) -> Option<NonZeroUsize> {
Some(NonZeroUsize::new(max_num_threads()).unwrap_or(NonZeroUsize::MIN))
if current_num_threads() == 1 {
return None;
}
Some(NonZeroUsize::new(current_num_threads()).unwrap_or(NonZeroUsize::MIN))
}
}
@ -156,11 +159,6 @@ impl CheckScheduler for SameThreadCheckScheduler<'_> {
}
}
#[derive(Debug, Clone)]
pub enum CheckError {
Cancelled,
}
#[derive(Debug)]
pub struct CheckFileTask {
file_id: FileId,
@ -176,7 +174,7 @@ impl CheckFileTask {
.sender
.send(CheckFileMessage::Completed(diagnostics))
.unwrap(),
Err(CheckError::Cancelled) => self
Err(QueryError::Cancelled) => self
.context
.sender
.send(CheckFileMessage::Cancelled)
@ -187,19 +185,12 @@ impl CheckFileTask {
#[derive(Clone, Debug)]
struct CheckContext {
cancellation_token: CancellationToken,
sender: crossbeam_channel::Sender<CheckFileMessage>,
sender: crossbeam::channel::Sender<CheckFileMessage>,
}
impl CheckContext {
fn new(
cancellation_token: CancellationToken,
sender: crossbeam_channel::Sender<CheckFileMessage>,
) -> Self {
Self {
cancellation_token,
sender,
}
fn new(sender: crossbeam::channel::Sender<CheckFileMessage>) -> Self {
Self { sender }
}
/// Queues a new file for checking using the [`CheckScheduler`].
@ -207,52 +198,36 @@ impl CheckContext {
fn schedule_check_file(&self, file_id: FileId) {
self.sender.send(CheckFileMessage::Queue(file_id)).unwrap();
}
/// Returns `true` if the check has been cancelled.
fn is_cancelled(&self) -> bool {
self.cancellation_token.is_cancelled()
}
fn cancelled_ok(&self) -> Result<(), CheckError> {
if self.is_cancelled() {
Err(CheckError::Cancelled)
} else {
Ok(())
}
}
}
struct CheckFilesLoop<'a> {
scheduler: &'a dyn CheckScheduler,
cancellation_token: CancellationToken,
pending: usize,
queued_files: FxHashSet<FileId>,
}
impl<'a> CheckFilesLoop<'a> {
fn new(scheduler: &'a dyn CheckScheduler, cancellation_token: CancellationToken) -> Self {
fn new(scheduler: &'a dyn CheckScheduler) -> Self {
Self {
scheduler,
cancellation_token,
queued_files: FxHashSet::default(),
pending: 0,
}
}
fn run(mut self, files: impl Iterator<Item = FileId>) -> Result<Vec<String>, CheckError> {
fn run(mut self, files: impl Iterator<Item = FileId>) -> QueryResult<Vec<String>> {
let (sender, receiver) = if let Some(max_concurrency) = self.scheduler.max_concurrency() {
crossbeam_channel::bounded(max_concurrency.get())
crossbeam::channel::bounded(max_concurrency.get())
} else {
// The checks run on the current thread. That means it is necessary to store all messages
// or we risk deadlocking when the main loop never gets a chance to read the messages.
crossbeam_channel::unbounded()
crossbeam::channel::unbounded()
};
let context = CheckContext::new(self.cancellation_token.clone(), sender.clone());
let context = CheckContext::new(sender.clone());
for file in files {
self.queue_file(file, context.clone())?;
self.queue_file(file, context.clone());
}
self.run_impl(receiver, &context)
@ -260,14 +235,11 @@ impl<'a> CheckFilesLoop<'a> {
fn run_impl(
mut self,
receiver: crossbeam_channel::Receiver<CheckFileMessage>,
receiver: crossbeam::channel::Receiver<CheckFileMessage>,
context: &CheckContext,
) -> Result<Vec<String>, CheckError> {
if self.cancellation_token.is_cancelled() {
return Err(CheckError::Cancelled);
}
) -> QueryResult<Vec<String>> {
let mut result = Vec::default();
let mut cancelled = false;
for message in receiver {
match message {
@ -281,30 +253,35 @@ impl<'a> CheckFilesLoop<'a> {
}
}
CheckFileMessage::Queue(id) => {
self.queue_file(id, context.clone())?;
if !cancelled {
self.queue_file(id, context.clone());
}
}
CheckFileMessage::Cancelled => {
return Err(CheckError::Cancelled);
self.pending -= 1;
cancelled = true;
if self.pending == 0 {
break;
}
}
}
}
Ok(result)
if cancelled {
Err(QueryError::Cancelled)
} else {
Ok(result)
}
}
fn queue_file(&mut self, file_id: FileId, context: CheckContext) -> Result<(), CheckError> {
if context.is_cancelled() {
return Err(CheckError::Cancelled);
}
fn queue_file(&mut self, file_id: FileId, context: CheckContext) {
if self.queued_files.insert(file_id) {
self.pending += 1;
self.scheduler
.check_file(CheckFileTask { file_id, context });
}
Ok(())
}
}

View file

@ -1,45 +1,35 @@
pub mod check;
use std::path::Path;
use std::sync::Arc;
use crate::db::{Db, HasJar, SemanticDb, SemanticJar, SourceDb, SourceJar};
use crate::files::{FileId, Files};
use crate::lint::{
lint_semantic, lint_syntax, Diagnostics, LintSemanticStorage, LintSyntaxStorage,
use crate::db::{
Database, Db, DbRuntime, HasJar, HasJars, JarsStorage, ParallelDatabase, QueryResult,
SemanticDb, SemanticJar, Snapshot, SourceDb, SourceJar,
};
use crate::files::{FileId, Files};
use crate::lint::{lint_semantic, lint_syntax, Diagnostics};
use crate::module::{
add_module, file_to_module, path_to_module, resolve_module, set_module_search_paths, Module,
ModuleData, ModuleName, ModuleResolver, ModuleSearchPath,
ModuleData, ModuleName, ModuleSearchPath,
};
use crate::parse::{parse, Parsed, ParsedStorage};
use crate::source::{source_text, Source, SourceStorage};
use crate::symbols::{symbol_table, SymbolId, SymbolTable, SymbolTablesStorage};
use crate::types::{infer_symbol_type, Type, TypeStore};
use crate::parse::{parse, Parsed};
use crate::source::{source_text, Source};
use crate::symbols::{symbol_table, SymbolId, SymbolTable};
use crate::types::{infer_symbol_type, Type};
use crate::Workspace;
pub mod check;
#[derive(Debug)]
pub struct Program {
jars: JarsStorage<Program>,
files: Files,
source: SourceJar,
semantic: SemanticJar,
workspace: Workspace,
}
impl Program {
pub fn new(workspace: Workspace, module_search_paths: Vec<ModuleSearchPath>) -> Self {
pub fn new(workspace: Workspace) -> Self {
Self {
source: SourceJar {
sources: SourceStorage::default(),
parsed: ParsedStorage::default(),
lint_syntax: LintSyntaxStorage::default(),
},
semantic: SemanticJar {
module_resolver: ModuleResolver::new(module_search_paths),
symbol_tables: SymbolTablesStorage::default(),
type_store: TypeStore::default(),
lint_semantic: LintSemanticStorage::default(),
},
jars: JarsStorage::default(),
files: Files::default(),
workspace,
}
@ -49,17 +39,19 @@ impl Program {
where
I: IntoIterator<Item = FileChange>,
{
let files = self.files.clone();
let (source, semantic) = self.jars_mut();
for change in changes {
self.semantic
.module_resolver
.remove_module(&self.file_path(change.id));
self.semantic.symbol_tables.remove(&change.id);
self.source.sources.remove(&change.id);
self.source.parsed.remove(&change.id);
self.source.lint_syntax.remove(&change.id);
let file_path = files.path(change.id);
semantic.module_resolver.remove_module(&file_path);
semantic.symbol_tables.remove(&change.id);
source.sources.remove(&change.id);
source.parsed.remove(&change.id);
source.lint_syntax.remove(&change.id);
// TODO: remove all dependent modules as well
self.semantic.type_store.remove_module(change.id);
self.semantic.lint_semantic.remove(&change.id);
semantic.type_store.remove_module(change.id);
semantic.lint_semantic.remove(&change.id);
}
}
@ -85,41 +77,41 @@ impl SourceDb for Program {
self.files.path(file_id)
}
fn source(&self, file_id: FileId) -> Source {
fn source(&self, file_id: FileId) -> QueryResult<Source> {
source_text(self, file_id)
}
fn parse(&self, file_id: FileId) -> Parsed {
fn parse(&self, file_id: FileId) -> QueryResult<Parsed> {
parse(self, file_id)
}
fn lint_syntax(&self, file_id: FileId) -> Diagnostics {
fn lint_syntax(&self, file_id: FileId) -> QueryResult<Diagnostics> {
lint_syntax(self, file_id)
}
}
impl SemanticDb for Program {
fn resolve_module(&self, name: ModuleName) -> Option<Module> {
fn resolve_module(&self, name: ModuleName) -> QueryResult<Option<Module>> {
resolve_module(self, name)
}
fn file_to_module(&self, file_id: FileId) -> Option<Module> {
fn file_to_module(&self, file_id: FileId) -> QueryResult<Option<Module>> {
file_to_module(self, file_id)
}
fn path_to_module(&self, path: &Path) -> Option<Module> {
fn path_to_module(&self, path: &Path) -> QueryResult<Option<Module>> {
path_to_module(self, path)
}
fn symbol_table(&self, file_id: FileId) -> Arc<SymbolTable> {
fn symbol_table(&self, file_id: FileId) -> QueryResult<Arc<SymbolTable>> {
symbol_table(self, file_id)
}
fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> Type {
fn infer_symbol_type(&self, file_id: FileId, symbol_id: SymbolId) -> QueryResult<Type> {
infer_symbol_type(self, file_id, symbol_id)
}
fn lint_semantic(&self, file_id: FileId) -> Diagnostics {
fn lint_semantic(&self, file_id: FileId) -> QueryResult<Diagnostics> {
lint_semantic(self, file_id)
}
@ -135,23 +127,55 @@ impl SemanticDb for Program {
impl Db for Program {}
impl Database for Program {
fn runtime(&self) -> &DbRuntime {
self.jars.runtime()
}
fn runtime_mut(&mut self) -> &mut DbRuntime {
self.jars.runtime_mut()
}
}
impl ParallelDatabase for Program {
fn snapshot(&self) -> Snapshot<Self> {
Snapshot::new(Self {
jars: self.jars.snapshot(),
files: self.files.clone(),
workspace: self.workspace.clone(),
})
}
}
impl HasJars for Program {
type Jars = (SourceJar, SemanticJar);
fn jars(&self) -> QueryResult<&Self::Jars> {
self.jars.jars()
}
fn jars_mut(&mut self) -> &mut Self::Jars {
self.jars.jars_mut()
}
}
impl HasJar<SourceJar> for Program {
fn jar(&self) -> &SourceJar {
&self.source
fn jar(&self) -> QueryResult<&SourceJar> {
Ok(&self.jars()?.0)
}
fn jar_mut(&mut self) -> &mut SourceJar {
&mut self.source
&mut self.jars_mut().0
}
}
impl HasJar<SemanticJar> for Program {
fn jar(&self) -> &SemanticJar {
&self.semantic
fn jar(&self) -> QueryResult<&SemanticJar> {
Ok(&self.jars()?.1)
}
fn jar_mut(&mut self) -> &mut SemanticJar {
&mut self.semantic
&mut self.jars_mut().1
}
}

View file

@ -1,5 +1,5 @@
use crate::cache::KeyValueCache;
use crate::db::{HasJar, SourceDb, SourceJar};
use crate::db::{HasJar, QueryResult, SourceDb, SourceJar};
use ruff_notebook::Notebook;
use ruff_python_ast::PySourceType;
use std::ops::{Deref, DerefMut};
@ -8,11 +8,11 @@ use std::sync::Arc;
use crate::files::FileId;
#[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn source_text<Db>(db: &Db, file_id: FileId) -> Source
pub(crate) fn source_text<Db>(db: &Db, file_id: FileId) -> QueryResult<Source>
where
Db: SourceDb + HasJar<SourceJar>,
{
let sources = &db.jar().sources;
let sources = &db.jar()?.sources;
sources.get(&file_id, |file_id| {
let path = db.file_path(*file_id);
@ -43,7 +43,7 @@ where
}
};
Source { kind }
Ok(Source { kind })
})
}

View file

@ -16,22 +16,22 @@ use ruff_python_ast::visitor::preorder::PreorderVisitor;
use crate::ast_ids::TypedNodeKey;
use crate::cache::KeyValueCache;
use crate::db::{HasJar, SemanticDb, SemanticJar};
use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar};
use crate::files::FileId;
use crate::module::ModuleName;
use crate::Name;
#[allow(unreachable_pub)]
#[tracing::instrument(level = "debug", skip(db))]
pub fn symbol_table<Db>(db: &Db, file_id: FileId) -> Arc<SymbolTable>
pub fn symbol_table<Db>(db: &Db, file_id: FileId) -> QueryResult<Arc<SymbolTable>>
where
Db: SemanticDb + HasJar<SemanticJar>,
{
let jar = db.jar();
let jar = db.jar()?;
jar.symbol_tables.get(&file_id, |_| {
let parsed = db.parse(file_id);
Arc::from(SymbolTable::from_ast(parsed.ast()))
let parsed = db.parse(file_id)?;
Ok(Arc::from(SymbolTable::from_ast(parsed.ast())))
})
}

View file

@ -2,7 +2,7 @@
use ruff_python_ast::AstNode;
use crate::db::{HasJar, SemanticDb, SemanticJar};
use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar};
use crate::module::ModuleName;
use crate::symbols::{Definition, ImportFromDefinition, SymbolId};
use crate::types::Type;
@ -11,23 +11,24 @@ use ruff_python_ast as ast;
// FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`.
#[tracing::instrument(level = "trace", skip(db))]
pub fn infer_symbol_type<Db>(db: &Db, file_id: FileId, symbol_id: SymbolId) -> Type
pub fn infer_symbol_type<Db>(db: &Db, file_id: FileId, symbol_id: SymbolId) -> QueryResult<Type>
where
Db: SemanticDb + HasJar<SemanticJar>,
{
let symbols = db.symbol_table(file_id);
let symbols = db.symbol_table(file_id)?;
let defs = symbols.definitions(symbol_id);
if let Some(ty) = db
.jar()
.jar()?
.type_store
.get_cached_symbol_type(file_id, symbol_id)
{
return ty;
return Ok(ty);
}
// TODO handle multiple defs, conditional defs...
assert_eq!(defs.len(), 1);
let type_store = &db.jar()?.type_store;
let ty = match &defs[0] {
Definition::ImportFrom(ImportFromDefinition {
@ -38,11 +39,11 @@ where
// TODO relative imports
assert!(matches!(level, 0));
let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports"));
if let Some(module) = db.resolve_module(module_name) {
let remote_file_id = module.path(db).file();
let remote_symbols = db.symbol_table(remote_file_id);
if let Some(module) = db.resolve_module(module_name)? {
let remote_file_id = module.path(db)?.file();
let remote_symbols = db.symbol_table(remote_file_id)?;
if let Some(remote_symbol_id) = remote_symbols.root_symbol_id_by_name(name) {
db.infer_symbol_type(remote_file_id, remote_symbol_id)
db.infer_symbol_type(remote_file_id, remote_symbol_id)?
} else {
Type::Unknown
}
@ -50,71 +51,68 @@ where
Type::Unknown
}
}
Definition::ClassDef(node_key) => db
.jar()
.type_store
.get_cached_node_type(file_id, node_key.erased())
.unwrap_or_else(|| {
let parsed = db.parse(file_id);
Definition::ClassDef(node_key) => {
if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) {
ty
} else {
let parsed = db.parse(file_id)?;
let ast = parsed.ast();
let node = node_key.resolve_unwrap(ast.as_any_node_ref());
let bases: Vec<_> = node
.bases()
.iter()
.map(|base_expr| infer_expr_type(db, file_id, base_expr))
.collect();
let mut bases = Vec::with_capacity(node.bases().len());
let store = &db.jar().type_store;
let ty = Type::Class(store.add_class(file_id, &node.name.id, bases));
store.cache_node_type(file_id, *node_key.erased(), ty);
for base in node.bases() {
bases.push(infer_expr_type(db, file_id, base)?);
}
let ty = Type::Class(type_store.add_class(file_id, &node.name.id, bases));
type_store.cache_node_type(file_id, *node_key.erased(), ty);
ty
}),
Definition::FunctionDef(node_key) => db
.jar()
.type_store
.get_cached_node_type(file_id, node_key.erased())
.unwrap_or_else(|| {
let parsed = db.parse(file_id);
}
}
Definition::FunctionDef(node_key) => {
if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) {
ty
} else {
let parsed = db.parse(file_id)?;
let ast = parsed.ast();
let node = node_key
.resolve(ast.as_any_node_ref())
.expect("node key should resolve");
let store = &db.jar().type_store;
let ty = store.add_function(file_id, &node.name.id).into();
store.cache_node_type(file_id, *node_key.erased(), ty);
let ty = type_store.add_function(file_id, &node.name.id).into();
type_store.cache_node_type(file_id, *node_key.erased(), ty);
ty
}),
}
}
Definition::Assignment(node_key) => {
let parsed = db.parse(file_id);
let parsed = db.parse(file_id)?;
let ast = parsed.ast();
let node = node_key.resolve_unwrap(ast.as_any_node_ref());
// TODO handle unpacking assignment correctly
infer_expr_type(db, file_id, &node.value)
infer_expr_type(db, file_id, &node.value)?
}
_ => todo!("other kinds of definitions"),
};
db.jar()
.type_store
.cache_symbol_type(file_id, symbol_id, ty);
type_store.cache_symbol_type(file_id, symbol_id, ty);
// TODO record dependencies
ty
Ok(ty)
}
fn infer_expr_type<Db>(db: &Db, file_id: FileId, expr: &ast::Expr) -> Type
fn infer_expr_type<Db>(db: &Db, file_id: FileId, expr: &ast::Expr) -> QueryResult<Type>
where
Db: SemanticDb + HasJar<SemanticJar>,
{
// TODO cache the resolution of the type on the node
let symbols = db.symbol_table(file_id);
let symbols = db.symbol_table(file_id)?;
match expr {
ast::Expr::Name(name) => {
if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) {
db.infer_symbol_type(file_id, symbol_id)
} else {
Type::Unknown
Ok(Type::Unknown)
}
}
_ => todo!("full expression type resolution"),
@ -154,7 +152,7 @@ mod tests {
}
#[test]
fn follow_import_to_class() -> std::io::Result<()> {
fn follow_import_to_class() -> anyhow::Result<()> {
let case = create_test()?;
let db = &case.db;
@ -163,18 +161,18 @@ mod tests {
std::fs::write(a_path, "from b import C as D; E = D")?;
std::fs::write(b_path, "class C: pass")?;
let a_file = db
.resolve_module(ModuleName::new("a"))
.resolve_module(ModuleName::new("a"))?
.expect("module should be found")
.path(db)
.path(db)?
.file();
let a_syms = db.symbol_table(a_file);
let a_syms = db.symbol_table(a_file)?;
let e_sym = a_syms
.root_symbol_id_by_name("E")
.expect("E symbol should be found");
let ty = db.infer_symbol_type(a_file, e_sym);
let ty = db.infer_symbol_type(a_file, e_sym)?;
let jar = HasJar::<SemanticJar>::jar(db);
let jar = HasJar::<SemanticJar>::jar(db)?;
assert!(matches!(ty, Type::Class(_)));
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]");
@ -182,28 +180,28 @@ mod tests {
}
#[test]
fn resolve_base_class_by_name() -> std::io::Result<()> {
fn resolve_base_class_by_name() -> anyhow::Result<()> {
let case = create_test()?;
let db = &case.db;
let path = case.src.path().join("mod.py");
std::fs::write(path, "class Base: pass\nclass Sub(Base): pass")?;
let file = db
.resolve_module(ModuleName::new("mod"))
.resolve_module(ModuleName::new("mod"))?
.expect("module should be found")
.path(db)
.path(db)?
.file();
let syms = db.symbol_table(file);
let syms = db.symbol_table(file)?;
let sym = syms
.root_symbol_id_by_name("Sub")
.expect("Sub symbol should be found");
let ty = db.infer_symbol_type(file, sym);
let ty = db.infer_symbol_type(file, sym)?;
let Type::Class(class_id) = ty else {
panic!("Sub is not a Class")
};
let jar = HasJar::<SemanticJar>::jar(db);
let jar = HasJar::<SemanticJar>::jar(db)?;
let base_names: Vec<_> = jar
.type_store
.get_class(class_id)

View file

@ -26,7 +26,7 @@ ruff_text_size = { path = "../ruff_text_size" }
ruff_workspace = { path = "../ruff_workspace" }
anyhow = { workspace = true }
crossbeam-channel = { workspace = true }
crossbeam = { workspace = true }
jod-thread = { workspace = true }
libc = { workspace = true }
lsp-server = { workspace = true }

View file

@ -6,7 +6,7 @@ use serde_json::Value;
use super::schedule::Task;
pub(crate) type ClientSender = crossbeam_channel::Sender<lsp_server::Message>;
pub(crate) type ClientSender = crossbeam::channel::Sender<lsp_server::Message>;
type ResponseBuilder<'s> = Box<dyn FnOnce(lsp_server::Response) -> Task<'s>>;

View file

@ -1,6 +1,6 @@
use std::num::NonZeroUsize;
use crossbeam_channel::Sender;
use crossbeam::channel::Sender;
use crate::session::Session;

View file

@ -21,7 +21,7 @@ use std::{
},
};
use crossbeam_channel::{Receiver, Sender};
use crossbeam::channel::{Receiver, Sender};
use super::{Builder, JoinHandle, ThreadPriority};
@ -52,7 +52,7 @@ impl Pool {
let threads = usize::from(threads);
// Channel buffer capacity is between 2 and 4, depending on the pool size.
let (job_sender, job_receiver) = crossbeam_channel::bounded(std::cmp::min(threads * 2, 4));
let (job_sender, job_receiver) = crossbeam::channel::bounded(std::cmp::min(threads * 2, 4));
let extant_tasks = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::with_capacity(threads);