Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Enable generic dimentionality for input#41

Merged
tgale96 merged 1 commit into
databricks:mainfrom
vchiley:enable_flat_inputs
Dec 4, 2023
Merged

Enable generic dimentionality for input#41
tgale96 merged 1 commit into
databricks:mainfrom
vchiley:enable_flat_inputs

Conversation

@vchiley

@vchiley vchiley commented Dec 1, 2023

Copy link
Copy Markdown
Contributor

MegaBlocks currently assumes a 3D input [sl, bs, hs] when in fact MegaBlocks should be able to operate on inputs of any multi-dim shape [*, hs]
The changes in this PR enable that.

Useful for padded inputs. If we flatten padded inputs with shape: [sl, bs, hs] into input with shape: [sl * bs - num_pad_tok, hs], then pad tokens are not part of all2all routing and lbl isn't computed on padding tokens.

@vchiley vchiley marked this pull request as ready for review December 1, 2023 23:52
@mvpatel2000 mvpatel2000 requested a review from tgale96 December 2, 2023 00:09
@tgale96 tgale96 merged commit 6e09bc5 into databricks:main Dec 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants