diff --git a/src/address_parser.c b/src/address_parser.c index a9199dd9..3306f653 100644 --- a/src/address_parser.c +++ b/src/address_parser.c @@ -7,6 +7,7 @@ #include "log/log.h" #define ADDRESS_PARSER_MODEL_FILENAME "address_parser.dat" +#define ADDRESS_PARSER_MODEL_FILENAME_CRF "address_parser_crf.dat" #define ADDRESS_PARSER_VOCAB_FILENAME "address_parser_vocab.trie" #define ADDRESS_PARSER_PHRASE_FILENAME "address_parser_phrases.dat" #define ADDRESS_PARSER_POSTAL_CODES_FILENAME "address_parser_postal_codes.dat" @@ -48,14 +49,32 @@ address_parser_t *get_address_parser(void) { bool address_parser_save(address_parser_t *self, char *output_dir) { if (self == NULL || output_dir == NULL) return false; + char *model_filename = NULL; + if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + model_filename = ADDRESS_PARSER_MODEL_FILENAME; + } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) { + model_filename = ADDRESS_PARSER_MODEL_FILENAME_CRF; + } else { + return false; + } + char_array *path = char_array_new_size(strlen(output_dir)); - char_array_add_joined(path, PATH_SEPARATOR, true, 2, output_dir, ADDRESS_PARSER_MODEL_FILENAME); + char_array_add_joined(path, PATH_SEPARATOR, true, 2, output_dir, model_filename); char *model_path = char_array_get_string(path); - if (!averaged_perceptron_save(self->model, model_path)) { - char_array_destroy(path); - return false; + if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + if (!averaged_perceptron_save(self->model.ap, model_path)) { + log_info("Error in averaged_perceptron_save\n"); + char_array_destroy(path); + return false; + } + } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) { + if (!crf_save(self->model.crf, model_path)) { + log_info("Error in crf_save\n"); + char_array_destroy(path); + return false; + } } char_array_clear(path); @@ -148,15 +167,47 @@ bool address_parser_load(char *dir) { char_array_add_joined(path, PATH_SEPARATOR, true, 2, dir, ADDRESS_PARSER_MODEL_FILENAME); char *model_path = char_array_get_string(path); - averaged_perceptron_t *model = averaged_perceptron_load(model_path); - - if (model == NULL) { - char_array_destroy(path); - return false; + if (file_exists(model_path)) { + averaged_perceptron_t *ap_model = averaged_perceptron_load(model_path); + if (ap_model != NULL) { + parser = address_parser_new(); + parser->model_type = ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON; + parser->model.ap = ap_model; + } else { + char_array_destroy(model_path); + log_error("Averaged perceptron model could not be loaded\n"); + return false; + } + } else { + model_path = NULL; } - parser = address_parser_new(); - parser->model = model; + if (model_path == NULL) { + char_array_clear(path); + char_array_add_joined(path, PATH_SEPARATOR, true, 2, dir, ADDRESS_PARSER_MODEL_FILENAME_CRF); + model_path = char_array_get_string(path); + + if (file_exists(model_path)) { + crf_t *crf_model = crf_load(model_path); + if (crf_model != NULL) { + parser = address_parser_new(); + parser->model_type = ADDRESS_PARSER_TYPE_CRF; + parser->model.crf = crf_model; + } else { + char_array_destroy(model_path); + log_error("Averaged perceptron model could not be loaded\n"); + return false; + } + } else { + model_path == NULL; + } + } + + if (parser == NULL) { + char_array_destroy(path); + log_error("Could not find parser model file of known type\n"); + return false; + } char_array_clear(path); @@ -276,8 +327,10 @@ exit_address_parser_created: void address_parser_destroy(address_parser_t *self) { if (self == NULL) return; - if (self->model != NULL) { - averaged_perceptron_destroy(self->model); + if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON && self->model.ap != NULL) { + averaged_perceptron_destroy(self->model.ap); + } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF && self->model.crf != NULL) { + crf_destroy(self->model.crf); } if (self->vocab != NULL) { @@ -313,7 +366,7 @@ inline void address_parser_normalize_token(cstring_array *array, char *str, toke normalize_token(array, str, token, ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS); } -inline void address_parser_normalize_phrase_token(cstring_array *array, char *str, token_t token) { +static inline void address_parser_normalize_phrase_token(cstring_array *array, char *str, token_t token) { normalize_token(array, str, token, ADDRESS_PARSER_NORMALIZE_ADMIN_TOKEN_OPTIONS); } @@ -814,7 +867,7 @@ typedef struct address_parser_phrase { phrase_t phrase; } address_parser_phrase_t; -inline bool is_plain_word_phrase_type(address_parser_phrase_type_t type) { +static inline bool is_plain_word_phrase_type(address_parser_phrase_type_t type) { return type == ADDRESS_PARSER_NULL_PHRASE || type == ADDRESS_PARSER_SUFFIX_PHRASE || type == ADDRESS_PARSER_PREFIX_PHRASE; } @@ -1003,8 +1056,8 @@ address_parser_features This is a feature function similar to those found in MEMM and CRF models. -Follows the signature of an ap_feature_function so it can be called -as a function pointer by the averaged perceptron model. +Follows the signature of a tagger_feature_function so it can be called +as a function pointer by the averaged perceptron or CRF model. Parameters: @@ -1399,7 +1452,7 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize } } - if (idx == 0) { + if (idx == 0 && !is_unknown_word) { feature_array_add(features, 2, "first word", word); //feature_array_add(features, 3, "first word+next word", word, next_word); } @@ -1413,11 +1466,17 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize if (last_index == idx - 1) { // Previous tag and current word feature_array_add(prev_tag_features, 2, "word", word); - feature_array_add(prev_tag_features, 1, "trans"); // Previous two tags and current word - feature_array_add(prev2_tag_features, 2, "word", word); - feature_array_add(prev2_tag_features, 1, "trans"); + if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + // In the CRF this is accounted for by the transition weights + // so only need it for the averaged perceptron + feature_array_add(prev_tag_features, 1, "trans"); + + // Averaged perceptron uses two tags of history, CRF uses one + feature_array_add(prev2_tag_features, 2, "word", word); + feature_array_add(prev2_tag_features, 1, "trans"); + } } if (last_index >= 0) { @@ -1583,6 +1642,17 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize } +bool address_parser_predict(address_parser_t *self, address_parser_context_t *context, cstring_array *token_labels, tagger_feature_function feature_function, tokenized_string_t *tokenized_str) { + if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + return averaged_perceptron_tagger_predict(self->model.ap, self, context, context->features, context->prev_tag_features, context->prev2_tag_features, token_labels, feature_function, tokenized_str, self->options.print_features); + } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) { + return crf_tagger_predict(self->model.crf, self, context, context->features, context->prev_tag_features, token_labels, feature_function, tokenized_str, self->options.print_features); + } else { + log_error("Parser has unknown model type\n"); + } + return false; +} + address_parser_response_t *address_parser_response_new(void) { address_parser_response_t *response = malloc(sizeof(address_parser_response_t)); return response; @@ -1603,8 +1673,6 @@ address_parser_response_t *address_parser_parse(char *address, char *language, c normalized = address; } - averaged_perceptron_t *model = parser->model; - token_array *tokens = tokenize(normalized); tokenized_string_t *tokenized_str = tokenized_string_new_from_str_size(normalized, strlen(normalized), tokens->n); @@ -1706,7 +1774,9 @@ address_parser_response_t *address_parser_parse(char *address, char *language, c char *prev_label = NULL; - if (averaged_perceptron_tagger_predict(model, parser, context, context->features, context->prev_tag_features, context->prev2_tag_features, token_labels, &address_parser_features, tokenized_str, parser->options.print_features)) { + bool prediction_success = address_parser_predict(parser, context, token_labels, &address_parser_features, tokenized_str); + + if (prediction_success) { response = address_parser_response_new(); size_t num_strings = cstring_array_num_strings(tokenized_str->strings); @@ -1738,6 +1808,8 @@ address_parser_response_t *address_parser_parse(char *address, char *language, c response->components = cstring_array_to_strings(components); response->labels = cstring_array_to_strings(labels); + } else { + log_error("Error in prediction\n"); } token_array_destroy(tokens); diff --git a/src/address_parser.h b/src/address_parser.h index dc19a249..ec9a9f72 100644 --- a/src/address_parser.h +++ b/src/address_parser.h @@ -46,11 +46,13 @@ with the general error-driven averaged perceptron. #include #include -#include "averaged_perceptron.h" -#include "averaged_perceptron_tagger.h" #include "libpostal.h" #include "libpostal_config.h" + +#include "averaged_perceptron.h" +#include "averaged_perceptron_tagger.h" #include "collections.h" +#include "crf.h" #include "normalize.h" #include "string_utils.h" @@ -178,6 +180,11 @@ typedef union postal_code_context_value { #define POSTAL_CODE_CONTEXT(pc, ad) ((postal_code_context_value_t){.postcode = (pc), .admin = (ad) }) +typedef enum address_parser_model_type { + ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON, + ADDRESS_PARSER_TYPE_CRF +} address_parser_model_type_t; + typedef struct parser_options { uint64_t rare_word_threshold; bool print_features; @@ -187,7 +194,11 @@ typedef struct parser_options { typedef struct address_parser { parser_options_t options; size_t num_classes; - averaged_perceptron_t *model; + address_parser_model_type_t model_type; + union { + averaged_perceptron_t *ap; + crf_t *crf; + } model; trie_t *vocab; trie_t *phrases; address_parser_types_array *phrase_types; @@ -208,6 +219,8 @@ void address_parser_destroy(address_parser_t *self); char *address_parser_normalize_string(char *str); void address_parser_normalize_token(cstring_array *array, char *str, token_t token); +bool address_parser_predict(address_parser_t *self, address_parser_context_t *context, cstring_array *token_labels, tagger_feature_function feature_function, tokenized_string_t *tokenized_str); + address_parser_context_t *address_parser_context_new(void); void address_parser_context_destroy(address_parser_context_t *self); diff --git a/src/address_parser_test.c b/src/address_parser_test.c index a51e2bf9..966e60fc 100644 --- a/src/address_parser_test.c +++ b/src/address_parser_test.c @@ -1,7 +1,6 @@ #include "address_parser.h" #include "address_parser_io.h" #include "address_dictionary.h" -#include "averaged_perceptron_trainer.h" #include "collections.h" #include "constants.h" #include "file_utils.h" @@ -10,8 +9,6 @@ #include "log/log.h" -//#define ADDRESS_PARSER_TEST_PRINT_ERRORS - typedef struct address_parser_test_results { size_t num_errors; size_t num_predictions; @@ -21,19 +18,44 @@ typedef struct address_parser_test_results { } address_parser_test_results_t; -uint32_t get_class_index(address_parser_t *parser, char *name) { +static uint32_t address_parser_num_classes(address_parser_t *parser) { + if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + averaged_perceptron_t *ap = parser->model.ap; + return parser->model.ap->num_classes; + } else if (parser->model_type == ADDRESS_PARSER_TYPE_CRF) { + return parser->model.crf->num_classes; + } + return 0; +} + +static cstring_array *address_parser_class_strings(address_parser_t *parser) { + if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + cstring_array *classes = parser->model.ap->classes; + return parser->model.ap->classes; + } else if (parser->model_type == ADDRESS_PARSER_TYPE_CRF) { + return parser->model.crf->classes; + } + return NULL; +} + +static uint32_t address_parser_get_class_index(address_parser_t *parser, char *name) { uint32_t i; char *str; - cstring_array_foreach(parser->model->classes, i, str, { - if (strcmp(name, str) == 0) { - return i; - } - }) + cstring_array *classes = address_parser_class_strings(parser); + uint32_t num_classes = address_parser_num_classes(parser); + if (classes != NULL) { + cstring_array_foreach(classes, i, str, { + if (strcmp(name, str) == 0) { + return i; + } + }) + } - return parser->model->num_classes; + return num_classes; } + #define EMPTY_ADDRESS_PARSER_TEST_RESULT (address_parser_test_results_t){0, 0, 0, 0, NULL} bool address_parser_test(address_parser_t *parser, char *filename, address_parser_test_results_t *result, bool print_errors) { @@ -42,7 +64,7 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse return NULL; } - uint32_t num_classes = parser->model->num_classes; + uint32_t num_classes = address_parser_num_classes(parser); result->confusion = calloc(num_classes * num_classes, sizeof(uint32_t)); @@ -80,7 +102,9 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse size_t starting_errors = result->num_errors; - if (averaged_perceptron_tagger_predict(parser->model, parser, context, context->features, context->prev_tag_features, context->prev2_tag_features, token_labels, &address_parser_features, data_set->tokenized_str, false)) { + bool prediction_success = address_parser_predict(parser, context, token_labels, &address_parser_features, data_set->tokenized_str); + + if (prediction_success) { uint32_t i; char *predicted; cstring_array_foreach(token_labels, i, predicted, { @@ -89,8 +113,8 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse if (strcmp(predicted, truth) != 0) { result->num_errors++; - uint32_t predicted_index = get_class_index(parser, predicted); - uint32_t truth_index = get_class_index(parser, truth); + uint32_t predicted_index = address_parser_get_class_index(parser, predicted); + uint32_t truth_index = address_parser_get_class_index(parser, truth); result->confusion[predicted_index * num_classes + truth_index]++; @@ -102,6 +126,10 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse }) + } else { + log_error("Error in prediction\n"); + tokenized_string_destroy(data_set->tokenized_str); + break; } if (result->num_errors > starting_errors) { @@ -184,6 +212,12 @@ int main(int argc, char **argv) { address_parser_t *parser = get_address_parser(); + if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + printf("averaged perceptron parser\n"); + } else if (parser->model_type == ADDRESS_PARSER_TYPE_CRF) { + printf("crf parser\n"); + } + address_parser_test_results_t results = EMPTY_ADDRESS_PARSER_TEST_RESULT; if (!address_parser_test(parser, filename, &results, print_errors)) { @@ -196,7 +230,7 @@ int main(int argc, char **argv) { printf("Confusion matrix:\n\n"); - uint32_t num_classes = parser->model->num_classes; + uint32_t num_classes = address_parser_num_classes(parser); size_t *confusion_sorted = uint32_array_argsort(results.confusion, num_classes * num_classes); @@ -211,8 +245,9 @@ int main(int argc, char **argv) { if (i == j) continue; if (class_errors > 0) { - char *predicted = cstring_array_get_string(parser->model->classes, i); - char *truth = cstring_array_get_string(parser->model->classes, j); + cstring_array *classes = address_parser_class_strings(parser); + char *predicted = cstring_array_get_string(classes, i); + char *truth = cstring_array_get_string(classes, j); printf("(%s, %s): %d\n", predicted, truth, class_errors); } diff --git a/src/address_parser_train.c b/src/address_parser_train.c index 1818c8c7..e08d5aae 100644 --- a/src/address_parser_train.c +++ b/src/address_parser_train.c @@ -4,6 +4,7 @@ #include "address_parser_io.h" #include "address_dictionary.h" #include "averaged_perceptron_trainer.h" +#include "crf_trainer_averaged_perceptron.h" #include "collections.h" #include "constants.h" #include "file_utils.h" @@ -30,6 +31,7 @@ KHASH_MAP_INIT_STR(phrase_types, address_parser_types_t) #define DEFAULT_ITERATIONS 5 #define DEFAULT_MIN_UPDATES 5 +#define DEFAULT_MODEL_TYPE ADDRESS_PARSER_TYPE_CRF #define MIN_VOCAB_COUNT 5 #define MIN_PHRASE_COUNT 1 @@ -691,8 +693,6 @@ address_parser_t *address_parser_init(char *filename) { log_info("Calculating phrase types\n"); - parser->model = NULL; - size_t num_classes = kh_size(class_counts); log_info("num_classes = %zu\n", num_classes); parser->num_classes = num_classes; @@ -977,9 +977,77 @@ exit_hashes_allocated: return parser; } +static inline bool address_parser_train_example(address_parser_t *self, void *trainer, address_parser_context_t *context, address_parser_data_set_t *data_set) { + if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + return averaged_perceptron_trainer_train_example((averaged_perceptron_trainer_t *)trainer, self, context, context->features, context->prev_tag_features, context->prev2_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels); + } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) { + return crf_averaged_perceptron_trainer_train_example((crf_averaged_perceptron_trainer_t *)trainer, self, context, context->features, context->prev_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels); + } else { + log_error("Parser model is of unknown type\n"); + } + return false; +} +static inline void address_parser_trainer_destroy(address_parser_t *self, void *trainer) { + if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + averaged_perceptron_trainer_destroy((averaged_perceptron_trainer_t *)trainer); + } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) { + crf_averaged_perceptron_trainer_destroy((crf_averaged_perceptron_trainer_t *)trainer); + } +} -bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trainer_t *trainer, char *filename) { +static inline bool address_parser_finalize_model(address_parser_t *self, void *trainer) { + if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + self->model.ap = averaged_perceptron_trainer_finalize((averaged_perceptron_trainer_t *)trainer); + return self->model.ap != NULL; + } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) { + self->model.crf = crf_averaged_perceptron_trainer_finalize((crf_averaged_perceptron_trainer_t *)trainer); + return self->model.crf != NULL; + } else { + log_error("Parser model is of unknown type\n"); + } + return false; +} + +static inline uint32_t address_parser_train_num_iterations(address_parser_t *self, void *trainer) { + if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer; + return ap_trainer->iterations; + } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) { + crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer; + return crf_trainer->iterations; + } else { + log_error("Parser model is of unknown type\n"); + } + return 0; +} + +static inline void address_parser_train_set_iterations(address_parser_t *self, void *trainer, uint32_t iterations) { + if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer; + ap_trainer->iterations = iterations; + } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) { + crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer; + crf_trainer->iterations = iterations; + } else { + log_error("Parser model is of unknown type\n"); + } +} + +static inline uint64_t address_parser_train_num_errors(address_parser_t *self, void *trainer) { + if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer; + return ap_trainer->num_updates; + } else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) { + crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer; + return crf_trainer->num_updates; + } else { + log_error("Parser model is of unknown type\n"); + } + return 0; +} + +bool address_parser_train_epoch(address_parser_t *self, void *trainer, char *filename) { if (filename == NULL) { log_error("Filename was NULL\n"); return false; @@ -994,7 +1062,9 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai address_parser_context_t *context = address_parser_context_new(); size_t examples = 0; - size_t errors = trainer->num_errors; + uint64_t errors = address_parser_train_num_errors(self, trainer); + + uint32_t iteration = address_parser_train_num_iterations(self, trainer); bool logged = false; @@ -1007,7 +1077,7 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai address_parser_context_fill(context, self, data_set->tokenized_str, language, country); - bool example_success = averaged_perceptron_trainer_train_example(trainer, self, context, context->features, context->prev_tag_features, context->prev2_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels); + bool example_success = address_parser_train_example(self, trainer, context, data_set); if (!example_success) { log_error("Error training example\n"); @@ -1024,10 +1094,11 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai examples++; if (examples % 1000 == 0 && examples > 0) { - log_info("Iter %d: Did %zu examples with %llu errors\n", trainer->iterations, examples, trainer->num_errors - errors); - errors = trainer->num_errors; - } + uint64_t prev_errors = errors; + errors = address_parser_train_num_errors(self, trainer); + log_info("Iter %d: Did %zu examples with %llu errors\n", iteration, examples, errors - prev_errors); + } } exit_epoch_training_started: @@ -1037,20 +1108,29 @@ exit_epoch_training_started: return true; } -bool address_parser_train(address_parser_t *self, char *filename, uint32_t num_iterations, size_t min_updates) { - averaged_perceptron_trainer_t *trainer = averaged_perceptron_trainer_new(min_updates); + +bool address_parser_train(address_parser_t *self, char *filename, address_parser_model_type_t model_type, uint32_t num_iterations, size_t min_updates) { + self->model_type = model_type; + void *trainer; + if (model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) { + averaged_perceptron_trainer_t *ap_trainer = averaged_perceptron_trainer_new(min_updates); + trainer = (void *)ap_trainer; + } else if (model_type == ADDRESS_PARSER_TYPE_CRF) { + crf_averaged_perceptron_trainer_t *crf_trainer = crf_averaged_perceptron_trainer_new(self->num_classes, min_updates); + trainer = (void *)crf_trainer; + } for (uint32_t iter = 0; iter < num_iterations; iter++) { log_info("Doing epoch %d\n", iter); - trainer->iterations = iter; + address_parser_train_set_iterations(self, trainer, iter); #if defined(HAVE_SHUF) || defined(HAVE_GSHUF) log_info("Shuffling\n"); if (!shuffle_file_chunked_size(filename, DEFAULT_SHUFFLE_CHUNK_SIZE)) { log_error("Error in shuffle\n"); - averaged_perceptron_trainer_destroy(trainer); + address_parser_trainer_destroy(self, trainer); return false; } @@ -1059,14 +1139,17 @@ bool address_parser_train(address_parser_t *self, char *filename, uint32_t num_i if (!address_parser_train_epoch(self, trainer, filename)) { log_error("Error in epoch\n"); - averaged_perceptron_trainer_destroy(trainer); + address_parser_trainer_destroy(self, trainer); return false; } } log_debug("Done with training, averaging weights\n"); - self->model = averaged_perceptron_trainer_finalize(trainer); + if (!address_parser_finalize_model(self, trainer)) { + log_error("model was NULL\n"); + return false; + } return true; } @@ -1074,10 +1157,11 @@ bool address_parser_train(address_parser_t *self, char *filename, uint32_t num_i typedef enum { ADDRESS_PARSER_TRAIN_POSITIONAL_ARG, ADDRESS_PARSER_TRAIN_ARG_ITERATIONS, - ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES + ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES, + ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE } address_parser_train_keyword_arg_t; -#define USAGE "Usage: ./address_parser_train filename output_dir [--iterations number]\n" +#define USAGE "Usage: ./address_parser_train filename output_dir [--iterations number --min-updates number --model (crf|greedyap)]\n" int main(int argc, char **argv) { if (argc < 3) { @@ -1103,6 +1187,8 @@ int main(int argc, char **argv) { char *filename = NULL; char *output_dir = NULL; + address_parser_model_type_t model_type = DEFAULT_MODEL_TYPE; + for (int i = pos_args; i < argc; i++) { char *arg = argv[i]; @@ -1116,6 +1202,11 @@ int main(int argc, char **argv) { continue; } + if (string_equals(arg, "--model")) { + kwarg = ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE; + continue; + } + if (kwarg == ADDRESS_PARSER_TRAIN_ARG_ITERATIONS) { if (sscanf(arg, "%zd", &arg_iterations) != 1 || arg_iterations < 0) { log_error("Bad arg for --iterations: %s\n", arg); @@ -1129,6 +1220,15 @@ int main(int argc, char **argv) { } min_updates = arg_min_updates; log_info("min_updates = %llu\n", min_updates); + } else if (kwarg == ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE) { + if (string_equals(arg, "crf")) { + model_type = ADDRESS_PARSER_TYPE_CRF; + } else if (string_equals(arg, "greedyap")) { + model_type = ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON; + } else { + log_error("Bad arg for --model, valid values are [crf, greedyap]\n"); + exit(EXIT_FAILURE); + } } else if (position == 0) { filename = arg; position++; @@ -1169,7 +1269,7 @@ int main(int argc, char **argv) { log_info("Finished initialization\n"); - if (!address_parser_train(parser, filename, num_iterations, min_updates)) { + if (!address_parser_train(parser, filename, model_type, num_iterations, min_updates)) { log_error("Error in training\n"); exit(EXIT_FAILURE); }