Fix code formatting for CI

This commit is contained in:
Lina Tawfik 2025-06-12 00:20:28 -07:00
parent 6ca3514261
commit 63ef121e18
No known key found for this signature in database
8 changed files with 101 additions and 79 deletions

View file

@ -34,8 +34,7 @@ async def with_options_example():
) )
async for message in query( async for message in query(
prompt="Explain what Python is in one sentence.", prompt="Explain what Python is in one sentence.", options=options
options=options
): ):
if isinstance(message, AssistantMessage): if isinstance(message, AssistantMessage):
for block in message.content: for block in message.content:
@ -55,7 +54,7 @@ async def with_tools_example():
async for message in query( async for message in query(
prompt="Create a file called hello.txt with 'Hello, World!' in it", prompt="Create a file called hello.txt with 'Hello, World!' in it",
options=options options=options,
): ):
if isinstance(message, AssistantMessage): if isinstance(message, AssistantMessage):
for block in message.content: for block in message.content:

View file

@ -48,9 +48,7 @@ class InternalClient:
match data["type"]: match data["type"]:
case "user": case "user":
# Extract just the content from the nested structure # Extract just the content from the nested structure
return UserMessage( return UserMessage(content=data["message"]["content"])
content=data["message"]["content"]
)
case "assistant": case "assistant":
# Parse content blocks # Parse content blocks
@ -60,24 +58,28 @@ class InternalClient:
case "text": case "text":
content_blocks.append(TextBlock(text=block["text"])) content_blocks.append(TextBlock(text=block["text"]))
case "tool_use": case "tool_use":
content_blocks.append(ToolUseBlock( content_blocks.append(
id=block["id"], ToolUseBlock(
name=block["name"], id=block["id"],
input=block["input"] name=block["name"],
)) input=block["input"],
)
)
case "tool_result": case "tool_result":
content_blocks.append(ToolResultBlock( content_blocks.append(
tool_use_id=block["tool_use_id"], ToolResultBlock(
content=block.get("content"), tool_use_id=block["tool_use_id"],
is_error=block.get("is_error") content=block.get("content"),
)) is_error=block.get("is_error"),
)
)
return AssistantMessage(content=content_blocks) return AssistantMessage(content=content_blocks)
case "system": case "system":
return SystemMessage( return SystemMessage(
subtype=data["subtype"], subtype=data["subtype"],
data=data # Pass through all data data=data, # Pass through all data
) )
case "result": case "result":
@ -92,7 +94,7 @@ class InternalClient:
session_id=data["session_id"], session_id=data["session_id"],
total_cost_usd=data["total_cost"], total_cost_usd=data["total_cost"],
usage=data.get("usage"), usage=data.get("usage"),
result=data.get("result") result=data.get("result"),
) )
case _: case _:

View file

@ -22,7 +22,10 @@ class SubprocessCLITransport(Transport):
"""Subprocess transport using Claude Code CLI.""" """Subprocess transport using Claude Code CLI."""
def __init__( def __init__(
self, prompt: str, options: ClaudeCodeOptions, cli_path: str | Path | None = None self,
prompt: str,
options: ClaudeCodeOptions,
cli_path: str | Path | None = None,
): ):
self._prompt = prompt self._prompt = prompt
self._options = options self._options = options

View file

