Skip to content

Commit

Permalink
Support additional policy directories
Browse files Browse the repository at this point in the history
This allows multiple directories to contain qrexec policy, which allows
for transient policy that disappears on reboot.

Fixes: QubesOS/qubes-issues#8513
  • Loading branch information
DemiMarie committed Mar 14, 2024
1 parent c918563 commit 9d31ff7
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 83 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ install-dom0: all-dom0
install -t $(DESTDIR)/etc/qubes/policy.d/include -m 664 policy.d/include/*
install -d $(DESTDIR)/lib/systemd/system -m 755
install -t $(DESTDIR)/lib/systemd/system -m 644 systemd/qubes-qrexec-policy-daemon.service
install -m 755 -d $(DESTDIR)/usr/lib/tmpfiles.d/
install -m 0644 -t $(DESTDIR)/usr/lib/tmpfiles.d/ systemd/qrexec.conf
.PHONY: install-dom0


Expand Down
2 changes: 2 additions & 0 deletions qrexec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@
RPC_PATH = "/etc/qubes-rpc"
POLICY_AGENT_SOCKET_PATH = "/var/run/qubes/policy-agent.sock"
POLICYPATH = pathlib.Path("/etc/qubes/policy.d")
RUNTIME_POLICY_PATH = pathlib.Path("/run/qubes/policy.d")
POLICYSOCKET = pathlib.Path("/var/run/qubes/policy.sock")
POLICY_EVAL_SOCKET = pathlib.Path("/etc/qubes-rpc/policy.EvalSimple")
POLICY_GUI_SOCKET = pathlib.Path("/etc/qubes-rpc/policy.EvalGUI")
INCLUDEPATH = POLICYPATH / "include"
RUNTIME_INCLUDE_PATH = RUNTIME_POLICY_PATH / "include"
POLICYSUFFIX = ".policy"
POLICYPATH_OLD = pathlib.Path("/etc/qubes-rpc/policy")

Expand Down
69 changes: 55 additions & 14 deletions qrexec/policy/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
Sequence,
)

from .. import POLICYPATH, RPCNAME_ALLOWED_CHARSET, POLICYSUFFIX
from .. import POLICYPATH, RPCNAME_ALLOWED_CHARSET, POLICYSUFFIX, RUNTIME_POLICY_PATH
from ..utils import FullSystemInfo
from .. import exc
from ..exc import (
Expand Down Expand Up @@ -1790,22 +1790,54 @@ class AbstractFileSystemLoader(AbstractDirectoryLoader, AbstractFileLoader):
"""This class is used when policy is stored as regular files in a directory.
Args:
policy_path (pathlib.Path): Load this directory. Paths given to
``!include`` etc. directives are interpreted relative to this path.
policy_path: Load these directories. Paths given to
``!include`` etc. directives in a file are interpreted relative to
the path from which the file was loaded.
"""

def __init__(self, *, policy_path=POLICYPATH, **kwds):
super().__init__(**kwds)
self.policy_path = pathlib.Path(policy_path)

policy_path: Optional[pathlib.Path]
def __init__(
self,
*,
policy_path: Union[None, pathlib.PurePath, Iterable[pathlib.PurePath]]
) -> None:
super().__init__()
if policy_path is None:
iterable_policy_paths = [RUNTIME_POLICY_PATH, POLICYPATH]
elif isinstance(policy_path, pathlib.Path):
iterable_policy_paths = [policy_path]
elif isinstance(policy_path, list):
iterable_policy_paths = policy_path
else:
raise TypeError("unexpected type of policy path in AbstractFileSystemLoader.__init__!")
try:
self.load_policy_dir(self.policy_path)
self.load_policy_dirs(iterable_policy_paths)
except OSError as err:
raise AccessDenied(
"failed to load {} file: {!s}".format(err.filename, err)
) from err

def resolve_path(self, included_path):
self.policy_path = None

def load_policy_dirs(self, paths: Iterable[pathlib.PurePath]) -> None:
already_seen = set()
final_list = []
for path in paths:
for file_path in filter_filepaths(pathlib.Path(path).iterdir()):
basename = file_path.name
if basename not in already_seen:
already_seen.add(basename)
final_list.append(file_path)
final_list.sort(key=lambda x: x.name)
for file_path in final_list:
with file_path.open() as file:
self.policy_path = file_path.parent
try:
self.load_policy_file(file, file_path)
finally:
self.policy_path = None

def resolve_path(self, included_path: pathlib.PurePosixPath) -> pathlib.Path:
assert self.policy_path is not None, "Tried to resolve a path when not loading policy"
return (self.policy_path / included_path).resolve()


Expand Down Expand Up @@ -1840,12 +1872,21 @@ class ValidateParser(FilePolicy):
"""

def __init__(
self, *args, overrides: Dict[pathlib.Path, Optional[str]], **kwds
):
self,
*,
overrides: Dict[pathlib.Path, Optional[str]],
policy_path: Union[None, pathlib.PurePath, Iterable[pathlib.PurePath]] = None,
) -> None:
self.overrides = overrides
super().__init__(*args, **kwds)
super().__init__(policy_path=policy_path)

def load_policy_dir(self, dirpath):
def load_policy_dirs(self, paths: Iterable[pathlib.PurePath]) -> None:
assert len(paths) == 1
path, = paths
self.policy_path = path
self.load_policy_dir(path)

def load_policy_dir(self, dirpath: pathlib.Path) -> None:
for path in filter_filepaths(dirpath.iterdir()):
if path not in self.overrides:
with path.open() as file:
Expand Down
30 changes: 15 additions & 15 deletions qrexec/policy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,20 @@
import asyncio
import os.path
import pyinotify
from qrexec import POLICYPATH, POLICYPATH_OLD
from qrexec import POLICYPATH, POLICYPATH_OLD, RUNTIME_POLICY_PATH
from . import parser


class PolicyCache:
def __init__(self, path=POLICYPATH, use_legacy=True, lazy_load=False):
self.path = path
def __init__(
self, path=(RUNTIME_POLICY_PATH, POLICYPATH), use_legacy=True, lazy_load=False
) -> None:
self.paths = list(path)
self.outdated = lazy_load
if lazy_load:
self.policy = None
else:
self.policy = parser.FilePolicy(policy_path=self.path)
self.policy = parser.FilePolicy(policy_path=self.paths)

# default policy paths are listed manually, for compatibility with R4.0
# to be removed in Qubes 5.0
Expand All @@ -56,22 +58,20 @@ def initialize_watcher(self):
self.watch_manager, loop, default_proc_fun=PolicyWatcher(self)
)

if str(self.path) not in self.default_policy_paths and os.path.exists(
self.path
):
self.watches.append(
self.watch_manager.add_watch(
str(self.path), mask, rec=True, auto_add=True
for path in self.paths:
str_path = str(path)
if str_path not in self.default_policy_paths and os.path.exists(str_path):
self.watches.append(
self.watch_manager.add_watch(
str_path, mask, rec=True, auto_add=True
)
)
)

for path in self.default_policy_paths:
if not os.path.exists(path):
continue
self.watches.append(
self.watch_manager.add_watch(
str(path), mask, rec=True, auto_add=True
)
self.watch_manager.add_watch(str(path), mask, rec=True, auto_add=True)
)

def cleanup(self):
Expand All @@ -86,7 +86,7 @@ def cleanup(self):

def get_policy(self):
if self.outdated:
self.policy = parser.FilePolicy(policy_path=self.path)
self.policy = parser.FilePolicy(policy_path=self.paths)
self.outdated = False

return self.policy
Expand Down
2 changes: 1 addition & 1 deletion qrexec/tests/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def policy():
yield policy

assert mock_policy.mock_calls == [
mock.call(policy_path=PosixPath("/etc/qubes/policy.d"))
mock.call(policy_path=[PosixPath("/run/qubes/policy.d"), PosixPath("/etc/qubes/policy.d")]),
]


Expand Down
98 changes: 56 additions & 42 deletions qrexec/tests/policy_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,20 @@
import pytest
import unittest
import unittest.mock
import pathlib

from ..policy import utils


class TestPolicyCache:
@pytest.fixture
def tmp_paths(self, tmp_path: pathlib.Path) -> list[pathlib.Path]:
path1 = tmp_path / "path1"
path2 = tmp_path / "path2"
path1.mkdir()
path2.mkdir()
return [path1, path2]

@pytest.fixture
def mock_parser(self, monkeypatch):
mock_parser = unittest.mock.Mock()
Expand All @@ -37,58 +46,60 @@ def mock_parser(self, monkeypatch):
return mock_parser

def test_00_policy_init(self, tmp_path, mock_parser):
cache = utils.PolicyCache(tmp_path)
mock_parser.assert_called_once_with(policy_path=tmp_path)
cache = utils.PolicyCache([tmp_path])
mock_parser.assert_called_once_with(policy_path=[tmp_path])

@pytest.mark.asyncio
async def test_10_file_created(self, tmp_path, mock_parser):
cache = utils.PolicyCache(tmp_path)
cache.initialize_watcher()
async def test_10_file_created(self, tmp_paths, mock_parser):
for i in tmp_paths:
cache = utils.PolicyCache(tmp_paths)
cache.initialize_watcher()

assert not cache.outdated
assert not cache.outdated

file = tmp_path / "test"
file.write_text("test")
(i / "file").write_text("test")

await asyncio.sleep(1)
await asyncio.sleep(1)

assert cache.outdated
assert cache.outdated

@pytest.mark.asyncio
async def test_11_file_changed(self, tmp_path, mock_parser):
file = tmp_path / "test"
file.write_text("test")
async def test_11_file_changed(self, tmp_paths, mock_parser):
for i in tmp_paths:
file = i / "test"
file.write_text("test")

cache = utils.PolicyCache(tmp_path)
cache.initialize_watcher()
cache = utils.PolicyCache(tmp_paths)
cache.initialize_watcher()

assert not cache.outdated
assert not cache.outdated

file.write_text("new_content")
file.write_text("new_content")

await asyncio.sleep(1)
await asyncio.sleep(1)

assert cache.outdated
assert cache.outdated

@pytest.mark.asyncio
async def test_12_file_deleted(self, tmp_path, mock_parser):
file = tmp_path / "test"
file.write_text("test")
async def test_12_file_deleted(self, tmp_paths, mock_parser):
for i in tmp_paths:
file = i / "test"
file.write_text("test")

cache = utils.PolicyCache(tmp_path)
cache.initialize_watcher()
cache = utils.PolicyCache(tmp_paths)
cache.initialize_watcher()

assert not cache.outdated
assert not cache.outdated

os.remove(file)
os.remove(file)

await asyncio.sleep(1)
await asyncio.sleep(1)

assert cache.outdated
assert cache.outdated

@pytest.mark.asyncio
async def test_13_no_change(self, tmp_path, mock_parser):
cache = utils.PolicyCache(tmp_path)
async def test_13_no_change(self, tmp_paths, mock_parser):
cache = utils.PolicyCache(tmp_paths)
cache.initialize_watcher()

assert not cache.outdated
Expand All @@ -98,23 +109,26 @@ async def test_13_no_change(self, tmp_path, mock_parser):
assert not cache.outdated

@pytest.mark.asyncio
async def test_20_policy_updates(self, tmp_path, mock_parser):
cache = utils.PolicyCache(tmp_path)
cache.initialize_watcher()
async def test_20_policy_updates(self, tmp_paths, mock_parser):
count = 0
call = unittest.mock.call(policy_path=tmp_paths)

mock_parser.assert_called_once_with(policy_path=tmp_path)
for i in tmp_paths:
count += 2
cache = utils.PolicyCache(tmp_paths)
cache.initialize_watcher()

assert not cache.outdated
assert mock_parser.mock_calls == [call] * (count - 1)

file = tmp_path / "test"
file.write_text("test")
assert not cache.outdated

await asyncio.sleep(1)
file = i / "test"
file.write_text("test")

assert cache.outdated
await asyncio.sleep(1)

cache.get_policy()
assert cache.outdated

call = unittest.mock.call(policy_path=tmp_path)
cache.get_policy()

assert mock_parser.mock_calls == [call, call]
assert mock_parser.mock_calls == [call] * count
1 change: 0 additions & 1 deletion qrexec/tools/qrexec_policy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
decisions."""

import itertools
import os
import argparse
import asyncio

Expand Down
7 changes: 5 additions & 2 deletions qrexec/tools/qrexec_policy_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@

from ..utils import sanitize_domain_name, get_system_info
from .qrexec_policy_exec import handle_request
from .. import POLICYPATH, POLICYSOCKET, POLICY_EVAL_SOCKET, POLICY_GUI_SOCKET
from .. import POLICYPATH, POLICYSOCKET, POLICY_EVAL_SOCKET, POLICY_GUI_SOCKET, RUNTIME_POLICY_PATH
from ..policy.utils import PolicyCache

argparser = argparse.ArgumentParser(description="Evaluate qrexec policy daemon")

argparser.add_argument(
"--policy-path",
type=pathlib.Path,
default=POLICYPATH,
default=[RUNTIME_POLICY_PATH, POLICYPATH],
help="Use alternative policy path",
action='append',
)
argparser.add_argument(
"--socket-path",
Expand Down Expand Up @@ -291,6 +292,8 @@ async def handle_qrexec_connection(

async def start_serving(args=None):
args = argparser.parse_args(args)
if len(args.policy_path) > 2:
args.policy_path = args.policy_path[2:]

logging.basicConfig(format="%(message)s")
log = logging.getLogger("policy")
Expand Down
Loading

0 comments on commit 9d31ff7

Please sign in to comment.