# Aufteilung der Ausgabenprojektionen (Basismodell, Stand August 2025) -------------------

# Clear workspace and load necessary packages.
sapply(c("tidyverse", "magrittr", "simputation", "readxl", "rsample", "broom",
         "collapse" , "dqrng"   , "ggshadow"   , "strucchange"), 
       library, char = TRUE)

# Suppress column type guessing messages from 'readr' functions.
options(readr.show_col_types = FALSE)

# Load baseline parameters.
source("scripts/base_par.R")

# Define auxiliary functions.
source("scripts/aux_fun.R")

# Read and process all necessary inputs.
source("scripts/prepare_inputs.R")


# Extract projections by pension type. ---------------------------------------------------

# Range of trend extrapolation points from crossvalidation (see 'cv_out.R').
load("data/range.RData")

# Set up list to collect results.
rl <- list()

# Imputation over rent type, sex, and domicile for the BFS reference scenario 'A'.
rl[["au"]] <- a_dat %>%
  filter(scen %in% c("H", "A"), dom == "au") %>%
  mutate(t_n = ifelse(year >= range[1], 1, NA), t_m = ifelse(year >= range[2], 1, NA)) %>%
  group_by(sex, type) %>%
  arrange(year, by_group = TRUE) %>% 
  mutate(d_n = c(NA, diff(n)), d_m = c(NA, diff(m))) %>%
  # Extrapolate pension counts and mean pension levels with a linear time trend over
  # the Cartesian product of pension type, sex, and domicile.
  ungroup() %>%
  impute_lm(d_n ~ 0 + t_n | sex + type) %>%
  impute_lm(d_m ~ 0 + t_m | sex + type) %>%
  group_by(sex, type) %>%
  mutate(d_n = ifelse(year > first(PAR$pint), d_n, 0),
         d_m = ifelse(year > first(PAR$pint), d_m, 0)) %>%
  # Prevent eventual negative predictions due to linear extrapolation.
  mutate(m = pmin(2, pmax(0, na_locf(m) + cumsum(d_m))),
         n =         pmax(0, na_locf(n) + cumsum(d_n))) %>%
  select(- t_n, - d_n, - t_m, - d_m, - scen)

rl[["ch"]] <- a_dat %>%
  filter(scen %in% c("H", "A"), dom == "ch") %>%
  mutate(t_m = ifelse(year >= range[3], 1, NA)) %>%
  # Extrapolate mean pension levels with a linear time trend over the Cartesian product of 
  # pension type, sex, and domicile.
  group_by(sex, type) %>%
  arrange(sex, type, year) %>% 
  mutate(d_m = c(NA, diff(m))) %>%
  ungroup() %>%
  impute_lm(d_m ~ 0 + t_m | sex + type) %>%
  group_by(sex, type) %>%
  mutate(d_m = ifelse(year > first(PAR$pint), d_m, 0)) %>%
  mutate(m = pmin(2, pmax(0, na_locf(m) + cumsum(d_m)))) %>%
  select(- t_m, - d_m, - scen)

# Import historical intermediate-year deaths of retirees and their average
# pensions (source: RR). Factor 'mean(1:11) / 12' in 'save_rel' approximates the
# ratio of savings due to within-year pensioner deaths under monthly versus
# yearly payout of 13th AHV pension payment according to the Liechtenstein model.
corr <-
  read_delim(PAR$in_lim) %>%
  select(year = an_rr, sex = csex, dom = recoded_cdom, n = anzahl,
         m_dec = rentensumme_dez) %>%
  mutate(year = year + 1,
         sex  = recode(sex, `1` = "m", `2` = "f"),
         dom  = recode(dom, `100` = "ch", `900` = "au")) %>%
  # Multiplication by '11/12' corrects for death count in the preceding december
  # (implicit assumption: uniform death rate over months and pension levels).
  dplyr::summarize(m = 11 / 12 * sum(m_dec) / sum(n), n = sum(n),
                   .by = c("year", "sex", "dom")) %>%
  right_join(
    tibble(year = rep(min(.$year):last(PAR$pint), each = 4),
           sex  = rep(c("m" ,  "m",  "f",  "f"), length(min(.$year):last(PAR$pint))),
           dom  = rep(c("ch", "au", "ch", "au"), length(min(.$year):last(PAR$pint)))),
    by = c("year", "sex", "dom")
  ) %>%
  # arrange(year) %>% 
  left_join(mp, by = "year") %>%
  # Express mean monthly pensions as multiples of minimal rent.
  mutate(m = m / mp) %>%
  group_by(sex, dom) %>%
  arrange(sex, dom, year) %>% 
  mutate(d_n = c(0, diff(n)), d_m = c(0, diff(m))) %>%
  ungroup() %>%
  # Impute death counts and mean pensions linearly and separately by sex and domicile.
  impute_lm(d_n + d_m ~ 1 | sex + dom) %>%  
  group_by(sex, dom) %>% 
  mutate(d_n = cumsum((year > first(PAR$pint)) * d_n),
         d_m = cumsum((year > first(PAR$pint)) * d_m)) %>%
  mutate(m = pmin(2, pmax(0, na_locf(m) + d_m)), 
         n =         pmax(0, na_locf(n) + d_n)) %>%
  mutate(save = m * n * mp) %>%
  ungroup() %>% 
  dplyr::summarize(save_tot = sum(save), .by = c("year", "sex", "dom")) %>%
  filter(year >= first(PAR$pint)) %>%
  # Calculate effective savings relative to yearly payout.
  mutate(save_rel = ifelse(year >= 2026, save_tot * (1 - mean(1:11) / 12), 0),
         type = "alt") %>%
  select(year, sex, dom, type, save_rel)

# Replace average pension projections for male widows with exogenous
# vector based on consolidated STATPOP & pension registry data (author:
# Thomas Borek @Math-BSV).
p_dat <-
  bind_rows(rl) %>%
  filter(!(type == "wit" & year > first(PAR$pint))) %>%
  bind_rows(
    wid %>% 
      filter(year > first(PAR$pint), year <= last(PAR$pint),
             scen == "A") %>% 
      select(- scen)
  ) %>%
  left_join(corr, by = c("year", "sex", "dom", "type")) %>%
  tidyr::replace_na(list(save_rel = 0)) %>% 
  # Incorporate Liechtenstein effect through an equivalent decrease of surviving 
  # retirees' average pension level.
  mutate(m = ifelse(type == "alt" & year >= 2026,
                    m * (1 + PAR$ahv13 * 1/12) - 
                      PAR$ahv13 * save_rel / (n * mp), m)) %>% 
  left_join(zas, by = "year") %>% 
  ungroup()

# Fit model and produce projections.
f_dat <- p_dat %>%
  dplyr::summarize(m = weighted.mean(m, n), n = sum(n),
                   .by = c("year", "mp", "exp_tot")) %>%
  filter(year %in% range[4]:first(PAR$pint)) %>% 
  mutate(d_nmmp = c(NA, diff(n * m * mp)))

fit <-
  lm(exp_tot ~ 0 + I(n * m * mp), f_dat)

PROJ <- p_dat %>% 
  filter(year >= 2024) %>% 
  mutate(type = fct_recode(type, other = "kin", other = "wai")) %>% 
  summarize(exp = sum(n * m * mp - save_rel), .by = c("type", "year")) %>% 
  mutate(exp = exp / 1e6)

REST <- PROJ %>% 
  mutate(type = "other", exp = exp * (coef(fit) - 1)) %>% 
  summarize(exp = sum(exp), .by = c("type", "year"))

PROJ <- PROJ %>% 
  bind_rows(REST) %>% 
  summarize(exp = sum(exp), .by = c("type", "year"))

ADJ <- 
  tribble(~ type , ~ adj,
          "alt"  , 47945 / 47607,
          "wit"  ,  1914 /  1897,
          "other",  1082 /  1408)

PROJ <- PROJ %>% 
  left_join(ADJ, by = "type") %>% 
  mutate(exp = adj * exp,
         type = fct_recode(type, Altersrenten = "alt", Verwitwetenrenten = "wit", 
                           `Andere Ausgaben` = "other")) %>% 
  select(- adj) %>% 
  rename(Ausgabentyp = type, Jahr = year, Betrag = exp)

ggplot(PROJ, aes(x = Jahr, y = Betrag)) +
  geom_col() +
  facet_grid(Ausgabentyp ~ ., scales = "free_y")

write_delim(PROJ,
            "~/anfragen/outputs/kostenaufteilung_proj_basismodell_aug25.csv", delim = ";")

# Dekomposition des Ausgabenwachstums in Volumen versus Rentenniveaus --------------------

p_dat %<>%
  dplyr::summarize(m = weighted.mean(m, n), n = sum(n),
                   .by = c("year", "mp", "exp_tot"))

rtab <- p_dat %>% 
  filter(year %in% PAR$pint) %>%
  mutate(d_nmmp = c(0, cumsum(diff(n * m * mp)))) %>% 
  # Adjust projections by estimated cost top-up.
  mutate(exp_tot = 
           ifelse(year > first(PAR$pint), 
                  first(exp_tot) + d_nmmp * coef(fit), 
                  exp_tot)) %>% 
  dplyr::rename(ref = exp_tot) %>% 
  select(- d_nmmp)



