diff --git a/examples/use_modernised.py b/examples/use_modernised.py index cfae1cc..6ca0d51 100644 --- a/examples/use_modernised.py +++ b/examples/use_modernised.py @@ -54,7 +54,7 @@ def __init__(self, symbols): """Intialise a checkpoint object. Upon initialisation, a checkpoint stores only a reference to the symbols that are passed into it. The symbols must be passed as a mapping symbolname->symbolobject.""" - + self.dtype = np.float32 if(isinstance(symbols, collections.Mapping)): self.symbols = symbols else: @@ -97,7 +97,7 @@ def size(self): v = Symbol((1)) fwdo = ForwardOperator(u, m) revo = ReverseOperator(u, m, v) -wrp = pr.Revolver(checkpoint, fwdo, revo, None, nSteps) +wrp = pr.Revolver(checkpoint, fwdo, revo, nSteps) wrp.apply_forward() print("u=%s" % u.data) wrp.apply_reverse() diff --git a/pyrevolve/pyrevolve.py b/pyrevolve/pyrevolve.py index 2d00d7c..6d856b5 100644 --- a/pyrevolve/pyrevolve.py +++ b/pyrevolve/pyrevolve.py @@ -2,8 +2,12 @@ import numpy as np from abc import ABCMeta, abstractproperty, abstractmethod +class Operator(object): + def apply(self, **kwargs): + pass -class Checkpoint: + +class Checkpoint(object): """Abstract base class, containing the methods and properties that any user-given Checkpoint class must have.""" __metaclass__ = ABCMeta @@ -24,7 +28,7 @@ def load(self, ptr): return NotImplemented -class CheckpointStorage: +class CheckpointStorage(object): """Holds a chunk of memory large enough to store all checkpoints. The []-operator is overloaded to return a pointer to the memory reserved for a given checkpoint number. Revolve will typically use this as LIFO, but the @@ -32,8 +36,8 @@ class CheckpointStorage: """Allocates memory on initialisation. Requires number of checkpoints and size of one checkpoint. Memory is allocated in C-contiguous style.""" - def __init__(self, size_ckp, n_ckp): - self.storage = np.zeros((n_ckp, size_ckp), order='C') + def __init__(self, size_ckp, n_ckp, dtype): + self.storage = np.zeros((n_ckp, size_ckp), order='C', dtype=dtype) """Returns a pointer to the contiguous chunk of memory reserved for the checkpoint with number `key`.""" @@ -57,8 +61,7 @@ class Revolver(object): """ def __init__(self, checkpoint, - fwd_operator, rev_operator, - n_checkpoints=None, n_timesteps=None): + fwd_operator, rev_operator, n_timesteps, n_checkpoints=None): """Initialise checkpointer for a given forward- and reverse operator, a given number of time steps, and a given storage strategy. The number of time steps must currently be provided explicitly, and the storage must @@ -71,8 +74,11 @@ def __init__(self, checkpoint, self.fwd_operator = fwd_operator self.rev_operator = rev_operator self.checkpoint = checkpoint - self.storage = CheckpointStorage(checkpoint.size, n_checkpoints) + checkpoint.revolver = self + self.storage = CheckpointStorage(checkpoint.size, n_checkpoints, checkpoint.dtype) self.n_timesteps = n_timesteps + self.fwd_args = {} + self.rev_args = {} storage_disk = None # this is not yet supported # We use the crevolve wrapper around the C++ Revolve library. self.ckp = cr.CRevolve(n_checkpoints, n_timesteps, storage_disk) @@ -86,18 +92,18 @@ def apply_forward(self): action = self.ckp.revolve() if(action == cr.Action.advance): # advance forward computation - self.fwd_operator.apply(t_start=self.ckp.oldcapo, - t_end=self.ckp.capo) + self.call_f(t_start=self.ckp.oldcapo, t_end=self.ckp.capo) elif(action == cr.Action.takeshot): # take a snapshot: copy from workspace into storage + #print("Taking snapshot number: %d"%self.ckp.check) self.checkpoint.save(self.storage[self.ckp.check]) elif(action == cr.Action.restore): # restore a snapshot: copy from storage into workspace + #print("Restoring snapshot number: %d"%self.ckp.check) self.checkpoint.load(self.storage[self.ckp.check]) elif(action == cr.Action.firstrun): # final step in the forward computation - self.fwd_operator.apply(t_start=self.ckp.oldcapo, - t_end=self.n_timesteps) + self.call_f(t_start=self.ckp.oldcapo, t_end=self.n_timesteps) break def apply_reverse(self): @@ -106,25 +112,33 @@ def apply_reverse(self): recompute sections of the trajectory that have not been stored in the forward run.""" - self.rev_operator.apply(t_start=self.ckp.capo, - t_end=self.ckp.capo+1) + self.call_r(t_start=self.ckp.capo, t_end=self.ckp.capo+1) while(True): # ask Revolve what to do next. action = self.ckp.revolve() if(action == cr.Action.advance): # advance forward computation - self.fwd_operator.apply(t_start=self.ckp.oldcapo, - t_end=self.ckp.capo) + self.call_f(t_start=self.ckp.oldcapo, t_end=self.ckp.capo) elif(action == cr.Action.takeshot): # take a snapshot: copy from workspace into storage + #print("Taking snapshot number: %d"%self.ckp.check) self.checkpoint.save(self.storage[self.ckp.check]) elif(action == cr.Action.restore): # restore a snapshot: copy from storage into workspace + #print("Restoring snapshot number: %d"%self.ckp.check) self.checkpoint.load(self.storage[self.ckp.check]) elif(action == cr.Action.youturn): # advance adjoint computation by a single step - self.rev_operator.apply(t_start=self.ckp.capo, - t_end=self.ckp.capo+1) + #print("t=%d, L2(u(t))=%d"%(self.ckp.capo+1, np.linalg.norm(self.checkpoint.symbols[0].data[(self.ckp.capo+1)%3, :, :]))) + self.call_f(t_start=self.ckp.capo, t_end=self.ckp.capo+1) + self.call_r(t_start=self.ckp.capo, t_end=self.ckp.capo+1) elif(action == cr.Action.terminate): break + + + def call_f(self, **kwargs): + self.fwd_operator.apply(**kwargs) + + def call_r(self, **kwargs): + self.rev_operator.apply(**kwargs)