1515from dataclasses import dataclass , field
1616
1717
18+ def sample_multivariate_normal_with_jitter (
19+ rng , mean , cov , initial_jitter = 1e-8 , max_attempts = 5
20+ ):
21+ try :
22+ return rng .multivariate_normal (mean , cov , method = "cholesky" )
23+ except np .linalg .LinAlgError :
24+ jitter = initial_jitter
25+ dim = cov .shape [0 ]
26+ cov += jitter * np .eye (dim )
27+ for _ in range (max_attempts ):
28+ try :
29+ return rng .multivariate_normal (mean , cov , method = "cholesky" )
30+ except np .linalg .LinAlgError :
31+ jitter *= 10
32+ cov += jitter * np .eye (dim )
33+ __import__ ("pdb" ).set_trace ()
34+
35+ raise LinAlgError (
36+ "Covariance matrix is not positive definite even after adding jitter."
37+ )
38+
39+
1840@dataclass
1941class MHOptions :
2042 """
@@ -236,8 +258,8 @@ def default_prop_rnd(self, x: np.ndarray, chain_idx: int) -> np.ndarray:
236258 Default random-walk: draw from N(x, Cov) where Cov depends on proposal_params.
237259 """
238260 cov = self ._get_cov_parameter (chain_idx )
239- perturbation = self . rng . multivariate_normal (
240- np .zeros (self .dim ), cov , method = "cholesky"
261+ perturbation = sample_multivariate_normal_with_jitter (
262+ self . rng , np .zeros (self .dim ), cov
241263 )
242264 return x + perturbation
243265
@@ -321,7 +343,11 @@ def mhstep(self, x_current: np.ndarray, chain_idx: int) -> Tuple[np.ndarray, boo
321343 If symmetric=False, includes reverse-proposal terms in acceptance.
322344 """
323345 y = self .prop_rnd (x_current , chain_idx )
324- log_a = self .log_target (y ) - self .log_target (x_current )
346+ try :
347+ log_target_y = self .log_target (y )
348+ except :
349+ log_target_y = - np .inf
350+ log_a = log_target_y - self .log_target (x_current )
325351 if not self .symmetric :
326352 log_a += self ._log_prop (y , x_current , chain_idx ) - self ._log_prop (
327353 x_current , y , chain_idx
@@ -987,7 +1013,7 @@ def plot_chains(self, burnin=None, parameter_indices=None, show_rate=True):
9871013 linestyle = "--" ,
9881014 label = "End Burn-in" if i == 0 else None ,
9891015 )
990- axes [i ].legend (loc = "best" )
1016+ # axes[i].legend(loc="best")
9911017 if show_rate :
9921018 axr = axes [- 1 ]
9931019 if self .rates is not None :
@@ -1003,7 +1029,7 @@ def plot_chains(self, burnin=None, parameter_indices=None, show_rate=True):
10031029 else :
10041030 print ("No acceptance data to display." )
10051031 axr .set_ylabel ("Acceptance" )
1006- axr .legend (loc = "best" )
1032+ # axr.legend(loc="best")
10071033 axr .set_xlabel ("Iteration" )
10081034 else :
10091035 axes [- 1 ].set_xlabel ("Iteration" )
@@ -1049,7 +1075,7 @@ def plot_empirical_distributions(
10491075 ax .plot (xx , kde (xx ), label = f"Chain { c + 1 } " )
10501076 ax .set_xlabel (rf"$\theta_{{{ param + 1 } }}$" )
10511077 ax .set_ylabel ("Density" )
1052- ax .legend (loc = "best" )
1078+ # ax.legend(loc="best")
10531079 plt .tight_layout ()
10541080 plt .show ()
10551081
0 commit comments