diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 4b04786b..ebcd23de 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -41,6 +41,7 @@ The following people have contributed to the development of Rich: - [Hedy Li](https://github.com/hedythedev) - [Henry Mai](https://github.com/tanducmai) - [Luka Mamukashvili](https://github.com/UltraStudioLTD) +- [Shane K2 Macaulay](https://github.com/K2) - [Alexander Mancevice](https://github.com/amancevice) - [Will McGugan](https://github.com/willmcgugan) - [Paul McGuire](https://github.com/ptmcg) diff --git a/README.md b/README.md index e46a1e42..8798ee13 100644 --- a/README.md +++ b/README.md @@ -275,6 +275,26 @@ The columns may be configured to show any details you want. Built-in columns inc To try this out yourself, see [examples/downloader.py](https://github.com/textualize/rich/blob/master/examples/downloader.py) which can download multiple URLs simultaneously while displaying progress. +#### tqdm compatibility shim + +If you need to redirect `tqdm` calls in third-party code to Rich at runtime, call `install_tqdm()` before the code imports `tqdm`: + +```python +from rich.tqdm import install_tqdm + +install_tqdm() + +from tqdm import tqdm + +for _ in tqdm(range(10), desc="basic"): + ... +``` + +Runnable examples: + +- Basic replacement: `PYTHONPATH=. python examples/tqdm_adapter_basic.py` +- Hidden reference overwrite (e.g. stashed `_hidden_tqdm`): `PYTHONPATH=. python examples/tqdm_adapter_hidden.py` +
diff --git a/examples/tqdm_adapter_basic.py b/examples/tqdm_adapter_basic.py new file mode 100644 index 00000000..d83422ff --- /dev/null +++ b/examples/tqdm_adapter_basic.py @@ -0,0 +1,22 @@ +"""Replace tqdm with rich.progress at runtime (basic example). + +Run: + PYTHONPATH=. python examples/tqdm_adapter_basic.py +""" + +from time import sleep + +from rich.tqdm import install_tqdm + + +def main() -> None: + install_tqdm() + + from tqdm import tqdm # resolved to the rich-backed shim after install + + for _ in tqdm(range(10), desc="basic"): + sleep(0.05) + + +if __name__ == "__main__": + main() diff --git a/examples/tqdm_adapter_hidden.py b/examples/tqdm_adapter_hidden.py new file mode 100644 index 00000000..69c3c412 --- /dev/null +++ b/examples/tqdm_adapter_hidden.py @@ -0,0 +1,23 @@ +"""Overwrite hidden tqdm references (aggressive replacement example). + +Simulates a third-party module (examples/tqdm_hidden_lib.py) that stashes +``tqdm`` in an internal variable at import time. The adapter should still +replace that hidden reference after install_tqdm(). + +Run: + PYTHONPATH=. python examples/tqdm_adapter_hidden.py +""" + +from rich.tqdm import install_tqdm + +from examples import tqdm_hidden_lib + + +def main() -> None: + # Import happens before install; hidden module already captured original tqdm + install_tqdm() # aggressive replacement updates hidden references + tqdm_hidden_lib.run_hidden_loop() + + +if __name__ == "__main__": + main() diff --git a/examples/tqdm_hidden_lib.py b/examples/tqdm_hidden_lib.py new file mode 100644 index 00000000..a0410917 --- /dev/null +++ b/examples/tqdm_hidden_lib.py @@ -0,0 +1,15 @@ +"""Simulated third-party module that hides a tqdm reference internally. + +Imported by tqdm_adapter_hidden.py to demonstrate aggressive replacement. +""" + +import time + +import tqdm as _tqdm + +_hidden_tqdm = _tqdm.tqdm + + +def run_hidden_loop() -> None: + for _ in _hidden_tqdm(range(5), desc="hidden-lib"): + time.sleep(0.05) diff --git a/rich/tqdm.py b/rich/tqdm.py new file mode 100644 index 00000000..9196f763 --- /dev/null +++ b/rich/tqdm.py @@ -0,0 +1,741 @@ +"""Compatibility shim to replace ``tqdm`` with Rich's progress bars. + +Rationale +--------- +- Opt-in bridge: call :func:`install_tqdm` to monkeypatch ``tqdm`` symbols + (``tqdm``, ``trange``, notebook variants) in imported modules **and** any + already-imported aliases found in ``sys.modules``. The shim stays inert until + you install it; uninstall with :func:`uninstall_tqdm`. +- Keep call sites stable: the wrapper preserves the common tqdm surface area + (iteration, len(), postfix, write, wrapattr) while routing rendering through + :class:`rich.progress.Progress`. +- Favor performance: reuses a shared global ``Progress`` when no custom console + or file is supplied, throttles refreshes via ``mininterval``/``miniters`` and + optional ``maxinterval``, and caches formatted postfix values to avoid string + churn when metrics are unchanged. +- Handle non-interactive outputs: a simplified text backend is used when + ``partial=True`` with a file-like target to approximate tqdm's plain output + without Rich's live rendering. + +This shim aims to be compatible with the most common tqdm usage patterns, but +it does not implement every tqdm argument. Unrecognized kwargs are ignored +rather than raising, matching tqdm's permissive behavior. +""" + +# pylint: disable=missing-function-docstring, import-outside-toplevel + +from __future__ import annotations + +import importlib +import sys +import time +import types +import threading +from types import TracebackType +from typing import ( + Any, + Callable, + ContextManager, + Dict, + IO, + Iterable, + Iterator, + List, + Optional, + Tuple, + Type, + cast, + Reversible, + Literal, +) + + +_PATCH_STATE: Dict[str, Any] = { + "installed": False, + "patched_modules": [], # list of (module, attr, original) + "replaced_locations": [], # list of (module, name, original) +} + +_GLOBAL_STATE: Dict[str, Optional[Any]] = {"progress": None} + + +class _DummyLock: + def __enter__(self) -> _DummyLock: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Literal[False]: + return False + + +def _format_interval(seconds: float) -> str: + # Simplified copy of tqdm.format_interval + minutes, sec = divmod(int(seconds), 60) + hours, minutes = divmod(minutes, 60) + if hours: + return f"{hours:d}:{minutes:02d}:{sec:02d}" + return f"{minutes:02d}:{sec:02d}" + + +def _format_postfix_value(value: Any) -> Any: + if isinstance(value, (int, float)): + formatted = f"{value:.3g}" + value_str = str(value) + return formatted if len(formatted) < len(value_str) else value_str + return value if isinstance(value, str) else str(value) + + +def format_num(n: Any) -> str: + formatted = str(_format_postfix_value(n)) + return formatted.replace("e+0", "e+").replace("e-0", "e-") + + +def format_interval(t: float) -> str: + return _format_interval(t) + + +def format_meter( + n: float, total: Optional[float], elapsed: float, prefix: str = "", **_: Any +) -> str: + total_str = "?" if total is None else format_num(total) + percentage: float + if total is None or total == 0: + percentage = 0.0 + else: + percentage = 100 * n / total + elapsed_str = _format_interval(elapsed) + rate = n / elapsed if elapsed > 0 else 0 + rate_str = "?" if rate == 0 else format_num(rate) + return ( + f"{prefix}{percentage:3.0f}%| | {format_num(n)}/{total_str} " + f"[{elapsed_str} Tuple[Any, bool]: + """Return a Progress instance and whether it is shared/global.""" + + if console is not None or file is not None: + from .progress import Progress + from .console import Console + + prog_console = console or Console(file=file) + progress = Progress(console=prog_console) + progress.start() + return progress, False + + progress = cast("Progress", _GLOBAL_STATE["progress"]) + if progress is None: + from . import get_console + from .progress import Progress + + progress = Progress(console=get_console()) + progress.start() + _GLOBAL_STATE["progress"] = progress + return progress, True + + +def _stop_global_progress() -> None: + progress = _GLOBAL_STATE.get("progress") + if progress is not None: + try: + progress.stop() + finally: + _GLOBAL_STATE["progress"] = None + + +def _attach_api_shims(obj: Any) -> None: + """Attach tqdm-compatible helper functions and class variables to a callable.""" + # pylint: disable=protected-access + obj.format_num = format_num + obj.format_interval = format_interval + obj.format_meter = format_meter + obj.write = _TqdmWrapper.write + obj.get_lock = _TqdmWrapper.get_lock + obj.set_lock = _TqdmWrapper.set_lock + obj.external_write_mode = _TqdmWrapper.external_write_mode + obj.wrapattr = _TqdmWrapper.wrapattr + obj._instances = _TqdmWrapper._instances + obj._lock = None + + +class _TextBackend: + """Minimal text renderer for partial compatibility mode.""" + + def __init__(self, stream: IO[str], mininterval: float) -> None: + self.stream = stream + self.mininterval = mininterval + self._last_render = 0.0 + + def render( + self, + *, + n: int, + total: Optional[int], + desc: str, + postfix: Dict[Any, Any], + leave: bool, + to_string: bool, + ) -> str: + now = time.monotonic() + if not to_string and now - self._last_render < self.mininterval: + return "" + self._last_render = now + + percent = 0 if not total else (n / total * 100 if total else 0) + bar_len = 10 + done = int(bar_len * (0 if not total else min(1.0, n / float(total)))) + abar = "#" * done + " " * (bar_len - done) + total_str = "?" if total is None else str(total) + line = f"\r{percent:3.0f}%|{abar}| {n}/{total_str} [00:00<]" + if desc: + line = f"{desc} {line}" + if postfix: + postfix_str = ", ".join(f"{k}={v}" for k, v in postfix.items()) + line = f"{line} {postfix_str}" + if to_string: + return line.strip() + if not leave and total is not None and n >= total: + return line + try: + self.stream.write(line) + except (OSError, ValueError): + return line + return line + + +class _TqdmWrapper: + """Lightweight stand-in for ``tqdm`` based on Rich progress with render throttling.""" + + _instances: List[Any] = [] + _lock_obj: Any = None + + def __init__( + self, + iterable: Optional[Iterable[Any]] = None, + total: Optional[int] = None, + desc: Optional[str] = None, + leave: bool = True, + disable: bool = False, + initial: int = 0, + console: Any = None, + file: Any = None, + partial: bool = False, + mininterval: float = 0.1, + maxinterval: Optional[float] = None, + miniters: Optional[int] = None, + ascii: bool = False, # pylint: disable=redefined-builtin + smoothing: Optional[float] = None, + unit: str = "it", + unit_scale: bool = False, + unit_divisor: int = 1000, + bar_format: Optional[str] = None, + **_: Any, + ) -> None: + del smoothing, unit, unit_scale, unit_divisor, bar_format + self._iterable = iterable + self._total = total if total is not None else None + self._desc = desc or "" + self._leave = leave + self.desc = self._desc + self._closed = False + self._disabled = disable + self._progress: Optional[Any] = None + self._task: Optional[Any] = None + self._write_stream: Optional[IO[str]] = None + self._partial = partial + self._use_text_backend = bool(partial and file is not None) + self._text_backend: Optional[_TextBackend] = None + self.dynamic_miniters = True + self.ascii = ascii + self._lock: Optional[Any] = None + self.postfix: Optional[Dict[Any, Any]] = {0: {}} + self.n = initial + self._is_global = False + self._mininterval = mininterval + self._maxinterval = maxinterval + self._miniters = 1 if miniters is None else max(1, int(miniters)) + # Seed render markers so the first update is eligible to render promptly. + self._last_render_time = 0.0 + self._last_render_n = self.n + + if self._total is None and hasattr(iterable, "__len__"): + try: + self._total = len(iterable) # type: ignore[arg-type] + except (TypeError, OverflowError): + self._total = None + + if self._use_text_backend: + self._write_stream = file + self._text_backend = _TextBackend(file, mininterval) + if not self._disabled and not self._use_text_backend: + self._progress, self._is_global = _get_progress(console=console, file=file) + if self._progress is not None: + self._task = self._progress.add_task(self._desc, total=self._total) + if initial: + self._progress.update(self._task, completed=initial) + + def __iter__(self) -> Iterator[Any]: + """Iterate while tracking progress updates, honoring disabled/text modes.""" + if self._iterable is None: + return iter(()) + iterator = iter(self._iterable) + if self._disabled: + return self._iter_disabled(iterator) + if self._use_text_backend: + self._maybe_render_text() + return self._iter_enabled(iterator) + + def _iter_disabled(self, iterator: Iterator[Any]) -> Iterator[Any]: + """Yield items while disabled, tracking count without rendering.""" + for item in iterator: + self.n += 1 + yield item + + def _iter_enabled(self, iterator: Iterator[Any]) -> Iterator[Any]: + """Yield items, updating progress each step and closing if not leaving.""" + for item in iterator: + yield item + self.update(1) + if not self._leave: + self.close() + + def update( + self, + n: int = 1, + *, + postfix: Optional[Dict[str, Any]] = None, + refresh: bool = False, + ) -> None: + """Advance the bar and conditionally refresh based on throttling thresholds.""" + if self._closed: + return + self.n += n + postfix_changed = False + if postfix: + new_postfix = {k: _format_postfix_value(v) for k, v in postfix.items()} + if new_postfix != (self.postfix or {}): + self.postfix = new_postfix + postfix_changed = True + if self._disabled: + return + if self._use_text_backend: + self._maybe_render_text() + return + progress = self._progress + task = self._task + if progress is None or task is None: + return + now = time.monotonic() + should_refresh = self._should_render( + now, refresh=refresh, postfix_changed=postfix_changed + ) + progress.update(task, advance=n, postfix=self.postfix, refresh=should_refresh) + if should_refresh: + self._mark_render(now) + + def set_description(self, desc: str = "") -> None: + if self._closed: + return + self._desc = str(desc) + if self._disabled: + return + self.desc = self._desc + if self._use_text_backend: + self._maybe_render_text() + return + progress = self._progress + task = self._task + if progress is None or task is None: + return + now = time.monotonic() + progress.update(task, description=self._desc, refresh=True) + self._mark_render(now) + + def set_postfix( + self, + ordered_dict: Optional[Dict[str, Any]] = None, + refresh: bool = True, + **kwargs: Any, + ) -> None: + """Update postfix fields while reusing the last render unless content changed.""" + if self._closed: + return + combined: Dict[str, Any] = {} + if ordered_dict: + combined.update(ordered_dict) + if kwargs: + combined.update(kwargs) + postfix_changed = False + if combined: + new_postfix = {k: _format_postfix_value(v) for k, v in combined.items()} + if new_postfix != (self.postfix or {}): + self.postfix = new_postfix + postfix_changed = True + if self._disabled: + return + if self._use_text_backend: + self._maybe_render_text() + return + progress = self._progress + task = self._task + if progress is None or task is None: + return + now = time.monotonic() + should_refresh = self._should_render( + now, refresh=refresh, postfix_changed=postfix_changed + ) + progress.update(task, postfix=self.postfix, refresh=should_refresh) + if should_refresh: + self._mark_render(now) + + def set_postfix_str(self, s: str = "", refresh: bool = True) -> None: + self.set_postfix({"postfix": s}, refresh=refresh) + + def close(self) -> None: + if self._closed: + return + if not self._disabled and self._progress and self._task is not None: + self._progress.remove_task(self._task) + self._closed = True + try: + self._instances.remove(self) + except ValueError: + pass + if self._progress and not self._progress.tasks and self._is_global: + _stop_global_progress() + + def refresh(self) -> None: + if self._disabled: + return + if self._closed: + return + if self._use_text_backend: + self._maybe_render_text() + return + if self._progress is None: + return + self._progress.refresh() + + def __len__(self) -> int: + if self._total is not None: + return int(self._total) + if self._iterable is not None: + try: + return len(self._iterable) # type: ignore[arg-type] + except TypeError: + pass + raise TypeError("object of type '_TqdmWrapper' has no len()") + + def __bool__(self) -> bool: + return bool(self._total or self._iterable) + + def __repr__(self) -> str: + return self._maybe_render_text(to_string=True) + + def __lt__(self, other: Any) -> bool: + try: + return (self._total or 0) < (getattr(other, "_total", 0) or 0) + except (AttributeError, TypeError): + return False + + def __enter__(self) -> _TqdmWrapper: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + + def __contains__(self, item: Any) -> bool: + if self._iterable is None: + return False + try: + return item in self._iterable + except (TypeError, ValueError): + return False + + def __reversed__(self) -> Iterator[Any]: + if self._iterable is None: + raise TypeError("_TqdmWrapper object is not reversible") + try: + reversible = cast(Reversible[Any], self._iterable) + return reversed(reversible) + except (TypeError, AttributeError) as exc: + raise TypeError("_TqdmWrapper object is not reversible") from exc + + def _maybe_render_text(self, to_string: bool = False) -> str: + if not self._text_backend: + return "" + postfix_items = {k: v for k, v in (self.postfix or {}).items() if k != 0 and v} + return self._text_backend.render( + n=self.n, + total=self._total, + desc=self._desc, + postfix=postfix_items, + leave=self._leave, + to_string=to_string, + ) + + def _should_render( + self, now: float, *, refresh: bool, postfix_changed: bool = False + ) -> bool: + """Decide whether to render based on min interval/iters and optional max interval.""" + if refresh: + return True + delta_time = now - self._last_render_time + delta_n = self.n - self._last_render_n + if postfix_changed and delta_time >= self._mininterval: + return True + if delta_time >= self._mininterval and delta_n >= self._miniters: + return True + if self._maxinterval is not None and delta_time >= self._maxinterval: + return True + return False + + def _mark_render(self, now: float) -> None: + """Record the last render time and counter for subsequent throttling decisions.""" + self._last_render_time = now + self._last_render_n = self.n + + def set_description_str(self, desc: str = "") -> None: + self.set_description(desc) + + def reset(self, total: Optional[int] = None) -> None: + if total is not None: + self._total = total + self.n = 0 + if self._use_text_backend: + self._maybe_render_text() + elif self._progress and self._task is not None: + self._progress.reset(self._task, total=self._total) + + def unpause(self) -> None: + return + + def clear(self, nolock: bool = False) -> None: + del nolock + if self._use_text_backend and self._write_stream is not None: + try: + self._write_stream.write("\n") + except (OSError, ValueError): + return + return + + @classmethod + def write( + cls, + s: str, + file: Optional[IO[str]] = None, + end: str = "\n", + nolock: bool = False, + ) -> None: + del nolock + target = file or sys.stderr + try: + target.write(str(s) + end) + except (OSError, ValueError): + return + + @classmethod + def get_lock(cls) -> Any: + if cls._lock_obj is None: + try: + cls._lock_obj = threading.Lock() + except (OSError, RuntimeError): + cls._lock_obj = _DummyLock() + return cls._lock_obj + + @classmethod + def set_lock(cls, lock: Any) -> None: + cls._lock_obj = lock + + @classmethod + def external_write_mode( + cls, file: Optional[IO[str]] = None, nolock: bool = False + ) -> ContextManager[IO[str]]: + """Context manager mirroring tqdm.external_write_mode for compatibility.""" + + del nolock + + class _CM: + def __enter__(self) -> IO[str]: + return file or sys.stderr + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Literal[False]: + return False + + return _CM() + + @classmethod + def wrapattr( + cls, + stream: Any, + method: str, + total: Optional[int] = None, + bytes: bool = False, # pylint: disable=redefined-builtin + **kwargs: Any, + ) -> Any: + """Wrap a stream attribute call for compatibility with tqdm's wrapattr.""" + del total, bytes, kwargs + func = getattr(stream, method) + + class _Wrap: + def __init__(self, fn: Callable[..., Any]): + self._fn = fn + + def __call__(self, *a: Any, **k: Any) -> Any: + return self._fn(*a, **k) + + def __getattr__(self, name: str) -> Any: + return getattr(self._fn, name) + + def __enter__(self) -> "_Wrap": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Literal[False]: + return False + + def write(self, data: Any) -> Any: + return self._fn(data) + + return _Wrap(func) + + def __del__(self) -> None: + if getattr(self, "_closed", True): + return + self.close() + + +def _make_tqdm_callable( + console: Any = None, *, partial: bool = False +) -> Callable[[Optional[Iterable[Any]]], _TqdmWrapper]: + def _tqdm(iterable: Optional[Iterable[Any]] = None, **kwargs: Any) -> _TqdmWrapper: + return _TqdmWrapper( + iterable=iterable, console=console, partial=partial, **kwargs + ) + + _attach_api_shims(_tqdm) + + return _tqdm + + +def _make_trange( + console: Any = None, *, partial: bool = False +) -> Callable[..., _TqdmWrapper]: + def _trange(*args: Any, **kwargs: Any) -> _TqdmWrapper: + return _make_tqdm_callable(console, partial=partial)(range(*args), **kwargs) + + _attach_api_shims(_trange) + + return _trange + + +def install_tqdm(console: Any = None, *, partial: bool = False) -> None: + """Monkeypatch ``tqdm`` symbols to use Rich progress. + + Args: + console: Optional Console to route progress output through. + partial: When True, enable simplified text rendering for file-like outputs to + approximate tqdm aesthetics. When False (default), favor the Rich progress + presentation and avoid extra tqdm-style formatting. + """ + + if _PATCH_STATE["installed"]: + return + + module_names = ["tqdm", "tqdm.std", "tqdm.auto", "tqdm.notebook"] + target_attrs = ["tqdm", "trange", "tqdm_notebook", "tnrange"] + + replacements: Dict[Any, Any] = {} + to_patch: List[Tuple[types.ModuleType, str, Any]] = [] + + for modname in module_names: + try: + mod = importlib.import_module(modname) + except ImportError: + continue + for attr in target_attrs: + if not hasattr(mod, attr): + continue + orig = getattr(mod, attr) + if orig in replacements: + continue + newobj = ( + _make_tqdm_callable(console, partial=partial) + if attr in ("tqdm", "tqdm_notebook") + else _make_trange(console, partial=partial) + ) + replacements[orig] = newobj + to_patch.append((mod, attr, orig)) + + replaced_locations: List[Tuple[types.ModuleType, str, Any]] = [] + for orig, newval in list(replacements.items()): + for module in list(sys.modules.values()): + moddict = getattr(module, "__dict__", None) + if not isinstance(moddict, dict): + continue + for name, val in list(moddict.items()): + if val is orig: + replaced_locations.append((module, name, orig)) + try: + setattr(module, name, newval) + except (AttributeError, TypeError): + continue + + patched_modules: List[Tuple[types.ModuleType, str, Any]] = [] + for mod, attr, orig in to_patch: + replacement: Any = replacements.get(orig) + if replacement is None: + continue + new_callable = cast(Callable[..., Any], replacement) + try: + setattr(mod, attr, new_callable) + patched_modules.append((mod, attr, orig)) + except (AttributeError, TypeError): + continue + + _PATCH_STATE["installed"] = True + _PATCH_STATE["patched_modules"] = patched_modules + _PATCH_STATE["replaced_locations"] = replaced_locations + + +def uninstall_tqdm() -> None: + """Revert monkeypatching performed by :func:`install_tqdm`.""" + + if not _PATCH_STATE["installed"]: + return + + for module, name, orig in _PATCH_STATE.get("replaced_locations", []): + try: + setattr(module, name, orig) + except (AttributeError, TypeError): + continue + + for module, attr, orig in _PATCH_STATE.get("patched_modules", []): + try: + setattr(module, attr, orig) + except (AttributeError, TypeError): + continue + + _PATCH_STATE["installed"] = False + _PATCH_STATE["patched_modules"] = [] + _PATCH_STATE["replaced_locations"] = [] + _stop_global_progress() diff --git a/tests/test_tqdm_adapter.py b/tests/test_tqdm_adapter.py new file mode 100644 index 00000000..07c9d69f --- /dev/null +++ b/tests/test_tqdm_adapter.py @@ -0,0 +1,82 @@ +"""Adapter install/uninstall coverage for the Rich tqdm shim.""" + +import sys +import types +from typing import cast + +from rich.tqdm import _TqdmWrapper, install_tqdm, uninstall_tqdm + + +def test_install_replaces_imported_reference() -> None: + """Replacing imported tqdm symbols should affect existing aliases.""" + saved = sys.modules.pop("tqdm", None) + try: + fake = types.ModuleType("tqdm") + + def orig_tqdm(iterable=None, **kwargs): + del kwargs + if iterable is None: + return None + return ("orig", list(iterable)) + + def orig_trange(*args, **kwargs): + del kwargs + return range(*args) + + fake.tqdm = orig_tqdm + fake.trange = orig_trange + sys.modules["tqdm"] = fake + + third = types.ModuleType("thirdparty") + third.tqdm_alias = fake.tqdm + sys.modules["thirdparty"] = third + + try: + assert third.tqdm_alias is orig_tqdm + install_tqdm() + assert third.tqdm_alias is not orig_tqdm + it = third.tqdm_alias(range(3)) + assert list(it) == [0, 1, 2] + finally: + uninstall_tqdm() + assert third.tqdm_alias is orig_tqdm + finally: + sys.modules.pop("thirdparty", None) + if saved is not None: + sys.modules["tqdm"] = saved + else: + sys.modules.pop("tqdm", None) + + +def test_install_replaces_hidden_reference_and_postfix_support() -> None: + """Hidden tqdm references should update and still format postfix values.""" + saved = sys.modules.pop("tqdm", None) + try: + fake = types.ModuleType("tqdm") + + def orig_tqdm(iterable=None, **kwargs): + del kwargs + if iterable is None: + return None + return ("orig", list(iterable)) + + fake.tqdm = orig_tqdm + fake._hidden_tqdm = orig_tqdm # pylint: disable=protected-access + sys.modules["tqdm"] = fake + + try: + install_tqdm() + patched = fake.tqdm + assert patched is not orig_tqdm + assert fake._hidden_tqdm is patched # pylint: disable=protected-access + + progress = cast(_TqdmWrapper, patched(range(2))) + progress.update(postfix={"loss": 1.23}) # pylint: disable=no-member + assert progress.postfix == {"loss": "1.23"} # pylint: disable=no-member + finally: + uninstall_tqdm() + finally: + if saved is not None: + sys.modules["tqdm"] = saved + else: + sys.modules.pop("tqdm", None) diff --git a/tests/test_tqdm_parity.py b/tests/test_tqdm_parity.py new file mode 100644 index 00000000..cb4c9984 --- /dev/null +++ b/tests/test_tqdm_parity.py @@ -0,0 +1,60 @@ +"""Parity checks between Rich's tqdm shim and vendored tqdm reference. + +Note: this test relies on a vendored copy of tqdm under ``tests/tqdm`` to +provide a reference implementation. If that folder is absent, the test is +skipped. Keep the vendored copy lightweight and in sync with the targeted tqdm +version when you want to exercise parity; otherwise the skip is expected in +normal installs. +""" + +# pylint: disable=import-error + +from __future__ import annotations + +import importlib +import sys +from pathlib import Path + +import pytest + +from rich.tqdm import install_tqdm, uninstall_tqdm + + +SUBMODULE_ROOT = Path(__file__).parent / "tqdm" + + +def test_tqdm_adapter_len_and_disable_parity() -> None: + """Use the vendored tqdm submodule as a baseline for adapter parity.""" + + if not SUBMODULE_ROOT.exists(): + pytest.skip("vendored tqdm submodule missing") + + added_path = False + submodule_path = str(SUBMODULE_ROOT) + if submodule_path not in sys.path: + sys.path.insert(0, submodule_path) + added_path = True + + orig_mod = None + try: + orig_mod = importlib.import_module("tqdm") + orig_tqdm = orig_mod.tqdm + + baseline_len = len(orig_tqdm(range(5))) + + install_tqdm() + patched_tqdm = orig_mod.tqdm + + assert patched_tqdm is not orig_tqdm + assert len(patched_tqdm(range(5))) == baseline_len + assert list(patched_tqdm(range(3), disable=True)) == [0, 1, 2] + finally: + uninstall_tqdm() + if orig_mod is not None: + importlib.reload(orig_mod) + sys.modules.pop("tqdm", None) + if added_path: + try: + sys.path.remove(submodule_path) + except ValueError: + pass diff --git a/tests/test_tqdm_related.py b/tests/test_tqdm_related.py new file mode 100644 index 00000000..adbcea18 --- /dev/null +++ b/tests/test_tqdm_related.py @@ -0,0 +1,133 @@ +"""Unit-level helpers that mirror tqdm formatting and postfix handling.""" + +import io + +from rich.tqdm import _TqdmWrapper, _TextBackend + +# pylint: disable=protected-access + + +def format_num(n): + """Intelligent scientific notation (.3g), mirroring tqdm.""" + + f = "{0:.3g}".format(n).replace("+0", "+").replace("-0", "-") + n_str = str(n) + return f if len(f) < len(n_str) else n_str + + +def test_format_num_matches_shorter_representation(): + """Ensure numeric formatting mirrors tqdm's shorter-of fixed/scientific rule.""" + assert format_num(1337) == "1337" + # prefers the shorter of fixed or scientific + assert format_num(1239876) == "1239876" + assert format_num(0.00001234) == "1.23e-5" + assert format_num(-0.1234) == "-.123" + + +def test_wrapper_accepts_postfix_and_len(): + """Wrapper should expose len() and normalize postfix values.""" + progress = _TqdmWrapper(range(3), total=None) + assert len(progress) == 3 + progress.update(postfix={"loss": 1.2345}) + assert progress.postfix == {"loss": "1.23"} + + +def test_set_postfix_allows_multiple_fields(): + """Multiple postfix fields remain formatted to short strings.""" + progress = _TqdmWrapper(range(2), total=2) + progress.set_postfix({"loss": 0.9876}, acc=0.42) + assert progress.postfix == {"loss": "0.988", "acc": "0.42"} + + +def test_set_postfix_str_passthrough(): + """String postfix should pass through unchanged.""" + progress = _TqdmWrapper(range(1), total=1) + progress.set_postfix_str("done") + assert progress.postfix == {"postfix": "done"} + + +def test_wrapper_disable_skips_progress_calls(): + """Disabled wrapper still increments counters without rendering.""" + progress = _TqdmWrapper(range(2), disable=True) + assert list(progress) == [0, 1] + # update should not raise when disabled and should still track n + progress.update(2, postfix={"acc": 0.9}) + assert progress.n == 4 + + +def test_postfix_idempotent_when_unchanged(): + """Calling set_postfix with identical content should keep the same mapping.""" + progress = _TqdmWrapper(range(1), total=1) + progress.set_postfix({"loss": 1.0}) + first = progress.postfix + progress.set_postfix({"loss": 1.0}) + assert progress.postfix is first + + +def test_text_backend_throttles(monkeypatch): + """Text backend should skip renders inside the mininterval window.""" + backend = _TextBackend(io.StringIO(), mininterval=0.1) + backend._last_render = -1.0 # noqa: SLF001 ensure first render passes throttle + ticks = iter([0.0, 0.05, 0.2]) + monkeypatch.setattr("rich.tqdm.time.monotonic", lambda: next(ticks)) + + first = backend.render(n=1, total=2, desc="d", postfix={}, leave=True, to_string=False) + second = backend.render(n=2, total=2, desc="d", postfix={}, leave=True, to_string=False) + third = backend.render(n=2, total=2, desc="d", postfix={}, leave=True, to_string=False) + + assert first + assert second == "" # throttled + assert third # rendered after interval + + +def test_should_render_thresholds(): + """Render should wait for both time and iter thresholds unless postfix changed.""" + progress = _TqdmWrapper(range(1), mininterval=1.0, miniters=2) + progress._last_render_time = 0.0 # noqa: SLF001 + progress._last_render_n = 0 # noqa: SLF001 + + progress.n = 1 + assert progress._should_render(0.5, refresh=False, postfix_changed=False) is False # noqa: SLF001 + + progress.n = 2 + assert progress._should_render(0.5, refresh=False, postfix_changed=False) is False # noqa: SLF001 + assert progress._should_render(1.0, refresh=False, postfix_changed=False) is True # noqa: SLF001 + + progress._last_render_time = 0.0 # noqa: SLF001 + progress.n = 0 + assert progress._should_render(1.0, refresh=False, postfix_changed=True) is True # noqa: SLF001 + + +def test_wrapattr_passthrough(): + """wrapattr should forward calls to the underlying method.""" + buf = io.StringIO() + wrapped = _TqdmWrapper.wrapattr(buf, "write") + wrapped("ok") + assert buf.getvalue() == "ok" + + +def test_external_write_mode_uses_provided_file(): + """external_write_mode should enter with provided file object.""" + buf = io.StringIO() + cm = _TqdmWrapper.external_write_mode(file=buf) + with cm as target: + target.write("line") + assert buf.getvalue() == "line" + + +def test_first_update_marks_render(monkeypatch): + """Initial update should render and record counters instead of throttling away.""" + monkeypatch.setattr("rich.tqdm.time.monotonic", lambda: 100.0) + progress = _TqdmWrapper(range(1), mininterval=0.1) + progress.update() + + assert progress._last_render_n == progress.n # noqa: SLF001 + + +def test_update_after_close_is_noop(): + """Updating after close should be ignored and not bump counters.""" + progress = _TqdmWrapper(range(1)) + progress.close() + progress.update() + + assert progress.n == 0