Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/6485.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed ``Response.text`` when body is a ``Payload`` -- by :user:`Dreamsorcerer`.
17 changes: 17 additions & 0 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,8 @@ def filename(self) -> Optional[str]:

@payload_type(BodyPartReader, order=Order.try_first)
class BodyPartReaderPayload(Payload):
_value: BodyPartReader

def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None:
super().__init__(value, *args, **kwargs)

Expand All @@ -573,6 +575,9 @@ def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None:
if params:
self.set_content_disposition("attachment", True, **params)

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
raise TypeError("Unable to decode.")

async def write(self, writer: Any) -> None:
field = self._value
chunk = await field.read_chunk(size=2**16)
Expand Down Expand Up @@ -788,6 +793,8 @@ async def _maybe_release_last_part(self) -> None:
class MultipartWriter(Payload):
"""Multipart body writer."""

_value: None

def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None:
boundary = boundary if boundary is not None else uuid.uuid4().hex
# The underlying Payload API demands a str (utf-8), not bytes,
Expand Down Expand Up @@ -972,6 +979,16 @@ def size(self) -> Optional[int]:
total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n'
return total

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return "".join(
"--"
+ self.boundary
+ "\n"
+ part._binary_headers.decode(encoding, errors)
+ part.decode()
for part, _e, _te in self._parts
)

async def write(self, writer: Any, close_boundary: bool = True) -> None:
"""Write body."""
for part, encoding, te_encoding in self._parts:
Expand Down
36 changes: 34 additions & 2 deletions aiohttp/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ def set_content_disposition(
disptype, quote_fields=quote_fields, _charset=_charset, params=params
)

@abstractmethod
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
"""Return string representation of the value.

This is named decode() to allow compatibility with bytes objects.
"""

