Skip to content

Commit 0364178

Browse files
ngxsonggerganov
andauthored
clip : refactor clip_init, add tests (#12757)
* refactor clip_init * fix loading file * fix style * test ok * better test with report * add missing headers * clarify * add KEY_MM_PATCH_MERGE_TYPE * remove bool has_* pattern * Apply suggestions from code review Co-authored-by: Georgi Gerganov <[email protected]> * Update examples/llava/clip.cpp Co-authored-by: Georgi Gerganov <[email protected]> * use ggml_soft_max_ext * refactor logging system * add minicpm-v-o 2.6 for testing * use nullptr everywhere * fix Yi-VL model --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent c6ff5d2 commit 0364178

File tree

9 files changed

+920
-891
lines changed

9 files changed

+920
-891
lines changed

examples/llava/clip-impl.h

+273
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
#include "ggml.h"
2+
#include "gguf.h"
3+
4+
#include <climits>
5+
#include <cstdarg>
6+
#include <string>
7+
#include <map>
8+
#include <sstream>
9+
#include <vector>
10+
11+
// Internal header for clip.cpp
12+
13+
#define KEY_FTYPE "general.file_type"
14+
#define KEY_NAME "general.name"
15+
#define KEY_DESCRIPTION "general.description"
16+
#define KEY_HAS_TEXT_ENC "clip.has_text_encoder"
17+
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
18+
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
19+
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
20+
#define KEY_HAS_GLM_PROJ "clip.has_glm_projector"
21+
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
22+
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
23+
#define KEY_USE_GELU "clip.use_gelu"
24+
#define KEY_USE_SILU "clip.use_silu"
25+
#define KEY_N_EMBD "clip.%s.embedding_length"
26+
#define KEY_N_FF "clip.%s.feed_forward_length"
27+
#define KEY_N_BLOCK "clip.%s.block_count"
28+
#define KEY_N_HEAD "clip.%s.attention.head_count"
29+
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
30+
#define KEY_PROJ_DIM "clip.%s.projection_dim"
31+
#define KEY_TOKENS "tokenizer.ggml.tokens"
32+
#define KEY_N_POSITIONS "clip.text.context_length"
33+
#define KEY_IMAGE_SIZE "clip.vision.image_size"
34+
#define KEY_PATCH_SIZE "clip.vision.patch_size"
35+
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
36+
#define KEY_IMAGE_STD "clip.vision.image_std"
37+
#define KEY_PROJ_TYPE "clip.projector_type"
38+
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
39+
40+
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
41+
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
42+
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
43+
44+
45+
//
46+
// tensor name constants
47+
//
48+
49+
#define TN_TOKEN_EMBD "%s.token_embd.weight"
50+
#define TN_POS_EMBD "%s.position_embd.weight"
51+
#define TN_CLASS_EMBD "v.class_embd"
52+
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
53+
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
54+
#define TN_PATCH_BIAS "v.patch_embd.bias"
55+
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
56+
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
57+
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
58+
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
59+
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
60+
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
61+
#define TN_LN_1 "%s.blk.%d.ln1.%s"
62+
#define TN_LN_2 "%s.blk.%d.ln2.%s"
63+
#define TN_LN_PRE "%s.pre_ln.%s"
64+
#define TN_LN_POST "%s.post_ln.%s"
65+
#define TN_TEXT_PROJ "text_projection.weight"
66+
#define TN_VIS_PROJ "visual_projection.weight"
67+
#define TN_LLAVA_PROJ "mm.%d.%s"
68+
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
69+
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
70+
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
71+
#define TN_IMAGE_NEWLINE "model.image_newline"
72+
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
73+
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
74+
75+
// mimicpmv
76+
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
77+
#define TN_MINICPMV_QUERY "resampler.query"
78+
#define TN_MINICPMV_PROJ "resampler.proj.weight"
79+
#define TN_MINICPMV_KV_PROJ "resampler.kv.weight"
80+
#define TN_MINICPMV_ATTN "resampler.attn.%s.%s"
81+
#define TN_MINICPMV_LN "resampler.ln_%s.%s"
82+
83+
#define TN_GLM_ADAPER_CONV "adapter.conv.%s"
84+
#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s"
85+
#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s"
86+
#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s"
87+
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
88+
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
89+
#define TN_GLM_BOI_W "adapter.boi"
90+
#define TN_GLM_EOI_W "adapter.eoi"
91+
92+
enum projector_type {
93+
PROJECTOR_TYPE_MLP,
94+
PROJECTOR_TYPE_MLP_NORM,
95+
PROJECTOR_TYPE_LDP,
96+
PROJECTOR_TYPE_LDPV2,
97+
PROJECTOR_TYPE_RESAMPLER,
98+
PROJECTOR_TYPE_GLM_EDGE,
99+
PROJECTOR_TYPE_MERGER,
100+
PROJECTOR_TYPE_GEMMA3,
101+
PROJECTOR_TYPE_UNKNOWN,
102+
};
103+
104+
static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
105+
{ PROJECTOR_TYPE_MLP, "mlp" },
106+
{ PROJECTOR_TYPE_LDP, "ldp" },
107+
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
108+
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
109+
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
110+
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
111+
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
112+
};
113+
114+
static projector_type clip_projector_type_from_string(const std::string & str) {
115+
for (const auto & pair : PROJECTOR_TYPE_NAMES) {
116+
if (pair.second == str) {
117+
return pair.first;
118+
}
119+
}
120+
return PROJECTOR_TYPE_UNKNOWN;
121+
}
122+
123+
//
124+
// logging
125+
//
126+
127+
static void clip_log_callback_default(enum ggml_log_level level, const char * text, void * user_data) {
128+
(void) level;
129+
(void) user_data;
130+
fputs(text, stderr);
131+
fflush(stderr);
132+
}
133+
134+
struct clip_logger_state {
135+
ggml_log_level verbosity_thold;
136+
ggml_log_callback log_callback;
137+
void * log_callback_user_data;
138+
};
139+
140+
extern struct clip_logger_state g_logger_state;
141+
142+
static void clip_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
143+
if (format == NULL) {
144+
return;
145+
}
146+
va_list args_copy;
147+
va_copy(args_copy, args);
148+
char buffer[128];
149+
int len = vsnprintf(buffer, 128, format, args);
150+
if (len < 128) {
151+
g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
152+
} else {
153+
char * buffer2 = (char *) calloc(len + 1, sizeof(char));
154+
vsnprintf(buffer2, len + 1, format, args_copy);
155+
buffer2[len] = 0;
156+
g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
157+
free(buffer2);
158+
}
159+
va_end(args_copy);
160+
}
161+
162+
static void clip_log_internal(enum ggml_log_level level, const char * format, ...) {
163+
va_list args;
164+
va_start(args, format);
165+
clip_log_internal_v(level, format, args);
166+
va_end(args);
167+
}
168+
169+
#define LOG_TMPL(level, ...) \
170+
do { \
171+
if ((level) >= g_logger_state.verbosity_thold) { \
172+
clip_log_internal((level), __VA_ARGS__); \
173+
} \
174+
} while (0)
175+
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
176+
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
177+
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
178+
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
179+
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
180+
181+
//
182+
// common utils
183+
//
184+
185+
static std::string string_format(const char * fmt, ...) {
186+
va_list ap;
187+
va_list ap2;
188+
va_start(ap, fmt);
189+
va_copy(ap2, ap);
190+
int size = vsnprintf(NULL, 0, fmt, ap);
191+
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
192+
std::vector<char> buf(size + 1);
193+
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
194+
GGML_ASSERT(size2 == size);
195+
va_end(ap2);
196+
va_end(ap);
197+
return std::string(buf.data(), buf.size());
198+
}
199+
200+
static void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
201+
if (search.empty()) {
202+
return;
203+
}
204+
std::string builder;
205+
builder.reserve(s.length());
206+
size_t pos = 0;
207+
size_t last_pos = 0;
208+
while ((pos = s.find(search, last_pos)) != std::string::npos) {
209+
builder.append(s, last_pos, pos - last_pos);
210+
builder.append(replace);
211+
last_pos = pos + search.length();
212+
}
213+
builder.append(s, last_pos, std::string::npos);
214+
s = std::move(builder);
215+
}
216+
217+
//
218+
// gguf utils
219+
//
220+
221+
static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
222+
switch (type) {
223+
case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]);
224+
case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]);
225+
case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]);
226+
case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]);
227+
case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]);
228+
case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]);
229+
case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]);
230+
case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]);
231+
case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]);
232+
case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]);
233+
case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false";
234+
default: return string_format("unknown type %d", type);
235+
}
236+
}
237+
238+
static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
239+
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
240+
241+
switch (type) {
242+
case GGUF_TYPE_STRING:
243+
return gguf_get_val_str(ctx_gguf, i);
244+
case GGUF_TYPE_ARRAY:
245+
{
246+
const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
247+
int arr_n = gguf_get_arr_n(ctx_gguf, i);
248+
const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i);
249+
std::stringstream ss;
250+
ss << "[";
251+
for (int j = 0; j < arr_n; j++) {
252+
if (arr_type == GGUF_TYPE_STRING) {
253+
std::string val = gguf_get_arr_str(ctx_gguf, i, j);
254+
// escape quotes
255+
string_replace_all(val, "\\", "\\\\");
256+
string_replace_all(val, "\"", "\\\"");
257+
ss << '"' << val << '"';
258+
} else if (arr_type == GGUF_TYPE_ARRAY) {
259+
ss << "???";
260+
} else {
261+
ss << gguf_data_to_str(arr_type, data, j);
262+
}
263+
if (j < arr_n - 1) {
264+
ss << ", ";
265+
}
266+
}
267+
ss << "]";
268+
return ss.str();
269+
}
270+
default:
271+
return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
272+
}
273+
}

0 commit comments

Comments
 (0)