#########################################################################
# SIMULATE PREDICTIVE BANDS BASED ON NORMAL RESIDUALS AND BFS SCENARIOS #
##########################################################################################

# Setup. ----------------------------------------------------------------------------

# Range of trend extrapolation points from crossvalidation (see 'cv_out.R').
load("data/output/RANGE.RData")

# Fix degrees of freedom for predictive band simulations.
DF <- tibble(nau = first(PAR$pint) - RANGE$nau + 1 - 2,
             mau = first(PAR$pint) - RANGE$mau + 1 - 2,
             mch = first(PAR$pint) - RANGE$mch + 1 - 2,
             tot = first(PAR$pint) - RANGE$tot + 1 - 2)

# Foundational analysis data (see 'prepare_inputs.R').
A_DATA <- 
  read_delim("data/output/A_DATA.csv") %>% 
  arrange(sex, dom, type, year) %>% 
  filter(year <= last(PAR$pint)) %>% 
  left_join(ZAS, by = "year")

# Estimate pension sum increment top-up. -------------------------------------------------

fit_t <-
  lm(diff(exp_tot) ~ 0 + diff(n * m * mp), FIT_DATA)

MOMENTS_TOTAL <- 
  tibble(year = PAR$pint, mu_t = coef(fit_t), sd_t = glance(fit_t)$sigma)

# Collect prediction moments for explanatory variables. ----------------------------------

# Lists for storing interim as well as scenario-specific final results.
RES_PROJ <- RES_FIN <- list()

# Attach historic data of current year..
RES_PROJ[["ref"]] <- 
  filter(A_DATA, year == first(PAR$pint)) %>% 
  # Respect that historic data is deterministic.
  mutate(sd_n = 0, sd_m = 0) %>% 
  rename(mu_m = m, mu_n = n) %>% 
  select(scen, sex, dom, type, year, mu_n, sd_n, mu_m, sd_m)

# Loop over scenarios to calculate respective prediction moments of explanatory variables.
for (c in c("B", "C")) {
  
  for (d in unique(A_DATA$dom)) {
    for (t in unique(unique(filter(A_DATA, type != "wit")$type))) {
      for (s in unique(A_DATA$sex)) {
        
        # Constrain extrapolation data according to 'range'.
        DATA <- A_DATA %>% 
          filter(scen %in% c("H", c), sex == s, dom == d, type == t) %>%
          mutate(t_n = ifelse(year >= RANGE$nau, 1, NA))
        
        if (d == "au")
          DATA <- DATA %>% 
            mutate(t_m = ifelse(year >= RANGE$mau, 1, NA))
        
        if (d == "ch")
          DATA <- DATA %>% 
            mutate(t_m = ifelse(year >= RANGE$mch, 1, NA))
        
        # Set moment collection table for explanatory variables.
        MOMENTS_EXP <-  
          tibble(year = PAR$pint[- 1], scen = c, dom = d, sex = s, type = t)
        
        PROJ_DATA <- DATA %>% 
          filter(year > first(PAR$pint))
        
        # Collect prediction moments for pension counts. Custom function 'psd' returns
        # standard deviation of predictions according to standard formula for simple
        # linear regression without intercept.
        if (d == "au") {
          
          fit_n <- 
            lm(c(NA, diff(n)) ~ 0 + t_n, DATA)
          
          MOMENTS_EXP <- MOMENTS_EXP %>%  
            mutate(mu_n = predict(fit_n, PROJ_DATA), sd_n = psd(fit_n, PROJ_DATA$t_n))
          
        } else {
          
          FA_DATA <- A_DATA %>% 
            filter(scen %in% c("H", c), sex == s, dom == d, type == t, 
                   year >= first(PAR$pint)) %>% 
            mutate(n = c(NA, diff(n))) %>% 
            slice(- 1)
          
          MOMENTS_EXP <- MOMENTS_EXP %>%  
            mutate(mu_n = FA_DATA$n, sd_n = 0)
        }
        
        # Collect prediction moments for pension levels.
        fit_m <- 
          lm(c(NA, diff(m)) ~ 0 + t_m, DATA)
        
        MOMENTS_EXP <- MOMENTS_EXP %>% 
          mutate(mu_m = predict(fit_m, PROJ_DATA), sd_m = psd(fit_m, PROJ_DATA$t_n))
        
        # Save moment table.
        RES_PROJ[[paste0(d, t, s, c)]] <- MOMENTS_EXP
        
      }
    }
  }
  
  SIGMA_DATA <- 
    bind_rows(RES_PROJ) %>% 
    left_join(MOMENTS_TOTAL, by = "year") %>% 
    filter(scen == c, year > first(PAR$pint)) %>% 
    select(- scen)
  
  # Simulate response distributions across time. --------------------------------------
  
  # Set number of simulation runs for calculation of uncertainty bands (fixed by 
  # experimentation).
  N <- 50000
  
  SIM_LIST <- list()
  
  # Fix random seed for reproducibility.
  set.seed(95827323)

  for (d in sort(unique(SIGMA_DATA$dom))) {
    for (t in sort(unique(SIGMA_DATA$type))) {
      for (s in sort(unique(SIGMA_DATA$sex))) {
        for (y in sort(unique(SIGMA_DATA$year))) {
          
          PARAM <- SIGMA_DATA %>% 
            filter(sex == s, dom == d, type == t, year == y)
          
          if (d == "au") {
            
            SIM_LIST[[paste0(c, d, t, s, y)]] <-
              PARAM %>% 
              tibble(
                sim_n = mu_n + rt(N, DF$nau) * sd_n,
                sim_m = mu_m + rt(N, DF$mau) * sd_m
                ) %>% 
              # select(year, scen, dom, sex, type, sim_n, sim_m) %>% 
              select(sex, dom, type, year, sim_n, sim_m) %>% 
              mutate(scen = c, run = 1:n()) %>% 
              relocate(scen)
          }
          
          if (d == "ch") {
            
            SIM_LIST[[paste0(c, d, t, s, y)]] <-
              PARAM %>% 
              tibble(
                sim_n = mu_n,
                sim_m = mu_m + rt(N, DF$mch) * sd_m
                ) %>% 
              # select(year, scen, dom, sex, type, sim_n, sim_m) %>% 
              select(dom, sex, type, year, sim_n, sim_m) %>% 
              mutate(scen = c, run = 1:n()) %>% 
              relocate(scen)
          }
          
        }
      }
    }
  }

  # Combine simulation runs.
  SIMULATIONS <-
    bind_rows(SIM_LIST) %>%
    group_by(scen, year) %>% 
    arrange(type, .by_group = TRUE)
  
  # Prepare data from last historic period.
  HIST_DATA <- A_DATA %>%
    filter(year == first(PAR$pint)) %>%
    select(dom, sex, type, year) %>%
    mutate(sim_n = 0, sim_m = 0) %>%
    expand(run = unique(SIMULATIONS$run), scen = c,
           nesting(year, dom, sex, type, sim_n, sim_m)) %>%
    relocate(scen, dom, sex, type, year)
  
  # Consolidate data, derive perturbed predictions across simulation runs, and integrate
  # 13th AHV pension payment.
  SIMULATIONS <- SIMULATIONS %>% 
    # Attach historic data.
    bind_rows(HIST_DATA) %>% 
    left_join(A_DATA %>% 
                filter(year == first(PAR$pint)) %>% 
                select(dom, sex, type, n, m), 
              by = c("dom", "sex", "type")) %>% 
    left_join(MINIMAL_PENSION, by = "year") %>% 
    group_by(scen, run, dom, sex, type) %>%
    arrange(year, by_group = TRUE) %>%
    # Enforce legal/logical bounds on projections.
    mutate(
      sim_n =         pmax(0, n + cumsum(sim_n)), 
      sim_m = pmin(2, pmax(0, m + cumsum(sim_m)))
    ) %>% 
    # 13th AHV pension payment.
    mutate(sim_m = 
             ifelse(type == "alt" & year >= 2026, sim_m * (1 + PAR$ahv13 * 1/12), sim_m),
           d_nmmp = c(0, diff(sim_n * sim_m * mp))) %>% 
    ungroup() %>% 
    # Map variables onto explanatory variable from 'fit_t', separately for each simulation 
    # run per scenario.
    summarize(d_nmmp = sum(d_nmmp), .by = c("scen", "run", "year")) %>% 
    mutate(p_exp = filter(ZAS, year == first(PAR$pint)) %>% pull(exp_tot)) %>% 
    left_join(MOMENTS_TOTAL, by = "year", relationship = "many-to-one") %>% 
    group_by(scen, run) %>%
    arrange(year, by_group = TRUE)
  
  # Simulate final predictions incorporating uncertainty due to the rent sum top-up 
  # estimation.
  
  # Reset seed for compatibility with Basismodell stand-alone implementation.
  set.seed(95827323)
  
  RES_FIN[[c]] <- SIMULATIONS %>%
    mutate(d_nmmp = mu_t * d_nmmp + rt(n(), DF$tot) *
             sd_t *
             sqrt(1 + d_nmmp^2 / sum(model.frame(fit_t)[, 2]^2)), 
           p_exp  = p_exp + cumsum((year != first(PAR$pint)) * d_nmmp)) %>% 
    group_by(year) %>% 
    # Calculate conditional confidence bands at the 5% error level.
    mutate(low  = quantile(p_exp, .05),
           high = quantile(p_exp, .95)) %>%
    select(scen, year, low, high) %>% 
    distinct()
}
 
# Collect final simulation results.
BANDS <- bind_rows(RES_FIN)

