change command handler strat to function decorator (#26)

* change command handler strat to function decorator

* add handlers back to agent class
This commit is contained in:
Josh Thomas 2024-12-12 19:39:43 -06:00 committed by GitHub
parent 0a6e975ca5
commit 520a2eff59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 279 additions and 447 deletions

View file

@ -1,47 +1,34 @@
from __future__ import annotations
import logging
import struct
import sys
from typing import Any
from typing import cast
from google.protobuf.message import Message
from .commands import COMMANDS
from .commands import Command
from .logging import configure_logging
from .proto.v1 import messages_pb2
logger = logging.getLogger("djls")
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler("/tmp/djls_debug.log")
fh.setLevel(logging.DEBUG)
ch = logging.StreamHandler(sys.stderr)
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
logger = configure_logging()
class LSPAgent:
def __init__(self):
self._commands: dict[str, Command] = {cmd.name: cmd() for cmd in COMMANDS}
from .handlers import handlers
self._handlers = handlers
logger.debug(
"LSPAgent initialized with commands: %s", list(self._commands.keys())
"LSPAgent initialized with handlers: %s", list(self._handlers.keys())
)
def serve(self):
async def serve(self):
print("ready", flush=True)
try:
import django
django.setup()
except Exception as e:
error_response = self.create_error(messages_pb2.Error.DJANGO_ERROR, str(e))
self.write_message(error_response)
@ -52,7 +39,7 @@ class LSPAgent:
if not data:
break
response = self.handle_request(data)
response = await self.handle_request(data)
self.write_message(response)
except Exception as e:
@ -71,23 +58,30 @@ class LSPAgent:
logger.debug("Read data bytes: %r", data)
return data
def handle_request(self, request_data: bytes) -> Message:
async def handle_request(self, request_data: bytes) -> Message:
request = messages_pb2.Request()
request.ParseFromString(request_data)
command_name = request.WhichOneof("command")
logger.debug("Command name: %s", command_name)
command = self._commands.get(command_name)
if not command:
if not command_name:
logger.error("No command specified")
return self.create_error(
messages_pb2.Error.INVALID_REQUEST, "No command specified"
)
handler = self._handlers.get(command_name)
if not handler:
logger.error("Unknown command: %s", command_name)
return self.create_error(
messages_pb2.Error.INVALID_REQUEST, f"Unknown command: {command_name}"
)
try:
result = command.execute(getattr(request, command_name))
return messages_pb2.Response(**{command_name: result})
command_message = getattr(request, command_name)
result = await handler(command_message)
return messages_pb2.Response(**{command_name: cast(Any, result)})
except Exception as e:
logger.exception("Error executing command")
return self.create_error(messages_pb2.Error.UNKNOWN, str(e))
@ -110,14 +104,14 @@ class LSPAgent:
return response
def main() -> None:
async def main() -> None:
logger.debug("Starting DJLS...")
try:
logger.debug("Initializing LSPAgent...")
agent = LSPAgent()
logger.debug("Starting LSPAgent serve...")
agent.serve()
await agent.serve()
except KeyboardInterrupt:
logger.debug("Received KeyboardInterrupt")
sys.exit(0)
@ -128,4 +122,6 @@ def main() -> None:
if __name__ == "__main__":
main()
import asyncio
asyncio.run(main())

View file

@ -1,209 +0,0 @@
from __future__ import annotations
import importlib.metadata
import os
import sys
import sysconfig
from abc import ABC
from abc import abstractmethod
from typing import ClassVar
from typing import Generic
from typing import TypeVar
from google.protobuf.message import Message
from ._typing import override
from .proto.v1 import check_pb2
from .proto.v1 import django_pb2
from .proto.v1 import python_pb2
Request = TypeVar("Request", bound=Message)
Response = TypeVar("Response", bound=Message)
class Command(ABC, Generic[Request, Response]):
name: ClassVar[str]
request: ClassVar[type[Message]]
response: ClassVar[type[Message]]
def __init_subclass__(cls) -> None:
super().__init_subclass__()
class_vars = ["name", "request", "response"]
for class_var in class_vars:
if not hasattr(cls, class_var):
raise TypeError(
f"Command subclass {cls.__name__} must define '{class_var}'"
)
@abstractmethod
def execute(self, request: Request) -> Response: ...
class CheckHealth(Command[check_pb2.HealthRequest, check_pb2.HealthResponse]):
name = "check__health"
request = check_pb2.HealthRequest
response = check_pb2.HealthResponse
@override
def execute(self, request: check_pb2.HealthRequest) -> check_pb2.HealthResponse:
return check_pb2.HealthResponse(passed=True)
class CheckDjangoAvailable(
Command[check_pb2.DjangoAvailableRequest, check_pb2.DjangoAvailableResponse]
):
name = "check__django_available"
request = check_pb2.DjangoAvailableRequest
response = check_pb2.DjangoAvailableResponse
@override
def execute(
self, request: check_pb2.DjangoAvailableRequest
) -> check_pb2.DjangoAvailableResponse:
try:
import django
return check_pb2.DjangoAvailableResponse(passed=True)
except ImportError:
return check_pb2.DjangoAvailableResponse(
passed=False, error="Django is not installed"
)
class CheckAppInstalled(
Command[check_pb2.AppInstalledRequest, check_pb2.AppInstalledResponse]
):
name = "check__app_installed"
request = check_pb2.AppInstalledRequest
response = check_pb2.AppInstalledResponse
@override
def execute(
self, request: check_pb2.AppInstalledRequest
) -> check_pb2.AppInstalledResponse:
try:
from django.apps import apps
return check_pb2.AppInstalledResponse(
passed=apps.is_installed(request.app_name)
)
except ImportError:
return check_pb2.AppInstalledResponse(
passed=False, error="Django is not installed"
)
class PythonGetEnvironment(
Command[python_pb2.GetEnvironmentRequest, python_pb2.GetEnvironmentResponse]
):
name = "python__get_environment"
request = python_pb2.GetEnvironmentRequest
response = python_pb2.GetEnvironmentResponse
@override
def execute(
self, request: python_pb2.GetEnvironmentRequest
) -> python_pb2.GetEnvironmentResponse:
packages = {}
for dist in importlib.metadata.distributions():
try:
requires = []
try:
requires = list(dist.requires) if hasattr(dist, "requires") else []
except Exception:
pass
location = None
try:
location = str(dist._path) if hasattr(dist, "_path") else None
except Exception:
pass
packages[dist.metadata["Name"]] = python_pb2.Package(
dist_name=dist.metadata["Name"],
dist_version=dist.metadata["Version"],
dist_location=location,
dist_requires=requires,
dist_requires_python=dist.metadata.get("Requires-Python"),
dist_entry_points=str(dist.entry_points)
if hasattr(dist, "entry_points")
else None,
)
except Exception:
continue
sysconfig_paths = sysconfig.get_paths()
version_info = python_pb2.VersionInfo(
major=sys.version_info.major,
minor=sys.version_info.minor,
micro=sys.version_info.micro,
releaselevel={
"alpha": python_pb2.ReleaseLevel.ALPHA,
"beta": python_pb2.ReleaseLevel.BETA,
"candidate": python_pb2.ReleaseLevel.CANDIDATE,
"final": python_pb2.ReleaseLevel.FINAL,
}[sys.version_info.releaselevel],
serial=sys.version_info.serial,
)
return python_pb2.GetEnvironmentResponse(
python=python_pb2.Python(
os=python_pb2.Os(environ={k: v for k, v in os.environ.items()}),
site=python_pb2.Site(packages=packages),
sys=python_pb2.Sys(
debug_build=hasattr(sys, "gettotalrefcount"),
dev_mode=sys.flags.dev_mode,
is_venv=sys.prefix != sys.base_prefix,
abiflags=sys.abiflags,
base_prefix=sys.base_prefix,
default_encoding=sys.getdefaultencoding(),
executable=sys.executable,
filesystem_encoding=sys.getfilesystemencoding(),
implementation_name=sys.implementation.name,
platform=sys.platform,
prefix=sys.prefix,
builtin_module_names=list(sys.builtin_module_names),
dll_paths=sys.path if sys.platform == "win32" else [],
path=sys.path,
version_info=version_info,
),
sysconfig=python_pb2.Sysconfig(
data=sysconfig_paths.get("data", ""),
include=sysconfig_paths.get("include", ""),
platinclude=sysconfig_paths.get("platinclude", ""),
platlib=sysconfig_paths.get("platlib", ""),
platstdlib=sysconfig_paths.get("platstdlib", ""),
purelib=sysconfig_paths.get("purelib", ""),
scripts=sysconfig_paths.get("scripts", ""),
stdlib=sysconfig_paths.get("stdlib", ""),
),
)
)
class DjangoGetProjectInfo(
Command[django_pb2.GetProjectInfoRequest, django_pb2.GetProjectInfoResponse]
):
name = "django__get_project_info"
request = django_pb2.GetProjectInfoRequest
response = django_pb2.GetProjectInfoResponse
@override
def execute(
self, request: django_pb2.GetProjectInfoRequest
) -> django_pb2.GetProjectInfoResponse:
import django
return django_pb2.GetProjectInfoResponse(
project=django_pb2.Project(version=django.__version__)
)
COMMANDS = [
CheckAppInstalled,
CheckDjangoAvailable,
CheckHealth,
PythonGetEnvironment,
DjangoGetProjectInfo,
]

208
python/djls/handlers.py Normal file
View file

@ -0,0 +1,208 @@
from __future__ import annotations
import importlib.metadata
import inspect
import os
import sys
import sysconfig
import traceback
from collections.abc import Awaitable
from collections.abc import Coroutine
from functools import wraps
from typing import Any
from typing import Callable
from typing import TypeVar
from typing import cast
import django
from django.apps import apps
from google.protobuf.message import Message
from .proto.v1 import check_pb2
from .proto.v1 import django_pb2
from .proto.v1 import messages_pb2
from .proto.v1 import python_pb2
T = TypeVar("T", bound=Message)
R = TypeVar("R", bound=Message)
handlers: dict[str, Callable[[Message], Coroutine[Any, Any, Message]]] = {}
def proto_handler(
request_type: type[T],
error: messages_pb2.Error | None = None,
) -> Callable[
[Callable[[T], R] | Callable[[T], Awaitable[R]]],
Callable[[T], Coroutine[Any, Any, R]],
]:
for req_field in messages_pb2.Request.DESCRIPTOR.fields:
if req_field.message_type == request_type.DESCRIPTOR:
command_name = req_field.name
# Find corresponding response type
for resp_field in messages_pb2.Response.DESCRIPTOR.fields:
if resp_field.name == command_name:
response_type = resp_field.message_type._concrete_class
break
else:
raise ValueError(f"No response type found for {request_type}")
break
else:
raise ValueError(f"Message type {request_type} not found in Request message")
def decorator(
func: Callable[[T], R] | Callable[[T], Awaitable[R]],
) -> Callable[[T], Coroutine[Any, Any, R]]:
is_async = inspect.iscoroutinefunction(func)
@wraps(func)
async def wrapper(request: T) -> R:
try:
if is_async:
result = await cast(Callable[[T], Awaitable[R]], func)(request)
else:
result = cast(Callable[[T], R], func)(request)
# Runtime type checking
if not isinstance(result, response_type):
raise TypeError(
f"Handler returned {type(result)}, expected {response_type}"
)
return result
except Exception as e:
if error:
err = error
else:
err = messages_pb2.Error(
code=messages_pb2.Error.PYTHON_ERROR,
message=str(e),
traceback=traceback.format_exc(),
)
return cast(R, messages_pb2.Response(error=err))
handlers[command_name] = wrapper # pyright: ignore[reportArgumentType]
return wrapper
return decorator
@proto_handler(check_pb2.HealthRequest)
async def check__health(_request: check_pb2.HealthRequest) -> check_pb2.HealthResponse:
return check_pb2.HealthResponse(passed=True)
@proto_handler(
check_pb2.DjangoAvailableRequest,
error=messages_pb2.Error(
code=messages_pb2.Error.DJANGO_ERROR, message="Django is not installed"
),
)
async def check__django_available(
_request: check_pb2.DjangoAvailableRequest,
) -> check_pb2.DjangoAvailableResponse:
import django # noqa: F401
return check_pb2.DjangoAvailableResponse(passed=True)
@proto_handler(
check_pb2.AppInstalledRequest,
error=messages_pb2.Error(
code=messages_pb2.Error.DJANGO_ERROR, message="App is not in INSTALLED_APPS"
),
)
async def check__app_installed(
request: check_pb2.AppInstalledRequest,
) -> check_pb2.AppInstalledResponse:
return check_pb2.AppInstalledResponse(passed=apps.is_installed(request.app_name))
@proto_handler(python_pb2.GetEnvironmentRequest)
async def python__get_environment(
_request: python_pb2.GetEnvironmentRequest,
) -> python_pb2.GetEnvironmentResponse:
packages = {}
for dist in importlib.metadata.distributions():
try:
requires = []
try:
requires = list(dist.requires) if hasattr(dist, "requires") else []
except Exception:
pass
location = None
try:
location = str(dist._path) if hasattr(dist, "_path") else None
except Exception:
pass
packages[dist.metadata["Name"]] = python_pb2.Package(
dist_name=dist.metadata["Name"],
dist_version=dist.metadata["Version"],
dist_location=location,
dist_requires=requires,
dist_requires_python=dist.metadata.get("Requires-Python"),
dist_entry_points=str(dist.entry_points)
if hasattr(dist, "entry_points")
else None,
)
except Exception:
continue
sysconfig_paths = sysconfig.get_paths()
version_info = python_pb2.VersionInfo(
major=sys.version_info.major,
minor=sys.version_info.minor,
micro=sys.version_info.micro,
releaselevel={
"alpha": python_pb2.ReleaseLevel.ALPHA,
"beta": python_pb2.ReleaseLevel.BETA,
"candidate": python_pb2.ReleaseLevel.CANDIDATE,
"final": python_pb2.ReleaseLevel.FINAL,
}[sys.version_info.releaselevel],
serial=sys.version_info.serial,
)
return python_pb2.GetEnvironmentResponse(
python=python_pb2.Python(
os=python_pb2.Os(environ={k: v for k, v in os.environ.items()}),
site=python_pb2.Site(packages=packages),
sys=python_pb2.Sys(
debug_build=hasattr(sys, "gettotalrefcount"),
dev_mode=sys.flags.dev_mode,
is_venv=sys.prefix != sys.base_prefix,
abiflags=sys.abiflags,
base_prefix=sys.base_prefix,
default_encoding=sys.getdefaultencoding(),
executable=sys.executable,
filesystem_encoding=sys.getfilesystemencoding(),
implementation_name=sys.implementation.name,
platform=sys.platform,
prefix=sys.prefix,
builtin_module_names=list(sys.builtin_module_names),
dll_paths=sys.path if sys.platform == "win32" else [],
path=sys.path,
version_info=version_info,
),
sysconfig=python_pb2.Sysconfig(
data=sysconfig_paths.get("data", ""),
include=sysconfig_paths.get("include", ""),
platinclude=sysconfig_paths.get("platinclude", ""),
platlib=sysconfig_paths.get("platlib", ""),
platstdlib=sysconfig_paths.get("platstdlib", ""),
purelib=sysconfig_paths.get("purelib", ""),
scripts=sysconfig_paths.get("scripts", ""),
stdlib=sysconfig_paths.get("stdlib", ""),
),
)
)
@proto_handler(django_pb2.GetProjectInfoRequest)
async def django__get_project_info(
_request: django_pb2.GetProjectInfoRequest,
) -> django_pb2.GetProjectInfoResponse:
return django_pb2.GetProjectInfoResponse(
project=django_pb2.Project(version=django.__version__)
)

44
python/djls/logging.py Normal file
View file

@ -0,0 +1,44 @@
from __future__ import annotations
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
@dataclass
class LogConfig:
log_file: Path | str = "/tmp/djls_debug.log"
log_level: int = logging.DEBUG
console_level: int = logging.DEBUG
file_level: int = logging.DEBUG
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
def configure_logging(config: LogConfig | None = None) -> logging.Logger:
if config is None:
config = LogConfig()
logger = logging.getLogger("djls")
logger.setLevel(config.log_level)
# Clear any existing handlers
logger.handlers.clear()
# File handler
fh = logging.FileHandler(config.log_file)
fh.setLevel(config.file_level)
# Console handler
ch = logging.StreamHandler(sys.stderr)
ch.setLevel(config.console_level)
# Formatter
formatter = logging.Formatter(config.format)
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
return logger

View file

@ -1,68 +0,0 @@
from __future__ import annotations
import json
import sys
from typing import Any
from .scripts import django_setup
from .scripts import has_import
from .scripts import python_setup
def handle_json_command(data: dict[str, Any]) -> dict[str, Any]:
command = data["command"]
args = data.get("args", []) # Get args if they exist
if command == "django_setup":
import django
django.setup()
return {"status": "ok", "data": django_setup.get_django_setup_info()}
if command == "has_import":
if not args:
return {"status": "error", "error": "Missing module name argument"}
return {
"status": "ok",
"data": {"can_import": has_import.check_import(args[0])},
}
if command == "health":
return {"status": "ok"}
if command == "installed_apps_check":
import django
from django.conf import settings
django.setup()
if not args:
return {"status": "error", "error": "Missing module name argument"}
return {
"status": "ok",
"data": {"has_app": args[0] in settings.INSTALLED_APPS},
}
if command == "python_setup":
return {"status": "ok", "data": python_setup.get_python_info()}
if command == "version":
return {"status": "ok", "data": "0.1.0"}
return {"status": "error", "error": f"Unknown command: {command}"}
def main():
transport_type = sys.stdin.readline().strip()
print("ready", flush=True)
while True:
try:
line = sys.stdin.readline()
if not line:
break
data = json.loads(line)
response = handle_json_command(data)
print(json.dumps(response), flush=True)
except Exception as e:
print(json.dumps({"status": "error", "error": str(e)}), flush=True)
if __name__ == "__main__":
main()

View file

@ -1,31 +0,0 @@
from __future__ import annotations
import json
def get_django_setup_info():
from django.conf import settings
from django.template.engine import Engine
return {
"installed_apps": list(settings.INSTALLED_APPS),
"templatetags": [
{
"name": tag_name,
"library": module_name.split(".")[-1],
"doc": tag_func.__doc__ if hasattr(tag_func, "__doc__") else None,
}
for module_name, library in (
[("", lib) for lib in Engine.get_default().template_builtins]
+ sorted(Engine.get_default().template_libraries.items())
)
for tag_name, tag_func in library.tags.items()
],
}
if __name__ == "__main__":
import django
django.setup()
print(json.dumps(get_django_setup_info()))

View file

@ -1,21 +0,0 @@
# has_import.py
from __future__ import annotations
import json
import sys
def check_import(module: str) -> bool:
try:
module_parts = module.split(".")
current = __import__(module_parts[0])
for part in module_parts[1:]:
current = getattr(current, part)
return True
except (ImportError, AttributeError):
return False
if __name__ == "__main__":
result = {"can_import": check_import(sys.argv[1])}
print(json.dumps(result))

View file

@ -1,9 +0,0 @@
from __future__ import annotations
import json
import sys
from django.conf import settings
if __name__ == "__main__":
print(json.dumps({"has_app": sys.argv[1] in settings.INSTALLED_APPS}))

View file

@ -1,78 +0,0 @@
from __future__ import annotations
import importlib.metadata
import json
import sys
import sysconfig
from typing import Dict
from typing import List
from typing import Optional
from typing import TypedDict
def get_version_info():
version_parts = sys.version.split()[0].split(".")
patch_and_suffix = version_parts[2]
for i, c in enumerate(patch_and_suffix):
if not c.isdigit():
patch = patch_and_suffix[:i]
suffix = patch_and_suffix[i:]
break
else:
patch = patch_and_suffix
suffix = None
return {
"major": int(version_parts[0]),
"minor": int(version_parts[1]),
"patch": int(patch),
"suffix": suffix,
}
class Package(TypedDict):
name: str
version: str
location: Optional[str]
def get_installed_packages() -> Dict[str, Package]:
packages: Dict[str, Package] = {}
for dist in importlib.metadata.distributions():
try:
location_path = dist.locate_file("")
location = location_path.parent.as_posix() if location_path else None
packages[dist.metadata["Name"]] = {
"name": dist.metadata["Name"],
"version": dist.version,
"location": location,
}
except Exception:
continue
return packages
def get_python_info() -> (
Dict[
str,
str
| Dict[str, str]
| List[str]
| Dict[str, Package]
| Dict[str, int | str | None],
]
):
return {
"version_info": get_version_info(),
"sysconfig_paths": sysconfig.get_paths(),
"sys_prefix": sys.prefix,
"sys_base_prefix": sys.base_prefix,
"sys_executable": sys.executable,
"sys_path": [p for p in sys.path if p],
"packages": get_installed_packages(),
}
if __name__ == "__main__":
print(json.dumps(get_python_info()))