[parser] counting num classes in address parser init for models where it is needed a priori

This commit is contained in:
Al
2017-03-06 15:17:52 -05:00
parent 5f19e63cbe
commit 39fa8ff1a5
2 changed files with 39 additions and 9 deletions

View File

@@ -186,6 +186,7 @@ typedef struct parser_options {
// Can add other gazetteers as well
typedef struct address_parser {
parser_options_t options;
size_t num_classes;
averaged_perceptron_t *model;
trie_t *vocab;
trie_t *phrases;

View File

@@ -344,6 +344,12 @@ address_parser_t *address_parser_init(char *filename) {
return NULL;
}
khash_t(str_set) *unique_classes = kh_init(str_set);
if (unique_classes == NULL) {
log_error("Could not allocate unique_classes\n");
return NULL;
}
khash_t(phrase_stats) *phrase_stats = kh_init(phrase_stats);
if (phrase_stats == NULL) {
log_error("Could not allocate phrase_stats\n");
@@ -380,7 +386,7 @@ address_parser_t *address_parser_init(char *filename) {
uint32_t i, j;
phrase_stats_t stats;
khash_t(int_uint32) *class_counts;
khash_t(int_uint32) *place_class_counts;
uint32_t vocab_size = 0;
size_t examples = 0;
@@ -469,6 +475,19 @@ address_parser_t *address_parser_init(char *filename) {
}
})
cstring_array_foreach(phrase_labels, i, label, {
k = kh_get(str_set, unique_classes, label);
if (k == kh_end(unique_classes)) {
char *label_key = strdup(label);
k = kh_put(str_set, unique_classes, label_key, &ret);
if (ret < 0) {
log_error("Error in kh_put in unique_classes\n");
free(label_key);
goto exit_hashes_allocated;
}
}
})
cstring_array_foreach(phrases, i, phrase, {
if (phrase == NULL) continue;
@@ -617,19 +636,19 @@ address_parser_t *address_parser_init(char *filename) {
free(key);
goto exit_hashes_allocated;
}
class_counts = kh_init(int_uint32);
place_class_counts = kh_init(int_uint32);
stats.class_counts = class_counts;
stats.class_counts = place_class_counts;
stats.components = component;
kh_value(phrase_stats, k) = stats;
} else {
stats = kh_value(phrase_stats, k);
class_counts = stats.class_counts;
place_class_counts = stats.class_counts;
stats.components |= component;
}
if (!int_uint32_hash_incr(class_counts, (khint_t)class_id)) {
if (!int_uint32_hash_incr(place_class_counts, (khint_t)class_id)) {
log_error("Error in int_uint32_hash_incr in class_counts\n");
goto exit_hashes_allocated;
}
@@ -679,6 +698,11 @@ address_parser_t *address_parser_init(char *filename) {
parser->model = NULL;
size_t num_classes = kh_size(unique_classes);
log_info("num_classes = %zu\n", num_classes);
parser->num_classes = num_classes;
log_info("Creating vocab trie\n");
parser->vocab = trie_new_from_hash(vocab);
@@ -754,14 +778,14 @@ address_parser_t *address_parser_init(char *filename) {
stats = kh_value(phrase_stats, k);
class_counts = stats.class_counts;
place_class_counts = stats.class_counts;
int32_t most_common = -1;
uint32_t max_count = 0;
uint32_t total = 0;
for (uint32_t i = 0; i < NUM_ADDRESS_PARSER_BOUNDARY_TYPES; i++) {
k = kh_get(int_uint32, class_counts, (khint_t)i);
if (k != kh_end(class_counts)) {
count = kh_value(class_counts, k);
k = kh_get(int_uint32, place_class_counts, (khint_t)i);
if (k != kh_end(place_class_counts)) {
count = kh_value(place_class_counts, k);
if (count > max_count) {
max_count = count;
@@ -914,6 +938,11 @@ exit_hashes_allocated:
})
kh_destroy(str_uint32, vocab);
kh_foreach_key(unique_classes, token, {
free((char *)token);
})
kh_destroy(str_set, unique_classes);
kh_foreach(phrase_stats, token, stats, {
kh_destroy(int_uint32, stats.class_counts);
free((char *)token);