diff --git a/src/linux/init/GnsEngine.cpp b/src/linux/init/GnsEngine.cpp index e0af19028..ce9e0ada7 100644 --- a/src/linux/init/GnsEngine.cpp +++ b/src/linux/init/GnsEngine.cpp @@ -25,12 +25,13 @@ constexpr auto c_ipStrings = {"ip", "ip6"}; const char* c_loopbackInterfaceName = "lo"; GnsEngine::GnsEngine( + wsl::shared::SocketChannel& channel, const NotificationRoutine& notificationRoutine, const StatusRoutine& statusRoutine, NetworkManager& manager, std::optional dnsTunnelingFd, const std::string& dnsTunnelingIpAddress) : - notificationRoutine(notificationRoutine), statusRoutine(statusRoutine), manager(manager) + channel(channel), notificationRoutine(notificationRoutine), statusRoutine(statusRoutine), manager(manager) { if (dnsTunnelingFd.has_value()) { @@ -363,11 +364,11 @@ void GnsEngine::ProcessLinkChange(Interface& interface, const wsl::shared::hns:: } } -std::tuple GnsEngine::ProcessNextMessage() +std::tuple GnsEngine::ProcessNextMessage(wsl::shared::Transaction& transaction) { int return_value = 0; - auto payload = notificationRoutine(); + auto payload = notificationRoutine(transaction); if (!payload.has_value()) { GNS_LOG_ERROR("Received empty message, exiting"); @@ -723,22 +724,23 @@ void GnsEngine::run() while (true) { + auto transaction = channel.ReceiveTransaction(); try { GNS_LOG_INFO("Processing Next Message"); - auto [should_continue, return_value] = ProcessNextMessage(); + auto [should_continue, return_value] = ProcessNextMessage(transaction); if (!should_continue) { break; } GNS_LOG_INFO("Processing Next Message Successful ({:#x})", return_value); - statusRoutine(return_value, ""); + statusRoutine(return_value, "", transaction); } catch (const std::exception& e) { GNS_LOG_ERROR("Error while processing message: {}", e.what()); - statusRoutine(-1, e.what()); + statusRoutine(-1, e.what(), transaction); } } diff --git a/src/linux/init/GnsEngine.h b/src/linux/init/GnsEngine.h index c6ea91576..39119ead9 100644 --- a/src/linux/init/GnsEngine.h +++ b/src/linux/init/GnsEngine.h @@ -21,10 +21,11 @@ class GnsEngine std::optional AdapterId; }; - using NotificationRoutine = std::function()>; - using StatusRoutine = std::function; + using NotificationRoutine = std::function(wsl::shared::Transaction&)>; + using StatusRoutine = std::function; GnsEngine( + wsl::shared::SocketChannel& channel, const NotificationRoutine& notificationRoutine, const StatusRoutine& statusRoutine, NetworkManager& manager, @@ -36,7 +37,7 @@ class GnsEngine void run(); private: - std::tuple ProcessNextMessage(); + std::tuple ProcessNextMessage(wsl::shared::Transaction& transaction); void ProcessNotification(const nlohmann::json& payload, Interface& interface); @@ -69,6 +70,7 @@ class GnsEngine void ProcessNotificationImpl( Interface& interface, const nlohmann::json& payload, void (GnsEngine::*routine)(Interface&, const T&, wsl::shared::hns::ModifyRequestType)); + wsl::shared::SocketChannel& channel; const NotificationRoutine& notificationRoutine; const StatusRoutine& statusRoutine; NetworkManager& manager; diff --git a/src/linux/init/binfmt.cpp b/src/linux/init/binfmt.cpp index 0bc10e289..496a2d302 100644 --- a/src/linux/init/binfmt.cpp +++ b/src/linux/init/binfmt.cpp @@ -174,7 +174,8 @@ try // Send the create process message to the interop server. // - channel.SendMessage(Span); + auto transaction = channel.StartTransaction(); + transaction.Send(Span); // // Accept connections from the interop server. diff --git a/src/linux/init/config.cpp b/src/linux/init/config.cpp index b4a554b1c..6162ee49a 100644 --- a/src/linux/init/config.cpp +++ b/src/linux/init/config.cpp @@ -334,7 +334,7 @@ try CATCH_LOG() void ConfigHandleInteropMessage( - wsl::shared::SocketChannel& ResponseChannel, + wsl::shared::Transaction& Transaction, wsl::shared::SocketChannel& InteropChannel, bool Elevated, gsl::span Message, @@ -350,7 +350,7 @@ Routine Description: Arguments: - ResponseChannel - Supplies channel used to send responses. + Transaction - Supplies transaction used to send responses. InteropChannel - Supplies a channel to the host to be used for create process requests. @@ -381,7 +381,7 @@ try case LxInitMessageQueryDrvfsElevated: { - ResponseChannel.SendResultMessage(Elevated); + Transaction.SendResultMessage(Elevated); break; } @@ -397,7 +397,7 @@ try auto Value = UtilGetEnvironmentVariable(Query->Buffer); wsl::shared::MessageWriter Response(LxInitMessageQueryEnvironmentVariable); Response.WriteString(Value); - ResponseChannel.SendMessage(Response.Span()); + Transaction.Send(Response.Span()); } break; @@ -405,7 +405,7 @@ try case LxInitMessageQueryFeatureFlags: { assert(Config.FeatureFlags.has_value()); - ResponseChannel.SendResultMessage(Config.FeatureFlags.value()); + Transaction.SendResultMessage(Config.FeatureFlags.value()); break; } @@ -419,7 +419,7 @@ try } bool success = false; - auto sendResponse = wil::scope_exit([&]() { ResponseChannel.SendResultMessage(success); }); + auto sendResponse = wil::scope_exit([&]() { Transaction.SendResultMessage(success); }); if (!Config.BootInit || Config.InitPid.value_or(0) != getpid()) { @@ -435,7 +435,7 @@ try case LxInitMessageQueryNetworkingMode: assert(Config.NetworkingMode.has_value()); - ResponseChannel.SendResultMessage(static_cast(Config.NetworkingMode.value())); + Transaction.SendResultMessage(static_cast(Config.NetworkingMode.value())); break; case LxInitMessageQueryVmId: @@ -446,7 +446,7 @@ try Response.WriteString(Config.VmId.value()); } - ResponseChannel.SendMessage(Response.Span()); + Transaction.Send(Response.Span()); break; } @@ -618,7 +618,7 @@ try } CATCH_LOG() -int ConfigInitializeInstance(wsl::shared::SocketChannel& Channel, gsl::span Buffer, wsl::linux::WslDistributionConfig& Config) +int ConfigInitializeInstance(const std::function&)>& SendResponse, gsl::span Buffer, wsl::linux::WslDistributionConfig& Config) /*++ @@ -632,7 +632,7 @@ Routine Description: Arguments: - MessageFd - Supplies a file descriptor to send the response message. + SendResponse - Supplies a function to send the response message. Buffer - Supplies the message buffer. @@ -923,7 +923,7 @@ try Response.WriteString(Response->VersionIndex, Version->c_str()); } - Channel.SendMessage(Response.Span()); + SendResponse(Response.Span()); // // Accept the interop connection. @@ -973,13 +973,14 @@ try continue; } - auto [Message, Span] = ClientChannel.ReceiveMessageOrClosed(); + auto transaction = ClientChannel.ReceiveTransaction(); + auto [Message, Span] = transaction.ReceiveOrClosed(); if (Message == nullptr) { continue; } - ConfigHandleInteropMessage(ClientChannel, InteropChannel, Elevated, Span, Message, Config); + ConfigHandleInteropMessage(transaction, InteropChannel, Elevated, Span, Message, Config); } }); @@ -2186,7 +2187,7 @@ Return Value: return Result; } -int ConfigRemountDrvFs(gsl::span Buffer, wsl::shared::SocketChannel& Channel, const wsl::linux::WslDistributionConfig& Config) +int ConfigRemountDrvFs(gsl::span Buffer, wsl::shared::Transaction& Transaction, const wsl::linux::WslDistributionConfig& Config) /*++ @@ -2207,7 +2208,7 @@ Return Value: --*/ { - Channel.SendResultMessage(ConfigRemountDrvFsImpl(Buffer, Config)); + Transaction.SendResultMessage(ConfigRemountDrvFsImpl(Buffer, Config)); return 0; } diff --git a/src/linux/init/config.h b/src/linux/init/config.h index bec5d71c2..771bda245 100644 --- a/src/linux/init/config.h +++ b/src/linux/init/config.h @@ -20,6 +20,7 @@ Module Name: #include #include #include +#include #include "SocketChannel.h" #include "WslDistributionConfig.h" @@ -399,7 +400,7 @@ std::set> ConfigGetMountedDrvFsVolumes(void std::vector> ConfigGetWslgEnvironmentVariables(const wsl::linux::WslDistributionConfig& Config); void ConfigHandleInteropMessage( - wsl::shared::SocketChannel& ResponseChannel, + wsl::shared::Transaction& Transaction, wsl::shared::SocketChannel& InteropChannel, bool Elevated, gsl::span Message, @@ -408,7 +409,7 @@ void ConfigHandleInteropMessage( void ConfigInitializeCgroups(wsl::linux::WslDistributionConfig& Config); -int ConfigInitializeInstance(wsl::shared::SocketChannel& Channel, gsl::span Buffer, wsl::linux::WslDistributionConfig& Config); +int ConfigInitializeInstance(const std::function&)>& SendResponse, gsl::span Buffer, wsl::linux::WslDistributionConfig& Config); void ConfigMountDrvFsVolumes(unsigned int DrvFsVolumes, uid_t OwnerUid, std::optional Admin, const wsl::linux::WslDistributionConfig& Config); @@ -420,7 +421,7 @@ int ConfigRegisterBinfmtInterpreter(void); int ConfigSetMountNamespace(bool Elevated); -int ConfigRemountDrvFs(gsl::span Buffer, wsl::shared::SocketChannel& Channel, const wsl::linux::WslDistributionConfig& Config); +int ConfigRemountDrvFs(gsl::span Buffer, wsl::shared::Transaction& Transaction, const wsl::linux::WslDistributionConfig& Config); int ConfigRemountDrvFsImpl(gsl::span Buffer, const wsl::linux::WslDistributionConfig& Config); diff --git a/src/linux/init/drvfs.cpp b/src/linux/init/drvfs.cpp index c36ae7e81..128157e57 100644 --- a/src/linux/init/drvfs.cpp +++ b/src/linux/init/drvfs.cpp @@ -209,8 +209,9 @@ Return Value: QueryPortMessage.MessageType = LxInitMessageQueryDrvfsElevated; QueryPortMessage.MessageSize = sizeof(QueryPortMessage); - channel.SendMessage(QueryPortMessage); - return channel.ReceiveMessage>().Result; + auto transaction = channel.StartTransaction(); + transaction.Send(QueryPortMessage); + return transaction.Receive>().Result; } int MountFilesystem(const char* FsType, const char* Source, const char* Target, const char* Options, int* ExitCode) diff --git a/src/linux/init/init.cpp b/src/linux/init/init.cpp index 479b0992f..2cd22fc0d 100644 --- a/src/linux/init/init.cpp +++ b/src/linux/init/init.cpp @@ -116,10 +116,15 @@ int InitConnectToServer(int LxBusFd, bool WaitForServer); int InitCreateProcessUtilityVm( gsl::span Message, const LX_INIT_CREATE_PROCESS_UTILITY_VM& Header, - wsl::shared::SocketChannel& MessageFd, + wsl::shared::Transaction& Transaction, const wsl::linux::WslDistributionConfig& Config); -int InitCreateSessionLeader(gsl::span Buffer, wsl::shared::SocketChannel& Channel, int LxBusFd, wsl::linux::WslDistributionConfig& Config); +int InitCreateSessionLeader( + gsl::span Buffer, + wsl::shared::SocketChannel& Channel, + const std::function& SendResponse, + int LxBusFd, + wsl::linux::WslDistributionConfig& Config); void InitEntry(int Argc, char* Argv[]); @@ -127,7 +132,7 @@ void InitEntryWsl(wsl::linux::WslDistributionConfig& Config); void InitEntryUtilityVm(wsl::linux::WslDistributionConfig& Config); -void InitTerminateInstance(gsl::span Buffer, wsl::shared::SocketChannel& Channel, wsl::linux::WslDistributionConfig& Config); +void InitTerminateInstance(gsl::span Buffer, const std::function& SendResult, wsl::linux::WslDistributionConfig& Config); void InitTerminateInstanceInternal(const wsl::linux::WslDistributionConfig& Config); @@ -1111,7 +1116,12 @@ Return Value: return 0; } -int InitCreateSessionLeader(gsl::span Buffer, wsl::shared::SocketChannel& Channel, int LxBusFd, wsl::linux::WslDistributionConfig& Config) +int InitCreateSessionLeader( + gsl::span Buffer, + wsl::shared::SocketChannel& Channel, + const std::function& SendResponse, + int LxBusFd, + wsl::linux::WslDistributionConfig& Config) /*++ @@ -1228,7 +1238,7 @@ try Response.Header.MessageType = LxInitMessageCreateSessionResponse; Response.Header.MessageSize = sizeof(Response); Response.Port = SocketAddress.svm_port; - Channel.SendMessage(Response); + SendResponse(Response); if (!ListenSocket) { @@ -1329,7 +1339,7 @@ Return Value: int InitCreateProcessUtilityVm( gsl::span Span, const LX_INIT_CREATE_PROCESS_UTILITY_VM& CreateProcess, - wsl::shared::SocketChannel& Channel, + wsl::shared::Transaction& Transaction, const wsl::linux::WslDistributionConfig& Config) /*++ @@ -1414,7 +1424,7 @@ Return Value: // Tell the service which sockets ports to connect to. // - Channel.SendResultMessage(SocketAddress.svm_port); + Transaction.SendResultMessage(SocketAddress.svm_port); // // Exit if creating the listening socket failed. @@ -1978,13 +1988,14 @@ Return Value: continue; } - auto [Header, Span] = channel.ReceiveMessageOrClosed(); + auto transaction = channel.ReceiveTransaction(); + auto [Header, Span] = transaction.ReceiveOrClosed(); if (Header != nullptr) { try { ConfigHandleInteropMessage( - channel, ControlChannel, WI_IsFlagSet(CreateProcess.Common.Flags, LxInitCreateProcessFlagsElevated), Span, Header, Config); + transaction, ControlChannel, WI_IsFlagSet(CreateProcess.Common.Flags, LxInitCreateProcessFlagsElevated), Span, Header, Config); } CATCH_LOG(); } @@ -2445,7 +2456,8 @@ Return Value: } else if (PollDescriptors[0].revents & POLLIN) { - auto [Header, Span] = channel.ReceiveMessageOrClosed(); + auto transaction = channel.ReceiveTransaction(); + auto [Header, Span] = transaction.ReceiveOrClosed(); if (Header == nullptr) { break; @@ -2454,16 +2466,23 @@ Return Value: switch (Header->MessageType) { case LxInitMessageCreateSession: - if (InitCreateSessionLeader(Span, channel, -1, Config) < 0) + { + auto SendResponse = [&](LX_INIT_CREATE_SESSION_RESPONSE& response) { transaction.Send(response); }; + if (InitCreateSessionLeader(Span, channel, SendResponse, -1, Config) < 0) { FATAL_ERROR("InitCreateSessionLeader failed"); } - - break; + } + break; case LxInitMessageInitialize: - ConfigInitializeInstance(channel, Span, Config); - break; + { + auto SendResponse = [&](const gsl::span& span) { + transaction.Send(span); + }; + ConfigInitializeInstance(SendResponse, Span, Config); + } + break; case LxInitMessageTimezoneInformation: UpdateTimezone(Span, Config); @@ -2479,15 +2498,18 @@ Return Value: // WaitForBootProcess(Config); - ConfigRemountDrvFs(Span, channel, Config); + ConfigRemountDrvFs(Span, transaction, Config); break; case LxInitMessageTerminateInstance: - InitTerminateInstance(Span, channel, Config); - break; + { + auto SendResult = [&](bool result) { transaction.SendResultMessage(result); }; + InitTerminateInstance(Span, SendResult, Config); + } + break; case LxInitCreateProcess: - ProcessCreateProcessMessage(channel, Span); + ProcessCreateProcessMessage(transaction, Span); break; default: @@ -2602,7 +2624,9 @@ Return Value: switch (Header->MessageType) { case LxInitMessageCreateSession: - if (InitCreateSessionLeader(Message, Channel, LxBusFd.get(), Config) < 0) + { + auto SendResponse = [&](LX_INIT_CREATE_SESSION_RESPONSE& response) { Channel.SendMessage(response); }; + if (InitCreateSessionLeader(Message, Channel, SendResponse, LxBusFd.get(), Config) < 0) { // // If this distro has no children, exit on failure. @@ -2616,24 +2640,32 @@ Return Value: LOG_ERROR("InitCreateSessionLeader failed"); } - - break; + } + break; case LxInitMessageNetworkInformation: ConfigUpdateNetworkInformation(Message, Config); break; case LxInitMessageInitialize: - ConfigInitializeInstance(Channel, Message, Config); - break; + { + auto SendResponse = [&](const gsl::span& span) { + Channel.SendMessage(span); + }; + ConfigInitializeInstance(SendResponse, Message, Config); + } + break; case LxInitMessageTimezoneInformation: UpdateTimezone(Message, Config); break; case LxInitMessageTerminateInstance: - InitTerminateInstance(Message, Channel, Config); - break; + { + auto SendResult = [&](bool result) { Channel.SendResultMessage(result); }; + InitTerminateInstance(Message, SendResult, Config); + } + break; default: FATAL_ERROR("Unexpected message {}", Header->MessageType); @@ -2643,7 +2675,7 @@ Return Value: return; } -void InitTerminateInstance(gsl::span Buffer, wsl::shared::SocketChannel& Channel, wsl::linux::WslDistributionConfig& Config) +void InitTerminateInstance(gsl::span Buffer, const std::function& SendResult, wsl::linux::WslDistributionConfig& Config) /*++ @@ -2655,7 +2687,7 @@ Routine Description: Buffer - Supplies the message buffer. - Channel - Supplies a channel to send the response. + SendResult - Supplies a function to send the response. Config - Supplies the distribution config. @@ -2680,7 +2712,7 @@ try if (!StopPlan9Server(Message->Force, Config)) { - Channel.SendResultMessage(false); + SendResult(false); return; } @@ -3026,7 +3058,8 @@ Return Value: for (;;) { - auto [Message, Span] = channel.ReceiveMessageOrClosed(); + auto transaction = channel.ReceiveTransaction(); + auto [Message, Span] = transaction.ReceiveOrClosed(); if (Message == nullptr) { _exit(0); @@ -3035,7 +3068,7 @@ Return Value: switch (Message->Header.MessageType) { case LxInitMessageCreateProcessUtilityVm: - if (InitCreateProcessUtilityVm(Span, *Message, channel, Config) < 0) + if (InitCreateProcessUtilityVm(Span, *Message, transaction, Config) < 0) { FATAL_ERROR("InitCreateProcessUtilityVm failed"); } @@ -3284,7 +3317,7 @@ unsigned int StartGns(int Argc, char** Argv) if (channel.Socket() == -1) { - readNotification = [&]() -> std::optional { + readNotification = [&](wsl::shared::Transaction&) -> std::optional { std::string content{std::istreambuf_iterator(std::cin), std::istreambuf_iterator()}; if (content.empty()) { @@ -3298,7 +3331,7 @@ unsigned int StartGns(int Argc, char** Argv) return {{AdapterId.has_value() ? LxGnsMessageNotification : LxGnsMessageInterfaceConfiguration, content, AdapterId}}; }; - returnStatus = [&](int Result, const std::string& Error) { + returnStatus = [&](int Result, const std::string& Error, wsl::shared::Transaction&) { GNS_LOG_INFO("Returning LxGnsMessageResult (no output fd) [{} - {}]", Result, Error.c_str()); // exitCode keeps the most recent error in the test path if (Result != 0) @@ -3310,9 +3343,9 @@ unsigned int StartGns(int Argc, char** Argv) } else { - readNotification = [&]() -> std::optional { + readNotification = [&](wsl::shared::Transaction& transaction) -> std::optional { std::vector Buffer; - auto [Message, Span] = channel.ReceiveMessageOrClosed(); + auto [Message, Span] = transaction.ReceiveOrClosed(); if (Message == nullptr) { return {}; @@ -3375,7 +3408,7 @@ unsigned int StartGns(int Argc, char** Argv) } }; - returnStatus = [&](int Result, const std::string& Error) { + returnStatus = [&](int Result, const std::string& Error, wsl::shared::Transaction& transaction) { std::vector Buffer(sizeof(LX_GNS_RESULT) + Error.size() + 1); GNS_LOG_INFO("Returning LxGnsMessageResult [{} - {}]", Result, Error.c_str()); @@ -3387,13 +3420,13 @@ unsigned int StartGns(int Argc, char** Argv) response.WriteString(Error); } - return channel.SendMessage(response.Span()); + return transaction.Send(response.Span()); }; } RoutingTable routingTable(RT_TABLE_MAIN); NetworkManager manager(routingTable); - GnsEngine engine(readNotification, returnStatus, manager, DnsFd, DnsTunnelingIp); + GnsEngine engine(channel, readNotification, returnStatus, manager, DnsFd, DnsTunnelingIp); engine.run(); diff --git a/src/linux/init/localhost.cpp b/src/linux/init/localhost.cpp index f1af32182..baca3278e 100644 --- a/src/linux/init/localhost.cpp +++ b/src/linux/init/localhost.cpp @@ -246,7 +246,8 @@ try { auto message = SockToRelayMessage(sock); message.Header.MessageType = LxGnsMessagePortListenerRelayStart; - channel.SendMessage(message); + auto transaction = channel.StartTransaction(); + transaction.Send(message); return 0; } @@ -257,7 +258,8 @@ try { auto message = SockToRelayMessage(sock); message.Header.MessageType = LxGnsMessagePortListenerRelayStop; - channel.SendMessage(message); + auto transaction = channel.StartTransaction(); + transaction.Send(message); return 0; } diff --git a/src/linux/init/main.cpp b/src/linux/init/main.cpp index f31ebf732..c67077249 100644 --- a/src/linux/init/main.cpp +++ b/src/linux/init/main.cpp @@ -193,7 +193,7 @@ int MountInit(const char* Target); int MountPlan9(const char* Name, const char* Target, bool ReadOnly, std::optional BufferSize = {}); -int ProcessMessage(wsl::shared::SocketChannel& channel, LX_MESSAGE_TYPE Type, gsl::span Buffer, VmConfiguration& Config); +int ProcessMessage(wsl::shared::Transaction& Transaction, LX_MESSAGE_TYPE Type, gsl::span Buffer, VmConfiguration& Config); wil::unique_fd RegisterSeccompHook(); @@ -2808,7 +2808,7 @@ void ProcessImportExportMessage(gsl::span Buffer, wsl::shared::Socket } } -int ProcessMountFolderMessage(wsl::shared::SocketChannel& Channel, gsl::span Buffer) +int ProcessMountFolderMessage(wsl::shared::Transaction& Transaction, gsl::span Buffer) /*++ @@ -2844,7 +2844,7 @@ Return Value: } int Result = MountPlan9(Name, Target, Message->ReadOnly); - Channel.SendResultMessage(Result); + Transaction.SendResultMessage(Result); return 0; } @@ -3163,7 +3163,7 @@ try } CATCH_RETURN_ERRNO(); -int ProcessMessage(wsl::shared::SocketChannel& Channel, LX_MESSAGE_TYPE Type, gsl::span Buffer, VmConfiguration& Config) +int ProcessMessage(wsl::shared::Transaction& Transaction, LX_MESSAGE_TYPE Type, gsl::span Buffer, VmConfiguration& Config) /*++ @@ -3173,9 +3173,7 @@ Routine Description: Arguments: - MessageFd - Supplies a file descriptor to the socket on which the message was - received. This is used for operations that require responses, for example a - VHD eject request. + Transaction - Supplies the transaction for replying to the message. Buffer - Supplies the message. @@ -3259,7 +3257,7 @@ try return -1; } - Channel.SendResultMessage(EjectScsi(EjectMessage->Lun)); + Transaction.SendResultMessage(EjectScsi(EjectMessage->Lun)); return 0; } @@ -3495,10 +3493,10 @@ try return 0; case LxMiniInitMountFolder: - return ProcessMountFolderMessage(Channel, Buffer); + return ProcessMountFolderMessage(Transaction, Buffer); case LxInitCreateProcess: - return ProcessCreateProcessMessage(Channel, Buffer); + return ProcessCreateProcessMessage(Transaction, Buffer); case LxMiniInitMessageWaitForPmemDevice: { @@ -4175,13 +4173,14 @@ int main(int Argc, char* Argv[]) } else if (PollDescriptors[0].revents & POLLIN) { - auto [Message, Range] = channel.ReceiveMessageOrClosed(); + auto transaction = channel.ReceiveTransaction(); + auto [Message, Range] = transaction.ReceiveOrClosed(); if (Message == nullptr) { break; // Socket was closed, exit } - Result = ProcessMessage(channel, Message->MessageType, Range, Config); + Result = ProcessMessage(transaction, Message->MessageType, Range, Config); if (Result < 0) { goto ErrorExit; diff --git a/src/linux/init/plan9.cpp b/src/linux/init/plan9.cpp index 1cc5cf3c0..81df7490f 100644 --- a/src/linux/init/plan9.cpp +++ b/src/linux/init/plan9.cpp @@ -157,13 +157,14 @@ try std::vector Buffer; for (;;) { - auto [Message, _] = channel.ReceiveMessageOrClosed(); + auto transaction = channel.ReceiveTransaction(); + auto [Message, _] = transaction.ReceiveOrClosed(); if (Message == nullptr) { _exit(0); } - channel.SendResultMessage(StopPlan9Server(fileSystem, Message->Force)); + transaction.SendResultMessage(StopPlan9Server(fileSystem, Message->Force)); } } CATCH_LOG(); diff --git a/src/linux/init/util.cpp b/src/linux/init/util.cpp index 398b94c43..ebcda116f 100644 --- a/src/linux/init/util.cpp +++ b/src/linux/init/util.cpp @@ -1110,13 +1110,14 @@ try wsl::shared::MessageWriter Message(LxInitMessageQueryEnvironmentVariable); Message.WriteString(Name); - channel.SendMessage(Message.Span()); + auto transaction = channel.StartTransaction(); + transaction.Send(Message.Span()); // // Read a response, this will contain the environment variable value if it exists. // - Value = channel.ReceiveMessage().Buffer; + Value = transaction.Receive().Buffer; // // Set the environment variable for future queries. @@ -1195,8 +1196,9 @@ Return Value: Message.MessageType = LxInitMessageQueryFeatureFlags; Message.MessageSize = sizeof(Message); - channel.SendMessage(Message); - FeatureFlags = channel.ReceiveMessage>().Result; + auto transaction = channel.StartTransaction(); + transaction.Send(Message); + FeatureFlags = transaction.Receive>().Result; } UtilSetFeatureFlags(FeatureFlags, FeatureFlagEnv == nullptr); @@ -1264,9 +1266,10 @@ try Message.MessageType = LxInitMessageQueryNetworkingMode; Message.MessageSize = sizeof(Message); - channel.SendMessage(Message); + auto transaction = channel.StartTransaction(); + transaction.Send(Message); - const auto& response = channel.ReceiveMessage>(); + const auto& response = transaction.Receive>(); auto NetworkingMode = static_cast(response.Result); THROW_ERRNO_IF(EINVAL, NetworkingMode < LxMiniInitNetworkingModeNone || NetworkingMode > LxMiniInitNetworkingModeVirtioProxy); @@ -1358,9 +1361,10 @@ try THROW_LAST_ERROR_IF(channel.Socket() < 0); wsl::shared::MessageWriter Message(LxInitMessageQueryVmId); - channel.SendMessage(Message.Span()); + auto transaction = channel.StartTransaction(); + transaction.Send(Message.Span()); - return channel.ReceiveMessage().Buffer; + return transaction.Receive().Buffer; } catch (...) { @@ -3344,7 +3348,7 @@ Return Value: return 0; } -int ProcessCreateProcessMessage(wsl::shared::SocketChannel& channel, gsl::span Buffer) +int ProcessCreateProcessMessage(wsl::shared::Transaction& Transaction, gsl::span Buffer) { auto* Message = gslhelpers::try_get_struct(Buffer); if (!Message) @@ -3353,7 +3357,7 @@ int ProcessCreateProcessMessage(wsl::shared::SocketChannel& channel, gsl::span(Result); }; + auto sendResult = [&](unsigned long Result) { Transaction.SendResultMessage(Result); }; sockaddr_vm SocketAddress{}; wil::unique_fd ListenSocket{UtilListenVsockAnyPort(&SocketAddress, 1, false)}; diff --git a/src/linux/init/util.h b/src/linux/init/util.h index 1fa5414d5..35d0aa3ad 100644 --- a/src/linux/init/util.h +++ b/src/linux/init/util.h @@ -35,7 +35,8 @@ Module Name: namespace wsl::shared { class SocketChannel; -} +class Transaction; +} // namespace wsl::shared namespace wsl::linux { struct WslDistributionConfig; @@ -312,4 +313,4 @@ uint16_t UtilWinAfToLinuxAf(uint16_t AddressFamily); int WriteToFile(const char* Path, const char* Content, int permissions = 0644); -int ProcessCreateProcessMessage(wsl::shared::SocketChannel& channel, gsl::span Buffer); \ No newline at end of file +int ProcessCreateProcessMessage(wsl::shared::Transaction& Transaction, gsl::span Buffer); \ No newline at end of file diff --git a/src/shared/inc/SocketChannel.h b/src/shared/inc/SocketChannel.h index c14147254..7ce62792e 100644 --- a/src/shared/inc/SocketChannel.h +++ b/src/shared/inc/SocketChannel.h @@ -14,6 +14,7 @@ Module Name: #pragma once +#include #include #include "socketshared.h" #include "lxinitshared.h" @@ -41,6 +42,44 @@ constexpr timeval* DefaultSocketTimeout = nullptr; #endif +class SocketChannel; + +class Transaction +{ + friend class SocketChannel; + +public: + ~Transaction() = default; + + NON_COPYABLE(Transaction); + + template + void Send(gsl::span span); + + template + void Send(TMessage& message); + + template + void SendResultMessage(TResult value); + + template + std::pair> ReceiveOrClosed(TTimeout timeout = DefaultSocketTimeout); + + template + TMessage& Receive(gsl::span* responseSpan = nullptr, TTimeout timeout = DefaultSocketTimeout); + +private: + Transaction(SocketChannel& channel, uint32_t id) : + m_channel(channel), m_id(id), m_step(static_cast(TRANSACTION_STEP::REQUEST)) + { + } + + SocketChannel& m_channel; + uint32_t m_id; + /** Use uint32_t as step can go beyond FIRST_REPLY */ + uint32_t m_step; +}; + class SocketChannel { @@ -63,6 +102,9 @@ class SocketChannel m_exitEvent = std::move(other.m_exitEvent); #endif m_ignore_sequence = other.m_ignore_sequence; + m_sent_non_transaction_messages = other.m_sent_non_transaction_messages; + m_received_non_transaction_messages = other.m_received_non_transaction_messages; + m_transaction_id_seed = other.m_transaction_id_seed.load(); return *this; } @@ -82,7 +124,7 @@ class SocketChannel #endif template - void SendMessage(gsl::span span) + void SendMessage(gsl::span span, uint32_t transactionStep = static_cast(TRANSACTION_STEP::NONE), uint32_t transactionId = 0) { // Ensure that no other thread is using this channel. const std::unique_lock lock{m_sendMutex, std::try_to_lock}; @@ -103,12 +145,20 @@ class SocketChannel THROW_INVALID_ARG_IF(m_name == nullptr || span.size() < sizeof(TMessage)); - m_sent_messages++; - auto* header = gslhelpers::try_get_struct(span); WI_ASSERT(header->MessageSize == span.size()); - header->SequenceNumber = m_sent_messages; + if (transactionStep == static_cast(TRANSACTION_STEP::NONE)) + { + m_sent_non_transaction_messages++; + header->TransactionId = m_sent_non_transaction_messages; + header->TransactionStep = static_cast(TRANSACTION_STEP::NONE); + } + else + { + header->TransactionId = transactionId; + header->TransactionStep = transactionStep; + } #ifdef WIN32 @@ -150,7 +200,7 @@ class SocketChannel } template - void SendMessage(TMessage& message) + void SendMessage(TMessage& message, uint32_t transactionStep = static_cast(TRANSACTION_STEP::NONE), uint32_t transactionId = 0) { // Catch situations where the other SendMessage() method should be used const auto& header = GetMessageHeader(message); @@ -164,7 +214,7 @@ class SocketChannel #endif } - SendMessage(gslhelpers::struct_as_writeable_bytes(message)); + SendMessage(gslhelpers::struct_as_writeable_bytes(message), transactionStep, transactionId); } template @@ -179,7 +229,10 @@ class SocketChannel } template - std::pair> ReceiveMessageOrClosed(TTimeout timeout = DefaultSocketTimeout) + std::pair> ReceiveMessageOrClosed( + TTimeout timeout = DefaultSocketTimeout, + uint32_t expectedTransactionStep = static_cast(TRANSACTION_STEP::NONE), + uint32_t expectedTransactionId = 0) { WI_ASSERT(m_name != nullptr); @@ -199,20 +252,180 @@ class SocketChannel #endif } - m_received_messages++; - - auto receivedSpan = ReceiveImpl(TMessage::Type, timeout); - if (receivedSpan.empty()) + gsl::span receivedSpan{}; + for (;;) { + if (expectedTransactionStep == static_cast(TRANSACTION_STEP::NONE)) + { + // Adhere to the old ++ before receive behavior for non-transaction messages. + m_received_non_transaction_messages++; + } + + receivedSpan = ReceiveImpl(TMessage::Type, timeout); + if (receivedSpan.empty()) + { + +#ifdef WIN32 + if (errno == HCS_E_CONNECTION_TIMEOUT) + { + THROW_HR_MSG(HCS_E_CONNECTION_TIMEOUT, "Timeout: %u, expected type: %hs, channel: %hs", timeout, ToString(TMessage::Type), m_name); + } +#endif + + return {nullptr, {}}; + } + + auto* header = gslhelpers::try_get_struct(receivedSpan); + if (header == nullptr) + { +#ifdef WIN32 + THROW_HR_MSG(E_UNEXPECTED, "Message too small for header: %zd, channel: %hs", receivedSpan.size(), m_name); +#else + LOG_ERROR("Message too small for header: {}, channel: {}", receivedSpan.size(), m_name); + THROW_ERRNO(EINVAL); +#endif + } + + if (expectedTransactionStep == static_cast(TRANSACTION_STEP::NONE)) + { + // Handle non-transaction messages with legacy logic. + if (!m_ignore_sequence) + { + if (header->TransactionStep != static_cast(TRANSACTION_STEP::NONE)) + { +#ifdef WIN32 + THROW_HR_MSG( + E_UNEXPECTED, + "Unexpected transaction message received on non-transaction channel: %hs, message type: %hs", + m_name, + ToString(header->MessageType)); +#else + LOG_ERROR( + "Unexpected transaction message received on non-transaction channel: {}, message type: {}", + m_name, + ToString(header->MessageType)); + THROW_ERRNO(EINVAL); +#endif + } + if (header->TransactionId != m_received_non_transaction_messages) + { +#ifdef WIN32 + THROW_HR_MSG( + E_UNEXPECTED, + "Unexpected non-transaction message id: %u, expected: %u, channel: %hs", + header->TransactionId, + m_received_non_transaction_messages, + m_name); +#else + LOG_ERROR("Unexpected non-transaction message id: {}, expected: {}, channel: {}", header->TransactionId, m_received_non_transaction_messages, m_name); + THROW_ERRNO(EINVAL); +#endif + } + } + break; + } + // Handle transaction messages + if (header->TransactionStep == static_cast(TRANSACTION_STEP::NONE)) + { + // Skip stale non-transaction messages #ifdef WIN32 - if (errno == HCS_E_CONNECTION_TIMEOUT) + WSL_LOG( + "DiscardStaleNonTransactionMessage", + TraceLoggingValue(m_name, "Name"), + TraceLoggingValue(ToString(header->MessageType), "MessageType"), + TraceLoggingValue(ToString(TMessage::Type), "ExpectedMessageType"), + TraceLoggingValue(header->TransactionId, "StaleNonTransactionId"), + TraceLoggingValue(m_received_non_transaction_messages, "ExpectedNonTransactionId")); +#else + LOG_WARNING( + "Discard stale non-transaction message on channel: {}. MessageType: {}, ExpectedMessageType: {}, " + "StaleNonTransactionId: {}, ExpectedNonTransactionId: {}", + m_name, + header->MessageType, + TMessage::Type, + header->TransactionId, + m_received_non_transaction_messages); +#endif + continue; + } + + if (expectedTransactionStep == static_cast(TRANSACTION_STEP::REQUEST)) { - THROW_HR_MSG(HCS_E_CONNECTION_TIMEOUT, "Timeout: %d, expected type: %hs, channel: %hs", timeout, ToString(TMessage::Type), m_name); + // Skip until we get the next request. No matter the transaction id. + if (header->TransactionStep != static_cast(TRANSACTION_STEP::REQUEST)) + { +#ifdef WIN32 + WSL_LOG( + "DiscardOutOfOrderTransactionMessage", + TraceLoggingValue(m_name, "Name"), + TraceLoggingValue(ToString(header->MessageType), "MessageType"), + TraceLoggingValue(ToString(TMessage::Type), "ExpectedMessageType"), + TraceLoggingValue(header->TransactionStep, "StaleTransactionStep"), + TraceLoggingValue(expectedTransactionStep, "ExpectedTransactionStep")); +#else + LOG_WARNING( + "Discard out of order transaction message on channel: {}. MessageType: {}, ExpectedMessageType: {}, " + "StaleTransactionStep: {}, ExpectedTransactionStep: {}", + m_name, + header->MessageType, + TMessage::Type, + header->TransactionStep, + expectedTransactionStep); +#endif + continue; + } + break; } + + auto diff = static_cast(header->TransactionId - expectedTransactionId); + if (diff < 0) + { + // Skip stale transaction messages +#ifdef WIN32 + WSL_LOG( + "DiscardStaleTransactionMessage", + TraceLoggingValue(m_name, "Name"), + TraceLoggingValue(ToString(header->MessageType), "MessageType"), + TraceLoggingValue(ToString(TMessage::Type), "ExpectedMessageType"), + TraceLoggingValue(header->TransactionId, "StaleTransactionId"), + TraceLoggingValue(expectedTransactionId, "ExpectedTransactionId")); +#else + LOG_WARNING( + "Discard stale transaction message on channel: {}. MessageType: {}, ExpectedMessageType: {}, " + "StaleTransactionId: {}, ExpectedTransactionId: {}", + m_name, + header->MessageType, + TMessage::Type, + header->TransactionId, + expectedTransactionId); +#endif + continue; + } + + if (diff > 0) + { + // Message is from the future. +#ifdef WIN32 + THROW_HR_MSG(E_UNEXPECTED, "Unexpected transaction message id: %u, expected: %u, channel: %hs", header->TransactionId, expectedTransactionId, m_name); +#else + LOG_ERROR("Unexpected transaction message id: {}, expected: {}, channel: {}", header->TransactionId, expectedTransactionId, m_name); + THROW_ERRNO(EINVAL); +#endif + } + + if (header->TransactionStep != expectedTransactionStep) + { + // Broken transaction. +#ifdef WIN32 + THROW_HR_MSG(E_UNEXPECTED, "Unexpected transaction message step: %u, expected: %u, channel: %hs", header->TransactionStep, expectedTransactionStep, m_name); +#else + LOG_ERROR("Unexpected transaction message step: {}, expected: {}, channel: {}", header->TransactionStep, expectedTransactionStep, m_name); + THROW_ERRNO(EINVAL); #endif + } - return {nullptr, {}}; + break; } auto* message = gslhelpers::try_get_struct(receivedSpan); @@ -228,7 +441,7 @@ class SocketChannel #endif } - ValidateMessageHeader(GetMessageHeader(*message), TMessage::Type, m_received_messages); + ValidateMessageHeader(GetMessageHeader(*message), TMessage::Type); #ifdef WIN32 WSL_LOG( @@ -243,9 +456,13 @@ class SocketChannel } template - TMessage& ReceiveMessage(gsl::span* responseSpan = nullptr, TTimeout timeout = DefaultSocketTimeout) + TMessage& ReceiveMessage( + gsl::span* responseSpan = nullptr, + TTimeout timeout = DefaultSocketTimeout, + uint32_t expectedTransactionStep = static_cast(TRANSACTION_STEP::NONE), + uint32_t expectedTransactionId = 0) { - auto [message, span] = ReceiveMessageOrClosed(timeout); + auto [message, span] = ReceiveMessageOrClosed(timeout, expectedTransactionStep, expectedTransactionId); if (message == nullptr) { #ifdef WIN32 @@ -264,16 +481,28 @@ class SocketChannel return *message; } - template - TSentMessage::TResponse& Transaction(gsl::span message, gsl::span* responseSpan = nullptr, TTimeout timeout = DefaultSocketTimeout) + Transaction StartTransaction() + { + uint32_t transactionId = m_transaction_id_seed++; + return wsl::shared::Transaction(*this, transactionId); + } + + Transaction ReceiveTransaction() { - SendMessage(message); + // Transaction id should follow the received one on the receive end. + return wsl::shared::Transaction(*this, 0); + } - return ReceiveMessage(responseSpan, timeout); + template + typename TSentMessage::TResponse& Transaction(gsl::span message, gsl::span* responseSpan = nullptr, TTimeout timeout = DefaultSocketTimeout) + { + auto transaction = StartTransaction(); + transaction.Send(message); + return transaction.Receive(responseSpan, timeout); } template - TSentMessage::TResponse& Transaction(TSentMessage& message, gsl::span* responseSpan = nullptr, TTimeout timeout = DefaultSocketTimeout) + typename TSentMessage::TResponse& Transaction(TSentMessage& message, gsl::span* responseSpan = nullptr, TTimeout timeout = DefaultSocketTimeout) { WI_ASSERT(message.Header.MessageSize == sizeof(message)); @@ -321,33 +550,33 @@ class SocketChannel #endif - void ValidateMessageHeader(const MESSAGE_HEADER& header, LX_MESSAGE_TYPE expected, unsigned int expectedSequence) const + void ValidateMessageHeader(const MESSAGE_HEADER& header, LX_MESSAGE_TYPE expected) const { - if (header.MessageSize < sizeof(header) || (expected != LxMiniInitMessageAny && header.MessageType != expected) || - (!m_ignore_sequence && header.SequenceNumber != expectedSequence)) + + if (header.MessageSize < sizeof(header) || (expected != LxMiniInitMessageAny && header.MessageType != expected)) { #ifdef WIN32 THROW_HR_MSG( E_UNEXPECTED, - "Protocol error: Received message size: %u, type: %u, sequence: %u. Expected type: %u, expected sequence: %u, " + "Protocol error: Received message size: %u, type: %u, id: %u, step: %u. Expected type: %u, " "channel: %hs", header.MessageSize, header.MessageType, - header.SequenceNumber, + header.TransactionId, + header.TransactionStep, expected, - expectedSequence, m_name); #else LOG_ERROR( - "Protocol error: Received message size: {}, type: {}, sequence: {}. Expected type: {}, expected sequence: {}, " + "Protocol error: Received message size: {}, type: {}, id: {}, step: {}. Expected type: {}, " "channel: {}", header.MessageSize, header.MessageType, - header.SequenceNumber, + header.TransactionId, + header.TransactionStep, expected, - expectedSequence, m_name); THROW_ERRNO(EINVAL); @@ -392,11 +621,65 @@ class SocketChannel HANDLE m_exitEvent{}; #endif - uint32_t m_sent_messages = 0; - uint32_t m_received_messages = 0; + uint32_t m_sent_non_transaction_messages = 0; + uint32_t m_received_non_transaction_messages = 0; + std::atomic m_transaction_id_seed = 0; bool m_ignore_sequence = false; const char* m_name{}; std::mutex m_sendMutex; std::mutex m_receiveMutex; }; -} // namespace wsl::shared \ No newline at end of file + +template +void Transaction::Send(gsl::span span) +{ + m_channel.SendMessage(span, m_step, m_id); + m_step++; +} + +template +void Transaction::Send(TMessage& message) +{ + Send(gslhelpers::struct_as_writeable_bytes(message)); +} + +template +void Transaction::SendResultMessage(TResult value) +{ + RESULT_MESSAGE Result{}; + Result.Header.MessageSize = sizeof(Result); + Result.Header.MessageType = RESULT_MESSAGE::Type; + Result.Result = value; + + Send(Result); +} + +template +std::pair> Transaction::ReceiveOrClosed(TTimeout timeout) +{ + auto result = m_channel.ReceiveMessageOrClosed(timeout, m_step, m_id); + if (m_step == static_cast(TRANSACTION_STEP::REQUEST) && result.first != nullptr) + { + // Use the request's id for the reply side transaction. + MESSAGE_HEADER& header = m_channel.GetMessageHeader(*result.first); + m_id = header.TransactionId; + } + m_step++; + return result; +} + +template +TMessage& Transaction::Receive(gsl::span* responseSpan, TTimeout timeout) +{ + auto& message = m_channel.ReceiveMessage(responseSpan, timeout, m_step, m_id); + if (m_step == static_cast(TRANSACTION_STEP::REQUEST)) + { + // Use the request's id for the reply side transaction. + MESSAGE_HEADER& header = m_channel.GetMessageHeader(message); + m_id = header.TransactionId; + } + m_step++; + return message; +} + +} // namespace wsl::shared diff --git a/src/shared/inc/lxinitshared.h b/src/shared/inc/lxinitshared.h index 0e0ba26e0..e57c55fed 100644 --- a/src/shared/inc/lxinitshared.h +++ b/src/shared/inc/lxinitshared.h @@ -475,15 +475,23 @@ inline void PrettyPrint(std::stringstream& Out, LX_MESSAGE_TYPE Value) Out << ToString(Value); } +enum class TRANSACTION_STEP : unsigned int +{ + NONE = 0, + REQUEST = 1, + FIRST_REPLY = 2, +}; + struct MESSAGE_HEADER { static inline auto Type = LxMiniInitMessageAny; // Setting this allows using MESSAGE_HEADER to receive any type of message LX_MESSAGE_TYPE MessageType; unsigned int MessageSize; - unsigned int SequenceNumber; + unsigned int TransactionId; + unsigned int TransactionStep; - PRETTY_PRINT(FIELD(MessageType), FIELD(MessageSize), FIELD(SequenceNumber)); + PRETTY_PRINT(FIELD(MessageType), FIELD(MessageSize), FIELD(TransactionId), FIELD(TransactionStep)); }; // @@ -771,7 +779,7 @@ typedef struct _LX_GNS_SET_PORT_LISTENER PRETTY_PRINT(FIELD(Header), FIELD(HvSocketPort)); } LX_GNS_SET_PORT_LISTENER, *PLX_GNS_SET_PORT_LISTENER; -static_assert(sizeof(LX_GNS_SET_PORT_LISTENER) == 16); +static_assert(sizeof(LX_GNS_SET_PORT_LISTENER) == 20); typedef struct _LX_GNS_PORT_LISTENER_RELAY { diff --git a/src/shared/inc/socketshared.h b/src/shared/inc/socketshared.h index 1c9129180..85c6423d4 100644 --- a/src/shared/inc/socketshared.h +++ b/src/shared/inc/socketshared.h @@ -95,18 +95,18 @@ try LOG_HR_MSG( E_UNEXPECTED, - "Socket closed while reading message. Size: %u, type: %i, sequence: %u", + "Socket closed while reading message. Size: %u, type: %i, id: %u", Header->MessageSize, Header->MessageType, - Header->SequenceNumber); + Header->TransactionId); #elif defined(__GNUC__) LOG_ERROR( - "Socket closed while reading message. Size: {}, type: {}, sequence: {}", + "Socket closed while reading message. Size: {}, type: {}, id: {}", Header->MessageSize, Header->MessageType, - Header->SequenceNumber); + Header->TransactionId); #endif diff --git a/src/windows/common/GnsPortTrackerChannel.cpp b/src/windows/common/GnsPortTrackerChannel.cpp index 09f70e8d3..61ea606f9 100644 --- a/src/windows/common/GnsPortTrackerChannel.cpp +++ b/src/windows/common/GnsPortTrackerChannel.cpp @@ -33,7 +33,8 @@ void GnsPortTrackerChannel::Run() { for (;;) { - auto [header, range] = m_channel.ReceiveMessageOrClosed(); + auto transaction = m_channel.ReceiveTransaction(); + auto [header, range] = transaction.ReceiveOrClosed(); if (header == nullptr) { return; @@ -46,7 +47,8 @@ void GnsPortTrackerChannel::Run() const auto* message = gslhelpers::try_get_struct(range); THROW_HR_IF_MSG(E_UNEXPECTED, !message, "Unexpected message size: %i", header->MessageSize); - m_channel.SendResultMessage(m_callback(ConvertPortRequestToSockAddr(message), message->Protocol, message->Allocate)); + transaction.SendResultMessage( + m_callback(ConvertPortRequestToSockAddr(message), message->Protocol, message->Allocate)); } break; case LxGnsMessageIfStateChangeRequest: @@ -55,7 +57,7 @@ void GnsPortTrackerChannel::Run() THROW_HR_IF_MSG(E_UNEXPECTED, !message, "Unexpected message size: %i", header->MessageSize); m_interfaceStateCallback(message->InterfaceName, message->InterfaceUp); - m_channel.SendResultMessage(0); + transaction.SendResultMessage(0); } break; default: diff --git a/src/windows/service/exe/LxssCreateProcess.h b/src/windows/service/exe/LxssCreateProcess.h index 8e1c704a4..0bb3e793f 100644 --- a/src/windows/service/exe/LxssCreateProcess.h +++ b/src/windows/service/exe/LxssCreateProcess.h @@ -85,15 +85,15 @@ class LxssCreateProcess wsl::shared::MessageWriter message(LxInitCreateProcess); message.WriteString(message->PathIndex, Path); gsl::copy(as_bytes(gsl::span(ArgumentsData)), message.InsertBuffer(message->CommandLineIndex, ArgumentsData.size())); - channel.SendMessage(message.Span()); + auto transaction = channel.StartTransaction(); + transaction.Send(message.Span()); auto readResult = [&]() { - const auto& message = channel.ReceiveMessage>(nullptr, Timeout); + const auto& message = transaction.Receive>(nullptr, Timeout); return message.Result; }; auto processSocket = wsl::windows::common::hvsocket::Connect(RuntimeId, readResult(), terminatingEvent); - const auto execResult = readResult(); THROW_HR_IF_MSG(E_FAIL, execResult != 0, "Failed to execute '%hs', error=%d", Path, execResult); diff --git a/src/windows/service/exe/WslCoreInstance.cpp b/src/windows/service/exe/WslCoreInstance.cpp index 74fe3168d..98b7379c4 100644 --- a/src/windows/service/exe/WslCoreInstance.cpp +++ b/src/windows/service/exe/WslCoreInstance.cpp @@ -344,7 +344,8 @@ void WslCoreInstance::UpdateTimezone() wsl::windows::common::helpers::GenerateTimezoneUpdateMessage(wsl::windows::common::helpers::GetLinuxTimezone(m_userToken.get())); auto lock = m_initChannel->Lock(); - m_initChannel->GetChannel().SendMessage(gsl::make_span(message)); + auto transaction = m_initChannel->GetChannel().StartTransaction(); + transaction.Send(gsl::make_span(message)); } ULONG64 WslCoreInstance::GetLifetimeManagerId() const @@ -389,11 +390,12 @@ void WslCoreInstance::Initialize() auto config = wsl::windows::common::helpers::GenerateConfigurationMessage( m_configuration.Name, fixedDrives, m_defaultUid, timezone, {}, m_featureFlags, drvfsMount); - m_initChannel->GetChannel().SendMessage(gsl::span(config)); + auto transaction = m_initChannel->GetChannel().StartTransaction(); + transaction.Send(gsl::span(config)); // Init replies with information about the distribution. gsl::span span; - const auto& response = m_initChannel->GetChannel().ReceiveMessage(&span); + const auto& response = transaction.Receive(&span); m_defaultUid = response.DefaultUid; m_plan9Port = response.Plan9Port; m_distributionInfo.PidNamespace = response.PidNamespace; @@ -473,8 +475,9 @@ bool WslCoreInstance::RequestStop(_In_ bool Force) terminateMessage.Header.MessageSize = sizeof(terminateMessage); terminateMessage.Force = Force; - m_initChannel->GetChannel().SendMessage(terminateMessage); - auto [message, span] = m_initChannel->GetChannel().ReceiveMessageOrClosed>(m_socketTimeout); + auto transaction = m_initChannel->GetChannel().StartTransaction(); + transaction.Send(terminateMessage); + auto [message, span] = transaction.ReceiveOrClosed>(m_socketTimeout); if (message) { shutdown = message->Result; diff --git a/src/windows/service/exe/WslCoreVm.cpp b/src/windows/service/exe/WslCoreVm.cpp index 2d26d2bee..20ba3477d 100644 --- a/src/windows/service/exe/WslCoreVm.cpp +++ b/src/windows/service/exe/WslCoreVm.cpp @@ -527,7 +527,8 @@ void WslCoreVm::Initialize(const GUID& VmId, const wil::shared_handle& UserToken message.WriteString(message->KernelModulesListOffset, m_vmConfig.KernelModulesList); message->DnsTunnelingIpAddress = m_vmConfig.DnsTunnelingIpAddress.value_or(0); - m_miniInitChannel.SendMessage(message.Span()); + auto transaction = m_miniInitChannel.StartTransaction(); + transaction.Send(message.Span()); { ExecutionContext context(Context::ConfigureNetworking); @@ -1098,7 +1099,8 @@ void WslCoreVm::CollectCrashDumps(wil::unique_socket&& listenSocket) const auto channel = wsl::shared::SocketChannel{std::move(socket.value()), "crash_dump", m_terminatingEvent.get()}; - const auto& message = channel.ReceiveMessage(); + auto transaction = channel.ReceiveTransaction(); + const auto& message = transaction.Receive(); const char* process = reinterpret_cast(&message.Buffer); constexpr auto dumpExtension = ".dmp"; @@ -1146,7 +1148,7 @@ void WslCoreVm::CollectCrashDumps(wil::unique_socket&& listenSocket) const wil::unique_hfile file{CreateFileW(fullPath.c_str(), GENERIC_WRITE, 0, nullptr, CREATE_NEW, FILE_ATTRIBUTE_TEMPORARY, nullptr)}; THROW_LAST_ERROR_IF(!file); - channel.SendResultMessage(0); + transaction.SendResultMessage(0); wsl::windows::common::relay::InterruptableRelay(reinterpret_cast(channel.Socket()), file.get(), nullptr); } @@ -1206,7 +1208,8 @@ std::shared_ptr WslCoreVm::CreateInstance( message.WriteString(message->SharedMemoryRootOffset, sharedMemoryRoot); message.WriteString(message->InstallPathOffset, installPath); message.WriteString(message->UserProfileOffset, userProfile); - m_miniInitChannel.SendMessage(message.Span()); + auto transaction = m_miniInitChannel.StartTransaction(); + transaction.Send(message.Span()); return CreateInstanceInternal( InstanceId, Configuration, ReceiveTimeout, DefaultUid, ClientLifetimeId, WI_IsFlagSet(flags, LxMiniInitMessageFlagLaunchSystemDistro), ConnectPort); @@ -1844,7 +1847,8 @@ void WslCoreVm::InitializeGuest() } // Send the message. - m_miniInitChannel.SendMessage(message.Span()); + auto transaction = m_miniInitChannel.StartTransaction(); + transaction.Send(message.Span()); // If port tracker or localhost relay are enabled, establish a connection with the guest and start processing messages. switch (message->NetworkingConfiguration.PortTrackerType) @@ -1978,7 +1982,8 @@ WslCoreVm::DiskMountResult WslCoreVm::MountDiskLockHeld( message.WriteString(message->OptionsOffset, Options); // Send the message. - m_miniInitChannel.SendMessage(message.Span()); + auto transaction = m_miniInitChannel.StartTransaction(); + transaction.Send(message.Span()); // Accept a connection from mini_init wsl::shared::SocketChannel channel{AcceptConnection(m_vmConfig.KernelBootTimeout), "MountResult", m_terminatingEvent.get()}; @@ -2106,7 +2111,8 @@ void WslCoreVm::WaitForPmemDeviceInVm(_In_ ULONG PmemId) { auto lock = m_lock.lock_exclusive(); - m_miniInitChannel.SendMessage(message); + auto transaction = m_miniInitChannel.StartTransaction(); + transaction.Send(message); channel = { AcceptConnection(m_vmConfig.KernelBootTimeout), "WaitForPmem", @@ -2415,7 +2421,8 @@ void WslCoreVm::ResizeDistribution(_In_ ULONG Lun, _In_ HANDLE OutputHandle, _In message.ScsiLun = Lun; message.NewSize = NewSize; - m_miniInitChannel.SendMessage(message); + auto transaction = m_miniInitChannel.StartTransaction(); + transaction.Send(message); wsl::shared::SocketChannel channel{AcceptConnection(m_vmConfig.KernelBootTimeout), "ResizeDistribution", m_terminatingEvent.get()}; auto outputChannel = AcceptConnection(m_vmConfig.KernelBootTimeout); @@ -2492,7 +2499,8 @@ std::pair WslCoreVm::UnmountDisk(_In_ const AttachedDis message.Header.MessageSize = sizeof(message); message.ScsiLun = State.Lun; - m_miniInitChannel.SendMessage(message); + auto transaction = m_miniInitChannel.StartTransaction(); + transaction.Send(message); // Accept a connection from mini_init. wsl::shared::SocketChannel channel{AcceptConnection(m_vmConfig.KernelBootTimeout), "MountResult", m_terminatingEvent.get()}; @@ -2507,7 +2515,8 @@ std::pair WslCoreVm::UnmountVolume(_In_ const AttachedD message.WriteString(Name); // Send the message. - m_miniInitChannel.SendMessage(message.Span()); + auto transaction = m_miniInitChannel.StartTransaction(); + transaction.Send(message.Span()); // Accept a connection from mini_init. wsl::shared::SocketChannel channel{AcceptConnection(m_vmConfig.KernelBootTimeout), "MountResult", m_terminatingEvent.get()}; @@ -2573,7 +2582,8 @@ try { wsl::windows::common::wslutil::SetThreadDescription(L"VirtioFs - Request"); - auto [message, span] = channel.ReceiveMessageOrClosed(); + auto transaction = channel.ReceiveTransaction(); + auto [message, span] = transaction.ReceiveOrClosed(); if (message == nullptr) { return; @@ -2587,7 +2597,7 @@ try response.WriteString(response->TagOffset, tag); response.WriteString(response->SourceOffset, source); - channel.SendMessage(response.Span()); + transaction.Send(response.Span()); }; if (message->MessageType == LxInitMessageAddVirtioFsDevice) diff --git a/src/windows/wslrelay/localhost.cpp b/src/windows/wslrelay/localhost.cpp index 7e6a7a734..4d2c7f869 100644 --- a/src/windows/wslrelay/localhost.cpp +++ b/src/windows/wslrelay/localhost.cpp @@ -81,7 +81,8 @@ void wsl::windows::wslrelay::localhost::RelayWorker(_In_ wsl::shared::SocketChan for (;;) { - auto [Message, Span] = Channel.ReceiveMessageOrClosed(); + auto Transaction = Channel.ReceiveTransaction(); + auto [Message, Span] = Transaction.ReceiveOrClosed(); if (Message == nullptr) { break; @@ -151,7 +152,7 @@ void wsl::windows::wslrelay::localhost::RelayWorker(_In_ wsl::shared::SocketChan } } - Channel.SendMessage(Response); + Transaction.Send(Response); break; }