diff --git a/README.md b/README.md index 55f9a2f..38fdc7f 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,10 @@ This repository contains the code (in PyTorch) for "[Pyramid Stereo Matching Network](https://arxiv.org/abs/1803.08669)" paper (CVPR 2018) by [Jia-Ren Chang](https://jiarenchang.github.io/) and [Yong-Sheng Chen](https://people.cs.nctu.edu.tw/~yschen/). +#### changelog +2020/12/20: Update PSMNet: now support torch 1.6.0 / torchvision 0.5.0 and python 3.7, Removed inconsistent indentation. + +2020/12/20: Our proposed Real-Time Stereo can be found here [Real-time Stereo](https://github.com/JiaRenChang/RealtimeStereo). ### Citation ``` @inproceedings{chang2018pyramid, @@ -30,9 +34,9 @@ Recent work has shown that depth estimation from a stereo pair of images can be ### Dependencies -- [Python2.7](https://www.python.org/downloads/) -- [PyTorch(0.4.0+)](http://pytorch.org) -- torchvision 0.2.0 (higher version may cause issues) +- [Python 3.7](https://www.python.org/downloads/) +- [PyTorch(1.6.0+)](http://pytorch.org) +- torchvision 0.5.0 - [KITTI Stereo](http://www.cvlibs.net/datasets/kitti/eval_stereo.php) - [Scene Flow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) @@ -42,6 +46,9 @@ Download RGB cleanpass images and its disparity for three subset: FlyingThings3D Put them in the same folder. And rename the folder as: "driving_frames_cleanpass", "driving_disparity", "monkaa_frames_cleanpass", "monkaa_disparity", "frames_cleanpass", "frames_disparity". ``` +### Notice +1. Warning of upsample function in PyTorch 0.4.1+: add "align_corners=True" to upsample functions. +2. Output disparity may be better with multipling by 1.17. Reported from issues [#135](https://github.com/JiaRenChang/PSMNet/issues/135) and [#113](https://github.com/JiaRenChang/PSMNet/issues/113). ### Train As an example, use the following command to train a PSMNet on Scene Flow @@ -51,7 +58,7 @@ python main.py --maxdisp 192 \ --model stackhourglass \ --datapath (your scene flow data folder)\ --epochs 10 \ - --loadmodel (optional)\ + --loadmodel (optional)\ --savemodel (path for saving model) ``` @@ -66,7 +73,7 @@ python finetune.py --maxdisp 192 \ --loadmodel (pretrained PSMNet) \ --savemodel (path for saving model) ``` -You can alse see those example in run.sh +You can also see those examples in run.sh. ### Evaluation Use the following command to evaluate the trained PSMNet on KITTI 2015 test data @@ -84,17 +91,23 @@ python submission.py --maxdisp 192 \ Update: 2018/9/6 We released the pre-trained KITTI 2012 model. -| KITTI 2015 | Scene Flow | KITTI 2012| -|---|---|---| -|[Google Drive](https://drive.google.com/file/d/1pHWjmhKMG4ffCrpcsp_MTXMJXhgl3kF9/view?usp=sharing)|[Google Drive](https://drive.google.com/file/d/1xoqkQ2NXik1TML_FMUTNZJFAHrhLdKZG/view?usp=sharing)|[Google Drive](https://drive.google.com/file/d/1p4eJ2xDzvQxaqB20A_MmSP9-KORBX1pZ/view)| +Update: 2021/9/22 a pretrained model using torch 1.8.1 (the previous model weight are trained torch 0.4.1) + +| KITTI 2015 | Scene Flow | KITTI 2012 | Scene Flow (torch 1.8.1) +|---|---|---|---| +|[Google Drive](https://drive.google.com/file/d/1pHWjmhKMG4ffCrpcsp_MTXMJXhgl3kF9/view?usp=sharing)|[Google Drive](https://drive.google.com/file/d/1xoqkQ2NXik1TML_FMUTNZJFAHrhLdKZG/view?usp=sharing)|[Google Drive](https://drive.google.com/file/d/1p4eJ2xDzvQxaqB20A_MmSP9-KORBX1pZ/view?usp=sharing)| [Google Drive](https://drive.google.com/file/d/1NDKrWHkwgMKtDwynXVU12emK3G5d5kkp/view?usp=sharing) +### Test on your own stereo pair +``` +python Test_img.py --loadmodel (finetuned PSMNet) --leftimg ./left.png --rightimg ./right.png +``` ## Results -### Evalutation of PSMNet with different settings +### Evaluation of PSMNet with different settings -※Note that the reported 3-px validation errors were calculated using KITTI's offical matlab code, not our code. +※Note that the reported 3-px validation errors were calculated using KITTI's official matlab code, not our code. ### Results on KITTI 2015 leaderboard [Leaderboard Link](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo) diff --git a/Test_img.py b/Test_img.py new file mode 100644 index 0000000..31340ef --- /dev/null +++ b/Test_img.py @@ -0,0 +1,132 @@ +from __future__ import print_function +import argparse +import os +import random +import torch +import torch.nn as nn +import torchvision.transforms as transforms +import torch.nn.functional as F +import numpy as np +import time +import math +from models import * +import cv2 +from PIL import Image + +# 2012 data /media/jiaren/ImageNet/data_scene_flow_2012/testing/ + +parser = argparse.ArgumentParser(description='PSMNet') +parser.add_argument('--KITTI', default='2015', + help='KITTI version') +parser.add_argument('--datapath', default='/media/jiaren/ImageNet/data_scene_flow_2015/testing/', + help='select model') +parser.add_argument('--loadmodel', default='./trained/pretrained_model_KITTI2015.tar', + help='loading model') +parser.add_argument('--leftimg', default= './VO04_L.png', + help='load model') +parser.add_argument('--rightimg', default= './VO04_R.png', + help='load model') +parser.add_argument('--model', default='stackhourglass', + help='select model') +parser.add_argument('--maxdisp', type=int, default=192, + help='maxium disparity') +parser.add_argument('--no-cuda', action='store_true', default=False, + help='enables CUDA training') +parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +torch.manual_seed(args.seed) +if args.cuda: + torch.cuda.manual_seed(args.seed) + +if args.model == 'stackhourglass': + model = stackhourglass(args.maxdisp) +elif args.model == 'basic': + model = basic(args.maxdisp) +else: + print('no model') + +model = nn.DataParallel(model, device_ids=[0]) +model.cuda() + +if args.loadmodel is not None: + print('load PSMNet') + state_dict = torch.load(args.loadmodel) + model.load_state_dict(state_dict['state_dict']) + +print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) + +def test(imgL,imgR): + model.eval() + + if args.cuda: + imgL = imgL.cuda() + imgR = imgR.cuda() + + with torch.no_grad(): + disp = model(imgL,imgR) + + disp = torch.squeeze(disp) + pred_disp = disp.data.cpu().numpy() + + return pred_disp + + +def main(): + + normal_mean_var = {'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225]} + infer_transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize(**normal_mean_var)]) + + imgL_o = Image.open(args.leftimg).convert('RGB') + imgR_o = Image.open(args.rightimg).convert('RGB') + + imgL = infer_transform(imgL_o) + imgR = infer_transform(imgR_o) + + + # pad to width and hight to 16 times + if imgL.shape[1] % 16 != 0: + times = imgL.shape[1]//16 + top_pad = (times+1)*16 -imgL.shape[1] + else: + top_pad = 0 + + if imgL.shape[2] % 16 != 0: + times = imgL.shape[2]//16 + right_pad = (times+1)*16-imgL.shape[2] + else: + right_pad = 0 + + imgL = F.pad(imgL,(0,right_pad, top_pad,0)).unsqueeze(0) + imgR = F.pad(imgR,(0,right_pad, top_pad,0)).unsqueeze(0) + + start_time = time.time() + pred_disp = test(imgL,imgR) + print('time = %.2f' %(time.time() - start_time)) + + + if top_pad !=0 and right_pad != 0: + img = pred_disp[top_pad:,:-right_pad] + elif top_pad ==0 and right_pad != 0: + img = pred_disp[:,:-right_pad] + elif top_pad !=0 and right_pad == 0: + img = pred_disp[top_pad:,:] + else: + img = pred_disp + + img = (img*256).astype('uint16') + img = Image.fromarray(img) + img.save('Test_disparity.png') + +if __name__ == '__main__': + main() + + + + + + diff --git a/dataloader/KITTI_submission_loader.py b/dataloader/KITTI_submission_loader.py index ad73745..cd4252e 100644 --- a/dataloader/KITTI_submission_loader.py +++ b/dataloader/KITTI_submission_loader.py @@ -16,14 +16,14 @@ def is_image_file(filename): def dataloader(filepath): - left_fold = 'image_2/' - right_fold = 'image_3/' + left_fold = 'image_2/' + right_fold = 'image_3/' - image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] + image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] - left_test = [filepath+left_fold+img for img in image] - right_test = [filepath+right_fold+img for img in image] + left_test = [filepath+left_fold+img for img in image] + right_test = [filepath+right_fold+img for img in image] - return left_test, right_test + return left_test, right_test diff --git a/dataloader/KITTI_submission_loader2012.py b/dataloader/KITTI_submission_loader2012.py index 7ca9859..0767ab6 100644 --- a/dataloader/KITTI_submission_loader2012.py +++ b/dataloader/KITTI_submission_loader2012.py @@ -16,14 +16,14 @@ def is_image_file(filename): def dataloader(filepath): - left_fold = 'colored_0/' - right_fold = 'colored_1/' + left_fold = 'colored_0/' + right_fold = 'colored_1/' - image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] + image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] - left_test = [filepath+left_fold+img for img in image] - right_test = [filepath+right_fold+img for img in image] + left_test = [filepath+left_fold+img for img in image] + right_test = [filepath+right_fold+img for img in image] - return left_test, right_test + return left_test, right_test diff --git a/dataloader/KITTIloader2012.py b/dataloader/KITTIloader2012.py index d651a89..6a0f944 100644 --- a/dataloader/KITTIloader2012.py +++ b/dataloader/KITTIloader2012.py @@ -16,22 +16,22 @@ def is_image_file(filename): def dataloader(filepath): - left_fold = 'colored_0/' - right_fold = 'colored_1/' - disp_noc = 'disp_occ/' + left_fold = 'colored_0/' + right_fold = 'colored_1/' + disp_noc = 'disp_occ/' - image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] + image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] - train = image[:] - val = image[160:] + train = image[:] + val = image[160:] - left_train = [filepath+left_fold+img for img in train] - right_train = [filepath+right_fold+img for img in train] - disp_train = [filepath+disp_noc+img for img in train] + left_train = [filepath+left_fold+img for img in train] + right_train = [filepath+right_fold+img for img in train] + disp_train = [filepath+disp_noc+img for img in train] - left_val = [filepath+left_fold+img for img in val] - right_val = [filepath+right_fold+img for img in val] - disp_val = [filepath+disp_noc+img for img in val] + left_val = [filepath+left_fold+img for img in val] + right_val = [filepath+right_fold+img for img in val] + disp_val = [filepath+disp_noc+img for img in val] - return left_train, right_train, disp_train, left_val, right_val, disp_val + return left_train, right_train, disp_train, left_val, right_val, disp_val diff --git a/dataloader/KITTIloader2015.py b/dataloader/KITTIloader2015.py index c443189..0eb1cf4 100644 --- a/dataloader/KITTIloader2015.py +++ b/dataloader/KITTIloader2015.py @@ -16,24 +16,24 @@ def is_image_file(filename): def dataloader(filepath): - left_fold = 'image_2/' - right_fold = 'image_3/' - disp_L = 'disp_occ_0/' - disp_R = 'disp_occ_1/' + left_fold = 'image_2/' + right_fold = 'image_3/' + disp_L = 'disp_occ_0/' + disp_R = 'disp_occ_1/' - image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] + image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] - train = image[:160] - val = image[160:] + train = image[:160] + val = image[160:] - left_train = [filepath+left_fold+img for img in train] - right_train = [filepath+right_fold+img for img in train] - disp_train_L = [filepath+disp_L+img for img in train] - #disp_train_R = [filepath+disp_R+img for img in train] + left_train = [filepath+left_fold+img for img in train] + right_train = [filepath+right_fold+img for img in train] + disp_train_L = [filepath+disp_L+img for img in train] + #disp_train_R = [filepath+disp_R+img for img in train] - left_val = [filepath+left_fold+img for img in val] - right_val = [filepath+right_fold+img for img in val] - disp_val_L = [filepath+disp_L+img for img in val] - #disp_val_R = [filepath+disp_R+img for img in val] + left_val = [filepath+left_fold+img for img in val] + right_val = [filepath+right_fold+img for img in val] + disp_val_L = [filepath+disp_L+img for img in val] + #disp_val_R = [filepath+disp_R+img for img in val] - return left_train, right_train, disp_train_L, left_val, right_val, disp_val_L + return left_train, right_train, disp_train_L, left_val, right_val, disp_val_L diff --git a/dataloader/SecenFlowLoader.py b/dataloader/SecenFlowLoader.py index af6f30a..a6ba537 100644 --- a/dataloader/SecenFlowLoader.py +++ b/dataloader/SecenFlowLoader.py @@ -5,9 +5,9 @@ import torchvision.transforms as transforms import random from PIL import Image, ImageOps -import preprocess -import listflowfile as lt -import readpfm as rp +from . import preprocess +from . import listflowfile as lt +from . import readpfm as rp import numpy as np IMG_EXTENSIONS = [ @@ -24,7 +24,6 @@ def default_loader(path): def disparity_loader(path): return rp.readPFM(path) - class myImageFloder(data.Dataset): def __init__(self, left, right, left_disparity, training, loader=default_loader, dploader= disparity_loader): @@ -46,34 +45,28 @@ def __getitem__(self, index): dataL, scaleL = self.dploader(disp_L) dataL = np.ascontiguousarray(dataL,dtype=np.float32) - - if self.training: - w, h = left_img.size - th, tw = 256, 512 - - x1 = random.randint(0, w - tw) - y1 = random.randint(0, h - th) + w, h = left_img.size + th, tw = 256, 512 - left_img = left_img.crop((x1, y1, x1 + tw, y1 + th)) - right_img = right_img.crop((x1, y1, x1 + tw, y1 + th)) + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) - dataL = dataL[y1:y1 + th, x1:x1 + tw] + left_img = left_img.crop((x1, y1, x1 + tw, y1 + th)) + right_img = right_img.crop((x1, y1, x1 + tw, y1 + th)) - processed = preprocess.get_transform(augment=False) - left_img = processed(left_img) - right_img = processed(right_img) + dataL = dataL[y1:y1 + th, x1:x1 + tw] - return left_img, right_img, dataL - else: - w, h = left_img.size - left_img = left_img.crop((w-960, h-544, w, h)) - right_img = right_img.crop((w-960, h-544, w, h)) - processed = preprocess.get_transform(augment=False) - left_img = processed(left_img) - right_img = processed(right_img) + processed = preprocess.get_transform(augment=False) + left_img = processed(left_img) + right_img = processed(right_img) - return left_img, right_img, dataL + return left_img, right_img, dataL + else: + processed = preprocess.get_transform(augment=False) + left_img = processed(left_img) + right_img = processed(right_img) + return left_img, right_img, dataL def __len__(self): return len(self.left) diff --git a/dataloader/listflowfile.py b/dataloader/listflowfile.py index d16556e..fa606bb 100644 --- a/dataloader/listflowfile.py +++ b/dataloader/listflowfile.py @@ -15,93 +15,94 @@ def is_image_file(filename): def dataloader(filepath): - classes = [d for d in os.listdir(filepath) if os.path.isdir(os.path.join(filepath, d))] - image = [img for img in classes if img.find('frames_cleanpass') > -1] - disp = [dsp for dsp in classes if dsp.find('disparity') > -1] + classes = [d for d in os.listdir(filepath) if os.path.isdir(os.path.join(filepath, d))] + image = [img for img in classes if img.find('frames_cleanpass') > -1] + disp = [dsp for dsp in classes if dsp.find('disparity') > -1] - monkaa_path = filepath + [x for x in image if 'monkaa' in x][0] - monkaa_disp = filepath + [x for x in disp if 'monkaa' in x][0] + monkaa_path = filepath + [x for x in image if 'monkaa' in x][0] + monkaa_disp = filepath + [x for x in disp if 'monkaa' in x][0] - - monkaa_dir = os.listdir(monkaa_path) - all_left_img=[] - all_right_img=[] - all_left_disp = [] - test_left_img=[] - test_right_img=[] - test_left_disp = [] + monkaa_dir = os.listdir(monkaa_path) + all_left_img=[] + all_right_img=[] + all_left_disp = [] + test_left_img=[] + test_right_img=[] + test_left_disp = [] - for dd in monkaa_dir: - for im in os.listdir(monkaa_path+'/'+dd+'/left/'): - if is_image_file(monkaa_path+'/'+dd+'/left/'+im): - all_left_img.append(monkaa_path+'/'+dd+'/left/'+im) - all_left_disp.append(monkaa_disp+'/'+dd+'/left/'+im.split(".")[0]+'.pfm') - for im in os.listdir(monkaa_path+'/'+dd+'/right/'): - if is_image_file(monkaa_path+'/'+dd+'/right/'+im): - all_right_img.append(monkaa_path+'/'+dd+'/right/'+im) + for dd in monkaa_dir: + for im in os.listdir(monkaa_path+'/'+dd+'/left/'): + if is_image_file(monkaa_path+'/'+dd+'/left/'+im): + all_left_img.append(monkaa_path+'/'+dd+'/left/'+im) + all_left_disp.append(monkaa_disp+'/'+dd+'/left/'+im.split(".")[0]+'.pfm') - flying_path = filepath + [x for x in image if x == 'frames_cleanpass'][0] - flying_disp = filepath + [x for x in disp if x == 'frames_disparity'][0] - flying_dir = flying_path+'/TRAIN/' - subdir = ['A','B','C'] + for im in os.listdir(monkaa_path+'/'+dd+'/right/'): + if is_image_file(monkaa_path+'/'+dd+'/right/'+im): + all_right_img.append(monkaa_path+'/'+dd+'/right/'+im) - for ss in subdir: - flying = os.listdir(flying_dir+ss) + flying_path = filepath + [x for x in image if x == 'frames_cleanpass'][0] + flying_disp = filepath + [x for x in disp if x == 'frames_disparity'][0] + flying_dir = flying_path+'/TRAIN/' + subdir = ['A','B','C'] - for ff in flying: - imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/') - for im in imm_l: - if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im): - all_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im) + for ss in subdir: + flying = os.listdir(flying_dir+ss) - all_left_disp.append(flying_disp+'/TRAIN/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm') + for ff in flying: + imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/') + for im in imm_l: + if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im): + all_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im) + + all_left_disp.append(flying_disp+'/TRAIN/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm') - if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im): - all_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im) + if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im): + all_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im) - flying_dir = flying_path+'/TEST/' + flying_dir = flying_path+'/TEST/' - subdir = ['A','B','C'] + subdir = ['A','B','C'] - for ss in subdir: - flying = os.listdir(flying_dir+ss) + for ss in subdir: + flying = os.listdir(flying_dir+ss) - for ff in flying: - imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/') - for im in imm_l: - if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im): - test_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im) + for ff in flying: + imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/') + for im in imm_l: + if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im): + test_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im) + + test_left_disp.append(flying_disp+'/TEST/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm') - test_left_disp.append(flying_disp+'/TEST/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm') + if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im): + test_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im) - if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im): - test_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im) + driving_dir = filepath + [x for x in image if 'driving' in x][0] + '/' + driving_disp = filepath + [x for x in disp if 'driving' in x][0] - driving_dir = filepath + [x for x in image if 'driving' in x][0] + '/' - driving_disp = filepath + [x for x in disp if 'driving' in x][0] + subdir1 = ['35mm_focallength','15mm_focallength'] + subdir2 = ['scene_backwards','scene_forwards'] + subdir3 = ['fast','slow'] - subdir1 = ['35mm_focallength','15mm_focallength'] - subdir2 = ['scene_backwards','scene_forwards'] - subdir3 = ['fast','slow'] + for i in subdir1: + for j in subdir2: + for k in subdir3: + imm_l = os.listdir(driving_dir+i+'/'+j+'/'+k+'/left/') + for im in imm_l: + if is_image_file(driving_dir+i+'/'+j+'/'+k+'/left/'+im): + all_left_img.append(driving_dir+i+'/'+j+'/'+k+'/left/'+im) - for i in subdir1: - for j in subdir2: - for k in subdir3: - imm_l = os.listdir(driving_dir+i+'/'+j+'/'+k+'/left/') - for im in imm_l: - if is_image_file(driving_dir+i+'/'+j+'/'+k+'/left/'+im): - all_left_img.append(driving_dir+i+'/'+j+'/'+k+'/left/'+im) - all_left_disp.append(driving_disp+'/'+i+'/'+j+'/'+k+'/left/'+im.split(".")[0]+'.pfm') + all_left_disp.append(driving_disp+'/'+i+'/'+j+'/'+k+'/left/'+im.split(".")[0]+'.pfm') - if is_image_file(driving_dir+i+'/'+j+'/'+k+'/right/'+im): - all_right_img.append(driving_dir+i+'/'+j+'/'+k+'/right/'+im) + if is_image_file(driving_dir+i+'/'+j+'/'+k+'/right/'+im): + all_right_img.append(driving_dir+i+'/'+j+'/'+k+'/right/'+im) - return all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp + return all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp diff --git a/dataloader/readpfm.py b/dataloader/readpfm.py index c4b1536..6223627 100644 --- a/dataloader/readpfm.py +++ b/dataloader/readpfm.py @@ -1,7 +1,7 @@ import re import numpy as np import sys - +import chardet def readPFM(file): file = open(file, 'rb') @@ -13,6 +13,8 @@ def readPFM(file): endian = None header = file.readline().rstrip() + encode_type = chardet.detect(header) + header = header.decode(encode_type['encoding']) if header == 'PF': color = True elif header == 'Pf': @@ -20,13 +22,13 @@ def readPFM(file): else: raise Exception('Not a PFM file.') - dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode(encode_type['encoding'])) if dim_match: width, height = map(int, dim_match.groups()) else: raise Exception('Malformed PFM header.') - scale = float(file.readline().rstrip()) + scale = float(file.readline().rstrip().decode(encode_type['encoding'])) if scale < 0: # little-endian endian = '<' scale = -scale @@ -40,3 +42,4 @@ def readPFM(file): data = np.flipud(data) return data, scale + diff --git a/dataset/Readme.md b/dataset/Readme.md new file mode 100644 index 0000000..b0bcf10 --- /dev/null +++ b/dataset/Readme.md @@ -0,0 +1,22 @@ +* SceneFlow includes three datasets: flything3d, driving and monkaa. +* You can train PSMNet with some of three datasets, or all of them. +* the following is the describtion of six subfolder. +``` +# the disp folder of Driving dataset +driving_disparity +# the image folder of Driving dataset +driving_frames_cleanpass + +# the disp folder of Flything3D dataset +frames_cleanpass +# the image folder of Flything3D dataset +frames_disparity + +# the disp folder of Monkaa dataset +monkaa_disparity +# the image folder of Monkaa dataset +monkaa_frames_cleanpass +``` +* Download the dataset from [this](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html). And unzip them to corresponding folder. + +* `data_scene_flow_2015` is the folder for kitti15. You can unzip kitti15 to this folder. This will be used in **test** pahse. diff --git a/finetune.py b/finetune.py index c944c9a..27a76f9 100644 --- a/finetune.py +++ b/finetune.py @@ -16,6 +16,7 @@ import numpy as np import time import math +import copy from dataloader import KITTIloader2015 as ls from dataloader import KITTILoader as DA @@ -125,7 +126,7 @@ def test(imgL,imgR,disp_true): pred_disp = output3.data.cpu() #computing 3-px error# - true_disp = disp_true + true_disp = copy.deepcopy(disp_true) index = np.argwhere(true_disp>0) disp_true[index[0][:], index[1][:], index[2][:]] = np.abs(true_disp[index[0][:], index[1][:], index[2][:]]-pred_disp[index[0][:], index[1][:], index[2][:]]) correct = (disp_true[index[0][:], index[1][:], index[2][:]] < 3)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05) diff --git a/main.py b/main.py index e22b15d..e3f6c7e 100644 --- a/main.py +++ b/main.py @@ -4,10 +4,7 @@ import random import torch import torch.nn as nn -import torch.nn.parallel -import torch.backends.cudnn as cudnn import torch.optim as optim -import torch.utils.data from torch.autograd import Variable import torch.nn.functional as F import numpy as np @@ -64,8 +61,9 @@ model.cuda() if args.loadmodel is not None: - state_dict = torch.load(args.loadmodel) - model.load_state_dict(state_dict['state_dict']) + print('Load pretrained model') + pretrain_dict = torch.load(args.loadmodel) + model.load_state_dict(pretrain_dict['state_dict']) print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) @@ -73,9 +71,6 @@ def train(imgL,imgR, disp_L): model.train() - imgL = Variable(torch.FloatTensor(imgL)) - imgR = Variable(torch.FloatTensor(imgR)) - disp_L = Variable(torch.FloatTensor(disp_L)) if args.cuda: imgL, imgR, disp_true = imgL.cuda(), imgR.cuda(), disp_L.cuda() @@ -100,30 +95,48 @@ def train(imgL,imgR, disp_L): loss.backward() optimizer.step() - return loss.data[0] + return loss.data def test(imgL,imgR,disp_true): + model.eval() - imgL = Variable(torch.FloatTensor(imgL)) - imgR = Variable(torch.FloatTensor(imgR)) + if args.cuda: - imgL, imgR = imgL.cuda(), imgR.cuda() - + imgL, imgR, disp_true = imgL.cuda(), imgR.cuda(), disp_true.cuda() #--------- mask = disp_true < 192 #---- + if imgL.shape[2] % 16 != 0: + times = imgL.shape[2]//16 + top_pad = (times+1)*16 -imgL.shape[2] + else: + top_pad = 0 + + if imgL.shape[3] % 16 != 0: + times = imgL.shape[3]//16 + right_pad = (times+1)*16-imgL.shape[3] + else: + right_pad = 0 + + imgL = F.pad(imgL,(0,right_pad, top_pad,0)) + imgR = F.pad(imgR,(0,right_pad, top_pad,0)) + with torch.no_grad(): output3 = model(imgL,imgR) - - output = torch.squeeze(output3.data.cpu(),1)[:,4:,:] + output3 = torch.squeeze(output3) + + if top_pad !=0: + img = output3[:,top_pad:,:] + else: + img = output3 if len(disp_true[mask])==0: loss = 0 else: - loss = torch.mean(torch.abs(output[mask]-disp_true[mask])) # end-point-error + loss = F.l1_loss(img[mask],disp_true[mask]) #torch.mean(torch.abs(img[mask]-disp_true[mask])) # end-point-error - return loss + return loss.data.cpu() def adjust_learning_rate(optimizer, epoch): lr = 0.001 @@ -135,7 +148,7 @@ def adjust_learning_rate(optimizer, epoch): def main(): start_full_time = time.time() - for epoch in range(1, args.epochs+1): + for epoch in range(0, args.epochs): print('This is %d-th epoch' %(epoch)) total_train_loss = 0 adjust_learning_rate(optimizer,epoch) diff --git a/models/basic.py b/models/basic.py index 5ee5a59..bbaefd5 100644 --- a/models/basic.py +++ b/models/basic.py @@ -5,7 +5,7 @@ from torch.autograd import Variable import torch.nn.functional as F import math -from submodule import * +from .submodule import * class PSMNet(nn.Module): def __init__(self, maxdisp): diff --git a/models/stackhourglass.py b/models/stackhourglass.py index 8430242..48238de 100644 --- a/models/stackhourglass.py +++ b/models/stackhourglass.py @@ -5,7 +5,7 @@ from torch.autograd import Variable import torch.nn.functional as F import math -from submodule import * +from .submodule import * class hourglass(nn.Module): def __init__(self, inplanes): @@ -107,9 +107,9 @@ def forward(self, left, right): #matching - cost = Variable(torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1]*2, self.maxdisp/4, refimg_fea.size()[2], refimg_fea.size()[3]).zero_()).cuda() + cost = Variable(torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1]*2, self.maxdisp//4, refimg_fea.size()[2], refimg_fea.size()[3]).zero_()).cuda() - for i in range(self.maxdisp/4): + for i in range(self.maxdisp//4): if i > 0 : cost[:, :refimg_fea.size()[1], i, :,i:] = refimg_fea[:,:,:,i:] cost[:, refimg_fea.size()[1]:, i, :,i:] = targetimg_fea[:,:,:,:-i] @@ -135,20 +135,23 @@ def forward(self, left, right): cost3 = self.classif3(out3) + cost2 if self.training: - cost1 = F.upsample(cost1, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear') - cost2 = F.upsample(cost2, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear') + cost1 = F.upsample(cost1, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear') + cost2 = F.upsample(cost2, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear') - cost1 = torch.squeeze(cost1,1) - pred1 = F.softmax(cost1,dim=1) - pred1 = disparityregression(self.maxdisp)(pred1) + cost1 = torch.squeeze(cost1,1) + pred1 = F.softmax(cost1,dim=1) + pred1 = disparityregression(self.maxdisp)(pred1) - cost2 = torch.squeeze(cost2,1) - pred2 = F.softmax(cost2,dim=1) - pred2 = disparityregression(self.maxdisp)(pred2) + cost2 = torch.squeeze(cost2,1) + pred2 = F.softmax(cost2,dim=1) + pred2 = disparityregression(self.maxdisp)(pred2) cost3 = F.upsample(cost3, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear') cost3 = torch.squeeze(cost3,1) pred3 = F.softmax(cost3,dim=1) + #For your information: This formulation 'softmax(c)' learned "similarity" + #while 'softmax(-c)' learned 'matching cost' as mentioned in the paper. + #However, 'c' or '-c' do not affect the performance because feature-based cost volume provided flexibility. pred3 = disparityregression(self.maxdisp)(pred3) if self.training: diff --git a/models/submodule.py b/models/submodule.py index a7c8d57..2953bfd 100644 --- a/models/submodule.py +++ b/models/submodule.py @@ -42,25 +42,13 @@ def forward(self, x): return out -class matchshifted(nn.Module): - def __init__(self): - super(matchshifted, self).__init__() - - def forward(self, left, right, shift): - batch, filters, height, width = left.size() - shifted_left = F.pad(torch.index_select(left, 3, Variable(torch.LongTensor([i for i in range(shift,width)])).cuda()),(shift,0,0,0)) - shifted_right = F.pad(torch.index_select(right, 3, Variable(torch.LongTensor([i for i in range(width-shift)])).cuda()),(shift,0,0,0)) - out = torch.cat((shifted_left,shifted_right),1).view(batch,filters*2,1,height,width) - return out - class disparityregression(nn.Module): def __init__(self, maxdisp): super(disparityregression, self).__init__() - self.disp = Variable(torch.Tensor(np.reshape(np.array(range(maxdisp)),[1,maxdisp,1,1])).cuda(), requires_grad=False) + self.disp = torch.Tensor(np.reshape(np.array(range(maxdisp)),[1, maxdisp,1,1])).cuda() def forward(self, x): - disp = self.disp.repeat(x.size()[0],1,x.size()[2],x.size()[3]) - out = torch.sum(x*disp,1) + out = torch.sum(x*self.disp.data,1, keepdim=True) return out class feature_extraction(nn.Module): diff --git a/run.sh b/run.sh index ce059ea..847c84d 100644 --- a/run.sh +++ b/run.sh @@ -2,7 +2,7 @@ python main.py --maxdisp 192 \ --model stackhourglass \ - --datapath /media/jiaren/ImageNet/SceneFlowData/ \ + --datapath dataset/ \ --epochs 0 \ --loadmodel ./trained/checkpoint_10.tar \ --savemodel ./trained/ @@ -12,7 +12,7 @@ python main.py --maxdisp 192 \ python finetune.py --maxdisp 192 \ --model stackhourglass \ --datatype 2015 \ - --datapath /media/jiaren/ImageNet/data_scene_flow_2015/training/ \ + --datapath dataset/data_scene_flow_2015/training/ \ --epochs 300 \ --loadmodel ./trained/checkpoint_10.tar \ --savemodel ./trained/ diff --git a/submission.py b/submission.py index ab74207..e80ab02 100644 --- a/submission.py +++ b/submission.py @@ -4,33 +4,24 @@ import random import torch import torch.nn as nn -import torch.nn.parallel -import torch.backends.cudnn as cudnn -import torch.optim as optim -import torch.utils.data -from torch.autograd import Variable +import torchvision.transforms as transforms import torch.nn.functional as F -import skimage -import skimage.io -import skimage.transform import numpy as np import time import math -from utils import preprocess from models import * - -# 2012 data /media/jiaren/ImageNet/data_scene_flow_2012/testing/ +from PIL import Image parser = argparse.ArgumentParser(description='PSMNet') parser.add_argument('--KITTI', default='2015', help='KITTI version') parser.add_argument('--datapath', default='/media/jiaren/ImageNet/data_scene_flow_2015/testing/', help='select model') -parser.add_argument('--loadmodel', default=None, +parser.add_argument('--loadmodel', default='./trained/pretrained_model_KITTI2015.tar', help='loading model') parser.add_argument('--model', default='stackhourglass', help='select model') -parser.add_argument('--maxdisp', type=int, default=192, +parser.add_argument('--maxdisp', default=192, help='maxium disparity') parser.add_argument('--no-cuda', action='store_true', default=False, help='enables CUDA training') @@ -48,7 +39,6 @@ else: from dataloader import KITTI_submission_loader2012 as DA - test_left_img, test_right_img = DA.dataloader(args.datapath) if args.model == 'stackhourglass': @@ -68,51 +58,63 @@ print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) def test(imgL,imgR): - model.eval() + model.eval() + + if args.cuda: + imgL = imgL.cuda() + imgR = imgR.cuda() - if args.cuda: - imgL = torch.FloatTensor(imgL).cuda() - imgR = torch.FloatTensor(imgR).cuda() + with torch.no_grad(): + output = model(imgL,imgR) + output = torch.squeeze(output).data.cpu().numpy() + return output - imgL, imgR= Variable(imgL), Variable(imgR) +def main(): + normal_mean_var = {'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225]} + infer_transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize(**normal_mean_var)]) - with torch.no_grad(): - output = model(imgL,imgR) - output = torch.squeeze(output) - pred_disp = output.data.cpu().numpy() + for inx in range(len(test_left_img)): - return pred_disp + imgL_o = Image.open(test_left_img[inx]).convert('RGB') + imgR_o = Image.open(test_right_img[inx]).convert('RGB') + imgL = infer_transform(imgL_o) + imgR = infer_transform(imgR_o) -def main(): - processed = preprocess.get_transform(augment=False) + # pad to width and hight to 16 times + if imgL.shape[1] % 16 != 0: + times = imgL.shape[1]//16 + top_pad = (times+1)*16 -imgL.shape[1] + else: + top_pad = 0 + + if imgL.shape[2] % 16 != 0: + times = imgL.shape[2]//16 + right_pad = (times+1)*16-imgL.shape[2] + else: + right_pad = 0 - for inx in range(len(test_left_img)): + imgL = F.pad(imgL,(0,right_pad, top_pad,0)).unsqueeze(0) + imgR = F.pad(imgR,(0,right_pad, top_pad,0)).unsqueeze(0) - imgL_o = (skimage.io.imread(test_left_img[inx]).astype('float32')) - imgR_o = (skimage.io.imread(test_right_img[inx]).astype('float32')) - imgL = processed(imgL_o).numpy() - imgR = processed(imgR_o).numpy() - imgL = np.reshape(imgL,[1,3,imgL.shape[1],imgL.shape[2]]) - imgR = np.reshape(imgR,[1,3,imgR.shape[1],imgR.shape[2]]) + start_time = time.time() + pred_disp = test(imgL,imgR) + print('time = %.2f' %(time.time() - start_time)) - # pad to (384, 1248) - top_pad = 384-imgL.shape[2] - left_pad = 1248-imgL.shape[3] - imgL = np.lib.pad(imgL,((0,0),(0,0),(top_pad,0),(0,left_pad)),mode='constant',constant_values=0) - imgR = np.lib.pad(imgR,((0,0),(0,0),(top_pad,0),(0,left_pad)),mode='constant',constant_values=0) + if top_pad !=0 or right_pad != 0: + img = pred_disp[top_pad:,:-right_pad] + else: + img = pred_disp - start_time = time.time() - pred_disp = test(imgL,imgR) - print('time = %.2f' %(time.time() - start_time)) + img = (img*256).astype('uint16') + img = Image.fromarray(img) + img.save(test_left_img[inx].split('/')[-1]) - top_pad = 384-imgL_o.shape[0] - left_pad = 1248-imgL_o.shape[1] - img = pred_disp[top_pad:,:-left_pad] - skimage.io.imsave(test_left_img[inx].split('/')[-1],(img*256).astype('uint16')) if __name__ == '__main__': - main() + main()