From f808f74271130e7a49a9139dc50834a67c5f36a8 Mon Sep 17 00:00:00 2001 From: Al Date: Sun, 17 Jan 2016 21:11:37 -0500 Subject: [PATCH] [language_classification] Automatic hyperparameter optimization using either the cross-validation set or two distinct subsets of the training set --- src/language_classifier_io.c | 31 +-- src/language_classifier_io.h | 4 +- src/language_classifier_train.c | 315 ++++++++++++++++++++++++------ src/logistic_regression.c | 36 ++-- src/logistic_regression_trainer.c | 14 +- src/logistic_regression_trainer.h | 11 +- 6 files changed, 299 insertions(+), 112 deletions(-) diff --git a/src/language_classifier_io.c b/src/language_classifier_io.c index 6be0d84d..5efd8229 100644 --- a/src/language_classifier_io.c +++ b/src/language_classifier_io.c @@ -121,7 +121,7 @@ inline bool language_classifier_language_is_valid(char *language) { return !string_equals(language, AMBIGUOUS_LANGUAGE) && !string_equals(language, UNKNOWN_LANGUAGE); } -language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, size_t batch_size, bool with_country) { +language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, khash_t(str_uint32) *labels, size_t batch_size) { size_t in_batch = 0; language_classifier_minibatch_t *minibatch = NULL; @@ -131,17 +131,19 @@ language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with if (strlen(address) == 0) { continue; } - char *country = NULL; - if (with_country) { - country = char_array_get_string(self->country); - } + char *country = NULL; + //char *country = char_array_get_string(self->country); char *language = char_array_get_string(self->language); if (!language_classifier_language_is_valid(language)) { continue; } + if (labels != NULL && kh_get(str_uint32, labels, language) == kh_end(labels)) { + continue; + } + if (minibatch == NULL) { minibatch = language_classifier_minibatch_new(); if (minibatch == NULL) { @@ -150,13 +152,16 @@ language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with } } - khash_t(str_double) *feature_counts = extract_language_features(address, country, self->tokens, self->feature_array); - if (feature_counts == NULL) { - log_error("Could not extract features for: %s\n", address); - language_classifier_minibatch_destroy(minibatch); - return NULL; + if (labels != NULL) { + khash_t(str_double) *feature_counts = extract_language_features(address, country, self->tokens, self->feature_array); + if (feature_counts == NULL) { + log_error("Could not extract features for: %s\n", address); + language_classifier_minibatch_destroy(minibatch); + return NULL; + } + feature_count_array_push(minibatch->features, feature_counts); } - feature_count_array_push(minibatch->features, feature_counts); + cstring_array_add_string(minibatch->labels, language); in_batch++; } @@ -164,8 +169,8 @@ language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with return minibatch; } -inline language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, bool with_country) { - return language_classifier_data_set_get_minibatch_with_size(self, LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE, with_country); +inline language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, khash_t(str_uint32) *labels) { + return language_classifier_data_set_get_minibatch_with_size(self, labels, LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE); } void language_classifier_data_set_destroy(language_classifier_data_set_t *self) { diff --git a/src/language_classifier_io.h b/src/language_classifier_io.h index 09a4dead..8b1ac5c2 100644 --- a/src/language_classifier_io.h +++ b/src/language_classifier_io.h @@ -42,8 +42,8 @@ language_classifier_data_set_t *language_classifier_data_set_init(char *filename bool language_classifier_data_set_next(language_classifier_data_set_t *self); void language_classifier_data_set_destroy(language_classifier_data_set_t *self); -language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, size_t batch_size, bool with_country); -language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, bool with_country); +language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, khash_t(str_uint32) *labels, size_t batch_size); +language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, khash_t(str_uint32) *labels); void language_classifier_minibatch_destroy(language_classifier_minibatch_t *self); #endif \ No newline at end of file diff --git a/src/language_classifier_train.c b/src/language_classifier_train.c index c0c9033f..6773f414 100644 --- a/src/language_classifier_train.c +++ b/src/language_classifier_train.c @@ -10,16 +10,33 @@ #include "logistic_regression.h" #include "logistic_regression_trainer.h" #include "shuffle.h" +#include "sparse_matrix.h" +#include "sparse_matrix_utils.h" #define LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD 1.0 -#define LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD 1 +#define LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD 100 #define LOG_BATCH_INTERVAL 10 #define COMPUTE_COST_INTERVAL 100 +#define LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES 100 + +static double GAMMA_SCHEDULE[] = {0.01, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0}; +static const size_t GAMMA_SCHEDULE_SIZE = sizeof(GAMMA_SCHEDULE) / sizeof(double); + +#define DEFAULT_GAMMA_0 10.0 + +static double LAMBDA_SCHEDULE[] = {0.0, 1e-5, 1e-4, 0.001, 0.01, 0.1, \ + 0.2, 0.5, 1.0, 2.0, 5.0, 10.0}; +static const size_t LAMBDA_SCHEDULE_SIZE = sizeof(LAMBDA_SCHEDULE) / sizeof(double); + +#define DEFAULT_LAMBDA 0.0 + #define TRAIN_EPOCHS 10 -logistic_regression_trainer_t *language_classifier_init_thresholds(char *filename, bool with_country, double feature_count_threshold, uint32_t label_count_threshold) { +#define HYPERPARAMETER_EPOCHS 30 + +logistic_regression_trainer_t *language_classifier_init_thresholds(char *filename, double feature_count_threshold, uint32_t label_count_threshold) { if (filename == NULL) { log_error("Filename was NULL\n"); return NULL; @@ -34,26 +51,56 @@ logistic_regression_trainer_t *language_classifier_init_thresholds(char *filenam size_t num_batches = 0; // Count features and labels - while ((minibatch = language_classifier_data_set_get_minibatch(data_set, with_country)) != NULL) { - if (!count_features_minibatch(feature_counts, minibatch->features, true)){ - log_error("Counting minibatch features failed\n"); - exit(EXIT_FAILURE); - } - + while ((minibatch = language_classifier_data_set_get_minibatch(data_set, NULL)) != NULL) { if (!count_labels_minibatch(label_counts, minibatch->labels)) { log_error("Counting minibatch labeles failed\n"); exit(EXIT_FAILURE); } if (num_batches % LOG_BATCH_INTERVAL == 0) { - log_info("Counted %zu batches\n", num_batches); + log_info("Counting labels, did %zu batches\n", num_batches); } num_batches++; language_classifier_minibatch_destroy(minibatch); } - log_info("Done counting, finalizing\n"); + log_info("Done counting labels\n"); + + language_classifier_data_set_destroy(data_set); + + data_set = language_classifier_data_set_init(filename); + num_batches = 0; + + khash_t(str_uint32) *label_ids = select_labels_threshold(label_counts, label_count_threshold); + if (label_ids == NULL) { + log_error("Error creating labels\n"); + exit(EXIT_FAILURE); + } + + size_t num_labels = kh_size(label_ids); + log_info("num_labels=%zu\n", num_labels); + + // Don't free the label strings as the pointers are reused in select_labels_threshold + kh_destroy(str_uint32, label_counts); + + // Run through the training set again, counting only features which co-occur with valid classes + while ((minibatch = language_classifier_data_set_get_minibatch(data_set, label_ids)) != NULL) { + if (!count_features_minibatch(feature_counts, minibatch->features, true)){ + log_error("Counting minibatch features failed\n"); + exit(EXIT_FAILURE); + } + + if (num_batches % LOG_BATCH_INTERVAL == 0) { + log_info("Counting features, did %zu batches\n", num_batches); + } + + num_batches++; + + language_classifier_minibatch_destroy(minibatch); + } + + log_info("Done counting features, finalizing\n"); language_classifier_data_set_destroy(data_set); @@ -71,23 +118,15 @@ logistic_regression_trainer_t *language_classifier_init_thresholds(char *filenam }) kh_destroy(str_double, feature_counts); - khash_t(str_uint32) *label_ids = select_labels_threshold(label_counts, label_count_threshold); - if (label_ids == NULL) { - log_error("Error creating labels\n"); - exit(EXIT_FAILURE); - } - // Don't free the label strings as the pointers are reused in select_labels_threshold - kh_destroy(str_uint32, label_counts); - - return logistic_regression_trainer_init(feature_ids, label_ids); + return logistic_regression_trainer_init(feature_ids, label_ids, DEFAULT_GAMMA_0, DEFAULT_LAMBDA); } -logistic_regression_trainer_t *language_classifier_init(char *filename, bool with_country) { - return language_classifier_init_thresholds(filename, with_country, LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD, LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD); +logistic_regression_trainer_t *language_classifier_init(char *filename) { + return language_classifier_init_thresholds(filename, LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD, LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD); } -double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filename, bool with_country) { +double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filename) { language_classifier_data_set_t *data_set = language_classifier_data_set_init(filename); language_classifier_minibatch_t *minibatch; @@ -97,7 +136,7 @@ double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filenam matrix_t *p_y = matrix_new_zeros(LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE, trainer->num_labels); - while ((minibatch = language_classifier_data_set_get_minibatch(data_set, with_country)) != NULL) { + while ((minibatch = language_classifier_data_set_get_minibatch(data_set, trainer->label_ids)) != NULL) { sparse_matrix_t *x = feature_matrix(trainer->feature_ids, minibatch->features); uint32_array *y = label_vector(trainer->label_ids, minibatch->labels); @@ -138,12 +177,54 @@ double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filenam return accuracy; } -bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, bool with_country) { + + +double compute_total_cost(logistic_regression_trainer_t *trainer, char *filename, ssize_t compute_batches) { + language_classifier_data_set_t *data_set = language_classifier_data_set_init(filename); + + language_classifier_minibatch_t *minibatch; + + double total_cost = 0.0; + size_t num_batches = 0; + + while ((minibatch = language_classifier_data_set_get_minibatch(data_set, trainer->label_ids)) != NULL) { + + double batch_cost = logistic_regression_trainer_batch_cost(trainer, minibatch->features, minibatch->labels); + total_cost += batch_cost; + + language_classifier_minibatch_destroy(minibatch); + + num_batches++; + + if (compute_batches > 0 && num_batches == (size_t)compute_batches) { + break; + } + } + + language_classifier_data_set_destroy(data_set); + + return total_cost; +} + + +bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, ssize_t train_batches) { if (filename == NULL) { log_error("Filename was NULL\n"); return false; } + #if defined(HAVE_SHUF) + log_info("Shuffling\n"); + + if (!shuffle_file(filename)) { + log_error("Error in shuffle\n"); + logistic_regression_trainer_destroy(trainer); + return NULL; + } + + log_info("Shuffle complete\n"); + #endif + language_classifier_data_set_t *data_set = language_classifier_data_set_init(filename); language_classifier_minibatch_t *minibatch; @@ -156,16 +237,16 @@ bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, cha double train_cost = 0.0; double cv_accuracy = 0.0; - while ((minibatch = language_classifier_data_set_get_minibatch(data_set, with_country)) != NULL) { + while ((minibatch = language_classifier_data_set_get_minibatch(data_set, trainer->label_ids)) != NULL) { bool compute_cost = num_batches % COMPUTE_COST_INTERVAL == 0 && num_batches > 0; if (num_batches % LOG_BATCH_INTERVAL == 0 && num_batches > 0) { - log_info("Epoch %u, trained %zu batches\n", trainer->epochs, num_batches); + log_info("Epoch %u, doing batch %zu\n", trainer->epochs, num_batches); } if (compute_cost) { train_cost = logistic_regression_trainer_batch_cost(trainer, minibatch->features, minibatch->labels); - log_info("cost = %f\n", train_cost); + log_info("cost = %f\n", train_cost); } if (!logistic_regression_trainer_train_batch(trainer, minibatch->features, minibatch->labels)){ @@ -174,13 +255,21 @@ bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, cha } if (compute_cost && cv_filename != NULL) { - cv_accuracy = compute_cv_accuracy(trainer, cv_filename, with_country); + cv_accuracy = compute_cv_accuracy(trainer, cv_filename); log_info("cv accuracy=%f\n", cv_accuracy); } num_batches++; + if (train_batches > 0 && num_batches == (size_t)train_batches) { + log_info("Epoch %u, trained %zu batches\n", trainer->epochs, num_batches); + train_cost = logistic_regression_trainer_batch_cost(trainer, minibatch->features, minibatch->labels); + log_info("cost = %f\n", train_cost); + break; + } + language_classifier_minibatch_destroy(minibatch); + } language_classifier_data_set_destroy(data_set); @@ -188,28 +277,86 @@ bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, cha return true; } +typedef struct language_classifier_params { + double lambda; + double gamma_0; +} language_classifier_params_t; -language_classifier_t *language_classifier_train(char *filename, char *cv_filename, uint32_t num_iterations, bool with_country) { - logistic_regression_trainer_t *trainer = language_classifier_init(filename, with_country); +language_classifier_params_t language_classifier_parameter_sweep(char *filename, char *cv_filename) { + // Select features using the full data set + logistic_regression_trainer_t *trainer = language_classifier_init(filename); + + double best_cost = 0.0; + + language_classifier_params_t best_params = (language_classifier_params_t){0.0, 0.0}; + + for (size_t i = 0; i < LAMBDA_SCHEDULE_SIZE; i++) { + for (size_t j = 0; j < GAMMA_SCHEDULE_SIZE; j++) { + trainer->lambda = LAMBDA_SCHEDULE[i]; + trainer->gamma_0 = GAMMA_SCHEDULE[j]; + + log_info("Optimizing hyperparameters. Trying lambda=%f, gamma_0=%f\n", trainer->lambda, trainer->gamma_0); + + for (int k = 0; k < HYPERPARAMETER_EPOCHS; k++) { + trainer->epochs = k; + + if (!language_classifier_train_epoch(trainer, filename, NULL, LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES)) { + log_error("Error in epoch\n"); + logistic_regression_trainer_destroy(trainer); + exit(EXIT_FAILURE); + } + } + + ssize_t cost_batches; + char *cost_file; + + if (cv_filename == NULL) { + cost_file = filename; + cost_batches = LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES; + } else { + cost_file = cv_filename; + cost_batches = -1; + } + + double cost = compute_total_cost(trainer, cost_file, cost_batches); + log_info("Total cost = %f\n", cost); + if ((i == 0 && j == 0) || cost < best_cost) { + log_info("Better than current best parameters: lambda=%f, gamma_0=%f\n", trainer->lambda, trainer->gamma_0); + best_cost = cost; + best_params.lambda = trainer->lambda; + best_params.gamma_0 = trainer->gamma_0; + } + } + } + + return best_params; +} + + +language_classifier_t *language_classifier_train(char *filename, char *subset_filename, bool cross_validation_set, char *cv_filename, char *test_filename, uint32_t num_iterations) { + language_classifier_params_t params = language_classifier_parameter_sweep(subset_filename, cv_filename); + log_info("Best params: lambda=%f, gamma_0=%f\n", params.lambda, params.gamma_0); + + logistic_regression_trainer_t *trainer = language_classifier_init(filename); + trainer->lambda = params.lambda; + trainer->gamma_0 = params.gamma_0; + + /* If there's not a distinct cross-validation set, e.g. + when training the production model, then the cross validation + file is just a subset of the training data and only used + for setting the hyperparameters, so ignore it after we're + done with the parameter sweep. + */ + if (!cross_validation_set) { + cv_filename = NULL; + } for (uint32_t epoch = 0; epoch < num_iterations; epoch++) { log_info("Doing epoch %d\n", epoch); trainer->epochs = epoch; - - #if defined(HAVE_SHUF) - log_info("Shuffling\n"); - - if (!shuffle_file(filename)) { - log_error("Error in shuffle\n"); - logistic_regression_trainer_destroy(trainer); - return NULL; - } - - log_info("Shuffle complete\n"); - #endif - if (!language_classifier_train_epoch(trainer, filename, cv_filename, with_country)) { + if (!language_classifier_train_epoch(trainer, filename, cv_filename, -1)) { log_error("Error in epoch\n"); logistic_regression_trainer_destroy(trainer); return NULL; @@ -224,7 +371,17 @@ language_classifier_t *language_classifier_train(char *filename, char *cv_filena return NULL; } + if (test_filename != NULL) { + double test_accuracy = compute_cv_accuracy(trainer, test_filename); + log_info("Test accuracy = %f\n", test_accuracy); + } + language_classifier_t *classifier = language_classifier_new(); + if (classifier == NULL) { + log_error("Error creating classifier\n"); + logistic_regression_trainer_destroy(trainer); + return NULL; + } // Reassign weights and features to the classifier model classifier->weights = trainer->weights; @@ -232,6 +389,7 @@ language_classifier_t *language_classifier_train(char *filename, char *cv_filena classifier->num_features = trainer->num_features; classifier->features = trainer->feature_ids; + // Set trainer feature_ids to NULL so it doesn't get destroyed trainer->feature_ids = NULL; size_t num_labels = trainer->num_labels; @@ -262,7 +420,7 @@ language_classifier_t *language_classifier_train(char *filename, char *cv_filena } -#define LANGUAGE_CLASSIFIER_TRAIN_USAGE "Usage: ./address_parser_train [train|cv] filename [cv_filename] [output_dir]\n" +#define LANGUAGE_CLASSIFIER_TRAIN_USAGE "Usage: ./address_parser_train [train|cv] filename [cv_filename] [test_filename] [output_dir]\n" int main(int argc, char **argv) { if (argc < 3) { @@ -271,10 +429,10 @@ int main(int argc, char **argv) { } char *command = argv[1]; - bool cross_validate = false; + bool cross_validation_set = false; if (string_equals(command, "cv")) { - cross_validate = true; + cross_validation_set = true; } else if (!string_equals(command, "train")) { printf(LANGUAGE_CLASSIFIER_TRAIN_USAGE); exit(EXIT_FAILURE); @@ -282,16 +440,18 @@ int main(int argc, char **argv) { char *filename = argv[2]; char *cv_filename = NULL; + char *test_filename = NULL; - if (cross_validate && argc < 4) { + if (cross_validation_set && argc < 5) { printf(LANGUAGE_CLASSIFIER_TRAIN_USAGE); exit(EXIT_FAILURE); - } else if (cross_validate) { + } else if (cross_validation_set) { cv_filename = argv[3]; + test_filename = argv[4]; } char *output_dir = LIBPOSTAL_LANGUAGE_CLASSIFIER_DIR; - int output_dir_arg = cross_validate ? 4 : 3; + int output_dir_arg = cross_validation_set ? 5 : 3; if (argc > output_dir_arg) { output_dir = argv[output_dir_arg]; @@ -306,7 +466,50 @@ int main(int argc, char **argv) { exit(EXIT_FAILURE); } - language_classifier_t *language_classifier = language_classifier_train(filename, cv_filename, TRAIN_EPOCHS, false); + char_array *temp_file = char_array_new(); + char_array_cat_printf(temp_file, ".%s.tmp", filename); + + char *temp_filename = char_array_get_string(temp_file); + + char_array *head_command = char_array_new(); + + size_t subset_examples = LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES * LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE; + + char_array_cat_printf(head_command, "head -n %d %s > %s", subset_examples, filename, temp_filename); + int ret = system(char_array_get_string(head_command)); + + if (ret != 0) { + exit(EXIT_FAILURE); + } + + char_array *temp_cv_file = NULL; + if (!cross_validation_set) { + char_array_clear(head_command); + + temp_cv_file = char_array_new(); + char_array_cat_printf(temp_cv_file, ".%s.cv.tmp", filename); + + char *temp_cv_filename = char_array_get_string(temp_cv_file); + + char_array_cat_printf(head_command, "head -n %d %s | tail -n %d > %s", subset_examples * 2, filename, subset_examples, temp_cv_filename); + int ret = system(char_array_get_string(head_command)); + + cv_filename = temp_cv_filename; + } + + if (ret != 0) { + exit(EXIT_FAILURE); + } + + char_array_destroy(head_command); + + language_classifier_t *language_classifier = language_classifier_train(filename, temp_filename, cross_validation_set, cv_filename, test_filename, TRAIN_EPOCHS); + + remove(temp_filename); + char_array_destroy(temp_file); + if (temp_cv_file != NULL) { + char_array_destroy(temp_cv_file); + } log_info("Done with classifier\n"); char_array *path = char_array_new_size(strlen(output_dir) + PATH_SEPARATOR_LEN + strlen(LANGUAGE_CLASSIFIER_COUNTRY_FILENAME)); @@ -320,18 +523,6 @@ int main(int argc, char **argv) { language_classifier_destroy(language_classifier); } - language_classifier_t *language_classifier_country = language_classifier_train(filename, cv_filename, TRAIN_EPOCHS, true); - - if (language_classifier_country != NULL) { - char_array_clear(path); - char_array_cat_joined(path, PATH_SEPARATOR, true, 2, output_dir, LANGUAGE_CLASSIFIER_COUNTRY_FILENAME); - - classifier_path = char_array_get_string(path); - - language_classifier_save(language_classifier_country, classifier_path); - language_classifier_destroy(language_classifier_country); - } - char_array_destroy(path); log_info("Success!\n"); diff --git a/src/logistic_regression.c b/src/logistic_regression.c index ef970a73..3a865af8 100644 --- a/src/logistic_regression.c +++ b/src/logistic_regression.c @@ -18,7 +18,6 @@ bool logistic_regression_model_expectation(matrix_t *theta, sparse_matrix_t *x, return true; } - double logistic_regression_cost_function(matrix_t *theta, sparse_matrix_t *x, uint32_array *y, matrix_t *p_y, double lambda) { size_t m = x->m; size_t n = x->n; @@ -139,28 +138,29 @@ static bool logistic_regression_gradient_params(matrix_t *theta, matrix_t *gradi // Update the only the relevant columns in x - if (regularize && x_cols != NULL) { - size_t batch_rows = x_cols->n; - uint32_t *cols = x_cols->a; - for (i = 0; i < batch_rows; i++) { - col = cols[i]; + if (regularize) { + size_t num_rows = num_features; + uint32_t *cols = NULL; + + if (x_cols != NULL) { + cols = x_cols->a; + num_rows = x_cols->n; + } + + for (i = 0; i < num_rows; i++) { + col = x_cols != NULL ? cols[i] : i; for (j = 0; j < num_classes; j++) { - size_t idx = i * num_classes + j; + size_t idx = col * num_classes + j; double theta_ij = theta_values[idx]; - double reg_value = theta_ij * lambda; - gradient_values[idx] += reg_value; + double reg_update = theta_ij * lambda; + double current_value = gradient_values[idx]; + double updated_value = current_value + reg_update; + if (fabs(updated_value) == fabs(current_value)) { + gradient_values[idx] = updated_value; + } } } - } else if (regularize) { - for (i = 1; i < num_features; i++) { - for (j = 0; j < num_classes; j++) { - size_t idx = i * num_classes + j; - double theta_ij = theta_values[idx]; - double reg_value = theta_ij * lambda; - gradient_values[idx] += reg_value; - } - } } return true; diff --git a/src/logistic_regression_trainer.c b/src/logistic_regression_trainer.c index 775636d7..b2df4533 100644 --- a/src/logistic_regression_trainer.c +++ b/src/logistic_regression_trainer.c @@ -35,7 +35,7 @@ void logistic_regression_trainer_destroy(logistic_regression_trainer_t *self) { free(self); } -logistic_regression_trainer_t *logistic_regression_trainer_init(trie_t *feature_ids, khash_t(str_uint32) *label_ids) { +logistic_regression_trainer_t *logistic_regression_trainer_init(trie_t *feature_ids, khash_t(str_uint32) *label_ids, double gamma_0, double lambda) { if (feature_ids == NULL || label_ids == NULL) return NULL; logistic_regression_trainer_t *trainer = malloc(sizeof(logistic_regression_trainer_t)); @@ -57,11 +57,10 @@ logistic_regression_trainer_t *logistic_regression_trainer_init(trie_t *feature_ trainer->last_updated = uint32_array_new_zeros(trainer->num_features); - trainer->lambda = DEFAULT_LAMBDA; + trainer->lambda = lambda; trainer->iters = 0; trainer->epochs = 0; - trainer->gamma_0 = DEFAULT_GAMMA_0; - trainer->gamma = DEFAULT_GAMMA; + trainer->gamma_0 = gamma_0; return trainer; @@ -118,7 +117,7 @@ bool logistic_regression_trainer_train_batch(logistic_regression_trainer_t *self goto exit_matrices_created; } - if (self->lambda > 0.0 && !stochastic_gradient_descent_sparse_regularize_weights(self->weights, self->batch_columns, self->last_updated, self->iters, self->lambda, self->gamma_0)) { + if (self->lambda > 0.0 && !stochastic_gradient_descent_regularize_weights(self->weights, self->batch_columns, self->last_updated, self->iters, self->lambda, self->gamma_0)) { log_error("Error regularizing weights\n"); goto exit_matrices_created; } @@ -130,7 +129,8 @@ bool logistic_regression_trainer_train_batch(logistic_regression_trainer_t *self size_t data_len = m * n; - ret = stochastic_gradient_descent_sparse(self->weights, gradient, self->batch_columns, self->gamma); + double gamma = stochastic_gradient_descent_gamma_t(self->gamma_0, self->lambda, self->iters); + ret = stochastic_gradient_descent_sparse(self->weights, gradient, self->batch_columns, gamma); self->iters++; @@ -145,7 +145,7 @@ bool logistic_regression_trainer_finalize(logistic_regression_trainer_t *self) { if (self == NULL) return false; if (self->lambda > 0.0) { - return stochastic_gradient_descent_sparse_finalize_weights(self->weights, self->last_updated, self->iters, self->lambda, self->gamma_0); + return stochastic_gradient_descent_finalize_weights(self->weights, self->last_updated, self->iters, self->lambda, self->gamma_0); } return true; diff --git a/src/logistic_regression_trainer.h b/src/logistic_regression_trainer.h index a552ac41..bc647ad7 100644 --- a/src/logistic_regression_trainer.h +++ b/src/logistic_regression_trainer.h @@ -17,14 +17,6 @@ #include "tokens.h" #include "trie.h" -#define DEFAULT_GAMMA_SCHEDULE {0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0} -#define DEFAUlT_LAMBDA_SCHEDULE {0.0, 1e-5, 1e-4, 0.001, 0.01, 0.1, \ - 0.2, 0.5, 1.0, 2.0, 5.0, 10.0} - -#define DEFAULT_GAMMA_0 1.0 -#define DEFAULT_LAMBDA 0.0 -#define DEFAULT_GAMMA 0.1 - /** * Helper struct for training logistic regression model */ @@ -43,11 +35,10 @@ typedef struct logistic_regression_trainer { uint32_t iters; // Number of iterations, used to decay learning rate uint32_t epochs; // Number of epochs double gamma_0; // Initial learning rate - double gamma; // Simple scalar learning rate } logistic_regression_trainer_t; -logistic_regression_trainer_t *logistic_regression_trainer_init(trie_t *feature_ids, khash_t(str_uint32) *label_ids); +logistic_regression_trainer_t *logistic_regression_trainer_init(trie_t *feature_ids, khash_t(str_uint32) *label_ids, double gamma_0, double lambda); bool logistic_regression_trainer_train_batch(logistic_regression_trainer_t *self, feature_count_array *features, cstring_array *labels); double logistic_regression_trainer_batch_cost(logistic_regression_trainer_t *self, feature_count_array *features, cstring_array *labels);