diff --git a/src/language_classifier_train.c b/src/language_classifier_train.c index ddf73719..5e06dd27 100644 --- a/src/language_classifier_train.c +++ b/src/language_classifier_train.c @@ -2,9 +2,12 @@ #include #include #include +#include #include "log/log.h" #include "address_dictionary.h" +#include "cartesian_product.h" +#include "collections.h" #include "language_classifier.h" #include "language_classifier_io.h" #include "logistic_regression.h" @@ -12,31 +15,50 @@ #include "shuffle.h" #include "sparse_matrix.h" #include "sparse_matrix_utils.h" +#include "stochastic_gradient_descent.h" +#include "transliterate.h" -#define LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD 5.0 +#define LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD 3.0 #define LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD 100 #define LOG_BATCH_INTERVAL 10 #define COMPUTE_COST_INTERVAL 100 +#define COMPUTE_CV_INTERVAL 1000 -#define LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES 100 +#define LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES 50 + +// Hyperparameters for stochastic gradient descent static double GAMMA_SCHEDULE[] = {0.01, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0}; static const size_t GAMMA_SCHEDULE_SIZE = sizeof(GAMMA_SCHEDULE) / sizeof(double); #define DEFAULT_GAMMA_0 10.0 -static double LAMBDA_SCHEDULE[] = {0.0, 1e-5, 1e-4, 0.001, 0.01, 0.1, \ - 0.2, 0.5, 1.0}; -static const size_t LAMBDA_SCHEDULE_SIZE = sizeof(LAMBDA_SCHEDULE) / sizeof(double); +#define REGULARIZATION_SCHEDULE {0.0, 1e-7, 1e-6, 1e-5, 1e-4, 0.001, 0.01, 0.1, \ + 0.2, 0.5, 1.0, 2.0, 5.0, 10.0} -#define DEFAULT_LAMBDA 0.0 +static double L2_SCHEDULE[] = REGULARIZATION_SCHEDULE; +static const size_t L2_SCHEDULE_SIZE = sizeof(L2_SCHEDULE) / sizeof(double); + +static double L1_SCHEDULE[] = REGULARIZATION_SCHEDULE; +static const size_t L1_SCHEDULE_SIZE = sizeof(L1_SCHEDULE) / sizeof(double); + +#define DEFAULT_L2 1e-6 +#define DEFAULT_L1 1e-4 + +// Hyperparameters for FTRL-Proximal + +static double ALPHA_SCHEDULE[] = {0.01, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0}; +static const size_t ALPHA_SCHEDULE_SIZE = sizeof(ALPHA_SCHEDULE) / sizeof(double); +static double DEFAULT_BETA = 1.0; + +#define DEFAULT_ALPHA 10.0 #define TRAIN_EPOCHS 10 -#define HYPERPARAMETER_EPOCHS 30 +#define HYPERPARAMETER_EPOCHS 5 -logistic_regression_trainer_t *language_classifier_init_thresholds(char *filename, double feature_count_threshold, uint32_t label_count_threshold) { +logistic_regression_trainer_t *language_classifier_init_params(char *filename, double feature_count_threshold, uint32_t label_count_threshold, size_t minibatch_size, logistic_regression_optimizer_type optim_type, regularization_type_t reg_type) { if (filename == NULL) { log_error("Filename was NULL\n"); return NULL; @@ -51,14 +73,14 @@ logistic_regression_trainer_t *language_classifier_init_thresholds(char *filenam size_t num_batches = 0; // Count features and labels - while ((minibatch = language_classifier_data_set_get_minibatch(data_set, NULL)) != NULL) { + while ((minibatch = language_classifier_data_set_get_minibatch_with_size(data_set, NULL, minibatch_size)) != NULL) { if (!count_labels_minibatch(label_counts, minibatch->labels)) { log_error("Counting minibatch labeles failed\n"); exit(EXIT_FAILURE); } - if (num_batches % LOG_BATCH_INTERVAL == 0) { - log_info("Counting labels, did %zu batches\n", num_batches); + if (num_batches % LOG_BATCH_INTERVAL == 0 && num_batches > 0) { + log_info("Counting labels, did %zu examples\n", num_batches * minibatch_size); } num_batches++; @@ -91,8 +113,8 @@ logistic_regression_trainer_t *language_classifier_init_thresholds(char *filenam exit(EXIT_FAILURE); } - if (num_batches % LOG_BATCH_INTERVAL == 0) { - log_info("Counting features, did %zu batches\n", num_batches); + if (num_batches % LOG_BATCH_INTERVAL == 0 && num_batches > 0) { + log_info("Counting features, did %zu examples\n", num_batches * minibatch_size); } num_batches++; @@ -119,11 +141,34 @@ logistic_regression_trainer_t *language_classifier_init_thresholds(char *filenam kh_destroy(str_double, feature_counts); - return logistic_regression_trainer_init(feature_ids, label_ids, DEFAULT_GAMMA_0, DEFAULT_LAMBDA); + logistic_regression_trainer_t *trainer = NULL; + + if (optim_type == LOGISTIC_REGRESSION_OPTIMIZER_SGD) { + bool fit_intercept = true; + double default_lambda = 0.0; + if (reg_type == REGULARIZATION_L2){ + default_lambda = DEFAULT_L2; + } else if (reg_type == REGULARIZATION_L1) { + default_lambda = DEFAULT_L1; + } + trainer = logistic_regression_trainer_init_sgd(feature_ids, label_ids, fit_intercept, reg_type, default_lambda, DEFAULT_GAMMA_0); + } else if (optim_type == LOGISTIC_REGRESSION_OPTIMIZER_FTRL) { + trainer = logistic_regression_trainer_init_ftrl(feature_ids, label_ids, DEFAULT_ALPHA, DEFAULT_BETA, DEFAULT_L1, DEFAULT_L2); + } + + return trainer; } -logistic_regression_trainer_t *language_classifier_init(char *filename) { - return language_classifier_init_thresholds(filename, LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD, LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD); +logistic_regression_trainer_t *language_classifier_init_optim_reg(char *filename, size_t minibatch_size, logistic_regression_optimizer_type optim_type, regularization_type_t reg_type) { + return language_classifier_init_params(filename, LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD, LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD, minibatch_size, optim_type, reg_type); +} + +logistic_regression_trainer_t *language_classifier_init_sgd_reg(char *filename, size_t minibatch_size, regularization_type_t reg_type) { + return language_classifier_init_params(filename, LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD, LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD, minibatch_size, LOGISTIC_REGRESSION_OPTIMIZER_SGD, reg_type); +} + +logistic_regression_trainer_t *language_classifier_init_ftrl(char *filename, size_t minibatch_size) { + return language_classifier_init_params(filename, LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD, LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD, minibatch_size, LOGISTIC_REGRESSION_OPTIMIZER_FTRL, REGULARIZATION_NONE); } double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filename) { @@ -140,9 +185,20 @@ double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filenam sparse_matrix_t *x = feature_matrix(trainer->feature_ids, minibatch->features); uint32_array *y = label_vector(trainer->label_ids, minibatch->labels); - double_matrix_resize(p_y, x->m, trainer->num_labels); + if (!double_matrix_resize_aligned(p_y, x->m, trainer->num_labels, 16)) { + log_error("resize p_y failed\n"); + exit(EXIT_FAILURE); + } + double_matrix_zero(p_y); - if (!logistic_regression_model_expectation(trainer->weights, x, p_y)) { + if (!sparse_matrix_add_unique_columns(x, trainer->unique_columns, trainer->batch_columns)) { + log_error("Error adding unique columns\n"); + exit(EXIT_FAILURE); + } + + double_matrix_t *theta = logistic_regression_trainer_get_regularized_weights(trainer); + + if (!logistic_regression_model_expectation(theta, x, p_y)) { log_error("Predict cv batch failed\n"); exit(EXIT_FAILURE); } @@ -187,9 +243,12 @@ double compute_total_cost(logistic_regression_trainer_t *trainer, char *filename double total_cost = 0.0; size_t num_batches = 0; + // Need to regularize the weights + double_matrix_t *theta = logistic_regression_trainer_get_regularized_weights(trainer); + while ((minibatch = language_classifier_data_set_get_minibatch(data_set, trainer->label_ids)) != NULL) { - double batch_cost = logistic_regression_trainer_batch_cost(trainer, minibatch->features, minibatch->labels); + double batch_cost = logistic_regression_trainer_minibatch_cost(trainer, minibatch->features, minibatch->labels); total_cost += batch_cost; language_classifier_minibatch_destroy(minibatch); @@ -207,16 +266,16 @@ double compute_total_cost(logistic_regression_trainer_t *trainer, char *filename } -bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, ssize_t train_batches) { +bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, ssize_t train_batches, size_t minibatch_size) { if (filename == NULL) { log_error("Filename was NULL\n"); return false; } - #if defined(HAVE_SHUF) + #if defined(HAVE_SHUF) || defined(HAVE_GSHUF) log_info("Shuffling\n"); - if (!shuffle_file(filename)) { + if (!shuffle_file_chunked_size(filename, DEFAULT_SHUFFLE_CHUNK_SIZE)) { log_error("Error in shuffle\n"); logistic_regression_trainer_destroy(trainer); return NULL; @@ -237,24 +296,25 @@ bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, cha double train_cost = 0.0; double cv_accuracy = 0.0; - while ((minibatch = language_classifier_data_set_get_minibatch(data_set, trainer->label_ids)) != NULL) { - bool compute_cost = num_batches % COMPUTE_COST_INTERVAL == 0 && num_batches > 0; - + while ((minibatch = language_classifier_data_set_get_minibatch_with_size(data_set, trainer->label_ids, minibatch_size)) != NULL) { + bool compute_cost = num_batches % COMPUTE_COST_INTERVAL == 0; + bool compute_cv = num_batches % COMPUTE_CV_INTERVAL == 0 && num_batches > 0 && cv_filename != NULL; + if (num_batches % LOG_BATCH_INTERVAL == 0 && num_batches > 0) { - log_info("Epoch %u, doing batch %zu\n", trainer->epochs, num_batches); + log_info("Epoch %u, doing %zu examples\n", trainer->epochs, num_batches * minibatch_size); } if (compute_cost) { - train_cost = logistic_regression_trainer_batch_cost(trainer, minibatch->features, minibatch->labels); + train_cost = logistic_regression_trainer_minibatch_cost(trainer, minibatch->features, minibatch->labels); log_info("cost = %f\n", train_cost); } - if (!logistic_regression_trainer_train_batch(trainer, minibatch->features, minibatch->labels)){ + if (!logistic_regression_trainer_train_minibatch(trainer, minibatch->features, minibatch->labels)){ log_error("Train batch failed\n"); exit(EXIT_FAILURE); } - if (compute_cost && cv_filename != NULL) { + if (compute_cv) { cv_accuracy = compute_cv_accuracy(trainer, cv_filename); log_info("cv accuracy=%f\n", cv_accuracy); } @@ -262,8 +322,8 @@ bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, cha num_batches++; if (train_batches > 0 && num_batches == (size_t)train_batches) { - log_info("Epoch %u, trained %zu batches\n", trainer->epochs, num_batches); - train_cost = logistic_regression_trainer_batch_cost(trainer, minibatch->features, minibatch->labels); + log_info("Epoch %u, trained %zu examples\n", trainer->epochs, num_batches * minibatch_size); + train_cost = logistic_regression_trainer_minibatch_cost(trainer, minibatch->features, minibatch->labels); log_info("cost = %f\n", train_cost); break; } @@ -277,100 +337,236 @@ bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, cha return true; } -typedef struct language_classifier_params { +static double language_classifier_cv_cost(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, size_t minibatch_size, bool *diverged) { + ssize_t cost_batches; + char *cost_file; + + if (cv_filename == NULL) { + cost_file = filename; + cost_batches = LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES; + } else { + cost_file = cv_filename; + cost_batches = -1; + } + + double initial_cost = compute_total_cost(trainer, cost_file, cost_batches); + + for (size_t k = 0; k < HYPERPARAMETER_EPOCHS; k++) { + trainer->epochs = k; + + if (!language_classifier_train_epoch(trainer, filename, NULL, LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES, minibatch_size)) { + log_error("Error in epoch\n"); + logistic_regression_trainer_destroy(trainer); + exit(EXIT_FAILURE); + } + } + + double final_cost = compute_total_cost(trainer, cost_file, cost_batches); + + *diverged = final_cost > initial_cost; + log_info("final_cost = %f, initial_cost = %f\n", final_cost, initial_cost); + + return final_cost; +} + +typedef struct language_classifier_sgd_params { double lambda; double gamma_0; -} language_classifier_params_t; +} language_classifier_sgd_params_t; -language_classifier_params_t language_classifier_parameter_sweep(char *filename, char *cv_filename) { - // Select features using the full data set - logistic_regression_trainer_t *trainer = language_classifier_init(filename); +typedef struct language_classifier_ftrl_params { + double alpha; + double lambda1; + double lambda2; +} language_classifier_ftrl_params_t; - double best_cost = 0.0; +VECTOR_INIT(language_classifier_sgd_param_array, language_classifier_sgd_params_t) +VECTOR_INIT(language_classifier_ftrl_param_array, language_classifier_ftrl_params_t) - language_classifier_params_t best_params = (language_classifier_params_t){0.0, 0.0}; +/* Uses the one standard-error rule (http://www.stat.cmu.edu/~ryantibs/datamining/lectures/19-val2.pdf) + A solution that's better regularized is preferred if it's within one standard error + of the solution with the lowest cross-validation error. +*/ - for (size_t i = 0; i < LAMBDA_SCHEDULE_SIZE; i++) { - for (size_t j = 0; j < GAMMA_SCHEDULE_SIZE; j++) { - trainer->lambda = LAMBDA_SCHEDULE[i]; - trainer->gamma_0 = GAMMA_SCHEDULE[j]; +language_classifier_sgd_params_t language_classifier_parameter_sweep_sgd(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, size_t minibatch_size) { + double best_cost = DBL_MAX; - log_info("Optimizing hyperparameters. Trying lambda=%f, gamma_0=%f\n", trainer->lambda, trainer->gamma_0); + double default_lambda = 0.0; + size_t lambda_schedule_size = 0; + double *lambda_schedule = NULL; - for (int k = 0; k < HYPERPARAMETER_EPOCHS; k++) { - trainer->epochs = k; + sgd_trainer_t *sgd = trainer->optimizer.sgd; - if (!language_classifier_train_epoch(trainer, filename, NULL, LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES)) { - log_error("Error in epoch\n"); - logistic_regression_trainer_destroy(trainer); - exit(EXIT_FAILURE); - } - } + if (sgd->reg_type == REGULARIZATION_L2) { + default_lambda = DEFAULT_L2; + lambda_schedule_size = L2_SCHEDULE_SIZE; + lambda_schedule = L2_SCHEDULE; + } else if (sgd->reg_type == REGULARIZATION_L1) { + lambda_schedule_size = L1_SCHEDULE_SIZE; + lambda_schedule = L1_SCHEDULE; + default_lambda = DEFAULT_L1; + } - ssize_t cost_batches; - char *cost_file; + double_array *costs = double_array_new(); + language_classifier_sgd_param_array *all_params = language_classifier_sgd_param_array_new(); - if (cv_filename == NULL) { - cost_file = filename; - cost_batches = LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES; - } else { - cost_file = cv_filename; - cost_batches = -1; - } + language_classifier_sgd_params_t best_params = (language_classifier_sgd_params_t){default_lambda, DEFAULT_GAMMA_0}; + double cost; + language_classifier_sgd_params_t params; - double cost = compute_total_cost(trainer, cost_file, cost_batches); - log_info("Total cost = %f\n", cost); - if ((i == 0 && j == 0) || cost < best_cost) { - log_info("Better than current best parameters: lambda=%f, gamma_0=%f\n", trainer->lambda, trainer->gamma_0); - best_cost = cost; - best_params.lambda = trainer->lambda; - best_params.gamma_0 = trainer->gamma_0; + cartesian_product_iterator_t *iter = cartesian_product_iterator_new(2, lambda_schedule_size, GAMMA_SCHEDULE_SIZE); + for (uint32_t *vals = cartesian_product_iterator_start(iter); !cartesian_product_iterator_done(iter); vals = cartesian_product_iterator_next(iter)) { + double lambda = lambda_schedule[vals[0]]; + double gamma_0 = GAMMA_SCHEDULE[vals[1]]; + + params.lambda = lambda, + params.gamma_0 = gamma_0; + + if (!logistic_regression_trainer_reset_params_sgd(trainer, lambda, gamma_0)) { + log_error("Error resetting params\n"); + logistic_regression_trainer_destroy(trainer); + exit(EXIT_FAILURE); + } + + log_info("Optimizing hyperparameters. Trying lambda=%.7f, gamma_0=%f\n", lambda, gamma_0); + + bool diverged = false; + cost = language_classifier_cv_cost(trainer, filename, cv_filename, minibatch_size, &diverged); + + if (!diverged) { + language_classifier_sgd_param_array_push(all_params, params); + double_array_push(costs, cost); + } else { + log_info("Diverged, cost = %f\n", cost); + } + + log_info("Total cost = %f\n", cost); + if (cost < best_cost) { + log_info("Better than current best parameters: setting lambda=%.7f, gamma_0=%f\n", lambda, gamma_0); + best_cost = cost; + best_params.lambda = lambda; + best_params.gamma_0 = gamma_0; + } + } + + size_t num_params = costs->n; + if (num_params > 0) { + language_classifier_sgd_params_t *param_values = all_params->a; + double *cost_values = costs->a; + + double std_error = double_array_std(cost_values, num_params) / sqrt((double)num_params); + + double max_cost = best_cost + std_error; + log_info("max_cost = %f using the one standard error rule\n", max_cost); + + for (size_t i = 0; i < num_params; i++) { + cost = cost_values[i]; + params = param_values[i]; + + if (cost < max_cost && params.lambda > best_params.lambda) { + best_params = params; + log_info("cost (%f) < max_cost and better regularized, setting lambda=%.7f, gamma_0=%f\n", cost, params.lambda, params.gamma_0); } } } + language_classifier_sgd_param_array_destroy(all_params); + double_array_destroy(costs); + return best_params; } -language_classifier_t *language_classifier_train(char *filename, char *subset_filename, bool cross_validation_set, char *cv_filename, char *test_filename, uint32_t num_iterations) { - language_classifier_params_t params = language_classifier_parameter_sweep(subset_filename, cv_filename); - log_info("Best params: lambda=%f, gamma_0=%f\n", params.lambda, params.gamma_0); - - logistic_regression_trainer_t *trainer = language_classifier_init(filename); - trainer->lambda = params.lambda; - trainer->gamma_0 = params.gamma_0; +language_classifier_ftrl_params_t language_classifier_parameter_sweep_ftrl(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, size_t minibatch_size) { + double best_cost = DBL_MAX; - /* If there's not a distinct cross-validation set, e.g. - when training the production model, then the cross validation - file is just a subset of the training data and only used - for setting the hyperparameters, so ignore it after we're - done with the parameter sweep. - */ - if (!cross_validation_set) { - cv_filename = NULL; - } + language_classifier_ftrl_params_t best_params = (language_classifier_ftrl_params_t){DEFAULT_ALPHA, DEFAULT_L1, DEFAULT_L2}; - for (uint32_t epoch = 0; epoch < num_iterations; epoch++) { - log_info("Doing epoch %d\n", epoch); + double_array *costs = double_array_new(); + language_classifier_ftrl_param_array *all_params = language_classifier_ftrl_param_array_new(); + language_classifier_ftrl_params_t params; + double cost; - trainer->epochs = epoch; - - if (!language_classifier_train_epoch(trainer, filename, cv_filename, -1)) { - log_error("Error in epoch\n"); + cartesian_product_iterator_t *iter = cartesian_product_iterator_new(3, L1_SCHEDULE_SIZE, L2_SCHEDULE_SIZE, ALPHA_SCHEDULE_SIZE); + for (uint32_t *vals = cartesian_product_iterator_start(iter); !cartesian_product_iterator_done(iter); vals = cartesian_product_iterator_next(iter)) { + double lambda1 = L1_SCHEDULE[vals[0]]; + double lambda2 = L2_SCHEDULE[vals[1]]; + double alpha = ALPHA_SCHEDULE[vals[2]]; + + params.lambda1 = lambda1, + params.lambda2 = lambda2, + params.alpha = alpha; + + if (!logistic_regression_trainer_reset_params_ftrl(trainer, alpha, DEFAULT_BETA, lambda1, lambda2)) { + log_error("Error resetting params\n"); logistic_regression_trainer_destroy(trainer); - return NULL; + exit(EXIT_FAILURE); + } + + log_info("Optimizing hyperparameters. Trying lambda1=%.7f, lambda2=%.7f, alpha=%f\n", lambda1, lambda2, alpha); + + bool diverged = false; + cost = language_classifier_cv_cost(trainer, filename, cv_filename, minibatch_size, &diverged); + + if (!diverged) { + language_classifier_ftrl_param_array_push(all_params, params); + double_array_push(costs, cost); + } else { + log_info("Diverged, cost = %f\n", cost); + } + + log_info("Total cost = %f\n", cost); + if (cost < best_cost) { + log_info("Better than current best parameters: setting lambda1=%.7f, lambda2=%.7f, alpha=%f\n", lambda1, lambda2, alpha); + best_cost = cost; + best_params.lambda1 = lambda1; + best_params.lambda2 = lambda2; + best_params.alpha = alpha; } } - log_info("Done training\n"); + size_t num_params = costs->n; + if (num_params > 0) { + language_classifier_ftrl_params_t *param_values = all_params->a; + double *cost_values = costs->a; - if (!logistic_regression_trainer_finalize(trainer)) { - log_error("Error in finalization\n"); - logistic_regression_trainer_destroy(trainer); - return NULL; + double std_error = double_array_std(cost_values, num_params) / sqrt((double)num_params); + + double max_cost = best_cost + std_error; + log_info("best_cost = %f, std_error = %f, max_cost = %f using the one standard error rule\n", best_cost, std_error, max_cost); + + for (size_t i = 0; i < num_params; i++) { + cost = cost_values[i]; + params = param_values[i]; + + log_info("cost = %f, lambda1 = %f, lambda2 = %f, alpha = %f\n", cost, params.lambda1, params.lambda2, params.alpha); + + if (cost < max_cost && + (params.lambda1 > best_params.lambda1 || double_equals(params.lambda1, best_params.lambda1)) && + (params.lambda2 > best_params.lambda2 || double_equals(params.lambda2, best_params.lambda2)) + ) { + if (double_equals(params.lambda1, best_params.lambda1) && double_equals(params.lambda2, best_params.lambda2) && params.alpha > best_params.alpha) { + log_info("cost < max_cost but higher alpha\n"); + continue; + } + best_params = params; + log_info("cost (%f) < max_cost and better regularized, setting lambda1=%.7f, lambda2=%.7f alpha=%f\n", cost, params.lambda1, params.lambda2, params.alpha); + } + } } + language_classifier_ftrl_param_array_destroy(all_params); + double_array_destroy(costs); + + return best_params; +} + + +static language_classifier_t *trainer_finalize(logistic_regression_trainer_t *trainer, char *test_filename) { + if (trainer == NULL) return NULL; + + log_info("Done training\n"); + if (test_filename != NULL) { double test_accuracy = compute_cv_accuracy(trainer, test_filename); log_info("Test accuracy = %f\n", test_accuracy); @@ -380,12 +576,31 @@ language_classifier_t *language_classifier_train(char *filename, char *subset_fi if (classifier == NULL) { log_error("Error creating classifier\n"); logistic_regression_trainer_destroy(trainer); - return NULL; + return NULL; } // Reassign weights and features to the classifier model - classifier->weights = trainer->weights; - trainer->weights = NULL; + // final_weights + + if (trainer->optimizer_type == LOGISTIC_REGRESSION_OPTIMIZER_SGD) { + sgd_trainer_t *sgd_trainer = trainer->optimizer.sgd; + if (sgd_trainer->reg_type == REGULARIZATION_L2 || sgd_trainer->reg_type == REGULARIZATION_NONE) { + double_matrix_t *weights = logistic_regression_trainer_final_weights(trainer); + classifier->weights_type = MATRIX_DENSE; + classifier->weights.dense = weights; + } else if (sgd_trainer->reg_type == REGULARIZATION_L1) { + sparse_matrix_t *sparse_weights = logistic_regression_trainer_final_weights_sparse(trainer); + classifier->weights_type = MATRIX_SPARSE; + classifier->weights.sparse = sparse_weights; + log_info("Weights sparse: %zu rows (m=%u), %zu cols, %zu elements\n", sparse_weights->indptr->n, sparse_weights->m, sparse_weights->n, sparse_weights->data->n); + } + } else if (trainer->optimizer_type == LOGISTIC_REGRESSION_OPTIMIZER_FTRL) { + sparse_matrix_t *sparse_weights = logistic_regression_trainer_final_weights_sparse(trainer); + classifier->weights_type = MATRIX_SPARSE; + classifier->weights.sparse = sparse_weights; + log_info("Weights sparse: %zu rows (m=%u), %zu cols, %zu elements\n", sparse_weights->indptr->n, sparse_weights->m, sparse_weights->n, sparse_weights->data->n); + } + classifier->num_features = trainer->num_features; classifier->features = trainer->feature_ids; @@ -420,7 +635,89 @@ language_classifier_t *language_classifier_train(char *filename, char *subset_fi } -#define LANGUAGE_CLASSIFIER_TRAIN_USAGE "Usage: ./language_classifier_train [train|cv] filename [cv_filename] [test_filename] [output_dir]\n" +language_classifier_t *language_classifier_train_sgd(char *filename, char *subset_filename, bool cross_validation_set, char *cv_filename, char *test_filename, uint32_t num_iterations, size_t minibatch_size, regularization_type_t reg_type) { + logistic_regression_trainer_t *trainer = language_classifier_init_sgd_reg(filename, minibatch_size, reg_type); + + language_classifier_sgd_params_t params = language_classifier_parameter_sweep_sgd(trainer, subset_filename, cv_filename, minibatch_size); + log_info("Best params: lambda=%f, gamma_0=%f\n", params.lambda, params.gamma_0); + + if (!logistic_regression_trainer_reset_params_sgd(trainer, params.lambda, params.gamma_0)) { + logistic_regression_trainer_destroy(trainer); + return NULL; + } + + /* If there's not a distinct cross-validation set, e.g. + when training the production model, then the cross validation + file is just a subset of the training data and only used + for setting the hyperparameters, so ignore it after we're + done with the parameter sweep. + */ + if (!cross_validation_set) { + cv_filename = NULL; + } + + for (uint32_t epoch = 0; epoch < num_iterations; epoch++) { + log_info("Doing epoch %d\n", epoch); + + trainer->epochs = epoch; + + if (!language_classifier_train_epoch(trainer, filename, cv_filename, -1, minibatch_size)) { + log_error("Error in epoch\n"); + logistic_regression_trainer_destroy(trainer); + return NULL; + } + } + + return trainer_finalize(trainer, test_filename); +} + +language_classifier_t *language_classifier_train_ftrl(char *filename, char *subset_filename, bool cross_validation_set, char *cv_filename, char *test_filename, uint32_t num_iterations, size_t minibatch_size) { + logistic_regression_trainer_t *trainer = language_classifier_init_ftrl(filename, minibatch_size); + + language_classifier_ftrl_params_t params = language_classifier_parameter_sweep_ftrl(trainer, subset_filename, cv_filename, minibatch_size); + log_info("Best params: lambda1=%.7f, lambda2=%.7f, alpha=%f\n", params.lambda1, params.lambda2, params.alpha); + + if (!logistic_regression_trainer_reset_params_ftrl(trainer, params.alpha, DEFAULT_BETA, params.lambda1, params.lambda2)) { + logistic_regression_trainer_destroy(trainer); + return NULL; + } + + /* If there's not a distinct cross-validation set, e.g. + when training the production model, then the cross validation + file is just a subset of the training data and only used + for setting the hyperparameters, so ignore it after we're + done with the parameter sweep. + */ + if (!cross_validation_set) { + cv_filename = NULL; + } + + for (uint32_t epoch = 0; epoch < num_iterations; epoch++) { + log_info("Doing epoch %d\n", epoch); + + trainer->epochs = epoch; + + if (!language_classifier_train_epoch(trainer, filename, cv_filename, -1, minibatch_size)) { + log_error("Error in epoch\n"); + logistic_regression_trainer_destroy(trainer); + return NULL; + } + } + + return trainer_finalize(trainer, test_filename); +} + + + +typedef enum { + LANGUAGE_CLASSIFIER_TRAIN_POSITIONAL_ARG, + LANGUAGE_CLASSIFIER_TRAIN_ARG_ITERATIONS, + LANGUAGE_CLASSIFIER_TRAIN_ARG_OPTIMIZER, + LANGUAGE_CLASSIFIER_TRAIN_ARG_REGULARIZATION, + LANGUAGE_CLASSIFIER_TRAIN_ARG_MINIBATCH_SIZE +} language_classifier_train_keyword_arg_t; + +#define LANGUAGE_CLASSIFIER_TRAIN_USAGE "Usage: ./language_classifier_train [train|cv] filename [cv_filename] [test_filename] [output_dir] [--iterations number --opt (sgd|ftrl) --reg (l1|l2) --minibatch-size number]\n" int main(int argc, char **argv) { if (argc < 3) { @@ -428,42 +725,127 @@ int main(int argc, char **argv) { exit(EXIT_FAILURE); } - char *command = argv[1]; - bool cross_validation_set = false; + int pos_args = 1; - if (string_equals(command, "cv")) { - cross_validation_set = true; - } else if (!string_equals(command, "train")) { - printf(LANGUAGE_CLASSIFIER_TRAIN_USAGE); - exit(EXIT_FAILURE); - } + language_classifier_train_keyword_arg_t kwarg = LANGUAGE_CLASSIFIER_TRAIN_POSITIONAL_ARG; + + size_t num_epochs = TRAIN_EPOCHS; + size_t minibatch_size = LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE; + logistic_regression_optimizer_type optim_type = LOGISTIC_REGRESSION_OPTIMIZER_SGD; + regularization_type_t reg_type = REGULARIZATION_L2; + + size_t position = 0; + + ssize_t arg_iterations; + ssize_t arg_minibatch_size; + + char *command = NULL; + char *filename = NULL; - char *filename = argv[2]; char *cv_filename = NULL; char *test_filename = NULL; - if (cross_validation_set && argc < 5) { - printf(LANGUAGE_CLASSIFIER_TRAIN_USAGE); - exit(EXIT_FAILURE); - } else if (cross_validation_set) { - cv_filename = argv[3]; - test_filename = argv[4]; - } + bool cross_validation_set = false; char *output_dir = LIBPOSTAL_LANGUAGE_CLASSIFIER_DIR; - int output_dir_arg = cross_validation_set ? 5 : 3; - if (argc > output_dir_arg) { - output_dir = argv[output_dir_arg]; + for (int i = pos_args; i < argc; i++) { + char *arg = argv[i]; + + if (string_equals(arg, "--iterations")) { + kwarg = LANGUAGE_CLASSIFIER_TRAIN_ARG_ITERATIONS; + continue; + } + + if (string_equals(arg, "--opt")) { + kwarg = LANGUAGE_CLASSIFIER_TRAIN_ARG_OPTIMIZER; + continue; + } + + if (string_equals(arg, "--reg")) { + kwarg = LANGUAGE_CLASSIFIER_TRAIN_ARG_REGULARIZATION; + continue; + } + + if (string_equals(arg, "--minibatch-size")) { + kwarg = LANGUAGE_CLASSIFIER_TRAIN_ARG_MINIBATCH_SIZE; + continue; + } + + if (kwarg == LANGUAGE_CLASSIFIER_TRAIN_ARG_ITERATIONS) { + if (sscanf(arg, "%zd", &arg_iterations) != 1 || arg_iterations < 0) { + log_error("Bad arg for --iterations: %s\n", arg); + exit(EXIT_FAILURE); + } + num_epochs = (size_t)arg_iterations; + } else if (kwarg == LANGUAGE_CLASSIFIER_TRAIN_ARG_OPTIMIZER) { + if (string_equals(arg, "sgd")) { + optim_type = LOGISTIC_REGRESSION_OPTIMIZER_SGD; + } else if (string_equals(arg, "ftrl")) { + log_info("ftrl\n"); + optim_type = LOGISTIC_REGRESSION_OPTIMIZER_FTRL; + } else { + log_error("Bad arg for --opt: %s\n", arg); + exit(EXIT_FAILURE); + } + } else if (kwarg == LANGUAGE_CLASSIFIER_TRAIN_ARG_REGULARIZATION) { + if (string_equals(arg, "l2")) { + reg_type = REGULARIZATION_L2; + } else if (string_equals(arg, "l1")) { + reg_type = REGULARIZATION_L1; + } else { + log_error("Bad arg for --reg: %s\n", arg); + exit(EXIT_FAILURE); + } + } else if (kwarg == LANGUAGE_CLASSIFIER_TRAIN_ARG_MINIBATCH_SIZE) { + if (sscanf(arg, "%zd", &arg_minibatch_size) != 1 || arg_minibatch_size < 0) { + log_error("Bad arg for --batch: %s\n", arg); + exit(EXIT_FAILURE); + } + minibatch_size = (size_t)arg_minibatch_size; + } else if (position == 0) { + command = arg; + if (string_equals(command, "cv")) { + cross_validation_set = true; + } else if (!string_equals(command, "train")) { + printf(LANGUAGE_CLASSIFIER_TRAIN_USAGE); + exit(EXIT_FAILURE); + } + position++; + } else if (position == 1) { + filename = arg; + position++; + } else if (position == 2 && cross_validation_set) { + cv_filename = arg; + position++; + } else if (position == 2 && !cross_validation_set) { + output_dir = arg; + position++; + } else if (position == 3 && cross_validation_set) { + test_filename = arg; + position++; + } else if (position == 4 && cross_validation_set) { + output_dir = arg; + position++; + } + kwarg = LANGUAGE_CLASSIFIER_TRAIN_POSITIONAL_ARG; } - #if !defined(HAVE_SHUF) + if ((command == NULL || filename == NULL) || (cross_validation_set && (cv_filename == NULL || test_filename == NULL))) { + printf(LANGUAGE_CLASSIFIER_TRAIN_USAGE); + exit(EXIT_FAILURE); + } + + #if !defined(HAVE_SHUF) && !defined(HAVE_GSHUF) log_warn("shuf must be installed to train address parser effectively. If this is a production machine, please install shuf. No shuffling will be performed.\n"); #endif if (!address_dictionary_module_setup(NULL)) { log_error("Could not load address dictionaries\n"); exit(EXIT_FAILURE); + } else if (!transliteration_module_setup(NULL)) { + log_error("Could not load transliteration module\n"); + exit(EXIT_FAILURE); } char_array *temp_file = char_array_new(); @@ -503,7 +885,13 @@ int main(int argc, char **argv) { char_array_destroy(head_command); - language_classifier_t *language_classifier = language_classifier_train(filename, temp_filename, cross_validation_set, cv_filename, test_filename, TRAIN_EPOCHS); + language_classifier_t *language_classifier = NULL; + + if (optim_type == LOGISTIC_REGRESSION_OPTIMIZER_SGD) { + language_classifier = language_classifier_train_sgd(filename, temp_filename, cross_validation_set, cv_filename, test_filename, num_epochs, minibatch_size, reg_type); + } else if (optim_type == LOGISTIC_REGRESSION_OPTIMIZER_FTRL) { + language_classifier = language_classifier_train_ftrl(filename, temp_filename, cross_validation_set, cv_filename, test_filename, num_epochs, minibatch_size); + } remove(temp_filename); char_array_destroy(temp_file); @@ -512,7 +900,7 @@ int main(int argc, char **argv) { } log_info("Done with classifier\n"); - char_array *path = char_array_new_size(strlen(output_dir) + PATH_SEPARATOR_LEN + strlen(LANGUAGE_CLASSIFIER_COUNTRY_FILENAME)); + char_array *path = char_array_new_size(strlen(output_dir) + PATH_SEPARATOR_LEN + strlen(LANGUAGE_CLASSIFIER_FILENAME)); char *classifier_path; if (language_classifier != NULL) {