added rich cast to protocol

This commit is contained in:
Will McGugan 2021-11-16 12:58:41 +00:00
parent 189731826b
commit 20b27d53e4
5 changed files with 80 additions and 16 deletions

View file

@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [10.13.1] - 2021-11-09
## [10.14.0] - Unreleased
### Fixed
@ -17,6 +17,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added file protocol to URL highlighter https://github.com/willmcgugan/rich/issues/1681
- Added rich.protocol.rich_cast
### Changed
- Allowed `__rich__` to work recursively
## [10.13.0] - 2021-11-07

View file

@ -25,6 +25,7 @@ from typing import (
Mapping,
NamedTuple,
Optional,
Set,
TextIO,
Tuple,
Type,
@ -53,6 +54,7 @@ from .markup import render as render_markup
from .measure import Measurement, measure_renderables
from .pager import Pager, SystemPager
from .pretty import Pretty, is_expandable
from .protocol import rich_cast
from .region import Region
from .scope import render_scope
from .screen import Screen
@ -1220,8 +1222,8 @@ class Console:
# No space to render anything. This prevents potential recursion errors.
return
render_iterable: RenderResult
if hasattr(renderable, "__rich__") and not isclass(renderable):
renderable = renderable.__rich__() # type: ignore
renderable = rich_cast(renderable)
if hasattr(renderable, "__rich_console__") and not isclass(renderable):
render_iterable = renderable.__rich_console__(self, _options) # type: ignore
elif isinstance(renderable, str):
@ -1439,15 +1441,7 @@ class Console:
del text[:]
for renderable in objects:
# I promise this is sane
# This detects an object which claims to have all attributes, such as MagicMock.mock_calls
if hasattr(
renderable, "jwevpw_eors4dfo6mwo345ermk7kdnfnwerwer"
): # pragma: no cover
renderable = repr(renderable)
rich_cast = getattr(renderable, "__rich__", None)
if rich_cast:
renderable = rich_cast()
renderable = rich_cast(renderable)
if isinstance(renderable, str):
append_text(
self.render_str(

View file

@ -2,7 +2,7 @@ from operator import itemgetter
from typing import Callable, Iterable, NamedTuple, Optional, TYPE_CHECKING
from . import errors
from .protocol import is_renderable
from .protocol import is_renderable, rich_cast
if TYPE_CHECKING:
from .console import Console, ConsoleOptions, RenderableType
@ -97,8 +97,7 @@ class Measurement(NamedTuple):
return Measurement(0, 0)
if isinstance(renderable, str):
renderable = console.render_str(renderable, markup=options.markup)
if hasattr(renderable, "__rich__"):
renderable = renderable.__rich__() # type: ignore
renderable = rich_cast(renderable)
if is_renderable(renderable):
get_console_width: Optional[
Callable[["Console", "ConsoleOptions"], "Measurement"]

View file

@ -1,4 +1,10 @@
from typing import Any
from typing import Any, Callable, cast, Set, TYPE_CHECKING
from inspect import isclass
if TYPE_CHECKING:
from rich.console import RenderableType
_GIBBERISH = """aihwerij235234ljsdnp34ksodfipwoe234234jlskjdf"""
def is_renderable(check_object: Any) -> bool:
@ -8,3 +14,29 @@ def is_renderable(check_object: Any) -> bool:
or hasattr(check_object, "__rich__")
or hasattr(check_object, "__rich_console__")
)
def rich_cast(renderable: object) -> "RenderableType":
"""Cast an object to a renderable by calling __rich__ if present.
Args:
renderable (object): A potentially renderable object
Returns:
object: The result of recursively calling __rich__.
"""
from rich.console import RenderableType
rich_visited_set: Set[type] = set() # Prevent potential infinite loop
while hasattr(renderable, "__rich__") and not isclass(renderable):
# Detect object which claim to have all the attributes
if hasattr(renderable, _GIBBERISH):
return repr(renderable)
cast_method = getattr(renderable, "__rich__")
renderable = cast_method()
renderable_type = type(renderable)
if renderable_type in rich_visited_set:
break
rich_visited_set.add(renderable_type)
return cast(RenderableType, renderable)

View file

@ -33,3 +33,37 @@ def test_abc():
assert not isinstance(foo, str)
assert not isinstance("foo", RichRenderable)
assert not isinstance([], RichRenderable)
def test_cast_deep():
class B:
def __rich__(self) -> Foo:
return Foo()
class A:
def __rich__(self) -> B:
return B()
console = Console(file=io.StringIO())
console.print(A())
assert console.file.getvalue() == "Foo\n"
def test_cast_recursive():
class B:
def __rich__(self) -> "A":
return A()
def __repr__(self) -> str:
return "<B>"
class A:
def __rich__(self) -> B:
return B()
def __repr__(self) -> str:
return "<A>"
console = Console(file=io.StringIO())
console.print(A())
assert console.file.getvalue() == "<B>\n"