From 7d727fc8f075a9c1d3024605dd30cf37f24f0ca1 Mon Sep 17 00:00:00 2001 From: Al Date: Sun, 17 Jan 2016 20:59:47 -0500 Subject: [PATCH] [optimization] Using adapted learning rate in stochastic gradient descent (if lambda > 0) --- src/stochastic_gradient_descent.c | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/stochastic_gradient_descent.c b/src/stochastic_gradient_descent.c index feb2207e..b9439e86 100644 --- a/src/stochastic_gradient_descent.c +++ b/src/stochastic_gradient_descent.c @@ -103,7 +103,9 @@ bool stochastic_gradient_descent_regularize_weights(matrix_t *theta, uint32_arra uint32_t row = rows[i]; double *theta_i = matrix_get_row(theta, row); uint32_t last_updated = updates[row]; - regularize_row(theta_i, n, lambda, last_updated, t, gamma_0); + + double gamma_t = stochastic_gradient_descent_gamma_t(gamma_0, lambda, t - last_updated); + regularize_row(theta_i, n, lambda, last_updated, t, gamma_t); updates[row] = t; } @@ -121,8 +123,9 @@ inline bool stochastic_gradient_descent_finalize_weights(matrix_t *theta, uint32 for (size_t i = 0; i < m; i++) { double *theta_i = matrix_get_row(theta, i); uint32_t last_updated = updates[i]; - regularize_row(theta_i, n, lambda, last_updated, t, gamma_0); + double gamma_t = stochastic_gradient_descent_gamma_t(gamma_0, lambda, t - last_updated); + regularize_row(theta_i, n, lambda, last_updated, t, gamma_t); updates[i] = t; } }