#' @title Canonical Correlation Analysis and Panel VAR
#' @name cca_pvar
#' @description Functions for sparse CCA and panel vector autoregression.
NULL

#' Run Sparse CCA with PCA Preprocessing
#'
#' Performs canonical correlation analysis on PCA-reduced price matrices,
#' with optional sparsity penalties.
#'
#' @param X_matrix Numeric matrix of direct prices.
#' @param Y_matrix Numeric matrix of production prices.
#' @param n_components Number of canonical components to extract. Default 3.
#' @param variance_threshold Cumulative variance threshold for PCA. Default 0.90.
#' @param min_pcs Minimum number of PCs to retain. Default 8.
#' @param max_pcs Maximum number of PCs to retain. Default 12.
#'
#' @return A list containing:
#' \describe{
#'   \item{method}{Method used for CCA}
#'   \item{correlations}{Canonical correlations}
#'   \item{U_loadings}{X loadings in PC space}
#'   \item{V_loadings}{Y loadings in PC space}
#'   \item{W_X_original}{X loadings projected to original variables}
#'   \item{W_Y_original}{Y loadings projected to original variables}
#'   \item{n_pcs_x}{Number of PCs used for X}
#'   \item{n_pcs_y}{Number of PCs used for Y}
#' }
#'
#' @details
#' The function first reduces dimensionality using PCA, then applies CCA.
#' Falls back to base R cancor if specialized packages unavailable.
#'
#' @examples
#' set.seed(123)
#' n <- 50
#' p <- 20
#' X <- matrix(rnorm(n * p), n, p)
#' Y <- X %*% matrix(rnorm(p * 5), p, 5) + matrix(rnorm(n * 5, 0, 0.5), n, 5)
#' colnames(X) <- paste0("X", 1:p)
#' colnames(Y) <- paste0("Y", 1:5)
#'
#' result <- run_sparse_cca(X, Y, n_components = 2)
#' print(result$correlations)
#'
#' @export
run_sparse_cca <- function(X_matrix,
                            Y_matrix,
                            n_components = 3L,
                            variance_threshold = 0.90,
                            min_pcs = 8L,
                            max_pcs = 12L) {

    if (!is.matrix(X_matrix)) {
        X_matrix <- as.matrix(X_matrix)
    }
    if (!is.matrix(Y_matrix)) {
        Y_matrix <- as.matrix(Y_matrix)
    }

    if (nrow(X_matrix) != nrow(Y_matrix)) {
        stop("X_matrix and Y_matrix must have same number of rows.")
    }

    complete_rows <- stats::complete.cases(cbind(X_matrix, Y_matrix)) &
        apply(is.finite(cbind(X_matrix, Y_matrix)), 1L, all)

    if (!all(complete_rows)) {
        n_removed <- sum(!complete_rows)
        warning(sprintf("Removed %d incomplete rows.", n_removed))
        X_matrix <- X_matrix[complete_rows, , drop = FALSE]
        Y_matrix <- Y_matrix[complete_rows, , drop = FALSE]
    }

    pca_x <- stats::prcomp(X_matrix, center = TRUE, scale. = TRUE)
    pca_y <- stats::prcomp(Y_matrix, center = TRUE, scale. = TRUE)

    k_x <- choose_n_pcs(pca_x, min_pcs, max_pcs, variance_threshold)
    k_y <- choose_n_pcs(pca_y, min_pcs, max_pcs, variance_threshold)

    X_pc <- scale(as.matrix(pca_x$x[, seq_len(k_x), drop = FALSE]))
    Y_pc <- scale(as.matrix(pca_y$x[, seq_len(k_y), drop = FALSE]))

    n_comp_actual <- min(n_components, k_x, k_y)

    cca_result <- stats::cancor(X_pc, Y_pc)

    correlations <- cca_result$cor[seq_len(n_comp_actual)]

    U_load <- cca_result$xcoef[, seq_len(n_comp_actual), drop = FALSE]
    V_load <- cca_result$ycoef[, seq_len(n_comp_actual), drop = FALSE]

    W_X_orig <- pca_x$rotation[, seq_len(k_x), drop = FALSE] %*% U_load
    W_Y_orig <- pca_y$rotation[, seq_len(k_y), drop = FALSE] %*% V_load

    list(
        method = "base::cancor (on PCA scores)",
        correlations = correlations,
        sum_r_squared = sum(correlations^2),
        U_loadings = U_load,
        V_loadings = V_load,
        W_X_original = W_X_orig,
        W_Y_original = W_Y_orig,
        n_pcs_x = k_x,
        n_pcs_y = k_y,
        pca_x = pca_x,
        pca_y = pca_y
    )
}


