diff --git a/src/address_parser_train.c b/src/address_parser_train.c index 622e1a7d..bfc08cc6 100644 --- a/src/address_parser_train.c +++ b/src/address_parser_train.c @@ -11,7 +11,7 @@ #include "log/log.h" - + typedef struct phrase_stats { khash_t(int_uint32) *class_counts; address_parser_types_t parser_types; @@ -26,16 +26,20 @@ KHASH_MAP_INIT_STR(phrase_stats, phrase_stats_t) #define MIN_VOCAB_COUNT 5 #define MIN_PHRASE_COUNT 1 -static inline bool is_phrase_component(char *label) { +static inline bool is_postal_code(char *label) { + return string_equals(label, ADDRESS_PARSER_LABEL_POSTAL_CODE); +} + +static inline bool is_admin_component(char *label) { return (string_equals(label, ADDRESS_PARSER_LABEL_SUBURB) || string_equals(label, ADDRESS_PARSER_LABEL_CITY_DISTRICT) || string_equals(label, ADDRESS_PARSER_LABEL_CITY) || string_equals(label, ADDRESS_PARSER_LABEL_STATE_DISTRICT) || string_equals(label, ADDRESS_PARSER_LABEL_ISLAND) || string_equals(label, ADDRESS_PARSER_LABEL_STATE) || - string_equals(label, ADDRESS_PARSER_LABEL_POSTAL_CODE) || string_equals(label, ADDRESS_PARSER_LABEL_COUNTRY_REGION) || - string_equals(label, ADDRESS_PARSER_LABEL_COUNTRY)); + string_equals(label, ADDRESS_PARSER_LABEL_COUNTRY) || + string_equals(label, ADDRESS_PARSER_LABEL_WORLD_REGION)); } address_parser_t *address_parser_init(char *filename) { @@ -51,7 +55,6 @@ address_parser_t *address_parser_init(char *filename) { return NULL; } - address_parser_t *parser = address_parser_new(); if (parser == NULL) { log_error("Error allocating parser\n"); @@ -64,6 +67,12 @@ address_parser_t *address_parser_init(char *filename) { return NULL; } + khash_t(str_uint32) *phrase_counts = kh_init(str_uint32); + if (vocab == NULL) { + log_error("Could not allocate vocab\n"); + return NULL; + } + khash_t(phrase_stats) *phrase_stats = kh_init(phrase_stats); if (phrase_stats == NULL) { log_error("Could not allocate phrase_stats\n"); @@ -79,13 +88,14 @@ address_parser_t *address_parser_init(char *filename) { khiter_t k; char *str; + uint32_t i; + phrase_stats_t stats; khash_t(int_uint32) *class_counts; uint32_t vocab_size = 0; size_t examples = 0; - uint32_t i; const char *token; char *normalized; uint32_t count; @@ -93,12 +103,33 @@ address_parser_t *address_parser_init(char *filename) { char *key; int ret = 0; + bool is_postal; + + char *label; + char *prev_label; + char_array *token_builder = char_array_new(); + char_array *postcode_token_builder = char_array_new(); + char_array *sub_token_builder = char_array_new(); char_array *phrase_builder = char_array_new(); + cstring_array *phrases = cstring_array_new(); + cstring_array *phrase_labels = cstring_array_new(); + + char *phrase; + + phrase_array *dictionary_phrases = phrase_array_new(); + + token_array *sub_tokens = token_array_new(); + + trie_t *phrase_counts_trie = NULL; + + tokenized_string_t *tokenized_str; + token_array *tokens; + while (address_parser_data_set_next(data_set)) { - tokenized_string_t *tokenized_str = data_set->tokenized_str; + tokenized_str = data_set->tokenized_str; if (tokenized_str == NULL) { log_error("tokenized str is NULL\n"); @@ -106,25 +137,39 @@ address_parser_t *address_parser_init(char *filename) { return NULL; } - str = tokenized_str->str; + char *language = char_array_get_string(data_set->language); + if (string_equals(language, UNKNOWN_LANGUAGE) || string_equals(language, AMBIGUOUS_LANGUAGE)) { + language = NULL; + } - char *prev_label = NULL; + str = tokenized_str->str; + tokens = tokenized_str->tokens; + + prev_label = NULL; size_t num_strings = cstring_array_num_strings(tokenized_str->strings); - cstring_array_foreach(tokenized_str->strings, i, token, { - token_t t = tokenized_str->tokens->a[i]; + cstring_array_clear(phrases); + cstring_array_clear(phrase_labels); - char *label = cstring_array_get_string(data_set->labels, i); + cstring_array_foreach(tokenized_str->strings, i, token, { + token_t t = tokens->a[i]; + + label = cstring_array_get_string(data_set->labels, i); if (label == NULL) { continue; } char_array_clear(token_builder); - bool is_phrase_label = is_phrase_component(label); + bool is_admin = is_admin_component(label); + is_postal = !is_admin && is_postal_code(label); - uint64_t normalize_token_options = is_phrase_label ? ADDRESS_PARSER_NORMALIZE_PHRASE_TOKEN_OPTIONS : ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS; + uint64_t normalize_token_options = ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS; + + if (is_admin || is_postal) { + normalize_token_options = ADDRESS_PARSER_NORMALIZE_ADMIN_TOKEN_OPTIONS; + } add_normalized_token(token_builder, str, t, normalize_token_options); if (token_builder->n == 0) { @@ -133,9 +178,9 @@ address_parser_t *address_parser_init(char *filename) { normalized = char_array_get_string(token_builder); - if (!is_phrase_component(label)) { + if (!is_admin && !is_postal) { k = kh_get(str_uint32, vocab, normalized); - + if (k == kh_end(vocab)) { key = strdup(normalized); k = kh_put(str_uint32, vocab, key, &ret); @@ -155,127 +200,100 @@ address_parser_t *address_parser_init(char *filename) { continue; } - if (prev_label == NULL || !string_equals(label, prev_label) || i == num_strings - 1) { + if (is_postal) { + char_array_clear(postcode_token_builder); + add_normalized_token(postcode_token_builder, str, t, ADDRESS_PARSER_NORMALIZE_POSTAL_CODE_TOKEN_OPTIONS); + char *postcode_normalized = char_array_get_string(postcode_token_builder); - if (i == num_strings - 1) { - if (!string_equals(label, prev_label)) { - char_array_clear(phrase_builder); - } else if (prev_label != NULL) { + token_array_clear(sub_tokens); + phrase_array_clear(dictionary_phrases); + tokenize_add_tokens(sub_tokens, postcode_normalized, strlen(postcode_normalized), false); + + // One specific case where "CP" or "CEP" can be concatenated onto the front of the token + if (sub_tokens->n > 1 && search_address_dictionaries_tokens_with_phrases(postcode_normalized, sub_tokens, language, &dictionary_phrases) && dictionary_phrases->n > 0) { + phrase_t first_postcode_phrase = dictionary_phrases->a[0]; + address_expansion_value_t *value = address_dictionary_get_expansions(first_postcode_phrase.data); + if (value != NULL && value->components & ADDRESS_POSTAL_CODE) { + char_array_clear(token_builder); + size_t first_real_token_index = first_postcode_phrase.start + first_postcode_phrase.len; + token_t first_real_token = sub_tokens->a[first_real_token_index]; + char_array_cat(token_builder, postcode_normalized + first_real_token.offset); + normalized = char_array_get_string(token_builder); + } + } + } + + + bool same_as_previous_label = string_equals(label, prev_label); + + if (prev_label == NULL || !same_as_previous_label || i == num_strings - 1) { + if (i == num_strings - 1 && (same_as_previous_label || prev_label == NULL)) { + if (prev_label != NULL) { char_array_cat(phrase_builder, " "); } char_array_cat(phrase_builder, normalized); - prev_label = label; } // End of phrase, add to hashtable if (prev_label != NULL) { - uint32_t class_id; - uint32_t component = 0; + bool last_was_postal = string_equals(prev_label, ADDRESS_PARSER_LABEL_POSTAL_CODE); - // Too many variations on these - if (string_equals(prev_label, ADDRESS_PARSER_LABEL_CITY)) { - class_id = ADDRESS_PARSER_CITY; - component = ADDRESS_COMPONENT_CITY; - } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_STATE)) { - class_id = ADDRESS_PARSER_STATE; - component = ADDRESS_COMPONENT_STATE; - } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_COUNTRY)) { - class_id = ADDRESS_PARSER_COUNTRY; - component = ADDRESS_COMPONENT_COUNTRY; - } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_COUNTRY_REGION)) { - class_id = ADDRESS_PARSER_COUNTRY_REGION; - component = ADDRESS_COMPONENT_COUNTRY; - } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_STATE_DISTRICT)) { - class_id = ADDRESS_PARSER_STATE_DISTRICT; - component = ADDRESS_COMPONENT_STATE_DISTRICT; - } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_SUBURB)) { - class_id = ADDRESS_PARSER_SUBURB; - component = ADDRESS_COMPONENT_SUBURB; - } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_CITY_DISTRICT)) { - class_id = ADDRESS_PARSER_CITY_DISTRICT; - component = ADDRESS_COMPONENT_CITY_DISTRICT; - } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_ISLAND)) { - class_id = ADDRESS_PARSER_ISLAND; - component = ADDRESS_COMPONENT_ISLAND; - } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_POSTAL_CODE)) { - class_id = ADDRESS_PARSER_POSTAL_CODE; - component = ADDRESS_COMPONENT_POSTAL_CODE; - } else { - // Shouldn't happen but just in case - prev_label = NULL; - continue; - } + phrase = char_array_get_string(phrase_builder); - char *phrase = char_array_get_string(phrase_builder); + if (last_was_postal) { + token_array_clear(sub_tokens); + phrase_array_clear(dictionary_phrases); - char *normalized_phrase = NULL; + tokenize_add_tokens(sub_tokens, phrase, strlen(phrase), false); - if (string_contains_hyphen(phrase)) { - char *phrase_copy = strdup(phrase); - if (phrase_copy == NULL) { - goto exit_hashes_allocated; - } - normalized_phrase = normalize_string_utf8(phrase_copy, NORMALIZE_STRING_REPLACE_HYPHENS); + if (sub_tokens->n > 0 && search_address_dictionaries_tokens_with_phrases(phrase, sub_tokens, language, &dictionary_phrases) && dictionary_phrases->n > 0) { + char_array_clear(sub_token_builder); - } + phrase_t current_phrase = NULL_PHRASE; + phrase_t prev_phrase = NULL_PHRASE; + token_t current_sub_token; - char *phrases[2]; - phrases[0] = phrase; - phrases[1] = normalized_phrase; + for (size_t pc = 0; pc < dictionary_phrases->n; pc++) { + current_phrase = dictionary_phrases->a[pc]; + for (size_t j = prev_phrase.start + prev_phrase.len; j < current_phrase.start; j++) { + current_sub_token = sub_tokens->a[j]; - for (int i = 0; i < sizeof(phrases) / sizeof(char *); i++) { - phrase = phrases[i]; - if (phrase == NULL) continue; + char_array_cat_len(sub_token_builder, phrase + current_sub_token.offset, current_sub_token.len); - k = kh_get(phrase_stats, phrase_stats, phrase); - - if (k == kh_end(phrase_stats)) { - key = strdup(phrase); - ret = 0; - k = kh_put(phrase_stats, phrase_stats, key, &ret); - if (ret < 0) { - log_error("Error in kh_put in phrase_stats\n"); - free(key); - if (normalized_phrase != NULL) { - free(normalized_phrase); + if (j < current_phrase.start - 1) { + char_array_cat(sub_token_builder, " "); + } } - goto exit_hashes_allocated; + prev_phrase = current_phrase; } - class_counts = kh_init(int_uint32); - - stats.class_counts = class_counts; - stats.parser_types.components = component; - stats.parser_types.most_common = 0; - - kh_value(phrase_stats, k) = stats; - } else { - stats = kh_value(phrase_stats, k); - class_counts = stats.class_counts; - stats.parser_types.components |= component; - } + if (prev_phrase.len > 0) { + for (size_t j = prev_phrase.start + prev_phrase.len; j < sub_tokens->n; j++) { + current_sub_token = sub_tokens->a[j]; - k = kh_get(int_uint32, class_counts, (khint_t)class_id); + char_array_cat_len(sub_token_builder, phrase + current_sub_token.offset, current_sub_token.len); - if (k == kh_end(class_counts)) { - ret = 0; - k = kh_put(int_uint32, class_counts, class_id, &ret); - if (ret < 0) { - log_error("Error in kh_put in class_counts\n"); - goto exit_hashes_allocated; + if (j < sub_tokens->n - 1) { + char_array_cat(sub_token_builder, " "); + } + } } - kh_value(class_counts, k) = 1; - } else { - kh_value(class_counts, k)++; - } + phrase = char_array_get_string(sub_token_builder); + } } - if (normalized_phrase != NULL) { - free(normalized_phrase); - } + log_info("phrase=%s\n", phrase); + cstring_array_add_string(phrases, phrase); + cstring_array_add_string(phrase_labels, prev_label); + } + + if (i == num_strings - 1 && !same_as_previous_label && prev_label != NULL) { + log_info("phrase=%s\n", normalized); + cstring_array_add_string(phrases, normalized); + cstring_array_add_string(phrase_labels, label); } char_array_clear(phrase_builder); @@ -289,6 +307,135 @@ address_parser_t *address_parser_init(char *filename) { }) + cstring_array_foreach(phrases, i, phrase, { + if (phrase == NULL) continue; + + uint32_t class_id; + uint32_t component = 0; + + char *phrase_label = cstring_array_get_string(phrase_labels, i); + if (phrase_label == NULL) continue; + + is_postal = false; + + // Too many variations on these + if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_CITY)) { + class_id = ADDRESS_PARSER_BOUNDARY_CITY; + component = ADDRESS_COMPONENT_CITY; + } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_STATE)) { + class_id = ADDRESS_PARSER_BOUNDARY_STATE; + component = ADDRESS_COMPONENT_STATE; + } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_COUNTRY)) { + class_id = ADDRESS_PARSER_BOUNDARY_COUNTRY; + component = ADDRESS_COMPONENT_COUNTRY; + } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_POSTAL_CODE)) { + class_id = ADDRESS_PARSER_BOUNDARY_POSTAL_CODE; + component = ADDRESS_COMPONENT_POSTAL_CODE; + is_postal = true; + } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_COUNTRY_REGION)) { + class_id = ADDRESS_PARSER_BOUNDARY_COUNTRY_REGION; + component = ADDRESS_COMPONENT_COUNTRY_REGION; + } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_STATE_DISTRICT)) { + class_id = ADDRESS_PARSER_BOUNDARY_STATE_DISTRICT; + component = ADDRESS_COMPONENT_STATE_DISTRICT; + } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_SUBURB)) { + class_id = ADDRESS_PARSER_BOUNDARY_SUBURB; + component = ADDRESS_COMPONENT_SUBURB; + } else if (string_equals(phrase_label, ADDRESS_PARSER_LABEL_CITY_DISTRICT)) { + class_id = ADDRESS_PARSER_BOUNDARY_CITY_DISTRICT; + component = ADDRESS_COMPONENT_CITY_DISTRICT; + } else if (string_equals(label, ADDRESS_PARSER_LABEL_WORLD_REGION)) { + class_id = ADDRESS_PARSER_BOUNDARY_WORLD_REGION; + component = ADDRESS_COMPONENT_WORLD_REGION; + } else if (string_equals(label, ADDRESS_PARSER_LABEL_ISLAND)) { + class_id = ADDRESS_PARSER_BOUNDARY_ISLAND; + component = ADDRESS_COMPONENT_ISLAND; + } else { + // Shouldn't happen but just in case + continue; + } + + char *normalized_phrase = NULL; + + if (string_contains_hyphen(phrase) && !is_postal) { + normalized_phrase = normalize_string_utf8(phrase, NORMALIZE_STRING_REPLACE_HYPHENS); + } + + char *phrases[2]; + phrases[0] = phrase; + phrases[1] = normalized_phrase; + + for (int p_i = 0; p_i < sizeof(phrases) / sizeof(char *); p_i++) { + phrase = phrases[p_i]; + if (phrase == NULL) continue; + log_info("adding: %s\n", phrase); + + k = kh_get(phrase_stats, phrase_stats, phrase); + + if (k == kh_end(phrase_stats)) { + key = strdup(phrase); + ret = 0; + k = kh_put(phrase_stats, phrase_stats, key, &ret); + if (ret < 0) { + log_error("Error in kh_put in phrase_stats\n"); + free(key); + goto exit_hashes_allocated; + } + class_counts = kh_init(int_uint32); + + stats.class_counts = class_counts; + stats.parser_types.components = component; + stats.parser_types.most_common = 0; + + kh_value(phrase_stats, k) = stats; + } else { + stats = kh_value(phrase_stats, k); + class_counts = stats.class_counts; + stats.parser_types.components |= component; + } + + k = kh_get(int_uint32, class_counts, (khint_t)class_id); + + if (k == kh_end(class_counts)) { + ret = 0; + k = kh_put(int_uint32, class_counts, class_id, &ret); + if (ret < 0) { + log_error("Error in kh_put in class_counts\n"); + goto exit_hashes_allocated; + } + kh_value(class_counts, k) = 1; + } else { + kh_value(class_counts, k)++; + } + + k = kh_get(str_uint32, phrase_counts, phrase); + + if (k != kh_end(phrase_counts)) { + kh_value(phrase_counts, k)++; + } else { + key = strdup(phrase); + ret = 0; + k = kh_put(str_uint32, phrase_counts, key, &ret); + if (ret < 0) { + log_error("Error in kh_put in phrase_counts\n"); + free(key); + if (normalized_phrase != NULL) { + free(normalized_phrase); + } + goto exit_hashes_allocated; + } + kh_value(phrase_counts, k) = 1; + } + } + + if (normalized_phrase != NULL) { + log_info("freeing\n"); + free(normalized_phrase); + normalized_phrase = NULL; + } + + }) + tokenized_string_destroy(tokenized_str); examples++; if (examples % 10000 == 0 && examples != 0) { @@ -310,12 +457,14 @@ address_parser_t *address_parser_init(char *filename) { } } + phrase_counts_trie = trie_new_from_hash(phrase_counts); + kh_foreach(phrase_stats, token, stats, { class_counts = stats.class_counts; int most_common = -1; uint32_t max_count = 0; uint32_t total = 0; - for (int i = 0; i < NUM_ADDRESS_PARSER_TYPES; i++) { + for (int i = 0; i < NUM_ADDRESS_PARSER_BOUNDARY_TYPES; i++) { k = kh_get(int_uint32, class_counts, (khint_t)i); if (k != kh_end(class_counts)) { count = kh_value(class_counts, k); @@ -343,6 +492,8 @@ address_parser_t *address_parser_init(char *filename) { } }) + parser->model = NULL; + parser->vocab = trie_new_from_hash(vocab); if (parser->vocab == NULL) { log_error("Error initializing vocabulary\n"); @@ -363,9 +514,18 @@ exit_hashes_allocated: // Free memory for hashtables, etc. char_array_destroy(token_builder); + char_array_destroy(postcode_token_builder); char_array_destroy(phrase_builder); + cstring_array_destroy(phrases); + cstring_array_destroy(phrase_labels); + phrase_array_destroy(dictionary_phrases); + token_array_destroy(sub_tokens); address_parser_data_set_destroy(data_set); + if (phrase_counts_trie != NULL) { + trie_destroy(phrase_counts_trie); + } + kh_foreach(vocab, token, count, { free((char *)token); }) @@ -378,6 +538,12 @@ exit_hashes_allocated: kh_destroy(phrase_stats, phrase_stats); + kh_foreach(phrase_counts, token, count, { + free((char *)token); + }) + + kh_destroy(str_uint32, phrase_counts); + kh_foreach(phrase_types, token, count, { free((char *)token); }) @@ -510,6 +676,8 @@ int main(int argc, char **argv) { size_t num_iterations = DEFAULT_ITERATIONS; size_t position = 0; + ssize_t arg_iterations; + char *filename = NULL; char *output_dir = NULL; @@ -522,7 +690,11 @@ int main(int argc, char **argv) { } if (kwarg == ADDRESS_PARSER_TRAIN_ARG_ITERATIONS) { - num_iterations = (size_t)atoi(arg); + if (sscanf(arg, "%zd", &arg_iterations) != 1 || arg_iterations < 0) { + log_error("Bad arg for --iterations: %s\n", arg); + exit(EXIT_FAILURE); + } + num_iterations = (size_t)arg_iterations; } else if (position == 0) { filename = arg; position++;