#' Estimation des EPRC à partir de la population
#'
#' Ce module calcule les équivalents pleine rente cumulés (EPRC) à
#' partir des données d'une certaine population en les faisant évoluer selon
#' les taux de mortalité.
#'
#' @param PARAM_GLOBAL un dataframe d'une seule ligne, dont nous utilisons les
#'   paramètres suivants:
#'   - `ra_f_2005`: Age de la retraite pour les femmes dès 2005 (64)
#'   - `ra_m`: Age de la retraite pour les hommes
#'   - `min_age_cot`: Age à partir duquel on verse des cotisations
#'
#' @param POPULATION_TOT data frame contenant les données de la population
#' calculé dans la fonction \code{\link{mod_population}}.
#'
#' @param MORTALITE data frame contenant les taux de mortalité, cf. fonction
#' \code{\link{mod_input_mortalite}}.
#'
#' @param list `tidylist`. Elément obligatoire dans tous les modules. Au lieu de
#'   fournir des dataframes au module, il est possible de l'alimenter uniquement
#'   avec une`tidylist` qui contient les tidy dataframes. Tous les datframes
#'   listés doivent être présents dans la tidylist, avec le même nom. De plus,
#'   la `tidylist` peut aussi contenir des dataframes qui ne sont pas utilisés
#'   par le module.
#'
#' @references \href{https://www.bsv.admin.ch/dam/bsv/fr/dokumente/ahv/finanzperspektiven/validierung-modellansatz-ahv.pdf.download.pdf/2018_07_09_definitif_ld_rapport_ofas.pdf}{Rapport de Prof. Dr Laurent Donzé}
#'
#' @return une `tidylist` contenant le data frame suivant:
#' - `KV`
#'
#' @author [MAS BSV](mailto:sekretariat.mas@bsv.admin.ch)
#'
#' @export
#'

source("~/delfinverse/dinput/R/utils.R")

options(readr.show_col_types = FALSE)

setwd("~/data/appl-wb/20_staff/kjo/fhh/2025-05-20T1646_u80874371_ahv_basis_kjo")

PARAM_GLOBAL <- 
  read_delim("PARAM_GLOBAL.csv")

BEVOELKERUNG <- 
  read_delim("BEVOELKERUNG.csv")

TAUX_MORTALITE <- 
  read_delim("TAUX_MORTALITE.csv")

