From 76e3c35d358cf47ab04d21aac9260cd7c921c83f Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 1 Feb 2019 10:20:36 +0100 Subject: [PATCH 1/9] replace cblas calls in utils/weight_vector --- sklearn/utils/weight_vector.pyx | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/sklearn/utils/weight_vector.pyx b/sklearn/utils/weight_vector.pyx index 5d8e3b24f8273..edd6067059609 100644 --- a/sklearn/utils/weight_vector.pyx +++ b/sklearn/utils/weight_vector.pyx @@ -14,11 +14,8 @@ from libc.math cimport sqrt import numpy as np cimport numpy as np -cdef extern from "cblas.h": - double ddot "cblas_ddot"(int, double *, int, double *, int) nogil - void dscal "cblas_dscal"(int, double, double *, int) nogil - void daxpy "cblas_daxpy" (int, double, const double*, - int, double*, int) nogil +from ._cython_blas cimport _dot, _scal, _axpy + np.import_array() @@ -59,7 +56,7 @@ cdef class WeightVector(object): self.w_data_ptr = wdata self.wscale = 1.0 self.n_features = w.shape[0] - self.sq_norm = ddot(w.shape[0], wdata, 1, wdata, 1) + self.sq_norm = _dot(w.shape[0], wdata, 1, wdata, 1) self.aw = aw if self.aw is not None: @@ -183,14 +180,14 @@ cdef class WeightVector(object): cdef void reset_wscale(self) nogil: """Scales each coef of ``w`` by ``wscale`` and resets it to 1. """ if self.aw is not None: - daxpy(self.aw.shape[0], self.average_a, + _axpy(self.aw.shape[0], self.average_a, self.w.data, 1, self.aw.data, 1) - dscal(self.aw.shape[0], 1.0 / self.average_b, + _scal(self.aw.shape[0], 1.0 / self.average_b, self.aw.data, 1) self.average_a = 0.0 self.average_b = 1.0 - dscal(self.w.shape[0], self.wscale, self.w.data, 1) + _scal(self.w.shape[0], self.wscale, self.w.data, 1) self.wscale = 1.0 cdef double norm(self) nogil: From 3a197e15a074245dcb0a88d09d984219ebcd7df0 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 1 Feb 2019 12:03:33 +0100 Subject: [PATCH 2/9] add rot & rotg to cython_blas --- sklearn/utils/_cython_blas.pxd | 4 ++++ sklearn/utils/_cython_blas.pyx | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/sklearn/utils/_cython_blas.pxd b/sklearn/utils/_cython_blas.pxd index 4d82c7b1aaf13..30b3ea5e4fc07 100644 --- a/sklearn/utils/_cython_blas.pxd +++ b/sklearn/utils/_cython_blas.pxd @@ -26,6 +26,10 @@ cdef void _copy(int, floating*, int, floating*, int) nogil cdef void _scal(int, floating, floating*, int) nogil +cdef void _rotg(floating*, floating*, floating*, floating*) nogil + +cdef void _rot(int, floating*, int, floating*, int, floating, floating) nogil + # BLAS Level 2 ################################################################ cdef void _gemv(BLAS_Order, BLAS_Trans, int, int, floating, floating*, int, floating*, int, floating, floating*, int) nogil diff --git a/sklearn/utils/_cython_blas.pyx b/sklearn/utils/_cython_blas.pyx index 7585105227f9f..a128aa97fb9cd 100644 --- a/sklearn/utils/_cython_blas.pyx +++ b/sklearn/utils/_cython_blas.pyx @@ -6,6 +6,8 @@ from scipy.linalg.cython_blas cimport saxpy, daxpy from scipy.linalg.cython_blas cimport snrm2, dnrm2 from scipy.linalg.cython_blas cimport scopy, dcopy from scipy.linalg.cython_blas cimport sscal, dscal +from scipy.linalg.cython_blas cimport srotg, drotg +from scipy.linalg.cython_blas cimport srot, drot from scipy.linalg.cython_blas cimport sgemv, dgemv from scipy.linalg.cython_blas cimport sger, dger from scipy.linalg.cython_blas cimport sgemm, dgemm @@ -89,6 +91,23 @@ cpdef _scal_memview(floating alpha, floating[::1] x): _scal(x.shape[0], alpha, &x[0], 1) +cdef void _rotg(floating *a, floating *b, floating *c, floating *s) nogil: + """Generate plane rotation""" + if floating is float: + srotg(a, b, c, s) + else: + drotg(a, b, c, s) + + +cdef void _rot(int n, floating *x, int incx, floating *y, int incy, + floating c, floating s) nogil: + """Apply plane rotation""" + if floating is float: + srot(&n, x, &incx, y, &incy, &c, &s) + else: + drot(&n, x, &incx, y, &incy, &c, &s) + + ################ # BLAS Level 2 # ################ From 61cfa89b66c49d00144e5a3c5a383c0b35e21fcf Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 1 Feb 2019 12:04:24 +0100 Subject: [PATCH 3/9] cholesky_delete -> cython to use cython_blas --- sklearn/utils/arrayfuncs.pyx | 52 ++++++++++++++------ sklearn/utils/setup.py | 24 +++------ sklearn/utils/src/cholesky_delete.h | 76 ----------------------------- 3 files changed, 42 insertions(+), 110 deletions(-) delete mode 100644 sklearn/utils/src/cholesky_delete.h diff --git a/sklearn/utils/arrayfuncs.pyx b/sklearn/utils/arrayfuncs.pyx index 0ac50c76df5cd..4fc64724e5ec8 100644 --- a/sklearn/utils/arrayfuncs.pyx +++ b/sklearn/utils/arrayfuncs.pyx @@ -4,14 +4,12 @@ Small collection of auxiliary functions that operate on arrays """ cimport numpy as np import numpy as np - cimport cython - +from cython cimport floating +from libc.math cimport fabs from libc.float cimport DBL_MAX, FLT_MAX -cdef extern from "src/cholesky_delete.h": - int cholesky_delete_dbl(int m, int n, double *L, int go_out) - int cholesky_delete_flt(int m, int n, float *L, int go_out) +from ._cython_blas cimport _copy, _rotg, _rot ctypedef np.float64_t DOUBLE @@ -51,14 +49,36 @@ cdef double _double_min_pos(double *X, Py_ssize_t size): return min_val -# we should be using np.npy_intp or Py_ssize_t for indices, but BLAS wants int -def cholesky_delete(np.ndarray L, int go_out): - cdef int n = L.shape[0] - cdef int m = L.strides[0] - - if L.dtype.name == 'float64': - cholesky_delete_dbl(m / sizeof(double), n, L.data, go_out) - elif L.dtype.name == 'float32': - cholesky_delete_flt(m / sizeof(float), n, L.data, go_out) - else: - raise TypeError("unsupported dtype %r." % L.dtype) +# General Cholesky Delete. +# Remove an element from the cholesky factorization +# m = columns +# n = rows +# +# TODO: put transpose as an option +def cholesky_delete(np.ndarray[floating, ndim=2] L, int go_out): + cdef: + int n = L.shape[0] + int m = L.shape[1] + floating c, s + floating *L1 + int i + + # delete row go_out + L1 = &L[0, 0] + (go_out * m) + for i in range(go_out, n - 1): + _copy(i + 2, L1 + m, 1, L1, 1) + L1 += m + + L1 = &L[0, 0] + (go_out * m) + for i in range(go_out, n-1): + _rotg(L1 + i, L1 + i + 1, &c, &s) + if L1[i] < 0: + # Diagonals cannot be negative + L1[i] = fabs(L1[i]) + c = -c + s = -s + + L1[i + 1] = 0. # just for cleanup + L1 += m + + _rot(n - i - 2, L1 + i, m, L1 + i + 1, m, c, s) diff --git a/sklearn/utils/setup.py b/sklearn/utils/setup.py index 97aeb602408c4..c083180144bcd 100644 --- a/sklearn/utils/setup.py +++ b/sklearn/utils/setup.py @@ -1,8 +1,6 @@ import os from os.path import join -from sklearn._build_utils import get_blas_info - def configuration(parent_package='', top_path=None): import numpy @@ -10,18 +8,12 @@ def configuration(parent_package='', top_path=None): config = Configuration('utils', parent_package, top_path) - cblas_libs, blas_info = get_blas_info() - cblas_compile_args = blas_info.pop('extra_compile_args', []) - cblas_includes = [join('..', 'src', 'cblas'), - numpy.get_include(), - blas_info.pop('include_dirs', [])] - libraries = [] if os.name == 'posix': libraries.append('m') - cblas_libs.append('m') - config.add_extension('sparsefuncs_fast', sources=['sparsefuncs_fast.pyx'], + config.add_extension('sparsefuncs_fast', + sources=['sparsefuncs_fast.pyx'], libraries=libraries) config.add_extension('_cython_blas', @@ -30,11 +22,8 @@ def configuration(parent_package='', top_path=None): config.add_extension('arrayfuncs', sources=['arrayfuncs.pyx'], - depends=[join('src', 'cholesky_delete.h')], - libraries=cblas_libs, - include_dirs=cblas_includes, - extra_compile_args=cblas_compile_args, - **blas_info) + include_dirs=[numpy.get_include()], + libraries=libraries) config.add_extension('murmurhash', sources=['murmurhash.pyx', join( @@ -62,9 +51,8 @@ def configuration(parent_package='', top_path=None): config.add_extension('weight_vector', sources=['weight_vector.pyx'], - include_dirs=cblas_includes, - libraries=cblas_libs, - **blas_info) + include_dirs=[numpy.get_include()], + libraries=libraries) config.add_extension("_random", sources=["_random.pyx"], diff --git a/sklearn/utils/src/cholesky_delete.h b/sklearn/utils/src/cholesky_delete.h deleted file mode 100644 index 6e20a2b003ed7..0000000000000 --- a/sklearn/utils/src/cholesky_delete.h +++ /dev/null @@ -1,76 +0,0 @@ -#include -#include - -#ifdef _MSC_VER -# define inline __inline -#endif - - -/* - * General Cholesky Delete. - * Remove an element from the cholesky factorization - * m = columns - * n = rows - * - * TODO: put transpose as an option - */ -static inline void cholesky_delete_dbl(int m, int n, double *L, int go_out) -{ - double c, s; - - /* delete row go_out */ - double *L1 = L + (go_out * m); - int i; - for (i = go_out; i < n - 1; ++i) { - cblas_dcopy (i + 2, L1 + m , 1, L1, 1); - L1 += m; - } - - L1 = L + (go_out * m); - for (i=go_out; i < n - 1; ++i) { - - cblas_drotg(L1 + i, L1 + i + 1, &c, &s); - if (L1[i] < 0) { - /* Diagonals cannot be negative */ - L1[i] = fabs(L1[i]); - c = -c; - s = -s; - } - L1[i+1] = 0.; /* just for cleanup */ - L1 += m; - - cblas_drot(n - (i + 2), L1 + i, m, L1 + i + 1, - m, c, s); - } -} - - -static inline void cholesky_delete_flt(int m, int n, float *L, int go_out) -{ - float c, s; - - /* delete row go_out */ - float *L1 = L + (go_out * m); - int i; - for (i = go_out; i < n - 1; ++i) { - cblas_scopy (i + 2, L1 + m , 1, L1, 1); - L1 += m; - } - - L1 = L + (go_out * m); - for (i=go_out; i < n - 1; ++i) { - - cblas_srotg(L1 + i, L1 + i + 1, &c, &s); - if (L1[i] < 0) { - /* Diagonals cannot be negative */ - L1[i] = fabsf(L1[i]); - c = -c; - s = -s; - } - L1[i+1] = 0.; /* just for cleanup */ - L1 += m; - - cblas_srot(n - (i + 2), L1 + i, m, L1 + i + 1, - m, c, s); - } -} From 309f8ef1a520ab880cb09f589663eec97de74d09 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 1 Feb 2019 14:54:05 +0100 Subject: [PATCH 4/9] add rot & rotg to cython_blas --- sklearn/utils/_cython_blas.pyx | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sklearn/utils/_cython_blas.pyx b/sklearn/utils/_cython_blas.pyx index a128aa97fb9cd..c15e66ee02ce1 100644 --- a/sklearn/utils/_cython_blas.pyx +++ b/sklearn/utils/_cython_blas.pyx @@ -99,6 +99,11 @@ cdef void _rotg(floating *a, floating *b, floating *c, floating *s) nogil: drotg(a, b, c, s) +cpdef _rotg_memview(floating a, floating b, floating c, floating s): + _rotg(&a, &b, &c, &s) + return a, b, c, s + + cdef void _rot(int n, floating *x, int incx, floating *y, int incy, floating c, floating s) nogil: """Apply plane rotation""" @@ -108,6 +113,10 @@ cdef void _rot(int n, floating *x, int incx, floating *y, int incy, drot(&n, x, &incx, y, &incy, &c, &s) +cpdef _rot_memview(floating[::1] x, floating[::1] y, floating c, floating s): + _rot(x.shape[0], &x[0], 1, &y[0], 1, c, s) + + ################ # BLAS Level 2 # ################ From 71e60bf3d7146b791589be34e90c8fc9b4f12d61 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 1 Feb 2019 14:54:31 +0100 Subject: [PATCH 5/9] rot & rotg cython_blas tests --- sklearn/utils/tests/test_cython_blas.py | 46 +++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/sklearn/utils/tests/test_cython_blas.py b/sklearn/utils/tests/test_cython_blas.py index 0305e5a5476dc..cfb0c9c1e146d 100644 --- a/sklearn/utils/tests/test_cython_blas.py +++ b/sklearn/utils/tests/test_cython_blas.py @@ -11,6 +11,8 @@ from sklearn.utils._cython_blas import _nrm2_memview from sklearn.utils._cython_blas import _copy_memview from sklearn.utils._cython_blas import _scal_memview +from sklearn.utils._cython_blas import _rotg_memview +from sklearn.utils._cython_blas import _rot_memview from sklearn.utils._cython_blas import _gemv_memview from sklearn.utils._cython_blas import _ger_memview from sklearn.utils._cython_blas import _gemm_memview @@ -110,6 +112,50 @@ def test_scal(dtype): assert_allclose(x, expected, rtol=RTOL[dtype]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_rotg(dtype): + rotg = _rotg_memview[NUMPY_TO_CYTHON[dtype]] + + rng = np.random.RandomState(0) + a = dtype(rng.randn()) + b = dtype(rng.randn()) + c, s = 0.0, 0.0 + + def expected_rotg(a, b): + roe = a if abs(a) > abs(b) else b + if a == 0 and b == 0: + c, s, r, z = (1, 0, 0, 0) + else: + r = np.sqrt(a**2 + b**2) * (1 if roe >= 0 else -1) + c, s = a/r, b/r + z = s if roe == a else (1 if c == 0 else 1 / c) + return r, z, c, s + + expected = expected_rotg(a, b) + actual = rotg(a, b, c, s) + + assert_allclose(actual, expected, rtol=RTOL[dtype]) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_rot(dtype): + rot = _rot_memview[NUMPY_TO_CYTHON[dtype]] + + rng = np.random.RandomState(0) + x = rng.random_sample(10).astype(dtype, copy=False) + y = rng.random_sample(10).astype(dtype, copy=False) + c = dtype(rng.randn()) + s = dtype(rng.randn()) + + expected_x = c * x + s * y + expected_y = c * y - s * x + + rot(x, y, c, s) + + assert_allclose(x, expected_x) + assert_allclose(y, expected_y) + + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("opA, transA", [(_no_op, NoTrans), (np.transpose, Trans)], From 0ac6ebc8e0f927d6d43002b291fdf843f3122c78 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 1 Feb 2019 14:54:59 +0100 Subject: [PATCH 6/9] cholesky delete fix --- sklearn/utils/arrayfuncs.pyx | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/arrayfuncs.pyx b/sklearn/utils/arrayfuncs.pyx index 4fc64724e5ec8..2958c6d14e638 100644 --- a/sklearn/utils/arrayfuncs.pyx +++ b/sklearn/utils/arrayfuncs.pyx @@ -58,14 +58,19 @@ cdef double _double_min_pos(double *X, Py_ssize_t size): def cholesky_delete(np.ndarray[floating, ndim=2] L, int go_out): cdef: int n = L.shape[0] - int m = L.shape[1] + int m = L.strides[0] floating c, s floating *L1 int i + + if floating is float: + m /= sizeof(float) + else: + m /= sizeof(double) # delete row go_out L1 = &L[0, 0] + (go_out * m) - for i in range(go_out, n - 1): + for i in range(go_out, n-1): _copy(i + 2, L1 + m, 1, L1, 1) L1 += m From 1432c9f246d7eccd95dd1aaa4dcb2a744b8f98fb Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 1 Feb 2019 15:44:31 +0100 Subject: [PATCH 7/9] cblas to scipy cython_blas in linear_model/cd_fast --- sklearn/linear_model/cd_fast.pyx | 178 +++++++++---------------------- sklearn/linear_model/setup.py | 29 ++--- 2 files changed, 59 insertions(+), 148 deletions(-) diff --git a/sklearn/linear_model/cd_fast.pyx b/sklearn/linear_model/cd_fast.pyx index c75ad0f667d46..c512c02f6576a 100644 --- a/sklearn/linear_model/cd_fast.pyx +++ b/sklearn/linear_model/cd_fast.pyx @@ -16,6 +16,11 @@ from cpython cimport bool from cython cimport floating import warnings +from ..utils._cython_blas cimport (_axpy, _dot, _asum, _ger, _gemv, _nrm2, + _copy, _scal) +from ..utils._cython_blas cimport RowMajor, ColMajor, Trans, NoTrans + + ctypedef np.float64_t DOUBLE ctypedef np.uint32_t UINT32_t @@ -94,50 +99,6 @@ cdef floating diff_abs_max(int n, floating* a, floating* b) nogil: return m -cdef extern from "cblas.h": - enum CBLAS_ORDER: - CblasRowMajor=101 - CblasColMajor=102 - enum CBLAS_TRANSPOSE: - CblasNoTrans=111 - CblasTrans=112 - CblasConjTrans=113 - AtlasConj=114 - - void daxpy "cblas_daxpy"(int N, double alpha, double *X, int incX, - double *Y, int incY) nogil - void saxpy "cblas_saxpy"(int N, float alpha, float *X, int incX, - float *Y, int incY) nogil - double ddot "cblas_ddot"(int N, double *X, int incX, double *Y, int incY - ) nogil - float sdot "cblas_sdot"(int N, float *X, int incX, float *Y, - int incY) nogil - double dasum "cblas_dasum"(int N, double *X, int incX) nogil - float sasum "cblas_sasum"(int N, float *X, int incX) nogil - void dger "cblas_dger"(CBLAS_ORDER Order, int M, int N, double alpha, - double *X, int incX, double *Y, int incY, - double *A, int lda) nogil - void sger "cblas_sger"(CBLAS_ORDER Order, int M, int N, float alpha, - float *X, int incX, float *Y, int incY, - float *A, int lda) nogil - void dgemv "cblas_dgemv"(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA, - int M, int N, double alpha, double *A, int lda, - double *X, int incX, double beta, - double *Y, int incY) nogil - void sgemv "cblas_sgemv"(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA, - int M, int N, float alpha, float *A, int lda, - float *X, int incX, float beta, - float *Y, int incY) nogil - double dnrm2 "cblas_dnrm2"(int N, double *X, int incX) nogil - float snrm2 "cblas_snrm2"(int N, float *X, int incX) nogil - void dcopy "cblas_dcopy"(int N, double *X, int incX, double *Y, - int incY) nogil - void scopy "cblas_scopy"(int N, float *X, int incX, float *Y, - int incY) nogil - void dscal "cblas_dscal"(int N, double alpha, double *X, int incX) nogil - void sscal "cblas_sscal"(int N, float alpha, float *X, int incX) nogil - - @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) @@ -156,21 +117,10 @@ def enet_coordinate_descent(floating[::1] w, """ - # fused types version of BLAS functions if floating is float: dtype = np.float32 - gemv = sgemv - dot = sdot - axpy = saxpy - asum = sasum - copy = scopy else: dtype = np.float64 - gemv = dgemv - dot = ddot - axpy = daxpy - asum = dasum - copy = dcopy # get the data information into easy vars cdef unsigned int n_samples = X.shape[0] @@ -209,14 +159,12 @@ def enet_coordinate_descent(floating[::1] w, with nogil: # R = y - np.dot(X, w) - copy(n_samples, &y[0], 1, &R[0], 1) - gemv(CblasColMajor, CblasNoTrans, - n_samples, n_features, -1.0, &X[0, 0], n_samples, - &w[0], 1, - 1.0, &R[0], 1) + _copy(n_samples, &y[0], 1, &R[0], 1) + _gemv(ColMajor, NoTrans, n_samples, n_features, -1.0, &X[0, 0], + n_samples, &w[0], 1, 1.0, &R[0], 1) # tol *= np.dot(y, y) - tol *= dot(n_samples, &y[0], 1, &y[0], 1) + tol *= _dot(n_samples, &y[0], 1, &y[0], 1) for n_iter in range(max_iter): w_max = 0.0 @@ -234,10 +182,10 @@ def enet_coordinate_descent(floating[::1] w, if w_ii != 0.0: # R += w_ii * X[:,ii] - axpy(n_samples, w_ii, &X[0, ii], 1, &R[0], 1) + _axpy(n_samples, w_ii, &X[0, ii], 1, &R[0], 1) # tmp = (X[:,ii]*R).sum() - tmp = dot(n_samples, &X[0, ii], 1, &R[0], 1) + tmp = _dot(n_samples, &X[0, ii], 1, &R[0], 1) if positive and tmp < 0: w[ii] = 0.0 @@ -247,7 +195,7 @@ def enet_coordinate_descent(floating[::1] w, if w[ii] != 0.0: # R -= w[ii] * X[:,ii] # Update residual - axpy(n_samples, -w[ii], &X[0, ii], 1, &R[0], 1) + _axpy(n_samples, -w[ii], &X[0, ii], 1, &R[0], 1) # update the maximum absolute coefficient update d_w_ii = fabs(w[ii] - w_ii) @@ -264,7 +212,7 @@ def enet_coordinate_descent(floating[::1] w, # XtA = np.dot(X.T, R) - beta * w for i in range(n_features): - XtA[i] = (dot(n_samples, &X[0, i], 1, &R[0], 1) + XtA[i] = (_dot(n_samples, &X[0, i], 1, &R[0], 1) - beta * w[i]) if positive: @@ -273,10 +221,10 @@ def enet_coordinate_descent(floating[::1] w, dual_norm_XtA = abs_max(n_features, &XtA[0]) # R_norm2 = np.dot(R, R) - R_norm2 = dot(n_samples, &R[0], 1, &R[0], 1) + R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1) # w_norm2 = np.dot(w, w) - w_norm2 = dot(n_features, &w[0], 1, &w[0], 1) + w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1) if (dual_norm_XtA > alpha): const = alpha / dual_norm_XtA @@ -286,11 +234,11 @@ def enet_coordinate_descent(floating[::1] w, const = 1.0 gap = R_norm2 - l1_norm = asum(n_features, &w[0], 1) + l1_norm = _asum(n_features, &w[0], 1) # np.dot(R.T, y) gap += (alpha * l1_norm - - const * dot(n_samples, &R[0], 1, &y[0], 1) + - const * _dot(n_samples, &R[0], 1, &y[0], 1) + 0.5 * beta * (1 + const ** 2) * (w_norm2)) if gap < tol: @@ -336,15 +284,10 @@ def sparse_enet_coordinate_descent(floating [::1] w, cdef floating[:] X_T_R cdef floating[:] XtA - # fused types version of BLAS functions if floating is float: dtype = np.float32 - dot = sdot - asum = sasum else: dtype = np.float64 - dot = ddot - asum = dasum norm_cols_X = np.zeros(n_features, dtype=dtype) X_T_R = np.zeros(n_features, dtype=dtype) @@ -397,7 +340,7 @@ def sparse_enet_coordinate_descent(floating [::1] w, startptr = endptr # tol *= np.dot(y, y) - tol *= dot(n_samples, &y[0], 1, &y[0], 1) + tol *= _dot(n_samples, &y[0], 1, &y[0], 1) for n_iter in range(max_iter): @@ -486,10 +429,10 @@ def sparse_enet_coordinate_descent(floating [::1] w, dual_norm_XtA = abs_max(n_features, &XtA[0]) # R_norm2 = np.dot(R, R) - R_norm2 = dot(n_samples, &R[0], 1, &R[0], 1) + R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1) # w_norm2 = np.dot(w, w) - w_norm2 = dot(n_features, &w[0], 1, &w[0], 1) + w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1) if (dual_norm_XtA > alpha): const = alpha / dual_norm_XtA A_norm2 = R_norm2 * const**2 @@ -498,9 +441,9 @@ def sparse_enet_coordinate_descent(floating [::1] w, const = 1.0 gap = R_norm2 - l1_norm = asum(n_features, &w[0], 1) + l1_norm = _asum(n_features, &w[0], 1) - gap += (alpha * l1_norm - const * dot( + gap += (alpha * l1_norm - const * _dot( n_samples, &R[0], 1, &y[0], 1 @@ -536,17 +479,10 @@ def enet_coordinate_descent_gram(floating[::1] w, q = X^T y """ - # fused types version of BLAS functions if floating is float: dtype = np.float32 - dot = sdot - axpy = saxpy - asum = sasum else: dtype = np.float64 - dot = ddot - axpy = daxpy - asum = dasum # get the data information into easy vars cdef unsigned int n_samples = y.shape[0] @@ -601,8 +537,8 @@ def enet_coordinate_descent_gram(floating[::1] w, if w_ii != 0.0: # H -= w_ii * Q[ii] - axpy(n_features, -w_ii, Q_ptr + ii * n_features, 1, - H_ptr, 1) + _axpy(n_features, -w_ii, Q_ptr + ii * n_features, 1, + H_ptr, 1) tmp = q[ii] - H[ii] @@ -614,8 +550,8 @@ def enet_coordinate_descent_gram(floating[::1] w, if w[ii] != 0.0: # H += w[ii] * Q[ii] # Update H = X.T X w - axpy(n_features, w[ii], Q_ptr + ii * n_features, 1, - H_ptr, 1) + _axpy(n_features, w[ii], Q_ptr + ii * n_features, 1, + H_ptr, 1) # update the maximum absolute coefficient update d_w_ii = fabs(w[ii] - w_ii) @@ -631,7 +567,7 @@ def enet_coordinate_descent_gram(floating[::1] w, # criterion # q_dot_w = np.dot(w, q) - q_dot_w = dot(n_features, w_ptr, 1, q_ptr, 1) + q_dot_w = _dot(n_features, w_ptr, 1, q_ptr, 1) for ii in range(n_features): XtA[ii] = q[ii] - H[ii] - beta * w[ii] @@ -647,7 +583,7 @@ def enet_coordinate_descent_gram(floating[::1] w, R_norm2 = y_norm2 + tmp - 2.0 * q_dot_w # w_norm2 = np.dot(w, w) - w_norm2 = dot(n_features, &w[0], 1, &w[0], 1) + w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1) if (dual_norm_XtA > alpha): const = alpha / dual_norm_XtA @@ -657,8 +593,8 @@ def enet_coordinate_descent_gram(floating[::1] w, const = 1.0 gap = R_norm2 - # The call to dasum is equivalent to the L1 norm of w - gap += (alpha * asum(n_features, &w[0], 1) - + # The call to asum is equivalent to the L1 norm of w + gap += (alpha * _asum(n_features, &w[0], 1) - const * y_norm2 + const * q_dot_w + 0.5 * beta * (1 + const ** 2) * w_norm2) @@ -686,25 +622,11 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg, (1/2) * norm(y - X w, 2)^2 + l1_reg ||w||_21 + (1/2) * l2_reg norm(w, 2)^2 """ - # fused types version of BLAS functions + if floating is float: dtype = np.float32 - dot = sdot - nrm2 = snrm2 - asum = sasum - copy = scopy - scal = sscal - ger = sger - gemv = sgemv else: dtype = np.float64 - dot = ddot - nrm2 = dnrm2 - asum = dasum - copy = dcopy - scal = dscal - ger = dger - gemv = dgemv # get the data information into easy vars cdef unsigned int n_samples = X.shape[0] @@ -759,11 +681,11 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg, for ii in range(n_samples): for jj in range(n_tasks): R[ii, jj] = Y[ii, jj] - ( - dot(n_features, X_ptr + ii, n_samples, W_ptr + jj, n_tasks) + _dot(n_features, X_ptr + ii, n_samples, W_ptr + jj, n_tasks) ) # tol = tol * linalg.norm(Y, ord='fro') ** 2 - tol = tol * nrm2(n_samples * n_tasks, Y_ptr, 1) ** 2 + tol = tol * _nrm2(n_samples * n_tasks, Y_ptr, 1) ** 2 for n_iter in range(max_iter): w_max = 0.0 @@ -778,33 +700,32 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg, continue # w_ii = W[:, ii] # Store previous value - copy(n_tasks, W_ptr + ii * n_tasks, 1, wii_ptr, 1) + _copy(n_tasks, W_ptr + ii * n_tasks, 1, wii_ptr, 1) # if np.sum(w_ii ** 2) != 0.0: # can do better - if nrm2(n_tasks, wii_ptr, 1) != 0.0: + if _nrm2(n_tasks, wii_ptr, 1) != 0.0: # R += np.dot(X[:, ii][:, None], w_ii[None, :]) # rank 1 update - ger(CblasRowMajor, n_samples, n_tasks, 1.0, + _ger(RowMajor, n_samples, n_tasks, 1.0, X_ptr + ii * n_samples, 1, wii_ptr, 1, &R[0, 0], n_tasks) # tmp = np.dot(X[:, ii][None, :], R).ravel() - gemv(CblasRowMajor, CblasTrans, - n_samples, n_tasks, 1.0, &R[0, 0], n_tasks, - X_ptr + ii * n_samples, 1, 0.0, &tmp[0], 1) + _gemv(RowMajor, Trans, n_samples, n_tasks, 1.0, &R[0, 0], + n_tasks, X_ptr + ii * n_samples, 1, 0.0, &tmp[0], 1) # nn = sqrt(np.sum(tmp ** 2)) - nn = nrm2(n_tasks, &tmp[0], 1) + nn = _nrm2(n_tasks, &tmp[0], 1) # W[:, ii] = tmp * fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg) - copy(n_tasks, &tmp[0], 1, W_ptr + ii * n_tasks, 1) - scal(n_tasks, fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg), - W_ptr + ii * n_tasks, 1) + _copy(n_tasks, &tmp[0], 1, W_ptr + ii * n_tasks, 1) + _scal(n_tasks, fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg), + W_ptr + ii * n_tasks, 1) # if np.sum(W[:, ii] ** 2) != 0.0: # can do better - if nrm2(n_tasks, W_ptr + ii * n_tasks, 1) != 0.0: + if _nrm2(n_tasks, W_ptr + ii * n_tasks, 1) != 0.0: # R -= np.dot(X[:, ii][:, None], W[:, ii][None, :]) # Update residual : rank 1 update - ger(CblasRowMajor, n_samples, n_tasks, -1.0, + _ger(RowMajor, n_samples, n_tasks, -1.0, X_ptr + ii * n_samples, 1, W_ptr + ii * n_tasks, 1, &R[0, 0], n_tasks) @@ -826,7 +747,7 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg, # XtA = np.dot(X.T, R) - l2_reg * W.T for ii in range(n_features): for jj in range(n_tasks): - XtA[ii, jj] = dot( + XtA[ii, jj] = _dot( n_samples, X_ptr + ii * n_samples, 1, &R[0, 0] + jj, n_tasks ) - l2_reg * W[jj, ii] @@ -835,15 +756,16 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg, dual_norm_XtA = 0.0 for ii in range(n_features): # np.sqrt(np.sum(XtA ** 2, axis=1)) - XtA_axis1norm = nrm2(n_tasks, &XtA[0, 0] + ii * n_tasks, 1) + XtA_axis1norm = _nrm2(n_tasks, + &XtA[0, 0] + ii * n_tasks, 1) if XtA_axis1norm > dual_norm_XtA: dual_norm_XtA = XtA_axis1norm # TODO: use squared L2 norm directly # R_norm = linalg.norm(R, ord='fro') # w_norm = linalg.norm(W, ord='fro') - R_norm = nrm2(n_samples * n_tasks, &R[0, 0], 1) - w_norm = nrm2(n_features * n_tasks, W_ptr, 1) + R_norm = _nrm2(n_samples * n_tasks, &R[0, 0], 1) + w_norm = _nrm2(n_features * n_tasks, W_ptr, 1) if (dual_norm_XtA > l1_reg): const = l1_reg / dual_norm_XtA A_norm = R_norm * const @@ -862,7 +784,7 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg, l21_norm = 0.0 for ii in range(n_features): # np.sqrt(np.sum(W ** 2, axis=0)) - l21_norm += nrm2(n_tasks, W_ptr + n_tasks * ii, 1) + l21_norm += _nrm2(n_tasks, W_ptr + n_tasks * ii, 1) gap += l1_reg * l21_norm - const * ry_sum + \ 0.5 * l2_reg * (1 + const ** 2) * (w_norm ** 2) diff --git a/sklearn/linear_model/setup.py b/sklearn/linear_model/setup.py index 9c3822b8e7561..74245d547ec1e 100644 --- a/sklearn/linear_model/setup.py +++ b/sklearn/linear_model/setup.py @@ -1,38 +1,26 @@ import os -from os.path import join import numpy -from sklearn._build_utils import get_blas_info - def configuration(parent_package='', top_path=None): from numpy.distutils.misc_util import Configuration config = Configuration('linear_model', parent_package, top_path) - cblas_libs, blas_info = get_blas_info() - + libraries = [] if os.name == 'posix': - cblas_libs.append('m') + libraries.append('m') - config.add_extension('cd_fast', sources=['cd_fast.pyx'], - libraries=cblas_libs, - include_dirs=[join('..', 'src', 'cblas'), - numpy.get_include(), - blas_info.pop('include_dirs', [])], - extra_compile_args=blas_info.pop('extra_compile_args', - []), **blas_info) + config.add_extension('cd_fast', + sources=['cd_fast.pyx'], + include_dirs=numpy.get_include(), + libraries=libraries) config.add_extension('sgd_fast', sources=['sgd_fast.pyx'], - include_dirs=[join('..', 'src', 'cblas'), - numpy.get_include(), - blas_info.pop('include_dirs', [])], - libraries=cblas_libs, - extra_compile_args=blas_info.pop('extra_compile_args', - []), - **blas_info) + include_dirs=numpy.get_include(), + libraries=libraries) config.add_extension('sag_fast', sources=['sag_fast.pyx'], @@ -43,6 +31,7 @@ def configuration(parent_package='', top_path=None): return config + if __name__ == '__main__': from numpy.distutils.core import setup setup(**configuration(top_path='').todict()) From ff5f99341ebc31120e4640846f90d936ef50356e Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 1 Feb 2019 15:49:52 +0100 Subject: [PATCH 8/9] cblas to scipy cython_blas in cluster/_k_means --- sklearn/cluster/_k_means.pyx | 14 ++++---------- sklearn/cluster/setup.py | 18 ++++-------------- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 66fd620a90cdb..cfd81bc5f9e83 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -16,13 +16,11 @@ cimport cython from cython cimport floating from sklearn.utils.sparsefuncs_fast import assign_rows_csr +from ..utils._cython_blas cimport _dot ctypedef np.float64_t DOUBLE ctypedef np.int32_t INT -cdef extern from "cblas.h": - double ddot "cblas_ddot"(int N, double *X, int incX, double *Y, int incY) - float sdot "cblas_sdot"(int N, float *X, int incX, float *Y, int incY) np.import_array() @@ -60,18 +58,16 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, center_squared_norms = np.zeros(n_clusters, dtype=np.float32) x_stride = X.strides[1] / sizeof(float) center_stride = centers.strides[1] / sizeof(float) - dot = sdot else: center_squared_norms = np.zeros(n_clusters, dtype=np.float64) x_stride = X.strides[1] / sizeof(DOUBLE) center_stride = centers.strides[1] / sizeof(DOUBLE) - dot = ddot if n_samples == distances.shape[0]: store_distances = 1 for center_idx in range(n_clusters): - center_squared_norms[center_idx] = dot( + center_squared_norms[center_idx] = _dot( n_features, ¢ers[center_idx, 0], center_stride, ¢ers[center_idx, 0], center_stride) @@ -81,7 +77,7 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, dist = 0.0 # hardcoded: minimize euclidean distance to cluster center: # ||a - b||^2 = ||a||^2 + ||b||^2 -2 - dist += dot(n_features, &X[sample_idx, 0], x_stride, + dist += _dot(n_features, &X[sample_idx, 0], x_stride, ¢ers[center_idx, 0], center_stride) dist *= -2 dist += center_squared_norms[center_idx] @@ -129,16 +125,14 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] sample_weight, if floating is float: center_squared_norms = np.zeros(n_clusters, dtype=np.float32) - dot = sdot else: center_squared_norms = np.zeros(n_clusters, dtype=np.float64) - dot = ddot if n_samples == distances.shape[0]: store_distances = 1 for center_idx in range(n_clusters): - center_squared_norms[center_idx] = dot( + center_squared_norms[center_idx] = _dot( n_features, ¢ers[center_idx, 0], 1, ¢ers[center_idx, 0], 1) diff --git a/sklearn/cluster/setup.py b/sklearn/cluster/setup.py index 99c4dcd6177b0..c65489b89863d 100644 --- a/sklearn/cluster/setup.py +++ b/sklearn/cluster/setup.py @@ -1,21 +1,15 @@ # Author: Alexandre Gramfort # License: BSD 3 clause import os -from os.path import join import numpy -from sklearn._build_utils import get_blas_info - def configuration(parent_package='', top_path=None): from numpy.distutils.misc_util import Configuration - cblas_libs, blas_info = get_blas_info() - libraries = [] if os.name == 'posix': - cblas_libs.append('m') libraries.append('m') config = Configuration('cluster', parent_package, top_path) @@ -29,26 +23,22 @@ def configuration(parent_package='', top_path=None): language="c++", include_dirs=[numpy.get_include()], libraries=libraries) + config.add_extension('_k_means_elkan', sources=['_k_means_elkan.pyx'], include_dirs=[numpy.get_include()], libraries=libraries) config.add_extension('_k_means', - libraries=cblas_libs, sources=['_k_means.pyx'], - include_dirs=[join('..', 'src', 'cblas'), - numpy.get_include(), - blas_info.pop('include_dirs', [])], - extra_compile_args=blas_info.pop( - 'extra_compile_args', []), - **blas_info - ) + include_dirs=numpy.get_include(), + libraries=libraries) config.add_subpackage('tests') return config + if __name__ == '__main__': from numpy.distutils.core import setup setup(**configuration(top_path='').todict()) From 6525cec703e32ab26f758ee3298bbd8cbab29dc0 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 1 Feb 2019 16:40:34 +0100 Subject: [PATCH 9/9] cleanup manifold setup --- sklearn/manifold/setup.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/sklearn/manifold/setup.py b/sklearn/manifold/setup.py index bec1e25eee77c..457332b11a5f4 100644 --- a/sklearn/manifold/setup.py +++ b/sklearn/manifold/setup.py @@ -1,31 +1,28 @@ import os -from os.path import join import numpy -from numpy.distutils.misc_util import Configuration -from sklearn._build_utils import get_blas_info def configuration(parent_package="", top_path=None): + from numpy.distutils.misc_util import Configuration + config = Configuration("manifold", parent_package, top_path) + libraries = [] if os.name == 'posix': libraries.append('m') + config.add_extension("_utils", sources=["_utils.pyx"], include_dirs=[numpy.get_include()], libraries=libraries, extra_compile_args=["-O3"]) - cblas_libs, blas_info = get_blas_info() - eca = blas_info.pop('extra_compile_args', []) - eca.append("-O4") + config.add_extension("_barnes_hut_tsne", - libraries=cblas_libs, sources=["_barnes_hut_tsne.pyx"], - include_dirs=[join('..', 'src', 'cblas'), - numpy.get_include(), - blas_info.pop('include_dirs', [])], - extra_compile_args=eca, **blas_info) + include_dirs=[numpy.get_include()], + libraries=libraries, + extra_compile_args=['-O4']) config.add_subpackage('tests')