##########################################################
# IMPUTE FUTURE FOREIGN POPULATION COUNTS AND MEAN RENTS #
################################################################################

# Disable 'collapse' namespace to avoid code collisions.
set_collapse(mask = NULL)

# Read in historical total AHV expenditures from the ZAS Betriebsrechnung.
zas %<>% 
  filter(year <= first(par$pint))

# Load CV results.
load("data/range.RData")

# Set up list to collect results.
rl <- list()

# Imputation over rent type, sex, and domicile for the BFS reference scenario 'A'.
rl[["au"]] <-
  a_dat %>%
  filter(scen %in% c("H", "A"), dom == "au") %>%
  mutate(tn = ifelse(year >= range[1], 1, NA), tm = ifelse(year >= range[2], 1, NA)) %>%
  group_by(sex, type) %>%
  arrange(sex, type, year) %>% 
  mutate(dn = c(NA, diff(n)), dm = c(NA, diff(m))) %>%
  # Extrapolate rent counts and mean rents with a linear time trend. This is
  # done separately for the Cartesian product of rent type, sex, and domicile.
  ungroup() %>%
  impute_lm(dn ~ 0 + tn | sex + type) %>%
  impute_lm(dm ~ 0 + tm | sex + type) %>%
  group_by(sex, type) %>%
  mutate(dn = ifelse(year > first(par$pint), dn, 0),
         dm = ifelse(year > first(par$pint), dm, 0)) %>%
  # Prevent eventual negative predictions due to linear extrapolation.
  mutate(m = na_locf(m), n = na_locf(n)) %>%
  mutate(m = pmax(0, m + cumsum(dm)), n = pmax(0, n + cumsum(dn))) %>%
  select(- tn, - tm, - dn, - dm, - scen)

rl[["ch"]] <-
  a_dat %>%
  filter(scen %in% c("H", "A"), dom == "ch") %>%
  mutate(tm = ifelse(year >= range[3], 1, NA),
         # Interpolate "Ehegattenrenten" linearly since they are not awarded anymore
         # since January 1997.
         n  = ifelse(type == "ehe" & year > first(par$pint), NA, n),
         tn = year - min(year)) %>%
  # Extrapolate rent counts and mean rents with a linear time trend. This is
  # done separately for the Cartesian product of rent type, sex, and domicile.
  group_by(sex, type) %>%
  arrange(sex, type, year) %>% 
  mutate(dn = c(NA, diff(n)), dm = c(NA, diff(m))) %>%
  ungroup() %>%
  impute_lm(dm ~ 0 + tm | sex + type) %>%
  impute_lm(dn ~ 0 + tn | sex + type) %>%
  group_by(sex, type) %>%
  mutate(dn = cumsum(ifelse(year > first(par$pint) & type == "ehe", dn, 0)),
         dm = cumsum(ifelse(year > first(par$pint), dm, 0))) %>%
  mutate(m = pmax(0, na_locf(m) + dm),
         n = pmax(0, na_locf(n) + dn)) %>%
  select(- tn, - tm, - dn, - dm, - scen)

# Import historical intermediate-year deaths of retirees and their average
# pensions (source: RR). Factor 'mean(1:11) / 12' in 'save_rel' approximates the
# ratio of savings due to within-year pensioner deaths under monthly versus
# yearly payout of 13th AHV pension payment according to the Liechtenstein model.
# Read in projected minimal rent data (Delfin output).
corr <-
  read_delim(par$in_lim) %>%
  select(year = an_rr, sex = csex, dom = recoded_cdom, n = anzahl,
         m_dec = rentensumme_dez) %>%
  mutate(year = year + 1,
         sex  = recode(sex, `1` = "m"   , `2` = "f"),
         dom  = recode(dom, `100` = "ch", `900` = "au")) %>%
  # Multiplication by '11/12' corrects for death count in the preceding december
  # (implicit assumption: uniform death rate over months and pension levels).
  dplyr::summarize(m = 11 / 12 * sum(m_dec) / sum(n), n = sum(n),
                   .by = c("year", "sex", "dom")) %>%
  right_join(
    tibble(year = rep(min(.$year):last(par$pint), each = 4),
           sex  = rep(c("m" ,  "m",  "f",  "f"), length(min(.$year):last(par$pint))),
           dom  = rep(c("ch", "au", "ch", "au"), length(min(.$year):last(par$pint)))),
             by = c("year", "sex", "dom")
    ) %>%
  arrange(year) %>% 
  left_join(mi , by = "year") %>%
  # Express mean pensions as multiples of minimal rent.
  mutate(m = 12 * m / mi) %>%
  group_by(sex, dom) %>%
  arrange(sex, dom, year) %>% 
  mutate(dn = c(NA, diff(n)), dm = c(NA, diff(m))) %>%
  ungroup() %>%
  # Impute death counts and mean pensions linearly and separately by sex and domicile.
  impute_lm(dm ~ 1 | sex + dom) %>%
  impute_lm(dn ~ 1 | sex + dom) %>%
  group_by(sex, dom) %>% 
  mutate(dn = cumsum(ifelse(year > first(par$pint), dn, 0)),
         dm = cumsum(ifelse(year > first(par$pint), dm, 0))) %>%
  mutate(m = pmax(0, na_locf(m) + dm),
         n = pmax(0, na_locf(n) + dn)) %>%
  mutate(save = m * n * mi) %>%
  ungroup() %>% 
  dplyr::summarize(save_tot = sum(save), .by = c("year", "sex", "dom")) %>%
  filter(year > first(par$pint)) %>%
  # Calculate effective savings relative to yearly payout.
  mutate(save_rel = ifelse(year >= 2026, save_tot * (1 - mean(1:11) / 12), 0),
         type = "alt") %>%
  select(year, sex, dom, type, save_rel)

