Merge pull request #7130 from snobee/annotate-type-signatures

Automatic annotation of type signatures
This commit is contained in:
Sam Mohr 2025-01-31 11:46:12 -05:00 committed by GitHub
commit 670d255060
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 775 additions and 27 deletions

3
Cargo.lock generated
View file

@ -2596,12 +2596,14 @@ dependencies = [
"roc_mono",
"roc_packaging",
"roc_parse",
"roc_problem",
"roc_region",
"roc_repl_cli",
"roc_repl_expect",
"roc_reporting",
"roc_target",
"roc_tracing",
"roc_types",
"roc_wasm_interp",
"signal-hook",
"strum",
@ -2862,6 +2864,7 @@ dependencies = [
"log",
"parking_lot",
"roc_can",
"roc_cli",
"roc_collections",
"roc_error_macros",
"roc_fmt",

View file

@ -60,10 +60,12 @@ roc_module.workspace = true
roc_mono.workspace = true
roc_packaging.workspace = true
roc_parse.workspace = true
roc_problem.workspace = true
roc_region.workspace = true
roc_reporting.workspace = true
roc_target.workspace = true
roc_tracing.workspace = true
roc_types.workspace = true
roc_repl_cli = { workspace = true, optional = true }
roc_wasm_interp = { workspace = true, optional = true }

View file

@ -1,17 +1,28 @@
use std::ffi::OsStr;
use std::io::Write;
use std::ops::Range;
use std::path::{Path, PathBuf};
use bumpalo::Bump;
use bumpalo::{collections::String as BumpString, Bump};
use roc_can::abilities::{IAbilitiesStore, Resolved};
use roc_can::expr::{DeclarationTag, Declarations, Expr};
use roc_error_macros::{internal_error, user_error};
use roc_fmt::def::fmt_defs;
use roc_fmt::header::fmt_header;
use roc_fmt::Buf;
use roc_fmt::MigrationFlags;
use roc_load::{ExecutionMode, FunctionKind, LoadConfig, LoadedModule, LoadingProblem, Threading};
use roc_module::symbol::{Interns, ModuleId};
use roc_packaging::cache::{self, RocCacheDir};
use roc_parse::ast::{FullAst, SpacesBefore};
use roc_parse::header::parse_module_defs;
use roc_parse::normalize::Normalize;
use roc_parse::{header, parser::SyntaxError, state::State};
use roc_problem::can::RuntimeError;
use roc_region::all::{LineColumn, LineInfo};
use roc_reporting::report::{RenderTarget, DEFAULT_PALETTE};
use roc_target::Target;
use roc_types::subs::{Subs, Variable};
#[derive(Copy, Clone, Debug)]
pub enum FormatMode {
@ -263,10 +274,177 @@ fn fmt_all<'a>(buf: &mut Buf<'a>, ast: &'a FullAst) {
buf.fmt_end_of_file();
}
#[derive(Debug)]
pub enum AnnotationProblem<'a> {
Loading(LoadingProblem<'a>),
Type(TypeProblem),
}
#[derive(Debug)]
pub struct TypeProblem {
pub name: String,
pub position: LineColumn,
}
pub fn annotate_file(arena: &Bump, file: PathBuf) -> Result<(), AnnotationProblem> {
let load_config = LoadConfig {
target: Target::default(),
function_kind: FunctionKind::from_env(),
render: RenderTarget::ColorTerminal,
palette: DEFAULT_PALETTE,
threading: Threading::AllAvailable,
exec_mode: ExecutionMode::Check,
};
let mut loaded = roc_load::load_and_typecheck(
arena,
file.clone(),
None,
RocCacheDir::Persistent(cache::roc_cache_dir().as_path()),
load_config,
)
.map_err(AnnotationProblem::Loading)?;
let buf = annotate_module(arena, &mut loaded)?;
std::fs::write(&file, buf.as_str())
.unwrap_or_else(|e| internal_error!("failed to write annotated file to {file:?}: {e}"));
Ok(())
}
fn annotate_module<'a>(
arena: &'a Bump,
loaded: &mut LoadedModule,
) -> Result<BumpString<'a>, AnnotationProblem<'a>> {
let (decls, subs, abilities) =
if let Some(decls) = loaded.declarations_by_id.get(&loaded.module_id) {
let subs = loaded.solved.inner_mut();
let abilities = &loaded.abilities_store;
(decls, subs, abilities)
} else if let Some(checked) = loaded.typechecked.get_mut(&loaded.module_id) {
let decls = &checked.decls;
let subs = checked.solved_subs.inner_mut();
let abilities = &checked.abilities_store;
(decls, subs, abilities)
} else {
internal_error!("Could not find file's module");
};
let src = &loaded
.sources
.get(&loaded.module_id)
.unwrap_or_else(|| internal_error!("Could not find the file's source"))
.1;
let mut edits = annotation_edits(
decls,
subs,
abilities,
src,
loaded.module_id,
&loaded.interns,
)
.map_err(AnnotationProblem::Type)?;
edits.sort_by_key(|(offset, _)| *offset);
let mut buffer = BumpString::new_in(arena);
let mut file_progress = 0;
for (position, edit) in edits {
buffer.push_str(&src[file_progress..position]);
buffer.push_str(&edit);
file_progress = position;
}
buffer.push_str(&src[file_progress..]);
Ok(buffer)
}
pub fn annotation_edits(
decls: &Declarations,
subs: &Subs,
abilities: &IAbilitiesStore<Resolved>,
src: &str,
module_id: ModuleId,
interns: &Interns,
) -> Result<Vec<(usize, String)>, TypeProblem> {
let mut edits = Vec::with_capacity(decls.len());
for (index, tag) in decls.iter_bottom_up() {
let var = decls.variables[index];
let symbol = decls.symbols[index];
let expr = &decls.expressions[index].value;
if decls.annotations[index].is_some()
| matches!(
*expr,
Expr::RuntimeError(RuntimeError::ExposedButNotDefined(..)) | Expr::ImportParams(..)
)
| abilities.is_specialization_name(symbol.value)
| matches!(tag, DeclarationTag::MutualRecursion { .. })
{
continue;
}
let byte_range = match tag {
DeclarationTag::Destructure(i) => decls.destructs[i.index()].loc_pattern.byte_range(),
_ => symbol.byte_range(),
};
let edit = annotation_edit(src, subs, interns, module_id, var, byte_range)?;
edits.push(edit);
}
Ok(edits)
}
pub fn annotation_edit(
src: &str,
subs: &Subs,
interns: &Interns,
module_id: ModuleId,
var: Variable,
symbol_range: Range<usize>,
) -> Result<(usize, String), TypeProblem> {
let symbol_str = &src[symbol_range.clone()];
if subs.var_contains_error(var) {
let line_info = LineInfo::new(src);
let position = line_info.convert_offset(symbol_range.start as u32);
return Err(TypeProblem {
name: symbol_str.to_owned(),
position,
});
}
let signature = roc_types::pretty_print::name_and_print_var(
var,
&mut subs.clone(),
module_id,
interns,
roc_types::pretty_print::DebugPrint::NOTHING,
);
let line_start = src[..symbol_range.start]
.rfind('\n')
.map_or(symbol_range.start, |pos| pos + 1);
let indent = src[line_start..]
.split_once(|c: char| !c.is_ascii_whitespace())
.map_or("", |pair| pair.0);
let edit = format!("{indent}{symbol_str} : {signature}\n");
Ok((line_start, edit))
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use indoc::indoc;
use std::fs::{read_to_string, File};
use std::io::Write;
use tempfile::{tempdir, TempDir};
@ -379,4 +557,75 @@ main =
cleanup_temp_dir(dir);
}
const HEADER: &str = indoc! {r#"
interface Test
exposes []
imports []
"#};
fn annotate_string(before: String) -> String {
let dir = tempdir().unwrap();
let file_path = setup_test_file(dir.path(), "before.roc", &before);
let arena = Bump::new();
let result = annotate_file(&arena, file_path.clone());
result.unwrap();
let annotated = read_to_string(file_path).unwrap();
cleanup_temp_dir(dir);
annotated
}
#[test]
fn test_annotate_simple() {
let before = HEADER.to_string()
+ indoc! {r#"
main =
"Hello, World!""#};
let after = HEADER.to_string()
+ indoc! {r#"
main : Str
main =
"Hello, World!"
"#};
let annotated = annotate_string(before);
assert_eq!(annotated, after);
}
#[test]
fn test_annotate_empty() {
let before = HEADER.to_string();
let after = HEADER.to_string() + "\n";
let annotated = annotate_string(before);
assert_eq!(annotated, after);
}
#[test]
fn test_annotate_destructure() {
let before = HEADER.to_string()
+ indoc! {r#"
{a, b} = {a: "zero", b: (1, 2)}
main = a"#};
let after = HEADER.to_string()
+ indoc! {r#"
{a, b} : { a : Str, b : ( Num *, Num * )* }
{a, b} = {a: "zero", b: (1, 2)}
main : Str
main = a
"#};
let annotated = annotate_string(before);
assert_eq!(annotated, after);
}
}

View file

@ -42,7 +42,10 @@ use strum::IntoEnumIterator;
use tempfile::TempDir;
mod format;
pub use format::{format_files, format_src, FormatMode};
pub use format::{
annotate_file, annotation_edit, annotation_edits, format_files, format_src, AnnotationProblem,
FormatMode,
};
pub const CMD_BUILD: &str = "build";
pub const CMD_RUN: &str = "run";
@ -52,6 +55,7 @@ pub const CMD_DOCS: &str = "docs";
pub const CMD_CHECK: &str = "check";
pub const CMD_VERSION: &str = "version";
pub const CMD_FORMAT: &str = "format";
pub const CMD_FORMAT_ANNOTATE: &str = "annotate";
pub const CMD_TEST: &str = "test";
pub const CMD_GLUE: &str = "glue";
pub const CMD_PREPROCESS_HOST: &str = "preprocess-host";
@ -380,6 +384,16 @@ pub fn build_app() -> Command {
.required(false),
)
.after_help("If DIRECTORY_OR_FILES is omitted, the .roc files in the current working\ndirectory are formatted.")
.subcommand(Command::new(CMD_FORMAT_ANNOTATE)
.about("Annotate all top level definitions from a .roc file")
.arg(
Arg::new(ROC_FILE)
.help("The .roc file ot annotate")
.value_parser(value_parser!(PathBuf))
.required(false)
.default_value(DEFAULT_ROC_FILENAME),
)
)
)
.subcommand(Command::new(CMD_VERSION)
.about(concatcp!("Print the Roc compilers version, which is currently ", VERSION)))

View file

@ -3,15 +3,16 @@ use bumpalo::Bump;
use roc_build::link::LinkType;
use roc_build::program::{check_file, CodeGenBackend};
use roc_cli::{
build_app, default_linking_strategy, format_files, format_src, test, BuildConfig, FormatMode,
CMD_BUILD, CMD_CHECK, CMD_DEV, CMD_DOCS, CMD_FORMAT, CMD_GLUE, CMD_PREPROCESS_HOST, CMD_REPL,
CMD_RUN, CMD_TEST, CMD_VERSION, DIRECTORY_OR_FILES, FLAG_CHECK, FLAG_DEV, FLAG_DOCS_ROOT,
FLAG_LIB, FLAG_MAIN, FLAG_MIGRATE, FLAG_NO_COLOR, FLAG_NO_HEADER, FLAG_NO_LINK, FLAG_OUTPUT,
FLAG_PP_DYLIB, FLAG_PP_HOST, FLAG_PP_PLATFORM, FLAG_STDIN, FLAG_STDOUT, FLAG_TARGET, FLAG_TIME,
FLAG_VERBOSE, GLUE_DIR, GLUE_SPEC, ROC_FILE, VERSION,
annotate_file, build_app, default_linking_strategy, format_files, format_src, test,
AnnotationProblem, BuildConfig, FormatMode, CMD_BUILD, CMD_CHECK, CMD_DEV, CMD_DOCS,
CMD_FORMAT, CMD_FORMAT_ANNOTATE, CMD_GLUE, CMD_PREPROCESS_HOST, CMD_REPL, CMD_RUN, CMD_TEST,
CMD_VERSION, DIRECTORY_OR_FILES, FLAG_CHECK, FLAG_DEV, FLAG_DOCS_ROOT, FLAG_LIB, FLAG_MAIN,
FLAG_MIGRATE, FLAG_NO_COLOR, FLAG_NO_HEADER, FLAG_NO_LINK, FLAG_OUTPUT, FLAG_PP_DYLIB,
FLAG_PP_HOST, FLAG_PP_PLATFORM, FLAG_STDIN, FLAG_STDOUT, FLAG_TARGET, FLAG_TIME, FLAG_VERBOSE,
GLUE_DIR, GLUE_SPEC, ROC_FILE, VERSION,
};
use roc_docs::generate_docs_html;
use roc_error_macros::user_error;
use roc_error_macros::{internal_error, user_error};
use roc_fmt::MigrationFlags;
use roc_gen_dev::AssemblyBackendMode;
use roc_gen_llvm::llvm::build::LlvmBackendMode;
@ -346,6 +347,42 @@ fn main() -> io::Result<()> {
Ok(0)
}
Some((CMD_FORMAT, fmatches)) if Some(CMD_FORMAT_ANNOTATE) == fmatches.subcommand_name() => {
let matches = fmatches
.subcommand_matches(CMD_FORMAT_ANNOTATE)
.unwrap_or_else(|| internal_error!("No annotate subcommand present"));
let arena = Bump::new();
let roc_file_path = matches
.get_one::<PathBuf>(ROC_FILE)
.unwrap_or_else(|| internal_error!("No default for ROC_FILE"));
let annotate_exit_code = match annotate_file(&arena, roc_file_path.to_owned()) {
Ok(()) => 0,
Err(AnnotationProblem::Loading(LoadingProblem::FormattedReport(report, ..))) => {
eprintln!("{report}");
1
}
Err(AnnotationProblem::Type(type_problem)) => {
eprintln!(
"The type generated for `{}` on line {} contains an error",
type_problem.name, type_problem.position.line,
);
eprintln!(
"run `roc check \"{}\"` for a more detailed error",
roc_file_path.to_str().unwrap_or_else(|| internal_error!(
"File path is not a valid utf8 string"
))
);
1
}
Err(other) => {
internal_error!("build_file failed with error:\n{other:?}");
}
};
Ok(annotate_exit_code)
}
Some((CMD_FORMAT, matches)) => {
let from_stdin = matches.get_flag(FLAG_STDIN);
let to_stdout = matches.get_flag(FLAG_STDOUT);

View file

@ -2990,16 +2990,16 @@ impl Declarations {
}
pub fn iter_bottom_up(&self) -> impl Iterator<Item = (usize, DeclarationTag)> + '_ {
self.declarations
.iter()
.rev()
.scan(self.declarations.len() - 1, |state, e| {
self.declarations.iter().rev().scan(
self.declarations.len().saturating_sub(1),
|state, e| {
let length_so_far = *state;
*state = length_so_far.saturating_sub(e.len());
Some((length_so_far, *e))
})
},
)
}
pub fn expects(&self) -> ExpectCollector {

View file

@ -31,6 +31,7 @@ pub enum DeclarationInfo<'a> {
expr_var: Variable,
pattern: Pattern,
function: &'a Loc<expr::FunctionDef>,
annotation: Option<&'a Annotation>,
},
Destructure {
loc_pattern: &'a Loc<Pattern>,
@ -113,6 +114,7 @@ pub fn walk_decls<V: Visitor>(visitor: &mut V, decls: &Declarations) {
let loc_symbol = decls.symbols[index];
let expr_var = decls.variables[index];
let annotation = decls.annotations[index].as_ref();
let pattern = match decls.specializes.get(&index).copied() {
Some(specializes) => Pattern::AbilityMemberSpecialization {
@ -130,6 +132,7 @@ pub fn walk_decls<V: Visitor>(visitor: &mut V, decls: &Declarations) {
expr_var,
pattern,
function: function_def,
annotation,
}
}
Destructure(destructure_index) => {
@ -191,6 +194,7 @@ pub fn walk_decl<V: Visitor>(visitor: &mut V, decl: DeclarationInfo<'_>) {
expr_var,
pattern,
function,
annotation,
} => {
visitor.visit_pattern(&pattern, loc_symbol.region, Some(expr_var));
@ -199,7 +203,11 @@ pub fn walk_decl<V: Visitor>(visitor: &mut V, decl: DeclarationInfo<'_>) {
&function.value.arguments,
loc_body,
function.value.return_type,
)
);
if let Some(annot) = annotation {
visitor.visit_annotation(annot);
}
}
Destructure {
loc_pattern,

View file

@ -2198,6 +2198,98 @@ impl Subs {
_ => false,
}
}
pub fn var_contains_error(&self, var: Variable) -> bool {
match &self.get_content_without_compacting(var).clone() {
Content::Error => true,
Content::FlexVar(Some(index)) => {
// Generated names for errors start with `#`
self[*index].as_str().starts_with('#')
}
Content::FlexVar(..)
| Content::RigidVar(..)
| Content::FlexAbleVar(..)
| Content::RigidAbleVar(..)
| Content::ErasedLambda
| Content::RangedNumber(..)
| Content::Pure
| Content::Effectful
| Content::Structure(FlatType::EmptyRecord)
| Content::Structure(FlatType::EmptyTagUnion)
| Content::Structure(FlatType::EffectfulFunc) => false,
Content::RecursionVar { structure, .. } => self.var_contains_error(*structure),
Content::LambdaSet(LambdaSet {
solved,
recursion_var,
unspecialized,
..
}) => {
if let Some(rec_var) = recursion_var.into_variable() {
if self.var_contains_error(rec_var) {
return true;
}
}
unspecialized
.into_iter()
.any(|uls_index| self.var_contains_error(self[uls_index].0))
|| solved.variables().into_iter().any(|slice_index| {
self[slice_index]
.into_iter()
.any(|var_index| self.var_contains_error(self[var_index]))
})
}
Content::Alias(_symbol, args, actual, _kind) => {
self.var_contains_error(*actual)
|| args
.into_iter()
.take(args.len())
.any(|index| self.var_contains_error(self[index]))
}
Content::Structure(FlatType::Apply(_, args)) => args
.into_iter()
.any(|index| self.var_contains_error(self[index])),
Content::Structure(FlatType::Func(arg_vars, closure_var, ret_var, fx_var)) => {
self.var_contains_error(*closure_var)
|| self.var_contains_error(*ret_var)
|| self.var_contains_error(*fx_var)
|| arg_vars
.into_iter()
.any(|index| self.var_contains_error(self[index]))
}
Content::Structure(FlatType::Record(sorted_fields, ext_var)) => {
self.var_contains_error(*ext_var)
|| sorted_fields
.iter_variables()
.any(|index| self.var_contains_error(self[index]))
}
Content::Structure(FlatType::Tuple(elems, ext_var)) => {
self.var_contains_error(*ext_var)
|| elems
.iter_variables()
.any(|index| self.var_contains_error(self[index]))
}
Content::Structure(FlatType::TagUnion(tags, ext_var)) => {
self.var_contains_error(ext_var.var())
|| tags.variables().into_iter().any(|slice_index| {
self[slice_index]
.into_iter()
.any(|var_index| self.var_contains_error(self[var_index]))
})
}
Content::Structure(FlatType::FunctionOrTagUnion(_, _, ext_var)) => {
self.var_contains_error(ext_var.var())
}
Content::Structure(FlatType::RecursiveTagUnion(rec_var, tags, ext_var)) => {
self.var_contains_error(ext_var.var())
|| self.var_contains_error(*rec_var)
|| tags.variables().into_iter().any(|slice_index| {
self[slice_index]
.into_iter()
.any(|var_index| self.var_contains_error(self[var_index]))
})
}
}
}
}
#[inline(always)]

View file

@ -25,6 +25,7 @@ roc_solve_problem.workspace = true
roc_target.workspace = true
roc_types.workspace = true
roc_packaging.workspace = true
roc_cli.workspace = true
bumpalo.workspace = true
parking_lot.workspace = true
@ -41,4 +42,3 @@ indoc.workspace = true
env_logger = "0.10.1"
futures.workspace = true
roc_error_macros.workspace = true

View file

@ -20,6 +20,7 @@ use roc_types::subs::{Subs, Variable};
use tower_lsp::lsp_types::{Diagnostic, SemanticTokenType, Url};
mod analysed_doc;
mod annotation_visitor;
mod completion;
mod parse_ast;
mod semantic_tokens;

View file

@ -1,4 +1,6 @@
use log::{debug, info};
use roc_cli::{annotation_edit, annotation_edits};
use roc_fmt::MigrationFlags;
use std::collections::HashMap;
@ -6,11 +8,12 @@ use bumpalo::Bump;
use roc_module::symbol::{ModuleId, Symbol};
use roc_region::all::LineInfo;
use roc_region::all::{LineInfo, Position as RocPosition, Region};
use tower_lsp::lsp_types::{
CompletionItem, Diagnostic, GotoDefinitionResponse, Hover, HoverContents, LanguageString,
Location, MarkedString, Position, Range, SemanticTokens, SemanticTokensResult, TextEdit, Url,
CodeAction, CodeActionKind, CompletionItem, Diagnostic, GotoDefinitionResponse, Hover,
HoverContents, LanguageString, Location, MarkedString, Position, Range, SemanticTokens,
SemanticTokensResult, TextEdit, Url, WorkspaceEdit,
};
use crate::{
@ -18,10 +21,11 @@ use crate::{
field_completion, get_completion_items, get_module_completion_items,
get_tag_completion_items,
},
convert::{ToRange, ToRocPosition},
convert::{ToRange, ToRegion, ToRocPosition},
};
use super::{
annotation_visitor::{find_declaration_at, FoundDeclaration, NotFound},
parse_ast::Ast,
semantic_tokens::arrange_semantic_tokens,
utils::{format_var_type, is_roc_identifier_char},
@ -324,4 +328,85 @@ impl AnalyzedDocument {
}
}
}
pub fn annotate(&self, range: Range) -> Option<CodeAction> {
let region = range.to_region(self.line_info());
match find_declaration_at(region, &self.module()?.declarations) {
Ok(found_declaration) => self.annotate_declaration(found_declaration),
Err(NotFound::TopLevel) => self.annnotate_top_level(),
_ => None,
}
}
fn annnotate_top_level(&self) -> Option<CodeAction> {
let AnalyzedModule {
module_id,
interns,
subs,
abilities,
declarations,
..
} = self.module()?;
let edits = annotation_edits(
declarations,
subs,
abilities,
&self.doc_info.source,
*module_id,
interns,
)
.ok()?
.into_iter()
.map(|(offset, new_text)| {
let pos = roc_region::all::Position::new(offset as u32);
let range = Region::new(pos, pos).to_range(self.line_info());
TextEdit { range, new_text }
})
.collect();
Some(CodeAction {
title: "Add top-level signatures".to_owned(),
edit: Some(WorkspaceEdit::new(HashMap::from([(
self.url().clone(),
edits,
)]))),
kind: Some(CodeActionKind::SOURCE),
..Default::default()
})
}
fn annotate_declaration(&self, decl: FoundDeclaration) -> Option<CodeAction> {
let AnalyzedModule {
module_id,
interns,
subs,
..
} = self.module()?;
let (offset, new_text) = annotation_edit(
&self.doc_info.source,
subs,
interns,
*module_id,
decl.var,
decl.range,
)
.ok()?;
let pos = RocPosition::new(offset as u32);
let range = Region::new(pos, pos).to_range(self.line_info());
let edit = TextEdit { range, new_text };
Some(CodeAction {
title: "Add signature".to_owned(),
edit: Some(WorkspaceEdit::new(HashMap::from([(
self.url().clone(),
vec![edit],
)]))),
..Default::default()
})
}
}

View file

@ -0,0 +1,90 @@
use roc_can::{
def::{Def, DefKind},
expr::{Declarations, Expr},
traverse::{self, DeclarationInfo, Visitor},
};
use roc_region::all::Region;
use roc_types::subs::Variable;
use std::ops::Range;
pub struct FoundDeclaration {
pub var: Variable,
pub range: Range<usize>,
}
pub enum NotFound {
TopLevel,
AlreadyAnnotated,
}
pub fn find_declaration_at(
region: Region,
decls: &Declarations,
) -> Result<FoundDeclaration, NotFound> {
let mut visitor = Finder {
region,
found: Err(NotFound::TopLevel),
};
visitor.visit_decls(decls);
return visitor.found;
struct Finder {
region: Region,
found: Result<FoundDeclaration, NotFound>,
}
impl Visitor for Finder {
fn should_visit(&mut self, region: Region) -> bool {
region.contains(&self.region)
}
fn visit_decl(&mut self, decl: DeclarationInfo<'_>) {
if self.should_visit(decl.region()) {
match decl {
DeclarationInfo::Value { loc_expr, .. }
if matches!(loc_expr.value, Expr::ImportParams(..)) => {}
DeclarationInfo::Value {
expr_var: var,
loc_symbol,
annotation,
..
}
| DeclarationInfo::Function {
expr_var: var,
loc_symbol,
annotation,
..
} if annotation.is_none() => {
let range = loc_symbol.byte_range();
self.found = Ok(FoundDeclaration { var, range })
}
DeclarationInfo::Destructure {
expr_var: var,
loc_pattern,
annotation,
..
} if annotation.is_none() => {
let range = loc_pattern.byte_range();
self.found = Ok(FoundDeclaration { var, range })
}
DeclarationInfo::Expectation { .. } => {}
_ => self.found = Err(NotFound::AlreadyAnnotated),
}
traverse::walk_decl(self, decl)
}
}
fn visit_def(&mut self, def: &Def) {
if self.should_visit(def.region()) {
if !matches!(def.kind, DefKind::Stmt(..)) && def.annotation.is_none() {
self.found = Ok(FoundDeclaration {
var: def.expr_var,
range: def.loc_pattern.byte_range(),
});
}
traverse::walk_def(self, def)
}
}
}
}

View file

@ -42,7 +42,7 @@ impl ToRegion for Range {
},
end: LineColumn {
line: self.end.line,
column: self.end.line,
column: self.end.character,
},
};

View file

@ -9,8 +9,8 @@ use std::{
use tokio::sync::{Mutex, MutexGuard};
use tower_lsp::lsp_types::{
CompletionResponse, Diagnostic, GotoDefinitionResponse, Hover, Position, SemanticTokensResult,
TextEdit, Url,
CodeActionOrCommand, CodeActionResponse, CompletionResponse, Diagnostic,
GotoDefinitionResponse, Hover, Position, Range, SemanticTokensResult, TextEdit, Url,
};
use crate::analysis::{AnalyzedDocument, DocInfo};
@ -217,4 +217,14 @@ impl Registry {
Some(CompletionResponse::Array(completions))
}
pub async fn code_actions(&self, url: &Url, range: Range) -> Option<CodeActionResponse> {
let document = self.latest_document_by_url(url).await?;
let mut responses = vec![];
if let Some(edit) = document.annotate(range) {
responses.push(CodeActionOrCommand::CodeAction(edit));
}
Some(responses)
}
}

View file

@ -103,6 +103,7 @@ impl RocServer {
work_done_progress: None,
},
};
let code_action_provider = CodeActionProviderCapability::Simple(true);
ServerCapabilities {
text_document_sync: Some(text_document_sync),
hover_provider: Some(hover_provider),
@ -110,6 +111,7 @@ impl RocServer {
document_formatting_provider: Some(OneOf::Right(document_formatting_provider)),
semantic_tokens_provider: Some(semantic_tokens_provider),
completion_provider: Some(completion_provider),
code_action_provider: Some(code_action_provider),
..ServerCapabilities::default()
}
}
@ -338,6 +340,18 @@ impl LanguageServer for RocServer {
)
.await
}
async fn code_action(&self, params: CodeActionParams) -> Result<Option<CodeActionResponse>> {
let CodeActionParams {
text_document,
range,
context: _,
partial_result_params: _,
work_done_progress_params: _,
} = params;
unwind_async(self.state.registry.code_actions(&text_document.uri, range)).await
}
}
async fn unwind_async<Fut, T>(future: Fut) -> tower_lsp::jsonrpc::Result<T>
@ -466,7 +480,7 @@ mod tests {
+ indoc! {r#"
main =
when a is
inn as outer ->
inn as outer ->
"#};
let (inner, url) = test_setup(suffix.clone()).await;
@ -508,7 +522,7 @@ mod tests {
+ indoc! {r#"
main =
when a is
{one,two} as outer ->
{one,two} as outer ->
"#};
let (inner, url) = test_setup(doc.clone()).await;
@ -572,7 +586,7 @@ mod tests {
async fn test_completion_closure() {
let actual = completion_test_labels(
indoc! {r"
main = [] |> List.map \ param1 , param2->
main = [] |> List.map \ param1 , param2->
"},
"par",
Position::new(4, 3),
@ -644,4 +658,147 @@ mod tests {
"#]]
.assert_debug_eq(&actual);
}
async fn code_action_edits(doc: String, position: Position, name: &str) -> Vec<TextEdit> {
let (inner, url) = test_setup(doc.clone()).await;
let registry = &inner.registry;
let actions = registry
.code_actions(&url, Range::new(position, position))
.await
.unwrap();
actions
.into_iter()
.find_map(|either| match either {
CodeActionOrCommand::CodeAction(action) if name == action.title => Some(action),
_ => None,
})
.expect("Code action not present")
.edit
.expect("Code action does not have an associated edit")
.changes
.expect("Edit does not have any changes")
.get(&url)
.expect("Edit does not have changes for this file")
.clone()
}
#[tokio::test]
async fn test_annotate_single() {
let edit = code_action_edits(
DOC_LIT.to_string() + r#"main = "Hello, world!""#,
Position::new(3, 2),
"Add signature",
)
.await;
expect![[r#"
[
TextEdit {
range: Range {
start: Position {
line: 3,
character: 0,
},
end: Position {
line: 3,
character: 0,
},
},
new_text: "main : Str\n",
},
]
"#]]
.assert_debug_eq(&edit);
}
#[tokio::test]
async fn test_annotate_top_level() {
let edit = code_action_edits(
DOC_LIT.to_string()
+ indoc! {r#"
other = \_ ->
"Something else?"
main =
other {}
"#},
Position::new(5, 0),
"Add top-level signatures",
)
.await;
expect![[r#"
[
TextEdit {
range: Range {
start: Position {
line: 3,
character: 0,
},
end: Position {
line: 3,
character: 0,
},
},
new_text: "other : * -> Str\n",
},
TextEdit {
range: Range {
start: Position {
line: 6,
character: 0,
},
end: Position {
line: 6,
character: 0,
},
},
new_text: "main : Str\n",
},
]
"#]]
.assert_debug_eq(&edit);
}
#[tokio::test]
async fn test_annotate_inner() {
let edit = code_action_edits(
DOC_LIT.to_string()
+ indoc! {r#"
main =
start = 10
fib start 0 1
fib = \n, a, b ->
if n == 0 then
a
else
fib (n - 1) b (a + b)
"#},
Position::new(4, 8),
"Add signature",
)
.await;
expect![[r#"
[
TextEdit {
range: Range {
start: Position {
line: 4,
character: 0,
},
end: Position {
line: 4,
character: 0,
},
},
new_text: " start : Num *\n",
},
]
"#]]
.assert_debug_eq(&edit);
}
}