@ -13,6 +13,7 @@ PermissionMode = Literal["default", "acceptEdits", "bypassPermissions"]
# MCP Server config # MCP Server config
class McpServerConfig(TypedDict): class McpServerConfig(TypedDict):
"""MCP server configuration.""" """MCP server configuration."""
transport: list[str] transport: list[str]
env: NotRequired[dict[str, Any]] env: NotRequired[dict[str, Any]]
@ -21,12 +22,14 @@ class McpServerConfig(TypedDict):
@dataclass @dataclass
class TextBlock: class TextBlock:
"""Text content block.""" """Text content block."""
text: str text: str
@dataclass @dataclass
class ToolUseBlock: class ToolUseBlock:
"""Tool use content block.""" """Tool use content block."""
id: str id: str
name: str name: str
input: dict[str, Any] input: dict[str, Any]
@ -35,6 +38,7 @@ class ToolUseBlock:
@dataclass @dataclass
class ToolResultBlock: class ToolResultBlock:
"""Tool result content block.""" """Tool result content block."""
tool_use_id: str tool_use_id: str
content: str | list[dict[str, Any]] | None = None content: str | list[dict[str, Any]] | None = None
is_error: bool | None = None is_error: bool | None = None
@ -47,18 +51,21 @@ ContentBlock = TextBlock | ToolUseBlock | ToolResultBlock
@dataclass @dataclass
class UserMessage: class UserMessage:
"""User message.""" """User message."""
content: str content: str
@dataclass @dataclass
class AssistantMessage: class AssistantMessage:
"""Assistant message with content blocks.""" """Assistant message with content blocks."""
content: list[ContentBlock] content: list[ContentBlock]
@dataclass @dataclass
class SystemMessage: class SystemMessage:
"""System message with metadata.""" """System message with metadata."""
subtype: str subtype: str
data: dict[str, Any] data: dict[str, Any]
@ -66,6 +73,7 @@ class SystemMessage:
@dataclass @dataclass
class ResultMessage: class ResultMessage:
"""Result message with cost and usage information.""" """Result message with cost and usage information."""
subtype: str subtype: str
cost_usd: float cost_usd: float
duration_ms: int duration_ms: int
@ -84,6 +92,7 @@ Message = UserMessage | AssistantMessage | SystemMessage | ResultMessage
@dataclass @dataclass
class ClaudeCodeOptions: class ClaudeCodeOptions:
"""Query options for Claude SDK.""" """Query options for Claude SDK."""
allowed_tools: list[str] = field(default_factory=list) allowed_tools: list[str] = field(default_factory=list)
max_thinking_tokens: int = 8000 max_thinking_tokens: int = 8000
system_prompt: str | None = None system_prompt: str | None = None

View file

