Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e813e0e

Browse files
committed
Add optimized functions for linear search within byte arrays
In similar vein to b6ef167, add pg_lfind8() and pg_lfind8_le() to search for bytes equal or less-than-or-equal to a given byte, respectively. To abstract away platform details, add helper functions and typedefs to simd.h. John Naylor and Nathan Bossart, per suggestion from Andres Freund Discussion: https://www.postgresql.org/message-id/CAFBsxsGzaaGLF%3DNuq61iRXTyspbO9rOjhSqFN%3DV6ozzmta5mXg%40mail.gmail.com
1 parent bcc8b14 commit e813e0e

File tree

6 files changed

+358
-10
lines changed

6 files changed

+358
-10
lines changed

src/include/port/pg_lfind.h

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
/*-------------------------------------------------------------------------
22
*
33
* pg_lfind.h
4-
* Optimized linear search routines.
4+
* Optimized linear search routines using SIMD intrinsics where
5+
* available.
56
*
67
* Copyright (c) 2022, PostgreSQL Global Development Group
78
*
@@ -15,6 +16,70 @@
1516

1617
#include "port/simd.h"
1718

19+
/*
20+
* pg_lfind8
21+
*
22+
* Return true if there is an element in 'base' that equals 'key', otherwise
23+
* return false.
24+
*/
25+
static inline bool
26+
pg_lfind8(uint8 key, uint8 *base, uint32 nelem)
27+
{
28+
uint32 i;
29+
30+
/* round down to multiple of vector length */
31+
uint32 tail_idx = nelem & ~(sizeof(Vector8) - 1);
32+
Vector8 chunk;
33+
34+
for (i = 0; i < tail_idx; i += sizeof(Vector8))
35+
{
36+
vector8_load(&chunk, &base[i]);
37+
if (vector8_has(chunk, key))
38+
return true;
39+
}
40+
41+
/* Process the remaining elements one at a time. */
42+
for (; i < nelem; i++)
43+
{
44+
if (key == base[i])
45+
return true;
46+
}
47+
48+
return false;
49+
}
50+
51+
/*
52+
* pg_lfind8_le
53+
*
54+
* Return true if there is an element in 'base' that is less than or equal to
55+
* 'key', otherwise return false.
56+
*/
57+
static inline bool
58+
pg_lfind8_le(uint8 key, uint8 *base, uint32 nelem)
59+
{
60+
uint32 i;
61+
62+
/* round down to multiple of vector length */
63+
uint32 tail_idx = nelem & ~(sizeof(Vector8) - 1);
64+
Vector8 chunk;
65+
66+
for (i = 0; i < tail_idx; i += sizeof(Vector8))
67+
{
68+
vector8_load(&chunk, &base[i]);
69+
if (vector8_has_le(chunk, key))
70+
return true;
71+
}
72+
73+
/* Process the remaining elements one at a time. */
74+
for (; i < nelem; i++)
75+
{
76+
if (base[i] <= key)
77+
return true;
78+
}
79+
80+
return false;
81+
}
82+
1883
/*
1984
* pg_lfind32
2085
*
@@ -26,7 +91,6 @@ pg_lfind32(uint32 key, uint32 *base, uint32 nelem)
2691
{
2792
uint32 i = 0;
2893

29-
/* Use SIMD intrinsics where available. */
3094
#ifdef USE_SSE2
3195

3296
/*

src/include/port/simd.h

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,17 @@
88
*
99
* src/include/port/simd.h
1010
*
11+
* NOTES
12+
* - VectorN in this file refers to a register where the element operands
13+
* are N bits wide. The vector width is platform-specific, so users that care
14+
* about that will need to inspect "sizeof(VectorN)".
15+
*
1116
*-------------------------------------------------------------------------
1217
*/
1318
#ifndef SIMD_H
1419
#define SIMD_H
1520