#' Choose Number of Principal Components
#'
#' Internal function to select PCs based on variance threshold.
#'
#' @param pca_obj Result from prcomp.
#' @param min_k Minimum components.
#' @param max_k Maximum components.
#' @param threshold Variance threshold.
#'
#' @return Integer number of components.
#'
#' @keywords internal
choose_n_pcs <- function(pca_obj, min_k, max_k, threshold) {

    eigenvalues <- pca_obj$sdev^2
    cumvar <- cumsum(eigenvalues) / sum(eigenvalues)

    k_threshold <- which(cumvar >= threshold)[1L]
    if (is.na(k_threshold)) {
        k_threshold <- length(eigenvalues)
    }

    k_raw <- min(max(min_k, k_threshold), max_k)

    k_safe <- min(k_raw, nrow(pca_obj$x), ncol(pca_obj$x))

    if (is.na(k_safe) || k_safe < 2L) {
        k_safe <- min(max_k, max(2L, ncol(pca_obj$x), nrow(pca_obj$x)))
    }

    k_safe
}


#' Extract Top CCA Loadings
#'
#' Extracts the variables with highest absolute loadings for each CCA component.
#'
#' @param cca_result Result from run_sparse_cca.
#' @param n_top Number of top variables to show. Default 5.
#' @param which_matrix Character. "X", "Y", or "both". Default "both".
#'
#' @return Data frame with top loadings.
#'
#' @examples
#' set.seed(123)
#' n <- 50
#' p <- 20
#' X <- matrix(rnorm(n * p), n, p)
#' Y <- X %*% matrix(rnorm(p * 5), p, 5) + matrix(rnorm(n * 5, 0, 0.5), n, 5)
#' colnames(X) <- paste0("X", 1:p)
#' colnames(Y) <- paste0("Y", 1:5)
#'
#' cca_res <- run_sparse_cca(X, Y, n_components = 2)
#' top_loads <- extract_cca_loadings(cca_res)
#' print(top_loads)
#'
#' @export
extract_cca_loadings <- function(cca_result, n_top = 5L,
                                  which_matrix = c("both", "X", "Y")) {

    which_matrix <- match.arg(which_matrix)

    results <- list()

    if (which_matrix %in% c("both", "X")) {
        W_X <- cca_result$W_X_original
        if (!is.null(W_X) && !is.null(rownames(W_X))) {
            for (k in seq_len(ncol(W_X))) {
                loadings_k <- W_X[, k]
                top_idx <- order(abs(loadings_k), decreasing = TRUE)[seq_len(min(n_top, length(loadings_k)))]
                results[[length(results) + 1L]] <- data.frame(
                    matrix = "X",
                    component = k,
                    variable = rownames(W_X)[top_idx],
                    loading = loadings_k[top_idx],
                    stringsAsFactors = FALSE
                )
            }
        }
    }

    if (which_matrix %in% c("both", "Y")) {
        W_Y <- cca_result$W_Y_original
        if (!is.null(W_Y) && !is.null(rownames(W_Y))) {
            for (k in seq_len(ncol(W_Y))) {
                loadings_k <- W_Y[, k]
                top_idx <- order(abs(loadings_k), decreasing = TRUE)[seq_len(min(n_top, length(loadings_k)))]
                results[[length(results) + 1L]] <- data.frame(
                    matrix = "Y",
                    component = k,
                    variable = rownames(W_Y)[top_idx],
                    loading = loadings_k[top_idx],
                    stringsAsFactors = FALSE
                )
            }
        }
    }

    if (length(results) == 0L) {
        return(NULL)
    }

    result_df <- do.call(rbind, results)
    rownames(result_df) <- NULL
    result_df
}


