-
Notifications
You must be signed in to change notification settings - Fork 287
Expand file tree
/
Copy pathvar2.py
More file actions
217 lines (176 loc) · 7.04 KB
/
var2.py
File metadata and controls
217 lines (176 loc) · 7.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
r"""
Example: VAR(2) process
=======================
In this example, we demonstrate how to implement and perform Bayesian inference for a
Vector Autoregressive process of order 2 (VAR(2)). VAR models are widely used in
time series analysis, especially for capturing the dynamics between multiple variables.
A VAR(2) process for a multivariate time series :math:`y_t` with :math:`K` variables is defined as:
.. math::
y_t = c + \Phi_1 y_{t-1} + \Phi_2 y_{t-2} + \epsilon_t
Here, :math:`c` is a constant vector, :math:`\Phi_1` and :math:`\Phi_2` are coefficient matrices for lag 1
and lag 2, respectively, and :math:`\epsilon_t` is a Gaussian noise term with zero mean and a
covariance matrix :math:`\Sigma`.
This example uses NumPyro's `scan` utility to efficiently model the temporal dependencies without
explicit Python loops.
For more general time series forecasting techniques and examples, refer to the
`Time Series Forecasting` tutorial:
https://num.pyro.ai/en/stable/tutorials/time_series_forecasting.html#Forecasting
Reference
---------
For more information on Vector Autoregressive models, see:
https://otexts.com/fpp2/VAR.html
.. image:: ../_static/img/examples/var2.png
:align: center
"""
import argparse
import os
import time
import matplotlib.pyplot as plt
import numpy as np
from jax import random
import jax.numpy as jnp
import numpyro
from numpyro.contrib.control_flow import scan
import numpyro.distributions as dist
def var2_scan(y):
T, K = y.shape # Number of time steps and number of variables
# Priors for constants and coefficients
c = numpyro.sample("c", dist.Normal(0, 1).expand([K])) # Constants vector of size K
Phi1 = numpyro.sample(
"Phi1", dist.Normal(0, 1).expand([K, K]).to_event(2)
) # Coefficients for lag 1
Phi2 = numpyro.sample(
"Phi2", dist.Normal(0, 1).expand([K, K]).to_event(2)
) # Coefficients for lag 2
# Priors for error terms
sigma = numpyro.sample("sigma", dist.HalfNormal(1.0).expand([K]).to_event(1))
L_omega = numpyro.sample(
"L_omega", dist.LKJCholesky(dimension=K, concentration=1.0)
)
L_Sigma = (
sigma[..., None] * L_omega
) # Alternative: jnp.einsum("...i,...ij->...ij", sigma, L_omega)
def transition(carry, t):
y_prev1, y_prev2, y_obs = carry # Previous two observations and observed data
m_t = c + jnp.dot(Phi1, y_prev1) + jnp.dot(Phi2, y_prev2) # Mean prediction
# Conditioned on observed y
y_t = numpyro.sample(
f"y_{t}",
dist.MultivariateNormal(loc=m_t, scale_tril=L_Sigma),
obs=y_obs[t],
)
new_carry = (y_t, y_prev1, y_obs)
return new_carry, m_t
# Initial carry: observations at time steps 1 and 0
init_carry = (y[1], y[0], y[2:])
# Time indices starting from time step 2
time_indices = jnp.arange(T - 2)
# Run the scan
_, mu = scan(transition, init_carry, time_indices)
# Store the mean trajectory as a deterministic variable
numpyro.deterministic("mu", mu)
def generate_var2_data(T, K, c, Phi1, Phi2, sigma):
"""
Generate time series data from a VAR(2) process.
Args:
T (int): Number of time steps.
K (int): Number of variables in the time series.
c (array): Constants (shape: (K,)).
Phi1 (array): Coefficients for lag 1 (shape: (K, K)).
Phi2 (array): Coefficients for lag 2 (shape: (K, K)).
sigma (array): Covariance matrix for the noise (shape: (K, K)).
Returns:
np.ndarray: Generated time series data (shape: (T, K)).
"""
# Initialize time series with random values
y = np.zeros((T, K))
y[:2] = np.random.multivariate_normal(mean=np.zeros(K), cov=sigma, size=2)
# Generate the time series
for t in range(2, T):
y[t] = (
c
+ Phi1 @ y[t - 1]
+ Phi2 @ y[t - 2]
+ np.random.multivariate_normal(mean=np.zeros(K), cov=sigma)
)
return y
def run_inference(model, args, rng_key, y):
"""
Run MCMC inference for the given model.
Args:
model: The probabilistic model to infer.
args: Command-line arguments.
rng_key: PRNG key for randomness.
y: Observed time series data.
"""
start = time.time()
sampler = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(
sampler,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, y=y)
mcmc.print_summary()
print("\nMCMC elapsed time:", time.time() - start)
return mcmc.get_samples()
def main(args):
# Generate artificial dataset
T = args.num_data # Number of time steps
K = 2 # Number of variables
c_true = jnp.array([0.5, -0.3]) # Constants
Phi1_true = jnp.array([[0.7, 0.1], [0.2, 0.5]]) # Coefficients for lag 1
Phi2_true = jnp.array([[0.2, -0.1], [-0.1, 0.2]]) # Coefficients for lag 2
sigma_true = jnp.array([[0.1, 0.02], [0.02, 0.1]]) # Covariance matrix
rng_key = random.key(0)
y = generate_var2_data(T, K, c_true, Phi1_true, Phi2_true, sigma_true)
# Perform inference
samples = run_inference(var2_scan, args, rng_key, y)
# Prediction
mean_prediction = samples["mu"].mean(axis=0)
lower_bound = jnp.percentile(samples["mu"], 2.5, axis=0) # 2.5th percentile
upper_bound = jnp.percentile(samples["mu"], 97.5, axis=0) # 97.5th percentile
# Plot results
fig, axes = plt.subplots(K, 1, figsize=(10, 6), sharex=True)
time_steps = jnp.arange(T)
for i in range(K):
# True values
axes[i].plot(time_steps, y[:, i], label=f"True Variable {i + 1}", color="blue")
# Posterior mean prediction
axes[i].plot(
time_steps[2:],
mean_prediction[:, i],
label=f"Predicted Mean Variable {i + 1}",
color="orange",
)
# 95% confidence interval
axes[i].fill_between(
time_steps[2:],
lower_bound[:, i],
upper_bound[:, i],
color="orange",
alpha=0.2,
label="95% CI",
)
axes[i].set_title(f"Variable {i + 1}")
axes[i].legend()
axes[i].grid(True)
plt.xlabel("Time Steps")
plt.tight_layout()
plt.savefig("var2.png")
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.21.0")
parser = argparse.ArgumentParser(description="VAR(2) example")
parser.add_argument("--num-data", nargs="?", default=100, type=int)
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)