#' @title  IMPUTE FUTURE FOREIGN POPULATION COUNTS AND MEAN RENTS
#'
#' @description Ce .
#'
#' @param PARAM_GLOBAL
#'
#' @return une `tidylist` contenant le data frame suivant:
#'   - `AHV`
#'
#' @author [MAS BSV](mailto:sekretariat.mas@bsv.admin.ch)
#'
#' @export
# 
options(readr.show_col_types = FALSE)

setwd("~/data/appl-wb/20_staff/kjo/fhh/2025-07-22T0923_u80874371_ahv_basis")

PARAM_GLOBAL <-
  read_delim("PARAM_GLOBAL.csv")

RANGE <-
  read_delim("RANGE.csv", delim = ";")$value

HIST_MORT_RENT <-
  read_delim("HIST_MORT_RENT.csv")

RENTENENTWICKLUNG <-
  read_delim("RENTENENTWICKLUNG.csv")

AHV_ABRECHNUNG_DEF <-
  read_delim("AHV_ABRECHNUNG_DEF.csv")

WIDOW_BASISMODELL <-
  read_delim("WIDOW_BASISMODELL.csv")

mod_ahv_ausgaben_basismodell <- function(PARAM_GLOBAL,
                                         RANGE,
                                         A_DAT,
                                         HIST_MORT_RENT,
                                         RENTENENTWICKLUNG,
                                         AHV_ABRECHNUNG_DEF,
                                         WIDOW_BASISMODELL) {

  print("Run module: mod_ahv_ausgaben_basismodell")
  
  MP <- RENTENENTWICKLUNG %>%
    mutate(mp = 12 * minimalrente)  %>%
    select(jahr, mp)
  
  # Projection interval *including* the current year.
  pint <- PARAM_GLOBAL$jahr_abr: PARAM_GLOBAL$jahr_ende_basismodell
  
  # Set up list to collect results.
  RL <- list()
  
  # Imputation over rent type, sex, and domicile for the BFS reference scenario 'A'.
  RL[["au"]] <- A_DAT %>%
    filter(scen %in% c("H", "A"), dom == "au") %>%
    mutate(t_n = ifelse(jahr >= RANGE[1], 1, NA), 
           t_m = ifelse(jahr >= RANGE[2], 1, NA)) %>%
    group_by(sex, type) %>%
    arrange(sex, type, jahr) %>%
    mutate(d_n = c(NA, diff(n)), d_m = c(NA, diff(m))) %>%
    # Extrapolate pension counts and mean pension levels with a linear time trend over
    # the Cartesian product of pension type, sex, and domicile.
    ungroup() %>%
    impute_lm(d_n ~ 0 + t_n | sex + type) %>%
    impute_lm(d_m ~ 0 + t_m | sex + type) %>%
    group_by(sex, type) %>%
    mutate(d_n = ifelse(jahr > first(pint), d_n, 0),
           d_m = ifelse(jahr > first(pint), d_m, 0)) %>%
    # Prevent eventual negative predictions due to linear extrapolation.
    mutate(m = pmin(2, pmax(0, na_locf(m) + cumsum(d_m))),
           n =         pmax(0, na_locf(n) + cumsum(d_n))) %>%
    select(- t_n, - d_n, - t_m, - d_m)

  RL[["ch"]] <- A_DAT %>%
    filter(scen %in% c("H", "A"), dom == "ch") %>%
    mutate(t_m = ifelse(jahr >= RANGE[3], 1, NA)) %>%
    # Extrapolate mean pension levels with a linear time trend over the Cartesian product of
    # pension type, sex, and domicile.
    group_by(sex, type) %>%
    arrange(sex, type, jahr) %>%
    mutate(d_m = c(NA, diff(m))) %>%
    ungroup() %>%
    impute_lm(d_m ~ 0 + t_m | sex + type) %>%
    group_by(sex, type) %>%
    mutate(d_m = ifelse(jahr > first(pint), d_m, 0)) %>%
    mutate(m = pmin(2, pmax(0, na_locf(m) + cumsum(d_m)))) %>%
    select(- t_m, - d_m)

  # Import historical intermediate-year deaths of retirees and their average
  # pensions (source: RR). Factor 'mean(1:11) / 12' in 'save_rel' approximates the
  # ratio of savings due to within-year pensioner deaths under monthly versus
  # yearly payout of 13th AHV pension payment according to the Liechtenstein model.
  CORR <- HIST_MORT_RENT %>%
    select(jahr = an_rr, sex = csex, dom = recoded_cdom, n = anzahl,
           m_dec = rentensumme_dez) %>%
    mutate(jahr = jahr + 1,
           sex  = recode(sex, `1` = "m", `2` = "f"),
           dom  = recode(dom, `100` = "ch", `900` = "au")) %>%
    # Multiplication by '11/12' corrects for death count in the preceding december
    # (implicit assumption: uniform death rate over months and pension levels).
    dplyr::summarize(m = 11 / 12 * sum(m_dec) / sum(n), n = sum(n),
                     .by = c("jahr", "sex", "dom")) %>%
    right_join(
      tibble(jahr = rep(min(.$jahr):last(pint), each = 4),
             sex  = rep(c("m" ,  "m",  "f",  "f"), length(min(.$jahr):last(pint))),
             dom  = rep(c("ch", "au", "ch", "au"), length(min(.$jahr):last(pint)))),
      by = c("jahr", "sex", "dom")
    ) %>%
    # arrange(year) %>%
    left_join(MP, by = "jahr") %>%
    # Express mean monthly pensions as multiples of minimal rent.
    mutate(m = m / mp) %>%
    group_by(sex, dom) %>%
    arrange(sex, dom, jahr) %>%
    mutate(d_n = c(0, diff(n)), d_m = c(0, diff(m))) %>%
    ungroup() %>%
    # Impute death counts and mean pensions linearly and separately by sex and domicile.
    impute_lm(d_n + d_m ~ 1 | sex + dom) %>%
    group_by(sex, dom) %>%
    mutate(d_n = cumsum((jahr > first(pint)) * d_n),
           d_m = cumsum((jahr > first(pint)) * d_m)) %>%
    mutate(m = pmin(2, pmax(0, na_locf(m) + d_m)),
           n =         pmax(0, na_locf(n) + d_n)) %>%
    mutate(save = m * n * mp) %>%
    ungroup() %>%
    dplyr::summarize(save_tot = sum(save), .by = c("jahr", "sex", "dom")) %>%
    filter(jahr >= first(pint)) %>%
    # Calculate effective savings relative to yearly payout.
    mutate(save_rel = ifelse(jahr >= 2026, save_tot * (1 - mean(1:11) / 12), 0),
           type = "alt") %>%
    select(jahr, sex, dom, type, save_rel)

  ahv13 = TRUE

  ZAS <- AHV_ABRECHNUNG_DEF %>%
    select(jahr, exp_tot = aus_tot) %>%
    mutate(exp_tot = exp_tot)  %>%
    filter(jahr <= first(pint))

  # Replace average pension projections for male widows with exogenous
  # vector based on consolidated STATPOP & pension registry data (author:
  # Thomas Borek @Math-BSV).
  P_DAT <-
    bind_rows(RL) %>%
    left_join(CORR, by = c("jahr", "sex", "dom", "type")) %>%
    tidyr::replace_na(list(save_rel = 0)) %>%
    # Incorporate Liechtenstein effect through an equivalent decrease of surviving
    # retirees' average pension level.
    mutate(m = ifelse(type == "alt" & jahr >= 2026,
                      m * (1 + ahv13 * 1/12) -
                        ahv13 * save_rel / (n * mp), m)) %>%
    filter(!(type == "wit" & jahr > first(pint))) %>%
    bind_rows(filter(WIDOW_BASISMODELL, jahr > first(pint), jahr <= last(pint))) %>%
    left_join(ZAS, by = "jahr") %>%
    ungroup()

  # Fit model and produce projections.
  F_DAT <- P_DAT %>%
    dplyr::summarize(m = weighted.mean(m, n), n = sum(n),
                     .by = c("jahr", "mp", "exp_tot")) %>%
    filter(jahr %in% RANGE[4]:first(pint)) %>%
    mutate(d_nmmp = c(NA, diff(n * m * mp)))

  FIT <-
    lm(c(NA, diff(exp_tot)) ~ 0 + d_nmmp, F_DAT)

  # Process results according to desired aggregation level.
  P_DAT <- P_DAT %>% 
    dplyr::summarize(m = weighted.mean(m, n), n = sum(n),
                     .by = c("jahr", "mp", "exp_tot"))

  # Consolidate results.
  TOT_AUSG <- P_DAT %>%
    filter(jahr %in% pint) %>%
    mutate(d_nmmp = c(0, cumsum(diff(n * m * mp)))) %>%
    # Adjust projections by estimated cost top-up.
    mutate(exp_tot = ifelse(jahr > first(pint),
                            first(exp_tot) + d_nmmp * coef(FIT),
                            exp_tot)) %>%
    dplyr::rename(ref = exp_tot) %>%
    select(- d_nmmp) %>%
    rename(aus_tot = ref) %>%
    select(jahr, aus_tot)

  if (PARAM_GLOBAL$szenario_fhh %in% c("tief", "hoch"))
    TOT_AUSG <- mod_ahv_szen(PARAM_GLOBAL,
                             RANGE,
                             A_DAT,
                             AHV_ABRECHNUNG_DEF,
                             F_DAT,
                             CORR,
                             RENTENENTWICKLUNG,
                             WIDOW_BASISMODELL)

  return(TOT_AUSG = TOT_AUSG)
}
