[parsing] Adding a training data derived index of complete phrases from suburb up to country. Only adding bias and word features for non phrases, using UNKNOWN_WORD and UNKNOWN_NUMERIC for infrequent tokens (not meeting minimum vocab count threshold).

This commit is contained in:
Al
2015-12-05 14:34:06 -05:00
parent f41158b8b3
commit 24208c209f
3 changed files with 411 additions and 73 deletions

View File

@@ -8,8 +8,10 @@
#define ADDRESS_PARSER_MODEL_FILENAME "address_parser.dat"
#define ADDRESS_PARSER_VOCAB_FILENAME "address_parser_vocab.trie"
#define ADDRESS_PARSER_PHRASE_FILENAME "address_parser_phrases.trie"
#define UNKNOWN_WORD "UNKNOWN"
#define UNKNOWN_NUMERIC "UNKNOWN_NUMERIC"
static address_parser_t *parser = NULL;
@@ -47,6 +49,15 @@ bool address_parser_save(address_parser_t *self, char *output_dir) {
return false;
}
char_array_clear(path);
char_array_add_joined(path, PATH_SEPARATOR, true, 2, output_dir, ADDRESS_PARSER_PHRASE_FILENAME);
char *phrases_path = char_array_get_string(path);
if (!trie_save(self->phrase_types, phrases_path)) {
return false;
}
char_array_destroy(path);
return true;
@@ -90,6 +101,22 @@ bool address_parser_load(char *dir) {
parser->vocab = vocab;
char_array_clear(path);
char_array_add_joined(path, PATH_SEPARATOR, true, 2, dir, ADDRESS_PARSER_PHRASE_FILENAME);
char *phrases_path = char_array_get_string(path);
trie_t *phrase_types = trie_load(phrases_path);
if (phrase_types == NULL) {
address_parser_destroy(parser);
char_array_destroy(path);
return false;
}
parser->phrase_types = phrase_types;
char_array_destroy(path);
return true;
}
@@ -105,6 +132,10 @@ void address_parser_destroy(address_parser_t *self) {
trie_destroy(self->vocab);
}
if (self->phrase_types != NULL) {
trie_destroy(self->phrase_types);
}
free(self);
}
@@ -162,6 +193,14 @@ void address_parser_context_destroy(address_parser_context_t *self) {
int64_array_destroy(self->geodb_phrase_memberships);
}
if (self->component_phrases != NULL) {
phrase_array_destroy(self->component_phrases);
}
if (self->component_phrase_memberships != NULL) {
int64_array_destroy(self->component_phrase_memberships);
}
free(self);
}
@@ -218,6 +257,16 @@ address_parser_context_t *address_parser_context_new(void) {
goto exit_address_parser_context_allocated;
}
context->component_phrases = phrase_array_new();
if (context->component_phrases == NULL) {
goto exit_address_parser_context_allocated;
}
context->component_phrase_memberships = int64_array_new();
if (context->component_phrase_memberships == NULL) {
goto exit_address_parser_context_allocated;
}
return context;
exit_address_parser_context_allocated:
@@ -225,7 +274,7 @@ exit_address_parser_context_allocated:
return NULL;
}
void address_parser_context_fill(address_parser_context_t *context, tokenized_string_t *tokenized_str, char *language, char *country) {
void address_parser_context_fill(address_parser_context_t *context, address_parser_t *parser, tokenized_string_t *tokenized_str, char *language, char *country) {
int64_t i, j;
uint32_t token_index;
@@ -301,6 +350,34 @@ void address_parser_context_fill(address_parser_context_t *context, tokenized_st
int64_array_push(geodb_phrase_memberships, NULL_PHRASE_MEMBERSHIP);
}
phrase_array_clear(context->component_phrases);
int64_array_clear(context->component_phrase_memberships);
i = 0;
phrase_array *component_phrases = context->component_phrases;
int64_array *component_phrase_memberships = context->component_phrase_memberships;
if (trie_search_tokens_with_phrases(parser->phrase_types, str, tokens, &component_phrases)) {
for (j = 0; j < component_phrases->n; j++) {
phrase = component_phrases->a[j];
for (; i < phrase.start; i++) {
log_debug("token i=%lld, null component phrase membership\n", i);
int64_array_push(component_phrase_memberships, NULL_PHRASE_MEMBERSHIP);
}
for (i = phrase.start; i < phrase.start + phrase.len; i++) {
log_debug("token i=%lld, component phrase membership=%lld\n", i, j);
int64_array_push(component_phrase_memberships, j);
}
}
}
for (; i < tokens->n; i++) {
log_debug("token i=%lld, null component phrase membership\n", i);
int64_array_push(component_phrase_memberships, NULL_PHRASE_MEMBERSHIP);
}
}
@@ -384,7 +461,6 @@ static inline void add_phrase_features(cstring_array *features, uint32_t phrase_
}
}
/*
address_parser_features
-----------------------
@@ -426,15 +502,14 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize
int64_array *address_phrase_memberships = context->address_phrase_memberships;
phrase_array *geodb_phrases = context->geodb_phrases;
int64_array *geodb_phrase_memberships = context->geodb_phrase_memberships;
phrase_array *component_phrases = context->component_phrases;
int64_array *component_phrase_memberships = context->component_phrase_memberships;
cstring_array *normalized = context->normalized;
uint32_array *separators = context->separators;
cstring_array_clear(features);
// Bias unit, acts as an intercept
feature_array_add(features, 1, "bias");
char *original_word = tokenized_string_get_token(tokenized, i);
token_t token = tokenized->tokens->a[i];
@@ -449,7 +524,6 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize
}
size_t word_len = strlen(word);
char *current_word = word;
log_debug("word=%s\n", word);
@@ -459,6 +533,7 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize
char *phrase_string = NULL;
char *geo_phrase_string = NULL;
char *component_phrase_string = NULL;
int64_t address_phrase_index = address_phrase_memberships->a[i];
@@ -519,40 +594,96 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize
}
}
bool add_word_feature = true;
int64_t component_phrase_index = component_phrase_memberships->a[i];
phrase = NULL_PHRASE;
address_parser_types_t types;
// Component phrases
if (component_phrase_index != NULL_PHRASE_MEMBERSHIP) {
phrase = component_phrases->a[component_phrase_index];
component_phrase_string = get_phrase_string(tokenized, phrase_tokens, phrase);
types.value = phrase.data;
uint32_t component_phrase_types = types.components;
uint32_t most_common = types.most_common;
if (last_index >= (ssize_t)phrase.start - 1 || next_index <= (ssize_t)phrase.start + phrase.len - 1) {
last_index = (ssize_t)phrase.start - 1;
next_index = (ssize_t)phrase.start + phrase.len;
}
if (component_phrase_string != NULL && component_phrase_types ^ ADDRESS_COMPONENT_POSTAL_CODE) {
feature_array_add(features, 2, "phrase", component_phrase_string);
add_word_feature = false;
}
if (component_phrase_types > 0) {
add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_SUBURB, "suburb", component_phrase_string, prev2, prev);
add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_CITY, "city", component_phrase_string, prev2, prev);
add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_CITY_DISTRICT, "city_district", component_phrase_string, prev2, prev);
add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_STATE_DISTRICT, "state_district", component_phrase_string, prev2, prev);
add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_STATE, "state", component_phrase_string, prev2, prev);
add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_POSTAL_CODE, "postal_code", component_phrase_string, prev2, prev);
add_phrase_features(features, component_phrase_types, ADDRESS_COMPONENT_COUNTRY, "country", component_phrase_string, prev2, prev);
}
if (most_common == ADDRESS_PARSER_CITY) {
feature_array_add(features, 2, "commonly city", component_phrase_string);
} else if (most_common == ADDRESS_PARSER_STATE) {
feature_array_add(features, 2, "commonly state", component_phrase_string);
} else if (most_common == ADDRESS_PARSER_COUNTRY) {
feature_array_add(features, 2, "commonly country", component_phrase_string);
} else if (most_common == ADDRESS_PARSER_STATE_DISTRICT) {
feature_array_add(features, 2, "commonly state_district", component_phrase_string);
} else if (most_common == ADDRESS_PARSER_SUBURB) {
feature_array_add(features, 2, "commonly suburb", component_phrase_string);
} else if (most_common == ADDRESS_PARSER_CITY_DISTRICT) {
feature_array_add(features, 2, "commonly city_district", component_phrase_string);
} else if (most_common == ADDRESS_PARSER_POSTAL_CODE) {
feature_array_add(features, 2, "commonly postal_code", component_phrase_string);
}
}
int64_t geodb_phrase_index = geodb_phrase_memberships->a[i];
phrase = NULL_PHRASE;
geodb_value_t geo;
// GeoDB phrases
if (geodb_phrase_index != NULL_PHRASE_MEMBERSHIP) {
if (component_phrase_index == NULL_PHRASE_MEMBERSHIP && geodb_phrase_index != NULL_PHRASE_MEMBERSHIP) {
phrase = geodb_phrases->a[geodb_phrase_index];
geo_phrase_string = get_phrase_string(tokenized, phrase_tokens, phrase);
geo.value = phrase.data;
uint32_t geodb_phrase_types = geo.components;
if (last_index <= (ssize_t)phrase.start - 1 && next_index >= (ssize_t)phrase.start + phrase.len - 1) {
if (last_index >= (ssize_t)phrase.start - 1 || next_index <= (ssize_t)phrase.start + phrase.len) {
last_index = (ssize_t)phrase.start - 1;
next_index = (ssize_t)phrase.start + phrase.len;
if (geo_phrase_string != NULL && geodb_phrase_types ^ ADDRESS_POSTAL_CODE) {
word = geo_phrase_string;
}
}
if (geo_phrase_string != NULL && geodb_phrase_types ^ ADDRESS_POSTAL_CODE) {
feature_array_add(features, 2, "phrase", geo_phrase_string);
add_word_feature = false;
}
if (geodb_phrase_types ^ ADDRESS_ANY) {
add_phrase_features(features, geodb_phrase_types, ADDRESS_LOCALITY, "gn city", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN1, "gn admin1", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN2, "gn admin2", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN3, "gn admin3", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN4, "gn admin4", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN_OTHER, "gn admin other", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_NEIGHBORHOOD, "gn neighborhood", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_LOCALITY, "city", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN1, "admin1", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN2, "admin2", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN3, "admin3", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN4, "admin4", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_ADMIN_OTHER, "admin other", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_NEIGHBORHOOD, "neighborhood", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_COUNTRY, "country", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_POSTAL_CODE, "postal code", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_COUNTRY, "gn country", geo_phrase_string, prev2, prev);
add_phrase_features(features, geodb_phrase_types, ADDRESS_POSTAL_CODE, "gn postal code", geo_phrase_string, prev2, prev);
}
@@ -560,24 +691,27 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize
uint32_t word_freq = word_vocab_frequency(parser, word);
if (phrase_string == NULL && geo_phrase_string == NULL) {
if (add_word_feature) {
// Bias unit, acts as an intercept
feature_array_add(features, 1, "bias");
if (word_freq > 0) {
// The individual word
feature_array_add(features, 2, "word", word);
} else {
log_debug("word not in vocab: %s\n", original_word);
word = UNKNOWN_WORD;
word = (token.type != NUMERIC && token.type != IDEOGRAPHIC_NUMBER) ? UNKNOWN_WORD : UNKNOWN_NUMERIC;
}
}
if (prev != NULL) {
if (prev != NULL && last_index == i - 1) {
// Previous tag and current word
feature_array_add(features, 3, "i-1 tag+word", prev, current_word);
feature_array_add(features, 3, "i-1 tag+word", prev, word);
feature_array_add(features, 2, "i-1 tag", prev);
if (prev2 != NULL) {
// Previous two tags and current word
feature_array_add(features, 4, "i-2 tag+i-1 tag+word", prev2, prev, current_word);
feature_array_add(features, 4, "i-2 tag+i-1 tag+word", prev2, prev, word);
feature_array_add(features, 3, "i-2 tag+i-1 tag", prev2, prev);
}
}
@@ -587,15 +721,14 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize
uint32_t prev_word_freq = word_vocab_frequency(parser, prev_word);
if (prev_word_freq == 0) {
prev_word = UNKNOWN_WORD;
token_t prev_token = tokenized->tokens->a[last_index];
prev_word = (prev_token.type != NUMERIC && prev_token.type != IDEOGRAPHIC_NUMBER) ? UNKNOWN_WORD : UNKNOWN_NUMERIC;
}
// Previous word
feature_array_add(features, 2, "i-1 word", prev_word);
// Previous tag + previous word
if (last_index == i - 1) {
feature_array_add(features, 3, "i-1 tag+i-1 word", prev, prev_word);
}
feature_array_add(features, 3, "i-1 tag+i-1 word", prev, prev_word);
// Previous word and current word
feature_array_add(features, 3, "i-1 word+word", prev_word, word);
}
@@ -607,15 +740,34 @@ bool address_parser_features(void *self, void *ctx, tokenized_string_t *tokenize
uint32_t next_word_freq = word_vocab_frequency(parser, next_word);
if (next_word_freq == 0) {
next_word = UNKNOWN_WORD;
token_t next_token = tokenized->tokens->a[next_index];
next_word = (next_token.type != NUMERIC && next_token.type != IDEOGRAPHIC_NUMBER) ? UNKNOWN_WORD : UNKNOWN_NUMERIC;
}
// Next word e.g. if the current word is unknown and the next word is "street"
feature_array_add(features, 2, "i+1 word", next_word);
// Current word and next word
feature_array_add(features, 3, "word+i+1 word", word, next_word);
}
#ifndef PRINT_FEATURES
if (0) {
#endif
uint32_t idx;
char *feature;
printf("{");
cstring_array_foreach(features, idx, feature, {
printf(" %s, ", feature);
})
printf("}\n");
#ifndef PRINT_FEATURES
}
#endif
return true;
}
@@ -681,7 +833,7 @@ address_parser_response_t *address_parser_parse(char *address, char *language, c
uint32_array_push(context->separators, ADDRESS_SEPARATOR_NONE);
}
address_parser_context_fill(context, tokenized_str, language, country);
address_parser_context_fill(context, parser, tokenized_str, language, country);
cstring_array *token_labels = cstring_array_new_size(tokens->n);

View File

@@ -10,7 +10,6 @@
#include "log/log.h"
typedef struct address_parser_test_results {
size_t num_errors;
size_t num_predictions;
@@ -67,7 +66,7 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
}
char *country = char_array_get_string(data_set->country);
address_parser_context_fill(context, data_set->tokenized_str, language, country);
address_parser_context_fill(context, parser, data_set->tokenized_str, language, country);
cstring_array *token_labels = cstring_array_new_size(data_set->tokenized_str->strings->str->n);
@@ -90,6 +89,7 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
uint32_t truth_index = get_class_index(parser, truth);
result->confusion[predicted_index * num_classes + truth_index]++;
}
result->num_predictions++;

View File

@@ -15,6 +15,14 @@
#define DEFAULT_ITERATIONS 5
#define MIN_VOCAB_COUNT 5
#define MIN_PHRASE_COUNT 1
typedef struct phrase_stats {
khash_t(int_uint32) *class_counts;
address_parser_types_t parser_types;
} phrase_stats_t;
KHASH_MAP_INIT_STR(phrase_stats, phrase_stats_t)
address_parser_t *address_parser_init(char *filename) {
if (filename == NULL) {
@@ -37,21 +45,43 @@ address_parser_t *address_parser_init(char *filename) {
}
khash_t(str_uint32) *vocab = kh_init(str_uint32);
if (vocab == NULL) {
log_error("Could not allocate vocab\n");
return NULL;
}
khash_t(phrase_stats) *phrase_stats = kh_init(phrase_stats);
if (phrase_stats == NULL) {
log_error("Could not allocate phrase_stats\n");
return NULL;
}
khash_t(str_uint32) *phrase_types = kh_init(str_uint32);
if (phrase_types == NULL) {
log_error("Could not allocate phrase_types\n");
return NULL;
}
khiter_t k;
char *str;
phrase_stats_t stats;
khash_t(int_uint32) *class_counts;
uint32_t vocab_size = 0;
size_t examples = 0;
const char *word;
uint32_t i;
char *token;
const char *token;
char *normalized;
uint32_t count;
char_array *token_array = char_array_new();
char *key;
int ret = 0;
char_array *token_builder = char_array_new();
char_array *phrase_builder = char_array_new();
while (address_parser_data_set_next(data_set)) {
tokenized_string_t *tokenized_str = data_set->tokenized_str;
@@ -59,38 +89,35 @@ address_parser_t *address_parser_init(char *filename) {
if (tokenized_str == NULL) {
log_error("tokenized str is NULL\n");
kh_destroy(str_uint32, vocab);
return false;
return NULL;
}
str = tokenized_str->str;
char *prev_label = NULL;
size_t num_strings = cstring_array_num_strings(tokenized_str->strings);
cstring_array_foreach(tokenized_str->strings, i, token, {
token_t t = tokenized_str->tokens->a[i];
char_array_clear(token_array);
add_normalized_token(token_array, str, t, ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS);
if (token_array->n == 0) {
char_array_clear(token_builder);
add_normalized_token(token_builder, str, t, ADDRESS_PARSER_NORMALIZE_TOKEN_OPTIONS);
if (token_builder->n == 0) {
continue;
}
normalized = char_array_get_string(token_array);
normalized = char_array_get_string(token_builder);
k = kh_get(str_uint32, vocab, normalized);
if (k == kh_end(vocab)) {
int ret;
char *key = strdup(normalized);
key = strdup(normalized);
k = kh_put(str_uint32, vocab, key, &ret);
if (ret < 0) {
log_error("Error in kh_put\n");
log_error("Error in kh_put in vocab\n");
free(key);
tokenized_string_destroy(tokenized_str);
kh_foreach(vocab, word, count, {
free((char *)word);
})
kh_destroy(str_uint32, vocab);
char_array_destroy(token_array);
return false;
goto exit_hashes_allocated;
}
kh_value(vocab, k) = 1;
vocab_size++;
@@ -98,6 +125,111 @@ address_parser_t *address_parser_init(char *filename) {
kh_value(vocab, k)++;
}
char *label = cstring_array_get_string(data_set->labels, i);
if (label == NULL) {
continue;
}
if (string_equals(label, "road") || string_equals(label, "house_number") || string_equals(label, "house")) {
prev_label = NULL;
continue;
}
if (prev_label == NULL || !string_equals(label, prev_label) || i == num_strings - 1) {
if (i == num_strings - 1) {
if (!string_equals(label, prev_label)) {
char_array_clear(phrase_builder);
} else if (prev_label != NULL) {
char_array_cat(phrase_builder, " ");
}
char_array_cat(phrase_builder, normalized);
prev_label = label;
}
// End of phrase, add to hashtable
if (prev_label != NULL) {
uint32_t class_id;
uint32_t component = 0;
// Too many variations on these
if (string_equals(prev_label, "city")) {
class_id = ADDRESS_PARSER_CITY;
component = ADDRESS_COMPONENT_CITY;
} else if (string_equals(prev_label, "state")) {
class_id = ADDRESS_PARSER_STATE;
component = ADDRESS_COMPONENT_STATE;
} else if (string_equals(prev_label, "country")) {
class_id = ADDRESS_PARSER_COUNTRY;
component = ADDRESS_COMPONENT_COUNTRY;
} else if (string_equals(prev_label, "state_district")) {
class_id = ADDRESS_PARSER_STATE_DISTRICT;
component = ADDRESS_COMPONENT_STATE_DISTRICT;
} else if (string_equals(prev_label, "suburb")) {
class_id = ADDRESS_PARSER_SUBURB;
component = ADDRESS_COMPONENT_SUBURB;
} else if (string_equals(prev_label, "city_district")) {
class_id = ADDRESS_PARSER_CITY_DISTRICT;
component = ADDRESS_COMPONENT_CITY_DISTRICT;
} else if (string_equals(prev_label, "postcode")) {
class_id = ADDRESS_PARSER_POSTAL_CODE;
component = ADDRESS_COMPONENT_POSTAL_CODE;
}
char *phrase = char_array_get_string(phrase_builder);
k = kh_get(phrase_stats, phrase_stats, phrase);
if (k == kh_end(phrase_stats)) {
key = strdup(phrase);
ret = 0;
k = kh_put(phrase_stats, phrase_stats, key, &ret);
if (ret < 0) {
log_error("Error in kh_put in phrase_stats\n");
free(key);
goto exit_hashes_allocated;
}
class_counts = kh_init(int_uint32);
stats.class_counts = class_counts;
stats.parser_types.components = component;
stats.parser_types.most_common = 0;
kh_value(phrase_stats, k) = stats;
} else {
stats = kh_value(phrase_stats, k);
class_counts = stats.class_counts;
stats.parser_types.components |= component;
}
k = kh_get(int_uint32, class_counts, (khint_t)class_id);
if (k == kh_end(class_counts)) {
ret = 0;
k = kh_put(int_uint32, class_counts, class_id, &ret);
if (ret < 0) {
log_error("Error in kh_put in class_counts\n");
goto exit_hashes_allocated;
}
kh_value(class_counts, k) = 1;
} else {
kh_value(class_counts, k)++;
}
}
char_array_clear(phrase_builder);
} else if (prev_label != NULL) {
char_array_cat(phrase_builder, " ");
}
char_array_cat(phrase_builder, normalized);
prev_label = label;
})
tokenized_string_destroy(tokenized_str);
@@ -110,38 +242,92 @@ address_parser_t *address_parser_init(char *filename) {
log_debug("Done with vocab, total size=%d\n", vocab_size);
for (k = kh_begin(vocab); k != kh_end(vocab); ++k) {
char *word = (char *)kh_key(vocab, k);
token = (char *)kh_key(vocab, k);
if (!kh_exist(vocab, k)) {
continue;
}
uint32_t count = kh_value(vocab, k);
if (count < MIN_VOCAB_COUNT) {
kh_del(str_uint32, vocab, k);
free(word);
free((char *)token);
}
}
kh_foreach(phrase_stats, token, stats, {
class_counts = stats.class_counts;
int most_common = -1;
uint32_t max_count = 0;
uint32_t total = 0;
for (int i = 0; i < NUM_ADDRESS_PARSER_TYPES; i++) {
k = kh_get(int_uint32, class_counts, (khint_t)i);
if (k != kh_end(class_counts)) {
count = kh_value(class_counts, k);
if (count > max_count) {
max_count = count;
most_common = i;
}
total += count;
}
}
if (most_common > -1 && total >= MIN_PHRASE_COUNT) {
stats.parser_types.most_common = (uint32_t)most_common;
ret = 0;
char *key = strdup(token);
k = kh_put(str_uint32, phrase_types, key, &ret);
if (ret < 0) {
log_error("Error on kh_put in phrase_types\n");
free(key);
goto exit_hashes_allocated;
}
kh_value(phrase_types, k) = stats.parser_types.value;
}
})
parser->vocab = trie_new_from_hash(vocab);
for (k = kh_begin(vocab); k != kh_end(vocab); ++k) {
if (!kh_exist(vocab, k)) {
continue;
}
char *word = (char *)kh_key(vocab, k);
free(word);
}
kh_destroy(str_uint32, vocab);
char_array_destroy(token_array);
address_parser_data_set_destroy(data_set);
if (parser->vocab == NULL) {
log_error("Error initializing vocabulary\n");
address_parser_destroy(parser);
return NULL;
parser = NULL;
goto exit_hashes_allocated;
}
parser->phrase_types = trie_new_from_hash(phrase_types);
if (parser->phrase_types == NULL) {
log_error("Error converting phrase_types to trie\n");
address_parser_destroy(parser);
parser = NULL;
goto exit_hashes_allocated;
}
exit_hashes_allocated:
// Free memory for hashtables, etc.
char_array_destroy(token_builder);
char_array_destroy(phrase_builder);
address_parser_data_set_destroy(data_set);
kh_foreach(vocab, token, count, {
free((char *)token);
})
kh_destroy(str_uint32, vocab);
kh_foreach(phrase_stats, token, stats, {
kh_destroy(int_uint32, stats.class_counts);
free((char *)token);
})
kh_destroy(phrase_stats, phrase_stats);
kh_foreach(phrase_types, token, count, {
free((char *)token);
})
kh_destroy(str_uint32, phrase_types);
return parser;
}
@@ -174,7 +360,7 @@ bool address_parser_train_epoch(address_parser_t *self, averaged_perceptron_trai
}
char *country = char_array_get_string(data_set->country);
address_parser_context_fill(context, data_set->tokenized_str, language, country);
address_parser_context_fill(context, self, data_set->tokenized_str, language, country);
bool example_success = averaged_perceptron_trainer_train_example(trainer, self, context, context->features, &address_parser_features, data_set->tokenized_str, data_set->labels);