intro_to_ml/19_nystroem_approximation_c...

71 lines
2.0 KiB
R
Raw Normal View History

2023-01-16 18:17:52 -05:00
# Nystroem Approximation Kernel Ridge Regression
nakr <-
function(X, y, sigma2=NULL, lambdas=NULL, nb.landmarks=NULL)
2023-01-16 18:17:52 -05:00
{
set.seed(1123)
2023-01-16 18:17:52 -05:00
X <- as.matrix(X)
n <- nrow(X)
p <- ncol(X)
if(is.null(sigma2)) { sigma2 <- p }
if(is.null(nb.landmarks)) { nb.landmarks <- 15*n^(1/3) }
2023-01-16 18:17:52 -05:00
X <- scale(X)
y <- scale(y)
idx.landmarks <- sample(1:n, nb.landmarks, replace = FALSE)
S <- X[idx.landmarks, ]
K11 <- gausskernel.nakr(S, S, sigma2)
C <- gausskernel.nakr(X, S, sigma2)
svd.K11 <- svd(K11)
ks <- which(svd.K11$d < 1E-12) # K11 often ill-formed -> drop small sv
if (length(ks)>0) {k <- ks[1]} else {k <- length(svd.K11$d)}
W <- svd.K11$u[,1:k] %*% diag(1/sqrt(svd.K11$d[1:k]))
Phi <- C %*% W
ridge <- ridge(Phi, y, lambdas)
r <- list(center.X=attr(X,"scaled:center"),
scale.X=attr(X,"scaled:scale"),
center.y=attr(y,"scaled:center"),
scale.y=attr(y,"scaled:scale"),
S=S,
2023-01-16 18:17:52 -05:00
sigma2=sigma2,
W=W,
ridge=ridge
2023-01-16 18:17:52 -05:00
)
class(r) <- "nakr"
return(r)
}
predict.nakr <-
function(o, newdata)
{
if(class(o) != "nakr") {
warning("Object is not of class 'nakr'")
UseMethod("predict")
return(invisible(NULL))
}
test <- as.matrix(newdata)
test <- scale(test,center=o$center.X,scale=o$scale.X)
K.test <- gausskernel.nakr(test, o$S, o$sigma2)
Phi.test <- K.test %*% o$W
yh <- predict(o$ridge, Phi.test)
yh <- yh * o$scale.y + o$center.y
2023-01-16 18:17:52 -05:00
}
# compute the gaussian kernel between each row of X1 and each row of X2
# should be done more efficiently (C code, threads)
gausskernel.nakr <-
function(X1, X2, sigma2)
{
if(is(X1,"vector"))
X1 <- as.matrix(X1)
if(is(X2,"vector"))
X2 <- as.matrix(X2)
if (!(dim(X1)[2]==dim(X2)[2]))
stop("X1 and X2 must have the same number of columns")
n1 <- dim(X1)[1]
n2 <- dim(X2)[1]
dotX1 <- rowSums(X1*X1)
dotX2 <- rowSums(X2*X2)
res <- X1%*%t(X2)
for(i in 1:n2) res[,i] <- exp((2*res[,i] - dotX1 - rep(dotX2[i],n1))/sigma2)
return(res)
}