[parser] adding polymorphic (as much as C does polymorphism) model type for the parser to allow it to handle either the greedy averaged perceptron or a CRF. During training, saving, and loading, we use a different filename for a parser trained with a CRF, which is still backward-compatible with models previously trained in parser-data. Making necessary modifications to address_parser.c, address_parser_train.c, and address_parser_test.c. Also adding an option in address_parser_test to print individual errors in addition to the confusion matrix.

This commit is contained in:
Al
2017-03-10 19:19:40 -05:00
parent 1bd4689c5f
commit 8deb1716cb
4 changed files with 281 additions and 61 deletions

View File

@@ -4,6 +4,7 @@
#include "address_parser_io.h"
#include "address_dictionary.h"
#include "averaged_perceptron_trainer.h"
#include "crf_trainer_averaged_perceptron.h"
#include "collections.h"
#include "constants.h"
#include "file_utils.h"
@@ -30,6 +31,7 @@ KHASH_MAP_INIT_STR(phrase_types, address_parser_types_t)
#define DEFAULT_ITERATIONS 5
#define DEFAULT_MIN_UPDATES 5
#define DEFAULT_MODEL_TYPE ADDRESS_PARSER_TYPE_CRF
#define MIN_VOCAB_COUNT 5
#define MIN_PHRASE_COUNT 1
@@ -691,8 +693,6 @@ address_parser_t *address_parser_init(char *filename) {
log_info("Calculating phrase types\n");
parser->model = NULL;
size_t num_classes = kh_size(class_counts);
log_info("num_classes = %zu\n", num_classes);
parser->num_classes = num_classes;
@@ -977,9 +977,77 @@ exit_hashes_allocated:
return parser;
}
static inline bool address_parser_train_example(address_parser_t *self, void *trainer, address_parser_context_t *context, address_parser_data_set_t *data_set) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
return averaged_perceptron_trainer_train_example((averaged_perceptron_trainer_t *)trainer, self, context, context->features, context->prev_tag_features, context->prev2_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels);
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
return crf_averaged_perceptron_trainer_train_example((crf_averaged_perceptron_trainer_t *)trainer, self, context, context->features, context->prev_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels);
} else {
log_error("Parser model is of unknown type\n");
}
return false;
}
static inline void address_parser_trainer_destroy(address_parser_t *self, void *trainer) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
averaged_perceptron_trainer_destroy((averaged_perceptron_trainer_t *)trainer);
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
crf_averaged_perceptron_trainer_destroy((crf_averaged_perceptron_trainer_t *)trainer);
}
}
bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trainer_t *trainer, char *filename) {
static inline bool address_parser_finalize_model(address_parser_t *self, void *trainer) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
self->model.ap = averaged_perceptron_trainer_finalize((averaged_perceptron_trainer_t *)trainer);
return self->model.ap != NULL;
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
self->model.crf = crf_averaged_perceptron_trainer_finalize((crf_averaged_perceptron_trainer_t *)trainer);
return self->model.crf != NULL;
} else {
log_error("Parser model is of unknown type\n");
}
return false;
}
static inline uint32_t address_parser_train_num_iterations(address_parser_t *self, void *trainer) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
return ap_trainer->iterations;
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
return crf_trainer->iterations;
} else {
log_error("Parser model is of unknown type\n");
}
return 0;
}
static inline void address_parser_train_set_iterations(address_parser_t *self, void *trainer, uint32_t iterations) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
ap_trainer->iterations = iterations;
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
crf_trainer->iterations = iterations;
} else {
log_error("Parser model is of unknown type\n");
}
}
static inline uint64_t address_parser_train_num_errors(address_parser_t *self, void *trainer) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
return ap_trainer->num_updates;
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
return crf_trainer->num_updates;
} else {
log_error("Parser model is of unknown type\n");
}
return 0;
}
bool address_parser_train_epoch(address_parser_t *self, void *trainer, char *filename) {
if (filename == NULL) {
log_error("Filename was NULL\n");
return false;
@@ -994,7 +1062,9 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai
address_parser_context_t *context = address_parser_context_new();
size_t examples = 0;
size_t errors = trainer->num_errors;
uint64_t errors = address_parser_train_num_errors(self, trainer);
uint32_t iteration = address_parser_train_num_iterations(self, trainer);
bool logged = false;
@@ -1007,7 +1077,7 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai
address_parser_context_fill(context, self, data_set->tokenized_str, language, country);
bool example_success = averaged_perceptron_trainer_train_example(trainer, self, context, context->features, context->prev_tag_features, context->prev2_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels);
bool example_success = address_parser_train_example(self, trainer, context, data_set);
if (!example_success) {
log_error("Error training example\n");
@@ -1024,10 +1094,11 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai
examples++;
if (examples % 1000 == 0 && examples > 0) {
log_info("Iter %d: Did %zu examples with %llu errors\n", trainer->iterations, examples, trainer->num_errors - errors);
errors = trainer->num_errors;
}
uint64_t prev_errors = errors;
errors = address_parser_train_num_errors(self, trainer);
log_info("Iter %d: Did %zu examples with %llu errors\n", iteration, examples, errors - prev_errors);
}
}
exit_epoch_training_started:
@@ -1037,20 +1108,29 @@ exit_epoch_training_started:
return true;
}
bool address_parser_train(address_parser_t *self, char *filename, uint32_t num_iterations, size_t min_updates) {
averaged_perceptron_trainer_t *trainer = averaged_perceptron_trainer_new(min_updates);
bool address_parser_train(address_parser_t *self, char *filename, address_parser_model_type_t model_type, uint32_t num_iterations, size_t min_updates) {
self->model_type = model_type;
void *trainer;
if (model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
averaged_perceptron_trainer_t *ap_trainer = averaged_perceptron_trainer_new(min_updates);
trainer = (void *)ap_trainer;
} else if (model_type == ADDRESS_PARSER_TYPE_CRF) {
crf_averaged_perceptron_trainer_t *crf_trainer = crf_averaged_perceptron_trainer_new(self->num_classes, min_updates);
trainer = (void *)crf_trainer;
}
for (uint32_t iter = 0; iter < num_iterations; iter++) {
log_info("Doing epoch %d\n", iter);
trainer->iterations = iter;
address_parser_train_set_iterations(self, trainer, iter);
#if defined(HAVE_SHUF) || defined(HAVE_GSHUF)
log_info("Shuffling\n");
if (!shuffle_file_chunked_size(filename, DEFAULT_SHUFFLE_CHUNK_SIZE)) {
log_error("Error in shuffle\n");
averaged_perceptron_trainer_destroy(trainer);
address_parser_trainer_destroy(self, trainer);
return false;
}
@@ -1059,14 +1139,17 @@ bool address_parser_train(address_parser_t *self, char *filename, uint32_t num_i
if (!address_parser_train_epoch(self, trainer, filename)) {
log_error("Error in epoch\n");
averaged_perceptron_trainer_destroy(trainer);
address_parser_trainer_destroy(self, trainer);
return false;
}
}
log_debug("Done with training, averaging weights\n");
self->model = averaged_perceptron_trainer_finalize(trainer);
if (!address_parser_finalize_model(self, trainer)) {
log_error("model was NULL\n");
return false;
}
return true;
}
@@ -1074,10 +1157,11 @@ bool address_parser_train(address_parser_t *self, char *filename, uint32_t num_i
typedef enum {
ADDRESS_PARSER_TRAIN_POSITIONAL_ARG,
ADDRESS_PARSER_TRAIN_ARG_ITERATIONS,
ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES
ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES,
ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE
} address_parser_train_keyword_arg_t;
#define USAGE "Usage: ./address_parser_train filename output_dir [--iterations number]\n"
#define USAGE "Usage: ./address_parser_train filename output_dir [--iterations number --min-updates number --model (crf|greedyap)]\n"
int main(int argc, char **argv) {
if (argc < 3) {
@@ -1103,6 +1187,8 @@ int main(int argc, char **argv) {
char *filename = NULL;
char *output_dir = NULL;
address_parser_model_type_t model_type = DEFAULT_MODEL_TYPE;
for (int i = pos_args; i < argc; i++) {
char *arg = argv[i];
@@ -1116,6 +1202,11 @@ int main(int argc, char **argv) {
continue;
}
if (string_equals(arg, "--model")) {
kwarg = ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE;
continue;
}
if (kwarg == ADDRESS_PARSER_TRAIN_ARG_ITERATIONS) {
if (sscanf(arg, "%zd", &arg_iterations) != 1 || arg_iterations < 0) {
log_error("Bad arg for --iterations: %s\n", arg);
@@ -1129,6 +1220,15 @@ int main(int argc, char **argv) {
}
min_updates = arg_min_updates;
log_info("min_updates = %llu\n", min_updates);
} else if (kwarg == ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE) {
if (string_equals(arg, "crf")) {
model_type = ADDRESS_PARSER_TYPE_CRF;
} else if (string_equals(arg, "greedyap")) {
model_type = ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON;
} else {
log_error("Bad arg for --model, valid values are [crf, greedyap]\n");
exit(EXIT_FAILURE);
}
} else if (position == 0) {
filename = arg;
position++;
@@ -1169,7 +1269,7 @@ int main(int argc, char **argv) {
log_info("Finished initialization\n");
if (!address_parser_train(parser, filename, num_iterations, min_updates)) {
if (!address_parser_train(parser, filename, model_type, num_iterations, min_updates)) {
log_error("Error in training\n");
exit(EXIT_FAILURE);
}