Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,9 @@ class PoolingLayer : public Layer<Dtype> {
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);

int max_top_blobs_;
int kernel_size_;
int stride_;
int pad_;
int kernel_h_, kernel_w_;
int stride_h_, stride_w_;
int pad_h_, pad_w_;
int channels_;
int height_;
int width_;
Expand Down
81 changes: 57 additions & 24 deletions src/caffe/layers/pooling_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,68 @@ void PoolingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
max_top_blobs_ = 1;
}
Layer<Dtype>::SetUp(bottom, top);
kernel_size_ = this->layer_param_.pooling_param().kernel_size();
stride_ = this->layer_param_.pooling_param().stride();
pad_ = this->layer_param_.pooling_param().pad();
if (pad_ != 0) {
PoolingParameter pool_param = this->layer_param_.pooling_param();
CHECK(!pool_param.has_kernel_size() !=
!(pool_param.has_kernel_h() && pool_param.has_kernel_w()))
<< "Filter size is kernel_size OR kernel_h and kernel_w; not both";
CHECK(pool_param.has_kernel_size() ||
(pool_param.has_kernel_h() && pool_param.has_kernel_w()))
<< "For non-square filters both kernel_h and kernel_w are required.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that when kernel_size is specified, the other two are ignored. So why not both?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For compatibility issue with previous caffe prototype file, we don't want to require users to update their network prototype. So, the user can either provide one parameter kernel_size or two parameters {kernel_size_w, kernel_size_h}. You may see #505 for detailed reasons.

CHECK((!pool_param.has_pad() && pool_param.has_pad_h()
&& pool_param.has_pad_w())
|| (!pool_param.has_pad_h() && !pool_param.has_pad_w()))
<< "pad is pad OR pad_h and pad_w are required.";
CHECK((!pool_param.has_stride() && pool_param.has_stride_h()
&& pool_param.has_stride_w())
|| (!pool_param.has_stride_h() && !pool_param.has_stride_w()))
<< "Stride is stride OR stride_h and stride_w are required.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conditions to decide which values of kernel_size, pad or stride are used should be consistent.

if (pool_param.has_kernel_size()) {
kernel_h_ = kernel_w_ = pool_param.kernel_size();
} else {
kernel_h_ = pool_param.kernel_h();
kernel_w_ = pool_param.kernel_w();
}
CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero.";
CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero.";
if (!pool_param.has_pad_h()) {
pad_h_ = pad_w_ = pool_param.pad();
} else {
pad_h_ = pool_param.pad_h();
pad_w_ = pool_param.pad_w();
}
if (!pool_param.has_stride_h()) {
stride_h_ = stride_w_ = pool_param.stride();
} else {
stride_h_ = pool_param.stride_h();
stride_w_ = pool_param.stride_w();
}
if (pad_h_ != 0 || pad_w_ != 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you check the kernel_size, then the pad and stride should be checked too.

CHECK(this->layer_param_.pooling_param().pool()
== PoolingParameter_PoolMethod_AVE
|| this->layer_param_.pooling_param().pool()
== PoolingParameter_PoolMethod_MAX)
<< "Padding implemented only for average and max pooling.";
CHECK_LT(pad_, kernel_size_);
CHECK_LT(pad_h_, kernel_h_);
CHECK_LT(pad_w_, kernel_w_);
}
channels_ = bottom[0]->channels();
height_ = bottom[0]->height();
width_ = bottom[0]->width();
pooled_height_ = static_cast<int>(ceil(static_cast<float>(
height_ + 2 * pad_ - kernel_size_) / stride_)) + 1;
height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
pooled_width_ = static_cast<int>(ceil(static_cast<float>(
width_ + 2 * pad_ - kernel_size_) / stride_)) + 1;
if (pad_) {
width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
if (pad_h_ || pad_w_) {
// If we have padding, ensure that the last pooling starts strictly
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once checked, they can be used freely.

// inside the image (instead of at the padding); otherwise clip the last.
if ((pooled_height_ - 1) * stride_ >= height_ + pad_) {
if ((pooled_height_ - 1) * stride_h_ >= height_ + pad_h_) {
--pooled_height_;
}
if ((pooled_width_ - 1) * stride_ >= width_ + pad_) {
if ((pooled_width_ - 1) * stride_w_ >= width_ + pad_w_) {
--pooled_width_;
}
CHECK_LT((pooled_height_ - 1) * stride_, height_ + pad_);
CHECK_LT((pooled_width_ - 1) * stride_, width_ + pad_);
CHECK_LT((pooled_height_ - 1) * stride_h_, height_ + pad_h_);
CHECK_LT((pooled_width_ - 1) * stride_w_, width_ + pad_w_);
}
(*top)[0]->Reshape(bottom[0]->num(), channels_, pooled_height_,
pooled_width_);
Expand Down Expand Up @@ -107,10 +140,10 @@ Dtype PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
for (int c = 0; c < channels_; ++c) {
for (int ph = 0; ph < pooled_height_; ++ph) {
for (int pw = 0; pw < pooled_width_; ++pw) {
int hstart = ph * stride_ - pad_;
int wstart = pw * stride_ - pad_;
int hend = min(hstart + kernel_size_, height_);
int wend = min(wstart + kernel_size_, width_);
int hstart = ph * stride_h_ - pad_h_;
int wstart = pw * stride_w_ - pad_w_;
int hend = min(hstart + kernel_h_, height_);
int wend = min(wstart + kernel_w_, width_);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
const int pool_index = ph * pooled_width_ + pw;
Expand Down Expand Up @@ -149,10 +182,10 @@ Dtype PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
for (int c = 0; c < channels_; ++c) {
for (int ph = 0; ph < pooled_height_; ++ph) {
for (int pw = 0; pw < pooled_width_; ++pw) {
int hstart = ph * stride_ - pad_;
int wstart = pw * stride_ - pad_;
int hend = min(hstart + kernel_size_, height_ + pad_);
int wend = min(wstart + kernel_size_, width_ + pad_);
int hstart = ph * stride_h_ - pad_h_;
int wstart = pw * stride_w_ - pad_w_;
int hend = min(hstart + kernel_h_, height_ + pad_h_);
int wend = min(wstart + kernel_w_, width_ + pad_w_);
int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
Expand Down Expand Up @@ -231,10 +264,10 @@ void PoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
for (int c = 0; c < channels_; ++c) {
for (int ph = 0; ph < pooled_height_; ++ph) {
for (int pw = 0; pw < pooled_width_; ++pw) {
int hstart = ph * stride_ - pad_;
int wstart = pw * stride_ - pad_;
int hend = min(hstart + kernel_size_, height_ + pad_);
int wend = min(wstart + kernel_size_, width_ + pad_);
int hstart = ph * stride_h_ - pad_h_;
int wstart = pw * stride_w_ - pad_w_;
int hend = min(hstart + kernel_h_, height_ + pad_h_);
int wend = min(wstart + kernel_w_, width_ + pad_w_);
int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
Expand Down
117 changes: 64 additions & 53 deletions src/caffe/layers/pooling_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@ template <typename Dtype>
__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, const int pad, Dtype* top_data,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, Dtype* top_data,
int* mask, Dtype* top_mask) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride - pad;
int wstart = pw * stride - pad;
int hend = min(hstart + kernel_size, height);
int wend = min(wstart + kernel_size, width);
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + kernel_h, height);
int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
Dtype maxval = -FLT_MAX;
Expand All @@ -54,16 +55,17 @@ template <typename Dtype>
__global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, const int pad, Dtype* top_data) {
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride - pad;
int wstart = pw * stride - pad;
int hend = min(hstart + kernel_size, height + pad);
int wend = min(wstart + kernel_size, width + pad);
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + kernel_h, height + pad_h);
int wend = min(wstart + kernel_w, width + pad_w);
int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
Expand All @@ -85,16 +87,17 @@ __global__ void StoPoolForwardTrain(const int nthreads,
const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, Dtype* rand_idx, Dtype* top_data) {
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, Dtype* rand_idx, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride;
int hend = min(hstart + kernel_size, height);
int wstart = pw * stride;
int wend = min(wstart + kernel_size, width);
int hstart = ph * stride_h;
int hend = min(hstart + kernel_h, height);
int wstart = pw * stride_w;
int wend = min(wstart + kernel_w, width);
Dtype cumsum = 0.;
bottom_data += (n * channels + c) * height * width;
// First pass: get sum
Expand Down Expand Up @@ -125,16 +128,17 @@ __global__ void StoPoolForwardTest(const int nthreads,
const Dtype* bottom_data,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, Dtype* top_data) {
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, Dtype* top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride;
int hend = min(hstart + kernel_size, height);
int wstart = pw * stride;
int wend = min(wstart + kernel_size, width);
int hstart = ph * stride_h;
int hend = min(hstart + kernel_h, height);
int wstart = pw * stride_w;
int wend = min(wstart + kernel_w, width);
// We set cumsum to be 0 to avoid divide-by-zero problems
Dtype cumsum = FLT_MIN;
Dtype cumvalues = 0.;
Expand Down Expand Up @@ -171,15 +175,16 @@ Dtype PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
// NOLINT_NEXT_LINE(whitespace/operators)
MaxPoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bottom[0]->num(), channels_,
height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
pad_, top_data, mask, top_mask);
height_, width_, pooled_height_, pooled_width_, kernel_h_,
kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data,
mask, top_mask);
break;
case PoolingParameter_PoolMethod_AVE:
// NOLINT_NEXT_LINE(whitespace/operators)
AvePoolForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bottom[0]->num(), channels_,
height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
pad_, top_data);
height_, width_, pooled_height_, pooled_width_, kernel_h_,
kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data);
break;
case PoolingParameter_PoolMethod_STOCHASTIC:
if (Caffe::phase() == Caffe::TRAIN) {
Expand All @@ -190,15 +195,16 @@ Dtype PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
StoPoolForwardTrain<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bottom[0]->num(), channels_,
height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
height_, width_, pooled_height_, pooled_width_, kernel_h_,
kernel_w_, stride_h_, stride_w_,
rand_idx_.mutable_gpu_data(), top_data);
} else {
// NOLINT_NEXT_LINE(whitespace/operators)
StoPoolForwardTest<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bottom[0]->num(), channels_,
height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
top_data);
height_, width_, pooled_height_, pooled_width_, kernel_h_,
kernel_w_, stride_h_, stride_w_, top_data);
}
break;
default:
Expand All @@ -213,8 +219,9 @@ template <typename Dtype>
__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
const int* mask, const Dtype* top_mask, const int num, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int kernel_size, const int stride,
const int pad, Dtype* bottom_diff) {
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
Expand All @@ -223,11 +230,11 @@ __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart =
(h + pad < kernel_size) ? 0 : (h + pad - kernel_size) / stride + 1;
int phend = min((h + pad) / stride + 1, pooled_height);
(h + pad_h < kernel_h) ? 0 : (h + pad_h - kernel_h) / stride_h + 1;
int phend = min((h + pad_h) / stride_h + 1, pooled_height);
int pwstart =
(w + pad < kernel_size) ? 0 : (w + pad - kernel_size) / stride + 1;
int pwend = min((w + pad) / stride + 1, pooled_width);
(w + pad_w < kernel_w) ? 0 : (w + pad_w - kernel_w) / stride_w + 1;
int pwend = min((w + pad_w) / stride_w + 1, pooled_width);
Dtype gradient = 0;
int offset = (n * channels + c) * pooled_height * pooled_width;
top_diff += offset;
Expand Down Expand Up @@ -258,28 +265,29 @@ template <typename Dtype>
__global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, const int pad,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
int w = index % width + pad;
int h = (index / width) % height + pad;
int w = index % width + pad_w;
int h = (index / width) % height + pad_h;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart = (h < kernel_size) ? 0 : (h - kernel_size) / stride + 1;
int phend = min(h / stride + 1, pooled_height);
int pwstart = (w < kernel_size) ? 0 : (w - kernel_size) / stride + 1;
int pwend = min(w / stride + 1, pooled_width);
int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
int phend = min(h / stride_h + 1, pooled_height);
int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
int pwend = min(w / stride_w + 1, pooled_width);
Dtype gradient = 0;
top_diff += (n * channels + c) * pooled_height * pooled_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
// figure out the pooling size
int hstart = ph * stride - pad;
int wstart = pw * stride - pad;
int hend = min(hstart + kernel_size, height + pad);
int wend = min(wstart + kernel_size, width + pad);
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + kernel_h, height + pad_h);
int wend = min(wstart + kernel_w, width + pad_w);
int pool_size = (hend - hstart) * (wend - wstart);
gradient += top_diff[ph * pooled_width + pw] / pool_size;
}
Expand All @@ -294,18 +302,19 @@ __global__ void StoPoolBackward(const int nthreads,
const Dtype* rand_idx, const Dtype* top_diff,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, Dtype* bottom_diff) {
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, Dtype* bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
int w = index % width;
int h = (index / width) % height;
int c = (index / width / height) % channels;
int n = index / width / height / channels;
int phstart = (h < kernel_size) ? 0 : (h - kernel_size) / stride + 1;
int phend = min(h / stride + 1, pooled_height);
int pwstart = (w < kernel_size) ? 0 : (w - kernel_size) / stride + 1;
int pwend = min(w / stride + 1, pooled_width);
int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
int phend = min(h / stride_h + 1, pooled_height);
int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
int pwend = min(w / stride_w + 1, pooled_width);
Dtype gradient = 0;
rand_idx += (n * channels + c) * pooled_height * pooled_width;
top_diff += (n * channels + c) * pooled_height * pooled_width;
Expand Down Expand Up @@ -345,21 +354,23 @@ void PoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
MaxPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, mask, top_mask, top[0]->num(), channels_,
height_, width_, pooled_height_, pooled_width_,
kernel_size_, stride_, pad_, bottom_diff);
kernel_h_, kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_,
bottom_diff);
break;
case PoolingParameter_PoolMethod_AVE:
// NOLINT_NEXT_LINE(whitespace/operators)
AvePoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, top[0]->num(), channels_,
height_, width_, pooled_height_, pooled_width_, kernel_size_, stride_,
pad_, bottom_diff);
height_, width_, pooled_height_, pooled_width_, kernel_h_,
kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, bottom_diff);
break;
case PoolingParameter_PoolMethod_STOCHASTIC:
// NOLINT_NEXT_LINE(whitespace/operators)
StoPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, rand_idx_.gpu_data(), top_diff,
top[0]->num(), channels_, height_, width_, pooled_height_,
pooled_width_, kernel_size_, stride_, bottom_diff);
pooled_width_, kernel_h_, kernel_w_, stride_h_, stride_w_,
bottom_diff);
break;
default:
LOG(FATAL) << "Unknown pooling method.";
Expand Down
Loading