RR <- 
  read_delim("RR_AVS.csv")
    
    POPULATION_TOT <- BEVOELKERUNG %>%
        mutate(bevendejahr = bevendejahr, dom = "ch") %>%
        select(jahr, sex, nat, alt, dom, bevendejahr) %>%
        bind_rows(
          BEVOELKERUNG %>%
            mutate(bevendejahr = 
                     rowSums(
                       select(., auswanderung, frontaliers, assures_facultatifs),
                       # select(., auswanderung, frontaliers, assures_facultatifs), 
                       na.rm = TRUE), 
                             dom = "au") %>%
                      select(jahr, sex, nat, alt, dom, bevendejahr))

    if ("min_age_cot" %in% colnames(PARAM_GLOBAL) == FALSE) {
      min_age_cot <- 21
      print("Please add <<min_age_cot;21>> into PARAM_GLOBAL")
    } else {
      min_age_cot <- as.numeric(PARAM_GLOBAL$min_age_cot)
    }
  
  # Bewertungsjahrfaktor
  # Version avec le même facteur pour tout le monde (pour les hommes et les femmes)

  POPU <- POPULATION_TOT %>%
    left_join(TAUX_MORTALITE, by = c("jahr", "alt", "sex", "nat")) %>%
    mutate(coh = jahr - alt) %>% 
    mutate(
      bewertungsjahrfaktor = case_when(
        alt %in% 21:64 & sex == "m" ~ 
          pmax(
            1 / (65 - 21             ),
            1 / (65 - 21 + coh - 1927)),
        
        alt %in% 21:62 & sex == "f" & coh <= 1938 ~ 
          pmax(
            1 / (62 - 21             ),
            1 / (62 - 21 + coh - 1927)),
        
        min_age_cot <= alt & sex == "f" & coh %in% 1939:1941 & alt < 63 ~ 
          pmax(
            1 / (63 - 21             ),
            1 / (63 - 21 + coh - 1927)),
        
        min_age_cot <= alt & sex == "f" & coh >=        1942 & alt < 64 ~ 
          pmax(
            1 / (64 - 21             ),
            1 / (64 - 21 + coh - 1927)),

        # all other cases
        TRUE ~ 0
      )
    ) %>% 
    mutate(
      epr = bewertungsjahrfaktor * bevendejahr,
      eprc = epr * 
        case_when(
          (alt == 21 | jahr == min(jahr)) & nat == "ch" & sex == "m"                      ~ pmin(alt, 65) - 21 + 1,
          (alt == 21 | jahr == min(jahr)) & nat == "ch" & sex == "f" & coh <= 1938        ~ pmin(alt, 62) - 21 + 1,
          (alt == 21 | jahr == min(jahr)) & nat == "ch" & sex == "f" & coh %in% 1939:1941 ~ pmin(alt, 63) - 21 + 1,
          (alt == 21 | jahr == min(jahr)) & nat == "ch" & sex == "f" & coh >= 1942        ~ pmin(alt, 64) - 21 + 1,
          (alt == 21 | jahr == min(jahr)) & nat != "ch"                                   ~                      5,
          # all other cases
          TRUE ~ 0)
      ) %>%
    ungroup()

  # EXTRAPOLATION DES EPRC : Fonction de calcul des EPRC
  calculate_eprc_extrapolation <- function(DTA) {

    # Jahr und Alter
    jahr <- unique(DTA$jahr)
    alt  <- unique(DTA$alt)

    # Taux de mortalité
    Q <- tibble_to_matrix(dplyr::select(DTA, jahr, alt, quotients_mortalite))

    # Equivalents pleine rente
    V <- tibble_to_matrix(dplyr::select(DTA, jahr, alt, epr))

    # Equivalents pleine rente cumulés
    K <- tibble_to_matrix(dplyr::select(DTA, jahr, alt, eprc))

    for (i in (seq_along(jahr[-1]) + 1)) {
      for (j in (seq_along(alt[-1]) + 1)) {
        
        K[i, j] <- (1 - Q[i - 1, j - 1]) * K[i - 1, j - 1] + V[i, j]
        
      }
    }
    
    matrix_to_tibble(K, dplyr::select(DTA, jahr, alt, eprc))
  }

  # Estimation des EPRC pour les 4 groupes
  KV <- crossing(
    sex = c("f", "m"), nat = c("ch", "au"), dom = c("ch", "au")) %>%
    group_by(sex, nat, dom) %>%
    do(dta0 = calculate_eprc_extrapolation(filter(
      POPU, sex == .$sex, nat == .$nat, dom == .$dom
    ))) %>%
    unnest(cols = c(dta0)) %>%
    filter(!is.na(eprc)) %>%
    left_join(POPULATION_TOT, by = c("jahr", "sex", "nat", "alt", "dom"))

  # return(KV = KV)
# }
# 
#   KV %<>%
#     select(jahr, sex, nat, dom, alt, kv = eprc, pop = bevendejahr) %>%
#     filter(!(sex == "f" & alt < 62), !(sex == "m" & alt < 63),
#            jahr >= PARAM_GLOBAL$jahr_rr) %>%
#     group_by(sex, nat, dom, alt) %>%
#     arrange(sex, nat, dom, alt, jahr) %>%
#     mutate(type = "est", g_kv = kv / lag(kv), g_pop = pop / lag(pop)) %>%
#     replace_na(list(g_kv = 1, g_pop = 1)) %>%
#     select(type, jahr, sex, nat, dom, alt, kv, g_kv, g_pop)

  # KV %<>%
  #   bind_rows(
  #     c("ledig", "verheiratet", "geschieden", "verwitwet") %>%
  #       map(\(x) mutate(KV, zv = x)))
  #
  # KV %<>%
  #   bind_rows(
  #     62:70 %>%
  #       map(\(x) mutate(KV, age_ret = x)))

  # KV %<>%
  #   bind_rows(
  #     c("rorphelin_pere_simple", "renfant_pere_simple", "rorphelin_mere_simple",
  #       "rorphelin_double", "renfant_mere_simple", "rveuve", "rcompl_femme",
  #       "rvieillesse_simple") %>%
  #       map(\(x) mutate(KV, gpr = x))) %>%
  #   na.omit()
