diff --git a/src/claude_code_sdk/_errors.py b/src/claude_code_sdk/_errors.py index e832757..8f3e759 100644 --- a/src/claude_code_sdk/_errors.py +++ b/src/claude_code_sdk/_errors.py @@ -44,3 +44,11 @@ class CLIJSONDecodeError(ClaudeSDKError): self.line = line self.original_error = original_error super().__init__(f"Failed to decode JSON: {line[:100]}...") + + +class MessageParseError(ClaudeSDKError): + """Raised when unable to parse a message from CLI output.""" + + def __init__(self, message: str, data: dict | None = None): + self.data = data + super().__init__(message) diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index c1afa9e..715dab5 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -27,9 +27,7 @@ class InternalClient: await transport.connect() async for data in transport.receive_messages(): - message = parse_message(data) - if message: - yield message + yield parse_message(data) finally: await transport.disconnect() diff --git a/src/claude_code_sdk/_internal/message_parser.py b/src/claude_code_sdk/_internal/message_parser.py index c5f4fc0..858e24f 100644 --- a/src/claude_code_sdk/_internal/message_parser.py +++ b/src/claude_code_sdk/_internal/message_parser.py @@ -3,6 +3,7 @@ import logging from typing import Any +from .._errors import MessageParseError from ..types import ( AssistantMessage, ContentBlock, @@ -18,7 +19,7 @@ from ..types import ( logger = logging.getLogger(__name__) -def parse_message(data: dict[str, Any]) -> Message | None: +def parse_message(data: dict[str, Any]) -> Message: """ Parse message from CLI output into typed Message objects. @@ -26,25 +27,29 @@ def parse_message(data: dict[str, Any]) -> Message | None: data: Raw message dictionary from CLI output Returns: - Parsed Message object or None if type is unrecognized or parsing fails - """ - try: - message_type = data.get("type") - if not message_type: - logger.warning("Message missing 'type' field: %s", data) - return None + Parsed Message object - except AttributeError: - logger.error("Invalid message data type (expected dict): %s", type(data)) - return None + Raises: + MessageParseError: If parsing fails or message type is unrecognized + """ + if not isinstance(data, dict): + raise MessageParseError( + f"Invalid message data type (expected dict, got {type(data).__name__})", + data, + ) + + message_type = data.get("type") + if not message_type: + raise MessageParseError("Message missing 'type' field", data) match message_type: case "user": try: return UserMessage(content=data["message"]["content"]) except KeyError as e: - logger.error("Missing required field in user message: %s", e) - return None + raise MessageParseError( + f"Missing required field in user message: {e}", data + ) from e case "assistant": try: @@ -72,8 +77,9 @@ def parse_message(data: dict[str, Any]) -> Message | None: return AssistantMessage(content=content_blocks) except KeyError as e: - logger.error("Missing required field in assistant message: %s", e) - return None + raise MessageParseError( + f"Missing required field in assistant message: {e}", data + ) from e case "system": try: @@ -82,8 +88,9 @@ def parse_message(data: dict[str, Any]) -> Message | None: data=data, ) except KeyError as e: - logger.error("Missing required field in system message: %s", e) - return None + raise MessageParseError( + f"Missing required field in system message: {e}", data + ) from e case "result": try: @@ -99,9 +106,9 @@ def parse_message(data: dict[str, Any]) -> Message | None: result=data.get("result"), ) except KeyError as e: - logger.error("Missing required field in result message: %s", e) - return None + raise MessageParseError( + f"Missing required field in result message: {e}", data + ) from e case _: - logger.debug("Unknown message type: %s", message_type) - return None + raise MessageParseError(f"Unknown message type: {message_type}", data) diff --git a/src/claude_code_sdk/_internal/transport/subprocess_cli.py b/src/claude_code_sdk/_internal/transport/subprocess_cli.py index 94c42fa..b39f903 100644 --- a/src/claude_code_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_code_sdk/_internal/transport/subprocess_cli.py @@ -297,8 +297,9 @@ class SubprocessCLITransport(Transport): except GeneratorExit: return except json.JSONDecodeError: - # Don't clear buffer - we might be in the middle of a split JSON message - # The buffer will be cleared when we successfully parse or hit size limit + # We are speculatively decoding the buffer until we get + # a full JSON object. If there is an actual issue, we + # raise an error after _MAX_BUFFER_SIZE. continue except anyio.ClosedResourceError: diff --git a/src/claude_code_sdk/client.py b/src/claude_code_sdk/client.py index a4c81ed..8e86ba7 100644 --- a/src/claude_code_sdk/client.py +++ b/src/claude_code_sdk/client.py @@ -126,9 +126,7 @@ class ClaudeSDKClient: from ._internal.message_parser import parse_message async for data in self._transport.receive_messages(): - message = parse_message(data) - if message: - yield message + yield parse_message(data) async def query( self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str = "default" diff --git a/tests/test_message_parser.py b/tests/test_message_parser.py new file mode 100644 index 0000000..47bd521 --- /dev/null +++ b/tests/test_message_parser.py @@ -0,0 +1,118 @@ +"""Tests for message parser error handling.""" + +import pytest + +from claude_code_sdk._errors import MessageParseError +from claude_code_sdk._internal.message_parser import parse_message +from claude_code_sdk.types import ( + AssistantMessage, + ResultMessage, + SystemMessage, + TextBlock, + ToolUseBlock, + UserMessage, +) + + +class TestMessageParser: + """Test message parsing with the new exception behavior.""" + + def test_parse_valid_user_message(self): + """Test parsing a valid user message.""" + data = {"type": "user", "message": {"content": [{"type": "text", "text": "Hello"}]}} + message = parse_message(data) + assert isinstance(message, UserMessage) + + def test_parse_valid_assistant_message(self): + """Test parsing a valid assistant message.""" + data = { + "type": "assistant", + "message": { + "content": [ + {"type": "text", "text": "Hello"}, + { + "type": "tool_use", + "id": "tool_123", + "name": "Read", + "input": {"file_path": "/test.txt"}, + }, + ] + }, + } + message = parse_message(data) + assert isinstance(message, AssistantMessage) + assert len(message.content) == 2 + assert isinstance(message.content[0], TextBlock) + assert isinstance(message.content[1], ToolUseBlock) + + def test_parse_valid_system_message(self): + """Test parsing a valid system message.""" + data = {"type": "system", "subtype": "start"} + message = parse_message(data) + assert isinstance(message, SystemMessage) + assert message.subtype == "start" + + def test_parse_valid_result_message(self): + """Test parsing a valid result message.""" + data = { + "type": "result", + "subtype": "success", + "duration_ms": 1000, + "duration_api_ms": 500, + "is_error": False, + "num_turns": 2, + "session_id": "session_123", + } + message = parse_message(data) + assert isinstance(message, ResultMessage) + assert message.subtype == "success" + + def test_parse_invalid_data_type(self): + """Test that non-dict data raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message("not a dict") # type: ignore + assert "Invalid message data type" in str(exc_info.value) + assert "expected dict, got str" in str(exc_info.value) + + def test_parse_missing_type_field(self): + """Test that missing 'type' field raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"message": {"content": []}}) + assert "Message missing 'type' field" in str(exc_info.value) + + def test_parse_unknown_message_type(self): + """Test that unknown message type raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "unknown_type"}) + assert "Unknown message type: unknown_type" in str(exc_info.value) + + def test_parse_user_message_missing_fields(self): + """Test that user message with missing fields raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "user"}) + assert "Missing required field in user message" in str(exc_info.value) + + def test_parse_assistant_message_missing_fields(self): + """Test that assistant message with missing fields raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "assistant"}) + assert "Missing required field in assistant message" in str(exc_info.value) + + def test_parse_system_message_missing_fields(self): + """Test that system message with missing fields raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "system"}) + assert "Missing required field in system message" in str(exc_info.value) + + def test_parse_result_message_missing_fields(self): + """Test that result message with missing fields raises MessageParseError.""" + with pytest.raises(MessageParseError) as exc_info: + parse_message({"type": "result", "subtype": "success"}) + assert "Missing required field in result message" in str(exc_info.value) + + def test_message_parse_error_contains_data(self): + """Test that MessageParseError contains the original data.""" + data = {"type": "unknown", "some": "data"} + with pytest.raises(MessageParseError) as exc_info: + parse_message(data) + assert exc_info.value.data == data