Skip to content

Commit c8c2edd

Browse files
Add Interpreter.bind().
1 parent 9f83af7 commit c8c2edd

File tree

5 files changed

+136
-12
lines changed

5 files changed

+136
-12
lines changed

Lib/test/support/interpreters.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ def close(self):
103103
"""
104104
return _interpreters.destroy(self._id)
105105

106+
def bind(self, ns=None, /, **kwargs):
107+
"""Bind the given values into the interpreter's __main__."""
108+
ns = dict(ns, **kwargs) if ns is not None else kwargs
109+
_interpreters.bind(self._id, ns)
110+
106111
# XXX Rename "run" to "exec"?
107112
# XXX Do not allow init to overwrite (by default)?
108113
def run(self, src_str, /, *, init=None):

Lib/test/test__xxinterpchannels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,12 +587,12 @@ def test_run_string_arg_unresolved(self):
587587
cid = channels.create()
588588
interp = interpreters.create()
589589

590+
interpreters.bind(interp, dict(cid=cid.send))
590591
out = _run_output(interp, dedent("""
591592
import _xxinterpchannels as _channels
592593
print(cid.end)
593594
_channels.send(cid, b'spam', blocking=False)
594-
"""),
595-
dict(cid=cid.send))
595+
"""))
596596
obj = channels.recv(cid)
597597

598598
self.assertEqual(obj, b'spam')

Lib/test/test__xxsubinterpreters.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ def _captured_script(script):
3131
return wrapped, open(r, encoding="utf-8")
3232

3333

34-
def _run_output(interp, request, shared=None):
34+
def _run_output(interp, request):
3535
script, rpipe = _captured_script(request)
3636
with rpipe:
37-
interpreters.run_string(interp, script, shared)
37+
interpreters.run_string(interp, script)
3838
return rpipe.read()
3939

4040

@@ -659,10 +659,10 @@ def test_shareable_types(self):
659659
]
660660
for obj in objects:
661661
with self.subTest(obj):
662+
interpreters.bind(interp, dict(obj=obj))
662663
interpreters.run_string(
663664
interp,
664665
f'assert(obj == {obj!r})',
665-
shared=dict(obj=obj),
666666
)
667667

668668
def test_os_exec(self):
@@ -790,7 +790,8 @@ def test_with_shared(self):
790790
with open({w}, 'wb') as chan:
791791
pickle.dump(ns, chan)
792792
""")
793-
interpreters.run_string(self.id, script, shared)
793+
interpreters.bind(self.id, shared)
794+
interpreters.run_string(self.id, script)
794795
with open(r, 'rb') as chan:
795796
ns = pickle.load(chan)
796797

