Description
This is the most computationally significant call in the entire transformer evaluation, so we have to be sure that it is running optimally.
It computes the matrix multiplication: z = x * y
x
is quantizedy
is F32z
is F32
Currently, it runs in 2 modes, depending on the tensor shapes:
- (A) for bigger tensors, if BLAS is available,
x
is dequantized to F32 and we usesgemm
to perform the matrix multiplication - (B) for smaller tensors, or if BLAS is not available,
y
is quantized to 4-bits on-the-fly and we use integer-based dot products to perform the matrix multiplication
The former method is much more accurate than the latter. This can be clearly observed during perplexity computations.
However, during text generation (i.e. batch = 1), it is not feasible to use it - my experience is that there is significant overhead of calling BLAS for smaller tensor shapes, typical for single-token inference calls.
There are at least two alternative modes of operation that can be explored:
- (C) for smaller tensors, or if BLAS is not available,
x
is dequantized to F32 and we useggml_vec_dot_f32()
to perform the multiplication - (D) for smaller tensors, or if BLAS is not available,
x
is dequantized to F16,y
is converted to F16 and we useggml_vec_dot_f16()
to perform the multiplication - (E) for smaller tensors, or if BLAS is not available,
y
is quantized on-the-fly to 8-bits and we use a newggml
dot-product call that operates on4-bit x
and8-bit y
. This call will still unpackx
into 8-bits as usual and perform the 8-bit dot-product as in the existing routines, but in contrast to (B),y
will already be unpacked to 8-bits and the precision loss will be significantly slower
To me it is not immediately clear if (C) or (D) would be significantly slower compared to (B), but they should be much more accurate compared to (B) and probably as accurate as (A).
I think, one has to be careful and choose the respective mode based on the tensor shapes, trying to find a good balance between speed and accuracy. Ideally, I am hoping after this investigation that we will achieve noticeable perplexity gain without using BLAS at the cost of a slightly slower single-token (i.e. batch = 1) computation.
Edit: after the analysis and discussion in #896 I added a new mode (E) which I think is very important to be explored. Unless I am missing something, I believe this mode can be exactly as efficient as (B), but with significantly higher accuracy. Much higher than what can be achieved via improving the quantization RMS.
So I believe we have to investigate this with very high priority.