[classification] correcting cost functions in SGD and FTRL for use in parameter sweeps
This commit is contained in:
22
src/ftrl.c
22
src/ftrl.c
@@ -178,22 +178,33 @@ bool ftrl_update_gradient(ftrl_trainer_t *self, double_matrix_t *gradient, doubl
|
||||
return true;
|
||||
}
|
||||
|
||||
double ftrl_reg_cost(ftrl_trainer_t *self, double_matrix_t *theta, uint32_array *indices, size_t batch_size) {
|
||||
double ftrl_reg_cost(ftrl_trainer_t *self, double_matrix_t *theta, uint32_array *update_indices, size_t batch_size) {
|
||||
double cost = 0.0;
|
||||
|
||||
size_t m = theta->m;
|
||||
size_t n = theta->n;
|
||||
|
||||
uint32_t *indices = NULL;
|
||||
size_t num_indices = m;
|
||||
|
||||
if (update_indices != NULL) {
|
||||
uint32_t *indices = update_indices->a;
|
||||
size_t num_indices = update_indices->n;
|
||||
}
|
||||
size_t i_start = self->fit_intercept ? 1 : 0;
|
||||
|
||||
double lambda1 = self->lambda1;
|
||||
double lambda2 = self->lambda2;
|
||||
|
||||
size_t i_start = self->fit_intercept ? 1 : 0;
|
||||
|
||||
double l2_cost = 0.0;
|
||||
double l1_cost = 0.0;
|
||||
|
||||
for (size_t i = 0; i < m; i++) {
|
||||
uint32_t row_idx = indices->a[i];
|
||||
uint32_t row_idx = i;
|
||||
if (indices != NULL) {
|
||||
row_idx = indices[i];
|
||||
}
|
||||
|
||||
if (row_idx >= i_start) {
|
||||
double *theta_i = double_matrix_get_row(theta, i);
|
||||
|
||||
@@ -205,9 +216,10 @@ double ftrl_reg_cost(ftrl_trainer_t *self, double_matrix_t *theta, uint32_array
|
||||
cost += lambda2 / 2.0 * l2_cost;
|
||||
cost += lambda1 * l1_cost;
|
||||
|
||||
return cost;
|
||||
return cost * 1.0 / (double)batch_size;
|
||||
}
|
||||
|
||||
|
||||
double_matrix_t *ftrl_weights_finalize(ftrl_trainer_t *self) {
|
||||
if (!ftrl_set_weights(self, self->z, NULL)) {
|
||||
return NULL;
|
||||
|
||||
@@ -79,7 +79,7 @@ int main(int argc, char **argv) {
|
||||
filename = argv[1];
|
||||
}
|
||||
|
||||
if (!language_classifier_module_setup(argv[1]) || !address_dictionary_module_setup(NULL)) {
|
||||
if (!language_classifier_module_setup(dir) || !address_dictionary_module_setup(NULL) || !transliteration_module_setup(NULL)) {
|
||||
log_error("Error setting up classifier\n");
|
||||
}
|
||||
|
||||
|
||||
@@ -191,12 +191,12 @@ double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filenam
|
||||
}
|
||||
double_matrix_zero(p_y);
|
||||
|
||||
if (!sparse_matrix_add_unique_columns(x, trainer->unique_columns, trainer->batch_columns)) {
|
||||
if (!sparse_matrix_add_unique_columns_alias(x, trainer->unique_columns, trainer->batch_columns)) {
|
||||
log_error("Error adding unique columns\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
double_matrix_t *theta = logistic_regression_trainer_get_regularized_weights(trainer);
|
||||
double_matrix_t *theta = logistic_regression_trainer_get_weights(trainer);
|
||||
|
||||
if (!logistic_regression_model_expectation(theta, x, p_y)) {
|
||||
log_error("Predict cv batch failed\n");
|
||||
@@ -242,6 +242,7 @@ double compute_total_cost(logistic_regression_trainer_t *trainer, char *filename
|
||||
|
||||
double total_cost = 0.0;
|
||||
size_t num_batches = 0;
|
||||
size_t num_examples = 0;
|
||||
|
||||
// Need to regularize the weights
|
||||
double_matrix_t *theta = logistic_regression_trainer_get_regularized_weights(trainer);
|
||||
@@ -251,6 +252,8 @@ double compute_total_cost(logistic_regression_trainer_t *trainer, char *filename
|
||||
double batch_cost = logistic_regression_trainer_minibatch_cost(trainer, minibatch->features, minibatch->labels);
|
||||
total_cost += batch_cost;
|
||||
|
||||
num_examples += minibatch->features->n;
|
||||
|
||||
language_classifier_minibatch_destroy(minibatch);
|
||||
|
||||
num_batches++;
|
||||
@@ -260,6 +263,10 @@ double compute_total_cost(logistic_regression_trainer_t *trainer, char *filename
|
||||
}
|
||||
}
|
||||
|
||||
double reg_cost = logistic_regression_trainer_regularization_cost(trainer, num_examples);
|
||||
log_info("cost = %f, reg_cost = %f, m = %zu\n", total_cost, reg_cost, num_examples);
|
||||
total_cost += reg_cost;
|
||||
|
||||
language_classifier_data_set_destroy(data_set);
|
||||
|
||||
return total_cost;
|
||||
@@ -305,7 +312,7 @@ bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, cha
|
||||
}
|
||||
|
||||
if (compute_cost) {
|
||||
train_cost = logistic_regression_trainer_minibatch_cost(trainer, minibatch->features, minibatch->labels);
|
||||
train_cost = logistic_regression_trainer_minibatch_cost_regularized(trainer, minibatch->features, minibatch->labels);
|
||||
log_info("cost = %f\n", train_cost);
|
||||
}
|
||||
|
||||
@@ -323,7 +330,7 @@ bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, cha
|
||||
|
||||
if (train_batches > 0 && num_batches == (size_t)train_batches) {
|
||||
log_info("Epoch %u, trained %zu examples\n", trainer->epochs, num_batches * minibatch_size);
|
||||
train_cost = logistic_regression_trainer_minibatch_cost(trainer, minibatch->features, minibatch->labels);
|
||||
train_cost = logistic_regression_trainer_minibatch_cost_regularized(trainer, minibatch->features, minibatch->labels);
|
||||
log_info("cost = %f\n", train_cost);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -122,7 +122,7 @@ bool logistic_regression_trainer_reset_params_ftrl(logistic_regression_trainer_t
|
||||
return ftrl_trainer_reset_params(ftrl_trainer, alpha, beta, lambda1, lambda2);
|
||||
}
|
||||
|
||||
double logistic_regression_trainer_minibatch_cost(logistic_regression_trainer_t *self, feature_count_array *features, cstring_array *labels) {
|
||||
static double logistic_regression_trainer_minibatch_cost_params(logistic_regression_trainer_t *self, feature_count_array *features, cstring_array *labels, bool regularized) {
|
||||
size_t n = self->num_labels;
|
||||
|
||||
sparse_matrix_t *x = feature_matrix(self->feature_ids, features);
|
||||
@@ -141,14 +141,16 @@ double logistic_regression_trainer_minibatch_cost(logistic_regression_trainer_t
|
||||
|
||||
cost = logistic_regression_cost_function(weights, x, y, p_y);
|
||||
|
||||
if (self->optimizer_type == LOGISTIC_REGRESSION_OPTIMIZER_SGD) {
|
||||
sgd_trainer_t *sgd_trainer = self->optimizer.sgd;
|
||||
double reg_cost = stochastic_gradient_descent_reg_cost(sgd_trainer, self->batch_columns, x->m);
|
||||
cost += reg_cost;
|
||||
} else if (self->optimizer_type == LOGISTIC_REGRESSION_OPTIMIZER_FTRL) {
|
||||
ftrl_trainer_t *ftrl_trainer = self->optimizer.ftrl;
|
||||
double reg_cost = ftrl_reg_cost(ftrl_trainer, weights, self->batch_columns, x->m);
|
||||
cost += reg_cost;
|
||||
if (regularized) {
|
||||
if (self->optimizer_type == LOGISTIC_REGRESSION_OPTIMIZER_SGD) {
|
||||
sgd_trainer_t *sgd_trainer = self->optimizer.sgd;
|
||||
double reg_cost = stochastic_gradient_descent_reg_cost(sgd_trainer, self->batch_columns, x->m);
|
||||
cost += reg_cost;
|
||||
} else if (self->optimizer_type == LOGISTIC_REGRESSION_OPTIMIZER_FTRL) {
|
||||
ftrl_trainer_t *ftrl_trainer = self->optimizer.ftrl;
|
||||
double reg_cost = ftrl_reg_cost(ftrl_trainer, weights, self->batch_columns, x->m);
|
||||
cost += reg_cost;
|
||||
}
|
||||
}
|
||||
|
||||
exit_cost_matrices_created:
|
||||
@@ -158,6 +160,27 @@ exit_cost_matrices_created:
|
||||
return cost;
|
||||
}
|
||||
|
||||
inline double logistic_regression_trainer_minibatch_cost(logistic_regression_trainer_t *self, feature_count_array *features, cstring_array *labels) {
|
||||
return logistic_regression_trainer_minibatch_cost_params(self, features, labels, false);
|
||||
}
|
||||
|
||||
inline double logistic_regression_trainer_minibatch_cost_regularized(logistic_regression_trainer_t *self, feature_count_array *features, cstring_array *labels) {
|
||||
return logistic_regression_trainer_minibatch_cost_params(self, features, labels, true);
|
||||
}
|
||||
|
||||
double logistic_regression_trainer_regularization_cost(logistic_regression_trainer_t *self, size_t m) {
|
||||
if (self->optimizer_type == LOGISTIC_REGRESSION_OPTIMIZER_SGD) {
|
||||
sgd_trainer_t *sgd_trainer = self->optimizer.sgd;
|
||||
return stochastic_gradient_descent_reg_cost(sgd_trainer, NULL, m);
|
||||
} else if (self->optimizer_type == LOGISTIC_REGRESSION_OPTIMIZER_FTRL) {
|
||||
ftrl_trainer_t *ftrl_trainer = self->optimizer.ftrl;
|
||||
double_matrix_t *weights = logistic_regression_trainer_get_weights(self);
|
||||
return ftrl_reg_cost(ftrl_trainer, weights, NULL, m);
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
|
||||
bool logistic_regression_trainer_train_minibatch(logistic_regression_trainer_t *self, feature_count_array *features, cstring_array *labels) {
|
||||
double_matrix_t *gradient = self->gradient;
|
||||
|
||||
@@ -233,16 +256,9 @@ double_matrix_t *logistic_regression_trainer_get_weights(logistic_regression_tra
|
||||
|
||||
if (self->optimizer_type == LOGISTIC_REGRESSION_OPTIMIZER_SGD) {
|
||||
if (self->optimizer.sgd == NULL) return NULL;
|
||||
double_matrix_t *full_weights = self->optimizer.sgd->theta;
|
||||
uint32_t *columns = self->batch_columns->a;
|
||||
|
||||
for (size_t i = 0; i < m; i++) {
|
||||
uint32_t col = columns[i];
|
||||
double *theta_row = double_matrix_get_row(full_weights, col);
|
||||
double *row = double_matrix_get_row(batch_weights, i);
|
||||
for (size_t j = 0; j < n; j++) {
|
||||
row[j] = theta_row[j];
|
||||
}
|
||||
if (!stochastic_gradient_descent_set_regularized_weights(self->optimizer.sgd, self->batch_weights, self->batch_columns)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return batch_weights;
|
||||
|
||||
@@ -50,6 +50,8 @@ bool logistic_regression_trainer_reset_params_sgd(logistic_regression_trainer_t
|
||||
bool logistic_regression_trainer_reset_params_ftrl(logistic_regression_trainer_t *self, double alpha, double beta, double lambda1, double lambda2);
|
||||
bool logistic_regression_trainer_train_minibatch(logistic_regression_trainer_t *self, feature_count_array *features, cstring_array *labels);
|
||||
double logistic_regression_trainer_minibatch_cost(logistic_regression_trainer_t *self, feature_count_array *features, cstring_array *labels);
|
||||
double logistic_regression_trainer_minibatch_cost_regularized(logistic_regression_trainer_t *self, feature_count_array *features, cstring_array *labels);
|
||||
double logistic_regression_trainer_regularization_cost(logistic_regression_trainer_t *self, size_t m);
|
||||
|
||||
double_matrix_t *logistic_regression_trainer_get_weights(logistic_regression_trainer_t *self);
|
||||
double_matrix_t *logistic_regression_trainer_get_regularized_weights(logistic_regression_trainer_t *self);
|
||||
|
||||
@@ -28,6 +28,8 @@ sgd_trainer_t *sgd_trainer_new(size_t m, size_t n, bool fit_intercept, regulariz
|
||||
if (sgd->penalties == NULL) {
|
||||
goto exit_sgd_trainer_created;
|
||||
}
|
||||
// Penalty for last_updated == 0 is 0
|
||||
double_array_push(sgd->penalties, 0.0);
|
||||
} else {
|
||||
sgd->last_updated = NULL;
|
||||
sgd->penalties = NULL;
|
||||
@@ -60,6 +62,7 @@ bool sgd_trainer_reset_params(sgd_trainer_t *self, double lambda, double gamma_0
|
||||
} else {
|
||||
double_array_clear(self->penalties);
|
||||
}
|
||||
double_array_push(self->penalties, 0.0);
|
||||
}
|
||||
|
||||
double_matrix_zero(self->theta);
|
||||
@@ -70,7 +73,7 @@ bool sgd_trainer_reset_params(sgd_trainer_t *self, double lambda, double gamma_0
|
||||
}
|
||||
|
||||
|
||||
inline double stochastic_gradient_descent_gamma_t(double gamma_0, double lambda, uint32_t t) {
|
||||
static inline double stochastic_gradient_descent_gamma_t(double gamma_0, double lambda, uint32_t t) {
|
||||
return gamma_0 / (1.0 + lambda * gamma_0 * (double)t);
|
||||
}
|
||||
|
||||
@@ -194,18 +197,16 @@ bool stochastic_gradient_descent_update_sparse(sgd_trainer_t *self, double_matri
|
||||
double lambda_update = 0.0;
|
||||
double penalty = 0.0;
|
||||
|
||||
double *penalties = self->penalties->a;
|
||||
|
||||
if (reg_type != REGULARIZATION_NONE) {
|
||||
lambda_update = lambda / (double)batch_size * gamma_t;
|
||||
|
||||
if (self->iterations > 0) {
|
||||
uint32_t penalty_index = t - 1;
|
||||
|
||||
if (penalty_index >= self->penalties->n) {
|
||||
log_info("t = %zu, penalty_index = %u, penalties->n = %zu\n", t, penalty_index, self->penalties->n);
|
||||
return false;
|
||||
}
|
||||
penalty = self->penalties->a[penalty_index];
|
||||
if (t > self->penalties->n) {
|
||||
log_info("t = %zu, penalties->n = %zu\n", t, self->penalties->n);
|
||||
return false;
|
||||
}
|
||||
penalty = self->penalties->a[t];
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_updated; i++) {
|
||||
@@ -218,20 +219,23 @@ bool stochastic_gradient_descent_update_sparse(sgd_trainer_t *self, double_matri
|
||||
|
||||
if (self->iterations > 0) {
|
||||
if (last_updated >= self->penalties->n) {
|
||||
log_info("t = %zu, last_updated = %zu, penalties->n = %zu\n", col, indices[i - 1], t, last_updated, self->penalties->n);
|
||||
log_info("col = %u, t = %zu, last_updated = %zu, penalties->n = %zu\n", col, t, last_updated, self->penalties->n);
|
||||
return false;
|
||||
}
|
||||
last_update_penalty = self->penalties->a[last_updated];
|
||||
|
||||
last_update_penalty = penalties[last_updated];
|
||||
|
||||
// Update the weights to what they would have been
|
||||
// if all the regularization updates were applied
|
||||
|
||||
double penalty_update = penalty - last_update_penalty;
|
||||
if (last_updated < t) {
|
||||
double penalty_update = penalty - last_update_penalty;
|
||||
|
||||
if (reg_type == REGULARIZATION_L2 && col >= i_start) {
|
||||
regularize_l2(theta_i, n, penalty_update);
|
||||
} else if (reg_type == REGULARIZATION_L1 && col >= i_start) {
|
||||
regularize_l1(theta_i, n, penalty_update);
|
||||
if (reg_type == REGULARIZATION_L2 && col >= i_start) {
|
||||
regularize_l2(theta_i, n, penalty_update);
|
||||
} else if (reg_type == REGULARIZATION_L1 && col >= i_start) {
|
||||
regularize_l1(theta_i, n, penalty_update);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,8 +252,9 @@ bool stochastic_gradient_descent_update_sparse(sgd_trainer_t *self, double_matri
|
||||
regularize_l1(theta_i, n, lambda_update);
|
||||
}
|
||||
|
||||
// Set the last updated timestep for this feature to time t
|
||||
updates[col] = t;
|
||||
// Set the last updated timestep for this feature to time t + 1
|
||||
// since we're upating the iteration count
|
||||
updates[col] = t + 1;
|
||||
}
|
||||
|
||||
if (reg_type != REGULARIZATION_NONE) {
|
||||
@@ -270,19 +275,28 @@ double stochastic_gradient_descent_reg_cost(sgd_trainer_t *self, uint32_array *u
|
||||
if (reg_type == REGULARIZATION_NONE) return cost;
|
||||
|
||||
double_matrix_t *theta = self->theta;
|
||||
size_t m = theta->m;
|
||||
size_t n = theta->n;
|
||||
|
||||
uint32_t *indices = update_indices->a;
|
||||
size_t num_indices = update_indices->n;
|
||||
uint32_t *indices = NULL;
|
||||
size_t num_indices = m;
|
||||
|
||||
size_t n = self->theta->n;
|
||||
if (update_indices != NULL) {
|
||||
uint32_t *indices = update_indices->a;
|
||||
size_t num_indices = update_indices->n;
|
||||
}
|
||||
size_t i_start = self->fit_intercept ? 1 : 0;
|
||||
|
||||
for (size_t i = 0; i < num_indices; i++) {
|
||||
uint32_t row = indices[i];
|
||||
uint32_t row = i;
|
||||
if (indices != NULL) {
|
||||
row = indices[i];
|
||||
}
|
||||
double *theta_i = double_matrix_get_row(theta, row);
|
||||
|
||||
if (reg_type == REGULARIZATION_L2) {
|
||||
if (reg_type == REGULARIZATION_L2 && row >= i_start) {
|
||||
cost += double_array_l2_norm(theta_i, n);
|
||||
} else if (reg_type == REGULARIZATION_L1) {
|
||||
} else if (reg_type == REGULARIZATION_L1 && row >= i_start) {
|
||||
cost += double_array_l1_norm(theta_i, n);
|
||||
}
|
||||
}
|
||||
@@ -293,11 +307,10 @@ double stochastic_gradient_descent_reg_cost(sgd_trainer_t *self, uint32_array *u
|
||||
cost *= self->lambda;
|
||||
}
|
||||
|
||||
return cost * 1.0 / (double)batch_size;
|
||||
return cost / (double)batch_size;
|
||||
}
|
||||
|
||||
|
||||
bool stochastic_gradient_descent_regularize_weights(sgd_trainer_t *self) {
|
||||
bool stochastic_gradient_descent_set_regularized_weights(sgd_trainer_t *self, double_matrix_t *w, uint32_array *indices) {
|
||||
if (self == NULL || self->theta == NULL) {
|
||||
if (self->theta == NULL) {
|
||||
log_info("stochastic_gradient_descent_regularize_weights theta NULL\n");
|
||||
@@ -306,58 +319,89 @@ bool stochastic_gradient_descent_regularize_weights(sgd_trainer_t *self) {
|
||||
}
|
||||
|
||||
double lambda = self->lambda;
|
||||
double gamma_0 = self->gamma_0;
|
||||
regularization_type_t reg_type = self->reg_type;
|
||||
|
||||
if (lambda > 0.0 && reg_type != REGULARIZATION_NONE) {
|
||||
double_matrix_t *theta = self->theta;
|
||||
double_matrix_t *theta = self->theta;
|
||||
|
||||
size_t m = theta->m;
|
||||
size_t n = theta->n;
|
||||
size_t m = theta->m;
|
||||
size_t n = theta->n;
|
||||
|
||||
size_t i_start = self->fit_intercept ? 1 : 0;
|
||||
uint32_t *row_indices = NULL;
|
||||
size_t num_indices = m;
|
||||
|
||||
double prev_penalty = 0.0;
|
||||
if (reg_type != REGULARIZATION_NONE) {
|
||||
if (self->iterations > 0) {
|
||||
uint32_t penalty_index = self->iterations - 1;
|
||||
if (penalty_index >= self->penalties->n) {
|
||||
log_error("penalty_index (%u) >= self->penalties->n (%zu)\n", penalty_index, self->penalties->n);
|
||||
return false;
|
||||
}
|
||||
prev_penalty = self->penalties->a[penalty_index];
|
||||
} else {
|
||||
prev_penalty = 1.0;
|
||||
}
|
||||
if (indices != NULL) {
|
||||
row_indices = indices->a;
|
||||
num_indices = indices->n;
|
||||
}
|
||||
|
||||
uint32_t *updates = self->last_updated->a;
|
||||
double *penalties = self->penalties->a;
|
||||
|
||||
if (w != NULL && !double_matrix_resize(w, num_indices, n)) {
|
||||
log_error("Resizing weights failed\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t i_start = self->fit_intercept ? 1 : 0;
|
||||
bool regularize = lambda > 0.0 && reg_type != REGULARIZATION_NONE;
|
||||
|
||||
for (size_t i = 0; i < num_indices; i++) {
|
||||
uint32_t row_idx = i;
|
||||
if (indices != NULL) {
|
||||
row_idx = row_indices[i];
|
||||
}
|
||||
|
||||
uint32_t *updates = self->last_updated->a;
|
||||
double *theta_i = double_matrix_get_row(theta, row_idx);
|
||||
double *w_i = theta_i;
|
||||
if (w != NULL) {
|
||||
w_i = double_matrix_get_row(w, i);
|
||||
double_array_raw_copy(w_i, theta_i, n);
|
||||
}
|
||||
|
||||
if (regularize && i >= i_start) {
|
||||
double most_recent_penalty = 0.0;
|
||||
uint32_t most_recent_iter = 0;
|
||||
|
||||
if (self->iterations > 0) {
|
||||
most_recent_iter = self->iterations;
|
||||
if (most_recent_iter >= self->penalties->n) {
|
||||
log_error("penalty_index (%u) >= self->penalties->n (%zu)\n", most_recent_iter, self->penalties->n);
|
||||
return false;
|
||||
}
|
||||
most_recent_penalty = penalties[most_recent_iter];
|
||||
} else {
|
||||
most_recent_penalty = lambda / gamma_0;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < m; i++) {
|
||||
double *theta_i = double_matrix_get_row(theta, i);
|
||||
uint32_t last_updated = updates[i];
|
||||
if (last_updated > self->penalties->n) {
|
||||
if (last_updated >= self->penalties->n) {
|
||||
log_error("last_updated (%zu) >= self->penalties-> (%zu)\n", last_updated, self->penalties->n);
|
||||
return false;
|
||||
}
|
||||
double last_update_penalty = self->penalties->a[last_updated];
|
||||
double last_update_penalty = penalties[last_updated];
|
||||
|
||||
double penalty_update = prev_penalty - last_update_penalty;
|
||||
if (last_updated < most_recent_iter) {
|
||||
double penalty_update = most_recent_penalty - last_update_penalty;
|
||||
|
||||
if (reg_type == REGULARIZATION_L2 && i >= i_start) {
|
||||
regularize_l2(theta_i, n, penalty_update);
|
||||
} else if (reg_type == REGULARIZATION_L1 && i >= i_start) {
|
||||
regularize_l1(theta_i, n, penalty_update);
|
||||
if (reg_type == REGULARIZATION_L2) {
|
||||
regularize_l2(w_i, n, penalty_update);
|
||||
} else if (reg_type == REGULARIZATION_L1) {
|
||||
regularize_l1(w_i, n, penalty_update);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (reg_type != REGULARIZATION_NONE) {
|
||||
uint32_array_set(self->last_updated->a, (self->iterations > 0 ? self->iterations - 1 : 0), self->last_updated->n);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool stochastic_gradient_descent_regularize_weights(sgd_trainer_t *self) {
|
||||
return stochastic_gradient_descent_set_regularized_weights(self, NULL, NULL);
|
||||
}
|
||||
|
||||
double_matrix_t *stochastic_gradient_descent_get_weights(sgd_trainer_t *self) {
|
||||
if (!stochastic_gradient_descent_regularize_weights(self)) {
|
||||
log_info("stochastic_gradient_descent_regularize_weights returned false\n");
|
||||
|
||||
@@ -38,6 +38,7 @@ bool sgd_trainer_reset_params(sgd_trainer_t *self, double lambda, double gamma_0
|
||||
bool stochastic_gradient_descent_update(sgd_trainer_t *self, double_matrix_t *gradient, size_t batch_size);
|
||||
bool stochastic_gradient_descent_update_sparse(sgd_trainer_t *self, double_matrix_t *gradient, uint32_array *update_indices, size_t batch_size);
|
||||
double stochastic_gradient_descent_reg_cost(sgd_trainer_t *self, uint32_array *indices, size_t batch_size);
|
||||
bool stochastic_gradient_descent_set_regularized_weights(sgd_trainer_t *self, double_matrix_t *w, uint32_array *indices);
|
||||
bool stochastic_gradient_descent_regularize_weights(sgd_trainer_t *self);
|
||||
double_matrix_t *stochastic_gradient_descent_get_weights(sgd_trainer_t *self);
|
||||
sparse_matrix_t *stochastic_gradient_descent_get_weights_sparse(sgd_trainer_t *self);
|
||||
|
||||
Reference in New Issue
Block a user