Close stdin for query()

This commit is contained in:
Dickson Tsai 2025-07-19 15:01:43 -07:00
parent 489677d614
commit eeb0be9955
No known key found for this signature in database
3 changed files with 55 additions and 14 deletions

View file

@ -19,7 +19,11 @@ class InternalClient:
) -> AsyncIterator[Message]:
"""Process a query through transport."""
transport = SubprocessCLITransport(prompt=prompt, options=options)
transport = SubprocessCLITransport(
prompt=prompt,
options=options,
close_stdin_after_prompt=True
)
try:
await transport.connect()

View file

@ -31,6 +31,7 @@ class SubprocessCLITransport(Transport):
prompt: str | AsyncIterable[dict[str, Any]],
options: ClaudeCodeOptions,
cli_path: str | Path | None = None,
close_stdin_after_prompt: bool = False,
):
self._prompt = prompt
self._is_streaming = not isinstance(prompt, str)
@ -43,6 +44,7 @@ class SubprocessCLITransport(Transport):
self._stdin_stream: TextSendStream | None = None
self._pending_control_responses: dict[str, Any] = {}
self._request_counter = 0
self._close_stdin_after_prompt = close_stdin_after_prompt
def _find_cli(self) -> str:
"""Find Claude Code CLI binary."""
@ -228,8 +230,11 @@ class SubprocessCLITransport(Transport):
break
await self._stdin_stream.send(json.dumps(message) + "\n")
# Don't close stdin - keep it open for send_request
# Users can explicitly call disconnect() when done
# Close stdin after prompt if requested (e.g., for query() one-shot mode)
if self._close_stdin_after_prompt and self._stdin_stream:
await self._stdin_stream.aclose()
self._stdin_stream = None
# Otherwise keep stdin open for send_request (ClaudeSDKClient interactive mode)
except Exception as e:
logger.debug(f"Error streaming to stdin: {e}")
if self._stdin_stream:

View file

