From 102f8a6cce65f0a100ec12c516d7840b176c7e4e Mon Sep 17 00:00:00 2001 From: 9seconds Date: Tue, 7 Apr 2026 13:41:44 +0200 Subject: [PATCH] Propagate keep alive settings from the config --- example.config.toml | 12 +++++++++++- internal/cli/doctor.go | 6 ++++++ internal/cli/run_proxy.go | 6 ++++++ internal/config/config.go | 6 ++++++ internal/config/parse.go | 6 ++++++ network/v2/base_http_test.go | 2 +- network/v2/base_network_test.go | 2 +- network/v2/init.go | 24 ++++++++++++++++++++++-- network/v2/network.go | 17 ++++++++++------- network/v2/sockopts.go | 9 ++------- network/v2/sockopts_test.go | 8 ++++---- network/v2/socks_proxy_test.go | 2 +- 12 files changed, 76 insertions(+), 24 deletions(-) diff --git a/example.config.toml b/example.config.toml index 73f568374..ece3104dd 100644 --- a/example.config.toml +++ b/example.config.toml @@ -204,13 +204,23 @@ proxies = [ # define a global timeout on establishing of network connections. idle # means a timeout on pumping data between sockset when nothing is # happening. -# [network.timeout] tcp = "5s" http = "10s" idle = "5m" handshake = "10s" +# this defines a configuration for TCP keep alives. Default values are taken +# from Golang default behavior. +[network.keep-alive] +disabled = false +# idle means a time period after which we start sending TCP Keep Alive probes +idle = "15s" +# interval is a period between 2 consecutive probes +interval = "15s" +# if we miss that many probes, a connection will be considered as a dead one. +count = 9 + # mtg has to mimic real websites. It does not mean domain fronting, it also # means that traffic characteristics should be similar to real world traffic. # websites and applications behave differently, their traffic patterns are also diff --git a/internal/cli/doctor.go b/internal/cli/doctor.go index b00c3e08a..48563bc1e 100644 --- a/internal/cli/doctor.go +++ b/internal/cli/doctor.go @@ -97,6 +97,12 @@ func (d *Doctor) Run(cli *CLI, version string) error { conf.Network.Timeout.TCP.Get(10*time.Second), conf.Network.Timeout.HTTP.Get(0), conf.Network.Timeout.Idle.Get(0), + net.KeepAliveConfig{ + Enable: !conf.Network.KeepAlive.Disabled.Get(false), + Idle: conf.Network.KeepAlive.Idle.Get(0), + Interval: conf.Network.KeepAlive.Interval.Get(0), + Count: int(conf.Network.KeepAlive.Count.Get(0)), + }, ) fmt.Println("Validate native network connectivity") diff --git a/internal/cli/run_proxy.go b/internal/cli/run_proxy.go index de8b20616..5f0b3dee6 100644 --- a/internal/cli/run_proxy.go +++ b/internal/cli/run_proxy.go @@ -50,6 +50,12 @@ func makeNetwork(conf *config.Config, version string) (mtglib.Network, error) { conf.Network.Timeout.TCP.Get(0), conf.Network.Timeout.HTTP.Get(0), conf.Network.Timeout.Idle.Get(0), + net.KeepAliveConfig{ + Enable: !conf.Network.KeepAlive.Disabled.Get(false), + Idle: conf.Network.KeepAlive.Idle.Get(0), + Interval: conf.Network.KeepAlive.Interval.Get(0), + Count: int(conf.Network.KeepAlive.Count.Get(0)), + }, ) proxyDialers := make([]mtglib.Network, len(conf.Network.Proxies)) diff --git a/internal/config/config.go b/internal/config/config.go index 2fb79eec3..70e233f17 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -65,6 +65,12 @@ type Config struct { Idle TypeDuration `json:"idle"` Handshake TypeDuration `json:"handshake"` } `json:"timeout"` + KeepAlive struct { + Disabled TypeBool `json:"disabled"` + Idle TypeDuration `json:"idle"` + Interval TypeDuration `json:"interval"` + Count TypeConcurrency `json:"count"` + } `json:"keepAlive"` DOHIP TypeIP `json:"dohIp"` DNS TypeDNSURI `json:"dns"` Proxies []TypeProxyURL `json:"proxies"` diff --git a/internal/config/parse.go b/internal/config/parse.go index 6136588bb..bdc76162d 100644 --- a/internal/config/parse.go +++ b/internal/config/parse.go @@ -60,6 +60,12 @@ type tomlConfig struct { Idle string `toml:"idle" json:"idle,omitempty"` Handshake string `toml:"handshake" json:"handshake,omitempty"` } `toml:"timeout" json:"timeout,omitempty"` + KeepAlive struct { + Disabled bool `toml:"disabled" json:"disabled,omitempty"` + Idle string `toml:"idle" json:"idle,omitempty"` + Interval string `toml:"interval" json:"interval,omitempty"` + Count uint `toml:"count" json:"count,omitempty"` + } `toml:"keep-alive" json:"keepAlive,omitempty"` DOHIP string `toml:"doh-ip" json:"dohIp,omitempty"` DNS string `toml:"dns" json:"dns,omitempty"` Proxies []string `toml:"proxies" json:"proxies,omitempty"` diff --git a/network/v2/base_http_test.go b/network/v2/base_http_test.go index 904778296..926db428b 100644 --- a/network/v2/base_http_test.go +++ b/network/v2/base_http_test.go @@ -25,7 +25,7 @@ func (suite *BaseHTTPTestSuite) SetupSuite() { } func (suite *BaseHTTPTestSuite) SetupTest() { - suite.client = network.New(nil, "mtg/1", 0, 0, 0).MakeHTTPClient(nil) + suite.client = network.New(nil, "mtg/1", 0, 0, 0, network.DefaultKeepAliveConfig).MakeHTTPClient(nil) } func (suite *BaseHTTPTestSuite) TestGet() { diff --git a/network/v2/base_network_test.go b/network/v2/base_network_test.go index f2d7a1075..32fbe35df 100644 --- a/network/v2/base_network_test.go +++ b/network/v2/base_network_test.go @@ -19,7 +19,7 @@ type BaseNetworkTestSuite struct { func (suite *BaseNetworkTestSuite) SetupSuite() { suite.EchoServerTestSuite.SetupSuite() - suite.net = network.New(nil, "agent", 0, 0, 0) + suite.net = network.New(nil, "agent", 0, 0, 0, network.DefaultKeepAliveConfig) } func (suite *BaseNetworkTestSuite) TestDialUnknownNetwork() { diff --git a/network/v2/init.go b/network/v2/init.go index abcc064d4..0b807d479 100644 --- a/network/v2/init.go +++ b/network/v2/init.go @@ -11,6 +11,7 @@ package network import ( "errors" + "net" "time" ) @@ -27,19 +28,25 @@ const ( // DefaultTCPKeepAlivePeriod defines a time period between 2 consecuitive // probes. // - // Deprecated: use DefaultKeepAliveIdle and DefaultKeepAliveInterval instead. + // Deprecated: use DefaultKeepAliveConfig DefaultTCPKeepAlivePeriod = 10 * time.Second // DefaultKeepAliveIdle is the time a connection must be idle before // the first keepalive probe is sent. + // + // Deprecated: use DefaultKeepAliveConfig DefaultKeepAliveIdle = 30 * time.Second // DefaultKeepAliveInterval is the time between consecutive keepalive // probes. + // + // Deprecated: use DefaultKeepAliveConfig DefaultKeepAliveInterval = 10 * time.Second // DefaultKeepAliveCount is the number of unacknowledged probes before // the connection is considered dead. + // + // Deprecated: use DefaultKeepAliveConfig DefaultKeepAliveCount = 3 // User Agent to use in HTTP client. @@ -50,4 +57,17 @@ const ( tcpLingerTimeout = 1 ) -var ErrCannotDial = errors.New("cannot dial to any address") +var ( + ErrCannotDial = errors.New("cannot dial to any address") + + // DefaultKeepAliveConfig defines a default configuration for + // keep alive settings. As per official documentation, if keep alive + // is enabled, then: + // + // Idle = 15 * time.Second + // Interval = 15 * time.Second + // Count = 9 + DefaultKeepAliveConfig = net.KeepAliveConfig{ + Enable: true, + } +) diff --git a/network/v2/network.go b/network/v2/network.go index 3f961edd5..0590b1479 100644 --- a/network/v2/network.go +++ b/network/v2/network.go @@ -14,9 +14,10 @@ import ( type network struct { net.Dialer - httpTimeout time.Duration - idleTimeout time.Duration - userAgent string + keepAliveConfig net.KeepAliveConfig + httpTimeout time.Duration + idleTimeout time.Duration + userAgent string } func (n *network) Dial(network, address string) (essentials.Conn, error) { @@ -37,7 +38,7 @@ func (n *network) DialContext(ctx context.Context, network, address string) (ess tcpConn := conn.(*net.TCPConn) - return tcpConn, setCommonSocketOptions(tcpConn) + return tcpConn, setCommonSocketOptions(tcpConn, n.keepAliveConfig) } func (n *network) MakeHTTPClient( @@ -71,6 +72,7 @@ func New( tcpTimeout, httpTimeout, idleTimeout time.Duration, + keepAliveConfig net.KeepAliveConfig, ) mtglib.Network { if dnsResolver == nil { dnsResolver = net.DefaultResolver @@ -86,8 +88,9 @@ func New( Resolver: dnsResolver, FallbackDelay: -1, }, - userAgent: userAgent, - idleTimeout: idleTimeout, - httpTimeout: httpTimeout, + userAgent: userAgent, + idleTimeout: idleTimeout, + httpTimeout: httpTimeout, + keepAliveConfig: keepAliveConfig, } } diff --git a/network/v2/sockopts.go b/network/v2/sockopts.go index e89b1eaa7..cf3490c40 100644 --- a/network/v2/sockopts.go +++ b/network/v2/sockopts.go @@ -5,13 +5,8 @@ import ( "net" ) -func setCommonSocketOptions(conn *net.TCPConn) error { - if err := conn.SetKeepAliveConfig(net.KeepAliveConfig{ - Enable: true, - Idle: DefaultKeepAliveIdle, - Interval: DefaultKeepAliveInterval, - Count: DefaultKeepAliveCount, - }); err != nil { +func setCommonSocketOptions(conn *net.TCPConn, keepAliveConfig net.KeepAliveConfig) error { + if err := conn.SetKeepAliveConfig(keepAliveConfig); err != nil { return fmt.Errorf("cannot configure TCP keepalive: %w", err) } diff --git a/network/v2/sockopts_test.go b/network/v2/sockopts_test.go index 94226c1aa..939d3c2f2 100644 --- a/network/v2/sockopts_test.go +++ b/network/v2/sockopts_test.go @@ -65,7 +65,7 @@ func TestSetCommonSocketOptionsKeepAlive(t *testing.T) { tcpConn := accepted.(*net.TCPConn) - err = setCommonSocketOptions(tcpConn) + err = setCommonSocketOptions(tcpConn, DefaultKeepAliveConfig) require.NoError(t, err) rawConn, err := tcpConn.SyscallConn() @@ -78,15 +78,15 @@ func TestSetCommonSocketOptionsKeepAlive(t *testing.T) { idle, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, tcpKeepIdleOption()) require.NoError(t, err) - require.Equal(t, int(DefaultKeepAliveIdle.Seconds()), idle, "keepalive idle should match DefaultKeepAliveIdle") + require.Equal(t, 15, idle, "keepalive idle should match DefaultKeepAliveIdle") interval, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_KEEPINTVL) require.NoError(t, err) - require.Equal(t, int(DefaultKeepAliveInterval.Seconds()), interval, "keepalive interval should match DefaultKeepAliveInterval") + require.Equal(t, 15, interval, "keepalive interval should match DefaultKeepAliveInterval") count, err := unix.GetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_KEEPCNT) require.NoError(t, err) - require.Equal(t, DefaultKeepAliveCount, count, "keepalive count should match DefaultKeepAliveCount") + require.Equal(t, 9, count, "keepalive count should match DefaultKeepAliveCount") }) require.NoError(t, err) } diff --git a/network/v2/socks_proxy_test.go b/network/v2/socks_proxy_test.go index d6d41ae35..6d74b6dd0 100644 --- a/network/v2/socks_proxy_test.go +++ b/network/v2/socks_proxy_test.go @@ -66,7 +66,7 @@ func (suite *SocksProxyTestSuite) SetupSuite() { require.NoError(suite.T(), err) suite.authURL = parsed - suite.baseNetwork = network.New(nil, "mtg", 0, 0, 0) + suite.baseNetwork = network.New(nil, "mtg", 0, 0, 0, network.DefaultKeepAliveConfig) } func (suite *SocksProxyTestSuite) TestIncorrectSchema() {