Files
libpostal/src/averaged_perceptron_trainer.c

515 lines
15 KiB
C

#include "averaged_perceptron_trainer.h"
void averaged_perceptron_trainer_destroy(averaged_perceptron_trainer_t *self) {
if (self == NULL) return;
const char *key;
uint32_t id;
if (self->features != NULL) {
kh_foreach(self->features, key, id, {
free((char *)key);
})
kh_destroy(str_uint32, self->features);
}
if (self->classes != NULL) {
kh_foreach(self->classes, key, id, {
free((char *)key);
})
kh_destroy(str_uint32, self->classes);
}
if (self->class_strings != NULL) {
cstring_array_destroy(self->class_strings);
}
uint32_t feature_id;
khash_t(class_weights) *weights;
kh_foreach(self->weights, feature_id, weights, {
kh_destroy(class_weights, weights);
})
if (self->weights != NULL) {
kh_destroy(feature_class_weights, self->weights);
}
if (self->update_counts != NULL) {
uint64_array_destroy(self->update_counts);
}
if (self->scores != NULL) {
double_array_destroy(self->scores);
}
free(self);
}
bool averaged_perceptron_trainer_get_class_id(averaged_perceptron_trainer_t *self, char *class_name, uint32_t *class_id, bool add_if_missing) {
khiter_t k;
if (class_name == NULL) {
log_error("class_name was NULL\n");
return false;
}
khash_t(str_uint32) *classes = self->classes;
k = kh_get(str_uint32, classes, class_name);
if (k != kh_end(classes)) {
*class_id = kh_value(classes, k);
return true;
} else if (add_if_missing) {
uint32_t new_id = (uint32_t)kh_size(classes);
int ret;
char *key = strdup(class_name);
if (key == NULL) {
return false;
}
k = kh_put(str_uint32, classes, key, &ret);
if (ret < 0) {
return false;
}
kh_value(classes, k) = new_id;
*class_id = new_id;
cstring_array_add_string(self->class_strings, class_name);
self->num_classes++;
return true;
}
return false;
}
bool averaged_perceptron_trainer_get_feature_id(averaged_perceptron_trainer_t *self, char *feature, uint32_t *feature_id, bool add_if_missing) {
khiter_t k;
if (feature == NULL) {
log_error("feature was NULL\n");
return false;
}
khash_t(str_uint32) *features = self->features;
k = kh_get(str_uint32, features, feature);
if (k != kh_end(features)) {
*feature_id = kh_value(features, k);
return true;
} else if (add_if_missing) {
uint32_t new_id = (uint32_t)kh_size(features);
int ret;
char *key = strdup(feature);
if (key == NULL) {
return false;
}
k = kh_put(str_uint32, features, key, &ret);
if (ret < 0) {
return false;
}
kh_value(features, k) = new_id;
*feature_id = new_id;
uint64_array_push(self->update_counts, 0);
self->num_features++;
return true;
}
return false;
}
averaged_perceptron_t *averaged_perceptron_trainer_finalize(averaged_perceptron_trainer_t *self) {
if (self == NULL || self->num_classes == 0) return NULL;
uint32_t class_id;
class_weight_t weight;
uint64_t updates = self->num_updates;
khash_t(class_weights) *weights;
char **feature_keys = malloc(sizeof(char *) * self->num_features);
uint32_t feature_id;
const char *feature;
kh_foreach(self->features, feature, feature_id, {
if (feature_id >= self->num_features) {
free(feature_keys);
return NULL;
}
feature_keys[feature_id] = (char *)feature;
})
sparse_matrix_t *averaged_weights = sparse_matrix_new();
uint32_t next_feature_id = 0;
khiter_t k;
uint64_t *update_counts = self->update_counts->a;
log_info("Finalizing trainer, num_features=%u\n", self->num_features);
log_info("Pruning weights with < min_updates = %llu\n", self->min_updates);
for (feature_id = 0; feature_id < self->num_features; feature_id++) {
k = kh_get(feature_class_weights, self->weights, feature_id);
if (k == kh_end(self->weights)) {
sparse_matrix_destroy(averaged_weights);
free(feature_keys);
return NULL;
}
weights = kh_value(self->weights, k);
uint32_t class_id;
uint64_t update_count = update_counts[feature_id];
bool keep_feature = update_count >= self->min_updates;
uint32_t new_feature_id = next_feature_id;
if (keep_feature) {
kh_foreach(weights, class_id, weight, {
weight.total += (updates - weight.last_updated) * weight.value;
double value = weight.total / updates;
sparse_matrix_append(averaged_weights, class_id, value);
})
sparse_matrix_finalize_row(averaged_weights);
next_feature_id++;
}
if (!keep_feature || new_feature_id != feature_id) {
feature = feature_keys[feature_id];
k = kh_get(str_uint32, self->features, feature);
if (k != kh_end(self->features)) {
if (keep_feature) {
kh_value(self->features, k) = new_feature_id;
} else {
kh_del(str_uint32, self->features, k);
}
} else {
log_error("Error in kh_get on self->features\n");
averaged_perceptron_trainer_destroy(self);
return NULL;
}
}
}
free(feature_keys);
self->num_features = kh_size(self->features);
log_info("After pruning, num_features=%u\n", self->num_features);
averaged_perceptron_t *perceptron = malloc(sizeof(averaged_perceptron_t));
perceptron->weights = averaged_weights;
trie_t *features = trie_new_from_hash(self->features);
if (features == NULL) {
averaged_perceptron_trainer_destroy(self);
return NULL;
}
perceptron->features = features;
perceptron->num_features = self->num_features;
perceptron->num_classes = self->num_classes;
perceptron->scores = double_array_new_zeros(perceptron->num_classes);
// Set our pointers to NULL so they don't get free'd on destroy
perceptron->classes = self->class_strings;
self->class_strings = NULL;
averaged_perceptron_trainer_destroy(self);
return perceptron;
}
khash_t(class_weights) *averaged_perceptron_trainer_get_class_weights(averaged_perceptron_trainer_t *self, uint32_t feature_id, bool add_if_missing) {
khiter_t k;
k = kh_get(feature_class_weights, self->weights, feature_id);
if (k != kh_end(self->weights)) {
return kh_value(self->weights, k);
} else if (add_if_missing) {
khash_t(class_weights) *weights = kh_init(class_weights);
int ret;
k = kh_put(feature_class_weights, self->weights, feature_id, &ret);
if (ret < 0) {
kh_destroy(class_weights, weights);
return NULL;
}
kh_value(self->weights, k) = weights;
return weights;
}
return NULL;
}
static inline bool averaged_perceptron_trainer_update_weight(khash_t(class_weights) *weights, uint64_t iter, uint32_t class_id, double value) {
class_weight_t weight = NULL_WEIGHT;
size_t index;
khiter_t k;
k = kh_get(class_weights, weights, class_id);
if (k != kh_end(weights)) {
weight = kh_value(weights, k);
}
weight.total += (iter - weight.last_updated) * weight.value;
weight.last_updated = iter;
weight.value += value;
int ret;
k = kh_put(class_weights, weights, class_id, &ret);
if (ret < 0) return false;
kh_value(weights, k) = weight;
return true;
}
static inline bool averaged_perceptron_trainer_update_feature(averaged_perceptron_trainer_t *self, uint32_t feature_id, uint32_t guess, uint32_t truth, double value) {
bool add_if_missing = true;
khash_t(class_weights) *weights = averaged_perceptron_trainer_get_class_weights(self, feature_id, add_if_missing);
if (weights == NULL) {
return false;
}
uint64_t updates = self->num_updates;
if (!averaged_perceptron_trainer_update_weight(weights, updates, guess, -1.0 * value) ||
!averaged_perceptron_trainer_update_weight(weights, updates, truth, value)) {
return false;
}
uint64_t *update_counts = self->update_counts->a;
update_counts[feature_id]++;
return true;
}
uint32_t averaged_perceptron_trainer_predict(averaged_perceptron_trainer_t *self, cstring_array *features) {
double_array *scores = self->scores;
size_t num_classes = (size_t)self->num_classes;
uint32_t i = 0;
char *feature = NULL;
bool add_if_missing = false;
uint32_t feature_id;
khash_t(class_weights) *weights;
uint32_t class_id;
class_weight_t weight;
if (scores->m < num_classes) {
double_array_resize(scores, num_classes);
}
if (scores->n < num_classes) {
scores->n = num_classes;
}
double_array_zero(scores->a, scores->n);
uint64_t *update_counts = self->update_counts->a;
cstring_array_foreach(features, i, feature, {
if (!averaged_perceptron_trainer_get_feature_id(self, feature, &feature_id, add_if_missing)) {
continue;
}
uint64_t update_count = update_counts[feature_id];
bool keep_feature = update_count >= self->min_updates;
if (keep_feature) {
weights = averaged_perceptron_trainer_get_class_weights(self, feature_id, add_if_missing);
if (weights == NULL) {
continue;
}
kh_foreach(weights, class_id, weight, {
scores->a[class_id] += weight.value;
})
}
})
int64_t max_score = double_array_argmax(scores->a, scores->n);
return (uint32_t)max_score;
}
bool averaged_perceptron_trainer_update(averaged_perceptron_trainer_t *self, uint32_t guess, uint32_t truth, cstring_array *features) {
uint32_t i = 0;
char *feature = NULL;
uint32_t feature_id;
bool add_if_missing = true;
cstring_array_foreach(features, i, feature, {
if (!averaged_perceptron_trainer_get_feature_id(self, feature, &feature_id, add_if_missing)) {
return false;
}
if (!averaged_perceptron_trainer_update_feature(self, feature_id, guess, truth, 1.0)) {
return false;
}
})
self->num_updates++;
return true;
}
bool averaged_perceptron_trainer_update_counts(averaged_perceptron_trainer_t *self, uint32_t guess, uint32_t truth, khash_t(str_uint32) *feature_counts) {
const char *feature;
uint32_t feature_id;
uint32_t count;
bool add_if_missing = true;
kh_foreach(feature_counts, feature, count, {
if (!averaged_perceptron_trainer_get_feature_id(self, (char *)feature, &feature_id, add_if_missing)) {
return false;
}
if (!averaged_perceptron_trainer_update_feature(self, feature_id, guess, truth, (double)count)) {
return false;
}
})
self->num_updates++;
return true;
}
bool averaged_perceptron_trainer_train_example(averaged_perceptron_trainer_t *self, void *tagger, void *context, cstring_array *features, cstring_array *prev_tag_features, cstring_array *prev2_tag_features, tagger_feature_function feature_function, tokenized_string_t *tokenized, cstring_array *labels) {
// Keep two tags of history in training
char *prev = NULL;
char *prev2 = NULL;
uint32_t prev_id = 0;
uint32_t prev2_id = 0;
size_t num_tokens = tokenized->tokens->n;
if (cstring_array_num_strings(labels) != num_tokens) {
return false;
}
bool add_if_missing = true;
for (uint32_t i = 0; i < num_tokens; i++) {
cstring_array_clear(features);
cstring_array_clear(prev_tag_features);
cstring_array_clear(prev2_tag_features);
char *label = cstring_array_get_string(labels, i);
if (label == NULL) {
log_error("label is NULL\n");
}
if (!feature_function(tagger, context, tokenized, i)) {
log_error("Could not add address parser features\n");
return false;
}
uint32_t truth;
if (!averaged_perceptron_trainer_get_class_id(self, label, &truth, add_if_missing)) {
log_error("Get class id failed\n");
return false;
}
uint32_t fidx;
const char *feature;
if (i > 0) {
prev = cstring_array_get_string(self->class_strings, prev_id);
cstring_array_foreach(prev_tag_features, fidx, feature, {
feature_array_add(features, 3, "prev", prev, (char *)feature);
})
if (i > 1) {
prev2 = cstring_array_get_string(self->class_strings, prev2_id);
cstring_array_foreach(prev2_tag_features, fidx, feature, {
feature_array_add(features, 5, "prev2", prev2, "prev", prev, (char *)feature);
})
}
}
uint32_t guess = averaged_perceptron_trainer_predict(self, features);
// Online error-driven learning, only needs to update weights when it gets a wrong answer, making training fast
if (guess != truth) {
self->num_errors++;
if (!averaged_perceptron_trainer_update(self, guess, truth, features)) {
log_error("Trainer update failed\n");
return false;
}
}
prev2_id = prev_id;
prev_id = guess;
}
return true;
}
averaged_perceptron_trainer_t *averaged_perceptron_trainer_new(uint64_t min_updates) {
averaged_perceptron_trainer_t *self = calloc(1, sizeof(averaged_perceptron_trainer_t));
if (self == NULL) return NULL;
self->num_features = 0;
self->num_classes = 0;
self->num_updates = 0;
self->num_errors = 0;
self->iterations = 0;
self->min_updates = min_updates;
self->features = kh_init(str_uint32);
if (self->features == NULL) {
goto exit_trainer_created;
}
self->classes = kh_init(str_uint32);
if (self->classes == NULL) {
goto exit_trainer_created;
}
self->class_strings = cstring_array_new();
if (self->class_strings == NULL) {
goto exit_trainer_created;
}
self->weights = kh_init(feature_class_weights);
if (self->weights == NULL) {
goto exit_trainer_created;
}
self->update_counts = uint64_array_new();
if (self->update_counts == NULL) {
goto exit_trainer_created;
}
self->scores = double_array_new();
if (self->scores == NULL) {
goto exit_trainer_created;
}
return self;
exit_trainer_created:
averaged_perceptron_trainer_destroy(self);
return NULL;
}