Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Problem in position embedding #4

@jmercat

Description

@jmercat

queries, keys, vals = self.pos_embed(queries, keys, vals)

It seems to me that the rotary position embedding is being applied on the head dimension (dim -2) of the vectors q, k instead of the sequence dimension (dim 1).
I think the head and sequence dimensions should be swapped before calling position embedding .
(see https://github.com/facebookresearch/xformers/blob/748c159096d4f9fcfe3eaf22801e5aed4777210b/xformers/components/positional_embedding/rotary.py#L85)

What I'm proposing is simply to re-write RotaryWithCast as follow:

class RotaryWithCast(RotaryEmbedding):
    def forward(self, q, k, v):
        q, k = super().forward(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3))
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        return q.to(v.dtype), k.to(v.dtype), v

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions