Description
Hi, I am trying to convert a Pytorch model to JAX model, and I find there are some implementations different in "bicubic".
Here is a script I have used to confirm that the outputs are different.
import torch
import torch.nn as nn
import jax.numpy as jnp
import jax
import numpy as np
input_arr = jnp.array([
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
]) / 9
torch_input_arr = torch.from_numpy(np.array(input_arr)).unsqueeze(0).unsqueeze(0)
output = jax.image.resize(
image=input_arr,
shape=(4,10),
method="bicubic")
interpolate_without_align_corners = torch.nn.functional.interpolate(
torch_input_arr.float(),
size=(4, 10),
mode="bicubic",
align_corners=False,
)
print("jax: ", output)
print("torch interpolate without align corners",interpolate_without_align_corners)
from the prints I get:
jax: [[-0.05882353 -0.03036592 0.0298608 0.08986928 0.14542481 0.20098035
0.25653598 0.31654444 0.37677112 0.40522873]
[ 0.10527544 0.13373306 0.19395977 0.25396827 0.30952382 0.36507937
0.42063496 0.48064342 0.54087013 0.5693277 ]
[ 0.4306723 0.45912993 0.5193566 0.57936513 0.6349207 0.69047624
0.74603176 0.8060403 0.86626697 0.89472455]
[ 0.5947712 0.62322885 0.6834555 0.74346405 0.79901963 0.85457516
0.9101307 0.9701392 1.0303658 1.0588235 ]]
torch interpolate without align corners tensor([[[[-0.0703, -0.0373, 0.0156, 0.0855, 0.1306, 0.1966, 0.2418,
0.3116, 0.3646, 0.3976],
[ 0.1141, 0.1471, 0.2001, 0.2700, 0.3151, 0.3811, 0.4262,
0.4961, 0.5490, 0.5820],
[ 0.4180, 0.4510, 0.5039, 0.5738, 0.6189, 0.6849, 0.7300,
0.7999, 0.8529, 0.8859],
[ 0.6024, 0.6354, 0.6884, 0.7582, 0.8034, 0.8694, 0.9145,
0.9844, 1.0373, 1.0703]]]])
I tried some other resizing methods. e.g., linear and bilinear, and they look fine. Does anyone have a workaround for it?
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response
Description
Hi, I am trying to convert a Pytorch model to JAX model, and I find there are some implementations different in "bicubic".
Here is a script I have used to confirm that the outputs are different.
from the prints I get:
I tried some other resizing methods. e.g., linear and bilinear, and they look fine. Does anyone have a workaround for it?
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response