@abstractmethod
async def write(self, writer: AbstractStreamWriter) -> None:
"""Write payload.
Expand All @@ -216,6 +223,8 @@ async def write(self, writer: AbstractStreamWriter) -> None:


class BytesPayload(Payload):
_value: bytes

def __init__(
self, value: Union[bytes, bytearray, memoryview], *args: Any, **kwargs: Any
) -> None:
Expand All @@ -241,6 +250,9 @@ def __init__(
source=self,
)

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return self._value.decode(encoding, errors)

async def write(self, writer: AbstractStreamWriter) -> None:
await writer.write(self._value)

Expand Down Expand Up @@ -281,7 +293,7 @@ def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None:


class IOBasePayload(Payload):
_value: IO[Any]
_value: io.IOBase

def __init__(
self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any
Expand All @@ -305,9 +317,12 @@ async def write(self, writer: AbstractStreamWriter) -> None:
finally:
await loop.run_in_executor(None, self._value.close)

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return "".join(r.decode(encoding, errors) for r in self._value.readlines())


class TextIOPayload(IOBasePayload):
_value: TextIO
_value: io.TextIOBase

def __init__(
self,
Expand Down Expand Up @@ -343,6 +358,9 @@ def size(self) -> Optional[int]:
except OSError:
return None

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return self._value.read()

async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
Expand All @@ -360,15 +378,22 @@ async def write(self, writer: AbstractStreamWriter) -> None:


class BytesIOPayload(IOBasePayload):
_value: io.BytesIO

@property
def size(self) -> int:
position = self._value.tell()
end = self._value.seek(0, os.SEEK_END)
self._value.seek(position)
return end - position

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return self._value.read().decode(encoding, errors)


class BufferedReaderPayload(IOBasePayload):
_value: io.BufferedIOBase

@property
def size(self) -> Optional[int]:
try:
Expand All @@ -378,6 +403,9 @@ def size(self) -> Optional[int]:
# io.BufferedReader(io.BytesIO(b'data'))
return None

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return self._value.read().decode(encoding, errors)


class JsonPayload(BytesPayload):
def __init__(
Expand Down Expand Up @@ -412,6 +440,7 @@ def __init__(

class AsyncIterablePayload(Payload):
_iter: Optional[_AsyncIterator] = None
_value: _AsyncIterable

def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None:
if not isinstance(value, AsyncIterable):
Expand Down Expand Up @@ -439,6 +468,9 @@ async def write(self, writer: AbstractStreamWriter) -> None:
except StopAsyncIteration:
self._iter = None

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
raise TypeError("Unable to decode.")


class StreamReaderPayload(AsyncIterablePayload):
def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
Expand Down
28 changes: 11 additions & 17 deletions aiohttp/web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class StreamResponse(BaseClass, HeadersMixin, CookieMixin):
"__weakref__",
)

_body: Union[None, bytes, bytearray, Payload]

def __init__(
self,
*,
Expand Down Expand Up @@ -499,7 +501,6 @@ def __eq__(self, other: object) -> bool:

class Response(StreamResponse):
__slots__ = (
"_body_payload",
"_compressed_body",
"_zlib_executor_size",
"_zlib_executor",
Expand Down Expand Up @@ -580,21 +581,17 @@ def body(self) -> Optional[Union[bytes, Payload]]:
return self._body

@body.setter
def body(self, body: bytes) -> None:
def body(self, body: Any) -> None:
if body is None:
self._body: Optional[bytes] = None
self._body_payload: bool = False
self._body = None
elif isinstance(body, (bytes, bytearray)):
self._body = body
self._body_payload = False
else:
try:
self._body = body = payload.PAYLOAD_REGISTRY.get(body)
except payload.LookupError:
raise ValueError("Unsupported body type %r" % type(body))

self._body_payload = True

headers = self._headers

# set content-type
Expand Down Expand Up @@ -625,7 +622,6 @@ def text(self, text: str) -> None:
self.charset = "utf-8"

self._body = text.encode(self.charset)
self._body_payload = False
self._compressed_body = None

@property
Expand All @@ -639,7 +635,7 @@ def content_length(self) -> Optional[int]:
if self._compressed_body is not None:
# Return length of the compressed body
return len(self._compressed_body)
elif self._body_payload:
elif isinstance(self._body, Payload):
# A payload without content length, or a compressed payload
return None
elif self._body is not None:
Expand All @@ -664,9 +660,8 @@ async def write_eof(self, data: bytes = b"") -> None:
if body is not None:
if self._must_be_empty_body:
await super().write_eof()
elif self._body_payload:
payload = cast(Payload, body)
await payload.write(self._payload_writer)
elif isinstance(self._body, Payload):
await self._body.write(self._payload_writer)
await super().write_eof()
else:
await super().write_eof(cast(bytes, body))
Expand All @@ -678,10 +673,9 @@ async def _start(self, request: "BaseRequest") -> AbstractStreamWriter:
if hdrs.CONTENT_LENGTH in self._headers:
del self._headers[hdrs.CONTENT_LENGTH]
elif not self._chunked and hdrs.CONTENT_LENGTH not in self._headers:
if self._body_payload:
size = cast(Payload, self._body).size
if size is not None:
self._headers[hdrs.CONTENT_LENGTH] = str(size)
if isinstance(self._body, Payload):
if self._body.size is not None:
self._headers[hdrs.CONTENT_LENGTH] = str(self._body.size)
else:
body_len = len(self._body) if self._body else "0"
# https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-7
Expand All @@ -693,7 +687,7 @@ async def _start(self, request: "BaseRequest") -> AbstractStreamWriter:
return await super()._start(request)

async def _do_start_compression(self, coding: ContentCoding) -> None:
if self._body_payload or self._chunked:
if self._chunked or isinstance(self._body, Payload):
return await super()._do_start_compression(coding)

if coding != ContentCoding.identity:
Expand Down
3 changes: 3 additions & 0 deletions tests/test_payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def registry() -> Iterator[payload.PayloadRegistry]:


class Payload(payload.Payload):
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
assert False

async def write(self, writer: Any) -> None:
pass

Expand Down
48 changes: 46 additions & 2 deletions tests/test_web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import collections.abc
import datetime
import gzip
import io
import json
import re
import weakref
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
from typing import Any, AsyncIterator, Optional
from unittest import mock

import aiosignal
Expand All @@ -17,7 +18,8 @@
from aiohttp import HttpVersion, HttpVersion10, HttpVersion11, hdrs
from aiohttp.helpers import ETag
from aiohttp.http_writer import StreamWriter, _serialize_headers
from aiohttp.payload import BytesPayload
from aiohttp.multipart import BodyPartReader, MultipartWriter
from aiohttp.payload import BytesPayload, StringPayload
from aiohttp.test_utils import make_mocked_coro, make_mocked_request
from aiohttp.web import ContentCoding, Response, StreamResponse, json_response

Expand Down Expand Up @@ -986,6 +988,48 @@ def test_assign_nonstr_text() -> None:
assert 4 == resp.content_length


mpwriter = MultipartWriter(boundary="x")
mpwriter.append_payload(StringPayload("test"))


async def async_iter() -> AsyncIterator[str]:
yield "foo" # pragma: no cover


class CustomIO(io.IOBase):
def __init__(self):
self._lines = [b"", b"", b"test"]

def read(self, size: int = -1) -> bytes:
return self._lines.pop()


@pytest.mark.parametrize(
"payload,expected",
(
("test", "test"),
(CustomIO(), "test"),
(io.StringIO("test"), "test"),
(io.TextIOWrapper(io.BytesIO(b"test")), "test"),
(io.BytesIO(b"test"), "test"),
(io.BufferedReader(io.BytesIO(b"test")), "test"),
(async_iter(), None),
(BodyPartReader("x", CIMultiDictProxy(CIMultiDict()), mock.Mock()), None),
(
mpwriter,
"--x\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\ntest",
),
),
)
def test_payload_body_get_text(payload, expected: Optional[str]) -> None:
resp = Response(body=payload)
if expected is None:
with pytest.raises(TypeError):
resp.text
else:
assert resp.text == expected


def test_response_set_content_length() -> None:
resp = Response()
with pytest.raises(RuntimeError):
Expand Down