#' Fit Panel VAR Model
#'
#' Fits a panel vector autoregression model with first-difference transformation.
#'
#' @param panel_data Data frame in panel format.
#' @param max_lags Maximum lag order to consider. Default 2.
#' @param verbose Logical. Print progress. Default TRUE.
#'
#' @return A list containing:
#' \describe{
#'   \item{model}{The fitted panelvar model}
#'   \item{best_lag}{Selected lag order}
#'   \item{bic_values}{BIC for each lag order tested}
#' }
#'
#' @examples
#' \donttest{
#' if (requireNamespace("panelvar", quietly = TRUE)) {
#'   set.seed(123)
#'   panel <- data.frame(
#'     year = rep(2000:2019, 5),
#'     sector = rep(LETTERS[1:5], each = 20),
#'     log_direct = rnorm(100, 5, 0.5),
#'     log_production = rnorm(100, 5, 0.5)
#'   )
#'
#'   result <- fit_panel_var(panel)
#'   print(result$best_lag)
#' }
#' }
#'
#' @export
fit_panel_var <- function(panel_data, max_lags = 2L, verbose = TRUE) {

    check_package("panelvar", "panel VAR estimation")

    validate_panel_data(panel_data, require_log = TRUE)

    panel_var_data <- panel_data[, c("sector", "year", "log_direct", "log_production")]
    panel_var_data <- panel_var_data[order(panel_var_data$sector, panel_var_data$year), ]

    bic_values <- numeric(max_lags)
    models <- vector("list", max_lags)

    for (lag in seq_len(max_lags)) {

        if (verbose) {
            message(sprintf("Fitting PVAR with lag = %d...", lag))
        }

        model_fit <- tryCatch(
            panelvar::pvarfeols(
                dependent_vars = c("log_production", "log_direct"),
                lags = lag,
                transformation = "fd",
                data = panel_var_data,
                panel_identifier = c("sector", "year")
            ),
            error = function(e) NULL
        )

        if (!is.null(model_fit)) {
            models[[lag]] <- model_fit
            bic_values[lag] <- compute_pvar_bic(model_fit)
        } else {
            bic_values[lag] <- Inf
        }
    }

    if (all(!is.finite(bic_values))) {
        warning("All PVAR models failed to fit.")
        return(list(model = NULL, best_lag = NA, bic_values = bic_values))
    }

    best_lag <- which.min(bic_values)

    list(
        model = models[[best_lag]],
        best_lag = best_lag,
        bic_values = bic_values
    )
}


#' Compute BIC for Panel VAR Model
#'
#' Internal function to compute approximate BIC.
#'
#' @param pvar_obj Fitted pvarfeols object.
#'
#' @return Numeric BIC value.
#'
#' @keywords internal
compute_pvar_bic <- function(pvar_obj) {

    resids <- tryCatch(
        stats::residuals(pvar_obj),
        error = function(e) NULL
    )

    if (is.null(resids)) {
        return(Inf)
    }

    r_vec <- as.numeric(as.matrix(resids))
    r_vec <- r_vec[is.finite(r_vec)]

    if (length(r_vec) == 0L) {
        return(Inf)
    }

    n <- length(r_vec)
    s2 <- mean(r_vec^2)

    k <- tryCatch(
        length(stats::coef(pvar_obj)),
        error = function(e) 10L
    )

    if (!is.finite(k)) {
        k <- 10L
    }

    log(s2) + (k * log(n)) / n
}


#' Fit Aggregated VAR Model
#'
#' Fits a VAR model on time-aggregated (mean across sectors) data.
#'
#' @param panel_data Data frame in panel format.
#' @param max_lags Maximum lag order. Default 6.
#' @param difference Logical. Apply first differencing. Default TRUE.
#'
#' @return A list containing:
#' \describe{
#'   \item{model}{The fitted vars::VAR model}
#'   \item{selected_lag}{Lag selected by information criteria}
#'   \item{irf}{Impulse response functions (if computed)}
#'   \item{fevd}{Forecast error variance decomposition (if computed)}
#' }
#'
#' @examples
#' \donttest{
#' if (requireNamespace("vars", quietly = TRUE)) {
#'   set.seed(123)
#'   panel <- data.frame(
#'     year = rep(2000:2019, 5),
#'     sector = rep(LETTERS[1:5], each = 20),
#'     log_direct = rnorm(100, 5, 0.5),
#'     log_production = rnorm(100, 5, 0.5)
#'   )
#'
#'   result <- fit_aggregated_var(panel)
#'   print(result$selected_lag)
#' }
#' }
#'
#' @export
fit_aggregated_var <- function(panel_data, max_lags = 6L, difference = TRUE) {

    check_package("vars", "VAR model estimation")

    validate_panel_data(panel_data, require_log = TRUE)

    agg_ts <- aggregate_to_timeseries(
        panel_data,
        vars = c("log_direct", "log_production")
    )

    agg_ts <- agg_ts[order(agg_ts$year), ]

    var_data <- data.frame(
        x = agg_ts$log_direct_mean,
        y = agg_ts$log_production_mean
    )

    if (difference) {
        var_data <- data.frame(
            dx = c(NA, diff(var_data$x)),
            dy = c(NA, diff(var_data$y))
        )
        var_data <- var_data[-1L, , drop = FALSE]
    }

    if (nrow(var_data) < 10L) {
        warning("Too few observations for VAR estimation.")
        return(list(model = NULL, selected_lag = NA, irf = NULL, fevd = NULL))
    }

    lag_max <- min(max_lags, floor(nrow(var_data) / 4L))

    p_selected <- 2L

    tryCatch(
        {
            sel_result <- vars::VARselect(var_data, lag.max = lag_max, type = "const")

            if (!is.null(sel_result$selection)) {
                p_candidates <- as.numeric(stats::na.omit(sel_result$selection))
                if (length(p_candidates) > 0L) {
                    p_selected <- as.integer(p_candidates[1L])
                    if (!is.finite(p_selected) || p_selected < 1L) {
                        p_selected <- 2L
                    }
                    p_selected <- min(max(p_selected, 1L), lag_max)
                }
            }
        },
        error = function(e) {
            warning("VAR lag selection failed; using p = 2.")
        }
    )

    var_fit <- tryCatch(
        vars::VAR(var_data, p = p_selected, type = "const"),
        error = function(e) {
            tryCatch(
                vars::VAR(var_data, p = 1L, type = "const"),
                error = function(e2) NULL
            )
        }
    )

    if (is.null(var_fit)) {
        warning("VAR model fitting failed.")
        return(list(model = NULL, selected_lag = NA, irf = NULL, fevd = NULL))
    }

    irf_result <- tryCatch(
        {
            irf_xy <- vars::irf(var_fit, impulse = "dx", response = "dy",
                                n.ahead = 12L, boot = TRUE, runs = 200L)
            irf_yx <- vars::irf(var_fit, impulse = "dy", response = "dx",
                                n.ahead = 12L, boot = TRUE, runs = 200L)
            list(direct_to_production = irf_xy, production_to_direct = irf_yx)
        },
        error = function(e) NULL
    )

    fevd_result <- tryCatch(
        vars::fevd(var_fit, n.ahead = 12L),
        error = function(e) NULL
    )

    list(
        model = var_fit,
        selected_lag = var_fit$p,
        irf = irf_result,
        fevd = fevd_result,
        aggregated_data = agg_ts
    )
}


