Fix KeyError: 'cost_usd' for Max subscription users

This fix addresses issue #2 where Max subscription users encounter a KeyError
when the API response doesn't include the 'cost_usd' field.

Changes:
- Modified ResultMessage dataclass to make cost_usd and total_cost_usd optional (float | None)
- Updated client.py to use .get() with None default for cost-related fields
- Maintains backward compatibility for users with cost data

This ensures the SDK works for all subscription types, including Max subscriptions
that don't have access to cost information.

By: Unclecode <https://github.com/unclecode>
This commit is contained in:
unclecode 2025-06-17 15:59:04 +08:00
parent 54bff2e85d
commit 18e5899d33
3 changed files with 30 additions and 4 deletions

View file

@ -84,15 +84,16 @@ class InternalClient:
case "result":
# Map total_cost to total_cost_usd for consistency
# Handle missing cost_usd for Max subscription users
return ResultMessage(
subtype=data["subtype"],
cost_usd=data["cost_usd"],
cost_usd=data.get("cost_usd", None),
duration_ms=data["duration_ms"],
duration_api_ms=data["duration_api_ms"],
is_error=data["is_error"],
num_turns=data["num_turns"],
session_id=data["session_id"],
total_cost_usd=data["total_cost"],
total_cost_usd=data.get("total_cost", None),
usage=data.get("usage"),
result=data.get("result"),
)

View file

@ -75,13 +75,13 @@ class ResultMessage:
"""Result message with cost and usage information."""
subtype: str
cost_usd: float
cost_usd: float | None
duration_ms: int
duration_api_ms: int
is_error: bool
num_turns: int
session_id: str
total_cost_usd: float
total_cost_usd: float | None
usage: dict[str, Any] | None = None
result: str | None = None

25
test_fix.py Normal file
View file

@ -0,0 +1,25 @@
#!/usr/bin/env python3
"""Test script to verify the cost_usd fix works."""
import anyio
from claude_code_sdk import process_query, ClaudeCodeOptions
async def main():
"""Test the SDK with a simple query."""
print("Testing claude-code-sdk with cost_usd fix...")
try:
# Simple test query
async for message in process_query(
"What is 2+2?",
ClaudeCodeOptions()
):
print(f"Message type: {type(message).__name__}")
print(f"Message: {message}")
print("-" * 40)
except Exception as e:
print(f"Error: {type(e).__name__}: {e}")
raise
if __name__ == "__main__":
anyio.run(main)