Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 690af57

Browse files
committed
Enable modular station selection
1 parent d35f034 commit 690af57

File tree

10 files changed

+426
-256
lines changed

10 files changed

+426
-256
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Neural networks for kidney segmentation in UK Biobank neck-to-knee body MRI
1+
# Neural networks for semantic segmentation of UK Biobank neck-to-knee body MRI
22

33
This repository contains PyTorch code for cross-validation and inference with neural networks for kidney segmentation on UK Biobank neck-to-knee body MRI, as described in:
44
[_"Kidney segmentation in neck-to-knee body MRI of 40,000 UK Biobank participants"_](https://arxiv.org/abs/2006.06996) [1]
@@ -8,7 +8,7 @@ The included inference pipeline and trained snapshot enables measurements of lef
88
Contents:
99
- 2.5D U-Net architecture with residual connections
1010
- Infrastructure for training and *cross-validation*
11-
- Pipeline for *inference* on MRI DICOMs
11+
- Pipeline for *inference* on neck-to-knee body MRI DICOMs
1212
- Code for *quality_controls* based on numerical metrics
1313
- Trained snapshot for parenchymal kidney tissue can be found at *TODO*
1414

cross_validation/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ Using the code samples in the *scripts* subfolder, a segmentation model can be t
66
3. Run *crossValidate.py* to train and evaluate a model, with the results stored to the directory "networks"
77

88
To re-train a network for inference using all data, the cross-validation split can simply be set to contain all images in one split set.
9+
10+
Note that the first run on new training data may be very slow, whereas subsequent runs will benefit from caching by the data loader.

cross_validation/scripts/createNewSplit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# Any ids that should be added only for training can be listed
66
# in a separate file at a later stage.
77

8-
split_name = "kidney_64_8fold"
8+
split_name = "liver_99_8fold"
99
split_path = "../splits/" + split_name + "/"
1010

1111
id_list = split_path + "id_list.txt"

cross_validation/scripts/createTrainingSlices.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,22 @@
2424
def main(argv):
2525

2626
#
27-
path_img = "/media/taro/DATA/Taro/UKBiobank/segmentations/kidney/combined_128/signals/NRRD/"
28-
path_seg = "/media/taro/DATA/Taro/UKBiobank/segmentations/kidney/combined_128/segmentations/NRRD/"
29-
path_ids = "/media/taro/DATA/Taro/UKBiobank/segmentations/kidney/combined_128/subject_ids.txt"
30-
output_path = "../image_data/kidney_128/"
27+
#path_img = "/media/taro/DATA/Taro/UKBiobank/segmentations/kidney/combined_128/signals/NRRD/"
28+
#path_seg = "/media/taro/DATA/Taro/UKBiobank/segmentations/kidney/combined_128/segmentations/NRRD/"
29+
#path_ids = "/media/taro/DATA/Taro/UKBiobank/segmentations/kidney/combined_128/subject_ids.txt"
30+
#output_path = "../image_data/kidney_128/"
31+
32+
#
33+
#path_img = "/media/taro/DATA/Taro/Projects/ukb_segmentation/github/temp_volumes/liver/signals/NRRD_3/"
34+
#path_seg = "/media/taro/DATA/Taro/Projects/ukb_segmentation/github/temp_volumes/liver/segmentations/NRRD_fixedHeaders/"
35+
#path_ids = "/media/taro/DATA/Taro/Projects/ukb_segmentation/github/temp_volumes/liver/ids.txt"
36+
#output_path = "../image_data/liver_allStations/"
37+
38+
#
39+
path_img = "/media/taro/DATA/Taro/UKBiobank/segmentations/liver/Andres_refined/signals/"
40+
path_seg = "/media/taro/DATA/Taro/UKBiobank/segmentations/liver/Andres_refined/segmentations/"
41+
path_ids = "/media/taro/DATA/Taro/UKBiobank/segmentations/liver/Andres_refined/ids_add.txt"
42+
output_path = "../image_data/liver_refined_99_add/"
3143

3244
#####
3345
createFolders(output_path, overwrite=True)
@@ -92,7 +104,7 @@ def convertSubject(subject_id, files_img, files_seg, path_img, path_seg, output_
92104
(slices_seg, shape_seg) = formatSeg(path_seg + file_s, shape_img)
93105

94106
if not np.array_equal(shape_img, shape_seg):
95-
print("ERROR: Mismatching dimensions for img and seg of {}".format(name))
107+
print("ERROR: Mismatching dimensions for img and seg of {} ({} vs {})".format(name, shape_img, shape_seg))
96108
sys.exit()
97109

98110
# For each axial slice, save outputs

cross_validation/scripts/crossValidate.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
import dataLoading
2222

2323

24+
# After the slice-wise training, the validation fuses all specified imaging stations
25+
# to a common space to calculate subject-wise Dice scores and other evaluation metrics
26+
c_target_spacing = np.array((2.23214293, 2.23214293, 4.5)) # abdominal spacing
27+
#c_target_spacing = np.array((2.23214293, 2.23214293, 3.)) # top station spacing
28+
2429
def main(argv):
2530

2631
path_network_out = "../networks/kidney_64_8fold/"
@@ -32,14 +37,18 @@ def main(argv):
3237
path_stations_img = "/media/taro/DATA/Taro/UKBiobank/segmentations/kidney/combined_128/signals/NRRD/"
3338
path_stations_gt = "/media/taro/DATA/Taro/UKBiobank/segmentations/kidney/combined_128/segmentations/NRRD_fixedHeaders/"
3439

35-
# Optional path to list of ids which are to be used as additional training samples on each split.
40+
# Select which MRI stations to use for training and evaluation
41+
station_ids = [0, 1, 2]
42+
43+
# Optional name of list file in split path with ids
44+
# which are to be used as additional training samples, in each split.
3645
# Set to None for conventional cross-validation
3746
path_train_ids_add = None
3847

39-
runExperiment(path_network_out, path_training_slices, path_split, path_stations_img, path_stations_gt, path_train_ids_add)
48+
runExperiment(path_network_out, path_training_slices, path_split, path_stations_img, path_stations_gt, path_train_ids_add, station_ids)
4049

4150

42-
def runExperiment(path_network_out, path_training_slices, path_split, path_stations_img, path_stations_gt, path_train_ids_add):
51+
def runExperiment(path_network_out, path_training_slices, path_split, path_stations_img, path_stations_gt, path_train_ids_add, station_ids):
4352

4453
I = 80000 # Training iterations
4554
save_step = 5000 # Iterations between checkpoint saving
@@ -50,7 +59,7 @@ def runExperiment(path_network_out, path_training_slices, path_split, path_stati
5059
class_count = 2 # Number of labels, including background
5160
class_weights = torch.FloatTensor([1, 1]) # Background, L1, L2...
5261

53-
start_k = 0 # First cross-validation set to validate against
62+
start_k = 0 # First cross-validation set to train and validate against
5463

5564
do_train = True
5665
do_predict = True
@@ -97,13 +106,13 @@ def runExperiment(path_network_out, path_training_slices, path_split, path_stati
97106
os.makedirs(path_out_k)
98107
os.makedirs(path_checkpoints)
99108

100-
loader_train = getDataloader(path_training_slices + "data/", path_out_k + "train_files.txt", train_subsets, path_split, B=1, sigma=2, points=8, path_train_ids_add=path_train_ids_add)
109+
loader_train = getDataloader(path_training_slices + "data/", path_out_k + "train_files.txt", train_subsets, path_split, B=1, sigma=2, points=8, path_train_ids_add=path_train_ids_add, station_ids=station_ids)
101110
time = train.train(net, loader_train, I, path_checkpoints, save_step, class_weights, I_reduce_lr)
102111

103112
with open(path_out_k + "training_time.txt", "w") as f: f.write("{}".format(time))
104113

105114
if do_predict:
106-
evaluate.evaluateSnapshots(path_checkpoints, path_stations_img, path_stations_gt, path_split, val_subset, path_out_k + "eval/", net)
115+
evaluate.evaluateSnapshots(path_checkpoints, path_stations_img, path_stations_gt, path_split, val_subset, path_out_k + "eval/", net, station_ids, c_target_spacing)
107116

108117
evaluate.writeSubsetTrainingCurve(path_out_k)
109118

@@ -121,7 +130,7 @@ def createDocumentation(network_path, split_path):
121130

122131

123132
#
124-
def getDataloader(input_path, output_path, subsets, path_split, B, sigma, points, path_train_ids_add):
133+
def getDataloader(input_path, output_path, subsets, path_split, B, sigma, points, path_train_ids_add, station_ids):
125134

126135
# Get chosen volumes
127136
subject_ids = []
@@ -140,10 +149,10 @@ def getDataloader(input_path, output_path, subsets, path_split, B, sigma, points
140149

141150
print("Loading data for {} subjects".format(len(subject_ids)))
142151

143-
# For each subject, use stations 1 and 2
152+
# For each subject, use the specified stations
144153
stations = []
145-
stations.extend([f + "_station1" for f in subject_ids])
146-
stations.extend([f + "_station2" for f in subject_ids])
154+
for s in station_ids:
155+
stations.extend([f + "_station{}".format(s) for f in subject_ids])
147156

148157
# Get training samples
149158
files = [f for f in os.listdir(input_path) if os.path.isfile(os.path.join(input_path, f))]

cross_validation/scripts/evaluate.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def plotCurve(path_out, x, y, label_x, label_y):
217217
plt.close()
218218

219219

220-
def evaluateSnapshots(path_checkpoints, path_stations_img, path_stations_gt, path_split, val_subset, path_out, net):
220+
def evaluateSnapshots(path_checkpoints, path_stations_img, path_stations_gt, path_split, val_subset, path_out, net, station_ids, target_spacing):
221221

222222
time_start = time.time()
223223

@@ -236,8 +236,7 @@ def evaluateSnapshots(path_checkpoints, path_stations_img, path_stations_gt, pat
236236

237237
# Fuse and store reference segmentation
238238
for i in range(N):
239-
240-
fuseStationsGt(val_subjects[i], path_stations_gt, path_out + "volumes/")
239+
fuseStationsGt(val_subjects[i], path_stations_img, path_stations_gt, path_out + "volumes/", station_ids, target_spacing)
241240

242241
# Find checkpoints
243242
checkpoint_files = [f for f in os.listdir(path_checkpoints) if os.path.isfile(os.path.join(path_checkpoints, f))]
@@ -250,7 +249,7 @@ def evaluateSnapshots(path_checkpoints, path_stations_img, path_stations_gt, pat
250249
print(" Evaluating snapshot {}...".format(i))
251250
checkpoint_i = path_checkpoints + checkpoint_files[i]
252251

253-
predictWithCheckpoint(checkpoint_i, path_stations_img, val_subjects, net, path_out + "volumes/")
252+
predictWithCheckpoint(checkpoint_i, path_stations_img, val_subjects, net, path_out + "volumes/", station_ids, target_spacing)
254253

255254
iteration = checkpoint_files[i].split("_")[1].split(".")[0]
256255
evaluateAgreement(path_out, iteration, val_subjects)
@@ -281,6 +280,7 @@ def evaluateAgreement(path_out, iteration, val_subjects):
281280
voxel_dim = np.array((space_dir[0][0], space_dir[1][1], space_dir[2][2]))
282281
voxel_scale = np.prod(voxel_dim) / (10*10*10)
283282

283+
284284
# Get positives, true positives, false positives
285285
p = np.count_nonzero(gt)
286286
tp = np.count_nonzero(np.multiply(gt, out))
@@ -295,7 +295,7 @@ def evaluateAgreement(path_out, iteration, val_subjects):
295295
f.write("{},{},{},{},{},{}\n".format(val_subjects[i],dice,p,tp,fp,voxel_scale))
296296

297297

298-
def predictWithCheckpoint(path_checkpoint, path_stations_img, val_subjects, net, path_out):
298+
def predictWithCheckpoint(path_checkpoint, path_stations_img, val_subjects, net, path_out, station_ids, target_spacing):
299299

300300
# Load network weights
301301
checkpoint = torch.load(path_checkpoint, map_location={"cuda" : "cpu"})
@@ -309,32 +309,71 @@ def predictWithCheckpoint(path_checkpoint, path_stations_img, val_subjects, net,
309309

310310
print("Subject {}".format(val_subjects[i]))
311311

312-
(img_1, header_1) = nrrd.read(path_stations_img + "{}_station1_W.nrrd".format(val_subjects[i]))
313-
(img_2, header_2) = nrrd.read(path_stations_img + "{}_station2_W.nrrd".format(val_subjects[i]))
312+
stations = []
313+
headers = []
314314

315-
(img, out, header, _, _) = predictForSubject.predictForSubject([img_1, img_2], [header_1, header_2], net)
315+
for s in station_ids:
316+
317+
(station, header) = nrrd.read(path_stations_img + "{}_station{}_W.nrrd".format(val_subjects[i], s))
318+
stations.append(station)
319+
headers.append(header)
320+
#(img_2, header_2) = nrrd.read(path_stations_img + "{}_station2_W.nrrd".format(val_subjects[i]))
316321

317322
if not os.path.exists(path_out + "{}_img.nrrd".format(val_subjects[i])):
323+
fuse_img = True
324+
else:
325+
fuse_img = False
326+
327+
(img, out, header, _, _) = predictForSubject.predictForSubject(stations, headers, net, target_spacing, fuse_img)
328+
329+
if fuse_img:
318330
nrrd.write(path_out + "{}_img.nrrd".format(val_subjects[i]), img, header, compression_level=1)
319331

320332
nrrd.write(path_out + "{}_out.nrrd".format(val_subjects[i]), out, header, compression_level=1)
321333

322334

323-
def fuseStationsGt(subject_id, path_stations_gt, path_out):
335+
def fuseStationsGt(subject_id, path_stations_img, path_stations_gt, path_out, station_ids, target_spacing):
324336

325-
(gt_1, header_1) = nrrd.read(path_stations_gt + "{}_station1.nrrd".format(subject_id))
326-
(gt_2, header_2) = nrrd.read(path_stations_gt + "{}_station2.nrrd".format(subject_id))
337+
volumes_gt = []
338+
headers_gt = []
339+
positions = []
340+
spacings = []
327341

328-
# Rounding before fusion appears to give best results for SmartPaint values
329-
gt_1 = np.around(gt_1)
330-
gt_2 = np.around(gt_2)
342+
for s in station_ids:
331343

332-
#
333-
(W, W_size, W_end, scalings, offsets) = fuseVolumes.getResamplingParameters([gt_1, gt_2], [header_1, header_2])
344+
path_s = path_stations_gt + "{}_station{}.nrrd".format(subject_id, s)
345+
346+
if not os.path.exists(path_s):
347+
print("WARNING: Could not find ground truth segmentation, assuming empty segmentation for {}".format(path_s))
348+
path_s = path_stations_img + "{}_station{}_W.nrrd".format(subject_id, s)
349+
350+
# Load signal instead and set values to 0
351+
(volume_gt, header) = nrrd.read(path_s)
352+
volume_gt[:] = 0
334353

335-
(gt, seg_fusion_cost) = fuseVolumes.fuseStations(gt_1, gt_2, W, W_size, W_end, scalings, offsets, False)
354+
else:
355+
(volume_gt, header) = nrrd.read(path_s)
356+
357+
# Round volumes to binarize segmentations from SmartPaint. Using the float values appears to provide no benefit
358+
volume_gt = np.around(volume_gt)
359+
360+
volumes_gt.append(volume_gt)
361+
headers_gt.append(header)
362+
363+
#
364+
positions.append(header["space origin"])
365+
366+
spacing = header["space directions"]
367+
spacing = np.array((spacing[0][0], spacing[1][1], spacing[2][2]))
368+
369+
spacings.append(spacing)
370+
371+
#
372+
(gt, gt_origin,seg_fusion_cost) = fuseVolumes.fuseStations(volumes_gt, positions, spacings, target_spacing, False)
336373

337-
header = header_1
374+
header = headers_gt[0]
338375
header["sizes"] = gt.shape
376+
header["space origin"] = gt_origin
377+
for i in range(3): header["space directions"][i][i] = target_spacing[i]
339378

340379
nrrd.write(path_out + "{}_gt.nrrd".format(subject_id), gt, header, compression_level=1)

image_fusion/dicomToVolume.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
c_use_gpu = True # If yes, use numba for gpu access, otherwise use scipy on cpu
4343

4444

45-
def dicomToVolume(input_path_zip):
45+
def dicomToVolume(input_path_zip, station_ids):
4646

4747
if not os.path.exists(input_path_zip):
4848
print("Could not find input file {}".format(input_path_zip))
@@ -59,7 +59,7 @@ def dicomToVolume(input_path_zip):
5959
headers = []
6060

6161
# Only use abdominal imaging stations
62-
for i in range(1, 3):
62+
for i in station_ids:
6363

6464
#
6565
voxel_data_w[i] = np.flip(voxel_data_w[i], 2)

0 commit comments

Comments
 (0)