Skip to content

Commit 84e157e

Browse files
authored
Feature/tensor details (#97)
<img width="1165" alt="Screenshot 2024-09-25 at 12 33 01 PM" src="https://github.com/user-attachments/assets/339f3eba-7b1d-4bfd-a804-3573705c0538"> Adds further details and analysis to tensors in the list view.
2 parents e054ab5 + 0a7fdc8 commit 84e157e

File tree

10 files changed

+304
-170
lines changed

10 files changed

+304
-170
lines changed

src/components/BufferDetails.tsx

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
//
3+
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
4+
5+
import { Icon, Intent } from '@blueprintjs/core';
6+
import { IconNames } from '@blueprintjs/icons';
7+
import classNames from 'classnames';
8+
import { Link } from 'react-router-dom';
9+
import { Tensor } from '../model/Graph';
10+
import { OperationDescription, TensorData } from '../model/APIData';
11+
import { toHex } from '../functions/math';
12+
import ROUTES from '../definitions/routes';
13+
import { useNextBuffer } from '../hooks/useAPI';
14+
import 'styles/components/BufferDetails.scss';
15+
16+
interface BufferDetailsProps {
17+
tensor: TensorData;
18+
operations: OperationDescription[];
19+
queryKey: string;
20+
className?: string;
21+
}
22+
23+
interface ShardSpec {
24+
grid: string;
25+
shape: [number, number];
26+
orientation: string;
27+
halo: number;
28+
}
29+
30+
const HEADER_LABELS = {
31+
shard_spec: 'ShardSpec',
32+
memory_layout: 'MemoryLayout',
33+
grid: 'CoreRangeSet',
34+
shape: 'Shape',
35+
orientation: 'ShardOrientation',
36+
halo: 'Halo',
37+
};
38+
39+
function BufferDetails({ tensor, operations, queryKey, className }: BufferDetailsProps) {
40+
const { address, consumers, dtype, layout, shape } = tensor;
41+
const lastOperationId: number = tensor.consumers[tensor.consumers.length - 1];
42+
const deallocationOperationId = getDeallocationOperation(tensor, operations);
43+
const { data: buffer, isLoading } = useNextBuffer(address, consumers, queryKey);
44+
45+
return (
46+
<>
47+
<table className='ttnn-table analysis-table'>
48+
<tbody>
49+
<tr>
50+
<th>Last used</th>
51+
<td>
52+
<Link to={`${ROUTES.OPERATIONS}/${lastOperationId}`}>
53+
{lastOperationId}{' '}
54+
{operations.find((operation) => operation.id === lastOperationId)?.name}
55+
</Link>
56+
</td>
57+
</tr>
58+
59+
<tr>
60+
<th>Deallocation</th>
61+
<td>
62+
{isLoading ? 'Loading...' : undefined}
63+
64+
{buffer && !isLoading && deallocationOperationId ? (
65+
<div>
66+
Deallocation found in{' '}
67+
<Link to={`${ROUTES.OPERATIONS}/${deallocationOperationId}`}>
68+
{deallocationOperationId}{' '}
69+
{operations.find((operation) => operation.id === deallocationOperationId)?.name}
70+
</Link>
71+
<Icon
72+
className='deallocation-icon'
73+
icon={IconNames.TICK}
74+
intent={Intent.SUCCESS}
75+
/>
76+
</div>
77+
) : (
78+
<>
79+
Missing deallocation operation
80+
<Icon
81+
className='deallocation-icon'
82+
icon={IconNames.WARNING_SIGN}
83+
intent={Intent.WARNING}
84+
/>
85+
</>
86+
)}
87+
</td>
88+
</tr>
89+
90+
{buffer?.next_usage && address && !isLoading ? (
91+
<tr>
92+
<th>Next allocation</th>
93+
<td>
94+
<span>
95+
{toHex(address)} next allocated in{' '}
96+
<Link to={`${ROUTES.OPERATIONS}/${buffer.operation_id}`}>
97+
{buffer.operation_id}{' '}
98+
{operations.find((operation) => operation.id === buffer.operation_id)?.name}
99+
</Link>{' '}
100+
(+{buffer.next_usage} operations)
101+
</span>
102+
</td>
103+
</tr>
104+
) : null}
105+
</tbody>
106+
</table>
107+
108+
<table className={classNames('ttnn-table two-tone-rows buffer-table', className)}>
109+
<tbody>
110+
<tr>
111+
<th>Device Id</th>
112+
<td>{tensor.device_id ?? 'n/a'}</td>
113+
</tr>
114+
115+
<tr>
116+
<th>DataType</th>
117+
<td>{dtype}</td>
118+
</tr>
119+
120+
<tr>
121+
<th>Layout</th>
122+
<td>{layout}</td>
123+
</tr>
124+
125+
{tensor?.memory_config
126+
? Object.entries(parseMemoryConfig(tensor.memory_config)).map(([key, value]) => (
127+
<tr key={key}>
128+
{key === 'shard_spec' && value && typeof value !== 'string' ? (
129+
<>
130+
<th>{getHeaderLabel(key)}</th>
131+
<td>
132+
<table className='ttnn-table alt-two-tone-rows'>
133+
<tbody>
134+
{Object.entries(value as ShardSpec).map(
135+
([innerKey, innerValue]) => (
136+
<tr key={innerKey}>
137+
<th>
138+
{getHeaderLabel(
139+
innerKey as keyof typeof HEADER_LABELS,
140+
)}
141+
</th>
142+
<td>{innerValue}</td>
143+
</tr>
144+
),
145+
)}
146+
</tbody>
147+
</table>
148+
</td>
149+
</>
150+
) : (
151+
<>
152+
<th>{getHeaderLabel(key as keyof typeof HEADER_LABELS)}</th>
153+
<td>{value as string}</td>
154+
</>
155+
)}
156+
</tr>
157+
))
158+
: null}
159+
160+
<tr>
161+
<th>Shape</th>
162+
<td>{shape}</td>
163+
</tr>
164+
</tbody>
165+
</table>
166+
</>
167+
);
168+
}
169+
170+
function getDeallocationOperation(tensor: Tensor, operations: OperationDescription[]): number | undefined {
171+
// TODO: Maybe we can strengthen this logic to ensure we're looking at deallocations rather than just checking the name
172+
const matchingInputs = operations.filter(
173+
(operation) =>
174+
operation.name.includes('deallocate') && operation.inputs.find((input) => input.id === tensor.id),
175+
);
176+
177+
return matchingInputs.map((x) => x.id)[0];
178+
}
179+
180+
function parseMemoryConfig(string: string) {
181+
const regex = /MemoryConfig\((.*)\)$/;
182+
const match = string.match(regex);
183+
184+
if (match) {
185+
const capturedString = match[1];
186+
187+
const memoryLayoutPattern = /memory_layout=([A-Za-z_:]+)/;
188+
const shardSpecPattern =
189+
/shard_spec=ShardSpec\(grid=\{(\[.*?\])\},shape=\{(\d+),\s*(\d+)\},orientation=ShardOrientation::([A-Z_]+),halo=(\d+)\)/;
190+
191+
const memoryLayoutMatch = capturedString.match(memoryLayoutPattern);
192+
const shardSpecMatch = capturedString.match(shardSpecPattern);
193+
194+
const memoryLayout = memoryLayoutMatch ? memoryLayoutMatch[1] : null;
195+
const shardSpec = shardSpecMatch
196+
? {
197+
grid: shardSpecMatch[1],
198+
shape: [parseInt(shardSpecMatch[2], 10), parseInt(shardSpecMatch[3], 10)],
199+
orientation: shardSpecMatch[4],
200+
halo: parseInt(shardSpecMatch[5], 10),
201+
}
202+
: null;
203+
204+
return {
205+
memory_layout: memoryLayout,
206+
shard_spec: shardSpec || 'std::nullopt',
207+
};
208+
}
209+
210+
return string;
211+
}
212+
213+
function getHeaderLabel(key: keyof typeof HEADER_LABELS) {
214+
return HEADER_LABELS[key];
215+
}
216+
217+
export default BufferDetails;

src/components/BufferTable.tsx

Lines changed: 0 additions & 113 deletions
This file was deleted.

src/components/DeviceOperations.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ interface DeviceOperationsData {
77
function DeviceOperations({ deviceOperations }: DeviceOperationsData) {
88
return (
99
<div>
10-
<table className='arguments-table'>
10+
<table className='ttnn-table two-tone-rows arguments-table'>
1111
<caption>Device Operations</caption>
1212

1313
<tbody>

src/components/OperationArguments.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function OperationArguments({ operationIndex, operation, onCollapseTensor }: Ope
2929
};
3030

3131
return (
32-
<table className='arguments-table has-vertical-headings'>
32+
<table className='ttnn-table two-tone-rows arguments-table has-vertical-headings'>
3333
<caption>Arguments</caption>
3434

3535
<tbody>

src/components/TensorList.tsx

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ import LoadingSpinner from './LoadingSpinner';
1515
import { useOperationsList, useTensors } from '../hooks/useAPI';
1616
import ROUTES from '../definitions/routes';
1717
import { Tensor } from '../model/Graph';
18-
import { OperationDescription } from '../model/APIData';
18+
import { OperationDescription, TensorData } from '../model/APIData';
1919
import { BufferType, BufferTypeLabel } from '../model/BufferType';
2020
import Collapsible from './Collapsible';
21-
import BufferTable from './BufferTable';
2221
import { expandedTensorsAtom } from '../store/app';
2322
import ListItem from './ListItem';
2423
import '@blueprintjs/select/lib/css/blueprint-select.css';
2524
import 'styles/components/ListView.scss';
2625
import 'styles/components/TensorList.scss';
26+
import BufferDetails from './BufferDetails';
2727

2828
const PLACEHOLDER_ARRAY_SIZE = 10;
2929
const OPERATION_EL_HEIGHT = 39; // Height in px of each list item
@@ -38,7 +38,7 @@ const TensorList = () => {
3838
const [shouldCollapseAll, setShouldCollapseAll] = useState(false);
3939
const [filterQuery, setFilterQuery] = useState('');
4040
const [memoryLeakCount, setMemoryLeakCount] = useState(0);
41-
const [filteredTensorList, setFilteredTensorList] = useState<Tensor[]>([]);
41+
const [filteredTensorList, setFilteredTensorList] = useState<TensorData[]>([]);
4242
const [hasScrolledFromTop, setHasScrolledFromTop] = useState(false);
4343
const [hasScrolledToBottom, setHasScrolledToBottom] = useState(false);
4444
const [filterMemoryLeaks, setFilterMemoryLeaks] = useState(false);
@@ -302,12 +302,13 @@ const TensorList = () => {
302302
) : undefined
303303
}
304304
>
305-
<BufferTable
306-
className='buffer-data'
307-
tensor={tensor}
308-
operations={operations}
309-
queryKey={virtualRow.index.toString()}
310-
/>
305+
<div className='arguments-wrapper'>
306+
<BufferDetails
307+
tensor={tensor}
308+
operations={operations}
309+
queryKey={virtualRow.index.toString()}
310+
/>
311+
</div>
311312
</Collapsible>
312313
</li>
313314
);

0 commit comments

Comments
 (0)