Skip to content

Commit 95b13dd

Browse files
Fix log(0) issue in CS divergence calculation by ensuring minimum numerator value
1 parent 6cd25dc commit 95b13dd

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/sdialog/evaluation/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def _cs_divergence(p1, p2, resolution=100, bw_method=1):
109109
p2_vals = p2_kernel(r)
110110
numerator = np.sum(p1_vals * p2_vals)
111111
denominator = sqrt(np.sum(p1_vals ** 2) * np.sum(p2_vals ** 2))
112-
return -log(numerator / denominator)
112+
# Avoid log(0) by ensuring numerator has minimum value
113+
return -log(max(numerator, 1e-12) / denominator)
113114

114115

115116
def _kl_divergence(p1, p2, resolution=100, bw_method=1e-1):

0 commit comments

Comments
 (0)