##############################################################
# TUNE TREND SAMPLE RANGE VIA SLIDING WINDOW CROSSVALIDATION #
################################################################################

# Speed up computations through masking of functions by C++ equivalents and 
# multithreading.
set_collapse(mask = "all", nthreads = 4)

# Calculate breakpoint for starting year selection.
bp <- 
  (left_join(zas, mp, by = "year") %>% 
  na.omit() %>% 
  mutate(exp_tot = exp_tot / mp) %>% 
  filter(year >= 2005) %>% 
  breakpoints(diff(exp_tot) ~ 1, data = ., breaks = 1))$breakpoints

# Restrict to viable training data.
t_dat <- 
  filter(a_dat, scen == "H", year >= 2005 + (bp - 1)) %>% 
  select(- scen)

# Set discount factors for error weighting.
d <- 
  mutate(inf, k = year - first(year)) %>%
  filter(k %in% 1:par$out) %>%  
  select(k, dc)

# Number of estimation points to consider for trend extrapolation.
ran <- 
  expand.grid(nau = par$pr[["tr"]], 
              mau = par$pr[["tr"]], 
              mch = par$pr[["tr"]], 
              tot = par$pr[["tot"]])

# List to store forecast performances of respective 'range' fits.
el <- list()

# Constant to harmonize the number of training windows across parameter sets.
tmax <- max(max(par$pr))

# Cross-validation loop over trend ranges, groups and time windows.
for (z in 1:nrow(ran)) {
  
  rant <- 
    slice(ran, z)
  
  eval <- list()

  for (w in 1:(length(unique(t_dat$year)) - (tmax + par$out) + 1)) {
    
    rl <- list()
    
      for (i in c("m", "f")) {
        for (j in c("ch", "au")) {
          for (k in c("alt", "kin", "wai", "wit")) {
          
            temp <- 
              filter(t_dat, sex == i, dom == j, type == k) %>% 
              slice(w : (w + (tmax + par$out) - 1))
            
            # Impute native average pension level.
            indm <- 
              (tmax + 1):(tmax + par$out) 
            temp$m[indm] <- NA
            
            if (j == "ch") {
              
              fitm <- 
                lm(I(diff(m)) ~ 1, slice(temp, (tmax - rant$mch + 1):tmax)) %>% 
                predict(slice(temp, indm))
              
              temp$m[indm] <- 
                pmax(0, last(na.omit(temp)$m) + cumsum(fitm))
            }
            
            if (j == "au") {
              
              # Impute foreign average pension level.
              fitm <- 
                lm(I(diff(m)) ~ 1, slice(temp, (tmax - rant$mau + 1):tmax)) %>% 
                predict(slice(temp, indm))
              
              temp$m[indm] <- 
                pmax(0, last(na.omit(temp)$m) + cumsum(fitm))
  
              # Impute foreign pension count.
              temp$n[(tmax + 1):(tmax + par$out)] <- NA
              indn <- which(is.na(temp$n))
              
              fitn <- 
                lm(I(diff(n)) ~ 1, slice(temp, (tmax - rant$nau + 1):tmax)) %>% 
                predict(slice(temp, indn))
              
              temp$n[indn] <- 
                pmax(0, last(na.omit(temp)$n) + cumsum(fitn))
            }
    
            rl[[paste0(i, j, k)]] <- 
              temp
          }
        }
      }

    split <- 
      initial_time_split(
        bind_rows(rl) %>%
          dplyr::summarise(m = weighted.mean(m, n), n = sum(n), .by = c("year", "mp")) %>% 
          slice((tmax - rant$tot + 1):(tmax + par$out)) %>%
          left_join(zas, by = "year") %>% 
          rename(exp_p = exp_tot) %>% 
          mutate(d_nmmp = diff(n * m * mp)), 
        prop = rant$tot / (rant$tot + par$out))

    fit <- 
      lm(I(diff(exp_p)) ~ 0 + d_nmmp, data = training(split)) %>% 
      predict(testing(split)) %>% 
      cumsum()
  
    temp <- 
      tibble(k    = 1:par$out,
             obs  = testing(split)$exp_p[1:par$out], 
             pred = (last(training(split)$exp_p) + fit)[1:par$out])
    
    eval[[w]] <- 
      mutate(temp,
             # Evaluate predictions along within k-step error metric set in 'par'.
             err = par$err(pred, obs),
             ran = rep(paste0(rant, collapse = ""), nrow(temp)))
  }
  
  # Save run for specific 'range' combination.
  el[[z]] <- eval
}

# Arrange and aggregate k-step errors of different trend estimation sample
# sizes, and select the 'range' combination with minimal average error.
range <- 
  max(t_dat$year) + 1 - 
  slice(ran, 
    bind_rows(el) %>%
    dplyr::summarize(err = par$glo(err), .by = c("ran", "k")) %>%
    left_join(d, by = "k") %>% 
    mutate(err = err * dc) %>% 
    select(- dc) %>% 
    pivot_wider(names_from = ran, values_from = err) %>%
    dplyr::summarize(across(- 1, par$glo)) %>%
    which.min()) %>% as.numeric()

save(range, file = "data/range.RData")

x <- 
  bind_rows(el) %>% 
  filter(ran == parse_number(
    bind_rows(el) %>%
      dplyr::summarize(err = par$glo(err), .by = c("ran", "k")) %>%
      left_join(d, by = "k") %>% 
      mutate(err = err * dc) %>% 
      select(- dc) %>% 
      pivot_wider(names_from = ran, values_from = err) %>%
      dplyr::summarize(across(- 1, par$glo)) %>% 
      which.min() %>% names()
  )) %>% 
  mutate(err = err / 1e6,
         year = rep(min(t_dat$year):(min(t_dat$year) + nrow(.) / 10 - 1), each = 10)) %>%     
  left_join(d, by = "k") %>% 
  mutate(err = err * dc) %>% 
  select(- dc)

ggplot(x, aes(x = as.factor(k), y = err, col = as.factor(year))) +
  geom_shadowpoint()


