#########################################################################
# SIMULATE PREDICTIVE BANDS BASED ON NORMAL RESIDUALS AND BFS SCENARIOS #
##########################################################################################

# Setup. ----------------------------------------------------------------------------

# Fix degrees of freedom for predictive band simulations.
df_nau <- first(par$pint) - range[1] + 1 - 2
df_mau <- first(par$pint) - range[2] + 1 - 2
df_mch <- first(par$pint) - range[3] + 1 - 2
df_tot <- first(par$pint) - range[4] + 1 - 2

# Range of trend extrapolation points from crossvalidation (see 'cv_out.R').
load("data/range.RData")

# Foundational analysis data (see 'prepare_inputs.R').
a_dat <- 
  read_delim("data/a_dat.csv") %>% 
  arrange(sex, dom, type, year) %>% 
  filter(year <= last(par$pint)) %>% 
  left_join(zas, by = "year")

# Estimate pension top-up. ----------------------------------------------------------

fit_t <-
  lm(c(NA, diff(exp_tot)) ~ 0 + d_nmmp, f_dat)

mom_t <- 
  tibble(year = par$pint, mu_t = coef(fit_t), sd_t = glance(fit_t)$sigma)

# Collect prediction moments for explanatory variables. -----------------------------

res_p <- list()

for (c in unique(filter(a_dat, !(scen %in% c("H")))$scen)) {
  for (d in unique(a_dat$dom)) {
    for (t in unique(filter(a_dat, type != "wit")$type)) {
      for (s in unique(a_dat$sex)) {
        
        # Constrain extrapolation data according to 'range'.
        data <- a_dat %>% 
          filter(dom == d, type == t, sex == s, scen %in% c("H", c)) %>%
            mutate(t_n = ifelse(year >= range[1], 1, NA))
        
        if (d == "au")
          data %<>% 
            mutate(t_m = ifelse(year >= range[2], 1, NA))
        
        if (d == "ch")
          data %<>%
            mutate(t_m = ifelse(year >= range[3], 1, NA))
        
        # Set moment collection table.
        mom <-  
          tibble(year = par$pint[- 1], scen = c, dom = d, sex = s, type = t)
        
        p_dat <- 
          filter(data, year > first(par$pint))
        
        # Collect prediction moments for pension counts. Custom function 'psd' returns
        # standard deviation of predictions according to standard formula for simple
        # linear regression without intercept.
        if (d == "au") {
          
          fit_n <- 
            lm(c(NA, diff(n)) ~ 0 + t_n, data)
          
          mom %<>% 
            mutate(mu_n = predict(fit_n, p_dat), sd_n = psd(fit_n, p_dat$t_n))
          
        } else {
          
          fa_dat <- 
            filter(a_dat, year >= first(par$pint), 
                   scen %in% c("H", c), sex == s, type == t, dom == d) %>% 
            mutate(n = c(NA, diff(n))) %>% 
            slice(- 1)
          
          mom %<>% 
            mutate(mu_n  = fa_dat$n, sd_n = 0)
          
        }
        
        # Collect prediction moments for pension levels.
        fit_m <- 
          lm(c(NA, diff(m)) ~ 0 + t_m, data)
        
        mom %<>% 
          mutate(mu_m = predict(fit_m, p_dat), sd_m = psd(fit_m, p_dat$t_m))
        
        # Save moment table.
        res_p[[paste0(d, t, s, c)]] <- mom
        
      }
    }
  }
}

# Collect results and attach last historic period.
res_p[["ref"]] <- 
  filter(a_dat, year == first(par$pint)) %>% 
  # Respect that historic data is deterministic.
  mutate(sd_n = 0, sd_m = 0) %>% 
  dplyr::rename(mu_m = m, mu_n = n) %>% 
  select(scen, year, dom, sex, type, mu_n, sd_n, mu_m, sd_m)

# Simulate response distributions across time. --------------------------------------

#  Set number of simulation runs (fixed by experimentation).
N <- 50000

res_s <- list()

for (c in c("A", "B", "C")) {
  
  # Fix random seed for reproducibility.
  set.seed(95827323)
  
  sig_data <- 
    bind_rows(res_p) %>% 
    left_join(mom_t, by = "year") %>% 
    filter(year > first(par$pint), scen == c) %>% 
    select(- scen)
  
  for (d in sort(unique(sig_data$dom))) {
    for (t in sort(unique(sig_data$type))) {
      for (s in sort(unique(sig_data$sex))) {
        for (y in sort(unique(sig_data$year))) {
          
          param <- sig_data %>% 
            filter(year == y, sex == s, type == t, dom == d)
          
          if (d == "au") {
            
            res_s[[paste0(c, d, t, s, y)]] <-
              param %>% 
              tibble(
                sim_n = mu_n + rt(N, df_nau) * sd_n,
                sim_m = mu_m + rt(N, df_mau) * sd_m
                ) %>% 
              # select(year, scen, dom, sex, type, sim_n, sim_m) %>% 
              select(year, dom, sex, type, sim_n, sim_m) %>% 
              mutate(scen = c, run = 1:n())
          }
          
          if (d == "ch") {
            
            res_s[[paste0(c, d, t, s, y)]] <-
              param %>% 
              tibble(
                sim_n = mu_n,
                sim_m = mu_m + rt(N, df_mch) * sd_m
                ) %>% 
              # select(year, scen, dom, sex, type, sim_n, sim_m) %>% 
              select(year, dom, sex, type, sim_n, sim_m) %>% 
              mutate(scen = c, run = 1:n())
          }
          
        }
      }
    }
  }
}

