##############################################################################
# OPTIMIZE TREND EXTRAPOLATION DATA RANGE VIA SLIDING WINDOW CROSSVALIDATION #
##########################################################################################

# Calculate breakpoint for starting year selection. The initial year 2005 has been chosen 
# to avoid contaminations due to the reference age increases for women along the 10th 
# AHV reform.
year_init <-
  max(
    (left_join(ZAS, MINIMAL_PENSION, by = "year") %>%
     na.omit() %>%
     mutate(exp_tot = exp_tot / mp) %>%
     filter(year >= 2005) %>%
     breakpoints(diff(exp_tot) ~ 1, data = .))$breakpoints,
  0, na.rm = TRUE)

# Restrict to viable training data.
T_DATA <- A_DATA %>% 
  filter(scen == "H", year >= 2005 + (year_init - 1)) %>% 
  select(- scen)

# Number of estimation points to consider for trend extrapolation.
RANGE_PAR <- 
  expand_grid(nau = PAR$pr[["tr"]], mau = PAR$pr[["tr"]], 
              mch = PAR$pr[["tr"]], tot = PAR$pr[["tot"]])

# List to store forecast performances of respective 'range' fits.
PER_LIST <- list()

# Constant to harmonize the set of training windows across parameter sets.
tmax <- max(bind_rows(PAR$pr))

# Cross-validation loop over trend ranges, groups and time windows.
for (z in 1:nrow(RANGE_PAR)) {
  
  # Extract relevant 'Trendpunkte' combination.
  RANGE_SEL <- dplyr::slice(RANGE_PAR, z)
  
  # List to store final results.
  EVAL <- list()

  for (w in 1:(length(unique(T_DATA$year)) - (tmax + PAR$out) + 1)) {
    
    RES_LIST <- list()
    
      for (i in c("m", "f")) {
        for (j in c("ch", "au")) {
          for (k in c("alt", "kin", "wai", "wit")) {
          
            TEMP <- T_DATA %>% 
              filter(sex == i, dom == j, type == k) %>% 
              slice(w : (w + (tmax + PAR$out) - 1))
            
            # Impute native average pension level relative to minimal pension.
            ind_m <- (tmax + 1):(tmax + PAR$out) 
            TEMP$m[ind_m] <- NA
            
            if (j == "ch") {
              
              fit_m <- 
                lm(diff(m) ~ 1, slice(TEMP, (tmax - RANGE_SEL$mch + 1) : tmax)) %>% 
                predict(slice(TEMP, ind_m))
              
              TEMP$m[ind_m] <- 
                pmin(2, pmax(0, last(na.omit(TEMP)$m)) + cumsum(fit_m))
            }
            
            if (j == "au") {
              
              # Impute foreign average pension level relative to minimal pension.
              fit_m <- 
                lm(diff(m) ~ 1, slice(TEMP, (tmax - RANGE_SEL$mau + 1) : tmax)) %>% 
                predict(slice(TEMP, ind_m))
              
              TEMP$m[ind_m] <- 
                pmin(2, pmax(0, last(na.omit(TEMP)$m) + cumsum(fit_m)))
  
              # Impute foreign pension count.
              TEMP$n[(tmax + 1):(tmax + PAR$out)] <- NA
              ind_n <- which(is.na(TEMP$n))
              
              fit_n <- 
                lm(diff(n) ~ 1, slice(TEMP, (tmax - RANGE_SEL$nau + 1) : tmax)) %>% 
                predict(slice(TEMP, ind_n))
              
              TEMP$n[ind_n] <- 
                pmax(0, last(na.omit(TEMP)$n) + cumsum(fit_n))
            }
    
            RES_LIST[[paste0(i, j, k)]] <- TEMP
          }
        }
      }

    SPLIT <- 
      rsample::initial_time_split(
        bind_rows(RES_LIST) %>%
          summarize(m = weighted.mean(m, n), n = sum(n), .by = c("mp", "year")) %>% 
          slice((tmax - RANGE_SEL$tot + 1) : (tmax + PAR$out)) %>%
          left_join(ZAS, by = "year") %>%
          mutate(d_exp  = exp_tot - lag(exp_tot),
                 d_nmmp = n * m * mp - lag(n * m * mp)),
        prop = RANGE_SEL$tot / (RANGE_SEL$tot + PAR$out)
      )

    fit <- 
      lm(d_exp ~ 0 + d_nmmp, data = training(SPLIT)) %>% 
      predict(testing(SPLIT)) %>% 
      cumsum()
  
    TEMP <- 
      tibble(k    = 1:PAR$out,
             obs  = testing(SPLIT)$exp_tot[1:PAR$out], 
             pred = (last(training(SPLIT)$exp_tot) + fit)[1:PAR$out]) %>% 
      mutate(err = PAR$err(pred, obs),
             ran = rep(paste0(RANGE_SEL, collapse = ""), nrow(.)))
    
    # Discount forecast errors. Variable 'k' refers to the number of years projected into
    # the future.
    DISCOUNT <- INFLATION %>% 
      filter(year %in% testing(SPLIT)$year) %>% 
      mutate(k = 1:PAR$out, df = cumprod(1 / (inf / first(inf) + .02))) %>% 
      select(k, df)
    
    TEMP <- TEMP %>% 
      left_join(DISCOUNT, by = "k") %>% 
      mutate(err = err * df) %>% 
      select(- df)
    
    EVAL[[w]] <- TEMP
  }
  
  # Save run for specific Trendpunkte.
  PER_LIST[[z]] <- EVAL
}

# Derive initial years for respective upcoming trend extrapolations according to the 
# minimal global error criterion regarding the out-of-sample total AHV expenditure
# projections.
RANGE <-
  max(T_DATA$year) + 1 -
  RANGE_PAR %>%
    slice(
      bind_rows(PER_LIST) %>%
      # Average k-step errors across runs according to aggregator 'PAR$glo'.
      summarize(err = PAR$glo(err), .by = "ran") %>%
      pull(err) %>%
      which.min()
    )

save(RANGE, file = "data/output/RANGE.RData")
