diff --git a/src/averaged_perceptron_trainer.c b/src/averaged_perceptron_trainer.c index 3039156a..ff96244a 100644 --- a/src/averaged_perceptron_trainer.c +++ b/src/averaged_perceptron_trainer.c @@ -1,16 +1,35 @@ #include "averaged_perceptron_trainer.h" +#include "klib/ksort.h" + +#define START "START" +#define START2 "START2" + +KSORT_INIT_STR void averaged_perceptron_trainer_destroy(averaged_perceptron_trainer_t *self) { if (self == NULL) return; + const char *key; + uint32_t id; + if (self->features != NULL) { - trie_destroy(self->features); + kh_foreach(self->features, key, id, { + free((char *)key); + }) + kh_destroy(str_uint32, self->features); } if (self->classes != NULL) { + kh_foreach(self->classes, key, id, { + free((char *)key); + }) kh_destroy(str_uint32, self->classes); } + if (self->class_strings != NULL) { + cstring_array_destroy(self->class_strings); + } + uint32_t feature_id; khash_t(class_weights) *weights; @@ -33,6 +52,11 @@ void averaged_perceptron_trainer_destroy(averaged_perceptron_trainer_t *self) { bool averaged_perceptron_trainer_get_class_id(averaged_perceptron_trainer_t *self, char *class_name, uint32_t *class_id, bool add_if_missing) { khiter_t k; + if (class_name == NULL) { + log_error("class_name was NULL\n"); + return false; + } + khash_t(str_uint32) *classes = self->classes; k = kh_get(str_uint32, classes, class_name); @@ -40,9 +64,16 @@ bool averaged_perceptron_trainer_get_class_id(averaged_perceptron_trainer_t *sel *class_id = kh_value(classes, k); return true; } else if (add_if_missing) { - uint32_t new_id = kh_size(classes); + uint32_t new_id = (uint32_t)kh_size(classes); int ret; - k = kh_put(str_uint32, classes, class_name, &ret); + char *key = strdup(class_name); + if (key == NULL) { + return false; + } + k = kh_put(str_uint32, classes, key, &ret); + if (ret < 0) { + return false; + } kh_value(classes, k) = new_id; *class_id = new_id; @@ -54,23 +85,39 @@ bool averaged_perceptron_trainer_get_class_id(averaged_perceptron_trainer_t *sel } bool averaged_perceptron_trainer_get_feature_id(averaged_perceptron_trainer_t *self, char *feature, uint32_t *feature_id, bool add_if_missing) { - trie_t *features = self->features; + khiter_t k; - bool in_trie = trie_get_data(features, feature, feature_id); - - if (add_if_missing && !in_trie) { - uint32_t new_id = features->num_keys; - *feature_id = new_id; - if (!trie_add(features, feature, new_id)) { - return false; - } - self->num_features++; - return true; - } else if (in_trie) { - return true; + if (feature == NULL) { + log_error("feature was NULL\n"); + return false; } + khash_t(str_uint32) *features = self->features; + + + k = kh_get(str_uint32, features, feature); + if (k != kh_end(features)) { + *feature_id = kh_value(features, k); + return true; + } else if (add_if_missing) { + uint32_t new_id = (uint32_t)kh_size(features); + int ret; + char *key = strdup(feature); + if (key == NULL) { + return false; + } + k = kh_put(str_uint32, features, key, &ret); + if (ret < 0) { + return false; + } + kh_value(features, k) = new_id; + *feature_id = new_id; + + self->num_features++; + return true; + } return false; + } averaged_perceptron_t *averaged_perceptron_trainer_finalize(averaged_perceptron_trainer_t *self) { @@ -108,6 +155,42 @@ averaged_perceptron_t *averaged_perceptron_trainer_finalize(averaged_perceptron_ perceptron->weights = averaged_weights; + trie_t *features = trie_new(); + const char *key; + uint32_t feature_id; + + string_array *feature_keys = string_array_new_size(kh_size(self->features)); + kh_foreach(self->features, key, feature_id, { + string_array_push(feature_keys, (char *)key); + }) + + ks_introsort(str, feature_keys->n, (const char **)feature_keys->a); + + khiter_t k; + + for (int i = 0; i < feature_keys->n; i++) { + char *str = feature_keys->a[i]; + k = kh_get(str_uint32, self->features, str); + if (k == kh_end(self->features)) { + log_error("Key not found\n"); + trie_destroy(features); + averaged_perceptron_destroy(perceptron); + } + + feature_id = kh_value(self->features, k); + + if (!trie_add(features, str, feature_id)) { + log_error("Error adding to trie\n"); + trie_destroy(features); + averaged_perceptron_destroy(perceptron); + return NULL; + } + } + + string_array_destroy(feature_keys); + + perceptron->features = features; + perceptron->num_features = self->num_features; perceptron->num_classes = self->num_classes; @@ -117,9 +200,6 @@ averaged_perceptron_t *averaged_perceptron_trainer_finalize(averaged_perceptron_ perceptron->classes = self->class_strings; self->class_strings = NULL; - perceptron->features = self->features; - self->features = NULL; - averaged_perceptron_trainer_destroy(self); return perceptron; @@ -276,20 +356,64 @@ bool averaged_perceptron_trainer_update_counts(averaged_perceptron_trainer_t *se return true; } -bool averaged_perceptron_trainer_train_example(averaged_perceptron_trainer_t *self, cstring_array *features, char *label) { - uint32_t truth; - bool add_if_missing = true; +bool averaged_perceptron_trainer_train_example(averaged_perceptron_trainer_t *self, void *tagger, cstring_array *features, ap_tagger_feature_function feature_function, tokenized_string_t *tokenized, cstring_array *labels) { + // Keep two tags of history in training + char *prev = START; + char *prev2 = START2; - if (!averaged_perceptron_trainer_get_class_id(self, label, &truth, add_if_missing)) { + uint32_t prev_id = 0; + uint32_t prev2_id = 0; + + size_t num_tokens = tokenized->tokens->n; + if (cstring_array_num_strings(labels) != num_tokens) { return false; } - uint32_t guess = averaged_perceptron_trainer_predict(self, features); + bool add_if_missing = true; + + for (uint32_t i = 0; i < num_tokens; i++) { + cstring_array_clear(features); + + char *label = cstring_array_get_string(labels, i); + if (label == NULL) { + log_error("label is NULL\n"); + } + + if (i > 0) { + prev = cstring_array_get_string(labels, prev_id); + } + + if (i > 1) { + prev2 = cstring_array_get_string(labels, prev2_id); + } + + if (!feature_function(tagger, features, tokenized, i, prev, prev2)) { + log_error("Could not add address parser features\n"); + return false; + } + + uint32_t truth; + + if (!averaged_perceptron_trainer_get_class_id(self, label, &truth, add_if_missing)) { + log_error("Get class id failed\n"); + return false; + } + + uint32_t guess = averaged_perceptron_trainer_predict(self, features); + char *predicted = cstring_array_get_string(self->class_strings, guess); + + // Online error-driven learning, only needs to update weights when it gets a wrong answer, making training fast + if (guess != truth) { + self->num_errors++; + if (!averaged_perceptron_trainer_update(self, guess, truth, features)) { + log_error("Trainer update failed\n"); + return false; + } + } + + prev2_id = prev_id; + prev_id = guess; - // Online error-driven learning, only needs to update weights when it gets a wrong answer, making training fast - if (guess != truth) { - self->num_errors++; - return averaged_perceptron_trainer_update(self, guess, truth, features); } return true; @@ -306,7 +430,7 @@ averaged_perceptron_trainer_t *averaged_perceptron_trainer_new(void) { self->num_updates = 0; self->num_errors = 0; - self->features = trie_new(); + self->features = kh_init(str_uint32); if (self->features == NULL) { goto exit_trainer_created; } @@ -335,3 +459,4 @@ exit_trainer_created: averaged_perceptron_trainer_destroy(self); return NULL; } + diff --git a/src/averaged_perceptron_trainer.h b/src/averaged_perceptron_trainer.h index b1ed3a1a..2e54d59f 100644 --- a/src/averaged_perceptron_trainer.h +++ b/src/averaged_perceptron_trainer.h @@ -36,6 +36,7 @@ Link: http://www.cs.columbia.edu/~mcollins/papers/tagperc.pdf #include "averaged_perceptron.h" #include "collections.h" #include "string_utils.h" +#include "tokens.h" #include "trie.h" typedef struct class_weight { @@ -50,12 +51,14 @@ KHASH_MAP_INIT_INT(class_weights, class_weight_t) KHASH_MAP_INIT_INT(feature_class_weights, khash_t(class_weights) *) +typedef bool (*ap_tagger_feature_function)(void *, cstring_array *, tokenized_string_t *, uint32_t, char *, char *); + typedef struct averaged_perceptron_trainer { uint32_t num_features; uint32_t num_classes; uint64_t num_updates; uint64_t num_errors; - trie_t *features; + khash_t(str_uint32) *features; khash_t(str_uint32) *classes; cstring_array *class_strings; // {feature_id => {class_id => class_weight_t}} @@ -66,10 +69,19 @@ typedef struct averaged_perceptron_trainer { averaged_perceptron_trainer_t *averaged_perceptron_trainer_new(void); uint32_t averaged_perceptron_trainer_predict(averaged_perceptron_trainer_t *self, cstring_array *features); -bool averaged_perceptron_trainer_train_example(averaged_perceptron_trainer_t *trainer, cstring_array *features, char *label); + +bool averaged_perceptron_trainer_train_example(averaged_perceptron_trainer_t *self, + void *tagger, + cstring_array *features, + ap_tagger_feature_function feature_function, + tokenized_string_t *tokenized, + cstring_array *labels + ); averaged_perceptron_t *averaged_perceptron_trainer_finalize(averaged_perceptron_trainer_t *self); + + void averaged_perceptron_trainer_destroy(averaged_perceptron_trainer_t *self); #endif