From d9d53ce17ed8528d7959fa798c02b5e2a25a5693 Mon Sep 17 00:00:00 2001 From: Al Date: Tue, 8 Dec 2015 15:39:52 -0500 Subject: [PATCH] [math] Matrix method updates --- src/matrix.c | 35 ++++++++++++++++++++++++++++++----- src/matrix.h | 10 ++++++++-- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/src/matrix.c b/src/matrix.c index 928f5720..5814660c 100644 --- a/src/matrix.c +++ b/src/matrix.c @@ -40,13 +40,24 @@ bool matrix_resize(matrix_t *self, size_t m, size_t n) { return true; } -matrix_t *matrix_copy(matrix_t *self) { +matrix_t *matrix_new_copy(matrix_t *self) { matrix_t *cpy = matrix_new(self->m, self->n); size_t num_values = self->m * self->n; - memcpy(cpy->values, self->values, num_values * sizeof(double)); + double_array_copy(cpy->values, self->values, num_values); + return cpy; } +bool matrix_copy(matrix_t *self, matrix_t *other) { + if (self->m != other->m || self->n != other->n) { + return false; + } + size_t num_values = self->n * self->n; + + double_array_copy(other->values, self->values, num_values); + return true; +} + inline void matrix_init_values(matrix_t *self, double *values) { size_t num_values = self->m * self->n; memcpy(self->values, values, num_values * sizeof(double)); @@ -78,6 +89,12 @@ inline double matrix_get(matrix_t *self, size_t row_index, size_t col_index) { return self->values->a[index]; } +inline double *matrix_get_row(matrix_t *self, size_t row_index) { + size_t index = row_index * self->n; + return self->values->a + index; +} + + inline matrix_t *matrix_new_value(size_t m, size_t n, double value) { matrix_t *matrix = matrix_new(m, n); matrix_set(matrix, value); @@ -119,6 +136,14 @@ inline void matrix_sub(matrix_t *self, double value) { } +inline void matrix_log(matrix_t *self) { + double_array_log(self->values, self->m * self->n); +} + +inline void matrix_exp(matrix_t *self) { + double_array_exp(self->values, self->m * self->n); +} + void matrix_dot_vector(matrix_t *self, double *vec, double *result) { double *values = self->values; size_t n = self->n; @@ -129,9 +154,9 @@ void matrix_dot_vector(matrix_t *self, double *vec, double *result) { } } -int matrix_dot_matrix(matrix_t *m1, matrix_t *m2, matrix_t *result) { +bool matrix_dot_matrix(matrix_t *m1, matrix_t *m2, matrix_t *result) { if (m1->n != m2->m || m1->m != result->m || m2->n != result->n) { - return -1; + return false; } size_t m1_rows = m1->m; @@ -153,7 +178,7 @@ int matrix_dot_matrix(matrix_t *m1, matrix_t *m2, matrix_t *result) { } } - return 0; + return true; } matrix_t *matrix_read(FILE *f) { diff --git a/src/matrix.h b/src/matrix.h index eb622cbf..6fa79507 100644 --- a/src/matrix.h +++ b/src/matrix.h @@ -20,7 +20,8 @@ matrix_t *matrix_new_ones(size_t m, size_t n); bool matrix_resize(matrix_t *self, size_t m, size_t n); -matrix_t *matrix_copy(matrix_t *self); +matrix_t *matrix_new_copy(matrix_t *self); +bool matrix_copy(matrix_t *self, matrix_t *other); void matrix_init_values(matrix_t *self, double *values); void matrix_set(matrix_t *self, double value); @@ -29,13 +30,18 @@ void matrix_set_row(matrix_t *self, size_t index, double *row); void matrix_set_scalar(matrix_t *self, size_t row_index, size_t col_index, double value); double matrix_get(matrix_t *self, size_t row_index, size_t col_index); +double *matrix_get_row(matrix_t *self, size_t row_index); void matrix_add(matrix_t *self, double value); void matrix_sub(matrix_t *self, double value); void matrix_mul(matrix_t *self, double value); void matrix_div(matrix_t *self, double value); + +void matrix_log(matrix_t *self); +void matrix_exp(matrix_t *self); + void matrix_dot_vector(matrix_t *self, double *vec, double *result); -int matrix_dot_matrix(matrix_t *m1, matrix_t *m2, matrix_t *result); +bool matrix_dot_matrix(matrix_t *m1, matrix_t *m2, matrix_t *result); matrix_t *matrix_read(FILE *f); bool matrix_write(matrix_t *self, FILE *f);