#########################################################################
# SIMULATE PREDICTIVE BANDS BASED ON NORMAL RESIDUALS AND BFS SCENARIOS #
##########################################################################################

# Create random samples from t-distributions for later resampling. ------------------

if (par$sim) {
  
  # Set seed for reproducibility.
  seed <- sample(10000:100000, 1)
  set.seed(seed)
  N    <- 10000000 
  
  # Fix degrees of freedom for respective regressions, taking into consideration that
  # differencing and estimating one parameter decreases the effective sample size by 2.
  df_nau <- first(par$pint) - range[1] + 1 - 2
  df_mau <- first(par$pint) - range[2] + 1 - 2
  df_mch <- first(par$pint) - range[3] + 1 - 2
  df_tot <- first(par$pint) - range[4] + 1 - 2
  
  # Exploit that Student distributions emerge as centered normal variates with inverse
  # Gamma distributed variances.
  sim_list <- 
    list(seed = seed,
         sim_nau = dqrnorm(N) / sqrt(rchisq(N, df_nau) / df_nau),
         sim_mau = dqrnorm(N) / sqrt(rchisq(N, df_mau) / df_mau),
         sim_mch = dqrnorm(N) / sqrt(rchisq(N, df_mch) / df_mch),
         sim_tot = dqrnorm(N) / sqrt(rchisq(N, df_tot) / df_tot))
  
  save(sim_list, file = "~/data/appl-wb/20_staff/kjo/misc_data/sim_list.rdata")
}

# Load simulated distributions.
load("~/data/appl-wb/20_staff/kjo/misc_data/sim_list.rdata")


# Load processed data from 'basismodell'. -------------------------------------------

# Range of trend extrapolation points from crossvalidation.
load("data/range.RData")

# Historic 'Abschlussrechnungen" for total AHV expenditures.
zas <- 
  read_delim("data/zas.csv") %>% 
  filter(year <= first(par$pint))

# Foundational analysis data.
a_dat <- 
  read_delim("data/a_dat.csv") %>% 
  arrange(sex, dom, type, year) %>% 
  filter(year <= last(par$pint)) %>% 
  left_join(zas, by = "year")

# Historic and projected minimal pension levels.
mp <- 
  read_delim("data/mpen_VA26001.csv") %>% 
  select(year = jahr, mp = minimalrente) %>% 
  mutate(mp = 12 * mp)

# Estimate pension top-up. ----------------------------------------------------------

t_dat <- 
  filter(
    # Constrain to historic data.
    filter(a_dat, scen == "H") %>%
           dplyr::summarize(m = weighted.mean(m, n), n = sum(n),
                            .by = c("year", "mi", "exp_tot")), 
         year %in% range[4]:first(par$pint)
  ) %>% 
  mutate(d_nmmp = c(NA, diff(n * m * mi)))

fit_t <-
  lm(I(c(NA, diff(exp_tot))) ~ 0 + d_nmmp, t_dat)

mom_t <- 
  tibble(year = par$pint, mu_b = coef(fit_t), sd_t = glance(fit_t)$sigma)


# Collect prediction moments for explanatory variables. -----------------------------

res_p <- list()

for (c in unique(filter(a_dat, scen != "H")$scen)) {
  for (d in unique(a_dat$dom)) {
    for (t in unique(a_dat$type)) {
      for (s in unique(a_dat$sex)) {
        
        # Constrain extrapolation data according to 'range'.
        data <- 
          filter(a_dat, dom == d, type == t, sex == s, scen %in% c("H", c)) %>%
            mutate(tn = ifelse(year >= range[1], 1, NA))
        
        if (d == "au")
          data %<>% 
            mutate(tm = ifelse(year >= range[2], 1, NA))
        
        if (d == "ch")
          data %<>%
            mutate(tm = ifelse(year >= range[3], 1, NA))
        
        # Set moment collection table.
        mom <-  
          tibble(year = par$pint[-1],
                 scen = c, dom = d, sex = s, type = t)
        
        pdat <- 
          filter(data, year > first(par$pint))
        
        # Collect prediction moments for pension counts. Custom function 'psd' returns
        # standard deviation of predictions according to standard formula for simple
        # linear regression without intercept.
        if (d == "au") {
          fit.n <- 
            lm(I(c(NA, diff(n))) ~ 0 + tn, data)
          
          mom %<>% 
            mutate(mu_n = predict(fit.n, pdat),
                   sd_n = psd(fit.n, pdat$tn))
        } else {
          fa_dat <- 
            filter(a_dat, year >= first(par$pint), 
                   scen %in% c("H", c), sex == s, type == t, dom == d) %>% 
            mutate(n = c(NA, diff(n))) %>% 
            slice(-1)
          
          mom %<>% 
            mutate(mu_n  = fa_dat$n,
                   sd_n = 0)
        }
        
        # Collect prediction moments for pension levels.
        fit.m <- 
          lm(I(c(NA, diff(m))) ~ 0 + tm, data)
        
        mom %<>% 
          mutate(mu_m = predict(fit.m, pdat),
                 sd_m = psd(fit.m, pdat$tn))
        
        # Save moment table.
        res_p[[paste0(d, t, s, c)]] <- mom
      }
    }
  }
}

# Collect results and attach last historic period.
res_p[["ref"]] <- 
  filter(a_dat, year == first(par$pint)) %>% 
  # Respect that historic data is deterministic.
  mutate(sd_n = 0, sd_m = 0) %>% 
  dplyr::rename(mu_m = m, mu_n = n) %>% 
  select(year, scen, dom, sex, type, mu_n, sd_n, mu_m, sd_m)

