Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: _get_folder_size fn #471

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/litdata/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def __init__(

self._skip_chunk_indexes_deletion: Optional[List[int]] = None
self.zero_based_roi: Optional[List[Tuple[int, int]]] = None
self.filename_to_size_map: Dict[str, int] = {}
for cnk in _original_chunks:
self.filename_to_size_map[cnk["filename"]] = cnk["chunk_bytes"]

def can_delete(self, chunk_index: int) -> bool:
if self._skip_chunk_indexes_deletion is None:
Expand Down
13 changes: 8 additions & 5 deletions src/litdata/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ def _maybe_delete_chunks(self) -> None:
def _can_delete_chunk(self) -> bool:
if self._delete_chunks_when_processed:
return self._pre_download_counter >= self._max_pre_download - 1
return self._max_cache_size is not None and _get_folder_size(self._parent_cache_dir) >= self._max_cache_size
return (
self._max_cache_size is not None
and _get_folder_size(self._parent_cache_dir, self._config) >= self._max_cache_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't use the parent_dir anymore. Otherwise, this would always be empty. Each StreamingDataset should take care only its own cache.

)

def _pre_load_chunk(self, chunk_index: int) -> None:
chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)]
Expand Down Expand Up @@ -432,17 +435,17 @@ def __del__(self) -> None:
self._prepare_thread = None


def _get_folder_size(path: str) -> int:
def _get_folder_size(path: str, config: ChunksConfig) -> int:
"""Collect the size of each files within a folder.

This method is robust to file deletion races

"""
size = 0
for dirpath, _, filenames in os.walk(str(path)):
for filename in filenames:
for filename in os.listdir(os.path.join(path, "cache_dir")):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't add cache_dir there.

if filename in config.filename_to_size_map:
with contextlib.suppress(FileNotFoundError):
size += os.stat(os.path.join(dirpath, filename)).st_size
size += config.filename_to_size_map[filename]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check if filename is within the map, otherwise, this would fail the thread. Normally, the files should be. We need to print a warning if it doesn't.

return size


Expand Down
32 changes: 17 additions & 15 deletions tests/streaming/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import shutil
from time import sleep

import numpy as np

from litdata.streaming import reader
from litdata.streaming.cache import Cache
from litdata.streaming.config import ChunkedIndex
from litdata.streaming.item_loader import PyTreeLoader
from litdata.streaming.reader import _END_TOKEN, PrepareChunksThread, _get_folder_size
from litdata.streaming.reader import _END_TOKEN, PrepareChunksThread
from litdata.streaming.resolver import Dir
from litdata.utilities.env import _DistributedEnv
from tests.streaming.utils import filter_lock_files, get_lock_files
Expand All @@ -18,6 +16,7 @@ def test_reader_chunk_removal(tmpdir):
cache_dir = os.path.join(tmpdir, "cache_dir")
remote_dir = os.path.join(tmpdir, "remote_dir")
os.makedirs(cache_dir, exist_ok=True)
# we don't care about the max cache size here (so very large number)
cache = Cache(input_dir=Dir(path=cache_dir, url=remote_dir), chunk_size=2, max_cache_size=28020)

for i in range(25):
Expand All @@ -37,12 +36,18 @@ def test_reader_chunk_removal(tmpdir):
assert len(filter_lock_files(os.listdir(cache_dir))) == 14
assert len(get_lock_files(os.listdir(cache_dir))) == 0

cache = Cache(input_dir=Dir(path=cache_dir, url=remote_dir), chunk_size=2, max_cache_size=2800)
# Let's test if cache actually respects the max cache size
# each chunk is 40 bytes if it has 2 items
# a chunk with only 1 item is 24 bytes (values determined by checking actual chunk sizes)
cache = Cache(input_dir=Dir(path=cache_dir, url=remote_dir), chunk_size=2, max_cache_size=90)

shutil.rmtree(cache_dir)
os.makedirs(cache_dir, exist_ok=True)

for i in range(25):
# we expect at max 3 files to be present (2 chunks and 1 index file)
# why 2 chunks? Bcoz max cache size is 90 bytes and each chunk is 40 bytes or 24 bytes (1 item)
# So any additional chunk will go over the max cache size
assert len(filter_lock_files(os.listdir(cache_dir))) <= 3
index = ChunkedIndex(*cache._get_chunk_index_from_index(i), is_last_index=i == 24)
assert cache[index] == i
Expand All @@ -54,6 +59,7 @@ def test_reader_chunk_removal_compressed(tmpdir):
cache_dir = os.path.join(tmpdir, "cache_dir")
remote_dir = os.path.join(tmpdir, "remote_dir")
os.makedirs(cache_dir, exist_ok=True)
# we don't care about the max cache size here (so very large number)
cache = Cache(input_dir=Dir(path=cache_dir, url=remote_dir), chunk_size=2, max_cache_size=28020, compression="zstd")

for i in range(25):
Expand All @@ -72,29 +78,25 @@ def test_reader_chunk_removal_compressed(tmpdir):

assert len(filter_lock_files(os.listdir(cache_dir))) == 14
assert len(get_lock_files(os.listdir(cache_dir))) == 0

cache = Cache(input_dir=Dir(path=cache_dir, url=remote_dir), chunk_size=2, max_cache_size=2800, compression="zstd")
# Let's test if cache actually respects the max cache size
# each chunk is 40 bytes if it has 2 items
# a chunk with only 1 item is 24 bytes (values determined by checking actual chunk sizes)
cache = Cache(input_dir=Dir(path=cache_dir, url=remote_dir), chunk_size=2, max_cache_size=90, compression="zstd")

shutil.rmtree(cache_dir)
os.makedirs(cache_dir, exist_ok=True)

for i in range(25):
# we expect at max 3 files to be present (2 chunks and 1 index file)
# why 2 chunks? Bcoz max cache size is 90 bytes and each chunk is 40 bytes or 24 bytes (1 item)
# So any additional chunk will go over the max cache size
assert len(filter_lock_files(os.listdir(cache_dir))) <= 3
index = ChunkedIndex(*cache._get_chunk_index_from_index(i), is_last_index=i == 24)
assert cache[index] == i

assert len(filter_lock_files(os.listdir(cache_dir))) in [2, 3]


def test_get_folder_size(tmpdir):
array = np.zeros((10, 10))

np.save(os.path.join(tmpdir, "array_1.npy"), array)
np.save(os.path.join(tmpdir, "array_2.npy"), array)

assert _get_folder_size(tmpdir) == 928 * 2


def test_prepare_chunks_thread_eviction(tmpdir, monkeypatch):
monkeypatch.setattr(reader, "_LONG_DEFAULT_TIMEOUT", 0.1)

Expand Down
Loading