Skip to content

Attention takes softmax over wrong dimension #282

Open
@Trezorro

Description

@Trezorro

PLEASE CORRECT ME IF IM WRONG.

I believe the line attn = attn.softmax(dim=2) is incorrect.

Dim 1 contains the index (i) over the query sequence entries, and dim 2 contains the index (j) over the key sequency entries.
If my understanding is correct, for any query (dim 1), we would like the sum of associated keys to be 1, so the copied information to that query position will remain the same scale, and so it may ignore information from many keys.

However the current implementation has it such that for any key, the attention from all query positions to it sums up to 1 after the softmax. Then some query positions may be close to 0 for all keys, while this forces EVERY key to be used by at least one query position.

We should not take the softmax over the key dimension (2) but over the query dimension (1).

This implementation, based on the current, uses dim 1. https://github.com/pdearena/pdearena/blob/db7664bb8ba1fe6ec3217e4079979a5e4f800151/pdearena/modules/conditioned/twod_unet.py#L223

Or am I mistaken in the output of the softmax?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions