# Aufteilung der Ausgabenprojektionen (Basismodell, Stand August 2025) -------------------

# Clear workspace and load necessary packages.
sapply(c("tidyverse", "magrittr", "simputation", "readxl", "rsample", "broom",
         "dqrng"   , "ggshadow"   , "strucchange"),
       library, char = TRUE)

# Suppress column type guessing messages from 'readr' functions.
options(readr.show_col_types = FALSE)

# Load baseline parameters.
source("scripts/model/1_initialize_parameters.R")

# Define auxiliary functions.
source("scripts/other/aux_fun.R")

PAR$reftab <-
  tribble(
    ~ year, ~ sex,     ~ refage,
    2024,   "f",       64      ,

    # Reform AHV 21.
    2025,   "f",       64 + 1/4,
    2026,   "f",       64 + 2/4,
    2027,   "f",       64 + 3/4,
    2028,   "f",       64 + 4/4,

    2024,   "m",       65)

# Read and process all necessary inputs.
source("scripts/model/2_prepare_inputs.R")
source("scripts/model/3_adjust_scenarios.R")

# Extract projections by pension type. ---------------------------------------------------

# Range of trend extrapolation points from crossvalidation (see 'cv_out.R').
load("data/output/RANGE.RData")

# Set up list to collect results.
RES_LIST <- list()

# Imputation over rent type, sex, and domicile for the BFS reference scenario 'A'.
RES_LIST[["au"]] <- A_DATA %>%
  filter(scen %in% c("H", "A"), dom == "au") %>%
  mutate(t_n = ifelse(year >= RANGE$nau, 1, NA), 
         t_m = ifelse(year >= RANGE$mau, 1, NA)) %>%
  group_by(sex, type) %>%
  arrange(year, by_group = TRUE) %>% 
  mutate(d_n = n - lag(n), d_m = m - lag(m)) %>%
  # Extrapolate pension counts and mean pension levels with a linear time trend over
  # the Cartesian product of pension type, sex, and domicile.
  ungroup() %>%
  impute_lm(d_n ~ 0 + t_n | sex + type) %>%
  impute_lm(d_m ~ 0 + t_m | sex + type) %>%
  group_by(sex, type) %>%
  mutate(d_n = ifelse(year > first(PAR$pint), d_n, 0),
         d_m = ifelse(year > first(PAR$pint), d_m, 0)) %>%
  fill(m, n) %>% 
  # Prevent eventual negative predictions, and enforce legal maximum pension of twice
  # the minimal pension (implicit assumption: pension supplements due to Rentenaufschub
  # will never increase the mean pension level above the upper limit).
  mutate(m = pmin(2, pmax(0, m + cumsum(d_m))),
         n =         pmax(0, n + cumsum(d_n))) %>%
  select(- t_n, - d_n, - t_m, - d_m, - scen)

RES_LIST[["ch"]] <- A_DATA %>%
  filter(scen %in% c("H", "A"), dom == "ch") %>%
  mutate(t_m = ifelse(year >= RANGE$mch, 1, NA)) %>%
  # Extrapolate mean pension levels with a linear time trend over the Cartesian product of 
  # pension type, sex, and domicile.
  group_by(sex, type) %>%
  arrange(year, .by_group = TRUE) %>% 
  mutate(d_m = m - lag(m)) %>%
  ungroup() %>%
  impute_lm(d_m ~ 0 + t_m | sex + type) %>%
  group_by(sex, type) %>%
  mutate(d_m = ifelse(year > first(PAR$pint), d_m, 0)) %>%
  fill(m) %>% 
  mutate(m = pmin(2, pmax(0, m + cumsum(d_m)))) %>%
  select(- t_m, - d_m, - scen)

# Import historical intermediate-year deaths of retirees and their average
# pensions (source: RR). Factor 'mean(1:11) / 12' in 'save_rel' approximates the
# ratio of savings due to within-year pensioner deaths under monthly versus
# yearly payout of 13th AHV pension payment according to the Liechtenstein model.
LIECHTENSTEIN <-
  read_delim(PAR$in_lim) %>%
  select(year = an_rr, sex = csex, dom = recoded_cdom, n = anzahl,
         m_dec = rentensumme_dez) %>%
  mutate(year = year + 1,
         sex  = recode(sex, `1` = "m", `2` = "f"),
         dom  = recode(dom, `100` = "ch", `900` = "au")) %>%
  # Multiplication by '11/12' corrects for death count in the preceding december
  # (implicit assumption: uniform death rate over months and pension levels).
  summarize(m = 11 / 12 * sum(m_dec) / sum(n), n = sum(n),
            .by = c("year", "sex", "dom")) %>%
  right_join(
    expand_grid(year = min(.$year):last(PAR$pint),
                sex  = c("m", "f"), dom = c("ch", "au")),
    by = c("year", "sex", "dom")
  ) %>%
  # arrange(year) %>% 
  left_join(MINIMAL_PENSION, by = "year") %>%
  # Express mean monthly pensions as multiples of minimal rent.
  mutate(m = m / mp) %>%
  group_by(sex, dom) %>%
  arrange(year, .by_group = TRUE) %>% 
  mutate(d_n = c(0, diff(n)), d_m = c(0, diff(m))) %>%
  ungroup() %>%
  # Impute death counts and mean pensions linearly and separately by sex and domicile.
  impute_lm(d_n + d_m ~ 1 | sex + dom) %>%  
  group_by(sex, dom) %>% 
  fill(m, n) %>% 
  mutate(d_n = cumsum((year > first(PAR$pint)) * d_n),
         d_m = cumsum((year > first(PAR$pint)) * d_m),
         m = pmin(2, pmax(0, m + d_m)), 
         n =         pmax(0, n + d_n),
         savings_tot = m * n * mp) %>%
  ungroup() %>% 
  summarize(savings_tot = sum(savings_tot), .by = c("sex", "dom", "year")) %>%
  filter(year >= first(PAR$pint)) %>%
  # Calculate effective savings relative to yearly payout.
  mutate(savings_rel = ifelse(year >= 2026, savings_tot * (1 - mean(1:11) / 12), 0),
         type = "alt") %>%
  select(year, sex, dom, type, savings_rel)

# Replace relative average pension projections for widows with exogenous vector based on 
# consolidated STATPOP & pension registry data (author: Thomas Borek @BSV, mathematics
# department).
PROJ_DATA <-
  bind_rows(RES_LIST) %>%
  filter(!(type == "wit" & year > first(PAR$pint))) %>%
  bind_rows(
    WIDOWS %>% 
      filter(year %in% PAR$pint, scen == "A") %>% 
      select(- scen)
  ) %>%
  left_join(LIECHTENSTEIN, by = c("sex", "dom", "type", "year")) %>%
  replace_na(list(savings_rel = 0)) %>% 
  # Incorporate Liechtenstein effect through an equivalent decrease of surviving 
  # retirees' average pension level.
  mutate(m = ifelse(type == "alt" & year >= 2026,
                    m * (1 + PAR$ahv13 * 1/12) - 
                      PAR$ahv13 * savings_rel / (n * mp), m)) %>% 
  left_join(ZAS, by = "year") %>% 
  ungroup()

# Fit model and produce projections.
FIT_DATA <- PROJ_DATA %>%
  summarize(m = weighted.mean(m, n), n = sum(n), .by = c("mp", "exp_tot", "year")) %>%
  filter(year %in% RANGE[["tot"]]:first(PAR$pint))

fit <-
  lm(diff(exp_tot) ~ 0 + diff(n * m * mp), FIT_DATA)

