Skip to content

Commit d397adf

Browse files
committed
Refactor imatrix implementation into main example
1 parent 6d0d214 commit d397adf

File tree

8 files changed

+82
-1192
lines changed

8 files changed

+82
-1192
lines changed

examples/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
22

3-
add_subdirectory(cli)
4-
add_subdirectory(imatrix)
3+
add_subdirectory(cli)

examples/cli/main.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
#define STB_IMAGE_RESIZE_STATIC
2525
#include "stb_image_resize.h"
2626

27+
#define IMATRIX_IMPL
28+
#include "imatrix.hpp"
29+
static IMatrixCollector g_collector;
30+
2731
#define SAFE_STR(s) ((s) ? (s) : "")
2832
#define BOOL_STR(b) ((b) ? "true" : "false")
2933

@@ -109,6 +113,12 @@ struct SDParams {
109113
bool chroma_use_dit_mask = true;
110114
bool chroma_use_t5_mask = false;
111115
int chroma_t5_mask_pad = 1;
116+
117+
/* Imatrix params */
118+
119+
std::string imatrix_out = "";
120+
121+
std::vector<std::string> imatrix_in = {};
112122
};
113123

114124
void print_params(SDParams params) {
@@ -193,6 +203,8 @@ void print_usage(int argc, const char* argv[]) {
193203
printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n");
194204
printf(" If not specified, the default is the type of the weight file\n");
195205
printf(" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n");
206+
printf(" --imat-out [PATH] If set, compute the imatrix for this run and save it to the provided path");
207+
printf(" --imat-in [PATH] Use imatrix for quantization.");
196208
printf(" --lora-model-dir [DIR] lora model directory\n");
197209
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
198210
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
@@ -388,6 +400,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
388400
{"-n", "--negative-prompt", "", &params.negative_prompt},
389401

390402
{"", "--upscale-model", "", &params.esrgan_path},
403+
{"", "--imat-out", "", &params.imatrix_out},
404+
391405
};
392406

393407
options.int_options = {
@@ -557,6 +571,14 @@ void parse_args(int argc, const char** argv, SDParams& params) {
557571
return 1;
558572
};
559573

574+
auto on_imatrix_in_arg = [&](int argc, const char** argv, int index) {
575+
if (++index >= argc) {
576+
return -1;
577+
}
578+
params.imatrix_in.push_back(argv[index]);
579+
return 1;
580+
};
581+
560582
options.manual_options = {
561583
{"-M", "--mode", "", on_mode_arg},
562584
{"", "--type", "", on_type_arg},
@@ -567,6 +589,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
567589
{"", "--skip-layers", "", on_skip_layers_arg},
568590
{"-r", "--ref-image", "", on_ref_image_arg},
569591
{"-h", "--help", "", on_help_arg},
592+
{"", "--imat-in", "", on_imatrix_in_arg},
593+
570594
};
571595

572596
if (!parse_options(argc, argv, options)) {
@@ -728,6 +752,10 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
728752
fflush(out_stream);
729753
}
730754

755+
static bool collect_imatrix(struct ggml_tensor* t, bool ask, void* user_data) {
756+
return g_collector.collect_imatrix(t, ask, user_data);
757+
}
758+
731759
int main(int argc, const char* argv[]) {
732760
SDParams params;
733761

@@ -752,8 +780,21 @@ int main(int argc, const char* argv[]) {
752780
printf("%s", sd_get_system_info());
753781
}
754782

783+
if (params.imatrix_out != "") {
784+
sd_set_backend_eval_callback((sd_graph_eval_callback_t)collect_imatrix, &params);
785+
}
786+
if (params.imatrix_out != "" || params.mode == CONVERT || params.wtype != SD_TYPE_COUNT) {
787+
setConvertImatrixCollector((void*)&g_collector);
788+
for (const auto& in_file : params.imatrix_in) {
789+
printf("loading imatrix from '%s'\n", in_file.c_str());
790+
if (!g_collector.load_imatrix(in_file.c_str())) {
791+
printf("Failed to load %s\n", in_file.c_str());
792+
}
793+
}
794+
}
795+
755796
if (params.mode == CONVERT) {
756-
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype, params.tensor_type_rules.c_str(),NULL);
797+
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype, params.tensor_type_rules.c_str());
757798
if (!success) {
758799
fprintf(stderr,
759800
"convert '%s'/'%s' to '%s' failed\n",
@@ -1060,6 +1101,9 @@ int main(int argc, const char* argv[]) {
10601101
free(results[i].data);
10611102
results[i].data = NULL;
10621103
}
1104+
if (params.imatrix_out != "") {
1105+
g_collector.save_imatrix(params.imatrix_out);
1106+
}
10631107
free(results);
10641108
free_sd_ctx(sd_ctx);
10651109
free(control_image_buffer);

examples/imatrix/CMakeLists.txt

Lines changed: 0 additions & 7 deletions
This file was deleted.

0 commit comments

Comments
 (0)