Skip to content

Commit 0d38434

Browse files
committed
Correct the format
1 parent f0f5d79 commit 0d38434

File tree

4 files changed

+35
-40
lines changed

4 files changed

+35
-40
lines changed

modelcache/embedding/data2vec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
def mean_pooling(model_output, attention_mask):
11-
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
11+
token_embeddings = model_output[0]
1212
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
1313
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
1414

modelcache_mm/adapter/adapter_query.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -47,37 +47,36 @@ def adapt_query(cache_data_convert, *args, **kwargs):
4747
else:
4848
raise MultiTypeError
4949

50-
embedding_data = None
51-
mm_type = None
52-
if cache_enable:
53-
if pre_multi_type == 'IMG_TEXT':
54-
embedding_data_resp = time_cal(
55-
chat_cache.embedding_func,
56-
func_name="iat_embedding",
57-
report_func=chat_cache.report.embedding,
58-
)(data_dict)
59-
else:
60-
embedding_data_resp = time_cal(
61-
chat_cache.embedding_func,
62-
func_name="iat_embedding",
63-
report_func=chat_cache.report.embedding,
64-
)(data_dict)
65-
image_embeddings = embedding_data_resp['image_embedding']
66-
text_embeddings = embedding_data_resp['text_embeddings']
67-
68-
if len(image_embeddings) > 0 and len(image_embeddings) > 0:
69-
embedding_data = np.concatenate((image_embeddings, text_embeddings))
70-
mm_type = 'mm'
71-
elif len(image_embeddings) > 0:
72-
image_embedding = np.array(image_embeddings[0])
73-
embedding_data = image_embedding
74-
mm_type = 'image'
75-
elif len(text_embeddings) > 0:
76-
text_embedding = np.array(text_embeddings[0])
77-
embedding_data = text_embedding
78-
mm_type = 'text'
79-
else:
80-
raise ValueError('maya embedding service return both empty list, please check!')
50+
# embedding_data = None
51+
# mm_type = None
52+
if pre_multi_type == 'IMG_TEXT':
53+
embedding_data_resp = time_cal(
54+
chat_cache.embedding_func,
55+
func_name="mm_embedding",
56+
report_func=chat_cache.report.embedding,
57+
)(data_dict)
58+
else:
59+
embedding_data_resp = time_cal(
60+
chat_cache.embedding_func,
61+
func_name="mm_embedding",
62+
report_func=chat_cache.report.embedding,
63+
)(data_dict)
64+
image_embeddings = embedding_data_resp['image_embedding']
65+
text_embeddings = embedding_data_resp['text_embeddings']
66+
67+
if len(image_embeddings) > 0 and len(image_embeddings) > 0:
68+
embedding_data = np.concatenate((image_embeddings, text_embeddings))
69+
# mm_type = 'mm'
70+
elif len(image_embeddings) > 0:
71+
image_embedding = np.array(image_embeddings[0])
72+
embedding_data = image_embedding
73+
# mm_type = 'image'
74+
elif len(text_embeddings) > 0:
75+
text_embedding = np.array(text_embeddings[0])
76+
embedding_data = text_embedding
77+
# mm_type = 'text'
78+
else:
79+
raise ValueError('maya embedding service return both empty list, please check!')
8180

8281
if cache_enable:
8382
cache_data_list = time_cal(

modelcache_mm/embedding/clip.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def to_embeddings(self, data_dict, **_):
1818
text_list = data_dict['text']
1919
image_data = data_dict['image']
2020

21-
img_data = None
22-
txt_data = None
21+
# img_data = None
22+
# txt_data = None
2323

2424
if image_data:
2525
input_img = load_image(image_data)
@@ -46,8 +46,4 @@ def post_proc(self, token_embeddings, inputs):
4646

4747
@property
4848
def dimension(self):
49-
"""Embedding dimension.
50-
51-
:return: embedding dimension
52-
"""
5349
return self.__dimension

modelcache_mm/manager/data_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def save_query_resp(self, query_resp_dict, **kwargs):
3434
@abstractmethod
3535
def import_data(
3636
self, texts: List[Any], image_urls: List[Any], image_ids: List[Any], answers: List[Answer],
37-
embeddings: List[Any], model: Any, iat_type: Any
37+
embeddings: List[Any], model: Any, mm_type: Any
3838
):
3939
pass
4040

@@ -96,7 +96,7 @@ def save_query_resp(self, query_resp_dict, **kwargs):
9696

9797
def import_data(
9898
self, texts: List[Any], image_urls: List[Any], image_ids: List[Any], answers: List[Answer],
99-
embeddings: List[Any], model: Any, iat_type: Any
99+
embeddings: List[Any], model: Any, mm_type: Any
100100
):
101101
pass
102102

0 commit comments

Comments
 (0)