##############################################################
# TUNE TREND SAMPLE RANGE VIA SLIDING WINDOW CROSSVALIDATION #
################################################################################

# Speed up computations through masking of functions by C++ equivalents and 
# multithreading.
set_collapse(mask = "all", nthreads = 4)

# Restrict to viable training data.
t_dat <- 
  filter(a_dat, scen == "H", year >= par$sj) %>% 
  select(- scen)

# Number of estimation points to consider for trend extrapolation.
ran <- 
  expand.grid(nau = par$pr[["tr"]], 
              mau = par$pr[["tr"]], 
              mch = par$pr[["tr"]], 
              tot = par$pr[["tot"]])

# List to store forecast performances of respective 'range' fits.
el <- list()

# Cross-validation loop over trend ranges, groups and time windows.
for (z in 1:nrow(ran)) {
  
  rant <- 
    slice(ran, z) %>% 
    rowwise() %>% 
    dplyr::mutate(max = max(c_across(nau:tot)))
  
  eval <- list()
  
  for (w in 1:(length(unique(t_dat$year)) - (rant$max + par$out) + 1)) {
    
    rl <- list()
    
      for (i in c("m", "f")) {
        for (j in c("ch", "au")) {
          for (k in c("alt", "kin", "wai", "wit")) {
          
            temp <- 
              filter(t_dat, sex == i, dom == j, type == k) %>% 
              slice(w : (w + (rant$max + par$out) - 1))
            
            # Impute native average pension level.
            indm <- 
              (rant$max + 1):(rant$max + par$out) 
            temp$m[indm] <- NA
            
            if (j == "ch") {
              
              fitm <- 
                lm(I(diff(m)) ~ 1, slice(temp, (rant$max - rant$mch + 1):rant$max)) %>% 
                predict(slice(temp, indm))
              
              temp$m[indm] <- 
                pmax(0, last(na.omit(temp)$m) + cumsum(fitm))
            }
            
            if (j == "au") {
              
              # Impute foreign average pension level.
              fitm <- 
                lm(I(diff(m)) ~ 1, slice(temp, (rant$max - rant$mau + 1):rant$max)) %>% 
                predict(slice(temp, indm))
              
              temp$m[indm] <- 
                pmax(0, last(na.omit(temp)$m) + cumsum(fitm))
  
              # Impute foreign pension count.
              temp$n[(rant$max + 1):(rant$max + par$out)] <- NA
              indn <- which(is.na(temp$n))
              
              fitn <- 
                lm(I(diff(n)) ~ 1, slice(temp, (rant$max - rant$nau + 1):rant$max)) %>% 
                predict(slice(temp, indn))
              
              temp$n[indn] <- 
                pmax(0, last(na.omit(temp)$n) + cumsum(fitn))
            }
    
            rl[[paste0(i, j, k)]] <- 
              temp
          }
        }
      }

    split <- 
      initial_time_split(
        bind_rows(rl) %>%
          dplyr::summarise(m = weighted.mean(m, n), n = sum(n), .by = c("year", "mp")) %>% 
          slice((rant$max - rant$tot + 1):(rant$max + par$out)) %>%
          left_join(zas, by = "year") %>% 
          rename(exp_p = exp_tot) %>% 
          mutate(d_nmmp = diff(n * m * mp)), 
        prop = rant$tot / (rant$tot + par$out))

    fit <- 
      lm(I(diff(exp_p)) ~ 0 + d_nmmp, data = training(split)) %>% 
      predict(testing(split)) %>% 
      cumsum()
  
    temp <- 
      tibble(k    = 1:par$out,
             obs  = testing(split)$exp_p[1:par$out], 
             pred = (last(training(split)$exp_p) + fit)[1:par$out])
    
    eval[[w]] <- 
      mutate(temp,
             # Evaluate predictions along within k-step error metric set in 'par'.
             err = par$err(pred, obs),
             ran = rep(paste0(select(rant, - max), collapse = ""), nrow(temp)))
  }
  
  # Save run for specific 'range' combination.
  el[[z]] <- eval
}

# Arrange and aggregate k-step errors of different trend estimation sample
# sizes, and select the 'range' combination with minimal average error.
range <- 
  max(t_dat$year) + 1 - 
  slice(ran, 
    bind_rows(el) %>%
    dplyr::summarize(err = par$glo(err) * 100, .by = c("ran", "k")) %>%
    pivot_wider(names_from = ran, values_from = err) %>%
    dplyr::summarize(across(- 1, par$glo)) %>%
    which.min()) %>% as.numeric()

save(range, file = "data/range.RData")

x <- 
  bind_rows(el) %>% 
  filter(ran == parse_number(
    bind_rows(el) %>%
      dplyr::summarize(err = par$glo(err) * 100, .by = c("ran", "k")) %>%
      pivot_wider(names_from = ran, values_from = err) %>%
      dplyr::summarize(across(- 1, par$glo)) %>% 
      which.min() %>% names()
  )) %>% 
  mutate(err = err * 100,
         year = rep(par$sj:(par$sj + nrow(.) / 10 - 1), each = 10))

ggplot(x, aes(x = as.factor(k), y = err, col = as.factor(year))) +
  geom_shadowpoint()
