Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add tests
  • Loading branch information
scottgigante committed Jan 22, 2018
commit 6b8e950b229abc2744a6892a80f58313231d3c41
60 changes: 59 additions & 1 deletion test/test_pyemd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import pytest

from pyemd import emd, emd_with_flow
from pyemd import emd, emd_with_flow, emd_samples


EMD_PRECISION = 5
Expand Down Expand Up @@ -146,6 +146,64 @@ def test_extra_mass_penalty_flow():
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 1.0, 1.0]])


def test_case_samples():
first_array = [1,2,3,4]
second_array = [2,3,4,5]
emd_value = emd_samples(first_array, second_array)
assert round(emd_value, EMD_PRECISION) == 0.75


def test_case_samples_binsize():
first_array = [1,2,3,4]
second_array = [2,3,4,5]
emd_value = emd_samples(first_array, second_array, bins=2)
assert round(emd_value, EMD_PRECISION) == 0.5


def test_case_samples_manual_range():
first_array = [1,2,3,4]
second_array = [2,3,4,5]
emd_value = emd_samples(first_array, second_array, range=(0,10))
assert round(emd_value, EMD_PRECISION) == 1.0


def test_case_samples_not_normalized():
first_array = [1,2,3,4]
second_array = [2,3,4,5]
emd_value = emd_samples(first_array, second_array, normalized=False)
assert round(emd_value, EMD_PRECISION) == 3.0


def test_case_samples_custom_distance():
dist = lambda x : np.array([[0. if i == j else 1. for i in x] for j in x])
first_array = [1,2,3,4]
second_array = [2,3,4,5]
emd_value = emd_samples(first_array, second_array, distance=dist)
assert round(emd_value, EMD_PRECISION) == 0.25


def test_case_samples_2():
first_array = [1]
second_array = [2]
emd_value = emd_samples(first_array, second_array)
assert round(emd_value, EMD_PRECISION) == 0.5


def test_case_samples_3():
first_array = [1,1,1,2,3]
second_array = [1,2,2,2,3]
emd_value = emd_samples(first_array, second_array)
assert round(emd_value, EMD_PRECISION) == 0.32


def test_case_samples_4():
first_array = [1,2,3,4,5]
second_array = [99,98,97,96,95]
emd_value = emd_samples(first_array, second_array)
assert round(emd_value, EMD_PRECISION) == 78.4


# Validation testing
# ~~~~~~~~~~~~~~~~~~

Expand Down