save(corr, file = "data/li_corr.rdata")

if (par$wid) {
  # Replace average pension projections for male widows with exogenous
  # vector based on consolidated STATPOP & rent registry data (author:
  # Thomas Borek).
  wid %<>% 
    left_join(mi, by = "year") %>% 
    filter(year <= last(par$pint))
    
  p_dat <-
    bind_rows(rl) %>%
    left_join(corr, by = c("year", "sex", "dom", "type")) %>%
    # Incorporate Liechtenstein effect through an equivalent decrease of surviving 
    # retirees' average pension level.
    mutate(m = ifelse(type == "alt" & year >= 2026,
                      m * (1 + par$ahv13 * 1/12) - 
                               par$ahv13 * save_rel / (n * mi), m)) %>% 
    filter(!(type == "wit" & dom == "au" & sex == "m" & year > first(par$pint))) %>%
    bind_rows(filter(wid, dom == "au", sex == "m")) %>%
    left_join(zas, by = "year") %>% 
    ungroup()
  
} else {
  
  p_dat <-
    bind_rows(rl) %>%
    left_join(corr, by = c("year", "sex", "dom", "type")) %>% 
    # Incorporate Liechtenstein effect through an equivalent decrease of surviving 
    # retirees' average pension level.
    mutate(m = ifelse(type == "alt" & year >= 2026,
                      m * (1 + par$ahv13 * 1/12) - 
                               par$ahv13 * save_rel / (n * mi), m)) %>%
    left_join(zas, by = "year") %>% 
    ungroup()
}

# Fit model and produce projections.
fdat <- 
  p_dat %>%
  dplyr::summarize(m = weighted.mean(m, n), n = sum(n),
                   .by = c("year", "mi", "exp_tot")) %>%
  filter(year %in% range[4]:first(par$pint)) %>% 
  mutate(d_nmmi = c(NA, cumsum(diff(n * m * mi))))

fit <-
  lm(I(c(NA, cumsum(diff(exp_tot)))) ~ 0 + d_nmmi, fdat)

# Process results according to desired aggregation level.
if (par$agg) {
  
  p_dat %<>%
    dplyr::summarize(m = weighted.mean(m, n), n = sum(n),
                     .by = c("year", "mi", "exp_tot")) %>%
    # Set 's' to 1 to avoid later NA's due to AHV 21 cost assignment.
    mutate(s = 1)
  
} else {
  
  p_dat %<>%
    dplyr::summarize(m = m, n = sum(n),
                     .by = c("year", "mi", "exp_tot", "sex", "dom", "type"))
  
  # Assignment of AHV21 cost to women only according to rent population shares 
  # (heuristic).
  s21 <- 
    filter(p_dat, sex == "f", type == "alt") %>% 
    group_by(year) %>% 
    mutate(s = n / sum(n)) %>% 
    select(year, sex, dom, type, s)
  
  p_dat %<>%
    left_join(s21, by = c("year", "sex", "dom", "type")) %>% 
    tidyr::replace_na(list(s = 0))
}

# Consolidate results.
rtab <-
  filter(p_dat, year %in% par$pint) %>%
  mutate(d_nmmi = c(0, cumsum(diff(n * m * mi)))) %>% 
  # Adjust projections by estimated cost top-up.
  mutate(exp_tot = ifelse(year > first(par$pint), 
                          first(exp_tot) + d_nmmi * coef(fit), 
                          exp_tot)) %>% 
  left_join(inf, by = "year") %>% 
  dplyr::rename(exp_p = exp_tot) %>% 
  select(- d_nmmi)

# Adjust predictions by the adapted Delfin AHV21 cost projections and fix prices at the 
# latest observed year.
if (par$ahv21) 
  rtab %<>% 
    left_join(par$ahv21_cost, by = "year") %>%
    tidyr::replace_na(list(cost = 0)) %>% 
    # Assignment of reform cost differential via variable 's'.
    mutate(exp_p = exp_p + s * cost) %>% 
    select(- cost)

if (par$real)
  rtab %<>% 
    mutate(exp_p = exp_p * df)

# Convert results to millions and round.
rtab %<>%
  mutate(exp_p = round(exp_p / 1e6))

# Print and save final projections.
if (par$write) {
  if (par$agg) {
    fn <- 
      ifelse(par$real, "data/proj_base_agg_real.csv", 
                       "~/data/appl-wb/20_staff/kjo/check_fhh_19032025/proj_base_agg_nomi.csv")
    write_delim(select(rtab, year, exp_p), file = fn, delim = ";")
  } else {
    fn <- 
      ifelse(par$real, "data/proj_base_disagg_real.csv", 
                       "data/proj_base_disagg_nomi.csv")
    write_delim(select(rtab, year, sex, dom, type, exp_p), file = fn, delim = ";")
  }
}

if (par$pband) {
  
  # Produce predictive intervals for total AHV expenditures.
  source("scripts/scenarios.R")
  sim %<>%
    filter(scen %in% c("B", "C")) %>% 
    pivot_wider(names_from = scen, values_from = low:high) %>% 
    select(year, low = low_C, high = high_B)
  
  rtab %<>%
    left_join(sim, by = "year")
}
  
if (par$show)
  select(rtab, - mi, - m, - n, - s, - df)