@ -13,13 +13,14 @@ class TestQueryFunction:
def test_query_single_prompt(self): def test_query_single_prompt(self):
"""Test query with a single prompt.""" """Test query with a single prompt."""
async def _test(): async def _test():
with patch('claude_code_sdk._internal.client.InternalClient.process_query') as mock_process: with patch(
"claude_code_sdk._internal.client.InternalClient.process_query"
) as mock_process:
# Mock the async generator # Mock the async generator
async def mock_generator(): async def mock_generator():
yield AssistantMessage( yield AssistantMessage(content=[TextBlock(text="4")])
content=[TextBlock(text="4")]
)
mock_process.return_value = mock_generator() mock_process.return_value = mock_generator()
@ -35,12 +36,14 @@ class TestQueryFunction:
def test_query_with_options(self): def test_query_with_options(self):
"""Test query with various options.""" """Test query with various options."""
async def _test(): async def _test():
with patch('claude_code_sdk._internal.client.InternalClient.process_query') as mock_process: with patch(
"claude_code_sdk._internal.client.InternalClient.process_query"
) as mock_process:
async def mock_generator(): async def mock_generator():
yield AssistantMessage( yield AssistantMessage(content=[TextBlock(text="Hello!")])
content=[TextBlock(text="Hello!")]
)
mock_process.return_value = mock_generator() mock_process.return_value = mock_generator()
@ -48,28 +51,28 @@ class TestQueryFunction:
allowed_tools=["Read", "Write"], allowed_tools=["Read", "Write"],
system_prompt="You are helpful", system_prompt="You are helpful",
permission_mode="acceptEdits", permission_mode="acceptEdits",
max_turns=5 max_turns=5,
) )
messages = [] messages = []
async for msg in query( async for msg in query(prompt="Hi", options=options):
prompt="Hi",
options=options
):
messages.append(msg) messages.append(msg)
# Verify process_query was called with correct prompt and options # Verify process_query was called with correct prompt and options
mock_process.assert_called_once() mock_process.assert_called_once()
call_args = mock_process.call_args call_args = mock_process.call_args
assert call_args[1]['prompt'] == "Hi" assert call_args[1]["prompt"] == "Hi"
assert call_args[1]['options'] == options assert call_args[1]["options"] == options
anyio.run(_test) anyio.run(_test)
def test_query_with_cwd(self): def test_query_with_cwd(self):
"""Test query with custom working directory.""" """Test query with custom working directory."""
async def _test(): async def _test():
with patch('claude_code_sdk._internal.client.SubprocessCLITransport') as mock_transport_class: with patch(
"claude_code_sdk._internal.client.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock() mock_transport = AsyncMock()
mock_transport_class.return_value = mock_transport mock_transport_class.return_value = mock_transport
@ -79,8 +82,8 @@ class TestQueryFunction:
"type": "assistant", "type": "assistant",
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": [{"type": "text", "text": "Done"}] "content": [{"type": "text", "text": "Done"}],
} },
} }
yield { yield {
"type": "result", "type": "result",
@ -91,7 +94,7 @@ class TestQueryFunction:
"is_error": False, "is_error": False,
"num_turns": 1, "num_turns": 1,
"session_id": "test-session", "session_id": "test-session",
"total_cost": 0.001 "total_cost": 0.001,
} }
mock_transport.receive_messages = mock_receive mock_transport.receive_messages = mock_receive
@ -100,16 +103,13 @@ class TestQueryFunction:
options = ClaudeCodeOptions(cwd="/custom/path") options = ClaudeCodeOptions(cwd="/custom/path")
messages = [] messages = []
async for msg in query( async for msg in query(prompt="test", options=options):
prompt="test",
options=options
):
messages.append(msg) messages.append(msg)
# Verify transport was created with correct parameters # Verify transport was created with correct parameters
mock_transport_class.assert_called_once() mock_transport_class.assert_called_once()
call_kwargs = mock_transport_class.call_args.kwargs call_kwargs = mock_transport_class.call_args.kwargs
assert call_kwargs['prompt'] == "test" assert call_kwargs["prompt"] == "test"
assert call_kwargs['options'].cwd == "/custom/path" assert call_kwargs["options"].cwd == "/custom/path"
anyio.run(_test) anyio.run(_test)

View file