# Integrate Liechtenstein effect, AHV21 cost vector and complementary widow projections.
# Only the widow projections are adapted to the demographic scenarios.

# Express widow pension projections as incremental differences from the status quo.
# D_WIDOWS <- WIDOWS %>%
#   filter(scen %in% c("B", "C")) %>%
#   # Attach historic widow pension data for current year.
#   bind_rows(A_DATA %>%
#               select(dom, sex, type, year, m, n, mp) %>%
#               filter(type == "wit", year == first(PAR$pint)) %>%
#               right_join(tibble(year = first(PAR$pint), scen = c("B", "C")),
#                          by = "year",
#                          relationship = "many-to-many")) %>%
#   group_by(scen, dom, sex) %>%
#   arrange(year, by_group = TRUE) %>%
#   mutate(d_wid = c(0, diff(n * m * mp))) %>%
#   ungroup() %>%
#   summarize(d_wid = sum(d_wid), .by = c("scen", "year")) %>%
#   mutate(d_wid = cumsum(d_wid))

D_WIDOWS <- WIDOWS %>%
  filter(scen == "A") %>%
  bind_rows(A_DATA %>%
              select(dom, sex, type, year, m, n, mp) %>%
              filter(type == "wit", year == 2024)) %>%
  group_by(dom, sex) %>%
  arrange(year, .by_group = TRUE) %>%
  mutate(d_wid = c(0, diff(n * m * mp))) %>%
  ungroup() %>%
  summarize(d_wid = sum(d_wid), .by = "year") %>%
  mutate(d_wid = cumsum(d_wid))

# BANDS <- BANDS %>% 
#   left_join(LIECHTENSTEIN %>%
#        select(year, savings_rel) %>%
#        summarize(save = sum(savings_rel), .by = "year"),
#      by = "year") %>%
#   left_join(D_WIDOWS, by = c("scen", "year")) %>%
#   left_join(PAR$ahv21_cost, by = "year") %>%
#   # Express results in millions and round.
#   mutate(low  = round((low  + PAR$ahv21 * cost + d_wid - save) / 1e6),
#          high = round((high + PAR$ahv21 * cost + d_wid - save) / 1e6)) %>%
#   select(- cost, - save, - d_wid)

# BANDS <- BANDS %>%
#   left_join(LIECHTENSTEIN %>%
#               select(year, savings_rel) %>%
#               summarize(save = sum(savings_rel), .by = "year"),
#             by = "year") %>%
#   left_join(D_WIDOWS, by = "year") %>%
#   left_join(PAR$ahv21_cost, by = "year") %>%
#   mutate(low  = low  + cost + d_wid - save,
#          high = high + cost + d_wid - save) %>%
#   select(- cost, - save, - d_wid)

BANDS <- BANDS %>%
  left_join(LIECHTENSTEIN %>%
              select(year, savings_rel) %>%
              summarize(save = sum(savings_rel), .by = "year"),
            by = "year") %>%
  left_join(D_WIDOWS, by = "year") %>%
  # left_join(PAR$ahv21_cost, by = "year") %>%
  mutate(low  = low  + d_wid - save,
         high = high + d_wid - save) %>%
  # select(- cost, - save, - d_wid)
  select(- save, - d_wid)

# Deflate results, if desired.
if (PAR$real)
  BANDS <- BANDS %>% 
    left_join(ECKWERTE, by = "year") %>%
    mutate(low = low * df, high = high * df) %>%
    select(- df)

RES_TAB <- RES_TAB %>% 
  left_join(
    BANDS %>% 
      pivot_wider(names_from = scen, values_from = low:high) %>% 
      select(year, low = high_C, high = low_B), 
    by = "year") %>%
  mutate(low  = round(low / 1e6), high = round(high / 1e6),
         low  = ifelse(year >= last(PAR$pint - 5), round(low , -2), low),
         high = ifelse(year >= last(PAR$pint - 5), round(high, -2), high)) %>% 
  relocate(year, low, ref, high)

# Visualize reference projection with uncertainty bands.
ggplot(select(RES_TAB, - mp, - m, - n, - s) %>% pivot_longer(low:high)) +
  geom_ribbon(data = RES_TAB, aes(x = year, ymin = low, ymax = high), alpha = .2) +
  geom_line(aes(x = year, y = value, col = as.factor(name))) +
  geom_point(aes(x = year, y = value, col = as.factor(name))) +
  labs(x = NULL, 
       y = ifelse(PAR$real, 
                  "Millionen Franken (real)\n", 
                  "Millionen Franken (laufende Preise)\n"),
       title = "AHV-Ausgabenprojektion mit Unsicherheitsbändern") +
  theme_grey(base_size = 16, base_family = "Garamond") +
  scale_colour_viridis_d(labels = c("ungünstig", "günstig", "mittel"),
                         name = "Szenario", option = "A") +
  scale_y_continuous(labels = scales::comma)

select(RES_TAB, year, low, ref, high)
