Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ea2ba09

Browse files
committed
fix: fixes checkpointing issues for GaussKronrod Adjoint
1 parent 6f6c42a commit ea2ba09

File tree

4 files changed

+1443
-10
lines changed

4 files changed

+1443
-10
lines changed

src/gauss_adjoint.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
AbstractGAdjoint = Union{GaussAdjoint, GaussKronrodAdjoint} # needs to be supertype of GaussAdjoint && GaussKronrodAdjoint !!!
2-
3-
41
mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DGP,
52
G, SAlg <: AbstractGAdjoint}
63
sol::S
@@ -579,8 +576,6 @@ function _adjoint_sensitivities(sol, sensealg::AbstractGAdjoint, alg; t = nothin
579576
elseif sensealg isa GaussKronrodAdjoint
580577
cb = IntegratingGKSumCallback((out, u, t, integrator) -> integrand(out, t, u),
581578
integrand_values, allocate_vjp(tunables))
582-
else
583-
print("\n\nerror\n\n typeof(sensealg) = ", typeof(sensealg),"\n\n")
584579
end
585580
rcb = nothing
586581
cb2 = nothing

src/gauss_adjoint.jl~

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
AbstractGAdjoint = Union{GaussAdjoint, GaussKronrodAdjoint} # needs to be supertype of GaussAdjoint && GaussKronrodAdjoint !!!
2-
1+
# AbstractGAdjoint = Union{GaussAdjoint, GaussKronrodAdjoint} # needs to be supertype of GaussAdjoint && GaussKronrodAdjoint !!!
32

43
mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DGP,
54
G, SAlg <: AbstractGAdjoint}
@@ -573,14 +572,14 @@ function _adjoint_sensitivities(sol, sensealg::AbstractGAdjoint, alg; t = nothin
573572
end
574573
integrand = GaussIntegrand(sol, sensealg, checkpoints, dgdp_continuous)
575574
integrand_values = IntegrandValuesSum(allocate_zeros(tunables))
576-
if typeof(sensealg) == GaussAdjoint
575+
if sensealg isa GaussAdjoint
577576
cb = IntegratingSumCallback((out, u, t, integrator) -> integrand(out, t, u),
578577
integrand_values, allocate_vjp(tunables))
579-
elseif typeof(sensealg) == GaussKronrodAdjoint
578+
elseif sensealg isa GaussKronrodAdjoint
580579
cb = IntegratingGKSumCallback((out, u, t, integrator) -> integrand(out, t, u),
581580
integrand_values, allocate_vjp(tunables))
582581
else
583-
print("\n\nerror\n\n typeof(sensealg) = ", typeof(sensealg))
582+
print("\n\nerror\n\n typeof(sensealg) = ", typeof(sensealg),"\n\n")
584583
end
585584
rcb = nothing
586585
cb2 = nothing

src/sensitivity_algorithms.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,8 @@ function setvjp(sensealg::GaussKronrodAdjoint{CS, AD, FDT, Nothing}, vjp) where
667667
GaussKronrodAdjoint{CS, AD, FDT, typeof(vjp)}(vjp, sensealg.checkpointing)
668668
end
669669

670+
AbstractGAdjoint = Union{GaussAdjoint, GaussKronrodAdjoint}
671+
670672
"""
671673
```julia
672674
TrackerAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing}
@@ -1384,6 +1386,8 @@ end
13841386
@inline ischeckpointing(alg::InterpolatingAdjoint, sol) = alg.checkpointing || !sol.dense
13851387
@inline ischeckpointing(alg::GaussAdjoint) = alg.checkpointing
13861388
@inline ischeckpointing(alg::GaussAdjoint, sol) = alg.checkpointing || !sol.dense
1389+
@inline ischeckpointing(alg::GaussKronrodAdjoint) = alg.checkpointing
1390+
@inline ischeckpointing(alg::GaussKronrodAdjoint, sol) = alg.checkpointing || !sol.dense
13871391
@inline ischeckpointing(alg::BacksolveAdjoint, sol = nothing) = alg.checkpointing
13881392

13891393
@inline isnoisemixing(alg::AbstractSensitivityAlgorithm) = false

0 commit comments

Comments
 (0)