@@ -811,7 +812,8 @@ def test_shared_overwrites(self):
811812
ns2 = dict(vars())
812813
del ns2['__builtins__']
813814
""")
814-
interpreters.run_string(self.id, script, shared)
815+
interpreters.bind(self.id, shared)
816+
interpreters.run_string(self.id, script)
815817

816818
r, w = os.pipe()
817819
script = dedent(f"""
@@ -842,7 +844,8 @@ def test_shared_overwrites_default_vars(self):
842844
with open({w}, 'wb') as chan:
843845
pickle.dump(ns, chan)
844846
""")
845-
interpreters.run_string(self.id, script, shared)
847+
interpreters.bind(self.id, shared)
848+
interpreters.run_string(self.id, script)
846849
with open(r, 'rb') as chan:
847850
ns = pickle.load(chan)
848851

@@ -948,7 +951,8 @@ def script():
948951
with open(w, 'w', encoding="utf-8") as spipe:
949952
with contextlib.redirect_stdout(spipe):
950953
print('it worked!', end='')
951-
interpreters.run_func(self.id, script, shared=dict(w=w))
954+
interpreters.bind(self.id, dict(w=w))
955+
interpreters.run_func(self.id, script)
952956

953957
with open(r, encoding="utf-8") as outfile:
954958
out = outfile.read()
@@ -964,7 +968,8 @@ def script():
964968
with contextlib.redirect_stdout(spipe):
965969
print('it worked!', end='')
966970
def f():
967-
interpreters.run_func(self.id, script, shared=dict(w=w))
971+
interpreters.bind(self.id, dict(w=w))
972+
interpreters.run_func(self.id, script)
968973
t = threading.Thread(target=f)
969974
t.start()
970975
t.join()
@@ -984,7 +989,8 @@ def script():
984989
with contextlib.redirect_stdout(spipe):
985990
print('it worked!', end='')
986991
code = script.__code__
987-
interpreters.run_func(self.id, code, shared=dict(w=w))
992+
interpreters.bind(self.id, dict(w=w))
993+
interpreters.run_func(self.id, code)
988994

989995
with open(r, encoding="utf-8") as outfile:
990996
out = outfile.read()

Lib/test/test_interpreters.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def clean_up_interpreters():
4242
def _run_output(interp, request, init=None):
4343
script, rpipe = _captured_script(request)
4444
with rpipe:
45-
interp.run(script, init=init)
45+
if init:
46+
interp.bind(init)
47+
interp.run(script)
4648
return rpipe.read()
4749

4850

@@ -467,6 +469,63 @@ def task():
467469
self.assertEqual(os.read(r_interp, 1), FINISHED)
468470

469471

472+
class TestInterpreterBind(TestBase):
473+
474+
def test_empty(self):
475+
interp = interpreters.create()
476+
with self.assertRaises(ValueError):
477+
interp.bind()
478+
479+
def test_dict(self):
480+
values = {'spam': 42, 'eggs': 'ham'}
481+
interp = interpreters.create()
482+
interp.bind(values)
483+
out = _run_output(interp, dedent("""
484+
print(spam, eggs)
485+
"""))
486+
self.assertEqual(out.strip(), '42 ham')
487+
488+
def test_tuple(self):
489+
values = {'spam': 42, 'eggs': 'ham'}
490+
values = tuple(values.items())
491+
interp = interpreters.create()
492+
interp.bind(values)
493+
out = _run_output(interp, dedent("""
494+
print(spam, eggs)
495+
"""))
496+
self.assertEqual(out.strip(), '42 ham')
497+
498+
def test_kwargs(self):
499+
values = {'spam': 42, 'eggs': 'ham'}
500+
interp = interpreters.create()
501+
interp.bind(**values)
502+
out = _run_output(interp, dedent("""
503+
print(spam, eggs)
504+
"""))
505+
self.assertEqual(out.strip(), '42 ham')
506+
507+
def test_dict_and_kwargs(self):
508+
values = {'spam': 42, 'eggs': 'ham'}
509+
interp = interpreters.create()
510+
interp.bind(values, foo='bar')
511+
out = _run_output(interp, dedent("""
512+
print(spam, eggs, foo)
513+
"""))
514+
self.assertEqual(out.strip(), '42 ham bar')
515+
516+
def test_not_shareable(self):
517+
interp = interpreters.create()
518+
# XXX TypeError?
519+
with self.assertRaises(ValueError):
520+
interp.bind(spam={'spam': 'eggs', 'foo': 'bar'})
521+
522+
# Make sure neither was actually bound.
523+
with self.assertRaises(RuntimeError):
524+
interp.run('print(foo)')
525+
with self.assertRaises(RuntimeError):
526+
interp.run('print(spam)')
527+
528+
470529
class TestInterpreterRun(TestBase):
471530

472531
def test_success(self):

Modules/_xxsubinterpretersmodule.c

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,58 @@ PyDoc_STRVAR(get_main_doc,
402402
\n\
403403
Return the ID of main interpreter.");
404404

405+
static PyObject *
406+
interp_bind(PyObject *self, PyObject *args)
407+
{
408+
PyObject *id, *updates;
409+
if (!PyArg_ParseTuple(args, "OO:" MODULE_NAME ".bind", &id, &updates)) {
410+
return NULL;
411+
}
412+
413+
// Look up the interpreter.
414+
PyInterpreterState *interp = PyInterpreterID_LookUp(id);
415+
if (interp == NULL) {
416+
return NULL;
417+
}
418+
419+
// Check the updates.
420+
if (updates != Py_None) {
421+
Py_ssize_t size = PyObject_Size(updates);
422+
if (size < 0) {
423+
return NULL;
424+
}
425+
if (size == 0) {
426+
PyErr_SetString(PyExc_ValueError,
427+
"arg 2 must be a non-empty mapping");
428+
return NULL;
429+
}
430+
}
431+
432+
_PyXI_session session = {0};
433+
434+
// Prep and switch interpreters, including apply the updates.
435+
if (_PyXI_Enter(&session, interp, updates) < 0) {
436+
if (!PyErr_Occurred()) {
437+
_PyXI_ApplyCapturedException(&session, NULL);
438+
assert(PyErr_Occurred());
439+
}
440+
else {
441+
assert(!_PyXI_HasCapturedException(&session));
442+
}
443+
return NULL;
444+
}
445+
446+
// Clean up and switch back.
447+
_PyXI_Exit(&session);
448+
449+
Py_RETURN_NONE;
450+
}
451+
452+
PyDoc_STRVAR(bind_doc,
453+
"bind(id, ns)\n\
454+
\n\
455+
Bind the given attributes in the interpreter's __main__ module.");
456+
405457
static PyUnicodeObject *
406458
convert_script_arg(PyObject *arg, const char *fname, const char *displayname,
407459
const char *expected)
@@ -698,6 +750,8 @@ static PyMethodDef module_functions[] = {
698750
{"run_func", _PyCFunction_CAST(interp_run_func),
699751
METH_VARARGS | METH_KEYWORDS, run_func_doc},
700752

753+
{"bind", _PyCFunction_CAST(interp_bind),
754+
METH_VARARGS, bind_doc},
701755
{"is_shareable", _PyCFunction_CAST(object_is_shareable),
702756
METH_VARARGS | METH_KEYWORDS, is_shareable_doc},
703757

0 commit comments

Comments
 (0)