[math] Adding sparse dot sparse given a dense output matrix (suitable for the minibatch use case), fixing sparse dot vector

This commit is contained in:
Al
2016-01-11 13:55:54 -05:00
parent 736bc7c70d
commit 3260edcf18
2 changed files with 53 additions and 2 deletions

View File

@@ -113,7 +113,7 @@ void sparse_matrix_sort_indices(sparse_matrix_t *self) {
inline int sparse_matrix_dot_vector(sparse_matrix_t *self, double *vec, size_t n, double *result) {
if (n != self->m) return -1;
if (n != self->n) {printf("self->n=%zu, n=%zu\n", self->n, n); return -1; }
uint32_t row, row_start, row_len;
double val;
@@ -130,7 +130,7 @@ inline int sparse_matrix_dot_vector(sparse_matrix_t *self, double *vec, size_t n
}
int sparse_matrix_rows_dot_vector(sparse_matrix_t *self, uint32_t *rows, size_t m, double *vec, size_t n, double *result) {
if (m != n) return -1;
if (n != self->n) return -1;
uint32_t *indptr = self->indptr->a;
uint32_t *indices = self->indices->a;
@@ -265,6 +265,56 @@ int sparse_matrix_dot_dense(sparse_matrix_t *self, matrix_t *matrix, matrix_t *r
return 0;
}
int sparse_matrix_dot_sparse(sparse_matrix_t *self, sparse_matrix_t *other, matrix_t *result) {
if (self->n != other->m || self->m != result->m || other->n != result->n) {
return -1;
}
uint32_t *indptr = self->indptr->a;
uint32_t *indices = self->indices->a;
double *data = self->data->a;
size_t m1_rows = self->m;
size_t m1_cols = self->n;
size_t m2_rows = other->m;
size_t m2_cols = other->n;
uint32_t *other_indptr = other->indptr->a;
uint32_t *other_indices = other->indices->a;
double *other_data = other->data->a;
double *result_values = result->values;
uint32_t row, row_start, row_len;
sparse_matrix_foreach_row(self, row, row_start, row_len, {
for (uint32_t i = row_start; i < row_start + row_len; i++) {
uint32_t col = indices[i];
if (col >= m2_rows) { return -1; }
uint32_t m2_row_start = other_indptr[col];
uint32_t m2_row_end = other_indptr[col + 1];
double m1_data = data[i];
for (uint32_t j = m2_row_start; j < m2_row_end; j++) {
uint32_t m2_col = other_indices[j];
size_t result_index = row * m2_cols + m2_col;
double m2_data = other_data[j];
result_values[result_index] += m1_data * m2_data;
}
}
})
return 0;
}
sparse_matrix_t *sparse_matrix_read(FILE *f) {
sparse_matrix_t *sp = malloc(sizeof(sparse_matrix_t));
if (sp == NULL) return NULL;