From eeb0be9955f9f70a34c9d423352693eddf49f5a4 Mon Sep 17 00:00:00 2001 From: Dickson Tsai Date: Sat, 19 Jul 2025 15:01:43 -0700 Subject: [PATCH] Close stdin for query() --- src/claude_code_sdk/_internal/client.py | 6 ++- .../_internal/transport/subprocess_cli.py | 9 +++- tests/test_streaming_client.py | 54 +++++++++++++++---- 3 files changed, 55 insertions(+), 14 deletions(-) diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index fb4eeb8..d40540f 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -19,7 +19,11 @@ class InternalClient: ) -> AsyncIterator[Message]: """Process a query through transport.""" - transport = SubprocessCLITransport(prompt=prompt, options=options) + transport = SubprocessCLITransport( + prompt=prompt, + options=options, + close_stdin_after_prompt=True + ) try: await transport.connect() diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 5632c21..701b686 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -31,6 +31,7 @@ class SubprocessCLITransport(Transport): prompt: str | AsyncIterable[dict[str, Any]], options: ClaudeCodeOptions, cli_path: str | Path | None = None, + close_stdin_after_prompt: bool = False, ): self._prompt = prompt self._is_streaming = not isinstance(prompt, str) @@ -43,6 +44,7 @@ class SubprocessCLITransport(Transport): self._stdin_stream: TextSendStream | None = None self._pending_control_responses: dict[str, Any] = {} self._request_counter = 0 + self._close_stdin_after_prompt = close_stdin_after_prompt def _find_cli(self) -> str: """Find Claude Code CLI binary.""" @@ -228,8 +230,11 @@ class SubprocessCLITransport(Transport): break await self._stdin_stream.send(json.dumps(message) + "\n") - # Don't close stdin - keep it open for send_request - # Users can explicitly call disconnect() when done + # Close stdin after prompt if requested (e.g., for query() one-shot mode) + if self._close_stdin_after_prompt and self._stdin_stream: + await self._stdin_stream.aclose() + self._stdin_stream = None + # Otherwise keep stdin open for send_request (ClaudeSDKClient interactive mode) except Exception as e: logger.debug(f"Error streaming to stdin: {e}") if self._stdin_stream: diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index 4f62545..ed83bcd 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -24,6 +24,7 @@ class TestClaudeSDKClientStreaming: def test_auto_connect_with_context_manager(self): """Test automatic connection when using context manager.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -43,6 +44,7 @@ class TestClaudeSDKClientStreaming: def test_manual_connect_disconnect(self): """Test manual connect and disconnect.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -66,6 +68,7 @@ class TestClaudeSDKClientStreaming: def test_connect_with_string_prompt(self): """Test connecting with a string prompt.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -84,6 +87,7 @@ class TestClaudeSDKClientStreaming: def test_connect_with_async_iterable(self): """Test connecting with an async iterable.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -93,7 +97,10 @@ class TestClaudeSDKClientStreaming: async def message_stream(): yield {"type": "user", "message": {"role": "user", "content": "Hi"}} - yield {"type": "user", "message": {"role": "user", "content": "Bye"}} + yield { + "type": "user", + "message": {"role": "user", "content": "Bye"}, + } client = ClaudeSDKClient() stream = message_stream() @@ -108,6 +115,7 @@ class TestClaudeSDKClientStreaming: def test_send_message(self): """Test sending a message.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -131,6 +139,7 @@ class TestClaudeSDKClientStreaming: def test_send_message_with_session_id(self): """Test sending a message with custom session ID.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -150,6 +159,7 @@ class TestClaudeSDKClientStreaming: def test_send_message_not_connected(self): """Test sending message when not connected raises error.""" + async def _test(): client = ClaudeSDKClient() with pytest.raises(CLIConnectionError, match="Not connected"): @@ -159,6 +169,7 @@ class TestClaudeSDKClientStreaming: def test_receive_messages(self): """Test receiving messages.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -200,6 +211,7 @@ class TestClaudeSDKClientStreaming: def test_receive_response(self): """Test receive_response stops at ResultMessage.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -231,7 +243,9 @@ class TestClaudeSDKClientStreaming: "type": "assistant", "message": { "role": "assistant", - "content": [{"type": "text", "text": "Should not see this"}], + "content": [ + {"type": "text", "text": "Should not see this"} + ], }, } @@ -251,6 +265,7 @@ class TestClaudeSDKClientStreaming: def test_interrupt(self): """Test interrupt functionality.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -266,6 +281,7 @@ class TestClaudeSDKClientStreaming: def test_interrupt_not_connected(self): """Test interrupt when not connected raises error.""" + async def _test(): client = ClaudeSDKClient() with pytest.raises(CLIConnectionError, match="Not connected"): @@ -275,6 +291,7 @@ class TestClaudeSDKClientStreaming: def test_client_with_options(self): """Test client initialization with options.""" + async def _test(): options = ClaudeCodeOptions( cwd="/custom/path", @@ -299,6 +316,7 @@ class TestClaudeSDKClientStreaming: def test_concurrent_send_receive(self): """Test concurrent sending and receiving messages.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -334,7 +352,7 @@ class TestClaudeSDKClientStreaming: # Helper to get next message async def get_next_message(): return await client.receive_response().__anext__() - + # Start receiving in background receive_task = asyncio.create_task(get_next_message()) @@ -353,13 +371,14 @@ class TestQueryWithAsyncIterable: def test_query_with_async_iterable(self): """Test query with async iterable of messages.""" + async def _test(): async def message_stream(): yield {"type": "user", "message": {"role": "user", "content": "First"}} yield {"type": "user", "message": {"role": "user", "content": "Second"}} with patch( - "claude_code_sdk.query.InternalClient" + "claude_code_sdk._internal.client.InternalClient" ) as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client @@ -399,6 +418,7 @@ class TestQueryWithAsyncIterable: def test_query_async_iterable_with_options(self): """Test query with async iterable and custom options.""" + async def _test(): async def complex_stream(): yield { @@ -421,7 +441,7 @@ class TestQueryWithAsyncIterable: ) with patch( - "claude_code_sdk.query.InternalClient" + "claude_code_sdk._internal.client.InternalClient" ) as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client @@ -445,6 +465,7 @@ class TestQueryWithAsyncIterable: def test_query_empty_async_iterable(self): """Test query with empty async iterable.""" + async def _test(): async def empty_stream(): # Never yields anything @@ -452,7 +473,7 @@ class TestQueryWithAsyncIterable: yield with patch( - "claude_code_sdk.query.InternalClient" + "claude_code_sdk._internal.client.InternalClient" ) as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client @@ -460,8 +481,7 @@ class TestQueryWithAsyncIterable: # Mock response async def mock_process(): yield SystemMessage( - subtype="info", - data={"message": "No input provided"} + subtype="info", data={"message": "No input provided"} ) mock_client.process_query.return_value = mock_process() @@ -478,6 +498,7 @@ class TestQueryWithAsyncIterable: def test_query_async_iterable_with_delay(self): """Test query with async iterable that has delays between yields.""" + async def _test(): async def delayed_stream(): yield {"type": "user", "message": {"role": "user", "content": "Start"}} @@ -487,7 +508,7 @@ class TestQueryWithAsyncIterable: yield {"type": "user", "message": {"role": "user", "content": "End"}} with patch( - "claude_code_sdk.query.InternalClient" + "claude_code_sdk._internal.client.InternalClient" ) as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client @@ -512,6 +533,7 @@ class TestQueryWithAsyncIterable: # Time the execution import time + start_time = time.time() messages = [] async for msg in query(prompt=delayed_stream()): @@ -527,13 +549,14 @@ class TestQueryWithAsyncIterable: def test_query_async_iterable_exception_handling(self): """Test query handles exceptions in async iterable.""" + async def _test(): async def failing_stream(): yield {"type": "user", "message": {"role": "user", "content": "First"}} raise ValueError("Stream error") with patch( - "claude_code_sdk.query.InternalClient" + "claude_code_sdk._internal.client.InternalClient" ) as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client @@ -561,6 +584,7 @@ class TestClaudeSDKClientEdgeCases: def test_receive_messages_not_connected(self): """Test receiving messages when not connected.""" + async def _test(): client = ClaudeSDKClient() with pytest.raises(CLIConnectionError, match="Not connected"): @@ -571,6 +595,7 @@ class TestClaudeSDKClientEdgeCases: def test_receive_response_not_connected(self): """Test receive_response when not connected.""" + async def _test(): client = ClaudeSDKClient() with pytest.raises(CLIConnectionError, match="Not connected"): @@ -581,6 +606,7 @@ class TestClaudeSDKClientEdgeCases: def test_double_connect(self): """Test connecting twice.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -600,6 +626,7 @@ class TestClaudeSDKClientEdgeCases: def test_disconnect_without_connect(self): """Test disconnecting without connecting first.""" + async def _test(): client = ClaudeSDKClient() # Should not raise error @@ -609,6 +636,7 @@ class TestClaudeSDKClientEdgeCases: def test_context_manager_with_exception(self): """Test context manager cleans up on exception.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -627,6 +655,7 @@ class TestClaudeSDKClientEdgeCases: def test_receive_response_list_comprehension(self): """Test collecting messages with list comprehension as shown in examples.""" + async def _test(): with patch( "claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" @@ -668,7 +697,10 @@ class TestClaudeSDKClientEdgeCases: messages = [msg async for msg in client.receive_response()] assert len(messages) == 3 - assert all(isinstance(msg, AssistantMessage | ResultMessage) for msg in messages) + assert all( + isinstance(msg, AssistantMessage | ResultMessage) + for msg in messages + ) assert isinstance(messages[-1], ResultMessage) anyio.run(_test)