Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/caffe/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,11 @@ int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
bottom_vecs_[layer_id].push_back(blobs_[blob_id].get());
bottom_id_vecs_[layer_id].push_back(blob_id);
available_blobs->erase(blob_name);
bool propagate_down = true;
bool need_backward = blob_need_backward_[blob_id];
// Check if the backpropagation on bottom_id should be skipped
if (layer_param.propagate_down_size() > 0)
propagate_down = layer_param.propagate_down(bottom_id);
const bool need_backward = blob_need_backward_[blob_id] &&
propagate_down;
if (layer_param.propagate_down_size() > 0) {
need_backward = layer_param.propagate_down(bottom_id);
}
bottom_need_backward_[layer_id].push_back(need_backward);
return blob_id;
}
Expand Down
7 changes: 6 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,12 @@ message LayerParameter {
// The blobs containing the numeric parameters of the layer.
repeated BlobProto blobs = 7;

// Specifies on which bottoms the backpropagation should be skipped.
// Specifies whether to backpropagate to each bottom. If unspecified,
// Caffe will automatically infer whether each input needs backpropagation
// to compute parameter gradients. If set to true for some inputs,
// backpropagation to those inputs is forced; if set false for some inputs,
// backpropagation to those inputs is skipped.
//
// The size must be either 0 or equal to the number of bottoms.
repeated bool propagate_down = 11;

Expand Down
102 changes: 102 additions & 0 deletions src/caffe/test/test_net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,61 @@ class NetTest : public MultiDeviceTest<TypeParam> {
InitNetFromProtoString(proto);
}

virtual void InitForcePropNet(bool test_force_true) {
string proto =
"name: 'ForcePropTestNetwork' "
"layer { "
" name: 'data' "
" type: 'DummyData' "
" dummy_data_param { "
" shape { "
" dim: 5 "
" dim: 2 "
" dim: 3 "
" dim: 4 "
" } "
" data_filler { "
" type: 'gaussian' "
" std: 0.01 "
" } "
" shape { "
" dim: 5 "
" } "
" data_filler { "
" type: 'constant' "
" value: 0 "
" } "
" } "
" top: 'data' "
" top: 'label' "
"} "
"layer { "
" name: 'innerproduct' "
" type: 'InnerProduct' "
" inner_product_param { "
" num_output: 1 "
" weight_filler { "
" type: 'gaussian' "
" std: 0.01 "
" } "
" } "
" bottom: 'data' "
" top: 'innerproduct' ";
if (test_force_true) {
proto += " propagate_down: true ";
}
proto +=
"} "
"layer { "
" name: 'loss' "
" bottom: 'innerproduct' "
" bottom: 'label' "
" top: 'cross_entropy_loss' "
" type: 'SigmoidCrossEntropyLoss' "
"} ";
InitNetFromProtoString(proto);
}

int seed_;
shared_ptr<Net<Dtype> > net_;
};
Expand Down Expand Up @@ -2371,4 +2426,51 @@ TYPED_TEST(NetTest, TestSkipPropagateDown) {
}
}

TYPED_TEST(NetTest, TestForcePropagateDown) {
this->InitForcePropNet(false);
vector<bool> layer_need_backward = this->net_->layer_need_backward();
for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
const string& layer_name = this->net_->layer_names()[layer_id];
const vector<bool> need_backward =
this->net_->bottom_need_backward()[layer_id];
if (layer_name == "data") {
ASSERT_EQ(need_backward.size(), 0);
EXPECT_FALSE(layer_need_backward[layer_id]);
} else if (layer_name == "innerproduct") {
ASSERT_EQ(need_backward.size(), 1);
EXPECT_FALSE(need_backward[0]); // data
EXPECT_TRUE(layer_need_backward[layer_id]);
} else if (layer_name == "loss") {
ASSERT_EQ(need_backward.size(), 2);
EXPECT_TRUE(need_backward[0]); // innerproduct
EXPECT_FALSE(need_backward[1]); // label
EXPECT_TRUE(layer_need_backward[layer_id]);
} else {
LOG(FATAL) << "Unknown layer: " << layer_name;
}
}
this->InitForcePropNet(true);
layer_need_backward = this->net_->layer_need_backward();
for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
const string& layer_name = this->net_->layer_names()[layer_id];
const vector<bool> need_backward =
this->net_->bottom_need_backward()[layer_id];
if (layer_name == "data") {
ASSERT_EQ(need_backward.size(), 0);
EXPECT_FALSE(layer_need_backward[layer_id]);
} else if (layer_name == "innerproduct") {
ASSERT_EQ(need_backward.size(), 1);
EXPECT_TRUE(need_backward[0]); // data
EXPECT_TRUE(layer_need_backward[layer_id]);
} else if (layer_name == "loss") {
ASSERT_EQ(need_backward.size(), 2);
EXPECT_TRUE(need_backward[0]); // innerproduct
EXPECT_FALSE(need_backward[1]); // label
EXPECT_TRUE(layer_need_backward[layer_id]);
} else {
LOG(FATAL) << "Unknown layer: " << layer_name;
}
}
}

} // namespace caffe