Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 104 additions & 3 deletions examples/server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
log_print(level, log, svr_params->verbose, svr_params->color);
}

struct LoraEntry {
std::string name;
std::string path;
};

int main(int argc, const char** argv) {
if (argc > 1 && std::string(argv[1]) == "--version") {
std::cout << version_string() << "\n";
Expand Down Expand Up @@ -293,6 +298,54 @@ int main(int argc, const char** argv) {

std::mutex sd_ctx_mutex;

std::vector<LoraEntry> lora_cache;
std::mutex lora_mutex;

auto refresh_lora_cache = [&]() {
std::vector<LoraEntry> new_cache;

fs::path lora_dir = ctx_params.lora_model_dir;
if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) {
auto is_lora_ext = [](const fs::path& p) {
auto ext = p.extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
return ext == ".gguf" || ext == ".pt" || ext == ".pth" || ext == ".safetensors";
};

for (auto& entry : fs::recursive_directory_iterator(lora_dir)) {
if (!entry.is_regular_file())
continue;
const fs::path& p = entry.path();
if (!is_lora_ext(p))
continue;

LoraEntry e;
e.name = p.stem().u8string();
std::string rel = fs::relative(p, lora_dir).u8string();
std::replace(rel.begin(), rel.end(), '\\', '/');
e.path = rel;

new_cache.push_back(std::move(e));
}
}

std::sort(new_cache.begin(), new_cache.end(),
[](const LoraEntry& a, const LoraEntry& b) {
return a.path < b.path;
});

{
std::lock_guard<std::mutex> lock(lora_mutex);
lora_cache = std::move(new_cache);
}
};

auto is_valid_lora_path = [&](const std::string& path) -> bool {
std::lock_guard<std::mutex> lock(lora_mutex);
return std::any_of(lora_cache.begin(), lora_cache.end(),
[&](const LoraEntry& e) { return e.path == path; });
};

httplib::Server svr;

svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
Expand All @@ -312,7 +365,7 @@ int main(int argc, const char** argv) {
return httplib::Server::HandlerResponse::Unhandled;
});

// health
// root
svr.Get("/", [&](const httplib::Request&, httplib::Response& res) {
if (!svr_params.serve_html_path.empty()) {
std::ifstream file(svr_params.serve_html_path);
Expand Down Expand Up @@ -767,6 +820,37 @@ int main(int argc, const char** argv) {
return bad("prompt required");
}

std::vector<sd_lora_t> sd_loras;
std::vector<std::string> lora_path_storage;

if (j.contains("lora") && j["lora"].is_array()) {
for (const auto& item : j["lora"]) {
if (!item.is_object()) {
continue;
}

std::string path = item.value("path", "");
float multiplier = item.value("multiplier", 1.0f);
bool is_high_noise = item.value("is_high_noise", false);

if (path.empty()) {
return bad("lora.path required");
}

if (!is_valid_lora_path(path)) {
return bad("invalid lora path: " + path);
}

lora_path_storage.push_back(path);
sd_lora_t l;
l.is_high_noise = is_high_noise;
l.multiplier = multiplier;
l.path = lora_path_storage.back().c_str();

sd_loras.push_back(l);
}
}

auto get_sample_method = [](std::string name) -> enum sample_method_t {
enum sample_method_t result = str_to_sample_method(name.c_str());
if (result != SAMPLE_METHOD_COUNT) return result;
Expand Down Expand Up @@ -894,8 +978,8 @@ int main(int argc, const char** argv) {
}

sd_img_gen_params_t img_gen_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
sd_loras.data(),
static_cast<uint32_t>(sd_loras.size()),
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.clip_skip,
Expand Down Expand Up @@ -987,6 +1071,23 @@ int main(int argc, const char** argv) {
sdapi_any2img(req, res, true);
});

svr.Get("/sdapi/v1/loras", [&](const httplib::Request&, httplib::Response& res) {
refresh_lora_cache();

json result = json::array();
{
std::lock_guard<std::mutex> lock(lora_mutex);
for (const auto& e : lora_cache) {
json item;
item["name"] = e.name;
item["path"] = e.path;
result.push_back(item);
}
}

res.set_content(result.dump(), "application/json");
});

svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> sampler_names;
sampler_names.push_back("default");
Expand Down
Loading