# Combine simulation runs.
sim <-
  bind_rows(res_s) %>%
  arrange(year, type, scen) %>%
  group_by(year, scen)

# Prepare data from last historic period.
h_dat <- a_dat %>%
  filter(year == first(par$pint)) %>%
  select(year, dom, sex, type) %>%
  mutate(sim_n = 0, sim_m = 0) %>%
  expand(run = unique(sim$run), scen = c("A", "B", "C"),
         nesting(year, dom, sex, type, sim_n, sim_m)) %>%
  relocate(year, scen, dom, sex, type)

# Consolidate data, derive perturbed predictions across simulation runs, and integrate
# 13th AHV pension payment.
sim %<>%
  # Attach historic data.
  bind_rows(h_dat) %>% 
  left_join(a_dat %>% 
              filter(year == first(par$pint)) %>% 
              select(dom, sex, type, n, m), 
            by = c("dom", "sex", "type")) %>% 
  left_join(mp, by = "year") %>% 
  group_by(scen, run, dom, sex, type) %>%
  arrange(year, by_group = TRUE) %>%
  # Enforce legal/logical bounds on projections.
  mutate(
    sim_n =         pmax(0, n + cumsum(sim_n)), 
    sim_m = pmin(2, pmax(0, m + cumsum(sim_m)))
  ) %>% 
  # 13th AHV pension payment.
  mutate(sim_m = 
           ifelse(type == "alt" & year >= 2026, sim_m * (1 + par$ahv13 * 1/12), sim_m),
         d_nmmp = c(0, diff(sim_n * sim_m * mp))) %>% 
  ungroup() %>% 
  # Map variables onto explanatory variable from 'fit_t', separately for each simulation 
  # run per scenario.
  dplyr::summarize(d_nmmp = sum(d_nmmp), .by = c("scen", "run", "year")) %>% 
  mutate(p_exp = filter(f_dat, year == first(par$pint)) %>% pull(exp_tot)) %>% 
  left_join(mom_t, by = "year", relationship = "many-to-one") %>% 
  group_by(scen, run) %>%
  arrange(year, by_group = TRUE)

# Simulate final predictions incorporating uncertainty due to the rent sum top-up 
# estimation.

res_f <- list()

for (c in c("A", "B", "C")) {
  
  set.seed(95827323)
  
  temp <- 
    filter(sim, scen == c) %>% 
    group_by(run) %>% 
    select(- scen)
  
  res_f[[c]] <- temp %>% 
    mutate(d_nmmp = mu_t * d_nmmp + rt(n(), df_tot) *
             sd_t *
             sqrt(1 + d_nmmp^2 / sum(model.frame(fit_t)[, 2]^2)), 
           p_exp  = p_exp + cumsum((year != first(par$pint)) * d_nmmp)) %>% 
    group_by(year) %>% 
    # Calculate conditional confidence bands at the 5% error level.
    mutate(low  = quantile(p_exp, .95),
           high = quantile(p_exp, .05),
           mid  = mean(p_exp)) %>%
    select(year, low, mid, high) %>% 
    distinct() %>% 
    mutate(scen = c)
}

# Collect final simulation results.
sim <- bind_rows(res_f)

# Integrate Liechtenstein effect, AHV21 cost vector and complementary widow projections. 
# Only the widow projections are adapted to the demographic scenarios.
d_wid <- wid %>% 
  bind_rows(a_dat %>%
              select(year, sex, dom, type, m, n, mp) %>%
              filter(year == 2024, type == "wit")) %>%
  group_by(sex, dom) %>% 
  arrange(year, by_group = TRUE) %>% 
  mutate(d_wid = c(0, diff(n * m * mp))) %>% 
  ungroup() %>% 
  dplyr::summarize(d_wid = sum(d_wid), .by = "year") %>% 
  mutate(d_wid = cumsum(d_wid))

sim %<>%
  left_join(corr %>% 
       select(year, save_rel) %>% 
       dplyr::summarize(save = sum(save_rel), .by = "year"), 
     by = "year") %>% 
  left_join(d_wid, by = "year") %>%
  left_join(par$ahv21_cost, by = "year") %>% 
  mutate(low  = low  + par$ahv21 * cost + d_wid - save,
         mid  = mid  + par$ahv21 * cost + d_wid - save,
         high = high + par$ahv21 * cost + d_wid - save) %>%
  select(- cost, - save, - d_wid)

# Deflate results, if desired.
if (par$real)
  sim %<>%
    left_join(eck, by = "year") %>%
    mutate(low = low * df, high = high * df) %>%
    select(- df)

# Express results in millions and round.
sim %<>%
  mutate(low  = round(low  / 1e6),
         mid  = round(mid  / 1e6),
         high = round(high / 1e6))

