[language_classification] Simple accuracy-based test program for language classifier.
This commit is contained in:
78
src/language_classifier_test.c
Normal file
78
src/language_classifier_test.c
Normal file
@@ -0,0 +1,78 @@
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "log/log.h"
|
||||
#include "address_dictionary.h"
|
||||
#include "language_classifier.h"
|
||||
#include "language_classifier_io.h"
|
||||
#include "string_utils.h"
|
||||
#include "trie_utils.h"
|
||||
|
||||
|
||||
double test_accuracy(char *filename) {
|
||||
language_classifier_data_set_t *data_set = language_classifier_data_set_init(filename);
|
||||
if (data_set == NULL) {
|
||||
log_error("Error creating data set\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
language_classifier_minibatch_t *minibatch;
|
||||
|
||||
uint32_t correct = 0;
|
||||
uint32_t total = 0;
|
||||
|
||||
language_classifier_t *classifier = get_language_classifier();
|
||||
trie_t *label_ids = trie_new_from_cstring_array(classifier->labels);
|
||||
|
||||
while (language_classifier_data_set_next(data_set)) {
|
||||
char *address = char_array_get_string(data_set->address);
|
||||
char *language = char_array_get_string(data_set->language);
|
||||
|
||||
uint32_t label_id;
|
||||
if (!trie_get_data(label_ids, language, &label_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
language_classifier_response_t *response = classify_languages(address);
|
||||
if (response == NULL || response->num_languages == 0) {
|
||||
printf("%s\tNULL\t%s\n", language, address);
|
||||
continue;
|
||||
}
|
||||
|
||||
char *top_lang = response->languages[0];
|
||||
|
||||
if (string_equals(top_lang, language)) {
|
||||
correct++;
|
||||
} else {
|
||||
printf("%s\t%s\t%s\n", language, top_lang, address);
|
||||
}
|
||||
|
||||
total++;
|
||||
|
||||
language_classifier_response_destroy(response);
|
||||
|
||||
}
|
||||
|
||||
log_info("total=%zu\n", total);
|
||||
|
||||
trie_destroy(label_ids);
|
||||
|
||||
return (double) correct / total;
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc < 3) {
|
||||
log_error("Usage: language_classifier_test dir filename\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (!language_classifier_module_setup(argv[1]) || !address_dictionary_module_setup(NULL)) {
|
||||
log_error("Error setting up classifier\n");
|
||||
}
|
||||
|
||||
double accuracy = test_accuracy(argv[2]);
|
||||
log_info("Done. Accuracy: %f\n", accuracy);
|
||||
}
|
||||
Reference in New Issue
Block a user