Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0987783

Browse files
committed
Merge branch 'master' of https://github.com/deepchem/deepchem into sdf_reader
2 parents a5018a1 + a88b70d commit 0987783

15 files changed

Lines changed: 1956 additions & 64 deletions

File tree

datasets/rev8020split_desc.csv

Lines changed: 1476 additions & 0 deletions
Large diffs are not rendered by default.

deepchem/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ def get_label_stds(self):
277277

278278
def get_statistics(self):
279279
"""Computes and returns statistics of this dataset"""
280+
if len(self) == 0:
281+
return None, None, None, None
280282
self.update_moments()
281283
df = self.metadata_df
282284
X_means, X_stds, y_means, y_stds = compute_mean_and_std(df)

deepchem/datasets/bace_datasets.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import sys
2+
import os
3+
import deepchem
4+
import tempfile, shutil
5+
from deepchem.utils.save import load_from_disk
6+
from deepchem.splits import SpecifiedSplitter
7+
from deepchem.featurizers.featurize import DataFeaturizer
8+
from deepchem.datasets import Dataset
9+
from deepchem.transformers import NormalizationTransformer
10+
from deepchem.transformers import ClippingTransformer
11+
from deepchem.hyperparameters import HyperparamOpt
12+
from sklearn.ensemble import RandomForestRegressor
13+
from deepchem.models.sklearn_models import SklearnModel
14+
from deepchem.datasets.bace_features import user_specified_features
15+
from deepchem import metrics
16+
from deepchem.metrics import Metric
17+
from deepchem.utils.evaluate import Evaluator
18+
19+
def load_bace(mode="regression", transform=True, split="20-80"):
20+
"""Load BACE-1 dataset as regression/classification problem."""
21+
reload = True
22+
verbosity = "high"
23+
assert split in ["20-80", "80-20"]
24+
25+
current_dir = os.path.dirname(os.path.realpath(__file__))
26+
if split == "20-80":
27+
dataset_file = os.path.join(
28+
current_dir, "../../datasets/desc_canvas_aug30.csv")
29+
elif split == "80-20":
30+
dataset_file = os.path.join(
31+
current_dir, "../../datasets/rev8020split_desc.csv")
32+
dataset = load_from_disk(dataset_file)
33+
num_display = 10
34+
pretty_columns = (
35+
"[" + ",".join(["'%s'" % column for column in
36+
dataset.columns.values[:num_display]])
37+
+ ",...]")
38+
39+
crystal_dataset_file = os.path.join(
40+
current_dir, "../../datasets/crystal_desc_canvas_aug30.csv")
41+
crystal_dataset = load_from_disk(crystal_dataset_file)
42+
43+
print("Columns of dataset: %s" % pretty_columns)
44+
print("Number of examples in dataset: %s" % str(dataset.shape[0]))
45+
print("Number of examples in crystal dataset: %s" %
46+
str(crystal_dataset.shape[0]))
47+
48+
#Make directories to store the raw and featurized datasets.
49+
base_dir = tempfile.mkdtemp()
50+
feature_dir = os.path.join(base_dir, "features")
51+
samples_dir = os.path.join(base_dir, "samples")
52+
full_dir = os.path.join(base_dir, "full_dataset")
53+
train_dir = os.path.join(base_dir, "train_dataset")
54+
valid_dir = os.path.join(base_dir, "valid_dataset")
55+
test_dir = os.path.join(base_dir, "test_dataset")
56+
model_dir = os.path.join(base_dir, "model")
57+
crystal_dir = os.path.join(base_dir, "crystal")
58+
crystal_feature_dir = os.path.join(base_dir, "crystal_feature")
59+
crystal_samples_dir = os.path.join(base_dir, "crystal_samples")
60+
61+
62+
if mode == "regression":
63+
bace_tasks = ["pIC50"]
64+
elif mode == "classification":
65+
bace_tasks = ["Class"]
66+
else:
67+
raise ValueError("Unknown mode %s" % mode)
68+
featurizer = DataFeaturizer(tasks=bace_tasks,
69+
smiles_field="mol",
70+
id_field="CID",
71+
user_specified_features=user_specified_features,
72+
split_field="Model")
73+
featurized_samples = featurizer.featurize(
74+
dataset_file, feature_dir, samples_dir, shard_size=2000,
75+
reload=reload)
76+
77+
crystal_featurized_samples = featurizer.featurize(
78+
crystal_dataset_file, crystal_feature_dir, crystal_samples_dir,
79+
shard_size=2000)
80+
81+
82+
splitter = SpecifiedSplitter(verbosity=verbosity)
83+
train_samples, valid_samples, test_samples = splitter.train_valid_test_split(
84+
featurized_samples, train_dir, valid_dir, test_dir,
85+
reload=reload)
86+
87+
#NOTE THE RENAMING:
88+
if split == "20-80":
89+
valid_samples, test_samples = test_samples, valid_samples
90+
91+
train_dataset = Dataset(data_dir=train_dir, samples=train_samples,
92+
featurizers=[], tasks=bace_tasks,
93+
use_user_specified_features=True)
94+
valid_dataset = Dataset(data_dir=valid_dir, samples=valid_samples,
95+
featurizers=[], tasks=bace_tasks,
96+
use_user_specified_features=True)
97+
test_dataset = Dataset(data_dir=test_dir, samples=test_samples,
98+
featurizers=[], tasks=bace_tasks,
99+
use_user_specified_features=True)
100+
crystal_dataset = Dataset(data_dir=crystal_dir,
101+
samples=crystal_featurized_samples,
102+
featurizers=[], tasks=bace_tasks,
103+
use_user_specified_features=True)
104+
print("Number of compounds in train set")
105+
print(len(train_dataset))
106+
print("Number of compounds in validation set")
107+
print(len(valid_dataset))
108+
print("Number of compounds in test set")
109+
print(len(test_dataset))
110+
print("Number of compounds in crystal set")
111+
print(len(crystal_dataset))
112+
113+
if transform:
114+
input_transformers = [
115+
NormalizationTransformer(transform_X=True, dataset=train_dataset),
116+
ClippingTransformer(transform_X=True, dataset=train_dataset)]
117+
output_transformers = []
118+
if mode == "regression":
119+
output_transformers = [
120+
NormalizationTransformer(transform_y=True, dataset=train_dataset)]
121+
else:
122+
output_transformers = []
123+
else:
124+
input_transformers, output_transformers = [], []
125+
126+
transformers = input_transformers + output_transformers
127+
for transformer in transformers:
128+
transformer.transform(train_dataset)
129+
for transformer in transformers:
130+
transformer.transform(valid_dataset)
131+
for transformer in transformers:
132+
transformer.transform(test_dataset)
133+
for transformer in transformers:
134+
transformer.transform(crystal_dataset)
135+
136+
return (bace_tasks, train_dataset, valid_dataset, test_dataset,
137+
crystal_dataset, output_transformers)

