[parser/crf] adding an initial training algorithm for CRFs, the averaged
perceptron (FTW!)
Though it does not generate scores suitable for use as probabilties, and
might achieve slightly lower accuracy on some tasks than its
gradient-based counterparts like SGD (a possibility for libpostal)
or LBFGS (prohibitive on this much data), the averaged perceptron is
appealing for two reasons: speed and low memory usage i.e. we can still use
all the same tricks as in the greedy model like sparse construction of
the weight matrix. In this case we can go even sparser than in the
original because the state-transition features are separate from the
state features, and we need to be able to iterate over all of them
instead of simply creating new string keys in the feature space. The
solution to this is quite simple: we simply treat the weights for each
state-transition feature as if they have L * L output labels instead of
simply L. So instead of:
{
"prev|road|word|DD": {1: 1.0, 2: -1.0}
...
}
We'd have:
{
"word|DD": {(0, 1): 1.0, (0, 2): -1.0}
...
}
As usual we compress the features to a trie, and the weights to
compressed-sparse row (CSR) format sparse matrix after the weights have
been averaged. These representations are smaller, faster to load from
disk, and faster to use at runtime (contiguous arrays vs hashtables).
This also includes the min_updates variation from the greedy perceptron,
so features that participate in fewer than N updates are discarded at
the end (and also not used in scoring until they meet the threshold so
the model doesn't become dependent on features it doesn't really have).
This tends to discard irrelevant features, keeping the model small
without hurting accuracy much (within a tenth of a percent or so in my
tests on the greedy perceptron).
This commit is contained in:
943
src/crf_trainer_averaged_perceptron.c
Normal file
943
src/crf_trainer_averaged_perceptron.c
Normal file
@@ -0,0 +1,943 @@
|
||||
#include "crf_trainer_averaged_perceptron.h"
|
||||
|
||||
void crf_averaged_perceptron_trainer_destroy(crf_averaged_perceptron_trainer_t *self) {
|
||||
if (self == NULL) return;
|
||||
|
||||
uint32_t feature_id;
|
||||
khash_t(class_weights) *weights;
|
||||
|
||||
if (self->weights != NULL) {
|
||||
kh_foreach(self->weights, feature_id, weights, {
|
||||
if (weights != NULL) {
|
||||
kh_destroy(class_weights, weights);
|
||||
}
|
||||
})
|
||||
kh_destroy(feature_class_weights, self->weights);
|
||||
}
|
||||
|
||||
if (self->prev_tag_weights != NULL) {
|
||||
khash_t(prev_tag_class_weights) *prev_tag_weights;
|
||||
|
||||
kh_foreach(self->prev_tag_weights, feature_id, prev_tag_weights, {
|
||||
if (prev_tag_weights != NULL) {
|
||||
kh_destroy(prev_tag_class_weights, prev_tag_weights);
|
||||
}
|
||||
})
|
||||
|
||||
kh_destroy(feature_prev_tag_class_weights, self->prev_tag_weights);
|
||||
}
|
||||
|
||||
if (self->trans_weights != NULL) {
|
||||
kh_destroy(prev_tag_class_weights, self->trans_weights);
|
||||
}
|
||||
|
||||
if (self->update_counts != NULL) {
|
||||
uint64_array_destroy(self->update_counts);
|
||||
}
|
||||
|
||||
if (self->prev_tag_update_counts != NULL) {
|
||||
uint64_array_destroy(self->prev_tag_update_counts);
|
||||
}
|
||||
|
||||
if (self->sequence_features != NULL) {
|
||||
cstring_array_destroy(self->sequence_features);
|
||||
}
|
||||
|
||||
if (self->sequence_features_indptr != NULL) {
|
||||
uint32_array_destroy(self->sequence_features_indptr);
|
||||
}
|
||||
|
||||
if (self->sequence_prev_tag_features != NULL) {
|
||||
cstring_array_destroy(self->sequence_prev_tag_features);
|
||||
}
|
||||
|
||||
if (self->sequence_prev_tag_features_indptr != NULL) {
|
||||
uint32_array_destroy(self->sequence_prev_tag_features_indptr);
|
||||
}
|
||||
|
||||
if (self->label_ids != NULL) {
|
||||
uint32_array_destroy(self->label_ids);
|
||||
}
|
||||
|
||||
if (self->viterbi != NULL) {
|
||||
uint32_array_destroy(self->viterbi);
|
||||
}
|
||||
|
||||
if (self->base_trainer != NULL) {
|
||||
crf_trainer_destroy(self->base_trainer);
|
||||
}
|
||||
|
||||
free(self);
|
||||
}
|
||||
|
||||
crf_averaged_perceptron_trainer_t *crf_averaged_perceptron_trainer_new(size_t num_classes, size_t min_updates) {
|
||||
crf_averaged_perceptron_trainer_t *self = calloc(1, sizeof(crf_averaged_perceptron_trainer_t));
|
||||
|
||||
if (self == NULL) return NULL;
|
||||
|
||||
log_info("num_classes %zu\n", num_classes);
|
||||
|
||||
self->num_updates = 0;
|
||||
self->num_errors = 0;
|
||||
self->iterations = 0;
|
||||
self->min_updates = min_updates;
|
||||
|
||||
self->base_trainer = crf_trainer_new(num_classes);
|
||||
if (self->base_trainer == NULL) {
|
||||
goto exit_trainer_created;
|
||||
}
|
||||
|
||||
self->weights = kh_init(feature_class_weights);
|
||||
|
||||
if (self->weights == NULL) {
|
||||
goto exit_trainer_created;
|
||||
}
|
||||
|
||||
self->prev_tag_weights = kh_init(feature_prev_tag_class_weights);
|
||||
|
||||
if (self->prev_tag_weights == NULL) {
|
||||
goto exit_trainer_created;
|
||||
}
|
||||
|
||||
self->trans_weights = kh_init(prev_tag_class_weights);
|
||||
if (self->trans_weights == NULL) {
|
||||
goto exit_trainer_created;
|
||||
}
|
||||
|
||||
self->update_counts = uint64_array_new();
|
||||
if (self->update_counts == NULL) {
|
||||
goto exit_trainer_created;
|
||||
}
|
||||
|
||||
self->prev_tag_update_counts = uint64_array_new();
|
||||
if (self->prev_tag_update_counts == NULL) {
|
||||
goto exit_trainer_created;
|
||||
}
|
||||
|
||||
self->sequence_features = cstring_array_new();
|
||||
if (self->sequence_features == NULL) {
|
||||
goto exit_trainer_created;
|
||||
}
|
||||
|
||||
self->sequence_features_indptr = uint32_array_new();
|
||||
if (self->sequence_features_indptr == NULL) {
|
||||
goto exit_trainer_created;
|
||||
}
|
||||
|
||||
self->sequence_prev_tag_features = cstring_array_new();
|
||||
if (self->sequence_prev_tag_features == NULL) {
|
||||
goto exit_trainer_created;
|
||||
}
|
||||
|
||||
self->sequence_prev_tag_features_indptr = uint32_array_new();
|
||||
if (self->sequence_prev_tag_features_indptr == NULL) {
|
||||
goto exit_trainer_created;
|
||||
}
|
||||
|
||||
self->label_ids = uint32_array_new();
|
||||
if (self->label_ids == NULL) {
|
||||
goto exit_trainer_created;
|
||||
}
|
||||
|
||||
self->viterbi = uint32_array_new();
|
||||
if (self->viterbi == NULL) {
|
||||
goto exit_trainer_created;
|
||||
}
|
||||
|
||||
return self;
|
||||
|
||||
exit_trainer_created:
|
||||
crf_averaged_perceptron_trainer_destroy(self);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static inline uint32_t tag_bigram_class_id(crf_averaged_perceptron_trainer_t *self, tag_bigram_t tag_bigram) {
|
||||
return tag_bigram.prev_class_id * self->base_trainer->num_classes + tag_bigram.class_id;
|
||||
}
|
||||
|
||||
khash_t(class_weights) *crf_averaged_perceptron_trainer_get_class_weights(crf_averaged_perceptron_trainer_t *self, uint32_t feature_id, bool add_if_missing) {
|
||||
khiter_t k;
|
||||
k = kh_get(feature_class_weights, self->weights, feature_id);
|
||||
if (k != kh_end(self->weights)) {
|
||||
return kh_value(self->weights, k);
|
||||
} else if (add_if_missing) {
|
||||
khash_t(class_weights) *weights = kh_init(class_weights);
|
||||
int ret;
|
||||
k = kh_put(feature_class_weights, self->weights, feature_id, &ret);
|
||||
if (ret < 0) {
|
||||
kh_destroy(class_weights, weights);
|
||||
return NULL;
|
||||
}
|
||||
kh_value(self->weights, k) = weights;
|
||||
return weights;
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
khash_t(prev_tag_class_weights) *crf_averaged_perceptron_trainer_get_prev_tag_class_weights(crf_averaged_perceptron_trainer_t *self, uint32_t feature_id, bool add_if_missing) {
|
||||
khiter_t k;
|
||||
k = kh_get(feature_prev_tag_class_weights, self->prev_tag_weights, feature_id);
|
||||
if (k != kh_end(self->prev_tag_weights)) {
|
||||
return kh_value(self->prev_tag_weights, k);
|
||||
} else if (add_if_missing) {
|
||||
khash_t(prev_tag_class_weights) *weights = kh_init(prev_tag_class_weights);
|
||||
int ret;
|
||||
k = kh_put(feature_prev_tag_class_weights, self->prev_tag_weights, feature_id, &ret);
|
||||
if (ret < 0) {
|
||||
kh_destroy(prev_tag_class_weights, weights);
|
||||
return NULL;
|
||||
}
|
||||
kh_value(self->prev_tag_weights, k) = weights;
|
||||
return weights;
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
static inline bool crf_averaged_perceptron_trainer_update_weight(khash_t(class_weights) *weights, uint64_t iter, uint32_t class_id, double value) {
|
||||
class_weight_t weight = NULL_WEIGHT;
|
||||
|
||||
khiter_t k;
|
||||
k = kh_get(class_weights, weights, class_id);
|
||||
if (k != kh_end(weights)) {
|
||||
weight = kh_value(weights, k);
|
||||
}
|
||||
|
||||
weight.total += (iter - weight.last_updated) * weight.value;
|
||||
weight.last_updated = iter;
|
||||
weight.value += value;
|
||||
|
||||
int ret;
|
||||
k = kh_put(class_weights, weights, class_id, &ret);
|
||||
if (ret < 0) return false;
|
||||
kh_value(weights, k) = weight;
|
||||
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
static inline bool crf_averaged_perceptron_trainer_update_prev_tag_weight(khash_t(prev_tag_class_weights) *weights, uint64_t iter, uint32_t prev_class_id, uint32_t class_id, double value) {
|
||||
class_weight_t weight = NULL_WEIGHT;
|
||||
|
||||
tag_bigram_t tag_bigram = {.prev_class_id = prev_class_id, .class_id = class_id};
|
||||
|
||||
uint64_t key = tag_bigram.value;
|
||||
|
||||
khiter_t k;
|
||||
k = kh_get(prev_tag_class_weights, weights, key);
|
||||
if (k != kh_end(weights)) {
|
||||
weight = kh_value(weights, k);
|
||||
}
|
||||
|
||||
weight.total += (iter - weight.last_updated) * weight.value;
|
||||
weight.last_updated = iter;
|
||||
weight.value += value;
|
||||
|
||||
int ret;
|
||||
k = kh_put(prev_tag_class_weights, weights, key, &ret);
|
||||
if (ret < 0) return false;
|
||||
kh_value(weights, k) = weight;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
static inline bool crf_averaged_perceptron_trainer_update_feature(crf_averaged_perceptron_trainer_t *self, uint32_t feature_id, uint32_t guess, uint32_t truth, double value) {
|
||||
bool add_if_missing = true;
|
||||
|
||||
khash_t(class_weights) *weights = crf_averaged_perceptron_trainer_get_class_weights(self, feature_id, add_if_missing);
|
||||
|
||||
if (weights == NULL) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint64_t updates = self->num_updates;
|
||||
|
||||
if (!crf_averaged_perceptron_trainer_update_weight(weights, updates, guess, -1.0 * value) ||
|
||||
!crf_averaged_perceptron_trainer_update_weight(weights, updates, truth, value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
static inline bool crf_averaged_perceptron_trainer_update_prev_tag_feature(crf_averaged_perceptron_trainer_t *self, uint32_t feature_id, uint32_t prev_guess, uint32_t prev_truth, uint32_t guess, uint32_t truth, double value) {
|
||||
bool add_if_missing = true;
|
||||
khash_t(prev_tag_class_weights) *weights = crf_averaged_perceptron_trainer_get_prev_tag_class_weights(self, feature_id, add_if_missing);
|
||||
|
||||
if (weights == NULL) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint64_t updates = self->num_updates;
|
||||
|
||||
if (!crf_averaged_perceptron_trainer_update_prev_tag_weight(weights, updates, prev_guess, guess, -1.0 * value) ||
|
||||
!crf_averaged_perceptron_trainer_update_prev_tag_weight(weights, updates, prev_truth, truth, value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline bool crf_averaged_perceptron_trainer_update_trans_feature(crf_averaged_perceptron_trainer_t *self, uint32_t prev_guess, uint32_t prev_truth, uint32_t guess, uint32_t truth, double value) {
|
||||
bool add_if_missing = true;
|
||||
khash_t(prev_tag_class_weights) *weights = self->trans_weights;
|
||||
|
||||
if (weights == NULL) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint64_t updates = self->num_updates;
|
||||
|
||||
if (!crf_averaged_perceptron_trainer_update_prev_tag_weight(weights, updates, prev_guess, guess, -1.0 * value) ||
|
||||
!crf_averaged_perceptron_trainer_update_prev_tag_weight(weights, updates, prev_truth, truth, value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
static inline bool crf_averaged_perceptron_trainer_cache_features(crf_averaged_perceptron_trainer_t *self, cstring_array *features) {
|
||||
size_t i;
|
||||
char *feature;
|
||||
uint32_t feature_id;
|
||||
|
||||
cstring_array_foreach(features, i, feature, {
|
||||
cstring_array_add_string(self->sequence_features, feature);
|
||||
})
|
||||
|
||||
size_t num_strings = cstring_array_num_strings(self->sequence_features);
|
||||
uint32_array_push(self->sequence_features_indptr, num_strings);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
static inline bool crf_averaged_perceptron_trainer_cache_prev_tag_features(crf_averaged_perceptron_trainer_t *self, cstring_array *features) {
|
||||
size_t i;
|
||||
char *feature;
|
||||
uint32_t feature_id;
|
||||
|
||||
cstring_array_foreach(features, i, feature, {
|
||||
cstring_array_add_string(self->sequence_prev_tag_features, feature);
|
||||
})
|
||||
|
||||
size_t num_strings = cstring_array_num_strings(self->sequence_prev_tag_features);
|
||||
uint32_array_push(self->sequence_prev_tag_features_indptr, num_strings);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
static bool crf_averaged_perceptron_trainer_state_score(crf_averaged_perceptron_trainer_t *self) {
|
||||
if (self == NULL || self->base_trainer == NULL ||
|
||||
self->sequence_features == NULL || self->sequence_features_indptr == NULL) {
|
||||
return false;
|
||||
}
|
||||
crf_context_t *context = self->base_trainer->context;
|
||||
|
||||
uint32_t class_id;
|
||||
|
||||
class_weight_t weight;
|
||||
|
||||
cstring_array *sequence_features = self->sequence_features;
|
||||
|
||||
uint64_t *update_counts = self->update_counts->a;
|
||||
|
||||
size_t num_tokens = self->sequence_features_indptr->n - 1;
|
||||
uint32_t *indptr = self->sequence_features_indptr->a;
|
||||
|
||||
for (size_t t = 0; t < num_tokens; t++) {
|
||||
uint32_t idx = indptr[t];
|
||||
uint32_t next_start = indptr[t + 1];
|
||||
|
||||
double *scores = state_score(context, t);
|
||||
|
||||
for (uint32_t j = idx; j < next_start; j++) {
|
||||
char *feature = cstring_array_get_string(sequence_features, j);
|
||||
|
||||
uint32_t feature_id;
|
||||
if (!crf_trainer_get_feature_id(self->base_trainer, feature, &feature_id)) {
|
||||
continue;
|
||||
}
|
||||
uint64_t update_count = update_counts[feature_id];
|
||||
bool keep_feature = update_count >= self->min_updates;
|
||||
|
||||
if (keep_feature) {
|
||||
bool add_if_missing = false;
|
||||
khash_t(class_weights) *weights = crf_averaged_perceptron_trainer_get_class_weights(self, feature_id, add_if_missing);
|
||||
|
||||
if (weights == NULL) {
|
||||
continue;
|
||||
}
|
||||
|
||||
kh_foreach(weights, class_id, weight, {
|
||||
scores[class_id] += weight.value;
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool crf_averaged_perceptron_trainer_state_trans_score(crf_averaged_perceptron_trainer_t *self) {
|
||||
if (self == NULL || self->base_trainer == NULL ||
|
||||
self->sequence_prev_tag_features == NULL || self->sequence_features_indptr == NULL) {
|
||||
return false;
|
||||
}
|
||||
crf_context_t* context = self->base_trainer->context;
|
||||
|
||||
uint32_t t = 0;
|
||||
uint32_t idx = 0;
|
||||
uint32_t length = 0;
|
||||
|
||||
bool add_if_missing = false;
|
||||
|
||||
class_weight_t weight;
|
||||
|
||||
cstring_array *sequence_features = self->sequence_prev_tag_features;
|
||||
uint64_t *update_counts = self->prev_tag_update_counts->a;
|
||||
|
||||
size_t num_tokens = self->sequence_prev_tag_features_indptr->n - 1;
|
||||
uint32_t *indptr = self->sequence_prev_tag_features_indptr->a;
|
||||
|
||||
for (size_t t = 0; t < num_tokens; t++) {
|
||||
uint32_t idx = indptr[t];
|
||||
uint32_t next_start = indptr[t + 1];
|
||||
|
||||
double *scores = state_trans_score_all(context, t);
|
||||
|
||||
for (uint32_t j = idx; j < next_start; j++) {
|
||||
char *feature = cstring_array_get_string(sequence_features, j);
|
||||
|
||||
uint32_t feature_id;
|
||||
if (!crf_trainer_get_prev_tag_feature_id(self->base_trainer, feature, &feature_id)) {
|
||||
continue;
|
||||
}
|
||||
uint64_t update_count = update_counts[feature_id];
|
||||
bool keep_feature = update_count >= self->min_updates;
|
||||
|
||||
if (keep_feature) {
|
||||
bool add_if_missing = false;
|
||||
khash_t(prev_tag_class_weights) *prev_tag_weights = crf_averaged_perceptron_trainer_get_prev_tag_class_weights(self, feature_id, add_if_missing);
|
||||
|
||||
if (prev_tag_weights == NULL) {
|
||||
continue;
|
||||
}
|
||||
|
||||
tag_bigram_t tag_bigram;
|
||||
uint64_t tag_bigram_key;
|
||||
|
||||
kh_foreach(prev_tag_weights, tag_bigram_key, weight, {
|
||||
tag_bigram.value = tag_bigram_key;
|
||||
uint32_t class_id = tag_bigram_class_id(self, tag_bigram);
|
||||
scores[class_id] += weight.value;
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool crf_averaged_perceptron_trainer_trans_score(crf_averaged_perceptron_trainer_t *self) {
|
||||
if (self == NULL || self->base_trainer == NULL || self->trans_weights == NULL) return false;
|
||||
crf_context_t *context = self->base_trainer->context;
|
||||
|
||||
khash_t(prev_tag_class_weights) *trans_weights = self->trans_weights;
|
||||
|
||||
class_weight_t weight;
|
||||
tag_bigram_t tag_bigram;
|
||||
uint64_t tag_bigram_key;
|
||||
|
||||
double *scores = context->trans->values;
|
||||
|
||||
kh_foreach(trans_weights, tag_bigram_key, weight, {
|
||||
tag_bigram.value = tag_bigram_key;
|
||||
uint32_t class_id = tag_bigram_class_id(self, tag_bigram);
|
||||
scores[class_id] += weight.value;
|
||||
})
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool crf_averaged_perceptron_trainer_update(crf_averaged_perceptron_trainer_t *self, double value) {
|
||||
if (self->viterbi == NULL || self->label_ids == NULL || self->label_ids->n != self->viterbi->n ||
|
||||
self->sequence_features == NULL || self->sequence_features_indptr == NULL ||
|
||||
self->label_ids->n != self->sequence_features_indptr->n - 1 ||
|
||||
self->sequence_prev_tag_features == NULL || self->sequence_prev_tag_features_indptr == NULL ||
|
||||
self->label_ids->n != self->sequence_prev_tag_features_indptr->n - 1 ||
|
||||
self->update_counts == NULL || self->prev_tag_update_counts == NULL) {
|
||||
log_error("Something was NULL\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t t, idx, length;
|
||||
|
||||
bool add_if_missing = false;
|
||||
|
||||
uint32_t *viterbi = self->viterbi->a;
|
||||
uint32_t *labels = self->label_ids->a;
|
||||
|
||||
uint32_t truth, guess;
|
||||
|
||||
size_t num_tokens = self->sequence_features_indptr->n - 1;
|
||||
uint32_t *indptr = self->sequence_features_indptr->a;
|
||||
|
||||
cstring_array *sequence_features = self->sequence_features;
|
||||
|
||||
for (size_t t = 0; t < num_tokens; t++) {
|
||||
truth = labels[t];
|
||||
guess = viterbi[t];
|
||||
|
||||
if (guess != truth) {
|
||||
uint32_t idx = indptr[t];
|
||||
uint32_t next_start = indptr[t + 1];
|
||||
|
||||
for (uint32_t j = idx; j < next_start; j++) {
|
||||
char *feature = cstring_array_get_string(sequence_features, j);
|
||||
if (feature == NULL) {
|
||||
log_error("feature NULL, j = %u, len = %zu\n", j, cstring_array_num_strings(sequence_features));
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t feature_id;
|
||||
bool exists;
|
||||
if (!crf_trainer_hash_feature_to_id_exists(self->base_trainer, feature, &feature_id, &exists)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!crf_averaged_perceptron_trainer_update_feature(self, feature_id, guess, truth, value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (exists) {
|
||||
self->update_counts->a[feature_id]++;
|
||||
} else {
|
||||
uint64_array_push(self->update_counts, 1);
|
||||
}
|
||||
}
|
||||
// This is shared between the state and state-trans features, only increment once
|
||||
self->num_updates++;
|
||||
self->num_errors++;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t prev_truth, prev_guess;
|
||||
|
||||
uint64_t *prev_tag_update_counts = self->prev_tag_update_counts->a;
|
||||
|
||||
sequence_features = self->sequence_prev_tag_features;
|
||||
|
||||
num_tokens = self->sequence_prev_tag_features_indptr->n - 1;
|
||||
indptr = self->sequence_prev_tag_features_indptr->a;
|
||||
|
||||
for (size_t t = 0; t < num_tokens; t++) {
|
||||
truth = labels[t];
|
||||
guess = viterbi[t];
|
||||
|
||||
if (t > 0 && guess != truth) {
|
||||
uint32_t idx = indptr[t];
|
||||
uint32_t next_start = indptr[t + 1];
|
||||
|
||||
for (uint32_t j = idx; j < next_start; j++) {
|
||||
char *feature = cstring_array_get_string(sequence_features, j);
|
||||
|
||||
if (feature == NULL) {
|
||||
log_error("feature NULL, j = %u, len = %zu\n", j, cstring_array_num_strings(sequence_features));
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t feature_id;
|
||||
bool exists;
|
||||
if (!crf_trainer_hash_prev_tag_feature_to_id_exists(self->base_trainer, feature, &feature_id, &exists)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!crf_averaged_perceptron_trainer_update_prev_tag_feature(self, feature_id, prev_guess, prev_truth, guess, truth, value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (exists) {
|
||||
self->prev_tag_update_counts->a[feature_id]++;
|
||||
} else {
|
||||
uint64_array_push(self->prev_tag_update_counts, 1);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
prev_truth = truth;
|
||||
prev_guess = guess;
|
||||
}
|
||||
|
||||
size_t sequence_len = self->label_ids->n;
|
||||
|
||||
for (t = 0; t < sequence_len; t++) {
|
||||
truth = labels[t];
|
||||
guess = viterbi[t];
|
||||
|
||||
if (t > 0 && guess != truth) {
|
||||
if (!crf_averaged_perceptron_trainer_update_trans_feature(self, prev_guess, prev_truth, guess, truth, value)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
prev_truth = truth;
|
||||
prev_guess = guess;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool crf_averaged_perceptron_trainer_train_example(crf_averaged_perceptron_trainer_t *self, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, tagger_feature_function feature_function, tokenized_string_t *tokenized, cstring_array *labels) {
|
||||
if (self == NULL || self->base_trainer == NULL) return false;
|
||||
|
||||
size_t num_tokens = tokenized->tokens->n;
|
||||
if (cstring_array_num_strings(labels) != num_tokens) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_array_clear(self->sequence_features_indptr);
|
||||
uint32_array_push(self->sequence_features_indptr, 0);
|
||||
cstring_array_clear(self->sequence_features);
|
||||
|
||||
uint32_array_clear(self->sequence_prev_tag_features_indptr);
|
||||
uint32_array_push(self->sequence_prev_tag_features_indptr, 0);
|
||||
cstring_array_clear(self->sequence_prev_tag_features);
|
||||
|
||||
crf_context_t *crf_context = self->base_trainer->context;
|
||||
|
||||
if (!uint32_array_resize(self->label_ids, num_tokens)) {
|
||||
log_error("Resizing label_ids failed\n");
|
||||
return false;
|
||||
}
|
||||
uint32_array_clear(self->label_ids);
|
||||
|
||||
if (!crf_context_set_num_items(crf_context, num_tokens)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
crf_context_reset(crf_context, CRF_CONTEXT_RESET_ALL);
|
||||
|
||||
bool add_if_missing = true;
|
||||
|
||||
for (uint32_t i = 0; i < num_tokens; i++) {
|
||||
cstring_array_clear(features);
|
||||
cstring_array_clear(prev_tag_features);
|
||||
|
||||
if (!feature_function(tagger, tagger_context, tokenized, i)) {
|
||||
log_error("Could not add address parser features\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
char *label = cstring_array_get_string(labels, i);
|
||||
if (label == NULL) {
|
||||
log_error("label is NULL\n");
|
||||
}
|
||||
|
||||
uint32_t class_id;
|
||||
|
||||
if (!crf_trainer_get_class_id(self->base_trainer, label, &class_id, add_if_missing)) {
|
||||
log_error("Get class id failed\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_array_push(self->label_ids, class_id);
|
||||
|
||||
if (!crf_averaged_perceptron_trainer_cache_features(self, features) ||
|
||||
!crf_averaged_perceptron_trainer_cache_prev_tag_features(self, prev_tag_features)) {
|
||||
log_error("Caching features failed\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!crf_averaged_perceptron_trainer_state_score(self)) {
|
||||
log_error("Error in state score\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!crf_averaged_perceptron_trainer_state_trans_score(self)) {
|
||||
log_error("Error in state_trans score\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!crf_averaged_perceptron_trainer_trans_score(self)) {
|
||||
log_error("Error in trans score\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_array_resize_fixed(self->viterbi, num_tokens);
|
||||
|
||||
uint32_t *viterbi = self->viterbi->a;
|
||||
double viterbi_score = crf_context_viterbi(crf_context, viterbi);
|
||||
|
||||
if (self->viterbi->n != num_tokens || self->label_ids->n != num_tokens) {
|
||||
log_error("self->viterbi->n=%zu, num_tokens=%zu, self->label_ids->n=%zu\n", self->viterbi->n, num_tokens, self->label_ids->n);
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t *true_labels = self->label_ids->a;
|
||||
|
||||
for (uint32_t i = 0; i < num_tokens; i++) {
|
||||
uint32_t truth = true_labels[i];
|
||||
|
||||
// Technically this is supposed to be updated all at once
|
||||
uint32_t guess = viterbi[i];
|
||||
|
||||
if (guess != truth) {
|
||||
if (!crf_averaged_perceptron_trainer_update(self, 1.0)) {
|
||||
log_error("Error in crf_averaged_perceptron_trainer_update\n");
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
crf_t *crf_averaged_perceptron_trainer_finalize(crf_averaged_perceptron_trainer_t *self) {
|
||||
if (self == NULL || self->base_trainer == NULL || self->base_trainer->num_classes == 0) {
|
||||
log_error("Something was NULL\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
uint32_t class_id;
|
||||
class_weight_t weight;
|
||||
|
||||
khiter_t k;
|
||||
|
||||
size_t num_features = kh_size(self->base_trainer->features);
|
||||
|
||||
sparse_matrix_t *averaged_weights = sparse_matrix_new();
|
||||
if (averaged_weights == NULL) {
|
||||
log_error("Error creating averaged_weights\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
log_info("Finalizing trainer, num_features=%zu\n", num_features);
|
||||
|
||||
char **feature_keys = malloc(sizeof(char *) * num_features);
|
||||
uint32_t feature_id;
|
||||
const char *feature;
|
||||
|
||||
kh_foreach(self->base_trainer->features, feature, feature_id, {
|
||||
if (feature_id >= num_features) {
|
||||
free(feature_keys);
|
||||
log_error("Error populating feature_keys, feature_id=%u, num_features=%zu\n", feature_id, num_features);
|
||||
return NULL;
|
||||
}
|
||||
feature_keys[feature_id] = (char *)feature;
|
||||
})
|
||||
|
||||
khash_t(str_uint32) *features = self->base_trainer->features;
|
||||
khash_t(str_uint32) *prev_tag_features = self->base_trainer->prev_tag_features;
|
||||
|
||||
uint64_t updates = self->num_updates;
|
||||
khash_t(class_weights) *weights;
|
||||
|
||||
uint32_t next_feature_id = 0;
|
||||
uint64_t *update_counts = self->update_counts->a;
|
||||
|
||||
log_info("Pruning weights with < min_updates = %llu\n", self->min_updates);
|
||||
|
||||
for (feature_id = 0; feature_id < num_features; feature_id++) {
|
||||
k = kh_get(feature_class_weights, self->weights, feature_id);
|
||||
if (k == kh_end(self->weights)) {
|
||||
sparse_matrix_destroy(averaged_weights);
|
||||
free(feature_keys);
|
||||
log_error("Error in kh_get on self->weights, feature_id=%u, num_features=%zu\n", feature_id, num_features);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
weights = kh_value(self->weights, k);
|
||||
uint32_t class_id;
|
||||
|
||||
uint64_t update_count = update_counts[feature_id];
|
||||
bool keep_feature = update_count >= self->min_updates;
|
||||
|
||||
uint32_t new_feature_id = next_feature_id;
|
||||
|
||||
if (keep_feature) {
|
||||
kh_foreach(weights, class_id, weight, {
|
||||
weight.total += (updates - weight.last_updated) * weight.value;
|
||||
double value = weight.total / updates;
|
||||
sparse_matrix_append(averaged_weights, class_id, value);
|
||||
})
|
||||
|
||||
sparse_matrix_finalize_row(averaged_weights);
|
||||
next_feature_id++;
|
||||
}
|
||||
|
||||
|
||||
if (!keep_feature || new_feature_id != feature_id) {
|
||||
feature = feature_keys[feature_id];
|
||||
k = kh_get(str_uint32, features, feature);
|
||||
if (k != kh_end(features)) {
|
||||
if (keep_feature) {
|
||||
kh_value(features, k) = new_feature_id;
|
||||
} else {
|
||||
kh_del(str_uint32, features, k);
|
||||
}
|
||||
} else {
|
||||
log_error("Error in kh_get on features\n");
|
||||
crf_averaged_perceptron_trainer_destroy(self);
|
||||
free(feature_keys);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
free(feature_keys);
|
||||
|
||||
num_features = kh_size(features);
|
||||
log_info("After pruning, num_features=%zu\n", num_features);
|
||||
|
||||
sparse_matrix_t *averaged_state_trans_weights = sparse_matrix_new();
|
||||
if (averaged_state_trans_weights == NULL) {
|
||||
log_error("Error creating averaged_state_trans_weights\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
size_t num_prev_tag_features = kh_size(prev_tag_features);
|
||||
|
||||
char **prev_tag_feature_keys = malloc(sizeof(char *) * num_prev_tag_features);
|
||||
|
||||
|
||||
kh_foreach(prev_tag_features, feature, feature_id, {
|
||||
if (feature_id >= num_prev_tag_features) {
|
||||
free(prev_tag_feature_keys);
|
||||
log_error("Error populating prev_tag_feature_keys\n");
|
||||
return NULL;
|
||||
}
|
||||
prev_tag_feature_keys[feature_id] = (char *)feature;
|
||||
})
|
||||
|
||||
khash_t(prev_tag_class_weights) *prev_tag_weights;
|
||||
|
||||
log_info("Pruning previous tag features, num_prev_tag_features=%zu\n", num_prev_tag_features);
|
||||
|
||||
uint32_t next_prev_tag_feature_id = 0;
|
||||
|
||||
uint64_t *prev_tag_update_counts = self->prev_tag_update_counts->a;
|
||||
|
||||
tag_bigram_t tag_bigram;
|
||||
uint64_t tag_bigram_key;
|
||||
|
||||
for (feature_id = 0; feature_id < num_prev_tag_features; feature_id++) {
|
||||
k = kh_get(feature_prev_tag_class_weights, self->prev_tag_weights, feature_id);
|
||||
if (k == kh_end(self->prev_tag_weights)) {
|
||||
sparse_matrix_destroy(averaged_state_trans_weights);
|
||||
free(prev_tag_feature_keys);
|
||||
log_error("Error in kh_get self->prev_tag_weights\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
prev_tag_weights = kh_value(self->prev_tag_weights, k);
|
||||
|
||||
uint64_t update_count = prev_tag_update_counts[feature_id];
|
||||
bool keep_feature = update_count >= self->min_updates;
|
||||
|
||||
uint32_t new_feature_id = next_prev_tag_feature_id;
|
||||
|
||||
if (keep_feature) {
|
||||
kh_foreach(prev_tag_weights, tag_bigram_key, weight, {
|
||||
tag_bigram.value = tag_bigram_key;
|
||||
weight.total += (updates - weight.last_updated) * weight.value;
|
||||
double value = weight.total / updates;
|
||||
class_id = tag_bigram_class_id(self, tag_bigram);
|
||||
sparse_matrix_append(averaged_state_trans_weights, class_id, value);
|
||||
})
|
||||
|
||||
sparse_matrix_finalize_row(averaged_state_trans_weights);
|
||||
|
||||
next_prev_tag_feature_id++;
|
||||
}
|
||||
|
||||
if (!keep_feature || new_feature_id != feature_id) {
|
||||
feature = prev_tag_feature_keys[feature_id];
|
||||
k = kh_get(str_uint32, prev_tag_features, feature);
|
||||
if (k != kh_end(prev_tag_features)) {
|
||||
if (keep_feature) {
|
||||
kh_value(prev_tag_features, k) = new_feature_id;
|
||||
} else {
|
||||
kh_del(str_uint32, prev_tag_features, k);
|
||||
}
|
||||
} else {
|
||||
log_error("Error in kh_get on prev_tag_features\n");
|
||||
crf_averaged_perceptron_trainer_destroy(self);
|
||||
free(prev_tag_feature_keys);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
free(prev_tag_feature_keys);
|
||||
|
||||
num_prev_tag_features = kh_size(prev_tag_features);
|
||||
log_info("After pruning, num_prev_tag_features=%zu\n", num_prev_tag_features);
|
||||
|
||||
|
||||
size_t num_classes = self->base_trainer->num_classes;
|
||||
|
||||
double_matrix_t *averaged_trans_weights = double_matrix_new_zeros(num_classes, num_classes);
|
||||
if (averaged_trans_weights == NULL) {
|
||||
log_error("Error creating double matrix for transition weights\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
double *trans = averaged_trans_weights->values;
|
||||
|
||||
kh_foreach(self->trans_weights, tag_bigram_key, weight, {
|
||||
tag_bigram.value = tag_bigram_key;
|
||||
weight.total += (updates - weight.last_updated) * weight.value;
|
||||
double value = weight.total / updates;
|
||||
class_id = tag_bigram_class_id(self, tag_bigram);
|
||||
trans[class_id] = value;
|
||||
})
|
||||
|
||||
crf_t *crf = malloc(sizeof(crf_t));
|
||||
|
||||
crf->num_classes = num_classes;
|
||||
crf->weights = averaged_weights;
|
||||
crf->state_trans_weights = averaged_state_trans_weights;
|
||||
crf->trans_weights = averaged_trans_weights;
|
||||
crf->classes = self->base_trainer->class_strings;
|
||||
self->base_trainer->class_strings = NULL;
|
||||
|
||||
trie_t *state_features = trie_new_from_hash(features);
|
||||
if (state_features == NULL) {
|
||||
crf_averaged_perceptron_trainer_destroy(self);
|
||||
log_error("Error creating state_features\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
crf->state_features = state_features;
|
||||
|
||||
trie_t *state_trans_features = trie_new_from_hash(prev_tag_features);
|
||||
if (state_trans_features == NULL) {
|
||||
crf_averaged_perceptron_trainer_destroy(self);
|
||||
log_error("Error creating state_trans_features\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
crf->state_trans_features = state_trans_features;
|
||||
|
||||
crf->viterbi = self->viterbi;
|
||||
self->viterbi = NULL;
|
||||
|
||||
crf->context = self->base_trainer->context;
|
||||
self->base_trainer->context = NULL;
|
||||
|
||||
crf_averaged_perceptron_trainer_destroy(self);
|
||||
|
||||
return crf;
|
||||
}
|
||||
67
src/crf_trainer_averaged_perceptron.h
Normal file
67
src/crf_trainer_averaged_perceptron.h
Normal file
@@ -0,0 +1,67 @@
|
||||
#ifndef CRF_AVERAGED_PERCEPTRON_TRAINER_H
|
||||
#define CRF_AVERAGED_PERCEPTRON_TRAINER_H
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "averaged_perceptron_trainer.h"
|
||||
#include "crf.h"
|
||||
#include "crf_trainer.h"
|
||||
#include "collections.h"
|
||||
#include "string_utils.h"
|
||||
#include "tokens.h"
|
||||
#include "trie.h"
|
||||
#include "trie_utils.h"
|
||||
|
||||
typedef union tag_bigram {
|
||||
uint64_t value;
|
||||
struct {
|
||||
uint32_t prev_class_id:32;
|
||||
uint32_t class_id:32;
|
||||
};
|
||||
} tag_bigram_t;
|
||||
|
||||
KHASH_MAP_INIT_INT64(prev_tag_class_weights, class_weight_t)
|
||||
|
||||
KHASH_MAP_INIT_INT(feature_prev_tag_class_weights, khash_t(prev_tag_class_weights) *)
|
||||
|
||||
typedef struct crf_averaged_perceptron_trainer {
|
||||
crf_trainer_t *base_trainer;
|
||||
uint64_t num_updates;
|
||||
uint64_t num_errors;
|
||||
uint32_t iterations;
|
||||
uint64_t min_updates;
|
||||
// {feature_id => {class_id => class_weight_t}}
|
||||
khash_t(feature_class_weights) *weights;
|
||||
khash_t(feature_prev_tag_class_weights) *prev_tag_weights;
|
||||
khash_t(prev_tag_class_weights) *trans_weights;
|
||||
uint64_array *update_counts;
|
||||
uint64_array *prev_tag_update_counts;
|
||||
cstring_array *sequence_features;
|
||||
uint32_array *sequence_features_indptr;
|
||||
cstring_array *sequence_prev_tag_features;
|
||||
uint32_array *sequence_prev_tag_features_indptr;
|
||||
uint32_array *label_ids;
|
||||
uint32_array *viterbi;
|
||||
} crf_averaged_perceptron_trainer_t;
|
||||
|
||||
crf_averaged_perceptron_trainer_t *crf_averaged_perceptron_trainer_new(size_t num_classes, size_t min_updates);
|
||||
|
||||
uint32_t crf_averaged_perceptron_trainer_predict(crf_averaged_perceptron_trainer_t *self, cstring_array *features);
|
||||
|
||||
bool crf_averaged_perceptron_trainer_train_example(crf_averaged_perceptron_trainer_t *self,
|
||||
void *tagger,
|
||||
void *context,
|
||||
cstring_array *features,
|
||||
cstring_array *prev_tag_features,
|
||||
tagger_feature_function feature_function,
|
||||
tokenized_string_t *tokenized,
|
||||
cstring_array *labels
|
||||
);
|
||||
|
||||
crf_t *crf_averaged_perceptron_trainer_finalize(crf_averaged_perceptron_trainer_t *self);
|
||||
|
||||
void crf_averaged_perceptron_trainer_destroy(crf_averaged_perceptron_trainer_t *self);
|
||||
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user