sig_data <- 
  bind_rows(res_p) %>% 
  left_join(mom_t, by = "year") %>% 
  group_by(year, scen) %>% 
  filter(year > first(par$pint))

# Simulate response distributions across time. --------------------------------------

# Fix random seed for reproducibility.
set.seed(5483412)

#  Set number of simulation runs.
N <- 5000

res_s <- list()

for (c in unique(sig_data$scen)) {
  for (d in unique(sig_data$dom)) {
    for (t in unique(sig_data$type)) {
      for (s in unique(sig_data$sex)) {
        for (y in unique(sig_data$year)) {
          
          param <- 
            filter(sig_data, year == y, sex == s, type == t, dom == d, scen == c)
          
          if (d == "au") {
            res_s[[paste0(c, d, t, s, y)]] <-
              param %>% 
              tibble(
                sim_n = mu_n + dqsample(sim_list$sim_nau, N) * sd_n,
                sim_m = mu_m + dqsample(sim_list$sim_mau, N) * sd_m) %>% 
              select(year, scen, dom, sex, type, sim_n, sim_m) %>% 
              mutate(run = 1:n())
          }
          
          if (d == "ch") {
            res_s[[paste0(c, d, t, s, y)]] <-
              param %>% 
              tibble(
                sim_n = mu_n,
                sim_m = mu_m + dqsample(sim_list$sim_mch, N) * sd_m) %>% 
              select(year, scen, dom, sex, type, sim_n, sim_m) %>% 
              mutate(run = 1:n())
          }
        }
      }
    }
  }
}

# Combine simulations.
sim <- 
  bind_rows(res_s) %>%
  arrange(year, type, scen) %>% 
  group_by(year, scen)

# Prepare and attach data from last historic period.
h_dat <- 
  filter(a_dat, year == 2023) %>% 
  select(year, dom, sex, type) %>% 
  mutate(sim_n = 0, sim_m = 0)

h_dat %<>%
  bind_rows(
     unique(sim$run) %>% map(\(x) mutate(h_dat, run = x)))

h_dat %<>%
  bind_rows(
    c("A", "B", "C") %>% map(\(x) mutate(h_dat, scen = x))) %>% 
  na.omit() %>% distinct() %>% 
  relocate(year, scen, dom, sex, type)

# Consolidate data, derive perturbed predictions across simulation runs, and integrate
# 13th AHV pension payment.
sim %<>%
  bind_rows(h_dat) %>% 
  left_join(filter(a_dat, year == 2023) %>% 
              select(dom, sex, type, n, m, mp_0 = mi), by = c("dom", "sex", "type")) %>% 
  left_join(mp, by = "year") %>% 
  arrange(scen, dom, sex, type, run, year) %>% 
  group_by(scen, run, dom, sex, type) %>% 
  mutate(sim_n = pmax(0, n + cumsum(sim_n)), 
         sim_m = pmax(0, m + cumsum(sim_m))) %>% 
  # 13th AHV pension payment.
  mutate(sim_m = ifelse(type == "alt" & year >= 2026, 
                        sim_m * (1 + par$ahv13 * 1/12), sim_m),
         p_nmmp = sim_n * sim_m * mp, 
         d_nmmp = c(0, diff(p_nmmp))) %>% 
  ungroup() %>% 
  # Map variables onto explanatory variable from 'fit.t', separately for each simulation 
  # run and scenario.
  dplyr::summarize(d_nmmp = sum(d_nmmp), .by = c("run", "year", "scen")) %>% 
  mutate(p_exp = filter(t_dat, year == 2023)$exp_tot) %>% 
  left_join(mom_t, by = "year", relationship = "many-to-one") %>% 
  group_by(run, scen) %>% 
  # Simulate final predictions.
  mutate(d_nmmp = mu_b * d_nmmp + dqsample(sim_list$sim_tot, n()) *
                  sd_t *
                  sqrt(1 + d_nmmp^2 / sum(model.frame(fit_t)[, 2]^2))) %>% 
  mutate(d_nmmp = ifelse(year == 2023, 0, d_nmmp),
         p_exp  = (p_exp + cumsum(d_nmmp)) / 1e6) %>% 
  group_by(scen, year) %>% 
  # Calculate confidence bands at the 1% level.
  mutate(low = quantile(p_exp, .005), high = quantile(p_exp, .995)) %>% 
  select(year, scen, low, high) %>% 
  distinct()

# Integrate Liechtenstein effect and Delfin AHV21 cost vector.
sim %<>%
  left_join(select(corr, year, save_rel) %>% 
              dplyr::summarize(save_rel = sum(save_rel), .by = "year"), 
            by = "year") %>% 
  mutate(low = low - save_rel / 1e6, high = high - save_rel / 1e6) %>% 
  select(- save_rel) %>% 
  left_join(par$ahv21_cost, by = "year") %>% 
  mutate(low = low + cost / 1e6, high = high + cost / 1e6) %>% 
  select(- cost)

if (par$band_plot)
 ggplot(sim, aes(x = year, ymin = low, ymax = high, col = scen, fill = scen)) +
  theme_minimal(base_size = 14) +
  geom_ribbon(alpha = .25) +
  # geom_shadowline(alpha = .5, data = filter(sig_data, year <= 2040, scen == "A")) +
  labs(title = "Unsicherheitsquantifizierung der AHV-Ausgabenprojektionen",
       x = NULL, y = "Mil. Franken (laufende Preise)") +
  theme(legend.position = "top") +
  guides(fill = "none", color = "none") +
  scale_x_continuous(breaks = 2023:2040)
