Skip to content

Commit 420fe90

Browse files
committed
Refactor inproc pair sockets
1 parent d8df24d commit 420fe90

File tree

8 files changed

+156
-145
lines changed

8 files changed

+156
-145
lines changed

ipykernel/kernelbase.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def should_handle(self, stream, msg, idents):
350350
return False
351351
return True
352352

353-
async def dispatch_shell(self, msg, /, stream=None, subshell_id: str | None = None):
353+
async def dispatch_shell(self, msg, /, subshell_id: str | None = None):
354354
"""dispatch shell requests"""
355355
if len(msg) == 1 and msg[0].buffer == b"stop aborting":
356356
# Dummy "stop aborting" message to stop aborting execute requests on this subshell.
@@ -386,11 +386,12 @@ async def dispatch_shell(self, msg, /, stream=None, subshell_id: str | None = No
386386
msg_type = msg["header"]["msg_type"]
387387
assert msg["header"].get("subshell_id") == subshell_id
388388

389-
if stream is None:
390-
if self._supports_kernel_subshells:
391-
stream = self.shell_channel_thread.manager.get_other_stream(subshell_id)
392-
else:
393-
stream = self.shell_stream
389+
if self._supports_kernel_subshells:
390+
stream = self.shell_channel_thread.manager.get_subshell_to_shell_channel_socket(
391+
subshell_id
392+
)
393+
else:
394+
stream = self.shell_stream
394395

395396
# Only abort execute requests
396397
if msg_type == "execute_request":
@@ -611,7 +612,7 @@ async def shell_channel_thread_main(self, msg):
611612

612613
# Find inproc pair socket to use to send message to correct subshell.
613614
subshell_manager = self.shell_channel_thread.manager
614-
socket = subshell_manager.get_shell_channel_stream(subshell_id)
615+
socket = subshell_manager.get_shell_channel_to_subshell_socket(subshell_id)
615616
assert socket is not None
616617
socket.send_multipart(msg, copy=False)
617618
except Exception:
@@ -630,26 +631,25 @@ async def shell_main(self, subshell_id: str | None, msg):
630631
self.shell_channel_thread,
631632
threading.main_thread(),
632633
)
633-
# Inproc pair socket that this subshell uses to talk to shell channel thread.
634-
stream = self.shell_channel_thread.manager.get_other_stream(subshell_id)
634+
socket_pair = self.shell_channel_thread.manager.get_shell_channel_to_subshell_pair(
635+
subshell_id
636+
)
635637
else:
636638
assert subshell_id is None
637639
assert threading.current_thread() == threading.main_thread()
638-
stream = self.shell_stream
640+
socket_pair = None
639641

640642
try:
641643
# Whilst executing a shell message, do not accept any other shell messages on the
642644
# same subshell, so that cells are run sequentially. Without this we can run multiple
643645
# async cells at the same time which would be a nice feature to have but is an API
644646
# change.
645-
stream.stop_on_recv()
646-
await self.dispatch_shell(msg, stream=stream, subshell_id=subshell_id)
647+
if socket_pair:
648+
socket_pair.pause_on_recv()
649+
await self.dispatch_shell(msg, subshell_id=subshell_id)
647650
finally:
648-
stream.on_recv(
649-
partial(self.shell_main, subshell_id),
650-
copy=False,
651-
)
652-
stream.flush()
651+
if socket_pair:
652+
socket_pair.resume_on_recv()
653653

