Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cec65f1

Browse files
authored
support tensor parallel (OpenNMT#1599)
* tensor parallel support * add docs * small fix * fix adding bias multiple times in layer output.
1 parent 1427722 commit cec65f1

43 files changed

Lines changed: 1313 additions & 42 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ option(ENABLE_PROFILING "Compile with profiling support" OFF)
2020
option(BUILD_CLI "Compile the clients" ON)
2121
option(BUILD_TESTS "Compile the tests" OFF)
2222
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
23+
option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF)
2324

2425
if(ENABLE_PROFILING)
2526
message(STATUS "Enable profiling support")
@@ -179,6 +180,8 @@ set(SOURCES
179180
src/ops/topp_mask.cc
180181
src/ops/topp_mask_cpu.cc
181182
src/ops/transpose.cc
183+
src/ops/nccl_ops.cc
184+
src/ops/nccl_ops_cpu.cc
182185
src/padder.cc
183186
src/profiler.cc
184187
src/random.cc
@@ -191,7 +194,7 @@ set(SOURCES
191194
src/utils.cc
192195
src/vocabulary.cc
193196
src/vocabulary_map.cc
194-
)
197+
)
195198
set(LIBRARIES
196199
${CMAKE_THREAD_LIBS_INIT}
197200
spdlog::spdlog_header_only
@@ -419,6 +422,24 @@ endif()
419422

420423
if (WITH_CUDA)
421424
find_package(CUDA 11.0 REQUIRED)
425+
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
426+
if (WITH_TENSOR_PARALLEL)
427+
find_package(MPI REQUIRED)
428+
find_package(NCCL REQUIRED)
429+
include_directories(${NCCL_INCLUDE_DIR})
430+
include_directories(${MPI_INCLUDE_PATH})
431+
if(CUDA_DYNAMIC_LOADING)
432+
list(APPEND SOURCES src/cuda/mpi_stub.cc)
433+
list(APPEND SOURCES src/cuda/nccl_stub.cc)
434+
add_definitions(-DCT2_WITH_CUDA_DYNAMIC_LOADING)
435+
else ()
436+
list(APPEND LIBRARIES ${NCCL_LIBRARY})
437+
list(APPEND LIBRARIES ${MPI_LIBRARIES})
438+
endif ()
439+
add_definitions(-DCT2_WITH_TENSOR_PARALLEL)
440+
endif ()
441+
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include)
442+
422443
add_definitions(-DCT2_WITH_CUDA)
423444
if(MSVC)
424445
if(BUILD_SHARED_LIBS)
@@ -522,7 +543,8 @@ if (WITH_CUDA)
522543
src/ops/topk_gpu.cu
523544
src/ops/topp_mask_gpu.cu
524545
src/ops/quantize_gpu.cu
525-
)
546+
src/ops/nccl_ops_gpu.cu
547+
)
526548
elseif(WITH_CUDNN)
527549
message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON")
528550
else()
@@ -546,6 +568,10 @@ target_include_directories(${PROJECT_NAME} BEFORE
546568
PRIVATE ${PRIVATE_INCLUDE_DIRECTORIES}
547569
)
548570

571+
if (WITH_TENSOR_PARALLEL AND CUDA_DYNAMIC_LOADING)
572+
target_compile_options(${PROJECT_NAME} PRIVATE -DOMPI_SKIP_MPICXX)
573+
endif()
574+
549575
if(BUILD_TESTS)
550576
add_subdirectory(tests)
551577
endif()
@@ -587,6 +613,11 @@ configure_file(cmake/${PROJECT_NAME}Config.cmake
587613
COPYONLY
588614
)
589615

616+
configure_file(cmake/FindNCCL.cmake
617+
"${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}/FindNCCL.cmake"
618+
COPYONLY
619+
)
620+
590621
set(ConfigPackageLocation ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME})
591622