@ -23,8 +23,11 @@ class TestIntegration:
def test_simple_query_response(self): def test_simple_query_response(self):
"""Test a simple query with text response.""" """Test a simple query with text response."""
async def _test(): async def _test():
with patch("claude_code_sdk._internal.client.SubprocessCLITransport") as mock_transport_class: with patch(
"claude_code_sdk._internal.client.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock() mock_transport = AsyncMock()
mock_transport_class.return_value = mock_transport mock_transport_class.return_value = mock_transport
@ -71,12 +74,15 @@ class TestIntegration:
assert messages[1].cost_usd == 0.001 assert messages[1].cost_usd == 0.001
assert messages[1].session_id == "test-session" assert messages[1].session_id == "test-session"
anyio.run(_test) anyio.run(_test)
def test_query_with_tool_use(self): def test_query_with_tool_use(self):
"""Test query that uses tools.""" """Test query that uses tools."""
async def _test(): async def _test():
with patch("claude_code_sdk._internal.client.SubprocessCLITransport") as mock_transport_class: with patch(
"claude_code_sdk._internal.client.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock() mock_transport = AsyncMock()
mock_transport_class.return_value = mock_transport mock_transport_class.return_value = mock_transport
@ -135,25 +141,31 @@ class TestIntegration:
assert messages[0].content[1].name == "Read" assert messages[0].content[1].name == "Read"
assert messages[0].content[1].input["file_path"] == "/test.txt" assert messages[0].content[1].input["file_path"] == "/test.txt"
anyio.run(_test) anyio.run(_test)
def test_cli_not_found(self): def test_cli_not_found(self):
"""Test handling when CLI is not found.""" """Test handling when CLI is not found."""
async def _test(): async def _test():
with patch("shutil.which", return_value=None), patch( with (
"pathlib.Path.exists", return_value=False patch("shutil.which", return_value=None),
), pytest.raises(CLINotFoundError) as exc_info: patch("pathlib.Path.exists", return_value=False),
pytest.raises(CLINotFoundError) as exc_info,
):
async for _ in query(prompt="test"): async for _ in query(prompt="test"):
pass pass
assert "Claude Code requires Node.js" in str(exc_info.value) assert "Claude Code requires Node.js" in str(exc_info.value)
anyio.run(_test) anyio.run(_test)
def test_continuation_option(self): def test_continuation_option(self):
"""Test query with continue_conversation option.""" """Test query with continue_conversation option."""
async def _test(): async def _test():
with patch("claude_code_sdk._internal.client.SubprocessCLITransport") as mock_transport_class: with patch(
"claude_code_sdk._internal.client.SubprocessCLITransport"
) as mock_transport_class:
mock_transport = AsyncMock() mock_transport = AsyncMock()
mock_transport_class.return_value = mock_transport mock_transport_class.return_value = mock_transport
@ -179,13 +191,14 @@ class TestIntegration:
# Run query with continuation # Run query with continuation
messages = [] messages = []
async for msg in query( async for msg in query(
prompt="Continue", options=ClaudeCodeOptions(continue_conversation=True) prompt="Continue",
options=ClaudeCodeOptions(continue_conversation=True),
): ):
messages.append(msg) messages.append(msg)
# Verify transport was created with continuation option # Verify transport was created with continuation option
mock_transport_class.assert_called_once() mock_transport_class.assert_called_once()
call_kwargs = mock_transport_class.call_args.kwargs call_kwargs = mock_transport_class.call_args.kwargs
assert call_kwargs['options'].continue_conversation is True assert call_kwargs["options"].continue_conversation is True
anyio.run(_test) anyio.run(_test)

View file

@ -16,12 +16,12 @@ class TestSubprocessCLITransport:
"""Test CLI not found error.""" """Test CLI not found error."""
from claude_code_sdk._errors import CLINotFoundError from claude_code_sdk._errors import CLINotFoundError
with patch("shutil.which", return_value=None), patch( with (
"pathlib.Path.exists", return_value=False patch("shutil.which", return_value=None),
), pytest.raises(CLINotFoundError) as exc_info: patch("pathlib.Path.exists", return_value=False),
SubprocessCLITransport( pytest.raises(CLINotFoundError) as exc_info,
prompt="test", options=ClaudeCodeOptions() ):
) SubprocessCLITransport(prompt="test", options=ClaudeCodeOptions())
assert "Claude Code requires Node.js" in str(exc_info.value) assert "Claude Code requires Node.js" in str(exc_info.value)
@ -43,7 +43,9 @@ class TestSubprocessCLITransport:
from pathlib import Path from pathlib import Path
transport = SubprocessCLITransport( transport = SubprocessCLITransport(
prompt="Hello", options=ClaudeCodeOptions(), cli_path=Path("/usr/bin/claude") prompt="Hello",
options=ClaudeCodeOptions(),
cli_path=Path("/usr/bin/claude"),
) )
assert transport._cli_path == "/usr/bin/claude" assert transport._cli_path == "/usr/bin/claude"
@ -92,6 +94,7 @@ class TestSubprocessCLITransport:
def test_connect_disconnect(self): def test_connect_disconnect(self):
"""Test connect and disconnect lifecycle.""" """Test connect and disconnect lifecycle."""
async def _test(): async def _test():
with patch("anyio.open_process") as mock_exec: with patch("anyio.open_process") as mock_exec:
mock_process = MagicMock() mock_process = MagicMock()
@ -103,7 +106,9 @@ class TestSubprocessCLITransport:
mock_exec.return_value = mock_process mock_exec.return_value = mock_process
transport = SubprocessCLITransport( transport = SubprocessCLITransport(
prompt="test", options=ClaudeCodeOptions(), cli_path="/usr/bin/claude" prompt="test",
options=ClaudeCodeOptions(),
cli_path="/usr/bin/claude",
) )
await transport.connect() await transport.connect()

View file

@ -26,9 +26,7 @@ class TestMessageTypes:
def test_tool_use_block(self): def test_tool_use_block(self):
"""Test creating a ToolUseBlock.""" """Test creating a ToolUseBlock."""
block = ToolUseBlock( block = ToolUseBlock(
id="tool-123", id="tool-123", name="Read", input={"file_path": "/test.txt"}
name="Read",
input={"file_path": "/test.txt"}
) )
assert block.id == "tool-123" assert block.id == "tool-123"
assert block.name == "Read" assert block.name == "Read"
@ -37,9 +35,7 @@ class TestMessageTypes:
def test_tool_result_block(self): def test_tool_result_block(self):
"""Test creating a ToolResultBlock.""" """Test creating a ToolResultBlock."""
block = ToolResultBlock( block = ToolResultBlock(
tool_use_id="tool-123", tool_use_id="tool-123", content="File contents here", is_error=False
content="File contents here",
is_error=False
) )
assert block.tool_use_id == "tool-123" assert block.tool_use_id == "tool-123"
assert block.content == "File contents here" assert block.content == "File contents here"
@ -55,7 +51,7 @@ class TestMessageTypes:
is_error=False, is_error=False,
num_turns=1, num_turns=1,
session_id="session-123", session_id="session-123",
total_cost_usd=0.01 total_cost_usd=0.01,
) )
assert msg.subtype == "success" assert msg.subtype == "success"
assert msg.cost_usd == 0.01 assert msg.cost_usd == 0.01
@ -78,8 +74,7 @@ class TestOptions:
def test_claude_code_options_with_tools(self): def test_claude_code_options_with_tools(self):
"""Test Options with built-in tools.""" """Test Options with built-in tools."""
options = ClaudeCodeOptions( options = ClaudeCodeOptions(
allowed_tools=["Read", "Write", "Edit"], allowed_tools=["Read", "Write", "Edit"], disallowed_tools=["Bash"]
disallowed_tools=["Bash"]
) )
assert options.allowed_tools == ["Read", "Write", "Edit"] assert options.allowed_tools == ["Read", "Write", "Edit"]
assert options.disallowed_tools == ["Bash"] assert options.disallowed_tools == ["Bash"]
@ -93,25 +88,21 @@ class TestOptions:
"""Test Options with system prompt.""" """Test Options with system prompt."""
options = ClaudeCodeOptions( options = ClaudeCodeOptions(
system_prompt="You are a helpful assistant.", system_prompt="You are a helpful assistant.",
append_system_prompt="Be concise." append_system_prompt="Be concise.",
) )
assert options.system_prompt == "You are a helpful assistant." assert options.system_prompt == "You are a helpful assistant."
assert options.append_system_prompt == "Be concise." assert options.append_system_prompt == "Be concise."
def test_claude_code_options_with_session_continuation(self): def test_claude_code_options_with_session_continuation(self):
"""Test Options with session continuation.""" """Test Options with session continuation."""
options = ClaudeCodeOptions( options = ClaudeCodeOptions(continue_conversation=True, resume="session-123")
continue_conversation=True,
resume="session-123"
)
assert options.continue_conversation is True assert options.continue_conversation is True
assert options.resume == "session-123" assert options.resume == "session-123"
def test_claude_code_options_with_model_specification(self): def test_claude_code_options_with_model_specification(self):
"""Test Options with model specification.""" """Test Options with model specification."""
options = ClaudeCodeOptions( options = ClaudeCodeOptions(
model="claude-3-5-sonnet-20241022", model="claude-3-5-sonnet-20241022", permission_prompt_tool_name="CustomTool"
permission_prompt_tool_name="CustomTool"
) )
assert options.model == "claude-3-5-sonnet-20241022" assert options.model == "claude-3-5-sonnet-20241022"
assert options.permission_prompt_tool_name == "CustomTool" assert options.permission_prompt_tool_name == "CustomTool"