654654
def record_ports(self, ports):
655655
"""Record the ports that this kernel is using.
@@ -690,7 +690,7 @@ def _publish_status(self, status, channel, parent=None):
690690
def _publish_status_and_flush(self, status, channel, stream, parent=None):
691691
"""send status on IOPub and flush specified stream to ensure reply is sent before handling the next reply"""
692692
self._publish_status(status, channel, parent)
693-
if stream:
693+
if stream and hasattr(stream, "flush"):
694694
stream.flush(zmq.POLLOUT)
695695

696696
def _publish_debug_event(self, event):
@@ -835,6 +835,8 @@ async def execute_request(self, stream, ident, parent):
835835
if self._do_exec_accepted_params["cell_id"]:
836836
do_execute_args["cell_id"] = cell_id
837837

838+
subshell_id = parent["header"].get("subshell_id")
839+
838840
# Call do_execute with the appropriate arguments
839841
reply_content = self.do_execute(**do_execute_args)
840842

@@ -1174,9 +1176,9 @@ async def create_subshell_request(self, socket, ident, parent) -> None:
11741176

11751177
# This should only be called in the control thread if it exists.
11761178
# Request is passed to shell channel thread to process.
1177-
other_socket = self.shell_channel_thread.manager.get_control_other_socket()
1178-
other_socket.send_json({"type": "create"})
1179-
reply = other_socket.recv_json()
1179+
control_socket = self.shell_channel_thread.manager.control_to_shell_channel.from_socket
1180+
control_socket.send_json({"type": "create"})
1181+
reply = control_socket.recv_json()
11801182
self.session.send(socket, "create_subshell_reply", reply, parent, ident)
11811183

11821184
async def delete_subshell_request(self, socket, ident, parent) -> None:
@@ -1197,9 +1199,10 @@ async def delete_subshell_request(self, socket, ident, parent) -> None:
11971199

11981200
# This should only be called in the control thread if it exists.
11991201
# Request is passed to shell channel thread to process.
1200-
other_socket = self.shell_channel_thread.manager.get_control_other_socket()
1201-
other_socket.send_json({"type": "delete", "subshell_id": subshell_id})
1202-
reply = other_socket.recv_json()
1202+
control_socket = self.shell_channel_thread.manager.control_to_shell_channel.from_socket
1203+
control_socket.send_json({"type": "delete", "subshell_id": subshell_id})
1204+
reply = control_socket.recv_json()
1205+
12031206
self.session.send(socket, "delete_subshell_reply", reply, parent, ident)
12041207

12051208
async def list_subshell_request(self, socket, ident, parent) -> None:
@@ -1213,9 +1216,10 @@ async def list_subshell_request(self, socket, ident, parent) -> None:
12131216

12141217
# This should only be called in the control thread if it exists.
12151218
# Request is passed to shell channel thread to process.
1216-
other_socket = self.shell_channel_thread.manager.get_control_other_socket()
1217-
other_socket.send_json({"type": "list"})
1218-
reply = other_socket.recv_json()
1219+
control_socket = self.shell_channel_thread.manager.control_to_shell_channel.from_socket
1220+
control_socket.send_json({"type": "list"})
1221+
reply = control_socket.recv_json()
1222+
12191223
self.session.send(socket, "list_subshell_reply", reply, parent, ident)
12201224

12211225
# ---------------------------------------------------------------------------
@@ -1315,7 +1319,7 @@ def _post_dummy_stop_aborting_message(self, subshell_id: str | None) -> None:
13151319
the _aborting flag.
13161320
"""
13171321
subshell_manager = self.shell_channel_thread.manager
1318-
socket = subshell_manager.get_shell_channel_stream(subshell_id)
1322+
socket = subshell_manager.get_shell_channel_to_subshell_socket(subshell_id)
13191323
assert socket is not None
13201324

13211325
msg = b"stop aborting" # Magic string for dummy message.

ipykernel/socket_pair.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import zmq
6+
from tornado.ioloop import IOLoop
7+
from zmq.eventloop.zmqstream import ZMQStream
8+
9+
10+
class SocketPair:
11+
"""Pair of ZMQ inproc sockets for one-direction communication between 2 threads.
12+
13+
One of the threads is always the shell_channel_thread, the other may be the control
14+
thread, main thread or a subshell thread.
15+
"""
16+
17+
from_socket: zmq.Socket[Any]
18+
to_socket: zmq.Socket[Any]
19+
to_stream: ZMQStream
20+
on_recv_callback: Any
21+
on_recv_copy: bool
22+
23+
def __init__(self, context: zmq.Context[Any], name: str):
24+
self.from_socket = context.socket(zmq.PAIR)
25+
self.to_socket = context.socket(zmq.PAIR)
26+
address = self._address(name)
27+
self.from_socket.bind(address)
28+
self.to_socket.connect(address) # Or do I need to do this in another thread?
29+
30+
def close(self):
31+
self.from_socket.close()
32+
33+
if self.to_stream is not None:
34+
self.to_stream.close()
35+
self.to_socket.close()
36+
37+
def on_recv(self, io_loop: IOLoop, on_recv_callback, copy: bool = False):
38+
# io_loop is that of the 'to' thread.
39+
self.on_recv_callback = on_recv_callback
40+
self.on_recv_copy = copy
41+
self.to_stream = ZMQStream(self.to_socket, io_loop)
42+
self.resume_on_recv()
43+
44+
def pause_on_recv(self):
45+
self.to_stream.stop_on_recv()
46+
47+
def resume_on_recv(self):
48+
self.to_stream.on_recv(self.on_recv_callback, copy=self.on_recv_copy)
49+
50+
def _address(self, name) -> str:
51+
return f"inproc://subshell{name}"

