@@ -217,7 +217,7 @@ def plotCurve(path_out, x, y, label_x, label_y):
217217 plt .close ()
218218
219219
220- def evaluateSnapshots (path_checkpoints , path_stations_img , path_stations_gt , path_split , val_subset , path_out , net ):
220+ def evaluateSnapshots (path_checkpoints , path_stations_img , path_stations_gt , path_split , val_subset , path_out , net , station_ids , target_spacing ):
221221
222222 time_start = time .time ()
223223
@@ -236,8 +236,7 @@ def evaluateSnapshots(path_checkpoints, path_stations_img, path_stations_gt, pat
236236
237237 # Fuse and store reference segmentation
238238 for i in range (N ):
239-
240- fuseStationsGt (val_subjects [i ], path_stations_gt , path_out + "volumes/" )
239+ fuseStationsGt (val_subjects [i ], path_stations_img , path_stations_gt , path_out + "volumes/" , station_ids , target_spacing )
241240
242241 # Find checkpoints
243242 checkpoint_files = [f for f in os .listdir (path_checkpoints ) if os .path .isfile (os .path .join (path_checkpoints , f ))]
@@ -250,7 +249,7 @@ def evaluateSnapshots(path_checkpoints, path_stations_img, path_stations_gt, pat
250249 print (" Evaluating snapshot {}..." .format (i ))
251250 checkpoint_i = path_checkpoints + checkpoint_files [i ]
252251
253- predictWithCheckpoint (checkpoint_i , path_stations_img , val_subjects , net , path_out + "volumes/" )
252+ predictWithCheckpoint (checkpoint_i , path_stations_img , val_subjects , net , path_out + "volumes/" , station_ids , target_spacing )
254253
255254 iteration = checkpoint_files [i ].split ("_" )[1 ].split ("." )[0 ]
256255 evaluateAgreement (path_out , iteration , val_subjects )
@@ -281,6 +280,7 @@ def evaluateAgreement(path_out, iteration, val_subjects):
281280 voxel_dim = np .array ((space_dir [0 ][0 ], space_dir [1 ][1 ], space_dir [2 ][2 ]))
282281 voxel_scale = np .prod (voxel_dim ) / (10 * 10 * 10 )
283282
283+
284284 # Get positives, true positives, false positives
285285 p = np .count_nonzero (gt )
286286 tp = np .count_nonzero (np .multiply (gt , out ))
@@ -295,7 +295,7 @@ def evaluateAgreement(path_out, iteration, val_subjects):
295295 f .write ("{},{},{},{},{},{}\n " .format (val_subjects [i ],dice ,p ,tp ,fp ,voxel_scale ))
296296
297297
298- def predictWithCheckpoint (path_checkpoint , path_stations_img , val_subjects , net , path_out ):
298+ def predictWithCheckpoint (path_checkpoint , path_stations_img , val_subjects , net , path_out , station_ids , target_spacing ):
299299
300300 # Load network weights
301301 checkpoint = torch .load (path_checkpoint , map_location = {"cuda" : "cpu" })
@@ -309,32 +309,71 @@ def predictWithCheckpoint(path_checkpoint, path_stations_img, val_subjects, net,
309309
310310 print ("Subject {}" .format (val_subjects [i ]))
311311
312- ( img_1 , header_1 ) = nrrd . read ( path_stations_img + "{}_station1_W.nrrd" . format ( val_subjects [ i ]))
313- ( img_2 , header_2 ) = nrrd . read ( path_stations_img + "{}_station2_W.nrrd" . format ( val_subjects [ i ]))
312+ stations = []
313+ headers = []
314314
315- (img , out , header , _ , _ ) = predictForSubject .predictForSubject ([img_1 , img_2 ], [header_1 , header_2 ], net )
315+ for s in station_ids :
316+
317+ (station , header ) = nrrd .read (path_stations_img + "{}_station{}_W.nrrd" .format (val_subjects [i ], s ))
318+ stations .append (station )
319+ headers .append (header )
320+ #(img_2, header_2) = nrrd.read(path_stations_img + "{}_station2_W.nrrd".format(val_subjects[i]))
316321
317322 if not os .path .exists (path_out + "{}_img.nrrd" .format (val_subjects [i ])):
323+ fuse_img = True
324+ else :
325+ fuse_img = False
326+
327+ (img , out , header , _ , _ ) = predictForSubject .predictForSubject (stations , headers , net , target_spacing , fuse_img )
328+
329+ if fuse_img :
318330 nrrd .write (path_out + "{}_img.nrrd" .format (val_subjects [i ]), img , header , compression_level = 1 )
319331
320332 nrrd .write (path_out + "{}_out.nrrd" .format (val_subjects [i ]), out , header , compression_level = 1 )
321333
322334
323- def fuseStationsGt (subject_id , path_stations_gt , path_out ):
335+ def fuseStationsGt (subject_id , path_stations_img , path_stations_gt , path_out , station_ids , target_spacing ):
324336
325- (gt_1 , header_1 ) = nrrd .read (path_stations_gt + "{}_station1.nrrd" .format (subject_id ))
326- (gt_2 , header_2 ) = nrrd .read (path_stations_gt + "{}_station2.nrrd" .format (subject_id ))
337+ volumes_gt = []
338+ headers_gt = []
339+ positions = []
340+ spacings = []
327341
328- # Rounding before fusion appears to give best results for SmartPaint values
329- gt_1 = np .around (gt_1 )
330- gt_2 = np .around (gt_2 )
342+ for s in station_ids :
331343
332- #
333- (W , W_size , W_end , scalings , offsets ) = fuseVolumes .getResamplingParameters ([gt_1 , gt_2 ], [header_1 , header_2 ])
344+ path_s = path_stations_gt + "{}_station{}.nrrd" .format (subject_id , s )
345+
346+ if not os .path .exists (path_s ):
347+ print ("WARNING: Could not find ground truth segmentation, assuming empty segmentation for {}" .format (path_s ))
348+ path_s = path_stations_img + "{}_station{}_W.nrrd" .format (subject_id , s )
349+
350+ # Load signal instead and set values to 0
351+ (volume_gt , header ) = nrrd .read (path_s )
352+ volume_gt [:] = 0
334353
335- (gt , seg_fusion_cost ) = fuseVolumes .fuseStations (gt_1 , gt_2 , W , W_size , W_end , scalings , offsets , False )
354+ else :
355+ (volume_gt , header ) = nrrd .read (path_s )
356+
357+ # Round volumes to binarize segmentations from SmartPaint. Using the float values appears to provide no benefit
358+ volume_gt = np .around (volume_gt )
359+
360+ volumes_gt .append (volume_gt )
361+ headers_gt .append (header )
362+
363+ #
364+ positions .append (header ["space origin" ])
365+
366+ spacing = header ["space directions" ]
367+ spacing = np .array ((spacing [0 ][0 ], spacing [1 ][1 ], spacing [2 ][2 ]))
368+
369+ spacings .append (spacing )
370+
371+ #
372+ (gt , gt_origin ,seg_fusion_cost ) = fuseVolumes .fuseStations (volumes_gt , positions , spacings , target_spacing , False )
336373
337- header = header_1
374+ header = headers_gt [ 0 ]
338375 header ["sizes" ] = gt .shape
376+ header ["space origin" ] = gt_origin
377+ for i in range (3 ): header ["space directions" ][i ][i ] = target_spacing [i ]
339378
340379 nrrd .write (path_out + "{}_gt.nrrd" .format (subject_id ), gt , header , compression_level = 1 )
0 commit comments