Skip to content

Commit d68403d

Browse files
c00wpytorchmergebot
authored andcommitted
filelock: Make waitcounter variant to use (pytorch#139816)
Pull Request resolved: pytorch#139816 Approved by: https://github.com/ezyang
1 parent 6cb6e8d commit d68403d

File tree

8 files changed

+109
-14
lines changed

8 files changed

+109
-14
lines changed

test/test_utils_filelock.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Owner(s): ["module: unknown"]
2+
import concurrent.futures
3+
import tempfile
4+
import time
5+
6+
from torch.testing._internal.common_utils import run_tests, skipIfWindows, TestCase
7+
from torch.utils._filelock import FileLock
8+
9+
10+
class TestFileLock(TestCase):
11+
def test_no_crash(self):
12+
_, p = tempfile.mkstemp()
13+
with FileLock(p):
14+
pass
15+
16+
@skipIfWindows(
17+
msg="Windows doesn't support multiple files being opened at once easily"
18+
)
19+
def test_sequencing(self):
20+
with tempfile.NamedTemporaryFile() as ofd:
21+
p = ofd.name
22+
23+
def test_thread(i):
24+
with FileLock(p + ".lock"):
25+
start = time.time()
26+
with open(p, "a") as fd:
27+
fd.write(str(i))
28+
end = time.time()
29+
return (start, end)
30+
31+
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
32+
futures = [executor.submit(test_thread, i) for i in range(10)]
33+
times = []
34+
for f in futures:
35+
times.append(f.result(60))
36+
37+
with open(p) as fd:
38+
self.assertEqual(
39+
set(fd.read()), {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}
40+
)
41+
42+
for i, (start, end) in enumerate(times):
43+
for j, (newstart, newend) in enumerate(times):
44+
if i == j:
45+
continue
46+
47+
# Times should never intersect
48+
self.assertFalse(newstart > start and newstart < end)
49+
self.assertFalse(newend > start and newstart < end)
50+
51+
52+
if __name__ == "__main__":
53+
run_tests()

torch/_dynamo/pgo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def put_local_code_state(cache_key: str) -> None:
659659
lock_path = path + ".lock"
660660
# We /mostly/ don't need the lock but the tmp file could be clobbered
661661
# TODO: use a safe tempfile create to eliminate lock
662-
from filelock import FileLock
662+
from torch.utils._filelock import FileLock
663663

664664
os.makedirs(os.path.dirname(path), exist_ok=True)
665665

torch/_inductor/aoti_eager.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@ def aoti_eager_cache_dir(namespace: str, device: str) -> Path:
2020

2121

2222
def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any:
23-
from filelock import FileLock
24-
2523
# Avoid circular import
2624
from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT
25+
from torch.utils._filelock import FileLock
2726

2827
op_conf_lock_file = f"{op_func_name_with_overload}.lock"
2928
lock_dir = get_lock_dir()

torch/_inductor/codecache.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -1568,7 +1568,7 @@ def _compile_consts(consts: bytes, platform: str) -> str:
15681568
pos += rc
15691569
return consts_o
15701570

1571-
from filelock import FileLock
1571+
from torch.utils._filelock import FileLock
15721572

15731573
lock_dir = get_lock_dir()
15741574
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
@@ -2003,7 +2003,7 @@ def load_async(
20032003
key, input_path = write(source_code, "cpp", extra=vec_isa_cmd)
20042004

20052005
if key not in cls.cache:
2006-
from filelock import FileLock
2006+
from torch.utils._filelock import FileLock
20072007

20082008
lock_path = os.path.join(get_lock_dir(), key + ".lock")
20092009
output_name, output_dir = get_name_and_dir_from_output_file_path(input_path)
@@ -2068,7 +2068,7 @@ def _worker_compile_cpp(
20682068
fb_input_path: str,
20692069
fb_output_path: str,
20702070
) -> None:
2071-
from filelock import FileLock
2071+
from torch.utils._filelock import FileLock
20722072

20732073
with FileLock(lock_path, timeout=LOCK_TIMEOUT):
20742074
binary_path = (
@@ -2646,10 +2646,11 @@ def build_standalone_runtime(cls) -> str:
26462646
afile = str(dirpath / "standalone_halide_runtime.a")
26472647
sofile = str(dirpath / libname)
26482648
if not os.path.exists(donefile):
2649-
import filelock
26502649
import halide as hl # type: ignore[import-untyped,import-not-found]
26512650

2652-
with filelock.FileLock(lockfile, LOCK_TIMEOUT):
2651+
from torch.utils._filelock import FileLock
2652+
2653+
with FileLock(lockfile, LOCK_TIMEOUT):
26532654
if not os.path.exists(donefile):
26542655
with open(hookfile, "w") as f:
26552656
if device_type == "cuda":
@@ -2680,7 +2681,7 @@ def build_standalone_runtime(cls) -> str:
26802681

26812682

26822683
def _worker_task_halide(lockfile: str, jobs: List[partial[Any]]) -> None:
2683-
from filelock import FileLock
2684+
from torch.utils._filelock import FileLock
26842685

26852686
try:
26862687
with FileLock(lockfile, LOCK_TIMEOUT):
@@ -3075,7 +3076,7 @@ def compile(
30753076
"""
30763077
key, input_path = cls.write(source_code, dst_file_ext)
30773078
if key not in cls.cache:
3078-
from filelock import FileLock
3079+
from torch.utils._filelock import FileLock
30793080

30803081
lock_dir = get_lock_dir()
30813082
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
@@ -3166,7 +3167,7 @@ def compile(
31663167

31673168
key, input_path = cls.write(source_code, dst_file_ext)
31683169
if key not in cls.cache:
3169-
from filelock import FileLock
3170+
from torch.utils._filelock import FileLock
31703171

31713172
lock_dir = get_lock_dir()
31723173
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)

torch/_inductor/cpp_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def cpp_compiler_search(search: str) -> str:
7979
# Do not install GXX by default
8080
if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"):
8181
continue
82-
from filelock import FileLock
82+
from torch.utils._filelock import FileLock
8383

8484
lock_dir = get_lock_dir()
8585
lock = FileLock(

torch/_inductor/cpu_vec_isa.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def check_build(self, code: str) -> bool:
101101
"cpp",
102102
extra=_get_isa_dry_compile_fingerprint(self._arch_flags),
103103
)
104-
from filelock import FileLock
104+
from torch.utils._filelock import FileLock
105105

106106
lock_dir = get_lock_dir()
107107
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)

torch/_inductor/select_algorithm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
from unittest.mock import patch
2121

2222
import sympy
23-
from filelock import FileLock
2423

2524
import torch
2625
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
2726
from torch._dynamo.testing import rand_strided
2827
from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state
28+
from torch.utils._filelock import FileLock
2929

3030
from . import config, ir
3131
from .autotune_process import (

torch/utils/_filelock.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from types import TracebackType
2+
from typing import Optional
3+
from typing_extensions import Self
4+
5+
from filelock import FileLock as base_FileLock
6+
7+
from torch.monitor import _WaitCounter
8+
9+
10+
class FileLock(base_FileLock):
11+
"""
12+
This behaves like a normal file lock.
13+
14+
However, it adds waitcounters for acquiring and releasing the filelock
15+
as well as for the critical region within it.
16+
17+
pytorch.filelock.enter - While we're acquiring the filelock.
18+
pytorch.filelock.region - While we're holding the filelock and doing work.
19+
pytorch.filelock.exit - While we're releasing the filelock.
20+
"""
21+
22+
def __enter__(self) -> Self:
23+
self.region_counter = _WaitCounter("pytorch.filelock.region").guard()
24+
with _WaitCounter("pytorch.filelock.enter").guard():
25+
result = super().__enter__()
26+
self.region_counter.__enter__()
27+
return result
28+
29+
def __exit__(
30+
self,
31+
exc_type: Optional[type[BaseException]],
32+
exc_value: Optional[BaseException],
33+
traceback: Optional[TracebackType],
34+
) -> None:
35+
self.region_counter.__exit__()
36+
with _WaitCounter("pytorch.filelock.exit").guard():
37+
# Returns nothing per
38+
# https://github.com/tox-dev/filelock/blob/57f488ff8fdc2193572efe102408fb63cfefe4e4/src/filelock/_api.py#L379
39+
super().__exit__(exc_type, exc_value, traceback)
40+
# Returns nothing per
41+
# https://github.com/pytorch/pytorch/blob/0f6bfc58a2cfb7a5c052bea618ab62becaf5c912/torch/csrc/monitor/python_init.cpp#L315
42+
return None

0 commit comments

Comments
 (0)