forked from ltkong218/IFRNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_with_ft.py
More file actions
113 lines (76 loc) · 3.18 KB
/
Copy pathtest_with_ft.py
File metadata and controls
113 lines (76 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import numpy as np
import torch
from models.quantize.IFRNet_without_ft import Model
from utils import read
from imageio import mimsave
from imageio import imwrite
import time
import torch.nn as nn
import torch.nn.functional as F
from utils import warp, get_robust_weight
from loss import *
from datasets import Vimeo90K_Train_Dataset, Vimeo90K_Test_Dataset
from torchinfo import summary
from torch.utils.data import DataLoader
import torch.quantization as tq
from collections import OrderedDict
from torch.quantization import fuse_modules
# device = "cuda"
# model = Model().to(device)
# torch.save(model.state_dict(), "./testing_weight/model_weights_without_ft.pth")
def find_conv_relu_pairs(model):
pairs = []
for name, module in model.named_modules():
# if module is a Sequential containing 'conv' and 'relu' named submodules already
if isinstance(module, nn.Sequential):
keys = list(module._modules.keys())
for i in range(len(keys)-1):
a = module._modules[keys[i]]
b = module._modules[keys[i+1]]
if isinstance(a, nn.Conv2d) and isinstance(b, nn.ReLU):
pairs.append([f"{name}.{keys[i]}", f"{name}.{keys[i+1]}"])
# also handle named conv/relu inside OrderedDict wrapper:
if hasattr(a, 'named_children') and hasattr(b, 'named_children'):
pass
return pairs
device = "cpu"
model = Model().to(device)
model.load_state_dict(torch.load("./testing_weight/model_weights_without_ft.pth"))
model.eval()
for name, module in model.named_modules():
if isinstance(module, torch.nn.ConvTranspose2d):
module.qconfig = None
fuse_list = find_conv_relu_pairs(model)
fuse_modules(model, fuse_list, inplace=True)
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
tq.prepare(model, inplace=True)
tq.convert(model, inplace=True)
i = 1
source_path = './testing_data/input_occlusion2'
save_path = './testing_data/output_occlusion2'
file_type = 'png'
prev_img_np = read(f'{source_path}/frame_{str(i).zfill(4)}.{file_type}')
prev_img = (torch.tensor(prev_img_np.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
imwrite(f'{save_path}/img_0001.{file_type}',prev_img_np)
total_time = 0
while True :
try :
i += 1
next_img_np = read(f'{source_path}/frame_{str(i).zfill(4)}.{file_type}')
next_img = (torch.tensor(next_img_np .transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).to(device)
# embt = torch.tensor(1/2).view(1, 1, 1, 1).float()
embt = torch.tensor(0.5, device=device).view(1,1,1,1)
start = time.perf_counter()
imgt_pred = model.inference(prev_img, next_img, embt)
end = time.perf_counter()
print(f"Elapsed time: {end - start:.4f} seconds")
total_time += end - start
imgt_pred_np = (imgt_pred[0].data.permute(1, 2, 0).cpu().numpy() * 255.0).astype(np.uint8)
imwrite(f'{save_path}/img_{str((i-1)*2).zfill(4)}.{file_type}',imgt_pred_np)
imwrite(f'{save_path}/img_{str((i-1)*2 + 1).zfill(4)}.{file_type}',next_img_np)
prev_img = next_img
print(f'i_th : {i}')
except :
print(f'average time : {total_time/(i-1)}')
break