[math] Generic dense matrix implementation using BLAS calls for matrix-matrix multiplication if available
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
#include "sparse_matrix.h"
|
||||
#include "sparse_matrix_utils.h"
|
||||
|
||||
#define LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD 1.0
|
||||
#define LANGUAGE_CLASSIFIER_FEATURE_COUNT_THRESHOLD 5.0
|
||||
#define LANGUAGE_CLASSIFIER_LABEL_COUNT_THRESHOLD 100
|
||||
|
||||
#define LOG_BATCH_INTERVAL 10
|
||||
@@ -134,13 +134,13 @@ double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filenam
|
||||
uint32_t correct = 0;
|
||||
uint32_t total = 0;
|
||||
|
||||
matrix_t *p_y = matrix_new_zeros(LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE, trainer->num_labels);
|
||||
double_matrix_t *p_y = double_matrix_new_zeros(LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE, trainer->num_labels);
|
||||
|
||||
while ((minibatch = language_classifier_data_set_get_minibatch(data_set, trainer->label_ids)) != NULL) {
|
||||
sparse_matrix_t *x = feature_matrix(trainer->feature_ids, minibatch->features);
|
||||
uint32_array *y = label_vector(trainer->label_ids, minibatch->labels);
|
||||
|
||||
matrix_resize(p_y, x->m, trainer->num_labels);
|
||||
double_matrix_resize(p_y, x->m, trainer->num_labels);
|
||||
|
||||
if (!logistic_regression_model_expectation(trainer->weights, x, p_y)) {
|
||||
log_error("Predict cv batch failed\n");
|
||||
@@ -149,7 +149,7 @@ double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filenam
|
||||
|
||||
double *row;
|
||||
for (size_t i = 0; i < p_y->m; i++) {
|
||||
row = matrix_get_row(p_y, i);
|
||||
row = double_matrix_get_row(p_y, i);
|
||||
|
||||
int64_t predicted = double_array_argmax(row, p_y->n);
|
||||
if (predicted < 0) {
|
||||
@@ -171,7 +171,7 @@ double compute_cv_accuracy(logistic_regression_trainer_t *trainer, char *filenam
|
||||
}
|
||||
|
||||
language_classifier_data_set_destroy(data_set);
|
||||
matrix_destroy(p_y);
|
||||
double_matrix_destroy(p_y);
|
||||
|
||||
double accuracy = (double)correct / total;
|
||||
return accuracy;
|
||||
|
||||
Reference in New Issue
Block a user