21+
#if (defined(__x86_64__) || defined(_M_AMD64))
1622
/*
1723
* SSE2 instructions are part of the spec for the 64-bit x86 ISA. We assume
1824
* that compilers targeting this architecture understand SSE2 intrinsics.
@@ -22,9 +28,169 @@
2228
* will allow the use of intrinsics that haven't been enabled at compile
2329
* time.
2430
*/
25-
#if (defined(__x86_64__) || defined(_M_AMD64))
2631
#include <emmintrin.h>
2732
#define USE_SSE2
33+
typedef __m128i Vector8;
34+
35+
#else
36+
/*
37+
* If no SIMD instructions are available, we can in some cases emulate vector
38+
* operations using bitwise operations on unsigned integers.
39+
*/
40+
#define USE_NO_SIMD
41+
typedef uint64 Vector8;
42+
#endif
43+
44+
45+
/* load/store operations */
46+
static inline void vector8_load(Vector8 *v, const uint8 *s);
47+
48+
/* assignment operations */
49+
static inline Vector8 vector8_broadcast(const uint8 c);
50+
51+
/* element-wise comparisons to a scalar */
52+
static inline bool vector8_has(const Vector8 v, const uint8 c);
53+
static inline bool vector8_has_zero(const Vector8 v);
54+
static inline bool vector8_has_le(const Vector8 v, const uint8 c);
55+
56+
57+
/*
58+
* Load a chunk of memory into the given vector.
59+
*/
60+
static inline void
61+
vector8_load(Vector8 *v, const uint8 *s)
62+
{
63+
#if defined(USE_SSE2)
64+
*v = _mm_loadu_si128((const __m128i *) s);
65+
#else
66+
memcpy(v, s, sizeof(Vector8));
2867
#endif
68+
}
69+
70+
71+
/*
72+
* Create a vector with all elements set to the same value.
73+
*/
74+
static inline Vector8
75+
vector8_broadcast(const uint8 c)
76+
{
77+
#if defined(USE_SSE2)
78+
return _mm_set1_epi8(c);
79+
#else
80+
return ~UINT64CONST(0) / 0xFF * c;
81+
#endif
82+
}
83+
84+
/*
85+
* Return true if any elements in the vector are equal to the given scalar.
86+
*/
87+
static inline bool
88+
vector8_has(const Vector8 v, const uint8 c)
89+
{
90+
bool result;
91+
92+
/* pre-compute the result for assert checking */
93+
#ifdef USE_ASSERT_CHECKING
94+
bool assert_result = false;
95+
96+
for (int i = 0; i < sizeof(Vector8); i++)
97+
{
98+
if (((const uint8 *) &v)[i] == c)
99+
{
100+
assert_result = true;
101+
break;
102+
}
103+
}
104+
#endif /* USE_ASSERT_CHECKING */
105+
106+
#if defined(USE_NO_SIMD)
107+
/* any bytes in v equal to c will evaluate to zero via XOR */
108+
result = vector8_has_zero(v ^ vector8_broadcast(c));
109+
#elif defined(USE_SSE2)
110+
result = _mm_movemask_epi8(_mm_cmpeq_epi8(v, vector8_broadcast(c)));
111+
#endif
112+
113+
Assert(assert_result == result);
114+
return result;
115+
}
116+
117+
/*
118+
* Convenience function equivalent to vector8_has(v, 0)
119+
*/
120+
static inline bool
121+
vector8_has_zero(const Vector8 v)
122+
{
123+
#if defined(USE_NO_SIMD)
124+
/*
125+
* We cannot call vector8_has() here, because that would lead to a circular
126+
* definition.
127+
*/
128+
return vector8_has_le(v, 0);
129+
#elif defined(USE_SSE2)
130+
return vector8_has(v, 0);
131+
#endif
132+
}
133+
134+
/*
135+
* Return true if any elements in the vector are less than or equal to the
136+
* given scalar.
137+
*/
138+
static inline bool
139+
vector8_has_le(const Vector8 v, const uint8 c)
140+
{
141+
bool result = false;
142+
#if defined(USE_SSE2)
143+
__m128i sub;
144+
#endif
145+
146+
/* pre-compute the result for assert checking */
147+
#ifdef USE_ASSERT_CHECKING
148+
bool assert_result = false;
149+
150+
for (int i = 0; i < sizeof(Vector8); i++)
151+
{
152+
if (((const uint8 *) &v)[i] <= c)
153+
{
154+
assert_result = true;
155+
break;
156+
}
157+
}
158+
#endif /* USE_ASSERT_CHECKING */
159+
160+
#if defined(USE_NO_SIMD)
161+
162+
/*
163+
* To find bytes <= c, we can use bitwise operations to find bytes < c+1,
164+
* but it only works if c+1 <= 128 and if the highest bit in v is not set.
165+
* Adapted from
166+
* https://graphics.stanford.edu/~seander/bithacks.html#HasLessInWord
167+
*/
168+
if ((int64) v >= 0 && c < 0x80)
169+
result = (v - vector8_broadcast(c + 1)) & ~v & vector8_broadcast(0x80);
170+
else
171+
{
172+
/* one byte at a time */
173+
for (int i = 0; i < sizeof(Vector8); i++)
174+
{
175+
if (((const uint8 *) &v)[i] <= c)
176+
{
177+
result = true;
178+
break;
179+
}
180+
}
181+
}
182+
#elif defined(USE_SSE2)
183+
184+
/*
185+
* Use saturating subtraction to find bytes <= c, which will present as
186+
* NUL bytes in 'sub'.
187+
*/
188+
sub = _mm_subs_epu8(v, vector8_broadcast(c));
189+
result = vector8_has_zero(sub);
190+
#endif
191+
192+
Assert(assert_result == result);
193+
return result;
194+
}
29195

30196
#endif /* SIMD_H */

src/test/modules/test_lfind/expected/test_lfind.out

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,21 @@ CREATE EXTENSION test_lfind;
44
-- the operations complete without crashing or hanging and that none of their
55
-- internal sanity tests fail.
66
--
7-
SELECT test_lfind();
8-
test_lfind
9-
------------
7+
SELECT test_lfind8();
8+
test_lfind8
9+
-------------
10+
11+
(1 row)
12+
13+
SELECT test_lfind8_le();
14+
test_lfind8_le
15+
----------------
16+
17+
(1 row)
18+
19+
SELECT test_lfind32();
20+
test_lfind32
21+
--------------
1022

1123
(1 row)
1224

src/test/modules/test_lfind/sql/test_lfind.sql

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ CREATE EXTENSION test_lfind;
55
-- the operations complete without crashing or hanging and that none of their
66
-- internal sanity tests fail.
77
--
8-
SELECT test_lfind();
8+
SELECT test_lfind8();
9+
SELECT test_lfind8_le();
10+
SELECT test_lfind32();

src/test/modules/test_lfind/test_lfind--1.0.sql

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
-- complain if script is sourced in psql, rather than via CREATE EXTENSION
44
\echo Use "CREATE EXTENSION test_lfind" to load this file. \quit
55

6-
CREATE FUNCTION test_lfind()
6+
CREATE FUNCTION test_lfind32()
7+
RETURNS pg_catalog.void
8+
AS 'MODULE_PATHNAME' LANGUAGE C;
9+
10+
CREATE FUNCTION test_lfind8()
11+
RETURNS pg_catalog.void
12+
AS 'MODULE_PATHNAME' LANGUAGE C;
13+
14+
CREATE FUNCTION test_lfind8_le()
715
RETURNS pg_catalog.void
816
AS 'MODULE_PATHNAME' LANGUAGE C;

0 commit comments

Comments
 (0)