diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index cf1c6a5..c196c56 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -3,7 +3,6 @@ import asyncio import sys import tempfile -import textwrap from pathlib import Path from unittest.mock import AsyncMock, patch @@ -373,80 +372,6 @@ class TestClaudeSDKClientStreaming: class TestQueryWithAsyncIterable: """Test query() function with async iterable inputs.""" - def _create_test_script( - self, expected_messages=None, response=None, should_error=False - ): - """Create a test script that validates CLI args and stdin messages. - - Args: - expected_messages: List of expected message content strings, or None to skip validation - response: Custom response to output, defaults to a success result - should_error: If True, script will exit with error after reading stdin - - Returns: - Path to the test script - """ - if response is None: - response = '{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}' - - script_content = textwrap.dedent( - """ - #!/usr/bin/env python3 - import sys - import json - import time - - # Check command line args - args = sys.argv[1:] - assert "--output-format" in args - assert "stream-json" in args - - # Read stdin messages - stdin_messages = [] - stdin_closed = False - try: - while True: - line = sys.stdin.readline() - if not line: - stdin_closed = True - break - stdin_messages.append(line.strip()) - except: - stdin_closed = True - """, - ) - - if expected_messages is not None: - script_content += textwrap.dedent( - f""" - # Verify we got the expected messages - assert len(stdin_messages) == {len(expected_messages)} - """, - ) - for i, msg in enumerate(expected_messages): - script_content += f'''assert '"{msg}"' in stdin_messages[{i}]\n''' - - if should_error: - script_content += textwrap.dedent( - """ - sys.exit(1) - """, - ) - else: - script_content += textwrap.dedent( - f""" - # Output response - print('{response}') - """, - ) - - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: - test_script = f.name - f.write(script_content) - - Path(test_script).chmod(0o755) - return test_script - def test_query_with_async_iterable(self): """Test query with async iterable of messages.""" @@ -455,32 +380,63 @@ class TestQueryWithAsyncIterable: yield {"type": "user", "message": {"role": "user", "content": "First"}} yield {"type": "user", "message": {"role": "user", "content": "Second"}} - test_script = self._create_test_script( - expected_messages=["First", "Second"] - ) + # Create a simple test script that validates stdin and outputs a result + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + test_script = f.name + f.write("""#!/usr/bin/env python3 +import sys +import json + +# Read stdin messages +stdin_messages = [] +while True: + line = sys.stdin.readline() + if not line: + break + stdin_messages.append(line.strip()) + +# Verify we got 2 messages +assert len(stdin_messages) == 2 +assert '"First"' in stdin_messages[0] +assert '"Second"' in stdin_messages[1] + +# Output a valid result +print('{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}') +""") + + Path(test_script).chmod(0o755) try: - # Mock _build_command to return our test script + # Mock _find_cli to return python executing our test script with patch.object( SubprocessCLITransport, - "_build_command", - return_value=[ - sys.executable, - test_script, - "--output-format", - "stream-json", - "--verbose", - ], + "_find_cli", + return_value=sys.executable ): - # Run query with async iterable - messages = [] - async for msg in query(prompt=message_stream()): - messages.append(msg) + # Mock _build_command to add our test script as first argument + original_build_command = SubprocessCLITransport._build_command + + def mock_build_command(self): + # Get original command + cmd = original_build_command(self) + # Replace the CLI path with python + script + cmd[0] = test_script + return cmd + + with patch.object( + SubprocessCLITransport, + "_build_command", + mock_build_command + ): + # Run query with async iterable + messages = [] + async for msg in query(prompt=message_stream()): + messages.append(msg) - # Should get the result message - assert len(messages) == 1 - assert isinstance(messages[0], ResultMessage) - assert messages[0].subtype == "success" + # Should get the result message + assert len(messages) == 1 + assert isinstance(messages[0], ResultMessage) + assert messages[0].subtype == "success" finally: # Clean up Path(test_script).unlink()