Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[BUG] Incorrect code generation with in-place division when using Warp 1.4.0+ #342

@romerojosh

Description

@romerojosh

Bug Description

I have a program that works fine with Warp 1.3.1 and breaks with Warp 1.4.0+. I've found one kernel so far that leads to issues:

@wp.kernel
def trisolve_periodic_multi(x: wp.array2d(dtype=Any),
                            q: wp.array2d(dtype=Any),
                            s: wp.array2d(dtype=Any),
                            qe: wp.array2d(dtype=Any),
                            ap: wp.array2d(dtype=Any),
                            am: wp.array2d(dtype=Any),
                            ac: wp.array2d(dtype=Any),
                            n: int):

  irhs = wp.tid()

  q[0, irhs] = -ap[0, irhs] / ac[0, irhs]
  s[0, irhs] = -am[0, irhs] / ac[0, irhs]
  fn = x[n - 1, irhs]
  x[0, irhs] /= ac[0, irhs]

  # forward elimination sweep
  for i in range(1, n):
    p = x.dtype(1.0) / (ac[i, irhs] + am[i, irhs]* q[i - 1, irhs])
    q[i, irhs] = -ap[i, irhs] * p
    s[i, irhs] = -am[i, irhs] * s[i - 1, irhs] * p
    x[i, irhs] = (x[i, irhs] - am[i, irhs] * x[i - 1, irhs]) * p

  s[n - 1, irhs] = x.dtype(1.0)
  qe[n - 1, irhs] = x.dtype(0.0)

  # backward pass
  for i in range(n - 2, -1, -1):
    s[i, irhs] += q[i, irhs] * s[i + 1, irhs]
    qe[i, irhs] = x[i, irhs] + q[i, irhs] * qe[i + 1, irhs]

  x[n - 1, irhs] = ((fn - ap[0, irhs] * qe[0, irhs] - am[0, irhs] * qe[n - 2, irhs]) /
                    (ap[0, irhs] * s[0, irhs] + am[0, irhs] * s[n - 2, irhs] + ac[0, irhs]))

  # backward elimination pass
  for i in range(n - 2, -1, -1):
    x[i, irhs] = x[n - 1, irhs] * s[i, irhs] + qe[i, irhs]

After some digging, I think the issue might stem from the line with in-place division:

x[0, irhs] /= ac[0, irhs]

If I look at the generated code, I find the following in 1.3.1 for that line:

        // x[0, irhs] /= ac[0, irhs]                                                              <L 21>
        var_19 = wp::address(var_x, var_1, var_0);
        var_20 = wp::address(var_ac, var_1, var_0);
        var_21 = wp::load(var_19);
        var_22 = wp::load(var_20);
        var_23 = wp::div(var_21, var_22);
        wp::array_store(var_x, var_1, var_0, var_23);

while in 1.4.0, I see just a single load operation, which is clearly not right:

        // x[0, irhs] /= ac[0, irhs]                                                              <L 21>
        var_19 = wp::address(var_ac, var_1, var_0);

The output from 1.4.0 and 1.4.1 from compiling the program also prints a warning like:

Warning: in-place op <ast.Div object at 0x7f71739d6aa0> is not differentiable

which also indicates this in-place division might be causing issues.

As an additional experiment, I also changed the in place division line to:

x[0, irhs] = x[0, irhs] / ac[0, irhs]

and that fixes it with Warp 1.4.0+.

I don't care that my kernel is not differentiable, so it is a little alarming that the in-place division warning seems to just produce incorrect code. Is this expected and should this warning be fatal?

System Information

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions