Skip to content

Commit e2ca0c1

Browse files
committed
staticdata: Memoize type_in_worklist query (#57917)
When pre-compiling `stdlib/` this cache has a 91% hit rate, so this seems fairly profitable. It also dramatically improves some pathological cases, a few of which have been hit in the wild (arguably due to inference bugs) Without this PR, this package takes exponentially long to pre-compile: ```julia function BigType(N) (N == 0) && return Nothing T = BigType(N-1) return Pair{T,T} end foo(::Type{T}) where T = T precompile(foo, (Type{BigType(40)},)) ``` For an in-the-wild test case hit by a customer, this reduces pre-compilation time from over an hour to just ~two and a half minutes. Resolves #53331.
1 parent 269369c commit e2ca0c1

2 files changed

Lines changed: 93 additions & 51 deletions

File tree

src/staticdata.c

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,22 @@ External links:
8989
#include "valgrind.h"
9090
#include "julia_assert.h"
9191

92+
// This structure is used to store hash tables for the memoization
93+
// of queries in staticdata.c (currently only `type_in_worklist`).
94+
typedef struct {
95+
htable_t type_in_worklist;
96+
} jl_query_cache;
97+
98+
static void init_query_cache(jl_query_cache *cache)
99+
{
100+
htable_new(&cache->type_in_worklist, 0);
101+
}
102+
103+
static void destroy_query_cache(jl_query_cache *cache)
104+
{
105+
htable_free(&cache->type_in_worklist);
106+
}
107+
92108
#include "staticdata_utils.c"
93109
#include "precompile_utils.c"
94110

@@ -479,6 +495,7 @@ typedef struct {
479495
jl_array_t *link_ids_gctags;
480496
jl_array_t *link_ids_gvars;
481497
jl_array_t *link_ids_external_fnvars;
498+
jl_query_cache *query_cache;
482499
jl_ptls_t ptls;
483500
htable_t callers_with_edges;
484501
jl_image_t *image;
@@ -621,38 +638,37 @@ static int jl_needs_serialization(jl_serializer_state *s, jl_value_t *v) JL_NOTS
621638
return 1;
622639
}
623640

624-
625-
static int caching_tag(jl_value_t *v) JL_NOTSAFEPOINT
641+
static int caching_tag(jl_value_t *v, jl_query_cache *query_cache) JL_NOTSAFEPOINT
626642
{
627643
if (jl_is_method_instance(v)) {
628644
jl_method_instance_t *mi = (jl_method_instance_t*)v;
629645
jl_value_t *m = mi->def.value;
630646
if (jl_is_method(m) && jl_object_in_image(m))
631-
return 1 + type_in_worklist(mi->specTypes);
647+
return 1 + type_in_worklist(mi->specTypes, query_cache);
632648
}
633649
if (jl_is_datatype(v)) {
634650
jl_datatype_t *dt = (jl_datatype_t*)v;
635651
if (jl_is_tuple_type(dt) ? !dt->isconcretetype : dt->hasfreetypevars)
636652
return 0; // aka !is_cacheable from jltypes.c
637653
if (jl_object_in_image((jl_value_t*)dt->name))
638-
return 1 + type_in_worklist(v);
654+
return 1 + type_in_worklist(v, query_cache);
639655
}
640656
jl_value_t *dtv = jl_typeof(v);
641657
if (jl_is_datatype_singleton((jl_datatype_t*)dtv)) {
642-
return 1 - type_in_worklist(dtv); // these are already recached in the datatype in the image
658+
return 1 - type_in_worklist(dtv, query_cache); // these are already recached in the datatype in the image
643659
}
644660
return 0;
645661
}
646662

647-
static int needs_recaching(jl_value_t *v) JL_NOTSAFEPOINT
663+
static int needs_recaching(jl_value_t *v, jl_query_cache *query_cache) JL_NOTSAFEPOINT
648664
{
649-
return caching_tag(v) == 2;
665+
return caching_tag(v, query_cache) == 2;
650666
}
651667

652-
static int needs_uniquing(jl_value_t *v) JL_NOTSAFEPOINT
668+
static int needs_uniquing(jl_value_t *v, jl_query_cache *query_cache) JL_NOTSAFEPOINT
653669
{
654670
assert(!jl_object_in_image(v));
655-
return caching_tag(v) == 1;
671+
return caching_tag(v, query_cache) == 1;
656672
}
657673

658674
static void record_field_change(jl_value_t **addr, jl_value_t *newval) JL_NOTSAFEPOINT
@@ -738,7 +754,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
738754
// ensure all type parameters are recached
739755
jl_queue_for_serialization_(s, (jl_value_t*)dt->parameters, 1, 1);
740756
jl_value_t *singleton = dt->instance;
741-
if (singleton && needs_uniquing(singleton)) {
757+
if (singleton && needs_uniquing(singleton, s->query_cache)) {
742758
assert(jl_needs_serialization(s, singleton)); // should be true, since we visited dt
743759
// do not visit dt->instance for our template object as it leads to unwanted cycles here
744760
// (it may get serialized from elsewhere though)
@@ -749,7 +765,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
749765
if (s->incremental && jl_is_method_instance(v)) {
750766
jl_method_instance_t *mi = (jl_method_instance_t*)v;
751767
jl_value_t *def = mi->def.value;
752-
if (needs_uniquing(v)) {
768+
if (needs_uniquing(v, s->query_cache)) {
753769
// we only need 3 specific fields of this (the rest are not used)
754770
jl_queue_for_serialization(s, mi->def.value);
755771
jl_queue_for_serialization(s, mi->specTypes);
@@ -767,7 +783,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
767783
record_field_change((jl_value_t**)&mi->cache, NULL);
768784
}
769785
else {
770-
assert(!needs_recaching(v));
786+
assert(!needs_recaching(v, s->query_cache));
771787
}
772788
// n.b. opaque closures cannot be inspected and relied upon like a
773789
// normal method since they can get improperly introduced by generated
@@ -901,9 +917,9 @@ static void jl_queue_for_serialization_(jl_serializer_state *s, jl_value_t *v, i
901917
// Items that require postorder traversal must visit their children prior to insertion into
902918
// the worklist/serialization_order (and also before their first use)
903919
if (s->incremental && !immediate) {
904-
if (jl_is_datatype(t) && needs_uniquing(v))
920+
if (jl_is_datatype(t) && needs_uniquing(v, s->query_cache))
905921
immediate = 1;
906-
if (jl_is_datatype_singleton((jl_datatype_t*)t) && needs_uniquing(v))
922+
if (jl_is_datatype_singleton((jl_datatype_t*)t) && needs_uniquing(v, s->query_cache))
907923
immediate = 1;
908924
}
909925

@@ -1067,7 +1083,7 @@ static uintptr_t _backref_id(jl_serializer_state *s, jl_value_t *v, jl_array_t *
10671083

10681084
static void record_uniquing(jl_serializer_state *s, jl_value_t *fld, uintptr_t offset) JL_NOTSAFEPOINT
10691085
{
1070-
if (s->incremental && jl_needs_serialization(s, fld) && needs_uniquing(fld)) {
1086+
if (s->incremental && jl_needs_serialization(s, fld) && needs_uniquing(fld, s->query_cache)) {
10711087
if (jl_is_datatype(fld) || jl_is_datatype_singleton((jl_datatype_t*)jl_typeof(fld)))
10721088
arraylist_push(&s->uniquing_types, (void*)(uintptr_t)offset);
10731089
else if (jl_is_method_instance(fld))
@@ -1211,7 +1227,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
12111227
write_padding(f, LLT_ALIGN(skip_header_pos, 16) - skip_header_pos);
12121228

12131229
// write header
1214-
if (s->incremental && jl_needs_serialization(s, (jl_value_t*)t) && needs_uniquing((jl_value_t*)t))
1230+
if (s->incremental && jl_needs_serialization(s, (jl_value_t*)t) && needs_uniquing((jl_value_t*)t, s->query_cache))
12151231
arraylist_push(&s->uniquing_types, (void*)(uintptr_t)(ios_pos(f)|1));
12161232
if (f == s->const_data)
12171233
write_uint(s->const_data, ((uintptr_t)t->smalltag << 4) | GC_OLD_MARKED);
@@ -1222,7 +1238,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
12221238
layout_table.items[item] = (void*)(reloc_offset | (f == s->const_data)); // store the inverse mapping of `serialization_order` (`id` => object-as-streampos)
12231239

12241240
if (s->incremental) {
1225-
if (needs_uniquing(v)) {
1241+
if (needs_uniquing(v, s->query_cache)) {
12261242
if (jl_is_method_instance(v)) {
12271243
assert(f == s->s);
12281244
jl_method_instance_t *mi = (jl_method_instance_t*)v;
@@ -1243,7 +1259,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
12431259
assert(jl_is_datatype_singleton(t) && "unreachable");
12441260
}
12451261
}
1246-
else if (needs_recaching(v)) {
1262+
else if (needs_recaching(v, s->query_cache)) {
12471263
arraylist_push(jl_is_datatype(v) ? &s->fixup_types : &s->fixup_objs, (void*)reloc_offset);
12481264
}
12491265
else if (jl_typetagis(v, jl_binding_type)) {
@@ -1606,7 +1622,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
16061622
}
16071623
}
16081624
void *superidx = ptrhash_get(&serialization_order, dt->super);
1609-
if (s->incremental && superidx != HT_NOTFOUND && (char*)superidx - 1 - (char*)HT_NOTFOUND > item && needs_uniquing((jl_value_t*)dt->super))
1625+
if (s->incremental && superidx != HT_NOTFOUND && (char*)superidx - 1 - (char*)HT_NOTFOUND > item && needs_uniquing((jl_value_t*)dt->super, s->query_cache))
16101626
arraylist_push(&s->uniquing_super, dt->super);
16111627
}
16121628
else if (jl_is_typename(v)) {
@@ -2351,7 +2367,8 @@ JL_DLLEXPORT jl_value_t *jl_as_global_root(jl_value_t *val JL_MAYBE_UNROOTED)
23512367

23522368
static void jl_prepare_serialization_data(jl_array_t *mod_array, jl_array_t *newly_inferred, uint64_t worklist_key,
23532369
/* outputs */ jl_array_t **extext_methods, jl_array_t **new_specializations,
2354-
jl_array_t **method_roots_list, jl_array_t **ext_targets, jl_array_t **edges)
2370+
jl_array_t **method_roots_list, jl_array_t **ext_targets, jl_array_t **edges,
2371+
jl_query_cache *query_cache)
23552372
{
23562373
// extext_methods: [method1, ...], worklist-owned "extending external" methods added to functions owned by modules outside the worklist
23572374
// ext_targets: [invokesig1, callee1, matches1, ...] non-worklist callees of worklist-owned methods
@@ -2362,7 +2379,7 @@ static void jl_prepare_serialization_data(jl_array_t *mod_array, jl_array_t *new
23622379
assert(edges_map == NULL);
23632380

23642381
// Save the inferred code from newly inferred, external methods
2365-
*new_specializations = queue_external_cis(newly_inferred);
2382+
*new_specializations = queue_external_cis(newly_inferred, query_cache);
23662383

23672384
// Collect method extensions and edges data
23682385
JL_GC_PUSH1(&edges_map);
@@ -2401,7 +2418,8 @@ static void jl_prepare_serialization_data(jl_array_t *mod_array, jl_array_t *new
24012418
static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
24022419
jl_array_t *worklist, jl_array_t *extext_methods,
24032420
jl_array_t *new_specializations, jl_array_t *method_roots_list,
2404-
jl_array_t *ext_targets, jl_array_t *edges) JL_GC_DISABLED
2421+
jl_array_t *ext_targets, jl_array_t *edges,
2422+
jl_query_cache *query_cache) JL_GC_DISABLED
24052423
{
24062424
htable_new(&field_replace, 0);
24072425
// strip metadata and IR when requested
@@ -2428,6 +2446,7 @@ static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
24282446
ios_mem(&gvar_record, 0);
24292447
ios_mem(&fptr_record, 0);
24302448
jl_serializer_state s = {0};
2449+
s.query_cache = query_cache;
24312450
s.incremental = !(worklist == NULL);
24322451
s.s = &sysimg;
24332452
s.const_data = &const_data;
@@ -2743,12 +2762,15 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
27432762
int64_t datastartpos = 0;
27442763
JL_GC_PUSH6(&mod_array, &extext_methods, &new_specializations, &method_roots_list, &ext_targets, &edges);
27452764

2765+
jl_query_cache query_cache;
2766+
init_query_cache(&query_cache);
2767+
27462768
if (worklist) {
27472769
mod_array = jl_get_loaded_modules(); // __toplevel__ modules loaded in this session (from Base.loaded_modules_array)
27482770
// Generate _native_data`
27492771
if (_native_data != NULL) {
27502772
jl_prepare_serialization_data(mod_array, newly_inferred, jl_worklist_key(worklist),
2751-
&extext_methods, &new_specializations, NULL, NULL, NULL);
2773+
&extext_methods, &new_specializations, NULL, NULL, NULL, &query_cache);
27522774
jl_precompile_toplevel_module = (jl_module_t*)jl_array_ptr_ref(worklist, jl_array_len(worklist)-1);
27532775
*_native_data = jl_precompile_worklist(worklist, extext_methods, new_specializations);
27542776
jl_precompile_toplevel_module = NULL;
@@ -2777,7 +2799,7 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
27772799
if (worklist) {
27782800
htable_new(&relocatable_ext_cis, 0);
27792801
jl_prepare_serialization_data(mod_array, newly_inferred, jl_worklist_key(worklist),
2780-
&extext_methods, &new_specializations, &method_roots_list, &ext_targets, &edges);
2802+
&extext_methods, &new_specializations, &method_roots_list, &ext_targets, &edges, &query_cache);
27812803
if (!emit_split) {
27822804
write_int32(f, 0); // No clone_targets
27832805
write_padding(f, LLT_ALIGN(ios_pos(f), JL_CACHE_BYTE_ALIGNMENT) - ios_pos(f));
@@ -2789,7 +2811,7 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
27892811
}
27902812
if (_native_data != NULL)
27912813
native_functions = *_native_data;
2792-
jl_save_system_image_to_stream(ff, mod_array, worklist, extext_methods, new_specializations, method_roots_list, ext_targets, edges);
2814+
jl_save_system_image_to_stream(ff, mod_array, worklist, extext_methods, new_specializations, method_roots_list, ext_targets, edges, &query_cache);
27932815
if (_native_data != NULL)
27942816
native_functions = NULL;
27952817
if (worklist)
@@ -2820,6 +2842,8 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
28202842
}
28212843
}
28222844

2845+
destroy_query_cache(&query_cache);
2846+
28232847
JL_GC_POP();
28242848
*s = f;
28252849
if (emit_split)

src/staticdata_utils.c

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -103,62 +103,80 @@ JL_DLLEXPORT void jl_push_newly_inferred(jl_value_t* ci)
103103
JL_UNLOCK(&newly_inferred_mutex);
104104
}
105105

106-
107106
// compute whether a type references something internal to worklist
108107
// and thus could not have existed before deserialize
109108
// and thus does not need delayed unique-ing
110-
static int type_in_worklist(jl_value_t *v) JL_NOTSAFEPOINT
109+
static int type_in_worklist(jl_value_t *v, jl_query_cache *cache) JL_NOTSAFEPOINT
111110
{
112111
if (jl_object_in_image(v))
113112
return 0; // fast-path for rejection
113+
114+
void *cached = HT_NOTFOUND;
115+
if (cache != NULL)
116+
cached = ptrhash_get(&cache->type_in_worklist, v);
117+
118+
// fast-path for memoized results
119+
if (cached != HT_NOTFOUND)
120+
return cached == v;
121+
122+
int result = 0;
114123
if (jl_is_uniontype(v)) {
115124
jl_uniontype_t *u = (jl_uniontype_t*)v;
116-
return type_in_worklist(u->a) ||
117-
type_in_worklist(u->b);
125+
result = type_in_worklist(u->a, cache) ||
126+
type_in_worklist(u->b, cache);
118127
}
119128
else if (jl_is_unionall(v)) {
120129
jl_unionall_t *ua = (jl_unionall_t*)v;
121-
return type_in_worklist((jl_value_t*)ua->var) ||
122-
type_in_worklist(ua->body);
130+
result = type_in_worklist((jl_value_t*)ua->var, cache) ||
131+
type_in_worklist(ua->body, cache);
123132
}
124133
else if (jl_is_typevar(v)) {
125134
jl_tvar_t *tv = (jl_tvar_t*)v;
126-
return type_in_worklist(tv->lb) ||
127-
type_in_worklist(tv->ub);
135+
result = type_in_worklist(tv->lb, cache) ||
136+
type_in_worklist(tv->ub, cache);
128137
}
129138
else if (jl_is_vararg(v)) {
130139
jl_vararg_t *tv = (jl_vararg_t*)v;
131-
if (tv->T && type_in_worklist(tv->T))
132-
return 1;
133-
if (tv->N && type_in_worklist(tv->N))
134-
return 1;
140+
result = ((tv->T && type_in_worklist(tv->T, cache)) ||
141+
(tv->N && type_in_worklist(tv->N, cache)));
135142
}
136143
else if (jl_is_datatype(v)) {
137144
jl_datatype_t *dt = (jl_datatype_t*)v;
138-
if (!jl_object_in_image((jl_value_t*)dt->name))
139-
return 1;
140-
jl_svec_t *tt = dt->parameters;
141-
size_t i, l = jl_svec_len(tt);
142-
for (i = 0; i < l; i++)
143-
if (type_in_worklist(jl_tparam(dt, i)))
144-
return 1;
145+
if (!jl_object_in_image((jl_value_t*)dt->name)) {
146+
result = 1;
147+
}
148+
else {
149+
jl_svec_t *tt = dt->parameters;
150+
size_t i, l = jl_svec_len(tt);
151+
for (i = 0; i < l; i++) {
152+
if (type_in_worklist(jl_tparam(dt, i), cache)) {
153+
result = 1;
154+
break;
155+
}
156+
}
157+
}
145158
}
146159
else {
147-
return type_in_worklist(jl_typeof(v));
160+
return type_in_worklist(jl_typeof(v), cache);
148161
}
149-
return 0;
162+
163+
// Memoize result
164+
if (cache != NULL)
165+
ptrhash_put(&cache->type_in_worklist, (void*)v, result ? (void*)v : NULL);
166+
167+
return result;
150168
}
151169

152170
// When we infer external method instances, ensure they link back to the
153171
// package. Otherwise they might be, e.g., for external macros.
154172
// Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
155-
static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited, arraylist_t *stack)
173+
static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited, arraylist_t *stack, jl_query_cache *query_cache)
156174
{
157175
jl_module_t *mod = mi->def.module;
158176
if (jl_is_method(mod))
159177
mod = ((jl_method_t*)mod)->module;
160178
assert(jl_is_module(mod));
161-
if (mi->precompiled || !jl_object_in_image((jl_value_t*)mod) || type_in_worklist(mi->specTypes)) {
179+
if (mi->precompiled || !jl_object_in_image((jl_value_t*)mod) || type_in_worklist(mi->specTypes, query_cache)) {
162180
return 1;
163181
}
164182
if (!mi->backedges) {
@@ -181,7 +199,7 @@ static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited,
181199
while (i < n) {
182200
jl_method_instance_t *be;
183201
i = get_next_edge(mi->backedges, i, NULL, &be);
184-
int child_found = has_backedge_to_worklist(be, visited, stack);
202+
int child_found = has_backedge_to_worklist(be, visited, stack, query_cache);
185203
if (child_found == 1 || child_found == 2) {
186204
// found what we were looking for, so terminate early
187205
found = 1;
@@ -224,7 +242,7 @@ static int is_relocatable_ci(htable_t *relocatable_ext_cis, jl_code_instance_t *
224242
// from the worklist or explicitly added by a `precompile` statement, and
225243
// (4) are the most recently computed result for that method.
226244
// These will be preserved in the image.
227-
static jl_array_t *queue_external_cis(jl_array_t *list)
245+
static jl_array_t *queue_external_cis(jl_array_t *list, jl_query_cache *query_cache)
228246
{
229247
if (list == NULL)
230248
return NULL;
@@ -245,7 +263,7 @@ static jl_array_t *queue_external_cis(jl_array_t *list)
245263
jl_method_instance_t *mi = ci->def;
246264
jl_method_t *m = mi->def.method;
247265
if (ci->inferred && jl_is_method(m) && jl_object_in_image((jl_value_t*)m->module)) {
248-
int found = has_backedge_to_worklist(mi, &visited, &stack);
266+
int found = has_backedge_to_worklist(mi, &visited, &stack, query_cache);
249267
assert(found == 0 || found == 1 || found == 2);
250268
assert(stack.len == 0);
251269
if (found == 1 && ci->max_world == ~(size_t)0) {

0 commit comments

Comments
 (0)