mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
Merge branch 'dickson/control' into feat/sdk-mcp-server-support
Resolved conflict by keeping both TYPE_CHECKING (needed for MCP) and contextlib.suppress (from base branch fixes)
This commit is contained in:
commit
76f6ed1d9c
8 changed files with 201 additions and 72 deletions
|
|
@ -1,5 +1,6 @@
|
|||
"""Internal client implementation."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -64,8 +65,6 @@ class InternalClient:
|
|||
# Stream input if it's an AsyncIterable
|
||||
if isinstance(prompt, AsyncIterable):
|
||||
# Start streaming in background
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(query.stream_input(prompt))
|
||||
# For string prompts, the prompt is already passed via CLI args
|
||||
|
||||
|
|
@ -74,4 +73,4 @@ class InternalClient:
|
|||
yield parse_message(data)
|
||||
|
||||
finally:
|
||||
query.close()
|
||||
await query.close()
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
|
||||
from contextlib import suppress
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .transport import Transport
|
||||
|
|
@ -146,12 +147,16 @@ class Query:
|
|||
# Regular SDK messages go to the queue
|
||||
await self._message_queue.put(message)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Task was cancelled - this is expected behavior
|
||||
logger.debug("Read task cancelled")
|
||||
raise # Re-raise to properly handle cancellation
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading messages: {e}")
|
||||
logger.error(f"Fatal error in message reader: {e}")
|
||||
# Put error in queue so iterators can handle it
|
||||
await self._message_queue.put({"type": "error", "error": str(e)})
|
||||
finally:
|
||||
# Signal end of stream
|
||||
# Always signal end of stream
|
||||
await self._message_queue.put({"type": "end"})
|
||||
|
||||
async def _handle_control_request(self, request: dict[str, Any]) -> None:
|
||||
|
|
@ -345,7 +350,7 @@ class Query:
|
|||
break
|
||||
await self.transport.write(json.dumps(message) + "\n")
|
||||
# After all messages sent, end input
|
||||
self.transport.end_input()
|
||||
await self.transport.end_input()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error streaming input: {e}")
|
||||
|
||||
|
|
@ -362,12 +367,15 @@ class Query:
|
|||
|
||||
yield message
|
||||
|
||||
def close(self) -> None:
|
||||
async def close(self) -> None:
|
||||
"""Close the query and transport."""
|
||||
self._closed = True
|
||||
if self._read_task and not self._read_task.done():
|
||||
self._read_task.cancel()
|
||||
self.transport.close()
|
||||
# Wait for task to complete cancellation
|
||||
with suppress(asyncio.CancelledError):
|
||||
await self._read_task
|
||||
await self.transport.close()
|
||||
|
||||
# Make Query an async iterator
|
||||
def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class Transport(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
async def close(self) -> None:
|
||||
"""Close the transport connection and clean up resources."""
|
||||
pass
|
||||
|
||||
|
|
@ -60,7 +60,7 @@ class Transport(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def end_input(self) -> None:
|
||||
async def end_input(self) -> None:
|
||||
"""End the input stream (close stdin for process transports)."""
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import shutil
|
|||
import tempfile
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from subprocess import PIPE
|
||||
from typing import Any
|
||||
|
|
@ -216,20 +217,31 @@ class SubprocessCLITransport(Transport):
|
|||
except Exception as e:
|
||||
raise CLIConnectionError(f"Failed to start Claude Code: {e}") from e
|
||||
|
||||
def close(self) -> None:
|
||||
async def close(self) -> None:
|
||||
"""Close the transport and clean up resources."""
|
||||
self._ready = False
|
||||
|
||||
if not self._process:
|
||||
return
|
||||
|
||||
if self._process.returncode is None:
|
||||
from contextlib import suppress
|
||||
# Close stdin first if it's still open
|
||||
if self._stdin_stream:
|
||||
with suppress(Exception):
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
|
||||
if self._process.stdin:
|
||||
with suppress(Exception):
|
||||
await self._process.stdin.aclose()
|
||||
|
||||
# Terminate and wait for process
|
||||
if self._process.returncode is None:
|
||||
with suppress(ProcessLookupError):
|
||||
self._process.terminate()
|
||||
# Note: We can't use async wait here since close() is sync
|
||||
# The process will be cleaned up by the OS
|
||||
# Wait for process to finish with timeout
|
||||
with suppress(Exception):
|
||||
# Just try to wait, but don't block if it fails
|
||||
await self._process.wait()
|
||||
|
||||
# Clean up temp file
|
||||
if self._stderr_file:
|
||||
|
|
@ -252,18 +264,15 @@ class SubprocessCLITransport(Transport):
|
|||
|
||||
await self._stdin_stream.send(data)
|
||||
|
||||
def end_input(self) -> None:
|
||||
async def end_input(self) -> None:
|
||||
"""End the input stream (close stdin)."""
|
||||
if self._stdin_stream:
|
||||
# Note: We can't use async aclose here since end_input() is sync
|
||||
# Just mark it as None and let cleanup happen later
|
||||
with suppress(Exception):
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
if self._process and self._process.stdin:
|
||||
from contextlib import suppress
|
||||
|
||||
with suppress(Exception):
|
||||
# Mark stdin as closed - actual close will happen during cleanup
|
||||
pass
|
||||
await self._process.stdin.aclose()
|
||||
|
||||
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Read and parse messages from the transport."""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""Claude SDK Client for interacting with Claude Code."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
from typing import Any
|
||||
|
|
@ -137,8 +139,6 @@ class ClaudeSDKClient:
|
|||
|
||||
# If we have an initial prompt stream, start streaming it
|
||||
if prompt is not None and isinstance(prompt, AsyncIterable):
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(self._query.stream_input(prompt))
|
||||
|
||||
async def receive_messages(self) -> AsyncIterator[Message]:
|
||||
|
|
@ -164,8 +164,6 @@ class ClaudeSDKClient:
|
|||
if not self._query or not self._transport:
|
||||
raise CLIConnectionError("Not connected. Call connect() first.")
|
||||
|
||||
import json
|
||||
|
||||
# Handle string prompts
|
||||
if isinstance(prompt, str):
|
||||
message = {
|
||||
|
|
@ -258,7 +256,7 @@ class ClaudeSDKClient:
|
|||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Claude."""
|
||||
if self._query:
|
||||
self._query.close()
|
||||
await self._query.close()
|
||||
self._query = None
|
||||
self._transport = None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Tests for Claude SDK client functionality."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import anyio
|
||||
|
||||
|
|
@ -102,9 +102,12 @@ class TestQueryFunction:
|
|||
"total_cost_usd": 0.001,
|
||||
}
|
||||
|
||||
mock_transport.receive_messages = mock_receive
|
||||
mock_transport.read_messages = mock_receive
|
||||
mock_transport.connect = AsyncMock()
|
||||
mock_transport.disconnect = AsyncMock()
|
||||
mock_transport.close = AsyncMock()
|
||||
mock_transport.end_input = AsyncMock()
|
||||
mock_transport.write = AsyncMock()
|
||||
mock_transport.is_ready = Mock(return_value=True)
|
||||
|
||||
options = ClaudeCodeOptions(cwd="/custom/path")
|
||||
messages = []
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
These tests verify end-to-end functionality with mocked CLI responses.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
|
@ -52,9 +52,12 @@ class TestIntegration:
|
|||
"total_cost_usd": 0.001,
|
||||
}
|
||||
|
||||
mock_transport.receive_messages = mock_receive
|
||||
mock_transport.read_messages = mock_receive
|
||||
mock_transport.connect = AsyncMock()
|
||||
mock_transport.disconnect = AsyncMock()
|
||||
mock_transport.close = AsyncMock()
|
||||
mock_transport.end_input = AsyncMock()
|
||||
mock_transport.write = AsyncMock()
|
||||
mock_transport.is_ready = Mock(return_value=True)
|
||||
|
||||
# Run query
|
||||
messages = []
|
||||
|
|
@ -118,9 +121,12 @@ class TestIntegration:
|
|||
"total_cost_usd": 0.002,
|
||||
}
|
||||
|
||||
mock_transport.receive_messages = mock_receive
|
||||
mock_transport.read_messages = mock_receive
|
||||
mock_transport.connect = AsyncMock()
|
||||
mock_transport.disconnect = AsyncMock()
|
||||
mock_transport.close = AsyncMock()
|
||||
mock_transport.end_input = AsyncMock()
|
||||
mock_transport.write = AsyncMock()
|
||||
mock_transport.is_ready = Mock(return_value=True)
|
||||
|
||||
# Run query with tools enabled
|
||||
messages = []
|
||||
|
|
@ -185,9 +191,12 @@ class TestIntegration:
|
|||
},
|
||||
}
|
||||
|
||||
mock_transport.receive_messages = mock_receive
|
||||
mock_transport.read_messages = mock_receive
|
||||
mock_transport.connect = AsyncMock()
|
||||
mock_transport.disconnect = AsyncMock()
|
||||
mock_transport.close = AsyncMock()
|
||||
mock_transport.end_input = AsyncMock()
|
||||
mock_transport.write = AsyncMock()
|
||||
mock_transport.is_ready = Mock(return_value=True)
|
||||
|
||||
# Run query with continuation
|
||||
messages = []
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
"""Tests for ClaudeSDKClient streaming functionality and query() with async iterables."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
|
@ -22,6 +23,63 @@ from claude_code_sdk import (
|
|||
from claude_code_sdk._internal.transport.subprocess_cli import SubprocessCLITransport
|
||||
|
||||
|
||||
def create_mock_transport(with_init_response=True):
|
||||
"""Create a properly configured mock transport.
|
||||
|
||||
Args:
|
||||
with_init_response: If True, automatically respond to initialization request
|
||||
"""
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport.connect = AsyncMock()
|
||||
mock_transport.close = AsyncMock()
|
||||
mock_transport.end_input = AsyncMock()
|
||||
mock_transport.write = AsyncMock()
|
||||
mock_transport.is_ready = Mock(return_value=True)
|
||||
|
||||
# Track written messages to simulate control protocol responses
|
||||
written_messages = []
|
||||
|
||||
async def mock_write(data):
|
||||
written_messages.append(data)
|
||||
|
||||
mock_transport.write.side_effect = mock_write
|
||||
|
||||
# Default read_messages to handle control protocol
|
||||
async def control_protocol_generator():
|
||||
# Wait for initialization request if needed
|
||||
if with_init_response:
|
||||
# Wait a bit for the write to happen
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Check if initialization was requested
|
||||
for msg_str in written_messages:
|
||||
try:
|
||||
msg = json.loads(msg_str.strip())
|
||||
if (
|
||||
msg.get("type") == "control_request"
|
||||
and msg.get("request", {}).get("subtype") == "initialize"
|
||||
):
|
||||
# Send initialization response
|
||||
yield {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"request_id": msg.get("request_id"),
|
||||
"subtype": "success",
|
||||
"commands": [],
|
||||
"output_style": "default",
|
||||
},
|
||||
}
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError, AttributeError):
|
||||
pass
|
||||
|
||||
# Then end the stream
|
||||
return
|
||||
|
||||
mock_transport.read_messages = control_protocol_generator
|
||||
return mock_transport
|
||||
|
||||
|
||||
class TestClaudeSDKClientStreaming:
|
||||
"""Test ClaudeSDKClient streaming functionality."""
|
||||
|
||||
|
|
@ -32,7 +90,7 @@ class TestClaudeSDKClientStreaming:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
|
|
@ -41,7 +99,7 @@ class TestClaudeSDKClientStreaming:
|
|||
assert client._transport is mock_transport
|
||||
|
||||
# Verify disconnect was called on exit
|
||||
mock_transport.disconnect.assert_called_once()
|
||||
mock_transport.close.assert_called_once()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
|
@ -52,7 +110,7 @@ class TestClaudeSDKClientStreaming:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
|
|
@ -64,7 +122,7 @@ class TestClaudeSDKClientStreaming:
|
|||
|
||||
await client.disconnect()
|
||||
# Verify disconnect was called
|
||||
mock_transport.disconnect.assert_called_once()
|
||||
mock_transport.close.assert_called_once()
|
||||
assert client._transport is None
|
||||
|
||||
anyio.run(_test)
|
||||
|
|
@ -76,7 +134,7 @@ class TestClaudeSDKClientStreaming:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
|
|
@ -95,7 +153,7 @@ class TestClaudeSDKClientStreaming:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async def message_stream():
|
||||
|
|
@ -123,20 +181,30 @@ class TestClaudeSDKClientStreaming:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
await client.query("Test message")
|
||||
|
||||
# Verify send_request was called with correct format
|
||||
mock_transport.send_request.assert_called_once()
|
||||
call_args = mock_transport.send_request.call_args
|
||||
messages, options = call_args[0]
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["type"] == "user"
|
||||
assert messages[0]["message"]["content"] == "Test message"
|
||||
assert options["session_id"] == "default"
|
||||
# Verify write was called with correct format
|
||||
# Should have at least 2 writes: init request and user message
|
||||
assert mock_transport.write.call_count >= 2
|
||||
|
||||
# Find the user message in the write calls
|
||||
user_msg_found = False
|
||||
for call in mock_transport.write.call_args_list:
|
||||
data = call[0][0]
|
||||
try:
|
||||
msg = json.loads(data.strip())
|
||||
if msg.get("type") == "user":
|
||||
assert msg["message"]["content"] == "Test message"
|
||||
assert msg["session_id"] == "default"
|
||||
user_msg_found = True
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError, AttributeError):
|
||||
pass
|
||||
assert user_msg_found, "User message not found in write calls"
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
|
@ -147,16 +215,25 @@ class TestClaudeSDKClientStreaming:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
await client.query("Test", session_id="custom-session")
|
||||
|
||||
call_args = mock_transport.send_request.call_args
|
||||
messages, options = call_args[0]
|
||||
assert messages[0]["session_id"] == "custom-session"
|
||||
assert options["session_id"] == "custom-session"
|
||||
# Find the user message with custom session ID
|
||||
session_found = False
|
||||
for call in mock_transport.write.call_args_list:
|
||||
data = call[0][0]
|
||||
try:
|
||||
msg = json.loads(data.strip())
|
||||
if msg.get("type") == "user":
|
||||
assert msg["session_id"] == "custom-session"
|
||||
session_found = True
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError, AttributeError):
|
||||
pass
|
||||
assert session_found, "User message with custom session not found"
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
|
@ -177,11 +254,37 @@ class TestClaudeSDKClientStreaming:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
# Mock the message stream
|
||||
# Mock the message stream with control protocol support
|
||||
async def mock_receive():
|
||||
# First handle initialization
|
||||
await asyncio.sleep(0.01)
|
||||
written = mock_transport.write.call_args_list
|
||||
for call in written:
|
||||
data = call[0][0]
|
||||
try:
|
||||
msg = json.loads(data.strip())
|
||||
if (
|
||||
msg.get("type") == "control_request"
|
||||
and msg.get("request", {}).get("subtype")
|
||||
== "initialize"
|
||||
):
|
||||
yield {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"request_id": msg.get("request_id"),
|
||||
"subtype": "success",
|
||||
"commands": [],
|
||||
"output_style": "default",
|
||||
},
|
||||
}
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError, AttributeError):
|
||||
pass
|
||||
|
||||
# Then yield the actual messages
|
||||
yield {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
|
|
@ -195,7 +298,7 @@ class TestClaudeSDKClientStreaming:
|
|||
"message": {"role": "user", "content": "Hi there"},
|
||||
}
|
||||
|
||||
mock_transport.receive_messages = mock_receive
|
||||
mock_transport.read_messages = mock_receive
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
messages = []
|
||||
|
|
@ -220,7 +323,7 @@ class TestClaudeSDKClientStreaming:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
# Mock the message stream
|
||||
|
|
@ -255,7 +358,7 @@ class TestClaudeSDKClientStreaming:
|
|||
"model": "claude-opus-4-1-20250805",
|
||||
}
|
||||
|
||||
mock_transport.receive_messages = mock_receive
|
||||
mock_transport.read_messages = mock_receive
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
messages = []
|
||||
|
|
@ -276,7 +379,7 @@ class TestClaudeSDKClientStreaming:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
|
|
@ -308,7 +411,7 @@ class TestClaudeSDKClientStreaming:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
client = ClaudeSDKClient(options=options)
|
||||
|
|
@ -327,7 +430,7 @@ class TestClaudeSDKClientStreaming:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
# Mock receive to wait then yield messages
|
||||
|
|
@ -353,7 +456,7 @@ class TestClaudeSDKClientStreaming:
|
|||
"total_cost_usd": 0.001,
|
||||
}
|
||||
|
||||
mock_transport.receive_messages = mock_receive
|
||||
mock_transport.read_messages = mock_receive
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Helper to get next message
|
||||
|
|
@ -476,7 +579,7 @@ class TestClaudeSDKClientEdgeCases:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
|
|
@ -506,7 +609,7 @@ class TestClaudeSDKClientEdgeCases:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
|
@ -514,7 +617,7 @@ class TestClaudeSDKClientEdgeCases:
|
|||
raise ValueError("Test error")
|
||||
|
||||
# Disconnect should still be called
|
||||
mock_transport.disconnect.assert_called_once()
|
||||
mock_transport.close.assert_called_once()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
|
@ -525,7 +628,7 @@ class TestClaudeSDKClientEdgeCases:
|
|||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport = create_mock_transport()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
# Mock the message stream
|
||||
|
|
@ -557,7 +660,7 @@ class TestClaudeSDKClientEdgeCases:
|
|||
"total_cost_usd": 0.001,
|
||||
}
|
||||
|
||||
mock_transport.receive_messages = mock_receive
|
||||
mock_transport.read_messages = mock_receive
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Test list comprehension pattern from docstring
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue