Skip to content

Commit 51f2be4

Browse files
authored
feat(s2n-quic-tls): record server's ConnectionInfo in s2n-quic's TLS Connection (#2906)
1 parent 2158845 commit 51f2be4

17 files changed

Lines changed: 248 additions & 21 deletions

File tree

quic/s2n-quic-core/src/crypto/tls.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
// SPDX-License-Identifier: Apache-2.0
33

4+
use crate::path::{LocalAddress, RemoteAddress};
45
#[cfg(feature = "alloc")]
56
use alloc::vec::Vec;
67
#[cfg(feature = "alloc")]
78
pub use bytes::{Bytes, BytesMut};
8-
use core::{any::Any, fmt::Debug};
9+
use core::{any::Any, fmt::Debug, net::SocketAddr};
910
use zerocopy::{FromBytes, IntoBytes, Unaligned};
1011

1112
mod error;
@@ -23,6 +24,25 @@ pub mod slow_tls;
2324
#[cfg(feature = "std")]
2425
pub mod offload;
2526

27+
/// Holds connection address information for establishing a TLS session.
28+
/// This includes both the local and remote addresses.
29+
#[derive(Debug, Clone, Copy)]
30+
#[non_exhaustive]
31+
pub struct ConnectionInfo {
32+
pub local_address: SocketAddr,
33+
pub remote_address: SocketAddr,
34+
}
35+
36+
impl ConnectionInfo {
37+
#[doc(hidden)]
38+
pub fn new(local_address: LocalAddress, remote_address: RemoteAddress) -> Self {
39+
Self {
40+
local_address: local_address.into(),
41+
remote_address: remote_address.into(),
42+
}
43+
}
44+
}
45+
2646
/// Holds all application parameters which are exchanged within the TLS handshake.
2747
#[derive(Debug)]
2848
pub struct ApplicationParameters<'a> {
@@ -206,6 +226,7 @@ pub trait Endpoint: 'static + Sized + Send {
206226
fn new_server_session<Params: s2n_codec::EncoderValue>(
207227
&mut self,
208228
transport_parameters: &Params,
229+
connection_info: ConnectionInfo,
209230
) -> Self::Session;
210231

211232
fn new_client_session<Params: s2n_codec::EncoderValue>(

quic/s2n-quic-core/src/crypto/tls/null.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ impl<T: Send + Clone + 'static + std::fmt::Debug> crypto::tls::Endpoint for Endp
8888
fn new_server_session<Params: s2n_codec::EncoderValue>(
8989
&mut self,
9090
transport_parameters: &Params,
91+
_connection_info: tls::ConnectionInfo,
9192
) -> Self::Session {
9293
let params = transport_parameters.encode_to_vec().into();
9394
Session::Server(server::TlsSession::Init {

quic/s2n-quic-core/src/crypto/tls/offload.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use crate::{
44
application,
55
crypto::{
6-
tls::{self, NamedGroup, TlsSession},
6+
tls::{self, ConnectionInfo, NamedGroup, TlsSession},
77
CryptoSuite,
88
},
99
sync::spsc::{channel, Receiver, SendSlice, Sender},
@@ -76,9 +76,11 @@ where
7676
fn new_server_session<Params: s2n_codec::EncoderValue>(
7777
&mut self,
7878
transport_parameters: &Params,
79+
connection_info: ConnectionInfo,
7980
) -> Self::Session {
8081
OffloadSession::new(
81-
self.inner.new_server_session(transport_parameters),
82+
self.inner
83+
.new_server_session(transport_parameters, connection_info),
8284
&self.executor,
8385
self.exporter.clone(),
8486
self.channel_capacity,

quic/s2n-quic-core/src/crypto/tls/slow_tls.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ impl<E: tls::Endpoint> tls::Endpoint for SlowEndpoint<E> {
2626
fn new_server_session<Params: s2n_codec::EncoderValue>(
2727
&mut self,
2828
transport_parameters: &Params,
29+
connection_info: tls::ConnectionInfo,
2930
) -> Self::Session {
30-
let inner_session = self.endpoint.new_server_session(transport_parameters);
31+
let inner_session = self
32+
.endpoint
33+
.new_server_session(transport_parameters, connection_info);
3134
SlowSession {
3235
defer: DEFER_COUNT,
3336
inner_session,

quic/s2n-quic-core/src/crypto/tls/testing.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@ use crate::{
88
crypto::{
99
header_crypto::{LONG_HEADER_MASK, SHORT_HEADER_MASK},
1010
scatter, tls,
11-
tls::{ApplicationParameters, CipherSuite, NamedGroup, TlsExportError, TlsSession},
11+
tls::{
12+
ApplicationParameters, CipherSuite, ConnectionInfo, NamedGroup, TlsExportError,
13+
TlsSession,
14+
},
1215
CryptoSuite, HeaderKey, Key,
1316
},
14-
endpoint, transport,
17+
endpoint,
18+
inet::SocketAddressV4,
19+
path::{LocalAddress, RemoteAddress},
20+
transport,
1521
transport::parameters::{ClientTransportParameters, ServerTransportParameters},
1622
};
1723
use alloc::sync::Arc;
@@ -69,6 +75,7 @@ impl super::Endpoint for Endpoint {
6975
fn new_server_session<Params: EncoderValue>(
7076
&mut self,
7177
_transport_parameters: &Params,
78+
_connection_info: ConnectionInfo,
7279
) -> Self::Session {
7380
Session
7481
}
@@ -155,7 +162,7 @@ pub struct Pair<S: tls::Session, C: tls::Session> {
155162
pub server_name: ServerName,
156163
}
157164

158-
fn server_params() -> Bytes {
165+
pub fn server_params() -> Bytes {
159166
ServerTransportParameters {
160167
initial_max_data: 123.try_into().unwrap(),
161168
..Default::default()
@@ -185,7 +192,12 @@ impl<S: tls::Session, C: tls::Session> Pair<S, C> {
185192
{
186193
use crate::crypto::InitialKey;
187194

188-
let server = server_endpoint.new_server_session(&&server_params()[..]);
195+
// This testing pair doesn't use tls::ConnectionInfo. Hence, we can create random local/remote addresses to pass in new_server_session
196+
let local_address: LocalAddress = SocketAddressV4::new([127, 0, 0, 1], 443).into();
197+
let remote_address: RemoteAddress = SocketAddressV4::new([127, 0, 0, 1], 12345).into();
198+
let connection_info = tls::ConnectionInfo::new(local_address, remote_address);
199+
200+
let server = server_endpoint.new_server_session(&&server_params()[..], connection_info);
189201
let mut server_context =
190202
Context::new(endpoint::Type::Server, ServerState::WaitingClientHello);
191203
server_context.initial.crypto = Some(S::InitialKey::new_server(server_name.as_bytes()));

quic/s2n-quic-rustls/src/client.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ impl tls::Endpoint for Client {
7171
fn new_server_session<Params: EncoderValue>(
7272
&mut self,
7373
_transport_parameters: &Params,
74+
_connection_info: tls::ConnectionInfo,
7475
) -> Self::Session {
7576
panic!("cannot create a server session from a client config");
7677
}

quic/s2n-quic-rustls/src/server.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ impl tls::Endpoint for Server {
6969
fn new_server_session<Params: EncoderValue>(
7070
&mut self,
7171
transport_parameters: &Params,
72+
_connection_info: tls::ConnectionInfo,
7273
) -> Self::Session {
7374
//= https://www.rfc-editor.org/rfc/rfc9001#section-8.2
7475
//# Endpoints MUST send the quic_transport_parameters extension;

quic/s2n-quic-tests/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,9 @@ zerocopy = { version = "0.8", features = ["derive"] }
3131
[target.'cfg(not(target_arch = "x86"))'.dependencies]
3232
quiche = "0.24"
3333

34+
# s2n-tls is required by ch_callback_server_local_address_test and doesn't build on Windows
35+
[target.'cfg(not(target_os = "windows"))'.dependencies]
36+
s2n-tls = "0.3.31"
37+
3438
[target.'cfg(unix)'.dependencies]
3539
s2n-quic = { path = "../s2n-quic", features = ["provider-event-tracing", "provider-tls-s2n", "unstable-provider-io-testing", "unstable-provider-dc", "unstable-provider-packet-interceptor", "unstable-provider-random", "unstable-offload-tls", "unstable_client_hello"] }

quic/s2n-quic-tests/src/tests.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use s2n_quic::{
1111
self as io, network::Packet, primary, rand, spawn, test, time::delay, Model,
1212
},
1313
packet_interceptor::Loss,
14+
tls,
1415
},
1516
Client, Server,
1617
};
@@ -53,6 +54,9 @@ mod tls_context;
5354
#[cfg(not(target_arch = "x86"))]
5455
mod zero_length_cid_client_connection_migration;
5556

57+
// The ClientHelloCallback trait is only available with s2n-tls
58+
#[cfg(not(target_os = "windows"))]
59+
mod ch_callback_connection_info;
5660
// TODO: https://github.com/aws/s2n-quic/issues/1726
5761
//
5862
// The rustls tls provider is used on windows and has different
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use super::*;
5+
use s2n_quic_core::crypto::tls::ConnectionInfo;
6+
use s2n_tls::{
7+
callbacks::{ClientHelloCallback, ConnectionFuture},
8+
error::Error as S2nError,
9+
};
10+
use std::{
11+
pin::Pin,
12+
sync::{Arc, Mutex},
13+
};
14+
15+
struct TestClientHelloHandle {
16+
// The ClientHelloCallback trait requires `&self` as a immutable reference.
17+
// We use Arc<Mutex<>> to enable interior mutability - allowing us to mutate the recorded
18+
// ConnectionInfo through an immutable reference.
19+
recorded_info: Arc<Mutex<Option<ConnectionInfo>>>,
20+
}
21+
22+
impl TestClientHelloHandle {
23+
pub fn new(recorded_info: Arc<Mutex<Option<ConnectionInfo>>>) -> Self {
24+
Self { recorded_info }
25+
}
26+
}
27+
28+
impl ClientHelloCallback for TestClientHelloHandle {
29+
fn on_client_hello(
30+
&self,
31+
connection: &mut s2n_tls::connection::Connection,
32+
) -> Result<Option<Pin<Box<dyn ConnectionFuture>>>, S2nError> {
33+
let connection_info = connection.application_context::<ConnectionInfo>();
34+
35+
assert!(connection_info.is_some());
36+
if let Some(info) = connection_info {
37+
*self.recorded_info.lock().unwrap() = Some(*info);
38+
}
39+
40+
Ok(None)
41+
}
42+
}
43+
44+
/// Tests that ConnectionInfo is accessible in the client hello callback and contains
45+
/// the correct local (server) and remote (client) socket addresses.
46+
///
47+
/// This test:
48+
/// 1. Creates a server with a client hello callback that records ConnectionInfo
49+
/// 2. Records the actual server and client socket addresses during connection setup
50+
/// 3. Verifies that the ConnectionInfo captured in the callback matches the expected addresses
51+
///
52+
/// Note: Uses interior mutability (Arc<Mutex<>>) to store data from the callback since
53+
/// ClientHelloCallback requires an immutable reference (&self).
54+
#[test]
55+
#[cfg_attr(miri, ignore)]
56+
fn ch_callback_connection_info_test() {
57+
let model = Model::default();
58+
59+
let ch_callback_handle_inner = Arc::new(Mutex::new(None));
60+
let ch_callback_handle_inner_clone = ch_callback_handle_inner.clone();
61+
62+
let mut server_local_address = None;
63+
let mut server_remote_address = None;
64+
65+
test(model.clone(), |handle| {
66+
let server_tls = tls::s2n_tls::Server::builder()
67+
.with_certificate(certificates::CERT_PEM, certificates::KEY_PEM)
68+
.unwrap()
69+
.with_client_hello_handler(TestClientHelloHandle::new(ch_callback_handle_inner_clone))
70+
.unwrap()
71+
.build()
72+
.unwrap();
73+
74+
let server = Server::builder()
75+
.with_io(handle.builder().build()?)?
76+
.with_tls(server_tls)?
77+
.with_event(tracing_events(true, model.clone()))?
78+
.with_random(Random::with_seed(456))?
79+
.start()?;
80+
81+
let server_addr = start_server(server)?;
82+
server_local_address = Some(server_addr);
83+
84+
let client = build_client(handle, model.clone(), true)?;
85+
server_remote_address = Some(client.local_addr().unwrap());
86+
87+
start_client(client, server_addr, Data::new(1000))?;
88+
89+
Ok(server_addr)
90+
})
91+
.unwrap();
92+
93+
let connection_info = ch_callback_handle_inner.lock().unwrap().unwrap();
94+
95+
// Verify that the ConnectionInfo contains the exact server local address
96+
assert_eq!(connection_info.local_address, server_local_address.unwrap());
97+
98+
// Verify that the ConnectionInfo contains the exact server's remote address (client's local address)
99+
assert_eq!(
100+
connection_info.remote_address,
101+
server_remote_address.unwrap()
102+
);
103+
}

0 commit comments

Comments
 (0)