gh-132975: Improve Remote PDB interrupt handling (#133223)

This commit is contained in:
Matt Wozniski 2025-05-05 15:33:59 -04:00 committed by GitHub
parent 24ebb9ccfd
commit 9434709edf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 346 additions and 188 deletions

View file

@ -77,6 +77,7 @@ import glob
import json
import token
import types
import atexit
import codeop
import pprint
import signal
@ -92,11 +93,12 @@ import tokenize
import itertools
import traceback
import linecache
import selectors
import threading
import _colorize
import _pyrepl.utils
from contextlib import closing
from contextlib import contextmanager
from contextlib import ExitStack, closing, contextmanager
from rlcompleter import Completer
from types import CodeType
from warnings import deprecated
@ -2670,12 +2672,21 @@ async def set_trace_async(*, header=None, commands=None):
# Remote PDB
class _PdbServer(Pdb):
def __init__(self, sockfile, owns_sockfile=True, **kwargs):
def __init__(
self,
sockfile,
signal_server=None,
owns_sockfile=True,
**kwargs,
):
self._owns_sockfile = owns_sockfile
self._interact_state = None
self._sockfile = sockfile
self._command_name_cache = []
self._write_failed = False
if signal_server:
# Only started by the top level _PdbServer, not recursive ones.
self._start_signal_listener(signal_server)
super().__init__(colorize=False, **kwargs)
@staticmethod
@ -2731,15 +2742,49 @@ class _PdbServer(Pdb):
f"PDB message doesn't follow the schema! {msg}"
)
@classmethod
def _start_signal_listener(cls, address):
def listener(sock):
with closing(sock):
# Check if the interpreter is finalizing every quarter of a second.
# Clean up and exit if so.
sock.settimeout(0.25)
sock.shutdown(socket.SHUT_WR)
while not shut_down.is_set():
try:
data = sock.recv(1024)
except socket.timeout:
continue
if data == b"":
return # EOF
signal.raise_signal(signal.SIGINT)
def stop_thread():
shut_down.set()
thread.join()
# Use a daemon thread so that we don't detach until after all non-daemon
# threads are done. Use an atexit handler to stop gracefully at that point,
# so that our thread is stopped before the interpreter is torn down.
shut_down = threading.Event()
thread = threading.Thread(
target=listener,
args=[socket.create_connection(address, timeout=5)],
daemon=True,
)
atexit.register(stop_thread)
thread.start()
def _send(self, **kwargs):
self._ensure_valid_message(kwargs)
json_payload = json.dumps(kwargs)
try:
self._sockfile.write(json_payload.encode() + b"\n")
self._sockfile.flush()
except OSError:
# This means that the client has abruptly disconnected, but we'll
# handle that the next time we try to read from the client instead
except (OSError, ValueError):
# We get an OSError if the network connection has dropped, and a
# ValueError if detach() if the sockfile has been closed. We'll
# handle this the next time we try to read from the client instead
# of trying to handle it from everywhere _send() may be called.
# Track this with a flag rather than assuming readline() will ever
# return an empty string because the socket may be half-closed.
@ -2967,10 +3012,15 @@ class _PdbServer(Pdb):
class _PdbClient:
def __init__(self, pid, sockfile, interrupt_script):
def __init__(self, pid, server_socket, interrupt_sock):
self.pid = pid
self.sockfile = sockfile
self.interrupt_script = interrupt_script
self.read_buf = b""
self.signal_read = None
self.signal_write = None
self.sigint_received = False
self.raise_on_sigint = False
self.server_socket = server_socket
self.interrupt_sock = interrupt_sock
self.pdb_instance = Pdb()
self.pdb_commands = set()
self.completion_matches = []
@ -3012,8 +3062,7 @@ class _PdbClient:
self._ensure_valid_message(kwargs)
json_payload = json.dumps(kwargs)
try:
self.sockfile.write(json_payload.encode() + b"\n")
self.sockfile.flush()
self.server_socket.sendall(json_payload.encode() + b"\n")
except OSError:
# This means that the client has abruptly disconnected, but we'll
# handle that the next time we try to read from the client instead
@ -3022,10 +3071,44 @@ class _PdbClient:
# return an empty string because the socket may be half-closed.
self.write_failed = True
def read_command(self, prompt):
self.multiline_block = False
reply = input(prompt)
def _readline(self):
if self.sigint_received:
# There's a pending unhandled SIGINT. Handle it now.
self.sigint_received = False
raise KeyboardInterrupt
# Wait for either a SIGINT or a line or EOF from the PDB server.
selector = selectors.DefaultSelector()
selector.register(self.signal_read, selectors.EVENT_READ)
selector.register(self.server_socket, selectors.EVENT_READ)
while b"\n" not in self.read_buf:
for key, _ in selector.select():
if key.fileobj == self.signal_read:
self.signal_read.recv(1024)
if self.sigint_received:
# If not, we're reading wakeup events for sigints that
# we've previously handled, and can ignore them.
self.sigint_received = False
raise KeyboardInterrupt
elif key.fileobj == self.server_socket:
data = self.server_socket.recv(16 * 1024)
self.read_buf += data
if not data and b"\n" not in self.read_buf:
# EOF without a full final line. Drop the partial line.
self.read_buf = b""
return b""
ret, sep, self.read_buf = self.read_buf.partition(b"\n")
return ret + sep
def read_input(self, prompt, multiline_block):
self.multiline_block = multiline_block
with self._sigint_raises_keyboard_interrupt():
return input(prompt)
def read_command(self, prompt):
reply = self.read_input(prompt, multiline_block=False)
if self.state == "dumb":
# No logic applied whatsoever, just pass the raw reply back.
return reply
@ -3048,10 +3131,9 @@ class _PdbClient:
return prefix + reply
# Otherwise, valid first line of a multi-line statement
self.multiline_block = True
continue_prompt = "...".ljust(len(prompt))
more_prompt = "...".ljust(len(prompt))
while codeop.compile_command(reply, "<stdin>", "single") is None:
reply += "\n" + input(continue_prompt)
reply += "\n" + self.read_input(more_prompt, multiline_block=True)
return prefix + reply
@ -3076,11 +3158,70 @@ class _PdbClient:
finally:
readline.set_completer(old_completer)
@contextmanager
def _sigint_handler(self):
# Signal handling strategy:
# - When we call input() we want a SIGINT to raise KeyboardInterrupt
# - Otherwise we want to write to the wakeup FD and set a flag.
# We'll break out of select() when the wakeup FD is written to,
# and we'll check the flag whenever we're about to accept input.
def handler(signum, frame):
self.sigint_received = True
if self.raise_on_sigint:
# One-shot; don't raise again until the flag is set again.
self.raise_on_sigint = False
self.sigint_received = False
raise KeyboardInterrupt
sentinel = object()
old_handler = sentinel
old_wakeup_fd = sentinel
self.signal_read, self.signal_write = socket.socketpair()
with (closing(self.signal_read), closing(self.signal_write)):
self.signal_read.setblocking(False)
self.signal_write.setblocking(False)
try:
old_handler = signal.signal(signal.SIGINT, handler)
try:
old_wakeup_fd = signal.set_wakeup_fd(
self.signal_write.fileno(),
warn_on_full_buffer=False,
)
yield
finally:
# Restore the old wakeup fd if we installed a new one
if old_wakeup_fd is not sentinel:
signal.set_wakeup_fd(old_wakeup_fd)
finally:
self.signal_read = self.signal_write = None
if old_handler is not sentinel:
# Restore the old handler if we installed a new one
signal.signal(signal.SIGINT, old_handler)
@contextmanager
def _sigint_raises_keyboard_interrupt(self):
if self.sigint_received:
# There's a pending unhandled SIGINT. Handle it now.
self.sigint_received = False
raise KeyboardInterrupt
try:
self.raise_on_sigint = True
yield
finally:
self.raise_on_sigint = False
def cmdloop(self):
with self.readline_completion(self.complete):
with (
self._sigint_handler(),
self.readline_completion(self.complete),
):
while not self.write_failed:
try:
if not (payload_bytes := self.sockfile.readline()):
if not (payload_bytes := self._readline()):
break
except KeyboardInterrupt:
self.send_interrupt()
@ -3098,11 +3239,17 @@ class _PdbClient:
self.process_payload(payload)
def send_interrupt(self):
print(
"\n*** Program will stop at the next bytecode instruction."
" (Use 'cont' to resume)."
)
sys.remote_exec(self.pid, self.interrupt_script)
if self.interrupt_sock is not None:
# Write to a socket that the PDB server listens on. This triggers
# the remote to raise a SIGINT for itself. We do this because
# Windows doesn't allow triggering SIGINT remotely.
# See https://stackoverflow.com/a/35792192 for many more details.
self.interrupt_sock.sendall(signal.SIGINT.to_bytes())
else:
# On Unix we can just send a SIGINT to the remote process.
# This is preferable to using the signal thread approach that we
# use on Windows because it can interrupt IO in the main thread.
os.kill(self.pid, signal.SIGINT)
def process_payload(self, payload):
match payload:
@ -3172,7 +3319,7 @@ class _PdbClient:
if self.write_failed:
return None
payload = self.sockfile.readline()
payload = self._readline()
if not payload:
return None
@ -3189,11 +3336,18 @@ class _PdbClient:
return None
def _connect(host, port, frame, commands, version):
def _connect(*, host, port, frame, commands, version, signal_raising_thread):
with closing(socket.create_connection((host, port))) as conn:
sockfile = conn.makefile("rwb")
remote_pdb = _PdbServer(sockfile)
# The client requests this thread on Windows but not on Unix.
# Most tests don't request this thread, to keep them simpler.
if signal_raising_thread:
signal_server = (host, port)
else:
signal_server = None
remote_pdb = _PdbServer(sockfile, signal_server=signal_server)
weakref.finalize(remote_pdb, sockfile.close)
if Pdb._last_pdb_instance is not None:
@ -3214,43 +3368,48 @@ def _connect(host, port, frame, commands, version):
def attach(pid, commands=()):
"""Attach to a running process with the given PID."""
with closing(socket.create_server(("localhost", 0))) as server:
with ExitStack() as stack:
server = stack.enter_context(
closing(socket.create_server(("localhost", 0)))
)
port = server.getsockname()[1]
with tempfile.NamedTemporaryFile("w", delete_on_close=False) as connect_script:
connect_script.write(
textwrap.dedent(
f"""
import pdb, sys
pdb._connect(
host="localhost",
port={port},
frame=sys._getframe(1),
commands={json.dumps("\n".join(commands))},
version={_PdbServer.protocol_version()},
)
"""
connect_script = stack.enter_context(
tempfile.NamedTemporaryFile("w", delete_on_close=False)
)
use_signal_thread = sys.platform == "win32"
connect_script.write(
textwrap.dedent(
f"""
import pdb, sys
pdb._connect(
host="localhost",
port={port},
frame=sys._getframe(1),
commands={json.dumps("\n".join(commands))},
version={_PdbServer.protocol_version()},
signal_raising_thread={use_signal_thread!r},
)
"""
)
connect_script.close()
sys.remote_exec(pid, connect_script.name)
)
connect_script.close()
sys.remote_exec(pid, connect_script.name)
# TODO Add a timeout? Or don't bother since the user can ^C?
client_sock, _ = server.accept()
# TODO Add a timeout? Or don't bother since the user can ^C?
client_sock, _ = server.accept()
stack.enter_context(closing(client_sock))
with closing(client_sock):
sockfile = client_sock.makefile("rwb")
if use_signal_thread:
interrupt_sock, _ = server.accept()
stack.enter_context(closing(interrupt_sock))
interrupt_sock.setblocking(False)
else:
interrupt_sock = None
with closing(sockfile):
with tempfile.NamedTemporaryFile("w", delete_on_close=False) as interrupt_script:
interrupt_script.write(
'import pdb, sys\n'
'if inst := pdb.Pdb._last_pdb_instance:\n'
' inst.set_trace(sys._getframe(1))\n'
)
interrupt_script.close()
_PdbClient(pid, sockfile, interrupt_script.name).cmdloop()
_PdbClient(pid, client_sock, interrupt_sock).cmdloop()
# Post-Mortem interface

View file

@ -12,7 +12,7 @@ import textwrap
import threading
import unittest
import unittest.mock
from contextlib import contextmanager, redirect_stdout, ExitStack
from contextlib import closing, contextmanager, redirect_stdout, ExitStack
from pathlib import Path
from test.support import is_wasi, os_helper, requires_subprocess, SHORT_TIMEOUT
from test.support.os_helper import temp_dir, TESTFN, unlink
@ -79,44 +79,6 @@ class MockSocketFile:
return results
class MockDebuggerSocket:
"""Mock file-like simulating a connection to a _RemotePdb instance"""
def __init__(self, incoming):
self.incoming = iter(incoming)
self.outgoing = []
self.buffered = bytearray()
def write(self, data: bytes) -> None:
"""Simulate write to socket."""
self.buffered += data
def flush(self) -> None:
"""Ensure each line is valid JSON."""
lines = self.buffered.splitlines(keepends=True)
self.buffered.clear()
for line in lines:
assert line.endswith(b"\n")
self.outgoing.append(json.loads(line))
def readline(self) -> bytes:
"""Read a line from the prepared input queue."""
# Anything written must be flushed before trying to read,
# since the read will be dependent upon the last write.
assert not self.buffered
try:
item = next(self.incoming)
if not isinstance(item, bytes):
item = json.dumps(item).encode()
return item + b"\n"
except StopIteration:
return b""
def close(self) -> None:
"""No-op close implementation."""
pass
class PdbClientTestCase(unittest.TestCase):
"""Tests for the _PdbClient class."""
@ -124,8 +86,11 @@ class PdbClientTestCase(unittest.TestCase):
self,
*,
incoming,
simulate_failure=None,
simulate_send_failure=False,
simulate_sigint_during_stdout_write=False,
use_interrupt_socket=False,
expected_outgoing=None,
expected_outgoing_signals=None,
expected_completions=None,
expected_exception=None,
expected_stdout="",
@ -134,6 +99,8 @@ class PdbClientTestCase(unittest.TestCase):
):
if expected_outgoing is None:
expected_outgoing = []
if expected_outgoing_signals is None:
expected_outgoing_signals = []
if expected_completions is None:
expected_completions = []
if expected_state is None:
@ -142,16 +109,6 @@ class PdbClientTestCase(unittest.TestCase):
expected_state.setdefault("write_failed", False)
messages = [m for source, m in incoming if source == "server"]
prompts = [m["prompt"] for source, m in incoming if source == "user"]
sockfile = MockDebuggerSocket(messages)
stdout = io.StringIO()
if simulate_failure:
sockfile.write = unittest.mock.Mock()
sockfile.flush = unittest.mock.Mock()
if simulate_failure == "write":
sockfile.write.side_effect = OSError("write failed")
elif simulate_failure == "flush":
sockfile.flush.side_effect = OSError("flush failed")
input_iter = (m for source, m in incoming if source == "user")
completions = []
@ -178,18 +135,60 @@ class PdbClientTestCase(unittest.TestCase):
reply = message["input"]
if isinstance(reply, BaseException):
raise reply
return reply
if isinstance(reply, str):
return reply
return reply()
with ExitStack() as stack:
client_sock, server_sock = socket.socketpair()
stack.enter_context(closing(client_sock))
stack.enter_context(closing(server_sock))
server_sock = unittest.mock.Mock(wraps=server_sock)
client_sock.sendall(
b"".join(
(m if isinstance(m, bytes) else json.dumps(m).encode()) + b"\n"
for m in messages
)
)
client_sock.shutdown(socket.SHUT_WR)
if simulate_send_failure:
server_sock.sendall = unittest.mock.Mock(
side_effect=OSError("sendall failed")
)
client_sock.shutdown(socket.SHUT_RD)
stdout = io.StringIO()
if simulate_sigint_during_stdout_write:
orig_stdout_write = stdout.write
def sigint_stdout_write(s):
signal.raise_signal(signal.SIGINT)
return orig_stdout_write(s)
stdout.write = sigint_stdout_write
input_mock = stack.enter_context(
unittest.mock.patch("pdb.input", side_effect=mock_input)
)
stack.enter_context(redirect_stdout(stdout))
if use_interrupt_socket:
interrupt_sock = unittest.mock.Mock(spec=socket.socket)
mock_kill = None
else:
interrupt_sock = None
mock_kill = stack.enter_context(
unittest.mock.patch("os.kill", spec=os.kill)
)
client = _PdbClient(
pid=0,
sockfile=sockfile,
interrupt_script="/a/b.py",
pid=12345,
server_socket=server_sock,
interrupt_sock=interrupt_sock,
)
if expected_exception is not None:
@ -199,13 +198,12 @@ class PdbClientTestCase(unittest.TestCase):
client.cmdloop()
actual_outgoing = sockfile.outgoing
if simulate_failure:
actual_outgoing += [
json.loads(msg.args[0]) for msg in sockfile.write.mock_calls
]
sent_msgs = [msg.args[0] for msg in server_sock.sendall.mock_calls]
for msg in sent_msgs:
assert msg.endswith(b"\n")
actual_outgoing = [json.loads(msg) for msg in sent_msgs]
self.assertEqual(sockfile.outgoing, expected_outgoing)
self.assertEqual(actual_outgoing, expected_outgoing)
self.assertEqual(completions, expected_completions)
if expected_stdout_substring and not expected_stdout:
self.assertIn(expected_stdout_substring, stdout.getvalue())
@ -215,6 +213,20 @@ class PdbClientTestCase(unittest.TestCase):
actual_state = {k: getattr(client, k) for k in expected_state}
self.assertEqual(actual_state, expected_state)
if use_interrupt_socket:
outgoing_signals = [
signal.Signals(int.from_bytes(call.args[0]))
for call in interrupt_sock.sendall.call_args_list
]
else:
assert mock_kill is not None
outgoing_signals = []
for call in mock_kill.call_args_list:
pid, signum = call.args
self.assertEqual(pid, 12345)
outgoing_signals.append(signal.Signals(signum))
self.assertEqual(outgoing_signals, expected_outgoing_signals)
def test_remote_immediately_closing_the_connection(self):
"""Test the behavior when the remote closes the connection immediately."""
incoming = []
@ -409,11 +421,17 @@ class PdbClientTestCase(unittest.TestCase):
expected_state={"state": "dumb"},
)
def test_keyboard_interrupt_at_prompt(self):
"""Test signaling when a prompt gets a KeyboardInterrupt."""
def test_sigint_at_prompt(self):
"""Test signaling when a prompt gets interrupted."""
incoming = [
("server", {"prompt": "(Pdb) ", "state": "pdb"}),
("user", {"prompt": "(Pdb) ", "input": KeyboardInterrupt()}),
(
"user",
{
"prompt": "(Pdb) ",
"input": lambda: signal.raise_signal(signal.SIGINT),
},
),
]
self.do_test(
incoming=incoming,
@ -423,6 +441,43 @@ class PdbClientTestCase(unittest.TestCase):
expected_state={"state": "pdb"},
)
def test_sigint_at_continuation_prompt(self):
"""Test signaling when a continuation prompt gets interrupted."""
incoming = [
("server", {"prompt": "(Pdb) ", "state": "pdb"}),
("user", {"prompt": "(Pdb) ", "input": "if True:"}),
(
"user",
{
"prompt": "... ",
"input": lambda: signal.raise_signal(signal.SIGINT),
},
),
]
self.do_test(
incoming=incoming,
expected_outgoing=[
{"signal": "INT"},
],
expected_state={"state": "pdb"},
)
def test_sigint_when_writing(self):
"""Test siginaling when sys.stdout.write() gets interrupted."""
incoming = [
("server", {"message": "Some message or other\n", "type": "info"}),
]
for use_interrupt_socket in [False, True]:
with self.subTest(use_interrupt_socket=use_interrupt_socket):
self.do_test(
incoming=incoming,
simulate_sigint_during_stdout_write=True,
use_interrupt_socket=use_interrupt_socket,
expected_outgoing=[],
expected_outgoing_signals=[signal.SIGINT],
expected_stdout="Some message or other\n",
)
def test_eof_at_prompt(self):
"""Test signaling when a prompt gets an EOFError."""
incoming = [
@ -478,20 +533,7 @@ class PdbClientTestCase(unittest.TestCase):
self.do_test(
incoming=incoming,
expected_outgoing=[{"signal": "INT"}],
simulate_failure="write",
expected_state={"write_failed": True},
)
def test_flush_failing(self):
"""Test terminating if flush fails due to a half closed socket."""
incoming = [
("server", {"prompt": "(Pdb) ", "state": "pdb"}),
("user", {"prompt": "(Pdb) ", "input": KeyboardInterrupt()}),
]
self.do_test(
incoming=incoming,
expected_outgoing=[{"signal": "INT"}],
simulate_failure="flush",
simulate_send_failure=True,
expected_state={"write_failed": True},
)
@ -660,42 +702,7 @@ class PdbClientTestCase(unittest.TestCase):
},
{"reply": "xyz"},
],
simulate_failure="write",
expected_completions=[],
expected_state={"state": "interact", "write_failed": True},
)
def test_flush_failure_during_completion(self):
"""Test failing to flush to the socket to request tab completions."""
incoming = [
("server", {"prompt": ">>> ", "state": "interact"}),
(
"user",
{
"prompt": ">>> ",
"completion_request": {
"line": "xy",
"begidx": 0,
"endidx": 2,
},
"input": "xyz",
},
),
]
self.do_test(
incoming=incoming,
expected_outgoing=[
{
"complete": {
"text": "xy",
"line": "xy",
"begidx": 0,
"endidx": 2,
}
},
{"reply": "xyz"},
],
simulate_failure="flush",
simulate_send_failure=True,
expected_completions=[],
expected_state={"state": "interact", "write_failed": True},
)
@ -1032,6 +1039,7 @@ class PdbConnectTestCase(unittest.TestCase):
frame=frame,
commands="",
version=pdb._PdbServer.protocol_version(),
signal_raising_thread=False,
)
return x # This line won't be reached in debugging
@ -1089,23 +1097,6 @@ class PdbConnectTestCase(unittest.TestCase):
client_file.write(json.dumps({"reply": command}).encode() + b"\n")
client_file.flush()
def _send_interrupt(self, pid):
"""Helper to send an interrupt signal to the debugger."""
# with tempfile.NamedTemporaryFile("w", delete_on_close=False) as interrupt_script:
interrupt_script = TESTFN + "_interrupt_script.py"
with open(interrupt_script, 'w') as f:
f.write(
'import pdb, sys\n'
'print("Hello, world!")\n'
'if inst := pdb.Pdb._last_pdb_instance:\n'
' inst.set_trace(sys._getframe(1))\n'
)
self.addCleanup(unlink, interrupt_script)
try:
sys.remote_exec(pid, interrupt_script)
except PermissionError:
self.skipTest("Insufficient permissions to execute code in remote process")
def test_connect_and_basic_commands(self):
"""Test connecting to a remote debugger and sending basic commands."""
self._create_script()
@ -1218,6 +1209,7 @@ class PdbConnectTestCase(unittest.TestCase):
frame=frame,
commands="",
version=pdb._PdbServer.protocol_version(),
signal_raising_thread=True,
)
print("Connected to debugger")
iterations = 50
@ -1233,6 +1225,10 @@ class PdbConnectTestCase(unittest.TestCase):
self._create_script(script=script)
process, client_file = self._connect_and_get_client_file()
# Accept a 2nd connection from the subprocess to tell it about signals
signal_sock, _ = self.server_sock.accept()
self.addCleanup(signal_sock.close)
with kill_on_error(process):
# Skip initial messages until we get to the prompt
self._read_until_prompt(client_file)
@ -1248,7 +1244,7 @@ class PdbConnectTestCase(unittest.TestCase):
break
# Inject a script to interrupt the running process
self._send_interrupt(process.pid)
signal_sock.sendall(signal.SIGINT.to_bytes())
messages = self._read_until_prompt(client_file)
# Verify we got the keyboard interrupt message.
@ -1304,6 +1300,7 @@ class PdbConnectTestCase(unittest.TestCase):
frame=frame,
commands="",
version=fake_version,
signal_raising_thread=False,
)
# This should print if the debugger detaches correctly

View file

@ -0,0 +1,2 @@
When PDB is attached to a remote process, do a better job of intercepting
Ctrl+C and forwarding it to the remote process.