3
3
from typing import Optional
4
4
5
5
import pytest
6
- from sklearn .utils . _testing import all_estimators
6
+ from sklearn .utils import all_estimators
7
7
8
8
numpydoc_validation = pytest .importorskip ("numpydoc.validate" )
9
9
10
- # List of whitelisted modules and methods; regexp are supported.
11
- DOCSTRING_WHITELIST = [
12
- "LogisticRegression$" ,
13
- "LogisticRegression.fit" ,
14
- "LogisticRegression.decision_function" ,
15
- "Birch.predict" ,
16
- "Birch.transform" ,
10
+ # List of modules ignored when checking for numpydoc validation.
11
+ DOCSTRING_IGNORE_LIST = [
12
+ "ARDRegression" ,
13
+ "AdaBoostClassifier" ,
14
+ "AdaBoostRegressor" ,
15
+ "AdditiveChi2Sampler" ,
16
+ "AffinityPropagation" ,
17
+ "AgglomerativeClustering" ,
18
+ "BaggingClassifier" ,
19
+ "BaggingRegressor" ,
20
+ "BayesianGaussianMixture" ,
21
+ "BayesianRidge" ,
22
+ "BernoulliNB" ,
23
+ "BernoulliRBM" ,
24
+ "Binarizer" ,
25
+ "Birch" ,
26
+ "CCA" ,
27
+ "CalibratedClassifierCV" ,
28
+ "CategoricalNB" ,
29
+ "ClassifierChain" ,
30
+ "ColumnTransformer" ,
31
+ "ComplementNB" ,
32
+ "CountVectorizer" ,
33
+ "DBSCAN" ,
34
+ "DecisionTreeClassifier" ,
35
+ "DecisionTreeRegressor" ,
36
+ "DictVectorizer" ,
37
+ "DictionaryLearning" ,
38
+ "DummyClassifier" ,
39
+ "DummyRegressor" ,
40
+ "ElasticNet" ,
41
+ "ElasticNetCV" ,
42
+ "EllipticEnvelope" ,
43
+ "EmpiricalCovariance" ,
44
+ "ExtraTreeClassifier" ,
45
+ "ExtraTreeRegressor" ,
46
+ "ExtraTreesClassifier" ,
47
+ "ExtraTreesRegressor" ,
48
+ "FactorAnalysis" ,
49
+ "FastICA" ,
50
+ "FeatureAgglomeration" ,
51
+ "FeatureHasher" ,
52
+ "FeatureUnion" ,
53
+ "FunctionTransformer" ,
54
+ "GammaRegressor" ,
55
+ "GaussianMixture" ,
56
+ "GaussianNB" ,
57
+ "GaussianProcessClassifier" ,
58
+ "GaussianProcessRegressor" ,
59
+ "GaussianRandomProjection" ,
60
+ "GenericUnivariateSelect" ,
17
61
"GradientBoostingClassifier" ,
18
62
"GradientBoostingRegressor" ,
19
- "LinearDiscriminantAnalysis.decision_function" ,
20
- "LinearSVC.decision_function" ,
21
- "LogisticRegressionCV.decision_function" ,
22
- "OPTICS" ,
23
- "OPTICS.fit" ,
24
- "PassiveAggressiveClassifier.decision_function" ,
25
- "Perceptron.decision_function" ,
26
- "RidgeClassifier.decision_function" ,
27
- "RidgeClassifier.fit" ,
28
- "RidgeClassifierCV.decision_function" ,
63
+ "GraphicalLasso" ,
64
+ "GraphicalLassoCV" ,
65
+ "GridSearchCV" ,
66
+ "HalvingGridSearchCV" ,
67
+ "HalvingRandomSearchCV" ,
68
+ "HashingVectorizer" ,
69
+ "HistGradientBoostingClassifier" ,
70
+ "HistGradientBoostingRegressor" ,
71
+ "HuberRegressor" ,
72
+ "IncrementalPCA" ,
73
+ "IsolationForest" ,
74
+ "Isomap" ,
75
+ "IsotonicRegression" ,
76
+ "IterativeImputer" ,
77
+ "KBinsDiscretizer" ,
78
+ "KMeans" ,
79
+ "KNNImputer" ,
80
+ "KNeighborsClassifier" ,
81
+ "KNeighborsRegressor" ,
82
+ "KNeighborsTransformer" ,
83
+ "KernelCenterer" ,
29
84
"KernelDensity" ,
30
- "KernelDensity.fit" ,
31
- "KernelDensity.score" ,
32
- "DecisionTreeClassifier" ,
33
- "DecisionTreeRegressor" ,
34
- "LinearRegression$" ,
35
- "SGDClassifier.decision_function" ,
36
- "SGDClassifier.set_params" ,
37
- "SGDClassifier.get_params" ,
38
- "SGDClassifier.fit" ,
39
- "SGDClassifier.partial_fit" ,
40
- "SGDClassifier.predict$" , # $ to avoid match w/ predict_proba (regex)
41
- "SGDClassifier.score" ,
42
- "SGDClassifier.sparsify" ,
43
- "SGDClassifier.densify" ,
44
- "VotingClassifier.fit" ,
45
- "VotingClassifier.transform" ,
46
- "VotingClassifier.predict" ,
47
- "VotingClassifier.score" ,
48
- "VotingClassifier.predict_proba" ,
49
- "VotingClassifier.set_params" ,
50
- "VotingClassifier.get_params" ,
51
- "VotingClassifier.named_estimators" ,
52
- "VotingClassifier$" ,
85
+ "KernelPCA" ,
86
+ "KernelRidge" ,
87
+ "LabelBinarizer" ,
88
+ "LabelEncoder" ,
89
+ "LabelPropagation" ,
90
+ "LabelSpreading" ,
91
+ "Lars" ,
92
+ "LarsCV" ,
93
+ "Lasso" ,
94
+ "LassoCV" ,
95
+ "LassoLars" ,
96
+ "LassoLarsCV" ,
97
+ "LassoLarsIC" ,
98
+ "LatentDirichletAllocation" ,
99
+ "LedoitWolf" ,
100
+ "LinearDiscriminantAnalysis" ,
101
+ "LinearRegression" ,
102
+ "LinearSVC" ,
103
+ "LinearSVR" ,
104
+ "LocalOutlierFactor" ,
105
+ "LocallyLinearEmbedding" ,
106
+ "LogisticRegression" ,
107
+ "LogisticRegressionCV" ,
108
+ "MDS" ,
109
+ "MLPClassifier" ,
110
+ "MLPRegressor" ,
111
+ "MaxAbsScaler" ,
112
+ "MeanShift" ,
113
+ "MinCovDet" ,
114
+ "MinMaxScaler" ,
115
+ "MiniBatchDictionaryLearning" ,
116
+ "MiniBatchKMeans" ,
117
+ "MiniBatchSparsePCA" ,
118
+ "MissingIndicator" ,
119
+ "MultiLabelBinarizer" ,
120
+ "MultiOutputClassifier" ,
121
+ "MultiOutputRegressor" ,
122
+ "MultiTaskElasticNet" ,
123
+ "MultiTaskElasticNetCV" ,
124
+ "MultiTaskLasso" ,
125
+ "MultiTaskLassoCV" ,
126
+ "MultinomialNB" ,
127
+ "NMF" ,
128
+ "NearestCentroid" ,
129
+ "NearestNeighbors" ,
130
+ "NeighborhoodComponentsAnalysis" ,
131
+ "Normalizer" ,
132
+ "NuSVC" ,
133
+ "NuSVR" ,
134
+ "Nystroem" ,
135
+ "OAS" ,
136
+ "OPTICS" ,
137
+ "OneClassSVM" ,
138
+ "OneHotEncoder" ,
139
+ "OneVsOneClassifier" ,
140
+ "OneVsRestClassifier" ,
141
+ "OrdinalEncoder" ,
142
+ "OrthogonalMatchingPursuit" ,
143
+ "OrthogonalMatchingPursuitCV" ,
144
+ "OutputCodeClassifier" ,
145
+ "PCA" ,
146
+ "PLSCanonical" ,
147
+ "PLSRegression" ,
148
+ "PLSSVD" ,
149
+ "PassiveAggressiveClassifier" ,
150
+ "PassiveAggressiveRegressor" ,
151
+ "PatchExtractor" ,
152
+ "Perceptron" ,
153
+ "Pipeline" ,
154
+ "PoissonRegressor" ,
155
+ "PolynomialCountSketch" ,
156
+ "PolynomialFeatures" ,
157
+ "PowerTransformer" ,
158
+ "QuadraticDiscriminantAnalysis" ,
159
+ "QuantileRegressor" ,
160
+ "QuantileTransformer" ,
161
+ "RANSACRegressor" ,
162
+ "RBFSampler" ,
163
+ "RFE" ,
164
+ "RFECV" ,
165
+ "RadiusNeighborsClassifier" ,
166
+ "RadiusNeighborsRegressor" ,
167
+ "RadiusNeighborsTransformer" ,
168
+ "RandomForestClassifier" ,
169
+ "RandomForestRegressor" ,
170
+ "RandomTreesEmbedding" ,
171
+ "RandomizedSearchCV" ,
172
+ "RegressorChain" ,
173
+ "Ridge" ,
174
+ "RidgeCV" ,
175
+ "RidgeClassifier" ,
176
+ "RidgeClassifierCV" ,
177
+ "RobustScaler" ,
178
+ "SGDOneClassSVM" ,
179
+ "SGDRegressor" ,
180
+ "SVC" ,
181
+ "SVR" ,
182
+ "SelectFdr" ,
183
+ "SelectFpr" ,
184
+ "SelectFromModel" ,
185
+ "SelectFwe" ,
186
+ "SelectKBest" ,
187
+ "SelectPercentile" ,
188
+ "SelfTrainingClassifier" ,
189
+ "SequentialFeatureSelector" ,
190
+ "ShrunkCovariance" ,
191
+ "SimpleImputer" ,
192
+ "SkewedChi2Sampler" ,
193
+ "SparseCoder" ,
194
+ "SparsePCA" ,
195
+ "SparseRandomProjection" ,
196
+ "SpectralBiclustering" ,
197
+ "SpectralClustering" ,
198
+ "SpectralCoclustering" ,
199
+ "SpectralEmbedding" ,
200
+ "SplineTransformer" ,
201
+ "StackingClassifier" ,
202
+ "StackingRegressor" ,
203
+ "StandardScaler" ,
204
+ "TSNE" ,
205
+ "TfidfTransformer" ,
206
+ "TfidfVectorizer" ,
207
+ "TheilSenRegressor" ,
208
+ "TransformedTargetRegressor" ,
209
+ "TruncatedSVD" ,
210
+ "TweedieRegressor" ,
211
+ "VarianceThreshold" ,
212
+ "VotingClassifier" ,
213
+ "VotingRegressor" ,
53
214
]
54
215
55
216
@@ -72,7 +233,7 @@ def get_all_methods():
72
233
yield Estimator , method
73
234
74
235
75
- def filter_errors (errors , method ):
236
+ def filter_errors (errors , method , Estimator = None ):
76
237
"""
77
238
Ignore some errors based on the method type.
78
239
@@ -90,6 +251,13 @@ def filter_errors(errors, method):
90
251
if code in ["RT02" , "GL01" ]:
91
252
continue
92
253
254
+ # Ignore PR02: Unknown parameters for properties. We sometimes use
255
+ # properties for ducktyping, i.e. SGDClassifier.predict_proba
256
+ if code == "PR02" and Estimator is not None and method is not None :
257
+ method_obj = getattr (Estimator , method )
258
+ if isinstance (method_obj , property ):
259
+ continue
260
+
93
261
# Following codes are only taken into account for the
94
262
# top level class docstrings:
95
263
# - ES01: No extended summary found
@@ -165,14 +333,14 @@ def test_docstring(Estimator, method, request):
165
333
166
334
import_path = "." .join (import_path )
167
335
168
- if not any (re .search (regex , import_path ) for regex in DOCSTRING_WHITELIST ):
336
+ if any (re .search (regex , import_path ) for regex in DOCSTRING_IGNORE_LIST ):
169
337
request .applymarker (
170
338
pytest .mark .xfail (run = False , reason = "TODO pass numpydoc validation" )
171
339
)
172
340
173
341
res = numpydoc_validation .validate (import_path )
174
342
175
- res ["errors" ] = list (filter_errors (res ["errors" ], method ))
343
+ res ["errors" ] = list (filter_errors (res ["errors" ], method , Estimator = Estimator ))
176
344
177
345
if res ["errors" ]:
178
346
msg = repr_errors (res , Estimator , method )
0 commit comments