|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// |
| 3 | +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. |
| 4 | + |
| 5 | +import { Icon, Intent } from '@blueprintjs/core'; |
| 6 | +import { IconNames } from '@blueprintjs/icons'; |
| 7 | +import classNames from 'classnames'; |
| 8 | +import { Link } from 'react-router-dom'; |
| 9 | +import { Tensor } from '../model/Graph'; |
| 10 | +import { OperationDescription, TensorData } from '../model/APIData'; |
| 11 | +import { toHex } from '../functions/math'; |
| 12 | +import ROUTES from '../definitions/routes'; |
| 13 | +import { useNextBuffer } from '../hooks/useAPI'; |
| 14 | +import 'styles/components/BufferDetails.scss'; |
| 15 | + |
| 16 | +interface BufferDetailsProps { |
| 17 | + tensor: TensorData; |
| 18 | + operations: OperationDescription[]; |
| 19 | + queryKey: string; |
| 20 | + className?: string; |
| 21 | +} |
| 22 | + |
| 23 | +interface ShardSpec { |
| 24 | + grid: string; |
| 25 | + shape: [number, number]; |
| 26 | + orientation: string; |
| 27 | + halo: number; |
| 28 | +} |
| 29 | + |
| 30 | +const HEADER_LABELS = { |
| 31 | + shard_spec: 'ShardSpec', |
| 32 | + memory_layout: 'MemoryLayout', |
| 33 | + grid: 'CoreRangeSet', |
| 34 | + shape: 'Shape', |
| 35 | + orientation: 'ShardOrientation', |
| 36 | + halo: 'Halo', |
| 37 | +}; |
| 38 | + |
| 39 | +function BufferDetails({ tensor, operations, queryKey, className }: BufferDetailsProps) { |
| 40 | + const { address, consumers, dtype, layout, shape } = tensor; |
| 41 | + const lastOperationId: number = tensor.consumers[tensor.consumers.length - 1]; |
| 42 | + const deallocationOperationId = getDeallocationOperation(tensor, operations); |
| 43 | + const { data: buffer, isLoading } = useNextBuffer(address, consumers, queryKey); |
| 44 | + |
| 45 | + return ( |
| 46 | + <> |
| 47 | + <table className='ttnn-table analysis-table'> |
| 48 | + <tbody> |
| 49 | + <tr> |
| 50 | + <th>Last used</th> |
| 51 | + <td> |
| 52 | + <Link to={`${ROUTES.OPERATIONS}/${lastOperationId}`}> |
| 53 | + {lastOperationId}{' '} |
| 54 | + {operations.find((operation) => operation.id === lastOperationId)?.name} |
| 55 | + </Link> |
| 56 | + </td> |
| 57 | + </tr> |
| 58 | + |
| 59 | + <tr> |
| 60 | + <th>Deallocation</th> |
| 61 | + <td> |
| 62 | + {isLoading ? 'Loading...' : undefined} |
| 63 | + |
| 64 | + {buffer && !isLoading && deallocationOperationId ? ( |
| 65 | + <div> |
| 66 | + Deallocation found in{' '} |
| 67 | + <Link to={`${ROUTES.OPERATIONS}/${deallocationOperationId}`}> |
| 68 | + {deallocationOperationId}{' '} |
| 69 | + {operations.find((operation) => operation.id === deallocationOperationId)?.name} |
| 70 | + </Link> |
| 71 | + <Icon |
| 72 | + className='deallocation-icon' |
| 73 | + icon={IconNames.TICK} |
| 74 | + intent={Intent.SUCCESS} |
| 75 | + /> |
| 76 | + </div> |
| 77 | + ) : ( |
| 78 | + <> |
| 79 | + Missing deallocation operation |
| 80 | + <Icon |
| 81 | + className='deallocation-icon' |
| 82 | + icon={IconNames.WARNING_SIGN} |
| 83 | + intent={Intent.WARNING} |
| 84 | + /> |
| 85 | + </> |
| 86 | + )} |
| 87 | + </td> |
| 88 | + </tr> |
| 89 | + |
| 90 | + {buffer?.next_usage && address && !isLoading ? ( |
| 91 | + <tr> |
| 92 | + <th>Next allocation</th> |
| 93 | + <td> |
| 94 | + <span> |
| 95 | + {toHex(address)} next allocated in{' '} |
| 96 | + <Link to={`${ROUTES.OPERATIONS}/${buffer.operation_id}`}> |
| 97 | + {buffer.operation_id}{' '} |
| 98 | + {operations.find((operation) => operation.id === buffer.operation_id)?.name} |
| 99 | + </Link>{' '} |
| 100 | + (+{buffer.next_usage} operations) |
| 101 | + </span> |
| 102 | + </td> |
| 103 | + </tr> |
| 104 | + ) : null} |
| 105 | + </tbody> |
| 106 | + </table> |
| 107 | + |
| 108 | + <table className={classNames('ttnn-table two-tone-rows buffer-table', className)}> |
| 109 | + <tbody> |
| 110 | + <tr> |
| 111 | + <th>Device Id</th> |
| 112 | + <td>{tensor.device_id ?? 'n/a'}</td> |
| 113 | + </tr> |
| 114 | + |
| 115 | + <tr> |
| 116 | + <th>DataType</th> |
| 117 | + <td>{dtype}</td> |
| 118 | + </tr> |
| 119 | + |
| 120 | + <tr> |
| 121 | + <th>Layout</th> |
| 122 | + <td>{layout}</td> |
| 123 | + </tr> |
| 124 | + |
| 125 | + {tensor?.memory_config |
| 126 | + ? Object.entries(parseMemoryConfig(tensor.memory_config)).map(([key, value]) => ( |
| 127 | + <tr key={key}> |
| 128 | + {key === 'shard_spec' && value && typeof value !== 'string' ? ( |
| 129 | + <> |
| 130 | + <th>{getHeaderLabel(key)}</th> |
| 131 | + <td> |
| 132 | + <table className='ttnn-table alt-two-tone-rows'> |
| 133 | + <tbody> |
| 134 | + {Object.entries(value as ShardSpec).map( |
| 135 | + ([innerKey, innerValue]) => ( |
| 136 | + <tr key={innerKey}> |
| 137 | + <th> |
| 138 | + {getHeaderLabel( |
| 139 | + innerKey as keyof typeof HEADER_LABELS, |
| 140 | + )} |
| 141 | + </th> |
| 142 | + <td>{innerValue}</td> |
| 143 | + </tr> |
| 144 | + ), |
| 145 | + )} |
| 146 | + </tbody> |
| 147 | + </table> |
| 148 | + </td> |
| 149 | + </> |
| 150 | + ) : ( |
| 151 | + <> |
| 152 | + <th>{getHeaderLabel(key as keyof typeof HEADER_LABELS)}</th> |
| 153 | + <td>{value as string}</td> |
| 154 | + </> |
| 155 | + )} |
| 156 | + </tr> |
| 157 | + )) |
| 158 | + : null} |
| 159 | + |
| 160 | + <tr> |
| 161 | + <th>Shape</th> |
| 162 | + <td>{shape}</td> |
| 163 | + </tr> |
| 164 | + </tbody> |
| 165 | + </table> |
| 166 | + </> |
| 167 | + ); |
| 168 | +} |
| 169 | + |
| 170 | +function getDeallocationOperation(tensor: Tensor, operations: OperationDescription[]): number | undefined { |
| 171 | + // TODO: Maybe we can strengthen this logic to ensure we're looking at deallocations rather than just checking the name |
| 172 | + const matchingInputs = operations.filter( |
| 173 | + (operation) => |
| 174 | + operation.name.includes('deallocate') && operation.inputs.find((input) => input.id === tensor.id), |
| 175 | + ); |
| 176 | + |
| 177 | + return matchingInputs.map((x) => x.id)[0]; |
| 178 | +} |
| 179 | + |
| 180 | +function parseMemoryConfig(string: string) { |
| 181 | + const regex = /MemoryConfig\((.*)\)$/; |
| 182 | + const match = string.match(regex); |
| 183 | + |
| 184 | + if (match) { |
| 185 | + const capturedString = match[1]; |
| 186 | + |
| 187 | + const memoryLayoutPattern = /memory_layout=([A-Za-z_:]+)/; |
| 188 | + const shardSpecPattern = |
| 189 | + /shard_spec=ShardSpec\(grid=\{(\[.*?\])\},shape=\{(\d+),\s*(\d+)\},orientation=ShardOrientation::([A-Z_]+),halo=(\d+)\)/; |
| 190 | + |
| 191 | + const memoryLayoutMatch = capturedString.match(memoryLayoutPattern); |
| 192 | + const shardSpecMatch = capturedString.match(shardSpecPattern); |
| 193 | + |
| 194 | + const memoryLayout = memoryLayoutMatch ? memoryLayoutMatch[1] : null; |
| 195 | + const shardSpec = shardSpecMatch |
| 196 | + ? { |
| 197 | + grid: shardSpecMatch[1], |
| 198 | + shape: [parseInt(shardSpecMatch[2], 10), parseInt(shardSpecMatch[3], 10)], |
| 199 | + orientation: shardSpecMatch[4], |
| 200 | + halo: parseInt(shardSpecMatch[5], 10), |
| 201 | + } |
| 202 | + : null; |
| 203 | + |
| 204 | + return { |
| 205 | + memory_layout: memoryLayout, |
| 206 | + shard_spec: shardSpec || 'std::nullopt', |
| 207 | + }; |
| 208 | + } |
| 209 | + |
| 210 | + return string; |
| 211 | +} |
| 212 | + |
| 213 | +function getHeaderLabel(key: keyof typeof HEADER_LABELS) { |
| 214 | + return HEADER_LABELS[key]; |
| 215 | +} |
| 216 | + |
| 217 | +export default BufferDetails; |
0 commit comments