Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5bdf68d
Draft new abstract class for objective functions
santisoler Jan 15, 2024
97fc027
Remove mapping property
santisoler Jan 22, 2024
0477908
Fix class name in check
santisoler Jan 22, 2024
0571cf9
Add abstract deriv and deriv2 methods
santisoler Jan 22, 2024
405da1c
Remove the W method
santisoler Jan 22, 2024
4e11c7c
Make BaseObjectiveFunction an abstract class
santisoler Jan 22, 2024
6b4dcf3
Adjust ComboObjectiveFunction to work with the new class
santisoler Jan 22, 2024
f64c97c
Remove L2ObjectiveFunction
santisoler Jan 22, 2024
280bcce
Make BaseDataMisfit an abstract class
santisoler Jan 22, 2024
46a2a8b
Remove L2ObjectiveFunction from __all__ list
santisoler Jan 22, 2024
b2962f2
Remove L2ObjectiveFunction from the list in __init__.py
santisoler Jan 22, 2024
c11555a
Make BaseRegularization an abstract class
santisoler Jan 22, 2024
50216f5
Remove unused imports
santisoler Jan 22, 2024
d97085d
Recover the `nP` property for BaseRegularization
santisoler Jan 22, 2024
888b3cd
Improve check for has fields
santisoler Jan 22, 2024
39d7e82
Merge branch 'main' into objective-fun-redesign
santisoler Jan 22, 2024
fda863e
Merge branch 'main' into objective-fun-redesign
santisoler Feb 28, 2024
edb7f10
Update tests on objective functions operations
santisoler Feb 29, 2024
fdc093e
Create a MockRegularization class in regularization tests
santisoler Feb 29, 2024
bfd5740
Replace `.tests()` for `.test_derivatives()` in regularization tests
santisoler Feb 29, 2024
05d4b2a
Remove NotImplemented error tests on BaseRegularization
santisoler Feb 29, 2024
10d5c20
Replace `.test()` for `.test_derivatives()` in test_data_misfit.py
santisoler Feb 29, 2024
54def12
Replace .test for .test_derivatives in test_joint.py
santisoler Feb 29, 2024
fe8a14f
Revert BaseRegularization to a non-abstract class
santisoler Mar 1, 2024
a0bdaf0
Use .test_derivatives() in CrossGradient tests
santisoler Mar 1, 2024
6b767f6
Fix signature of deriv2 abstract method
santisoler Mar 1, 2024
ece66c0
Restore docstring of the __call__ abstract method
santisoler Mar 1, 2024
e47c0a0
Update tests for objective functions
santisoler Mar 1, 2024
7f4e213
Fix em tests: use reg.nP instead of reg.mapping.nP
santisoler Mar 1, 2024
162f830
Ditch usage of L2ObjectiveFunction in regularization test
santisoler Mar 1, 2024
2c7700d
Try to implement nP for the Volume regularization in example
santisoler Mar 1, 2024
625bf58
Merge branch 'main' into objective-fun-redesign
santisoler May 6, 2024
7cf64db
Merge branch 'main' into objective-fun-redesign
santisoler Aug 23, 2024
47e800d
Replace deprecated abstractproperty
santisoler Aug 23, 2024
cf4c49a
Add test to check error if adding non objective function
santisoler Aug 23, 2024
83a2385
Parametrize to test __radd__
santisoler Aug 23, 2024
883debe
Replace "sum" for "add" in tests
santisoler Aug 23, 2024
dc48072
Fix docstring
santisoler Aug 23, 2024
52f6f24
Add type hint to multipliers
santisoler Aug 23, 2024
c25a9e9
Extend tests
santisoler Aug 23, 2024
4f4be57
Add tests for the get_functions_of_type method
santisoler Aug 26, 2024
46d5cf4
Improve docstring for `fun_class` argument
santisoler Aug 26, 2024
3e670a3
Add tests for _need_to_pass_fields
santisoler Aug 26, 2024
ac3895f
Add tests for calling combos
santisoler Aug 26, 2024
d3606c4
Add read-only has_fields property to BaseDataMisfit
santisoler Aug 26, 2024
ed1294f
Fix docstring
santisoler Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/20-published/plot_tomo_joint_with_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def knownVolume(self):
def knownVolume(self, value):
self._knownVolume = utils.validate_float("knownVolume", value, min_val=0.0)

@property
def nP(self):
return "*"

def __call__(self, m):
return (self.estVol(m) - self.knownVolume) ** 2

Expand Down
1 change: 0 additions & 1 deletion simpeg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@

objective_function.BaseObjectiveFunction
objective_function.ComboObjectiveFunction
objective_function.L2ObjectiveFunction
data_misfit.BaseDataMisfit
data_misfit.L2DataMisfit

Expand Down
109 changes: 39 additions & 70 deletions simpeg/data_misfit.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,42 @@
from abc import abstractmethod
import numpy as np
from .utils import Counter, sdiag, timeIt, Identity, validate_type
from .utils import sdiag, Identity, validate_type
from .data import Data
from .simulation import BaseSimulation
from .objective_function import L2ObjectiveFunction
from .objective_function import BaseObjectiveFunction

__all__ = ["L2DataMisfit"]


