-
Notifications
You must be signed in to change notification settings - Fork 376
Description
Bug Description
I have a program that works fine with Warp 1.3.1 and breaks with Warp 1.4.0+. I've found one kernel so far that leads to issues:
@wp.kernel
def trisolve_periodic_multi(x: wp.array2d(dtype=Any),
q: wp.array2d(dtype=Any),
s: wp.array2d(dtype=Any),
qe: wp.array2d(dtype=Any),
ap: wp.array2d(dtype=Any),
am: wp.array2d(dtype=Any),
ac: wp.array2d(dtype=Any),
n: int):
irhs = wp.tid()
q[0, irhs] = -ap[0, irhs] / ac[0, irhs]
s[0, irhs] = -am[0, irhs] / ac[0, irhs]
fn = x[n - 1, irhs]
x[0, irhs] /= ac[0, irhs]
# forward elimination sweep
for i in range(1, n):
p = x.dtype(1.0) / (ac[i, irhs] + am[i, irhs]* q[i - 1, irhs])
q[i, irhs] = -ap[i, irhs] * p
s[i, irhs] = -am[i, irhs] * s[i - 1, irhs] * p
x[i, irhs] = (x[i, irhs] - am[i, irhs] * x[i - 1, irhs]) * p
s[n - 1, irhs] = x.dtype(1.0)
qe[n - 1, irhs] = x.dtype(0.0)
# backward pass
for i in range(n - 2, -1, -1):
s[i, irhs] += q[i, irhs] * s[i + 1, irhs]
qe[i, irhs] = x[i, irhs] + q[i, irhs] * qe[i + 1, irhs]
x[n - 1, irhs] = ((fn - ap[0, irhs] * qe[0, irhs] - am[0, irhs] * qe[n - 2, irhs]) /
(ap[0, irhs] * s[0, irhs] + am[0, irhs] * s[n - 2, irhs] + ac[0, irhs]))
# backward elimination pass
for i in range(n - 2, -1, -1):
x[i, irhs] = x[n - 1, irhs] * s[i, irhs] + qe[i, irhs]
After some digging, I think the issue might stem from the line with in-place division:
x[0, irhs] /= ac[0, irhs]
If I look at the generated code, I find the following in 1.3.1 for that line:
// x[0, irhs] /= ac[0, irhs] <L 21>
var_19 = wp::address(var_x, var_1, var_0);
var_20 = wp::address(var_ac, var_1, var_0);
var_21 = wp::load(var_19);
var_22 = wp::load(var_20);
var_23 = wp::div(var_21, var_22);
wp::array_store(var_x, var_1, var_0, var_23);
while in 1.4.0, I see just a single load operation, which is clearly not right:
// x[0, irhs] /= ac[0, irhs] <L 21>
var_19 = wp::address(var_ac, var_1, var_0);
The output from 1.4.0 and 1.4.1 from compiling the program also prints a warning like:
Warning: in-place op <ast.Div object at 0x7f71739d6aa0> is not differentiable
which also indicates this in-place division might be causing issues.
As an additional experiment, I also changed the in place division line to:
x[0, irhs] = x[0, irhs] / ac[0, irhs]
and that fixes it with Warp 1.4.0+.
I don't care that my kernel is not differentiable, so it is a little alarming that the in-place division warning seems to just produce incorrect code. Is this expected and should this warning be fatal?
System Information
No response