#########################################################################
# SIMULATE PREDICTIVE BANDS BASED ON NORMAL RESIDUALS AND BFS SCENARIOS #
##########################################################################################

# Create random samples from t-distributions for later resampling (faster). --------------
seed <- 95827323

if (par$sim) {
  
  # Set seed for reproducibility.
  set.seed(seed); dqset.seed(seed)
  N_sim <- 1000000000
  
  # Fix degrees of freedom for respective regressions, taking into consideration that
  # differencing and estimating one parameter decreases the effective sample size by 2.
  df_nau <- first(par$pint) - range[1] + 1 - 2
  df_mau <- first(par$pint) - range[2] + 1 - 2
  df_mch <- first(par$pint) - range[3] + 1 - 2
  df_tot <- first(par$pint) - range[4] + 1 - 2
  
  # Exploit that Student distributions emerge as centered normal variates with inverse
  # Gamma distributed variances.
  sim_list <- 
    list(seed = seed,
         sim_nau = dqrnorm(N_sim) / sqrt(rchisq(N_sim, df_nau) / df_nau),
         sim_mau = dqrnorm(N_sim) / sqrt(rchisq(N_sim, df_mau) / df_mau),
         sim_mch = dqrnorm(N_sim) / sqrt(rchisq(N_sim, df_mch) / df_mch),
         sim_tot = dqrnorm(N_sim) / sqrt(rchisq(N_sim, df_tot) / df_tot))
  
  save(sim_list, file = "~/data/appl-wb/20_staff/kjo/misc_data/sim_list.rdata")
}

# Load simulated distributions.
load("~/data/appl-wb/20_staff/kjo/misc_data/sim_list.rdata")


# Load processed data from 'basismodell'. -------------------------------------------

# Range of trend extrapolation points from crossvalidation (see 'cv_out.R').
load("data/range.RData")

# Foundational analysis data (see 'prepare_inputs.R').
a_dat <- 
  read_delim("data/a_dat.csv") %>% 
  arrange(sex, dom, type, year) %>% 
  filter(year <= last(par$pint)) %>% 
  left_join(zas, by = "year")

# Estimate pension top-up. ----------------------------------------------------------

fit_t <-
  lm(I(c(NA, diff(exp_tot))) ~ 0 + d_nmmp, f_dat)

mom_t <- 
  tibble(year = par$pint, mu_t = coef(fit_t), sd_t = glance(fit_t)$sigma)

# Collect prediction moments for explanatory variables. -----------------------------

res_p <- list()

for (c in unique(filter(a_dat, !(scen %in% c("H")))$scen)) {
  for (d in unique(a_dat$dom)) {
    for (t in unique(filter(a_dat, type != "wit")$type)) {
      for (s in unique(a_dat$sex)) {
        
        # Constrain extrapolation data according to 'range'.
        data <- a_dat %>% 
          filter(dom == d, type == t, sex == s, scen %in% c("H", c)) %>%
            mutate(t_n = ifelse(year >= range[1], 1, NA))
        
        if (d == "au")
          data %<>% 
            mutate(t_m = ifelse(year >= range[2], 1, NA))
        
        if (d == "ch")
          data %<>%
            mutate(t_m = ifelse(year >= range[3], 1, NA))
        
        # Set moment collection table.
        mom <-  
          tibble(year = par$pint[- 1], scen = c, dom = d, sex = s, type = t)
        
        p_dat <- 
          filter(data, year > first(par$pint))
        
        # Collect prediction moments for pension counts. Custom function 'psd' returns
        # standard deviation of predictions according to standard formula for simple
        # linear regression without intercept.
        if (d == "au") {
          
          fit_n <- 
            lm(c(NA, diff(n)) ~ 0 + t_n, data)
          
          mom %<>% 
            mutate(mu_n = predict(fit_n, p_dat), sd_n = psd(fit_n, p_dat$t_n))
          
        } else {
          
          fa_dat <- 
            filter(a_dat, year >= first(par$pint), 
                   scen %in% c("H", c), sex == s, type == t, dom == d) %>% 
            mutate(n = c(NA, diff(n))) %>% 
            slice(- 1)
          
          mom %<>% 
            mutate(mu_n  = fa_dat$n, sd_n = 0)
          
        }
        
        # Collect prediction moments for pension levels.
        fit_m <- 
          lm(c(NA, diff(m)) ~ 0 + t_m, data)
        
        mom %<>% 
          mutate(mu_m = predict(fit_m, p_dat), sd_m = psd(fit_m, p_dat$t_n))
        
        # Save moment table.
        res_p[[paste0(d, t, s, c)]] <- mom
        
      }
    }
  }
}

# Collect results and attach last historic period.
res_p[["ref"]] <- 
  filter(a_dat, year == first(par$pint)) %>% 
  # Respect that historic data is deterministic.
  mutate(sd_n = 0, sd_m = 0) %>% 
  dplyr::rename(mu_m = m, mu_n = n) %>% 
  select(scen, year, dom, sex, type, mu_n, sd_n, mu_m, sd_m)

sig_data <- 
  bind_rows(res_p) %>% 
  left_join(mom_t, by = "year") %>% 
  group_by(year, scen) %>% 
  filter(year > first(par$pint))

# Simulate response distributions across time. --------------------------------------

# Fix random seed for reproducibility.
set.seed(seed + 1); dqset.seed(seed + 1)

#  Set number of simulation runs (fixed by experimentation).
N <- 20000

res_s <- list()

