Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d4b67e1

Browse files
zhoukunshengluotao1
authored andcommitted
Add Where Op(PaddlePaddle#16793)
1 parent 1bfff02 commit d4b67e1

6 files changed

Lines changed: 363 additions & 0 deletions

File tree

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l
234234
paddle.fluid.layers.pixel_shuffle (ArgSpec(args=['x', 'upscale_factor'], varargs=None, keywords=None, defaults=None), ('document', '132b6e74ff642a392bd6b14c10aedc65'))
235235
paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393'))
236236
paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', 'a07a44c2bacdcd09c1f5f35a96a0514e'))
237+
paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=None, defaults=None), ('document', '3126e3039e752ce26077f1efaca355c6'))
237238
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', 'adf285346e23316097f7789b572491e9'))
238239
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'cf12066a3139026119f97f9d4381a1bd'))
239240
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', 'b0a1c2fc51c27a106da28f3308c41f5e'))

paddle/fluid/operators/where_op.cc

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/where_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class WhereOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("Condition"),
26+
"Input(Condition) of WhereOp should not be null.");
27+
PADDLE_ENFORCE(
28+
ctx->GetInputDim("Condition").size() >= 1,
29+
"Input(Condition) should have number of dimension at least 1");
30+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
31+
"Output(OUt) of WhereOp should not be null.");
32+
ctx->SetOutputDim("Out", {-1, ctx->GetInputDim("Condition").size()});
33+
}
34+
35+
protected:
36+
framework::OpKernelType GetExpectedKernelType(
37+
const framework::ExecutionContext& ctx) const override {
38+
auto output_type = framework::proto::VarType::INT64;
39+
return framework::OpKernelType(output_type, ctx.device_context());
40+
}
41+
};
42+
43+
class WhereOpMaker : public framework::OpProtoAndCheckerMaker {
44+
public:
45+
void Make() override {
46+
AddInput("Condition", "A bool tensor whose rank is at least 1");
47+
AddOutput("Out", "An int64 tensor of rank 2");
48+
AddComment(R"DOC(
49+
Return a int64 tensor with rank 2, specifying the coordinate of true element in `Condition`.
50+
)DOC");
51+
}
52+
};
53+
} // namespace operators
54+
} // namespace paddle
55+
56+
namespace ops = paddle::operators;
57+
REGISTER_OP_WITHOUT_GRADIENT(where, ops::WhereOp, ops::WhereOpMaker);
58+
REGISTER_OP_CPU_KERNEL(where, ops::CPUWhereKernel<int64_t>);

paddle/fluid/operators/where_op.cu

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <thrust/device_vector.h>
16+
#include "paddle/fluid/framework/ddim.h"
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/operators/where_op.h"
19+
#include "paddle/fluid/platform/cuda_primitives.h"
20+
#include "paddle/fluid/platform/for_range.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
25+
using CUDADeviceContext = paddle::platform::CUDADeviceContext;
26+
27+
template <typename T>
28+
class CUDAWhereKernel : public framework::OpKernel<T> {
29+
public:
30+
void Compute(const framework::ExecutionContext& context) const override {
31+
auto* condition = context.Input<framework::Tensor>("Condition");
32+
auto* out = context.Output<framework::Tensor>("Out");
33+
34+
// TODO(zhoukunsheng): Should optimize to ensure GPU is faster than CPU.
35+
framework::Tensor cond_cpu;
36+
framework::TensorCopy(*condition, platform::CPUPlace(), &cond_cpu);
37+
38+
const bool* cond_data = cond_cpu.data<bool>();
39+
int64_t numel = cond_cpu.numel();
40+
auto dims = cond_cpu.dims();
41+
int rank = dims.size();
42+
43+
thrust::host_vector<int> h_true_index;
44+
for (int64_t i = 0; i < numel; i++) {
45+
if (cond_data[i]) {
46+
h_true_index.push_back(i);
47+
}
48+
}
49+
thrust::device_vector<int> d_true_index = h_true_index;
50+
int* ptr_true_index = thrust::raw_pointer_cast(d_true_index.data());
51+
52+
size_t true_num = h_true_index.size();
53+
54+
out->Resize(framework::make_ddim({static_cast<int64_t>(true_num), rank}));
55+
auto out_ptr = out->mutable_data<T>(context.GetPlace());
56+
57+
if (true_num == 0) {
58+
return;
59+
}
60+
61+
thrust::host_vector<int> h_stride(rank, 0);
62+
h_stride[rank - 1] = 1;
63+
for (int i = rank - 2; i >= 0; i--) {
64+
h_stride[i] = h_stride[i + 1] * dims[i + 1];
65+
}
66+
thrust::device_vector<int> d_stride = h_stride;
67+
int* ptr_stride = thrust::raw_pointer_cast(d_stride.data());
68+
69+
auto& dev_ctx = context.template device_context<CUDADeviceContext>();
70+
WhereFunctor<int*> functor(ptr_true_index, true_num, ptr_stride, rank,
71+
out_ptr);
72+
platform::ForRange<CUDADeviceContext> for_range(dev_ctx, true_num);
73+
for_range(functor);
74+
}
75+
};
76+
77+
} // namespace operators
78+
} // namespace paddle
79+
80+
namespace ops = paddle::operators;
81+
REGISTER_OP_CUDA_KERNEL(where, ops::CUDAWhereKernel<int64_t>);

