Merge branch 'dickson/control' into feat/sdk-mcp-server-support

Resolved conflict by keeping both TYPE_CHECKING (needed for MCP) and contextlib.suppress (from base branch fixes)
This commit is contained in:
Kashyap Murali 2025-09-01 03:09:42 -07:00
commit 76f6ed1d9c
No known key found for this signature in database
8 changed files with 201 additions and 72 deletions

View file

@ -1,5 +1,6 @@
"""Internal client implementation."""
import asyncio
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
@ -64,8 +65,6 @@ class InternalClient:
# Stream input if it's an AsyncIterable
if isinstance(prompt, AsyncIterable):
# Start streaming in background
import asyncio
asyncio.create_task(query.stream_input(prompt))
# For string prompts, the prompt is already passed via CLI args
@ -74,4 +73,4 @@ class InternalClient:
yield parse_message(data)
finally:
query.close()
await query.close()

View file

@ -5,6 +5,7 @@ import json
import logging
import os
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from contextlib import suppress
from typing import TYPE_CHECKING, Any
from .transport import Transport
@ -146,12 +147,16 @@ class Query:
# Regular SDK messages go to the queue
await self._message_queue.put(message)
except asyncio.CancelledError:
# Task was cancelled - this is expected behavior
logger.debug("Read task cancelled")
raise # Re-raise to properly handle cancellation
except Exception as e:
logger.debug(f"Error reading messages: {e}")
logger.error(f"Fatal error in message reader: {e}")
# Put error in queue so iterators can handle it
await self._message_queue.put({"type": "error", "error": str(e)})
finally:
# Signal end of stream
# Always signal end of stream
await self._message_queue.put({"type": "end"})
async def _handle_control_request(self, request: dict[str, Any]) -> None:
@ -345,7 +350,7 @@ class Query:
break
await self.transport.write(json.dumps(message) + "\n")
# After all messages sent, end input
self.transport.end_input()
await self.transport.end_input()
except Exception as e:
logger.debug(f"Error streaming input: {e}")
@ -362,12 +367,15 @@ class Query:
yield message
def close(self) -> None:
async def close(self) -> None:
"""Close the query and transport."""
self._closed = True
if self._read_task and not self._read_task.done():
self._read_task.cancel()
self.transport.close()
# Wait for task to complete cancellation
with suppress(asyncio.CancelledError):
await self._read_task
await self.transport.close()
# Make Query an async iterator
def __aiter__(self) -> AsyncIterator[dict[str, Any]]:

View file

@ -46,7 +46,7 @@ class Transport(ABC):
pass
@abstractmethod
def close(self) -> None:
async def close(self) -> None:
"""Close the transport connection and clean up resources."""
pass
@ -60,7 +60,7 @@ class Transport(ABC):
pass
@abstractmethod
def end_input(self) -> None:
async def end_input(self) -> None:
"""End the input stream (close stdin for process transports)."""
pass

View file

