Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8dfd51d

Browse files
mdouzefacebook-github-bot
authored andcommitted
Moved pytorch interop code to contrib
Summary: The pytorch interop code was in a test until now. However, it is better if people can rely on it to be updated when the API is updated. Therefore, we move it into contrib. Also added a README.md Reviewed By: wickedfoo Differential Revision: D23392962 fbshipit-source-id: 9b7c0e388a7ea3c0b73dc0018322138f49191673
1 parent f849680 commit 8dfd51d

File tree

3 files changed

+148
-96
lines changed

3 files changed

+148
-96
lines changed

contrib/README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
# The contrib modules
3+
4+
The contrib directory contains helper modules for Faiss for various tasks.
5+
6+
## Code structure
7+
8+
The contrib directory gets compiled in the module faiss.contrib.
9+
Note that although some of the modules may depend on additional modules (eg. GPU Faiss, pytorch, hdf5), they are not necessarily compiled in to avoid adding dependencies. It is the user's responsibility to provide them.
10+
11+
In contrib, we are progressively dropping python2 support.
12+
13+
## List of contrib modules
14+
15+
### rpc.py
16+
17+
A very simple Remote Procedure Call library, where function parameters and results are pickled, for use with client_server.py
18+
19+
### client_server.py
20+
21+
The server handles requests to a Faiss index. The client calls the remote index.
22+
This is mainly to shard datasets over several machines, see [Distributd index](https://github.com/facebookresearch/faiss/wiki/Indexes-that-do-not-fit-in-RAM#distributed-index)
23+
24+
### ondisk.py
25+
26+
Encloses the main logic to merge indexes into an on-disk index.
27+
See [On-disk storage](https://github.com/facebookresearch/faiss/wiki/Indexes-that-do-not-fit-in-RAM#on-disk-storage)
28+
29+
### exhaustive_search.py
30+
31+
Computes the ground-truth search results for a dataset that possibly does not fit in RAM. Uses GPU if available.
32+
Tested in `tests/test_contrib.TestComputeGT`
33+
34+
### gpu.py
35+
36+
(requires GPU Faiss)
37+
38+
Interoperability functions for pytorch and Faiss: pass GPU data without copying back to CPU.
39+
Tested in `gpu/test/test_pytorch_faiss`
40+
41+
### datasets.py
42+
43+
(may require h5py)
44+
45+
Defintion of how to access data for some standard datsets.

contrib/pytorch_tensors.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import faiss
2+
import torch
3+
4+
def swig_ptr_from_FloatTensor(x):
5+
""" gets a Faiss SWIG pointer from a pytorch trensor (on CPU or GPU) """
6+
assert x.is_contiguous()
7+
assert x.dtype == torch.float32
8+
return faiss.cast_integer_to_float_ptr(
9+
x.storage().data_ptr() + x.storage_offset() * 4)
10+
11+
def swig_ptr_from_LongTensor(x):
12+
""" gets a Faiss SWIG pointer from a pytorch trensor (on CPU or GPU) """
13+
assert x.is_contiguous()
14+
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
15+
return faiss.cast_integer_to_long_ptr(
16+
x.storage().data_ptr() + x.storage_offset() * 8)
17+
18+
19+
20+
def search_index_pytorch(index, x, k, D=None, I=None):
21+
"""call the search function of an index with pytorch tensor I/O (CPU
22+
and GPU supported)"""
23+
assert x.is_contiguous()
24+
n, d = x.size()
25+
assert d == index.d
26+
27+
if D is None:
28+
D = torch.empty((n, k), dtype=torch.float32, device=x.device)
29+
else:
30+
assert D.size() == (n, k)
31+
32+
if I is None:
33+
I = torch.empty((n, k), dtype=torch.int64, device=x.device)
34+
else:
35+
assert I.size() == (n, k)
36+
torch.cuda.synchronize()
37+
xptr = swig_ptr_from_FloatTensor(x)
38+
Iptr = swig_ptr_from_LongTensor(I)
39+
Dptr = swig_ptr_from_FloatTensor(D)
40+
index.search_c(n, xptr,
41+
k, Dptr, Iptr)
42+
torch.cuda.synchronize()
43+
return D, I
44+
45+
46+
def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None,
47+
metric=faiss.METRIC_L2):
48+
"""search xq in xb, without building an index"""
49+
assert xb.device == xq.device
50+
51+
nq, d = xq.size()
52+
if xq.is_contiguous():
53+
xq_row_major = True
54+
elif xq.t().is_contiguous():
55+
xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-)
56+
xq_row_major = False
57+
else:
58+
raise TypeError('matrix should be row or column-major')
59+
60+
xq_ptr = swig_ptr_from_FloatTensor(xq)
61+
62+
nb, d2 = xb.size()
63+
assert d2 == d
64+
if xb.is_contiguous():
65+
xb_row_major = True
66+
elif xb.t().is_contiguous():
67+
xb = xb.t()
68+
xb_row_major = False
69+
else:
70+
raise TypeError('matrix should be row or column-major')
71+
xb_ptr = swig_ptr_from_FloatTensor(xb)
72+
73+
if D is None:
74+
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
75+
else:
76+
assert D.shape == (nq, k)
77+
assert D.device == xb.device
78+
79+
if I is None:
80+
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
81+
else:
82+
assert I.shape == (nq, k)
83+
assert I.device == xb.device
84+
85+
D_ptr = swig_ptr_from_FloatTensor(D)
86+
I_ptr = swig_ptr_from_LongTensor(I)
87+
88+
args = faiss.GpuDistanceParams()
89+
args.metric = metric
90+
args.k = k
91+
args.dims = d
92+
args.vectors = xb_ptr
93+
args.vectorsRowMajor = xb_row_major
94+
args.numVectors = nb
95+
args.queries = xq_ptr
96+
args.queriesRowMajor = xq_row_major
97+
args.numQueries = nq
98+
args.outDistances = D_ptr
99+
args.outIndices = I_ptr
100+
faiss.bfKnn(res, args)
101+
102+
return D, I

