|
27 | 27 | #include <stdlib.h> |
28 | 28 | #include <regex> |
29 | 29 |
|
30 | | -// Names of the sampler method, same order as enum sample_method in stable-diffusion.h |
31 | | -const char* sample_method_str[] = { |
32 | | - "euler", |
33 | | - "euler_a", |
34 | | - "heun", |
35 | | - "dpm2", |
36 | | - "dpm++2s_a", |
37 | | - "dpm++2m", |
38 | | - "dpm++2mv2", |
39 | | - "ipndm", |
40 | | - "ipndm_v", |
41 | | - "lcm", |
42 | | - "ddim_trailing", |
43 | | - "tcd", |
44 | | - "res_multistep", |
45 | | - "res_2s", |
46 | | -}; |
47 | | - |
48 | | -static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch"); |
49 | | - |
50 | | -// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h |
51 | | -const char* schedulers[] = { |
52 | | - "discrete", |
53 | | - "karras", |
54 | | - "exponential", |
55 | | - "ays", |
56 | | - "gits", |
57 | | - "sgm_uniform", |
58 | | - "simple", |
59 | | - "smoothstep", |
60 | | - "kl_optimal", |
61 | | - "lcm", |
62 | | - "bong_tangent", |
63 | | -}; |
64 | | - |
65 | | -static_assert(std::size(schedulers) == SCHEDULER_COUNT, "schedulers mismatch"); |
66 | | - |
67 | | -// New enum string arrays |
68 | | -const char* rng_type_str[] = { |
69 | | - "std_default", |
70 | | - "cuda", |
71 | | - "cpu", |
72 | | -}; |
73 | | -static_assert(std::size(rng_type_str) == RNG_TYPE_COUNT, "rng type mismatch"); |
74 | | - |
75 | | -const char* prediction_str[] = { |
76 | | - "epsilon", |
77 | | - "v", |
78 | | - "edm_v", |
79 | | - "flow", |
80 | | - "flux_flow", |
81 | | - "flux2_flow", |
82 | | -}; |
83 | | -static_assert(std::size(prediction_str) == PREDICTION_COUNT, "prediction mismatch"); |
84 | | - |
85 | | -const char* lora_apply_mode_str[] = { |
86 | | - "auto", |
87 | | - "immediately", |
88 | | - "at_runtime", |
89 | | -}; |
90 | | -static_assert(std::size(lora_apply_mode_str) == LORA_APPLY_MODE_COUNT, "lora apply mode mismatch"); |
91 | | - |
92 | | -constexpr const char* sd_type_str[] = { |
93 | | - "f32", // 0 |
94 | | - "f16", // 1 |
95 | | - "q4_0", // 2 |
96 | | - "q4_1", // 3 |
97 | | - nullptr, // 4 |
98 | | - nullptr, // 5 |
99 | | - "q5_0", // 6 |
100 | | - "q5_1", // 7 |
101 | | - "q8_0", // 8 |
102 | | - "q8_1", // 9 |
103 | | - "q2_k", // 10 |
104 | | - "q3_k", // 11 |
105 | | - "q4_k", // 12 |
106 | | - "q5_k", // 13 |
107 | | - "q6_k", // 14 |
108 | | - "q8_k", // 15 |
109 | | - "iq2_xxs", // 16 |
110 | | - "iq2_xs", // 17 |
111 | | - "iq3_xxs", // 18 |
112 | | - "iq1_s", // 19 |
113 | | - "iq4_nl", // 20 |
114 | | - "iq3_s", // 21 |
115 | | - "iq2_s", // 22 |
116 | | - "iq4_xs", // 23 |
117 | | - "i8", // 24 |
118 | | - "i16", // 25 |
119 | | - "i32", // 26 |
120 | | - "i64", // 27 |
121 | | - "f64", // 28 |
122 | | - "iq1_m", // 29 |
123 | | - "bf16", // 30 |
124 | | - nullptr, nullptr, nullptr, // 31-33 |
125 | | - "tq1_0", // 34 |
126 | | - "tq2_0", // 35 |
127 | | - nullptr, nullptr, nullptr, // 36-38 |
128 | | - "mxfp4" // 39 |
129 | | -}; |
130 | | -static_assert(std::size(sd_type_str) == SD_TYPE_COUNT, "sd type mismatch"); |
| 30 | + |
131 | 31 |
|
132 | 32 | sd_ctx_params_t ctx_params; |
133 | 33 | sd_ctx_t* sd_c; |
@@ -596,75 +496,45 @@ int load_model(const char *model, char *model_path, char* options[], int threads |
596 | 496 | if (!strcmp(optname, "flow_shift")) flow_shift = atof(optval); |
597 | 497 |
|
598 | 498 | if (!strcmp(optname, "rng_type")) { |
599 | | - int found = -1; |
600 | | - for (int m = 0; m < RNG_TYPE_COUNT; m++) { |
601 | | - if (!strcmp(optval, rng_type_str[m])) { |
602 | | - found = m; |
603 | | - break; |
604 | | - } |
605 | | - } |
606 | | - if (found != -1) { |
607 | | - rng_type = (rng_type_t)found; |
| 499 | + rng_type_t parsed = str_to_rng_type(optval); |
| 500 | + if (parsed != RNG_TYPE_COUNT) { |
| 501 | + rng_type = parsed; |
608 | 502 | fprintf(stderr, "Found rng_type: %s\n", optval); |
609 | 503 | } else { |
610 | 504 | fprintf(stderr, "Invalid rng_type: %s, using default\n", optval); |
611 | 505 | } |
612 | 506 | } |
613 | 507 | if (!strcmp(optname, "sampler_rng_type")) { |
614 | | - int found = -1; |
615 | | - for (int m = 0; m < RNG_TYPE_COUNT; m++) { |
616 | | - if (!strcmp(optval, rng_type_str[m])) { |
617 | | - found = m; |
618 | | - break; |
619 | | - } |
620 | | - } |
621 | | - if (found != -1) { |
622 | | - sampler_rng_type = (rng_type_t)found; |
| 508 | + rng_type_t parsed = str_to_rng_type(optval); |
| 509 | + if (parsed != RNG_TYPE_COUNT) { |
| 510 | + sampler_rng_type = parsed; |
623 | 511 | fprintf(stderr, "Found sampler_rng_type: %s\n", optval); |
624 | 512 | } else { |
625 | 513 | fprintf(stderr, "Invalid sampler_rng_type: %s, using default\n", optval); |
626 | 514 | } |
627 | 515 | } |
628 | 516 | if (!strcmp(optname, "prediction")) { |
629 | | - int found = -1; |
630 | | - for (int m = 0; m < PREDICTION_COUNT; m++) { |
631 | | - if (!strcmp(optval, prediction_str[m])) { |
632 | | - found = m; |
633 | | - break; |
634 | | - } |
635 | | - } |
636 | | - if (found != -1) { |
637 | | - prediction = (prediction_t)found; |
| 517 | + prediction_t parsed = str_to_prediction(optval); |
| 518 | + if (parsed != PREDICTION_COUNT) { |
| 519 | + prediction = parsed; |
638 | 520 | fprintf(stderr, "Found prediction: %s\n", optval); |
639 | 521 | } else { |
640 | 522 | fprintf(stderr, "Invalid prediction: %s, using default\n", optval); |
641 | 523 | } |
642 | 524 | } |
643 | 525 | if (!strcmp(optname, "lora_apply_mode")) { |
644 | | - int found = -1; |
645 | | - for (int m = 0; m < LORA_APPLY_MODE_COUNT; m++) { |
646 | | - if (!strcmp(optval, lora_apply_mode_str[m])) { |
647 | | - found = m; |
648 | | - break; |
649 | | - } |
650 | | - } |
651 | | - if (found != -1) { |
652 | | - lora_apply_mode = (lora_apply_mode_t)found; |
| 526 | + lora_apply_mode_t parsed = str_to_lora_apply_mode(optval); |
| 527 | + if (parsed != LORA_APPLY_MODE_COUNT) { |
| 528 | + lora_apply_mode = parsed; |
653 | 529 | fprintf(stderr, "Found lora_apply_mode: %s\n", optval); |
654 | 530 | } else { |
655 | 531 | fprintf(stderr, "Invalid lora_apply_mode: %s, using default\n", optval); |
656 | 532 | } |
657 | 533 | } |
658 | 534 | if (!strcmp(optname, "wtype")) { |
659 | | - int found = -1; |
660 | | - for (int m = 0; m < SD_TYPE_COUNT; m++) { |
661 | | - if (sd_type_str[m] && !strcmp(optval, sd_type_str[m])) { |
662 | | - found = m; |
663 | | - break; |
664 | | - } |
665 | | - } |
666 | | - if (found != -1) { |
667 | | - wtype = (sd_type_t)found; |
| 535 | + sd_type_t parsed = str_to_sd_type(optval); |
| 536 | + if (parsed != SD_TYPE_COUNT) { |
| 537 | + wtype = parsed; |
668 | 538 | fprintf(stderr, "Found wtype: %s\n", optval); |
669 | 539 | } else { |
670 | 540 | fprintf(stderr, "Invalid wtype: %s, using default\n", optval); |
@@ -735,27 +605,25 @@ int load_model(const char *model, char *model_path, char* options[], int threads |
735 | 605 | fprintf (stderr, "Created context: OK\n"); |
736 | 606 |
|
737 | 607 | int sample_method_found = -1; |
738 | | - for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) { |
739 | | - if (!strcmp(sampler, sample_method_str[m])) { |
740 | | - sample_method_found = m; |
741 | | - fprintf(stderr, "Found sampler: %s\n", sampler); |
742 | | - } |
| 608 | + sample_method_t sm = str_to_sample_method(sampler); |
| 609 | + if (sm != SAMPLE_METHOD_COUNT) { |
| 610 | + sample_method_found = (int)sm; |
| 611 | + fprintf(stderr, "Found sampler: %s\n", sampler); |
743 | 612 | } |
744 | 613 | if (sample_method_found == -1) { |
745 | 614 | sample_method_found = sd_get_default_sample_method(sd_ctx); |
746 | | - fprintf(stderr, "Invalid sample method, using default: %s\n", sample_method_str[sample_method_found]); |
| 615 | + fprintf(stderr, "Invalid sample method, using default: %s\n", sd_sample_method_name((sample_method_t)sample_method_found)); |
747 | 616 | } |
748 | 617 | sample_method = (sample_method_t)sample_method_found; |
749 | 618 |
|
750 | | - for (int d = 0; d < SCHEDULER_COUNT; d++) { |
751 | | - if (!strcmp(scheduler_str, schedulers[d])) { |
752 | | - scheduler = (scheduler_t)d; |
753 | | - fprintf (stderr, "Found scheduler: %s\n", scheduler_str); |
754 | | - } |
| 619 | + scheduler_t sched = str_to_scheduler(scheduler_str); |
| 620 | + if (sched != SCHEDULER_COUNT) { |
| 621 | + scheduler = sched; |
| 622 | + fprintf(stderr, "Found scheduler: %s\n", scheduler_str); |
755 | 623 | } |
756 | 624 | if (scheduler == SCHEDULER_COUNT) { |
757 | | - scheduler = sd_get_default_scheduler(sd_ctx, sample_method); |
758 | | - fprintf(stderr, "Invalid scheduler, using default: %s\n", schedulers[scheduler]); |
| 625 | + scheduler = sd_get_default_scheduler(sd_ctx, sample_method); |
| 626 | + fprintf(stderr, "Invalid scheduler, using default: %s\n", sd_scheduler_name(scheduler)); |
759 | 627 | } |
760 | 628 |
|
761 | 629 | sd_c = sd_ctx; |
|
0 commit comments