-
-
Notifications
You must be signed in to change notification settings - Fork 26.6k
Description
Description
TSNE with method="exact" is ignoring the n_iter_without_progress parameter.
My insight is that it never gets passed to _gradient_descent:
opt_args['n_iter_without_progress'] is only set for barnes_hut and not for "exact"
Steps/Code to Reproduce
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.datasets import fetch_mldata
from numpy import *
# Load MNIST dataset
mnist = fetch_mldata("MNIST original")
X, y = mnist.data / 255.0, mnist.target
# Create subset and reduce to first 50 dimensions
indices = arange(X.shape[0])
random.shuffle(indices)
n_train_samples = 1000
X_pca = PCA(n_components=50).fit_transform(X)
X_train = X_pca[indices[:n_train_samples]]
X_train_embedded = TSNE(n_components=2, perplexity=40, verbose=2,\
method='exact', n_iter_without_progress=400, n_iter=1000).fit_transform(X_train)
Expected Results
[t-SNE] Iteration 225: did not make any progress during the last 400 episodes. Finished.
Actual Results
[t-SNE] Computing pairwise distances...
[t-SNE] Computed conditional probabilities for sample 1000 / 1000
[t-SNE] Mean sigma: 2.524818
[t-SNE] Iteration 25: error = 18.2748800, gradient norm = 0.0848871
[t-SNE] Iteration 50: error = 17.5069131, gradient norm = 0.0767555
[t-SNE] Iteration 75: error = 17.4644246, gradient norm = 0.0764557
[t-SNE] Iteration 100: error = 17.5864772, gradient norm = 0.0684978
[t-SNE] KL divergence after 100 iterations with early exaggeration: 17.586477
[t-SNE] Iteration 125: error = 1.0545128, gradient norm = 0.0067884
[t-SNE] Iteration 150: error = 1.0181383, gradient norm = 0.0080407
[t-SNE] Iteration 175: error = 1.0443270, gradient norm = 0.0086719
[t-SNE] Iteration 200: error = 1.0717048, gradient norm = 0.0089992
[t-SNE] Iteration 225: error = 1.0598099, gradient norm = 0.0094873
[t-SNE] Iteration 225: did not make any progress during the last 50 episodes. Finished.
[t-SNE] Error after 225 iterations: 1.059810
Versions
Darwin-15.5.0-x86_64-i386-64bit
('Python', '2.7.12 |Anaconda 4.2.0 (x86_64)| (default, Jul 2 2016, 17:43:17) \n[GCC 4.2.1 (Based on Apple Inc. build 5658) (LLVM build 2336.11.00)]')
('NumPy', '1.11.2')
('SciPy', '0.17.1')
('Scikit-Learn', '0.18')