|
1 | 1 | import itertools
|
2 | 2 | import re
|
| 3 | +import warnings |
3 | 4 | from collections import defaultdict
|
4 | 5 |
|
5 | 6 | import numpy as np
|
@@ -620,19 +621,44 @@ def test_argkmin_factory_method_wrong_usages():
|
620 | 621 | with pytest.raises(ValueError, match="ndarray is not C-contiguous"):
|
621 | 622 | ArgKmin.compute(X=np.asfortranarray(X), Y=Y, k=k, metric=metric)
|
622 | 623 |
|
| 624 | + # A UserWarning must be raised in this case. |
623 | 625 | unused_metric_kwargs = {"p": 3}
|
624 | 626 |
|
625 |
| - message = ( |
626 |
| - r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this" |
627 |
| - r" case \(" |
628 |
| - r"EuclideanArgKmin64." |
629 |
| - ) |
| 627 | + message = r"Some metric_kwargs have been passed \({'p': 3}\) but" |
630 | 628 |
|
631 | 629 | with pytest.warns(UserWarning, match=message):
|
632 | 630 | ArgKmin.compute(
|
633 | 631 | X=X, Y=Y, k=k, metric=metric, metric_kwargs=unused_metric_kwargs
|
634 | 632 | )
|
635 | 633 |
|
| 634 | + # A UserWarning must be raised in this case. |
| 635 | + metric_kwargs = { |
| 636 | + "p": 3, # unused |
| 637 | + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), |
| 638 | + } |
| 639 | + |
| 640 | + message = r"Some metric_kwargs have been passed \({'p': 3, 'Y_norm_squared'" |
| 641 | + |
| 642 | + with pytest.warns(UserWarning, match=message): |
| 643 | + ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs) |
| 644 | + |
| 645 | + # No user warning must be raised in this case. |
| 646 | + metric_kwargs = { |
| 647 | + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), |
| 648 | + } |
| 649 | + with warnings.catch_warnings(): |
| 650 | + warnings.simplefilter("error", category=UserWarning) |
| 651 | + ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs) |
| 652 | + |
| 653 | + # No user warning must be raised in this case. |
| 654 | + metric_kwargs = { |
| 655 | + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), |
| 656 | + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), |
| 657 | + } |
| 658 | + with warnings.catch_warnings(): |
| 659 | + warnings.simplefilter("error", category=UserWarning) |
| 660 | + ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs) |
| 661 | + |
636 | 662 |
|
637 | 663 | def test_radius_neighbors_factory_method_wrong_usages():
|
638 | 664 | rng = np.random.RandomState(1)
|
@@ -683,16 +709,48 @@ def test_radius_neighbors_factory_method_wrong_usages():
|
683 | 709 |
|
684 | 710 | unused_metric_kwargs = {"p": 3}
|
685 | 711 |
|
686 |
| - message = ( |
687 |
| - r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this" |
688 |
| - r" case \(EuclideanRadiusNeighbors64" |
689 |
| - ) |
| 712 | + # A UserWarning must be raised in this case. |
| 713 | + message = r"Some metric_kwargs have been passed \({'p': 3}\) but" |
690 | 714 |
|
691 | 715 | with pytest.warns(UserWarning, match=message):
|
692 | 716 | RadiusNeighbors.compute(
|
693 | 717 | X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=unused_metric_kwargs
|
694 | 718 | )
|
695 | 719 |
|
| 720 | + # A UserWarning must be raised in this case. |
| 721 | + metric_kwargs = { |
| 722 | + "p": 3, # unused |
| 723 | + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), |
| 724 | + } |
| 725 | + |
| 726 | + message = r"Some metric_kwargs have been passed \({'p': 3, 'Y_norm_squared'" |
| 727 | + |
| 728 | + with pytest.warns(UserWarning, match=message): |
| 729 | + RadiusNeighbors.compute( |
| 730 | + X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs |
| 731 | + ) |
| 732 | + |
| 733 | + # No user warning must be raised in this case. |
| 734 | + metric_kwargs = { |
| 735 | + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), |
| 736 | + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), |
| 737 | + } |
| 738 | + with warnings.catch_warnings(): |
| 739 | + warnings.simplefilter("error", category=UserWarning) |
| 740 | + RadiusNeighbors.compute( |
| 741 | + X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs |
| 742 | + ) |
| 743 | + |
| 744 | + # No user warning must be raised in this case. |
| 745 | + metric_kwargs = { |
| 746 | + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), |
| 747 | + } |
| 748 | + with warnings.catch_warnings(): |
| 749 | + warnings.simplefilter("error", category=UserWarning) |
| 750 | + RadiusNeighbors.compute( |
| 751 | + X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs |
| 752 | + ) |
| 753 | + |
696 | 754 |
|
697 | 755 | @pytest.mark.parametrize(
|
698 | 756 | "n_samples_X, n_samples_Y", [(100, 100), (500, 100), (100, 500)]
|
|
0 commit comments