[language_classification] Runtime language classifier can now use dense or sparse weights, with a different header signature for the sparse version (using old signature for the dense version, so backward-compatible)
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include "unicode_scripts.h"
|
||||
|
||||
#define LANGUAGE_CLASSIFIER_SIGNATURE 0xCCCCCCCC
|
||||
#define LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE 0xC0C0C0C0
|
||||
|
||||
#define LANGUAGE_CLASSIFIER_SETUP_ERROR "language_classifier not loaded, run libpostal_setup_language_classifier()\n"
|
||||
|
||||
@@ -27,8 +28,10 @@ void language_classifier_destroy(language_classifier_t *self) {
|
||||
cstring_array_destroy(self->labels);
|
||||
}
|
||||
|
||||
if (self->weights != NULL) {
|
||||
double_matrix_destroy(self->weights);
|
||||
if (self->weights_type == MATRIX_DENSE && self->weights.dense != NULL) {
|
||||
double_matrix_destroy(self->weights.dense);
|
||||
} else if (self->weights_type == MATRIX_SPARSE && self->weights.sparse != NULL) {
|
||||
sparse_matrix_destroy(self->weights.sparse);
|
||||
}
|
||||
|
||||
free(self);
|
||||
@@ -86,7 +89,14 @@ language_classifier_response_t *classify_languages(char *address) {
|
||||
double_matrix_t *p_y = double_matrix_new_zeros(1, n);
|
||||
|
||||
language_classifier_response_t *response = NULL;
|
||||
if (logistic_regression_model_expectation(classifier->weights, x, p_y)) {
|
||||
bool model_exp = false;
|
||||
if (classifier->weights_type == MATRIX_DENSE) {
|
||||
model_exp = logistic_regression_model_expectation(classifier->weights.dense, x, p_y);
|
||||
} else if (classifier->weights_type == MATRIX_SPARSE) {
|
||||
model_exp = logistic_regression_model_expectation_sparse(classifier->weights.sparse, x, p_y);
|
||||
}
|
||||
|
||||
if (model_exp) {
|
||||
double *predictions = double_matrix_get_row(p_y, 0);
|
||||
size_t *indices = double_array_argsort(predictions, n);
|
||||
size_t num_languages = 0;
|
||||
@@ -145,7 +155,11 @@ language_classifier_t *language_classifier_read(FILE *f) {
|
||||
|
||||
uint32_t signature;
|
||||
|
||||
if (!file_read_uint32(f, &signature) || signature != LANGUAGE_CLASSIFIER_SIGNATURE) {
|
||||
if (!file_read_uint32(f, &signature)) {
|
||||
goto exit_file_read;
|
||||
}
|
||||
|
||||
if (signature != LANGUAGE_CLASSIFIER_SIGNATURE && signature != LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE) {
|
||||
goto exit_file_read;
|
||||
}
|
||||
|
||||
@@ -190,13 +204,21 @@ language_classifier_t *language_classifier_read(FILE *f) {
|
||||
}
|
||||
classifier->num_labels = cstring_array_num_strings(classifier->labels);
|
||||
|
||||
if (signature == LANGUAGE_CLASSIFIER_SIGNATURE) {
|
||||
double_matrix_t *weights = double_matrix_read(f);
|
||||
|
||||
if (weights == NULL) {
|
||||
goto exit_classifier_created;
|
||||
}
|
||||
|
||||
classifier->weights = weights;
|
||||
classifier->weights_type = MATRIX_DENSE;
|
||||
classifier->weights.dense = weights;
|
||||
} else if (signature == LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE) {
|
||||
sparse_matrix_t *sparse_weights = sparse_matrix_read(f);
|
||||
if (sparse_weights == NULL) {
|
||||
goto exit_classifier_created;
|
||||
}
|
||||
classifier->weights_type = MATRIX_SPARSE;
|
||||
classifier->weights.sparse = sparse_weights;
|
||||
}
|
||||
|
||||
return classifier;
|
||||
|
||||
@@ -223,12 +245,22 @@ language_classifier_t *language_classifier_load(char *path) {
|
||||
bool language_classifier_write(language_classifier_t *self, FILE *f) {
|
||||
if (f == NULL || self == NULL) return false;
|
||||
|
||||
if (!file_write_uint32(f, LANGUAGE_CLASSIFIER_SIGNATURE) ||
|
||||
!trie_write(self->features, f) ||
|
||||
if (self->weights_type == MATRIX_DENSE && !file_write_uint32(f, LANGUAGE_CLASSIFIER_SIGNATURE)) {
|
||||
return false;
|
||||
} else if (self->weights_type == MATRIX_SPARSE && !file_write_uint32(f, LANGUAGE_CLASSIFIER_SPARSE_SIGNATURE)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!trie_write(self->features, f) ||
|
||||
!file_write_uint64(f, self->num_features) ||
|
||||
!file_write_uint64(f, self->labels->str->n) ||
|
||||
!file_write_chars(f, (const char *)self->labels->str->a, self->labels->str->n) ||
|
||||
!double_matrix_write(self->weights, f)) {
|
||||
!file_write_chars(f, (const char *)self->labels->str->a, self->labels->str->n)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (self->weights_type == MATRIX_DENSE && !double_matrix_write(self->weights.dense, f)) {
|
||||
return false;
|
||||
} else if (self->weights_type == MATRIX_SPARSE && !sparse_matrix_write(self->weights.sparse, f)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -9,19 +9,24 @@
|
||||
#include "collections.h"
|
||||
#include "language_features.h"
|
||||
#include "logistic_regression.h"
|
||||
#include "matrix.h"
|
||||
#include "tokens.h"
|
||||
#include "sparse_matrix.h"
|
||||
#include "string_utils.h"
|
||||
#include "trie.h"
|
||||
|
||||
#define LANGUAGE_CLASSIFIER_FILENAME "language_classifier.dat"
|
||||
#define LANGUAGE_CLASSIFIER_COUNTRY_FILENAME "language_classifier_country.dat"
|
||||
|
||||
typedef struct language_classifier {
|
||||
size_t num_labels;
|
||||
size_t num_features;
|
||||
trie_t *features;
|
||||
cstring_array *labels;
|
||||
double_matrix_t *weights;
|
||||
matrix_type_t weights_type;
|
||||
union {
|
||||
double_matrix_t *dense;
|
||||
sparse_matrix_t *sparse;
|
||||
} weights;
|
||||
} language_classifier_t;
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user