##############################################################
# TUNE TREND SAMPLE RANGE VIA SLIDING WINDOW CROSSVALIDATION #
################################################################################

# Speed up computations through masking of functions by C++ equivalents and 
# multithreading.
set_collapse(mask = "all", nthreads = 4)

# Calculate breakpoint for starting year selection. The initial year 2005 has been chosen 
# to avoid contaminations due to the reference age increases for women along the 10th 
# AHV reform.
bp <- 
  max(
    (left_join(zas, mp, by = "year") %>% 
     na.omit() %>% 
     mutate(exp_tot = exp_tot / mp) %>% 
     filter(year >= 2005) %>% 
     breakpoints(diff(exp_tot) ~ 1, data = ., breaks = 1))$breakpoints, 
  0, na.rm = TRUE)

# Restrict to viable training data.
t_dat <- a_dat %>% 
  filter(scen == "H", year >= 2005 + (bp - 1)) %>% 
  select(- scen)

# Number of estimation points to consider for trend extrapolation.
ran <- 
  expand_grid(nau = par$pr[["tr"]], mau = par$pr[["tr"]], 
              mch = par$pr[["tr"]], tot = par$pr[["tot"]])

# List to store forecast performances of respective 'range' fits.
el <- list()

# Constant to harmonize the number of training windows across parameter sets.
tmax <- max(max(par$pr))

# Cross-validation loop over trend ranges, groups and time windows.
for (z in 1:nrow(ran)) {
  
  rant <-  
    slice(ran, z)
  
  eval <- list()

  for (w in 1:(length(unique(t_dat$year)) - (tmax + par$out) + 1)) {
    
    rl <- list()
    
      for (i in c("m", "f")) {
        for (j in c("ch", "au")) {
          for (k in c("alt", "kin", "wai", "wit")) {
          
            temp <- t_dat %>% 
              filter(sex == i, dom == j, type == k) %>% 
              slice(w : (w + (tmax + par$out) - 1))
            
            # Impute native average pension level.
            ind_m <- (tmax + 1):(tmax + par$out) 
            temp$m[ind_m] <- NA
            
            if (j == "ch") {
              
              fit_m <- 
                lm(I(diff(m)) ~ 1, slice(temp, (tmax - rant$mch + 1):tmax)) %>% 
                predict(slice(temp, ind_m))
              
              temp$m[ind_m] <- 
                pmax(0, last(na.omit(temp)$m) + cumsum(fit_m))
            }
            
            if (j == "au") {
              
              # Impute foreign average pension level.
              fit_m <- 
                lm(I(diff(m)) ~ 1, slice(temp, (tmax - rant$mau + 1):tmax)) %>% 
                predict(slice(temp, ind_m))
              
              temp$m[ind_m] <- 
                pmin(2, pmax(0, last(na.omit(temp)$m) + cumsum(fit_m)))
  
              # Impute foreign pension count.
              temp$n[(tmax + 1):(tmax + par$out)] <- NA
              ind_n <- which(is.na(temp$n))
              
              fit_n <- 
                lm(I(diff(n)) ~ 1, slice(temp, (tmax - rant$nau + 1):tmax)) %>% 
                predict(slice(temp, ind_n))
              
              temp$n[ind_n] <- 
                pmax(0, last(na.omit(temp)$n) + cumsum(fit_n))
            }
    
            rl[[paste0(i, j, k)]] <- temp
          }
        }
      }

    split <- 
      initial_time_split(
        bind_rows(rl) %>%
          dplyr::summarise(m = weighted.mean(m, n), n = sum(n), .by = c("year", "mp")) %>% 
          slice((tmax - rant$tot + 1):(tmax + par$out)) %>%
          left_join(zas, by = "year") %>% 
          rename(exp_p = exp_tot) %>% 
          mutate(d_nmmp = diff(n * m * mp)), 
        prop = rant$tot / (rant$tot + par$out))

    fit <- 
      lm(I(diff(exp_p)) ~ 0 + d_nmmp, data = training(split)) %>% 
      predict(testing(split)) %>% 
      cumsum()
  
    temp <- 
      tibble(k    = 1:par$out,
             obs  = testing(split)$exp_p[1:par$out], 
             pred = (last(training(split)$exp_p) + fit)[1:par$out]) %>% 
      mutate(err = par$err(pred, obs),
             ran = rep(paste0(rant, collapse = ""), nrow(.)))
    
    # Discount forecast errors.
    d <- inf %>% 
      filter(year %in% testing(split)$year) %>% 
      mutate(k = 1:par$out, df = cumprod(1 / (inf / first(inf) + .02))) %>% 
      select(k, df)
    
    temp %<>%
      left_join(d, by = "k") %>% 
      mutate(err = err * df) %>% 
      select(- df)
    
    eval[[w]] <- temp
  }
  
  # Save run for specific 'range' combination.
  el[[z]] <- eval
}

# Arrange and aggregate k-step errors of different trend estimation sample
# sizes, and select the 'range' combination with minimal average error.
range <- 
  max(t_dat$year) + 1 - 
  slice(ran, 
    bind_rows(el) %>%
    dplyr::summarize(err = par$glo(err), .by = c("ran", "k")) %>%
    pivot_wider(names_from = ran, values_from = err) %>%
    dplyr::summarize(across(- 1, par$glo)) %>%
    which.min()) %>% as.numeric()

save(range, file = "data/range.RData")
