-
Notifications
You must be signed in to change notification settings - Fork 54
Description
2D temporaries could be developed as an experimenbtal feature quickly with support in dace:X and debug backend.
gt:X backend are locked behind an hypothesis that pre-suppose temporaries to be 3D - unclear how easy it is to undo.
Syntax remains to be decided. We propose to re-use the typing of parameters within stencils, e.g.:
def the_stencil(in_field: Field[IJK, np.float64], out_field: Field[IJK, np.float64]):
with computation(FORWARD), interval(0, 1):
tmp_2d: Field[IJK, np.float64] = in_field
with computation(PARALLEL), interval(...):
out_field = in_field + tmp_2din order to compact the types we also propose to bring down shortcuts that have been introduced in our NDSL layer: FloarField and FloatFieldIJ.
Rewriting the above to
def the_stencil(in_field: FloatField, out_field: FloatField):
with computation(FORWARD), interval(0, 1):
tmp_2d: FloatFieldIJ = in_field
with computation(PARALLEL), interval(...):
out_field = in_field + tmp_2dThere's one caveat. Currently mixed precision implementation as introduced a quick way to define non-standard precision on 3D temporaries. E.g.
tmp_3d: float64 = in_fieldWith the introduction of this feature we should undo (deprecate then remove) this feature and move to explicitly stating the full type, e.g.
tmp_3d: FloatField64 = in_fieldWe would, in contrast, keep the current default behavior which is that any temporary defined without a type hint is a FloatField
Dev NOTE.
The frontend work is easy:
- Annotations are already intercepted since the mixed precision work
- Temporary declaration can see the annotations and reason on it
Here's a hard-coded version that works in dace:X
diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py
index 8f461e8c2..3d6d63a8b 100644
--- a/src/gt4py/cartesian/frontend/gtscript_frontend.py
+++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py
@@ -1599,10 +1599,18 @@ class IRMaker(ast.NodeVisitor):
loc=nodes.Location.from_ast_node(t),
)
dtype = nodes.DataType.AUTO
+ axes = nodes.Domain.LatLonGrid().axes_names
if target_annotation is not None:
source = ast.unparse(target_annotation)
+ if source.startswith("IJTemporary"):
+ axes = nodes.Domain.LatLonGrid()
+ axes.sequential_axis = None
+ axes = axes.axes_names
+ dtype_to_translate = ast.unparse(target_annotation.slice)
+ else:
+ dtype_to_translate = source
try:
- dtype = eval(source, self.temporary_type_to_native_type)
+ dtype = eval(dtype_to_translate, self.temporary_type_to_native_type)
except NameError:
raise GTScriptSyntaxError(
message=f"Failed to recognize type {source} for local symbol {name}."
@@ -1612,7 +1620,7 @@ class IRMaker(ast.NodeVisitor):
field_decl = nodes.FieldDecl(
name=name,
data_type=dtype,
- axes=nodes.Domain.LatLonGrid().axes_names,
+ axes=axes,
is_api=False,
loc=nodes.Location.from_ast_node(t),
)Doing the work properly would mean using eval(target_annotation with a subset of gtscript.Field and/or FloatField symbols, then capture the axes from those object which do carry them.