class BaseDataMisfit(L2ObjectiveFunction):
class BaseDataMisfit(BaseObjectiveFunction):
r"""Base data misfit class.

Inherit this class to build your own data misfit function. The ``BaseDataMisfit``
class inherits the :py:class:`simpeg.objective_function.L2ObjectiveFunction`.
And as a result, it is limited to building data misfit functions of the form:
Inherit this class to build your own data misfit function.

.. important::

This class is not meant to be instantiated. You should inherit from it to
create your own data misfit class.

.. math::
\phi_d (\mathbf{m}) = \| \mathbf{W} f(\mathbf{m}) \|_2^2

where :math:`\mathbf{m}` is the model vector, :math:`\mathbf{W}` is a linear weighting
matrix, and :math:`f` is a mapping function that acts on the model.

Parameters
----------
data : simpeg.data.Data
A SimPEG data object.
simulation : simpeg.simulation.BaseSimulation
A SimPEG simulation object.
debug : bool
Print debugging information.
counter : None or simpeg.utils.Counter
Assign a SimPEG ``Counter`` object to store iterations and run-times.
"""

def __init__(self, data, simulation, debug=False, counter=None, **kwargs):
super().__init__(has_fields=True, debug=debug, counter=counter, **kwargs)

def __init__(self, data, simulation):
self.data = data
self.simulation = simulation

@property
def has_fields(self):
"""
Data misfits always have fields.
"""
return True

@property
def data(self):
"""A SimPEG data object.
Expand Down Expand Up @@ -74,36 +69,26 @@
"simulation", value, BaseSimulation, cast=False
)

@property
def debug(self):
"""Print debugging information.

Returns
-------
bool
Print debugging information.
@abstractmethod
def __call__(self, model, f=None) -> float:
"""
return self._debug

@debug.setter
def debug(self, value):
self._debug = validate_type("debug", value, bool)

@property
def counter(self):
"""SimPEG ``Counter`` object to store iterations and run-times.
Evaluate the objective function for a given model.
"""
pass

Returns
-------
None or simpeg.utils.Counter
SimPEG ``Counter`` object to store iterations and run-times.
@abstractmethod
def deriv(self, model):
"""
Gradient of the objective function evaluated on a given model.
"""
return self._counter
pass

@counter.setter
def counter(self, value):
if value is not None:
value = validate_type("counter", value, Counter, cast=False)
@abstractmethod
def deriv2(self, model):
"""
Hessian of the objective function evaluated on a given model.
"""
pass

@property
def nP(self):
Expand All @@ -114,9 +99,7 @@
int
Number of model parameters.
"""
if self._mapping is not None:
return self.mapping.nP
elif self.simulation.model is not None:
if self.simulation.model is not None:
return len(self.simulation.model)
else:
return "*"
Expand Down Expand Up @@ -147,15 +130,8 @@

@property
def W(self):
r"""The data weighting matrix.

For a discrete least-squares data misfit function of the form:

.. math::
\phi_d (\mathbf{m}) = \| \mathbf{W} \mathbf{f}(\mathbf{m}) \|_2^2

:math:`\mathbf{W}` is a linear weighting matrix, :math:`\mathbf{m}` is the model vector,
and :math:`\mathbf{f}` is a discrete mapping function that acts on the model vector.
"""
The data weighting matrix.

Returns
-------
Expand All @@ -165,22 +141,22 @@

if getattr(self, "_W", None) is None:
if self.data is None:
raise Exception(
raise TypeError(

Check warning on line 144 in simpeg/data_misfit.py

View check run for this annotation

Codecov / codecov/patch

simpeg/data_misfit.py#L144

Added line #L144 was not covered by tests
"data with standard deviations must be set before the data "
"misfit can be constructed. Please set the data: "
"dmis.data = Data(dobs=dobs, relative_error=rel"
", noise_floor=eps)"
)
standard_deviation = self.data.standard_deviation
if standard_deviation is None:
raise Exception(
raise TypeError(
"data standard deviations must be set before the data misfit "

Check warning on line 153 in simpeg/data_misfit.py

View check run for this annotation

Codecov / codecov/patch

simpeg/data_misfit.py#L152-L153

Added lines #L152 - L153 were not covered by tests
"can be constructed (data.relative_error = 0.05, "
"data.noise_floor = 1e-5), alternatively, the W matrix "
"can be set directly (dmisfit.W = 1./standard_deviation)"
)
if any(standard_deviation <= 0):
raise Exception(
raise ValueError(
"data.standard_deviation must be strictly positive to construct "
"the W matrix. Please set data.relative_error and or "
"data.noise_floor."
Expand All @@ -205,9 +181,9 @@
def residual(self, m, f=None):
r"""Computes the data residual vector for a given model.

Where :math:`\mathbf{d}_\text{obs}` is the observed data vector and :math:`\mathbf{d}_\text{pred}`
is the predicted data vector for a model vector :math:`\mathbf{m}`, this function
computes the data residual:
Where :math:`\mathbf{d}_\text{obs}` is the observed data vector and
:math:`\mathbf{d}_\text{pred}` is the predicted data vector for a model
vector :math:`\mathbf{m}`, this method computes the data residual:

.. math::
\mathbf{r} = \mathbf{d}_\text{pred} - \mathbf{d}_\text{obs}
Expand Down Expand Up @@ -255,20 +231,14 @@
A SimPEG data object that has observed data and uncertainties.
simulation : simpeg.simulation.BaseSimulation
A SimPEG simulation object.
debug : bool
Print debugging information.
counter : None or simpeg.utils.Counter
Assign a SimPEG ``Counter`` object to store iterations and run-times.
"""

@timeIt
def __call__(self, m, f=None):
"""Evaluate the residual for a given model."""

R = self.W * self.residual(m, f=f)
return np.vdot(R, R)

@timeIt
def deriv(self, m, f=None):
r"""Gradient of the data misfit function evaluated for the model provided.

Expand Down Expand Up @@ -297,7 +267,6 @@
m, self.W.T * (self.W * self.residual(m, f=f)), f=f
)

@timeIt
def deriv2(self, m, v, f=None):
r"""Hessian of the data misfit function evaluated for the model provided.

Expand Down
Loading