Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 80bb49c

Browse files
author
YoYo000HKUST
committed
1. Training with BlendedMVS & ETH3D. 2. Online photometric augmentation. 3. Set tf verbosity level to ERROR
1 parent 771ff63 commit 80bb49c

File tree

7 files changed

+224
-36
lines changed

7 files changed

+224
-36
lines changed

cnn_wrapper/mvsnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class RegNetUS0(Network):
113113
"""network for regularizing 3D cost volume in a encoder-decoder style. Keeping original size."""
114114

115115
def setup(self):
116-
print ('3D with 8 filters')
116+
print ('Shallow 3D UNet with 8 channel input')
117117
base_filter = 8
118118
(self.feed('data')
119119
.conv_bn(3, base_filter * 2, 2, center=True, scale=True, name='3dconv1_0')
37.1 MB
Binary file not shown.

mvsnet/photometric_augmentation.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import cv2 as cv
2+
import numpy as np
3+
import tensorflow as tf
4+
5+
augmentations = [
6+
# 'additive_gaussian_noise',
7+
# 'additive_speckle_noise',
8+
'random_brightness',
9+
'random_contrast',
10+
# 'additive_shade',
11+
'motion_blur'
12+
]
13+
14+
def additive_gaussian_noise(image, stddev_range=[5, 95]):
15+
stddev = tf.random_uniform((), *stddev_range)
16+
noise = tf.random_normal(tf.shape(image), stddev=stddev)
17+
noisy_image = tf.clip_by_value(image + noise, 0, 255)
18+
return noisy_image
19+
20+
21+
def additive_speckle_noise(image, prob_range=[0.0, 0.005]):
22+
prob = tf.random_uniform((), *prob_range)
23+
sample = tf.random_uniform(tf.shape(image))
24+
noisy_image = tf.where(sample <= prob, tf.zeros_like(image), image)
25+
noisy_image = tf.where(sample >= (1. - prob), 255.*tf.ones_like(image), noisy_image)
26+
return noisy_image
27+
28+
29+
def random_brightness(image, max_abs_change=50):
30+
return tf.clip_by_value(tf.image.random_brightness(image, max_abs_change), 0, 255)
31+
32+
33+
def random_contrast(image, strength_range=[0.5, 1.5]):
34+
return tf.clip_by_value(tf.image.random_contrast(image, *strength_range), 0, 255)
35+
36+
37+
def additive_shade(image, nb_ellipses=20, transparency_range=[-0.5, 0.8],
38+
kernel_size_range=[250, 350]):
39+
40+
def _py_additive_shade(img):
41+
min_dim = min(img.shape[:2]) / 4
42+
mask = np.zeros(img.shape[:2], np.uint8)
43+
for i in range(nb_ellipses):
44+
ax = int(max(np.random.rand() * min_dim, min_dim / 5))
45+
ay = int(max(np.random.rand() * min_dim, min_dim / 5))
46+
max_rad = max(ax, ay)
47+
x = np.random.randint(max_rad, img.shape[1] - max_rad) # center
48+
y = np.random.randint(max_rad, img.shape[0] - max_rad)
49+
angle = np.random.rand() * 90
50+
cv.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1)
51+
52+
transparency = np.random.uniform(*transparency_range)
53+
kernel_size = np.random.randint(*kernel_size_range)
54+
if (kernel_size % 2) == 0: # kernel_size has to be odd
55+
kernel_size += 1
56+
mask = cv.GaussianBlur(mask.astype(np.float32), (kernel_size, kernel_size), 0)
57+
shaded = img * (1 - transparency * mask[..., np.newaxis]/255.)
58+
return np.clip(shaded, 0, 255)
59+
60+
shaded = tf.py_func(_py_additive_shade, [image], tf.float32)
61+
res = tf.reshape(shaded, tf.shape(image))
62+
return res
63+
64+
65+
def motion_blur(image, max_kernel_size=10):
66+
67+
def _py_motion_blur(img):
68+
# Either vertial, hozirontal or diagonal blur
69+
mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up'])
70+
ksize = np.random.randint(0, (max_kernel_size+1)/2)*2 + 1 # make sure is odd
71+
center = int((ksize-1)/2)
72+
kernel = np.zeros((ksize, ksize))
73+
if mode == 'h':
74+
kernel[center, :] = 1.
75+
elif mode == 'v':
76+
kernel[:, center] = 1.
77+
elif mode == 'diag_down':
78+
kernel = np.eye(ksize)
79+
elif mode == 'diag_up':
80+
kernel = np.flip(np.eye(ksize), 0)
81+
var = ksize * ksize / 16.
82+
grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1)
83+
gaussian = np.exp(-(np.square(grid-center)+np.square(grid.T-center))/(2.*var))
84+
kernel *= gaussian
85+
kernel /= np.sum(kernel)
86+
img = cv.filter2D(img, -1, kernel)
87+
return img
88+
89+
blurred = tf.numpy_function(_py_motion_blur, [image], tf.float32)
90+
return tf.reshape(blurred, tf.shape(image))
91+
92+
def online_augmentation(image, random_order=True):
93+
primitives = augmentations
94+
config = {}
95+
config['random_brightness'] = {'max_abs_change': 50}
96+
config['random_contrast'] = {'strength_range': [0.3, 1.5]}
97+
config['additive_gaussian_noise'] = {'stddev_range': [0, 10]}
98+
config['additive_speckle_noise'] = {'prob_range': [0, 0.0035]}
99+
config['additive_shade'] = {'transparency_range': [-0.5, 0.5], 'kernel_size_range': [100, 150]}
100+
config['motion_blur'] = {'max_kernel_size': 3}
101+
102+
with tf.name_scope('online_augmentation'):
103+
prim_configs = [config.get(p, {}) for p in primitives]
104+
105+
indices = tf.range(len(primitives))
106+
if random_order:
107+
indices = tf.random.shuffle(indices)
108+
109+
def step(i, image):
110+
fn_pairs = [(tf.equal(indices[i], j), lambda p=p, c=c: getattr(photaug, p)(image, **c))
111+
for j, (p, c) in enumerate(zip(primitives, prim_configs))]
112+
image = tf.case(fn_pairs)
113+
return i + 1, image
114+
115+
_, aug_image = tf.while_loop(lambda i, image: tf.less(i, len(primitives)),
116+
step, [0, image], parallel_iterations=1)
117+
118+
return aug_image

