mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
Close stdin for query()
This commit is contained in:
parent
489677d614
commit
eeb0be9955
3 changed files with 55 additions and 14 deletions
|
|
@ -19,7 +19,11 @@ class InternalClient:
|
|||
) -> AsyncIterator[Message]:
|
||||
"""Process a query through transport."""
|
||||
|
||||
transport = SubprocessCLITransport(prompt=prompt, options=options)
|
||||
transport = SubprocessCLITransport(
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
close_stdin_after_prompt=True
|
||||
)
|
||||
|
||||
try:
|
||||
await transport.connect()
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class SubprocessCLITransport(Transport):
|
|||
prompt: str | AsyncIterable[dict[str, Any]],
|
||||
options: ClaudeCodeOptions,
|
||||
cli_path: str | Path | None = None,
|
||||
close_stdin_after_prompt: bool = False,
|
||||
):
|
||||
self._prompt = prompt
|
||||
self._is_streaming = not isinstance(prompt, str)
|
||||
|
|
@ -43,6 +44,7 @@ class SubprocessCLITransport(Transport):
|
|||
self._stdin_stream: TextSendStream | None = None
|
||||
self._pending_control_responses: dict[str, Any] = {}
|
||||
self._request_counter = 0
|
||||
self._close_stdin_after_prompt = close_stdin_after_prompt
|
||||
|
||||
def _find_cli(self) -> str:
|
||||
"""Find Claude Code CLI binary."""
|
||||
|
|
@ -228,8 +230,11 @@ class SubprocessCLITransport(Transport):
|
|||
break
|
||||
await self._stdin_stream.send(json.dumps(message) + "\n")
|
||||
|
||||
# Don't close stdin - keep it open for send_request
|
||||
# Users can explicitly call disconnect() when done
|
||||
# Close stdin after prompt if requested (e.g., for query() one-shot mode)
|
||||
if self._close_stdin_after_prompt and self._stdin_stream:
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
# Otherwise keep stdin open for send_request (ClaudeSDKClient interactive mode)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error streaming to stdin: {e}")
|
||||
if self._stdin_stream:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
def test_auto_connect_with_context_manager(self):
|
||||
"""Test automatic connection when using context manager."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
|
|
@ -43,6 +44,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
def test_manual_connect_disconnect(self):
|
||||
"""Test manual connect and disconnect."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
|
|
@ -66,6 +68,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
def test_connect_with_string_prompt(self):
|
||||
"""Test connecting with a string prompt."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
|
|
@ -84,6 +87,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
def test_connect_with_async_iterable(self):
|
||||
"""Test connecting with an async iterable."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
|
|
@ -93,7 +97,10 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
async def message_stream():
|
||||
yield {"type": "user", "message": {"role": "user", "content": "Hi"}}
|
||||
yield {"type": "user", "message": {"role": "user", "content": "Bye"}}
|
||||
yield {
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": "Bye"},
|
||||
}
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
stream = message_stream()
|
||||
|
|
@ -108,6 +115,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
def test_send_message(self):
|
||||
"""Test sending a message."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
|
|
@ -131,6 +139,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
def test_send_message_with_session_id(self):
|
||||
"""Test sending a message with custom session ID."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
|
|
@ -150,6 +159,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
def test_send_message_not_connected(self):
|
||||
"""Test sending message when not connected raises error."""
|
||||
|
||||
async def _test():
|
||||
client = ClaudeSDKClient()
|
||||
with pytest.raises(CLIConnectionError, match="Not connected"):
|
||||
|
|
@ -159,6 +169,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
def test_receive_messages(self):
|
||||
"""Test receiving messages."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
|
|
@ -200,6 +211,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
def test_receive_response(self):
|
||||
"""Test receive_response stops at ResultMessage."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
|
|
@ -231,7 +243,9 @@ class TestClaudeSDKClientStreaming:
|
|||
"type": "assistant",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Should not see this"}],
|
||||
"content": [
|
||||
{"type": "text", "text": "Should not see this"}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -251,6 +265,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
def test_interrupt(self):
|
||||
"""Test interrupt functionality."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
|
|
@ -266,6 +281,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
def test_interrupt_not_connected(self):
|
||||
"""Test interrupt when not connected raises error."""
|
||||
|
||||
async def _test():
|
||||
client = ClaudeSDKClient()
|
||||
with pytest.raises(CLIConnectionError, match="Not connected"):
|
||||
|
|
@ -275,6 +291,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
def test_client_with_options(self):
|
||||
"""Test client initialization with options."""
|
||||
|
||||
async def _test():
|
||||
options = ClaudeCodeOptions(
|
||||
cwd="/custom/path",
|
||||
|
|
@ -299,6 +316,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
def test_concurrent_send_receive(self):
|
||||
"""Test concurrent sending and receiving messages."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
|
|
@ -334,7 +352,7 @@ class TestClaudeSDKClientStreaming:
|
|||
# Helper to get next message
|
||||
async def get_next_message():
|
||||
return await client.receive_response().__anext__()
|
||||
|
||||
|
||||
# Start receiving in background
|
||||
receive_task = asyncio.create_task(get_next_message())
|
||||
|
||||
|
|
@ -353,13 +371,14 @@ class TestQueryWithAsyncIterable:
|
|||
|
||||
def test_query_with_async_iterable(self):
|
||||
"""Test query with async iterable of messages."""
|
||||
|
||||
async def _test():
|
||||
async def message_stream():
|
||||
yield {"type": "user", "message": {"role": "user", "content": "First"}}
|
||||
yield {"type": "user", "message": {"role": "user", "content": "Second"}}
|
||||
|
||||
with patch(
|
||||
"claude_code_sdk.query.InternalClient"
|
||||
"claude_code_sdk._internal.client.InternalClient"
|
||||
) as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
|
@ -399,6 +418,7 @@ class TestQueryWithAsyncIterable:
|
|||
|
||||
def test_query_async_iterable_with_options(self):
|
||||
"""Test query with async iterable and custom options."""
|
||||
|
||||
async def _test():
|
||||
async def complex_stream():
|
||||
yield {
|
||||
|
|
@ -421,7 +441,7 @@ class TestQueryWithAsyncIterable:
|
|||
)
|
||||
|
||||
with patch(
|
||||
"claude_code_sdk.query.InternalClient"
|
||||
"claude_code_sdk._internal.client.InternalClient"
|
||||
) as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
|
@ -445,6 +465,7 @@ class TestQueryWithAsyncIterable:
|
|||
|
||||
def test_query_empty_async_iterable(self):
|
||||
"""Test query with empty async iterable."""
|
||||
|
||||
async def _test():
|
||||
async def empty_stream():
|
||||
# Never yields anything
|
||||
|
|
@ -452,7 +473,7 @@ class TestQueryWithAsyncIterable:
|
|||
yield
|
||||
|
||||
with patch(
|
||||
"claude_code_sdk.query.InternalClient"
|
||||
"claude_code_sdk._internal.client.InternalClient"
|
||||
) as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
|
@ -460,8 +481,7 @@ class TestQueryWithAsyncIterable:
|
|||
# Mock response
|
||||
async def mock_process():
|
||||
yield SystemMessage(
|
||||
subtype="info",
|
||||
data={"message": "No input provided"}
|
||||
subtype="info", data={"message": "No input provided"}
|
||||
)
|
||||
|
||||
mock_client.process_query.return_value = mock_process()
|
||||
|
|
@ -478,6 +498,7 @@ class TestQueryWithAsyncIterable:
|
|||
|
||||
def test_query_async_iterable_with_delay(self):
|
||||
"""Test query with async iterable that has delays between yields."""
|
||||
|
||||
async def _test():
|
||||
async def delayed_stream():
|
||||
yield {"type": "user", "message": {"role": "user", "content": "Start"}}
|
||||
|
|
@ -487,7 +508,7 @@ class TestQueryWithAsyncIterable:
|
|||
yield {"type": "user", "message": {"role": "user", "content": "End"}}
|
||||
|
||||
with patch(
|
||||
"claude_code_sdk.query.InternalClient"
|
||||
"claude_code_sdk._internal.client.InternalClient"
|
||||
) as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
|
@ -512,6 +533,7 @@ class TestQueryWithAsyncIterable:
|
|||
|
||||
# Time the execution
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
messages = []
|
||||
async for msg in query(prompt=delayed_stream()):
|
||||
|
|
@ -527,13 +549,14 @@ class TestQueryWithAsyncIterable:
|
|||
|
||||
def test_query_async_iterable_exception_handling(self):
|
||||
"""Test query handles exceptions in async iterable."""
|
||||
|
||||
async def _test():
|
||||
async def failing_stream():
|
||||
yield {"type": "user", "message": {"role": "user", "content": "First"}}
|
||||
raise ValueError("Stream error")
|
||||
|
||||
with patch(
|
||||
"claude_code_sdk.query.InternalClient"
|
||||
"claude_code_sdk._internal.client.InternalClient"
|
||||
) as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
|
@ -561,6 +584,7 @@ class TestClaudeSDKClientEdgeCases:
|
|||
|
||||
def test_receive_messages_not_connected(self):
|
||||
"""Test receiving messages when not connected."""
|
||||
|
||||
async def _test():
|
||||
client = ClaudeSDKClient()
|
||||
with pytest.raises(CLIConnectionError, match="Not connected"):
|
||||
|
|
@ -571,6 +595,7 @@ class TestClaudeSDKClientEdgeCases:
|
|||
|
||||
def test_receive_response_not_connected(self):
|
||||
"""Test receive_response when not connected."""
|
||||
|
||||
async def _test():
|
||||
client = ClaudeSDKClient()
|
||||
with pytest.raises(CLIConnectionError, match="Not connected"):
|
||||
|
|
@ -581,6 +606,7 @@ class TestClaudeSDKClientEdgeCases:
|
|||
|
||||
def test_double_connect(self):
|
||||
"""Test connecting twice."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
|
|
@ -600,6 +626,7 @@ class TestClaudeSDKClientEdgeCases:
|
|||
|
||||
def test_disconnect_without_connect(self):
|
||||
"""Test disconnecting without connecting first."""
|
||||
|
||||
async def _test():
|
||||
client = ClaudeSDKClient()
|
||||
# Should not raise error
|
||||
|
|
@ -609,6 +636,7 @@ class TestClaudeSDKClientEdgeCases:
|
|||
|
||||
def test_context_manager_with_exception(self):
|
||||
"""Test context manager cleans up on exception."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
|
|
@ -627,6 +655,7 @@ class TestClaudeSDKClientEdgeCases:
|
|||
|
||||
def test_receive_response_list_comprehension(self):
|
||||
"""Test collecting messages with list comprehension as shown in examples."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
|
|
@ -668,7 +697,10 @@ class TestClaudeSDKClientEdgeCases:
|
|||
messages = [msg async for msg in client.receive_response()]
|
||||
|
||||
assert len(messages) == 3
|
||||
assert all(isinstance(msg, AssistantMessage | ResultMessage) for msg in messages)
|
||||
assert all(
|
||||
isinstance(msg, AssistantMessage | ResultMessage)
|
||||
for msg in messages
|
||||
)
|
||||
assert isinstance(messages[-1], ResultMessage)
|
||||
|
||||
anyio.run(_test)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue