[parser] adding polymorphic (as much as C does polymorphism) model type for the parser to allow it to handle either the greedy averaged perceptron or a CRF. During training, saving, and loading, we use a different filename for a parser trained with a CRF, which is still backward-compatible with models previously trained in parser-data. Making necessary modifications to address_parser.c, address_parser_train.c, and address_parser_test.c. Also adding an option in address_parser_test to print individual errors in addition to the confusion matrix.

This commit is contained in:
Al
2017-03-10 19:19:40 -05:00
parent 1bd4689c5f
commit 8deb1716cb
4 changed files with 281 additions and 61 deletions

View File

@@ -7,6 +7,7 @@
#include "log/log.h"
#define ADDRESS_PARSER_MODEL_FILENAME "address_parser.dat"
#define ADDRESS_PARSER_MODEL_FILENAME_CRF "address_parser_crf.dat"
#define ADDRESS_PARSER_VOCAB_FILENAME "address_parser_vocab.trie"
#define ADDRESS_PARSER_PHRASE_FILENAME "address_parser_phrases.dat"
#define ADDRESS_PARSER_POSTAL_CODES_FILENAME "address_parser_postal_codes.dat"
@@ -48,15 +49,33 @@ address_parser_t *get_address_parser(void) {
bool address_parser_save(address_parser_t *self, char *output_dir) {
if (self == NULL || output_dir == NULL) return false;
char *model_filename = NULL;
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
model_filename = ADDRESS_PARSER_MODEL_FILENAME;
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
model_filename = ADDRESS_PARSER_MODEL_FILENAME_CRF;
} else {
return false;
}
char_array *path = char_array_new_size(strlen(output_dir));
char_array_add_joined(path, PATH_SEPARATOR, true, 2, output_dir, ADDRESS_PARSER_MODEL_FILENAME);
char_array_add_joined(path, PATH_SEPARATOR, true, 2, output_dir, model_filename);
char *model_path = char_array_get_string(path);
if (!averaged_perceptron_save(self->model, model_path)) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
if (!averaged_perceptron_save(self->model.ap, model_path)) {
log_info("Error in averaged_perceptron_save\n");
char_array_destroy(path);
return false;
}
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
if (!crf_save(self->model.crf, model_path)) {
log_info("Error in crf_save\n");
char_array_destroy(path);
return false;
}
}
char_array_clear(path);
@@ -148,15 +167,47 @@ bool address_parser_load(char *dir) {
char_array_add_joined(path, PATH_SEPARATOR, true, 2, dir, ADDRESS_PARSER_MODEL_FILENAME);
char *model_path = char_array_get_string(path);
averaged_perceptron_t *model = averaged_perceptron_load(model_path);
if (model == NULL) {
char_array_destroy(path);
if (file_exists(model_path)) {
averaged_perceptron_t *ap_model = averaged_perceptron_load(model_path);
if (ap_model != NULL) {
parser = address_parser_new();
parser->model_type = ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON;
parser->model.ap = ap_model;
} else {
char_array_destroy(model_path);
log_error("Averaged perceptron model could not be loaded\n");
return false;
}
} else {
model_path = NULL;
}
if (model_path == NULL) {
char_array_clear(path);
char_array_add_joined(path, PATH_SEPARATOR, true, 2, dir, ADDRESS_PARSER_MODEL_FILENAME_CRF);
model_path = char_array_get_string(path);
if (file_exists(model_path)) {
crf_t *crf_model = crf_load(model_path);
if (crf_model != NULL) {
parser = address_parser_new();
parser->model = model;
parser->model_type = ADDRESS_PARSER_TYPE_CRF;
parser->model.crf = crf_model;
} else {
char_array_destroy(model_path);
log_error("Averaged perceptron model could not be loaded\n");
return false;
}
} else {
model_path == NULL;
}
}
if (parser == NULL) {
char_array_destroy(path);
log_error("Could not find parser model file of known type\n");
return false;
}
char_array_clear(path);
@@ -276,8 +327,10 @@ exit_address_parser_created:
void address_parser_destroy(address_parser_t *self) {
if (self == NULL) return;
if (self->model != NULL) {
averaged_perceptron_destroy(self->model);
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON && self->model.ap != NULL) {
averaged_perceptron_destroy(self->model.ap);
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF && self->model.crf != NULL) {
crf_destroy(self->model.crf);
}
if (self->vocab != NULL) {
@@ -313,7 +366,7 @@ inline void address_parser_normalize_token(cstring_array *array, char *str, toke
normalize_token(array, str, token, ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS);
}
inline void address_parser_normalize_phrase_token(cstring_array *array, char *str, token_t token) {
static inline void address_parser_normalize_phrase_token(cstring_array *array, char *str, token_t token) {
normalize_token(array, str, token, ADDRESS_PARSER_NORMALIZE_ADMIN_TOKEN_OPTIONS);
}
@@ -814,7 +867,7 @@ typedef struct address_parser_phrase {
phrase_t phrase;
} address_parser_phrase_t;
inline bool is_plain_word_phrase_type(address_parser_phrase_type_t type) {
static inline bool is_plain_word_phrase_type(address_parser_phrase_type_t type) {
return type == ADDRESS_PARSER_NULL_PHRASE || type == ADDRESS_PARSER_SUFFIX_PHRASE || type == ADDRESS_PARSER_PREFIX_PHRASE;
}
@@ -1003,8 +1056,8 @@ address_parser_features
This is a feature function similar to those found in MEMM and CRF models.
Follows the signature of an ap_feature_function so it can be called
as a function pointer by the averaged perceptron model.
Follows the signature of a tagger_feature_function so it can be called
as a function pointer by the averaged perceptron or CRF model.
Parameters:
@@ -1399,7 +1452,7 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize
}
}
if (idx == 0) {
if (idx == 0 && !is_unknown_word) {
feature_array_add(features, 2, "first word", word);
//feature_array_add(features, 3, "first word+next word", word, next_word);
}
@@ -1413,12 +1466,18 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize
if (last_index == idx - 1) {
// Previous tag and current word
feature_array_add(prev_tag_features, 2, "word", word);
feature_array_add(prev_tag_features, 1, "trans");
// Previous two tags and current word
if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
// In the CRF this is accounted for by the transition weights
// so only need it for the averaged perceptron
feature_array_add(prev_tag_features, 1, "trans");
// Averaged perceptron uses two tags of history, CRF uses one
feature_array_add(prev2_tag_features, 2, "word", word);
feature_array_add(prev2_tag_features, 1, "trans");
}
}
if (last_index >= 0) {
address_parser_phrase_t prev_word_or_phrase = word_or_phrase_at_index(parser, tokenized, context, last_index, false);
@@ -1583,6 +1642,17 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize
}
bool address_parser_predict(address_parser_t *self, address_parser_context_t *context, cstring_array *token_labels, tagger_feature_function feature_function, tokenized_string_t *tokenized_str) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
return averaged_perceptron_tagger_predict(self->model.ap, self, context, context->features, context->prev_tag_features, context->prev2_tag_features, token_labels, feature_function, tokenized_str, self->options.print_features);
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
return crf_tagger_predict(self->model.crf, self, context, context->features, context->prev_tag_features, token_labels, feature_function, tokenized_str, self->options.print_features);
} else {
log_error("Parser has unknown model type\n");
}
return false;
}
address_parser_response_t *address_parser_response_new(void) {
address_parser_response_t *response = malloc(sizeof(address_parser_response_t));
return response;
@@ -1603,8 +1673,6 @@ address_parser_response_t *address_parser_parse(char *address, char *language, c
normalized = address;
}
averaged_perceptron_t *model = parser->model;
token_array *tokens = tokenize(normalized);
tokenized_string_t *tokenized_str = tokenized_string_new_from_str_size(normalized, strlen(normalized), tokens->n);
@@ -1706,7 +1774,9 @@ address_parser_response_t *address_parser_parse(char *address, char *language, c
char *prev_label = NULL;
if (averaged_perceptron_tagger_predict(model, parser, context, context->features, context->prev_tag_features, context->prev2_tag_features, token_labels, &address_parser_features, tokenized_str, parser->options.print_features)) {
bool prediction_success = address_parser_predict(parser, context, token_labels, &address_parser_features, tokenized_str);
if (prediction_success) {
response = address_parser_response_new();
size_t num_strings = cstring_array_num_strings(tokenized_str->strings);
@@ -1738,6 +1808,8 @@ address_parser_response_t *address_parser_parse(char *address, char *language, c
response->components = cstring_array_to_strings(components);
response->labels = cstring_array_to_strings(labels);
} else {
log_error("Error in prediction\n");
}
token_array_destroy(tokens);

View File

@@ -46,11 +46,13 @@ with the general error-driven averaged perceptron.
#include <stdint.h>
#include <stdbool.h>
#include "averaged_perceptron.h"
#include "averaged_perceptron_tagger.h"
#include "libpostal.h"
#include "libpostal_config.h"
#include "averaged_perceptron.h"
#include "averaged_perceptron_tagger.h"
#include "collections.h"
#include "crf.h"
#include "normalize.h"
#include "string_utils.h"
@@ -178,6 +180,11 @@ typedef union postal_code_context_value {
#define POSTAL_CODE_CONTEXT(pc, ad) ((postal_code_context_value_t){.postcode = (pc), .admin = (ad) })
typedef enum address_parser_model_type {
ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON,
ADDRESS_PARSER_TYPE_CRF
} address_parser_model_type_t;
typedef struct parser_options {
uint64_t rare_word_threshold;
bool print_features;
@@ -187,7 +194,11 @@ typedef struct parser_options {
typedef struct address_parser {
parser_options_t options;
size_t num_classes;
averaged_perceptron_t *model;
address_parser_model_type_t model_type;
union {
averaged_perceptron_t *ap;
crf_t *crf;
} model;
trie_t *vocab;
trie_t *phrases;
address_parser_types_array *phrase_types;
@@ -208,6 +219,8 @@ void address_parser_destroy(address_parser_t *self);
char *address_parser_normalize_string(char *str);
void address_parser_normalize_token(cstring_array *array, char *str, token_t token);
bool address_parser_predict(address_parser_t *self, address_parser_context_t *context, cstring_array *token_labels, tagger_feature_function feature_function, tokenized_string_t *tokenized_str);
address_parser_context_t *address_parser_context_new(void);
void address_parser_context_destroy(address_parser_context_t *self);

View File

@@ -1,7 +1,6 @@
#include "address_parser.h"
#include "address_parser_io.h"
#include "address_dictionary.h"
#include "averaged_perceptron_trainer.h"
#include "collections.h"
#include "constants.h"
#include "file_utils.h"
@@ -10,8 +9,6 @@
#include "log/log.h"
//#define ADDRESS_PARSER_TEST_PRINT_ERRORS
typedef struct address_parser_test_results {
size_t num_errors;
size_t num_predictions;
@@ -21,19 +18,44 @@ typedef struct address_parser_test_results {
} address_parser_test_results_t;
uint32_t get_class_index(address_parser_t *parser, char *name) {
static uint32_t address_parser_num_classes(address_parser_t *parser) {
if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
averaged_perceptron_t *ap = parser->model.ap;
return parser->model.ap->num_classes;
} else if (parser->model_type == ADDRESS_PARSER_TYPE_CRF) {
return parser->model.crf->num_classes;
}
return 0;
}
static cstring_array *address_parser_class_strings(address_parser_t *parser) {
if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
cstring_array *classes = parser->model.ap->classes;
return parser->model.ap->classes;
} else if (parser->model_type == ADDRESS_PARSER_TYPE_CRF) {
return parser->model.crf->classes;
}
return NULL;
}
static uint32_t address_parser_get_class_index(address_parser_t *parser, char *name) {
uint32_t i;
char *str;
cstring_array_foreach(parser->model->classes, i, str, {
cstring_array *classes = address_parser_class_strings(parser);
uint32_t num_classes = address_parser_num_classes(parser);
if (classes != NULL) {
cstring_array_foreach(classes, i, str, {
if (strcmp(name, str) == 0) {
return i;
}
})
return parser->model->num_classes;
}
return num_classes;
}
#define EMPTY_ADDRESS_PARSER_TEST_RESULT (address_parser_test_results_t){0, 0, 0, 0, NULL}
bool address_parser_test(address_parser_t *parser, char *filename, address_parser_test_results_t *result, bool print_errors) {
@@ -42,7 +64,7 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
return NULL;
}
uint32_t num_classes = parser->model->num_classes;
uint32_t num_classes = address_parser_num_classes(parser);
result->confusion = calloc(num_classes * num_classes, sizeof(uint32_t));
@@ -80,7 +102,9 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
size_t starting_errors = result->num_errors;
if (averaged_perceptron_tagger_predict(parser->model, parser, context, context->features, context->prev_tag_features, context->prev2_tag_features, token_labels, &address_parser_features, data_set->tokenized_str, false)) {
bool prediction_success = address_parser_predict(parser, context, token_labels, &address_parser_features, data_set->tokenized_str);
if (prediction_success) {
uint32_t i;
char *predicted;
cstring_array_foreach(token_labels, i, predicted, {
@@ -89,8 +113,8 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
if (strcmp(predicted, truth) != 0) {
result->num_errors++;
uint32_t predicted_index = get_class_index(parser, predicted);
uint32_t truth_index = get_class_index(parser, truth);
uint32_t predicted_index = address_parser_get_class_index(parser, predicted);
uint32_t truth_index = address_parser_get_class_index(parser, truth);
result->confusion[predicted_index * num_classes + truth_index]++;
@@ -102,6 +126,10 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
})
} else {
log_error("Error in prediction\n");
tokenized_string_destroy(data_set->tokenized_str);
break;
}
if (result->num_errors > starting_errors) {
@@ -184,6 +212,12 @@ int main(int argc, char **argv) {
address_parser_t *parser = get_address_parser();
if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
printf("averaged perceptron parser\n");
} else if (parser->model_type == ADDRESS_PARSER_TYPE_CRF) {
printf("crf parser\n");
}
address_parser_test_results_t results = EMPTY_ADDRESS_PARSER_TEST_RESULT;
if (!address_parser_test(parser, filename, &results, print_errors)) {
@@ -196,7 +230,7 @@ int main(int argc, char **argv) {
printf("Confusion matrix:\n\n");
uint32_t num_classes = parser->model->num_classes;
uint32_t num_classes = address_parser_num_classes(parser);
size_t *confusion_sorted = uint32_array_argsort(results.confusion, num_classes * num_classes);
@@ -211,8 +245,9 @@ int main(int argc, char **argv) {
if (i == j) continue;
if (class_errors > 0) {
char *predicted = cstring_array_get_string(parser->model->classes, i);
char *truth = cstring_array_get_string(parser->model->classes, j);
cstring_array *classes = address_parser_class_strings(parser);
char *predicted = cstring_array_get_string(classes, i);
char *truth = cstring_array_get_string(classes, j);
printf("(%s, %s): %d\n", predicted, truth, class_errors);
}

View File

@@ -4,6 +4,7 @@
#include "address_parser_io.h"
#include "address_dictionary.h"
#include "averaged_perceptron_trainer.h"
#include "crf_trainer_averaged_perceptron.h"
#include "collections.h"
#include "constants.h"
#include "file_utils.h"
@@ -30,6 +31,7 @@ KHASH_MAP_INIT_STR(phrase_types, address_parser_types_t)
#define DEFAULT_ITERATIONS 5
#define DEFAULT_MIN_UPDATES 5
#define DEFAULT_MODEL_TYPE ADDRESS_PARSER_TYPE_CRF
#define MIN_VOCAB_COUNT 5
#define MIN_PHRASE_COUNT 1
@@ -691,8 +693,6 @@ address_parser_t *address_parser_init(char *filename) {
log_info("Calculating phrase types\n");
parser->model = NULL;
size_t num_classes = kh_size(class_counts);
log_info("num_classes = %zu\n", num_classes);
parser->num_classes = num_classes;
@@ -977,9 +977,77 @@ exit_hashes_allocated:
return parser;
}
static inline bool address_parser_train_example(address_parser_t *self, void *trainer, address_parser_context_t *context, address_parser_data_set_t *data_set) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
return averaged_perceptron_trainer_train_example((averaged_perceptron_trainer_t *)trainer, self, context, context->features, context->prev_tag_features, context->prev2_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels);
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
return crf_averaged_perceptron_trainer_train_example((crf_averaged_perceptron_trainer_t *)trainer, self, context, context->features, context->prev_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels);
} else {
log_error("Parser model is of unknown type\n");
}
return false;
}
static inline void address_parser_trainer_destroy(address_parser_t *self, void *trainer) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
averaged_perceptron_trainer_destroy((averaged_perceptron_trainer_t *)trainer);
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
crf_averaged_perceptron_trainer_destroy((crf_averaged_perceptron_trainer_t *)trainer);
}
}
bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trainer_t *trainer, char *filename) {
static inline bool address_parser_finalize_model(address_parser_t *self, void *trainer) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
self->model.ap = averaged_perceptron_trainer_finalize((averaged_perceptron_trainer_t *)trainer);
return self->model.ap != NULL;
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
self->model.crf = crf_averaged_perceptron_trainer_finalize((crf_averaged_perceptron_trainer_t *)trainer);
return self->model.crf != NULL;
} else {
log_error("Parser model is of unknown type\n");
}
return false;
}
static inline uint32_t address_parser_train_num_iterations(address_parser_t *self, void *trainer) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
return ap_trainer->iterations;
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
return crf_trainer->iterations;
} else {
log_error("Parser model is of unknown type\n");
}
return 0;
}
static inline void address_parser_train_set_iterations(address_parser_t *self, void *trainer, uint32_t iterations) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
ap_trainer->iterations = iterations;
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
crf_trainer->iterations = iterations;
} else {
log_error("Parser model is of unknown type\n");
}
}
static inline uint64_t address_parser_train_num_errors(address_parser_t *self, void *trainer) {
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
return ap_trainer->num_updates;
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
return crf_trainer->num_updates;
} else {
log_error("Parser model is of unknown type\n");
}
return 0;
}
bool address_parser_train_epoch(address_parser_t *self, void *trainer, char *filename) {
if (filename == NULL) {
log_error("Filename was NULL\n");
return false;
@@ -994,7 +1062,9 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai
address_parser_context_t *context = address_parser_context_new();
size_t examples = 0;
size_t errors = trainer->num_errors;
uint64_t errors = address_parser_train_num_errors(self, trainer);
uint32_t iteration = address_parser_train_num_iterations(self, trainer);
bool logged = false;
@@ -1007,7 +1077,7 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai
address_parser_context_fill(context, self, data_set->tokenized_str, language, country);
bool example_success = averaged_perceptron_trainer_train_example(trainer, self, context, context->features, context->prev_tag_features, context->prev2_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels);
bool example_success = address_parser_train_example(self, trainer, context, data_set);
if (!example_success) {
log_error("Error training example\n");
@@ -1024,10 +1094,11 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai
examples++;
if (examples % 1000 == 0 && examples > 0) {
log_info("Iter %d: Did %zu examples with %llu errors\n", trainer->iterations, examples, trainer->num_errors - errors);
errors = trainer->num_errors;
}
uint64_t prev_errors = errors;
errors = address_parser_train_num_errors(self, trainer);
log_info("Iter %d: Did %zu examples with %llu errors\n", iteration, examples, errors - prev_errors);
}
}
exit_epoch_training_started:
@@ -1037,20 +1108,29 @@ exit_epoch_training_started:
return true;
}
bool address_parser_train(address_parser_t *self, char *filename, uint32_t num_iterations, size_t min_updates) {
averaged_perceptron_trainer_t *trainer = averaged_perceptron_trainer_new(min_updates);
bool address_parser_train(address_parser_t *self, char *filename, address_parser_model_type_t model_type, uint32_t num_iterations, size_t min_updates) {
self->model_type = model_type;
void *trainer;
if (model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
averaged_perceptron_trainer_t *ap_trainer = averaged_perceptron_trainer_new(min_updates);
trainer = (void *)ap_trainer;
} else if (model_type == ADDRESS_PARSER_TYPE_CRF) {
crf_averaged_perceptron_trainer_t *crf_trainer = crf_averaged_perceptron_trainer_new(self->num_classes, min_updates);
trainer = (void *)crf_trainer;
}
for (uint32_t iter = 0; iter < num_iterations; iter++) {
log_info("Doing epoch %d\n", iter);
trainer->iterations = iter;
address_parser_train_set_iterations(self, trainer, iter);
#if defined(HAVE_SHUF) || defined(HAVE_GSHUF)
log_info("Shuffling\n");
if (!shuffle_file_chunked_size(filename, DEFAULT_SHUFFLE_CHUNK_SIZE)) {
log_error("Error in shuffle\n");
averaged_perceptron_trainer_destroy(trainer);
address_parser_trainer_destroy(self, trainer);
return false;
}
@@ -1059,14 +1139,17 @@ bool address_parser_train(address_parser_t *self, char *filename, uint32_t num_i
if (!address_parser_train_epoch(self, trainer, filename)) {
log_error("Error in epoch\n");
averaged_perceptron_trainer_destroy(trainer);
address_parser_trainer_destroy(self, trainer);
return false;
}
}
log_debug("Done with training, averaging weights\n");
self->model = averaged_perceptron_trainer_finalize(trainer);
if (!address_parser_finalize_model(self, trainer)) {
log_error("model was NULL\n");
return false;
}
return true;
}
@@ -1074,10 +1157,11 @@ bool address_parser_train(address_parser_t *self, char *filename, uint32_t num_i
typedef enum {
ADDRESS_PARSER_TRAIN_POSITIONAL_ARG,
ADDRESS_PARSER_TRAIN_ARG_ITERATIONS,
ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES
ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES,
ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE
} address_parser_train_keyword_arg_t;
#define USAGE "Usage: ./address_parser_train filename output_dir [--iterations number]\n"
#define USAGE "Usage: ./address_parser_train filename output_dir [--iterations number --min-updates number --model (crf|greedyap)]\n"
int main(int argc, char **argv) {
if (argc < 3) {
@@ -1103,6 +1187,8 @@ int main(int argc, char **argv) {
char *filename = NULL;
char *output_dir = NULL;
address_parser_model_type_t model_type = DEFAULT_MODEL_TYPE;
for (int i = pos_args; i < argc; i++) {
char *arg = argv[i];
@@ -1116,6 +1202,11 @@ int main(int argc, char **argv) {
continue;
}
if (string_equals(arg, "--model")) {
kwarg = ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE;
continue;
}
if (kwarg == ADDRESS_PARSER_TRAIN_ARG_ITERATIONS) {
if (sscanf(arg, "%zd", &arg_iterations) != 1 || arg_iterations < 0) {
log_error("Bad arg for --iterations: %s\n", arg);
@@ -1129,6 +1220,15 @@ int main(int argc, char **argv) {
}
min_updates = arg_min_updates;
log_info("min_updates = %llu\n", min_updates);
} else if (kwarg == ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE) {
if (string_equals(arg, "crf")) {
model_type = ADDRESS_PARSER_TYPE_CRF;
} else if (string_equals(arg, "greedyap")) {
model_type = ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON;
} else {
log_error("Bad arg for --model, valid values are [crf, greedyap]\n");
exit(EXIT_FAILURE);
}
} else if (position == 0) {
filename = arg;
position++;
@@ -1169,7 +1269,7 @@ int main(int argc, char **argv) {
log_info("Finished initialization\n");
if (!address_parser_train(parser, filename, num_iterations, min_updates)) {
if (!address_parser_train(parser, filename, model_type, num_iterations, min_updates)) {
log_error("Error in training\n");
exit(EXIT_FAILURE);
}