diff --git a/src/crf_trainer_averaged_perceptron.c b/src/crf_trainer_averaged_perceptron.c new file mode 100644 index 00000000..d34269b3 --- /dev/null +++ b/src/crf_trainer_averaged_perceptron.c @@ -0,0 +1,943 @@ +#include "crf_trainer_averaged_perceptron.h" + +void crf_averaged_perceptron_trainer_destroy(crf_averaged_perceptron_trainer_t *self) { + if (self == NULL) return; + + uint32_t feature_id; + khash_t(class_weights) *weights; + + if (self->weights != NULL) { + kh_foreach(self->weights, feature_id, weights, { + if (weights != NULL) { + kh_destroy(class_weights, weights); + } + }) + kh_destroy(feature_class_weights, self->weights); + } + + if (self->prev_tag_weights != NULL) { + khash_t(prev_tag_class_weights) *prev_tag_weights; + + kh_foreach(self->prev_tag_weights, feature_id, prev_tag_weights, { + if (prev_tag_weights != NULL) { + kh_destroy(prev_tag_class_weights, prev_tag_weights); + } + }) + + kh_destroy(feature_prev_tag_class_weights, self->prev_tag_weights); + } + + if (self->trans_weights != NULL) { + kh_destroy(prev_tag_class_weights, self->trans_weights); + } + + if (self->update_counts != NULL) { + uint64_array_destroy(self->update_counts); + } + + if (self->prev_tag_update_counts != NULL) { + uint64_array_destroy(self->prev_tag_update_counts); + } + + if (self->sequence_features != NULL) { + cstring_array_destroy(self->sequence_features); + } + + if (self->sequence_features_indptr != NULL) { + uint32_array_destroy(self->sequence_features_indptr); + } + + if (self->sequence_prev_tag_features != NULL) { + cstring_array_destroy(self->sequence_prev_tag_features); + } + + if (self->sequence_prev_tag_features_indptr != NULL) { + uint32_array_destroy(self->sequence_prev_tag_features_indptr); + } + + if (self->label_ids != NULL) { + uint32_array_destroy(self->label_ids); + } + + if (self->viterbi != NULL) { + uint32_array_destroy(self->viterbi); + } + + if (self->base_trainer != NULL) { + crf_trainer_destroy(self->base_trainer); + } + + free(self); +} + +crf_averaged_perceptron_trainer_t *crf_averaged_perceptron_trainer_new(size_t num_classes, size_t min_updates) { + crf_averaged_perceptron_trainer_t *self = calloc(1, sizeof(crf_averaged_perceptron_trainer_t)); + + if (self == NULL) return NULL; + + log_info("num_classes %zu\n", num_classes); + + self->num_updates = 0; + self->num_errors = 0; + self->iterations = 0; + self->min_updates = min_updates; + + self->base_trainer = crf_trainer_new(num_classes); + if (self->base_trainer == NULL) { + goto exit_trainer_created; + } + + self->weights = kh_init(feature_class_weights); + + if (self->weights == NULL) { + goto exit_trainer_created; + } + + self->prev_tag_weights = kh_init(feature_prev_tag_class_weights); + + if (self->prev_tag_weights == NULL) { + goto exit_trainer_created; + } + + self->trans_weights = kh_init(prev_tag_class_weights); + if (self->trans_weights == NULL) { + goto exit_trainer_created; + } + + self->update_counts = uint64_array_new(); + if (self->update_counts == NULL) { + goto exit_trainer_created; + } + + self->prev_tag_update_counts = uint64_array_new(); + if (self->prev_tag_update_counts == NULL) { + goto exit_trainer_created; + } + + self->sequence_features = cstring_array_new(); + if (self->sequence_features == NULL) { + goto exit_trainer_created; + } + + self->sequence_features_indptr = uint32_array_new(); + if (self->sequence_features_indptr == NULL) { + goto exit_trainer_created; + } + + self->sequence_prev_tag_features = cstring_array_new(); + if (self->sequence_prev_tag_features == NULL) { + goto exit_trainer_created; + } + + self->sequence_prev_tag_features_indptr = uint32_array_new(); + if (self->sequence_prev_tag_features_indptr == NULL) { + goto exit_trainer_created; + } + + self->label_ids = uint32_array_new(); + if (self->label_ids == NULL) { + goto exit_trainer_created; + } + + self->viterbi = uint32_array_new(); + if (self->viterbi == NULL) { + goto exit_trainer_created; + } + + return self; + +exit_trainer_created: + crf_averaged_perceptron_trainer_destroy(self); + return NULL; +} + +static inline uint32_t tag_bigram_class_id(crf_averaged_perceptron_trainer_t *self, tag_bigram_t tag_bigram) { + return tag_bigram.prev_class_id * self->base_trainer->num_classes + tag_bigram.class_id; +} + +khash_t(class_weights) *crf_averaged_perceptron_trainer_get_class_weights(crf_averaged_perceptron_trainer_t *self, uint32_t feature_id, bool add_if_missing) { + khiter_t k; + k = kh_get(feature_class_weights, self->weights, feature_id); + if (k != kh_end(self->weights)) { + return kh_value(self->weights, k); + } else if (add_if_missing) { + khash_t(class_weights) *weights = kh_init(class_weights); + int ret; + k = kh_put(feature_class_weights, self->weights, feature_id, &ret); + if (ret < 0) { + kh_destroy(class_weights, weights); + return NULL; + } + kh_value(self->weights, k) = weights; + return weights; + } + + return NULL; +} + +khash_t(prev_tag_class_weights) *crf_averaged_perceptron_trainer_get_prev_tag_class_weights(crf_averaged_perceptron_trainer_t *self, uint32_t feature_id, bool add_if_missing) { + khiter_t k; + k = kh_get(feature_prev_tag_class_weights, self->prev_tag_weights, feature_id); + if (k != kh_end(self->prev_tag_weights)) { + return kh_value(self->prev_tag_weights, k); + } else if (add_if_missing) { + khash_t(prev_tag_class_weights) *weights = kh_init(prev_tag_class_weights); + int ret; + k = kh_put(feature_prev_tag_class_weights, self->prev_tag_weights, feature_id, &ret); + if (ret < 0) { + kh_destroy(prev_tag_class_weights, weights); + return NULL; + } + kh_value(self->prev_tag_weights, k) = weights; + return weights; + } + + return NULL; +} + + +static inline bool crf_averaged_perceptron_trainer_update_weight(khash_t(class_weights) *weights, uint64_t iter, uint32_t class_id, double value) { + class_weight_t weight = NULL_WEIGHT; + + khiter_t k; + k = kh_get(class_weights, weights, class_id); + if (k != kh_end(weights)) { + weight = kh_value(weights, k); + } + + weight.total += (iter - weight.last_updated) * weight.value; + weight.last_updated = iter; + weight.value += value; + + int ret; + k = kh_put(class_weights, weights, class_id, &ret); + if (ret < 0) return false; + kh_value(weights, k) = weight; + + return true; + +} + +static inline bool crf_averaged_perceptron_trainer_update_prev_tag_weight(khash_t(prev_tag_class_weights) *weights, uint64_t iter, uint32_t prev_class_id, uint32_t class_id, double value) { + class_weight_t weight = NULL_WEIGHT; + + tag_bigram_t tag_bigram = {.prev_class_id = prev_class_id, .class_id = class_id}; + + uint64_t key = tag_bigram.value; + + khiter_t k; + k = kh_get(prev_tag_class_weights, weights, key); + if (k != kh_end(weights)) { + weight = kh_value(weights, k); + } + + weight.total += (iter - weight.last_updated) * weight.value; + weight.last_updated = iter; + weight.value += value; + + int ret; + k = kh_put(prev_tag_class_weights, weights, key, &ret); + if (ret < 0) return false; + kh_value(weights, k) = weight; + + return true; +} + + +static inline bool crf_averaged_perceptron_trainer_update_feature(crf_averaged_perceptron_trainer_t *self, uint32_t feature_id, uint32_t guess, uint32_t truth, double value) { + bool add_if_missing = true; + + khash_t(class_weights) *weights = crf_averaged_perceptron_trainer_get_class_weights(self, feature_id, add_if_missing); + + if (weights == NULL) { + return false; + } + + uint64_t updates = self->num_updates; + + if (!crf_averaged_perceptron_trainer_update_weight(weights, updates, guess, -1.0 * value) || + !crf_averaged_perceptron_trainer_update_weight(weights, updates, truth, value)) { + return false; + } + + return true; +} + + +static inline bool crf_averaged_perceptron_trainer_update_prev_tag_feature(crf_averaged_perceptron_trainer_t *self, uint32_t feature_id, uint32_t prev_guess, uint32_t prev_truth, uint32_t guess, uint32_t truth, double value) { + bool add_if_missing = true; + khash_t(prev_tag_class_weights) *weights = crf_averaged_perceptron_trainer_get_prev_tag_class_weights(self, feature_id, add_if_missing); + + if (weights == NULL) { + return false; + } + + uint64_t updates = self->num_updates; + + if (!crf_averaged_perceptron_trainer_update_prev_tag_weight(weights, updates, prev_guess, guess, -1.0 * value) || + !crf_averaged_perceptron_trainer_update_prev_tag_weight(weights, updates, prev_truth, truth, value)) { + return false; + } + + return true; +} + +static inline bool crf_averaged_perceptron_trainer_update_trans_feature(crf_averaged_perceptron_trainer_t *self, uint32_t prev_guess, uint32_t prev_truth, uint32_t guess, uint32_t truth, double value) { + bool add_if_missing = true; + khash_t(prev_tag_class_weights) *weights = self->trans_weights; + + if (weights == NULL) { + return false; + } + + uint64_t updates = self->num_updates; + + if (!crf_averaged_perceptron_trainer_update_prev_tag_weight(weights, updates, prev_guess, guess, -1.0 * value) || + !crf_averaged_perceptron_trainer_update_prev_tag_weight(weights, updates, prev_truth, truth, value)) { + return false; + } + + return true; +} + + +static inline bool crf_averaged_perceptron_trainer_cache_features(crf_averaged_perceptron_trainer_t *self, cstring_array *features) { + size_t i; + char *feature; + uint32_t feature_id; + + cstring_array_foreach(features, i, feature, { + cstring_array_add_string(self->sequence_features, feature); + }) + + size_t num_strings = cstring_array_num_strings(self->sequence_features); + uint32_array_push(self->sequence_features_indptr, num_strings); + return true; +} + + +static inline bool crf_averaged_perceptron_trainer_cache_prev_tag_features(crf_averaged_perceptron_trainer_t *self, cstring_array *features) { + size_t i; + char *feature; + uint32_t feature_id; + + cstring_array_foreach(features, i, feature, { + cstring_array_add_string(self->sequence_prev_tag_features, feature); + }) + + size_t num_strings = cstring_array_num_strings(self->sequence_prev_tag_features); + uint32_array_push(self->sequence_prev_tag_features_indptr, num_strings); + return true; +} + + +static bool crf_averaged_perceptron_trainer_state_score(crf_averaged_perceptron_trainer_t *self) { + if (self == NULL || self->base_trainer == NULL || + self->sequence_features == NULL || self->sequence_features_indptr == NULL) { + return false; + } + crf_context_t *context = self->base_trainer->context; + + uint32_t class_id; + + class_weight_t weight; + + cstring_array *sequence_features = self->sequence_features; + + uint64_t *update_counts = self->update_counts->a; + + size_t num_tokens = self->sequence_features_indptr->n - 1; + uint32_t *indptr = self->sequence_features_indptr->a; + + for (size_t t = 0; t < num_tokens; t++) { + uint32_t idx = indptr[t]; + uint32_t next_start = indptr[t + 1]; + + double *scores = state_score(context, t); + + for (uint32_t j = idx; j < next_start; j++) { + char *feature = cstring_array_get_string(sequence_features, j); + + uint32_t feature_id; + if (!crf_trainer_get_feature_id(self->base_trainer, feature, &feature_id)) { + continue; + } + uint64_t update_count = update_counts[feature_id]; + bool keep_feature = update_count >= self->min_updates; + + if (keep_feature) { + bool add_if_missing = false; + khash_t(class_weights) *weights = crf_averaged_perceptron_trainer_get_class_weights(self, feature_id, add_if_missing); + + if (weights == NULL) { + continue; + } + + kh_foreach(weights, class_id, weight, { + scores[class_id] += weight.value; + }) + } + } + } + + return true; +} + +static bool crf_averaged_perceptron_trainer_state_trans_score(crf_averaged_perceptron_trainer_t *self) { + if (self == NULL || self->base_trainer == NULL || + self->sequence_prev_tag_features == NULL || self->sequence_features_indptr == NULL) { + return false; + } + crf_context_t* context = self->base_trainer->context; + + uint32_t t = 0; + uint32_t idx = 0; + uint32_t length = 0; + + bool add_if_missing = false; + + class_weight_t weight; + + cstring_array *sequence_features = self->sequence_prev_tag_features; + uint64_t *update_counts = self->prev_tag_update_counts->a; + + size_t num_tokens = self->sequence_prev_tag_features_indptr->n - 1; + uint32_t *indptr = self->sequence_prev_tag_features_indptr->a; + + for (size_t t = 0; t < num_tokens; t++) { + uint32_t idx = indptr[t]; + uint32_t next_start = indptr[t + 1]; + + double *scores = state_trans_score_all(context, t); + + for (uint32_t j = idx; j < next_start; j++) { + char *feature = cstring_array_get_string(sequence_features, j); + + uint32_t feature_id; + if (!crf_trainer_get_prev_tag_feature_id(self->base_trainer, feature, &feature_id)) { + continue; + } + uint64_t update_count = update_counts[feature_id]; + bool keep_feature = update_count >= self->min_updates; + + if (keep_feature) { + bool add_if_missing = false; + khash_t(prev_tag_class_weights) *prev_tag_weights = crf_averaged_perceptron_trainer_get_prev_tag_class_weights(self, feature_id, add_if_missing); + + if (prev_tag_weights == NULL) { + continue; + } + + tag_bigram_t tag_bigram; + uint64_t tag_bigram_key; + + kh_foreach(prev_tag_weights, tag_bigram_key, weight, { + tag_bigram.value = tag_bigram_key; + uint32_t class_id = tag_bigram_class_id(self, tag_bigram); + scores[class_id] += weight.value; + }) + } + } + } + + return true; +} + +static bool crf_averaged_perceptron_trainer_trans_score(crf_averaged_perceptron_trainer_t *self) { + if (self == NULL || self->base_trainer == NULL || self->trans_weights == NULL) return false; + crf_context_t *context = self->base_trainer->context; + + khash_t(prev_tag_class_weights) *trans_weights = self->trans_weights; + + class_weight_t weight; + tag_bigram_t tag_bigram; + uint64_t tag_bigram_key; + + double *scores = context->trans->values; + + kh_foreach(trans_weights, tag_bigram_key, weight, { + tag_bigram.value = tag_bigram_key; + uint32_t class_id = tag_bigram_class_id(self, tag_bigram); + scores[class_id] += weight.value; + }) + + return true; +} + +bool crf_averaged_perceptron_trainer_update(crf_averaged_perceptron_trainer_t *self, double value) { + if (self->viterbi == NULL || self->label_ids == NULL || self->label_ids->n != self->viterbi->n || + self->sequence_features == NULL || self->sequence_features_indptr == NULL || + self->label_ids->n != self->sequence_features_indptr->n - 1 || + self->sequence_prev_tag_features == NULL || self->sequence_prev_tag_features_indptr == NULL || + self->label_ids->n != self->sequence_prev_tag_features_indptr->n - 1 || + self->update_counts == NULL || self->prev_tag_update_counts == NULL) { + log_error("Something was NULL\n"); + return false; + } + + uint32_t t, idx, length; + + bool add_if_missing = false; + + uint32_t *viterbi = self->viterbi->a; + uint32_t *labels = self->label_ids->a; + + uint32_t truth, guess; + + size_t num_tokens = self->sequence_features_indptr->n - 1; + uint32_t *indptr = self->sequence_features_indptr->a; + + cstring_array *sequence_features = self->sequence_features; + + for (size_t t = 0; t < num_tokens; t++) { + truth = labels[t]; + guess = viterbi[t]; + + if (guess != truth) { + uint32_t idx = indptr[t]; + uint32_t next_start = indptr[t + 1]; + + for (uint32_t j = idx; j < next_start; j++) { + char *feature = cstring_array_get_string(sequence_features, j); + if (feature == NULL) { + log_error("feature NULL, j = %u, len = %zu\n", j, cstring_array_num_strings(sequence_features)); + return false; + } + + uint32_t feature_id; + bool exists; + if (!crf_trainer_hash_feature_to_id_exists(self->base_trainer, feature, &feature_id, &exists)) { + return false; + } + + if (!crf_averaged_perceptron_trainer_update_feature(self, feature_id, guess, truth, value)) { + return false; + } + + if (exists) { + self->update_counts->a[feature_id]++; + } else { + uint64_array_push(self->update_counts, 1); + } + } + // This is shared between the state and state-trans features, only increment once + self->num_updates++; + self->num_errors++; + } + } + + uint32_t prev_truth, prev_guess; + + uint64_t *prev_tag_update_counts = self->prev_tag_update_counts->a; + + sequence_features = self->sequence_prev_tag_features; + + num_tokens = self->sequence_prev_tag_features_indptr->n - 1; + indptr = self->sequence_prev_tag_features_indptr->a; + + for (size_t t = 0; t < num_tokens; t++) { + truth = labels[t]; + guess = viterbi[t]; + + if (t > 0 && guess != truth) { + uint32_t idx = indptr[t]; + uint32_t next_start = indptr[t + 1]; + + for (uint32_t j = idx; j < next_start; j++) { + char *feature = cstring_array_get_string(sequence_features, j); + + if (feature == NULL) { + log_error("feature NULL, j = %u, len = %zu\n", j, cstring_array_num_strings(sequence_features)); + return false; + } + + uint32_t feature_id; + bool exists; + if (!crf_trainer_hash_prev_tag_feature_to_id_exists(self->base_trainer, feature, &feature_id, &exists)) { + return false; + } + + if (!crf_averaged_perceptron_trainer_update_prev_tag_feature(self, feature_id, prev_guess, prev_truth, guess, truth, value)) { + return false; + } + + if (exists) { + self->prev_tag_update_counts->a[feature_id]++; + } else { + uint64_array_push(self->prev_tag_update_counts, 1); + } + } + + } + + prev_truth = truth; + prev_guess = guess; + } + + size_t sequence_len = self->label_ids->n; + + for (t = 0; t < sequence_len; t++) { + truth = labels[t]; + guess = viterbi[t]; + + if (t > 0 && guess != truth) { + if (!crf_averaged_perceptron_trainer_update_trans_feature(self, prev_guess, prev_truth, guess, truth, value)) { + return false; + } + } + + prev_truth = truth; + prev_guess = guess; + } + + return true; +} + +bool crf_averaged_perceptron_trainer_train_example(crf_averaged_perceptron_trainer_t *self, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, tagger_feature_function feature_function, tokenized_string_t *tokenized, cstring_array *labels) { + if (self == NULL || self->base_trainer == NULL) return false; + + size_t num_tokens = tokenized->tokens->n; + if (cstring_array_num_strings(labels) != num_tokens) { + return false; + } + + uint32_array_clear(self->sequence_features_indptr); + uint32_array_push(self->sequence_features_indptr, 0); + cstring_array_clear(self->sequence_features); + + uint32_array_clear(self->sequence_prev_tag_features_indptr); + uint32_array_push(self->sequence_prev_tag_features_indptr, 0); + cstring_array_clear(self->sequence_prev_tag_features); + + crf_context_t *crf_context = self->base_trainer->context; + + if (!uint32_array_resize(self->label_ids, num_tokens)) { + log_error("Resizing label_ids failed\n"); + return false; + } + uint32_array_clear(self->label_ids); + + if (!crf_context_set_num_items(crf_context, num_tokens)) { + return false; + } + + crf_context_reset(crf_context, CRF_CONTEXT_RESET_ALL); + + bool add_if_missing = true; + + for (uint32_t i = 0; i < num_tokens; i++) { + cstring_array_clear(features); + cstring_array_clear(prev_tag_features); + + if (!feature_function(tagger, tagger_context, tokenized, i)) { + log_error("Could not add address parser features\n"); + return false; + } + + char *label = cstring_array_get_string(labels, i); + if (label == NULL) { + log_error("label is NULL\n"); + } + + uint32_t class_id; + + if (!crf_trainer_get_class_id(self->base_trainer, label, &class_id, add_if_missing)) { + log_error("Get class id failed\n"); + return false; + } + + uint32_array_push(self->label_ids, class_id); + + if (!crf_averaged_perceptron_trainer_cache_features(self, features) || + !crf_averaged_perceptron_trainer_cache_prev_tag_features(self, prev_tag_features)) { + log_error("Caching features failed\n"); + return false; + } + } + + if (!crf_averaged_perceptron_trainer_state_score(self)) { + log_error("Error in state score\n"); + return false; + } + + if (!crf_averaged_perceptron_trainer_state_trans_score(self)) { + log_error("Error in state_trans score\n"); + return false; + } + + if (!crf_averaged_perceptron_trainer_trans_score(self)) { + log_error("Error in trans score\n"); + return false; + } + + uint32_array_resize_fixed(self->viterbi, num_tokens); + + uint32_t *viterbi = self->viterbi->a; + double viterbi_score = crf_context_viterbi(crf_context, viterbi); + + if (self->viterbi->n != num_tokens || self->label_ids->n != num_tokens) { + log_error("self->viterbi->n=%zu, num_tokens=%zu, self->label_ids->n=%zu\n", self->viterbi->n, num_tokens, self->label_ids->n); + return false; + } + + uint32_t *true_labels = self->label_ids->a; + + for (uint32_t i = 0; i < num_tokens; i++) { + uint32_t truth = true_labels[i]; + + // Technically this is supposed to be updated all at once + uint32_t guess = viterbi[i]; + + if (guess != truth) { + if (!crf_averaged_perceptron_trainer_update(self, 1.0)) { + log_error("Error in crf_averaged_perceptron_trainer_update\n"); + return false; + } + break; + } + + } + + return true; +} + + +crf_t *crf_averaged_perceptron_trainer_finalize(crf_averaged_perceptron_trainer_t *self) { + if (self == NULL || self->base_trainer == NULL || self->base_trainer->num_classes == 0) { + log_error("Something was NULL\n"); + return NULL; + } + + uint32_t class_id; + class_weight_t weight; + + khiter_t k; + + size_t num_features = kh_size(self->base_trainer->features); + + sparse_matrix_t *averaged_weights = sparse_matrix_new(); + if (averaged_weights == NULL) { + log_error("Error creating averaged_weights\n"); + return NULL; + } + + log_info("Finalizing trainer, num_features=%zu\n", num_features); + + char **feature_keys = malloc(sizeof(char *) * num_features); + uint32_t feature_id; + const char *feature; + + kh_foreach(self->base_trainer->features, feature, feature_id, { + if (feature_id >= num_features) { + free(feature_keys); + log_error("Error populating feature_keys, feature_id=%u, num_features=%zu\n", feature_id, num_features); + return NULL; + } + feature_keys[feature_id] = (char *)feature; + }) + + khash_t(str_uint32) *features = self->base_trainer->features; + khash_t(str_uint32) *prev_tag_features = self->base_trainer->prev_tag_features; + + uint64_t updates = self->num_updates; + khash_t(class_weights) *weights; + + uint32_t next_feature_id = 0; + uint64_t *update_counts = self->update_counts->a; + + log_info("Pruning weights with < min_updates = %llu\n", self->min_updates); + + for (feature_id = 0; feature_id < num_features; feature_id++) { + k = kh_get(feature_class_weights, self->weights, feature_id); + if (k == kh_end(self->weights)) { + sparse_matrix_destroy(averaged_weights); + free(feature_keys); + log_error("Error in kh_get on self->weights, feature_id=%u, num_features=%zu\n", feature_id, num_features); + return NULL; + } + + weights = kh_value(self->weights, k); + uint32_t class_id; + + uint64_t update_count = update_counts[feature_id]; + bool keep_feature = update_count >= self->min_updates; + + uint32_t new_feature_id = next_feature_id; + + if (keep_feature) { + kh_foreach(weights, class_id, weight, { + weight.total += (updates - weight.last_updated) * weight.value; + double value = weight.total / updates; + sparse_matrix_append(averaged_weights, class_id, value); + }) + + sparse_matrix_finalize_row(averaged_weights); + next_feature_id++; + } + + + if (!keep_feature || new_feature_id != feature_id) { + feature = feature_keys[feature_id]; + k = kh_get(str_uint32, features, feature); + if (k != kh_end(features)) { + if (keep_feature) { + kh_value(features, k) = new_feature_id; + } else { + kh_del(str_uint32, features, k); + } + } else { + log_error("Error in kh_get on features\n"); + crf_averaged_perceptron_trainer_destroy(self); + free(feature_keys); + return NULL; + } + } + + } + + free(feature_keys); + + num_features = kh_size(features); + log_info("After pruning, num_features=%zu\n", num_features); + + sparse_matrix_t *averaged_state_trans_weights = sparse_matrix_new(); + if (averaged_state_trans_weights == NULL) { + log_error("Error creating averaged_state_trans_weights\n"); + return NULL; + } + + size_t num_prev_tag_features = kh_size(prev_tag_features); + + char **prev_tag_feature_keys = malloc(sizeof(char *) * num_prev_tag_features); + + + kh_foreach(prev_tag_features, feature, feature_id, { + if (feature_id >= num_prev_tag_features) { + free(prev_tag_feature_keys); + log_error("Error populating prev_tag_feature_keys\n"); + return NULL; + } + prev_tag_feature_keys[feature_id] = (char *)feature; + }) + + khash_t(prev_tag_class_weights) *prev_tag_weights; + + log_info("Pruning previous tag features, num_prev_tag_features=%zu\n", num_prev_tag_features); + + uint32_t next_prev_tag_feature_id = 0; + + uint64_t *prev_tag_update_counts = self->prev_tag_update_counts->a; + + tag_bigram_t tag_bigram; + uint64_t tag_bigram_key; + + for (feature_id = 0; feature_id < num_prev_tag_features; feature_id++) { + k = kh_get(feature_prev_tag_class_weights, self->prev_tag_weights, feature_id); + if (k == kh_end(self->prev_tag_weights)) { + sparse_matrix_destroy(averaged_state_trans_weights); + free(prev_tag_feature_keys); + log_error("Error in kh_get self->prev_tag_weights\n"); + return NULL; + } + + prev_tag_weights = kh_value(self->prev_tag_weights, k); + + uint64_t update_count = prev_tag_update_counts[feature_id]; + bool keep_feature = update_count >= self->min_updates; + + uint32_t new_feature_id = next_prev_tag_feature_id; + + if (keep_feature) { + kh_foreach(prev_tag_weights, tag_bigram_key, weight, { + tag_bigram.value = tag_bigram_key; + weight.total += (updates - weight.last_updated) * weight.value; + double value = weight.total / updates; + class_id = tag_bigram_class_id(self, tag_bigram); + sparse_matrix_append(averaged_state_trans_weights, class_id, value); + }) + + sparse_matrix_finalize_row(averaged_state_trans_weights); + + next_prev_tag_feature_id++; + } + + if (!keep_feature || new_feature_id != feature_id) { + feature = prev_tag_feature_keys[feature_id]; + k = kh_get(str_uint32, prev_tag_features, feature); + if (k != kh_end(prev_tag_features)) { + if (keep_feature) { + kh_value(prev_tag_features, k) = new_feature_id; + } else { + kh_del(str_uint32, prev_tag_features, k); + } + } else { + log_error("Error in kh_get on prev_tag_features\n"); + crf_averaged_perceptron_trainer_destroy(self); + free(prev_tag_feature_keys); + return NULL; + } + } + + } + + free(prev_tag_feature_keys); + + num_prev_tag_features = kh_size(prev_tag_features); + log_info("After pruning, num_prev_tag_features=%zu\n", num_prev_tag_features); + + + size_t num_classes = self->base_trainer->num_classes; + + double_matrix_t *averaged_trans_weights = double_matrix_new_zeros(num_classes, num_classes); + if (averaged_trans_weights == NULL) { + log_error("Error creating double matrix for transition weights\n"); + return NULL; + } + + double *trans = averaged_trans_weights->values; + + kh_foreach(self->trans_weights, tag_bigram_key, weight, { + tag_bigram.value = tag_bigram_key; + weight.total += (updates - weight.last_updated) * weight.value; + double value = weight.total / updates; + class_id = tag_bigram_class_id(self, tag_bigram); + trans[class_id] = value; + }) + + crf_t *crf = malloc(sizeof(crf_t)); + + crf->num_classes = num_classes; + crf->weights = averaged_weights; + crf->state_trans_weights = averaged_state_trans_weights; + crf->trans_weights = averaged_trans_weights; + crf->classes = self->base_trainer->class_strings; + self->base_trainer->class_strings = NULL; + + trie_t *state_features = trie_new_from_hash(features); + if (state_features == NULL) { + crf_averaged_perceptron_trainer_destroy(self); + log_error("Error creating state_features\n"); + return NULL; + } + + crf->state_features = state_features; + + trie_t *state_trans_features = trie_new_from_hash(prev_tag_features); + if (state_trans_features == NULL) { + crf_averaged_perceptron_trainer_destroy(self); + log_error("Error creating state_trans_features\n"); + return NULL; + } + + crf->state_trans_features = state_trans_features; + + crf->viterbi = self->viterbi; + self->viterbi = NULL; + + crf->context = self->base_trainer->context; + self->base_trainer->context = NULL; + + crf_averaged_perceptron_trainer_destroy(self); + + return crf; +} diff --git a/src/crf_trainer_averaged_perceptron.h b/src/crf_trainer_averaged_perceptron.h new file mode 100644 index 00000000..a6cd5e68 --- /dev/null +++ b/src/crf_trainer_averaged_perceptron.h @@ -0,0 +1,67 @@ +#ifndef CRF_AVERAGED_PERCEPTRON_TRAINER_H +#define CRF_AVERAGED_PERCEPTRON_TRAINER_H + +#include +#include + +#include "averaged_perceptron_trainer.h" +#include "crf.h" +#include "crf_trainer.h" +#include "collections.h" +#include "string_utils.h" +#include "tokens.h" +#include "trie.h" +#include "trie_utils.h" + +typedef union tag_bigram { + uint64_t value; + struct { + uint32_t prev_class_id:32; + uint32_t class_id:32; + }; +} tag_bigram_t; + +KHASH_MAP_INIT_INT64(prev_tag_class_weights, class_weight_t) + +KHASH_MAP_INIT_INT(feature_prev_tag_class_weights, khash_t(prev_tag_class_weights) *) + +typedef struct crf_averaged_perceptron_trainer { + crf_trainer_t *base_trainer; + uint64_t num_updates; + uint64_t num_errors; + uint32_t iterations; + uint64_t min_updates; + // {feature_id => {class_id => class_weight_t}} + khash_t(feature_class_weights) *weights; + khash_t(feature_prev_tag_class_weights) *prev_tag_weights; + khash_t(prev_tag_class_weights) *trans_weights; + uint64_array *update_counts; + uint64_array *prev_tag_update_counts; + cstring_array *sequence_features; + uint32_array *sequence_features_indptr; + cstring_array *sequence_prev_tag_features; + uint32_array *sequence_prev_tag_features_indptr; + uint32_array *label_ids; + uint32_array *viterbi; +} crf_averaged_perceptron_trainer_t; + +crf_averaged_perceptron_trainer_t *crf_averaged_perceptron_trainer_new(size_t num_classes, size_t min_updates); + +uint32_t crf_averaged_perceptron_trainer_predict(crf_averaged_perceptron_trainer_t *self, cstring_array *features); + +bool crf_averaged_perceptron_trainer_train_example(crf_averaged_perceptron_trainer_t *self, + void *tagger, + void *context, + cstring_array *features, + cstring_array *prev_tag_features, + tagger_feature_function feature_function, + tokenized_string_t *tokenized, + cstring_array *labels + ); + +crf_t *crf_averaged_perceptron_trainer_finalize(crf_averaged_perceptron_trainer_t *self); + +void crf_averaged_perceptron_trainer_destroy(crf_averaged_perceptron_trainer_t *self); + + +#endif