Skip to content

[Bug]: ManagedCollisionEmbeddingCollection returns all-zero embeddings after applying apply_optimizer_in_backward with RowWiseAdagrad #3030

@rayhuang90

Description

@rayhuang90

Hi there, the ManagedCollisionEmbeddingCollection with multiple tables + shared features returns all-zero embeddings after applying apply_optimizer_in_backward with RowWiseAdagrad. This result is unexpected.

This bug likely relates to the initialization and updating of the RowWiseAdagrad state and associated embeddings during eviction events in ManagedCollisionEmbeddingCollection.​

Below is a minimal reproducible Python code example:

mch_rowrisegrad_bug.py.txt

torchrun --standalone --nnodes=1 --node-rank=0 --nproc-per-node=1 mch_rowrisegrad_bug.py

Unexpected Result

[RANK0] emb_result key: item_tag, jt: JaggedTensor({
    [[[-0.00694586057215929, 0.005635389592498541, 0.029554935172200203, -0.014213510788977146, 0.027853110805153847, 0.023257633671164513, 0.004495333414524794, -0.01736217364668846]]]
})

[RANK0] emb_result key: user_tag, jt: JaggedTensor({
    [[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]    # --> all-zero embeddings, bug
})

[RANK0] emb_result key: item_id, jt: JaggedTensor({
    [[[0.026882518082857132, -0.008349019102752209, 0.025774799287319183, 0.010714510455727577, 0.022058645263314247, -0.02674921043217182, 0.029537828639149666, 0.007071810774505138]]]
})

[RANK0] remapped_ids: KeyedJaggedTensor({
    "item_tag": [[997]],
    "user_tag": [[998]],
    "item_id": [[998]]
})

My Current Environment

fbgemm_gpu==1.1.0+cu118
numpy==2.1.2
protobuf==3.19.6
torch==2.6.0+cu118
torchrec==1.1.0+cu118
transformers==4.48.0
triton==3.2.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions