[parser/crf] adding runtime CRF tagger, which can be loaded/used once trained. Currently only does Viterbi inference, can add top-N and/or sequence probabilities later
This commit is contained in:
331
src/crf.c
Normal file
331
src/crf.c
Normal file
@@ -0,0 +1,331 @@
|
||||
#include "crf.h"
|
||||
#include "log/log.h"
|
||||
|
||||
#define CRF_SIGNATURE 0xCFCFCFCF
|
||||
|
||||
|
||||
static inline bool crf_get_feature_id(crf_t *self, char *feature, uint32_t *feature_id) {
|
||||
return trie_get_data(self->state_features, feature, feature_id);
|
||||
}
|
||||
|
||||
static inline bool crf_get_state_trans_feature_id(crf_t *self, char *feature, uint32_t *feature_id) {
|
||||
return trie_get_data(self->state_trans_features, feature, feature_id);
|
||||
}
|
||||
|
||||
bool crf_tagger_score(crf_t *self, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, tagger_feature_function feature_function, tokenized_string_t *tokenized, bool print_features) {
|
||||
if (self == NULL || feature_function == NULL || tokenized == NULL ) {
|
||||
return false;
|
||||
}
|
||||
size_t num_tokens = tokenized->tokens->n;
|
||||
|
||||
crf_context_t *crf_context = self->context;
|
||||
crf_context_set_num_items(crf_context, num_tokens);
|
||||
crf_context_reset(crf_context, CRF_CONTEXT_RESET_ALL);
|
||||
|
||||
if (!double_matrix_copy(self->trans_weights, crf_context->trans)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (uint32_t t = 0; t < num_tokens; t++) {
|
||||
cstring_array_clear(features);
|
||||
cstring_array_clear(prev_tag_features);
|
||||
|
||||
if (!feature_function(tagger, tagger_context, tokenized, t)) {
|
||||
log_error("Could not add address parser features\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t fidx;
|
||||
char *feature;
|
||||
|
||||
if (print_features) {
|
||||
printf("{ ");
|
||||
size_t num_features = cstring_array_num_strings(features);
|
||||
cstring_array_foreach(features, fidx, feature, {
|
||||
printf("%s", feature);
|
||||
if (fidx < num_features - 1) printf(", ");
|
||||
})
|
||||
size_t num_prev_tag_features = cstring_array_num_strings(prev_tag_features);
|
||||
cstring_array_foreach(prev_tag_features, fidx, feature, {
|
||||
printf("%s", feature);
|
||||
if (fidx < num_prev_tag_features - 1) printf(", ");
|
||||
})
|
||||
printf(" }\n");
|
||||
}
|
||||
|
||||
uint32_t feature_id;
|
||||
|
||||
double *state_scores = state_score(crf_context, t);
|
||||
|
||||
uint32_t *indptr = self->weights->indptr->a;
|
||||
uint32_t *indices = self->weights->indices->a;
|
||||
double *data = self->weights->data->a;
|
||||
|
||||
cstring_array_foreach(features, fidx, feature, {
|
||||
if (!crf_get_feature_id(self, feature, &feature_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int col = indptr[feature_id]; col < indptr[feature_id + 1]; col++) {
|
||||
uint32_t class_id = indices[col];
|
||||
state_scores[class_id] += data[col];
|
||||
}
|
||||
})
|
||||
|
||||
double *state_trans_scores = state_trans_score_all(crf_context, t);
|
||||
|
||||
indptr = self->state_trans_weights->indptr->a;
|
||||
indices = self->state_trans_weights->indices->a;
|
||||
data = self->state_trans_weights->data->a;
|
||||
|
||||
cstring_array_foreach(prev_tag_features, fidx, feature, {
|
||||
if (!crf_get_state_trans_feature_id(self, feature, &feature_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int col = indptr[feature_id]; col < indptr[feature_id + 1]; col++) {
|
||||
// Note: here there are L * L classes
|
||||
uint32_t class_id = indices[col];
|
||||
state_trans_scores[class_id] += data[col];
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool crf_tagger_score_viterbi(crf_t *self, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, tagger_feature_function feature_function, tokenized_string_t *tokenized, double *score, bool print_features) {
|
||||
if (!crf_tagger_score(self, tagger, tagger_context, features, prev_tag_features, feature_function, tokenized, print_features)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t num_tokens = tokenized->tokens->n;
|
||||
|
||||
uint32_array_resize_fixed(self->viterbi, num_tokens);
|
||||
double viterbi_score = crf_context_viterbi(self->context, self->viterbi->a);
|
||||
|
||||
*score = viterbi_score;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool crf_tagger_predict(crf_t *self, void *tagger, void *context, cstring_array *features, cstring_array *prev_tag_features, cstring_array *labels, tagger_feature_function feature_function, tokenized_string_t *tokenized, bool print_features) {
|
||||
double score;
|
||||
|
||||
if (labels == NULL) return false;
|
||||
if (!crf_tagger_score_viterbi(self, tagger, context, features, prev_tag_features, feature_function, tokenized, &score, print_features)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t *viterbi = self->viterbi->a;
|
||||
|
||||
if (self->viterbi->n == 0) {
|
||||
log_error("self->viterbi->n == 0\n");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < self->viterbi->n; i++) {
|
||||
char *predicted = cstring_array_get_string(self->classes, viterbi[i]);
|
||||
cstring_array_add_string(labels, predicted);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool crf_write(crf_t *self, FILE *f) {
|
||||
if (self == NULL || f == NULL || self->weights == NULL || self->classes == NULL ||
|
||||
self->state_features == NULL || self->state_trans_features == NULL) {
|
||||
log_info("something was NULL\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!file_write_uint32(f, CRF_SIGNATURE) ||
|
||||
!file_write_uint32(f, self->num_classes)) {
|
||||
log_info("error writing header\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
uint64_t classes_str_len = (uint64_t) cstring_array_used(self->classes);
|
||||
if (!file_write_uint64(f, classes_str_len)) {
|
||||
log_info("error writing classes_str_len\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!file_write_chars(f, self->classes->str->a, classes_str_len)) {
|
||||
log_info("error writing chars\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!trie_write(self->state_features, f)) {
|
||||
log_info("error state_features\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!sparse_matrix_write(self->weights, f)) {
|
||||
log_info("error weights\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!trie_write(self->state_trans_features, f)) {
|
||||
log_info("error state_trans_features\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!sparse_matrix_write(self->state_trans_weights, f)) {
|
||||
log_info("error state_trans_weights\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!double_matrix_write(self->trans_weights, f)) {
|
||||
log_info("error trans_weights\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool crf_save(crf_t *self, char *filename) {
|
||||
if (self == NULL || filename == NULL) {
|
||||
log_info("crf or filename was NULL\n");
|
||||
return false;
|
||||
}
|
||||
FILE *f = fopen(filename, "wb");
|
||||
if (f == NULL) return false;
|
||||
bool ret_val = crf_write(self, f);
|
||||
fclose(f);
|
||||
return ret_val;
|
||||
}
|
||||
|
||||
|
||||
crf_t *crf_read(FILE *f) {
|
||||
if (f == NULL) return NULL;
|
||||
|
||||
uint32_t signature;
|
||||
|
||||
if (!file_read_uint32(f, &signature) || signature != CRF_SIGNATURE) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
crf_t *crf = calloc(1, sizeof(crf_t));
|
||||
if (crf == NULL) return NULL;
|
||||
|
||||
if (!file_read_uint32(f, &crf->num_classes) ||
|
||||
crf->num_classes == 0) {
|
||||
free(crf);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
uint64_t classes_str_len;
|
||||
|
||||
if (!file_read_uint64(f, &classes_str_len)) {
|
||||
goto exit_crf_created;
|
||||
}
|
||||
|
||||
char_array *array = char_array_new_size(classes_str_len);
|
||||
|
||||
if (array == NULL) {
|
||||
goto exit_crf_created;
|
||||
}
|
||||
|
||||
if (!file_read_chars(f, array->a, classes_str_len)) {
|
||||
char_array_destroy(array);
|
||||
goto exit_crf_created;
|
||||
}
|
||||
|
||||
array->n = classes_str_len;
|
||||
|
||||
crf->classes = cstring_array_from_char_array(array);
|
||||
if (crf->classes == NULL) {
|
||||
goto exit_crf_created;
|
||||
}
|
||||
|
||||
crf->state_features = trie_read(f);
|
||||
if (crf->state_features == NULL) {
|
||||
goto exit_crf_created;
|
||||
}
|
||||
|
||||
crf->weights = sparse_matrix_read(f);
|
||||
if (crf->weights == NULL) {
|
||||
goto exit_crf_created;
|
||||
}
|
||||
|
||||
crf->state_trans_features = trie_read(f);
|
||||
if (crf->state_trans_features == NULL) {
|
||||
goto exit_crf_created;
|
||||
}
|
||||
|
||||
crf->state_trans_weights = sparse_matrix_read(f);
|
||||
if (crf->state_trans_weights == NULL) {
|
||||
goto exit_crf_created;
|
||||
}
|
||||
|
||||
crf->trans_weights = double_matrix_read(f);
|
||||
if (crf->trans_weights == NULL) {
|
||||
goto exit_crf_created;
|
||||
}
|
||||
|
||||
crf->viterbi = uint32_array_new();
|
||||
if (crf->viterbi == NULL) {
|
||||
goto exit_crf_created;
|
||||
}
|
||||
|
||||
crf->context = crf_context_new(CRF_CONTEXT_VITERBI | CRF_CONTEXT_MARGINALS, crf->num_classes, CRF_CONTEXT_DEFAULT_NUM_ITEMS);
|
||||
if (crf->context == NULL) {
|
||||
goto exit_crf_created;
|
||||
}
|
||||
|
||||
return crf;
|
||||
|
||||
exit_crf_created:
|
||||
crf_destroy(crf);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
crf_t *crf_load(char *filename) {
|
||||
if (filename == NULL) return NULL;
|
||||
FILE *f = fopen(filename, "rb");
|
||||
if (f == NULL) return NULL;
|
||||
crf_t *crf = crf_read(f);
|
||||
fclose(f);
|
||||
return crf;
|
||||
}
|
||||
|
||||
void crf_destroy(crf_t *self) {
|
||||
if (self == NULL) return;
|
||||
|
||||
if (self->classes != NULL) {
|
||||
cstring_array_destroy(self->classes);
|
||||
}
|
||||
|
||||
if (self->state_features != NULL) {
|
||||
trie_destroy(self->state_features);
|
||||
}
|
||||
|
||||
if (self->weights != NULL) {
|
||||
sparse_matrix_destroy(self->weights);
|
||||
}
|
||||
|
||||
if (self->state_trans_features != NULL) {
|
||||
trie_destroy(self->state_trans_features);
|
||||
}
|
||||
|
||||
if (self->state_trans_weights != NULL) {
|
||||
sparse_matrix_destroy(self->state_trans_weights);
|
||||
}
|
||||
|
||||
if (self->trans_weights != NULL) {
|
||||
double_matrix_destroy(self->trans_weights);
|
||||
}
|
||||
|
||||
if (self->viterbi != NULL) {
|
||||
uint32_array_destroy(self->viterbi);
|
||||
}
|
||||
|
||||
if (self->context != NULL) {
|
||||
crf_context_destroy(self->context);
|
||||
}
|
||||
|
||||
free(self);
|
||||
}
|
||||
53
src/crf.h
Normal file
53
src/crf.h
Normal file
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
crf.h
|
||||
---------------------------------------------------------------
|
||||
A linear-chain CRF tagger tries to find the best labeling
|
||||
for a sequence. The feature function can use the current token,
|
||||
surrounding tokens and n (typically n=2) previous predictions
|
||||
to predict the current transition matrix.
|
||||
|
||||
*/
|
||||
|
||||
#ifndef CRF_H
|
||||
#define CRF_H
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "collections.h"
|
||||
#include "crf_context.h"
|
||||
#include "matrix.h"
|
||||
#include "sparse_matrix.h"
|
||||
#include "tagger.h"
|
||||
#include "trie.h"
|
||||
|
||||
typedef struct crf {
|
||||
uint32_t num_classes;
|
||||
cstring_array *classes;
|
||||
trie_t *state_features;
|
||||
sparse_matrix_t *weights;
|
||||
trie_t *state_trans_features;
|
||||
sparse_matrix_t *state_trans_weights;
|
||||
double_matrix_t *trans_weights;
|
||||
uint32_array *viterbi;
|
||||
crf_context_t *context;
|
||||
} crf_t;
|
||||
|
||||
bool crf_tagger_predict(crf_t *model, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, cstring_array *labels, tagger_feature_function feature_function, tokenized_string_t *tokenized, bool print_features);
|
||||
|
||||
bool crf_tagger_score(crf_t *self, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, tagger_feature_function feature_function, tokenized_string_t *tokenized, bool print_features);
|
||||
bool crf_tagger_score_viterbi(crf_t *self, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, tagger_feature_function feature_function, tokenized_string_t *tokenized, double *score, bool print_features);
|
||||
|
||||
bool crf_tagger_predict(crf_t *self, void *tagger, void *context, cstring_array *features, cstring_array *prev_tag_features, cstring_array *labels, tagger_feature_function feature_function, tokenized_string_t *tokenized, bool print_features);
|
||||
|
||||
bool crf_write(crf_t *self, FILE *f);
|
||||
bool crf_save(crf_t *self, char *filename);
|
||||
|
||||
crf_t *crf_read(FILE *f);
|
||||
crf_t *crf_load(char *filename);
|
||||
|
||||
void crf_destroy(crf_t *self);
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user