##########################################################
# IMPUTE FUTURE FOREIGN POPULATION COUNTS AND MEAN RENTS #
################################################################################

# Disable 'collapse' namespace to avoid collisions.
set_collapse(mask = NULL)

# Range of trend extrapolation points from crossvalidation (see 'cv_out.R').
load("data/range.RData")

# Set up list to collect results.
rl <- list()

# Imputation over rent type, sex, and domicile for the BFS reference scenario 'A'.
rl[["au"]] <- a_dat %>%
  filter(scen %in% c("H", "A"), dom == "au") %>%
  mutate(t_n = ifelse(year >= range[1], 1, NA), t_m = ifelse(year >= range[2], 1, NA)) %>%
  group_by(sex, type) %>%
  arrange(year, by_group = TRUE) %>% 
  mutate(d_n = c(NA, diff(n)), d_m = c(NA, diff(m))) %>%
  # Extrapolate pension counts and mean pension levels with a linear time trend over
  # the Cartesian product of pension type, sex, and domicile.
  ungroup() %>%
  impute_lm(d_n ~ 0 + t_n | sex + type) %>%
  impute_lm(d_m ~ 0 + t_m | sex + type) %>%
  group_by(sex, type) %>%
  mutate(d_n = ifelse(year > first(PAR$pint), d_n, 0),
         d_m = ifelse(year > first(PAR$pint), d_m, 0)) %>%
  # Prevent eventual negative predictions due to linear extrapolation.
  mutate(m = pmin(2, pmax(0, na_locf(m) + cumsum(d_m))),
         n =         pmax(0, na_locf(n) + cumsum(d_n))) %>%
  select(- t_n, - d_n, - t_m, - d_m, - scen)

rl[["ch"]] <- a_dat %>%
  filter(scen %in% c("H", "A"), dom == "ch") %>%
  mutate(t_m = ifelse(year >= range[3], 1, NA)) %>%
  # Extrapolate mean pension levels with a linear time trend over the Cartesian product of 
  # pension type, sex, and domicile.
  group_by(sex, type) %>%
  arrange(sex, type, year) %>% 
  mutate(d_m = c(NA, diff(m))) %>%
  ungroup() %>%
  impute_lm(d_m ~ 0 + t_m | sex + type) %>%
  group_by(sex, type) %>%
  mutate(d_m = ifelse(year > first(PAR$pint), d_m, 0)) %>%
  mutate(m = pmin(2, pmax(0, na_locf(m) + cumsum(d_m)))) %>%
  select(- t_m, - d_m, - scen)

# Import historical intermediate-year deaths of retirees and their average
# pensions (source: RR). Factor 'mean(1:11) / 12' in 'save_rel' approximates the
# ratio of savings due to within-year pensioner deaths under monthly versus
# yearly payout of 13th AHV pension payment according to the Liechtenstein model.
corr <-
  read_delim(PAR$in_lim) %>%
  select(year = an_rr, sex = csex, dom = recoded_cdom, n = anzahl,
         m_dec = rentensumme_dez) %>%
  mutate(year = year + 1,
         sex  = recode(sex, `1` = "m", `2` = "f"),
         dom  = recode(dom, `100` = "ch", `900` = "au")) %>%
  # Multiplication by '11/12' corrects for death count in the preceding december
  # (implicit assumption: uniform death rate over months and pension levels).
  dplyr::summarize(m = 11 / 12 * sum(m_dec) / sum(n), n = sum(n),
                   .by = c("year", "sex", "dom")) %>%
  right_join(
    tibble(year = rep(min(.$year):last(PAR$pint), each = 4),
           sex  = rep(c("m" ,  "m",  "f",  "f"), length(min(.$year):last(PAR$pint))),
           dom  = rep(c("ch", "au", "ch", "au"), length(min(.$year):last(PAR$pint)))),
             by = c("year", "sex", "dom")
    ) %>%
  # arrange(year) %>% 
  left_join(mp, by = "year") %>%
  # Express mean monthly pensions as multiples of minimal rent.
  mutate(m = m / mp) %>%
  group_by(sex, dom) %>%
  arrange(sex, dom, year) %>% 
  mutate(d_n = c(0, diff(n)), d_m = c(0, diff(m))) %>%
  ungroup() %>%
  # Impute death counts and mean pensions linearly and separately by sex and domicile.
  impute_lm(d_n + d_m ~ 1 | sex + dom) %>%  
  group_by(sex, dom) %>% 
  mutate(d_n = cumsum((year > first(PAR$pint)) * d_n),
         d_m = cumsum((year > first(PAR$pint)) * d_m)) %>%
  mutate(m = pmin(2, pmax(0, na_locf(m) + d_m)), 
         n =         pmax(0, na_locf(n) + d_n)) %>%
  mutate(save = m * n * mp) %>%
  ungroup() %>% 
  dplyr::summarize(save_tot = sum(save), .by = c("year", "sex", "dom")) %>%
  filter(year >= first(PAR$pint)) %>%
  # Calculate effective savings relative to yearly payout.
  mutate(save_rel = ifelse(year >= 2026, save_tot * (1 - mean(1:11) / 12), 0),
         type = "alt") %>%
  select(year, sex, dom, type, save_rel)

# Replace average pension projections for male widows with exogenous
# vector based on consolidated STATPOP & pension registry data (author:
# Thomas Borek @Math-BSV).
p_dat <-
  bind_rows(rl) %>%
  filter(!(type == "wit" & year > first(PAR$pint))) %>%
  bind_rows(
    wid %>% 
      filter(year > first(PAR$pint), year <= last(PAR$pint),
             scen == "A") %>% 
      select(- scen)
  ) %>%
  left_join(corr, by = c("year", "sex", "dom", "type")) %>%
  tidyr::replace_na(list(save_rel = 0)) %>% 
  # Incorporate Liechtenstein effect through an equivalent decrease of surviving 
  # retirees' average pension level.
  mutate(m = ifelse(type == "alt" & year >= 2026,
                    m * (1 + PAR$ahv13 * 1/12) - 
                             PAR$ahv13 * save_rel / (n * mp), m)) %>% 
  left_join(zas, by = "year") %>% 
  ungroup()

# Fit model and produce projections.
f_dat <- p_dat %>%
  dplyr::summarize(m = weighted.mean(m, n), n = sum(n),
                   .by = c("year", "mp", "exp_tot")) %>%
  filter(year %in% range[4]:first(PAR$pint)) %>% 
  mutate(d_nmmp = c(NA, diff(n * m * mp)))

fit <-
  lm(I(c(NA, diff(exp_tot))) ~ 0 + d_nmmp, f_dat)

# Process results according to desired aggregation level.
if (PAR$agg) {
  
  p_dat %<>%
    dplyr::summarize(m = weighted.mean(m, n), n = sum(n),
                     .by = c("year", "mp", "exp_tot")) %>%
    # Set 's' to 1 to avoid later NA's due to AHV 21 cost assignment.
    mutate(s = 1)
  
} else {
  
  p_dat %<>%
    dplyr::summarize(m = m, n = sum(n),
                     .by = c("year", "mp", "exp_tot", "sex", "dom", "type"))
  
  # Assignment of AHV21 cost to women only according to rent population shares 
  # (heuristic).
  s21 <- p_dat %>% 
    filter(sex == "f", type == "alt") %>% 
    group_by(year) %>% 
    mutate(s = n / sum(n)) %>% 
    select(year, sex, dom, type, s)
  
  p_dat %<>%
    left_join(s21, by = c("year", "sex", "dom", "type")) %>% 
    tidyr::replace_na(list(s = 0))
}

# Consolidate results.
rtab <- p_dat %>% 
  filter(year %in% PAR$pint) %>%
  mutate(d_nmmp = c(0, cumsum(diff(n * m * mp)))) %>% 
  # Adjust projections by estimated cost top-up.
  mutate(exp_tot = 
           ifelse(year > first(PAR$pint), 
                  first(exp_tot) + d_nmmp * coef(fit), 
                  exp_tot)) %>% 
  dplyr::rename(ref = exp_tot) %>% 
  select(- d_nmmp)

# Adjust predictions by the adapt2ed Delfin AHV21 cost projections.
if (PAR$ahv21) 
  rtab %<>% 
    left_join(PAR$ahv21_cost, by = "year") %>%
    # Assignment of reform cost differential via variable 's'.
    mutate(ref = ref + s * cost) %>% 
    select(- cost)

if (PAR$real)
  rtab %<>% 
    left_join(select(eck, year, df), by = "year") %>% 
    mutate(ref = ref * df) %>% 
    select(- df)

# Express results in millions and round.
rtab %<>%
  mutate(ref = round(ref / 1e6))

# Print and save final projections.
if (PAR$write) {
  if (PAR$agg) {

    write_delim(select(rtab, year, ref), delim = ";",
                file = ifelse(PAR$real, "data/proj_base_agg_real.csv", 
                                        "data/proj_base_agg_nomi.csv"))
    
  } else {

    write_delim(select(rtab, year, sex, dom, type, ref), delim = ";",
                file = ifelse(PAR$real, "data/proj_base_disagg_real.csv", 
                                        "data/proj_base_disagg_nomi.csv"))
    
  }
}

if (PAR$bands) {
  
  # Produce predictive intervals for total AHV expenditures.
  source("scripts/scenarios.R", echo = TRUE)
  
  rtab %<>%
    left_join(
      sim %>% 
        pivot_wider(names_from = scen, values_from = low:high) %>% 
        select(year, low = high_C, high = low_B), 
      by = "year") %>% 
    relocate(year, low, ref, high)
  
  # Visualize reference projection with uncertainty bands.
  ggplot(select(rtab, - mp, - m, - n, - s) %>% pivot_longer(low:high)) +
    geom_ribbon(data = rtab, aes(x = year, ymin = low, ymax = high), alpha = .2) +
    geom_line(aes(x = year, y = value, col = as.factor(name))) +
    geom_shadowpoint(aes(x = year, y = value, col = as.factor(name))) +
    labs(x = NULL, 
         y = ifelse(PAR$real, 
                    "Millionen Franken (real)\n", 
                    "Millionen Franken (laufende Preise)\n"),
         title = "AHV-Ausgabenprojektion mit Unsicherheitsbändern") +
    theme_grey(base_size = 16, base_family = "Garamond") +
    scale_colour_viridis_d(labels = c("ungünstig", "günstig", "mittel"),
                           name = "Szenario", option = "A") +
    scale_y_continuous(labels = scales::comma)
}

# Print results.
select(rtab, - mp, - m, - n, - s)
  # mutate(low = coalesce(low, ref), high = coalesce(high, ref)) %>% 
  # mutate(ref = ifelse(year > first(PAR$pint) + 10, round(ref, -2), ref)) %>% 

