Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Notice a different between jax.image.resize and F.interpolate when using "bicubic" #15768

@Haotian-Zhang

Description

@Haotian-Zhang

Description

Hi, I am trying to convert a Pytorch model to JAX model, and I find there are some implementations different in "bicubic".

Here is a script I have used to confirm that the outputs are different.

import torch
import torch.nn as nn
import jax.numpy as jnp
import jax
import numpy as np

input_arr = jnp.array([
    [0, 1, 2, 3, 4],
    [5, 6, 7, 8, 9],
]) / 9

torch_input_arr = torch.from_numpy(np.array(input_arr)).unsqueeze(0).unsqueeze(0)
output = jax.image.resize(
    image=input_arr,
    shape=(4,10),
    method="bicubic")

interpolate_without_align_corners = torch.nn.functional.interpolate(
    torch_input_arr.float(),
    size=(4, 10),
    mode="bicubic",
    align_corners=False,
)

print("jax: ", output)
print("torch interpolate without align corners",interpolate_without_align_corners)

from the prints I get:

jax:  [[-0.05882353 -0.03036592  0.0298608   0.08986928  0.14542481  0.20098035
   0.25653598  0.31654444  0.37677112  0.40522873]
 [ 0.10527544  0.13373306  0.19395977  0.25396827  0.30952382  0.36507937
   0.42063496  0.48064342  0.54087013  0.5693277 ]
 [ 0.4306723   0.45912993  0.5193566   0.57936513  0.6349207   0.69047624
   0.74603176  0.8060403   0.86626697  0.89472455]
 [ 0.5947712   0.62322885  0.6834555   0.74346405  0.79901963  0.85457516
   0.9101307   0.9701392   1.0303658   1.0588235 ]]
torch interpolate without align corners tensor([[[[-0.0703, -0.0373,  0.0156,  0.0855,  0.1306,  0.1966,  0.2418,
            0.3116,  0.3646,  0.3976],
          [ 0.1141,  0.1471,  0.2001,  0.2700,  0.3151,  0.3811,  0.4262,
            0.4961,  0.5490,  0.5820],
          [ 0.4180,  0.4510,  0.5039,  0.5738,  0.6189,  0.6849,  0.7300,
            0.7999,  0.8529,  0.8859],
          [ 0.6024,  0.6354,  0.6884,  0.7582,  0.8034,  0.8694,  0.9145,
            0.9844,  1.0373,  1.0703]]]])

I tried some other resizing methods. e.g., linear and bilinear, and they look fine. Does anyone have a workaround for it?

What jax/jaxlib version are you using?

No response

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions