[language_classification] Runtime language classifier can now use dense or sparse weights, with a different header signature for the sparse version (using old signature for the dense version, so backward-compatible)

This commit is contained in:
Al
2017-04-02 23:51:54 -04:00
parent 835d851310
commit 5dfdd4b7eb
2 changed files with 53 additions and 16 deletions

View File

@@ -9,6 +9,7 @@
#include "unicode_scripts.h"
#define LANGUAGE_CLASSIFIER_SIGNATURE 0xCCCCCCCC
#define LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE 0xC0C0C0C0
#define LANGUAGE_CLASSIFIER_SETUP_ERROR "language_classifier not loaded, run libpostal_setup_language_classifier()\n"
@@ -27,8 +28,10 @@ void language_classifier_destroy(language_classifier_t *self) {
cstring_array_destroy(self->labels);
}
if (self->weights != NULL) {
double_matrix_destroy(self->weights);
if (self->weights_type == MATRIX_DENSE && self->weights.dense != NULL) {
double_matrix_destroy(self->weights.dense);
} else if (self->weights_type == MATRIX_SPARSE && self->weights.sparse != NULL) {
sparse_matrix_destroy(self->weights.sparse);
}
free(self);
@@ -86,7 +89,14 @@ language_classifier_response_t *classify_languages(char *address) {
double_matrix_t *p_y = double_matrix_new_zeros(1, n);
language_classifier_response_t *response = NULL;
if (logistic_regression_model_expectation(classifier->weights, x, p_y)) {
bool model_exp = false;
if (classifier->weights_type == MATRIX_DENSE) {
model_exp = logistic_regression_model_expectation(classifier->weights.dense, x, p_y);
} else if (classifier->weights_type == MATRIX_SPARSE) {
model_exp = logistic_regression_model_expectation_sparse(classifier->weights.sparse, x, p_y);
}
if (model_exp) {
double *predictions = double_matrix_get_row(p_y, 0);
size_t *indices = double_array_argsort(predictions, n);
size_t num_languages = 0;
@@ -145,7 +155,11 @@ language_classifier_t *language_classifier_read(FILE *f) {
uint32_t signature;
if (!file_read_uint32(f, &signature) || signature != LANGUAGE_CLASSIFIER_SIGNATURE) {
if (!file_read_uint32(f, &signature)) {
goto exit_file_read;
}
if (signature != LANGUAGE_CLASSIFIER_SIGNATURE && signature != LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE) {
goto exit_file_read;
}
@@ -190,14 +204,22 @@ language_classifier_t *language_classifier_read(FILE *f) {
}
classifier->num_labels = cstring_array_num_strings(classifier->labels);
double_matrix_t *weights = double_matrix_read(f);
if (weights == NULL) {
goto exit_classifier_created;
if (signature == LANGUAGE_CLASSIFIER_SIGNATURE) {
double_matrix_t *weights = double_matrix_read(f);
if (weights == NULL) {
goto exit_classifier_created;
}
classifier->weights_type = MATRIX_DENSE;
classifier->weights.dense = weights;
} else if (signature == LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE) {
sparse_matrix_t *sparse_weights = sparse_matrix_read(f);
if (sparse_weights == NULL) {
goto exit_classifier_created;
}
classifier->weights_type = MATRIX_SPARSE;
classifier->weights.sparse = sparse_weights;
}
classifier->weights = weights;
return classifier;
exit_classifier_created:
@@ -223,12 +245,22 @@ language_classifier_t *language_classifier_load(char *path) {
bool language_classifier_write(language_classifier_t *self, FILE *f) {
if (f == NULL || self == NULL) return false;
if (!file_write_uint32(f, LANGUAGE_CLASSIFIER_SIGNATURE) ||
!trie_write(self->features, f) ||
if (self->weights_type == MATRIX_DENSE && !file_write_uint32(f, LANGUAGE_CLASSIFIER_SIGNATURE)) {
return false;
} else if (self->weights_type == MATRIX_SPARSE && !file_write_uint32(f, LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE)) {
return false;
}
if (!trie_write(self->features, f) ||
!file_write_uint64(f, self->num_features) ||
!file_write_uint64(f, self->labels->str->n) ||
!file_write_chars(f, (const char *)self->labels->str->a, self->labels->str->n) ||
!double_matrix_write(self->weights, f)) {
!file_write_chars(f, (const char *)self->labels->str->a, self->labels->str->n)) {
return false;
}
if (self->weights_type == MATRIX_DENSE && !double_matrix_write(self->weights.dense, f)) {
return false;
} else if (self->weights_type == MATRIX_SPARSE && !sparse_matrix_write(self->weights.sparse, f)) {
return false;
}

View File

@@ -9,19 +9,24 @@
#include "collections.h"
#include "language_features.h"
#include "logistic_regression.h"
#include "matrix.h"
#include "tokens.h"
#include "sparse_matrix.h"
#include "string_utils.h"
#include "trie.h"
#define LANGUAGE_CLASSIFIER_FILENAME "language_classifier.dat"
#define LANGUAGE_CLASSIFIER_COUNTRY_FILENAME "language_classifier_country.dat"
typedef struct language_classifier {
size_t num_labels;
size_t num_features;
trie_t *features;
cstring_array *labels;
double_matrix_t *weights;
matrix_type_t weights_type;
union {
double_matrix_t *dense;
sparse_matrix_t *sparse;
} weights;
} language_classifier_t;