Source code for skchem.test.test_cross_validation.test_similarity_threshold
#! /usr/bin/env python
#
# Copyright (C) 2016 Rich Lewis <rl403@cam.ac.uk>
# License: 3-clause BSD
"""
## skchem.tests.test_cross_validation.test_similarity_threshold
Tests for similarity threshold dataset partitioning functionality.
"""
import pytest
from scipy.spatial.distance import cdist
import numpy as np
from ...data import Diversity
from ...cross_validation import SimThresholdSplit
@pytest.fixture
[docs]def x():
return Diversity.read_frame('feats/X_morg')
@pytest.fixture
[docs]def cv(x):
return SimThresholdSplit(fper=None, block_width=500, n_jobs=1).fit(x)
[docs]def test_split(cv, x):
train, test = cv.split((8, 2))
assert (1 - cdist(x[train], x[test]) > cv.threshold_).sum() == 0
assert np.allclose([train.sum()], [len(x) * 0.8], rtol=0.05)
[docs]def test_k_fold(cv, x):
kfold = [fold for fold in cv.k_fold(5)]
assert len(kfold) == 5
i, j = kfold[0]
assert (i != j).all()