diff --git a/build_tools/travis/test_script.sh b/build_tools/travis/test_script.sh index cdcfbe01b3b8b..f7d3ab2a32e0e 100755 --- a/build_tools/travis/test_script.sh +++ b/build_tools/travis/test_script.sh @@ -43,10 +43,13 @@ run_tests() { fi $TEST_CMD sklearn - # Test doc (only with nose until we switch completely to pytest) - if [[ "$USE_PYTEST" != "true" ]]; then - # Going back to git checkout folder needed for make test-doc - cd $OLDPWD + # Going back to git checkout folder needed to test documentation + cd $OLDPWD + + if [[ "$USE_PYTEST" == "true" ]]; then + pytest $(find doc -name '*.rst' | sort) + else + # Makefile is using nose make test-doc fi } diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/doc/datasets/conftest.py b/doc/datasets/conftest.py new file mode 100644 index 0000000000000..0ccc0bced9ee7 --- /dev/null +++ b/doc/datasets/conftest.py @@ -0,0 +1,75 @@ +from os.path import exists +from os.path import join + +import numpy as np + +from sklearn.utils.testing import SkipTest +from sklearn.utils.testing import check_skip_network +from sklearn.datasets import get_data_home +from sklearn.utils.testing import install_mldata_mock +from sklearn.utils.testing import uninstall_mldata_mock + + +def setup_labeled_faces(): + data_home = get_data_home() + if not exists(join(data_home, 'lfw_home')): + raise SkipTest("Skipping dataset loading doctests") + + +def setup_mldata(): + # setup mock urllib2 module to avoid downloading from mldata.org + install_mldata_mock({ + 'mnist-original': { + 'data': np.empty((70000, 784)), + 'label': np.repeat(np.arange(10, dtype='d'), 7000), + }, + 'iris': { + 'data': np.empty((150, 4)), + }, + 'datasets-uci-iris': { + 'double0': np.empty((150, 4)), + 'class': np.empty((150,)), + }, + }) + + +def teardown_mldata(): + uninstall_mldata_mock() + + +def setup_rcv1(): + check_skip_network() + # skip the test in rcv1.rst if the dataset is not already loaded + rcv1_dir = join(get_data_home(), "RCV1") + if not exists(rcv1_dir): + raise SkipTest("Download RCV1 dataset to run this test.") + + +def setup_twenty_newsgroups(): + data_home = get_data_home() + if not exists(join(data_home, '20news_home')): + raise SkipTest("Skipping dataset loading doctests") + + +def setup_working_with_text_data(): + check_skip_network() + + +def pytest_runtest_setup(item): + fname = item.fspath.strpath + if fname.endswith('datasets/labeled_faces.rst'): + setup_labeled_faces() + elif fname.endswith('datasets/mldata.rst'): + setup_mldata() + elif fname.endswith('datasets/rcv1.rst'): + setup_rcv1() + elif fname.endswith('datasets/twenty_newsgroups.rst'): + setup_twenty_newsgroups() + elif fname.endswith('datasets/working_with_text_data.rst'): + setup_working_with_text_data() + + +def pytest_runtest_teardown(item): + fname = item.fspath.strpath + if fname.endswith('datasets/mldata.rst'): + teardown_mldata() diff --git a/doc/datasets/mldata.rst b/doc/datasets/mldata.rst index 5083317cffc53..b94dfd7620a24 100644 --- a/doc/datasets/mldata.rst +++ b/doc/datasets/mldata.rst @@ -3,6 +3,11 @@ >>> import numpy as np >>> import os + >>> import tempfile + >>> # Create a temporary folder for the data fetcher + >>> custom_data_home = tempfile.mkdtemp() + >>> os.makedirs(os.path.join(custom_data_home, 'mldata')) + .. _mldata: @@ -70,3 +75,8 @@ defaults to individual datasets: ... data_home=custom_data_home) >>> iris3 = fetch_mldata('datasets-UCI iris', target_name='class', ... data_name='double0', data_home=custom_data_home) + + +.. + >>> import shutil + >>> shutil.rmtree(custom_data_home) diff --git a/doc/datasets/mldata_fixture.py b/doc/datasets/mldata_fixture.py index 37d9f9af05dc3..0ee5cccaa0f5e 100644 --- a/doc/datasets/mldata_fixture.py +++ b/doc/datasets/mldata_fixture.py @@ -3,26 +3,12 @@ Mock urllib2 access to mldata.org and create a temporary data folder. """ -from os import makedirs -from os.path import join import numpy as np -import tempfile -import shutil -from sklearn import datasets from sklearn.utils.testing import install_mldata_mock from sklearn.utils.testing import uninstall_mldata_mock -def globs(globs): - # Create a temporary folder for the data fetcher - global custom_data_home - custom_data_home = tempfile.mkdtemp() - makedirs(join(custom_data_home, 'mldata')) - globs['custom_data_home'] = custom_data_home - return globs - - def setup_module(): # setup mock urllib2 module to avoid downloading from mldata.org install_mldata_mock({ @@ -42,4 +28,3 @@ def setup_module(): def teardown_module(): uninstall_mldata_mock() - shutil.rmtree(custom_data_home)