[classification] correcting cost functions in SGD and FTRL for use in parameter sweeps
This commit is contained in:
@@ -191,12 +191,12 @@ double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filenam
|
||||
}
|
||||
double_matrix_zero(p_y);
|
||||
|
||||
if (!sparse_matrix_add_unique_columns(x, trainer->unique_columns, trainer->batch_columns)) {
|
||||
if (!sparse_matrix_add_unique_columns_alias(x, trainer->unique_columns, trainer->batch_columns)) {
|
||||
log_error("Error adding unique columns\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
double_matrix_t *theta = logistic_regression_trainer_get_regularized_weights(trainer);
|
||||
double_matrix_t *theta = logistic_regression_trainer_get_weights(trainer);
|
||||
|
||||
if (!logistic_regression_model_expectation(theta, x, p_y)) {
|
||||
log_error("Predict cv batch failed\n");
|
||||
@@ -242,6 +242,7 @@ double compute_total_cost(logistic_regression_trainer_t *trainer, char *filename
|
||||
|
||||
double total_cost = 0.0;
|
||||
size_t num_batches = 0;
|
||||
size_t num_examples = 0;
|
||||
|
||||
// Need to regularize the weights
|
||||
double_matrix_t *theta = logistic_regression_trainer_get_regularized_weights(trainer);
|
||||
@@ -251,6 +252,8 @@ double compute_total_cost(logistic_regression_trainer_t *trainer, char *filename
|
||||
double batch_cost = logistic_regression_trainer_minibatch_cost(trainer, minibatch->features, minibatch->labels);
|
||||
total_cost += batch_cost;
|
||||
|
||||
num_examples += minibatch->features->n;
|
||||
|
||||
language_classifier_minibatch_destroy(minibatch);
|
||||
|
||||
num_batches++;
|
||||
@@ -260,6 +263,10 @@ double compute_total_cost(logistic_regression_trainer_t *trainer, char *filename
|
||||
}
|
||||
}
|
||||
|
||||
double reg_cost = logistic_regression_trainer_regularization_cost(trainer, num_examples);
|
||||
log_info("cost = %f, reg_cost = %f, m = %zu\n", total_cost, reg_cost, num_examples);
|
||||
total_cost += reg_cost;
|
||||
|
||||
language_classifier_data_set_destroy(data_set);
|
||||
|
||||
return total_cost;
|
||||
@@ -305,7 +312,7 @@ bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, cha
|
||||
}
|
||||
|
||||
if (compute_cost) {
|
||||
train_cost = logistic_regression_trainer_minibatch_cost(trainer, minibatch->features, minibatch->labels);
|
||||
train_cost = logistic_regression_trainer_minibatch_cost_regularized(trainer, minibatch->features, minibatch->labels);
|
||||
log_info("cost = %f\n", train_cost);
|
||||
}
|
||||
|
||||
@@ -323,7 +330,7 @@ bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, cha
|
||||
|
||||
if (train_batches > 0 && num_batches == (size_t)train_batches) {
|
||||
log_info("Epoch %u, trained %zu examples\n", trainer->epochs, num_batches * minibatch_size);
|
||||
train_cost = logistic_regression_trainer_minibatch_cost(trainer, minibatch->features, minibatch->labels);
|
||||
train_cost = logistic_regression_trainer_minibatch_cost_regularized(trainer, minibatch->features, minibatch->labels);
|
||||
log_info("cost = %f\n", train_cost);
|
||||
break;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user