Skip to content

Commit 4ca7ae3

Browse files
committed
feat: support mmap for model loading
1 parent 43a70e8 commit 4ca7ae3

File tree

8 files changed

+176
-9
lines changed

8 files changed

+176
-9
lines changed

examples/cli/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,4 +814,4 @@ int main(int argc, const char* argv[]) {
814814
release_all_resources();
815815

816816
return 0;
817-
}
817+
}

examples/common/common.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ struct SDContextParams {
334334
rng_type_t rng_type = CUDA_RNG;
335335
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
336336
bool offload_params_to_cpu = false;
337+
bool enable_mmap = true;
337338
bool control_net_cpu = false;
338339
bool clip_on_cpu = false;
339340
bool vae_on_cpu = false;
@@ -469,6 +470,10 @@ struct SDContextParams {
469470
"--offload-to-cpu",
470471
"place the weights in RAM to save VRAM, and automatically load them into VRAM when needed",
471472
true, &offload_params_to_cpu},
473+
{"",
474+
"--no-mmap",
475+
"do not memory-map model files",
476+
false, &enable_mmap},
472477
{"",
473478
"--control-net-cpu",
474479
"keep controlnet in cpu (for low vram)",
@@ -750,6 +755,7 @@ struct SDContextParams {
750755
<< " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n"
751756
<< " flow_shift: " << (std::isinf(flow_shift) ? "INF" : std::to_string(flow_shift)) << "\n"
752757
<< " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n"
758+
<< " enable_mmap: " << (enable_mmap ? "true" : "false") << ",\n"
753759
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
754760
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
755761
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
@@ -810,6 +816,7 @@ struct SDContextParams {
810816
prediction,
811817
lora_apply_mode,
812818
offload_params_to_cpu,
819+
enable_mmap,
813820
clip_on_cpu,
814821
control_net_cpu,
815822
vae_on_cpu,
@@ -1800,4 +1807,4 @@ uint8_t* load_image_from_memory(const char* image_bytes,
18001807
int expected_height = 0,
18011808
int expected_channel = 3) {
18021809
return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel);
1803-
}
1810+
}

model.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,7 +1340,7 @@ std::string ModelLoader::load_umt5_tokenizer_json() {
13401340
return json_str;
13411341
}
13421342

1343-
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p) {
1343+
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p, bool enable_mmap) {
13441344
int64_t process_time_ms = 0;
13451345
std::atomic<int64_t> read_time_ms(0);
13461346
std::atomic<int64_t> memcpy_time_ms(0);
@@ -1390,6 +1390,15 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
13901390
}
13911391
}
13921392

1393+
std::shared_ptr<MmapWrapper> mmapped;
1394+
if (enable_mmap && !is_zip) {
1395+
LOG_DEBUG("using mmap for I/O");
1396+
mmapped = MmapWrapper::create(file_path);
1397+
if (!mmapped) {
1398+
LOG_WARN("failed to memory-map '%s'", file_path.c_str());
1399+
}
1400+
}
1401+
13931402
int n_threads = is_zip ? 1 : std::min(num_threads_to_use, (int)file_tensors.size());
13941403
if (n_threads < 1) {
13951404
n_threads = 1;
@@ -1411,7 +1420,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
14111420
failed = true;
14121421
return;
14131422
}
1414-
} else {
1423+
} else if (!mmapped) {
14151424
file.open(file_path, std::ios::binary);
14161425
if (!file.is_open()) {
14171426
LOG_ERROR("failed to open '%s'", file_path.c_str());
@@ -1464,6 +1473,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
14641473
zip_entry_noallocread(zip, (void*)buf, n);
14651474
}
14661475
zip_entry_close(zip);
1476+
} else if (mmapped) {
1477+
if (!mmapped->copy_data(buf, n, tensor_storage.offset)) {
1478+
LOG_ERROR("read tensor data failed: '%s'", file_path.c_str());
1479+
failed = true;
1480+
}
14671481
} else {
14681482
file.seekg(tensor_storage.offset);
14691483
file.read(buf, n);
@@ -1583,7 +1597,8 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
15831597

15841598
bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
15851599
std::set<std::string> ignore_tensors,
1586-
int n_threads) {
1600+
int n_threads,
1601+
bool enable_mmap) {
15871602
std::set<std::string> tensor_names_in_file;
15881603
std::mutex tensor_names_mutex;
15891604
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
@@ -1626,7 +1641,7 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
16261641
return true;
16271642
};
16281643