deepchem/datasets/bace_features.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
user_specified_features = ['MW','AlogP','HBA','HBD','RB','HeavyAtomCount','ChiralCenterCount','ChiralCenterCountAllPossible','RingCount','PSA','Estate','MR','Polar','sLi_Key','ssBe_Key','ssssBem_Key','sBH2_Key','ssBH_Key','sssB_Key','ssssBm_Key','sCH3_Key','dCH2_Key','ssCH2_Key','tCH_Key','dsCH_Key','aaCH_Key','sssCH_Key','ddC_Key','tsC_Key','dssC_Key','aasC_Key','aaaC_Key','ssssC_Key','sNH3_Key','sNH2_Key','ssNH2_Key','dNH_Key','ssNH_Key','aaNH_Key','tN_Key','sssNH_Key','dsN_Key','aaN_Key','sssN_Key','ddsN_Key','aasN_Key','ssssN_Key','daaN_Key','sOH_Key','dO_Key','ssO_Key','aaO_Key','aOm_Key','sOm_Key','sF_Key','sSiH3_Key','ssSiH2_Key','sssSiH_Key','ssssSi_Key','sPH2_Key','ssPH_Key','sssP_Key','dsssP_Key','ddsP_Key','sssssP_Key','sSH_Key','dS_Key','ssS_Key','aaS_Key','dssS_Key','ddssS_Key','ssssssS_Key','Sm_Key','sCl_Key','sGeH3_Key','ssGeH2_Key','sssGeH_Key','ssssGe_Key','sAsH2_Key','ssAsH_Key','sssAs_Key','dsssAs_Key','ddsAs_Key','sssssAs_Key','sSeH_Key','dSe_Key','ssSe_Key','aaSe_Key','dssSe_Key','ssssssSe_Key','ddssSe_Key','sBr_Key','sSnH3_Key','ssSnH2_Key','sssSnH_Key','ssssSn_Key','sI_Key','sPbH3_Key','ssPbH2_Key','sssPbH_Key','ssssPb_Key','sLi_Cnt','ssBe_Cnt','ssssBem_Cnt','sBH2_Cnt','ssBH_Cnt','sssB_Cnt','ssssBm_Cnt','sCH3_Cnt','dCH2_Cnt','ssCH2_Cnt','tCH_Cnt','dsCH_Cnt','aaCH_Cnt','sssCH_Cnt','ddC_Cnt','tsC_Cnt','dssC_Cnt','aasC_Cnt','aaaC_Cnt','ssssC_Cnt','sNH3_Cnt','sNH2_Cnt','ssNH2_Cnt','dNH_Cnt','ssNH_Cnt','aaNH_Cnt','tN_Cnt','sssNH_Cnt','dsN_Cnt','aaN_Cnt','sssN_Cnt','ddsN_Cnt','aasN_Cnt','ssssN_Cnt','daaN_Cnt','sOH_Cnt','dO_Cnt','ssO_Cnt','aaO_Cnt','aOm_Cnt','sOm_Cnt','sF_Cnt','sSiH3_Cnt','ssSiH2_Cnt','sssSiH_Cnt','ssssSi_Cnt','sPH2_Cnt','ssPH_Cnt','sssP_Cnt','dsssP_Cnt','ddsP_Cnt','sssssP_Cnt','sSH_Cnt','dS_Cnt','ssS_Cnt','aaS_Cnt','dssS_Cnt','ddssS_Cnt','ssssssS_Cnt','Sm_Cnt','sCl_Cnt','sGeH3_Cnt','ssGeH2_Cnt','sssGeH_Cnt','ssssGe_Cnt','sAsH2_Cnt','ssAsH_Cnt','sssAs_Cnt','dsssAs_Cnt','ddsAs_Cnt','sssssAs_Cnt','sSeH_Cnt','dSe_Cnt','ssSe_Cnt','aaSe_Cnt','dssSe_Cnt','ssssssSe_Cnt','ddssSe_Cnt','sBr_Cnt','sSnH3_Cnt','ssSnH2_Cnt','sssSnH_Cnt','ssssSn_Cnt','sI_Cnt','sPbH3_Cnt','ssPbH2_Cnt','sssPbH_Cnt','ssssPb_Cnt','sLi_Sum','ssBe_Sum','ssssBem_Sum','sBH2_Sum','ssBH_Sum','sssB_Sum','ssssBm_Sum','sCH3_Sum','dCH2_Sum','ssCH2_Sum','tCH_Sum','dsCH_Sum','aaCH_Sum','sssCH_Sum','ddC_Sum','tsC_Sum','dssC_Sum','aasC_Sum','aaaC_Sum','ssssC_Sum','sNH3_Sum','sNH2_Sum','ssNH2_Sum','dNH_Sum','ssNH_Sum','aaNH_Sum','tN_Sum','sssNH_Sum','dsN_Sum','aaN_Sum','sssN_Sum','ddsN_Sum','aasN_Sum','ssssN_Sum','daaN_Sum','sOH_Sum','dO_Sum','ssO_Sum','aaO_Sum','aOm_Sum','sOm_Sum','sF_Sum','sSiH3_Sum','ssSiH2_Sum','sssSiH_Sum','ssssSi_Sum','sPH2_Sum','ssPH_Sum','sssP_Sum','dsssP_Sum','ddsP_Sum','sssssP_Sum','sSH_Sum','dS_Sum','ssS_Sum','aaS_Sum','dssS_Sum','ddssS_Sum','ssssssS_Sum','Sm_Sum','sCl_Sum','sGeH3_Sum','ssGeH2_Sum','sssGeH_Sum','ssssGe_Sum','sAsH2_Sum','ssAsH_Sum','sssAs_Sum','dsssAs_Sum','ddsAs_Sum','sssssAs_Sum','sSeH_Sum','dSe_Sum','ssSe_Sum','aaSe_Sum','dssSe_Sum','ssssssSe_Sum','ddssSe_Sum','sBr_Sum','sSnH3_Sum','ssSnH2_Sum','sssSnH_Sum','ssssSn_Sum','sI_Sum','sPbH3_Sum','ssPbH2_Sum','sssPbH_Sum','ssssPb_Sum','sLi_Avg','ssBe_Avg','ssssBem_Avg','sBH2_Avg','ssBH_Avg','sssB_Avg','ssssBm_Avg','sCH3_Avg','dCH2_Avg','ssCH2_Avg','tCH_Avg','dsCH_Avg','aaCH_Avg','sssCH_Avg','ddC_Avg','tsC_Avg','dssC_Avg','aasC_Avg','aaaC_Avg','ssssC_Avg','sNH3_Avg','sNH2_Avg','ssNH2_Avg','dNH_Avg','ssNH_Avg','aaNH_Avg','tN_Avg','sssNH_Avg','dsN_Avg','aaN_Avg','sssN_Avg','ddsN_Avg','aasN_Avg','ssssN_Avg','daaN_Avg','sOH_Avg','dO_Avg','ssO_Avg','aaO_Avg','aOm_Avg','sOm_Avg','sF_Avg','sSiH3_Avg','ssSiH2_Avg','sssSiH_Avg','ssssSi_Avg','sPH2_Avg','ssPH_Avg','sssP_Avg','dsssP_Avg','ddsP_Avg','sssssP_Avg','sSH_Avg','dS_Avg','ssS_Avg','aaS_Avg','dssS_Avg','ddssS_Avg','ssssssS_Avg','Sm_Avg','sCl_Avg','sGeH3_Avg','ssGeH2_Avg','sssGeH_Avg','ssssGe_Avg','sAsH2_Avg','ssAsH_Avg','sssAs_Avg','dsssAs_Avg','ddsAs_Avg','sssssAs_Avg','sSeH_Avg','dSe_Avg','ssSe_Avg','aaSe_Avg','dssSe_Avg','ssssssSe_Avg','ddssSe_Avg','sBr_Avg','sSnH3_Avg','ssSnH2_Avg','sssSnH_Avg','ssssSn_Avg','sI_Avg','sPbH3_Avg','ssPbH2_Avg','sssPbH_Avg','ssssPb_Avg','First Zagreb (ZM1)','First Zagreb index by valence vertex degrees (ZM1V)','Second Zagreb (ZM2)','Second Zagreb index by valence vertex degrees (ZM2V)','Polarity (Pol)','Narumi Simple Topological (NST)','Narumi Harmonic Topological (NHT)','Narumi Geometric Topological (NGT)','Total structure connectivity (TSC)','Wiener (W)','Mean Wiener (MW)','Xu (Xu)','Quadratic (QIndex)','Radial centric (RC)','Mean Square Distance Balaban (MSDB)','Superpendentic (SP)','Harary (Har)','Log of product of row sums (LPRS)','Pogliani (Pog)','Schultz Molecular Topological (SMT)','Schultz Molecular Topological by valence vertex degrees (SMTV)','Mean Distance Degree Deviation (MDDD)','Ramification (Ram)','Gutman Molecular Topological (GMT)','Gutman MTI by valence vertex degrees (GMTV)','Average vertex distance degree (AVDD)','Unipolarity (UP)','Centralization (CENT)','Variation (VAR)','Molecular electrotopological variation (MEV)','Maximal electrotopological positive variation (MEPV)','Maximal electrotopological negative variation (MENV)','Eccentric connectivity (ECCc)','Eccentricity (ECC)','Average eccentricity (AECC)','Eccentric (DECC)','Valence connectivity index chi-0 (vX0)','Valence connectivity index chi-1 (vX1)','Valence connectivity index chi-2 (vX2)','Valence connectivity index chi-3 (vX3)','Valence connectivity index chi-4 (vX4)','Valence connectivity index chi-5 (vX5)','Average valence connectivity index chi-0 (AvX0)','Average valence connectivity index chi-1 (AvX1)','Average valence connectivity index chi-2 (AvX2)','Average valence connectivity index chi-3 (AvX3)','Average valence connectivity index chi-4 (AvX4)','Average valence connectivity index chi-5 (AvX5)','Quasi Wiener (QW)','First Mohar (FM)','Second Mohar (SM)','Spanning tree number (STN)','Kier benzene-likeliness index (KBLI)','Topological charge index of order 1 (TCI1)','Topological charge index of order 2 (TCI2)','Topological charge index of order 3 (TCI3)','Topological charge index of order 4 (TCI4)','Topological charge index of order 5 (TCI5)','Topological charge index of order 6 (TCI6)','Topological charge index of order 7 (TCI7)','Topological charge index of order 8 (TCI8)','Topological charge index of order 9 (TCI9)','Topological charge index of order 10 (TCI10)','Mean topological charge index of order 1 (MTCI1)','Mean topological charge index of order 2 (MTCI2)','Mean topological charge index of order 3 (MTCI3)','Mean topological charge index of order 4 (MTCI4)','Mean topological charge index of order 5 (MTCI5)','Mean topological charge index of order 6 (MTCI6)','Mean topological charge index of order 7 (MTCI7)','Mean topological charge index of order 8 (MTCI8)','Mean topological charge index of order 9 (MTCI9)','Mean topological charge index of order 10 (MTCI10)','Global topological charge (GTC)','Hyper-distance-path index (HDPI)','Reciprocal hyper-distance-path index (RHDPI)','Square reciprocal distance sum (SRDS)','Modified Randic connectivity (MRC)','Balaban centric (BC)','Lopping centric (LC)','Kier Hall electronegativity (KHE)','Sum of topological distances between N..N (STD(N N))','Sum of topological distances between N..O (STD(N O))','Sum of topological distances between N..S (STD(N S))','Sum of topological distances between N..P (STD(N P))','Sum of topological distances between N..F (STD(N F))','Sum of topological distances between N..Cl (STD(N Cl))','Sum of topological distances between N..Br (STD(N Br))','Sum of topological distances between N..I (STD(N I))','Sum of topological distances between O..O (STD(O O))','Sum of topological distances between O..S (STD(O S))','Sum of topological distances between O..P (STD(O P))','Sum of topological distances between O..F (STD(O F))','Sum of topological distances between O..Cl (STD(O Cl))','Sum of topological distances between O..Br (STD(O Br))','Sum of topological distances between O..I (STD(O I))','Sum of topological distances between S..S (STD(S S))','Sum of topological distances between S..P (STD(S P))','Sum of topological distances between S..F (STD(S F))','Sum of topological distances between S..Cl (STD(S Cl))','Sum of topological distances between S..Br (STD(S Br))','Sum of topological distances between S..I (STD(S I))','Sum of topological distances between P..P (STD(P P))','Sum of topological distances between P..F (STD(P F))','Sum of topological distances between P..Cl (STD(P Cl))','Sum of topological distances between P..Br (STD(P Br))','Sum of topological distances between P..I (STD(P I))','Sum of topological distances between F..F (STD(F F))','Sum of topological distances between F..Cl (STD(F Cl))','Sum of topological distances between F..Br (STD(F Br))','Sum of topological distances between F..I (STD(F I))','Sum of topological distances between Cl..Cl (STD(Cl Cl))','Sum of topological distances between Cl..Br (STD(Cl Br))','Sum of topological distances between Cl..I (STD(Cl I))','Sum of topological distances between Br..Br (STD(Br Br))','Sum of topological distances between Br..I (STD(Br I))','Sum of topological distances between I..I (STD(I I))','Wiener-type index from Z weighted distance matrix - Barysz matrix (WhetZ)','Wiener-type index from electronegativity weighted distance matrix (Whete)','Wiener-type index from mass weighted distance matrix (Whetm)','Wiener-type index from van der waals weighted distance matrix (Whetv)','Wiener-type index from polarizability weighted distance matrix (Whetp)','Balaban-type index from Z weighted distance matrix - Barysz matrix (JhetZ)','Balaban-type index from electronegativity weighted distance matrix (Jhete)','Balaban-type index from mass weighted distance matrix (Jhetm)','Balaban-type index from van der waals weighted distance matrix (Jhetv)','Balaban-type index from polarizability weighted distance matrix (Jhetp)','Topological diameter (TD)','Topological radius (TR)','Petitjean 2D shape (PJ2DS)','Balaban distance connectivity index (J)','Solvation connectivity index chi-0 (SCIX0)','Solvation connectivity index chi-1 (SCIX1)','Solvation connectivity index chi-2 (SCIX2)','Solvation connectivity index chi-3 (SCIX3)','Solvation connectivity index chi-4 (SCIX4)','Solvation connectivity index chi-5 (SCIX5)','Connectivity index chi-0 (CIX0)','Connectivity chi-1 [Randic connectivity] (CIX1)','Connectivity index chi-2 (CIX2)','Connectivity index chi-3 (CIX3)','Connectivity index chi-4 (CIX4)','Connectivity index chi-5 (CIX5)','Average connectivity index chi-0 (ACIX0)','Average connectivity index chi-1 (ACIX1)','Average connectivity index chi-2 (ACIX2)','Average connectivity index chi-3 (ACIX3)','Average connectivity index chi-4 (ACIX4)','Average connectivity index chi-5 (ACIX5)','reciprocal distance Randic-type index (RDR)','reciprocal distance square Randic-type index (RDSR)','1-path Kier alpha-modified shape index (KAMS1)','2-path Kier alpha-modified shape index (KAMS2)','3-path Kier alpha-modified shape index (KAMS3)','Kier flexibility (KF)','path/walk 2 - Randic shape index (RSIpw2)','path/walk 3 - Randic shape index (RSIpw3)','path/walk 4 - Randic shape index (RSIpw4)','path/walk 5 - Randic shape index (RSIpw5)','E-state topological parameter (ETP)','Ring Count 3 (RNGCNT3)','Ring Count 4 (RNGCNT4)','Ring Count 5 (RNGCNT5)','Ring Count 6 (RNGCNT6)','Ring Count 7 (RNGCNT7)','Ring Count 8 (RNGCNT8)','Ring Count 9 (RNGCNT9)','Ring Count 10 (RNGCNT10)','Ring Count 11 (RNGCNT11)','Ring Count 12 (RNGCNT12)','Ring Count 13 (RNGCNT13)','Ring Count 14 (RNGCNT14)','Ring Count 15 (RNGCNT15)','Ring Count 16 (RNGCNT16)','Ring Count 17 (RNGCNT17)','Ring Count 18 (RNGCNT18)','Ring Count 19 (RNGCNT19)','Ring Count 20 (RNGCNT20)','Atom Count (ATMCNT)','Bond Count (BNDCNT)','Atoms in Ring System (ATMRNGCNT)','Bonds in Ring System (BNDRNGCNT)','Cyclomatic number (CYCLONUM)','Number of ring systems (NRS)','Normalized number of ring systems (NNRS)','Ring Fusion degree (RFD)','Ring perimeter (RNGPERM)','Ring bridge count (RNGBDGE)','Molecule cyclized degree (MCD)','Ring Fusion density (RFDELTA)','Ring complexity index (RCI)','Van der Waals surface area (VSA)','MR1 (MR1)','MR2 (MR2)','MR3 (MR3)','MR4 (MR4)','MR5 (MR5)','MR6 (MR6)','MR7 (MR7)','MR8 (MR8)','ALOGP1 (ALOGP1)','ALOGP2 (ALOGP2)','ALOGP3 (ALOGP3)','ALOGP4 (ALOGP4)','ALOGP5 (ALOGP5)','ALOGP6 (ALOGP6)','ALOGP7 (ALOGP7)','ALOGP8 (ALOGP8)','ALOGP9 (ALOGP9)','ALOGP10 (ALOGP10)','PEOE1 (PEOE1)','PEOE2 (PEOE2)','PEOE3 (PEOE3)','PEOE4 (PEOE4)','PEOE5 (PEOE5)','PEOE6 (PEOE6)','PEOE7 (PEOE7)','PEOE8 (PEOE8)','PEOE9 (PEOE9)','PEOE10 (PEOE10)','PEOE11 (PEOE11)','PEOE12 (PEOE12)','PEOE13 (PEOE13)','PEOE14 (PEOE14)']

