Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/anthropic/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,24 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes:
return file

if is_tuple_t(file):
return (file[0], read_file_content(file[1]), *file[2:])
return cast(HttpxFileTypes, _transform_file_tuple(file))

raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")


def _transform_file_tuple(file: tuple[object, ...]) -> tuple[object, ...]:
# Copy mutable entries in file tuples to prevent shared state.
# File tuples can be: (filename, content), (filename, content, content_type),
# or (filename, content, content_type, headers) where headers is a mutable Mapping.
result: list[object] = [file[0], read_file_content(file[1])]
for item in file[2:]:
if isinstance(item, dict):
result.append(dict(item))
else:
result.append(item)
return tuple(result)


def read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return pathlib.Path(file).read_bytes()
Expand Down Expand Up @@ -113,7 +126,7 @@ async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
return file

if is_tuple_t(file):
return (file[0], await async_read_file_content(file[1]), *file[2:])
return cast(HttpxFileTypes, _transform_file_tuple((file[0], await async_read_file_content(file[1]), *file[2:])))

raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")

Expand Down
17 changes: 14 additions & 3 deletions src/anthropic/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import httpx

from ._utils import is_dict, extract_type_var_from_base
from ._exceptions import APIConnectionError

if TYPE_CHECKING:
from ._client import Anthropic, AsyncAnthropic
Expand Down Expand Up @@ -72,7 +73,12 @@ def __iter__(self) -> Iterator[_T]:
yield item

def _iter_events(self) -> Iterator[ServerSentEvent]:
yield from self._decoder.iter_bytes(self.response.iter_bytes())
try:
yield from self._decoder.iter_bytes(self.response.iter_bytes())
except httpx.TimeoutException:
raise
except httpx.TransportError as err:
raise APIConnectionError(request=self.response.request) from err

def __stream__(self) -> Iterator[_T]:
cast_to = cast(Any, self._cast_to)
Expand Down Expand Up @@ -226,8 +232,13 @@ async def __aiter__(self) -> AsyncIterator[_T]:
yield item

async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
yield sse
try:
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
yield sse
except httpx.TimeoutException:
raise
except httpx.TransportError as err:
raise APIConnectionError(request=self.response.request) from err

async def __stream__(self) -> AsyncIterator[_T]:
cast_to = cast(Any, self._cast_to)
Expand Down
75 changes: 72 additions & 3 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from anthropic import Anthropic, AsyncAnthropic
from anthropic._streaming import Stream, AsyncStream, ServerSentEvent
from anthropic._exceptions import APIStatusError
from anthropic._exceptions import APIStatusError, APIConnectionError

_T = TypeVar("_T")

Expand Down Expand Up @@ -238,6 +238,60 @@ def body() -> Iterator[bytes]:
assert "Overloaded" in str(exc_info.value)


@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
async def test_mid_stream_transport_error_wrapped(sync: bool, client: Anthropic, async_client: AsyncAnthropic) -> None:
"""Mid-stream httpx.RemoteProtocolError is wrapped as APIConnectionError with __cause__ preserved."""

def body() -> Iterator[bytes]:
yield b"event: completion\n"
yield b'data: {"type":"message","content":[]}\n'
yield b"\n"
raise httpx.RemoteProtocolError("peer closed connection without sending complete message body")

request = httpx.Request("GET", "https://api.anthropic.com/v1/messages/stream")
iterator = make_event_iterator_with_request(
content=body(), sync=sync, client=client, async_client=async_client, request=request
)

# First event should succeed
sse = await iter_next(iterator)
assert sse.event == "completion"

# Second read should raise APIConnectionError, not the bare httpx error
with pytest.raises(APIConnectionError) as exc_info:
await iter_next(iterator)

assert isinstance(exc_info.value.__cause__, httpx.RemoteProtocolError)


@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
async def test_mid_stream_read_timeout_passes_through(
sync: bool,
client: Anthropic,
async_client: AsyncAnthropic,
) -> None:
"""Mid-stream httpx.ReadTimeout passes through unchanged (not double-wrapped)."""

def body() -> Iterator[bytes]:
yield b"event: completion\n"
yield b'data: {"type":"message","content":[]}\n'
yield b"\n"
raise httpx.ReadTimeout("read timeout")

request = httpx.Request("GET", "https://api.anthropic.com/v1/messages/stream")
iterator = make_event_iterator_with_request(
content=body(), sync=sync, client=client, async_client=async_client, request=request
)

# First event should succeed
sse = await iter_next(iterator)
assert sse.event == "completion"

# ReadTimeout should pass through unchanged
with pytest.raises(httpx.ReadTimeout):
await iter_next(iterator)


def test_isinstance_check(client: Anthropic, async_client: AsyncAnthropic) -> None:
async_stream = AsyncStream(cast_to=object, client=async_client, response=httpx.Response(200, content=b"foo"))
assert isinstance(async_stream, AsyncStream)
Expand Down Expand Up @@ -269,12 +323,27 @@ def make_event_iterator(
sync: bool,
client: Anthropic,
async_client: AsyncAnthropic,
) -> Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]:
return make_event_iterator_with_request(
content=content, sync=sync, client=client, async_client=async_client, request=None
)


def make_event_iterator_with_request(
content: Iterator[bytes],
*,
sync: bool,
client: Anthropic,
async_client: AsyncAnthropic,
request: httpx.Request | None,
) -> Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]:
if sync:
return Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))._iter_events()
return Stream(
cast_to=object, client=client, response=httpx.Response(200, content=content, request=request)
)._iter_events()

return AsyncStream(
cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content))
cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content), request=request)
)._iter_events()


Expand Down