1629-
bool success = load_tensors(on_new_tensor_cb, n_threads);
1644+
bool success = load_tensors(on_new_tensor_cb, n_threads, enable_mmap);
16301645
if (!success) {
16311646
LOG_ERROR("load tensors from file failed");
16321647
return false;

model.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,11 @@ class ModelLoader {
310310
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
311311
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
312312
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
313-
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0);
313+
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false);
314314
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
315315
std::set<std::string> ignore_tensors = {},
316-
int n_threads = 0);
316+
int n_threads = 0,
317+
bool use_mmap = false);
317318

318319
std::vector<std::string> get_tensor_names() const {
319320
std::vector<std::string> names;

stable-diffusion.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ class StableDiffusionGGML {
710710
if (version == VERSION_SVD) {
711711
ignore_tensors.insert("conditioner.embedders.3");
712712
}
713-
bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads);
713+
bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads, sd_ctx_params->enable_mmap);
714714
if (!success) {
715715
LOG_ERROR("load tensors from model loader failed");
716716
ggml_free(ctx);
@@ -2507,6 +2507,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
25072507
sd_ctx_params->prediction = PREDICTION_COUNT;
25082508
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
25092509
sd_ctx_params->offload_params_to_cpu = false;
2510+
sd_ctx_params->enable_mmap = true;
25102511
sd_ctx_params->keep_clip_on_cpu = false;
25112512
sd_ctx_params->keep_control_net_on_cpu = false;
25122513
sd_ctx_params->keep_vae_on_cpu = false;

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ typedef struct {
182182
enum prediction_t prediction;
183183
enum lora_apply_mode_t lora_apply_mode;
184184
bool offload_params_to_cpu;
185+
bool enable_mmap;
185186
bool keep_clip_on_cpu;
186187
bool keep_control_net_on_cpu;
187188
bool keep_vae_on_cpu;

util.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,72 @@ bool is_directory(const std::string& path) {
9595
return (attributes != INVALID_FILE_ATTRIBUTES && (attributes & FILE_ATTRIBUTE_DIRECTORY));
9696
}
9797

98+
class MmapWrapperImpl : public MmapWrapper {
99+
public:
100+
MmapWrapperImpl(void* data, size_t size, HANDLE hfile, HANDLE hmapping)
101+
: MmapWrapper(data, size), hfile_(hfile), hmapping_(hmapping) {}
102+
103+
~MmapWrapperImpl() override {
104+
UnmapViewOfFile(data_);
105+
CloseHandle(hmapping_);
106+
CloseHandle(hfile_);
107+
}
108+
109+
private:
110+
HANDLE hfile_;
111+
HANDLE hmapping_;
112+
};
113+
114+
std::shared_ptr<MmapWrapper> MmapWrapper::create(const std::string& filename) {
115+
void* mapped_data = nullptr;
116+
size_t file_size = 0;
117+
118+
HANDLE file_handle = CreateFileA(
119+
filename.c_str(),
120+
GENERIC_READ,
121+
FILE_SHARE_READ,
122+
NULL,
123+
OPEN_EXISTING,
124+
FILE_ATTRIBUTE_NORMAL,
125+
NULL
126+
);
127+
128+
if (file_handle == INVALID_HANDLE_VALUE) {
129+
return nullptr;
130+
}
131+
132+
LARGE_INTEGER size;
133+
if (!GetFileSizeEx(file_handle, &size)) {
134+
CloseHandle(file_handle);
135+
return nullptr;
136+
}
137+
138+
file_size = static_cast<size_t>(size.QuadPart);
139+
140+
HANDLE mapping_handle = CreateFileMapping(file_handle, NULL, PAGE_READONLY, 0, 0, NULL);
141+
142+
if (mapping_handle == NULL) {
143+
CloseHandle(file_handle);
144+
return nullptr;
145+
}
146+
147+
mapped_data = MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, file_size);
148+
149+
if (mapped_data == NULL) {
150+
CloseHandle(mapping_handle);
151+
CloseHandle(file_handle);
152+
return nullptr;
153+
}
154+
155+
return std::make_shared<MmapWrapperImpl>(mapped_data, file_size, file_handle, mapping_handle);
156+
}
157+
98158
#else // Unix
99159
#include <dirent.h>
160+
#include <fcntl.h>
161+
#include <sys/mman.h>
100162
#include <sys/stat.h>
163+
#include <unistd.h>
101164

102165
bool file_exists(const std::string& filename) {
103166
struct stat buffer;
@@ -109,8 +172,64 @@ bool is_directory(const std::string& path) {
109172
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
110173
}
111174

175+
class MmapWrapperImpl : public MmapWrapper {
176+
public:
177+
MmapWrapperImpl(void* data, size_t size) : MmapWrapper(data, size) {}
178+
179+
~MmapWrapperImpl() override {
180+
munmap(data_, size_);
181+
}
182+
};
183+
184+
std::shared_ptr<MmapWrapper> MmapWrapper::create(const std::string& filename) {
185+
186+
int file_descriptor = open(filename.c_str(), O_RDONLY);
187+
if (file_descriptor == -1) {
188+
return nullptr;
189+
}
190+
191+
int mmap_flags = MAP_PRIVATE;
192+
193+
#ifdef __linux__
194+
// performance flags used by llama.cpp
195+
//posix_fadvise(file_descriptor, 0, 0, POSIX_FADV_SEQUENTIAL);
196+
//mmap_flags |= MAP_POPULATE;
197+
#endif
198+
199+
struct stat sb;
200+
if (fstat(file_descriptor, &sb) == -1) {
201+
close(file_descriptor);
202+
return nullptr;
203+
}
204+
205+
size_t file_size = sb.st_size;
206+
207+
void* mapped_data = mmap(NULL, file_size, PROT_READ, mmap_flags, file_descriptor, 0);
208+
209+
close(file_descriptor);
210+
211+
if (mapped_data == MAP_FAILED) {
212+
return nullptr;
213+
}
214+
215+
#ifdef __linux__
216+
// performance flags used by llama.cpp
217+
//posix_madvise(mapped_data, file_size, POSIX_MADV_WILLNEED);
218+
#endif
219+
220+
return std::make_shared<MmapWrapperImpl>(mapped_data, file_size);
221+
}
222+
112223
#endif
113224

225+
bool MmapWrapper::copy_data(void* buf, size_t n, size_t offset) const {
226+
if (offset >= size_ || n > (size_ - offset)) {
227+
return false;
228+
}
229+
std::memcpy(buf, data() + offset, n);
230+
return true;
231+
}
232+
114233
// get_num_physical_cores is copy from
115234
// https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
116235
// LICENSE: https://github.com/ggerganov/llama.cpp/blob/master/LICENSE

util.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define __UTIL_H__
33

44
#include <cstdint>
5+
#include <memory>
56
#include <string>
67
#include <vector>
78

@@ -43,6 +44,28 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
4344

4445
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height);
4546

47+
class MmapWrapper {
48+
public:
49+
static std::shared_ptr<MmapWrapper> create(const std::string& filename);
50+
51+
virtual ~MmapWrapper() = default;
52+
53+
MmapWrapper(const MmapWrapper&) = delete;
54+
MmapWrapper& operator=(const MmapWrapper&) = delete;
55+
MmapWrapper(MmapWrapper&&) = delete;
56+
MmapWrapper& operator=(MmapWrapper&&) = delete;
57+
58+
const uint8_t* data() const { return static_cast<uint8_t*>(data_); }
59+
size_t size() const { return size_; }
60+
bool copy_data(void* buf, size_t n, size_t offset) const;
61+
62+
protected:
63+
MmapWrapper(void* data, size_t size)
64+
: data_(data), size_(size) {}
65+
void* data_ = nullptr;
66+
size_t size_ = 0;
67+
};
68+
4669
std::string path_join(const std::string& p1, const std::string& p2);
4770
std::vector<std::string> split_string(const std::string& str, char delimiter);
4871
void pretty_progress(int step, int steps, float time);

0 commit comments

Comments
 (0)