PROJ <- PROJ_DATA %>%
  filter(year >= 2024) %>%
  mutate(type = fct_recode(type, other = "kin", other = "wai")) %>%
  summarize(exp = sum(n * m * mp - savings_rel), .by = c("type", "year")) %>%
  mutate(exp = exp / 1e6)

REST <- PROJ %>%
  mutate(type = "other", exp = exp * (coef(fit) - 1)) %>%
  summarize(exp = sum(exp), .by = c("type", "year"))

PROJ <- PROJ %>%
  bind_rows(REST) %>%
  summarize(exp = sum(exp), .by = c("type", "year"))

ADJ <-
  tribble(~ type , ~ adj,
          "alt"  , 47945 / 47607,
          "wit"  ,  1914 /  1897,
          "other",  1082 /  1467)

PROJ <- PROJ %>%
  left_join(ADJ, by = "type") %>%
  mutate(exp = round(adj * exp),
         type = fct_recode(type, Altersrenten = "alt", Verwitwetenrenten = "wit",
                           `Andere Ausgaben` = "other")) %>%
  select(- adj) %>%
  rename(Ausgabentyp = type, Jahr = year, Betrag = exp)

write_delim(PROJ,
            "~/anfragen/outputs/kostenaufteilung_proj_basismodell_aug25.csv", delim = ";")

# Dekomposition des Ausgabenwachstums in Volumen versus Rentenniveaus --------------------

PROJ_DATA %<>%
  filter(type == "alt") %>%
  dplyr::summarize(m = weighted.mean(m, n), n = sum(n),
                   .by = c("dom", "year")) %>%
  na.omit()

# p_dat %<>%
#   filter(type == "alt") %>%
#   dplyr::summarize(m = weighted.mean(m, n), n = sum(n),
#                    .by = c("year", "mp", "exp_tot"))

# p_dat %<>%
#   filter(type == "alt") %>%
#   dplyr::summarize(m = weighted.mean(m, n), n = sum(n),
#                    .by = c("dom", "year", "mp", "exp_tot"))

# rtab <- p_dat %>%
#   filter(year %in% PAR$pint) %>%
#   mutate(d_nmmp = c(0, cumsum(diff(n * m * mp)))) %>%
#   # Adjust projections by estimated cost top-up.
#   mutate(exp_tot =
#            ifelse(year > first(PAR$pint),
#                   first(exp_tot) + d_nmmp * coef(fit),
#                   exp_tot)) %>%
#   dplyr::rename(ref = exp_tot) %>%
#   select(- d_nmmp)

WR <- filter(PROJ_DATA, year <= 2024) %>%
  # bind_rows(rtab) %>%
  select(dom, year, n, m) %>%
  group_by(dom) %>%
  arrange(year, .by_group = TRUE) %>%
  mutate(wr_pen = n / lag(n) - 1,
         wr_ave = m / lag(m) - 1) %>%
  na.omit()

# BEV <-
#   read_delim("~/data/appl-wb/20_staff/kjo/fhh/2025-09-18T0941_u80874371_ahv_basis/BEVOELKERUNG.csv") %>%
#   select(sex, alt, jahr, bevendejahr) %>%
#   filter(jahr %in% 2015:2025, alt >= 65) %>%
#   summarize(bev = sum(bevendejahr), .by = c("sex", "jahr"))

BEV <-
  read_delim("~/data/appl-wb/20_staff/kjo/fhh/2025-09-18T0941_u80874371_ahv_basis/BEVOELKERUNG.csv") %>%
  select(sex, alt, year = jahr, bevendejahr) %>%
  filter(year <= 2024, 
         !(sex == "f" & alt < 64),
         !(sex == "m" & alt < 65)) %>%
  summarize(bev = sum(bevendejahr), .by = c("year")) %>%
  mutate(wr_bev = bev / lag(bev) - 1) %>%
  na.omit()