@ -7,6 +7,7 @@ import shutil
import tempfile
from collections import deque
from collections.abc import AsyncIterable, AsyncIterator
from contextlib import suppress
from pathlib import Path
from subprocess import PIPE
from typing import Any
@ -216,20 +217,31 @@ class SubprocessCLITransport(Transport):
except Exception as e:
raise CLIConnectionError(f"Failed to start Claude Code: {e}") from e
def close(self) -> None:
async def close(self) -> None:
"""Close the transport and clean up resources."""
self._ready = False
if not self._process:
return
if self._process.returncode is None:
from contextlib import suppress
# Close stdin first if it's still open
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
if self._process.stdin:
with suppress(Exception):
await self._process.stdin.aclose()
# Terminate and wait for process
if self._process.returncode is None:
with suppress(ProcessLookupError):
self._process.terminate()
# Note: We can't use async wait here since close() is sync
# The process will be cleaned up by the OS
# Wait for process to finish with timeout
with suppress(Exception):
# Just try to wait, but don't block if it fails
await self._process.wait()
# Clean up temp file
if self._stderr_file:
@ -252,18 +264,15 @@ class SubprocessCLITransport(Transport):
await self._stdin_stream.send(data)
def end_input(self) -> None:
async def end_input(self) -> None:
"""End the input stream (close stdin)."""
if self._stdin_stream:
# Note: We can't use async aclose here since end_input() is sync
# Just mark it as None and let cleanup happen later
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
if self._process and self._process.stdin:
from contextlib import suppress
with suppress(Exception):
# Mark stdin as closed - actual close will happen during cleanup
pass
await self._process.stdin.aclose()
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Read and parse messages from the transport."""

View file

@ -1,5 +1,7 @@
"""Claude SDK Client for interacting with Claude Code."""
import asyncio
import json
import os
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
@ -137,8 +139,6 @@ class ClaudeSDKClient:
# If we have an initial prompt stream, start streaming it
if prompt is not None and isinstance(prompt, AsyncIterable):
import asyncio
asyncio.create_task(self._query.stream_input(prompt))
async def receive_messages(self) -> AsyncIterator[Message]:
@ -164,8 +164,6 @@ class ClaudeSDKClient:
if not self._query or not self._transport:
raise CLIConnectionError("Not connected. Call connect() first.")
import json
# Handle string prompts
if isinstance(prompt, str):
message = {
@ -258,7 +256,7 @@ class ClaudeSDKClient:
async def disconnect(self) -> None:
"""Disconnect from Claude."""
if self._query:
self._query.close()
await self._query.close()
self._query = None
self._transport = None

View file

@ -1,6 +1,6 @@
"""Tests for Claude SDK client functionality."""
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
import anyio
@ -102,9 +102,12 @@ class TestQueryFunction:
"total_cost_usd": 0.001,
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
mock_transport.connect = AsyncMock()
mock_transport.disconnect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)
options = ClaudeCodeOptions(cwd="/custom/path")
messages = []

View file

@ -3,7 +3,7 @@
These tests verify end-to-end functionality with mocked CLI responses.
"""
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
import anyio
import pytest
@ -52,9 +52,12 @@ class TestIntegration:
"total_cost_usd": 0.001,
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
mock_transport.connect = AsyncMock()
mock_transport.disconnect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)
# Run query
messages = []
@ -118,9 +121,12 @@ class TestIntegration:
"total_cost_usd": 0.002,
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
mock_transport.connect = AsyncMock()
mock_transport.disconnect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)
# Run query with tools enabled
messages = []
@ -185,9 +191,12 @@ class TestIntegration:
},
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
mock_transport.connect = AsyncMock()
mock_transport.disconnect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)
# Run query with continuation
messages = []

View file

