###############################################################################
# OPTIMIZE TREND EXTRAPOLATION DATA RANGE VIA SLIDING WINDOW CROSS-VALIDATION #
##########################################################################################

# Calculate breakpoint for starting year selection of the cross-validation procedure. The 
# initial year 2005 has been chosen to avoid contaminations due to the reference age 
# increase for women along the 10th AHV reform.
breakpoint <-
  max(
    (left_join(ZAS, MINIMAL_PENSION, by = "year") %>%
     na.omit() %>%
     mutate(exp_tot = exp_tot / mp) %>%
     filter(year >= 2001) %>%
     breakpoints(diff(exp_tot) ~ 1, data = .))$breakpoints,
  0, na.rm = TRUE)

# Restrict to viable training data. Scenario 'H' encodes historic data.
TRAIN_DATA <- A_DATA %>% 
  filter(scen == "H", year >= 2001 + (breakpoint - 1)) %>%
  select(- scen)

# Assign 'Trendpunkte' to consider for linear extrapolations (see Methodenbeschrieb of the
# Basismodell).
RANGE_PAR <- 
  expand_grid(nau = PAR$pr[["tr"]], mau = PAR$pr[["tr"]], 
              mch = PAR$pr[["tr"]], tot = PAR$pr[["tot"]])

# List to store forecast performances of the respective Trendpunkt choices.
PER_LIST <- list()

# Constant to harmonize the set of training windows considered across Trendpunkt sets.
tmax <- max(bind_rows(PAR$pr))

