diff --git a/examples/server/main.cpp b/examples/server/main.cpp index def499755..4de46fa6a 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -263,6 +263,11 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { log_print(level, log, svr_params->verbose, svr_params->color); } +struct LoraEntry { + std::string name; + std::string path; +}; + int main(int argc, const char** argv) { if (argc > 1 && std::string(argv[1]) == "--version") { std::cout << version_string() << "\n"; @@ -293,6 +298,54 @@ int main(int argc, const char** argv) { std::mutex sd_ctx_mutex; + std::vector lora_cache; + std::mutex lora_mutex; + + auto refresh_lora_cache = [&]() { + std::vector new_cache; + + fs::path lora_dir = ctx_params.lora_model_dir; + if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) { + auto is_lora_ext = [](const fs::path& p) { + auto ext = p.extension().string(); + std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); + return ext == ".gguf" || ext == ".pt" || ext == ".pth" || ext == ".safetensors"; + }; + + for (auto& entry : fs::recursive_directory_iterator(lora_dir)) { + if (!entry.is_regular_file()) + continue; + const fs::path& p = entry.path(); + if (!is_lora_ext(p)) + continue; + + LoraEntry e; + e.name = p.stem().u8string(); + std::string rel = fs::relative(p, lora_dir).u8string(); + std::replace(rel.begin(), rel.end(), '\\', '/'); + e.path = rel; + + new_cache.push_back(std::move(e)); + } + } + + std::sort(new_cache.begin(), new_cache.end(), + [](const LoraEntry& a, const LoraEntry& b) { + return a.path < b.path; + }); + + { + std::lock_guard lock(lora_mutex); + lora_cache = std::move(new_cache); + } + }; + + auto is_valid_lora_path = [&](const std::string& path) -> bool { + std::lock_guard lock(lora_mutex); + return std::any_of(lora_cache.begin(), lora_cache.end(), + [&](const LoraEntry& e) { return e.path == path; }); + }; + httplib::Server svr; svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) { @@ -312,7 +365,7 @@ int main(int argc, const char** argv) { return httplib::Server::HandlerResponse::Unhandled; }); - // health + // root svr.Get("/", [&](const httplib::Request&, httplib::Response& res) { if (!svr_params.serve_html_path.empty()) { std::ifstream file(svr_params.serve_html_path); @@ -767,6 +820,37 @@ int main(int argc, const char** argv) { return bad("prompt required"); } + std::vector sd_loras; + std::vector lora_path_storage; + + if (j.contains("lora") && j["lora"].is_array()) { + for (const auto& item : j["lora"]) { + if (!item.is_object()) { + continue; + } + + std::string path = item.value("path", ""); + float multiplier = item.value("multiplier", 1.0f); + bool is_high_noise = item.value("is_high_noise", false); + + if (path.empty()) { + return bad("lora.path required"); + } + + if (!is_valid_lora_path(path)) { + return bad("invalid lora path: " + path); + } + + lora_path_storage.push_back(path); + sd_lora_t l; + l.is_high_noise = is_high_noise; + l.multiplier = multiplier; + l.path = lora_path_storage.back().c_str(); + + sd_loras.push_back(l); + } + } + auto get_sample_method = [](std::string name) -> enum sample_method_t { enum sample_method_t result = str_to_sample_method(name.c_str()); if (result != SAMPLE_METHOD_COUNT) return result; @@ -894,8 +978,8 @@ int main(int argc, const char** argv) { } sd_img_gen_params_t img_gen_params = { - gen_params.lora_vec.data(), - static_cast(gen_params.lora_vec.size()), + sd_loras.data(), + static_cast(sd_loras.size()), gen_params.prompt.c_str(), gen_params.negative_prompt.c_str(), gen_params.clip_skip, @@ -987,6 +1071,23 @@ int main(int argc, const char** argv) { sdapi_any2img(req, res, true); }); + svr.Get("/sdapi/v1/loras", [&](const httplib::Request&, httplib::Response& res) { + refresh_lora_cache(); + + json result = json::array(); + { + std::lock_guard lock(lora_mutex); + for (const auto& e : lora_cache) { + json item; + item["name"] = e.name; + item["path"] = e.path; + result.push_back(item); + } + } + + res.set_content(result.dump(), "application/json"); + }); + svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) { std::vector sampler_names; sampler_names.push_back("default");