@@ -33,10 +33,10 @@ layer transforms the values from the previous layer by a weighted linear summati
33
33
:math: `g(\cdot ):R \rightarrow R` - like the hyperbolic tan function. The output layer
34
34
receives the values from the last hidden layer and transforms them into output values.
35
35
36
- The module contains the public attributes ``layers_coef_ `` and ``layers_intercept_ ``.
37
- ``layers_coef_ `` is a list of weight matrices, where weight matrix at index
36
+ The module contains the public attributes ``coefs_ `` and ``intercepts_ ``.
37
+ ``coefs_ `` is a list of weight matrices, where weight matrix at index
38
38
:math: `i` represents the weights between layer :math: `i` and layer
39
- :math: `i+1 `. ``layers_intercept_ `` is a list of bias vectors, where the vector
39
+ :math: `i+1 `. ``intercepts_ `` is a list of bias vectors, where the vector
40
40
at index :math: `i` represents the bias values added to layer :math: `i+1 `.
41
41
42
42
The advantages of Multi-layer Perceptron are:
@@ -68,45 +68,45 @@ some of these disadvantages.
68
68
Classification
69
69
==============
70
70
71
- Class :class: `MultilayerPerceptronClassifier ` implements
71
+ Class :class: `MLPClassifier ` implements
72
72
a multi layer perceptron (MLP) algorithm that trains using Backpropagation.
73
73
74
74
MLP trains on two arrays: array X of size (n_samples, n_features), which holds
75
75
the training samples represented as floating point feature vectors; and array
76
76
y of size (n_samples,), which holds the target values (class labels) for the
77
77
training samples::
78
78
79
- >>> from sklearn.neural_network import MultilayerPerceptronClassifier
79
+ >>> from sklearn.neural_network import MLPClassifier
80
80
>>> X = [[0., 0.], [1., 1.]]
81
81
>>> y = [0, 1]
82
- >>> clf = MultilayerPerceptronClassifier (hidden_layer_sizes=(5, 2), random_state=1)
83
- >>> clf.fit(X, y)
84
- MultilayerPerceptronClassifier (activation='relu', algorithm='l-bfgs',
85
- alpha=1e-05, batch_size=200 , hidden_layer_sizes=(5, 2),
86
- learning_rate='constant', learning_rate_init=0.5 ,
87
- max_iter=200 , power_t=0.5, random_state=1, shuffle=False ,
88
- tol=1e-05 , verbose=False, warm_start=False)
82
+ >>> clf = MLPClassifier (hidden_layer_sizes=(5, 2), random_state=1)
83
+ >>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
84
+ MLPClassifier (activation='relu', algorithm='l-bfgs', alpha=1e-05 ,
85
+ batch_size=200, early_stopping=False , hidden_layer_sizes=(5, 2),
86
+ learning_rate='constant', learning_rate_init=0.2, max_iter=200 ,
87
+ momentum=0.9, nesterovs_momentum=True , power_t=0.5, random_state=1,
88
+ shuffle=True, tol=0.0001 , verbose=False, warm_start=False)
89
89
90
90
After fitting (training), the model can predict labels for new samples::
91
91
92
92
>>> clf.predict([[2., 2.], [-1., -2.]])
93
93
array([1, 0])
94
94
95
- MLP can fit a non-linear model to the training data. ``clf.layers_coef_ ``
95
+ MLP can fit a non-linear model to the training data. ``clf.coefs_ ``
96
96
contains the weight matrices that constitute the model parameters::
97
97
98
- >>> [coef.shape for coef in clf.layers_coef_ ]
98
+ >>> [coef.shape for coef in clf.coefs_ ]
99
99
[(2, 5), (5, 2), (2, 1)]
100
100
101
101
To get the raw values before applying the output activation function, run the
102
102
following command,
103
103
104
- use :meth: `MultilayerPerceptronClassifier .decision_function `::
104
+ use :meth: `MLPClassifier .decision_function `::
105
105
106
106
>>> clf.decision_function([[2., 2.], [1., 2.]]) # doctest: +ELLIPSIS
107
- array([ 11.55 ..., 11.55 ...])
107
+ array([ 47.6 ..., 47.6 ...])
108
108
109
- Currently, :class: `MultilayerPerceptronClassifier ` supports only the
109
+ Currently, :class: `MLPClassifier ` supports only the
110
110
Cross-Entropy loss function, which allows probability estimates by running the
111
111
``predict_proba `` method.
112
112
@@ -115,36 +115,36 @@ Cross-Entropy loss function, giving a vector of probability estimates
115
115
:math: `P(y|x)` per sample :math: `x`::
116
116
117
117
>>> clf.predict_proba([[2 ., 2 .], [1 ., 2 .]]) # doctest: +ELLIPSIS
118
- array([[ 9.5...e-06 , 9.99...e-01 ],
119
- [ 9.5...e-06 , 9.99...e-01 ]])
118
+ array([[ 0. , 1. ],
119
+ [ 0. , 1. ]])
120
120
121
- :class: `MultilayerPerceptronClassifier ` supports multi-class classification by
121
+ :class: `MLPClassifier ` supports multi-class classification by
122
122
applying `Softmax <http://en.wikipedia.org/wiki/Softmax_activation_function >`_
123
123
as the output function.
124
124
125
125
Further, the algorithm supports :ref: `multi-label classification <multiclass >`
126
126
in which a sample can belong to more than one class. For each class, the output
127
- of :meth: `MultilayerPerceptronClassifier .decision_function ` passes through the
127
+ of :meth: `MLPClassifier .decision_function ` passes through the
128
128
logistic function. Values larger or equal to `0.5 ` are rounded to `1 `,
129
129
otherwise to `0 `. For a predicted output of a sample, the indices where the
130
130
value is `1 ` represents the assigned classes of that samples::
131
131
132
132
>>> X = [[0., 0.], [1., 1.]]
133
- >>> y = [[0, 1], [1]]
134
- >>> clf = MultilayerPerceptronClassifier (hidden_layer_sizes=(15,), random_state=1)
133
+ >>> y = [[0, 1], [1, 1 ]]
134
+ >>> clf = MLPClassifier (hidden_layer_sizes=(15,), random_state=1)
135
135
>>> clf.fit(X, y)
136
- MultilayerPerceptronClassifier (activation='relu', algorithm='l-bfgs',
137
- alpha=1e-05, batch_size=200 , hidden_layer_sizes=(15,),
138
- learning_rate='constant', learning_rate_init=0.5 ,
139
- max_iter=200 , power_t=0.5, random_state=1, shuffle=False ,
140
- tol=1e-05 , verbose=False, warm_start=False)
136
+ MLPClassifier (activation='relu', algorithm='l-bfgs', alpha=1e-05 ,
137
+ batch_size=200, early_stopping=False , hidden_layer_sizes=(15,),
138
+ learning_rate='constant', learning_rate_init=0.2, max_iter=200 ,
139
+ momentum=0.9, nesterovs_momentum=True , power_t=0.5, random_state=1,
140
+ shuffle=True, tol=0.0001 , verbose=False, warm_start=False)
141
141
>>> clf.predict([1., 2.])
142
- [(1,)]
142
+ array([[1, 1]])
143
143
>>> clf.predict([0., 0.])
144
- [( 0, 1)]
144
+ array([[ 0, 1]])
145
145
146
146
See the examples below and the doc string of
147
- :meth: `MultilayerPerceptronClassifier .fit ` for further information.
147
+ :meth: `MLPClassifier .fit ` for further information.
148
148
149
149
.. topic :: Examples:
150
150
@@ -155,12 +155,12 @@ See the examples below and the doc string of
155
155
Regression
156
156
==========
157
157
158
- Class :class: `MultilayerPerceptronRegressor ` implements
158
+ Class :class: `MLPRegressor ` implements
159
159
a multi layer perceptron (MLP) that trains using backpropagation with no
160
160
activation function in the output layer. Therefore, it uses the square error as
161
161
the loss function, and the output is a set of continuous values.
162
162
163
- :class: `MultilayerPerceptronRegressor ` also supports multi-output regression, in
163
+ :class: `MLPRegressor ` also supports multi-output regression, in
164
164
which a sample can have more than one target.
165
165
166
166
@@ -308,9 +308,22 @@ Tips on Practical Use
308
308
* Empirically, we observed that `L-BFGS ` converges faster and
309
309
with better solutions than `SGD `. Therefore, if mini-batch
310
310
and online learning are unnecessary, it is best advised
311
- to set :meth: `MultilayerPerceptronClassifier .algorithm ` as
311
+ to set :meth: `MLPClassifier .algorithm ` as
312
312
'l-bfgs'.
313
313
314
+ More control with warm_start
315
+ ============================
316
+ If you want more control over stopping criteria or learning rate in SGD,
317
+ or want to do additional monitoring, using ``warm_start=True `` and
318
+ ``max_iter=1 `` and iterating yourself can be helpful::
319
+
320
+ >>> X = [[0., 0.], [1., 1.]]
321
+ >>> y = [0, 1]
322
+ >>> clf = MLPClassifier(hidden_layer_sizes=(15,), random_state=1, max_iter=1)
323
+ >>> for i in range(10):
324
+ ... clf.fit(X, y)
325
+ ... # additional monitoring / inspection # doctest: +ELLIPSIS
326
+ MLPClassifier(...
314
327
315
328
.. topic :: References:
316
329
0 commit comments