diff --git a/qml/qmlearn/preprocessing.py b/qml/qmlearn/preprocessing.py index 2a28bb4a5..e669c4a7e 100644 --- a/qml/qmlearn/preprocessing.py +++ b/qml/qmlearn/preprocessing.py @@ -206,6 +206,21 @@ def _transform(self, data, features, y): else: return delta_y + def _revert_transform(self, data, features, y): + """ + Reverts the work of the transform method. + """ + + full_y = y + self.model.predict(features) + + if data: + # Force copy + data.energies = data.energies.copy() + data.energies[data._indices] = full_y + return data + else: + return full_y + def _check_elements(self, nuclear_charges): """ Check that the elements in the given nuclear_charges was @@ -261,3 +276,27 @@ def transform(self, X, y=None): features = self._featurizer(nuclear_charges) return self._transform(data, features, y) + + def revert_transform(self, X, y=None): + """ + Transforms data back to what it originally would have been if it hadn't been transformed with the fitted linear + model. Supports three different types of input. + 1) X is a list of nuclear charges and y is values to transform. + 2) X is an array of indices of which to transform. + 3) X is a data object + + :param X: List with nuclear charges or Data object. + :type X: list + :param y: Values to revert to before transform + :type y: array or None + :return: Array of untransformed values or Data object, depending on input + :rtype: array or Data object + """ + + data, nuclear_charges, y = self._parse_input(X, y) + + self._check_elements(nuclear_charges) + + features = self._featurizer(nuclear_charges) + + return self._revert_transform(data, features, y) \ No newline at end of file diff --git a/test/test_armp.py b/test/test_armp.py index 7cb184c0b..eeaf0688e 100644 --- a/test/test_armp.py +++ b/test/test_armp.py @@ -228,7 +228,7 @@ def test_predict_fromxyz(): pred1 = estimator.predict(idx) pred2 = estimator.predict_from_xyz(xyz, zs) - assert np.all(np.isclose(pred1, pred2, rtol=1.e-6)) + assert np.all(np.isclose(pred1, pred2, rtol=1.e-5)) estimator.save_nn(save_dir="temp") @@ -243,11 +243,11 @@ def test_predict_fromxyz(): pred3 = new_estimator.predict(idx) pred4 = new_estimator.predict_from_xyz(xyz, zs) - assert np.all(np.isclose(pred3, pred4, rtol=1.e-6)) - assert np.all(np.isclose(pred1, pred3, rtol=1.e-6)) - shutil.rmtree("temp") + assert np.all(np.isclose(pred3, pred4, rtol=1.e-5)) + assert np.all(np.isclose(pred1, pred3, rtol=1.e-5)) + def test_retraining(): xyz = np.array([[[0, 1, 0], [0, 1, 1], [1, 0, 1]], [[1, 2, 2], [3, 1, 2], [1, 3, 4]], @@ -291,8 +291,8 @@ def test_retraining(): pred4 = new_estimator.predict(idx) - assert np.all(np.isclose(pred1, pred3, rtol=1.e-6)) - assert np.all(np.isclose(pred2, pred4, rtol=1.e-6)) + assert np.all(np.isclose(pred1, pred3, rtol=1.e-5)) + assert np.all(np.isclose(pred2, pred4, rtol=1.e-5)) shutil.rmtree("temp")