From 44908ff95a064d6161cafafbc98eb8b184c51ca6 Mon Sep 17 00:00:00 2001 From: Al Date: Wed, 22 Jun 2016 11:50:42 -0400 Subject: [PATCH] [parser] No digit normalization in training data-derived parser phrases (for postcodes, etc.), phrases include the new island type, house number phrases if any are valid. Adjacent words are now full phrases if they are part of a multiword token like a city name. For hyphenated names like Carmel-by-the-Sea, adding a version to the phrase dictionary where the hyphens are replaced with spaces --- src/address_parser.c | 321 +++++++++++++++++++++++++++++-------- src/address_parser.h | 37 +++-- src/address_parser_train.c | 227 ++++++++++++++++++-------- 3 files changed, 439 insertions(+), 146 deletions(-) diff --git a/src/address_parser.c b/src/address_parser.c index d1593fa4..f3cf95d8 100644 --- a/src/address_parser.c +++ b/src/address_parser.c @@ -15,6 +15,14 @@ static address_parser_t *parser = NULL; +//#define PRINT_ADDRESS_PARSER_FEATURES + +typedef enum { + ADDRESS_PARSER_NULL_PHRASE, + ADDRESS_PARSER_DICTIONARY_PHRASE, + ADDRESS_PARSER_COMPONENT_PHRASE, + ADDRESS_PARSER_GEODB_PHRASE +} address_parser_phrase_type_t; address_parser_t *address_parser_new(void) { address_parser_t *parser = malloc(sizeof(address_parser_t)); @@ -149,6 +157,10 @@ inline void address_parser_normalize_token(cstring_array *array, char *str, toke normalize_token(array, str, token, ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS); } +inline void address_parser_normalize_phrase_token(cstring_array *array, char *str, token_t token) { + normalize_token(array, str, token, ADDRESS_PARSER_NORMALIZE_PHRASE_TOKEN_OPTIONS); +} + inline char *address_parser_normalize_string(char *str) { return normalize_string_latin(str, strlen(str), ADDRESS_PARSER_NORMALIZE_STRING_OPTIONS); } @@ -161,14 +173,46 @@ void address_parser_context_destroy(address_parser_context_t *self) { char_array_destroy(self->phrase); } + if (self->context_phrase != NULL) { + char_array_destroy(self->context_phrase); + } + + if (self->long_context_phrase != NULL) { + char_array_destroy(self->long_context_phrase); + } + if (self->component_phrase != NULL) { char_array_destroy(self->component_phrase); } + if (self->context_component_phrase != NULL) { + char_array_destroy(self->context_component_phrase); + } + + if (self->long_context_component_phrase != NULL) { + char_array_destroy(self->long_context_component_phrase); + } + if (self->geodb_phrase != NULL) { char_array_destroy(self->geodb_phrase); } + if (self->context_geodb_phrase != NULL) { + char_array_destroy(self->context_geodb_phrase); + } + + if (self->long_context_geodb_phrase != NULL) { + char_array_destroy(self->long_context_geodb_phrase); + } + + if (self->sub_token != NULL) { + char_array_destroy(self->sub_token); + } + + if (self->sub_tokens != NULL) { + token_array_destroy(self->sub_tokens); + } + if (self->separators != NULL) { uint32_array_destroy(self->separators); } @@ -225,16 +269,56 @@ address_parser_context_t *address_parser_context_new(void) { goto exit_address_parser_context_allocated; } + context->context_phrase = char_array_new(); + if (context->context_phrase == NULL) { + goto exit_address_parser_context_allocated; + } + + context->long_context_phrase = char_array_new(); + if (context->long_context_phrase == NULL) { + goto exit_address_parser_context_allocated; + } + context->component_phrase = char_array_new(); if (context->component_phrase == NULL) { goto exit_address_parser_context_allocated; } + context->context_component_phrase = char_array_new(); + if (context->context_component_phrase == NULL) { + goto exit_address_parser_context_allocated; + } + + context->long_context_component_phrase = char_array_new(); + if (context->long_context_component_phrase == NULL) { + goto exit_address_parser_context_allocated; + } + context->geodb_phrase = char_array_new(); if (context->geodb_phrase == NULL) { goto exit_address_parser_context_allocated; } + context->context_geodb_phrase = char_array_new(); + if (context->context_geodb_phrase == NULL) { + goto exit_address_parser_context_allocated; + } + + context->long_context_geodb_phrase = char_array_new(); + if (context->long_context_geodb_phrase == NULL) { + goto exit_address_parser_context_allocated; + } + + context->sub_token = char_array_new(); + if (context->sub_token == NULL) { + goto exit_address_parser_context_allocated; + } + + context->sub_tokens = token_array_new(); + if (context->sub_tokens == NULL) { + goto exit_address_parser_context_allocated; + } + context->separators = uint32_array_new(); if (context->separators == NULL) { goto exit_address_parser_context_allocated; @@ -313,9 +397,21 @@ void address_parser_context_fill(address_parser_context_t *context, address_pars address_parser_normalize_token(normalized, str, token); }) + + /* + Address dictionary phrases + -------------------------- + Recognizing phrases that occur in libpostal's dictionaries. + + Note: if the dictionaries are updates to try to improve the parser, + we'll need to retrain. This can be done without rebuilding the + training data (a long-running process which can take up to a week), + but will require running address_parser_train, the main training script. + */ + phrase_array_clear(context->address_dictionary_phrases); int64_array_clear(context->address_phrase_memberships); - + i = 0; phrase_array *address_dictionary_phrases = context->address_dictionary_phrases; int64_array *address_phrase_memberships = context->address_phrase_memberships; @@ -363,6 +459,7 @@ void address_parser_context_fill(address_parser_context_t *context, address_pars } } } + for (; i < tokens->n; i++) { log_debug("token i=%lld, null geo phrase membership\n", i); int64_array_push(geodb_phrase_memberships, NULL_PHRASE_MEMBERSHIP); @@ -390,24 +487,21 @@ void address_parser_context_fill(address_parser_context_t *context, address_pars } } } + for (; i < tokens->n; i++) { log_debug("token i=%lld, null component phrase membership\n", i); int64_array_push(component_phrase_memberships, NULL_PHRASE_MEMBERSHIP); } - - } - -static inline char *get_phrase_string(tokenized_string_t *str, char_array *phrase_tokens, phrase_t phrase) { - size_t phrase_len = 0; +static inline char *get_phrase_string_array(cstring_array *str, char_array *phrase_tokens, phrase_t phrase) { char_array_clear(phrase_tokens); size_t phrase_end = phrase.start + phrase.len; for (int k = phrase.start; k < phrase_end; k++) { - char *w = tokenized_string_get_token(str, k); + char *w = cstring_array_get_string(str, k); char_array_append(phrase_tokens, w); if (k < phrase_end - 1) { char_array_append(phrase_tokens, " "); @@ -419,55 +513,127 @@ static inline char *get_phrase_string(tokenized_string_t *str, char_array *phras } -/* +static inline char *get_phrase_string(tokenized_string_t *str, char_array *phrase_tokens, phrase_t phrase) { + return get_phrase_string_array(str->strings, phrase_tokens, phrase); +} -typedef struct adjacent_phrase { +static inline phrase_t get_phrase(phrase_array *phrases, int64_array *phrase_memberships, uint32_t i) { + if (phrases == NULL || phrase_memberships == NULL || i > phrases->n - 1) { + return NULL_PHRASE; + } + + int64_t phrase_index = phrase_memberships->a[i]; + if (phrase_index != NULL_PHRASE_MEMBERSHIP) { + phrase_t phrase = phrases->a[phrase_index]; + return phrase; + } + + return NULL_PHRASE; +} + +typedef struct address_parser_phrase { + char *str; + address_parser_phrase_type_t type; phrase_t phrase; - uint32_t num_separators; -} adjacent_phrase_t; +} address_parser_phrase_t; -#define NULL_ADJACENT_PHRASE (adjacent_phrase_t){NULL_PHRASE, 0}; +static inline address_parser_phrase_t word_or_phrase_at_index(tokenized_string_t *tokenized, address_parser_context_t *context, uint32_t i) { + phrase_t phrase; + address_parser_phrase_t response; + char *phrase_string = NULL; + + phrase = get_phrase(context->address_dictionary_phrases, context->address_phrase_memberships, i); + if (phrase.len > 0) { + phrase_string = get_phrase_string(tokenized, context->context_phrase, phrase), + + response = (address_parser_phrase_t){ + phrase_string, + ADDRESS_PARSER_DICTIONARY_PHRASE, + phrase + }; + return response; + } + + address_parser_types_t types; + + phrase = get_phrase(context->component_phrases, context->component_phrase_memberships, i); + if (phrase.len > 0) { + types.value = phrase.data; + uint32_t component_phrase_types = types.components; + + if (component_phrase_types != ADDRESS_COMPONENT_POSTAL_CODE) { + phrase_string = get_phrase_string(tokenized, context->context_component_phrase, phrase); + } else { + phrase_string = get_phrase_string_array(context->normalized, context->context_component_phrase, phrase); + } + + response = (address_parser_phrase_t){ + phrase_string, + ADDRESS_PARSER_COMPONENT_PHRASE, + phrase + }; + return response; + } + + geodb_value_t geo; + + phrase = get_phrase(context->geodb_phrases, context->geodb_phrase_memberships, i); + if (phrase.len > 0) { + geo.value = phrase.data; + uint32_t geodb_phrase_types = geo.components; + + if (geodb_phrase_types != GEONAMES_ADDRESS_COMPONENT_POSTCODE) { + phrase_string = get_phrase_string(tokenized, context->context_geodb_phrase, phrase); + } else { + phrase_string = get_phrase_string_array(context->normalized, context->context_geodb_phrase, phrase); + } + + response = (address_parser_phrase_t){ + phrase_string, + ADDRESS_PARSER_GEODB_PHRASE, + phrase + }; + return response; + + } + + cstring_array *normalized = context->normalized; + + char *word = cstring_array_get_string(normalized, i); + response = (address_parser_phrase_t){ + word, + ADDRESS_PARSER_NULL_PHRASE, + NULL_PHRASE + }; + return response; + +} + +static inline int64_t phrase_index(int64_array *phrase_memberships, size_t start, int8_t direction) { + if (phrase_memberships == NULL) { + return -1; + } -static inline adjacent_phrase_t get_adjacent_phrase(int64_array *phrase_memberships, phrase_array *phrases, uint32_array *separator_positions, uint32_t i, int32_t direction) { - uint32_t *separators = separator_positions->a; int64_t *memberships = phrase_memberships->a; + int64_t membership; - uint32_t num_strings = (uint32_t)phrase_memberships->n; - - adjacent_phrase_t adjacent = NULL_ADJACENT_PHRASE; - - if (direction == -1) { - for (uint32_t idx = i; idx >= 0; idx--) { - uint32_t separator = separators[idx]; - if (separator > ADDRESS_SEPARATOR_NONE) { - adjacent.num_separators++; + if (direction == -1) { + for (size_t idx = start; idx >= 0; idx--) { + if (memberships[idx] != NULL_PHRASE_MEMBERSHIP) { + return (int64_t)idx; } - - int64_t membership = memberships[ids]; - if (membership != NULL_PHRASE_MEMBERSHIP) { - adjacent.phrase = phrases->a[membership]; - break; - } - } } else if (direction == 1) { - for (uint32_t idx = i; idx < num_strings; idx++) { - uint32_t separator = separators[idx]; - if (separator > ADDRESS_SEPARATOR_NONE) { - adjacent.num_separators++; - } - - int64_t membership = memberships[ids]; - if (membership != NULL_PHRASE_MEMBERSHIP) { - adjacent.phrase = phrases->a[membership]; - break; + size_t n = phrase_memberships->n; + for (size_t idx = start; idx < n; idx++) { + if (memberships[idx] != NULL_PHRASE_MEMBERSHIP) { + return (int64_t)idx; } } } - return adjacent; + return -1; } -*/ static inline void add_phrase_features(cstring_array *features, uint32_t phrase_types, uint32_t component, char *phrase_type, char *phrase_string, char *prev2, char *prev) { if (phrase_types == component) { @@ -572,7 +738,7 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize log_debug("expansion=%d\n", expansion.value); - if (address_phrase_types & (ADDRESS_STREET | ADDRESS_NAME)) { + if (address_phrase_types & (ADDRESS_STREET | ADDRESS_HOUSE_NUMBER | ADDRESS_NAME)) { phrase_string = get_phrase_string(tokenized, phrase_tokens, phrase); add_word_feature = false; @@ -580,6 +746,8 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize add_phrase_features(features, address_phrase_types, ADDRESS_STREET, "street", phrase_string, prev2, prev); add_phrase_features(features, address_phrase_types, ADDRESS_NAME, "name", phrase_string, prev2, prev); + add_phrase_features(features, address_phrase_types, ADDRESS_HOUSE_NUMBER, "house_number", phrase_string, prev2, prev); + } } @@ -615,6 +783,8 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize address_parser_types_t types; + bool possible_postal_code = false; + // Component phrases if (component_phrase_index != NULL_PHRASE_MEMBERSHIP) { phrase = component_phrases->a[component_phrase_index]; @@ -640,6 +810,7 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_SUBURB, "suburb", component_phrase_string, prev2, prev); add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_CITY, "city", component_phrase_string, prev2, prev); add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_CITY_DISTRICT, "city_district", component_phrase_string, prev2, prev); + add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_ISLAND, "island", component_phrase_string, prev2, prev); add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_STATE_DISTRICT, "state_district", component_phrase_string, prev2, prev); add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_STATE, "state", component_phrase_string, prev2, prev); add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_POSTAL_CODE, "postal_code", component_phrase_string, prev2, prev); @@ -654,12 +825,15 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize feature_array_add(features, 2, "commonly country", component_phrase_string); } else if (most_common == ADDRESS_PARSER_STATE_DISTRICT) { feature_array_add(features, 2, "commonly state_district", component_phrase_string); + } else if (most_common == ADDRESS_PARSER_ISLAND) { + feature_array_add(features, 2, "commonly island", component_phrase_string); } else if (most_common == ADDRESS_PARSER_SUBURB) { feature_array_add(features, 2, "commonly suburb", component_phrase_string); } else if (most_common == ADDRESS_PARSER_CITY_DISTRICT) { feature_array_add(features, 2, "commonly city_district", component_phrase_string); } else if (most_common == ADDRESS_PARSER_POSTAL_CODE) { feature_array_add(features, 2, "commonly postal_code", component_phrase_string); + possible_postal_code = true; } } @@ -701,10 +875,14 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize } + possible_postal_code = geodb_phrase_types & GEONAMES_ADDRESS_COMPONENT_POSTCODE; + } uint32_t word_freq = word_vocab_frequency(parser, word); + bool is_unknown_word = false; + if (add_word_feature) { // Bias unit, acts as an intercept feature_array_add(features, 1, "bias"); @@ -737,12 +915,15 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize } if (last_index >= 0) { - char *prev_word = cstring_array_get_string(normalized, last_index); + address_parser_phrase_t prev_word_or_phrase = word_or_phrase_at_index(tokenized, context, last_index); + char *prev_word = prev_word_or_phrase.str; - uint32_t prev_word_freq = word_vocab_frequency(parser, prev_word); - if (prev_word_freq == 0) { - token_t prev_token = tokenized->tokens->a[last_index]; - prev_word = (prev_token.type != NUMERIC && prev_token.type != IDEOGRAPHIC_NUMBER) ? UNKNOWN_WORD : UNKNOWN_NUMERIC; + if (prev_word_or_phrase.type == ADDRESS_PARSER_NULL_PHRASE) { + uint32_t prev_word_freq = word_vocab_frequency(parser, prev_word); + if (prev_word_freq == 0) { + token_t prev_token = tokenized->tokens->a[last_index]; + prev_word = (prev_token.type != NUMERIC && prev_token.type != IDEOGRAPHIC_NUMBER) ? UNKNOWN_WORD : UNKNOWN_NUMERIC; + } } // Previous word @@ -759,12 +940,18 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize size_t num_tokens = tokenized->tokens->n; if (next_index < num_tokens) { - char *next_word = cstring_array_get_string(normalized, next_index); + address_parser_phrase_t next_word_or_phrase = word_or_phrase_at_index(tokenized, context, next_index); + char *next_word = next_word_or_phrase.str; + size_t next_word_len = 1; - uint32_t next_word_freq = word_vocab_frequency(parser, next_word); - if (next_word_freq == 0) { - token_t next_token = tokenized->tokens->a[next_index]; - next_word = (next_token.type != NUMERIC && next_token.type != IDEOGRAPHIC_NUMBER) ? UNKNOWN_WORD : UNKNOWN_NUMERIC; + if (next_word_or_phrase.type == ADDRESS_PARSER_NULL_PHRASE) { + uint32_t next_word_freq = word_vocab_frequency(parser, next_word); + if (next_word_freq == 0) { + token_t next_token = tokenized->tokens->a[next_index]; + next_word = (next_token.type != NUMERIC && next_token.type != IDEOGRAPHIC_NUMBER) ? UNKNOWN_WORD : UNKNOWN_NUMERIC; + } + } else { + next_word_len = next_word_or_phrase.phrase.len; } // Next word e.g. if the current word is unknown and the next word is "street" @@ -774,7 +961,7 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize feature_array_add(features, 3, "word+i+1 word", word, next_word); } - #ifndef PRINT_FEATURES + #ifndef PRINT_ADDRESS_PARSER_FEATURES if (0) { #endif @@ -787,7 +974,7 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize }) printf("}\n"); - #ifndef PRINT_FEATURES + #ifndef PRINT_ADDRESS_PARSER_FEATURES } #endif @@ -868,18 +1055,21 @@ address_parser_response_t *address_parser_parse(char *address, char *language, c label = strdup(ADDRESS_PARSER_LABEL_POSTAL_CODE); } - char **single_label = malloc(sizeof(char *)); - single_label[0] = label; - char **single_component = malloc(sizeof(char *)); - single_component[0] = strdup(normalized); + // Implicit: if most_common is not one of the above, ignore and parse regularly + if (label != NULL) { + char **single_label = malloc(sizeof(char *)); + single_label[0] = label; + char **single_component = malloc(sizeof(char *)); + single_component[0] = strdup(normalized); - response->num_components = 1; - response->labels = single_label; - response->components = single_component; + response->num_components = 1; + response->labels = single_label; + response->components = single_component; - token_array_destroy(tokens); - tokenized_string_destroy(tokenized_str); - return response; + token_array_destroy(tokens); + tokenized_string_destroy(tokenized_str); + return response; + } } } @@ -895,7 +1085,6 @@ address_parser_response_t *address_parser_parse(char *address, char *language, c cstring_array *labels = cstring_array_new_size(num_strings); cstring_array *components = cstring_array_new_size(strlen(address) + num_strings); - for (int i = 0; i < num_strings; i++) { char *str = tokenized_string_get_token(tokenized_str, i); char *label = cstring_array_get_string(token_labels, i); diff --git a/src/address_parser.h b/src/address_parser.h index bd70d1c8..5fcd4e98 100644 --- a/src/address_parser.h +++ b/src/address_parser.h @@ -59,7 +59,8 @@ with the general error-driven averaged perceptron. #define NULL_PHRASE_MEMBERSHIP -1 #define ADDRESS_PARSER_NORMALIZE_STRING_OPTIONS NORMALIZE_STRING_COMPOSE | NORMALIZE_STRING_LOWERCASE | NORMALIZE_STRING_LATIN_ASCII -#define ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS NORMALIZE_TOKEN_DELETE_HYPHENS | NORMALIZE_TOKEN_DELETE_FINAL_PERIOD | NORMALIZE_TOKEN_DELETE_ACRONYM_PERIODS | NORMALIZE_TOKEN_REPLACE_DIGITS +#define ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS NORMALIZE_TOKEN_DELETE_FINAL_PERIOD | NORMALIZE_TOKEN_DELETE_ACRONYM_PERIODS | NORMALIZE_TOKEN_REPLACE_DIGITS +#define ADDRESS_PARSER_NORMALIZE_PHRASE_TOKEN_OPTIONS ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS ^ NORMALIZE_TOKEN_REPLACE_DIGITS #define ADDRESS_SEPARATOR_NONE 0 #define ADDRESS_SEPARATOR_FIELD_INTERNAL 1 << 0 @@ -77,10 +78,11 @@ with the general error-driven averaged perceptron. #define ADDRESS_COMPONENT_SUBURB 1 << 7 #define ADDRESS_COMPONENT_CITY_DISTRICT 1 << 8 #define ADDRESS_COMPONENT_CITY 1 << 9 -#define ADDRESS_COMPONENT_STATE_DISTRICT 1 << 10 -#define ADDRESS_COMPONENT_STATE 1 << 11 -#define ADDRESS_COMPONENT_POSTAL_CODE 1 << 12 -#define ADDRESS_COMPONENT_COUNTRY 1 << 13 +#define ADDRESS_COMPONENT_ISLAND 1 << 10 +#define ADDRESS_COMPONENT_STATE_DISTRICT 1 << 11 +#define ADDRESS_COMPONENT_STATE 1 << 12 +#define ADDRESS_COMPONENT_POSTAL_CODE 1 << 13 +#define ADDRESS_COMPONENT_COUNTRY 1 << 14 typedef enum { ADDRESS_PARSER_HOUSE, @@ -90,6 +92,7 @@ typedef enum { ADDRESS_PARSER_CITY_DISTRICT, ADDRESS_PARSER_CITY, ADDRESS_PARSER_STATE_DISTRICT, + ADDRESS_PARSER_ISLAND, ADDRESS_PARSER_STATE, ADDRESS_PARSER_POSTAL_CODE, ADDRESS_PARSER_COUNTRY, @@ -103,8 +106,9 @@ typedef enum { #define ADDRESS_PARSER_LABEL_CITY_DISTRICT "city_district" #define ADDRESS_PARSER_LABEL_CITY "city" #define ADDRESS_PARSER_LABEL_STATE_DISTRICT "state_district" +#define ADDRESS_PARSER_LABEL_ISLAND "island" #define ADDRESS_PARSER_LABEL_STATE "state" -#define ADDRESS_PARSER_LABEL_POSTAL_CODE "postal_code" +#define ADDRESS_PARSER_LABEL_POSTAL_CODE "postcode" #define ADDRESS_PARSER_LABEL_COUNTRY "country" typedef union address_parser_types { @@ -120,20 +124,29 @@ typedef struct address_parser_context { char *language; char *country; cstring_array *features; + // Temporary strings used at each token during feature extraction char_array *phrase; + char_array *context_phrase; + char_array *long_context_phrase; char_array *component_phrase; + char_array *context_component_phrase; + char_array *long_context_component_phrase; char_array *geodb_phrase; + char_array *context_geodb_phrase; + char_array *long_context_geodb_phrase; + // For hyphenated words + char_array *sub_token; + token_array *sub_tokens; + // Strings/arrays relating to the sentence uint32_array *separators; cstring_array *normalized; + // Known phrases phrase_array *address_dictionary_phrases; - // Index in address_dictionary_phrases or -1 - int64_array *address_phrase_memberships; + int64_array *address_phrase_memberships; // Index in address_dictionary_phrases or -1 phrase_array *geodb_phrases; - // Index in gedob_phrases or -1 - int64_array *geodb_phrase_memberships; + int64_array *geodb_phrase_memberships; // Index in gedob_phrases or -1 phrase_array *component_phrases; - // Index in component_phrases or -1 - int64_array *component_phrase_memberships; + int64_array *component_phrase_memberships; // Index in component_phrases or -1 tokenized_string_t *tokenized_str; } address_parser_context_t; diff --git a/src/address_parser_train.c b/src/address_parser_train.c index 11b1cdee..220c5d66 100644 --- a/src/address_parser_train.c +++ b/src/address_parser_train.c @@ -11,12 +11,6 @@ #include "log/log.h" -// Training - -#define DEFAULT_ITERATIONS 5 - -#define MIN_VOCAB_COUNT 5 -#define MIN_PHRASE_COUNT 1 typedef struct phrase_stats { khash_t(int_uint32) *class_counts; @@ -25,6 +19,24 @@ typedef struct phrase_stats { KHASH_MAP_INIT_STR(phrase_stats, phrase_stats_t) +// Training + +#define DEFAULT_ITERATIONS 5 + +#define MIN_VOCAB_COUNT 5 +#define MIN_PHRASE_COUNT 1 + +static inline bool is_phrase_component(char *label) { + return (string_equals(label, ADDRESS_PARSER_LABEL_SUBURB) || + string_equals(label, ADDRESS_PARSER_LABEL_CITY_DISTRICT) || + string_equals(label, ADDRESS_PARSER_LABEL_CITY) || + string_equals(label, ADDRESS_PARSER_LABEL_STATE_DISTRICT) || + string_equals(label, ADDRESS_PARSER_LABEL_ISLAND) || + string_equals(label, ADDRESS_PARSER_LABEL_STATE) || + string_equals(label, ADDRESS_PARSER_LABEL_POSTAL_CODE) || + string_equals(label, ADDRESS_PARSER_LABEL_COUNTRY)); +} + address_parser_t *address_parser_init(char *filename) { if (filename == NULL) { log_error("Filename was NULL\n"); @@ -102,37 +114,43 @@ address_parser_t *address_parser_init(char *filename) { cstring_array_foreach(tokenized_str->strings, i, token, { token_t t = tokenized_str->tokens->a[i]; + char *label = cstring_array_get_string(data_set->labels, i); + if (label == NULL) { + continue; + } + char_array_clear(token_builder); - add_normalized_token(token_builder, str, t, ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS); + + bool is_phrase_label = is_phrase_component(label); + + uint64_t normalize_token_options = is_phrase_label ? ADDRESS_PARSER_NORMALIZE_PHRASE_TOKEN_OPTIONS : ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS; + + add_normalized_token(token_builder, str, t, normalize_token_options); if (token_builder->n == 0) { continue; } normalized = char_array_get_string(token_builder); - k = kh_get(str_uint32, vocab, normalized); - - if (k == kh_end(vocab)) { - key = strdup(normalized); - k = kh_put(str_uint32, vocab, key, &ret); - if (ret < 0) { - log_error("Error in kh_put in vocab\n"); - free(key); - goto exit_hashes_allocated; + if (!is_phrase_component(label)) { + k = kh_get(str_uint32, vocab, normalized); + + if (k == kh_end(vocab)) { + key = strdup(normalized); + k = kh_put(str_uint32, vocab, key, &ret); + if (ret < 0) { + log_error("Error in kh_put in vocab\n"); + free(key); + goto exit_hashes_allocated; + } + kh_value(vocab, k) = 1; + vocab_size++; + } else { + kh_value(vocab, k)++; } - kh_value(vocab, k) = 1; - vocab_size++; - } else { - kh_value(vocab, k)++; - } - char *label = cstring_array_get_string(data_set->labels, i); - if (label == NULL) { - continue; - } - - if (string_equals(label, "road") || string_equals(label, "house_number") || string_equals(label, "house")) { prev_label = NULL; + continue; } @@ -155,68 +173,103 @@ address_parser_t *address_parser_init(char *filename) { uint32_t component = 0; // Too many variations on these - if (string_equals(prev_label, "city")) { + if (string_equals(prev_label, ADDRESS_PARSER_LABEL_CITY)) { class_id = ADDRESS_PARSER_CITY; component = ADDRESS_COMPONENT_CITY; - } else if (string_equals(prev_label, "state")) { + } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_STATE)) { class_id = ADDRESS_PARSER_STATE; component = ADDRESS_COMPONENT_STATE; - } else if (string_equals(prev_label, "country")) { + } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_COUNTRY)) { class_id = ADDRESS_PARSER_COUNTRY; component = ADDRESS_COMPONENT_COUNTRY; - } else if (string_equals(prev_label, "state_district")) { + } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_STATE_DISTRICT)) { class_id = ADDRESS_PARSER_STATE_DISTRICT; component = ADDRESS_COMPONENT_STATE_DISTRICT; - } else if (string_equals(prev_label, "suburb")) { + } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_SUBURB)) { class_id = ADDRESS_PARSER_SUBURB; component = ADDRESS_COMPONENT_SUBURB; - } else if (string_equals(prev_label, "city_district")) { + } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_CITY_DISTRICT)) { class_id = ADDRESS_PARSER_CITY_DISTRICT; component = ADDRESS_COMPONENT_CITY_DISTRICT; - } else if (string_equals(prev_label, "postcode")) { + } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_ISLAND)) { + class_id = ADDRESS_PARSER_ISLAND; + component = ADDRESS_COMPONENT_ISLAND; + } else if (string_equals(prev_label, ADDRESS_PARSER_LABEL_POSTAL_CODE)) { class_id = ADDRESS_PARSER_POSTAL_CODE; component = ADDRESS_COMPONENT_POSTAL_CODE; + } else { + // Shouldn't happen but just in case + prev_label = NULL; + continue; } char *phrase = char_array_get_string(phrase_builder); - k = kh_get(phrase_stats, phrase_stats, phrase); + char *normalized_phrase = NULL; - if (k == kh_end(phrase_stats)) { - key = strdup(phrase); - ret = 0; - k = kh_put(phrase_stats, phrase_stats, key, &ret); - if (ret < 0) { - log_error("Error in kh_put in phrase_stats\n"); - free(key); + if (string_contains_hyphen(phrase)) { + char *phrase_copy = strdup(phrase); + if (phrase_copy == NULL) { goto exit_hashes_allocated; } - class_counts = kh_init(int_uint32); - - stats.class_counts = class_counts; - stats.parser_types.components = component; - stats.parser_types.most_common = 0; - - kh_value(phrase_stats, k) = stats; - } else { - stats = kh_value(phrase_stats, k); - class_counts = stats.class_counts; - stats.parser_types.components |= component; + normalized_phrase = normalize_string_utf8(phrase_copy, NORMALIZE_STRING_REPLACE_HYPHENS); } - k = kh_get(int_uint32, class_counts, (khint_t)class_id); + char *phrases[2]; + phrases[0] = phrase; + phrases[1] = normalized_phrase; + + for (int i = 0; i < sizeof(phrases) / sizeof(char *); i++) { + phrase = phrases[i]; + if (phrase == NULL) continue; + + k = kh_get(phrase_stats, phrase_stats, phrase); + + if (k == kh_end(phrase_stats)) { + key = strdup(phrase); + ret = 0; + k = kh_put(phrase_stats, phrase_stats, key, &ret); + if (ret < 0) { + log_error("Error in kh_put in phrase_stats\n"); + free(key); + if (normalized_phrase != NULL) { + free(normalized_phrase); + } + goto exit_hashes_allocated; + } + class_counts = kh_init(int_uint32); + + stats.class_counts = class_counts; + stats.parser_types.components = component; + stats.parser_types.most_common = 0; + + kh_value(phrase_stats, k) = stats; + } else { + stats = kh_value(phrase_stats, k); + class_counts = stats.class_counts; + stats.parser_types.components |= component; - if (k == kh_end(class_counts)) { - ret = 0; - k = kh_put(int_uint32, class_counts, class_id, &ret); - if (ret < 0) { - log_error("Error in kh_put in class_counts\n"); - goto exit_hashes_allocated; } - kh_value(class_counts, k) = 1; - } else { - kh_value(class_counts, k)++; + + k = kh_get(int_uint32, class_counts, (khint_t)class_id); + + if (k == kh_end(class_counts)) { + ret = 0; + k = kh_put(int_uint32, class_counts, class_id, &ret); + if (ret < 0) { + log_error("Error in kh_put in class_counts\n"); + goto exit_hashes_allocated; + } + kh_value(class_counts, k) = 1; + } else { + kh_value(class_counts, k)++; + } + + } + + if (normalized_phrase != NULL) { + free(normalized_phrase); } } @@ -228,7 +281,6 @@ address_parser_t *address_parser_init(char *filename) { char_array_cat(phrase_builder, normalized); - prev_label = label; }) @@ -430,10 +482,16 @@ bool address_parser_train(address_parser_t *self, char *filename, uint32_t num_i return true; } +typedef enum { + ADDRESS_PARSER_TRAIN_POSITIONAL_ARG, + ADDRESS_PARSER_TRAIN_ARG_ITERATIONS +} address_parser_train_keyword_arg_t; + +#define USAGE "Usage: ./address_parser_train filename output_dir [--iterations number]\n" int main(int argc, char **argv) { if (argc < 3) { - printf("Usage: ./address_parser_train filename output_dir\n"); + printf(USAGE); exit(EXIT_FAILURE); } @@ -441,8 +499,41 @@ int main(int argc, char **argv) { log_warn("shuf must be installed to train address parser effectively. If this is a production machine, please install shuf. No shuffling will be performed.\n"); #endif - char *filename = argv[1]; - char *output_dir = argv[2]; + int pos_args = 1; + + address_parser_train_keyword_arg_t kwarg = ADDRESS_PARSER_TRAIN_POSITIONAL_ARG; + + size_t num_iterations = DEFAULT_ITERATIONS; + size_t position = 0; + + char *filename = NULL; + char *output_dir = NULL; + + for (int i = pos_args; i < argc; i++) { + char *arg = argv[i]; + + if (string_equals(arg, "--iterations")) { + kwarg = ADDRESS_PARSER_TRAIN_ARG_ITERATIONS; + continue; + } + + if (kwarg == ADDRESS_PARSER_TRAIN_ARG_ITERATIONS) { + num_iterations = (size_t)atoi(arg); + } else if (position == 0) { + filename = arg; + position++; + } else if (position == 1) { + output_dir = arg; + position++; + } + kwarg = ADDRESS_PARSER_TRAIN_POSITIONAL_ARG; + + } + + if (filename == NULL || output_dir == NULL) { + printf(USAGE); + exit(EXIT_FAILURE); + } if (!address_dictionary_module_setup(NULL)) { log_error("Could not load address dictionaries\n"); @@ -475,7 +566,7 @@ int main(int argc, char **argv) { log_info("Finished initialization\n"); - if (!address_parser_train(parser, filename, DEFAULT_ITERATIONS)) { + if (!address_parser_train(parser, filename, num_iterations)) { log_error("Error in training\n"); exit(EXIT_FAILURE); }