# Cross-validation loop over trend ranges, groups and time windows.
for (z in 1:nrow(RANGE_PAR)) {
  
  # Extract relevant 'Trendpunkt' combination.
  RANGE_SEL <- slice(RANGE_PAR, z)
  
  # List to store the forecast performances of the respective runs.
  EVAL <- list()

  for (w in 1:(length(unique(TRAIN_DATA$year)) - (tmax + PAR$out) + 1)) {
    
    # Set list to store the projected explanatory variables per group.
    RES_LIST <- list()
    
      for (i in c("m", "f")) {
        for (j in c("ch", "au")) {
          for (k in c("alt", "kin", "wai", "wit")) {
            
            # Subset the relevant group's data.
            TEMP <- TRAIN_DATA %>% 
              filter(sex == i, dom == j, type == k) %>% 
              slice(w : (w + (tmax + PAR$out) - 1))
            
            # Delete historic mean pension data beyond the training range determined by
            # the Trendpunkte up until the chosen prediction horizon 'PAR$out'.
            ind_m <- (tmax + 1):(tmax + PAR$out) 
            TEMP$m[ind_m] <- NA
            
            if (j == "ch") {
              
              # Predict the previously deleted annual changes of Swiss residents' mean 
              # pension conditional on sex 'i' and pension type 'k'.
              fit_m <- 
                lm(diff(m) ~ 1, slice(TEMP, (tmax - RANGE_SEL$mch + 1) : tmax)) %>% 
                predict(slice(TEMP, ind_m))
              
              # Assign mean pension predictions with enforced boundary constraints
              # reflecting the legal minimum/maximum of pension entitlements. Eventual
              # crossings above 2 due to (massive) pension deferrements are ignored.
              TEMP$m[ind_m] <- 
                pmin(2, pmax(0, last(na.omit(TEMP)$m)) + cumsum(fit_m))
            }
            
            if (j == "au") {
              
              # Predict the previously deleted average pension level abroad relative to 
              # the legal minimal pension conditional on sex 'i' and pension type 'k'.
              fit_m <- 
                lm(diff(m) ~ 1, slice(TEMP, (tmax - RANGE_SEL$mau + 1) : tmax)) %>% 
                predict(slice(TEMP, ind_m))
              
              # Assign mean pension predictions with enforced boundary constraints
              # reflecting the legal minimum/maximum of pension entitlements.
              TEMP$m[ind_m] <- 
                pmin(2, pmax(0, last(na.omit(TEMP)$m) + cumsum(fit_m)))
  
              # Impute foreign pension count. Note this is only done for pensions paid
              # abroad since the Swiss pension stocks are derived from the BFS demographic
              # scenarios (see script '3_adjust_scenarios.R').
              TEMP$n[(tmax + 1):(tmax + PAR$out)] <- NA
              ind_n <- which(is.na(TEMP$n))
              
              fit_n <- 
                lm(diff(n) ~ 1, slice(TEMP, (tmax - RANGE_SEL$nau + 1) : tmax)) %>% 
                predict(slice(TEMP, ind_n))
              
              # Assign stock projection with enforced non-negativity constraint on the 
              # aggregate projection.
              TEMP$n[ind_n] <- 
                pmax(0, last(na.omit(TEMP)$n) + cumsum(fit_n))
            }
            
            # Store projections for the considered subgroup.
            RES_LIST[[paste0(i, j, k)]] <- TEMP
          }
        }
      }
    
    # Combine results across groups and aggregate them into yearly projections for
    # the pension sum. Afterwards, subset the data according to the Trendpunkte for 
    # pension top-up estimation and attach historic AHV total expenditures. Lastly, 
    # separate the data into training and testing sets to evaluate the final regression's
    # prediction.
    SPLIT <- 
      bind_rows(RES_LIST) %>%
      summarize(m = weighted.mean(m, n), n = sum(n), .by = c("mp", "year")) %>% 
      slice((tmax - RANGE_SEL$tot + 1) : (tmax + PAR$out)) %>%
      left_join(ZAS, by = "year") %>%
      mutate(d_exp  = exp_tot - lag(exp_tot),
             d_nmmp = n * m * mp - lag(n * m * mp)) %>% 
      rsample::initial_time_split(prop = RANGE_SEL$tot / (RANGE_SEL$tot + PAR$out))
    
    # Estimate the pension-sum top-up necessary to arrive at the total AHV expenditures on 
    # the basis of the training data, and apply the resulting fit onto the withheld 
    # testing data.
    fit <- 
      lm(d_exp ~ 0 + d_nmmp, data = training(SPLIT)) %>% 
      predict(testing(SPLIT)) %>% 
      cumsum()
    
    # Collect historic as well as projected data conditional on the relevant projection
    # horizon 'k', calculate the horizon-specific forecast error along parameter function
    # 'PAR$err' and record the corresponding Trendpunkt combination.
    TEMP <- 
      tibble(k    = 1:PAR$out,
             obs  = testing(SPLIT)$exp_tot[1:PAR$out], 
             pred = (last(training(SPLIT)$exp_tot) + fit)[1:PAR$out]) %>% 
      mutate(err = PAR$err(pred, obs),
             ran = rep(paste0(RANGE_SEL, collapse = ""), nrow(.)))
    
    # Discount the forecast errors along the approximate expected annual return of the AHV 
    # fonds (inflation plus two percentage points, in line with historic returns).
    DISCOUNT <- INFLATION %>% 
      filter(year %in% testing(SPLIT)$year) %>% 
      mutate(k = 1:PAR$out, df = cumprod(1 / (inf / first(inf) + .02))) %>% 
      select(k, df)
    
    TEMP <- TEMP %>% 
      left_join(DISCOUNT, by = "k") %>% 
      mutate(err = err * df) %>% 
      select(- df)
    
    # Store the forecast performance for the relevant Trendpunkt combination and data
    # window selected by 'w'.
    EVAL[[w]] <- TEMP
  }
  
  # Save all runs across data windows for the relevant Trendpunkt combination.
  PER_LIST[[z]] <- EVAL
}

# Select the initial years for the upcoming trend extrapolations of the explanatory
# variables (pension stocks abroad, mean relative pension levels both within Switzerland 
# and abroad, pension sum top-up) according to the minimal global error criterion 
# regarding the out-of-sample total AHV expenditure projections. Parameter function 
# 'PAR$glo' determines how the horizon-specific forecast errors are consolidated.
RANGE <-
  max(TRAIN_DATA$year) + 1 -
  RANGE_PAR %>%
    slice(
      bind_rows(PER_LIST) %>%
      # Average k-step errors across runs according to aggregator 'PAR$glo'.
      summarize(err = PAR$glo(err), .by = "ran") %>%
      pull(err) %>%
      which.min()
    )

# Store selected Trendpunkt combination for the subsequent production of the reference
# projection as well as the uncertainty bands (see script '6_uncertainty_bands.R').
save(RANGE, file = "data/output/RANGE.RData")
