Skip to content

Commit f0eed80

Browse files
authored
align InfoNCE with Qwen3-Embedding (#6420)
1 parent e14bcd3 commit f0eed80

File tree

3 files changed

+113
-20
lines changed

3 files changed

+113
-20
lines changed

docs/source/BestPractices/Embedding训练.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,13 @@ loss的源代码可以在[这里](https://github.com/modelscope/ms-swift/blob/ma
8686
```
8787

8888
infonce loss支持几个环境变量:
89-
1. INFONCE_TEMPERATURE temperature参数,不设置的话默认值是0.01
90-
2. INFONCE_USE_BATCH 使用sample内部的`negative_messages`(hard negative样例)还是使用一个batch内其他样本作为in-batch negatives;默认为True,表示使用batch内部的样本作为负例
91-
3. INFONCE_HARD_NEGATIVES hard negatives的数量;如果不设置会使用数据中提供的所有`negative_messages`。由于长度未必一致,因此会采用for循环计算loss(计算会慢)。若设置为某个数值,则不足会随机采样补齐,超长会选用前`INFONCE_HARD_NEGATIVES`
92-
4. INFONCE_MASK_FAKE_NEGATIVE mask掉假negative。默认为False,开启时会判断positive sample的similarity+0.1,比该值大的sample的similarity会被设置为-inf,防止positive sample泄露问题
89+
1. `INFONCE_TEMPERATURE`: temperature参数,不设置的话默认值是0.01
90+
2. `INFONCE_USE_BATCH`: 使用sample内部的`negative_messages`(hard negative样例)还是使用一个batch内其他样本作为in-batch negatives;默认为True,表示使用batch内部的样本作为负例
91+
3. `INFONCE_HARD_NEGATIVES`: hard negatives的数量;如果不设置会使用数据中提供的所有`negative_messages`。由于长度未必一致,因此会采用for循环计算loss(计算会慢)。若设置为某个数值,则不足会随机采样补齐,超长会选用前`INFONCE_HARD_NEGATIVES`
92+
4. `INFONCE_MASK_FAKE_NEGATIVE`: mask掉假negative。默认为False,开启时会判断 `positive_similarity + INFONCE_FAKE_NEG_MARGIN`,比该阈值大的样本相似度会被设置为 `-inf`,以防止正样本泄露问题
93+
5. `INFONCE_FAKE_NEG_MARGIN`:假负样本屏蔽的边际,默认 `0.1`
94+
6. `INFONCE_INCLUDE_QQ`:是否在分母中加入 q–q 分量(query 间相似度)作为负例,默认 `False`
95+
7. `INFONCE_INCLUDE_DD`:是否在分母中加入 d–d 分量(正样本文档与 batch 内所有文档的相似度)作为负例,默认 `False`
9396

9497
> 也可以在数据集中将hard negatives数量设置为数量相等,这样即使不设置也不会使用for循环方式,加快计算速度
9598
> `negative_messages`也可以不提供。在这种情况下,保持`INFONCE_USE_BATCH=True`,会使用一个batch内部的其他样本作为负例

docs/source_en/BestPractices/Embedding.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,10 @@ InfoNCE loss supports the following environment variables:
9090
1. `INFONCE_TEMPERATURE`: The temperature parameter. If not set, the default value is 0.01.
9191
2. `INFONCE_USE_BATCH`: Use `negative_messages` within the sample (hard negatives) or use other samples in the batch as in-batch negatives. The default is `True`, which means using in-batch negatives.
9292
3. `INFONCE_HARD_NEGATIVES`: The number of hard negatives. If not set, all provided `negative_messages` will be used. Since the lengths may vary, a for loop will be used to compute the loss (slower). If set to a specific number, missing items will be randomly sampled, and excess items will be truncated to the first `INFONCE_HARD_NEGATIVES`.
93-
4. `INFONCE_MASK_FAKE_NEGATIVE`: Masks out fake negatives. The default is `False`. When enabled, it checks `positive_similarity + 0.1`; any sample with similarity larger than this threshold will have its similarity set to `-inf` to prevent positive leakage.
93+
4. `INFONCE_MASK_FAKE_NEGATIVE`: Masks out fake negatives. The default is `False`. When enabled, it checks `positive_similarity + INFONCE_FAKE_NEG_MARGIN`; any sample with similarity larger than this threshold will have its similarity set to `-inf` to prevent positive leakage.
94+
5. `INFONCE_FAKE_NEG_MARGIN`: Margin used by the fake-negative mask. Default: `0.1`.
95+
6. `INFONCE_INCLUDE_QQ`: Include the q–q block (similarities among queries) in the denominator as additional negatives. Default: `False`.
96+
7. `INFONCE_INCLUDE_DD`: Include the d–d block (similarities of the positive doc to all in-batch docs) in the denominator as additional negatives. Default: `False`.
9497

9598
> You can also make the number of hard negatives equal across samples in the dataset, which avoids the for-loop computation and speeds up training even if `INFONCE_HARD_NEGATIVES` is not set.
9699
>

swift/plugin/loss.py

Lines changed: 102 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,11 @@ def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, **kw
323323
hard_negatives = os.environ.get('INFONCE_HARD_NEGATIVES', None) # how many negative prompts kept in one sample
324324
# mask out fake negatives
325325
infonce_mask_fake_negative = strtobool(os.environ.get('INFONCE_MASK_FAKE_NEGATIVE', 'False'))
326+
fake_neg_margin = float(os.environ.get('INFONCE_FAKE_NEG_MARGIN', '0.1'))
327+
# enhanced components to align with Qwen3-Embedding denominator; controlled individually
328+
# defaults set to False for backward compatibility
329+
infonce_include_qq = strtobool(os.environ.get('INFONCE_INCLUDE_QQ', 'False'))
330+
infonce_include_dd = strtobool(os.environ.get('INFONCE_INCLUDE_DD', 'False'))
326331
if hard_negatives is not None:
327332
hard_negatives = int(hard_negatives)
328333
from swift.utils import get_dist_setting
@@ -376,39 +381,121 @@ def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, **kw
376381
# avg between all batches in one gpu
377382
loss /= len(split_tensors)
378383
else:
379-
380-
def mask_fake_negative(sim_matrix, sim_labels):
381-
thresholds = sim_matrix[torch.arange(sim_matrix.size(0)), sim_labels].view(-1, 1) + 0.1
382-
thresholds = thresholds.detach()
383-
mask = sim_matrix > thresholds
384-
sim_matrix[mask] = float('-inf')
385-
386384
if can_batched:
387385
# [B, neg+2, D]
388386
sentences = torch.stack(split_tensors, dim=0)
389-
# [B, D] * [B*(neg+1), D]
390-
similarity_matrix = torch.matmul(sentences[:, 0].squeeze(1), sentences[:,
391-
1:].reshape(-1, sentences.size(2)).T)
387+
# base q->d similarities (includes own positive and all in-batch documents)
388+
queries = sentences[:, 0].squeeze(1) # [B, D]
389+
docs_all = sentences[:, 1:].reshape(-1, sentences.size(2)) # [B*(neg+1), D]
390+
qd_matrix = torch.matmul(queries, docs_all.T) # [B, B*(neg+1)]
391+
# target indices: start of each group's document block (its positive)
392392
labels = torch.tensor(range(0,
393393
sentences.size(0) * (sentences.size(1) - 1),
394394
sentences.size(1) - 1)).view(-1).to(sentences.device)
395+
396+
logits_list = [qd_matrix]
397+
398+
if infonce_include_qq:
399+
# q->q similarities; exclude self via -inf on diagonal to avoid accidental positives
400+
qq_matrix = torch.matmul(queries, queries.T) # [B, B]
401+
qq_matrix = qq_matrix.clone()
402+
qq_matrix.fill_diagonal_(float('-inf'))
403+
logits_list.append(qq_matrix)
404+
405+
if infonce_include_dd:
406+
# d+ -> d (doc-doc) similarities; exclude self-positive column per row
407+
pos_docs = sentences[:, 1].squeeze(1) # [B, D]
408+
dd_matrix = torch.matmul(pos_docs, docs_all.T) # [B, B*(neg+1)]
409+
# mask self positive per row: column index = row_idx * (neg+1)
410+
block = sentences.size(1) - 1 # (neg+1)
411+
if block > 0:
412+
row_idx = torch.arange(dd_matrix.size(0), device=dd_matrix.device)
413+
col_idx = row_idx * block
414+
dd_matrix[row_idx, col_idx] = float('-inf')
415+
logits_list.append(dd_matrix)
416+
395417
if infonce_mask_fake_negative:
396-
mask_fake_negative(similarity_matrix, labels)
418+
# thresholds derived from positive q->d scores per row
419+
row_idx = torch.arange(qd_matrix.size(0), device=qd_matrix.device)
420+
pos_scores = qd_matrix[row_idx, labels]
421+
thresholds = pos_scores.view(-1, 1).detach() + fake_neg_margin
422+
423+
# qd block mask
424+
qd_block = qd_matrix.clone()
425+
qd_mask = qd_block > thresholds
426+
qd_block[qd_mask] = float('-inf')
427+
428+
components = [qd_block]
429+
430+
# qq block mask (if present)
431+
if infonce_include_qq:
432+
qq_block = qq_matrix.clone()
433+
qq_mask = qq_block > thresholds
434+
qq_block[qq_mask] = float('-inf')
435+
# diagonal already masked unconditionally at construction time
436+
components.append(qq_block)
437+
438+
# dd block (if present): self-positive column already masked unconditionally
439+
if infonce_include_dd:
440+
# align with Qwen3-Embedding, no threshold masking for d-d
441+
components.append(dd_matrix)
442+
443+
similarity_matrix = torch.cat(components, dim=1)
444+
else:
445+
# concatenate all components without masking
446+
similarity_matrix = torch.cat(logits_list, dim=1)
447+
# temperature scaling and CE
397448
similarity_matrix = similarity_matrix / temperature
398-
# every neg+1 is positive start from 0
399449
loss = nn.CrossEntropyLoss()(similarity_matrix, labels) / world_size # avoid duplicate
400450
else:
401451
all_tensors = []
402452
for tensor in split_tensors:
403453
all_tensors.append(tensor[1:])
404454
# cat all neg+1 tensors
405455
sentences = torch.cat(all_tensors, dim=0)
456+
# prepare query anchors list if q-q is included
457+
if infonce_include_qq:
458+
queries_all = torch.stack([t[0] for t in split_tensors], dim=0) # [B, D]
406459
length = 0
407460
for idx, tensor in enumerate(split_tensors):
408461
# [D] * [B*(neg+1), D], neg numbers are different
409-
similarity_matrix = torch.matmul(tensor[0], sentences.T) / temperature
410-
labels = torch.tensor(length).to(tensor.device)
411-
loss += nn.CrossEntropyLoss()(similarity_matrix, labels)
462+
qd_vec = torch.matmul(tensor[0], sentences.T)
463+
target = torch.tensor(length).to(tensor.device)
464+
logits_parts = []
465+
466+
# compute threshold from positive q->d score
467+
threshold = (qd_vec[target].detach() + fake_neg_margin)
468+
469+
# qd part with masking
470+
if infonce_mask_fake_negative:
471+
qd_masked = torch.where(qd_vec > threshold, torch.tensor(float('-inf'), device=qd_vec.device),
472+
qd_vec)
473+
else:
474+
qd_masked = qd_vec
475+
logits_parts.append(qd_masked)
476+
477+
# qq part
478+
if infonce_include_qq:
479+
qq_vec = torch.matmul(tensor[0], queries_all.T) # [B]
480+
# exclude self
481+
qq_vec = qq_vec.clone()
482+
qq_vec[idx] = float('-inf')
483+
if infonce_mask_fake_negative:
484+
qq_vec = torch.where(qq_vec > threshold, torch.tensor(float('-inf'), device=qq_vec.device),
485+
qq_vec)
486+
logits_parts.append(qq_vec)
487+
488+
# dd part
489+
if infonce_include_dd:
490+
dd_vec = torch.matmul(tensor[1], sentences.T) # [B*(neg+1)]
491+
# mask self positive column for this row only (no threshold masking for d-d)
492+
block = split_tensors[idx].size(0) - 1 # (neg+1) for this group
493+
dd_vec[length] = float('-inf')
494+
logits_parts.append(dd_vec)
495+
496+
logits_row = torch.cat(logits_parts, dim=-1)
497+
logits_row = logits_row / temperature
498+
loss += nn.CrossEntropyLoss()(logits_row.unsqueeze(0), target.unsqueeze(0))
412499
# next positive is neg+1
413500
length += tensor.size(0) - 1
414501
loss /= len(split_tensors)

0 commit comments

Comments
 (0)