Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Make conversions transitive and make getitem more comprehensive #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea
- Upgraded `egg-smol` dependency ([changes](https://github.com/saulshanabrook/egg-smol/compare/353c4387640019bd2066991ee0488dc6d5c54168...2ac80cb1162c61baef295d8e6d00351bfe84883f))

- Add support for functions which mutates their args, like `__setitem__` [#35](https://github.com/metadsl/egglog-python/pull/35)
- Makes conversions transitive [#38](https://github.com/metadsl/egglog-python/pull/38)

## 0.5.1 (2023-07-18)

Expand Down
2 changes: 2 additions & 0 deletions docs/reference/egglog-translation.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ Math(2) + 30 + "x"
Math(2) + Math(i64(30)) + Math.var(String("x"))
```

Regstering a conversion from A to B will also register all transitively reachable conversions from A to B.

### Declarations

In egglog, the `(declare ...)` command is syntactic sugar for a nullary function. In Python, these can be declare either as class variables or with the toplevel `egraph.constant` function:
Expand Down
44 changes: 14 additions & 30 deletions docs/tutorials/array-api.ipynb

Large diffs are not rendered by default.

109 changes: 97 additions & 12 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# mypy: disable-error-code=empty-body

from __future__ import annotations

import itertools
Expand All @@ -13,6 +11,9 @@
# Pretend that exprs are numbers b/c scikit learn does isinstance checks
from egglog.runtime import RuntimeExpr

# mypy: disable-error-code=empty-body


numbers.Integral.register(RuntimeExpr)

egraph = EGraph()
Expand Down Expand Up @@ -111,7 +112,6 @@ def isdtype(dtype: DType, kind: IsDtypeKind) -> Bool:
...


converter(np.dtype, IsDtypeKind, lambda x: IsDtypeKind.dtype(convert(x, DType)))
converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
converter(
Expand Down Expand Up @@ -286,23 +286,108 @@ def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
]


# HANDLED_FUNCTIONS = {}
@egraph.class_
class OptionalInt(Expr):
none: ClassVar[OptionalInt]

@classmethod
def some(cls, value: Int) -> OptionalInt:
...


converter(type(None), OptionalInt, lambda x: OptionalInt.none)
converter(Int, OptionalInt, OptionalInt.some)


@egraph.class_
class IndexKey(Expr):
class Slice(Expr):
def __init__(self, start: OptionalInt, stop: OptionalInt, step: OptionalInt) -> None:
...


converter(
slice,
Slice,
lambda x: Slice(convert(x.start, OptionalInt), convert(x.stop, OptionalInt), convert(x.step, OptionalInt)),
)


@egraph.class_
class MultiAxisIndexKeyItem(Expr):
ELLIPSIS: ClassVar[MultiAxisIndexKeyItem]
NONE: ClassVar[MultiAxisIndexKeyItem]

@classmethod
def tuple_int(cls, ti: TupleInt) -> IndexKey:
def int(cls, i: Int) -> MultiAxisIndexKeyItem:
...

@classmethod
def slice(cls, slice: Slice) -> MultiAxisIndexKeyItem:
...


converter(type(...), MultiAxisIndexKeyItem, lambda x: MultiAxisIndexKeyItem.ELLIPSIS)
converter(type(None), MultiAxisIndexKeyItem, lambda x: MultiAxisIndexKeyItem.NONE)
converter(Int, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.int)
converter(Slice, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.slice)


@egraph.class_
class MultiAxisIndexKey(Expr):
def __init__(self, item: MultiAxisIndexKeyItem) -> None:
...

EMPTY: ClassVar[MultiAxisIndexKey]

def __add__(self, other: MultiAxisIndexKey) -> MultiAxisIndexKey:
...


converter(
tuple,
MultiAxisIndexKey,
lambda x: MultiAxisIndexKey(convert(x[0], MultiAxisIndexKeyItem)) + convert(x[1:], MultiAxisIndexKey)
if x
else MultiAxisIndexKey.EMPTY,
)


@egraph.class_
class IndexKey(Expr):
"""
A key for indexing into an array

https://data-apis.org/array-api/2022.12/API_specification/indexing.html

It is equivalent to the following type signature:

Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis, None], ...], array]
"""

ELLIPSIS: ClassVar[IndexKey]

@classmethod
def int(cls, i: Int) -> IndexKey:
...

@classmethod
def slice(cls, slice: Slice) -> IndexKey:
...

# Disabled until we support late binding
# @classmethod
# def boolean_array(cls, b: NDArray) -> IndexKey:
# ...

@classmethod
def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey:
...


converter(tuple, IndexKey, lambda x: IndexKey.tuple_int(convert(x, TupleInt)))
converter(int, IndexKey, lambda x: IndexKey.int(Int(x)))
converter(Int, IndexKey, lambda x: IndexKey.int(x))
converter(type(...), IndexKey, lambda x: IndexKey.ELLIPSIS)
converter(Int, IndexKey, IndexKey.int)
converter(Slice, IndexKey, IndexKey.slice)
converter(MultiAxisIndexKey, IndexKey, IndexKey.multi_axis)


@egraph.class_
Expand Down Expand Up @@ -400,8 +485,8 @@ def ndarray_index(x: NDArray) -> IndexKey:
converter(NDArray, IndexKey, ndarray_index)


converter(float, NDArray, lambda x: NDArray.scalar_float(Float(x)))
converter(int, NDArray, lambda x: NDArray.scalar_int(Int(x)))
converter(Float, NDArray, NDArray.scalar_float)
converter(Int, NDArray, NDArray.scalar_int)


@egraph.register
Expand Down Expand Up @@ -478,7 +563,6 @@ def some(cls, value: Bool) -> OptionalBool:

converter(type(None), OptionalBool, lambda x: OptionalBool.none)
converter(Bool, OptionalBool, lambda x: OptionalBool.some(x))
converter(bool, OptionalBool, lambda x: OptionalBool.some(convert(x, Bool)))


@egraph.class_
Expand Down Expand Up @@ -518,6 +602,7 @@ def some(cls, value: TupleInt) -> OptionalTupleInt:

converter(type(None), OptionalTupleInt, lambda x: OptionalTupleInt.none)
converter(TupleInt, OptionalTupleInt, lambda x: OptionalTupleInt.some(x))
# TODO: Don't allow ints to be converted to OptionalTupleInt, and have another type that also unions ints
converter(int, OptionalTupleInt, lambda x: OptionalTupleInt.some(TupleInt(Int(x))))


Expand Down
36 changes: 35 additions & 1 deletion python/egglog/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,41 @@ def converter(from_type: Type[T], to_type: Type[V], fn: Callable[[T], V]) -> Non
to_type_name = process_tp(to_type)
if not isinstance(to_type_name, JustTypeRef):
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
CONVERSIONS[(process_tp(from_type), to_type_name)] = fn
_register_converter(process_tp(from_type), to_type_name, fn)


def _register_converter(a: Type | JustTypeRef, b: JustTypeRef, a_b: Callable) -> None:
"""
Registers a converter from some type to an egglog type, if not already registered.

Also adds transitive converters, i.e. if registering A->B and there is already B->C, then A->C will be registered.
Also, if registering A->B and there is already D->A, then D->B will be registered.
"""
if a == b or (a, b) in CONVERSIONS:
return
CONVERSIONS[(a, b)] = a_b
for (c, d), c_d in list(CONVERSIONS.items()):
if b == c:
_register_converter(a, d, _ComposedConverter(a_b, c_d))
if a == d:
_register_converter(c, b, _ComposedConverter(c_d, a_b))


@dataclass
class _ComposedConverter:
"""
A converter which is composed of multiple converters.

_ComposeConverter(a_b, b_c) is equivalent to lambda x: b_c(a_b(x))

We use the dataclass instead of the lambda to make it easier to debug.
"""

a_b: Callable
b_c: Callable

def __call__(self, x: object) -> object:
return self.b_c(self.a_b(x))


def convert(source: object, target: type[V]) -> V:
Expand Down
78 changes: 78 additions & 0 deletions python/tests/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import copy

import egglog.runtime
import pytest
from egglog import *


@pytest.fixture(autouse=True)
def reset_conversions():
old_conversions = copy.copy(egglog.runtime.CONVERSIONS)
yield
egglog.runtime.CONVERSIONS = old_conversions


def test_conversion_custom_metaclass():
class MyMeta(type):
pass
Expand Down Expand Up @@ -33,3 +44,70 @@ def __init__(self):
converter(MyType, MyTypeExpr, lambda x: MyTypeExpr())

assert expr_parts(convert(MyType(), MyTypeExpr)) == expr_parts(MyTypeExpr())


def test_conversion_transitive_forward():
egraph = EGraph()

class MyType:
pass

@egraph.class_
class MyTypeExpr(Expr):
def __init__(self):
...

@egraph.class_
class MyTypeExpr2(Expr):
def __init__(self):
...

converter(MyType, MyTypeExpr, lambda x: MyTypeExpr())
converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2())

assert expr_parts(convert(MyType(), MyTypeExpr2)) == expr_parts(MyTypeExpr2())


def test_conversion_transitive_backward():
egraph = EGraph()

class MyType:
pass

@egraph.class_
class MyTypeExpr(Expr):
def __init__(self):
...

@egraph.class_
class MyTypeExpr2(Expr):
def __init__(self):
...

converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2())
converter(MyType, MyTypeExpr, lambda x: MyTypeExpr())
assert expr_parts(convert(MyType(), MyTypeExpr2)) == expr_parts(MyTypeExpr2())


def test_conversion_transitive_cycle():
egraph = EGraph()

class MyType:
pass

@egraph.class_
class MyTypeExpr(Expr):
def __init__(self):
...

@egraph.class_
class MyTypeExpr2(Expr):
def __init__(self):
...

converter(MyType, MyTypeExpr, lambda x: MyTypeExpr())
converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2())
converter(MyTypeExpr2, MyTypeExpr, lambda x: MyTypeExpr())

assert expr_parts(convert(MyType(), MyTypeExpr2)) == expr_parts(MyTypeExpr2())
assert expr_parts(convert(MyType(), MyTypeExpr)) == expr_parts(MyTypeExpr())