Skip to content

Commit 87c2d59

Browse files
authored
Merge pull request #2 from Rust-Scientific-Computing/1-implement-tensor
Implement Tensor struct
2 parents 15047f2 + 0df55fe commit 87c2d59

11 files changed

+3174
-15
lines changed

Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ version = "0.1.0"
44
edition = "2021"
55
description = "A tensor library for scientific computing in Rust"
66
license = "MIT"
7-
license-file = "LICENSE"
87
homepage = "https://github.com/Rust-Scientific-Computing/feotensor"
98
repository = "https://github.com/Rust-Scientific-Computing/feotensor"
109

1110
[dependencies]
11+
itertools = "0.13.0"
12+
num = "0.4.3"

src/axes.rs

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub type Axes = Vec<usize>;

src/coordinate.rs

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
use std::fmt;
2+
use std::ops::{Index, IndexMut};
3+
4+
use crate::error::ShapeError;
5+
6+
#[derive(Debug, Clone, PartialEq)]
7+
pub struct Coordinate {
8+
indices: Vec<usize>,
9+
}
10+
11+
impl Coordinate {
12+
pub fn new(indices: Vec<usize>) -> Result<Self, ShapeError> {
13+
if indices.is_empty() {
14+
return Err(ShapeError::new("Coordinate cannot be empty"));
15+
}
16+
Ok(Self { indices })
17+
}
18+
19+
pub fn order(&self) -> usize {
20+
self.indices.len()
21+
}
22+
23+
pub fn iter(&self) -> std::slice::Iter<'_, usize> {
24+
self.indices.iter()
25+
}
26+
27+
pub fn insert(&self, index: usize, axis: usize) -> Self {
28+
let mut new_indices = self.indices.clone();
29+
new_indices.insert(index, axis);
30+
Self {
31+
indices: new_indices,
32+
}
33+
}
34+
}
35+
36+
impl Index<usize> for Coordinate {
37+
type Output = usize;
38+
39+
fn index(&self, index: usize) -> &Self::Output {
40+
&self.indices[index]
41+
}
42+
}
43+
44+
impl IndexMut<usize> for Coordinate {
45+
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
46+
&mut self.indices[index]
47+
}
48+
}
49+
50+
impl fmt::Display for Coordinate {
51+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52+
use itertools::Itertools;
53+
let idxs = self.indices.iter().map(|&x| format!("{}", x)).join(", ");
54+
write!(f, "({})", idxs)
55+
}
56+
}
57+
58+
#[macro_export]
59+
macro_rules! coord {
60+
($($index:expr),*) => {
61+
{
62+
use $crate::coordinate::Coordinate;
63+
Coordinate::new(vec![$($index),*])
64+
}
65+
};
66+
67+
($index:expr; $count:expr) => {
68+
{
69+
use $crate::coordinate::Coordinate;
70+
Coordinate::new(vec![$index; $count])
71+
}
72+
};
73+
}
74+
75+
#[cfg(test)]
76+
mod tests {
77+
use super::*;
78+
79+
#[test]
80+
fn test_order() {
81+
let coord = coord![1, 2, 3].unwrap();
82+
assert_eq!(coord.order(), 3);
83+
}
84+
85+
#[test]
86+
fn test_iter() {
87+
let coord = coord![1, 2, 3].unwrap();
88+
let mut iter = coord.iter();
89+
assert_eq!(iter.next(), Some(&1));
90+
assert_eq!(iter.next(), Some(&2));
91+
assert_eq!(iter.next(), Some(&3));
92+
assert_eq!(iter.next(), None);
93+
}
94+
95+
#[test]
96+
fn test_insert() {
97+
let coord = coord![1, 2, 3].unwrap();
98+
let new_coord = coord.insert(1, 4);
99+
assert_eq!(new_coord, coord![1, 4, 2, 3].unwrap());
100+
}
101+
102+
#[test]
103+
fn test_index() {
104+
let coord = coord![1, 2, 3].unwrap();
105+
assert_eq!(coord[0], 1);
106+
assert_eq!(coord[1], 2);
107+
assert_eq!(coord[2], 3);
108+
}
109+
110+
#[test]
111+
fn test_index_mut() {
112+
let mut coord = coord![1, 2, 3].unwrap();
113+
coord[1] = 4;
114+
assert_eq!(coord[1], 4);
115+
}
116+
117+
#[test]
118+
fn test_display() {
119+
let coord = coord![1, 2, 3].unwrap();
120+
assert_eq!(format!("{}", coord), "(1, 2, 3)");
121+
}
122+
123+
#[test]
124+
fn test_coord_macro() {
125+
let coord = coord![1, 2, 3].unwrap();
126+
assert_eq!(coord, Coordinate::new(vec![1, 2, 3]).unwrap());
127+
128+
let coord_repeated = coord![1; 3].unwrap();
129+
assert_eq!(coord_repeated, Coordinate::new(vec![1, 1, 1]).unwrap());
130+
}
131+
}

src/error.rs

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use std::fmt;
2+
3+
#[derive(Debug, Clone)]
4+
pub struct ShapeError {
5+
reason: String,
6+
}
7+
8+
impl ShapeError {
9+
pub fn new(reason: &str) -> Self {
10+
ShapeError {
11+
reason: reason.to_string(),
12+
}
13+
}
14+
}
15+
16+
impl fmt::Display for ShapeError {
17+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
18+
write!(f, "ShapeError: {}", self.reason)
19+
}
20+
}

src/iter.rs

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
use crate::coord;
2+
use crate::coordinate::Coordinate;
3+
use crate::shape::Shape;
4+
use std::cmp::max;
5+
6+
pub struct IndexIterator {
7+
shape: Shape,
8+
current: Coordinate,
9+
done: bool,
10+
}
11+
12+
impl IndexIterator {
13+
pub fn new(shape: &Shape) -> Self {
14+
// (shape.order() == 0) => `next` returns None before `current` is used
15+
let current = coord![0; max(shape.order(), 1)].unwrap();
16+
IndexIterator {
17+
shape: shape.clone(),
18+
current,
19+
done: false,
20+
}
21+
}
22+
}
23+
24+
impl Iterator for IndexIterator {
25+
type Item = Coordinate;
26+
27+
fn next(&mut self) -> Option<Self::Item> {
28+
if self.done || self.shape.order() == 0 {
29+
return None;
30+
}
31+
32+
let result = self.current.clone();
33+
34+
for i in (0..self.shape.order()).rev() {
35+
if self.current[i] + 1 < self.shape[i] {
36+
self.current[i] += 1;
37+
break;
38+
} else {
39+
self.current[i] = 0;
40+
if i == 0 {
41+
self.done = true;
42+
}
43+
}
44+
}
45+
46+
Some(result)
47+
}
48+
}
49+
50+
#[cfg(test)]
51+
mod tests {
52+
use super::*;
53+
use crate::shape;
54+
55+
#[test]
56+
fn test_index_iterator() {
57+
let shape = shape![2, 3].unwrap();
58+
let mut iter = IndexIterator::new(&shape);
59+
60+
assert_eq!(iter.next(), Some(coord![0, 0].unwrap()));
61+
assert_eq!(iter.next(), Some(coord![0, 1].unwrap()));
62+
assert_eq!(iter.next(), Some(coord![0, 2].unwrap()));
63+
assert_eq!(iter.next(), Some(coord![1, 0].unwrap()));
64+
assert_eq!(iter.next(), Some(coord![1, 1].unwrap()));
65+
assert_eq!(iter.next(), Some(coord![1, 2].unwrap()));
66+
assert_eq!(iter.next(), None);
67+
}
68+
69+
#[test]
70+
fn test_index_iterator_single_dimension() {
71+
let shape = shape![4].unwrap();
72+
let mut iter = IndexIterator::new(&shape);
73+
74+
assert_eq!(iter.next(), Some(coord![0].unwrap()));
75+
assert_eq!(iter.next(), Some(coord![1].unwrap()));
76+
assert_eq!(iter.next(), Some(coord![2].unwrap()));
77+
assert_eq!(iter.next(), Some(coord![3].unwrap()));
78+
assert_eq!(iter.next(), None);
79+
}
80+
81+
#[test]
82+
fn test_index_iterator_empty_tensor() {
83+
let shape = shape![].unwrap();
84+
let mut iter = IndexIterator::new(&shape);
85+
86+
assert_eq!(iter.next(), None);
87+
}
88+
}

src/lib.rs

+9-14
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
1-
pub fn add(left: usize, right: usize) -> usize {
2-
left + right
3-
}
4-
5-
#[cfg(test)]
6-
mod tests {
7-
use super::*;
8-
9-
#[test]
10-
fn it_works() {
11-
let result = add(2, 2);
12-
assert_eq!(result, 4);
13-
}
14-
}
1+
pub mod axes;
2+
pub mod coordinate;
3+
pub mod error;
4+
pub mod iter;
5+
pub mod matrix;
6+
pub mod shape;
7+
pub mod storage;
8+
pub mod tensor;
9+
pub mod vector;

0 commit comments

Comments
 (0)