Source code for pandas_cat.profile

import base64
import logging
import re
import warnings
from io import BytesIO
from importlib.metadata import version as _pkg_version, PackageNotFoundError as _PackageNotFoundError
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import scipy.stats as ss
from jinja2 import Environment, FileSystemLoader

from pandas_cat.statistics import _cramers_corrected_stat, _theils_u
from pandas_cat.visualisation import _plot_histogram, _plot_continuous_histogram, _plot_memory_bar
from pandas_cat.preparation import _is_continuous, _to_float_codes, prepare, handle_missing_values

_log = logging.getLogger(__name__)

_TEMPLATE_MODE_RE = re.compile(r'\{#\s*pandas-cat:\s*mode\s*=\s*(\w+)\s*#\}')

try:
    version_string = _pkg_version("pandas-cat")
except _PackageNotFoundError:
    version_string = "unknown"

template_name = "default.html.j2"


def _humanbytes(B: float) -> str:
    """Return bytes as a human-friendly string (B, KB, MB, GB, TB)."""
    power = 2 ** 10
    n = 0
    labels = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB"}
    while B > power:
        B /= power
        n += 1
    return f"{B:.2f} {labels[n]}"


def _prepare_profile_data(
    df: pd.DataFrame,
    options: dict,
    verbose: bool = True,
) -> dict:
    """Handle missing values, filter columns, and build per-column metadata.

    Single source for both _profile_default and _profile_interactive, 
    ensuring consistent filtering and statistics.

    Returns a dict with keys:
      df, detected_missings, replaced_counts, warning_info,
      excluded_attributes, columns_meta, cat_cols, cont_cols
    """
    if verbose:
        print("Handling missing values...")
    df, detected_missings, replaced_counts = handle_missing_values(
        df,
        options["na_values"],
        options["na_ignore"],
        options["keep_default_na"],
    )

    cat_limit = options["cat_limit"]
    warning_info: list = []
    excluded_attributes: list = []
    cols_to_drop: list = []

    if verbose:
        print("Profiling columns...")
    for col in df.columns:
        uniq = df[col].unique()
        _log.debug("column %s: %d unique values", col, len(uniq))

        # Remove all-NA columns
        if len(uniq) == 0 or (len(uniq) == 1 and pd.isna(uniq[0])):
            msg = f"Variable '{col}' removed — no valid values."
            warnings.warn(msg, UserWarning, stacklevel=2)
            warning_info.append({"type": "alert-warning", "text": msg})
            cols_to_drop.append(col)
            continue

        # Apply cat-limit on categorical variables
        if not _is_continuous(df[col]):
            n_cats = len(df[col].value_counts())
            if n_cats > cat_limit:
                msg = (
                    f"Variable '{col}' removed — {n_cats} categories exceeds "
                    f"limit of {cat_limit}. Increase cat_limit to include it."
                )
                warnings.warn(msg, UserWarning, stacklevel=2)
                warning_info.append({"type": "alert-warning", "text": msg})
                excluded_attributes.append({"attribute": col, "categories": n_cats})
                cols_to_drop.append(col)

    if cols_to_drop:
        df = df.drop(columns=cols_to_drop)

    columns_meta: list = []
    for col in df.columns:
        is_cont = _is_continuous(df[col])
        missing_count = int(df[col].isna().sum())
        ram = _humanbytes(df.memory_usage(deep=True)[col])
        detected = [str(v) for v in detected_missings.get(col, [])]
        replaced = [int(v) for v in replaced_counts.get(col, [])]

        is_ordered = not is_cont and hasattr(df[col], "cat") and df[col].cat.ordered
        meta: dict = {
            "name": col,
            "is_continuous": is_cont,
            "is_ordered": is_ordered,
            "missing": missing_count,
            "ram": ram,
            "detected": detected,
            "replaced": replaced,
        }

        if not is_cont:
            if hasattr(df[col], "cat") and df[col].cat.ordered:
                cat_counts = df[col].value_counts().reindex(
                    df[col].cat.categories, fill_value=0
                )
            else:
                cat_counts = df[col].value_counts()
            total_known = int(cat_counts.values.sum())
            meta["categories"] = cat_counts.index.tolist()
            meta["counts"] = [int(v) for v in cat_counts.values.tolist()]
            meta["percentages"] = [
                float(round(v / (total_known + missing_count) * 100, 2))
                for v in cat_counts.values.tolist()
            ]
            meta["_cat_counts"] = cat_counts  # avoids recomputing in renderers

        columns_meta.append(meta)

    cat_cols = [m["name"] for m in columns_meta if not m["is_continuous"]]
    cont_cols = [m["name"] for m in columns_meta if m["is_continuous"]]

    return {
        "df": df,
        "detected_missings": detected_missings,
        "replaced_counts": replaced_counts,
        "warning_info": warning_info,
        "excluded_attributes": excluded_attributes,
        "columns_meta": columns_meta,
        "cat_cols": cat_cols,
        "cont_cols": cont_cols,
    }


def _build_context(
    df: pd.DataFrame,
    options: dict,
    dataset_name: str | None,
    verbose: bool = True,
) -> dict:
    """Build the unified rendering context consumed by renderers.

    Calls _prepare_profile_data internally, then assembles raw attribute
    profiles, correlations, and summary data.  Private keys (``_``-prefixed)
    carry data needed by _add_svg_charts and are not passed to templates.
    """
    if not isinstance(df, pd.DataFrame):
        raise TypeError(f"df must be a pandas DataFrame, got {type(df).__name__}")

    prepared = _prepare_profile_data(df, options, verbose)
    working_df = prepared["df"]
    columns_meta = prepared["columns_meta"]
    cat_cols = prepared["cat_cols"]
    cont_cols = prepared["cont_cols"]
    cat_set = set(cat_cols)
    size_total = working_df.memory_usage(deep=True).sum()
    nrows = len(working_df)

    # --- Attribute profiles (raw data, no SVGs) ---
    if verbose:
        print("Preparing summary...")
    attribute_profiles: list = []
    summ_vars: list = []
    chart_names: list = []
    chart_values: list = []

    for meta in columns_meta:
        col = meta["name"]
        var_size = working_df[col].memory_usage(deep=True)
        col_size = working_df[[col]].memory_usage(deep=True).sum()
        missings = meta["missing"]
        missings_pct = missings / nrows * 100 if nrows > 0 else 0.0

        entry: dict = {
            "attribute": col,
            "is_continuous": meta["is_continuous"],
            "is_ordered": meta["is_ordered"],
            "missing": missings,
            "ram": meta["ram"],
            "detected": meta["detected"],
            "replaced": meta["replaced"],
            "categories": meta.get("categories", []),
            "counts": meta.get("counts", []),
            "percentages": meta.get("percentages", []),
        }

        if meta["is_continuous"]:
            series = working_df[col].dropna()
            if len(series) > 0:
                hist_counts, bin_edges = np.histogram(series, bins="auto")
                bin_mids = [
                    float((bin_edges[i] + bin_edges[i + 1]) / 2)
                    for i in range(len(hist_counts))
                ]
                denom = int(hist_counts.sum()) + missings
                entry["histogram_bins"] = bin_mids
                entry["histogram_counts"] = [int(c) for c in hist_counts]
                entry["histogram_percentages"] = [
                    round(float(c) / denom * 100, 2) if denom > 0 else 0.0
                    for c in hist_counts
                ]
                entry["stats"] = {
                    "mean":   round(float(series.mean()),         6),
                    "std":    round(float(series.std()),          6),
                    "min":    round(float(series.min()),          6),
                    "max":    round(float(series.max()),          6),
                    "median": round(float(series.median()),       6),
                    "q1":     round(float(series.quantile(0.25)), 6),
                    "q3":     round(float(series.quantile(0.75)), 6),
                }
                entry["summary_tbl"] = {
                    "Mean":     f"{series.mean():,.4g}",
                    "Median":   f"{series.median():,.4g}",
                    "Std Dev":  f"{series.std():,.4g}",
                    "Min":      f"{series.min():,.4g}",
                    "Max":      f"{series.max():,.4g}",
                    "Q1 (25%)": f"{series.quantile(0.25):,.4g}",
                    "Q3 (75%)": f"{series.quantile(0.75):,.4g}",
                    "Missings": f"{missings:,} ({missings_pct:.2f}%)",
                    "Memory":   _humanbytes(col_size),
                }
            else:
                entry["histogram_bins"] = []
                entry["histogram_counts"] = []
                entry["histogram_percentages"] = []
                entry["stats"] = {
                    "mean": 0.0, "std": 0.0, "min": 0.0, "max": 0.0,
                    "median": 0.0, "q1": 0.0, "q3": 0.0,
                }
                entry["summary_tbl"] = {
                    "Mean": "—", "Median": "—", "Std Dev": "—",
                    "Min": "—", "Max": "—", "Q1 (25%)": "—", "Q3 (75%)": "—",
                    "Missings": f"{missings:,} ({missings_pct:.2f}%)",
                    "Memory":   _humanbytes(col_size),
                }
            entry["freq_tbl"] = []
            cat_display = "Continuous"
            info_display = (
                f"mean={working_df[col].mean():.4g}, "
                f"std={working_df[col].std():.4g}"
            )
        else:
            cat_counts = meta["_cat_counts"]
            most_frequent = int(cat_counts.max())
            freq_tbl = []
            for cat_name, count in cat_counts.items():
                count = int(count)
                pct = count / nrows * 100
                freq_tbl.append({
                    "name": cat_name,
                    "count": count,
                    "pct": f"{pct:.2f}%",
                    "pct_num": pct,
                    "fmt_width": f"{count / most_frequent * 100:.2f}%",
                })
            vc = working_df[col].value_counts()
            idxmax, idxmin = vc.idxmax(), vc.idxmin()
            cnt_max, cnt_min = int(vc.max()), int(vc.min())
            pct_max = cnt_max / nrows * 100
            pct_min = cnt_min / nrows * 100
            entry["freq_tbl"] = freq_tbl
            entry["summary_tbl"] = {
                "Categories":     str(working_df[col].nunique()),
                "Most frequent":  f"{idxmax} ({cnt_max:,} values, {pct_max:.2f}%)",
                "Least frequent": f"{idxmin} ({cnt_min:,} values, {pct_min:.2f}%)",
                "Missings":       f"{missings:,} ({missings_pct:.2f}%)",
                "Memory":         _humanbytes(col_size),
            }
            cat_display = str(len(meta["categories"]))
            info_display = ", ".join(str(c) for c in meta["categories"])

        attribute_profiles.append(entry)
        summ_vars.append({
            "Attribute": col,
            "Categories": cat_display,
            "Categories_list": info_display,
            "Memory_usage": var_size,
            "Memory_usage_hr": _humanbytes(var_size),
        })
        chart_names.append(col)
        chart_values.append(var_size)

    # Scale chart values to a common unit for the memory bar chart.
    unit = "Bytes"
    chart_values_scaled = list(chart_values)
    tot = sum(chart_values)
    for threshold, label in [(3e12, "TB"), (3e9, "GB"), (3e6, "MB"), (3e3, "KB")]:
        if tot > threshold:
            divisor = {"TB": 1e12, "GB": 1e9, "MB": 1e6, "KB": 1e3}[label]
            chart_values_scaled = [v / divisor for v in chart_values]
            unit = label
            break

    if verbose:
        print("Preparing summary...done")

    # --- Correlations (raw {x,y,v} lists) ---
    if verbose:
        print("Preparing correlations...")
    correlations_data: dict = {"Cramers V": [], "Spearman Rank": [], "Theils U": []}
    all_cols = cat_cols + cont_cols
    for col_a in all_cols:
        for col_b in all_cols:
            a_is_cat = col_a in cat_set
            b_is_cat = col_b in cat_set

            cramers_v = 0.0
            if a_is_cat and b_is_cat:
                cm = pd.crosstab(working_df[col_a], working_df[col_b])
                cramers_v = round(float(_cramers_corrected_stat(cm)), 3)
            correlations_data["Cramers V"].append({"x": col_a, "y": col_b, "v": cramers_v})

            arr_a = (
                _to_float_codes(working_df[col_a]) if a_is_cat
                else working_df[col_a].to_numpy(dtype=float)
            )
            arr_b = (
                _to_float_codes(working_df[col_b]) if b_is_cat
                else working_df[col_b].to_numpy(dtype=float)
            )
            mask = ~(np.isnan(arr_a) | np.isnan(arr_b))
            spearman = (
                round(float(ss.spearmanr(arr_a[mask], arr_b[mask])[0]), 3)
                if mask.sum() >= 2 else 0.0
            )
            correlations_data["Spearman Rank"].append({"x": col_a, "y": col_b, "v": spearman})

            tu = 0.0
            if a_is_cat and b_is_cat:
                tu = round(float(_theils_u(working_df[col_a], working_df[col_b])), 3)
            correlations_data["Theils U"].append({"x": col_a, "y": col_b, "v": tu})

    for col_a in cat_cols:
        for col_b in cat_cols:
            cm = pd.crosstab(working_df[col_a], working_df[col_b])
            crosstab_data = cm.to_dict(orient="split")
            key = f"{col_a} x {col_b}"
            if key not in correlations_data:
                correlations_data[key] = []
            for k, cat_a in enumerate(crosstab_data["index"]):
                for l, cat_b in enumerate(crosstab_data["columns"]):
                    correlations_data[key].append(
                        {"x": cat_a, "y": cat_b, "v": float(crosstab_data["data"][k][l])}
                    )

    if verbose:
        print("Preparing correlations...done")

    return {
        "title": dataset_name or "DataFrame",
        "version_string": version_string,
        "warning_info": prepared["warning_info"],
        "excluded_attributes": prepared["excluded_attributes"],
        "records_count": nrows,
        "attribute_count": len(working_df.columns),
        "missing_count": int(working_df.isnull().sum().sum()),
        "total_ram": _humanbytes(size_total),
        "attribute_profiles": attribute_profiles,
        "correlations_data": correlations_data,
        "df_summary": {
            "overall_table": {
                "Records": f"{nrows:,}",
                "Columns": f"{len(working_df.columns):,}",
                "Memory usage": _humanbytes(size_total),
            },
            "Profiles": summ_vars,
        },
        # Private: consumed by _add_svg_charts, not passed to templates.
        "_df": working_df,
        "_cat_cols": cat_cols,
        "_cont_cols": cont_cols,
        "_chart_names": chart_names,
        "_chart_values_scaled": chart_values_scaled,
        "_unit": unit,
    }


def _add_svg_charts(context: dict, verbose: bool = True) -> None:
    """Add SVG chart fields to context in-place. Called only by static renderers."""
    working_df = context["_df"]
    cat_cols = context["_cat_cols"]
    cont_cols = context["_cont_cols"]

    # Per-column charts
    if verbose:
        print("Preparing individual profiles...")
    for entry in context["attribute_profiles"]:
        col = entry["attribute"]
        if entry["is_continuous"]:
            entry["fcont"] = _plot_continuous_histogram(working_df[col], col)
        else:
            entry["fcont"] = _plot_histogram(
                working_df[[col]], col, sort=False, save=False, rotate=False
            )
    if verbose:
        print("Preparing individual profiles...done")

    # Memory bar chart
    context["df_summary"]["mem_usg_svg"] = (
        _plot_memory_bar(
            context["_chart_names"], context["_chart_values_scaled"], context["_unit"]
        ) if context["_chart_names"] else ""
    )

    # Cramér's V heatmap SVG (rebuilt from correlations_data to avoid recomputing)
    overall_corr = None
    if cat_cols:
        cv_map = {
            (p["x"], p["y"]): p["v"]
            for p in context["correlations_data"]["Cramers V"]
            if p["x"] in cat_cols and p["y"] in cat_cols
        }
        ct = pd.DataFrame(
            [[cv_map.get((i, j), 0.0) for j in cat_cols] for i in cat_cols],
            index=cat_cols, columns=cat_cols, dtype=float,
        )
        sns.set_theme(style="whitegrid")
        fig, ax = plt.subplots(figsize=(16, 4))
        sns.heatmap(ct, annot=True, cmap="Blues", fmt=".2f", linewidth=1, ax=ax)
        ax.set_title("Cramér's V (categorical variables)")
        fig.tight_layout()
        buf = BytesIO()
        fig.savefig(buf, format="svg")
        plt.close(fig)
        overall_corr = base64.b64encode(buf.getvalue()).decode("utf-8")

    # Pearson correlation heatmap SVG
    pearson_corr = None
    if len(cont_cols) >= 2:
        pearson_matrix = working_df[cont_cols].corr(method="pearson")
        sns.set_theme(style="whitegrid")
        fig, ax = plt.subplots(figsize=(16, max(4, len(cont_cols))))
        sns.heatmap(
            pearson_matrix, annot=True, cmap="RdBu_r", center=0,
            fmt=".2f", linewidth=1, vmin=-1, vmax=1, ax=ax,
        )
        ax.set_title("Pearson Correlation (continuous variables)")
        fig.tight_layout()
        buf = BytesIO()
        fig.savefig(buf, format="svg")
        plt.close(fig)
        pearson_corr = base64.b64encode(buf.getvalue()).decode("utf-8")

    # Per-pair crosstab heatmap SVGs — reuse data already computed in _build_context
    indiv_corr: dict = {}
    for i in cat_cols:
        _log.debug("crosstabs for %s...", i)
        dict2 = {}
        for j in cat_cols:
            points = context["correlations_data"][f"{i} x {j}"]
            rows = list(dict.fromkeys(p["x"] for p in points))
            cols = list(dict.fromkeys(p["y"] for p in points))
            lookup = {(p["x"], p["y"]): p["v"] for p in points}
            ct = pd.DataFrame(
                [[lookup[(r, c)] for c in cols] for r in rows],
                index=rows, columns=cols,
            )
            sns.set_theme(style="whitegrid")
            n_rows = max(4, ct.shape[0] * 0.4 + 1)
            fig, ax = plt.subplots(figsize=(16, n_rows))
            sns.heatmap(ct, annot=True, cmap="Blues", fmt="g", ax=ax)
            buf = BytesIO()
            fig.savefig(buf, format="svg")
            plt.close(fig)
            dict2[j] = base64.b64encode(buf.getvalue()).decode("utf-8")
        indiv_corr[i] = {"attribute": i, "vars": dict2}

    context["corr"] = {
        "overall_corr": overall_corr,
        "pearson_corr": pearson_corr,
        "indiv_corr": indiv_corr,
        "spearman_rank": context["correlations_data"]["Spearman Rank"],
        "theils_u": context["correlations_data"]["Theils U"],
    }


def _resolve_template(template: str | None) -> tuple[Environment, str, Path]:
    """Map the ``template`` parameter to a Jinja2 Environment, template name, and path."""
    templates_dir = Path(__file__).parent / "templates"
    if template is None or template == "default":
        tmpl_name = template_name
        tmpl_path = templates_dir / tmpl_name
        env = Environment(loader=FileSystemLoader(templates_dir), autoescape=True)
    elif template == "modern":
        tmpl_name = "modern.html.j2"
        tmpl_path = templates_dir / tmpl_name
        env = Environment(loader=FileSystemLoader(templates_dir), autoescape=True)
    elif template == "interactive":
        tmpl_name = "interactive.html.j2"
        tmpl_path = templates_dir / "interactive" / tmpl_name
        env = Environment(loader=FileSystemLoader(templates_dir / "interactive"), autoescape=True)
    else:
        custom_path = Path(template)
        if not custom_path.is_file():
            raise ValueError(f"Template file not found: {custom_path}")
        tmpl_name = custom_path.name
        tmpl_path = custom_path.resolve()
        env = Environment(loader=FileSystemLoader(custom_path.parent), autoescape=True)
    return env, tmpl_name, tmpl_path


def _read_template_mode(tmpl_path: Path) -> str:
    """Return 'interactive' or 'static' based on the pandas-cat mode tag in the template."""
    m = _TEMPLATE_MODE_RE.search(tmpl_path.read_text(encoding="utf-8"))
    return m.group(1) if m else "static"


def _profile_render(
    df: pd.DataFrame,
    dataset_name: str | None,
    out_html: str,
    options: dict,
    env: Environment,
    tmpl_name: str,
    mode: str,
    verbose: bool = True,
) -> None:
    """Unified renderer for all template types."""
    context = _build_context(df, options, dataset_name, verbose)
    if mode != "interactive":
        _add_svg_charts(context, verbose)
    if verbose:
        print("Preparing report...")
    public_context = {k: v for k, v in context.items() if not k.startswith("_")}
    html = env.get_template(tmpl_name).render(**public_context)
    report_dir = Path.cwd() / "report"
    report_dir.mkdir(exist_ok=True)
    out_path = report_dir / out_html
    out_path.write_text(html, encoding="utf-8")
    if verbose:
        print(f"Report ready: {out_path}")


[docs] def profile( df: pd.DataFrame | None = None, dataset_name: str | None = None, template: str | None = None, out_html: str = "report.html", opts: dict | None = None, verbose: bool = True, ) -> None: """Profile a dataset and write an HTML report. The report is written to ``<cwd>/report/<out_html>``. The directory is created automatically if it does not exist. Categorical columns produce frequency bar charts and crosstab heatmaps. Numeric (continuous) columns produce histograms with mean/median overlays and descriptive statistics. :param df: DataFrame to profile. :param dataset_name: Title shown in the report header. :param template: Built-in name (``None``/``'default'``, ``'modern'``, ``'interactive'``) or a file-system path to a custom ``.html.j2`` template. Custom templates declare their rendering mode with ``{# pandas-cat: mode=interactive #}``; anything without that tag renders as static (SVG charts). :param out_html: Output filename (basename only). :param opts: Optional settings dict: * **auto_prepare** (*bool*, default ``True``) * **cat_limit** (*int*, default ``20``) * **na_values** (*list*) * **na_ignore** (*list*) * **keep_default_na** (*bool*, default ``True``) :returns: ``None``. """ if not isinstance(df, pd.DataFrame): raise TypeError(f"df must be a pandas DataFrame, got {type(df).__name__}") default_opts: dict = { "auto_prepare": True, "cat_limit": 20, "na_values": None, "na_ignore": None, "keep_default_na": True, } options = default_opts if opts is None else {**default_opts, **opts} working_df = df if options["auto_prepare"]: if verbose: print("Auto-preparing data...") working_df = prepare(df=working_df, opts=opts, verbose=verbose) if verbose: print("Auto-prepare done.") env, tmpl_name, tmpl_path = _resolve_template(template) mode = _read_template_mode(tmpl_path) _profile_render(working_df, dataset_name, out_html, options, env, tmpl_name, mode, verbose)