2424#define STB_IMAGE_RESIZE_STATIC
2525#include " stb_image_resize.h"
2626
27+ #define IMATRIX_IMPL
28+ #include " imatrix.hpp"
29+ static IMatrixCollector g_collector;
30+
2731#define SAFE_STR (s ) ((s) ? (s) : " " )
2832#define BOOL_STR (b ) ((b) ? " true" : " false" )
2933
@@ -109,6 +113,12 @@ struct SDParams {
109113 bool chroma_use_dit_mask = true ;
110114 bool chroma_use_t5_mask = false ;
111115 int chroma_t5_mask_pad = 1 ;
116+
117+ /* Imatrix params */
118+
119+ std::string imatrix_out = " " ;
120+
121+ std::vector<std::string> imatrix_in = {};
112122};
113123
114124void print_params (SDParams params) {
@@ -193,6 +203,8 @@ void print_usage(int argc, const char* argv[]) {
193203 printf (" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n " );
194204 printf (" If not specified, the default is the type of the weight file\n " );
195205 printf (" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \" ^vae\\ .=f16,model\\ .=q8_0\" )\n " );
206+ printf (" --imat-out [PATH] If set, compute the imatrix for this run and save it to the provided path" );
207+ printf (" --imat-in [PATH] Use imatrix for quantization." );
196208 printf (" --lora-model-dir [DIR] lora model directory\n " );
197209 printf (" -i, --init-img [IMAGE] path to the input image, required by img2img\n " );
198210 printf (" --mask [MASK] path to the mask image, required by img2img with mask\n " );
@@ -388,6 +400,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
388400 {" -n" , " --negative-prompt" , " " , ¶ms.negative_prompt },
389401
390402 {" " , " --upscale-model" , " " , ¶ms.esrgan_path },
403+ {" " , " --imat-out" , " " , ¶ms.imatrix_out },
404+
391405 };
392406
393407 options.int_options = {
@@ -557,6 +571,14 @@ void parse_args(int argc, const char** argv, SDParams& params) {
557571 return 1 ;
558572 };
559573
574+ auto on_imatrix_in_arg = [&](int argc, const char ** argv, int index) {
575+ if (++index >= argc) {
576+ return -1 ;
577+ }
578+ params.imatrix_in .push_back (argv[index]);
579+ return 1 ;
580+ };
581+
560582 options.manual_options = {
561583 {" -M" , " --mode" , " " , on_mode_arg},
562584 {" " , " --type" , " " , on_type_arg},
@@ -567,6 +589,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
567589 {" " , " --skip-layers" , " " , on_skip_layers_arg},
568590 {" -r" , " --ref-image" , " " , on_ref_image_arg},
569591 {" -h" , " --help" , " " , on_help_arg},
592+ {" " , " --imat-in" , " " , on_imatrix_in_arg},
593+
570594 };
571595
572596 if (!parse_options (argc, argv, options)) {
@@ -728,6 +752,10 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
728752 fflush (out_stream);
729753}
730754
755+ static bool collect_imatrix (struct ggml_tensor * t, bool ask, void * user_data) {
756+ return g_collector.collect_imatrix (t, ask, user_data);
757+ }
758+
731759int main (int argc, const char * argv[]) {
732760 SDParams params;
733761
@@ -752,8 +780,21 @@ int main(int argc, const char* argv[]) {
752780 printf (" %s" , sd_get_system_info ());
753781 }
754782
783+ if (params.imatrix_out != " " ) {
784+ sd_set_backend_eval_callback ((sd_graph_eval_callback_t )collect_imatrix, ¶ms);
785+ }
786+ if (params.imatrix_out != " " || params.mode == CONVERT || params.wtype != SD_TYPE_COUNT) {
787+ setConvertImatrixCollector ((void *)&g_collector);
788+ for (const auto & in_file : params.imatrix_in ) {
789+ printf (" loading imatrix from '%s'\n " , in_file.c_str ());
790+ if (!g_collector.load_imatrix (in_file.c_str ())) {
791+ printf (" Failed to load %s\n " , in_file.c_str ());
792+ }
793+ }
794+ }
795+
755796 if (params.mode == CONVERT) {
756- bool success = convert (params.model_path .c_str (), params.vae_path .c_str (), params.output_path .c_str (), params.wtype , params.tensor_type_rules .c_str (), NULL );
797+ bool success = convert (params.model_path .c_str (), params.vae_path .c_str (), params.output_path .c_str (), params.wtype , params.tensor_type_rules .c_str ());
757798 if (!success) {
758799 fprintf (stderr,
759800 " convert '%s'/'%s' to '%s' failed\n " ,
@@ -1060,6 +1101,9 @@ int main(int argc, const char* argv[]) {
10601101 free (results[i].data );
10611102 results[i].data = NULL ;
10621103 }
1104+ if (params.imatrix_out != " " ) {
1105+ g_collector.save_imatrix (params.imatrix_out );
1106+ }
10631107 free (results);
10641108 free_sd_ctx (sd_ctx);
10651109 free (control_image_buffer);
0 commit comments