# 
#   KV %<>%
#     na.omit() %>%
#     mutate(gpr = "rvieillesse_simple") %>%
#     filter(!(sex == "m" & gpr == "rvieillesse_simple" & age_ret < 63),
#            !(gpr == "rvieillesse_simple" & alt < age_ret)) %>%
#     # mutate(age_ret = ifelse(gpr == "rvieillesse_simple", age_ret, NA),
#     #        g_kv = ifelse(gpr == "rvieillesse_simple", g_kv, g_pop)) %>%
#     distinct() %>%
#     group_by(sex, nat, dom, zv, age_ret, alt) %>%
#     arrange(sex, nat, dom, zv, age_ret, alt, jahr) %>%
#     mutate(g_kv = cumprod(g_kv)) %>%
#     select(jahr, gpr, sex, nat, dom, zv, age_ret, alt, kv, g_kv)
# 
#   write_delim(KV, "~/delfinverse/KV.csv", delim = ";")
  
KV %<>%
  select(jahr, sex, nat, dom, alt, kv = eprc, pop = bevendejahr) %>%
  filter(!(sex == "f" & alt < 62), !(sex == "m" & alt < 63),
         jahr >= 1997) %>%
  group_by(sex, nat, dom, alt) %>%
  arrange(sex, nat, dom, alt, jahr) %>%
  mutate(type = "est", g_kv = kv / lag(kv), g_pop = pop / lag(pop)) %>%
  replace_na(list(g_kv = 1, g_pop = 1)) %>%
  select(type, jahr, sex, nat, dom, alt, kv, pop, g_kv, g_pop)
# 
# KV %<>%
#   bind_rows(
#     c("ledig", "verheiratet", "geschieden", "verwitwet") %>%
#       map(\(x) mutate(KV, zv = x)))
# 
# KV %<>%
#   bind_rows(
#     62:70 %>%
#       map(\(x) mutate(KV, age_ret = x)))
# 
# KV %<>%
#   na.omit() %>%
#   filter(!(sex == "m" & age_ret == 62)) %>%
#   mutate(gpr = "rvieillesse_simple") %>%
#   select(jahr, gpr, sex, nat, dom, zv, age_ret, alt, g_kv, g_pop)
# 
# # write_delim(KV, "~/delfinverse/KV.csv", delim = ";")

RR %<>%
  filter(gpr == "rvieillesse_simple", alt >= age_ret,
         !(sex == "f" & alt < 62), !(sex == "m" & alt < 63), eprc > 0) %>%
  select(jahr, sex, nat, dom, alt, eprc, bez_av) %>%
  dplyr::summarize(kv = sum(eprc), pop = sum(bez_av), 
                   .by = c("jahr", "sex", "nat", "dom", "alt")) %>%
  group_by(sex, nat, dom, alt) %>%
  arrange(sex, nat, dom, alt, jahr) %>%
  mutate(type = "rr", g_kv = kv / lag(kv), g_pop = pop / lag(pop)) %>%
  replace_na(list(g_kv = 1, g_pop = 1))

adj <-
  filter(RR, jahr == 2024) %>%
  select(jahr, sex, nat, dom, alt, kv) %>%
  left_join(select(KV, jahr, sex, nat, dom, alt, kv_est = kv), 
            by = c("jahr", "sex" , "nat", "dom", "alt")) %>%
  mutate(adj = kv / kv_est) %>%
  select(sex, nat, dom, alt, adj)

KV %<>%
  left_join(adj, by = c("sex", "nat", "dom", "alt"),
            relationship = "many-to-one") %>%
  mutate(kv = kv * adj) %>%
  bind_rows(RR) %>% 
  select(- adj, - pop)

viz <- KV

ggplot(filter(viz, jahr %in% 2010:2040, sex == "f", alt %in% 65:70,
              nat == "ch", dom == "ch"), aes(x = jahr, y = kv, col = as.factor(alt), shape = type)) +
  geom_hline(yintercept = 1) +
  geom_vline(xintercept = 2023.5) +
  geom_shadowpoint()

KV %<>%
  filter(type == "est") %>% 
  ungroup() %>% 
  select(jahr, sex, nat, dom, alt, kv) %>% 
  group_by(sex, nat, dom, alt) %>% 
  arrange(sex, nat, dom, alt, jahr) %>% 
  filter(jahr >= PARAM_GLOBAL$jahr_rr) %>% 
  mutate(g_kv = cumprod(kv / lag(kv, def = first(kv))),
         gpr = "rvieilleisse_simple") %>% 
  select(gpr, jahr, sex, nat, dom, alt, g_kv)

write_delim(KV, "~/delfinverse/KV.csv", delim = ";")