deepchem/hyperparameters/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def hyperparam_search(self, params_dict, train_dataset, valid_dataset,
7777

7878
evaluator = Evaluator(model, valid_dataset, output_transformers)
7979
multitask_scores = evaluator.compute_model_performance(
80-
[metric], valid_csv_out.name, valid_stats_out)
80+
[metric], valid_csv_out.name, valid_stats_out.name)
8181
valid_score = multitask_scores[metric.name]
8282
all_scores[str(hyperparameter_tuple)] = valid_score
8383

@@ -100,12 +100,14 @@ def hyperparam_search(self, params_dict, train_dataset, valid_dataset,
100100

101101
if best_model is None:
102102
log("No models trained correctly.", self.verbosity, "low")
103+
# arbitrarily return last model
104+
best_model, best_hyperparams = model, hyperparameter_tuple
103105
return best_model, best_hyperparams, all_scores
104106
train_csv_out = tempfile.NamedTemporaryFile()
105107
train_stats_out = tempfile.NamedTemporaryFile()
106108
train_evaluator = Evaluator(best_model, train_dataset, output_transformers)
107109
multitask_scores = train_evaluator.compute_model_performance(
108-
[metric], train_csv_out.name, train_stats_out)
110+
[metric], train_csv_out.name, train_stats_out.name)
109111
train_score = multitask_scores[metric.name]
110112
log("Best hyperparameters: %s" % str(best_hyperparams),
111113
self.verbosity, "low")

0 commit comments

Comments
 (0)