#########################################################################
# SIMULATE PREDICTIVE BANDS BASED ON NORMAL RESIDUALS AND BFS SCENARIOS #
##########################################################################################

# Create random samples from t-distributions for later resampling (faster). --------------
seed <- 95827323
set.seed(seed)

df_nau <- first(par$pint) - range[1] + 1 - 2
df_mau <- first(par$pint) - range[2] + 1 - 2
df_mch <- first(par$pint) - range[3] + 1 - 2
df_tot <- first(par$pint) - range[4] + 1 - 2

# if (par$sim) {
# 
#   # Set seed for reproducibility.
#   seed <- sample(10000:100000, 1)
#   set.seed(seed)
#   N_sim <- 10000000
# 
#   # Fix degrees of freedom for respective regressions, taking into consideration that
#   # differencing and estimating one parameter decreases the effective sample size by 2.
#   df_nau <- first(par$pint) - range[1] + 1 - 2
#   df_mau <- first(par$pint) - range[2] + 1 - 2
#   df_mch <- first(par$pint) - range[3] + 1 - 2
#   df_tot <- first(par$pint) - range[4] + 1 - 2
# 
#   # Exploit that Student distributions emerge as centered normal variates with inverse
#   # Gamma distributed variances.
#   sim_list <-
#     list(seed = seed,
#          sim_nau = dqrnorm(N_sim) / sqrt(rchisq(N_sim, df_nau) / df_nau),
#          sim_mau = dqrnorm(N_sim) / sqrt(rchisq(N_sim, df_mau) / df_mau),
#          sim_mch = dqrnorm(N_sim) / sqrt(rchisq(N_sim, df_mch) / df_mch),
#          sim_tot = dqrnorm(N_sim) / sqrt(rchisq(N_sim, df_tot) / df_tot))
# 
#   save(sim_list, file = "R:/Prod/wb/20_staff/kjo/misc_data/sim_list.rdata")
# }
# 
# # Load simulated distributions.
# load("R:/Prod/wb/20_staff/kjo/misc_data/sim_list.rdata")


# Load processed data from 'basismodell'. -------------------------------------------

# Range of trend extrapolation points from crossvalidation (see 'cv_out.R').
load("data/range.RData")

# Foundational analysis data (see 'prepare_inputs.R').
a_dat <- 
  read_delim("data/a_dat.csv")   %>% 
  arrange(sex, dom, type, year)  %>% 
  filter(year <= last(par$pint)) %>% 
  left_join(zas, by = "year")

# Estimate pension top-up. ----------------------------------------------------------

fit_t <-
  lm(c(NA, diff(exp_tot)) ~ 0 + d_nmmp, f_dat)

mom_t <- 
  tibble(year = par$pint, mu_t = coef(fit_t), sd_t = glance(fit_t)$sigma)

# Collect prediction moments for explanatory variables. -----------------------------

res_p <- list()

for (c in unique(filter(a_dat, !(scen %in% c("H")))$scen)) {
  for (d in unique(a_dat$dom)) {
    for (t in unique(filter(a_dat, type != "wit")$type)) {
      for (s in unique(a_dat$sex)) {
        
        # Constrain extrapolation data according to 'range'.
        data <- a_dat %>% 
          filter(dom == d, type == t, sex == s, scen %in% c("H", c)) %>%
          mutate(t_n = ifelse(year >= range[1], 1, NA))
        
        if (d == "au")
          data %<>% 
          mutate(t_m = ifelse(year >= range[2], 1, NA))
        
        if (d == "ch")
          data %<>%
          mutate(t_m = ifelse(year >= range[3], 1, NA))
        
        # Set moment collection table.
        mom <-  
          tibble(year = par$pint[- 1], scen = c, dom = d, sex = s, type = t)
        
        p_dat <- 
          filter(data, year > first(par$pint))
        
        # Collect prediction moments for pension counts. Custom function 'psd' returns
        # standard deviation of predictions according to standard formula for simple
        # linear regression without intercept.
        if (d == "au") {
          
          fit_n <- 
            lm(c(NA, diff(n)) ~ 0 + t_n, data)
          
          mom %<>% 
            mutate(mu_n = predict(fit_n, p_dat), sd_n = psd(fit_n, p_dat$t_n))
          
        } else {
          
          fa_dat <- 
            filter(a_dat, year >= first(par$pint), 
                   scen %in% c("H", c), sex == s, type == t, dom == d) %>% 
            mutate(n = c(NA, diff(n))) %>% 
            slice(- 1)
          
          mom %<>% 
            mutate(mu_n  = fa_dat$n, sd_n = 0)
          
        }
        
        # Collect prediction moments for pension levels.
        fit_m <- 
          lm(c(NA, diff(m)) ~ 0 + t_m, data)
        
        mom %<>% 
          mutate(mu_m = predict(fit_m, p_dat), sd_m = psd(fit_m, p_dat$t_n))
        
        # Save moment table.
        res_p[[paste0(d, t, s, c)]] <- mom
        
      }
    }
  }
}

