Add threshold to ControlCostArea; bug fix in sum over controls #15
Add threshold to ControlCostArea; bug fix in sum over controls #15dkweiss31 merged 12 commits intodkweiss31:mainfrom
Conversation
qontrol/cost.py
Outdated
There was a problem hiding this comment.
I see that you can now call prefactor with an array making the vmap unnecesary, nice! (I don't recall if that was always the case...)
qontrol/cost.py
Outdated
There was a problem hiding this comment.
Shouldn't we make the same change in ControlCostNorm? Maybe this logic wants to be in the ControlCost _evaluate_at_tsave function?
There was a problem hiding this comment.
Of course! That makes sense.
| import dynamiqs as dq | ||
| import jax | ||
| import jax.numpy as jnp | ||
| import numpy as np |
There was a problem hiding this comment.
I think you'll want to undo the changes to get rid of numpy. You'll notice that all of the calls to numpy are outside of jitted regions. Calls to jax or jax.numpy outside of jitted regions just slow you down AFAIK
There was a problem hiding this comment.
I agree this would be the case on a CPU. Is this also true on a GPU?
I presume numpy commands are always run on the CPU. If we use numpy, the process would be bottlenecked by data transfer back and forth between a GPU and a CPU.
However, I presume jax.numpy could perform all operations on a GPU directly.
I guess we should compare the slowdown due to jax.numpy versus that due to CPU-GPU data transfer.
There was a problem hiding this comment.
I think the bottleneck as you noticed will be saving data and updating the plots, both of which must happen on a CPU. So there is no harm in using numpy in that context. Again, if the jax code is not in a jitted region, then no matter where you run the code (CPU, GPU, TPU) this will be slower than pure numpy (since the code doesn't get compiled).
There was a problem hiding this comment.
That makes sense. Will undo the np->jnp changes.
|
|
||
| import jax | ||
| import matplotlib.pyplot as plt | ||
| import numpy as np |
There was a problem hiding this comment.
In this file as well, we shouldn't be calling jnp
dkweiss31
left a comment
There was a problem hiding this comment.
Looks good so far! See my comments: I think the dt logic wants to be in ControlCost, and you'll want to restore most/all calls to numpy since those calls are outside of jit regions
|
Also for all of these PRs: you'll want to run in terminal task formatand task lintThe first call autoformats the code and the second call will optimize some imports but also tell you various things you need to fix (lines can't be longer than 88 characters, remove whitespaces, annotate arguments, etc.) Notice that this is the reason all of the checks are failing |
|
Implemented requested changes. |
dkweiss31
left a comment
There was a problem hiding this comment.
LGTM! Thanks @HarshBabla99 !
Bug fixes:
dtfactor was missing. Fixed this bug.TimeQArray.prefactor()method in the parent class now adds a prefactor to all children, includingConstantTimeQArray. Previously,qontrolchecked for the presence or absence of a prefactor attribute, but now I check if a Hamiltonian is aConstantTimeQArray.Optimizations:
Feature: