import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Tuple, Dict

@dataclass
class DPAHParams:
    """All tunable parameters in one place."""
    n_bins_P: int = 12
    n_bins_T: int = 15
    n_bins_D: int = 12      # Dewpoint (D)
    n_bins_ELR: int = 8
    P_range: Tuple[float, float] = (997.5, 1032.5)
    T_range: Tuple[float, float] = (285.0, 320.0)
    D_range: Tuple[float, float] = (280.0, 300.0)
    ELR_range: Tuple[float, float] = (4.0, 10.0)
    film_suppression: float = 0.0
    regime: str = "Ascent"
    target_P_ascent: float = 1010.0
    target_T_ascent: float = 299.0
    target_D_ascent: float = 294.0
    target_ELR_ascent: float = 6.3
    target_P_descent: float = 1025.0
    target_T_descent: float = 305.0
    target_D_descent: float = 288.0
    target_ELR_descent: float = 8.0

class DPAHMarkovModel:
    def __init__(self, params: DPAHParams):
        self.params = params
        self.n_bins = {'P': params.n_bins_P, 'T': params.n_bins_T,
                       'D': params.n_bins_D, 'ELR': params.n_bins_ELR}
        self.ranges = {'P': params.P_range, 'T': params.T_range,
                       'D': params.D_range, 'ELR': params.ELR_range}
        self.state_shape = tuple(self.n_bins.values())
        self.total_states = np.prod(self.state_shape)
        self.preferred_T = None
        self.scores = None
        self.stationary_dist = None

    def moist_adiabat_upward(self, P_sfc: float, T_sfc: float, D_sfc: float) -> Tuple[float, float]:
        depression = T_sfc - D_sfc
        lcl_height = 125 * max(depression, 0.0) + 100
        gamma_moist = 6.5 - 0.5 * (D_sfc - 280) / 20
        T_lcl = T_sfc - gamma_moist * (lcl_height / 1000)
        P_lcl = P_sfc * (T_lcl / T_sfc) ** (9.81 / (287 * 0.0065))
        return P_lcl, T_lcl

    def limited_descent(self, P_sfc: float, T_lcl: float, P_lcl: float, ELR: float) -> float:
        T_surface_preferred = T_lcl + ELR * (P_sfc - P_lcl) * 0.01
        if self.params.regime == "Descent":
            T_surface_preferred -= 2.0 * max(0, T_surface_preferred - 302)
        return np.clip(T_surface_preferred, self.ranges['T'][0], self.ranges['T'][1])

    def precompute_physics(self):
        self.preferred_T = np.zeros(self.state_shape)
        P_vals = np.linspace(*self.ranges['P'], self.n_bins['P'])
        T_vals = np.linspace(*self.ranges['T'], self.n_bins['T'])
        D_vals = np.linspace(*self.ranges['D'], self.n_bins['D'])
        ELR_vals = np.linspace(*self.ranges['ELR'], self.n_bins['ELR'])
        for i_p, P in enumerate(P_vals):
            for i_t, T in enumerate(T_vals):
                for i_d, D in enumerate(D_vals):
                    for i_elr, ELR in enumerate(ELR_vals):
                        P_lcl, T_lcl = self.moist_adiabat_upward(P, T, D)
                        pref_T = self.limited_descent(P, T_lcl, P_lcl, ELR)
                        suppression_factor = 1.0 - self.params.film_suppression * (D / 295.0)
                        self.preferred_T[i_p, i_t, i_d, i_elr] = pref_T * suppression_factor

    def build_scores(self):
        self.scores = np.ones(self.state_shape)
        P_vals = np.linspace(*self.ranges['P'], self.n_bins['P'])
        T_vals = np.linspace(*self.ranges['T'], self.n_bins['T'])
        D_vals = np.linspace(*self.ranges['D'], self.n_bins['D'])
        ELR_vals = np.linspace(*self.ranges['ELR'], self.n_bins['ELR'])

        if self.params.regime == "Ascent":
            target_P, target_T, target_D, target_ELR = (self.params.target_P_ascent,
                                                        self.params.target_T_ascent,
                                                        self.params.target_D_ascent,
                                                        self.params.target_ELR_ascent)
            bias_scale = 14.0
        else:
            target_P, target_T, target_D, target_ELR = (self.params.target_P_descent,
                                                        self.params.target_T_descent,
                                                        self.params.target_D_descent,
                                                        self.params.target_ELR_descent)
            bias_scale = 12.0

        p_grid = P_vals[:, None, None, None]
        t_grid = T_vals[None, :, None, None]
        d_grid = D_vals[None, None, :, None]
        elr_grid = ELR_vals[None, None, None, :]

        physics_t_score = np.exp(-np.abs(t_grid - self.preferred_T) / 0.32)
        warm_bias = np.exp(-np.abs(t_grid - target_T) / 0.85)
        cold_penalty = np.where(t_grid < 297, np.exp((t_grid - 297) / 2.2), 1.0)
        hot_penalty = np.where(t_grid > 305, np.exp((305 - t_grid) / 1.8), 1.0)
        moist_boost = np.exp(-np.abs(d_grid - target_D) / 1.5) * (t_grid > 297)

        p_score = np.exp(-np.abs(p_grid - target_P) / 8.0)
        d_score = np.exp(-np.abs(d_grid - target_D) / 1.6)
        elr_score = np.exp(-np.abs(elr_grid - target_ELR) / 1.3)

        self.scores = (physics_t_score * 0.55 + warm_bias * 0.30 + moist_boost * 0.15) * \
                      cold_penalty * hot_penalty * p_score * d_score * elr_score * bias_scale
        total = self.scores.sum()
        if total > 0:
            self.scores /= total

    def compute_stationary(self):
        self.precompute_physics()
        mean_pref_T = np.mean(self.preferred_T)
        print(f"  [Diagnostic] Mean preferred_T (pure physics): {mean_pref_T:.2f} K")
        self.build_scores()
        self.stationary_dist = self.scores.reshape(self.state_shape)
        return self.stationary_dist

    def summarize(self) -> Dict:
        T_vals = np.linspace(*self.ranges['T'], self.n_bins['T'])
        marg_T_idx = np.argmax(self.stationary_dist.sum(axis=(0,2,3)))
        mean_T = np.sum(self.stationary_dist * T_vals[None, :, None, None])
        T_peak = T_vals[marg_T_idx]
        print(f"Regime: {self.params.regime} | Peak T = {T_peak:.1f} K (bin {marg_T_idx})")
        print(f"Mean stationary T: {mean_T:.2f} K | Film suppression: {self.params.film_suppression}")
        return {"mean_T": mean_T, "peak_T": T_peak}

    def plot_marginals(self, title_suffix=""):
        """Full 2x2 marginals + AUTOMATIC PNG SAVE"""
        fig, axs = plt.subplots(2, 2, figsize=(14, 10))
        T_vals = np.linspace(*self.ranges['T'], self.n_bins['T'])
        P_vals = np.linspace(*self.ranges['P'], self.n_bins['P'])
        D_vals = np.linspace(*self.ranges['D'], self.n_bins['D'])
        ELR_vals = np.linspace(*self.ranges['ELR'], self.n_bins['ELR'])

        axs[0,0].bar(T_vals, self.stationary_dist.sum(axis=(0,2,3)), width=2.2)
        axs[0,0].set_title(f"{self.params.regime} Surface Temperature (K)")
        axs[0,1].bar(P_vals, self.stationary_dist.sum(axis=(1,2,3)), width=2.8)
        axs[0,1].set_title("Surface Pressure (hPa)")
        axs[1,0].bar(D_vals, self.stationary_dist.sum(axis=(0,1,3)), width=1.7)
        axs[1,0].set_title("Dewpoint D (K)")
        axs[1,1].bar(ELR_vals, self.stationary_dist.sum(axis=(0,1,2)), width=0.6)
        axs[1,1].set_title("ELR (K/km)")

        plt.suptitle(f"Run401 {self.params.regime} Marginals {title_suffix} | Mean T = {np.sum(self.stationary_dist * T_vals[None,:,None,None]):.2f} K")
        plt.tight_layout()

        # AUTOMATIC PNG SAVE
        filename = f"Run401_{self.params.regime}_{title_suffix}_marginals.png"
        fig.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"   → Plot saved: {filename}")

        plt.show()          # comment out if you want headless runs
        plt.close('all')

# ============== Test Run (with automatic PNG saving) ==============
if __name__ == "__main__":
    print("=== Run401 v6 – Stable Baseline + Auto PNG Save ===")
    for regime_name in ["Ascent", "Descent"]:
        print(f"\n--- {regime_name} Regime ---")
        for film in [0.0, 0.05, 0.10]:
            params = DPAHParams(regime=regime_name, film_suppression=film)
            model = DPAHMarkovModel(params)
            dist = model.compute_stationary()
            model.summarize()
            model.plot_marginals(f"film{film}")
            print(f"Total states: {model.total_states} | Dist sum: {model.stationary_dist.sum():.6f}\n")
    print("\n✅ Run401 v6 complete — all plots saved as PNG files automatically!")