VIZ <- WR %>%
  ungroup() %>%
  select(dom, year, wr_pen) %>%
  left_join(BEV, by = c("year")) %>%
  select(dom, year, wr_bev, wr_pen) %>%
  na.omit() %>%
  pivot_longer(cols = wr_bev:wr_pen) %>%
  mutate(dom = fct_recode(dom, `Domizil: Ausland` = "au", `Domizil: Inland` = "ch"))

# VIZ <- WR %>%
#   ungroup() %>%
#   select(year, wr_pen) %>%
#   left_join(BEV, by = "year") %>%
#   select(year, wr_bev, wr_pen) %>%
#   na.omit() %>%
#   pivot_longer(cols = wr_bev:wr_pen)

# ggplot(VIZ, aes(x = as.factor(year), y = value, fill = name)) +
#   geom_col(position = "dodge") +
#   labs(x = NULL, y = "Wachstumsrate zum Vorjahr") +
#   scale_fill_brewer(type = "qual",
#     labels = c("Inländische Wohnbevölkerung Alter 65+",
#                "Anzahl Altersrentenansprüche im Ausland"), name = NULL) +
#   theme_grey(base_size = 13) +
#   theme(legend.position = "top") +
#   scale_y_continuous(labels = scales::percent)

ggplot(VIZ, aes(x = as.factor(year), y = value, fill = name)) +
  geom_col(position = "dodge") +
  labs(x = NULL, y = "Wachstumsrate zum Vorjahr") +
  scale_fill_brewer(type = "qual",
                    labels = c("Inländische Wohnbevölkerung Alter 64+",
                               "Anzahl ausbezahlte Altersrenten"), name = NULL) +
  theme_grey(base_size = 16) +
  theme(legend.position = "top") +
  scale_y_continuous(labels = scales::percent) +
  facet_grid(dom ~ .) +
  theme(axis.text.x = element_text(angle = 45, vjust = .5))

# ggplot(VIZ, aes(x = as.factor(year), y = value, fill = name)) +
#   geom_col(position = "dodge") +
#   labs(x = NULL, y = "Wachstumsrate zum Vorjahr") +
#   scale_fill_brewer(type = "qual",
#                     labels = c("Inländische Wohnbevölkerung Alter 65+",
#                                "Anzahl Altersrentenansprüche (Inland + Ausland)"), name = NULL,
#                     palette = 3) +
#   theme_grey(base_size = 16) +
#   theme(legend.position = "top") +
#   scale_y_continuous(labels = scales::percent)

DATA <- 
  tibble::tribble(
    ~Jahr,    ~Rentenrart, ~class1_txt, ~class2_txt, ~class3_txt, ~class1_val, ~class2_val, ~class3_val,   ~N_Dez,       ~N,
    2005L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 1698329L, 1754253L,
    2006L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 1749177L, 1805285L,
    2007L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 1808234L, 1865617L,
    2008L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 1868973L, 1927320L,
    2009L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 1929643L, 1990175L,
    2010L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 1981278L, 2042893L,
    2011L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 2031367L, 2093534L,
    2012L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 2088379L, 2154004L,
    2013L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 2142728L, 2209851L,
    2014L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 2196409L, 2262842L,
    2015L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 2239836L, 2312358L,
    2016L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 2285453L, 2355519L,
    2017L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 2324854L, 2398725L,
    2018L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 2363790L, 2439095L,
    2019L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 2403779L, 2480753L,
    2020L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 2438787L, 2521231L,
    2021L, "Altersrenten",     "Total",     "Total",     "Total",     -99999L,     -99999L,     -99999L, 2470754L, 2551585L
    ) %>% 
  select(Jahr, N, N_Dez) %>% 
  mutate(delta = N - N_Dez) %>% 
  pivot_longer(N:delta)

ggplot(DATA %>% filter(name != "delta"), aes(x = as.factor(Jahr), y = value, fill = name)) +
  geom_col(position = "dodge")

ggplot(DATA %>% filter(name == "delta"), aes(x = as.factor(Jahr), y = value)) +
  geom_col()