@ -1,10 +1,11 @@
"""Tests for ClaudeSDKClient streaming functionality and query() with async iterables."""
import asyncio
import json
import sys
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
import anyio
import pytest
@ -22,6 +23,63 @@ from claude_code_sdk import (
from claude_code_sdk._internal.transport.subprocess_cli import SubprocessCLITransport
def create_mock_transport(with_init_response=True):
"""Create a properly configured mock transport.
Args:
with_init_response: If True, automatically respond to initialization request
"""
mock_transport = AsyncMock()
mock_transport.connect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.write = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)
# Track written messages to simulate control protocol responses
written_messages = []
async def mock_write(data):
written_messages.append(data)
mock_transport.write.side_effect = mock_write
# Default read_messages to handle control protocol
async def control_protocol_generator():
# Wait for initialization request if needed
if with_init_response:
# Wait a bit for the write to happen
await asyncio.sleep(0.01)
# Check if initialization was requested
for msg_str in written_messages:
try:
msg = json.loads(msg_str.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype") == "initialize"
):
# Send initialization response
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default",
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Then end the stream
return
mock_transport.read_messages = control_protocol_generator
return mock_transport
class TestClaudeSDKClientStreaming:
"""Test ClaudeSDKClient streaming functionality."""
@ -32,7 +90,7 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async with ClaudeSDKClient() as client:
@ -41,7 +99,7 @@ class TestClaudeSDKClientStreaming:
assert client._transport is mock_transport
# Verify disconnect was called on exit
mock_transport.disconnect.assert_called_once()
mock_transport.close.assert_called_once()
anyio.run(_test)
@ -52,7 +110,7 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
client = ClaudeSDKClient()
@ -64,7 +122,7 @@ class TestClaudeSDKClientStreaming:
await client.disconnect()
# Verify disconnect was called
mock_transport.disconnect.assert_called_once()
mock_transport.close.assert_called_once()
assert client._transport is None
anyio.run(_test)
@ -76,7 +134,7 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
client = ClaudeSDKClient()
@ -95,7 +153,7 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async def message_stream():
@ -123,20 +181,30 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async with ClaudeSDKClient() as client:
await client.query("Test message")
# Verify send_request was called with correct format
mock_transport.send_request.assert_called_once()
call_args = mock_transport.send_request.call_args
messages, options = call_args[0]
assert len(messages) == 1
assert messages[0]["type"] == "user"
assert messages[0]["message"]["content"] == "Test message"
assert options["session_id"] == "default"
# Verify write was called with correct format
# Should have at least 2 writes: init request and user message
assert mock_transport.write.call_count >= 2
# Find the user message in the write calls
user_msg_found = False
for call in mock_transport.write.call_args_list:
data = call[0][0]
try:
msg = json.loads(data.strip())
if msg.get("type") == "user":
assert msg["message"]["content"] == "Test message"
assert msg["session_id"] == "default"
user_msg_found = True
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
assert user_msg_found, "User message not found in write calls"
anyio.run(_test)
@ -147,16 +215,25 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async with ClaudeSDKClient() as client:
await client.query("Test", session_id="custom-session")
call_args = mock_transport.send_request.call_args
messages, options = call_args[0]
assert messages[0]["session_id"] == "custom-session"
assert options["session_id"] == "custom-session"
# Find the user message with custom session ID
session_found = False
for call in mock_transport.write.call_args_list:
data = call[0][0]
try:
msg = json.loads(data.strip())
if msg.get("type") == "user":
assert msg["session_id"] == "custom-session"
session_found = True
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
assert session_found, "User message with custom session not found"
anyio.run(_test)
@ -177,11 +254,37 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
# Mock the message stream
# Mock the message stream with control protocol support
async def mock_receive():
# First handle initialization
await asyncio.sleep(0.01)
written = mock_transport.write.call_args_list
for call in written:
data = call[0][0]
try:
msg = json.loads(data.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype")
== "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
"output_style": "default",
},
}
break
except (json.JSONDecodeError, KeyError, AttributeError):
pass
# Then yield the actual messages
yield {
"type": "assistant",
"message": {
@ -195,7 +298,7 @@ class TestClaudeSDKClientStreaming:
"message": {"role": "user", "content": "Hi there"},
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
async with ClaudeSDKClient() as client:
messages = []
@ -220,7 +323,7 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
# Mock the message stream
@ -255,7 +358,7 @@ class TestClaudeSDKClientStreaming:
"model": "claude-opus-4-1-20250805",
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
async with ClaudeSDKClient() as client:
messages = []
@ -276,7 +379,7 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
async with ClaudeSDKClient() as client:
@ -308,7 +411,7 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
client = ClaudeSDKClient(options=options)
@ -327,7 +430,7 @@ class TestClaudeSDKClientStreaming:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
# Mock receive to wait then yield messages
@ -353,7 +456,7 @@ class TestClaudeSDKClientStreaming:
"total_cost_usd": 0.001,
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
async with ClaudeSDKClient() as client:
# Helper to get next message
@ -476,7 +579,7 @@ class TestClaudeSDKClientEdgeCases:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
client = ClaudeSDKClient()
@ -506,7 +609,7 @@ class TestClaudeSDKClientEdgeCases:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
with pytest.raises(ValueError):
@ -514,7 +617,7 @@ class TestClaudeSDKClientEdgeCases:
raise ValueError("Test error")
# Disconnect should still be called
mock_transport.disconnect.assert_called_once()
mock_transport.close.assert_called_once()
anyio.run(_test)
@ -525,7 +628,7 @@ class TestClaudeSDKClientEdgeCases:
with patch(
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock()
mock_transport = create_mock_transport()
mock_transport_class.return_value = mock_transport
# Mock the message stream
@ -557,7 +660,7 @@ class TestClaudeSDKClientEdgeCases:
"total_cost_usd": 0.001,
}
mock_transport.receive_messages = mock_receive
mock_transport.read_messages = mock_receive
async with ClaudeSDKClient() as client:
# Test list comprehension pattern from docstring