Refactor LSP main loop

This commit is contained in:
oxalica 2023-03-24 13:16:20 +08:00
parent e8e338d59b
commit 48009b60fa
3 changed files with 87 additions and 67 deletions

View file

@ -7,10 +7,10 @@ mod semantic_tokens;
mod server;
mod vfs;
use anyhow::{Context, Result};
use anyhow::Result;
use ide::VfsPath;
use lsp_server::{Connection, ErrorCode};
use lsp_types::{InitializeParams, Url};
use lsp_types::Url;
use std::fmt;
pub(crate) use server::{Server, StateSnapshot};
@ -67,26 +67,14 @@ impl UrlExt for Url {
}
}
pub fn main_loop(conn: Connection) -> Result<()> {
let init_params = conn.initialize(
serde_json::to_value(capabilities::server_capabilities()).context("Invalid init_params")?,
)?;
tracing::info!("Init params: {}", init_params);
pub fn run_server_stdio() -> Result<()> {
let (conn, io_threads) = Connection::stdio();
let init_params = serde_json::from_value::<InitializeParams>(init_params)?;
let root_path = match init_params
.root_uri
.as_ref()
.and_then(|uri| uri.to_file_path().ok())
{
Some(path) => path,
None => std::env::current_dir()?,
};
let mut server = Server::new(conn.sender.clone(), root_path);
server.run(conn.receiver, init_params)?;
let server = Server::new(conn.sender, conn.receiver);
server.run()?;
tracing::info!("Leaving main loop");
io_threads.join()?;
Ok(())
}

View file

@ -2,7 +2,6 @@ use anyhow::{Context, Result};
use argh::FromArgs;
use codespan_reporting::term::termcolor::WriteColor;
use ide::AnalysisHost;
use lsp_server::Connection;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::{env, fs, io, process};
@ -80,8 +79,7 @@ fn main() {
setup_logger();
let (conn, io_threads) = Connection::stdio();
match nil::main_loop(conn).and_then(|()| io_threads.join().map_err(Into::into)) {
match nil::run_server_stdio() {
Ok(()) => {}
Err(err) => {
tracing::error!("Unexpected error: {err:#}");

View file

@ -1,9 +1,12 @@
use crate::capabilities::server_capabilities;
use crate::config::{Config, CONFIG_KEY};
use crate::{convert, handler, lsp_ext, LspError, UrlExt, Vfs, MAX_FILE_LEN};
use anyhow::{anyhow, bail, Context, Result};
use crossbeam_channel::{Receiver, Sender};
use ide::{Analysis, AnalysisHost, Cancelled, FlakeInfo, VfsPath};
use lsp_server::{ErrorCode, Message, Notification, ReqQueue, Request, RequestId, Response};
use lsp_server::{
Connection, ErrorCode, Message, Notification, ReqQueue, Request, RequestId, Response,
};
use lsp_types::notification::Notification as _;
use lsp_types::{
notification as notif, request as req, ConfigurationItem, ConfigurationParams, Diagnostic,
@ -17,7 +20,7 @@ use std::cell::Cell;
use std::collections::HashMap;
use std::io::ErrorKind;
use std::panic::UnwindSafe;
use std::path::{Path, PathBuf};
use std::path::Path;
use std::sync::{Arc, Once, RwLock};
use std::{fs, panic, thread};
@ -61,6 +64,7 @@ pub struct Server {
// Message passing.
req_queue: ReqQueue<(), ReqHandler>,
lsp_tx: Sender<Message>,
lsp_rx: Receiver<Message>,
task_tx: Sender<Task>,
event_tx: Sender<Event>,
event_rx: Receiver<Event>,
@ -73,7 +77,7 @@ struct FileData {
}
impl Server {
pub fn new(lsp_tx: Sender<Message>, root_path: PathBuf) -> Self {
pub fn new(lsp_tx: Sender<Message>, lsp_rx: Receiver<Message>) -> Self {
let (task_tx, task_rx) = crossbeam_channel::unbounded();
let (event_tx, event_rx) = crossbeam_channel::unbounded();
let worker_cnt = thread::available_parallelism().map_or(1, |n| n.get());
@ -91,12 +95,14 @@ impl Server {
host: AnalysisHost::default(),
vfs: Arc::new(RwLock::new(Vfs::new())),
opened_files: HashMap::default(),
config: Arc::new(Config::new(root_path)),
// Will be initialized in `Server::run`.
config: Arc::new(Config::new("/non-existing-path".into())),
is_shutdown: false,
version_counter: 0,
req_queue: ReqQueue::default(),
lsp_tx,
lsp_rx,
task_tx,
event_tx,
event_rx,
@ -111,47 +117,35 @@ impl Server {
}
}
pub fn run(&mut self, lsp_rx: Receiver<Message>, init_params: InitializeParams) -> Result<()> {
#[cfg(target_os = "linux")]
pub fn run(mut self) -> Result<()> {
let init_params = Connection {
sender: self.lsp_tx.clone(),
receiver: self.lsp_rx.clone(),
}
.initialize(serde_json::to_value(server_capabilities()).unwrap())?;
tracing::info!("Init params: {}", init_params);
let init_params = serde_json::from_value::<InitializeParams>(init_params)
.context("Invalid init_params")?;
let root_path = match init_params
.root_uri
.as_ref()
.and_then(|uri| uri.to_file_path().ok())
{
Some(path) => path,
None => std::env::current_dir().context("Failed to the current directory")?,
};
*Arc::get_mut(&mut self.config).expect("No concurrent access yet") = Config::new(root_path);
if let Some(pid) = init_params.process_id {
use std::io;
use std::mem::MaybeUninit;
use std::os::unix::io::{AsRawFd, FromRawFd, OwnedFd, RawFd};
use std::ptr::null_mut;
fn wait_remote_pid(pid: libc::pid_t) -> Result<(), io::Error> {
let pidfd = unsafe {
let ret = libc::syscall(libc::SYS_pidfd_open, pid, 0 as libc::c_int);
if ret == -1 {
return Err(io::Error::last_os_error());
}
OwnedFd::from_raw_fd(ret as RawFd)
};
unsafe {
let mut fdset = MaybeUninit::uninit();
libc::FD_ZERO(fdset.as_mut_ptr());
libc::FD_SET(pidfd.as_raw_fd(), fdset.as_mut_ptr());
let nfds = pidfd.as_raw_fd() + 1;
let ret =
libc::select(nfds, fdset.as_mut_ptr(), null_mut(), null_mut(), null_mut());
if ret == -1 {
return Err(io::Error::last_os_error());
}
}
Ok(())
}
let event_tx = self.event_tx.clone();
thread::spawn(move || {
match wait_remote_pid(pid as _) {
Ok(()) => {}
Err(err) if err.raw_os_error() == Some(libc::ESRCH) => {}
Err(err) => {
tracing::warn!("Failed to monitor parent pid {}: {}", pid, err);
return;
}
thread::spawn(move || match wait_for_pid(pid as _) {
Ok(()) => {
let _ = event_tx.send(Event::ClientExited);
}
Err(err) => {
tracing::warn!("Failed to monitor parent pid {}: {}", pid, err);
}
let _ = event_tx.send(Event::ClientExited);
});
}
@ -164,7 +158,7 @@ impl Server {
loop {
crossbeam_channel::select! {
recv(lsp_rx) -> msg => {
recv(self.lsp_rx) -> msg => {
match msg.context("Channel closed")? {
Message::Request(req) => self.dispatch_request(req),
Message::Notification(notif) => {
@ -175,7 +169,7 @@ impl Server {
}
Message::Response(resp) => {
if let Some(callback) = self.req_queue.outgoing.complete(resp.id.clone()) {
callback(self, resp);
callback(&mut self, resp);
}
}
}
@ -779,3 +773,43 @@ impl StateSnapshot {
self.vfs.read().unwrap()
}
}
#[cfg(target_os = "linux")]
fn wait_for_pid(pid: u32) -> std::io::Result<()> {
use std::io;
use std::mem::MaybeUninit;
use std::os::unix::io::{AsRawFd, FromRawFd, OwnedFd, RawFd};
use std::ptr::null_mut;
let pidfd = unsafe {
let ret = libc::syscall(libc::SYS_pidfd_open, pid as libc::pid_t, 0 as libc::c_int);
if ret == -1 {
let err = io::Error::last_os_error();
if err.raw_os_error() == Some(libc::ESRCH) {
return Ok(());
}
return Err(err);
}
OwnedFd::from_raw_fd(ret as RawFd)
};
unsafe {
let mut fdset = MaybeUninit::uninit();
libc::FD_ZERO(fdset.as_mut_ptr());
libc::FD_SET(pidfd.as_raw_fd(), fdset.as_mut_ptr());
let nfds = pidfd.as_raw_fd() + 1;
let ret = libc::select(nfds, fdset.as_mut_ptr(), null_mut(), null_mut(), null_mut());
if ret == -1 {
return Err(io::Error::last_os_error());
}
}
Ok(())
}
#[cfg(not(target_os = "linux"))]
fn wait_for_pid(_pid: u32) -> std::io::Result<()> {
Err(std::io::Error::new(
ErrorKind::Other,
"Waiting for arbitrary PID is not supported on this platform",
))
}