ipykernel/subshell.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
from typing import Any
44

55
import zmq
6-
from zmq.eventloop.zmqstream import ZMQStream
76

7+
from .socket_pair import SocketPair
88
from .thread import BaseThread
9-
from .utils import create_inproc_pair_socket
109

1110

1211
class SubshellThread(BaseThread):
@@ -21,12 +20,8 @@ def __init__(
2120
"""Initialize the thread."""
2221
super().__init__(name=f"subshell-{subshell_id}", **kwargs)
2322

24-
shell_channel_socket = create_inproc_pair_socket(context, subshell_id, True)
25-
# io_loop will be current io_loop which is of ShellChannelThread
26-
self.shell_channel_stream = ZMQStream(shell_channel_socket)
27-
28-
subshell_socket = create_inproc_pair_socket(context, subshell_id, False)
29-
self.subshell_stream = ZMQStream(subshell_socket, self.io_loop)
23+
self.shell_channel_to_subshell = SocketPair(context, subshell_id)
24+
self.subshell_to_shell_channel = SocketPair(context, subshell_id + "-reverse")
3025

3126
# When aborting flag is set, execute_request messages to this subshell will be aborted.
3227
self.aborting = False
@@ -35,5 +30,5 @@ def run(self) -> None:
3530
try:
3631
super().run()
3732
finally:
38-
self.subshell_stream.close()
39-
self.shell_channel_stream.close()
33+
self.shell_channel_to_subshell.close()
34+
self.subshell_to_shell_channel.close()

ipykernel/subshell_manager.py

Lines changed: 35 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@
99

1010
import zmq
1111
from tornado.ioloop import IOLoop
12-
from zmq.eventloop.zmqstream import ZMQStream
1312

13+
from .socket_pair import SocketPair
1414
from .subshell import SubshellThread
1515
from .thread import SHELL_CHANNEL_THREAD_NAME
16-
from .utils import create_inproc_pair_socket
1716

1817

1918
class SubshellManager:
@@ -46,26 +45,23 @@ def __init__(
4645
self._lock_cache = Lock()
4746
self._lock_shell_socket = Lock()
4847

49-
# Inproc pair sockets for control channel and main shell (parent subshell).
50-
# Each inproc pair has a "shell_channel" socket used in the shell channel
51-
# thread, and an "other" socket used in the other thread.
52-
control_shell_channel_socket = create_inproc_pair_socket(self._context, "control", True)
53-
self._control_shell_channel_stream = ZMQStream(
54-
control_shell_channel_socket, self._shell_channel_io_loop
48+
# Inproc socket pair for communication from control thread to shell channel thread,
49+
# such as for create_subshell_request messages. Reply messages are returned straight away.
50+
self.control_to_shell_channel = SocketPair(self._context, "control")
51+
self.control_to_shell_channel.on_recv(
52+
self._shell_channel_io_loop, self._process_control_request, copy=True
5553
)
56-
self._control_shell_channel_stream.on_recv(self._process_control_request, copy=True)
5754

58-
self._control_other_socket = create_inproc_pair_socket(self._context, "control", False)
55+
# Inproc socket pair for communication from shell channel thread to main thread,
56+
# such as for execute_request messages.
57+
self._shell_channel_to_main = SocketPair(self._context, "main")
5958

60-
parent_shell_channel_socket = create_inproc_pair_socket(self._context, None, True)
61-
self._parent_shell_channel_stream = ZMQStream(
62-
parent_shell_channel_socket, self._shell_channel_io_loop
59+
# Inproc socket pair for communication from main thread to shell channel thread.
60+
# such as for execute_reply messages.
61+
self._main_to_shell_channel = SocketPair(self._context, "main-reverse")
62+
self._main_to_shell_channel.on_recv(
63+
self._shell_channel_io_loop, self._send_on_shell_channel
6364
)
64-
self._parent_shell_channel_stream.on_recv(self._send_on_shell_channel, copy=False)
65-
66-
# Initialised in set_on_recv_callback
67-
self._on_recv_callback: t.Any = None # Callback for ZMQStream.on_recv for "other" sockets.
68-
self._parent_other_stream: ZMQStream | None = None
6965

7066
def close(self) -> None:
7167
"""Stop all subshells and close all resources."""
@@ -78,40 +74,24 @@ def close(self) -> None:
7874
break
7975
self._stop_subshell(subshell_thread)
8076

81-
for socket_or_stream in (
82-
self._control_shell_channel_stream,
83-
self._parent_shell_channel_stream,
84-
self._parent_other_stream,
85-
):
86-
if socket_or_stream is not None:
87-
socket_or_stream.close()
88-
89-
if self._control_other_socket is not None:
90-
self._control_other_socket.close()
91-
92-
def get_control_other_socket(self) -> zmq.Socket[t.Any]:
93-
return self._control_other_socket
77+
self.control_to_shell_channel.close()
78+
self._main_to_shell_channel.close()
79+
self._shell_channel_to_main.close()
9480

95-
def get_other_stream(self, subshell_id: str | None) -> ZMQStream:
96-
"""Return the other inproc pair socket for a subshell.
97-
98-
This socket is accessed from the subshell thread.
99-
"""
81+
def get_shell_channel_to_subshell_pair(self, subshell_id: str | None) -> SocketPair:
10082
if subshell_id is None:
101-
assert self._parent_other_stream is not None
102-
return self._parent_other_stream
83+
return self._shell_channel_to_main
10384
with self._lock_cache:
104-
return self._cache[subshell_id].subshell_stream
85+
return self._cache[subshell_id].shell_channel_to_subshell
10586

106-
def get_shell_channel_stream(self, subshell_id: str | None) -> ZMQStream:
107-
"""Return the stream for the shell channel inproc pair socket for a subshell.
108-
109-
This stream is accessed from the shell channel thread.
110-
"""
87+
def get_subshell_to_shell_channel_socket(self, subshell_id: str | None) -> zmq.Socket[t.Any]:
11188
if subshell_id is None:
112-
return self._parent_shell_channel_stream
89+
return self._main_to_shell_channel.from_socket
11390
with self._lock_cache:
114-
return self._cache[subshell_id].shell_channel_stream
91+
return self._cache[subshell_id].subshell_to_shell_channel.from_socket
92+
93+
def get_shell_channel_to_subshell_socket(self, subshell_id: str | None) -> zmq.Socket[t.Any]:
94+
return self.get_shell_channel_to_subshell_pair(subshell_id).from_socket
11595

11696
def get_subshell_aborting(self, subshell_id: str) -> bool:
11797
"""Get the aborting flag of the specified subshell."""
@@ -128,13 +108,7 @@ def list_subshell(self) -> list[str]:
128108
def set_on_recv_callback(self, on_recv_callback):
129109
assert current_thread() == main_thread()
130110
self._on_recv_callback = on_recv_callback
131-
if not self._parent_other_stream:
132-
parent_other_socket = create_inproc_pair_socket(self._context, None, False)
133-
self._parent_other_stream = ZMQStream(parent_other_socket)
134-
self._parent_other_stream.on_recv(
135-
partial(self._on_recv_callback, None),
136-
copy=False,
137-
)
111+
self._shell_channel_to_main.on_recv(IOLoop.current(), partial(self._on_recv_callback, None))
138112

139113
def set_subshell_aborting(self, subshell_id: str, aborting: bool) -> None:
140114
"""Set the aborting flag of the specified subshell."""
@@ -168,11 +142,14 @@ def _create_subshell(self) -> str:
168142
assert subshell_id not in self._cache
169143
self._cache[subshell_id] = subshell_thread
170144

171-
subshell_thread.subshell_stream.on_recv(
145+
subshell_thread.shell_channel_to_subshell.on_recv(
146+
subshell_thread.io_loop,
172147
partial(self._on_recv_callback, subshell_id),
173-
copy=False,
174148
)
175-
subshell_thread.shell_channel_stream.on_recv(self._send_on_shell_channel, copy=False)
149+
150+
subshell_thread.subshell_to_shell_channel.on_recv(
151+
self._shell_channel_io_loop, self._send_on_shell_channel
152+
)
176153

177154
subshell_thread.start()
178155
return subshell_id
@@ -219,7 +196,8 @@ def _process_control_request(
219196
"evalue": str(err),
220197
}
221198

222-
self._control_shell_channel_stream.send_json(reply)
199+
# Return the reply to the control thread.
200+
self.control_to_shell_channel.to_socket.send_json(reply)
223201

224202
def _send_on_shell_channel(self, msg) -> None:
225203
assert current_thread().name == SHELL_CHANNEL_THREAD_NAME

ipykernel/utils.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

0 commit comments

Comments
 (0)