-
Notifications
You must be signed in to change notification settings - Fork 352
Open
Labels
Description
PR: #2036, adds standard model architectures to torchao.testing.model_architectures.py. Replace the existing model definitions from torchao and tests to reuse the model definitions from model_architectures.py. If new definitions are found, add them to model_architectures.py
Eg:
Replace
ao/test/quantization/test_quant_api.py
Lines 122 to 138 in 34421b1
| class ToyLinearModel(torch.nn.Module): | |
| def __init__(self, m=64, n=32, k=64, bias=False): | |
| super().__init__() | |
| self.linear1 = torch.nn.Linear(m, n, bias=bias).to(torch.float) | |
| self.linear2 = torch.nn.Linear(n, k, bias=bias).to(torch.float) | |
| def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): | |
| return ( | |
| torch.randn( | |
| batch_size, self.linear1.in_features, dtype=dtype, device=device | |
| ), | |
| ) | |
| def forward(self, x): | |
| x = self.linear1(x) | |
| x = self.linear2(x) | |
| return x |
torchao.testing.model_architectures.ToyLinearModel