paddle/fluid/operators/where_op.h

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <functional>
17+
#include <vector>
18+
#include "paddle/fluid/framework/eigen.h"
19+
#include "paddle/fluid/framework/op_registry.h"
20+
#include "paddle/fluid/operators/math/math_function.h"
21+
#include "paddle/fluid/platform/for_range.h"
22+
23+
namespace paddle {
24+
namespace operators {
25+
26+
template <typename T>
27+
struct WhereFunctor {
28+
WhereFunctor(const T& true_index, int true_num, const T& stride, int rank,
29+
int64_t* out)
30+
: true_index_(true_index),
31+
true_num_(true_num),
32+
stride_(stride),
33+
rank_(rank),
34+
out_ptr_(out) {}
35+
36+
HOSTDEVICE void operator()(size_t idx) const {
37+
int index = true_index_[idx];
38+
for (int j = 0; j < rank_; j++) {
39+
out_ptr_[idx * rank_ + j] = index / stride_[j];
40+
index -= out_ptr_[idx * rank_ + j] * stride_[j];
41+
}
42+
}
43+
44+
const T true_index_;
45+
int true_num_;
46+
const T stride_;
47+
int rank_;
48+
int64_t* out_ptr_;
49+
};
50+
51+
using CPUDeviceContext = paddle::platform::CPUDeviceContext;
52+
53+
template <typename T>
54+
class CPUWhereKernel : public framework::OpKernel<T> {
55+
public:
56+
void Compute(const framework::ExecutionContext& context) const override {
57+
auto* condition = context.Input<framework::Tensor>("Condition");
58+
auto* out = context.Output<framework::Tensor>("Out");
59+
60+
const bool* cond_data = condition->data<bool>();
61+
auto numel = condition->numel();
62+
auto dims = condition->dims();
63+
const int rank = dims.size();
64+
65+
std::vector<int> true_index;
66+
for (auto i = 0; i < numel; i++) {
67+
if (cond_data[i]) {
68+
true_index.push_back(i);
69+
}
70+
}
71+
auto true_num = true_index.size();
72+
73+
out->Resize(framework::make_ddim({static_cast<int64_t>(true_num), rank}));
74+
auto out_ptr = out->mutable_data<T>(context.GetPlace());
75+
76+
if (true_num == 0) {
77+
return;
78+
}
79+
80+
std::vector<int> stride(rank);
81+
stride[rank - 1] = 1;
82+
for (int i = rank - 2; i >= 0; i--) {
83+
stride[i] = stride[i + 1] * dims[i + 1];
84+
}
85+
86+
auto& dev_ctx = context.template device_context<CPUDeviceContext>();
87+
WhereFunctor<int*> functor(true_index.data(), true_num, stride.data(), rank,
88+
out_ptr);
89+
platform::ForRange<CPUDeviceContext> for_range(dev_ctx, true_num);
90+
for_range(functor);
91+
}
92+
};
93+
94+
} // namespace operators
95+
} // namespace paddle

python/paddle/fluid/layers/nn.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@
200200
'pixel_shuffle',
201201
'fsp_matrix',
202202
'continuous_value_model',
203+
'where',
203204
]
204205

205206
kIgnoreIndex = -100
@@ -11341,3 +11342,38 @@ def continuous_value_model(input, cvm, use_cvm=True):
1134111342
outputs={'Y': [out]},
1134211343
attrs={"use_cvm": use_cvm})
1134311344
return out
11345+
11346+
11347+
def where(condition):
11348+
"""
11349+
Return an int64 tensor with rank 2, specifying the coordinate of true element in `condition`.
11350+
11351+
Output's first dimension is the number of true element, second dimension is rank(number of dimension) of `condition`.
11352+
If there is zero true element, then an empty tensor will be generated.
11353+
11354+
Args:
11355+
condition(Variable): A bool tensor with rank at least 1.
11356+
11357+
Returns:
11358+
Variable: The tensor variable storing a 2-D tensor.
11359+
11360+
Examples:
11361+
.. code-block:: python
11362+
11363+
# condition is a tensor [True, False, True]
11364+
out = fluid.layers.where(condition) # [[0], [2]]
11365+
11366+
# condition is a tensor [[True, False], [False, True]]
11367+
out = fluid.layers.where(condition) # [[0, 0], [1, 1]]
11368+
11369+
# condition is a tensor [False, False, False]
11370+
out = fluid.layers.where(condition) # [[]]
11371+
"""
11372+
helper = LayerHelper("where", **locals())
11373+
11374+
out = helper.create_variable_for_type_inference(
11375+
dtype=core.VarDesc.VarType.INT64)
11376+
11377+
helper.append_op(
11378+
type='where', inputs={'Condition': condition}, outputs={'Out': [out]})
11379+
return out
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
from op_test import OpTest
20+
import paddle.fluid.core as core
21+
from paddle.fluid.op import Operator
22+
23+
24+
class TestWhereOp(OpTest):
25+
def setUp(self):
26+
self.op_type = "where"
27+
self.init_config()
28+
29+
def test_check_output(self):
30+
self.check_output()
31+
32+
def init_config(self):
33+
self.inputs = {'Condition': np.array([True, False, True]), }
34+
35+
self.outputs = {'Out': np.array([[0], [2]], dtype='int64')}
36+
37+
38+
class TestAllFalse(unittest.TestCase):
39+
def setUp(self):
40+
self.op_type = "where"
41+
self.init_config()
42+
43+
def check_with_place(self, place):
44+
scope = core.Scope()
45+
condition = scope.var('Condition').get_tensor()
46+
condition.set(self.cond_data, place)
47+
48+
out = scope.var("Out").get_tensor()
49+
out.set(np.full(self.shape, 0).astype('int64'), place)
50+
51+
op = Operator("where", Condition="Condition", Out="Out")
52+
op.run(scope, place)
53+
54+
out_array = np.array(out)
55+
self.assertTrue((out_array == self.out_data).all())
56+
57+
def init_config(self):
58+
self.cond_data = np.array([False, False, False])
59+
self.shape = (3, 1)
60+
self.out_data = np.array([], dtype='int64')
61+
62+
def test_all_false(self):
63+
self.check_with_place(core.CPUPlace())
64+
65+
if core.is_compiled_with_cuda():
66+
self.check_with_place(core.CUDAPlace(0))
67+
68+
69+
class TestRank2(TestWhereOp):
70+
def init_config(self):
71+
self.inputs = {'Condition': np.array([[True, False], [False, True]]), }
72+
73+
self.outputs = {'Out': np.array([[0, 0], [1, 1]], dtype='int64')}
74+
75+
76+
class TestRank3(TestWhereOp):
77+
def init_config(self):
78+
self.inputs = {
79+
'Condition': np.array([[[True, False], [False, True]],
80+
[[False, True], [True, False]],
81+
[[False, False], [False, True]]]),
82+
}
83+
84+
self.outputs = {
85+
'Out': np.array(
86+
[[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0], [2, 1, 1]],
87+
dtype='int64')
88+
}
89+
90+
91+
if __name__ == "__main__":
92+
unittest.main()

0 commit comments

Comments
 (0)