diff --git a/README.md b/README.md index c363d41f2..eb5bcf8c1 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +**ATTENTION**: This branch is no longer used and it belongs to pre-1.0 versions of DiffSharp. Use the [dev](https://github.com/DiffSharp/DiffSharp/tree/dev) branch for the latest version. The text below and the code in this branch are kept as a historical record. + + + DiffSharp: Differentiable Functional Programming ------------------------------------------------ diff --git a/appveyor.yml b/appveyor.yml index 209ab7bd7..735f3570e 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -1,6 +1,5 @@ -image: Visual Studio 2017 +image: Visual Studio 2019 build_script: - - cmd: .\build.cmd - -test: off \ No newline at end of file + - cmd: dotnet build + - cmd: dotnet test diff --git a/build.cmd b/build.cmd index f725e6fb4..ded6cb77d 100644 --- a/build.cmd +++ b/build.cmd @@ -3,3 +3,4 @@ cls dotnet --version dotnet build DiffSharp.sln -c release -v:n dotnet test tests/DiffSharp.Tests -c release -v:n +dotnet pack DiffSharp.sln -c release \ No newline at end of file diff --git a/build.sh b/build.sh index 9d77d2324..c7c15ac03 100755 --- a/build.sh +++ b/build.sh @@ -3,7 +3,7 @@ case "$(uname -s)" in Darwin) - brew install homebrew/science/openblas + brew install openblas ;; CYGWIN*|MINGW32*|MSYS*|Linux) @@ -25,7 +25,9 @@ if [ -d "MONO" ]; then # not currently testing on mono # mono ./packages/NUnit.Runners/tools/nunit-console.exe ./tests/DiffSharp.Tests/bin/Release/DiffSharp.Tests.dll else + dotnet --version dotnet build DiffSharp.sln -c debug dotnet test tests/DiffSharp.Tests -c release -f netcoreapp2.0 + dotnet pack DiffSharp.sln -c release fi diff --git a/docs/BuildDocs.fsx b/docs/BuildDocs.fsx index 3cb72d3ef..e7477c47c 100644 --- a/docs/BuildDocs.fsx +++ b/docs/BuildDocs.fsx @@ -79,7 +79,7 @@ Literate.ProcessScriptFile(relative "input/examples-stochasticgradientdescent.fs // Generate API reference // -let library = relative "../src/DiffSharp/bin/Debug/DiffSharp.dll" +let library = relative "../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" let layoutRoots = [relative "input/templates"; relative "input/templates/reference" ] MetadataFormat.Generate(library, relative "output/reference", layoutRoots, tags, markDownComments = true) diff --git a/docs/input/api-overview.fsx b/docs/input/api-overview.fsx index 40a77978a..2ed89eb5a 100644 --- a/docs/input/api-overview.fsx +++ b/docs/input/api-overview.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" (** API Overview diff --git a/docs/input/download.fsx b/docs/input/download.fsx index bf92c8b39..7cd002a17 100644 --- a/docs/input/download.fsx +++ b/docs/input/download.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" (** Download diff --git a/docs/input/examples-gradientdescent.fsx b/docs/input/examples-gradientdescent.fsx index e07b0f1d9..c3c504891 100644 --- a/docs/input/examples-gradientdescent.fsx +++ b/docs/input/examples-gradientdescent.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" (** Gradient Descent diff --git a/docs/input/examples-hamiltonianmontecarlo.fsx b/docs/input/examples-hamiltonianmontecarlo.fsx index 6d90a1f77..3aca0503c 100644 --- a/docs/input/examples-hamiltonianmontecarlo.fsx +++ b/docs/input/examples-hamiltonianmontecarlo.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" #load "../../packages/FSharp.Charting/FSharp.Charting.fsx" (** diff --git a/docs/input/examples-helmholtzenergyfunction.fsx b/docs/input/examples-helmholtzenergyfunction.fsx index 61c8b8ab3..798f5fa2f 100644 --- a/docs/input/examples-helmholtzenergyfunction.fsx +++ b/docs/input/examples-helmholtzenergyfunction.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" #load "../../packages/FSharp.Charting/FSharp.Charting.fsx" (** diff --git a/docs/input/examples-inversekinematics.fsx b/docs/input/examples-inversekinematics.fsx index fd4bdfbbf..279844ed8 100644 --- a/docs/input/examples-inversekinematics.fsx +++ b/docs/input/examples-inversekinematics.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" #load "../../packages/FSharp.Charting/FSharp.Charting.fsx" #load "EventEx-0.1.fsx" diff --git a/docs/input/examples-kinematics.fsx b/docs/input/examples-kinematics.fsx index e3d7840a1..44ace66cd 100644 --- a/docs/input/examples-kinematics.fsx +++ b/docs/input/examples-kinematics.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" #load "../../packages/FSharp.Charting/FSharp.Charting.fsx" (** diff --git a/docs/input/examples-kmeansclustering.fsx b/docs/input/examples-kmeansclustering.fsx index 882b26b0b..6bf772428 100644 --- a/docs/input/examples-kmeansclustering.fsx +++ b/docs/input/examples-kmeansclustering.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" #r "../../packages/FSharp.Data/lib/net40/FSharp.Data.dll" #load "../../packages/FSharp.Charting/FSharp.Charting.fsx" diff --git a/docs/input/examples-lhopitalsrule.fsx b/docs/input/examples-lhopitalsrule.fsx index 4a414ef46..680fcaf33 100644 --- a/docs/input/examples-lhopitalsrule.fsx +++ b/docs/input/examples-lhopitalsrule.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" (** diff --git a/docs/input/examples-neuralnetworks.fsx b/docs/input/examples-neuralnetworks.fsx index 00fb8330d..ed4d92721 100644 --- a/docs/input/examples-neuralnetworks.fsx +++ b/docs/input/examples-neuralnetworks.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" #load "../../packages/FSharp.Charting/FSharp.Charting.fsx" (** diff --git a/docs/input/examples-newtonsmethod.fsx b/docs/input/examples-newtonsmethod.fsx index 049dc9e64..2197c7773 100644 --- a/docs/input/examples-newtonsmethod.fsx +++ b/docs/input/examples-newtonsmethod.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" (** Newton's Method diff --git a/docs/input/examples-stochasticgradientdescent.fsx b/docs/input/examples-stochasticgradientdescent.fsx index 9abf1a553..d38fd51d9 100644 --- a/docs/input/examples-stochasticgradientdescent.fsx +++ b/docs/input/examples-stochasticgradientdescent.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" #load "../../packages/FSharp.Charting/FSharp.Charting.fsx" (** diff --git a/docs/input/gettingstarted-nestedad.fsx b/docs/input/gettingstarted-nestedad.fsx index d7dd72862..5e7515e43 100644 --- a/docs/input/gettingstarted-nestedad.fsx +++ b/docs/input/gettingstarted-nestedad.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" (** Nested AD diff --git a/docs/input/gettingstarted-numericaldifferentiation.fsx b/docs/input/gettingstarted-numericaldifferentiation.fsx index b619491a9..cf4ffc727 100644 --- a/docs/input/gettingstarted-numericaldifferentiation.fsx +++ b/docs/input/gettingstarted-numericaldifferentiation.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" (** Numerical Differentiation diff --git a/docs/input/gettingstarted-symbolicdifferentiation.fsx b/docs/input/gettingstarted-symbolicdifferentiation.fsx index 9d7b0f0f1..dc54c5372 100644 --- a/docs/input/gettingstarted-symbolicdifferentiation.fsx +++ b/docs/input/gettingstarted-symbolicdifferentiation.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" (** Symbolic Differentiation diff --git a/docs/input/gettingstarted-typeinference.fsx b/docs/input/gettingstarted-typeinference.fsx index 2e411125b..99506919d 100644 --- a/docs/input/gettingstarted-typeinference.fsx +++ b/docs/input/gettingstarted-typeinference.fsx @@ -1,5 +1,5 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" (** Type Inference diff --git a/docs/input/index.fsx b/docs/input/index.fsx index dd54e768d..3a7960128 100644 --- a/docs/input/index.fsx +++ b/docs/input/index.fsx @@ -1,10 +1,12 @@ (*** hide ***) -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" (** -DiffSharp: Differentiable Functional Programming +DiffSharp: Differentiable Functional Programming (v0.8) ================================================ +**For DiffSharp 1.0 see https://diffsharp.github.io/** + DiffSharp is a functional [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) (AD) library. AD allows exact and efficient calculation of derivatives, by systematically invoking the chain rule of calculus at the elementary operator level during program execution. AD is different from [numerical differentiation](https://en.wikipedia.org/wiki/Numerical_differentiation), which is prone to truncation and round-off errors, and [symbolic differentiation](https://en.wikipedia.org/wiki/Symbolic_computation), which is affected by expression swell and cannot fully handle algorithmic control flow. @@ -24,10 +26,10 @@ DiffSharp is implemented in the F# language and [can be used from C#](csharp.htm -Current Features and Roadmap +Features ---------------------------- -The following features are up and running: +The following features are in v0.8: - _Functional nested differentiation with linear algebra primitives, supporting forward and reverse AD, or any combination thereof, up to any level_ - _Matrix-free Jacobian- and Hessian-vector products_ @@ -35,14 +37,7 @@ The following features are up and running: - _Parallel implementations of non-BLAS operations (e.g. Hadamard products, matrix transpose)_ - _Support for 32- and 64-bit floating point precision (32 bit float operations run significantly faster on many systems)_ -Possible future features include: - -- _GPU backend using CUDA/OpenCL_ -- _Generalization to tensors/multidimensional arrays_ -- _Improved Hessian calculations exploiting sparsity structure (e.g. matrix-coloring)_ -- _AD via syntax tree transformation, using code quotations_ - -At this point we are debugging algorithmic complexity and the APIs. We are hoping the community will help us get the API right and ensure that the latest models can make use of DiffSharp as succinctly and as cleanly as possible, which would make it convenient to use in production. +For DiffSharp 1.0 see https://diffsharp.github.io/ How to Get ---------- diff --git a/src/DiffSharp/AD.Float32.fs b/src/DiffSharp/AD.Float32.fs index 5ccd15a81..cab99d13d 100644 --- a/src/DiffSharp/AD.Float32.fs +++ b/src/DiffSharp/AD.Float32.fs @@ -12,13 +12,13 @@ module DiffSharp.AD.Float32 open DiffSharp.Util open DiffSharp.Config -open System.Threading.Tasks open System.Collections.Generic type number = float32 let inline Backend<'T> = GlobalConfig.Float32Backend let inline VisualizationContrast<'T> = GlobalConfig.Float32VisualizationContrast let inline FixedPointEpsilon<'T> = GlobalConfig.Float32FixedPointEpsilon + module N = let inline toNumber x = float32 x let inline failWithInvalidTypeMessage () = failwith "Unsupported type. Expecting D, float32, or int." @@ -35,24 +35,27 @@ module N = /// with nesting capability, using tags to avoid perturbation confusion [] type D = + /// Primal | D of number + /// Primal, tangent, layer tag (for forward mode) - | DF of primal: D * tanget: D * tag: uint32 + | DF of primal: D * tanget: D * tag: uint32 + /// Primal, parent, layer tag (for reverse mode) - | DR of primal: D * parentOperation: TraceOp * tag: uint32 * uniq: int32 + | DR of primal: D * adjoint: (D ref) * parentOperation: TraceOp * fanOutCounter: (uint32 ref) * tag: uint32 interface dobj /// Make a reverse node - static member R(d, op, ai) = DR(d, op, ai, UniqueTagger.Next()) + static member R(d,op,ai) = DR(d, ref D.Zero, op, ref 0u, ai) /// Primal value of this D member d.P = match d with | D _ -> d | DF(ap, _, _) -> ap - | DR(ap, _, _, _) -> ap + | DR(ap, _, _, _, _) -> ap /// Deepest primal value of this D member d.PD = @@ -60,7 +63,7 @@ type D = match x with | D _ -> x | DF(xp, _, _) -> prec xp - | DR(xp, _, _, _) -> prec xp + | DR(xp, _, _, _, _) -> prec xp prec d /// Tangent value of this D @@ -70,11 +73,37 @@ type D = | DF(_, at, _) -> at | DR _ -> failwith "Cannot get tangent value of DR." + /// Adjoint script of this D + member d.A + with get() : D = + match d with + | D _ -> D.Zero + | DF(_,_,_) -> failwith "Cannot get adjoint value of DF." + | DR(_,a,_,_,_) -> !a + and set(v: D) = + match d with + | D _ -> () + | DF (_,_,_) -> failwith "Cannot set adjoint value of DF." + | DR (_,a,_,_,_) -> a := v + + /// Fan-out counter of this D + member d.F + with get() = + match d with + | D _ -> failwith "Cannot get fan-out value of D." + | DF (_,_,_) -> failwith "Cannot get fan-out value of DF." + | DR (_,_,_,f,_) -> !f + and set(v) = + match d with + | D _ -> failwith "Cannot set fan-out value of D." + | DF (_,_,_) -> failwith "Cannot set fan-out value of DF." + | DR (_,_,_,f,_) -> f := v + member d.GetForward(t:D, i:uint32) = DF(d, t, i) member d.GetReverse(i:uint32) = D.R(d, Noop, i) - static member Zero = D N.zero + static member Zero : D = D N.zero static member One = D N.one @@ -83,7 +112,7 @@ type D = match x with | D(p) -> p | DF(xp, _, _) -> prec xp - | DR(xp, _, _, _) -> prec xp + | DR(xp, _, _, _, _) -> prec xp prec d interface System.IComparable with @@ -101,7 +130,7 @@ type D = match d with | D(ap) -> hash [|ap|] | DF(ap, at, ai) -> hash [|ap; at; ai|] - | DR(ap, ao, ai, _) -> hash [|ap; ao; ai|] + | DR(ap, ao, ai, _, _) -> hash [|ap; ao; ai|] override d.ToString() = let (d':number) = D.op_Explicit(d) @@ -112,9 +141,9 @@ type D = static member inline Op_D_D (a, ff, fd, df, r) = match a with - | D(ap) -> D(ff(ap)) - | DF(ap, at, ai) -> let cp = fd(ap) in DF(cp, df(cp, ap, at), ai) - | DR(ap, _, ai, _) -> D.R(fd(ap), r(a), ai) + | D(ap) -> D(ff(ap)) + | DF(ap, at, ai) -> let cp = fd(ap) in DF(cp, df(cp, ap, at), ai) + | DR(ap,_,_,_,ai) -> D.R(fd(ap), r(a), ai) static member inline Op_D_D_D (a, b, ff, fd, df_da, df_db, df_dab, r_d_d, r_d_c, r_c_d) = match a with @@ -122,7 +151,7 @@ type D = match b with | D(bp) -> D(ff(ap, bp)) | DF(bp, bt, bi) -> let cp = fd(a, bp) in DF(cp, df_db(cp, bp, bt), bi) - | DR(bp, _, bi, _) -> D.R(fd(a, bp), r_c_d(a, b), bi) + | DR(bp, _, _, _, bi) -> D.R(fd(a, bp), r_c_d(a, b), bi) | DF(ap, at, ai) -> match b with | D _ -> let cp = fd(ap, b) in DF(cp, df_da(cp, ap, at), ai) @@ -131,12 +160,12 @@ type D = | 0 -> let cp = fd(ap, bp) in DF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DF(cp, df_da(cp, ap, at), ai) // ai > bi - | DR(bp, _, bi, _) -> + | DR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> D.R(fd(a, bp), r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DR(ap, _, ai, _) -> + | DR(ap, _, _, _, ai) -> match b with | D _ -> D.R(fd(ap, b), r_d_c(a, b), ai) | DF(bp, bt, bi) -> @@ -144,7 +173,7 @@ type D = | -1 -> let cp = fd(a, bp) in DF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> D.R(fd(ap, b), r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DR(bp, _, bi, _) -> + | DR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> D.R(fd(ap, bp), r_d_d(a, b), ai) // ai = bi | -1 -> D.R(fd(a, bp), r_c_d(a, b), bi) // ai < bi @@ -439,7 +468,7 @@ type D = i <- imax a <- aa DF(a.P, a.T, bi) - | DR(bp, _, bi, _) -> + | DR(bp,_,_,_,bi) -> let bfirst = D.R(bp, Noop, bi) // Cut the connection between b and bfirst ("switch of graph construction" involving b beyond this point) while i < imax do i <- i + 1 @@ -461,7 +490,7 @@ type D = and DV = | DV of number[] // Primal | DVF of DV * DV * uint32 // Primal, tangent, layer tag - | DVR of DV * TraceOp * uint32 * int32 // Primal, parent operation, layer tag, unique + | DVR of primal: DV * adjoint: (DV ref) * TraceOp * (uint32 ref) * uint32 // Primal, adjoint, parent operation, fan-out counter, tag interface dobj @@ -470,7 +499,7 @@ and DV = match d with | DV _ -> d | DVF(ap, _, _) -> ap - | DVR(ap, _, _, _) -> ap + | DVR(ap,_,_,_,_) -> ap /// Deepest primal value of this DV member d.PD = @@ -478,15 +507,41 @@ and DV = match x with | DV _ -> x | DVF(xp, _, _) -> prec xp - | DVR(xp, _, _, _) -> prec xp + | DVR(xp,_,_,_,_) -> prec xp prec d /// Tangent value of this DV member d.T = match d with - | DV _ -> DV.ZeroN d.Length - | DVF(_, at, _) -> at - | DVR _ -> failwith "Cannot get tangent value of DVR." + | DV(_) -> DV.ZeroN d.Length + | DVF(_,at,_) -> at + | DVR(_,_,_,_,_) -> failwith "Cannot get tangent value of DVR." + + /// Adjoint value of this DV + member d.A + with get() : DV = + match d with + | DV(_) -> DV.ZeroN d.Length + | DVF(_,_,_) -> failwith "Cannot get adjoint value of DVF." + | DVR(_,a,_,_,_) -> !a + and set(v: DV) = + match d with + | DV(_) -> () + | DVF(_,_,_) -> failwith "Cannot set adjoint value of DVF." + | DVR(_,a,_,_,_) -> a := v + + /// Fan-out counter of this DV + member d.F + with get() = + match d with + | DV(_) -> failwith "Cannot get fan-out value of DV." + | DVF(_,_,_) -> failwith "Cannot get fan-out value of DVF." + | DVR(_,_,_,f,_) -> !f + and set(v) = + match d with + | DV(_) -> failwith "Cannot set fan-out value of DV." + | DVF(_,_,_) -> failwith "Cannot set fan-out value of DVF." + | DVR(_,_,_,f,_) -> f := v /// Convert to use forward AD at this layer member d.GetForward(t:DV, i:uint32) = DVF(d, t, i) @@ -495,20 +550,20 @@ and DV = member d.GetReverse(i:uint32) = DV.R(d, Noop, i) /// Make a reverse node - static member R(d, op, ai) = DVR(d, op, ai, UniqueTagger.Next()) + static member R(d,op,ai) = DVR(d, ref (DV.ZeroN d.Length), op, ref 0u, ai) member d.Length = match d with | DV(ap) -> ap.Length | DVF(ap, _, _) -> ap.Length - | DVR(ap, _, _, _) -> ap.Length + | DVR(ap, _, _, _, _) -> ap.Length member d.Item with get i = match d with | DV(ap) -> D(ap.[i]) | DVF(ap, at, ai) -> DF(ap.[i], at.[i], ai) - | DVR(ap, _, ai, _) -> D.R(ap.[i], Item_DV(d, i), ai) + | DVR(ap, _, _, _, ai) -> D.R(ap.[i], Item_DV(d, i), ai) member d.GetSlice(lower, upper) = let l = defaultArg lower 0 @@ -516,21 +571,21 @@ and DV = match d with | DV(ap) -> DV(ap.[l..u]) | DVF(ap, at, ai) -> DVF(ap.[l..u], at.[l..u], ai) - | DVR(ap, _, ai, _) -> let cp = ap.[l..u] in DV.R(cp, Slice_DV(d, l), ai) + | DVR(ap, _, _, _, ai) -> let cp = ap.[l..u] in DV.R(cp, Slice_DV(d, l), ai) member d.ToArray() = match d with | DV(ap) -> ap |> Array.map D | DVF(ap, at, ai) -> Array.init ap.Length (fun i -> DF(ap.[i], at.[i], ai)) - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> Array.init ap.Length (fun i -> D.R(ap.[i], Item_DV(d, i), ai)) member d.ToRowDM() = match d with | DV(ap) -> seq [ap] |> array2D |> DM | DVF(ap, at, ai) -> DMF(ap.ToRowDM(), at.ToRowDM(), ai) - | DVR(ap, _, ai, _) -> let cp = ap.ToRowDM() in DM.R(cp, RowMatrix_DV(d), ai) + | DVR(ap, _, _, _, ai) -> let cp = ap.ToRowDM() in DM.R(cp, RowMatrix_DV(d), ai) member d.ToColDM() = DM.Transpose(d.ToRowDM()) @@ -574,7 +629,7 @@ and DV = match x with | DV(p) -> p | DVF(xp, _, _) -> prec xp - | DVR(xp, _, _, _) -> prec xp + | DVR(xp,_,_,_,_) -> prec xp prec d static member op_Explicit(d) = DV(d) @@ -587,7 +642,7 @@ and DV = let ap = a |> Array.map (fun x -> x.P) let at = a |> Array.map (fun x -> x.T) DVF(DV.OfArray(ap), DV.OfArray(at), ai) - | DR(_, _, ai, _) -> + | DR(_,_,_,_,ai) -> let ap = a |> Array.map (fun x -> x.P) let cp = DV.OfArray(ap) in DV.R(cp, Make_DV_ofDs(a), ai) @@ -600,7 +655,7 @@ and DV = let aps = DV.Split(ap, n) let ats = DV.Split(at, n) Seq.map2 (fun p t -> DVF(p, t, ai)) aps ats - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> let aps = DV.Split(ap, n) let ii = n |> Seq.mapFold (fun s i -> s, s + i) 0 |> fst |> Array.ofSeq Seq.mapi (fun i p -> DV.R(p, Split_DV(d, ii.[i]), ai)) aps @@ -610,19 +665,19 @@ and DV = match a with | DV(ap) -> DV(ff(ap)) | DVF(ap, at, ai) -> let cp = fd(ap) in DVF(cp, df(cp, ap, at), ai) - | DVR(ap, _, ai, _) -> let cp = fd(ap) in DV.R(cp, r(a), ai) + | DVR(ap,_,_,_,ai) -> let cp = fd(ap) in DV.R(cp, r(a), ai) static member inline Op_DV_DM (a, ff, fd, df, r) = match a with | DV(ap) -> DM(ff(ap)) | DVF(ap, at, ai) -> let cp = fd(ap) in DMF(cp, df(cp, ap, at), ai) - | DVR(ap, _, ai, _) -> let cp = fd(ap) in DM.R(cp, r(a), ai) + | DVR(ap,_,_,_,ai) -> let cp = fd(ap) in DM.R(cp, r(a), ai) static member inline Op_DV_D (a, ff, fd, df, r) = match a with | DV(ap) -> D(ff(ap)) | DVF(ap, at, ai) -> let cp = fd(ap) in DF(cp, df(cp, ap, at), ai) - | DVR(ap, _, ai, _) -> let cp = fd(ap) in D.R(cp, r(a), ai) + | DVR(ap,_,_,_,ai) -> let cp = fd(ap) in D.R(cp, r(a), ai) static member inline Op_DV_DV_DV (a, b, ff, fd, df_da, df_db, df_dab, r_d_d, r_d_c, r_c_d) = match a with @@ -630,7 +685,7 @@ and DV = match b with | DV(bp) -> DV(ff(ap, bp)) | DVF(bp, bt, bi) -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) - | DVR(bp, _, bi, _) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) + | DVR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) | DVF(ap, at, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) @@ -639,12 +694,12 @@ and DV = | 0 -> let cp = fd(ap, bp) in DVF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) | DVF(bp, bt, bi) -> @@ -652,7 +707,7 @@ and DV = | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DV.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi @@ -664,7 +719,7 @@ and DV = match b with | DV(bp) -> DM(ff(ap, bp)) | DVF(bp, bt, bi) -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) - | DVR(bp, _, bi, _) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) + | DVR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) | DVF(ap, at, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) @@ -673,12 +728,12 @@ and DV = | 0 -> let cp = fd(ap, bp) in DMF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) | DVF(bp, bt, bi) -> @@ -686,7 +741,7 @@ and DV = | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DM.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi @@ -698,7 +753,7 @@ and DV = match b with | DV(bp) -> D(ff(ap, bp)) | DVF(bp, bt, bi) -> let cp = fd(a, bp) in DF(cp, df_db(cp, bp, bt), bi) - | DVR(bp, _, bi, _) -> D.R(fd(a, bp), r_c_d(a, b), bi) + | DVR(bp, _, _, _, bi) -> D.R(fd(a, bp), r_c_d(a, b), bi) | DVF(ap, at, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DF(cp, df_da(cp, ap, at), ai) @@ -707,12 +762,12 @@ and DV = | 0 -> let cp = fd(ap, bp) in DF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DF(cp, df_da(cp, ap, at), ai) // ai > bi - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> D.R(fd(a, bp), r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> match b with | DV _ -> D.R(fd(ap, b), r_d_c(a, b), ai) | DVF(bp, bt, bi) -> @@ -720,7 +775,7 @@ and DV = | -1 -> let cp = fd(a, bp) in DF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> D.R(fd(ap, b), r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> D.R(fd(ap, bp), r_d_d(a, b), ai) // ai = bi | -1 -> D.R(fd(a, bp), r_c_d(a, b), bi) // ai < bi @@ -732,7 +787,7 @@ and DV = match b with | D(bp) -> DV(ff(ap, bp)) | DF(bp, bt, bi) -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) - | DR(bp, _, bi, _) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) + | DR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) | DVF(ap, at, ai) -> match b with | D _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) @@ -741,12 +796,12 @@ and DV = | 0 -> let cp = fd(ap, bp) in DVF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi - | DR(bp, _, bi, _) -> + | DR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> match b with | D _ -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) | DF(bp, bt, bi) -> @@ -754,7 +809,7 @@ and DV = | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DR(bp, _, bi, _) -> + | DR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DV.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi @@ -767,7 +822,7 @@ and DV = match b with | DV(bp) -> DV(ff(ap, bp)) | DVF(bp, bt, bi) -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) - | DVR(bp, _, bi, _) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) + | DVR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) | DF(ap, at, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) @@ -776,12 +831,12 @@ and DV = | 0 -> let cp = fd(ap, bp) in DVF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DR(ap, _, ai, _) -> + | DR(ap, _, _, _, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) | DVF(bp, bt, bi) -> @@ -789,7 +844,7 @@ and DV = | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DV.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1407,7 +1462,7 @@ and DM = /// Primal, tangent, layer tag (for forward mode) | DMF of primal: DM * tanget: DM * tag: uint32 /// Primal, parent, layer tag (for reverse mode) - | DMR of primal: DM * parentOperation: TraceOp * tag: uint32 * uniq: int32 + | DMR of primal: DM * adjoint: (DM ref) * parentOperation: TraceOp * fanOutCounter: (uint32 ref) * tag: uint32 interface dobj @@ -1416,7 +1471,7 @@ and DM = match d with | DM(_) -> d | DMF(ap, _, _) -> ap - | DMR(ap, _, _, _) -> ap + | DMR(ap,_,_,_,_) -> ap /// Deepest primal value of this DM member d.PD = @@ -1424,47 +1479,73 @@ and DM = match x with | DM(_) -> x | DMF(xp, _, _) -> prec xp - | DMR(xp, _, _, _) -> prec xp + | DMR(xp,_,_,_,_) -> prec xp prec d /// Tangent value of this DM member d.T = match d with - | DM(_) -> DM.ZeroMN d.Rows d.Cols + | DM _ -> DM.ZeroMN d.Rows d.Cols | DMF(_, at, _) -> at | DMR _ -> failwith "Cannot get tangent value of DMR." + /// Adjoint value of this DM + member d.A + with get() : DM = + match d with + | DM _ -> DM.ZeroMN d.Rows d.Cols + | DMF(_,_,_) -> failwith "Cannot get adjoint value of DMF." + | DMR(_,a,_,_,_) -> !a + and set(v: DM) = + match d with + | DM _ -> () + | DMF(_,_,_) -> failwith "Cannot set adjoint value of DMF." + | DMR(_,a,_,_,_) -> a := v + + /// Fan-out value of this DM + member d.F + with get() = + match d with + | DM _ -> failwith "Cannot get fan-out value of DM." + | DMF(_,_,_) -> failwith "Cannot get fan-out value of DMF." + | DMR(_,_,_,f,_) -> !f + and set(v) = + match d with + | DM(_) -> failwith "Cannot set fan-out value of DM." + | DMF(_,_,_) -> failwith "Cannot set fan-out value of DMF." + | DMR(_,_,_,f,_) -> f := v + member d.GetForward(t:DM, i:uint32) = DMF(d, t, i) member d.GetReverse(i:uint32) = DM.R(d, Noop, i) /// Make a reverse node - static member R(cp, op, ai) = DMR(cp, op, ai, UniqueTagger.Next()) + static member R(cp,op,ai) = DMR(cp, ref (DM.ZeroMN cp.Rows cp.Cols), op, ref 0u, ai) member d.Length = match d with | DM(ap) -> ap.Length | DMF(ap, _, _) -> ap.Length - | DMR(ap, _, _, _) -> ap.Length + | DMR(ap,_,_,_,_) -> ap.Length member d.Rows = match d with | DM(ap) -> Array2D.length1 ap | DMF(ap, _, _) -> ap.Rows - | DMR(ap, _, _, _) -> ap.Rows + | DMR(ap, _, _, _, _) -> ap.Rows member d.Cols = match d with | DM(ap) -> Array2D.length2 ap | DMF(ap, _, _) -> ap.Cols - | DMR(ap, _, _, _) -> ap.Cols + | DMR(ap, _, _, _, _) -> ap.Cols member d.Item with get (i, j) = match d with | DM(ap) -> D(ap.[i, j]) | DMF(ap, at, ai) -> DF(ap.[i, j], at.[i, j], ai) - | DMR(ap, _, ai, _) -> D.R(ap.[i, j], Item_DM(d, i, j), ai) + | DMR(ap, _, _, _, ai) -> D.R(ap.[i, j], Item_DM(d, i, j), ai) member d.GetSlice(rowStart, rowFinish, colStart, colFinish) = let rowStart = defaultArg rowStart 0 @@ -1474,7 +1555,7 @@ and DM = match d with | DM(ap) -> DM(ap.[rowStart..rowFinish, colStart..colFinish]) | DMF(ap, at, ai) -> DMF(ap.[rowStart..rowFinish, colStart..colFinish], at.[rowStart..rowFinish, colStart..colFinish], ai) - | DMR(ap, _, ai, _) -> let cp = ap.[rowStart..rowFinish, colStart..colFinish] in DM.R(cp, Slice_DM(d, rowStart, rowFinish), ai) + | DMR(ap, _, _, _, ai) -> let cp = ap.[rowStart..rowFinish, colStart..colFinish] in DM.R(cp, Slice_DM(d, rowStart, colStart), ai) member d.GetSlice(row, colStart, colFinish) = let colStart = defaultArg colStart 0 @@ -1482,7 +1563,7 @@ and DM = match d with | DM(ap) -> DV(ap.[row, colStart..colFinish]) | DMF(ap, at, ai) -> DVF(ap.[row, colStart..colFinish], at.[row, colStart..colFinish], ai) - | DMR(ap, _, ai, _) -> let cp = ap.[row, colStart..colFinish] in DV.R(cp, SliceRow_DM(d, row, colStart), ai) + | DMR(ap, _, _, _, ai) -> let cp = ap.[row, colStart..colFinish] in DV.R(cp, SliceRow_DM(d, row, colStart), ai) member d.GetSlice(rowStart, rowFinish, col) = let rowStart = defaultArg rowStart 0 @@ -1490,7 +1571,7 @@ and DM = match d with | DM(ap) -> DV(ap.[rowStart..rowFinish, col]) | DMF(ap, at, ai) -> DVF(ap.[rowStart..rowFinish, col], at.[rowStart..rowFinish, col], ai) - | DMR(ap, _, ai, _) -> let cp = ap.[rowStart..rowFinish, col] in DV.R(cp, SliceCol_DM(d, rowStart, col), ai) + | DMR(ap, _, _, _, ai) -> let cp = ap.[rowStart..rowFinish, col] in DV.R(cp, SliceCol_DM(d, rowStart, col), ai) member d.GetRows() = seq {for i = 0 to d.Rows - 1 do yield d.[i, *]} @@ -1546,7 +1627,7 @@ and DM = match x with | DM(p) -> p | DMF(xp, _, _) -> prec xp - | DMR(xp, _, _, _) -> prec xp + | DMR(xp, _, _, _, _) -> prec xp prec d static member op_Explicit(d:number[, ]) = DM(d) @@ -1559,7 +1640,7 @@ and DM = let ap = a |> Array2D.map (fun x -> x.P) let at = a |> Array2D.map (fun x -> x.T) DMF(DM.OfArray2D(ap), DM.OfArray2D(at), ai) - | DR(_, _, ai, _) -> + | DR(_, _, _, _, ai) -> let ap = a |> Array2D.map (fun x -> x.P) let cp = DM.OfArray2D(ap) in DM.R(cp, Make_DM_ofDs(a), ai) @@ -1577,7 +1658,7 @@ and DM = let ap = s |> Seq.map (fun x -> x.P) let at = s |> Seq.map (fun x -> x.T) DMF(DM.OfRows(ap), DM.OfRows(at), ai) - | DVR(_, _, ai, _) -> + | DVR(_, _, _, _, ai) -> let ap = s |> Seq.map (fun x -> x.P) let cp = DM.OfRows(ap) in DM.R(cp, Make_DMRows_ofDVs(s |> Seq.toArray), ai) @@ -1585,33 +1666,33 @@ and DM = match a with | DV(ap) -> DM(Backend.RepeatReshapeCopy_V_MRows(m, ap)) | DVF(ap, at, ai) -> DMF(DM.OfRows(m, ap), DM.OfRows(m, at), ai) - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> let cp = DM.OfRows(m, ap) in DM.R(cp, Make_DMRows_ofDV(a), ai) static member OfCols (n:int, a:DV) = match a with | DV(ap) -> DM(Backend.RepeatReshapeCopy_V_MCols(n, ap)) | DVF(ap, at, ai) -> DMF(DM.OfCols(n, ap), DM.OfCols(n, at), ai) - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> let cp = DM.OfCols(n, ap) in DM.R(cp, Make_DMCols_ofDV(a), ai) static member inline Op_DM_DM (a, ff, fd, df, r) = match a with | DM(ap) -> DM(ff(ap)) | DMF(ap, at, ai) -> let cp = fd(ap) in DMF(cp, df(cp, ap, at), ai) - | DMR(ap, _, ai, _) -> let cp = fd(ap) in DM.R(cp, r(a), ai) + | DMR(ap, _, _, _, ai) -> let cp = fd(ap) in DM.R(cp, r(a), ai) static member inline Op_DM_DV (a, ff, fd, df, r) = match a with | DM(ap) -> DV(ff(ap)) | DMF(ap, at, ai) -> let cp = fd(ap) in DVF(cp, df(cp, ap, at), ai) - | DMR(ap, _, ai, _) -> let cp = fd(ap) in DV.R(cp, r(a), ai) + | DMR(ap, _, _, _, ai) -> let cp = fd(ap) in DV.R(cp, r(a), ai) static member inline Op_DM_D (a, ff, fd, df, r) = match a with | DM(ap) -> D(ff(ap)) | DMF(ap, at, ai) -> let cp = fd(ap) in DF(cp, df(cp, ap, at), ai) - | DMR(ap, _, ai, _) -> let cp = fd(ap) in D.R(cp, r(a), ai) + | DMR(ap, _, _, _, ai) -> let cp = fd(ap) in D.R(cp, r(a), ai) static member inline Op_DM_DM_DM (a, b, ff, fd, df_da, df_db, df_dab, r_d_d, r_d_c, r_c_d) = match a with @@ -1619,7 +1700,7 @@ and DM = match b with | DM(bp) -> DM(ff(ap, bp)) | DMF(bp, bt, bi) -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) - | DMR(bp, _, bi, _) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) + | DMR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) | DMF(ap, at, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) @@ -1628,12 +1709,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DMF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(ap, _, ai, _) -> + | DMR(ap, _, _, _, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) | DMF(bp, bt, bi) -> @@ -1641,7 +1722,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DM.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1653,7 +1734,7 @@ and DM = match b with | D(bp) -> DM(ff(ap, bp)) | DF(bp, bt, bi) -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) - | DR(bp, _, bi, _) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) + | DR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) | DMF(ap, at, ai) -> match b with | D _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) @@ -1662,12 +1743,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DMF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi - | DR(bp, _, bi, _) -> + | DR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(ap, _, ai, _) -> + | DMR(ap, _, _, _, ai) -> match b with | D _ -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) | DF(bp, bt, bi) -> @@ -1675,7 +1756,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DR(bp, _, bi, _) -> + | DR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DM.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1687,7 +1768,7 @@ and DM = match b with | DM(bp) -> DM(ff(ap, bp)) | DMF(bp, bt, bi) -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) - | DMR(bp, _, bi, _) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) + | DMR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) | DF(ap, at, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) @@ -1696,12 +1777,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DMF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DR(ap, _, ai, _) -> + | DR(ap, _, _, _, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) | DMF(bp, bt, bi) -> @@ -1709,7 +1790,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DM.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1721,7 +1802,7 @@ and DM = match b with | DV(bp) -> DV(ff(ap, bp)) | DVF(bp, bt, bi) -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) - | DVR(bp, _, bi, _) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) + | DVR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) | DMF(ap, at, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) @@ -1730,12 +1811,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DVF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(ap, _, ai, _) -> + | DMR(ap, _, _, _, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) | DVF(bp, bt, bi) -> @@ -1743,7 +1824,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DV.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1755,7 +1836,7 @@ and DM = match b with | DM(bp) -> DV(ff(ap, bp)) | DMF(bp, bt, bi) -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) - | DMR(bp, _, bi, _) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) + | DMR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) | DVF(ap, at, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) @@ -1764,12 +1845,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DVF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) | DMF(bp, bt, bi) -> @@ -1777,7 +1858,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DV.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1789,7 +1870,7 @@ and DM = match b with | DV(bp) -> DM(ff(ap, bp)) | DVF(bp, bt, bi) -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) - | DVR(bp, _, bi, _) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) + | DVR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) | DMF(ap, at, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) @@ -1798,12 +1879,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DMF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(ap, _, ai, _) -> + | DMR(ap, _, _, _, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) | DVF(bp, bt, bi) -> @@ -1811,7 +1892,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DM.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1823,7 +1904,7 @@ and DM = match b with | DM(bp) -> DM(ff(ap, bp)) | DMF(bp, bt, bi) -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) - | DMR(bp, _, bi, _) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) + | DMR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) | DVF(ap, at, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) @@ -1832,12 +1913,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DMF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) | DMF(bp, bt, bi) -> @@ -1845,7 +1926,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DM.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi @@ -2751,6 +2832,7 @@ and TraceOp = /// A constraint used to ensure the evaluation stack is only over D, DV or DM and dobj = interface end +let bxd (x : dobj) = x /// Functional-oriented operations on vectors. Implementing functionality similar to FSharp.Collections.Array. [] @@ -3116,123 +3198,6 @@ module DM = let inline visualize (m:DM) = m.Visualize() let inline visualizeAsDV (m:DM) = DM.ReshapeToDV(m).Visualize() - -// Scripts for adjusting the adjoint -type Delta = - | X of D - //| XNeg of DeltaV - interface delta -and DeltaV = - | XV of DV - //| XNegV of DeltaV - interface delta -and DeltaM = - | XM of DM - | XNegM of DeltaM - interface delta -and delta = interface end - -/// Represents the computed adjoints for reverse AD. This is a table indexed by node ID. -/// The table is destructively updated as the adjoints are accumulated. -type Adjoints() = - let dict = Dictionary() - - let rec eval d = - match d with - | X d -> d - and evalV d = - match d with - | XV d -> d - and evalM d = - match d with - | XM v -> v - | XNegM v -> -(evalM v) - - member internal __.GetD(uniq: int) = dict.[uniq] :?> D - member internal __.SetD(uniq:int, v:D) = dict.[uniq] <- v - member internal __.GetDV(uniq: int) = dict.[uniq] :?> DV - member internal __.SetDV(uniq:int, v:DV) = dict.[uniq] <- v - member internal __.GetDM(uniq: int) = dict.[uniq] :?> DM - member internal __.SetDM(uniq:int, v:DM) = dict.[uniq] <- v - - // adj <- adj + interp(delta) - member internal __.ApplyDelta(uniq: int, x:Delta) = - let adj = dict.[uniq] :?> D - let res = eval x + adj - dict.[uniq] <- res - res - - // adj <- adj + interp(delta) - member internal __.ApplyDeltaV(uniq: int, x:DeltaV) = - let adj = dict.[uniq] :?> DV - match adj,x with - | DV adjv, XV (DV xv) -> - Backend.Add_V_V_Inplace(xv, adjv) - adj - | _ -> - let res = DV.Add_V_V_Inplace(evalV x,adj) - dict.[uniq] <- res - res - - // adj <- adj + interp(delta) - member internal __.ApplyDeltaM(uniq: int, x:DeltaM) = - let adj = dict.[uniq] :?> DM - match adj,x with - | DM adjm, XM (DM xm) -> - Backend.AlphaAdd_M_M_Inplace(N.one, xm, adjm) - adj - | DM adjm, XNegM (XM (DM xm)) -> - // TODO: also perform the inplace update in the case where adj is not "DM adj" - // However this needs care. - Backend.AlphaAdd_M_M_Inplace(N.minus1, xm, adjm) - adj - | _ -> - let adj = DM.Add_M_M_Inplace(evalM x,adj) - dict.[uniq] <- adj - adj - - /// Lookup the adjoint for a value - member this.Item - with get (d:D) : D = - match d with - | D _ -> D.Zero - | DF _ -> failwith "Cannot get adjoint value of DF. Use makeReverse on this node when composing the computation." - | DR (_, _, _, uniq) -> this.GetD(uniq) - and set (d:D) (v : D) = - match d with - | D _ -> () - | DF _ -> failwith "Cannot set adjoint value of DF. Use makeReverse on this node when composing the computation." - | DR (_, _, _, uniq) -> this.SetD(uniq, v) - - /// Lookup the adjoint for a vector - member this.Item - with get (d:DV) : DV = - match d with - | DV _ -> DV.ZeroN d.Length - | DVF _ -> failwith "Cannot get adjoint value of DVF. Use makeReverse on this node when composing the computation." - | DVR (_, _, _, uniq) -> this.GetDV(uniq) - and set (d:DV) (v : DV) = - match d with - | DV _ -> () - | DVF _ -> failwith "Cannot set adjoint value of DVF. Use makeReverse on this node when composing the computation." - | DVR (_, _, _, uniq) -> this.SetDV(uniq, v) - - /// Lookup the adjoint for a matrix - member this.Item - with get (d:DM) : DM = - match d with - | DM(_) -> DM.ZeroMN d.Rows d.Cols - | DMF _ -> failwith "Cannot get adjoint value of DMF. Use makeReverse on this node when composing the computation." - | DMR (_, _, _, uniq) -> this.GetDM(uniq) - and set (d:DM) (v : DM) = - match d with - | DM _ -> () - | DMF _ -> failwith "Cannot set adjoint value of DMF. Use makeReverse on this node when composing the computation." - | DMR (_, _, _, uniq) -> this.SetDM(uniq, v) - - override __.ToString() = sprintf "(%d computed adjoints)" dict.Count - - /// D, DV, DM operations (automatically opened) [] module DOps = @@ -3274,31 +3239,18 @@ module DOps = let inline tangent (d:^a when ^a : (member T : ^a)) = (^a : (member T : ^a) d) /// Get the adjoint value of `d` - let adjoint (adjoints: Adjoints) (d : 'T :> dobj) : 'T = + let adjoint (d : 'T :> dobj) : 'T = match box d with - | :? D as d -> adjoints.[d] |> box :?> 'T - | :? DV as d -> adjoints.[d] |> box :?> 'T - | :? DM as d -> adjoints.[d] |> box :?> 'T + | :? D as d -> box d.A :?> 'T + | :? DV as d -> box d.A :?> 'T + | :? DM as d -> box d.A :?> 'T | _ -> failwith "invalid dobj type" /// Get the primal and tangent values of `d`, as a tuple let inline primalTangent d = d |> primal, d |> tangent - - type Fanouts = Dictionary - let incrementFanout (fanouts: Fanouts) d = - match fanouts.TryGetValue(d) with - | true,fanout -> - fanouts.[d] <- fanout + 1u - fanout + 1u - | _ -> - fanouts.[d] <- 1u - 1u - /// Resets the adjoints of all the values in the evaluation trace of `d`, preparing for a new reverse propagation - let reverseReset (adjoints: Adjoints) (d:dobj) = - let bxd (x : dobj) = x - let fanouts = Fanouts() + let reverseReset (d:dobj) = // Note, this uses an explicit worklist over (D|DV|DM) to make it tail-recursive let rec resetRec (ds:dobj list) = match ds with @@ -3307,10 +3259,10 @@ module DOps = match d with | :? D as d -> match d with - | DR(_, o, _, uniq) -> - let fanout = incrementFanout fanouts uniq - if fanout = 1u then - adjoints.SetD(uniq,D.Zero) + | DR(_,_,o,_,_) -> + d.A <- D.Zero + d.F <- d.F + 1u + if d.F = 1u then match o with | Add_D_D(a, b) -> resetRec (bxd a :: bxd b :: t) | Add_D_DCons(a) -> resetRec (bxd a :: t) @@ -3366,10 +3318,10 @@ module DOps = | _ -> resetRec t | :? DV as d -> match d with - | DVR(_, o, _, uniq) -> - let fanout = incrementFanout fanouts uniq - if fanout = 1u then - adjoints.SetDV(uniq,DV.ZeroN d.Length) + | DVR(_,_,o,_,_) -> + d.A <- DV.ZeroN d.Length + d.F <- d.F + 1u + if d.F = 1u then match o with | Add_DV_DV(a, b) -> resetRec (bxd a :: bxd b :: t) | Add_DV_DVCons(a) -> resetRec (bxd a :: t) @@ -3468,10 +3420,10 @@ module DOps = | _ -> resetRec t | :? DM as d -> match d with - | DMR(_, o, _, uniq) -> - let fanout = incrementFanout fanouts uniq - if fanout = 1u then - adjoints.SetDM(uniq,DM.ZeroMN d.Rows d.Cols) + | DMR(_,_,o,_,_) -> + d.A <- DM.ZeroMN d.Rows d.Cols + d.F <- d.F + 1u + if d.F = 1u then match o with | Add_DM_DM(a, b) -> resetRec (bxd a :: bxd b :: t) | Add_DM_DMCons(a) -> resetRec (bxd a :: t) @@ -3480,6 +3432,7 @@ module DOps = | Sub_DMCons_DM(a) -> resetRec (bxd a :: t) | Mul_DM_DM(a, b) -> resetRec (bxd a :: bxd b :: t) | Mul_DM_DMCons(a, _) -> resetRec (bxd a :: t) + | Mul_DMCons_DM(_, b) -> resetRec (bxd b :: t) | Mul_Had_DM_DM(a, b) -> resetRec (bxd a :: bxd b :: t) | Mul_Had_DM_DMCons(a, _) -> resetRec (bxd a :: t) | Mul_DM_D(a, b) -> resetRec (bxd a :: bxd b :: t) @@ -3571,89 +3524,79 @@ module DOps = | _ -> resetRec t | _ -> resetRec t resetRec [d] - fanouts /// Propagates the adjoint `v` backwards through the evaluation trace of `d`. The adjoints in the trace are reset before the push. - let rec reverseProp (adjoints: Adjoints) (v:dobj) (d:dobj) = - let fanouts = reverseReset adjoints d - let inline bxd (x : dobj) = x - let inline bxdelta (x : delta) = x - - let inline bd (v: Delta) (d:D) = bxdelta v, bxd d - let inline bdv (v: DeltaV) (d:DV) = bxdelta v, bxd d - let inline bdm (v: DeltaM) (d:DM) = bxdelta v, bxd d - - let inline bx (v: D) (d:D) = bd (X v) d - let inline bxv (v: DV) (d:DV) = bdv (XV v) d - let inline bxm (v: DM) (d:DM) = bdm (XM v) d + let rec reverseProp (v:dobj) (d:dobj) = + let inline bx (v: D) d = (v :> dobj), bxd d + let inline bxv (v: DV) d = (v :> dobj), bxd d + let inline bxm (v: DM) d = (v :> dobj), bxd d // Note, this uses an explicit worklist over (D*D|DV*DV|DM*DM) to make it tail-recursive - let rec pushRec (ds:(delta*dobj) list) = + let rec pushRec (ds:(dobj*dobj) list) = match ds with | [] -> () | (v, d) :: t -> - match v, d with - | (:? Delta as delta), (:? D as d) -> + match d, v with + | (:? D as d), (:? D as v) -> match d with - | DR(_, o, _, uniq) -> - let dA = adjoints.ApplyDelta(uniq, delta) - let fanout = fanouts.[uniq] - 1u - fanouts.[uniq] <- fanout - // If all incoming parts of the adjoint have been received, then proceed to the parent - if fanout = 0u then + | DR(_,_,o,_,_) -> + d.A <- d.A + v + d.F <- d.F - 1u + // If all incoming parts of the adjoint have been received, then proceed to the children + if d.F = 0u then match o with - | Add_D_D(a, b) -> pushRec ((bx dA a) :: (bx dA b) :: t) - | Add_D_DCons(a) -> pushRec ((bx dA a) :: t) - | Sub_D_D(a, b) -> pushRec ((bx dA a) :: (bx -dA b) :: t) - | Sub_D_DCons(a) -> pushRec ((bx dA a) :: t) - | Sub_DCons_D(b) -> pushRec ((bx -dA b) :: t) - | Mul_D_D(a, b) -> pushRec ((bx (dA * b.P) a) :: (bx (dA * a.P) b) :: t) - | Mul_D_DCons(a, cons) -> pushRec ((bx (dA * cons) a) :: t) - | Div_D_D(a, b) -> pushRec ((bx (dA / b.P) a) :: (bx (dA * (-a.P / (b.P * b.P))) b) :: t) - | Div_D_DCons(a, cons) -> pushRec ((bx (dA / cons) a) :: t) - | Div_DCons_D(cons, b) -> pushRec ((bx (dA * (-cons / (b.P * b.P))) b) :: t) - | Pow_D_D(a, b) -> pushRec ((bx (dA * (a.P ** (b.P - D.One)) * b.P) a) :: (bx (dA * (a.P ** b.P) * log a.P) b) :: t) - | Pow_D_DCons(a, cons) -> pushRec ((bx (dA * (a.P ** (cons - D.One)) * cons) a) :: t) - | Pow_DCons_D(cons, b) -> pushRec ((bx (dA * (cons ** b.P) * log cons) b) :: t) - | Atan2_D_D(a, b) -> let denom = a.P * a.P + b.P * b.P in pushRec ((bx (dA * b.P / denom) a) :: (bx (dA * (-a.P) / denom) b) :: t) - | Atan2_D_DCons(a, cons) -> pushRec ((bx (dA * cons / (a.P * a.P + cons * cons)) a) :: t) - | Atan2_DCons_D(cons, b) -> pushRec ((bx (dA * (-cons) / (cons * cons + b.P * b.P)) b) :: t) - | Log_D(a) -> pushRec ((bx (dA / a.P) a) :: t) - | Log10_D(a) -> pushRec ((bx (dA / (a.P * N.log10Val)) a) :: t) - | Exp_D(a) -> pushRec ((bx (dA * d.P) a) :: t) // d.P = exp a.P - | Sin_D(a) -> pushRec ((bx (dA * cos a.P) a) :: t) - | Cos_D(a) -> pushRec ((bx (dA * (-sin a.P)) a) :: t) - | Tan_D(a) -> let seca = D.One / cos a.P in pushRec ((bx (dA * seca * seca) a) :: t) - | Neg_D(a) -> pushRec ((bx -dA a) :: t) - | Sqrt_D(a) -> pushRec ((bx (dA / (D N.two * d.P)) a) :: t) // d.P = sqrt a.P - | Sinh_D(a) -> pushRec ((bx (dA * cosh a.P) a) :: t) - | Cosh_D(a) -> pushRec ((bx (dA * sinh a.P) a) :: t) - | Tanh_D(a) -> let secha = D.One / cosh a.P in pushRec ((bx (dA * secha * secha) a) :: t) - | Asin_D(a) -> pushRec ((bx (dA / sqrt (D.One - a.P * a.P)) a) :: t) - | Acos_D(a) -> pushRec ((bx (-dA / sqrt (D.One - a.P * a.P)) a) :: t) - | Atan_D(a) -> pushRec ((bx (dA / (D.One + a.P * a.P)) a) :: t) - | Abs_D(a) -> pushRec ((bx (dA * D.Sign(a.P)) a) :: t) + | Add_D_D(a, b) -> pushRec ((bx d.A a) :: (bx d.A b) :: t) + | Add_D_DCons(a) -> pushRec ((bx d.A a) :: t) + | Sub_D_D(a, b) -> pushRec ((bx d.A a) :: (bx -d.A b) :: t) + | Sub_D_DCons(a) -> pushRec ((bx d.A a) :: t) + | Sub_DCons_D(b) -> pushRec ((bx -d.A b) :: t) + | Mul_D_D(a, b) -> pushRec ((bx (d.A * b.P) a) :: (bx (d.A * a.P) b) :: t) + | Mul_D_DCons(a, cons) -> pushRec ((bx (d.A * cons) a) :: t) + | Div_D_D(a, b) -> pushRec ((bx (d.A / b.P) a) :: (bx (d.A * (-a.P / (b.P * b.P))) b) :: t) + | Div_D_DCons(a, cons) -> pushRec ((bx (d.A / cons) a) :: t) + | Div_DCons_D(cons, b) -> pushRec ((bx (d.A * (-cons / (b.P * b.P))) b) :: t) + | Pow_D_D(a, b) -> pushRec ((bx (d.A * (a.P ** (b.P - D.One)) * b.P) a) :: (bx (d.A * (a.P ** b.P) * log a.P) b) :: t) + | Pow_D_DCons(a, cons) -> pushRec ((bx (d.A * (a.P ** (cons - D.One)) * cons) a) :: t) + | Pow_DCons_D(cons, b) -> pushRec ((bx (d.A * (cons ** b.P) * log cons) b) :: t) + | Atan2_D_D(a, b) -> let denom = a.P * a.P + b.P * b.P in pushRec ((bx (d.A * b.P / denom) a) :: (bx (d.A * (-a.P) / denom) b) :: t) + | Atan2_D_DCons(a, cons) -> pushRec ((bx (d.A * cons / (a.P * a.P + cons * cons)) a) :: t) + | Atan2_DCons_D(cons, b) -> pushRec ((bx (d.A * (-cons) / (cons * cons + b.P * b.P)) b) :: t) + | Log_D(a) -> pushRec ((bx (d.A / a.P) a) :: t) + | Log10_D(a) -> pushRec ((bx (d.A / (a.P * N.log10Val)) a) :: t) + | Exp_D(a) -> pushRec ((bx (d.A * d.P) a) :: t) // d.P = exp a.P + | Sin_D(a) -> pushRec ((bx (d.A * cos a.P) a) :: t) + | Cos_D(a) -> pushRec ((bx (d.A * (-sin a.P)) a) :: t) + | Tan_D(a) -> let seca = D.One / cos a.P in pushRec ((bx (d.A * seca * seca) a) :: t) + | Neg_D(a) -> pushRec ((bx -d.A a) :: t) + | Sqrt_D(a) -> pushRec ((bx (d.A / (D N.two * d.P)) a) :: t) // d.P = sqrt a.P + | Sinh_D(a) -> pushRec ((bx (d.A * cosh a.P) a) :: t) + | Cosh_D(a) -> pushRec ((bx (d.A * sinh a.P) a) :: t) + | Tanh_D(a) -> let secha = D.One / cosh a.P in pushRec ((bx (d.A * secha * secha) a) :: t) + | Asin_D(a) -> pushRec ((bx (d.A / sqrt (D.One - a.P * a.P)) a) :: t) + | Acos_D(a) -> pushRec ((bx (-d.A / sqrt (D.One - a.P * a.P)) a) :: t) + | Atan_D(a) -> pushRec ((bx (d.A / (D.One + a.P * a.P)) a) :: t) + | Abs_D(a) -> pushRec ((bx (d.A * D.Sign(a.P)) a) :: t) | Sign_D(a) -> pushRec ((bx D.Zero a) :: t) | Floor_D(a) -> pushRec ((bx D.Zero a) :: t) | Ceil_D(a) -> pushRec ((bx D.Zero a) :: t) | Round_D(a) -> pushRec ((bx D.Zero a) :: t) - | Mul_Dot_DV_DV(a, b) -> pushRec ((bxv (dA * b.P) a) :: (bxv (dA * a.P) b) :: t) - | Mul_Dot_DV_DVCons(a, cons) -> pushRec ((bxv (dA * cons) a) :: t) - | Sum_DV(a) -> pushRec ((bxv (DV.create a.Length dA) a) :: t) - | L1Norm_DV(a) -> pushRec ((bxv (dA * DV.Sign a.P) a) :: t) - | L2NormSq_DV(a) -> pushRec ((bxv (dA * (D N.two) * a.P) a) :: t) - | L2Norm_DV(a) -> pushRec ((bxv ((dA / d.P) * a.P) a) :: t) + | Mul_Dot_DV_DV(a, b) -> pushRec ((bxv (d.A * b.P) a) :: (bxv (d.A * a.P) b) :: t) + | Mul_Dot_DV_DVCons(a, cons) -> pushRec ((bxv (d.A * cons) a) :: t) + | Sum_DV(a) -> pushRec ((bxv (DV.create a.Length d.A) a) :: t) + | L1Norm_DV(a) -> pushRec ((bxv (d.A * DV.Sign a.P) a) :: t) + | L2NormSq_DV(a) -> pushRec ((bxv (d.A * (D N.two) * a.P) a) :: t) + | L2Norm_DV(a) -> pushRec ((bxv ((d.A / d.P) * a.P) a) :: t) | Item_DV(a, i) -> - adjoints.[a] <- DV.AddItem(adjoints.[a], i, dA); + a.A <- DV.AddItem(a.A, i, d.A) pushRec ((bxv DV.Zero a) :: t) - | Sum_DM(a) -> pushRec ((bxm (DM.create a.Rows a.Cols dA) a) :: t) + | Sum_DM(a) -> pushRec ((bxm (DM.create a.Rows a.Cols d.A) a) :: t) | Item_DM(a, i, j) -> - adjoints.[a] <- DM.AddItem(adjoints.[a], i, j, dA); + a.A <- DM.AddItem(a.A, i, j, d.A) pushRec ((bxm DM.Zero a) :: t) | Det_DM(a) -> pushRec ((bxm (d.T * d.P * DM.Transpose(DM.Inverse(a))) a) :: t) // Check this - | ReLU_D(a) -> pushRec ((bx (dA * ((D.Sign(a.P) + N.one) / N.two)) a) :: t) - | Sigmoid_D(a) -> pushRec ((bx (dA * d.P * (N.one - d.P)) a) :: t) // d.P = D.Sigmoid(a.P) - | LogSumExp_DV(a) -> pushRec ((bxv ((dA / exp d.P) * exp a.P) a) :: t) // d.P = DV.LogSumExp(a.P) + | ReLU_D(a) -> pushRec ((bx (d.A * ((D.Sign(a.P) + N.one) / N.two)) a) :: t) + | Sigmoid_D(a) -> pushRec ((bx (d.A * d.P * (N.one - d.P)) a) :: t) // d.P = D.Sigmoid(a.P) + | LogSumExp_DV(a) -> pushRec ((bxv ((d.A / exp d.P) * exp a.P) a) :: t) // d.P = DV.LogSumExp(a.P) | FixedPoint_D(b, bfirst, aprev, alast) -> // Christianson (1994) let imax = DiffSharp.Config.GlobalConfig.FixedPointMaxIterations @@ -3661,8 +3604,8 @@ module DOps = let mutable i = 0 - let r = dA - reverseProp adjoints r alast + let r = d.A + reverseProp r alast while i < imax do i <- i + 1 @@ -3670,306 +3613,294 @@ module DOps = //printfn "Fixed point reverse iteration timeout, i = %i" i ignore() else - if abs (adjoints.[aprev] + r - adjoints.[alast]) <= eps then + if abs (aprev.A + r - alast.A) <= eps then //printfn "Fixed point reverse iteration converged, i = %i" i i <- imax else - reverseProp adjoints (r + adjoints.[aprev]) alast + reverseProp (r + aprev.A) alast - pushRec ((bx (adjoints.[bfirst]) b) :: t) // Propogate converged adjoint back towards the original b at the beginning of the fixed point iteration + pushRec ((bx bfirst.A b) :: t) // Propogate converged adjoint back towards the original b at the beginning of the fixed point iteration | _ -> pushRec t else pushRec t | _ -> pushRec t - | (:? DeltaV as delta), (:? DV as d) -> + | (:? DV as d), (:? DV as v) -> match d with - | DVR(_, o, _, uniq) -> - let dA = adjoints.ApplyDeltaV(uniq,delta) - let fanout = fanouts.[uniq] - 1u - fanouts.[uniq] <- fanout - // If all incoming parts of the adjoint have been received, then proceed to the parent - if fanout = 0u then + | DVR(_,_,o,_,_) -> + d.A <- d.A + v + d.F <- d.F - 1u + // If all incoming parts of the adjoint have been received, then proceed to the children + if d.F = 0u then match o with - | Add_DV_DV(a, b) -> pushRec ((bxv dA a) :: (bxv dA b) :: t) - | Add_DV_DVCons(a) -> pushRec ((bxv dA a) :: t) - | Add_DV_D(a, b) -> pushRec ((bxv dA a) :: (bx (DV.Sum(dA)) b) :: t) - | Add_DV_DCons(a) -> pushRec ((bxv dA a) :: t) - | Add_DVCons_D(b) -> pushRec ((bx (DV.Sum(dA)) b) :: t) - | Sub_DV_DV(a, b) -> pushRec ((bxv dA a) :: (bxv -dA b) :: t) - | Sub_DV_DVCons(a) -> pushRec ((bxv dA a) :: t) - | Sub_DVCons_DV(a) -> pushRec ((bxv -dA a) :: t) - | Sub_DV_D(a, b) -> pushRec ((bxv dA a) :: (bx -(DV.Sum(dA)) b) :: t) - | Sub_DV_DCons(a) -> pushRec ((bxv dA a) :: t) - | Sub_DVCons_D(b) -> pushRec ((bx -(DV.Sum(dA)) b) :: t) - | Sub_D_DV(a, b) -> pushRec ((bx (DV.Sum(dA)) a) :: (bxv -dA b) :: t) - | Sub_D_DVCons(a) -> pushRec ((bx (DV.Sum(dA)) a) :: t) - | Sub_DCons_DV(b) -> pushRec ((bxv -dA b) :: t) - | Mul_Had_DV_DV(a, b) -> pushRec ((bxv (dA .* b.P) a) :: (bxv (dA .* a.P) b) :: t) - | Mul_Had_DV_DVCons(a, cons) -> pushRec ((bxv (dA .* cons) a) :: t) - | Mul_DV_D(a, b) -> pushRec ((bxv (dA * b.P) a) :: (bx (dA * a.P) b) :: t) - | Mul_DV_DCons(a, cons) -> pushRec ((bxv (dA * cons) a) :: t) - | Mul_DVCons_D(cons, b) -> pushRec ((bx (dA * cons) b) :: t) - | Mul_DM_DV(a, b) -> pushRec ((bxm (dA &* b.P) a) :: (bxv (DM.Transpose(a.P) * dA) b) :: t) - | Mul_DM_DVCons(a, cons) -> pushRec ((bxm (dA &* cons) a) :: t) - | Mul_DMCons_DV(cons, b) -> pushRec ((bxv (DM.Transpose(cons) * dA) b) :: t) - | Mul_DV_DM(a, b) -> pushRec ((bxv (dA * DM.Transpose(b.P)) a) :: (bxm (a.P &* dA) b) :: t) - | Mul_DV_DMCons(a, cons) -> pushRec ((bxv (dA * DM.Transpose(cons)) a) :: t) - | Mul_DVCons_DM(cons, b) -> pushRec ((bxm (cons &* dA) b) :: t) - | Div_Had_DV_DV(a, b) -> pushRec ((bxv (dA ./ b.P) a) :: (bxv (dA .* (-a.P ./ (b.P .* b.P))) b) :: t) - | Div_Had_DV_DVCons(a, cons) -> pushRec ((bxv (dA ./ cons) a) :: t) - | Div_Had_DVCons_DV(cons, b) -> pushRec ((bxv (dA .* (-cons ./ (b.P .* b.P))) b) :: t) - | Div_DV_D(a, b) -> pushRec ((bxv (dA / b.P) a) :: (bx (dA * (-a.P / (b.P * b.P))) b) :: t) - | Div_DV_DCons(a, cons) -> pushRec ((bxv (dA / cons) a) :: t) - | Div_DVCons_D(cons, b) -> pushRec ((bx (dA * (-cons / (b.P * b.P))) b) :: t) - | Div_D_DV(a, b) -> pushRec ((bx (DV.Sum(dA ./ b.P)) a) :: (bxv (dA .* (-a.P / (b.P .* b.P))) b) :: t) - | Div_D_DVCons(a, cons) -> pushRec ((bx (DV.Sum(dA ./ cons)) a) :: t) - | Div_DCons_DV(cons, b) -> pushRec ((bxv (dA .* (-cons / (b.P .* b.P))) b) :: t) - | Pow_DV_DV(a, b) -> pushRec ((bxv (dA .* (a.P ** (b.P - D.One)) .* b.P) a) :: (bxv (dA .* (a.P ** b.P) .* log a.P) b) :: t) - | Pow_DV_DVCons(a, cons) -> pushRec ((bxv (dA .* (a.P ** (cons - D.One)) .* cons) a) :: t) - | Pow_DVCons_DV(cons, b) -> pushRec ((bxv (dA .* (cons ** b.P) .* log cons) b) :: t) - | Atan2_DV_DV(a, b) -> let denom = (a.P .* a.P) + (b.P .* b.P) in pushRec ((bxv (dA .* b.P ./ denom) a) :: (bxv (dA .* (-a.P) ./ denom) b) :: t) - | Atan2_DV_DVCons(a, cons) -> pushRec ((bxv (dA .* cons ./ ((a.P .* a.P) + (cons .* cons))) a) :: t) - | Atan2_DVCons_DV(cons, b) -> pushRec ((bxv (dA .* (-cons) ./ ((cons .* cons) + (b.P .* b.P))) b) :: t) - | Pow_DV_D(a, b) -> pushRec ((bxv (dA .* (a.P ** (b.P - D.One)) * b.P) a) :: (bx (DV.Sum(dA .* (a.P ** b.P) .* log a.P)) b) :: t) - | Pow_DV_DCons(a, cons) -> pushRec ((bxv (dA .* (a.P ** (cons - D.One)) * cons) a) :: t) - | Pow_DVCons_D(cons, b) -> pushRec ((bx (DV.Sum(dA .* (cons ** b.P) .* log cons)) b) :: t) - | Pow_D_DV(a, b) -> pushRec ((bx (DV.Sum(dA .* (DV.Pow(a.P, b.P - D.One)) .* b.P)) a) :: (bxv (dA .* (DV.Pow(a.P, b.P)) * log a.P) b) :: t) - | Pow_D_DVCons(a, cons) -> pushRec ((bx (DV.Sum(dA .* (DV.Pow(a.P, cons - D.One)) .* cons)) a) :: t) - | Pow_DCons_DV(cons, b) -> pushRec ((bxv (dA .* (DV.Pow(cons, b.P)) * log cons) b) :: t) - | Atan2_DV_D(a, b) -> let denom = (a.P .* a.P) + (b.P * b.P) in pushRec ((bxv (dA * b.P ./ denom) a) :: (bx (DV.Sum(dA .* (-a.P) ./ denom)) b) :: t) - | Atan2_DV_DCons(a, cons) -> pushRec ((bxv (dA * cons ./ ((a.P .* a.P) + (cons * cons))) a) :: t) - | Atan2_DVCons_D(cons, b) -> pushRec ((bx (DV.Sum(dA .* (-cons) ./ ((cons .* cons) + (b.P * b.P)))) b) :: t) - | Atan2_D_DV(a, b) -> let denom = (a.P * a.P) + (b.P .* b.P) in pushRec ((bx (DV.Sum(dA .* b.P ./ denom)) a) :: (bxv (dA * (-a.P) ./ denom) b) :: t) - | Atan2_D_DVCons(a, cons) -> pushRec ((bx (DV.Sum(dA .* cons ./ ((a.P * a.P) + (cons .* cons)))) a) :: t) - | Atan2_DCons_DV(cons, b) -> pushRec ((bxv (dA * (-cons) ./ ((cons * cons) + (b.P .* b.P))) b) :: t) - | Log_DV(a) -> pushRec ((bxv (dA ./ a.P) a) :: t) - | Log10_DV(a) -> pushRec ((bxv (dA ./ (a.P * N.log10Val)) a) :: t) - | Exp_DV(a) -> pushRec ((bxv (dA .* d.P) a) :: t) // d.P = exp a.P - | Sin_DV(a) -> pushRec ((bxv (dA .* cos a.P) a) :: t) - | Cos_DV(a) -> pushRec ((bxv (-dA .* sin a.P) a) :: t) - | Tan_DV(a) -> let seca = D.One / cos a.P in pushRec ((bxv (dA .* seca .* seca) a) :: t) - | Neg_DV(a) -> pushRec ((bxv -dA a) :: t) - | Sqrt_DV(a) -> pushRec ((bxv (dA ./ (N.two * d.P)) a) :: t) // d.P = sqrt a.P - | Sinh_DV(a) -> pushRec ((bxv (dA .* cosh a.P) a) :: t) - | Cosh_DV(a) -> pushRec ((bxv (dA .* sinh a.P) a) :: t) - | Tanh_DV(a) -> let secha = D.One / cosh a.P in pushRec ((bxv (dA .* secha .* secha) a) :: t) - | Asin_DV(a) -> pushRec ((bxv (dA ./ sqrt (D.One - (a.P .* a.P))) a) :: t) - | Acos_DV(a) -> pushRec ((bxv (-dA ./ sqrt (D.One - (a.P .* a.P))) a) :: t) - | Atan_DV(a) -> pushRec ((bxv (dA ./ (D.One + (a.P .* a.P))) a) :: t) - | Abs_DV(a) -> pushRec ((bxv (dA .* DV.Sign a.P) a) :: t) + | Add_DV_DV(a, b) -> pushRec ((bxv d.A a) :: (bxv d.A b) :: t) + | Add_DV_DVCons(a) -> pushRec ((bxv d.A a) :: t) + | Add_DV_D(a, b) -> pushRec ((bxv d.A a) :: (bx (DV.Sum(d.A)) b) :: t) + | Add_DV_DCons(a) -> pushRec ((bxv d.A a) :: t) + | Add_DVCons_D(b) -> pushRec ((bx (DV.Sum(d.A)) b) :: t) + | Sub_DV_DV(a, b) -> pushRec ((bxv d.A a) :: (bxv -d.A b) :: t) + | Sub_DV_DVCons(a) -> pushRec ((bxv d.A a) :: t) + | Sub_DVCons_DV(a) -> pushRec ((bxv -d.A a) :: t) + | Sub_DV_D(a, b) -> pushRec ((bxv d.A a) :: (bx -(DV.Sum(d.A)) b) :: t) + | Sub_DV_DCons(a) -> pushRec ((bxv d.A a) :: t) + | Sub_DVCons_D(b) -> pushRec ((bx -(DV.Sum(d.A)) b) :: t) + | Sub_D_DV(a, b) -> pushRec ((bx (DV.Sum(d.A)) a) :: (bxv -d.A b) :: t) + | Sub_D_DVCons(a) -> pushRec ((bx (DV.Sum(d.A)) a) :: t) + | Sub_DCons_DV(b) -> pushRec ((bxv -d.A b) :: t) + | Mul_Had_DV_DV(a, b) -> pushRec ((bxv (d.A .* b.P) a) :: (bxv (d.A .* a.P) b) :: t) + | Mul_Had_DV_DVCons(a, cons) -> pushRec ((bxv (d.A .* cons) a) :: t) + | Mul_DV_D(a, b) -> pushRec ((bxv (d.A * b.P) a) :: (bx (d.A * a.P) b) :: t) + | Mul_DV_DCons(a, cons) -> pushRec ((bxv (d.A * cons) a) :: t) + | Mul_DVCons_D(cons, b) -> pushRec ((bx (d.A * cons) b) :: t) + | Mul_DM_DV(a, b) -> pushRec ((bxm (d.A &* b.P) a) :: (bxv (DM.Transpose(a.P) * d.A) b) :: t) + | Mul_DM_DVCons(a, cons) -> pushRec ((bxm (d.A &* cons) a) :: t) + | Mul_DMCons_DV(cons, b) -> pushRec ((bxv (DM.Transpose(cons) * d.A) b) :: t) + | Mul_DV_DM(a, b) -> pushRec ((bxv (d.A * DM.Transpose(b.P)) a) :: (bxm (a.P &* d.A) b) :: t) + | Mul_DV_DMCons(a, cons) -> pushRec ((bxv (d.A * DM.Transpose(cons)) a) :: t) + | Mul_DVCons_DM(cons, b) -> pushRec ((bxm (cons &* d.A) b) :: t) + | Div_Had_DV_DV(a, b) -> pushRec ((bxv (d.A ./ b.P) a) :: (bxv (d.A .* (-a.P ./ (b.P .* b.P))) b) :: t) + | Div_Had_DV_DVCons(a, cons) -> pushRec ((bxv (d.A ./ cons) a) :: t) + | Div_Had_DVCons_DV(cons, b) -> pushRec ((bxv (d.A .* (-cons ./ (b.P .* b.P))) b) :: t) + | Div_DV_D(a, b) -> pushRec ((bxv (d.A / b.P) a) :: (bx (d.A * (-a.P / (b.P * b.P))) b) :: t) + | Div_DV_DCons(a, cons) -> pushRec ((bxv (d.A / cons) a) :: t) + | Div_DVCons_D(cons, b) -> pushRec ((bx (d.A * (-cons / (b.P * b.P))) b) :: t) + | Div_D_DV(a, b) -> pushRec ((bx (DV.Sum(d.A ./ b.P)) a) :: (bxv (d.A .* (-a.P / (b.P .* b.P))) b) :: t) + | Div_D_DVCons(a, cons) -> pushRec ((bx (DV.Sum(d.A ./ cons)) a) :: t) + | Div_DCons_DV(cons, b) -> pushRec ((bxv (d.A .* (-cons / (b.P .* b.P))) b) :: t) + | Pow_DV_DV(a, b) -> pushRec ((bxv (d.A .* (a.P ** (b.P - D.One)) .* b.P) a) :: (bxv (d.A .* (a.P ** b.P) .* log a.P) b) :: t) + | Pow_DV_DVCons(a, cons) -> pushRec ((bxv (d.A .* (a.P ** (cons - D.One)) .* cons) a) :: t) + | Pow_DVCons_DV(cons, b) -> pushRec ((bxv (d.A .* (cons ** b.P) .* log cons) b) :: t) + | Atan2_DV_DV(a, b) -> let denom = (a.P .* a.P) + (b.P .* b.P) in pushRec ((bxv (d.A .* b.P ./ denom) a) :: (bxv (d.A .* (-a.P) ./ denom) b) :: t) + | Atan2_DV_DVCons(a, cons) -> pushRec ((bxv (d.A .* cons ./ ((a.P .* a.P) + (cons .* cons))) a) :: t) + | Atan2_DVCons_DV(cons, b) -> pushRec ((bxv (d.A .* (-cons) ./ ((cons .* cons) + (b.P .* b.P))) b) :: t) + | Pow_DV_D(a, b) -> pushRec ((bxv (d.A .* (a.P ** (b.P - D.One)) * b.P) a) :: (bx (DV.Sum(d.A .* (a.P ** b.P) .* log a.P)) b) :: t) + | Pow_DV_DCons(a, cons) -> pushRec ((bxv (d.A .* (a.P ** (cons - D.One)) * cons) a) :: t) + | Pow_DVCons_D(cons, b) -> pushRec ((bx (DV.Sum(d.A .* (cons ** b.P) .* log cons)) b) :: t) + | Pow_D_DV(a, b) -> pushRec ((bx (DV.Sum(d.A .* (DV.Pow(a.P, b.P - D.One)) .* b.P)) a) :: (bxv (d.A .* (DV.Pow(a.P, b.P)) * log a.P) b) :: t) + | Pow_D_DVCons(a, cons) -> pushRec ((bx (DV.Sum(d.A .* (DV.Pow(a.P, cons - D.One)) .* cons)) a) :: t) + | Pow_DCons_DV(cons, b) -> pushRec ((bxv (d.A .* (DV.Pow(cons, b.P)) * log cons) b) :: t) + | Atan2_DV_D(a, b) -> let denom = (a.P .* a.P) + (b.P * b.P) in pushRec ((bxv (d.A * b.P ./ denom) a) :: (bx (DV.Sum(d.A .* (-a.P) ./ denom)) b) :: t) + | Atan2_DV_DCons(a, cons) -> pushRec ((bxv (d.A * cons ./ ((a.P .* a.P) + (cons * cons))) a) :: t) + | Atan2_DVCons_D(cons, b) -> pushRec ((bx (DV.Sum(d.A .* (-cons) ./ ((cons .* cons) + (b.P * b.P)))) b) :: t) + | Atan2_D_DV(a, b) -> let denom = (a.P * a.P) + (b.P .* b.P) in pushRec ((bx (DV.Sum(d.A .* b.P ./ denom)) a) :: (bxv (d.A * (-a.P) ./ denom) b) :: t) + | Atan2_D_DVCons(a, cons) -> pushRec ((bx (DV.Sum(d.A .* cons ./ ((a.P * a.P) + (cons .* cons)))) a) :: t) + | Atan2_DCons_DV(cons, b) -> pushRec ((bxv (d.A * (-cons) ./ ((cons * cons) + (b.P .* b.P))) b) :: t) + | Log_DV(a) -> pushRec ((bxv (d.A ./ a.P) a) :: t) + | Log10_DV(a) -> pushRec ((bxv (d.A ./ (a.P * N.log10Val)) a) :: t) + | Exp_DV(a) -> pushRec ((bxv (d.A .* d.P) a) :: t) // d.P = exp a.P + | Sin_DV(a) -> pushRec ((bxv (d.A .* cos a.P) a) :: t) + | Cos_DV(a) -> pushRec ((bxv (-d.A .* sin a.P) a) :: t) + | Tan_DV(a) -> let seca = D.One / cos a.P in pushRec ((bxv (d.A .* seca .* seca) a) :: t) + | Neg_DV(a) -> pushRec ((bxv -d.A a) :: t) + | Sqrt_DV(a) -> pushRec ((bxv (d.A ./ (N.two * d.P)) a) :: t) // d.P = sqrt a.P + | Sinh_DV(a) -> pushRec ((bxv (d.A .* cosh a.P) a) :: t) + | Cosh_DV(a) -> pushRec ((bxv (d.A .* sinh a.P) a) :: t) + | Tanh_DV(a) -> let secha = D.One / cosh a.P in pushRec ((bxv (d.A .* secha .* secha) a) :: t) + | Asin_DV(a) -> pushRec ((bxv (d.A ./ sqrt (D.One - (a.P .* a.P))) a) :: t) + | Acos_DV(a) -> pushRec ((bxv (-d.A ./ sqrt (D.One - (a.P .* a.P))) a) :: t) + | Atan_DV(a) -> pushRec ((bxv (d.A ./ (D.One + (a.P .* a.P))) a) :: t) + | Abs_DV(a) -> pushRec ((bxv (d.A .* DV.Sign a.P) a) :: t) | Sign_DV(a) -> pushRec ((bxv DV.Zero a) :: t) | Floor_DV(a) -> pushRec ((bxv DV.Zero a) :: t) | Ceil_DV(a) -> pushRec ((bxv DV.Zero a) :: t) | Round_DV(a) -> pushRec ((bxv DV.Zero a) :: t) - | Make_DV_ofDs(a) -> pushRec (t |> List.append (a |> Array.mapi (fun i v -> (bx dA.[i] v)) |> List.ofArray)) + | Make_DV_ofDs(a) -> pushRec (t |> List.append (a |> Array.mapi (fun i v -> (bx d.A.[i] v)) |> List.ofArray)) | SliceRow_DM(a, i, j) -> - adjoints.[a] <- DM.AddSubMatrix(adjoints.[a], i, j, dA.ToRowDM()) + a.A <- DM.AddSubMatrix(a.A, i, j, d.A.ToRowDM()) pushRec ((bxm DM.Zero a) :: t) | SliceCol_DM(a, i, j) -> - adjoints.[a] <- DM.AddSubMatrix(adjoints.[a], i, j, dA.ToColDM()) + a.A <- DM.AddSubMatrix(a.A, i, j, d.A.ToColDM()) pushRec ((bxm DM.Zero a) :: t) - | Solve_DM_DV(a, b) -> let ba = DM.Solve(DM.Transpose(a), dA) in pushRec ((bxm (-ba &* dA) a) :: (bxv (ba) b) :: t) - | Solve_DM_DVCons(a, cons) -> let ba = DM.Solve(DM.Transpose(a), dA) in pushRec ((bxm (-ba &* dA) a) :: t) - | Solve_DMCons_DV(cons, b) -> let ba = DM.Solve(DM.Transpose(cons), dA) in pushRec ((bxv ba b) :: t) + | Solve_DM_DV(a, b) -> let ba = DM.Solve(DM.Transpose(a), d.A) in pushRec ((bxm (-ba &* d.A) a) :: (bxv (ba) b) :: t) + | Solve_DM_DVCons(a, cons) -> let ba = DM.Solve(DM.Transpose(a), d.A) in pushRec ((bxm (-ba &* d.A) a) :: t) + | Solve_DMCons_DV(cons, b) -> let ba = DM.Solve(DM.Transpose(cons), d.A) in pushRec ((bxv ba b) :: t) | Append_DV_DV(a, b) -> - adjoints.[a] <- adjoints.[a] + dA.[..(a.Length - 1)] - adjoints.[b] <- adjoints.[b] + dA.[a.Length..] + a.A <- a.A + d.A.[..(a.Length - 1)] + b.A <- b.A + d.A.[a.Length..] pushRec ((bxv DV.Zero a) :: (bxv DV.Zero b) :: t) | Append_DV_DVCons(a) -> - adjoints.[a] <- adjoints.[a] + dA.[..(a.Length - 1)] + a.A <- a.A + d.A.[..(a.Length - 1)] pushRec ((bxv DV.Zero a) :: t) | Append_DVCons_DV(b) -> - adjoints.[b] <- adjoints.[b] + dA.[(d.Length - b.Length)..] + b.A <- b.A + d.A.[(d.Length - b.Length)..] pushRec ((bxv DV.Zero b) :: t) | Split_DV(a, i) -> - adjoints.[a] <- DV.AddSubVector(adjoints.[a], i, dA) + a.A <- DV.AddSubVector(a.A, i, d.A) pushRec ((bxv DV.Zero a) :: t) - | AddItem_DV_D(a, i, b) -> pushRec ((bxv dA a) :: (bx (dA.[i]) b) :: t) - | AddItem_DV_DCons(a) -> pushRec ((bxv dA a) :: t) - | AddItem_DVCons_D(i, b) -> pushRec ((bx dA.[i] b) :: t) - | AddSubVector_DV_DV(a, i, b) -> pushRec ((bxv dA a) :: (bxv (dA.[i..(i + b.Length - 1)]) b) :: t) - | AddSubVector_DV_DVCons(a) -> pushRec ((bxv dA a) :: t) - | AddSubVector_DVCons_DV(i, b) -> pushRec ((bxv (dA.[i..(i + b.Length - 1)]) b) :: t) - | ReshapeCopy_DM_DV(a) -> pushRec ((bxm (DV.ReshapeToDM(a.Rows, dA)) a) :: t) + | AddItem_DV_D(a, i, b) -> pushRec ((bxv d.A a) :: (bx (d.A.[i]) b) :: t) + | AddItem_DV_DCons(a) -> pushRec ((bxv d.A a) :: t) + | AddItem_DVCons_D(i, b) -> pushRec ((bx d.A.[i] b) :: t) + | AddSubVector_DV_DV(a, i, b) -> pushRec ((bxv d.A a) :: (bxv (d.A.[i..(i + b.Length - 1)]) b) :: t) + | AddSubVector_DV_DVCons(a) -> pushRec ((bxv d.A a) :: t) + | AddSubVector_DVCons_DV(i, b) -> pushRec ((bxv (d.A.[i..(i + b.Length - 1)]) b) :: t) + | ReshapeCopy_DM_DV(a) -> pushRec ((bxm (DV.ReshapeToDM(a.Rows, d.A)) a) :: t) | Slice_DV(a, i) -> - adjoints.[a] <- DV.AddSubVector(adjoints.[a], i, dA) + a.A <- DV.AddSubVector(a.A, i, d.A) pushRec ((bxv DV.Zero a) :: t) | Diagonal_DM(a) -> - adjoints.[a] <- DM.AddDiagonal(adjoints.[a], dA) + a.A <- DM.AddDiagonal(a.A, d.A) pushRec ((bxm DM.Zero a) :: t) - | ReLU_DV(a) -> pushRec ((bxv (dA .* ((DV.Sign(a.P) + N.one) / N.two)) a) :: t) - | Sigmoid_DV(a) -> pushRec ((bxv (dA .* d.P .* (N.one - d.P)) a) :: t) // d.P = DV.Sigmoid(a.P) + | ReLU_DV(a) -> pushRec ((bxv (d.A .* ((DV.Sign(a.P) + N.one) / N.two)) a) :: t) + | Sigmoid_DV(a) -> pushRec ((bxv (d.A .* d.P .* (N.one - d.P)) a) :: t) // d.P = DV.Sigmoid(a.P) | _ -> pushRec t else pushRec t | _ -> pushRec t - | (:? DeltaM as delta), (:? DM as d) -> + | (:? DM as d), (:? DM as v) -> match d with - | DMR(_, o, _, uniq) -> - let dA = adjoints.ApplyDeltaM(uniq,delta) - let fanout = fanouts.[uniq] - 1u - fanouts.[uniq] <- fanout - // If all incoming parts of the adjoint have been received, then proceed to the parent - if fanout = 0u then + | DMR(_,_,o,_,_) -> + d.A <- d.A + v + d.F <- d.F - 1u + // If all incoming parts of the adjoint have been received, then proceed to the children + if d.F = 0u then match o with - | Add_DM_DM(a, b) -> pushRec ((bxm dA a) :: (bxm dA b) :: t) - | Add_DM_DMCons(a) -> pushRec ((bxm dA a) :: t) + | Add_DM_DM(a, b) -> pushRec ((bxm d.A a) :: (bxm d.A b) :: t) + | Add_DM_DMCons(a) -> pushRec ((bxm d.A a) :: t) - // When pushing "-dA" as adjoint increment for b, the operation - // "b.Adjoint <- -1.0 * dA + b.Adjoint" + // When pushing "-d.A" as adjoint increment for b, the operation + // "b.Adjoint <- -1.0 * d.A + b.Adjoint" // can be performed directly in-place. Instead of pushing a D|DV|DM we should a // structured expression about how to compute the D|DV|DM which can be interpreted // to do an in-place update - | Sub_DM_DM(a, b) -> pushRec ((bxm dA a) :: (bdm (XNegM (XM dA)) b) :: t) + | Sub_DM_DM(a, b) -> pushRec ((bxm d.A a) :: (bxm (-d.A) b) :: t) + + | Sub_DM_DMCons(a) -> pushRec ((bxm d.A a) :: t) + | Sub_DMCons_DM(a) -> pushRec ((bxm d.A a) :: t) // TODO: also avoid the inplace operations in most of the below. - | Sub_DM_DMCons(a) -> pushRec ((bxm dA a) :: t) - | Sub_DMCons_DM(a) -> pushRec ((bxm -dA a) :: t) - | Mul_DM_DM(a, b) -> pushRec ((bxm (dA * DM.Transpose(b.P)) a) :: (bxm (DM.Transpose(a.P) * dA) b) :: t) - | Mul_DM_DMCons(a, cons) -> pushRec ((bxm (dA * DM.Transpose(cons)) a) :: t) - | Mul_DMCons_DM(cons, b) -> pushRec ((bxm (DM.Transpose(cons) * dA) b) :: t) - | Mul_Had_DM_DM(a, b) -> pushRec ((bxm (dA .* b.P) a) :: (bxm (dA .* a.P) b) :: t) - | Mul_Had_DM_DMCons(a, cons) -> pushRec ((bxm (dA .* cons) a) :: t) - | Mul_DM_D(a, b) -> pushRec ((bxm (dA * b.P) a) :: (bx (DM.Sum(dA .* a.P)) b) :: t) - | Mul_DM_DCons(a, cons) -> pushRec ((bxm (dA * cons) a) :: t) - | Mul_DMCons_D(cons, b) -> pushRec ((bx (DM.Sum(dA .* cons)) b) :: t) - | Mul_Out_DV_DV(a, b) -> pushRec ((bxv (dA * b.P) a) :: (bxv (DM.Transpose(dA) * a.P) b) :: t) - | Mul_Out_DV_DVCons(a, cons) -> pushRec ((bxv (dA * cons) a) :: t) - | Mul_Out_DVCons_DV(cons, b) -> pushRec ((bxv (DM.Transpose(dA) * cons) b) :: t) - | Div_Had_DM_DM(a, b) -> pushRec ((bxm (dA ./ b.P) a) :: (bxm (dA .* (-a.P ./ (b.P .* b.P))) b) :: t) - | Div_Had_DM_DMCons(a, cons) -> pushRec ((bxm (dA ./ cons) a) :: t) - | Div_Had_DMCons_DM(cons, b) -> pushRec ((bxm (dA .* (-cons ./ (b.P .* b.P))) b) :: t) - | Pow_DM_DM(a, b) -> pushRec ((bxm (dA .* (a.P ** (b.P - D.One)) .* b.P) a) :: (bxm (dA .* (a.P ** b.P) .* log a.P) b) :: t) - | Pow_DM_DMCons(a, cons) -> pushRec ((bxm (dA .* (a.P ** (cons - D.One)) .* cons) a) :: t) - | Pow_DMCons_DM(cons, b) -> pushRec ((bxm (dA .* (cons ** b.P) .* log cons) b) :: t) - | Atan2_DM_DM(a, b) -> let denom = (a.P .* a.P) + (b.P .* b.P) in pushRec ((bxm (dA .* b.P ./ denom) a) :: (bxm (dA .* (-a.P) ./ denom) b) :: t) - | Atan2_DM_DMCons(a, cons) -> pushRec ((bxm (dA .* cons ./ ((a.P .* a.P) + (cons .* cons))) a) :: t) - | Atan2_DMCons_DM(cons, b) -> pushRec ((bxm (dA .* (-cons) ./ ((cons .* cons) + (b.P .* b.P))) b) :: t) - | Add_DM_D(a, b) -> pushRec ((bxm dA a) :: (bx (DM.Sum(dA)) b) :: t) - | Add_DM_DCons(a) -> pushRec ((bxm dA a) :: t) - | Add_DMCons_D(b) -> pushRec ((bx (DM.Sum(dA)) b) :: t) + | Mul_DM_DM(a, b) -> pushRec ((bxm (d.A * DM.Transpose(b.P)) a) :: (bxm (DM.Transpose(a.P) * d.A) b) :: t) + | Mul_DM_DMCons(a, cons) -> pushRec ((bxm (d.A * DM.Transpose(cons)) a) :: t) + | Mul_DMCons_DM(cons, b) -> pushRec ((bxm (DM.Transpose(cons) * d.A) b) :: t) + | Mul_Had_DM_DM(a, b) -> pushRec ((bxm (d.A .* b.P) a) :: (bxm (d.A .* a.P) b) :: t) + | Mul_Had_DM_DMCons(a, cons) -> pushRec ((bxm (d.A .* cons) a) :: t) + | Mul_DM_D(a, b) -> pushRec ((bxm (d.A * b.P) a) :: (bx (DM.Sum(d.A .* a.P)) b) :: t) + | Mul_DM_DCons(a, cons) -> pushRec ((bxm (d.A * cons) a) :: t) + | Mul_DMCons_D(cons, b) -> pushRec ((bx (DM.Sum(d.A .* cons)) b) :: t) + | Mul_Out_DV_DV(a, b) -> pushRec ((bxv (d.A * b.P) a) :: (bxv (DM.Transpose(d.A) * a.P) b) :: t) + | Mul_Out_DV_DVCons(a, cons) -> pushRec ((bxv (d.A * cons) a) :: t) + | Mul_Out_DVCons_DV(cons, b) -> pushRec ((bxv (DM.Transpose(d.A) * cons) b) :: t) + | Div_Had_DM_DM(a, b) -> pushRec ((bxm (d.A ./ b.P) a) :: (bxm (d.A .* (-a.P ./ (b.P .* b.P))) b) :: t) + | Div_Had_DM_DMCons(a, cons) -> pushRec ((bxm (d.A ./ cons) a) :: t) + | Div_Had_DMCons_DM(cons, b) -> pushRec ((bxm (d.A .* (-cons ./ (b.P .* b.P))) b) :: t) + | Pow_DM_DM(a, b) -> pushRec ((bxm (d.A .* (a.P ** (b.P - D.One)) .* b.P) a) :: (bxm (d.A .* (a.P ** b.P) .* log a.P) b) :: t) + | Pow_DM_DMCons(a, cons) -> pushRec ((bxm (d.A .* (a.P ** (cons - D.One)) .* cons) a) :: t) + | Pow_DMCons_DM(cons, b) -> pushRec ((bxm (d.A .* (cons ** b.P) .* log cons) b) :: t) + | Atan2_DM_DM(a, b) -> let denom = (a.P .* a.P) + (b.P .* b.P) in pushRec ((bxm (d.A .* b.P ./ denom) a) :: (bxm (d.A .* (-a.P) ./ denom) b) :: t) + | Atan2_DM_DMCons(a, cons) -> pushRec ((bxm (d.A .* cons ./ ((a.P .* a.P) + (cons .* cons))) a) :: t) + | Atan2_DMCons_DM(cons, b) -> pushRec ((bxm (d.A .* (-cons) ./ ((cons .* cons) + (b.P .* b.P))) b) :: t) + | Add_DM_D(a, b) -> pushRec ((bxm d.A a) :: (bx (DM.Sum(d.A)) b) :: t) + | Add_DM_DCons(a) -> pushRec ((bxm d.A a) :: t) + | Add_DMCons_D(b) -> pushRec ((bx (DM.Sum(d.A)) b) :: t) | Add_DMCols_DV(a, b) -> - dA.GetCols() |> Seq.iter (fun v -> adjoints.[b] <- adjoints.[b] + v) - pushRec ((bxm dA a) :: (bxv DV.Zero b) :: t) + d.A.GetCols() |> Seq.iter (fun v -> b.A <- b.A + v) + pushRec ((bxm d.A a) :: (bxv DV.Zero b) :: t) | Add_DMCols_DVCons(a) -> - pushRec ((bxm dA a) :: t) + pushRec ((bxm d.A a) :: t) | Add_DMColsCons_DV(b) -> - dA.GetCols() |> Seq.iter (fun v -> adjoints.[b] <- adjoints.[b] + v) + d.A.GetCols() |> Seq.iter (fun v -> b.A <- b.A + v) pushRec ((bxv DV.Zero b) :: t) - | Sub_DM_D(a, b) -> pushRec ((bxm dA a) :: (bx -(DM.Sum(dA)) b) :: t) - | Sub_DM_DCons(a) -> pushRec ((bxm dA a) :: t) - | Sub_DMCons_D(b) -> pushRec ((bx -(DM.Sum(dA)) b) :: t) - | Sub_D_DM(a, b) -> pushRec ((bx (DM.Sum(dA)) a) :: (bxm -dA b) :: t) - | Sub_D_DMCons(a) -> pushRec ((bx (DM.Sum(dA)) a) :: t) - | Sub_DCons_DM(b) -> pushRec ((bxm -dA b) :: t) - | Div_DM_D(a, b) -> pushRec ((bxm (dA / b.P) a) :: (bx (DM.Sum (dA .* (-a.P / b.P * b.P))) b) :: t) - | Div_DM_DCons(a, cons) -> pushRec ((bxm (dA / cons) a) :: t) - | Div_DMCons_D(cons, b) -> pushRec ((bx (DM.Sum (dA .* (-cons / (b.P * b.P)))) b) :: t) - | Div_D_DM(a, b) -> pushRec ((bx (DM.Sum(dA ./ b.P)) a) :: (bxm (dA .* (-a.P / (b.P .* b.P))) b) :: t) - | Div_D_DMCons(a, cons) -> pushRec ((bx (DM.Sum(dA ./ cons)) a) :: t) - | Div_DCons_DM(cons, b) -> pushRec ((bxm (dA .* (-cons / (b.P .* b.P))) b) :: t) - | Pow_DM_D(a, b) -> pushRec ((bxm (dA .* (a.P ** (b.P - D.One)) * b.P) a) :: (bx (DM.Sum(dA .* (a.P ** b.P) .* log a.P)) b) :: t) - | Pow_DM_DCons(a, cons) -> pushRec ((bxm (dA .* (a.P ** (cons - D.One)) * cons) a) :: t) - | Pow_DMCons_D(cons, b) -> pushRec ((bx (DM.Sum(dA .* (cons ** b.P) .* log cons)) b) :: t) - | Pow_D_DM(a, b) -> pushRec ((bx (DM.Sum(dA .* (DM.Pow(a.P, b.P - D.One)) .* b.P)) a) :: (bxm (dA .* (DM.Pow(a.P, b.P)) * log a.P) b) :: t) - | Pow_D_DMCons(a, cons) -> pushRec ((bx (DM.Sum(dA .* (DM.Pow(a.P, cons - D.One)) .* cons)) a) :: t) - | Pow_DCons_DM(cons, b) -> pushRec ((bxm (dA .* (DM.Pow(cons, b.P)) * log cons) b) :: t) - | Atan2_DM_D(a, b) -> let denom = (a.P .* a.P) + (b.P * b.P) in pushRec ((bxm (dA * b.P ./ denom) a) :: (bx (DM.Sum(dA .* (-a.P) ./ denom)) b) :: t) - | Atan2_DM_DCons(a, cons) -> pushRec ((bxm (dA * cons ./ ((a.P .* a.P) + (cons * cons))) a) :: t) - | Atan2_DMCons_D(cons, b) ->pushRec ((bx (DM.Sum(dA .* (-cons) ./ ((cons .* cons) + (b.P * b.P)))) b) :: t) - | Atan2_D_DM(a, b) -> let denom = (a.P * a.P) + (b.P .* b.P) in pushRec ((bx (DM.Sum(dA .* b.P ./ denom)) a) :: (bxm (dA * (-a.P) ./ denom) b) :: t) - | Atan2_D_DMCons(a, cons) -> pushRec ((bx (DM.Sum(dA .* cons ./ ((a.P * a.P) + (cons .* cons)))) a) :: t) - | Atan2_DCons_DM(cons, b) -> pushRec ((bxm (dA * (-cons) ./ ((cons * cons) + (b.P .* b.P))) b) :: t) - | Log_DM(a) -> pushRec ((bxm (dA ./ a.P) a) :: t) - | Log10_DM(a) -> pushRec ((bxm (dA ./ (a.P * N.log10Val)) a) :: t) - | Exp_DM(a) -> pushRec ((bxm (dA .* d.P) a) :: t) // d.P = exp a.P - | Sin_DM(a) -> pushRec ((bxm (dA .* cos a.P) a) :: t) - | Cos_DM(a) -> pushRec ((bxm (-dA .* sin a.P) a) :: t) - | Tan_DM(a) -> let seca = D.One / cos a.P in pushRec ((bxm (dA .* seca .* seca) a) :: t) - | Neg_DM(a) -> pushRec ((bxm -dA a) :: t) - | Sqrt_DM(a) -> pushRec ((bxm (dA ./ (N.two * d.P)) a) :: t) // d.P = sqrt a.P - | Sinh_DM(a) -> pushRec ((bxm (dA .* cosh a.P) a) :: t) - | Cosh_DM(a) -> pushRec ((bxm (dA .* sinh a.P) a) :: t) - | Tanh_DM(a) -> let secha = D.One / cosh a.P in pushRec ((bxm (dA .* secha .* secha) a) :: t) - | Asin_DM(a) -> pushRec ((bxm (dA ./ sqrt (D.One - (a.P .* a.P))) a) :: t) - | Acos_DM(a) -> pushRec ((bxm (-dA ./ sqrt (D.One - (a.P .* a.P))) a) :: t) - | Atan_DM(a) -> pushRec ((bxm (dA ./ (D.One + (a.P .* a.P))) a) :: t) - | Abs_DM(a) -> pushRec ((bxm (dA .* DM.Sign a.P) a) :: t) + | Sub_DM_D(a, b) -> pushRec ((bxm d.A a) :: (bx -(DM.Sum(d.A)) b) :: t) + | Sub_DM_DCons(a) -> pushRec ((bxm d.A a) :: t) + | Sub_DMCons_D(b) -> pushRec ((bx -(DM.Sum(d.A)) b) :: t) + | Sub_D_DM(a, b) -> pushRec ((bx (DM.Sum(d.A)) a) :: (bxm -d.A b) :: t) + | Sub_D_DMCons(a) -> pushRec ((bx (DM.Sum(d.A)) a) :: t) + | Sub_DCons_DM(b) -> pushRec ((bxm -d.A b) :: t) + | Div_DM_D(a, b) -> pushRec ((bxm (d.A / b.P) a) :: (bx (DM.Sum (d.A .* (-a.P / b.P * b.P))) b) :: t) + | Div_DM_DCons(a, cons) -> pushRec ((bxm (d.A / cons) a) :: t) + | Div_DMCons_D(cons, b) -> pushRec ((bx (DM.Sum (d.A .* (-cons / (b.P * b.P)))) b) :: t) + | Div_D_DM(a, b) -> pushRec ((bx (DM.Sum(d.A ./ b.P)) a) :: (bxm (d.A .* (-a.P / (b.P .* b.P))) b) :: t) + | Div_D_DMCons(a, cons) -> pushRec ((bx (DM.Sum(d.A ./ cons)) a) :: t) + | Div_DCons_DM(cons, b) -> pushRec ((bxm (d.A .* (-cons / (b.P .* b.P))) b) :: t) + | Pow_DM_D(a, b) -> pushRec ((bxm (d.A .* (a.P ** (b.P - D.One)) * b.P) a) :: (bx (DM.Sum(d.A .* (a.P ** b.P) .* log a.P)) b) :: t) + | Pow_DM_DCons(a, cons) -> pushRec ((bxm (d.A .* (a.P ** (cons - D.One)) * cons) a) :: t) + | Pow_DMCons_D(cons, b) -> pushRec ((bx (DM.Sum(d.A .* (cons ** b.P) .* log cons)) b) :: t) + | Pow_D_DM(a, b) -> pushRec ((bx (DM.Sum(d.A .* (DM.Pow(a.P, b.P - D.One)) .* b.P)) a) :: (bxm (d.A .* (DM.Pow(a.P, b.P)) * log a.P) b) :: t) + | Pow_D_DMCons(a, cons) -> pushRec ((bx (DM.Sum(d.A .* (DM.Pow(a.P, cons - D.One)) .* cons)) a) :: t) + | Pow_DCons_DM(cons, b) -> pushRec ((bxm (d.A .* (DM.Pow(cons, b.P)) * log cons) b) :: t) + | Atan2_DM_D(a, b) -> let denom = (a.P .* a.P) + (b.P * b.P) in pushRec ((bxm (d.A * b.P ./ denom) a) :: (bx (DM.Sum(d.A .* (-a.P) ./ denom)) b) :: t) + | Atan2_DM_DCons(a, cons) -> pushRec ((bxm (d.A * cons ./ ((a.P .* a.P) + (cons * cons))) a) :: t) + | Atan2_DMCons_D(cons, b) ->pushRec ((bx (DM.Sum(d.A .* (-cons) ./ ((cons .* cons) + (b.P * b.P)))) b) :: t) + | Atan2_D_DM(a, b) -> let denom = (a.P * a.P) + (b.P .* b.P) in pushRec ((bx (DM.Sum(d.A .* b.P ./ denom)) a) :: (bxm (d.A * (-a.P) ./ denom) b) :: t) + | Atan2_D_DMCons(a, cons) -> pushRec ((bx (DM.Sum(d.A .* cons ./ ((a.P * a.P) + (cons .* cons)))) a) :: t) + | Atan2_DCons_DM(cons, b) -> pushRec ((bxm (d.A * (-cons) ./ ((cons * cons) + (b.P .* b.P))) b) :: t) + | Log_DM(a) -> pushRec ((bxm (d.A ./ a.P) a) :: t) + | Log10_DM(a) -> pushRec ((bxm (d.A ./ (a.P * N.log10Val)) a) :: t) + | Exp_DM(a) -> pushRec ((bxm (d.A .* d.P) a) :: t) // d.P = exp a.P + | Sin_DM(a) -> pushRec ((bxm (d.A .* cos a.P) a) :: t) + | Cos_DM(a) -> pushRec ((bxm (-d.A .* sin a.P) a) :: t) + | Tan_DM(a) -> let seca = D.One / cos a.P in pushRec ((bxm (d.A .* seca .* seca) a) :: t) + | Neg_DM(a) -> pushRec ((bxm -d.A a) :: t) + | Sqrt_DM(a) -> pushRec ((bxm (d.A ./ (N.two * d.P)) a) :: t) // d.P = sqrt a.P + | Sinh_DM(a) -> pushRec ((bxm (d.A .* cosh a.P) a) :: t) + | Cosh_DM(a) -> pushRec ((bxm (d.A .* sinh a.P) a) :: t) + | Tanh_DM(a) -> let secha = D.One / cosh a.P in pushRec ((bxm (d.A .* secha .* secha) a) :: t) + | Asin_DM(a) -> pushRec ((bxm (d.A ./ sqrt (D.One - (a.P .* a.P))) a) :: t) + | Acos_DM(a) -> pushRec ((bxm (-d.A ./ sqrt (D.One - (a.P .* a.P))) a) :: t) + | Atan_DM(a) -> pushRec ((bxm (d.A ./ (D.One + (a.P .* a.P))) a) :: t) + | Abs_DM(a) -> pushRec ((bxm (d.A .* DM.Sign a.P) a) :: t) | Sign_DM(a) -> pushRec ((bxm DM.Zero a) :: t) | Floor_DM(a) -> pushRec ((bxm DM.Zero a) :: t) | Ceil_DM(a) -> pushRec ((bxm DM.Zero a) :: t) | Round_DM(a) -> pushRec ((bxm DM.Zero a) :: t) - | Transpose_DM(a) -> pushRec ((bxm (DM.Transpose(dA)) a) :: t) - | Make_DM_ofDs(a) -> pushRec (t |> List.append (List.map2 (fun v dd -> (bx v dd)) (dA |> DM.toDV |> DV.toArray |> Array.toList) (a |> Array2D.toArray |> List.ofArray))) + | Transpose_DM(a) -> pushRec ((bxm (DM.Transpose(d.A)) a) :: t) + | Make_DM_ofDs(a) -> pushRec (t |> List.append (List.map2 (fun v dd -> (bx v dd)) (d.A |> DM.toDV |> DV.toArray |> Array.toList) (a |> Array2D.toArray |> List.ofArray))) | Make_DMRows_ofDV(a) -> - dA.GetRows() |> Seq.iter (fun v -> adjoints.[a] <- adjoints.[a] + v) + d.A.GetRows() |> Seq.iter (fun v -> a.A <- a.A + v) pushRec ((bxv DV.Zero a) :: t) | Make_DMCols_ofDV(a) -> - dA.GetCols() |> Seq.iter (fun v -> adjoints.[a] <- adjoints.[a] + v) + d.A.GetCols() |> Seq.iter (fun v -> a.A <- a.A + v) pushRec ((bxv DV.Zero a) :: t) - | Make_DMRows_ofDVs(a) -> pushRec (t |> List.append (a |> List.ofArray |> List.mapi (fun i v -> (bxv dA.[i, *] v)))) - | AddItem_DM_D(a, i, j, b) -> pushRec ((bxm dA a) :: (bx (dA.[i, j]) b) :: t) - | AddItem_DM_DCons(a) -> pushRec ((bxm dA a) :: t) - | AddItem_DMCons_D(i, j, b) -> pushRec ((bx dA.[i, j] b) :: t) - | AddSubMatrix_DM_DM(a, i, j, b) -> pushRec ((bxm dA a) :: (bxm (dA.[i..(i + b.Rows - 1), j..(j + b.Cols - 1)]) b) :: t) - | AddSubMatrix_DM_DMCons(a) -> pushRec ((bxm dA a) :: t) - | AddSubMatrix_DMCons_DM(i, j, b) -> pushRec ((bxm (dA.[i..(i + b.Rows - 1), j..(j + b.Cols - 1)]) b) :: t) + | Make_DMRows_ofDVs(a) -> pushRec (t |> List.append (a |> List.ofArray |> List.mapi (fun i v -> (bxv d.A.[i, *] v)))) + | AddItem_DM_D(a, i, j, b) -> pushRec ((bxm d.A a) :: (bx (d.A.[i, j]) b) :: t) + | AddItem_DM_DCons(a) -> pushRec ((bxm d.A a) :: t) + | AddItem_DMCons_D(i, j, b) -> pushRec ((bx d.A.[i, j] b) :: t) + | AddSubMatrix_DM_DM(a, i, j, b) -> pushRec ((bxm d.A a) :: (bxm (d.A.[i..(i + b.Rows - 1), j..(j + b.Cols - 1)]) b) :: t) + | AddSubMatrix_DM_DMCons(a) -> pushRec ((bxm d.A a) :: t) + | AddSubMatrix_DMCons_DM(i, j, b) -> pushRec ((bxm (d.A.[i..(i + b.Rows - 1), j..(j + b.Cols - 1)]) b) :: t) | Slice_DM(a, i, j) -> - adjoints.[a] <- DM.AddSubMatrix(adjoints.[a], i, j, dA) + a.A <- DM.AddSubMatrix(a.A, i, j, d.A) pushRec ((bxm DM.Zero a) :: t) - | RowMatrix_DV(a) -> pushRec ((bxv (dA.[0, *]) a) :: t) - | AddDiagonal_DM_DV(a, b) -> pushRec ((bxm dA a) :: (bxv (DM.Diagonal(dA)) b) :: t) - | AddDiagonal_DM_DVCons(a) -> pushRec ((bxm dA a) :: t) - | AddDiagonal_DMCons_DV(b) -> pushRec ((bxv (DM.Diagonal(dA)) b) :: t) - | ReshapeCopy_DV_DM(a) -> pushRec ((bxv (DM.ReshapeToDV(dA)) a) :: t) - | Inverse_DM(a) -> let dpt = DM.Transpose(d.P) in pushRec ((bxm (-dpt * dA * dpt) a) :: t) // d.P = DM.Inverse(a.P) - | ReLU_DM(a) -> pushRec ((bxm (dA .* ((DM.Sign(a.P) + N.one) / N.two)) a) :: t) - | Sigmoid_DM(a) -> pushRec ((bxm (dA .* d.P .* (N.one - d.P)) a) :: t) // d.P = DM.Sigmoid(a.P) + | RowMatrix_DV(a) -> pushRec ((bxv (d.A.[0, *]) a) :: t) + | AddDiagonal_DM_DV(a, b) -> pushRec ((bxm d.A a) :: (bxv (DM.Diagonal(d.A)) b) :: t) + | AddDiagonal_DM_DVCons(a) -> pushRec ((bxm d.A a) :: t) + | AddDiagonal_DMCons_DV(b) -> pushRec ((bxv (DM.Diagonal(d.A)) b) :: t) + | ReshapeCopy_DV_DM(a) -> pushRec ((bxv (DM.ReshapeToDV(d.A)) a) :: t) + | Inverse_DM(a) -> let dpt = DM.Transpose(d.P) in pushRec ((bxm (-dpt * d.A * dpt) a) :: t) // d.P = DM.Inverse(a.P) + | ReLU_DM(a) -> pushRec ((bxm (d.A .* ((DM.Sign(a.P) + N.one) / N.two)) a) :: t) + | Sigmoid_DM(a) -> pushRec ((bxm (d.A .* d.P .* (N.one - d.P)) a) :: t) // d.P = DM.Sigmoid(a.P) | _ -> pushRec t else pushRec t | _ -> pushRec t | _ -> pushRec t - let initialv = - match v with - | :? D as v -> bxdelta (X v) - | :? DV as v -> bxdelta (XV v) - | :? DM as v -> bxdelta (XM v) - | _ -> failwith "invalid dobj" - pushRec [(initialv, d)] + pushRec [(v, d)] /// Forward and reverse differentiation operations module (automatically opened) [] module DiffOps = - let inline computeAdjoints (d: 'T :> dobj) = - let adjoints = Adjoints() - let one = LanguagePrimitives.GenericOne<'T> - reverseProp adjoints one d - adjoints - /// Original value and first derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diff' f x = - x |> makeForward GlobalTagger.Next (D.One) |> f |> primalTangent + let diff' (f: D -> D) x = + let dx = makeForward GlobalTagger.Next (D.One) x + dx |> f |> primalTangent /// First derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diff f x = diff' f x |> snd + let diff (f: D -> D) x = diff' f x |> snd /// Second derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diff2 f x = + let diff2 (f: D -> D) x : D = diff (diff f) x /// Original value, first derivative, and second derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diff2'' f x = + let diff2'' (f: D -> D) x : D * D * D = let v, d = diff' f x let d2 = diff2 f x (v, d, d2) /// Original value and second derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diff2' f x = + let diff2' (f: D -> D) x : D * D = diff2'' f x |> drop2Of3 /// `n`-th derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diffn n f x = + let diffn n (f: D -> D) x : D = if n < 0 then ErrorMessages.InvalidArgDiffn() elif n = 0 then x |> f else @@ -3980,57 +3911,59 @@ module DiffOps = x |> d n f /// Original value and `n`-th derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diffn' n f x = + let diffn' n (f: D -> D) x : D * D = (x |> f, diffn n f x) /// Original value and gradient of a vector-to-scalar function `f`, at point `x`. Reverse AD. - let inline grad' f x = + let grad' (f: DV -> D) x : D * DV = let xa = x |> makeReverse GlobalTagger.Next let z:D = f xa - let adjoints = computeAdjoints z - (z |> primal, xa |> adjoint adjoints ) + z |> reverseProp D.One + (z |> primal, xa |> adjoint) /// Gradient of a vector-to-scalar function `f`, at point `x`. Reverse AD. - let inline grad f x = + let grad (f: DV -> D) x : DV = grad' f x |> snd + /// Original value and gradient-vector product (directional derivative) of a vector-to-scalar function `f`, at point `x`, along vector `v`. Forward AD. + let gradv' (f: DV -> D) (x: DV) (v: DV) : D * D = + let dvx = makeForward GlobalTagger.Next v x + dvx |> f |> primalTangent + + /// Gradient-vector product (directional derivative) of a vector-to-scalar function `f`, at point `x`, along vector `v`. Forward AD. + let gradv (f: DV -> D) x v : D = + gradv' f x v |> snd + /// Original value and Jacobian-vector product of a vector-to-vector function `f`, at point `x`, along vector `v`. Forward AD. - let inline jacobianv' f x v = + let jacobianv' (f: DV -> DV) x v : DV * DV = x |> makeForward GlobalTagger.Next v |> f |> primalTangent /// Jacobian-vector product of a vector-to-vector function `f`, at point `x`, along vector `v`. Forward AD. - let inline jacobianv f x v = + let jacobianv (f: DV -> DV) x v : DV = jacobianv' f x v |> snd - /// Gradient-vector product (directional derivative) of a vector-to-scalar function `f`, at point `x`, along vector `v`. Forward AD. - let inline gradv f x v = jacobianv f x v - - /// Original value and gradient-vector product (directional derivative) of a vector-to-scalar function `f`, at point `x`, along vector `v`. Forward AD. - let inline gradv' f x v = jacobianv' f x v - /// Original value and a function for evaluating the transposed Jacobian-vector product of a vector-to-vector function `f`, at point `x`. Of the returned pair, the first is the original value of function `f` at point `x` (the result of the forward pass of the reverse mode AD) and the second is a function (the reverse evaluator) that can compute the transposed Jacobian-vector product many times along many different vectors (performing a new reverse pass of reverse mode AD, with the given vector, without repeating the forward pass). Reverse AD. - let inline jacobianTv'' (f:'a->'b) (x:'a) = + let jacobianTv'' (f: DV -> DV) (x:DV) = let xa = x |> makeReverse GlobalTagger.Next let z = f xa let r1 = z |> primal let r2 = - fun (v:'b) -> - let adjoints = Adjoints() - z |> reverseProp adjoints v - xa |> adjoint adjoints + fun (v:DV) -> + z |> reverseProp v + xa |> adjoint (r1, r2) /// Original value and transposed Jacobian-vector product of a vector-to-vector function `f`, at point `x`, along vector `v`. Reverse AD. - let inline jacobianTv' f x v = + let jacobianTv' (f: DV -> DV) x v = let r1, r2 = jacobianTv'' f x (r1, r2 v) /// Transposed Jacobian-vector product of a vector-to-vector function `f`, at point `x`, along vector `v`. Reverse AD. - let inline jacobianTv f x v = + let jacobianTv (f: DV -> DV) x v = jacobianTv' f x v |> snd /// Original value and Jacobian of a vector-to-vector function `f`, at point `x`. Forward or reverse AD, depending on input and output dimensions. - let inline jacobian' f (x:DV) = + let jacobian' (f: DV -> DV) (x:DV) : DV * DM = let o:DV = x |> f |> primal if x.Length > o.Length then let r = jacobianTv f x @@ -4038,88 +3971,87 @@ module DiffOps = else (o, Array.init x.Length (fun i -> jacobianv f x (DV.standardBasis x.Length i)) |> DM.ofCols) - /// Jacobian of a vector-to-vector function `f`, at point `x`. Forward or reverse AD, depending on input and output dimensions. - let inline jacobian f x = + let jacobian (f: DV -> DV) x : DM = jacobian' f x |> snd /// Original value and transposed Jacobian of a vector-to-vector function `f`, at point `x`. Forward or reverse AD, depending on input and output dimensions. - let inline jacobianT' f x = + let jacobianT' (f: DV -> DV) x = jacobian' f x |> fun (r, j) -> (r, DM.transpose j) /// Transposed Jacobian of a vector-to-vector function `f`, at point `x`. Forward or reverse AD, depending on input and output dimensions. - let inline jacobianT f x = + let jacobianT (f: DV -> DV) x : DM = jacobianT' f x |> snd /// Gradient and Hessian of a vector-to-scalar function `f`, at point `x`. Forward-on-reverse AD. - let inline gradhessian f x = + let gradhessian (f: DV -> D) x : DV * DM = jacobian' (grad f) x /// Original value, gradient, and Hessian of a vector-to-scalar function `f`, at point `x`. Forward-on-reverse AD. - let inline gradhessian' f x = + let gradhessian' (f: DV -> D) x : D * DV * DM = let g, h = gradhessian f x (x |> f , g, h) /// Hessian of a vector-to-scalar function `f`, at point `x`. Forward-on-reverse AD. - let inline hessian f x = + let hessian (f: DV -> D) x : DM = jacobian (grad f) x /// Original value and Hessian of a vector-to-scalar function `f`, at point `x`. Forward-on-reverse AD. - let inline hessian' f x = + let hessian' (f: DV -> D) x : D * DM = (x |> f, hessian f x) /// Original value, gradient-vector product (directional derivative), and Hessian-vector product of a vector-to-scalar function `f`, at point `x`, along vector `v`. Reverse-on-forward AD. - let inline gradhessianv' f x v = - let gv, hv = grad' (fun xx -> jacobianv f xx v) x + let gradhessianv' (f: DV -> D) x v = + let gv, hv = grad' (fun xx -> gradv f xx v) x (x |> f, gv, hv) /// Gradient-vector product (directional derivative) and Hessian-vector product of a vector-to-scalar function `f`, at point `x`, along vector `v`. Reverse-on-forward AD. - let inline gradhessianv f x v = + let gradhessianv (f: DV -> D) x v : D * DV = gradhessianv' f x v |> drop1Of3 /// Original value and Hessian-vector product of a vector-to-scalar function `f`, at point `x`, along vector `v`. Reverse-on-forward AD. - let inline hessianv' f x v = + let hessianv' (f: DV -> D) x v = gradhessianv' f x v |> drop2Of3 /// Hessian-vector product of a vector-to-scalar function `f`, at point `x`, along vector `v`. Reverse-on-forward AD. - let inline hessianv f x v = + let hessianv (f: DV -> D) x v : DV = hessianv' f x v |> snd /// Original value and Laplacian of a vector-to-scalar function `f`, at point `x`. Reverse-on-forward AD. - let inline laplacian' f x = // TODO: reimplement faster + let laplacian' (f: DV -> D) x : D * D = // TODO: reimplement faster let v, h = hessian' f x (v, DM.trace h) /// Laplacian of a vector-to-scalar function `f`, at point `x`. Reverse-on-forward AD. - let inline laplacian f x = + let laplacian (f: DV -> D) x : D = laplacian' f x |> snd /// Original value and curl of a vector-to-vector function `f`, at point `x`. Supported only for functions with a three-by-three Jacobian matrix. Forward AD. - let inline curl' f x = + let curl' (f: DV -> DV) x = let v, j = jacobianT' f x if (j.Rows, j.Cols) <> (3, 3) then ErrorMessages.InvalidArgCurl() v, toDV [|j.[1, 2] - j.[2, 1]; j.[2, 0] - j.[0, 2]; j.[0, 1] - j.[1, 0]|] /// Curl of a vector-to-vector function `f`, at point `x`. Supported only for functions with a three-by-three Jacobian matrix. Forward AD. - let inline curl f x = + let curl (f: DV -> DV) x : DV = curl' f x |> snd /// Original value and divergence of a vector-to-vector function `f`, at point `x`. Defined only for functions with a square Jacobian matrix. Forward AD. - let inline div' f x = + let div' (f: DV -> DV) x = let v, j = jacobianT' f x if j.Rows <> j.Cols then ErrorMessages.InvalidArgDiv() v, DM.trace j /// Divergence of a vector-to-vector function `f`, at point `x`. Defined only for functions with a square Jacobian matrix. Forward AD. - let inline div f x = + let div (f: DV -> DV) x : D = div' f x |> snd /// Original value, curl, and divergence of a vector-to-vector function `f`, at point `x`. Supported only for functions with a three-by-three Jacobian matrix. Forward AD. - let inline curldiv' f x = + let curldiv' (f: DV -> DV) x = let v, j = jacobianT' f x if (j.Rows, j.Cols) <> (3, 3) then ErrorMessages.InvalidArgCurlDiv() v, toDV [|j.[1, 2] - j.[2, 1]; j.[2, 0] - j.[0, 2]; j.[0, 1] - j.[1, 0]|], DM.trace j /// Curl and divergence of a vector-to-vector function `f`, at point `x`. Supported only for functions with a three-by-three Jacobian matrix. Forward AD. - let inline curldiv f x = + let curldiv (f: DV -> DV) x : DV * D = curldiv' f x |> drop1Of3 diff --git a/src/DiffSharp/AD.Float64.fs b/src/DiffSharp/AD.Float64.fs index f146c178e..1eef5b169 100644 --- a/src/DiffSharp/AD.Float64.fs +++ b/src/DiffSharp/AD.Float64.fs @@ -12,13 +12,13 @@ module DiffSharp.AD.Float64 open DiffSharp.Util open DiffSharp.Config -open System.Threading.Tasks open System.Collections.Generic type number = float let inline Backend<'T> = GlobalConfig.Float64Backend let inline VisualizationContrast<'T> = GlobalConfig.Float64VisualizationContrast let inline FixedPointEpsilon<'T> = GlobalConfig.Float64FixedPointEpsilon + module N = let inline toNumber x = float x let inline failWithInvalidTypeMessage () = failwith "Unsupported type. Expecting D, float, or int." @@ -35,24 +35,27 @@ module N = /// with nesting capability, using tags to avoid perturbation confusion [] type D = + /// Primal | D of number + /// Primal, tangent, layer tag (for forward mode) - | DF of primal: D * tanget: D * tag: uint32 + | DF of primal: D * tanget: D * tag: uint32 + /// Primal, parent, layer tag (for reverse mode) - | DR of primal: D * parentOperation: TraceOp * tag: uint32 * uniq: int32 + | DR of primal: D * adjoint: (D ref) * parentOperation: TraceOp * fanOutCounter: (uint32 ref) * tag: uint32 interface dobj /// Make a reverse node - static member R(d, op, ai) = DR(d, op, ai, UniqueTagger.Next()) + static member R(d,op,ai) = DR(d, ref D.Zero, op, ref 0u, ai) /// Primal value of this D member d.P = match d with | D _ -> d | DF(ap, _, _) -> ap - | DR(ap, _, _, _) -> ap + | DR(ap, _, _, _, _) -> ap /// Deepest primal value of this D member d.PD = @@ -60,7 +63,7 @@ type D = match x with | D _ -> x | DF(xp, _, _) -> prec xp - | DR(xp, _, _, _) -> prec xp + | DR(xp, _, _, _, _) -> prec xp prec d /// Tangent value of this D @@ -70,11 +73,37 @@ type D = | DF(_, at, _) -> at | DR _ -> failwith "Cannot get tangent value of DR." + /// Adjoint script of this D + member d.A + with get() : D = + match d with + | D _ -> D.Zero + | DF(_,_,_) -> failwith "Cannot get adjoint value of DF." + | DR(_,a,_,_,_) -> !a + and set(v: D) = + match d with + | D _ -> () + | DF (_,_,_) -> failwith "Cannot set adjoint value of DF." + | DR (_,a,_,_,_) -> a := v + + /// Fan-out counter of this D + member d.F + with get() = + match d with + | D _ -> failwith "Cannot get fan-out value of D." + | DF (_,_,_) -> failwith "Cannot get fan-out value of DF." + | DR (_,_,_,f,_) -> !f + and set(v) = + match d with + | D _ -> failwith "Cannot set fan-out value of D." + | DF (_,_,_) -> failwith "Cannot set fan-out value of DF." + | DR (_,_,_,f,_) -> f := v + member d.GetForward(t:D, i:uint32) = DF(d, t, i) member d.GetReverse(i:uint32) = D.R(d, Noop, i) - static member Zero = D N.zero + static member Zero : D = D N.zero static member One = D N.one @@ -83,7 +112,7 @@ type D = match x with | D(p) -> p | DF(xp, _, _) -> prec xp - | DR(xp, _, _, _) -> prec xp + | DR(xp, _, _, _, _) -> prec xp prec d interface System.IComparable with @@ -101,7 +130,7 @@ type D = match d with | D(ap) -> hash [|ap|] | DF(ap, at, ai) -> hash [|ap; at; ai|] - | DR(ap, ao, ai, _) -> hash [|ap; ao; ai|] + | DR(ap, ao, ai, _, _) -> hash [|ap; ao; ai|] override d.ToString() = let (d':number) = D.op_Explicit(d) @@ -112,9 +141,9 @@ type D = static member inline Op_D_D (a, ff, fd, df, r) = match a with - | D(ap) -> D(ff(ap)) - | DF(ap, at, ai) -> let cp = fd(ap) in DF(cp, df(cp, ap, at), ai) - | DR(ap, _, ai, _) -> D.R(fd(ap), r(a), ai) + | D(ap) -> D(ff(ap)) + | DF(ap, at, ai) -> let cp = fd(ap) in DF(cp, df(cp, ap, at), ai) + | DR(ap,_,_,_,ai) -> D.R(fd(ap), r(a), ai) static member inline Op_D_D_D (a, b, ff, fd, df_da, df_db, df_dab, r_d_d, r_d_c, r_c_d) = match a with @@ -122,7 +151,7 @@ type D = match b with | D(bp) -> D(ff(ap, bp)) | DF(bp, bt, bi) -> let cp = fd(a, bp) in DF(cp, df_db(cp, bp, bt), bi) - | DR(bp, _, bi, _) -> D.R(fd(a, bp), r_c_d(a, b), bi) + | DR(bp, _, _, _, bi) -> D.R(fd(a, bp), r_c_d(a, b), bi) | DF(ap, at, ai) -> match b with | D _ -> let cp = fd(ap, b) in DF(cp, df_da(cp, ap, at), ai) @@ -131,12 +160,12 @@ type D = | 0 -> let cp = fd(ap, bp) in DF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DF(cp, df_da(cp, ap, at), ai) // ai > bi - | DR(bp, _, bi, _) -> + | DR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> D.R(fd(a, bp), r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DR(ap, _, ai, _) -> + | DR(ap, _, _, _, ai) -> match b with | D _ -> D.R(fd(ap, b), r_d_c(a, b), ai) | DF(bp, bt, bi) -> @@ -144,7 +173,7 @@ type D = | -1 -> let cp = fd(a, bp) in DF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> D.R(fd(ap, b), r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DR(bp, _, bi, _) -> + | DR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> D.R(fd(ap, bp), r_d_d(a, b), ai) // ai = bi | -1 -> D.R(fd(a, bp), r_c_d(a, b), bi) // ai < bi @@ -439,7 +468,7 @@ type D = i <- imax a <- aa DF(a.P, a.T, bi) - | DR(bp, _, bi, _) -> + | DR(bp,_,_,_,bi) -> let bfirst = D.R(bp, Noop, bi) // Cut the connection between b and bfirst ("switch of graph construction" involving b beyond this point) while i < imax do i <- i + 1 @@ -461,7 +490,7 @@ type D = and DV = | DV of number[] // Primal | DVF of DV * DV * uint32 // Primal, tangent, layer tag - | DVR of DV * TraceOp * uint32 * int32 // Primal, parent operation, layer tag, unique + | DVR of primal: DV * adjoint: (DV ref) * TraceOp * (uint32 ref) * uint32 // Primal, adjoint, parent operation, fan-out counter, tag interface dobj @@ -470,7 +499,7 @@ and DV = match d with | DV _ -> d | DVF(ap, _, _) -> ap - | DVR(ap, _, _, _) -> ap + | DVR(ap,_,_,_,_) -> ap /// Deepest primal value of this DV member d.PD = @@ -478,15 +507,41 @@ and DV = match x with | DV _ -> x | DVF(xp, _, _) -> prec xp - | DVR(xp, _, _, _) -> prec xp + | DVR(xp,_,_,_,_) -> prec xp prec d /// Tangent value of this DV member d.T = match d with - | DV _ -> DV.ZeroN d.Length - | DVF(_, at, _) -> at - | DVR _ -> failwith "Cannot get tangent value of DVR." + | DV(_) -> DV.ZeroN d.Length + | DVF(_,at,_) -> at + | DVR(_,_,_,_,_) -> failwith "Cannot get tangent value of DVR." + + /// Adjoint value of this DV + member d.A + with get() : DV = + match d with + | DV(_) -> DV.ZeroN d.Length + | DVF(_,_,_) -> failwith "Cannot get adjoint value of DVF." + | DVR(_,a,_,_,_) -> !a + and set(v: DV) = + match d with + | DV(_) -> () + | DVF(_,_,_) -> failwith "Cannot set adjoint value of DVF." + | DVR(_,a,_,_,_) -> a := v + + /// Fan-out counter of this DV + member d.F + with get() = + match d with + | DV(_) -> failwith "Cannot get fan-out value of DV." + | DVF(_,_,_) -> failwith "Cannot get fan-out value of DVF." + | DVR(_,_,_,f,_) -> !f + and set(v) = + match d with + | DV(_) -> failwith "Cannot set fan-out value of DV." + | DVF(_,_,_) -> failwith "Cannot set fan-out value of DVF." + | DVR(_,_,_,f,_) -> f := v /// Convert to use forward AD at this layer member d.GetForward(t:DV, i:uint32) = DVF(d, t, i) @@ -495,20 +550,20 @@ and DV = member d.GetReverse(i:uint32) = DV.R(d, Noop, i) /// Make a reverse node - static member R(d, op, ai) = DVR(d, op, ai, UniqueTagger.Next()) + static member R(d,op,ai) = DVR(d, ref (DV.ZeroN d.Length), op, ref 0u, ai) member d.Length = match d with | DV(ap) -> ap.Length | DVF(ap, _, _) -> ap.Length - | DVR(ap, _, _, _) -> ap.Length + | DVR(ap, _, _, _, _) -> ap.Length member d.Item with get i = match d with | DV(ap) -> D(ap.[i]) | DVF(ap, at, ai) -> DF(ap.[i], at.[i], ai) - | DVR(ap, _, ai, _) -> D.R(ap.[i], Item_DV(d, i), ai) + | DVR(ap, _, _, _, ai) -> D.R(ap.[i], Item_DV(d, i), ai) member d.GetSlice(lower, upper) = let l = defaultArg lower 0 @@ -516,21 +571,21 @@ and DV = match d with | DV(ap) -> DV(ap.[l..u]) | DVF(ap, at, ai) -> DVF(ap.[l..u], at.[l..u], ai) - | DVR(ap, _, ai, _) -> let cp = ap.[l..u] in DV.R(cp, Slice_DV(d, l), ai) + | DVR(ap, _, _, _, ai) -> let cp = ap.[l..u] in DV.R(cp, Slice_DV(d, l), ai) member d.ToArray() = match d with | DV(ap) -> ap |> Array.map D | DVF(ap, at, ai) -> Array.init ap.Length (fun i -> DF(ap.[i], at.[i], ai)) - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> Array.init ap.Length (fun i -> D.R(ap.[i], Item_DV(d, i), ai)) member d.ToRowDM() = match d with | DV(ap) -> seq [ap] |> array2D |> DM | DVF(ap, at, ai) -> DMF(ap.ToRowDM(), at.ToRowDM(), ai) - | DVR(ap, _, ai, _) -> let cp = ap.ToRowDM() in DM.R(cp, RowMatrix_DV(d), ai) + | DVR(ap, _, _, _, ai) -> let cp = ap.ToRowDM() in DM.R(cp, RowMatrix_DV(d), ai) member d.ToColDM() = DM.Transpose(d.ToRowDM()) @@ -574,7 +629,7 @@ and DV = match x with | DV(p) -> p | DVF(xp, _, _) -> prec xp - | DVR(xp, _, _, _) -> prec xp + | DVR(xp,_,_,_,_) -> prec xp prec d static member op_Explicit(d) = DV(d) @@ -587,7 +642,7 @@ and DV = let ap = a |> Array.map (fun x -> x.P) let at = a |> Array.map (fun x -> x.T) DVF(DV.OfArray(ap), DV.OfArray(at), ai) - | DR(_, _, ai, _) -> + | DR(_,_,_,_,ai) -> let ap = a |> Array.map (fun x -> x.P) let cp = DV.OfArray(ap) in DV.R(cp, Make_DV_ofDs(a), ai) @@ -600,7 +655,7 @@ and DV = let aps = DV.Split(ap, n) let ats = DV.Split(at, n) Seq.map2 (fun p t -> DVF(p, t, ai)) aps ats - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> let aps = DV.Split(ap, n) let ii = n |> Seq.mapFold (fun s i -> s, s + i) 0 |> fst |> Array.ofSeq Seq.mapi (fun i p -> DV.R(p, Split_DV(d, ii.[i]), ai)) aps @@ -610,19 +665,19 @@ and DV = match a with | DV(ap) -> DV(ff(ap)) | DVF(ap, at, ai) -> let cp = fd(ap) in DVF(cp, df(cp, ap, at), ai) - | DVR(ap, _, ai, _) -> let cp = fd(ap) in DV.R(cp, r(a), ai) + | DVR(ap,_,_,_,ai) -> let cp = fd(ap) in DV.R(cp, r(a), ai) static member inline Op_DV_DM (a, ff, fd, df, r) = match a with | DV(ap) -> DM(ff(ap)) | DVF(ap, at, ai) -> let cp = fd(ap) in DMF(cp, df(cp, ap, at), ai) - | DVR(ap, _, ai, _) -> let cp = fd(ap) in DM.R(cp, r(a), ai) + | DVR(ap,_,_,_,ai) -> let cp = fd(ap) in DM.R(cp, r(a), ai) static member inline Op_DV_D (a, ff, fd, df, r) = match a with | DV(ap) -> D(ff(ap)) | DVF(ap, at, ai) -> let cp = fd(ap) in DF(cp, df(cp, ap, at), ai) - | DVR(ap, _, ai, _) -> let cp = fd(ap) in D.R(cp, r(a), ai) + | DVR(ap,_,_,_,ai) -> let cp = fd(ap) in D.R(cp, r(a), ai) static member inline Op_DV_DV_DV (a, b, ff, fd, df_da, df_db, df_dab, r_d_d, r_d_c, r_c_d) = match a with @@ -630,7 +685,7 @@ and DV = match b with | DV(bp) -> DV(ff(ap, bp)) | DVF(bp, bt, bi) -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) - | DVR(bp, _, bi, _) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) + | DVR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) | DVF(ap, at, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) @@ -639,12 +694,12 @@ and DV = | 0 -> let cp = fd(ap, bp) in DVF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) | DVF(bp, bt, bi) -> @@ -652,7 +707,7 @@ and DV = | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DV.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi @@ -664,7 +719,7 @@ and DV = match b with | DV(bp) -> DM(ff(ap, bp)) | DVF(bp, bt, bi) -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) - | DVR(bp, _, bi, _) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) + | DVR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) | DVF(ap, at, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) @@ -673,12 +728,12 @@ and DV = | 0 -> let cp = fd(ap, bp) in DMF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) | DVF(bp, bt, bi) -> @@ -686,7 +741,7 @@ and DV = | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DM.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi @@ -698,7 +753,7 @@ and DV = match b with | DV(bp) -> D(ff(ap, bp)) | DVF(bp, bt, bi) -> let cp = fd(a, bp) in DF(cp, df_db(cp, bp, bt), bi) - | DVR(bp, _, bi, _) -> D.R(fd(a, bp), r_c_d(a, b), bi) + | DVR(bp, _, _, _, bi) -> D.R(fd(a, bp), r_c_d(a, b), bi) | DVF(ap, at, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DF(cp, df_da(cp, ap, at), ai) @@ -707,12 +762,12 @@ and DV = | 0 -> let cp = fd(ap, bp) in DF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DF(cp, df_da(cp, ap, at), ai) // ai > bi - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> D.R(fd(a, bp), r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> match b with | DV _ -> D.R(fd(ap, b), r_d_c(a, b), ai) | DVF(bp, bt, bi) -> @@ -720,7 +775,7 @@ and DV = | -1 -> let cp = fd(a, bp) in DF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> D.R(fd(ap, b), r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> D.R(fd(ap, bp), r_d_d(a, b), ai) // ai = bi | -1 -> D.R(fd(a, bp), r_c_d(a, b), bi) // ai < bi @@ -732,7 +787,7 @@ and DV = match b with | D(bp) -> DV(ff(ap, bp)) | DF(bp, bt, bi) -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) - | DR(bp, _, bi, _) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) + | DR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) | DVF(ap, at, ai) -> match b with | D _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) @@ -741,12 +796,12 @@ and DV = | 0 -> let cp = fd(ap, bp) in DVF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi - | DR(bp, _, bi, _) -> + | DR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> match b with | D _ -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) | DF(bp, bt, bi) -> @@ -754,7 +809,7 @@ and DV = | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DR(bp, _, bi, _) -> + | DR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DV.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi @@ -767,7 +822,7 @@ and DV = match b with | DV(bp) -> DV(ff(ap, bp)) | DVF(bp, bt, bi) -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) - | DVR(bp, _, bi, _) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) + | DVR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) | DF(ap, at, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) @@ -776,12 +831,12 @@ and DV = | 0 -> let cp = fd(ap, bp) in DVF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DR(ap, _, ai, _) -> + | DR(ap, _, _, _, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) | DVF(bp, bt, bi) -> @@ -789,7 +844,7 @@ and DV = | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DV.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1407,7 +1462,7 @@ and DM = /// Primal, tangent, layer tag (for forward mode) | DMF of primal: DM * tanget: DM * tag: uint32 /// Primal, parent, layer tag (for reverse mode) - | DMR of primal: DM * parentOperation: TraceOp * tag: uint32 * uniq: int32 + | DMR of primal: DM * adjoint: (DM ref) * parentOperation: TraceOp * fanOutCounter: (uint32 ref) * tag: uint32 interface dobj @@ -1416,7 +1471,7 @@ and DM = match d with | DM(_) -> d | DMF(ap, _, _) -> ap - | DMR(ap, _, _, _) -> ap + | DMR(ap,_,_,_,_) -> ap /// Deepest primal value of this DM member d.PD = @@ -1424,47 +1479,73 @@ and DM = match x with | DM(_) -> x | DMF(xp, _, _) -> prec xp - | DMR(xp, _, _, _) -> prec xp + | DMR(xp,_,_,_,_) -> prec xp prec d /// Tangent value of this DM member d.T = match d with - | DM(_) -> DM.ZeroMN d.Rows d.Cols + | DM _ -> DM.ZeroMN d.Rows d.Cols | DMF(_, at, _) -> at | DMR _ -> failwith "Cannot get tangent value of DMR." + /// Adjoint value of this DM + member d.A + with get() : DM = + match d with + | DM _ -> DM.ZeroMN d.Rows d.Cols + | DMF(_,_,_) -> failwith "Cannot get adjoint value of DMF." + | DMR(_,a,_,_,_) -> !a + and set(v: DM) = + match d with + | DM _ -> () + | DMF(_,_,_) -> failwith "Cannot set adjoint value of DMF." + | DMR(_,a,_,_,_) -> a := v + + /// Fan-out value of this DM + member d.F + with get() = + match d with + | DM _ -> failwith "Cannot get fan-out value of DM." + | DMF(_,_,_) -> failwith "Cannot get fan-out value of DMF." + | DMR(_,_,_,f,_) -> !f + and set(v) = + match d with + | DM(_) -> failwith "Cannot set fan-out value of DM." + | DMF(_,_,_) -> failwith "Cannot set fan-out value of DMF." + | DMR(_,_,_,f,_) -> f := v + member d.GetForward(t:DM, i:uint32) = DMF(d, t, i) member d.GetReverse(i:uint32) = DM.R(d, Noop, i) /// Make a reverse node - static member R(cp, op, ai) = DMR(cp, op, ai, UniqueTagger.Next()) + static member R(cp,op,ai) = DMR(cp, ref (DM.ZeroMN cp.Rows cp.Cols), op, ref 0u, ai) member d.Length = match d with | DM(ap) -> ap.Length | DMF(ap, _, _) -> ap.Length - | DMR(ap, _, _, _) -> ap.Length + | DMR(ap,_,_,_,_) -> ap.Length member d.Rows = match d with | DM(ap) -> Array2D.length1 ap | DMF(ap, _, _) -> ap.Rows - | DMR(ap, _, _, _) -> ap.Rows + | DMR(ap, _, _, _, _) -> ap.Rows member d.Cols = match d with | DM(ap) -> Array2D.length2 ap | DMF(ap, _, _) -> ap.Cols - | DMR(ap, _, _, _) -> ap.Cols + | DMR(ap, _, _, _, _) -> ap.Cols member d.Item with get (i, j) = match d with | DM(ap) -> D(ap.[i, j]) | DMF(ap, at, ai) -> DF(ap.[i, j], at.[i, j], ai) - | DMR(ap, _, ai, _) -> D.R(ap.[i, j], Item_DM(d, i, j), ai) + | DMR(ap, _, _, _, ai) -> D.R(ap.[i, j], Item_DM(d, i, j), ai) member d.GetSlice(rowStart, rowFinish, colStart, colFinish) = let rowStart = defaultArg rowStart 0 @@ -1474,7 +1555,7 @@ and DM = match d with | DM(ap) -> DM(ap.[rowStart..rowFinish, colStart..colFinish]) | DMF(ap, at, ai) -> DMF(ap.[rowStart..rowFinish, colStart..colFinish], at.[rowStart..rowFinish, colStart..colFinish], ai) - | DMR(ap, _, ai, _) -> let cp = ap.[rowStart..rowFinish, colStart..colFinish] in DM.R(cp, Slice_DM(d, rowStart, rowFinish), ai) + | DMR(ap, _, _, _, ai) -> let cp = ap.[rowStart..rowFinish, colStart..colFinish] in DM.R(cp, Slice_DM(d, rowStart, colStart), ai) member d.GetSlice(row, colStart, colFinish) = let colStart = defaultArg colStart 0 @@ -1482,7 +1563,7 @@ and DM = match d with | DM(ap) -> DV(ap.[row, colStart..colFinish]) | DMF(ap, at, ai) -> DVF(ap.[row, colStart..colFinish], at.[row, colStart..colFinish], ai) - | DMR(ap, _, ai, _) -> let cp = ap.[row, colStart..colFinish] in DV.R(cp, SliceRow_DM(d, row, colStart), ai) + | DMR(ap, _, _, _, ai) -> let cp = ap.[row, colStart..colFinish] in DV.R(cp, SliceRow_DM(d, row, colStart), ai) member d.GetSlice(rowStart, rowFinish, col) = let rowStart = defaultArg rowStart 0 @@ -1490,7 +1571,7 @@ and DM = match d with | DM(ap) -> DV(ap.[rowStart..rowFinish, col]) | DMF(ap, at, ai) -> DVF(ap.[rowStart..rowFinish, col], at.[rowStart..rowFinish, col], ai) - | DMR(ap, _, ai, _) -> let cp = ap.[rowStart..rowFinish, col] in DV.R(cp, SliceCol_DM(d, rowStart, col), ai) + | DMR(ap, _, _, _, ai) -> let cp = ap.[rowStart..rowFinish, col] in DV.R(cp, SliceCol_DM(d, rowStart, col), ai) member d.GetRows() = seq {for i = 0 to d.Rows - 1 do yield d.[i, *]} @@ -1546,7 +1627,7 @@ and DM = match x with | DM(p) -> p | DMF(xp, _, _) -> prec xp - | DMR(xp, _, _, _) -> prec xp + | DMR(xp, _, _, _, _) -> prec xp prec d static member op_Explicit(d:number[, ]) = DM(d) @@ -1559,7 +1640,7 @@ and DM = let ap = a |> Array2D.map (fun x -> x.P) let at = a |> Array2D.map (fun x -> x.T) DMF(DM.OfArray2D(ap), DM.OfArray2D(at), ai) - | DR(_, _, ai, _) -> + | DR(_, _, _, _, ai) -> let ap = a |> Array2D.map (fun x -> x.P) let cp = DM.OfArray2D(ap) in DM.R(cp, Make_DM_ofDs(a), ai) @@ -1577,7 +1658,7 @@ and DM = let ap = s |> Seq.map (fun x -> x.P) let at = s |> Seq.map (fun x -> x.T) DMF(DM.OfRows(ap), DM.OfRows(at), ai) - | DVR(_, _, ai, _) -> + | DVR(_, _, _, _, ai) -> let ap = s |> Seq.map (fun x -> x.P) let cp = DM.OfRows(ap) in DM.R(cp, Make_DMRows_ofDVs(s |> Seq.toArray), ai) @@ -1585,33 +1666,33 @@ and DM = match a with | DV(ap) -> DM(Backend.RepeatReshapeCopy_V_MRows(m, ap)) | DVF(ap, at, ai) -> DMF(DM.OfRows(m, ap), DM.OfRows(m, at), ai) - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> let cp = DM.OfRows(m, ap) in DM.R(cp, Make_DMRows_ofDV(a), ai) static member OfCols (n:int, a:DV) = match a with | DV(ap) -> DM(Backend.RepeatReshapeCopy_V_MCols(n, ap)) | DVF(ap, at, ai) -> DMF(DM.OfCols(n, ap), DM.OfCols(n, at), ai) - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> let cp = DM.OfCols(n, ap) in DM.R(cp, Make_DMCols_ofDV(a), ai) static member inline Op_DM_DM (a, ff, fd, df, r) = match a with | DM(ap) -> DM(ff(ap)) | DMF(ap, at, ai) -> let cp = fd(ap) in DMF(cp, df(cp, ap, at), ai) - | DMR(ap, _, ai, _) -> let cp = fd(ap) in DM.R(cp, r(a), ai) + | DMR(ap, _, _, _, ai) -> let cp = fd(ap) in DM.R(cp, r(a), ai) static member inline Op_DM_DV (a, ff, fd, df, r) = match a with | DM(ap) -> DV(ff(ap)) | DMF(ap, at, ai) -> let cp = fd(ap) in DVF(cp, df(cp, ap, at), ai) - | DMR(ap, _, ai, _) -> let cp = fd(ap) in DV.R(cp, r(a), ai) + | DMR(ap, _, _, _, ai) -> let cp = fd(ap) in DV.R(cp, r(a), ai) static member inline Op_DM_D (a, ff, fd, df, r) = match a with | DM(ap) -> D(ff(ap)) | DMF(ap, at, ai) -> let cp = fd(ap) in DF(cp, df(cp, ap, at), ai) - | DMR(ap, _, ai, _) -> let cp = fd(ap) in D.R(cp, r(a), ai) + | DMR(ap, _, _, _, ai) -> let cp = fd(ap) in D.R(cp, r(a), ai) static member inline Op_DM_DM_DM (a, b, ff, fd, df_da, df_db, df_dab, r_d_d, r_d_c, r_c_d) = match a with @@ -1619,7 +1700,7 @@ and DM = match b with | DM(bp) -> DM(ff(ap, bp)) | DMF(bp, bt, bi) -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) - | DMR(bp, _, bi, _) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) + | DMR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) | DMF(ap, at, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) @@ -1628,12 +1709,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DMF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(ap, _, ai, _) -> + | DMR(ap, _, _, _, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) | DMF(bp, bt, bi) -> @@ -1641,7 +1722,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DM.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1653,7 +1734,7 @@ and DM = match b with | D(bp) -> DM(ff(ap, bp)) | DF(bp, bt, bi) -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) - | DR(bp, _, bi, _) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) + | DR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) | DMF(ap, at, ai) -> match b with | D _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) @@ -1662,12 +1743,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DMF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi - | DR(bp, _, bi, _) -> + | DR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(ap, _, ai, _) -> + | DMR(ap, _, _, _, ai) -> match b with | D _ -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) | DF(bp, bt, bi) -> @@ -1675,7 +1756,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DR(bp, _, bi, _) -> + | DR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DM.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1687,7 +1768,7 @@ and DM = match b with | DM(bp) -> DM(ff(ap, bp)) | DMF(bp, bt, bi) -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) - | DMR(bp, _, bi, _) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) + | DMR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) | DF(ap, at, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) @@ -1696,12 +1777,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DMF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DR(ap, _, ai, _) -> + | DR(ap, _, _, _, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) | DMF(bp, bt, bi) -> @@ -1709,7 +1790,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DM.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1721,7 +1802,7 @@ and DM = match b with | DV(bp) -> DV(ff(ap, bp)) | DVF(bp, bt, bi) -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) - | DVR(bp, _, bi, _) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) + | DVR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) | DMF(ap, at, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) @@ -1730,12 +1811,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DVF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(ap, _, ai, _) -> + | DMR(ap, _, _, _, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) | DVF(bp, bt, bi) -> @@ -1743,7 +1824,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DV.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1755,7 +1836,7 @@ and DM = match b with | DM(bp) -> DV(ff(ap, bp)) | DMF(bp, bt, bi) -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) - | DMR(bp, _, bi, _) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) + | DMR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) | DVF(ap, at, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) @@ -1764,12 +1845,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DVF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DVF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) | DMF(bp, bt, bi) -> @@ -1777,7 +1858,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DVF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DV.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DV.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DV.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1789,7 +1870,7 @@ and DM = match b with | DV(bp) -> DM(ff(ap, bp)) | DVF(bp, bt, bi) -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) - | DVR(bp, _, bi, _) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) + | DVR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) | DMF(ap, at, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) @@ -1798,12 +1879,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DMF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(ap, _, ai, _) -> + | DMR(ap, _, _, _, ai) -> match b with | DV _ -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) | DVF(bp, bt, bi) -> @@ -1811,7 +1892,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(bp, _, bi, _) -> + | DVR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DM.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi @@ -1823,7 +1904,7 @@ and DM = match b with | DM(bp) -> DM(ff(ap, bp)) | DMF(bp, bt, bi) -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) - | DMR(bp, _, bi, _) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) + | DMR(bp, _, _, _, bi) -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) | DVF(ap, at, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) @@ -1832,12 +1913,12 @@ and DM = | 0 -> let cp = fd(ap, bp) in DMF(cp, df_dab(cp, ap, at, bp, bt), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | _ -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DMF(cp, df_da(cp, ap, at), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DVR(ap, _, ai, _) -> + | DVR(ap, _, _, _, ai) -> match b with | DM(_) -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) | DMF(bp, bt, bi) -> @@ -1845,7 +1926,7 @@ and DM = | -1 -> let cp = fd(a, bp) in DMF(cp, df_db(cp, bp, bt), bi) // ai < bi | 1 -> let cp = fd(ap, b) in DM.R(cp, r_d_c(a, b), ai) // ai > bi | _ -> failwith "Forward and reverse AD cannot run on the same level." - | DMR(bp, _, bi, _) -> + | DMR(bp, _, _, _, bi) -> match compare ai bi with | 0 -> let cp = fd(ap, bp) in DM.R(cp, r_d_d(a, b), ai) // ai = bi | -1 -> let cp = fd(a, bp) in DM.R(cp, r_c_d(a, b), bi) // ai < bi @@ -2751,6 +2832,7 @@ and TraceOp = /// A constraint used to ensure the evaluation stack is only over D, DV or DM and dobj = interface end +let bxd (x : dobj) = x /// Functional-oriented operations on vectors. Implementing functionality similar to FSharp.Collections.Array. [] @@ -3116,123 +3198,6 @@ module DM = let inline visualize (m:DM) = m.Visualize() let inline visualizeAsDV (m:DM) = DM.ReshapeToDV(m).Visualize() - -// Scripts for adjusting the adjoint -type Delta = - | X of D - //| XNeg of DeltaV - interface delta -and DeltaV = - | XV of DV - //| XNegV of DeltaV - interface delta -and DeltaM = - | XM of DM - | XNegM of DeltaM - interface delta -and delta = interface end - -/// Represents the computed adjoints for reverse AD. This is a table indexed by node ID. -/// The table is destructively updated as the adjoints are accumulated. -type Adjoints() = - let dict = Dictionary() - - let rec eval d = - match d with - | X d -> d - and evalV d = - match d with - | XV d -> d - and evalM d = - match d with - | XM v -> v - | XNegM v -> -(evalM v) - - member internal __.GetD(uniq: int) = dict.[uniq] :?> D - member internal __.SetD(uniq:int, v:D) = dict.[uniq] <- v - member internal __.GetDV(uniq: int) = dict.[uniq] :?> DV - member internal __.SetDV(uniq:int, v:DV) = dict.[uniq] <- v - member internal __.GetDM(uniq: int) = dict.[uniq] :?> DM - member internal __.SetDM(uniq:int, v:DM) = dict.[uniq] <- v - - // adj <- adj + interp(delta) - member internal __.ApplyDelta(uniq: int, x:Delta) = - let adj = dict.[uniq] :?> D - let res = eval x + adj - dict.[uniq] <- res - res - - // adj <- adj + interp(delta) - member internal __.ApplyDeltaV(uniq: int, x:DeltaV) = - let adj = dict.[uniq] :?> DV - match adj,x with - | DV adjv, XV (DV xv) -> - Backend.Add_V_V_Inplace(xv, adjv) - adj - | _ -> - let res = DV.Add_V_V_Inplace(evalV x,adj) - dict.[uniq] <- res - res - - // adj <- adj + interp(delta) - member internal __.ApplyDeltaM(uniq: int, x:DeltaM) = - let adj = dict.[uniq] :?> DM - match adj,x with - | DM adjm, XM (DM xm) -> - Backend.AlphaAdd_M_M_Inplace(N.one, xm, adjm) - adj - | DM adjm, XNegM (XM (DM xm)) -> - // TODO: also perform the inplace update in the case where adj is not "DM adj" - // However this needs care. - Backend.AlphaAdd_M_M_Inplace(N.minus1, xm, adjm) - adj - | _ -> - let adj = DM.Add_M_M_Inplace(evalM x,adj) - dict.[uniq] <- adj - adj - - /// Lookup the adjoint for a value - member this.Item - with get (d:D) : D = - match d with - | D _ -> D.Zero - | DF _ -> failwith "Cannot get adjoint value of DF. Use makeReverse on this node when composing the computation." - | DR (_, _, _, uniq) -> this.GetD(uniq) - and set (d:D) (v : D) = - match d with - | D _ -> () - | DF _ -> failwith "Cannot set adjoint value of DF. Use makeReverse on this node when composing the computation." - | DR (_, _, _, uniq) -> this.SetD(uniq, v) - - /// Lookup the adjoint for a vector - member this.Item - with get (d:DV) : DV = - match d with - | DV _ -> DV.ZeroN d.Length - | DVF _ -> failwith "Cannot get adjoint value of DVF. Use makeReverse on this node when composing the computation." - | DVR (_, _, _, uniq) -> this.GetDV(uniq) - and set (d:DV) (v : DV) = - match d with - | DV _ -> () - | DVF _ -> failwith "Cannot set adjoint value of DVF. Use makeReverse on this node when composing the computation." - | DVR (_, _, _, uniq) -> this.SetDV(uniq, v) - - /// Lookup the adjoint for a matrix - member this.Item - with get (d:DM) : DM = - match d with - | DM(_) -> DM.ZeroMN d.Rows d.Cols - | DMF _ -> failwith "Cannot get adjoint value of DMF. Use makeReverse on this node when composing the computation." - | DMR (_, _, _, uniq) -> this.GetDM(uniq) - and set (d:DM) (v : DM) = - match d with - | DM _ -> () - | DMF _ -> failwith "Cannot set adjoint value of DMF. Use makeReverse on this node when composing the computation." - | DMR (_, _, _, uniq) -> this.SetDM(uniq, v) - - override __.ToString() = sprintf "(%d computed adjoints)" dict.Count - - /// D, DV, DM operations (automatically opened) [] module DOps = @@ -3274,31 +3239,18 @@ module DOps = let inline tangent (d:^a when ^a : (member T : ^a)) = (^a : (member T : ^a) d) /// Get the adjoint value of `d` - let adjoint (adjoints: Adjoints) (d : 'T :> dobj) : 'T = + let adjoint (d : 'T :> dobj) : 'T = match box d with - | :? D as d -> adjoints.[d] |> box :?> 'T - | :? DV as d -> adjoints.[d] |> box :?> 'T - | :? DM as d -> adjoints.[d] |> box :?> 'T + | :? D as d -> box d.A :?> 'T + | :? DV as d -> box d.A :?> 'T + | :? DM as d -> box d.A :?> 'T | _ -> failwith "invalid dobj type" /// Get the primal and tangent values of `d`, as a tuple let inline primalTangent d = d |> primal, d |> tangent - - type Fanouts = Dictionary - let incrementFanout (fanouts: Fanouts) d = - match fanouts.TryGetValue(d) with - | true,fanout -> - fanouts.[d] <- fanout + 1u - fanout + 1u - | _ -> - fanouts.[d] <- 1u - 1u - /// Resets the adjoints of all the values in the evaluation trace of `d`, preparing for a new reverse propagation - let reverseReset (adjoints: Adjoints) (d:dobj) = - let bxd (x : dobj) = x - let fanouts = Fanouts() + let reverseReset (d:dobj) = // Note, this uses an explicit worklist over (D|DV|DM) to make it tail-recursive let rec resetRec (ds:dobj list) = match ds with @@ -3307,10 +3259,10 @@ module DOps = match d with | :? D as d -> match d with - | DR(_, o, _, uniq) -> - let fanout = incrementFanout fanouts uniq - if fanout = 1u then - adjoints.SetD(uniq,D.Zero) + | DR(_,_,o,_,_) -> + d.A <- D.Zero + d.F <- d.F + 1u + if d.F = 1u then match o with | Add_D_D(a, b) -> resetRec (bxd a :: bxd b :: t) | Add_D_DCons(a) -> resetRec (bxd a :: t) @@ -3366,10 +3318,10 @@ module DOps = | _ -> resetRec t | :? DV as d -> match d with - | DVR(_, o, _, uniq) -> - let fanout = incrementFanout fanouts uniq - if fanout = 1u then - adjoints.SetDV(uniq,DV.ZeroN d.Length) + | DVR(_,_,o,_,_) -> + d.A <- DV.ZeroN d.Length + d.F <- d.F + 1u + if d.F = 1u then match o with | Add_DV_DV(a, b) -> resetRec (bxd a :: bxd b :: t) | Add_DV_DVCons(a) -> resetRec (bxd a :: t) @@ -3468,10 +3420,10 @@ module DOps = | _ -> resetRec t | :? DM as d -> match d with - | DMR(_, o, _, uniq) -> - let fanout = incrementFanout fanouts uniq - if fanout = 1u then - adjoints.SetDM(uniq,DM.ZeroMN d.Rows d.Cols) + | DMR(_,_,o,_,_) -> + d.A <- DM.ZeroMN d.Rows d.Cols + d.F <- d.F + 1u + if d.F = 1u then match o with | Add_DM_DM(a, b) -> resetRec (bxd a :: bxd b :: t) | Add_DM_DMCons(a) -> resetRec (bxd a :: t) @@ -3480,6 +3432,7 @@ module DOps = | Sub_DMCons_DM(a) -> resetRec (bxd a :: t) | Mul_DM_DM(a, b) -> resetRec (bxd a :: bxd b :: t) | Mul_DM_DMCons(a, _) -> resetRec (bxd a :: t) + | Mul_DMCons_DM(_, b) -> resetRec (bxd b :: t) | Mul_Had_DM_DM(a, b) -> resetRec (bxd a :: bxd b :: t) | Mul_Had_DM_DMCons(a, _) -> resetRec (bxd a :: t) | Mul_DM_D(a, b) -> resetRec (bxd a :: bxd b :: t) @@ -3571,89 +3524,79 @@ module DOps = | _ -> resetRec t | _ -> resetRec t resetRec [d] - fanouts /// Propagates the adjoint `v` backwards through the evaluation trace of `d`. The adjoints in the trace are reset before the push. - let rec reverseProp (adjoints: Adjoints) (v:dobj) (d:dobj) = - let fanouts = reverseReset adjoints d - let inline bxd (x : dobj) = x - let inline bxdelta (x : delta) = x - - let inline bd (v: Delta) (d:D) = bxdelta v, bxd d - let inline bdv (v: DeltaV) (d:DV) = bxdelta v, bxd d - let inline bdm (v: DeltaM) (d:DM) = bxdelta v, bxd d - - let inline bx (v: D) (d:D) = bd (X v) d - let inline bxv (v: DV) (d:DV) = bdv (XV v) d - let inline bxm (v: DM) (d:DM) = bdm (XM v) d + let rec reverseProp (v:dobj) (d:dobj) = + let inline bx (v: D) d = (v :> dobj), bxd d + let inline bxv (v: DV) d = (v :> dobj), bxd d + let inline bxm (v: DM) d = (v :> dobj), bxd d // Note, this uses an explicit worklist over (D*D|DV*DV|DM*DM) to make it tail-recursive - let rec pushRec (ds:(delta*dobj) list) = + let rec pushRec (ds:(dobj*dobj) list) = match ds with | [] -> () | (v, d) :: t -> - match v, d with - | (:? Delta as delta), (:? D as d) -> + match d, v with + | (:? D as d), (:? D as v) -> match d with - | DR(_, o, _, uniq) -> - let dA = adjoints.ApplyDelta(uniq, delta) - let fanout = fanouts.[uniq] - 1u - fanouts.[uniq] <- fanout - // If all incoming parts of the adjoint have been received, then proceed to the parent - if fanout = 0u then + | DR(_,_,o,_,_) -> + d.A <- d.A + v + d.F <- d.F - 1u + // If all incoming parts of the adjoint have been received, then proceed to the children + if d.F = 0u then match o with - | Add_D_D(a, b) -> pushRec ((bx dA a) :: (bx dA b) :: t) - | Add_D_DCons(a) -> pushRec ((bx dA a) :: t) - | Sub_D_D(a, b) -> pushRec ((bx dA a) :: (bx -dA b) :: t) - | Sub_D_DCons(a) -> pushRec ((bx dA a) :: t) - | Sub_DCons_D(b) -> pushRec ((bx -dA b) :: t) - | Mul_D_D(a, b) -> pushRec ((bx (dA * b.P) a) :: (bx (dA * a.P) b) :: t) - | Mul_D_DCons(a, cons) -> pushRec ((bx (dA * cons) a) :: t) - | Div_D_D(a, b) -> pushRec ((bx (dA / b.P) a) :: (bx (dA * (-a.P / (b.P * b.P))) b) :: t) - | Div_D_DCons(a, cons) -> pushRec ((bx (dA / cons) a) :: t) - | Div_DCons_D(cons, b) -> pushRec ((bx (dA * (-cons / (b.P * b.P))) b) :: t) - | Pow_D_D(a, b) -> pushRec ((bx (dA * (a.P ** (b.P - D.One)) * b.P) a) :: (bx (dA * (a.P ** b.P) * log a.P) b) :: t) - | Pow_D_DCons(a, cons) -> pushRec ((bx (dA * (a.P ** (cons - D.One)) * cons) a) :: t) - | Pow_DCons_D(cons, b) -> pushRec ((bx (dA * (cons ** b.P) * log cons) b) :: t) - | Atan2_D_D(a, b) -> let denom = a.P * a.P + b.P * b.P in pushRec ((bx (dA * b.P / denom) a) :: (bx (dA * (-a.P) / denom) b) :: t) - | Atan2_D_DCons(a, cons) -> pushRec ((bx (dA * cons / (a.P * a.P + cons * cons)) a) :: t) - | Atan2_DCons_D(cons, b) -> pushRec ((bx (dA * (-cons) / (cons * cons + b.P * b.P)) b) :: t) - | Log_D(a) -> pushRec ((bx (dA / a.P) a) :: t) - | Log10_D(a) -> pushRec ((bx (dA / (a.P * N.log10Val)) a) :: t) - | Exp_D(a) -> pushRec ((bx (dA * d.P) a) :: t) // d.P = exp a.P - | Sin_D(a) -> pushRec ((bx (dA * cos a.P) a) :: t) - | Cos_D(a) -> pushRec ((bx (dA * (-sin a.P)) a) :: t) - | Tan_D(a) -> let seca = D.One / cos a.P in pushRec ((bx (dA * seca * seca) a) :: t) - | Neg_D(a) -> pushRec ((bx -dA a) :: t) - | Sqrt_D(a) -> pushRec ((bx (dA / (D N.two * d.P)) a) :: t) // d.P = sqrt a.P - | Sinh_D(a) -> pushRec ((bx (dA * cosh a.P) a) :: t) - | Cosh_D(a) -> pushRec ((bx (dA * sinh a.P) a) :: t) - | Tanh_D(a) -> let secha = D.One / cosh a.P in pushRec ((bx (dA * secha * secha) a) :: t) - | Asin_D(a) -> pushRec ((bx (dA / sqrt (D.One - a.P * a.P)) a) :: t) - | Acos_D(a) -> pushRec ((bx (-dA / sqrt (D.One - a.P * a.P)) a) :: t) - | Atan_D(a) -> pushRec ((bx (dA / (D.One + a.P * a.P)) a) :: t) - | Abs_D(a) -> pushRec ((bx (dA * D.Sign(a.P)) a) :: t) + | Add_D_D(a, b) -> pushRec ((bx d.A a) :: (bx d.A b) :: t) + | Add_D_DCons(a) -> pushRec ((bx d.A a) :: t) + | Sub_D_D(a, b) -> pushRec ((bx d.A a) :: (bx -d.A b) :: t) + | Sub_D_DCons(a) -> pushRec ((bx d.A a) :: t) + | Sub_DCons_D(b) -> pushRec ((bx -d.A b) :: t) + | Mul_D_D(a, b) -> pushRec ((bx (d.A * b.P) a) :: (bx (d.A * a.P) b) :: t) + | Mul_D_DCons(a, cons) -> pushRec ((bx (d.A * cons) a) :: t) + | Div_D_D(a, b) -> pushRec ((bx (d.A / b.P) a) :: (bx (d.A * (-a.P / (b.P * b.P))) b) :: t) + | Div_D_DCons(a, cons) -> pushRec ((bx (d.A / cons) a) :: t) + | Div_DCons_D(cons, b) -> pushRec ((bx (d.A * (-cons / (b.P * b.P))) b) :: t) + | Pow_D_D(a, b) -> pushRec ((bx (d.A * (a.P ** (b.P - D.One)) * b.P) a) :: (bx (d.A * (a.P ** b.P) * log a.P) b) :: t) + | Pow_D_DCons(a, cons) -> pushRec ((bx (d.A * (a.P ** (cons - D.One)) * cons) a) :: t) + | Pow_DCons_D(cons, b) -> pushRec ((bx (d.A * (cons ** b.P) * log cons) b) :: t) + | Atan2_D_D(a, b) -> let denom = a.P * a.P + b.P * b.P in pushRec ((bx (d.A * b.P / denom) a) :: (bx (d.A * (-a.P) / denom) b) :: t) + | Atan2_D_DCons(a, cons) -> pushRec ((bx (d.A * cons / (a.P * a.P + cons * cons)) a) :: t) + | Atan2_DCons_D(cons, b) -> pushRec ((bx (d.A * (-cons) / (cons * cons + b.P * b.P)) b) :: t) + | Log_D(a) -> pushRec ((bx (d.A / a.P) a) :: t) + | Log10_D(a) -> pushRec ((bx (d.A / (a.P * N.log10Val)) a) :: t) + | Exp_D(a) -> pushRec ((bx (d.A * d.P) a) :: t) // d.P = exp a.P + | Sin_D(a) -> pushRec ((bx (d.A * cos a.P) a) :: t) + | Cos_D(a) -> pushRec ((bx (d.A * (-sin a.P)) a) :: t) + | Tan_D(a) -> let seca = D.One / cos a.P in pushRec ((bx (d.A * seca * seca) a) :: t) + | Neg_D(a) -> pushRec ((bx -d.A a) :: t) + | Sqrt_D(a) -> pushRec ((bx (d.A / (D N.two * d.P)) a) :: t) // d.P = sqrt a.P + | Sinh_D(a) -> pushRec ((bx (d.A * cosh a.P) a) :: t) + | Cosh_D(a) -> pushRec ((bx (d.A * sinh a.P) a) :: t) + | Tanh_D(a) -> let secha = D.One / cosh a.P in pushRec ((bx (d.A * secha * secha) a) :: t) + | Asin_D(a) -> pushRec ((bx (d.A / sqrt (D.One - a.P * a.P)) a) :: t) + | Acos_D(a) -> pushRec ((bx (-d.A / sqrt (D.One - a.P * a.P)) a) :: t) + | Atan_D(a) -> pushRec ((bx (d.A / (D.One + a.P * a.P)) a) :: t) + | Abs_D(a) -> pushRec ((bx (d.A * D.Sign(a.P)) a) :: t) | Sign_D(a) -> pushRec ((bx D.Zero a) :: t) | Floor_D(a) -> pushRec ((bx D.Zero a) :: t) | Ceil_D(a) -> pushRec ((bx D.Zero a) :: t) | Round_D(a) -> pushRec ((bx D.Zero a) :: t) - | Mul_Dot_DV_DV(a, b) -> pushRec ((bxv (dA * b.P) a) :: (bxv (dA * a.P) b) :: t) - | Mul_Dot_DV_DVCons(a, cons) -> pushRec ((bxv (dA * cons) a) :: t) - | Sum_DV(a) -> pushRec ((bxv (DV.create a.Length dA) a) :: t) - | L1Norm_DV(a) -> pushRec ((bxv (dA * DV.Sign a.P) a) :: t) - | L2NormSq_DV(a) -> pushRec ((bxv (dA * (D N.two) * a.P) a) :: t) - | L2Norm_DV(a) -> pushRec ((bxv ((dA / d.P) * a.P) a) :: t) + | Mul_Dot_DV_DV(a, b) -> pushRec ((bxv (d.A * b.P) a) :: (bxv (d.A * a.P) b) :: t) + | Mul_Dot_DV_DVCons(a, cons) -> pushRec ((bxv (d.A * cons) a) :: t) + | Sum_DV(a) -> pushRec ((bxv (DV.create a.Length d.A) a) :: t) + | L1Norm_DV(a) -> pushRec ((bxv (d.A * DV.Sign a.P) a) :: t) + | L2NormSq_DV(a) -> pushRec ((bxv (d.A * (D N.two) * a.P) a) :: t) + | L2Norm_DV(a) -> pushRec ((bxv ((d.A / d.P) * a.P) a) :: t) | Item_DV(a, i) -> - adjoints.[a] <- DV.AddItem(adjoints.[a], i, dA); + a.A <- DV.AddItem(a.A, i, d.A) pushRec ((bxv DV.Zero a) :: t) - | Sum_DM(a) -> pushRec ((bxm (DM.create a.Rows a.Cols dA) a) :: t) + | Sum_DM(a) -> pushRec ((bxm (DM.create a.Rows a.Cols d.A) a) :: t) | Item_DM(a, i, j) -> - adjoints.[a] <- DM.AddItem(adjoints.[a], i, j, dA); + a.A <- DM.AddItem(a.A, i, j, d.A) pushRec ((bxm DM.Zero a) :: t) | Det_DM(a) -> pushRec ((bxm (d.T * d.P * DM.Transpose(DM.Inverse(a))) a) :: t) // Check this - | ReLU_D(a) -> pushRec ((bx (dA * ((D.Sign(a.P) + N.one) / N.two)) a) :: t) - | Sigmoid_D(a) -> pushRec ((bx (dA * d.P * (N.one - d.P)) a) :: t) // d.P = D.Sigmoid(a.P) - | LogSumExp_DV(a) -> pushRec ((bxv ((dA / exp d.P) * exp a.P) a) :: t) // d.P = DV.LogSumExp(a.P) + | ReLU_D(a) -> pushRec ((bx (d.A * ((D.Sign(a.P) + N.one) / N.two)) a) :: t) + | Sigmoid_D(a) -> pushRec ((bx (d.A * d.P * (N.one - d.P)) a) :: t) // d.P = D.Sigmoid(a.P) + | LogSumExp_DV(a) -> pushRec ((bxv ((d.A / exp d.P) * exp a.P) a) :: t) // d.P = DV.LogSumExp(a.P) | FixedPoint_D(b, bfirst, aprev, alast) -> // Christianson (1994) let imax = DiffSharp.Config.GlobalConfig.FixedPointMaxIterations @@ -3661,8 +3604,8 @@ module DOps = let mutable i = 0 - let r = dA - reverseProp adjoints r alast + let r = d.A + reverseProp r alast while i < imax do i <- i + 1 @@ -3670,306 +3613,294 @@ module DOps = //printfn "Fixed point reverse iteration timeout, i = %i" i ignore() else - if abs (adjoints.[aprev] + r - adjoints.[alast]) <= eps then + if abs (aprev.A + r - alast.A) <= eps then //printfn "Fixed point reverse iteration converged, i = %i" i i <- imax else - reverseProp adjoints (r + adjoints.[aprev]) alast + reverseProp (r + aprev.A) alast - pushRec ((bx (adjoints.[bfirst]) b) :: t) // Propogate converged adjoint back towards the original b at the beginning of the fixed point iteration + pushRec ((bx bfirst.A b) :: t) // Propogate converged adjoint back towards the original b at the beginning of the fixed point iteration | _ -> pushRec t else pushRec t | _ -> pushRec t - | (:? DeltaV as delta), (:? DV as d) -> + | (:? DV as d), (:? DV as v) -> match d with - | DVR(_, o, _, uniq) -> - let dA = adjoints.ApplyDeltaV(uniq,delta) - let fanout = fanouts.[uniq] - 1u - fanouts.[uniq] <- fanout - // If all incoming parts of the adjoint have been received, then proceed to the parent - if fanout = 0u then + | DVR(_,_,o,_,_) -> + d.A <- d.A + v + d.F <- d.F - 1u + // If all incoming parts of the adjoint have been received, then proceed to the children + if d.F = 0u then match o with - | Add_DV_DV(a, b) -> pushRec ((bxv dA a) :: (bxv dA b) :: t) - | Add_DV_DVCons(a) -> pushRec ((bxv dA a) :: t) - | Add_DV_D(a, b) -> pushRec ((bxv dA a) :: (bx (DV.Sum(dA)) b) :: t) - | Add_DV_DCons(a) -> pushRec ((bxv dA a) :: t) - | Add_DVCons_D(b) -> pushRec ((bx (DV.Sum(dA)) b) :: t) - | Sub_DV_DV(a, b) -> pushRec ((bxv dA a) :: (bxv -dA b) :: t) - | Sub_DV_DVCons(a) -> pushRec ((bxv dA a) :: t) - | Sub_DVCons_DV(a) -> pushRec ((bxv -dA a) :: t) - | Sub_DV_D(a, b) -> pushRec ((bxv dA a) :: (bx -(DV.Sum(dA)) b) :: t) - | Sub_DV_DCons(a) -> pushRec ((bxv dA a) :: t) - | Sub_DVCons_D(b) -> pushRec ((bx -(DV.Sum(dA)) b) :: t) - | Sub_D_DV(a, b) -> pushRec ((bx (DV.Sum(dA)) a) :: (bxv -dA b) :: t) - | Sub_D_DVCons(a) -> pushRec ((bx (DV.Sum(dA)) a) :: t) - | Sub_DCons_DV(b) -> pushRec ((bxv -dA b) :: t) - | Mul_Had_DV_DV(a, b) -> pushRec ((bxv (dA .* b.P) a) :: (bxv (dA .* a.P) b) :: t) - | Mul_Had_DV_DVCons(a, cons) -> pushRec ((bxv (dA .* cons) a) :: t) - | Mul_DV_D(a, b) -> pushRec ((bxv (dA * b.P) a) :: (bx (dA * a.P) b) :: t) - | Mul_DV_DCons(a, cons) -> pushRec ((bxv (dA * cons) a) :: t) - | Mul_DVCons_D(cons, b) -> pushRec ((bx (dA * cons) b) :: t) - | Mul_DM_DV(a, b) -> pushRec ((bxm (dA &* b.P) a) :: (bxv (DM.Transpose(a.P) * dA) b) :: t) - | Mul_DM_DVCons(a, cons) -> pushRec ((bxm (dA &* cons) a) :: t) - | Mul_DMCons_DV(cons, b) -> pushRec ((bxv (DM.Transpose(cons) * dA) b) :: t) - | Mul_DV_DM(a, b) -> pushRec ((bxv (dA * DM.Transpose(b.P)) a) :: (bxm (a.P &* dA) b) :: t) - | Mul_DV_DMCons(a, cons) -> pushRec ((bxv (dA * DM.Transpose(cons)) a) :: t) - | Mul_DVCons_DM(cons, b) -> pushRec ((bxm (cons &* dA) b) :: t) - | Div_Had_DV_DV(a, b) -> pushRec ((bxv (dA ./ b.P) a) :: (bxv (dA .* (-a.P ./ (b.P .* b.P))) b) :: t) - | Div_Had_DV_DVCons(a, cons) -> pushRec ((bxv (dA ./ cons) a) :: t) - | Div_Had_DVCons_DV(cons, b) -> pushRec ((bxv (dA .* (-cons ./ (b.P .* b.P))) b) :: t) - | Div_DV_D(a, b) -> pushRec ((bxv (dA / b.P) a) :: (bx (dA * (-a.P / (b.P * b.P))) b) :: t) - | Div_DV_DCons(a, cons) -> pushRec ((bxv (dA / cons) a) :: t) - | Div_DVCons_D(cons, b) -> pushRec ((bx (dA * (-cons / (b.P * b.P))) b) :: t) - | Div_D_DV(a, b) -> pushRec ((bx (DV.Sum(dA ./ b.P)) a) :: (bxv (dA .* (-a.P / (b.P .* b.P))) b) :: t) - | Div_D_DVCons(a, cons) -> pushRec ((bx (DV.Sum(dA ./ cons)) a) :: t) - | Div_DCons_DV(cons, b) -> pushRec ((bxv (dA .* (-cons / (b.P .* b.P))) b) :: t) - | Pow_DV_DV(a, b) -> pushRec ((bxv (dA .* (a.P ** (b.P - D.One)) .* b.P) a) :: (bxv (dA .* (a.P ** b.P) .* log a.P) b) :: t) - | Pow_DV_DVCons(a, cons) -> pushRec ((bxv (dA .* (a.P ** (cons - D.One)) .* cons) a) :: t) - | Pow_DVCons_DV(cons, b) -> pushRec ((bxv (dA .* (cons ** b.P) .* log cons) b) :: t) - | Atan2_DV_DV(a, b) -> let denom = (a.P .* a.P) + (b.P .* b.P) in pushRec ((bxv (dA .* b.P ./ denom) a) :: (bxv (dA .* (-a.P) ./ denom) b) :: t) - | Atan2_DV_DVCons(a, cons) -> pushRec ((bxv (dA .* cons ./ ((a.P .* a.P) + (cons .* cons))) a) :: t) - | Atan2_DVCons_DV(cons, b) -> pushRec ((bxv (dA .* (-cons) ./ ((cons .* cons) + (b.P .* b.P))) b) :: t) - | Pow_DV_D(a, b) -> pushRec ((bxv (dA .* (a.P ** (b.P - D.One)) * b.P) a) :: (bx (DV.Sum(dA .* (a.P ** b.P) .* log a.P)) b) :: t) - | Pow_DV_DCons(a, cons) -> pushRec ((bxv (dA .* (a.P ** (cons - D.One)) * cons) a) :: t) - | Pow_DVCons_D(cons, b) -> pushRec ((bx (DV.Sum(dA .* (cons ** b.P) .* log cons)) b) :: t) - | Pow_D_DV(a, b) -> pushRec ((bx (DV.Sum(dA .* (DV.Pow(a.P, b.P - D.One)) .* b.P)) a) :: (bxv (dA .* (DV.Pow(a.P, b.P)) * log a.P) b) :: t) - | Pow_D_DVCons(a, cons) -> pushRec ((bx (DV.Sum(dA .* (DV.Pow(a.P, cons - D.One)) .* cons)) a) :: t) - | Pow_DCons_DV(cons, b) -> pushRec ((bxv (dA .* (DV.Pow(cons, b.P)) * log cons) b) :: t) - | Atan2_DV_D(a, b) -> let denom = (a.P .* a.P) + (b.P * b.P) in pushRec ((bxv (dA * b.P ./ denom) a) :: (bx (DV.Sum(dA .* (-a.P) ./ denom)) b) :: t) - | Atan2_DV_DCons(a, cons) -> pushRec ((bxv (dA * cons ./ ((a.P .* a.P) + (cons * cons))) a) :: t) - | Atan2_DVCons_D(cons, b) -> pushRec ((bx (DV.Sum(dA .* (-cons) ./ ((cons .* cons) + (b.P * b.P)))) b) :: t) - | Atan2_D_DV(a, b) -> let denom = (a.P * a.P) + (b.P .* b.P) in pushRec ((bx (DV.Sum(dA .* b.P ./ denom)) a) :: (bxv (dA * (-a.P) ./ denom) b) :: t) - | Atan2_D_DVCons(a, cons) -> pushRec ((bx (DV.Sum(dA .* cons ./ ((a.P * a.P) + (cons .* cons)))) a) :: t) - | Atan2_DCons_DV(cons, b) -> pushRec ((bxv (dA * (-cons) ./ ((cons * cons) + (b.P .* b.P))) b) :: t) - | Log_DV(a) -> pushRec ((bxv (dA ./ a.P) a) :: t) - | Log10_DV(a) -> pushRec ((bxv (dA ./ (a.P * N.log10Val)) a) :: t) - | Exp_DV(a) -> pushRec ((bxv (dA .* d.P) a) :: t) // d.P = exp a.P - | Sin_DV(a) -> pushRec ((bxv (dA .* cos a.P) a) :: t) - | Cos_DV(a) -> pushRec ((bxv (-dA .* sin a.P) a) :: t) - | Tan_DV(a) -> let seca = D.One / cos a.P in pushRec ((bxv (dA .* seca .* seca) a) :: t) - | Neg_DV(a) -> pushRec ((bxv -dA a) :: t) - | Sqrt_DV(a) -> pushRec ((bxv (dA ./ (N.two * d.P)) a) :: t) // d.P = sqrt a.P - | Sinh_DV(a) -> pushRec ((bxv (dA .* cosh a.P) a) :: t) - | Cosh_DV(a) -> pushRec ((bxv (dA .* sinh a.P) a) :: t) - | Tanh_DV(a) -> let secha = D.One / cosh a.P in pushRec ((bxv (dA .* secha .* secha) a) :: t) - | Asin_DV(a) -> pushRec ((bxv (dA ./ sqrt (D.One - (a.P .* a.P))) a) :: t) - | Acos_DV(a) -> pushRec ((bxv (-dA ./ sqrt (D.One - (a.P .* a.P))) a) :: t) - | Atan_DV(a) -> pushRec ((bxv (dA ./ (D.One + (a.P .* a.P))) a) :: t) - | Abs_DV(a) -> pushRec ((bxv (dA .* DV.Sign a.P) a) :: t) + | Add_DV_DV(a, b) -> pushRec ((bxv d.A a) :: (bxv d.A b) :: t) + | Add_DV_DVCons(a) -> pushRec ((bxv d.A a) :: t) + | Add_DV_D(a, b) -> pushRec ((bxv d.A a) :: (bx (DV.Sum(d.A)) b) :: t) + | Add_DV_DCons(a) -> pushRec ((bxv d.A a) :: t) + | Add_DVCons_D(b) -> pushRec ((bx (DV.Sum(d.A)) b) :: t) + | Sub_DV_DV(a, b) -> pushRec ((bxv d.A a) :: (bxv -d.A b) :: t) + | Sub_DV_DVCons(a) -> pushRec ((bxv d.A a) :: t) + | Sub_DVCons_DV(a) -> pushRec ((bxv -d.A a) :: t) + | Sub_DV_D(a, b) -> pushRec ((bxv d.A a) :: (bx -(DV.Sum(d.A)) b) :: t) + | Sub_DV_DCons(a) -> pushRec ((bxv d.A a) :: t) + | Sub_DVCons_D(b) -> pushRec ((bx -(DV.Sum(d.A)) b) :: t) + | Sub_D_DV(a, b) -> pushRec ((bx (DV.Sum(d.A)) a) :: (bxv -d.A b) :: t) + | Sub_D_DVCons(a) -> pushRec ((bx (DV.Sum(d.A)) a) :: t) + | Sub_DCons_DV(b) -> pushRec ((bxv -d.A b) :: t) + | Mul_Had_DV_DV(a, b) -> pushRec ((bxv (d.A .* b.P) a) :: (bxv (d.A .* a.P) b) :: t) + | Mul_Had_DV_DVCons(a, cons) -> pushRec ((bxv (d.A .* cons) a) :: t) + | Mul_DV_D(a, b) -> pushRec ((bxv (d.A * b.P) a) :: (bx (d.A * a.P) b) :: t) + | Mul_DV_DCons(a, cons) -> pushRec ((bxv (d.A * cons) a) :: t) + | Mul_DVCons_D(cons, b) -> pushRec ((bx (d.A * cons) b) :: t) + | Mul_DM_DV(a, b) -> pushRec ((bxm (d.A &* b.P) a) :: (bxv (DM.Transpose(a.P) * d.A) b) :: t) + | Mul_DM_DVCons(a, cons) -> pushRec ((bxm (d.A &* cons) a) :: t) + | Mul_DMCons_DV(cons, b) -> pushRec ((bxv (DM.Transpose(cons) * d.A) b) :: t) + | Mul_DV_DM(a, b) -> pushRec ((bxv (d.A * DM.Transpose(b.P)) a) :: (bxm (a.P &* d.A) b) :: t) + | Mul_DV_DMCons(a, cons) -> pushRec ((bxv (d.A * DM.Transpose(cons)) a) :: t) + | Mul_DVCons_DM(cons, b) -> pushRec ((bxm (cons &* d.A) b) :: t) + | Div_Had_DV_DV(a, b) -> pushRec ((bxv (d.A ./ b.P) a) :: (bxv (d.A .* (-a.P ./ (b.P .* b.P))) b) :: t) + | Div_Had_DV_DVCons(a, cons) -> pushRec ((bxv (d.A ./ cons) a) :: t) + | Div_Had_DVCons_DV(cons, b) -> pushRec ((bxv (d.A .* (-cons ./ (b.P .* b.P))) b) :: t) + | Div_DV_D(a, b) -> pushRec ((bxv (d.A / b.P) a) :: (bx (d.A * (-a.P / (b.P * b.P))) b) :: t) + | Div_DV_DCons(a, cons) -> pushRec ((bxv (d.A / cons) a) :: t) + | Div_DVCons_D(cons, b) -> pushRec ((bx (d.A * (-cons / (b.P * b.P))) b) :: t) + | Div_D_DV(a, b) -> pushRec ((bx (DV.Sum(d.A ./ b.P)) a) :: (bxv (d.A .* (-a.P / (b.P .* b.P))) b) :: t) + | Div_D_DVCons(a, cons) -> pushRec ((bx (DV.Sum(d.A ./ cons)) a) :: t) + | Div_DCons_DV(cons, b) -> pushRec ((bxv (d.A .* (-cons / (b.P .* b.P))) b) :: t) + | Pow_DV_DV(a, b) -> pushRec ((bxv (d.A .* (a.P ** (b.P - D.One)) .* b.P) a) :: (bxv (d.A .* (a.P ** b.P) .* log a.P) b) :: t) + | Pow_DV_DVCons(a, cons) -> pushRec ((bxv (d.A .* (a.P ** (cons - D.One)) .* cons) a) :: t) + | Pow_DVCons_DV(cons, b) -> pushRec ((bxv (d.A .* (cons ** b.P) .* log cons) b) :: t) + | Atan2_DV_DV(a, b) -> let denom = (a.P .* a.P) + (b.P .* b.P) in pushRec ((bxv (d.A .* b.P ./ denom) a) :: (bxv (d.A .* (-a.P) ./ denom) b) :: t) + | Atan2_DV_DVCons(a, cons) -> pushRec ((bxv (d.A .* cons ./ ((a.P .* a.P) + (cons .* cons))) a) :: t) + | Atan2_DVCons_DV(cons, b) -> pushRec ((bxv (d.A .* (-cons) ./ ((cons .* cons) + (b.P .* b.P))) b) :: t) + | Pow_DV_D(a, b) -> pushRec ((bxv (d.A .* (a.P ** (b.P - D.One)) * b.P) a) :: (bx (DV.Sum(d.A .* (a.P ** b.P) .* log a.P)) b) :: t) + | Pow_DV_DCons(a, cons) -> pushRec ((bxv (d.A .* (a.P ** (cons - D.One)) * cons) a) :: t) + | Pow_DVCons_D(cons, b) -> pushRec ((bx (DV.Sum(d.A .* (cons ** b.P) .* log cons)) b) :: t) + | Pow_D_DV(a, b) -> pushRec ((bx (DV.Sum(d.A .* (DV.Pow(a.P, b.P - D.One)) .* b.P)) a) :: (bxv (d.A .* (DV.Pow(a.P, b.P)) * log a.P) b) :: t) + | Pow_D_DVCons(a, cons) -> pushRec ((bx (DV.Sum(d.A .* (DV.Pow(a.P, cons - D.One)) .* cons)) a) :: t) + | Pow_DCons_DV(cons, b) -> pushRec ((bxv (d.A .* (DV.Pow(cons, b.P)) * log cons) b) :: t) + | Atan2_DV_D(a, b) -> let denom = (a.P .* a.P) + (b.P * b.P) in pushRec ((bxv (d.A * b.P ./ denom) a) :: (bx (DV.Sum(d.A .* (-a.P) ./ denom)) b) :: t) + | Atan2_DV_DCons(a, cons) -> pushRec ((bxv (d.A * cons ./ ((a.P .* a.P) + (cons * cons))) a) :: t) + | Atan2_DVCons_D(cons, b) -> pushRec ((bx (DV.Sum(d.A .* (-cons) ./ ((cons .* cons) + (b.P * b.P)))) b) :: t) + | Atan2_D_DV(a, b) -> let denom = (a.P * a.P) + (b.P .* b.P) in pushRec ((bx (DV.Sum(d.A .* b.P ./ denom)) a) :: (bxv (d.A * (-a.P) ./ denom) b) :: t) + | Atan2_D_DVCons(a, cons) -> pushRec ((bx (DV.Sum(d.A .* cons ./ ((a.P * a.P) + (cons .* cons)))) a) :: t) + | Atan2_DCons_DV(cons, b) -> pushRec ((bxv (d.A * (-cons) ./ ((cons * cons) + (b.P .* b.P))) b) :: t) + | Log_DV(a) -> pushRec ((bxv (d.A ./ a.P) a) :: t) + | Log10_DV(a) -> pushRec ((bxv (d.A ./ (a.P * N.log10Val)) a) :: t) + | Exp_DV(a) -> pushRec ((bxv (d.A .* d.P) a) :: t) // d.P = exp a.P + | Sin_DV(a) -> pushRec ((bxv (d.A .* cos a.P) a) :: t) + | Cos_DV(a) -> pushRec ((bxv (-d.A .* sin a.P) a) :: t) + | Tan_DV(a) -> let seca = D.One / cos a.P in pushRec ((bxv (d.A .* seca .* seca) a) :: t) + | Neg_DV(a) -> pushRec ((bxv -d.A a) :: t) + | Sqrt_DV(a) -> pushRec ((bxv (d.A ./ (N.two * d.P)) a) :: t) // d.P = sqrt a.P + | Sinh_DV(a) -> pushRec ((bxv (d.A .* cosh a.P) a) :: t) + | Cosh_DV(a) -> pushRec ((bxv (d.A .* sinh a.P) a) :: t) + | Tanh_DV(a) -> let secha = D.One / cosh a.P in pushRec ((bxv (d.A .* secha .* secha) a) :: t) + | Asin_DV(a) -> pushRec ((bxv (d.A ./ sqrt (D.One - (a.P .* a.P))) a) :: t) + | Acos_DV(a) -> pushRec ((bxv (-d.A ./ sqrt (D.One - (a.P .* a.P))) a) :: t) + | Atan_DV(a) -> pushRec ((bxv (d.A ./ (D.One + (a.P .* a.P))) a) :: t) + | Abs_DV(a) -> pushRec ((bxv (d.A .* DV.Sign a.P) a) :: t) | Sign_DV(a) -> pushRec ((bxv DV.Zero a) :: t) | Floor_DV(a) -> pushRec ((bxv DV.Zero a) :: t) | Ceil_DV(a) -> pushRec ((bxv DV.Zero a) :: t) | Round_DV(a) -> pushRec ((bxv DV.Zero a) :: t) - | Make_DV_ofDs(a) -> pushRec (t |> List.append (a |> Array.mapi (fun i v -> (bx dA.[i] v)) |> List.ofArray)) + | Make_DV_ofDs(a) -> pushRec (t |> List.append (a |> Array.mapi (fun i v -> (bx d.A.[i] v)) |> List.ofArray)) | SliceRow_DM(a, i, j) -> - adjoints.[a] <- DM.AddSubMatrix(adjoints.[a], i, j, dA.ToRowDM()) + a.A <- DM.AddSubMatrix(a.A, i, j, d.A.ToRowDM()) pushRec ((bxm DM.Zero a) :: t) | SliceCol_DM(a, i, j) -> - adjoints.[a] <- DM.AddSubMatrix(adjoints.[a], i, j, dA.ToColDM()) + a.A <- DM.AddSubMatrix(a.A, i, j, d.A.ToColDM()) pushRec ((bxm DM.Zero a) :: t) - | Solve_DM_DV(a, b) -> let ba = DM.Solve(DM.Transpose(a), dA) in pushRec ((bxm (-ba &* dA) a) :: (bxv (ba) b) :: t) - | Solve_DM_DVCons(a, cons) -> let ba = DM.Solve(DM.Transpose(a), dA) in pushRec ((bxm (-ba &* dA) a) :: t) - | Solve_DMCons_DV(cons, b) -> let ba = DM.Solve(DM.Transpose(cons), dA) in pushRec ((bxv ba b) :: t) + | Solve_DM_DV(a, b) -> let ba = DM.Solve(DM.Transpose(a), d.A) in pushRec ((bxm (-ba &* d.A) a) :: (bxv (ba) b) :: t) + | Solve_DM_DVCons(a, cons) -> let ba = DM.Solve(DM.Transpose(a), d.A) in pushRec ((bxm (-ba &* d.A) a) :: t) + | Solve_DMCons_DV(cons, b) -> let ba = DM.Solve(DM.Transpose(cons), d.A) in pushRec ((bxv ba b) :: t) | Append_DV_DV(a, b) -> - adjoints.[a] <- adjoints.[a] + dA.[..(a.Length - 1)] - adjoints.[b] <- adjoints.[b] + dA.[a.Length..] + a.A <- a.A + d.A.[..(a.Length - 1)] + b.A <- b.A + d.A.[a.Length..] pushRec ((bxv DV.Zero a) :: (bxv DV.Zero b) :: t) | Append_DV_DVCons(a) -> - adjoints.[a] <- adjoints.[a] + dA.[..(a.Length - 1)] + a.A <- a.A + d.A.[..(a.Length - 1)] pushRec ((bxv DV.Zero a) :: t) | Append_DVCons_DV(b) -> - adjoints.[b] <- adjoints.[b] + dA.[(d.Length - b.Length)..] + b.A <- b.A + d.A.[(d.Length - b.Length)..] pushRec ((bxv DV.Zero b) :: t) | Split_DV(a, i) -> - adjoints.[a] <- DV.AddSubVector(adjoints.[a], i, dA) + a.A <- DV.AddSubVector(a.A, i, d.A) pushRec ((bxv DV.Zero a) :: t) - | AddItem_DV_D(a, i, b) -> pushRec ((bxv dA a) :: (bx (dA.[i]) b) :: t) - | AddItem_DV_DCons(a) -> pushRec ((bxv dA a) :: t) - | AddItem_DVCons_D(i, b) -> pushRec ((bx dA.[i] b) :: t) - | AddSubVector_DV_DV(a, i, b) -> pushRec ((bxv dA a) :: (bxv (dA.[i..(i + b.Length - 1)]) b) :: t) - | AddSubVector_DV_DVCons(a) -> pushRec ((bxv dA a) :: t) - | AddSubVector_DVCons_DV(i, b) -> pushRec ((bxv (dA.[i..(i + b.Length - 1)]) b) :: t) - | ReshapeCopy_DM_DV(a) -> pushRec ((bxm (DV.ReshapeToDM(a.Rows, dA)) a) :: t) + | AddItem_DV_D(a, i, b) -> pushRec ((bxv d.A a) :: (bx (d.A.[i]) b) :: t) + | AddItem_DV_DCons(a) -> pushRec ((bxv d.A a) :: t) + | AddItem_DVCons_D(i, b) -> pushRec ((bx d.A.[i] b) :: t) + | AddSubVector_DV_DV(a, i, b) -> pushRec ((bxv d.A a) :: (bxv (d.A.[i..(i + b.Length - 1)]) b) :: t) + | AddSubVector_DV_DVCons(a) -> pushRec ((bxv d.A a) :: t) + | AddSubVector_DVCons_DV(i, b) -> pushRec ((bxv (d.A.[i..(i + b.Length - 1)]) b) :: t) + | ReshapeCopy_DM_DV(a) -> pushRec ((bxm (DV.ReshapeToDM(a.Rows, d.A)) a) :: t) | Slice_DV(a, i) -> - adjoints.[a] <- DV.AddSubVector(adjoints.[a], i, dA) + a.A <- DV.AddSubVector(a.A, i, d.A) pushRec ((bxv DV.Zero a) :: t) | Diagonal_DM(a) -> - adjoints.[a] <- DM.AddDiagonal(adjoints.[a], dA) + a.A <- DM.AddDiagonal(a.A, d.A) pushRec ((bxm DM.Zero a) :: t) - | ReLU_DV(a) -> pushRec ((bxv (dA .* ((DV.Sign(a.P) + N.one) / N.two)) a) :: t) - | Sigmoid_DV(a) -> pushRec ((bxv (dA .* d.P .* (N.one - d.P)) a) :: t) // d.P = DV.Sigmoid(a.P) + | ReLU_DV(a) -> pushRec ((bxv (d.A .* ((DV.Sign(a.P) + N.one) / N.two)) a) :: t) + | Sigmoid_DV(a) -> pushRec ((bxv (d.A .* d.P .* (N.one - d.P)) a) :: t) // d.P = DV.Sigmoid(a.P) | _ -> pushRec t else pushRec t | _ -> pushRec t - | (:? DeltaM as delta), (:? DM as d) -> + | (:? DM as d), (:? DM as v) -> match d with - | DMR(_, o, _, uniq) -> - let dA = adjoints.ApplyDeltaM(uniq,delta) - let fanout = fanouts.[uniq] - 1u - fanouts.[uniq] <- fanout - // If all incoming parts of the adjoint have been received, then proceed to the parent - if fanout = 0u then + | DMR(_,_,o,_,_) -> + d.A <- d.A + v + d.F <- d.F - 1u + // If all incoming parts of the adjoint have been received, then proceed to the children + if d.F = 0u then match o with - | Add_DM_DM(a, b) -> pushRec ((bxm dA a) :: (bxm dA b) :: t) - | Add_DM_DMCons(a) -> pushRec ((bxm dA a) :: t) + | Add_DM_DM(a, b) -> pushRec ((bxm d.A a) :: (bxm d.A b) :: t) + | Add_DM_DMCons(a) -> pushRec ((bxm d.A a) :: t) - // When pushing "-dA" as adjoint increment for b, the operation - // "b.Adjoint <- -1.0 * dA + b.Adjoint" + // When pushing "-d.A" as adjoint increment for b, the operation + // "b.Adjoint <- -1.0 * d.A + b.Adjoint" // can be performed directly in-place. Instead of pushing a D|DV|DM we should a // structured expression about how to compute the D|DV|DM which can be interpreted // to do an in-place update - | Sub_DM_DM(a, b) -> pushRec ((bxm dA a) :: (bdm (XNegM (XM dA)) b) :: t) + | Sub_DM_DM(a, b) -> pushRec ((bxm d.A a) :: (bxm (-d.A) b) :: t) + + | Sub_DM_DMCons(a) -> pushRec ((bxm d.A a) :: t) + | Sub_DMCons_DM(a) -> pushRec ((bxm d.A a) :: t) // TODO: also avoid the inplace operations in most of the below. - | Sub_DM_DMCons(a) -> pushRec ((bxm dA a) :: t) - | Sub_DMCons_DM(a) -> pushRec ((bxm -dA a) :: t) - | Mul_DM_DM(a, b) -> pushRec ((bxm (dA * DM.Transpose(b.P)) a) :: (bxm (DM.Transpose(a.P) * dA) b) :: t) - | Mul_DM_DMCons(a, cons) -> pushRec ((bxm (dA * DM.Transpose(cons)) a) :: t) - | Mul_DMCons_DM(cons, b) -> pushRec ((bxm (DM.Transpose(cons) * dA) b) :: t) - | Mul_Had_DM_DM(a, b) -> pushRec ((bxm (dA .* b.P) a) :: (bxm (dA .* a.P) b) :: t) - | Mul_Had_DM_DMCons(a, cons) -> pushRec ((bxm (dA .* cons) a) :: t) - | Mul_DM_D(a, b) -> pushRec ((bxm (dA * b.P) a) :: (bx (DM.Sum(dA .* a.P)) b) :: t) - | Mul_DM_DCons(a, cons) -> pushRec ((bxm (dA * cons) a) :: t) - | Mul_DMCons_D(cons, b) -> pushRec ((bx (DM.Sum(dA .* cons)) b) :: t) - | Mul_Out_DV_DV(a, b) -> pushRec ((bxv (dA * b.P) a) :: (bxv (DM.Transpose(dA) * a.P) b) :: t) - | Mul_Out_DV_DVCons(a, cons) -> pushRec ((bxv (dA * cons) a) :: t) - | Mul_Out_DVCons_DV(cons, b) -> pushRec ((bxv (DM.Transpose(dA) * cons) b) :: t) - | Div_Had_DM_DM(a, b) -> pushRec ((bxm (dA ./ b.P) a) :: (bxm (dA .* (-a.P ./ (b.P .* b.P))) b) :: t) - | Div_Had_DM_DMCons(a, cons) -> pushRec ((bxm (dA ./ cons) a) :: t) - | Div_Had_DMCons_DM(cons, b) -> pushRec ((bxm (dA .* (-cons ./ (b.P .* b.P))) b) :: t) - | Pow_DM_DM(a, b) -> pushRec ((bxm (dA .* (a.P ** (b.P - D.One)) .* b.P) a) :: (bxm (dA .* (a.P ** b.P) .* log a.P) b) :: t) - | Pow_DM_DMCons(a, cons) -> pushRec ((bxm (dA .* (a.P ** (cons - D.One)) .* cons) a) :: t) - | Pow_DMCons_DM(cons, b) -> pushRec ((bxm (dA .* (cons ** b.P) .* log cons) b) :: t) - | Atan2_DM_DM(a, b) -> let denom = (a.P .* a.P) + (b.P .* b.P) in pushRec ((bxm (dA .* b.P ./ denom) a) :: (bxm (dA .* (-a.P) ./ denom) b) :: t) - | Atan2_DM_DMCons(a, cons) -> pushRec ((bxm (dA .* cons ./ ((a.P .* a.P) + (cons .* cons))) a) :: t) - | Atan2_DMCons_DM(cons, b) -> pushRec ((bxm (dA .* (-cons) ./ ((cons .* cons) + (b.P .* b.P))) b) :: t) - | Add_DM_D(a, b) -> pushRec ((bxm dA a) :: (bx (DM.Sum(dA)) b) :: t) - | Add_DM_DCons(a) -> pushRec ((bxm dA a) :: t) - | Add_DMCons_D(b) -> pushRec ((bx (DM.Sum(dA)) b) :: t) + | Mul_DM_DM(a, b) -> pushRec ((bxm (d.A * DM.Transpose(b.P)) a) :: (bxm (DM.Transpose(a.P) * d.A) b) :: t) + | Mul_DM_DMCons(a, cons) -> pushRec ((bxm (d.A * DM.Transpose(cons)) a) :: t) + | Mul_DMCons_DM(cons, b) -> pushRec ((bxm (DM.Transpose(cons) * d.A) b) :: t) + | Mul_Had_DM_DM(a, b) -> pushRec ((bxm (d.A .* b.P) a) :: (bxm (d.A .* a.P) b) :: t) + | Mul_Had_DM_DMCons(a, cons) -> pushRec ((bxm (d.A .* cons) a) :: t) + | Mul_DM_D(a, b) -> pushRec ((bxm (d.A * b.P) a) :: (bx (DM.Sum(d.A .* a.P)) b) :: t) + | Mul_DM_DCons(a, cons) -> pushRec ((bxm (d.A * cons) a) :: t) + | Mul_DMCons_D(cons, b) -> pushRec ((bx (DM.Sum(d.A .* cons)) b) :: t) + | Mul_Out_DV_DV(a, b) -> pushRec ((bxv (d.A * b.P) a) :: (bxv (DM.Transpose(d.A) * a.P) b) :: t) + | Mul_Out_DV_DVCons(a, cons) -> pushRec ((bxv (d.A * cons) a) :: t) + | Mul_Out_DVCons_DV(cons, b) -> pushRec ((bxv (DM.Transpose(d.A) * cons) b) :: t) + | Div_Had_DM_DM(a, b) -> pushRec ((bxm (d.A ./ b.P) a) :: (bxm (d.A .* (-a.P ./ (b.P .* b.P))) b) :: t) + | Div_Had_DM_DMCons(a, cons) -> pushRec ((bxm (d.A ./ cons) a) :: t) + | Div_Had_DMCons_DM(cons, b) -> pushRec ((bxm (d.A .* (-cons ./ (b.P .* b.P))) b) :: t) + | Pow_DM_DM(a, b) -> pushRec ((bxm (d.A .* (a.P ** (b.P - D.One)) .* b.P) a) :: (bxm (d.A .* (a.P ** b.P) .* log a.P) b) :: t) + | Pow_DM_DMCons(a, cons) -> pushRec ((bxm (d.A .* (a.P ** (cons - D.One)) .* cons) a) :: t) + | Pow_DMCons_DM(cons, b) -> pushRec ((bxm (d.A .* (cons ** b.P) .* log cons) b) :: t) + | Atan2_DM_DM(a, b) -> let denom = (a.P .* a.P) + (b.P .* b.P) in pushRec ((bxm (d.A .* b.P ./ denom) a) :: (bxm (d.A .* (-a.P) ./ denom) b) :: t) + | Atan2_DM_DMCons(a, cons) -> pushRec ((bxm (d.A .* cons ./ ((a.P .* a.P) + (cons .* cons))) a) :: t) + | Atan2_DMCons_DM(cons, b) -> pushRec ((bxm (d.A .* (-cons) ./ ((cons .* cons) + (b.P .* b.P))) b) :: t) + | Add_DM_D(a, b) -> pushRec ((bxm d.A a) :: (bx (DM.Sum(d.A)) b) :: t) + | Add_DM_DCons(a) -> pushRec ((bxm d.A a) :: t) + | Add_DMCons_D(b) -> pushRec ((bx (DM.Sum(d.A)) b) :: t) | Add_DMCols_DV(a, b) -> - dA.GetCols() |> Seq.iter (fun v -> adjoints.[b] <- adjoints.[b] + v) - pushRec ((bxm dA a) :: (bxv DV.Zero b) :: t) + d.A.GetCols() |> Seq.iter (fun v -> b.A <- b.A + v) + pushRec ((bxm d.A a) :: (bxv DV.Zero b) :: t) | Add_DMCols_DVCons(a) -> - pushRec ((bxm dA a) :: t) + pushRec ((bxm d.A a) :: t) | Add_DMColsCons_DV(b) -> - dA.GetCols() |> Seq.iter (fun v -> adjoints.[b] <- adjoints.[b] + v) + d.A.GetCols() |> Seq.iter (fun v -> b.A <- b.A + v) pushRec ((bxv DV.Zero b) :: t) - | Sub_DM_D(a, b) -> pushRec ((bxm dA a) :: (bx -(DM.Sum(dA)) b) :: t) - | Sub_DM_DCons(a) -> pushRec ((bxm dA a) :: t) - | Sub_DMCons_D(b) -> pushRec ((bx -(DM.Sum(dA)) b) :: t) - | Sub_D_DM(a, b) -> pushRec ((bx (DM.Sum(dA)) a) :: (bxm -dA b) :: t) - | Sub_D_DMCons(a) -> pushRec ((bx (DM.Sum(dA)) a) :: t) - | Sub_DCons_DM(b) -> pushRec ((bxm -dA b) :: t) - | Div_DM_D(a, b) -> pushRec ((bxm (dA / b.P) a) :: (bx (DM.Sum (dA .* (-a.P / b.P * b.P))) b) :: t) - | Div_DM_DCons(a, cons) -> pushRec ((bxm (dA / cons) a) :: t) - | Div_DMCons_D(cons, b) -> pushRec ((bx (DM.Sum (dA .* (-cons / (b.P * b.P)))) b) :: t) - | Div_D_DM(a, b) -> pushRec ((bx (DM.Sum(dA ./ b.P)) a) :: (bxm (dA .* (-a.P / (b.P .* b.P))) b) :: t) - | Div_D_DMCons(a, cons) -> pushRec ((bx (DM.Sum(dA ./ cons)) a) :: t) - | Div_DCons_DM(cons, b) -> pushRec ((bxm (dA .* (-cons / (b.P .* b.P))) b) :: t) - | Pow_DM_D(a, b) -> pushRec ((bxm (dA .* (a.P ** (b.P - D.One)) * b.P) a) :: (bx (DM.Sum(dA .* (a.P ** b.P) .* log a.P)) b) :: t) - | Pow_DM_DCons(a, cons) -> pushRec ((bxm (dA .* (a.P ** (cons - D.One)) * cons) a) :: t) - | Pow_DMCons_D(cons, b) -> pushRec ((bx (DM.Sum(dA .* (cons ** b.P) .* log cons)) b) :: t) - | Pow_D_DM(a, b) -> pushRec ((bx (DM.Sum(dA .* (DM.Pow(a.P, b.P - D.One)) .* b.P)) a) :: (bxm (dA .* (DM.Pow(a.P, b.P)) * log a.P) b) :: t) - | Pow_D_DMCons(a, cons) -> pushRec ((bx (DM.Sum(dA .* (DM.Pow(a.P, cons - D.One)) .* cons)) a) :: t) - | Pow_DCons_DM(cons, b) -> pushRec ((bxm (dA .* (DM.Pow(cons, b.P)) * log cons) b) :: t) - | Atan2_DM_D(a, b) -> let denom = (a.P .* a.P) + (b.P * b.P) in pushRec ((bxm (dA * b.P ./ denom) a) :: (bx (DM.Sum(dA .* (-a.P) ./ denom)) b) :: t) - | Atan2_DM_DCons(a, cons) -> pushRec ((bxm (dA * cons ./ ((a.P .* a.P) + (cons * cons))) a) :: t) - | Atan2_DMCons_D(cons, b) ->pushRec ((bx (DM.Sum(dA .* (-cons) ./ ((cons .* cons) + (b.P * b.P)))) b) :: t) - | Atan2_D_DM(a, b) -> let denom = (a.P * a.P) + (b.P .* b.P) in pushRec ((bx (DM.Sum(dA .* b.P ./ denom)) a) :: (bxm (dA * (-a.P) ./ denom) b) :: t) - | Atan2_D_DMCons(a, cons) -> pushRec ((bx (DM.Sum(dA .* cons ./ ((a.P * a.P) + (cons .* cons)))) a) :: t) - | Atan2_DCons_DM(cons, b) -> pushRec ((bxm (dA * (-cons) ./ ((cons * cons) + (b.P .* b.P))) b) :: t) - | Log_DM(a) -> pushRec ((bxm (dA ./ a.P) a) :: t) - | Log10_DM(a) -> pushRec ((bxm (dA ./ (a.P * N.log10Val)) a) :: t) - | Exp_DM(a) -> pushRec ((bxm (dA .* d.P) a) :: t) // d.P = exp a.P - | Sin_DM(a) -> pushRec ((bxm (dA .* cos a.P) a) :: t) - | Cos_DM(a) -> pushRec ((bxm (-dA .* sin a.P) a) :: t) - | Tan_DM(a) -> let seca = D.One / cos a.P in pushRec ((bxm (dA .* seca .* seca) a) :: t) - | Neg_DM(a) -> pushRec ((bxm -dA a) :: t) - | Sqrt_DM(a) -> pushRec ((bxm (dA ./ (N.two * d.P)) a) :: t) // d.P = sqrt a.P - | Sinh_DM(a) -> pushRec ((bxm (dA .* cosh a.P) a) :: t) - | Cosh_DM(a) -> pushRec ((bxm (dA .* sinh a.P) a) :: t) - | Tanh_DM(a) -> let secha = D.One / cosh a.P in pushRec ((bxm (dA .* secha .* secha) a) :: t) - | Asin_DM(a) -> pushRec ((bxm (dA ./ sqrt (D.One - (a.P .* a.P))) a) :: t) - | Acos_DM(a) -> pushRec ((bxm (-dA ./ sqrt (D.One - (a.P .* a.P))) a) :: t) - | Atan_DM(a) -> pushRec ((bxm (dA ./ (D.One + (a.P .* a.P))) a) :: t) - | Abs_DM(a) -> pushRec ((bxm (dA .* DM.Sign a.P) a) :: t) + | Sub_DM_D(a, b) -> pushRec ((bxm d.A a) :: (bx -(DM.Sum(d.A)) b) :: t) + | Sub_DM_DCons(a) -> pushRec ((bxm d.A a) :: t) + | Sub_DMCons_D(b) -> pushRec ((bx -(DM.Sum(d.A)) b) :: t) + | Sub_D_DM(a, b) -> pushRec ((bx (DM.Sum(d.A)) a) :: (bxm -d.A b) :: t) + | Sub_D_DMCons(a) -> pushRec ((bx (DM.Sum(d.A)) a) :: t) + | Sub_DCons_DM(b) -> pushRec ((bxm -d.A b) :: t) + | Div_DM_D(a, b) -> pushRec ((bxm (d.A / b.P) a) :: (bx (DM.Sum (d.A .* (-a.P / b.P * b.P))) b) :: t) + | Div_DM_DCons(a, cons) -> pushRec ((bxm (d.A / cons) a) :: t) + | Div_DMCons_D(cons, b) -> pushRec ((bx (DM.Sum (d.A .* (-cons / (b.P * b.P)))) b) :: t) + | Div_D_DM(a, b) -> pushRec ((bx (DM.Sum(d.A ./ b.P)) a) :: (bxm (d.A .* (-a.P / (b.P .* b.P))) b) :: t) + | Div_D_DMCons(a, cons) -> pushRec ((bx (DM.Sum(d.A ./ cons)) a) :: t) + | Div_DCons_DM(cons, b) -> pushRec ((bxm (d.A .* (-cons / (b.P .* b.P))) b) :: t) + | Pow_DM_D(a, b) -> pushRec ((bxm (d.A .* (a.P ** (b.P - D.One)) * b.P) a) :: (bx (DM.Sum(d.A .* (a.P ** b.P) .* log a.P)) b) :: t) + | Pow_DM_DCons(a, cons) -> pushRec ((bxm (d.A .* (a.P ** (cons - D.One)) * cons) a) :: t) + | Pow_DMCons_D(cons, b) -> pushRec ((bx (DM.Sum(d.A .* (cons ** b.P) .* log cons)) b) :: t) + | Pow_D_DM(a, b) -> pushRec ((bx (DM.Sum(d.A .* (DM.Pow(a.P, b.P - D.One)) .* b.P)) a) :: (bxm (d.A .* (DM.Pow(a.P, b.P)) * log a.P) b) :: t) + | Pow_D_DMCons(a, cons) -> pushRec ((bx (DM.Sum(d.A .* (DM.Pow(a.P, cons - D.One)) .* cons)) a) :: t) + | Pow_DCons_DM(cons, b) -> pushRec ((bxm (d.A .* (DM.Pow(cons, b.P)) * log cons) b) :: t) + | Atan2_DM_D(a, b) -> let denom = (a.P .* a.P) + (b.P * b.P) in pushRec ((bxm (d.A * b.P ./ denom) a) :: (bx (DM.Sum(d.A .* (-a.P) ./ denom)) b) :: t) + | Atan2_DM_DCons(a, cons) -> pushRec ((bxm (d.A * cons ./ ((a.P .* a.P) + (cons * cons))) a) :: t) + | Atan2_DMCons_D(cons, b) ->pushRec ((bx (DM.Sum(d.A .* (-cons) ./ ((cons .* cons) + (b.P * b.P)))) b) :: t) + | Atan2_D_DM(a, b) -> let denom = (a.P * a.P) + (b.P .* b.P) in pushRec ((bx (DM.Sum(d.A .* b.P ./ denom)) a) :: (bxm (d.A * (-a.P) ./ denom) b) :: t) + | Atan2_D_DMCons(a, cons) -> pushRec ((bx (DM.Sum(d.A .* cons ./ ((a.P * a.P) + (cons .* cons)))) a) :: t) + | Atan2_DCons_DM(cons, b) -> pushRec ((bxm (d.A * (-cons) ./ ((cons * cons) + (b.P .* b.P))) b) :: t) + | Log_DM(a) -> pushRec ((bxm (d.A ./ a.P) a) :: t) + | Log10_DM(a) -> pushRec ((bxm (d.A ./ (a.P * N.log10Val)) a) :: t) + | Exp_DM(a) -> pushRec ((bxm (d.A .* d.P) a) :: t) // d.P = exp a.P + | Sin_DM(a) -> pushRec ((bxm (d.A .* cos a.P) a) :: t) + | Cos_DM(a) -> pushRec ((bxm (-d.A .* sin a.P) a) :: t) + | Tan_DM(a) -> let seca = D.One / cos a.P in pushRec ((bxm (d.A .* seca .* seca) a) :: t) + | Neg_DM(a) -> pushRec ((bxm -d.A a) :: t) + | Sqrt_DM(a) -> pushRec ((bxm (d.A ./ (N.two * d.P)) a) :: t) // d.P = sqrt a.P + | Sinh_DM(a) -> pushRec ((bxm (d.A .* cosh a.P) a) :: t) + | Cosh_DM(a) -> pushRec ((bxm (d.A .* sinh a.P) a) :: t) + | Tanh_DM(a) -> let secha = D.One / cosh a.P in pushRec ((bxm (d.A .* secha .* secha) a) :: t) + | Asin_DM(a) -> pushRec ((bxm (d.A ./ sqrt (D.One - (a.P .* a.P))) a) :: t) + | Acos_DM(a) -> pushRec ((bxm (-d.A ./ sqrt (D.One - (a.P .* a.P))) a) :: t) + | Atan_DM(a) -> pushRec ((bxm (d.A ./ (D.One + (a.P .* a.P))) a) :: t) + | Abs_DM(a) -> pushRec ((bxm (d.A .* DM.Sign a.P) a) :: t) | Sign_DM(a) -> pushRec ((bxm DM.Zero a) :: t) | Floor_DM(a) -> pushRec ((bxm DM.Zero a) :: t) | Ceil_DM(a) -> pushRec ((bxm DM.Zero a) :: t) | Round_DM(a) -> pushRec ((bxm DM.Zero a) :: t) - | Transpose_DM(a) -> pushRec ((bxm (DM.Transpose(dA)) a) :: t) - | Make_DM_ofDs(a) -> pushRec (t |> List.append (List.map2 (fun v dd -> (bx v dd)) (dA |> DM.toDV |> DV.toArray |> Array.toList) (a |> Array2D.toArray |> List.ofArray))) + | Transpose_DM(a) -> pushRec ((bxm (DM.Transpose(d.A)) a) :: t) + | Make_DM_ofDs(a) -> pushRec (t |> List.append (List.map2 (fun v dd -> (bx v dd)) (d.A |> DM.toDV |> DV.toArray |> Array.toList) (a |> Array2D.toArray |> List.ofArray))) | Make_DMRows_ofDV(a) -> - dA.GetRows() |> Seq.iter (fun v -> adjoints.[a] <- adjoints.[a] + v) + d.A.GetRows() |> Seq.iter (fun v -> a.A <- a.A + v) pushRec ((bxv DV.Zero a) :: t) | Make_DMCols_ofDV(a) -> - dA.GetCols() |> Seq.iter (fun v -> adjoints.[a] <- adjoints.[a] + v) + d.A.GetCols() |> Seq.iter (fun v -> a.A <- a.A + v) pushRec ((bxv DV.Zero a) :: t) - | Make_DMRows_ofDVs(a) -> pushRec (t |> List.append (a |> List.ofArray |> List.mapi (fun i v -> (bxv dA.[i, *] v)))) - | AddItem_DM_D(a, i, j, b) -> pushRec ((bxm dA a) :: (bx (dA.[i, j]) b) :: t) - | AddItem_DM_DCons(a) -> pushRec ((bxm dA a) :: t) - | AddItem_DMCons_D(i, j, b) -> pushRec ((bx dA.[i, j] b) :: t) - | AddSubMatrix_DM_DM(a, i, j, b) -> pushRec ((bxm dA a) :: (bxm (dA.[i..(i + b.Rows - 1), j..(j + b.Cols - 1)]) b) :: t) - | AddSubMatrix_DM_DMCons(a) -> pushRec ((bxm dA a) :: t) - | AddSubMatrix_DMCons_DM(i, j, b) -> pushRec ((bxm (dA.[i..(i + b.Rows - 1), j..(j + b.Cols - 1)]) b) :: t) + | Make_DMRows_ofDVs(a) -> pushRec (t |> List.append (a |> List.ofArray |> List.mapi (fun i v -> (bxv d.A.[i, *] v)))) + | AddItem_DM_D(a, i, j, b) -> pushRec ((bxm d.A a) :: (bx (d.A.[i, j]) b) :: t) + | AddItem_DM_DCons(a) -> pushRec ((bxm d.A a) :: t) + | AddItem_DMCons_D(i, j, b) -> pushRec ((bx d.A.[i, j] b) :: t) + | AddSubMatrix_DM_DM(a, i, j, b) -> pushRec ((bxm d.A a) :: (bxm (d.A.[i..(i + b.Rows - 1), j..(j + b.Cols - 1)]) b) :: t) + | AddSubMatrix_DM_DMCons(a) -> pushRec ((bxm d.A a) :: t) + | AddSubMatrix_DMCons_DM(i, j, b) -> pushRec ((bxm (d.A.[i..(i + b.Rows - 1), j..(j + b.Cols - 1)]) b) :: t) | Slice_DM(a, i, j) -> - adjoints.[a] <- DM.AddSubMatrix(adjoints.[a], i, j, dA) + a.A <- DM.AddSubMatrix(a.A, i, j, d.A) pushRec ((bxm DM.Zero a) :: t) - | RowMatrix_DV(a) -> pushRec ((bxv (dA.[0, *]) a) :: t) - | AddDiagonal_DM_DV(a, b) -> pushRec ((bxm dA a) :: (bxv (DM.Diagonal(dA)) b) :: t) - | AddDiagonal_DM_DVCons(a) -> pushRec ((bxm dA a) :: t) - | AddDiagonal_DMCons_DV(b) -> pushRec ((bxv (DM.Diagonal(dA)) b) :: t) - | ReshapeCopy_DV_DM(a) -> pushRec ((bxv (DM.ReshapeToDV(dA)) a) :: t) - | Inverse_DM(a) -> let dpt = DM.Transpose(d.P) in pushRec ((bxm (-dpt * dA * dpt) a) :: t) // d.P = DM.Inverse(a.P) - | ReLU_DM(a) -> pushRec ((bxm (dA .* ((DM.Sign(a.P) + N.one) / N.two)) a) :: t) - | Sigmoid_DM(a) -> pushRec ((bxm (dA .* d.P .* (N.one - d.P)) a) :: t) // d.P = DM.Sigmoid(a.P) + | RowMatrix_DV(a) -> pushRec ((bxv (d.A.[0, *]) a) :: t) + | AddDiagonal_DM_DV(a, b) -> pushRec ((bxm d.A a) :: (bxv (DM.Diagonal(d.A)) b) :: t) + | AddDiagonal_DM_DVCons(a) -> pushRec ((bxm d.A a) :: t) + | AddDiagonal_DMCons_DV(b) -> pushRec ((bxv (DM.Diagonal(d.A)) b) :: t) + | ReshapeCopy_DV_DM(a) -> pushRec ((bxv (DM.ReshapeToDV(d.A)) a) :: t) + | Inverse_DM(a) -> let dpt = DM.Transpose(d.P) in pushRec ((bxm (-dpt * d.A * dpt) a) :: t) // d.P = DM.Inverse(a.P) + | ReLU_DM(a) -> pushRec ((bxm (d.A .* ((DM.Sign(a.P) + N.one) / N.two)) a) :: t) + | Sigmoid_DM(a) -> pushRec ((bxm (d.A .* d.P .* (N.one - d.P)) a) :: t) // d.P = DM.Sigmoid(a.P) | _ -> pushRec t else pushRec t | _ -> pushRec t | _ -> pushRec t - let initialv = - match v with - | :? D as v -> bxdelta (X v) - | :? DV as v -> bxdelta (XV v) - | :? DM as v -> bxdelta (XM v) - | _ -> failwith "invalid dobj" - pushRec [(initialv, d)] + pushRec [(v, d)] /// Forward and reverse differentiation operations module (automatically opened) [] module DiffOps = - let inline computeAdjoints (d: 'T :> dobj) = - let adjoints = Adjoints() - let one = LanguagePrimitives.GenericOne<'T> - reverseProp adjoints one d - adjoints - /// Original value and first derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diff' f x = - x |> makeForward GlobalTagger.Next (D.One) |> f |> primalTangent + let diff' (f: D -> D) x = + let dx = makeForward GlobalTagger.Next (D.One) x + dx |> f |> primalTangent /// First derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diff f x = diff' f x |> snd + let diff (f: D -> D) x = diff' f x |> snd /// Second derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diff2 f x = + let diff2 (f: D -> D) x : D = diff (diff f) x /// Original value, first derivative, and second derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diff2'' f x = + let diff2'' (f: D -> D) x : D * D * D = let v, d = diff' f x let d2 = diff2 f x (v, d, d2) /// Original value and second derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diff2' f x = + let diff2' (f: D -> D) x : D * D = diff2'' f x |> drop2Of3 /// `n`-th derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diffn n f x = + let diffn n (f: D -> D) x : D = if n < 0 then ErrorMessages.InvalidArgDiffn() elif n = 0 then x |> f else @@ -3980,57 +3911,59 @@ module DiffOps = x |> d n f /// Original value and `n`-th derivative of a scalar-to-scalar function `f`, at point `x`. Forward AD. - let inline diffn' n f x = + let diffn' n (f: D -> D) x : D * D = (x |> f, diffn n f x) /// Original value and gradient of a vector-to-scalar function `f`, at point `x`. Reverse AD. - let inline grad' f x = + let grad' (f: DV -> D) x : D * DV = let xa = x |> makeReverse GlobalTagger.Next let z:D = f xa - let adjoints = computeAdjoints z - (z |> primal, xa |> adjoint adjoints ) + z |> reverseProp D.One + (z |> primal, xa |> adjoint) /// Gradient of a vector-to-scalar function `f`, at point `x`. Reverse AD. - let inline grad f x = + let grad (f: DV -> D) x : DV = grad' f x |> snd + /// Original value and gradient-vector product (directional derivative) of a vector-to-scalar function `f`, at point `x`, along vector `v`. Forward AD. + let gradv' (f: DV -> D) (x: DV) (v: DV) : D * D = + let dvx = makeForward GlobalTagger.Next v x + dvx |> f |> primalTangent + + /// Gradient-vector product (directional derivative) of a vector-to-scalar function `f`, at point `x`, along vector `v`. Forward AD. + let gradv (f: DV -> D) x v : D = + gradv' f x v |> snd + /// Original value and Jacobian-vector product of a vector-to-vector function `f`, at point `x`, along vector `v`. Forward AD. - let inline jacobianv' f x v = + let jacobianv' (f: DV -> DV) x v : DV * DV = x |> makeForward GlobalTagger.Next v |> f |> primalTangent /// Jacobian-vector product of a vector-to-vector function `f`, at point `x`, along vector `v`. Forward AD. - let inline jacobianv f x v = + let jacobianv (f: DV -> DV) x v : DV = jacobianv' f x v |> snd - /// Gradient-vector product (directional derivative) of a vector-to-scalar function `f`, at point `x`, along vector `v`. Forward AD. - let inline gradv f x v = jacobianv f x v - - /// Original value and gradient-vector product (directional derivative) of a vector-to-scalar function `f`, at point `x`, along vector `v`. Forward AD. - let inline gradv' f x v = jacobianv' f x v - /// Original value and a function for evaluating the transposed Jacobian-vector product of a vector-to-vector function `f`, at point `x`. Of the returned pair, the first is the original value of function `f` at point `x` (the result of the forward pass of the reverse mode AD) and the second is a function (the reverse evaluator) that can compute the transposed Jacobian-vector product many times along many different vectors (performing a new reverse pass of reverse mode AD, with the given vector, without repeating the forward pass). Reverse AD. - let inline jacobianTv'' (f:'a->'b) (x:'a) = + let jacobianTv'' (f: DV -> DV) (x:DV) = let xa = x |> makeReverse GlobalTagger.Next let z = f xa let r1 = z |> primal let r2 = - fun (v:'b) -> - let adjoints = Adjoints() - z |> reverseProp adjoints v - xa |> adjoint adjoints + fun (v:DV) -> + z |> reverseProp v + xa |> adjoint (r1, r2) /// Original value and transposed Jacobian-vector product of a vector-to-vector function `f`, at point `x`, along vector `v`. Reverse AD. - let inline jacobianTv' f x v = + let jacobianTv' (f: DV -> DV) x v = let r1, r2 = jacobianTv'' f x (r1, r2 v) /// Transposed Jacobian-vector product of a vector-to-vector function `f`, at point `x`, along vector `v`. Reverse AD. - let inline jacobianTv f x v = + let jacobianTv (f: DV -> DV) x v = jacobianTv' f x v |> snd /// Original value and Jacobian of a vector-to-vector function `f`, at point `x`. Forward or reverse AD, depending on input and output dimensions. - let inline jacobian' f (x:DV) = + let jacobian' (f: DV -> DV) (x:DV) : DV * DM = let o:DV = x |> f |> primal if x.Length > o.Length then let r = jacobianTv f x @@ -4038,88 +3971,87 @@ module DiffOps = else (o, Array.init x.Length (fun i -> jacobianv f x (DV.standardBasis x.Length i)) |> DM.ofCols) - /// Jacobian of a vector-to-vector function `f`, at point `x`. Forward or reverse AD, depending on input and output dimensions. - let inline jacobian f x = + let jacobian (f: DV -> DV) x : DM = jacobian' f x |> snd /// Original value and transposed Jacobian of a vector-to-vector function `f`, at point `x`. Forward or reverse AD, depending on input and output dimensions. - let inline jacobianT' f x = + let jacobianT' (f: DV -> DV) x = jacobian' f x |> fun (r, j) -> (r, DM.transpose j) /// Transposed Jacobian of a vector-to-vector function `f`, at point `x`. Forward or reverse AD, depending on input and output dimensions. - let inline jacobianT f x = + let jacobianT (f: DV -> DV) x : DM = jacobianT' f x |> snd /// Gradient and Hessian of a vector-to-scalar function `f`, at point `x`. Forward-on-reverse AD. - let inline gradhessian f x = + let gradhessian (f: DV -> D) x : DV * DM = jacobian' (grad f) x /// Original value, gradient, and Hessian of a vector-to-scalar function `f`, at point `x`. Forward-on-reverse AD. - let inline gradhessian' f x = + let gradhessian' (f: DV -> D) x : D * DV * DM = let g, h = gradhessian f x (x |> f , g, h) /// Hessian of a vector-to-scalar function `f`, at point `x`. Forward-on-reverse AD. - let inline hessian f x = + let hessian (f: DV -> D) x : DM = jacobian (grad f) x /// Original value and Hessian of a vector-to-scalar function `f`, at point `x`. Forward-on-reverse AD. - let inline hessian' f x = + let hessian' (f: DV -> D) x : D * DM = (x |> f, hessian f x) /// Original value, gradient-vector product (directional derivative), and Hessian-vector product of a vector-to-scalar function `f`, at point `x`, along vector `v`. Reverse-on-forward AD. - let inline gradhessianv' f x v = - let gv, hv = grad' (fun xx -> jacobianv f xx v) x + let gradhessianv' (f: DV -> D) x v = + let gv, hv = grad' (fun xx -> gradv f xx v) x (x |> f, gv, hv) /// Gradient-vector product (directional derivative) and Hessian-vector product of a vector-to-scalar function `f`, at point `x`, along vector `v`. Reverse-on-forward AD. - let inline gradhessianv f x v = + let gradhessianv (f: DV -> D) x v : D * DV = gradhessianv' f x v |> drop1Of3 /// Original value and Hessian-vector product of a vector-to-scalar function `f`, at point `x`, along vector `v`. Reverse-on-forward AD. - let inline hessianv' f x v = + let hessianv' (f: DV -> D) x v = gradhessianv' f x v |> drop2Of3 /// Hessian-vector product of a vector-to-scalar function `f`, at point `x`, along vector `v`. Reverse-on-forward AD. - let inline hessianv f x v = + let hessianv (f: DV -> D) x v : DV = hessianv' f x v |> snd /// Original value and Laplacian of a vector-to-scalar function `f`, at point `x`. Reverse-on-forward AD. - let inline laplacian' f x = // TODO: reimplement faster + let laplacian' (f: DV -> D) x : D * D = // TODO: reimplement faster let v, h = hessian' f x (v, DM.trace h) /// Laplacian of a vector-to-scalar function `f`, at point `x`. Reverse-on-forward AD. - let inline laplacian f x = + let laplacian (f: DV -> D) x : D = laplacian' f x |> snd /// Original value and curl of a vector-to-vector function `f`, at point `x`. Supported only for functions with a three-by-three Jacobian matrix. Forward AD. - let inline curl' f x = + let curl' (f: DV -> DV) x = let v, j = jacobianT' f x if (j.Rows, j.Cols) <> (3, 3) then ErrorMessages.InvalidArgCurl() v, toDV [|j.[1, 2] - j.[2, 1]; j.[2, 0] - j.[0, 2]; j.[0, 1] - j.[1, 0]|] /// Curl of a vector-to-vector function `f`, at point `x`. Supported only for functions with a three-by-three Jacobian matrix. Forward AD. - let inline curl f x = + let curl (f: DV -> DV) x : DV = curl' f x |> snd /// Original value and divergence of a vector-to-vector function `f`, at point `x`. Defined only for functions with a square Jacobian matrix. Forward AD. - let inline div' f x = + let div' (f: DV -> DV) x = let v, j = jacobianT' f x if j.Rows <> j.Cols then ErrorMessages.InvalidArgDiv() v, DM.trace j /// Divergence of a vector-to-vector function `f`, at point `x`. Defined only for functions with a square Jacobian matrix. Forward AD. - let inline div f x = + let div (f: DV -> DV) x : D = div' f x |> snd /// Original value, curl, and divergence of a vector-to-vector function `f`, at point `x`. Supported only for functions with a three-by-three Jacobian matrix. Forward AD. - let inline curldiv' f x = + let curldiv' (f: DV -> DV) x = let v, j = jacobianT' f x if (j.Rows, j.Cols) <> (3, 3) then ErrorMessages.InvalidArgCurlDiv() v, toDV [|j.[1, 2] - j.[2, 1]; j.[2, 0] - j.[0, 2]; j.[0, 1] - j.[1, 0]|], DM.trace j /// Curl and divergence of a vector-to-vector function `f`, at point `x`. Supported only for functions with a three-by-three Jacobian matrix. Forward AD. - let inline curldiv f x = + let curldiv (f: DV -> DV) x : DV * D = curldiv' f x |> drop1Of3 diff --git a/src/DiffSharp/Backend.OpenBLAS.fs b/src/DiffSharp/Backend.OpenBLAS.fs index 175ace22e..10bb3c1ae 100644 --- a/src/DiffSharp/Backend.OpenBLAS.fs +++ b/src/DiffSharp/Backend.OpenBLAS.fs @@ -715,7 +715,8 @@ module OpenBLAS = let xl2 = Array2D.length2 x let yl1 = Array2D.length1 y let yl2 = Array2D.length2 y - if (xl1 <> yl1) || (xl2 <> yl2) then + if xl1 * xl2 = 0 then () + elif (xl1 <> yl1) || (xl2 <> yl2) then ErrorMessages.InvalidArgMM() else Stats.InplaceOp(yl1 * yl2) diff --git a/src/DiffSharp/DiffSharp.fsproj b/src/DiffSharp/DiffSharp.fsproj index b2764a099..2ae7d082f 100644 --- a/src/DiffSharp/DiffSharp.fsproj +++ b/src/DiffSharp/DiffSharp.fsproj @@ -3,6 +3,33 @@ netstandard2.0 /warnon:1182 x64 + 0.8.4 + 0.8.4-beta + Atılım Güneş Baydin,Barak A. Pearlmutter + BSD-2-Clause + http://diffsharp.github.io/DiffSharp/img/diffsharp-logo.png + https://github.com/DiffSharp/DiffSharp + Codestin Search App + DiffSharp is an automatic differentiation (AD) library. + +AD allows exact and efficient calculation of derivatives, by systematically invoking the chain rule of calculus at the elementary operator level during program execution. AD is different from numerical differentiation, which is prone to truncation and round-off errors, and symbolic differentiation, which is affected by expression swell and cannot fully handle algorithmic control flow. + +Using the DiffSharp library, derivative calculations (gradients, Hessians, Jacobians, directional derivatives, and matrix-free Hessian- and Jacobian-vector products) can be incorporated with minimal change into existing algorithms. Diffsharp supports nested forward and reverse AD up to any level, meaning that you can compute exact higher-order derivatives or differentiate functions that are internally making use of differentiation. Please see the API Overview page for a list of available operations. + +The library is under active development by Atılım Güneş Baydin and Barak A. Pearlmutter mainly for research applications in machine learning, as part of their work at the Brain and Computation Lab, Hamilton Institute, National University of Ireland Maynooth. + +DiffSharp is implemented in the F# language and can be used from C# and the other languages running on .NET Core, Mono, or the .NET Framework; targeting the 64 bit platform. It is tested on Linux and Windows. We are working on interfaces/ports to other languages. + Please visit + +https://github.com/DiffSharp/DiffSharp/releases + +for the latest release notes. + Copyright (c) 2016- University of Oxford (Atilim Gunes Baydin) +Copyright (c) 2017- Microsoft Research, Cambridge, UK (Don Syme) +Copyright (c) 2014- National University of Ireland Maynooth (Barak A. Pearlmutter) +Copyright (c) 2014-2016 National University of Ireland Maynooth (Atilim Gunes Baydin) + Differentiation;Automatic;Symbolic;Numerical;Optimization;Machine Learning;FSharp;F# + @@ -19,21 +46,35 @@ + true + runtimes\win\native\ PreserveNewest + true + runtimes\win\native\ PreserveNewest + true + runtimes\win\native\ PreserveNewest + true + runtimes\win\native\ PreserveNewest PreserveNewest + true + PreserveNewest + + + true + build\ PreserveNewest diff --git a/src/DiffSharp/DiffSharp.targets b/src/DiffSharp/DiffSharp.targets new file mode 100644 index 000000000..29d32fe58 --- /dev/null +++ b/src/DiffSharp/DiffSharp.targets @@ -0,0 +1,10 @@ + + + + + + %(RecursiveDir)%(FileName)%(Extension) + PreserveNewest + + + \ No newline at end of file diff --git a/src/DiffSharp/Interop.Float32.fs b/src/DiffSharp/Interop.Float32.fs index 8961af29b..4ac69bf9c 100644 --- a/src/DiffSharp/Interop.Float32.fs +++ b/src/DiffSharp/Interop.Float32.fs @@ -32,7 +32,7 @@ type D(x:ADD) = match d with | AD.D(p) -> sprintf "D %A" p | AD.DF(p, t, _) -> sprintf "DF (%A, %A)" (s p) (s t) - | AD.DR(p, op, _, _) -> sprintf "DR (%A, %A)" (s p) (op.ToString()) + | AD.DR(p, op, _, _, _) -> sprintf "DR (%A, %A)" (s p) (op.ToString()) s (d.toADD()) static member op_Explicit(d:D):AD.number = ADD.op_Explicit (d.toADD()) static member op_Implicit(a:AD.number):D = D(a) @@ -126,7 +126,7 @@ and DV(v:ADDV) = match d with | AD.DV(p) -> sprintf "DV %A" p | AD.DVF(p, t, _) -> sprintf "DVF (%A, %A)" (s p) (s t) - | AD.DVR(p, op, _, _) -> sprintf "DVR (%A, %A)" (s p) (op.ToString()) + | AD.DVR(p, op, _, _, _) -> sprintf "DVR (%A, %A)" (s p) (op.ToString()) s (d.toADDV()) member d.Visualize() = d.toADDV().Visualize() static member op_Explicit(d:DV):AD.number[] = ADDV.op_Explicit(d.toADDV()) @@ -239,7 +239,7 @@ and DM(m:ADDM) = match d with | AD.DM(p) -> sprintf "DM %A" p | AD.DMF(p, t, _) -> sprintf "DMF (%A, %A)" (s p) (s t) - | AD.DMR(p, op, _, _) -> sprintf "DMR (%A, %A)" (s p) (op.ToString()) + | AD.DMR(p, op, _, _, _) -> sprintf "DMR (%A, %A)" (s p) (op.ToString()) s (d.toADDM()) member d.Visualize() = d.toADDM().Visualize() static member op_Explicit(d:DM):AD.number[, ] = ADDM.op_Explicit(d.toADDM()) @@ -351,16 +351,6 @@ and DM(m:ADDM) = static member Normalize (a:DM) = DM(ADDM.Normalize(a.toADDM())) static member Standardize (a:DM) = DM(ADDM.Standardize(a.toADDM())) -and ADAdjoints = AD.Adjoints - -and Adjoints() = - let m = ADAdjoints() - member internal this.toADAdjoints() = m - member this.Item with get (d:D) = m.[d.toADD()] |> D.ADDtoD - member this.Item with get (d:DV) = m.[d.toADDV()] |> DV.ADDVtoDV - member this.Item with get (d:DM) = m.[d.toADDM()] |> DM.ADDMtoDM - - /// Nested forward and reverse mode automatic differentiation module type AD = diff --git a/src/DiffSharp/Interop.Float64.fs b/src/DiffSharp/Interop.Float64.fs index c7169dc31..f9aa3a7a2 100644 --- a/src/DiffSharp/Interop.Float64.fs +++ b/src/DiffSharp/Interop.Float64.fs @@ -32,7 +32,7 @@ type D(x:ADD) = match d with | AD.D(p) -> sprintf "D %A" p | AD.DF(p, t, _) -> sprintf "DF (%A, %A)" (s p) (s t) - | AD.DR(p, op, _, _) -> sprintf "DR (%A, %A)" (s p) (op.ToString()) + | AD.DR(p, op, _, _, _) -> sprintf "DR (%A, %A)" (s p) (op.ToString()) s (d.toADD()) static member op_Explicit(d:D):AD.number = ADD.op_Explicit (d.toADD()) static member op_Implicit(a:AD.number):D = D(a) @@ -126,7 +126,7 @@ and DV(v:ADDV) = match d with | AD.DV(p) -> sprintf "DV %A" p | AD.DVF(p, t, _) -> sprintf "DVF (%A, %A)" (s p) (s t) - | AD.DVR(p, op, _, _) -> sprintf "DVR (%A, %A)" (s p) (op.ToString()) + | AD.DVR(p, op, _, _, _) -> sprintf "DVR (%A, %A)" (s p) (op.ToString()) s (d.toADDV()) member d.Visualize() = d.toADDV().Visualize() static member op_Explicit(d:DV):AD.number[] = ADDV.op_Explicit(d.toADDV()) @@ -239,7 +239,7 @@ and DM(m:ADDM) = match d with | AD.DM(p) -> sprintf "DM %A" p | AD.DMF(p, t, _) -> sprintf "DMF (%A, %A)" (s p) (s t) - | AD.DMR(p, op, _, _) -> sprintf "DMR (%A, %A)" (s p) (op.ToString()) + | AD.DMR(p, op, _, _, _) -> sprintf "DMR (%A, %A)" (s p) (op.ToString()) s (d.toADDM()) member d.Visualize() = d.toADDM().Visualize() static member op_Explicit(d:DM):AD.number[, ] = ADDM.op_Explicit(d.toADDM()) @@ -351,16 +351,6 @@ and DM(m:ADDM) = static member Normalize (a:DM) = DM(ADDM.Normalize(a.toADDM())) static member Standardize (a:DM) = DM(ADDM.Standardize(a.toADDM())) -and ADAdjoints = AD.Adjoints - -and Adjoints() = - let m = ADAdjoints() - member internal this.toADAdjoints() = m - member this.Item with get (d:D) = m.[d.toADD()] |> D.ADDtoD - member this.Item with get (d:DV) = m.[d.toADDV()] |> DV.ADDVtoDV - member this.Item with get (d:DM) = m.[d.toADDM()] |> DM.ADDMtoDM - - /// Nested forward and reverse mode automatic differentiation module type AD = diff --git a/tests/DiffSharp.Tests/AD.Float32.fs b/tests/DiffSharp.Tests/AD.Float32.fs index 06b4cf9b8..89c36177b 100644 --- a/tests/DiffSharp.Tests/AD.Float32.fs +++ b/tests/DiffSharp.Tests/AD.Float32.fs @@ -12,8 +12,7 @@ open DiffSharp.Util open DiffSharp.Tests open DiffSharp.AD.Float32 - -[] +(*[] let ``AD.32.F.D.FixedPoint``() = let g (a:D) (b:D) = (a + b / a) / (D 2.f) let p, t = jacobianv' (D.FixedPoint g (D 1.2f)) (D 25.f) (D 1.f) @@ -23,4 +22,41 @@ let ``AD.32.F.D.FixedPoint``() = let ``AD.32.R.D.FixedPoint``() = let g (a:D) (b:D) = (a + b / a) / (D 2.f) let p, t = jacobianTv' (D.FixedPoint g (D 1.2f)) (D 25.f) (D 1.f) - Util.(=~)(p, D 5.f) && Util.(=~)(t, D 0.1f) + Util.(=~)(p, D 5.f) && Util.(=~)(t, D 0.1f)*) + +[] +let ``Gradient descent``() = + + let minimize (f:DV->D) (x0:DV) = + let eta = 0.01f + let mutable W = x0 + for _ in [0..10] do + let L,g = grad' f W + W <- W - eta*g + + let lossFunction (w:DV) = + let x = toDM [[1.0; 0.0]] + let Wg = w.[0..3] |> DM.ofDV 2 + let g = (x*Wg) + cos g.[0,0] + + minimize lossFunction (DV.create 5 1.0f) //Smoke test + + +[] +let ``Gradient descent (with arrays)``() = + + let minimize (f:DV->D) (x0:DV) = + let eta = 0.01f + let mutable W = x0 + for _ in [0..10] do + let L,g = grad' f W + W <- W - eta*g + + let n = 5 + let lossFunction (w:DV) = + let x = DM.init n n (fun i j -> w.[n*i+j]) + let x' = x.GetSlice(None, None, None, None) + cos x'.[0,0] + + minimize lossFunction (DV.create (n*n) 1.0f) //Smoke test \ No newline at end of file diff --git a/tests/DiffSharp.Tests/Script.fsx b/tests/DiffSharp.Tests/Script.fsx index 0bf0a0ce7..dc0413dea 100644 --- a/tests/DiffSharp.Tests/Script.fsx +++ b/tests/DiffSharp.Tests/Script.fsx @@ -1,4 +1,4 @@ -#r "../../src/DiffSharp/bin/Debug/DiffSharp.dll" +#r "../../src/DiffSharp/bin/Debug/netstandard2.0/DiffSharp.dll" open DiffSharp.AD.Float32 open DiffSharp.Config diff --git a/tests/Dsbench/Program.fs b/tests/Dsbench/Program.fs index dd7ec9b13..7d9eb59ff 100644 --- a/tests/Dsbench/Program.fs +++ b/tests/Dsbench/Program.fs @@ -356,10 +356,6 @@ let main argv = let run_grad_AD = duration "grad (auto)" n (fun () -> DiffSharp.AD.Float64.DiffOps.grad fvsD xvD) let run_grad_N = duration "grad (numeric)" n (fun () -> DiffSharp.Numerical.Float64.DiffOps.grad fvs xv) - let run_computeAdjoints_AD = duration "computeAdjoints (auto)" n (fun () -> DiffSharp.AD.Float64.DiffOps.computeAdjoints (fvsD (DiffSharp.AD.Float64.DOps.makeReverse 100u xvD))) - - let run_computeAdjoints_AD_m = duration "computeAdjoints (auto)" n (fun () -> DiffSharp.AD.Float64.DiffOps.computeAdjoints (fmsD (DiffSharp.AD.Float64.DOps.makeReverse 100u xmD))) - let run_gradv_AD = duration "gradv (auto)" n (fun () -> DiffSharp.AD.Float64.DiffOps.gradv fvsD xvD vvD) let run_gradv_N = duration "gradv (numeric)" n (fun () -> DiffSharp.Numerical.Float64.DiffOps.gradv fvs xv vv) @@ -451,13 +447,13 @@ let main argv = let k v () = v let K (vl: Lazy<_>) () = vl.Force() - let opNames = ["diff"; "diff2"; "diffn"; "computeAdjointsV"; "computeAdjointsM"; "gradV"; "gradv"; "hessian"; "hessianv"; "gradhessian"; "gradhessianv"; "laplacian"; "jacobian"; "jacobianv"; "jacobianT"; "jacobianTv"] - let row_originals = [K fss; K fss; K fss; K fvs; K fms; K fvs; K fvs; K fvs; K fvs; K fvs; K fvs; K fvs; K fvv; K fvv; K fvv; K fvv] - let row_originalsD = [K fssD; K fssD; K fssD; K fvsD; K fmsD; K fvsD; K fvsD; K fvsD; K fvsD; K fvsD; K fvsD; K fvsD; K fvvD; K fvvD; K fvvD; K fvvD] - let row_AD = [run_diff_AD; run_diff2_AD; run_diffn_AD; run_computeAdjoints_AD; run_computeAdjoints_AD_m; run_grad_AD; run_gradv_AD; run_hessian_AD; run_hessianv_AD; run_gradhessian_AD; run_gradhessianv_AD; run_laplacian_AD; run_jacobian_AD; run_jacobianv_AD; run_jacobianT_AD; run_jacobianTv_AD] - let row_N = [run_diff_N; run_diff2_N; run_diffn_N; k 0.0; k 0.0; run_grad_N; run_gradv_N; run_hessian_N; run_hessianv_N; run_gradhessian_N; run_gradhessianv_N; run_laplacian_N; run_jacobian_N; run_jacobianv_N; run_jacobianT_N; run_jacobianTv_N] - let row'_AD = [run_diff'_AD; run_diff2'_AD; run_diffn'_AD; k 0.0; k 0.0; run_grad'_AD; run_gradv'_AD; run_hessian'_AD; run_hessianv'_AD; run_gradhessian'_AD; run_gradhessianv'_AD; run_laplacian'_AD; run_jacobian'_AD; run_jacobianv'_AD; run_jacobianT'_AD; run_jacobianTv'_AD] - let row'_N = [run_diff'_N; run_diff2'_N; run_diffn'_N; k 0.0; k 0.0; run_grad'_N; run_gradv'_N; run_hessian'_N; run_hessianv'_N; run_gradhessian'_N; run_gradhessianv'_N; run_laplacian'_N; run_jacobian'_N; run_jacobianv'_N; run_jacobianT'_N; run_jacobianTv'_N] + let opNames = ["diff"; "diff2"; "diffn"; "gradV"; "gradv"; "hessian"; "hessianv"; "gradhessian"; "gradhessianv"; "laplacian"; "jacobian"; "jacobianv"; "jacobianT"; "jacobianTv"] + let row_originals = [K fss; K fss; K fss; K fvs; K fvs; K fvs; K fvs; K fvs; K fvs; K fvs; K fvv; K fvv; K fvv; K fvv] + let row_originalsD = [K fssD; K fssD; K fssD; K fvsD; K fvsD; K fvsD; K fvsD; K fvsD; K fvsD; K fvsD; K fvvD; K fvvD; K fvvD; K fvvD] + let row_AD = [run_diff_AD; run_diff2_AD; run_diffn_AD; run_gradv_AD; run_hessian_AD; run_hessianv_AD; run_gradhessian_AD; run_gradhessianv_AD; run_laplacian_AD; run_jacobian_AD; run_jacobianv_AD; run_jacobianT_AD; run_jacobianTv_AD] + let row_N = [run_diff_N; run_diff2_N; run_diffn_N; run_grad_N; run_gradv_N; run_hessian_N; run_hessianv_N; run_gradhessian_N; run_gradhessianv_N; run_laplacian_N; run_jacobian_N; run_jacobianv_N; run_jacobianT_N; run_jacobianTv_N] + let row'_AD = [run_diff'_AD; run_diff2'_AD; run_diffn'_AD; run_grad'_AD; run_gradv'_AD; run_hessian'_AD; run_hessianv'_AD; run_gradhessian'_AD; run_gradhessianv'_AD; run_laplacian'_AD; run_jacobian'_AD; run_jacobianv'_AD; run_jacobianT'_AD; run_jacobianTv'_AD] + let row'_N = [run_diff'_N; run_diff2'_N; run_diffn'_N; run_grad'_N; run_gradv'_N; run_hessian'_N; run_hessianv'_N; run_gradhessian'_N; run_gradhessianv'_N; run_laplacian'_N; run_jacobian'_N; run_jacobianv'_N; run_jacobianT'_N; run_jacobianTv'_N] if Set.ofList opNames <> Set.ofList opArgNames then failwith "opNames <> opArgNames - fix the benchmark program please" let doOp op = List.contains op opsToRun