Skip to content

Commit 02b6342

Browse files
Enhance plotting functionality in BaseDatasetScoreEvaluator to accept additional keyword arguments
For instance, `zoom` argument for KDEDistanceEvaluator (default zoom=False)
1 parent 12f16a0 commit 02b6342

2 files changed

Lines changed: 23 additions & 5 deletions

File tree

src/sdialog/evaluation/__init__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _kl_divergence(p1, p2, resolution=100, bw_method=1e-1):
117117
"""
118118
Estimate KL divergence KL(p1 || p2) between two 1D distributions via KDE.
119119
120-
KL(p1||p2) is nonsymmetric and >= 0 (0 means identical).
120+
KL(p1||p2) is non-symmetric and >= 0 (0 means identical).
121121
122122
:param p1: First sample (treat as true distribution).
123123
:type p1: array-like
@@ -2009,7 +2009,7 @@ def __init__(self,
20092009
leave=verbose)]
20102010
self.reference_scores = np.array([s for s in self.reference_scores if s is not None])
20112011

2012-
def __plot__(self, dialog_scores: Dict[str, np.ndarray], plot: Optional[plt.Axes] = None):
2012+
def __plot__(self, dialog_scores: Dict[str, np.ndarray], plot: Optional[plt.Axes] = None, zoom: bool = False):
20132013
"""
20142014
Plot KDE curves of reference and candidate score distributions.
20152015
@@ -2038,6 +2038,21 @@ def __plot__(self, dialog_scores: Dict[str, np.ndarray], plot: Optional[plt.Axes
20382038
color_idx += 1
20392039
except ValueError as e:
20402040
logger.error(f"Error plotting KDE for {dataset_name}: {e}")
2041+
2042+
if zoom:
2043+
# Percentile-based zoom
2044+
all_scores = []
2045+
if self.reference_scores is not None:
2046+
all_scores.append(self.reference_scores)
2047+
for scores in dialog_scores.values():
2048+
all_scores.append(scores)
2049+
2050+
if all_scores:
2051+
all_scores = np.concatenate(all_scores)
2052+
low, high = np.percentile(all_scores, [2, 98]) # tweak if needed
2053+
pad = 0.05 * (high - low)
2054+
plt.gca().set_xlim(low - pad, high + pad)
2055+
20412056
plot.xlabel(self.plot_xlabel if self.plot_xlabel else self.dialog_score.name)
20422057
plot.ylabel(self.plot_ylabel if self.plot_ylabel else "Density")
20432058
plot.legend(loc='best', frameon=True, fancybox=False, edgecolor='black', framealpha=1.0)

src/sdialog/evaluation/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -569,14 +569,17 @@ def clear(self):
569569

570570
def plot(self,
571571
show: bool = True,
572-
save_path: str = None):
572+
save_path: str = None,
573+
**kwargs):
573574
"""
574575
Generate plots for stored dataset scores.
575576
576577
:param show: Whether to display the plot(s).
577578
:type show: bool
578579
:param save_path: If provided, save figure(s) to this path (metric name appended when multi-metric).
579580
:type save_path: Optional[str]
581+
:param kwargs: Additional keyword arguments for plotting.
582+
:type kwargs: dict
580583
:return: None
581584
:rtype: None
582585
"""
@@ -587,7 +590,7 @@ def plot(self,
587590
if self.datasets_scores and isinstance(next(iter(self.datasets_scores.values())), dict):
588591
for metric in self.datasets_scores:
589592
plt.figure(figsize=(8, 5))
590-
self.__plot__(self.datasets_scores[metric], plot=plt, metric=metric)
593+
self.__plot__(self.datasets_scores[metric], plot=plt, metric=metric, **kwargs)
591594
if save_path:
592595
# Append metric name to filename before saving
593596
if "." in save_path.split("/")[-1]:
@@ -601,7 +604,7 @@ def plot(self,
601604
plt.show()
602605
else:
603606
plt.figure(figsize=(8, 5))
604-
self.__plot__(self.datasets_scores, plot=plt)
607+
self.__plot__(self.datasets_scores, plot=plt, **kwargs)
605608
if save_path:
606609
os.makedirs(os.path.dirname(save_path), exist_ok=True)
607610
plt.savefig(save_path, dpi=300)

0 commit comments

Comments
 (0)