Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d128dd0

Browse files
committed
MNT: Use GEMV in enet_coordinate_descent
Make use of the BLAS GEMV operation in `enet_coordinate_descent` instead of using DOT in a `for`-loop. They are both semantically equivalent, but the former is likely multithreaded in BLAS implementations while here it is merely a serial loop.
1 parent a85eeb2 commit d128dd0

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

sklearn/linear_model/cd_fast.pyx

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,18 @@ def enet_coordinate_descent(floating[::1] w,
159159
# fused types version of BLAS functions
160160
if floating is float:
161161
dtype = np.float32
162+
gemv = sgemv
162163
dot = sdot
163164
axpy = saxpy
164165
asum = sasum
166+
copy = scopy
165167
else:
166168
dtype = np.float64
169+
gemv = dgemv
167170
dot = ddot
168171
axpy = daxpy
169172
asum = dasum
173+
copy = dcopy
170174

171175
# get the data information into easy vars
172176
cdef unsigned int n_samples = X.shape[0]
@@ -205,8 +209,11 @@ def enet_coordinate_descent(floating[::1] w,
205209

206210
with nogil:
207211
# R = y - np.dot(X, w)
208-
for i in range(n_samples):
209-
R[i] = y[i] - dot(n_features, &X[i, 0], n_samples, &w[0], 1)
212+
copy(n_samples, &y[0], 1, &R[0], 1)
213+
gemv(CblasColMajor, CblasNoTrans,
214+
n_samples, n_features, -1.0, &X[0, 0], n_samples,
215+
&w[0], 1,
216+
1.0, &R[0], 1)
210217

211218
# tol *= np.dot(y, y)
212219
tol *= dot(n_samples, &y[0], 1, &y[0], 1)
@@ -258,9 +265,11 @@ def enet_coordinate_descent(floating[::1] w,
258265
# stopping criterion
259266

260267
# XtA = np.dot(X.T, R) - beta * w
261-
for i in range(n_features):
262-
XtA[i] = (dot(n_samples, &X[0, i], 1, &R[0], 1)
263-
- beta * w[i])
268+
copy(n_features, &w[0], 1, &XtA[0], 1)
269+
gemv(CblasColMajor, CblasTrans,
270+
n_samples, n_features, 1.0, &X[0, 0], n_samples,
271+
&R[0], 1,
272+
-beta, &XtA[0], 1)
264273

265274
if positive:
266275
dual_norm_XtA = max(n_features, &XtA[0])

0 commit comments

Comments
 (0)