Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspatzke committed Mar 8, 2025
1 parent b81d7dc commit 1c34629
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 86 deletions.
87 changes: 54 additions & 33 deletions sigma/collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import InitVar, dataclass, field
from functools import reduce
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, IO, cast
Expand All @@ -24,26 +24,54 @@
class SigmaCollection:
"""Collection of Sigma rules"""

rules: List[Union[SigmaRule, SigmaCorrelationRule, SigmaFilter]]
init_rules: InitVar[List[Union[SigmaRule, SigmaCorrelationRule, SigmaFilter]]]
errors: List[SigmaError] = field(default_factory=list)
ids_to_rules: Dict[UUID, SigmaRuleBase] = field(
collect_filters: InitVar[bool] = False
rules: List[Union[SigmaRule, SigmaCorrelationRule]] = field(default_factory=list)
filters: List[SigmaFilter] = field(default_factory=list)
ids_to_rules: Dict[UUID, Union[SigmaRule, SigmaCorrelationRule]] = field(
init=False, repr=False, hash=False, compare=False
)
names_to_rules: Dict[str, SigmaRuleBase] = field(
names_to_rules: Dict[str, Union[SigmaRule, SigmaCorrelationRule]] = field(
init=False, repr=False, hash=False, compare=False
)

def __post_init__(self) -> None:
def __post_init__(
self,
init_rules: List[Union[SigmaRule, SigmaCorrelationRule, SigmaFilter]],
collect_filters: bool,
) -> None:
"""
Map rule identifiers to rules and resolve rule references in correlation rules.
"""
self.ids_to_rules = {}
self.names_to_rules = {}
for rule in self.rules:
if rule.id is not None:
self.ids_to_rules[rule.id] = rule
if rule.name is not None:
self.names_to_rules[rule.name] = rule
for rule in init_rules:
if isinstance(rule, (SigmaRule, SigmaCorrelationRule)):
self.rules.append(rule)
if rule.id is not None:
self.ids_to_rules[rule.id] = rule
if rule.name is not None:
self.names_to_rules[rule.name] = rule
elif isinstance(rule, SigmaFilter):
self.filters.append(rule)
else:
raise TypeError(f"Object of type { type(rule) } not supported in SigmaCollection")
if self.filters and not collect_filters:
self.apply_filters(self.filters)

def apply_filters(self, filters: List[SigmaFilter]) -> None:
"""
Apply filters on each rule and replace the rule with the filtered rule
"""
self.rules = [
reduce(
lambda r, f: f.apply_on_rule(r) if isinstance(r, SigmaRule) else r,
filters,
rule,
)
for rule in self.rules
]

def resolve_rule_references(self) -> None:
"""
Expand All @@ -57,24 +85,6 @@ def resolve_rule_references(self) -> None:
if isinstance(rule, SigmaCorrelationRule):
rule.resolve_rule_references(self)

# Extract all filters from the rules
filters: List[SigmaFilter] = [rule for rule in self.rules if isinstance(rule, SigmaFilter)]
self.rules = [rule for rule in self.rules if not isinstance(rule, SigmaFilter)]

# Apply filters on each rule and replace the rule with the filtered rule
self.rules = (
[
reduce(
lambda r, f: f.apply_on_rule(r) if isinstance(r, SigmaRule) else r,
filters,
rule,
)
for rule in self.rules
]
if filters
else self.rules
)

# Sort rules by reference order
self.rules = list(sorted(self.rules))

Expand All @@ -84,12 +94,15 @@ def from_dicts(
rules: List[NestedDict],
collect_errors: bool = False,
source: Optional[SigmaRuleLocation] = None,
collect_filters: bool = False,
) -> "SigmaCollection":
"""
Generate a rule collection from list of dicts containing parsed YAML content.
If the collect_errors parameters is set, exceptions are not raised while parsing but collected
in the errors property individually for each Sigma rule and the whole SigmaCollection.
If collect_filters is set, filters are only collected in the collection but not yet applied to the rules.
"""
errors: List[SigmaError] = []
parsed_rules: List[Union[SigmaRule, SigmaCorrelationRule, SigmaFilter]] = list()
Expand Down Expand Up @@ -149,22 +162,27 @@ def from_dicts(
else:
raise exception

return cls(parsed_rules, errors)
return cls(parsed_rules, errors, collect_filters)

@classmethod
def from_yaml(
cls,
yaml_str: Union[bytes, str, IO[Any]],
collect_errors: bool = False,
source: Optional[SigmaRuleLocation] = None,
collect_filters: bool = False,
) -> "SigmaCollection":
"""
Generate a rule collection from a string containing one or multiple YAML documents.
If the collect_errors parameters is set, exceptions are not raised while parsing but collected
in the errors property individually for each Sigma rule and the whole SigmaCollection.
If collect_filters is set, filters are only collected in the collection but not yet applied to the rules.
"""
return cls.from_dicts(list(yaml.safe_load_all(yaml_str)), collect_errors, source)
return cls.from_dicts(
list(yaml.safe_load_all(yaml_str)), collect_errors, source, collect_filters
)

@classmethod
def resolve_paths(
Expand Down Expand Up @@ -231,7 +249,8 @@ def load_ruleset(
sigma_collection = SigmaCollection.from_yaml(
result_path.open(encoding="utf-8"),
collect_errors,
SigmaRuleLocation(result_path),
collect_filters=True,
source=SigmaRuleLocation(result_path),
)
if (
on_load is not None
Expand All @@ -249,7 +268,9 @@ def load_ruleset(
def merge(cls, collections: Iterable["SigmaCollection"]) -> "SigmaCollection":
"""Merge multiple SigmaCollection objects into one and return it."""
return cls(
rules=[rule for collection in collections for rule in collection.rules],
init_rules=[
rule for collection in collections for rule in collection.rules + collection.filters
],
errors=[error for collection in collections for error in collection.errors],
)

Expand All @@ -267,7 +288,7 @@ def __iter__(self) -> Iterable[SigmaRuleBase]:
def __len__(self) -> int:
return len(self.rules)

def __getitem__(self, i: Union[int, str, UUID]) -> SigmaRuleBase:
def __getitem__(self, i: Union[int, str, UUID]) -> Union[SigmaRule, SigmaCorrelationRule]:
try:
if isinstance(i, int): # Index by position
return self.rules[i]
Expand Down
21 changes: 10 additions & 11 deletions sigma/conversion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,18 +182,17 @@ def convert(
processing.
"""
rule_collection.resolve_rule_references()
queries = []
for rule in rule_collection.rules:
if isinstance(rule, SigmaRule):
for query in self.convert_rule(rule, output_format or self.default_format):
queries.append(query)
elif isinstance(rule, SigmaCorrelationRule):
for query in self.convert_correlation_rule(
queries = [
query
for rule in rule_collection.rules
for query in (
self.convert_rule(rule, output_format or self.default_format)
if isinstance(rule, SigmaRule)
else self.convert_correlation_rule(
rule, output_format or self.default_format, correlation_method
):
queries.append(query)
else:
raise TypeError(f"Unexpected rule type: {type(rule)}")
)
)
]
return self.finalize(queries, output_format or self.default_format)

def convert_rule(self, rule: SigmaRule, output_format: Optional[str] = None) -> List[Any]:
Expand Down
4 changes: 2 additions & 2 deletions sigma/conversion/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class DeferredQueryExpression(ParentChainMixin, ABC):
conversion_state: "sigma.backends.state.ConversionState"
negated: bool = field(init=False, default=False)

def __post_init__(self):
def __post_init__(self) -> None:
"""Deferred expression automatically adds itself to conversion state."""
self.conversion_state.add_deferred_expression(self)

Expand Down Expand Up @@ -69,7 +69,7 @@ class DeferredTextQueryExpression(DeferredQueryExpression):
operators: ClassVar[Dict[bool, str]]
default_field: ClassVar[Optional[str]]

def __post_init__(self):
def __post_init__(self) -> None:
super().__post_init__()
if self.field is None and self.default_field is not None:
self.field = self.default_field
Expand Down
40 changes: 22 additions & 18 deletions sigma/correlations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Dict, List, Literal, Optional
from typing import Any, Dict, Iterator, List, Literal, Optional, Set, Union, Iterable

import sigma.exceptions as sigma_exceptions
from sigma.exceptions import SigmaRuleLocation, SigmaTimespanError
from sigma.processing.tracking import ProcessingItemTrackingMixin
Expand Down Expand Up @@ -32,9 +33,9 @@ class SigmaRuleReference:
"""

reference: str
rule: SigmaRule = field(init=False, repr=False, compare=False)
rule: Union[SigmaRule, "SigmaCorrelationRule"] = field(init=False, repr=False, compare=False)

def resolve(self, rule_collection: "sigma.collection.SigmaCollection"):
def resolve(self, rule_collection: "sigma.collection.SigmaCollection") -> None:
"""
Resolves the reference to the actual Sigma rule.
Expand All @@ -52,7 +53,7 @@ class SigmaCorrelationConditionOperator(Enum):
EQ = auto()

@classmethod
def operators(cls):
def operators(cls) -> Set[str]:
return {op.name.lower() for op in cls}


Expand All @@ -65,7 +66,7 @@ class SigmaCorrelationCondition:

@classmethod
def from_dict(
cls, d: dict, source: Optional[SigmaRuleLocation] = None
cls, d: Dict[str, Any], source: Optional[SigmaRuleLocation] = None
) -> "SigmaCorrelationCondition":
d_keys = frozenset(d.keys())
ops = frozenset(SigmaCorrelationConditionOperator.operators())
Expand All @@ -81,7 +82,6 @@ def from_dict(
)

# Condition operator and count
cond_op = None
for (
op
) in (
Expand All @@ -105,7 +105,7 @@ def from_dict(

return cls(op=cond_op, count=cond_count, fieldref=cond_field, source=source)

def to_dict(self) -> dict:
def to_dict(self) -> Dict[str, Any]:
if not self.fieldref:
return {self.op.name.lower(): self.count}
return {self.op.name.lower(): self.count, "field": self.fieldref}
Expand All @@ -118,7 +118,7 @@ class SigmaCorrelationTimespan:
count: int = field(init=False)
unit: str = field(init=False)

def __post_init__(self):
def __post_init__(self) -> None:
"""
Parses a string representing a time span and stores the equivalent number of seconds.
Expand Down Expand Up @@ -156,7 +156,7 @@ class SigmaCorrelationFieldAlias:
alias: str
mapping: Dict[SigmaRuleReference, str]

def resolve_rule_references(self, rule_collection: "sigma.collection.SigmaCollection"):
def resolve_rule_references(self, rule_collection: "sigma.collection.SigmaCollection") -> None:
"""
Resolves all rule references in the mapping property to actual Sigma rules.
Expand All @@ -171,14 +171,14 @@ def resolve_rule_references(self, rule_collection: "sigma.collection.SigmaCollec
class SigmaCorrelationFieldAliases:
aliases: Dict[str, SigmaCorrelationFieldAlias] = field(default_factory=dict)

def __iter__(self):
def __iter__(self) -> Iterator[SigmaCorrelationFieldAlias]:
return iter(self.aliases.values())

def __len__(self):
def __len__(self) -> int:
return len(self.aliases)

@classmethod
def from_dict(cls, d: dict):
def from_dict(cls, d: Dict[str, Any]) -> "SigmaCorrelationFieldAliases":
aliases = {}
for alias, mapping in d.items():
if not isinstance(mapping, dict):
Expand All @@ -196,15 +196,15 @@ def from_dict(cls, d: dict):

return cls(aliases=aliases)

def to_dict(self) -> dict:
def to_dict(self) -> Dict[str, Any]:
return {
alias: {
rule_ref.reference: field_name for rule_ref, field_name in alias_def.mapping.items()
}
for alias, alias_def in self.aliases.items()
}

def resolve_rule_references(self, rule_collection: "sigma.collection.SigmaCollection"):
def resolve_rule_references(self, rule_collection: "sigma.collection.SigmaCollection") -> None:
"""
Resolves all rule references in the aliases property to actual Sigma rules.
Expand All @@ -217,16 +217,20 @@ def resolve_rule_references(self, rule_collection: "sigma.collection.SigmaCollec

@dataclass
class SigmaCorrelationRule(SigmaRuleBase, ProcessingItemTrackingMixin):
type: SigmaCorrelationType = None
type: SigmaCorrelationType = SigmaCorrelationType.EVENT_COUNT
rules: List[SigmaRuleReference] = field(default_factory=list)
generate: bool = field(default=False)
timespan: SigmaCorrelationTimespan = field(default_factory=SigmaCorrelationTimespan)
timespan: SigmaCorrelationTimespan = field(
default_factory=lambda: SigmaCorrelationTimespan("1m")
)
group_by: Optional[List[str]] = None
aliases: SigmaCorrelationFieldAliases = field(default_factory=SigmaCorrelationFieldAliases)
condition: SigmaCorrelationCondition = None
condition: SigmaCorrelationCondition = field(
default_factory=lambda: SigmaCorrelationCondition(SigmaCorrelationConditionOperator.GTE, 1)
)
source: Optional[SigmaRuleLocation] = field(default=None, compare=False)

def __post_init__(self):
def __post_init__(self) -> None:
super().__post_init__()
if (
self.type not in {SigmaCorrelationType.TEMPORAL, SigmaCorrelationType.TEMPORAL_ORDERED}
Expand Down
Loading

0 comments on commit 1c34629

Please sign in to comment.