mirror of
https://github.com/anthropics/claude-code-sdk-python.git
synced 2025-12-23 09:19:52 +00:00
Merge pull request #75 from anthropics/dickson/streaming
Implement streaming
This commit is contained in:
commit
c4384ead71
13 changed files with 1937 additions and 132 deletions
27
CLAUDE.md
Normal file
27
CLAUDE.md
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# Workflow
|
||||
|
||||
```bash
|
||||
# Lint and style
|
||||
# Check for issues and fix automatically
|
||||
python -m ruff check src/ tests/ --fix
|
||||
python -m ruff format src/ tests/
|
||||
|
||||
# Typecheck (only done for src/)
|
||||
python -m mypy src/
|
||||
|
||||
# Run all tests
|
||||
python -m pytest tests/
|
||||
|
||||
# Run specific test file
|
||||
python -m pytest tests/test_client.py
|
||||
```
|
||||
|
||||
# Codebase Structure
|
||||
|
||||
- `src/claude_code_sdk/` - Main package
|
||||
- `client.py` - ClaudeSDKClient for interactive sessions
|
||||
- `query.py` - One-shot query function
|
||||
- `types.py` - Type definitions
|
||||
- `_internal/` - Internal implementation details
|
||||
- `transport/subprocess_cli.py` - CLI subprocess management
|
||||
- `message_parser.py` - Message parsing logic
|
||||
394
examples/streaming_mode.py
Executable file
394
examples/streaming_mode.py
Executable file
|
|
@ -0,0 +1,394 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive examples of using ClaudeSDKClient for streaming mode.
|
||||
|
||||
This file demonstrates various patterns for building applications with
|
||||
the ClaudeSDKClient streaming interface.
|
||||
|
||||
The queries are intentionally simplistic. In reality, a query can be a more
|
||||
complex task that Claude SDK uses its agentic capabilities and tools (e.g. run
|
||||
bash commands, edit files, search the web, fetch web content) to accomplish.
|
||||
|
||||
Usage:
|
||||
./examples/streaming_mode.py - List the examples
|
||||
./examples/streaming_mode.py all - Run all examples
|
||||
./examples/streaming_mode.py basic_streaming - Run a specific example
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import sys
|
||||
|
||||
from claude_code_sdk import (
|
||||
AssistantMessage,
|
||||
ClaudeCodeOptions,
|
||||
ClaudeSDKClient,
|
||||
CLIConnectionError,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
|
||||
def display_message(msg):
|
||||
"""Standardized message display function.
|
||||
|
||||
- UserMessage: "User: <content>"
|
||||
- AssistantMessage: "Claude: <content>"
|
||||
- SystemMessage: ignored
|
||||
- ResultMessage: "Result ended" + cost if available
|
||||
"""
|
||||
if isinstance(msg, UserMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"User: {block.text}")
|
||||
elif isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"Claude: {block.text}")
|
||||
elif isinstance(msg, SystemMessage):
|
||||
# Ignore system messages
|
||||
pass
|
||||
elif isinstance(msg, ResultMessage):
|
||||
print("Result ended")
|
||||
|
||||
|
||||
async def example_basic_streaming():
|
||||
"""Basic streaming with context manager."""
|
||||
print("=== Basic Streaming Example ===")
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
print("User: What is 2+2?")
|
||||
await client.query("What is 2+2?")
|
||||
|
||||
# Receive complete response using the helper method
|
||||
async for msg in client.receive_response():
|
||||
display_message(msg)
|
||||
|
||||
print("\n")
|
||||
|
||||
|
||||
async def example_multi_turn_conversation():
|
||||
"""Multi-turn conversation using receive_response helper."""
|
||||
print("=== Multi-Turn Conversation Example ===")
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
# First turn
|
||||
print("User: What's the capital of France?")
|
||||
await client.query("What's the capital of France?")
|
||||
|
||||
# Extract and print response
|
||||
async for msg in client.receive_response():
|
||||
display_message(msg)
|
||||
|
||||
# Second turn - follow-up
|
||||
print("\nUser: What's the population of that city?")
|
||||
await client.query("What's the population of that city?")
|
||||
|
||||
async for msg in client.receive_response():
|
||||
display_message(msg)
|
||||
|
||||
print("\n")
|
||||
|
||||
|
||||
async def example_concurrent_responses():
|
||||
"""Handle responses while sending new messages."""
|
||||
print("=== Concurrent Send/Receive Example ===")
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Background task to continuously receive messages
|
||||
async def receive_messages():
|
||||
async for message in client.receive_messages():
|
||||
display_message(message)
|
||||
|
||||
# Start receiving in background
|
||||
receive_task = asyncio.create_task(receive_messages())
|
||||
|
||||
# Send multiple messages with delays
|
||||
questions = [
|
||||
"What is 2 + 2?",
|
||||
"What is the square root of 144?",
|
||||
"What is 10% of 80?",
|
||||
]
|
||||
|
||||
for question in questions:
|
||||
print(f"\nUser: {question}")
|
||||
await client.query(question)
|
||||
await asyncio.sleep(3) # Wait between messages
|
||||
|
||||
# Give time for final responses
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Clean up
|
||||
receive_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await receive_task
|
||||
|
||||
print("\n")
|
||||
|
||||
|
||||
async def example_with_interrupt():
|
||||
"""Demonstrate interrupt capability."""
|
||||
print("=== Interrupt Example ===")
|
||||
print("IMPORTANT: Interrupts require active message consumption.")
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Start a long-running task
|
||||
print("\nUser: Count from 1 to 100 slowly")
|
||||
await client.query(
|
||||
"Count from 1 to 100 slowly, with a brief pause between each number"
|
||||
)
|
||||
|
||||
# Create a background task to consume messages
|
||||
messages_received = []
|
||||
interrupt_sent = False
|
||||
|
||||
async def consume_messages():
|
||||
"""Consume messages in the background to enable interrupt processing."""
|
||||
async for message in client.receive_messages():
|
||||
messages_received.append(message)
|
||||
if isinstance(message, AssistantMessage):
|
||||
for block in message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
# Print first few numbers
|
||||
print(f"Claude: {block.text[:50]}...")
|
||||
elif isinstance(message, ResultMessage):
|
||||
display_message(message)
|
||||
if interrupt_sent:
|
||||
break
|
||||
|
||||
# Start consuming messages in the background
|
||||
consume_task = asyncio.create_task(consume_messages())
|
||||
|
||||
# Wait 2 seconds then send interrupt
|
||||
await asyncio.sleep(2)
|
||||
print("\n[After 2 seconds, sending interrupt...]")
|
||||
interrupt_sent = True
|
||||
await client.interrupt()
|
||||
|
||||
# Wait for the consume task to finish processing the interrupt
|
||||
await consume_task
|
||||
|
||||
# Send new instruction after interrupt
|
||||
print("\nUser: Never mind, just tell me a quick joke")
|
||||
await client.query("Never mind, just tell me a quick joke")
|
||||
|
||||
# Get the joke
|
||||
async for msg in client.receive_response():
|
||||
display_message(msg)
|
||||
|
||||
print("\n")
|
||||
|
||||
|
||||
async def example_manual_message_handling():
|
||||
"""Manually handle message stream for custom logic."""
|
||||
print("=== Manual Message Handling Example ===")
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
await client.query(
|
||||
"List 5 programming languages and their main use cases"
|
||||
)
|
||||
|
||||
# Manually process messages with custom logic
|
||||
languages_found = []
|
||||
|
||||
async for message in client.receive_messages():
|
||||
if isinstance(message, AssistantMessage):
|
||||
for block in message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
text = block.text
|
||||
print(f"Claude: {text}")
|
||||
# Custom logic: extract language names
|
||||
for lang in [
|
||||
"Python",
|
||||
"JavaScript",
|
||||
"Java",
|
||||
"C++",
|
||||
"Go",
|
||||
"Rust",
|
||||
"Ruby",
|
||||
]:
|
||||
if lang in text and lang not in languages_found:
|
||||
languages_found.append(lang)
|
||||
print(f"Found language: {lang}")
|
||||
elif isinstance(message, ResultMessage):
|
||||
display_message(message)
|
||||
print(f"Total languages mentioned: {len(languages_found)}")
|
||||
break
|
||||
|
||||
print("\n")
|
||||
|
||||
|
||||
async def example_with_options():
|
||||
"""Use ClaudeCodeOptions to configure the client."""
|
||||
print("=== Custom Options Example ===")
|
||||
|
||||
# Configure options
|
||||
options = ClaudeCodeOptions(
|
||||
allowed_tools=["Read", "Write"], # Allow file operations
|
||||
max_thinking_tokens=10000,
|
||||
system_prompt="You are a helpful coding assistant.",
|
||||
)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
print("User: Create a simple hello.txt file with a greeting message")
|
||||
await client.query(
|
||||
"Create a simple hello.txt file with a greeting message"
|
||||
)
|
||||
|
||||
tool_uses = []
|
||||
async for msg in client.receive_response():
|
||||
if isinstance(msg, AssistantMessage):
|
||||
display_message(msg)
|
||||
for block in msg.content:
|
||||
if hasattr(block, "name") and not isinstance(
|
||||
block, TextBlock
|
||||
): # ToolUseBlock
|
||||
tool_uses.append(getattr(block, "name", ""))
|
||||
else:
|
||||
display_message(msg)
|
||||
|
||||
if tool_uses:
|
||||
print(f"Tools used: {', '.join(tool_uses)}")
|
||||
|
||||
print("\n")
|
||||
|
||||
|
||||
async def example_async_iterable_prompt():
|
||||
"""Demonstrate send_message with async iterable."""
|
||||
print("=== Async Iterable Prompt Example ===")
|
||||
|
||||
async def create_message_stream():
|
||||
"""Generate a stream of messages."""
|
||||
print("User: Hello! I have multiple questions.")
|
||||
yield {
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": "Hello! I have multiple questions."},
|
||||
"parent_tool_use_id": None,
|
||||
"session_id": "qa-session",
|
||||
}
|
||||
|
||||
print("User: First, what's the capital of Japan?")
|
||||
yield {
|
||||
"type": "user",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": "First, what's the capital of Japan?",
|
||||
},
|
||||
"parent_tool_use_id": None,
|
||||
"session_id": "qa-session",
|
||||
}
|
||||
|
||||
print("User: Second, what's 15% of 200?")
|
||||
yield {
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": "Second, what's 15% of 200?"},
|
||||
"parent_tool_use_id": None,
|
||||
"session_id": "qa-session",
|
||||
}
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Send async iterable of messages
|
||||
await client.query(create_message_stream())
|
||||
|
||||
# Receive the three responses
|
||||
async for msg in client.receive_response():
|
||||
display_message(msg)
|
||||
async for msg in client.receive_response():
|
||||
display_message(msg)
|
||||
async for msg in client.receive_response():
|
||||
display_message(msg)
|
||||
|
||||
print("\n")
|
||||
|
||||
|
||||
async def example_error_handling():
|
||||
"""Demonstrate proper error handling."""
|
||||
print("=== Error Handling Example ===")
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
|
||||
try:
|
||||
await client.connect()
|
||||
|
||||
# Send a message that will take time to process
|
||||
print("User: Run a bash sleep command for 60 seconds")
|
||||
await client.query("Run a bash sleep command for 60 seconds")
|
||||
|
||||
# Try to receive response with a short timeout
|
||||
try:
|
||||
messages = []
|
||||
async with asyncio.timeout(10.0):
|
||||
async for msg in client.receive_response():
|
||||
messages.append(msg)
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"Claude: {block.text[:50]}...")
|
||||
elif isinstance(msg, ResultMessage):
|
||||
display_message(msg)
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(
|
||||
"\nResponse timeout after 10 seconds - demonstrating graceful handling"
|
||||
)
|
||||
print(f"Received {len(messages)} messages before timeout")
|
||||
|
||||
except CLIConnectionError as e:
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {e}")
|
||||
|
||||
finally:
|
||||
# Always disconnect
|
||||
await client.disconnect()
|
||||
|
||||
print("\n")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all examples or a specific example based on command line argument."""
|
||||
examples = {
|
||||
"basic_streaming": example_basic_streaming,
|
||||
"multi_turn_conversation": example_multi_turn_conversation,
|
||||
"concurrent_responses": example_concurrent_responses,
|
||||
"with_interrupt": example_with_interrupt,
|
||||
"manual_message_handling": example_manual_message_handling,
|
||||
"with_options": example_with_options,
|
||||
"async_iterable_prompt": example_async_iterable_prompt,
|
||||
"error_handling": example_error_handling,
|
||||
}
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
# List available examples
|
||||
print("Usage: python streaming_mode.py <example_name>")
|
||||
print("\nAvailable examples:")
|
||||
print(" all - Run all examples")
|
||||
for name in examples:
|
||||
print(f" {name}")
|
||||
sys.exit(0)
|
||||
|
||||
example_name = sys.argv[1]
|
||||
|
||||
if example_name == "all":
|
||||
# Run all examples
|
||||
for example in examples.values():
|
||||
await example()
|
||||
print("-" * 50 + "\n")
|
||||
elif example_name in examples:
|
||||
# Run specific example
|
||||
await examples[example_name]()
|
||||
else:
|
||||
print(f"Error: Unknown example '{example_name}'")
|
||||
print("\nAvailable examples:")
|
||||
print(" all - Run all examples")
|
||||
for name in examples:
|
||||
print(f" {name}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
228
examples/streaming_mode_ipython.py
Normal file
228
examples/streaming_mode_ipython.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
IPython-friendly code snippets for ClaudeSDKClient streaming mode.
|
||||
|
||||
These examples are designed to be copy-pasted directly into IPython.
|
||||
Each example is self-contained and can be run independently.
|
||||
|
||||
The queries are intentionally simplistic. In reality, a query can be a more
|
||||
complex task that Claude SDK uses its agentic capabilities and tools (e.g. run
|
||||
bash commands, edit files, search the web, fetch web content) to accomplish.
|
||||
"""
|
||||
|
||||
# ============================================================================
|
||||
# BASIC STREAMING
|
||||
# ============================================================================
|
||||
|
||||
from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock, ResultMessage
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
print("User: What is 2+2?")
|
||||
await client.query("What is 2+2?")
|
||||
|
||||
async for msg in client.receive_response():
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"Claude: {block.text}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STREAMING WITH REAL-TIME DISPLAY
|
||||
# ============================================================================
|
||||
|
||||
import asyncio
|
||||
from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
async def send_and_receive(prompt):
|
||||
print(f"User: {prompt}")
|
||||
await client.query(prompt)
|
||||
async for msg in client.receive_response():
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"Claude: {block.text}")
|
||||
|
||||
await send_and_receive("Tell me a short joke")
|
||||
print("\n---\n")
|
||||
await send_and_receive("Now tell me a fun fact")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PERSISTENT CLIENT FOR MULTIPLE QUESTIONS
|
||||
# ============================================================================
|
||||
|
||||
from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock
|
||||
|
||||
# Create client
|
||||
client = ClaudeSDKClient()
|
||||
await client.connect()
|
||||
|
||||
|
||||
# Helper to get response
|
||||
async def get_response():
|
||||
async for msg in client.receive_response():
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"Claude: {block.text}")
|
||||
|
||||
|
||||
# Use it multiple times
|
||||
print("User: What's 2+2?")
|
||||
await client.query("What's 2+2?")
|
||||
await get_response()
|
||||
|
||||
print("User: What's 10*10?")
|
||||
await client.query("What's 10*10?")
|
||||
await get_response()
|
||||
|
||||
# Don't forget to disconnect when done
|
||||
await client.disconnect()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# WITH INTERRUPT CAPABILITY
|
||||
# ============================================================================
|
||||
# IMPORTANT: Interrupts require active message consumption. You must be
|
||||
# consuming messages from the client for the interrupt to be processed.
|
||||
|
||||
import asyncio
|
||||
from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock, ResultMessage
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
print("\n--- Sending initial message ---\n")
|
||||
|
||||
# Send a long-running task
|
||||
print("User: Count from 1 to 100, run bash sleep for 1 second in between")
|
||||
await client.query("Count from 1 to 100, run bash sleep for 1 second in between")
|
||||
|
||||
# Create a background task to consume messages
|
||||
messages_received = []
|
||||
interrupt_sent = False
|
||||
|
||||
async def consume_messages():
|
||||
async for msg in client.receive_messages():
|
||||
messages_received.append(msg)
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"Claude: {block.text}")
|
||||
|
||||
# Check if we got a result after interrupt
|
||||
if isinstance(msg, ResultMessage) and interrupt_sent:
|
||||
break
|
||||
|
||||
# Start consuming messages in the background
|
||||
consume_task = asyncio.create_task(consume_messages())
|
||||
|
||||
# Wait a bit then send interrupt
|
||||
await asyncio.sleep(10)
|
||||
print("\n--- Sending interrupt ---\n")
|
||||
interrupt_sent = True
|
||||
await client.interrupt()
|
||||
|
||||
# Wait for the consume task to finish
|
||||
await consume_task
|
||||
|
||||
# Send a new message after interrupt
|
||||
print("\n--- After interrupt, sending new message ---\n")
|
||||
await client.query("Just say 'Hello! I was interrupted.'")
|
||||
|
||||
async for msg in client.receive_response():
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"Claude: {block.text}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ERROR HANDLING PATTERN
|
||||
# ============================================================================
|
||||
|
||||
from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock
|
||||
|
||||
try:
|
||||
async with ClaudeSDKClient() as client:
|
||||
print("User: Run a bash sleep command for 60 seconds")
|
||||
await client.query("Run a bash sleep command for 60 seconds")
|
||||
|
||||
# Timeout after 20 seconds
|
||||
messages = []
|
||||
async with asyncio.timeout(20.0):
|
||||
async for msg in client.receive_response():
|
||||
messages.append(msg)
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"Claude: {block.text}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print("Request timed out after 20 seconds")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SENDING ASYNC ITERABLE OF MESSAGES
|
||||
# ============================================================================
|
||||
|
||||
from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock
|
||||
|
||||
async def message_generator():
|
||||
"""Generate multiple messages as an async iterable."""
|
||||
print("User: I have two math questions.")
|
||||
yield {
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": "I have two math questions."},
|
||||
"parent_tool_use_id": None,
|
||||
"session_id": "math-session"
|
||||
}
|
||||
print("User: What is 25 * 4?")
|
||||
yield {
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": "What is 25 * 4?"},
|
||||
"parent_tool_use_id": None,
|
||||
"session_id": "math-session"
|
||||
}
|
||||
print("User: What is 100 / 5?")
|
||||
yield {
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": "What is 100 / 5?"},
|
||||
"parent_tool_use_id": None,
|
||||
"session_id": "math-session"
|
||||
}
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Send async iterable instead of string
|
||||
await client.query(message_generator())
|
||||
|
||||
async for msg in client.receive_response():
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"Claude: {block.text}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# COLLECTING ALL MESSAGES INTO A LIST
|
||||
# ============================================================================
|
||||
|
||||
from claude_code_sdk import ClaudeSDKClient, AssistantMessage, TextBlock, ResultMessage
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
print("User: What are the primary colors?")
|
||||
await client.query("What are the primary colors?")
|
||||
|
||||
# Collect all messages into a list
|
||||
messages = [msg async for msg in client.receive_response()]
|
||||
|
||||
# Process them afterwards
|
||||
for msg in messages:
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"Claude: {block.text}")
|
||||
elif isinstance(msg, ResultMessage):
|
||||
print(f"Total messages: {len(messages)}")
|
||||
|
|
@ -1,8 +1,5 @@
|
|||
"""Claude SDK for Python."""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from ._errors import (
|
||||
ClaudeSDKError,
|
||||
CLIConnectionError,
|
||||
|
|
@ -10,7 +7,8 @@ from ._errors import (
|
|||
CLINotFoundError,
|
||||
ProcessError,
|
||||
)
|
||||
from ._internal.client import InternalClient
|
||||
from .client import ClaudeSDKClient
|
||||
from .query import query
|
||||
from .types import (
|
||||
AssistantMessage,
|
||||
ClaudeCodeOptions,
|
||||
|
|
@ -29,8 +27,9 @@ from .types import (
|
|||
__version__ = "0.0.14"
|
||||
|
||||
__all__ = [
|
||||
# Main function
|
||||
# Main exports
|
||||
"query",
|
||||
"ClaudeSDKClient",
|
||||
# Types
|
||||
"PermissionMode",
|
||||
"McpServerConfig",
|
||||
|
|
@ -51,52 +50,3 @@ __all__ = [
|
|||
"ProcessError",
|
||||
"CLIJSONDecodeError",
|
||||
]
|
||||
|
||||
|
||||
async def query(
|
||||
*, prompt: str, options: ClaudeCodeOptions | None = None
|
||||
) -> AsyncIterator[Message]:
|
||||
"""
|
||||
Query Claude Code.
|
||||
|
||||
Python SDK for interacting with Claude Code.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to Claude
|
||||
options: Optional configuration (defaults to ClaudeCodeOptions() if None).
|
||||
Set options.permission_mode to control tool execution:
|
||||
- 'default': CLI prompts for dangerous tools
|
||||
- 'acceptEdits': Auto-accept file edits
|
||||
- 'bypassPermissions': Allow all tools (use with caution)
|
||||
Set options.cwd for working directory.
|
||||
|
||||
Yields:
|
||||
Messages from the conversation
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Simple usage
|
||||
async for message in query(prompt="Hello"):
|
||||
print(message)
|
||||
|
||||
# With options
|
||||
async for message in query(
|
||||
prompt="Hello",
|
||||
options=ClaudeCodeOptions(
|
||||
system_prompt="You are helpful",
|
||||
cwd="/home/user"
|
||||
)
|
||||
):
|
||||
print(message)
|
||||
```
|
||||
"""
|
||||
if options is None:
|
||||
options = ClaudeCodeOptions()
|
||||
|
||||
os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py"
|
||||
|
||||
client = InternalClient()
|
||||
|
||||
async for message in client.process_query(prompt=prompt, options=options):
|
||||
yield message
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""Error types for Claude SDK."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ClaudeSDKError(Exception):
|
||||
"""Base exception for all Claude SDK errors."""
|
||||
|
|
@ -44,3 +46,11 @@ class CLIJSONDecodeError(ClaudeSDKError):
|
|||
self.line = line
|
||||
self.original_error = original_error
|
||||
super().__init__(f"Failed to decode JSON: {line[:100]}...")
|
||||
|
||||
|
||||
class MessageParseError(ClaudeSDKError):
|
||||
"""Raised when unable to parse a message from CLI output."""
|
||||
|
||||
def __init__(self, message: str, data: dict[str, Any] | None = None):
|
||||
self.data = data
|
||||
super().__init__(message)
|
||||
|
|
|
|||
|
|
@ -1,20 +1,10 @@
|
|||
"""Internal client implementation."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from ..types import (
|
||||
AssistantMessage,
|
||||
ClaudeCodeOptions,
|
||||
ContentBlock,
|
||||
Message,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
from ..types import ClaudeCodeOptions, Message
|
||||
from .message_parser import parse_message
|
||||
from .transport.subprocess_cli import SubprocessCLITransport
|
||||
|
||||
|
||||
|
|
@ -25,73 +15,19 @@ class InternalClient:
|
|||
"""Initialize the internal client."""
|
||||
|
||||
async def process_query(
|
||||
self, prompt: str, options: ClaudeCodeOptions
|
||||
self, prompt: str | AsyncIterable[dict[str, Any]], options: ClaudeCodeOptions
|
||||
) -> AsyncIterator[Message]:
|
||||
"""Process a query through transport."""
|
||||
|
||||
transport = SubprocessCLITransport(prompt=prompt, options=options)
|
||||
transport = SubprocessCLITransport(
|
||||
prompt=prompt, options=options, close_stdin_after_prompt=True
|
||||
)
|
||||
|
||||
try:
|
||||
await transport.connect()
|
||||
|
||||
async for data in transport.receive_messages():
|
||||
message = self._parse_message(data)
|
||||
if message:
|
||||
yield message
|
||||
yield parse_message(data)
|
||||
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
def _parse_message(self, data: dict[str, Any]) -> Message | None:
|
||||
"""Parse message from CLI output, trusting the structure."""
|
||||
|
||||
match data["type"]:
|
||||
case "user":
|
||||
return UserMessage(content=data["message"]["content"])
|
||||
|
||||
case "assistant":
|
||||
content_blocks: list[ContentBlock] = []
|
||||
for block in data["message"]["content"]:
|
||||
match block["type"]:
|
||||
case "text":
|
||||
content_blocks.append(TextBlock(text=block["text"]))
|
||||
case "tool_use":
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
id=block["id"],
|
||||
name=block["name"],
|
||||
input=block["input"],
|
||||
)
|
||||
)
|
||||
case "tool_result":
|
||||
content_blocks.append(
|
||||
ToolResultBlock(
|
||||
tool_use_id=block["tool_use_id"],
|
||||
content=block.get("content"),
|
||||
is_error=block.get("is_error"),
|
||||
)
|
||||
)
|
||||
|
||||
return AssistantMessage(content=content_blocks)
|
||||
|
||||
case "system":
|
||||
return SystemMessage(
|
||||
subtype=data["subtype"],
|
||||
data=data,
|
||||
)
|
||||
|
||||
case "result":
|
||||
return ResultMessage(
|
||||
subtype=data["subtype"],
|
||||
duration_ms=data["duration_ms"],
|
||||
duration_api_ms=data["duration_api_ms"],
|
||||
is_error=data["is_error"],
|
||||
num_turns=data["num_turns"],
|
||||
session_id=data["session_id"],
|
||||
total_cost_usd=data.get("total_cost_usd"),
|
||||
usage=data.get("usage"),
|
||||
result=data.get("result"),
|
||||
)
|
||||
|
||||
case _:
|
||||
return None
|
||||
|
|
|
|||
114
src/claude_code_sdk/_internal/message_parser.py
Normal file
114
src/claude_code_sdk/_internal/message_parser.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
"""Message parser for Claude Code SDK responses."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from .._errors import MessageParseError
|
||||
from ..types import (
|
||||
AssistantMessage,
|
||||
ContentBlock,
|
||||
Message,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_message(data: dict[str, Any]) -> Message:
|
||||
"""
|
||||
Parse message from CLI output into typed Message objects.
|
||||
|
||||
Args:
|
||||
data: Raw message dictionary from CLI output
|
||||
|
||||
Returns:
|
||||
Parsed Message object
|
||||
|
||||
Raises:
|
||||
MessageParseError: If parsing fails or message type is unrecognized
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
raise MessageParseError(
|
||||
f"Invalid message data type (expected dict, got {type(data).__name__})",
|
||||
data,
|
||||
)
|
||||
|
||||
message_type = data.get("type")
|
||||
if not message_type:
|
||||
raise MessageParseError("Message missing 'type' field", data)
|
||||
|
||||
match message_type:
|
||||
case "user":
|
||||
try:
|
||||
return UserMessage(content=data["message"]["content"])
|
||||
except KeyError as e:
|
||||
raise MessageParseError(
|
||||
f"Missing required field in user message: {e}", data
|
||||
) from e
|
||||
|
||||
case "assistant":
|
||||
try:
|
||||
content_blocks: list[ContentBlock] = []
|
||||
for block in data["message"]["content"]:
|
||||
match block["type"]:
|
||||
case "text":
|
||||
content_blocks.append(TextBlock(text=block["text"]))
|
||||
case "tool_use":
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
id=block["id"],
|
||||
name=block["name"],
|
||||
input=block["input"],
|
||||
)
|
||||
)
|
||||
case "tool_result":
|
||||
content_blocks.append(
|
||||
ToolResultBlock(
|
||||
tool_use_id=block["tool_use_id"],
|
||||
content=block.get("content"),
|
||||
is_error=block.get("is_error"),
|
||||
)
|
||||
)
|
||||
|
||||
return AssistantMessage(content=content_blocks)
|
||||
except KeyError as e:
|
||||
raise MessageParseError(
|
||||
f"Missing required field in assistant message: {e}", data
|
||||
) from e
|
||||
|
||||
case "system":
|
||||
try:
|
||||
return SystemMessage(
|
||||
subtype=data["subtype"],
|
||||
data=data,
|
||||
)
|
||||
except KeyError as e:
|
||||
raise MessageParseError(
|
||||
f"Missing required field in system message: {e}", data
|
||||
) from e
|
||||
|
||||
case "result":
|
||||
try:
|
||||
return ResultMessage(
|
||||
subtype=data["subtype"],
|
||||
duration_ms=data["duration_ms"],
|
||||
duration_api_ms=data["duration_api_ms"],
|
||||
is_error=data["is_error"],
|
||||
num_turns=data["num_turns"],
|
||||
session_id=data["session_id"],
|
||||
total_cost_usd=data.get("total_cost_usd"),
|
||||
usage=data.get("usage"),
|
||||
result=data.get("result"),
|
||||
)
|
||||
except KeyError as e:
|
||||
raise MessageParseError(
|
||||
f"Missing required field in result message: {e}", data
|
||||
) from e
|
||||
|
||||
case _:
|
||||
raise MessageParseError(f"Unknown message type: {message_type}", data)
|
||||
|
|
@ -4,14 +4,14 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
from pathlib import Path
|
||||
from subprocess import PIPE
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
from anyio.abc import Process
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from anyio.streams.text import TextReceiveStream, TextSendStream
|
||||
|
||||
from ..._errors import CLIConnectionError, CLINotFoundError, ProcessError
|
||||
from ..._errors import CLIJSONDecodeError as SDKJSONDecodeError
|
||||
|
|
@ -28,17 +28,23 @@ class SubprocessCLITransport(Transport):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt: str | AsyncIterable[dict[str, Any]],
|
||||
options: ClaudeCodeOptions,
|
||||
cli_path: str | Path | None = None,
|
||||
close_stdin_after_prompt: bool = False,
|
||||
):
|
||||
self._prompt = prompt
|
||||
self._is_streaming = not isinstance(prompt, str)
|
||||
self._options = options
|
||||
self._cli_path = str(cli_path) if cli_path else self._find_cli()
|
||||
self._cwd = str(options.cwd) if options.cwd else None
|
||||
self._process: Process | None = None
|
||||
self._stdout_stream: TextReceiveStream | None = None
|
||||
self._stderr_stream: TextReceiveStream | None = None
|
||||
self._stdin_stream: TextSendStream | None = None
|
||||
self._pending_control_responses: dict[str, dict[str, Any]] = {}
|
||||
self._request_counter = 0
|
||||
self._close_stdin_after_prompt = close_stdin_after_prompt
|
||||
|
||||
def _find_cli(self) -> str:
|
||||
"""Find Claude Code CLI binary."""
|
||||
|
|
@ -116,7 +122,14 @@ class SubprocessCLITransport(Transport):
|
|||
["--mcp-config", json.dumps({"mcpServers": self._options.mcp_servers})]
|
||||
)
|
||||
|
||||
cmd.extend(["--print", self._prompt])
|
||||
# Add prompt handling based on mode
|
||||
if self._is_streaming:
|
||||
# Streaming mode: use --input-format stream-json
|
||||
cmd.extend(["--input-format", "stream-json"])
|
||||
else:
|
||||
# String mode: use --print with the prompt
|
||||
cmd.extend(["--print", str(self._prompt)])
|
||||
|
||||
return cmd
|
||||
|
||||
async def connect(self) -> None:
|
||||
|
|
@ -126,9 +139,10 @@ class SubprocessCLITransport(Transport):
|
|||
|
||||
cmd = self._build_command()
|
||||
try:
|
||||
# Enable stdin pipe for both modes (but we'll close it for string mode)
|
||||
self._process = await anyio.open_process(
|
||||
cmd,
|
||||
stdin=None,
|
||||
stdin=PIPE,
|
||||
stdout=PIPE,
|
||||
stderr=PIPE,
|
||||
cwd=self._cwd,
|
||||
|
|
@ -140,6 +154,20 @@ class SubprocessCLITransport(Transport):
|
|||
if self._process.stderr:
|
||||
self._stderr_stream = TextReceiveStream(self._process.stderr)
|
||||
|
||||
# Handle stdin based on mode
|
||||
if self._is_streaming:
|
||||
# Streaming mode: keep stdin open and start streaming task
|
||||
if self._process.stdin:
|
||||
self._stdin_stream = TextSendStream(self._process.stdin)
|
||||
# Start streaming messages to stdin in background
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(self._stream_to_stdin())
|
||||
else:
|
||||
# String mode: close stdin immediately (backward compatible)
|
||||
if self._process.stdin:
|
||||
await self._process.stdin.aclose()
|
||||
|
||||
except FileNotFoundError as e:
|
||||
# Check if the error comes from the working directory or the CLI
|
||||
if self._cwd and not Path(self._cwd).exists():
|
||||
|
|
@ -169,9 +197,50 @@ class SubprocessCLITransport(Transport):
|
|||
self._process = None
|
||||
self._stdout_stream = None
|
||||
self._stderr_stream = None
|
||||
self._stdin_stream = None
|
||||
|
||||
async def send_request(self, messages: list[Any], options: dict[str, Any]) -> None:
|
||||
"""Not used for CLI transport - args passed via command line."""
|
||||
"""Send additional messages in streaming mode."""
|
||||
if not self._is_streaming:
|
||||
raise CLIConnectionError("send_request only works in streaming mode")
|
||||
|
||||
if not self._stdin_stream:
|
||||
raise CLIConnectionError("stdin not available - stream may have ended")
|
||||
|
||||
# Send each message as a user message
|
||||
for message in messages:
|
||||
# Ensure message has required structure
|
||||
if not isinstance(message, dict):
|
||||
message = {
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": str(message)},
|
||||
"parent_tool_use_id": None,
|
||||
"session_id": options.get("session_id", "default"),
|
||||
}
|
||||
|
||||
await self._stdin_stream.send(json.dumps(message) + "\n")
|
||||
|
||||
async def _stream_to_stdin(self) -> None:
|
||||
"""Stream messages to stdin for streaming mode."""
|
||||
if not self._stdin_stream or not isinstance(self._prompt, AsyncIterable):
|
||||
return
|
||||
|
||||
try:
|
||||
async for message in self._prompt:
|
||||
if not self._stdin_stream:
|
||||
break
|
||||
await self._stdin_stream.send(json.dumps(message) + "\n")
|
||||
|
||||
# Close stdin after prompt if requested (e.g., for query() one-shot mode)
|
||||
if self._close_stdin_after_prompt and self._stdin_stream:
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
# Otherwise keep stdin open for send_request (ClaudeSDKClient interactive mode)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error streaming to stdin: {e}")
|
||||
if self._stdin_stream:
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
|
||||
async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Receive messages from CLI."""
|
||||
|
|
@ -213,11 +282,24 @@ class SubprocessCLITransport(Transport):
|
|||
try:
|
||||
data = json.loads(json_buffer)
|
||||
json_buffer = ""
|
||||
|
||||
# Handle control responses separately
|
||||
if data.get("type") == "control_response":
|
||||
response = data.get("response", {})
|
||||
request_id = response.get("request_id")
|
||||
if request_id:
|
||||
# Store the response for the pending request
|
||||
self._pending_control_responses[request_id] = response
|
||||
continue
|
||||
|
||||
try:
|
||||
yield data
|
||||
except GeneratorExit:
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
# We are speculatively decoding the buffer until we get
|
||||
# a full JSON object. If there is an actual issue, we
|
||||
# raise an error after _MAX_BUFFER_SIZE.
|
||||
continue
|
||||
|
||||
except anyio.ClosedResourceError:
|
||||
|
|
@ -280,3 +362,45 @@ class SubprocessCLITransport(Transport):
|
|||
def is_connected(self) -> bool:
|
||||
"""Check if subprocess is running."""
|
||||
return self._process is not None and self._process.returncode is None
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Send interrupt control request (only works in streaming mode)."""
|
||||
if not self._is_streaming:
|
||||
raise CLIConnectionError(
|
||||
"Interrupt requires streaming mode (AsyncIterable prompt)"
|
||||
)
|
||||
|
||||
if not self._stdin_stream:
|
||||
raise CLIConnectionError("Not connected or stdin not available")
|
||||
|
||||
await self._send_control_request({"subtype": "interrupt"})
|
||||
|
||||
async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Send a control request and wait for response."""
|
||||
if not self._stdin_stream:
|
||||
raise CLIConnectionError("Stdin not available")
|
||||
|
||||
# Generate unique request ID
|
||||
self._request_counter += 1
|
||||
request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}"
|
||||
|
||||
# Build control request
|
||||
control_request = {
|
||||
"type": "control_request",
|
||||
"request_id": request_id,
|
||||
"request": request,
|
||||
}
|
||||
|
||||
# Send request
|
||||
await self._stdin_stream.send(json.dumps(control_request) + "\n")
|
||||
|
||||
# Wait for response
|
||||
while request_id not in self._pending_control_responses:
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
response = self._pending_control_responses.pop(request_id)
|
||||
|
||||
if response.get("subtype") == "error":
|
||||
raise CLIConnectionError(f"Control request failed: {response.get('error')}")
|
||||
|
||||
return response
|
||||
|
|
|
|||
226
src/claude_code_sdk/client.py
Normal file
226
src/claude_code_sdk/client.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""Claude SDK Client for interacting with Claude Code."""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from ._errors import CLIConnectionError
|
||||
from .types import ClaudeCodeOptions, Message, ResultMessage
|
||||
|
||||
|
||||
class ClaudeSDKClient:
|
||||
"""
|
||||
Client for bidirectional, interactive conversations with Claude Code.
|
||||
|
||||
This client provides full control over the conversation flow with support
|
||||
for streaming, interrupts, and dynamic message sending. For simple one-shot
|
||||
queries, consider using the query() function instead.
|
||||
|
||||
Key features:
|
||||
- **Bidirectional**: Send and receive messages at any time
|
||||
- **Stateful**: Maintains conversation context across messages
|
||||
- **Interactive**: Send follow-ups based on responses
|
||||
- **Control flow**: Support for interrupts and session management
|
||||
|
||||
When to use ClaudeSDKClient:
|
||||
- Building chat interfaces or conversational UIs
|
||||
- Interactive debugging or exploration sessions
|
||||
- Multi-turn conversations with context
|
||||
- When you need to react to Claude's responses
|
||||
- Real-time applications with user input
|
||||
- When you need interrupt capabilities
|
||||
|
||||
When to use query() instead:
|
||||
- Simple one-off questions
|
||||
- Batch processing of prompts
|
||||
- Fire-and-forget automation scripts
|
||||
- When all inputs are known upfront
|
||||
- Stateless operations
|
||||
|
||||
Example - Interactive conversation:
|
||||
```python
|
||||
# Automatically connects with empty stream for interactive use
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Send initial message
|
||||
await client.query("Let's solve a math problem step by step")
|
||||
|
||||
# Receive and process response
|
||||
async for message in client.receive_messages():
|
||||
if "ready" in str(message.content).lower():
|
||||
break
|
||||
|
||||
# Send follow-up based on response
|
||||
await client.query("What's 15% of 80?")
|
||||
|
||||
# Continue conversation...
|
||||
# Automatically disconnects
|
||||
```
|
||||
|
||||
Example - With interrupt:
|
||||
```python
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Start a long task
|
||||
await client.query("Count to 1000")
|
||||
|
||||
# Interrupt after 2 seconds
|
||||
await asyncio.sleep(2)
|
||||
await client.interrupt()
|
||||
|
||||
# Send new instruction
|
||||
await client.query("Never mind, what's 2+2?")
|
||||
```
|
||||
|
||||
Example - Manual connection:
|
||||
```python
|
||||
client = ClaudeSDKClient()
|
||||
|
||||
# Connect with initial message stream
|
||||
async def message_stream():
|
||||
yield {"type": "user", "message": {"role": "user", "content": "Hello"}}
|
||||
|
||||
await client.connect(message_stream())
|
||||
|
||||
# Send additional messages dynamically
|
||||
await client.query("What's the weather?")
|
||||
|
||||
async for message in client.receive_messages():
|
||||
print(message)
|
||||
|
||||
await client.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, options: ClaudeCodeOptions | None = None):
|
||||
"""Initialize Claude SDK client."""
|
||||
if options is None:
|
||||
options = ClaudeCodeOptions()
|
||||
self.options = options
|
||||
self._transport: Any | None = None
|
||||
os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py-client"
|
||||
|
||||
async def connect(
|
||||
self, prompt: str | AsyncIterable[dict[str, Any]] | None = None
|
||||
) -> None:
|
||||
"""Connect to Claude with a prompt or message stream."""
|
||||
from ._internal.transport.subprocess_cli import SubprocessCLITransport
|
||||
|
||||
# Auto-connect with empty async iterable if no prompt is provided
|
||||
async def _empty_stream() -> AsyncIterator[dict[str, Any]]:
|
||||
# Never yields, but indicates that this function is an iterator and
|
||||
# keeps the connection open.
|
||||
# This yield is never reached but makes this an async generator
|
||||
return
|
||||
yield {} # type: ignore[unreachable]
|
||||
|
||||
self._transport = SubprocessCLITransport(
|
||||
prompt=_empty_stream() if prompt is None else prompt,
|
||||
options=self.options,
|
||||
)
|
||||
await self._transport.connect()
|
||||
|
||||
async def receive_messages(self) -> AsyncIterator[Message]:
|
||||
"""Receive all messages from Claude."""
|
||||
if not self._transport:
|
||||
raise CLIConnectionError("Not connected. Call connect() first.")
|
||||
|
||||
from ._internal.message_parser import parse_message
|
||||
|
||||
async for data in self._transport.receive_messages():
|
||||
yield parse_message(data)
|
||||
|
||||
async def query(
|
||||
self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str = "default"
|
||||
) -> None:
|
||||
"""
|
||||
Send a new request in streaming mode.
|
||||
|
||||
Args:
|
||||
prompt: Either a string message or an async iterable of message dictionaries
|
||||
session_id: Session identifier for the conversation
|
||||
"""
|
||||
if not self._transport:
|
||||
raise CLIConnectionError("Not connected. Call connect() first.")
|
||||
|
||||
# Handle string prompts
|
||||
if isinstance(prompt, str):
|
||||
message = {
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": prompt},
|
||||
"parent_tool_use_id": None,
|
||||
"session_id": session_id,
|
||||
}
|
||||
await self._transport.send_request([message], {"session_id": session_id})
|
||||
else:
|
||||
# Handle AsyncIterable prompts
|
||||
messages = []
|
||||
async for msg in prompt:
|
||||
# Ensure session_id is set on each message
|
||||
if "session_id" not in msg:
|
||||
msg["session_id"] = session_id
|
||||
messages.append(msg)
|
||||
|
||||
if messages:
|
||||
await self._transport.send_request(messages, {"session_id": session_id})
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Send interrupt signal (only works with streaming mode)."""
|
||||
if not self._transport:
|
||||
raise CLIConnectionError("Not connected. Call connect() first.")
|
||||
await self._transport.interrupt()
|
||||
|
||||
async def receive_response(self) -> AsyncIterator[Message]:
|
||||
"""
|
||||
Receive messages from Claude until and including a ResultMessage.
|
||||
|
||||
This async iterator yields all messages in sequence and automatically terminates
|
||||
after yielding a ResultMessage (which indicates the response is complete).
|
||||
It's a convenience method over receive_messages() for single-response workflows.
|
||||
|
||||
**Stopping Behavior:**
|
||||
- Yields each message as it's received
|
||||
- Terminates immediately after yielding a ResultMessage
|
||||
- The ResultMessage IS included in the yielded messages
|
||||
- If no ResultMessage is received, the iterator continues indefinitely
|
||||
|
||||
Yields:
|
||||
Message: Each message received (UserMessage, AssistantMessage, SystemMessage, ResultMessage)
|
||||
|
||||
Example:
|
||||
```python
|
||||
async with ClaudeSDKClient() as client:
|
||||
await client.query("What's the capital of France?")
|
||||
|
||||
async for msg in client.receive_response():
|
||||
if isinstance(msg, AssistantMessage):
|
||||
for block in msg.content:
|
||||
if isinstance(block, TextBlock):
|
||||
print(f"Claude: {block.text}")
|
||||
elif isinstance(msg, ResultMessage):
|
||||
print(f"Cost: ${msg.total_cost_usd:.4f}")
|
||||
# Iterator will terminate after this message
|
||||
```
|
||||
|
||||
Note:
|
||||
To collect all messages: `messages = [msg async for msg in client.receive_response()]`
|
||||
The final message in the list will always be a ResultMessage.
|
||||
"""
|
||||
async for message in self.receive_messages():
|
||||
yield message
|
||||
if isinstance(message, ResultMessage):
|
||||
return
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Claude."""
|
||||
if self._transport:
|
||||
await self._transport.disconnect()
|
||||
self._transport = None
|
||||
|
||||
async def __aenter__(self) -> "ClaudeSDKClient":
|
||||
"""Enter async context - automatically connects with empty stream for interactive use."""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
|
||||
"""Exit async context - always disconnects."""
|
||||
await self.disconnect()
|
||||
return False
|
||||
102
src/claude_code_sdk/query.py
Normal file
102
src/claude_code_sdk/query.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""Query function for one-shot interactions with Claude Code."""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from ._internal.client import InternalClient
|
||||
from .types import ClaudeCodeOptions, Message
|
||||
|
||||
|
||||
async def query(
|
||||
*,
|
||||
prompt: str | AsyncIterable[dict[str, Any]],
|
||||
options: ClaudeCodeOptions | None = None,
|
||||
) -> AsyncIterator[Message]:
|
||||
"""
|
||||
Query Claude Code for one-shot or unidirectional streaming interactions.
|
||||
|
||||
This function is ideal for simple, stateless queries where you don't need
|
||||
bidirectional communication or conversation management. For interactive,
|
||||
stateful conversations, use ClaudeSDKClient instead.
|
||||
|
||||
Key differences from ClaudeSDKClient:
|
||||
- **Unidirectional**: Send all messages upfront, receive all responses
|
||||
- **Stateless**: Each query is independent, no conversation state
|
||||
- **Simple**: Fire-and-forget style, no connection management
|
||||
- **No interrupts**: Cannot interrupt or send follow-up messages
|
||||
|
||||
When to use query():
|
||||
- Simple one-off questions ("What is 2+2?")
|
||||
- Batch processing of independent prompts
|
||||
- Code generation or analysis tasks
|
||||
- Automated scripts and CI/CD pipelines
|
||||
- When you know all inputs upfront
|
||||
|
||||
When to use ClaudeSDKClient:
|
||||
- Interactive conversations with follow-ups
|
||||
- Chat applications or REPL-like interfaces
|
||||
- When you need to send messages based on responses
|
||||
- When you need interrupt capabilities
|
||||
- Long-running sessions with state
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to Claude. Can be a string for single-shot queries
|
||||
or an AsyncIterable[dict] for streaming mode with continuous interaction.
|
||||
In streaming mode, each dict should have the structure:
|
||||
{
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": "..."},
|
||||
"parent_tool_use_id": None,
|
||||
"session_id": "..."
|
||||
}
|
||||
options: Optional configuration (defaults to ClaudeCodeOptions() if None).
|
||||
Set options.permission_mode to control tool execution:
|
||||
- 'default': CLI prompts for dangerous tools
|
||||
- 'acceptEdits': Auto-accept file edits
|
||||
- 'bypassPermissions': Allow all tools (use with caution)
|
||||
Set options.cwd for working directory.
|
||||
|
||||
Yields:
|
||||
Messages from the conversation
|
||||
|
||||
Example - Simple query:
|
||||
```python
|
||||
# One-off question
|
||||
async for message in query(prompt="What is the capital of France?"):
|
||||
print(message)
|
||||
```
|
||||
|
||||
Example - With options:
|
||||
```python
|
||||
# Code generation with specific settings
|
||||
async for message in query(
|
||||
prompt="Create a Python web server",
|
||||
options=ClaudeCodeOptions(
|
||||
system_prompt="You are an expert Python developer",
|
||||
cwd="/home/user/project"
|
||||
)
|
||||
):
|
||||
print(message)
|
||||
```
|
||||
|
||||
Example - Streaming mode (still unidirectional):
|
||||
```python
|
||||
async def prompts():
|
||||
yield {"type": "user", "message": {"role": "user", "content": "Hello"}}
|
||||
yield {"type": "user", "message": {"role": "user", "content": "How are you?"}}
|
||||
|
||||
# All prompts are sent, then all responses received
|
||||
async for message in query(prompt=prompts()):
|
||||
print(message)
|
||||
```
|
||||
"""
|
||||
if options is None:
|
||||
options = ClaudeCodeOptions()
|
||||
|
||||
os.environ["CLAUDE_CODE_ENTRYPOINT"] = "sdk-py"
|
||||
|
||||
client = InternalClient()
|
||||
|
||||
async for message in client.process_query(prompt=prompt, options=options):
|
||||
yield message
|
||||
121
tests/test_message_parser.py
Normal file
121
tests/test_message_parser.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"""Tests for message parser error handling."""
|
||||
|
||||
import pytest
|
||||
|
||||
from claude_code_sdk._errors import MessageParseError
|
||||
from claude_code_sdk._internal.message_parser import parse_message
|
||||
from claude_code_sdk.types import (
|
||||
AssistantMessage,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
|
||||
class TestMessageParser:
|
||||
"""Test message parsing with the new exception behavior."""
|
||||
|
||||
def test_parse_valid_user_message(self):
|
||||
"""Test parsing a valid user message."""
|
||||
data = {
|
||||
"type": "user",
|
||||
"message": {"content": [{"type": "text", "text": "Hello"}]},
|
||||
}
|
||||
message = parse_message(data)
|
||||
assert isinstance(message, UserMessage)
|
||||
|
||||
def test_parse_valid_assistant_message(self):
|
||||
"""Test parsing a valid assistant message."""
|
||||
data = {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"content": [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tool_123",
|
||||
"name": "Read",
|
||||
"input": {"file_path": "/test.txt"},
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
message = parse_message(data)
|
||||
assert isinstance(message, AssistantMessage)
|
||||
assert len(message.content) == 2
|
||||
assert isinstance(message.content[0], TextBlock)
|
||||
assert isinstance(message.content[1], ToolUseBlock)
|
||||
|
||||
def test_parse_valid_system_message(self):
|
||||
"""Test parsing a valid system message."""
|
||||
data = {"type": "system", "subtype": "start"}
|
||||
message = parse_message(data)
|
||||
assert isinstance(message, SystemMessage)
|
||||
assert message.subtype == "start"
|
||||
|
||||
def test_parse_valid_result_message(self):
|
||||
"""Test parsing a valid result message."""
|
||||
data = {
|
||||
"type": "result",
|
||||
"subtype": "success",
|
||||
"duration_ms": 1000,
|
||||
"duration_api_ms": 500,
|
||||
"is_error": False,
|
||||
"num_turns": 2,
|
||||
"session_id": "session_123",
|
||||
}
|
||||
message = parse_message(data)
|
||||
assert isinstance(message, ResultMessage)
|
||||
assert message.subtype == "success"
|
||||
|
||||
def test_parse_invalid_data_type(self):
|
||||
"""Test that non-dict data raises MessageParseError."""
|
||||
with pytest.raises(MessageParseError) as exc_info:
|
||||
parse_message("not a dict") # type: ignore
|
||||
assert "Invalid message data type" in str(exc_info.value)
|
||||
assert "expected dict, got str" in str(exc_info.value)
|
||||
|
||||
def test_parse_missing_type_field(self):
|
||||
"""Test that missing 'type' field raises MessageParseError."""
|
||||
with pytest.raises(MessageParseError) as exc_info:
|
||||
parse_message({"message": {"content": []}})
|
||||
assert "Message missing 'type' field" in str(exc_info.value)
|
||||
|
||||
def test_parse_unknown_message_type(self):
|
||||
"""Test that unknown message type raises MessageParseError."""
|
||||
with pytest.raises(MessageParseError) as exc_info:
|
||||
parse_message({"type": "unknown_type"})
|
||||
assert "Unknown message type: unknown_type" in str(exc_info.value)
|
||||
|
||||
def test_parse_user_message_missing_fields(self):
|
||||
"""Test that user message with missing fields raises MessageParseError."""
|
||||
with pytest.raises(MessageParseError) as exc_info:
|
||||
parse_message({"type": "user"})
|
||||
assert "Missing required field in user message" in str(exc_info.value)
|
||||
|
||||
def test_parse_assistant_message_missing_fields(self):
|
||||
"""Test that assistant message with missing fields raises MessageParseError."""
|
||||
with pytest.raises(MessageParseError) as exc_info:
|
||||
parse_message({"type": "assistant"})
|
||||
assert "Missing required field in assistant message" in str(exc_info.value)
|
||||
|
||||
def test_parse_system_message_missing_fields(self):
|
||||
"""Test that system message with missing fields raises MessageParseError."""
|
||||
with pytest.raises(MessageParseError) as exc_info:
|
||||
parse_message({"type": "system"})
|
||||
assert "Missing required field in system message" in str(exc_info.value)
|
||||
|
||||
def test_parse_result_message_missing_fields(self):
|
||||
"""Test that result message with missing fields raises MessageParseError."""
|
||||
with pytest.raises(MessageParseError) as exc_info:
|
||||
parse_message({"type": "result", "subtype": "success"})
|
||||
assert "Missing required field in result message" in str(exc_info.value)
|
||||
|
||||
def test_message_parse_error_contains_data(self):
|
||||
"""Test that MessageParseError contains the original data."""
|
||||
data = {"type": "unknown", "some": "data"}
|
||||
with pytest.raises(MessageParseError) as exc_info:
|
||||
parse_message(data)
|
||||
assert exc_info.value.data == data
|
||||
567
tests/test_streaming_client.py
Normal file
567
tests/test_streaming_client.py
Normal file
|
|
@ -0,0 +1,567 @@
|
|||
"""Tests for ClaudeSDKClient streaming functionality and query() with async iterables."""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from claude_code_sdk import (
|
||||
AssistantMessage,
|
||||
ClaudeCodeOptions,
|
||||
ClaudeSDKClient,
|
||||
CLIConnectionError,
|
||||
ResultMessage,
|
||||
TextBlock,
|
||||
UserMessage,
|
||||
query,
|
||||
)
|
||||
from claude_code_sdk._internal.transport.subprocess_cli import SubprocessCLITransport
|
||||
|
||||
|
||||
class TestClaudeSDKClientStreaming:
|
||||
"""Test ClaudeSDKClient streaming functionality."""
|
||||
|
||||
def test_auto_connect_with_context_manager(self):
|
||||
"""Test automatic connection when using context manager."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Verify connect was called
|
||||
mock_transport.connect.assert_called_once()
|
||||
assert client._transport is mock_transport
|
||||
|
||||
# Verify disconnect was called on exit
|
||||
mock_transport.disconnect.assert_called_once()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_manual_connect_disconnect(self):
|
||||
"""Test manual connect and disconnect."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
await client.connect()
|
||||
|
||||
# Verify connect was called
|
||||
mock_transport.connect.assert_called_once()
|
||||
assert client._transport is mock_transport
|
||||
|
||||
await client.disconnect()
|
||||
# Verify disconnect was called
|
||||
mock_transport.disconnect.assert_called_once()
|
||||
assert client._transport is None
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_connect_with_string_prompt(self):
|
||||
"""Test connecting with a string prompt."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
await client.connect("Hello Claude")
|
||||
|
||||
# Verify transport was created with string prompt
|
||||
call_kwargs = mock_transport_class.call_args.kwargs
|
||||
assert call_kwargs["prompt"] == "Hello Claude"
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_connect_with_async_iterable(self):
|
||||
"""Test connecting with an async iterable."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async def message_stream():
|
||||
yield {"type": "user", "message": {"role": "user", "content": "Hi"}}
|
||||
yield {
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": "Bye"},
|
||||
}
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
stream = message_stream()
|
||||
await client.connect(stream)
|
||||
|
||||
# Verify transport was created with async iterable
|
||||
call_kwargs = mock_transport_class.call_args.kwargs
|
||||
# Should be the same async iterator
|
||||
assert call_kwargs["prompt"] is stream
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_query(self):
|
||||
"""Test sending a query."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
await client.query("Test message")
|
||||
|
||||
# Verify send_request was called with correct format
|
||||
mock_transport.send_request.assert_called_once()
|
||||
call_args = mock_transport.send_request.call_args
|
||||
messages, options = call_args[0]
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["type"] == "user"
|
||||
assert messages[0]["message"]["content"] == "Test message"
|
||||
assert options["session_id"] == "default"
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_send_message_with_session_id(self):
|
||||
"""Test sending a message with custom session ID."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
await client.query("Test", session_id="custom-session")
|
||||
|
||||
call_args = mock_transport.send_request.call_args
|
||||
messages, options = call_args[0]
|
||||
assert messages[0]["session_id"] == "custom-session"
|
||||
assert options["session_id"] == "custom-session"
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_send_message_not_connected(self):
|
||||
"""Test sending message when not connected raises error."""
|
||||
|
||||
async def _test():
|
||||
client = ClaudeSDKClient()
|
||||
with pytest.raises(CLIConnectionError, match="Not connected"):
|
||||
await client.query("Test")
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_receive_messages(self):
|
||||
"""Test receiving messages."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
# Mock the message stream
|
||||
async def mock_receive():
|
||||
yield {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Hello!"}],
|
||||
},
|
||||
}
|
||||
yield {
|
||||
"type": "user",
|
||||
"message": {"role": "user", "content": "Hi there"},
|
||||
}
|
||||
|
||||
mock_transport.receive_messages = mock_receive
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
messages = []
|
||||
async for msg in client.receive_messages():
|
||||
messages.append(msg)
|
||||
if len(messages) == 2:
|
||||
break
|
||||
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], AssistantMessage)
|
||||
assert isinstance(messages[0].content[0], TextBlock)
|
||||
assert messages[0].content[0].text == "Hello!"
|
||||
assert isinstance(messages[1], UserMessage)
|
||||
assert messages[1].content == "Hi there"
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_receive_response(self):
|
||||
"""Test receive_response stops at ResultMessage."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
# Mock the message stream
|
||||
async def mock_receive():
|
||||
yield {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Answer"}],
|
||||
},
|
||||
}
|
||||
yield {
|
||||
"type": "result",
|
||||
"subtype": "success",
|
||||
"duration_ms": 1000,
|
||||
"duration_api_ms": 800,
|
||||
"is_error": False,
|
||||
"num_turns": 1,
|
||||
"session_id": "test",
|
||||
"total_cost_usd": 0.001,
|
||||
}
|
||||
# This should not be yielded
|
||||
yield {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Should not see this"}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
mock_transport.receive_messages = mock_receive
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
messages = []
|
||||
async for msg in client.receive_response():
|
||||
messages.append(msg)
|
||||
|
||||
# Should only get 2 messages (assistant + result)
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], AssistantMessage)
|
||||
assert isinstance(messages[1], ResultMessage)
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_interrupt(self):
|
||||
"""Test interrupt functionality."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
await client.interrupt()
|
||||
mock_transport.interrupt.assert_called_once()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_interrupt_not_connected(self):
|
||||
"""Test interrupt when not connected raises error."""
|
||||
|
||||
async def _test():
|
||||
client = ClaudeSDKClient()
|
||||
with pytest.raises(CLIConnectionError, match="Not connected"):
|
||||
await client.interrupt()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_client_with_options(self):
|
||||
"""Test client initialization with options."""
|
||||
|
||||
async def _test():
|
||||
options = ClaudeCodeOptions(
|
||||
cwd="/custom/path",
|
||||
allowed_tools=["Read", "Write"],
|
||||
system_prompt="Be helpful",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
client = ClaudeSDKClient(options=options)
|
||||
await client.connect()
|
||||
|
||||
# Verify options were passed to transport
|
||||
call_kwargs = mock_transport_class.call_args.kwargs
|
||||
assert call_kwargs["options"] is options
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_concurrent_send_receive(self):
|
||||
"""Test concurrent sending and receiving messages."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
# Mock receive to wait then yield messages
|
||||
async def mock_receive():
|
||||
await asyncio.sleep(0.1)
|
||||
yield {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Response 1"}],
|
||||
},
|
||||
}
|
||||
await asyncio.sleep(0.1)
|
||||
yield {
|
||||
"type": "result",
|
||||
"subtype": "success",
|
||||
"duration_ms": 1000,
|
||||
"duration_api_ms": 800,
|
||||
"is_error": False,
|
||||
"num_turns": 1,
|
||||
"session_id": "test",
|
||||
"total_cost_usd": 0.001,
|
||||
}
|
||||
|
||||
mock_transport.receive_messages = mock_receive
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Helper to get next message
|
||||
async def get_next_message():
|
||||
return await client.receive_response().__anext__()
|
||||
|
||||
# Start receiving in background
|
||||
receive_task = asyncio.create_task(get_next_message())
|
||||
|
||||
# Send message while receiving
|
||||
await client.query("Question 1")
|
||||
|
||||
# Wait for first message
|
||||
first_msg = await receive_task
|
||||
assert isinstance(first_msg, AssistantMessage)
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
||||
class TestQueryWithAsyncIterable:
|
||||
"""Test query() function with async iterable inputs."""
|
||||
|
||||
def test_query_with_async_iterable(self):
|
||||
"""Test query with async iterable of messages."""
|
||||
|
||||
async def _test():
|
||||
async def message_stream():
|
||||
yield {"type": "user", "message": {"role": "user", "content": "First"}}
|
||||
yield {"type": "user", "message": {"role": "user", "content": "Second"}}
|
||||
|
||||
# Create a simple test script that validates stdin and outputs a result
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
test_script = f.name
|
||||
f.write("""#!/usr/bin/env python3
|
||||
import sys
|
||||
import json
|
||||
|
||||
# Read stdin messages
|
||||
stdin_messages = []
|
||||
while True:
|
||||
line = sys.stdin.readline()
|
||||
if not line:
|
||||
break
|
||||
stdin_messages.append(line.strip())
|
||||
|
||||
# Verify we got 2 messages
|
||||
assert len(stdin_messages) == 2
|
||||
assert '"First"' in stdin_messages[0]
|
||||
assert '"Second"' in stdin_messages[1]
|
||||
|
||||
# Output a valid result
|
||||
print('{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}')
|
||||
""")
|
||||
|
||||
Path(test_script).chmod(0o755)
|
||||
|
||||
try:
|
||||
# Mock _find_cli to return python executing our test script
|
||||
with patch.object(
|
||||
SubprocessCLITransport, "_find_cli", return_value=sys.executable
|
||||
):
|
||||
# Mock _build_command to add our test script as first argument
|
||||
original_build_command = SubprocessCLITransport._build_command
|
||||
|
||||
def mock_build_command(self):
|
||||
# Get original command
|
||||
cmd = original_build_command(self)
|
||||
# Replace the CLI path with python + script
|
||||
cmd[0] = test_script
|
||||
return cmd
|
||||
|
||||
with patch.object(
|
||||
SubprocessCLITransport, "_build_command", mock_build_command
|
||||
):
|
||||
# Run query with async iterable
|
||||
messages = []
|
||||
async for msg in query(prompt=message_stream()):
|
||||
messages.append(msg)
|
||||
|
||||
# Should get the result message
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0], ResultMessage)
|
||||
assert messages[0].subtype == "success"
|
||||
finally:
|
||||
# Clean up
|
||||
Path(test_script).unlink()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
|
||||
class TestClaudeSDKClientEdgeCases:
|
||||
"""Test edge cases and error scenarios."""
|
||||
|
||||
def test_receive_messages_not_connected(self):
|
||||
"""Test receiving messages when not connected."""
|
||||
|
||||
async def _test():
|
||||
client = ClaudeSDKClient()
|
||||
with pytest.raises(CLIConnectionError, match="Not connected"):
|
||||
async for _ in client.receive_messages():
|
||||
pass
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_receive_response_not_connected(self):
|
||||
"""Test receive_response when not connected."""
|
||||
|
||||
async def _test():
|
||||
client = ClaudeSDKClient()
|
||||
with pytest.raises(CLIConnectionError, match="Not connected"):
|
||||
async for _ in client.receive_response():
|
||||
pass
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_double_connect(self):
|
||||
"""Test connecting twice."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
client = ClaudeSDKClient()
|
||||
await client.connect()
|
||||
# Second connect should create new transport
|
||||
await client.connect()
|
||||
|
||||
# Should have been called twice
|
||||
assert mock_transport_class.call_count == 2
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_disconnect_without_connect(self):
|
||||
"""Test disconnecting without connecting first."""
|
||||
|
||||
async def _test():
|
||||
client = ClaudeSDKClient()
|
||||
# Should not raise error
|
||||
await client.disconnect()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_context_manager_with_exception(self):
|
||||
"""Test context manager cleans up on exception."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with ClaudeSDKClient():
|
||||
raise ValueError("Test error")
|
||||
|
||||
# Disconnect should still be called
|
||||
mock_transport.disconnect.assert_called_once()
|
||||
|
||||
anyio.run(_test)
|
||||
|
||||
def test_receive_response_list_comprehension(self):
|
||||
"""Test collecting messages with list comprehension as shown in examples."""
|
||||
|
||||
async def _test():
|
||||
with patch(
|
||||
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
|
||||
) as mock_transport_class:
|
||||
mock_transport = AsyncMock()
|
||||
mock_transport_class.return_value = mock_transport
|
||||
|
||||
# Mock the message stream
|
||||
async def mock_receive():
|
||||
yield {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Hello"}],
|
||||
},
|
||||
}
|
||||
yield {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "World"}],
|
||||
},
|
||||
}
|
||||
yield {
|
||||
"type": "result",
|
||||
"subtype": "success",
|
||||
"duration_ms": 1000,
|
||||
"duration_api_ms": 800,
|
||||
"is_error": False,
|
||||
"num_turns": 1,
|
||||
"session_id": "test",
|
||||
"total_cost_usd": 0.001,
|
||||
}
|
||||
|
||||
mock_transport.receive_messages = mock_receive
|
||||
|
||||
async with ClaudeSDKClient() as client:
|
||||
# Test list comprehension pattern from docstring
|
||||
messages = [msg async for msg in client.receive_response()]
|
||||
|
||||
assert len(messages) == 3
|
||||
assert all(
|
||||
isinstance(msg, AssistantMessage | ResultMessage)
|
||||
for msg in messages
|
||||
)
|
||||
assert isinstance(messages[-1], ResultMessage)
|
||||
|
||||
anyio.run(_test)
|
||||
|
|
@ -103,6 +103,12 @@ class TestSubprocessCLITransport:
|
|||
mock_process.wait = AsyncMock()
|
||||
mock_process.stdout = MagicMock()
|
||||
mock_process.stderr = MagicMock()
|
||||
|
||||
# Mock stdin with aclose method
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.aclose = AsyncMock()
|
||||
mock_process.stdin = mock_stdin
|
||||
|
||||
mock_exec.return_value = mock_process
|
||||
|
||||
transport = SubprocessCLITransport(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue