#' Projektion der Durchschnittsrenten pro Vollrentenäquivalent bei der Erstverrentung
#' sowie der Folgeentwicklung nach Rentenbezug ab 65.
#'
#' @param PARAM_GLOBAL un dataframe d'une seule ligne, dont nous utilisons les
#'   paramètres suivants:
#'   - `ra_f_2005`: Age de la retraite pour les femmes dès 2005 (64)
#'   - `ra_m`: Age de la retraite pour les hommes
#'   - `min_age_cot`: Age à partir duquel on verse des cotisations
#'
#' @param POPULATION_TOT data frame contenant les données de la population
#' carzulé dans la fonction \code{\link{mod_population}}.
#'
#' @param MORTALITE data frame contenant les taux de mortalité, cf. fonction
#' \code{\link{mod_input_mortalite}}.
#'
#' @param list `tidylist`. Elément obligatoire dans tous les modules. Au lieu de
#'   fournir des dataframes au module, il est possible de l'alimenter uniquement
#'   avec une`tidylist` qui contient les tidy dataframes. Tous les datframes
#'   listés doivent être présents dans la tidylist, avec le même nom. De plus,
#'   la `tidylist` peut aussi contenir des dataframes qui ne sont pas utilisés
#'   par le module.
#'
#' @references \href{https://www.bsv.admin.ch/dam/bsv/fr/dokumente/ahv/finanzperspektiven/validierung-modellansatz-ahv.pdf.download.pdf/2018_07_09_definitif_ld_rapport_ofas.pdf}{Rapport de Prof. Dr Laurent Donzé}
#'
#' @return une `tidylist` contenant le data frame suivant:
#' - `CUMPROD_PARAM_ERSTRENTE`
#'
#' @author [MAS BSV](mailto:sekretariat.mas@bsv.admin.ch)
#'
#' @export
#'

setwd("~/data/appl-wb/20_staff/kjo/fhh/2025-05-20T1731_u80874371_ahv_basis_kjo")

PARAM_GLOBAL <- 
  read_delim("PARAM_GLOBAL.csv")

RENTENENTWICKLUNG <- 
  read_delim("RENTENENTWICKLUNG.csv")

RR_AVS <- 
  read_delim("RR_AVS.csv")

# mod_scenario_erstrenten <- function(PARAM_GLOBAL,
#                                     RENTENENTWICKLUNG,
#                                     RR_AVS) {
#   
  sapply(c("tidyverse", "magrittr", "tsibble", "fable", "simputation", "ggshadow",
           "imputeTS" , "spatstat.utils", "Amelia", "strucchange", "timeplyr"), 
         library, char = TRUE)
  
  # Obergrenze des Lebensyzklus in Jahren nach Lebensalter 65. Die Grenze wurde anhand
  # der Datenverfügbarkeit seit 1997 gewählt. Die Ergebnisse sind nicht sensitiv gegenüber
  # leichten Variationen.
  rz_b <- 65 + PARAM_GLOBAL$jahr_lj - 2005
  
  # Projektion der Rentenniveaus vor Lebensalter 65. -------------------------------------
  RR_0 <- RR_AVS %>% 
    mutate(coh = jahr - alt) %>%
    filter(gpr == "rvieillesse_simple", 
           # Entfernung inkohärenter Einzelfälle.
           eprc > 0, alt >= 62, age_ret > 61, !(sex == "m" & age_ret < 63)) %>% 
    dplyr::summarise(pen  = sum(monatliche_rentensumme), 
                     eprc = sum(eprc),
                     .by = c("jahr", "coh", "sex", "nat", "dom", "alt")) %>% 
    left_join(select(RENTENENTWICKLUNG, jahr, minimalrente), 
              by = "jahr", relationship = "many-to-one") %>%
    # Ausdruck der durchschnittlichen Renten pro Vollrentenäquivalent relativ zur
    # kontemporären Minimalrente.
    mutate(pen = pen / (minimalrente * eprc)) %>% 
    select(coh, jahr, sex, nat, dom, alt, pen) %>% 
    filter(alt %in% 62:65) %>% 
    # Expansion der Daten um zukünftige Kohorten-IDs und Kalenderjahre.
    mutate(jahr = factor(jahr, levels =  (min(.$coh) + 65):PARAM_GLOBAL$jahr_lastoutput),
           coh  = factor(coh , levels = c(min(.$coh):(PARAM_GLOBAL$jahr_lastoutput - 62)))) %>%
    complete(coh, jahr, sex, nat, dom, alt) %>%
    mutate(jahr = parse_number(as.character(jahr)),
           coh  = parse_number(as.character(coh))) %>%
    # Entfernung inkohärenter Kombinationen.
    filter(jahr - coh == alt,
           !(sex == "m" & alt == 62), !(sex == "m" & jahr < 2001 & alt == 63)) %>%
    select(coh, jahr, sex, nat, dom, alt, pen) %>% 
    group_by(sex, nat, dom, alt) %>% 
    arrange(sex, nat, dom, alt, jahr) %>% 
    # Eingrenzung der zur Extrapolation verwendeten Datenpunkte. Das Fenster wurde 
    # gewählt, um Kontaminationen durch die vorteilhaften Vorbezugssätze der 10. AHV 
    # Revision zu vermeiden, werzhe 2012 ausgelaufen sind.
    mutate(pen_c = ifelse(jahr %in% 2014:2023, pen, NA)) %>% 
    ungroup() %>% 
    filter(!(sex == "f" & coh > 1969 & alt == 62)) %>% 
    impute_lm(pen_c ~ jahr | sex + nat + dom + alt) %>% 
    group_by(sex, nat, dom, alt) %>% 
    arrange(jahr) %>% 
    # Zusammenführen der historischen und projizierten Rentenniveaus.
    mutate(pen = coalesce(pen, pen_c)) %>% 
    select(- pen_c) %>%  
    select(coh, jahr, sex, nat, dom, alt, pen) %>% 
    ungroup()
  
  ggplot(RR_0, aes(x = jahr, y = pen, col = alt)) +
    geom_shadowpoint() +
    facet_grid(sex ~ nat + dom, labeller = label_both)
  
  # Projektion des Rentenniveaus der konsolidierten Gruppe im Alter > 65 + rz_b. -----------
  RR_F <- RR_AVS %>% 
    mutate(coh = jahr - alt) %>%
    filter(gpr == "rvieillesse_simple", 
           eprc > 0, alt >= 62, age_ret > 61, !(sex == "m" & age_ret < 63)) %>%
    dplyr::summarise(pen  = sum(monatliche_rentensumme),
                     eprc = sum(eprc),
                     n    = sum(bez_av),
                     .by = c("coh", "jahr", "sex", "nat", "dom", "zv", "alt")) %>% 
    left_join(select(RENTENENTWICKLUNG, jahr, minimalrente), 
              by = "jahr", relationship = "many-to-one")
  
  # Imputation der kontrafaktischen Wachstumsraten der laufenden Renten in 2001, als die
  # die 10. AHV Revision die Anpassung aller laufenden Renten erforderte. Wichtig: 
  # aufgrund des Prinzips der Besitzstandswahrung waren nur Anpassungen nach oben möglich.
  # Die exakten Regeln zur Auwertung sind zu kompliziert für eine Implementierung.
  corr <- RR_F %>% 
    # Eingrenzung auf Renten, werzhe bereits vor 1997 bezogen wurden (ab 1997 wurden
    # Neurenten bereits im neuen Rentensystem behandelt). Ledige werden ausgeschlossen,
    # da deren Renten nicht von dieser Anpassung betroffen waren.
    filter(!(sex == "f" & coh >= 1935), !(sex == "m" & coh >= 1932),
           jahr %in% 1999:2002, alt >= 62, zv != "ledig") %>%
    group_by(coh, sex, nat, dom, zv) %>%
    # Restriktion auf Beobachtungen, die von 1999:2002 auch beobachtbar sind.
    filter(n() == 4) %>%
    mutate(pen = pen / (minimalrente * eprc)) %>%
    arrange(coh, sex, nat, dom, zv, jahr) %>%
    mutate(w_f = pen / lag(pen) - 1) %>%
    na.omit() %>%
    # Löschen des Wachstumssprungs in 2001 zur nachfolgenden Imputation.
    mutate(w_c = ifelse(jahr == 2001, NA, w_f)) %>%
    ungroup() %>%
    impute_lm(w_c ~ 1 | coh + sex + nat + dom + zv) %>%
    filter(jahr == 2001) %>%
    # Berechnung des kontrafaktischen Wachstumsfaktors.
    mutate(corr = (1 + w_c) / (1 + w_f)) %>%
    select(coh, sex, nat, dom, zv, corr)
  
  RR_F %<>%
    left_join(corr, by = c("coh", "sex", "nat", "dom", "zv")) %>%
    # Ersetzung nötig für Ledige sowie Kohorten, werzhe ab 1997 oder später erstmals
    # bezogen haben.
    tidyr::replace_na(list(corr = 1)) %>%
    # Anwendung der Korrektur auf alte laufende Renten ab 2001.
    mutate(
      pen_c = pen / (minimalrente * eprc),
      pen_c = ifelse(jahr >= 2001, pen_c * corr, pen_c),
      pen_c = pen_c * (eprc * minimalrente)) %>%
    select(- corr) %>%
    # Einschränkung auf Altersgruppen oberharz_b der Lebenszyklus-Obergrenze.
    filter(alt > rz_b) %>% 
    select(jahr, sex, nat, dom, eprc, pen, pen_c, minimalrente) %>% 
    dplyr::summarize(
      pen  = sum(pen), pen_c = sum(pen_c), eprc = sum(eprc),
      .by  = c("jahr", "sex", "nat", "dom", "minimalrente")) %>%  
    mutate(
      pen   = pen   / (minimalrente * eprc),
      pen_c = pen_c / (minimalrente * eprc),
      jahr = factor(jahr, levels = min(RR_AVS$jahr):PARAM_GLOBAL$jahr_lastoutput)) %>%
    complete(jahr, sex, nat, dom) %>%
    mutate(jahr = parse_number(as.character(jahr))) %>%
    select(jahr, sex, nat, dom, eprc, pen, pen_c) %>% 
    group_by(sex, nat, dom) %>% 
    arrange(sex, nat, dom, jahr) %>% 
    filter(jahr >= 2001) %>% 
    mutate(d_pen = pen_c - lag(pen_c)) %>%
    ungroup()
  
  # Projektion zukünftiger Rentenniveaus via NNETAR-Methode (basierend auf einlagigem
  # neuralem Netzwerk).
  NN_PRED <-
    model(filter(as_tsibble(RR_F, index = jahr, key = c(sex, nat, dom)),
                 !(is.na(d_pen))), net = NNETAR(d_pen)) %>%
    forecast(h = PARAM_GLOBAL$jahr_lastoutput - PARAM_GLOBAL$jahr_lj + 1,
             times = 0) %>%
    as_tibble(mod) %>%
    select(jahr, sex, nat, dom, pred = .mean)
  
  RR_F %<>%
    left_join(NN_PRED, by = c("jahr", "sex", "nat", "dom")) %>%
    mutate(d_pen = coalesce(d_pen, pred)) %>%
    tidyr::replace_na(list(d_pen = 0)) %>% 
    select(- pred) %>%
    group_by(sex, nat, dom) %>%
    arrange(sex, nat, dom, jahr) %>%
    mutate(d_pen = is.na(pen) * d_pen) %>%
    fill(pen) %>%
    mutate(pen = pen + cumsum(d_pen),
           pen_c = coalesce(pen_c, pen)) %>%
    group_by(sex, nat, dom) %>%
    arrange(sex, nat, dom, jahr) %>%
    filter(jahr <= PARAM_GLOBAL$jahr_lastoutput) %>%
    # Die Kohorten ID 9999 wird zur Kennzeichnung der konsolidierten Altersgruppe
    # verwendet.
    mutate(coh = 9999, alt = rz_b + 1) %>%
    select(coh, jahr, sex, nat, dom, alt, pen)
  
  # Imputation des Rentenzyklus vom Alter 65 bis 'rz_b' ------------------------------------
  RR_EV <- RR_AVS %>% 
    mutate(coh = jahr - alt) %>%
    filter(gpr == "rvieillesse_simple", 
           eprc > 0, alt >= 62, age_ret > 61) %>% 
    dplyr::summarise(pen  = sum(monatliche_rentensumme), 
                     eprc = sum(eprc),
                     .by = c("coh", "jahr", "sex", "alt")) %>% 
    left_join(select(RENTENENTWICKLUNG, jahr, minimalrente), 
              by = "jahr", relationship = "many-to-one") %>% 
    mutate(pen = pen / (minimalrente * eprc)) %>% 
    # Restriktion auf Kohorten, werzhe frühestens seit 1997 beziehen. Vorige Kohorten sind
    # zu kontaminiert für die Schätzung des Zyklus.
    filter(!(sex == "f" & coh < 1935), !(sex == "m" & coh < 1932), alt >= 65) %>% 
    select(coh, jahr, sex, alt, pen)
  
  # Manuelle Korrektur der männlichen Erstrentenniveaus, werzhe aufgrund der weiblichen
  # Referenzalterverschiebung Sprünge aufweisen. Diese Sprünge sollten nicht extrapoliert
  # werden.
  for (x in 65:68) {
    
    M_DAT <-
      filter(RR_EV, sex == "m", alt == x) %>%
      ungroup() %>%
      mutate(ind = case_when(
        # Die Zeiträume umfassen die Referenzaltererhöhungen sowie den Auslauf der 
        # vorteilhaften Vorbezugssätze in 2012.
        jahr %in% 1997:2000 ~ "a",
        jahr %in% 2001:2004 ~ "b",
        jahr %in% 2005:2011 ~ "c",
        jahr %in% 2012:2024 ~ "d"))
    
    # Der Kontrafakt besteht darin, unterschiedliche Achsenabschnitte für die gewählten
    # Zeiträume zu schätzen, und nachfolgend ausschliesslich den letzten geschätzen
    # Achsenabschnitt zur Rückwärtsextrapolation zu verwenden. Das Subskript "_c" wird
    # verwendet, um 'counterfactual' abzukürzen.
    M_DAT %<>%
      mutate(pen_c = lm(pen ~ 0 + jahr + ind, M_DAT) %>% predict(mutate(M_DAT, ind = "d")),
             pen_c = ifelse(jahr <= 2011, pen_c, pen)) %>%
      arrange(jahr) %>%
      select(jahr, sex, alt, pen, pen_c)
    
    RR_EV %<>%
      left_join(select(M_DAT, - pen), by = c("jahr", "sex", "alt")) %>%
      mutate(pen = ifelse(alt == x & sex == "m", pen_c, pen)) %>%
      select(- pen_c)
  }
  
  RR_EV %<>%
    arrange(sex, alt, jahr) %>%
    mutate(pen_ref = ifelse(alt == 65, pen, NA)) %>%
    group_by(sex, coh) %>% 
    fill(pen_ref, .direction = "downup") %>%
    mutate(rz = pen / pen_ref) %>% 
    filter(alt <= rz_b) %>% 
    select(- pen, - pen_ref)
  
  RR_EV %<>% 
    ungroup() %>%
    # Jahres-/Kohortengrenzen werden über den Projektionshorizont gesetzt, um Randpunkt-
    # probleme in der späteren LOESS-Glättung zu vermeiden.
    mutate(jahr = factor(jahr, levels = 1997:(PARAM_GLOBAL$jahr_lastoutput + 10)),
           coh  = factor(coh , levels = c(min(.$coh):((PARAM_GLOBAL$jahr_lastoutput + 20) - 65)))) %>%
    complete(coh, jahr, sex, alt) %>%
    mutate(jahr = parse_number(as.character(jahr)),
           coh  = parse_number(as.character(coh))) %>%
    # Ausschluss inkohärenter Fälle sowie einzelner Männer-Kohorte wegen abnormalem Verlauf.
    filter(jahr - coh == alt, !(sex == "m" & coh == 1938),
           !(jahr <= 2024 & is.na(rz))) %>%
    group_by(sex, alt) %>%
    arrange(sex, alt, jahr) %>%
    # Bruchpunktanalyse zur Selektion der Extrapolationspunke pro Zykluszeitreihe 
    # konditional auf die vergangenen Lebensjahre seit 65. Die minimale Anzahl von Punkten
    # pro Segment 'h' wurde auf 3 gesetzt, um Cauchy-verteilte Prognosen zu vermeiden 
    # (gegeben normalverteilter Störterme). Die Restriktion auf Alter bis 65 + 16 ist nötig,
    # um die minimalen Datenanforderungen der Methode zu erfüllen.
    mutate(w_rz = rz / lag(rz) - 1) %>%
    slice(-1) %>% 
    mutate(bp = ifelse(alt <= 65 + 13,
                       max(breakpoints(w_rz ~ jahr, h = 3)$breakpoints, 0, na.rm = TRUE), 
                       0),
           rz_c = ifelse(1:n() > bp, rz, NA)) %>%
    mutate(w_rz = rz_c / lag(rz_c) - 1) %>%
    # mutate(w_rz = rz / lag(rz) - 1) %>%
    # Restriktion auf Altersgruppen unterhalb der Zyklus-Altersobergrenze.
    filter(alt <= rz_b) %>%
    # Definition von 't' notwendig aufgrund der händisch entfernten Männer-Kohorte von 1938.
    mutate(t = c(1, diff(jahr - 1997)), fl = is.na(rz)) %>%
    ungroup() %>% 
    impute_lm(w_rz ~ 0 + t | sex + alt) %>%
    group_by(sex, alt) %>%
    arrange(sex, jahr) %>%
    mutate(w_rz = is.na(rz) * w_rz) %>%
    fill(rz) %>%
    mutate(rz = rz * cumprod(1 + w_rz)) %>%
    select(coh, jahr, sex, alt, rz, rz_c) %>% 
    ungroup()
  
  # Glättung der Zyklus-Prognosen anhand der LOESS-Methode. 
  RR_L <- list()
  
  for (c in unique(RR_EV$coh)) {
    for (s in unique(RR_EV$sex)) {
      
      suppressWarnings({
        try({
          
          temp <-
            filter(RR_EV, coh == c, sex == s) %>%
            mutate(fl = ifelse(alt == 65, FALSE, TRUE)) %>%
            arrange(alt)
          
          temp$rz[temp$fl] <-
            predict(loess(rz ~ alt, temp))[temp$fl]
          
          RR_L[[paste0(c, s)]] <- temp
          
        }, silent = TRUE)
      })
    }
  }
  
  RR_EV <-
    bind_rows(RR_L) %>% 
    filter(jahr <= PARAM_GLOBAL$jahr_lastoutput)
  
  # Erweiterung der Prognosen um Nationalität und Domizil zur Kompatibilität mit Skript
  # 'mod_ahv_rentensumme_go'.
  RR_EV %<>%
    expand(nat = c("ch", "au"), dom = c("ch", "au"),
           nesting(coh, jahr, sex, alt, rz)) %>% 
    select(coh, jahr, sex, nat, dom, alt, rz) %>% 
    filter(alt > 65) %>% 
    left_join(filter(RR_0, alt == 65) %>% select(coh, sex, nat, dom, pen), 
              by = c("coh", "sex", "nat", "dom"), relationship = "many-to-one") %>% 
    mutate(pen = pen * rz) %>% 
    select(- rz)
  
  # Zusammenführung der Prognosen zu den Rentenniveaus im Alter 62-65 (RR_0), der 
  # Durchschnittsrenten im Alter > 65 + rz_b (RR_F) sowie des Rentenzyklus (RR_EV) ---------
  RR_C <- 
    bind_rows(RR_0, RR_EV, RR_F) %>% 
    filter(jahr >= PARAM_GLOBAL$jahr_rr) %>% 
    group_by(sex, nat, dom, alt) %>% 
    arrange(sex, nat, dom, alt, jahr) %>% 
    mutate(w_pen = cumprod(pen / lag(pen, def = first(pen)))) %>%
    ungroup() %>% 
    expand(age_ret = 62:70, nesting(coh, jahr, sex, nat, dom, alt, w_pen)) %>% 
    filter(alt >= age_ret, !(sex == "m" & age_ret < 63)) %>% 
    select(jahr, sex, nat, dom, age_ret, alt, cumprod_param_erstrente = w_pen)
  
  # Komplettierung der Altersreichweite auf 99+.
  att <- 
    bind_rows((rz_b + 1):99 %>% map(\(x) mutate(
      filter(RR_C, alt == rz_b + 1) %>% select( - alt), alt = x))) %>% 
    select(jahr, sex, nat, dom, age_ret, alt, cumprod_param_erstrente)
  
  RR_C <- 
    bind_rows(RR_C, att) %>% 
    arrange(age_ret, alt, jahr) %>% 
    distinct()
  
  return(CUMPROD_PARAM_ERSTRENTE = RR_C)
}
