[parser] counting num classes in address parser init for models where it is needed a priori
This commit is contained in:
@@ -186,6 +186,7 @@ typedef struct parser_options {
|
|||||||
// Can add other gazetteers as well
|
// Can add other gazetteers as well
|
||||||
typedef struct address_parser {
|
typedef struct address_parser {
|
||||||
parser_options_t options;
|
parser_options_t options;
|
||||||
|
size_t num_classes;
|
||||||
averaged_perceptron_t *model;
|
averaged_perceptron_t *model;
|
||||||
trie_t *vocab;
|
trie_t *vocab;
|
||||||
trie_t *phrases;
|
trie_t *phrases;
|
||||||
|
|||||||
@@ -344,6 +344,12 @@ address_parser_t *address_parser_init(char *filename) {
|
|||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
khash_t(str_set) *unique_classes = kh_init(str_set);
|
||||||
|
if (unique_classes == NULL) {
|
||||||
|
log_error("Could not allocate unique_classes\n");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
khash_t(phrase_stats) *phrase_stats = kh_init(phrase_stats);
|
khash_t(phrase_stats) *phrase_stats = kh_init(phrase_stats);
|
||||||
if (phrase_stats == NULL) {
|
if (phrase_stats == NULL) {
|
||||||
log_error("Could not allocate phrase_stats\n");
|
log_error("Could not allocate phrase_stats\n");
|
||||||
@@ -380,7 +386,7 @@ address_parser_t *address_parser_init(char *filename) {
|
|||||||
uint32_t i, j;
|
uint32_t i, j;
|
||||||
|
|
||||||
phrase_stats_t stats;
|
phrase_stats_t stats;
|
||||||
khash_t(int_uint32) *class_counts;
|
khash_t(int_uint32) *place_class_counts;
|
||||||
|
|
||||||
uint32_t vocab_size = 0;
|
uint32_t vocab_size = 0;
|
||||||
size_t examples = 0;
|
size_t examples = 0;
|
||||||
@@ -469,6 +475,19 @@ address_parser_t *address_parser_init(char *filename) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
cstring_array_foreach(phrase_labels, i, label, {
|
||||||
|
k = kh_get(str_set, unique_classes, label);
|
||||||
|
if (k == kh_end(unique_classes)) {
|
||||||
|
char *label_key = strdup(label);
|
||||||
|
k = kh_put(str_set, unique_classes, label_key, &ret);
|
||||||
|
if (ret < 0) {
|
||||||
|
log_error("Error in kh_put in unique_classes\n");
|
||||||
|
free(label_key);
|
||||||
|
goto exit_hashes_allocated;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
cstring_array_foreach(phrases, i, phrase, {
|
cstring_array_foreach(phrases, i, phrase, {
|
||||||
if (phrase == NULL) continue;
|
if (phrase == NULL) continue;
|
||||||
|
|
||||||
@@ -617,19 +636,19 @@ address_parser_t *address_parser_init(char *filename) {
|
|||||||
free(key);
|
free(key);
|
||||||
goto exit_hashes_allocated;
|
goto exit_hashes_allocated;
|
||||||
}
|
}
|
||||||
class_counts = kh_init(int_uint32);
|
place_class_counts = kh_init(int_uint32);
|
||||||
|
|
||||||
stats.class_counts = class_counts;
|
stats.class_counts = place_class_counts;
|
||||||
stats.components = component;
|
stats.components = component;
|
||||||
|
|
||||||
kh_value(phrase_stats, k) = stats;
|
kh_value(phrase_stats, k) = stats;
|
||||||
} else {
|
} else {
|
||||||
stats = kh_value(phrase_stats, k);
|
stats = kh_value(phrase_stats, k);
|
||||||
class_counts = stats.class_counts;
|
place_class_counts = stats.class_counts;
|
||||||
stats.components |= component;
|
stats.components |= component;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!int_uint32_hash_incr(class_counts, (khint_t)class_id)) {
|
if (!int_uint32_hash_incr(place_class_counts, (khint_t)class_id)) {
|
||||||
log_error("Error in int_uint32_hash_incr in class_counts\n");
|
log_error("Error in int_uint32_hash_incr in class_counts\n");
|
||||||
goto exit_hashes_allocated;
|
goto exit_hashes_allocated;
|
||||||
}
|
}
|
||||||
@@ -679,6 +698,11 @@ address_parser_t *address_parser_init(char *filename) {
|
|||||||
|
|
||||||
parser->model = NULL;
|
parser->model = NULL;
|
||||||
|
|
||||||
|
|
||||||
|
size_t num_classes = kh_size(unique_classes);
|
||||||
|
log_info("num_classes = %zu\n", num_classes);
|
||||||
|
parser->num_classes = num_classes;
|
||||||
|
|
||||||
log_info("Creating vocab trie\n");
|
log_info("Creating vocab trie\n");
|
||||||
|
|
||||||
parser->vocab = trie_new_from_hash(vocab);
|
parser->vocab = trie_new_from_hash(vocab);
|
||||||
@@ -754,14 +778,14 @@ address_parser_t *address_parser_init(char *filename) {
|
|||||||
|
|
||||||
stats = kh_value(phrase_stats, k);
|
stats = kh_value(phrase_stats, k);
|
||||||
|
|
||||||
class_counts = stats.class_counts;
|
place_class_counts = stats.class_counts;
|
||||||
int32_t most_common = -1;
|
int32_t most_common = -1;
|
||||||
uint32_t max_count = 0;
|
uint32_t max_count = 0;
|
||||||
uint32_t total = 0;
|
uint32_t total = 0;
|
||||||
for (uint32_t i = 0; i < NUM_ADDRESS_PARSER_BOUNDARY_TYPES; i++) {
|
for (uint32_t i = 0; i < NUM_ADDRESS_PARSER_BOUNDARY_TYPES; i++) {
|
||||||
k = kh_get(int_uint32, class_counts, (khint_t)i);
|
k = kh_get(int_uint32, place_class_counts, (khint_t)i);
|
||||||
if (k != kh_end(class_counts)) {
|
if (k != kh_end(place_class_counts)) {
|
||||||
count = kh_value(class_counts, k);
|
count = kh_value(place_class_counts, k);
|
||||||
|
|
||||||
if (count > max_count) {
|
if (count > max_count) {
|
||||||
max_count = count;
|
max_count = count;
|
||||||
@@ -914,6 +938,11 @@ exit_hashes_allocated:
|
|||||||
})
|
})
|
||||||
kh_destroy(str_uint32, vocab);
|
kh_destroy(str_uint32, vocab);
|
||||||
|
|
||||||
|
kh_foreach_key(unique_classes, token, {
|
||||||
|
free((char *)token);
|
||||||
|
})
|
||||||
|
kh_destroy(str_set, unique_classes);
|
||||||
|
|
||||||
kh_foreach(phrase_stats, token, stats, {
|
kh_foreach(phrase_stats, token, stats, {
|
||||||
kh_destroy(int_uint32, stats.class_counts);
|
kh_destroy(int_uint32, stats.class_counts);
|
||||||
free((char *)token);
|
free((char *)token);
|
||||||
|
|||||||
Reference in New Issue
Block a user