1
1
import matplotlib
2
2
matplotlib .use ('Agg' )
3
- import matplotlib .pyplot as plt
4
- from matplotlib .backends .backend_agg import FigureCanvasAgg as FigureCanvas
5
- from matplotlib .figure import Figure
6
- import keras
7
- import numpy as np
8
3
import wandb
4
+ import numpy as np
5
+ import keras
6
+ from matplotlib .figure import Figure
7
+ from matplotlib .backends .backend_agg import FigureCanvasAgg as FigureCanvas
8
+ import matplotlib .pyplot as plt
9
9
10
- def fig2data ( fig ):
10
+
11
+ def fig2data (fig ):
11
12
"""
12
13
@brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
13
14
@param fig a matplotlib figure
14
15
@return a numpy 3D array of RGBA values
15
16
"""
16
17
# draw the renderer
17
- fig .canvas .draw ( )
18
-
18
+ fig .canvas .draw ( )
19
+
19
20
# Get the RGBA buffer from the figure
20
- w ,h = fig .canvas .get_width_height ()
21
- buf = np .fromstring ( fig .canvas .tostring_argb (), dtype = np .uint8 )
22
- buf .shape = ( w , h ,4 )
23
-
21
+ w , h = fig .canvas .get_width_height ()
22
+ buf = np .fromstring ( fig .canvas .tostring_argb (), dtype = np .uint8 )
23
+ buf .shape = (w , h , 4 )
24
+
24
25
# canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
25
- buf = np .roll ( buf , 3 , axis = 2 )
26
+ buf = np .roll ( buf , 3 , axis = 2 )
26
27
return buf
27
28
28
29
29
-
30
- def repeated_predictions (model , data , look_back , steps = 100 ):
30
+ def repeated_predictions (model , data , look_back , steps = 100 ):
31
31
predictions = []
32
32
for i in range (steps ):
33
- input_data = data [np .newaxis ,:, np .newaxis ]
33
+ input_data = data [np .newaxis , :, np .newaxis ]
34
34
generated = model .predict (input_data )[0 ]
35
35
data = np .append (data , generated )[- look_back :]
36
36
predictions .append (generated )
37
37
return predictions
38
38
39
+
39
40
class PlotCallback (keras .callbacks .Callback ):
40
41
def __init__ (self , trainX , trainY , testX , testY , look_back ):
41
42
self .repeat_predictions = True
@@ -44,22 +45,24 @@ def __init__(self, trainX, trainY, testX, testY, look_back):
44
45
self .testX = testX
45
46
self .testY = testY
46
47
self .look_back = look_back
47
-
48
+
48
49
def on_epoch_end (self , epoch , logs ):
49
50
if self .repeat_predictions :
50
- preds = repeated_predictions (self .model , self .trainX [- 1 ,:,0 ], self .look_back , self .testX .shape [0 ])
51
+ preds = repeated_predictions (
52
+ self .model , self .trainX [- 1 , :, 0 ], self .look_back , self .testX .shape [0 ])
51
53
else :
52
54
preds = model .predict (testX )
53
55
54
56
# Generate a figure with matplotlib</font>
55
- figure = matplotlib .pyplot .figure ( figsize = (10 ,10 ) )
56
- plot = figure .add_subplot ( 111 )
57
+ figure = matplotlib .pyplot .figure (figsize = (10 , 10 ))
58
+ plot = figure .add_subplot ( 111 )
57
59
58
- plot .plot ( self .trainY )
59
- plot .plot ( np .append (np .empty_like (self .trainY ) * np .nan , self .testY ))
60
- plot .plot ( np .append (np .empty_like (self .trainY ) * np .nan , preds ))
60
+ plot .plot ( self .trainY )
61
+ plot .plot ( np .append (np .empty_like (self .trainY ) * np .nan , self .testY ))
62
+ plot .plot ( np .append (np .empty_like (self .trainY ) * np .nan , preds ))
61
63
62
- data = fig2data ( figure )
64
+ data = fig2data ( figure )
63
65
matplotlib .pyplot .close (figure )
64
66
65
- wandb .log ({"image" : wandb .Image (data )}, commit = False )
67
+ if epoch % 4 == 0 :
68
+ wandb .log ({"image" : wandb .Image (data )}, commit = False )
0 commit comments