for (c in unique(sig_data$scen)) {
  for (d in unique(sig_data$dom)) {
    for (t in unique(sig_data$type)) {
      for (s in unique(sig_data$sex)) {
        for (y in unique(sig_data$year)) {
          
          param <- 
            filter(sig_data, year == y, sex == s, type == t, dom == d, scen == c)
          
          if (d == "au") {
            
            res_s[[paste0(c, d, t, s, y)]] <-
              param %>% 
              tibble(
                sim_n = mu_n + dqsample(sim_list$sim_nau, N, replace = TRUE) * sd_n,
                sim_m = mu_m + dqsample(sim_list$sim_mau, N, replace = TRUE) * sd_m
                ) %>% 
              select(year, scen, dom, sex, type, sim_n, sim_m) %>% 
              mutate(run = 1:n())
          }
          
          if (d == "ch") {
            
            res_s[[paste0(c, d, t, s, y)]] <-
              param %>% 
              tibble(
                sim_n = mu_n,
                sim_m = mu_m + dqsample(sim_list$sim_mch, N, replace = TRUE) * sd_m
                ) %>% 
              select(year, scen, dom, sex, type, sim_n, sim_m) %>% 
              mutate(run = 1:n())
          }
        }
      }
    }
  }
}

# Combine simulation runs.
sim <- 
  bind_rows(res_s) %>%
  arrange(year, type, scen) %>% 
  group_by(year, scen)

# Prepare data from last historic period.
h_dat <- a_dat %>% 
  filter(year == first(par$pint)) %>% 
  select(year, dom, sex, type) %>% 
  mutate(sim_n = 0, sim_m = 0) %>% 
  expand(run = unique(sim$run), scen = c("A", "B", "C"),
         nesting(year, dom, sex, type, sim_n, sim_m)) %>% 
  relocate(year, scen, dom, sex, type)

# Consolidate data, derive perturbed predictions across simulation runs, and integrate
# 13th AHV pension payment.
sim %<>%
  # Attach historic data.
  bind_rows(h_dat) %>% 
  left_join(a_dat %>% 
      filter(year == first(par$pint)) %>% 
      select(dom, sex, type, n, m), 
    by = c("dom", "sex", "type")) %>% 
  left_join(mp, by = "year") %>% 
  group_by(scen, run, dom, sex, type) %>%
  arrange(year, by_group = TRUE) %>%
  # Enforce legal/logical bounds on projections.
  mutate(sim_n =         pmax(0, n + cumsum(sim_n)), 
         sim_m = pmin(2, pmax(0, m + cumsum(sim_m)))
  ) %>% 
  # 13th AHV pension payment.
  mutate(sim_m = 
           ifelse(type == "alt" & year >= 2026, sim_m * (1 + par$ahv13 * 1/12), sim_m),
         d_nmmp = c(0, diff(sim_n * sim_m * mp))) %>% 
  ungroup() %>% 
  # Map variables onto explanatory variable from 'fit_t', separately for each simulation 
  # run per scenario.
  dplyr::summarize(d_nmmp = sum(d_nmmp), .by = c("run", "year", "scen")) %>% 
  mutate(p_exp = filter(f_dat, year == first(par$pint)) %>% pull(exp_tot)) %>% 
  left_join(mom_t, by = "year", relationship = "many-to-one") %>% 
  group_by(scen, run) %>%
  arrange(year, by_group = TRUE) %>% 
  # Simulate final predictions incorporating uncertainty due to the rent sum top-up
  # estimation.
  mutate(d_nmmp = mu_t * d_nmmp + dqsample(sim_list$sim_tot, n(), replace = TRUE) *
                  sd_t *
                  sqrt(1 + d_nmmp^2 / sum(model.frame(fit_t)[, 2]^2)), 
         p_exp  = p_exp + cumsum((year != first(par$pint)) * d_nmmp)) %>% 
  group_by(scen, year) %>% 
  # Calculate conditional confidence bands at the 5% level.
  # mutate(low  = quantile(p_exp, .05),
  #        high = quantile(p_exp, .95),
  #        mid  = mean(p_exp)) %>%
  mutate(low  = quantile(p_exp, .95),
         high = quantile(p_exp, .05),
         mid  = mean(p_exp)) %>%
  select(year, scen, low, mid, high) %>% 
  distinct()

# Integrate Liechtenstein effect, AHV21 cost vector and complementary widow projections. 
# Neither of these ex post correction is included in the uncertainty quantification.
d_wid <- wid %>% 
  bind_rows(a_dat %>%
              select(year, sex, dom, type, m, n, mp) %>%
              filter(year == 2024, type == "wit")) %>%
  group_by(sex, dom) %>% 
  arrange(year, by_group = TRUE) %>% 
  mutate(d_wid = c(0, diff(n * m * mp))) %>% 
  ungroup() %>% 
  dplyr::summarize(d_wid = sum(d_wid), .by = "year") %>% 
  mutate(d_wid = cumsum(d_wid))

sim %<>%
  left_join(corr %>% 
     select(year, save_rel) %>% dplyr::summarize(save = sum(save_rel), .by = "year"), 
     by = "year") %>% 
  left_join(d_wid, by = "year") %>%
  left_join(par$ahv21_cost, by = "year") %>% 
  # mutate(low  = low  + cost + d_wid - save,
  #        mid  = mid  + cost + d_wid - save,
  #        high = high + cost + d_wid - save) %>% 
  mutate(low  = low  + d_wid - save,
         mid  = mid  + d_wid - save,
         high = high + d_wid - save) %>% 
  select(- cost, - save, - d_wid)

# Deflate results, if desired.
if (par$real)
  sim %<>%
    left_join(eck, by = "year") %>%
    mutate(low = low * df, high = high * df) %>%
    select(- df)

# Express results in millions and round.
sim %<>%
  mutate(low  = round(low  / 1e6),
         mid  = round(mid  / 1e6),
         high = round(high / 1e6))
