From 39fa8ff1a5de0b655d24775ab0f158c7697ad612 Mon Sep 17 00:00:00 2001 From: Al Date: Mon, 6 Mar 2017 15:17:52 -0500 Subject: [PATCH] [parser] counting num classes in address parser init for models where it is needed a priori --- src/address_parser.h | 1 + src/address_parser_train.c | 47 ++++++++++++++++++++++++++++++-------- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/src/address_parser.h b/src/address_parser.h index b3df7837..dc19a249 100644 --- a/src/address_parser.h +++ b/src/address_parser.h @@ -186,6 +186,7 @@ typedef struct parser_options { // Can add other gazetteers as well typedef struct address_parser { parser_options_t options; + size_t num_classes; averaged_perceptron_t *model; trie_t *vocab; trie_t *phrases; diff --git a/src/address_parser_train.c b/src/address_parser_train.c index c567b878..ce80fe3a 100644 --- a/src/address_parser_train.c +++ b/src/address_parser_train.c @@ -344,6 +344,12 @@ address_parser_t *address_parser_init(char *filename) { return NULL; } + khash_t(str_set) *unique_classes = kh_init(str_set); + if (unique_classes == NULL) { + log_error("Could not allocate unique_classes\n"); + return NULL; + } + khash_t(phrase_stats) *phrase_stats = kh_init(phrase_stats); if (phrase_stats == NULL) { log_error("Could not allocate phrase_stats\n"); @@ -380,7 +386,7 @@ address_parser_t *address_parser_init(char *filename) { uint32_t i, j; phrase_stats_t stats; - khash_t(int_uint32) *class_counts; + khash_t(int_uint32) *place_class_counts; uint32_t vocab_size = 0; size_t examples = 0; @@ -469,6 +475,19 @@ address_parser_t *address_parser_init(char *filename) { } }) + cstring_array_foreach(phrase_labels, i, label, { + k = kh_get(str_set, unique_classes, label); + if (k == kh_end(unique_classes)) { + char *label_key = strdup(label); + k = kh_put(str_set, unique_classes, label_key, &ret); + if (ret < 0) { + log_error("Error in kh_put in unique_classes\n"); + free(label_key); + goto exit_hashes_allocated; + } + } + }) + cstring_array_foreach(phrases, i, phrase, { if (phrase == NULL) continue; @@ -617,19 +636,19 @@ address_parser_t *address_parser_init(char *filename) { free(key); goto exit_hashes_allocated; } - class_counts = kh_init(int_uint32); + place_class_counts = kh_init(int_uint32); - stats.class_counts = class_counts; + stats.class_counts = place_class_counts; stats.components = component; kh_value(phrase_stats, k) = stats; } else { stats = kh_value(phrase_stats, k); - class_counts = stats.class_counts; + place_class_counts = stats.class_counts; stats.components |= component; } - if (!int_uint32_hash_incr(class_counts, (khint_t)class_id)) { + if (!int_uint32_hash_incr(place_class_counts, (khint_t)class_id)) { log_error("Error in int_uint32_hash_incr in class_counts\n"); goto exit_hashes_allocated; } @@ -679,6 +698,11 @@ address_parser_t *address_parser_init(char *filename) { parser->model = NULL; + + size_t num_classes = kh_size(unique_classes); + log_info("num_classes = %zu\n", num_classes); + parser->num_classes = num_classes; + log_info("Creating vocab trie\n"); parser->vocab = trie_new_from_hash(vocab); @@ -754,14 +778,14 @@ address_parser_t *address_parser_init(char *filename) { stats = kh_value(phrase_stats, k); - class_counts = stats.class_counts; + place_class_counts = stats.class_counts; int32_t most_common = -1; uint32_t max_count = 0; uint32_t total = 0; for (uint32_t i = 0; i < NUM_ADDRESS_PARSER_BOUNDARY_TYPES; i++) { - k = kh_get(int_uint32, class_counts, (khint_t)i); - if (k != kh_end(class_counts)) { - count = kh_value(class_counts, k); + k = kh_get(int_uint32, place_class_counts, (khint_t)i); + if (k != kh_end(place_class_counts)) { + count = kh_value(place_class_counts, k); if (count > max_count) { max_count = count; @@ -914,6 +938,11 @@ exit_hashes_allocated: }) kh_destroy(str_uint32, vocab); + kh_foreach_key(unique_classes, token, { + free((char *)token); + }) + kh_destroy(str_set, unique_classes); + kh_foreach(phrase_stats, token, stats, { kh_destroy(int_uint32, stats.class_counts); free((char *)token);