Skip to content

Commit

Permalink
add cachix support
Browse files Browse the repository at this point in the history
  • Loading branch information
Mic92 committed May 2, 2024
1 parent bfe8a69 commit fe9b0aa
Showing 1 changed file with 139 additions and 11 deletions.
150 changes: 139 additions & 11 deletions nix_fast_build/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections.abc import AsyncIterator
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import TemporaryDirectory
from types import TracebackType
from typing import IO, Any, NoReturn, TypeVar
Expand Down Expand Up @@ -73,6 +74,8 @@ class Options:
no_link: bool = False
out_link: str = "result"

cachix_cache: str | None = None

@property
def remote_url(self) -> None | str:
if self.remote is None:
Expand Down Expand Up @@ -150,6 +153,11 @@ async def parse_args(args: list[str]) -> Options:
metavar=("name", "value"),
default=[],
)
parser.add_argument(
"--cachix-cache",
help="Cachix cache to upload to",
default=None,
)
parser.add_argument(
"--no-nom",
help="Don't use nix-output-monitor to print build output (default: false)",
Expand Down Expand Up @@ -272,6 +280,7 @@ async def parse_args(args: list[str]) -> Options:
eval_max_memory_size=a.eval_max_memory_size,
eval_workers=a.eval_workers,
copy_to=a.copy_to,
cachix_cache=a.cachix_cache,
no_link=a.no_link,
out_link=a.out_link,
)
Expand Down Expand Up @@ -397,14 +406,14 @@ async def ensure_stop(
proc.send_signal(signal_no)
try:
await asyncio.wait_for(proc.wait(), timeout=timeout)
except asyncio.TimeoutError:
except TimeoutError:
print(f"Failed to stop process {shlex.join(cmd)}. Killing it.")
proc.kill()
await proc.wait()


@asynccontextmanager
async def remote_temp_dir(opts: Options) -> AsyncIterator[str]:
async def remote_temp_dir(opts: Options) -> AsyncIterator[Path]:
assert opts.remote
ssh_cmd = ["ssh", opts.remote, *opts.remote_ssh_options, "--"]
cmd = [*ssh_cmd, "mktemp", "-d"]
Expand All @@ -418,7 +427,7 @@ async def remote_temp_dir(opts: Options) -> AsyncIterator[str]:
f"Failed to create temporary directory on remote machine {opts.remote}: {rc}"
)
try:
yield tempdir
yield Path(tempdir)
finally:
cmd = [*ssh_cmd, "rm", "-rf", tempdir]
logger.info("run %s", shlex.join(cmd))
Expand All @@ -427,16 +436,11 @@ async def remote_temp_dir(opts: Options) -> AsyncIterator[str]:


@asynccontextmanager
async def nix_eval_jobs(stack: AsyncExitStack, opts: Options) -> AsyncIterator[Process]:
if opts.remote:
gc_root_dir = await stack.enter_async_context(remote_temp_dir(opts))
else:
gc_root_dir = stack.enter_context(TemporaryDirectory())

