Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using untyped storage ref instead of tensor ref #50

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Using untyped storage ref instead of tensor ref
sanketpurandare committed Mar 28, 2024
commit 8c0e36d6fb192eac9eed456497ab6697b5d3b23d
14 changes: 8 additions & 6 deletions max_mem_tracker.py
Original file line number Diff line number Diff line change
@@ -4,29 +4,31 @@
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.utils.weak import WeakIdKeyDictionary
import weakref
import math

# Track all the memory being used by Tensors.
# Only max is tracked but others can be added.
MEMORY_USE = WeakIdKeyDictionary()
MEMORY_MAX = 0
MEMORY_ID = 0
# Minimum allocation size
PYTORCH_MIN_ALLOCATE = 2**9

def update_stats():
global MEMORY_MAX
curr_use = 0
for k, v in MEMORY_USE.items():
curr_use += k.nelement() * k.element_size()
curr_use += math.ceil(k.size() * k.element_size()/PYTORCH_MIN_ALLOCATE) * PYTORCH_MIN_ALLOCATE

if MEMORY_MAX < curr_use:
MEMORY_MAX = curr_use

# Should be called on every Tensor created
def track(t):
def track(t:torch.Tensor):
def cb(_):
update_stats()

wt = weakref.ref(t, cb)
MEMORY_USE[t] = wt
st = t.untyped_storage()
wt = weakref.ref(st, cb)
MEMORY_USE[st] = wt
update_stats()

# Use this Mode to call track on every Tensor being created by functions