Skip to content

Commit 87a6331

Browse files
Copilotmudler
andauthored
stablediffusion-ggml: replace hand-maintained enum string arrays with upstream API calls (#9192)
* Initial plan * Remove hand-maintained enum string arrays in gosd.cpp, use upstream API functions Agent-Logs-Url: https://github.com/mudler/LocalAI/sessions/561fb489-89ed-4588-8f1e-7b967d91ba37 Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: mudler <2420543+mudler@users.noreply.github.com>
1 parent efdcbbe commit 87a6331

2 files changed

Lines changed: 28 additions & 160 deletions

File tree

backend/go/stablediffusion-ggml/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
88

99
# stablediffusion.cpp (ggml)
1010
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
11-
STABLEDIFFUSION_GGML_VERSION?=f16a110f8776398ef23a2a6b7b57522c2471637a
11+
STABLEDIFFUSION_GGML_VERSION?=1d6cb0f8c33ddadf1bff8aff40ec2e5b1ccb4940
1212

1313
CMAKE_ARGS+=-DGGML_MAX_NAME=128
1414

backend/go/stablediffusion-ggml/gosd.cpp

Lines changed: 27 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -27,107 +27,7 @@
2727
#include <stdlib.h>
2828
#include <regex>
2929

30-
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
31-
const char* sample_method_str[] = {
32-
"euler",
33-
"euler_a",
34-
"heun",
35-
"dpm2",
36-
"dpm++2s_a",
37-
"dpm++2m",
38-
"dpm++2mv2",
39-
"ipndm",
40-
"ipndm_v",
41-
"lcm",
42-
"ddim_trailing",
43-
"tcd",
44-
"res_multistep",
45-
"res_2s",
46-
};
47-
48-
static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch");
49-
50-
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
51-
const char* schedulers[] = {
52-
"discrete",
53-
"karras",
54-
"exponential",
55-
"ays",
56-
"gits",
57-
"sgm_uniform",
58-
"simple",
59-
"smoothstep",
60-
"kl_optimal",
61-
"lcm",
62-
"bong_tangent",
63-
};
64-
65-
static_assert(std::size(schedulers) == SCHEDULER_COUNT, "schedulers mismatch");
66-
67-
// New enum string arrays
68-
const char* rng_type_str[] = {
69-
"std_default",
70-
"cuda",
71-
"cpu",
72-
};
73-
static_assert(std::size(rng_type_str) == RNG_TYPE_COUNT, "rng type mismatch");
74-
75-
const char* prediction_str[] = {
76-
"epsilon",
77-
"v",
78-
"edm_v",
79-
"flow",
80-
"flux_flow",
81-
"flux2_flow",
82-
};
83-
static_assert(std::size(prediction_str) == PREDICTION_COUNT, "prediction mismatch");
84-
85-
const char* lora_apply_mode_str[] = {
86-
"auto",
87-
"immediately",
88-
"at_runtime",
89-
};
90-
static_assert(std::size(lora_apply_mode_str) == LORA_APPLY_MODE_COUNT, "lora apply mode mismatch");
91-
92-
constexpr const char* sd_type_str[] = {
93-
"f32", // 0
94-
"f16", // 1
95-
"q4_0", // 2
96-
"q4_1", // 3
97-
nullptr, // 4
98-
nullptr, // 5
99-
"q5_0", // 6
100-
"q5_1", // 7
101-
"q8_0", // 8
102-
"q8_1", // 9
103-
"q2_k", // 10
104-
"q3_k", // 11
105-
"q4_k", // 12
106-
"q5_k", // 13
107-
"q6_k", // 14
108-
"q8_k", // 15
109-
"iq2_xxs", // 16
110-
"iq2_xs", // 17
111-
"iq3_xxs", // 18
112-
"iq1_s", // 19
113-
"iq4_nl", // 20
114-
"iq3_s", // 21
115-
"iq2_s", // 22
116-
"iq4_xs", // 23
117-
"i8", // 24
118-
"i16", // 25
119-
"i32", // 26
120-
"i64", // 27
121-
"f64", // 28
122-
"iq1_m", // 29
123-
"bf16", // 30
124-
nullptr, nullptr, nullptr, // 31-33
125-
"tq1_0", // 34
126-
"tq2_0", // 35
127-
nullptr, nullptr, nullptr, // 36-38
128-
"mxfp4" // 39
129-
};
130-
static_assert(std::size(sd_type_str) == SD_TYPE_COUNT, "sd type mismatch");
30+
13131

13232
sd_ctx_params_t ctx_params;
13333
sd_ctx_t* sd_c;
@@ -596,75 +496,45 @@ int load_model(const char *model, char *model_path, char* options[], int threads
596496
if (!strcmp(optname, "flow_shift")) flow_shift = atof(optval);
597497

598498
if (!strcmp(optname, "rng_type")) {
599-
int found = -1;
600-
for (int m = 0; m < RNG_TYPE_COUNT; m++) {
601-
if (!strcmp(optval, rng_type_str[m])) {
602-
found = m;
603-
break;
604-
}
605-
}
606-
if (found != -1) {
607-
rng_type = (rng_type_t)found;
499+
rng_type_t parsed = str_to_rng_type(optval);
500+
if (parsed != RNG_TYPE_COUNT) {
501+
rng_type = parsed;
608502
fprintf(stderr, "Found rng_type: %s\n", optval);
609503
} else {
610504
fprintf(stderr, "Invalid rng_type: %s, using default\n", optval);
611505
}
612506
}
613507
if (!strcmp(optname, "sampler_rng_type")) {
614-
int found = -1;
615-
for (int m = 0; m < RNG_TYPE_COUNT; m++) {
616-
if (!strcmp(optval, rng_type_str[m])) {
617-
found = m;
618-
break;
619-
}
620-
}
621-
if (found != -1) {
622-
sampler_rng_type = (rng_type_t)found;
508+
rng_type_t parsed = str_to_rng_type(optval);
509+
if (parsed != RNG_TYPE_COUNT) {
510+
sampler_rng_type = parsed;
623511
fprintf(stderr, "Found sampler_rng_type: %s\n", optval);
624512
} else {
625513
fprintf(stderr, "Invalid sampler_rng_type: %s, using default\n", optval);
626514
}
627515
}
628516
if (!strcmp(optname, "prediction")) {
629-
int found = -1;
630-
for (int m = 0; m < PREDICTION_COUNT; m++) {
631-
if (!strcmp(optval, prediction_str[m])) {
632-
found = m;
633-
break;
634-
}
635-
}
636-
if (found != -1) {
637-
prediction = (prediction_t)found;
517+
prediction_t parsed = str_to_prediction(optval);
518+
if (parsed != PREDICTION_COUNT) {
519+
prediction = parsed;
638520
fprintf(stderr, "Found prediction: %s\n", optval);
639521
} else {
640522
fprintf(stderr, "Invalid prediction: %s, using default\n", optval);
641523
}
642524
}
643525
if (!strcmp(optname, "lora_apply_mode")) {
644-
int found = -1;
645-
for (int m = 0; m < LORA_APPLY_MODE_COUNT; m++) {
646-
if (!strcmp(optval, lora_apply_mode_str[m])) {
647-
found = m;
648-
break;
649-
}
650-
}
651-
if (found != -1) {
652-
lora_apply_mode = (lora_apply_mode_t)found;
526+
lora_apply_mode_t parsed = str_to_lora_apply_mode(optval);
527+
if (parsed != LORA_APPLY_MODE_COUNT) {
528+
lora_apply_mode = parsed;
653529
fprintf(stderr, "Found lora_apply_mode: %s\n", optval);
654530
} else {
655531
fprintf(stderr, "Invalid lora_apply_mode: %s, using default\n", optval);
656532
}
657533
}
658534
if (!strcmp(optname, "wtype")) {
659-
int found = -1;
660-
for (int m = 0; m < SD_TYPE_COUNT; m++) {
661-
if (sd_type_str[m] && !strcmp(optval, sd_type_str[m])) {
662-
found = m;
663-
break;
664-
}
665-
}
666-
if (found != -1) {
667-
wtype = (sd_type_t)found;
535+
sd_type_t parsed = str_to_sd_type(optval);
536+
if (parsed != SD_TYPE_COUNT) {
537+
wtype = parsed;
668538
fprintf(stderr, "Found wtype: %s\n", optval);
669539
} else {
670540
fprintf(stderr, "Invalid wtype: %s, using default\n", optval);
@@ -735,27 +605,25 @@ int load_model(const char *model, char *model_path, char* options[], int threads
735605
fprintf (stderr, "Created context: OK\n");
736606

737607
int sample_method_found = -1;
738-
for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) {
739-
if (!strcmp(sampler, sample_method_str[m])) {
740-
sample_method_found = m;
741-
fprintf(stderr, "Found sampler: %s\n", sampler);
742-
}
608+
sample_method_t sm = str_to_sample_method(sampler);
609+
if (sm != SAMPLE_METHOD_COUNT) {
610+
sample_method_found = (int)sm;
611+
fprintf(stderr, "Found sampler: %s\n", sampler);
743612
}
744613
if (sample_method_found == -1) {
745614
sample_method_found = sd_get_default_sample_method(sd_ctx);
746-
fprintf(stderr, "Invalid sample method, using default: %s\n", sample_method_str[sample_method_found]);
615+
fprintf(stderr, "Invalid sample method, using default: %s\n", sd_sample_method_name((sample_method_t)sample_method_found));
747616
}
748617
sample_method = (sample_method_t)sample_method_found;
749618

750-
for (int d = 0; d < SCHEDULER_COUNT; d++) {
751-
if (!strcmp(scheduler_str, schedulers[d])) {
752-
scheduler = (scheduler_t)d;
753-
fprintf (stderr, "Found scheduler: %s\n", scheduler_str);
754-
}
619+
scheduler_t sched = str_to_scheduler(scheduler_str);
620+
if (sched != SCHEDULER_COUNT) {
621+
scheduler = sched;
622+
fprintf(stderr, "Found scheduler: %s\n", scheduler_str);
755623
}
756624
if (scheduler == SCHEDULER_COUNT) {
757-
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
758-
fprintf(stderr, "Invalid scheduler, using default: %s\n", schedulers[scheduler]);
625+
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
626+
fprintf(stderr, "Invalid scheduler, using default: %s\n", sd_scheduler_name(scheduler));
759627
}
760628

761629
sd_c = sd_ctx;

0 commit comments

Comments
 (0)