async def nix_eval_jobs(tmp_dir: Path, opts: Options) -> AsyncIterator[Process]:
args = [
"nix-eval-jobs",
"--gc-roots-dir",
gc_root_dir,
str(tmp_dir / "gcroots"),
"--force-recurse",
"--max-memory-size",
str(opts.eval_max_memory_size),
Expand Down Expand Up @@ -480,6 +484,56 @@ async def nix_output_monitor(pipe: Pipe, opts: Options) -> AsyncIterator[Process
print("\033[?25h")


@asynccontextmanager
async def run_cachix_daemon(
exit_stack: AsyncExitStack, tmp_dir: Path, cachix_cache: str, opts: Options
) -> AsyncIterator[Path]:
sock_path = tmp_dir / "cachix.sock"
cmd = maybe_remote(
[
*nix_shell(["nixpkgs#cachix"]),
"cachix",
"daemon",
"run",
"--socket",
str(sock_path),
cachix_cache,
],
opts,
)
proc = await asyncio.create_subprocess_exec(*cmd)
try:
await exit_stack.enter_async_context(ensure_stop(proc, cmd))
while True:
if sock_path.exists():
break
await asyncio.sleep(0.1)
yield sock_path
finally:
await run_cachix_daemon_stop(exit_stack, sock_path, opts)


async def run_cachix_daemon_stop(
exit_stack: AsyncExitStack, sock_path: Path | None, opts: Options
) -> int:
if sock_path is None:
return 0
cmd = maybe_remote(
[
*nix_shell(["nixpkgs#cachix"]),
"cachix",
"daemon",
"stop",
"--socket",
str(sock_path),
],
opts,
)
proc = await asyncio.create_subprocess_exec(*cmd)
await exit_stack.enter_async_context(ensure_stop(proc, cmd))
return await proc.wait()


@dataclass
class Build:
attr: str
Expand Down Expand Up @@ -529,6 +583,27 @@ async def upload(self, exit_stack: AsyncExitStack, opts: Options) -> int:
await exit_stack.enter_async_context(ensure_stop(proc, cmd))
return await proc.wait()

async def upload_cachix(
self, cachix_socket_path: Path | None, opts: Options
) -> int:
if cachix_socket_path is None:
return 0
cmd = maybe_remote(
[
*nix_shell(["nixpkgs#cachix"]),
"cachix",
"daemon",
"push",
"--socket",
str(cachix_socket_path),
*list(self.outputs.values()),
],
opts,
)
logger.debug("run %s", shlex.join(cmd))
proc = await asyncio.create_subprocess_exec(*cmd)
return await proc.wait()

async def download(self, exit_stack: AsyncExitStack, opts: Options) -> int:
if not opts.remote_url or not opts.download:
return 0
Expand Down Expand Up @@ -573,6 +648,10 @@ class DownloadFailure(Failure):
pass


class CachixFailure(Failure):
pass


T = TypeVar("T")


Expand Down Expand Up @@ -663,6 +742,7 @@ async def run_builds(
build_output: IO,
build_queue: QueueWithContext[Job | StopTask],
upload_queue: QueueWithContext[Build | StopTask],
cachix_queue: QueueWithContext[Build | StopTask],
download_queue: QueueWithContext[Build | StopTask],
failures: list[Failure],
opts: Options,
Expand All @@ -684,6 +764,7 @@ async def run_builds(
if rc == 0:
upload_queue.put_nowait(build)
download_queue.put_nowait(build)
cachix_queue.put_nowait(build)
else:
failures.append(BuildFailure(build.attr, f"build exited with {rc}"))

Expand All @@ -704,6 +785,24 @@ async def run_uploads(
failures.append(UploadFailure(build.attr, f"upload exited with {rc}"))


async def run_cachix_upload(
cachix_queue: QueueWithContext[Build | StopTask],
cachix_socket_path: Path | None,
failures: list[Failure],
opts: Options,
) -> int:
while True:
async with cachix_queue.get_context() as build:
if isinstance(build, StopTask):
logger.debug("finish cachix upload task")
return 0
rc = await build.upload_cachix(cachix_socket_path, opts)
if rc != 0:
failures.append(
UploadFailure(build.attr, f"cachix upload exited with {rc}")
)


async def run_downloads(
stack: AsyncExitStack,
download_queue: QueueWithContext[Build | StopTask],
Expand Down Expand Up @@ -744,14 +843,26 @@ async def report_progress(


async def run(stack: AsyncExitStack, opts: Options) -> int:
eval_proc = await stack.enter_async_context(nix_eval_jobs(stack, opts))
if opts.remote:
tmp_dir = await stack.enter_async_context(remote_temp_dir(opts))
else:
tmp_dir = Path(stack.enter_context(TemporaryDirectory()))

eval_proc = await stack.enter_async_context(nix_eval_jobs(tmp_dir, opts))
pipe: Pipe | None = None
output_monitor: Process | None = None
if opts.nom:
pipe = stack.enter_context(Pipe())
output_monitor = await stack.enter_async_context(nix_output_monitor(pipe, opts))

cachix_socket_path: Path | None = None
if opts.cachix_cache:
cachix_socket_path = await stack.enter_async_context(
run_cachix_daemon(stack, tmp_dir, opts.cachix_cache, opts)
)
failures: defaultdict[type, list[Failure]] = defaultdict(list)
build_queue: QueueWithContext[Job | StopTask] = QueueWithContext()
cachix_queue: QueueWithContext[Build | StopTask] = QueueWithContext()
upload_queue: QueueWithContext[Build | StopTask] = QueueWithContext()
download_queue: QueueWithContext[Build | StopTask] = QueueWithContext()

Expand All @@ -775,6 +886,7 @@ async def run(stack: AsyncExitStack, opts: Options) -> int:
build_output,
build_queue,
upload_queue,
cachix_queue,
download_queue,
failures[BuildFailure],
opts,
Expand All @@ -788,6 +900,17 @@ async def run(stack: AsyncExitStack, opts: Options) -> int:
name=f"upload-{i}",
)
)
tasks.append(
tg.create_task(
run_cachix_upload(
cachix_queue,
cachix_socket_path,
failures[CachixFailure],
opts,
),
name=f"cachix-{i}",
)
)
tasks.append(
tg.create_task(
run_downloads(
Expand Down Expand Up @@ -817,6 +940,11 @@ async def run(stack: AsyncExitStack, opts: Options) -> int:
upload_queue.put_nowait(StopTask())
await upload_queue.join()

logger.debug("Uploads finished, waiting for cachix uploads to finish...")
for _ in range(opts.max_jobs):
cachix_queue.put_nowait(StopTask())
await cachix_queue.join()

logger.debug("Uploads finished, waiting for downloads to finish...")
for _ in range(opts.max_jobs):
download_queue.put_nowait(StopTask())
Expand Down

0 comments on commit fe9b0aa

Please sign in to comment.