-
Notifications
You must be signed in to change notification settings - Fork 5
Description
Sincere thanks for sharing the training codes!
when I set args.hc = 1, there is an error:
Traceback (most recent call last): File "scripts/train_lcfed.py", line 223, in
pred, x5, hmap = net_current(volume_batch, client_idx, 1, seg_heads)
File "/home/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/FedLC-main/scripts/../networks/lcrepnet.py", line 243, in forward
o = self.forward_for_train(x, site_index, stage, seg_heads)
File "/home/FedLC-main/scripts/../networks/lcrepnet.py", line 236, in forward_for_train
o = self.hc(uncertainty.detach(), preds.detach(), fea.detach())
File "/home/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/FedLC-main/scripts/../networks/lcrepnet.py", line 30, in forward
fea2 = self.soft(uncertainty[:, c].unsqueeze(1), fea) + fea
File "/home/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given