Skip to content

Commit 110495e

Browse files
committed
feat: add pow method
1 parent e5be9c6 commit 110495e

File tree

3 files changed

+81
-2
lines changed

3 files changed

+81
-2
lines changed

src/matrix.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::ops::{Add, Sub, Mul, Div, Deref, Index, IndexMut, DerefMut};
22

3-
use num::Num;
3+
use num::{Num, Float};
44
use crate::error::ShapeError;
55
use crate::shape;
66
use crate::coord;
@@ -90,6 +90,12 @@ impl<T: Num + PartialOrd + Copy> DynamicMatrix<T> {
9090
}
9191
}
9292

93+
impl<T: Float + PartialOrd + Copy> DynamicMatrix<T> {
94+
pub fn pow(&self, power: T) -> DynamicMatrix<T> {
95+
DynamicMatrix::from_tensor(self.tensor.pow(power)).unwrap()
96+
}
97+
}
98+
9399
// Scalar Addition
94100
impl<T: Num + PartialOrd + Copy> Add<T> for DynamicMatrix<T> {
95101
type Output = DynamicMatrix<T>;
@@ -606,4 +612,17 @@ mod tests {
606612
assert_eq!(matrix[coord![1, 0]], 3.0);
607613
assert_eq!(matrix[coord![1, 1]], 4.0);
608614
}
615+
616+
#[test]
617+
fn test_pow_matrix() {
618+
let shape = shape![2, 2].unwrap();
619+
let data = vec![2.0, 3.0, 4.0, 5.0];
620+
let matrix = DynamicMatrix::new(&shape, &data).unwrap();
621+
let result = matrix.pow(2.0);
622+
assert_eq!(result[coord![0, 0]], 4.0);
623+
assert_eq!(result[coord![0, 1]], 9.0);
624+
assert_eq!(result[coord![1, 0]], 16.0);
625+
assert_eq!(result[coord![1, 1]], 25.0);
626+
assert_eq!(result.shape(), &shape);
627+
}
609628
}

src/tensor.rs

+41-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use num::Num;
1+
use num::{Num, Float};
22
use std::ops::{Add, Sub, Mul, Div};
33

44
use crate::shape;
@@ -253,6 +253,16 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
253253
}
254254
}
255255

256+
impl<T: Float + PartialOrd + Copy> Tensor<T> {
257+
pub fn pow(&self, power: T) -> Tensor<T> {
258+
let mut result = Tensor::zeros(&self.shape);
259+
for i in 0..self.size() {
260+
result.data[i] = self.data[i].clone().powf(power);
261+
}
262+
result
263+
}
264+
}
265+
256266
// Element-wise Multiplication
257267
impl<T: Num + PartialOrd + Copy> Mul<T> for Tensor<T> {
258268
type Output = Tensor<T>;
@@ -1205,5 +1215,35 @@ mod tests {
12051215
let display = tensor.display();
12061216
assert_eq!(display, "[[[[1, 2],\n [3, 4]],\n\n [[5, 6],\n [7, 8]]],\n\n\n [[[9, 10],\n [11, 12]],\n\n [[13, 14],\n [15, 16]]]]");
12071217
}
1218+
1219+
#[test]
1220+
fn test_pow_tensor_square() {
1221+
let shape = shape![2, 2].unwrap();
1222+
let data = vec![1.0, 2.0, 3.0, 4.0];
1223+
let tensor = Tensor::new(&shape, &data).unwrap();
1224+
let result = tensor.pow(2.0);
1225+
assert_eq!(result.shape(), &shape);
1226+
assert_eq!(result.data, DynamicStorage::new(vec![1.0, 4.0, 9.0, 16.0]));
1227+
}
1228+
1229+
#[test]
1230+
fn test_pow_tensor_sqrt() {
1231+
let shape = shape![2, 2].unwrap();
1232+
let data = vec![1.0, 4.0, 9.0, 16.0];
1233+
let tensor = Tensor::new(&shape, &data).unwrap();
1234+
let result = tensor.pow(0.5);
1235+
assert_eq!(result.shape(), &shape);
1236+
assert_eq!(result.data, DynamicStorage::new(vec![1.0, 2.0, 3.0, 4.0]));
1237+
}
1238+
1239+
#[test]
1240+
fn test_pow_tensor_negative_exponent() {
1241+
let shape = shape![2, 2].unwrap();
1242+
let data = vec![1.0, 2.0, 4.0, 8.0];
1243+
let tensor = Tensor::new(&shape, &data).unwrap();
1244+
let result = tensor.pow(-1.0);
1245+
assert_eq!(result.shape(), &shape);
1246+
assert_eq!(result.data, DynamicStorage::new(vec![1.0, 0.5, 0.25, 0.125]));
1247+
}
12081248
}
12091249

src/vector.rs

+20
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::ops::{Add, Sub, Mul, Div, Deref, Index, IndexMut, DerefMut};
22

33
use num::Num;
4+
use num::Float;
45
use crate::error::ShapeError;
56
use crate::shape;
67
use crate::coord;
@@ -84,6 +85,12 @@ impl<T: Num + PartialOrd + Copy> DynamicVector<T> {
8485
}
8586
}
8687

88+
impl<T: Float + PartialOrd + Copy> DynamicVector<T> {
89+
pub fn pow(&self, power: T) -> DynamicVector<T> {
90+
DynamicVector::from_tensor(self.tensor.pow(power)).unwrap()
91+
}
92+
}
93+
8794
// Scalar Addition
8895
impl<T: Num + PartialOrd + Copy> Add<T> for DynamicVector<T> {
8996
type Output = DynamicVector<T>;
@@ -585,4 +592,17 @@ mod tests {
585592
assert_eq!(result[3], 2.0);
586593
assert_eq!(result.shape(), &shape);
587594
}
595+
596+
#[test]
597+
fn test_pow_vector() {
598+
let shape = shape![4].unwrap();
599+
let data = vec![2.0, 3.0, 4.0, 5.0];
600+
let vector = DynamicVector::new(&data).unwrap();
601+
let result = vector.pow(2.0);
602+
assert_eq!(result[0], 4.0);
603+
assert_eq!(result[1], 9.0);
604+
assert_eq!(result[2], 16.0);
605+
assert_eq!(result[3], 25.0);
606+
assert_eq!(result.shape(), &shape);
607+
}
588608
}

0 commit comments

Comments
 (0)