Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
571f4af
Make helper functions in cd use fused types
yenchenlin Jun 21, 2016
a2756a3
Import cblas float functions
yenchenlin Jun 21, 2016
d204dea
Make enet_coordinate_descent support fused types
yenchenlin Jun 21, 2016
c9aa51e
Make dense case work
yenchenlin Jul 2, 2016
4540ebc
Refactor format
yenchenlin Jul 2, 2016
4c1829c
Remove redundant change
yenchenlin Jul 2, 2016
efdda45
Add cblas files
yenchenlin Jul 6, 2016
a5249b6
Avoid redundant code
yenchenlin Jul 7, 2016
8339aa1
Remove redundant c files and import
yenchenlin Jul 7, 2016
9b8f470
Recover unnecessary change
yenchenlin Jul 7, 2016
9ac624d
Update comment
yenchenlin Jul 7, 2016
f992209
Make coef_ type consistent
yenchenlin Jul 8, 2016
5bfeb93
Test float32 input
yenchenlin Jul 8, 2016
2310766
Add user warning when fitting float32 data with small alpha
yenchenlin Jul 8, 2016
2ff201e
Fix bug
yenchenlin Jul 8, 2016
38c4d06
Change variable to floating type
yenchenlin Jul 8, 2016
75da365
Make cd sparse support fused types
yenchenlin Jul 8, 2016
e9bee9d
Make CD support fused types when data is sparse
yenchenlin Jul 8, 2016
cc9df4a
Add referenced src files
yenchenlin Jul 18, 2016
b22d676
Avoid type casting
yenchenlin Jul 18, 2016
d5c8d37
Fix indentation in test
yenchenlin Jul 18, 2016
0dcf4da
Avoid duplicated code
yenchenlin Jul 18, 2016
0c0eef8
Avoid type casting in sparse implementation
yenchenlin Jul 18, 2016
cde1d2b
Fix indentation
yenchenlin Jul 18, 2016
7bfe714
Fix duplicated intialization code
yenchenlin Jul 18, 2016
e65bec0
Follow PEP8
yenchenlin Jul 19, 2016
f4b247b
Raise tmp precision to double
yenchenlin Jul 22, 2016
e948157
Add 64 bit computer check
yenchenlin Jul 24, 2016
6a15fa6
Fix test
yenchenlin Jul 25, 2016
1591b0c
Add constraint
yenchenlin Jul 25, 2016
4ffaac0
PEP 8
yenchenlin Jul 27, 2016
3d2790e
Make saxpy have the same structure as daxpy
fabianp Aug 11, 2016
c745af4
Remove wrong hardware test
yenchenlin Aug 12, 2016
8b04c53
Remove dsdot
yenchenlin Aug 12, 2016
b035f34
Remove redundant asarray
yenchenlin Aug 15, 2016
c967912
Add test for fit_intercept
yenchenlin Aug 15, 2016
45b4aaa
Make _preprocess_data support other dtypes
yenchenlin Aug 16, 2016
dd4a42e
Add concrete value
yenchenlin Aug 16, 2016
23b6c2a
Workaround
yenchenlin Aug 18, 2016
116ec79
Fix error msg
yenchenlin Aug 21, 2016
470d8ab
Move declarartion
yenchenlin Aug 21, 2016
f868af7
Remove redundant comment
yenchenlin Aug 21, 2016
0e88af2
Add tests
yenchenlin Aug 21, 2016
14237e8
Test normalize
yenchenlin Aug 22, 2016
b4b9cf1
Delete warning
yenchenlin Aug 22, 2016
9348ad7
Fix comment
yenchenlin Aug 23, 2016
82fdf09
Add error msg
yenchenlin Aug 23, 2016
d0b56bb
Add error msg
yenchenlin Aug 24, 2016
611b412
Add what's new
yenchenlin Aug 24, 2016
00cadb6
Fix error msg
yenchenlin Aug 25, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ Enhancements
generating attribute ``estimators_samples_`` only when it is needed.
By `David Staub`_.

- :class:`linear_model.ElasticNet` and :class:`linear_model.Lasso`
now works with ``np.float32`` input data without converting it
into ``np.float64``. This allows to reduce the memory
consumption.
(`#6913 <https://github.com/scikit-learn/scikit-learn/pull/6913>`_)
By `YenChen Lin`_.

Bug fixes
.........
Expand Down
Loading