diff --git a/DashAI/front/src/components/notebooks/dataset/tabs/CorrelationsTab.jsx b/DashAI/front/src/components/notebooks/dataset/tabs/CorrelationsTab.jsx index f9139e4a0..18822e507 100644 --- a/DashAI/front/src/components/notebooks/dataset/tabs/CorrelationsTab.jsx +++ b/DashAI/front/src/components/notebooks/dataset/tabs/CorrelationsTab.jsx @@ -1,36 +1,48 @@ -import React from "react"; +import React, { useMemo } from "react"; import { Box, Typography, CardContent, Card } from "@mui/material"; import { useTheme } from "@mui/material/styles"; -import { - BarChart, - Bar, - Cell, - XAxis, - YAxis, - CartesianGrid, - Tooltip, - ResponsiveContainer, -} from "recharts"; +import Plot from "react-plotly.js"; import { useTranslation } from "react-i18next"; const CorrelationsTab = ({ correlations }) => { const { t } = useTranslation(["datasets"]); - const theme = useTheme(); - const corrData = []; - Object.entries(correlations).forEach(([col1, corrs]) => { - Object.entries(corrs).forEach(([col2, value]) => { - if (col1 < col2) { - // Avoid duplicates - corrData.push({ - pair: `${col1} - ${col2}`, - correlation: value, - }); - } + + const { columns, zValues, strongCorrelations, leftMargin } = useMemo(() => { + const cols = Object.keys(correlations); + + // Build symmetric matrix + const z = cols.map((col1) => + cols.map((col2) => { + if (col1 === col2) return 1; + return correlations[col1]?.[col2] ?? correlations[col2]?.[col1] ?? 0; + }), + ); + + // Extract strong correlations (|r| > 0.5, no self-correlations) + const strong = []; + cols.forEach((col1, i) => { + cols.forEach((col2, j) => { + if (i < j && Math.abs(z[i][j]) > 0.5) { + strong.push({ + pair: `${col1} - ${col2}`, + correlation: z[i][j], + }); + } + }); }); - }); + strong.sort((a, b) => Math.abs(b.correlation) - Math.abs(a.correlation)); - corrData.sort((a, b) => Math.abs(b.correlation) - Math.abs(a.correlation)); + const maxLabelLen = Math.max(...cols.map((c) => c.length), 0); + const leftMargin = Math.min(maxLabelLen * 7 + 20, 300); + + return { + columns: cols, + zValues: z, + strongCorrelations: strong, + leftMargin, + }; + }, [correlations]); return ( @@ -40,37 +52,57 @@ const CorrelationsTab = ({ correlations }) => { - - - - - - row.map((val) => val.toFixed(3))), + texttemplate: "%{text}", + textfont: { color: theme.palette.text.primary, - }} - labelStyle={{ color: theme.palette.text.primary }} - /> - - {corrData.map((entry, index) => ( - 0 - ? theme.palette.success.main - : theme.palette.error.main - } - /> - ))} - - - + size: 11, + }, + hovertemplate: "%{x} — %{y}
r = %{z:.3f}", + showscale: true, + colorbar: { + title: { + text: "r", + font: { color: theme.palette.text.primary }, + }, + tickfont: { color: theme.palette.text.secondary }, + }, + }, + ]} + layout={{ + autosize: true, + height: Math.max(400, columns.length * 40 + 150), + margin: { l: leftMargin, r: 40, t: 20, b: 120 }, + paper_bgcolor: "rgba(0,0,0,0)", + plot_bgcolor: "rgba(0,0,0,0)", + xaxis: { + tickangle: -45, + tickfont: { color: theme.palette.text.secondary, size: 11 }, + side: "bottom", + }, + yaxis: { + tickfont: { color: theme.palette.text.secondary, size: 11 }, + autorange: "reversed", + }, + }} + config={{ displayModeBar: false, responsive: true }} + useResizeHandler + style={{ width: "100%" }} + />
@@ -78,36 +110,33 @@ const CorrelationsTab = ({ correlations }) => { {t("datasets:label.strongCorrelations")} (|r| > 0.5) - {corrData - .filter((d) => Math.abs(d.correlation) > 0.5) - .map((d, idx) => ( - ( + + + {d.pair} + + 0 ? "success.main" : "error.main", }} > - - {d.pair} - - 0 ? "success.main" : "error.main", - }} - > - {d.correlation.toFixed(3)} - - - ))} - {corrData.filter((d) => Math.abs(d.correlation) > 0.5).length === - 0 && ( + {d.correlation.toFixed(3)} + + + ))} + {strongCorrelations.length === 0 && ( sum + v, 0)}`} + label={`Total: ${Object.values(nan).reduce( + (sum, v) => sum + v, + 0, + )}`} size="small" variant="outlined" sx={{ fontWeight: "bold" }} @@ -110,9 +113,15 @@ const OverviewTab = ({ - + - +