faiss/gpu/test/test_pytorch_faiss.py

Lines changed: 1 addition & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -10,102 +10,7 @@
1010
import faiss
1111
import torch
1212

13-
def swig_ptr_from_FloatTensor(x):
14-
assert x.is_contiguous()
15-
assert x.dtype == torch.float32
16-
return faiss.cast_integer_to_float_ptr(
17-
x.storage().data_ptr() + x.storage_offset() * 4)
18-
19-
def swig_ptr_from_LongTensor(x):
20-
assert x.is_contiguous()
21-
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
22-
return faiss.cast_integer_to_long_ptr(
23-
x.storage().data_ptr() + x.storage_offset() * 8)
24-
25-
26-
27-
def search_index_pytorch(index, x, k, D=None, I=None):
28-
"""call the search function of an index with pytorch tensor I/O (CPU
29-
and GPU supported)"""
30-
assert x.is_contiguous()
31-
n, d = x.size()
32-
assert d == index.d
33-
34-
if D is None:
35-
D = torch.empty((n, k), dtype=torch.float32, device=x.device)
36-
else:
37-
assert D.size() == (n, k)
38-
39-
if I is None:
40-
I = torch.empty((n, k), dtype=torch.int64, device=x.device)
41-
else:
42-
assert I.size() == (n, k)
43-
torch.cuda.synchronize()
44-
xptr = swig_ptr_from_FloatTensor(x)
45-
Iptr = swig_ptr_from_LongTensor(I)
46-
Dptr = swig_ptr_from_FloatTensor(D)
47-
index.search_c(n, xptr,
48-
k, Dptr, Iptr)
49-
torch.cuda.synchronize()
50-
return D, I
51-
52-
53-
def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None,
54-
metric=faiss.METRIC_L2):
55-
assert xb.device == xq.device
56-
57-
nq, d = xq.size()
58-
if xq.is_contiguous():
59-
xq_row_major = True
60-
elif xq.t().is_contiguous():
61-
xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-)
62-
xq_row_major = False
63-
else:
64-
raise TypeError('matrix should be row or column-major')
65-
66-
xq_ptr = swig_ptr_from_FloatTensor(xq)
67-
68-
nb, d2 = xb.size()
69-
assert d2 == d
70-
if xb.is_contiguous():
71-
xb_row_major = True
72-
elif xb.t().is_contiguous():
73-
xb = xb.t()
74-
xb_row_major = False
75-
else:
76-
raise TypeError('matrix should be row or column-major')
77-
xb_ptr = swig_ptr_from_FloatTensor(xb)
78-
79-
if D is None:
80-
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
81-
else:
82-
assert D.shape == (nq, k)
83-
assert D.device == xb.device
84-
85-
if I is None:
86-
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
87-
else:
88-
assert I.shape == (nq, k)
89-
assert I.device == xb.device
90-
91-
D_ptr = swig_ptr_from_FloatTensor(D)
92-
I_ptr = swig_ptr_from_LongTensor(I)
93-
94-
args = faiss.GpuDistanceParams()
95-
args.metric = metric
96-
args.k = k
97-
args.dims = d
98-
args.vectors = xb_ptr
99-
args.vectorsRowMajor = xb_row_major
100-
args.numVectors = nb
101-
args.queries = xq_ptr
102-
args.queriesRowMajor = xq_row_major
103-
args.numQueries = nq
104-
args.outDistances = D_ptr
105-
args.outIndices = I_ptr
106-
faiss.bfKnn(res, args)
107-
108-
return D, I
13+
from faiss.contrib.pytorch_tensors import search_index_pytorch, search_raw_array_pytorch
10914

11015
def to_column_major(x):
11116
if hasattr(torch, 'contiguous_format'):

0 commit comments

Comments
 (0)