[language_classification] Automatic hyperparameter optimization using either the cross-validation set or two distinct subsets of the training set

This commit is contained in:
Al
2016-01-17 21:11:37 -05:00
parent af5689ee52
commit f808f74271
6 changed files with 299 additions and 112 deletions

View File

@@ -121,7 +121,7 @@ inline bool language_classifier_language_is_valid(char *language) {
return !string_equals(language, AMBIGUOUS_LANGUAGE) && !string_equals(language, UNKNOWN_LANGUAGE);
}
language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, size_t batch_size, bool with_country) {
language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, khash_t(str_uint32) *labels, size_t batch_size) {
size_t in_batch = 0;
language_classifier_minibatch_t *minibatch = NULL;
@@ -131,17 +131,19 @@ language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with
if (strlen(address) == 0) {
continue;
}
char *country = NULL;
if (with_country) {
country = char_array_get_string(self->country);
}
char *country = NULL;
//char *country = char_array_get_string(self->country);
char *language = char_array_get_string(self->language);
if (!language_classifier_language_is_valid(language)) {
continue;
}
if (labels != NULL && kh_get(str_uint32, labels, language) == kh_end(labels)) {
continue;
}
if (minibatch == NULL) {
minibatch = language_classifier_minibatch_new();
if (minibatch == NULL) {
@@ -150,6 +152,7 @@ language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with
}
}
if (labels != NULL) {
khash_t(str_double) *feature_counts = extract_language_features(address, country, self->tokens, self->feature_array);
if (feature_counts == NULL) {
log_error("Could not extract features for: %s\n", address);
@@ -157,6 +160,8 @@ language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with
return NULL;
}
feature_count_array_push(minibatch->features, feature_counts);
}
cstring_array_add_string(minibatch->labels, language);
in_batch++;
}
@@ -164,8 +169,8 @@ language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with
return minibatch;
}
inline language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, bool with_country) {
return language_classifier_data_set_get_minibatch_with_size(self, LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE, with_country);
inline language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, khash_t(str_uint32) *labels) {
return language_classifier_data_set_get_minibatch_with_size(self, labels, LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE);
}
void language_classifier_data_set_destroy(language_classifier_data_set_t *self) {

View File

@@ -42,8 +42,8 @@ language_classifier_data_set_t *language_classifier_data_set_init(char *filename
bool language_classifier_data_set_next(language_classifier_data_set_t *self);
void language_classifier_data_set_destroy(language_classifier_data_set_t *self);
language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, size_t batch_size, bool with_country);
language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, bool with_country);
language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, khash_t(str_uint32) *labels, size_t batch_size);
language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, khash_t(str_uint32) *labels);
void language_classifier_minibatch_destroy(language_classifier_minibatch_t *self);
#endif

View File

@@ -10,16 +10,33 @@
#include "logistic_regression.h"
#include "logistic_regression_trainer.h"
#include "shuffle.h"
#include "sparse_matrix.h"
#include "sparse_matrix_utils.h"
#define LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD 1.0
#define LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD 1
#define LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD 100
#define LOG_BATCH_INTERVAL 10
#define COMPUTE_COST_INTERVAL 100
#define LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES 100
static double GAMMA_SCHEDULE[] = {0.01, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0};
static const size_t GAMMA_SCHEDULE_SIZE = sizeof(GAMMA_SCHEDULE) / sizeof(double);
#define DEFAULT_GAMMA_0 10.0
static double LAMBDA_SCHEDULE[] = {0.0, 1e-5, 1e-4, 0.001, 0.01, 0.1, \
0.2, 0.5, 1.0, 2.0, 5.0, 10.0};
static const size_t LAMBDA_SCHEDULE_SIZE = sizeof(LAMBDA_SCHEDULE) / sizeof(double);
#define DEFAULT_LAMBDA 0.0
#define TRAIN_EPOCHS 10
logistic_regression_trainer_t *language_classifier_init_thresholds(char *filename, bool with_country, double feature_count_threshold, uint32_t label_count_threshold) {
#define HYPERPARAMETER_EPOCHS 30
logistic_regression_trainer_t *language_classifier_init_thresholds(char *filename, double feature_count_threshold, uint32_t label_count_threshold) {
if (filename == NULL) {
log_error("Filename was NULL\n");
return NULL;
@@ -34,26 +51,56 @@ logistic_regression_trainer_t *language_classifier_init_thresholds(char *filenam
size_t num_batches = 0;
// Count features and labels
while ((minibatch = language_classifier_data_set_get_minibatch(data_set, with_country)) != NULL) {
if (!count_features_minibatch(feature_counts, minibatch->features, true)){
log_error("Counting minibatch features failed\n");
exit(EXIT_FAILURE);
}
while ((minibatch = language_classifier_data_set_get_minibatch(data_set, NULL)) != NULL) {
if (!count_labels_minibatch(label_counts, minibatch->labels)) {
log_error("Counting minibatch labeles failed\n");
exit(EXIT_FAILURE);
}
if (num_batches % LOG_BATCH_INTERVAL == 0) {
log_info("Counted %zu batches\n", num_batches);
log_info("Counting labels, did %zu batches\n", num_batches);
}
num_batches++;
language_classifier_minibatch_destroy(minibatch);
}
log_info("Done counting, finalizing\n");
log_info("Done counting labels\n");
language_classifier_data_set_destroy(data_set);
data_set = language_classifier_data_set_init(filename);
num_batches = 0;
khash_t(str_uint32) *label_ids = select_labels_threshold(label_counts, label_count_threshold);
if (label_ids == NULL) {
log_error("Error creating labels\n");
exit(EXIT_FAILURE);
}
size_t num_labels = kh_size(label_ids);
log_info("num_labels=%zu\n", num_labels);
// Don't free the label strings as the pointers are reused in select_labels_threshold
kh_destroy(str_uint32, label_counts);
// Run through the training set again, counting only features which co-occur with valid classes
while ((minibatch = language_classifier_data_set_get_minibatch(data_set, label_ids)) != NULL) {
if (!count_features_minibatch(feature_counts, minibatch->features, true)){
log_error("Counting minibatch features failed\n");
exit(EXIT_FAILURE);
}
if (num_batches % LOG_BATCH_INTERVAL == 0) {
log_info("Counting features, did %zu batches\n", num_batches);
}
num_batches++;
language_classifier_minibatch_destroy(minibatch);
}
log_info("Done counting features, finalizing\n");
language_classifier_data_set_destroy(data_set);
@@ -71,23 +118,15 @@ logistic_regression_trainer_t *language_classifier_init_thresholds(char *filenam
})
kh_destroy(str_double, feature_counts);
khash_t(str_uint32) *label_ids = select_labels_threshold(label_counts, label_count_threshold);
if (label_ids == NULL) {
log_error("Error creating labels\n");
exit(EXIT_FAILURE);
return logistic_regression_trainer_init(feature_ids, label_ids, DEFAULT_GAMMA_0, DEFAULT_LAMBDA);
}
// Don't free the label strings as the pointers are reused in select_labels_threshold
kh_destroy(str_uint32, label_counts);
return logistic_regression_trainer_init(feature_ids, label_ids);
logistic_regression_trainer_t *language_classifier_init(char *filename) {
return language_classifier_init_thresholds(filename, LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD, LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD);
}
logistic_regression_trainer_t *language_classifier_init(char *filename, bool with_country) {
return language_classifier_init_thresholds(filename, with_country, LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD, LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD);
}
double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filename, bool with_country) {
double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filename) {
language_classifier_data_set_t *data_set = language_classifier_data_set_init(filename);
language_classifier_minibatch_t *minibatch;
@@ -97,7 +136,7 @@ double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filenam
matrix_t *p_y = matrix_new_zeros(LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE, trainer->num_labels);
while ((minibatch = language_classifier_data_set_get_minibatch(data_set, with_country)) != NULL) {
while ((minibatch = language_classifier_data_set_get_minibatch(data_set, trainer->label_ids)) != NULL) {
sparse_matrix_t *x = feature_matrix(trainer->feature_ids, minibatch->features);
uint32_array *y = label_vector(trainer->label_ids, minibatch->labels);
@@ -138,64 +177,41 @@ double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filenam
return accuracy;
}
bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, bool with_country) {
if (filename == NULL) {
log_error("Filename was NULL\n");
return false;
}
double compute_total_cost(logistic_regression_trainer_t *trainer, char *filename, ssize_t compute_batches) {
language_classifier_data_set_t *data_set = language_classifier_data_set_init(filename);
language_classifier_minibatch_t *minibatch;
size_t num_batches = 0;
double batch_cost = 0.0;
double total_cost = 0.0;
double last_cost = 0.0;
size_t num_batches = 0;
double train_cost = 0.0;
double cv_accuracy = 0.0;
while ((minibatch = language_classifier_data_set_get_minibatch(data_set, trainer->label_ids)) != NULL) {
while ((minibatch = language_classifier_data_set_get_minibatch(data_set, with_country)) != NULL) {
bool compute_cost = num_batches % COMPUTE_COST_INTERVAL == 0 && num_batches > 0;
double batch_cost = logistic_regression_trainer_batch_cost(trainer, minibatch->features, minibatch->labels);
total_cost += batch_cost;
if (num_batches % LOG_BATCH_INTERVAL == 0 && num_batches > 0) {
log_info("Epoch %u, trained %zu batches\n", trainer->epochs, num_batches);
}
if (compute_cost) {
train_cost = logistic_regression_trainer_batch_cost(trainer, minibatch->features, minibatch->labels);
log_info("cost = %f\n", train_cost);
}
if (!logistic_regression_trainer_train_batch(trainer, minibatch->features, minibatch->labels)){
log_error("Train batch failed\n");
exit(EXIT_FAILURE);
}
if (compute_cost && cv_filename != NULL) {
cv_accuracy = compute_cv_accuracy(trainer, cv_filename, with_country);
log_info("cv accuracy=%f\n", cv_accuracy);
}
language_classifier_minibatch_destroy(minibatch);
num_batches++;
language_classifier_minibatch_destroy(minibatch);
if (compute_batches > 0 && num_batches == (size_t)compute_batches) {
break;
}
}
language_classifier_data_set_destroy(data_set);
return true;
return total_cost;
}
language_classifier_t *language_classifier_train(char *filename, char *cv_filename, uint32_t num_iterations, bool with_country) {
logistic_regression_trainer_t *trainer = language_classifier_init(filename, with_country);
for (uint32_t epoch = 0; epoch < num_iterations; epoch++) {
log_info("Doing epoch %d\n", epoch);
trainer->epochs = epoch;
bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, ssize_t train_batches) {
if (filename == NULL) {
log_error("Filename was NULL\n");
return false;
}
#if defined(HAVE_SHUF)
log_info("Shuffling\n");
@@ -209,7 +225,138 @@ language_classifier_t *language_classifier_train(char *filename, char *cv_filena
log_info("Shuffle complete\n");
#endif
if (!language_classifier_train_epoch(trainer, filename, cv_filename, with_country)) {
language_classifier_data_set_t *data_set = language_classifier_data_set_init(filename);
language_classifier_minibatch_t *minibatch;
size_t num_batches = 0;
double batch_cost = 0.0;
double total_cost = 0.0;
double last_cost = 0.0;
double train_cost = 0.0;
double cv_accuracy = 0.0;
while ((minibatch = language_classifier_data_set_get_minibatch(data_set, trainer->label_ids)) != NULL) {
bool compute_cost = num_batches % COMPUTE_COST_INTERVAL == 0 && num_batches > 0;
if (num_batches % LOG_BATCH_INTERVAL == 0 && num_batches > 0) {
log_info("Epoch %u, doing batch %zu\n", trainer->epochs, num_batches);
}
if (compute_cost) {
train_cost = logistic_regression_trainer_batch_cost(trainer, minibatch->features, minibatch->labels);
log_info("cost = %f\n", train_cost);
}
if (!logistic_regression_trainer_train_batch(trainer, minibatch->features, minibatch->labels)){
log_error("Train batch failed\n");
exit(EXIT_FAILURE);
}
if (compute_cost && cv_filename != NULL) {
cv_accuracy = compute_cv_accuracy(trainer, cv_filename);
log_info("cv accuracy=%f\n", cv_accuracy);
}
num_batches++;
if (train_batches > 0 && num_batches == (size_t)train_batches) {
log_info("Epoch %u, trained %zu batches\n", trainer->epochs, num_batches);
train_cost = logistic_regression_trainer_batch_cost(trainer, minibatch->features, minibatch->labels);
log_info("cost = %f\n", train_cost);
break;
}
language_classifier_minibatch_destroy(minibatch);
}
language_classifier_data_set_destroy(data_set);
return true;
}
typedef struct language_classifier_params {
double lambda;
double gamma_0;
} language_classifier_params_t;
language_classifier_params_t language_classifier_parameter_sweep(char *filename, char *cv_filename) {
// Select features using the full data set
logistic_regression_trainer_t *trainer = language_classifier_init(filename);
double best_cost = 0.0;
language_classifier_params_t best_params = (language_classifier_params_t){0.0, 0.0};
for (size_t i = 0; i < LAMBDA_SCHEDULE_SIZE; i++) {
for (size_t j = 0; j < GAMMA_SCHEDULE_SIZE; j++) {
trainer->lambda = LAMBDA_SCHEDULE[i];
trainer->gamma_0 = GAMMA_SCHEDULE[j];
log_info("Optimizing hyperparameters. Trying lambda=%f, gamma_0=%f\n", trainer->lambda, trainer->gamma_0);
for (int k = 0; k < HYPERPARAMETER_EPOCHS; k++) {
trainer->epochs = k;
if (!language_classifier_train_epoch(trainer, filename, NULL, LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES)) {
log_error("Error in epoch\n");
logistic_regression_trainer_destroy(trainer);
exit(EXIT_FAILURE);
}
}
ssize_t cost_batches;
char *cost_file;
if (cv_filename == NULL) {
cost_file = filename;
cost_batches = LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES;
} else {
cost_file = cv_filename;
cost_batches = -1;
}
double cost = compute_total_cost(trainer, cost_file, cost_batches);
log_info("Total cost = %f\n", cost);
if ((i == 0 && j == 0) || cost < best_cost) {
log_info("Better than current best parameters: lambda=%f, gamma_0=%f\n", trainer->lambda, trainer->gamma_0);
best_cost = cost;
best_params.lambda = trainer->lambda;
best_params.gamma_0 = trainer->gamma_0;
}
}
}
return best_params;
}
language_classifier_t *language_classifier_train(char *filename, char *subset_filename, bool cross_validation_set, char *cv_filename, char *test_filename, uint32_t num_iterations) {
language_classifier_params_t params = language_classifier_parameter_sweep(subset_filename, cv_filename);
log_info("Best params: lambda=%f, gamma_0=%f\n", params.lambda, params.gamma_0);
logistic_regression_trainer_t *trainer = language_classifier_init(filename);
trainer->lambda = params.lambda;
trainer->gamma_0 = params.gamma_0;
/* If there's not a distinct cross-validation set, e.g.
when training the production model, then the cross validation
file is just a subset of the training data and only used
for setting the hyperparameters, so ignore it after we're
done with the parameter sweep.
*/
if (!cross_validation_set) {
cv_filename = NULL;
}
for (uint32_t epoch = 0; epoch < num_iterations; epoch++) {
log_info("Doing epoch %d\n", epoch);
trainer->epochs = epoch;
if (!language_classifier_train_epoch(trainer, filename, cv_filename, -1)) {
log_error("Error in epoch\n");
logistic_regression_trainer_destroy(trainer);
return NULL;
@@ -224,7 +371,17 @@ language_classifier_t *language_classifier_train(char *filename, char *cv_filena
return NULL;
}
if (test_filename != NULL) {
double test_accuracy = compute_cv_accuracy(trainer, test_filename);
log_info("Test accuracy = %f\n", test_accuracy);
}
language_classifier_t *classifier = language_classifier_new();
if (classifier == NULL) {
log_error("Error creating classifier\n");
logistic_regression_trainer_destroy(trainer);
return NULL;
}
// Reassign weights and features to the classifier model
classifier->weights = trainer->weights;
@@ -232,6 +389,7 @@ language_classifier_t *language_classifier_train(char *filename, char *cv_filena
classifier->num_features = trainer->num_features;
classifier->features = trainer->feature_ids;
// Set trainer feature_ids to NULL so it doesn't get destroyed
trainer->feature_ids = NULL;
size_t num_labels = trainer->num_labels;
@@ -262,7 +420,7 @@ language_classifier_t *language_classifier_train(char *filename, char *cv_filena
}
#define LANGUAGE_CLASSIFIER_TRAIN_USAGE "Usage: ./address_parser_train [train|cv] filename [cv_filename] [output_dir]\n"
#define LANGUAGE_CLASSIFIER_TRAIN_USAGE "Usage: ./address_parser_train [train|cv] filename [cv_filename] [test_filename] [output_dir]\n"
int main(int argc, char **argv) {
if (argc < 3) {
@@ -271,10 +429,10 @@ int main(int argc, char **argv) {
}
char *command = argv[1];
bool cross_validate = false;
bool cross_validation_set = false;
if (string_equals(command, "cv")) {
cross_validate = true;
cross_validation_set = true;
} else if (!string_equals(command, "train")) {
printf(LANGUAGE_CLASSIFIER_TRAIN_USAGE);
exit(EXIT_FAILURE);
@@ -282,16 +440,18 @@ int main(int argc, char **argv) {
char *filename = argv[2];
char *cv_filename = NULL;
char *test_filename = NULL;
if (cross_validate && argc < 4) {
if (cross_validation_set && argc < 5) {
printf(LANGUAGE_CLASSIFIER_TRAIN_USAGE);
exit(EXIT_FAILURE);
} else if (cross_validate) {
} else if (cross_validation_set) {
cv_filename = argv[3];
test_filename = argv[4];
}
char *output_dir = LIBPOSTAL_LANGUAGE_CLASSIFIER_DIR;
int output_dir_arg = cross_validate ? 4 : 3;
int output_dir_arg = cross_validation_set ? 5 : 3;
if (argc > output_dir_arg) {
output_dir = argv[output_dir_arg];
@@ -306,7 +466,50 @@ int main(int argc, char **argv) {
exit(EXIT_FAILURE);
}
language_classifier_t *language_classifier = language_classifier_train(filename, cv_filename, TRAIN_EPOCHS, false);
char_array *temp_file = char_array_new();
char_array_cat_printf(temp_file, ".%s.tmp", filename);
char *temp_filename = char_array_get_string(temp_file);
char_array *head_command = char_array_new();
size_t subset_examples = LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES * LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE;
char_array_cat_printf(head_command, "head -n %d %s > %s", subset_examples, filename, temp_filename);
int ret = system(char_array_get_string(head_command));
if (ret != 0) {
exit(EXIT_FAILURE);
}
char_array *temp_cv_file = NULL;
if (!cross_validation_set) {
char_array_clear(head_command);
temp_cv_file = char_array_new();
char_array_cat_printf(temp_cv_file, ".%s.cv.tmp", filename);
char *temp_cv_filename = char_array_get_string(temp_cv_file);
char_array_cat_printf(head_command, "head -n %d %s | tail -n %d > %s", subset_examples * 2, filename, subset_examples, temp_cv_filename);
int ret = system(char_array_get_string(head_command));
cv_filename = temp_cv_filename;
}
if (ret != 0) {
exit(EXIT_FAILURE);
}
char_array_destroy(head_command);
language_classifier_t *language_classifier = language_classifier_train(filename, temp_filename, cross_validation_set, cv_filename, test_filename, TRAIN_EPOCHS);
remove(temp_filename);
char_array_destroy(temp_file);
if (temp_cv_file != NULL) {
char_array_destroy(temp_cv_file);
}
log_info("Done with classifier\n");
char_array *path = char_array_new_size(strlen(output_dir) + PATH_SEPARATOR_LEN + strlen(LANGUAGE_CLASSIFIER_COUNTRY_FILENAME));
@@ -320,18 +523,6 @@ int main(int argc, char **argv) {
language_classifier_destroy(language_classifier);
}
language_classifier_t *language_classifier_country = language_classifier_train(filename, cv_filename, TRAIN_EPOCHS, true);
if (language_classifier_country != NULL) {
char_array_clear(path);
char_array_cat_joined(path, PATH_SEPARATOR, true, 2, output_dir, LANGUAGE_CLASSIFIER_COUNTRY_FILENAME);
classifier_path = char_array_get_string(path);
language_classifier_save(language_classifier_country, classifier_path);
language_classifier_destroy(language_classifier_country);
}
char_array_destroy(path);
log_info("Success!\n");

View File

@@ -18,7 +18,6 @@ bool logistic_regression_model_expectation(matrix_t *theta, sparse_matrix_t *x,
return true;
}
double logistic_regression_cost_function(matrix_t *theta, sparse_matrix_t *x, uint32_array *y, matrix_t *p_y, double lambda) {
size_t m = x->m;
size_t n = x->n;
@@ -139,27 +138,28 @@ static bool logistic_regression_gradient_params(matrix_t *theta, matrix_t *gradi
// Update the only the relevant columns in x
if (regularize && x_cols != NULL) {
size_t batch_rows = x_cols->n;
uint32_t *cols = x_cols->a;
for (i = 0; i < batch_rows; i++) {
col = cols[i];
if (regularize) {
size_t num_rows = num_features;
uint32_t *cols = NULL;
if (x_cols != NULL) {
cols = x_cols->a;
num_rows = x_cols->n;
}
for (i = 0; i < num_rows; i++) {
col = x_cols != NULL ? cols[i] : i;
for (j = 0; j < num_classes; j++) {
size_t idx = i * num_classes + j;
size_t idx = col * num_classes + j;
double theta_ij = theta_values[idx];
double reg_value = theta_ij * lambda;
gradient_values[idx] += reg_value;
double reg_update = theta_ij * lambda;
double current_value = gradient_values[idx];
double updated_value = current_value + reg_update;
if (fabs(updated_value) == fabs(current_value)) {
gradient_values[idx] = updated_value;
}
}
} else if (regularize) {
for (i = 1; i < num_features; i++) {
for (j = 0; j < num_classes; j++) {
size_t idx = i * num_classes + j;
double theta_ij = theta_values[idx];
double reg_value = theta_ij * lambda;
gradient_values[idx] += reg_value;
}
}
}

View File

@@ -35,7 +35,7 @@ void logistic_regression_trainer_destroy(logistic_regression_trainer_t *self) {
free(self);
}
logistic_regression_trainer_t *logistic_regression_trainer_init(trie_t *feature_ids, khash_t(str_uint32) *label_ids) {
logistic_regression_trainer_t *logistic_regression_trainer_init(trie_t *feature_ids, khash_t(str_uint32) *label_ids, double gamma_0, double lambda) {
if (feature_ids == NULL || label_ids == NULL) return NULL;
logistic_regression_trainer_t *trainer = malloc(sizeof(logistic_regression_trainer_t));
@@ -57,11 +57,10 @@ logistic_regression_trainer_t *logistic_regression_trainer_init(trie_t *feature_
trainer->last_updated = uint32_array_new_zeros(trainer->num_features);
trainer->lambda = DEFAULT_LAMBDA;
trainer->lambda = lambda;
trainer->iters = 0;
trainer->epochs = 0;
trainer->gamma_0 = DEFAULT_GAMMA_0;
trainer->gamma = DEFAULT_GAMMA;
trainer->gamma_0 = gamma_0;
return trainer;
@@ -118,7 +117,7 @@ bool logistic_regression_trainer_train_batch(logistic_regression_trainer_t *self
goto exit_matrices_created;
}
if (self->lambda > 0.0 && !stochastic_gradient_descent_sparse_regularize_weights(self->weights, self->batch_columns, self->last_updated, self->iters, self->lambda, self->gamma_0)) {
if (self->lambda > 0.0 && !stochastic_gradient_descent_regularize_weights(self->weights, self->batch_columns, self->last_updated, self->iters, self->lambda, self->gamma_0)) {
log_error("Error regularizing weights\n");
goto exit_matrices_created;
}
@@ -130,7 +129,8 @@ bool logistic_regression_trainer_train_batch(logistic_regression_trainer_t *self
size_t data_len = m * n;
ret = stochastic_gradient_descent_sparse(self->weights, gradient, self->batch_columns, self->gamma);
double gamma = stochastic_gradient_descent_gamma_t(self->gamma_0, self->lambda, self->iters);
ret = stochastic_gradient_descent_sparse(self->weights, gradient, self->batch_columns, gamma);
self->iters++;
@@ -145,7 +145,7 @@ bool logistic_regression_trainer_finalize(logistic_regression_trainer_t *self) {
if (self == NULL) return false;
if (self->lambda > 0.0) {
return stochastic_gradient_descent_sparse_finalize_weights(self->weights, self->last_updated, self->iters, self->lambda, self->gamma_0);
return stochastic_gradient_descent_finalize_weights(self->weights, self->last_updated, self->iters, self->lambda, self->gamma_0);
}
return true;

View File

@@ -17,14 +17,6 @@
#include "tokens.h"
#include "trie.h"
#define DEFAULT_GAMMA_SCHEDULE {0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0}
#define DEFAUlT_LAMBDA_SCHEDULE {0.0, 1e-5, 1e-4, 0.001, 0.01, 0.1, \
0.2, 0.5, 1.0, 2.0, 5.0, 10.0}
#define DEFAULT_GAMMA_0 1.0
#define DEFAULT_LAMBDA 0.0
#define DEFAULT_GAMMA 0.1
/**
* Helper struct for training logistic regression model
*/
@@ -43,11 +35,10 @@ typedef struct logistic_regression_trainer {
uint32_t iters; // Number of iterations, used to decay learning rate
uint32_t epochs; // Number of epochs
double gamma_0; // Initial learning rate
double gamma; // Simple scalar learning rate
} logistic_regression_trainer_t;
logistic_regression_trainer_t *logistic_regression_trainer_init(trie_t *feature_ids, khash_t(str_uint32) *label_ids);
logistic_regression_trainer_t *logistic_regression_trainer_init(trie_t *feature_ids, khash_t(str_uint32) *label_ids, double gamma_0, double lambda);
bool logistic_regression_trainer_train_batch(logistic_regression_trainer_t *self, feature_count_array *features, cstring_array *labels);
double logistic_regression_trainer_batch_cost(logistic_regression_trainer_t *self, feature_count_array *features, cstring_array *labels);