@ -24,6 +24,7 @@ class TestClaudeSDKClientStreaming:
def test_auto_connect_with_context_manager(self):
"""Test automatic connection when using context manager."""
async def _test():
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@ -43,6 +44,7 @@ class TestClaudeSDKClientStreaming:
def test_manual_connect_disconnect(self):
"""Test manual connect and disconnect."""
async def _test():
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@ -66,6 +68,7 @@ class TestClaudeSDKClientStreaming:
def test_connect_with_string_prompt(self):
"""Test connecting with a string prompt."""
async def _test():
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@ -84,6 +87,7 @@ class TestClaudeSDKClientStreaming:
def test_connect_with_async_iterable(self):
"""Test connecting with an async iterable."""
async def _test():
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@ -93,7 +97,10 @@ class TestClaudeSDKClientStreaming:
async def message_stream():
yield {"type": "user", "message": {"role": "user", "content": "Hi"}}
yield {"type": "user", "message": {"role": "user", "content": "Bye"}}
yield {
"type": "user",
"message": {"role": "user", "content": "Bye"},
}
client = ClaudeSDKClient()
stream = message_stream()
@ -108,6 +115,7 @@ class TestClaudeSDKClientStreaming:
def test_send_message(self):
"""Test sending a message."""
async def _test():
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@ -131,6 +139,7 @@ class TestClaudeSDKClientStreaming:
def test_send_message_with_session_id(self):
"""Test sending a message with custom session ID."""
async def _test():
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@ -150,6 +159,7 @@ class TestClaudeSDKClientStreaming:
def test_send_message_not_connected(self):
"""Test sending message when not connected raises error."""
async def _test():
client = ClaudeSDKClient()
with pytest.raises(CLIConnectionError, match="Not connected"):
@ -159,6 +169,7 @@ class TestClaudeSDKClientStreaming:
def test_receive_messages(self):
"""Test receiving messages."""
async def _test():
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@ -200,6 +211,7 @@ class TestClaudeSDKClientStreaming:
def test_receive_response(self):
"""Test receive_response stops at ResultMessage."""
async def _test():
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@ -231,7 +243,9 @@ class TestClaudeSDKClientStreaming:
"type": "assistant",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": "Should not see this"}],
"content": [
{"type": "text", "text": "Should not see this"}
],
},
}
@ -251,6 +265,7 @@ class TestClaudeSDKClientStreaming:
def test_interrupt(self):
"""Test interrupt functionality."""
async def _test():
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@ -266,6 +281,7 @@ class TestClaudeSDKClientStreaming:
def test_interrupt_not_connected(self):
"""Test interrupt when not connected raises error."""
async def _test():
client = ClaudeSDKClient()
with pytest.raises(CLIConnectionError, match="Not connected"):
@ -275,6 +291,7 @@ class TestClaudeSDKClientStreaming:
def test_client_with_options(self):
"""Test client initialization with options."""
async def _test():
options = ClaudeCodeOptions(
cwd="/custom/path",
@ -299,6 +316,7 @@ class TestClaudeSDKClientStreaming:
def test_concurrent_send_receive(self):
"""Test concurrent sending and receiving messages."""
async def _test():
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@ -334,7 +352,7 @@ class TestClaudeSDKClientStreaming:
# Helper to get next message
async def get_next_message():
return await client.receive_response().__anext__()
# Start receiving in background
receive_task = asyncio.create_task(get_next_message())
@ -353,13 +371,14 @@ class TestQueryWithAsyncIterable:
def test_query_with_async_iterable(self):
"""Test query with async iterable of messages."""
async def _test():
async def message_stream():
yield {"type": "user", "message": {"role": "user", "content": "First"}}
yield {"type": "user", "message": {"role": "user", "content": "Second"}}
with patch(
"claude_code_sdk.query.InternalClient"
"claude_code_sdk._internal.client.InternalClient"
) as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
@ -399,6 +418,7 @@ class TestQueryWithAsyncIterable:
def test_query_async_iterable_with_options(self):
"""Test query with async iterable and custom options."""
async def _test():
async def complex_stream():
yield {
@ -421,7 +441,7 @@ class TestQueryWithAsyncIterable:
)
with patch(
"claude_code_sdk.query.InternalClient"
"claude_code_sdk._internal.client.InternalClient"
) as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
@ -445,6 +465,7 @@ class TestQueryWithAsyncIterable:
def test_query_empty_async_iterable(self):
"""Test query with empty async iterable."""
async def _test():
async def empty_stream():
# Never yields anything
@ -452,7 +473,7 @@ class TestQueryWithAsyncIterable:
yield
with patch(
"claude_code_sdk.query.InternalClient"
"claude_code_sdk._internal.client.InternalClient"
) as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
@ -460,8 +481,7 @@ class TestQueryWithAsyncIterable:
# Mock response
async def mock_process():
yield SystemMessage(
subtype="info",
data={"message": "No input provided"}
subtype="info", data={"message": "No input provided"}
)
mock_client.process_query.return_value = mock_process()
@ -478,6 +498,7 @@ class TestQueryWithAsyncIterable:
def test_query_async_iterable_with_delay(self):
"""Test query with async iterable that has delays between yields."""
async def _test():
async def delayed_stream():
yield {"type": "user", "message": {"role": "user", "content": "Start"}}
@ -487,7 +508,7 @@ class TestQueryWithAsyncIterable:
yield {"type": "user", "message": {"role": "user", "content": "End"}}
with patch(
"claude_code_sdk.query.InternalClient"
"claude_code_sdk._internal.client.InternalClient"
) as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
@ -512,6 +533,7 @@ class TestQueryWithAsyncIterable:
# Time the execution
import time
start_time = time.time()
messages = []
async for msg in query(prompt=delayed_stream()):
@ -527,13 +549,14 @@ class TestQueryWithAsyncIterable:
def test_query_async_iterable_exception_handling(self):
"""Test query handles exceptions in async iterable."""
async def _test():
async def failing_stream():
yield {"type": "user", "message": {"role": "user", "content": "First"}}
raise ValueError("Stream error")
with patch(
"claude_code_sdk.query.InternalClient"
"claude_code_sdk._internal.client.InternalClient"
) as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
@ -561,6 +584,7 @@ class TestClaudeSDKClientEdgeCases:
def test_receive_messages_not_connected(self):
"""Test receiving messages when not connected."""
async def _test():
client = ClaudeSDKClient()
with pytest.raises(CLIConnectionError, match="Not connected"):
@ -571,6 +595,7 @@ class TestClaudeSDKClientEdgeCases:
def test_receive_response_not_connected(self):
"""Test receive_response when not connected."""
async def _test():
client = ClaudeSDKClient()
with pytest.raises(CLIConnectionError, match="Not connected"):
@ -581,6 +606,7 @@ class TestClaudeSDKClientEdgeCases:
def test_double_connect(self):
"""Test connecting twice."""
async def _test():
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@ -600,6 +626,7 @@ class TestClaudeSDKClientEdgeCases:
def test_disconnect_without_connect(self):
"""Test disconnecting without connecting first."""
async def _test():
client = ClaudeSDKClient()
# Should not raise error
@ -609,6 +636,7 @@ class TestClaudeSDKClientEdgeCases:
def test_context_manager_with_exception(self):
"""Test context manager cleans up on exception."""
async def _test():
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@ -627,6 +655,7 @@ class TestClaudeSDKClientEdgeCases:
def test_receive_response_list_comprehension(self):
"""Test collecting messages with list comprehension as shown in examples."""
async def _test():
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@ -668,7 +697,10 @@ class TestClaudeSDKClientEdgeCases:
messages = [msg async for msg in client.receive_response()]
assert len(messages) == 3
assert all(isinstance(msg, AssistantMessage | ResultMessage) for msg in messages)
assert all(
isinstance(msg, AssistantMessage | ResultMessage)
for msg in messages
)
assert isinstance(messages[-1], ResultMessage)
anyio.run(_test)