diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c3dca338a..749366e75 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -21,6 +21,7 @@ on: "**/*.c", "**/*.cpp", "**/*.cu", + "examples/server/frontend", "examples/server/frontend/**", ] pull_request: @@ -35,6 +36,7 @@ on: "**/*.c", "**/*.cpp", "**/*.cu", + "examples/server/frontend", "examples/server/frontend/**", ] @@ -174,6 +176,7 @@ jobs: build-and-push-docker-images: name: Build and push container images + if: ${{ github.event_name != 'pull_request' }} runs-on: ubuntu-latest permissions: @@ -239,6 +242,7 @@ jobs: id: build-push uses: docker/build-push-action@v6 with: + context: . platforms: linux/amd64 push: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} file: Dockerfile.${{ matrix.variant }} diff --git a/.gitmodules b/.gitmodules index 91cde1f28..5ccdc3824 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,10 @@ url = https://github.com/ggml-org/ggml.git [submodule "examples/server/frontend"] path = examples/server/frontend - url = https://github.com/leejet/stable-ui.git + url = https://github.com/leejet/sdcpp-webui.git [submodule "thirdparty/libwebp"] path = thirdparty/libwebp url = https://github.com/webmproject/libwebp.git +[submodule "thirdparty/libwebm"] + path = thirdparty/libwebm + url = https://github.com/webmproject/libwebm.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 9098f827b..48ce456ea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,10 @@ endif() if (MSVC) add_compile_definitions(_CRT_SECURE_NO_WARNINGS) add_compile_definitions(_SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING) + add_compile_options( + $<$:/MP> + $<$:/MP> + ) endif() set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) @@ -22,6 +26,26 @@ else() set(SD_STANDALONE OFF) endif() +set(SD_SUBMODULE_WEBP FALSE) +if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/thirdparty/libwebp/CMakeLists.txt") + set(SD_SUBMODULE_WEBP TRUE) +endif() +if(SD_SUBMODULE_WEBP) + set(SD_WEBP_DEFAULT ON) +else() + set(SD_WEBP_DEFAULT ${SD_USE_SYSTEM_WEBP}) +endif() + +set(SD_SUBMODULE_WEBM FALSE) +if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/thirdparty/libwebm/CMakeLists.txt") + set(SD_SUBMODULE_WEBM TRUE) +endif() +if(SD_SUBMODULE_WEBM) + set(SD_WEBM_DEFAULT ON) +else() + set(SD_WEBM_DEFAULT ${SD_USE_SYSTEM_WEBM}) +endif() + # # Option list # @@ -29,7 +53,10 @@ endif() # general #option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE}) option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE}) -option(SD_WEBP "sd: enable WebP image I/O support" ON) +option(SD_WEBP "sd: enable WebP image I/O support" ${SD_WEBP_DEFAULT}) +option(SD_USE_SYSTEM_WEBP "sd: link against system libwebp" OFF) +option(SD_WEBM "sd: enable WebM video output support" ${SD_WEBM_DEFAULT}) +option(SD_USE_SYSTEM_WEBM "sd: link against system libwebm" OFF) option(SD_CUDA "sd: cuda backend" OFF) option(SD_HIPBLAS "sd: rocm backend" OFF) option(SD_METAL "sd: metal backend" OFF) @@ -45,51 +72,94 @@ option(SD_USE_SYSTEM_GGML "sd: use system-installed GGML library" OFF if(SD_CUDA) message("-- Use CUDA as backend stable-diffusion") set(GGML_CUDA ON) - add_definitions(-DSD_USE_CUDA) endif() if(SD_METAL) message("-- Use Metal as backend stable-diffusion") set(GGML_METAL ON) - add_definitions(-DSD_USE_METAL) endif() if (SD_VULKAN) message("-- Use Vulkan as backend stable-diffusion") set(GGML_VULKAN ON) - add_definitions(-DSD_USE_VULKAN) endif () if (SD_OPENCL) message("-- Use OpenCL as backend stable-diffusion") set(GGML_OPENCL ON) - add_definitions(-DSD_USE_OPENCL) endif () if (SD_HIPBLAS) message("-- Use HIPBLAS as backend stable-diffusion") set(GGML_HIP ON) - add_definitions(-DSD_USE_CUDA) endif () if(SD_MUSA) message("-- Use MUSA as backend stable-diffusion") set(GGML_MUSA ON) - add_definitions(-DSD_USE_CUDA) endif() if(SD_WEBP) - add_compile_definitions(SD_USE_WEBP) + if(NOT SD_SUBMODULE_WEBP AND NOT SD_USE_SYSTEM_WEBP) + message(FATAL_ERROR "WebP support enabled but no source found. + Either initialize the submodule:\n git submodule update --init thirdparty/libwebp\n\n" + "Or link against system library:\n cmake (...) -DSD_USE_SYSTEM_WEBP=ON") + endif() + if(SD_USE_SYSTEM_WEBP) + find_package(WebP REQUIRED) + add_library(webp ALIAS WebP::webp) + # libwebp CMake target naming is not consistent across versions/distros. + # Some export WebP::libwebpmux, others export WebP::webpmux. + if(TARGET WebP::libwebpmux) + add_library(libwebpmux ALIAS WebP::libwebpmux) + elseif(TARGET WebP::webpmux) + add_library(libwebpmux ALIAS WebP::webpmux) + else() + message(FATAL_ERROR + "Could not find a compatible webpmux target in system WebP package. " + "Expected WebP::libwebpmux or WebP::webpmux." + ) + endif() + endif() +endif() + +if(SD_WEBM) + if(NOT SD_WEBP) + message(FATAL_ERROR "SD_WEBM requires SD_WEBP because WebM output reuses libwebp VP8 encoding.") + endif() + if(NOT SD_SUBMODULE_WEBM AND NOT SD_USE_SYSTEM_WEBM) + message(FATAL_ERROR "WebM support enabled but no source found. + Either initialize the submodule:\n git submodule update --init thirdparty/libwebm\n\n" + "Or link against system library:\n cmake (...) -DSD_USE_SYSTEM_WEBM=ON") + endif() + if(SD_USE_SYSTEM_WEBM) + find_path(WEBM_INCLUDE_DIR + NAMES mkvmuxer/mkvmuxer.h mkvparser/mkvparser.h common/webmids.h + PATH_SUFFIXES webm + REQUIRED) + find_library(WEBM_LIBRARY + NAMES webm libwebm + REQUIRED) + + add_library(webm UNKNOWN IMPORTED) + set_target_properties(webm PROPERTIES + IMPORTED_LOCATION "${WEBM_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${WEBM_INCLUDE_DIR}") + endif() endif() set(SD_LIB stable-diffusion) -file(GLOB SD_LIB_SOURCES +file(GLOB SD_LIB_SOURCES CONFIGURE_DEPENDS "src/*.h" "src/*.cpp" "src/*.hpp" - "src/vocab/*.h" - "src/vocab/*.cpp" + "src/model_io/*.h" + "src/model_io/*.cpp" + "src/tokenizers/*.h" + "src/tokenizers/*.cpp" + "src/tokenizers/vocab/*.h" + "src/tokenizers/vocab/*.cpp" ) find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) @@ -146,7 +216,6 @@ if(SD_SYCL) message("-- Use SYCL as backend stable-diffusion") set(GGML_SYCL ON) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl") - add_definitions(-DSD_USE_SYCL) # disable fast-math on host, see: # https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-10/fp-model-fp.html if (WIN32) @@ -182,7 +251,7 @@ endif() add_subdirectory(thirdparty) target_link_libraries(${SD_LIB} PUBLIC ggml zip) -target_include_directories(${SD_LIB} PUBLIC . include) +target_include_directories(${SD_LIB} PUBLIC . src include) target_include_directories(${SD_LIB} PUBLIC . thirdparty) target_compile_features(${SD_LIB} PUBLIC c_std_11 cxx_std_17) diff --git a/README.md b/README.md index b5bb49751..8afdeb20a 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,9 @@ API and command-line option may change frequently.*** ## ๐Ÿ”ฅImportant News +* **2026/04/11** ๐Ÿš€ stable-diffusion.cpp now uses a brand-new embedded web UI. + ๐Ÿ‘‰ Details: [PR #1408](https://github.com/leejet/stable-diffusion.cpp/pull/1408) + * **2026/01/18** ๐Ÿš€ stable-diffusion.cpp now supports **FLUX.2-klein** ๐Ÿ‘‰ Details: [PR #1193](https://github.com/leejet/stable-diffusion.cpp/pull/1193) @@ -54,6 +57,7 @@ API and command-line option may change frequently.*** - [Z-Image](./docs/z_image.md) - [Ovis-Image](./docs/ovis_image.md) - [Anima](./docs/anima.md) + - [ERNIE-Image](./docs/ernie_image.md) - Image Edit Models - [FLUX.1-Kontext-dev](./docs/kontext.md) - [Qwen Image Edit series](./docs/qwen_image_edit.md) @@ -73,9 +77,10 @@ API and command-line option may change frequently.*** - OpenCL - SYCL - Supported weight formats - - Pytorch checkpoint (`.ckpt` or `.pth`) + - Pytorch checkpoint (`.ckpt` or `.pth` or `.pt`) - Safetensors (`.safetensors`) - GGUF (`.gguf`) +- Convert mode supports converting model weights to `.gguf` or `.safetensors` - Supported platforms - Linux - Mac OS @@ -93,6 +98,7 @@ API and command-line option may change frequently.*** - `DPM++ 2M` - [`DPM++ 2M v2`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457) - `DPM++ 2S a` + - `ER-SDE` - [`LCM`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13952) - Cross-platform reproducibility - `--rng cuda`, default, consistent with the `stable-diffusion-webui GPU RNG` @@ -141,6 +147,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe - [๐Ÿ”ฅZ-Image](./docs/z_image.md) - [Ovis-Image](./docs/ovis_image.md) - [Anima](./docs/anima.md) +- [ERNIE-Image](./docs/ernie_image.md) - [LoRA](./docs/lora.md) - [LCM/LCM-LoRA](./docs/lcm.md) - [Using PhotoMaker to personalize image generation](./docs/photo_maker.md) diff --git a/assets/ernie_image/example.png b/assets/ernie_image/example.png new file mode 100644 index 000000000..3f5ed652f Binary files /dev/null and b/assets/ernie_image/example.png differ diff --git a/assets/ernie_image/turbo_example.png b/assets/ernie_image/turbo_example.png new file mode 100644 index 000000000..15318b3e9 Binary files /dev/null and b/assets/ernie_image/turbo_example.png differ diff --git a/docs/build.md b/docs/build.md index eabb51ac3..dbc87691e 100644 --- a/docs/build.md +++ b/docs/build.md @@ -16,15 +16,23 @@ git submodule init git submodule update ``` -## WebP Support in Examples +## WebP and WebM Support in Examples -The example applications (`examples/cli` and `examples/server`) use `libwebp` to support WebP image I/O. This is enabled by default. +The example applications (`examples/cli` and `examples/server`) use `libwebp` to support WebP image I/O, and `examples/cli` can also use `libwebm` for `.webm` video output. Both are enabled by default. WebM output currently reuses `libwebp` to encode each frame as VP8 before muxing with `libwebm`. -If you do not want WebP support, you can disable it at configure time: +If you do not want WebP/WebM support, you can disable them at configure time: ```shell mkdir build && cd build -cmake .. -DSD_WEBP=OFF +cmake .. -DSD_WEBP=OFF -DSD_WEBM=OFF +cmake --build . --config Release +``` + +If the submodules are not available, you can also link against system packages instead: + +```shell +mkdir build && cd build +cmake .. -DSD_USE_SYSTEM_WEBP=ON -DSD_USE_SYSTEM_WEBM=ON cmake --build . --config Release ``` diff --git a/docs/caching.md b/docs/caching.md index b02a541b7..01f019744 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -131,8 +131,6 @@ sd-cli -m model.safetensors -p "a cat" --cache-mode spectrum | `warmup` | Steps to always compute before caching starts | 4 | | `stop` | Stop caching at this fraction of total steps | 0.9 | -``` - ### Performance Tips - Start with default thresholds and adjust based on output quality diff --git a/docs/distilled_sd.md b/docs/distilled_sd.md index 3174b18f8..7aa8fbede 100644 --- a/docs/distilled_sd.md +++ b/docs/distilled_sd.md @@ -87,51 +87,32 @@ pipe.save_pretrained("segmindtiny-sd", safe_serialization=True) ```bash python convert_diffusers_to_original_stable_diffusion.py \ --model_path ./segmindtiny-sd \ - --checkpoint_path ./segmind_tiny-sd.ckpt --half + --checkpoint_path ./segmind_tiny-sd.safetensors --half --use_safetensors ``` -The file segmind_tiny-sd.ckpt will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above. +The file segmind_tiny-sd.safetensors will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above. -##### Another available .ckpt file: - - * https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt - -To use this file, you must first adjust its non-contiguous tensors: - -```python -import torch -ckpt = torch.load("tinySDdistilled.ckpt", map_location=torch.device('cpu')) -for key, value in ckpt['state_dict'].items(): - if isinstance(value, torch.Tensor): - ckpt['state_dict'][key] = value.contiguous() -torch.save(ckpt, "tinySDdistilled_fixed.ckpt") -``` - - -### SDXS-512 +### SDXS-512-DreamShaper Another very tiny and **incredibly fast** model is SDXS by IDKiro et al. The authors refer to it as *"Real-Time One-Step Latent Diffusion Models with Image Conditions"*. For details read the paper: https://arxiv.org/pdf/2403.16627 . Once again the authors removed some more blocks of U-Net part and unlike other SD1 models they use an adjusted _AutoEncoderTiny_ instead of default _AutoEncoderKL_ for the VAE part. +##### Some ready-to-run SDXS-512 model files are available online, such as: -##### 1. Download the diffusers model from Hugging Face using Python: - -```python -from diffusers import StableDiffusionPipeline -pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper") -pipe.save_pretrained(save_directory="sdxs") -``` -##### 2. Create a safetensors file - -```bash -python convert_diffusers_to_original_stable_diffusion.py \ - --model_path sdxs --checkpoint_path sdxs.safetensors --half --use_safetensors -``` - -##### 3. Run the model as follows: +* https://huggingface.co/akleine/sdxs-512 +* https://huggingface.co/concedo/sdxs-512-tinySDdistilled-GGUF +##### Run the model as follows: ```bash ~/stable-diffusion.cpp/build/bin/sd-cli -m sdxs.safetensors -p "portrait of a lovely cat" \ --cfg-scale 1 --steps 1 ``` +Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here. + +### SDXS-512-0.9 + +Even though the name "SDXS-512-0.9" is similar to "SDXS-512-DreamShaper", it is *completely different* but also **incredibly fast**. Sometimes it is preferred, so try it yourself. +##### Download a ready-to-run file from here: + +* https://huggingface.co/akleine/sdxs-09 -Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here. +For the use of this model, both options ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are again absolutely necessary. diff --git a/docs/ernie_image.md b/docs/ernie_image.md new file mode 100644 index 000000000..d68da3966 --- /dev/null +++ b/docs/ernie_image.md @@ -0,0 +1,35 @@ +# How to Use + +You can run ERNIE-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM โ€” or even less. + +## Download weights + +- Download ERNIE-Image-Turbo + - safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/diffusion_models + - gguf: https://huggingface.co/unsloth/ERNIE-Image-Turbo-GGUF/tree/main +- Download ERNIE-Image + - safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/diffusion_models + - gguf: https://huggingface.co/unsloth/ERNIE-Image-GGUF/tree/main +- Download vae + - safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/vae +- Download ministral 3b + - safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/text_encoders + - gguf: https://huggingface.co/unsloth/Ministral-3-3B-Instruct-2512-GGUF/tree/main + +## Examples + +### ERNIE-Image-Turbo + +``` +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\ernie-image-turbo.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\ministral-3-3b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 8 -v --offload-to-cpu --diffusion-fa +``` + +ERNIE-Image Turbo example + +### ERNIE-Image + +``` +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\ernie-image-UD-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\ministral-3-3b.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa +``` + +ERNIE-Image example diff --git a/docs/flux2.md b/docs/flux2.md index 1524478cc..11202e919 100644 --- a/docs/flux2.md +++ b/docs/flux2.md @@ -8,6 +8,8 @@ - gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main - Download vae - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main +- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option + - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main - Download Mistral-Small-3.2-24B-Instruct-2506-GGUF - gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main @@ -31,6 +33,8 @@ - gguf: https://huggingface.co/leejet/FLUX.2-klein-base-4B-GGUF/tree/main - Download vae - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main +- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option + - safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main - Download Qwen3 4b - safetensors: https://huggingface.co/Comfy-Org/flux2-klein-4B/tree/main/split_files/text_encoders - gguf: https://huggingface.co/unsloth/Qwen3-4B-GGUF/tree/main diff --git a/examples/cli/CMakeLists.txt b/examples/cli/CMakeLists.txt index e4acaac87..db1f4ca37 100644 --- a/examples/cli/CMakeLists.txt +++ b/examples/cli/CMakeLists.txt @@ -1,6 +1,7 @@ set(TARGET sd-cli) add_executable(${TARGET} + ../common/common.cpp ../common/log.cpp ../common/media_io.cpp image_metadata.cpp @@ -9,6 +10,11 @@ add_executable(${TARGET} install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE stable-diffusion zip ${CMAKE_THREAD_LIBS_INIT}) if(SD_WEBP) + target_compile_definitions(${TARGET} PRIVATE SD_USE_WEBP) target_link_libraries(${TARGET} PRIVATE webp libwebpmux) endif() +if(SD_WEBM) + target_compile_definitions(${TARGET} PRIVATE SD_USE_WEBM) + target_link_libraries(${TARGET} PRIVATE webm) +endif() target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17) diff --git a/examples/cli/README.md b/examples/cli/README.md index 25fcce692..b32fe37f9 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -4,26 +4,29 @@ usage: ./bin/sd-cli [options] CLI Options: - -o, --output path to write result image to. you can use printf-style %d format specifiers for image sequences (default: - ./output.png) (eg. output_%03d.png). For video generation, single-file outputs support .avi and animated .webp - --preview-path path to write preview image to (default: ./preview.png). Multi-frame previews support .avi and animated .webp - --preview-interval interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at - every step) - --output-begin-idx starting index for output image sequence, must be non-negative (default 0 if specified %d in output path, 1 otherwise) - --image path to the image to inspect (for metadata mode) - --metadata-format metadata output format, one of [text, json] (default: text) - --canny apply canny preprocessor (edge detection) - --convert-name convert tensor name (for convert mode) - -v, --verbose print extra info - --color colors the logging tags according to level - --taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview tae) - --preview-noisy enables previewing noisy inputs of the models rather than the denoised outputs - --metadata-raw include raw hex previews for unparsed metadata payloads - --metadata-brief truncate long metadata text values in text output - --metadata-all include structural/container entries such as IHDR, IDAT, and non-metadata JPEG segments - -M, --mode run mode, one of [img_gen, vid_gen, upscale, convert, metadata], default: img_gen - --preview preview method. must be one of the following [none, proj, tae, vae] (default is none) - -h, --help show this help message and exit + -o, --output path to write result image to. you can use printf-style %d format specifiers for image + sequences (default: ./output.png) (eg. output_%03d.png). Single-file video outputs + support .avi, .webm, and animated .webp + --image path to the image to inspect (for metadata mode) + --metadata-format metadata output format, one of [text, json] (default: text) + --preview-path path to write preview image to (default: ./preview.png). Multi-frame previews support + .avi, .webm, and animated .webp + --preview-interval interval in denoising steps between consecutive updates of the image preview file + (default is 1, meaning updating at every step) + --output-begin-idx starting index for output image sequence, must be non-negative (default 0 if specified + %d in output path, 1 otherwise) + --canny apply canny preprocessor (edge detection) + --convert-name convert tensor name (for convert mode) + -v, --verbose print extra info + --color colors the logging tags according to level + --taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview tae) + --preview-noisy enables previewing noisy inputs of the models rather than the denoised outputs + --metadata-raw include raw hex previews for unparsed metadata payloads + --metadata-brief truncate long metadata text values in text output + --metadata-all include structural/container entries such as IHDR, IDAT, and non-metadata JPEG segments + -M, --mode run mode, one of [img_gen, vid_gen, upscale, convert, metadata], default: img_gen + --preview preview method. must be one of the following [none, proj, tae, vae] (default is none) + -h, --help show this help message and exit Context Options: -m, --model path to full model @@ -31,7 +34,8 @@ Context Options: --clip_g path to the clip-g text encoder --clip_vision path to the clip-vision encoder --t5xxl path to the t5xxl text encoder - --llm path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, mistral-small3.2 for flux2, ...) + --llm path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, + mistral-small3.2 for flux2, ...) --llm_vision path to the llm vit --qwen2vl alias of --llm. Deprecated. --qwen2vl_vision alias of --llm_vision. Deprecated. @@ -43,16 +47,18 @@ Context Options: --control-net path to control net model --embd-dir embeddings directory --lora-model-dir lora model directory + --hires-upscalers-dir highres fix upscaler model directory --tensor-type-rules weight type per tensor pattern (example: "^vae\.=f16,model\.=q8_0") --photo-maker path to PHOTOMAKER model --upscale-model path to esrgan model. - -t, --threads number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of - CPU physical cores + -t, --threads number of threads to use during computation (default: -1). If threads <= 0, + then threads will be set to the number of CPU physical cores --chroma-t5-mask-pad t5 mask pad size of chroma - --vae-tile-overlap tile overlap for vae tiling, in fraction of tile size (default: 0.5) - --vae-tiling process vae in tiles to reduce memory usage + --max-vram maximum VRAM budget in GiB for graph-cut segmented execution. 0 disables + graph splitting --force-sdxl-vae-conv-scale force use of conv scale on sdxl vae - --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed + --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM + when needed --mmap whether to memory-map model --control-net-cpu keep controlnet in cpu (for low vram) --clip-on-cpu keep clip in cpu (for low vram) @@ -67,20 +73,19 @@ Context Options: --chroma-disable-dit-mask disable dit mask for chroma --qwen-image-zero-cond-t enable zero_cond_t for qwen image --chroma-enable-t5-mask enable t5 mask for chroma - --type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the - type of the weight file + --type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, + q4_K). If not specified, the default is the type of the weight file --rng RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui) --sampler-rng sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng - --prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, flux2_flow] - --lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. In auto mode, if the model weights - contain any quantized parameters, the at_runtime mode will be used; otherwise, - immediately will be used.The immediately mode may have precision and - compatibility issues with quantized parameters, but it usually offers faster inference - speed and, in some cases, lower memory usage. The at_runtime mode, on the - other hand, is exactly the opposite. - --vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32) - --vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 - (overrides --vae-tile-size) + --prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, + flux2_flow] + --lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is + auto. In auto mode, if the model weights contain any quantized parameters, + the at_runtime mode will be used; otherwise, immediately will be used.The + immediately mode may have precision and compatibility issues with quantized + parameters, but it usually offers faster inference speed and, in some cases, + lower memory usage. The at_runtime mode, on the other hand, is exactly the + opposite. Generation Options: -p, --prompt the prompt to render @@ -89,69 +94,99 @@ Generation Options: --end-img path to the end image, required by flf2v --mask path to the mask image --control-image path to control image, control net - --control-video path to control video frames, It must be a directory path. The video frames inside should be stored as images in - lexicographical (character) order. For example, if the control video path is - `frames`, the directory contain images such as 00.png, 01.png, ... etc. + --control-video path to control video frames, It must be a directory path. The video frames + inside should be stored as images in lexicographical (character) order. For + example, if the control video path is `frames`, the directory contain images + such as 00.png, 01.png, ... etc. --pm-id-images-dir path to PHOTOMAKER input id images dir --pm-id-embed-path path to PHOTOMAKER v2 id embed + --hires-upscaler highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent + (nearest-exact), Latent (antialiased), Latent (bicubic), Latent (bicubic + antialiased), or a model name under --hires-upscalers-dir (default: Latent) -H, --height image height, in pixel space (default: 512) -W, --width image width, in pixel space (default: 512) --steps number of sample steps (default: 20) --high-noise-steps (high noise) number of sample steps (default: -1 = auto) - --clip-skip ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1). <= 0 represents unspecified, - will be 1 for SD1.x, 2 for SD2.x + --clip-skip ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer + (default: -1). <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x -b, --batch-count batch count --video-frames video frames (default: 1) --fps fps (default: 24) - --timestep-shift shift timestep for NitroFusion models (default: 0). recommended N for NitroSD-Realism around 250 and 500 for - NitroSD-Vibrant + --timestep-shift shift timestep for NitroFusion models (default: 0). recommended N for + NitroSD-Realism around 250 and 500 for NitroSD-Vibrant --upscale-repeats Run the ESRGAN upscaler this many times (default: 1) --upscale-tile-size tile size for ESRGAN upscaling (default: 128) + --hires-width highres fix target width, 0 to use --hires-scale (default: 0) + --hires-height highres fix target height, 0 to use --hires-scale (default: 0) + --hires-steps highres fix second pass sample steps, 0 to reuse --steps (default: 0) + --hires-upscale-tile-size highres fix upscaler tile size, reserved for model-backed upscalers (default: + 128) --cfg-scale unconditional guidance scale: (default: 7.0) - --img-cfg-scale image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale) + --img-cfg-scale image guidance scale for inpaint or instruct-pix2pix models: (default: same + as --cfg-scale) --guidance distilled guidance scale for models with guidance input (default: 3.5) - --slg-scale skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means disabled, a value of 2.5 is nice for sd3.5 - medium + --slg-scale skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means + disabled, a value of 2.5 is nice for sd3.5 medium --skip-layer-start SLG enabling point (default: 0.01) --skip-layer-end SLG disabling point (default: 0.2) - --eta noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a) + --eta noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and + res_2s; 1 for euler_a, er_sde and dpm++2s_a) --flow-shift shift value for Flow models like SD3.x or WAN (default: auto) --high-noise-cfg-scale (high noise) unconditional guidance scale: (default: 7.0) - --high-noise-img-cfg-scale (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale) - --high-noise-guidance (high noise) distilled guidance scale for models with guidance input (default: 3.5) - --high-noise-slg-scale (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0) + --high-noise-img-cfg-scale (high noise) image guidance scale for inpaint or instruct-pix2pix models + (default: same as --cfg-scale) + --high-noise-guidance (high noise) distilled guidance scale for models with guidance input + (default: 3.5) + --high-noise-slg-scale (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: + 0) --high-noise-skip-layer-start (high noise) SLG enabling point (default: 0.01) --high-noise-skip-layer-end (high noise) SLG disabling point (default: 0.2) - --high-noise-eta (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a) + --high-noise-eta (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, + res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a) --strength strength for noising/unnoising (default: 0.75) - --pm-style-strength - --control-strength strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image - --moe-boundary timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if `--high-noise-steps` is set to -1 + --pm-style-strength + --control-strength strength to apply Control Net (default: 0.9). 1.0 corresponds to full + destruction of information in init image + --moe-boundary timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if + `--high-noise-steps` is set to -1 --vace-strength wan vace strength - --increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1). + --vae-tile-overlap tile overlap for vae tiling, in fraction of tile size (default: 0.5) + --hires-scale highres fix scale when target size is not set (default: 2.0) + --hires-denoising-strength highres fix second pass denoising strength (default: 0.7) + --increase-ref-index automatically increase the indices of references images based on the order + they are listed (starting with 1). --disable-auto-resize-ref-image disable auto resize of ref images --disable-image-metadata do not embed generation metadata on image files + --vae-tiling process vae in tiles to reduce memory usage + --hires enable highres fix -s, --seed RNG seed (default: 42, use random seed for < 0) - --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, - tcd, res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a - otherwise) - --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, - ddim_trailing, tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan, - euler_a otherwise - --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, - kl_optimal, lcm, bong_tangent], default: discrete - --sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0"). + --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, + dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s, + er_sde] (default: euler for Flux/SD3/Wan, euler_a otherwise) + --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, + dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, + res_2s, er_sde] default: euler for Flux/SD3/Wan, euler_a otherwise + --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, + smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default: + discrete + --sigmas custom sigma values for the sampler, comma-separated (e.g., + "14.61,7.8,3.5,0.0"). --skip-layers layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) -r, --ref-image reference image for Flux Kontext models (can be used multiple times) - --cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), - 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting) + --cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), + 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT + Chebyshev+Taylor forecasting) --cache-option named cache params (key=value format, comma-separated). easycache/ucache: - threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=; - spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples: - "threshold=0.25" or "threshold=1.5,reset=0" or "w=0.4,window=2" - --scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache + threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: + Fn=,Bn=,threshold=,warmup=; spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. + Examples: "threshold=0.25" or "threshold=1.5,reset=0" + --scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., + "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache --scm-policy SCM policy: 'dynamic' (default) or 'static' + --vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32) + --vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size + if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size) ``` Metadata mode inspects PNG/JPEG container metadata without loading any model: diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index b4a3c343e..27513f475 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -15,10 +15,13 @@ // #include "preprocessing.hpp" #include "stable-diffusion.h" -#include "common/common.hpp" +#include "common/common.h" #include "common/media_io.h" +#include "common/resource_owners.hpp" #include "image_metadata.h" +namespace fs = std::filesystem; + const char* previews_str[] = { "none", "proj", @@ -58,7 +61,7 @@ struct SDCliParams { options.string_options = { {"-o", "--output", - "path to write result image to. you can use printf-style %d format specifiers for image sequences (default: ./output.png) (eg. output_%03d.png)", + "path to write result image to. you can use printf-style %d format specifiers for image sequences (default: ./output.png) (eg. output_%03d.png). Single-file video outputs support .avi, .webm, and animated .webp", &output_path}, {"", "--image", @@ -70,7 +73,7 @@ struct SDCliParams { &metadata_format}, {"", "--preview-path", - "path to write preview image to (default: ./preview.png)", + "path to write preview image to (default: ./preview.png). Multi-frame previews support .avi, .webm, and animated .webp", &preview_path}, }; @@ -189,17 +192,22 @@ struct SDCliParams { return options; }; - bool process_and_check() { - if (mode != METADATA && output_path.length() == 0) { - LOG_ERROR("error: the following arguments are required: output_path"); - return false; - } - + bool resolve() { if (mode == CONVERT) { if (output_path == "output.png") { output_path = "output.gguf"; } - } else if (mode == METADATA) { + } + return true; + } + + bool validate() { + if (mode != METADATA) { + if (output_path.length() == 0) { + LOG_ERROR("error: the following arguments are required: output_path"); + return false; + } + } else { if (image_path.empty()) { LOG_ERROR("error: metadata mode needs an image path (--image)"); return false; @@ -213,6 +221,16 @@ struct SDCliParams { return true; } + bool resolve_and_validate() { + if (!resolve()) { + return false; + } + if (!validate()) { + return false; + } + return true; + } + std::string to_string() const { std::ostringstream oss; oss << "SDCliParams {\n" @@ -257,10 +275,12 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP exit(cli_params.normal_exit ? 0 : 1); } - bool valid = cli_params.process_and_check(); + bool valid = cli_params.resolve_and_validate(); if (valid && cli_params.mode != METADATA) { - valid = ctx_params.process_and_check(cli_params.mode) && - gen_params.process_and_check(cli_params.mode, ctx_params.lora_model_dir); + valid = ctx_params.resolve_and_validate(cli_params.mode) && + gen_params.resolve_and_validate(cli_params.mode, + ctx_params.lora_model_dir, + ctx_params.hires_upscalers_dir); } if (!valid) { @@ -275,7 +295,7 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { } bool load_images_from_dir(const std::string dir, - std::vector& images, + std::vector& images, int expected_width = 0, int expected_height = 0, int max_image_num = 0, @@ -312,12 +332,12 @@ bool load_images_from_dir(const std::string dir, return false; } - images.push_back({(uint32_t)width, - (uint32_t)height, - 3, - image_buffer}); + images.emplace_back(sd_image_t{(uint32_t)width, + (uint32_t)height, + 3, + image_buffer}); - if (max_image_num > 0 && images.size() >= max_image_num) { + if (max_image_num > 0 && static_cast(images.size()) >= max_image_num) { break; } } @@ -396,7 +416,9 @@ bool save_results(const SDCliParams& cli_params, if (!ext.empty()) { if (output_format == EncodedImageFormat::JPEG || output_format == EncodedImageFormat::PNG || - output_format == EncodedImageFormat::WEBP) { + output_format == EncodedImageFormat::WEBP || + ext_lower == ".avi" || + ext_lower == ".webm") { base_path.replace_extension(); } } @@ -411,10 +433,11 @@ bool save_results(const SDCliParams& cli_params, if (!img.data) return false; - std::string params = gen_params.embed_image_metadata - ? get_image_params(ctx_params, gen_params, gen_params.seed + idx) - : ""; - const bool ok = write_image_to_file(path.string(), img.data, img.width, img.height, img.channel, params, 90); + const int64_t metadata_seed = cli_params.mode == VID_GEN ? gen_params.seed : gen_params.seed + idx; + std::string params = gen_params.embed_image_metadata + ? get_image_params(ctx_params, gen_params, metadata_seed, cli_params.mode) + : ""; + const bool ok = write_image_to_file(path.string(), img.data, img.width, img.height, img.channel, params, 90); LOG_INFO("save result image %d to '%s' (%s)", idx, path.string().c_str(), ok ? "success" : "failure"); return ok; }; @@ -438,7 +461,7 @@ bool save_results(const SDCliParams& cli_params, } if (cli_params.mode == VID_GEN && num_results > 1) { - if (ext_lower != ".avi" && ext_lower != ".webp") + if (ext_lower != ".avi" && ext_lower != ".webp" && ext_lower != ".webm") ext = ".avi"; fs::path video_path = base_path; video_path += ext; @@ -552,39 +575,10 @@ int main(int argc, const char* argv[]) { } } - bool vae_decode_only = true; - sd_image_t init_image = {0, 0, 3, nullptr}; - sd_image_t end_image = {0, 0, 3, nullptr}; - sd_image_t control_image = {0, 0, 3, nullptr}; - sd_image_t mask_image = {0, 0, 1, nullptr}; - std::vector ref_images; - std::vector pmid_images; - std::vector control_frames; - - auto release_all_resources = [&]() { - free(init_image.data); - free(end_image.data); - free(control_image.data); - free(mask_image.data); - for (auto image : ref_images) { - free(image.data); - image.data = nullptr; - } - ref_images.clear(); - for (auto image : pmid_images) { - free(image.data); - image.data = nullptr; - } - pmid_images.clear(); - for (auto image : control_frames) { - free(image.data); - image.data = nullptr; - } - control_frames.clear(); - }; + bool vae_decode_only = true; auto load_image_and_update_size = [&](const std::string& path, - sd_image_t& image, + SDImageOwner& image, bool resize_image = true, int expected_channel = 3) -> bool { int expected_width = 0; @@ -594,74 +588,73 @@ int main(int argc, const char* argv[]) { expected_height = gen_params.height; } - if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) { + if (!load_sd_image_from_file(image.put(), path.c_str(), expected_width, expected_height, expected_channel)) { LOG_ERROR("load image from '%s' failed", path.c_str()); - release_all_resources(); return false; } - gen_params.set_width_and_height_if_unset(image.width, image.height); + gen_params.set_width_and_height_if_unset(image.get().width, image.get().height); return true; }; if (gen_params.init_image_path.size() > 0) { vae_decode_only = false; - if (!load_image_and_update_size(gen_params.init_image_path, init_image)) { + if (!load_image_and_update_size(gen_params.init_image_path, gen_params.init_image)) { return 1; } } if (gen_params.end_image_path.size() > 0) { vae_decode_only = false; - if (!load_image_and_update_size(gen_params.end_image_path, end_image)) { + if (!load_image_and_update_size(gen_params.end_image_path, gen_params.end_image)) { return 1; } } if (gen_params.ref_image_paths.size() > 0) { vae_decode_only = false; + gen_params.ref_images.clear(); for (auto& path : gen_params.ref_image_paths) { - sd_image_t ref_image = {0, 0, 3, nullptr}; + SDImageOwner ref_image({0, 0, 3, nullptr}); if (!load_image_and_update_size(path, ref_image, false)) { return 1; } - ref_images.push_back(ref_image); + gen_params.ref_images.push_back(std::move(ref_image)); } } if (gen_params.mask_image_path.size() > 0) { - if (!load_sd_image_from_file(&mask_image, + if (!load_sd_image_from_file(gen_params.mask_image.put(), gen_params.mask_image_path.c_str(), gen_params.get_resolved_width(), gen_params.get_resolved_height(), 1)) { LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str()); - release_all_resources(); return 1; } } else { - mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height()); - if (mask_image.data == nullptr) { + sd_image_t generated_mask = {0, 0, 1, nullptr}; + generated_mask.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height()); + if (generated_mask.data == nullptr) { LOG_ERROR("malloc mask image failed"); - release_all_resources(); return 1; } - mask_image.width = gen_params.get_resolved_width(); - mask_image.height = gen_params.get_resolved_height(); - memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height()); + generated_mask.width = gen_params.get_resolved_width(); + generated_mask.height = gen_params.get_resolved_height(); + memset(generated_mask.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height()); + gen_params.mask_image.reset(generated_mask); } if (gen_params.control_image_path.size() > 0) { - if (!load_sd_image_from_file(&control_image, + if (!load_sd_image_from_file(gen_params.control_image.put(), gen_params.control_image_path.c_str(), gen_params.get_resolved_width(), gen_params.get_resolved_height())) { LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str()); - release_all_resources(); return 1; } if (cli_params.canny_preprocess) { // apply preprocessor - preprocess_canny(control_image, + preprocess_canny(gen_params.control_image.get(), 0.08f, 0.08f, 0.8f, @@ -671,25 +664,25 @@ int main(int argc, const char* argv[]) { } if (!gen_params.control_video_path.empty()) { + gen_params.control_frames.clear(); if (!load_images_from_dir(gen_params.control_video_path, - control_frames, + gen_params.control_frames, gen_params.get_resolved_width(), gen_params.get_resolved_height(), gen_params.video_frames, cli_params.verbose)) { - release_all_resources(); return 1; } } if (!gen_params.pm_id_images_dir.empty()) { + gen_params.pm_id_images.clear(); if (!load_images_from_dir(gen_params.pm_id_images_dir, - pmid_images, + gen_params.pm_id_images, 0, 0, 0, cli_params.verbose)) { - release_all_resources(); return 1; } } @@ -698,119 +691,65 @@ int main(int argc, const char* argv[]) { vae_decode_only = false; } + if (gen_params.hires_enabled && + (gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_MODEL || + gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_LANCZOS || + gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_NEAREST)) { + vae_decode_only = false; + } + sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview); - sd_image_t* results = nullptr; - int num_results = 0; + SDImageVec results; + int num_results = 0; if (cli_params.mode == UPSCALE) { num_results = 1; - results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t)); - if (results == nullptr) { - LOG_INFO("failed to allocate results array"); - release_all_resources(); - return 1; - } - - results[0] = init_image; - init_image.data = nullptr; + results.push_back(gen_params.init_image.release()); } else { - sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); + SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params)); if (sd_ctx == nullptr) { LOG_INFO("new_sd_ctx_t failed"); - release_all_resources(); return 1; } if (gen_params.sample_params.sample_method == SAMPLE_METHOD_COUNT) { - gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx); + gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get()); } if (gen_params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) { - gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx); + gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get()); } if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) { - gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method); + gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx.get(), gen_params.sample_params.sample_method); } if (cli_params.mode == IMG_GEN) { - sd_img_gen_params_t img_gen_params = { - gen_params.lora_vec.data(), - static_cast(gen_params.lora_vec.size()), - gen_params.prompt.c_str(), - gen_params.negative_prompt.c_str(), - gen_params.clip_skip, - init_image, - ref_images.data(), - (int)ref_images.size(), - gen_params.auto_resize_ref_image, - gen_params.increase_ref_index, - mask_image, - gen_params.get_resolved_width(), - gen_params.get_resolved_height(), - gen_params.sample_params, - gen_params.strength, - gen_params.seed, - gen_params.batch_count, - control_image, - gen_params.control_strength, - { - pmid_images.data(), - (int)pmid_images.size(), - gen_params.pm_id_embed_path.c_str(), - gen_params.pm_style_strength, - }, // pm_params - gen_params.vae_tiling_params, - gen_params.cache_params, - }; - - results = generate_image(sd_ctx, &img_gen_params); + sd_img_gen_params_t img_gen_params = gen_params.to_sd_img_gen_params_t(); + num_results = gen_params.batch_count; + results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results); } else if (cli_params.mode == VID_GEN) { - sd_vid_gen_params_t vid_gen_params = { - gen_params.lora_vec.data(), - static_cast(gen_params.lora_vec.size()), - gen_params.prompt.c_str(), - gen_params.negative_prompt.c_str(), - gen_params.clip_skip, - init_image, - end_image, - control_frames.data(), - (int)control_frames.size(), - gen_params.get_resolved_width(), - gen_params.get_resolved_height(), - gen_params.sample_params, - gen_params.high_noise_sample_params, - gen_params.moe_boundary, - gen_params.strength, - gen_params.seed, - gen_params.video_frames, - gen_params.vace_strength, - gen_params.vae_tiling_params, - gen_params.cache_params, - }; - - results = generate_video(sd_ctx, &vid_gen_params, &num_results); - } - - if (results == nullptr) { + sd_vid_gen_params_t vid_gen_params = gen_params.to_sd_vid_gen_params_t(); + sd_image_t* generated_video = generate_video(sd_ctx.get(), &vid_gen_params, &num_results); + results.adopt(generated_video, num_results); + } + + if (!results) { LOG_ERROR("generate failed"); - free_sd_ctx(sd_ctx); return 1; } - - free_sd_ctx(sd_ctx); } int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth if (ctx_params.esrgan_path.size() > 0 && gen_params.upscale_repeats > 0) { - upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(ctx_params.esrgan_path.c_str(), - ctx_params.offload_params_to_cpu, - ctx_params.diffusion_conv_direct, - ctx_params.n_threads, - gen_params.upscale_tile_size); + UpscalerCtxPtr upscaler_ctx(new_upscaler_ctx(ctx_params.esrgan_path.c_str(), + ctx_params.offload_params_to_cpu, + ctx_params.diffusion_conv_direct, + ctx_params.n_threads, + gen_params.upscale_tile_size)); if (upscaler_ctx == nullptr) { LOG_ERROR("new_upscaler_ctx failed"); @@ -819,32 +758,24 @@ int main(int argc, const char* argv[]) { if (results[i].data == nullptr) { continue; } - sd_image_t current_image = results[i]; + SDImageOwner current_image(results[i]); + results[i] = {0, 0, 0, nullptr}; for (int u = 0; u < gen_params.upscale_repeats; ++u) { - sd_image_t upscaled_image = upscale(upscaler_ctx, current_image, upscale_factor); - if (upscaled_image.data == nullptr) { + SDImageOwner upscaled_image(upscale(upscaler_ctx.get(), current_image.get(), upscale_factor)); + if (upscaled_image.get().data == nullptr) { LOG_ERROR("upscale failed"); break; } - free(current_image.data); - current_image = upscaled_image; + current_image = std::move(upscaled_image); } - results[i] = current_image; // Set the final upscaled image as the result + results[i] = current_image.release(); // Set the final upscaled image as the result } } } - if (!save_results(cli_params, ctx_params, gen_params, results, num_results)) { + if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results)) { return 1; } - for (int i = 0; i < num_results; i++) { - free(results[i].data); - results[i].data = nullptr; - } - free(results); - - release_all_resources(); - return 0; } diff --git a/examples/common/common.cpp b/examples/common/common.cpp new file mode 100644 index 000000000..d4c8a72b8 --- /dev/null +++ b/examples/common/common.cpp @@ -0,0 +1,2544 @@ +#include "common.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#if defined(_WIN32) +#define NOMINMAX +#include +#endif // _WIN32 + +#include "log.h" +#include "media_io.h" +#include "resource_owners.hpp" + +using json = nlohmann::json; +namespace fs = std::filesystem; + +const char* const modes_str[] = { + "img_gen", + "vid_gen", + "convert", + "upscale", + "metadata", +}; + +#if defined(_WIN32) +static std::string utf16_to_utf8(const std::wstring& wstr) { + if (wstr.empty()) + return {}; + int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(), + nullptr, 0, nullptr, nullptr); + if (size_needed <= 0) + throw std::runtime_error("UTF-16 to UTF-8 conversion failed"); + + std::string utf8(size_needed, 0); + WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(), + (char*)utf8.data(), size_needed, nullptr, nullptr); + return utf8; +} + +static std::string argv_to_utf8(int index, const char** argv) { + (void)argv; + int argc; + wchar_t** argv_w = CommandLineToArgvW(GetCommandLineW(), &argc); + if (!argv_w) + throw std::runtime_error("Failed to parse command line"); + + std::string result; + if (index < argc) { + result = utf16_to_utf8(argv_w[index]); + } + LocalFree(argv_w); + return result; +} + +#else // Linux / macOS +static std::string argv_to_utf8(int index, const char** argv) { + return std::string(argv[index]); +} + +#endif + +template +static std::string vec_to_string(const std::vector& v) { + std::ostringstream oss; + oss << "["; + for (size_t i = 0; i < v.size(); i++) { + oss << v[i]; + if (i + 1 < v.size()) + oss << ", "; + } + oss << "]"; + return oss.str(); +} + +static std::string vec_str_to_string(const std::vector& v) { + std::ostringstream oss; + oss << "["; + for (size_t i = 0; i < v.size(); i++) { + oss << "\"" << v[i] << "\""; + if (i + 1 < v.size()) + oss << ", "; + } + oss << "]"; + return oss.str(); +} + +static bool is_absolute_path(const std::string& p) { +#ifdef _WIN32 + return p.size() > 1 && std::isalpha(static_cast(p[0])) && p[1] == ':'; +#else + return !p.empty() && p[0] == '/'; +#endif +} + +std::string ArgOptions::wrap_text(const std::string& text, size_t width, size_t indent) { + std::ostringstream oss; + size_t pos = 0; + size_t line_len = 0; + + while (pos < text.size()) { + if (text[pos] == '\n') { + oss << '\n' + << std::string(indent, ' '); + line_len = 0; + ++pos; + continue; + } + + if (std::isspace(static_cast(text[pos]))) { + ++pos; + continue; + } + + size_t word_start = pos; + while (pos < text.size() && + text[pos] != '\n' && + !std::isspace(static_cast(text[pos]))) { + ++pos; + } + + std::string word = text.substr(word_start, pos - word_start); + while (!word.empty()) { + size_t separator_len = line_len == 0 ? 0 : 1; + if (line_len + separator_len + word.size() <= width) { + if (separator_len > 0) { + oss << ' '; + ++line_len; + } + oss << word; + line_len += word.size(); + word.clear(); + continue; + } + + if (line_len > 0) { + oss << '\n' + << std::string(indent, ' '); + line_len = 0; + continue; + } + + size_t chunk_len = std::min(width, word.size()); + oss << word.substr(0, chunk_len); + line_len = chunk_len; + word.erase(0, chunk_len); + if (!word.empty()) { + oss << '\n' + << std::string(indent, ' '); + line_len = 0; + } + } + } + + return oss.str(); +} + +void ArgOptions::print() const { + constexpr size_t max_line_width = 120; + + struct Entry { + std::string names; + std::string desc; + }; + std::vector entries; + + auto add_entry = [&](const std::string& s, const std::string& l, + const std::string& desc, const std::string& hint = "") { + std::ostringstream ss; + if (!s.empty()) + ss << s; + if (!s.empty() && !l.empty()) + ss << ", "; + if (!l.empty()) + ss << l; + if (!hint.empty()) + ss << " " << hint; + entries.push_back({ss.str(), desc}); + }; + + for (auto& o : string_options) + add_entry(o.short_name, o.long_name, o.desc, ""); + for (auto& o : int_options) + add_entry(o.short_name, o.long_name, o.desc, ""); + for (auto& o : float_options) + add_entry(o.short_name, o.long_name, o.desc, ""); + for (auto& o : bool_options) + add_entry(o.short_name, o.long_name, o.desc, ""); + for (auto& o : manual_options) + add_entry(o.short_name, o.long_name, o.desc); + + size_t max_name_width = 0; + for (auto& e : entries) + max_name_width = std::max(max_name_width, e.names.size()); + + for (auto& e : entries) { + size_t indent = 2 + max_name_width + 4; + size_t desc_width = (max_line_width > indent ? max_line_width - indent : 40); + std::string wrapped_desc = wrap_text(e.desc, desc_width, indent); + std::cout << " " << std::left << std::setw(static_cast(max_name_width) + 4) + << e.names << wrapped_desc << "\n"; + } +} + +bool parse_options(int argc, const char** argv, const std::vector& options_list) { + bool invalid_arg = false; + std::string arg; + + auto match_and_apply = [&](auto& opts, auto&& apply_fn) -> bool { + for (auto& option : opts) { + if ((option.short_name.size() > 0 && arg == option.short_name) || + (option.long_name.size() > 0 && arg == option.long_name)) { + apply_fn(option); + return true; + } + } + return false; + }; + + for (int i = 1; i < argc; i++) { + arg = argv[i]; + bool found_arg = false; + + for (auto& options : options_list) { + if (match_and_apply(options.string_options, [&](auto& option) { + if (++i >= argc) { + invalid_arg = true; + return; + } + *option.target = argv_to_utf8(i, argv); + found_arg = true; + })) + break; + + if (match_and_apply(options.int_options, [&](auto& option) { + if (++i >= argc) { + invalid_arg = true; + return; + } + *option.target = std::stoi(argv[i]); + found_arg = true; + })) + break; + + if (match_and_apply(options.float_options, [&](auto& option) { + if (++i >= argc) { + invalid_arg = true; + return; + } + *option.target = std::stof(argv[i]); + found_arg = true; + })) + break; + + if (match_and_apply(options.bool_options, [&](auto& option) { + *option.target = option.keep_true ? true : false; + found_arg = true; + })) + break; + + if (match_and_apply(options.manual_options, [&](auto& option) { + int ret = option.cb(argc, argv, i); + if (ret < 0) { + invalid_arg = true; + return; + } + i += ret; + found_arg = true; + })) + break; + } + + if (invalid_arg) { + LOG_ERROR("error: invalid parameter for argument: %s", arg.c_str()); + return false; + } + if (!found_arg) { + LOG_ERROR("error: unknown argument: %s", arg.c_str()); + return false; + } + } + + return true; +} + +ArgOptions SDContextParams::get_options() { + ArgOptions options; + options.string_options = { + {"-m", + "--model", + "path to full model", + &model_path}, + {"", + "--clip_l", + "path to the clip-l text encoder", &clip_l_path}, + {"", "--clip_g", + "path to the clip-g text encoder", + &clip_g_path}, + {"", + "--clip_vision", + "path to the clip-vision encoder", + &clip_vision_path}, + {"", + "--t5xxl", + "path to the t5xxl text encoder", + &t5xxl_path}, + {"", + "--llm", + "path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, mistral-small3.2 for flux2, ...)", + &llm_path}, + {"", + "--llm_vision", + "path to the llm vit", + &llm_vision_path}, + {"", + "--qwen2vl", + "alias of --llm. Deprecated.", + &llm_path}, + {"", + "--qwen2vl_vision", + "alias of --llm_vision. Deprecated.", + &llm_vision_path}, + {"", + "--diffusion-model", + "path to the standalone diffusion model", + &diffusion_model_path}, + {"", + "--high-noise-diffusion-model", + "path to the standalone high noise diffusion model", + &high_noise_diffusion_model_path}, + {"", + "--vae", + "path to standalone vae model", + &vae_path}, + {"", + "--taesd", + "path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)", + &taesd_path}, + {"", + "--tae", + "alias of --taesd", + &taesd_path}, + {"", + "--control-net", + "path to control net model", + &control_net_path}, + {"", + "--embd-dir", + "embeddings directory", + &embedding_dir}, + {"", + "--lora-model-dir", + "lora model directory", + &lora_model_dir}, + {"", + "--hires-upscalers-dir", + "highres fix upscaler model directory", + &hires_upscalers_dir}, + {"", + "--tensor-type-rules", + "weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")", + &tensor_type_rules}, + {"", + "--photo-maker", + "path to PHOTOMAKER model", + &photo_maker_path}, + {"", + "--upscale-model", + "path to esrgan model.", + &esrgan_path}, + }; + + options.int_options = { + {"-t", + "--threads", + "number of threads to use during computation (default: -1). " + "If threads <= 0, then threads will be set to the number of CPU physical cores", + &n_threads}, + {"", + "--chroma-t5-mask-pad", + "t5 mask pad size of chroma", + &chroma_t5_mask_pad}, + }; + + options.float_options = { + {"", + "--max-vram", + "maximum VRAM budget in GiB for graph-cut segmented execution. 0 disables graph splitting", + &max_vram}, + }; + + options.bool_options = { + {"", + "--force-sdxl-vae-conv-scale", + "force use of conv scale on sdxl vae", + true, &force_sdxl_vae_conv_scale}, + {"", + "--offload-to-cpu", + "place the weights in RAM to save VRAM, and automatically load them into VRAM when needed", + true, &offload_params_to_cpu}, + {"", + "--mmap", + "whether to memory-map model", + true, &enable_mmap}, + {"", + "--control-net-cpu", + "keep controlnet in cpu (for low vram)", + true, &control_net_cpu}, + {"", + "--clip-on-cpu", + "keep clip in cpu (for low vram)", + true, &clip_on_cpu}, + {"", + "--vae-on-cpu", + "keep vae in cpu (for low vram)", + true, &vae_on_cpu}, + {"", + "--fa", + "use flash attention", + true, &flash_attn}, + {"", + "--diffusion-fa", + "use flash attention in the diffusion model only", + true, &diffusion_flash_attn}, + {"", + "--diffusion-conv-direct", + "use ggml_conv2d_direct in the diffusion model", + true, &diffusion_conv_direct}, + {"", + "--vae-conv-direct", + "use ggml_conv2d_direct in the vae model", + true, &vae_conv_direct}, + {"", + "--circular", + "enable circular padding for convolutions", + true, &circular}, + {"", + "--circularx", + "enable circular RoPE wrapping on x-axis (width) only", + true, &circular_x}, + {"", + "--circulary", + "enable circular RoPE wrapping on y-axis (height) only", + true, &circular_y}, + {"", + "--chroma-disable-dit-mask", + "disable dit mask for chroma", + false, &chroma_use_dit_mask}, + {"", + "--qwen-image-zero-cond-t", + "enable zero_cond_t for qwen image", + true, &qwen_image_zero_cond_t}, + {"", + "--chroma-enable-t5-mask", + "enable t5 mask for chroma", + true, &chroma_use_t5_mask}, + }; + + auto on_type_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + wtype = str_to_sd_type(arg); + if (wtype == SD_TYPE_COUNT) { + LOG_ERROR("error: invalid weight format %s", + arg); + return -1; + } + return 1; + }; + + auto on_rng_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + rng_type = str_to_rng_type(arg); + if (rng_type == RNG_TYPE_COUNT) { + LOG_ERROR("error: invalid rng type %s", + arg); + return -1; + } + return 1; + }; + + auto on_sampler_rng_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + sampler_rng_type = str_to_rng_type(arg); + if (sampler_rng_type == RNG_TYPE_COUNT) { + LOG_ERROR("error: invalid sampler rng type %s", + arg); + return -1; + } + return 1; + }; + + auto on_prediction_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + prediction = str_to_prediction(arg); + if (prediction == PREDICTION_COUNT) { + LOG_ERROR("error: invalid prediction type %s", + arg); + return -1; + } + return 1; + }; + + auto on_lora_apply_mode_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + lora_apply_mode = str_to_lora_apply_mode(arg); + if (lora_apply_mode == LORA_APPLY_MODE_COUNT) { + LOG_ERROR("error: invalid lora apply model %s", + arg); + return -1; + } + return 1; + }; + + options.manual_options = { + {"", + "--type", + "weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). " + "If not specified, the default is the type of the weight file", + on_type_arg}, + {"", + "--rng", + "RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)", + on_rng_arg}, + {"", + "--sampler-rng", + "sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng", + on_sampler_rng_arg}, + {"", + "--prediction", + "prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, flux2_flow]", + on_prediction_arg}, + {"", + "--lora-apply-mode", + "the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. " + "In auto mode, if the model weights contain any quantized parameters, the at_runtime mode will be used; otherwise, immediately will be used." + "The immediately mode may have precision and compatibility issues with quantized parameters, " + "but it usually offers faster inference speed and, in some cases, lower memory usage. " + "The at_runtime mode, on the other hand, is exactly the opposite.", + on_lora_apply_mode_arg}, + }; + + return options; +} + +void SDContextParams::build_embedding_map() { + static const std::vector valid_ext = {".gguf", ".safetensors", ".pt"}; + + if (!fs::exists(embedding_dir) || !fs::is_directory(embedding_dir)) { + return; + } + + for (auto& p : fs::directory_iterator(embedding_dir)) { + if (!p.is_regular_file()) + continue; + + auto path = p.path(); + std::string ext = path.extension().string(); + + bool valid = false; + for (auto& e : valid_ext) { + if (ext == e) { + valid = true; + break; + } + } + if (!valid) + continue; + + std::string key = path.stem().string(); + std::string value = path.string(); + + embedding_map[key] = value; + } +} + +bool SDContextParams::resolve(SDMode mode) { + if (n_threads <= 0) { + n_threads = sd_get_num_physical_cores(); + } + + build_embedding_map(); + + return true; +} + +bool SDContextParams::validate(SDMode mode) { + if (mode != UPSCALE && mode != METADATA && model_path.length() == 0 && diffusion_model_path.length() == 0) { + LOG_ERROR("error: the following arguments are required: model_path/diffusion_model\n"); + return false; + } + + if (mode == UPSCALE) { + if (esrgan_path.length() == 0) { + LOG_ERROR("error: upscale mode needs an upscaler model (--upscale-model)\n"); + return false; + } + } + + return true; +} + +bool SDContextParams::resolve_and_validate(SDMode mode) { + if (!resolve(mode)) { + return false; + } + if (!validate(mode)) { + return false; + } + return true; +} + +std::string SDContextParams::to_string() const { + std::ostringstream emb_ss; + emb_ss << "{\n"; + for (auto it = embedding_map.begin(); it != embedding_map.end(); ++it) { + emb_ss << " \"" << it->first << "\": \"" << it->second << "\""; + if (std::next(it) != embedding_map.end()) { + emb_ss << ","; + } + emb_ss << "\n"; + } + emb_ss << " }"; + + std::string embeddings_str = emb_ss.str(); + std::ostringstream oss; + oss << "SDContextParams {\n" + << " n_threads: " << n_threads << ",\n" + << " model_path: \"" << model_path << "\",\n" + << " clip_l_path: \"" << clip_l_path << "\",\n" + << " clip_g_path: \"" << clip_g_path << "\",\n" + << " clip_vision_path: \"" << clip_vision_path << "\",\n" + << " t5xxl_path: \"" << t5xxl_path << "\",\n" + << " llm_path: \"" << llm_path << "\",\n" + << " llm_vision_path: \"" << llm_vision_path << "\",\n" + << " diffusion_model_path: \"" << diffusion_model_path << "\",\n" + << " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n" + << " vae_path: \"" << vae_path << "\",\n" + << " taesd_path: \"" << taesd_path << "\",\n" + << " esrgan_path: \"" << esrgan_path << "\",\n" + << " control_net_path: \"" << control_net_path << "\",\n" + << " embedding_dir: \"" << embedding_dir << "\",\n" + << " embeddings: " << embeddings_str << "\n" + << " wtype: " << sd_type_name(wtype) << ",\n" + << " tensor_type_rules: \"" << tensor_type_rules << "\",\n" + << " lora_model_dir: \"" << lora_model_dir << "\",\n" + << " hires_upscalers_dir: \"" << hires_upscalers_dir << "\",\n" + << " photo_maker_path: \"" << photo_maker_path << "\",\n" + << " rng_type: " << sd_rng_type_name(rng_type) << ",\n" + << " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n" + << " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n" + << " max_vram: " << max_vram << ",\n" + << " enable_mmap: " << (enable_mmap ? "true" : "false") << ",\n" + << " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n" + << " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n" + << " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n" + << " flash_attn: " << (flash_attn ? "true" : "false") << ",\n" + << " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n" + << " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n" + << " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n" + << " circular: " << (circular ? "true" : "false") << ",\n" + << " circular_x: " << (circular_x ? "true" : "false") << ",\n" + << " circular_y: " << (circular_y ? "true" : "false") << ",\n" + << " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n" + << " qwen_image_zero_cond_t: " << (qwen_image_zero_cond_t ? "true" : "false") << ",\n" + << " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n" + << " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n" + << " prediction: " << sd_prediction_name(prediction) << ",\n" + << " lora_apply_mode: " << sd_lora_apply_mode_name(lora_apply_mode) << ",\n" + << " force_sdxl_vae_conv_scale: " << (force_sdxl_vae_conv_scale ? "true" : "false") << "\n" + << "}"; + return oss.str(); +} + +sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) { + embedding_vec.clear(); + embedding_vec.reserve(embedding_map.size()); + for (const auto& kv : embedding_map) { + sd_embedding_t item; + item.name = kv.first.c_str(); + item.path = kv.second.c_str(); + embedding_vec.emplace_back(item); + } + + sd_ctx_params_t sd_ctx_params = { + model_path.c_str(), + clip_l_path.c_str(), + clip_g_path.c_str(), + clip_vision_path.c_str(), + t5xxl_path.c_str(), + llm_path.c_str(), + llm_vision_path.c_str(), + diffusion_model_path.c_str(), + high_noise_diffusion_model_path.c_str(), + vae_path.c_str(), + taesd_path.c_str(), + control_net_path.c_str(), + embedding_vec.data(), + static_cast(embedding_vec.size()), + photo_maker_path.c_str(), + tensor_type_rules.c_str(), + vae_decode_only, + free_params_immediately, + n_threads, + wtype, + rng_type, + sampler_rng_type, + prediction, + lora_apply_mode, + offload_params_to_cpu, + enable_mmap, + clip_on_cpu, + control_net_cpu, + vae_on_cpu, + flash_attn, + diffusion_flash_attn, + taesd_preview, + diffusion_conv_direct, + vae_conv_direct, + circular || circular_x, + circular || circular_y, + force_sdxl_vae_conv_scale, + chroma_use_dit_mask, + chroma_use_t5_mask, + chroma_t5_mask_pad, + qwen_image_zero_cond_t, + max_vram, + }; + return sd_ctx_params; +} + +SDGenerationParams::SDGenerationParams() { + sd_sample_params_init(&sample_params); + sd_sample_params_init(&high_noise_sample_params); +} + +ArgOptions SDGenerationParams::get_options() { + ArgOptions options; + options.string_options = { + {"-p", + "--prompt", + "the prompt to render", + &prompt}, + {"-n", + "--negative-prompt", + "the negative prompt (default: \"\")", + &negative_prompt}, + {"-i", + "--init-img", + "path to the init image", + &init_image_path}, + {"", + "--end-img", + "path to the end image, required by flf2v", + &end_image_path}, + {"", + "--mask", + "path to the mask image", + &mask_image_path}, + {"", + "--control-image", + "path to control image, control net", + &control_image_path}, + {"", + "--control-video", + "path to control video frames, It must be a directory path. The video frames inside should be stored as images in " + "lexicographical (character) order. For example, if the control video path is `frames`, the directory contain images " + "such as 00.png, 01.png, ... etc.", + &control_video_path}, + {"", + "--pm-id-images-dir", + "path to PHOTOMAKER input id images dir", + &pm_id_images_dir}, + {"", + "--pm-id-embed-path", + "path to PHOTOMAKER v2 id embed", + &pm_id_embed_path}, + {"", + "--hires-upscaler", + "highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent (nearest-exact), " + "Latent (antialiased), Latent (bicubic), Latent (bicubic antialiased), or a model name " + "under --hires-upscalers-dir (default: Latent)", + &hires_upscaler}, + }; + + options.int_options = { + {"-H", + "--height", + "image height, in pixel space (default: 512)", + &height}, + {"-W", + "--width", + "image width, in pixel space (default: 512)", + &width}, + {"", + "--steps", + "number of sample steps (default: 20)", + &sample_params.sample_steps}, + {"", + "--high-noise-steps", + "(high noise) number of sample steps (default: -1 = auto)", + &high_noise_sample_params.sample_steps}, + {"", + "--clip-skip", + "ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1). " + "<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x", + &clip_skip}, + {"-b", + "--batch-count", + "batch count", + &batch_count}, + {"", + "--video-frames", + "video frames (default: 1)", + &video_frames}, + {"", + "--fps", + "fps (default: 24)", + &fps}, + {"", + "--timestep-shift", + "shift timestep for NitroFusion models (default: 0). " + "recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant", + &sample_params.shifted_timestep}, + {"", + "--upscale-repeats", + "Run the ESRGAN upscaler this many times (default: 1)", + &upscale_repeats}, + {"", + "--upscale-tile-size", + "tile size for ESRGAN upscaling (default: 128)", + &upscale_tile_size}, + {"", + "--hires-width", + "highres fix target width, 0 to use --hires-scale (default: 0)", + &hires_width}, + {"", + "--hires-height", + "highres fix target height, 0 to use --hires-scale (default: 0)", + &hires_height}, + {"", + "--hires-steps", + "highres fix second pass sample steps, 0 to reuse --steps (default: 0)", + &hires_steps}, + {"", + "--hires-upscale-tile-size", + "highres fix upscaler tile size, reserved for model-backed upscalers (default: 128)", + &hires_upscale_tile_size}, + }; + + options.float_options = { + {"", + "--cfg-scale", + "unconditional guidance scale: (default: 7.0)", + &sample_params.guidance.txt_cfg}, + {"", + "--img-cfg-scale", + "image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)", + &sample_params.guidance.img_cfg}, + {"", + "--guidance", + "distilled guidance scale for models with guidance input (default: 3.5)", + &sample_params.guidance.distilled_guidance}, + {"", + "--slg-scale", + "skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means disabled, a value of 2.5 is nice for sd3.5 medium", + &sample_params.guidance.slg.scale}, + {"", + "--skip-layer-start", + "SLG enabling point (default: 0.01)", + &sample_params.guidance.slg.layer_start}, + {"", + "--skip-layer-end", + "SLG disabling point (default: 0.2)", + &sample_params.guidance.slg.layer_end}, + {"", + "--eta", + "noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a)", + &sample_params.eta}, + {"", + "--flow-shift", + "shift value for Flow models like SD3.x or WAN (default: auto)", + &sample_params.flow_shift}, + {"", + "--high-noise-cfg-scale", + "(high noise) unconditional guidance scale: (default: 7.0)", + &high_noise_sample_params.guidance.txt_cfg}, + {"", + "--high-noise-img-cfg-scale", + "(high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)", + &high_noise_sample_params.guidance.img_cfg}, + {"", + "--high-noise-guidance", + "(high noise) distilled guidance scale for models with guidance input (default: 3.5)", + &high_noise_sample_params.guidance.distilled_guidance}, + {"", + "--high-noise-slg-scale", + "(high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)", + &high_noise_sample_params.guidance.slg.scale}, + {"", + "--high-noise-skip-layer-start", + "(high noise) SLG enabling point (default: 0.01)", + &high_noise_sample_params.guidance.slg.layer_start}, + {"", + "--high-noise-skip-layer-end", + "(high noise) SLG disabling point (default: 0.2)", + &high_noise_sample_params.guidance.slg.layer_end}, + {"", + "--high-noise-eta", + "(high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a)", + &high_noise_sample_params.eta}, + {"", + "--strength", + "strength for noising/unnoising (default: 0.75)", + &strength}, + {"", + "--pm-style-strength", + "", + &pm_style_strength}, + {"", + "--control-strength", + "strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image", + &control_strength}, + {"", + "--moe-boundary", + "timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if `--high-noise-steps` is set to -1", + &moe_boundary}, + {"", + "--vace-strength", + "wan vace strength", + &vace_strength}, + {"", + "--vae-tile-overlap", + "tile overlap for vae tiling, in fraction of tile size (default: 0.5)", + &vae_tiling_params.target_overlap}, + {"", + "--hires-scale", + "highres fix scale when target size is not set (default: 2.0)", + &hires_scale}, + {"", + "--hires-denoising-strength", + "highres fix second pass denoising strength (default: 0.7)", + &hires_denoising_strength}, + }; + + options.bool_options = { + {"", + "--increase-ref-index", + "automatically increase the indices of references images based on the order they are listed (starting with 1).", + true, + &increase_ref_index}, + {"", + "--disable-auto-resize-ref-image", + "disable auto resize of ref images", + false, + &auto_resize_ref_image}, + {"", + "--disable-image-metadata", + "do not embed generation metadata on image files", + false, + &embed_image_metadata}, + {"", + "--vae-tiling", + "process vae in tiles to reduce memory usage", + true, + &vae_tiling_params.enabled}, + {"", + "--hires", + "enable highres fix", + true, + &hires_enabled}, + }; + + auto on_seed_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + seed = std::stoll(argv[index]); + return 1; + }; + + auto on_sample_method_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + sample_params.sample_method = str_to_sample_method(arg); + if (sample_params.sample_method == SAMPLE_METHOD_COUNT) { + LOG_ERROR("error: invalid sample method %s", + arg); + return -1; + } + return 1; + }; + + auto on_high_noise_sample_method_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + high_noise_sample_params.sample_method = str_to_sample_method(arg); + if (high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) { + LOG_ERROR("error: invalid high noise sample method %s", + arg); + return -1; + } + return 1; + }; + + auto on_scheduler_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + sample_params.scheduler = str_to_scheduler(arg); + if (sample_params.scheduler == SCHEDULER_COUNT) { + LOG_ERROR("error: invalid scheduler %s", + arg); + return -1; + } + return 1; + }; + + auto on_skip_layers_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string layers_str = argv[index]; + if (layers_str[0] != '[' || layers_str[layers_str.size() - 1] != ']') { + return -1; + } + + layers_str = layers_str.substr(1, layers_str.size() - 2); + + std::regex regex("[, ]+"); + std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1); + std::sregex_token_iterator end; + std::vector tokens(iter, end); + std::vector layers; + for (const auto& token : tokens) { + try { + layers.push_back(std::stoi(token)); + } catch (const std::invalid_argument&) { + return -1; + } + } + skip_layers = layers; + return 1; + }; + + auto on_high_noise_skip_layers_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string layers_str = argv[index]; + if (layers_str[0] != '[' || layers_str[layers_str.size() - 1] != ']') { + return -1; + } + + layers_str = layers_str.substr(1, layers_str.size() - 2); + + std::regex regex("[, ]+"); + std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1); + std::sregex_token_iterator end; + std::vector tokens(iter, end); + std::vector layers; + for (const auto& token : tokens) { + try { + layers.push_back(std::stoi(token)); + } catch (const std::invalid_argument&) { + return -1; + } + } + high_noise_skip_layers = layers; + return 1; + }; + + auto on_sigmas_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string sigmas_str = argv[index]; + if (!sigmas_str.empty() && sigmas_str.front() == '[') { + sigmas_str.erase(0, 1); + } + if (!sigmas_str.empty() && sigmas_str.back() == ']') { + sigmas_str.pop_back(); + } + + std::stringstream ss(sigmas_str); + std::string item; + while (std::getline(ss, item, ',')) { + item.erase(0, item.find_first_not_of(" \t\n\r\f\v")); + item.erase(item.find_last_not_of(" \t\n\r\f\v") + 1); + if (!item.empty()) { + try { + custom_sigmas.push_back(std::stof(item)); + } catch (const std::invalid_argument&) { + LOG_ERROR("error: invalid float value '%s' in --sigmas", item.c_str()); + return -1; + } catch (const std::out_of_range&) { + LOG_ERROR("error: float value '%s' out of range in --sigmas", item.c_str()); + return -1; + } + } + } + + if (custom_sigmas.empty() && !sigmas_str.empty()) { + LOG_ERROR("error: could not parse any sigma values from '%s'", argv[index]); + return -1; + } + return 1; + }; + + auto on_ref_image_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + ref_image_paths.push_back(argv[index]); + return 1; + }; + + auto on_cache_mode_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + cache_mode = argv_to_utf8(index, argv); + if (cache_mode != "easycache" && cache_mode != "ucache" && + cache_mode != "dbcache" && cache_mode != "taylorseer" && cache_mode != "cache-dit" && cache_mode != "spectrum") { + fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache', 'ucache', 'dbcache', 'taylorseer', 'cache-dit', or 'spectrum'\n", cache_mode.c_str()); + return -1; + } + return 1; + }; + + auto on_cache_option_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + cache_option = argv_to_utf8(index, argv); + return 1; + }; + + auto on_scm_mask_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + scm_mask = argv_to_utf8(index, argv); + return 1; + }; + + auto on_scm_policy_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string policy = argv_to_utf8(index, argv); + if (policy == "dynamic") { + scm_policy_dynamic = true; + } else if (policy == "static") { + scm_policy_dynamic = false; + } else { + fprintf(stderr, "error: invalid scm policy '%s', must be 'dynamic' or 'static'\n", policy.c_str()); + return -1; + } + return 1; + }; + + auto on_tile_size_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string tile_size_str = argv[index]; + size_t x_pos = tile_size_str.find('x'); + try { + if (x_pos != std::string::npos) { + std::string tile_x_str = tile_size_str.substr(0, x_pos); + std::string tile_y_str = tile_size_str.substr(x_pos + 1); + vae_tiling_params.tile_size_x = std::stoi(tile_x_str); + vae_tiling_params.tile_size_y = std::stoi(tile_y_str); + } else { + vae_tiling_params.tile_size_x = vae_tiling_params.tile_size_y = std::stoi(tile_size_str); + } + } catch (const std::invalid_argument&) { + return -1; + } catch (const std::out_of_range&) { + return -1; + } + return 1; + }; + + auto on_relative_tile_size_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string rel_size_str = argv[index]; + size_t x_pos = rel_size_str.find('x'); + try { + if (x_pos != std::string::npos) { + std::string rel_x_str = rel_size_str.substr(0, x_pos); + std::string rel_y_str = rel_size_str.substr(x_pos + 1); + vae_tiling_params.rel_size_x = std::stof(rel_x_str); + vae_tiling_params.rel_size_y = std::stof(rel_y_str); + } else { + vae_tiling_params.rel_size_x = vae_tiling_params.rel_size_y = std::stof(rel_size_str); + } + } catch (const std::invalid_argument&) { + return -1; + } catch (const std::out_of_range&) { + return -1; + } + return 1; + }; + + options.manual_options = { + {"-s", + "--seed", + "RNG seed (default: 42, use random seed for < 0)", + on_seed_arg}, + {"", + "--sampling-method", + "sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s, er_sde] " + "(default: euler for Flux/SD3/Wan, euler_a otherwise)", + on_sample_method_arg}, + {"", + "--high-noise-sampling-method", + "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s, er_sde]" + " default: euler for Flux/SD3/Wan, euler_a otherwise", + on_high_noise_sample_method_arg}, + {"", + "--scheduler", + "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default: discrete", + on_scheduler_arg}, + {"", + "--sigmas", + "custom sigma values for the sampler, comma-separated (e.g., \"14.61,7.8,3.5,0.0\").", + on_sigmas_arg}, + {"", + "--skip-layers", + "layers to skip for SLG steps (default: [7,8,9])", + on_skip_layers_arg}, + {"", + "--high-noise-skip-layers", + "(high noise) layers to skip for SLG steps (default: [7,8,9])", + on_high_noise_skip_layers_arg}, + {"-r", + "--ref-image", + "reference image for Flux Kontext models (can be used multiple times)", + on_ref_image_arg}, + {"", + "--cache-mode", + "caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)", + on_cache_mode_arg}, + {"", + "--cache-option", + "named cache params (key=value format, comma-separated). easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=; spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"", + on_cache_option_arg}, + {"", + "--scm-mask", + "SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache", + on_scm_mask_arg}, + {"", + "--scm-policy", + "SCM policy: 'dynamic' (default) or 'static'", + on_scm_policy_arg}, + {"", + "--vae-tile-size", + "tile size for vae tiling, format [X]x[Y] (default: 32x32)", + on_tile_size_arg}, + {"", + "--vae-relative-tile-size", + "relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)", + on_relative_tile_size_arg}, + + }; + + return options; +} + +static const std::string k_base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + +static bool is_base64(unsigned char c) { + return std::isalnum(c) || c == '+' || c == '/'; +} + +static std::vector decode_base64_bytes(const std::string& encoded_string) { + int in_len = static_cast(encoded_string.size()); + int i = 0; + int j = 0; + int in_ = 0; + uint8_t char_array_4[4]; + uint8_t char_array_3[3]; + std::vector ret; + + while (in_len-- && encoded_string[in_] != '=' && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; + in_++; + if (i == 4) { + for (i = 0; i < 4; i++) { + char_array_4[i] = static_cast(k_base64_chars.find(char_array_4[i])); + } + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; i < 3; i++) { + ret.push_back(char_array_3[i]); + } + i = 0; + } + } + + if (i) { + for (j = i; j < 4; j++) { + char_array_4[j] = 0; + } + + for (j = 0; j < 4; j++) { + char_array_4[j] = static_cast(k_base64_chars.find(char_array_4[j])); + } + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; j < i - 1; j++) { + ret.push_back(char_array_3[j]); + } + } + + return ret; +} + +bool decode_base64_image(const std::string& encoded_input, + int target_channels, + int expected_width, + int expected_height, + SDImageOwner& out_image) { + std::string encoded = encoded_input; + auto comma_pos = encoded.find(','); + if (comma_pos != std::string::npos) { + encoded = encoded.substr(comma_pos + 1); + } + + std::vector image_bytes = decode_base64_bytes(encoded); + if (image_bytes.empty()) { + return false; + } + + int decoded_width = 0; + int decoded_height = 0; + uint8_t* raw_data = load_image_from_memory(reinterpret_cast(image_bytes.data()), + static_cast(image_bytes.size()), + decoded_width, + decoded_height, + expected_width, + expected_height, + target_channels); + if (raw_data == nullptr) { + return false; + } + + out_image.reset({(uint32_t)decoded_width, (uint32_t)decoded_height, (uint32_t)target_channels, raw_data}); + return true; +} + +static bool parse_image_json_field(const json& parent, + const char* key, + int channels, + int expected_width, + int expected_height, + SDImageOwner& out_image) { + if (!parent.contains(key)) { + return true; + } + if (parent.at(key).is_null()) { + out_image.reset({0, 0, (uint32_t)channels, nullptr}); + return true; + } + if (!parent.at(key).is_string()) { + return false; + } + return decode_base64_image(parent.at(key).get(), channels, expected_width, expected_height, out_image); +} + +static bool parse_image_array_json_field(const json& parent, + const char* key, + int channels, + int expected_width, + int expected_height, + std::vector& out_images) { + if (!parent.contains(key)) { + return true; + } + if (parent.at(key).is_null()) { + out_images.clear(); + return true; + } + if (!parent.at(key).is_array()) { + return false; + } + + out_images.clear(); + for (const auto& item : parent.at(key)) { + if (!item.is_string()) { + return false; + } + SDImageOwner image; + if (!decode_base64_image(item.get(), channels, expected_width, expected_height, image)) { + return false; + } + out_images.push_back(std::move(image)); + } + return true; +} + +static bool parse_lora_json_field(const json& parent, + const std::function& lora_path_resolver, + std::map& lora_map, + std::map& high_noise_lora_map) { + if (!parent.contains("lora")) { + return true; + } + if (!parent.at("lora").is_array()) { + return false; + } + + lora_map.clear(); + high_noise_lora_map.clear(); + for (const auto& item : parent.at("lora")) { + if (!item.is_object()) { + return false; + } + + std::string path = item.value("path", ""); + if (path.empty()) { + return false; + } + + std::string resolved_path = lora_path_resolver ? lora_path_resolver(path) : path; + if (resolved_path.empty()) { + return false; + } + + const float multiplier = item.value("multiplier", 1.0f); + const bool is_high_noise = item.value("is_high_noise", false); + if (is_high_noise) { + high_noise_lora_map[resolved_path] += multiplier; + } else { + lora_map[resolved_path] += multiplier; + } + } + + return true; +} + +static bool resolve_model_file_from_dir(const std::string& model_name, + const std::string& model_dir, + const std::vector& valid_ext, + const char* label, + std::string& resolved_path) { + if (model_dir.empty()) { + LOG_ERROR("%s directory is empty", label); + return false; + } + if (model_name.empty() || + model_name.find('/') != std::string::npos || + model_name.find('\\') != std::string::npos || + fs::path(model_name).has_root_path() || + fs::path(model_name).has_extension()) { + LOG_ERROR("%s must be a model name without path or extension: %s", label, model_name.c_str()); + return false; + } + + fs::path model_dir_path = model_dir; + for (const auto& ext : valid_ext) { + fs::path try_path = model_dir_path / (model_name + ext); + if (fs::exists(try_path) && fs::is_regular_file(try_path)) { + resolved_path = try_path.lexically_normal().string(); + return true; + } + } + + LOG_ERROR("can not find %s %s in %s", label, model_name.c_str(), model_dir_path.lexically_normal().string().c_str()); + return false; +} + +bool SDGenerationParams::from_json_str( + const std::string& json_str, + const std::function& lora_path_resolver) { + json j; + try { + j = json::parse(json_str); + } catch (...) { + LOG_ERROR("json parse failed %s", json_str.c_str()); + return false; + } + + auto load_if_exists = [&](const char* key, auto& out) { + if (j.contains(key)) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + if (j[key].is_string()) + out = j[key]; + } else if constexpr (std::is_same_v || std::is_same_v) { + if (j[key].is_number_integer()) + out = j[key]; + } else if constexpr (std::is_same_v) { + if (j[key].is_number()) + out = j[key]; + } else if constexpr (std::is_same_v) { + if (j[key].is_boolean()) + out = j[key]; + } else if constexpr (std::is_same_v>) { + if (j[key].is_array()) + out = j[key].get>(); + } else if constexpr (std::is_same_v>) { + if (j[key].is_array()) + out = j[key].get>(); + } else if constexpr (std::is_same_v>) { + if (j[key].is_array()) + out = j[key].get>(); + } + } + }; + + load_if_exists("prompt", prompt); + load_if_exists("negative_prompt", negative_prompt); + load_if_exists("cache_mode", cache_mode); + load_if_exists("cache_option", cache_option); + load_if_exists("scm_mask", scm_mask); + + load_if_exists("clip_skip", clip_skip); + load_if_exists("width", width); + load_if_exists("height", height); + load_if_exists("batch_count", batch_count); + load_if_exists("video_frames", video_frames); + load_if_exists("fps", fps); + load_if_exists("upscale_repeats", upscale_repeats); + load_if_exists("seed", seed); + + load_if_exists("strength", strength); + load_if_exists("control_strength", control_strength); + load_if_exists("moe_boundary", moe_boundary); + load_if_exists("vace_strength", vace_strength); + + load_if_exists("auto_resize_ref_image", auto_resize_ref_image); + load_if_exists("increase_ref_index", increase_ref_index); + load_if_exists("embed_image_metadata", embed_image_metadata); + + if (j.contains("hires") && j["hires"].is_object()) { + const json& hires_json = j["hires"]; + if (hires_json.contains("enabled") && hires_json["enabled"].is_boolean()) { + hires_enabled = hires_json["enabled"]; + } + if (hires_json.contains("upscaler") && hires_json["upscaler"].is_string()) { + hires_upscaler = hires_json["upscaler"]; + } + if (hires_json.contains("scale") && hires_json["scale"].is_number()) { + hires_scale = hires_json["scale"]; + } + if (hires_json.contains("target_width") && hires_json["target_width"].is_number_integer()) { + hires_width = hires_json["target_width"]; + } + if (hires_json.contains("target_height") && hires_json["target_height"].is_number_integer()) { + hires_height = hires_json["target_height"]; + } + if (hires_json.contains("steps") && hires_json["steps"].is_number_integer()) { + hires_steps = hires_json["steps"]; + } + if (hires_json.contains("denoising_strength") && hires_json["denoising_strength"].is_number()) { + hires_denoising_strength = hires_json["denoising_strength"]; + } + if (hires_json.contains("upscale_tile_size") && hires_json["upscale_tile_size"].is_number_integer()) { + hires_upscale_tile_size = hires_json["upscale_tile_size"]; + } + } + + auto parse_sample_params_json = [&](const json& sample_json, + sd_sample_params_t& target_params, + std::vector& target_skip_layers, + std::vector* target_custom_sigmas) { + if (sample_json.contains("sample_steps") && sample_json["sample_steps"].is_number_integer()) { + target_params.sample_steps = sample_json["sample_steps"]; + } + if (sample_json.contains("eta") && sample_json["eta"].is_number()) { + target_params.eta = sample_json["eta"]; + } + if (sample_json.contains("shifted_timestep") && sample_json["shifted_timestep"].is_number_integer()) { + target_params.shifted_timestep = sample_json["shifted_timestep"]; + } + if (sample_json.contains("flow_shift") && sample_json["flow_shift"].is_number()) { + target_params.flow_shift = sample_json["flow_shift"]; + } + if (target_custom_sigmas != nullptr && + sample_json.contains("custom_sigmas") && + sample_json["custom_sigmas"].is_array()) { + *target_custom_sigmas = sample_json["custom_sigmas"].get>(); + } + if (sample_json.contains("sample_method") && sample_json["sample_method"].is_string()) { + enum sample_method_t tmp = str_to_sample_method(sample_json["sample_method"].get().c_str()); + if (tmp != SAMPLE_METHOD_COUNT) { + target_params.sample_method = tmp; + } + } + if (sample_json.contains("scheduler") && sample_json["scheduler"].is_string()) { + enum scheduler_t tmp = str_to_scheduler(sample_json["scheduler"].get().c_str()); + if (tmp != SCHEDULER_COUNT) { + target_params.scheduler = tmp; + } + } + if (sample_json.contains("guidance") && sample_json["guidance"].is_object()) { + const json& guidance_json = sample_json["guidance"]; + if (guidance_json.contains("txt_cfg") && guidance_json["txt_cfg"].is_number()) { + target_params.guidance.txt_cfg = guidance_json["txt_cfg"]; + } + if (guidance_json.contains("img_cfg") && guidance_json["img_cfg"].is_number()) { + target_params.guidance.img_cfg = guidance_json["img_cfg"]; + } + if (guidance_json.contains("distilled_guidance") && guidance_json["distilled_guidance"].is_number()) { + target_params.guidance.distilled_guidance = guidance_json["distilled_guidance"]; + } + if (guidance_json.contains("slg") && guidance_json["slg"].is_object()) { + const json& slg_json = guidance_json["slg"]; + if (slg_json.contains("layers") && slg_json["layers"].is_array()) { + target_skip_layers = slg_json["layers"].get>(); + } + if (slg_json.contains("layer_start") && slg_json["layer_start"].is_number()) { + target_params.guidance.slg.layer_start = slg_json["layer_start"]; + } + if (slg_json.contains("layer_end") && slg_json["layer_end"].is_number()) { + target_params.guidance.slg.layer_end = slg_json["layer_end"]; + } + if (slg_json.contains("scale") && slg_json["scale"].is_number()) { + target_params.guidance.slg.scale = slg_json["scale"]; + } + } + } + }; + + if (j.contains("sample_params") && j["sample_params"].is_object()) { + parse_sample_params_json(j["sample_params"], sample_params, skip_layers, &custom_sigmas); + } + if (j.contains("high_noise_sample_params") && j["high_noise_sample_params"].is_object()) { + parse_sample_params_json(j["high_noise_sample_params"], + high_noise_sample_params, + high_noise_skip_layers, + nullptr); + } + + if (j.contains("vae_tiling_params") && j["vae_tiling_params"].is_object()) { + const json& tiling_json = j["vae_tiling_params"]; + if (tiling_json.contains("enabled") && tiling_json["enabled"].is_boolean()) { + vae_tiling_params.enabled = tiling_json["enabled"]; + } + if (tiling_json.contains("tile_size_x") && tiling_json["tile_size_x"].is_number_integer()) { + vae_tiling_params.tile_size_x = tiling_json["tile_size_x"]; + } + if (tiling_json.contains("tile_size_y") && tiling_json["tile_size_y"].is_number_integer()) { + vae_tiling_params.tile_size_y = tiling_json["tile_size_y"]; + } + if (tiling_json.contains("target_overlap") && tiling_json["target_overlap"].is_number()) { + vae_tiling_params.target_overlap = tiling_json["target_overlap"]; + } + if (tiling_json.contains("rel_size_x") && tiling_json["rel_size_x"].is_number()) { + vae_tiling_params.rel_size_x = tiling_json["rel_size_x"]; + } + if (tiling_json.contains("rel_size_y") && tiling_json["rel_size_y"].is_number()) { + vae_tiling_params.rel_size_y = tiling_json["rel_size_y"]; + } + } + + if (!parse_lora_json_field(j, lora_path_resolver, lora_map, high_noise_lora_map)) { + LOG_ERROR("invalid lora"); + return false; + } + if (!parse_image_json_field(j, "init_image", 3, width, height, init_image)) { + LOG_ERROR("invalid init_image"); + return false; + } + if (!parse_image_json_field(j, "end_image", 3, width, height, end_image)) { + LOG_ERROR("invalid end_image"); + return false; + } + if (!parse_image_array_json_field(j, "ref_images", 3, width, height, ref_images)) { + LOG_ERROR("invalid ref_images"); + return false; + } + if (!parse_image_array_json_field(j, "control_frames", 3, width, height, control_frames)) { + LOG_ERROR("invalid control_frames"); + return false; + } + if (!parse_image_json_field(j, "mask_image", 1, width, height, mask_image)) { + LOG_ERROR("invalid mask_image"); + return false; + } + if (!parse_image_json_field(j, "control_image", 3, width, height, control_image)) { + LOG_ERROR("invalid control_image"); + return false; + } + + return true; +} + +void SDGenerationParams::extract_and_remove_lora(const std::string& lora_model_dir) { + if (lora_model_dir.empty()) { + return; + } + static const std::regex re(R"(]+):([^>]+)>)"); + static const std::vector valid_ext = {".gguf", ".safetensors", ".pt"}; + std::smatch m; + + std::string tmp = prompt; + + while (std::regex_search(tmp, m, re)) { + std::string raw_path = m[1].str(); + const std::string raw_mul = m[2].str(); + + float mul = 0.f; + try { + mul = std::stof(raw_mul); + } catch (...) { + tmp = m.suffix().str(); + prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); + continue; + } + + bool is_high_noise = false; + static const std::string prefix = "|high_noise|"; + if (raw_path.rfind(prefix, 0) == 0) { + raw_path.erase(0, prefix.size()); + is_high_noise = true; + } + + fs::path final_path; + if (is_absolute_path(raw_path)) { + final_path = raw_path; + } else { + final_path = fs::path(lora_model_dir) / raw_path; + } + if (!fs::exists(final_path)) { + bool found = false; + for (const auto& ext : valid_ext) { + fs::path try_path = final_path; + try_path += ext; + if (fs::exists(try_path)) { + final_path = try_path; + found = true; + break; + } + } + if (!found) { + LOG_WARN("can not found lora %s", final_path.lexically_normal().string().c_str()); + tmp = m.suffix().str(); + prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); + continue; + } + } + + const std::string key = final_path.lexically_normal().string(); + + if (is_high_noise) + high_noise_lora_map[key] += mul; + else + lora_map[key] += mul; + + prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); + + tmp = m.suffix().str(); + } +} + +bool SDGenerationParams::width_and_height_are_set() const { + return width > 0 && height > 0; +} + +void SDGenerationParams::set_width_and_height_if_unset(int w, int h) { + if (!width_and_height_are_set()) { + LOG_INFO("set width x height to %d x %d", w, h); + width = w; + height = h; + } +} + +int SDGenerationParams::get_resolved_width() const { + return (width > 0) ? width : 512; +} + +int SDGenerationParams::get_resolved_height() const { + return (height > 0) ? height : 512; +} + +bool SDGenerationParams::initialize_cache_params() { + sd_cache_params_init(&cache_params); + + auto parse_named_params = [&](const std::string& opt_str) -> bool { + std::stringstream ss(opt_str); + std::string token; + while (std::getline(ss, token, ',')) { + size_t eq_pos = token.find('='); + if (eq_pos == std::string::npos) { + LOG_ERROR("error: cache option '%s' missing '=' separator", token.c_str()); + return false; + } + std::string key = token.substr(0, eq_pos); + std::string val = token.substr(eq_pos + 1); + try { + if (key == "threshold") { + if (cache_mode == "easycache" || cache_mode == "ucache") { + cache_params.reuse_threshold = std::stof(val); + } else { + cache_params.residual_diff_threshold = std::stof(val); + } + } else if (key == "start") { + cache_params.start_percent = std::stof(val); + } else if (key == "end") { + cache_params.end_percent = std::stof(val); + } else if (key == "decay") { + cache_params.error_decay_rate = std::stof(val); + } else if (key == "relative") { + cache_params.use_relative_threshold = (std::stof(val) != 0.0f); + } else if (key == "reset") { + cache_params.reset_error_on_compute = (std::stof(val) != 0.0f); + } else if (key == "Fn" || key == "fn") { + cache_params.Fn_compute_blocks = std::stoi(val); + } else if (key == "Bn" || key == "bn") { + cache_params.Bn_compute_blocks = std::stoi(val); + } else if (key == "warmup") { + if (cache_mode == "spectrum") { + cache_params.spectrum_warmup_steps = std::stoi(val); + } else { + cache_params.max_warmup_steps = std::stoi(val); + } + } else if (key == "w") { + cache_params.spectrum_w = std::stof(val); + } else if (key == "m") { + cache_params.spectrum_m = std::stoi(val); + } else if (key == "lam") { + cache_params.spectrum_lam = std::stof(val); + } else if (key == "window") { + cache_params.spectrum_window_size = std::stoi(val); + } else if (key == "flex") { + cache_params.spectrum_flex_window = std::stof(val); + } else if (key == "stop") { + cache_params.spectrum_stop_percent = std::stof(val); + } else { + LOG_ERROR("error: unknown cache parameter '%s'", key.c_str()); + return false; + } + } catch (const std::exception&) { + LOG_ERROR("error: invalid value '%s' for parameter '%s'", val.c_str(), key.c_str()); + return false; + } + } + return true; + }; + + if (!cache_mode.empty()) { + if (cache_mode == "disabled") { + cache_params.mode = SD_CACHE_DISABLED; + } else if (cache_mode == "easycache") { + cache_params.mode = SD_CACHE_EASYCACHE; + } else if (cache_mode == "ucache") { + cache_params.mode = SD_CACHE_UCACHE; + } else if (cache_mode == "dbcache") { + cache_params.mode = SD_CACHE_DBCACHE; + } else if (cache_mode == "taylorseer") { + cache_params.mode = SD_CACHE_TAYLORSEER; + } else if (cache_mode == "cache-dit") { + cache_params.mode = SD_CACHE_CACHE_DIT; + } else if (cache_mode == "spectrum") { + cache_params.mode = SD_CACHE_SPECTRUM; + } else { + LOG_ERROR("error: invalid cache mode '%s'", cache_mode.c_str()); + return false; + } + } + + if (!cache_option.empty() && !parse_named_params(cache_option)) { + return false; + } + + if (cache_params.mode == SD_CACHE_DBCACHE || + cache_params.mode == SD_CACHE_TAYLORSEER || + cache_params.mode == SD_CACHE_CACHE_DIT) { + cache_params.scm_policy_dynamic = scm_policy_dynamic; + } + + return true; +} + +bool SDGenerationParams::resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict) { + if (high_noise_sample_params.sample_steps <= 0) { + high_noise_sample_params.sample_steps = -1; + } + + if (!initialize_cache_params()) { + return false; + } + + if (seed < 0) { + srand((int)time(nullptr)); + seed = rand(); + } + + if (strict) { + batch_count = std::clamp(batch_count, 1, 8); + sample_params.sample_steps = std::clamp(sample_params.sample_steps, 1, 100); + } + + hires_upscaler_model_path.clear(); + if (hires_enabled) { + if (hires_upscaler.empty()) { + hires_upscaler = "Latent"; + } + resolved_hires_upscaler = str_to_sd_hires_upscaler(hires_upscaler.c_str()); + if (resolved_hires_upscaler == SD_HIRES_UPSCALER_NONE) { + hires_enabled = false; + } else if (resolved_hires_upscaler == SD_HIRES_UPSCALER_COUNT) { + static const std::vector valid_ext = {".gguf", ".safetensors", ".pt", ".pth"}; + if (!resolve_model_file_from_dir(hires_upscaler, + hires_upscalers_dir, + valid_ext, + "hires upscaler", + hires_upscaler_model_path)) { + return false; + } + resolved_hires_upscaler = SD_HIRES_UPSCALER_MODEL; + } + } + + prompt_with_lora = prompt; + if (!lora_model_dir.empty()) { + extract_and_remove_lora(lora_model_dir); + } + return true; +} + +bool SDGenerationParams::validate(SDMode mode) { + if (batch_count <= 0) { + LOG_ERROR("error: batch_count must be greater than 0"); + return false; + } + + if (sample_params.sample_steps <= 0) { + LOG_ERROR("error: the sample_steps must be greater than 0\n"); + return false; + } + + if (strength < 0.f || strength > 1.f) { + LOG_ERROR("error: can only work with strength in [0.0, 1.0]\n"); + return false; + } + + if (sample_params.guidance.txt_cfg < 0.f) { + LOG_ERROR("error: cfg_scale must be positive"); + return false; + } + + if (!cache_mode.empty()) { + if (cache_mode == "easycache" || cache_mode == "ucache") { + if (cache_params.reuse_threshold < 0.0f) { + LOG_ERROR("error: cache threshold must be non-negative"); + return false; + } + if (cache_params.start_percent < 0.0f || cache_params.start_percent >= 1.0f || + cache_params.end_percent <= 0.0f || cache_params.end_percent > 1.0f || + cache_params.start_percent >= cache_params.end_percent) { + LOG_ERROR("error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0"); + return false; + } + } + } + + if (mode == VID_GEN && video_frames <= 0) { + return false; + } + + if (mode == VID_GEN && fps <= 0) { + return false; + } + + if (sample_params.shifted_timestep < 0 || sample_params.shifted_timestep > 1000) { + LOG_ERROR("error: shifted_timestep must be in range [0, 1000]"); + return false; + } + + if (upscale_repeats < 1) { + return false; + } + + if (upscale_tile_size < 1) { + return false; + } + + if (hires_enabled) { + if (hires_width < 0 || hires_height < 0) { + LOG_ERROR("error: hires target width and height must be >= 0"); + return false; + } + if (hires_scale <= 0.f && hires_width <= 0 && hires_height <= 0) { + LOG_ERROR("error: hires scale must be positive when target size is not set"); + return false; + } + if (hires_steps < 0) { + LOG_ERROR("error: hires steps must be >= 0"); + return false; + } + if (hires_denoising_strength <= 0.f || hires_denoising_strength > 1.f) { + LOG_ERROR("error: hires denoising strength must be in (0.0, 1.0]"); + return false; + } + if (hires_upscale_tile_size < 1) { + LOG_ERROR("error: hires upscale tile size must be positive"); + return false; + } + } + + if (mode == UPSCALE) { + if (init_image_path.length() == 0) { + LOG_ERROR("error: upscale mode needs an init image (--init-img)\n"); + return false; + } + } + + return true; +} + +bool SDGenerationParams::resolve_and_validate(SDMode mode, + const std::string& lora_model_dir, + const std::string& hires_upscalers_dir, + bool strict) { + if (!resolve(lora_model_dir, hires_upscalers_dir, strict)) { + return false; + } + if (!validate(mode)) { + return false; + } + return true; +} + +sd_img_gen_params_t SDGenerationParams::to_sd_img_gen_params_t() { + sd_img_gen_params_t params; + sd_img_gen_params_init(¶ms); + + lora_vec.clear(); + lora_vec.reserve(lora_map.size() + high_noise_lora_map.size()); + for (const auto& kv : lora_map) { + lora_vec.push_back({false, kv.second, kv.first.c_str()}); + } + for (const auto& kv : high_noise_lora_map) { + lora_vec.push_back({true, kv.second, kv.first.c_str()}); + } + + ref_image_views.clear(); + ref_image_views.reserve(ref_images.size()); + for (auto& ref_image : ref_images) { + ref_image_views.push_back(ref_image.get()); + } + + pm_id_image_views.clear(); + pm_id_image_views.reserve(pm_id_images.size()); + for (auto& image : pm_id_images) { + pm_id_image_views.push_back(image.get()); + } + + sample_params.guidance.slg.layers = skip_layers.empty() ? nullptr : skip_layers.data(); + sample_params.guidance.slg.layer_count = skip_layers.size(); + high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.empty() ? nullptr : high_noise_skip_layers.data(); + high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size(); + sample_params.custom_sigmas = custom_sigmas.empty() ? nullptr : custom_sigmas.data(); + sample_params.custom_sigmas_count = static_cast(custom_sigmas.size()); + cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str(); + + sd_pm_params_t pm_params = { + pm_id_image_views.empty() ? nullptr : pm_id_image_views.data(), + static_cast(pm_id_image_views.size()), + pm_id_embed_path.empty() ? nullptr : pm_id_embed_path.c_str(), + pm_style_strength, + }; + + params.loras = lora_vec.empty() ? nullptr : lora_vec.data(); + params.lora_count = static_cast(lora_vec.size()); + params.prompt = prompt.c_str(); + params.negative_prompt = negative_prompt.c_str(); + params.clip_skip = clip_skip; + params.init_image = init_image.get(); + params.ref_images = ref_image_views.empty() ? nullptr : ref_image_views.data(); + params.ref_images_count = static_cast(ref_image_views.size()); + params.auto_resize_ref_image = auto_resize_ref_image; + params.increase_ref_index = increase_ref_index; + params.mask_image = mask_image.get(); + params.width = get_resolved_width(); + params.height = get_resolved_height(); + params.sample_params = sample_params; + params.strength = strength; + params.seed = seed; + params.batch_count = batch_count; + params.control_image = control_image.get(); + params.control_strength = control_strength; + params.pm_params = pm_params; + params.vae_tiling_params = vae_tiling_params; + params.cache = cache_params; + + params.hires.enabled = hires_enabled; + params.hires.upscaler = resolved_hires_upscaler; + params.hires.model_path = hires_upscaler_model_path.empty() ? nullptr : hires_upscaler_model_path.c_str(); + params.hires.scale = hires_scale; + params.hires.target_width = hires_width; + params.hires.target_height = hires_height; + params.hires.steps = hires_steps; + params.hires.denoising_strength = hires_denoising_strength; + params.hires.upscale_tile_size = hires_upscale_tile_size; + return params; +} + +sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() { + sd_vid_gen_params_t params; + sd_vid_gen_params_init(¶ms); + + lora_vec.clear(); + lora_vec.reserve(lora_map.size() + high_noise_lora_map.size()); + for (const auto& kv : lora_map) { + lora_vec.push_back({false, kv.second, kv.first.c_str()}); + } + for (const auto& kv : high_noise_lora_map) { + lora_vec.push_back({true, kv.second, kv.first.c_str()}); + } + + control_frame_views.clear(); + control_frame_views.reserve(control_frames.size()); + for (auto& frame : control_frames) { + control_frame_views.push_back(frame.get()); + } + + sample_params.guidance.slg.layers = skip_layers.empty() ? nullptr : skip_layers.data(); + sample_params.guidance.slg.layer_count = skip_layers.size(); + high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.empty() ? nullptr : high_noise_skip_layers.data(); + high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size(); + sample_params.custom_sigmas = custom_sigmas.empty() ? nullptr : custom_sigmas.data(); + sample_params.custom_sigmas_count = static_cast(custom_sigmas.size()); + cache_params.scm_mask = scm_mask.empty() ? nullptr : scm_mask.c_str(); + + params.loras = lora_vec.empty() ? nullptr : lora_vec.data(); + params.lora_count = static_cast(lora_vec.size()); + params.prompt = prompt.c_str(); + params.negative_prompt = negative_prompt.c_str(); + params.clip_skip = clip_skip; + params.init_image = init_image.get(); + params.end_image = end_image.get(); + params.control_frames = control_frame_views.empty() ? nullptr : control_frame_views.data(); + params.control_frames_size = static_cast(control_frame_views.size()); + params.width = get_resolved_width(); + params.height = get_resolved_height(); + params.sample_params = sample_params; + params.high_noise_sample_params = high_noise_sample_params; + params.moe_boundary = moe_boundary; + params.strength = strength; + params.seed = seed; + params.video_frames = video_frames; + params.vace_strength = vace_strength; + params.vae_tiling_params = vae_tiling_params; + params.cache = cache_params; + return params; +} + +std::string SDGenerationParams::to_string() const { + FreeUniquePtr sample_params_str(sd_sample_params_to_str(&sample_params)); + FreeUniquePtr high_noise_sample_params_str(sd_sample_params_to_str(&high_noise_sample_params)); + + std::ostringstream lora_ss; + lora_ss << "{\n"; + for (auto it = lora_map.begin(); it != lora_map.end(); ++it) { + lora_ss << " \"" << it->first << "\": \"" << it->second << "\""; + if (std::next(it) != lora_map.end()) { + lora_ss << ","; + } + lora_ss << "\n"; + } + lora_ss << " }"; + std::string loras_str = lora_ss.str(); + + lora_ss = std::ostringstream(); + ; + lora_ss << "{\n"; + for (auto it = high_noise_lora_map.begin(); it != high_noise_lora_map.end(); ++it) { + lora_ss << " \"" << it->first << "\": \"" << it->second << "\""; + if (std::next(it) != high_noise_lora_map.end()) { + lora_ss << ","; + } + lora_ss << "\n"; + } + lora_ss << " }"; + std::string high_noise_loras_str = lora_ss.str(); + + std::ostringstream oss; + oss << "SDGenerationParams {\n" + << " loras: \"" << loras_str << "\",\n" + << " high_noise_loras: \"" << high_noise_loras_str << "\",\n" + << " prompt: \"" << prompt << "\",\n" + << " negative_prompt: \"" << negative_prompt << "\",\n" + << " clip_skip: " << clip_skip << ",\n" + << " width: " << width << ",\n" + << " height: " << height << ",\n" + << " batch_count: " << batch_count << ",\n" + << " init_image_path: \"" << init_image_path << "\",\n" + << " end_image_path: \"" << end_image_path << "\",\n" + << " mask_image_path: \"" << mask_image_path << "\",\n" + << " control_image_path: \"" << control_image_path << "\",\n" + << " ref_image_paths: " << vec_str_to_string(ref_image_paths) << ",\n" + << " control_video_path: \"" << control_video_path << "\",\n" + << " auto_resize_ref_image: " << (auto_resize_ref_image ? "true" : "false") << ",\n" + << " increase_ref_index: " << (increase_ref_index ? "true" : "false") << ",\n" + << " pm_id_images_dir: \"" << pm_id_images_dir << "\",\n" + << " pm_id_embed_path: \"" << pm_id_embed_path << "\",\n" + << " pm_style_strength: " << pm_style_strength << ",\n" + << " skip_layers: " << vec_to_string(skip_layers) << ",\n" + << " sample_params: " << SAFE_STR(sample_params_str.get()) << ",\n" + << " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n" + << " high_noise_sample_params: " << SAFE_STR(high_noise_sample_params_str.get()) << ",\n" + << " custom_sigmas: " << vec_to_string(custom_sigmas) << ",\n" + << " cache_mode: \"" << cache_mode << "\",\n" + << " cache_option: \"" << cache_option << "\",\n" + << " cache: " + << (cache_params.mode != SD_CACHE_DISABLED ? "enabled" : "disabled") + << " (threshold=" << cache_params.reuse_threshold + << ", start=" << cache_params.start_percent + << ", end=" << cache_params.end_percent << "),\n" + << " moe_boundary: " << moe_boundary << ",\n" + << " video_frames: " << video_frames << ",\n" + << " fps: " << fps << ",\n" + << " vace_strength: " << vace_strength << ",\n" + << " strength: " << strength << ",\n" + << " control_strength: " << control_strength << ",\n" + << " seed: " << seed << ",\n" + << " upscale_repeats: " << upscale_repeats << ",\n" + << " upscale_tile_size: " << upscale_tile_size << ",\n" + << " hires: { enabled: " << (hires_enabled ? "true" : "false") + << ", upscaler: \"" << hires_upscaler << "\"" + << ", model_path: \"" << hires_upscaler_model_path << "\"" + << ", scale: " << hires_scale + << ", target_width: " << hires_width + << ", target_height: " << hires_height + << ", steps: " << hires_steps + << ", denoising_strength: " << hires_denoising_strength + << ", upscale_tile_size: " << hires_upscale_tile_size << " },\n" + << " vae_tiling_params: { " + << vae_tiling_params.enabled << ", " + << vae_tiling_params.tile_size_x << ", " + << vae_tiling_params.tile_size_y << ", " + << vae_tiling_params.target_overlap << ", " + << vae_tiling_params.rel_size_x << ", " + << vae_tiling_params.rel_size_y << " },\n" + << "}"; + return oss.str(); +} + +std::string version_string() { + return std::string("stable-diffusion.cpp version ") + sd_version() + ", commit " + sd_commit(); +} + +static std::string safe_json_string(const char* value) { + return value ? value : ""; +} + +static void set_json_basename_if_not_empty(json& target, const char* key, const std::string& path) { + if (!path.empty()) { + target[key] = sd_basename(path); + } +} + +static json build_sampling_metadata_json(const sd_sample_params_t& sample_params, + const std::vector& skip_layers, + const std::vector* custom_sigmas = nullptr) { + json sampling = { + {"steps", sample_params.sample_steps}, + {"eta", sample_params.eta}, + {"shifted_timestep", sample_params.shifted_timestep}, + {"flow_shift", sample_params.flow_shift}, + {"guidance", + { + {"txt_cfg", sample_params.guidance.txt_cfg}, + {"img_cfg", sample_params.guidance.img_cfg}, + {"distilled_guidance", sample_params.guidance.distilled_guidance}, + {"slg", + { + {"scale", sample_params.guidance.slg.scale}, + {"layers", skip_layers}, + {"start", sample_params.guidance.slg.layer_start}, + {"end", sample_params.guidance.slg.layer_end}, + }}, + }}, + }; + if (sample_params.sample_method != SAMPLE_METHOD_COUNT) { + sampling["method"] = safe_json_string(sd_sample_method_name(sample_params.sample_method)); + } + if (sample_params.scheduler != SCHEDULER_COUNT) { + sampling["scheduler"] = safe_json_string(sd_scheduler_name(sample_params.scheduler)); + } + if (custom_sigmas != nullptr) { + sampling["custom_sigmas"] = *custom_sigmas; + } + return sampling; +} + +std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params, + const SDGenerationParams& gen_params, + int64_t seed, + SDMode mode) { + json root; + root["schema"] = "sdcpp.image.params/v1"; + root["mode"] = mode == VID_GEN ? "vid_gen" : "img_gen"; + root["generator"] = { + {"name", "stable-diffusion.cpp"}, + {"version", safe_json_string(sd_version())}, + {"commit", safe_json_string(sd_commit())}, + }; + root["seed"] = seed; + root["width"] = gen_params.get_resolved_width(); + root["height"] = gen_params.get_resolved_height(); + + root["prompt"] = { + {"positive", gen_params.prompt}, + {"negative", gen_params.negative_prompt}, + }; + root["sampling"] = build_sampling_metadata_json(gen_params.sample_params, + gen_params.skip_layers, + &gen_params.custom_sigmas); + + json models; + set_json_basename_if_not_empty(models, "model", ctx_params.model_path); + set_json_basename_if_not_empty(models, "clip_l", ctx_params.clip_l_path); + set_json_basename_if_not_empty(models, "clip_g", ctx_params.clip_g_path); + set_json_basename_if_not_empty(models, "clip_vision", ctx_params.clip_vision_path); + set_json_basename_if_not_empty(models, "t5xxl", ctx_params.t5xxl_path); + set_json_basename_if_not_empty(models, "llm", ctx_params.llm_path); + set_json_basename_if_not_empty(models, "llm_vision", ctx_params.llm_vision_path); + set_json_basename_if_not_empty(models, "diffusion_model", ctx_params.diffusion_model_path); + set_json_basename_if_not_empty(models, "high_noise_diffusion_model", ctx_params.high_noise_diffusion_model_path); + set_json_basename_if_not_empty(models, "vae", ctx_params.vae_path); + set_json_basename_if_not_empty(models, "taesd", ctx_params.taesd_path); + set_json_basename_if_not_empty(models, "control_net", ctx_params.control_net_path); + root["models"] = std::move(models); + + root["clip_skip"] = gen_params.clip_skip; + root["strength"] = gen_params.strength; + root["control_strength"] = gen_params.control_strength; + root["auto_resize_ref_image"] = gen_params.auto_resize_ref_image; + root["increase_ref_index"] = gen_params.increase_ref_index; + if (mode == VID_GEN) { + root["video"] = { + {"frame_count", gen_params.video_frames}, + {"fps", gen_params.fps}, + }; + root["moe_boundary"] = gen_params.moe_boundary; + root["vace_strength"] = gen_params.vace_strength; + root["high_noise_sampling"] = build_sampling_metadata_json(gen_params.high_noise_sample_params, + gen_params.high_noise_skip_layers); + } + + root["rng"] = safe_json_string(sd_rng_type_name(ctx_params.rng_type)); + if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) { + root["sampler_rng"] = safe_json_string(sd_rng_type_name(ctx_params.sampler_rng_type)); + } + + json loras = json::array(); + for (const auto& entry : gen_params.lora_map) { + loras.push_back({ + {"name", sd_basename(entry.first)}, + {"multiplier", entry.second}, + {"is_high_noise", false}, + }); + } + for (const auto& entry : gen_params.high_noise_lora_map) { + loras.push_back({ + {"name", sd_basename(entry.first)}, + {"multiplier", entry.second}, + {"is_high_noise", true}, + }); + } + if (!loras.empty()) { + root["loras"] = std::move(loras); + } + + if (gen_params.hires_enabled) { + root["hires"] = { + {"enabled", gen_params.hires_enabled}, + {"upscaler", gen_params.hires_upscaler}, + {"model", gen_params.hires_upscaler_model_path.empty() ? "" : sd_basename(gen_params.hires_upscaler_model_path)}, + {"scale", gen_params.hires_scale}, + {"target_width", gen_params.hires_width}, + {"target_height", gen_params.hires_height}, + {"steps", gen_params.hires_steps}, + {"denoising_strength", gen_params.hires_denoising_strength}, + {"upscale_tile_size", gen_params.hires_upscale_tile_size}, + }; + } + + if (gen_params.cache_params.mode != SD_CACHE_DISABLED) { + root["cache"] = { + {"requested_mode", gen_params.cache_mode}, + {"requested_option", gen_params.cache_option}, + {"mode", gen_params.cache_params.mode}, + {"scm_mask", gen_params.scm_mask}, + {"scm_policy_dynamic", gen_params.scm_policy_dynamic}, + {"reuse_threshold", gen_params.cache_params.reuse_threshold}, + {"start_percent", gen_params.cache_params.start_percent}, + {"end_percent", gen_params.cache_params.end_percent}, + {"error_decay_rate", gen_params.cache_params.error_decay_rate}, + {"use_relative_threshold", gen_params.cache_params.use_relative_threshold}, + {"reset_error_on_compute", gen_params.cache_params.reset_error_on_compute}, + {"Fn_compute_blocks", gen_params.cache_params.Fn_compute_blocks}, + {"Bn_compute_blocks", gen_params.cache_params.Bn_compute_blocks}, + {"residual_diff_threshold", gen_params.cache_params.residual_diff_threshold}, + {"max_warmup_steps", gen_params.cache_params.max_warmup_steps}, + {"max_cached_steps", gen_params.cache_params.max_cached_steps}, + {"max_continuous_cached_steps", gen_params.cache_params.max_continuous_cached_steps}, + {"taylorseer_n_derivatives", gen_params.cache_params.taylorseer_n_derivatives}, + {"taylorseer_skip_interval", gen_params.cache_params.taylorseer_skip_interval}, + {"spectrum_w", gen_params.cache_params.spectrum_w}, + {"spectrum_m", gen_params.cache_params.spectrum_m}, + {"spectrum_lam", gen_params.cache_params.spectrum_lam}, + {"spectrum_window_size", gen_params.cache_params.spectrum_window_size}, + {"spectrum_flex_window", gen_params.cache_params.spectrum_flex_window}, + {"spectrum_warmup_steps", gen_params.cache_params.spectrum_warmup_steps}, + {"spectrum_stop_percent", gen_params.cache_params.spectrum_stop_percent}, + }; + } + + if (gen_params.vae_tiling_params.enabled) { + root["vae_tiling"] = { + {"enabled", gen_params.vae_tiling_params.enabled}, + {"tile_size_x", gen_params.vae_tiling_params.tile_size_x}, + {"tile_size_y", gen_params.vae_tiling_params.tile_size_y}, + {"target_overlap", gen_params.vae_tiling_params.target_overlap}, + {"rel_size_x", gen_params.vae_tiling_params.rel_size_x}, + {"rel_size_y", gen_params.vae_tiling_params.rel_size_y}, + }; + } + + return root.dump(); +} + +std::string get_image_params(const SDContextParams& ctx_params, + const SDGenerationParams& gen_params, + int64_t seed, + SDMode mode) { + std::string parameter_string; + if (gen_params.prompt_with_lora.size() != 0) { + parameter_string += gen_params.prompt_with_lora + "\n"; + } else { + parameter_string += gen_params.prompt + "\n"; + } + if (gen_params.negative_prompt.size() != 0) { + parameter_string += "Negative prompt: " + gen_params.negative_prompt + "\n"; + } + parameter_string += "Steps: " + std::to_string(gen_params.sample_params.sample_steps) + ", "; + parameter_string += "CFG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", "; + if (gen_params.sample_params.guidance.slg.scale != 0 && gen_params.skip_layers.size() != 0) { + parameter_string += "SLG scale: " + std::to_string(gen_params.sample_params.guidance.slg.scale) + ", "; + parameter_string += "Skip layers: ["; + for (const auto& layer : gen_params.skip_layers) { + parameter_string += std::to_string(layer) + ", "; + } + parameter_string += "], "; + parameter_string += "Skip layer start: " + std::to_string(gen_params.sample_params.guidance.slg.layer_start) + ", "; + parameter_string += "Skip layer end: " + std::to_string(gen_params.sample_params.guidance.slg.layer_end) + ", "; + } + parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", "; + parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", "; + parameter_string += "Seed: " + std::to_string(seed) + ", "; + parameter_string += "Size: " + std::to_string(gen_params.get_resolved_width()) + "x" + std::to_string(gen_params.get_resolved_height()) + ", "; + parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", "; + parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", "; + if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) { + parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", "; + } + parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method)); + if (!gen_params.custom_sigmas.empty()) { + parameter_string += ", Custom Sigmas: ["; + for (size_t i = 0; i < gen_params.custom_sigmas.size(); ++i) { + std::ostringstream oss; + oss << std::fixed << std::setprecision(4) << gen_params.custom_sigmas[i]; + parameter_string += oss.str() + (i == gen_params.custom_sigmas.size() - 1 ? "" : ", "); + } + parameter_string += "]"; + } else if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { // Only show schedule if not using custom sigmas + parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler)); + } + parameter_string += ", "; + for (const auto& te : {ctx_params.clip_l_path, ctx_params.clip_g_path, ctx_params.t5xxl_path, ctx_params.llm_path, ctx_params.llm_vision_path}) { + if (!te.empty()) { + parameter_string += "TE: " + sd_basename(te) + ", "; + } + } + if (!ctx_params.diffusion_model_path.empty()) { + parameter_string += "Unet: " + sd_basename(ctx_params.diffusion_model_path) + ", "; + } + if (!ctx_params.vae_path.empty()) { + parameter_string += "VAE: " + sd_basename(ctx_params.vae_path) + ", "; + } + if (gen_params.clip_skip != -1) { + parameter_string += "Clip skip: " + std::to_string(gen_params.clip_skip) + ", "; + } + if (gen_params.hires_enabled) { + parameter_string += "Hires upscale: " + gen_params.hires_upscaler + ", "; + parameter_string += "Hires scale: " + std::to_string(gen_params.hires_scale) + ", "; + parameter_string += "Hires resize: " + std::to_string(gen_params.hires_width) + "x" + std::to_string(gen_params.hires_height) + ", "; + parameter_string += "Hires steps: " + std::to_string(gen_params.hires_steps) + ", "; + parameter_string += "Denoising strength: " + std::to_string(gen_params.hires_denoising_strength) + ", "; + } + parameter_string += "Version: stable-diffusion.cpp"; + parameter_string += ", SDCPP: " + build_sdcpp_image_metadata_json(ctx_params, gen_params, seed, mode); + return parameter_string; +} diff --git a/examples/common/common.h b/examples/common/common.h new file mode 100644 index 000000000..f87293f3e --- /dev/null +++ b/examples/common/common.h @@ -0,0 +1,262 @@ +#ifndef __EXAMPLES_COMMON_COMMON_H__ +#define __EXAMPLES_COMMON_COMMON_H__ + +#include +#include +#include +#include +#include +#include + +#include "log.h" +#include "resource_owners.hpp" +#include "stable-diffusion.h" + +#define SAFE_STR(s) ((s) ? (s) : "") +#define BOOL_STR(b) ((b) ? "true" : "false") + +extern const char* const modes_str[]; +#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale, metadata" + +enum SDMode { + IMG_GEN, + VID_GEN, + CONVERT, + UPSCALE, + METADATA, + MODE_COUNT +}; + +struct StringOption { + std::string short_name; + std::string long_name; + std::string desc; + std::string* target; +}; + +struct IntOption { + std::string short_name; + std::string long_name; + std::string desc; + int* target; +}; + +struct FloatOption { + std::string short_name; + std::string long_name; + std::string desc; + float* target; +}; + +struct BoolOption { + std::string short_name; + std::string long_name; + std::string desc; + bool keep_true; + bool* target; +}; + +struct ManualOption { + std::string short_name; + std::string long_name; + std::string desc; + std::function cb; +}; + +struct ArgOptions { + std::vector string_options; + std::vector int_options; + std::vector float_options; + std::vector bool_options; + std::vector manual_options; + + static std::string wrap_text(const std::string& text, size_t width, size_t indent); + void print() const; +}; + +bool parse_options(int argc, const char** argv, const std::vector& options_list); +bool decode_base64_image(const std::string& encoded_input, + int target_channels, + int expected_width, + int expected_height, + SDImageOwner& out_image); + +struct SDContextParams { + int n_threads = -1; + std::string model_path; + std::string clip_l_path; + std::string clip_g_path; + std::string clip_vision_path; + std::string t5xxl_path; + std::string llm_path; + std::string llm_vision_path; + std::string diffusion_model_path; + std::string high_noise_diffusion_model_path; + std::string vae_path; + std::string taesd_path; + std::string esrgan_path; + std::string control_net_path; + std::string embedding_dir; + std::string photo_maker_path; + sd_type_t wtype = SD_TYPE_COUNT; + std::string tensor_type_rules; + std::string lora_model_dir = "."; + std::string hires_upscalers_dir; + + std::map embedding_map; + std::vector embedding_vec; + + rng_type_t rng_type = CUDA_RNG; + rng_type_t sampler_rng_type = RNG_TYPE_COUNT; + bool offload_params_to_cpu = false; + float max_vram = 0.f; + bool enable_mmap = false; + bool control_net_cpu = false; + bool clip_on_cpu = false; + bool vae_on_cpu = false; + bool flash_attn = false; + bool diffusion_flash_attn = false; + bool diffusion_conv_direct = false; + bool vae_conv_direct = false; + + bool circular = false; + bool circular_x = false; + bool circular_y = false; + + bool chroma_use_dit_mask = true; + bool chroma_use_t5_mask = false; + int chroma_t5_mask_pad = 1; + + bool qwen_image_zero_cond_t = false; + + prediction_t prediction = PREDICTION_COUNT; + lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO; + + bool force_sdxl_vae_conv_scale = false; + + float flow_shift = INFINITY; + ArgOptions get_options(); + void build_embedding_map(); + bool resolve(SDMode mode); + bool validate(SDMode mode); + bool resolve_and_validate(SDMode mode); + std::string to_string() const; + sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview); +}; + +struct SDGenerationParams { + // User-facing input fields. + std::string prompt; + std::string negative_prompt; + int clip_skip = -1; // <= 0 represents unspecified + int width = -1; + int height = -1; + int batch_count = 1; + int64_t seed = 42; + float strength = 0.75f; + float control_strength = 0.9f; + bool auto_resize_ref_image = true; + bool increase_ref_index = false; + bool embed_image_metadata = true; + + std::string init_image_path; + std::string end_image_path; + std::string mask_image_path; + std::string control_image_path; + std::vector ref_image_paths; + std::string control_video_path; + + sd_sample_params_t sample_params; + sd_sample_params_t high_noise_sample_params; + std::vector skip_layers = {7, 8, 9}; + std::vector high_noise_skip_layers = {7, 8, 9}; + + std::vector custom_sigmas; + + std::string cache_mode; + std::string cache_option; + std::string scm_mask; + bool scm_policy_dynamic = true; + sd_cache_params_t cache_params{}; + + float moe_boundary = 0.875f; + int video_frames = 1; + int fps = 16; + float vace_strength = 1.f; + sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; + + std::string pm_id_images_dir; + std::string pm_id_embed_path; + float pm_style_strength = 20.f; + + int upscale_repeats = 1; + int upscale_tile_size = 128; + + bool hires_enabled = false; + std::string hires_upscaler = "Latent"; + std::string hires_upscaler_model_path; + float hires_scale = 2.f; + int hires_width = 0; + int hires_height = 0; + int hires_steps = 0; + float hires_denoising_strength = 0.7f; + int hires_upscale_tile_size = 128; + + std::map lora_map; + std::map high_noise_lora_map; + + // Derived and normalized fields. + std::string prompt_with_lora; // for metadata record only + std::vector lora_vec; + sd_hires_upscaler_t resolved_hires_upscaler; + + // Owned execution payload. + SDImageOwner init_image; + SDImageOwner end_image; + std::vector ref_images; + SDImageOwner mask_image; + SDImageOwner control_image; + std::vector pm_id_images; + std::vector control_frames; + + // Backing storage for sd_img_gen_params_t view fields. + std::vector ref_image_views; + std::vector pm_id_image_views; + std::vector control_frame_views; + + SDGenerationParams(); + SDGenerationParams(const SDGenerationParams& other) = default; + SDGenerationParams& operator=(const SDGenerationParams& other) = default; + SDGenerationParams(SDGenerationParams&& other) noexcept = default; + SDGenerationParams& operator=(SDGenerationParams&& other) noexcept = default; + ArgOptions get_options(); + bool from_json_str(const std::string& json_str, + const std::function& lora_path_resolver = {}); + bool initialize_cache_params(); + void extract_and_remove_lora(const std::string& lora_model_dir); + bool width_and_height_are_set() const; + void set_width_and_height_if_unset(int w, int h); + int get_resolved_width() const; + int get_resolved_height() const; + bool resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict = false); + bool validate(SDMode mode); + bool resolve_and_validate(SDMode mode, + const std::string& lora_model_dir, + const std::string& hires_upscalers_dir, + bool strict = false); + sd_img_gen_params_t to_sd_img_gen_params_t(); + sd_vid_gen_params_t to_sd_vid_gen_params_t(); + std::string to_string() const; +}; + +std::string version_string(); +std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params, + const SDGenerationParams& gen_params, + int64_t seed, + SDMode mode = IMG_GEN); +std::string get_image_params(const SDContextParams& ctx_params, + const SDGenerationParams& gen_params, + int64_t seed, + SDMode mode = IMG_GEN); + +#endif // __EXAMPLES_COMMON_COMMON_H__ diff --git a/examples/common/common.hpp b/examples/common/common.hpp deleted file mode 100644 index 7beef9d58..000000000 --- a/examples/common/common.hpp +++ /dev/null @@ -1,1902 +0,0 @@ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -using json = nlohmann::json; -namespace fs = std::filesystem; - -#if defined(_WIN32) -#define NOMINMAX -#include -#endif // _WIN32 - -#include "log.h" -#include "stable-diffusion.h" - -#define SAFE_STR(s) ((s) ? (s) : "") -#define BOOL_STR(b) ((b) ? "true" : "false") - -const char* modes_str[] = { - "img_gen", - "vid_gen", - "convert", - "upscale", - "metadata", -}; -#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale, metadata" - -enum SDMode { - IMG_GEN, - VID_GEN, - CONVERT, - UPSCALE, - METADATA, - MODE_COUNT -}; - -#if defined(_WIN32) -static std::string utf16_to_utf8(const std::wstring& wstr) { - if (wstr.empty()) - return {}; - int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(), - nullptr, 0, nullptr, nullptr); - if (size_needed <= 0) - throw std::runtime_error("UTF-16 to UTF-8 conversion failed"); - - std::string utf8(size_needed, 0); - WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(), - (char*)utf8.data(), size_needed, nullptr, nullptr); - return utf8; -} - -static std::string argv_to_utf8(int index, const char** argv) { - int argc; - wchar_t** argv_w = CommandLineToArgvW(GetCommandLineW(), &argc); - if (!argv_w) - throw std::runtime_error("Failed to parse command line"); - - std::string result; - if (index < argc) { - result = utf16_to_utf8(argv_w[index]); - } - LocalFree(argv_w); - return result; -} - -#else // Linux / macOS -static std::string argv_to_utf8(int index, const char** argv) { - return std::string(argv[index]); -} - -#endif - -struct StringOption { - std::string short_name; - std::string long_name; - std::string desc; - std::string* target; -}; - -struct IntOption { - std::string short_name; - std::string long_name; - std::string desc; - int* target; -}; - -struct FloatOption { - std::string short_name; - std::string long_name; - std::string desc; - float* target; -}; - -struct BoolOption { - std::string short_name; - std::string long_name; - std::string desc; - bool keep_true; - bool* target; -}; - -struct ManualOption { - std::string short_name; - std::string long_name; - std::string desc; - std::function cb; -}; - -struct ArgOptions { - std::vector string_options; - std::vector int_options; - std::vector float_options; - std::vector bool_options; - std::vector manual_options; - - static std::string wrap_text(const std::string& text, size_t width, size_t indent) { - std::ostringstream oss; - size_t line_len = 0; - size_t pos = 0; - - while (pos < text.size()) { - // Preserve manual newlines - if (text[pos] == '\n') { - oss << '\n' - << std::string(indent, ' '); - line_len = indent; - ++pos; - continue; - } - - // Add the character - oss << text[pos]; - ++line_len; - ++pos; - - // If the current line exceeds width, try to break at the last space - if (line_len >= width) { - std::string current = oss.str(); - size_t back = current.size(); - - // Find the last space (for a clean break) - while (back > 0 && current[back - 1] != ' ' && current[back - 1] != '\n') - --back; - - // If found a space to break on - if (back > 0 && current[back - 1] != '\n') { - std::string before = current.substr(0, back - 1); - std::string after = current.substr(back); - oss.str(""); - oss.clear(); - oss << before << "\n" - << std::string(indent, ' ') << after; - } else { - // If no space found, just break at width - oss << "\n" - << std::string(indent, ' '); - } - line_len = indent; - } - } - - return oss.str(); - } - - void print() const { - constexpr size_t max_line_width = 120; - - struct Entry { - std::string names; - std::string desc; - }; - std::vector entries; - - auto add_entry = [&](const std::string& s, const std::string& l, - const std::string& desc, const std::string& hint = "") { - std::ostringstream ss; - if (!s.empty()) - ss << s; - if (!s.empty() && !l.empty()) - ss << ", "; - if (!l.empty()) - ss << l; - if (!hint.empty()) - ss << " " << hint; - entries.push_back({ss.str(), desc}); - }; - - for (auto& o : string_options) - add_entry(o.short_name, o.long_name, o.desc, ""); - for (auto& o : int_options) - add_entry(o.short_name, o.long_name, o.desc, ""); - for (auto& o : float_options) - add_entry(o.short_name, o.long_name, o.desc, ""); - for (auto& o : bool_options) - add_entry(o.short_name, o.long_name, o.desc, ""); - for (auto& o : manual_options) - add_entry(o.short_name, o.long_name, o.desc); - - size_t max_name_width = 0; - for (auto& e : entries) - max_name_width = std::max(max_name_width, e.names.size()); - - for (auto& e : entries) { - size_t indent = 2 + max_name_width + 4; - size_t desc_width = (max_line_width > indent ? max_line_width - indent : 40); - std::string wrapped_desc = wrap_text(e.desc, max_line_width, indent); - std::cout << " " << std::left << std::setw(static_cast(max_name_width) + 4) - << e.names << wrapped_desc << "\n"; - } - } -}; - -static bool parse_options(int argc, const char** argv, const std::vector& options_list) { - bool invalid_arg = false; - std::string arg; - - auto match_and_apply = [&](auto& opts, auto&& apply_fn) -> bool { - for (auto& option : opts) { - if ((option.short_name.size() > 0 && arg == option.short_name) || - (option.long_name.size() > 0 && arg == option.long_name)) { - apply_fn(option); - return true; - } - } - return false; - }; - - for (int i = 1; i < argc; i++) { - arg = argv[i]; - bool found_arg = false; - - for (auto& options : options_list) { - if (match_and_apply(options.string_options, [&](auto& option) { - if (++i >= argc) { - invalid_arg = true; - return; - } - *option.target = argv_to_utf8(i, argv); - found_arg = true; - })) - break; - - if (match_and_apply(options.int_options, [&](auto& option) { - if (++i >= argc) { - invalid_arg = true; - return; - } - *option.target = std::stoi(argv[i]); - found_arg = true; - })) - break; - - if (match_and_apply(options.float_options, [&](auto& option) { - if (++i >= argc) { - invalid_arg = true; - return; - } - *option.target = std::stof(argv[i]); - found_arg = true; - })) - break; - - if (match_and_apply(options.bool_options, [&](auto& option) { - *option.target = option.keep_true ? true : false; - found_arg = true; - })) - break; - - if (match_and_apply(options.manual_options, [&](auto& option) { - int ret = option.cb(argc, argv, i); - if (ret < 0) { - invalid_arg = true; - return; - } - i += ret; - found_arg = true; - })) - break; - } - - if (invalid_arg) { - LOG_ERROR("error: invalid parameter for argument: %s", arg.c_str()); - return false; - } - if (!found_arg) { - LOG_ERROR("error: unknown argument: %s", arg.c_str()); - return false; - } - } - - return true; -} - -struct SDContextParams { - int n_threads = -1; - std::string model_path; - std::string clip_l_path; - std::string clip_g_path; - std::string clip_vision_path; - std::string t5xxl_path; - std::string llm_path; - std::string llm_vision_path; - std::string diffusion_model_path; - std::string high_noise_diffusion_model_path; - std::string vae_path; - std::string taesd_path; - std::string esrgan_path; - std::string control_net_path; - std::string embedding_dir; - std::string photo_maker_path; - sd_type_t wtype = SD_TYPE_COUNT; - std::string tensor_type_rules; - std::string lora_model_dir = "."; - - std::map embedding_map; - std::vector embedding_vec; - - rng_type_t rng_type = CUDA_RNG; - rng_type_t sampler_rng_type = RNG_TYPE_COUNT; - bool offload_params_to_cpu = false; - bool enable_mmap = false; - bool control_net_cpu = false; - bool clip_on_cpu = false; - bool vae_on_cpu = false; - bool flash_attn = false; - bool diffusion_flash_attn = false; - bool diffusion_conv_direct = false; - bool vae_conv_direct = false; - - bool circular = false; - bool circular_x = false; - bool circular_y = false; - - bool chroma_use_dit_mask = true; - bool chroma_use_t5_mask = false; - int chroma_t5_mask_pad = 1; - - bool qwen_image_zero_cond_t = false; - - prediction_t prediction = PREDICTION_COUNT; - lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO; - - bool force_sdxl_vae_conv_scale = false; - - float flow_shift = INFINITY; - - ArgOptions get_options() { - ArgOptions options; - options.string_options = { - {"-m", - "--model", - "path to full model", - &model_path}, - {"", - "--clip_l", - "path to the clip-l text encoder", &clip_l_path}, - {"", "--clip_g", - "path to the clip-g text encoder", - &clip_g_path}, - {"", - "--clip_vision", - "path to the clip-vision encoder", - &clip_vision_path}, - {"", - "--t5xxl", - "path to the t5xxl text encoder", - &t5xxl_path}, - {"", - "--llm", - "path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, mistral-small3.2 for flux2, ...)", - &llm_path}, - {"", - "--llm_vision", - "path to the llm vit", - &llm_vision_path}, - {"", - "--qwen2vl", - "alias of --llm. Deprecated.", - &llm_path}, - {"", - "--qwen2vl_vision", - "alias of --llm_vision. Deprecated.", - &llm_vision_path}, - {"", - "--diffusion-model", - "path to the standalone diffusion model", - &diffusion_model_path}, - {"", - "--high-noise-diffusion-model", - "path to the standalone high noise diffusion model", - &high_noise_diffusion_model_path}, - {"", - "--vae", - "path to standalone vae model", - &vae_path}, - {"", - "--taesd", - "path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)", - &taesd_path}, - {"", - "--tae", - "alias of --taesd", - &taesd_path}, - {"", - "--control-net", - "path to control net model", - &control_net_path}, - {"", - "--embd-dir", - "embeddings directory", - &embedding_dir}, - {"", - "--lora-model-dir", - "lora model directory", - &lora_model_dir}, - - {"", - "--tensor-type-rules", - "weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")", - &tensor_type_rules}, - {"", - "--photo-maker", - "path to PHOTOMAKER model", - &photo_maker_path}, - {"", - "--upscale-model", - "path to esrgan model.", - &esrgan_path}, - }; - - options.int_options = { - {"-t", - "--threads", - "number of threads to use during computation (default: -1). " - "If threads <= 0, then threads will be set to the number of CPU physical cores", - &n_threads}, - {"", - "--chroma-t5-mask-pad", - "t5 mask pad size of chroma", - &chroma_t5_mask_pad}, - }; - - options.float_options = {}; - - options.bool_options = { - {"", - "--force-sdxl-vae-conv-scale", - "force use of conv scale on sdxl vae", - true, &force_sdxl_vae_conv_scale}, - {"", - "--offload-to-cpu", - "place the weights in RAM to save VRAM, and automatically load them into VRAM when needed", - true, &offload_params_to_cpu}, - {"", - "--mmap", - "whether to memory-map model", - true, &enable_mmap}, - {"", - "--control-net-cpu", - "keep controlnet in cpu (for low vram)", - true, &control_net_cpu}, - {"", - "--clip-on-cpu", - "keep clip in cpu (for low vram)", - true, &clip_on_cpu}, - {"", - "--vae-on-cpu", - "keep vae in cpu (for low vram)", - true, &vae_on_cpu}, - {"", - "--fa", - "use flash attention", - true, &flash_attn}, - {"", - "--diffusion-fa", - "use flash attention in the diffusion model only", - true, &diffusion_flash_attn}, - {"", - "--diffusion-conv-direct", - "use ggml_conv2d_direct in the diffusion model", - true, &diffusion_conv_direct}, - {"", - "--vae-conv-direct", - "use ggml_conv2d_direct in the vae model", - true, &vae_conv_direct}, - {"", - "--circular", - "enable circular padding for convolutions", - true, &circular}, - {"", - "--circularx", - "enable circular RoPE wrapping on x-axis (width) only", - true, &circular_x}, - {"", - "--circulary", - "enable circular RoPE wrapping on y-axis (height) only", - true, &circular_y}, - {"", - "--chroma-disable-dit-mask", - "disable dit mask for chroma", - false, &chroma_use_dit_mask}, - {"", - "--qwen-image-zero-cond-t", - "enable zero_cond_t for qwen image", - true, &qwen_image_zero_cond_t}, - {"", - "--chroma-enable-t5-mask", - "enable t5 mask for chroma", - true, &chroma_use_t5_mask}, - }; - - auto on_type_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - wtype = str_to_sd_type(arg); - if (wtype == SD_TYPE_COUNT) { - LOG_ERROR("error: invalid weight format %s", - arg); - return -1; - } - return 1; - }; - - auto on_rng_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - rng_type = str_to_rng_type(arg); - if (rng_type == RNG_TYPE_COUNT) { - LOG_ERROR("error: invalid rng type %s", - arg); - return -1; - } - return 1; - }; - - auto on_sampler_rng_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - sampler_rng_type = str_to_rng_type(arg); - if (sampler_rng_type == RNG_TYPE_COUNT) { - LOG_ERROR("error: invalid sampler rng type %s", - arg); - return -1; - } - return 1; - }; - - auto on_prediction_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - prediction = str_to_prediction(arg); - if (prediction == PREDICTION_COUNT) { - LOG_ERROR("error: invalid prediction type %s", - arg); - return -1; - } - return 1; - }; - - auto on_lora_apply_mode_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - lora_apply_mode = str_to_lora_apply_mode(arg); - if (lora_apply_mode == LORA_APPLY_MODE_COUNT) { - LOG_ERROR("error: invalid lora apply model %s", - arg); - return -1; - } - return 1; - }; - - options.manual_options = { - {"", - "--type", - "weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). " - "If not specified, the default is the type of the weight file", - on_type_arg}, - {"", - "--rng", - "RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)", - on_rng_arg}, - {"", - "--sampler-rng", - "sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng", - on_sampler_rng_arg}, - {"", - "--prediction", - "prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, flux2_flow]", - on_prediction_arg}, - {"", - "--lora-apply-mode", - "the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. " - "In auto mode, if the model weights contain any quantized parameters, the at_runtime mode will be used; otherwise, immediately will be used." - "The immediately mode may have precision and compatibility issues with quantized parameters, " - "but it usually offers faster inference speed and, in some cases, lower memory usage. " - "The at_runtime mode, on the other hand, is exactly the opposite.", - on_lora_apply_mode_arg}, - }; - - return options; - } - - void build_embedding_map() { - static const std::vector valid_ext = {".gguf", ".safetensors", ".pt"}; - - if (!fs::exists(embedding_dir) || !fs::is_directory(embedding_dir)) { - return; - } - - for (auto& p : fs::directory_iterator(embedding_dir)) { - if (!p.is_regular_file()) - continue; - - auto path = p.path(); - std::string ext = path.extension().string(); - - bool valid = false; - for (auto& e : valid_ext) { - if (ext == e) { - valid = true; - break; - } - } - if (!valid) - continue; - - std::string key = path.stem().string(); - std::string value = path.string(); - - embedding_map[key] = value; - } - } - - bool process_and_check(SDMode mode) { - if (mode != UPSCALE && mode != METADATA && model_path.length() == 0 && diffusion_model_path.length() == 0) { - LOG_ERROR("error: the following arguments are required: model_path/diffusion_model\n"); - return false; - } - - if (mode == UPSCALE) { - if (esrgan_path.length() == 0) { - LOG_ERROR("error: upscale mode needs an upscaler model (--upscale-model)\n"); - return false; - } - } - - if (n_threads <= 0) { - n_threads = sd_get_num_physical_cores(); - } - - build_embedding_map(); - - return true; - } - - std::string to_string() const { - std::ostringstream emb_ss; - emb_ss << "{\n"; - for (auto it = embedding_map.begin(); it != embedding_map.end(); ++it) { - emb_ss << " \"" << it->first << "\": \"" << it->second << "\""; - if (std::next(it) != embedding_map.end()) { - emb_ss << ","; - } - emb_ss << "\n"; - } - emb_ss << " }"; - - std::string embeddings_str = emb_ss.str(); - std::ostringstream oss; - oss << "SDContextParams {\n" - << " n_threads: " << n_threads << ",\n" - << " model_path: \"" << model_path << "\",\n" - << " clip_l_path: \"" << clip_l_path << "\",\n" - << " clip_g_path: \"" << clip_g_path << "\",\n" - << " clip_vision_path: \"" << clip_vision_path << "\",\n" - << " t5xxl_path: \"" << t5xxl_path << "\",\n" - << " llm_path: \"" << llm_path << "\",\n" - << " llm_vision_path: \"" << llm_vision_path << "\",\n" - << " diffusion_model_path: \"" << diffusion_model_path << "\",\n" - << " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n" - << " vae_path: \"" << vae_path << "\",\n" - << " taesd_path: \"" << taesd_path << "\",\n" - << " esrgan_path: \"" << esrgan_path << "\",\n" - << " control_net_path: \"" << control_net_path << "\",\n" - << " embedding_dir: \"" << embedding_dir << "\",\n" - << " embeddings: " << embeddings_str << "\n" - << " wtype: " << sd_type_name(wtype) << ",\n" - << " tensor_type_rules: \"" << tensor_type_rules << "\",\n" - << " lora_model_dir: \"" << lora_model_dir << "\",\n" - << " photo_maker_path: \"" << photo_maker_path << "\",\n" - << " rng_type: " << sd_rng_type_name(rng_type) << ",\n" - << " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n" - << " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n" - << " enable_mmap: " << (enable_mmap ? "true" : "false") << ",\n" - << " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n" - << " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n" - << " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n" - << " flash_attn: " << (flash_attn ? "true" : "false") << ",\n" - << " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n" - << " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n" - << " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n" - << " circular: " << (circular ? "true" : "false") << ",\n" - << " circular_x: " << (circular_x ? "true" : "false") << ",\n" - << " circular_y: " << (circular_y ? "true" : "false") << ",\n" - << " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n" - << " qwen_image_zero_cond_t: " << (qwen_image_zero_cond_t ? "true" : "false") << ",\n" - << " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n" - << " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n" - << " prediction: " << sd_prediction_name(prediction) << ",\n" - << " lora_apply_mode: " << sd_lora_apply_mode_name(lora_apply_mode) << ",\n" - << " force_sdxl_vae_conv_scale: " << (force_sdxl_vae_conv_scale ? "true" : "false") << "\n" - << "}"; - return oss.str(); - } - - sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) { - embedding_vec.clear(); - embedding_vec.reserve(embedding_map.size()); - for (const auto& kv : embedding_map) { - sd_embedding_t item; - item.name = kv.first.c_str(); - item.path = kv.second.c_str(); - embedding_vec.emplace_back(item); - } - - sd_ctx_params_t sd_ctx_params = { - model_path.c_str(), - clip_l_path.c_str(), - clip_g_path.c_str(), - clip_vision_path.c_str(), - t5xxl_path.c_str(), - llm_path.c_str(), - llm_vision_path.c_str(), - diffusion_model_path.c_str(), - high_noise_diffusion_model_path.c_str(), - vae_path.c_str(), - taesd_path.c_str(), - control_net_path.c_str(), - embedding_vec.data(), - static_cast(embedding_vec.size()), - photo_maker_path.c_str(), - tensor_type_rules.c_str(), - vae_decode_only, - free_params_immediately, - n_threads, - wtype, - rng_type, - sampler_rng_type, - prediction, - lora_apply_mode, - offload_params_to_cpu, - enable_mmap, - clip_on_cpu, - control_net_cpu, - vae_on_cpu, - flash_attn, - diffusion_flash_attn, - taesd_preview, - diffusion_conv_direct, - vae_conv_direct, - circular || circular_x, - circular || circular_y, - force_sdxl_vae_conv_scale, - chroma_use_dit_mask, - chroma_use_t5_mask, - chroma_t5_mask_pad, - qwen_image_zero_cond_t, - }; - return sd_ctx_params; - } -}; - -template -static std::string vec_to_string(const std::vector& v) { - std::ostringstream oss; - oss << "["; - for (size_t i = 0; i < v.size(); i++) { - oss << v[i]; - if (i + 1 < v.size()) - oss << ", "; - } - oss << "]"; - return oss.str(); -} - -static std::string vec_str_to_string(const std::vector& v) { - std::ostringstream oss; - oss << "["; - for (size_t i = 0; i < v.size(); i++) { - oss << "\"" << v[i] << "\""; - if (i + 1 < v.size()) - oss << ", "; - } - oss << "]"; - return oss.str(); -} - -static bool is_absolute_path(const std::string& p) { -#ifdef _WIN32 - // Windows: C:/path or C:\path - return p.size() > 1 && std::isalpha(static_cast(p[0])) && p[1] == ':'; -#else - return !p.empty() && p[0] == '/'; -#endif -} - -struct SDGenerationParams { - std::string prompt; - std::string prompt_with_lora; // for metadata record only - std::string negative_prompt; - int clip_skip = -1; // <= 0 represents unspecified - int width = -1; - int height = -1; - int batch_count = 1; - std::string init_image_path; - std::string end_image_path; - std::string mask_image_path; - std::string control_image_path; - std::vector ref_image_paths; - std::string control_video_path; - bool auto_resize_ref_image = true; - bool increase_ref_index = false; - bool embed_image_metadata = true; - - std::vector skip_layers = {7, 8, 9}; - sd_sample_params_t sample_params; - - std::vector high_noise_skip_layers = {7, 8, 9}; - sd_sample_params_t high_noise_sample_params; - - std::vector custom_sigmas; - - std::string cache_mode; - std::string cache_option; - std::string scm_mask; - bool scm_policy_dynamic = true; - sd_cache_params_t cache_params{}; - - float moe_boundary = 0.875f; - int video_frames = 1; - int fps = 16; - float vace_strength = 1.f; - - float strength = 0.75f; - float control_strength = 0.9f; - - int64_t seed = 42; - - sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; - - // Photo Maker - std::string pm_id_images_dir; - std::string pm_id_embed_path; - float pm_style_strength = 20.f; - - int upscale_repeats = 1; - int upscale_tile_size = 128; - - std::map lora_map; - std::map high_noise_lora_map; - std::vector lora_vec; - - SDGenerationParams() { - sd_sample_params_init(&sample_params); - sd_sample_params_init(&high_noise_sample_params); - } - - ArgOptions get_options() { - ArgOptions options; - options.string_options = { - {"-p", - "--prompt", - "the prompt to render", - &prompt}, - {"-n", - "--negative-prompt", - "the negative prompt (default: \"\")", - &negative_prompt}, - {"-i", - "--init-img", - "path to the init image", - &init_image_path}, - {"", - "--end-img", - "path to the end image, required by flf2v", - &end_image_path}, - {"", - "--mask", - "path to the mask image", - &mask_image_path}, - {"", - "--control-image", - "path to control image, control net", - &control_image_path}, - {"", - "--control-video", - "path to control video frames, It must be a directory path. The video frames inside should be stored as images in " - "lexicographical (character) order. For example, if the control video path is `frames`, the directory contain images " - "such as 00.png, 01.png, ... etc.", - &control_video_path}, - {"", - "--pm-id-images-dir", - "path to PHOTOMAKER input id images dir", - &pm_id_images_dir}, - {"", - "--pm-id-embed-path", - "path to PHOTOMAKER v2 id embed", - &pm_id_embed_path}, - }; - - options.int_options = { - {"-H", - "--height", - "image height, in pixel space (default: 512)", - &height}, - {"-W", - "--width", - "image width, in pixel space (default: 512)", - &width}, - {"", - "--steps", - "number of sample steps (default: 20)", - &sample_params.sample_steps}, - {"", - "--high-noise-steps", - "(high noise) number of sample steps (default: -1 = auto)", - &high_noise_sample_params.sample_steps}, - {"", - "--clip-skip", - "ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1). " - "<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x", - &clip_skip}, - {"-b", - "--batch-count", - "batch count", - &batch_count}, - {"", - "--video-frames", - "video frames (default: 1)", - &video_frames}, - {"", - "--fps", - "fps (default: 24)", - &fps}, - {"", - "--timestep-shift", - "shift timestep for NitroFusion models (default: 0). " - "recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant", - &sample_params.shifted_timestep}, - {"", - "--upscale-repeats", - "Run the ESRGAN upscaler this many times (default: 1)", - &upscale_repeats}, - {"", - "--upscale-tile-size", - "tile size for ESRGAN upscaling (default: 128)", - &upscale_tile_size}, - }; - - options.float_options = { - {"", - "--cfg-scale", - "unconditional guidance scale: (default: 7.0)", - &sample_params.guidance.txt_cfg}, - {"", - "--img-cfg-scale", - "image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)", - &sample_params.guidance.img_cfg}, - {"", - "--guidance", - "distilled guidance scale for models with guidance input (default: 3.5)", - &sample_params.guidance.distilled_guidance}, - {"", - "--slg-scale", - "skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means disabled, a value of 2.5 is nice for sd3.5 medium", - &sample_params.guidance.slg.scale}, - {"", - "--skip-layer-start", - "SLG enabling point (default: 0.01)", - &sample_params.guidance.slg.layer_start}, - {"", - "--skip-layer-end", - "SLG disabling point (default: 0.2)", - &sample_params.guidance.slg.layer_end}, - {"", - "--eta", - "noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)", - &sample_params.eta}, - {"", - "--flow-shift", - "shift value for Flow models like SD3.x or WAN (default: auto)", - &sample_params.flow_shift}, - {"", - "--high-noise-cfg-scale", - "(high noise) unconditional guidance scale: (default: 7.0)", - &high_noise_sample_params.guidance.txt_cfg}, - {"", - "--high-noise-img-cfg-scale", - "(high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)", - &high_noise_sample_params.guidance.img_cfg}, - {"", - "--high-noise-guidance", - "(high noise) distilled guidance scale for models with guidance input (default: 3.5)", - &high_noise_sample_params.guidance.distilled_guidance}, - {"", - "--high-noise-slg-scale", - "(high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)", - &high_noise_sample_params.guidance.slg.scale}, - {"", - "--high-noise-skip-layer-start", - "(high noise) SLG enabling point (default: 0.01)", - &high_noise_sample_params.guidance.slg.layer_start}, - {"", - "--high-noise-skip-layer-end", - "(high noise) SLG disabling point (default: 0.2)", - &high_noise_sample_params.guidance.slg.layer_end}, - {"", - "--high-noise-eta", - "(high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)", - &high_noise_sample_params.eta}, - {"", - "--strength", - "strength for noising/unnoising (default: 0.75)", - &strength}, - {"", - "--pm-style-strength", - "", - &pm_style_strength}, - {"", - "--control-strength", - "strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image", - &control_strength}, - {"", - "--moe-boundary", - "timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if `--high-noise-steps` is set to -1", - &moe_boundary}, - {"", - "--vace-strength", - "wan vace strength", - &vace_strength}, - {"", - "--vae-tile-overlap", - "tile overlap for vae tiling, in fraction of tile size (default: 0.5)", - &vae_tiling_params.target_overlap}, - }; - - options.bool_options = { - {"", - "--increase-ref-index", - "automatically increase the indices of references images based on the order they are listed (starting with 1).", - true, - &increase_ref_index}, - {"", - "--disable-auto-resize-ref-image", - "disable auto resize of ref images", - false, - &auto_resize_ref_image}, - {"", - "--disable-image-metadata", - "do not embed generation metadata on image files", - false, - &embed_image_metadata}, - {"", - "--vae-tiling", - "process vae in tiles to reduce memory usage", - true, - &vae_tiling_params.enabled}, - }; - - auto on_seed_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - seed = std::stoll(argv[index]); - return 1; - }; - - auto on_sample_method_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - sample_params.sample_method = str_to_sample_method(arg); - if (sample_params.sample_method == SAMPLE_METHOD_COUNT) { - LOG_ERROR("error: invalid sample method %s", - arg); - return -1; - } - return 1; - }; - - auto on_high_noise_sample_method_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - high_noise_sample_params.sample_method = str_to_sample_method(arg); - if (high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) { - LOG_ERROR("error: invalid high noise sample method %s", - arg); - return -1; - } - return 1; - }; - - auto on_scheduler_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - const char* arg = argv[index]; - sample_params.scheduler = str_to_scheduler(arg); - if (sample_params.scheduler == SCHEDULER_COUNT) { - LOG_ERROR("error: invalid scheduler %s", - arg); - return -1; - } - return 1; - }; - - auto on_skip_layers_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - std::string layers_str = argv[index]; - if (layers_str[0] != '[' || layers_str[layers_str.size() - 1] != ']') { - return -1; - } - - layers_str = layers_str.substr(1, layers_str.size() - 2); - - std::regex regex("[, ]+"); - std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1); - std::sregex_token_iterator end; - std::vector tokens(iter, end); - std::vector layers; - for (const auto& token : tokens) { - try { - layers.push_back(std::stoi(token)); - } catch (const std::invalid_argument&) { - return -1; - } - } - skip_layers = layers; - return 1; - }; - - auto on_high_noise_skip_layers_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - std::string layers_str = argv[index]; - if (layers_str[0] != '[' || layers_str[layers_str.size() - 1] != ']') { - return -1; - } - - layers_str = layers_str.substr(1, layers_str.size() - 2); - - std::regex regex("[, ]+"); - std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1); - std::sregex_token_iterator end; - std::vector tokens(iter, end); - std::vector layers; - for (const auto& token : tokens) { - try { - layers.push_back(std::stoi(token)); - } catch (const std::invalid_argument&) { - return -1; - } - } - high_noise_skip_layers = layers; - return 1; - }; - - auto on_sigmas_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - std::string sigmas_str = argv[index]; - if (!sigmas_str.empty() && sigmas_str.front() == '[') { - sigmas_str.erase(0, 1); - } - if (!sigmas_str.empty() && sigmas_str.back() == ']') { - sigmas_str.pop_back(); - } - - std::stringstream ss(sigmas_str); - std::string item; - while (std::getline(ss, item, ',')) { - item.erase(0, item.find_first_not_of(" \t\n\r\f\v")); - item.erase(item.find_last_not_of(" \t\n\r\f\v") + 1); - if (!item.empty()) { - try { - custom_sigmas.push_back(std::stof(item)); - } catch (const std::invalid_argument&) { - LOG_ERROR("error: invalid float value '%s' in --sigmas", item.c_str()); - return -1; - } catch (const std::out_of_range&) { - LOG_ERROR("error: float value '%s' out of range in --sigmas", item.c_str()); - return -1; - } - } - } - - if (custom_sigmas.empty() && !sigmas_str.empty()) { - LOG_ERROR("error: could not parse any sigma values from '%s'", argv[index]); - return -1; - } - return 1; - }; - - auto on_ref_image_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - ref_image_paths.push_back(argv[index]); - return 1; - }; - - auto on_cache_mode_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - cache_mode = argv_to_utf8(index, argv); - if (cache_mode != "easycache" && cache_mode != "ucache" && - cache_mode != "dbcache" && cache_mode != "taylorseer" && cache_mode != "cache-dit" && cache_mode != "spectrum") { - fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache', 'ucache', 'dbcache', 'taylorseer', 'cache-dit', or 'spectrum'\n", cache_mode.c_str()); - return -1; - } - return 1; - }; - - auto on_cache_option_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - cache_option = argv_to_utf8(index, argv); - return 1; - }; - - auto on_scm_mask_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - scm_mask = argv_to_utf8(index, argv); - return 1; - }; - - auto on_scm_policy_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - std::string policy = argv_to_utf8(index, argv); - if (policy == "dynamic") { - scm_policy_dynamic = true; - } else if (policy == "static") { - scm_policy_dynamic = false; - } else { - fprintf(stderr, "error: invalid scm policy '%s', must be 'dynamic' or 'static'\n", policy.c_str()); - return -1; - } - return 1; - }; - - auto on_tile_size_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - std::string tile_size_str = argv[index]; - size_t x_pos = tile_size_str.find('x'); - try { - if (x_pos != std::string::npos) { - std::string tile_x_str = tile_size_str.substr(0, x_pos); - std::string tile_y_str = tile_size_str.substr(x_pos + 1); - vae_tiling_params.tile_size_x = std::stoi(tile_x_str); - vae_tiling_params.tile_size_y = std::stoi(tile_y_str); - } else { - vae_tiling_params.tile_size_x = vae_tiling_params.tile_size_y = std::stoi(tile_size_str); - } - } catch (const std::invalid_argument&) { - return -1; - } catch (const std::out_of_range&) { - return -1; - } - return 1; - }; - - auto on_relative_tile_size_arg = [&](int argc, const char** argv, int index) { - if (++index >= argc) { - return -1; - } - std::string rel_size_str = argv[index]; - size_t x_pos = rel_size_str.find('x'); - try { - if (x_pos != std::string::npos) { - std::string rel_x_str = rel_size_str.substr(0, x_pos); - std::string rel_y_str = rel_size_str.substr(x_pos + 1); - vae_tiling_params.rel_size_x = std::stof(rel_x_str); - vae_tiling_params.rel_size_y = std::stof(rel_y_str); - } else { - vae_tiling_params.rel_size_x = vae_tiling_params.rel_size_y = std::stof(rel_size_str); - } - } catch (const std::invalid_argument&) { - return -1; - } catch (const std::out_of_range&) { - return -1; - } - return 1; - }; - - options.manual_options = { - {"-s", - "--seed", - "RNG seed (default: 42, use random seed for < 0)", - on_seed_arg}, - {"", - "--sampling-method", - "sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s] " - "(default: euler for Flux/SD3/Wan, euler_a otherwise)", - on_sample_method_arg}, - {"", - "--high-noise-sampling-method", - "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s]" - " default: euler for Flux/SD3/Wan, euler_a otherwise", - on_high_noise_sample_method_arg}, - {"", - "--scheduler", - "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default: discrete", - on_scheduler_arg}, - {"", - "--sigmas", - "custom sigma values for the sampler, comma-separated (e.g., \"14.61,7.8,3.5,0.0\").", - on_sigmas_arg}, - {"", - "--skip-layers", - "layers to skip for SLG steps (default: [7,8,9])", - on_skip_layers_arg}, - {"", - "--high-noise-skip-layers", - "(high noise) layers to skip for SLG steps (default: [7,8,9])", - on_high_noise_skip_layers_arg}, - {"-r", - "--ref-image", - "reference image for Flux Kontext models (can be used multiple times)", - on_ref_image_arg}, - {"", - "--cache-mode", - "caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)", - on_cache_mode_arg}, - {"", - "--cache-option", - "named cache params (key=value format, comma-separated). easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=; spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"", - on_cache_option_arg}, - {"", - "--scm-mask", - "SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache", - on_scm_mask_arg}, - {"", - "--scm-policy", - "SCM policy: 'dynamic' (default) or 'static'", - on_scm_policy_arg}, - {"", - "--vae-tile-size", - "tile size for vae tiling, format [X]x[Y] (default: 32x32)", - on_tile_size_arg}, - {"", - "--vae-relative-tile-size", - "relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)", - on_relative_tile_size_arg}, - - }; - - return options; - } - - bool from_json_str(const std::string& json_str) { - json j; - try { - j = json::parse(json_str); - } catch (...) { - LOG_ERROR("json parse failed %s", json_str.c_str()); - return false; - } - - auto load_if_exists = [&](const char* key, auto& out) { - if (j.contains(key)) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - if (j[key].is_string()) - out = j[key]; - } else if constexpr (std::is_same_v || std::is_same_v) { - if (j[key].is_number_integer()) - out = j[key]; - } else if constexpr (std::is_same_v) { - if (j[key].is_number()) - out = j[key]; - } else if constexpr (std::is_same_v) { - if (j[key].is_boolean()) - out = j[key]; - } else if constexpr (std::is_same_v>) { - if (j[key].is_array()) - out = j[key].get>(); - } else if constexpr (std::is_same_v>) { - if (j[key].is_array()) - out = j[key].get>(); - } - } - }; - - load_if_exists("prompt", prompt); - load_if_exists("negative_prompt", negative_prompt); - load_if_exists("cache_mode", cache_mode); - load_if_exists("cache_option", cache_option); - load_if_exists("scm_mask", scm_mask); - - load_if_exists("clip_skip", clip_skip); - load_if_exists("width", width); - load_if_exists("height", height); - load_if_exists("batch_count", batch_count); - load_if_exists("video_frames", video_frames); - load_if_exists("fps", fps); - load_if_exists("upscale_repeats", upscale_repeats); - load_if_exists("seed", seed); - - load_if_exists("strength", strength); - load_if_exists("control_strength", control_strength); - load_if_exists("pm_style_strength", pm_style_strength); - load_if_exists("moe_boundary", moe_boundary); - load_if_exists("vace_strength", vace_strength); - - load_if_exists("auto_resize_ref_image", auto_resize_ref_image); - load_if_exists("increase_ref_index", increase_ref_index); - load_if_exists("embed_image_metadata", embed_image_metadata); - - load_if_exists("skip_layers", skip_layers); - load_if_exists("high_noise_skip_layers", high_noise_skip_layers); - - load_if_exists("steps", sample_params.sample_steps); - load_if_exists("high_noise_steps", high_noise_sample_params.sample_steps); - load_if_exists("cfg_scale", sample_params.guidance.txt_cfg); - load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg); - load_if_exists("guidance", sample_params.guidance.distilled_guidance); - load_if_exists("flow_shift", sample_params.flow_shift); - - auto load_sampler_if_exists = [&](const char* key, enum sample_method_t& out) { - if (j.contains(key) && j[key].is_string()) { - enum sample_method_t tmp = str_to_sample_method(j[key].get().c_str()); - if (tmp != SAMPLE_METHOD_COUNT) { - out = tmp; - } - } - }; - load_sampler_if_exists("sample_method", sample_params.sample_method); - load_sampler_if_exists("high_noise_sample_method", high_noise_sample_params.sample_method); - - if (j.contains("scheduler") && j["scheduler"].is_string()) { - enum scheduler_t tmp = str_to_scheduler(j["scheduler"].get().c_str()); - if (tmp != SCHEDULER_COUNT) { - sample_params.scheduler = tmp; - } - } - - return true; - } - - void extract_and_remove_lora(const std::string& lora_model_dir) { - if (lora_model_dir.empty()) { - return; - } - static const std::regex re(R"(]+):([^>]+)>)"); - static const std::vector valid_ext = {".gguf", ".safetensors", ".pt"}; - std::smatch m; - - std::string tmp = prompt; - - while (std::regex_search(tmp, m, re)) { - std::string raw_path = m[1].str(); - const std::string raw_mul = m[2].str(); - - float mul = 0.f; - try { - mul = std::stof(raw_mul); - } catch (...) { - tmp = m.suffix().str(); - prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); - continue; - } - - bool is_high_noise = false; - static const std::string prefix = "|high_noise|"; - if (raw_path.rfind(prefix, 0) == 0) { - raw_path.erase(0, prefix.size()); - is_high_noise = true; - } - - fs::path final_path; - if (is_absolute_path(raw_path)) { - final_path = raw_path; - } else { - final_path = fs::path(lora_model_dir) / raw_path; - } - if (!fs::exists(final_path)) { - bool found = false; - for (const auto& ext : valid_ext) { - fs::path try_path = final_path; - try_path += ext; - if (fs::exists(try_path)) { - final_path = try_path; - found = true; - break; - } - } - if (!found) { - LOG_WARN("can not found lora %s", final_path.lexically_normal().string().c_str()); - tmp = m.suffix().str(); - prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); - continue; - } - } - - const std::string key = final_path.lexically_normal().string(); - - if (is_high_noise) - high_noise_lora_map[key] += mul; - else - lora_map[key] += mul; - - prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); - - tmp = m.suffix().str(); - } - - for (const auto& kv : lora_map) { - sd_lora_t item; - item.is_high_noise = false; - item.path = kv.first.c_str(); - item.multiplier = kv.second; - lora_vec.emplace_back(item); - } - - for (const auto& kv : high_noise_lora_map) { - sd_lora_t item; - item.is_high_noise = true; - item.path = kv.first.c_str(); - item.multiplier = kv.second; - lora_vec.emplace_back(item); - } - } - - bool width_and_height_are_set() const { - return width > 0 && height > 0; - } - - void set_width_and_height_if_unset(int w, int h) { - if (!width_and_height_are_set()) { - LOG_INFO("set width x height to %d x %d", w, h); - width = w; - height = h; - } - } - - int get_resolved_width() const { return (width > 0) ? width : 512; } - - int get_resolved_height() const { return (height > 0) ? height : 512; } - - bool process_and_check(SDMode mode, const std::string& lora_model_dir) { - prompt_with_lora = prompt; - - if (sample_params.sample_steps <= 0) { - LOG_ERROR("error: the sample_steps must be greater than 0\n"); - return false; - } - - if (high_noise_sample_params.sample_steps <= 0) { - high_noise_sample_params.sample_steps = -1; - } - - if (strength < 0.f || strength > 1.f) { - LOG_ERROR("error: can only work with strength in [0.0, 1.0]\n"); - return false; - } - - sd_cache_params_init(&cache_params); - - auto parse_named_params = [&](const std::string& opt_str) -> bool { - std::stringstream ss(opt_str); - std::string token; - while (std::getline(ss, token, ',')) { - size_t eq_pos = token.find('='); - if (eq_pos == std::string::npos) { - LOG_ERROR("error: cache option '%s' missing '=' separator", token.c_str()); - return false; - } - std::string key = token.substr(0, eq_pos); - std::string val = token.substr(eq_pos + 1); - try { - if (key == "threshold") { - if (cache_mode == "easycache" || cache_mode == "ucache") { - cache_params.reuse_threshold = std::stof(val); - } else { - cache_params.residual_diff_threshold = std::stof(val); - } - } else if (key == "start") { - cache_params.start_percent = std::stof(val); - } else if (key == "end") { - cache_params.end_percent = std::stof(val); - } else if (key == "decay") { - cache_params.error_decay_rate = std::stof(val); - } else if (key == "relative") { - cache_params.use_relative_threshold = (std::stof(val) != 0.0f); - } else if (key == "reset") { - cache_params.reset_error_on_compute = (std::stof(val) != 0.0f); - } else if (key == "Fn" || key == "fn") { - cache_params.Fn_compute_blocks = std::stoi(val); - } else if (key == "Bn" || key == "bn") { - cache_params.Bn_compute_blocks = std::stoi(val); - } else if (key == "warmup") { - if (cache_mode == "spectrum") { - cache_params.spectrum_warmup_steps = std::stoi(val); - } else { - cache_params.max_warmup_steps = std::stoi(val); - } - } else if (key == "w") { - cache_params.spectrum_w = std::stof(val); - } else if (key == "m") { - cache_params.spectrum_m = std::stoi(val); - } else if (key == "lam") { - cache_params.spectrum_lam = std::stof(val); - } else if (key == "window") { - cache_params.spectrum_window_size = std::stoi(val); - } else if (key == "flex") { - cache_params.spectrum_flex_window = std::stof(val); - } else if (key == "stop") { - cache_params.spectrum_stop_percent = std::stof(val); - } else { - LOG_ERROR("error: unknown cache parameter '%s'", key.c_str()); - return false; - } - } catch (const std::exception&) { - LOG_ERROR("error: invalid value '%s' for parameter '%s'", val.c_str(), key.c_str()); - return false; - } - } - return true; - }; - - if (!cache_mode.empty()) { - if (cache_mode == "easycache") { - cache_params.mode = SD_CACHE_EASYCACHE; - } else if (cache_mode == "ucache") { - cache_params.mode = SD_CACHE_UCACHE; - } else if (cache_mode == "dbcache") { - cache_params.mode = SD_CACHE_DBCACHE; - } else if (cache_mode == "taylorseer") { - cache_params.mode = SD_CACHE_TAYLORSEER; - } else if (cache_mode == "cache-dit") { - cache_params.mode = SD_CACHE_CACHE_DIT; - } else if (cache_mode == "spectrum") { - cache_params.mode = SD_CACHE_SPECTRUM; - } - - if (!cache_option.empty()) { - if (!parse_named_params(cache_option)) { - return false; - } - } - - if (cache_mode == "easycache" || cache_mode == "ucache") { - if (cache_params.reuse_threshold < 0.0f) { - LOG_ERROR("error: cache threshold must be non-negative"); - return false; - } - if (cache_params.start_percent < 0.0f || cache_params.start_percent >= 1.0f || - cache_params.end_percent <= 0.0f || cache_params.end_percent > 1.0f || - cache_params.start_percent >= cache_params.end_percent) { - LOG_ERROR("error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0"); - return false; - } - } - } - - if (cache_params.mode == SD_CACHE_DBCACHE || - cache_params.mode == SD_CACHE_TAYLORSEER || - cache_params.mode == SD_CACHE_CACHE_DIT) { - if (!scm_mask.empty()) { - cache_params.scm_mask = scm_mask.c_str(); - } - cache_params.scm_policy_dynamic = scm_policy_dynamic; - } - - sample_params.guidance.slg.layers = skip_layers.data(); - sample_params.guidance.slg.layer_count = skip_layers.size(); - sample_params.custom_sigmas = custom_sigmas.data(); - sample_params.custom_sigmas_count = static_cast(custom_sigmas.size()); - high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.data(); - high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size(); - - if (mode == VID_GEN && video_frames <= 0) { - return false; - } - - if (mode == VID_GEN && fps <= 0) { - return false; - } - - if (sample_params.shifted_timestep < 0 || sample_params.shifted_timestep > 1000) { - return false; - } - - if (upscale_repeats < 1) { - return false; - } - - if (upscale_tile_size < 1) { - return false; - } - - if (mode == UPSCALE) { - if (init_image_path.length() == 0) { - LOG_ERROR("error: upscale mode needs an init image (--init-img)\n"); - return false; - } - } - - if (seed < 0) { - srand((int)time(nullptr)); - seed = rand(); - } - - extract_and_remove_lora(lora_model_dir); - - return true; - } - - std::string to_string() const { - char* sample_params_str = sd_sample_params_to_str(&sample_params); - char* high_noise_sample_params_str = sd_sample_params_to_str(&high_noise_sample_params); - - std::ostringstream lora_ss; - lora_ss << "{\n"; - for (auto it = lora_map.begin(); it != lora_map.end(); ++it) { - lora_ss << " \"" << it->first << "\": \"" << it->second << "\""; - if (std::next(it) != lora_map.end()) { - lora_ss << ","; - } - lora_ss << "\n"; - } - lora_ss << " }"; - std::string loras_str = lora_ss.str(); - - lora_ss = std::ostringstream(); - ; - lora_ss << "{\n"; - for (auto it = high_noise_lora_map.begin(); it != high_noise_lora_map.end(); ++it) { - lora_ss << " \"" << it->first << "\": \"" << it->second << "\""; - if (std::next(it) != high_noise_lora_map.end()) { - lora_ss << ","; - } - lora_ss << "\n"; - } - lora_ss << " }"; - std::string high_noise_loras_str = lora_ss.str(); - - std::ostringstream oss; - oss << "SDGenerationParams {\n" - << " loras: \"" << loras_str << "\",\n" - << " high_noise_loras: \"" << high_noise_loras_str << "\",\n" - << " prompt: \"" << prompt << "\",\n" - << " negative_prompt: \"" << negative_prompt << "\",\n" - << " clip_skip: " << clip_skip << ",\n" - << " width: " << width << ",\n" - << " height: " << height << ",\n" - << " batch_count: " << batch_count << ",\n" - << " init_image_path: \"" << init_image_path << "\",\n" - << " end_image_path: \"" << end_image_path << "\",\n" - << " mask_image_path: \"" << mask_image_path << "\",\n" - << " control_image_path: \"" << control_image_path << "\",\n" - << " ref_image_paths: " << vec_str_to_string(ref_image_paths) << ",\n" - << " control_video_path: \"" << control_video_path << "\",\n" - << " auto_resize_ref_image: " << (auto_resize_ref_image ? "true" : "false") << ",\n" - << " increase_ref_index: " << (increase_ref_index ? "true" : "false") << ",\n" - << " pm_id_images_dir: \"" << pm_id_images_dir << "\",\n" - << " pm_id_embed_path: \"" << pm_id_embed_path << "\",\n" - << " pm_style_strength: " << pm_style_strength << ",\n" - << " skip_layers: " << vec_to_string(skip_layers) << ",\n" - << " sample_params: " << sample_params_str << ",\n" - << " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n" - << " high_noise_sample_params: " << high_noise_sample_params_str << ",\n" - << " custom_sigmas: " << vec_to_string(custom_sigmas) << ",\n" - << " cache_mode: \"" << cache_mode << "\",\n" - << " cache_option: \"" << cache_option << "\",\n" - << " cache: " - << (cache_params.mode != SD_CACHE_DISABLED ? "enabled" : "disabled") - << " (threshold=" << cache_params.reuse_threshold - << ", start=" << cache_params.start_percent - << ", end=" << cache_params.end_percent << "),\n" - << " moe_boundary: " << moe_boundary << ",\n" - << " video_frames: " << video_frames << ",\n" - << " fps: " << fps << ",\n" - << " vace_strength: " << vace_strength << ",\n" - << " strength: " << strength << ",\n" - << " control_strength: " << control_strength << ",\n" - << " seed: " << seed << ",\n" - << " upscale_repeats: " << upscale_repeats << ",\n" - << " upscale_tile_size: " << upscale_tile_size << ",\n" - << " vae_tiling_params: { " - << vae_tiling_params.enabled << ", " - << vae_tiling_params.tile_size_x << ", " - << vae_tiling_params.tile_size_y << ", " - << vae_tiling_params.target_overlap << ", " - << vae_tiling_params.rel_size_x << ", " - << vae_tiling_params.rel_size_y << " },\n" - << "}"; - free(sample_params_str); - free(high_noise_sample_params_str); - return oss.str(); - } -}; - -static std::string version_string() { - return std::string("stable-diffusion.cpp version ") + sd_version() + ", commit " + sd_commit(); -} - -std::string get_image_params(const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed) { - std::string parameter_string; - if (gen_params.prompt_with_lora.size() != 0) { - parameter_string += gen_params.prompt_with_lora + "\n"; - } else { - parameter_string += gen_params.prompt + "\n"; - } - if (gen_params.negative_prompt.size() != 0) { - parameter_string += "Negative prompt: " + gen_params.negative_prompt + "\n"; - } - parameter_string += "Steps: " + std::to_string(gen_params.sample_params.sample_steps) + ", "; - parameter_string += "CFG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", "; - if (gen_params.sample_params.guidance.slg.scale != 0 && gen_params.skip_layers.size() != 0) { - parameter_string += "SLG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", "; - parameter_string += "Skip layers: ["; - for (const auto& layer : gen_params.skip_layers) { - parameter_string += std::to_string(layer) + ", "; - } - parameter_string += "], "; - parameter_string += "Skip layer start: " + std::to_string(gen_params.sample_params.guidance.slg.layer_start) + ", "; - parameter_string += "Skip layer end: " + std::to_string(gen_params.sample_params.guidance.slg.layer_end) + ", "; - } - parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", "; - parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", "; - parameter_string += "Seed: " + std::to_string(seed) + ", "; - parameter_string += "Size: " + std::to_string(gen_params.get_resolved_width()) + "x" + std::to_string(gen_params.get_resolved_height()) + ", "; - parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", "; - parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", "; - if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) { - parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", "; - } - parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method)); - if (!gen_params.custom_sigmas.empty()) { - parameter_string += ", Custom Sigmas: ["; - for (size_t i = 0; i < gen_params.custom_sigmas.size(); ++i) { - std::ostringstream oss; - oss << std::fixed << std::setprecision(4) << gen_params.custom_sigmas[i]; - parameter_string += oss.str() + (i == gen_params.custom_sigmas.size() - 1 ? "" : ", "); - } - parameter_string += "]"; - } else if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { // Only show schedule if not using custom sigmas - parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler)); - } - parameter_string += ", "; - for (const auto& te : {ctx_params.clip_l_path, ctx_params.clip_g_path, ctx_params.t5xxl_path, ctx_params.llm_path, ctx_params.llm_vision_path}) { - if (!te.empty()) { - parameter_string += "TE: " + sd_basename(te) + ", "; - } - } - if (!ctx_params.diffusion_model_path.empty()) { - parameter_string += "Unet: " + sd_basename(ctx_params.diffusion_model_path) + ", "; - } - if (!ctx_params.vae_path.empty()) { - parameter_string += "VAE: " + sd_basename(ctx_params.vae_path) + ", "; - } - if (gen_params.clip_skip != -1) { - parameter_string += "Clip skip: " + std::to_string(gen_params.clip_skip) + ", "; - } - parameter_string += "Version: stable-diffusion.cpp"; - return parameter_string; -} diff --git a/examples/common/log.cpp b/examples/common/log.cpp index 44fcd1e43..2c4343912 100644 --- a/examples/common/log.cpp +++ b/examples/common/log.cpp @@ -1,5 +1,7 @@ #include "log.h" +#include + bool log_verbose = false; bool log_color = false; @@ -34,17 +36,12 @@ void print_utf8(FILE* stream, const char* utf8) { return; } - wchar_t* wbuf = (wchar_t*)malloc(wlen * sizeof(wchar_t)); - if (!wbuf) { - return; - } + std::vector wbuf(static_cast(wlen)); - MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wbuf, wlen); + MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wbuf.data(), wlen); DWORD written; - WriteConsoleW(h, wbuf, wlen - 1, &written, NULL); - - free(wbuf); + WriteConsoleW(h, wbuf.data(), wlen - 1, &written, NULL); } else { DWORD written; WriteFile(h, utf8, (DWORD)strlen(utf8), &written, NULL); diff --git a/examples/common/media_io.cpp b/examples/common/media_io.cpp index a38513b9d..e2e1ca5a3 100644 --- a/examples/common/media_io.cpp +++ b/examples/common/media_io.cpp @@ -1,5 +1,6 @@ -#include "log.h" #include "media_io.h" +#include "log.h" +#include "resource_owners.hpp" #include #include @@ -30,9 +31,121 @@ #include "webp/mux.h" #endif +#ifdef SD_USE_WEBM +#include "mkvmuxer/mkvmuxer.h" +#include "mkvmuxer/mkvwriter.h" +#endif + namespace fs = std::filesystem; -namespace { +#ifdef SD_USE_WEBP +struct WebPFreeDeleter { + void operator()(void* ptr) const { + if (ptr != nullptr) { + WebPFree(ptr); + } + } +}; + +struct WebPMuxDeleter { + void operator()(WebPMux* mux) const { + if (mux != nullptr) { + WebPMuxDelete(mux); + } + } +}; + +struct WebPAnimEncoderDeleter { + void operator()(WebPAnimEncoder* enc) const { + if (enc != nullptr) { + WebPAnimEncoderDelete(enc); + } + } +}; + +struct WebPDataGuard { + WebPDataGuard() { + WebPDataInit(&data); + } + + ~WebPDataGuard() { + WebPDataClear(&data); + } + + WebPData data; +}; + +struct WebPPictureGuard { + WebPPictureGuard() + : initialized(WebPPictureInit(&picture) != 0) { + } + + ~WebPPictureGuard() { + if (initialized) { + WebPPictureFree(&picture); + } + } + + WebPPicture picture; + bool initialized; +}; + +using WebPBufferPtr = std::unique_ptr; +using WebPMuxPtr = std::unique_ptr; +using WebPAnimEncoderPtr = std::unique_ptr; +#endif + +#ifdef SD_USE_WEBM +class MemoryMkvWriter : public mkvmuxer::IMkvWriter { +public: + mkvmuxer::int32 Write(const void* buf, mkvmuxer::uint32 len) override { + if (buf == nullptr && len > 0) { + return -1; + } + const size_t end_pos = position_ + static_cast(len); + if (end_pos > data_.size()) { + data_.resize(end_pos); + } + if (len > 0) { + memcpy(data_.data() + position_, buf, len); + } + position_ = end_pos; + return 0; + } + + mkvmuxer::int64 Position() const override { + return static_cast(position_); + } + + mkvmuxer::int32 Position(mkvmuxer::int64 position) override { + if (position < 0) { + return -1; + } + const size_t target = static_cast(position); + if (target > data_.size()) { + data_.resize(target); + } + position_ = target; + return 0; + } + + bool Seekable() const override { + return true; + } + + void ElementStartNotify(mkvmuxer::uint64, mkvmuxer::int64) override { + } + + const std::vector& data() const { + return data_; + } + +private: + std::vector data_; + size_t position_ = 0; +}; +#endif + bool read_binary_file_bytes(const char* path, std::vector& data) { std::ifstream fin(fs::path(path), std::ios::binary); if (!fin) { @@ -71,6 +184,13 @@ bool write_binary_file_bytes(const std::string& path, const std::vector return true; } +uint32_t read_u32_le_bytes(const uint8_t* data) { + return static_cast(data[0]) | + (static_cast(data[1]) << 8) | + (static_cast(data[2]) << 16) | + (static_cast(data[3]) << 24); +} + int stbi_ext_write_png_to_func(stbi_write_func* func, void* context, int x, @@ -146,27 +266,25 @@ uint8_t* decode_webp_image_to_buffer(const uint8_t* data, if (expected_channel == 1) { int decoded_width = width; int decoded_height = height; - uint8_t* decoded = features.has_alpha - ? WebPDecodeRGBA(data, size, &decoded_width, &decoded_height) - : WebPDecodeRGB(data, size, &decoded_width, &decoded_height); + WebPBufferPtr decoded(features.has_alpha + ? WebPDecodeRGBA(data, size, &decoded_width, &decoded_height) + : WebPDecodeRGB(data, size, &decoded_width, &decoded_height)); if (decoded == nullptr) { return nullptr; } - uint8_t* grayscale = (uint8_t*)malloc(pixel_count); + FreeUniquePtr grayscale((uint8_t*)malloc(pixel_count)); if (grayscale == nullptr) { - WebPFree(decoded); return nullptr; } const int decoded_channels = features.has_alpha ? 4 : 3; for (size_t i = 0; i < pixel_count; ++i) { - const uint8_t* src = decoded + i * decoded_channels; - grayscale[i] = static_cast((77 * src[0] + 150 * src[1] + 29 * src[2] + 128) >> 8); + const uint8_t* src = decoded.get() + i * decoded_channels; + grayscale.get()[i] = static_cast((77 * src[0] + 150 * src[1] + 29 * src[2] + 128) >> 8); } - WebPFree(decoded); - return grayscale; + return grayscale.release(); } if (expected_channel != 3 && expected_channel != 4) { @@ -175,23 +293,21 @@ uint8_t* decode_webp_image_to_buffer(const uint8_t* data, int decoded_width = width; int decoded_height = height; - uint8_t* decoded = (expected_channel == 4) - ? WebPDecodeRGBA(data, size, &decoded_width, &decoded_height) - : WebPDecodeRGB(data, size, &decoded_width, &decoded_height); + WebPBufferPtr decoded((expected_channel == 4) + ? WebPDecodeRGBA(data, size, &decoded_width, &decoded_height) + : WebPDecodeRGB(data, size, &decoded_width, &decoded_height)); if (decoded == nullptr) { return nullptr; } const size_t out_size = pixel_count * static_cast(expected_channel); - uint8_t* output = (uint8_t*)malloc(out_size); + FreeUniquePtr output((uint8_t*)malloc(out_size)); if (output == nullptr) { - WebPFree(decoded); return nullptr; } - memcpy(output, decoded, out_size); - WebPFree(decoded); - return output; + memcpy(output.get(), decoded.get(), out_size); + return output.release(); } std::string build_webp_xmp_packet(const std::string& parameters) { @@ -243,30 +359,29 @@ bool encode_webp_image_to_vector(const uint8_t* image, return false; } - uint8_t* encoded = nullptr; - size_t encoded_size = (input_channels == 4) - ? WebPEncodeRGBA(input_image, width, height, width * input_channels, static_cast(quality), &encoded) - : WebPEncodeRGB(input_image, width, height, width * input_channels, static_cast(quality), &encoded); + uint8_t* encoded_raw = nullptr; + size_t encoded_size = (input_channels == 4) + ? WebPEncodeRGBA(input_image, width, height, width * input_channels, static_cast(quality), &encoded_raw) + : WebPEncodeRGB(input_image, width, height, width * input_channels, static_cast(quality), &encoded_raw); + WebPBufferPtr encoded(encoded_raw); if (encoded == nullptr || encoded_size == 0) { return false; } - out.assign(encoded, encoded + encoded_size); - WebPFree(encoded); + out.assign(encoded.get(), encoded.get() + encoded_size); if (parameters.empty()) { return true; } WebPData image_data; - WebPData assembled_data; WebPDataInit(&image_data); - WebPDataInit(&assembled_data); + WebPDataGuard assembled_data; image_data.bytes = out.data(); image_data.size = out.size(); - WebPMux* mux = WebPMuxNew(); + WebPMuxPtr mux(WebPMuxNew()); if (mux == nullptr) { return false; } @@ -277,18 +392,86 @@ bool encode_webp_image_to_vector(const uint8_t* image, xmp_data.bytes = reinterpret_cast(xmp_packet.data()); xmp_data.size = xmp_packet.size(); - const bool ok = WebPMuxSetImage(mux, &image_data, 1) == WEBP_MUX_OK && - WebPMuxSetChunk(mux, "XMP ", &xmp_data, 1) == WEBP_MUX_OK && - WebPMuxAssemble(mux, &assembled_data) == WEBP_MUX_OK; + const bool ok = WebPMuxSetImage(mux.get(), &image_data, 1) == WEBP_MUX_OK && + WebPMuxSetChunk(mux.get(), "XMP ", &xmp_data, 1) == WEBP_MUX_OK && + WebPMuxAssemble(mux.get(), &assembled_data.data) == WEBP_MUX_OK; if (ok) { - out.assign(assembled_data.bytes, assembled_data.bytes + assembled_data.size); + out.assign(assembled_data.data.bytes, assembled_data.data.bytes + assembled_data.data.size); } - WebPDataClear(&assembled_data); - WebPMuxDelete(mux); return ok; } + +#ifdef SD_USE_WEBM +bool extract_vp8_frame_from_webp(const std::vector& webp_data, std::vector& vp8_frame) { + if (!is_webp_signature(webp_data.data(), webp_data.size())) { + return false; + } + + size_t offset = 12; + while (offset + 8 <= webp_data.size()) { + const uint8_t* chunk = webp_data.data() + offset; + const uint32_t chunk_len = read_u32_le_bytes(chunk + 4); + const size_t chunk_start = offset + 8; + const size_t padded_len = static_cast(chunk_len) + (chunk_len & 1u); + + if (chunk_start + chunk_len > webp_data.size()) { + return false; + } + + if (memcmp(chunk, "VP8 ", 4) == 0) { + vp8_frame.assign(webp_data.data() + chunk_start, + webp_data.data() + chunk_start + chunk_len); + return !vp8_frame.empty(); + } + + offset = chunk_start + padded_len; + } + + return false; +} + +bool encode_sd_image_to_vp8_frame(const sd_image_t& image, int quality, std::vector& vp8_frame) { + if (image.data == nullptr || image.width == 0 || image.height == 0) { + return false; + } + + const int width = static_cast(image.width); + const int height = static_cast(image.height); + const int input_channel = static_cast(image.channel); + if (input_channel != 1 && input_channel != 3 && input_channel != 4) { + return false; + } + + std::vector rgb_buffer; + const uint8_t* rgb_data = image.data; + if (input_channel == 1) { + rgb_buffer.resize(static_cast(width) * static_cast(height) * 3); + for (int i = 0; i < width * height; ++i) { + rgb_buffer[i * 3 + 0] = image.data[i]; + rgb_buffer[i * 3 + 1] = image.data[i]; + rgb_buffer[i * 3 + 2] = image.data[i]; + } + rgb_data = rgb_buffer.data(); + } else if (input_channel == 4) { + rgb_buffer.resize(static_cast(width) * static_cast(height) * 3); + for (int i = 0; i < width * height; ++i) { + rgb_buffer[i * 3 + 0] = image.data[i * 4 + 0]; + rgb_buffer[i * 3 + 1] = image.data[i * 4 + 1]; + rgb_buffer[i * 3 + 2] = image.data[i * 4 + 2]; + } + rgb_data = rgb_buffer.data(); + } + + std::vector encoded_webp; + if (!encode_webp_image_to_vector(rgb_data, width, height, 3, "", quality, encoded_webp)) { + return false; + } + + return extract_vp8_frame_from_webp(encoded_webp, vp8_frame); +} +#endif #endif uint8_t* load_image_common(bool from_memory, @@ -300,19 +483,19 @@ uint8_t* load_image_common(bool from_memory, int expected_height, int expected_channel) { const char* image_path; - uint8_t* image_buffer = nullptr; + FreeUniquePtr image_buffer; int source_channel_count = 0; #ifdef SD_USE_WEBP if (from_memory) { image_path = "memory"; if (len > 0 && is_webp_signature(reinterpret_cast(image_path_or_bytes), static_cast(len))) { - image_buffer = decode_webp_image_to_buffer(reinterpret_cast(image_path_or_bytes), - static_cast(len), - width, - height, - expected_channel, - source_channel_count); + image_buffer.reset(decode_webp_image_to_buffer(reinterpret_cast(image_path_or_bytes), + static_cast(len), + width, + height, + expected_channel, + source_channel_count)); } } else { image_path = image_path_or_bytes; @@ -326,12 +509,12 @@ uint8_t* load_image_common(bool from_memory, LOG_ERROR("load image from '%s' failed", image_path_or_bytes); return nullptr; } - image_buffer = decode_webp_image_to_buffer(file_bytes.data(), - file_bytes.size(), - width, - height, - expected_channel, - source_channel_count); + image_buffer.reset(decode_webp_image_to_buffer(file_bytes.data(), + file_bytes.size(), + width, + height, + expected_channel, + source_channel_count)); } } #endif @@ -339,15 +522,15 @@ uint8_t* load_image_common(bool from_memory, if (from_memory) { image_path = "memory"; if (image_buffer == nullptr) { - int c = 0; - image_buffer = (uint8_t*)stbi_load_from_memory((const stbi_uc*)image_path_or_bytes, len, &width, &height, &c, expected_channel); + int c = 0; + image_buffer.reset((uint8_t*)stbi_load_from_memory((const stbi_uc*)image_path_or_bytes, len, &width, &height, &c, expected_channel)); source_channel_count = c; } } else { image_path = image_path_or_bytes; if (image_buffer == nullptr) { - int c = 0; - image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel); + int c = 0; + image_buffer.reset((uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel)); source_channel_count = c; } } @@ -362,17 +545,14 @@ uint8_t* load_image_common(bool from_memory, expected_channel, source_channel_count, image_path); - free(image_buffer); return nullptr; } if (width <= 0) { LOG_ERROR("error: the width of image must be greater than 0, image_path = %s", image_path); - free(image_buffer); return nullptr; } if (height <= 0) { LOG_ERROR("error: the height of image must be greater than 0, image_path = %s", image_path); - free(image_buffer); return nullptr; } @@ -393,43 +573,39 @@ uint8_t* load_image_common(bool from_memory, if (crop_x != 0 || crop_y != 0) { LOG_INFO("crop input image from %dx%d to %dx%d, image_path = %s", width, height, crop_w, crop_h, image_path); - uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel); + FreeUniquePtr cropped_image_buffer((uint8_t*)malloc(crop_w * crop_h * expected_channel)); if (cropped_image_buffer == nullptr) { LOG_ERROR("error: allocate memory for crop\n"); - free(image_buffer); return nullptr; } for (int row = 0; row < crop_h; row++) { - uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel; - uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel; + uint8_t* src = image_buffer.get() + ((crop_y + row) * width + crop_x) * expected_channel; + uint8_t* dst = cropped_image_buffer.get() + (row * crop_w) * expected_channel; memcpy(dst, src, crop_w * expected_channel); } - width = crop_w; - height = crop_h; - free(image_buffer); - image_buffer = cropped_image_buffer; + width = crop_w; + height = crop_h; + image_buffer = std::move(cropped_image_buffer); } LOG_INFO("resize input image from %dx%d to %dx%d", width, height, expected_width, expected_height); - uint8_t* resized_image_buffer = (uint8_t*)malloc(expected_height * expected_width * expected_channel); + FreeUniquePtr resized_image_buffer((uint8_t*)malloc(expected_height * expected_width * expected_channel)); if (resized_image_buffer == nullptr) { LOG_ERROR("error: allocate memory for resize input image\n"); - free(image_buffer); return nullptr; } - stbir_resize(image_buffer, width, height, 0, - resized_image_buffer, expected_width, expected_height, 0, STBIR_TYPE_UINT8, + stbir_resize(image_buffer.get(), width, height, 0, + resized_image_buffer.get(), expected_width, expected_height, 0, STBIR_TYPE_UINT8, expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0, STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_FILTER_BOX, STBIR_FILTER_BOX, STBIR_COLORSPACE_SRGB, nullptr); - width = expected_width; - height = expected_height; - free(image_buffer); - image_buffer = resized_image_buffer; + width = expected_width; + height = expected_height; + image_buffer = std::move(resized_image_buffer); } - return image_buffer; + return image_buffer.release(); } typedef struct { @@ -444,7 +620,32 @@ void write_u32_le(FILE* f, uint32_t val) { void write_u16_le(FILE* f, uint16_t val) { fwrite(&val, 2, 1, f); } -} // namespace + +void write_u32_le(std::vector& data, uint32_t val) { + data.push_back(static_cast(val & 0xFF)); + data.push_back(static_cast((val >> 8) & 0xFF)); + data.push_back(static_cast((val >> 16) & 0xFF)); + data.push_back(static_cast((val >> 24) & 0xFF)); +} + +void write_u16_le(std::vector& data, uint16_t val) { + data.push_back(static_cast(val & 0xFF)); + data.push_back(static_cast((val >> 8) & 0xFF)); +} + +void patch_u32_le(std::vector& data, size_t offset, uint32_t val) { + if (offset + 4 > data.size()) { + return; + } + data[offset + 0] = static_cast(val & 0xFF); + data[offset + 1] = static_cast((val >> 8) & 0xFF); + data[offset + 2] = static_cast((val >> 16) & 0xFF); + data[offset + 3] = static_cast((val >> 24) & 0xFF); +} + +void write_fourcc(std::vector& data, const char* fourcc) { + data.insert(data.end(), fourcc, fourcc + 4); +} EncodedImageFormat encoded_image_format_from_path(const std::string& path) { std::string ext = fs::path(path).extension().string(); @@ -559,8 +760,9 @@ bool load_sd_image_from_file(sd_image_t* image, if (image->data == nullptr) { return false; } - image->width = width; - image->height = height; + image->width = width; + image->height = height; + image->channel = expected_channel; return true; } @@ -574,16 +776,10 @@ uint8_t* load_image_from_memory(const char* image_bytes, return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel); } -int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { +std::vector create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) { if (num_images == 0) { fprintf(stderr, "Error: Image array is empty.\n"); - return -1; - } - - FILE* f = fopen(filename, "wb"); - if (!f) { - perror("Error opening file for writing"); - return -1; + return {}; } uint32_t width = images[0].width; @@ -591,152 +787,153 @@ int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int uint32_t channels = images[0].channel; if (channels != 3 && channels != 4) { fprintf(stderr, "Error: Unsupported channel count: %u\n", channels); - fclose(f); - return -1; - } - - fwrite("RIFF", 4, 1, f); - long riff_size_pos = ftell(f); - write_u32_le(f, 0); - fwrite("AVI ", 4, 1, f); - - fwrite("LIST", 4, 1, f); - write_u32_le(f, 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40); - fwrite("hdrl", 4, 1, f); - - fwrite("avih", 4, 1, f); - write_u32_le(f, 56); - write_u32_le(f, 1000000 / fps); - write_u32_le(f, 0); - write_u32_le(f, 0); - write_u32_le(f, 0x110); - write_u32_le(f, num_images); - write_u32_le(f, 0); - write_u32_le(f, 1); - write_u32_le(f, width * height * 3); - write_u32_le(f, width); - write_u32_le(f, height); - write_u32_le(f, 0); - write_u32_le(f, 0); - write_u32_le(f, 0); - write_u32_le(f, 0); - - fwrite("LIST", 4, 1, f); - write_u32_le(f, 4 + 8 + 56 + 8 + 40); - fwrite("strl", 4, 1, f); - - fwrite("strh", 4, 1, f); - write_u32_le(f, 56); - fwrite("vids", 4, 1, f); - fwrite("MJPG", 4, 1, f); - write_u32_le(f, 0); - write_u16_le(f, 0); - write_u16_le(f, 0); - write_u32_le(f, 0); - write_u32_le(f, 1); - write_u32_le(f, fps); - write_u32_le(f, 0); - write_u32_le(f, num_images); - write_u32_le(f, width * height * 3); - write_u32_le(f, (uint32_t)-1); - write_u32_le(f, 0); - write_u16_le(f, 0); - write_u16_le(f, 0); - write_u16_le(f, 0); - write_u16_le(f, 0); - - fwrite("strf", 4, 1, f); - write_u32_le(f, 40); - write_u32_le(f, 40); - write_u32_le(f, width); - write_u32_le(f, height); - write_u16_le(f, 1); - write_u16_le(f, 24); - fwrite("MJPG", 4, 1, f); - write_u32_le(f, width * height * 3); - write_u32_le(f, 0); - write_u32_le(f, 0); - write_u32_le(f, 0); - write_u32_le(f, 0); - - fwrite("LIST", 4, 1, f); - long movi_size_pos = ftell(f); - write_u32_le(f, 0); - fwrite("movi", 4, 1, f); - - avi_index_entry* index = (avi_index_entry*)malloc(sizeof(avi_index_entry) * num_images); - if (!index) { - fclose(f); - return -1; + return {}; } - struct { - uint8_t* buf; - size_t size; - } jpeg_data; + // stb_image_write changes JPEG sampling behavior above quality 90. + // MJPG AVI playback is more compatible when we keep the encoder on the + // <= 90 path. + const int mjpg_quality = std::clamp(quality, 1, 90); + + std::vector avi_data; + avi_data.reserve(static_cast(num_images) * 1024); + + write_fourcc(avi_data, "RIFF"); + const size_t riff_size_pos = avi_data.size(); + write_u32_le(avi_data, 0); + write_fourcc(avi_data, "AVI "); + + write_fourcc(avi_data, "LIST"); + write_u32_le(avi_data, 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40); + write_fourcc(avi_data, "hdrl"); + + write_fourcc(avi_data, "avih"); + write_u32_le(avi_data, 56); + write_u32_le(avi_data, 1000000 / fps); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0x110); + write_u32_le(avi_data, num_images); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 1); + write_u32_le(avi_data, width * height * 3); + write_u32_le(avi_data, width); + write_u32_le(avi_data, height); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + + write_fourcc(avi_data, "LIST"); + write_u32_le(avi_data, 4 + 8 + 56 + 8 + 40); + write_fourcc(avi_data, "strl"); + + write_fourcc(avi_data, "strh"); + write_u32_le(avi_data, 56); + write_fourcc(avi_data, "vids"); + write_fourcc(avi_data, "MJPG"); + write_u32_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 1); + write_u32_le(avi_data, fps); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, num_images); + write_u32_le(avi_data, width * height * 3); + write_u32_le(avi_data, static_cast(-1)); + write_u32_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u16_le(avi_data, 0); + + write_fourcc(avi_data, "strf"); + write_u32_le(avi_data, 40); + write_u32_le(avi_data, 40); + write_u32_le(avi_data, width); + write_u32_le(avi_data, height); + write_u16_le(avi_data, 1); + write_u16_le(avi_data, 24); + write_fourcc(avi_data, "MJPG"); + write_u32_le(avi_data, width * height * 3); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + + write_fourcc(avi_data, "LIST"); + const size_t movi_size_pos = avi_data.size(); + write_u32_le(avi_data, 0); + write_fourcc(avi_data, "movi"); + + std::vector index(static_cast(num_images)); + std::vector jpeg_data; for (int i = 0; i < num_images; i++) { - jpeg_data.buf = nullptr; - jpeg_data.size = 0; + jpeg_data.clear(); auto write_to_buf = [](void* context, void* data, int size) { - auto jd = (decltype(jpeg_data)*)context; - jd->buf = (uint8_t*)realloc(jd->buf, jd->size + size); - memcpy(jd->buf + jd->size, data, size); - jd->size += size; + auto* buffer = reinterpret_cast*>(context); + const uint8_t* src = reinterpret_cast(data); + buffer->insert(buffer->end(), src, src + size); }; - stbi_write_jpg_to_func(write_to_buf, &jpeg_data, images[i].width, images[i].height, channels, images[i].data, quality); + if (!stbi_write_jpg_to_func(write_to_buf, &jpeg_data, images[i].width, images[i].height, channels, images[i].data, mjpg_quality)) { + fprintf(stderr, "Error: Failed to encode JPEG frame.\n"); + return {}; + } - fwrite("00dc", 4, 1, f); - write_u32_le(f, (uint32_t)jpeg_data.size); - index[i].offset = ftell(f) - 8; - index[i].size = (uint32_t)jpeg_data.size; - fwrite(jpeg_data.buf, 1, jpeg_data.size, f); + index[i].offset = static_cast(avi_data.size()); + write_fourcc(avi_data, "00dc"); + write_u32_le(avi_data, static_cast(jpeg_data.size())); + index[i].size = (uint32_t)jpeg_data.size(); + avi_data.insert(avi_data.end(), jpeg_data.begin(), jpeg_data.end()); - if (jpeg_data.size % 2) { - fputc(0, f); + if (jpeg_data.size() % 2) { + avi_data.push_back(0); } - - free(jpeg_data.buf); } - long cur_pos = ftell(f); - long movi_size = cur_pos - movi_size_pos - 4; - fseek(f, movi_size_pos, SEEK_SET); - write_u32_le(f, movi_size); - fseek(f, cur_pos, SEEK_SET); + const size_t movi_size = avi_data.size() - movi_size_pos - 4; + patch_u32_le(avi_data, movi_size_pos, static_cast(movi_size)); - fwrite("idx1", 4, 1, f); - write_u32_le(f, num_images * 16); + write_fourcc(avi_data, "idx1"); + write_u32_le(avi_data, num_images * 16); for (int i = 0; i < num_images; i++) { - fwrite("00dc", 4, 1, f); - write_u32_le(f, 0x10); - write_u32_le(f, index[i].offset); - write_u32_le(f, index[i].size); + write_fourcc(avi_data, "00dc"); + write_u32_le(avi_data, 0x10); + write_u32_le(avi_data, index[i].offset); + write_u32_le(avi_data, index[i].size); } - cur_pos = ftell(f); - long file_size = cur_pos - riff_size_pos - 4; - fseek(f, riff_size_pos, SEEK_SET); - write_u32_le(f, file_size); - fseek(f, cur_pos, SEEK_SET); + const size_t file_size = avi_data.size() - riff_size_pos - 4; + patch_u32_le(avi_data, riff_size_pos, static_cast(file_size)); - fclose(f); - free(index); + return avi_data; +} +int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { + std::vector avi_data = create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality); + if (avi_data.empty()) { + return -1; + } + if (!write_binary_file_bytes(filename, avi_data)) { + perror("Error opening file for writing"); + return -1; + } return 0; } #ifdef SD_USE_WEBP -int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { +std::vector create_animated_webp_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) { if (num_images == 0) { fprintf(stderr, "Error: Image array is empty.\n"); - return -1; + return {}; } if (fps <= 0) { fprintf(stderr, "Error: FPS must be positive.\n"); - return -1; + return {}; } const int width = static_cast(images[0].width); @@ -744,14 +941,14 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images const int channels = static_cast(images[0].channel); if (channels != 1 && channels != 3 && channels != 4) { fprintf(stderr, "Error: Unsupported channel count: %d\n", channels); - return -1; + return {}; } WebPAnimEncoderOptions anim_options; WebPConfig config; if (!WebPAnimEncoderOptionsInit(&anim_options) || !WebPConfigInit(&config)) { fprintf(stderr, "Error: Failed to initialize WebP animation encoder.\n"); - return -1; + return {}; } config.quality = static_cast(quality); @@ -762,34 +959,33 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images } if (!WebPValidateConfig(&config)) { fprintf(stderr, "Error: Invalid WebP encoder configuration.\n"); - return -1; + return {}; } - WebPAnimEncoder* enc = WebPAnimEncoderNew(width, height, &anim_options); + WebPAnimEncoderPtr enc(WebPAnimEncoderNew(width, height, &anim_options)); if (enc == nullptr) { fprintf(stderr, "Error: Could not create WebPAnimEncoder object.\n"); - return -1; + return {}; } const int frame_duration_ms = std::max(1, static_cast(std::lround(1000.0 / static_cast(fps)))); int timestamp_ms = 0; - int ret = -1; for (int i = 0; i < num_images; ++i) { const sd_image_t& image = images[i]; if (static_cast(image.width) != width || static_cast(image.height) != height) { fprintf(stderr, "Error: Frame dimensions do not match.\n"); - goto cleanup; + return {}; } - WebPPicture picture; - if (!WebPPictureInit(&picture)) { + WebPPictureGuard picture; + if (!picture.initialized) { fprintf(stderr, "Error: Failed to initialize WebPPicture.\n"); - goto cleanup; + return {}; } - picture.use_argb = 1; - picture.width = width; - picture.height = height; + picture.picture.use_argb = 1; + picture.picture.width = width; + picture.picture.height = height; bool picture_ok = false; std::vector rgb_buffer; @@ -800,80 +996,194 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images rgb_buffer[p * 3 + 1] = image.data[p]; rgb_buffer[p * 3 + 2] = image.data[p]; } - picture_ok = WebPPictureImportRGB(&picture, rgb_buffer.data(), width * 3) != 0; + picture_ok = WebPPictureImportRGB(&picture.picture, rgb_buffer.data(), width * 3) != 0; } else if (image.channel == 4) { - picture_ok = WebPPictureImportRGBA(&picture, image.data, width * 4) != 0; + picture_ok = WebPPictureImportRGBA(&picture.picture, image.data, width * 4) != 0; } else { - picture_ok = WebPPictureImportRGB(&picture, image.data, width * 3) != 0; + picture_ok = WebPPictureImportRGB(&picture.picture, image.data, width * 3) != 0; } if (!picture_ok) { fprintf(stderr, "Error: Failed to import frame into WebPPicture.\n"); - WebPPictureFree(&picture); - goto cleanup; + return {}; } - if (!WebPAnimEncoderAdd(enc, &picture, timestamp_ms, &config)) { - fprintf(stderr, "Error: Failed to add frame to animated WebP: %s\n", WebPAnimEncoderGetError(enc)); - WebPPictureFree(&picture); - goto cleanup; + if (!WebPAnimEncoderAdd(enc.get(), &picture.picture, timestamp_ms, &config)) { + fprintf(stderr, "Error: Failed to add frame to animated WebP: %s\n", WebPAnimEncoderGetError(enc.get())); + return {}; } - WebPPictureFree(&picture); timestamp_ms += frame_duration_ms; } - if (!WebPAnimEncoderAdd(enc, nullptr, timestamp_ms, nullptr)) { - fprintf(stderr, "Error: Failed to finalize animated WebP frames: %s\n", WebPAnimEncoderGetError(enc)); - goto cleanup; + if (!WebPAnimEncoderAdd(enc.get(), nullptr, timestamp_ms, nullptr)) { + fprintf(stderr, "Error: Failed to finalize animated WebP frames: %s\n", WebPAnimEncoderGetError(enc.get())); + return {}; + } + + WebPDataGuard webp_data; + if (!WebPAnimEncoderAssemble(enc.get(), &webp_data.data)) { + fprintf(stderr, "Error: Failed to assemble animated WebP: %s\n", WebPAnimEncoderGetError(enc.get())); + return {}; + } + + return std::vector(webp_data.data.bytes, webp_data.data.bytes + webp_data.data.size); +} + +int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { + std::vector webp_data = create_animated_webp_from_sd_images_to_vector(images, num_images, fps, quality); + if (webp_data.empty()) { + return -1; } + if (!write_binary_file_bytes(filename, webp_data)) { + perror("Error opening file for writing"); + return -1; + } + return 0; +} +#endif - { - WebPData webp_data; - WebPDataInit(&webp_data); - if (!WebPAnimEncoderAssemble(enc, &webp_data)) { - fprintf(stderr, "Error: Failed to assemble animated WebP: %s\n", WebPAnimEncoderGetError(enc)); - WebPDataClear(&webp_data); - goto cleanup; +#ifdef SD_USE_WEBM +std::vector create_webm_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) { + if (num_images == 0) { + fprintf(stderr, "Error: Image array is empty.\n"); + return {}; + } + if (fps <= 0) { + fprintf(stderr, "Error: FPS must be positive.\n"); + return {}; + } + + const int width = static_cast(images[0].width); + const int height = static_cast(images[0].height); + if (width <= 0 || height <= 0) { + fprintf(stderr, "Error: Invalid frame dimensions.\n"); + return {}; + } + + MemoryMkvWriter writer; + + const int ret = [&]() -> int { + mkvmuxer::Segment segment; + if (!segment.Init(&writer)) { + fprintf(stderr, "Error: Failed to initialize WebM muxer.\n"); + return -1; } - FILE* f = fopen(filename, "wb"); - if (!f) { - perror("Error opening file for writing"); - WebPDataClear(&webp_data); - goto cleanup; + segment.set_mode(mkvmuxer::Segment::kFile); + segment.OutputCues(true); + + const uint64_t track_number = segment.AddVideoTrack(width, height, 0); + if (track_number == 0) { + fprintf(stderr, "Error: Failed to add VP8 video track.\n"); + return -1; } - if (webp_data.size > 0 && fwrite(webp_data.bytes, 1, webp_data.size, f) != webp_data.size) { - fprintf(stderr, "Error: Failed to write animated WebP file.\n"); - fclose(f); - WebPDataClear(&webp_data); - goto cleanup; + if (!segment.CuesTrack(track_number)) { + fprintf(stderr, "Error: Failed to set WebM cues track.\n"); + return -1; } - fclose(f); - WebPDataClear(&webp_data); - } - ret = 0; + mkvmuxer::VideoTrack* video_track = static_cast(segment.GetTrackByNumber(track_number)); + if (video_track != nullptr) { + video_track->set_display_width(static_cast(width)); + video_track->set_display_height(static_cast(height)); + video_track->set_frame_rate(static_cast(fps)); + } + segment.GetSegmentInfo()->set_writing_app("stable-diffusion.cpp"); + segment.GetSegmentInfo()->set_muxing_app("stable-diffusion.cpp"); + + const uint64_t frame_duration_ns = std::max( + 1, static_cast(std::llround(1000000000.0 / static_cast(fps)))); + uint64_t timestamp_ns = 0; + + for (int i = 0; i < num_images; ++i) { + const sd_image_t& image = images[i]; + if (static_cast(image.width) != width || static_cast(image.height) != height) { + fprintf(stderr, "Error: Frame dimensions do not match.\n"); + return -1; + } + + std::vector vp8_frame; + if (!encode_sd_image_to_vp8_frame(image, quality, vp8_frame)) { + fprintf(stderr, "Error: Failed to encode frame %d as VP8.\n", i); + return -1; + } + + if (!segment.AddFrame(vp8_frame.data(), + static_cast(vp8_frame.size()), + track_number, + timestamp_ns, + true)) { + fprintf(stderr, "Error: Failed to mux frame %d into WebM.\n", i); + return -1; + } + + timestamp_ns += frame_duration_ns; + } + + if (!segment.Finalize()) { + fprintf(stderr, "Error: Failed to finalize WebM output.\n"); + return -1; + } + return 0; + }(); + if (ret != 0) { + return {}; + } + return writer.data(); +} -cleanup: - WebPAnimEncoderDelete(enc); - return ret; +int create_webm_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { + std::vector webm_data = create_webm_from_sd_images_to_vector(images, num_images, fps, quality); + if (webm_data.empty()) { + return -1; + } + if (!write_binary_file_bytes(filename, webm_data)) { + perror("Error opening file for writing"); + return -1; + } + return 0; } #endif -int create_video_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { - std::string path = filename ? filename : ""; - auto pos = path.find_last_of('.'); - std::string ext = pos == std::string::npos ? "" : path.substr(pos); - for (char& ch : ext) { - ch = static_cast(tolower(static_cast(ch))); +std::vector create_video_from_sd_images_to_vector(const std::string& output_format, + sd_image_t* images, + int num_images, + int fps, + int quality) { + std::string format = output_format; + std::transform(format.begin(), format.end(), format.begin(), + [](unsigned char c) { return static_cast(tolower(c)); }); + if (!format.empty() && format[0] == '.') { + format.erase(format.begin()); + } + +#ifdef SD_USE_WEBM + if (format == "webm") { + return create_webm_from_sd_images_to_vector(images, num_images, fps, quality); } +#endif #ifdef SD_USE_WEBP - if (ext == ".webp") { - return create_animated_webp_from_sd_images(filename, images, num_images, fps, quality); + if (format == "webp") { + return create_animated_webp_from_sd_images_to_vector(images, num_images, fps, quality); } #endif - return create_mjpg_avi_from_sd_images(filename, images, num_images, fps, quality); + return create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality); +} + +int create_video_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { + std::string path = filename ? filename : ""; + auto pos = path.find_last_of('.'); + std::string ext = pos == std::string::npos ? "" : path.substr(pos); + std::vector video_data = create_video_from_sd_images_to_vector(ext, images, num_images, fps, quality); + if (video_data.empty()) { + return -1; + } + if (!write_binary_file_bytes(filename, video_data)) { + perror("Error opening file for writing"); + return -1; + } + return 0; } diff --git a/examples/common/media_io.h b/examples/common/media_io.h index cb8302906..6b3f6f883 100644 --- a/examples/common/media_io.h +++ b/examples/common/media_io.h @@ -58,6 +58,10 @@ int create_mjpg_avi_from_sd_images(const char* filename, int num_images, int fps, int quality = 90); +std::vector create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, + int num_images, + int fps, + int quality = 90); #ifdef SD_USE_WEBP int create_animated_webp_from_sd_images(const char* filename, @@ -65,6 +69,22 @@ int create_animated_webp_from_sd_images(const char* filename, int num_images, int fps, int quality = 90); +std::vector create_animated_webp_from_sd_images_to_vector(sd_image_t* images, + int num_images, + int fps, + int quality = 90); +#endif + +#ifdef SD_USE_WEBM +int create_webm_from_sd_images(const char* filename, + sd_image_t* images, + int num_images, + int fps, + int quality = 90); +std::vector create_webm_from_sd_images_to_vector(sd_image_t* images, + int num_images, + int fps, + int quality = 90); #endif int create_video_from_sd_images(const char* filename, @@ -72,5 +92,10 @@ int create_video_from_sd_images(const char* filename, int num_images, int fps, int quality = 90); +std::vector create_video_from_sd_images_to_vector(const std::string& output_format, + sd_image_t* images, + int num_images, + int fps, + int quality = 90); #endif // __MEDIA_IO_H__ diff --git a/examples/common/resource_owners.hpp b/examples/common/resource_owners.hpp new file mode 100644 index 000000000..d47134abe --- /dev/null +++ b/examples/common/resource_owners.hpp @@ -0,0 +1,236 @@ +#ifndef __EXAMPLE_RESOURCE_OWNERS_H__ +#define __EXAMPLE_RESOURCE_OWNERS_H__ + +#include +#include +#include +#include +#include +#include + +#include "stable-diffusion.h" + +struct FreeDeleter { + void operator()(void* ptr) const { + free(ptr); + } +}; + +struct FileCloser { + void operator()(FILE* file) const { + if (file != nullptr) { + fclose(file); + } + } +}; + +struct SDCtxDeleter { + void operator()(sd_ctx_t* ctx) const { + if (ctx != nullptr) { + free_sd_ctx(ctx); + } + } +}; + +struct UpscalerCtxDeleter { + void operator()(upscaler_ctx_t* ctx) const { + if (ctx != nullptr) { + free_upscaler_ctx(ctx); + } + } +}; + +template +using FreeUniquePtr = std::unique_ptr; + +using FilePtr = std::unique_ptr; +using SDCtxPtr = std::unique_ptr; +using UpscalerCtxPtr = std::unique_ptr; + +class SDImageOwner { +private: + static sd_image_t copy_image(const sd_image_t& image) { + if (image.data == nullptr) { + return {image.width, image.height, image.channel, nullptr}; + } + + const size_t byte_count = static_cast(image.width) * image.height * image.channel; + uint8_t* raw_copy = static_cast(malloc(byte_count)); + if (raw_copy == nullptr) { + return {0, 0, 0, nullptr}; + } + + std::memcpy(raw_copy, image.data, byte_count); + return {image.width, image.height, image.channel, raw_copy}; + } + + sd_image_t image_ = {0, 0, 0, nullptr}; + +public: + SDImageOwner() = default; + explicit SDImageOwner(sd_image_t image) + : image_(image) { + } + + SDImageOwner(const SDImageOwner& other) + : image_(copy_image(other.image_)) { + } + + SDImageOwner& operator=(const SDImageOwner& other) { + if (this != &other) { + reset(copy_image(other.image_)); + } + return *this; + } + + SDImageOwner(SDImageOwner&& other) noexcept + : image_(other.release()) { + } + + SDImageOwner& operator=(SDImageOwner&& other) noexcept { + if (this != &other) { + reset(); + image_ = other.release(); + } + return *this; + } + + ~SDImageOwner() { + reset(); + } + + sd_image_t* put() { + if (image_.data != nullptr) { + free(image_.data); + image_.data = nullptr; + } + image_.width = 0; + image_.height = 0; + image_.channel = 0; + return &image_; + } + + sd_image_t& get() { + return image_; + } + + const sd_image_t& get() const { + return image_; + } + + sd_image_t release() { + sd_image_t image = image_; + image_ = {0, 0, 0, nullptr}; + return image; + } + + void reset(sd_image_t image = {0, 0, 0, nullptr}) { + if (image_.data != nullptr) { + free(image_.data); + } + image_ = image; + } +}; + +class SDImageVec { +private: + std::vector images_; + +public: + SDImageVec() = default; + + SDImageVec(const SDImageVec&) = delete; + SDImageVec& operator=(const SDImageVec&) = delete; + + SDImageVec(SDImageVec&& other) noexcept + : images_(std::move(other.images_)) { + } + + SDImageVec& operator=(SDImageVec&& other) noexcept { + if (this != &other) { + clear(); + images_ = std::move(other.images_); + } + return *this; + } + + ~SDImageVec() { + clear(); + } + + void push_back(sd_image_t image) { + images_.push_back(image); + } + + void push_back(SDImageOwner&& image) { + images_.push_back(image.release()); + } + + void reserve(size_t count) { + images_.reserve(count); + } + + void adopt(sd_image_t* images, int count) { + clear(); + if (images == nullptr || count <= 0) { + free(images); + return; + } + + images_.reserve(static_cast(count)); + for (int i = 0; i < count; ++i) { + images_.push_back(images[i]); + } + free(images); + } + + size_t size() const { + return images_.size(); + } + + bool empty() const { + return images_.empty(); + } + + int count() const { + return static_cast(images_.size()); + } + + explicit operator bool() const { + return !images_.empty(); + } + + sd_image_t* data() { + return images_.data(); + } + + const sd_image_t* data() const { + return images_.data(); + } + + sd_image_t& operator[](size_t index) { + return images_[index]; + } + + const sd_image_t& operator[](size_t index) const { + return images_[index]; + } + + std::vector& raw() { + return images_; + } + + const std::vector& raw() const { + return images_; + } + + void clear() { + for (sd_image_t& image : images_) { + free(image.data); + image.data = nullptr; + } + images_.clear(); + } +}; + +#endif // __EXAMPLE_RESOURCE_OWNERS_H__ diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index bf2b252bb..b70b525e5 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -50,16 +50,29 @@ if(SD_SERVER_BUILD_FRONTEND AND EXISTS "${FRONTEND_DIR}") set_source_files_properties("${GENERATED_HTML_HEADER}" PROPERTIES GENERATED TRUE) else() - message(WARNING "pnpm not found, frontend build disabled") + if(EXISTS "${GENERATED_HTML_HEADER}") + message(STATUS "pnpm not found; using pre-built frontend header detected at ${GENERATED_HTML_HEADER}") + set(HAVE_FRONTEND_BUILD ON) + add_custom_target(${TARGET}_frontend) + else() + message(WARNING "pnpm not found; frontend build disabled.") + endif() endif() else() message(STATUS "Frontend disabled or directory not found: ${FRONTEND_DIR}") endif() add_executable(${TARGET} + ../common/common.cpp ../common/log.cpp ../common/media_io.cpp main.cpp + runtime.cpp + async_jobs.cpp + routes_index.cpp + routes_openai.cpp + routes_sdapi.cpp + routes_sdcpp.cpp ) if(HAVE_FRONTEND_BUILD) @@ -75,8 +88,13 @@ endif() install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) if(SD_WEBP) + target_compile_definitions(${TARGET} PRIVATE SD_USE_WEBP) target_link_libraries(${TARGET} PRIVATE webp libwebpmux) endif() +if(SD_WEBM) + target_compile_definitions(${TARGET} PRIVATE SD_USE_WEBM) + target_link_libraries(${TARGET} PRIVATE webm) +endif() # due to httplib; it contains a pragma for MSVC, but other things need explicit flags if(WIN32 AND NOT MSVC) diff --git a/examples/server/README.md b/examples/server/README.md index 620586d2e..23b79c9d8 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -1,3 +1,33 @@ +# Example + +The following example starts `sd-server` with a standalone diffusion model, VAE, and LLM text encoder: + +``` +.\bin\Release\sd-server.exe --diffusion-model ..\models\diffusion_models\z_image_turbo_bf16.safetensors --vae ..\models\vae\ae.sft --llm ..\models\text_encoders\qwen_3_4b.safetensors --diffusion-fa --offload-to-cpu -v --cfg-scale 1.0 +``` + +What this example does: + +* `--diffusion-model` selects the standalone diffusion model +* `--vae` selects the VAE decoder +* `--llm` selects the text encoder / language model used by this pipeline +* `--diffusion-fa` enables flash attention in the diffusion model +* `--offload-to-cpu` reduces VRAM pressure by keeping weights in RAM when possible +* `-v` enables verbose logging +* `--cfg-scale 1.0` sets the default CFG scale for generation + +After the server starts successfully: + +* the web UI is available at `http://127.0.0.1:1234/` +* the native async API is available under `/sdcpp/v1/...` +* the compatibility APIs are available under `/v1/...` and `/sdapi/v1/...` + +If you want to use a different host or port, pass: + +```bash +--listen-ip --listen-port +``` + # Frontend ## Build with Frontend @@ -8,7 +38,7 @@ The server can optionally build the web frontend and embed it into the binary as Install the following tools: -* **Node.js** โ‰ฅ 22.18 +* **Node.js** โ‰ฅ 20 https://nodejs.org/ * **pnpm** โ‰ฅ 10 @@ -54,7 +84,7 @@ and embed the generated frontend into the server binary. ## Frontend Repository -The web frontend is maintained in a **separate repository**, https://github.com/leejet/stable-ui. +The web frontend is maintained in a **separate repository**, https://github.com/leejet/sdcpp-webui. If you want to modify the UI or frontend logic, please submit pull requests to the **frontend repository**. @@ -93,11 +123,11 @@ In this case, the server will load and serve the specified `index.html` file ins usage: ./bin/sd-server [options] Svr Options: - -l, --listen-ip server listen ip (default: 127.0.0.1) + -l, --listen-ip server listen ip (default: 127.0.0.1) --serve-html-path path to HTML file to serve at root (optional) --listen-port server listen port (default: 1234) -v, --verbose print extra info - --color colors the logging tags according to level + --color colors the logging tags according to level -h, --help show this help message and exit Context Options: @@ -106,7 +136,8 @@ Context Options: --clip_g path to the clip-g text encoder --clip_vision path to the clip-vision encoder --t5xxl path to the t5xxl text encoder - --llm path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, mistral-small3.2 for flux2, ...) + --llm path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, + mistral-small3.2 for flux2, ...) --llm_vision path to the llm vit --qwen2vl alias of --llm. Deprecated. --qwen2vl_vision alias of --llm_vision. Deprecated. @@ -118,16 +149,18 @@ Context Options: --control-net path to control net model --embd-dir embeddings directory --lora-model-dir lora model directory + --hires-upscalers-dir highres fix upscaler model directory --tensor-type-rules weight type per tensor pattern (example: "^vae\.=f16,model\.=q8_0") --photo-maker path to PHOTOMAKER model --upscale-model path to esrgan model. - -t, --threads number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of - CPU physical cores + -t, --threads number of threads to use during computation (default: -1). If threads <= 0, + then threads will be set to the number of CPU physical cores --chroma-t5-mask-pad t5 mask pad size of chroma - --vae-tile-overlap tile overlap for vae tiling, in fraction of tile size (default: 0.5) - --vae-tiling process vae in tiles to reduce memory usage + --max-vram maximum VRAM budget in GiB for graph-cut segmented execution. 0 disables + graph splitting --force-sdxl-vae-conv-scale force use of conv scale on sdxl vae - --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed + --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM + when needed --mmap whether to memory-map model --control-net-cpu keep controlnet in cpu (for low vram) --clip-on-cpu keep clip in cpu (for low vram) @@ -142,20 +175,19 @@ Context Options: --chroma-disable-dit-mask disable dit mask for chroma --qwen-image-zero-cond-t enable zero_cond_t for qwen image --chroma-enable-t5-mask enable t5 mask for chroma - --type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the - type of the weight file + --type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, + q4_K). If not specified, the default is the type of the weight file --rng RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui) --sampler-rng sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng - --prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, flux2_flow] - --lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. In auto mode, if the model weights - contain any quantized parameters, the at_runtime mode will be used; otherwise, - immediately will be used.The immediately mode may have precision and - compatibility issues with quantized parameters, but it usually offers faster inference - speed and, in some cases, lower memory usage. The at_runtime mode, on the - other hand, is exactly the opposite. - --vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32) - --vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 - (overrides --vae-tile-size) + --prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, + flux2_flow] + --lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is + auto. In auto mode, if the model weights contain any quantized parameters, + the at_runtime mode will be used; otherwise, immediately will be used.The + immediately mode may have precision and compatibility issues with quantized + parameters, but it usually offers faster inference speed and, in some cases, + lower memory usage. The at_runtime mode, on the other hand, is exactly the + opposite. Default Generation Options: -p, --prompt the prompt to render @@ -164,65 +196,97 @@ Default Generation Options: --end-img path to the end image, required by flf2v --mask path to the mask image --control-image path to control image, control net - --control-video path to control video frames, It must be a directory path. The video frames inside should be stored as images in - lexicographical (character) order. For example, if the control video path is - `frames`, the directory contain images such as 00.png, 01.png, ... etc. + --control-video path to control video frames, It must be a directory path. The video frames + inside should be stored as images in lexicographical (character) order. For + example, if the control video path is `frames`, the directory contain images + such as 00.png, 01.png, ... etc. --pm-id-images-dir path to PHOTOMAKER input id images dir --pm-id-embed-path path to PHOTOMAKER v2 id embed + --hires-upscaler highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent + (nearest-exact), Latent (antialiased), Latent (bicubic), Latent (bicubic + antialiased), or a model name under --hires-upscalers-dir (default: Latent) -H, --height image height, in pixel space (default: 512) -W, --width image width, in pixel space (default: 512) --steps number of sample steps (default: 20) --high-noise-steps (high noise) number of sample steps (default: -1 = auto) - --clip-skip ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1). <= 0 represents unspecified, - will be 1 for SD1.x, 2 for SD2.x + --clip-skip ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer + (default: -1). <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x -b, --batch-count batch count --video-frames video frames (default: 1) --fps fps (default: 24) - --timestep-shift shift timestep for NitroFusion models (default: 0). recommended N for NitroSD-Realism around 250 and 500 for - NitroSD-Vibrant + --timestep-shift shift timestep for NitroFusion models (default: 0). recommended N for + NitroSD-Realism around 250 and 500 for NitroSD-Vibrant --upscale-repeats Run the ESRGAN upscaler this many times (default: 1) --upscale-tile-size tile size for ESRGAN upscaling (default: 128) + --hires-width highres fix target width, 0 to use --hires-scale (default: 0) + --hires-height highres fix target height, 0 to use --hires-scale (default: 0) + --hires-steps highres fix second pass sample steps, 0 to reuse --steps (default: 0) + --hires-upscale-tile-size highres fix upscaler tile size, reserved for model-backed upscalers (default: + 128) --cfg-scale unconditional guidance scale: (default: 7.0) - --img-cfg-scale image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale) + --img-cfg-scale image guidance scale for inpaint or instruct-pix2pix models: (default: same + as --cfg-scale) --guidance distilled guidance scale for models with guidance input (default: 3.5) - --slg-scale skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means disabled, a value of 2.5 is nice for sd3.5 - medium + --slg-scale skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means + disabled, a value of 2.5 is nice for sd3.5 medium --skip-layer-start SLG enabling point (default: 0.01) --skip-layer-end SLG disabling point (default: 0.2) - --eta noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a) + --eta noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and + res_2s; 1 for euler_a, er_sde and dpm++2s_a) --flow-shift shift value for Flow models like SD3.x or WAN (default: auto) --high-noise-cfg-scale (high noise) unconditional guidance scale: (default: 7.0) - --high-noise-img-cfg-scale (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale) - --high-noise-guidance (high noise) distilled guidance scale for models with guidance input (default: 3.5) - --high-noise-slg-scale (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0) + --high-noise-img-cfg-scale (high noise) image guidance scale for inpaint or instruct-pix2pix models + (default: same as --cfg-scale) + --high-noise-guidance (high noise) distilled guidance scale for models with guidance input + (default: 3.5) + --high-noise-slg-scale (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: + 0) --high-noise-skip-layer-start (high noise) SLG enabling point (default: 0.01) --high-noise-skip-layer-end (high noise) SLG disabling point (default: 0.2) - --high-noise-eta (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a) + --high-noise-eta (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, + res_multistep and res_2s; 1 for euler_a, er_sde and dpm++2s_a) --strength strength for noising/unnoising (default: 0.75) - --pm-style-strength - --control-strength strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image - --moe-boundary timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if `--high-noise-steps` is set to -1 + --pm-style-strength + --control-strength strength to apply Control Net (default: 0.9). 1.0 corresponds to full + destruction of information in init image + --moe-boundary timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if + `--high-noise-steps` is set to -1 --vace-strength wan vace strength - --increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1). + --vae-tile-overlap tile overlap for vae tiling, in fraction of tile size (default: 0.5) + --hires-scale highres fix scale when target size is not set (default: 2.0) + --hires-denoising-strength highres fix second pass denoising strength (default: 0.7) + --increase-ref-index automatically increase the indices of references images based on the order + they are listed (starting with 1). --disable-auto-resize-ref-image disable auto resize of ref images --disable-image-metadata do not embed generation metadata on image files + --vae-tiling process vae in tiles to reduce memory usage + --hires enable highres fix -s, --seed RNG seed (default: 42, use random seed for < 0) - --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, - tcd, res_multistep, res_2s] (default: euler for Flux/SD3/Wan, euler_a - otherwise) - --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, - ddim_trailing, tcd, res_multistep, res_2s] default: euler for Flux/SD3/Wan, - euler_a otherwise - --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, - kl_optimal, lcm, bong_tangent], default: discrete - --sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0"). + --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, + dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s, + er_sde] (default: euler for Flux/SD3/Wan, euler_a otherwise) + --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, + dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, + res_2s, er_sde] default: euler for Flux/SD3/Wan, euler_a otherwise + --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, + smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default: + discrete + --sigmas custom sigma values for the sampler, comma-separated (e.g., + "14.61,7.8,3.5,0.0"). --skip-layers layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) -r, --ref-image reference image for Flux Kontext models (can be used multiple times) - --cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting) + --cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), + 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT + Chebyshev+Taylor forecasting) --cache-option named cache params (key=value format, comma-separated). easycache/ucache: - threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples: - "threshold=0.25" or "threshold=1.5,reset=0" - --scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache + threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: + Fn=,Bn=,threshold=,warmup=; spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. + Examples: "threshold=0.25" or "threshold=1.5,reset=0" + --scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., + "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache --scm-policy SCM policy: 'dynamic' (default) or 'static' + --vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32) + --vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size + if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size) ``` diff --git a/examples/server/api.md b/examples/server/api.md new file mode 100644 index 000000000..483daa041 --- /dev/null +++ b/examples/server/api.md @@ -0,0 +1,1288 @@ +# stable-diffusion.cpp Server APIs + +This document describes the server-facing APIs exposed by `examples/server`. + +The server currently exposes three API families: + +- `OpenAI API` under `/v1/...` +- `Stable Diffusion WebUI API` under `/sdapi/v1/...` +- `sdcpp API` under `/sdcpp/v1/...` + +The `sdcpp API` is the native API surface. +Its request schema is the same schema used by `sd_cpp_extra_args`. + +Global LoRA rule: + +- Server APIs do not parse LoRA tags embedded inside `prompt`. +- `` prompt syntax is intentionally unsupported in `OpenAI API`, `sdapi`, and `sdcpp API`. +- LoRA must be passed through structured API fields when the API supports it. + +## Overview + +### OpenAI API + +Compatibility API shaped like OpenAI image endpoints. + +Current generation-related endpoints include: + +- `POST /v1/images/generations` +- `POST /v1/images/edits` +- `GET /v1/models` + +### Stable Diffusion WebUI API + +Compatibility API shaped like the AUTOMATIC1111 / WebUI endpoints. + +Current generation-related endpoints include: + +- `POST /sdapi/v1/txt2img` +- `POST /sdapi/v1/img2img` +- `GET /sdapi/v1/loras` +- `GET /sdapi/v1/upscalers` +- `GET /sdapi/v1/latent-upscale-modes` +- `GET /sdapi/v1/samplers` +- `GET /sdapi/v1/schedulers` +- `GET /sdapi/v1/sd-models` +- `GET /sdapi/v1/options` + +### sdcpp API + +Native async API for `stable-diffusion.cpp`. + +Current endpoints include: + +- `GET /sdcpp/v1/capabilities` +- `POST /sdcpp/v1/img_gen` +- `GET /sdcpp/v1/jobs/{id}` +- `POST /sdcpp/v1/jobs/{id}/cancel` +- `POST /sdcpp/v1/vid_gen` + +## `sd_cpp_extra_args` + +`sd_cpp_extra_args` is an extension mechanism for the compatibility APIs. + +Rules: + +- Its JSON schema is the same schema used by the native `sdcpp API`. +- `OpenAI API` and `sdapi` can embed it inside `prompt`. +- `sdcpp API` does not need it, because the request body already uses the native schema directly. + +Embedding format: + +```text +normal prompt text {"sample_params":{"sample_steps":28}} +``` + +Behavior: + +- The server extracts the JSON block. +- The JSON block is parsed using the same field rules as the `sdcpp API`. +- The block is removed from the final prompt before generation. + +Supported use: + +- extend `OpenAI API` requests with native `stable-diffusion.cpp` controls +- extend `sdapi` requests with native `stable-diffusion.cpp` controls + +Unsupported use: + +- do not use `sd_cpp_extra_args` with `/sdcpp/v1/*` + +## OpenAI API + +### Purpose + +This family exists for client compatibility. + +Use it when you want OpenAI-style request and response shapes. + +### Native Extension + +`OpenAI API` supports `sd_cpp_extra_args` embedded inside `prompt`. + +The embedded JSON follows the `sdcpp API` request schema. + +### Supported Fields + +#### `POST /v1/images/generations` + +Currently supported top-level request fields: + +| Field | Type | Notes | +| --- | --- | --- | +| `prompt` | `string` | Required | +| `n` | `integer` | Number of images | +| `size` | `string` | Format `WIDTHxHEIGHT` | +| `output_format` | `string` | `png`, `jpeg`, or `webp` | +| `output_compression` | `integer` | Range is clamped to `0..100` | + +Native extension fields: + +- any `sdcpp API` fields embedded through `sd_cpp_extra_args` inside `prompt` + +Response fields: + +| Field | Type | Notes | +| --- | --- | --- | +| `created` | `integer` | Unix timestamp | +| `output_format` | `string` | Final encoded image format | +| `data` | `array` | Generated image list | +| `data[].b64_json` | `string` | Base64-encoded image bytes | + +#### `POST /v1/images/edits` + +Currently supported multipart form fields: + +| Field | Type | Notes | +| --- | --- | --- | +| `prompt` | `string` | Required | +| `image[]` | `file[]` | Preferred image upload field | +| `image` | `file` | Legacy single-image upload field | +| `mask` | `file` | Optional mask image | +| `n` | `integer` | Number of images | +| `size` | `string` | Format `WIDTHxHEIGHT` | +| `output_format` | `string` | `png` or `jpeg` | +| `output_compression` | `integer` | Range is clamped to `0..100` | + +Native extension fields: + +- any `sdcpp API` fields embedded through `sd_cpp_extra_args` inside `prompt` + +Response fields: + +| Field | Type | Notes | +| --- | --- | --- | +| `created` | `integer` | Unix timestamp | +| `output_format` | `string` | Final encoded image format | +| `data` | `array` | Generated image list | +| `data[].b64_json` | `string` | Base64-encoded image bytes | + +#### `GET /v1/models` + +Response fields: + +| Field | Type | Notes | +| --- | --- | --- | +| `data` | `array` | Available local models | +| `data[].id` | `string` | Currently fixed to `sd-cpp-local` | +| `data[].object` | `string` | Currently fixed to `model` | +| `data[].owned_by` | `string` | Currently fixed to `local` | + +### Output Options + +`OpenAI API` supports response serialization controls such as: + +- `output_format` +- `output_compression` + +### Notes + +- `OpenAI API` is synchronous from the HTTP client's perspective. +- Native async job polling is not exposed through this family. +- Prompt-embedded `` tags are intentionally unsupported. + +## Stable Diffusion WebUI API + +### Purpose + +This family exists for client compatibility with WebUI-style tools. + +Use it when you want `txt2img` / `img2img`-style endpoints and response shapes. + +### Native Extension + +`sdapi` supports `sd_cpp_extra_args` embedded inside `prompt`. + +The embedded JSON follows the `sdcpp API` request schema. + +This allows `sdapi` clients to use native `stable-diffusion.cpp` controls without changing the outer request format. + +### Supported Fields + +#### `POST /sdapi/v1/txt2img` + +Currently supported request fields: + +| Field | Type | Notes | +| --- | --- | --- | +| `prompt` | `string` | Required | +| `negative_prompt` | `string` | Optional | +| `width` | `integer` | Positive image width | +| `height` | `integer` | Positive image height | +| `steps` | `integer` | Sampling steps | +| `cfg_scale` | `number` | Text CFG scale | +| `seed` | `integer` | `-1` means random | +| `batch_size` | `integer` | Number of images | +| `clip_skip` | `integer` | Optional | +| `sampler_name` | `string` | WebUI sampler name | +| `scheduler` | `string` | Scheduler name | +| `lora` | `array` | Structured LoRA list | +| `extra_images` | `array` | Base64 or data URL images | +| `enable_hr` | `boolean` | Enable highres fix for `txt2img` | +| `hr_upscaler` | `string` | `Lanczos`, `Nearest`, a latent mode such as `Latent (nearest-exact)`, or an upscaler model name from `/sdapi/v1/upscalers` | +| `hr_scale` | `number` | Highres scale when resize target is not set | +| `hr_resize_x` | `integer` | Highres target width, `0` to use scale | +| `hr_resize_y` | `integer` | Highres target height, `0` to use scale | +| `hr_steps` | `integer` | Highres second-pass sample steps, `0` to reuse `steps` | +| `denoising_strength` | `number` | Highres denoising strength for `txt2img` | + +Native extension fields: + +- any `sdcpp API` fields embedded through `sd_cpp_extra_args` inside `prompt` + +Response fields: + +| Field | Type | Notes | +| --- | --- | --- | +| `images` | `array` | Base64-encoded PNG images | +| `parameters` | `object` | Echo of the parsed outer request body | +| `info` | `string` | Currently empty string | + +#### `POST /sdapi/v1/img2img` + +Currently supported request fields: + +| Field | Type | Notes | +| --- | --- | --- | +| all currently supported `txt2img` fields | same as above | Reused | +| `init_images` | `array` | Base64 or data URL images | +| `mask` | `string` | Base64 or data URL image | +| `inpainting_mask_invert` | `integer` or `boolean` | Treated as invert flag | +| `denoising_strength` | `number` | Clamped to `0.0..1.0` | + +Highres fix fields are currently handled for `txt2img`; `img2img` uses `denoising_strength` as image-to-image strength. + +Native extension fields: + +- any `sdcpp API` fields embedded through `sd_cpp_extra_args` inside `prompt` + +Response fields: + +| Field | Type | Notes | +| --- | --- | --- | +| `images` | `array` | Base64-encoded PNG images | +| `parameters` | `object` | Echo of the parsed outer request body | +| `info` | `string` | Currently empty string | + +#### Discovery / Compatibility Endpoints + +Currently exposed: + +- `GET /sdapi/v1/loras` +- `GET /sdapi/v1/upscalers` +- `GET /sdapi/v1/latent-upscale-modes` +- `GET /sdapi/v1/samplers` +- `GET /sdapi/v1/schedulers` +- `GET /sdapi/v1/sd-models` +- `GET /sdapi/v1/options` + +Response fields: + +`GET /sdapi/v1/loras` + +| Field | Type | Notes | +| --- | --- | --- | +| `[].name` | `string` | Display name derived from file stem | +| `[].path` | `string` | Relative path under the configured LoRA directory | + +`GET /sdapi/v1/upscalers` + +| Field | Type | Notes | +| --- | --- | --- | +| `[].name` | `string` | Built-in name or model stem | +| `[].model_name` | `string \| null` | Model family label for model-backed upscalers | +| `[].model_path` | `string \| null` | Absolute model path for model-backed upscalers | +| `[].model_url` | `string \| null` | Currently always null | +| `[].scale` | `integer` | Currently `4` | + +Built-in entries include `None`, `Lanczos`, and `Nearest`. Model-backed entries are scanned from the top level of `--hires-upscalers-dir`; subdirectories are not scanned. + +`GET /sdapi/v1/latent-upscale-modes` + +| Field | Type | Notes | +| --- | --- | --- | +| `[].name` | `string` | WebUI-compatible latent upscale mode name | + +Built-in latent modes include `Latent`, `Latent (nearest)`, `Latent (nearest-exact)`, `Latent (antialiased)`, `Latent (bicubic)`, and `Latent (bicubic antialiased)`. + +`GET /sdapi/v1/samplers` + +| Field | Type | Notes | +| --- | --- | --- | +| `[].name` | `string` | Sampler name | +| `[].aliases` | `array` | Currently contains the same single sampler name | +| `[].options` | `object` | Currently empty object | + +`GET /sdapi/v1/schedulers` + +| Field | Type | Notes | +| --- | --- | --- | +| `[].name` | `string` | Scheduler name | +| `[].label` | `string` | Same value as `name` | + +`GET /sdapi/v1/sd-models` + +| Field | Type | Notes | +| --- | --- | --- | +| `[].title` | `string` | Model stem | +| `[].model_name` | `string` | Same value as `title` | +| `[].filename` | `string` | Model filename | +| `[].hash` | `string` | Placeholder compatibility value | +| `[].sha256` | `string` | Placeholder compatibility value | +| `[].config` | `null` | Currently always null | + +`GET /sdapi/v1/options` + +| Field | Type | Notes | +| --- | --- | --- | +| `samples_format` | `string` | Currently fixed to `png` | +| `sd_model_checkpoint` | `string` | Model stem | + +### Notes + +- `sdapi` is synchronous from the HTTP client's perspective. +- Prompt-embedded `` tags are intentionally unsupported. + +## sdcpp API + +### Purpose + +This is the native `stable-diffusion.cpp` API. + +Use it when you want: + +- async job submission +- explicit native parameter control +- frontend-oriented capability discovery + +### Job Model + +All async generation requests create a job. + +Job states: + +- `queued` +- `generating` +- `completed` +- `failed` +- `cancelled` + +Common job shape: + +```json +{ + "id": "job_01HTXYZABC", + "kind": "img_gen", + "status": "queued", + "created": 1775401200, + "started": null, + "completed": null, + "queue_position": 2, + "result": null, + "error": null +} +``` + +Field types: + +| Field | Type | +| --- | --- | +| `id` | `string` | +| `kind` | `string` | +| `status` | `string` | +| `created` | `integer` | +| `started` | `integer \| null` | +| `completed` | `integer \| null` | +| `queue_position` | `integer` | +| `result` | `object \| null` | +| `error` | `object \| null` | + +### Endpoints + +#### `GET /sdcpp/v1/capabilities` + +Returns frontend-friendly capability metadata. + +The mode-aware fields are the primary interface. The top-level compatibility fields are deprecated mirrors kept for older clients. + +Top-level fields: + +| Field | Type | Notes | +| --- | --- | --- | +| `model` | `object` | Loaded model metadata | +| `current_mode` | `string` | The native generation mode mirrored by top-level compatibility fields | +| `supported_modes` | `array` | Supported native modes such as `img_gen` or `vid_gen` | +| `defaults` | `object` | Deprecated compatibility mirror of `defaults_by_mode[current_mode]` | +| `output_formats` | `array` | Deprecated compatibility mirror of `output_formats_by_mode[current_mode]` | +| `features` | `object` | Deprecated compatibility mirror of `features_by_mode[current_mode]` | +| `defaults_by_mode` | `object` | Explicit defaults for each supported mode | +| `output_formats_by_mode` | `object` | Explicit output formats for each supported mode | +| `features_by_mode` | `object` | Explicit feature flags for each supported mode | +| `samplers` | `array` | Available sampling methods | +| `schedulers` | `array` | Available schedulers | +| `loras` | `array` | Available LoRA entries | +| `upscalers` | `array` | Available model-backed highres upscalers | +| `limits` | `object` | Shared queue and size limits | + +`model` + +| Field | Type | +| --- | --- | +| `model.name` | `string` | +| `model.stem` | `string` | +| `model.path` | `string` | + +Compatibility rules: + +- `defaults`, `output_formats`, and `features` are deprecated compatibility mirrors +- those three top-level fields always mirror `current_mode` +- `supported_modes`, `defaults_by_mode`, `output_formats_by_mode`, and `features_by_mode` are the mode-aware fields + +Mode-aware objects: + +| Field | Type | +| --- | --- | +| `defaults_by_mode.img_gen` | `object` | +| `defaults_by_mode.vid_gen` | `object` | +| `output_formats_by_mode.img_gen` | `array` | +| `output_formats_by_mode.vid_gen` | `array` | +| `features_by_mode.img_gen` | `object` | +| `features_by_mode.vid_gen` | `object` | + +Shared nested fields: + +`loras` + +| Field | Type | +| --- | --- | +| `loras[].name` | `string` | +| `loras[].path` | `string` | + +`upscalers` + +| Field | Type | Notes | +| --- | --- | --- | +| `upscalers[].name` | `string` | Built-in name or model stem; use this value in `hires.upscaler` | + +Built-in entries include `None`, `Lanczos`, `Nearest`, `Latent`, `Latent (nearest)`, `Latent (nearest-exact)`, `Latent (antialiased)`, `Latent (bicubic)`, and `Latent (bicubic antialiased)`. Model-backed entries are scanned from the top level of `--hires-upscalers-dir`; subdirectories are not scanned. + +`limits` + +| Field | Type | +| --- | --- | +| `limits.min_width` | `integer` | +| `limits.max_width` | `integer` | +| `limits.min_height` | `integer` | +| `limits.max_height` | `integer` | +| `limits.max_batch_count` | `integer` | +| `limits.max_queue_size` | `integer` | + +Shared default fields used by both `img_gen` and `vid_gen`: + +| Field | Type | +| --- | --- | +| `prompt` | `string` | +| `negative_prompt` | `string` | +| `clip_skip` | `integer` | +| `width` | `integer` | +| `height` | `integer` | +| `strength` | `number` | +| `seed` | `integer` | +| `sample_params` | `object` | +| `sample_params.scheduler` | `string` | +| `sample_params.sample_method` | `string` | +| `sample_params.sample_steps` | `integer` | +| `sample_params.eta` | `number \| null` | +| `sample_params.shifted_timestep` | `integer` | +| `sample_params.flow_shift` | `number \| null` | +| `sample_params.guidance.txt_cfg` | `number` | +| `sample_params.guidance.img_cfg` | `number \| null` | +| `sample_params.guidance.distilled_guidance` | `number` | +| `sample_params.guidance.slg.layers` | `array` | +| `sample_params.guidance.slg.layer_start` | `number` | +| `sample_params.guidance.slg.layer_end` | `number` | +| `sample_params.guidance.slg.scale` | `number` | +| `vae_tiling_params` | `object` | +| `vae_tiling_params.enabled` | `boolean` | +| `vae_tiling_params.tile_size_x` | `integer` | +| `vae_tiling_params.tile_size_y` | `integer` | +| `vae_tiling_params.target_overlap` | `number` | +| `vae_tiling_params.rel_size_x` | `number` | +| `vae_tiling_params.rel_size_y` | `number` | +| `cache_mode` | `string` | +| `cache_option` | `string` | +| `scm_mask` | `string` | +| `scm_policy_dynamic` | `boolean` | +| `output_format` | `string` | +| `output_compression` | `integer` | + +`img_gen`-specific default fields: + +| Field | Type | +| --- | --- | +| `batch_count` | `integer` | +| `auto_resize_ref_image` | `boolean` | +| `increase_ref_index` | `boolean` | +| `control_strength` | `number` | +| `hires` | `object` | +| `hires.enabled` | `boolean` | +| `hires.upscaler` | `string` | +| `hires.scale` | `number` | +| `hires.target_width` | `integer` | +| `hires.target_height` | `integer` | +| `hires.steps` | `integer` | +| `hires.denoising_strength` | `number` | +| `hires.upscale_tile_size` | `integer` | + +`vid_gen`-specific default fields: + +| Field | Type | +| --- | --- | +| `video_frames` | `integer` | +| `fps` | `integer` | +| `moe_boundary` | `number` | +| `vace_strength` | `number` | +| `high_noise_sample_params` | `object` | +| `high_noise_sample_params.scheduler` | `string` | +| `high_noise_sample_params.sample_method` | `string` | +| `high_noise_sample_params.sample_steps` | `integer` | +| `high_noise_sample_params.eta` | `number \| null` | +| `high_noise_sample_params.shifted_timestep` | `integer` | +| `high_noise_sample_params.flow_shift` | `number \| null` | +| `high_noise_sample_params.guidance.txt_cfg` | `number` | +| `high_noise_sample_params.guidance.img_cfg` | `number \| null` | +| `high_noise_sample_params.guidance.distilled_guidance` | `number` | +| `high_noise_sample_params.guidance.slg.layers` | `array` | +| `high_noise_sample_params.guidance.slg.layer_start` | `number` | +| `high_noise_sample_params.guidance.slg.layer_end` | `number` | +| `high_noise_sample_params.guidance.slg.scale` | `number` | + +Fields returned in `features_by_mode.img_gen`: + +- `init_image` +- `mask_image` +- `control_image` +- `ref_images` +- `lora` +- `vae_tiling` +- `hires` +- `cache` +- `cancel_queued` +- `cancel_generating` + +Fields returned in `features_by_mode.vid_gen`: + +- `init_image` +- `end_image` +- `control_frames` +- `high_noise_sample_params` +- `lora` +- `vae_tiling` +- `cache` +- `cancel_queued` +- `cancel_generating` + +#### `POST /sdcpp/v1/img_gen` + +Submits an async image generation job. + +Successful submission returns `202 Accepted`. + +Example response: + +```json +{ + "id": "job_01HTXYZABC", + "kind": "img_gen", + "status": "queued", + "created": 1775401200, + "poll_url": "/sdcpp/v1/jobs/job_01HTXYZABC" +} +``` + +Response fields: + +| Field | Type | +| --- | --- | +| `id` | `string` | +| `kind` | `string` | +| `status` | `string` | +| `created` | `integer` | +| `poll_url` | `string` | + +#### `GET /sdcpp/v1/jobs/{id}` + +Returns current job status. + +Typical status codes: + +- `200 OK` +- `404 Not Found` +- `410 Gone` + +#### `POST /sdcpp/v1/jobs/{id}/cancel` + +Attempts to cancel an accepted job. + +Typical status codes: + +- `200 OK` +- `404 Not Found` +- `409 Conflict` +- `410 Gone` + +### Request Body + +Example: + +```json +{ + "prompt": "a cat sitting on a chair", + "negative_prompt": "", + "clip_skip": -1, + "width": 1024, + "height": 1024, + "strength": 0.75, + "seed": -1, + "batch_count": 1, + "auto_resize_ref_image": true, + "increase_ref_index": false, + "control_strength": 0.9, + "embed_image_metadata": true, + + "init_image": null, + "ref_images": [], + "mask_image": null, + "control_image": null, + + "sample_params": { + "scheduler": "discrete", + "sample_method": "euler_a", + "sample_steps": 28, + "eta": 1.0, + "shifted_timestep": 0, + "custom_sigmas": [], + "flow_shift": 0.0, + "guidance": { + "txt_cfg": 7.0, + "img_cfg": 7.0, + "distilled_guidance": 3.5, + "slg": { + "layers": [7, 8, 9], + "layer_start": 0.01, + "layer_end": 0.2, + "scale": 0.0 + } + } + }, + + "lora": [], + "hires": { + "enabled": false, + "upscaler": "Latent", + "scale": 2.0, + "target_width": 0, + "target_height": 0, + "steps": 0, + "denoising_strength": 0.7, + "upscale_tile_size": 128 + }, + + "vae_tiling_params": { + "enabled": false, + "tile_size_x": 0, + "tile_size_y": 0, + "target_overlap": 0.5, + "rel_size_x": 0.0, + "rel_size_y": 0.0 + }, + + "cache_mode": "disabled", + "cache_option": "", + "scm_mask": "", + "scm_policy_dynamic": true, + + "output_format": "png", + "output_compression": 100 +} +``` + +### LoRA Rules + +- The server only accepts explicit LoRA entries from the `lora` field. +- Prompt-embedded `` tags are intentionally unsupported. +- Clients should resolve LoRA usage through the structured `lora` array. + +### Image Encoding Rules + +Any image field accepts: + +- a raw base64 string, or +- a data URL such as `data:image/png;base64,...` + +Channel expectations: + +- `init_image`: 3 channels +- `ref_images[]`: 3 channels +- `control_image`: 3 channels +- `mask_image`: 1 channel + +If omitted or null: + +- single-image fields map to an empty `sd_image_t` +- array fields map to an empty C-style array, represented as `pointer = nullptr` and `count = 0` + +### Field Mapping Summary + +Top-level scalar fields: + +| Field | Type | +| --- | --- | +| `prompt` | `string` | +| `negative_prompt` | `string` | +| `clip_skip` | `integer` | +| `width` | `integer` | +| `height` | `integer` | +| `strength` | `number` | +| `seed` | `integer` | +| `batch_count` | `integer` | +| `auto_resize_ref_image` | `boolean` | +| `increase_ref_index` | `boolean` | +| `control_strength` | `number` | +| `embed_image_metadata` | `boolean` | + +Image fields: + +| Field | Type | +| --- | --- | +| `init_image` | `string \| null` | +| `ref_images` | `array` | +| `mask_image` | `string \| null` | +| `control_image` | `string \| null` | + +LoRA fields: + +| Field | Type | +| --- | --- | +| `lora[].path` | `string` | +| `lora[].multiplier` | `number` | +| `lora[].is_high_noise` | `boolean` | + +Sampling fields: + +| Field | Type | +| --- | --- | +| `sample_params.scheduler` | `string` | +| `sample_params.sample_method` | `string` | +| `sample_params.sample_steps` | `integer` | +| `sample_params.eta` | `number` | +| `sample_params.shifted_timestep` | `integer` | +| `sample_params.custom_sigmas` | `array` | +| `sample_params.flow_shift` | `number` | +| `sample_params.guidance.txt_cfg` | `number` | +| `sample_params.guidance.img_cfg` | `number` | +| `sample_params.guidance.distilled_guidance` | `number` | +| `sample_params.guidance.slg.layers` | `array` | +| `sample_params.guidance.slg.layer_start` | `number` | +| `sample_params.guidance.slg.layer_end` | `number` | +| `sample_params.guidance.slg.scale` | `number` | + +Other native fields: + +| Field | Type | +| --- | --- | +| `hires` | `object` | +| `hires.enabled` | `boolean` | +| `hires.upscaler` | `string` | +| `hires.scale` | `number` | +| `hires.target_width` | `integer` | +| `hires.target_height` | `integer` | +| `hires.steps` | `integer` | +| `hires.denoising_strength` | `number` | +| `hires.upscale_tile_size` | `integer` | +| `vae_tiling_params` | `object` | +| `cache_mode` | `string` | +| `cache_option` | `string` | +| `scm_mask` | `string` | +| `scm_policy_dynamic` | `boolean` | + +For `hires.upscaler`, use `Lanczos`, `Nearest`, `Latent`, `Latent (nearest)`, `Latent (nearest-exact)`, `Latent (antialiased)`, `Latent (bicubic)`, `Latent (bicubic antialiased)`, or an `upscalers[].name` value from `GET /sdcpp/v1/capabilities`. Model-backed upscalers are resolved as `--hires-upscalers-dir / (name + ext)` and must live directly in that directory. + +HTTP-only output fields: + +| Field | Type | +| --- | --- | +| `output_format` | `string` | +| `output_compression` | `integer` | + +### Optional Field Handling + +Optional sampling fields may be omitted. + +When omitted, backend defaults apply to these fields: + +- `sample_params.scheduler` +- `sample_params.sample_method` +- `sample_params.eta` +- `sample_params.flow_shift` +- `sample_params.guidance.img_cfg` + +### Completion Result + +Example completed job: + +```json +{ + "id": "job_01HTXYZABC", + "kind": "img_gen", + "status": "completed", + "created": 1775401200, + "started": 1775401203, + "completed": 1775401215, + "queue_position": 0, + "result": { + "output_format": "png", + "images": [ + { + "index": 0, + "b64_json": "iVBORw0KGgoAAA..." + } + ] + }, + "error": null +} +``` + +### Failure Result + +Example failed job: + +```json +{ + "id": "job_01HTXYZABC", + "kind": "img_gen", + "status": "failed", + "created": 1775401200, + "started": 1775401203, + "completed": 1775401204, + "queue_position": 0, + "result": null, + "error": { + "code": "generation_failed", + "message": "generate_image returned empty results" + } +} +``` + +### Cancelled Result + +Example cancelled job: + +```json +{ + "id": "job_01HTXYZABC", + "kind": "img_gen", + "status": "cancelled", + "created": 1775401200, + "started": null, + "completed": 1775401202, + "queue_position": 0, + "result": null, + "error": { + "code": "cancelled", + "message": "job cancelled by client" + } +} +``` + +### Submission Errors + +`POST /sdcpp/v1/img_gen` may return: + +- `202 Accepted` when the job is created +- `400 Bad Request` for an empty body, unsupported model mode, invalid JSON, or invalid generation parameters +- `429 Too Many Requests` when the job queue is full +- `500 Internal Server Error` for unexpected server exceptions during submission + +### `vid_gen` + +The following section documents the native async contract for video generation. + +#### `POST /sdcpp/v1/vid_gen` + +Submits an async video generation job. + +Successful submission returns `202 Accepted`. + +Example response: + +```json +{ + "id": "job_01HTXYZVID", + "kind": "vid_gen", + "status": "queued", + "created": 1775401200, + "poll_url": "/sdcpp/v1/jobs/job_01HTXYZVID" +} +``` + +Response fields: + +| Field | Type | +| --- | --- | +| `id` | `string` | +| `kind` | `string` | +| `status` | `string` | +| `created` | `integer` | +| `poll_url` | `string` | + +### Request Body + +Compared with `img_gen`, the `vid_gen` request body: + +- `vid_gen` is a single video sequence job, so `batch_count` is not part of the request schema +- `ref_images`, `mask_image`, `control_image`, `control_strength`, and `embed_image_metadata` are not part of the request schema +- `vid_gen` adds `end_image`, `control_frames`, `high_noise_sample_params`, `video_frames`, `fps`, `moe_boundary`, and `vace_strength` + +Example: + +```json +{ + "prompt": "a cat walking through a rainy alley", + "negative_prompt": "", + "clip_skip": -1, + "width": 832, + "height": 480, + "strength": 0.75, + "seed": -1, + "video_frames": 33, + "fps": 16, + "moe_boundary": 0.875, + "vace_strength": 1.0, + + "init_image": null, + "end_image": null, + "control_frames": [], + + "sample_params": { + "scheduler": "discrete", + "sample_method": "euler", + "sample_steps": 28, + "eta": 1.0, + "shifted_timestep": 0, + "custom_sigmas": [], + "flow_shift": 0.0, + "guidance": { + "txt_cfg": 7.0, + "img_cfg": 7.0, + "distilled_guidance": 3.5, + "slg": { + "layers": [7, 8, 9], + "layer_start": 0.01, + "layer_end": 0.2, + "scale": 0.0 + } + } + }, + + "high_noise_sample_params": { + "scheduler": "discrete", + "sample_method": "euler", + "sample_steps": -1, + "eta": 1.0, + "shifted_timestep": 0, + "flow_shift": 0.0, + "guidance": { + "txt_cfg": 7.0, + "img_cfg": 7.0, + "distilled_guidance": 3.5, + "slg": { + "layers": [7, 8, 9], + "layer_start": 0.01, + "layer_end": 0.2, + "scale": 0.0 + } + } + }, + + "lora": [], + + "vae_tiling_params": { + "enabled": false, + "tile_size_x": 0, + "tile_size_y": 0, + "target_overlap": 0.5, + "rel_size_x": 0.0, + "rel_size_y": 0.0 + }, + + "cache_mode": "disabled", + "cache_option": "", + "scm_mask": "", + "scm_policy_dynamic": true, + + "output_format": "webm", + "output_compression": 100 +} +``` + +### LoRA Rules + +- The server only accepts explicit LoRA entries from the `lora` field. +- Prompt-embedded `` tags are intentionally unsupported. +- `lora[].is_high_noise` controls whether a LoRA applies only to the high-noise stage. + +### Image and Frame Encoding Rules + +Any image field accepts: + +- a raw base64 string, or +- a data URL such as `data:image/png;base64,...` + +Channel expectations: + +- `init_image`: 3 channels +- `end_image`: 3 channels +- `control_frames[]`: 3 channels + +Frame ordering rules: + +- `control_frames[]` order is the conditioning frame order +- `control_frames[]` is preserved in request order + +If omitted or null: + +- single-image fields map to an empty `sd_image_t` +- array fields map to an empty C-style array, represented as `pointer = nullptr` and `count = 0` + +### Field Mapping Summary + +Top-level scalar fields: + +| Field | Type | +| --- | --- | +| `prompt` | `string` | +| `negative_prompt` | `string` | +| `clip_skip` | `integer` | +| `width` | `integer` | +| `height` | `integer` | +| `strength` | `number` | +| `seed` | `integer` | +| `video_frames` | `integer` | +| `fps` | `integer` | +| `moe_boundary` | `number` | +| `vace_strength` | `number` | + +Image and frame fields: + +| Field | Type | +| --- | --- | +| `init_image` | `string \| null` | +| `end_image` | `string \| null` | +| `control_frames` | `array` | + +LoRA fields: + +| Field | Type | +| --- | --- | +| `lora[].path` | `string` | +| `lora[].multiplier` | `number` | +| `lora[].is_high_noise` | `boolean` | + +Sampling fields: + +| Field | Type | +| --- | --- | +| `sample_params.scheduler` | `string` | +| `sample_params.sample_method` | `string` | +| `sample_params.sample_steps` | `integer` | +| `sample_params.eta` | `number` | +| `sample_params.shifted_timestep` | `integer` | +| `sample_params.custom_sigmas` | `array` | +| `sample_params.flow_shift` | `number` | +| `sample_params.guidance.txt_cfg` | `number` | +| `sample_params.guidance.img_cfg` | `number` | +| `sample_params.guidance.distilled_guidance` | `number` | +| `sample_params.guidance.slg.layers` | `array` | +| `sample_params.guidance.slg.layer_start` | `number` | +| `sample_params.guidance.slg.layer_end` | `number` | +| `sample_params.guidance.slg.scale` | `number` | + +High-noise sampling fields: + +| Field | Type | +| --- | --- | +| `high_noise_sample_params.scheduler` | `string` | +| `high_noise_sample_params.sample_method` | `string` | +| `high_noise_sample_params.sample_steps` | `integer` | +| `high_noise_sample_params.eta` | `number` | +| `high_noise_sample_params.shifted_timestep` | `integer` | +| `high_noise_sample_params.flow_shift` | `number` | +| `high_noise_sample_params.guidance.txt_cfg` | `number` | +| `high_noise_sample_params.guidance.img_cfg` | `number` | +| `high_noise_sample_params.guidance.distilled_guidance` | `number` | +| `high_noise_sample_params.guidance.slg.layers` | `array` | +| `high_noise_sample_params.guidance.slg.layer_start` | `number` | +| `high_noise_sample_params.guidance.slg.layer_end` | `number` | +| `high_noise_sample_params.guidance.slg.scale` | `number` | + +Other native fields: + +| Field | Type | +| --- | --- | +| `vae_tiling_params` | `object` | +| `cache_mode` | `string` | +| `cache_option` | `string` | +| `scm_mask` | `string` | +| `scm_policy_dynamic` | `boolean` | + +HTTP-only output fields: + +| Field | Type | +| --- | --- | +| `output_format` | `string` | +| `output_compression` | `integer` | + +For `vid_gen`, `output_format` and `output_compression` control container encoding. +`fps` is request metadata for the generated sequence and is echoed in the completed job result. + +Allowed `output_format` values: + +- `webm` +- `webp` +- `avi` + +Output format behavior: + +- `output_format` defaults to `webm` +- `webp` means animated WebP +- `avi` means MJPG AVI +- `webm` requires the server to be built with WebM support; otherwise the request returns `400` + +### Result Payload + +Completed jobs return one encoded container payload, not a list of per-frame images. + +Result fields: + +- `result.b64_json` contains the whole encoded container file as base64 +- `result.mime_type` identifies the media type +- `result.output_format` echoes the selected container format +- `result.fps` echoes the effective playback FPS +- `result.frame_count` reports the actual decoded frame count used to build the container + +Expected MIME types: + +| `output_format` | `mime_type` | +| --- | --- | +| `webm` | `video/webm` | +| `webp` | `image/webp` | +| `avi` | `video/x-msvideo` | + +### Optional Field Handling + +Optional sampling fields may be omitted. + +When omitted, backend defaults apply to these fields: + +- `sample_params.scheduler` +- `sample_params.sample_method` +- `sample_params.eta` +- `sample_params.flow_shift` +- `sample_params.guidance.img_cfg` +- `high_noise_sample_params.scheduler` +- `high_noise_sample_params.sample_method` +- `high_noise_sample_params.eta` +- `high_noise_sample_params.flow_shift` +- `high_noise_sample_params.guidance.img_cfg` + +`high_noise_sample_params` may also be omitted entirely. + +### Frame Count Semantics + +`video_frames` is the requested target length, but the current core video path internally normalizes the effective frame count to the largest `4n + 1` value that does not exceed the requested count. + +Examples: + +- `video_frames = 33` stays `33` +- `video_frames = 34` becomes `33` +- `video_frames = 32` becomes `29` + +The completed job payload includes the actual decoded `frame_count`. + +### Completion Result + +Example completed job: + +```json +{ + "id": "job_01HTXYZVID", + "kind": "vid_gen", + "status": "completed", + "created": 1775401200, + "started": 1775401203, + "completed": 1775401215, + "queue_position": 0, + "result": { + "output_format": "webm", + "mime_type": "video/webm", + "fps": 16, + "frame_count": 33, + "b64_json": "GkXfo59ChoEBQveBAULygQRC84EIQo..." + }, + "error": null +} +``` + +The response returns the encoded `.webm`, animated `.webp`, or `.avi` container payload directly. + +### Failure Result + +Example failed job: + +```json +{ + "id": "job_01HTXYZVID", + "kind": "vid_gen", + "status": "failed", + "created": 1775401200, + "started": 1775401203, + "completed": 1775401204, + "queue_position": 0, + "result": null, + "error": { + "code": "generation_failed", + "message": "generate_video returned no results" + } +} +``` + +### Cancelled Result + +Example cancelled job: + +```json +{ + "id": "job_01HTXYZVID", + "kind": "vid_gen", + "status": "cancelled", + "created": 1775401200, + "started": null, + "completed": 1775401202, + "queue_position": 0, + "result": null, + "error": { + "code": "cancelled", + "message": "job cancelled by client" + } +} +``` + +### Submission Errors + +`POST /sdcpp/v1/vid_gen` may return: + +- `202 Accepted` when the job is created +- `400 Bad Request` for an empty body, unsupported model mode, invalid JSON, invalid generation parameters, or an unsupported output format +- `429 Too Many Requests` when the job queue is full +- `500 Internal Server Error` for unexpected server exceptions during submission diff --git a/examples/server/async_jobs.cpp b/examples/server/async_jobs.cpp new file mode 100644 index 000000000..e8e9d8ada --- /dev/null +++ b/examples/server/async_jobs.cpp @@ -0,0 +1,349 @@ +// Extracted from main.cpp during server refactor. + +#include "async_jobs.h" + +#include +#include + +#include "common/log.h" +#include "common/media_io.h" +#include "common/resource_owners.hpp" + +const char* async_job_kind_name(AsyncJobKind kind) { + switch (kind) { + case AsyncJobKind::ImgGen: + return "img_gen"; + case AsyncJobKind::VidGen: + return "vid_gen"; + default: + return "img_gen"; + } +} + +const char* async_job_status_name(AsyncJobStatus status) { + switch (status) { + case AsyncJobStatus::Queued: + return "queued"; + case AsyncJobStatus::Generating: + return "generating"; + case AsyncJobStatus::Completed: + return "completed"; + case AsyncJobStatus::Failed: + return "failed"; + case AsyncJobStatus::Cancelled: + return "cancelled"; + default: + return "failed"; + } +} + +void purge_expired_jobs(AsyncJobManager& manager) { + const int64_t now = unix_timestamp_now(); + + for (auto it = manager.expired_jobs.begin(); it != manager.expired_jobs.end();) { + if (it->second <= now) { + it = manager.expired_jobs.erase(it); + } else { + ++it; + } + } + + for (auto it = manager.jobs.begin(); it != manager.jobs.end();) { + const auto& job = it->second; + if (job->completed_at == 0) { + ++it; + continue; + } + + int64_t ttl_seconds = job->status == AsyncJobStatus::Completed + ? manager.completed_ttl_seconds + : manager.failed_ttl_seconds; + if (now - job->completed_at >= ttl_seconds) { + manager.expired_jobs[job->id] = now + std::max(ttl_seconds, 60); + it = manager.jobs.erase(it); + } else { + ++it; + } + } +} + +size_t count_pending_jobs(const AsyncJobManager& manager) { + size_t pending = 0; + for (const auto& entry : manager.jobs) { + if (entry.second->status == AsyncJobStatus::Queued || + entry.second->status == AsyncJobStatus::Generating) { + ++pending; + } + } + return pending; +} + +std::string make_async_job_id(AsyncJobManager& manager) { + std::ostringstream oss; + oss << "job_" << std::hex << unix_timestamp_now() << "_" << std::setw(8) + << std::setfill('0') << manager.next_id++; + return oss.str(); +} + +bool cancel_queued_job(AsyncJobManager& manager, AsyncGenerationJob& job) { + auto new_end = std::remove(manager.queue.begin(), manager.queue.end(), job.id); + if (new_end == manager.queue.end()) { + return false; + } + + manager.queue.erase(new_end, manager.queue.end()); + job.status = AsyncJobStatus::Cancelled; + job.completed_at = unix_timestamp_now(); + job.result_images_b64.clear(); + job.result_media_b64.clear(); + job.result_media_mime_type.clear(); + job.result_frame_count = 0; + job.result_fps = 0; + job.error_code = "cancelled"; + job.error_message = "job cancelled by client"; + return true; +} + +json make_async_job_json(const AsyncJobManager& manager, const AsyncGenerationJob& job) { + json result; + result["id"] = job.id; + result["kind"] = async_job_kind_name(job.kind); + result["status"] = async_job_status_name(job.status); + result["created"] = job.created_at; + result["started"] = job.started_at == 0 ? json(nullptr) : json(job.started_at); + result["completed"] = job.completed_at == 0 ? json(nullptr) : json(job.completed_at); + result["queue_position"] = 0; + + if (job.status == AsyncJobStatus::Queued) { + size_t position = 1; + for (const auto& queued_id : manager.queue) { + if (queued_id == job.id) { + result["queue_position"] = position; + break; + } + ++position; + } + } + + if (job.status == AsyncJobStatus::Completed) { + if (job.kind == AsyncJobKind::VidGen) { + result["result"] = { + {"output_format", job.vid_gen.output_format}, + {"mime_type", job.result_media_mime_type}, + {"fps", job.result_fps}, + {"frame_count", job.result_frame_count}, + {"b64_json", job.result_media_b64}, + }; + } else { + json images = json::array(); + for (size_t i = 0; i < job.result_images_b64.size(); ++i) { + images.push_back({{"index", i}, {"b64_json", job.result_images_b64[i]}}); + } + result["result"] = { + {"output_format", job.img_gen.output_format}, + {"images", images}, + }; + } + result["error"] = nullptr; + } else if (job.status == AsyncJobStatus::Failed || + job.status == AsyncJobStatus::Cancelled) { + result["result"] = nullptr; + result["error"] = { + {"code", + job.error_code.empty() + ? (job.status == AsyncJobStatus::Cancelled ? "cancelled" : "generation_failed") + : job.error_code}, + {"message", job.error_message}, + }; + } else { + result["result"] = nullptr; + result["error"] = nullptr; + } + + return result; +} + +bool execute_img_gen_job(ServerRuntime& runtime, + AsyncGenerationJob& job, + std::vector& output_images, + std::string& error_message) { + sd_img_gen_params_t params = job.img_gen.to_sd_img_gen_params_t(); + + SDImageVec results; + + { + std::lock_guard lock(*runtime.sd_ctx_mutex); + sd_image_t* raw_results = generate_image(runtime.sd_ctx, ¶ms); + results.adopt(raw_results, params.batch_count); + } + + const int num_results = results.count(); + if (num_results <= 0) { + error_message = "generate_image returned no results"; + return false; + } + + EncodedImageFormat encoded_format = EncodedImageFormat::PNG; + if (job.img_gen.output_format == "jpeg") { + encoded_format = EncodedImageFormat::JPEG; + } else if (job.img_gen.output_format == "webp") { + encoded_format = EncodedImageFormat::WEBP; + } + + for (int i = 0; i < num_results; ++i) { + if (results[i].data == nullptr) { + continue; + } + + const std::string metadata = job.img_gen.gen_params.embed_image_metadata + ? get_image_params(*runtime.ctx_params, + job.img_gen.gen_params, + job.img_gen.gen_params.seed + i) + : ""; + auto image_bytes = encode_image_to_vector(encoded_format, + results[i].data, + results[i].width, + results[i].height, + results[i].channel, + metadata, + job.img_gen.output_compression); + if (image_bytes.empty()) { + continue; + } + output_images.push_back(base64_encode(image_bytes)); + } + + if (output_images.empty()) { + error_message = "generate_image returned empty encoded outputs"; + return false; + } + + return true; +} + +bool execute_vid_gen_job(ServerRuntime& runtime, + AsyncGenerationJob& job, + std::string& output_media_b64, + std::string& output_media_mime_type, + int& output_frame_count, + int& output_fps, + std::string& error_message) { + sd_vid_gen_params_t params = job.vid_gen.to_sd_vid_gen_params_t(); + + SDImageVec results; + int num_results = 0; + + { + std::lock_guard lock(*runtime.sd_ctx_mutex); + sd_image_t* raw_results = generate_video(runtime.sd_ctx, ¶ms, &num_results); + results.adopt(raw_results, num_results); + } + + num_results = results.count(); + if (num_results <= 0) { + error_message = "generate_video returned no results"; + return false; + } + + std::vector video_bytes = create_video_from_sd_images_to_vector(job.vid_gen.output_format, + results.data(), + num_results, + job.vid_gen.gen_params.fps, + job.vid_gen.output_compression); + if (video_bytes.empty()) { + error_message = "failed to encode generated video container"; + return false; + } + + output_media_b64 = base64_encode(video_bytes); + output_media_mime_type = video_mime_type(job.vid_gen.output_format); + output_frame_count = num_results; + output_fps = job.vid_gen.gen_params.fps; + return true; +} + +void async_job_worker(ServerRuntime& runtime) { + AsyncJobManager& manager = *runtime.async_job_manager; + + while (true) { + std::shared_ptr job; + { + std::unique_lock lock(manager.mutex); + manager.cv.wait(lock, [&]() { return manager.stop || !manager.queue.empty(); }); + + if (manager.stop && manager.queue.empty()) { + break; + } + + purge_expired_jobs(manager); + if (manager.queue.empty()) { + continue; + } + + const std::string job_id = manager.queue.front(); + manager.queue.pop_front(); + + auto it = manager.jobs.find(job_id); + if (it == manager.jobs.end()) { + continue; + } + + job = it->second; + job->status = AsyncJobStatus::Generating; + job->started_at = unix_timestamp_now(); + } + + std::vector output_images; + std::string output_media_b64; + std::string output_media_mime_type; + int output_frame_count = 0; + int output_fps = 0; + std::string error_message; + bool ok = false; + + if (job->kind == AsyncJobKind::ImgGen) { + ok = execute_img_gen_job(runtime, *job, output_images, error_message); + } else if (job->kind == AsyncJobKind::VidGen) { + ok = execute_vid_gen_job(runtime, + *job, + output_media_b64, + output_media_mime_type, + output_frame_count, + output_fps, + error_message); + } else { + error_message = "unsupported job kind"; + } + + { + std::lock_guard lock(manager.mutex); + auto it = manager.jobs.find(job->id); + if (it == manager.jobs.end()) { + continue; + } + + job->completed_at = unix_timestamp_now(); + if (ok) { + job->status = AsyncJobStatus::Completed; + job->result_images_b64 = std::move(output_images); + job->result_media_b64 = std::move(output_media_b64); + job->result_media_mime_type = std::move(output_media_mime_type); + job->result_frame_count = output_frame_count; + job->result_fps = output_fps; + job->error_code.clear(); + job->error_message.clear(); + } else { + job->status = AsyncJobStatus::Failed; + job->error_code = "generation_failed"; + job->error_message = error_message.empty() ? "unknown generation error" : error_message; + job->result_images_b64.clear(); + job->result_media_b64.clear(); + job->result_media_mime_type.clear(); + job->result_frame_count = 0; + job->result_fps = 0; + } + + purge_expired_jobs(manager); + } + } +} diff --git a/examples/server/async_jobs.h b/examples/server/async_jobs.h new file mode 100644 index 000000000..89997a3b4 --- /dev/null +++ b/examples/server/async_jobs.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + + +#include "runtime.h" + +enum class AsyncJobKind { + ImgGen, + VidGen, +}; + +enum class AsyncJobStatus { + Queued, + Generating, + Completed, + Failed, + Cancelled, +}; + +const char* async_job_kind_name(AsyncJobKind kind); +const char* async_job_status_name(AsyncJobStatus status); + +struct AsyncGenerationJob { + std::string id; + AsyncJobKind kind = AsyncJobKind::ImgGen; + AsyncJobStatus status = AsyncJobStatus::Queued; + int64_t created_at = unix_timestamp_now(); + int64_t started_at = 0; + int64_t completed_at = 0; + ImgGenJobRequest img_gen; + VidGenJobRequest vid_gen; + std::vector result_images_b64; + std::string result_media_b64; + std::string result_media_mime_type; + int result_frame_count = 0; + int result_fps = 0; + std::string error_code; + std::string error_message; +}; + +struct AsyncJobManager { + std::mutex mutex; + std::condition_variable cv; + std::unordered_map> jobs; + std::unordered_map expired_jobs; + std::deque queue; + uint64_t next_id = 0; + bool stop = false; + size_t max_pending_jobs = 64; + int64_t completed_ttl_seconds = 600; + int64_t failed_ttl_seconds = 600; +}; + +void purge_expired_jobs(AsyncJobManager& manager); +size_t count_pending_jobs(const AsyncJobManager& manager); +std::string make_async_job_id(AsyncJobManager& manager); +bool cancel_queued_job(AsyncJobManager& manager, AsyncGenerationJob& job); +json make_async_job_json(const AsyncJobManager& manager, const AsyncGenerationJob& job); +bool execute_img_gen_job(ServerRuntime& runtime, + AsyncGenerationJob& job, + std::vector& output_images, + std::string& error_message); +bool execute_vid_gen_job(ServerRuntime& runtime, + AsyncGenerationJob& job, + std::string& output_media_b64, + std::string& output_media_mime_type, + int& output_frame_count, + int& output_fps, + std::string& error_message); +void async_job_worker(ServerRuntime& runtime); diff --git a/examples/server/frontend b/examples/server/frontend index 1a34176cd..797ccf808 160000 --- a/examples/server/frontend +++ b/examples/server/frontend @@ -1 +1 @@ -Subproject commit 1a34176cd6d39ad3a226b2b69047e71f6797f6bc +Subproject commit 797ccf80825cc035508ba9b599b2a21953e7f835 diff --git a/examples/server/main.cpp b/examples/server/main.cpp index 8d4e644b5..114d526a8 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -1,181 +1,25 @@ -// main.cpp -#include -#include -#include -#include +#include #include #include -#include +#include +#include #include #include "httplib.h" -#include "stable-diffusion.h" -#include "common/common.hpp" -#include "common/media_io.h" +#include "async_jobs.h" +#include "common/common.h" +#include "common/resource_owners.hpp" +#include "routes.h" +#include "runtime.h" #ifdef HAVE_INDEX_HTML #include "frontend/dist/gen_index_html.h" #endif -namespace fs = std::filesystem; - -// ----------------------- helpers ----------------------- -static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; - -std::string base64_encode(const std::vector& bytes) { - std::string ret; - int val = 0, valb = -6; - for (uint8_t c : bytes) { - val = (val << 8) + c; - valb += 8; - while (valb >= 0) { - ret.push_back(base64_chars[(val >> valb) & 0x3F]); - valb -= 6; - } - } - if (valb > -6) - ret.push_back(base64_chars[((val << 8) >> (valb + 8)) & 0x3F]); - while (ret.size() % 4) - ret.push_back('='); - return ret; -} - -inline bool is_base64(unsigned char c) { - return (isalnum(c) || (c == '+') || (c == '/')); -} - -std::vector base64_decode(const std::string& encoded_string) { - int in_len = static_cast(encoded_string.size()); - int i = 0; - int j = 0; - int in_ = 0; - uint8_t char_array_4[4], char_array_3[3]; - std::vector ret; - - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { - char_array_4[i++] = encoded_string[in_]; - in_++; - if (i == 4) { - for (i = 0; i < 4; i++) - char_array_4[i] = static_cast(base64_chars.find(char_array_4[i])); - - char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (i = 0; i < 3; i++) - ret.push_back(char_array_3[i]); - i = 0; - } - } - - if (i) { - for (j = i; j < 4; j++) - char_array_4[j] = 0; - - for (j = 0; j < 4; j++) - char_array_4[j] = static_cast(base64_chars.find(char_array_4[j])); - - char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (j = 0; j < i - 1; j++) - ret.push_back(char_array_3[j]); - } - - return ret; -} - -struct SDSvrParams { - std::string listen_ip = "127.0.0.1"; - int listen_port = 1234; - std::string serve_html_path; - bool normal_exit = false; - bool verbose = false; - bool color = false; - - ArgOptions get_options() { - ArgOptions options; - - options.string_options = { - {"-l", - "--listen-ip", - "server listen ip (default: 127.0.0.1)", - &listen_ip}, - {"", - "--serve-html-path", - "path to HTML file to serve at root (optional)", - &serve_html_path}}; - - options.int_options = { - {"", - "--listen-port", - "server listen port (default: 1234)", - &listen_port}, - }; - - options.bool_options = { - {"-v", - "--verbose", - "print extra info", - true, &verbose}, - {"", - "--color", - "colors the logging tags according to level", - true, &color}, - }; - - auto on_help_arg = [&](int argc, const char** argv, int index) { - normal_exit = true; - return -1; - }; - - options.manual_options = { - {"-h", - "--help", - "show this help message and exit", - on_help_arg}, - }; - return options; - }; - - bool process_and_check() { - if (listen_ip.empty()) { - LOG_ERROR("error: the following arguments are required: listen_ip"); - return false; - } - - if (listen_port < 0 || listen_port > 65535) { - LOG_ERROR("error: listen_port should be in the range [0, 65535]"); - return false; - } - - if (!serve_html_path.empty() && !fs::exists(serve_html_path)) { - LOG_ERROR("error: serve_html_path file does not exist: %s", serve_html_path.c_str()); - return false; - } - return true; - } - - std::string to_string() const { - std::ostringstream oss; - oss << "SDSvrParams {\n" - << " listen_ip: " << listen_ip << ",\n" - << " listen_port: \"" << listen_port << "\",\n" - << " serve_html_path: \"" << serve_html_path << "\",\n" - << "}"; - return oss.str(); - } -}; - -void print_usage(int argc, const char* argv[], const std::vector& options_list) { +static void print_usage(const char* argv0, const std::vector& options_list) { std::cout << version_string() << "\n"; - std::cout << "Usage: " << argv[0] << " [options]\n\n"; + std::cout << "Usage: " << argv0 << " [options]\n\n"; std::cout << "Svr Options:\n"; options_list[0].print(); std::cout << "\nContext Options:\n"; @@ -184,20 +28,30 @@ void print_usage(int argc, const char* argv[], const std::vector& op options_list[2].print(); } -void parse_args(int argc, const char** argv, SDSvrParams& svr_params, SDContextParams& ctx_params, SDGenerationParams& default_gen_params) { - std::vector options_vec = {svr_params.get_options(), ctx_params.get_options(), default_gen_params.get_options()}; +static void parse_args(int argc, + const char** argv, + SDSvrParams& svr_params, + SDContextParams& ctx_params, + SDGenerationParams& default_gen_params) { + std::vector options_vec = { + svr_params.get_options(), + ctx_params.get_options(), + default_gen_params.get_options(), + }; if (!parse_options(argc, argv, options_vec)) { - print_usage(argc, argv, options_vec); + print_usage(argv[0], options_vec); exit(svr_params.normal_exit ? 0 : 1); } const bool random_seed_requested = default_gen_params.seed < 0; - if (!svr_params.process_and_check() || - !ctx_params.process_and_check(IMG_GEN) || - !default_gen_params.process_and_check(IMG_GEN, ctx_params.lora_model_dir)) { - print_usage(argc, argv, options_vec); + if (!svr_params.resolve_and_validate() || + !ctx_params.resolve_and_validate(IMG_GEN) || + !default_gen_params.resolve_and_validate(IMG_GEN, + ctx_params.lora_model_dir, + ctx_params.hires_upscalers_dir)) { + print_usage(argv[0], options_vec); exit(1); } @@ -206,957 +60,11 @@ void parse_args(int argc, const char** argv, SDSvrParams& svr_params, SDContextP } } -std::string extract_and_remove_sd_cpp_extra_args(std::string& text) { - std::regex re("(.*?)"); - std::smatch match; - - std::string extracted; - if (std::regex_search(text, match, re)) { - extracted = match[1].str(); - text = std::regex_replace(text, re, ""); - } - return extracted; -} - void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { SDSvrParams* svr_params = (SDSvrParams*)data; log_print(level, log, svr_params->verbose, svr_params->color); } -struct LoraEntry { - std::string name; - std::string path; - std::string fullpath; -}; - -struct ServerRuntime { - sd_ctx_t* sd_ctx; - std::mutex* sd_ctx_mutex; - const SDSvrParams* svr_params; - const SDContextParams* ctx_params; - const SDGenerationParams* default_gen_params; - std::vector* lora_cache; - std::mutex* lora_mutex; -}; - -void refresh_lora_cache(ServerRuntime& rt) { - std::vector new_cache; - - fs::path lora_dir = rt.ctx_params->lora_model_dir; - if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) { - auto is_lora_ext = [](const fs::path& p) { - auto ext = p.extension().string(); - std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); - return ext == ".gguf" || ext == ".pt" || ext == ".pth" || ext == ".safetensors"; - }; - - for (auto& entry : fs::recursive_directory_iterator(lora_dir)) { - if (!entry.is_regular_file()) - continue; - const fs::path& p = entry.path(); - if (!is_lora_ext(p)) - continue; - - LoraEntry e; - e.name = p.stem().u8string(); - e.fullpath = p.u8string(); - std::string rel = p.lexically_relative(lora_dir).u8string(); - std::replace(rel.begin(), rel.end(), '\\', '/'); - e.path = rel; - - new_cache.push_back(std::move(e)); - } - } - - std::sort(new_cache.begin(), new_cache.end(), - [](const LoraEntry& a, const LoraEntry& b) { - return a.path < b.path; - }); - - { - std::lock_guard lock(*rt.lora_mutex); - *rt.lora_cache = std::move(new_cache); - } -} - -std::string get_lora_full_path(ServerRuntime& rt, const std::string& path) { - std::lock_guard lock(*rt.lora_mutex); - auto it = std::find_if(rt.lora_cache->begin(), rt.lora_cache->end(), - [&](const LoraEntry& e) { return e.path == path; }); - return (it != rt.lora_cache->end()) ? it->fullpath : ""; -} - -void free_results(sd_image_t* result_images, int num_results) { - if (result_images) { - for (int i = 0; i < num_results; ++i) { - if (result_images[i].data) { - free(result_images[i].data); - result_images[i].data = nullptr; - } - } - } - free(result_images); -} - -void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) { - const std::string serve_html_path = svr_params.serve_html_path; - svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) { - if (!serve_html_path.empty()) { - std::ifstream file(serve_html_path); - if (file) { - std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); - res.set_content(content, "text/html"); - } else { - res.status = 500; - res.set_content("Error: Unable to read HTML file", "text/plain"); - } - } else { - res.set_content(index_html, "text/html"); - } - }); -} - -void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { - ServerRuntime* runtime = &rt; - - svr.Get("/v1/models", [runtime](const httplib::Request&, httplib::Response& res) { - json r; - r["data"] = json::array(); - r["data"].push_back({{"id", "sd-cpp-local"}, {"object", "model"}, {"owned_by", "local"}}); - res.set_content(r.dump(), "application/json"); - }); - - svr.Post("/v1/images/generations", [runtime](const httplib::Request& req, httplib::Response& res) { - try { - if (req.body.empty()) { - res.status = 400; - res.set_content(R"({"error":"empty body"})", "application/json"); - return; - } - - json j = json::parse(req.body); - std::string prompt = j.value("prompt", ""); - int n = std::max(1, j.value("n", 1)); - std::string size = j.value("size", ""); - std::string output_format = j.value("output_format", "png"); - int output_compression = j.value("output_compression", 100); - int width = runtime->default_gen_params->width > 0 ? runtime->default_gen_params->width : 512; - int height = runtime->default_gen_params->width > 0 ? runtime->default_gen_params->height : 512; - if (!size.empty()) { - auto pos = size.find('x'); - if (pos != std::string::npos) { - try { - width = std::stoi(size.substr(0, pos)); - height = std::stoi(size.substr(pos + 1)); - } catch (...) { - } - } - } - - if (prompt.empty()) { - res.status = 400; - res.set_content(R"({"error":"prompt required"})", "application/json"); - return; - } - - std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(prompt); - - if (output_format != "png" && output_format != "jpeg" && output_format != "webp") { - res.status = 400; - res.set_content(R"({"error":"invalid output_format, must be one of [png, jpeg, webp]"})", "application/json"); - return; - } - if (n <= 0) - n = 1; - if (n > 8) - n = 8; - if (output_compression > 100) { - output_compression = 100; - } - if (output_compression < 0) { - output_compression = 0; - } - - json out; - out["created"] = static_cast(std::time(nullptr)); - out["data"] = json::array(); - out["output_format"] = output_format; - - SDGenerationParams gen_params = *runtime->default_gen_params; - gen_params.prompt = prompt; - gen_params.width = width; - gen_params.height = height; - gen_params.batch_count = n; - - if (!sd_cpp_extra_args_str.empty() && !gen_params.from_json_str(sd_cpp_extra_args_str)) { - res.status = 400; - res.set_content(R"({"error":"invalid sd_cpp_extra_args"})", "application/json"); - return; - } - - if (gen_params.sample_params.sample_steps > 100) - gen_params.sample_params.sample_steps = 100; - - if (!gen_params.process_and_check(IMG_GEN, "")) { - res.status = 400; - res.set_content(R"({"error":"invalid params"})", "application/json"); - return; - } - - LOG_DEBUG("%s\n", gen_params.to_string().c_str()); - - sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; - sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; - sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr}; - std::vector pmid_images; - - sd_img_gen_params_t img_gen_params = { - gen_params.lora_vec.data(), - static_cast(gen_params.lora_vec.size()), - gen_params.prompt.c_str(), - gen_params.negative_prompt.c_str(), - gen_params.clip_skip, - init_image, - nullptr, - 0, - gen_params.auto_resize_ref_image, - gen_params.increase_ref_index, - mask_image, - gen_params.width, - gen_params.height, - gen_params.sample_params, - gen_params.strength, - gen_params.seed, - gen_params.batch_count, - control_image, - gen_params.control_strength, - { - pmid_images.data(), - (int)pmid_images.size(), - gen_params.pm_id_embed_path.c_str(), - gen_params.pm_style_strength, - }, - gen_params.vae_tiling_params, - gen_params.cache_params, - }; - - sd_image_t* results = nullptr; - int num_results = 0; - - { - std::lock_guard lock(*runtime->sd_ctx_mutex); - results = generate_image(runtime->sd_ctx, &img_gen_params); - num_results = gen_params.batch_count; - } - - for (int i = 0; i < num_results; i++) { - if (results[i].data == nullptr) { - continue; - } - std::string params = gen_params.embed_image_metadata - ? get_image_params(*runtime->ctx_params, gen_params, gen_params.seed + i) - : ""; - auto image_bytes = encode_image_to_vector(output_format == "jpeg" - ? EncodedImageFormat::JPEG - : output_format == "webp" - ? EncodedImageFormat::WEBP - : EncodedImageFormat::PNG, - results[i].data, - results[i].width, - results[i].height, - results[i].channel, - params, - output_compression); - if (image_bytes.empty()) { - LOG_ERROR("write image to mem failed"); - continue; - } - - std::string b64 = base64_encode(image_bytes); - json item; - item["b64_json"] = b64; - out["data"].push_back(item); - } - free_results(results, num_results); - - res.set_content(out.dump(), "application/json"); - res.status = 200; - - } catch (const std::exception& e) { - res.status = 500; - json err; - err["error"] = "server_error"; - err["message"] = e.what(); - res.set_content(err.dump(), "application/json"); - } - }); - - svr.Post("/v1/images/edits", [runtime](const httplib::Request& req, httplib::Response& res) { - try { - if (!req.is_multipart_form_data()) { - res.status = 400; - res.set_content(R"({"error":"Content-Type must be multipart/form-data"})", "application/json"); - return; - } - - std::string prompt = req.form.get_field("prompt"); - if (prompt.empty()) { - res.status = 400; - res.set_content(R"({"error":"prompt required"})", "application/json"); - return; - } - - std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(prompt); - - size_t image_count = req.form.get_file_count("image[]"); - bool has_legacy_image = req.form.has_file("image"); - if (image_count == 0 && !has_legacy_image) { - res.status = 400; - res.set_content(R"({"error":"at least one image[] required"})", "application/json"); - return; - } - - std::vector> images_bytes; - for (size_t i = 0; i < image_count; i++) { - auto file = req.form.get_file("image[]", i); - images_bytes.emplace_back(file.content.begin(), file.content.end()); - } - if (image_count == 0 && has_legacy_image) { - auto file = req.form.get_file("image"); - images_bytes.emplace_back(file.content.begin(), file.content.end()); - } - - std::vector mask_bytes; - if (req.form.has_file("mask")) { - auto file = req.form.get_file("mask"); - mask_bytes.assign(file.content.begin(), file.content.end()); - } - - int n = 1; - if (req.form.has_field("n")) { - try { - n = std::stoi(req.form.get_field("n")); - } catch (...) { - } - } - n = std::clamp(n, 1, 8); - - std::string size = req.form.get_field("size"); - int width = -1, height = -1; - if (!size.empty()) { - auto pos = size.find('x'); - if (pos != std::string::npos) { - try { - width = std::stoi(size.substr(0, pos)); - height = std::stoi(size.substr(pos + 1)); - } catch (...) { - } - } - } - - std::string output_format = "png"; - if (req.form.has_field("output_format")) - output_format = req.form.get_field("output_format"); - if (output_format != "png" && output_format != "jpeg") { - res.status = 400; - res.set_content(R"({"error":"invalid output_format, must be one of [png, jpeg]"})", "application/json"); - return; - } - - std::string output_compression_str = req.form.get_field("output_compression"); - int output_compression = 100; - try { - output_compression = std::stoi(output_compression_str); - } catch (...) { - } - if (output_compression > 100) { - output_compression = 100; - } - if (output_compression < 0) { - output_compression = 0; - } - - SDGenerationParams gen_params = *runtime->default_gen_params; - gen_params.prompt = prompt; - gen_params.width = width; - gen_params.height = height; - gen_params.batch_count = n; - - if (!sd_cpp_extra_args_str.empty() && !gen_params.from_json_str(sd_cpp_extra_args_str)) { - res.status = 400; - res.set_content(R"({"error":"invalid sd_cpp_extra_args"})", "application/json"); - return; - } - - if (gen_params.sample_params.sample_steps > 100) - gen_params.sample_params.sample_steps = 100; - - if (!gen_params.process_and_check(IMG_GEN, "")) { - res.status = 400; - res.set_content(R"({"error":"invalid params"})", "application/json"); - return; - } - - LOG_DEBUG("%s\n", gen_params.to_string().c_str()); - - sd_image_t init_image = {0, 0, 3, nullptr}; - sd_image_t control_image = {0, 0, 3, nullptr}; - std::vector pmid_images; - - auto get_resolved_width = [&gen_params, runtime]() -> int { - if (gen_params.width > 0) - return gen_params.width; - if (runtime->default_gen_params->width > 0) - return runtime->default_gen_params->width; - return 512; - }; - auto get_resolved_height = [&gen_params, runtime]() -> int { - if (gen_params.height > 0) - return gen_params.height; - if (runtime->default_gen_params->height > 0) - return runtime->default_gen_params->height; - return 512; - }; - - std::vector ref_images; - ref_images.reserve(images_bytes.size()); - for (auto& bytes : images_bytes) { - int img_w; - int img_h; - - uint8_t* raw_pixels = load_image_from_memory( - reinterpret_cast(bytes.data()), - static_cast(bytes.size()), - img_w, img_h, - width, height, 3); - - if (!raw_pixels) { - continue; - } - - sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels}; - gen_params.set_width_and_height_if_unset(img.width, img.height); - ref_images.push_back(img); - } - - sd_image_t mask_image = {0}; - if (!mask_bytes.empty()) { - int expected_width = 0; - int expected_height = 0; - if (gen_params.width_and_height_are_set()) { - expected_width = gen_params.width; - expected_height = gen_params.height; - } - int mask_w; - int mask_h; - - uint8_t* mask_raw = load_image_from_memory( - reinterpret_cast(mask_bytes.data()), - static_cast(mask_bytes.size()), - mask_w, mask_h, - expected_width, expected_height, 1); - mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw}; - gen_params.set_width_and_height_if_unset(mask_image.width, mask_image.height); - } else { - mask_image.width = get_resolved_width(); - mask_image.height = get_resolved_height(); - mask_image.channel = 1; - mask_image.data = nullptr; - } - - sd_img_gen_params_t img_gen_params = { - gen_params.lora_vec.data(), - static_cast(gen_params.lora_vec.size()), - gen_params.prompt.c_str(), - gen_params.negative_prompt.c_str(), - gen_params.clip_skip, - init_image, - ref_images.data(), - (int)ref_images.size(), - gen_params.auto_resize_ref_image, - gen_params.increase_ref_index, - mask_image, - get_resolved_width(), - get_resolved_height(), - gen_params.sample_params, - gen_params.strength, - gen_params.seed, - gen_params.batch_count, - control_image, - gen_params.control_strength, - { - pmid_images.data(), - (int)pmid_images.size(), - gen_params.pm_id_embed_path.c_str(), - gen_params.pm_style_strength, - }, - gen_params.vae_tiling_params, - gen_params.cache_params, - }; - - sd_image_t* results = nullptr; - int num_results = 0; - - { - std::lock_guard lock(*runtime->sd_ctx_mutex); - results = generate_image(runtime->sd_ctx, &img_gen_params); - num_results = gen_params.batch_count; - } - - json out; - out["created"] = static_cast(std::time(nullptr)); - out["data"] = json::array(); - out["output_format"] = output_format; - - for (int i = 0; i < num_results; i++) { - if (results[i].data == nullptr) - continue; - std::string params = gen_params.embed_image_metadata - ? get_image_params(*runtime->ctx_params, gen_params, gen_params.seed + i) - : ""; - auto image_bytes = encode_image_to_vector(output_format == "jpeg" - ? EncodedImageFormat::JPEG - : output_format == "webp" - ? EncodedImageFormat::WEBP - : EncodedImageFormat::PNG, - results[i].data, - results[i].width, - results[i].height, - results[i].channel, - params, - output_compression); - std::string b64 = base64_encode(image_bytes); - json item; - item["b64_json"] = b64; - out["data"].push_back(item); - } - free_results(results, num_results); - - res.set_content(out.dump(), "application/json"); - res.status = 200; - - if (init_image.data) { - free(init_image.data); - } - if (mask_image.data) { - free(mask_image.data); - } - for (auto ref_image : ref_images) { - free(ref_image.data); - } - } catch (const std::exception& e) { - res.status = 500; - json err; - err["error"] = "server_error"; - err["message"] = e.what(); - res.set_content(err.dump(), "application/json"); - } - }); -} - -void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) { - ServerRuntime* runtime = &rt; - - auto sdapi_any2img = [runtime](const httplib::Request& req, httplib::Response& res, bool img2img) { - try { - if (req.body.empty()) { - res.status = 400; - res.set_content(R"({"error":"empty body"})", "application/json"); - return; - } - - json j = json::parse(req.body); - - std::string prompt = j.value("prompt", ""); - std::string negative_prompt = j.value("negative_prompt", ""); - int width = j.value("width", 512); - int height = j.value("height", 512); - int steps = j.value("steps", runtime->default_gen_params->sample_params.sample_steps); - float cfg_scale = j.value("cfg_scale", runtime->default_gen_params->sample_params.guidance.txt_cfg); - int64_t seed = j.value("seed", -1); - int batch_size = j.value("batch_size", 1); - int clip_skip = j.value("clip_skip", -1); - std::string sampler_name = j.value("sampler_name", ""); - std::string scheduler_name = j.value("scheduler", ""); - - auto bad = [&](const std::string& msg) { - res.status = 400; - res.set_content("{\"error\":\"" + msg + "\"}", "application/json"); - return; - }; - - if (width <= 0 || height <= 0) { - return bad("width and height must be positive"); - } - - if (steps < 1 || steps > 150) { - return bad("steps must be in range [1, 150]"); - } - - if (batch_size < 1 || batch_size > 8) { - return bad("batch_size must be in range [1, 8]"); - } - - if (cfg_scale < 0.f) { - return bad("cfg_scale must be positive"); - } - - if (prompt.empty()) { - return bad("prompt required"); - } - - std::vector sd_loras; - std::vector lora_path_storage; - - if (j.contains("lora") && j["lora"].is_array()) { - for (const auto& item : j["lora"]) { - if (!item.is_object()) { - continue; - } - - std::string path = item.value("path", ""); - float multiplier = item.value("multiplier", 1.0f); - bool is_high_noise = item.value("is_high_noise", false); - - if (path.empty()) { - return bad("lora.path required"); - } - - std::string fullpath = get_lora_full_path(*runtime, path); - if (fullpath.empty()) { - return bad("invalid lora path: " + path); - } - - lora_path_storage.push_back(fullpath); - sd_lora_t l; - l.is_high_noise = is_high_noise; - l.multiplier = multiplier; - l.path = lora_path_storage.back().c_str(); - - sd_loras.push_back(l); - } - } - - auto get_sample_method = [](std::string name) -> enum sample_method_t { - enum sample_method_t result = str_to_sample_method(name.c_str()); - if (result != SAMPLE_METHOD_COUNT) return result; - std::transform(name.begin(), name.end(), name.begin(), - [](unsigned char c) { return std::tolower(c); }); - static const std::unordered_map hardcoded{ - {"euler a", EULER_A_SAMPLE_METHOD}, - {"k_euler_a", EULER_A_SAMPLE_METHOD}, - {"euler", EULER_SAMPLE_METHOD}, - {"k_euler", EULER_SAMPLE_METHOD}, - {"heun", HEUN_SAMPLE_METHOD}, - {"k_heun", HEUN_SAMPLE_METHOD}, - {"dpm2", DPM2_SAMPLE_METHOD}, - {"k_dpm_2", DPM2_SAMPLE_METHOD}, - {"lcm", LCM_SAMPLE_METHOD}, - {"ddim", DDIM_TRAILING_SAMPLE_METHOD}, - {"dpm++ 2m", DPMPP2M_SAMPLE_METHOD}, - {"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD}, - {"res multistep", RES_MULTISTEP_SAMPLE_METHOD}, - {"k_res_multistep", RES_MULTISTEP_SAMPLE_METHOD}, - {"res 2s", RES_2S_SAMPLE_METHOD}, - {"k_res_2s", RES_2S_SAMPLE_METHOD}}; - auto it = hardcoded.find(name); - if (it != hardcoded.end()) return it->second; - return SAMPLE_METHOD_COUNT; - }; - - enum sample_method_t sample_method = get_sample_method(sampler_name); - enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str()); - - SDGenerationParams gen_params = *runtime->default_gen_params; - gen_params.prompt = prompt; - gen_params.negative_prompt = negative_prompt; - gen_params.seed = seed; - gen_params.sample_params.sample_steps = steps; - gen_params.batch_count = batch_size; - gen_params.sample_params.guidance.txt_cfg = cfg_scale; - - if (clip_skip > 0) { - gen_params.clip_skip = clip_skip; - } - - if (sample_method != SAMPLE_METHOD_COUNT) { - gen_params.sample_params.sample_method = sample_method; - } - - if (scheduler != SCHEDULER_COUNT) { - gen_params.sample_params.scheduler = scheduler; - } - - gen_params.width = j.value("width", -1); - gen_params.height = j.value("height", -1); - - LOG_DEBUG("%s\n", gen_params.to_string().c_str()); - - sd_image_t init_image = {0, 0, 3, nullptr}; - sd_image_t control_image = {0, 0, 3, nullptr}; - sd_image_t mask_image = {0, 0, 1, nullptr}; - std::vector mask_data; - std::vector pmid_images; - std::vector ref_images; - - auto get_resolved_width = [&gen_params, runtime]() -> int { - if (gen_params.width > 0) - return gen_params.width; - if (runtime->default_gen_params->width > 0) - return runtime->default_gen_params->width; - return 512; - }; - auto get_resolved_height = [&gen_params, runtime]() -> int { - if (gen_params.height > 0) - return gen_params.height; - if (runtime->default_gen_params->height > 0) - return runtime->default_gen_params->height; - return 512; - }; - - auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool { - auto comma_pos = encoded.find(','); - if (comma_pos != std::string::npos) { - encoded = encoded.substr(comma_pos + 1); - } - std::vector img_data = base64_decode(encoded); - if (!img_data.empty()) { - int expected_width = 0; - int expected_height = 0; - if (gen_params.width_and_height_are_set()) { - expected_width = gen_params.width; - expected_height = gen_params.height; - } - int img_w; - int img_h; - - uint8_t* raw_data = load_image_from_memory( - (const char*)img_data.data(), (int)img_data.size(), - img_w, img_h, - expected_width, expected_height, image.channel); - if (raw_data) { - image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data}; - gen_params.set_width_and_height_if_unset(image.width, image.height); - return true; - } - } - return false; - }; - - if (img2img) { - if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) { - std::string encoded = j["init_images"][0].get(); - decode_image(init_image, encoded); - } - - if (j.contains("mask") && j["mask"].is_string()) { - std::string encoded = j["mask"].get(); - decode_image(mask_image, encoded); - bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0; - if (inpainting_mask_invert && mask_image.data != nullptr) { - for (uint32_t i = 0; i < mask_image.width * mask_image.height; i++) { - mask_image.data[i] = 255 - mask_image.data[i]; - } - } - } else { - int m_width = get_resolved_width(); - int m_height = get_resolved_height(); - mask_data = std::vector(m_width * m_height, 255); - mask_image.width = m_width; - mask_image.height = m_height; - mask_image.channel = 1; - mask_image.data = mask_data.data(); - } - - float denoising_strength = j.value("denoising_strength", -1.f); - if (denoising_strength >= 0.f) { - denoising_strength = std::min(denoising_strength, 1.0f); - gen_params.strength = denoising_strength; - } - } - - if (j.contains("extra_images") && j["extra_images"].is_array()) { - for (auto extra_image : j["extra_images"]) { - std::string encoded = extra_image.get(); - sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; - if (decode_image(tmp_image, encoded)) { - ref_images.push_back(tmp_image); - } - } - } - - sd_img_gen_params_t img_gen_params = { - sd_loras.data(), - static_cast(sd_loras.size()), - gen_params.prompt.c_str(), - gen_params.negative_prompt.c_str(), - gen_params.clip_skip, - init_image, - ref_images.data(), - (int)ref_images.size(), - gen_params.auto_resize_ref_image, - gen_params.increase_ref_index, - mask_image, - get_resolved_width(), - get_resolved_height(), - gen_params.sample_params, - gen_params.strength, - gen_params.seed, - gen_params.batch_count, - control_image, - gen_params.control_strength, - { - pmid_images.data(), - (int)pmid_images.size(), - gen_params.pm_id_embed_path.c_str(), - gen_params.pm_style_strength, - }, - gen_params.vae_tiling_params, - gen_params.cache_params, - }; - - sd_image_t* results = nullptr; - int num_results = 0; - - { - std::lock_guard lock(*runtime->sd_ctx_mutex); - results = generate_image(runtime->sd_ctx, &img_gen_params); - num_results = gen_params.batch_count; - } - - json out; - out["images"] = json::array(); - out["parameters"] = j; - out["info"] = ""; - - for (int i = 0; i < num_results; i++) { - if (results[i].data == nullptr) { - continue; - } - - std::string params = gen_params.embed_image_metadata - ? get_image_params(*runtime->ctx_params, gen_params, gen_params.seed + i) - : ""; - auto image_bytes = encode_image_to_vector(EncodedImageFormat::PNG, - results[i].data, - results[i].width, - results[i].height, - results[i].channel, - params); - - if (image_bytes.empty()) { - LOG_ERROR("write image to mem failed"); - continue; - } - - std::string b64 = base64_encode(image_bytes); - out["images"].push_back(b64); - } - free_results(results, num_results); - - res.set_content(out.dump(), "application/json"); - res.status = 200; - - if (init_image.data) { - free(init_image.data); - } - if (mask_image.data && mask_data.empty()) { - free(mask_image.data); - } - for (auto ref_image : ref_images) { - free(ref_image.data); - } - - } catch (const std::exception& e) { - res.status = 500; - json err; - err["error"] = "server_error"; - err["message"] = e.what(); - res.set_content(err.dump(), "application/json"); - } - }; - - svr.Post("/sdapi/v1/txt2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) { - sdapi_any2img(req, res, false); - }); - - svr.Post("/sdapi/v1/img2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) { - sdapi_any2img(req, res, true); - }); - - svr.Get("/sdapi/v1/loras", [runtime](const httplib::Request&, httplib::Response& res) { - refresh_lora_cache(*runtime); - - json result = json::array(); - { - std::lock_guard lock(*runtime->lora_mutex); - for (const auto& e : *runtime->lora_cache) { - json item; - item["name"] = e.name; - item["path"] = e.path; - result.push_back(item); - } - } - - res.set_content(result.dump(), "application/json"); - }); - - svr.Get("/sdapi/v1/samplers", [runtime](const httplib::Request&, httplib::Response& res) { - std::vector sampler_names; - sampler_names.push_back("default"); - for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) { - sampler_names.push_back(sd_sample_method_name((sample_method_t)i)); - } - json r = json::array(); - for (auto name : sampler_names) { - json entry; - entry["name"] = name; - entry["aliases"] = json::array({name}); - entry["options"] = json::object(); - r.push_back(entry); - } - res.set_content(r.dump(), "application/json"); - }); - - svr.Get("/sdapi/v1/schedulers", [runtime](const httplib::Request&, httplib::Response& res) { - std::vector scheduler_names; - scheduler_names.push_back("default"); - for (int i = 0; i < SCHEDULER_COUNT; i++) { - scheduler_names.push_back(sd_scheduler_name((scheduler_t)i)); - } - json r = json::array(); - for (auto name : scheduler_names) { - json entry; - entry["name"] = name; - entry["label"] = name; - r.push_back(entry); - } - res.set_content(r.dump(), "application/json"); - }); - - svr.Get("/sdapi/v1/sd-models", [runtime](const httplib::Request&, httplib::Response& res) { - fs::path model_path = runtime->ctx_params->model_path; - json entry; - entry["title"] = model_path.stem(); - entry["model_name"] = model_path.stem(); - entry["filename"] = model_path.filename(); - entry["hash"] = "8888888888"; - entry["sha256"] = "8888888888888888888888888888888888888888888888888888888888888888"; - entry["config"] = nullptr; - json r = json::array(); - r.push_back(entry); - res.set_content(r.dump(), "application/json"); - }); - - svr.Get("/sdapi/v1/options", [runtime](const httplib::Request&, httplib::Response& res) { - fs::path model_path = runtime->ctx_params->model_path; - json r; - r["samples_format"] = "png"; - r["sd_model_checkpoint"] = model_path.stem(); - res.set_content(r.dump(), "application/json"); - }); -} - int main(int argc, const char** argv) { if (argc > 1 && std::string(argv[1]) == "--version") { std::cout << version_string() << "\n"; @@ -1178,7 +86,7 @@ int main(int argc, const char** argv) { LOG_DEBUG("%s", default_gen_params.to_string().c_str()); sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false); - sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); + SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params)); if (sd_ctx == nullptr) { LOG_ERROR("new_sd_ctx_t failed"); @@ -1189,16 +97,24 @@ int main(int argc, const char** argv) { std::vector lora_cache; std::mutex lora_mutex; + std::vector upscaler_cache; + std::mutex upscaler_mutex; + AsyncJobManager async_job_manager; ServerRuntime runtime = { - sd_ctx, + sd_ctx.get(), &sd_ctx_mutex, &svr_params, &ctx_params, &default_gen_params, &lora_cache, &lora_mutex, + &upscaler_cache, + &upscaler_mutex, + &async_job_manager, }; + std::thread async_worker(async_job_worker, std::ref(runtime)); + httplib::Server svr; svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) { @@ -1227,10 +143,16 @@ int main(int argc, const char** argv) { register_index_endpoints(svr, svr_params, index_html); register_openai_api_endpoints(svr, runtime); register_sdapi_endpoints(svr, runtime); + register_sdcpp_api_endpoints(svr, runtime); LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port); svr.listen(svr_params.listen_ip, svr_params.listen_port); - free_sd_ctx(sd_ctx); + { + std::lock_guard lock(async_job_manager.mutex); + async_job_manager.stop = true; + } + async_job_manager.cv.notify_all(); + async_worker.join(); return 0; } diff --git a/examples/server/routes.h b/examples/server/routes.h new file mode 100644 index 000000000..1a4efb5b7 --- /dev/null +++ b/examples/server/routes.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +#include "httplib.h" +#include "runtime.h" + +void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html); +void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt); +void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt); +void register_sdcpp_api_endpoints(httplib::Server& svr, ServerRuntime& rt); diff --git a/examples/server/routes_index.cpp b/examples/server/routes_index.cpp new file mode 100644 index 000000000..1341ff84a --- /dev/null +++ b/examples/server/routes_index.cpp @@ -0,0 +1,22 @@ +#include "routes.h" + +#include +#include + +void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) { + const std::string serve_html_path = svr_params.serve_html_path; + svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) { + if (!serve_html_path.empty()) { + std::ifstream file(serve_html_path); + if (file) { + std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + res.set_content(content, "text/html"); + } else { + res.status = 500; + res.set_content("Error: Unable to read HTML file", "text/plain"); + } + } else { + res.set_content(index_html, "text/html"); + } + }); +} diff --git a/examples/server/routes_openai.cpp b/examples/server/routes_openai.cpp new file mode 100644 index 000000000..a24383d67 --- /dev/null +++ b/examples/server/routes_openai.cpp @@ -0,0 +1,388 @@ +#include "routes.h" + +#include +#include +#include + +#include "common/common.h" +#include "common/media_io.h" +#include "common/resource_owners.hpp" + +static std::string extract_and_remove_sd_cpp_extra_args(std::string& text) { + std::regex re("(.*?)"); + std::smatch match; + + std::string extracted; + if (std::regex_search(text, match, re)) { + extracted = match[1].str(); + text = std::regex_replace(text, re, ""); + } + return extracted; +} + +static bool build_openai_generation_request(const httplib::Request& req, + ServerRuntime& runtime, + ImgGenJobRequest& request, + std::string& error_message) { + if (req.body.empty()) { + error_message = "empty body"; + return false; + } + + json j = json::parse(req.body); + std::string prompt = j.value("prompt", ""); + int n = std::max(1, j.value("n", 1)); + std::string size = j.value("size", ""); + std::string output_format = j.value("output_format", "png"); + int output_compression = j.value("output_compression", 100); + int width = runtime.default_gen_params->width > 0 ? runtime.default_gen_params->width : 512; + int height = runtime.default_gen_params->width > 0 ? runtime.default_gen_params->height : 512; + if (!size.empty()) { + auto pos = size.find('x'); + if (pos != std::string::npos) { + try { + width = std::stoi(size.substr(0, pos)); + height = std::stoi(size.substr(pos + 1)); + } catch (...) { + } + } + } + + if (prompt.empty()) { + error_message = "prompt required"; + return false; + } + + request.gen_params = *runtime.default_gen_params; + if (!assign_output_options(request, output_format, output_compression, true, error_message)) { + return false; + } + + request.gen_params.prompt = prompt; + request.gen_params.width = width; + request.gen_params.height = height; + request.gen_params.batch_count = n; + + std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(request.gen_params.prompt); + if (!sd_cpp_extra_args_str.empty() && !request.gen_params.from_json_str(sd_cpp_extra_args_str)) { + error_message = "invalid sd_cpp_extra_args"; + return false; + } + + // Intentionally disable prompt-embedded LoRA tag parsing for server APIs. + if (!request.gen_params.resolve_and_validate(IMG_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) { + error_message = "invalid params"; + return false; + } + return true; +} + +static bool build_openai_edit_request(const httplib::Request& req, + ServerRuntime& runtime, + ImgGenJobRequest& request, + std::string& error_message) { + if (!req.is_multipart_form_data()) { + error_message = "Content-Type must be multipart/form-data"; + return false; + } + + std::string prompt = req.form.get_field("prompt"); + if (prompt.empty()) { + error_message = "prompt required"; + return false; + } + + size_t image_count = req.form.get_file_count("image[]"); + bool has_legacy_image = req.form.has_file("image"); + if (image_count == 0 && !has_legacy_image) { + error_message = "at least one image[] required"; + return false; + } + + std::vector> images_bytes; + for (size_t i = 0; i < image_count; ++i) { + auto file = req.form.get_file("image[]", i); + images_bytes.emplace_back(file.content.begin(), file.content.end()); + } + if (image_count == 0 && has_legacy_image) { + auto file = req.form.get_file("image"); + images_bytes.emplace_back(file.content.begin(), file.content.end()); + } + + std::vector mask_bytes; + if (req.form.has_file("mask")) { + auto file = req.form.get_file("mask"); + mask_bytes.assign(file.content.begin(), file.content.end()); + } + + int n = 1; + if (req.form.has_field("n")) { + try { + n = std::stoi(req.form.get_field("n")); + } catch (...) { + } + } + + std::string size = req.form.get_field("size"); + int width = -1; + int height = -1; + if (!size.empty()) { + auto pos = size.find('x'); + if (pos != std::string::npos) { + try { + width = std::stoi(size.substr(0, pos)); + height = std::stoi(size.substr(pos + 1)); + } catch (...) { + } + } + } + + std::string output_format = req.form.has_field("output_format") + ? req.form.get_field("output_format") + : "png"; + + int output_compression = 100; + try { + output_compression = std::stoi(req.form.get_field("output_compression")); + } catch (...) { + } + + request.gen_params = *runtime.default_gen_params; + if (!assign_output_options(request, output_format, output_compression, false, error_message)) { + return false; + } + + request.gen_params.prompt = prompt; + request.gen_params.width = width; + request.gen_params.height = height; + request.gen_params.batch_count = n; + + for (auto& bytes : images_bytes) { + int img_w = 0; + int img_h = 0; + uint8_t* raw_pixels = load_image_from_memory( + reinterpret_cast(bytes.data()), + static_cast(bytes.size()), + img_w, img_h, + width, height, 3); + if (raw_pixels == nullptr) { + continue; + } + + SDImageOwner image_owner({(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels}); + request.gen_params.set_width_and_height_if_unset(image_owner.get().width, image_owner.get().height); + request.gen_params.ref_images.push_back(std::move(image_owner)); + } + + if (!request.gen_params.ref_images.empty()) { + request.gen_params.init_image = request.gen_params.ref_images.front(); + } + + if (!mask_bytes.empty()) { + int expected_width = 0; + int expected_height = 0; + if (request.gen_params.width_and_height_are_set()) { + expected_width = request.gen_params.width; + expected_height = request.gen_params.height; + } + int mask_w = 0; + int mask_h = 0; + + uint8_t* mask_raw = load_image_from_memory( + reinterpret_cast(mask_bytes.data()), + static_cast(mask_bytes.size()), + mask_w, mask_h, + expected_width, expected_height, 1); + request.gen_params.mask_image.reset({(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw}); + const sd_image_t& mask_image = request.gen_params.mask_image.get(); + request.gen_params.set_width_and_height_if_unset(mask_image.width, mask_image.height); + } else { + request.gen_params.mask_image.reset({ + (uint32_t)request.gen_params.get_resolved_width(), + (uint32_t)request.gen_params.get_resolved_height(), + 1, + nullptr, + }); + } + + std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(request.gen_params.prompt); + if (!sd_cpp_extra_args_str.empty() && !request.gen_params.from_json_str(sd_cpp_extra_args_str)) { + error_message = "invalid sd_cpp_extra_args"; + return false; + } + + // Intentionally disable prompt-embedded LoRA tag parsing for server APIs. + if (!request.gen_params.resolve_and_validate(IMG_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) { + error_message = "invalid params"; + return false; + } + + return true; +} + +static bool execute_sync_img_gen_request(ServerRuntime& runtime, + ImgGenJobRequest& request, + SDImageVec& results, + std::string& error_message) { + sd_img_gen_params_t img_gen_params = request.to_sd_img_gen_params_t(); + int num_results = 0; + + { + std::lock_guard lock(*runtime.sd_ctx_mutex); + sd_image_t* raw_results = generate_image(runtime.sd_ctx, &img_gen_params); + num_results = request.gen_params.batch_count; + results.adopt(raw_results, num_results); + } + + if (results.empty()) { + error_message = "generate_image returned no results"; + return false; + } + return true; +} + +void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { + ServerRuntime* runtime = &rt; + + svr.Get("/v1/models", [runtime](const httplib::Request&, httplib::Response& res) { + json r; + r["data"] = json::array(); + r["data"].push_back({{"id", "sd-cpp-local"}, {"object", "model"}, {"owned_by", "local"}}); + res.set_content(r.dump(), "application/json"); + }); + + svr.Post("/v1/images/generations", [runtime](const httplib::Request& req, httplib::Response& res) { + try { + if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) { + res.status = 400; + res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json"); + return; + } + + ImgGenJobRequest request; + std::string error_message; + if (!build_openai_generation_request(req, *runtime, request, error_message)) { + res.status = 400; + res.set_content(json({{"error", error_message}}).dump(), "application/json"); + return; + } + + LOG_DEBUG("%s\n", request.gen_params.to_string().c_str()); + + SDImageVec results; + if (!execute_sync_img_gen_request(*runtime, request, results, error_message)) { + res.status = 500; + res.set_content(json({{"error", error_message}}).dump(), "application/json"); + return; + } + + json out; + out["created"] = static_cast(std::time(nullptr)); + out["data"] = json::array(); + out["output_format"] = request.output_format; + + for (int i = 0; i < request.gen_params.batch_count; ++i) { + if (results[i].data == nullptr) { + continue; + } + std::string params = request.gen_params.embed_image_metadata + ? get_image_params(*runtime->ctx_params, + request.gen_params, + request.gen_params.seed + i) + : ""; + auto image_bytes = encode_image_to_vector(request.output_format == "jpeg" + ? EncodedImageFormat::JPEG + : request.output_format == "webp" + ? EncodedImageFormat::WEBP + : EncodedImageFormat::PNG, + results[i].data, + results[i].width, + results[i].height, + results[i].channel, + params, + request.output_compression); + if (image_bytes.empty()) { + LOG_ERROR("write image to mem failed"); + continue; + } + + json item; + item["b64_json"] = base64_encode(image_bytes); + out["data"].push_back(item); + } + + res.set_content(out.dump(), "application/json"); + res.status = 200; + + } catch (const std::exception& e) { + res.status = 500; + json err; + err["error"] = "server_error"; + err["message"] = e.what(); + res.set_content(err.dump(), "application/json"); + } + }); + + svr.Post("/v1/images/edits", [runtime](const httplib::Request& req, httplib::Response& res) { + try { + if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) { + res.status = 400; + res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json"); + return; + } + + ImgGenJobRequest request; + std::string error_message; + if (!build_openai_edit_request(req, *runtime, request, error_message)) { + res.status = 400; + res.set_content(json({{"error", error_message}}).dump(), "application/json"); + return; + } + + LOG_DEBUG("%s\n", request.gen_params.to_string().c_str()); + + SDImageVec results; + if (!execute_sync_img_gen_request(*runtime, request, results, error_message)) { + res.status = 500; + res.set_content(json({{"error", error_message}}).dump(), "application/json"); + return; + } + + json out; + out["created"] = static_cast(std::time(nullptr)); + out["data"] = json::array(); + out["output_format"] = request.output_format; + + for (int i = 0; i < request.gen_params.batch_count; ++i) { + if (results[i].data == nullptr) { + continue; + } + std::string params = request.gen_params.embed_image_metadata + ? get_image_params(*runtime->ctx_params, + request.gen_params, + request.gen_params.seed + i) + : ""; + auto image_bytes = encode_image_to_vector(request.output_format == "jpeg" ? EncodedImageFormat::JPEG : EncodedImageFormat::PNG, + results[i].data, + results[i].width, + results[i].height, + results[i].channel, + params, + request.output_compression); + json item; + item["b64_json"] = base64_encode(image_bytes); + out["data"].push_back(item); + } + + res.set_content(out.dump(), "application/json"); + res.status = 200; + + } catch (const std::exception& e) { + res.status = 500; + json err; + err["error"] = "server_error"; + err["message"] = e.what(); + res.set_content(err.dump(), "application/json"); + } + }); +} diff --git a/examples/server/routes_sdapi.cpp b/examples/server/routes_sdapi.cpp new file mode 100644 index 000000000..1e01d2921 --- /dev/null +++ b/examples/server/routes_sdapi.cpp @@ -0,0 +1,469 @@ +#include "routes.h" + +#include +#include +#include +#include +#include +#include + +#include "common/common.h" +#include "common/media_io.h" +#include "common/resource_owners.hpp" + +namespace fs = std::filesystem; + +static std::string extract_and_remove_sd_cpp_extra_args(std::string& text) { + std::regex re("(.*?)"); + std::smatch match; + + std::string extracted; + if (std::regex_search(text, match, re)) { + extracted = match[1].str(); + text = std::regex_replace(text, re, ""); + } + return extracted; +} + +static fs::path resolve_display_model_path(const ServerRuntime& runtime) { + const auto& ctx = *runtime.ctx_params; + if (!ctx.model_path.empty()) { + return fs::path(ctx.model_path); + } + if (!ctx.diffusion_model_path.empty()) { + return fs::path(ctx.diffusion_model_path); + } + return {}; +} + +static std::string lower_ascii(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return value; +} + +static enum sample_method_t get_sdapi_sample_method(std::string name) { + enum sample_method_t result = str_to_sample_method(name.c_str()); + if (result != SAMPLE_METHOD_COUNT) { + return result; + } + + name = lower_ascii(name); + static const std::unordered_map hardcoded{ + {"euler a", EULER_A_SAMPLE_METHOD}, + {"k_euler_a", EULER_A_SAMPLE_METHOD}, + {"euler", EULER_SAMPLE_METHOD}, + {"k_euler", EULER_SAMPLE_METHOD}, + {"heun", HEUN_SAMPLE_METHOD}, + {"k_heun", HEUN_SAMPLE_METHOD}, + {"dpm2", DPM2_SAMPLE_METHOD}, + {"k_dpm_2", DPM2_SAMPLE_METHOD}, + {"lcm", LCM_SAMPLE_METHOD}, + {"ddim", DDIM_TRAILING_SAMPLE_METHOD}, + {"dpm++ 2m", DPMPP2M_SAMPLE_METHOD}, + {"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD}, + {"res multistep", RES_MULTISTEP_SAMPLE_METHOD}, + {"k_res_multistep", RES_MULTISTEP_SAMPLE_METHOD}, + {"res 2s", RES_2S_SAMPLE_METHOD}, + {"k_res_2s", RES_2S_SAMPLE_METHOD}, + }; + auto it = hardcoded.find(name); + return it != hardcoded.end() ? it->second : SAMPLE_METHOD_COUNT; +} + +static void assign_solid_mask(SDImageOwner& mask_owner, int width, int height) { + const size_t pixel_count = static_cast(width) * static_cast(height); + uint8_t* raw_mask = static_cast(malloc(pixel_count)); + if (raw_mask == nullptr) { + mask_owner.reset({0, 0, 1, nullptr}); + return; + } + std::memset(raw_mask, 255, pixel_count); + mask_owner.reset({(uint32_t)width, (uint32_t)height, 1, raw_mask}); +} + +static bool build_sdapi_img_gen_request(const json& j, + ServerRuntime& runtime, + bool img2img, + ImgGenJobRequest& request, + std::string& error_message) { + std::string prompt = j.value("prompt", ""); + std::string negative_prompt = j.value("negative_prompt", ""); + int width = j.value("width", 512); + int height = j.value("height", 512); + int steps = j.value("steps", runtime.default_gen_params->sample_params.sample_steps); + float cfg_scale = j.value("cfg_scale", runtime.default_gen_params->sample_params.guidance.txt_cfg); + int64_t seed = j.value("seed", -1); + int batch_size = j.value("batch_size", 1); + int clip_skip = j.value("clip_skip", -1); + std::string sampler_name = j.value("sampler_name", ""); + std::string scheduler_name = j.value("scheduler", ""); + + if (width <= 0 || height <= 0) { + error_message = "width and height must be positive"; + return false; + } + + if (prompt.empty()) { + error_message = "prompt required"; + return false; + } + + request.gen_params = *runtime.default_gen_params; + + request.gen_params.prompt = prompt; + request.gen_params.negative_prompt = negative_prompt; + request.gen_params.seed = seed; + request.gen_params.sample_params.sample_steps = steps; + request.gen_params.batch_count = batch_size; + request.gen_params.sample_params.guidance.txt_cfg = cfg_scale; + request.gen_params.width = j.value("width", -1); + request.gen_params.height = j.value("height", -1); + + if (!img2img && j.value("enable_hr", false)) { + request.gen_params.hires_enabled = true; + request.gen_params.hires_scale = j.value("hr_scale", request.gen_params.hires_scale); + request.gen_params.hires_width = j.value("hr_resize_x", request.gen_params.hires_width); + request.gen_params.hires_height = j.value("hr_resize_y", request.gen_params.hires_height); + request.gen_params.hires_steps = j.value("hr_steps", request.gen_params.hires_steps); + request.gen_params.hires_denoising_strength = + j.value("denoising_strength", request.gen_params.hires_denoising_strength); + + request.gen_params.hires_upscaler = j.value("hr_upscaler", request.gen_params.hires_upscaler); + } + + std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(request.gen_params.prompt); + if (!sd_cpp_extra_args_str.empty() && !request.gen_params.from_json_str(sd_cpp_extra_args_str)) { + error_message = "invalid sd_cpp_extra_args"; + return false; + } + + if (clip_skip > 0) { + request.gen_params.clip_skip = clip_skip; + } + + enum sample_method_t sample_method = get_sdapi_sample_method(sampler_name); + if (sample_method != SAMPLE_METHOD_COUNT) { + request.gen_params.sample_params.sample_method = sample_method; + } + + enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str()); + if (scheduler != SCHEDULER_COUNT) { + request.gen_params.sample_params.scheduler = scheduler; + } + + if (j.contains("lora") && j["lora"].is_array()) { + request.gen_params.lora_map.clear(); + request.gen_params.high_noise_lora_map.clear(); + + for (const auto& item : j["lora"]) { + if (!item.is_object()) { + continue; + } + + std::string path = item.value("path", ""); + float multiplier = item.value("multiplier", 1.0f); + bool is_high_noise = item.value("is_high_noise", false); + + if (path.empty()) { + error_message = "lora.path required"; + return false; + } + + std::string fullpath = get_lora_full_path(runtime, path); + if (fullpath.empty()) { + error_message = "invalid lora path: " + path; + return false; + } + + if (is_high_noise) { + request.gen_params.high_noise_lora_map[fullpath] += multiplier; + } else { + request.gen_params.lora_map[fullpath] += multiplier; + } + } + } + + if (img2img) { + const int expected_width = request.gen_params.width_and_height_are_set() ? request.gen_params.width : 0; + const int expected_height = request.gen_params.width_and_height_are_set() ? request.gen_params.height : 0; + + if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) { + if (decode_base64_image(j["init_images"][0].get(), + 3, + expected_width, + expected_height, + request.gen_params.init_image)) { + const sd_image_t& image = request.gen_params.init_image.get(); + request.gen_params.set_width_and_height_if_unset(image.width, image.height); + } + } + + if (j.contains("mask") && j["mask"].is_string()) { + if (decode_base64_image(j["mask"].get(), + 1, + expected_width, + expected_height, + request.gen_params.mask_image)) { + const sd_image_t& image = request.gen_params.mask_image.get(); + request.gen_params.set_width_and_height_if_unset(image.width, image.height); + } + sd_image_t& mask_image = request.gen_params.mask_image.get(); + bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0; + if (inpainting_mask_invert && mask_image.data != nullptr) { + for (uint32_t i = 0; i < mask_image.width * mask_image.height; ++i) { + mask_image.data[i] = 255 - mask_image.data[i]; + } + } + } else { + const int resolved_width = request.gen_params.get_resolved_width(); + const int resolved_height = request.gen_params.get_resolved_height(); + assign_solid_mask(request.gen_params.mask_image, resolved_width, resolved_height); + } + + float denoising_strength = j.value("denoising_strength", -1.f); + if (denoising_strength >= 0.f) { + request.gen_params.strength = std::min(denoising_strength, 1.0f); + } + } + + if (j.contains("extra_images") && j["extra_images"].is_array()) { + for (const auto& extra_image : j["extra_images"]) { + if (!extra_image.is_string()) { + continue; + } + SDImageOwner image_owner; + if (decode_base64_image(extra_image.get(), + 3, + request.gen_params.width_and_height_are_set() ? request.gen_params.width : 0, + request.gen_params.width_and_height_are_set() ? request.gen_params.height : 0, + image_owner)) { + const sd_image_t& image = image_owner.get(); + request.gen_params.set_width_and_height_if_unset(image.width, image.height); + request.gen_params.ref_images.push_back(std::move(image_owner)); + } + } + } + + // Intentionally disable prompt-embedded LoRA tag parsing for server APIs. + if (!request.gen_params.resolve_and_validate(IMG_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) { + error_message = "invalid params"; + return false; + } + + return true; +} + +void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) { + ServerRuntime* runtime = &rt; + + auto sdapi_any2img = [runtime](const httplib::Request& req, httplib::Response& res, bool img2img) { + try { + if (req.body.empty()) { + res.status = 400; + res.set_content(R"({"error":"empty body"})", "application/json"); + return; + } + if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) { + res.status = 400; + res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json"); + return; + } + + json j = json::parse(req.body); + ImgGenJobRequest request; + std::string error_message; + if (!build_sdapi_img_gen_request(j, *runtime, img2img, request, error_message)) { + res.status = 400; + res.set_content(json({{"error", error_message}}).dump(), "application/json"); + return; + } + + LOG_DEBUG("%s\n", request.gen_params.to_string().c_str()); + + sd_img_gen_params_t img_gen_params = request.to_sd_img_gen_params_t(); + SDImageVec results; + int num_results = 0; + + { + std::lock_guard lock(*runtime->sd_ctx_mutex); + sd_image_t* raw_results = generate_image(runtime->sd_ctx, &img_gen_params); + num_results = request.gen_params.batch_count; + results.adopt(raw_results, num_results); + } + + if (results.empty()) { + res.status = 500; + res.set_content(R"({"error":"generate_image returned no results"})", "application/json"); + return; + } + + json out; + out["images"] = json::array(); + out["parameters"] = j; + out["info"] = ""; + + for (int i = 0; i < num_results; ++i) { + if (results[i].data == nullptr) { + continue; + } + + std::string params = request.gen_params.embed_image_metadata + ? get_image_params(*runtime->ctx_params, + request.gen_params, + request.gen_params.seed + i) + : ""; + auto image_bytes = encode_image_to_vector(EncodedImageFormat::PNG, + results[i].data, + results[i].width, + results[i].height, + results[i].channel, + params); + + if (image_bytes.empty()) { + LOG_ERROR("write image to mem failed"); + continue; + } + + out["images"].push_back(base64_encode(image_bytes)); + } + + res.set_content(out.dump(), "application/json"); + res.status = 200; + + } catch (const std::exception& e) { + res.status = 500; + json err; + err["error"] = "server_error"; + err["message"] = e.what(); + res.set_content(err.dump(), "application/json"); + } + }; + + svr.Post("/sdapi/v1/txt2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) { + sdapi_any2img(req, res, false); + }); + + svr.Post("/sdapi/v1/img2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) { + sdapi_any2img(req, res, true); + }); + + svr.Get("/sdapi/v1/loras", [runtime](const httplib::Request&, httplib::Response& res) { + refresh_lora_cache(*runtime); + + json result = json::array(); + { + std::lock_guard lock(*runtime->lora_mutex); + for (const auto& e : *runtime->lora_cache) { + json item; + item["name"] = e.name; + item["path"] = e.path; + result.push_back(item); + } + } + + res.set_content(result.dump(), "application/json"); + }); + + svr.Get("/sdapi/v1/upscalers", [runtime](const httplib::Request&, httplib::Response& res) { + refresh_upscaler_cache(*runtime); + + auto make_builtin = [](const char* name) { + json item; + item["name"] = name; + item["model_name"] = nullptr; + item["model_path"] = nullptr; + item["model_url"] = nullptr; + item["scale"] = 4; + return item; + }; + + json result = json::array(); + result.push_back(make_builtin("None")); + result.push_back(make_builtin("Lanczos")); + result.push_back(make_builtin("Nearest")); + + { + std::lock_guard lock(*runtime->upscaler_mutex); + for (const auto& e : *runtime->upscaler_cache) { + json item; + item["name"] = e.name; + item["model_name"] = e.model_name; + item["model_path"] = e.fullpath; + item["model_url"] = nullptr; + item["scale"] = e.scale; + result.push_back(item); + } + } + + res.set_content(result.dump(), "application/json"); + }); + + svr.Get("/sdapi/v1/latent-upscale-modes", [](const httplib::Request&, httplib::Response& res) { + json result = json::array({ + {{"name", "Latent"}}, + {{"name", "Latent (nearest)"}}, + {{"name", "Latent (nearest-exact)"}}, + {{"name", "Latent (antialiased)"}}, + {{"name", "Latent (bicubic)"}}, + {{"name", "Latent (bicubic antialiased)"}}, + }); + res.set_content(result.dump(), "application/json"); + }); + + svr.Get("/sdapi/v1/samplers", [runtime](const httplib::Request&, httplib::Response& res) { + std::vector sampler_names; + sampler_names.push_back("default"); + for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) { + sampler_names.push_back(sd_sample_method_name((sample_method_t)i)); + } + json r = json::array(); + for (auto name : sampler_names) { + json entry; + entry["name"] = name; + entry["aliases"] = json::array({name}); + entry["options"] = json::object(); + r.push_back(entry); + } + res.set_content(r.dump(), "application/json"); + }); + + svr.Get("/sdapi/v1/schedulers", [runtime](const httplib::Request&, httplib::Response& res) { + std::vector scheduler_names; + scheduler_names.push_back("default"); + for (int i = 0; i < SCHEDULER_COUNT; i++) { + scheduler_names.push_back(sd_scheduler_name((scheduler_t)i)); + } + json r = json::array(); + for (auto name : scheduler_names) { + json entry; + entry["name"] = name; + entry["label"] = name; + r.push_back(entry); + } + res.set_content(r.dump(), "application/json"); + }); + + svr.Get("/sdapi/v1/sd-models", [runtime](const httplib::Request&, httplib::Response& res) { + fs::path model_path = resolve_display_model_path(*runtime); + json entry; + entry["title"] = model_path.stem(); + entry["model_name"] = model_path.stem(); + entry["filename"] = model_path.filename(); + entry["hash"] = "8888888888"; + entry["sha256"] = "8888888888888888888888888888888888888888888888888888888888888888"; + entry["config"] = nullptr; + json r = json::array(); + r.push_back(entry); + res.set_content(r.dump(), "application/json"); + }); + + svr.Get("/sdapi/v1/options", [runtime](const httplib::Request&, httplib::Response& res) { + fs::path model_path = resolve_display_model_path(*runtime); + json r; + r["samples_format"] = "png"; + r["sd_model_checkpoint"] = model_path.stem(); + res.set_content(r.dump(), "application/json"); + }); +} diff --git a/examples/server/routes_sdcpp.cpp b/examples/server/routes_sdcpp.cpp new file mode 100644 index 000000000..16fe0af40 --- /dev/null +++ b/examples/server/routes_sdcpp.cpp @@ -0,0 +1,588 @@ +#include "routes.h" + +#include +#include +#include + +#include "async_jobs.h" +#include "common/common.h" + +namespace fs = std::filesystem; + +static bool parse_cache_mode(const std::string& mode_str, sd_cache_mode_t& mode_out) { + if (mode_str == "disabled") { + mode_out = SD_CACHE_DISABLED; + return true; + } + if (mode_str == "easycache") { + mode_out = SD_CACHE_EASYCACHE; + return true; + } + if (mode_str == "ucache") { + mode_out = SD_CACHE_UCACHE; + return true; + } + if (mode_str == "dbcache") { + mode_out = SD_CACHE_DBCACHE; + return true; + } + if (mode_str == "taylorseer") { + mode_out = SD_CACHE_TAYLORSEER; + return true; + } + if (mode_str == "cache-dit") { + mode_out = SD_CACHE_CACHE_DIT; + return true; + } + if (mode_str == "spectrum") { + mode_out = SD_CACHE_SPECTRUM; + return true; + } + return false; +} + +static json finite_number_or_null(float value) { + return std::isfinite(value) ? json(value) : json(nullptr); +} + +static const char* capability_scheduler_name(enum scheduler_t scheduler) { + return scheduler < SCHEDULER_COUNT ? sd_scheduler_name(scheduler) : "default"; +} + +static const char* capability_sample_method_name(enum sample_method_t sample_method) { + return sample_method < SAMPLE_METHOD_COUNT ? sd_sample_method_name(sample_method) : "default"; +} + +static json make_vae_tiling_json(const sd_tiling_params_t& params) { + return { + {"enabled", params.enabled}, + {"tile_size_x", params.tile_size_x}, + {"tile_size_y", params.tile_size_y}, + {"target_overlap", params.target_overlap}, + {"rel_size_x", params.rel_size_x}, + {"rel_size_y", params.rel_size_y}, + }; +} + +static fs::path resolve_display_model_path(const ServerRuntime& runtime) { + const auto& ctx = *runtime.ctx_params; + if (!ctx.model_path.empty()) { + return fs::path(ctx.model_path); + } + if (!ctx.diffusion_model_path.empty()) { + return fs::path(ctx.diffusion_model_path); + } + return {}; +} + +static json make_sample_params_json(const sd_sample_params_t& sample_params, const std::vector& skip_layers) { + const auto& guidance = sample_params.guidance; + return { + {"scheduler", capability_scheduler_name(sample_params.scheduler)}, + {"sample_method", capability_sample_method_name(sample_params.sample_method)}, + {"sample_steps", sample_params.sample_steps}, + {"eta", finite_number_or_null(sample_params.eta)}, + {"shifted_timestep", sample_params.shifted_timestep}, + {"flow_shift", finite_number_or_null(sample_params.flow_shift)}, + {"guidance", + { + {"txt_cfg", guidance.txt_cfg}, + {"img_cfg", finite_number_or_null(guidance.img_cfg)}, + {"distilled_guidance", guidance.distilled_guidance}, + {"slg", + { + {"layers", skip_layers}, + {"layer_start", guidance.slg.layer_start}, + {"layer_end", guidance.slg.layer_end}, + {"scale", guidance.slg.scale}, + }}, + }}, + }; +} + +static json make_img_gen_defaults_json(const SDGenerationParams& defaults, const std::string& output_format) { + return { + {"prompt", defaults.prompt}, + {"negative_prompt", defaults.negative_prompt}, + {"clip_skip", defaults.clip_skip}, + {"width", defaults.width > 0 ? defaults.width : 512}, + {"height", defaults.height > 0 ? defaults.height : 512}, + {"strength", defaults.strength}, + {"seed", defaults.seed}, + {"batch_count", defaults.batch_count}, + {"auto_resize_ref_image", defaults.auto_resize_ref_image}, + {"increase_ref_index", defaults.increase_ref_index}, + {"control_strength", defaults.control_strength}, + {"sample_params", make_sample_params_json(defaults.sample_params, defaults.skip_layers)}, + {"hires", + { + {"enabled", defaults.hires_enabled}, + {"upscaler", defaults.hires_upscaler}, + {"scale", defaults.hires_scale}, + {"target_width", defaults.hires_width}, + {"target_height", defaults.hires_height}, + {"steps", defaults.hires_steps}, + {"denoising_strength", defaults.hires_denoising_strength}, + {"upscale_tile_size", defaults.hires_upscale_tile_size}, + }}, + {"vae_tiling_params", make_vae_tiling_json(defaults.vae_tiling_params)}, + {"cache_mode", defaults.cache_mode}, + {"cache_option", defaults.cache_option}, + {"scm_mask", defaults.scm_mask}, + {"scm_policy_dynamic", defaults.scm_policy_dynamic}, + {"output_format", output_format}, + {"output_compression", 100}, + }; +} + +static json make_vid_gen_defaults_json(const SDGenerationParams& defaults, const std::string& output_format) { + return { + {"prompt", defaults.prompt}, + {"negative_prompt", defaults.negative_prompt}, + {"clip_skip", defaults.clip_skip}, + {"width", defaults.width > 0 ? defaults.width : 512}, + {"height", defaults.height > 0 ? defaults.height : 512}, + {"strength", defaults.strength}, + {"seed", defaults.seed}, + {"video_frames", defaults.video_frames}, + {"fps", defaults.fps}, + {"moe_boundary", defaults.moe_boundary}, + {"vace_strength", defaults.vace_strength}, + {"sample_params", make_sample_params_json(defaults.sample_params, defaults.skip_layers)}, + {"high_noise_sample_params", make_sample_params_json(defaults.high_noise_sample_params, defaults.high_noise_skip_layers)}, + {"vae_tiling_params", make_vae_tiling_json(defaults.vae_tiling_params)}, + {"cache_mode", defaults.cache_mode}, + {"cache_option", defaults.cache_option}, + {"scm_mask", defaults.scm_mask}, + {"scm_policy_dynamic", defaults.scm_policy_dynamic}, + {"output_format", output_format}, + {"output_compression", 100}, + }; +} + +static json make_img_gen_features_json() { + return { + {"init_image", true}, + {"mask_image", true}, + {"control_image", true}, + {"ref_images", true}, + {"lora", true}, + {"vae_tiling", true}, + {"hires", true}, + {"cache", true}, + {"cancel_queued", true}, + {"cancel_generating", false}, + }; +} + +static json make_vid_gen_features_json() { + return { + {"init_image", true}, + {"end_image", true}, + {"control_frames", true}, + {"high_noise_sample_params", true}, + {"lora", true}, + {"vae_tiling", true}, + {"cache", true}, + {"cancel_queued", true}, + {"cancel_generating", false}, + }; +} + +static json make_capabilities_json(ServerRuntime& runtime) { + refresh_lora_cache(runtime); + refresh_upscaler_cache(runtime); + + AsyncJobManager& manager = *runtime.async_job_manager; + const auto& defaults = *runtime.default_gen_params; + const fs::path model_path = resolve_display_model_path(runtime); + const bool supports_img = runtime_supports_generation_mode(runtime, IMG_GEN); + const bool supports_vid = runtime_supports_generation_mode(runtime, VID_GEN); + json samplers = json::array(); + json schedulers = json::array(); + json image_output_formats = supported_img_output_formats(); + json video_output_formats = supported_vid_output_formats(); + json available_loras = json::array(); + json available_upscalers = json::array(); + json supported_modes = json::array(); + + for (int i = 0; i < SAMPLE_METHOD_COUNT; ++i) { + samplers.push_back(sd_sample_method_name((sample_method_t)i)); + } + + for (int i = 0; i < SCHEDULER_COUNT; ++i) { + schedulers.push_back(sd_scheduler_name((scheduler_t)i)); + } + + { + std::lock_guard lock(*runtime.lora_mutex); + for (const auto& entry : *runtime.lora_cache) { + available_loras.push_back({ + {"name", entry.name}, + {"path", entry.path}, + }); + } + } + + available_upscalers.push_back({ + {"name", "None"}, + }); + available_upscalers.push_back({ + {"name", "Lanczos"}, + }); + available_upscalers.push_back({ + {"name", "Nearest"}, + }); + available_upscalers.push_back({ + {"name", "Latent"}, + }); + available_upscalers.push_back({ + {"name", "Latent (nearest)"}, + }); + available_upscalers.push_back({ + {"name", "Latent (nearest-exact)"}, + }); + available_upscalers.push_back({ + {"name", "Latent (antialiased)"}, + }); + available_upscalers.push_back({ + {"name", "Latent (bicubic)"}, + }); + available_upscalers.push_back({ + {"name", "Latent (bicubic antialiased)"}, + }); + { + std::lock_guard lock(*runtime.upscaler_mutex); + for (const auto& entry : *runtime.upscaler_cache) { + available_upscalers.push_back({ + {"name", entry.name}, + }); + } + } + + if (supports_img) { + supported_modes.push_back("img_gen"); + } + if (supports_vid) { + supported_modes.push_back("vid_gen"); + } + + std::string default_img_output_format = "png"; + std::string default_vid_output_format = "avi"; + if (!image_output_formats.empty()) { + default_img_output_format = image_output_formats[0].get(); + } + if (!video_output_formats.empty()) { + default_vid_output_format = video_output_formats[0].get(); + } + + json defaults_by_mode = json::object(); + json output_formats_by_mode = json::object(); + json features_by_mode = json::object(); + if (supports_img) { + defaults_by_mode["img_gen"] = make_img_gen_defaults_json(defaults, default_img_output_format); + output_formats_by_mode["img_gen"] = image_output_formats; + features_by_mode["img_gen"] = make_img_gen_features_json(); + } + if (supports_vid) { + defaults_by_mode["vid_gen"] = make_vid_gen_defaults_json(defaults, default_vid_output_format); + output_formats_by_mode["vid_gen"] = video_output_formats; + features_by_mode["vid_gen"] = make_vid_gen_features_json(); + } + + json top_level_defaults = json::object(); + json top_level_output_formats = json::array(); + json top_level_features = { + {"cancel_queued", true}, + {"cancel_generating", false}, + }; + std::string current_mode = ""; + if (supports_img) { + current_mode = "img_gen"; + top_level_defaults = defaults_by_mode["img_gen"]; + top_level_output_formats = output_formats_by_mode["img_gen"]; + top_level_features = features_by_mode["img_gen"]; + } else if (supports_vid) { + current_mode = "vid_gen"; + top_level_defaults = defaults_by_mode["vid_gen"]; + top_level_output_formats = output_formats_by_mode["vid_gen"]; + top_level_features = features_by_mode["vid_gen"]; + } + + json result; + result["model"] = { + {"name", model_path.filename().u8string()}, + {"stem", model_path.stem().u8string()}, + {"path", model_path.u8string()}, + }; + result["current_mode"] = current_mode; + result["supported_modes"] = supported_modes; + result["defaults"] = top_level_defaults; + result["defaults_by_mode"] = defaults_by_mode; + result["limits"] = { + {"min_width", 64}, + {"max_width", 4096}, + {"min_height", 64}, + {"max_height", 4096}, + {"max_batch_count", 8}, + {"max_queue_size", manager.max_pending_jobs}, + }; + result["samplers"] = samplers; + result["schedulers"] = schedulers; + result["output_formats"] = top_level_output_formats; + result["output_formats_by_mode"] = output_formats_by_mode; + result["features"] = top_level_features; + result["features_by_mode"] = features_by_mode; + result["loras"] = available_loras; + result["upscalers"] = available_upscalers; + return result; +} + +static bool parse_img_gen_request(const json& body, + ServerRuntime& runtime, + ImgGenJobRequest& request, + std::string& error_message) { + request.gen_params = *runtime.default_gen_params; + + refresh_lora_cache(runtime); + if (!request.gen_params.from_json_str(body.dump(), [&](const std::string& path) { + return get_lora_full_path(runtime, path); + })) { + error_message = "invalid generation parameters"; + return false; + } + + std::string output_format = body.value("output_format", "png"); + int output_compression = body.value("output_compression", 100); + if (!assign_output_options(request, output_format, output_compression, true, error_message)) { + return false; + } + // Intentionally disable prompt-embedded LoRA tag parsing for server APIs. + if (!request.gen_params.resolve_and_validate(IMG_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) { + error_message = "invalid generation parameters"; + return false; + } + return true; +} + +static bool parse_vid_gen_request(const json& body, + ServerRuntime& runtime, + VidGenJobRequest& request, + std::string& error_message) { + request.gen_params = *runtime.default_gen_params; + + refresh_lora_cache(runtime); + if (!request.gen_params.from_json_str(body.dump(), [&](const std::string& path) { + return get_lora_full_path(runtime, path); + })) { + error_message = "invalid generation parameters"; + return false; + } + + std::string output_format = body.value("output_format", "webm"); + int output_compression = body.value("output_compression", 100); + if (!assign_output_options(request, output_format, output_compression, error_message)) { + return false; + } + // Intentionally disable prompt-embedded LoRA tag parsing for server APIs. + if (!request.gen_params.resolve_and_validate(VID_GEN, "", runtime.ctx_params->hires_upscalers_dir, true)) { + error_message = "invalid generation parameters"; + return false; + } + return true; +} + +void register_sdcpp_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { + ServerRuntime* runtime = &rt; + + svr.Get("/sdcpp/v1/capabilities", [runtime](const httplib::Request&, httplib::Response& res) { + res.status = 200; + res.set_content(make_capabilities_json(*runtime).dump(), "application/json"); + }); + + svr.Post("/sdcpp/v1/img_gen", [runtime](const httplib::Request& req, httplib::Response& res) { + try { + if (req.body.empty()) { + res.status = 400; + res.set_content(R"({"error":"empty body"})", "application/json"); + return; + } + if (!runtime_supports_generation_mode(*runtime, IMG_GEN)) { + res.status = 400; + res.set_content(json({{"error", unsupported_generation_mode_error(IMG_GEN)}}).dump(), "application/json"); + return; + } + + json body = json::parse(req.body); + ImgGenJobRequest request; + std::string error_message; + if (!parse_img_gen_request(body, *runtime, request, error_message)) { + res.status = 400; + res.set_content(json({{"error", error_message}}).dump(), "application/json"); + return; + } + + AsyncJobManager& manager = *runtime->async_job_manager; + std::shared_ptr job = std::make_shared(); + job->kind = AsyncJobKind::ImgGen; + job->status = AsyncJobStatus::Queued; + job->created_at = unix_timestamp_now(); + job->img_gen = std::move(request); + + { + std::lock_guard lock(manager.mutex); + purge_expired_jobs(manager); + if (count_pending_jobs(manager) >= manager.max_pending_jobs) { + res.status = 429; + res.set_content(R"({"error":"job queue is full"})", "application/json"); + return; + } + job->id = make_async_job_id(manager); + manager.jobs[job->id] = job; + manager.queue.push_back(job->id); + } + + manager.cv.notify_one(); + + json out; + out["id"] = job->id; + out["kind"] = async_job_kind_name(job->kind); + out["status"] = async_job_status_name(job->status); + out["created"] = job->created_at; + out["poll_url"] = "/sdcpp/v1/jobs/" + job->id; + + res.status = 202; + res.set_content(out.dump(), "application/json"); + } catch (const json::parse_error& e) { + res.status = 400; + res.set_content(json({{"error", "invalid json"}, {"message", e.what()}}).dump(), "application/json"); + } catch (const std::exception& e) { + res.status = 500; + res.set_content(json({{"error", "server_error"}, {"message", e.what()}}).dump(), "application/json"); + } + }); + + svr.Post("/sdcpp/v1/vid_gen", [runtime](const httplib::Request& req, httplib::Response& res) { + try { + if (req.body.empty()) { + res.status = 400; + res.set_content(R"({"error":"empty body"})", "application/json"); + return; + } + if (!runtime_supports_generation_mode(*runtime, VID_GEN)) { + res.status = 400; + res.set_content(json({{"error", unsupported_generation_mode_error(VID_GEN)}}).dump(), "application/json"); + return; + } + + json body = json::parse(req.body); + VidGenJobRequest request; + std::string error_message; + if (!parse_vid_gen_request(body, *runtime, request, error_message)) { + res.status = 400; + res.set_content(json({{"error", error_message}}).dump(), "application/json"); + return; + } + + AsyncJobManager& manager = *runtime->async_job_manager; + std::shared_ptr job = std::make_shared(); + job->kind = AsyncJobKind::VidGen; + job->status = AsyncJobStatus::Queued; + job->created_at = unix_timestamp_now(); + job->vid_gen = std::move(request); + + { + std::lock_guard lock(manager.mutex); + purge_expired_jobs(manager); + if (count_pending_jobs(manager) >= manager.max_pending_jobs) { + res.status = 429; + res.set_content(R"({"error":"job queue is full"})", "application/json"); + return; + } + job->id = make_async_job_id(manager); + manager.jobs[job->id] = job; + manager.queue.push_back(job->id); + } + + manager.cv.notify_one(); + + json out; + out["id"] = job->id; + out["kind"] = async_job_kind_name(job->kind); + out["status"] = async_job_status_name(job->status); + out["created"] = job->created_at; + out["poll_url"] = "/sdcpp/v1/jobs/" + job->id; + + res.status = 202; + res.set_content(out.dump(), "application/json"); + } catch (const json::parse_error& e) { + res.status = 400; + res.set_content(json({{"error", "invalid json"}, {"message", e.what()}}).dump(), "application/json"); + } catch (const std::exception& e) { + res.status = 500; + res.set_content(json({{"error", "server_error"}, {"message", e.what()}}).dump(), "application/json"); + } + }); + + svr.Get(R"(/sdcpp/v1/jobs/([A-Za-z0-9_\-]+))", [runtime](const httplib::Request& req, httplib::Response& res) { + AsyncJobManager& manager = *runtime->async_job_manager; + std::lock_guard lock(manager.mutex); + purge_expired_jobs(manager); + + std::string job_id = req.matches[1]; + auto it = manager.jobs.find(job_id); + if (it == manager.jobs.end()) { + if (manager.expired_jobs.find(job_id) != manager.expired_jobs.end()) { + res.status = 410; + res.set_content(R"({"error":"job expired"})", "application/json"); + } else { + res.status = 404; + res.set_content(R"({"error":"job not found"})", "application/json"); + } + return; + } + + res.status = 200; + res.set_content(make_async_job_json(manager, *it->second).dump(), "application/json"); + }); + + svr.Post(R"(/sdcpp/v1/jobs/([A-Za-z0-9_\-]+)/cancel)", [runtime](const httplib::Request& req, httplib::Response& res) { + AsyncJobManager& manager = *runtime->async_job_manager; + std::lock_guard lock(manager.mutex); + purge_expired_jobs(manager); + + std::string job_id = req.matches[1]; + auto it = manager.jobs.find(job_id); + if (it == manager.jobs.end()) { + if (manager.expired_jobs.find(job_id) != manager.expired_jobs.end()) { + res.status = 410; + res.set_content(R"({"error":"job expired"})", "application/json"); + } else { + res.status = 404; + res.set_content(R"({"error":"job not found"})", "application/json"); + } + return; + } + + auto& job = *it->second; + if (job.status == AsyncJobStatus::Queued) { + if (!cancel_queued_job(manager, job)) { + res.status = 409; + res.set_content(R"({"error":"job queue state changed before cancellation"})", "application/json"); + return; + } + res.status = 200; + res.set_content(make_async_job_json(manager, job).dump(), "application/json"); + return; + } + + if (job.status == AsyncJobStatus::Generating) { + res.status = 409; + res.set_content(R"({"error":"job is currently generating and cannot be interrupted yet"})", "application/json"); + return; + } + + res.status = 200; + res.set_content(make_async_job_json(manager, job).dump(), "application/json"); + }); +} diff --git a/examples/server/runtime.cpp b/examples/server/runtime.cpp new file mode 100644 index 000000000..afadb62ae --- /dev/null +++ b/examples/server/runtime.cpp @@ -0,0 +1,332 @@ +#include "runtime.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/common.h" +#include "common/log.h" + +namespace fs = std::filesystem; + +static std::string lower_ascii(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return value; +} + +static bool is_supported_model_ext(const fs::path& p) { + auto ext = lower_ascii(p.extension().string()); + return ext == ".gguf" || ext == ".pt" || ext == ".pth" || ext == ".safetensors"; +} + +static const std::string k_base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + +std::string base64_encode(const std::vector& bytes) { + std::string ret; + int val = 0; + int valb = -6; + for (uint8_t c : bytes) { + val = (val << 8) + c; + valb += 8; + while (valb >= 0) { + ret.push_back(k_base64_chars[(val >> valb) & 0x3F]); + valb -= 6; + } + } + if (valb > -6) { + ret.push_back(k_base64_chars[((val << 8) >> (valb + 8)) & 0x3F]); + } + while (ret.size() % 4) { + ret.push_back('='); + } + return ret; +} + +std::string normalize_output_format(std::string output_format) { + std::transform(output_format.begin(), output_format.end(), output_format.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + return output_format; +} + +std::vector supported_img_output_formats(bool allow_webp) { + std::vector formats = {"png", "jpeg"}; +#ifdef SD_USE_WEBP + if (allow_webp) { + formats.push_back("webp"); + } +#else + (void)allow_webp; +#endif + return formats; +} + +std::vector supported_vid_output_formats() { + std::vector formats; +#ifdef SD_USE_WEBM + formats.push_back("webm"); +#endif +#ifdef SD_USE_WEBP + formats.push_back("webp"); +#endif + formats.push_back("avi"); + return formats; +} + +static std::string valid_vid_output_formats_message() { + const std::vector formats = supported_vid_output_formats(); + + std::string message = "invalid output_format, must be one of ["; + for (size_t i = 0; i < formats.size(); ++i) { + if (i > 0) { + message += ", "; + } + message += formats[i]; + } + message += "]"; + return message; +} + +bool assign_output_options(ImgGenJobRequest& request, + std::string output_format, + int output_compression, + bool allow_webp, + std::string& error_message) { + request.output_format = normalize_output_format(std::move(output_format)); + request.output_compression = std::clamp(output_compression, 0, 100); + + const std::vector valid_formats = supported_img_output_formats(allow_webp); + const bool valid_format = std::find(valid_formats.begin(), + valid_formats.end(), + request.output_format) != valid_formats.end(); + if (!valid_format) { + error_message = "invalid output_format, must be one of ["; + for (size_t i = 0; i < valid_formats.size(); ++i) { + if (i > 0) { + error_message += ", "; + } + error_message += valid_formats[i]; + } + error_message += "]"; + return false; + } + + return true; +} + +bool assign_output_options(VidGenJobRequest& request, + std::string output_format, + int output_compression, + std::string& error_message) { + request.output_format = normalize_output_format(std::move(output_format)); + request.output_compression = std::clamp(output_compression, 0, 100); + + if (request.output_format == "avi") { + return true; + } + + if (request.output_format == "webm") { +#ifdef SD_USE_WEBM + return true; +#else + error_message = valid_vid_output_formats_message(); + return false; +#endif + } + + if (request.output_format == "webp") { +#ifdef SD_USE_WEBP + return true; +#else + error_message = valid_vid_output_formats_message(); + return false; +#endif + } + + error_message = valid_vid_output_formats_message(); + return false; +} + +std::string video_mime_type(const std::string& output_format) { + if (output_format == "webm") { + return "video/webm"; + } + if (output_format == "webp") { + return "image/webp"; + } + return "video/x-msvideo"; +} + +bool runtime_supports_generation_mode(const ServerRuntime& runtime, SDMode mode) { + if (mode == VID_GEN) { + return sd_ctx_supports_video_generation(runtime.sd_ctx); + } + if (mode == IMG_GEN) { + return sd_ctx_supports_image_generation(runtime.sd_ctx); + } + return true; +} + +std::string unsupported_generation_mode_error(SDMode mode) { + if (mode == VID_GEN) { + return "loaded model does not support vid_gen"; + } + if (mode == IMG_GEN) { + return "loaded model does not support img_gen"; + } + return "loaded model does not support requested mode"; +} + +ArgOptions SDSvrParams::get_options() { + ArgOptions options; + + options.string_options = { + {"-l", "--listen-ip", "server listen ip (default: 127.0.0.1)", &listen_ip}, + {"", "--serve-html-path", "path to HTML file to serve at root (optional)", &serve_html_path}, + }; + + options.int_options = { + {"", "--listen-port", "server listen port (default: 1234)", &listen_port}, + }; + + options.bool_options = { + {"-v", "--verbose", "print extra info", true, &verbose}, + {"", "--color", "colors the logging tags according to level", true, &color}, + }; + + auto on_help_arg = [&](int, const char**, int) { + normal_exit = true; + return -1; + }; + + options.manual_options = { + {"-h", "--help", "show this help message and exit", on_help_arg}, + }; + return options; +} + +bool SDSvrParams::validate() { + if (listen_ip.empty()) { + LOG_ERROR("error: the following arguments are required: listen_ip"); + return false; + } + + if (listen_port < 0 || listen_port > 65535) { + LOG_ERROR("error: listen_port should be in the range [0, 65535]"); + return false; + } + + if (!serve_html_path.empty() && !fs::exists(serve_html_path)) { + LOG_ERROR("error: serve_html_path file does not exist: %s", serve_html_path.c_str()); + return false; + } + return true; +} + +bool SDSvrParams::resolve_and_validate() { + if (!validate()) { + return false; + } + return true; +} + +std::string SDSvrParams::to_string() const { + std::ostringstream oss; + oss << "SDSvrParams {\n" + << " listen_ip: " << listen_ip << ",\n" + << " listen_port: \"" << listen_port << "\",\n" + << " serve_html_path: \"" << serve_html_path << "\",\n" + << "}"; + return oss.str(); +} + +void refresh_lora_cache(ServerRuntime& rt) { + std::vector new_cache; + + fs::path lora_dir = rt.ctx_params->lora_model_dir; + if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) { + for (auto& entry : fs::recursive_directory_iterator(lora_dir)) { + if (!entry.is_regular_file()) { + continue; + } + const fs::path& p = entry.path(); + if (!is_supported_model_ext(p)) { + continue; + } + + LoraEntry lora_entry; + lora_entry.name = p.stem().u8string(); + lora_entry.fullpath = p.u8string(); + std::string rel = p.lexically_relative(lora_dir).u8string(); + std::replace(rel.begin(), rel.end(), '\\', '/'); + lora_entry.path = rel; + + new_cache.push_back(std::move(lora_entry)); + } + } + + std::sort(new_cache.begin(), new_cache.end(), [](const LoraEntry& a, const LoraEntry& b) { + return a.path < b.path; + }); + + { + std::lock_guard lock(*rt.lora_mutex); + *rt.lora_cache = std::move(new_cache); + } +} + +std::string get_lora_full_path(ServerRuntime& rt, const std::string& path) { + std::lock_guard lock(*rt.lora_mutex); + auto it = std::find_if(rt.lora_cache->begin(), rt.lora_cache->end(), + [&](const LoraEntry& entry) { return entry.path == path; }); + return it != rt.lora_cache->end() ? it->fullpath : ""; +} + +void refresh_upscaler_cache(ServerRuntime& rt) { + std::vector new_cache; + + fs::path upscaler_dir = rt.ctx_params->hires_upscalers_dir; + if (fs::exists(upscaler_dir) && fs::is_directory(upscaler_dir)) { + for (auto& entry : fs::directory_iterator(upscaler_dir)) { + if (!entry.is_regular_file()) { + continue; + } + const fs::path& p = entry.path(); + if (!is_supported_model_ext(p)) { + continue; + } + + UpscalerEntry upscaler_entry; + upscaler_entry.name = p.stem().u8string(); + upscaler_entry.fullpath = fs::absolute(p).lexically_normal().u8string(); + upscaler_entry.model_name = "ESRGAN_4x"; + upscaler_entry.path = p.filename().u8string(); + + new_cache.push_back(std::move(upscaler_entry)); + } + } + + std::sort(new_cache.begin(), new_cache.end(), [](const UpscalerEntry& a, const UpscalerEntry& b) { + return a.name < b.name; + }); + + { + std::lock_guard lock(*rt.upscaler_mutex); + *rt.upscaler_cache = std::move(new_cache); + } +} + +int64_t unix_timestamp_now() { + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} diff --git a/examples/server/runtime.h b/examples/server/runtime.h new file mode 100644 index 000000000..5c5f2d480 --- /dev/null +++ b/examples/server/runtime.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include "common/common.h" +#include "common/resource_owners.hpp" +#include "stable-diffusion.h" + +using json = nlohmann::json; + +struct ArgOptions; +struct SDContextParams; +struct AsyncJobManager; + +struct SDSvrParams { + std::string listen_ip = "127.0.0.1"; + int listen_port = 1234; + std::string serve_html_path; + bool normal_exit = false; + bool verbose = false; + bool color = false; + + ArgOptions get_options(); + bool validate(); + bool resolve_and_validate(); + std::string to_string() const; +}; + +struct LoraEntry { + std::string name; + std::string path; + std::string fullpath; +}; + +struct UpscalerEntry { + std::string name; + std::string path; + std::string fullpath; + std::string model_name; + int scale = 4; +}; + +struct ServerRuntime { + sd_ctx_t* sd_ctx; + std::mutex* sd_ctx_mutex; + const SDSvrParams* svr_params; + const SDContextParams* ctx_params; + const SDGenerationParams* default_gen_params; + std::vector* lora_cache; + std::mutex* lora_mutex; + std::vector* upscaler_cache; + std::mutex* upscaler_mutex; + AsyncJobManager* async_job_manager; +}; + +struct ImgGenJobRequest { + SDGenerationParams gen_params; + std::string output_format = "png"; + int output_compression = 100; + + sd_img_gen_params_t to_sd_img_gen_params_t() { + return gen_params.to_sd_img_gen_params_t(); + } +}; + +struct VidGenJobRequest { + SDGenerationParams gen_params; + std::string output_format = "webm"; + int output_compression = 100; + + sd_vid_gen_params_t to_sd_vid_gen_params_t() { + return gen_params.to_sd_vid_gen_params_t(); + } +}; + +std::string base64_encode(const std::vector& bytes); +std::string normalize_output_format(std::string output_format); +std::vector supported_img_output_formats(bool allow_webp = true); +std::vector supported_vid_output_formats(); +bool assign_output_options(ImgGenJobRequest& request, + std::string output_format, + int output_compression, + bool allow_webp, + std::string& error_message); +bool assign_output_options(VidGenJobRequest& request, + std::string output_format, + int output_compression, + std::string& error_message); +std::string video_mime_type(const std::string& output_format); +bool runtime_supports_generation_mode(const ServerRuntime& runtime, SDMode mode); +std::string unsupported_generation_mode_error(SDMode mode); +void refresh_lora_cache(ServerRuntime& rt); +std::string get_lora_full_path(ServerRuntime& rt, const std::string& path); +void refresh_upscaler_cache(ServerRuntime& rt); +int64_t unix_timestamp_now(); diff --git a/format-code.sh b/format-code.sh index ac5fd340b..8aa422bca 100644 --- a/format-code.sh +++ b/format-code.sh @@ -1,4 +1,6 @@ -for f in src/*.cpp src/*.h src/*.hpp src/vocab/*.h src/vocab/*.cpp examples/cli/*.cpp examples/common/*.hpp examples/cli/*.h examples/server/*.cpp; do +for f in src/*.cpp src/*.h src/*.hpp src/tokenizers/*.h src/tokenizers/*.cpp src/tokenizers/vocab/*.h src/tokenizers/vocab/*.cpp \ + src/model_io/*.h src/model_io/*.cpp examples/cli/*.cpp examples/cli/*.h examples/server/*.cpp \ + examples/common/*.hpp examples/common/*.h examples/common/*.cpp; do [[ "$f" == vocab* ]] && continue echo "formatting '$f'" # if [ "$f" != "stable-diffusion.h" ]; then diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index f093bb56c..c4c14949c 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -50,6 +50,7 @@ enum sample_method_t { TCD_SAMPLE_METHOD, RES_MULTISTEP_SAMPLE_METHOD, RES_2S_SAMPLE_METHOD, + ER_SDE_SAMPLE_METHOD, SAMPLE_METHOD_COUNT }; @@ -202,6 +203,7 @@ typedef struct { bool chroma_use_t5_mask; int chroma_t5_mask_pad; bool qwen_image_zero_cond_t; + float max_vram; } sd_ctx_params_t; typedef struct { @@ -288,6 +290,32 @@ typedef struct { const char* path; } sd_lora_t; +enum sd_hires_upscaler_t { + SD_HIRES_UPSCALER_NONE, + SD_HIRES_UPSCALER_LATENT, + SD_HIRES_UPSCALER_LATENT_NEAREST, + SD_HIRES_UPSCALER_LATENT_NEAREST_EXACT, + SD_HIRES_UPSCALER_LATENT_ANTIALIASED, + SD_HIRES_UPSCALER_LATENT_BICUBIC, + SD_HIRES_UPSCALER_LATENT_BICUBIC_ANTIALIASED, + SD_HIRES_UPSCALER_LANCZOS, + SD_HIRES_UPSCALER_NEAREST, + SD_HIRES_UPSCALER_MODEL, + SD_HIRES_UPSCALER_COUNT, +}; + +typedef struct { + bool enabled; + enum sd_hires_upscaler_t upscaler; + const char* model_path; + float scale; + int target_width; + int target_height; + int steps; + float denoising_strength; + int upscale_tile_size; +} sd_hires_params_t; + typedef struct { const sd_lora_t* loras; uint32_t lora_count; @@ -311,6 +339,7 @@ typedef struct { sd_pm_params_t pm_params; sd_tiling_params_t vae_tiling_params; sd_cache_params_t cache; + sd_hires_params_t hires; } sd_img_gen_params_t; typedef struct { @@ -347,6 +376,8 @@ SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data); SD_API void sd_set_preview_callback(sd_preview_cb_t cb, enum preview_t mode, int interval, bool denoised, bool noisy, void* data); SD_API int32_t sd_get_num_physical_cores(); SD_API const char* sd_get_system_info(); +SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx); +SD_API bool sd_ctx_supports_video_generation(const sd_ctx_t* sd_ctx); SD_API const char* sd_type_name(enum sd_type_t type); SD_API enum sd_type_t str_to_sd_type(const char* str); @@ -362,8 +393,11 @@ SD_API const char* sd_preview_name(enum preview_t preview); SD_API enum preview_t str_to_preview(const char* str); SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode); SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str); +SD_API const char* sd_hires_upscaler_name(enum sd_hires_upscaler_t upscaler); +SD_API enum sd_hires_upscaler_t str_to_sd_hires_upscaler(const char* str); SD_API void sd_cache_params_init(sd_cache_params_t* cache_params); +SD_API void sd_hires_params_init(sd_hires_params_t* hires_params); SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params); SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params); diff --git a/src/anima.hpp b/src/anima.hpp index 5850cc3e6..4bfc04749 100644 --- a/src/anima.hpp +++ b/src/anima.hpp @@ -499,9 +499,15 @@ namespace Anima { encoder_hidden_states = adapted_context; } + sd::ggml_graph_cut::mark_graph_cut(x, "anima.prelude", "x"); + sd::ggml_graph_cut::mark_graph_cut(embedded_timestep, "anima.prelude", "embedded_timestep"); + sd::ggml_graph_cut::mark_graph_cut(temb, "anima.prelude", "temb"); + sd::ggml_graph_cut::mark_graph_cut(encoder_hidden_states, "anima.prelude", "context"); + for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["blocks." + std::to_string(i)]); x = block->forward(ctx, x, encoder_hidden_states, embedded_timestep, temb, image_pe); + sd::ggml_graph_cut::mark_graph_cut(x, "anima.blocks." + std::to_string(i), "x"); } x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, h*w, ph*pw*C] diff --git a/src/auto_encoder_kl.hpp b/src/auto_encoder_kl.hpp index 039fb9df3..4fb28a16f 100644 --- a/src/auto_encoder_kl.hpp +++ b/src/auto_encoder_kl.hpp @@ -328,6 +328,7 @@ class Encoder : public GGMLBlock { auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); auto h = conv_in->forward(ctx, x); // [N, ch, h, w] + // sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.prelude", "h"); // downsampling size_t num_resolutions = ch_mult.size(); @@ -337,12 +338,14 @@ class Encoder : public GGMLBlock { auto down_block = std::dynamic_pointer_cast(blocks[name]); h = down_block->forward(ctx, h); + // sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.down." + std::to_string(i) + ".block." + std::to_string(j), "h"); } if (i != num_resolutions - 1) { std::string name = "down." + std::to_string(i) + ".downsample"; auto down_sample = std::dynamic_pointer_cast(blocks[name]); h = down_sample->forward(ctx, h); + // sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.down." + std::to_string(i) + ".downsample", "h"); } } @@ -350,6 +353,7 @@ class Encoder : public GGMLBlock { h = mid_block_1->forward(ctx, h); h = mid_attn_1->forward(ctx, h); h = mid_block_2->forward(ctx, h); // [N, block_in, h, w] + // sd::ggml_graph_cut::mark_graph_cut(h, "vae.encoder.mid", "h"); // end h = norm_out->forward(ctx, h); @@ -450,6 +454,7 @@ class Decoder : public GGMLBlock { // conv_in auto h = conv_in->forward(ctx, z); // [N, block_in, h, w] + // sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.prelude", "h"); // middle h = mid_block_1->forward(ctx, h); @@ -457,6 +462,7 @@ class Decoder : public GGMLBlock { h = mid_attn_1->forward(ctx, h); h = mid_block_2->forward(ctx, h); // [N, block_in, h, w] + // sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.mid", "h"); // upsampling int num_resolutions = static_cast(ch_mult.size()); @@ -466,12 +472,14 @@ class Decoder : public GGMLBlock { auto up_block = std::dynamic_pointer_cast(blocks[name]); h = up_block->forward(ctx, h); + // sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.up." + std::to_string(i) + ".block." + std::to_string(j), "h"); } if (i != 0) { std::string name = "up." + std::to_string(i) + ".upsample"; auto up_sample = std::dynamic_pointer_cast(blocks[name]); h = up_sample->forward(ctx, h); + // sd::ggml_graph_cut::mark_graph_cut(h, "vae.decoder.up." + std::to_string(i) + ".upsample", "h"); } } @@ -501,14 +509,39 @@ class AutoEncoderKLModel : public GGMLBlock { bool double_z = true; } dd_config; + static std::string get_tensor_name(const std::string& prefix, const std::string& name) { + return prefix.empty() ? name : prefix + "." + name; + } + + void detect_decoder_ch(const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + int& decoder_ch) { + auto conv_in_iter = tensor_storage_map.find(get_tensor_name(prefix, "decoder.conv_in.weight")); + if (conv_in_iter != tensor_storage_map.end() && conv_in_iter->second.n_dims >= 4 && conv_in_iter->second.ne[3] > 0) { + int last_ch_mult = dd_config.ch_mult.back(); + int64_t conv_in_out_channels = conv_in_iter->second.ne[3]; + if (last_ch_mult > 0 && conv_in_out_channels % last_ch_mult == 0) { + decoder_ch = static_cast(conv_in_out_channels / last_ch_mult); + LOG_INFO("vae decoder: ch = %d", decoder_ch); + } else { + LOG_WARN("vae decoder: failed to infer ch from %s (%" PRId64 " / %d)", + get_tensor_name(prefix, "decoder.conv_in.weight").c_str(), + conv_in_out_channels, + last_ch_mult); + } + } + } + public: - AutoEncoderKLModel(SDVersion version = VERSION_SD1, - bool decode_only = true, - bool use_linear_projection = false, - bool use_video_decoder = false) + AutoEncoderKLModel(SDVersion version = VERSION_SD1, + bool decode_only = true, + bool use_linear_projection = false, + bool use_video_decoder = false, + const String2TensorStorage& tensor_storage_map = {}, + const std::string& prefix = "") : version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) { if (sd_version_is_dit(version)) { - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { dd_config.z_channels = 32; embed_dim = 32; } else { @@ -519,7 +552,9 @@ class AutoEncoderKLModel : public GGMLBlock { if (use_video_decoder) { use_quant = false; } - blocks["decoder"] = std::shared_ptr(new Decoder(dd_config.ch, + int decoder_ch = dd_config.ch; + detect_decoder_ch(tensor_storage_map, prefix, decoder_ch); + blocks["decoder"] = std::shared_ptr(new Decoder(decoder_ch, dd_config.out_ch, dd_config.ch_mult, dd_config.num_res_blocks, @@ -551,7 +586,7 @@ class AutoEncoderKLModel : public GGMLBlock { ggml_tensor* decode(GGMLRunnerContext* ctx, ggml_tensor* z) { // z: [N, z_channels, h, w] - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { // [N, C*p*p, h, w] -> [N, C, h*p, w*p] int64_t p = 2; @@ -572,6 +607,7 @@ class AutoEncoderKLModel : public GGMLBlock { if (use_quant) { auto post_quant_conv = std::dynamic_pointer_cast(blocks["post_quant_conv"]); z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w] + // sd::ggml_graph_cut::mark_graph_cut(z, "vae.decode.prelude", "z"); } auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); @@ -589,8 +625,9 @@ class AutoEncoderKLModel : public GGMLBlock { if (use_quant) { auto quant_conv = std::dynamic_pointer_cast(blocks["quant_conv"]); z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8] + // sd::ggml_graph_cut::mark_graph_cut(z, "vae.encode.final", "z"); } - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0]; // [N, C, H, W] -> [N, C*p*p, H/p, W/p] @@ -613,7 +650,7 @@ class AutoEncoderKLModel : public GGMLBlock { int get_encoder_output_channels() { int factor = dd_config.double_z ? 2 : 1; - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { return dd_config.z_channels * 4; } return dd_config.z_channels * factor; @@ -646,7 +683,7 @@ struct AutoEncoderKL : public VAE { } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { scale_factor = 0.3611f; shift_factor = 0.1159f; - } else if (sd_version_is_flux2(version)) { + } else if (sd_version_uses_flux2_vae(version)) { scale_factor = 1.0f; shift_factor = 0.f; } @@ -662,7 +699,7 @@ struct AutoEncoderKL : public VAE { break; } } - ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder); + ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder, tensor_storage_map, prefix); ae.init(params_ctx, tensor_storage_map, prefix); } @@ -720,7 +757,7 @@ struct AutoEncoderKL : public VAE { } sd::Tensor vae_output_to_latents(const sd::Tensor& vae_output, std::shared_ptr rng) override { - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { return vae_output; } else if (version == VERSION_SD1_PIX2PIX) { return sd::ops::chunk(vae_output, 2, 2)[0]; @@ -731,7 +768,7 @@ struct AutoEncoderKL : public VAE { std::pair, sd::Tensor> get_latents_mean_std(const sd::Tensor& latents, int channel_dim) { GGML_ASSERT(channel_dim >= 0 && static_cast(channel_dim) < static_cast(latents.dim())); - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { GGML_ASSERT(latents.shape()[channel_dim] == 128); std::vector stats_shape(static_cast(latents.dim()), 1); stats_shape[static_cast(channel_dim)] = latents.shape()[channel_dim]; @@ -777,7 +814,7 @@ struct AutoEncoderKL : public VAE { } sd::Tensor diffusion_to_vae_latents(const sd::Tensor& latents) override { - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { int channel_dim = 2; auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim); return (latents * std_tensor) / scale_factor + mean_tensor; @@ -786,7 +823,7 @@ struct AutoEncoderKL : public VAE { } sd::Tensor vae_to_diffusion_latents(const sd::Tensor& latents) override { - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { int channel_dim = 2; auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim); return ((latents - mean_tensor) * scale_factor) / std_tensor; diff --git a/src/clip.hpp b/src/clip.hpp index 8f2ac0643..8b2084c49 100644 --- a/src/clip.hpp +++ b/src/clip.hpp @@ -3,455 +3,7 @@ #include "ggml_extend.hpp" #include "model.h" -#include "tokenize_util.h" -#include "vocab/vocab.h" - -/*================================================== CLIPTokenizer ===================================================*/ - -__STATIC_INLINE__ std::vector> bytes_to_unicode() { - std::vector> byte_unicode_pairs; - std::set byte_set; - for (int b = static_cast('!'); b <= static_cast('~'); ++b) { - byte_set.insert(b); - byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(b))); - } - for (int b = 161; b <= 172; ++b) { - byte_set.insert(b); - byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(b))); - } - for (int b = 174; b <= 255; ++b) { - byte_set.insert(b); - byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(b))); - } - int n = 0; - for (int b = 0; b < 256; ++b) { - if (byte_set.find(b) == byte_set.end()) { - byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(n + 256))); - ++n; - } - } - // LOG_DEBUG("byte_unicode_pairs %d", byte_unicode_pairs.size()); - return byte_unicode_pairs; -} - -// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py - -typedef std::function&)> on_new_token_cb_t; - -class CLIPTokenizer { -private: - std::map byte_encoder; - std::map byte_decoder; - std::map encoder; - std::map decoder; - std::map, int> bpe_ranks; - std::regex pat; - int encoder_len; - int bpe_len; - - std::vector special_tokens; - -public: - const std::string UNK_TOKEN = "<|endoftext|>"; - const std::string BOS_TOKEN = "<|startoftext|>"; - const std::string EOS_TOKEN = "<|endoftext|>"; - const std::string PAD_TOKEN = "<|endoftext|>"; - - const int UNK_TOKEN_ID = 49407; - const int BOS_TOKEN_ID = 49406; - const int EOS_TOKEN_ID = 49407; - const int PAD_TOKEN_ID = 49407; - -private: - static std::string strip(const std::string& str) { - std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); - std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); - - if (start == std::string::npos) { - // String contains only whitespace characters - return ""; - } - - return str.substr(start, end - start + 1); - } - - static std::string whitespace_clean(std::string text) { - text = std::regex_replace(text, std::regex(R"(\s+)"), " "); - text = strip(text); - return text; - } - - static std::set> get_pairs(const std::vector& subwords) { - std::set> pairs; - if (subwords.size() == 0) { - return pairs; - } - std::u32string prev_subword = subwords[0]; - for (int i = 1; i < subwords.size(); i++) { - std::u32string subword = subwords[i]; - std::pair pair(prev_subword, subword); - pairs.insert(pair); - prev_subword = subword; - } - return pairs; - } - - bool is_special_token(const std::string& token) { - for (auto& special_token : special_tokens) { - if (special_token == token) { - return true; - } - } - return false; - } - -public: - CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "") - : PAD_TOKEN_ID(pad_token_id) { - if (merges_utf8_str.size() > 0) { - load_from_merges(merges_utf8_str); - } else { - load_from_merges(load_clip_merges()); - } - add_special_token("<|startoftext|>"); - add_special_token("<|endoftext|>"); - } - - void load_from_merges(const std::string& merges_utf8_str) { - auto byte_unicode_pairs = bytes_to_unicode(); - // printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size()); - byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); - for (auto& pair : byte_unicode_pairs) { - byte_decoder[pair.second] = pair.first; - } - // for (auto & pair: byte_unicode_pairs) { - // std::cout << pair.first << ": " << pair.second << std::endl; - // } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - // LOG_DEBUG("merges size %llu", merges.size()); - GGML_ASSERT(merges.size() == 48895); - merges = std::vector(merges.begin() + 1, merges.end()); - std::vector> merge_pairs; - for (const auto& merge : merges) { - size_t space_pos = merge.find(' '); - merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - // printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(), - // utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - } - std::vector vocab; - for (const auto& pair : byte_unicode_pairs) { - vocab.push_back(pair.second); - } - for (const auto& pair : byte_unicode_pairs) { - vocab.push_back(pair.second + utf8_to_utf32("")); - } - for (const auto& merge : merge_pairs) { - vocab.push_back(merge.first + merge.second); - } - vocab.push_back(utf8_to_utf32("<|startoftext|>")); - vocab.push_back(utf8_to_utf32("<|endoftext|>")); - LOG_DEBUG("vocab size: %llu", vocab.size()); - int i = 0; - for (const auto& token : vocab) { - encoder[token] = i; - decoder[i] = token; - i++; - } - encoder_len = i; - - auto it = encoder.find(utf8_to_utf32("img")); - if (it != encoder.end()) { - LOG_DEBUG("trigger word img already in vocab"); - } else { - LOG_DEBUG("trigger word img not in vocab yet"); - } - - int rank = 0; - for (const auto& merge : merge_pairs) { - bpe_ranks[merge] = rank++; - } - bpe_len = rank; - }; - - void add_token(const std::string& text) { - std::u32string token = utf8_to_utf32(text); - auto it = encoder.find(token); - if (it != encoder.end()) { - encoder[token] = encoder_len; - decoder[encoder_len] = token; - encoder_len++; - } - } - - void add_special_token(const std::string& token) { - special_tokens.push_back(token); - } - - std::u32string bpe(const std::u32string& token) { - std::vector word; - - for (int i = 0; i < token.size() - 1; i++) { - word.emplace_back(1, token[i]); - } - word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("")); - - std::set> pairs = get_pairs(word); - - if (pairs.empty()) { - return token + utf8_to_utf32(""); - } - - while (true) { - auto min_pair_iter = std::min_element(pairs.begin(), - pairs.end(), - [&](const std::pair& a, - const std::pair& b) { - if (bpe_ranks.find(a) == bpe_ranks.end()) { - return false; - } else if (bpe_ranks.find(b) == bpe_ranks.end()) { - return true; - } - return bpe_ranks.at(a) < bpe_ranks.at(b); - }); - - const std::pair& bigram = *min_pair_iter; - - if (bpe_ranks.find(bigram) == bpe_ranks.end()) { - break; - } - - std::u32string first = bigram.first; - std::u32string second = bigram.second; - std::vector new_word; - int32_t i = 0; - - while (i < word.size()) { - auto it = std::find(word.begin() + i, word.end(), first); - if (it == word.end()) { - new_word.insert(new_word.end(), word.begin() + i, word.end()); - break; - } - new_word.insert(new_word.end(), word.begin() + i, it); - i = static_cast(std::distance(word.begin(), it)); - - if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { - new_word.push_back(first + second); - i += 2; - } else { - new_word.push_back(word[i]); - i += 1; - } - } - - word = new_word; - - if (word.size() == 1) { - break; - } - pairs = get_pairs(word); - } - - std::u32string result; - for (int i = 0; i < word.size(); i++) { - result += word[i]; - if (i != word.size() - 1) { - result += utf8_to_utf32(" "); - } - } - - return result; - } - - std::vector tokenize(std::string text, - on_new_token_cb_t on_new_token_cb, - size_t max_length = 0, - bool padding = false) { - std::vector tokens = encode(text, on_new_token_cb); - - tokens.insert(tokens.begin(), BOS_TOKEN_ID); - if (max_length > 0) { - if (tokens.size() > max_length - 1) { - tokens.resize(max_length - 1); - tokens.push_back(EOS_TOKEN_ID); - } else { - tokens.push_back(EOS_TOKEN_ID); - if (padding) { - tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID); - } - } - } - - return tokens; - } - - void pad_tokens(std::vector& tokens, - std::vector& weights, - size_t max_length = 0, - bool padding = false) { - if (max_length > 0 && padding) { - size_t n = static_cast(std::ceil(tokens.size() * 1.0 / (max_length - 2))); - if (n == 0) { - n = 1; - } - size_t length = max_length * n; - LOG_DEBUG("token length: %llu", length); - std::vector new_tokens; - std::vector new_weights; - new_tokens.push_back(BOS_TOKEN_ID); - new_weights.push_back(1.0); - int token_idx = 0; - for (int i = 1; i < length; i++) { - if (token_idx >= tokens.size()) { - break; - } - if (i % max_length == 0) { - new_tokens.push_back(BOS_TOKEN_ID); - new_weights.push_back(1.0); - } else if (i % max_length == max_length - 1) { - new_tokens.push_back(EOS_TOKEN_ID); - new_weights.push_back(1.0); - } else { - new_tokens.push_back(tokens[token_idx]); - new_weights.push_back(weights[token_idx]); - token_idx++; - } - } - - new_tokens.push_back(EOS_TOKEN_ID); - new_weights.push_back(1.0); - tokens = new_tokens; - weights = new_weights; - - if (padding) { - tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID); - weights.insert(weights.end(), length - weights.size(), 1.0); - } - } - } - - std::string clean_up_tokenization(std::string& text) { - std::regex pattern(R"( ,)"); - // Replace " ," with "," - std::string result = std::regex_replace(text, pattern, ","); - return result; - } - - std::string decode(const std::vector& tokens) { - std::string text = ""; - for (int t : tokens) { - if (t == 49406 || t == 49407) - continue; - std::u32string ts = decoder[t]; - // printf("%d, %s \n", t, utf32_to_utf8(ts).c_str()); - std::string s = utf32_to_utf8(ts); - if (s.length() >= 4) { - if (ends_with(s, "")) { - text += s.replace(s.length() - 4, s.length() - 1, "") + " "; - } else { - text += s; - } - } else { - text += " " + s; - } - } - // std::vector bytes; - // for (auto c : text){ - // bytes.push_back(byte_decoder[c]); - // } - - // std::string s((char *)bytes.data()); - // std::string s = ""; - text = clean_up_tokenization(text); - return trim(text); - } - - std::vector token_split(const std::string& text) { - std::regex pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)", - std::regex::icase); - std::sregex_iterator iter(text.begin(), text.end(), pat); - std::sregex_iterator end; - - std::vector result; - for (; iter != end; ++iter) { - result.emplace_back(iter->str()); - } - - return result; - } - - std::vector encode(std::string text, on_new_token_cb_t on_new_token_cb) { - std::string original_text = text; - std::vector bpe_tokens; - text = whitespace_clean(text); - std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); }); - - std::string str = text; - std::vector token_strs; - - auto splited_texts = split_with_special_tokens(text, special_tokens); - - for (auto& splited_text : splited_texts) { - LOG_DEBUG("token %s", splited_text.c_str()); - if (is_special_token(splited_text)) { - LOG_DEBUG("special %s", splited_text.c_str()); - bool skip = on_new_token_cb(splited_text, bpe_tokens); - if (skip) { - token_strs.push_back(splited_text); - continue; - } - continue; - } - - auto tokens = token_split(splited_text); - for (auto& token : tokens) { - if (on_new_token_cb != nullptr) { - bool skip = on_new_token_cb(token, bpe_tokens); - if (skip) { - token_strs.push_back(token); - continue; - } - } - - std::string token_str = token; - std::u32string utf32_token; - for (int i = 0; i < token_str.length(); i++) { - unsigned char b = token_str[i]; - utf32_token += byte_encoder[b]; - } - auto bpe_strs = bpe(utf32_token); - size_t start = 0; - size_t pos; - while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { - auto bpe_str = bpe_strs.substr(start, pos - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - - start = pos + 1; - } - auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - } - } - // std::stringstream ss; - // ss << "["; - // for (auto token : token_strs) { - // ss << "\"" << token << "\", "; - // } - // ss << "]"; - // LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); - // printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str()); - return bpe_tokens; - } -}; +#include "tokenizers/clip_tokenizer.h" /*================================================ FrozenCLIPEmbedder ================================================*/ @@ -543,8 +95,9 @@ struct CLIPEncoder : public GGMLBlock { ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, - ggml_tensor* mask = nullptr, - int clip_skip = -1) { + ggml_tensor* mask = nullptr, + int clip_skip = -1, + const std::string& graph_cut_prefix = "") { // x: [N, n_token, d_model] int layer_idx = n_layer - 1; // LOG_DEBUG("clip_skip %d", clip_skip); @@ -560,6 +113,9 @@ struct CLIPEncoder : public GGMLBlock { std::string name = "layers." + std::to_string(i); auto layer = std::dynamic_pointer_cast(blocks[name]); x = layer->forward(ctx, x, mask); // [N, n_token, d_model] + if (!graph_cut_prefix.empty()) { + sd::ggml_graph_cut::mark_graph_cut(x, graph_cut_prefix + ".layers." + std::to_string(i), "x"); + } // LOG_DEBUG("layer %d", i); } return x; @@ -752,7 +308,8 @@ class CLIPTextModel : public GGMLBlock { auto final_layer_norm = std::dynamic_pointer_cast(blocks["final_layer_norm"]); auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size] - x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip); + sd::ggml_graph_cut::mark_graph_cut(x, "clip_text.prelude", "x"); + x = encoder->forward(ctx, x, mask, return_pooled ? -1 : clip_skip, "clip_text"); if (return_pooled || with_final_ln) { x = final_layer_norm->forward(ctx, x); } @@ -816,7 +373,8 @@ class CLIPVisionModel : public GGMLBlock { auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim] x = pre_layernorm->forward(ctx, x); - x = encoder->forward(ctx, x, nullptr, clip_skip); + sd::ggml_graph_cut::mark_graph_cut(x, "clip_vision.prelude", "x"); + x = encoder->forward(ctx, x, nullptr, clip_skip, "clip_vision"); auto last_hidden_state = x; diff --git a/src/common_block.hpp b/src/common_block.hpp index 2cef389af..e6c0b06bd 100644 --- a/src/common_block.hpp +++ b/src/common_block.hpp @@ -1,7 +1,9 @@ #ifndef __COMMON_BLOCK_HPP__ #define __COMMON_BLOCK_HPP__ +#include "ggml-backend.h" #include "ggml_extend.hpp" +#include "util.h" class DownSampleBlock : public GGMLBlock { protected: @@ -248,9 +250,6 @@ class FeedForward : public GGMLBlock { float scale = 1.f; if (precision_fix) { scale = 1.f / 128.f; -#ifdef SD_USE_VULKAN - force_prec_f32 = true; -#endif } // The purpose of the scale here is to prevent NaN issues in certain situations. // For example, when using Vulkan without enabling force_prec_f32, @@ -264,6 +263,9 @@ class FeedForward : public GGMLBlock { auto net_0 = std::dynamic_pointer_cast(blocks["net.0"]); auto net_2 = std::dynamic_pointer_cast(blocks["net.2"]); + if (sd_backend_is(ctx->backend, "Vulkan")) { + net_2->set_force_prec_f32(true); + } x = net_0->forward(ctx, x); // [ne3, ne2, ne1, inner_dim] x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out] @@ -277,6 +279,7 @@ class CrossAttention : public GGMLBlock { int64_t context_dim; int64_t n_head; int64_t d_head; + bool xtra_dim = false; public: CrossAttention(int64_t query_dim, @@ -288,7 +291,11 @@ class CrossAttention : public GGMLBlock { query_dim(query_dim), context_dim(context_dim) { int64_t inner_dim = d_head * n_head; - + if (context_dim == 320 && d_head == 320) { + // LOG_DEBUG("CrossAttention: temp set dim to 1024 for sdxs_09"); + xtra_dim = true; + context_dim = 1024; + } blocks["to_q"] = std::shared_ptr(new Linear(query_dim, inner_dim, false)); blocks["to_k"] = std::shared_ptr(new Linear(context_dim, inner_dim, false)); blocks["to_v"] = std::shared_ptr(new Linear(context_dim, inner_dim, false)); @@ -313,10 +320,16 @@ class CrossAttention : public GGMLBlock { int64_t n_context = context->ne[1]; int64_t inner_dim = d_head * n_head; - auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim] + auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim] + if (xtra_dim) { + // LOG_DEBUG("CrossAttention: temp set dim to 1024 for sdxs_09"); + context->ne[0] = 1024; // patch dim + } auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] - + if (xtra_dim) { + context->ne[0] = 320; // reset dim to orig + } x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] diff --git a/src/conditioner.hpp b/src/conditioner.hpp index 5564373eb..4907938b0 100644 --- a/src/conditioner.hpp +++ b/src/conditioner.hpp @@ -85,7 +85,8 @@ struct Conditioner { virtual void free_params_buffer() = 0; virtual void get_param_tensors(std::map& tensors) = 0; virtual size_t get_params_buffer_size() = 0; - virtual void set_flash_attention_enabled(bool enabled) = 0; + virtual void set_max_graph_vram_bytes(size_t max_vram_bytes) {} + virtual void set_flash_attention_enabled(bool enabled) = 0; virtual void set_weight_adapter(const std::shared_ptr& adapter) {} virtual std::tuple> get_learned_condition_with_trigger(int n_threads, const ConditionerParams& conditioner_params) { @@ -165,6 +166,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { return buffer_size; } + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + text_model->set_max_graph_vram_bytes(max_vram_bytes); + if (sd_version_is_sdxl(version)) { + text_model2->set_max_graph_vram_bytes(max_vram_bytes); + } + } + void set_flash_attention_enabled(bool enabled) override { text_model->set_flash_attention_enabled(enabled); if (sd_version_is_sdxl(version)) { @@ -256,15 +264,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { return true; } - std::tuple, std::vector, std::vector> - tokenize_with_trigger_token(std::string text, - int num_input_imgs, - int32_t image_token, - bool padding = false) { - return tokenize_with_trigger_token(text, num_input_imgs, image_token, - text_model->model.n_token, padding); - } - std::vector convert_token_to_id(std::string text) { auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { auto iter = embedding_map.find(str); @@ -288,9 +287,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { std::tuple, std::vector, std::vector> tokenize_with_trigger_token(std::string text, int num_input_imgs, - int32_t image_token, - size_t max_length = 0, - bool padding = false) { + int32_t image_token) { auto parsed_attention = parse_prompt_attention(text); { @@ -377,7 +374,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { // tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID); // weights.insert(weights.begin(), 1.0); - tokenizer.pad_tokens(tokens, weights, max_length, padding); + tokenizer.pad_tokens(tokens, &weights, nullptr, text_model->model.n_token, text_model->model.n_token, true); int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs; for (int i = 0; i < tokens.size(); i++) { // if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs @@ -403,13 +400,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { } std::pair, std::vector> tokenize(std::string text, - bool padding = false) { - return tokenize(text, text_model->model.n_token, padding); - } - - std::pair, std::vector> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { + size_t min_length = 0, + size_t max_length = 0, + bool allow_overflow_expand = true) { auto parsed_attention = parse_prompt_attention(text); { @@ -460,7 +453,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { weights.insert(weights.end(), curr_tokens.size(), curr_weight); } - tokenizer.pad_tokens(tokens, weights, max_length, padding); + tokenizer.pad_tokens(tokens, &weights, nullptr, min_length, max_length, allow_overflow_expand); // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", "; @@ -603,8 +596,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { GGML_ASSERT(image_tokens.size() == 1); auto tokens_and_weights = tokenize_with_trigger_token(conditioner_params.text, conditioner_params.num_input_imgs, - image_tokens[0], - true); + image_tokens[0]); std::vector& tokens = std::get<0>(tokens_and_weights); std::vector& weights = std::get<1>(tokens_and_weights); std::vector& clsm = std::get<2>(tokens_and_weights); @@ -630,7 +622,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { std::string remove_trigger_from_prompt(const std::string& prompt) override { auto image_tokens = convert_token_to_id(trigger_word); GGML_ASSERT(image_tokens.size() == 1); - auto tokens_and_weights = tokenize(prompt, false); + auto tokens_and_weights = tokenize(prompt); std::vector& tokens = tokens_and_weights.first; auto it = std::find(tokens.begin(), tokens.end(), image_tokens[0]); GGML_ASSERT(it != tokens.end()); // prompt must have trigger word @@ -640,7 +632,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { SDCondition get_learned_condition(int n_threads, const ConditionerParams& conditioner_params) override { - auto tokens_and_weights = tokenize(conditioner_params.text, true); + auto tokens_and_weights = tokenize(conditioner_params.text, text_model->model.n_token, text_model->model.n_token, true); std::vector& tokens = tokens_and_weights.first; std::vector& weights = tokens_and_weights.second; return get_learned_condition_common(n_threads, @@ -797,6 +789,18 @@ struct SD3CLIPEmbedder : public Conditioner { return buffer_size; } + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + if (clip_l) { + clip_l->set_max_graph_vram_bytes(max_vram_bytes); + } + if (clip_g) { + clip_g->set_max_graph_vram_bytes(max_vram_bytes); + } + if (t5) { + t5->set_max_graph_vram_bytes(max_vram_bytes); + } + } + void set_flash_attention_enabled(bool enabled) override { if (clip_l) { clip_l->set_flash_attention_enabled(enabled); @@ -822,8 +826,9 @@ struct SD3CLIPEmbedder : public Conditioner { } std::vector, std::vector>> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { + size_t min_length = 0, + size_t max_length = 0, + bool allow_overflow_expand = true) { auto parsed_attention = parse_prompt_attention(text); { @@ -860,20 +865,20 @@ struct SD3CLIPEmbedder : public Conditioner { clip_g_weights.insert(clip_g_weights.end(), curr_tokens.size(), curr_weight); } if (t5) { - std::vector curr_tokens = t5_tokenizer.Encode(curr_text, true); + std::vector curr_tokens = t5_tokenizer.encode(curr_text); t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); } } if (clip_l) { - clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding); + clip_l_tokenizer.pad_tokens(clip_l_tokens, &clip_l_weights, nullptr, min_length, max_length, allow_overflow_expand); } if (clip_g) { - clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding); + clip_g_tokenizer.pad_tokens(clip_g_tokens, &clip_g_weights, nullptr, min_length, max_length, allow_overflow_expand); } if (t5) { - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, nullptr, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, &t5_weights, nullptr, min_length, max_length, true); } // for (int i = 0; i < clip_l_tokens.size(); i++) { @@ -1056,7 +1061,7 @@ struct SD3CLIPEmbedder : public Conditioner { SDCondition get_learned_condition(int n_threads, const ConditionerParams& conditioner_params) override { - auto tokens_and_weights = tokenize(conditioner_params.text, 77, true); + auto tokens_and_weights = tokenize(conditioner_params.text, 77, 77, true); return get_learned_condition_common(n_threads, tokens_and_weights, conditioner_params.clip_skip, @@ -1139,6 +1144,15 @@ struct FluxCLIPEmbedder : public Conditioner { return buffer_size; } + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + if (clip_l) { + clip_l->set_max_graph_vram_bytes(max_vram_bytes); + } + if (t5) { + t5->set_max_graph_vram_bytes(max_vram_bytes); + } + } + void set_flash_attention_enabled(bool enabled) override { if (clip_l) { clip_l->set_flash_attention_enabled(enabled); @@ -1158,8 +1172,8 @@ struct FluxCLIPEmbedder : public Conditioner { } std::vector, std::vector>> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { + size_t min_length = 0, + size_t max_length = 0) { auto parsed_attention = parse_prompt_attention(text); { @@ -1189,17 +1203,17 @@ struct FluxCLIPEmbedder : public Conditioner { clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight); } if (t5) { - std::vector curr_tokens = t5_tokenizer.Encode(curr_text, true); + std::vector curr_tokens = t5_tokenizer.encode(curr_text); t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); } } if (clip_l) { - clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding); + clip_l_tokenizer.pad_tokens(clip_l_tokens, &clip_l_weights, nullptr, 77, 77, true); } if (t5) { - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, nullptr, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, &t5_weights, nullptr, min_length, max_length, true); } // for (int i = 0; i < clip_l_tokens.size(); i++) { @@ -1300,7 +1314,7 @@ struct FluxCLIPEmbedder : public Conditioner { SDCondition get_learned_condition(int n_threads, const ConditionerParams& conditioner_params) override { - auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true); + auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, chunk_len); return get_learned_condition_common(n_threads, tokens_and_weights, conditioner_params.clip_skip, @@ -1364,6 +1378,12 @@ struct T5CLIPEmbedder : public Conditioner { return buffer_size; } + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + if (t5) { + t5->set_max_graph_vram_bytes(max_vram_bytes); + } + } + void set_flash_attention_enabled(bool enabled) override { if (t5) { t5->set_flash_attention_enabled(enabled); @@ -1377,8 +1397,8 @@ struct T5CLIPEmbedder : public Conditioner { } std::tuple, std::vector, std::vector> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { + size_t min_length = 0, + size_t max_length = 0) { auto parsed_attention = parse_prompt_attention(text); { @@ -1403,12 +1423,15 @@ struct T5CLIPEmbedder : public Conditioner { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = t5_tokenizer.Encode(curr_text, true); + std::vector curr_tokens = t5_tokenizer.encode(curr_text); t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); } - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, &t5_weights, &t5_mask, min_length, max_length, true); + for (auto& mask_value : t5_mask) { + mask_value = mask_value > 0.0f ? 0.0f : -HUGE_VALF; + } } return {t5_tokens, t5_weights, t5_mask}; } @@ -1496,7 +1519,7 @@ struct T5CLIPEmbedder : public Conditioner { SDCondition get_learned_condition(int n_threads, const ConditionerParams& conditioner_params) override { - auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true); + auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, chunk_len); return get_learned_condition_common(n_threads, tokens_and_weights, conditioner_params.clip_skip, @@ -1505,14 +1528,14 @@ struct T5CLIPEmbedder : public Conditioner { }; struct AnimaConditioner : public Conditioner { - std::shared_ptr qwen_tokenizer; + std::shared_ptr qwen_tokenizer; T5UniGramTokenizer t5_tokenizer; std::shared_ptr llm; AnimaConditioner(ggml_backend_t backend, bool offload_params_to_cpu, const String2TensorStorage& tensor_storage_map = {}) { - qwen_tokenizer = std::make_shared(); + qwen_tokenizer = std::make_shared(); llm = std::make_shared(LLM::LLMArch::QWEN3, backend, offload_params_to_cpu, @@ -1537,6 +1560,10 @@ struct AnimaConditioner : public Conditioner { return llm->get_params_buffer_size(); } + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + llm->set_max_graph_vram_bytes(max_vram_bytes); + } + void set_flash_attention_enabled(bool enabled) override { llm->set_flash_attention_enabled(enabled); } @@ -1578,7 +1605,7 @@ struct AnimaConditioner : public Conditioner { for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = t5_tokenizer.Encode(curr_text, true); + std::vector curr_tokens = t5_tokenizer.tokenize(curr_text, nullptr, true); t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); } @@ -1620,7 +1647,7 @@ struct AnimaConditioner : public Conditioner { struct LLMEmbedder : public Conditioner { SDVersion version; - std::shared_ptr tokenizer; + std::shared_ptr tokenizer; std::shared_ptr llm; LLMEmbedder(ggml_backend_t backend, @@ -1633,13 +1660,15 @@ struct LLMEmbedder : public Conditioner { LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL; if (version == VERSION_FLUX2) { arch = LLM::LLMArch::MISTRAL_SMALL_3_2; + } else if (sd_version_is_ernie_image(version)) { + arch = LLM::LLMArch::MINISTRAL_3_3B; } else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) { arch = LLM::LLMArch::QWEN3; } - if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) { - tokenizer = std::make_shared(); + if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2 || arch == LLM::LLMArch::MINISTRAL_3_3B) { + tokenizer = std::make_shared(); } else { - tokenizer = std::make_shared(); + tokenizer = std::make_shared(); } llm = std::make_shared(arch, backend, @@ -1667,6 +1696,10 @@ struct LLMEmbedder : public Conditioner { return buffer_size; } + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + llm->set_max_graph_vram_bytes(max_vram_bytes); + } + void set_flash_attention_enabled(bool enabled) override { llm->set_flash_attention_enabled(enabled); } @@ -1677,20 +1710,24 @@ struct LLMEmbedder : public Conditioner { } } - std::tuple, std::vector> tokenize(std::string text, - const std::pair& attn_range, - size_t max_length = 0, - bool padding = false) { + std::tuple, std::vector, std::vector> tokenize(std::string text, + const std::pair& attn_range, + size_t min_length = 0, + size_t max_length = 100000000) { std::vector> parsed_attention; if (attn_range.first >= 0 && attn_range.second > 0) { - parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); + if (attn_range.first > 0) { + parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); + } if (attn_range.second - attn_range.first > 0) { auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); parsed_attention.insert(parsed_attention.end(), new_parsed_attention.begin(), new_parsed_attention.end()); } - parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); + if (attn_range.second < text.size()) { + parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); + } } else { parsed_attention.emplace_back(text, 1.f); } @@ -1710,39 +1747,34 @@ struct LLMEmbedder : public Conditioner { for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = tokenizer->tokenize(curr_text, nullptr); + std::vector curr_tokens = tokenizer->encode(curr_text, nullptr); tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); weights.insert(weights.end(), curr_tokens.size(), curr_weight); } - tokenizer->pad_tokens(tokens, weights, max_length, padding); + std::vector mask; + tokenizer->pad_tokens(tokens, &weights, &mask, min_length, max_length); // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl; // } // std::cout << std::endl; - return {tokens, weights}; + return {tokens, weights, mask}; } sd::Tensor encode_prompt(int n_threads, const std::string prompt, const std::pair& prompt_attn_range, - int max_length, int min_length, + int hidden_states_min_length, const std::vector>>& image_embeds, const std::set& out_layers, int prompt_template_encode_start_idx) { - auto tokens_and_weights = tokenize(prompt, prompt_attn_range); - auto& tokens = std::get<0>(tokens_and_weights); - auto& weights = std::get<1>(tokens_and_weights); - std::vector mask; - - if (max_length > 0 && tokens.size() < max_length) { - mask.insert(mask.end(), tokens.size(), 1.f); - mask.insert(mask.end(), max_length - tokens.size(), 0.f); - tokenizer->pad_tokens(tokens, weights, max_length, true); - } + auto tokens_weights_mask = tokenize(prompt, prompt_attn_range, min_length); + auto& tokens = std::get<0>(tokens_weights_mask); + auto& weights = std::get<1>(tokens_weights_mask); + auto& mask = std::get<2>(tokens_weights_mask); sd::Tensor input_ids({static_cast(tokens.size())}, tokens); sd::Tensor attention_mask; @@ -1769,9 +1801,9 @@ struct LLMEmbedder : public Conditioner { GGML_ASSERT(hidden_states.shape()[1] > prompt_template_encode_start_idx); int64_t zero_pad_len = 0; - if (min_length > 0) { - if (hidden_states.shape()[1] - prompt_template_encode_start_idx < min_length) { - zero_pad_len = min_length - hidden_states.shape()[1] + prompt_template_encode_start_idx; + if (hidden_states_min_length > 0) { + if (hidden_states.shape()[1] - prompt_template_encode_start_idx < hidden_states_min_length) { + zero_pad_len = hidden_states_min_length - hidden_states.shape()[1] + prompt_template_encode_start_idx; } } @@ -1798,8 +1830,8 @@ struct LLMEmbedder : public Conditioner { std::vector> extra_prompts_attn_range; std::vector>> image_embeds; int prompt_template_encode_start_idx = 34; - int max_length = 0; // pad tokens - int min_length = 0; // zero pad hidden_states + int min_length = 0; // pad tokens + int hidden_states_min_length = 0; // zero pad hidden_states std::set out_layers; int64_t t0 = ggml_time_ms(); @@ -1874,7 +1906,7 @@ struct LLMEmbedder : public Conditioner { } } else if (version == VERSION_FLUX2) { prompt_template_encode_start_idx = 0; - min_length = 512; + hidden_states_min_length = 512; out_layers = {10, 20, 30}; prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; @@ -1884,6 +1916,13 @@ struct LLMEmbedder : public Conditioner { prompt_attn_range.second = static_cast(prompt.size()); prompt += "[/INST]"; + } else if (sd_version_is_ernie_image(version)) { + prompt_template_encode_start_idx = 0; + out_layers = {25}; // -2 + + prompt_attn_range.first = 0; + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); } else if (sd_version_is_z_image(version)) { prompt_template_encode_start_idx = 0; out_layers = {35}; // -2 @@ -1907,7 +1946,7 @@ struct LLMEmbedder : public Conditioner { } } else if (version == VERSION_FLUX2_KLEIN) { prompt_template_encode_start_idx = 0; - max_length = 512; + min_length = 512; out_layers = {9, 18, 27}; prompt = "<|im_start|>user\n"; @@ -1919,7 +1958,7 @@ struct LLMEmbedder : public Conditioner { prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; } else if (version == VERSION_OVIS_IMAGE) { prompt_template_encode_start_idx = 28; - max_length = prompt_template_encode_start_idx + 256; + min_length = prompt_template_encode_start_idx + 256; prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:"; @@ -1935,8 +1974,8 @@ struct LLMEmbedder : public Conditioner { auto hidden_states = encode_prompt(n_threads, prompt, prompt_attn_range, - max_length, min_length, + hidden_states_min_length, image_embeds, out_layers, prompt_template_encode_start_idx); @@ -1945,8 +1984,8 @@ struct LLMEmbedder : public Conditioner { auto extra_hidden_states = encode_prompt(n_threads, extra_prompts[i], extra_prompts_attn_range[i], - max_length, min_length, + hidden_states_min_length, image_embeds, out_layers, prompt_template_encode_start_idx); diff --git a/src/convert.cpp b/src/convert.cpp new file mode 100644 index 000000000..7cae8df0f --- /dev/null +++ b/src/convert.cpp @@ -0,0 +1,138 @@ +#include +#include +#include +#include + +#include "model.h" +#include "model_io/gguf_io.h" +#include "model_io/safetensors_io.h" +#include "util.h" + +#include "ggml-cpu.h" + +static ggml_type get_export_tensor_type(ModelLoader& model_loader, + const TensorStorage& tensor_storage, + ggml_type type, + const TensorTypeRules& tensor_type_rules) { + const std::string& name = tensor_storage.name; + ggml_type tensor_type = tensor_storage.type; + ggml_type dst_type = type; + + for (const auto& tensor_type_rule : tensor_type_rules) { + std::regex pattern(tensor_type_rule.first); + if (std::regex_search(name, pattern)) { + dst_type = tensor_type_rule.second; + break; + } + } + + if (model_loader.tensor_should_be_converted(tensor_storage, dst_type)) { + tensor_type = dst_type; + } + + return tensor_type; +} + +static bool load_tensors_for_export(ModelLoader& model_loader, + ggml_context* ggml_ctx, + ggml_type type, + const TensorTypeRules& tensor_type_rules, + std::vector& tensors) { + std::mutex tensor_mutex; + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { + const std::string& name = tensor_storage.name; + ggml_type tensor_type = get_export_tensor_type(model_loader, tensor_storage, type, tensor_type_rules); + + std::lock_guard lock(tensor_mutex); + ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne); + if (tensor == nullptr) { + LOG_ERROR("ggml_new_tensor failed"); + return false; + } + ggml_set_name(tensor, name.c_str()); + + if (!tensor->data) { + GGML_ASSERT(ggml_nelements(tensor) == 0); + // Avoid crashing writers by setting a dummy pointer for zero-sized tensors. + LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str()); + tensor->data = ggml_get_mem_buffer(ggml_ctx); + } + + TensorWriteInfo write_info; + write_info.tensor = tensor; + write_info.n_dims = tensor_storage.n_dims; + for (int i = 0; i < tensor_storage.n_dims; ++i) { + write_info.ne[i] = tensor_storage.ne[i]; + } + + *dst_tensor = tensor; + tensors.push_back(std::move(write_info)); + + return true; + }; + + bool success = model_loader.load_tensors(on_new_tensor_cb); + LOG_INFO("load tensors done"); + return success; +} + +bool convert(const char* input_path, + const char* vae_path, + const char* output_path, + sd_type_t output_type, + const char* tensor_type_rules, + bool convert_name) { + ModelLoader model_loader; + + if (!model_loader.init_from_file(input_path)) { + LOG_ERROR("init model loader from file failed: '%s'", input_path); + return false; + } + + if (vae_path != nullptr && strlen(vae_path) > 0) { + if (!model_loader.init_from_file(vae_path, "vae.")) { + LOG_ERROR("init model loader from file failed: '%s'", vae_path); + return false; + } + } + if (convert_name) { + model_loader.convert_tensors_name(); + } + + ggml_type type = (ggml_type)output_type; + bool output_is_safetensors = ends_with(output_path, ".safetensors"); + TensorTypeRules type_rules = parse_tensor_type_rules(tensor_type_rules); + + auto backend = ggml_backend_cpu_init(); + size_t mem_size = 1 * 1024 * 1024; // for padding + mem_size += model_loader.get_tensor_storage_map().size() * ggml_tensor_overhead(); + mem_size += model_loader.get_params_mem_size(backend, type); + LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f); + ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false}); + + if (ggml_ctx == nullptr) { + LOG_ERROR("ggml_init failed for converter"); + ggml_backend_free(backend); + return false; + } + + std::vector tensors; + bool success = load_tensors_for_export(model_loader, ggml_ctx, type, type_rules, tensors); + ggml_backend_free(backend); + + std::string error; + if (success) { + if (output_is_safetensors) { + success = write_safetensors_file(output_path, tensors, &error); + } else { + success = write_gguf_file(output_path, tensors, &error); + } + } + + if (!success && !error.empty()) { + LOG_ERROR("%s", error.c_str()); + } + + ggml_free(ggml_ctx); + return success; +} diff --git a/src/denoiser.hpp b/src/denoiser.hpp index 59b8c41b9..831da2580 100644 --- a/src/denoiser.hpp +++ b/src/denoiser.hpp @@ -658,32 +658,22 @@ inline float time_snr_shift(float alpha, float t) { } struct DiscreteFlowDenoiser : public Denoiser { - float sigmas[TIMESTEPS]; float shift = 3.0f; - float sigma_data = 1.0f; - DiscreteFlowDenoiser(float shift = 3.0f) { set_shift(shift); } - void set_parameters() { - for (int i = 1; i < TIMESTEPS + 1; i++) { - sigmas[i - 1] = t_to_sigma(static_cast(i)); - } - } - void set_shift(float shift) { this->shift = shift; - set_parameters(); } float sigma_min() override { - return sigmas[0]; + return t_to_sigma(0); } float sigma_max() override { - return sigmas[TIMESTEPS - 1]; + return t_to_sigma(TIMESTEPS - 1); } float sigma_to_t(float sigma) override { @@ -818,6 +808,18 @@ static std::tuple get_ancestral_step_flow(float sigma_from, return {sigma_down, sigma_up, alpha_scale}; } +static std::tuple get_ancestral_step(float sigma_from, + float sigma_to, + float eta, + bool is_flow_denoiser) { + if (is_flow_denoiser) { + return get_ancestral_step_flow(sigma_from, sigma_to, eta); + } else { + auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta); + return {sigma_down, sigma_up, 1.0f}; + } +} + static sd::Tensor sample_euler_ancestral(denoise_cb_t model, sd::Tensor x, const std::vector& sigmas, @@ -963,8 +965,9 @@ static sd::Tensor sample_dpmpp_2s_ancestral(denoise_cb_t model, float t_next = t_fn(sigma_down); float h = t_next - t; float s = t + 0.5f * h; - sd::Tensor x2 = (sigma_fn(s) / sigma_fn(t)) * x - (exp(-h * 0.5f) - 1) * denoised; - auto denoised2_opt = model(x2, sigmas[i + 1], i + 1); + float sigma_s = sigma_fn(s); + sd::Tensor x2 = (sigma_s / sigma_fn(t)) * x - (exp(-h * 0.5f) - 1) * denoised; + auto denoised2_opt = model(x2, sigma_s, i + 1); if (denoised2_opt.empty()) { return {}; } @@ -979,6 +982,100 @@ static sd::Tensor sample_dpmpp_2s_ancestral(denoise_cb_t model, return x; } +static sd::Tensor sample_dpmpp_2s_ancestral_flow(denoise_cb_t model, + sd::Tensor x, + const std::vector& sigmas, + std::shared_ptr rng, + float eta = 1.0f) { + int steps = static_cast(sigmas.size()) - 1; + for (int i = 0; i < steps; i++) { + float sigma = sigmas[i]; + float sigma_to = sigmas[i + 1]; + + bool opt_first_step = (1.0 - sigma < 1e-6); + + auto denoised_opt = model(x, sigma, (opt_first_step ? 1 : -1) * (i + 1)); + if (denoised_opt.empty()) { + return {}; + } + sd::Tensor denoised = std::move(denoised_opt); + + if (sigma_to == 0.0f) { + // Euler method (final step, no noise) + // sigma_to == 0 --> sigma_down = 0, so: + // x + d * (sigma_down - sigma) + // = x + ((x - denoised) / sigma) * (sigma_down - sigma) + // = x + ((x - denoised) / sigma) * ( 0 - sigma) + // = x + ((x - denoised) ) * -1 + // = x -x + denoised + x = denoised; + + } else { + auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step_flow(sigma, sigma_to, eta); + sd::Tensor D_i; + + if (opt_first_step) { + // the reformulated exp_s calc already accounts for this, but we can avoid + // a redundant model call for the typical sigma 1 at the first step: + // exp_s = sqrt((1-sigma)/sigma * (1-sigma_down)/sigma_down) + // = sqrt((1- 1)/ 1 * (1-sigma_down)/sigma_down) + // = 0 + // so sigma_s = 1 = sigma, and sigma_s_i_ratio = sigma_s / sigma = 1 + // u = (x*sigma_s_i_ratio)+(denoised*(1.0f-sigma_s_i_ratio)) + // = (x*1)+(denoised*0) = x + // so D_i = model(u, sigma_s, i + 1) + // = model(x, sigma, i + 1) + // = denoised + D_i = denoised; + + } else { + float sigma_s; + + // ref implementation would be: + // auto lambda_fn = [](float sigma) -> float { + // return std::log((1.0f - sigma) / sigma); }; + // auto sigma_fn = [](float lbda) -> float { + // return 1.0f / (std::exp(lbda) + 1.0f); }; + // t_i = lambda_fn(sigma); + // t_down = lambda_fn(sigma_down); + // float r = 0.5f; + // h = t_down - t_i; + // s = t_i + r * h; + // sigma_s = sigma_fn(s); + + // assuming r is constant, we sidestep the singularity at sigma -> 1 by: + // s = 0.5 * (lambda_fn(sigma) + lambda_fn(sigma_down)) + // = 0.5 * (log((1-sigma)/sigma) + log((1-sigma_down)/sigma_down)) + // = 0.5 * log(((1-sigma)/sigma) * ((1-sigma_down)/sigma_down)) + // = log(sqrt (((1-sigma)/sigma) * ((1-sigma_down)/sigma_down))) + // so exp(s) = sqrt((1-sigma)/sigma * (1-sigma_down)/sigma_down) + // and sigma_s = sigma_fn(s) = 1.0f / (exp(s) + 1.0f) + + float exp_s = std::sqrt(((1 - sigma) / sigma) * ((1 - sigma_down) / sigma_down)); + sigma_s = 1.0f / (exp_s + 1.0f); + + float sigma_s_i_ratio = sigma_s / sigma; + sd::Tensor u = (x * sigma_s_i_ratio) + (denoised * (1.0f - sigma_s_i_ratio)); + + auto denoised2_opt = model(u, sigma_s, i + 1); + if (denoised2_opt.empty()) { + return {}; + } + D_i = std::move(denoised2_opt); + } + + float sigma_down_i_ratio = sigma_down / sigma; + x = (x * sigma_down_i_ratio) + (D_i * (1.0f - sigma_down_i_ratio)); + + if (sigma_to > 0.0f && eta > 0.0f) { + x = alpha_scale * x + sd::Tensor::randn_like(x, rng) * sigma_up; + } + } + } + + return x; +} + static sd::Tensor sample_dpmpp_2m(denoise_cb_t model, sd::Tensor x, const std::vector& sigmas) { @@ -1050,7 +1147,8 @@ static sd::Tensor sample_dpmpp_2m_v2(denoise_cb_t model, static sd::Tensor sample_lcm(denoise_cb_t model, sd::Tensor x, const std::vector& sigmas, - std::shared_ptr rng) { + std::shared_ptr rng, + bool is_flow_denoiser) { int steps = static_cast(sigmas.size()) - 1; for (int i = 0; i < steps; i++) { auto denoised_opt = model(x, sigmas[i], i + 1); @@ -1059,6 +1157,9 @@ static sd::Tensor sample_lcm(denoise_cb_t model, } x = std::move(denoised_opt); if (sigmas[i + 1] > 0) { + if (is_flow_denoiser) { + x *= (1 - sigmas[i + 1]); + } x += sd::Tensor::randn_like(x, rng) * sigmas[i + 1]; } } @@ -1158,6 +1259,7 @@ static sd::Tensor sample_res_multistep(denoise_cb_t model, sd::Tensor x, const std::vector& sigmas, std::shared_ptr rng, + bool is_flow_denoiser, float eta) { sd::Tensor old_denoised = x; bool have_old_sigma = false; @@ -1189,7 +1291,8 @@ static sd::Tensor sample_res_multistep(denoise_cb_t model, float sigma_from = sigmas[i]; float sigma_to = sigmas[i + 1]; - auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta); + + auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma_from, sigma_to, eta, is_flow_denoiser); if (sigma_down == 0.0f || !have_old_sigma) { x += ((x - denoised) / sigma_from) * (sigma_down - sigma_from); @@ -1216,7 +1319,10 @@ static sd::Tensor sample_res_multistep(denoise_cb_t model, x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised); } - if (sigmas[i + 1] > 0 && sigma_up > 0.0f) { + if (sigma_to > 0.0f && sigma_up > 0.0f) { + if (is_flow_denoiser) { + x *= alpha_scale; + } x += sd::Tensor::randn_like(x, rng) * sigma_up; } @@ -1231,6 +1337,7 @@ static sd::Tensor sample_res_2s(denoise_cb_t model, sd::Tensor x, const std::vector& sigmas, std::shared_ptr rng, + bool is_flow_denoiser, float eta) { const float c2 = 0.5f; auto t_fn = [](float sigma) -> float { return -logf(sigma); }; @@ -1259,7 +1366,7 @@ static sd::Tensor sample_res_2s(denoise_cb_t model, } sd::Tensor denoised = std::move(denoised_opt); - auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta); + auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma_from, sigma_to, eta, is_flow_denoiser); sd::Tensor x0 = x; if (sigma_down == 0.0f || sigma_from == 0.0f) { @@ -1288,44 +1395,159 @@ static sd::Tensor sample_res_2s(denoise_cb_t model, x = x0 + h * (b1 * eps1 + b2 * eps2); } - if (sigmas[i + 1] > 0 && sigma_up > 0.0f) { + if (sigma_to > 0.0f && sigma_up > 0.0f) { + if (is_flow_denoiser) { + x *= alpha_scale; + } x += sd::Tensor::randn_like(x, rng) * sigma_up; } } return x; } +static sd::Tensor sample_er_sde(denoise_cb_t model, + sd::Tensor x, + std::vector sigmas, + std::shared_ptr rng, + bool is_flow_denoiser, + float eta) { + constexpr int max_stage = 3; + constexpr int num_integration_points = 200; + constexpr float num_integration_points_f = 200.0f; + float s_noise = eta; + + auto er_sde_flow_sigma = [](float sigma) -> float { + sigma = std::max(sigma, 1e-6f); + sigma = std::min(sigma, 1.0f - 1e-4f); + return sigma; + }; + + auto sigma_to_er_sde_lambda = [&](float sigma, bool is_flow_denoiser) -> float { + if (is_flow_denoiser) { + sigma = er_sde_flow_sigma(sigma); + return sigma / std::max(1.0f - sigma, 1e-6f); + } + return std::max(sigma, 1e-6f); + }; + + auto sigma_to_er_sde_alpha = [&](float sigma, bool is_flow_denoiser) -> float { + if (is_flow_denoiser) { + sigma = er_sde_flow_sigma(sigma); + return 1.0f - sigma; + } + return 1.0f; + }; + + auto er_sde_noise_scaler = [](float x) -> float { + x = std::max(x, 0.0f); + return x * (std::exp(std::pow(x, 0.3f)) + 10.0f); + }; + + if (is_flow_denoiser) { + for (size_t i = 0; i + 1 < sigmas.size(); ++i) { + if (sigmas[i] > 1.0f) { + sigmas[i] = er_sde_flow_sigma(sigmas[i]); + } + } + } + + std::vector er_lambdas(sigmas.size(), 0.0f); + for (size_t i = 0; i < sigmas.size(); ++i) { + er_lambdas[i] = sigma_to_er_sde_lambda(sigmas[i], is_flow_denoiser); + } + + sd::Tensor old_denoised = x; + sd::Tensor old_denoised_d = x; + bool have_old_denoised = false; + bool have_old_denoised_d = false; + + int steps = static_cast(sigmas.size()) - 1; + for (int i = 0; i < steps; i++) { + sd::Tensor denoised = model(x, sigmas[i], i + 1); + if (denoised.empty()) { + return {}; + } + + int stage_used = std::min(max_stage, i + 1); + + if (sigmas[i + 1] == 0.0f) { + x = denoised; + } else { + float er_lambda_s = er_lambdas[i]; + float er_lambda_t = er_lambdas[i + 1]; + float alpha_s = sigma_to_er_sde_alpha(sigmas[i], is_flow_denoiser); + float alpha_t = sigma_to_er_sde_alpha(sigmas[i + 1], is_flow_denoiser); + float scaled_s = er_sde_noise_scaler(er_lambda_s); + float scaled_t = er_sde_noise_scaler(er_lambda_t); + float r_alpha = alpha_s > 0.0f ? alpha_t / alpha_s : 0.0f; + float r = scaled_s > 0.0f ? scaled_t / scaled_s : 0.0f; + + x = r_alpha * r * x + alpha_t * (1.0f - r) * denoised; + + if (stage_used >= 2 && have_old_denoised) { + float dt = er_lambda_t - er_lambda_s; + float lambda_step_size = -dt / num_integration_points_f; + float s = 0.0f; + float s_u = 0.0f; + + for (int p = 0; p < num_integration_points; ++p) { + float lambda_pos = er_lambda_t + p * lambda_step_size; + float scaled_pos = er_sde_noise_scaler(lambda_pos); + if (scaled_pos <= 0.0f) { + continue; + } + + s += 1.0f / scaled_pos; + if (stage_used >= 3 && have_old_denoised_d) { + s_u += (lambda_pos - er_lambda_s) / scaled_pos; + } + } + + s *= lambda_step_size; + + float denom_d = er_lambda_s - er_lambdas[i - 1]; + if (std::fabs(denom_d) > 1e-12f) { + float coeff_d = alpha_t * (dt + s * scaled_t); + sd::Tensor denoised_d = (denoised - old_denoised) / denom_d; + x += coeff_d * denoised_d; + + if (stage_used >= 3 && have_old_denoised_d) { + float denom_u = (er_lambda_s - er_lambdas[i - 2]) * 0.5f; + if (std::fabs(denom_u) > 1e-12f) { + s_u *= lambda_step_size; + float coeff_u = alpha_t * (0.5f * dt * dt + s_u * scaled_t); + sd::Tensor denoised_u = (denoised_d - old_denoised_d) / denom_u; + x += coeff_u * denoised_u; + } + } + + old_denoised_d = denoised_d; + have_old_denoised_d = true; + } + } + + float noise_scale_sq = er_lambda_t * er_lambda_t - er_lambda_s * er_lambda_s * r * r; + if (s_noise > 0.0f && noise_scale_sq > 0.0f) { + float noise_scale = alpha_t * std::sqrt(std::max(noise_scale_sq, 0.0f)); + x += sd::Tensor::randn_like(x, rng) * noise_scale; + } + } + + old_denoised = denoised; + have_old_denoised = true; + } + return x; +} + static sd::Tensor sample_ddim_trailing(denoise_cb_t model, sd::Tensor x, const std::vector& sigmas, std::shared_ptr rng, float eta) { - float beta_start = 0.00085f; - float beta_end = 0.0120f; - std::vector alphas_cumprod(TIMESTEPS); - std::vector compvis_sigmas(TIMESTEPS); - for (int i = 0; i < TIMESTEPS; i++) { - alphas_cumprod[i] = - (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * - (1.0f - - std::pow(sqrtf(beta_start) + - (sqrtf(beta_end) - sqrtf(beta_start)) * - ((float)i / (TIMESTEPS - 1)), - 2)); - compvis_sigmas[i] = - std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]); - } - int steps = static_cast(sigmas.size()) - 1; for (int i = 0; i < steps; i++) { - int timestep = static_cast(roundf(TIMESTEPS - i * ((float)TIMESTEPS / steps))) - 1; - int prev_timestep = timestep - TIMESTEPS / steps; - float sigma = static_cast(compvis_sigmas[timestep]); - if (i == 0) { - x *= std::sqrt(sigma * sigma + 1) / sigma; - } else { - x *= std::sqrt(sigma * sigma + 1); - } + float sigma = sigmas[i]; + float sigma_to = sigmas[i + 1]; auto model_output_opt = model(x, sigma, i + 1); if (model_output_opt.empty()) { @@ -1334,8 +1556,8 @@ static sd::Tensor sample_ddim_trailing(denoise_cb_t model, sd::Tensor model_output = std::move(model_output_opt); model_output = (x - model_output) * (1.0f / sigma); - float alpha_prod_t = static_cast(alphas_cumprod[timestep]); - float alpha_prod_t_prev = static_cast(prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]); + float alpha_prod_t = 1.0f / (sigma * sigma + 1.0f); + float alpha_prod_t_prev = 1.0f / (sigma_to * sigma_to + 1.0f); float beta_prod_t = 1.0f - alpha_prod_t; sd::Tensor pred_original_sample = ((x / std::sqrt(sigma * sigma + 1)) - @@ -1347,11 +1569,11 @@ static sd::Tensor sample_ddim_trailing(denoise_cb_t model, (1.0f - alpha_prod_t / alpha_prod_t_prev); float std_dev_t = eta * std::sqrt(variance); - x = std::sqrt(alpha_prod_t_prev) * pred_original_sample + - std::sqrt(1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) * model_output; + x = pred_original_sample + + std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) / alpha_prod_t_prev) * model_output; if (eta > 0) { - x += std_dev_t * sd::Tensor::randn_like(x, rng); + x += std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor::randn_like(x, rng); } } return x; @@ -1378,19 +1600,26 @@ static sd::Tensor sample_tcd(denoise_cb_t model, std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]); } - int original_steps = 50; - int steps = static_cast(sigmas.size()) - 1; + auto get_timestep_from_sigma = [&](float s) -> int { + auto it = std::lower_bound(compvis_sigmas.begin(), compvis_sigmas.end(), s); + if (it == compvis_sigmas.begin()) + return 0; + if (it == compvis_sigmas.end()) + return TIMESTEPS - 1; + int idx_high = static_cast(std::distance(compvis_sigmas.begin(), it)); + int idx_low = idx_high - 1; + if (std::abs(compvis_sigmas[idx_high] - s) < std::abs(compvis_sigmas[idx_low] - s)) { + return idx_high; + } + return idx_low; + }; + + int steps = static_cast(sigmas.size()) - 1; for (int i = 0; i < steps; i++) { - int timestep = TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor(i * ((float)original_steps / steps)); - int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps)); + float sigma_to = sigmas[i + 1]; + int prev_timestep = get_timestep_from_sigma(sigma_to); int timestep_s = (int)floor((1 - eta) * prev_timestep); - float sigma = static_cast(compvis_sigmas[timestep]); - - if (i == 0) { - x *= std::sqrt(sigma * sigma + 1) / sigma; - } else { - x *= std::sqrt(sigma * sigma + 1); - } + float sigma = sigmas[i]; auto model_output_opt = model(x, sigma, i + 1); if (model_output_opt.empty()) { @@ -1399,9 +1628,9 @@ static sd::Tensor sample_tcd(denoise_cb_t model, sd::Tensor model_output = std::move(model_output_opt); model_output = (x - model_output) * (1.0f / sigma); - float alpha_prod_t = static_cast(alphas_cumprod[timestep]); + float alpha_prod_t = 1.0f / (sigma * sigma + 1.0f); float beta_prod_t = 1.0f - alpha_prod_t; - float alpha_prod_t_prev = static_cast(prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]); + float alpha_prod_t_prev = 1.0f / (sigma_to * sigma_to + 1.0f); float alpha_prod_s = static_cast(alphas_cumprod[timestep_s]); float beta_prod_s = 1.0f - alpha_prod_s; @@ -1409,12 +1638,12 @@ static sd::Tensor sample_tcd(denoise_cb_t model, std::sqrt(beta_prod_t) * model_output) * (1.0f / std::sqrt(alpha_prod_t)); - x = std::sqrt(alpha_prod_s) * pred_original_sample + - std::sqrt(beta_prod_s) * model_output; + x = std::sqrt(alpha_prod_s / alpha_prod_t_prev) * pred_original_sample + + std::sqrt(beta_prod_s / alpha_prod_t_prev) * model_output; - if (eta > 0 && i != steps - 1) { + if (eta > 0 && sigma_to > 0.0f) { x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x + - std::sqrt(1.0f - alpha_prod_t_prev / alpha_prod_s) * sd::Tensor::randn_like(x, rng); + std::sqrt(1.0f / alpha_prod_t_prev - 1.0f / alpha_prod_s) * sd::Tensor::randn_like(x, rng); } } return x; @@ -1441,21 +1670,26 @@ static sd::Tensor sample_k_diffusion(sample_method_t method, case DPM2_SAMPLE_METHOD: return sample_dpm2(model, std::move(x), sigmas); case DPMPP2S_A_SAMPLE_METHOD: - return sample_dpmpp_2s_ancestral(model, std::move(x), sigmas, rng, eta); + if (is_flow_denoiser) + return sample_dpmpp_2s_ancestral_flow(model, std::move(x), sigmas, rng, eta); + else + return sample_dpmpp_2s_ancestral(model, std::move(x), sigmas, rng, eta); case DPMPP2M_SAMPLE_METHOD: return sample_dpmpp_2m(model, std::move(x), sigmas); case DPMPP2Mv2_SAMPLE_METHOD: return sample_dpmpp_2m_v2(model, std::move(x), sigmas); case LCM_SAMPLE_METHOD: - return sample_lcm(model, std::move(x), sigmas, rng); + return sample_lcm(model, std::move(x), sigmas, rng, is_flow_denoiser); case IPNDM_SAMPLE_METHOD: return sample_ipndm(model, std::move(x), sigmas); case IPNDM_V_SAMPLE_METHOD: return sample_ipndm_v(model, std::move(x), sigmas); case RES_MULTISTEP_SAMPLE_METHOD: - return sample_res_multistep(model, std::move(x), sigmas, rng, eta); + return sample_res_multistep(model, std::move(x), sigmas, rng, is_flow_denoiser, eta); case RES_2S_SAMPLE_METHOD: - return sample_res_2s(model, std::move(x), sigmas, rng, eta); + return sample_res_2s(model, std::move(x), sigmas, rng, is_flow_denoiser, eta); + case ER_SDE_SAMPLE_METHOD: + return sample_er_sde(model, std::move(x), sigmas, rng, is_flow_denoiser, eta); case DDIM_TRAILING_SAMPLE_METHOD: return sample_ddim_trailing(model, std::move(x), sigmas, rng, eta); case TCD_SAMPLE_METHOD: diff --git a/src/diffusion_model.hpp b/src/diffusion_model.hpp index eb0debffc..1a202a1a7 100644 --- a/src/diffusion_model.hpp +++ b/src/diffusion_model.hpp @@ -3,6 +3,7 @@ #include #include "anima.hpp" +#include "ernie_image.hpp" #include "flux.hpp" #include "mmdit.hpp" #include "qwen_image.hpp" @@ -48,6 +49,7 @@ struct DiffusionModel { virtual void set_weight_adapter(const std::shared_ptr& adapter){}; virtual int64_t get_adm_in_channels() = 0; virtual void set_flash_attention_enabled(bool enabled) = 0; + virtual void set_max_graph_vram_bytes(size_t max_vram_bytes) = 0; virtual void set_circular_axes(bool circular_x, bool circular_y) = 0; }; @@ -97,6 +99,10 @@ struct UNetModel : public DiffusionModel { unet.set_flash_attention_enabled(enabled); } + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + unet.set_max_graph_vram_bytes(max_vram_bytes); + } + void set_circular_axes(bool circular_x, bool circular_y) override { unet.set_circular_axes(circular_x, circular_y); } @@ -163,6 +169,10 @@ struct MMDiTModel : public DiffusionModel { mmdit.set_flash_attention_enabled(enabled); } + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + mmdit.set_max_graph_vram_bytes(max_vram_bytes); + } + void set_circular_axes(bool circular_x, bool circular_y) override { mmdit.set_circular_axes(circular_x, circular_y); } @@ -228,6 +238,10 @@ struct FluxModel : public DiffusionModel { flux.set_flash_attention_enabled(enabled); } + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + flux.set_max_graph_vram_bytes(max_vram_bytes); + } + void set_circular_axes(bool circular_x, bool circular_y) override { flux.set_circular_axes(circular_x, circular_y); } @@ -298,6 +312,10 @@ struct AnimaModel : public DiffusionModel { anima.set_flash_attention_enabled(enabled); } + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + anima.set_max_graph_vram_bytes(max_vram_bytes); + } + void set_circular_axes(bool circular_x, bool circular_y) override { anima.set_circular_axes(circular_x, circular_y); } @@ -363,6 +381,10 @@ struct WanModel : public DiffusionModel { wan.set_flash_attention_enabled(enabled); } + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + wan.set_max_graph_vram_bytes(max_vram_bytes); + } + void set_circular_axes(bool circular_x, bool circular_y) override { wan.set_circular_axes(circular_x, circular_y); } @@ -432,6 +454,10 @@ struct QwenImageModel : public DiffusionModel { qwen_image.set_flash_attention_enabled(enabled); } + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + qwen_image.set_max_graph_vram_bytes(max_vram_bytes); + } + void set_circular_axes(bool circular_x, bool circular_y) override { qwen_image.set_circular_axes(circular_x, circular_y); } @@ -498,6 +524,10 @@ struct ZImageModel : public DiffusionModel { z_image.set_flash_attention_enabled(enabled); } + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + z_image.set_max_graph_vram_bytes(max_vram_bytes); + } + void set_circular_axes(bool circular_x, bool circular_y) override { z_image.set_circular_axes(circular_x, circular_y); } @@ -516,4 +546,70 @@ struct ZImageModel : public DiffusionModel { } }; +struct ErnieImageModel : public DiffusionModel { + std::string prefix; + ErnieImage::ErnieImageRunner ernie_image; + + ErnieImageModel(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "model.diffusion_model") + : prefix(prefix), ernie_image(backend, offload_params_to_cpu, tensor_storage_map, prefix) { + } + + std::string get_desc() override { + return ernie_image.get_desc(); + } + + void alloc_params_buffer() override { + ernie_image.alloc_params_buffer(); + } + + void free_params_buffer() override { + ernie_image.free_params_buffer(); + } + + void free_compute_buffer() override { + ernie_image.free_compute_buffer(); + } + + void get_param_tensors(std::map& tensors) override { + ernie_image.get_param_tensors(tensors, prefix); + } + + size_t get_params_buffer_size() override { + return ernie_image.get_params_buffer_size(); + } + + void set_weight_adapter(const std::shared_ptr& adapter) override { + ernie_image.set_weight_adapter(adapter); + } + + int64_t get_adm_in_channels() override { + return 768; + } + + void set_flash_attention_enabled(bool enabled) { + ernie_image.set_flash_attention_enabled(enabled); + } + + void set_max_graph_vram_bytes(size_t max_vram_bytes) override { + ernie_image.set_max_graph_vram_bytes(max_vram_bytes); + } + + void set_circular_axes(bool circular_x, bool circular_y) override { + ernie_image.set_circular_axes(circular_x, circular_y); + } + + sd::Tensor compute(int n_threads, + const DiffusionParams& diffusion_params) override { + GGML_ASSERT(diffusion_params.x != nullptr); + GGML_ASSERT(diffusion_params.timesteps != nullptr); + return ernie_image.compute(n_threads, + *diffusion_params.x, + *diffusion_params.timesteps, + tensor_or_empty(diffusion_params.context)); + } +}; + #endif diff --git a/src/ernie_image.hpp b/src/ernie_image.hpp new file mode 100644 index 000000000..931794f1a --- /dev/null +++ b/src/ernie_image.hpp @@ -0,0 +1,441 @@ +#ifndef __SD_ERNIE_IMAGE_HPP__ +#define __SD_ERNIE_IMAGE_HPP__ + +#include +#include + +#include "common_dit.hpp" +#include "flux.hpp" +#include "qwen_image.hpp" +#include "rope.hpp" + +namespace ErnieImage { + constexpr int ERNIE_IMAGE_GRAPH_SIZE = 40960; + + __STATIC_INLINE__ ggml_tensor* timestep_embedding_sin_cos(ggml_context* ctx, + ggml_tensor* timesteps, + int dim, + int max_period = 10000) { + auto emb = ggml_ext_timestep_embedding(ctx, timesteps, dim, max_period, 1.0f); + int64_t half = dim / 2; + auto cos_part = ggml_view_2d(ctx, emb, half, emb->ne[1], emb->nb[1], 0); + auto sin_part = ggml_view_2d(ctx, emb, half, emb->ne[1], emb->nb[1], half * emb->nb[0]); + auto sin_first = ggml_concat(ctx, sin_part, cos_part, 0); + return sin_first; + } + + __STATIC_INLINE__ ggml_tensor* apply_rotary_emb(ggml_context* ctx, ggml_tensor* x, ggml_tensor* pe) { + // x: [N, S, heads, head_dim] + // pe: [2, S, 1, head_dim], stored as ggml [head_dim, 1, S, 2]. + int64_t head_dim = x->ne[0]; + int64_t heads = x->ne[1]; + int64_t S = x->ne[2]; + int64_t N = x->ne[3]; + int64_t rot_dim = pe->ne[0]; + GGML_ASSERT(rot_dim <= head_dim); + GGML_ASSERT(rot_dim % 2 == 0); + GGML_ASSERT(pe->ne[1] == 1 && pe->ne[2] == S && pe->ne[3] == 2); + + x = ggml_cont(ctx, x); + auto x_rot = ggml_ext_slice(ctx, x, 0, 0, rot_dim, false); + auto x_pass = rot_dim < head_dim ? ggml_ext_slice(ctx, x, 0, rot_dim, head_dim, false) : nullptr; + + int64_t half = rot_dim / 2; + auto x1 = ggml_view_4d(ctx, x_rot, half, heads, S, N, x_rot->nb[1], x_rot->nb[2], x_rot->nb[3], 0); + auto x2 = ggml_view_4d(ctx, x_rot, half, heads, S, N, x_rot->nb[1], x_rot->nb[2], x_rot->nb[3], half * x_rot->nb[0]); + x1 = ggml_cont(ctx, x1); + x2 = ggml_cont(ctx, x2); + auto rotated = ggml_concat(ctx, ggml_neg(ctx, x2), x1, 0); + + auto cos_emb = ggml_ext_slice(ctx, pe, 3, 0, 1, false); + auto sin_emb = ggml_ext_slice(ctx, pe, 3, 1, 2, false); + + auto out = ggml_add(ctx, ggml_mul(ctx, x_rot, cos_emb), ggml_mul(ctx, rotated, sin_emb)); + if (x_pass != nullptr) { + out = ggml_concat(ctx, out, x_pass, 0); + } + return out; + } + + struct ErnieImageAttention : public GGMLBlock { + int64_t num_heads; + int64_t head_dim; + + ErnieImageAttention(int64_t query_dim, + int64_t heads, + int64_t dim_head, + float eps = 1e-6f) + : num_heads(heads), head_dim(dim_head) { + int64_t inner_dim = heads * dim_head; + blocks["to_q"] = std::make_shared(query_dim, inner_dim, false); + blocks["to_k"] = std::make_shared(query_dim, inner_dim, false); + blocks["to_v"] = std::make_shared(query_dim, inner_dim, false); + blocks["norm_q"] = std::make_shared(dim_head, eps); + blocks["norm_k"] = std::make_shared(dim_head, eps); + blocks["to_out.0"] = std::make_shared(inner_dim, query_dim, false); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* pe, + ggml_tensor* attention_mask = nullptr) { + // x: [N, S, hidden_size] + // pe: [S, head_dim/2, 2, 2], generated in image-token-first order. + auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); + auto to_k = std::dynamic_pointer_cast(blocks["to_k"]); + auto to_v = std::dynamic_pointer_cast(blocks["to_v"]); + auto norm_q = std::dynamic_pointer_cast(blocks["norm_q"]); + auto norm_k = std::dynamic_pointer_cast(blocks["norm_k"]); + auto to_out_0 = std::dynamic_pointer_cast(blocks["to_out.0"]); + + int64_t S = x->ne[1]; + int64_t N = x->ne[2]; + + auto q = to_q->forward(ctx, x); + auto k = to_k->forward(ctx, x); + auto v = to_v->forward(ctx, x); + + q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, S, N); // [N, S, heads, head_dim] + k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_heads, S, N); // [N, S, heads, head_dim] + v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_heads, S, N); // [N, S, heads, head_dim] + + q = norm_q->forward(ctx, q); + k = norm_k->forward(ctx, k); + + q = apply_rotary_emb(ctx->ggml_ctx, q, pe); + k = apply_rotary_emb(ctx->ggml_ctx, k, pe); + + q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 0, 2, 1, 3)); // [N, heads, S, head_dim] + q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); + + k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, heads, S, head_dim] + k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); + + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, ctx->flash_attn_enabled); // [N, S, hidden_size] + x = to_out_0->forward(ctx, x); + return x; + } + }; + + struct ErnieImageFeedForward : public GGMLBlock { + public: + ErnieImageFeedForward(int64_t hidden_size, int64_t ffn_hidden_size) { + blocks["gate_proj"] = std::make_shared(hidden_size, ffn_hidden_size, false); + blocks["up_proj"] = std::make_shared(hidden_size, ffn_hidden_size, false); + blocks["linear_fc2"] = std::make_shared(ffn_hidden_size, hidden_size, false); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto gate_proj = std::dynamic_pointer_cast(blocks["gate_proj"]); + auto up_proj = std::dynamic_pointer_cast(blocks["up_proj"]); + auto linear_fc2 = std::dynamic_pointer_cast(blocks["linear_fc2"]); + + auto gate = gate_proj->forward(ctx, x); + gate = ggml_ext_gelu(ctx->ggml_ctx, gate); + x = up_proj->forward(ctx, x); + x = ggml_mul(ctx->ggml_ctx, x, gate); + x = linear_fc2->forward(ctx, x); + return x; + } + }; + + struct ErnieImageSharedAdaLNBlock : public GGMLBlock { + public: + ErnieImageSharedAdaLNBlock(int64_t hidden_size, + int64_t num_heads, + int64_t ffn_hidden_size, + float eps = 1e-6f) { + blocks["adaLN_sa_ln"] = std::make_shared(hidden_size, eps); + blocks["self_attention"] = std::make_shared(hidden_size, + num_heads, + hidden_size / num_heads, + eps); + blocks["adaLN_mlp_ln"] = std::make_shared(hidden_size, eps); + blocks["mlp"] = std::make_shared(hidden_size, ffn_hidden_size); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* pe, + const std::vector& temb, + ggml_tensor* attention_mask = nullptr) { + // x: [N, image_tokens + text_tokens, hidden_size] + auto adaLN_sa_ln = std::dynamic_pointer_cast(blocks["adaLN_sa_ln"]); + auto self_attention = std::dynamic_pointer_cast(blocks["self_attention"]); + auto adaLN_mlp_ln = std::dynamic_pointer_cast(blocks["adaLN_mlp_ln"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + + auto shift_msa = temb[0]; + auto scale_msa = temb[1]; + auto gate_msa = temb[2]; + auto shift_mlp = temb[3]; + auto scale_mlp = temb[4]; + auto gate_mlp = temb[5]; + + auto residual = x; + x = adaLN_sa_ln->forward(ctx, x); + x = Flux::modulate(ctx->ggml_ctx, x, shift_msa, scale_msa, true); + auto attn_out = self_attention->forward(ctx, x, pe, attention_mask); + x = ggml_add(ctx->ggml_ctx, residual, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa)); + + residual = x; + x = adaLN_mlp_ln->forward(ctx, x); + x = Flux::modulate(ctx->ggml_ctx, x, shift_mlp, scale_mlp, true); + x = ggml_add(ctx->ggml_ctx, residual, ggml_mul(ctx->ggml_ctx, mlp->forward(ctx, x), gate_mlp)); + return x; + } + }; + + struct ErnieImageAdaLNContinuous : public GGMLBlock { + public: + ErnieImageAdaLNContinuous(int64_t hidden_size, float eps = 1e-6f) { + blocks["norm"] = std::make_shared(hidden_size, eps, false); + blocks["linear"] = std::make_shared(hidden_size, hidden_size * 2, true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* conditioning) { + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + + auto mods = ggml_ext_chunk(ctx->ggml_ctx, linear->forward(ctx, conditioning), 2, 0); + auto scale = mods[0]; + auto shift = mods[1]; + + x = norm->forward(ctx, x); + x = Flux::modulate(ctx->ggml_ctx, x, shift, scale); + return x; + } + }; + + struct ErnieImageParams { + int64_t hidden_size = 4096; + int64_t num_heads = 32; + int64_t num_layers = 36; + int64_t ffn_hidden_size = 12288; + int64_t in_channels = 128; + int64_t out_channels = 128; + int patch_size = 1; + int64_t text_in_dim = 3072; + int theta = 256; + std::vector axes_dim = {32, 48, 48}; + int axes_dim_sum = 128; + float eps = 1e-6f; + }; + + class ErnieImageModel : public GGMLBlock { + public: + ErnieImageParams params; + + ErnieImageModel() = default; + ErnieImageModel(ErnieImageParams params) + : params(params) { + blocks["x_embedder.proj"] = std::make_shared(params.in_channels, + params.hidden_size, + std::pair{params.patch_size, params.patch_size}, + std::pair{params.patch_size, params.patch_size}, + std::pair{0, 0}, + std::pair{1, 1}, + true); + if (params.text_in_dim != params.hidden_size) { + blocks["text_proj"] = std::make_shared(params.text_in_dim, params.hidden_size, false); + } + blocks["time_embedding"] = std::make_shared(params.hidden_size, params.hidden_size); + blocks["adaLN_modulation.1"] = std::make_shared(params.hidden_size, 6 * params.hidden_size, true); + + for (int i = 0; i < params.num_layers; i++) { + blocks["layers." + std::to_string(i)] = std::make_shared(params.hidden_size, + params.num_heads, + params.ffn_hidden_size, + params.eps); + } + + blocks["final_norm"] = std::make_shared(params.hidden_size, params.eps); + blocks["final_linear"] = std::make_shared(params.hidden_size, + params.patch_size * params.patch_size * params.out_channels, + true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* timestep, + ggml_tensor* context, + ggml_tensor* pe) { + // x: [N, C, H, W] + // context: [N, text_tokens, 3072] + // pe: [image_tokens + text_tokens, head_dim/2, 2, 2] + GGML_ASSERT(context != nullptr); + GGML_ASSERT(x->ne[1] % params.patch_size == 0 && x->ne[0] % params.patch_size == 0); + + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t Hp = H / params.patch_size; + int64_t Wp = W / params.patch_size; + int64_t n_img = Hp * Wp; + int64_t N = x->ne[3]; + + auto x_embedder_proj = std::dynamic_pointer_cast(blocks["x_embedder.proj"]); + auto time_embedding = std::dynamic_pointer_cast(blocks["time_embedding"]); + auto adaLN_mod = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + auto final_norm = std::dynamic_pointer_cast(blocks["final_norm"]); + auto final_linear = std::dynamic_pointer_cast(blocks["final_linear"]); + + auto img = x_embedder_proj->forward(ctx, x); // [N, hidden_size, Hp, Wp] + img = ggml_reshape_3d(ctx->ggml_ctx, img, img->ne[0] * img->ne[1], img->ne[2], N); // [N, hidden_size, image_tokens] + img = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img, 1, 0, 2, 3)); // [N, image_tokens, hidden_size] + + auto txt = context; + auto text_proj = std::dynamic_pointer_cast(blocks["text_proj"]); + if (text_proj) { + txt = text_proj->forward(ctx, txt); + } + + auto hidden_states = ggml_concat(ctx->ggml_ctx, img, txt, 1); // [N, image_tokens + text_tokens, hidden_size] + + auto sample = timestep_embedding_sin_cos(ctx->ggml_ctx, timestep, static_cast(params.hidden_size)); + auto c = time_embedding->forward(ctx, sample); // [N, hidden_size] + + auto mod_params = adaLN_mod->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 6 * hidden_size] + sd::ggml_graph_cut::mark_graph_cut(hidden_states, "ernie_image.prelude", "hidden_states"); + // sd::ggml_graph_cut::mark_graph_cut(mod_params, "ernie_image.prelude", "mod_params"); + auto chunks = ggml_ext_chunk(ctx->ggml_ctx, mod_params, 6, 0); + std::vector temb; + temb.reserve(6); + for (auto chunk : chunks) { + temb.push_back(ggml_reshape_3d(ctx->ggml_ctx, chunk, chunk->ne[0], 1, chunk->ne[1])); // [N, 1, hidden_size] + } + + for (int i = 0; i < params.num_layers; i++) { + auto layer = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); + hidden_states = layer->forward(ctx, hidden_states, pe, temb); + sd::ggml_graph_cut::mark_graph_cut(hidden_states, "ernie_image.layers." + std::to_string(i), "hidden_states"); + } + + hidden_states = final_norm->forward(ctx, hidden_states, c); + hidden_states = final_linear->forward(ctx, hidden_states); // [N, image_tokens, p*p*out_channels] + auto patches = ggml_ext_slice(ctx->ggml_ctx, hidden_states, 1, 0, n_img); // [N, image_tokens, hidden_size] + + auto out = DiT::unpatchify(ctx->ggml_ctx, + patches, + Hp, + Wp, + params.patch_size, + params.patch_size, + false); // [N, out_channels, H, W] + return out; + } + }; + + struct ErnieImageRunner : public GGMLRunner { + ErnieImageParams ernie_params; + ErnieImageModel ernie_image; + std::vector pe_vec; + + ErnieImageRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") + : GGMLRunner(backend, offload_params_to_cpu) { + ernie_params.num_layers = 0; + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + if (ends_with(name, "x_embedder.proj.weight") && tensor_storage.n_dims == 4) { + ernie_params.patch_size = static_cast(tensor_storage.ne[0]); + ernie_params.in_channels = tensor_storage.ne[2]; + ernie_params.hidden_size = tensor_storage.ne[3]; + } else if (ends_with(name, "text_proj.weight") && tensor_storage.n_dims == 2) { + ernie_params.text_in_dim = tensor_storage.ne[0]; + } else if (ends_with(name, "layers.0.self_attention.norm_q.weight")) { + int64_t head_dim = tensor_storage.ne[0]; + ernie_params.num_heads = ernie_params.hidden_size / head_dim; + } else if (ends_with(name, "layers.0.mlp.gate_proj.weight") && tensor_storage.n_dims == 2) { + ernie_params.ffn_hidden_size = tensor_storage.ne[1]; + } else if (ends_with(name, "final_linear.weight") && tensor_storage.n_dims == 2) { + int64_t out_dim = tensor_storage.ne[1]; + ernie_params.out_channels = out_dim / ernie_params.patch_size / ernie_params.patch_size; + } + + size_t pos = name.find("layers."); + if (pos != std::string::npos) { + std::string layer_name = name.substr(pos); + auto items = split_string(layer_name, '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + if (block_index + 1 > ernie_params.num_layers) { + ernie_params.num_layers = block_index + 1; + } + } + } + } + if (ernie_params.num_layers == 0) { + ernie_params.num_layers = 36; + } + ernie_params.axes_dim_sum = 0; + for (int axis_dim : ernie_params.axes_dim) { + ernie_params.axes_dim_sum += axis_dim; + } + + LOG_INFO("ernie_image: layers = %" PRId64 ", hidden_size = %" PRId64 ", heads = %" PRId64 + ", ffn_hidden_size = %" PRId64 ", in_channels = %" PRId64 ", out_channels = %" PRId64, + ernie_params.num_layers, + ernie_params.hidden_size, + ernie_params.num_heads, + ernie_params.ffn_hidden_size, + ernie_params.in_channels, + ernie_params.out_channels); + + ernie_image = ErnieImageModel(ernie_params); + ernie_image.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { + return "ernie_image"; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + ernie_image.get_param_tensors(tensors, prefix); + } + + ggml_cgraph* build_graph(const sd::Tensor& x_tensor, + const sd::Tensor& timesteps_tensor, + const sd::Tensor& context_tensor) { + ggml_cgraph* gf = new_graph_custom(ERNIE_IMAGE_GRAPH_SIZE); + ggml_tensor* x = make_input(x_tensor); + ggml_tensor* timesteps = make_input(timesteps_tensor); + GGML_ASSERT(x->ne[3] == 1); + GGML_ASSERT(!context_tensor.empty()); + ggml_tensor* context = make_input(context_tensor); + + pe_vec = Rope::gen_ernie_image_pe(static_cast(x->ne[1]), + static_cast(x->ne[0]), + ernie_params.patch_size, + static_cast(x->ne[3]), + static_cast(context->ne[1]), + ernie_params.theta, + circular_y_enabled, + circular_x_enabled, + ernie_params.axes_dim); + int pos_len = static_cast(pe_vec.size() / ernie_params.axes_dim_sum / 2); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, ernie_params.axes_dim_sum, 1, pos_len, 2); + set_backend_tensor_data(pe, pe_vec.data()); + + auto runner_ctx = get_context(); + ggml_tensor* out = ernie_image.forward(&runner_ctx, x, timesteps, context, pe); + ggml_build_forward_expand(gf, out); + return gf; + } + + sd::Tensor compute(int n_threads, + const sd::Tensor& x, + const sd::Tensor& timesteps, + const sd::Tensor& context) { + auto get_graph = [&]() -> ggml_cgraph* { + return build_graph(x, timesteps, context); + }; + return restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), x.dim()); + } + }; +} // namespace ErnieImage + +#endif // __SD_ERNIE_IMAGE_HPP__ diff --git a/src/esrgan.hpp b/src/esrgan.hpp index 26c46f5b3..f84b77a29 100644 --- a/src/esrgan.hpp +++ b/src/esrgan.hpp @@ -124,27 +124,33 @@ class RRDBNet : public GGMLBlock { auto conv_hr = std::dynamic_pointer_cast(blocks["conv_hr"]); auto conv_last = std::dynamic_pointer_cast(blocks["conv_last"]); - auto feat = conv_first->forward(ctx, x); + auto feat = conv_first->forward(ctx, x); + sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.prelude", "feat"); auto body_feat = feat; for (int i = 0; i < num_block; i++) { std::string name = "body." + std::to_string(i); auto block = std::dynamic_pointer_cast(blocks[name]); body_feat = block->forward(ctx, body_feat); + sd::ggml_graph_cut::mark_graph_cut(body_feat, "esrgan.body." + std::to_string(i), "feat"); } body_feat = conv_body->forward(ctx, body_feat); feat = ggml_add(ctx->ggml_ctx, feat, body_feat); + sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.body.out", "feat"); // upsample if (scale >= 2) { auto conv_up1 = std::dynamic_pointer_cast(blocks["conv_up1"]); feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); + sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.up1", "feat"); if (scale == 4) { auto conv_up2 = std::dynamic_pointer_cast(blocks["conv_up2"]); feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); + sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.up2", "feat"); } } // for all scales auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat))); + sd::ggml_graph_cut::mark_graph_cut(out, "esrgan.final", "out"); return out; } }; diff --git a/src/flux.hpp b/src/flux.hpp index e6bf002fb..732a37197 100644 --- a/src/flux.hpp +++ b/src/flux.hpp @@ -928,6 +928,9 @@ namespace Flux { } txt = txt_in->forward(ctx, txt); + sd::ggml_graph_cut::mark_graph_cut(img, "flux.prelude", "img"); + sd::ggml_graph_cut::mark_graph_cut(txt, "flux.prelude", "txt"); + sd::ggml_graph_cut::mark_graph_cut(vec, "flux.prelude", "vec"); for (int i = 0; i < params.depth; i++) { if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) { @@ -939,6 +942,8 @@ namespace Flux { auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask, ds_img_mods, ds_txt_mods); img = img_txt.first; // [N, n_img_token, hidden_size] txt = img_txt.second; // [N, n_txt_token, hidden_size] + sd::ggml_graph_cut::mark_graph_cut(img, "flux.double_blocks." + std::to_string(i), "img"); + sd::ggml_graph_cut::mark_graph_cut(txt, "flux.double_blocks." + std::to_string(i), "txt"); } auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size] @@ -949,6 +954,7 @@ namespace Flux { auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods); + sd::ggml_graph_cut::mark_graph_cut(txt_img, "flux.single_blocks." + std::to_string(i), "txt_img"); } img = ggml_view_3d(ctx->ggml_ctx, diff --git a/src/ggml_extend.hpp b/src/ggml_extend.hpp index 859270cbd..362303229 100644 --- a/src/ggml_extend.hpp +++ b/src/ggml_extend.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -24,32 +25,13 @@ #include "ggml-alloc.h" #include "ggml-backend.h" -#include "ggml-cpu.h" #include "ggml.h" +#include "ggml_extend_backend.hpp" +#include "ggml_graph_cut.h" #include "model.h" #include "tensor.hpp" -#ifdef SD_USE_CUDA -#include "ggml-cuda.h" -#endif - -#ifdef SD_USE_METAL -#include "ggml-metal.h" -#endif - -#ifdef SD_USE_VULKAN -#include "ggml-vulkan.h" -#endif - -#ifdef SD_USE_OPENCL -#include "ggml-opencl.h" -#endif - -#ifdef SD_USE_SYCL -#include "ggml-sycl.h" -#endif - #include "rng.hpp" #include "tensor_ggml.hpp" #include "util.h" @@ -91,6 +73,48 @@ __STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const cha } } +__STATIC_INLINE__ bool backend_name_exists(std::string name) { + ggml_backend_load_all_once(); + const size_t device_count = ggml_backend_dev_count(); + for (size_t i = 0; i < device_count; ++i) { + if (name == ggml_backend_dev_name(ggml_backend_dev_get(i))) { + return true; + } + } + return false; +} + +__STATIC_INLINE__ std::string sanitize_backend_name(std::string name) { + if (name == "" || backend_name_exists(name)) { + return name; + } else { + LOG_WARN("Backend %s not found, using default backend", name.c_str()); + return ""; + } +} + +__STATIC_INLINE__ std::string get_default_backend_name() { + ggml_backend_load_all_once(); + // should pick the same backend as ggml_backend_init_best + ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU); + dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU); + dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (dev == nullptr) { + return ""; + } + return ggml_backend_dev_name(dev); +} + +__STATIC_INLINE__ ggml_backend_t init_named_backend(std::string name = "") { + ggml_backend_load_all_once(); + LOG_DEBUG("Initializing backend: %s", name.c_str()); + if (name.empty()) { + return ggml_backend_init_best(); + } else { + return ggml_backend_init_by_name(name.c_str(), nullptr); + } +} + static_assert(GGML_MAX_NAME >= 128, "GGML_MAX_NAME must be at least 128"); // n-mode tensor-matrix product @@ -1286,25 +1310,25 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_ones_like(ggml_context* ctx, return ggml_ext_ones(ctx, x->ne[0], x->ne[1], x->ne[2], x->ne[3]); } -__STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor* a) { -#ifdef SD_USE_VULKAN - auto zero_index = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:zero_int"); - auto out = ggml_reshape_1d(ctx, a, ggml_nelements(a)); - out = ggml_get_rows(ctx, out, zero_index); - out = ggml_reshape(ctx, out, a); - // auto out = ggml_cast(ctx, a, GGML_TYPE_F32); - return out; -#else - auto out = ggml_reshape_2d(ctx, a, 1, ggml_nelements(a)); - ggml_tensor* one = ggml_ext_ones(ctx, 1, 1, 1, 1); // [1,] - if (ggml_is_transposed(out)) { - out = ggml_mul_mat(ctx, one, out); +__STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_backend_t backend, ggml_tensor* a) { + if (sd_backend_is(backend, "Vulkan")) { + auto zero_index = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:zero_int"); + auto out = ggml_reshape_1d(ctx, a, ggml_nelements(a)); + out = ggml_get_rows(ctx, out, zero_index); + out = ggml_reshape(ctx, out, a); + // auto out = ggml_cast(ctx, a, GGML_TYPE_F32); + return out; } else { - out = ggml_mul_mat(ctx, out, one); + auto out = ggml_reshape_2d(ctx, a, 1, ggml_nelements(a)); + ggml_tensor* one = ggml_ext_ones(ctx, 1, 1, 1, 1); // [1,] + if (ggml_is_transposed(out)) { + out = ggml_mul_mat(ctx, one, out); + } else { + out = ggml_mul_mat(ctx, out, one); + } + out = ggml_reshape(ctx, out, a); + return out; } - out = ggml_reshape(ctx, out, a); -#endif - return out; } // q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head] @@ -1496,16 +1520,14 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_group_norm(ggml_context* ctx, } __STATIC_INLINE__ void ggml_ext_backend_tensor_get_and_sync(ggml_backend_t backend, const ggml_tensor* tensor, void* data, size_t offset, size_t size) { -#if defined(SD_USE_CUDA) || defined(SD_USE_SYCL) - if (!ggml_backend_is_cpu(backend)) { + if ((sd_backend_is(backend, "ROCm") || sd_backend_is(backend, "CUDA") || sd_backend_is(backend, "SYCL")) && + !ggml_backend_is_cpu(backend)) { ggml_backend_tensor_get_async(backend, tensor, data, offset, size); ggml_backend_synchronize(backend); - } else { - ggml_backend_tensor_get(tensor, data, offset, size); + return; } -#else + ggml_backend_tensor_get(tensor, data, offset, size); -#endif } __STATIC_INLINE__ float ggml_ext_backend_tensor_get_f32(ggml_tensor* tensor) { @@ -1664,14 +1686,15 @@ struct WeightAdapter { float scale = 1.f; } conv2d; }; - virtual ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) = 0; + virtual ggml_tensor* patch_weight(ggml_context* ctx, ggml_backend_t backend, ggml_tensor* weight, const std::string& weight_name) = 0; virtual ggml_tensor* forward_with_lora(ggml_context* ctx, + ggml_backend_t backend, ggml_tensor* x, ggml_tensor* w, ggml_tensor* b, const std::string& prefix, - ForwardParams forward_params) = 0; - virtual size_t get_extra_graph_size() = 0; + ForwardParams forward_params) = 0; + virtual size_t get_extra_graph_size() = 0; }; struct GGMLRunnerContext { @@ -1687,6 +1710,8 @@ struct GGMLRunnerContext { struct GGMLRunner { protected: typedef std::function get_graph_cb_t; + using GraphCutSegment = sd::ggml_graph_cut::Segment; + using GraphCutPlan = sd::ggml_graph_cut::Plan; ggml_backend_t params_backend = nullptr; ggml_backend_t runtime_backend = nullptr; @@ -1703,6 +1728,11 @@ struct GGMLRunner { ggml_context* compute_ctx = nullptr; ggml_gallocr* compute_allocr = nullptr; + ggml_context* partial_offload_ctx = nullptr; + ggml_backend_buffer_t partial_runtime_params_buffer = nullptr; + std::vector> partial_offload_pairs; + size_t max_graph_vram_bytes = 0; + std::shared_ptr weight_adapter = nullptr; std::vector one_vec = {1.f}; @@ -1720,6 +1750,9 @@ struct GGMLRunner { bool circular_x_enabled = false; bool circular_y_enabled = false; + sd::ggml_graph_cut::PlanCache graph_cut_plan_cache_; + std::unordered_set params_tensor_set_; + template static sd::Tensor take_or_empty(std::optional> tensor) { if (!tensor.has_value()) { @@ -1754,6 +1787,7 @@ struct GGMLRunner { params_ctx = ggml_init(params); GGML_ASSERT(params_ctx != nullptr); + params_tensor_set_.clear(); if (params_backend != runtime_backend) { offload_ctx = ggml_init(params); GGML_ASSERT(offload_ctx != nullptr); @@ -1765,10 +1799,15 @@ struct GGMLRunner { ggml_free(params_ctx); params_ctx = nullptr; } + params_tensor_set_.clear(); if (offload_ctx != nullptr) { ggml_free(offload_ctx); offload_ctx = nullptr; } + if (partial_offload_ctx != nullptr) { + ggml_free(partial_offload_ctx); + partial_offload_ctx = nullptr; + } } void alloc_cache_ctx() { @@ -1803,6 +1842,17 @@ struct GGMLRunner { ggml_free(compute_ctx); compute_ctx = nullptr; } + backend_tensor_data_map.clear(); + } + + void rebuild_params_tensor_set() { + params_tensor_set_.clear(); + if (params_ctx == nullptr) { + return; + } + for (ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != nullptr; t = ggml_get_next_tensor(params_ctx, t)) { + params_tensor_set_.insert(t); + } } void prepare_build_in_tensor_before() { @@ -1838,13 +1888,25 @@ struct GGMLRunner { return gf; } - bool alloc_compute_buffer(get_graph_cb_t get_graph) { + bool prepare_compute_graph(get_graph_cb_t get_graph, + ggml_cgraph** gf_out) { + GGML_ASSERT(gf_out != nullptr); + + reset_compute_ctx(); + ggml_cgraph* gf = get_compute_graph(get_graph); + if (gf == nullptr) { + free_compute_ctx(); + return false; + } + + *gf_out = gf; + return true; + } + + bool alloc_compute_buffer(ggml_cgraph* gf) { if (compute_allocr != nullptr) { return true; } - reset_compute_ctx(); - ggml_cgraph* gf = get_compute_graph(get_graph); - backend_tensor_data_map.clear(); compute_allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(runtime_backend)); if (!ggml_gallocr_reserve(compute_allocr, gf)) { @@ -1870,47 +1932,132 @@ struct GGMLRunner { } } - void copy_cache_tensors_to_cache_buffer() { - if (cache_tensor_map.size() == 0) { - return; + bool copy_cache_tensors_to_cache_buffer(const std::unordered_set* cache_keep_names = nullptr) { + ggml_context* old_cache_ctx = cache_ctx; + ggml_backend_buffer_t old_cache_buffer = cache_buffer; + cache_ctx = nullptr; + cache_buffer = nullptr; + std::map merged_cache_sources; + if (old_cache_ctx != nullptr) { + for (ggml_tensor* tensor = ggml_get_first_tensor(old_cache_ctx); tensor != nullptr; tensor = ggml_get_next_tensor(old_cache_ctx, tensor)) { + if (cache_keep_names != nullptr && cache_keep_names->find(tensor->name) == cache_keep_names->end()) { + continue; + } + merged_cache_sources[tensor->name] = tensor; + } } - free_cache_ctx_and_buffer(); + for (const auto& kv : cache_tensor_map) { + if (cache_keep_names != nullptr && cache_keep_names->find(kv.first) == cache_keep_names->end()) { + continue; + } + merged_cache_sources[kv.first] = kv.second; + } + cache_tensor_map.clear(); + if (merged_cache_sources.empty()) { + if (old_cache_buffer != nullptr) { + ggml_backend_buffer_free(old_cache_buffer); + } + if (old_cache_ctx != nullptr) { + ggml_free(old_cache_ctx); + } + return true; + } + alloc_cache_ctx(); - GGML_ASSERT(cache_buffer == nullptr); - std::map runtime_tensor_to_cache_tensor; - for (auto kv : cache_tensor_map) { - auto cache_tensor = ggml_dup_tensor(cache_ctx, kv.second); + std::vector> source_to_cache_tensors; + source_to_cache_tensors.reserve(merged_cache_sources.size()); + for (const auto& kv : merged_cache_sources) { + ggml_tensor* source_tensor = sd::ggml_graph_cut::cache_source_tensor(kv.second); + auto cache_tensor = ggml_dup_tensor(cache_ctx, source_tensor); ggml_set_name(cache_tensor, kv.first.c_str()); - runtime_tensor_to_cache_tensor[kv.second] = cache_tensor; + source_to_cache_tensors.push_back({source_tensor, cache_tensor}); } size_t num_tensors = ggml_tensor_num(cache_ctx); cache_buffer = ggml_backend_alloc_ctx_tensors(cache_ctx, runtime_backend); GGML_ASSERT(cache_buffer != nullptr); - for (auto kv : runtime_tensor_to_cache_tensor) { - ggml_backend_tensor_copy(kv.first, kv.second); + for (const auto& kv : source_to_cache_tensors) { + ggml_tensor* src = kv.first; + ggml_tensor* dst = kv.second; + ggml_backend_buffer_t src_buf = sd::ggml_graph_cut::tensor_buffer(src); + ggml_backend_buffer_t dst_buf = sd::ggml_graph_cut::tensor_buffer(dst); + if (src_buf == nullptr || dst_buf == nullptr) { + LOG_ERROR("%s cache copy tensor buffer missing: name=%s src_buffer=%p src_view_src=%p src_view_src_buffer=%p dst_buffer=%p", + get_desc().c_str(), + src && src->name[0] != '\0' ? src->name : "", + src ? src->buffer : nullptr, + src ? src->view_src : nullptr, + (src && src->view_src) ? src->view_src->buffer : nullptr, + dst ? dst->buffer : nullptr); + return false; + } + const bool use_staging_copy = src->view_src != nullptr || !ggml_is_contiguous(src) || src->buffer == nullptr; + if (use_staging_copy) { + std::vector host_data(ggml_nbytes(src)); + ggml_backend_tensor_get(src, host_data.data(), 0, host_data.size()); + ggml_backend_tensor_set(dst, host_data.data(), 0, host_data.size()); + } else { + ggml_backend_tensor_copy(src, dst); + } } ggml_backend_synchronize(runtime_backend); - cache_tensor_map.clear(); size_t cache_buffer_size = ggml_backend_buffer_get_size(cache_buffer); LOG_DEBUG("%s cache backend buffer size = % 6.2f MB(%s) (%i tensors)", get_desc().c_str(), cache_buffer_size / (1024.f * 1024.f), ggml_backend_is_cpu(runtime_backend) ? "RAM" : "VRAM", num_tensors); + if (old_cache_buffer != nullptr) { + ggml_backend_buffer_free(old_cache_buffer); + } + if (old_cache_ctx != nullptr) { + ggml_free(old_cache_ctx); + } + return true; } - void copy_data_to_backend_tensor() { + void copy_data_to_backend_tensor(ggml_cgraph* gf, bool clear_after_copy = true) { + GGML_ASSERT(gf != nullptr); + std::unordered_set graph_tensor_set; + const int n_leafs = sd::ggml_graph_cut::leaf_count(gf); + const int n_nodes = ggml_graph_n_nodes(gf); + graph_tensor_set.reserve(static_cast(n_leafs + n_nodes)); + for (int i = 0; i < n_leafs; ++i) { + graph_tensor_set.insert(sd::ggml_graph_cut::leaf_tensor(gf, i)); + } + for (int i = 0; i < n_nodes; ++i) { + graph_tensor_set.insert(ggml_graph_node(gf, i)); + } + for (auto& kv : backend_tensor_data_map) { auto tensor = kv.first; auto data = kv.second; + if (graph_tensor_set.find(tensor) == graph_tensor_set.end()) { + continue; + } + + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + if (buf == nullptr) { + LOG_WARN("%s graph exec skip tensor copy: name=%s op=%s reason=buffer_not_set data=%p view_src=%p view_src_buffer=%p", + get_desc().c_str(), + tensor && tensor->name[0] != '\0' ? tensor->name : "", + tensor ? ggml_op_name(tensor->op) : "", + data, + tensor ? tensor->view_src : nullptr, + (tensor && tensor->view_src) ? tensor->view_src->buffer : nullptr); + continue; + } + ggml_backend_tensor_set(tensor, data, 0, ggml_nbytes(tensor)); } - backend_tensor_data_map.clear(); + if (clear_after_copy) { + backend_tensor_data_map.clear(); + } } - bool offload_params_to_runtime_backend() { + bool offload_all_params() { + restore_partial_params(); if (params_backend == runtime_backend) { return true; } @@ -1937,6 +2084,7 @@ struct GGMLRunner { num_tensors); return false; } + ggml_backend_buffer_set_usage(runtime_params_buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); ggml_tensor* t = ggml_get_first_tensor(params_ctx); ggml_tensor* offload_t = ggml_get_first_tensor(offload_ctx); @@ -1966,7 +2114,85 @@ struct GGMLRunner { return true; } - void offload_params_to_params_backend() { + bool offload_partial_params(const std::vector& tensors) { + restore_partial_params(); + if (params_backend == runtime_backend) { + return true; + } + if (tensors.empty()) { + return true; + } + GGML_ASSERT(!params_on_runtime_backend); + GGML_ASSERT(partial_runtime_params_buffer == nullptr); + + std::vector unique_tensors; + std::unordered_set seen_tensors; + unique_tensors.reserve(tensors.size()); + seen_tensors.reserve(tensors.size()); + for (ggml_tensor* tensor : tensors) { + if (tensor == nullptr) { + continue; + } + if (seen_tensors.insert(tensor).second) { + unique_tensors.push_back(tensor); + } + } + if (unique_tensors.empty()) { + return true; + } + + ggml_init_params params; + params.mem_size = std::max(1, unique_tensors.size()) * ggml_tensor_overhead(); + params.mem_buffer = nullptr; + params.no_alloc = true; + + partial_offload_ctx = ggml_init(params); + GGML_ASSERT(partial_offload_ctx != nullptr); + + partial_offload_pairs.clear(); + partial_offload_pairs.reserve(unique_tensors.size()); + + for (ggml_tensor* tensor : unique_tensors) { + GGML_ASSERT(tensor->view_src == nullptr); + ggml_tensor* offload_tensor = ggml_dup_tensor(partial_offload_ctx, tensor); + ggml_set_name(offload_tensor, tensor->name); + partial_offload_pairs.push_back({tensor, offload_tensor}); + } + + partial_runtime_params_buffer = ggml_backend_alloc_ctx_tensors(partial_offload_ctx, runtime_backend); + if (partial_runtime_params_buffer == nullptr) { + LOG_ERROR("%s alloc partial runtime params backend buffer failed, num_tensors = %zu", + get_desc().c_str(), + partial_offload_pairs.size()); + ggml_free(partial_offload_ctx); + partial_offload_ctx = nullptr; + partial_offload_pairs.clear(); + return false; + } + ggml_backend_buffer_set_usage(partial_runtime_params_buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + + for (auto& pair : partial_offload_pairs) { + ggml_tensor* tensor = pair.first; + ggml_tensor* offload_tensor = pair.second; + + ggml_backend_tensor_copy(tensor, offload_tensor); + std::swap(tensor->buffer, offload_tensor->buffer); + std::swap(tensor->data, offload_tensor->data); + std::swap(tensor->extra, offload_tensor->extra); + } + + size_t params_buffer_size = ggml_backend_buffer_get_size(partial_runtime_params_buffer); + LOG_DEBUG("%s offload partial params (%6.2f MB, %zu tensors) to runtime backend (%s)", + get_desc().c_str(), + params_buffer_size / (1024.f * 1024.f), + partial_offload_pairs.size(), + ggml_backend_name(runtime_backend)); + + return true; + } + + void restore_all_params() { + restore_partial_params(); if (!params_on_runtime_backend) { return; } @@ -1992,17 +2218,323 @@ struct GGMLRunner { params_on_runtime_backend = false; } + void restore_partial_params() { + if (partial_offload_pairs.empty()) { + if (partial_runtime_params_buffer != nullptr) { + ggml_backend_buffer_free(partial_runtime_params_buffer); + partial_runtime_params_buffer = nullptr; + } + if (partial_offload_ctx != nullptr) { + ggml_free(partial_offload_ctx); + partial_offload_ctx = nullptr; + } + return; + } + + for (auto& pair : partial_offload_pairs) { + ggml_tensor* tensor = pair.first; + ggml_tensor* offload_tensor = pair.second; + + tensor->buffer = offload_tensor->buffer; + tensor->data = offload_tensor->data; + tensor->extra = offload_tensor->extra; + offload_tensor->buffer = nullptr; + offload_tensor->data = nullptr; + offload_tensor->extra = nullptr; + } + + if (partial_runtime_params_buffer != nullptr) { + ggml_backend_buffer_free(partial_runtime_params_buffer); + partial_runtime_params_buffer = nullptr; + } + partial_offload_pairs.clear(); + + if (partial_offload_ctx != nullptr) { + ggml_free(partial_offload_ctx); + partial_offload_ctx = nullptr; + } + } + + bool should_use_graph_cut_segmented_compute(const GraphCutPlan& plan) { + return plan.has_cuts && + plan.valid && + max_graph_vram_bytes > 0 && + plan.segments.size() > 1 && + params_backend != runtime_backend && + !ggml_backend_is_cpu(runtime_backend); + } + + bool can_attempt_graph_cut_segmented_compute() const { + return max_graph_vram_bytes > 0 && + params_backend != runtime_backend && + !ggml_backend_is_cpu(runtime_backend); + } + + bool resolve_graph_cut_plan(ggml_cgraph* gf, + GraphCutPlan* plan_out) { + GGML_ASSERT(plan_out != nullptr); + GGML_ASSERT(gf != nullptr); + *plan_out = sd::ggml_graph_cut::resolve_plan(runtime_backend, + gf, + &graph_cut_plan_cache_, + max_graph_vram_bytes, + params_tensor_set_, + get_desc().c_str()); + return true; + } + + void reset_segment_runtime_tensors(const GraphCutSegment& segment, + ggml_cgraph* gf) { + GGML_ASSERT(gf != nullptr); + + for (const auto& input : segment.input_refs) { + ggml_tensor* input_tensor = sd::ggml_graph_cut::input_tensor(gf, input); + if (input_tensor == nullptr) { + continue; + } + switch (input.type) { + case GraphCutSegment::INPUT_PREVIOUS_CUT: + case GraphCutSegment::INPUT_EXTERNAL: + input_tensor->buffer = nullptr; + input_tensor->data = nullptr; + input_tensor->extra = nullptr; + break; + case GraphCutSegment::INPUT_PARAM: + break; + } + } + + for (int node_idx : segment.internal_node_indices) { + ggml_tensor* node = ggml_graph_node(gf, node_idx); + if (node == nullptr) { + continue; + } + node->buffer = nullptr; + node->data = nullptr; + node->extra = nullptr; + } + } + + bool bind_segment_cached_inputs(ggml_cgraph* gf, const GraphCutSegment& segment) { + GGML_ASSERT(gf != nullptr); + for (const auto& input : segment.input_refs) { + ggml_tensor* input_tensor = sd::ggml_graph_cut::input_tensor(gf, input); + if (input_tensor == nullptr) { + continue; + } + switch (input.type) { + case GraphCutSegment::INPUT_PREVIOUS_CUT: { + ggml_tensor* cache_tensor = get_cache_tensor_by_name(input.display_name); + if (cache_tensor == nullptr) { + LOG_ERROR("%s missing graph cut cache tensor: %s", + get_desc().c_str(), + input.display_name.c_str()); + return false; + } + if (input_tensor->view_src != nullptr) { + input_tensor->view_src = cache_tensor; + input_tensor->buffer = nullptr; + input_tensor->data = cache_tensor->data == nullptr + ? nullptr + : static_cast(static_cast(cache_tensor->data) + input_tensor->view_offs); + input_tensor->extra = cache_tensor->extra; + } else { + input_tensor->buffer = cache_tensor->buffer; + input_tensor->data = cache_tensor->data; + input_tensor->extra = cache_tensor->extra; + } + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + input_tensor->src[src_idx] = nullptr; + } + input_tensor->op = GGML_OP_NONE; + break; + } + case GraphCutSegment::INPUT_EXTERNAL: + case GraphCutSegment::INPUT_PARAM: + break; + } + } + return true; + } + + template + std::optional> execute_graph(ggml_cgraph* gf, + int n_threads, + bool free_compute_buffer_immediately, + const std::vector& runtime_param_tensors, + bool preserve_backend_tensor_data_map, + bool no_return = false, + const std::unordered_set* cache_keep_names = nullptr) { + int64_t t_execute_begin = ggml_time_ms(); + const bool use_partial_param_offload = !runtime_param_tensors.empty(); + int64_t t_offload_begin = ggml_time_ms(); + if (use_partial_param_offload) { + if (!offload_partial_params(runtime_param_tensors)) { + LOG_ERROR("%s offload partial params to runtime backend failed", get_desc().c_str()); + return std::nullopt; + } + } else { + if (!offload_all_params()) { + LOG_ERROR("%s offload params to runtime backend failed", get_desc().c_str()); + return std::nullopt; + } + } + int64_t t_offload_end = ggml_time_ms(); + + int64_t t_alloc_begin = ggml_time_ms(); + if (!alloc_compute_buffer(gf)) { + LOG_ERROR("%s alloc compute buffer failed", get_desc().c_str()); + if (use_partial_param_offload) { + restore_partial_params(); + } + return std::nullopt; + } + + if (!ggml_gallocr_alloc_graph(compute_allocr, gf)) { + LOG_ERROR("%s alloc compute graph failed", get_desc().c_str()); + if (free_compute_buffer_immediately) { + free_compute_buffer(); + } else if (use_partial_param_offload) { + restore_partial_params(); + } + return std::nullopt; + } + int64_t t_alloc_end = ggml_time_ms(); + + int64_t t_copy_begin = ggml_time_ms(); + copy_data_to_backend_tensor(gf, !preserve_backend_tensor_data_map); + int64_t t_copy_end = ggml_time_ms(); + if (ggml_backend_is_cpu(runtime_backend)) { + ggml_backend_cpu_set_n_threads(runtime_backend, n_threads); + } + + int64_t t_compute_begin = ggml_time_ms(); + ggml_status status = ggml_backend_graph_compute(runtime_backend, gf); + int64_t t_compute_end = ggml_time_ms(); + if (status != GGML_STATUS_SUCCESS) { + LOG_ERROR("%s compute failed: %s", get_desc().c_str(), ggml_status_to_string(status)); + if (free_compute_buffer_immediately) { + free_compute_buffer(); + } else if (use_partial_param_offload) { + restore_partial_params(); + } + return std::nullopt; + } + + int64_t t_cache_begin = ggml_time_ms(); + if (!copy_cache_tensors_to_cache_buffer(cache_keep_names)) { + if (free_compute_buffer_immediately) { + free_compute_buffer(); + } else if (use_partial_param_offload) { + restore_partial_params(); + } + return std::nullopt; + } + int64_t t_cache_end = ggml_time_ms(); + auto result = ggml_get_tensor(compute_ctx, final_result_name.c_str()); + std::optional> output; + if (!no_return) { + output = sd::make_sd_tensor_from_ggml(result); + } else { + output = sd::Tensor(); + } + + if (free_compute_buffer_immediately) { + free_compute_buffer(); + } else if (use_partial_param_offload) { + restore_partial_params(); + } + if (use_partial_param_offload) { + LOG_DEBUG("%s execute_graph timing: offload=%lld ms alloc=%lld ms copy_in=%lld ms compute=%lld ms cache=%lld ms total=%lld ms", + get_desc().c_str(), + t_offload_end - t_offload_begin, + t_alloc_end - t_alloc_begin, + t_copy_end - t_copy_begin, + t_compute_end - t_compute_begin, + t_cache_end - t_cache_begin, + ggml_time_ms() - t_execute_begin); + } + return output; + } + + template + std::optional> compute_with_graph_cuts(ggml_cgraph* gf, + const GraphCutPlan& plan, + int n_threads, + bool free_compute_buffer_immediately, + bool no_return = false) { + GGML_ASSERT(gf != nullptr); + + free_compute_buffer(); + free_cache_ctx_and_buffer(); + + std::optional> output = sd::Tensor(); + for (size_t seg_idx = 0; seg_idx < plan.segments.size(); ++seg_idx) { + int64_t t_segment_begin = ggml_time_ms(); + const auto& segment = plan.segments[seg_idx]; + auto future_cut_names = sd::ggml_graph_cut::collect_future_input_names(gf, plan, seg_idx); + LOG_DEBUG("%s graph cut executing segment %zu/%zu: %s", + get_desc().c_str(), + seg_idx + 1, + plan.segments.size(), + segment.group_name.c_str()); + + reset_segment_runtime_tensors(segment, gf); + if (!bind_segment_cached_inputs(gf, segment)) { + free_cache_ctx_and_buffer(); + free_compute_buffer(); + free_compute_ctx(); + return std::nullopt; + } + + const bool is_last_segment = seg_idx + 1 == plan.segments.size(); + if (!is_last_segment) { + for (size_t output_idx = 0; output_idx < segment.output_node_indices.size(); ++output_idx) { + ggml_tensor* output_tensor = sd::ggml_graph_cut::output_tensor(gf, segment, output_idx); + if (output_tensor != nullptr && + sd::ggml_graph_cut::is_graph_cut_tensor(output_tensor) && + future_cut_names.find(output_tensor->name) != future_cut_names.end()) { + cache(output_tensor->name, output_tensor); + } + } + } + + ggml_context* segment_graph_ctx = nullptr; + ggml_cgraph* segment_graph = sd::ggml_graph_cut::build_segment_graph(gf, segment, &segment_graph_ctx); + auto segment_output = execute_graph(segment_graph, + n_threads, + true, + sd::ggml_graph_cut::runtime_param_tensors(gf, segment, get_desc().c_str()), + true, + !is_last_segment || no_return, + &future_cut_names); + ggml_free(segment_graph_ctx); + if (!segment_output.has_value()) { + free_cache_ctx_and_buffer(); + free_compute_buffer(); + free_compute_ctx(); + return std::nullopt; + } + output = std::move(segment_output); + } + + backend_tensor_data_map.clear(); + free_cache_ctx_and_buffer(); + free_compute_ctx(); + return output; + } + public: virtual std::string get_desc() = 0; GGMLRunner(ggml_backend_t backend, bool offload_params_to_cpu = false) : runtime_backend(backend) { - alloc_params_ctx(); if (!ggml_backend_is_cpu(runtime_backend) && offload_params_to_cpu) { params_backend = ggml_backend_cpu_init(); } else { params_backend = runtime_backend; } + alloc_params_ctx(); } virtual ~GGMLRunner() { @@ -2042,6 +2574,8 @@ struct GGMLRunner { num_tensors); return false; } + rebuild_params_tensor_set(); + ggml_backend_buffer_set_usage(params_buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); size_t params_buffer_size = ggml_backend_buffer_get_size(params_buffer); LOG_DEBUG("%s params backend buffer size = % 6.2f MB(%s) (%i tensors)", get_desc().c_str(), @@ -2075,7 +2609,8 @@ struct GGMLRunner { ggml_gallocr_free(compute_allocr); compute_allocr = nullptr; } - offload_params_to_params_backend(); + restore_partial_params(); + restore_all_params(); } // do copy after alloc graph @@ -2139,41 +2674,36 @@ struct GGMLRunner { int n_threads, bool free_compute_buffer_immediately, bool no_return = false) { - if (!offload_params_to_runtime_backend()) { - LOG_ERROR("%s offload params to runtime backend failed", get_desc().c_str()); - return std::nullopt; - } - if (!alloc_compute_buffer(get_graph)) { - LOG_ERROR("%s alloc compute buffer failed", get_desc().c_str()); - return std::nullopt; - } - reset_compute_ctx(); - ggml_cgraph* gf = get_compute_graph(get_graph); - if (!ggml_gallocr_alloc_graph(compute_allocr, gf)) { - LOG_ERROR("%s alloc compute graph failed", get_desc().c_str()); + ggml_cgraph* gf = nullptr; + if (!prepare_compute_graph(get_graph, &gf)) { return std::nullopt; } - copy_data_to_backend_tensor(); - if (ggml_backend_is_cpu(runtime_backend)) { - ggml_backend_cpu_set_n_threads(runtime_backend, n_threads); - } + GGML_ASSERT(gf != nullptr); - ggml_status status = ggml_backend_graph_compute(runtime_backend, gf); - if (status != GGML_STATUS_SUCCESS) { - LOG_ERROR("%s compute failed: %s", get_desc().c_str(), ggml_status_to_string(status)); - return std::nullopt; - } - copy_cache_tensors_to_cache_buffer(); - auto result = ggml_get_tensor(compute_ctx, final_result_name.c_str()); - std::optional> output; - if (!no_return) { - output = sd::make_sd_tensor_from_ggml(result); + if (can_attempt_graph_cut_segmented_compute()) { + GraphCutPlan plan; + if (!resolve_graph_cut_plan(gf, &plan)) { + free_compute_ctx(); + return std::nullopt; + } + if (should_use_graph_cut_segmented_compute(plan)) { + return compute_with_graph_cuts(gf, + plan, + n_threads, + free_compute_buffer_immediately, + no_return); + } } - - if (free_compute_buffer_immediately) { - free_compute_buffer(); + if (!alloc_compute_buffer(gf)) { + LOG_ERROR("%s alloc compute buffer failed", get_desc().c_str()); + return std::nullopt; } - return output; + return execute_graph(gf, + n_threads, + free_compute_buffer_immediately, + {}, + false, + no_return); } void set_flash_attention_enabled(bool enabled) { @@ -2192,6 +2722,18 @@ struct GGMLRunner { void set_weight_adapter(const std::shared_ptr& adapter) { weight_adapter = adapter; } + + void set_max_graph_vram_bytes(size_t max_vram_bytes) { + max_graph_vram_bytes = max_vram_bytes; + } + + ggml_backend_t get_runtime_backend() { + return runtime_backend; + } + + ggml_backend_t get_params_backend() { + return params_backend; + } }; class GGMLBlock { @@ -2336,6 +2878,14 @@ class Linear : public UnaryBlock { force_prec_f32(force_prec_f32), scale(scale) {} + void set_scale(float scale_) { + scale = scale_; + } + + void set_force_prec_f32(bool force_prec_f32_) { + force_prec_f32 = force_prec_f32_; + } + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { ggml_tensor* w = params["weight"]; ggml_tensor* b = nullptr; @@ -2347,7 +2897,7 @@ class Linear : public UnaryBlock { forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR; forward_params.linear.force_prec_f32 = force_prec_f32; forward_params.linear.scale = scale; - return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); + return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x, w, b, prefix, forward_params); } return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); } @@ -2463,7 +3013,7 @@ class Conv2d : public UnaryBlock { forward_params.conv2d.circular_x = ctx->circular_x_enabled; forward_params.conv2d.circular_y = ctx->circular_y_enabled; forward_params.conv2d.scale = scale; - return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); + return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x, w, b, prefix, forward_params); } return ggml_ext_conv_2d(ctx->ggml_ctx, x, @@ -2527,7 +3077,7 @@ class Conv3d : public UnaryBlock { ggml_tensor* w = params["weight"]; ggml_tensor* b = nullptr; if (ctx->weight_adapter) { - w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight"); + w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, w, prefix + "weight"); if (w->type != GGML_TYPE_F16) { w = ggml_cast(ctx->ggml_ctx, w, GGML_TYPE_F16); } @@ -2535,7 +3085,7 @@ class Conv3d : public UnaryBlock { if (bias) { b = params["bias"]; if (ctx->weight_adapter) { - b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "bias"); + b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, b, prefix + "bias"); } } return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels, @@ -2582,12 +3132,12 @@ class LayerNorm : public UnaryBlock { if (elementwise_affine) { w = params["weight"]; if (ctx->weight_adapter) { - w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight"); + w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, w, prefix + "weight"); } if (bias) { b = params["bias"]; if (ctx->weight_adapter) { - b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "bias"); + b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, b, prefix + "bias"); } } } @@ -2630,8 +3180,8 @@ class GroupNorm : public GGMLBlock { w = params["weight"]; b = params["bias"]; if (ctx->weight_adapter) { - w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight"); - b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "bias"); + w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, w, prefix + "weight"); + b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, b, prefix + "bias"); } } return ggml_ext_group_norm(ctx->ggml_ctx, x, w, b, num_groups); @@ -2665,7 +3215,7 @@ class RMSNorm : public UnaryBlock { ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { ggml_tensor* w = params["weight"]; if (ctx->weight_adapter) { - w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight"); + w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, w, prefix + "weight"); } x = ggml_rms_norm(ctx->ggml_ctx, x, eps); x = ggml_mul_inplace(ctx->ggml_ctx, x, w); @@ -2748,6 +3298,7 @@ class MultiheadAttention : public GGMLBlock { __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward( ggml_context* ctx, + ggml_backend_t backend, ggml_tensor* h, // Input: [q, batch] or [W, H, q, batch] ggml_tensor* w1, // Outer C (Full rank) ggml_tensor* w1a, // Outer A (Low rank part 1) @@ -2758,17 +3309,17 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward( bool is_conv, WeightAdapter::ForwardParams::conv2d_params_t conv_params, float scale) { - GGML_ASSERT((w1 != NULL || (w1a != NULL && w1b != NULL))); - GGML_ASSERT((w2 != NULL || (w2a != NULL && w2b != NULL))); + GGML_ASSERT((w1 != nullptr || (w1a != nullptr && w1b != nullptr))); + GGML_ASSERT((w2 != nullptr || (w2a != nullptr && w2b != nullptr))); - int uq = (w1 != NULL) ? (int)w1->ne[0] : (int)w1a->ne[0]; - int up = (w1 != NULL) ? (int)w1->ne[1] : (int)w1b->ne[1]; + int uq = (w1 != nullptr) ? (int)w1->ne[0] : (int)w1a->ne[0]; + int up = (w1 != nullptr) ? (int)w1->ne[1] : (int)w1b->ne[1]; int q_actual = is_conv ? (int)h->ne[2] : (int)h->ne[0]; int vq = q_actual / uq; - int vp = (w2 != NULL) ? (is_conv ? (int)w2->ne[3] : (int)w2->ne[1]) - : (int)w2a->ne[1]; + int vp = (w2 != nullptr) ? (is_conv ? (int)w2->ne[3] : (int)w2->ne[1]) + : (int)w2a->ne[1]; GGML_ASSERT(q_actual == (uq * vq) && "Input dimension mismatch for LoKR split"); ggml_tensor* hb; @@ -2778,32 +3329,32 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward( int merge_batch_uq = batch; int merge_batch_vp = batch; -#if SD_USE_VULKAN - if (batch > 1) { - // no access to backend here, worst case is slightly worse perfs for other backends when built alongside Vulkan backend - int max_batch = 65535; - int max_batch_uq = max_batch / uq; - merge_batch_uq = 1; - for (int i = max_batch_uq; i > 0; i--) { - if (batch % i == 0) { - merge_batch_uq = i; - break; + if (sd_backend_is(backend, "Vulkan")) { + if (batch > 1) { + // no access to backend here, worst case is slightly worse perfs for other backends when built alongside Vulkan backend + int max_batch = 65535; + int max_batch_uq = max_batch / uq; + merge_batch_uq = 1; + for (int i = max_batch_uq; i > 0; i--) { + if (batch % i == 0) { + merge_batch_uq = i; + break; + } } - } - int max_batch_vp = max_batch / vp; - merge_batch_vp = 1; - for (int i = max_batch_vp; i > 0; i--) { - if (batch % i == 0) { - merge_batch_vp = i; - break; + int max_batch_vp = max_batch / vp; + merge_batch_vp = 1; + for (int i = max_batch_vp; i > 0; i--) { + if (batch % i == 0) { + merge_batch_vp = i; + break; + } } } } -#endif ggml_tensor* h_split = ggml_reshape_3d(ctx, h, vq, uq * merge_batch_uq, batch / merge_batch_uq); - if (w2 != NULL) { + if (w2 != nullptr) { hb = ggml_mul_mat(ctx, w2, h_split); } else { hb = ggml_mul_mat(ctx, w2b, ggml_mul_mat(ctx, w2a, h_split)); @@ -2816,7 +3367,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward( hb_t = ggml_reshape_3d(ctx, hb_t, uq, vp * merge_batch_vp, batch / merge_batch_vp); ggml_tensor* hc_t; - if (w1 != NULL) { + if (w1 != nullptr) { hc_t = ggml_mul_mat(ctx, w1, hb_t); } else { hc_t = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_t)); @@ -2834,7 +3385,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward( // 1. Reshape input: [W, H, vq*uq, batch] -> [W, H, vq, uq * batch] ggml_tensor* h_split = ggml_reshape_4d(ctx, h, h->ne[0], h->ne[1], vq, uq * batch); - if (w2 != NULL) { + if (w2 != nullptr) { hb = ggml_ext_conv_2d(ctx, h_split, w2, nullptr, conv_params.s0, conv_params.s1, @@ -2902,7 +3453,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward( ggml_tensor* hb_merged = ggml_reshape_2d(ctx, hb, w_out * h_out * vp, uq * batch); ggml_tensor* hc_t; ggml_tensor* hb_merged_t = ggml_cont(ctx, ggml_transpose(ctx, hb_merged)); - if (w1 != NULL) { + if (w1 != nullptr) { // Would be great to be able to transpose w1 instead to avoid transposing both hb and hc hc_t = ggml_mul_mat(ctx, w1, hb_merged_t); } else { diff --git a/src/ggml_extend_backend.hpp b/src/ggml_extend_backend.hpp new file mode 100644 index 000000000..50158c883 --- /dev/null +++ b/src/ggml_extend_backend.hpp @@ -0,0 +1,298 @@ +#ifndef __GGML_EXTEND_BACKEND_HPP__ +#define __GGML_EXTEND_BACKEND_HPP__ + +#include +#include + +#include "ggml-backend.h" +#include "ggml.h" + +#ifndef __STATIC_INLINE__ +#define __STATIC_INLINE__ static inline +#endif + +inline void ggml_backend_load_all_once() { + // If the registry already has devices and the CPU backend is present, + // assume either static registration or explicit host-side preloading has + // completed and avoid rescanning the default paths. + if (ggml_backend_dev_count() > 0 && ggml_backend_reg_by_name("CPU") != nullptr) { + return; + } + // In dynamic-backend mode the backend modules are discovered at runtime, + // so we must load them before asking for the CPU backend or its proc table. + // If the host preloaded only a subset of backends, allow one default-path + // scan so missing modules can still be discovered. + static std::once_flag once; + std::call_once(once, []() { + if (ggml_backend_dev_count() > 0 && ggml_backend_reg_by_name("CPU") != nullptr) { + return; + } + ggml_backend_load_all(); + }); +} + +// Do not gate this branch on GGML_CPU or GGML_CPU_ALL_VARIANTS: +// those are CMake options used to configure ggml itself, but they are not +// exported as PUBLIC compile definitions to stable-diffusion in backend-DL mode. +// In practice, this target can reliably see GGML_BACKEND_DL, but not whether +// the CPU backend was compiled as a loadable module. We therefore use runtime +// backend discovery instead of compile-time assumptions. + +__STATIC_INLINE__ ggml_backend_reg_t ggml_backend_cpu_reg() { + ggml_backend_reg_t reg = ggml_backend_reg_by_name("CPU"); + if (reg != nullptr) { + return reg; + } + + ggml_backend_load_all_once(); + return ggml_backend_reg_by_name("CPU"); +} + +__STATIC_INLINE__ ggml_backend_reg_t ggml_backend_reg_from_backend(ggml_backend_t backend) { + if (backend != nullptr) { + ggml_backend_dev_t device = ggml_backend_get_device(backend); + if (device != nullptr) { + return ggml_backend_dev_backend_reg(device); + } + } + + return ggml_backend_cpu_reg(); +} + +__STATIC_INLINE__ ggml_backend_t ggml_backend_cpu_init() { + ggml_backend_t backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + if (backend != nullptr) { + return backend; + } + + ggml_backend_load_all_once(); + return ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); +} + +__STATIC_INLINE__ bool ggml_backend_is_cpu(ggml_backend_t backend) { + if (backend == nullptr) { + return false; + } + + ggml_backend_dev_t device = ggml_backend_get_device(backend); + if (device != nullptr) { + return ggml_backend_dev_type(device) == GGML_BACKEND_DEVICE_TYPE_CPU; + } + + const char* backend_name = ggml_backend_name(backend); + return backend_name != nullptr && std::strcmp(backend_name, "CPU") == 0; +} + +__STATIC_INLINE__ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { + ggml_backend_reg_t reg = ggml_backend_reg_from_backend(backend_cpu); + if (reg == nullptr) { + return; + } + + auto fn = reinterpret_cast(ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads")); + if (fn != nullptr) { + fn(backend_cpu, n_threads); + } +} + +using __ggml_backend_cpu_set_threadpool_t = void (*)(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool); + +__STATIC_INLINE__ void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) { + ggml_backend_reg_t reg = ggml_backend_reg_from_backend(backend_cpu); + if (reg == nullptr) { + return; + } + + auto fn = reinterpret_cast<__ggml_backend_cpu_set_threadpool_t>(ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool")); + if (fn != nullptr) { + fn(backend_cpu, threadpool); + } +} + +__STATIC_INLINE__ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void* abort_callback_data) { + ggml_backend_reg_t reg = ggml_backend_reg_from_backend(backend_cpu); + if (reg == nullptr) { + return; + } + + auto fn = reinterpret_cast(ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback")); + if (fn != nullptr) { + fn(backend_cpu, abort_callback, abort_callback_data); + } +} + +__STATIC_INLINE__ ggml_backend_buffer_t ggml_backend_tensor_buffer(const struct ggml_tensor* tensor) { + if (tensor == nullptr) { + return nullptr; + } + + return tensor->view_src ? tensor->view_src->buffer : tensor->buffer; +} + +__STATIC_INLINE__ bool ggml_backend_tensor_is_host_accessible(const struct ggml_tensor* tensor) { + if (tensor == nullptr || tensor->data == nullptr) { + return false; + } + + ggml_backend_buffer_t buffer = ggml_backend_tensor_buffer(tensor); + return buffer == nullptr || ggml_backend_buffer_is_host(buffer); +} + +__STATIC_INLINE__ size_t ggml_backend_tensor_offset(const struct ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + return (size_t)(i0 * tensor->nb[0] + i1 * tensor->nb[1] + i2 * tensor->nb[2] + i3 * tensor->nb[3]); +} + +template +__STATIC_INLINE__ void ggml_backend_tensor_write_scalar(const struct ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3, T value) { + const size_t offset = ggml_backend_tensor_offset(tensor, i0, i1, i2, i3); + + if (ggml_backend_tensor_is_host_accessible(tensor)) { + auto* dst = reinterpret_cast(reinterpret_cast(tensor->data) + offset); + *dst = value; + return; + } + + ggml_backend_tensor_set(const_cast(tensor), &value, offset, sizeof(T)); +} + +__STATIC_INLINE__ void ggml_set_f32_nd(const struct ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float value) { + switch (tensor->type) { + case GGML_TYPE_I8: + ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, static_cast(value)); + break; + case GGML_TYPE_I16: + ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, static_cast(value)); + break; + case GGML_TYPE_I32: + ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, static_cast(value)); + break; + case GGML_TYPE_F16: + ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, ggml_fp32_to_fp16(value)); + break; + case GGML_TYPE_BF16: + ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, ggml_fp32_to_bf16(value)); + break; + case GGML_TYPE_F32: + ggml_backend_tensor_write_scalar(tensor, i0, i1, i2, i3, value); + break; + default: + GGML_ABORT("fatal error"); + } +} + +__STATIC_INLINE__ void ggml_set_f32_1d(const struct ggml_tensor* tensor, int i, float value) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = {0, 0, 0, 0}; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value); + return; + } + + switch (tensor->type) { + case GGML_TYPE_I8: + ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, static_cast(value)); + break; + case GGML_TYPE_I16: + ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, static_cast(value)); + break; + case GGML_TYPE_I32: + ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, static_cast(value)); + break; + case GGML_TYPE_F16: + ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, ggml_fp32_to_fp16(value)); + break; + case GGML_TYPE_BF16: + ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, ggml_fp32_to_bf16(value)); + break; + case GGML_TYPE_F32: + ggml_backend_tensor_write_scalar(tensor, i, 0, 0, 0, value); + break; + default: + GGML_ABORT("fatal error"); + } +} + +__STATIC_INLINE__ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context* ctx, struct ggml_cgraph* cgraph, int n_threads) { + (void)ctx; + + // The legacy ggml_graph_compute_with_ctx() symbol lives in ggml-cpu, but + // the backend proc table does not expose it in GGML_BACKEND_DL mode. + // Recreate the old behavior by initializing the CPU backend explicitly and + // executing the graph through the generic backend API. + ggml_backend_t backend = ggml_backend_cpu_init(); + if (backend == nullptr) { + return GGML_STATUS_ALLOC_FAILED; + } + + ggml_backend_cpu_set_n_threads(backend, n_threads); + + const enum ggml_status status = ggml_backend_graph_compute(backend, cgraph); + ggml_backend_free(backend); + + return status; +} + +__STATIC_INLINE__ ggml_tensor* ggml_set_f32(struct ggml_tensor* tensor, float value) { + GGML_ASSERT(tensor != nullptr); + + if (ggml_backend_tensor_is_host_accessible(tensor) && ggml_is_contiguous(tensor)) { + const int64_t nelements = ggml_nelements(tensor); + + switch (tensor->type) { + case GGML_TYPE_I8: { + auto* data = reinterpret_cast(tensor->data); + const int8_t v = static_cast(value); + for (int64_t i = 0; i < nelements; ++i) { + data[i] = v; + } + } break; + case GGML_TYPE_I16: { + auto* data = reinterpret_cast(tensor->data); + const int16_t v = static_cast(value); + for (int64_t i = 0; i < nelements; ++i) { + data[i] = v; + } + } break; + case GGML_TYPE_I32: { + auto* data = reinterpret_cast(tensor->data); + const int32_t v = static_cast(value); + for (int64_t i = 0; i < nelements; ++i) { + data[i] = v; + } + } break; + case GGML_TYPE_F16: { + auto* data = reinterpret_cast(tensor->data); + const ggml_fp16_t v = ggml_fp32_to_fp16(value); + for (int64_t i = 0; i < nelements; ++i) { + data[i] = v; + } + } break; + case GGML_TYPE_BF16: { + auto* data = reinterpret_cast(tensor->data); + const ggml_bf16_t v = ggml_fp32_to_bf16(value); + for (int64_t i = 0; i < nelements; ++i) { + data[i] = v; + } + } break; + case GGML_TYPE_F32: { + auto* data = reinterpret_cast(tensor->data); + for (int64_t i = 0; i < nelements; ++i) { + data[i] = value; + } + } break; + default: + GGML_ABORT("fatal error"); + } + + return tensor; + } + + const int64_t nelements = ggml_nelements(tensor); + for (int64_t i = 0; i < nelements; ++i) { + ggml_set_f32_1d(tensor, static_cast(i), value); + } + + return tensor; +} + +#endif diff --git a/src/ggml_graph_cut.cpp b/src/ggml_graph_cut.cpp new file mode 100644 index 000000000..f206f2d2d --- /dev/null +++ b/src/ggml_graph_cut.cpp @@ -0,0 +1,676 @@ +#include "ggml_graph_cut.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "util.h" + +#include "../ggml/src/ggml-impl.h" + +namespace sd::ggml_graph_cut { + + static std::string graph_cut_tensor_display_name(const ggml_tensor* tensor) { + if (tensor == nullptr) { + return ""; + } + if (tensor->name[0] != '\0') { + return tensor->name; + } + return sd_format("", (const void*)tensor); + } + + static int graph_leaf_index(ggml_cgraph* gf, const ggml_tensor* tensor) { + GGML_ASSERT(gf != nullptr); + GGML_ASSERT(tensor != nullptr); + for (int i = 0; i < gf->n_leafs; ++i) { + if (gf->leafs[i] == tensor) { + return i; + } + } + return -1; + } + + static bool is_params_tensor(const std::unordered_set& params_tensor_set, + const ggml_tensor* tensor) { + if (tensor == nullptr) { + return false; + } + return params_tensor_set.find(tensor) != params_tensor_set.end(); + } + + static Plan::InputShape input_shape(const ggml_tensor* tensor) { + Plan::InputShape shape; + if (tensor == nullptr) { + return shape; + } + shape.type = tensor->type; + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + shape.ne[static_cast(i)] = tensor->ne[i]; + } + return shape; + } + + static size_t graph_cut_segment_vram_bytes(const Segment& segment) { + return segment.compute_buffer_size + + segment.input_param_bytes + + segment.input_previous_cut_bytes + + segment.output_bytes; + } + + static Segment make_segment_seed(const Plan& plan, + size_t start_segment_index, + size_t end_segment_index) { + GGML_ASSERT(start_segment_index < plan.segments.size()); + GGML_ASSERT(end_segment_index < plan.segments.size()); + GGML_ASSERT(start_segment_index <= end_segment_index); + + Segment seed; + const auto& start_segment = plan.segments[start_segment_index]; + const auto& target_segment = plan.segments[end_segment_index]; + std::unordered_set seen_output_node_indices; + for (size_t seg_idx = start_segment_index; seg_idx <= end_segment_index; ++seg_idx) { + for (int output_node_index : plan.segments[seg_idx].output_node_indices) { + if (seen_output_node_indices.insert(output_node_index).second) { + seed.output_node_indices.push_back(output_node_index); + } + } + } + if (start_segment_index == end_segment_index) { + seed.group_name = target_segment.group_name; + } else { + seed.group_name = sd_format("%s..%s", + start_segment.group_name.c_str(), + target_segment.group_name.c_str()); + } + return seed; + } + + static void build_segment(ggml_cgraph* gf, + Plan& plan, + Segment& segment, + const std::unordered_map& producer_index, + std::unordered_set& available_cut_output_node_indices, + ggml_backend_t backend, + const std::unordered_set& params_tensor_set, + const char* log_desc) { + std::set internal_nodes; + std::unordered_set input_seen; + std::vector input_refs; + + std::stack work_stack; + for (int output_node_index : segment.output_node_indices) { + ggml_tensor* output = ggml_graph_node(gf, output_node_index); + if (output != nullptr) { + work_stack.push(output); + } + } + + while (!work_stack.empty()) { + ggml_tensor* tensor = work_stack.top(); + work_stack.pop(); + + if (tensor == nullptr) { + continue; + } + + auto producer_it = producer_index.find(tensor); + if (producer_it == producer_index.end()) { + if (input_seen.insert(tensor).second) { + Segment::InputRef input_ref; + input_ref.type = is_params_tensor(params_tensor_set, tensor) ? Segment::INPUT_PARAM : Segment::INPUT_EXTERNAL; + input_ref.display_name = graph_cut_tensor_display_name(tensor); + input_ref.leaf_index = graph_leaf_index(gf, tensor); + input_refs.push_back(std::move(input_ref)); + } + continue; + } + + int node_idx = producer_it->second; + if (available_cut_output_node_indices.find(node_idx) != available_cut_output_node_indices.end()) { + if (input_seen.insert(tensor).second) { + Segment::InputRef input_ref; + input_ref.type = Segment::INPUT_PREVIOUS_CUT; + input_ref.display_name = graph_cut_tensor_display_name(tensor); + input_ref.node_index = node_idx; + input_refs.push_back(std::move(input_ref)); + } + continue; + } + + if (!internal_nodes.insert(node_idx).second) { + continue; + } + + ggml_tensor* node = ggml_graph_node(gf, node_idx); + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + if (node->src[src_idx] != nullptr) { + work_stack.push(node->src[src_idx]); + } + } + } + + if (!internal_nodes.empty()) { + segment.internal_node_indices.assign(internal_nodes.begin(), internal_nodes.end()); + } + + std::sort(input_refs.begin(), + input_refs.end(), + [](const Segment::InputRef& a, const Segment::InputRef& b) { + if (a.type != b.type) { + return a.type < b.type; + } + return a.display_name < b.display_name; + }); + segment.input_refs = input_refs; + for (const auto& input : input_refs) { + ggml_tensor* current_input = input_tensor(gf, input); + size_t tensor_bytes = current_input == nullptr + ? 0 + : (input.type == Segment::INPUT_PREVIOUS_CUT + ? cache_tensor_bytes(current_input) + : ggml_nbytes(current_input)); + switch (input.type) { + case Segment::INPUT_PREVIOUS_CUT: + segment.input_previous_cut_bytes += tensor_bytes; + break; + case Segment::INPUT_PARAM: + segment.input_param_bytes += tensor_bytes; + break; + case Segment::INPUT_EXTERNAL: + default: + segment.input_external_bytes += tensor_bytes; + break; + } + } + for (int output_node_index : segment.output_node_indices) { + ggml_tensor* output = ggml_graph_node(gf, output_node_index); + segment.output_bytes += cache_tensor_bytes(output); + } + segment.compute_buffer_size = measure_segment_compute_buffer(backend, gf, segment, log_desc); + + for (int output_node_index : segment.output_node_indices) { + available_cut_output_node_indices.insert(output_node_index); + } + plan.segments.push_back(std::move(segment)); + } + + bool is_graph_cut_tensor(const ggml_tensor* tensor) { + if (tensor == nullptr || tensor->name[0] == '\0') { + return false; + } + return std::strncmp(tensor->name, GGML_RUNNER_CUT_PREFIX, std::strlen(GGML_RUNNER_CUT_PREFIX)) == 0; + } + + std::string make_graph_cut_name(const std::string& group, const std::string& output) { + return std::string(GGML_RUNNER_CUT_PREFIX) + group + "|" + output; + } + + void mark_graph_cut(ggml_tensor* tensor, const std::string& group, const std::string& output) { + if (tensor == nullptr) { + return; + } + auto name = make_graph_cut_name(group, output); + ggml_set_name(tensor, name.c_str()); + } + + int leaf_count(ggml_cgraph* gf) { + GGML_ASSERT(gf != nullptr); + return gf->n_leafs; + } + + ggml_tensor* leaf_tensor(ggml_cgraph* gf, int leaf_index) { + GGML_ASSERT(gf != nullptr); + if (leaf_index < 0 || leaf_index >= gf->n_leafs) { + return nullptr; + } + return gf->leafs[leaf_index]; + } + + ggml_backend_buffer_t tensor_buffer(const ggml_tensor* tensor) { + if (tensor == nullptr) { + return nullptr; + } + return tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + } + + ggml_tensor* cache_source_tensor(ggml_tensor* tensor) { + if (tensor == nullptr) { + return nullptr; + } + return tensor->view_src ? tensor->view_src : tensor; + } + + size_t cache_tensor_bytes(const ggml_tensor* tensor) { + if (tensor == nullptr) { + return 0; + } + const ggml_tensor* cache_src = tensor->view_src ? tensor->view_src : tensor; + return ggml_nbytes(cache_src); + } + + bool plan_matches_graph(ggml_cgraph* gf, const Plan& plan) { + GGML_ASSERT(gf != nullptr); + if (ggml_graph_n_nodes(gf) != plan.n_nodes || gf->n_leafs != plan.n_leafs) { + return false; + } + for (const auto& input_shape_ref : plan.input_shapes) { + if (input_shape_ref.leaf_index < 0 || input_shape_ref.leaf_index >= gf->n_leafs) { + return false; + } + ggml_tensor* leaf = gf->leafs[input_shape_ref.leaf_index]; + if (leaf == nullptr || input_shape_ref.type != leaf->type) { + return false; + } + for (int d = 0; d < GGML_MAX_DIMS; ++d) { + if (input_shape_ref.ne[static_cast(d)] != leaf->ne[d]) { + return false; + } + } + } + return true; + } + + ggml_tensor* output_tensor(ggml_cgraph* gf, const Segment& segment, size_t output_index) { + GGML_ASSERT(gf != nullptr); + if (output_index >= segment.output_node_indices.size()) { + return nullptr; + } + int node_index = segment.output_node_indices[output_index]; + if (node_index < 0 || node_index >= ggml_graph_n_nodes(gf)) { + return nullptr; + } + return ggml_graph_node(gf, node_index); + } + + ggml_tensor* input_tensor(ggml_cgraph* gf, const Segment::InputRef& input_ref) { + GGML_ASSERT(gf != nullptr); + if (input_ref.type == Segment::INPUT_PREVIOUS_CUT) { + if (input_ref.node_index < 0 || input_ref.node_index >= ggml_graph_n_nodes(gf)) { + return nullptr; + } + return ggml_graph_node(gf, input_ref.node_index); + } + if (input_ref.leaf_index < 0 || input_ref.leaf_index >= gf->n_leafs) { + return nullptr; + } + return leaf_tensor(gf, input_ref.leaf_index); + } + + std::vector param_tensors(ggml_cgraph* gf, const Segment& segment) { + GGML_ASSERT(gf != nullptr); + std::vector tensors; + std::unordered_set seen_tensors; + tensors.reserve(segment.input_refs.size()); + seen_tensors.reserve(segment.input_refs.size()); + for (const auto& input_ref : segment.input_refs) { + if (input_ref.type != Segment::INPUT_PARAM) { + continue; + } + ggml_tensor* tensor = input_tensor(gf, input_ref); + if (tensor == nullptr) { + continue; + } + if (seen_tensors.insert(tensor).second) { + tensors.push_back(tensor); + } + } + return tensors; + } + + std::vector runtime_param_tensors(ggml_cgraph* gf, const Segment& segment, const char* log_desc) { + std::vector tensors = param_tensors(gf, segment); + std::vector filtered_tensors; + filtered_tensors.reserve(tensors.size()); + for (ggml_tensor* tensor : tensors) { + if (tensor_buffer(tensor) == nullptr) { + LOG_WARN("%s graph cut skipping param input without buffer: segment=%s tensor=%s", + log_desc == nullptr ? "unknown" : log_desc, + segment.group_name.c_str(), + tensor->name); + continue; + } + filtered_tensors.push_back(tensor); + } + return filtered_tensors; + } + + std::unordered_set collect_future_input_names(ggml_cgraph* gf, + const Plan& plan, + size_t current_segment_index) { + GGML_ASSERT(gf != nullptr); + std::unordered_set future_input_names; + for (size_t seg_idx = current_segment_index + 1; seg_idx < plan.segments.size(); ++seg_idx) { + const auto& segment = plan.segments[seg_idx]; + for (const auto& input_ref : segment.input_refs) { + if (input_ref.type != Segment::INPUT_PREVIOUS_CUT) { + continue; + } + ggml_tensor* current_input = input_tensor(gf, input_ref); + if (current_input != nullptr && current_input->name[0] != '\0') { + future_input_names.insert(current_input->name); + } + } + } + return future_input_names; + } + + ggml_cgraph* build_segment_graph(ggml_cgraph* gf, + const Segment& segment, + ggml_context** graph_ctx_out) { + GGML_ASSERT(gf != nullptr); + GGML_ASSERT(graph_ctx_out != nullptr); + + const size_t graph_size = segment.internal_node_indices.size() + segment.input_refs.size() + 8; + ggml_init_params params = { + /*.mem_size =*/ggml_graph_overhead_custom(graph_size, false) + 1024, + /*.mem_buffer =*/nullptr, + /*.no_alloc =*/true, + }; + ggml_context* graph_ctx = ggml_init(params); + GGML_ASSERT(graph_ctx != nullptr); + ggml_cgraph* segment_graph = ggml_new_graph_custom(graph_ctx, graph_size, false); + GGML_ASSERT(segment_graph != nullptr); + + for (const auto& input : segment.input_refs) { + ggml_tensor* current_input = input_tensor(gf, input); + if (current_input == nullptr) { + continue; + } + GGML_ASSERT(segment_graph->n_leafs < segment_graph->size); + segment_graph->leafs[segment_graph->n_leafs++] = current_input; + } + + for (int output_node_index : segment.output_node_indices) { + ggml_tensor* output = ggml_graph_node(gf, output_node_index); + if (output == nullptr) { + continue; + } + ggml_set_output(output); + } + for (int node_idx : segment.internal_node_indices) { + ggml_graph_add_node(segment_graph, ggml_graph_node(gf, node_idx)); + } + *graph_ctx_out = graph_ctx; + return segment_graph; + } + + size_t measure_segment_compute_buffer(ggml_backend_t backend, + ggml_cgraph* gf, + const Segment& segment, + const char* log_desc) { + GGML_ASSERT(backend != nullptr); + GGML_ASSERT(gf != nullptr); + if (segment.internal_node_indices.empty()) { + return 0; + } + + ggml_context* graph_ctx = nullptr; + ggml_cgraph* segment_graph = build_segment_graph(gf, segment, &graph_ctx); + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + + size_t sizes[1] = {0}; + ggml_gallocr_reserve_n_size( + allocr, + segment_graph, + nullptr, + nullptr, + sizes); + size_t buffer_size = sizes[0]; + + ggml_gallocr_free(allocr); + ggml_free(graph_ctx); + return buffer_size; + } + + Plan build_plan(ggml_backend_t backend, + ggml_cgraph* gf, + const std::unordered_set& params_tensor_set, + const char* log_desc) { + GGML_ASSERT(backend != nullptr); + GGML_ASSERT(gf != nullptr); + Plan plan; + plan.available = true; + const int n_nodes = ggml_graph_n_nodes(gf); + if (n_nodes <= 0) { + return plan; + } + plan.n_nodes = n_nodes; + plan.n_leafs = gf->n_leafs; + for (int i = 0; i < gf->n_leafs; ++i) { + ggml_tensor* leaf = gf->leafs[i]; + if (is_params_tensor(params_tensor_set, leaf)) { + continue; + } + auto shape = input_shape(leaf); + shape.leaf_index = i; + plan.input_shapes.push_back(shape); + } + + std::unordered_map producer_index; + producer_index.reserve(static_cast(n_nodes)); + for (int i = 0; i < n_nodes; ++i) { + producer_index[ggml_graph_node(gf, i)] = i; + } + + std::vector grouped_segments; + std::unordered_map group_to_segment; + for (int i = 0; i < n_nodes; ++i) { + ggml_tensor* node = ggml_graph_node(gf, i); + if (!is_graph_cut_tensor(node)) { + continue; + } + + plan.has_cuts = true; + std::string full_name(node->name); + std::string payload = full_name.substr(std::strlen(GGML_RUNNER_CUT_PREFIX)); + size_t sep = payload.find('|'); + std::string group = sep == std::string::npos ? payload : payload.substr(0, sep); + + auto it = group_to_segment.find(group); + if (it == group_to_segment.end()) { + Segment segment; + segment.group_name = group; + segment.output_node_indices.push_back(i); + group_to_segment[group] = grouped_segments.size(); + grouped_segments.push_back(std::move(segment)); + } else { + auto& segment = grouped_segments[it->second]; + segment.output_node_indices.push_back(i); + } + } + + if (!plan.has_cuts) { + return plan; + } + + std::unordered_set available_cut_output_node_indices; + available_cut_output_node_indices.reserve(static_cast(n_nodes)); + for (auto& segment : grouped_segments) { + build_segment(gf, + plan, + segment, + producer_index, + available_cut_output_node_indices, + backend, + params_tensor_set, + log_desc); + } + + ggml_tensor* final_output = ggml_graph_node(gf, -1); + if (final_output != nullptr && available_cut_output_node_indices.find(n_nodes - 1) == available_cut_output_node_indices.end()) { + Segment final_segment; + final_segment.group_name = "ggml_runner.final"; + final_segment.output_node_indices.push_back(n_nodes - 1); + build_segment(gf, + plan, + final_segment, + producer_index, + available_cut_output_node_indices, + backend, + params_tensor_set, + log_desc); + } + + return plan; + } + + Plan apply_max_vram_budget(ggml_cgraph* gf, + const Plan& base_plan, + size_t max_graph_vram_bytes, + ggml_backend_t backend, + const std::unordered_set& params_tensor_set, + const char* log_desc) { + GGML_ASSERT(backend != nullptr); + GGML_ASSERT(gf != nullptr); + int64_t t_budget_begin = ggml_time_ms(); + if (max_graph_vram_bytes == 0 || !base_plan.has_cuts || base_plan.segments.size() <= 1) { + return base_plan; + } + + const int n_nodes = ggml_graph_n_nodes(gf); + std::unordered_map producer_index; + producer_index.reserve(static_cast(n_nodes)); + for (int i = 0; i < n_nodes; ++i) { + producer_index[ggml_graph_node(gf, i)] = i; + } + + Plan merged_plan; + merged_plan.available = true; + merged_plan.has_cuts = base_plan.has_cuts; + merged_plan.valid = base_plan.valid; + merged_plan.n_nodes = base_plan.n_nodes; + merged_plan.n_leafs = base_plan.n_leafs; + + std::unordered_set available_cut_output_node_indices; + available_cut_output_node_indices.reserve(static_cast(n_nodes)); + + size_t start_segment_index = 0; + while (start_segment_index < base_plan.segments.size()) { + Plan single_plan; + auto single_available_cut_output_node_indices = available_cut_output_node_indices; + auto single_seed = make_segment_seed(base_plan, + start_segment_index, + start_segment_index); + build_segment(gf, + single_plan, + single_seed, + producer_index, + single_available_cut_output_node_indices, + backend, + params_tensor_set, + log_desc); + GGML_ASSERT(!single_plan.segments.empty()); + + size_t best_end_segment_index = start_segment_index; + bool can_merge_next_segment = graph_cut_segment_vram_bytes(single_plan.segments.back()) <= max_graph_vram_bytes; + + while (can_merge_next_segment && best_end_segment_index + 1 < base_plan.segments.size()) { + const size_t next_end_segment_index = best_end_segment_index + 1; + Plan candidate_plan; + auto candidate_available_cut_output_node_indices = available_cut_output_node_indices; + auto candidate_seed = make_segment_seed(base_plan, + start_segment_index, + next_end_segment_index); + build_segment(gf, + candidate_plan, + candidate_seed, + producer_index, + candidate_available_cut_output_node_indices, + backend, + params_tensor_set, + log_desc); + GGML_ASSERT(!candidate_plan.segments.empty()); + + const auto& candidate_segment = candidate_plan.segments.back(); + if (graph_cut_segment_vram_bytes(candidate_segment) > max_graph_vram_bytes) { + break; + } + + best_end_segment_index = next_end_segment_index; + } + + auto best_seed = make_segment_seed(base_plan, + start_segment_index, + best_end_segment_index); + build_segment(gf, + merged_plan, + best_seed, + producer_index, + available_cut_output_node_indices, + backend, + params_tensor_set, + log_desc); + start_segment_index = best_end_segment_index + 1; + } + + if (log_desc != nullptr && merged_plan.segments.size() != base_plan.segments.size()) { + LOG_INFO("%s graph cut max_vram=%.2f MB merged %zu segments -> %zu segments", + log_desc, + max_graph_vram_bytes / 1024.0 / 1024.0, + base_plan.segments.size(), + merged_plan.segments.size()); + } + + if (log_desc != nullptr) { + LOG_INFO("%s graph cut max_vram budget merge took %lld ms", + log_desc, + ggml_time_ms() - t_budget_begin); + } + + return merged_plan; + } + + Plan resolve_plan(ggml_backend_t backend, + ggml_cgraph* gf, + PlanCache* cache, + size_t max_graph_vram_bytes, + const std::unordered_set& params_tensor_set, + const char* log_desc) { + GGML_ASSERT(backend != nullptr); + GGML_ASSERT(gf != nullptr); + GGML_ASSERT(cache != nullptr); + + int64_t t_prepare_begin = ggml_time_ms(); + Plan base_plan; + int64_t t_plan_begin = ggml_time_ms(); + if (cache->graph_cut_plan.available && plan_matches_graph(gf, cache->graph_cut_plan)) { + base_plan = cache->graph_cut_plan; + } else { + base_plan = build_plan(backend, gf, params_tensor_set, log_desc); + cache->graph_cut_plan = base_plan; + cache->graph_cut_plan.available = true; + cache->budgeted_graph_cut_plan.available = false; + if (log_desc != nullptr) { + LOG_INFO("%s build cached graph cut plan done (taking %lld ms)", log_desc, ggml_time_ms() - t_plan_begin); + } + } + + Plan resolved_plan = base_plan; + if (max_graph_vram_bytes > 0 && base_plan.has_cuts) { + if (cache->budgeted_graph_cut_plan.available && + cache->budgeted_graph_cut_plan_max_vram_bytes == max_graph_vram_bytes && + plan_matches_graph(gf, cache->budgeted_graph_cut_plan)) { + resolved_plan = cache->budgeted_graph_cut_plan; + } else { + resolved_plan = apply_max_vram_budget(gf, + base_plan, + max_graph_vram_bytes, + backend, + params_tensor_set, + log_desc); + cache->budgeted_graph_cut_plan = resolved_plan; + cache->budgeted_graph_cut_plan.available = true; + cache->budgeted_graph_cut_plan_max_vram_bytes = max_graph_vram_bytes; + } + } + return resolved_plan; + } + +} // namespace sd::ggml_graph_cut diff --git a/src/ggml_graph_cut.h b/src/ggml_graph_cut.h new file mode 100644 index 000000000..e42859c58 --- /dev/null +++ b/src/ggml_graph_cut.h @@ -0,0 +1,104 @@ +#ifndef __SD_GGML_GRAPH_CUT_H__ +#define __SD_GGML_GRAPH_CUT_H__ + +#include +#include +#include +#include + +#include "ggml-backend.h" +#include "ggml.h" + +namespace sd::ggml_graph_cut { + + struct Segment { + enum InputType { + INPUT_EXTERNAL = 0, + INPUT_PREVIOUS_CUT, + INPUT_PARAM, + }; + + struct InputRef { + InputType type = INPUT_EXTERNAL; + std::string display_name; + int leaf_index = -1; + int node_index = -1; + }; + + size_t compute_buffer_size = 0; + size_t output_bytes = 0; + size_t input_external_bytes = 0; + size_t input_previous_cut_bytes = 0; + size_t input_param_bytes = 0; + std::string group_name; + std::vector internal_node_indices; + std::vector output_node_indices; + std::vector input_refs; + }; + + struct Plan { + struct InputShape { + int leaf_index = -1; + ggml_type type = GGML_TYPE_COUNT; + std::array ne = {0, 0, 0, 0}; + }; + + bool available = false; + bool has_cuts = false; + bool valid = true; + int n_nodes = 0; + int n_leafs = 0; + std::vector input_shapes; + std::vector segments; + }; + + struct PlanCache { + Plan graph_cut_plan; + Plan budgeted_graph_cut_plan; + size_t budgeted_graph_cut_plan_max_vram_bytes = 0; + }; + + static constexpr const char* GGML_RUNNER_CUT_PREFIX = "ggml_runner_cut:"; + + bool is_graph_cut_tensor(const ggml_tensor* tensor); + std::string make_graph_cut_name(const std::string& group, const std::string& output); + void mark_graph_cut(ggml_tensor* tensor, const std::string& group, const std::string& output); + int leaf_count(ggml_cgraph* gf); + ggml_tensor* leaf_tensor(ggml_cgraph* gf, int leaf_index); + ggml_backend_buffer_t tensor_buffer(const ggml_tensor* tensor); + ggml_tensor* cache_source_tensor(ggml_tensor* tensor); + size_t cache_tensor_bytes(const ggml_tensor* tensor); + bool plan_matches_graph(ggml_cgraph* gf, const Plan& plan); + ggml_tensor* output_tensor(ggml_cgraph* gf, const Segment& segment, size_t output_index); + ggml_tensor* input_tensor(ggml_cgraph* gf, const Segment::InputRef& input_ref); + std::vector param_tensors(ggml_cgraph* gf, const Segment& segment); + std::vector runtime_param_tensors(ggml_cgraph* gf, const Segment& segment, const char* log_desc); + std::unordered_set collect_future_input_names(ggml_cgraph* gf, + const Plan& plan, + size_t current_segment_index); + ggml_cgraph* build_segment_graph(ggml_cgraph* gf, + const Segment& segment, + ggml_context** graph_ctx_out); + size_t measure_segment_compute_buffer(ggml_backend_t backend, + ggml_cgraph* gf, + const Segment& segment, + const char* log_desc); + Plan build_plan(ggml_backend_t backend, + ggml_cgraph* gf, + const std::unordered_set& params_tensor_set, + const char* log_desc); + Plan apply_max_vram_budget(ggml_cgraph* gf, + const Plan& base_plan, + size_t max_graph_vram_bytes, + ggml_backend_t backend, + const std::unordered_set& params_tensor_set, + const char* log_desc); + Plan resolve_plan(ggml_backend_t backend, + ggml_cgraph* gf, + PlanCache* cache, + size_t max_graph_vram_bytes, + const std::unordered_set& params_tensor_set, + const char* log_desc); +} // namespace sd::ggml_graph_cut + +#endif diff --git a/src/llm.hpp b/src/llm.hpp index c6c296149..a67b4ebf3 100644 --- a/src/llm.hpp +++ b/src/llm.hpp @@ -14,469 +14,21 @@ #include #include -#include "clip.hpp" #include "ggml_extend.hpp" #include "json.hpp" #include "rope.hpp" -#include "tokenize_util.h" -#include "vocab/vocab.h" +#include "tokenizers/bpe_tokenizer.h" +#include "tokenizers/mistral_tokenizer.h" +#include "tokenizers/qwen2_tokenizer.h" namespace LLM { constexpr int LLM_GRAPH_SIZE = 10240; - class BPETokenizer { - protected: - std::map byte_encoder; - std::map byte_decoder; - std::map encoder; - std::map decoder; - std::map, int> bpe_ranks; - std::regex pat; - int encoder_len; - int bpe_len; - - std::string UNK_TOKEN; - std::string BOS_TOKEN; - std::string EOS_TOKEN; - std::string PAD_TOKEN; - - int UNK_TOKEN_ID; - int BOS_TOKEN_ID; - int EOS_TOKEN_ID; - int PAD_TOKEN_ID; - - std::vector special_tokens; - - bool add_bos_token = false; - - protected: - static std::string strip(const std::string& str) { - std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); - std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); - - if (start == std::string::npos) { - // String contains only whitespace characters - return ""; - } - - return str.substr(start, end - start + 1); - } - - static std::string whitespace_clean(std::string text) { - text = std::regex_replace(text, std::regex(R"(\s+)"), " "); - text = strip(text); - return text; - } - - static std::set> get_pairs(const std::vector& subwords) { - std::set> pairs; - if (subwords.size() == 0) { - return pairs; - } - std::u32string prev_subword = subwords[0]; - for (int i = 1; i < subwords.size(); i++) { - std::u32string subword = subwords[i]; - std::pair pair(prev_subword, subword); - pairs.insert(pair); - prev_subword = subword; - } - return pairs; - } - - bool is_special_token(const std::string& token) { - for (auto& special_token : special_tokens) { - if (special_token == token) { - return true; - } - } - return false; - } - - public: - BPETokenizer() = default; - - std::u32string bpe(const std::u32string& token) { - std::vector word; - - for (int i = 0; i < token.size(); i++) { - word.emplace_back(1, token[i]); - } - - std::set> pairs = get_pairs(word); - - if (pairs.empty()) { - return token; - } - - while (true) { - auto min_pair_iter = std::min_element(pairs.begin(), - pairs.end(), - [&](const std::pair& a, - const std::pair& b) { - if (bpe_ranks.find(a) == bpe_ranks.end()) { - return false; - } else if (bpe_ranks.find(b) == bpe_ranks.end()) { - return true; - } - return bpe_ranks.at(a) < bpe_ranks.at(b); - }); - - const std::pair& bigram = *min_pair_iter; - - if (bpe_ranks.find(bigram) == bpe_ranks.end()) { - break; - } - - std::u32string first = bigram.first; - std::u32string second = bigram.second; - std::vector new_word; - int32_t i = 0; - - while (i < word.size()) { - auto it = std::find(word.begin() + i, word.end(), first); - if (it == word.end()) { - new_word.insert(new_word.end(), word.begin() + i, word.end()); - break; - } - new_word.insert(new_word.end(), word.begin() + i, it); - i = static_cast(std::distance(word.begin(), it)); - - if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { - new_word.push_back(first + second); - i += 2; - } else { - new_word.push_back(word[i]); - i += 1; - } - } - - word = new_word; - - if (word.size() == 1) { - break; - } - pairs = get_pairs(word); - } - - std::u32string result; - for (int i = 0; i < word.size(); i++) { - result += word[i]; - if (i != word.size() - 1) { - result += utf8_to_utf32(" "); - } - } - - return result; - } - - std::vector tokenize(std::string text, - on_new_token_cb_t on_new_token_cb = nullptr, - size_t max_length = 0, - bool padding = false) { - std::vector tokens = encode(text, on_new_token_cb); - - if (max_length > 0) { - if (tokens.size() < max_length) { - tokens.resize(max_length); - } else { - if (padding) { - tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID); - } - } - } - - return tokens; - } - - void pad_tokens(std::vector& tokens, - std::vector& weights, - size_t max_length = 0, - bool padding = false) { - if (add_bos_token) { - tokens.insert(tokens.begin(), BOS_TOKEN_ID); - weights.insert(weights.begin(), 1.f); - } - if (max_length > 0 && padding) { - size_t n = static_cast(std::ceil(tokens.size() * 1.f / max_length)); - if (n == 0) { - n = 1; - } - size_t length = max_length * n; - LOG_DEBUG("token length: %llu", length); - tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID); - weights.insert(weights.end(), length - weights.size(), 1.f); - } - } - - std::vector encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) { - std::string original_text = text; - std::vector bpe_tokens; - std::vector token_strs; - - auto splited_texts = split_with_special_tokens(text, special_tokens); - - for (auto& splited_text : splited_texts) { - if (is_special_token(splited_text)) { - bpe_tokens.push_back(encoder[utf8_to_utf32(splited_text)]); - token_strs.push_back(splited_text); - continue; - } - auto tokens = token_split(splited_text); - for (auto& token : tokens) { - if (on_new_token_cb != nullptr) { - bool skip = on_new_token_cb(token, bpe_tokens); - if (skip) { - continue; - } - } - - std::string token_str = token; - std::u32string utf32_token; - for (int i = 0; i < token_str.length(); i++) { - unsigned char b = token_str[i]; - utf32_token += byte_encoder[b]; - } - auto bpe_strs = bpe(utf32_token); - size_t start = 0; - size_t pos; - while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { - auto bpe_str = bpe_strs.substr(start, pos - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - - start = pos + 1; - } - auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - } - } - - std::stringstream ss; - ss << "["; - for (auto token : token_strs) { - ss << "\"" << token << "\", "; - } - ss << "]"; - LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); - // printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str()); - return bpe_tokens; - } - }; - - class Qwen2Tokenizer : public BPETokenizer { - protected: - void load_from_merges(const std::string& merges_utf8_str) { - auto byte_unicode_pairs = bytes_to_unicode(); - // printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size()); - byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); - for (auto& pair : byte_unicode_pairs) { - byte_decoder[pair.second] = pair.first; - } - // for (auto & pair: byte_unicode_pairs) { - // std::cout << pair.first << ": " << pair.second << std::endl; - // } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - LOG_DEBUG("merges size %llu", merges.size()); - merges = std::vector(merges.begin(), merges.end()); - std::vector> merge_pairs; - // int print_num = 10; - for (const auto& merge : merges) { - size_t space_pos = merge.find(' '); - merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // if (print_num > 0) { - // print_num--; - // printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(), - // utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - // } - } - - std::vector tokens; - for (const auto& pair : byte_unicode_pairs) { - tokens.push_back(pair.second); - } - for (const auto& merge : merge_pairs) { - tokens.push_back(merge.first + merge.second); - } - for (auto& special_token : special_tokens) { - tokens.push_back(utf8_to_utf32(special_token)); - } - - int i = 0; - for (const auto& token : tokens) { - encoder[token] = i; - decoder[i] = token; - i++; - } - encoder_len = i; - LOG_DEBUG("vocab size: %d", encoder_len); - - int rank = 0; - for (const auto& merge : merge_pairs) { - bpe_ranks[merge] = rank++; - } - bpe_len = rank; - }; - - public: - explicit Qwen2Tokenizer(const std::string& merges_utf8_str = "") { - UNK_TOKEN = "<|endoftext|>"; - EOS_TOKEN = "<|endoftext|>"; - PAD_TOKEN = "<|endoftext|>"; - - UNK_TOKEN_ID = 151643; - EOS_TOKEN_ID = 151643; - PAD_TOKEN_ID = 151643; - - special_tokens = { - "<|endoftext|>", - "<|im_start|>", - "<|im_end|>", - "<|object_ref_start|>", - "<|object_ref_end|>", - "<|box_start|>", - "<|box_end|>", - "<|quad_start|>", - "<|quad_end|>", - "<|vision_start|>", - "<|vision_end|>", - "<|vision_pad|>", - "<|image_pad|>", - "<|video_pad|>", - "", - "", - "<|fim_prefix|>", - "<|fim_middle|>", - "<|fim_suffix|>", - "<|fim_pad|>", - "<|repo_name|>", - "<|file_sep|>", - "", - "", - "", - "", - }; - - if (merges_utf8_str.size() > 0) { - load_from_merges(merges_utf8_str); - } else { - load_from_merges(load_qwen2_merges()); - } - } - }; - - class MistralTokenizer : public BPETokenizer { - protected: - void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) { - nlohmann::json vocab; - - try { - vocab = nlohmann::json::parse(vocab_utf8_str); - } catch (const nlohmann::json::parse_error&) { - GGML_ABORT("invalid vocab json str"); - } - for (const auto& [key, value] : vocab.items()) { - std::u32string token = utf8_to_utf32(key); - int i = value; - encoder[token] = i; - decoder[i] = token; - } - encoder_len = static_cast(vocab.size()); - LOG_DEBUG("vocab size: %d", encoder_len); - - auto byte_unicode_pairs = bytes_to_unicode(); - byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); - for (auto& pair : byte_unicode_pairs) { - byte_decoder[pair.second] = pair.first; - } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - LOG_DEBUG("merges size %llu", merges.size()); - merges = std::vector(merges.begin(), merges.end()); - std::vector> merge_pairs; - // int print_num = 10; - for (const auto& merge : merges) { - size_t space_pos = merge.find(' '); - merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // if (print_num > 0) { - // print_num--; - // printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(), - // utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - // } - } - - int rank = 0; - for (const auto& merge : merge_pairs) { - bpe_ranks[merge] = rank++; - } - bpe_len = rank; - }; - - public: - explicit MistralTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "") { - add_bos_token = true; - - UNK_TOKEN = ""; - BOS_TOKEN = ""; - EOS_TOKEN = ""; - PAD_TOKEN = ""; - - UNK_TOKEN_ID = 0; - BOS_TOKEN_ID = 1; - EOS_TOKEN_ID = 2; - PAD_TOKEN_ID = 11; - - special_tokens = { - "", - "", - "", - "[INST]", - "[/INST]", - "[AVAILABLE_TOOLS]", - "[/AVAILABLE_TOOLS]", - "[TOOL_RESULTS]", - "[/TOOL_RESULTS]", - "[TOOL_CALLS]", - "[IMG]", - "", - "[IMG_BREAK]", - "[IMG_END]", - "[PREFIX]", - "[MIDDLE]", - "[SUFFIX]", - "[SYSTEM_PROMPT]", - "[/SYSTEM_PROMPT]", - "[TOOL_CONTENT]", - }; - for (int i = 20; i < 1000; i++) { - special_tokens.push_back(""); - } - - if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) { - load_from_merges(merges_utf8_str, vocab_utf8_str); - } else { - load_from_merges(load_mistral_merges(), load_mistral_vocab_json()); - } - } - }; - enum class LLMArch { QWEN2_5_VL, QWEN3, MISTRAL_SMALL_3_2, + MINISTRAL_3_3B, ARCH_COUNT, }; @@ -484,6 +36,7 @@ namespace LLM { "qwen2.5vl", "qwen3", "mistral_small3.2", + "ministral3.3b", }; struct LLMVisionParams { @@ -793,6 +346,7 @@ namespace LLM { auto merger = std::dynamic_pointer_cast(blocks["merger"]); auto x = patch_embed->forward(ctx, pixel_values); + sd::ggml_graph_cut::mark_graph_cut(x, "llm.vision.prelude", "x"); x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]); x = ggml_get_rows(ctx->ggml_ctx, x, window_index); @@ -806,9 +360,11 @@ namespace LLM { mask = nullptr; } x = block->forward(ctx, x, pe, mask); + sd::ggml_graph_cut::mark_graph_cut(x, "llm.vision.blocks." + std::to_string(i), "x"); } x = merger->forward(ctx, x); + sd::ggml_graph_cut::mark_graph_cut(x, "llm.vision.final", "x"); x = ggml_get_rows(ctx->ggml_ctx, x, window_inverse_index); @@ -868,6 +424,9 @@ namespace LLM { if (arch == LLMArch::MISTRAL_SMALL_3_2) { q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + } else if (arch == LLMArch::MINISTRAL_3_3B) { + q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 262144, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 262144, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); } else if (arch == LLMArch::QWEN3) { q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); @@ -950,6 +509,7 @@ namespace LLM { auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto x = embed_tokens->forward(ctx, input_ids); + sd::ggml_graph_cut::mark_graph_cut(x, "llm.text.prelude", "x"); std::vector intermediate_outputs; @@ -996,6 +556,10 @@ namespace LLM { auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); x = block->forward(ctx, x, input_pos, attention_mask); + if (out_layers.size() > 1) { + x = ggml_cont(ctx->ggml_ctx, x); + } + sd::ggml_graph_cut::mark_graph_cut(x, "llm.text.layers." + std::to_string(i), "x"); if (out_layers.find(i + 1) != out_layers.end()) { intermediate_outputs.push_back(x); } @@ -1083,7 +647,7 @@ namespace LLM { bool enable_vision_ = false) : GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) { params.arch = arch; - if (arch == LLMArch::MISTRAL_SMALL_3_2) { + if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) { params.head_dim = 128; params.num_heads = 32; params.num_kv_heads = 8; @@ -1195,7 +759,7 @@ namespace LLM { } int64_t n_tokens = input_ids->ne[0]; - if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::QWEN3) { + if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::MINISTRAL_3_3B || params.arch == LLMArch::QWEN3) { input_pos_vec.resize(n_tokens); for (int i = 0; i < n_tokens; ++i) { input_pos_vec[i] = i; @@ -1431,7 +995,7 @@ namespace LLM { const std::string prefix = "", bool enable_vision = false) : model(arch, backend, offload_params_to_cpu, tensor_storage_map, prefix, enable_vision) { - if (arch == LLMArch::MISTRAL_SMALL_3_2) { + if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) { tokenizer = std::make_shared(); } else { tokenizer = std::make_shared(); @@ -1479,7 +1043,7 @@ namespace LLM { weights.insert(weights.end(), curr_tokens.size(), curr_weight); } - tokenizer->pad_tokens(tokens, weights, max_length, padding); + tokenizer->pad_tokens(tokens, &weights, nullptr, padding ? max_length : 0, padding ? max_length : 100000000, padding); // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", "; diff --git a/src/lora.hpp b/src/lora.hpp index d4a749ef9..b57bc4226 100644 --- a/src/lora.hpp +++ b/src/lora.hpp @@ -129,7 +129,7 @@ struct LoraModel : public GGMLRunner { } } - ggml_tensor* get_lora_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) { + ggml_tensor* get_lora_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_backend_t backend) { ggml_tensor* updown = nullptr; int index = 0; while (true) { @@ -152,17 +152,17 @@ struct LoraModel : public GGMLRunner { auto iter = lora_tensors.find(lora_up_name); if (iter != lora_tensors.end()) { - lora_up = ggml_ext_cast_f32(ctx, iter->second); + lora_up = ggml_ext_cast_f32(ctx, backend, iter->second); } iter = lora_tensors.find(lora_mid_name); if (iter != lora_tensors.end()) { - lora_mid = ggml_ext_cast_f32(ctx, iter->second); + lora_mid = ggml_ext_cast_f32(ctx, backend, iter->second); } iter = lora_tensors.find(lora_down_name); if (iter != lora_tensors.end()) { - lora_down = ggml_ext_cast_f32(ctx, iter->second); + lora_down = ggml_ext_cast_f32(ctx, backend, iter->second); } if (lora_up == nullptr || lora_down == nullptr) { @@ -208,7 +208,7 @@ struct LoraModel : public GGMLRunner { return updown; } - ggml_tensor* get_raw_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) { + ggml_tensor* get_raw_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_backend_t backend) { ggml_tensor* updown = nullptr; int index = 0; while (true) { @@ -225,7 +225,7 @@ struct LoraModel : public GGMLRunner { auto iter = lora_tensors.find(diff_name); if (iter != lora_tensors.end()) { - curr_updown = ggml_ext_cast_f32(ctx, iter->second); + curr_updown = ggml_ext_cast_f32(ctx, backend, iter->second); } else { break; } @@ -248,7 +248,7 @@ struct LoraModel : public GGMLRunner { return updown; } - ggml_tensor* get_loha_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) { + ggml_tensor* get_loha_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_backend_t backend) { ggml_tensor* updown = nullptr; int index = 0; while (true) { @@ -276,33 +276,33 @@ struct LoraModel : public GGMLRunner { auto iter = lora_tensors.find(hada_1_down_name); if (iter != lora_tensors.end()) { - hada_1_down = ggml_ext_cast_f32(ctx, iter->second); + hada_1_down = ggml_ext_cast_f32(ctx, backend, iter->second); } iter = lora_tensors.find(hada_1_up_name); if (iter != lora_tensors.end()) { - hada_1_up = ggml_ext_cast_f32(ctx, iter->second); + hada_1_up = ggml_ext_cast_f32(ctx, backend, iter->second); } iter = lora_tensors.find(hada_1_mid_name); if (iter != lora_tensors.end()) { - hada_1_mid = ggml_ext_cast_f32(ctx, iter->second); + hada_1_mid = ggml_ext_cast_f32(ctx, backend, iter->second); hada_1_up = ggml_cont(ctx, ggml_transpose(ctx, hada_1_up)); } iter = lora_tensors.find(hada_2_down_name); if (iter != lora_tensors.end()) { - hada_2_down = ggml_ext_cast_f32(ctx, iter->second); + hada_2_down = ggml_ext_cast_f32(ctx, backend, iter->second); } iter = lora_tensors.find(hada_2_up_name); if (iter != lora_tensors.end()) { - hada_2_up = ggml_ext_cast_f32(ctx, iter->second); + hada_2_up = ggml_ext_cast_f32(ctx, backend, iter->second); } iter = lora_tensors.find(hada_2_mid_name); if (iter != lora_tensors.end()) { - hada_2_mid = ggml_ext_cast_f32(ctx, iter->second); + hada_2_mid = ggml_ext_cast_f32(ctx, backend, iter->second); hada_2_up = ggml_cont(ctx, ggml_transpose(ctx, hada_2_up)); } @@ -351,7 +351,7 @@ struct LoraModel : public GGMLRunner { return updown; } - ggml_tensor* get_lokr_weight_diff(const std::string& model_tensor_name, ggml_context* ctx) { + ggml_tensor* get_lokr_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_backend_t backend) { ggml_tensor* updown = nullptr; int index = 0; while (true) { @@ -378,24 +378,24 @@ struct LoraModel : public GGMLRunner { auto iter = lora_tensors.find(lokr_w1_name); if (iter != lora_tensors.end()) { - lokr_w1 = ggml_ext_cast_f32(ctx, iter->second); + lokr_w1 = ggml_ext_cast_f32(ctx, backend, iter->second); } iter = lora_tensors.find(lokr_w2_name); if (iter != lora_tensors.end()) { - lokr_w2 = ggml_ext_cast_f32(ctx, iter->second); + lokr_w2 = ggml_ext_cast_f32(ctx, backend, iter->second); } int64_t rank = 1; if (lokr_w1 == nullptr) { iter = lora_tensors.find(lokr_w1_a_name); if (iter != lora_tensors.end()) { - lokr_w1_a = ggml_ext_cast_f32(ctx, iter->second); + lokr_w1_a = ggml_ext_cast_f32(ctx, backend, iter->second); } iter = lora_tensors.find(lokr_w1_b_name); if (iter != lora_tensors.end()) { - lokr_w1_b = ggml_ext_cast_f32(ctx, iter->second); + lokr_w1_b = ggml_ext_cast_f32(ctx, backend, iter->second); } if (lokr_w1_a == nullptr || lokr_w1_b == nullptr) { @@ -410,12 +410,12 @@ struct LoraModel : public GGMLRunner { if (lokr_w2 == nullptr) { iter = lora_tensors.find(lokr_w2_a_name); if (iter != lora_tensors.end()) { - lokr_w2_a = ggml_ext_cast_f32(ctx, iter->second); + lokr_w2_a = ggml_ext_cast_f32(ctx, backend, iter->second); } iter = lora_tensors.find(lokr_w2_b_name); if (iter != lora_tensors.end()) { - lokr_w2_b = ggml_ext_cast_f32(ctx, iter->second); + lokr_w2_b = ggml_ext_cast_f32(ctx, backend, iter->second); } if (lokr_w2_a == nullptr || lokr_w2_b == nullptr) { @@ -468,23 +468,23 @@ struct LoraModel : public GGMLRunner { return updown; } - ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora_and_lokr = true) { + ggml_tensor* get_weight_diff(const std::string& model_tensor_name, ggml_backend_t backend, ggml_context* ctx, ggml_tensor* model_tensor, bool with_lora_and_lokr = true) { // lora ggml_tensor* diff = nullptr; if (with_lora_and_lokr) { - diff = get_lora_weight_diff(model_tensor_name, ctx); + diff = get_lora_weight_diff(model_tensor_name, ctx, backend); } // diff if (diff == nullptr) { - diff = get_raw_weight_diff(model_tensor_name, ctx); + diff = get_raw_weight_diff(model_tensor_name, ctx, backend); } // loha if (diff == nullptr) { - diff = get_loha_weight_diff(model_tensor_name, ctx); + diff = get_loha_weight_diff(model_tensor_name, ctx, backend); } // lokr if (diff == nullptr && with_lora_and_lokr) { - diff = get_lokr_weight_diff(model_tensor_name, ctx); + diff = get_lokr_weight_diff(model_tensor_name, ctx, backend); } if (diff != nullptr) { if (ggml_nelements(diff) < ggml_nelements(model_tensor)) { @@ -502,6 +502,7 @@ struct LoraModel : public GGMLRunner { } ggml_tensor* get_out_diff(ggml_context* ctx, + ggml_backend_t backend, ggml_tensor* x, WeightAdapter::ForwardParams forward_params, const std::string& model_tensor_name) { @@ -590,7 +591,7 @@ struct LoraModel : public GGMLRunner { } scale_value *= multiplier; - auto curr_out_diff = ggml_ext_lokr_forward(ctx, x, lokr_w1, lokr_w1_a, lokr_w1_b, lokr_w2, lokr_w2_a, lokr_w2_b, is_conv2d, forward_params.conv2d, scale_value); + auto curr_out_diff = ggml_ext_lokr_forward(ctx, backend, x, lokr_w1, lokr_w1_a, lokr_w1_b, lokr_w2, lokr_w2_a, lokr_w2_b, is_conv2d, forward_params.conv2d, scale_value); if (out_diff == nullptr) { out_diff = curr_out_diff; } else { @@ -761,7 +762,7 @@ struct LoraModel : public GGMLRunner { ggml_tensor* model_tensor = it.second; // lora - ggml_tensor* diff = get_weight_diff(model_tensor_name, compute_ctx, model_tensor); + ggml_tensor* diff = get_weight_diff(model_tensor_name, runtime_backend, compute_ctx, model_tensor); if (diff == nullptr) { continue; } @@ -774,7 +775,7 @@ struct LoraModel : public GGMLRunner { ggml_tensor* final_tensor; if (model_tensor->type != GGML_TYPE_F32 && model_tensor->type != GGML_TYPE_F16) { - final_tensor = ggml_ext_cast_f32(compute_ctx, model_tensor); + final_tensor = ggml_ext_cast_f32(compute_ctx, runtime_backend, model_tensor); final_tensor = ggml_add_inplace(compute_ctx, final_tensor, diff); final_tensor = ggml_cpy(compute_ctx, final_tensor, model_tensor); } else { @@ -841,34 +842,35 @@ struct MultiLoraAdapter : public WeightAdapter { : lora_models(lora_models) { } - ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name, bool with_lora_and_lokr) { + ggml_tensor* patch_weight(ggml_context* ctx, ggml_backend_t backend, ggml_tensor* weight, const std::string& weight_name, bool with_lora_and_lokr) { for (auto& lora_model : lora_models) { - ggml_tensor* diff = lora_model->get_weight_diff(weight_name, ctx, weight, with_lora_and_lokr); + ggml_tensor* diff = lora_model->get_weight_diff(weight_name, backend, ctx, weight, with_lora_and_lokr); if (diff == nullptr) { continue; } if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) { - weight = ggml_ext_cast_f32(ctx, weight); + weight = ggml_ext_cast_f32(ctx, backend, weight); } weight = ggml_add(ctx, weight, diff); } return weight; } - ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) override { - return patch_weight(ctx, weight, weight_name, true); + ggml_tensor* patch_weight(ggml_context* ctx, ggml_backend_t backend, ggml_tensor* weight, const std::string& weight_name) override { + return patch_weight(ctx, backend, weight, weight_name, true); } ggml_tensor* forward_with_lora(ggml_context* ctx, + ggml_backend_t backend, ggml_tensor* x, ggml_tensor* w, ggml_tensor* b, const std::string& prefix, WeightAdapter::ForwardParams forward_params) override { - w = patch_weight(ctx, w, prefix + "weight", false); + w = patch_weight(ctx, backend, w, prefix + "weight", false); if (b) { - b = patch_weight(ctx, b, prefix + "bias", false); + b = patch_weight(ctx, backend, b, prefix + "bias", false); } ggml_tensor* out; if (forward_params.op_type == ForwardParams::op_type_t::OP_LINEAR) { @@ -890,7 +892,7 @@ struct MultiLoraAdapter : public WeightAdapter { forward_params.conv2d.scale); } for (auto& lora_model : lora_models) { - ggml_tensor* out_diff = lora_model->get_out_diff(ctx, x, forward_params, prefix + "weight"); + ggml_tensor* out_diff = lora_model->get_out_diff(ctx, backend, x, forward_params, prefix + "weight"); if (out_diff == nullptr) { continue; } diff --git a/src/mmdit.hpp b/src/mmdit.hpp index e75736c5d..e57041dc9 100644 --- a/src/mmdit.hpp +++ b/src/mmdit.hpp @@ -767,6 +767,8 @@ struct MMDiT : public GGMLBlock { auto context_x = block->forward(ctx, context, x, c_mod); context = context_x.first; x = context_x.second; + sd::ggml_graph_cut::mark_graph_cut(context, "mmdit.joint_blocks." + std::to_string(i), "context"); + sd::ggml_graph_cut::mark_graph_cut(x, "mmdit.joint_blocks." + std::to_string(i), "x"); } x = final_layer->forward(ctx, x, c_mod); // (N, T, patch_size ** 2 * out_channels) @@ -809,6 +811,11 @@ struct MMDiT : public GGMLBlock { context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536] } + sd::ggml_graph_cut::mark_graph_cut(x, "mmdit.prelude", "x"); + sd::ggml_graph_cut::mark_graph_cut(c, "mmdit.prelude", "c"); + if (context != nullptr) { + sd::ggml_graph_cut::mark_graph_cut(context, "mmdit.prelude", "context"); + } x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) diff --git a/src/model.cpp b/src/model.cpp index 1ccb03cf3..8fdde3b76 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -12,64 +13,21 @@ #include #include -#include "gguf_reader.hpp" #include "model.h" +#include "model_io/gguf_io.h" +#include "model_io/safetensors_io.h" +#include "model_io/torch_legacy_io.h" +#include "model_io/torch_zip_io.h" #include "stable-diffusion.h" #include "util.h" #include "ggml-alloc.h" #include "ggml-backend.h" -#include "ggml-cpu.h" #include "ggml.h" +#include "ggml_extend_backend.hpp" +#include "zip.h" #include "name_conversion.h" -#include "stable-diffusion.h" - -#ifdef SD_USE_METAL -#include "ggml-metal.h" -#endif - -#ifdef SD_USE_VULKAN -#include "ggml-vulkan.h" -#endif - -#ifdef SD_USE_OPENCL -#include "ggml-opencl.h" -#endif - -#define ST_HEADER_SIZE_LEN 8 - -uint64_t read_u64(uint8_t* buffer) { - // little endian - uint64_t value = 0; - value |= static_cast(buffer[7]) << 56; - value |= static_cast(buffer[6]) << 48; - value |= static_cast(buffer[5]) << 40; - value |= static_cast(buffer[4]) << 32; - value |= static_cast(buffer[3]) << 24; - value |= static_cast(buffer[2]) << 16; - value |= static_cast(buffer[1]) << 8; - value |= static_cast(buffer[0]); - return value; -} - -int32_t read_int(uint8_t* buffer) { - // little endian - int value = 0; - value |= buffer[3] << 24; - value |= buffer[2] << 16; - value |= buffer[1] << 8; - value |= buffer[0]; - return value; -} - -uint16_t read_short(uint8_t* buffer) { - // little endian - uint16_t value = 0; - value |= buffer[1] << 8; - value |= buffer[0]; - return value; -} /*================================================= Preprocess ==================================================*/ @@ -110,7 +68,7 @@ const char* unused_tensors[] = { "first_stage_model.bn.", }; -bool is_unused_tensor(std::string name) { +bool is_unused_tensor(const std::string& name) { for (size_t i = 0; i < sizeof(unused_tensors) / sizeof(const char*); i++) { if (starts_with(name, unused_tensors[i])) { return true; @@ -250,78 +208,6 @@ void ModelLoader::add_tensor_storage(const TensorStorage& tensor_storage) { tensor_storage_map[tensor_storage.name] = tensor_storage; } -bool is_zip_file(const std::string& file_path) { - zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); - if (zip == nullptr) { - return false; - } - zip_close(zip); - return true; -} - -bool is_gguf_file(const std::string& file_path) { - std::ifstream file(file_path, std::ios::binary); - if (!file.is_open()) { - return false; - } - - char magic[4]; - - file.read(magic, sizeof(magic)); - if (!file) { - return false; - } - for (uint32_t i = 0; i < sizeof(magic); i++) { - if (magic[i] != GGUF_MAGIC[i]) { - return false; - } - } - - return true; -} - -bool is_safetensors_file(const std::string& file_path) { - std::ifstream file(file_path, std::ios::binary); - if (!file.is_open()) { - return false; - } - - // get file size - file.seekg(0, file.end); - size_t file_size_ = file.tellg(); - file.seekg(0, file.beg); - - // read header size - if (file_size_ <= ST_HEADER_SIZE_LEN) { - return false; - } - - uint8_t header_size_buf[ST_HEADER_SIZE_LEN]; - file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN); - if (!file) { - return false; - } - - size_t header_size_ = read_u64(header_size_buf); - if (header_size_ >= file_size_ || header_size_ <= 2) { - return false; - } - - // read header - std::vector header_buf; - header_buf.resize(header_size_ + 1); - header_buf[header_size_] = '\0'; - file.read(header_buf.data(), header_size_); - if (!file) { - return false; - } - nlohmann::json header_ = nlohmann::json::parse(header_buf.data()); - if (header_.is_discarded()) { - return false; - } - return true; -} - bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) { if (is_directory(file_path)) { LOG_INFO("load %s using diffusers format", file_path.c_str()); @@ -332,9 +218,12 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string } else if (is_safetensors_file(file_path)) { LOG_INFO("load %s using safetensors format", file_path.c_str()); return init_from_safetensors_file(file_path, prefix); - } else if (is_zip_file(file_path)) { - LOG_INFO("load %s using checkpoint format", file_path.c_str()); - return init_from_ckpt_file(file_path, prefix); + } else if (is_torch_zip_file(file_path)) { + LOG_INFO("load %s using torch zip format", file_path.c_str()); + return init_from_torch_zip_file(file_path, prefix); + } else if (init_from_torch_legacy_file(file_path, prefix)) { + LOG_INFO("load %s using torch legacy format", file_path.c_str()); + return true; } else { if (file_exists(file_path)) { LOG_WARN("unknown format %s", file_path.c_str()); @@ -374,230 +263,121 @@ bool ModelLoader::init_from_file_and_convert_name(const std::string& file_path, bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::string& prefix) { LOG_DEBUG("init from '%s'", file_path.c_str()); - file_paths_.push_back(file_path); - size_t file_index = file_paths_.size() - 1; - gguf_context* ctx_gguf_ = nullptr; - ggml_context* ctx_meta_ = nullptr; - - ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_}); - if (!ctx_gguf_) { - LOG_ERROR("failed to open '%s' with gguf_init_from_file. Try to open it with GGUFReader.", file_path.c_str()); - GGUFReader gguf_reader; - if (!gguf_reader.load(file_path)) { - LOG_ERROR("failed to open '%s' with GGUFReader.", file_path.c_str()); - return false; - } - - size_t data_offset = gguf_reader.data_offset(); - for (const auto& gguf_tensor_info : gguf_reader.tensors()) { - std::string name = gguf_tensor_info.name; - if (!starts_with(name, prefix)) { - name = prefix + name; - } - - TensorStorage tensor_storage( - name, - gguf_tensor_info.type, - gguf_tensor_info.shape.data(), - static_cast(gguf_tensor_info.shape.size()), - file_index, - data_offset + gguf_tensor_info.offset); - - // LOG_DEBUG("%s %s", name.c_str(), tensor_storage.to_string().c_str()); - - add_tensor_storage(tensor_storage); - } - - return true; + std::vector tensor_storages; + std::string error; + if (!read_gguf_file(file_path, tensor_storages, &error)) { + LOG_ERROR("%s", error.c_str()); + return false; } - int n_tensors = static_cast(gguf_get_n_tensors(ctx_gguf_)); - - size_t total_size = 0; - size_t data_offset = gguf_get_data_offset(ctx_gguf_); - for (int i = 0; i < n_tensors; i++) { - std::string name = gguf_get_tensor_name(ctx_gguf_, i); - ggml_tensor* dummy = ggml_get_tensor(ctx_meta_, name.c_str()); - size_t offset = data_offset + gguf_get_tensor_offset(ctx_gguf_, i); + file_paths_.push_back(file_path); + size_t file_index = file_paths_.size() - 1; - // LOG_DEBUG("%s", name.c_str()); + for (auto& tensor_storage : tensor_storages) { + // LOG_DEBUG("%s", tensor_storage.name.c_str()); - if (!starts_with(name, prefix)) { - name = prefix + name; + if (!starts_with(tensor_storage.name, prefix)) { + tensor_storage.name = prefix + tensor_storage.name; } - - TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), file_index, offset); - - GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes()); + tensor_storage.file_index = file_index; add_tensor_storage(tensor_storage); } - gguf_free(ctx_gguf_); - ggml_free(ctx_meta_); - return true; } /*================================================= SafeTensorsModelLoader ==================================================*/ -ggml_type str_to_ggml_type(const std::string& dtype) { - ggml_type ttype = GGML_TYPE_COUNT; - if (dtype == "F16") { - ttype = GGML_TYPE_F16; - } else if (dtype == "BF16") { - ttype = GGML_TYPE_BF16; - } else if (dtype == "F32") { - ttype = GGML_TYPE_F32; - } else if (dtype == "F64") { - ttype = GGML_TYPE_F32; - } else if (dtype == "F8_E4M3") { - ttype = GGML_TYPE_F16; - } else if (dtype == "F8_E5M2") { - ttype = GGML_TYPE_F16; - } else if (dtype == "I64") { - ttype = GGML_TYPE_I32; - } - return ttype; -} - -// https://huggingface.co/docs/safetensors/index bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) { LOG_DEBUG("init from '%s', prefix = '%s'", file_path.c_str(), prefix.c_str()); - file_paths_.push_back(file_path); - size_t file_index = file_paths_.size() - 1; - std::ifstream file(file_path, std::ios::binary); - if (!file.is_open()) { - LOG_ERROR("failed to open '%s'", file_path.c_str()); - file_paths_.pop_back(); + + std::vector tensor_storages; + std::string error; + if (!read_safetensors_file(file_path, tensor_storages, &error)) { + LOG_ERROR("%s", error.c_str()); return false; } - // get file size - file.seekg(0, file.end); - size_t file_size_ = file.tellg(); - file.seekg(0, file.beg); + file_paths_.push_back(file_path); + size_t file_index = file_paths_.size() - 1; - // read header size - if (file_size_ <= ST_HEADER_SIZE_LEN) { - LOG_ERROR("invalid safetensor file '%s'", file_path.c_str()); - file_paths_.pop_back(); - return false; - } + for (auto& tensor_storage : tensor_storages) { + if (is_unused_tensor(tensor_storage.name)) { + continue; + } - uint8_t header_size_buf[ST_HEADER_SIZE_LEN]; - file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN); - if (!file) { - LOG_ERROR("read safetensors header size failed: '%s'", file_path.c_str()); - return false; - } + if (!starts_with(tensor_storage.name, prefix)) { + tensor_storage.name = prefix + tensor_storage.name; + } + tensor_storage.file_index = file_index; - size_t header_size_ = read_u64(header_size_buf); - if (header_size_ >= file_size_) { - LOG_ERROR("invalid safetensor file '%s'", file_path.c_str()); - file_paths_.pop_back(); - return false; - } + add_tensor_storage(tensor_storage); - // read header - std::vector header_buf; - header_buf.resize(header_size_ + 1); - header_buf[header_size_] = '\0'; - file.read(header_buf.data(), header_size_); - if (!file) { - LOG_ERROR("read safetensors header failed: '%s'", file_path.c_str()); - file_paths_.pop_back(); - return false; + // LOG_DEBUG("%s", tensor_storage.to_string().c_str()); } - nlohmann::json header_ = nlohmann::json::parse(header_buf.data()); + return true; +} - for (auto& item : header_.items()) { - std::string name = item.key(); - nlohmann::json tensor_info = item.value(); - // LOG_DEBUG("%s %s\n", name.c_str(), tensor_info.dump().c_str()); +/*================================================= TorchLegacyModelLoader ==================================================*/ - if (name == "__metadata__") { - continue; - } +bool ModelLoader::init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix) { + LOG_DEBUG("init from torch legacy '%s'", file_path.c_str()); - if (is_unused_tensor(name)) { - continue; + std::vector tensor_storages; + std::string error; + if (!read_torch_legacy_file(file_path, tensor_storages, &error)) { + if ((!error.empty()) && (ends_with(file_path, ".pt") || ends_with(file_path, ".pth"))) { + LOG_WARN("%s", error.c_str()); } + return false; + } - std::string dtype = tensor_info["dtype"]; - nlohmann::json shape = tensor_info["shape"]; + file_paths_.push_back(file_path); + size_t file_index = file_paths_.size() - 1; - if (dtype == "U8") { + for (auto& tensor_storage : tensor_storages) { + if (is_unused_tensor(tensor_storage.name)) { continue; } - size_t begin = tensor_info["data_offsets"][0].get(); - size_t end = tensor_info["data_offsets"][1].get(); - - ggml_type type = str_to_ggml_type(dtype); - if (type == GGML_TYPE_COUNT) { - LOG_ERROR("unsupported dtype '%s' (tensor '%s')", dtype.c_str(), name.c_str()); - return false; - } - - if (shape.size() > SD_MAX_DIMS) { - LOG_ERROR("invalid tensor '%s'", name.c_str()); - return false; + if (!starts_with(tensor_storage.name, prefix)) { + tensor_storage.name = prefix + tensor_storage.name; } + tensor_storage.file_index = file_index; - int n_dims = (int)shape.size(); - int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; - for (int i = 0; i < n_dims; i++) { - ne[i] = shape[i].get(); - } + add_tensor_storage(tensor_storage); + } - if (n_dims == 5) { - n_dims = 4; - ne[0] = ne[0] * ne[1]; - ne[1] = ne[2]; - ne[2] = ne[3]; - ne[3] = ne[4]; - } + return true; +} - // ggml_n_dims returns 1 for scalars - if (n_dims == 0) { - n_dims = 1; - } +/*================================================= TorchZipModelLoader ==================================================*/ - if (!starts_with(name, prefix)) { - name = prefix + name; - } +bool ModelLoader::init_from_torch_zip_file(const std::string& file_path, const std::string& prefix) { + LOG_DEBUG("init from '%s'", file_path.c_str()); - TensorStorage tensor_storage(name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin); - tensor_storage.reverse_ne(); + std::vector tensor_storages; + std::string error; + if (!read_torch_zip_file(file_path, tensor_storages, &error)) { + LOG_ERROR("%s", error.c_str()); + return false; + } - size_t tensor_data_size = end - begin; + file_paths_.push_back(file_path); + size_t file_index = file_paths_.size() - 1; - if (dtype == "F8_E4M3") { - tensor_storage.is_f8_e4m3 = true; - // f8 -> f16 - GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2); - } else if (dtype == "F8_E5M2") { - tensor_storage.is_f8_e5m2 = true; - // f8 -> f16 - GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2); - } else if (dtype == "F64") { - tensor_storage.is_f64 = true; - // f64 -> f32 - GGML_ASSERT(tensor_storage.nbytes() * 2 == tensor_data_size); - } else if (dtype == "I64") { - tensor_storage.is_i64 = true; - // i64 -> i32 - GGML_ASSERT(tensor_storage.nbytes() * 2 == tensor_data_size); - } else { - GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size); + for (auto& tensor_storage : tensor_storages) { + if (!starts_with(tensor_storage.name, prefix)) { + tensor_storage.name = prefix + tensor_storage.name; } + tensor_storage.file_index = file_index; add_tensor_storage(tensor_storage); - // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str()); + // LOG_DEBUG("%s", tensor_storage.to_string().c_str()); } return true; @@ -629,367 +409,6 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s return true; } -/*================================================= CkptModelLoader ==================================================*/ - -// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100 -// 0: \x80 PROTO 2 -// 2: } EMPTY_DICT -// 3: q BINPUT 0 -// 5: ( MARK -// 6: X BINUNICODE 'epoch' -// 16: q BINPUT 1 -// 18: K BININT1 6 -// 20: X BINUNICODE 'global_step' -// 36: q BINPUT 2 -// 38: J BININT 470000 -// 43: X BINUNICODE 'pytorch-lightning_version' -// 73: q BINPUT 3 -// 75: X BINUNICODE '1.4.2' -// 85: q BINPUT 4 -// 87: X BINUNICODE 'state_dict' -// 102: q BINPUT 5 -// 104: } EMPTY_DICT -// 105: q BINPUT 6 -// 107: ( MARK -// 108: X BINUNICODE 'betas' -// 118: q BINPUT 7 -// 120: c GLOBAL 'torch._utils _rebuild_tensor_v2' -// 153: q BINPUT 8 -// 155: ( MARK -// 156: ( MARK -// 157: X BINUNICODE 'storage' -// 169: q BINPUT 9 -// 171: c GLOBAL 'torch FloatStorage' -// 191: q BINPUT 10 -// 193: X BINUNICODE '0' -// 199: q BINPUT 11 -// 201: X BINUNICODE 'cpu' -// 209: q BINPUT 12 -// 211: M BININT2 1000 -// 214: t TUPLE (MARK at 156) -// 215: q BINPUT 13 -// 217: Q BINPERSID -// 218: K BININT1 0 -// 220: M BININT2 1000 -// ............................... -// 3201: q BINPUT 250 -// 3203: R REDUCE -// 3204: q BINPUT 251 -// 3206: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight' -// 3264: q BINPUT 252 -// 3266: h BINGET 8 -// 3268: ( MARK -// 3269: ( MARK -// 3270: h BINGET 9 -// 3272: h BINGET 10 -// 3274: X BINUNICODE '30' -// 3281: q BINPUT 253 -// 3283: h BINGET 12 -// 3285: J BININT 102400 -// 3290: t TUPLE (MARK at 3269) -// 3291: q BINPUT 254 -// 3293: Q BINPERSID -// 3294: K BININT1 0 -// 3296: ( MARK -// 3297: M BININT2 320 -// 3300: M BININT2 320 -// 3303: K BININT1 1 -// 3305: K BININT1 1 -// 3307: t TUPLE (MARK at 3296) -// 3308: q BINPUT 255 -// 3310: ( MARK -// 3311: M BININT2 320 -// 3314: K BININT1 1 -// 3316: K BININT1 1 -// 3318: K BININT1 1 -// 3320: t TUPLE (MARK at 3310) -// 3321: r LONG_BINPUT 256 -// 3326: \x89 NEWFALSE -// 3327: h BINGET 16 -// 3329: ) EMPTY_TUPLE -// 3330: R REDUCE -// 3331: r LONG_BINPUT 257 -// 3336: t TUPLE (MARK at 3268) -// 3337: r LONG_BINPUT 258 -// 3342: R REDUCE -// 3343: r LONG_BINPUT 259 -// 3348: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias' -// 3404: r LONG_BINPUT 260 -// 3409: h BINGET 8 -// 3411: ( MARK -// 3412: ( MARK -// 3413: h BINGET 9 -// 3415: h BINGET 10 -// 3417: X BINUNICODE '31' - -struct PickleTensorReader { - enum ReadPhase { - READ_NAME, - READ_DATA, - CHECK_SIZE, - READ_DIMENS - }; - ReadPhase phase = READ_NAME; - size_t entry_size = 0; - int32_t nelements = 0; - - TensorStorage tensor_storage; - - static ggml_type global_type; // all pickle_tensors data type - static bool read_global_type; - - bool read_int_value(uint32_t value) { - if (phase == CHECK_SIZE) { - if (entry_size == value * ggml_type_size(tensor_storage.type)) { - nelements = value; - phase = READ_DIMENS; - return true; - } else { - phase = READ_NAME; - } - } else if (phase == READ_DIMENS) { - if (tensor_storage.n_dims + 1 > SD_MAX_DIMS) { // too many dimens - phase = READ_NAME; - tensor_storage.n_dims = 0; - } - if (nelements % value == 0) { - tensor_storage.ne[tensor_storage.n_dims] = value; - tensor_storage.n_dims++; - } - } - return false; - } - - void read_global(const std::string& str) { - if (str == "FloatStorage") { - if (read_global_type) { - global_type = GGML_TYPE_F32; - read_global_type = false; - } - tensor_storage.type = GGML_TYPE_F32; - } else if (str == "HalfStorage") { - if (read_global_type) { - global_type = GGML_TYPE_F16; - read_global_type = false; - } - tensor_storage.type = GGML_TYPE_F16; - } - } - - void read_string(const std::string& str, zip_t* zip, std::string dir) { - if (str == "storage") { - read_global_type = true; - } else if (str != "state_dict") { - if (phase == READ_DATA) { - std::string entry_name = dir + "data/" + std::string(str); - - size_t i, n = zip_entries_total(zip); - for (i = 0; i < n; ++i) { - zip_entry_openbyindex(zip, i); - { - std::string name = zip_entry_name(zip); - if (name == entry_name) { - tensor_storage.index_in_zip = (int)i; - entry_size = zip_entry_size(zip); - zip_entry_close(zip); - break; - } - } - zip_entry_close(zip); - } - - phase = entry_size > 0 ? CHECK_SIZE : READ_NAME; - } - if (!read_global_type && phase == READ_NAME) { - tensor_storage.name = str; - phase = READ_DATA; - tensor_storage.type = global_type; - } - } - } -}; - -ggml_type PickleTensorReader::global_type = GGML_TYPE_F32; // all pickle_tensors data type -bool PickleTensorReader::read_global_type = false; - -int find_char(uint8_t* buffer, int len, char c) { - for (int pos = 0; pos < len; pos++) { - if (buffer[pos] == c) { - return pos; - } - } - return -1; -} - -#define MAX_STRING_BUFFER 512 - -bool ModelLoader::parse_data_pkl(uint8_t* buffer, - size_t buffer_size, - zip_t* zip, - std::string dir, - size_t file_index, - const std::string prefix) { - uint8_t* buffer_end = buffer + buffer_size; - if (buffer[0] == 0x80) { // proto - if (buffer[1] != 2) { - LOG_ERROR("Unsupported protocol\n"); - return false; - } - buffer += 2; // 0x80 and version - char string_buffer[MAX_STRING_BUFFER]; - bool finish = false; - PickleTensorReader reader; - // read pickle binary file - while (!finish && buffer < buffer_end) { - uint8_t opcode = *buffer; - buffer++; - // https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048 - // https://github.com/python/cpython/blob/main/Lib/pickle.py#L105 - switch (opcode) { - case '}': // EMPTY_DICT = b'}' # push empty dict - break; - case ']': // EMPTY_LIST = b']' # push empty list - break; - // skip unused sections - case 'h': // BINGET = b'h' # " " " " " " ; " " 1-byte arg - case 'q': // BINPUT = b'q' # " " " " " ; " " 1-byte arg - case 'Q': // BINPERSID = b'Q' # " " " ; " " " " stack - buffer++; - break; - case 'r': // LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg - buffer += 4; - break; - case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame - buffer += 8; - break; - case 0x94: // MEMOIZE = b'\x94' # store top of the stack in memo - break; - case '(': // MARK = b'(' # push special markobject on stack - break; - case 'K': // BININT1 = b'K' # push 1-byte unsigned int - { - uint8_t value = *buffer; - if (reader.read_int_value(value)) { - buffer++; - } - buffer++; - } break; - case 'M': // BININT2 = b'M' # push 2-byte unsigned int - { - uint16_t value = read_short(buffer); - if (reader.read_int_value(value)) { - buffer++; - } - buffer += 2; - } break; - case 'J': // BININT = b'J' # push four-byte signed int - { - const int32_t value = read_int(buffer); - if (reader.read_int_value(value)) { - buffer++; // skip tuple after read num_elements - } - buffer += 4; - } break; - case 'X': // BINUNICODE = b'X' # " " " ; counted UTF-8 string argument - { - const int32_t len = read_int(buffer); - buffer += 4; - memset(string_buffer, 0, MAX_STRING_BUFFER); - if (len > MAX_STRING_BUFFER) { - LOG_WARN("tensor name very large"); - } - memcpy(string_buffer, buffer, len < MAX_STRING_BUFFER ? len : (MAX_STRING_BUFFER - 1)); - buffer += len; - reader.read_string(string_buffer, zip, dir); - } break; - case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes - { - const int8_t len = *buffer; - buffer++; - memset(string_buffer, 0, MAX_STRING_BUFFER); - memcpy(string_buffer, buffer, len); - buffer += len; - // printf("String: '%s'\n", string_buffer); - } break; - case 'c': // GLOBAL = b'c' # push self.find_class(modname, name); 2 string args - { - int len = find_char(buffer, MAX_STRING_BUFFER, '\n'); - - buffer += len + 1; - len = find_char(buffer, MAX_STRING_BUFFER, '\n'); - - memset(string_buffer, 0, MAX_STRING_BUFFER); - memcpy(string_buffer, buffer, len); - buffer += len + 1; - reader.read_global(string_buffer); - } break; - case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items - case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack top - case 't': // TUPLE = b't' # build tuple from topmost stack items - if (reader.phase == PickleTensorReader::READ_DIMENS) { - reader.tensor_storage.reverse_ne(); - reader.tensor_storage.file_index = file_index; - // if(strcmp(prefix.c_str(), "scarlett") == 0) - // printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str()); - std::string name = reader.tensor_storage.name; - if (!starts_with(name, prefix)) { - name = prefix + name; - } - reader.tensor_storage.name = name; - add_tensor_storage(reader.tensor_storage); - - // LOG_DEBUG("%s", reader.tensor_storage.name.c_str()); - // reset - reader = PickleTensorReader(); - } - break; - case '.': // STOP = b'.' # every pickle ends with STOP - finish = true; - break; - default: - break; - } - } - } - return true; -} - -bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::string& prefix) { - LOG_DEBUG("init from '%s'", file_path.c_str()); - file_paths_.push_back(file_path); - size_t file_index = file_paths_.size() - 1; - - zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); - if (zip == nullptr) { - LOG_ERROR("failed to open '%s'", file_path.c_str()); - return false; - } - int n = (int)zip_entries_total(zip); - for (int i = 0; i < n; ++i) { - zip_entry_openbyindex(zip, i); - { - std::string name = zip_entry_name(zip); - size_t pos = name.find("data.pkl"); - if (pos != std::string::npos) { - std::string dir = name.substr(0, pos); - printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str()); - void* pkl_data = nullptr; - size_t pkl_size; - zip_entry_read(zip, &pkl_data, &pkl_size); - - // LOG_DEBUG("%lld", pkl_size); - - parse_data_pkl((uint8_t*)pkl_data, pkl_size, zip, dir, file_index, prefix); - - free(pkl_data); - } - } - zip_entry_close(zip); - } - zip_close(zip); - return true; -} - SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight, input_block_weight; @@ -1006,64 +425,66 @@ SDVersion ModelLoader::get_sd_version() { bool has_middle_block_1 = false; bool has_output_block_311 = false; bool has_output_block_71 = false; + bool has_attn_1024 = false; for (auto& [name, tensor_storage] : tensor_storage_map) { - if (!(is_xl)) { - if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { - is_flux = true; - } - if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { - return VERSION_CHROMA_RADIANCE; - } - if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) { - return VERSION_SD3; - } - if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) { - return VERSION_QWEN_IMAGE; - } - if (tensor_storage.name.find("llm_adapter.blocks.0.cross_attn.q_proj.weight") != std::string::npos) { - return VERSION_ANIMA; - } - if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) { - is_flux2 = true; - } - if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) { - has_single_block_47 = true; - } - if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) { - return VERSION_OVIS_IMAGE; - } - if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) { - return VERSION_Z_IMAGE; - } - if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { - is_wan = true; - } - if (tensor_storage.name.find("model.diffusion_model.patch_embedding.weight") != std::string::npos) { - patch_embedding_channels = tensor_storage.ne[3]; - } - if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) { - has_img_emb = true; - } - if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || - tensor_storage.name.find("unet.down_blocks.") != std::string::npos) { - is_unet = true; - if (has_multiple_encoders) { - is_xl = true; - } - } - if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || - tensor_storage.name.find("cond_stage_model.1") != std::string::npos || - tensor_storage.name.find("te.1") != std::string::npos) { - has_multiple_encoders = true; - if (is_unet) { - is_xl = true; - } + if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { + is_flux = true; + } + if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { + return VERSION_CHROMA_RADIANCE; + } + if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) { + return VERSION_SD3; + } + if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) { + return VERSION_QWEN_IMAGE; + } + if (tensor_storage.name.find("llm_adapter.blocks.0.cross_attn.q_proj.weight") != std::string::npos) { + return VERSION_ANIMA; + } + if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) { + is_flux2 = true; + } + if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) { + has_single_block_47 = true; + } + if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) { + return VERSION_OVIS_IMAGE; + } + if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) { + return VERSION_Z_IMAGE; + } + if (tensor_storage.name.find("model.diffusion_model.layers.0.adaLN_sa_ln.weight") != std::string::npos) { + return VERSION_ERNIE_IMAGE; + } + if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { + is_wan = true; + } + if (tensor_storage.name.find("model.diffusion_model.patch_embedding.weight") != std::string::npos) { + patch_embedding_channels = tensor_storage.ne[3]; + } + if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) { + has_img_emb = true; + } + if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || + tensor_storage.name.find("unet.down_blocks.") != std::string::npos) { + is_unet = true; + if (has_multiple_encoders) { + is_xl = true; } - if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { - return VERSION_SVD; + } + if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || + tensor_storage.name.find("cond_stage_model.1") != std::string::npos || + tensor_storage.name.find("te.1") != std::string::npos) { + has_multiple_encoders = true; + if (is_unet) { + is_xl = true; } } + if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { + return VERSION_SVD; + } if (tensor_storage.name.find("model.diffusion_model.middle_block.1.") != std::string::npos || tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) { has_middle_block_1 = true; @@ -1075,6 +496,10 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos || tensor_storage.name.find("unet.up_blocks.2.attentions.1") != std::string::npos) { has_output_block_71 = true; + if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight") != std::string::npos) { + if (tensor_storage.ne[0] == 1024) + has_attn_1024 = true; + } } if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" || tensor_storage.name == "cond_stage_model.model.token_embedding.weight" || @@ -1148,7 +573,7 @@ SDVersion ModelLoader::get_sd_version() { } if (!has_middle_block_1) { if (!has_output_block_71) { - return VERSION_SDXS; + return VERSION_SDXS_512_DS; } return VERSION_SD1_TINY_UNET; } @@ -1158,7 +583,7 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_SD2_INPAINT; } if (!has_middle_block_1) { - return VERSION_SD2_TINY_UNET; + return has_attn_1024 ? VERSION_SDXS_09 : VERSION_SD2_TINY_UNET; } return VERSION_SD2; } @@ -1249,8 +674,8 @@ std::map ModelLoader::get_vae_wtype_stat() { return wtype_stat; } -static std::vector> parse_tensor_type_rules(const std::string& tensor_type_rules) { - std::vector> result; +TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules) { + TensorTypeRules result; for (const auto& item : split_string(tensor_type_rules, ',')) { if (item.size() == 0) continue; @@ -1683,76 +1108,6 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage return false; } -bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) { - auto backend = ggml_backend_cpu_init(); - size_t mem_size = 1 * 1024 * 1024; // for padding - mem_size += tensor_storage_map.size() * ggml_tensor_overhead(); - mem_size += get_params_mem_size(backend, type); - LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f); - ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false}); - - gguf_context* gguf_ctx = gguf_init_empty(); - - auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str); - - std::mutex tensor_mutex; - auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { - const std::string& name = tensor_storage.name; - ggml_type tensor_type = tensor_storage.type; - ggml_type dst_type = type; - - for (const auto& tensor_type_rule : tensor_type_rules) { - std::regex pattern(tensor_type_rule.first); - if (std::regex_search(name, pattern)) { - dst_type = tensor_type_rule.second; - break; - } - } - - if (tensor_should_be_converted(tensor_storage, dst_type)) { - tensor_type = dst_type; - } - - std::lock_guard lock(tensor_mutex); - ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne); - if (tensor == nullptr) { - LOG_ERROR("ggml_new_tensor failed"); - return false; - } - ggml_set_name(tensor, name.c_str()); - - // LOG_DEBUG("%s %d %s %d[%d %d %d %d] %d[%d %d %d %d]", name.c_str(), - // ggml_nbytes(tensor), ggml_type_name(tensor_type), - // tensor_storage.n_dims, - // tensor_storage.ne[0], tensor_storage.ne[1], tensor_storage.ne[2], tensor_storage.ne[3], - // tensor->n_dims, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); - - if (!tensor->data) { - GGML_ASSERT(ggml_nelements(tensor) == 0); - // avoid crashing the gguf writer by setting a dummy pointer for zero-sized tensors - LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str()); - tensor->data = ggml_get_mem_buffer(ggml_ctx); - } - - *dst_tensor = tensor; - - gguf_add_tensor(gguf_ctx, tensor); - - return true; - }; - - bool success = load_tensors(on_new_tensor_cb); - ggml_backend_free(backend); - LOG_INFO("load tensors done"); - LOG_INFO("trying to save tensors to %s", file_path.c_str()); - if (success) { - gguf_write_to_file(gguf_ctx, file_path.c_str(), false); - } - ggml_free(ggml_ctx); - gguf_free(gguf_ctx); - return success; -} - int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) { size_t alignment = 128; if (backend != nullptr) { @@ -1772,29 +1127,3 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) return mem_size; } - -bool convert(const char* input_path, - const char* vae_path, - const char* output_path, - sd_type_t output_type, - const char* tensor_type_rules, - bool convert_name) { - ModelLoader model_loader; - - if (!model_loader.init_from_file(input_path)) { - LOG_ERROR("init model loader from file failed: '%s'", input_path); - return false; - } - - if (vae_path != nullptr && strlen(vae_path) > 0) { - if (!model_loader.init_from_file(vae_path, "vae.")) { - LOG_ERROR("init model loader from file failed: '%s'", vae_path); - return false; - } - } - if (convert_name) { - model_loader.convert_tensors_name(); - } - bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules); - return success; -} diff --git a/src/model.h b/src/model.h index 3af35eb7e..65bc6c367 100644 --- a/src/model.h +++ b/src/model.h @@ -5,20 +5,13 @@ #include #include #include -#include #include -#include -#include #include #include "ggml-backend.h" #include "ggml.h" -#include "gguf.h" -#include "json.hpp" +#include "model_io/tensor_storage.h" #include "ordered_map.hpp" -#include "zip.h" - -#define SD_MAX_DIMS 5 enum SDVersion { VERSION_SD1, @@ -28,7 +21,8 @@ enum SDVersion { VERSION_SD2, VERSION_SD2_INPAINT, VERSION_SD2_TINY_UNET, - VERSION_SDXS, + VERSION_SDXS_512_DS, + VERSION_SDXS_09, VERSION_SDXL, VERSION_SDXL_INPAINT, VERSION_SDXL_PIX2PIX, @@ -50,18 +44,19 @@ enum SDVersion { VERSION_FLUX2_KLEIN, VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, + VERSION_ERNIE_IMAGE, VERSION_COUNT, }; static inline bool sd_version_is_sd1(SDVersion version) { - if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS) { + if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS_512_DS) { return true; } return false; } static inline bool sd_version_is_sd2(SDVersion version) { - if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT || version == VERSION_SD2_TINY_UNET) { + if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS_09) { return true; } return false; @@ -137,6 +132,20 @@ static inline bool sd_version_is_z_image(SDVersion version) { return false; } +static inline bool sd_version_is_ernie_image(SDVersion version) { + if (version == VERSION_ERNIE_IMAGE) { + return true; + } + return false; +} + +static inline bool sd_version_uses_flux2_vae(SDVersion version) { + if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version)) { + return true; + } + return false; +} + static inline bool sd_version_is_inpaint(SDVersion version) { if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || @@ -155,7 +164,8 @@ static inline bool sd_version_is_dit(SDVersion version) { sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || - sd_version_is_z_image(version)) { + sd_version_is_z_image(version) || + sd_version_is_ernie_image(version)) { return true; } return false; @@ -178,116 +188,10 @@ enum PMVersion { PM_VERSION_2, }; -struct TensorStorage { - std::string name; - ggml_type type = GGML_TYPE_F32; - ggml_type expected_type = GGML_TYPE_COUNT; - bool is_f8_e4m3 = false; - bool is_f8_e5m2 = false; - bool is_f64 = false; - bool is_i64 = false; - int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; - int n_dims = 0; - - size_t file_index = 0; - int index_in_zip = -1; // >= means stored in a zip file - uint64_t offset = 0; // offset in file - - TensorStorage() = default; - - TensorStorage(std::string name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0) - : name(std::move(name)), type(type), n_dims(n_dims), file_index(file_index), offset(offset) { - for (int i = 0; i < n_dims; i++) { - this->ne[i] = ne[i]; - } - } - - int64_t nelements() const { - int64_t n = 1; - for (int i = 0; i < SD_MAX_DIMS; i++) { - n *= ne[i]; - } - return n; - } - - int64_t nbytes() const { - return nelements() * ggml_type_size(type) / ggml_blck_size(type); - } - - int64_t nbytes_to_read() const { - if (is_f8_e4m3 || is_f8_e5m2) { - return nbytes() / 2; - } else if (is_f64 || is_i64) { - return nbytes() * 2; - } else { - return nbytes(); - } - } - - void unsqueeze() { - if (n_dims == 2) { - n_dims = 4; - ne[3] = ne[1]; - ne[2] = ne[0]; - ne[1] = 1; - ne[0] = 1; - } - } - - std::vector chunk(size_t n) { - std::vector chunks; - uint64_t chunk_size = nbytes_to_read() / n; - // printf("%d/%d\n", chunk_size, nbytes_to_read()); - reverse_ne(); - for (size_t i = 0; i < n; i++) { - TensorStorage chunk_i = *this; - chunk_i.ne[0] = ne[0] / n; - chunk_i.offset = offset + i * chunk_size; - chunk_i.reverse_ne(); - chunks.push_back(chunk_i); - } - reverse_ne(); - return chunks; - } - - void reverse_ne() { - int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; - for (int i = 0; i < n_dims; i++) { - new_ne[i] = ne[n_dims - 1 - i]; - } - for (int i = 0; i < n_dims; i++) { - ne[i] = new_ne[i]; - } - } - - std::string to_string() const { - std::stringstream ss; - const char* type_name = ggml_type_name(type); - if (is_f8_e4m3) { - type_name = "f8_e4m3"; - } else if (is_f8_e5m2) { - type_name = "f8_e5m2"; - } else if (is_f64) { - type_name = "f64"; - } else if (is_i64) { - type_name = "i64"; - } - ss << name << " | " << type_name << " | "; - ss << n_dims << " ["; - for (int i = 0; i < SD_MAX_DIMS; i++) { - ss << ne[i]; - if (i != SD_MAX_DIMS - 1) { - ss << ", "; - } - } - ss << "]"; - return ss.str(); - } -}; - -typedef std::function on_new_tensor_cb_t; - typedef OrderedMap String2TensorStorage; +using TensorTypeRules = std::vector>; + +TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules); class ModelLoader { protected: @@ -297,16 +201,10 @@ class ModelLoader { void add_tensor_storage(const TensorStorage& tensor_storage); - bool parse_data_pkl(uint8_t* buffer, - size_t buffer_size, - zip_t* zip, - std::string dir, - size_t file_index, - const std::string prefix); - bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = ""); - bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = ""); + bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = ""); + bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); public: @@ -336,7 +234,6 @@ class ModelLoader { return names; } - bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); ~ModelLoader() = default; diff --git a/src/model_io/binary_io.h b/src/model_io/binary_io.h new file mode 100644 index 000000000..9093eeaf9 --- /dev/null +++ b/src/model_io/binary_io.h @@ -0,0 +1,57 @@ +#ifndef __SD_MODEL_IO_BINARY_IO_H__ +#define __SD_MODEL_IO_BINARY_IO_H__ + +#include +#include + +namespace model_io { + + inline int32_t read_int(const uint8_t* buffer) { + uint32_t value = 0; + value |= static_cast(buffer[3]) << 24; + value |= static_cast(buffer[2]) << 16; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return static_cast(value); + } + + inline uint16_t read_short(const uint8_t* buffer) { + uint16_t value = 0; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return value; + } + + inline uint64_t read_u64(const uint8_t* buffer) { + uint64_t value = 0; + value |= static_cast(buffer[7]) << 56; + value |= static_cast(buffer[6]) << 48; + value |= static_cast(buffer[5]) << 40; + value |= static_cast(buffer[4]) << 32; + value |= static_cast(buffer[3]) << 24; + value |= static_cast(buffer[2]) << 16; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return value; + } + + inline void write_u64(std::ostream& stream, uint64_t value) { + uint8_t buffer[8]; + for (int i = 0; i < 8; ++i) { + buffer[i] = static_cast((value >> (8 * i)) & 0xFF); + } + stream.write((const char*)buffer, sizeof(buffer)); + } + + inline int find_char(const uint8_t* buffer, int len, char c) { + for (int pos = 0; pos < len; pos++) { + if (buffer[pos] == (uint8_t)c) { + return pos; + } + } + return -1; + } + +} // namespace model_io + +#endif // __SD_MODEL_IO_BINARY_IO_H__ diff --git a/src/model_io/gguf_io.cpp b/src/model_io/gguf_io.cpp new file mode 100644 index 000000000..378694d8e --- /dev/null +++ b/src/model_io/gguf_io.cpp @@ -0,0 +1,123 @@ +#include "gguf_io.h" + +#include +#include +#include +#include + +#include "gguf.h" +#include "gguf_reader_ext.h" +#include "util.h" + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool is_gguf_file(const std::string& file_path) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + return false; + } + + char magic[4]; + + file.read(magic, sizeof(magic)); + if (!file) { + return false; + } + for (uint32_t i = 0; i < sizeof(magic); i++) { + if (magic[i] != GGUF_MAGIC[i]) { + return false; + } + } + + return true; +} + +bool read_gguf_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + tensor_storages.clear(); + + gguf_context* ctx_gguf_ = nullptr; + ggml_context* ctx_meta_ = nullptr; + + ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_}); + if (!ctx_gguf_) { + GGUFReader gguf_reader; + if (!gguf_reader.load(file_path)) { + set_error(error, "failed to open '" + file_path + "' with GGUFReader"); + return false; + } + + size_t data_offset = gguf_reader.data_offset(); + for (const auto& gguf_tensor_info : gguf_reader.tensors()) { + TensorStorage tensor_storage( + gguf_tensor_info.name, + gguf_tensor_info.type, + gguf_tensor_info.shape.data(), + static_cast(gguf_tensor_info.shape.size()), + 0, + data_offset + gguf_tensor_info.offset); + + tensor_storages.push_back(tensor_storage); + } + + return true; + } + + int n_tensors = static_cast(gguf_get_n_tensors(ctx_gguf_)); + + size_t data_offset = gguf_get_data_offset(ctx_gguf_); + for (int i = 0; i < n_tensors; i++) { + std::string name = gguf_get_tensor_name(ctx_gguf_, i); + ggml_tensor* dummy = ggml_get_tensor(ctx_meta_, name.c_str()); + size_t offset = data_offset + gguf_get_tensor_offset(ctx_gguf_, i); + + TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), 0, offset); + + if (ggml_nbytes(dummy) != tensor_storage.nbytes()) { + gguf_free(ctx_gguf_); + ggml_free(ctx_meta_); + set_error(error, "size mismatch for tensor '" + name + "'"); + return false; + } + + tensor_storages.push_back(tensor_storage); + } + + gguf_free(ctx_gguf_); + ggml_free(ctx_meta_); + + return true; +} + +bool write_gguf_file(const std::string& file_path, + const std::vector& tensors, + std::string* error) { + gguf_context* gguf_ctx = gguf_init_empty(); + if (gguf_ctx == nullptr) { + set_error(error, "gguf_init_empty failed"); + return false; + } + + for (const TensorWriteInfo& write_tensor : tensors) { + ggml_tensor* tensor = write_tensor.tensor; + if (tensor == nullptr) { + set_error(error, "null tensor cannot be written to GGUF"); + gguf_free(gguf_ctx); + return false; + } + gguf_add_tensor(gguf_ctx, tensor); + } + + LOG_INFO("trying to save tensors to %s", file_path.c_str()); + bool success = gguf_write_to_file(gguf_ctx, file_path.c_str(), false); + if (!success) { + set_error(error, "failed to write GGUF file '" + file_path + "'"); + } + gguf_free(gguf_ctx); + return success; +} diff --git a/src/model_io/gguf_io.h b/src/model_io/gguf_io.h new file mode 100644 index 000000000..81c981145 --- /dev/null +++ b/src/model_io/gguf_io.h @@ -0,0 +1,17 @@ +#ifndef __SD_MODEL_IO_GGUF_IO_H__ +#define __SD_MODEL_IO_GGUF_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool is_gguf_file(const std::string& file_path); +bool read_gguf_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); +bool write_gguf_file(const std::string& file_path, + const std::vector& tensors, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_GGUF_IO_H__ diff --git a/src/gguf_reader.hpp b/src/model_io/gguf_reader_ext.h similarity index 97% rename from src/gguf_reader.hpp rename to src/model_io/gguf_reader_ext.h index 2cc4d9d9c..95f0027fc 100644 --- a/src/gguf_reader.hpp +++ b/src/model_io/gguf_reader_ext.h @@ -1,5 +1,5 @@ -#ifndef __GGUF_READER_HPP__ -#define __GGUF_READER_HPP__ +#ifndef __SD_MODEL_IO_GGUF_READER_EXT_H__ +#define __SD_MODEL_IO_GGUF_READER_EXT_H__ #include #include @@ -59,6 +59,9 @@ class GGUFReader { if (!safe_read(fin, key_len)) return false; + if (key_len > 4096) + return false; + std::string key(key_len, '\0'); if (!safe_read(fin, (char*)key.data(), key_len)) return false; @@ -228,4 +231,4 @@ class GGUFReader { size_t data_offset() const { return data_offset_; } }; -#endif // __GGUF_READER_HPP__ +#endif // __SD_MODEL_IO_GGUF_READER_EXT_H__ diff --git a/src/model_io/pickle_io.cpp b/src/model_io/pickle_io.cpp new file mode 100644 index 000000000..3a978178a --- /dev/null +++ b/src/model_io/pickle_io.cpp @@ -0,0 +1,1064 @@ +#include "pickle_io.h" + +#include +#include +#include +#include +#include +#include + +#include "binary_io.h" +#include "util.h" + +// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100 +// 0: \x80 PROTO 2 +// 2: } EMPTY_DICT +// 3: q BINPUT 0 +// 5: ( MARK +// 6: X BINUNICODE 'epoch' +// 16: q BINPUT 1 +// 18: K BININT1 6 +// 20: X BINUNICODE 'global_step' +// 36: q BINPUT 2 +// 38: J BININT 470000 +// 43: X BINUNICODE 'pytorch-lightning_version' +// 73: q BINPUT 3 +// 75: X BINUNICODE '1.4.2' +// 85: q BINPUT 4 +// 87: X BINUNICODE 'state_dict' +// 102: q BINPUT 5 +// 104: } EMPTY_DICT +// 105: q BINPUT 6 +// 107: ( MARK +// 108: X BINUNICODE 'betas' +// 118: q BINPUT 7 +// 120: c GLOBAL 'torch._utils _rebuild_tensor_v2' +// 153: q BINPUT 8 +// 155: ( MARK +// 156: ( MARK +// 157: X BINUNICODE 'storage' +// 169: q BINPUT 9 +// 171: c GLOBAL 'torch FloatStorage' +// 191: q BINPUT 10 +// 193: X BINUNICODE '0' +// 199: q BINPUT 11 +// 201: X BINUNICODE 'cpu' +// 209: q BINPUT 12 +// 211: M BININT2 1000 +// 214: t TUPLE (MARK at 156) +// 215: q BINPUT 13 +// 217: Q BINPERSID +// 218: K BININT1 0 +// 220: M BININT2 1000 +// ............................... +// 3201: q BINPUT 250 +// 3203: R REDUCE +// 3204: q BINPUT 251 +// 3206: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight' +// 3264: q BINPUT 252 +// 3266: h BINGET 8 +// 3268: ( MARK +// 3269: ( MARK +// 3270: h BINGET 9 +// 3272: h BINGET 10 +// 3274: X BINUNICODE '30' +// 3281: q BINPUT 253 +// 3283: h BINGET 12 +// 3285: J BININT 102400 +// 3290: t TUPLE (MARK at 3269) +// 3291: q BINPUT 254 +// 3293: Q BINPERSID +// 3294: K BININT1 0 +// 3296: ( MARK +// 3297: M BININT2 320 +// 3300: M BININT2 320 +// 3303: K BININT1 1 +// 3305: K BININT1 1 +// 3307: t TUPLE (MARK at 3296) +// 3308: q BINPUT 255 +// 3310: ( MARK +// 3311: M BININT2 320 +// 3314: K BININT1 1 +// 3316: K BININT1 1 +// 3318: K BININT1 1 +// 3320: t TUPLE (MARK at 3310) +// 3321: r LONG_BINPUT 256 +// 3326: \x89 NEWFALSE +// 3327: h BINGET 16 +// 3329: ) EMPTY_TUPLE +// 3330: R REDUCE +// 3331: r LONG_BINPUT 257 +// 3336: t TUPLE (MARK at 3268) +// 3337: r LONG_BINPUT 258 +// 3342: R REDUCE +// 3343: r LONG_BINPUT 259 +// 3348: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias' +// 3404: r LONG_BINPUT 260 +// 3409: h BINGET 8 +// 3411: ( MARK +// 3412: ( MARK +// 3413: h BINGET 9 +// 3415: h BINGET 10 +// 3417: X BINUNICODE '31' +// https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048 +// https://github.com/python/cpython/blob/main/Lib/pickle.py#L105 + +using model_io::find_char; +using model_io::read_int; +using model_io::read_short; +using model_io::read_u64; + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size) { + const uint8_t* p = buffer; + const uint8_t* end = buffer + buffer_size; + + while (p < end) { + uint8_t opcode = *p++; + switch (opcode) { + case '.': // STOP = b'.' # every pickle ends with STOP + *object_size = (size_t)(p - buffer); + return true; + case 0x80: // PROTO = b'\x80' # protocol version indicator + case 'K': // BININT1 = b'K' # push 1-byte unsigned int + case 'h': // BINGET = b'h' # read memo index, 1-byte arg + case 'q': // BINPUT = b'q' # write memo index, 1-byte arg + case 'C': // SHORT_BINBYTES = b'C' # push bytes; length < 256 + case 0x82: // EXT1 = b'\x82' # extension code, 1-byte arg + p += 1; + break; + case 'M': // BININT2 = b'M' # push 2-byte unsigned int + case 0x83: // EXT2 = b'\x83' # extension code, 2-byte arg + p += 2; + break; + case 'J': // BININT = b'J' # push 4-byte signed int + case 'j': // LONG_BINGET = b'j' # read memo index, 4-byte arg + case 'r': // LONG_BINPUT = b'r' # write memo index, 4-byte arg + case 0x84: // EXT4 = b'\x84' # extension code, 4-byte arg + p += 4; + break; + case 'I': // INT = b'I' # push decimal integer line + case 'L': // LONG = b'L' # push decimal long integer line + case 'F': // FLOAT = b'F' # push decimal float line + case 'S': // STRING = b'S' # push quoted string line + case 'V': { // UNICODE = b'V' # push raw-unicode string line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + case 'G': // BINFLOAT = b'G' # push 8-byte binary float + p += 8; + break; + case 0x8A: // LONG1 = b'\x8a' # push long integer; 1-byte length + if (p >= end) { + return false; + } + p += 1 + p[0]; + break; + case 0x8B: { // LONG4 = b'\x8b' # push long integer; 4-byte length + if (p + 4 > end) { + return false; + } + uint32_t n = (uint32_t)read_int(p); + p += 4 + n; + } break; + case 'B': { // BINBYTES = b'B' # push bytes; 4-byte length + if (p + 4 > end) { + return false; + } + uint32_t n = (uint32_t)read_int(p); + p += 4 + n; + } break; + case 'T': // BINSTRING = b'T' # push string; 4-byte length + case 'X': { // BINUNICODE = b'X' # push UTF-8 string; 4-byte length + if (p + 4 > end) { + return false; + } + uint32_t n = (uint32_t)read_int(p); + p += 4 + n; + } break; + case 0x8D: // BINUNICODE8 = b'\x8d' # push UTF-8 string; 8-byte length + case 0x8E: // BINBYTES8 = b'\x8e' # push bytes; 8-byte length + case 0x96: { // BYTEARRAY8 = b'\x96' # push bytearray; 8-byte length + if (p + 8 > end) { + return false; + } + uint64_t n = read_u64(p); + p += 8; + if (n > (uint64_t)(end - p)) { + return false; + } + p += n; + } break; + case 'U': // SHORT_BINSTRING = b'U' # push string; length < 256 + case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push UTF-8 string; length < 256 + if (p >= end) { + return false; + } + p += 1 + p[0]; + break; + case 'P': { // PERSID = b'P' # persistent id, newline-terminated + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame + p += 8; + break; + case 'c': { // GLOBAL = b'c' # push module/name global reference + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + case '}': // EMPTY_DICT = b'}' # push empty dict + case ']': // EMPTY_LIST = b']' # push empty list + case '(': // MARK = b'(' # push markobject + case 't': // TUPLE = b't' # build tuple from mark + case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack + case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from stack + case 0x87: // TUPLE3 = b'\x87' # build 3-tuple from stack + case ')': // EMPTY_TUPLE = b')' # push empty tuple + case 'l': // LIST = b'l' # build list from mark + case 'Q': // BINPERSID = b'Q' # persistent id from stack + case 0x94: // MEMOIZE = b'\x94' # store top of stack in memo + case 0x88: // NEWTRUE = b'\x88' # push True + case 0x89: // NEWFALSE = b'\x89' # push False + case 'R': // REDUCE = b'R' # apply callable to args + case 'u': // SETITEMS = b'u' # add mark-delimited items to dict + case 's': // SETITEM = b's' # add key/value to dict + case 'e': // APPENDS = b'e' # extend list with mark-delimited items + case 'a': // APPEND = b'a' # append item to list + case 'b': // BUILD = b'b' # build object state + case 0x81: // NEWOBJ = b'\x81' # build object via __new__ + case 0x8F: // EMPTY_SET = b'\x8f' # push empty set + case 0x90: // ADDITEMS = b'\x90' # add mark-delimited items to set + case 0x91: // FROZENSET = b'\x91' # build frozenset from mark + case 0x92: // NEWOBJ_EX = b'\x92' # build object with kwargs + case 0x93: // STACK_GLOBAL = b'\x93' # build global from module/name strings + case 0x97: // NEXT_BUFFER = b'\x97' # out-of-band buffer marker + case 0x98: // READONLY_BUFFER = b'\x98' # mark buffer readonly + case 'N': // NONE = b'N' # push None + case '0': // POP = b'0' # discard top stack item + case '1': // POP_MARK = b'1' # discard stack through topmost mark + case '2': // DUP = b'2' # duplicate top stack item + case 'o': // OBJ = b'o' # build class instance from mark + break; + case 'i': { // INST = b'i' # build class instance from module/name + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + default: + return false; + } + if (p > end) { + return false; + } + } + + return false; +} + +bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size) { + static const uint8_t torch_magic_bytes[] = {0x6C, 0xFC, 0x9C, 0x46, 0xF9, 0x20, 0x6A, 0xA8, 0x50, 0x19}; + + if (buffer_size < 5 || buffer[0] != 0x80) { + return false; + } + + size_t pos = 2; + if (pos >= buffer_size) { + return false; + } + + uint8_t opcode = buffer[pos++]; + if (opcode != 0x8A || pos >= buffer_size) { + return false; + } + + uint8_t len = buffer[pos++]; + if (len != sizeof(torch_magic_bytes) || pos + len >= buffer_size) { + return false; + } + + if (memcmp(buffer + pos, torch_magic_bytes, sizeof(torch_magic_bytes)) != 0) { + return false; + } + pos += len; + + return pos < buffer_size && buffer[pos] == '.'; +} + +bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value) { + if (buffer_size < 4 || buffer[0] != 0x80) { + return false; + } + + size_t pos = 2; + if (pos >= buffer_size) { + return false; + } + + uint8_t opcode = buffer[pos++]; + switch (opcode) { + case 'K': // BININT1 = b'K' # push 1-byte unsigned int + if (pos + 1 >= buffer_size) { + return false; + } + *value = buffer[pos]; + pos += 1; + break; + case 'M': // BININT2 = b'M' # push 2-byte unsigned int + if (pos + 2 >= buffer_size) { + return false; + } + *value = read_short(buffer + pos); + pos += 2; + break; + case 'J': // BININT = b'J' # push 4-byte signed int + if (pos + 4 >= buffer_size) { + return false; + } + *value = (uint32_t)read_int(buffer + pos); + pos += 4; + break; + default: + return false; + } + + return pos < buffer_size && buffer[pos] == '.'; +} + +struct PickleStorageInfo { + std::string key; + ggml_type type = GGML_TYPE_COUNT; + bool is_f64 = false; + bool is_i64 = false; + uint64_t raw_element_nbytes = 0; + uint64_t nbytes = 0; +}; + +struct PickleTensorInfo { + TensorStorage tensor_storage; + int stride_n_dims = 0; + int64_t stride[SD_MAX_DIMS]{1, 1, 1, 1, 1}; +}; + +struct PickleValue { + enum Kind { + MARK, + NONE, + BOOL, + INT, + STRING, + GLOBAL, + TUPLE, + LIST, + DICT, + ORDERED_DICT, + STORAGE, + TENSOR, + }; + + Kind kind = NONE; + int64_t int_value = 0; + bool bool_value = false; + std::string str_value; + std::vector items; + std::vector> dict_items; + PickleStorageInfo storage; + PickleTensorInfo tensor; +}; + +static PickleValue make_mark_value() { + PickleValue value; + value.kind = PickleValue::MARK; + return value; +} + +static PickleValue make_none_value() { + PickleValue value; + value.kind = PickleValue::NONE; + return value; +} + +static PickleValue make_bool_value(bool b) { + PickleValue value; + value.kind = PickleValue::BOOL; + value.bool_value = b; + return value; +} + +static PickleValue make_int_value(int64_t x) { + PickleValue value; + value.kind = PickleValue::INT; + value.int_value = x; + return value; +} + +static PickleValue make_string_value(const std::string& s) { + PickleValue value; + value.kind = PickleValue::STRING; + value.str_value = s; + return value; +} + +static PickleValue make_global_value(const std::string& s) { + PickleValue value; + value.kind = PickleValue::GLOBAL; + value.str_value = s; + return value; +} + +static PickleValue make_tuple_value(std::vector items) { + PickleValue value; + value.kind = PickleValue::TUPLE; + value.items = std::move(items); + return value; +} + +static PickleValue make_list_value() { + PickleValue value; + value.kind = PickleValue::LIST; + return value; +} + +static PickleValue make_dict_value(bool ordered) { + PickleValue value; + value.kind = ordered ? PickleValue::ORDERED_DICT : PickleValue::DICT; + return value; +} + +static PickleValue make_storage_value(const PickleStorageInfo& storage) { + PickleValue value; + value.kind = PickleValue::STORAGE; + value.storage = storage; + return value; +} + +static PickleValue make_tensor_value(const PickleTensorInfo& tensor) { + PickleValue value; + value.kind = PickleValue::TENSOR; + value.tensor = tensor; + return value; +} + +static std::string pickle_value_to_string(const PickleValue& value) { + if (value.kind == PickleValue::STRING) { + return value.str_value; + } + if (value.kind == PickleValue::INT) { + return std::to_string(value.int_value); + } + return ""; +} + +static bool parse_storage_type(const std::string& global_name, PickleStorageInfo* storage) { + if (global_name == "torch.FloatStorage") { + storage->type = GGML_TYPE_F32; + storage->raw_element_nbytes = 4; + return true; + } + if (global_name == "torch.DoubleStorage") { + storage->type = GGML_TYPE_F32; + storage->is_f64 = true; + storage->raw_element_nbytes = 8; + return true; + } + if (global_name == "torch.HalfStorage") { + storage->type = GGML_TYPE_F16; + storage->raw_element_nbytes = 2; + return true; + } + if (global_name == "torch.BFloat16Storage") { + storage->type = GGML_TYPE_BF16; + storage->raw_element_nbytes = 2; + return true; + } + if (global_name == "torch.IntStorage") { + storage->type = GGML_TYPE_I32; + storage->raw_element_nbytes = 4; + return true; + } + if (global_name == "torch.LongStorage") { + storage->type = GGML_TYPE_I32; + storage->is_i64 = true; + storage->raw_element_nbytes = 8; + return true; + } + return false; +} + +static bool tensor_is_contiguous(const PickleTensorInfo& tensor) { + if (tensor.tensor_storage.nelements() == 0) { + return true; + } + if (tensor.stride_n_dims != tensor.tensor_storage.n_dims) { + return false; + } + + int64_t expected_stride = 1; + for (int i = tensor.tensor_storage.n_dims - 1; i >= 0; --i) { + if (tensor.stride[i] != expected_stride) { + return false; + } + expected_stride *= tensor.tensor_storage.ne[i]; + } + return true; +} + +static void collect_tensors_from_pickle_value(const PickleValue& value, + std::vector& tensor_storages) { + if (value.kind != PickleValue::DICT && value.kind != PickleValue::ORDERED_DICT) { + return; + } + + for (const auto& item : value.dict_items) { + if (item.first.kind == PickleValue::STRING && item.second.kind == PickleValue::TENSOR) { + TensorStorage tensor_storage = item.second.tensor.tensor_storage; + tensor_storage.name = item.first.str_value; + tensor_storage.reverse_ne(); + tensor_storages.push_back(tensor_storage); + } else if (item.second.kind == PickleValue::DICT || item.second.kind == PickleValue::ORDERED_DICT) { + collect_tensors_from_pickle_value(item.second, tensor_storages); + } + } +} + +bool parse_torch_state_dict_pickle(const uint8_t* buffer, + size_t buffer_size, + std::vector& tensor_storages, + std::unordered_map& storage_nbytes, + std::string* error) { + if (buffer_size < 2 || buffer[0] != 0x80 || buffer[1] < 2 || buffer[1] > 5) { + set_error(error, "unsupported torch pickle protocol"); + return false; + } + + const uint8_t* p = buffer + 2; + const uint8_t* end = buffer + buffer_size; + std::vector stack; + std::unordered_map memo; + + while (p < end) { + uint8_t opcode = *p++; + switch (opcode) { + case '.': { // STOP = b'.' # every pickle ends with STOP + if (stack.empty()) { + set_error(error, "empty torch pickle stack"); + return false; + } + size_t old_tensor_count = tensor_storages.size(); + collect_tensors_from_pickle_value(stack.back(), tensor_storages); + if (tensor_storages.size() == old_tensor_count) { + set_error(error, "torch pickle does not contain a supported state_dict"); + return false; + } + return true; + } + case '}': // EMPTY_DICT = b'}' # push empty dict + stack.push_back(make_dict_value(false)); + break; + case ']': // EMPTY_LIST = b']' # push empty list + stack.push_back(make_list_value()); + break; + case 'l': { // LIST = b'l' # build list from mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + set_error(error, "torch pickle list without mark"); + return false; + } + std::vector items(stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + PickleValue list_value = make_list_value(); + list_value.items = std::move(items); + stack.push_back(std::move(list_value)); + } break; + case '(': // MARK = b'(' # push markobject + stack.push_back(make_mark_value()); + break; + case ')': // EMPTY_TUPLE = b')' # push empty tuple + stack.push_back(make_tuple_value({})); + break; + case 'N': // NONE = b'N' # push None + stack.push_back(make_none_value()); + break; + case 0x88: // NEWTRUE = b'\x88' # push True + stack.push_back(make_bool_value(true)); + break; + case 0x89: // NEWFALSE = b'\x89' # push False + stack.push_back(make_bool_value(false)); + break; + case 'K': // BININT1 = b'K' # push 1-byte unsigned int + if (p >= end) { + return false; + } + stack.push_back(make_int_value(*p++)); + break; + case 'M': // BININT2 = b'M' # push 2-byte unsigned int + if (p + 2 > end) { + return false; + } + stack.push_back(make_int_value(read_short(p))); + p += 2; + break; + case 'J': // BININT = b'J' # push 4-byte signed int + if (p + 4 > end) { + return false; + } + stack.push_back(make_int_value(read_int(p))); + p += 4; + break; + case 'I': { // INT = b'I' # push decimal integer line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string s((const char*)p, len); + p += len + 1; + if (s == "01") { + stack.push_back(make_bool_value(true)); + } else if (s == "00") { + stack.push_back(make_bool_value(false)); + } else { + stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10))); + } + } break; + case 'L': { // LONG = b'L' # push decimal long integer line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string s((const char*)p, len); + p += len + 1; + if (!s.empty() && s.back() == 'L') { + s.pop_back(); + } + stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10))); + } break; + case 'F': { // FLOAT = b'F' # push decimal float line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + stack.push_back(make_none_value()); + } break; + case 'G': // BINFLOAT = b'G' # push 8-byte binary float + if (p + 8 > end) { + return false; + } + p += 8; + stack.push_back(make_none_value()); + break; + case 0x8A: { // LONG1 = b'\x8a' # push long integer; 1-byte length + if (p >= end) { + return false; + } + uint8_t n = *p++; + if (p + n > end || n > 8) { + return false; + } + int64_t value = 0; + for (uint8_t i = 0; i < n; ++i) { + value |= (int64_t)p[i] << (i * 8); + } + p += n; + stack.push_back(make_int_value(value)); + } break; + case 'C': { // SHORT_BINBYTES = b'C' # push bytes; length < 256 + if (p >= end) { + return false; + } + uint8_t len = *p++; + if (p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 'B': { // BINBYTES = b'B' # push bytes; 4-byte length + if (p + 4 > end) { + return false; + } + int32_t len = read_int(p); + p += 4; + if (len < 0 || p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 'T': // BINSTRING = b'T' # push string; 4-byte length + case 'X': { // BINUNICODE = b'X' # push UTF-8 string; 4-byte length + if (p + 4 > end) { + return false; + } + int32_t len = read_int(p); + p += 4; + if (len < 0 || p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 0x8D: // BINUNICODE8 = b'\x8d' # push UTF-8 string; 8-byte length + case 0x8E: // BINBYTES8 = b'\x8e' # push bytes; 8-byte length + case 0x96: { // BYTEARRAY8 = b'\x96' # push bytearray; 8-byte length + if (p + 8 > end) { + return false; + } + uint64_t len = read_u64(p); + p += 8; + if (len > (uint64_t)(end - p)) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, (size_t)len))); + p += len; + } break; + case 'U': // SHORT_BINSTRING = b'U' # push string; length < 256 + case 0x8C: { // SHORT_BINUNICODE = b'\x8c' # push UTF-8 string; length < 256 + if (p >= end) { + return false; + } + uint8_t len = *p++; + if (p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 'S': { // STRING = b'S' # push quoted string line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string s((const char*)p, len); + p += len + 1; + if (s.size() >= 2 && (s[0] == '\'' || s[0] == '"') && s.back() == s[0]) { + s = s.substr(1, s.size() - 2); + } + stack.push_back(make_string_value(s)); + } break; + case 'V': { // UNICODE = b'V' # push raw-unicode string line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len + 1; + } break; + case 'c': { // GLOBAL = b'c' # push module/name global reference + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string module((const char*)p, len); + p += len + 1; + len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string name((const char*)p, len); + p += len + 1; + stack.push_back(make_global_value(module + "." + name)); + } break; + case 0x93: { // STACK_GLOBAL = b'\x93' # build global from module/name strings + if (stack.size() < 2 || stack[stack.size() - 2].kind != PickleValue::STRING || + stack.back().kind != PickleValue::STRING) { + return false; + } + std::string name = stack.back().str_value; + stack.pop_back(); + std::string module = stack.back().str_value; + stack.pop_back(); + stack.push_back(make_global_value(module + "." + name)); + } break; + case 'h': // BINGET = b'h' # read memo index, 1-byte arg + if (p >= end || !memo.count(*p)) { + return false; + } + stack.push_back(memo[*p++]); + break; + case 'j': { // LONG_BINGET = b'j' # read memo index, 4-byte arg + if (p + 4 > end) { + return false; + } + int32_t memo_idx = read_int(p); + if (!memo.count(memo_idx)) { + return false; + } + stack.push_back(memo[memo_idx]); + p += 4; + } break; + case 'q': // BINPUT = b'q' # write memo index, 1-byte arg + if (p >= end || stack.empty()) { + return false; + } + memo[*p++] = stack.back(); + break; + case 'r': // LONG_BINPUT = b'r' # write memo index, 4-byte arg + if (p + 4 > end || stack.empty()) { + return false; + } + memo[read_int(p)] = stack.back(); + p += 4; + break; + case 0x94: // MEMOIZE = b'\x94' # store top of stack in memo + if (stack.empty()) { + return false; + } + memo[(int32_t)memo.size()] = stack.back(); + break; + case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame + if (p + 8 > end) { + return false; + } + p += 8; + break; + case '0': // POP = b'0' # discard top stack item + if (stack.empty()) { + return false; + } + stack.pop_back(); + break; + case '1': { // POP_MARK = b'1' # discard stack through topmost mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + return false; + } + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case '2': // DUP = b'2' # duplicate top stack item + if (stack.empty()) { + return false; + } + stack.push_back(stack.back()); + break; + case 0x8F: // EMPTY_SET = b'\x8f' # push empty set + stack.push_back(make_list_value()); + break; + case 0x90: { // ADDITEMS = b'\x90' # add mark-delimited items to set + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) { + return false; + } + PickleValue& set_value = stack[mark_idx - 1]; + set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case 0x91: { // FROZENSET = b'\x91' # build frozenset from mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + return false; + } + PickleValue set_value = make_list_value(); + set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + stack.push_back(std::move(set_value)); + } break; + case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack + case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from stack + case 0x87: { // TUPLE3 = b'\x87' # build 3-tuple from stack + int tuple_size = opcode == 0x85 ? 1 : (opcode == 0x86 ? 2 : 3); + if ((int)stack.size() < tuple_size) { + return false; + } + std::vector items(stack.end() - tuple_size, stack.end()); + stack.erase(stack.end() - tuple_size, stack.end()); + stack.push_back(make_tuple_value(std::move(items))); + } break; + case 't': { // TUPLE = b't' # build tuple from mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + return false; + } + std::vector items(stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + stack.push_back(make_tuple_value(std::move(items))); + } break; + case 'Q': { // BINPERSID = b'Q' # persistent id from stack + if (stack.empty()) { + return false; + } + PickleValue pid = stack.back(); + stack.pop_back(); + if (pid.kind != PickleValue::TUPLE || pid.items.size() < 5 || pid.items[0].kind != PickleValue::STRING || + pid.items[1].kind != PickleValue::GLOBAL || pid.items[4].kind != PickleValue::INT || + pid.items[0].str_value != "storage") { + return false; + } + + PickleStorageInfo storage; + storage.key = pickle_value_to_string(pid.items[2]); + if (storage.key.empty() || !parse_storage_type(pid.items[1].str_value, &storage)) { + return false; + } + storage.nbytes = (uint64_t)pid.items[4].int_value * storage.raw_element_nbytes; + storage_nbytes[storage.key] = storage.nbytes; + stack.push_back(make_storage_value(storage)); + } break; + case 'R': { // REDUCE = b'R' # apply callable to args + if (stack.size() < 2) { + return false; + } + PickleValue args = stack.back(); + stack.pop_back(); + PickleValue callable = stack.back(); + stack.pop_back(); + if (callable.kind != PickleValue::GLOBAL || args.kind != PickleValue::TUPLE) { + stack.push_back(make_none_value()); + break; + } + + if (callable.str_value == "collections.OrderedDict" && args.items.empty()) { + stack.push_back(make_dict_value(true)); + break; + } + + if ((callable.str_value == "torch._utils._rebuild_tensor_v2" || callable.str_value == "torch._utils._rebuild_tensor") && + args.items.size() >= 4 && args.items[0].kind == PickleValue::STORAGE && + args.items[1].kind == PickleValue::INT && args.items[2].kind == PickleValue::TUPLE && + args.items[3].kind == PickleValue::TUPLE) { + PickleTensorInfo tensor; + tensor.tensor_storage.type = args.items[0].storage.type; + tensor.tensor_storage.is_f64 = args.items[0].storage.is_f64; + tensor.tensor_storage.is_i64 = args.items[0].storage.is_i64; + tensor.tensor_storage.storage_key = args.items[0].storage.key; + tensor.tensor_storage.offset = (uint64_t)args.items[1].int_value * args.items[0].storage.raw_element_nbytes; + + for (const auto& item : args.items[2].items) { + if (item.kind != PickleValue::INT || tensor.tensor_storage.n_dims >= SD_MAX_DIMS) { + return false; + } + tensor.tensor_storage.ne[tensor.tensor_storage.n_dims++] = item.int_value; + } + + for (const auto& item : args.items[3].items) { + if (item.kind != PickleValue::INT || tensor.stride_n_dims >= SD_MAX_DIMS) { + return false; + } + tensor.stride[tensor.stride_n_dims++] = item.int_value; + } + + if (!tensor_is_contiguous(tensor)) { + return false; + } + stack.push_back(make_tensor_value(tensor)); + break; + } + + // Non-tensor checkpoint metadata can use REDUCE for arbitrary + // Python objects. Do not execute it; keep stack shape only. + stack.push_back(make_none_value()); + break; + } + case 'b': // BUILD = b'b' # build object state + if (stack.size() < 2) { + return false; + } + stack.pop_back(); + break; + case 'u': { // SETITEMS = b'u' # add mark-delimited items to dict + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx <= 0) { + return false; + } + PickleValue& dict = stack[mark_idx - 1]; + if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) { + return false; + } + for (int i = mark_idx + 1; i + 1 < (int)stack.size(); i += 2) { + dict.dict_items.emplace_back(stack[i], stack[i + 1]); + } + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case 's': { // SETITEM = b's' # add key/value to dict + if (stack.size() < 3) { + return false; + } + PickleValue value = stack.back(); + stack.pop_back(); + PickleValue key = stack.back(); + stack.pop_back(); + PickleValue& dict = stack.back(); + if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) { + return false; + } + dict.dict_items.emplace_back(key, value); + } break; + case 'e': { // APPENDS = b'e' # extend list with mark-delimited items + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) { + return false; + } + PickleValue& list_value = stack[mark_idx - 1]; + list_value.items.insert(list_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case 'a': { // APPEND = b'a' # append item to list + if (stack.size() < 2) { + return false; + } + PickleValue item = stack.back(); + stack.pop_back(); + if (stack.back().kind != PickleValue::LIST) { + return false; + } + stack.back().items.push_back(item); + } break; + default: + set_error(error, + "unsupported torch pickle opcode 0x" + sd_format("%02X", opcode) + + " at offset " + std::to_string((p - buffer) - 1)); + return false; + } + } + + set_error(error, "unterminated torch state_dict pickle"); + return false; +} diff --git a/src/model_io/pickle_io.h b/src/model_io/pickle_io.h new file mode 100644 index 000000000..6a3db37b9 --- /dev/null +++ b/src/model_io/pickle_io.h @@ -0,0 +1,21 @@ +#ifndef __SD_MODEL_IO_PICKLE_IO_H__ +#define __SD_MODEL_IO_PICKLE_IO_H__ + +#include +#include +#include +#include +#include + +#include "tensor_storage.h" + +bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size); +bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size); +bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value); +bool parse_torch_state_dict_pickle(const uint8_t* buffer, + size_t buffer_size, + std::vector& tensor_storages, + std::unordered_map& storage_nbytes, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_PICKLE_IO_H__ diff --git a/src/model_io/safetensors_io.cpp b/src/model_io/safetensors_io.cpp new file mode 100644 index 000000000..889352218 --- /dev/null +++ b/src/model_io/safetensors_io.cpp @@ -0,0 +1,316 @@ +#include "safetensors_io.h" + +#include +#include +#include +#include +#include + +#include "binary_io.h" +#include "json.hpp" +#include "util.h" + +static constexpr size_t ST_HEADER_SIZE_LEN = 8; + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool is_safetensors_file(const std::string& file_path) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + return false; + } + + // get file size + file.seekg(0, file.end); + size_t file_size_ = file.tellg(); + file.seekg(0, file.beg); + + // read header size + if (file_size_ <= ST_HEADER_SIZE_LEN) { + return false; + } + + uint8_t header_size_buf[ST_HEADER_SIZE_LEN]; + file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN); + if (!file) { + return false; + } + + size_t header_size_ = model_io::read_u64(header_size_buf); + if (header_size_ >= file_size_ || header_size_ <= 2) { + return false; + } + + // read header + std::vector header_buf; + header_buf.resize(header_size_ + 1); + header_buf[header_size_] = '\0'; + file.read(header_buf.data(), header_size_); + if (!file) { + return false; + } + try { + nlohmann::json header_ = nlohmann::json::parse(header_buf.data()); + } catch (const std::exception&) { + return false; + } + return true; +} + +static ggml_type safetensors_dtype_to_ggml_type(const std::string& dtype) { + ggml_type ttype = GGML_TYPE_COUNT; + if (dtype == "F16") { + ttype = GGML_TYPE_F16; + } else if (dtype == "BF16") { + ttype = GGML_TYPE_BF16; + } else if (dtype == "F32") { + ttype = GGML_TYPE_F32; + } else if (dtype == "F64") { + ttype = GGML_TYPE_F32; + } else if (dtype == "F8_E4M3") { + ttype = GGML_TYPE_F16; + } else if (dtype == "F8_E5M2") { + ttype = GGML_TYPE_F16; + } else if (dtype == "I32") { + ttype = GGML_TYPE_I32; + } else if (dtype == "I64") { + ttype = GGML_TYPE_I32; + } + return ttype; +} + +// https://huggingface.co/docs/safetensors/index +bool read_safetensors_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + set_error(error, "failed to open '" + file_path + "'"); + return false; + } + + // get file size + file.seekg(0, file.end); + size_t file_size_ = file.tellg(); + file.seekg(0, file.beg); + + // read header size + if (file_size_ <= ST_HEADER_SIZE_LEN) { + set_error(error, "invalid safetensor file '" + file_path + "'"); + return false; + } + + uint8_t header_size_buf[ST_HEADER_SIZE_LEN]; + file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN); + if (!file) { + set_error(error, "read safetensors header size failed: '" + file_path + "'"); + return false; + } + + size_t header_size_ = model_io::read_u64(header_size_buf); + if (header_size_ >= file_size_) { + set_error(error, "invalid safetensor file '" + file_path + "'"); + return false; + } + + // read header + std::vector header_buf; + header_buf.resize(header_size_ + 1); + header_buf[header_size_] = '\0'; + file.read(header_buf.data(), header_size_); + if (!file) { + set_error(error, "read safetensors header failed: '" + file_path + "'"); + return false; + } + + nlohmann::json header_; + try { + header_ = nlohmann::json::parse(header_buf.data()); + } catch (const std::exception&) { + set_error(error, "parsing safetensors header failed: '" + file_path + "'"); + return false; + } + + tensor_storages.clear(); + for (auto& item : header_.items()) { + std::string name = item.key(); + nlohmann::json tensor_info = item.value(); + // LOG_DEBUG("%s %s\n", name.c_str(), tensor_info.dump().c_str()); + + if (name == "__metadata__") { + continue; + } + + std::string dtype = tensor_info["dtype"]; + nlohmann::json shape = tensor_info["shape"]; + + if (dtype == "U8") { + continue; + } + + size_t begin = tensor_info["data_offsets"][0].get(); + size_t end = tensor_info["data_offsets"][1].get(); + + ggml_type type = safetensors_dtype_to_ggml_type(dtype); + if (type == GGML_TYPE_COUNT) { + set_error(error, "unsupported dtype '" + dtype + "' (tensor '" + name + "')"); + return false; + } + + if (shape.size() > SD_MAX_DIMS) { + set_error(error, "invalid tensor '" + name + "'"); + return false; + } + + int n_dims = (int)shape.size(); + int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + for (int i = 0; i < n_dims; i++) { + ne[i] = shape[i].get(); + } + + if (n_dims == 5) { + n_dims = 4; + ne[0] = ne[0] * ne[1]; + ne[1] = ne[2]; + ne[2] = ne[3]; + ne[3] = ne[4]; + } + + // ggml_n_dims returns 1 for scalars + if (n_dims == 0) { + n_dims = 1; + } + + TensorStorage tensor_storage(name, type, ne, n_dims, 0, ST_HEADER_SIZE_LEN + header_size_ + begin); + tensor_storage.reverse_ne(); + + size_t tensor_data_size = end - begin; + + bool tensor_size_ok; + if (dtype == "F8_E4M3") { + tensor_storage.is_f8_e4m3 = true; + // f8 -> f16 + tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2); + } else if (dtype == "F8_E5M2") { + tensor_storage.is_f8_e5m2 = true; + // f8 -> f16 + tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2); + } else if (dtype == "F64") { + tensor_storage.is_f64 = true; + // f64 -> f32 + tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size); + } else if (dtype == "I64") { + tensor_storage.is_i64 = true; + // i64 -> i32 + tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size); + } else { + tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size); + } + if (!tensor_size_ok) { + set_error(error, "size mismatch for tensor '" + name + "' (" + dtype + ")"); + return false; + } + + tensor_storages.push_back(tensor_storage); + + // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str()); + } + + return true; +} + +static bool ggml_type_to_safetensors_dtype(ggml_type type, std::string* dtype) { + switch (type) { + case GGML_TYPE_F16: + *dtype = "F16"; + return true; + case GGML_TYPE_BF16: + *dtype = "BF16"; + return true; + case GGML_TYPE_F32: + *dtype = "F32"; + return true; + case GGML_TYPE_I32: + *dtype = "I32"; + return true; + default: + return false; + } +} + +bool write_safetensors_file(const std::string& file_path, + const std::vector& tensors, + std::string* error) { + nlohmann::ordered_json header = nlohmann::ordered_json::object(); + + uint64_t data_offset = 0; + for (const TensorWriteInfo& write_tensor : tensors) { + ggml_tensor* tensor = write_tensor.tensor; + if (tensor == nullptr) { + set_error(error, "null tensor cannot be written to safetensors"); + return false; + } + + const std::string name = ggml_get_name(tensor); + std::string dtype; + if (!ggml_type_to_safetensors_dtype(tensor->type, &dtype)) { + set_error(error, + "unsupported safetensors dtype '" + std::string(ggml_type_name(tensor->type)) + + "' for tensor '" + name + "'"); + return false; + } + + const uint64_t tensor_nbytes = ggml_nbytes(tensor); + + nlohmann::ordered_json json_tensor_info = nlohmann::ordered_json::object(); + json_tensor_info["dtype"] = dtype; + + nlohmann::ordered_json shape = nlohmann::ordered_json::array(); + for (int i = 0; i < write_tensor.n_dims; ++i) { + shape.push_back(write_tensor.ne[write_tensor.n_dims - 1 - i]); + } + json_tensor_info["shape"] = shape; + + nlohmann::ordered_json data_offsets = nlohmann::ordered_json::array(); + data_offsets.push_back(data_offset); + data_offsets.push_back(data_offset + tensor_nbytes); + json_tensor_info["data_offsets"] = data_offsets; + + header[name] = json_tensor_info; + data_offset += tensor_nbytes; + } + + const std::string header_str = header.dump(); + + std::ofstream file(file_path, std::ios::binary); + if (!file.is_open()) { + set_error(error, "failed to open '" + file_path + "' for writing"); + return false; + } + + LOG_INFO("trying to save tensors to %s", file_path.c_str()); + model_io::write_u64(file, header_str.size()); + file.write(header_str.data(), header_str.size()); + if (!file) { + set_error(error, "failed to write safetensors header to '" + file_path + "'"); + return false; + } + + for (const TensorWriteInfo& write_tensor : tensors) { + ggml_tensor* tensor = write_tensor.tensor; + const std::string name = ggml_get_name(tensor); + const size_t tensor_nbytes = ggml_nbytes(tensor); + file.write((const char*)tensor->data, tensor_nbytes); + if (!file) { + set_error(error, + "failed to write tensor '" + name + "' to '" + file_path + "'"); + return false; + } + } + + return true; +} diff --git a/src/model_io/safetensors_io.h b/src/model_io/safetensors_io.h new file mode 100644 index 000000000..08a1bc1f3 --- /dev/null +++ b/src/model_io/safetensors_io.h @@ -0,0 +1,17 @@ +#ifndef __SD_MODEL_IO_SAFETENSORS_IO_H__ +#define __SD_MODEL_IO_SAFETENSORS_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool is_safetensors_file(const std::string& file_path); +bool read_safetensors_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); +bool write_safetensors_file(const std::string& file_path, + const std::vector& tensors, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_SAFETENSORS_IO_H__ diff --git a/src/model_io/tensor_storage.h b/src/model_io/tensor_storage.h new file mode 100644 index 000000000..c0cf079c5 --- /dev/null +++ b/src/model_io/tensor_storage.h @@ -0,0 +1,132 @@ +#ifndef __SD_TENSOR_STORAGE_H__ +#define __SD_TENSOR_STORAGE_H__ + +#include +#include +#include +#include +#include +#include +#include + +#include "ggml.h" + +#define SD_MAX_DIMS 5 + +struct TensorStorage { + std::string name; + ggml_type type = GGML_TYPE_F32; + ggml_type expected_type = GGML_TYPE_COUNT; + bool is_f8_e4m3 = false; + bool is_f8_e5m2 = false; + bool is_f64 = false; + bool is_i64 = false; + int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + int n_dims = 0; + + std::string storage_key; + size_t file_index = 0; + int index_in_zip = -1; // >= means stored in a zip file + uint64_t offset = 0; // offset in file + + TensorStorage() = default; + + TensorStorage(std::string name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0) + : name(std::move(name)), type(type), n_dims(n_dims), file_index(file_index), offset(offset) { + for (int i = 0; i < n_dims; i++) { + this->ne[i] = ne[i]; + } + } + + int64_t nelements() const { + int64_t n = 1; + for (int i = 0; i < SD_MAX_DIMS; i++) { + n *= ne[i]; + } + return n; + } + + int64_t nbytes() const { + return nelements() * ggml_type_size(type) / ggml_blck_size(type); + } + + int64_t nbytes_to_read() const { + if (is_f8_e4m3 || is_f8_e5m2) { + return nbytes() / 2; + } else if (is_f64 || is_i64) { + return nbytes() * 2; + } else { + return nbytes(); + } + } + + void unsqueeze() { + if (n_dims == 2) { + n_dims = 4; + ne[3] = ne[1]; + ne[2] = ne[0]; + ne[1] = 1; + ne[0] = 1; + } + } + + std::vector chunk(size_t n) { + std::vector chunks; + uint64_t chunk_size = nbytes_to_read() / n; + // printf("%d/%d\n", chunk_size, nbytes_to_read()); + reverse_ne(); + for (size_t i = 0; i < n; i++) { + TensorStorage chunk_i = *this; + chunk_i.ne[0] = ne[0] / n; + chunk_i.offset = offset + i * chunk_size; + chunk_i.reverse_ne(); + chunks.push_back(chunk_i); + } + reverse_ne(); + return chunks; + } + + void reverse_ne() { + int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + for (int i = 0; i < n_dims; i++) { + new_ne[i] = ne[n_dims - 1 - i]; + } + for (int i = 0; i < n_dims; i++) { + ne[i] = new_ne[i]; + } + } + + std::string to_string() const { + std::stringstream ss; + const char* type_name = ggml_type_name(type); + if (is_f8_e4m3) { + type_name = "f8_e4m3"; + } else if (is_f8_e5m2) { + type_name = "f8_e5m2"; + } else if (is_f64) { + type_name = "f64"; + } else if (is_i64) { + type_name = "i64"; + } + ss << name << " | " << type_name << " | "; + ss << n_dims << " ["; + for (int i = 0; i < SD_MAX_DIMS; i++) { + ss << ne[i]; + if (i != SD_MAX_DIMS - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } +}; + +struct TensorWriteInfo { + int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + int n_dims = 0; + ggml_tensor* tensor = nullptr; +}; + +typedef std::function on_new_tensor_cb_t; + +#endif // __SD_TENSOR_STORAGE_H__ diff --git a/src/model_io/torch_legacy_io.cpp b/src/model_io/torch_legacy_io.cpp new file mode 100644 index 000000000..816547252 --- /dev/null +++ b/src/model_io/torch_legacy_io.cpp @@ -0,0 +1,252 @@ +#include "torch_legacy_io.h" + +#include +#include +#include +#include +#include +#include + +#include "pickle_io.h" +#include "util.h" + +// torch.save format background: +// +// - Before PyTorch 1.6.0, torch.save used this legacy non-zip format by +// default. +// - Since PyTorch 1.6.0, torch.save defaults to an uncompressed ZIP64 archive +// containing data.pkl, data/, version, and, since PyTorch 2.1.0, byteorder. +// - The old format can still be produced explicitly with: +// torch.save(obj, path, _use_new_zipfile_serialization=False) +// +// Whether obj is a state_dict or a whole nn.Module does not change the outer +// container format selected by torch.save. It changes the pickled object inside: +// +// - state_dict: usually an OrderedDict[str, Tensor]. pickle_io.cpp supports a +// restricted subset of this layout because tensor metadata and raw storages +// can be recovered without executing pickle callables. +// - whole module/checkpoint object: arbitrary Python object graph. This may +// require importing user classes and executing pickle GLOBAL/REDUCE rebuild +// logic, so it is intentionally not supported here. +// +// Legacy non-zip PyTorch files are not a single pickle object: +// +// 1. pickle object: PyTorch legacy magic number +// 2. pickle object: legacy protocol version, expected to be 1001 +// 3. pickle object: sys_info metadata, ignored by this reader +// 4. pickle object: state_dict metadata, parsed by pickle_io.cpp +// 5. pickle object: serialized storage key list, skipped here +// 6. raw storage data payloads +// - PyTorch writes storages after the pickles, ordered by storage key +// - each storage has an 8-byte legacy storage header followed by raw bytes +static constexpr size_t LEGACY_STORAGE_HEADER_SIZE = 8; + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +static std::string bytes_to_hex(const std::vector& bytes) { + static const char* hex = "0123456789ABCDEF"; + std::string result; + result.reserve(bytes.size() * 3); + for (size_t i = 0; i < bytes.size(); ++i) { + if (i > 0) { + result.push_back('-'); + } + result.push_back(hex[(bytes[i] >> 4) & 0x0F]); + result.push_back(hex[bytes[i] & 0x0F]); + } + return result; +} + +static bool is_probably_tar_file(const std::vector& header) { + return header.size() >= 262 && + header[257] == 'u' && + header[258] == 's' && + header[259] == 't' && + header[260] == 'a' && + header[261] == 'r'; +} + +static std::string torch_legacy_diagnostics(const std::string& file_path, const std::vector& buffer) { + if (!ends_with(file_path, ".pt") && !ends_with(file_path, ".pth")) { + return ""; + } + if (buffer.empty()) { + return "unsupported PyTorch file '" + file_path + "': empty file"; + } + + size_t short_len = std::min(buffer.size(), 32); + std::vector short_header(buffer.begin(), buffer.begin() + short_len); + const bool raw_pickle = buffer[0] == 0x80; + const bool tar_file = is_probably_tar_file(buffer); + + std::string message = "unsupported PyTorch file '" + file_path + "': first bytes " + + bytes_to_hex(short_header) + + ", raw_pickle=" + (raw_pickle ? "true" : "false") + + ", tar=" + (tar_file ? "true" : "false"); + if (raw_pickle) { + message += "; raw pickle did not match the restricted state_dict layouts currently supported"; + } else if (tar_file) { + message += "; legacy tar PyTorch checkpoints are not supported yet"; + } + return message; +} + +bool read_torch_legacy_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + set_error(error, "failed to open '" + file_path + "'"); + return false; + } + + file.seekg(0, file.end); + size_t file_size = (size_t)file.tellg(); + file.seekg(0, file.beg); + if (file_size == 0) { + set_error(error, "empty file '" + file_path + "'"); + return false; + } + + std::vector buffer(file_size); + file.read((char*)buffer.data(), file_size); + if (!file) { + set_error(error, "failed to read '" + file_path + "'"); + return false; + } + + auto finalize_tensor_offsets = [&](size_t storage_data_offset, + const std::unordered_map& legacy_storage_map) -> bool { + if (storage_data_offset > file_size) { + return false; + } + + std::vector storage_keys; + storage_keys.reserve(legacy_storage_map.size()); + for (const auto& [storage_key, _] : legacy_storage_map) { + storage_keys.push_back(storage_key); + } + std::sort(storage_keys.begin(), storage_keys.end()); + + std::unordered_map storage_offsets; + uint64_t current_offset = storage_data_offset; + for (const auto& storage_key : storage_keys) { + auto it = legacy_storage_map.find(storage_key); + if (it == legacy_storage_map.end()) { + return false; + } + if (current_offset + LEGACY_STORAGE_HEADER_SIZE + it->second > file_size) { + return false; + } + storage_offsets[storage_key] = current_offset + LEGACY_STORAGE_HEADER_SIZE; + current_offset += LEGACY_STORAGE_HEADER_SIZE + it->second; + } + + for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.storage_key.empty()) { + continue; + } + + auto it_offset = storage_offsets.find(tensor_storage.storage_key); + auto it_size = legacy_storage_map.find(tensor_storage.storage_key); + if (it_offset == storage_offsets.end() || it_size == legacy_storage_map.end()) { + return false; + } + + uint64_t base_offset = it_offset->second; + uint64_t storage_nbytes = it_size->second; + uint64_t tensor_nbytes = tensor_storage.nbytes_to_read(); + if (tensor_storage.offset + tensor_nbytes > storage_nbytes) { + return false; + } + + tensor_storage.offset = base_offset + tensor_storage.offset; + tensor_storage.storage_key.clear(); + } + + return true; + }; + + auto parse_state_dict_at = [&](size_t state_dict_offset, size_t state_dict_size, size_t* storage_data_offset) -> bool { + tensor_storages.clear(); + std::unordered_map legacy_storage_map; + if (!parse_torch_state_dict_pickle(buffer.data() + state_dict_offset, + state_dict_size, + tensor_storages, + legacy_storage_map, + error)) { + return false; + } + + size_t offset_after_state_dict = state_dict_offset + state_dict_size; + size_t storage_keys_size = 0; + if (!skip_pickle_object(buffer.data() + offset_after_state_dict, + buffer.size() - offset_after_state_dict, + &storage_keys_size)) { + return false; + } + + *storage_data_offset = offset_after_state_dict + storage_keys_size; + return finalize_tensor_offsets(*storage_data_offset, legacy_storage_map); + }; + + size_t object_size_1 = 0; + size_t offset = 0; + + if (skip_pickle_object(buffer.data(), buffer.size(), &object_size_1) && + pickle_object_is_torch_magic_number(buffer.data(), object_size_1)) { + offset += object_size_1; + + size_t object_size_2 = 0; + if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_2)) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + uint32_t protocol_version = 0; + if (!parse_pickle_uint32_object(buffer.data() + offset, object_size_2, &protocol_version) || protocol_version != 1001) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + offset += object_size_2; + + size_t object_size_3 = 0; + if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_3)) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + offset += object_size_3; + + size_t state_dict_size = 0; + if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &state_dict_size)) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + + size_t storage_data_offset = 0; + if (parse_state_dict_at(offset, state_dict_size, &storage_data_offset)) { + return true; + } + + if (error != nullptr && error->empty()) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + } + return false; + } + + size_t state_dict_size = 0; + if (skip_pickle_object(buffer.data(), buffer.size(), &state_dict_size)) { + size_t storage_data_offset = 0; + if (parse_state_dict_at(0, state_dict_size, &storage_data_offset)) { + return true; + } + } + + if (error != nullptr && error->empty()) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + } + return false; +} diff --git a/src/model_io/torch_legacy_io.h b/src/model_io/torch_legacy_io.h new file mode 100644 index 000000000..6680e02a1 --- /dev/null +++ b/src/model_io/torch_legacy_io.h @@ -0,0 +1,13 @@ +#ifndef __SD_MODEL_IO_TORCH_LEGACY_IO_H__ +#define __SD_MODEL_IO_TORCH_LEGACY_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool read_torch_legacy_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_TORCH_LEGACY_IO_H__ diff --git a/src/model_io/torch_zip_io.cpp b/src/model_io/torch_zip_io.cpp new file mode 100644 index 000000000..9eaf6c53a --- /dev/null +++ b/src/model_io/torch_zip_io.cpp @@ -0,0 +1,140 @@ +#include "torch_zip_io.h" + +#include +#include +#include +#include +#include + +#include "pickle_io.h" + +#include "zip.h" + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool is_torch_zip_file(const std::string& file_path) { + zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); + if (zip == nullptr) { + return false; + } + zip_close(zip); + return true; +} + +static bool find_zip_entry(zip_t* zip, const std::string& entry_name, int* index, uint64_t* size) { + size_t n = zip_entries_total(zip); + for (size_t i = 0; i < n; ++i) { + zip_entry_openbyindex(zip, i); + std::string name = zip_entry_name(zip); + if (name == entry_name) { + *index = (int)i; + *size = zip_entry_size(zip); + zip_entry_close(zip); + return true; + } + zip_entry_close(zip); + } + return false; +} + +static bool parse_zip_data_pkl(const uint8_t* buffer, + size_t buffer_size, + zip_t* zip, + const std::string& dir, + std::vector& tensor_storages, + std::string* error) { + std::vector parsed_tensors; + std::unordered_map storage_nbytes; + if (!parse_torch_state_dict_pickle(buffer, buffer_size, parsed_tensors, storage_nbytes, error)) { + if (error != nullptr && error->empty()) { + *error = "failed to parse torch zip pickle metadata"; + } + return false; + } + + for (auto& tensor_storage : parsed_tensors) { + if (tensor_storage.storage_key.empty()) { + set_error(error, "tensor '" + tensor_storage.name + "' has no storage key"); + return false; + } + + const std::string entry_name = dir + "data/" + tensor_storage.storage_key; + int zip_index = -1; + uint64_t entry_size = 0; + if (!find_zip_entry(zip, entry_name, &zip_index, &entry_size)) { + set_error(error, "storage entry '" + entry_name + "' was not found"); + return false; + } + + auto it_storage_size = storage_nbytes.find(tensor_storage.storage_key); + if (it_storage_size != storage_nbytes.end() && entry_size < it_storage_size->second) { + set_error(error, "storage entry '" + entry_name + "' is smaller than pickle metadata"); + return false; + } + + uint64_t tensor_nbytes = tensor_storage.nbytes_to_read(); + if (tensor_storage.offset + tensor_nbytes > entry_size) { + set_error(error, "tensor '" + tensor_storage.name + "' exceeds storage entry '" + entry_name + "'"); + return false; + } + + tensor_storage.index_in_zip = zip_index; + tensor_storage.storage_key.clear(); + tensor_storages.push_back(tensor_storage); + } + + return true; +} + +bool read_torch_zip_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); + if (zip == nullptr) { + set_error(error, "failed to open '" + file_path + "'"); + return false; + } + + tensor_storages.clear(); + bool success = true; + bool found_data_pkl = false; + int n = (int)zip_entries_total(zip); + for (int i = 0; i < n; ++i) { + zip_entry_openbyindex(zip, i); + std::string name = zip_entry_name(zip); + size_t pos = name.find("data.pkl"); + if (pos != std::string::npos) { + found_data_pkl = true; + std::string dir = name.substr(0, pos); + void* pkl_data = nullptr; + size_t pkl_size = 0; + zip_entry_read(zip, &pkl_data, &pkl_size); + + if (pkl_data == nullptr || pkl_size == 0) { + set_error(error, "failed to read '" + name + "' from '" + file_path + "'"); + success = false; + } else if (!parse_zip_data_pkl((const uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) { + success = false; + } + + free(pkl_data); + } + zip_entry_close(zip); + + if (!success) { + break; + } + } + + if (success && !found_data_pkl) { + set_error(error, "data.pkl was not found in '" + file_path + "'"); + success = false; + } + + zip_close(zip); + return success; +} diff --git a/src/model_io/torch_zip_io.h b/src/model_io/torch_zip_io.h new file mode 100644 index 000000000..54fb099a7 --- /dev/null +++ b/src/model_io/torch_zip_io.h @@ -0,0 +1,14 @@ +#ifndef __SD_MODEL_IO_TORCH_ZIP_IO_H__ +#define __SD_MODEL_IO_TORCH_ZIP_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool is_torch_zip_file(const std::string& file_path); +bool read_torch_zip_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_TORCH_ZIP_IO_H__ diff --git a/src/name_conversion.cpp b/src/name_conversion.cpp index d5d5e052c..618c7f6e9 100644 --- a/src/name_conversion.cpp +++ b/src/name_conversion.cpp @@ -1120,7 +1120,7 @@ std::string convert_tensor_name(std::string name, SDVersion version) { for (const auto& prefix : first_stage_model_prefix_vec) { if (starts_with(name, prefix)) { name = convert_first_stage_model_name(name.substr(prefix.size()), prefix); - if (version == VERSION_SDXS) { + if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) { name = "tae." + name; } else { name = prefix + name; diff --git a/src/preprocessing.hpp b/src/preprocessing.hpp index 7c83a289d..57ab0cec7 100644 --- a/src/preprocessing.hpp +++ b/src/preprocessing.hpp @@ -24,6 +24,75 @@ static inline void preprocessing_set_4d(sd::Tensor& tensor, float value, tensor.values()[static_cast(preprocessing_offset_4d(tensor, i0, i1, i2, i3))] = value; } +static inline uint8_t preprocessing_float_to_u8(float value) { + if (value <= 0.0f) { + return 0; + } + if (value >= 1.0f) { + return 255; + } + return static_cast(value * 255.0f + 0.5f); +} + +static inline void preprocessing_tensor_frame_to_sd_image(const sd::Tensor& tensor, int frame_index, uint8_t* image_data) { + const auto& shape = tensor.shape(); + GGML_ASSERT(shape.size() == 4 || shape.size() == 5); + GGML_ASSERT(image_data != nullptr); + + const int width = static_cast(shape[0]); + const int height = static_cast(shape[1]); + const int channel = static_cast(shape[shape.size() == 5 ? 3 : 2]); + const size_t pixels = static_cast(width) * static_cast(height); + const float* src = tensor.data(); + + if (shape.size() == 4) { + GGML_ASSERT(frame_index >= 0 && frame_index < shape[3]); + const size_t frame_stride = pixels * static_cast(channel); + const float* frame_ptr = src + static_cast(frame_index) * frame_stride; + if (channel == 3) { + const float* c0 = frame_ptr; + const float* c1 = frame_ptr + pixels; + const float* c2 = frame_ptr + pixels * 2; + for (size_t i = 0; i < pixels; ++i) { + image_data[i * 3 + 0] = preprocessing_float_to_u8(c0[i]); + image_data[i * 3 + 1] = preprocessing_float_to_u8(c1[i]); + image_data[i * 3 + 2] = preprocessing_float_to_u8(c2[i]); + } + return; + } + + for (size_t i = 0; i < pixels; ++i) { + for (int c = 0; c < channel; ++c) { + image_data[i * static_cast(channel) + static_cast(c)] = + preprocessing_float_to_u8(frame_ptr[i + pixels * static_cast(c)]); + } + } + return; + } + + GGML_ASSERT(frame_index >= 0 && frame_index < shape[2]); + const size_t channel_stride = pixels * static_cast(shape[2]); + const float* frame_ptr = src + static_cast(frame_index) * pixels; + if (channel == 3) { + const float* c0 = frame_ptr; + const float* c1 = frame_ptr + channel_stride; + const float* c2 = frame_ptr + channel_stride * 2; + for (size_t i = 0; i < pixels; ++i) { + image_data[i * 3 + 0] = preprocessing_float_to_u8(c0[i]); + image_data[i * 3 + 1] = preprocessing_float_to_u8(c1[i]); + image_data[i * 3 + 2] = preprocessing_float_to_u8(c2[i]); + } + return; + } + + for (size_t i = 0; i < pixels; ++i) { + for (int c = 0; c < channel; ++c) { + image_data[i * static_cast(channel) + static_cast(c)] = + preprocessing_float_to_u8(frame_ptr[i + channel_stride * static_cast(c)]); + } + } +} + static inline sd::Tensor sd_image_to_preprocessing_tensor(sd_image_t image) { sd::Tensor tensor({static_cast(image.width), static_cast(image.height), static_cast(image.channel), 1}); for (uint32_t y = 0; y < image.height; ++y) { @@ -39,20 +108,7 @@ static inline sd::Tensor sd_image_to_preprocessing_tensor(sd_image_t imag static inline void preprocessing_tensor_to_sd_image(const sd::Tensor& tensor, uint8_t* image_data) { GGML_ASSERT(tensor.dim() == 4); GGML_ASSERT(tensor.shape()[3] == 1); - GGML_ASSERT(image_data != nullptr); - - int width = static_cast(tensor.shape()[0]); - int height = static_cast(tensor.shape()[1]); - int channel = static_cast(tensor.shape()[2]); - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < channel; ++c) { - float value = preprocessing_get_4d(tensor, x, y, c, 0); - value = std::min(1.0f, std::max(0.0f, value)); - image_data[(y * width + x) * channel + c] = static_cast(std::round(value * 255.0f)); - } - } - } + preprocessing_tensor_frame_to_sd_image(tensor, 0, image_data); } static inline sd::Tensor gaussian_kernel_tensor(int kernel_size) { diff --git a/src/qwen_image.hpp b/src/qwen_image.hpp index 83c8cec66..35d32109e 100644 --- a/src/qwen_image.hpp +++ b/src/qwen_image.hpp @@ -95,9 +95,7 @@ namespace Qwen { float scale = 1.f / 32.f; bool force_prec_f32 = false; -#ifdef SD_USE_VULKAN - force_prec_f32 = true; -#endif + // The purpose of the scale here is to prevent NaN issues in certain situations. // For example when using CUDA but the weights are k-quants (not all prompts). blocks["to_out.0"] = std::shared_ptr(new Linear(inner_dim, out_dim, out_bias, false, force_prec_f32, scale)); @@ -124,6 +122,10 @@ namespace Qwen { auto to_v = std::dynamic_pointer_cast(blocks["to_v"]); auto to_out_0 = std::dynamic_pointer_cast(blocks["to_out.0"]); + if (sd_backend_is(ctx->backend, "Vulkan")) { + to_out_0->set_force_prec_f32(true); + } + auto norm_added_q = std::dynamic_pointer_cast(blocks["norm_added_q"]); auto norm_added_k = std::dynamic_pointer_cast(blocks["norm_added_k"]); @@ -410,6 +412,9 @@ namespace Qwen { auto img = img_in->forward(ctx, x); auto txt = txt_norm->forward(ctx, context); txt = txt_in->forward(ctx, txt); + sd::ggml_graph_cut::mark_graph_cut(img, "qwen_image.prelude", "img"); + sd::ggml_graph_cut::mark_graph_cut(txt, "qwen_image.prelude", "txt"); + // sd::ggml_graph_cut::mark_graph_cut(t_emb, "qwen_image.prelude", "t_emb"); for (int i = 0; i < params.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["transformer_blocks." + std::to_string(i)]); @@ -417,6 +422,8 @@ namespace Qwen { auto result = block->forward(ctx, img, txt, t_emb, pe, modulate_index); img = result.first; txt = result.second; + sd::ggml_graph_cut::mark_graph_cut(img, "qwen_image.transformer_blocks." + std::to_string(i), "img"); + sd::ggml_graph_cut::mark_graph_cut(txt, "qwen_image.transformer_blocks." + std::to_string(i), "txt"); } if (params.zero_cond_t) { diff --git a/src/rope.hpp b/src/rope.hpp index db577f5d3..f84fac885 100644 --- a/src/rope.hpp +++ b/src/rope.hpp @@ -7,6 +7,11 @@ #include "ggml_extend.hpp" namespace Rope { + enum class EmbedNDLayout { + Matrix, + ErnieImage, + }; + template __STATIC_INLINE__ std::vector linspace(T start, T end, int num) { std::vector result(num); @@ -169,7 +174,8 @@ namespace Rope { int bs, const std::vector& axis_thetas, const std::vector& axes_dim, - const std::vector>& wrap_dims = {}) { + const std::vector>& wrap_dims = {}, + EmbedNDLayout layout = EmbedNDLayout::Matrix) { std::vector> trans_ids = transpose(ids); size_t pos_len = ids.size() / bs; size_t num_axes = axes_dim.size(); @@ -204,6 +210,24 @@ namespace Rope { offset += rope_emb[0].size(); } + if (layout == EmbedNDLayout::ErnieImage) { + int head_dim = emb_dim * 2; + std::vector ernie_emb(bs * pos_len * head_dim * 2, 0.0f); + for (size_t pos_idx = 0; pos_idx < bs * pos_len; ++pos_idx) { + for (int i = 0; i < emb_dim; ++i) { + float cos_val = emb[pos_idx][4 * i]; + float sin_val = emb[pos_idx][4 * i + 2]; + size_t cos_offset = pos_idx * head_dim + 2 * i; + size_t sin_offset = bs * pos_len * head_dim + cos_offset; + ernie_emb[cos_offset] = cos_val; + ernie_emb[cos_offset + 1] = cos_val; + ernie_emb[sin_offset] = sin_val; + ernie_emb[sin_offset + 1] = sin_val; + } + } + return ernie_emb; + } + return flatten(emb); } @@ -211,9 +235,10 @@ namespace Rope { int bs, float theta, const std::vector& axes_dim, - const std::vector>& wrap_dims = {}) { + const std::vector>& wrap_dims = {}, + EmbedNDLayout layout = EmbedNDLayout::Matrix) { std::vector axis_thetas(axes_dim.size(), theta); - return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims); + return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims, layout); } __STATIC_INLINE__ std::vector> gen_refs_ids(int patch_size, @@ -437,6 +462,74 @@ namespace Rope { return embed_nd(ids, bs, static_cast(theta), axes_dim, wrap_dims); } + __STATIC_INLINE__ std::vector> gen_ernie_image_ids(int h, + int w, + int patch_size, + int bs, + int context_len) { + int h_len = h / patch_size; + int w_len = w / patch_size; + + std::vector> img_ids(h_len * w_len, std::vector(3, 0.0f)); + std::vector h_ids = linspace(0.f, static_cast(h_len - 1), h_len); + std::vector w_ids = linspace(0.f, static_cast(w_len - 1), w_len); + for (int i = 0; i < h_len; ++i) { + for (int j = 0; j < w_len; ++j) { + img_ids[i * w_len + j][0] = static_cast(context_len); + img_ids[i * w_len + j][1] = h_ids[i]; + img_ids[i * w_len + j][2] = w_ids[j]; + } + } + + std::vector> img_ids_repeated(bs * img_ids.size(), std::vector(3, 0.0f)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < static_cast(img_ids.size()); ++j) { + img_ids_repeated[i * img_ids.size() + j] = img_ids[j]; + } + } + + std::vector> txt_ids(bs * context_len, std::vector(3, 0.0f)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < context_len; ++j) { + txt_ids[i * context_len + j][0] = static_cast(j); + } + } + + return concat_ids(img_ids_repeated, txt_ids, bs); + } + + __STATIC_INLINE__ std::vector gen_ernie_image_pe(int h, + int w, + int patch_size, + int bs, + int context_len, + int theta, + bool circular_h, + bool circular_w, + const std::vector& axes_dim) { + std::vector> ids = gen_ernie_image_ids(h, w, patch_size, bs, context_len); + std::vector> wrap_dims; + if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { + int h_len = h / patch_size; + int w_len = w / patch_size; + if (h_len > 0 && w_len > 0) { + size_t pos_len = ids.size() / bs; + wrap_dims.assign(axes_dim.size(), std::vector(pos_len, 0)); + const size_t img_tokens = static_cast(h_len) * static_cast(w_len); + for (size_t token_i = 0; token_i < img_tokens; ++token_i) { + if (circular_h) { + wrap_dims[1][token_i] = h_len; + } + if (circular_w) { + wrap_dims[2][token_i] = w_len; + } + } + } + } + + return embed_nd(ids, bs, static_cast(theta), axes_dim, wrap_dims, EmbedNDLayout::ErnieImage); + } + __STATIC_INLINE__ std::vector> gen_vid_ids(int t, int h, int w, diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index ae34530b0..fd439ff1d 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -17,6 +17,7 @@ #include "pmid.hpp" #include "sample-cache.h" #include "tae.hpp" +#include "upscaler.h" #include "vae.hpp" #include "latent-preview.h" @@ -30,7 +31,8 @@ const char* model_version_to_str[] = { "SD 2.x", "SD 2.x Inpaint", "SD 2.x Tiny UNet", - "SDXS", + "SDXS (512-DS)", + "SDXS (09)", "SDXL", "SDXL Inpaint", "SDXL Instruct-Pix2Pix", @@ -52,6 +54,7 @@ const char* model_version_to_str[] = { "Flux.2 klein", "Z-Image", "Ovis Image", + "Ernie Image", }; const char* sampling_methods_str[] = { @@ -69,6 +72,7 @@ const char* sampling_methods_str[] = { "TCD", "Res Multistep", "Res 2s", + "ER-SDE", }; /*================================================== Helper Functions ================================================*/ @@ -140,6 +144,7 @@ class StableDiffusionGGML { std::string taesd_path; sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0}; bool offload_params_to_cpu = false; + float max_vram = 0.f; bool use_pmid = false; bool is_using_v_parameterization = false; @@ -168,60 +173,7 @@ class StableDiffusionGGML { } void init_backend() { -#ifdef SD_USE_CUDA - LOG_DEBUG("Using CUDA backend"); - backend = ggml_backend_cuda_init(0); -#endif -#ifdef SD_USE_METAL - LOG_DEBUG("Using Metal backend"); - backend = ggml_backend_metal_init(); -#endif -#ifdef SD_USE_VULKAN - LOG_DEBUG("Using Vulkan backend"); - size_t device = 0; - const int device_count = ggml_backend_vk_get_device_count(); - if (device_count) { - const char* SD_VK_DEVICE = getenv("SD_VK_DEVICE"); - if (SD_VK_DEVICE != nullptr) { - std::string sd_vk_device_str = SD_VK_DEVICE; - try { - device = std::stoull(sd_vk_device_str); - } catch (const std::invalid_argument&) { - LOG_WARN("SD_VK_DEVICE environment variable is not a valid integer (%s). Falling back to device 0.", SD_VK_DEVICE); - device = 0; - } catch (const std::out_of_range&) { - LOG_WARN("SD_VK_DEVICE environment variable value is out of range for `unsigned long long` type (%s). Falling back to device 0.", SD_VK_DEVICE); - device = 0; - } - if (device >= device_count) { - LOG_WARN("Cannot find targeted vulkan device (%llu). Falling back to device 0.", device); - device = 0; - } - } - LOG_INFO("Vulkan: Using device %llu", device); - backend = ggml_backend_vk_init(device); - } - if (!backend) { - LOG_WARN("Failed to initialize Vulkan backend"); - } -#endif -#ifdef SD_USE_OPENCL - LOG_DEBUG("Using OpenCL backend"); - // ggml_log_set(ggml_log_callback_default, nullptr); // Optional ggml logs - backend = ggml_backend_opencl_init(); - if (!backend) { - LOG_WARN("Failed to initialize OpenCL backend"); - } -#endif -#ifdef SD_USE_SYCL - LOG_DEBUG("Using SYCL backend"); - backend = ggml_backend_sycl_init(0); -#endif - - if (!backend) { - LOG_DEBUG("Using CPU backend"); - backend = ggml_backend_cpu_init(); - } + backend = sd_get_default_backend(); } std::shared_ptr get_rng(rng_type_t rng_type) { @@ -239,6 +191,7 @@ class StableDiffusionGGML { vae_decode_only = sd_ctx_params->vae_decode_only; free_params_immediately = sd_ctx_params->free_params_immediately; offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu; + max_vram = sd_ctx_params->max_vram; bool use_tae = false; @@ -413,7 +366,7 @@ class StableDiffusionGGML { } bool tae_preview_only = sd_ctx_params->tae_preview_only; - if (version == VERSION_SDXS) { + if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) { tae_preview_only = false; use_tae = true; } @@ -424,6 +377,10 @@ class StableDiffusionGGML { bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu; + const size_t max_graph_vram_bytes = max_vram <= 0.f + ? 0 + : static_cast(static_cast(max_vram) * 1024.0 * 1024.0 * 1024.0); + { clip_backend = backend; if (clip_on_cpu && !ggml_backend_is_cpu(backend)) { @@ -513,6 +470,7 @@ class StableDiffusionGGML { clip_vision = std::make_shared(backend, offload_params_to_cpu, tensor_storage_map); + clip_vision->set_max_graph_vram_bytes(max_graph_vram_bytes); clip_vision->alloc_params_buffer(); clip_vision->get_param_tensors(tensors); } @@ -551,6 +509,15 @@ class StableDiffusionGGML { tensor_storage_map, "model.diffusion_model", version); + } else if (sd_version_is_ernie_image(version)) { + cond_stage_model = std::make_shared(clip_backend, + offload_params_to_cpu, + tensor_storage_map, + version); + diffusion_model = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model"); } else { // SD1.x SD2.x SDXL std::map embbeding_map; for (uint32_t i = 0; i < sd_ctx_params->embedding_count; i++) { @@ -580,9 +547,11 @@ class StableDiffusionGGML { } } + cond_stage_model->set_max_graph_vram_bytes(max_graph_vram_bytes); cond_stage_model->alloc_params_buffer(); cond_stage_model->get_param_tensors(tensors); + diffusion_model->set_max_graph_vram_bytes(max_graph_vram_bytes); diffusion_model->alloc_params_buffer(); diffusion_model->get_param_tensors(tensors); @@ -591,6 +560,7 @@ class StableDiffusionGGML { } if (high_noise_diffusion_model) { + high_noise_diffusion_model->set_max_graph_vram_bytes(max_graph_vram_bytes); high_noise_diffusion_model->alloc_params_buffer(); high_noise_diffusion_model->get_param_tensors(tensors); } @@ -663,16 +633,19 @@ class StableDiffusionGGML { } else if (use_tae && !tae_preview_only) { LOG_INFO("using TAE for encoding / decoding"); first_stage_model = create_tae(); + first_stage_model->set_max_graph_vram_bytes(max_graph_vram_bytes); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "tae"); } else { LOG_INFO("using VAE for encoding / decoding"); first_stage_model = create_vae(); + first_stage_model->set_max_graph_vram_bytes(max_graph_vram_bytes); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); if (use_tae && tae_preview_only) { LOG_INFO("using TAE for preview"); preview_vae = create_tae(); + preview_vae->set_max_graph_vram_bytes(max_graph_vram_bytes); preview_vae->alloc_params_buffer(); preview_vae->get_param_tensors(tensors, "tae"); } @@ -819,6 +792,10 @@ class StableDiffusionGGML { if (version == VERSION_SVD) { ignore_tensors.insert("conditioner.embedders.3"); } + if (sd_version_is_ernie_image(version)) { + ignore_tensors.insert("text_encoders.llm.vision_tower."); + ignore_tensors.insert("text_encoders.llm.multi_modal_projector."); + } bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads, sd_ctx_params->enable_mmap); if (!success) { LOG_ERROR("load tensors from model loader failed"); @@ -922,10 +899,13 @@ class StableDiffusionGGML { sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || + sd_version_is_ernie_image(version) || sd_version_is_z_image(version)) { pred_type = FLOW_PRED; if (sd_version_is_wan(version)) { default_flow_shift = 5.f; + } else if (sd_version_is_ernie_image(version)) { + default_flow_shift = 4.f; } else { default_flow_shift = 3.f; } @@ -1137,8 +1117,13 @@ class StableDiffusionGGML { cond_stage_lora_models.push_back(lora); } } - auto multi_lora_adapter = std::make_shared(cond_stage_lora_models); - cond_stage_model->set_weight_adapter(multi_lora_adapter); + // Only attach the adapter when there are LoRAs targeting the cond_stage model. + // An empty MultiLoraAdapter still routes every linear/conv through + // forward_with_lora() instead of the direct kernel path โ€” slower for no benefit. + if (!cond_stage_lora_models.empty()) { + auto multi_lora_adapter = std::make_shared(cond_stage_lora_models); + cond_stage_model->set_weight_adapter(multi_lora_adapter); + } } if (diffusion_model) { std::vector> lora_models; @@ -1169,10 +1154,12 @@ class StableDiffusionGGML { diffusion_lora_models.push_back(lora); } } - auto multi_lora_adapter = std::make_shared(diffusion_lora_models); - diffusion_model->set_weight_adapter(multi_lora_adapter); - if (high_noise_diffusion_model) { - high_noise_diffusion_model->set_weight_adapter(multi_lora_adapter); + if (!diffusion_lora_models.empty()) { + auto multi_lora_adapter = std::make_shared(diffusion_lora_models); + diffusion_model->set_weight_adapter(multi_lora_adapter); + if (high_noise_diffusion_model) { + high_noise_diffusion_model->set_weight_adapter(multi_lora_adapter); + } } } @@ -1205,8 +1192,10 @@ class StableDiffusionGGML { first_stage_lora_models.push_back(lora); } } - auto multi_lora_adapter = std::make_shared(first_stage_lora_models); - first_stage_model->set_weight_adapter(multi_lora_adapter); + if (!first_stage_lora_models.empty()) { + auto multi_lora_adapter = std::make_shared(first_stage_lora_models); + first_stage_model->set_weight_adapter(multi_lora_adapter); + } } } @@ -1395,7 +1384,7 @@ class StableDiffusionGGML { uint32_t dim = is_video ? static_cast(latents.shape()[3]) : static_cast(latents.shape()[2]); if (dim == 128) { - if (sd_version_is_flux2(version)) { + if (sd_version_uses_flux2_vae(version)) { latent_rgb_proj = flux2_latent_rgb_proj; latent_rgb_bias = flux2_latent_rgb_bias; patch_sz = 2; @@ -1844,7 +1833,7 @@ class StableDiffusionGGML { latent_channel = 48; } else if (version == VERSION_CHROMA_RADIANCE) { latent_channel = 3; - } else if (sd_version_is_flux2(version)) { + } else if (sd_version_uses_flux2_vae(version)) { latent_channel = 128; } else { latent_channel = 16; @@ -1975,6 +1964,7 @@ const char* sample_method_to_str[] = { "tcd", "res_multistep", "res_2s", + "er_sde", }; const char* sd_sample_method_name(enum sample_method_t sample_method) { @@ -2093,6 +2083,35 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) { return LORA_APPLY_MODE_COUNT; } +const char* hires_upscaler_to_str[] = { + "None", + "Latent", + "Latent (nearest)", + "Latent (nearest-exact)", + "Latent (antialiased)", + "Latent (bicubic)", + "Latent (bicubic antialiased)", + "Lanczos", + "Nearest", + "Model", +}; + +const char* sd_hires_upscaler_name(enum sd_hires_upscaler_t upscaler) { + if (upscaler >= SD_HIRES_UPSCALER_NONE && upscaler < SD_HIRES_UPSCALER_COUNT) { + return hires_upscaler_to_str[upscaler]; + } + return NONE_STR; +} + +enum sd_hires_upscaler_t str_to_sd_hires_upscaler(const char* str) { + for (int i = 0; i < SD_HIRES_UPSCALER_COUNT; i++) { + if (!strcmp(str, hires_upscaler_to_str[i])) { + return (enum sd_hires_upscaler_t)i; + } + } + return SD_HIRES_UPSCALER_COUNT; +} + void sd_cache_params_init(sd_cache_params_t* cache_params) { *cache_params = {}; cache_params->mode = SD_CACHE_DISABLED; @@ -2121,6 +2140,19 @@ void sd_cache_params_init(sd_cache_params_t* cache_params) { cache_params->spectrum_stop_percent = 0.9f; } +void sd_hires_params_init(sd_hires_params_t* hires_params) { + *hires_params = {}; + hires_params->enabled = false; + hires_params->upscaler = SD_HIRES_UPSCALER_LATENT; + hires_params->model_path = nullptr; + hires_params->scale = 2.0f; + hires_params->target_width = 0; + hires_params->target_height = 0; + hires_params->steps = 0; + hires_params->denoising_strength = 0.7f; + hires_params->upscale_tile_size = 128; +} + void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { *sd_ctx_params = {}; sd_ctx_params->vae_decode_only = true; @@ -2132,6 +2164,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->prediction = PREDICTION_COUNT; sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO; sd_ctx_params->offload_params_to_cpu = false; + sd_ctx_params->max_vram = 0.f; sd_ctx_params->enable_mmap = false; sd_ctx_params->keep_clip_on_cpu = false; sd_ctx_params->keep_control_net_on_cpu = false; @@ -2173,6 +2206,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "sampler_rng_type: %s\n" "prediction: %s\n" "offload_params_to_cpu: %s\n" + "max_vram: %.3f\n" "keep_clip_on_cpu: %s\n" "keep_control_net_on_cpu: %s\n" "keep_vae_on_cpu: %s\n" @@ -2205,6 +2239,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { sd_rng_type_name(sd_ctx_params->sampler_rng_type), sd_prediction_name(sd_ctx_params->prediction), BOOL_STR(sd_ctx_params->offload_params_to_cpu), + sd_ctx_params->max_vram, BOOL_STR(sd_ctx_params->keep_clip_on_cpu), BOOL_STR(sd_ctx_params->keep_control_net_on_cpu), BOOL_STR(sd_ctx_params->keep_vae_on_cpu), @@ -2290,6 +2325,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f}; sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; sd_cache_params_init(&sd_img_gen_params->cache); + sd_hires_params_init(&sd_img_gen_params->hires); } char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { @@ -2316,7 +2352,8 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { "increase_ref_index: %s\n" "control_strength: %.2f\n" "photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n" - "VAE tiling: %s\n", + "VAE tiling: %s\n" + "hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n", SAFE_STR(sd_img_gen_params->prompt), SAFE_STR(sd_img_gen_params->negative_prompt), sd_img_gen_params->clip_skip, @@ -2333,7 +2370,15 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->pm_params.style_strength, sd_img_gen_params->pm_params.id_images_count, SAFE_STR(sd_img_gen_params->pm_params.id_embed_path), - BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled)); + BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled), + BOOL_STR(sd_img_gen_params->hires.enabled), + sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler), + SAFE_STR(sd_img_gen_params->hires.model_path), + sd_img_gen_params->hires.scale, + sd_img_gen_params->hires.target_width, + sd_img_gen_params->hires.target_height, + sd_img_gen_params->hires.steps, + sd_img_gen_params->hires.denoising_strength); const char* cache_mode_str = "disabled"; if (sd_img_gen_params->cache.mode == SD_CACHE_EASYCACHE) { cache_mode_str = "easycache"; @@ -2370,6 +2415,14 @@ struct sd_ctx_t { StableDiffusionGGML* sd = nullptr; }; +static bool sd_version_supports_video_generation(SDVersion version) { + return version == VERSION_SVD || sd_version_is_wan(version); +} + +static bool sd_version_supports_image_generation(SDVersion version) { + return !sd_version_supports_video_generation(version); +} + sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params) { sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t)); if (sd_ctx == nullptr) { @@ -2399,6 +2452,20 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { free(sd_ctx); } +SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx) { + if (sd_ctx == nullptr || sd_ctx->sd == nullptr) { + return false; + } + return sd_version_supports_image_generation(sd_ctx->sd->version); +} + +SD_API bool sd_ctx_supports_video_generation(const sd_ctx_t* sd_ctx) { + if (sd_ctx == nullptr || sd_ctx->sd == nullptr) { + return false; + } + return sd_version_supports_video_generation(sd_ctx->sd->version); +} + enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) { if (sd_ctx != nullptr && sd_ctx->sd != nullptr) { if (sd_version_is_dit(sd_ctx->sd->version)) { @@ -2415,8 +2482,10 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me return EXPONENTIAL_SCHEDULER; } } - if (sample_method == LCM_SAMPLE_METHOD) { + if (sample_method == LCM_SAMPLE_METHOD || sample_method == TCD_SAMPLE_METHOD) { return LCM_SCHEDULER; + } else if (sample_method == DDIM_TRAILING_SAMPLE_METHOD) { + return SIMPLE_SCHEDULER; } return DISCRETE_SCHEDULER; } @@ -2457,6 +2526,7 @@ static float resolve_eta(sd_ctx_t* sd_ctx, return 0.0f; case EULER_A_SAMPLE_METHOD: case DPMPP2S_A_SAMPLE_METHOD: + case ER_SDE_SAMPLE_METHOD: return 1.0f; default:; } @@ -2489,6 +2559,7 @@ struct GenerationRequest { sd_guidance_params_t guidance = {}; sd_guidance_params_t high_noise_guidance = {}; sd_pm_params_t pm_params = {}; + sd_hires_params_t hires = {}; int frames = -1; float vace_strength = 1.f; @@ -2510,6 +2581,7 @@ struct GenerationRequest { auto_resize_ref_image = sd_img_gen_params->auto_resize_ref_image; guidance = sd_img_gen_params->sample_params.guidance; pm_params = sd_img_gen_params->pm_params; + hires = sd_img_gen_params->hires; cache_params = &sd_img_gen_params->cache; resolve(sd_ctx); } @@ -2532,26 +2604,76 @@ struct GenerationRequest { } void align_generation_request_size() { + align_image_size(&width, &height, "generation request"); + } + + void align_image_size(int* target_width, int* target_height, const char* label) { int spatial_multiple = vae_scale_factor * diffusion_model_down_factor; - int width_offset = align_up_offset(width, spatial_multiple); - int height_offset = align_up_offset(height, spatial_multiple); + int width_offset = align_up_offset(*target_width, spatial_multiple); + int height_offset = align_up_offset(*target_height, spatial_multiple); if (width_offset <= 0 && height_offset <= 0) { return; } - int original_width = width; - int original_height = height; + int original_width = *target_width; + int original_height = *target_height; - width += width_offset; - height += height_offset; - LOG_WARN("align up %dx%d to %dx%d (multiple=%d)", + *target_width += width_offset; + *target_height += height_offset; + LOG_WARN("align %s up %dx%d to %dx%d (multiple=%d)", + label, original_width, original_height, - width, - height, + *target_width, + *target_height, spatial_multiple); } + void resolve_hires() { + if (!hires.enabled) { + return; + } + if (hires.upscaler == SD_HIRES_UPSCALER_NONE) { + hires.enabled = false; + return; + } + if (hires.upscaler < SD_HIRES_UPSCALER_NONE || hires.upscaler >= SD_HIRES_UPSCALER_COUNT) { + LOG_WARN("hires upscaler '%d' is invalid, disabling hires", hires.upscaler); + hires.enabled = false; + return; + } + if (hires.upscaler == SD_HIRES_UPSCALER_MODEL && strlen(SAFE_STR(hires.model_path)) == 0) { + LOG_WARN("hires model upscaler requires a model path, disabling hires"); + hires.enabled = false; + return; + } + if (hires.scale <= 0.f && hires.target_width <= 0 && hires.target_height <= 0) { + LOG_WARN("hires scale must be positive when no target size is set, disabling hires"); + hires.enabled = false; + return; + } + hires.denoising_strength = std::clamp(hires.denoising_strength, 0.0001f, 1.f); + hires.steps = std::max(0, hires.steps); + + if (hires.target_width > 0 && hires.target_height > 0) { + // pass + } else if (hires.target_width > 0) { + hires.target_height = hires.target_width; + } else if (hires.target_height > 0) { + hires.target_width = hires.target_height; + } else { + hires.target_width = static_cast(std::round(width * hires.scale)); + hires.target_height = static_cast(std::round(height * hires.scale)); + } + + if (hires.target_width <= 0 || hires.target_height <= 0) { + LOG_WARN("hires target size is not positive, disabling hires"); + hires.enabled = false; + return; + } + align_image_size(&hires.target_width, &hires.target_height, "hires target"); + } + static void resolve_guidance(sd_ctx_t* sd_ctx, sd_guidance_params_t* guidance, bool* use_uncond, @@ -2592,6 +2714,7 @@ struct GenerationRequest { void resolve(sd_ctx_t* sd_ctx) { align_generation_request_size(); + resolve_hires(); seed = resolve_seed(seed); resolve_guidance(sd_ctx, &guidance, &use_uncond, &use_img_cond); @@ -2846,7 +2969,8 @@ static std::optional prepare_image_generation_latents(sd {request->width / request->vae_scale_factor, request->height / request->vae_scale_factor, 1, - 1}); + 1}, + sd::ops::InterpolateMode::NearestMax); sd::Tensor init_latent; sd::Tensor control_latent; @@ -2991,8 +3115,12 @@ static std::optional prepare_image_generation_latents(sd latents.ref_latents = std::move(ref_latents); if (sd_version_is_inpaint(sd_ctx->sd->version)) { - latents.denoise_mask = std::move(latent_mask); + latent_mask = sd::ops::max_pool_2d(latent_mask, + {3, 3}, + {1, 1}, + {1, 1}); } + latents.denoise_mask = std::move(latent_mask); return latents; } @@ -3077,7 +3205,7 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx, } decoded_images.push_back(std::move(image)); int64_t t2 = ggml_time_ms(); - LOG_INFO("latent %" PRId64 " decoded, taking %.2fs", i + 1, (t2 - t1) * 1.0f / 1000); + LOG_INFO("latent %zu decoded, taking %.2fs", i + 1, (t2 - t1) * 1.0f / 1000); } int64_t t4 = ggml_time_ms(); @@ -3099,6 +3227,135 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx, return result_images; } +static sd::Tensor upscale_hires_latent(sd_ctx_t* sd_ctx, + const sd::Tensor& latent, + const GenerationRequest& request, + UpscalerGGML* upscaler) { + auto get_hires_latent_target_shape = [&]() { + std::vector target_shape = latent.shape(); + if (target_shape.size() < 2) { + target_shape.clear(); + return target_shape; + } + target_shape[0] = request.hires.target_width / request.vae_scale_factor; + target_shape[1] = request.hires.target_height / request.vae_scale_factor; + return target_shape; + }; + + if (request.hires.upscaler == SD_HIRES_UPSCALER_LATENT || + request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_NEAREST || + request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_NEAREST_EXACT || + request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_ANTIALIASED || + request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_BICUBIC || + request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_BICUBIC_ANTIALIASED) { + std::vector target_shape = get_hires_latent_target_shape(); + if (target_shape.empty()) { + LOG_ERROR("latent has invalid shape for hires upscale"); + return {}; + } + + sd::ops::InterpolateMode mode = sd::ops::InterpolateMode::Nearest; + bool antialias = false; + switch (request.hires.upscaler) { + case SD_HIRES_UPSCALER_LATENT: + mode = sd::ops::InterpolateMode::Bilinear; + break; + case SD_HIRES_UPSCALER_LATENT_NEAREST: + mode = sd::ops::InterpolateMode::Nearest; + break; + case SD_HIRES_UPSCALER_LATENT_NEAREST_EXACT: + mode = sd::ops::InterpolateMode::NearestExact; + break; + case SD_HIRES_UPSCALER_LATENT_ANTIALIASED: + mode = sd::ops::InterpolateMode::Bilinear; + antialias = true; + break; + case SD_HIRES_UPSCALER_LATENT_BICUBIC: + mode = sd::ops::InterpolateMode::Bicubic; + break; + case SD_HIRES_UPSCALER_LATENT_BICUBIC_ANTIALIASED: + mode = sd::ops::InterpolateMode::Bicubic; + antialias = true; + break; + default: + break; + } + + LOG_INFO("hires %s upscale %" PRId64 "x%" PRId64 " -> %" PRId64 "x%" PRId64, + sd_hires_upscaler_name(request.hires.upscaler), + latent.shape()[0], + latent.shape()[1], + target_shape[0], + target_shape[1]); + + return sd::ops::interpolate(latent, target_shape, mode, false, antialias); + } else if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL || + request.hires.upscaler == SD_HIRES_UPSCALER_LANCZOS || + request.hires.upscaler == SD_HIRES_UPSCALER_NEAREST) { + if (sd_ctx->sd->vae_decode_only) { + LOG_ERROR("hires %s upscaler requires VAE encoder weights; create the context with vae_decode_only=false", + sd_hires_upscaler_name(request.hires.upscaler)); + return {}; + } + if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL && upscaler == nullptr) { + LOG_ERROR("hires model upscaler context is null"); + return {}; + } + + sd::Tensor decoded = sd_ctx->sd->decode_first_stage(latent); + if (decoded.empty()) { + LOG_ERROR("decode_first_stage failed before hires %s upscale", + sd_hires_upscaler_name(request.hires.upscaler)); + return {}; + } + + sd::Tensor upscaled_tensor; + if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) { + upscaled_tensor = upscaler->upscale_tensor(decoded); + if (upscaled_tensor.empty()) { + LOG_ERROR("hires model upscale failed"); + return {}; + } + + if (upscaled_tensor.shape()[0] != request.hires.target_width || + upscaled_tensor.shape()[1] != request.hires.target_height) { + upscaled_tensor = sd::ops::interpolate(upscaled_tensor, + {request.hires.target_width, + request.hires.target_height, + upscaled_tensor.shape()[2], + upscaled_tensor.shape()[3]}); + } + } else { + sd::ops::InterpolateMode mode = request.hires.upscaler == SD_HIRES_UPSCALER_LANCZOS + ? sd::ops::InterpolateMode::Lanczos + : sd::ops::InterpolateMode::Nearest; + LOG_INFO("hires %s image upscale %" PRId64 "x%" PRId64 " -> %dx%d", + sd_hires_upscaler_name(request.hires.upscaler), + decoded.shape()[0], + decoded.shape()[1], + request.hires.target_width, + request.hires.target_height); + upscaled_tensor = sd::ops::interpolate(decoded, + {request.hires.target_width, + request.hires.target_height, + decoded.shape()[2], + decoded.shape()[3]}, + mode); + upscaled_tensor = sd::ops::clamp(upscaled_tensor, 0.0f, 1.0f); + } + + sd::Tensor upscaled_latent = sd_ctx->sd->encode_first_stage(upscaled_tensor); + if (upscaled_latent.empty()) { + LOG_ERROR("encode_first_stage failed after hires %s upscale", + sd_hires_upscaler_name(request.hires.upscaler)); + } + return upscaled_latent; + } + + LOG_ERROR("unsupported hires upscaler '%s'", sd_hires_upscaler_name(request.hires.upscaler)); + return {}; +} + SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) { if (sd_ctx == nullptr || sd_img_gen_params == nullptr) { return nullptr; @@ -3186,14 +3443,143 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s } return nullptr; } - if (sd_ctx->sd->free_params_immediately) { + if (sd_ctx->sd->free_params_immediately && !request.hires.enabled) { sd_ctx->sd->diffusion_model->free_params_buffer(); } int64_t denoise_end = ggml_time_ms(); - LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs", + LOG_INFO("generating %zu latent images completed, taking %.2fs", final_latents.size(), (denoise_end - denoise_start) * 1.0f / 1000); + if (request.hires.enabled && request.hires.target_width > 0) { + LOG_INFO("hires fix: upscaling to %dx%d", request.hires.target_width, request.hires.target_height); + + std::unique_ptr hires_upscaler; + if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) { + LOG_INFO("hires fix: loading model upscaler from '%s'", request.hires.model_path); + hires_upscaler = std::make_unique(sd_ctx->sd->n_threads, + false, + request.hires.upscale_tile_size); + const size_t max_graph_vram_bytes = sd_ctx->sd->max_vram <= 0.f + ? 0 + : static_cast(static_cast(sd_ctx->sd->max_vram) * 1024.0 * 1024.0 * 1024.0); + hires_upscaler->set_max_graph_vram_bytes(max_graph_vram_bytes); + if (!hires_upscaler->load_from_file(request.hires.model_path, + sd_ctx->sd->offload_params_to_cpu, + sd_ctx->sd->n_threads)) { + LOG_ERROR("load hires model upscaler failed"); + if (sd_ctx->sd->free_params_immediately) { + sd_ctx->sd->diffusion_model->free_params_buffer(); + } + return nullptr; + } + } + + int hires_steps = request.hires.steps > 0 ? request.hires.steps : plan.sample_steps; + + // sd-webui behavior: scale up total steps so trimming by denoising_strength yields exactly hires_steps effective steps, + // unlike img2img which trims from a fixed step count + hires_steps = static_cast(hires_steps / request.hires.denoising_strength); + + std::vector hires_sigmas = sd_ctx->sd->denoiser->get_sigmas( + hires_steps, + sd_ctx->sd->get_image_seq_len(request.hires.target_height, request.hires.target_width), + sd_img_gen_params->sample_params.scheduler, + sd_ctx->sd->version); + + size_t t_enc = static_cast(hires_steps * request.hires.denoising_strength); + if (t_enc >= static_cast(hires_steps)) { + t_enc = static_cast(hires_steps) - 1; + } + std::vector hires_sigma_sched(hires_sigmas.begin() + hires_steps - static_cast(t_enc) - 1, + hires_sigmas.end()); + LOG_INFO("hires fix: %d steps, denoising_strength=%.2f, sigma_sched_size=%zu", + hires_steps, + request.hires.denoising_strength, + hires_sigma_sched.size()); + + std::vector> hires_final_latents; + int64_t hires_denoise_start = ggml_time_ms(); + for (int b = 0; b < (int)final_latents.size(); b++) { + int64_t cur_seed = request.seed + b; + sd_ctx->sd->rng->manual_seed(cur_seed); + sd_ctx->sd->sampler_rng->manual_seed(cur_seed); + + sd::Tensor upscaled = upscale_hires_latent(sd_ctx, + final_latents[b], + request, + hires_upscaler.get()); + if (upscaled.empty()) { + if (sd_ctx->sd->free_params_immediately) { + sd_ctx->sd->diffusion_model->free_params_buffer(); + } + return nullptr; + } + + sd::Tensor noise = sd::randn_like(upscaled, sd_ctx->sd->rng); + + sd::Tensor hires_denoise_mask; + if (!latents.denoise_mask.empty()) { + std::vector mask_shape = latents.denoise_mask.shape(); + mask_shape[0] = upscaled.shape()[0]; + mask_shape[1] = upscaled.shape()[1]; + hires_denoise_mask = sd::ops::interpolate(latents.denoise_mask, + mask_shape, + sd::ops::InterpolateMode::NearestMax); + } + + int64_t hires_sample_start = ggml_time_ms(); + sd::Tensor x_0 = sd_ctx->sd->sample(sd_ctx->sd->diffusion_model, + true, + upscaled, + std::move(noise), + embeds.cond, + embeds.uncond, + embeds.img_cond, + embeds.id_cond, + latents.control_image, + request.control_strength, + request.guidance, + plan.eta, + request.shifted_timestep, + plan.sample_method, + sd_ctx->sd->is_flow_denoiser(), + hires_sigma_sched, + plan.start_merge_step, + latents.ref_latents, + request.increase_ref_index, + hires_denoise_mask, + sd::Tensor(), + 1.f, + request.cache_params); + int64_t hires_sample_end = ggml_time_ms(); + if (!x_0.empty()) { + LOG_INFO("hires sampling %d/%d completed, taking %.2fs", + b + 1, + (int)final_latents.size(), + (hires_sample_end - hires_sample_start) * 1.0f / 1000); + hires_final_latents.push_back(std::move(x_0)); + continue; + } + + LOG_ERROR("hires sampling for image %d/%d failed after %.2fs", + b + 1, + (int)final_latents.size(), + (hires_sample_end - hires_sample_start) * 1.0f / 1000); + if (sd_ctx->sd->free_params_immediately) { + sd_ctx->sd->diffusion_model->free_params_buffer(); + } + return nullptr; + } + if (sd_ctx->sd->free_params_immediately) { + sd_ctx->sd->diffusion_model->free_params_buffer(); + } + int64_t hires_denoise_end = ggml_time_ms(); + LOG_INFO("hires fix completed, taking %.2fs", (hires_denoise_end - hires_denoise_start) * 1.0f / 1000); + + final_latents = std::move(hires_final_latents); + } + auto result = decode_image_outputs(sd_ctx, request, final_latents); if (result == nullptr) { return nullptr; diff --git a/src/t5.hpp b/src/t5.hpp index f64d0b6d7..71545e522 100644 --- a/src/t5.hpp +++ b/src/t5.hpp @@ -1,4 +1,4 @@ -๏ปฟ#ifndef __T5_HPP__ +#ifndef __T5_HPP__ #define __T5_HPP__ #include @@ -10,452 +10,9 @@ #include #include -#include "darts.h" #include "ggml_extend.hpp" -#include "json.hpp" #include "model.h" -#include "vocab/vocab.h" - -// Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h -// and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h. -// Original License: https://github.com/google/sentencepiece/blob/master/LICENSE -// -// Since tokenization is not the bottleneck in SD, performance was not a major consideration -// during the migration. -class MetaspacePreTokenizer { -private: - std::string replacement; - bool add_prefix_space; - -public: - MetaspacePreTokenizer(const std::string replacement = " ", bool add_prefix_space = true) - : replacement(replacement), add_prefix_space(add_prefix_space) {} - - std::string tokenize(const std::string& input) const { - std::string tokens; - std::stringstream ss(input); - - if (add_prefix_space) { - tokens += replacement; - } - - std::string token; - bool firstToken = true; - while (std::getline(ss, token, ' ')) { - if (!firstToken) - tokens += replacement + token; - else - tokens += token; - - firstToken = false; - } - - return tokens; - } -}; - -using EncodeResult = std::vector>; -class T5UniGramTokenizer { -public: - enum Status { - OK, - NO_PIECES_LOADED, - NO_ENTRY_FOUND, - BUILD_DOUBLE_ARRAY_FAILED, - PIECE_ALREADY_DEFINED, - INVLIAD_JSON - }; - -protected: - MetaspacePreTokenizer pre_tokenizer; - - // all pairs - std::vector> piece_score_pairs; - - float min_score_ = 0.0; - float max_score_ = 0.0; - std::unique_ptr trie_; - - // Maximum size of the return value of Trie, which corresponds - // to the maximum size of shared common prefix in the sentence pieces. - int trie_results_size_; - // unknown id. - int unk_id_ = 2; - std::string eos_token_ = ""; - int eos_id_ = 1; - int pad_id_ = 0; - // status. - Status status_ = OK; - - float kUnkPenalty = 10.0; - - std::string replacement; - bool add_prefix_space = true; - - void InitializePieces(const std::string& json_str) { - nlohmann::json data; - - try { - data = nlohmann::json::parse(json_str); - } catch (const nlohmann::json::parse_error&) { - status_ = INVLIAD_JSON; - return; - } - if (!data.contains("model")) { - status_ = INVLIAD_JSON; - return; - } - nlohmann::json model = data["model"]; - if (!model.contains("vocab")) { - status_ = INVLIAD_JSON; - return; - } - if (model.contains("unk_id")) { - unk_id_ = model["unk_id"]; - } - - replacement = data["pre_tokenizer"]["replacement"]; - add_prefix_space = data["pre_tokenizer"]["add_prefix_space"]; - - pre_tokenizer = MetaspacePreTokenizer(replacement, add_prefix_space); - - for (const auto& item : model["vocab"]) { - if (item.size() != 2 || !item[0].is_string() || !item[1].is_number_float()) { - status_ = INVLIAD_JSON; - return; - } - std::string piece = item[0]; - if (piece.empty()) { - piece = ""; - } - float score = item[1]; - piece_score_pairs.emplace_back(piece, score); - } - } - - // Builds a Trie index. - void BuildTrie(std::vector>* pieces) { - if (status_ != OK) - return; - - if (pieces->empty()) { - status_ = NO_PIECES_LOADED; - return; - } - - // sort by sentencepiece since DoubleArray::build() - // only accepts sorted strings. - sort(pieces->begin(), pieces->end()); - - // Makes key/value set for DoubleArrayTrie. - std::vector key(pieces->size()); - std::vector value(pieces->size()); - for (size_t i = 0; i < pieces->size(); ++i) { - // LOG_DEBUG("%s %d", (*pieces)[i].first.c_str(), (*pieces)[i].second); - key[i] = (*pieces)[i].first.data(); // sorted piece. - value[i] = (*pieces)[i].second; // vocab_id - } - - trie_ = std::unique_ptr(new Darts::DoubleArray()); - if (trie_->build(key.size(), const_cast(&key[0]), nullptr, - &value[0]) != 0) { - status_ = BUILD_DOUBLE_ARRAY_FAILED; - return; - } - - // Computes the maximum number of shared prefixes in the trie. - const int kMaxTrieResultsSize = 1024; - std::vector results( - kMaxTrieResultsSize); - trie_results_size_ = 0; - for (const auto& p : *pieces) { - const size_t num_nodes = trie_->commonPrefixSearch( - p.first.data(), results.data(), results.size(), p.first.size()); - trie_results_size_ = std::max(trie_results_size_, static_cast(num_nodes)); - } - - if (trie_results_size_ == 0) - status_ = NO_ENTRY_FOUND; - } - - // Non-virtual (inlined) implementation for faster execution. - inline float GetScoreInlined(int id) const { - return piece_score_pairs[id].second; - } - - inline bool IsUnusedInlined(int id) const { - return false; // TODO - } - - inline bool IsUserDefinedInlined(int id) const { - return false; // TODO - } - - inline size_t OneCharLen(const char* src) const { - return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4]; - } - - // The optimized Viterbi encode. - // Main differences from the original function: - // 1. Memorizes the best path at each postion so far, - // 2. No need to store the Lattice nodes, - // 3. Works in utf-8 directly, - // 4. Defines a new struct with fewer fields than Lattice, - // 5. Does not depend on `class Lattice` nor call `SetSentence()`, - // `PopulateNodes()`, or `Viterbi()`. It does everything in one function. - // For detailed explanations please see the comments inside the function body. - EncodeResult EncodeOptimized(const std::string& normalized) const { - // An optimized Viterbi algorithm for unigram language models. Benchmarking - // results show that it generates almost identical outputs and achieves 2.1x - // speedup on average for 102 languages compared to the original - // implementation. It's based on the following three ideas: - // - // 1. Because it uses the *unigram* model: - // best_score(x1, x2, ... xt) = best_score(x1, x2, ... x{t-1}) + score(xt) - // Deciding the best path (and score) can be decoupled into two isolated - // terms: (a) the best path ended before the last token `best_score(x1, x2, ...)` - // x{t-1})`, and (b) the last token and its `score(xt)`. The two terms are - // not related to each other at all. - // - // Therefore, we can compute once and store the *best_path ending at - // each character position*. In this way, when we know best_path_ends_at[M], - // we can reuse it to compute all the best_path_ends_at_[...] where the last - // token starts at the same character position M. - // - // This improves the time complexity from O(n*k*k) to O(n*k) because it - // eliminates the extra loop of recomputing the best path ending at the same - // position, where n is the input length and k is the maximum number of tokens - // that can be recognized starting at each position. - // - // 2. Again, because it uses the *unigram* model, we don't need to actually - // store the lattice nodes. We still recognize all the tokens and lattice - // nodes from the input, but along identifying them, we use and discard them - // on the fly. There is no need to actually store them for best path Viterbi - // decoding. The only thing we need to store is the best_path ending at - // each character position. - // - // This improvement reduces the things needed to store in memory from O(n*k) - // to O(n), where n is the input length and k is the maximum number of tokens - // that can be recognized starting at each position. - // - // It also avoids the need of dynamic-size lattice node pool, because the - // number of things to store is fixed as n. - // - // 3. SentencePiece is designed to work with unicode, taking utf-8 encoding - // inputs. In the original implementation, the lattice positions are based on - // unicode positions. A mapping from unicode position to the utf-8 position is - // maintained to recover the utf-8 string piece. - // - // We found that it is sufficient and beneficial to directly work with utf-8 - // positions: - // - // Firstly, it saves the conversion and mapping between unicode positions and - // utf-8 positions. - // - // Secondly, it reduces the number of fields we need to maintain in the - // node/path structure. Specifically, there are 8 fields defined in - // `Lattice::Node` used by the original encoder, but here in the optimized - // encoder we only need to define 3 fields in `BestPathNode`. - - if (status() != OK || normalized.empty()) { - return {}; - } - // Represents the last node of the best path. - struct BestPathNode { - int id = -1; // The vocab id. (maybe -1 for UNK) - float best_path_score = - 0; // The total score of the best path ending at this node. - int starts_at = - -1; // The starting position (in utf-8) of this node. The entire best - // path can be constructed by backtracking along this link. - }; - const int size = static_cast(normalized.size()); - const float unk_score = min_score() - kUnkPenalty; - // The ends are exclusive. - std::vector best_path_ends_at(size + 1); - // Generate lattice on-the-fly (not stored) and update best_path_ends_at. - int starts_at = 0; - while (starts_at < size) { - std::size_t node_pos = 0; - std::size_t key_pos = starts_at; - const auto best_path_score_till_here = - best_path_ends_at[starts_at].best_path_score; - bool has_single_node = false; - const int mblen = - std::min(static_cast(OneCharLen(normalized.data() + starts_at)), - size - starts_at); - while (key_pos < size) { - const int ret = - trie_->traverse(normalized.data(), node_pos, key_pos, key_pos + 1); - if (ret == -2) - break; - if (ret >= 0) { - if (IsUnusedInlined(ret)) - continue; - // Update the best path node. - auto& target_node = best_path_ends_at[key_pos]; - const auto length = (key_pos - starts_at); - // User defined symbol receives extra bonus to always be selected. - const auto score = IsUserDefinedInlined(ret) - ? (length * max_score_ - 0.1) - : GetScoreInlined(ret); - const auto candidate_best_path_score = - score + best_path_score_till_here; - if (target_node.starts_at == -1 || - candidate_best_path_score > target_node.best_path_score) { - target_node.best_path_score = static_cast(candidate_best_path_score); - target_node.starts_at = starts_at; - target_node.id = ret; - } - if (!has_single_node && length == mblen) { - has_single_node = true; - } - } - } - if (!has_single_node) { - auto& target_node = best_path_ends_at[starts_at + mblen]; - const auto candidate_best_path_score = - unk_score + best_path_score_till_here; - if (target_node.starts_at == -1 || - candidate_best_path_score > target_node.best_path_score) { - target_node.best_path_score = candidate_best_path_score; - target_node.starts_at = starts_at; - target_node.id = unk_id_; - } - } - // Move by one unicode character. - starts_at += mblen; - } - // Backtrack to identify the best path. - EncodeResult results; - int ends_at = size; - while (ends_at > 0) { - const auto& node = best_path_ends_at[ends_at]; - results.emplace_back( - normalized.substr(node.starts_at, ends_at - node.starts_at), node.id); - ends_at = node.starts_at; - } - std::reverse(results.begin(), results.end()); - return results; - } - -public: - explicit T5UniGramTokenizer(bool is_umt5 = false) { - if (is_umt5) { - InitializePieces(load_umt5_tokenizer_json()); - } else { - InitializePieces(load_t5_tokenizer_json()); - } - - min_score_ = FLT_MAX; - max_score_ = FLT_MIN; - - std::vector> pieces; - for (int i = 0; i < piece_score_pairs.size(); i++) { - const auto& sp = piece_score_pairs[i]; - - min_score_ = std::min(min_score_, sp.second); - max_score_ = std::max(max_score_, sp.second); - - pieces.emplace_back(sp.first, i); - } - - BuildTrie(&pieces); - } - ~T5UniGramTokenizer(){}; - - std::string Normalize(const std::string& input) const { - // Ref: https://github.com/huggingface/tokenizers/blob/1ff56c0c70b045f0cd82da1af9ac08cd4c7a6f9f/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py#L29 - // TODO: nmt-nfkc - std::string normalized = std::regex_replace(input, std::regex(" {2,}"), " "); - return normalized; - } - - std::vector Encode(const std::string& input, bool append_eos_if_not_present = true) const { - std::string normalized = Normalize(input); - normalized = pre_tokenizer.tokenize(normalized); - EncodeResult result = EncodeOptimized(normalized); - if (result.size() > 0 && append_eos_if_not_present) { - auto item = result[result.size() - 1]; - if (item.first != eos_token_) { - result.emplace_back(eos_token_, eos_id_); - } - } - std::vector tokens; - for (auto item : result) { - tokens.push_back(item.second); - } - return tokens; - } - - void pad_tokens(std::vector& tokens, - std::vector& weights, - std::vector* attention_mask, - size_t max_length = 0, - bool padding = false) { - if (max_length > 0 && padding) { - size_t orig_token_num = tokens.size() - 1; - size_t n = static_cast(std::ceil(orig_token_num * 1.0 / (max_length - 1))); - if (n == 0) { - n = 1; - } - size_t length = max_length * n; - LOG_DEBUG("token length: %llu", length); - std::vector new_tokens; - std::vector new_weights; - std::vector new_attention_mask; - int token_idx = 0; - for (int i = 0; i < length; i++) { - if (token_idx >= orig_token_num) { - break; - } - if (attention_mask != nullptr) { - new_attention_mask.push_back(0.0); - } - if (i % max_length == max_length - 1) { - new_tokens.push_back(eos_id_); - new_weights.push_back(1.0); - } else { - new_tokens.push_back(tokens[token_idx]); - new_weights.push_back(weights[token_idx]); - token_idx++; - } - } - - new_tokens.push_back(eos_id_); - new_weights.push_back(1.0); - if (attention_mask != nullptr) { - new_attention_mask.push_back(0.0); - } - - tokens = new_tokens; - weights = new_weights; - if (attention_mask != nullptr) { - *attention_mask = new_attention_mask; - } - - if (padding) { - int pad_token_id = pad_id_; - tokens.insert(tokens.end(), length - tokens.size(), pad_token_id); - weights.insert(weights.end(), length - weights.size(), 1.0); - if (attention_mask != nullptr) { - // maybe keep some padding tokens unmasked? - attention_mask->insert(attention_mask->end(), length - attention_mask->size(), -HUGE_VALF); - } - } - } - } - - // Returns the minimum score in sentence pieces. - // min_score() - 10 is used for the cost of unknown sentence. - float min_score() const { return min_score_; } - - // Returns the maximum score in sentence pieces. - // max_score() is used for the cost of user defined symbols. - float max_score() const { return max_score_; } - - Status status() const { return status_; } -}; +#include "tokenizers/t5_unigram_tokenizer.h" class T5LayerNorm : public UnaryBlock { protected: @@ -694,7 +251,8 @@ struct T5Stack : public GGMLBlock { ggml_tensor* x, ggml_tensor* past_bias = nullptr, ggml_tensor* attention_mask = nullptr, - ggml_tensor* relative_position_bucket = nullptr) { + ggml_tensor* relative_position_bucket = nullptr, + const std::string& graph_cut_prefix = "") { // x: [N, n_token, model_dim] for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["block." + std::to_string(i)]); @@ -702,6 +260,9 @@ struct T5Stack : public GGMLBlock { auto ret = block->forward(ctx, x, past_bias, attention_mask, relative_position_bucket); x = ret.first; past_bias = ret.second; + if (!graph_cut_prefix.empty()) { + sd::ggml_graph_cut::mark_graph_cut(x, graph_cut_prefix + ".block." + std::to_string(i), "x"); + } } auto final_layer_norm = std::dynamic_pointer_cast(blocks["final_layer_norm"]); @@ -748,7 +309,8 @@ struct T5 : public GGMLBlock { auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); auto x = shared->forward(ctx, input_ids); - x = encoder->forward(ctx, x, past_bias, attention_mask, relative_position_bucket); + sd::ggml_graph_cut::mark_graph_cut(x, "t5.prelude", "x"); + x = encoder->forward(ctx, x, past_bias, attention_mask, relative_position_bucket, "t5"); return x; } }; @@ -937,18 +499,17 @@ struct T5Embedder { for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = tokenizer.Encode(curr_text, false); + std::vector curr_tokens = tokenizer.encode(curr_text); tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); weights.insert(weights.end(), curr_tokens.size(), curr_weight); } - int EOS_TOKEN_ID = 1; - tokens.push_back(EOS_TOKEN_ID); - weights.push_back(1.0); - std::vector attention_mask; - tokenizer.pad_tokens(tokens, weights, &attention_mask, max_length, padding); + tokenizer.pad_tokens(tokens, &weights, &attention_mask, padding ? max_length : 0, padding ? max_length : 100000000, padding); + for (auto& mask_value : attention_mask) { + mask_value = mask_value > 0.0f ? 0.0f : -HUGE_VALF; + } // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", "; diff --git a/src/tensor.hpp b/src/tensor.hpp index 33a2bdeaa..f45551940 100644 --- a/src/tensor.hpp +++ b/src/tensor.hpp @@ -815,8 +815,202 @@ namespace sd { namespace ops { enum class InterpolateMode { Nearest, + NearestExact, + NearestMax, + NearestMin, + NearestAvg, + Bilinear, + Bicubic, + Lanczos, }; + inline bool is_nearest_like_interpolate_mode(InterpolateMode mode) { + return mode == InterpolateMode::Nearest || + mode == InterpolateMode::NearestExact || + mode == InterpolateMode::NearestMax || + mode == InterpolateMode::NearestMin || + mode == InterpolateMode::NearestAvg; + } + + inline bool is_2d_filter_interpolate_mode(InterpolateMode mode) { + return mode == InterpolateMode::Bilinear || + mode == InterpolateMode::Bicubic || + mode == InterpolateMode::Lanczos; + } + + inline int64_t nearest_exact_interpolate_index(int64_t output_index, + int64_t input_size, + int64_t output_size) { + const double scale = static_cast(input_size) / static_cast(output_size); + const double center = (static_cast(output_index) + 0.5) * scale - 0.5; + return std::min(std::max(static_cast(std::floor(center + 0.5)), 0), input_size - 1); + } + + inline double linear_interpolate_weight(double x) { + x = std::abs(x); + return x < 1.0 ? 1.0 - x : 0.0; + } + + inline double cubic_interpolate_weight(double x) { + constexpr double a = -0.75; // Match PyTorch bicubic interpolation. + x = std::abs(x); + if (x <= 1.0) { + return ((a + 2.0) * x - (a + 3.0)) * x * x + 1.0; + } + if (x < 2.0) { + return ((a * x - 5.0 * a) * x + 8.0 * a) * x - 4.0 * a; + } + return 0.0; + } + + inline double sinc(double x) { + constexpr double pi = 3.14159265358979323846; + if (std::abs(x) < 1e-12) { + return 1.0; + } + const double pix = pi * x; + return std::sin(pix) / pix; + } + + inline double lanczos_interpolate_weight(double x) { + constexpr double radius = 3.0; + x = std::abs(x); + if (x >= radius) { + return 0.0; + } + return sinc(x) * sinc(x / radius); + } + + struct InterpolateContributor { + int64_t index; + double weight; + }; + + inline std::vector> make_interpolate_contributors( + int64_t input_size, + int64_t output_size, + InterpolateMode mode, + bool antialias) { + std::vector> contributors(static_cast(output_size)); + const double scale = static_cast(input_size) / static_cast(output_size); + const double filter_scale = antialias ? std::max(1.0, scale) : 1.0; + + for (int64_t out = 0; out < output_size; ++out) { + const double center = (static_cast(out) + 0.5) * scale - 0.5; + int64_t start = 0; + int64_t end = 0; + + if (mode == InterpolateMode::Bilinear) { + const double support = filter_scale; + start = static_cast(std::ceil(center - support)); + end = static_cast(std::floor(center + support)); + } else if (mode == InterpolateMode::Bicubic) { + const double support = 2.0 * filter_scale; + start = static_cast(std::ceil(center - support)); + end = static_cast(std::floor(center + support)); + } else if (mode == InterpolateMode::Lanczos) { + const double support = 3.0 * filter_scale; + start = static_cast(std::ceil(center - support)); + end = static_cast(std::floor(center + support)); + } else { + tensor_throw_invalid_argument("Unsupported 2D filter interpolate mode: mode=" + + std::to_string(static_cast(mode))); + } + + double weight_sum = 0.0; + std::vector& axis_contributors = contributors[static_cast(out)]; + axis_contributors.reserve(static_cast(end - start + 1)); + + for (int64_t in = start; in <= end; ++in) { + double weight = 0.0; + if (mode == InterpolateMode::Bilinear) { + weight = linear_interpolate_weight((center - static_cast(in)) / filter_scale); + } else if (mode == InterpolateMode::Bicubic) { + weight = cubic_interpolate_weight((center - static_cast(in)) / filter_scale); + } else { + weight = lanczos_interpolate_weight((center - static_cast(in)) / filter_scale); + } + + if (weight == 0.0) { + continue; + } + + const int64_t clamped_index = std::min(std::max(in, 0), input_size - 1); + axis_contributors.push_back({clamped_index, weight}); + weight_sum += weight; + } + + if ((antialias || mode == InterpolateMode::Lanczos) && + std::abs(weight_sum) > 1e-12) { + for (auto& contributor : axis_contributors) { + contributor.weight /= weight_sum; + } + } + + if (axis_contributors.empty()) { + const int64_t nearest = std::min( + std::max(static_cast(std::floor(center + 0.5)), 0), + input_size - 1); + axis_contributors.push_back({nearest, 1.0}); + } + } + + return contributors; + } + + template + inline Tensor interpolate_2d_filter(const Tensor& input, + const std::vector& output_shape, + InterpolateMode mode, + bool antialias) { + if (input.dim() < 2) { + tensor_throw_invalid_argument("2D filter interpolate requires rank >= 2: input_shape=" + + tensor_shape_to_string(input.shape()) + ", output_shape=" + + tensor_shape_to_string(output_shape)); + } + for (size_t i = 2; i < output_shape.size(); ++i) { + if (input.shape()[i] != output_shape[i]) { + tensor_throw_invalid_argument("2D filter interpolate only supports resizing dimensions 0 and 1: input_shape=" + + tensor_shape_to_string(input.shape()) + ", output_shape=" + + tensor_shape_to_string(output_shape)); + } + } + + Tensor output(output_shape); + const int64_t input_width = input.shape()[0]; + const int64_t input_height = input.shape()[1]; + const int64_t output_width = output_shape[0]; + const int64_t output_height = output_shape[1]; + const int64_t input_plane = input_width * input_height; + const int64_t output_plane = output_width * output_height; + const int64_t plane_count = input.numel() / input_plane; + + auto x_contributors = make_interpolate_contributors(input_width, output_width, mode, antialias); + auto y_contributors = make_interpolate_contributors(input_height, output_height, mode, antialias); + + for (int64_t plane = 0; plane < plane_count; ++plane) { + const int64_t input_plane_offset = plane * input_plane; + const int64_t output_plane_offset = plane * output_plane; + for (int64_t y = 0; y < output_height; ++y) { + const auto& y_axis = y_contributors[static_cast(y)]; + for (int64_t x = 0; x < output_width; ++x) { + const auto& x_axis = x_contributors[static_cast(x)]; + double value = 0.0; + for (const auto& yc : y_axis) { + const int64_t input_row_offset = input_plane_offset + yc.index * input_width; + for (const auto& xc : x_axis) { + value += static_cast(input.data()[input_row_offset + xc.index]) * + xc.weight * yc.weight; + } + } + output.data()[output_plane_offset + y * output_width + x] = static_cast(value); + } + } + } + + return output; + } + inline int64_t normalize_slice_bound(int64_t index, int64_t dim_size) { if (index < 0) { index += dim_size; @@ -1011,13 +1205,20 @@ namespace sd { inline Tensor interpolate(const Tensor& input, std::vector output_shape, InterpolateMode mode = InterpolateMode::Nearest, - bool align_corners = false) { - if (mode != InterpolateMode::Nearest) { - tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" + + bool align_corners = false, + bool antialias = false) { + const bool is_nearest_like_mode = is_nearest_like_interpolate_mode(mode); + const bool is_2d_filter_mode = is_2d_filter_interpolate_mode(mode); + if (!is_nearest_like_mode && !is_2d_filter_mode) { + tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" + + std::to_string(static_cast(mode))); + } + if (antialias && !is_2d_filter_mode) { + tensor_throw_invalid_argument("Tensor interpolate antialias requires a 2D filter mode: mode=" + std::to_string(static_cast(mode))); } if (align_corners) { - tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" + + tensor_throw_invalid_argument("align_corners is not supported for tensor interpolate: input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } @@ -1044,14 +1245,126 @@ namespace sd { } } + if (is_2d_filter_mode) { + return interpolate_2d_filter(input, output_shape, mode, antialias); + } + + bool has_downsampling = false; + for (int64_t i = 0; i < input.dim(); ++i) { + if (input.shape()[i] > output_shape[i]) { + has_downsampling = true; + break; + } + } + Tensor output(std::move(output_shape)); - for (int64_t flat = 0; flat < output.numel(); ++flat) { - std::vector output_coord = tensor_unravel_index(flat, output.shape()); - std::vector input_coord(static_cast(input.dim()), 0); - for (size_t i = 0; i < static_cast(input.dim()); ++i) { - input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i]; + if (mode == InterpolateMode::Nearest || + mode == InterpolateMode::NearestExact || + !has_downsampling) { + for (int64_t flat = 0; flat < output.numel(); ++flat) { + std::vector output_coord = tensor_unravel_index(flat, output.shape()); + std::vector input_coord(static_cast(input.dim()), 0); + for (size_t i = 0; i < static_cast(input.dim()); ++i) { + if (mode == InterpolateMode::NearestExact) { + input_coord[i] = nearest_exact_interpolate_index(output_coord[i], + input.shape()[i], + output.shape()[i]); + } else { + input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i]; + } + } + output[flat] = input.index(input_coord); + } + + return output; + } + + auto init_reduction = [&]() -> T { + switch (mode) { + case InterpolateMode::NearestMax: + return std::numeric_limits::lowest(); + case InterpolateMode::NearestMin: + return std::numeric_limits::max(); + case InterpolateMode::NearestAvg: + return T(0); + case InterpolateMode::Nearest: + return T(0); + case InterpolateMode::NearestExact: + return T(0); + case InterpolateMode::Bilinear: + case InterpolateMode::Bicubic: + case InterpolateMode::Lanczos: + break; + } + + tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" + + std::to_string(static_cast(mode))); + }; + + auto reduce_value = [&](T& acc, const T& sample) { + switch (mode) { + case InterpolateMode::NearestMax: + acc = std::max(acc, sample); + break; + case InterpolateMode::NearestMin: + acc = std::min(acc, sample); + break; + case InterpolateMode::NearestAvg: + acc += sample; + break; + case InterpolateMode::Nearest: + break; + case InterpolateMode::NearestExact: + break; + case InterpolateMode::Bilinear: + case InterpolateMode::Bicubic: + case InterpolateMode::Lanczos: + break; + } + }; + + // Reduction modes only differ from nearest mode when downsampling. + for (int64_t flat_out = 0; flat_out < output.numel(); ++flat_out) { + std::vector output_coord = tensor_unravel_index(flat_out, output.shape()); + + std::vector input_start(output.dim(), 0); + std::vector input_end(output.dim(), 0); + + for (size_t i = 0; i < static_cast(output.dim()); ++i) { + const int64_t input_dim = input.shape()[i]; + const int64_t output_dim = output.shape()[i]; + + input_start[i] = std::max(int64_t(0), static_cast(output_coord[i] * input_dim / output_dim)); + input_end[i] = std::min(input_dim, ((output_coord[i] + 1) * input_dim + output_dim - 1) / output_dim); + } + + T value = init_reduction(); + bool done_window = false; + std::vector current_in_coord = input_start; + + while (!done_window) { + reduce_value(value, input.index(current_in_coord)); + + for (int d = static_cast(output.dim()) - 1; d >= 0; --d) { + if (++current_in_coord[d] < input_end[d]) { + break; + } + current_in_coord[d] = input_start[d]; + if (d == 0) { + done_window = true; + } + } + } + + if (mode == InterpolateMode::NearestAvg) { + int64_t window_size = 1; + for (size_t i = 0; i < static_cast(output.dim()); ++i) { + window_size *= (input_end[i] - input_start[i]); + } + value /= static_cast(window_size); } - output[flat] = input.index(input_coord); + + output[flat_out] = value; } return output; @@ -1062,13 +1375,20 @@ namespace sd { const std::optional>& size, const std::optional>& scale_factor, InterpolateMode mode = InterpolateMode::Nearest, - bool align_corners = false) { - if (mode != InterpolateMode::Nearest) { - tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" + + bool align_corners = false, + bool antialias = false) { + const bool is_nearest_like_mode = is_nearest_like_interpolate_mode(mode); + const bool is_2d_filter_mode = is_2d_filter_interpolate_mode(mode); + if (!is_nearest_like_mode && !is_2d_filter_mode) { + tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" + + std::to_string(static_cast(mode))); + } + if (antialias && !is_2d_filter_mode) { + tensor_throw_invalid_argument("Tensor interpolate antialias requires a 2D filter mode: mode=" + std::to_string(static_cast(mode))); } if (align_corners) { - tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" + + tensor_throw_invalid_argument("align_corners is not supported for tensor interpolate: input_shape=" + tensor_shape_to_string(input.shape())); } if (size.has_value() == scale_factor.has_value()) { @@ -1112,7 +1432,7 @@ namespace sd { } } - return interpolate(input, std::move(output_shape), mode, align_corners); + return interpolate(input, std::move(output_shape), mode, align_corners, antialias); } template @@ -1120,12 +1440,88 @@ namespace sd { const std::optional>& size, double scale_factor, InterpolateMode mode = InterpolateMode::Nearest, - bool align_corners = false) { + bool align_corners = false, + bool antialias = false) { return interpolate(input, size, std::vector(size.has_value() ? size->size() : input.dim(), scale_factor), mode, - align_corners); + align_corners, + antialias); + } + + template + inline Tensor max_pool_2d(const Tensor& input, + std::vector kernel_size, + std::vector stride, + std::vector padding) { + if (input.dim() < 2) { + tensor_throw_invalid_argument("Tensor max_pool_2d requires input_dim >= 2: input_dim=" + + std::to_string(input.dim()) + ", input_shape=" + + tensor_shape_to_string(input.shape())); + } + if (kernel_size.size() != 2 || stride.size() != 2 || padding.size() != 2) { + tensor_throw_invalid_argument("Tensor max_pool_2d requires kernel_size, stride, and padding to have length 2"); + } + for (size_t i = 0; i < 2; ++i) { + if (kernel_size[i] <= 0) { + tensor_throw_invalid_argument("Tensor max_pool_2d kernel_size must be positive: kernel_size=" + + tensor_shape_to_string(kernel_size)); + } + if (stride[i] <= 0) { + tensor_throw_invalid_argument("Tensor max_pool_2d stride must be positive: stride=" + + tensor_shape_to_string(stride)); + } + if (padding[i] < 0) { + tensor_throw_invalid_argument("Tensor max_pool_2d padding must be non-negative: padding=" + + tensor_shape_to_string(padding)); + } + } + + const int64_t in_height = input.shape()[0]; + const int64_t in_width = input.shape()[1]; + + const int64_t out_height = (in_height + 2 * padding[0] - kernel_size[0]) / stride[0] + 1; + const int64_t out_width = (in_width + 2 * padding[1] - kernel_size[1]) / stride[1] + 1; + + if (out_height <= 0 || out_width <= 0) { + tensor_throw_invalid_argument("max_pool_2d results in invalid output dimensions: " + + std::to_string(out_height) + "x" + std::to_string(out_width)); + } + + std::vector output_shape = input.shape(); + output_shape[0] = out_height; + output_shape[1] = out_width; + + Tensor output(std::move(output_shape)); + + for (int64_t flat_out = 0; flat_out < output.numel(); ++flat_out) { + std::vector output_coord = tensor_unravel_index(flat_out, output.shape()); + std::vector input_coord = output_coord; + + const int64_t oh = output_coord[0]; + const int64_t ow = output_coord[1]; + + T max_val = std::numeric_limits::lowest(); + bool has_valid_input = false; + + for (int64_t kh = 0; kh < kernel_size[0]; ++kh) { + for (int64_t kw = 0; kw < kernel_size[1]; ++kw) { + const int64_t ih = oh * stride[0] + kh - padding[0]; + const int64_t iw = ow * stride[1] + kw - padding[1]; + + if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) { + input_coord[0] = ih; + input_coord[1] = iw; + max_val = std::max(max_val, input.index(input_coord)); + has_valid_input = true; + } + } + } + + output[flat_out] = has_valid_input ? max_val : T(0); + } + return output; } template diff --git a/src/tokenizers/bpe_tokenizer.cpp b/src/tokenizers/bpe_tokenizer.cpp new file mode 100644 index 000000000..1ad5d9428 --- /dev/null +++ b/src/tokenizers/bpe_tokenizer.cpp @@ -0,0 +1,189 @@ +#include "bpe_tokenizer.h" + +#include +#include + +#include "tokenize_util.h" +#include "util.h" + +std::vector> BPETokenizer::bytes_to_unicode() { + std::vector> byte_unicode_pairs; + std::set byte_set; + for (int b = static_cast('!'); b <= static_cast('~'); ++b) { + byte_set.insert(b); + byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(b))); + } + for (int b = 161; b <= 172; ++b) { + byte_set.insert(b); + byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(b))); + } + for (int b = 174; b <= 255; ++b) { + byte_set.insert(b); + byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(b))); + } + int n = 0; + for (int b = 0; b < 256; ++b) { + if (byte_set.find(b) == byte_set.end()) { + byte_unicode_pairs.push_back(std::pair(b, unicode_value_to_utf32(n + 256))); + ++n; + } + } + return byte_unicode_pairs; +} + +std::vector BPETokenizer::token_split(const std::string& text) const { + return ::token_split(text); +} + +std::vector BPETokenizer::split_utf32(const std::string& text, char32_t delimiter) { + std::vector result; + size_t start = 0; + size_t pos = 0; + std::u32string utf32_text = utf8_to_utf32(text); + while ((pos = utf32_text.find(delimiter, start)) != std::u32string::npos) { + result.push_back(utf32_text.substr(start, pos - start)); + start = pos + 1; + } + return result; +} + +static std::set> get_pairs(const std::vector& subwords) { + std::set> pairs; + if (subwords.empty()) { + return pairs; + } + + std::u32string prev_subword = subwords[0]; + for (int i = 1; i < static_cast(subwords.size()); i++) { + std::u32string subword = subwords[i]; + std::pair pair(prev_subword, subword); + pairs.insert(pair); + prev_subword = subword; + } + return pairs; +} + +std::vector BPETokenizer::bpe(const std::u32string& token) const { + std::vector word; + + for (int i = 0; i < static_cast(token.size()) - 1; i++) { + word.emplace_back(1, token[i]); + } + word.push_back(token.substr(token.size() - 1) + utf8_to_utf32(end_of_word_suffix)); + + std::set> pairs = get_pairs(word); + + if (pairs.empty()) { + return {token + utf8_to_utf32(end_of_word_suffix)}; + } + + while (true) { + auto min_pair_iter = std::min_element(pairs.begin(), + pairs.end(), + [&](const std::pair& a, + const std::pair& b) { + if (bpe_ranks.find(a) == bpe_ranks.end()) { + return false; + } else if (bpe_ranks.find(b) == bpe_ranks.end()) { + return true; + } + return bpe_ranks.at(a) < bpe_ranks.at(b); + }); + + const std::pair& bigram = *min_pair_iter; + + if (bpe_ranks.find(bigram) == bpe_ranks.end()) { + break; + } + + std::u32string first = bigram.first; + std::u32string second = bigram.second; + std::vector new_word; + int32_t i = 0; + + while (i < static_cast(word.size())) { + auto it = std::find(word.begin() + i, word.end(), first); + if (it == word.end()) { + new_word.insert(new_word.end(), word.begin() + i, word.end()); + break; + } + new_word.insert(new_word.end(), word.begin() + i, it); + i = static_cast(std::distance(word.begin(), it)); + + if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { + new_word.push_back(first + second); + i += 2; + } else { + new_word.push_back(word[i]); + i += 1; + } + } + + word = new_word; + + if (word.size() == 1) { + break; + } + pairs = get_pairs(word); + } + + return word; +} + +std::vector BPETokenizer::encode(const std::string& text, on_new_token_cb_t on_new_token_cb) { + std::string normalized_text = normalize(text); + std::vector bpe_tokens; + std::vector token_strs; + + auto splited_texts = split_with_special_tokens(normalized_text, special_tokens); + + for (auto& splited_text : splited_texts) { + if (is_special_token(splited_text)) { + if (on_new_token_cb != nullptr) { + bool skip = on_new_token_cb(splited_text, bpe_tokens); + if (skip) { + token_strs.push_back(splited_text); + continue; + } + } + bpe_tokens.push_back(encoder[utf8_to_utf32(splited_text)]); + token_strs.push_back(splited_text); + continue; + } + auto tokens = token_split(splited_text); + for (auto& token : tokens) { + if (on_new_token_cb != nullptr) { + bool skip = on_new_token_cb(token, bpe_tokens); + if (skip) { + token_strs.push_back(splited_text); + continue; + } + } + + std::string token_str = token; + std::u32string utf32_token; + for (int i = 0; i < static_cast(token_str.length()); i++) { + unsigned char b = token_str[i]; + utf32_token += byte_encoder[b]; + } + auto bpe_strs = bpe(utf32_token); + for (auto bpe_str : bpe_strs) { + bpe_tokens.push_back(encoder[bpe_str]); + token_strs.push_back(utf32_to_utf8(bpe_str)); + } + } + } + + std::stringstream ss; + ss << "["; + for (auto token : token_strs) { + ss << "\"" << token << "\", "; + } + ss << "]"; + LOG_DEBUG("split prompt \"%s\" to tokens %s", text.c_str(), ss.str().c_str()); + return bpe_tokens; +} + +std::string BPETokenizer::decode_token(int token_id) const { + return utf32_to_utf8(decoder.at(token_id)); +} diff --git a/src/tokenizers/bpe_tokenizer.h b/src/tokenizers/bpe_tokenizer.h new file mode 100644 index 000000000..4dca4e97a --- /dev/null +++ b/src/tokenizers/bpe_tokenizer.h @@ -0,0 +1,40 @@ +#ifndef __SD_TOKENIZERS_BPE_TOKENIZER_H__ +#define __SD_TOKENIZERS_BPE_TOKENIZER_H__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tokenizer.h" + +class BPETokenizer : public Tokenizer { +protected: + std::map byte_encoder; + std::map byte_decoder; + std::map encoder; + std::map decoder; + std::map, int> bpe_ranks; + int encoder_len = 0; + int bpe_len = 0; + +protected: + static std::vector> bytes_to_unicode(); + static std::vector split_utf32(const std::string& text, char32_t delimiter = U'\n'); + virtual std::vector token_split(const std::string& text) const; + std::vector bpe(const std::u32string& token) const; + std::string decode_token(int token_id) const override; + +public: + BPETokenizer() = default; + virtual ~BPETokenizer() = default; + + std::vector encode(const std::string& text, on_new_token_cb_t on_new_token_cb = nullptr) override; +}; + +#endif // __SD_TOKENIZERS_BPE_TOKENIZER_H__ diff --git a/src/tokenizers/clip_tokenizer.cpp b/src/tokenizers/clip_tokenizer.cpp new file mode 100644 index 000000000..70d637724 --- /dev/null +++ b/src/tokenizers/clip_tokenizer.cpp @@ -0,0 +1,116 @@ +#include "clip_tokenizer.h" + +#include +#include +#include +#include +#include + +#include "ggml.h" +#include "tokenize_util.h" +#include "util.h" +#include "vocab/vocab.h" + +CLIPTokenizer::CLIPTokenizer(int pad_token_id, const std::string& merges_utf8_str) { + UNK_TOKEN = "<|endoftext|>"; + BOS_TOKEN = "<|startoftext|>"; + EOS_TOKEN = "<|endoftext|>"; + PAD_TOKEN = "<|endoftext|>"; + + UNK_TOKEN_ID = 49407; + BOS_TOKEN_ID = 49406; + EOS_TOKEN_ID = 49407; + PAD_TOKEN_ID = pad_token_id; + + end_of_word_suffix = ""; + add_bos_token = true; + add_eos_token = true; + + if (merges_utf8_str.size() > 0) { + load_from_merges(merges_utf8_str); + } else { + load_from_merges(load_clip_merges()); + } + add_special_token("<|startoftext|>"); + add_special_token("<|endoftext|>"); +} + +void CLIPTokenizer::load_from_merges(const std::string& merges_utf8_str) { + auto byte_unicode_pairs = bytes_to_unicode(); + byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + for (auto& pair : byte_unicode_pairs) { + byte_decoder[pair.second] = pair.first; + } + + std::vector merges = split_utf32(merges_utf8_str); + GGML_ASSERT(merges.size() == 48895); + merges = std::vector(merges.begin() + 1, merges.end()); + std::vector> merge_pairs; + for (const auto& merge : merges) { + size_t space_pos = merge.find(' '); + merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); + } + std::vector vocab; + for (const auto& pair : byte_unicode_pairs) { + vocab.push_back(pair.second); + } + for (const auto& pair : byte_unicode_pairs) { + vocab.push_back(pair.second + utf8_to_utf32("")); + } + for (const auto& merge : merge_pairs) { + vocab.push_back(merge.first + merge.second); + } + vocab.push_back(utf8_to_utf32("<|startoftext|>")); + vocab.push_back(utf8_to_utf32("<|endoftext|>")); + LOG_DEBUG("vocab size: %zu", vocab.size()); + int i = 0; + for (const auto& token : vocab) { + encoder[token] = i; + decoder[i] = token; + i++; + } + encoder_len = i; + + int rank = 0; + for (const auto& merge : merge_pairs) { + bpe_ranks[merge] = rank++; + } + bpe_len = rank; +} + +static std::string strip(const std::string& str) { + std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); + std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); + + if (start == std::string::npos) { + return ""; + } + + return str.substr(start, end - start + 1); +} + +static std::string whitespace_clean(const std::string& text) { + auto result = std::regex_replace(text, std::regex(R"(\s+)"), " "); + result = strip(result); + return result; +} + +std::string CLIPTokenizer::normalize(const std::string& text) const { + auto normalized_text = whitespace_clean(text); + std::transform(normalized_text.begin(), normalized_text.end(), normalized_text.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); + return normalized_text; +} + +std::vector CLIPTokenizer::token_split(const std::string& text) const { + std::regex clip_pat(R"('s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)", + std::regex::icase); + std::sregex_iterator iter(text.begin(), text.end(), clip_pat); + std::sregex_iterator end; + + std::vector result; + for (; iter != end; ++iter) { + result.emplace_back(iter->str()); + } + + return result; +} diff --git a/src/tokenizers/clip_tokenizer.h b/src/tokenizers/clip_tokenizer.h new file mode 100644 index 000000000..d4d71ae77 --- /dev/null +++ b/src/tokenizers/clip_tokenizer.h @@ -0,0 +1,20 @@ +#ifndef __SD_TOKENIZERS_CLIP_TOKENIZER_H__ +#define __SD_TOKENIZERS_CLIP_TOKENIZER_H__ + +#include +#include +#include + +#include "bpe_tokenizer.h" + +class CLIPTokenizer : public BPETokenizer { +protected: + void load_from_merges(const std::string& merges_utf8_str); + std::string normalize(const std::string& text) const override; + std::vector token_split(const std::string& text) const override; + +public: + explicit CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = ""); +}; + +#endif // __SD_TOKENIZERS_CLIP_TOKENIZER_H__ diff --git a/src/tokenizers/mistral_tokenizer.cpp b/src/tokenizers/mistral_tokenizer.cpp new file mode 100644 index 000000000..9b0624e3a --- /dev/null +++ b/src/tokenizers/mistral_tokenizer.cpp @@ -0,0 +1,89 @@ +#include "mistral_tokenizer.h" + +#include "ggml.h" +#include "json.hpp" +#include "util.h" +#include "vocab/vocab.h" + +void MistralTokenizer::load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) { + nlohmann::json vocab; + + try { + vocab = nlohmann::json::parse(vocab_utf8_str); + } catch (const nlohmann::json::parse_error&) { + GGML_ABORT("invalid vocab json str"); + } + for (const auto& [key, value] : vocab.items()) { + std::u32string token = utf8_to_utf32(key); + int i = value; + encoder[token] = i; + decoder[i] = token; + } + encoder_len = static_cast(vocab.size()); + LOG_DEBUG("vocab size: %d", encoder_len); + + auto byte_unicode_pairs = bytes_to_unicode(); + byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + for (auto& pair : byte_unicode_pairs) { + byte_decoder[pair.second] = pair.first; + } + std::vector merges = split_utf32(merges_utf8_str); + LOG_DEBUG("merges size %zu", merges.size()); + std::vector> merge_pairs; + for (const auto& merge : merges) { + size_t space_pos = merge.find(' '); + merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); + } + + int rank = 0; + for (const auto& merge : merge_pairs) { + bpe_ranks[merge] = rank++; + } + bpe_len = rank; +} + +MistralTokenizer::MistralTokenizer(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) { + add_bos_token = true; + + UNK_TOKEN = ""; + BOS_TOKEN = ""; + EOS_TOKEN = ""; + PAD_TOKEN = ""; + + UNK_TOKEN_ID = 0; + BOS_TOKEN_ID = 1; + EOS_TOKEN_ID = 2; + PAD_TOKEN_ID = 11; + + special_tokens = { + "", + "", + "", + "[INST]", + "[/INST]", + "[AVAILABLE_TOOLS]", + "[/AVAILABLE_TOOLS]", + "[TOOL_RESULTS]", + "[/TOOL_RESULTS]", + "[TOOL_CALLS]", + "[IMG]", + "", + "[IMG_BREAK]", + "[IMG_END]", + "[PREFIX]", + "[MIDDLE]", + "[SUFFIX]", + "[SYSTEM_PROMPT]", + "[/SYSTEM_PROMPT]", + "[TOOL_CONTENT]", + }; + for (int i = 20; i < 1000; i++) { + special_tokens.push_back(""); + } + + if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) { + load_from_merges(merges_utf8_str, vocab_utf8_str); + } else { + load_from_merges(load_mistral_merges(), load_mistral_vocab_json()); + } +} diff --git a/src/tokenizers/mistral_tokenizer.h b/src/tokenizers/mistral_tokenizer.h new file mode 100644 index 000000000..6749f56f1 --- /dev/null +++ b/src/tokenizers/mistral_tokenizer.h @@ -0,0 +1,16 @@ +#ifndef __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__ +#define __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__ + +#include + +#include "bpe_tokenizer.h" + +class MistralTokenizer : public BPETokenizer { +protected: + void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str); + +public: + explicit MistralTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = ""); +}; + +#endif // __SD_TOKENIZERS_MISTRAL_TOKENIZER_H__ diff --git a/src/tokenizers/qwen2_tokenizer.cpp b/src/tokenizers/qwen2_tokenizer.cpp new file mode 100644 index 000000000..9929ea387 --- /dev/null +++ b/src/tokenizers/qwen2_tokenizer.cpp @@ -0,0 +1,91 @@ +#include "qwen2_tokenizer.h" + +#include "util.h" +#include "vocab/vocab.h" + +void Qwen2Tokenizer::load_from_merges(const std::string& merges_utf8_str) { + auto byte_unicode_pairs = bytes_to_unicode(); + byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + for (auto& pair : byte_unicode_pairs) { + byte_decoder[pair.second] = pair.first; + } + + std::vector merges = split_utf32(merges_utf8_str); + LOG_DEBUG("merges size %zu", merges.size()); + std::vector> merge_pairs; + for (const auto& merge : merges) { + size_t space_pos = merge.find(' '); + merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); + } + + std::vector tokens; + for (const auto& pair : byte_unicode_pairs) { + tokens.push_back(pair.second); + } + for (const auto& merge : merge_pairs) { + tokens.push_back(merge.first + merge.second); + } + for (auto& special_token : special_tokens) { + tokens.push_back(utf8_to_utf32(special_token)); + } + + int i = 0; + for (const auto& token : tokens) { + encoder[token] = i; + decoder[i] = token; + i++; + } + encoder_len = i; + LOG_DEBUG("vocab size: %d", encoder_len); + + int rank = 0; + for (const auto& merge : merge_pairs) { + bpe_ranks[merge] = rank++; + } + bpe_len = rank; +} + +Qwen2Tokenizer::Qwen2Tokenizer(const std::string& merges_utf8_str) { + UNK_TOKEN = "<|endoftext|>"; + EOS_TOKEN = "<|endoftext|>"; + PAD_TOKEN = "<|endoftext|>"; + + UNK_TOKEN_ID = 151643; + EOS_TOKEN_ID = 151643; + PAD_TOKEN_ID = 151643; + + special_tokens = { + "<|endoftext|>", + "<|im_start|>", + "<|im_end|>", + "<|object_ref_start|>", + "<|object_ref_end|>", + "<|box_start|>", + "<|box_end|>", + "<|quad_start|>", + "<|quad_end|>", + "<|vision_start|>", + "<|vision_end|>", + "<|vision_pad|>", + "<|image_pad|>", + "<|video_pad|>", + "", + "", + "<|fim_prefix|>", + "<|fim_middle|>", + "<|fim_suffix|>", + "<|fim_pad|>", + "<|repo_name|>", + "<|file_sep|>", + "", + "", + "", + "", + }; + + if (merges_utf8_str.size() > 0) { + load_from_merges(merges_utf8_str); + } else { + load_from_merges(load_qwen2_merges()); + } +} diff --git a/src/tokenizers/qwen2_tokenizer.h b/src/tokenizers/qwen2_tokenizer.h new file mode 100644 index 000000000..04e92c2c3 --- /dev/null +++ b/src/tokenizers/qwen2_tokenizer.h @@ -0,0 +1,16 @@ +#ifndef __SD_TOKENIZERS_QWEN2_TOKENIZER_H__ +#define __SD_TOKENIZERS_QWEN2_TOKENIZER_H__ + +#include + +#include "bpe_tokenizer.h" + +class Qwen2Tokenizer : public BPETokenizer { +protected: + void load_from_merges(const std::string& merges_utf8_str); + +public: + explicit Qwen2Tokenizer(const std::string& merges_utf8_str = ""); +}; + +#endif // __SD_TOKENIZERS_QWEN2_TOKENIZER_H__ diff --git a/src/tokenizers/t5_unigram_tokenizer.cpp b/src/tokenizers/t5_unigram_tokenizer.cpp new file mode 100644 index 000000000..8ed4df539 --- /dev/null +++ b/src/tokenizers/t5_unigram_tokenizer.cpp @@ -0,0 +1,339 @@ +#include "t5_unigram_tokenizer.h" + +#include +#include +#include +#include +#include + +#include "json.hpp" +#include "tokenize_util.h" +#include "util.h" +#include "vocab/vocab.h" + +// Port from: https://github.com/google/sentencepiece/blob/master/src/unigram_model.h +// and https://github.com/google/sentencepiece/blob/master/src/unigram_model.h. +// Original License: https://github.com/google/sentencepiece/blob/master/LICENSE +// +// Since tokenization is not the bottleneck in SD, performance was not a major consideration +// during the migration. + +MetaspacePreTokenizer::MetaspacePreTokenizer(const std::string replacement, bool add_prefix_space) + : replacement(replacement), add_prefix_space(add_prefix_space) {} + +std::string MetaspacePreTokenizer::tokenize(const std::string& input) const { + std::string tokens; + std::stringstream ss(input); + + if (add_prefix_space) { + tokens += replacement; + } + + std::string token; + bool first_token = true; + while (std::getline(ss, token, ' ')) { + if (!first_token) { + tokens += replacement + token; + } else { + tokens += token; + } + + first_token = false; + } + + return tokens; +} + +void T5UniGramTokenizer::InitializePieces(const std::string& json_str) { + nlohmann::json data; + + try { + data = nlohmann::json::parse(json_str); + } catch (const nlohmann::json::parse_error&) { + status_ = INVLIAD_JSON; + return; + } + if (!data.contains("model")) { + status_ = INVLIAD_JSON; + return; + } + nlohmann::json model = data["model"]; + if (!model.contains("vocab")) { + status_ = INVLIAD_JSON; + return; + } + if (model.contains("unk_id")) { + UNK_TOKEN_ID = model["unk_id"]; + } + + replacement = data["pre_tokenizer"]["replacement"]; + add_prefix_space = data["pre_tokenizer"]["add_prefix_space"]; + + pre_tokenizer = MetaspacePreTokenizer(replacement, add_prefix_space); + + for (const auto& item : model["vocab"]) { + if (item.size() != 2 || !item[0].is_string() || !item[1].is_number_float()) { + status_ = INVLIAD_JSON; + return; + } + std::string piece = item[0]; + if (piece.empty()) { + piece = ""; + } + float score = item[1]; + piece_score_pairs.emplace_back(piece, score); + } +} + +void T5UniGramTokenizer::BuildTrie(std::vector>* pieces) { + if (status_ != OK) { + return; + } + + if (pieces->empty()) { + status_ = NO_PIECES_LOADED; + return; + } + + std::sort(pieces->begin(), pieces->end()); + + std::vector key(pieces->size()); + std::vector value(pieces->size()); + for (size_t i = 0; i < pieces->size(); ++i) { + key[i] = (*pieces)[i].first.data(); + value[i] = (*pieces)[i].second; + } + + trie_ = std::unique_ptr(new Darts::DoubleArray()); + if (trie_->build(key.size(), const_cast(&key[0]), nullptr, &value[0]) != 0) { + status_ = BUILD_DOUBLE_ARRAY_FAILED; + return; + } + + const int kMaxTrieResultsSize = 1024; + std::vector results(kMaxTrieResultsSize); + trie_results_size_ = 0; + for (const auto& p : *pieces) { + const size_t num_nodes = trie_->commonPrefixSearch( + p.first.data(), results.data(), results.size(), p.first.size()); + trie_results_size_ = std::max(trie_results_size_, static_cast(num_nodes)); + } + + if (trie_results_size_ == 0) { + status_ = NO_ENTRY_FOUND; + } +} + +float T5UniGramTokenizer::GetScoreInlined(int id) const { + return piece_score_pairs[id].second; +} + +bool T5UniGramTokenizer::IsUnusedInlined(int id) const { + (void)id; + return false; +} + +bool T5UniGramTokenizer::IsUserDefinedInlined(int id) const { + (void)id; + return false; +} + +size_t T5UniGramTokenizer::OneCharLen(const char* src) const { + return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4]; +} + +EncodeResult T5UniGramTokenizer::EncodeOptimized(const std::string& normalized) const { + if (status() != OK || normalized.empty()) { + return {}; + } + + struct BestPathNode { + int id = -1; + float best_path_score = 0; + int starts_at = -1; + }; + + const int size = static_cast(normalized.size()); + const float unk_score = min_score() - kUnkPenalty; + std::vector best_path_ends_at(size + 1); + + int starts_at = 0; + while (starts_at < size) { + std::size_t node_pos = 0; + std::size_t key_pos = starts_at; + const auto best_path_score_till_here = best_path_ends_at[starts_at].best_path_score; + bool has_single_node = false; + const int mblen = std::min(static_cast(OneCharLen(normalized.data() + starts_at)), size - starts_at); + while (key_pos < static_cast(size)) { + const int ret = trie_->traverse(normalized.data(), node_pos, key_pos, key_pos + 1); + if (ret == -2) { + break; + } + if (ret >= 0) { + if (IsUnusedInlined(ret)) { + continue; + } + auto& target_node = best_path_ends_at[key_pos]; + const auto length = static_cast(key_pos - starts_at); + const auto score = IsUserDefinedInlined(ret) ? (length * max_score_ - 0.1f) : GetScoreInlined(ret); + const auto candidate_best_path_score = score + best_path_score_till_here; + if (target_node.starts_at == -1 || candidate_best_path_score > target_node.best_path_score) { + target_node.best_path_score = static_cast(candidate_best_path_score); + target_node.starts_at = starts_at; + target_node.id = ret; + } + if (!has_single_node && length == mblen) { + has_single_node = true; + } + } + } + if (!has_single_node) { + auto& target_node = best_path_ends_at[starts_at + mblen]; + const auto candidate_best_path_score = unk_score + best_path_score_till_here; + if (target_node.starts_at == -1 || candidate_best_path_score > target_node.best_path_score) { + target_node.best_path_score = candidate_best_path_score; + target_node.starts_at = starts_at; + target_node.id = UNK_TOKEN_ID; + } + } + starts_at += mblen; + } + + EncodeResult results; + int ends_at = size; + while (ends_at > 0) { + const auto& node = best_path_ends_at[ends_at]; + results.emplace_back(normalized.substr(node.starts_at, ends_at - node.starts_at), node.id); + ends_at = node.starts_at; + } + std::reverse(results.begin(), results.end()); + return results; +} + +T5UniGramTokenizer::T5UniGramTokenizer(bool is_umt5) { + add_bos_token = false; + add_eos_token = true; + + if (is_umt5) { + PAD_TOKEN_ID = 0; + EOS_TOKEN_ID = 1; + BOS_TOKEN_ID = 2; + UNK_TOKEN_ID = 3; + + PAD_TOKEN = ""; + EOS_TOKEN = ""; + BOS_TOKEN = ""; + UNK_TOKEN = ""; + } else { + PAD_TOKEN_ID = 0; + EOS_TOKEN_ID = 1; + UNK_TOKEN_ID = 2; + + PAD_TOKEN = ""; + EOS_TOKEN = ""; + UNK_TOKEN = ""; + } + + special_tokens = { + "", + "", + "", + }; + + if (is_umt5) { + special_tokens.push_back(""); + } + + if (is_umt5) { + InitializePieces(load_umt5_tokenizer_json()); + } else { + InitializePieces(load_t5_tokenizer_json()); + } + + min_score_ = FLT_MAX; + max_score_ = FLT_MIN; + + std::vector> pieces; + for (int i = 0; i < static_cast(piece_score_pairs.size()); i++) { + const auto& sp = piece_score_pairs[i]; + + min_score_ = std::min(min_score_, sp.second); + max_score_ = std::max(max_score_, sp.second); + + pieces.emplace_back(sp.first, i); + } + + BuildTrie(&pieces); +} + +T5UniGramTokenizer::~T5UniGramTokenizer() = default; + +std::string T5UniGramTokenizer::decode_token(int token_id) const { + if (token_id < 0 || token_id >= static_cast(piece_score_pairs.size())) { + return ""; + } + + const std::string& piece = piece_score_pairs[token_id].first; + if (piece == "") { + return ""; + } + return piece; +} + +std::string T5UniGramTokenizer::normalize(const std::string& input) const { + // Ref: https://github.com/huggingface/tokenizers/blob/1ff56c0c70b045f0cd82da1af9ac08cd4c7a6f9f/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py#L29 + // TODO: nmt-nfkc + std::string normalized = std::regex_replace(input, std::regex(" {2,}"), " "); + return normalized; +} + +std::vector T5UniGramTokenizer::encode(const std::string& input, on_new_token_cb_t on_new_token_cb) { + std::vector tokens; + std::vector token_strs; + std::string normalized = normalize(input); + auto splited_texts = split_with_special_tokens(normalized, special_tokens); + if (splited_texts.empty()) { + splited_texts.push_back(normalized); // for empty string + } + + for (auto& splited_text : splited_texts) { + if (is_special_token(splited_text)) { + if (on_new_token_cb != nullptr) { + bool skip = on_new_token_cb(splited_text, tokens); + if (skip) { + token_strs.push_back(splited_text); + continue; + } + } + + if (splited_text == UNK_TOKEN) { + tokens.push_back(UNK_TOKEN_ID); + token_strs.push_back(UNK_TOKEN); + } else if (splited_text == EOS_TOKEN) { + tokens.push_back(EOS_TOKEN_ID); + token_strs.push_back(EOS_TOKEN); + } else if (splited_text == PAD_TOKEN) { + tokens.push_back(PAD_TOKEN_ID); + token_strs.push_back(PAD_TOKEN); + } + continue; + } + + std::string pretokenized = pre_tokenizer.tokenize(splited_text); + EncodeResult result = EncodeOptimized(pretokenized); + for (const auto& item : result) { + tokens.push_back(item.second); + token_strs.push_back(item.first); + } + } + + std::stringstream ss; + ss << "["; + for (const auto& token_str : token_strs) { + ss << "\"" << token_str << "\", "; + } + ss << "]"; + LOG_DEBUG("split prompt \"%s\" to tokens %s", input.c_str(), ss.str().c_str()); + + return tokens; +} diff --git a/src/tokenizers/t5_unigram_tokenizer.h b/src/tokenizers/t5_unigram_tokenizer.h new file mode 100644 index 000000000..9c9f13f8b --- /dev/null +++ b/src/tokenizers/t5_unigram_tokenizer.h @@ -0,0 +1,70 @@ +#ifndef __SD_TOKENIZERS_T5_UNIGRAM_TOKENIZER_H__ +#define __SD_TOKENIZERS_T5_UNIGRAM_TOKENIZER_H__ + +#include +#include +#include +#include +#include + +#include "darts.h" +#include "tokenizer.h" + +class MetaspacePreTokenizer { +private: + std::string replacement; + bool add_prefix_space; + +public: + MetaspacePreTokenizer(const std::string replacement = " ", bool add_prefix_space = true); + + std::string tokenize(const std::string& input) const; +}; + +using EncodeResult = std::vector>; + +class T5UniGramTokenizer : public Tokenizer { +public: + enum Status { + OK, + NO_PIECES_LOADED, + NO_ENTRY_FOUND, + BUILD_DOUBLE_ARRAY_FAILED, + PIECE_ALREADY_DEFINED, + INVLIAD_JSON + }; + +protected: + MetaspacePreTokenizer pre_tokenizer; + std::vector> piece_score_pairs; + float min_score_ = 0.0f; + float max_score_ = 0.0f; + std::unique_ptr trie_; + int trie_results_size_ = 0; + Status status_ = OK; + float kUnkPenalty = 10.0f; + std::string replacement; + bool add_prefix_space = true; + + void InitializePieces(const std::string& json_str); + void BuildTrie(std::vector>* pieces); + float GetScoreInlined(int id) const; + bool IsUnusedInlined(int id) const; + bool IsUserDefinedInlined(int id) const; + size_t OneCharLen(const char* src) const; + EncodeResult EncodeOptimized(const std::string& normalized) const; + + float min_score() const { return min_score_; } + float max_score() const { return max_score_; } + Status status() const { return status_; } + std::string decode_token(int token_id) const override; + std::string normalize(const std::string& input) const override; + +public: + explicit T5UniGramTokenizer(bool is_umt5 = false); + ~T5UniGramTokenizer(); + + std::vector encode(const std::string& input, on_new_token_cb_t on_new_token_cb = nullptr) override; +}; + +#endif // __SD_TOKENIZERS_T5_UNIGRAM_TOKENIZER_H__ diff --git a/src/tokenize_util.cpp b/src/tokenizers/tokenize_util.cpp similarity index 96% rename from src/tokenize_util.cpp rename to src/tokenizers/tokenize_util.cpp index 33fdad266..770bfb5fe 100644 --- a/src/tokenize_util.cpp +++ b/src/tokenizers/tokenize_util.cpp @@ -1,4 +1,4 @@ -๏ปฟ#include +#include #include #include #include diff --git a/src/tokenize_util.h b/src/tokenizers/tokenize_util.h similarity index 61% rename from src/tokenize_util.h rename to src/tokenizers/tokenize_util.h index e744d7503..efb0a1cc6 100644 --- a/src/tokenize_util.h +++ b/src/tokenizers/tokenize_util.h @@ -1,5 +1,5 @@ -#ifndef __TOKENIZE_UTIL__ -#define __TOKENIZE_UTIL__ +#ifndef __SD_TOKENIZERS_BPE_TOKENIZE_UTIL_H__ +#define __SD_TOKENIZERS_BPE_TOKENIZE_UTIL_H__ #include #include @@ -7,4 +7,4 @@ std::vector token_split(const std::string& text); std::vector split_with_special_tokens(const std::string& text, const std::vector& special_tokens); -#endif // __TOKENIZE_UTIL__ \ No newline at end of file +#endif // __SD_TOKENIZERS_BPE_TOKENIZE_UTIL_H__ \ No newline at end of file diff --git a/src/tokenizers/tokenizer.cpp b/src/tokenizers/tokenizer.cpp new file mode 100644 index 000000000..556cadd84 --- /dev/null +++ b/src/tokenizers/tokenizer.cpp @@ -0,0 +1,222 @@ +#include "tokenizer.h" + +#include +#include +#include + +#include "util.h" + +void Tokenizer::add_special_token(const std::string& token) { + special_tokens.push_back(token); +} + +bool Tokenizer::is_special_token(const std::string& token) const { + for (const auto& special_token : special_tokens) { + if (special_token == token) { + return true; + } + } + return false; +} + +std::string Tokenizer::normalize(const std::string& text) const { + return text; +} + +std::vector Tokenizer::tokenize(const std::string& text, + on_new_token_cb_t on_new_token_cb, + bool padding, + size_t min_length, + size_t max_length, + bool allow_overflow_expand) { + std::vector tokens = encode(text, on_new_token_cb); + if (padding) { + pad_tokens(tokens, nullptr, nullptr, min_length, max_length, allow_overflow_expand); + } + return tokens; +} + +void Tokenizer::pad_tokens(std::vector& tokens, + std::vector* weights, + std::vector* mask, + size_t min_length, + size_t max_length, + bool allow_overflow_expand) { + const bool use_weights = weights != nullptr; + const bool use_mask = mask != nullptr; + + if (use_weights && tokens.size() != weights->size()) { + LOG_ERROR("tokens size != weights size"); + return; + } + + const size_t bos_count = add_bos_token ? 1 : 0; + const size_t eos_count = add_eos_token ? 1 : 0; + const size_t special_token_count = bos_count + eos_count; + + auto build_sequence = [&](size_t begin, + size_t count, + size_t target_length, + std::vector& out_tokens, + std::vector& out_weights, + std::vector& out_mask) { + const size_t base_length = count + special_token_count; + const size_t final_length = std::max(target_length, base_length); + + out_tokens.clear(); + out_weights.clear(); + out_mask.clear(); + + out_tokens.reserve(final_length); + if (use_weights) { + out_weights.reserve(final_length); + } + if (use_mask) { + out_mask.reserve(final_length); + } + + if (add_bos_token) { + out_tokens.push_back(BOS_TOKEN_ID); + if (use_weights) { + out_weights.push_back(1.0f); + } + if (use_mask) { + out_mask.push_back(1.0f); + } + } + + for (size_t i = 0; i < count; ++i) { + out_tokens.push_back(tokens[begin + i]); + if (use_weights) { + out_weights.push_back((*weights)[begin + i]); + } + if (use_mask) { + out_mask.push_back(1.0f); + } + } + + if (add_eos_token) { + out_tokens.push_back(EOS_TOKEN_ID); + if (use_weights) { + out_weights.push_back(1.0f); + } + if (use_mask) { + out_mask.push_back(1.0f); + } + } + + if (final_length > out_tokens.size()) { + const size_t pad_count = final_length - out_tokens.size(); + if (pad_left) { + out_tokens.insert(out_tokens.begin(), pad_count, PAD_TOKEN_ID); + + if (use_weights) { + out_weights.insert(out_weights.begin(), pad_count, 1.0f); + } + if (use_mask) { + out_mask.insert(out_mask.begin(), pad_count, 0.0f); + } + } else { + out_tokens.insert(out_tokens.end(), pad_count, PAD_TOKEN_ID); + + if (use_weights) { + out_weights.insert(out_weights.end(), pad_count, 1.0f); + } + if (use_mask) { + out_mask.insert(out_mask.end(), pad_count, 0.0f); + } + } + } + }; + + const size_t single_length = std::max(min_length, tokens.size() + special_token_count); + const bool exceeds_max_length = max_length > 0 && single_length > max_length; + + std::vector new_tokens; + std::vector new_weights; + std::vector new_mask; + + if (!exceeds_max_length) { + build_sequence(0, tokens.size(), min_length, new_tokens, new_weights, new_mask); + } else if (!allow_overflow_expand) { + build_sequence(0, tokens.size(), 0, new_tokens, new_weights, new_mask); + + new_tokens.resize(max_length); + if (use_weights) { + new_weights.resize(max_length); + } + if (use_mask) { + new_mask.resize(max_length); + } + + if (add_eos_token && !new_tokens.empty()) { + new_tokens.back() = EOS_TOKEN_ID; + if (use_weights) { + new_weights.back() = 1.0f; + } + if (use_mask) { + new_mask.back() = 1.0f; + } + } + } else if (min_length > special_token_count) { + const size_t tokens_per_chunk = min_length - special_token_count; + size_t offset = 0; + + while (offset < tokens.size()) { + const size_t remaining = tokens.size() - offset; + const size_t take = std::min(tokens_per_chunk, remaining); + + std::vector chunk_tokens; + std::vector chunk_weights; + std::vector chunk_mask; + + build_sequence(offset, take, min_length, chunk_tokens, chunk_weights, chunk_mask); + + new_tokens.insert(new_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); + if (use_weights) { + new_weights.insert(new_weights.end(), chunk_weights.begin(), chunk_weights.end()); + } + if (use_mask) { + new_mask.insert(new_mask.end(), chunk_mask.begin(), chunk_mask.end()); + } + + offset += take; + } + } else { + build_sequence(0, tokens.size(), min_length, new_tokens, new_weights, new_mask); + } + + tokens = std::move(new_tokens); + if (use_weights) { + *weights = std::move(new_weights); + } + if (use_mask) { + *mask = std::move(new_mask); + } +} + +static std::string clean_up_tokenization(std::string& text) { + std::regex pattern(R"( ,)"); + return std::regex_replace(text, pattern, ","); +} + +std::string Tokenizer::decode(const std::vector& tokens) const { + std::string text; + + for (int token_id : tokens) { + if (token_id == BOS_TOKEN_ID || token_id == EOS_TOKEN_ID || token_id == PAD_TOKEN_ID) { + continue; + } + + std::string piece = decode_token(token_id); + if (!end_of_word_suffix.empty() && ends_with(piece, end_of_word_suffix)) { + piece.erase(piece.size() - end_of_word_suffix.size()); + text += piece + " "; + } else { + text += piece; + } + } + + text = clean_up_tokenization(text); + return trim(text); +} diff --git a/src/tokenizers/tokenizer.h b/src/tokenizers/tokenizer.h new file mode 100644 index 000000000..e044285bb --- /dev/null +++ b/src/tokenizers/tokenizer.h @@ -0,0 +1,53 @@ +#ifndef __SD_TOKENIZERS_TOKENIZER_H__ +#define __SD_TOKENIZERS_TOKENIZER_H__ + +#include +#include +#include +#include +#include + +using on_new_token_cb_t = std::function&)>; + +class Tokenizer { +protected: + std::vector special_tokens; + bool add_bos_token = false; + bool add_eos_token = false; + bool pad_left = false; + std::string end_of_word_suffix; + + virtual std::string decode_token(int token_id) const = 0; + virtual std::string normalize(const std::string& text) const; + +public: + std::string UNK_TOKEN; + std::string BOS_TOKEN; + std::string EOS_TOKEN; + std::string PAD_TOKEN; + int UNK_TOKEN_ID = 0; + int BOS_TOKEN_ID = 0; + int EOS_TOKEN_ID = 0; + int PAD_TOKEN_ID = 0; + + virtual ~Tokenizer() = default; + + void add_special_token(const std::string& token); + bool is_special_token(const std::string& token) const; + virtual std::vector encode(const std::string& text, on_new_token_cb_t on_new_token_cb = nullptr) = 0; + std::vector tokenize(const std::string& text, + on_new_token_cb_t on_new_token_cb = nullptr, + bool padding = false, + size_t min_length = 0, + size_t max_length = 100000000, + bool allow_overflow_expand = false); + void pad_tokens(std::vector& tokens, + std::vector* weights, + std::vector* mask, + size_t min_length = 0, + size_t max_length = 100000000, + bool allow_overflow_expand = false); + std::string decode(const std::vector& tokens) const; +}; + +#endif // __SD_TOKENIZERS_TOKENIZER_H__ diff --git a/src/vocab/clip_t5.hpp b/src/tokenizers/vocab/clip_t5.hpp similarity index 100% rename from src/vocab/clip_t5.hpp rename to src/tokenizers/vocab/clip_t5.hpp diff --git a/src/vocab/mistral.hpp b/src/tokenizers/vocab/mistral.hpp similarity index 100% rename from src/vocab/mistral.hpp rename to src/tokenizers/vocab/mistral.hpp diff --git a/src/vocab/qwen.hpp b/src/tokenizers/vocab/qwen.hpp similarity index 100% rename from src/vocab/qwen.hpp rename to src/tokenizers/vocab/qwen.hpp diff --git a/src/vocab/umt5.hpp b/src/tokenizers/vocab/umt5.hpp similarity index 100% rename from src/vocab/umt5.hpp rename to src/tokenizers/vocab/umt5.hpp diff --git a/src/vocab/vocab.cpp b/src/tokenizers/vocab/vocab.cpp similarity index 100% rename from src/vocab/vocab.cpp rename to src/tokenizers/vocab/vocab.cpp diff --git a/src/vocab/vocab.h b/src/tokenizers/vocab/vocab.h similarity index 66% rename from src/vocab/vocab.h rename to src/tokenizers/vocab/vocab.h index cfa033a49..de7a76406 100644 --- a/src/vocab/vocab.h +++ b/src/tokenizers/vocab/vocab.h @@ -1,5 +1,5 @@ -#ifndef __VOCAB_H__ -#define __VOCAB_H__ +#ifndef __SD_TOKENIZERS_VOCAB_VOCAB_H__ +#define __SD_TOKENIZERS_VOCAB_VOCAB_H__ #include @@ -10,4 +10,4 @@ std::string load_mistral_vocab_json(); std::string load_t5_tokenizer_json(); std::string load_umt5_tokenizer_json(); -#endif // __VOCAB_H__ \ No newline at end of file +#endif // __SD_TOKENIZERS_VOCAB_VOCAB_H__ \ No newline at end of file diff --git a/src/unet.hpp b/src/unet.hpp index 63e23eb93..d7ea8c3fa 100644 --- a/src/unet.hpp +++ b/src/unet.hpp @@ -217,11 +217,11 @@ class UnetModelBlock : public GGMLBlock { } else if (sd_version_is_unet_edit(version)) { in_channels = 8; } - if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS) { + if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) { num_res_blocks = 1; channel_mult = {1, 2, 4}; tiny_unet = true; - if (version == VERSION_SDXS) { + if (version == VERSION_SDXS_512_DS) { attention_resolutions = {4, 2}; // here just like SDXL } } @@ -264,6 +264,10 @@ class UnetModelBlock : public GGMLBlock { if (version == VERSION_SVD) { return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection); } else { + if (version == VERSION_SDXS_09 && n_head == 5) { + n_head = 1; // to carry a special case of sdxs_09 into CrossAttentionLayer, + d_head = 320; // works as long the product remains equal (5*64 == 1*320) + } return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection); } }; @@ -478,12 +482,14 @@ class UnetModelBlock : public GGMLBlock { emb = ggml_add(ctx->ggml_ctx, emb, label_emb); // [N, time_embed_dim] } + // sd::ggml_graph_cut::mark_graph_cut(emb, "unet.prelude", "emb"); // input_blocks std::vector hs; // input block 0 auto h = input_blocks_0_0->forward(ctx, x); + sd::ggml_graph_cut::mark_graph_cut(h, "unet.input_blocks.0", "h"); ggml_set_name(h, "bench-start"); hs.push_back(h); @@ -501,6 +507,7 @@ class UnetModelBlock : public GGMLBlock { std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1"; h = attention_layer_forward(name, ctx, h, context, num_video_frames); // [N, mult*model_channels, h, w] } + sd::ggml_graph_cut::mark_graph_cut(h, "unet.input_blocks." + std::to_string(input_block_idx), "h"); hs.push_back(h); } if (tiny_unet) { @@ -514,6 +521,7 @@ class UnetModelBlock : public GGMLBlock { auto block = std::dynamic_pointer_cast(blocks[name]); h = block->forward(ctx, h); // [N, mult*model_channels, h/(2^(i+1)), w/(2^(i+1))] + // sd::ggml_graph_cut::mark_graph_cut(h, "unet.input_blocks." + std::to_string(input_block_idx), "h"); hs.push_back(h); } } @@ -527,6 +535,7 @@ class UnetModelBlock : public GGMLBlock { h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] } } + sd::ggml_graph_cut::mark_graph_cut(h, "unet.middle_block", "h"); if (controls.size() > 0) { auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[controls.size() - 1], control_strength, true); h = ggml_add(ctx->ggml_ctx, h, cs); // middle control @@ -577,6 +586,7 @@ class UnetModelBlock : public GGMLBlock { } output_block_idx += 1; + sd::ggml_graph_cut::mark_graph_cut(h, "unet.output_blocks." + std::to_string(output_block_idx - 1), "h"); } } diff --git a/src/upscaler.cpp b/src/upscaler.cpp index 03f7714e5..25fc0c5df 100644 --- a/src/upscaler.cpp +++ b/src/upscaler.cpp @@ -1,125 +1,106 @@ -#include "esrgan.hpp" +#include "upscaler.h" #include "ggml_extend.hpp" #include "model.h" #include "stable-diffusion.h" #include "util.h" -struct UpscalerGGML { - ggml_backend_t backend = nullptr; // general backend - ggml_type model_data_type = GGML_TYPE_F16; - std::shared_ptr esrgan_upscaler; - std::string esrgan_path; - int n_threads; - bool direct = false; - int tile_size = 128; - - UpscalerGGML(int n_threads, - bool direct = false, - int tile_size = 128) - : n_threads(n_threads), - direct(direct), - tile_size(tile_size) { +UpscalerGGML::UpscalerGGML(int n_threads, + bool direct, + int tile_size) + : n_threads(n_threads), + direct(direct), + tile_size(tile_size) { +} + +void UpscalerGGML::set_max_graph_vram_bytes(size_t max_vram_bytes) { + max_graph_vram_bytes = max_vram_bytes; + if (esrgan_upscaler) { + esrgan_upscaler->set_max_graph_vram_bytes(max_vram_bytes); } +} + +bool UpscalerGGML::load_from_file(const std::string& esrgan_path, + bool offload_params_to_cpu, + int n_threads) { + ggml_log_set(ggml_log_callback_default, nullptr); - bool load_from_file(const std::string& esrgan_path, - bool offload_params_to_cpu, - int n_threads) { - ggml_log_set(ggml_log_callback_default, nullptr); -#ifdef SD_USE_CUDA - LOG_DEBUG("Using CUDA backend"); - backend = ggml_backend_cuda_init(0); -#endif -#ifdef SD_USE_METAL - LOG_DEBUG("Using Metal backend"); - backend = ggml_backend_metal_init(); -#endif -#ifdef SD_USE_VULKAN - LOG_DEBUG("Using Vulkan backend"); - backend = ggml_backend_vk_init(0); -#endif -#ifdef SD_USE_OPENCL - LOG_DEBUG("Using OpenCL backend"); - backend = ggml_backend_opencl_init(); -#endif -#ifdef SD_USE_SYCL - LOG_DEBUG("Using SYCL backend"); - backend = ggml_backend_sycl_init(0); -#endif - ModelLoader model_loader; - if (!model_loader.init_from_file_and_convert_name(esrgan_path)) { - LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str()); - } - model_loader.set_wtype_override(model_data_type); - if (!backend) { - LOG_DEBUG("Using CPU backend"); - backend = ggml_backend_cpu_init(); - } - LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type)); - esrgan_upscaler = std::make_shared(backend, offload_params_to_cpu, tile_size, model_loader.get_tensor_storage_map()); - if (direct) { - esrgan_upscaler->set_conv2d_direct_enabled(true); - } - if (!esrgan_upscaler->load_from_file(esrgan_path, n_threads)) { - return false; - } - return true; + backend = sd_get_default_backend(); + + ModelLoader model_loader; + if (!model_loader.init_from_file_and_convert_name(esrgan_path)) { + LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str()); + } + model_loader.set_wtype_override(model_data_type); + if (!backend) { + LOG_DEBUG("Using CPU backend"); + backend = ggml_backend_cpu_init(); + } + LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type)); + esrgan_upscaler = std::make_shared(backend, offload_params_to_cpu, tile_size, model_loader.get_tensor_storage_map()); + esrgan_upscaler->set_max_graph_vram_bytes(max_graph_vram_bytes); + if (direct) { + esrgan_upscaler->set_conv2d_direct_enabled(true); + } + if (!esrgan_upscaler->load_from_file(esrgan_path, n_threads)) { + return false; } + return true; +} + +sd::Tensor UpscalerGGML::upscale_tensor(const sd::Tensor& input_tensor) { + sd::Tensor upscaled; + if (tile_size <= 0 || (input_tensor.shape()[0] <= tile_size && input_tensor.shape()[1] <= tile_size)) { + upscaled = esrgan_upscaler->compute(n_threads, input_tensor); + } else { + auto on_processing = [&](const sd::Tensor& input_tile) -> sd::Tensor { + auto output_tile = esrgan_upscaler->compute(n_threads, input_tile); + if (output_tile.empty()) { + LOG_ERROR("esrgan compute failed while processing a tile"); + return {}; + } + return output_tile; + }; - sd::Tensor upscale_tensor(const sd::Tensor& input_tensor) { - sd::Tensor upscaled; - if (tile_size <= 0 || (input_tensor.shape()[0] <= tile_size && input_tensor.shape()[1] <= tile_size)) { - upscaled = esrgan_upscaler->compute(n_threads, input_tensor); - } else { - auto on_processing = [&](const sd::Tensor& input_tile) -> sd::Tensor { - auto output_tile = esrgan_upscaler->compute(n_threads, input_tile); - if (output_tile.empty()) { - LOG_ERROR("esrgan compute failed while processing a tile"); - return {}; - } - return output_tile; - }; - - upscaled = process_tiles_2d(input_tensor, - static_cast(input_tensor.shape()[0] * esrgan_upscaler->scale), - static_cast(input_tensor.shape()[1] * esrgan_upscaler->scale), - esrgan_upscaler->scale, - tile_size, - tile_size, - 0.25f, - false, - false, - on_processing); - } - esrgan_upscaler->free_compute_buffer(); - if (upscaled.empty()) { - LOG_ERROR("esrgan compute failed"); - return {}; - } - return upscaled; + upscaled = process_tiles_2d(input_tensor, + static_cast(input_tensor.shape()[0] * esrgan_upscaler->scale), + static_cast(input_tensor.shape()[1] * esrgan_upscaler->scale), + esrgan_upscaler->scale, + tile_size, + tile_size, + 0.25f, + false, + false, + on_processing); } + esrgan_upscaler->free_compute_buffer(); + if (upscaled.empty()) { + LOG_ERROR("esrgan compute failed"); + return {}; + } + return upscaled; +} - sd_image_t upscale(sd_image_t input_image, uint32_t upscale_factor) { - // upscale_factor, unused for RealESRGAN_x4plus_anime_6B.pth - sd_image_t upscaled_image = {0, 0, 0, nullptr}; - int output_width = (int)input_image.width * esrgan_upscaler->scale; - int output_height = (int)input_image.height * esrgan_upscaler->scale; - LOG_INFO("upscaling from (%i x %i) to (%i x %i)", - input_image.width, input_image.height, output_width, output_height); - - sd::Tensor input_tensor = sd_image_to_tensor(input_image); - sd::Tensor upscaled; - int64_t t0 = ggml_time_ms(); - upscaled = upscale_tensor(input_tensor); - if (upscaled.empty()) { - return upscaled_image; - } - sd_image_t upscaled_data = tensor_to_sd_image(upscaled); - int64_t t3 = ggml_time_ms(); - LOG_INFO("input_image_tensor upscaled, taking %.2fs", (t3 - t0) / 1000.0f); - upscaled_image = upscaled_data; +sd_image_t UpscalerGGML::upscale(sd_image_t input_image, uint32_t upscale_factor) { + // upscale_factor, unused for RealESRGAN_x4plus_anime_6B.pth + sd_image_t upscaled_image = {0, 0, 0, nullptr}; + int output_width = (int)input_image.width * esrgan_upscaler->scale; + int output_height = (int)input_image.height * esrgan_upscaler->scale; + LOG_INFO("upscaling from (%i x %i) to (%i x %i)", + input_image.width, input_image.height, output_width, output_height); + + sd::Tensor input_tensor = sd_image_to_tensor(input_image); + sd::Tensor upscaled; + int64_t t0 = ggml_time_ms(); + upscaled = upscale_tensor(input_tensor); + if (upscaled.empty()) { return upscaled_image; } -}; + sd_image_t upscaled_data = tensor_to_sd_image(upscaled); + int64_t t3 = ggml_time_ms(); + LOG_INFO("input_image_tensor upscaled, taking %.2fs", (t3 - t0) / 1000.0f); + upscaled_image = upscaled_data; + return upscaled_image; +} struct upscaler_ctx_t { UpscalerGGML* upscaler = nullptr; diff --git a/src/upscaler.h b/src/upscaler.h new file mode 100644 index 000000000..d667a6f15 --- /dev/null +++ b/src/upscaler.h @@ -0,0 +1,33 @@ +#ifndef __SD_UPSCALER_H__ +#define __SD_UPSCALER_H__ + +#include "esrgan.hpp" +#include "stable-diffusion.h" +#include "tensor.hpp" + +#include +#include + +struct UpscalerGGML { + ggml_backend_t backend = nullptr; // general backend + ggml_type model_data_type = GGML_TYPE_F16; + std::shared_ptr esrgan_upscaler; + std::string esrgan_path; + int n_threads; + bool direct = false; + int tile_size = 128; + size_t max_graph_vram_bytes = 0; + + UpscalerGGML(int n_threads, + bool direct = false, + int tile_size = 128); + + bool load_from_file(const std::string& esrgan_path, + bool offload_params_to_cpu, + int n_threads); + void set_max_graph_vram_bytes(size_t max_vram_bytes); + sd::Tensor upscale_tensor(const sd::Tensor& input_tensor); + sd_image_t upscale(sd_image_t input_image, uint32_t upscale_factor); +}; + +#endif // __SD_UPSCALER_H__ diff --git a/src/util.cpp b/src/util.cpp index e01876268..0b514bb73 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -23,8 +23,9 @@ #include #endif -#include "ggml-cpu.h" +#include "ggml-backend.h" #include "ggml.h" +#include "ggml_extend_backend.hpp" #include "stable-diffusion.h" bool ends_with(const std::string& str, const std::string& ending) { @@ -119,10 +120,10 @@ std::unique_ptr MmapWrapper::create(const std::string& filename) { filename.c_str(), GENERIC_READ, FILE_SHARE_READ, - NULL, + nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, - NULL); + nullptr); if (file_handle == INVALID_HANDLE_VALUE) { return nullptr; @@ -136,16 +137,16 @@ std::unique_ptr MmapWrapper::create(const std::string& filename) { file_size = static_cast(size.QuadPart); - HANDLE mapping_handle = CreateFileMapping(file_handle, NULL, PAGE_READONLY, 0, 0, NULL); + HANDLE mapping_handle = CreateFileMapping(file_handle, nullptr, PAGE_READONLY, 0, 0, nullptr); - if (mapping_handle == NULL) { + if (mapping_handle == nullptr) { CloseHandle(file_handle); return nullptr; } mapped_data = MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, file_size); - if (mapped_data == NULL) { + if (mapped_data == nullptr) { CloseHandle(mapping_handle); CloseHandle(file_handle); return nullptr; @@ -203,7 +204,7 @@ std::unique_ptr MmapWrapper::create(const std::string& filename) { size_t file_size = sb.st_size; - void* mapped_data = mmap(NULL, file_size, PROT_READ, mmap_flags, file_descriptor, 0); + void* mapped_data = mmap(nullptr, file_size, PROT_READ, mmap_flags, file_descriptor, 0); close(file_descriptor); @@ -495,26 +496,6 @@ sd_progress_cb_t sd_get_progress_callback() { void* sd_get_progress_callback_data() { return sd_progress_cb_data; } -const char* sd_get_system_info() { - static char buffer[1024]; - std::stringstream ss; - ss << "System Info: \n"; - ss << " SSE3 = " << ggml_cpu_has_sse3() << " | "; - ss << " AVX = " << ggml_cpu_has_avx() << " | "; - ss << " AVX2 = " << ggml_cpu_has_avx2() << " | "; - ss << " AVX512 = " << ggml_cpu_has_avx512() << " | "; - ss << " AVX512_VBMI = " << ggml_cpu_has_avx512_vbmi() << " | "; - ss << " AVX512_VNNI = " << ggml_cpu_has_avx512_vnni() << " | "; - ss << " FMA = " << ggml_cpu_has_fma() << " | "; - ss << " NEON = " << ggml_cpu_has_neon() << " | "; - ss << " ARM_FMA = " << ggml_cpu_has_arm_fma() << " | "; - ss << " F16C = " << ggml_cpu_has_f16c() << " | "; - ss << " FP16_VA = " << ggml_cpu_has_fp16_va() << " | "; - ss << " WASM_SIMD = " << ggml_cpu_has_wasm_simd() << " | "; - ss << " VSX = " << ggml_cpu_has_vsx() << " | "; - snprintf(buffer, sizeof(buffer), "%s", ss.str().c_str()); - return buffer; -} sd_image_t tensor_to_sd_image(const sd::Tensor& tensor, int frame_index) { const auto& shape = tensor.shape(); @@ -524,17 +505,7 @@ sd_image_t tensor_to_sd_image(const sd::Tensor& tensor, int frame_index) int channel = static_cast(shape[shape.size() == 5 ? 3 : 2]); uint8_t* data = (uint8_t*)malloc(static_cast(width * height * channel)); GGML_ASSERT(data != nullptr); - - for (int iw = 0; iw < width; ++iw) { - for (int ih = 0; ih < height; ++ih) { - for (int ic = 0; ic < channel; ++ic) { - float value = shape.size() == 5 ? tensor.index(iw, ih, frame_index, ic, 0) - : tensor.index(iw, ih, ic, frame_index); - value = std::clamp(value, 0.0f, 1.0f); - data[(ih * width + iw) * channel + ic] = static_cast(std::round(value * 255.0f)); - } - } - } + preprocessing_tensor_frame_to_sd_image(tensor, frame_index, data); return { static_cast(width), static_cast(height), @@ -718,3 +689,100 @@ std::vector> parse_prompt_attention(const std::str return res; } + +// test if the backend is a specific one, e.g. "CUDA", "ROCm", "Vulkan" etc. +bool sd_backend_is(ggml_backend_t backend, const std::string& name) { + if (!backend) { + return false; + } + ggml_backend_dev_t dev = ggml_backend_get_device(backend); + if (!dev) + return false; + std::string dev_name = ggml_backend_dev_name(dev); + return dev_name.find(name) != std::string::npos; +} + +ggml_backend_t sd_get_default_backend() { + ggml_backend_load_all_once(); + static std::once_flag once; + std::call_once(once, []() { + size_t dev_count = ggml_backend_dev_count(); + if (dev_count == 0) { + LOG_ERROR("No devices found!"); + } else { + LOG_DEBUG("Found %zu backend devices:", dev_count); + for (size_t i = 0; i < dev_count; ++i) { + auto dev = ggml_backend_dev_get(i); + LOG_DEBUG("#%zu: %s", i, ggml_backend_dev_name(dev)); + } + } + }); + ggml_backend_t backend = nullptr; + const char* SD_VK_DEVICE = getenv("SD_VK_DEVICE"); + if (SD_VK_DEVICE != nullptr) { + std::string sd_vk_device_str = SD_VK_DEVICE; + try { + unsigned long long device = std::stoull(sd_vk_device_str); + std::string vk_device_name = "Vulkan" + std::to_string(device); + if (backend_name_exists(vk_device_name)) { + LOG_INFO("Selecting %s as main device by env var SD_VK_DEVICE", vk_device_name.c_str()); + backend = init_named_backend(vk_device_name); + if (!backend) { + LOG_WARN("Device %s requested by SD_VK_DEVICE failed to init. Falling back to the default device.", vk_device_name.c_str()); + } + } else { + LOG_WARN("Device %s requested by SD_VK_DEVICE was not found. Falling back to the default device.", vk_device_name.c_str()); + } + } catch (const std::invalid_argument&) { + LOG_WARN("SD_VK_DEVICE environment variable is not a valid integer (%s). Falling back to the default device.", SD_VK_DEVICE); + } catch (const std::out_of_range&) { + LOG_WARN("SD_VK_DEVICE environment variable value is out of range for `unsigned long long` type (%s). Falling back to the default device.", SD_VK_DEVICE); + } + } + + if (!backend) { + std::string dev_name = get_default_backend_name(); + backend = init_named_backend(dev_name); + if (!backend && !dev_name.empty()) { + LOG_WARN("device %s failed to init", dev_name.c_str()); + } + } + + if (!backend) { + LOG_WARN("loading CPU backend"); + backend = ggml_backend_cpu_init(); + } + + if (ggml_backend_is_cpu(backend)) { + LOG_DEBUG("Using CPU backend"); + } + + return backend; +} + +// namespace is needed to avoid conflicts with ggml_backend_extend.hpp +namespace ggml_cpu { +#include "ggml-cpu.h" +} + +const char* sd_get_system_info() { + using namespace ggml_cpu; + static char buffer[1024]; + std::stringstream ss; + ss << "System Info: \n"; + ss << " SSE3 = " << ggml_cpu_has_sse3() << " | "; + ss << " AVX = " << ggml_cpu_has_avx() << " | "; + ss << " AVX2 = " << ggml_cpu_has_avx2() << " | "; + ss << " AVX512 = " << ggml_cpu_has_avx512() << " | "; + ss << " AVX512_VBMI = " << ggml_cpu_has_avx512_vbmi() << " | "; + ss << " AVX512_VNNI = " << ggml_cpu_has_avx512_vnni() << " | "; + ss << " FMA = " << ggml_cpu_has_fma() << " | "; + ss << " NEON = " << ggml_cpu_has_neon() << " | "; + ss << " ARM_FMA = " << ggml_cpu_has_arm_fma() << " | "; + ss << " F16C = " << ggml_cpu_has_f16c() << " | "; + ss << " FP16_VA = " << ggml_cpu_has_fp16_va() << " | "; + ss << " WASM_SIMD = " << ggml_cpu_has_wasm_simd() << " | "; + ss << " VSX = " << ggml_cpu_has_vsx() << " | "; + snprintf(buffer, sizeof(buffer), "%s", ss.str().c_str()); + return buffer; +} diff --git a/src/util.h b/src/util.h index 2468cb93d..72c8a815d 100644 --- a/src/util.h +++ b/src/util.h @@ -6,6 +6,7 @@ #include #include +#include "ggml-backend.h" #include "stable-diffusion.h" #include "tensor.hpp" @@ -82,6 +83,10 @@ int sd_get_preview_interval(); bool sd_should_preview_denoised(); bool sd_should_preview_noisy(); +// test if the backend is a specific one, e.g. "CUDA", "ROCm", "Vulkan" etc. +bool sd_backend_is(ggml_backend_t backend, const std::string& name); +ggml_backend_t sd_get_default_backend(); + #define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__) #define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__) #define LOG_WARN(format, ...) log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__) diff --git a/src/vae.hpp b/src/vae.hpp index 22be8867a..54bd88abf 100644 --- a/src/vae.hpp +++ b/src/vae.hpp @@ -69,7 +69,7 @@ struct VAE : public GGMLRunner { int scale_factor = 8; if (version == VERSION_WAN2_2_TI2V) { scale_factor = 16; - } else if (sd_version_is_flux2(version)) { + } else if (sd_version_uses_flux2_vae(version)) { scale_factor = 16; } else if (version == VERSION_CHROMA_RADIANCE) { scale_factor = 1; @@ -142,9 +142,10 @@ struct VAE : public GGMLRunner { "vae encode compute failed while processing a tile"); } else { output = _compute(n_threads, input, false); - free_compute_buffer(); } + free_compute_buffer(); + if (output.empty()) { LOG_ERROR("vae encode compute failed"); return {}; diff --git a/src/wan.hpp b/src/wan.hpp index 6860262c5..261453301 100644 --- a/src/wan.hpp +++ b/src/wan.hpp @@ -692,6 +692,7 @@ namespace WAN { } else { x = conv1->forward(ctx, x); } + // sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.encoder.prelude", "x"); // downsamples std::vector dims = {dim}; @@ -717,12 +718,14 @@ namespace WAN { x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); } } + // sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.encoder.down." + std::to_string(i), "x"); } // middle x = middle_0->forward(ctx, x, b, feat_cache, feat_idx); x = middle_1->forward(ctx, x, b); x = middle_2->forward(ctx, x, b, feat_cache, feat_idx); + // sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.encoder.mid", "x"); // head x = head_0->forward(ctx, x); @@ -863,11 +866,13 @@ namespace WAN { } else { x = conv1->forward(ctx, x); } + // sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.decoder.prelude", "x"); // middle x = middle_0->forward(ctx, x, b, feat_cache, feat_idx); x = middle_1->forward(ctx, x, b); x = middle_2->forward(ctx, x, b, feat_cache, feat_idx); + // sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.decoder.mid", "x"); // upsamples std::vector dims = {dim_mult[dim_mult.size() - 1] * dim}; @@ -893,6 +898,7 @@ namespace WAN { x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); } } + // sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.decoder.up." + std::to_string(i), "x"); } // head @@ -1031,6 +1037,7 @@ namespace WAN { if (wan2_2) { x = patchify(ctx->ggml_ctx, x, 2, b); } + // sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.encode.prelude", "x"); auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); @@ -1051,6 +1058,7 @@ namespace WAN { } out = conv1->forward(ctx, out); auto mu = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3)[0]; + // sd::ggml_graph_cut::mark_graph_cut(mu, "wan_vae.encode.final", "mu"); clear_cache(); return mu; } @@ -1068,6 +1076,7 @@ namespace WAN { int64_t iter_ = z->ne[2]; auto x = conv2->forward(ctx, z); + // sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.decode.prelude", "x"); ggml_tensor* out; for (int i = 0; i < iter_; i++) { _conv_idx = 0; @@ -1083,6 +1092,7 @@ namespace WAN { if (wan2_2) { out = unpatchify(ctx->ggml_ctx, out, 2, b); } + // sd::ggml_graph_cut::mark_graph_cut(out, "wan_vae.decode.final", "out"); clear_cache(); return out; } @@ -1097,13 +1107,15 @@ namespace WAN { auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); - auto x = conv2->forward(ctx, z); + auto x = conv2->forward(ctx, z); + // sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.decode_partial.prelude", "x"); auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w] _conv_idx = 0; auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); if (wan2_2) { out = unpatchify(ctx->ggml_ctx, out, 2, b); } + // sd::ggml_graph_cut::mark_graph_cut(out, "wan_vae.decode_partial.final", "out"); return out; } }; @@ -1984,6 +1996,13 @@ namespace WAN { c = ggml_reshape_3d(ctx->ggml_ctx, c, c->ne[0] * c->ne[1] * c->ne[2], c->ne[3] / N, N); // [N, dim, t_len*h_len*w_len] c = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, c, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim] } + sd::ggml_graph_cut::mark_graph_cut(x, "wan.prelude", "x"); + // sd::ggml_graph_cut::mark_graph_cut(e, "wan.prelude", "e"); + // sd::ggml_graph_cut::mark_graph_cut(e0, "wan.prelude", "e0"); + // sd::ggml_graph_cut::mark_graph_cut(context, "wan.prelude", "context"); + if (c != nullptr) { + sd::ggml_graph_cut::mark_graph_cut(c, "wan.prelude", "c"); + } auto x_orig = x; @@ -2004,6 +2023,10 @@ namespace WAN { c_skip = ggml_ext_scale(ctx->ggml_ctx, c_skip, vace_strength); x = ggml_add(ctx->ggml_ctx, x, c_skip); } + sd::ggml_graph_cut::mark_graph_cut(x, "wan.blocks." + std::to_string(i), "x"); + if (c != nullptr) { + sd::ggml_graph_cut::mark_graph_cut(c, "wan.blocks." + std::to_string(i), "c"); + } } x = head->forward(ctx, x, e); // [N, t_len*h_len*w_len, pt*ph*pw*out_dim] diff --git a/src/z_image.hpp b/src/z_image.hpp index 363ce5f4f..00b69c264 100644 --- a/src/z_image.hpp +++ b/src/z_image.hpp @@ -31,10 +31,6 @@ namespace ZImage { : head_dim(head_dim), num_heads(num_heads), num_kv_heads(num_kv_heads), qk_norm(qk_norm) { blocks["qkv"] = std::make_shared(hidden_size, (num_heads + num_kv_heads * 2) * head_dim, false); float scale = 1.f; -#if GGML_USE_HIP - // Prevent NaN issues with certain ROCm setups - scale = 1.f / 16.f; -#endif blocks["out"] = std::make_shared(num_heads * head_dim, hidden_size, false, false, false, scale); if (qk_norm) { blocks["q_norm"] = std::make_shared(head_dim); @@ -52,6 +48,10 @@ namespace ZImage { auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); auto out_proj = std::dynamic_pointer_cast(blocks["out"]); + if (sd_backend_is(ctx->backend, "ROCm")) { + out_proj->set_scale(1.f / 16.f); + } + auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim] qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim] @@ -115,9 +115,7 @@ namespace ZImage { bool force_prec_f32 = false; float scale = 1.f / 128.f; -#ifdef SD_USE_VULKAN - force_prec_f32 = true; -#endif + // The purpose of the scale here is to prevent NaN issues in certain situations. // For example, when using CUDA but the weights are k-quants. blocks["w2"] = std::make_shared(hidden_dim, dim, false, false, force_prec_f32, scale); @@ -129,6 +127,10 @@ namespace ZImage { auto w2 = std::dynamic_pointer_cast(blocks["w2"]); auto w3 = std::dynamic_pointer_cast(blocks["w3"]); + if (sd_backend_is(ctx->backend, "Vulkan")) { + w2->set_force_prec_f32(true); + } + auto x1 = w1->forward(ctx, x); auto x3 = w3->forward(ctx, x); x = ggml_swiglu_split(ctx->ggml_ctx, x1, x3); @@ -369,6 +371,9 @@ namespace ZImage { auto txt = cap_embedder_1->forward(ctx, cap_embedder_0->forward(ctx, context)); // [N, n_txt_token, hidden_size] auto img = x_embedder->forward(ctx, x); // [N, n_img_token, hidden_size] + sd::ggml_graph_cut::mark_graph_cut(txt, "z_image.prelude", "txt"); + sd::ggml_graph_cut::mark_graph_cut(img, "z_image.prelude", "img"); + sd::ggml_graph_cut::mark_graph_cut(t_emb, "z_image.prelude", "t_emb"); int64_t n_txt_pad_token = Rope::bound_mod(static_cast(n_txt_token), SEQ_MULTI_OF); if (n_txt_pad_token > 0) { @@ -391,20 +396,24 @@ namespace ZImage { auto block = std::dynamic_pointer_cast(blocks["context_refiner." + std::to_string(i)]); txt = block->forward(ctx, txt, txt_pe, nullptr, nullptr); + sd::ggml_graph_cut::mark_graph_cut(txt, "z_image.context_refiner." + std::to_string(i), "txt"); } for (int i = 0; i < z_image_params.num_refiner_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["noise_refiner." + std::to_string(i)]); img = block->forward(ctx, img, img_pe, nullptr, t_emb); + sd::ggml_graph_cut::mark_graph_cut(img, "z_image.noise_refiner." + std::to_string(i), "img"); } auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_txt_pad_token + n_img_token + n_img_pad_token, hidden_size] + sd::ggml_graph_cut::mark_graph_cut(txt_img, "z_image.prelude", "txt_img"); for (int i = 0; i < z_image_params.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); txt_img = block->forward(ctx, txt_img, pe, nullptr, t_emb); + sd::ggml_graph_cut::mark_graph_cut(txt_img, "z_image.layers." + std::to_string(i), "txt_img"); } txt_img = final_layer->forward(ctx, txt_img, t_emb); // [N, n_txt_token + n_txt_pad_token + n_img_token + n_img_pad_token, ph*pw*C] diff --git a/thirdparty/CMakeLists.txt b/thirdparty/CMakeLists.txt index a17178507..4dfdf0d29 100644 --- a/thirdparty/CMakeLists.txt +++ b/thirdparty/CMakeLists.txt @@ -2,7 +2,7 @@ set(Z_TARGET zip) add_library(${Z_TARGET} OBJECT zip.c zip.h miniz.h) target_include_directories(${Z_TARGET} PUBLIC .) -if(SD_WEBP) +if(SD_WEBP AND NOT SD_USE_SYSTEM_WEBP) set(WEBP_BUILD_ANIM_UTILS OFF) set(WEBP_BUILD_CWEBP OFF) set(WEBP_BUILD_DWEBP OFF) @@ -18,3 +18,28 @@ if(SD_WEBP) add_subdirectory(libwebp EXCLUDE_FROM_ALL) endif() + +if(SD_WEBM AND NOT SD_USE_SYSTEM_WEBM) + if(MSVC) + set(MSVC_RUNTIME dll) + endif() + set(ENABLE_WEBMTS OFF) + set(ENABLE_WEBMINFO OFF) + set(ENABLE_TESTS OFF) + set(ENABLE_WEBM_PARSER OFF) + set(ENABLE_SAMPLE_PROGRAMS OFF) + + set(SD_LIBWEBM_PARENT_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + + add_subdirectory(libwebm EXCLUDE_FROM_ALL) + + # libwebm mutates the global CMAKE_CXX_FLAGS for non-MSVC compilers to force + # C++11. Restore the parent flags so the main project keeps its own C++17 + # requirements, then pin the libwebm targets to C++17 explicitly. + set(CMAKE_CXX_FLAGS "${SD_LIBWEBM_PARENT_CXX_FLAGS}" CACHE STRING "" FORCE) + target_compile_features(mkvmuxer PRIVATE cxx_std_17) + target_compile_features(mkvparser PRIVATE cxx_std_17) + target_compile_features(webm PRIVATE cxx_std_17) + + target_include_directories(webm INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/libwebm) +endif() diff --git a/thirdparty/libwebm b/thirdparty/libwebm new file mode 160000 index 000000000..5bf12267e --- /dev/null +++ b/thirdparty/libwebm @@ -0,0 +1 @@ +Subproject commit 5bf12267eea773a32fcf4949de52b0add158a8d5