592623
if(BUILD_SHARED_LIBS)
@@ -603,6 +634,7 @@ endif()
603634
install(
604635
FILES
605636
cmake/${PROJECT_NAME}Config.cmake
637+
cmake/FindNCCL.cmake
606638
"${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}/${PROJECT_NAME}ConfigVersion.cmake"
607639
DESTINATION
608640
${ConfigPackageLocation}

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ The project is production-oriented and comes with [backward compatibility guaran
3434
* **Lightweight on disk**<br/>Quantization can make the models 4 times smaller on disk with minimal accuracy loss.
3535
* **Simple integration**<br/>The project has few dependencies and exposes simple APIs in [Python](https://opennmt.net/CTranslate2/python/overview.html) and C++ to cover most integration needs.
3636
* **Configurable and interactive decoding**<br/>[Advanced decoding features](https://opennmt.net/CTranslate2/decoding.html) allow autocompleting a partial sequence and returning alternatives at a specific location in the sequence.
37+
* **Support tensor parallelism for distributed inference.
3738

3839
Some of these features are difficult to achieve with standard deep learning frameworks and are the motivation for this project.
3940

cmake/FindNCCL.cmake

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Find the NCCL libraries
2+
#
3+
# The following variables are optionally searched for defaults
4+
# NCCL_ROOT_DIR: Base directory where all NCCL components are found
5+
#
6+
# The following are set after configuration is done:
7+
# NCCL_FOUND
8+
# NCCL_INCLUDE_DIR
9+
# NCCL_LIBRARY
10+
11+
find_path(NCCL_INCLUDE_DIR NAMES nccl.h
12+
PATHS ${NCCL_ROOT_DIR}/include
13+
)
14+
15+
find_library(NCCL_LIBRARY NAMES nccl
16+
PATHS ${NCCL_ROOT_DIR}/lib ${NCCL_ROOT_DIR}/lib64)
17+
18+
include(FindPackageHandleStandardArgs)
19+
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR
20+
NCCL_LIBRARY)
21+
22+
if (NCCL_FOUND)
23+
message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIR}, library:
24+
${NCCL_LIBRARY})")
25+
mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY)
26+
set(NCCL_VERSION "${NCCL_MAJOR}.${NCCL_MINOR}.${NCCL_PATCH}")
27+
28+
endif ()

docker/Dockerfile

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ RUN wget -q https://github.com/oneapi-src/oneDNN/archive/refs/tags/v${ONEDNN_VER
3535
cd .. && \
3636
rm -r oneDNN-*
3737

38+
ENV OPENMPI_VERSION=4.1.6
39+
RUN wget -q https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-${OPENMPI_VERSION}.tar.bz2 && \
40+
tar xf *.tar.bz2 && \
41+
rm *.tar.bz2 && \
42+
cd openmpi-* && \
43+
./configure && \
44+
make -j$(nproc) install && \
45+
cd .. && \
46+
rm -r openmpi-*
47+
3848
COPY third_party third_party
3949
COPY cli cli
4050
COPY include include
@@ -50,13 +60,14 @@ ENV CUDA_NVCC_FLAGS=${CUDA_NVCC_FLAGS:-"-Xfatbin=-compress-all"}
5060
ARG CUDA_ARCH_LIST
5161
ENV CUDA_ARCH_LIST=${CUDA_ARCH_LIST:-"Common"}
5262
ENV CTRANSLATE2_ROOT=/opt/ctranslate2
63+
ENV LD_LIBRARY_PATH=/usr/local/lib/:${LD_LIBRARY_PATH}
5364

54-
RUN mkdir build && \
55-
cd build && \
65+
RUN mkdir build_tmp && \
66+
cd build_tmp && \
5667
cmake -DCMAKE_INSTALL_PREFIX=${CTRANSLATE2_ROOT} \
5768
-DWITH_CUDA=ON -DWITH_CUDNN=ON -DWITH_MKL=ON -DWITH_DNNL=ON -DOPENMP_RUNTIME=COMP \
5869
-DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="${CXX_FLAGS}" \
59-
-DCUDA_NVCC_FLAGS="${CUDA_NVCC_FLAGS}" -DCUDA_ARCH_LIST="${CUDA_ARCH_LIST}" .. && \
70+
-DCUDA_NVCC_FLAGS="${CUDA_NVCC_FLAGS}" -DCUDA_ARCH_LIST="${CUDA_ARCH_LIST}" -DWITH_TENSOR_PARALLEL=ON .. && \
6071
VERBOSE=1 make -j$(nproc) install
6172