#' Run Complete CCA and VAR Analysis
#'
#' Convenience function to run both sparse CCA and panel/aggregated VAR.
#'
#' @param direct_prices Data frame with direct prices.
#' @param production_prices Data frame with production prices.
#' @param panel_data Data frame in panel format.
#' @param cca_components Number of CCA components. Default 3.
#' @param verbose Logical. Print progress. Default TRUE.
#'
#' @return A list with cca, pvar, agg_var, and granger results.
#'
#' @examples
#' \donttest{
#' set.seed(123)
#' years <- 2000:2019
#' sectors <- LETTERS[1:5]
#'
#' direct <- data.frame(Year = years)
#' production <- data.frame(Year = years)
#' for (s in sectors) {
#'   direct[[s]] <- 100 + cumsum(rnorm(20, 2, 1))
#'   production[[s]] <- 102 + cumsum(rnorm(20, 2, 1))
#' }
#'
#' panel <- prepare_panel_data(direct, production)
#'
#' matrices <- prepare_log_matrices(direct, production)
#'
#' result <- run_cca_var_analysis(
#'   direct, production, panel,
#'   cca_components = 2
#' )
#' }
#'
#' @export
run_cca_var_analysis <- function(direct_prices,
                                  production_prices,
                                  panel_data,
                                  cca_components = 3L,
                                  verbose = TRUE) {

    if (verbose) {
        message("Running sparse CCA analysis...")
    }

    matrices <- prepare_log_matrices(direct_prices, production_prices)

    cca_result <- run_sparse_cca(
        matrices$X_clean,
        matrices$Y_clean,
        n_components = cca_components
    )

    pvar_result <- NULL
    if (requireNamespace("panelvar", quietly = TRUE)) {
        if (verbose) {
            message("Fitting panel VAR...")
        }
        pvar_result <- tryCatch(
            fit_panel_var(panel_data, verbose = verbose),
            error = function(e) {
                warning("Panel VAR failed: ", conditionMessage(e))
                NULL
            }
        )
    }

    agg_var_result <- NULL
    if (requireNamespace("vars", quietly = TRUE)) {
        if (verbose) {
            message("Fitting aggregated VAR...")
        }
        agg_var_result <- tryCatch(
            fit_aggregated_var(panel_data),
            error = function(e) {
                warning("Aggregated VAR failed: ", conditionMessage(e))
                NULL
            }
        )
    }

    granger_result <- NULL
    if (requireNamespace("plm", quietly = TRUE)) {
        if (verbose) {
            message("Running panel Granger tests...")
        }
        granger_result <- tryCatch(
            panel_granger_test(panel_data),
            error = function(e) NULL
        )
    }

    list(
        cca = cca_result,
        pvar = pvar_result,
        agg_var = agg_var_result,
        granger = granger_result
    )
}