# Collect results and attach last historic period.
res_p[["ref"]] <- 
  filter(a_dat, year == first(par$pint)) %>% 
  # Respect that historic data is deterministic.
  mutate(sd_n = 0, sd_m = 0) %>% 
  dplyr::rename(mu_m = m, mu_n = n) %>% 
  select(scen, year, dom, sex, type, mu_n, sd_n, mu_m, sd_m)

sig_data <- 
  bind_rows(res_p) %>% 
  left_join(mom_t, by = "year") %>% 
  group_by(year, scen) %>% 
  filter(year > first(par$pint))

# Simulate response distributions across time. --------------------------------------

# Fix random seed for reproducibility.
set.seed(seed + 1)

#  Set number of simulation runs (fixed by experimentation).
N <- 20000

res_s <- list()

for (c in unique(sig_data$scen)) {
  for (d in unique(sig_data$dom)) {
    for (t in unique(sig_data$type)) {
      for (s in unique(sig_data$sex)) {
        for (y in unique(sig_data$year)) {
          
          param <- 
            filter(sig_data, year == y, sex == s, type == t, dom == d, scen == c)
          
          if (d == "au") {
            
            res_s[[paste0(c, d, t, s, y)]] <-
              param %>% 
              tibble(
                sim_n = mu_n + rt(N, df_nau) * sd_n,
                sim_m = mu_m + rt(N, df_mau) * sd_m
              ) %>% 
              select(year, scen, dom, sex, type, sim_n, sim_m) %>% 
              mutate(run = 1:n())
          }
          
          if (d == "ch") {
            
            res_s[[paste0(c, d, t, s, y)]] <-
              param %>% 
              tibble(
                sim_n = mu_n,
                sim_m = mu_m + rt(N, df_mch) * sd_m
              ) %>% 
              select(year, scen, dom, sex, type, sim_n, sim_m) %>% 
              mutate(run = 1:n())
          }
        }
      }
    }
  }
}

# Combine simulation runs.
sim <- 
  bind_rows(res_s) %>%
  arrange(year, type, scen) %>% 
  group_by(year, scen)

# Prepare data from last historic period.
h_dat <- a_dat %>% 
  filter(year == first(par$pint)) %>% 
  select(year, dom, sex, type) %>% 
  mutate(sim_n = 0, sim_m = 0) %>% 
  expand(run = unique(sim$run), scen = c("A", "B", "C"),
         nesting(year, dom, sex, type, sim_n, sim_m)) %>% 
  relocate(year, scen, dom, sex, type)

# Consolidate data, derive perturbed predictions across simulation runs, and integrate
# 13th AHV pension payment.
sim %<>%
  # Attach historic data.
  bind_rows(h_dat) %>% 
  left_join(filter(a_dat, year == first(par$pint)) %>% select(dom, sex, type, n, m), 
            by = c("dom", "sex", "type")) %>% 
  left_join(mp, by = "year") %>% 
  group_by(scen, run, dom, sex, type) %>%
  arrange(scen, run, dom, sex, type, year) %>%
  # Enforce legal bounds on projections.
  mutate(
    sim_n =         pmax(0, n + cumsum(sim_n)), 
    sim_m = pmin(2, pmax(0, m + cumsum(sim_m)))
  ) %>% 
  # 13th AHV pension payment.
  mutate(sim_m = 
           ifelse(type == "alt" & year >= 2026, sim_m * (1 + par$ahv13 * 1/12), sim_m),
         d_nmmp = c(0, diff(sim_n * sim_m * mp))) %>% 
  ungroup() %>% 
  filter(type == "alt") %>% 
  select(run, scen, dom, sex, year, sim_m, sim_n, mp)

m_sc <- sim %>% 
  dplyr::summarize(m = weighted.mean(sim_m, sim_n), n = sum(sim_n), 
                   .by = c("run", "scen", "year", "mp")) %>% 
  group_by(scen, year) %>% 
  mutate(m = case_when(
    scen == "A" ~ mean(m),
    scen == "B" ~ quantile(m, .005, weights = n),
    scen == "C" ~ quantile(m, .995, weights = n)
  )) %>%
  select(- run, - n) %>% 
  distinct() %>% 
  mutate(m = round(m, 3)) 

ggplot(filter(m_sc, year >= 2026), aes(x = year, y = m, col = scen)) +
  geom_shadowpoint()
