Skip to content

Commit ceb584c

Browse files
committed
Applying review comments
1 parent 8810db5 commit ceb584c

File tree

3 files changed

+122
-114
lines changed

3 files changed

+122
-114
lines changed

ptbcontrib/aiohttp_request/aiohttprequest.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,17 @@
1818
# along with this program. If not, see [http://www.gnu.org/licenses/].
1919
"""This module contains methods to make POST and GET requests using the aiohttp library."""
2020
import asyncio
21+
import logging
2122
from typing import Any, Optional, Union
2223

2324
import aiohttp
2425
import yarl
25-
from telegram._utils.logging import get_logger
26+
from telegram._utils.defaultvalue import DefaultValue
2627
from telegram._utils.types import ODVInput
2728
from telegram.error import NetworkError, TimedOut
2829
from telegram.request import BaseRequest, RequestData
2930

30-
_LOGGER = get_logger(__name__, "AiohttpRequest")
31+
_LOGGER = logging.getLogger("AiohttpRequest")
3132

3233

3334
class AiohttpRequest(BaseRequest):
@@ -43,13 +44,13 @@ class AiohttpRequest(BaseRequest):
4344
Note:
4445
:paramref:`media_total_timeout` will still be applied if a file is send, so be sure
4546
to also set it to an appropriate value.
46-
PylintDoesntAllowMeToWriteTODO Should I warn about this?
4747
media_total_timeout (:obj:`float` | :obj:`None`, optional): This overrides the total
4848
timeout with requests that upload media/files. Defaults to ``20`` seconds.
4949
proxy (:obj:`str` | `yarl.URL``, optional): The URL to a proxy server, aiohttp supports
5050
plain HTTP proxies and HTTP proxies that can be upgraded to HTTPS via the HTTP
5151
CONNECT method. See the docs of aiohttp: https://docs.aiohttp.org/en/stable/
5252
client_advanced.html#aiohttp-client-proxy-support.
53+
proxy_auth (``aiohttp.BasicAuth``, optional): Proxy authorization, see :paramref:`proxy`.
5354
trust_env (:obj:`bool`, optional): In order to read proxy environmental variables, see the
5455
docs of aiohttp: https://docs.aiohttp.org/en/stable/client_advanced.html
5556
#aiohttp-client-proxy-support.
@@ -61,25 +62,28 @@ class AiohttpRequest(BaseRequest):
6162
This parameter is intended for advanced users that want to fine-tune the behavior
6263
of the underlying ``aiohttp`` clientSession. The values passed here will override
6364
all the defaults set by ``python-telegram-bot`` and all other parameters passed to
64-
:class:`ClientSession`. The only exception is the :paramref:`media_write_timeout`
65+
:class:`ClientSession`. The only exception is the :paramref:`media_total_timeout`
6566
parameter, which is not passed to the client constructor.
6667
No runtime warnings will be issued about parameters that are overridden in this
6768
way.
6869
6970
"""
7071

71-
__slots__ = ("_session", "_session_kwargs", "_media_total_timeout")
72+
__slots__ = ("_session", "_session_kwargs", "_media_total_timeout", "_connection_pool_size")
7273

7374
def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
7475
self,
7576
connection_pool_size: int = 1,
7677
client_timeout: Optional[aiohttp.ClientTimeout] = None,
77-
media_total_timeout: Optional[float] = 20.0,
78+
media_total_timeout: Optional[float] = 30.0,
7879
proxy: Optional[Union[str, yarl.URL]] = None,
80+
proxy_auth: Optional[aiohttp.BasicAuth] = None,
7981
trust_env: Optional[bool] = None,
8082
aiohttp_kwargs: Optional[dict[str, Any]] = None,
8183
):
8284
self._media_total_timeout = media_total_timeout
85+
# this needs to be saved in case of initialize gets a closed session
86+
self._connection_pool_size = connection_pool_size
8387
timeout = (
8488
client_timeout
8589
if client_timeout
@@ -103,6 +107,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument
103107
"timeout": timeout,
104108
"connector": conn,
105109
"proxy": proxy,
110+
"proxy_auth": proxy_auth,
106111
"trust_env": trust_env,
107112
**(aiohttp_kwargs or {}),
108113
}
@@ -113,8 +118,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument
113118
def read_timeout(self) -> Optional[float]:
114119
"""See :attr:`BaseRequest.read_timeout`.
115120
116-
This makes not a lot of sense to implement since there is no actual read_timeout.
117-
But what can I do.
121+
aiohttp does not have a read timeout. Instead the total timeout for a request (including
122+
connection establishment, request sending and response reading) is returned.
118123
"""
119124
return self._session.timeout.total
120125

@@ -124,6 +129,14 @@ def _build_client(self) -> aiohttp.ClientSession:
124129
async def initialize(self) -> None:
125130
"""See :meth:`BaseRequest.initialize`."""
126131
if self._session.closed:
132+
# this means the TCPConnector has been closed, so we need to recreate it
133+
try:
134+
loop = asyncio.get_running_loop()
135+
except RuntimeError:
136+
loop = asyncio.get_event_loop()
137+
138+
conn = aiohttp.TCPConnector(limit=self._connection_pool_size, loop=loop)
139+
self._session_kwargs["connector"] = conn
127140
self._session = self._build_client()
128141

129142
async def shutdown(self) -> None:
@@ -134,7 +147,6 @@ async def shutdown(self) -> None:
134147

135148
await self._session.close()
136149

137-
# pylint: disable=unused-argument
138150
async def do_request( # pylint: disable=too-many-arguments,too-many-positional-arguments
139151
self,
140152
url: str,
@@ -145,7 +157,18 @@ async def do_request( # pylint: disable=too-many-arguments,too-many-positional-
145157
connect_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE,
146158
pool_timeout: ODVInput[float] = BaseRequest.DEFAULT_NONE,
147159
) -> tuple[int, bytes]:
148-
"""See :meth:`BaseRequest.do_request`."""
160+
"""See :meth:`BaseRequest.do_request`.
161+
162+
Since aiohttp has differen't timeouts, the params were mapped.
163+
164+
* :paramref:`pool_timeout` is mapped to :attr`~aiohttp.ClientTimeout.connect`
165+
* :paramref:`connect_timeout` is mapped to :attr`~aiohttp.ClientTimeout.sock_connect`
166+
* :paramref:`read_timeout` is mapped to :attr`~aiohttp.ClientTimeout.sock_read`
167+
* :paramref:`write_timeout` is mapped to :attr`~aiohttp.ClientTimeout.ceil_threshold`
168+
169+
The :attr`~aiohttp.ClientTimeout.total` timeout is not changed since it also includes
170+
response reading. You can only change them when initializing the class.
171+
"""
149172
if self._session.closed:
150173
raise RuntimeError("This AiohttpRequest is not initialized!")
151174

@@ -161,20 +184,27 @@ async def do_request( # pylint: disable=too-many-arguments,too-many-positional-
161184
filename=request_data.multipart_data[field_name][0],
162185
)
163186

164-
# I dont think it makes sense to support the timeout params.
165-
# PylintDoesntAllowMeToWriteTOLiDO if no one complains in initial PR
166-
# raise warnings if they are passed
167-
168-
timeout = (
169-
aiohttp.ClientTimeout(
170-
total=self._media_total_timeout,
171-
connect=self._session_kwargs["timeout"].connect,
172-
sock_read=self._session_kwargs["timeout"].sock_read,
173-
sock_connect=self._session_kwargs["timeout"].sock_connect,
174-
ceil_threshold=self._session_kwargs["timeout"].ceil_threshold,
175-
)
176-
if request_data and request_data.contains_files
177-
else self._session_kwargs["timeout"]
187+
# If user did not specify timeouts (for e.g. in a bot method), use the default ones when we
188+
# created this instance.
189+
if isinstance(read_timeout, DefaultValue):
190+
read_timeout = self._session_kwargs["timeout"].sock_read
191+
if isinstance(connect_timeout, DefaultValue):
192+
connect_timeout = self._session_kwargs["timeout"].sock_connect
193+
if isinstance(pool_timeout, DefaultValue):
194+
pool_timeout = self._session_kwargs["timeout"].connect
195+
if isinstance(write_timeout, DefaultValue):
196+
write_timeout = self._session_kwargs["timeout"].ceil_threshold
197+
198+
timeout = aiohttp.ClientTimeout(
199+
total=(
200+
self._media_total_timeout
201+
if (request_data and request_data.contains_files)
202+
else self._session_kwargs["timeout"].total
203+
),
204+
connect=pool_timeout,
205+
sock_read=read_timeout,
206+
sock_connect=connect_timeout,
207+
ceil_threshold=write_timeout,
178208
)
179209

180210
try:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
python-telegram-bot>=21.1
1+
python-telegram-bot>=22.1
22
aiohttp[speedups]>=3.11

tests/test_aiohttp_request.py

Lines changed: 67 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -112,35 +112,28 @@ class TestRequest:
112112
def _reset(self):
113113
self.test_flag = None
114114

115-
# def test_aiohttp_kwargs(self, monkeypatch):
116-
# self.test_flag = {}
117-
#
118-
# orig_init = aiohttp.ClientSession.__init__
119-
#
120-
# class Session(aiohttp.ClientSession):
121-
# def __init__(*args, **kwargs):
122-
# orig_init(*args, **kwargs)
123-
# self.test_flag["args"] = args
124-
# self.test_flag["kwargs"] = kwargs
125-
#
126-
# monkeypatch.setattr(aiohttp, "ClientSession", Session)
127-
#
128-
# AiohttpRequest(
129-
# client_timeout=aiohttp.ClientTimeout(total=1.0),
130-
# connection_pool_size=42,
131-
# httpx_kwargs={
132-
# "timeout": httpx.Timeout(7),
133-
# "limits": httpx.Limits(max_connections=7),
134-
# "http1": True,
135-
# "verify": False,
136-
# },
137-
# )
138-
# kwargs = self.test_flag["kwargs"]
139-
#
140-
# assert kwargs["timeout"].connect == 7
141-
# assert kwargs["limits"].max_connections == 7
142-
# assert kwargs["http1"] is True
143-
# assert kwargs["verify"] is False
115+
def test_aiohttp_kwargs(self, monkeypatch):
116+
self.test_flag = {}
117+
118+
orig_init = aiohttp.ClientSession.__init__
119+
120+
class Session(aiohttp.ClientSession):
121+
def __init__(*args, **kwargs):
122+
orig_init(*args, **kwargs)
123+
self.test_flag["args"] = args
124+
self.test_flag["kwargs"] = kwargs
125+
126+
monkeypatch.setattr(aiohttp, "ClientSession", Session)
127+
128+
AiohttpRequest(
129+
client_timeout=aiohttp.ClientTimeout(total=1.0),
130+
aiohttp_kwargs={
131+
"timeout": aiohttp.ClientTimeout(total=40.0),
132+
},
133+
)
134+
kwargs = self.test_flag["kwargs"]
135+
136+
assert kwargs["timeout"].total == 40
144137

145138
async def test_context_manager(self, monkeypatch):
146139
async def initialize():
@@ -386,41 +379,26 @@ class TestAiohttpRequest:
386379
def _reset(self):
387380
self.test_flag = None
388381

389-
# def test_init(self, monkeypatch):
390-
# @dataclass
391-
# class Session:
392-
# timeout: object
393-
# proxy: object
394-
# limits: object
395-
# http1: object
396-
# http2: object
397-
# transport: object = None
398-
#
399-
# monkeypatch.setattr(aiohttp, "ClientSession", Session)
400-
#
401-
# request = AiohttpRequest()
402-
# assert request._client.timeout ==
403-
# httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=1.0)
404-
# assert request._client.proxy is None
405-
# assert request._client.limits == httpx.Limits(
406-
# max_connections=1, max_keepalive_connections=1
407-
# )
408-
# assert request._client.http1 is True
409-
# assert not request._client.http2
410-
#
411-
# request = AiohttpRequest(
412-
# connection_pool_size=42,
413-
# proxy="proxy",
414-
# connect_timeout=43,
415-
# read_timeout=44,
416-
# write_timeout=45,
417-
# pool_timeout=46,
418-
# )
419-
# assert request._client.proxy == "proxy"
420-
# assert request._client.limits == httpx.Limits(
421-
# max_connections=42, max_keepalive_connections=42
422-
# )
423-
# assert request._client.timeout == httpx.Timeout(connect=43, read=44, write=45, pool=46)
382+
def test_init(self, monkeypatch):
383+
384+
request = AiohttpRequest()
385+
assert request._session.timeout == aiohttp.ClientTimeout(total=15)
386+
assert request._session._default_proxy is None
387+
assert request._session._version is aiohttp.HttpVersion11
388+
389+
request = AiohttpRequest(
390+
connection_pool_size=42,
391+
client_timeout=aiohttp.ClientTimeout(total=25),
392+
media_total_timeout=200,
393+
proxy="proxy",
394+
proxy_auth=aiohttp.BasicAuth("user", "pass"),
395+
trust_env=True,
396+
)
397+
assert request._session._default_proxy == "proxy"
398+
assert request._session._default_proxy_auth == aiohttp.BasicAuth("user", "pass")
399+
assert request._session.timeout == aiohttp.ClientTimeout(total=25)
400+
assert request._media_total_timeout == 200
401+
assert request._session._trust_env
424402

425403
async def test_multiple_inits_and_shutdowns(self, monkeypatch):
426404
self.test_flag = defaultdict(int)
@@ -507,31 +485,31 @@ async def make_assertion(_, **kwargs):
507485

508486
assert self.test_flag
509487

510-
# async def test_do_request_manual_timeouts(self, monkeypatch, aiohttp_request):
511-
# default_timeouts = aiohttp.ClientTimeout(total=42, connect=43,
512-
# sock_read=44, sock_connect=45, ceil_threshold=46)
513-
# manual_timeouts = aiohttp.ClientTimeout(total=52, connect=53,
514-
# sock_read=54, sock_connect=55, ceil_threshold=56)
515-
#
516-
# async def make_assertion(_, **kwargs):
517-
# print(kwargs.get("timeout"))
518-
# self.test_flag = kwargs.get("timeout") == manual_timeouts
519-
# return Response()
520-
#
521-
# async with AiohttpRequest(
522-
# client_timeout=default_timeouts
523-
# ) as aiohttp_request_ctx:
524-
# monkeypatch.setattr(aiohttp.ClientSession, "request", make_assertion)
525-
# await aiohttp_request_ctx.do_request(
526-
# method="GET",
527-
# url="URL",
528-
# connect_timeout=manual_timeouts.connect,
529-
# read_timeout=manual_timeouts.total,
530-
# write_timeout=manual_timeouts.sock_read,
531-
# pool_timeout=manual_timeouts.sock_connect,
532-
# )
533-
#
534-
# assert self.test_flag
488+
async def test_do_request_manual_timeouts(self, monkeypatch, aiohttp_request):
489+
default_timeouts = aiohttp.ClientTimeout(
490+
total=42, connect=43, sock_read=44, sock_connect=45, ceil_threshold=46
491+
)
492+
manual_timeouts = aiohttp.ClientTimeout(
493+
total=42, connect=53, sock_read=54, sock_connect=55, ceil_threshold=56
494+
)
495+
496+
async def make_assertion(_, **kwargs):
497+
print(kwargs.get("timeout"))
498+
self.test_flag = kwargs.get("timeout") == manual_timeouts
499+
return Response()
500+
501+
async with AiohttpRequest(client_timeout=default_timeouts) as aiohttp_request_ctx:
502+
monkeypatch.setattr(aiohttp.ClientSession, "request", make_assertion)
503+
await aiohttp_request_ctx.do_request(
504+
method="GET",
505+
url="URL",
506+
connect_timeout=manual_timeouts.sock_connect,
507+
read_timeout=manual_timeouts.sock_read,
508+
write_timeout=manual_timeouts.ceil_threshold,
509+
pool_timeout=manual_timeouts.connect,
510+
)
511+
512+
assert self.test_flag
535513

536514
async def test_do_request_params_no_data(self, monkeypatch, aiohttp_request):
537515
async def make_assertion(self, **kwargs):

0 commit comments

Comments
 (0)