Fix trio compatibility with owner task pattern for cancel scopes

🏠 Remote-Dev: homespace
This commit is contained in:
Michael Dworsky 2025-11-24 16:19:06 +00:00
parent 7a5b413159
commit 7dd64ff237
No known key found for this signature in database
3 changed files with 486 additions and 12 deletions

View file

@ -107,6 +107,11 @@ class Query:
self._closed = False
self._initialization_result: dict[str, Any] | None = None
# Owner task pattern: events for coordinating lifecycle
self._owner_stop_event: anyio.Event | None = None
self._owner_started_event: anyio.Event | None = None
self._outer_tg: anyio.abc.TaskGroup | None = None
async def initialize(self) -> dict[str, Any] | None:
"""Initialize control protocol if in streaming mode.
@ -152,11 +157,43 @@ class Query:
return response
async def start(self) -> None:
"""Start reading messages from transport."""
"""Start reading messages from transport.
Uses the owner task pattern to ensure the inner task group is properly
managed by a single task, which is required for trio compatibility.
"""
if self._tg is None:
self._tg = anyio.create_task_group()
await self._tg.__aenter__()
self._tg.start_soon(self._read_messages)
self._owner_stop_event = anyio.Event()
self._owner_started_event = anyio.Event()
# Outer task group spawns the owner task
self._outer_tg = anyio.create_task_group()
await self._outer_tg.__aenter__()
self._outer_tg.start_soon(self._task_group_owner)
# Wait for owner to signal it's ready
await self._owner_started_event.wait()
async def _task_group_owner(self) -> None:
"""Owner task that manages the inner task group.
This task owns the task group for its entire lifetime, ensuring that
the same task that enters the cancel scope also exits it. This is
required for trio compatibility.
"""
try:
async with anyio.create_task_group() as tg:
self._tg = tg
tg.start_soon(self._read_messages)
self._owner_started_event.set() # type: ignore[union-attr]
# Wait until close() signals us to stop
await self._owner_stop_event.wait() # type: ignore[union-attr]
# Cancel child tasks
tg.cancel_scope.cancel()
finally:
self._tg = None
async def _read_messages(self) -> None:
"""Read messages from transport and route them."""
@ -550,11 +587,17 @@ class Query:
async def close(self) -> None:
"""Close the query and transport."""
self._closed = True
if self._tg:
self._tg.cancel_scope.cancel()
# Wait for task group to complete cancellation
# Signal owner task to stop
if self._owner_stop_event:
self._owner_stop_event.set()
# Wait for outer task group to finish (owner will exit after stop event)
if self._outer_tg:
with suppress(anyio.get_cancelled_exc_class()):
await self._tg.__aexit__(None, None, None)
await self._outer_tg.__aexit__(None, None, None)
self._outer_tg = None
await self.transport.close()
# Make Query an async iterator

View file

@ -57,6 +57,8 @@ class SubprocessCLITransport(Transport):
self._stdin_stream: TextSendStream | None = None
self._stderr_stream: TextReceiveStream | None = None
self._stderr_task_group: anyio.abc.TaskGroup | None = None
self._stderr_stop_event: anyio.Event | None = None
self._stderr_started_event: anyio.Event | None = None
self._ready = False
self._exit_error: Exception | None = None # Track process exit errors
self._max_buffer_size = (
@ -340,10 +342,13 @@ class SubprocessCLITransport(Transport):
# Setup stderr stream if piped
if should_pipe_stderr and self._process.stderr:
self._stderr_stream = TextReceiveStream(self._process.stderr)
# Start async task to read stderr
# Start async task to read stderr using owner task pattern
self._stderr_stop_event = anyio.Event()
self._stderr_started_event = anyio.Event()
self._stderr_task_group = anyio.create_task_group()
await self._stderr_task_group.__aenter__()
self._stderr_task_group.start_soon(self._handle_stderr)
self._stderr_task_group.start_soon(self._stderr_owner_task)
await self._stderr_started_event.wait()
# Setup stdin for streaming mode
if self._is_streaming and self._process.stdin:
@ -370,6 +375,26 @@ class SubprocessCLITransport(Transport):
self._exit_error = error
raise error from e
async def _stderr_owner_task(self) -> None:
"""Owner task that manages the stderr reader task group.
This task owns the task group for its entire lifetime, ensuring that
the same task that enters the cancel scope also exits it. This is
required for trio compatibility.
"""
try:
async with anyio.create_task_group() as tg:
tg.start_soon(self._handle_stderr)
self._stderr_started_event.set() # type: ignore[union-attr]
# Wait until close() signals us to stop
await self._stderr_stop_event.wait() # type: ignore[union-attr]
# Cancel child tasks
tg.cancel_scope.cancel()
except Exception:
pass # Ignore errors during stderr task cleanup
async def _handle_stderr(self) -> None:
"""Handle stderr stream - read and invoke callbacks."""
if not self._stderr_stream:
@ -411,10 +436,13 @@ class SubprocessCLITransport(Transport):
if not self._process:
return
# Close stderr task group if active
# Signal stderr owner task to stop
if self._stderr_stop_event:
self._stderr_stop_event.set()
# Wait for stderr task group to finish (owner will exit after stop event)
if self._stderr_task_group:
with suppress(Exception):
self._stderr_task_group.cancel_scope.cancel()
await self._stderr_task_group.__aexit__(None, None, None)
self._stderr_task_group = None

View file

@ -0,0 +1,403 @@
"""Tests for the owner task pattern used for trio/asyncio compatibility.
The owner task pattern ensures that task groups are properly managed by a single
task, which is required for trio compatibility. These tests verify the pattern
works correctly when connect() and disconnect() are called from the same task.
Note: Cross-task connect/disconnect (calling connect() in one task and
disconnect() in another) is NOT supported due to cancel scope ownership
requirements. The owner task pattern ensures the INNER task group (which does
the actual message reading work) is properly managed.
"""
import json
from unittest.mock import AsyncMock, Mock, patch
import anyio
import pytest
from claude_agent_sdk import ClaudeSDKClient
from claude_agent_sdk._internal.query import Query
def create_mock_transport(with_init_response: bool = True) -> AsyncMock:
"""Create a properly configured mock transport.
Args:
with_init_response: If True, automatically respond to initialization request
"""
mock_transport = AsyncMock()
mock_transport.connect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)
written_messages: list[str] = []
async def mock_write(data: str) -> None:
written_messages.append(data)
mock_transport.write.side_effect = mock_write
async def control_protocol_generator():
if with_init_response:
# Use anyio.sleep for trio compatibility
await anyio.sleep(0.01)
for msg_str in written_messages:
try:
msg = json.loads(msg_str.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype") == "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default",
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Keep the generator alive briefly
timeout_counter = 0
while timeout_counter < 50:
await anyio.sleep(0.01)
timeout_counter += 1
mock_transport.read_messages = control_protocol_generator
return mock_transport
class TestQueryOwnerTaskPattern:
"""Test Query class owner task pattern lifecycle."""
def test_query_start_creates_owner_task(self):
"""Verify start() creates owner task and sets events."""
async def _test():
mock_transport = create_mock_transport()
query = Query(
transport=mock_transport,
is_streaming_mode=True,
)
await query.start()
# Verify owner task infrastructure is set up
assert query._owner_started_event is not None
assert query._owner_stop_event is not None
assert query._outer_tg is not None
assert query._tg is not None
# Verify started event is set (owner task is running)
assert query._owner_started_event.is_set()
# Clean up
await query.close()
anyio.run(_test)
def test_query_close_signals_stop_event(self):
"""Verify close() signals the owner task to stop."""
async def _test():
mock_transport = create_mock_transport()
query = Query(
transport=mock_transport,
is_streaming_mode=True,
)
await query.start()
await query.initialize()
# Store reference to stop event before close
stop_event = query._owner_stop_event
await query.close()
# Verify stop event was set
assert stop_event is not None
assert stop_event.is_set()
# Verify task group is cleaned up
assert query._tg is None
anyio.run(_test)
def test_query_double_close_is_safe(self):
"""Verify calling close() twice doesn't error."""
async def _test():
mock_transport = create_mock_transport()
query = Query(
transport=mock_transport,
is_streaming_mode=True,
)
await query.start()
await query.initialize()
# First close
await query.close()
# Second close should not raise
await query.close()
anyio.run(_test)
def test_query_close_without_start(self):
"""Verify close() works even if start() was never called."""
async def _test():
mock_transport = create_mock_transport()
query = Query(
transport=mock_transport,
is_streaming_mode=True,
)
# Close without start should not raise
await query.close()
anyio.run(_test)
class TestClientOwnerTaskPattern:
"""Test ClaudeSDKClient with owner task pattern."""
def test_client_context_manager_lifecycle(self):
"""Test that context manager properly manages owner task lifecycle."""
async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async with ClaudeSDKClient() as client:
# Verify query's owner task is running
assert client._query is not None
assert client._query._tg is not None
assert client._query._owner_started_event.is_set()
# After exit, transport should be closed
mock_transport.close.assert_called()
anyio.run(_test)
def test_client_manual_connect_disconnect(self):
"""Test manual connect/disconnect with owner task pattern."""
async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
client = ClaudeSDKClient()
await client.connect()
# Verify owner task is running
assert client._query is not None
assert client._query._owner_started_event.is_set()
await client.disconnect()
# Verify cleanup
assert client._query is None
anyio.run(_test)
def test_client_double_disconnect_is_safe(self):
"""Test that disconnecting twice doesn't error."""
async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
client = ClaudeSDKClient()
await client.connect()
await client.disconnect()
await client.disconnect() # Should not raise
anyio.run(_test)
class TestConcurrentOperations:
"""Test concurrent operations within the same async context.
Note: connect() and disconnect() must be called from the same task due to
cancel scope ownership requirements. However, query and other operations
can be performed concurrently while the client is connected.
"""
def test_query_operations_across_tasks(self):
"""Test that query operations work across different tasks."""
async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
client = ClaudeSDKClient()
await client.connect()
query_completed = anyio.Event()
async def query_in_different_task():
await client.query("Hello from another task")
query_completed.set()
async with anyio.create_task_group() as tg:
tg.start_soon(query_in_different_task)
with anyio.fail_after(5):
await query_completed.wait()
# Verify query was sent
write_calls = mock_transport.write.call_args_list
user_msg_found = False
for call in write_calls:
data = call[0][0]
try:
msg = json.loads(data.strip())
if msg.get("type") == "user":
assert "Hello from another task" in str(msg)
user_msg_found = True
break
except (json.JSONDecodeError, KeyError):
pass
assert user_msg_found
await client.disconnect()
anyio.run(_test)
def test_context_manager_with_concurrent_operations(self):
"""Test context manager properly handles concurrent operations."""
async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async with ClaudeSDKClient() as client:
# Start multiple concurrent queries
async def send_query(msg: str):
await client.query(msg)
async with anyio.create_task_group() as tg:
tg.start_soon(send_query, "Query 1")
tg.start_soon(send_query, "Query 2")
# Context manager ensures proper cleanup
mock_transport.close.assert_called()
anyio.run(_test)
class TestTrioBackend:
"""Tests that verify the owner task pattern works with trio backend.
These tests run with trio's stricter cancel scope rules to ensure
the implementation is compatible with both asyncio and trio.
"""
def test_client_with_trio_backend(self):
"""Verify client context manager works with trio backend."""
async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async with ClaudeSDKClient() as client:
assert client._query is not None
await client.query("test")
mock_transport.close.assert_called()
anyio.run(_test, backend="trio")
def test_query_lifecycle_with_trio_backend(self):
"""Verify Query lifecycle works with trio backend."""
async def _test():
mock_transport = create_mock_transport()
query = Query(
transport=mock_transport,
is_streaming_mode=True,
)
await query.start()
assert query._tg is not None
assert query._owner_started_event.is_set()
await query.close()
assert query._tg is None
anyio.run(_test, backend="trio")
def test_manual_connect_disconnect_with_trio_backend(self):
"""Verify manual connect/disconnect works with trio backend."""
async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
client = ClaudeSDKClient()
await client.connect()
assert client._query is not None
await client.query("test message")
await client.disconnect()
assert client._query is None
anyio.run(_test, backend="trio")
def test_concurrent_queries_with_trio_backend(self):
"""Verify concurrent operations work with trio backend."""
async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async with ClaudeSDKClient() as client:
async def send_query(msg: str):
await client.query(msg)
async with anyio.create_task_group() as tg:
tg.start_soon(send_query, "Query A")
tg.start_soon(send_query, "Query B")
anyio.run(_test, backend="trio")