add new render hooks for transparent progress

This commit is contained in:
Will McGugan 2020-06-12 21:00:56 +01:00
parent 8b5a8ddf07
commit f8b45a87be
4 changed files with 99 additions and 101 deletions

View file

@ -235,6 +235,15 @@ class ConsoleThreadLocals(threading.local):
buffer_index: int = 0
class RenderHook:
"""Provides hooks in to render process."""
def process_renderables(
self, renderables: List[ConsoleRenderable]
) -> List[ConsoleRenderable]:
return renderables
def detect_legacy_windows() -> bool:
"""Detect legacy Windows."""
return "WINDIR" in os.environ and "WT_SESSION" not in os.environ
@ -325,6 +334,7 @@ class Console:
self._record_buffer_lock = threading.RLock()
self._thread_locals = ConsoleThreadLocals()
self._record_buffer: List[Segment] = []
self._render_hooks: List[RenderHook] = []
def __repr__(self) -> str:
return f"<console width={self.width} {str(self._color_system)}>"
@ -368,6 +378,19 @@ class Console:
self._buffer_index -= 1
self._check_buffer()
def push_render_hook(self, hook: RenderHook) -> None:
"""Add a new render hook to the stack.
Args:
hook (RenderHook): Render hook instance.
"""
self._render_hooks.append(hook)
def pop_render_hook(self) -> None:
"""Pop the last renderhook from the stack."""
self._render_hooks.pop()
def __enter__(self) -> "Console":
"""Own context manager to enter buffer context."""
self._enter_buffer()
@ -684,6 +707,7 @@ class Console:
if rich_cast:
renderable = rich_cast()
if isinstance(renderable, str):
if renderable:
append_text(
self.render_str(
renderable,
@ -754,7 +778,7 @@ class Console:
end (str, optional): String to write at end of print data. Defaults to "\n".
style (Union[str, Style], optional): A style to apply to output. Defaults to None.
justify (str, optional): Justify method: "default", "left", "right", "center", or "full". Defaults to ``None``.
overflow (str, optional): Overflow method: "crop", "fold", or "ellipisis". Defaults to None.
overflow (str, optional): Overflow method: "crop", "fold", or "ellipsis". Defaults to None.
no_wrap (Optional[bool], optional): Disable word wrapping. Defaults to None.
emoji (Optional[bool], optional): Enable emoji code, or ``None`` to use console default. Defaults to ``None``.
markup (Optional[bool], optional): Enable markup, or ``None`` to use console default. Defaults to ``None``.
@ -775,6 +799,8 @@ class Console:
markup=markup,
highlight=highlight,
)
for hook in self._render_hooks:
renderables = hook.process_renderables(renderables)
render_options = self.options.update(
justify=justify, overflow=overflow, width=width, no_wrap=no_wrap
)
@ -843,6 +869,7 @@ class Console:
if not objects:
self.line()
return
with self:
renderables = self._collect_renderables(
objects,
sep,
@ -864,13 +891,16 @@ class Console:
}
renderables.append(tabulate_mapping(locals_map, title="Locals"))
with self:
self._buffer.extend(
self.render(
self._log_render(self, renderables, path=path, line_no=line_no),
self.options,
)
)
renderables = [
self._log_render(self, renderables, path=path, line_no=line_no)
]
for hook in self._render_hooks:
renderables = hook.process_renderables(renderables)
extend = self._buffer.extend
render = self.render
render_options = self.options
for renderable in renderables:
extend(render(renderable, render_options))
def _check_buffer(self) -> None:
"""Check if the buffer may be rendered."""
@ -883,6 +913,7 @@ class Console:
del self._buffer[:]
else:
text = self._render_buffer()
if text:
self.file.write(text)
self.file.flush()

View file

@ -37,7 +37,7 @@ class LiveRender:
if self._shape is not None:
_, height = self._shape
if height > 1:
return Control(f"\r\x1b[{height - 1}A\x1b[2K")
return Control("\r\x1b[2K" + "\x1b[1A\x1b[2K" * (height - 1))
else:
return Control("\r\x1b[2K")
return Control("")

View file

@ -3,6 +3,7 @@ from datetime import datetime
from logging import Handler, LogRecord
from pathlib import Path
from . import get_console
from rich._log_render import LogRender
from rich.console import Console
from rich.highlighter import ReprHighlighter
@ -25,7 +26,7 @@ class RichHandler(Handler):
def __init__(self, level: int = logging.NOTSET, console: Console = None) -> None:
super().__init__(level=level)
self.console = Console() if console is None else console
self.console = get_console() if console is None else console
self.highlighter = ReprHighlighter()
self._log_render = LogRender(show_level=True)

View file

@ -27,7 +27,15 @@ from typing import (
from . import get_console
from .bar import Bar
from .console import Console, JustifyMethod, RenderGroup, RenderableType
from .console import (
Console,
ConsoleRenderable,
JustifyMethod,
RenderGroup,
RenderHook,
RenderableType,
)
from .control import Control
from .highlighter import Highlighter
from . import filesize
from .live_render import LiveRender
@ -368,7 +376,7 @@ class _RefreshThread(Thread):
self.progress.refresh()
class Progress:
class Progress(RenderHook):
"""Renders an auto-updating progress bar(s).
Args:
@ -410,6 +418,8 @@ class Progress:
self._refresh_thread: Optional[_RefreshThread] = None
self._refresh_count = 0
self._started = False
self.print = self.console.print
self.log = self.console.log
@property
def tasks(self) -> List[Task]:
@ -438,6 +448,7 @@ class Progress:
return
self._started = True
self.console.show_cursor(False)
self.console.push_render_hook(self)
self.refresh()
if self.auto_refresh:
self._refresh_thread = _RefreshThread(self, self.refresh_per_second)
@ -457,6 +468,7 @@ class Progress:
self.console.line()
finally:
self.console.show_cursor(True)
self.console.pop_render_hook()
if self._refresh_thread is not None:
self._refresh_thread.join()
self._refresh_thread = None
@ -598,8 +610,9 @@ class Progress:
self._live_render.set_renderable(self.get_renderable())
if self.console.is_terminal:
with self.console:
self.console.print(self._live_render.position_cursor())
self.console.print(self._live_render)
# self.console.print(self._live_render.position_cursor())
self.console.print(Control(""))
# self.console.print(self._live_render)
self._refresh_count += 1
def get_renderable(self) -> RenderableType:
@ -693,65 +706,17 @@ class Progress:
with self._lock:
del self._tasks[task_id]
def print(
self,
*objects: Any,
sep=" ",
end="\n",
style: Union[str, Style] = None,
justify: JustifyMethod = None,
emoji: bool = None,
markup: bool = None,
highlight: bool = None,
) -> None:
"""Print to the terminal and preserve progress display. Parameters identical to :class:`~rich.console.Console.print`."""
console = self.console
with console:
if console.is_terminal:
console.print(self._live_render.position_cursor())
console.print(
*objects,
sep=sep,
end=end,
style=style,
justify=justify,
emoji=emoji,
markup=markup,
highlight=highlight,
)
if console.is_terminal:
console.print(self._live_render)
def log(
self,
*objects: Any,
sep=" ",
end="\n",
justify: JustifyMethod = None,
emoji: bool = None,
markup: bool = None,
highlight: bool = None,
log_locals: bool = False,
_stack_offset=1,
) -> None:
"""Log to the terminal and preserve progress display. Parameters identical to :class:`~rich.console.Console.log`."""
console = self.console
with console:
if console.is_terminal:
console.print(self._live_render.position_cursor())
console.log(
*objects,
sep=sep,
end=end,
justify=justify,
emoji=emoji,
markup=markup,
highlight=highlight,
log_locals=log_locals,
_stack_offset=_stack_offset + 1,
)
if console.is_terminal:
console.print(self._live_render)
def process_renderables(
self, renderables: List[ConsoleRenderable]
) -> List[ConsoleRenderable]:
"""Process renderables to restore cursor and display progress."""
if self.console.is_terminal:
renderables = [
self._live_render.position_cursor(),
*renderables,
self._live_render,
]
return renderables
if __name__ == "__main__": # pragma: no coverage
@ -799,7 +764,8 @@ yield True, previous_value''',
examples = cycle(progress_renderables)
with Progress(transient=True) as progress:
console = Console()
with Progress(console=console, transient=True) as progress:
task1 = progress.add_task(" [red]Downloading", total=1000)
task2 = progress.add_task(" [green]Processing", total=1000)