-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathfour_f_optimizer.py
More file actions
117 lines (90 loc) · 3.79 KB
/
Copy pathfour_f_optimizer.py
File metadata and controls
117 lines (90 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import sys
# Setting the path for XLuminA modules:
current_path = os.path.abspath(os.path.join('..'))
module_path = os.path.join(current_path)
if module_path not in sys.path:
sys.path.append(module_path)
from four_f_optical_table import *
from xlumina.toolbox import MultiHDF5DataLoader
import time
import jax
import optax
from jax import jit
import numpy as np
import jax.numpy as jnp
"""
OPTIMIZER FOR THE OPTICAL TELESCOPE (4F-SYSTEM).
"""
# Print device info (GPU or CPU)
print(jax.devices(), flush=True)
# Call the data loader and set batchsize
dataloader = MultiHDF5DataLoader("training_data_4f", batch_size=10)
# Define the loss function and compute its gradients:
loss_function = jit(loss_dualSLM)
# ----------------------------------------------------
def fit(params: optax.Params, optimizer: optax.GradientTransformation, num_iterations) -> optax.Params:
opt_state = optimizer.init(params)
@jit
def update(params, opt_state, input_fields, target_fields):
# Define single update step:
# JIT the loss and compute
loss_value, grads = jax.value_and_grad(loss_function, allow_int=True)(params, input_fields, target_fields)
# Update the state of the optimizer
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
# Initialize some parameters
iteration_steps=[]
loss_list=[]
# Optimizer settings
n_best = 500
best_loss = 1e2
best_params = None
best_step = 0
print('Starting Optimization', flush=True)
for step in range(num_iterations):
# Load data:
input_fields, target_fields = next(dataloader)
params, opt_state, loss_value = update(params, opt_state, input_fields, target_fields)
print(f"Step {step}")
print(f"Loss {loss_value}")
iteration_steps.append(step)
loss_list.append(loss_value)
# Update the `best_loss` value:
if loss_value < best_loss:
# Best loss value
best_loss = loss_value
# Best optimized parameters
best_params = params
best_step = step
print('Best loss value is updated')
if step % 100 == 0:
# Stopping criteria: if best_loss has not changed every 500 steps, stop.
if step - best_step > n_best:
print(f'Stopping criterion: no improvement in loss value for {n_best} steps')
break
print(f'Best loss: {best_loss} at step {best_step}')
print(f'Best parameters: {best_params}')
return best_params, best_loss, iteration_steps, loss_list
# ----------------------------------------------------
# Optimizer settings
num_iterations = 50000
num_samples = 1
# Step size engineering:
STEP_SIZE = 0.01
WEIGHT_DECAY = 0.0001
for i in range(num_samples):
tic = time.perf_counter()
# Init random parameters
phase_mask_slm1 = jnp.array([np.random.uniform(0, 1, (shape, shape))], dtype=jnp.float64)[0]
phase_mask_slm2 = jnp.array([np.random.uniform(0, 1, (shape, shape))], dtype=jnp.float64)[0]
distance_0 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64)
distance_1 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64)
distance_2 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64)
init_params = [distance_0, distance_1, distance_2, phase_mask_slm1, phase_mask_slm2]
# Init optimizer:
optimizer = optax.adamw(STEP_SIZE, weight_decay=WEIGHT_DECAY)
# Apply fit function:
best_params, best_loss, iteration_steps, loss_list = fit(init_params, optimizer, num_iterations)
print("Time taken to optimize one sample - in seconds", time.perf_counter() - tic)