This commit is contained in:
Miłosz Matuszewski 2025-12-04 16:18:29 +00:00 committed by GitHub
commit fadfb62d64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 155 additions and 83 deletions

View file

@ -1,3 +1,4 @@
import threading
from typing import Any, Generic, List, Optional, TextIO, TypeVar, Union, overload
from . import get_console
@ -27,6 +28,16 @@ class InvalidResponse(PromptError):
return self.message
class TimeoutError(PromptError):
"""Exception raised when a prompt times out."""
def __init__(self, message: TextType) -> None:
self.message = message
def __rich__(self) -> TextType:
return self.message
class PromptBase(Generic[PromptType]):
"""Ask the user for input until a valid response is received. This is the base class, see one of
the concrete classes for examples.
@ -39,6 +50,7 @@ class PromptBase(Generic[PromptType]):
case_sensitive (bool, optional): Matching of choices should be case-sensitive. Defaults to True.
show_default (bool, optional): Show default in prompt. Defaults to True.
show_choices (bool, optional): Show choices in prompt. Defaults to True.
timeout (Optional[float], optional): Timeout in seconds. Defaults to None.
"""
response_type: type = str
@ -52,15 +64,16 @@ class PromptBase(Generic[PromptType]):
choices: Optional[List[str]] = None
def __init__(
self,
prompt: TextType = "",
*,
console: Optional[Console] = None,
password: bool = False,
choices: Optional[List[str]] = None,
case_sensitive: bool = True,
show_default: bool = True,
show_choices: bool = True,
self,
prompt: TextType = "",
*,
console: Optional[Console] = None,
password: bool = False,
choices: Optional[List[str]] = None,
case_sensitive: bool = True,
show_default: bool = True,
show_choices: bool = True,
timeout: Optional[float] = None,
) -> None:
self.console = console or get_console()
self.prompt = (
@ -74,53 +87,57 @@ class PromptBase(Generic[PromptType]):
self.case_sensitive = case_sensitive
self.show_default = show_default
self.show_choices = show_choices
self.timeout = timeout
@classmethod
@overload
def ask(
cls,
prompt: TextType = "",
*,
console: Optional[Console] = None,
password: bool = False,
choices: Optional[List[str]] = None,
case_sensitive: bool = True,
show_default: bool = True,
show_choices: bool = True,
default: DefaultType,
stream: Optional[TextIO] = None,
cls,
prompt: TextType = "",
*,
console: Optional[Console] = None,
password: bool = False,
choices: Optional[List[str]] = None,
case_sensitive: bool = True,
show_default: bool = True,
show_choices: bool = True,
default: DefaultType,
timeout: Optional[float] = None,
stream: Optional[TextIO] = None,
) -> Union[DefaultType, PromptType]:
...
@classmethod
@overload
def ask(
cls,
prompt: TextType = "",
*,
console: Optional[Console] = None,
password: bool = False,
choices: Optional[List[str]] = None,
case_sensitive: bool = True,
show_default: bool = True,
show_choices: bool = True,
stream: Optional[TextIO] = None,
cls,
prompt: TextType = "",
*,
console: Optional[Console] = None,
password: bool = False,
choices: Optional[List[str]] = None,
case_sensitive: bool = True,
show_default: bool = True,
show_choices: bool = True,
timeout: Optional[float] = None,
stream: Optional[TextIO] = None,
) -> PromptType:
...
@classmethod
def ask(
cls,
prompt: TextType = "",
*,
console: Optional[Console] = None,
password: bool = False,
choices: Optional[List[str]] = None,
case_sensitive: bool = True,
show_default: bool = True,
show_choices: bool = True,
default: Any = ...,
stream: Optional[TextIO] = None,
cls,
prompt: TextType = "",
*,
console: Optional[Console] = None,
password: bool = False,
choices: Optional[List[str]] = None,
case_sensitive: bool = True,
show_default: bool = True,
show_choices: bool = True,
default: Any = ...,
timeout: Optional[float] = None,
stream: Optional[TextIO] = None,
) -> Any:
"""Shortcut to construct and run a prompt loop and return the result.
@ -135,6 +152,7 @@ class PromptBase(Generic[PromptType]):
case_sensitive (bool, optional): Matching of choices should be case-sensitive. Defaults to True.
show_default (bool, optional): Show default in prompt. Defaults to True.
show_choices (bool, optional): Show choices in prompt. Defaults to True.
timeout (Optional[float], optional): Timeout in seconds. Defaults to None.
stream (TextIO, optional): Optional text file open for reading to get input. Defaults to None.
"""
_prompt = cls(
@ -145,6 +163,7 @@ class PromptBase(Generic[PromptType]):
case_sensitive=case_sensitive,
show_default=show_default,
show_choices=show_choices,
timeout=timeout,
)
return _prompt(default=default, stream=stream)
@ -178,9 +197,9 @@ class PromptBase(Generic[PromptType]):
prompt.append(choices, "prompt.choices")
if (
default != ...
and self.show_default
and isinstance(default, (str, self.response_type))
default != ...
and self.show_default
and isinstance(default, (str, self.response_type))
):
prompt.append(" ")
_default = self.render_default(default)
@ -190,25 +209,49 @@ class PromptBase(Generic[PromptType]):
return prompt
@classmethod
def get_input(
cls,
console: Console,
prompt: TextType,
password: bool,
stream: Optional[TextIO] = None,
cls,
console: Console,
prompt: TextType,
password: bool,
stream: Optional[TextIO] = None,
timeout: Optional[float] = None,
) -> str:
"""Get input from user.
"""Get input from user with optional timeout.
Args:
console (Console): Console instance.
prompt (TextType): Prompt text.
password (bool): Enable password entry.
timeout (Optional[float]): Timeout in seconds.
Returns:
str: String from user.
Raises:
TimeoutError: If the user does not respond within the timeout.
"""
return console.input(prompt, password=password, stream=stream)
user_input = None
thread_finished = threading.Event()
def read_input():
nonlocal user_input
try:
user_input = console.input(prompt, password=password, stream=stream)
finally:
thread_finished.set()
thread = threading.Thread(target=read_input)
thread.start()
thread_finished.wait(timeout)
if not thread_finished.is_set():
raise TimeoutError("\nPrompt timed out.")
thread.join()
return user_input or ""
def check_choice(self, value: str) -> bool:
"""Check value is in the list of valid choices.
@ -264,6 +307,14 @@ class PromptBase(Generic[PromptType]):
"""
self.console.print(error)
def on_timeout_error(self, error: TimeoutError) -> None:
"""Called to handle timeout error.
Args:
error (TimeoutError): Exception instance that initiated the error.
"""
self.console.print(f"\n{error.message}")
def pre_prompt(self) -> None:
"""Hook to display something before the prompt."""
@ -273,7 +324,7 @@ class PromptBase(Generic[PromptType]):
@overload
def __call__(
self, *, default: DefaultType, stream: Optional[TextIO] = None
self, *, default: DefaultType, stream: Optional[TextIO] = None
) -> Union[PromptType, DefaultType]:
...
@ -289,7 +340,17 @@ class PromptBase(Generic[PromptType]):
while True:
self.pre_prompt()
prompt = self.make_prompt(default)
value = self.get_input(self.console, prompt, self.password, stream=stream)
try:
value = self.get_input(
self.console, prompt, self.password, stream=stream, timeout=self.timeout
)
except TimeoutError as error:
# Ensure the TimeoutError message is printed to console
self.on_timeout_error(error)
if default != ...:
return default
break
if value == "" and default != ...:
return default
try:
@ -301,13 +362,13 @@ class PromptBase(Generic[PromptType]):
return return_value
class Prompt(PromptBase[str]):
"""A prompt that returns a str.
Example:
>>> name = Prompt.ask("Enter your name")
"""
response_type = str
@ -367,34 +428,18 @@ if __name__ == "__main__": # pragma: no cover
from rich import print
if Confirm.ask("Run [i]prompt[/i] tests?", default=True):
while True:
try:
result = IntPrompt.ask(
":rocket: Enter a number between [b]1[/b] and [b]10[/b]", default=5
":rocket: Enter a number (will timeout in 10 seconds)",
default=5,
timeout=10,
)
if result >= 1 and result <= 10:
break
print(":pile_of_poo: [prompt.invalid]Number must be between 1 and 10")
print(f"number={result}")
print(f"number={result}")
except TimeoutError:
print("[prompt.invalid]Prompt timed out!")
while True:
password = Prompt.ask(
"Please enter a password [cyan](must be at least 5 characters)",
password=True,
)
if len(password) >= 5:
break
print("[prompt.invalid]password too short")
print(f"password={password!r}")
fruit = Prompt.ask("Enter a fruit", choices=["apple", "orange", "pear"])
print(f"fruit={fruit!r}")
doggie = Prompt.ask(
"What's the best Dog? (Case INSENSITIVE)",
choices=["Border Terrier", "Collie", "Labradoodle"],
case_sensitive=False,
)
print(f"doggie={doggie!r}")
else:
print("[b]OK :loudly_crying_face:")
try:
fruit = Prompt.ask("Enter a fruit", choices=["apple", "orange", "pear"], timeout=5)
print(f"fruit={fruit!r}")
except TimeoutError:
print("[prompt.invalid]You took too long to respond!")

View file

@ -1,4 +1,7 @@
import io
from time import sleep
import pytest
from rich.console import Console
from rich.prompt import Confirm, IntPrompt, Prompt
@ -111,3 +114,27 @@ def test_prompt_confirm_default():
output = console.file.getvalue()
print(repr(output))
assert output == expected
def test_prompt_timeout_handling():
"""Test that a timeout prints the correct message and returns default."""
console = Console(file=io.StringIO())
default_value = "Default"
result = Prompt.ask(
"Enter a value",
console=console,
timeout=0.1,
default=default_value
)
assert result == default_value
def test_prompt_timeout_no_default():
"""Test that a timeout with no default does not raise but returns an empty string."""
console = Console(file=io.StringIO())
result = Prompt.ask(
"Enter a value",
console=console,
timeout=0.1
)
assert result == ""