6273
ENV LANG=en_US.UTF-8
@@ -74,6 +85,9 @@ RUN apt-get update && \
7485
apt-get install -y --no-install-recommends \
7586
libcublas-12-2 \
7687
libcudnn8=8.9.7.29-1+cuda12.2 \
88+
libnccl2=2.19.3-1+cuda12.2 \
89+
libopenmpi3=4.0.3-0ubuntu1 \
90+
openmpi-bin \
7791
libgomp1 \
7892
python3-pip \
7993
&& \

docs/parallel.md

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,50 @@ Parallelization with multiple Python threads is possible because all computation
4242
```
4343

4444
## Model and tensor parallelism
45+
Models used with [`Translator`](python/ctranslate2.Translator.rst) and [`Generator`](python/ctranslate2.Generator.rst) can be split into multiple GPUs.
46+
This is very useful when the model is too big to be loaded in only 1 GPU.
4547

46-
These types of parallelism are not yet implemented in CTranslate2.
48+
```python
49+
translator = ctranslate2.Translator(model_path, device="cuda", tensor_parallel=True)
50+
```
51+
52+
Setup environment:
53+
* Install [open-mpi](https://www.open-mpi.org/)
54+
* Configure open-mpi by creating the config file like ``hostfile``:
55+
```bash
56+
[ipaddress or dns] slots=nbGPU1
57+
[other ipaddress or dns] slots=NbGPU2
58+
```
59+
60+
Run:
61+
* Run the application in multiprocess to use tensor parallel:
62+
```bash
63+
mpirun -np nbGPUExpected -hostfile hostfile python3 script
64+
```
65+
66+
If you're trying to use tensor parallelism in multiple machines, some additional configuration is needed:
67+
* Make sure Master and Slave can connect to each other as a pair with ssh + pubkey
68+
* Export all necessary environment variables from Master to Slave like the example below:
69+
```bash
70+
mpirun -x VIRTUAL_ENV_PROMPT -x PATH -x VIRTUAL_ENV -x _ -x LD_LIBRARY_PATH -np nbGPUExpected -hostfile hostfile python3 script
71+
```
72+
Read more [open-mpi docs](https://www.open-mpi.org/doc/) for more information.
73+
74+
* In this mode, the application will run in multiprocess. We can filter out the master process by using:
75+
```python
76+
if ctranslate2.MpiInfo.getCurRank() == 0:
77+
print(...)
78+
```
79+
80+
```{note}
81+
Running model in tensor parallel mode in one machine can boost the performance but if the model shared between multiple machines
82+
could be slower because of the latency in the connectivity.
83+
```
84+
85+
```{note}
86+
In mode tensor parallel, `inter_threads` is always supported to run multiple workers. Otherwise, `device_index` no longer has any effect
87+
because tensor parallel mode will check only for available gpus on the system and the number of gpus you want to use.
88+
```
4789

4890
## Asynchronous execution
4991

include/ctranslate2/devices.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
#include <stdexcept>
44
#include <string>
5+
#include <vector>
6+
#ifdef CT2_WITH_TENSOR_PARALLEL
7+
# include <nccl.h>
8+
#endif
59

610
namespace ctranslate2 {
711

@@ -45,4 +49,30 @@ namespace ctranslate2 {
4549
int _new_index;
4650
};
4751

52+
extern int my_rank;
53+
extern int local_rank;
54+
extern int n_ranks;
55+
56+
class ScopedMPISetter {
57+
public:
58+
ScopedMPISetter();
59+
~ScopedMPISetter();
60+
61+
static int getNRanks();
62+
static int getCurRank();
63+
static int getLocalRank();
64+
65+
#ifdef CT2_WITH_TENSOR_PARALLEL
66+
static ncclComm_t getNcclComm();
67+
#endif
68+
69+
static void finalize();
70+
71+
private:
72+
#ifdef CT2_WITH_TENSOR_PARALLEL
73+
static uint64_t getHostHash(const char *string);
74+
static void getHostName(char *hostname, int maxlen);
75+
static std::vector<ncclComm_t*> _nccl_comms;
76+
#endif
77+
};
4878
}

include/ctranslate2/layers/attention.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ namespace ctranslate2 {
4343
}
4444

4545
bool multi_query() const {
46-
return _num_heads_kv == 1;
46+
return _multi_query;
4747
}
4848

4949
static StorageView prepare_length_mask(const StorageView& lengths,
@@ -53,6 +53,7 @@ namespace ctranslate2 {
5353
const bool multi_query = false);
5454

5555
private:
56+
const bool _tensor_parallel;
5657
const dim_t _num_heads;
5758
const bool _self_attention;
5859
const bool _is_decoder;
@@ -68,6 +69,7 @@ namespace ctranslate2 {
6869
const StorageView* _relative_position_values;
6970
dim_t _maximum_relative_position;
7071
const float _queries_scale;
72+
const bool _multi_query;
7173
const dim_t _num_heads_kv;
7274
const bool _merge_time_and_head_dims;
7375
const dim_t _cache_time_dim;

include/ctranslate2/layers/common.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ namespace ctranslate2 {
127127
public:
128128
Dense(const models::Model& model,
129129
const std::string& scope,
130-
const ops::ActivationType* activation_type = nullptr);
130+
const ops::ActivationType* activation_type = nullptr,
131+
const bool is_layer_out = false);
131132
DataType output_type() const override;
132133
dim_t output_size() const override;
133134
void operator()(const StorageView& input, StorageView& output) const;
@@ -147,6 +148,7 @@ namespace ctranslate2 {
147148
const ops::Gemm _gemm_op;
148149
const ops::Quantize _quantize_op;
149150
const ops::Dequantize _dequantize_op;
151+
const bool _is_layer_out;
150152
};
151153

152154
class LayerNorm : public Layer

include/ctranslate2/layers/transformer.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ namespace ctranslate2 {
3434
const Dense _ff1;
3535
const std::unique_ptr<const Dense> _ff1_noact;
3636
const Dense _ff2;
37+
const bool _tensor_parallel;
3738
};
3839

3940
class TransformerEncoderLayer : public Layer
@@ -149,6 +150,7 @@ namespace ctranslate2 {
149150
const std::unique_ptr<const LayerNorm> _output_norm;
150151
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
151152
const std::unique_ptr<PositionEncoder> _position_encoder;
153+
const bool _tensor_parallel;
152154
};
153155

154156
class TransformerDecoder : public Decoder
@@ -211,6 +213,7 @@ namespace ctranslate2 {
211213
bool _average_alignment_heads;
212214
Dense _proj;
213215
const dim_t _sliding_window;
216+
const bool _tensor_parallel;
214217
};
215218

216219
}

include/ctranslate2/models/model.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ namespace ctranslate2 {
2626
static std::shared_ptr<const Model> load(const std::string& path,
2727
Device device = Device::CPU,
2828
int device_index = 0,
29-
ComputeType compute_type = ComputeType::DEFAULT);
29+
ComputeType compute_type = ComputeType::DEFAULT,
30+
bool tensor_parallel = false);
3031
static std::shared_ptr<const Model> load(ModelReader& model_reader,
3132
Device device = Device::CPU,
3233
int device_index = 0,
33-
ComputeType compute_type = ComputeType::DEFAULT);
34+
ComputeType compute_type = ComputeType::DEFAULT,
35+
bool tensor_parallel = false);
3436

3537
virtual std::unique_ptr<SequenceToSequenceReplica> as_sequence_to_sequence() const;
3638
virtual std::unique_ptr<SequenceGeneratorReplica> as_sequence_generator() const;
@@ -78,6 +80,10 @@ namespace ctranslate2 {
7880
return _binary_version >= 5;
7981
}
8082

83+
bool tensor_parallel() const {
84+
return _tensor_parallel;
85+
}
86+
8187
virtual bool use_global_int16_scale() const {
8288
return true;
8389
}
@@ -163,6 +169,7 @@ namespace ctranslate2 {
163169
ComputeType _effective_compute_type = ComputeType::DEFAULT;
164170
dim_t _preferred_size_multiple = 1;
165171
std::unordered_map<std::string, std::shared_ptr<StorageView>> _variable_index;
172+
bool _tensor_parallel = false;
166173
};
167174

168175
template<>
@@ -191,6 +198,7 @@ namespace ctranslate2 {
191198
std::vector<int> device_indices = {0};
192199
size_t num_replicas_per_device = 1;
193200
ComputeType compute_type = ComputeType::DEFAULT;
201+
bool tensor_parallel = false;
194202
};
195203

196204
// Base class for replicas.

0 commit comments

Comments
 (0)