mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
Fix trio compatibility with owner task pattern for cancel scopes
🏠 Remote-Dev: homespace
This commit is contained in:
parent
7a5b413159
commit
7dd64ff237
3 changed files with 486 additions and 12 deletions
|
|
@ -107,6 +107,11 @@ class Query:
|
|||
self._closed = False
|
||||
self._initialization_result: dict[str, Any] | None = None
|
||||
|
||||
# Owner task pattern: events for coordinating lifecycle
|
||||
self._owner_stop_event: anyio.Event | None = None
|
||||
self._owner_started_event: anyio.Event | None = None
|
||||
self._outer_tg: anyio.abc.TaskGroup | None = None
|
||||
|
||||
async def initialize(self) -> dict[str, Any] | None:
|
||||
"""Initialize control protocol if in streaming mode.
|
||||
|
||||
|
|
@ -152,11 +157,43 @@ class Query:
|
|||
return response
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start reading messages from transport."""
|
||||
"""Start reading messages from transport.
|
||||
|
||||
Uses the owner task pattern to ensure the inner task group is properly
|
||||
managed by a single task, which is required for trio compatibility.
|
||||
"""
|
||||
if self._tg is None:
|
||||
self._tg = anyio.create_task_group()
|
||||
await self._tg.__aenter__()
|
||||
self._tg.start_soon(self._read_messages)
|
||||
self._owner_stop_event = anyio.Event()
|
||||
self._owner_started_event = anyio.Event()
|
||||
|
||||
# Outer task group spawns the owner task
|
||||
self._outer_tg = anyio.create_task_group()
|
||||
await self._outer_tg.__aenter__()
|
||||
self._outer_tg.start_soon(self._task_group_owner)
|
||||
|
||||
# Wait for owner to signal it's ready
|
||||
await self._owner_started_event.wait()
|
||||
|
||||
async def _task_group_owner(self) -> None:
|
||||
"""Owner task that manages the inner task group.
|
||||
|
||||
This task owns the task group for its entire lifetime, ensuring that
|
||||
the same task that enters the cancel scope also exits it. This is
|
||||
required for trio compatibility.
|
||||
"""
|
||||
try:
|
||||
async with anyio.create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._read_messages)
|
||||
self._owner_started_event.set() # type: ignore[union-attr]
|
||||
|
||||
# Wait until close() signals us to stop
|
||||
await self._owner_stop_event.wait() # type: ignore[union-attr]
|
||||
|
||||
# Cancel child tasks
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
self._tg = None
|
||||
|
||||
async def _read_messages(self) -> None:
|
||||
"""Read messages from transport and route them."""
|
||||
|
|
@ -550,11 +587,17 @@ class Query:
|
|||
async def close(self) -> None:
|
||||
"""Close the query and transport."""
|
||||
self._closed = True
|
||||
if self._tg:
|
||||
self._tg.cancel_scope.cancel()
|
||||
# Wait for task group to complete cancellation
|
||||
|
||||
# Signal owner task to stop
|
||||
if self._owner_stop_event:
|
||||
self._owner_stop_event.set()
|
||||
|
||||
# Wait for outer task group to finish (owner will exit after stop event)
|
||||
if self._outer_tg:
|
||||
with suppress(anyio.get_cancelled_exc_class()):
|
||||
await self._tg.__aexit__(None, None, None)
|
||||
await self._outer_tg.__aexit__(None, None, None)
|
||||
self._outer_tg = None
|
||||
|
||||
await self.transport.close()
|
||||
|
||||
# Make Query an async iterator
|
||||
|
|
|
|||
|
|
@ -57,6 +57,8 @@ class SubprocessCLITransport(Transport):
|
|||
self._stdin_stream: TextSendStream | None = None
|
||||
self._stderr_stream: TextReceiveStream | None = None
|
||||
self._stderr_task_group: anyio.abc.TaskGroup | None = None
|
||||
self._stderr_stop_event: anyio.Event | None = None
|
||||
self._stderr_started_event: anyio.Event | None = None
|
||||
self._ready = False
|
||||
self._exit_error: Exception | None = None # Track process exit errors
|
||||
self._max_buffer_size = (
|
||||
|
|
@ -340,10 +342,13 @@ class SubprocessCLITransport(Transport):
|
|||
# Setup stderr stream if piped
|
||||
if should_pipe_stderr and self._process.stderr:
|
||||
self._stderr_stream = TextReceiveStream(self._process.stderr)
|
||||
# Start async task to read stderr
|
||||
# Start async task to read stderr using owner task pattern
|
||||
self._stderr_stop_event = anyio.Event()
|
||||
self._stderr_started_event = anyio.Event()
|
||||
self._stderr_task_group = anyio.create_task_group()
|
||||
await self._stderr_task_group.__aenter__()
|
||||
self._stderr_task_group.start_soon(self._handle_stderr)
|
||||
self._stderr_task_group.start_soon(self._stderr_owner_task)
|
||||
await self._stderr_started_event.wait()
|
||||
|
||||
# Setup stdin for streaming mode
|
||||
if self._is_streaming and self._process.stdin:
|
||||
|
|
@ -370,6 +375,26 @@ class SubprocessCLITransport(Transport):
|
|||
self._exit_error = error
|
||||
raise error from e
|
||||
|
||||
async def _stderr_owner_task(self) -> None:
|
||||
"""Owner task that manages the stderr reader task group.
|
||||
|
||||
This task owns the task group for its entire lifetime, ensuring that
|
||||
the same task that enters the cancel scope also exits it. This is
|
||||
required for trio compatibility.
|
||||
"""
|
||||
try:
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(self._handle_stderr)
|
||||
self._stderr_started_event.set() # type: ignore[union-attr]
|
||||
|
||||
# Wait until close() signals us to stop
|
||||
await self._stderr_stop_event.wait() # type: ignore[union-attr]
|
||||
|
||||
# Cancel child tasks
|
||||
tg.cancel_scope.cancel()
|
||||
except Exception:
|
||||
pass # Ignore errors during stderr task cleanup
|
||||
|
||||
async def _handle_stderr(self) -> None:
|
||||
"""Handle stderr stream - read and invoke callbacks."""
|
||||
if not self._stderr_stream:
|
||||
|
|
@ -411,10 +436,13 @@ class SubprocessCLITransport(Transport):
|
|||
if not self._process:
|
||||
return
|
||||
|
||||
# Close stderr task group if active
|
||||
# Signal stderr owner task to stop
|
||||
if self._stderr_stop_event:
|
||||
self._stderr_stop_event.set()
|
||||
|
||||
# Wait for stderr task group to finish (owner will exit after stop event)
|
||||
if self._stderr_task_group:
|
||||
with suppress(Exception):
|
||||
self._stderr_task_group.cancel_scope.cancel()
|
||||
await self._stderr_task_group.__aexit__(None, None, None)
|
||||
self._stderr_task_group = None
|
||||
|
||||
|
|
|
|||
403
tests/test_owner_task_pattern.py
Normal file
403
tests/test_owner_task_pattern.py
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
"""Tests for the owner task pattern used for trio/asyncio compatibility.
|
||||
|
||||
The owner task pattern ensures that task groups are properly managed by a single
|
||||
task, which is required for trio compatibility. These tests verify the pattern
|
||||
works correctly when connect() and disconnect() are called from the same task.
|
||||
|
||||
Note: Cross-task connect/disconnect (calling connect() in one task and
|
||||
disconnect() in another) is NOT supported due to cancel scope ownership
|
||||
requirements. The owner task pattern ensures the INNER task group (which does
|
||||
the actual message reading work) is properly managed.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
from claude_agent_sdk._internal.query import Query
|
||||
|
||||
|
||||
def create_mock_transport(with_init_response: bool = True) -> AsyncMock:
|
||||
"""Create a properly configured mock transport.
|
||||
|
||||
Args:
|
||||
with_init_response: If True, automatically respond to initialization request
|
||||
"""
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport.connect = AsyncMock()
|
||||
mock_transport.close = AsyncMock()
|
||||
mock_transport.end_input = AsyncMock()
|
||||
mock_transport.write = AsyncMock()
|
||||
mock_transport.is_ready = Mock(return_value=True)
|
||||
|
||||
written_messages: list[str] = []
|
||||
|
||||
async def mock_write(data: str) -> None:
|
||||
written_messages.append(data)
|
||||
|
||||
mock_transport.write.side_effect = mock_write
|
||||
|
||||
async def control_protocol_generator():
|
||||
if with_init_response:
|
||||
# Use anyio.sleep for trio compatibility
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
for msg_str in written_messages:
|
||||
try:
|
||||
msg = json.loads(msg_str.strip())
|
||||
if (
|
||||
msg.get("type") == "control_request"
|
||||
and msg.get("request", {}).get("subtype") == "initialize"
|
||||
):
|
||||
yield {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"request_id": msg.get("request_id"),
|
||||
"subtype": "success",
|
||||
"commands": [],
|
||||
"output_style": "default",
|
||||
},
|
||||
}
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError, AttributeError):
|
||||
pass
|
||||
|
||||
# Keep the generator alive briefly
|
||||
timeout_counter = 0
|
||||
while timeout_counter < 50:
|
||||
await anyio.sleep(0.01)
|
||||
timeout_counter += 1
|
||||
|
||||
mock_transport.read_messages = control_protocol_generator
|
||||
return mock_transport
|
||||
|
||||
|
||||
class TestQueryOwnerTaskPattern:
|
||||
"""Test Query class owner task pattern lifecycle."""
|
||||
|
||||
def test_query_start_creates_owner_task(self):
|
||||
"""Verify start() creates owner task and sets events."""
|
||||
|
||||
async def _test():
|
||||
mock_transport = create_mock_transport()
|
||||
|
||||
query = Query(
|
||||
transport=mock_transport,
|
||||
is_streaming_mode=True,
|
||||
)
|
||||
|
||||
await query.start()
|
||||
|
||||
# Verify owner task infrastructure is set up
|
||||
assert query._owner_started_event is not None
|
||||
assert query._owner_stop_event is not None
|
||||
assert query._outer_tg is not None
|
||||
assert query._tg is not None
|
||||
|
||||
# Verify started event is set (owner task is running)
|
||||
assert query._owner_started_event.is_set()
|
||||
|
||||
# Clean up
|
||||
await query.close()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_query_close_signals_stop_event(self):
|
||||
"""Verify close() signals the owner task to stop."""
|
||||
|
||||
async def _test():
|
||||
mock_transport = create_mock_transport()
|
||||
|
||||
query = Query(
|
||||
transport=mock_transport,
|
||||
is_streaming_mode=True,
|
||||
)
|
||||
|
||||
await query.start()
|
||||
await query.initialize()
|
||||
|
||||
# Store reference to stop event before close
|
||||
stop_event = query._owner_stop_event
|
||||
|
||||
await query.close()
|
||||
|
||||
# Verify stop event was set
|
||||
assert stop_event is not None
|
||||
assert stop_event.is_set()
|
||||
|
||||
# Verify task group is cleaned up
|
||||
assert query._tg is None
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_query_double_close_is_safe(self):
|
||||
"""Verify calling close() twice doesn't error."""
|
||||
|
||||
async def _test():
|
||||
mock_transport = create_mock_transport()
|
||||
|
||||
query = Query(
|
||||
transport=mock_transport,
|
||||
is_streaming_mode=True,
|
||||
)
|
||||
|
||||
await query.start()
|
||||
await query.initialize()
|
||||
|
||||
# First close
|
||||
await query.close()
|
||||
|
||||
# Second close should not raise
|
||||
await query.close()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_query_close_without_start(self):
|
||||
"""Verify close() works even if start() was never called."""
|
||||
|
||||
async def _test():
|
||||
mock_transport = create_mock_transport()
|
||||
|
||||
query = Query(
|
||||
transport=mock_transport,
|
||||
is_streaming_mode=True,
|
||||
)
|
||||
|
||||
# Close without start should not raise
|
||||
await query.close()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
||||
class TestClientOwnerTaskPattern:
|
||||
"""Test ClaudeSDKClient with owner task pattern."""
|
||||
|
||||
def test_client_context_manager_lifecycle(self):
|
||||
"""Test that context manager properly manages owner task lifecycle."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Verify query's owner task is running
|
||||
assert client._query is not None
|
||||
assert client._query._tg is not None
|
||||
assert client._query._owner_started_event.is_set()
|
||||
|
||||
# After exit, transport should be closed
|
||||
mock_transport.close.assert_called()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_client_manual_connect_disconnect(self):
|
||||
"""Test manual connect/disconnect with owner task pattern."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
|
||||
await client.connect()
|
||||
|
||||
# Verify owner task is running
|
||||
assert client._query is not None
|
||||
assert client._query._owner_started_event.is_set()
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
# Verify cleanup
|
||||
assert client._query is None
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_client_double_disconnect_is_safe(self):
|
||||
"""Test that disconnecting twice doesn't error."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
await client.connect()
|
||||
|
||||
await client.disconnect()
|
||||
await client.disconnect() # Should not raise
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
||||
class TestConcurrentOperations:
|
||||
"""Test concurrent operations within the same async context.
|
||||
|
||||
Note: connect() and disconnect() must be called from the same task due to
|
||||
cancel scope ownership requirements. However, query and other operations
|
||||
can be performed concurrently while the client is connected.
|
||||
"""
|
||||
|
||||
def test_query_operations_across_tasks(self):
|
||||
"""Test that query operations work across different tasks."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
await client.connect()
|
||||
|
||||
query_completed = anyio.Event()
|
||||
|
||||
async def query_in_different_task():
|
||||
await client.query("Hello from another task")
|
||||
query_completed.set()
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(query_in_different_task)
|
||||
with anyio.fail_after(5):
|
||||
await query_completed.wait()
|
||||
|
||||
# Verify query was sent
|
||||
write_calls = mock_transport.write.call_args_list
|
||||
user_msg_found = False
|
||||
for call in write_calls:
|
||||
data = call[0][0]
|
||||
try:
|
||||
msg = json.loads(data.strip())
|
||||
if msg.get("type") == "user":
|
||||
assert "Hello from another task" in str(msg)
|
||||
user_msg_found = True
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
assert user_msg_found
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_context_manager_with_concurrent_operations(self):
|
||||
"""Test context manager properly handles concurrent operations."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Start multiple concurrent queries
|
||||
async def send_query(msg: str):
|
||||
await client.query(msg)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(send_query, "Query 1")
|
||||
tg.start_soon(send_query, "Query 2")
|
||||
|
||||
# Context manager ensures proper cleanup
|
||||
mock_transport.close.assert_called()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
||||
class TestTrioBackend:
|
||||
"""Tests that verify the owner task pattern works with trio backend.
|
||||
|
||||
These tests run with trio's stricter cancel scope rules to ensure
|
||||
the implementation is compatible with both asyncio and trio.
|
||||
"""
|
||||
|
||||
def test_client_with_trio_backend(self):
|
||||
"""Verify client context manager works with trio backend."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
assert client._query is not None
|
||||
await client.query("test")
|
||||
|
||||
mock_transport.close.assert_called()
|
||||
|
||||
anyio.run(_test, backend="trio")
|
||||
|
||||
def test_query_lifecycle_with_trio_backend(self):
|
||||
"""Verify Query lifecycle works with trio backend."""
|
||||
|
||||
async def _test():
|
||||
mock_transport = create_mock_transport()
|
||||
|
||||
query = Query(
|
||||
transport=mock_transport,
|
||||
is_streaming_mode=True,
|
||||
)
|
||||
|
||||
await query.start()
|
||||
assert query._tg is not None
|
||||
assert query._owner_started_event.is_set()
|
||||
|
||||
await query.close()
|
||||
assert query._tg is None
|
||||
|
||||
anyio.run(_test, backend="trio")
|
||||
|
||||
def test_manual_connect_disconnect_with_trio_backend(self):
|
||||
"""Verify manual connect/disconnect works with trio backend."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
await client.connect()
|
||||
|
||||
assert client._query is not None
|
||||
await client.query("test message")
|
||||
|
||||
await client.disconnect()
|
||||
assert client._query is None
|
||||
|
||||
anyio.run(_test, backend="trio")
|
||||
|
||||
def test_concurrent_queries_with_trio_backend(self):
|
||||
"""Verify concurrent operations work with trio backend."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
async def send_query(msg: str):
|
||||
await client.query(msg)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(send_query, "Query A")
|
||||
tg.start_soon(send_query, "Query B")
|
||||
|
||||
anyio.run(_test, backend="trio")
|
||||
Loading…
Add table
Add a link
Reference in a new issue