[parser] adding polymorphic (as much as C does polymorphism) model type for the parser to allow it to handle either the greedy averaged perceptron or a CRF. During training, saving, and loading, we use a different filename for a parser trained with a CRF, which is still backward-compatible with models previously trained in parser-data. Making necessary modifications to address_parser.c, address_parser_train.c, and address_parser_test.c. Also adding an option in address_parser_test to print individual errors in addition to the confusion matrix.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
#include "address_parser.h"
|
||||
#include "address_parser_io.h"
|
||||
#include "address_dictionary.h"
|
||||
#include "averaged_perceptron_trainer.h"
|
||||
#include "collections.h"
|
||||
#include "constants.h"
|
||||
#include "file_utils.h"
|
||||
@@ -10,8 +9,6 @@
|
||||
|
||||
#include "log/log.h"
|
||||
|
||||
//#define ADDRESS_PARSER_TEST_PRINT_ERRORS
|
||||
|
||||
typedef struct address_parser_test_results {
|
||||
size_t num_errors;
|
||||
size_t num_predictions;
|
||||
@@ -21,19 +18,44 @@ typedef struct address_parser_test_results {
|
||||
} address_parser_test_results_t;
|
||||
|
||||
|
||||
uint32_t get_class_index(address_parser_t *parser, char *name) {
|
||||
static uint32_t address_parser_num_classes(address_parser_t *parser) {
|
||||
if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
averaged_perceptron_t *ap = parser->model.ap;
|
||||
return parser->model.ap->num_classes;
|
||||
} else if (parser->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
return parser->model.crf->num_classes;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static cstring_array *address_parser_class_strings(address_parser_t *parser) {
|
||||
if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
cstring_array *classes = parser->model.ap->classes;
|
||||
return parser->model.ap->classes;
|
||||
} else if (parser->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
return parser->model.crf->classes;
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static uint32_t address_parser_get_class_index(address_parser_t *parser, char *name) {
|
||||
uint32_t i;
|
||||
char *str;
|
||||
|
||||
cstring_array_foreach(parser->model->classes, i, str, {
|
||||
if (strcmp(name, str) == 0) {
|
||||
return i;
|
||||
}
|
||||
})
|
||||
cstring_array *classes = address_parser_class_strings(parser);
|
||||
uint32_t num_classes = address_parser_num_classes(parser);
|
||||
if (classes != NULL) {
|
||||
cstring_array_foreach(classes, i, str, {
|
||||
if (strcmp(name, str) == 0) {
|
||||
return i;
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return parser->model->num_classes;
|
||||
return num_classes;
|
||||
}
|
||||
|
||||
|
||||
#define EMPTY_ADDRESS_PARSER_TEST_RESULT (address_parser_test_results_t){0, 0, 0, 0, NULL}
|
||||
|
||||
bool address_parser_test(address_parser_t *parser, char *filename, address_parser_test_results_t *result, bool print_errors) {
|
||||
@@ -42,7 +64,7 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
|
||||
return NULL;
|
||||
}
|
||||
|
||||
uint32_t num_classes = parser->model->num_classes;
|
||||
uint32_t num_classes = address_parser_num_classes(parser);
|
||||
|
||||
result->confusion = calloc(num_classes * num_classes, sizeof(uint32_t));
|
||||
|
||||
@@ -80,7 +102,9 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
|
||||
|
||||
size_t starting_errors = result->num_errors;
|
||||
|
||||
if (averaged_perceptron_tagger_predict(parser->model, parser, context, context->features, context->prev_tag_features, context->prev2_tag_features, token_labels, &address_parser_features, data_set->tokenized_str, false)) {
|
||||
bool prediction_success = address_parser_predict(parser, context, token_labels, &address_parser_features, data_set->tokenized_str);
|
||||
|
||||
if (prediction_success) {
|
||||
uint32_t i;
|
||||
char *predicted;
|
||||
cstring_array_foreach(token_labels, i, predicted, {
|
||||
@@ -89,8 +113,8 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
|
||||
if (strcmp(predicted, truth) != 0) {
|
||||
result->num_errors++;
|
||||
|
||||
uint32_t predicted_index = get_class_index(parser, predicted);
|
||||
uint32_t truth_index = get_class_index(parser, truth);
|
||||
uint32_t predicted_index = address_parser_get_class_index(parser, predicted);
|
||||
uint32_t truth_index = address_parser_get_class_index(parser, truth);
|
||||
|
||||
result->confusion[predicted_index * num_classes + truth_index]++;
|
||||
|
||||
@@ -102,6 +126,10 @@ bool address_parser_test(address_parser_t *parser, char *filename, address_parse
|
||||
|
||||
})
|
||||
|
||||
} else {
|
||||
log_error("Error in prediction\n");
|
||||
tokenized_string_destroy(data_set->tokenized_str);
|
||||
break;
|
||||
}
|
||||
|
||||
if (result->num_errors > starting_errors) {
|
||||
@@ -184,6 +212,12 @@ int main(int argc, char **argv) {
|
||||
|
||||
address_parser_t *parser = get_address_parser();
|
||||
|
||||
if (parser->model_type == ADDRESS_PARSER_TYPE_GREEDY_AVERAGED_PERCEPTRON) {
|
||||
printf("averaged perceptron parser\n");
|
||||
} else if (parser->model_type == ADDRESS_PARSER_TYPE_CRF) {
|
||||
printf("crf parser\n");
|
||||
}
|
||||
|
||||
address_parser_test_results_t results = EMPTY_ADDRESS_PARSER_TEST_RESULT;
|
||||
|
||||
if (!address_parser_test(parser, filename, &results, print_errors)) {
|
||||
@@ -196,7 +230,7 @@ int main(int argc, char **argv) {
|
||||
|
||||
|
||||
printf("Confusion matrix:\n\n");
|
||||
uint32_t num_classes = parser->model->num_classes;
|
||||
uint32_t num_classes = address_parser_num_classes(parser);
|
||||
|
||||
size_t *confusion_sorted = uint32_array_argsort(results.confusion, num_classes * num_classes);
|
||||
|
||||
@@ -211,8 +245,9 @@ int main(int argc, char **argv) {
|
||||
if (i == j) continue;
|
||||
|
||||
if (class_errors > 0) {
|
||||
char *predicted = cstring_array_get_string(parser->model->classes, i);
|
||||
char *truth = cstring_array_get_string(parser->model->classes, j);
|
||||
cstring_array *classes = address_parser_class_strings(parser);
|
||||
char *predicted = cstring_array_get_string(classes, i);
|
||||
char *truth = cstring_array_get_string(classes, j);
|
||||
|
||||
printf("(%s, %s): %d\n", predicted, truth, class_errors);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user