|
11 | 11 | from torch._dynamo.utils import counters |
12 | 12 | from torch.fx.experimental.symbolic_shapes import has_free_symbols |
13 | 13 | from torch.fx.node import map_arg |
| 14 | +from torch.fx.passes.shape_prop import TensorMetadata |
14 | 15 |
|
15 | 16 | from ..lowering import lowerings as L, require_channels_last |
16 | | -from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match |
| 17 | +from ..pattern_matcher import ( |
| 18 | + Arg, |
| 19 | + CallFunction, |
| 20 | + CallFunctionVarArgs, |
| 21 | + filter_nodes, |
| 22 | + is_backward_pattern, |
| 23 | + KeywordArg, |
| 24 | + ListOf, |
| 25 | + Match, |
| 26 | + MULTIPLE, |
| 27 | + Placeholder, |
| 28 | + register_graph_pattern, |
| 29 | +) |
17 | 30 | from ..utils import pad_listlike |
18 | 31 | from .freezing_patterns import register_freezing_graph_pattern |
| 32 | +from .group_batch_fusion import is_node_meta_valid |
19 | 33 | from .post_grad import register_lowering_pattern |
| 34 | +from .split_cat import construct_pattern_matcher_pass |
20 | 35 |
|
21 | 36 |
|
22 | 37 | aten = torch.ops.aten |
@@ -3590,3 +3605,100 @@ def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: |
3590 | 3605 |
|
3591 | 3606 | graph_module.graph.lint() |
3592 | 3607 | graph_module.recompile() |
| 3608 | + |
| 3609 | + |
| 3610 | +activation_quantization_aten_pass = construct_pattern_matcher_pass( |
| 3611 | + "activation_quantization_aten_pass" |
| 3612 | +) |
| 3613 | +@register_graph_pattern( |
| 3614 | + CallFunctionVarArgs( |
| 3615 | + [ |
| 3616 | + torch.ops.aten.relu.default, |
| 3617 | + torch.ops.aten.tanh.default, |
| 3618 | + torch.ops.aten.sigmoid.default, |
| 3619 | + torch.ops.aten.gelu.default, |
| 3620 | + ], |
| 3621 | + users=MULTIPLE |
| 3622 | + ), |
| 3623 | + pass_dict=activation_quantization_aten_pass, |
| 3624 | + extra_check=is_backward_pattern(activation_quantization_aten_pass, False), |
| 3625 | +) |
| 3626 | +def quantize_activation_fw(match: Match, *args, **kwargs): |
| 3627 | + graph = match.graph |
| 3628 | + activation_nodes = match.nodes |
| 3629 | + quant_type = torch._inductor.config.post_grad_fusion_options[ |
| 3630 | + "activation_quantization_aten_pass" |
| 3631 | + ].get("quant_type", torch.float8_e5m2) |
| 3632 | + for activation_node in activation_nodes: |
| 3633 | + # check if the activation node is the return node |
| 3634 | + users = list(activation_node.users.keys()) |
| 3635 | + for user in users: |
| 3636 | + # check if the user is the return node |
| 3637 | + if user.op == "output": |
| 3638 | + if not is_node_meta_valid(activation_node): |
| 3639 | + continue |
| 3640 | + # we need to insert a quantization node after it |
| 3641 | + with graph.inserting_after(activation_node): |
| 3642 | + quant_activation_node = graph.call_function( |
| 3643 | + torch.ops.prims.convert_element_type.default, |
| 3644 | + args=(activation_node, quant_type) |
| 3645 | + ) |
| 3646 | + quant_activation_node.meta.update(activation_node.meta) |
| 3647 | + quant_activation_node.meta["val"] = quant_activation_node.meta["val"].to(quant_type) |
| 3648 | + quant_activation_node.meta["tensor_meta"] = TensorMetadata( |
| 3649 | + shape=quant_activation_node.meta["tensor_meta"].shape, |
| 3650 | + dtype=quant_type, |
| 3651 | + requires_grad=quant_activation_node.meta["tensor_meta"].requires_grad, |
| 3652 | + stride=quant_activation_node.meta["tensor_meta"].stride, |
| 3653 | + memory_format=quant_activation_node.meta["tensor_meta"].memory_format, |
| 3654 | + is_quantized=quant_activation_node.meta["tensor_meta"].is_quantized, |
| 3655 | + qparams=quant_activation_node.meta["tensor_meta"].qparams, |
| 3656 | + ) |
| 3657 | + # only update the return node args, and remain all other users unchanged |
| 3658 | + user_updated_args = tuple( |
| 3659 | + quant_activation_node if node == activation_node else node for node in user.args[0] |
| 3660 | + ) |
| 3661 | + user.update_arg(0, user_updated_args) |
| 3662 | + if len(activation_node.users) == 0: |
| 3663 | + graph.erase_node(activation_node) |
| 3664 | + counters["inductor"]["activation_quantization_aten_pass"] += 1 |
| 3665 | + break |
| 3666 | + |
| 3667 | + |
| 3668 | +@register_graph_pattern( |
| 3669 | + Placeholder(["tanh", "relu", "sigmoid", "gelu"], users=MULTIPLE), |
| 3670 | + pass_dict=activation_quantization_aten_pass, |
| 3671 | + extra_check=is_backward_pattern(activation_quantization_aten_pass, True), |
| 3672 | +) |
| 3673 | +def quantize_activation_bw(match: Match, *args, **kwargs): |
| 3674 | + graph = match.graph |
| 3675 | + inputs = match.nodes |
| 3676 | + quant_type = torch._inductor.config.post_grad_fusion_options[ |
| 3677 | + "activation_quantization_aten_pass" |
| 3678 | + ].get("quant_type", torch.float8_e5m2) |
| 3679 | + for input in inputs: |
| 3680 | + if not is_node_meta_valid(input): |
| 3681 | + continue |
| 3682 | + # we need to insert a dequantization node after it |
| 3683 | + with graph.inserting_after(input): |
| 3684 | + dequant_activation_node = graph.call_function( |
| 3685 | + torch.ops.prims.convert_element_type.default, |
| 3686 | + args=(input, input.meta["val"].dtype) |
| 3687 | + ) |
| 3688 | + input.replace_all_uses_with(dequant_activation_node) |
| 3689 | + # restore dequant_activation_node input |
| 3690 | + dequant_activation_node.replace_input_with(dequant_activation_node, input) |
| 3691 | + dequant_activation_node.meta.update(input.meta) |
| 3692 | + |
| 3693 | + # replace the input with quant type to keep sync with forward pass |
| 3694 | + input.meta["val"] = input.meta["val"].to(quant_type) |
| 3695 | + input.meta["tensor_meta"] = TensorMetadata( |
| 3696 | + shape=input.meta["tensor_meta"].shape, |
| 3697 | + dtype=quant_type, |
| 3698 | + requires_grad=input.meta["tensor_meta"].requires_grad, |
| 3699 | + stride=input.meta["tensor_meta"].stride, |
| 3700 | + memory_format=input.meta["tensor_meta"].memory_format, |
| 3701 | + is_quantized=input.meta["tensor_meta"].is_quantized, |
| 3702 | + qparams=input.meta["tensor_meta"].qparams, |
| 3703 | + ) |
| 3704 | + counters["inductor"]["activation_quantization_aten_pass"] += 1 |
0 commit comments