mirror of
https://github.com/Textualize/rich.git
synced 2025-12-23 07:08:35 +00:00
Add tqdm compatibility shim with Rich progress backend
Implements a runtime adapter that replaces tqdm imports with Rich's progress bars via install_tqdm(). The shim preserves common tqdm APIs (iteration, len(), postfix, write, wrapattr) while routing rendering through rich.progress.Progress. Features: - Opt-in monkeypatching of tqdm/trange and notebook variants - Shared global Progress instance for performance - Render throttling via mininterval/miniters/maxinterval - Simplified text backend for file outputs (partial mode) - Postfix value caching to avoid redundant formatting Includes runnable examples demonstrating basic replacement and aggressive reference overwriting, plus comprehensive unit tests covering adapter installation, throttling, and postfix handling. Updates README and contributor list.
This commit is contained in:
parent
f82a399d58
commit
adcb5c748c
9 changed files with 1097 additions and 0 deletions
|
|
@ -41,6 +41,7 @@ The following people have contributed to the development of Rich:
|
|||
- [Hedy Li](https://github.com/hedythedev)
|
||||
- [Henry Mai](https://github.com/tanducmai)
|
||||
- [Luka Mamukashvili](https://github.com/UltraStudioLTD)
|
||||
- [Shane K2 Macaulay](https://github.com/K2)
|
||||
- [Alexander Mancevice](https://github.com/amancevice)
|
||||
- [Will McGugan](https://github.com/willmcgugan)
|
||||
- [Paul McGuire](https://github.com/ptmcg)
|
||||
|
|
|
|||
20
README.md
20
README.md
|
|
@ -275,6 +275,26 @@ The columns may be configured to show any details you want. Built-in columns inc
|
|||
|
||||
To try this out yourself, see [examples/downloader.py](https://github.com/textualize/rich/blob/master/examples/downloader.py) which can download multiple URLs simultaneously while displaying progress.
|
||||
|
||||
#### tqdm compatibility shim
|
||||
|
||||
If you need to redirect `tqdm` calls in third-party code to Rich at runtime, call `install_tqdm()` before the code imports `tqdm`:
|
||||
|
||||
```python
|
||||
from rich.tqdm import install_tqdm
|
||||
|
||||
install_tqdm()
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
for _ in tqdm(range(10), desc="basic"):
|
||||
...
|
||||
```
|
||||
|
||||
Runnable examples:
|
||||
|
||||
- Basic replacement: `PYTHONPATH=. python examples/tqdm_adapter_basic.py`
|
||||
- Hidden reference overwrite (e.g. stashed `_hidden_tqdm`): `PYTHONPATH=. python examples/tqdm_adapter_hidden.py`
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
|
|
|||
22
examples/tqdm_adapter_basic.py
Normal file
22
examples/tqdm_adapter_basic.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""Replace tqdm with rich.progress at runtime (basic example).
|
||||
|
||||
Run:
|
||||
PYTHONPATH=. python examples/tqdm_adapter_basic.py
|
||||
"""
|
||||
|
||||
from time import sleep
|
||||
|
||||
from rich.tqdm import install_tqdm
|
||||
|
||||
|
||||
def main() -> None:
|
||||
install_tqdm()
|
||||
|
||||
from tqdm import tqdm # resolved to the rich-backed shim after install
|
||||
|
||||
for _ in tqdm(range(10), desc="basic"):
|
||||
sleep(0.05)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
23
examples/tqdm_adapter_hidden.py
Normal file
23
examples/tqdm_adapter_hidden.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
"""Overwrite hidden tqdm references (aggressive replacement example).
|
||||
|
||||
Simulates a third-party module (examples/tqdm_hidden_lib.py) that stashes
|
||||
``tqdm`` in an internal variable at import time. The adapter should still
|
||||
replace that hidden reference after install_tqdm().
|
||||
|
||||
Run:
|
||||
PYTHONPATH=. python examples/tqdm_adapter_hidden.py
|
||||
"""
|
||||
|
||||
from rich.tqdm import install_tqdm
|
||||
|
||||
from examples import tqdm_hidden_lib
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Import happens before install; hidden module already captured original tqdm
|
||||
install_tqdm() # aggressive replacement updates hidden references
|
||||
tqdm_hidden_lib.run_hidden_loop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
15
examples/tqdm_hidden_lib.py
Normal file
15
examples/tqdm_hidden_lib.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""Simulated third-party module that hides a tqdm reference internally.
|
||||
|
||||
Imported by tqdm_adapter_hidden.py to demonstrate aggressive replacement.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import tqdm as _tqdm
|
||||
|
||||
_hidden_tqdm = _tqdm.tqdm
|
||||
|
||||
|
||||
def run_hidden_loop() -> None:
|
||||
for _ in _hidden_tqdm(range(5), desc="hidden-lib"):
|
||||
time.sleep(0.05)
|
||||
741
rich/tqdm.py
Normal file
741
rich/tqdm.py
Normal file
|
|
@ -0,0 +1,741 @@
|
|||
"""Compatibility shim to replace ``tqdm`` with Rich's progress bars.
|
||||
|
||||
Rationale
|
||||
---------
|
||||
- Opt-in bridge: call :func:`install_tqdm` to monkeypatch ``tqdm`` symbols
|
||||
(``tqdm``, ``trange``, notebook variants) in imported modules **and** any
|
||||
already-imported aliases found in ``sys.modules``. The shim stays inert until
|
||||
you install it; uninstall with :func:`uninstall_tqdm`.
|
||||
- Keep call sites stable: the wrapper preserves the common tqdm surface area
|
||||
(iteration, len(), postfix, write, wrapattr) while routing rendering through
|
||||
:class:`rich.progress.Progress`.
|
||||
- Favor performance: reuses a shared global ``Progress`` when no custom console
|
||||
or file is supplied, throttles refreshes via ``mininterval``/``miniters`` and
|
||||
optional ``maxinterval``, and caches formatted postfix values to avoid string
|
||||
churn when metrics are unchanged.
|
||||
- Handle non-interactive outputs: a simplified text backend is used when
|
||||
``partial=True`` with a file-like target to approximate tqdm's plain output
|
||||
without Rich's live rendering.
|
||||
|
||||
This shim aims to be compatible with the most common tqdm usage patterns, but
|
||||
it does not implement every tqdm argument. Unrecognized kwargs are ignored
|
||||
rather than raising, matching tqdm's permissive behavior.
|
||||
"""
|
||||
|
||||
# pylint: disable=missing-function-docstring, import-outside-toplevel
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
import threading
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Dict,
|
||||
IO,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
cast,
|
||||
Reversible,
|
||||
Literal,
|
||||
)
|
||||
|
||||
|
||||
_PATCH_STATE: Dict[str, Any] = {
|
||||
"installed": False,
|
||||
"patched_modules": [], # list of (module, attr, original)
|
||||
"replaced_locations": [], # list of (module, name, original)
|
||||
}
|
||||
|
||||
_GLOBAL_STATE: Dict[str, Optional[Any]] = {"progress": None}
|
||||
|
||||
|
||||
class _DummyLock:
|
||||
def __enter__(self) -> _DummyLock:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
return False
|
||||
|
||||
|
||||
def _format_interval(seconds: float) -> str:
|
||||
# Simplified copy of tqdm.format_interval
|
||||
minutes, sec = divmod(int(seconds), 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
if hours:
|
||||
return f"{hours:d}:{minutes:02d}:{sec:02d}"
|
||||
return f"{minutes:02d}:{sec:02d}"
|
||||
|
||||
|
||||
def _format_postfix_value(value: Any) -> Any:
|
||||
if isinstance(value, (int, float)):
|
||||
formatted = f"{value:.3g}"
|
||||
value_str = str(value)
|
||||
return formatted if len(formatted) < len(value_str) else value_str
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
|
||||
def format_num(n: Any) -> str:
|
||||
formatted = str(_format_postfix_value(n))
|
||||
return formatted.replace("e+0", "e+").replace("e-0", "e-")
|
||||
|
||||
|
||||
def format_interval(t: float) -> str:
|
||||
return _format_interval(t)
|
||||
|
||||
|
||||
def format_meter(
|
||||
n: float, total: Optional[float], elapsed: float, prefix: str = "", **_: Any
|
||||
) -> str:
|
||||
total_str = "?" if total is None else format_num(total)
|
||||
percentage: float
|
||||
if total is None or total == 0:
|
||||
percentage = 0.0
|
||||
else:
|
||||
percentage = 100 * n / total
|
||||
elapsed_str = _format_interval(elapsed)
|
||||
rate = n / elapsed if elapsed > 0 else 0
|
||||
rate_str = "?" if rate == 0 else format_num(rate)
|
||||
return (
|
||||
f"{prefix}{percentage:3.0f}%| | {format_num(n)}/{total_str} "
|
||||
f"[{elapsed_str}<?, {rate_str}it/s]"
|
||||
)
|
||||
|
||||
|
||||
def _get_progress(console: Any = None, file: Any = None) -> Tuple[Any, bool]:
|
||||
"""Return a Progress instance and whether it is shared/global."""
|
||||
|
||||
if console is not None or file is not None:
|
||||
from .progress import Progress
|
||||
from .console import Console
|
||||
|
||||
prog_console = console or Console(file=file)
|
||||
progress = Progress(console=prog_console)
|
||||
progress.start()
|
||||
return progress, False
|
||||
|
||||
progress = cast("Progress", _GLOBAL_STATE["progress"])
|
||||
if progress is None:
|
||||
from . import get_console
|
||||
from .progress import Progress
|
||||
|
||||
progress = Progress(console=get_console())
|
||||
progress.start()
|
||||
_GLOBAL_STATE["progress"] = progress
|
||||
return progress, True
|
||||
|
||||
|
||||
def _stop_global_progress() -> None:
|
||||
progress = _GLOBAL_STATE.get("progress")
|
||||
if progress is not None:
|
||||
try:
|
||||
progress.stop()
|
||||
finally:
|
||||
_GLOBAL_STATE["progress"] = None
|
||||
|
||||
|
||||
def _attach_api_shims(obj: Any) -> None:
|
||||
"""Attach tqdm-compatible helper functions and class variables to a callable."""
|
||||
# pylint: disable=protected-access
|
||||
obj.format_num = format_num
|
||||
obj.format_interval = format_interval
|
||||
obj.format_meter = format_meter
|
||||
obj.write = _TqdmWrapper.write
|
||||
obj.get_lock = _TqdmWrapper.get_lock
|
||||
obj.set_lock = _TqdmWrapper.set_lock
|
||||
obj.external_write_mode = _TqdmWrapper.external_write_mode
|
||||
obj.wrapattr = _TqdmWrapper.wrapattr
|
||||
obj._instances = _TqdmWrapper._instances
|
||||
obj._lock = None
|
||||
|
||||
|
||||
class _TextBackend:
|
||||
"""Minimal text renderer for partial compatibility mode."""
|
||||
|
||||
def __init__(self, stream: IO[str], mininterval: float) -> None:
|
||||
self.stream = stream
|
||||
self.mininterval = mininterval
|
||||
self._last_render = 0.0
|
||||
|
||||
def render(
|
||||
self,
|
||||
*,
|
||||
n: int,
|
||||
total: Optional[int],
|
||||
desc: str,
|
||||
postfix: Dict[Any, Any],
|
||||
leave: bool,
|
||||
to_string: bool,
|
||||
) -> str:
|
||||
now = time.monotonic()
|
||||
if not to_string and now - self._last_render < self.mininterval:
|
||||
return ""
|
||||
self._last_render = now
|
||||
|
||||
percent = 0 if not total else (n / total * 100 if total else 0)
|
||||
bar_len = 10
|
||||
done = int(bar_len * (0 if not total else min(1.0, n / float(total))))
|
||||
abar = "#" * done + " " * (bar_len - done)
|
||||
total_str = "?" if total is None else str(total)
|
||||
line = f"\r{percent:3.0f}%|{abar}| {n}/{total_str} [00:00<]"
|
||||
if desc:
|
||||
line = f"{desc} {line}"
|
||||
if postfix:
|
||||
postfix_str = ", ".join(f"{k}={v}" for k, v in postfix.items())
|
||||
line = f"{line} {postfix_str}"
|
||||
if to_string:
|
||||
return line.strip()
|
||||
if not leave and total is not None and n >= total:
|
||||
return line
|
||||
try:
|
||||
self.stream.write(line)
|
||||
except (OSError, ValueError):
|
||||
return line
|
||||
return line
|
||||
|
||||
|
||||
class _TqdmWrapper:
|
||||
"""Lightweight stand-in for ``tqdm`` based on Rich progress with render throttling."""
|
||||
|
||||
_instances: List[Any] = []
|
||||
_lock_obj: Any = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iterable: Optional[Iterable[Any]] = None,
|
||||
total: Optional[int] = None,
|
||||
desc: Optional[str] = None,
|
||||
leave: bool = True,
|
||||
disable: bool = False,
|
||||
initial: int = 0,
|
||||
console: Any = None,
|
||||
file: Any = None,
|
||||
partial: bool = False,
|
||||
mininterval: float = 0.1,
|
||||
maxinterval: Optional[float] = None,
|
||||
miniters: Optional[int] = None,
|
||||
ascii: bool = False, # pylint: disable=redefined-builtin
|
||||
smoothing: Optional[float] = None,
|
||||
unit: str = "it",
|
||||
unit_scale: bool = False,
|
||||
unit_divisor: int = 1000,
|
||||
bar_format: Optional[str] = None,
|
||||
**_: Any,
|
||||
) -> None:
|
||||
del smoothing, unit, unit_scale, unit_divisor, bar_format
|
||||
self._iterable = iterable
|
||||
self._total = total if total is not None else None
|
||||
self._desc = desc or ""
|
||||
self._leave = leave
|
||||
self.desc = self._desc
|
||||
self._closed = False
|
||||
self._disabled = disable
|
||||
self._progress: Optional[Any] = None
|
||||
self._task: Optional[Any] = None
|
||||
self._write_stream: Optional[IO[str]] = None
|
||||
self._partial = partial
|
||||
self._use_text_backend = bool(partial and file is not None)
|
||||
self._text_backend: Optional[_TextBackend] = None
|
||||
self.dynamic_miniters = True
|
||||
self.ascii = ascii
|
||||
self._lock: Optional[Any] = None
|
||||
self.postfix: Optional[Dict[Any, Any]] = {0: {}}
|
||||
self.n = initial
|
||||
self._is_global = False
|
||||
self._mininterval = mininterval
|
||||
self._maxinterval = maxinterval
|
||||
self._miniters = 1 if miniters is None else max(1, int(miniters))
|
||||
# Seed render markers so the first update is eligible to render promptly.
|
||||
self._last_render_time = 0.0
|
||||
self._last_render_n = self.n
|
||||
|
||||
if self._total is None and hasattr(iterable, "__len__"):
|
||||
try:
|
||||
self._total = len(iterable) # type: ignore[arg-type]
|
||||
except (TypeError, OverflowError):
|
||||
self._total = None
|
||||
|
||||
if self._use_text_backend:
|
||||
self._write_stream = file
|
||||
self._text_backend = _TextBackend(file, mininterval)
|
||||
if not self._disabled and not self._use_text_backend:
|
||||
self._progress, self._is_global = _get_progress(console=console, file=file)
|
||||
if self._progress is not None:
|
||||
self._task = self._progress.add_task(self._desc, total=self._total)
|
||||
if initial:
|
||||
self._progress.update(self._task, completed=initial)
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
"""Iterate while tracking progress updates, honoring disabled/text modes."""
|
||||
if self._iterable is None:
|
||||
return iter(())
|
||||
iterator = iter(self._iterable)
|
||||
if self._disabled:
|
||||
return self._iter_disabled(iterator)
|
||||
if self._use_text_backend:
|
||||
self._maybe_render_text()
|
||||
return self._iter_enabled(iterator)
|
||||
|
||||
def _iter_disabled(self, iterator: Iterator[Any]) -> Iterator[Any]:
|
||||
"""Yield items while disabled, tracking count without rendering."""
|
||||
for item in iterator:
|
||||
self.n += 1
|
||||
yield item
|
||||
|
||||
def _iter_enabled(self, iterator: Iterator[Any]) -> Iterator[Any]:
|
||||
"""Yield items, updating progress each step and closing if not leaving."""
|
||||
for item in iterator:
|
||||
yield item
|
||||
self.update(1)
|
||||
if not self._leave:
|
||||
self.close()
|
||||
|
||||
def update(
|
||||
self,
|
||||
n: int = 1,
|
||||
*,
|
||||
postfix: Optional[Dict[str, Any]] = None,
|
||||
refresh: bool = False,
|
||||
) -> None:
|
||||
"""Advance the bar and conditionally refresh based on throttling thresholds."""
|
||||
if self._closed:
|
||||
return
|
||||
self.n += n
|
||||
postfix_changed = False
|
||||
if postfix:
|
||||
new_postfix = {k: _format_postfix_value(v) for k, v in postfix.items()}
|
||||
if new_postfix != (self.postfix or {}):
|
||||
self.postfix = new_postfix
|
||||
postfix_changed = True
|
||||
if self._disabled:
|
||||
return
|
||||
if self._use_text_backend:
|
||||
self._maybe_render_text()
|
||||
return
|
||||
progress = self._progress
|
||||
task = self._task
|
||||
if progress is None or task is None:
|
||||
return
|
||||
now = time.monotonic()
|
||||
should_refresh = self._should_render(
|
||||
now, refresh=refresh, postfix_changed=postfix_changed
|
||||
)
|
||||
progress.update(task, advance=n, postfix=self.postfix, refresh=should_refresh)
|
||||
if should_refresh:
|
||||
self._mark_render(now)
|
||||
|
||||
def set_description(self, desc: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._desc = str(desc)
|
||||
if self._disabled:
|
||||
return
|
||||
self.desc = self._desc
|
||||
if self._use_text_backend:
|
||||
self._maybe_render_text()
|
||||
return
|
||||
progress = self._progress
|
||||
task = self._task
|
||||
if progress is None or task is None:
|
||||
return
|
||||
now = time.monotonic()
|
||||
progress.update(task, description=self._desc, refresh=True)
|
||||
self._mark_render(now)
|
||||
|
||||
def set_postfix(
|
||||
self,
|
||||
ordered_dict: Optional[Dict[str, Any]] = None,
|
||||
refresh: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Update postfix fields while reusing the last render unless content changed."""
|
||||
if self._closed:
|
||||
return
|
||||
combined: Dict[str, Any] = {}
|
||||
if ordered_dict:
|
||||
combined.update(ordered_dict)
|
||||
if kwargs:
|
||||
combined.update(kwargs)
|
||||
postfix_changed = False
|
||||
if combined:
|
||||
new_postfix = {k: _format_postfix_value(v) for k, v in combined.items()}
|
||||
if new_postfix != (self.postfix or {}):
|
||||
self.postfix = new_postfix
|
||||
postfix_changed = True
|
||||
if self._disabled:
|
||||
return
|
||||
if self._use_text_backend:
|
||||
self._maybe_render_text()
|
||||
return
|
||||
progress = self._progress
|
||||
task = self._task
|
||||
if progress is None or task is None:
|
||||
return
|
||||
now = time.monotonic()
|
||||
should_refresh = self._should_render(
|
||||
now, refresh=refresh, postfix_changed=postfix_changed
|
||||
)
|
||||
progress.update(task, postfix=self.postfix, refresh=should_refresh)
|
||||
if should_refresh:
|
||||
self._mark_render(now)
|
||||
|
||||
def set_postfix_str(self, s: str = "", refresh: bool = True) -> None:
|
||||
self.set_postfix({"postfix": s}, refresh=refresh)
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if not self._disabled and self._progress and self._task is not None:
|
||||
self._progress.remove_task(self._task)
|
||||
self._closed = True
|
||||
try:
|
||||
self._instances.remove(self)
|
||||
except ValueError:
|
||||
pass
|
||||
if self._progress and not self._progress.tasks and self._is_global:
|
||||
_stop_global_progress()
|
||||
|
||||
def refresh(self) -> None:
|
||||
if self._disabled:
|
||||
return
|
||||
if self._closed:
|
||||
return
|
||||
if self._use_text_backend:
|
||||
self._maybe_render_text()
|
||||
return
|
||||
if self._progress is None:
|
||||
return
|
||||
self._progress.refresh()
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self._total is not None:
|
||||
return int(self._total)
|
||||
if self._iterable is not None:
|
||||
try:
|
||||
return len(self._iterable) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
pass
|
||||
raise TypeError("object of type '_TqdmWrapper' has no len()")
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._total or self._iterable)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self._maybe_render_text(to_string=True)
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
try:
|
||||
return (self._total or 0) < (getattr(other, "_total", 0) or 0)
|
||||
except (AttributeError, TypeError):
|
||||
return False
|
||||
|
||||
def __enter__(self) -> _TqdmWrapper:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def __contains__(self, item: Any) -> bool:
|
||||
if self._iterable is None:
|
||||
return False
|
||||
try:
|
||||
return item in self._iterable
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
def __reversed__(self) -> Iterator[Any]:
|
||||
if self._iterable is None:
|
||||
raise TypeError("_TqdmWrapper object is not reversible")
|
||||
try:
|
||||
reversible = cast(Reversible[Any], self._iterable)
|
||||
return reversed(reversible)
|
||||
except (TypeError, AttributeError) as exc:
|
||||
raise TypeError("_TqdmWrapper object is not reversible") from exc
|
||||
|
||||
def _maybe_render_text(self, to_string: bool = False) -> str:
|
||||
if not self._text_backend:
|
||||
return ""
|
||||
postfix_items = {k: v for k, v in (self.postfix or {}).items() if k != 0 and v}
|
||||
return self._text_backend.render(
|
||||
n=self.n,
|
||||
total=self._total,
|
||||
desc=self._desc,
|
||||
postfix=postfix_items,
|
||||
leave=self._leave,
|
||||
to_string=to_string,
|
||||
)
|
||||
|
||||
def _should_render(
|
||||
self, now: float, *, refresh: bool, postfix_changed: bool = False
|
||||
) -> bool:
|
||||
"""Decide whether to render based on min interval/iters and optional max interval."""
|
||||
if refresh:
|
||||
return True
|
||||
delta_time = now - self._last_render_time
|
||||
delta_n = self.n - self._last_render_n
|
||||
if postfix_changed and delta_time >= self._mininterval:
|
||||
return True
|
||||
if delta_time >= self._mininterval and delta_n >= self._miniters:
|
||||
return True
|
||||
if self._maxinterval is not None and delta_time >= self._maxinterval:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _mark_render(self, now: float) -> None:
|
||||
"""Record the last render time and counter for subsequent throttling decisions."""
|
||||
self._last_render_time = now
|
||||
self._last_render_n = self.n
|
||||
|
||||
def set_description_str(self, desc: str = "") -> None:
|
||||
self.set_description(desc)
|
||||
|
||||
def reset(self, total: Optional[int] = None) -> None:
|
||||
if total is not None:
|
||||
self._total = total
|
||||
self.n = 0
|
||||
if self._use_text_backend:
|
||||
self._maybe_render_text()
|
||||
elif self._progress and self._task is not None:
|
||||
self._progress.reset(self._task, total=self._total)
|
||||
|
||||
def unpause(self) -> None:
|
||||
return
|
||||
|
||||
def clear(self, nolock: bool = False) -> None:
|
||||
del nolock
|
||||
if self._use_text_backend and self._write_stream is not None:
|
||||
try:
|
||||
self._write_stream.write("\n")
|
||||
except (OSError, ValueError):
|
||||
return
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def write(
|
||||
cls,
|
||||
s: str,
|
||||
file: Optional[IO[str]] = None,
|
||||
end: str = "\n",
|
||||
nolock: bool = False,
|
||||
) -> None:
|
||||
del nolock
|
||||
target = file or sys.stderr
|
||||
try:
|
||||
target.write(str(s) + end)
|
||||
except (OSError, ValueError):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def get_lock(cls) -> Any:
|
||||
if cls._lock_obj is None:
|
||||
try:
|
||||
cls._lock_obj = threading.Lock()
|
||||
except (OSError, RuntimeError):
|
||||
cls._lock_obj = _DummyLock()
|
||||
return cls._lock_obj
|
||||
|
||||
@classmethod
|
||||
def set_lock(cls, lock: Any) -> None:
|
||||
cls._lock_obj = lock
|
||||
|
||||
@classmethod
|
||||
def external_write_mode(
|
||||
cls, file: Optional[IO[str]] = None, nolock: bool = False
|
||||
) -> ContextManager[IO[str]]:
|
||||
"""Context manager mirroring tqdm.external_write_mode for compatibility."""
|
||||
|
||||
del nolock
|
||||
|
||||
class _CM:
|
||||
def __enter__(self) -> IO[str]:
|
||||
return file or sys.stderr
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
return False
|
||||
|
||||
return _CM()
|
||||
|
||||
@classmethod
|
||||
def wrapattr(
|
||||
cls,
|
||||
stream: Any,
|
||||
method: str,
|
||||
total: Optional[int] = None,
|
||||
bytes: bool = False, # pylint: disable=redefined-builtin
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Wrap a stream attribute call for compatibility with tqdm's wrapattr."""
|
||||
del total, bytes, kwargs
|
||||
func = getattr(stream, method)
|
||||
|
||||
class _Wrap:
|
||||
def __init__(self, fn: Callable[..., Any]):
|
||||
self._fn = fn
|
||||
|
||||
def __call__(self, *a: Any, **k: Any) -> Any:
|
||||
return self._fn(*a, **k)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._fn, name)
|
||||
|
||||
def __enter__(self) -> "_Wrap":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> Literal[False]:
|
||||
return False
|
||||
|
||||
def write(self, data: Any) -> Any:
|
||||
return self._fn(data)
|
||||
|
||||
return _Wrap(func)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if getattr(self, "_closed", True):
|
||||
return
|
||||
self.close()
|
||||
|
||||
|
||||
def _make_tqdm_callable(
|
||||
console: Any = None, *, partial: bool = False
|
||||
) -> Callable[[Optional[Iterable[Any]]], _TqdmWrapper]:
|
||||
def _tqdm(iterable: Optional[Iterable[Any]] = None, **kwargs: Any) -> _TqdmWrapper:
|
||||
return _TqdmWrapper(
|
||||
iterable=iterable, console=console, partial=partial, **kwargs
|
||||
)
|
||||
|
||||
_attach_api_shims(_tqdm)
|
||||
|
||||
return _tqdm
|
||||
|
||||
|
||||
def _make_trange(
|
||||
console: Any = None, *, partial: bool = False
|
||||
) -> Callable[..., _TqdmWrapper]:
|
||||
def _trange(*args: Any, **kwargs: Any) -> _TqdmWrapper:
|
||||
return _make_tqdm_callable(console, partial=partial)(range(*args), **kwargs)
|
||||
|
||||
_attach_api_shims(_trange)
|
||||
|
||||
return _trange
|
||||
|
||||
|
||||
def install_tqdm(console: Any = None, *, partial: bool = False) -> None:
|
||||
"""Monkeypatch ``tqdm`` symbols to use Rich progress.
|
||||
|
||||
Args:
|
||||
console: Optional Console to route progress output through.
|
||||
partial: When True, enable simplified text rendering for file-like outputs to
|
||||
approximate tqdm aesthetics. When False (default), favor the Rich progress
|
||||
presentation and avoid extra tqdm-style formatting.
|
||||
"""
|
||||
|
||||
if _PATCH_STATE["installed"]:
|
||||
return
|
||||
|
||||
module_names = ["tqdm", "tqdm.std", "tqdm.auto", "tqdm.notebook"]
|
||||
target_attrs = ["tqdm", "trange", "tqdm_notebook", "tnrange"]
|
||||
|
||||
replacements: Dict[Any, Any] = {}
|
||||
to_patch: List[Tuple[types.ModuleType, str, Any]] = []
|
||||
|
||||
for modname in module_names:
|
||||
try:
|
||||
mod = importlib.import_module(modname)
|
||||
except ImportError:
|
||||
continue
|
||||
for attr in target_attrs:
|
||||
if not hasattr(mod, attr):
|
||||
continue
|
||||
orig = getattr(mod, attr)
|
||||
if orig in replacements:
|
||||
continue
|
||||
newobj = (
|
||||
_make_tqdm_callable(console, partial=partial)
|
||||
if attr in ("tqdm", "tqdm_notebook")
|
||||
else _make_trange(console, partial=partial)
|
||||
)
|
||||
replacements[orig] = newobj
|
||||
to_patch.append((mod, attr, orig))
|
||||
|
||||
replaced_locations: List[Tuple[types.ModuleType, str, Any]] = []
|
||||
for orig, newval in list(replacements.items()):
|
||||
for module in list(sys.modules.values()):
|
||||
moddict = getattr(module, "__dict__", None)
|
||||
if not isinstance(moddict, dict):
|
||||
continue
|
||||
for name, val in list(moddict.items()):
|
||||
if val is orig:
|
||||
replaced_locations.append((module, name, orig))
|
||||
try:
|
||||
setattr(module, name, newval)
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
|
||||
patched_modules: List[Tuple[types.ModuleType, str, Any]] = []
|
||||
for mod, attr, orig in to_patch:
|
||||
replacement: Any = replacements.get(orig)
|
||||
if replacement is None:
|
||||
continue
|
||||
new_callable = cast(Callable[..., Any], replacement)
|
||||
try:
|
||||
setattr(mod, attr, new_callable)
|
||||
patched_modules.append((mod, attr, orig))
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
|
||||
_PATCH_STATE["installed"] = True
|
||||
_PATCH_STATE["patched_modules"] = patched_modules
|
||||
_PATCH_STATE["replaced_locations"] = replaced_locations
|
||||
|
||||
|
||||
def uninstall_tqdm() -> None:
|
||||
"""Revert monkeypatching performed by :func:`install_tqdm`."""
|
||||
|
||||
if not _PATCH_STATE["installed"]:
|
||||
return
|
||||
|
||||
for module, name, orig in _PATCH_STATE.get("replaced_locations", []):
|
||||
try:
|
||||
setattr(module, name, orig)
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
|
||||
for module, attr, orig in _PATCH_STATE.get("patched_modules", []):
|
||||
try:
|
||||
setattr(module, attr, orig)
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
|
||||
_PATCH_STATE["installed"] = False
|
||||
_PATCH_STATE["patched_modules"] = []
|
||||
_PATCH_STATE["replaced_locations"] = []
|
||||
_stop_global_progress()
|
||||
82
tests/test_tqdm_adapter.py
Normal file
82
tests/test_tqdm_adapter.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
"""Adapter install/uninstall coverage for the Rich tqdm shim."""
|
||||
|
||||
import sys
|
||||
import types
|
||||
from typing import cast
|
||||
|
||||
from rich.tqdm import _TqdmWrapper, install_tqdm, uninstall_tqdm
|
||||
|
||||
|
||||
def test_install_replaces_imported_reference() -> None:
|
||||
"""Replacing imported tqdm symbols should affect existing aliases."""
|
||||
saved = sys.modules.pop("tqdm", None)
|
||||
try:
|
||||
fake = types.ModuleType("tqdm")
|
||||
|
||||
def orig_tqdm(iterable=None, **kwargs):
|
||||
del kwargs
|
||||
if iterable is None:
|
||||
return None
|
||||
return ("orig", list(iterable))
|
||||
|
||||
def orig_trange(*args, **kwargs):
|
||||
del kwargs
|
||||
return range(*args)
|
||||
|
||||
fake.tqdm = orig_tqdm
|
||||
fake.trange = orig_trange
|
||||
sys.modules["tqdm"] = fake
|
||||
|
||||
third = types.ModuleType("thirdparty")
|
||||
third.tqdm_alias = fake.tqdm
|
||||
sys.modules["thirdparty"] = third
|
||||
|
||||
try:
|
||||
assert third.tqdm_alias is orig_tqdm
|
||||
install_tqdm()
|
||||
assert third.tqdm_alias is not orig_tqdm
|
||||
it = third.tqdm_alias(range(3))
|
||||
assert list(it) == [0, 1, 2]
|
||||
finally:
|
||||
uninstall_tqdm()
|
||||
assert third.tqdm_alias is orig_tqdm
|
||||
finally:
|
||||
sys.modules.pop("thirdparty", None)
|
||||
if saved is not None:
|
||||
sys.modules["tqdm"] = saved
|
||||
else:
|
||||
sys.modules.pop("tqdm", None)
|
||||
|
||||
|
||||
def test_install_replaces_hidden_reference_and_postfix_support() -> None:
|
||||
"""Hidden tqdm references should update and still format postfix values."""
|
||||
saved = sys.modules.pop("tqdm", None)
|
||||
try:
|
||||
fake = types.ModuleType("tqdm")
|
||||
|
||||
def orig_tqdm(iterable=None, **kwargs):
|
||||
del kwargs
|
||||
if iterable is None:
|
||||
return None
|
||||
return ("orig", list(iterable))
|
||||
|
||||
fake.tqdm = orig_tqdm
|
||||
fake._hidden_tqdm = orig_tqdm # pylint: disable=protected-access
|
||||
sys.modules["tqdm"] = fake
|
||||
|
||||
try:
|
||||
install_tqdm()
|
||||
patched = fake.tqdm
|
||||
assert patched is not orig_tqdm
|
||||
assert fake._hidden_tqdm is patched # pylint: disable=protected-access
|
||||
|
||||
progress = cast(_TqdmWrapper, patched(range(2)))
|
||||
progress.update(postfix={"loss": 1.23}) # pylint: disable=no-member
|
||||
assert progress.postfix == {"loss": "1.23"} # pylint: disable=no-member
|
||||
finally:
|
||||
uninstall_tqdm()
|
||||
finally:
|
||||
if saved is not None:
|
||||
sys.modules["tqdm"] = saved
|
||||
else:
|
||||
sys.modules.pop("tqdm", None)
|
||||
60
tests/test_tqdm_parity.py
Normal file
60
tests/test_tqdm_parity.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""Parity checks between Rich's tqdm shim and vendored tqdm reference.
|
||||
|
||||
Note: this test relies on a vendored copy of tqdm under ``tests/tqdm`` to
|
||||
provide a reference implementation. If that folder is absent, the test is
|
||||
skipped. Keep the vendored copy lightweight and in sync with the targeted tqdm
|
||||
version when you want to exercise parity; otherwise the skip is expected in
|
||||
normal installs.
|
||||
"""
|
||||
|
||||
# pylint: disable=import-error
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from rich.tqdm import install_tqdm, uninstall_tqdm
|
||||
|
||||
|
||||
SUBMODULE_ROOT = Path(__file__).parent / "tqdm"
|
||||
|
||||
|
||||
def test_tqdm_adapter_len_and_disable_parity() -> None:
|
||||
"""Use the vendored tqdm submodule as a baseline for adapter parity."""
|
||||
|
||||
if not SUBMODULE_ROOT.exists():
|
||||
pytest.skip("vendored tqdm submodule missing")
|
||||
|
||||
added_path = False
|
||||
submodule_path = str(SUBMODULE_ROOT)
|
||||
if submodule_path not in sys.path:
|
||||
sys.path.insert(0, submodule_path)
|
||||
added_path = True
|
||||
|
||||
orig_mod = None
|
||||
try:
|
||||
orig_mod = importlib.import_module("tqdm")
|
||||
orig_tqdm = orig_mod.tqdm
|
||||
|
||||
baseline_len = len(orig_tqdm(range(5)))
|
||||
|
||||
install_tqdm()
|
||||
patched_tqdm = orig_mod.tqdm
|
||||
|
||||
assert patched_tqdm is not orig_tqdm
|
||||
assert len(patched_tqdm(range(5))) == baseline_len
|
||||
assert list(patched_tqdm(range(3), disable=True)) == [0, 1, 2]
|
||||
finally:
|
||||
uninstall_tqdm()
|
||||
if orig_mod is not None:
|
||||
importlib.reload(orig_mod)
|
||||
sys.modules.pop("tqdm", None)
|
||||
if added_path:
|
||||
try:
|
||||
sys.path.remove(submodule_path)
|
||||
except ValueError:
|
||||
pass
|
||||
133
tests/test_tqdm_related.py
Normal file
133
tests/test_tqdm_related.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
"""Unit-level helpers that mirror tqdm formatting and postfix handling."""
|
||||
|
||||
import io
|
||||
|
||||
from rich.tqdm import _TqdmWrapper, _TextBackend
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
|
||||
def format_num(n):
|
||||
"""Intelligent scientific notation (.3g), mirroring tqdm."""
|
||||
|
||||
f = "{0:.3g}".format(n).replace("+0", "+").replace("-0", "-")
|
||||
n_str = str(n)
|
||||
return f if len(f) < len(n_str) else n_str
|
||||
|
||||
|
||||
def test_format_num_matches_shorter_representation():
|
||||
"""Ensure numeric formatting mirrors tqdm's shorter-of fixed/scientific rule."""
|
||||
assert format_num(1337) == "1337"
|
||||
# prefers the shorter of fixed or scientific
|
||||
assert format_num(1239876) == "1239876"
|
||||
assert format_num(0.00001234) == "1.23e-5"
|
||||
assert format_num(-0.1234) == "-.123"
|
||||
|
||||
|
||||
def test_wrapper_accepts_postfix_and_len():
|
||||
"""Wrapper should expose len() and normalize postfix values."""
|
||||
progress = _TqdmWrapper(range(3), total=None)
|
||||
assert len(progress) == 3
|
||||
progress.update(postfix={"loss": 1.2345})
|
||||
assert progress.postfix == {"loss": "1.23"}
|
||||
|
||||
|
||||
def test_set_postfix_allows_multiple_fields():
|
||||
"""Multiple postfix fields remain formatted to short strings."""
|
||||
progress = _TqdmWrapper(range(2), total=2)
|
||||
progress.set_postfix({"loss": 0.9876}, acc=0.42)
|
||||
assert progress.postfix == {"loss": "0.988", "acc": "0.42"}
|
||||
|
||||
|
||||
def test_set_postfix_str_passthrough():
|
||||
"""String postfix should pass through unchanged."""
|
||||
progress = _TqdmWrapper(range(1), total=1)
|
||||
progress.set_postfix_str("done")
|
||||
assert progress.postfix == {"postfix": "done"}
|
||||
|
||||
|
||||
def test_wrapper_disable_skips_progress_calls():
|
||||
"""Disabled wrapper still increments counters without rendering."""
|
||||
progress = _TqdmWrapper(range(2), disable=True)
|
||||
assert list(progress) == [0, 1]
|
||||
# update should not raise when disabled and should still track n
|
||||
progress.update(2, postfix={"acc": 0.9})
|
||||
assert progress.n == 4
|
||||
|
||||
|
||||
def test_postfix_idempotent_when_unchanged():
|
||||
"""Calling set_postfix with identical content should keep the same mapping."""
|
||||
progress = _TqdmWrapper(range(1), total=1)
|
||||
progress.set_postfix({"loss": 1.0})
|
||||
first = progress.postfix
|
||||
progress.set_postfix({"loss": 1.0})
|
||||
assert progress.postfix is first
|
||||
|
||||
|
||||
def test_text_backend_throttles(monkeypatch):
|
||||
"""Text backend should skip renders inside the mininterval window."""
|
||||
backend = _TextBackend(io.StringIO(), mininterval=0.1)
|
||||
backend._last_render = -1.0 # noqa: SLF001 ensure first render passes throttle
|
||||
ticks = iter([0.0, 0.05, 0.2])
|
||||
monkeypatch.setattr("rich.tqdm.time.monotonic", lambda: next(ticks))
|
||||
|
||||
first = backend.render(n=1, total=2, desc="d", postfix={}, leave=True, to_string=False)
|
||||
second = backend.render(n=2, total=2, desc="d", postfix={}, leave=True, to_string=False)
|
||||
third = backend.render(n=2, total=2, desc="d", postfix={}, leave=True, to_string=False)
|
||||
|
||||
assert first
|
||||
assert second == "" # throttled
|
||||
assert third # rendered after interval
|
||||
|
||||
|
||||
def test_should_render_thresholds():
|
||||
"""Render should wait for both time and iter thresholds unless postfix changed."""
|
||||
progress = _TqdmWrapper(range(1), mininterval=1.0, miniters=2)
|
||||
progress._last_render_time = 0.0 # noqa: SLF001
|
||||
progress._last_render_n = 0 # noqa: SLF001
|
||||
|
||||
progress.n = 1
|
||||
assert progress._should_render(0.5, refresh=False, postfix_changed=False) is False # noqa: SLF001
|
||||
|
||||
progress.n = 2
|
||||
assert progress._should_render(0.5, refresh=False, postfix_changed=False) is False # noqa: SLF001
|
||||
assert progress._should_render(1.0, refresh=False, postfix_changed=False) is True # noqa: SLF001
|
||||
|
||||
progress._last_render_time = 0.0 # noqa: SLF001
|
||||
progress.n = 0
|
||||
assert progress._should_render(1.0, refresh=False, postfix_changed=True) is True # noqa: SLF001
|
||||
|
||||
|
||||
def test_wrapattr_passthrough():
|
||||
"""wrapattr should forward calls to the underlying method."""
|
||||
buf = io.StringIO()
|
||||
wrapped = _TqdmWrapper.wrapattr(buf, "write")
|
||||
wrapped("ok")
|
||||
assert buf.getvalue() == "ok"
|
||||
|
||||
|
||||
def test_external_write_mode_uses_provided_file():
|
||||
"""external_write_mode should enter with provided file object."""
|
||||
buf = io.StringIO()
|
||||
cm = _TqdmWrapper.external_write_mode(file=buf)
|
||||
with cm as target:
|
||||
target.write("line")
|
||||
assert buf.getvalue() == "line"
|
||||
|
||||
|
||||
def test_first_update_marks_render(monkeypatch):
|
||||
"""Initial update should render and record counters instead of throttling away."""
|
||||
monkeypatch.setattr("rich.tqdm.time.monotonic", lambda: 100.0)
|
||||
progress = _TqdmWrapper(range(1), mininterval=0.1)
|
||||
progress.update()
|
||||
|
||||
assert progress._last_render_n == progress.n # noqa: SLF001
|
||||
|
||||
|
||||
def test_update_after_close_is_noop():
|
||||
"""Updating after close should be ignored and not bump counters."""
|
||||
progress = _TqdmWrapper(range(1))
|
||||
progress.close()
|
||||
progress.update()
|
||||
|
||||
assert progress.n == 0
|
||||
Loading…
Add table
Add a link
Reference in a new issue