diff --git a/rich/prompt.py b/rich/prompt.py index c7cf25ba..2afbea4a 100644 --- a/rich/prompt.py +++ b/rich/prompt.py @@ -1,3 +1,4 @@ +import threading from typing import Any, Generic, List, Optional, TextIO, TypeVar, Union, overload from . import get_console @@ -27,6 +28,16 @@ class InvalidResponse(PromptError): return self.message +class TimeoutError(PromptError): + """Exception raised when a prompt times out.""" + + def __init__(self, message: TextType) -> None: + self.message = message + + def __rich__(self) -> TextType: + return self.message + + class PromptBase(Generic[PromptType]): """Ask the user for input until a valid response is received. This is the base class, see one of the concrete classes for examples. @@ -39,6 +50,7 @@ class PromptBase(Generic[PromptType]): case_sensitive (bool, optional): Matching of choices should be case-sensitive. Defaults to True. show_default (bool, optional): Show default in prompt. Defaults to True. show_choices (bool, optional): Show choices in prompt. Defaults to True. + timeout (Optional[float], optional): Timeout in seconds. Defaults to None. """ response_type: type = str @@ -52,15 +64,16 @@ class PromptBase(Generic[PromptType]): choices: Optional[List[str]] = None def __init__( - self, - prompt: TextType = "", - *, - console: Optional[Console] = None, - password: bool = False, - choices: Optional[List[str]] = None, - case_sensitive: bool = True, - show_default: bool = True, - show_choices: bool = True, + self, + prompt: TextType = "", + *, + console: Optional[Console] = None, + password: bool = False, + choices: Optional[List[str]] = None, + case_sensitive: bool = True, + show_default: bool = True, + show_choices: bool = True, + timeout: Optional[float] = None, ) -> None: self.console = console or get_console() self.prompt = ( @@ -74,53 +87,57 @@ class PromptBase(Generic[PromptType]): self.case_sensitive = case_sensitive self.show_default = show_default self.show_choices = show_choices + self.timeout = timeout @classmethod @overload def ask( - cls, - prompt: TextType = "", - *, - console: Optional[Console] = None, - password: bool = False, - choices: Optional[List[str]] = None, - case_sensitive: bool = True, - show_default: bool = True, - show_choices: bool = True, - default: DefaultType, - stream: Optional[TextIO] = None, + cls, + prompt: TextType = "", + *, + console: Optional[Console] = None, + password: bool = False, + choices: Optional[List[str]] = None, + case_sensitive: bool = True, + show_default: bool = True, + show_choices: bool = True, + default: DefaultType, + timeout: Optional[float] = None, + stream: Optional[TextIO] = None, ) -> Union[DefaultType, PromptType]: ... @classmethod @overload def ask( - cls, - prompt: TextType = "", - *, - console: Optional[Console] = None, - password: bool = False, - choices: Optional[List[str]] = None, - case_sensitive: bool = True, - show_default: bool = True, - show_choices: bool = True, - stream: Optional[TextIO] = None, + cls, + prompt: TextType = "", + *, + console: Optional[Console] = None, + password: bool = False, + choices: Optional[List[str]] = None, + case_sensitive: bool = True, + show_default: bool = True, + show_choices: bool = True, + timeout: Optional[float] = None, + stream: Optional[TextIO] = None, ) -> PromptType: ... @classmethod def ask( - cls, - prompt: TextType = "", - *, - console: Optional[Console] = None, - password: bool = False, - choices: Optional[List[str]] = None, - case_sensitive: bool = True, - show_default: bool = True, - show_choices: bool = True, - default: Any = ..., - stream: Optional[TextIO] = None, + cls, + prompt: TextType = "", + *, + console: Optional[Console] = None, + password: bool = False, + choices: Optional[List[str]] = None, + case_sensitive: bool = True, + show_default: bool = True, + show_choices: bool = True, + default: Any = ..., + timeout: Optional[float] = None, + stream: Optional[TextIO] = None, ) -> Any: """Shortcut to construct and run a prompt loop and return the result. @@ -135,6 +152,7 @@ class PromptBase(Generic[PromptType]): case_sensitive (bool, optional): Matching of choices should be case-sensitive. Defaults to True. show_default (bool, optional): Show default in prompt. Defaults to True. show_choices (bool, optional): Show choices in prompt. Defaults to True. + timeout (Optional[float], optional): Timeout in seconds. Defaults to None. stream (TextIO, optional): Optional text file open for reading to get input. Defaults to None. """ _prompt = cls( @@ -145,6 +163,7 @@ class PromptBase(Generic[PromptType]): case_sensitive=case_sensitive, show_default=show_default, show_choices=show_choices, + timeout=timeout, ) return _prompt(default=default, stream=stream) @@ -178,9 +197,9 @@ class PromptBase(Generic[PromptType]): prompt.append(choices, "prompt.choices") if ( - default != ... - and self.show_default - and isinstance(default, (str, self.response_type)) + default != ... + and self.show_default + and isinstance(default, (str, self.response_type)) ): prompt.append(" ") _default = self.render_default(default) @@ -190,25 +209,49 @@ class PromptBase(Generic[PromptType]): return prompt - @classmethod def get_input( - cls, - console: Console, - prompt: TextType, - password: bool, - stream: Optional[TextIO] = None, + cls, + console: Console, + prompt: TextType, + password: bool, + stream: Optional[TextIO] = None, + timeout: Optional[float] = None, ) -> str: - """Get input from user. + """Get input from user with optional timeout. Args: console (Console): Console instance. prompt (TextType): Prompt text. password (bool): Enable password entry. + timeout (Optional[float]): Timeout in seconds. Returns: str: String from user. + + Raises: + TimeoutError: If the user does not respond within the timeout. """ - return console.input(prompt, password=password, stream=stream) + user_input = None + thread_finished = threading.Event() + + def read_input(): + nonlocal user_input + try: + user_input = console.input(prompt, password=password, stream=stream) + finally: + thread_finished.set() + + thread = threading.Thread(target=read_input) + thread.start() + + thread_finished.wait(timeout) + + if not thread_finished.is_set(): + raise TimeoutError("\nPrompt timed out.") + + thread.join() + + return user_input or "" def check_choice(self, value: str) -> bool: """Check value is in the list of valid choices. @@ -264,6 +307,14 @@ class PromptBase(Generic[PromptType]): """ self.console.print(error) + def on_timeout_error(self, error: TimeoutError) -> None: + """Called to handle timeout error. + + Args: + error (TimeoutError): Exception instance that initiated the error. + """ + self.console.print(f"\n{error.message}") + def pre_prompt(self) -> None: """Hook to display something before the prompt.""" @@ -273,7 +324,7 @@ class PromptBase(Generic[PromptType]): @overload def __call__( - self, *, default: DefaultType, stream: Optional[TextIO] = None + self, *, default: DefaultType, stream: Optional[TextIO] = None ) -> Union[PromptType, DefaultType]: ... @@ -289,7 +340,17 @@ class PromptBase(Generic[PromptType]): while True: self.pre_prompt() prompt = self.make_prompt(default) - value = self.get_input(self.console, prompt, self.password, stream=stream) + try: + value = self.get_input( + self.console, prompt, self.password, stream=stream, timeout=self.timeout + ) + except TimeoutError as error: + # Ensure the TimeoutError message is printed to console + self.on_timeout_error(error) + if default != ...: + return default + break + if value == "" and default != ...: return default try: @@ -301,13 +362,13 @@ class PromptBase(Generic[PromptType]): return return_value + class Prompt(PromptBase[str]): """A prompt that returns a str. Example: >>> name = Prompt.ask("Enter your name") - """ response_type = str @@ -367,34 +428,18 @@ if __name__ == "__main__": # pragma: no cover from rich import print if Confirm.ask("Run [i]prompt[/i] tests?", default=True): - while True: + try: result = IntPrompt.ask( - ":rocket: Enter a number between [b]1[/b] and [b]10[/b]", default=5 + ":rocket: Enter a number (will timeout in 10 seconds)", + default=5, + timeout=10, ) - if result >= 1 and result <= 10: - break - print(":pile_of_poo: [prompt.invalid]Number must be between 1 and 10") - print(f"number={result}") + print(f"number={result}") + except TimeoutError: + print("[prompt.invalid]Prompt timed out!") - while True: - password = Prompt.ask( - "Please enter a password [cyan](must be at least 5 characters)", - password=True, - ) - if len(password) >= 5: - break - print("[prompt.invalid]password too short") - print(f"password={password!r}") - - fruit = Prompt.ask("Enter a fruit", choices=["apple", "orange", "pear"]) - print(f"fruit={fruit!r}") - - doggie = Prompt.ask( - "What's the best Dog? (Case INSENSITIVE)", - choices=["Border Terrier", "Collie", "Labradoodle"], - case_sensitive=False, - ) - print(f"doggie={doggie!r}") - - else: - print("[b]OK :loudly_crying_face:") + try: + fruit = Prompt.ask("Enter a fruit", choices=["apple", "orange", "pear"], timeout=5) + print(f"fruit={fruit!r}") + except TimeoutError: + print("[prompt.invalid]You took too long to respond!") diff --git a/tests/test_prompt.py b/tests/test_prompt.py index 11bffa71..5951e1c6 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -1,4 +1,7 @@ import io +from time import sleep + +import pytest from rich.console import Console from rich.prompt import Confirm, IntPrompt, Prompt @@ -111,3 +114,27 @@ def test_prompt_confirm_default(): output = console.file.getvalue() print(repr(output)) assert output == expected + + +def test_prompt_timeout_handling(): + """Test that a timeout prints the correct message and returns default.""" + console = Console(file=io.StringIO()) + default_value = "Default" + result = Prompt.ask( + "Enter a value", + console=console, + timeout=0.1, + default=default_value + ) + assert result == default_value + + +def test_prompt_timeout_no_default(): + """Test that a timeout with no default does not raise but returns an empty string.""" + console = Console(file=io.StringIO()) + result = Prompt.ask( + "Enter a value", + console=console, + timeout=0.1 + ) + assert result == ""