#' Estimation des EPRC à partir de la population
#'
#' Ce module calcule les équivalents pleine rente cumulés (EPRC) à
#' partir des données d'une certaine population en les faisant évoluer selon
#' les taux de mortalité.
#'
#' @param PARAM_GLOBAL un dataframe d'une seule ligne, dont nous utilisons les
#'   paramètres suivants:
#'   - `ra_f_2005`: Age de la retraite pour les femmes dès 2005 (64)
#'   - `ra_m`: Age de la retraite pour les hommes
#'   - `min_age_cot`: Age à partir duquel on verse des cotisations
#'
#' @param POPULATION_TOT data frame contenant les données de la population
#' calculé dans la fonction \code{\link{mod_population}}.
#'
#' @param MORTALITE data frame contenant les taux de mortalité, cf. fonction
#' \code{\link{mod_input_mortalite}}.
#'
#' @param list `tidylist`. Elément obligatoire dans tous les modules. Au lieu de
#'   fournir des dataframes au module, il est possible de l'alimenter uniquement
#'   avec une`tidylist` qui contient les tidy dataframes. Tous les datframes
#'   listés doivent être présents dans la tidylist, avec le même nom. De plus,
#'   la `tidylist` peut aussi contenir des dataframes qui ne sont pas utilisés
#'   par le module.
#'
#' @references \href{https://www.bsv.admin.ch/dam/bsv/fr/dokumente/ahv/finanzperspektiven/validierung-modellansatz-ahv.pdf.download.pdf/2018_07_09_definitif_ld_rapport_ofas.pdf}{Rapport de Prof. Dr Laurent Donzé}
#'
#' @return une `tidylist` contenant le data frame suivant:
#' - `EPRC_ESTIMATION`
#'
#' @author [MAS BSV](mailto:sekretariat.mas@bsv.admin.ch)
#'
#' @export
#'

source("~/delfinverse/dinput/R/utils.R")

options(readr.show_col_types = FALSE)

setwd("~/data/appl-wb/20_staff/kjo/fhh/2025-05-20T1628_u80874371_ahv_basis_kjo")

PARAM_GLOBAL <- 
  read_delim("PARAM_GLOBAL.csv")

BEVOELKERUNG <- 
  read_delim("BEVOELKERUNG.csv")

TAUX_MORTALITE <- 
  read_delim("TAUX_MORTALITE.csv")

RR_AVS <- 
  read_delim("RR_AVS.csv")

# mod_eprc_estimation_ <-
mod_eprc_estimation <- function(PARAM_GLOBAL,
                                BEVOELKERUNG,
                                TAUX_MORTALITE,
                                RR_AVS) {

    print("Run module: mod_eprc_estimation")
    
    POPULATION_TOT <- BEVOELKERUNG %>%
        mutate(bevendejahr = bevendejahr, dom = "ch") %>%
        select(jahr, sex, nat, alt, dom, bevendejahr) %>%
        bind_rows(BEVOELKERUNG %>%
                      mutate(bevendejahr = rowSums(select(., auswanderung, frontaliers, saisonniers, assures_facultatifs), na.rm = TRUE), 
                             dom = "au") %>%
                      select(jahr, sex, nat, alt, dom, bevendejahr))

    if ("min_age_cot" %in% colnames(PARAM_GLOBAL) == FALSE) {
      min_age_cot <- 21
      print("Please add <<min_age_cot;21>> into PARAM_GLOBAL")
    } else {
      min_age_cot <- PARAM_GLOBAL$min_age_cot
    }
  
  # Bewertungsjahrfaktor
  # Version avec le même facteur pour tout le monde (pour les hommes et les femmes)

  POPU <- POPULATION_TOT %>%
    left_join(TAUX_MORTALITE, by = c("jahr", "alt", "sex", "nat")) %>%
    mutate(coh = jahr - alt) %>% 
    mutate(
      bewertungsjahrfaktor = case_when(
        alt %in% 21:64 & sex == "m" ~ 
          pmax(
            1 / (65 - 21             ),
            1 / (65 - 21 + coh - 1927)),
        
        alt %in% 21:62 & sex == "f" & coh <= 1938 ~ 
          pmax(
            1 / (62 - 21             ),
            1 / (62 - 21 + coh - 1927)),
        
        min_age_cot <= alt & sex == "f" & coh %in% 1939:1941 & alt < 63 ~ 
          pmax(
            1 / (63 - 21             ),
            1 / (63 - 21 + coh - 1927)),
        
        min_age_cot <= alt & sex == "f" & coh >=        1942 & alt < 64 ~ 
          pmax(
            1 / (64 - 21             ),
            1 / (64 - 21 + coh - 1927)),
        
        # all other cases
        TRUE ~ 0
      )
    ) %>% 
    mutate(
      epr = bewertungsjahrfaktor * bevendejahr,
      eprc = epr * 
        case_when(
          (alt == 21 | jahr == min(jahr)) & nat == "ch" & sex == "m"                      ~ pmin(alt, 65) - 21 + 1,
          (alt == 21 | jahr == min(jahr)) & nat == "ch" & sex == "f" & coh <= 1938        ~ pmin(alt, 62) - 21 + 1,
          (alt == 21 | jahr == min(jahr)) & nat == "ch" & sex == "f" & coh %in% 1939:1941 ~ pmin(alt, 63) - 21 + 1,
          (alt == 21 | jahr == min(jahr)) & nat == "ch" & sex == "f" & coh >= 1942        ~ pmin(alt, 64) - 21 + 1,
          (alt == 21 | jahr == min(jahr)) & nat != "ch"                                   ~                      1,
          # all other cases
          TRUE ~ 0)
    ) %>%
    ungroup()

  # EXTRAPOLATION DES EPRC : Fonction de calcul des EPRC
  calculate_eprc_extrapolation <- function(DTA) {

    # Jahr und Alter
    jahr <- unique(DTA$jahr)
    alt <- unique(DTA$alt)


    # Taux de mortalité
    Q <- tibble_to_matrix(dplyr::select(DTA, jahr, alt, quotients_mortalite))

    # Equivalents pleine rente
    V <- tibble_to_matrix(dplyr::select(DTA, jahr, alt, epr))

    # Equivalents pleine rente cumulés
    K <- tibble_to_matrix(dplyr::select(DTA, jahr, alt, eprc))

    for (i in (seq_along(jahr[-1]) + 1)) {
      for (j in (seq_along(alt[-1]) + 1)) {
        K[i, j] <- (1 - Q[i - 1, j - 1]) * K[i - 1, j - 1] + V[i, j]
      }
    }


    matrix_to_tibble(K, dplyr::select(DTA, jahr, alt, eprc))
  }

  # Estimation des EPRC pour les 4 groupes
  EPRC_ESTIMATION <- crossing(
    sex = c("f", "m"),
    nat = c("ch", "au"),
    dom = c("ch", "au")
  ) %>%
    group_by(sex, nat, dom) %>%
    do(dta0 = calculate_eprc_extrapolation(filter(
      POPU, sex == .$sex,
      nat == .$nat,
      dom == .$dom
    ))) %>%
    unnest(cols = c(dta0)) %>%
    filter(!is.na(eprc)) %>%
    left_join(POPULATION_TOT,
      by = c("jahr", "sex", "nat", "alt", "dom")
    )

  ADJ <- RR_AVS %>% 
    filter(gpr == "rvieillesse_simple", alt >= age_ret,
           !(sex == "f" & alt < 62), !(sex == "m" & alt < 63), eprc > 0) %>%
    select(jahr, sex, nat, dom, alt, eprc, bez_av) %>%
    dplyr::summarize(eprc = sum(eprc),  
                     .by = c("jahr", "sex", "nat", "dom", "alt")) %>%
    group_by(sex, nat, dom, alt) %>%
    arrange(sex, nat, dom, alt, jahr) %>% 
    filter(jahr == PARAM_GLOBAL$jahr_rr) %>% 
    select(jahr, sex, nat, dom, alt, eprc) %>%
    left_join(select(EPRC_ESTIMATION, jahr, sex, nat, dom, alt, eprc_est = eprc), 
              by = c("jahr", "sex" , "nat", "dom", "alt")) %>%
    mutate(adj = eprc / eprc_est) %>%
    select(sex, nat, dom, alt, adj) %>% 
    na.omit()
  
  EPRC_ESTIMATION <- EPRC_ESTIMATION %>% 
    left_join(ADJ, by = c("sex", "nat", "dom", "alt"),
              relationship = "many-to-one") %>%
    tidyr::replace_na(list(adj = 1)) %>% 
    mutate(eprc = eprc * adj) %>%
    select(- adj)

  return(EPRC_ESTIMATION = EPRC_ESTIMATION)
}
