11from ssl import SSLContext
2- from typing import List , Optional , Tuple , Union
2+ from typing import List , Optional , Tuple
33
44from .._backends .auto import AsyncLock , AsyncSocketStream , AutoBackend
55from .._types import URL , Headers , Origin , TimeoutDict
1010 ConnectionState ,
1111 NewConnectionRequired ,
1212)
13- from .http2 import AsyncHTTP2Connection
14- from .http11 import AsyncHTTP11Connection
13+ from .http import AsyncBaseHTTPConnection
1514
1615logger = get_logger (__name__ )
1716
@@ -32,7 +31,7 @@ def __init__(
3231 if self .http2 :
3332 self .ssl_context .set_alpn_protocols (["http/1.1" , "h2" ])
3433
35- self .connection : Union [ None , AsyncHTTP11Connection , AsyncHTTP2Connection ] = None
34+ self .connection : Optional [ AsyncBaseHTTPConnection ] = None
3635 self .is_http11 = False
3736 self .is_http2 = False
3837 self .connect_failed = False
@@ -110,11 +109,15 @@ def _create_connection(self, socket: AsyncSocketStream) -> None:
110109 "create_connection socket=%r http_version=%r" , socket , http_version
111110 )
112111 if http_version == "HTTP/2" :
112+ from .http2 import AsyncHTTP2Connection
113+
113114 self .is_http2 = True
114115 self .connection = AsyncHTTP2Connection (
115116 socket = socket , backend = self .backend , ssl_context = self .ssl_context
116117 )
117118 else :
119+ from .http11 import AsyncHTTP11Connection
120+
118121 self .is_http11 = True
119122 self .connection = AsyncHTTP11Connection (
120123 socket = socket , ssl_context = self .ssl_context
@@ -126,7 +129,7 @@ def state(self) -> ConnectionState:
126129 return ConnectionState .CLOSED
127130 elif self .connection is None :
128131 return ConnectionState .PENDING
129- return self .connection .state
132+ return self .connection .get_state ()
130133
131134 def is_connection_dropped (self ) -> bool :
132135 return self .connection is not None and self .connection .is_connection_dropped ()
@@ -138,9 +141,8 @@ def mark_as_ready(self) -> None:
138141 async def start_tls (self , hostname : bytes , timeout : TimeoutDict = None ) -> None :
139142 if self .connection is not None :
140143 logger .trace ("start_tls hostname=%r timeout=%r" , hostname , timeout )
141- await self .connection .start_tls (hostname , timeout )
144+ self . socket = await self .connection .start_tls (hostname , timeout )
142145 logger .trace ("start_tls complete hostname=%r timeout=%r" , hostname , timeout )
143- self .socket = self .connection .socket
144146
145147 async def aclose (self ) -> None :
146148 async with self .request_lock :
0 commit comments