Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Optimize printing sympy expressions during logging and cache key computation #151823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
laithsakka opened this issue Apr 21, 2025 · 0 comments
Closed
Assignees
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@laithsakka
Copy link
Contributor

laithsakka commented Apr 21, 2025

repo:


import torch
def _cumsum(o):
    ret = [0] * (len(o) + 1)
    for i in range(len(o)):
        ret[i + 1] = ret[i] + o[i]
    return ret

@torch.compile(dynamic=True)
def func(o):
    out = _cumsum(o)
    return out

func([i for i in range(2000)])

We have a fast print implementation used in inductor here

def sympy_str(expr: sympy.Expr) -> str:
"""
Normal sympy str is very slow, this is a lot faster. The result are
somewhat worse, as it doesn't do as much simplification. So don't
use this for final codegen.
"""
if isinstance(expr, sympy.Symbol):
return expr.name
if isinstance(expr, sympy.Add):
return " + ".join(map(sympy_str, expr.args))
if isinstance(expr, sympy.Mul):
return " * ".join(map(sympy_str, expr.args))
if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)):
return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
return str(expr)

maybe we can reuse it?

profile:

Image https://fburl.com/scuba/pyperf_experimental/on_demand/vo6ru8ty

internal xref:
https://fb.workplace.com/groups/1075192433118967/permalink/23929961646604309/

Note this part is disabled from the model compilation even we can enable it after we fix this .

even though its not there we still see 10% cost for printing sympy expression in full model compilation
https://docs.google.com/document/d/1H-jueMz5VJuX6qVzyBl10OhlWWkxhAjp74JGtl7JhKg/edit?ouid=111904611073736927346&usp=docs_home&ths=true

cc @chauhang @penguinwu @ezyang @bobrenjc93

@eellison eellison added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 28, 2025
aorenste added a commit that referenced this issue May 5, 2025
Teach the graph printer how to allow overriding printing SymTypes (`SymInt`, `SymFloat`, `SymBool`) and then use that to reuse the fast SymNode printing from `torch._inductor.utils.sympy_str()` to make computing the cache key faster.

On my computer the repro from #151823 goes from 480s -> 80s (still terrible... but better).

Fixes #151823 




cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
aorenste added a commit that referenced this issue May 5, 2025
Teach the graph printer how to allow overriding printing SymTypes (`SymInt`, `SymFloat`, `SymBool`) and then use that to reuse the fast SymNode printing from `torch._inductor.utils.sympy_str()` to make computing the cache key faster.

On my computer the repro from #151823 goes from 480s -> 80s (still terrible... but better).

Fixes #151823 




cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants