|
1 | | -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 1 | +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
|
17 | 17 | from op_test import OpTest |
18 | 18 |
|
19 | 19 |
|
20 | | -class TestScaleOp(OpTest): |
| 20 | +class TestHashOp(OpTest): |
21 | 21 | def setUp(self): |
22 | 22 | self.op_type = "hash" |
23 | 23 | self.init_test_case() |
24 | 24 | self.inputs = {'X': (self.in_seq, self.lod)} |
25 | | - self.attrs = {'num_hash': 4, 'mod_by': 10000} |
| 25 | + self.attrs = {'num_hash': 2, 'mod_by': 10000} |
26 | 26 | self.outputs = {'Out': (self.out_seq, self.lod)} |
27 | 27 |
|
28 | 28 | def init_test_case(self): |
29 | | - np.random.seed = 1 |
30 | | - self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") |
31 | | - self.lod = [[9, 4, 11, 6]] |
32 | | - # self.out_seq = np.ones([30, 4, 1], dtype=np.int32) |
33 | | - self.out_seq = [ |
34 | | - [[9662], [9217], [1129], [8487]], [[9662], [9217], [1129], [8487]], |
35 | | - [[8310], [1327], [1654], [4567]], [[6897], [3218], [2013], [1241]], |
36 | | - [[9407], [6715], [6949], [8094]], [[8473], [694], [5142], [2479]], |
37 | | - [[8310], [1327], [1654], [4567]], [[6897], [3218], [2013], [1241]], |
38 | | - [[4372], [9456], [8204], [6695]], [[6897], [3218], [2013], [1241]], |
39 | | - [[8473], [694], [5142], [2479]], [[4372], [9456], [8204], [6695]], |
40 | | - [[4372], [9456], [8204], [6695]], [[8473], [694], [5142], [2479]], |
41 | | - [[9407], [6715], [6949], [8094]], [[9369], [4525], [8935], [9210]], |
42 | | - [[4372], [9456], [8204], [6695]], [[4372], [9456], [8204], [6695]], |
43 | | - [[9369], [4525], [8935], [9210]], [[6897], [3218], [2013], [1241]], |
44 | | - [[9038], [7951], [5953], [8657]], [[9407], [6715], [6949], [8094]], |
45 | | - [[9662], [9217], [1129], [8487]], [[9369], [4525], [8935], [9210]], |
46 | | - [[9038], [7951], [5953], [8657]], [[9662], [9217], [1129], [8487]], |
47 | | - [[9369], [4525], [8935], [9210]], [[1719], [5986], [9919], [3421]], |
48 | | - [[4372], [9456], [8204], [6695]], [[9038], [7951], [5953], [8657]] |
49 | | - ] |
| 29 | + np.random.seed(1) |
| 30 | + self.in_seq = np.random.randint(0, 10, (8, 1)).astype("int32") |
| 31 | + self.lod = [[2, 6]] |
| 32 | + self.out_seq = [[[3481], [7475]], [[1719], [5986]], [[8473], [694]], |
| 33 | + [[3481], [7475]], [[4372], [9456]], [[4372], [9456]], |
| 34 | + [[6897], [3218]], [[9038], [7951]]] |
| 35 | + self.out_seq = np.array(self.out_seq) |
| 36 | + |
| 37 | + def test_check_output(self): |
| 38 | + self.check_output() |
| 39 | + |
| 40 | + |
| 41 | +class TestHashNotLoDOp(TestHashOp): |
| 42 | + def setUp(self): |
| 43 | + self.op_type = "hash" |
| 44 | + self.init_test_case() |
| 45 | + self.inputs = {'X': self.in_seq} |
| 46 | + self.attrs = {'num_hash': 2, 'mod_by': 10000} |
| 47 | + self.outputs = {'Out': self.out_seq} |
| 48 | + |
| 49 | + def init_test_case(self): |
| 50 | + np.random.seed(1) |
| 51 | + self.in_seq = np.random.randint(0, 10, (8, 1)).astype("int32") |
| 52 | + self.out_seq = [[[3481], [7475]], [[1719], [5986]], [[8473], [694]], |
| 53 | + [[3481], [7475]], [[4372], [9456]], [[4372], [9456]], |
| 54 | + [[6897], [3218]], [[9038], [7951]]] |
50 | 55 | self.out_seq = np.array(self.out_seq) |
51 | 56 |
|
52 | 57 | def test_check_output(self): |
|
0 commit comments