#' @title   SIMULATE PREDICTIVE BANDS BASED ON NORMAL RESIDUALS AND BFS SCENARIOS
#'
#' @description Ce .
#'
#' @param PARAM_GLOBAL
#'
#' @return une `tidylist` contenant le data frame suivant:
#'   - `AHV`
#'
#' @author [MAS BSV](mailto:sekretariat.mas@bsv.admin.ch)
#'
#' @export

# options(readr. show_col_types = FALSE)
# setwd("~/data/appl-wb/20_staff/kjo/fhh/2025-07-22T0923_u80874371_ahv_basis")
# 
# PARAM_GLOBAL <- 
#   read_delim("PARAM_GLOBAL.csv")
# 
# RANGE <- 
#   read_delim("RANGE.csv")
# 
# AHV_ABRECHNUNG_DEF <- 
#   read_delim("AHV_ABRECHNUNG_DEF.csv")
# 
# CORR <- 
#   read_delim("CORR.csv")
# 
# RENTENENTWICKLUNG <- 
#   read_delim("RENTENENTWICKLUNG.csv")
# 
# WIDOW_BASISMODELL <- 
#   read_delim("WIDOW_BASISMODELL.csv")

mod_ahv_szen <- function(PARAM_GLOBAL,
                         RANGE,
                         A_DAT,
                         AHV_ABRECHNUNG_DEF,
                         F_DAT,
                         CORR,
                         RENTENENTWICKLUNG,
                         WIDOW_BASISMODELL) {

  print("Run module: mod_ahv_szen")

  # Set seed for reproducibility.
  seed <- 95827323
  
  # Projection interval.
  pint  <- PARAM_GLOBAL$jahr_abr: PARAM_GLOBAL$jahr_ende_basismodell
  
  # set.seed(seed); dqset.seed(seed)
  # N_sim <- 1000000000
  # 
  # # Projection interval *including* the current jahr.
  # pint  <- PARAM_GLOBAL$jahr_abr: PARAM_GLOBAL$jahr_ende_basismodell
  # 
  # Fix degrees of freedom for respective regressions, taking into consideration that
  # differencing and estimating one parameter decreases the effective sample size by 2.
  df_nau <- first(pint) - RANGE[1] + 1 - 2
  df_mau <- first(pint) - RANGE[2] + 1 - 2
  df_mch <- first(pint) - RANGE[3] + 1 - 2
  df_tot <- first(pint) - RANGE[4] + 1 - 2
  # 
  # # Exploit that Student distributions emerge as centered normal variates with inverse
  # # Gamma distributed variances.
  # SIM_LIST <-
  #   list(seed = seed,
  #        sim_nau = dqrng::dqrnorm(N_sim) / sqrt(rchisq(N_sim, df_nau) / df_nau),
  #        sim_mau = dqrng::dqrnorm(N_sim) / sqrt(rchisq(N_sim, df_mau) / df_mau),
  #        sim_mch = dqrng::dqrnorm(N_sim) / sqrt(rchisq(N_sim, df_mch) / df_mch),
  #        sim_tot = dqrng::dqrnorm(N_sim) / sqrt(rchisq(N_sim, df_tot) / df_tot))

  # Load processed data from 'basismodell'. ----------------------------------------------
  
  ZAS <- AHV_ABRECHNUNG_DEF %>%
    select(jahr, exp_tot = aus_tot) %>%
    mutate(exp_tot = exp_tot)  %>%
    filter(jahr <= first(pint))

  # Foundational analysis data (see 'prepare_inputs.R').
  A_DAT <- A_DAT %>%
    arrange(sex, dom, type, jahr) %>%
    filter(jahr <= last(pint)) %>%
    left_join(ZAS, by = "jahr")
  
  if (PARAM_GLOBAL$szenario_fhh == "hoch")
    A_DAT <- filter(A_DAT, scen %in% c("H", "B"))
  
  if (PARAM_GLOBAL$szenario_fhh == "tief")
    A_DAT <- filter(A_DAT, scen %in% c("H", "C"))
  
  A_DAT <- select(A_DAT, - scen)
  
  # Estimate pension top-up. -------------------------------------------------------------
  FIT_T <-
    lm(c(NA, diff(exp_tot)) ~ 0 + d_nmmp, F_DAT)
  
  MOM_T <-
    tibble(jahr = pint, mu_t = coef(FIT_T), sd_t = summary(FIT_T)$sigma)

  # Collect prediction moments for explanatory variables. --------------------------------
  
  # Define function for extracting prediction errors from a simple regression without
  # intercept.
  psd <- function(fit, h) 
    summary(fit)$sigma * sqrt(1 + h^2 / sum(model.frame(fit)[, 2]^2))

  RES_P <- list()
  
  for (d in unique(A_DAT$dom)) {
    for (t in unique(filter(A_DAT, type != "wit")$type)) {
      for (s in unique(A_DAT$sex)) {

        # Constrain extrapolation data according to 'range'.
        DATA <- A_DAT %>%
          filter(dom == d, type == t, sex == s) %>%
          mutate(t_n = ifelse(jahr >= RANGE[1], 1, NA))

        if (d == "au")
          DATA <- DATA %>% 
          mutate(t_m = ifelse(jahr >= RANGE[2], 1, NA))

        if (d == "ch")
          DATA <- DATA %>% 
          mutate(t_m = ifelse(jahr >= RANGE[3], 1, NA))

        # Set moment collection table.
        MOM <-
          tibble(jahr = pint[- 1], dom = d, sex = s, type = t)

        P_DAT <-
          filter(DATA, jahr > first(pint))

        # Collect prediction moments for pension counts. Custom function 'psd' returns
        # standard deviation of predictions according to standard formula for simple
        # linear regression without intercept.
        if (d == "au") {

          FIT_N <-
            lm(c(NA, diff(n)) ~ 0 + t_n, DATA)

          MOM <- MOM %>% 
            mutate(mu_n = predict(FIT_N, P_DAT), sd_n = psd(FIT_N, P_DAT$t_n))

        } else {

          FA_DAT <-
            filter(A_DAT, jahr >= first(pint), sex == s, type == t, dom == d) %>%
            mutate(n = c(NA, diff(n))) %>%
            slice(- 1)

          MOM <- MOM %>% 
            mutate(mu_n  = FA_DAT$n, sd_n = 0)

        }

        # Collect prediction moments for pension levels.
        FIT_M <-
          lm(c(NA, diff(m)) ~ 0 + t_m, DATA)

        MOM <- MOM %>% 
          mutate(mu_m = predict(FIT_M, P_DAT), sd_m = psd(FIT_M, P_DAT$t_n))

        # Save moment table.
        RES_P[[paste0(d, t, s)]] <- MOM

      }
    }
  }

  # Collect results and attach last historic period.
  RES_P[["ref"]] <-
    filter(A_DAT, jahr == first(pint)) %>%
    # Respect that historic data is deterministic.
    mutate(sd_n = 0, sd_m = 0) %>%
    dplyr::rename(mu_m = m, mu_n = n) %>%
    select(jahr, dom, sex, type, mu_n, sd_n, mu_m, sd_m)

  SIG_DATA <-
    bind_rows(RES_P) %>%
    left_join(MOM_T, by = "jahr") %>%
    filter(jahr > first(pint)) %>% 
    group_by(jahr)

  # Simulate response distributions across time. -----------------------------------------
  
  # Fix random seed for reproducibility.
  set.seed(seed + 1)
  
  #  Set number of simulation runs (fixed by experimentation).
  N <- 20000
  
  RES_S <- list()
  
  for (d in unique(SIG_DATA$dom)) {
    for (t in unique(SIG_DATA$type)) {
      for (s in unique(SIG_DATA$sex)) {
        for (y in unique(SIG_DATA$jahr)) {

          PARAM <-
            filter(SIG_DATA, jahr == y, sex == s, type == t, dom == d)

          if (d == "au") {

            RES_S[[paste0(d, t, s, y)]] <-
              PARAM %>%
              tibble(
                # sim_n = mu_n + dqrng::dqsample(SIM_LIST$sim_nau, N, rep = TRUE) * sd_n,
                sim_n = mu_n + rt(N, df_nau) * sd_n,
                sim_m = mu_m + rt(N, df_mau) * sd_m
              ) %>%
              select(jahr, dom, sex, type, sim_n, sim_m) %>%
              mutate(run = 1:n())
          }

          if (d == "ch") {

            RES_S[[paste0(d, t, s, y)]] <-
              PARAM %>%
              tibble(
                sim_n = mu_n,
                sim_m = mu_m + rt(N, df_mau) * sd_m
              ) %>%
              select(jahr, dom, sex, type, sim_n, sim_m) %>%
              mutate(run = 1:n())
          }
        }
      }
    }
  }

  # Combine simulation runs.
  SIM <-
    bind_rows(RES_S) %>%
    arrange(jahr, type) %>%
    group_by(jahr)

  # Prepare data from last historic period.
  H_DAT <- A_DAT %>%
    filter(jahr == first(pint)) %>%
    select(jahr, dom, sex, type) %>%
    mutate(sim_n = 0, sim_m = 0) %>%
    expand(run = unique(SIM$run),
           nesting(jahr, dom, sex, type, sim_n, sim_m)) %>%
    relocate(jahr, dom, sex, type)

  MP <- RENTENENTWICKLUNG %>%
    mutate(mp = 12 * minimalrente)  %>%
    select(jahr, mp)
  
  ahv13 <- TRUE

  # Consolidate data, derive perturbed predictions across simulation runs, and integrate
  # 13th AHV pension payment.
  SIM <- SIM %>% 
    # Attach historic data.
    bind_rows(H_DAT) %>%
    left_join(A_DAT %>% 
                filter(jahr == first(pint)) %>% 
                select(dom, sex, type, n, m),
              by = c("dom", "sex", "type")) %>%
    left_join(MP, by = "jahr") %>%
    group_by(run, dom, sex, type) %>%
    arrange(jahr, by_group = TRUE) %>%
    # Enforce legal bounds on projections.
    mutate(sim_n =         pmax(0, n + cumsum(sim_n)),
           sim_m = pmin(2, pmax(0, m + cumsum(sim_m)))
    ) %>%
    # 13th AHV pension payment.
    mutate(sim_m =
             ifelse(type == "alt" & jahr >= 2026, sim_m * (1 + ahv13 * 1/12), sim_m),
           d_nmmp = c(0, diff(sim_n * sim_m * mp))) %>%
    ungroup() %>%
    # Map variables onto explanatory variable from 'fit_t', separately for each simulation
    # run per scenario.
    dplyr::summarize(d_nmmp = sum(d_nmmp), 
                     .by = c("run", "jahr")) %>%
    mutate(p_exp = filter(F_DAT, jahr == first(pint))$exp_tot) %>%
    left_join(MOM_T, 
              by = "jahr", relationship = "many-to-one") %>%
    group_by(run) %>%
    arrange(jahr, by_group = TRUE) %>%
    # Simulate final predictions incorporating uncertainty due to the rent sum top-up
    # estimation.
    mutate(d_nmmp = mu_t * d_nmmp + rt(n(), df_tot) *
             sd_t *
             sqrt(1 + d_nmmp^2 / sum(model.frame(FIT_T)[, 2]^2)),
           p_exp  = p_exp + cumsum((jahr != first(pint)) * d_nmmp)) %>%
    group_by(jahr) %>%
    # Calculate boundaries of conditional confidence bands at the 90% confidence level.
    mutate(scen_1 = 
             quantile(p_exp, ifelse(PARAM_GLOBAL$szenario_fhh == "hoch", .05, .95))
           ) %>%
    select(jahr, scen_1) %>%
    distinct()

  # Integrate Liechtenstein effect, AHV21 cost vector and complementary widow projections.
  # Neither of these ex post correction is included in the uncertainty quantification.
  D_WID <- WIDOW_BASISMODELL %>%
    bind_rows(A_DAT %>%
                select(jahr, sex, dom, type, m, n, mp) %>%
                filter(jahr == PARAM_GLOBAL$jahr_abr, type == "wit")) %>%
    group_by(sex, dom) %>%
    arrange(jahr, by_group = TRUE) %>%
    mutate(d_wid = c(0, diff(n * m * mp))) %>%
    ungroup() %>%
    dplyr::summarize(d_wid = sum(d_wid), .by = "jahr") %>%
    mutate(d_wid = cumsum(d_wid))

  SIM <- SIM %>% 
    left_join(CORR %>%
                select(jahr, save_rel) %>% 
                dplyr::summarize(save = sum(save_rel), .by = "jahr"),
              by = "jahr") %>%
    left_join(D_WID, by = "jahr") %>%
    mutate(scen_1 = scen_1 + d_wid - save) %>%
    select(- save, - d_wid)

  # Express results in millions and round.
  TOT_AUSG <- SIM %>%
    mutate(aus_tot  = round(scen_1)) %>%
    select(jahr, aus_tot)

  return(TOT_AUSG)
}