# viz2 <-
#   mutate(viz, coh = jahr - alt) %>%
#   filter(sex == "m", coh == 1945)
# 
# ggplot(filter(viz2, alt >= 65), aes(x = alt, y = kv, col = type)) +
#   geom_shadowpoint() +
#   facet_grid(dom ~ nat, labeller = label_both)

# ggplot(filter(KV, sex == "f", nat == "au", dom == "au", alt == 80, jahr <= 2040),
#        aes(x = jahr, y = g_kv)) +
#   geom_hline(yintercept = 1) +
#   geom_shadowpoint()
# 
# viz <-
#   select(KV, year = jahr, sex, nat, dom, age = alt, kv = eprc, pop = bevendejahr) %>%
#   filter(!(sex == "m" & age < 65), !(sex == "f" & age < 64)) %>%
#   # filter(age >= 65) %>%
#   dplyr::summarize(kv = sum(kv), .by = c("year", "sex", "nat", "dom", "age")) %>%
#   left_join(RR, by = c("year", "sex", "nat", "dom", "age"))
# 
# adj <-
#   filter(viz, year == 2023) %>%
#   mutate(adj = eprc / kv) %>%
#   select(sex, nat, dom, age, adj)
# 
# viz %<>%
#   left_join(adj, by = c("sex", "nat", "dom", "age"),
#             relationship = "many-to-one") %>%
#   mutate(kv = kv * adj) %>%
#   select(- adj)
# 
# viz2 <-
#   dplyr::summarize(viz, kv = sum(kv), eprc = sum(eprc),
#                    .by = c("year", "sex", "nat", "dom")) %>%
#   pivot_longer(kv:eprc)
# 
# viz3 <-
#   dplyr::summarize(viz, kv = sum(kv), eprc = sum(eprc),
#                    .by = c("year", "sex")) %>%
#   pivot_longer(kv:eprc)
# 
# viz %<>%
#   pivot_longer(kv:eprc)
# 
# ggplot(filter(viz, age == 65, year %in% 2015:2040), aes(x = as.factor(year), y = value, col = name)) +
#   geom_vline(xintercept = 2022.5) +
#   geom_shadowpoint() +
#   facet_grid(sex ~ nat + dom, labeller = label_both)
# 
# ggplot(filter(viz2, year %in% 2010:2040), aes(x = as.factor(year), y = value, col = name)) +
#   geom_vline(xintercept = 2022.5) +
#   geom_shadowpoint() +
#   facet_grid(sex ~ nat + dom, labeller = label_both)
# 
# ggplot(filter(viz3, year %in% 2010:2040), aes(x = as.factor(year), y = value, col = name)) +
#   geom_vline(xintercept = 2022.5) +
#   geom_shadowpoint() +
#   facet_grid(sex ~ ., labeller = label_both)



# viz <-
#   select(KV, year = jahr, sex, nat, dom, age = alt, kv = eprc, pop = bevendejahr) %>%
#   # filter(!(sex == "m" & age < 65), !(sex == "f" & age < 64)) %>%
#   filter(age >= 65) %>%
#   dplyr::summarize(kv = sum(kv), pop = sum(pop), .by = c("year")) %>%
#   pivot_longer(kv:pop) %>%
#   filter(year <= 2040)
#
# ggplot(viz, aes(x = year, y = value, col = name)) +
#   geom_vline(xintercept = 2023.5) +
#   geom_shadowpoint() +
#   scale_y_continuous(breaks = seq(1e5, 40e5, 1e5))

# viz <-
#   select(POPULATION_TOT, year = jahr, sex, nat, dom, age = alt, pop = bevendejahr) %>%
#   filter(age >= 65) %>%
#   dplyr::summarize(pop = sum(pop), .by = c("year", "sex", "nat", "dom"))
#
# ggplot(viz, aes(x = year, y = pop)) +
#   geom_vline(xintercept = 2023.5) +
#   geom_shadowpoint() +
#   facet_grid(sex ~ nat + dom, labeller = label_both)
#
# viz <-
#   select(POPU, year = jahr, sex, nat, dom, age = alt, pop = bevendejahr,
#          qm = quotients_mortalite, bw = bewertungsjahrfaktor, v = epr, kv = eprc) %>%
#   mutate(coh = year - age) %>%
#   filter(nat == "ch", dom == "ch", year <= 2040)
#
# ggplot(filter(viz, age == 55), aes(x = year, y = qm)) +
#   geom_shadowpoint() +
#   facet_grid(sex ~ nat + dom, labeller = label_both)
