@@ -262,7 +262,8 @@ def __repr__(self):
262262class FunctionKind (Enum ):
263263 UNARY = 0
264264 BINARY = 1
265- TYPE = 2
265+ TERNARY = 2
266+ TYPE = 3
266267
267268
268269class UnaryFnType :
@@ -339,6 +340,33 @@ class BinaryFn:
339340 powf = BinaryFnType ("powf" )
340341
341342
343+ class TernaryFnType :
344+ """Ternary function.
345+
346+ A ternary function takes three tensor expressions and returns the
347+ function evaluation result.
348+ """
349+
350+ def __init__ (self , fn_name : str ):
351+ self .fn_name = fn_name
352+
353+ def __call__ (
354+ self , arg0 : TensorExpression , arg1 : TensorExpression , arg2 : TensorExpression
355+ ) -> "TensorFn" :
356+ return TensorFn (
357+ FunctionKind .TERNARY , self .fn_name , None , None , [arg0 , arg1 , arg2 ]
358+ )
359+
360+ def __repr__ (self ):
361+ return f"{ self .fn_name } "
362+
363+
364+ class TernaryFn :
365+ """Ternary function namespace."""
366+
367+ select = TernaryFnType ("select" )
368+
369+
342370class TypeFnType :
343371 """Type conversion function.
344372
@@ -437,7 +465,8 @@ class OperandKind(Enum):
437465 INDEX_ATTR = 3
438466 UNARY_FN_ATTR = 4
439467 BINARY_FN_ATTR = 5
440- TYPE_FN_ATTR = 6
468+ TERNARY_FN_ATTR = 6
469+ TYPE_FN_ATTR = 7
441470
442471
443472class OperandDef :
@@ -489,6 +518,7 @@ def is_attribute(self) -> bool:
489518 self .kind == OperandKind .INDEX_ATTR
490519 or self .kind == OperandKind .UNARY_FN_ATTR
491520 or self .kind == OperandKind .BINARY_FN_ATTR
521+ or self .kind == OperandKind .TERNARY_FN_ATTR
492522 or self .kind == OperandKind .TYPE_FN_ATTR
493523 )
494524
@@ -670,6 +700,33 @@ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
670700 return ReduceFnUse (None , self , * reduce_dims )
671701
672702
703+ class TernaryFnAttrDef :
704+ """Ternary function attribute definition.
705+
706+ Ternary function attributes provide a way to make the arithmetic computation
707+ parametrizable. Every attribute specifies a default Ternary function
708+ that may be overwritten at operation instantiation time.
709+ """
710+
711+ def __init__ (self , default : "TernaryFnType" ):
712+ if not isinstance (default , TernaryFnType ):
713+ raise ValueError (
714+ f"TernaryFnAttrDef requires default of type TernaryFnType "
715+ f"but got { default } "
716+ )
717+ self .operand_def = OperandDef (
718+ OperandKind .TERNARY_FN_ATTR , default_fn = default .fn_name
719+ )
720+
721+ def __call__ (self , arg0 : TensorExpression , arg1 : TensorExpression ) -> TensorFn :
722+ return TensorFn (
723+ FunctionKind .TERNARY , None , self .operand_def , None , [arg0 , arg1 ]
724+ )
725+
726+ def __getitem__ (self , reduce_dims : Tuple [DimDef ]) -> ReduceFnUse :
727+ return ReduceFnUse (None , self , * reduce_dims )
728+
729+
673730class TypeFnAttrDef :
674731 """Type conversion function attribute definition.
675732
0 commit comments