|
1 | | -use num::Num; |
| 1 | +use num::{Num, Float}; |
2 | 2 | use std::ops::{Add, Sub, Mul, Div}; |
3 | 3 |
|
4 | 4 | use crate::shape; |
@@ -253,6 +253,16 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> { |
253 | 253 | } |
254 | 254 | } |
255 | 255 |
|
| 256 | +impl<T: Float + PartialOrd + Copy> Tensor<T> { |
| 257 | + pub fn pow(&self, power: T) -> Tensor<T> { |
| 258 | + let mut result = Tensor::zeros(&self.shape); |
| 259 | + for i in 0..self.size() { |
| 260 | + result.data[i] = self.data[i].clone().powf(power); |
| 261 | + } |
| 262 | + result |
| 263 | + } |
| 264 | +} |
| 265 | + |
256 | 266 | // Element-wise Multiplication |
257 | 267 | impl<T: Num + PartialOrd + Copy> Mul<T> for Tensor<T> { |
258 | 268 | type Output = Tensor<T>; |
@@ -1205,5 +1215,35 @@ mod tests { |
1205 | 1215 | let display = tensor.display(); |
1206 | 1216 | assert_eq!(display, "[[[[1, 2],\n [3, 4]],\n\n [[5, 6],\n [7, 8]]],\n\n\n [[[9, 10],\n [11, 12]],\n\n [[13, 14],\n [15, 16]]]]"); |
1207 | 1217 | } |
| 1218 | + |
| 1219 | + #[test] |
| 1220 | + fn test_pow_tensor_square() { |
| 1221 | + let shape = shape![2, 2].unwrap(); |
| 1222 | + let data = vec![1.0, 2.0, 3.0, 4.0]; |
| 1223 | + let tensor = Tensor::new(&shape, &data).unwrap(); |
| 1224 | + let result = tensor.pow(2.0); |
| 1225 | + assert_eq!(result.shape(), &shape); |
| 1226 | + assert_eq!(result.data, DynamicStorage::new(vec![1.0, 4.0, 9.0, 16.0])); |
| 1227 | + } |
| 1228 | + |
| 1229 | + #[test] |
| 1230 | + fn test_pow_tensor_sqrt() { |
| 1231 | + let shape = shape![2, 2].unwrap(); |
| 1232 | + let data = vec![1.0, 4.0, 9.0, 16.0]; |
| 1233 | + let tensor = Tensor::new(&shape, &data).unwrap(); |
| 1234 | + let result = tensor.pow(0.5); |
| 1235 | + assert_eq!(result.shape(), &shape); |
| 1236 | + assert_eq!(result.data, DynamicStorage::new(vec![1.0, 2.0, 3.0, 4.0])); |
| 1237 | + } |
| 1238 | + |
| 1239 | + #[test] |
| 1240 | + fn test_pow_tensor_negative_exponent() { |
| 1241 | + let shape = shape![2, 2].unwrap(); |
| 1242 | + let data = vec![1.0, 2.0, 4.0, 8.0]; |
| 1243 | + let tensor = Tensor::new(&shape, &data).unwrap(); |
| 1244 | + let result = tensor.pow(-1.0); |
| 1245 | + assert_eq!(result.shape(), &shape); |
| 1246 | + assert_eq!(result.data, DynamicStorage::new(vec![1.0, 0.5, 0.25, 0.125])); |
| 1247 | + } |
1208 | 1248 | } |
1209 | 1249 |
|
0 commit comments