[parser] adding polymorphic (as much as C does polymorphism) model type for the parser to allow it to handle either the greedy averaged perceptron or a CRF. During training, saving, and loading, we use a different filename for a parser trained with a CRF, which is still backward-compatible with models previously trained in parser-data. Making necessary modifications to address_parser.c, address_parser_train.c, and address_parser_test.c. Also adding an option in address_parser_test to print individual errors in addition to the confusion matrix.
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include "log/log.h"
|
||||
|
||||
#define ADDRESS_PARSER_MODEL_FILENAME "address_parser.dat"
|
||||
#define ADDRESS_PARSER_MODEL_FILENAME_CRF "address_parser_crf.dat"
|
||||
#define ADDRESS_PARSER_VOCAB_FILENAME "address_parser_vocab.trie"
|
||||
#define ADDRESS_PARSER_PHRASE_FILENAME "address_parser_phrases.dat"
|
||||
#define ADDRESS_PARSER_POSTAL_CODES_FILENAME "address_parser_postal_codes.dat"
|
||||
@@ -48,15 +49,33 @@ address_parser_t *get_address_parser(void) {
|
||||
bool address_parser_save(address_parser_t *self, char *output_dir) {
|
||||
if (self == NULL || output_dir == NULL) return false;
|
||||
|
||||
char *model_filename = NULL;
|
||||
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
model_filename = ADDRESS_PARSER_MODEL_FILENAME;
|
||||
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
model_filename = ADDRESS_PARSER_MODEL_FILENAME_CRF;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
char_array *path = char_array_new_size(strlen(output_dir));
|
||||
|
||||
char_array_add_joined(path, PATH_SEPARATOR, true, 2, output_dir, ADDRESS_PARSER_MODEL_FILENAME);
|
||||
char_array_add_joined(path, PATH_SEPARATOR, true, 2, output_dir, model_filename);
|
||||
char *model_path = char_array_get_string(path);
|
||||
|
||||
if (!averaged_perceptron_save(self->model, model_path)) {
|
||||
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
if (!averaged_perceptron_save(self->model.ap, model_path)) {
|
||||
log_info("Error in averaged_perceptron_save\n");
|
||||
char_array_destroy(path);
|
||||
return false;
|
||||
}
|
||||
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
if (!crf_save(self->model.crf, model_path)) {
|
||||
log_info("Error in crf_save\n");
|
||||
char_array_destroy(path);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
char_array_clear(path);
|
||||
|
||||
@@ -148,15 +167,47 @@ bool address_parser_load(char *dir) {
|
||||
char_array_add_joined(path, PATH_SEPARATOR, true, 2, dir, ADDRESS_PARSER_MODEL_FILENAME);
|
||||
char *model_path = char_array_get_string(path);
|
||||
|
||||
averaged_perceptron_t *model = averaged_perceptron_load(model_path);
|
||||
|
||||
if (model == NULL) {
|
||||
char_array_destroy(path);
|
||||
if (file_exists(model_path)) {
|
||||
averaged_perceptron_t *ap_model = averaged_perceptron_load(model_path);
|
||||
if (ap_model != NULL) {
|
||||
parser = address_parser_new();
|
||||
parser->model_type = ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON;
|
||||
parser->model.ap = ap_model;
|
||||
} else {
|
||||
char_array_destroy(model_path);
|
||||
log_error("Averaged perceptron model could not be loaded\n");
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
model_path = NULL;
|
||||
}
|
||||
|
||||
if (model_path == NULL) {
|
||||
char_array_clear(path);
|
||||
char_array_add_joined(path, PATH_SEPARATOR, true, 2, dir, ADDRESS_PARSER_MODEL_FILENAME_CRF);
|
||||
model_path = char_array_get_string(path);
|
||||
|
||||
if (file_exists(model_path)) {
|
||||
crf_t *crf_model = crf_load(model_path);
|
||||
if (crf_model != NULL) {
|
||||
parser = address_parser_new();
|
||||
parser->model = model;
|
||||
parser->model_type = ADDRESS_PARSER_TYPE_CRF;
|
||||
parser->model.crf = crf_model;
|
||||
} else {
|
||||
char_array_destroy(model_path);
|
||||
log_error("Averaged perceptron model could not be loaded\n");
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
model_path == NULL;
|
||||
}
|
||||
}
|
||||
|
||||
if (parser == NULL) {
|
||||
char_array_destroy(path);
|
||||
log_error("Could not find parser model file of known type\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
char_array_clear(path);
|
||||
|
||||
@@ -276,8 +327,10 @@ exit_address_parser_created:
|
||||
void address_parser_destroy(address_parser_t *self) {
|
||||
if (self == NULL) return;
|
||||
|
||||
if (self->model != NULL) {
|
||||
averaged_perceptron_destroy(self->model);
|
||||
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON && self->model.ap != NULL) {
|
||||
averaged_perceptron_destroy(self->model.ap);
|
||||
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF && self->model.crf != NULL) {
|
||||
crf_destroy(self->model.crf);
|
||||
}
|
||||
|
||||
if (self->vocab != NULL) {
|
||||
@@ -313,7 +366,7 @@ inline void address_parser_normalize_token(cstring_array *array, char *str, toke
|
||||
normalize_token(array, str, token, ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS);
|
||||
}
|
||||
|
||||
inline void address_parser_normalize_phrase_token(cstring_array *array, char *str, token_t token) {
|
||||
static inline void address_parser_normalize_phrase_token(cstring_array *array, char *str, token_t token) {
|
||||
normalize_token(array, str, token, ADDRESS_PARSER_NORMALIZE_ADMIN_TOKEN_OPTIONS);
|
||||
}
|
||||
|
||||
@@ -814,7 +867,7 @@ typedef struct address_parser_phrase {
|
||||
phrase_t phrase;
|
||||
} address_parser_phrase_t;
|
||||
|
||||
inline bool is_plain_word_phrase_type(address_parser_phrase_type_t type) {
|
||||
static inline bool is_plain_word_phrase_type(address_parser_phrase_type_t type) {
|
||||
return type == ADDRESS_PARSER_NULL_PHRASE || type == ADDRESS_PARSER_SUFFIX_PHRASE || type == ADDRESS_PARSER_PREFIX_PHRASE;
|
||||
}
|
||||
|
||||
@@ -1003,8 +1056,8 @@ address_parser_features
|
||||
|
||||
This is a feature function similar to those found in MEMM and CRF models.
|
||||
|
||||
Follows the signature of an ap_feature_function so it can be called
|
||||
as a function pointer by the averaged perceptron model.
|
||||
Follows the signature of a tagger_feature_function so it can be called
|
||||
as a function pointer by the averaged perceptron or CRF model.
|
||||
|
||||
Parameters:
|
||||
|
||||
@@ -1399,7 +1452,7 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize
|
||||
}
|
||||
}
|
||||
|
||||
if (idx == 0) {
|
||||
if (idx == 0 && !is_unknown_word) {
|
||||
feature_array_add(features, 2, "first word", word);
|
||||
//feature_array_add(features, 3, "first word+next word", word, next_word);
|
||||
}
|
||||
@@ -1413,12 +1466,18 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize
|
||||
if (last_index == idx - 1) {
|
||||
// Previous tag and current word
|
||||
feature_array_add(prev_tag_features, 2, "word", word);
|
||||
feature_array_add(prev_tag_features, 1, "trans");
|
||||
|
||||
// Previous two tags and current word
|
||||
if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
// In the CRF this is accounted for by the transition weights
|
||||
// so only need it for the averaged perceptron
|
||||
feature_array_add(prev_tag_features, 1, "trans");
|
||||
|
||||
// Averaged perceptron uses two tags of history, CRF uses one
|
||||
feature_array_add(prev2_tag_features, 2, "word", word);
|
||||
feature_array_add(prev2_tag_features, 1, "trans");
|
||||
}
|
||||
}
|
||||
|
||||
if (last_index >= 0) {
|
||||
address_parser_phrase_t prev_word_or_phrase = word_or_phrase_at_index(parser, tokenized, context, last_index, false);
|
||||
@@ -1583,6 +1642,17 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize
|
||||
|
||||
}
|
||||
|
||||
bool address_parser_predict(address_parser_t *self, address_parser_context_t *context, cstring_array *token_labels, tagger_feature_function feature_function, tokenized_string_t *tokenized_str) {
|
||||
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
return averaged_perceptron_tagger_predict(self->model.ap, self, context, context->features, context->prev_tag_features, context->prev2_tag_features, token_labels, feature_function, tokenized_str, self->options.print_features);
|
||||
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
return crf_tagger_predict(self->model.crf, self, context, context->features, context->prev_tag_features, token_labels, feature_function, tokenized_str, self->options.print_features);
|
||||
} else {
|
||||
log_error("Parser has unknown model type\n");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
address_parser_response_t *address_parser_response_new(void) {
|
||||
address_parser_response_t *response = malloc(sizeof(address_parser_response_t));
|
||||
return response;
|
||||
@@ -1603,8 +1673,6 @@ address_parser_response_t *address_parser_parse(char *address, char *language, c
|
||||
normalized = address;
|
||||
}
|
||||
|
||||
averaged_perceptron_t *model = parser->model;
|
||||
|
||||
token_array *tokens = tokenize(normalized);
|
||||
|
||||
tokenized_string_t *tokenized_str = tokenized_string_new_from_str_size(normalized, strlen(normalized), tokens->n);
|
||||
@@ -1706,7 +1774,9 @@ address_parser_response_t *address_parser_parse(char *address, char *language, c
|
||||
|
||||
char *prev_label = NULL;
|
||||
|
||||
if (averaged_perceptron_tagger_predict(model, parser, context, context->features, context->prev_tag_features, context->prev2_tag_features, token_labels, &address_parser_features, tokenized_str, parser->options.print_features)) {
|
||||
bool prediction_success = address_parser_predict(parser, context, token_labels, &address_parser_features, tokenized_str);
|
||||
|
||||
if (prediction_success) {
|
||||
response = address_parser_response_new();
|
||||
|
||||
size_t num_strings = cstring_array_num_strings(tokenized_str->strings);
|
||||
@@ -1738,6 +1808,8 @@ address_parser_response_t *address_parser_parse(char *address, char *language, c
|
||||
response->components = cstring_array_to_strings(components);
|
||||
response->labels = cstring_array_to_strings(labels);
|
||||
|
||||
} else {
|
||||
log_error("Error in prediction\n");
|
||||
}
|
||||
|
||||
token_array_destroy(tokens);
|
||||
|
||||
@@ -46,11 +46,13 @@ with the general error-driven averaged perceptron.
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "averaged_perceptron.h"
|
||||
#include "averaged_perceptron_tagger.h"
|
||||
#include "libpostal.h"
|
||||
#include "libpostal_config.h"
|
||||
|
||||
#include "averaged_perceptron.h"
|
||||
#include "averaged_perceptron_tagger.h"
|
||||
#include "collections.h"
|
||||
#include "crf.h"
|
||||
#include "normalize.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
@@ -178,6 +180,11 @@ typedef union postal_code_context_value {
|
||||
|
||||
#define POSTAL_CODE_CONTEXT(pc, ad) ((postal_code_context_value_t){.postcode = (pc), .admin = (ad) })
|
||||
|
||||
typedef enum address_parser_model_type {
|
||||
ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON,
|
||||
ADDRESS_PARSER_TYPE_CRF
|
||||
} address_parser_model_type_t;
|
||||
|
||||
typedef struct parser_options {
|
||||
uint64_t rare_word_threshold;
|
||||
bool print_features;
|
||||
@@ -187,7 +194,11 @@ typedef struct parser_options {
|
||||
typedef struct address_parser {
|
||||
parser_options_t options;
|
||||
size_t num_classes;
|
||||
averaged_perceptron_t *model;
|
||||
address_parser_model_type_t model_type;
|
||||
union {
|
||||
averaged_perceptron_t *ap;
|
||||
crf_t *crf;
|
||||
} model;
|
||||
trie_t *vocab;
|
||||
trie_t *phrases;
|
||||
address_parser_types_array *phrase_types;
|
||||
@@ -208,6 +219,8 @@ void address_parser_destroy(address_parser_t *self);
|
||||
char *address_parser_normalize_string(char *str);
|
||||
void address_parser_normalize_token(cstring_array *array, char *str, token_t token);
|
||||
|
||||
bool address_parser_predict(address_parser_t *self, address_parser_context_t *context, cstring_array *token_labels, tagger_feature_function feature_function, tokenized_string_t *tokenized_str);
|
||||
|
||||
address_parser_context_t *address_parser_context_new(void);
|
||||
void address_parser_context_destroy(address_parser_context_t *self);
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#include "address_parser.h"
|
||||
#include "address_parser_io.h"
|
||||
#include "address_dictionary.h"
|
||||
#include "averaged_perceptron_trainer.h"
|
||||
#include "collections.h"
|
||||
#include "constants.h"
|
||||
#include "file_utils.h"
|
||||
@@ -10,8 +9,6 @@
|
||||
|
||||
#include "log/log.h"
|
||||
|
||||
//#define ADDRESS_PARSER_TEST_PRINT_ERRORS
|
||||
|
||||
typedef struct address_parser_test_results {
|
||||
size_t num_errors;
|
||||
size_t num_predictions;
|
||||
@@ -21,19 +18,44 @@ typedef struct address_parser_test_results {
|
||||
} address_parser_test_results_t;
|
||||
|
||||
|
||||
uint32_t get_class_index(address_parser_t *parser, char *name) {
|
||||
static uint32_t address_parser_num_classes(address_parser_t *parser) {
|
||||
if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
averaged_perceptron_t *ap = parser->model.ap;
|
||||
return parser->model.ap->num_classes;
|
||||
} else if (parser->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
return parser->model.crf->num_classes;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static cstring_array *address_parser_class_strings(address_parser_t *parser) {
|
||||
if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
cstring_array *classes = parser->model.ap->classes;
|
||||
return parser->model.ap->classes;
|
||||
} else if (parser->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
return parser->model.crf->classes;
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static uint32_t address_parser_get_class_index(address_parser_t *parser, char *name) {
|
||||
uint32_t i;
|
||||
char *str;
|
||||
|
||||
cstring_array_foreach(parser->model->classes, i, str, {
|
||||
cstring_array *classes = address_parser_class_strings(parser);
|
||||
uint32_t num_classes = address_parser_num_classes(parser);
|
||||
if (classes != NULL) {
|
||||
cstring_array_foreach(classes, i, str, {
|
||||
if (strcmp(name, str) == 0) {
|
||||
return i;
|
||||
}
|
||||
})
|
||||
|
||||
return parser->model->num_classes;
|
||||
}
|
||||
|
||||
return num_classes;
|
||||
}
|
||||
|
||||
|
||||
#define EMPTY_ADDRESS_PARSER_TEST_RESULT (address_parser_test_results_t){0, 0, 0, 0, NULL}
|
||||
|
||||
bool address_parser_test(address_parser_t *parser, char *filename, address_parser_test_results_t *result, bool print_errors) {
|
||||
@@ -42,7 +64,7 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
|
||||
return NULL;
|
||||
}
|
||||
|
||||
uint32_t num_classes = parser->model->num_classes;
|
||||
uint32_t num_classes = address_parser_num_classes(parser);
|
||||
|
||||
result->confusion = calloc(num_classes * num_classes, sizeof(uint32_t));
|
||||
|
||||
@@ -80,7 +102,9 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
|
||||
|
||||
size_t starting_errors = result->num_errors;
|
||||
|
||||
if (averaged_perceptron_tagger_predict(parser->model, parser, context, context->features, context->prev_tag_features, context->prev2_tag_features, token_labels, &address_parser_features, data_set->tokenized_str, false)) {
|
||||
bool prediction_success = address_parser_predict(parser, context, token_labels, &address_parser_features, data_set->tokenized_str);
|
||||
|
||||
if (prediction_success) {
|
||||
uint32_t i;
|
||||
char *predicted;
|
||||
cstring_array_foreach(token_labels, i, predicted, {
|
||||
@@ -89,8 +113,8 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
|
||||
if (strcmp(predicted, truth) != 0) {
|
||||
result->num_errors++;
|
||||
|
||||
uint32_t predicted_index = get_class_index(parser, predicted);
|
||||
uint32_t truth_index = get_class_index(parser, truth);
|
||||
uint32_t predicted_index = address_parser_get_class_index(parser, predicted);
|
||||
uint32_t truth_index = address_parser_get_class_index(parser, truth);
|
||||
|
||||
result->confusion[predicted_index * num_classes + truth_index]++;
|
||||
|
||||
@@ -102,6 +126,10 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
|
||||
|
||||
})
|
||||
|
||||
} else {
|
||||
log_error("Error in prediction\n");
|
||||
tokenized_string_destroy(data_set->tokenized_str);
|
||||
break;
|
||||
}
|
||||
|
||||
if (result->num_errors > starting_errors) {
|
||||
@@ -184,6 +212,12 @@ int main(int argc, char **argv) {
|
||||
|
||||
address_parser_t *parser = get_address_parser();
|
||||
|
||||
if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
printf("averaged perceptron parser\n");
|
||||
} else if (parser->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
printf("crf parser\n");
|
||||
}
|
||||
|
||||
address_parser_test_results_t results = EMPTY_ADDRESS_PARSER_TEST_RESULT;
|
||||
|
||||
if (!address_parser_test(parser, filename, &results, print_errors)) {
|
||||
@@ -196,7 +230,7 @@ int main(int argc, char **argv) {
|
||||
|
||||
|
||||
printf("Confusion matrix:\n\n");
|
||||
uint32_t num_classes = parser->model->num_classes;
|
||||
uint32_t num_classes = address_parser_num_classes(parser);
|
||||
|
||||
size_t *confusion_sorted = uint32_array_argsort(results.confusion, num_classes * num_classes);
|
||||
|
||||
@@ -211,8 +245,9 @@ int main(int argc, char **argv) {
|
||||
if (i == j) continue;
|
||||
|
||||
if (class_errors > 0) {
|
||||
char *predicted = cstring_array_get_string(parser->model->classes, i);
|
||||
char *truth = cstring_array_get_string(parser->model->classes, j);
|
||||
cstring_array *classes = address_parser_class_strings(parser);
|
||||
char *predicted = cstring_array_get_string(classes, i);
|
||||
char *truth = cstring_array_get_string(classes, j);
|
||||
|
||||
printf("(%s, %s): %d\n", predicted, truth, class_errors);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "address_parser_io.h"
|
||||
#include "address_dictionary.h"
|
||||
#include "averaged_perceptron_trainer.h"
|
||||
#include "crf_trainer_averaged_perceptron.h"
|
||||
#include "collections.h"
|
||||
#include "constants.h"
|
||||
#include "file_utils.h"
|
||||
@@ -30,6 +31,7 @@ KHASH_MAP_INIT_STR(phrase_types, address_parser_types_t)
|
||||
|
||||
#define DEFAULT_ITERATIONS 5
|
||||
#define DEFAULT_MIN_UPDATES 5
|
||||
#define DEFAULT_MODEL_TYPE ADDRESS_PARSER_TYPE_CRF
|
||||
|
||||
#define MIN_VOCAB_COUNT 5
|
||||
#define MIN_PHRASE_COUNT 1
|
||||
@@ -691,8 +693,6 @@ address_parser_t *address_parser_init(char *filename) {
|
||||
|
||||
log_info("Calculating phrase types\n");
|
||||
|
||||
parser->model = NULL;
|
||||
|
||||
size_t num_classes = kh_size(class_counts);
|
||||
log_info("num_classes = %zu\n", num_classes);
|
||||
parser->num_classes = num_classes;
|
||||
@@ -977,9 +977,77 @@ exit_hashes_allocated:
|
||||
return parser;
|
||||
}
|
||||
|
||||
static inline bool address_parser_train_example(address_parser_t *self, void *trainer, address_parser_context_t *context, address_parser_data_set_t *data_set) {
|
||||
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
return averaged_perceptron_trainer_train_example((averaged_perceptron_trainer_t *)trainer, self, context, context->features, context->prev_tag_features, context->prev2_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels);
|
||||
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
return crf_averaged_perceptron_trainer_train_example((crf_averaged_perceptron_trainer_t *)trainer, self, context, context->features, context->prev_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels);
|
||||
} else {
|
||||
log_error("Parser model is of unknown type\n");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline void address_parser_trainer_destroy(address_parser_t *self, void *trainer) {
|
||||
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
averaged_perceptron_trainer_destroy((averaged_perceptron_trainer_t *)trainer);
|
||||
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
crf_averaged_perceptron_trainer_destroy((crf_averaged_perceptron_trainer_t *)trainer);
|
||||
}
|
||||
}
|
||||
|
||||
bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trainer_t *trainer, char *filename) {
|
||||
static inline bool address_parser_finalize_model(address_parser_t *self, void *trainer) {
|
||||
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
self->model.ap = averaged_perceptron_trainer_finalize((averaged_perceptron_trainer_t *)trainer);
|
||||
return self->model.ap != NULL;
|
||||
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
self->model.crf = crf_averaged_perceptron_trainer_finalize((crf_averaged_perceptron_trainer_t *)trainer);
|
||||
return self->model.crf != NULL;
|
||||
} else {
|
||||
log_error("Parser model is of unknown type\n");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline uint32_t address_parser_train_num_iterations(address_parser_t *self, void *trainer) {
|
||||
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
|
||||
return ap_trainer->iterations;
|
||||
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
|
||||
return crf_trainer->iterations;
|
||||
} else {
|
||||
log_error("Parser model is of unknown type\n");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static inline void address_parser_train_set_iterations(address_parser_t *self, void *trainer, uint32_t iterations) {
|
||||
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
|
||||
ap_trainer->iterations = iterations;
|
||||
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
|
||||
crf_trainer->iterations = iterations;
|
||||
} else {
|
||||
log_error("Parser model is of unknown type\n");
|
||||
}
|
||||
}
|
||||
|
||||
static inline uint64_t address_parser_train_num_errors(address_parser_t *self, void *trainer) {
|
||||
if (self->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
averaged_perceptron_trainer_t *ap_trainer = (averaged_perceptron_trainer_t *)trainer;
|
||||
return ap_trainer->num_updates;
|
||||
} else if (self->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
crf_averaged_perceptron_trainer_t *crf_trainer = (crf_averaged_perceptron_trainer_t *)trainer;
|
||||
return crf_trainer->num_updates;
|
||||
} else {
|
||||
log_error("Parser model is of unknown type\n");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool address_parser_train_epoch(address_parser_t *self, void *trainer, char *filename) {
|
||||
if (filename == NULL) {
|
||||
log_error("Filename was NULL\n");
|
||||
return false;
|
||||
@@ -994,7 +1062,9 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai
|
||||
address_parser_context_t *context = address_parser_context_new();
|
||||
|
||||
size_t examples = 0;
|
||||
size_t errors = trainer->num_errors;
|
||||
uint64_t errors = address_parser_train_num_errors(self, trainer);
|
||||
|
||||
uint32_t iteration = address_parser_train_num_iterations(self, trainer);
|
||||
|
||||
bool logged = false;
|
||||
|
||||
@@ -1007,7 +1077,7 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai
|
||||
|
||||
address_parser_context_fill(context, self, data_set->tokenized_str, language, country);
|
||||
|
||||
bool example_success = averaged_perceptron_trainer_train_example(trainer, self, context, context->features, context->prev_tag_features, context->prev2_tag_features, &address_parser_features, data_set->tokenized_str, data_set->labels);
|
||||
bool example_success = address_parser_train_example(self, trainer, context, data_set);
|
||||
|
||||
if (!example_success) {
|
||||
log_error("Error training example\n");
|
||||
@@ -1024,10 +1094,11 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai
|
||||
|
||||
examples++;
|
||||
if (examples % 1000 == 0 && examples > 0) {
|
||||
log_info("Iter %d: Did %zu examples with %llu errors\n", trainer->iterations, examples, trainer->num_errors - errors);
|
||||
errors = trainer->num_errors;
|
||||
}
|
||||
uint64_t prev_errors = errors;
|
||||
errors = address_parser_train_num_errors(self, trainer);
|
||||
|
||||
log_info("Iter %d: Did %zu examples with %llu errors\n", iteration, examples, errors - prev_errors);
|
||||
}
|
||||
}
|
||||
|
||||
exit_epoch_training_started:
|
||||
@@ -1037,20 +1108,29 @@ exit_epoch_training_started:
|
||||
return true;
|
||||
}
|
||||
|
||||
bool address_parser_train(address_parser_t *self, char *filename, uint32_t num_iterations, size_t min_updates) {
|
||||
averaged_perceptron_trainer_t *trainer = averaged_perceptron_trainer_new(min_updates);
|
||||
|
||||
bool address_parser_train(address_parser_t *self, char *filename, address_parser_model_type_t model_type, uint32_t num_iterations, size_t min_updates) {
|
||||
self->model_type = model_type;
|
||||
void *trainer;
|
||||
if (model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
averaged_perceptron_trainer_t *ap_trainer = averaged_perceptron_trainer_new(min_updates);
|
||||
trainer = (void *)ap_trainer;
|
||||
} else if (model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
crf_averaged_perceptron_trainer_t *crf_trainer = crf_averaged_perceptron_trainer_new(self->num_classes, min_updates);
|
||||
trainer = (void *)crf_trainer;
|
||||
}
|
||||
|
||||
for (uint32_t iter = 0; iter < num_iterations; iter++) {
|
||||
log_info("Doing epoch %d\n", iter);
|
||||
|
||||
trainer->iterations = iter;
|
||||
address_parser_train_set_iterations(self, trainer, iter);
|
||||
|
||||
#if defined(HAVE_SHUF) || defined(HAVE_GSHUF)
|
||||
log_info("Shuffling\n");
|
||||
|
||||
if (!shuffle_file_chunked_size(filename, DEFAULT_SHUFFLE_CHUNK_SIZE)) {
|
||||
log_error("Error in shuffle\n");
|
||||
averaged_perceptron_trainer_destroy(trainer);
|
||||
address_parser_trainer_destroy(self, trainer);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1059,14 +1139,17 @@ bool address_parser_train(address_parser_t *self, char *filename, uint32_t num_i
|
||||
|
||||
if (!address_parser_train_epoch(self, trainer, filename)) {
|
||||
log_error("Error in epoch\n");
|
||||
averaged_perceptron_trainer_destroy(trainer);
|
||||
address_parser_trainer_destroy(self, trainer);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
log_debug("Done with training, averaging weights\n");
|
||||
|
||||
self->model = averaged_perceptron_trainer_finalize(trainer);
|
||||
if (!address_parser_finalize_model(self, trainer)) {
|
||||
log_error("model was NULL\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -1074,10 +1157,11 @@ bool address_parser_train(address_parser_t *self, char *filename, uint32_t num_i
|
||||
typedef enum {
|
||||
ADDRESS_PARSER_TRAIN_POSITIONAL_ARG,
|
||||
ADDRESS_PARSER_TRAIN_ARG_ITERATIONS,
|
||||
ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES
|
||||
ADDRESS_PARSER_TRAIN_ARG_MIN_UPDATES,
|
||||
ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE
|
||||
} address_parser_train_keyword_arg_t;
|
||||
|
||||
#define USAGE "Usage: ./address_parser_train filename output_dir [--iterations number]\n"
|
||||
#define USAGE "Usage: ./address_parser_train filename output_dir [--iterations number --min-updates number --model (crf|greedyap)]\n"
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc < 3) {
|
||||
@@ -1103,6 +1187,8 @@ int main(int argc, char **argv) {
|
||||
char *filename = NULL;
|
||||
char *output_dir = NULL;
|
||||
|
||||
address_parser_model_type_t model_type = DEFAULT_MODEL_TYPE;
|
||||
|
||||
for (int i = pos_args; i < argc; i++) {
|
||||
char *arg = argv[i];
|
||||
|
||||
@@ -1116,6 +1202,11 @@ int main(int argc, char **argv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (string_equals(arg, "--model")) {
|
||||
kwarg = ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (kwarg == ADDRESS_PARSER_TRAIN_ARG_ITERATIONS) {
|
||||
if (sscanf(arg, "%zd", &arg_iterations) != 1 || arg_iterations < 0) {
|
||||
log_error("Bad arg for --iterations: %s\n", arg);
|
||||
@@ -1129,6 +1220,15 @@ int main(int argc, char **argv) {
|
||||
}
|
||||
min_updates = arg_min_updates;
|
||||
log_info("min_updates = %llu\n", min_updates);
|
||||
} else if (kwarg == ADDRESS_PARSER_TRAIN_ARG_MODEL_TYPE) {
|
||||
if (string_equals(arg, "crf")) {
|
||||
model_type = ADDRESS_PARSER_TYPE_CRF;
|
||||
} else if (string_equals(arg, "greedyap")) {
|
||||
model_type = ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON;
|
||||
} else {
|
||||
log_error("Bad arg for --model, valid values are [crf, greedyap]\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
} else if (position == 0) {
|
||||
filename = arg;
|
||||
position++;
|
||||
@@ -1169,7 +1269,7 @@ int main(int argc, char **argv) {
|
||||
|
||||
log_info("Finished initialization\n");
|
||||
|
||||
if (!address_parser_train(parser, filename, num_iterations, min_updates)) {
|
||||
if (!address_parser_train(parser, filename, model_type, num_iterations, min_updates)) {
|
||||
log_error("Error in training\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user