Add tqdm compatibility shim with Rich progress backend

Implements a runtime adapter that replaces tqdm imports with Rich's
progress bars via install_tqdm(). The shim preserves common tqdm APIs
(iteration, len(), postfix, write, wrapattr) while routing rendering
through rich.progress.Progress.

Features:
- Opt-in monkeypatching of tqdm/trange and notebook variants
- Shared global Progress instance for performance
- Render throttling via mininterval/miniters/maxinterval
- Simplified text backend for file outputs (partial mode)
- Postfix value caching to avoid redundant formatting

Includes runnable examples demonstrating basic replacement and
aggressive reference overwriting, plus comprehensive unit tests
covering adapter installation, throttling, and postfix handling.

Updates README and contributor list.
This commit is contained in:
K2 2025-12-05 07:41:51 -05:00
parent f82a399d58
commit adcb5c748c
9 changed files with 1097 additions and 0 deletions

View file

@ -41,6 +41,7 @@ The following people have contributed to the development of Rich:
- [Hedy Li](https://github.com/hedythedev)
- [Henry Mai](https://github.com/tanducmai)
- [Luka Mamukashvili](https://github.com/UltraStudioLTD)
- [Shane K2 Macaulay](https://github.com/K2)
- [Alexander Mancevice](https://github.com/amancevice)
- [Will McGugan](https://github.com/willmcgugan)
- [Paul McGuire](https://github.com/ptmcg)

View file

@ -275,6 +275,26 @@ The columns may be configured to show any details you want. Built-in columns inc
To try this out yourself, see [examples/downloader.py](https://github.com/textualize/rich/blob/master/examples/downloader.py) which can download multiple URLs simultaneously while displaying progress.
#### tqdm compatibility shim
If you need to redirect `tqdm` calls in third-party code to Rich at runtime, call `install_tqdm()` before the code imports `tqdm`:
```python
from rich.tqdm import install_tqdm
install_tqdm()
from tqdm import tqdm
for _ in tqdm(range(10), desc="basic"):
...
```
Runnable examples:
- Basic replacement: `PYTHONPATH=. python examples/tqdm_adapter_basic.py`
- Hidden reference overwrite (e.g. stashed `_hidden_tqdm`): `PYTHONPATH=. python examples/tqdm_adapter_hidden.py`
</details>
<details>

View file

@ -0,0 +1,22 @@
"""Replace tqdm with rich.progress at runtime (basic example).
Run:
PYTHONPATH=. python examples/tqdm_adapter_basic.py
"""
from time import sleep
from rich.tqdm import install_tqdm
def main() -> None:
install_tqdm()
from tqdm import tqdm # resolved to the rich-backed shim after install
for _ in tqdm(range(10), desc="basic"):
sleep(0.05)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,23 @@
"""Overwrite hidden tqdm references (aggressive replacement example).
Simulates a third-party module (examples/tqdm_hidden_lib.py) that stashes
``tqdm`` in an internal variable at import time. The adapter should still
replace that hidden reference after install_tqdm().
Run:
PYTHONPATH=. python examples/tqdm_adapter_hidden.py
"""
from rich.tqdm import install_tqdm
from examples import tqdm_hidden_lib
def main() -> None:
# Import happens before install; hidden module already captured original tqdm
install_tqdm() # aggressive replacement updates hidden references
tqdm_hidden_lib.run_hidden_loop()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,15 @@
"""Simulated third-party module that hides a tqdm reference internally.
Imported by tqdm_adapter_hidden.py to demonstrate aggressive replacement.
"""
import time
import tqdm as _tqdm
_hidden_tqdm = _tqdm.tqdm
def run_hidden_loop() -> None:
for _ in _hidden_tqdm(range(5), desc="hidden-lib"):
time.sleep(0.05)

741
rich/tqdm.py Normal file
View file

@ -0,0 +1,741 @@
"""Compatibility shim to replace ``tqdm`` with Rich's progress bars.
Rationale
---------
- Opt-in bridge: call :func:`install_tqdm` to monkeypatch ``tqdm`` symbols
(``tqdm``, ``trange``, notebook variants) in imported modules **and** any
already-imported aliases found in ``sys.modules``. The shim stays inert until
you install it; uninstall with :func:`uninstall_tqdm`.
- Keep call sites stable: the wrapper preserves the common tqdm surface area
(iteration, len(), postfix, write, wrapattr) while routing rendering through
:class:`rich.progress.Progress`.
- Favor performance: reuses a shared global ``Progress`` when no custom console
or file is supplied, throttles refreshes via ``mininterval``/``miniters`` and
optional ``maxinterval``, and caches formatted postfix values to avoid string
churn when metrics are unchanged.
- Handle non-interactive outputs: a simplified text backend is used when
``partial=True`` with a file-like target to approximate tqdm's plain output
without Rich's live rendering.
This shim aims to be compatible with the most common tqdm usage patterns, but
it does not implement every tqdm argument. Unrecognized kwargs are ignored
rather than raising, matching tqdm's permissive behavior.
"""
# pylint: disable=missing-function-docstring, import-outside-toplevel
from __future__ import annotations
import importlib
import sys
import time
import types
import threading
from types import TracebackType
from typing import (
Any,
Callable,
ContextManager,
Dict,
IO,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
cast,
Reversible,
Literal,
)
_PATCH_STATE: Dict[str, Any] = {
"installed": False,
"patched_modules": [], # list of (module, attr, original)
"replaced_locations": [], # list of (module, name, original)
}
_GLOBAL_STATE: Dict[str, Optional[Any]] = {"progress": None}
class _DummyLock:
def __enter__(self) -> _DummyLock:
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> Literal[False]:
return False
def _format_interval(seconds: float) -> str:
# Simplified copy of tqdm.format_interval
minutes, sec = divmod(int(seconds), 60)
hours, minutes = divmod(minutes, 60)
if hours:
return f"{hours:d}:{minutes:02d}:{sec:02d}"
return f"{minutes:02d}:{sec:02d}"
def _format_postfix_value(value: Any) -> Any:
if isinstance(value, (int, float)):
formatted = f"{value:.3g}"
value_str = str(value)
return formatted if len(formatted) < len(value_str) else value_str
return value if isinstance(value, str) else str(value)
def format_num(n: Any) -> str:
formatted = str(_format_postfix_value(n))
return formatted.replace("e+0", "e+").replace("e-0", "e-")
def format_interval(t: float) -> str:
return _format_interval(t)
def format_meter(
n: float, total: Optional[float], elapsed: float, prefix: str = "", **_: Any
) -> str:
total_str = "?" if total is None else format_num(total)
percentage: float
if total is None or total == 0:
percentage = 0.0
else:
percentage = 100 * n / total
elapsed_str = _format_interval(elapsed)
rate = n / elapsed if elapsed > 0 else 0
rate_str = "?" if rate == 0 else format_num(rate)
return (
f"{prefix}{percentage:3.0f}%| | {format_num(n)}/{total_str} "
f"[{elapsed_str}<?, {rate_str}it/s]"
)
def _get_progress(console: Any = None, file: Any = None) -> Tuple[Any, bool]:
"""Return a Progress instance and whether it is shared/global."""
if console is not None or file is not None:
from .progress import Progress
from .console import Console
prog_console = console or Console(file=file)
progress = Progress(console=prog_console)
progress.start()
return progress, False
progress = cast("Progress", _GLOBAL_STATE["progress"])
if progress is None:
from . import get_console
from .progress import Progress
progress = Progress(console=get_console())
progress.start()
_GLOBAL_STATE["progress"] = progress
return progress, True
def _stop_global_progress() -> None:
progress = _GLOBAL_STATE.get("progress")
if progress is not None:
try:
progress.stop()
finally:
_GLOBAL_STATE["progress"] = None
def _attach_api_shims(obj: Any) -> None:
"""Attach tqdm-compatible helper functions and class variables to a callable."""
# pylint: disable=protected-access
obj.format_num = format_num
obj.format_interval = format_interval
obj.format_meter = format_meter
obj.write = _TqdmWrapper.write
obj.get_lock = _TqdmWrapper.get_lock
obj.set_lock = _TqdmWrapper.set_lock
obj.external_write_mode = _TqdmWrapper.external_write_mode
obj.wrapattr = _TqdmWrapper.wrapattr
obj._instances = _TqdmWrapper._instances
obj._lock = None
class _TextBackend:
"""Minimal text renderer for partial compatibility mode."""
def __init__(self, stream: IO[str], mininterval: float) -> None:
self.stream = stream
self.mininterval = mininterval
self._last_render = 0.0
def render(
self,
*,
n: int,
total: Optional[int],
desc: str,
postfix: Dict[Any, Any],
leave: bool,
to_string: bool,
) -> str:
now = time.monotonic()
if not to_string and now - self._last_render < self.mininterval:
return ""
self._last_render = now
percent = 0 if not total else (n / total * 100 if total else 0)
bar_len = 10
done = int(bar_len * (0 if not total else min(1.0, n / float(total))))
abar = "#" * done + " " * (bar_len - done)
total_str = "?" if total is None else str(total)
line = f"\r{percent:3.0f}%|{abar}| {n}/{total_str} [00:00<]"
if desc:
line = f"{desc} {line}"
if postfix:
postfix_str = ", ".join(f"{k}={v}" for k, v in postfix.items())
line = f"{line} {postfix_str}"
if to_string:
return line.strip()
if not leave and total is not None and n >= total:
return line
try:
self.stream.write(line)
except (OSError, ValueError):
return line
return line
class _TqdmWrapper:
"""Lightweight stand-in for ``tqdm`` based on Rich progress with render throttling."""
_instances: List[Any] = []
_lock_obj: Any = None
def __init__(
self,
iterable: Optional[Iterable[Any]] = None,
total: Optional[int] = None,
desc: Optional[str] = None,
leave: bool = True,
disable: bool = False,
initial: int = 0,
console: Any = None,
file: Any = None,
partial: bool = False,
mininterval: float = 0.1,
maxinterval: Optional[float] = None,
miniters: Optional[int] = None,
ascii: bool = False, # pylint: disable=redefined-builtin
smoothing: Optional[float] = None,
unit: str = "it",
unit_scale: bool = False,
unit_divisor: int = 1000,
bar_format: Optional[str] = None,
**_: Any,
) -> None:
del smoothing, unit, unit_scale, unit_divisor, bar_format
self._iterable = iterable
self._total = total if total is not None else None
self._desc = desc or ""
self._leave = leave
self.desc = self._desc
self._closed = False
self._disabled = disable
self._progress: Optional[Any] = None
self._task: Optional[Any] = None
self._write_stream: Optional[IO[str]] = None
self._partial = partial
self._use_text_backend = bool(partial and file is not None)
self._text_backend: Optional[_TextBackend] = None
self.dynamic_miniters = True
self.ascii = ascii
self._lock: Optional[Any] = None
self.postfix: Optional[Dict[Any, Any]] = {0: {}}
self.n = initial
self._is_global = False
self._mininterval = mininterval
self._maxinterval = maxinterval
self._miniters = 1 if miniters is None else max(1, int(miniters))
# Seed render markers so the first update is eligible to render promptly.
self._last_render_time = 0.0
self._last_render_n = self.n
if self._total is None and hasattr(iterable, "__len__"):
try:
self._total = len(iterable) # type: ignore[arg-type]
except (TypeError, OverflowError):
self._total = None
if self._use_text_backend:
self._write_stream = file
self._text_backend = _TextBackend(file, mininterval)
if not self._disabled and not self._use_text_backend:
self._progress, self._is_global = _get_progress(console=console, file=file)
if self._progress is not None:
self._task = self._progress.add_task(self._desc, total=self._total)
if initial:
self._progress.update(self._task, completed=initial)
def __iter__(self) -> Iterator[Any]:
"""Iterate while tracking progress updates, honoring disabled/text modes."""
if self._iterable is None:
return iter(())
iterator = iter(self._iterable)
if self._disabled:
return self._iter_disabled(iterator)
if self._use_text_backend:
self._maybe_render_text()
return self._iter_enabled(iterator)
def _iter_disabled(self, iterator: Iterator[Any]) -> Iterator[Any]:
"""Yield items while disabled, tracking count without rendering."""
for item in iterator:
self.n += 1
yield item
def _iter_enabled(self, iterator: Iterator[Any]) -> Iterator[Any]:
"""Yield items, updating progress each step and closing if not leaving."""
for item in iterator:
yield item
self.update(1)
if not self._leave:
self.close()
def update(
self,
n: int = 1,
*,
postfix: Optional[Dict[str, Any]] = None,
refresh: bool = False,
) -> None:
"""Advance the bar and conditionally refresh based on throttling thresholds."""
if self._closed:
return
self.n += n
postfix_changed = False
if postfix:
new_postfix = {k: _format_postfix_value(v) for k, v in postfix.items()}
if new_postfix != (self.postfix or {}):
self.postfix = new_postfix
postfix_changed = True
if self._disabled:
return
if self._use_text_backend:
self._maybe_render_text()
return
progress = self._progress
task = self._task
if progress is None or task is None:
return
now = time.monotonic()
should_refresh = self._should_render(
now, refresh=refresh, postfix_changed=postfix_changed
)
progress.update(task, advance=n, postfix=self.postfix, refresh=should_refresh)
if should_refresh:
self._mark_render(now)
def set_description(self, desc: str = "") -> None:
if self._closed:
return
self._desc = str(desc)
if self._disabled:
return
self.desc = self._desc
if self._use_text_backend:
self._maybe_render_text()
return
progress = self._progress
task = self._task
if progress is None or task is None:
return
now = time.monotonic()
progress.update(task, description=self._desc, refresh=True)
self._mark_render(now)
def set_postfix(
self,
ordered_dict: Optional[Dict[str, Any]] = None,
refresh: bool = True,
**kwargs: Any,
) -> None:
"""Update postfix fields while reusing the last render unless content changed."""
if self._closed:
return
combined: Dict[str, Any] = {}
if ordered_dict:
combined.update(ordered_dict)
if kwargs:
combined.update(kwargs)
postfix_changed = False
if combined:
new_postfix = {k: _format_postfix_value(v) for k, v in combined.items()}
if new_postfix != (self.postfix or {}):
self.postfix = new_postfix
postfix_changed = True
if self._disabled:
return
if self._use_text_backend:
self._maybe_render_text()
return
progress = self._progress
task = self._task
if progress is None or task is None:
return
now = time.monotonic()
should_refresh = self._should_render(
now, refresh=refresh, postfix_changed=postfix_changed
)
progress.update(task, postfix=self.postfix, refresh=should_refresh)
if should_refresh:
self._mark_render(now)
def set_postfix_str(self, s: str = "", refresh: bool = True) -> None:
self.set_postfix({"postfix": s}, refresh=refresh)
def close(self) -> None:
if self._closed:
return
if not self._disabled and self._progress and self._task is not None:
self._progress.remove_task(self._task)
self._closed = True
try:
self._instances.remove(self)
except ValueError:
pass
if self._progress and not self._progress.tasks and self._is_global:
_stop_global_progress()
def refresh(self) -> None:
if self._disabled:
return
if self._closed:
return
if self._use_text_backend:
self._maybe_render_text()
return
if self._progress is None:
return
self._progress.refresh()
def __len__(self) -> int:
if self._total is not None:
return int(self._total)
if self._iterable is not None:
try:
return len(self._iterable) # type: ignore[arg-type]
except TypeError:
pass
raise TypeError("object of type '_TqdmWrapper' has no len()")
def __bool__(self) -> bool:
return bool(self._total or self._iterable)
def __repr__(self) -> str:
return self._maybe_render_text(to_string=True)
def __lt__(self, other: Any) -> bool:
try:
return (self._total or 0) < (getattr(other, "_total", 0) or 0)
except (AttributeError, TypeError):
return False
def __enter__(self) -> _TqdmWrapper:
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()
def __contains__(self, item: Any) -> bool:
if self._iterable is None:
return False
try:
return item in self._iterable
except (TypeError, ValueError):
return False
def __reversed__(self) -> Iterator[Any]:
if self._iterable is None:
raise TypeError("_TqdmWrapper object is not reversible")
try:
reversible = cast(Reversible[Any], self._iterable)
return reversed(reversible)
except (TypeError, AttributeError) as exc:
raise TypeError("_TqdmWrapper object is not reversible") from exc
def _maybe_render_text(self, to_string: bool = False) -> str:
if not self._text_backend:
return ""
postfix_items = {k: v for k, v in (self.postfix or {}).items() if k != 0 and v}
return self._text_backend.render(
n=self.n,
total=self._total,
desc=self._desc,
postfix=postfix_items,
leave=self._leave,
to_string=to_string,
)
def _should_render(
self, now: float, *, refresh: bool, postfix_changed: bool = False
) -> bool:
"""Decide whether to render based on min interval/iters and optional max interval."""
if refresh:
return True
delta_time = now - self._last_render_time
delta_n = self.n - self._last_render_n
if postfix_changed and delta_time >= self._mininterval:
return True
if delta_time >= self._mininterval and delta_n >= self._miniters:
return True
if self._maxinterval is not None and delta_time >= self._maxinterval:
return True
return False
def _mark_render(self, now: float) -> None:
"""Record the last render time and counter for subsequent throttling decisions."""
self._last_render_time = now
self._last_render_n = self.n
def set_description_str(self, desc: str = "") -> None:
self.set_description(desc)
def reset(self, total: Optional[int] = None) -> None:
if total is not None:
self._total = total
self.n = 0
if self._use_text_backend:
self._maybe_render_text()
elif self._progress and self._task is not None:
self._progress.reset(self._task, total=self._total)
def unpause(self) -> None:
return
def clear(self, nolock: bool = False) -> None:
del nolock
if self._use_text_backend and self._write_stream is not None:
try:
self._write_stream.write("\n")
except (OSError, ValueError):
return
return
@classmethod
def write(
cls,
s: str,
file: Optional[IO[str]] = None,
end: str = "\n",
nolock: bool = False,
) -> None:
del nolock
target = file or sys.stderr
try:
target.write(str(s) + end)
except (OSError, ValueError):
return
@classmethod
def get_lock(cls) -> Any:
if cls._lock_obj is None:
try:
cls._lock_obj = threading.Lock()
except (OSError, RuntimeError):
cls._lock_obj = _DummyLock()
return cls._lock_obj
@classmethod
def set_lock(cls, lock: Any) -> None:
cls._lock_obj = lock
@classmethod
def external_write_mode(
cls, file: Optional[IO[str]] = None, nolock: bool = False
) -> ContextManager[IO[str]]:
"""Context manager mirroring tqdm.external_write_mode for compatibility."""
del nolock
class _CM:
def __enter__(self) -> IO[str]:
return file or sys.stderr
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> Literal[False]:
return False
return _CM()
@classmethod
def wrapattr(
cls,
stream: Any,
method: str,
total: Optional[int] = None,
bytes: bool = False, # pylint: disable=redefined-builtin
**kwargs: Any,
) -> Any:
"""Wrap a stream attribute call for compatibility with tqdm's wrapattr."""
del total, bytes, kwargs
func = getattr(stream, method)
class _Wrap:
def __init__(self, fn: Callable[..., Any]):
self._fn = fn
def __call__(self, *a: Any, **k: Any) -> Any:
return self._fn(*a, **k)
def __getattr__(self, name: str) -> Any:
return getattr(self._fn, name)
def __enter__(self) -> "_Wrap":
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> Literal[False]:
return False
def write(self, data: Any) -> Any:
return self._fn(data)
return _Wrap(func)
def __del__(self) -> None:
if getattr(self, "_closed", True):
return
self.close()
def _make_tqdm_callable(
console: Any = None, *, partial: bool = False
) -> Callable[[Optional[Iterable[Any]]], _TqdmWrapper]:
def _tqdm(iterable: Optional[Iterable[Any]] = None, **kwargs: Any) -> _TqdmWrapper:
return _TqdmWrapper(
iterable=iterable, console=console, partial=partial, **kwargs
)
_attach_api_shims(_tqdm)
return _tqdm
def _make_trange(
console: Any = None, *, partial: bool = False
) -> Callable[..., _TqdmWrapper]:
def _trange(*args: Any, **kwargs: Any) -> _TqdmWrapper:
return _make_tqdm_callable(console, partial=partial)(range(*args), **kwargs)
_attach_api_shims(_trange)
return _trange
def install_tqdm(console: Any = None, *, partial: bool = False) -> None:
"""Monkeypatch ``tqdm`` symbols to use Rich progress.
Args:
console: Optional Console to route progress output through.
partial: When True, enable simplified text rendering for file-like outputs to
approximate tqdm aesthetics. When False (default), favor the Rich progress
presentation and avoid extra tqdm-style formatting.
"""
if _PATCH_STATE["installed"]:
return
module_names = ["tqdm", "tqdm.std", "tqdm.auto", "tqdm.notebook"]
target_attrs = ["tqdm", "trange", "tqdm_notebook", "tnrange"]
replacements: Dict[Any, Any] = {}
to_patch: List[Tuple[types.ModuleType, str, Any]] = []
for modname in module_names:
try:
mod = importlib.import_module(modname)
except ImportError:
continue
for attr in target_attrs:
if not hasattr(mod, attr):
continue
orig = getattr(mod, attr)
if orig in replacements:
continue
newobj = (
_make_tqdm_callable(console, partial=partial)
if attr in ("tqdm", "tqdm_notebook")
else _make_trange(console, partial=partial)
)
replacements[orig] = newobj
to_patch.append((mod, attr, orig))
replaced_locations: List[Tuple[types.ModuleType, str, Any]] = []
for orig, newval in list(replacements.items()):
for module in list(sys.modules.values()):
moddict = getattr(module, "__dict__", None)
if not isinstance(moddict, dict):
continue
for name, val in list(moddict.items()):
if val is orig:
replaced_locations.append((module, name, orig))
try:
setattr(module, name, newval)
except (AttributeError, TypeError):
continue
patched_modules: List[Tuple[types.ModuleType, str, Any]] = []
for mod, attr, orig in to_patch:
replacement: Any = replacements.get(orig)
if replacement is None:
continue
new_callable = cast(Callable[..., Any], replacement)
try:
setattr(mod, attr, new_callable)
patched_modules.append((mod, attr, orig))
except (AttributeError, TypeError):
continue
_PATCH_STATE["installed"] = True
_PATCH_STATE["patched_modules"] = patched_modules
_PATCH_STATE["replaced_locations"] = replaced_locations
def uninstall_tqdm() -> None:
"""Revert monkeypatching performed by :func:`install_tqdm`."""
if not _PATCH_STATE["installed"]:
return
for module, name, orig in _PATCH_STATE.get("replaced_locations", []):
try:
setattr(module, name, orig)
except (AttributeError, TypeError):
continue
for module, attr, orig in _PATCH_STATE.get("patched_modules", []):
try:
setattr(module, attr, orig)
except (AttributeError, TypeError):
continue
_PATCH_STATE["installed"] = False
_PATCH_STATE["patched_modules"] = []
_PATCH_STATE["replaced_locations"] = []
_stop_global_progress()

View file

@ -0,0 +1,82 @@
"""Adapter install/uninstall coverage for the Rich tqdm shim."""
import sys
import types
from typing import cast
from rich.tqdm import _TqdmWrapper, install_tqdm, uninstall_tqdm
def test_install_replaces_imported_reference() -> None:
"""Replacing imported tqdm symbols should affect existing aliases."""
saved = sys.modules.pop("tqdm", None)
try:
fake = types.ModuleType("tqdm")
def orig_tqdm(iterable=None, **kwargs):
del kwargs
if iterable is None:
return None
return ("orig", list(iterable))
def orig_trange(*args, **kwargs):
del kwargs
return range(*args)
fake.tqdm = orig_tqdm
fake.trange = orig_trange
sys.modules["tqdm"] = fake
third = types.ModuleType("thirdparty")
third.tqdm_alias = fake.tqdm
sys.modules["thirdparty"] = third
try:
assert third.tqdm_alias is orig_tqdm
install_tqdm()
assert third.tqdm_alias is not orig_tqdm
it = third.tqdm_alias(range(3))
assert list(it) == [0, 1, 2]
finally:
uninstall_tqdm()
assert third.tqdm_alias is orig_tqdm
finally:
sys.modules.pop("thirdparty", None)
if saved is not None:
sys.modules["tqdm"] = saved
else:
sys.modules.pop("tqdm", None)
def test_install_replaces_hidden_reference_and_postfix_support() -> None:
"""Hidden tqdm references should update and still format postfix values."""
saved = sys.modules.pop("tqdm", None)
try:
fake = types.ModuleType("tqdm")
def orig_tqdm(iterable=None, **kwargs):
del kwargs
if iterable is None:
return None
return ("orig", list(iterable))
fake.tqdm = orig_tqdm
fake._hidden_tqdm = orig_tqdm # pylint: disable=protected-access
sys.modules["tqdm"] = fake
try:
install_tqdm()
patched = fake.tqdm
assert patched is not orig_tqdm
assert fake._hidden_tqdm is patched # pylint: disable=protected-access
progress = cast(_TqdmWrapper, patched(range(2)))
progress.update(postfix={"loss": 1.23}) # pylint: disable=no-member
assert progress.postfix == {"loss": "1.23"} # pylint: disable=no-member
finally:
uninstall_tqdm()
finally:
if saved is not None:
sys.modules["tqdm"] = saved
else:
sys.modules.pop("tqdm", None)

60
tests/test_tqdm_parity.py Normal file
View file

@ -0,0 +1,60 @@
"""Parity checks between Rich's tqdm shim and vendored tqdm reference.
Note: this test relies on a vendored copy of tqdm under ``tests/tqdm`` to
provide a reference implementation. If that folder is absent, the test is
skipped. Keep the vendored copy lightweight and in sync with the targeted tqdm
version when you want to exercise parity; otherwise the skip is expected in
normal installs.
"""
# pylint: disable=import-error
from __future__ import annotations
import importlib
import sys
from pathlib import Path
import pytest
from rich.tqdm import install_tqdm, uninstall_tqdm
SUBMODULE_ROOT = Path(__file__).parent / "tqdm"
def test_tqdm_adapter_len_and_disable_parity() -> None:
"""Use the vendored tqdm submodule as a baseline for adapter parity."""
if not SUBMODULE_ROOT.exists():
pytest.skip("vendored tqdm submodule missing")
added_path = False
submodule_path = str(SUBMODULE_ROOT)
if submodule_path not in sys.path:
sys.path.insert(0, submodule_path)
added_path = True
orig_mod = None
try:
orig_mod = importlib.import_module("tqdm")
orig_tqdm = orig_mod.tqdm
baseline_len = len(orig_tqdm(range(5)))
install_tqdm()
patched_tqdm = orig_mod.tqdm
assert patched_tqdm is not orig_tqdm
assert len(patched_tqdm(range(5))) == baseline_len
assert list(patched_tqdm(range(3), disable=True)) == [0, 1, 2]
finally:
uninstall_tqdm()
if orig_mod is not None:
importlib.reload(orig_mod)
sys.modules.pop("tqdm", None)
if added_path:
try:
sys.path.remove(submodule_path)
except ValueError:
pass

133
tests/test_tqdm_related.py Normal file
View file

@ -0,0 +1,133 @@
"""Unit-level helpers that mirror tqdm formatting and postfix handling."""
import io
from rich.tqdm import _TqdmWrapper, _TextBackend
# pylint: disable=protected-access
def format_num(n):
"""Intelligent scientific notation (.3g), mirroring tqdm."""
f = "{0:.3g}".format(n).replace("+0", "+").replace("-0", "-")
n_str = str(n)
return f if len(f) < len(n_str) else n_str
def test_format_num_matches_shorter_representation():
"""Ensure numeric formatting mirrors tqdm's shorter-of fixed/scientific rule."""
assert format_num(1337) == "1337"
# prefers the shorter of fixed or scientific
assert format_num(1239876) == "1239876"
assert format_num(0.00001234) == "1.23e-5"
assert format_num(-0.1234) == "-.123"
def test_wrapper_accepts_postfix_and_len():
"""Wrapper should expose len() and normalize postfix values."""
progress = _TqdmWrapper(range(3), total=None)
assert len(progress) == 3
progress.update(postfix={"loss": 1.2345})
assert progress.postfix == {"loss": "1.23"}
def test_set_postfix_allows_multiple_fields():
"""Multiple postfix fields remain formatted to short strings."""
progress = _TqdmWrapper(range(2), total=2)
progress.set_postfix({"loss": 0.9876}, acc=0.42)
assert progress.postfix == {"loss": "0.988", "acc": "0.42"}
def test_set_postfix_str_passthrough():
"""String postfix should pass through unchanged."""
progress = _TqdmWrapper(range(1), total=1)
progress.set_postfix_str("done")
assert progress.postfix == {"postfix": "done"}
def test_wrapper_disable_skips_progress_calls():
"""Disabled wrapper still increments counters without rendering."""
progress = _TqdmWrapper(range(2), disable=True)
assert list(progress) == [0, 1]
# update should not raise when disabled and should still track n
progress.update(2, postfix={"acc": 0.9})
assert progress.n == 4
def test_postfix_idempotent_when_unchanged():
"""Calling set_postfix with identical content should keep the same mapping."""
progress = _TqdmWrapper(range(1), total=1)
progress.set_postfix({"loss": 1.0})
first = progress.postfix
progress.set_postfix({"loss": 1.0})
assert progress.postfix is first
def test_text_backend_throttles(monkeypatch):
"""Text backend should skip renders inside the mininterval window."""
backend = _TextBackend(io.StringIO(), mininterval=0.1)
backend._last_render = -1.0 # noqa: SLF001 ensure first render passes throttle
ticks = iter([0.0, 0.05, 0.2])
monkeypatch.setattr("rich.tqdm.time.monotonic", lambda: next(ticks))
first = backend.render(n=1, total=2, desc="d", postfix={}, leave=True, to_string=False)
second = backend.render(n=2, total=2, desc="d", postfix={}, leave=True, to_string=False)
third = backend.render(n=2, total=2, desc="d", postfix={}, leave=True, to_string=False)
assert first
assert second == "" # throttled
assert third # rendered after interval
def test_should_render_thresholds():
"""Render should wait for both time and iter thresholds unless postfix changed."""
progress = _TqdmWrapper(range(1), mininterval=1.0, miniters=2)
progress._last_render_time = 0.0 # noqa: SLF001
progress._last_render_n = 0 # noqa: SLF001
progress.n = 1
assert progress._should_render(0.5, refresh=False, postfix_changed=False) is False # noqa: SLF001
progress.n = 2
assert progress._should_render(0.5, refresh=False, postfix_changed=False) is False # noqa: SLF001
assert progress._should_render(1.0, refresh=False, postfix_changed=False) is True # noqa: SLF001
progress._last_render_time = 0.0 # noqa: SLF001
progress.n = 0
assert progress._should_render(1.0, refresh=False, postfix_changed=True) is True # noqa: SLF001
def test_wrapattr_passthrough():
"""wrapattr should forward calls to the underlying method."""
buf = io.StringIO()
wrapped = _TqdmWrapper.wrapattr(buf, "write")
wrapped("ok")
assert buf.getvalue() == "ok"
def test_external_write_mode_uses_provided_file():
"""external_write_mode should enter with provided file object."""
buf = io.StringIO()
cm = _TqdmWrapper.external_write_mode(file=buf)
with cm as target:
target.write("line")
assert buf.getvalue() == "line"
def test_first_update_marks_render(monkeypatch):
"""Initial update should render and record counters instead of throttling away."""
monkeypatch.setattr("rich.tqdm.time.monotonic", lambda: 100.0)
progress = _TqdmWrapper(range(1), mininterval=0.1)
progress.update()
assert progress._last_render_n == progress.n # noqa: SLF001
def test_update_after_close_is_noop():
"""Updating after close should be ignored and not bump counters."""
progress = _TqdmWrapper(range(1))
progress.close()
progress.update()
assert progress.n == 0