Add GLU support#38
Conversation
|
I've updated the PR based on @tgale96 's feedback, mainly:
I've left a few of the issues unresolved as I work through them, namely |
| memory_optimized_mlp : bool = False | ||
| mlp_type: str = 'mlp' | ||
| grouped_mlp: bool = False | ||
| mlp_type : str = 'mlp' |
There was a problem hiding this comment.
Did you mean to add these spaces? Looks like we're actually mixed on having them and not having them in this file...
There was a problem hiding this comment.
hmmm yea I thought that the spaces would match the existing style of the file
|
|
||
| MlpType = Union[mlp.SparseMLP, glu.SparseGLU] | ||
|
|
||
| class dMlpRegistry: |
There was a problem hiding this comment.
Stylistic thing - can we remove this class and have get just be a function on the module? Then REGISTRY can be a private, global? i.e., _REGISTRY?
There was a problem hiding this comment.
I've refactored it to get rid of the class, let me know if it looks good now!
|
LGTM! Ready to merge? |
|
@tgale96 great, yes! |
|
Thanks for the contribution Sasha! This is awesome. |
This change adds GLU blocks to megablocks (replacing vanilla MLPs), and does some refactoring around the mlp types, including an
MLP_TYPE_REGISTRY.Note, this is unoptimized at the moment: