#' Estimation des EPRC à partir de la population
#'
#' Ce module calcule les équivalents pleine rente cumulés (EPRC) à
#' partir des données d'une certaine population en les faisant évoluer selon
#' les taux de mortalité.
#'
#' @param PARAM_GLOBAL un dataframe d'une seule ligne, dont nous utilisons les
#'   paramètres suivants:
#'   - `ra_f_2005`: Age de la retraite pour les femmes dès 2005 (64)
#'   - `ra_m`: Age de la retraite pour les hommes
#'   - `min_age_cot`: Age à partir duquel on verse des cotisations
#'
#' @param POPULATION_TOT data frame contenant les données de la population
#' calculé dans la fonction \code{\link{mod_population}}.
#'
#' @param MORTALITE data frame contenant les taux de mortalité, cf. fonction
#' \code{\link{mod_input_mortalite}}.
#'
#' @param list `tidylist`. Elément obligatoire dans tous les modules. Au lieu de
#'   fournir des dataframes au module, il est possible de l'alimenter uniquement
#'   avec une`tidylist` qui contient les tidy dataframes. Tous les datframes
#'   listés doivent être présents dans la tidylist, avec le même nom. De plus,
#'   la `tidylist` peut aussi contenir des dataframes qui ne sont pas utilisés
#'   par le module.
#'
#' @references \href{https://www.bsv.admin.ch/dam/bsv/fr/dokumente/ahv/finanzperspektiven/validierung-modellansatz-ahv.pdf.download.pdf/2018_07_09_definitif_ld_rapport_ofas.pdf}{Rapport de Prof. Dr Laurent Donzé}
#'
#' @return une `tidylist` contenant le data frame suivant:
#' - `KV`
#'
#' @author [MAS BSV](mailto:sekretariat.mas@bsv.admin.ch)
#'
#' @export
#'

source("~/delfinverse/dinput/R/utils.R")

options(readr.show_col_types = FALSE)

setwd("~/data/appl-wb/20_staff/kjo/fhh/2025-05-20T1646_u80874371_ahv_basis_kjo")

PARAM_GLOBAL <- 
  read_delim("PARAM_GLOBAL.csv")

BEVOELKERUNG <- 
  read_delim("BEVOELKERUNG.csv")

TAUX_MORTALITE <- 
  read_delim("TAUX_MORTALITE.csv")

RR_AVS <- 
  read_delim("RR_AVS.csv")


    POPULATION_TOT <- BEVOELKERUNG %>%
        mutate(bevendejahr = bevendejahr, dom = "ch") %>%
        select(jahr, sex, nat, alt, dom, bevendejahr) %>%
        bind_rows(
          BEVOELKERUNG %>%
            mutate(bevendejahr = 
                     rowSums(
                       select(., auswanderung, frontaliers, saisonniers, assures_facultatifs), 
                       na.rm = TRUE), 
                             dom = "au") %>%
                      select(jahr, sex, nat, alt, dom, bevendejahr))

    if ("min_age_cot" %in% colnames(PARAM_GLOBAL) == FALSE) {
      min_age_cot <- 21
      print("Please add <<min_age_cot;21>> into PARAM_GLOBAL")
    } else {
      min_age_cot <- as.numeric(PARAM_GLOBAL$min_age_cot)
    }
  
  # Bewertungsjahrfaktor
  # Version avec le même facteur pour tout le monde (pour les hommes et les femmes)

  POPU <- POPULATION_TOT %>%
    left_join(TAUX_MORTALITE, by = c("jahr", "alt", "sex", "nat")) %>%
    mutate(coh = jahr - alt) %>% 
    mutate(
      bewertungsjahrfaktor = case_when(
        alt %in% 21:64 & sex == "m" ~ 
          pmax(
            1 / (65 - 21             ),
            1 / (65 - 21 + coh - 1927)),
        
        alt %in% 21:62 & sex == "f" & coh <= 1938 ~ 
          pmax(
            1 / (62 - 21             ),
            1 / (62 - 21 + coh - 1927)),
        
        min_age_cot <= alt & sex == "f" & coh %in% 1939:1941 & alt < 63 ~ 
          pmax(
            1 / (63 - 21             ),
            1 / (63 - 21 + coh - 1927)),
        
        min_age_cot <= alt & sex == "f" & coh >=        1942 & alt < 64 ~ 
          pmax(
            1 / (64 - 21             ),
            1 / (64 - 21 + coh - 1927)),

        # all other cases
        TRUE ~ 0
      )
    ) %>% 
    mutate(
      epr = bewertungsjahrfaktor * bevendejahr,
      eprc = epr * 
        case_when(
          
          (alt == 21 | jahr == min(jahr)) & nat == "ch" & sex == "m"                      
            ~ pmin(alt, 65) - 21 + 1,
          
          (alt == 21 | jahr == min(jahr)) & nat == "ch" & sex == "f" & coh <= 1938        
            ~ pmin(alt, 62) - 21 + 1,
          
          (alt == 21 | jahr == min(jahr)) & nat == "ch" & sex == "f" & coh %in% 1939:1941 
            ~ pmin(alt, 63) - 21 + 1,
          
          (alt == 21 | jahr == min(jahr)) & nat == "ch" & sex == "f" & coh >= 1942        
            ~ pmin(alt, 64) - 21 + 1,
          
          (alt == 21 | jahr == min(jahr)) & nat != "ch"
            ~ 1,
          
          TRUE ~ 0)
      ) %>%
    ungroup()

  # EXTRAPOLATION DES EPRC : Fonction de calcul des EPRC
  calculate_eprc_extrapolation <- function(DTA) {

    # Jahr und Alter
    jahr <- unique(DTA$jahr)
    alt  <- unique(DTA$alt)

    # Taux de mortalité
    Q <- tibble_to_matrix(dplyr::select(DTA, jahr, alt, quotients_mortalite))

    # Equivalents pleine rente
    V <- tibble_to_matrix(dplyr::select(DTA, jahr, alt, epr))

    # Equivalents pleine rente cumulés
    K <- tibble_to_matrix(dplyr::select(DTA, jahr, alt, eprc))

    for (i in (seq_along(jahr[-1]) + 1)) {
      for (j in (seq_along(alt[-1]) + 1)) {
        
        K[i, j] <- (1 - Q[i - 1, j - 1]) * K[i - 1, j - 1] + V[i, j]
        
      }
    }
    
    matrix_to_tibble(K, dplyr::select(DTA, jahr, alt, eprc))
  }

  # Estimation des EPRC pour les 4 groupes
  KV <- crossing(
    sex = c("f", "m"), nat = c("ch", "au"), dom = c("ch", "au")) %>%
    group_by(sex, nat, dom) %>%
    do(dta0 = calculate_eprc_extrapolation(filter(
      POPU, sex == .$sex, nat == .$nat, dom == .$dom
    ))) %>%
    unnest(cols = c(dta0)) %>%
    filter(!is.na(eprc)) %>%
    left_join(POPULATION_TOT, by = c("jahr", "sex", "nat", "alt", "dom"))

  # return(KV = KV)
# }

  KV %<>%
    select(jahr, sex, nat, dom, alt, kv = eprc, pop = bevendejahr) %>% 
    filter(!(sex == "f" & alt < 62), !(sex == "m" & alt < 63),
           jahr >= PARAM_GLOBAL$jahr_rr) %>% 
    group_by(sex, nat, dom, alt) %>%
    arrange(sex, nat, dom, alt, jahr) %>%
    mutate(g_kv  = cumprod(kv  / lag(kv , def = first(kv))), 
           g_pop = cumprod(pop / lag(pop, def = first(pop)))
           ) %>%
    select(jahr, sex, nat, dom, alt, g_kv, g_pop) %>% 
    ungroup() %>% 
    expand(nesting(jahr, sex, nat, dom, alt, g_kv, g_pop), 
           zv = c("ledig", "verheiratet", "geschieden", "verwitwet"), 
           age_ret = 62:70) %>% 
    filter(alt >= age_ret, !(sex == "m" & age_ret < 63)) %>% 
    mutate(gpr = "rvieillesse_simple") %>% 
    select(jahr, sex, nat, dom, zv, gpr, age_ret, alt, g_kv, g_pop)
  
  # OT <- 
  #   filter(RR, !(gpr %in% c("rvieillesse_simple", "rcompl_femme"))) %>% 
  #   select(jahr, sex, nat, zv, gpr, alt, monatliche_rentensumme) %>% 
  #   dplyr::summarize(pen = sum(monatliche_rentensumme),
  #                    .by = c("jahr", "sex", "nat", "zv", "gpr")) %>%
  #   filter(jahr == PARAM_GLOBAL$jahr_rr) %>% 
  #   select(- jahr)
  # 
  # OT <-
  #   bind_rows(PARAM_GLOBAL$jahr_rr:2075 %>% map(\(x) mutate(OT, jahr = x))) %>%
  #   relocate(jahr)
  # 
  # BEV <-
  #   mutate(BEVOELKERUNG, 
  #          pop = bevendejahr + frontaliers + saisonniers + assures_facultatifs) %>% 
  #   select(jahr, sex, nat, alt, pop) %>%
  #   filter(jahr >= PARAM_GLOBAL$jahr_rr) %>% 
  #   mutate(pop_19 = ifelse(alt >= 19, pop, 0), pop_25 = ifelse(alt <= 25, pop, 0)) %>% 
  #   dplyr::summarize(pop_19 = sum(pop_19), pop_25 = sum(pop_25),
  #                    .by = c("jahr", "sex", "nat")) %>% 
  #   group_by(sex, nat) %>% 
  #   arrange(jahr) %>% 
  #   mutate(g_pop19 = cumprod(pop_19 / lag(pop_19, def = first(pop_19))),
  #          g_pop25 = cumprod(pop_25 / lag(pop_25, def = first(pop_25)))) %>% 
  #   select(jahr, sex, nat, g_pop19, g_pop25) %>% 
  #   ungroup() %>% 
  #   expand(nesting(jahr, sex, nat, g_pop19, g_pop25), 
  #          zv = c("ledig", "verheiratet", "geschieden", "verwitwet"), 
  #          gpr = c("rorphelin_pere_simple", "renfant_pere_simple",   
  #                  "rorphelin_mere_simple", "rorphelin_double"   , "renfant_mere_simple")) %>% 
  #   mutate(g_kv = ifelse(gpr == "rveuve", g_pop19, g_pop25), age_ret = NA) %>% 
  #   select(jahr, sex, nat, zv, gpr, age_ret, g_kv)
  # 
  # filter(BEV, sex == "f", nat == "au", zv == "ledig", gpr == "rorphelin_pere_simple")
  # filter(OT , sex == "f", nat == "au", zv == "ledig", gpr == "rorphelin_pere_simple")
  # 
  # OT %<>%
  #   left_join(BEV, by = c("jahr", "sex", "nat", "zv", "gpr")) %>% 
  #   mutate(pen = pen * g_kv / 1e6)
  # 
  # ggplot(filter(OT, gpr == "rveuve", zv == "verwitwet", sex == "f",
  #               jahr <= 2040), aes(x = jahr, y = pen, col = sex, group = sex)) +
  #   geom_shadowline() +
  #   geom_shadowpoint() +
  #   facet_grid(nat ~ .)
  
  # KV %<>%
  #   bind_rows(
  #     c("rorphelin_pere_simple", "renfant_pere_simple", "rorphelin_mere_simple",
  #       "rorphelin_double", "renfant_mere_simple", "rveuve", "rcompl_femme",
  #       "rvieillesse_simple") %>%
  #       map(\(x) mutate(KV, gpr = x))) %>%
  #   na.omit()

  KV %<>%
    na.omit() %>%
    mutate(gpr = "rvieillesse_simple") %>%
    filter(!(sex == "m" & gpr == "rvieillesse_simple" & age_ret < 63),
           !(gpr == "rvieillesse_simple" & alt < age_ret)) %>%
    # mutate(age_ret = ifelse(gpr == "rvieillesse_simple", age_ret, NA),
    #        g_kv = ifelse(gpr == "rvieillesse_simple", g_kv, g_pop)) %>%
    distinct() %>%
    group_by(sex, nat, dom, zv, age_ret, alt) %>%
    arrange(sex, nat, dom, zv, age_ret, alt, jahr) %>%
    mutate(g_kv = cumprod(g_kv)) %>%
    select(jahr, gpr, sex, nat, dom, zv, age_ret, alt, g_kv)

  write_delim(KV, "~/delfinverse/KV.csv", delim = ";")
  
# KV %<>%
#   select(jahr, sex, nat, dom, alt, kv = eprc, pop = bevendejahr) %>%
#   filter(!(sex == "f" & alt < 62), !(sex == "m" & alt < 63),
#          jahr >= 1997) %>%
#   group_by(sex, nat, dom, alt) %>%
#   arrange(sex, nat, dom, alt, jahr) %>%
#   mutate(type = "est", g_kv = kv / lag(kv), g_pop = pop / lag(pop)) %>%
#   replace_na(list(g_kv = 1, g_pop = 1)) %>%
#   select(type, jahr, sex, nat, dom, alt, kv, pop, g_kv, g_pop)
# 
# KV %<>%
#   bind_rows(
#     c("ledig", "verheiratet", "geschieden", "verwitwet") %>%
#       map(\(x) mutate(KV, zv = x)))
# 
# KV %<>%
#   bind_rows(
#     62:70 %>%
#       map(\(x) mutate(KV, age_ret = x)))
# 
# KV %<>%
#   na.omit() %>%
#   filter(!(sex == "m" & age_ret == 62)) %>%
#   mutate(gpr = "rvieillesse_simple") %>%
#   select(jahr, gpr, sex, nat, dom, zv, age_ret, alt, g_kv, g_pop)
# 
# # write_delim(KV, "~/delfinverse/KV.csv", delim = ";")
# 
RR %<>%
  filter(gpr == "rvieillesse_simple", alt >= age_ret,
         !(sex == "f" & alt < 62), !(sex == "m" & alt < 63), eprc > 0) %>%
  select(jahr, sex, nat, dom, alt, eprc, bez_av) %>%
  dplyr::summarize(kv = sum(eprc), pop = sum(bez_av), .by = c("jahr", "sex", "nat", "dom", "alt")) %>%
  group_by(sex, nat, dom, alt) %>%
  arrange(sex, nat, dom, alt, jahr) %>%
  mutate(type = "rr", g_kv = kv / lag(kv), g_pop = pop / lag(pop)) %>%
  replace_na(list(g_kv = 1, g_pop = 1))

adj <-
  filter(RR, jahr == 2024) %>%
  select(jahr, sex, nat, dom, alt, kv) %>%
  left_join(select(KV, jahr, sex, nat, dom, alt, kv_est = kv), by = c("jahr", "sex" , "nat", "dom", "alt")) %>%
  mutate(adj = kv / kv_est) %>%
  select(sex, nat, dom, alt, adj)

viz <-
  left_join(KV, adj, by = c("sex", "nat", "dom", "alt"),
            relationship = "many-to-one") %>%
  mutate(kv = kv * adj) %>%
  bind_rows(RR) %>% 
  select(- adj, - pop)

ggplot(filter(viz, jahr >= 2012, sex == "f", alt == 64), aes(x = jahr, y = kv, col = type)) +
  geom_hline(yintercept = 1) +
  geom_vline(xintercept = 2023.5) +
  geom_shadowpoint() +
  facet_grid(dom ~ nat, labeller = label_both)

viz2 <-
  mutate(viz, coh = jahr - alt) %>%
  filter(sex == "m", coh == 1938)

ggplot(viz2, aes(x = alt, y = kv, col = type)) +
  geom_shadowpoint() +
  facet_grid(nat ~ dom, labeller = label_both)

# ggplot(filter(KV, sex == "f", nat == "au", dom == "au", alt == 80, jahr <= 2040),
#        aes(x = jahr, y = g_kv)) +
#   geom_hline(yintercept = 1) +
#   geom_shadowpoint()
# 
# viz <-
#   select(KV, year = jahr, sex, nat, dom, age = alt, kv = eprc, pop = bevendejahr) %>%
#   filter(!(sex == "m" & age < 65), !(sex == "f" & age < 64)) %>%
#   # filter(age >= 65) %>%
#   dplyr::summarize(kv = sum(kv), .by = c("year", "sex", "nat", "dom", "age")) %>%
#   left_join(RR, by = c("year", "sex", "nat", "dom", "age"))
# 
# adj <-
#   filter(viz, year == 2023) %>%
#   mutate(adj = eprc / kv) %>%
#   select(sex, nat, dom, age, adj)
# 
# viz %<>%
#   left_join(adj, by = c("sex", "nat", "dom", "age"),
#             relationship = "many-to-one") %>%
#   mutate(kv = kv * adj) %>%
#   select(- adj)
# 
# viz2 <-
#   dplyr::summarize(viz, kv = sum(kv), eprc = sum(eprc),
#                    .by = c("year", "sex", "nat", "dom")) %>%
#   pivot_longer(kv:eprc)
# 
# viz3 <-
#   dplyr::summarize(viz, kv = sum(kv), eprc = sum(eprc),
#                    .by = c("year", "sex")) %>%
#   pivot_longer(kv:eprc)
# 
# viz %<>%
#   pivot_longer(kv:eprc)
# 
# ggplot(filter(viz, age == 65, year %in% 2015:2040), aes(x = as.factor(year), y = value, col = name)) +
#   geom_vline(xintercept = 2022.5) +
#   geom_shadowpoint() +
#   facet_grid(sex ~ nat + dom, labeller = label_both)
# 
# ggplot(filter(viz2, year %in% 2010:2040), aes(x = as.factor(year), y = value, col = name)) +
#   geom_vline(xintercept = 2022.5) +
#   geom_shadowpoint() +
#   facet_grid(sex ~ nat + dom, labeller = label_both)
# 
# ggplot(filter(viz3, year %in% 2010:2040), aes(x = as.factor(year), y = value, col = name)) +
#   geom_vline(xintercept = 2022.5) +
#   geom_shadowpoint() +
#   facet_grid(sex ~ ., labeller = label_both)



# viz <-
#   select(KV, year = jahr, sex, nat, dom, age = alt, kv = eprc, pop = bevendejahr) %>%
#   # filter(!(sex == "m" & age < 65), !(sex == "f" & age < 64)) %>%
#   filter(age >= 65) %>%
#   dplyr::summarize(kv = sum(kv), pop = sum(pop), .by = c("year")) %>%
#   pivot_longer(kv:pop) %>%
#   filter(year <= 2040)
#
# ggplot(viz, aes(x = year, y = value, col = name)) +
#   geom_vline(xintercept = 2023.5) +
#   geom_shadowpoint() +
#   scale_y_continuous(breaks = seq(1e5, 40e5, 1e5))

# viz <-
#   select(POPULATION_TOT, year = jahr, sex, nat, dom, age = alt, pop = bevendejahr) %>%
#   filter(age >= 65) %>%
#   dplyr::summarize(pop = sum(pop), .by = c("year", "sex", "nat", "dom"))
#
# ggplot(viz, aes(x = year, y = pop)) +
#   geom_vline(xintercept = 2023.5) +
#   geom_shadowpoint() +
#   facet_grid(sex ~ nat + dom, labeller = label_both)
#
# viz <-
#   select(POPU, year = jahr, sex, nat, dom, age = alt, pop = bevendejahr,
#          qm = quotients_mortalite, bw = bewertungsjahrfaktor, v = epr, kv = eprc) %>%
#   mutate(coh = year - age) %>%
#   filter(nat == "ch", dom == "ch", year <= 2040)
#
# ggplot(filter(viz, age == 55), aes(x = year, y = qm)) +
#   geom_shadowpoint() +
#   facet_grid(sex ~ nat + dom, labeller = label_both)
