[language_classification] adding options to language_classifier_train for using SGD with {L2, L1} regularization or FTRL-Proximal using both.
1. Creates sparse matrix for L1 SGD and FTRL
2. Uses the one standard-error rule during cross-validation.
Parameters within one standard error of the lowest-cost solution
are preferred if they are better regularized.
3. Pulls weights matrix for only the features that occurred
in a given batch. In the case of FTRL, this needs to be computed
each on each batch, so the sparsity helps here.
This commit is contained in:
@@ -2,9 +2,12 @@
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
#include <float.h>
|
||||
|
||||
#include "log/log.h"
|
||||
#include "address_dictionary.h"
|
||||
#include "cartesian_product.h"
|
||||
#include "collections.h"
|
||||
#include "language_classifier.h"
|
||||
#include "language_classifier_io.h"
|
||||
#include "logistic_regression.h"
|
||||
@@ -12,31 +15,50 @@
|
||||
#include "shuffle.h"
|
||||
#include "sparse_matrix.h"
|
||||
#include "sparse_matrix_utils.h"
|
||||
#include "stochastic_gradient_descent.h"
|
||||
#include "transliterate.h"
|
||||
|
||||
#define LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD 5.0
|
||||
#define LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD 3.0
|
||||
#define LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD 100
|
||||
|
||||
#define LOG_BATCH_INTERVAL 10
|
||||
#define COMPUTE_COST_INTERVAL 100
|
||||
#define COMPUTE_CV_INTERVAL 1000
|
||||
|
||||
#define LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES 100
|
||||
#define LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES 50
|
||||
|
||||
// Hyperparameters for stochastic gradient descent
|
||||
|
||||
static double GAMMA_SCHEDULE[] = {0.01, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0};
|
||||
static const size_t GAMMA_SCHEDULE_SIZE = sizeof(GAMMA_SCHEDULE) / sizeof(double);
|
||||
|
||||
#define DEFAULT_GAMMA_0 10.0
|
||||
|
||||
static double LAMBDA_SCHEDULE[] = {0.0, 1e-5, 1e-4, 0.001, 0.01, 0.1, \
|
||||
0.2, 0.5, 1.0};
|
||||
static const size_t LAMBDA_SCHEDULE_SIZE = sizeof(LAMBDA_SCHEDULE) / sizeof(double);
|
||||
#define REGULARIZATION_SCHEDULE {0.0, 1e-7, 1e-6, 1e-5, 1e-4, 0.001, 0.01, 0.1, \
|
||||
0.2, 0.5, 1.0, 2.0, 5.0, 10.0}
|
||||
|
||||
#define DEFAULT_LAMBDA 0.0
|
||||
static double L2_SCHEDULE[] = REGULARIZATION_SCHEDULE;
|
||||
static const size_t L2_SCHEDULE_SIZE = sizeof(L2_SCHEDULE) / sizeof(double);
|
||||
|
||||
static double L1_SCHEDULE[] = REGULARIZATION_SCHEDULE;
|
||||
static const size_t L1_SCHEDULE_SIZE = sizeof(L1_SCHEDULE) / sizeof(double);
|
||||
|
||||
#define DEFAULT_L2 1e-6
|
||||
#define DEFAULT_L1 1e-4
|
||||
|
||||
// Hyperparameters for FTRL-Proximal
|
||||
|
||||
static double ALPHA_SCHEDULE[] = {0.01, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0};
|
||||
static const size_t ALPHA_SCHEDULE_SIZE = sizeof(ALPHA_SCHEDULE) / sizeof(double);
|
||||
static double DEFAULT_BETA = 1.0;
|
||||
|
||||
#define DEFAULT_ALPHA 10.0
|
||||
|
||||
#define TRAIN_EPOCHS 10
|
||||
|
||||
#define HYPERPARAMETER_EPOCHS 30
|
||||
#define HYPERPARAMETER_EPOCHS 5
|
||||
|
||||
logistic_regression_trainer_t *language_classifier_init_thresholds(char *filename, double feature_count_threshold, uint32_t label_count_threshold) {
|
||||
logistic_regression_trainer_t *language_classifier_init_params(char *filename, double feature_count_threshold, uint32_t label_count_threshold, size_t minibatch_size, logistic_regression_optimizer_type optim_type, regularization_type_t reg_type) {
|
||||
if (filename == NULL) {
|
||||
log_error("Filename was NULL\n");
|
||||
return NULL;
|
||||
@@ -51,14 +73,14 @@ logistic_regression_trainer_t *language_classifier_init_thresholds(char *filenam
|
||||
size_t num_batches = 0;
|
||||
|
||||
// Count features and labels
|
||||
while ((minibatch = language_classifier_data_set_get_minibatch(data_set, NULL)) != NULL) {
|
||||
while ((minibatch = language_classifier_data_set_get_minibatch_with_size(data_set, NULL, minibatch_size)) != NULL) {
|
||||
if (!count_labels_minibatch(label_counts, minibatch->labels)) {
|
||||
log_error("Counting minibatch labeles failed\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (num_batches % LOG_BATCH_INTERVAL == 0) {
|
||||
log_info("Counting labels, did %zu batches\n", num_batches);
|
||||
if (num_batches % LOG_BATCH_INTERVAL == 0 && num_batches > 0) {
|
||||
log_info("Counting labels, did %zu examples\n", num_batches * minibatch_size);
|
||||
}
|
||||
|
||||
num_batches++;
|
||||
@@ -91,8 +113,8 @@ logistic_regression_trainer_t *language_classifier_init_thresholds(char *filenam
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (num_batches % LOG_BATCH_INTERVAL == 0) {
|
||||
log_info("Counting features, did %zu batches\n", num_batches);
|
||||
if (num_batches % LOG_BATCH_INTERVAL == 0 && num_batches > 0) {
|
||||
log_info("Counting features, did %zu examples\n", num_batches * minibatch_size);
|
||||
}
|
||||
|
||||
num_batches++;
|
||||
@@ -119,11 +141,34 @@ logistic_regression_trainer_t *language_classifier_init_thresholds(char *filenam
|
||||
kh_destroy(str_double, feature_counts);
|
||||
|
||||
|
||||
return logistic_regression_trainer_init(feature_ids, label_ids, DEFAULT_GAMMA_0, DEFAULT_LAMBDA);
|
||||
logistic_regression_trainer_t *trainer = NULL;
|
||||
|
||||
if (optim_type == LOGISTIC_REGRESSION_OPTIMIZER_SGD) {
|
||||
bool fit_intercept = true;
|
||||
double default_lambda = 0.0;
|
||||
if (reg_type == REGULARIZATION_L2){
|
||||
default_lambda = DEFAULT_L2;
|
||||
} else if (reg_type == REGULARIZATION_L1) {
|
||||
default_lambda = DEFAULT_L1;
|
||||
}
|
||||
trainer = logistic_regression_trainer_init_sgd(feature_ids, label_ids, fit_intercept, reg_type, default_lambda, DEFAULT_GAMMA_0);
|
||||
} else if (optim_type == LOGISTIC_REGRESSION_OPTIMIZER_FTRL) {
|
||||
trainer = logistic_regression_trainer_init_ftrl(feature_ids, label_ids, DEFAULT_ALPHA, DEFAULT_BETA, DEFAULT_L1, DEFAULT_L2);
|
||||
}
|
||||
|
||||
return trainer;
|
||||
}
|
||||
|
||||
logistic_regression_trainer_t *language_classifier_init(char *filename) {
|
||||
return language_classifier_init_thresholds(filename, LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD, LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD);
|
||||
logistic_regression_trainer_t *language_classifier_init_optim_reg(char *filename, size_t minibatch_size, logistic_regression_optimizer_type optim_type, regularization_type_t reg_type) {
|
||||
return language_classifier_init_params(filename, LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD, LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD, minibatch_size, optim_type, reg_type);
|
||||
}
|
||||
|
||||
logistic_regression_trainer_t *language_classifier_init_sgd_reg(char *filename, size_t minibatch_size, regularization_type_t reg_type) {
|
||||
return language_classifier_init_params(filename, LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD, LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD, minibatch_size, LOGISTIC_REGRESSION_OPTIMIZER_SGD, reg_type);
|
||||
}
|
||||
|
||||
logistic_regression_trainer_t *language_classifier_init_ftrl(char *filename, size_t minibatch_size) {
|
||||
return language_classifier_init_params(filename, LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD, LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD, minibatch_size, LOGISTIC_REGRESSION_OPTIMIZER_FTRL, REGULARIZATION_NONE);
|
||||
}
|
||||
|
||||
double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filename) {
|
||||
@@ -140,9 +185,20 @@ double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filenam
|
||||
sparse_matrix_t *x = feature_matrix(trainer->feature_ids, minibatch->features);
|
||||
uint32_array *y = label_vector(trainer->label_ids, minibatch->labels);
|
||||
|
||||
double_matrix_resize(p_y, x->m, trainer->num_labels);
|
||||
if (!double_matrix_resize_aligned(p_y, x->m, trainer->num_labels, 16)) {
|
||||
log_error("resize p_y failed\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
double_matrix_zero(p_y);
|
||||
|
||||
if (!logistic_regression_model_expectation(trainer->weights, x, p_y)) {
|
||||
if (!sparse_matrix_add_unique_columns(x, trainer->unique_columns, trainer->batch_columns)) {
|
||||
log_error("Error adding unique columns\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
double_matrix_t *theta = logistic_regression_trainer_get_regularized_weights(trainer);
|
||||
|
||||
if (!logistic_regression_model_expectation(theta, x, p_y)) {
|
||||
log_error("Predict cv batch failed\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
@@ -187,9 +243,12 @@ double compute_total_cost(logistic_regression_trainer_t *trainer, char *filename
|
||||
double total_cost = 0.0;
|
||||
size_t num_batches = 0;
|
||||
|
||||
// Need to regularize the weights
|
||||
double_matrix_t *theta = logistic_regression_trainer_get_regularized_weights(trainer);
|
||||
|
||||
while ((minibatch = language_classifier_data_set_get_minibatch(data_set, trainer->label_ids)) != NULL) {
|
||||
|
||||
double batch_cost = logistic_regression_trainer_batch_cost(trainer, minibatch->features, minibatch->labels);
|
||||
double batch_cost = logistic_regression_trainer_minibatch_cost(trainer, minibatch->features, minibatch->labels);
|
||||
total_cost += batch_cost;
|
||||
|
||||
language_classifier_minibatch_destroy(minibatch);
|
||||
@@ -207,16 +266,16 @@ double compute_total_cost(logistic_regression_trainer_t *trainer, char *filename
|
||||
}
|
||||
|
||||
|
||||
bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, ssize_t train_batches) {
|
||||
bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, ssize_t train_batches, size_t minibatch_size) {
|
||||
if (filename == NULL) {
|
||||
log_error("Filename was NULL\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined(HAVE_SHUF)
|
||||
#if defined(HAVE_SHUF) || defined(HAVE_GSHUF)
|
||||
log_info("Shuffling\n");
|
||||
|
||||
if (!shuffle_file(filename)) {
|
||||
if (!shuffle_file_chunked_size(filename, DEFAULT_SHUFFLE_CHUNK_SIZE)) {
|
||||
log_error("Error in shuffle\n");
|
||||
logistic_regression_trainer_destroy(trainer);
|
||||
return NULL;
|
||||
@@ -237,24 +296,25 @@ bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, cha
|
||||
double train_cost = 0.0;
|
||||
double cv_accuracy = 0.0;
|
||||
|
||||
while ((minibatch = language_classifier_data_set_get_minibatch(data_set, trainer->label_ids)) != NULL) {
|
||||
bool compute_cost = num_batches % COMPUTE_COST_INTERVAL == 0 && num_batches > 0;
|
||||
while ((minibatch = language_classifier_data_set_get_minibatch_with_size(data_set, trainer->label_ids, minibatch_size)) != NULL) {
|
||||
bool compute_cost = num_batches % COMPUTE_COST_INTERVAL == 0;
|
||||
bool compute_cv = num_batches % COMPUTE_CV_INTERVAL == 0 && num_batches > 0 && cv_filename != NULL;
|
||||
|
||||
if (num_batches % LOG_BATCH_INTERVAL == 0 && num_batches > 0) {
|
||||
log_info("Epoch %u, doing batch %zu\n", trainer->epochs, num_batches);
|
||||
log_info("Epoch %u, doing %zu examples\n", trainer->epochs, num_batches * minibatch_size);
|
||||
}
|
||||
|
||||
if (compute_cost) {
|
||||
train_cost = logistic_regression_trainer_batch_cost(trainer, minibatch->features, minibatch->labels);
|
||||
train_cost = logistic_regression_trainer_minibatch_cost(trainer, minibatch->features, minibatch->labels);
|
||||
log_info("cost = %f\n", train_cost);
|
||||
}
|
||||
|
||||
if (!logistic_regression_trainer_train_batch(trainer, minibatch->features, minibatch->labels)){
|
||||
if (!logistic_regression_trainer_train_minibatch(trainer, minibatch->features, minibatch->labels)){
|
||||
log_error("Train batch failed\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (compute_cost && cv_filename != NULL) {
|
||||
if (compute_cv) {
|
||||
cv_accuracy = compute_cv_accuracy(trainer, cv_filename);
|
||||
log_info("cv accuracy=%f\n", cv_accuracy);
|
||||
}
|
||||
@@ -262,8 +322,8 @@ bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, cha
|
||||
num_batches++;
|
||||
|
||||
if (train_batches > 0 && num_batches == (size_t)train_batches) {
|
||||
log_info("Epoch %u, trained %zu batches\n", trainer->epochs, num_batches);
|
||||
train_cost = logistic_regression_trainer_batch_cost(trainer, minibatch->features, minibatch->labels);
|
||||
log_info("Epoch %u, trained %zu examples\n", trainer->epochs, num_batches * minibatch_size);
|
||||
train_cost = logistic_regression_trainer_minibatch_cost(trainer, minibatch->features, minibatch->labels);
|
||||
log_info("cost = %f\n", train_cost);
|
||||
break;
|
||||
}
|
||||
@@ -277,36 +337,7 @@ bool language_classifier_train_epoch(logistic_regression_trainer_t *trainer, cha
|
||||
return true;
|
||||
}
|
||||
|
||||
typedef struct language_classifier_params {
|
||||
double lambda;
|
||||
double gamma_0;
|
||||
} language_classifier_params_t;
|
||||
|
||||
language_classifier_params_t language_classifier_parameter_sweep(char *filename, char *cv_filename) {
|
||||
// Select features using the full data set
|
||||
logistic_regression_trainer_t *trainer = language_classifier_init(filename);
|
||||
|
||||
double best_cost = 0.0;
|
||||
|
||||
language_classifier_params_t best_params = (language_classifier_params_t){0.0, 0.0};
|
||||
|
||||
for (size_t i = 0; i < LAMBDA_SCHEDULE_SIZE; i++) {
|
||||
for (size_t j = 0; j < GAMMA_SCHEDULE_SIZE; j++) {
|
||||
trainer->lambda = LAMBDA_SCHEDULE[i];
|
||||
trainer->gamma_0 = GAMMA_SCHEDULE[j];
|
||||
|
||||
log_info("Optimizing hyperparameters. Trying lambda=%f, gamma_0=%f\n", trainer->lambda, trainer->gamma_0);
|
||||
|
||||
for (int k = 0; k < HYPERPARAMETER_EPOCHS; k++) {
|
||||
trainer->epochs = k;
|
||||
|
||||
if (!language_classifier_train_epoch(trainer, filename, NULL, LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES)) {
|
||||
log_error("Error in epoch\n");
|
||||
logistic_regression_trainer_destroy(trainer);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
static double language_classifier_cv_cost(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, size_t minibatch_size, bool *diverged) {
|
||||
ssize_t cost_batches;
|
||||
char *cost_file;
|
||||
|
||||
@@ -318,59 +349,224 @@ language_classifier_params_t language_classifier_parameter_sweep(char *filename,
|
||||
cost_batches = -1;
|
||||
}
|
||||
|
||||
double cost = compute_total_cost(trainer, cost_file, cost_batches);
|
||||
double initial_cost = compute_total_cost(trainer, cost_file, cost_batches);
|
||||
|
||||
for (size_t k = 0; k < HYPERPARAMETER_EPOCHS; k++) {
|
||||
trainer->epochs = k;
|
||||
|
||||
if (!language_classifier_train_epoch(trainer, filename, NULL, LANGUAGE_CLASSIFIER_HYPERPARAMETER_BATCHES, minibatch_size)) {
|
||||
log_error("Error in epoch\n");
|
||||
logistic_regression_trainer_destroy(trainer);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
double final_cost = compute_total_cost(trainer, cost_file, cost_batches);
|
||||
|
||||
*diverged = final_cost > initial_cost;
|
||||
log_info("final_cost = %f, initial_cost = %f\n", final_cost, initial_cost);
|
||||
|
||||
return final_cost;
|
||||
}
|
||||
|
||||
typedef struct language_classifier_sgd_params {
|
||||
double lambda;
|
||||
double gamma_0;
|
||||
} language_classifier_sgd_params_t;
|
||||
|
||||
typedef struct language_classifier_ftrl_params {
|
||||
double alpha;
|
||||
double lambda1;
|
||||
double lambda2;
|
||||
} language_classifier_ftrl_params_t;
|
||||
|
||||
VECTOR_INIT(language_classifier_sgd_param_array, language_classifier_sgd_params_t)
|
||||
VECTOR_INIT(language_classifier_ftrl_param_array, language_classifier_ftrl_params_t)
|
||||
|
||||
/* Uses the one standard-error rule (http://www.stat.cmu.edu/~ryantibs/datamining/lectures/19-val2.pdf)
|
||||
A solution that's better regularized is preferred if it's within one standard error
|
||||
of the solution with the lowest cross-validation error.
|
||||
*/
|
||||
|
||||
language_classifier_sgd_params_t language_classifier_parameter_sweep_sgd(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, size_t minibatch_size) {
|
||||
double best_cost = DBL_MAX;
|
||||
|
||||
double default_lambda = 0.0;
|
||||
size_t lambda_schedule_size = 0;
|
||||
double *lambda_schedule = NULL;
|
||||
|
||||
sgd_trainer_t *sgd = trainer->optimizer.sgd;
|
||||
|
||||
if (sgd->reg_type == REGULARIZATION_L2) {
|
||||
default_lambda = DEFAULT_L2;
|
||||
lambda_schedule_size = L2_SCHEDULE_SIZE;
|
||||
lambda_schedule = L2_SCHEDULE;
|
||||
} else if (sgd->reg_type == REGULARIZATION_L1) {
|
||||
lambda_schedule_size = L1_SCHEDULE_SIZE;
|
||||
lambda_schedule = L1_SCHEDULE;
|
||||
default_lambda = DEFAULT_L1;
|
||||
}
|
||||
|
||||
double_array *costs = double_array_new();
|
||||
language_classifier_sgd_param_array *all_params = language_classifier_sgd_param_array_new();
|
||||
|
||||
language_classifier_sgd_params_t best_params = (language_classifier_sgd_params_t){default_lambda, DEFAULT_GAMMA_0};
|
||||
double cost;
|
||||
language_classifier_sgd_params_t params;
|
||||
|
||||
cartesian_product_iterator_t *iter = cartesian_product_iterator_new(2, lambda_schedule_size, GAMMA_SCHEDULE_SIZE);
|
||||
for (uint32_t *vals = cartesian_product_iterator_start(iter); !cartesian_product_iterator_done(iter); vals = cartesian_product_iterator_next(iter)) {
|
||||
double lambda = lambda_schedule[vals[0]];
|
||||
double gamma_0 = GAMMA_SCHEDULE[vals[1]];
|
||||
|
||||
params.lambda = lambda,
|
||||
params.gamma_0 = gamma_0;
|
||||
|
||||
if (!logistic_regression_trainer_reset_params_sgd(trainer, lambda, gamma_0)) {
|
||||
log_error("Error resetting params\n");
|
||||
logistic_regression_trainer_destroy(trainer);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
log_info("Optimizing hyperparameters. Trying lambda=%.7f, gamma_0=%f\n", lambda, gamma_0);
|
||||
|
||||
bool diverged = false;
|
||||
cost = language_classifier_cv_cost(trainer, filename, cv_filename, minibatch_size, &diverged);
|
||||
|
||||
if (!diverged) {
|
||||
language_classifier_sgd_param_array_push(all_params, params);
|
||||
double_array_push(costs, cost);
|
||||
} else {
|
||||
log_info("Diverged, cost = %f\n", cost);
|
||||
}
|
||||
|
||||
log_info("Total cost = %f\n", cost);
|
||||
if ((i == 0 && j == 0) || cost < best_cost) {
|
||||
log_info("Better than current best parameters: lambda=%f, gamma_0=%f\n", trainer->lambda, trainer->gamma_0);
|
||||
if (cost < best_cost) {
|
||||
log_info("Better than current best parameters: setting lambda=%.7f, gamma_0=%f\n", lambda, gamma_0);
|
||||
best_cost = cost;
|
||||
best_params.lambda = trainer->lambda;
|
||||
best_params.gamma_0 = trainer->gamma_0;
|
||||
best_params.lambda = lambda;
|
||||
best_params.gamma_0 = gamma_0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t num_params = costs->n;
|
||||
if (num_params > 0) {
|
||||
language_classifier_sgd_params_t *param_values = all_params->a;
|
||||
double *cost_values = costs->a;
|
||||
|
||||
double std_error = double_array_std(cost_values, num_params) / sqrt((double)num_params);
|
||||
|
||||
double max_cost = best_cost + std_error;
|
||||
log_info("max_cost = %f using the one standard error rule\n", max_cost);
|
||||
|
||||
for (size_t i = 0; i < num_params; i++) {
|
||||
cost = cost_values[i];
|
||||
params = param_values[i];
|
||||
|
||||
if (cost < max_cost && params.lambda > best_params.lambda) {
|
||||
best_params = params;
|
||||
log_info("cost (%f) < max_cost and better regularized, setting lambda=%.7f, gamma_0=%f\n", cost, params.lambda, params.gamma_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
language_classifier_sgd_param_array_destroy(all_params);
|
||||
double_array_destroy(costs);
|
||||
|
||||
return best_params;
|
||||
}
|
||||
|
||||
|
||||
language_classifier_t *language_classifier_train(char *filename, char *subset_filename, bool cross_validation_set, char *cv_filename, char *test_filename, uint32_t num_iterations) {
|
||||
language_classifier_params_t params = language_classifier_parameter_sweep(subset_filename, cv_filename);
|
||||
log_info("Best params: lambda=%f, gamma_0=%f\n", params.lambda, params.gamma_0);
|
||||
language_classifier_ftrl_params_t language_classifier_parameter_sweep_ftrl(logistic_regression_trainer_t *trainer, char *filename, char *cv_filename, size_t minibatch_size) {
|
||||
double best_cost = DBL_MAX;
|
||||
|
||||
logistic_regression_trainer_t *trainer = language_classifier_init(filename);
|
||||
trainer->lambda = params.lambda;
|
||||
trainer->gamma_0 = params.gamma_0;
|
||||
language_classifier_ftrl_params_t best_params = (language_classifier_ftrl_params_t){DEFAULT_ALPHA, DEFAULT_L1, DEFAULT_L2};
|
||||
|
||||
/* If there's not a distinct cross-validation set, e.g.
|
||||
when training the production model, then the cross validation
|
||||
file is just a subset of the training data and only used
|
||||
for setting the hyperparameters, so ignore it after we're
|
||||
done with the parameter sweep.
|
||||
*/
|
||||
if (!cross_validation_set) {
|
||||
cv_filename = NULL;
|
||||
}
|
||||
double_array *costs = double_array_new();
|
||||
language_classifier_ftrl_param_array *all_params = language_classifier_ftrl_param_array_new();
|
||||
language_classifier_ftrl_params_t params;
|
||||
double cost;
|
||||
|
||||
for (uint32_t epoch = 0; epoch < num_iterations; epoch++) {
|
||||
log_info("Doing epoch %d\n", epoch);
|
||||
cartesian_product_iterator_t *iter = cartesian_product_iterator_new(3, L1_SCHEDULE_SIZE, L2_SCHEDULE_SIZE, ALPHA_SCHEDULE_SIZE);
|
||||
for (uint32_t *vals = cartesian_product_iterator_start(iter); !cartesian_product_iterator_done(iter); vals = cartesian_product_iterator_next(iter)) {
|
||||
double lambda1 = L1_SCHEDULE[vals[0]];
|
||||
double lambda2 = L2_SCHEDULE[vals[1]];
|
||||
double alpha = ALPHA_SCHEDULE[vals[2]];
|
||||
|
||||
trainer->epochs = epoch;
|
||||
params.lambda1 = lambda1,
|
||||
params.lambda2 = lambda2,
|
||||
params.alpha = alpha;
|
||||
|
||||
if (!language_classifier_train_epoch(trainer, filename, cv_filename, -1)) {
|
||||
log_error("Error in epoch\n");
|
||||
if (!logistic_regression_trainer_reset_params_ftrl(trainer, alpha, DEFAULT_BETA, lambda1, lambda2)) {
|
||||
log_error("Error resetting params\n");
|
||||
logistic_regression_trainer_destroy(trainer);
|
||||
return NULL;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
log_info("Optimizing hyperparameters. Trying lambda1=%.7f, lambda2=%.7f, alpha=%f\n", lambda1, lambda2, alpha);
|
||||
|
||||
bool diverged = false;
|
||||
cost = language_classifier_cv_cost(trainer, filename, cv_filename, minibatch_size, &diverged);
|
||||
|
||||
if (!diverged) {
|
||||
language_classifier_ftrl_param_array_push(all_params, params);
|
||||
double_array_push(costs, cost);
|
||||
} else {
|
||||
log_info("Diverged, cost = %f\n", cost);
|
||||
}
|
||||
|
||||
log_info("Total cost = %f\n", cost);
|
||||
if (cost < best_cost) {
|
||||
log_info("Better than current best parameters: setting lambda1=%.7f, lambda2=%.7f, alpha=%f\n", lambda1, lambda2, alpha);
|
||||
best_cost = cost;
|
||||
best_params.lambda1 = lambda1;
|
||||
best_params.lambda2 = lambda2;
|
||||
best_params.alpha = alpha;
|
||||
}
|
||||
}
|
||||
|
||||
size_t num_params = costs->n;
|
||||
if (num_params > 0) {
|
||||
language_classifier_ftrl_params_t *param_values = all_params->a;
|
||||
double *cost_values = costs->a;
|
||||
|
||||
double std_error = double_array_std(cost_values, num_params) / sqrt((double)num_params);
|
||||
|
||||
double max_cost = best_cost + std_error;
|
||||
log_info("best_cost = %f, std_error = %f, max_cost = %f using the one standard error rule\n", best_cost, std_error, max_cost);
|
||||
|
||||
for (size_t i = 0; i < num_params; i++) {
|
||||
cost = cost_values[i];
|
||||
params = param_values[i];
|
||||
|
||||
log_info("cost = %f, lambda1 = %f, lambda2 = %f, alpha = %f\n", cost, params.lambda1, params.lambda2, params.alpha);
|
||||
|
||||
if (cost < max_cost &&
|
||||
(params.lambda1 > best_params.lambda1 || double_equals(params.lambda1, best_params.lambda1)) &&
|
||||
(params.lambda2 > best_params.lambda2 || double_equals(params.lambda2, best_params.lambda2))
|
||||
) {
|
||||
if (double_equals(params.lambda1, best_params.lambda1) && double_equals(params.lambda2, best_params.lambda2) && params.alpha > best_params.alpha) {
|
||||
log_info("cost < max_cost but higher alpha\n");
|
||||
continue;
|
||||
}
|
||||
best_params = params;
|
||||
log_info("cost (%f) < max_cost and better regularized, setting lambda1=%.7f, lambda2=%.7f alpha=%f\n", cost, params.lambda1, params.lambda2, params.alpha);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
language_classifier_ftrl_param_array_destroy(all_params);
|
||||
double_array_destroy(costs);
|
||||
|
||||
return best_params;
|
||||
}
|
||||
|
||||
|
||||
static language_classifier_t *trainer_finalize(logistic_regression_trainer_t *trainer, char *test_filename) {
|
||||
if (trainer == NULL) return NULL;
|
||||
|
||||
log_info("Done training\n");
|
||||
|
||||
if (!logistic_regression_trainer_finalize(trainer)) {
|
||||
log_error("Error in finalization\n");
|
||||
logistic_regression_trainer_destroy(trainer);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (test_filename != NULL) {
|
||||
double test_accuracy = compute_cv_accuracy(trainer, test_filename);
|
||||
log_info("Test accuracy = %f\n", test_accuracy);
|
||||
@@ -384,8 +580,27 @@ language_classifier_t *language_classifier_train(char *filename, char *subset_fi
|
||||
}
|
||||
|
||||
// Reassign weights and features to the classifier model
|
||||
classifier->weights = trainer->weights;
|
||||
trainer->weights = NULL;
|
||||
// final_weights
|
||||
|
||||
if (trainer->optimizer_type == LOGISTIC_REGRESSION_OPTIMIZER_SGD) {
|
||||
sgd_trainer_t *sgd_trainer = trainer->optimizer.sgd;
|
||||
if (sgd_trainer->reg_type == REGULARIZATION_L2 || sgd_trainer->reg_type == REGULARIZATION_NONE) {
|
||||
double_matrix_t *weights = logistic_regression_trainer_final_weights(trainer);
|
||||
classifier->weights_type = MATRIX_DENSE;
|
||||
classifier->weights.dense = weights;
|
||||
} else if (sgd_trainer->reg_type == REGULARIZATION_L1) {
|
||||
sparse_matrix_t *sparse_weights = logistic_regression_trainer_final_weights_sparse(trainer);
|
||||
classifier->weights_type = MATRIX_SPARSE;
|
||||
classifier->weights.sparse = sparse_weights;
|
||||
log_info("Weights sparse: %zu rows (m=%u), %zu cols, %zu elements\n", sparse_weights->indptr->n, sparse_weights->m, sparse_weights->n, sparse_weights->data->n);
|
||||
}
|
||||
} else if (trainer->optimizer_type == LOGISTIC_REGRESSION_OPTIMIZER_FTRL) {
|
||||
sparse_matrix_t *sparse_weights = logistic_regression_trainer_final_weights_sparse(trainer);
|
||||
classifier->weights_type = MATRIX_SPARSE;
|
||||
classifier->weights.sparse = sparse_weights;
|
||||
log_info("Weights sparse: %zu rows (m=%u), %zu cols, %zu elements\n", sparse_weights->indptr->n, sparse_weights->m, sparse_weights->n, sparse_weights->data->n);
|
||||
}
|
||||
|
||||
|
||||
classifier->num_features = trainer->num_features;
|
||||
classifier->features = trainer->feature_ids;
|
||||
@@ -420,7 +635,89 @@ language_classifier_t *language_classifier_train(char *filename, char *subset_fi
|
||||
}
|
||||
|
||||
|
||||
#define LANGUAGE_CLASSIFIER_TRAIN_USAGE "Usage: ./language_classifier_train [train|cv] filename [cv_filename] [test_filename] [output_dir]\n"
|
||||
language_classifier_t *language_classifier_train_sgd(char *filename, char *subset_filename, bool cross_validation_set, char *cv_filename, char *test_filename, uint32_t num_iterations, size_t minibatch_size, regularization_type_t reg_type) {
|
||||
logistic_regression_trainer_t *trainer = language_classifier_init_sgd_reg(filename, minibatch_size, reg_type);
|
||||
|
||||
language_classifier_sgd_params_t params = language_classifier_parameter_sweep_sgd(trainer, subset_filename, cv_filename, minibatch_size);
|
||||
log_info("Best params: lambda=%f, gamma_0=%f\n", params.lambda, params.gamma_0);
|
||||
|
||||
if (!logistic_regression_trainer_reset_params_sgd(trainer, params.lambda, params.gamma_0)) {
|
||||
logistic_regression_trainer_destroy(trainer);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* If there's not a distinct cross-validation set, e.g.
|
||||
when training the production model, then the cross validation
|
||||
file is just a subset of the training data and only used
|
||||
for setting the hyperparameters, so ignore it after we're
|
||||
done with the parameter sweep.
|
||||
*/
|
||||
if (!cross_validation_set) {
|
||||
cv_filename = NULL;
|
||||
}
|
||||
|
||||
for (uint32_t epoch = 0; epoch < num_iterations; epoch++) {
|
||||
log_info("Doing epoch %d\n", epoch);
|
||||
|
||||
trainer->epochs = epoch;
|
||||
|
||||
if (!language_classifier_train_epoch(trainer, filename, cv_filename, -1, minibatch_size)) {
|
||||
log_error("Error in epoch\n");
|
||||
logistic_regression_trainer_destroy(trainer);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return trainer_finalize(trainer, test_filename);
|
||||
}
|
||||
|
||||
language_classifier_t *language_classifier_train_ftrl(char *filename, char *subset_filename, bool cross_validation_set, char *cv_filename, char *test_filename, uint32_t num_iterations, size_t minibatch_size) {
|
||||
logistic_regression_trainer_t *trainer = language_classifier_init_ftrl(filename, minibatch_size);
|
||||
|
||||
language_classifier_ftrl_params_t params = language_classifier_parameter_sweep_ftrl(trainer, subset_filename, cv_filename, minibatch_size);
|
||||
log_info("Best params: lambda1=%.7f, lambda2=%.7f, alpha=%f\n", params.lambda1, params.lambda2, params.alpha);
|
||||
|
||||
if (!logistic_regression_trainer_reset_params_ftrl(trainer, params.alpha, DEFAULT_BETA, params.lambda1, params.lambda2)) {
|
||||
logistic_regression_trainer_destroy(trainer);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* If there's not a distinct cross-validation set, e.g.
|
||||
when training the production model, then the cross validation
|
||||
file is just a subset of the training data and only used
|
||||
for setting the hyperparameters, so ignore it after we're
|
||||
done with the parameter sweep.
|
||||
*/
|
||||
if (!cross_validation_set) {
|
||||
cv_filename = NULL;
|
||||
}
|
||||
|
||||
for (uint32_t epoch = 0; epoch < num_iterations; epoch++) {
|
||||
log_info("Doing epoch %d\n", epoch);
|
||||
|
||||
trainer->epochs = epoch;
|
||||
|
||||
if (!language_classifier_train_epoch(trainer, filename, cv_filename, -1, minibatch_size)) {
|
||||
log_error("Error in epoch\n");
|
||||
logistic_regression_trainer_destroy(trainer);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return trainer_finalize(trainer, test_filename);
|
||||
}
|
||||
|
||||
|
||||
|
||||
typedef enum {
|
||||
LANGUAGE_CLASSIFIER_TRAIN_POSITIONAL_ARG,
|
||||
LANGUAGE_CLASSIFIER_TRAIN_ARG_ITERATIONS,
|
||||
LANGUAGE_CLASSIFIER_TRAIN_ARG_OPTIMIZER,
|
||||
LANGUAGE_CLASSIFIER_TRAIN_ARG_REGULARIZATION,
|
||||
LANGUAGE_CLASSIFIER_TRAIN_ARG_MINIBATCH_SIZE
|
||||
} language_classifier_train_keyword_arg_t;
|
||||
|
||||
#define LANGUAGE_CLASSIFIER_TRAIN_USAGE "Usage: ./language_classifier_train [train|cv] filename [cv_filename] [test_filename] [output_dir] [--iterations number --opt (sgd|ftrl) --reg (l1|l2) --minibatch-size number]\n"
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc < 3) {
|
||||
@@ -428,42 +725,127 @@ int main(int argc, char **argv) {
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
char *command = argv[1];
|
||||
int pos_args = 1;
|
||||
|
||||
language_classifier_train_keyword_arg_t kwarg = LANGUAGE_CLASSIFIER_TRAIN_POSITIONAL_ARG;
|
||||
|
||||
size_t num_epochs = TRAIN_EPOCHS;
|
||||
size_t minibatch_size = LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE;
|
||||
logistic_regression_optimizer_type optim_type = LOGISTIC_REGRESSION_OPTIMIZER_SGD;
|
||||
regularization_type_t reg_type = REGULARIZATION_L2;
|
||||
|
||||
size_t position = 0;
|
||||
|
||||
ssize_t arg_iterations;
|
||||
ssize_t arg_minibatch_size;
|
||||
|
||||
char *command = NULL;
|
||||
char *filename = NULL;
|
||||
|
||||
char *cv_filename = NULL;
|
||||
char *test_filename = NULL;
|
||||
|
||||
bool cross_validation_set = false;
|
||||
|
||||
char *output_dir = LIBPOSTAL_LANGUAGE_CLASSIFIER_DIR;
|
||||
|
||||
for (int i = pos_args; i < argc; i++) {
|
||||
char *arg = argv[i];
|
||||
|
||||
if (string_equals(arg, "--iterations")) {
|
||||
kwarg = LANGUAGE_CLASSIFIER_TRAIN_ARG_ITERATIONS;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (string_equals(arg, "--opt")) {
|
||||
kwarg = LANGUAGE_CLASSIFIER_TRAIN_ARG_OPTIMIZER;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (string_equals(arg, "--reg")) {
|
||||
kwarg = LANGUAGE_CLASSIFIER_TRAIN_ARG_REGULARIZATION;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (string_equals(arg, "--minibatch-size")) {
|
||||
kwarg = LANGUAGE_CLASSIFIER_TRAIN_ARG_MINIBATCH_SIZE;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (kwarg == LANGUAGE_CLASSIFIER_TRAIN_ARG_ITERATIONS) {
|
||||
if (sscanf(arg, "%zd", &arg_iterations) != 1 || arg_iterations < 0) {
|
||||
log_error("Bad arg for --iterations: %s\n", arg);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
num_epochs = (size_t)arg_iterations;
|
||||
} else if (kwarg == LANGUAGE_CLASSIFIER_TRAIN_ARG_OPTIMIZER) {
|
||||
if (string_equals(arg, "sgd")) {
|
||||
optim_type = LOGISTIC_REGRESSION_OPTIMIZER_SGD;
|
||||
} else if (string_equals(arg, "ftrl")) {
|
||||
log_info("ftrl\n");
|
||||
optim_type = LOGISTIC_REGRESSION_OPTIMIZER_FTRL;
|
||||
} else {
|
||||
log_error("Bad arg for --opt: %s\n", arg);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
} else if (kwarg == LANGUAGE_CLASSIFIER_TRAIN_ARG_REGULARIZATION) {
|
||||
if (string_equals(arg, "l2")) {
|
||||
reg_type = REGULARIZATION_L2;
|
||||
} else if (string_equals(arg, "l1")) {
|
||||
reg_type = REGULARIZATION_L1;
|
||||
} else {
|
||||
log_error("Bad arg for --reg: %s\n", arg);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
} else if (kwarg == LANGUAGE_CLASSIFIER_TRAIN_ARG_MINIBATCH_SIZE) {
|
||||
if (sscanf(arg, "%zd", &arg_minibatch_size) != 1 || arg_minibatch_size < 0) {
|
||||
log_error("Bad arg for --batch: %s\n", arg);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
minibatch_size = (size_t)arg_minibatch_size;
|
||||
} else if (position == 0) {
|
||||
command = arg;
|
||||
if (string_equals(command, "cv")) {
|
||||
cross_validation_set = true;
|
||||
} else if (!string_equals(command, "train")) {
|
||||
printf(LANGUAGE_CLASSIFIER_TRAIN_USAGE);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
position++;
|
||||
} else if (position == 1) {
|
||||
filename = arg;
|
||||
position++;
|
||||
} else if (position == 2 && cross_validation_set) {
|
||||
cv_filename = arg;
|
||||
position++;
|
||||
} else if (position == 2 && !cross_validation_set) {
|
||||
output_dir = arg;
|
||||
position++;
|
||||
} else if (position == 3 && cross_validation_set) {
|
||||
test_filename = arg;
|
||||
position++;
|
||||
} else if (position == 4 && cross_validation_set) {
|
||||
output_dir = arg;
|
||||
position++;
|
||||
}
|
||||
kwarg = LANGUAGE_CLASSIFIER_TRAIN_POSITIONAL_ARG;
|
||||
}
|
||||
|
||||
char *filename = argv[2];
|
||||
char *cv_filename = NULL;
|
||||
char *test_filename = NULL;
|
||||
|
||||
if (cross_validation_set && argc < 5) {
|
||||
if ((command == NULL || filename == NULL) || (cross_validation_set && (cv_filename == NULL || test_filename == NULL))) {
|
||||
printf(LANGUAGE_CLASSIFIER_TRAIN_USAGE);
|
||||
exit(EXIT_FAILURE);
|
||||
} else if (cross_validation_set) {
|
||||
cv_filename = argv[3];
|
||||
test_filename = argv[4];
|
||||
}
|
||||
|
||||
char *output_dir = LIBPOSTAL_LANGUAGE_CLASSIFIER_DIR;
|
||||
int output_dir_arg = cross_validation_set ? 5 : 3;
|
||||
|
||||
if (argc > output_dir_arg) {
|
||||
output_dir = argv[output_dir_arg];
|
||||
}
|
||||
|
||||
#if !defined(HAVE_SHUF)
|
||||
#if !defined(HAVE_SHUF) && !defined(HAVE_GSHUF)
|
||||
log_warn("shuf must be installed to train address parser effectively. If this is a production machine, please install shuf. No shuffling will be performed.\n");
|
||||
#endif
|
||||
|
||||
if (!address_dictionary_module_setup(NULL)) {
|
||||
log_error("Could not load address dictionaries\n");
|
||||
exit(EXIT_FAILURE);
|
||||
} else if (!transliteration_module_setup(NULL)) {
|
||||
log_error("Could not load transliteration module\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
char_array *temp_file = char_array_new();
|
||||
@@ -503,7 +885,13 @@ int main(int argc, char **argv) {
|
||||
|
||||
char_array_destroy(head_command);
|
||||
|
||||
language_classifier_t *language_classifier = language_classifier_train(filename, temp_filename, cross_validation_set, cv_filename, test_filename, TRAIN_EPOCHS);
|
||||
language_classifier_t *language_classifier = NULL;
|
||||
|
||||
if (optim_type == LOGISTIC_REGRESSION_OPTIMIZER_SGD) {
|
||||
language_classifier = language_classifier_train_sgd(filename, temp_filename, cross_validation_set, cv_filename, test_filename, num_epochs, minibatch_size, reg_type);
|
||||
} else if (optim_type == LOGISTIC_REGRESSION_OPTIMIZER_FTRL) {
|
||||
language_classifier = language_classifier_train_ftrl(filename, temp_filename, cross_validation_set, cv_filename, test_filename, num_epochs, minibatch_size);
|
||||
}
|
||||
|
||||
remove(temp_filename);
|
||||
char_array_destroy(temp_file);
|
||||
@@ -512,7 +900,7 @@ int main(int argc, char **argv) {
|
||||
}
|
||||
|
||||
log_info("Done with classifier\n");
|
||||
char_array *path = char_array_new_size(strlen(output_dir) + PATH_SEPARATOR_LEN + strlen(LANGUAGE_CLASSIFIER_COUNTRY_FILENAME));
|
||||
char_array *path = char_array_new_size(strlen(output_dir) + PATH_SEPARATOR_LEN + strlen(LANGUAGE_CLASSIFIER_FILENAME));
|
||||
|
||||
char *classifier_path;
|
||||
if (language_classifier != NULL) {
|
||||
|
||||
Reference in New Issue
Block a user