mvsnet/preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def gen_dtu_mvs_path(dtu_data_folder, mode='training'):
404404

405405
return sample_list
406406

407-
def gen_blended_mvs_path(blendedmvs_data_folder, mode='training'):
407+
def gen_blendedmvs_path(blendedmvs_data_folder, mode='training'):
408408
""" generate data paths for blendedmvs dataset """
409409

410410
# read data list

mvsnet/test.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,21 @@
1515

1616
import cv2
1717
import tensorflow as tf
18+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
1819

1920
sys.path.append("../")
2021
from tools.common import Notify
2122
from preprocess import *
2223
from model import *
2324
from loss import *
2425

25-
# dataset parameters
26+
# input path
2627
tf.app.flags.DEFINE_string('dense_folder', None,
2728
"""Root path to dense folder.""")
28-
tf.app.flags.DEFINE_string('model_dir',
29-
'/data/tf_model',
29+
tf.app.flags.DEFINE_string('pretrained_model_ckpt_path',
30+
'/data/tf_model/3DCNNs/BlendedMVS/blended_augmented/model.ckpt',
3031
"""Path to restore the model.""")
31-
tf.app.flags.DEFINE_integer('ckpt_step', 100000,
32+
tf.app.flags.DEFINE_integer('ckpt_step', 150000,
3233
"""ckpt step.""")
3334

3435
# input parameters
@@ -146,7 +147,7 @@ def __iter__(self):
146147
def mvsnet_pipeline(mvs_list):
147148

148149
""" mvsnet in altizure pipeline """
149-
print ('sample number: ', len(mvs_list))
150+
print ('Testing sample number: ', len(mvs_list))
150151

151152
# create output folder
152153
output_folder = os.path.join(FLAGS.dense_folder, 'depths_mvsnet')
@@ -213,12 +214,12 @@ def mvsnet_pipeline(mvs_list):
213214
total_step = 0
214215

215216
# load model
216-
if FLAGS.model_dir is not None:
217-
pretrained_model_ckpt_path = os.path.join(FLAGS.model_dir, FLAGS.regularization, 'model.ckpt')
217+
if FLAGS.pretrained_model_ckpt_path is not None:
218218
restorer = tf.train.Saver(tf.global_variables())
219-
restorer.restore(sess, '-'.join([pretrained_model_ckpt_path, str(FLAGS.ckpt_step)]))
219+
restorer.restore(
220+
sess, '-'.join([FLAGS.pretrained_model_ckpt_path, str(FLAGS.ckpt_step)]))
220221
print(Notify.INFO, 'Pre-trained model restored from %s' %
221-
('-'.join([pretrained_model_ckpt_path, str(FLAGS.ckpt_step)])), Notify.ENDC)
222+
('-'.join([FLAGS.pretrained_model_ckpt_path, str(FLAGS.ckpt_step)])), Notify.ENDC)
222223
total_step = FLAGS.ckpt_step
223224

224225
# run inference for each reference view
@@ -270,4 +271,5 @@ def main(_): # pylint: disable=unused-argument
270271

271272

272273
if __name__ == '__main__':
274+
print ('Testing MVSNet with totally %d view inputs (including reference view)' % FLAGS.view_num)
273275
tf.app.run()

0 commit comments

Comments
 (0)