From ef8768281b7c35621ea6e39c40b69c77401ab984 Mon Sep 17 00:00:00 2001 From: Al Date: Fri, 10 Mar 2017 02:06:45 -0500 Subject: [PATCH] [parser/crf] adding runtime CRF tagger, which can be loaded/used once trained. Currently only does Viterbi inference, can add top-N and/or sequence probabilities later --- src/crf.c | 331 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/crf.h | 53 +++++++++ 2 files changed, 384 insertions(+) create mode 100644 src/crf.c create mode 100644 src/crf.h diff --git a/src/crf.c b/src/crf.c new file mode 100644 index 00000000..6e63d9d2 --- /dev/null +++ b/src/crf.c @@ -0,0 +1,331 @@ +#include "crf.h" +#include "log/log.h" + +#define CRF_SIGNATURE 0xCFCFCFCF + + +static inline bool crf_get_feature_id(crf_t *self, char *feature, uint32_t *feature_id) { + return trie_get_data(self->state_features, feature, feature_id); +} + +static inline bool crf_get_state_trans_feature_id(crf_t *self, char *feature, uint32_t *feature_id) { + return trie_get_data(self->state_trans_features, feature, feature_id); +} + +bool crf_tagger_score(crf_t *self, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, tagger_feature_function feature_function, tokenized_string_t *tokenized, bool print_features) { + if (self == NULL || feature_function == NULL || tokenized == NULL ) { + return false; + } + size_t num_tokens = tokenized->tokens->n; + + crf_context_t *crf_context = self->context; + crf_context_set_num_items(crf_context, num_tokens); + crf_context_reset(crf_context, CRF_CONTEXT_RESET_ALL); + + if (!double_matrix_copy(self->trans_weights, crf_context->trans)) { + return false; + } + + for (uint32_t t = 0; t < num_tokens; t++) { + cstring_array_clear(features); + cstring_array_clear(prev_tag_features); + + if (!feature_function(tagger, tagger_context, tokenized, t)) { + log_error("Could not add address parser features\n"); + return false; + } + + uint32_t fidx; + char *feature; + + if (print_features) { + printf("{ "); + size_t num_features = cstring_array_num_strings(features); + cstring_array_foreach(features, fidx, feature, { + printf("%s", feature); + if (fidx < num_features - 1) printf(", "); + }) + size_t num_prev_tag_features = cstring_array_num_strings(prev_tag_features); + cstring_array_foreach(prev_tag_features, fidx, feature, { + printf("%s", feature); + if (fidx < num_prev_tag_features - 1) printf(", "); + }) + printf(" }\n"); + } + + uint32_t feature_id; + + double *state_scores = state_score(crf_context, t); + + uint32_t *indptr = self->weights->indptr->a; + uint32_t *indices = self->weights->indices->a; + double *data = self->weights->data->a; + + cstring_array_foreach(features, fidx, feature, { + if (!crf_get_feature_id(self, feature, &feature_id)) { + continue; + } + + for (int col = indptr[feature_id]; col < indptr[feature_id + 1]; col++) { + uint32_t class_id = indices[col]; + state_scores[class_id] += data[col]; + } + }) + + double *state_trans_scores = state_trans_score_all(crf_context, t); + + indptr = self->state_trans_weights->indptr->a; + indices = self->state_trans_weights->indices->a; + data = self->state_trans_weights->data->a; + + cstring_array_foreach(prev_tag_features, fidx, feature, { + if (!crf_get_state_trans_feature_id(self, feature, &feature_id)) { + continue; + } + + for (int col = indptr[feature_id]; col < indptr[feature_id + 1]; col++) { + // Note: here there are L * L classes + uint32_t class_id = indices[col]; + state_trans_scores[class_id] += data[col]; + } + }) + + } + return true; +} + +bool crf_tagger_score_viterbi(crf_t *self, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, tagger_feature_function feature_function, tokenized_string_t *tokenized, double *score, bool print_features) { + if (!crf_tagger_score(self, tagger, tagger_context, features, prev_tag_features, feature_function, tokenized, print_features)) { + return false; + } + + size_t num_tokens = tokenized->tokens->n; + + uint32_array_resize_fixed(self->viterbi, num_tokens); + double viterbi_score = crf_context_viterbi(self->context, self->viterbi->a); + + *score = viterbi_score; + + return true; +} + + +bool crf_tagger_predict(crf_t *self, void *tagger, void *context, cstring_array *features, cstring_array *prev_tag_features, cstring_array *labels, tagger_feature_function feature_function, tokenized_string_t *tokenized, bool print_features) { + double score; + + if (labels == NULL) return false; + if (!crf_tagger_score_viterbi(self, tagger, context, features, prev_tag_features, feature_function, tokenized, &score, print_features)) { + return false; + } + + uint32_t *viterbi = self->viterbi->a; + + if (self->viterbi->n == 0) { + log_error("self->viterbi->n == 0\n"); + } + + for (size_t i = 0; i < self->viterbi->n; i++) { + char *predicted = cstring_array_get_string(self->classes, viterbi[i]); + cstring_array_add_string(labels, predicted); + } + + return true; +} + + +bool crf_write(crf_t *self, FILE *f) { + if (self == NULL || f == NULL || self->weights == NULL || self->classes == NULL || + self->state_features == NULL || self->state_trans_features == NULL) { + log_info("something was NULL\n"); + return false; + } + + if (!file_write_uint32(f, CRF_SIGNATURE) || + !file_write_uint32(f, self->num_classes)) { + log_info("error writing header\n"); + return false; + } + + uint64_t classes_str_len = (uint64_t) cstring_array_used(self->classes); + if (!file_write_uint64(f, classes_str_len)) { + log_info("error writing classes_str_len\n"); + return false; + } + + if (!file_write_chars(f, self->classes->str->a, classes_str_len)) { + log_info("error writing chars\n"); + return false; + } + + if (!trie_write(self->state_features, f)) { + log_info("error state_features\n"); + return false; + } + + if (!sparse_matrix_write(self->weights, f)) { + log_info("error weights\n"); + return false; + } + + if (!trie_write(self->state_trans_features, f)) { + log_info("error state_trans_features\n"); + return false; + } + + if (!sparse_matrix_write(self->state_trans_weights, f)) { + log_info("error state_trans_weights\n"); + return false; + } + + if (!double_matrix_write(self->trans_weights, f)) { + log_info("error trans_weights\n"); + return false; + } + + return true; +} + + +bool crf_save(crf_t *self, char *filename) { + if (self == NULL || filename == NULL) { + log_info("crf or filename was NULL\n"); + return false; + } + FILE *f = fopen(filename, "wb"); + if (f == NULL) return false; + bool ret_val = crf_write(self, f); + fclose(f); + return ret_val; +} + + +crf_t *crf_read(FILE *f) { + if (f == NULL) return NULL; + + uint32_t signature; + + if (!file_read_uint32(f, &signature) || signature != CRF_SIGNATURE) { + return NULL; + } + + crf_t *crf = calloc(1, sizeof(crf_t)); + if (crf == NULL) return NULL; + + if (!file_read_uint32(f, &crf->num_classes) || + crf->num_classes == 0) { + free(crf); + return NULL; + } + + uint64_t classes_str_len; + + if (!file_read_uint64(f, &classes_str_len)) { + goto exit_crf_created; + } + + char_array *array = char_array_new_size(classes_str_len); + + if (array == NULL) { + goto exit_crf_created; + } + + if (!file_read_chars(f, array->a, classes_str_len)) { + char_array_destroy(array); + goto exit_crf_created; + } + + array->n = classes_str_len; + + crf->classes = cstring_array_from_char_array(array); + if (crf->classes == NULL) { + goto exit_crf_created; + } + + crf->state_features = trie_read(f); + if (crf->state_features == NULL) { + goto exit_crf_created; + } + + crf->weights = sparse_matrix_read(f); + if (crf->weights == NULL) { + goto exit_crf_created; + } + + crf->state_trans_features = trie_read(f); + if (crf->state_trans_features == NULL) { + goto exit_crf_created; + } + + crf->state_trans_weights = sparse_matrix_read(f); + if (crf->state_trans_weights == NULL) { + goto exit_crf_created; + } + + crf->trans_weights = double_matrix_read(f); + if (crf->trans_weights == NULL) { + goto exit_crf_created; + } + + crf->viterbi = uint32_array_new(); + if (crf->viterbi == NULL) { + goto exit_crf_created; + } + + crf->context = crf_context_new(CRF_CONTEXT_VITERBI | CRF_CONTEXT_MARGINALS, crf->num_classes, CRF_CONTEXT_DEFAULT_NUM_ITEMS); + if (crf->context == NULL) { + goto exit_crf_created; + } + + return crf; + +exit_crf_created: + crf_destroy(crf); + return NULL; +} + +crf_t *crf_load(char *filename) { + if (filename == NULL) return NULL; + FILE *f = fopen(filename, "rb"); + if (f == NULL) return NULL; + crf_t *crf = crf_read(f); + fclose(f); + return crf; +} + +void crf_destroy(crf_t *self) { + if (self == NULL) return; + + if (self->classes != NULL) { + cstring_array_destroy(self->classes); + } + + if (self->state_features != NULL) { + trie_destroy(self->state_features); + } + + if (self->weights != NULL) { + sparse_matrix_destroy(self->weights); + } + + if (self->state_trans_features != NULL) { + trie_destroy(self->state_trans_features); + } + + if (self->state_trans_weights != NULL) { + sparse_matrix_destroy(self->state_trans_weights); + } + + if (self->trans_weights != NULL) { + double_matrix_destroy(self->trans_weights); + } + + if (self->viterbi != NULL) { + uint32_array_destroy(self->viterbi); + } + + if (self->context != NULL) { + crf_context_destroy(self->context); + } + + free(self); +} diff --git a/src/crf.h b/src/crf.h new file mode 100644 index 00000000..af0dabeb --- /dev/null +++ b/src/crf.h @@ -0,0 +1,53 @@ +/* +crf.h +--------------------------------------------------------------- +A linear-chain CRF tagger tries to find the best labeling +for a sequence. The feature function can use the current token, +surrounding tokens and n (typically n=2) previous predictions +to predict the current transition matrix. + +*/ + +#ifndef CRF_H +#define CRF_H + +#include +#include +#include +#include + +#include "collections.h" +#include "crf_context.h" +#include "matrix.h" +#include "sparse_matrix.h" +#include "tagger.h" +#include "trie.h" + +typedef struct crf { + uint32_t num_classes; + cstring_array *classes; + trie_t *state_features; + sparse_matrix_t *weights; + trie_t *state_trans_features; + sparse_matrix_t *state_trans_weights; + double_matrix_t *trans_weights; + uint32_array *viterbi; + crf_context_t *context; +} crf_t; + +bool crf_tagger_predict(crf_t *model, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, cstring_array *labels, tagger_feature_function feature_function, tokenized_string_t *tokenized, bool print_features); + +bool crf_tagger_score(crf_t *self, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, tagger_feature_function feature_function, tokenized_string_t *tokenized, bool print_features); +bool crf_tagger_score_viterbi(crf_t *self, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, tagger_feature_function feature_function, tokenized_string_t *tokenized, double *score, bool print_features); + +bool crf_tagger_predict(crf_t *self, void *tagger, void *context, cstring_array *features, cstring_array *prev_tag_features, cstring_array *labels, tagger_feature_function feature_function, tokenized_string_t *tokenized, bool print_features); + +bool crf_write(crf_t *self, FILE *f); +bool crf_save(crf_t *self, char *filename); + +crf_t *crf_read(FILE *f); +crf_t *crf_load(char *filename); + +void crf_destroy(crf_t *self); + +#endif \ No newline at end of file