[math] Matrix method updates

This commit is contained in:
Al
2015-12-08 15:39:52 -05:00
parent 48ee665e71
commit d9d53ce17e
2 changed files with 38 additions and 7 deletions

View File

@@ -40,13 +40,24 @@ bool matrix_resize(matrix_t *self, size_t m, size_t n) {
return true;
}
matrix_t *matrix_copy(matrix_t *self) {
matrix_t *matrix_new_copy(matrix_t *self) {
matrix_t *cpy = matrix_new(self->m, self->n);
size_t num_values = self->m * self->n;
memcpy(cpy->values, self->values, num_values * sizeof(double));
double_array_copy(cpy->values, self->values, num_values);
return cpy;
}
bool matrix_copy(matrix_t *self, matrix_t *other) {
if (self->m != other->m || self->n != other->n) {
return false;
}
size_t num_values = self->n * self->n;
double_array_copy(other->values, self->values, num_values);
return true;
}
inline void matrix_init_values(matrix_t *self, double *values) {
size_t num_values = self->m * self->n;
memcpy(self->values, values, num_values * sizeof(double));
@@ -78,6 +89,12 @@ inline double matrix_get(matrix_t *self, size_t row_index, size_t col_index) {
return self->values->a[index];
}
inline double *matrix_get_row(matrix_t *self, size_t row_index) {
size_t index = row_index * self->n;
return self->values->a + index;
}
inline matrix_t *matrix_new_value(size_t m, size_t n, double value) {
matrix_t *matrix = matrix_new(m, n);
matrix_set(matrix, value);
@@ -119,6 +136,14 @@ inline void matrix_sub(matrix_t *self, double value) {
}
inline void matrix_log(matrix_t *self) {
double_array_log(self->values, self->m * self->n);
}
inline void matrix_exp(matrix_t *self) {
double_array_exp(self->values, self->m * self->n);
}
void matrix_dot_vector(matrix_t *self, double *vec, double *result) {
double *values = self->values;
size_t n = self->n;
@@ -129,9 +154,9 @@ void matrix_dot_vector(matrix_t *self, double *vec, double *result) {
}
}
int matrix_dot_matrix(matrix_t *m1, matrix_t *m2, matrix_t *result) {
bool matrix_dot_matrix(matrix_t *m1, matrix_t *m2, matrix_t *result) {
if (m1->n != m2->m || m1->m != result->m || m2->n != result->n) {
return -1;
return false;
}
size_t m1_rows = m1->m;
@@ -153,7 +178,7 @@ int matrix_dot_matrix(matrix_t *m1, matrix_t *m2, matrix_t *result) {
}
}
return 0;
return true;
}
matrix_t *matrix_read(FILE *f) {

View File

@@ -20,7 +20,8 @@ matrix_t *matrix_new_ones(size_t m, size_t n);
bool matrix_resize(matrix_t *self, size_t m, size_t n);
matrix_t *matrix_copy(matrix_t *self);
matrix_t *matrix_new_copy(matrix_t *self);
bool matrix_copy(matrix_t *self, matrix_t *other);
void matrix_init_values(matrix_t *self, double *values);
void matrix_set(matrix_t *self, double value);
@@ -29,13 +30,18 @@ void matrix_set_row(matrix_t *self, size_t index, double *row);
void matrix_set_scalar(matrix_t *self, size_t row_index, size_t col_index, double value);
double matrix_get(matrix_t *self, size_t row_index, size_t col_index);
double *matrix_get_row(matrix_t *self, size_t row_index);
void matrix_add(matrix_t *self, double value);
void matrix_sub(matrix_t *self, double value);
void matrix_mul(matrix_t *self, double value);
void matrix_div(matrix_t *self, double value);
void matrix_log(matrix_t *self);
void matrix_exp(matrix_t *self);
void matrix_dot_vector(matrix_t *self, double *vec, double *result);
int matrix_dot_matrix(matrix_t *m1, matrix_t *m2, matrix_t *result);
bool matrix_dot_matrix(matrix_t *m1, matrix_t *m2, matrix_t *result);
matrix_t *matrix_read(FILE *f);
bool matrix_write(matrix_t *self, FILE *f);