diff --git a/src/address_parser.c b/src/address_parser.c index a15e998a..79c8129f 100644 --- a/src/address_parser.c +++ b/src/address_parser.c @@ -8,8 +8,10 @@ #define ADDRESS_PARSER_MODEL_FILENAME "address_parser.dat" #define ADDRESS_PARSER_VOCAB_FILENAME "address_parser_vocab.trie" +#define ADDRESS_PARSER_PHRASE_FILENAME "address_parser_phrases.trie" #define UNKNOWN_WORD "UNKNOWN" +#define UNKNOWN_NUMERIC "UNKNOWN_NUMERIC" static address_parser_t *parser = NULL; @@ -47,6 +49,15 @@ bool address_parser_save(address_parser_t *self, char *output_dir) { return false; } + char_array_clear(path); + + char_array_add_joined(path, PATH_SEPARATOR, true, 2, output_dir, ADDRESS_PARSER_PHRASE_FILENAME); + char *phrases_path = char_array_get_string(path); + + if (!trie_save(self->phrase_types, phrases_path)) { + return false; + } + char_array_destroy(path); return true; @@ -90,6 +101,22 @@ bool address_parser_load(char *dir) { parser->vocab = vocab; + char_array_clear(path); + + char_array_add_joined(path, PATH_SEPARATOR, true, 2, dir, ADDRESS_PARSER_PHRASE_FILENAME); + + char *phrases_path = char_array_get_string(path); + + trie_t *phrase_types = trie_load(phrases_path); + + if (phrase_types == NULL) { + address_parser_destroy(parser); + char_array_destroy(path); + return false; + } + + parser->phrase_types = phrase_types; + char_array_destroy(path); return true; } @@ -105,6 +132,10 @@ void address_parser_destroy(address_parser_t *self) { trie_destroy(self->vocab); } + if (self->phrase_types != NULL) { + trie_destroy(self->phrase_types); + } + free(self); } @@ -162,6 +193,14 @@ void address_parser_context_destroy(address_parser_context_t *self) { int64_array_destroy(self->geodb_phrase_memberships); } + if (self->component_phrases != NULL) { + phrase_array_destroy(self->component_phrases); + } + + if (self->component_phrase_memberships != NULL) { + int64_array_destroy(self->component_phrase_memberships); + } + free(self); } @@ -218,6 +257,16 @@ address_parser_context_t *address_parser_context_new(void) { goto exit_address_parser_context_allocated; } + context->component_phrases = phrase_array_new(); + if (context->component_phrases == NULL) { + goto exit_address_parser_context_allocated; + } + + context->component_phrase_memberships = int64_array_new(); + if (context->component_phrase_memberships == NULL) { + goto exit_address_parser_context_allocated; + } + return context; exit_address_parser_context_allocated: @@ -225,7 +274,7 @@ exit_address_parser_context_allocated: return NULL; } -void address_parser_context_fill(address_parser_context_t *context, tokenized_string_t *tokenized_str, char *language, char *country) { +void address_parser_context_fill(address_parser_context_t *context, address_parser_t *parser, tokenized_string_t *tokenized_str, char *language, char *country) { int64_t i, j; uint32_t token_index; @@ -301,6 +350,34 @@ void address_parser_context_fill(address_parser_context_t *context, tokenized_st int64_array_push(geodb_phrase_memberships, NULL_PHRASE_MEMBERSHIP); } + phrase_array_clear(context->component_phrases); + int64_array_clear(context->component_phrase_memberships); + i = 0; + + phrase_array *component_phrases = context->component_phrases; + int64_array *component_phrase_memberships = context->component_phrase_memberships; + + if (trie_search_tokens_with_phrases(parser->phrase_types, str, tokens, &component_phrases)) { + for (j = 0; j < component_phrases->n; j++) { + phrase = component_phrases->a[j]; + + for (; i < phrase.start; i++) { + log_debug("token i=%lld, null component phrase membership\n", i); + int64_array_push(component_phrase_memberships, NULL_PHRASE_MEMBERSHIP); + } + + for (i = phrase.start; i < phrase.start + phrase.len; i++) { + log_debug("token i=%lld, component phrase membership=%lld\n", i, j); + int64_array_push(component_phrase_memberships, j); + } + } + } + for (; i < tokens->n; i++) { + log_debug("token i=%lld, null component phrase membership\n", i); + int64_array_push(component_phrase_memberships, NULL_PHRASE_MEMBERSHIP); + } + + } @@ -384,7 +461,6 @@ static inline void add_phrase_features(cstring_array *features, uint32_t phrase_ } } - /* address_parser_features ----------------------- @@ -426,15 +502,14 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize int64_array *address_phrase_memberships = context->address_phrase_memberships; phrase_array *geodb_phrases = context->geodb_phrases; int64_array *geodb_phrase_memberships = context->geodb_phrase_memberships; + phrase_array *component_phrases = context->component_phrases; + int64_array *component_phrase_memberships = context->component_phrase_memberships; cstring_array *normalized = context->normalized; uint32_array *separators = context->separators; cstring_array_clear(features); - // Bias unit, acts as an intercept - feature_array_add(features, 1, "bias"); - char *original_word = tokenized_string_get_token(tokenized, i); token_t token = tokenized->tokens->a[i]; @@ -449,7 +524,6 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize } size_t word_len = strlen(word); - char *current_word = word; log_debug("word=%s\n", word); @@ -459,6 +533,7 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize char *phrase_string = NULL; char *geo_phrase_string = NULL; + char *component_phrase_string = NULL; int64_t address_phrase_index = address_phrase_memberships->a[i]; @@ -519,40 +594,96 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize } } + bool add_word_feature = true; + + int64_t component_phrase_index = component_phrase_memberships->a[i]; + phrase = NULL_PHRASE; + + address_parser_types_t types; + + // Component phrases + if (component_phrase_index != NULL_PHRASE_MEMBERSHIP) { + phrase = component_phrases->a[component_phrase_index]; + + component_phrase_string = get_phrase_string(tokenized, phrase_tokens, phrase); + + types.value = phrase.data; + uint32_t component_phrase_types = types.components; + uint32_t most_common = types.most_common; + + if (last_index >= (ssize_t)phrase.start - 1 || next_index <= (ssize_t)phrase.start + phrase.len - 1) { + last_index = (ssize_t)phrase.start - 1; + next_index = (ssize_t)phrase.start + phrase.len; + + } + + if (component_phrase_string != NULL && component_phrase_types ^ ADDRESS_COMPONENT_POSTAL_CODE) { + feature_array_add(features, 2, "phrase", component_phrase_string); + add_word_feature = false; + } + + if (component_phrase_types > 0) { + add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_SUBURB, "suburb", component_phrase_string, prev2, prev); + add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_CITY, "city", component_phrase_string, prev2, prev); + add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_CITY_DISTRICT, "city_district", component_phrase_string, prev2, prev); + add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_STATE_DISTRICT, "state_district", component_phrase_string, prev2, prev); + add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_STATE, "state", component_phrase_string, prev2, prev); + add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_POSTAL_CODE, "postal_code", component_phrase_string, prev2, prev); + add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_COUNTRY, "country", component_phrase_string, prev2, prev); + } + + if (most_common == ADDRESS_PARSER_CITY) { + feature_array_add(features, 2, "commonly city", component_phrase_string); + } else if (most_common == ADDRESS_PARSER_STATE) { + feature_array_add(features, 2, "commonly state", component_phrase_string); + } else if (most_common == ADDRESS_PARSER_COUNTRY) { + feature_array_add(features, 2, "commonly country", component_phrase_string); + } else if (most_common == ADDRESS_PARSER_STATE_DISTRICT) { + feature_array_add(features, 2, "commonly state_district", component_phrase_string); + } else if (most_common == ADDRESS_PARSER_SUBURB) { + feature_array_add(features, 2, "commonly suburb", component_phrase_string); + } else if (most_common == ADDRESS_PARSER_CITY_DISTRICT) { + feature_array_add(features, 2, "commonly city_district", component_phrase_string); + } else if (most_common == ADDRESS_PARSER_POSTAL_CODE) { + feature_array_add(features, 2, "commonly postal_code", component_phrase_string); + } + + } + int64_t geodb_phrase_index = geodb_phrase_memberships->a[i]; phrase = NULL_PHRASE; geodb_value_t geo; // GeoDB phrases - if (geodb_phrase_index != NULL_PHRASE_MEMBERSHIP) { + if (component_phrase_index == NULL_PHRASE_MEMBERSHIP && geodb_phrase_index != NULL_PHRASE_MEMBERSHIP) { phrase = geodb_phrases->a[geodb_phrase_index]; geo_phrase_string = get_phrase_string(tokenized, phrase_tokens, phrase); geo.value = phrase.data; uint32_t geodb_phrase_types = geo.components; - if (last_index <= (ssize_t)phrase.start - 1 && next_index >= (ssize_t)phrase.start + phrase.len - 1) { + if (last_index >= (ssize_t)phrase.start - 1 || next_index <= (ssize_t)phrase.start + phrase.len) { last_index = (ssize_t)phrase.start - 1; next_index = (ssize_t)phrase.start + phrase.len; - if (geo_phrase_string != NULL && geodb_phrase_types ^ ADDRESS_POSTAL_CODE) { - word = geo_phrase_string; - } + } + if (geo_phrase_string != NULL && geodb_phrase_types ^ ADDRESS_POSTAL_CODE) { + feature_array_add(features, 2, "phrase", geo_phrase_string); + add_word_feature = false; } if (geodb_phrase_types ^ ADDRESS_ANY) { + add_phrase_features(features, geodb_phrase_types, ADDRESS_LOCALITY, "gn city", geo_phrase_string, prev2, prev); + add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN1, "gn admin1", geo_phrase_string, prev2, prev); + add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN2, "gn admin2", geo_phrase_string, prev2, prev); + add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN3, "gn admin3", geo_phrase_string, prev2, prev); + add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN4, "gn admin4", geo_phrase_string, prev2, prev); + add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN_OTHER, "gn admin other", geo_phrase_string, prev2, prev); + add_phrase_features(features, geodb_phrase_types, ADDRESS_NEIGHBORHOOD, "gn neighborhood", geo_phrase_string, prev2, prev); - add_phrase_features(features, geodb_phrase_types, ADDRESS_LOCALITY, "city", geo_phrase_string, prev2, prev); - add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN1, "admin1", geo_phrase_string, prev2, prev); - add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN2, "admin2", geo_phrase_string, prev2, prev); - add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN3, "admin3", geo_phrase_string, prev2, prev); - add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN4, "admin4", geo_phrase_string, prev2, prev); - add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN_OTHER, "admin other", geo_phrase_string, prev2, prev); - add_phrase_features(features, geodb_phrase_types, ADDRESS_NEIGHBORHOOD, "neighborhood", geo_phrase_string, prev2, prev); - - add_phrase_features(features, geodb_phrase_types, ADDRESS_COUNTRY, "country", geo_phrase_string, prev2, prev); - add_phrase_features(features, geodb_phrase_types, ADDRESS_POSTAL_CODE, "postal code", geo_phrase_string, prev2, prev); + add_phrase_features(features, geodb_phrase_types, ADDRESS_COUNTRY, "gn country", geo_phrase_string, prev2, prev); + add_phrase_features(features, geodb_phrase_types, ADDRESS_POSTAL_CODE, "gn postal code", geo_phrase_string, prev2, prev); } @@ -560,24 +691,27 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize uint32_t word_freq = word_vocab_frequency(parser, word); - if (phrase_string == NULL && geo_phrase_string == NULL) { + if (add_word_feature) { + // Bias unit, acts as an intercept + feature_array_add(features, 1, "bias"); + if (word_freq > 0) { // The individual word feature_array_add(features, 2, "word", word); } else { log_debug("word not in vocab: %s\n", original_word); - word = UNKNOWN_WORD; + word = (token.type != NUMERIC && token.type != IDEOGRAPHIC_NUMBER) ? UNKNOWN_WORD : UNKNOWN_NUMERIC; } } - - if (prev != NULL) { + + if (prev != NULL && last_index == i - 1) { // Previous tag and current word - feature_array_add(features, 3, "i-1 tag+word", prev, current_word); + feature_array_add(features, 3, "i-1 tag+word", prev, word); feature_array_add(features, 2, "i-1 tag", prev); if (prev2 != NULL) { // Previous two tags and current word - feature_array_add(features, 4, "i-2 tag+i-1 tag+word", prev2, prev, current_word); + feature_array_add(features, 4, "i-2 tag+i-1 tag+word", prev2, prev, word); feature_array_add(features, 3, "i-2 tag+i-1 tag", prev2, prev); } } @@ -587,15 +721,14 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize uint32_t prev_word_freq = word_vocab_frequency(parser, prev_word); if (prev_word_freq == 0) { - prev_word = UNKNOWN_WORD; + token_t prev_token = tokenized->tokens->a[last_index]; + prev_word = (prev_token.type != NUMERIC && prev_token.type != IDEOGRAPHIC_NUMBER) ? UNKNOWN_WORD : UNKNOWN_NUMERIC; } // Previous word feature_array_add(features, 2, "i-1 word", prev_word); - // Previous tag + previous word - if (last_index == i - 1) { - feature_array_add(features, 3, "i-1 tag+i-1 word", prev, prev_word); - } + feature_array_add(features, 3, "i-1 tag+i-1 word", prev, prev_word); + // Previous word and current word feature_array_add(features, 3, "i-1 word+word", prev_word, word); } @@ -607,15 +740,34 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize uint32_t next_word_freq = word_vocab_frequency(parser, next_word); if (next_word_freq == 0) { - next_word = UNKNOWN_WORD; + token_t next_token = tokenized->tokens->a[next_index]; + next_word = (next_token.type != NUMERIC && next_token.type != IDEOGRAPHIC_NUMBER) ? UNKNOWN_WORD : UNKNOWN_NUMERIC; } // Next word e.g. if the current word is unknown and the next word is "street" feature_array_add(features, 2, "i+1 word", next_word); + // Current word and next word feature_array_add(features, 3, "word+i+1 word", word, next_word); } + #ifndef PRINT_FEATURES + if (0) { + #endif + + uint32_t idx; + char *feature; + + printf("{"); + cstring_array_foreach(features, idx, feature, { + printf(" %s, ", feature); + }) + printf("}\n"); + + #ifndef PRINT_FEATURES + } + #endif + return true; } @@ -681,7 +833,7 @@ address_parser_response_t *address_parser_parse(char *address, char *language, c uint32_array_push(context->separators, ADDRESS_SEPARATOR_NONE); } - address_parser_context_fill(context, tokenized_str, language, country); + address_parser_context_fill(context, parser, tokenized_str, language, country); cstring_array *token_labels = cstring_array_new_size(tokens->n); diff --git a/src/address_parser_test.c b/src/address_parser_test.c index 2ec1021a..7a3039c5 100644 --- a/src/address_parser_test.c +++ b/src/address_parser_test.c @@ -10,7 +10,6 @@ #include "log/log.h" - typedef struct address_parser_test_results { size_t num_errors; size_t num_predictions; @@ -67,7 +66,7 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse } char *country = char_array_get_string(data_set->country); - address_parser_context_fill(context, data_set->tokenized_str, language, country); + address_parser_context_fill(context, parser, data_set->tokenized_str, language, country); cstring_array *token_labels = cstring_array_new_size(data_set->tokenized_str->strings->str->n); @@ -90,6 +89,7 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse uint32_t truth_index = get_class_index(parser, truth); result->confusion[predicted_index * num_classes + truth_index]++; + } result->num_predictions++; diff --git a/src/address_parser_train.c b/src/address_parser_train.c index f96662aa..39a1a436 100644 --- a/src/address_parser_train.c +++ b/src/address_parser_train.c @@ -15,6 +15,14 @@ #define DEFAULT_ITERATIONS 5 #define MIN_VOCAB_COUNT 5 +#define MIN_PHRASE_COUNT 1 + +typedef struct phrase_stats { + khash_t(int_uint32) *class_counts; + address_parser_types_t parser_types; +} phrase_stats_t; + +KHASH_MAP_INIT_STR(phrase_stats, phrase_stats_t) address_parser_t *address_parser_init(char *filename) { if (filename == NULL) { @@ -37,21 +45,43 @@ address_parser_t *address_parser_init(char *filename) { } khash_t(str_uint32) *vocab = kh_init(str_uint32); + if (vocab == NULL) { + log_error("Could not allocate vocab\n"); + return NULL; + } + + khash_t(phrase_stats) *phrase_stats = kh_init(phrase_stats); + if (phrase_stats == NULL) { + log_error("Could not allocate phrase_stats\n"); + return NULL; + } + + khash_t(str_uint32) *phrase_types = kh_init(str_uint32); + if (phrase_types == NULL) { + log_error("Could not allocate phrase_types\n"); + return NULL; + } khiter_t k; char *str; + phrase_stats_t stats; + khash_t(int_uint32) *class_counts; + uint32_t vocab_size = 0; size_t examples = 0; - const char *word; - uint32_t i; - char *token; + const char *token; char *normalized; uint32_t count; - char_array *token_array = char_array_new(); + char *key; + int ret = 0; + + char_array *token_builder = char_array_new(); + + char_array *phrase_builder = char_array_new(); while (address_parser_data_set_next(data_set)) { tokenized_string_t *tokenized_str = data_set->tokenized_str; @@ -59,38 +89,35 @@ address_parser_t *address_parser_init(char *filename) { if (tokenized_str == NULL) { log_error("tokenized str is NULL\n"); kh_destroy(str_uint32, vocab); - return false; + return NULL; } str = tokenized_str->str; + char *prev_label = NULL; + + size_t num_strings = cstring_array_num_strings(tokenized_str->strings); + cstring_array_foreach(tokenized_str->strings, i, token, { token_t t = tokenized_str->tokens->a[i]; - char_array_clear(token_array); - add_normalized_token(token_array, str, t, ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS); - if (token_array->n == 0) { + char_array_clear(token_builder); + add_normalized_token(token_builder, str, t, ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS); + if (token_builder->n == 0) { continue; } - normalized = char_array_get_string(token_array); - + normalized = char_array_get_string(token_builder); k = kh_get(str_uint32, vocab, normalized); + if (k == kh_end(vocab)) { - int ret; - char *key = strdup(normalized); + key = strdup(normalized); k = kh_put(str_uint32, vocab, key, &ret); if (ret < 0) { - log_error("Error in kh_put\n"); + log_error("Error in kh_put in vocab\n"); free(key); - tokenized_string_destroy(tokenized_str); - kh_foreach(vocab, word, count, { - free((char *)word); - }) - kh_destroy(str_uint32, vocab); - char_array_destroy(token_array); - return false; + goto exit_hashes_allocated; } kh_value(vocab, k) = 1; vocab_size++; @@ -98,6 +125,111 @@ address_parser_t *address_parser_init(char *filename) { kh_value(vocab, k)++; } + char *label = cstring_array_get_string(data_set->labels, i); + if (label == NULL) { + continue; + } + + if (string_equals(label, "road") || string_equals(label, "house_number") || string_equals(label, "house")) { + prev_label = NULL; + continue; + } + + if (prev_label == NULL || !string_equals(label, prev_label) || i == num_strings - 1) { + + if (i == num_strings - 1) { + if (!string_equals(label, prev_label)) { + char_array_clear(phrase_builder); + } else if (prev_label != NULL) { + char_array_cat(phrase_builder, " "); + } + + char_array_cat(phrase_builder, normalized); + prev_label = label; + } + + // End of phrase, add to hashtable + if (prev_label != NULL) { + uint32_t class_id; + uint32_t component = 0; + + // Too many variations on these + if (string_equals(prev_label, "city")) { + class_id = ADDRESS_PARSER_CITY; + component = ADDRESS_COMPONENT_CITY; + } else if (string_equals(prev_label, "state")) { + class_id = ADDRESS_PARSER_STATE; + component = ADDRESS_COMPONENT_STATE; + } else if (string_equals(prev_label, "country")) { + class_id = ADDRESS_PARSER_COUNTRY; + component = ADDRESS_COMPONENT_COUNTRY; + } else if (string_equals(prev_label, "state_district")) { + class_id = ADDRESS_PARSER_STATE_DISTRICT; + component = ADDRESS_COMPONENT_STATE_DISTRICT; + } else if (string_equals(prev_label, "suburb")) { + class_id = ADDRESS_PARSER_SUBURB; + component = ADDRESS_COMPONENT_SUBURB; + } else if (string_equals(prev_label, "city_district")) { + class_id = ADDRESS_PARSER_CITY_DISTRICT; + component = ADDRESS_COMPONENT_CITY_DISTRICT; + } else if (string_equals(prev_label, "postcode")) { + class_id = ADDRESS_PARSER_POSTAL_CODE; + component = ADDRESS_COMPONENT_POSTAL_CODE; + } + + char *phrase = char_array_get_string(phrase_builder); + + k = kh_get(phrase_stats, phrase_stats, phrase); + + if (k == kh_end(phrase_stats)) { + key = strdup(phrase); + ret = 0; + k = kh_put(phrase_stats, phrase_stats, key, &ret); + if (ret < 0) { + log_error("Error in kh_put in phrase_stats\n"); + free(key); + goto exit_hashes_allocated; + } + class_counts = kh_init(int_uint32); + + stats.class_counts = class_counts; + stats.parser_types.components = component; + stats.parser_types.most_common = 0; + + kh_value(phrase_stats, k) = stats; + } else { + stats = kh_value(phrase_stats, k); + class_counts = stats.class_counts; + stats.parser_types.components |= component; + + } + + k = kh_get(int_uint32, class_counts, (khint_t)class_id); + + if (k == kh_end(class_counts)) { + ret = 0; + k = kh_put(int_uint32, class_counts, class_id, &ret); + if (ret < 0) { + log_error("Error in kh_put in class_counts\n"); + goto exit_hashes_allocated; + } + kh_value(class_counts, k) = 1; + } else { + kh_value(class_counts, k)++; + } + + } + + char_array_clear(phrase_builder); + } else if (prev_label != NULL) { + char_array_cat(phrase_builder, " "); + } + + char_array_cat(phrase_builder, normalized); + + + prev_label = label; + }) tokenized_string_destroy(tokenized_str); @@ -110,38 +242,92 @@ address_parser_t *address_parser_init(char *filename) { log_debug("Done with vocab, total size=%d\n", vocab_size); for (k = kh_begin(vocab); k != kh_end(vocab); ++k) { - char *word = (char *)kh_key(vocab, k); + token = (char *)kh_key(vocab, k); if (!kh_exist(vocab, k)) { continue; } uint32_t count = kh_value(vocab, k); if (count < MIN_VOCAB_COUNT) { kh_del(str_uint32, vocab, k); - free(word); + free((char *)token); } } + kh_foreach(phrase_stats, token, stats, { + class_counts = stats.class_counts; + int most_common = -1; + uint32_t max_count = 0; + uint32_t total = 0; + for (int i = 0; i < NUM_ADDRESS_PARSER_TYPES; i++) { + k = kh_get(int_uint32, class_counts, (khint_t)i); + if (k != kh_end(class_counts)) { + count = kh_value(class_counts, k); + + if (count > max_count) { + max_count = count; + most_common = i; + } + total += count; + } + } + + if (most_common > -1 && total >= MIN_PHRASE_COUNT) { + stats.parser_types.most_common = (uint32_t)most_common; + ret = 0; + char *key = strdup(token); + k = kh_put(str_uint32, phrase_types, key, &ret); + if (ret < 0) { + log_error("Error on kh_put in phrase_types\n"); + free(key); + goto exit_hashes_allocated; + } + + kh_value(phrase_types, k) = stats.parser_types.value; + } + }) + parser->vocab = trie_new_from_hash(vocab); - - for (k = kh_begin(vocab); k != kh_end(vocab); ++k) { - if (!kh_exist(vocab, k)) { - continue; - } - char *word = (char *)kh_key(vocab, k); - free(word); - } - - kh_destroy(str_uint32, vocab); - - char_array_destroy(token_array); - address_parser_data_set_destroy(data_set); if (parser->vocab == NULL) { log_error("Error initializing vocabulary\n"); address_parser_destroy(parser); - return NULL; + parser = NULL; + goto exit_hashes_allocated; } + parser->phrase_types = trie_new_from_hash(phrase_types); + if (parser->phrase_types == NULL) { + log_error("Error converting phrase_types to trie\n"); + address_parser_destroy(parser); + parser = NULL; + goto exit_hashes_allocated; + } + +exit_hashes_allocated: + // Free memory for hashtables, etc. + + char_array_destroy(token_builder); + char_array_destroy(phrase_builder); + address_parser_data_set_destroy(data_set); + + kh_foreach(vocab, token, count, { + free((char *)token); + }) + kh_destroy(str_uint32, vocab); + + kh_foreach(phrase_stats, token, stats, { + kh_destroy(int_uint32, stats.class_counts); + free((char *)token); + }) + + kh_destroy(phrase_stats, phrase_stats); + + kh_foreach(phrase_types, token, count, { + free((char *)token); + }) + kh_destroy(str_uint32, phrase_types); + return parser; + } @@ -174,7 +360,7 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai } char *country = char_array_get_string(data_set->country); - address_parser_context_fill(context, data_set->tokenized_str, language, country); + address_parser_context_fill(context, self, data_set->tokenized_str, language, country); bool example_success = averaged_perceptron_trainer_train_example(trainer, self, context, context->features, &address_parser_features, data_set->tokenized_str, data_set->labels);