Skip to content

Commit 7982c9d

Browse files
committed
mcmc.py: safeguard against target logpdf evaluation failure
1 parent 5f12274 commit 7982c9d

1 file changed

Lines changed: 32 additions & 6 deletions

File tree

gpmp/misc/mcmc.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,28 @@
1515
from dataclasses import dataclass, field
1616

1717

18+
def sample_multivariate_normal_with_jitter(
19+
rng, mean, cov, initial_jitter=1e-8, max_attempts=5
20+
):
21+
try:
22+
return rng.multivariate_normal(mean, cov, method="cholesky")
23+
except np.linalg.LinAlgError:
24+
jitter = initial_jitter
25+
dim = cov.shape[0]
26+
cov += jitter * np.eye(dim)
27+
for _ in range(max_attempts):
28+
try:
29+
return rng.multivariate_normal(mean, cov, method="cholesky")
30+
except np.linalg.LinAlgError:
31+
jitter *= 10
32+
cov += jitter * np.eye(dim)
33+
__import__("pdb").set_trace()
34+
35+
raise LinAlgError(
36+
"Covariance matrix is not positive definite even after adding jitter."
37+
)
38+
39+
1840
@dataclass
1941
class MHOptions:
2042
"""
@@ -236,8 +258,8 @@ def default_prop_rnd(self, x: np.ndarray, chain_idx: int) -> np.ndarray:
236258
Default random-walk: draw from N(x, Cov) where Cov depends on proposal_params.
237259
"""
238260
cov = self._get_cov_parameter(chain_idx)
239-
perturbation = self.rng.multivariate_normal(
240-
np.zeros(self.dim), cov, method="cholesky"
261+
perturbation = sample_multivariate_normal_with_jitter(
262+
self.rng, np.zeros(self.dim), cov
241263
)
242264
return x + perturbation
243265

@@ -321,7 +343,11 @@ def mhstep(self, x_current: np.ndarray, chain_idx: int) -> Tuple[np.ndarray, boo
321343
If symmetric=False, includes reverse-proposal terms in acceptance.
322344
"""
323345
y = self.prop_rnd(x_current, chain_idx)
324-
log_a = self.log_target(y) - self.log_target(x_current)
346+
try:
347+
log_target_y = self.log_target(y)
348+
except:
349+
log_target_y = -np.inf
350+
log_a = log_target_y - self.log_target(x_current)
325351
if not self.symmetric:
326352
log_a += self._log_prop(y, x_current, chain_idx) - self._log_prop(
327353
x_current, y, chain_idx
@@ -987,7 +1013,7 @@ def plot_chains(self, burnin=None, parameter_indices=None, show_rate=True):
9871013
linestyle="--",
9881014
label="End Burn-in" if i == 0 else None,
9891015
)
990-
axes[i].legend(loc="best")
1016+
# axes[i].legend(loc="best")
9911017
if show_rate:
9921018
axr = axes[-1]
9931019
if self.rates is not None:
@@ -1003,7 +1029,7 @@ def plot_chains(self, burnin=None, parameter_indices=None, show_rate=True):
10031029
else:
10041030
print("No acceptance data to display.")
10051031
axr.set_ylabel("Acceptance")
1006-
axr.legend(loc="best")
1032+
# axr.legend(loc="best")
10071033
axr.set_xlabel("Iteration")
10081034
else:
10091035
axes[-1].set_xlabel("Iteration")
@@ -1049,7 +1075,7 @@ def plot_empirical_distributions(
10491075
ax.plot(xx, kde(xx), label=f"Chain {c+1}")
10501076
ax.set_xlabel(rf"$\theta_{{{param+1}}}$")
10511077
ax.set_ylabel("Density")
1052-
ax.legend(loc="best")
1078+
# ax.legend(loc="best")
10531079
plt.tight_layout()
10541080
plt.show()
10551081

0 commit comments

Comments
 (0)