-
Notifications
You must be signed in to change notification settings - Fork 102
Expand file tree
/
Copy pathtest_entity_risk.py
More file actions
357 lines (281 loc) · 14.3 KB
/
test_entity_risk.py
File metadata and controls
357 lines (281 loc) · 14.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
"""
Unit tests for the Risk-Based Alerting (RBA) entity-rollup engine.
These tests run without a real Redis: a tiny in-memory fake implements the
exact subset of redis.asyncio commands ``EntityRiskEngine`` calls. That keeps
the test fast and deterministic (no clock + no flakes from TTLs).
The tests gate three properties the 2026 KPI bar depends on:
1. Score accumulates additively across alerts and severities.
2. Time-decay halves the score after one half-life so stale signal drops out.
3. Threshold promotion happens exactly once and is visible on the queue.
"""
from __future__ import annotations
from datetime import datetime, timedelta
from typing import Any
from uuid import UUID, uuid4
import pytest
from app.models.alert import AlertSeverity, RawAlert
from app.services.entity_risk import EntityRiskEngine
# ---------------------------------------------------------------------------
# In-memory fake redis (only the commands the engine uses)
# ---------------------------------------------------------------------------
class _FakeRedis:
"""Minimal async fake covering hset/hgetall/expire/zadd/zrevrange/zcard/zremrangebyrank."""
def __init__(self) -> None:
self.hashes: dict[str, dict[bytes, bytes]] = {}
self.zsets: dict[str, dict[str, float]] = {}
async def hset(self, key: str, mapping: dict[str, Any]) -> int:
bucket = self.hashes.setdefault(key, {})
for k, v in mapping.items():
bucket[k.encode() if isinstance(k, str) else k] = v.encode() if isinstance(v, str) else v
return len(mapping)
async def hgetall(self, key: str) -> dict[bytes, bytes]:
return dict(self.hashes.get(key, {}))
async def expire(self, key: str, ttl: int) -> bool:
return key in self.hashes or key in self.zsets
async def zadd(self, key: str, mapping: dict[str, float]) -> int:
bucket = self.zsets.setdefault(key, {})
added = 0
for member, score in mapping.items():
if member not in bucket:
added += 1
bucket[member] = float(score)
return added
async def zrevrange(self, key: str, start: int, stop: int, withscores: bool = False) -> list[Any]:
bucket = self.zsets.get(key, {})
sorted_pairs = sorted(bucket.items(), key=lambda kv: kv[1], reverse=True)
# Redis ranges are inclusive on both ends.
if stop == -1:
sliced = sorted_pairs[start:]
else:
sliced = sorted_pairs[start : stop + 1]
if withscores:
return [(m.encode(), s) for m, s in sliced]
return [m.encode() for m, _ in sliced]
async def zcard(self, key: str) -> int:
return len(self.zsets.get(key, {}))
async def zremrangebyrank(self, key: str, start: int, stop: int) -> int:
bucket = self.zsets.get(key)
if not bucket:
return 0
sorted_pairs = sorted(bucket.items(), key=lambda kv: kv[1])
n = len(sorted_pairs)
# Redis allows negative indexes from the end.
s = start if start >= 0 else max(0, n + start)
e = stop if stop >= 0 else n + stop
if e < s:
return 0
to_remove = sorted_pairs[s : e + 1]
for member, _ in to_remove:
bucket.pop(member, None)
return len(to_remove)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_TENANT = UUID("11111111-1111-1111-1111-111111111111")
def _alert(
severity: AlertSeverity = AlertSeverity.MEDIUM,
*,
username: str | None = "alice",
hostname: str | None = None,
src_ip: str | None = None,
risk_score: float = 0.0,
) -> RawAlert:
return RawAlert(
id=uuid4(),
tenant_id=_TENANT,
source="test",
title="probe",
severity=severity,
username=username,
hostname=hostname,
src_ip=src_ip,
risk_score=risk_score,
)
@pytest.fixture
def fake_redis() -> _FakeRedis:
return _FakeRedis()
@pytest.fixture
def engine(fake_redis: _FakeRedis) -> EntityRiskEngine:
return EntityRiskEngine(fake_redis) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# Score accumulation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_observe_creates_record_for_each_entity(engine: EntityRiskEngine) -> None:
alert = _alert(username="alice", hostname="dc01", src_ip="10.0.0.1")
await engine.observe(alert)
assert (await engine.get(_TENANT, "user", "alice")) is not None
assert (await engine.get(_TENANT, "host", "dc01")) is not None
assert (await engine.get(_TENANT, "src_ip", "10.0.0.1")) is not None
@pytest.mark.asyncio
async def test_score_accumulates_additively(engine: EntityRiskEngine) -> None:
"""Two medium alerts on the same user must double-up the score."""
await engine.observe(_alert(severity=AlertSeverity.MEDIUM))
await engine.observe(_alert(severity=AlertSeverity.MEDIUM))
rec = await engine.get(_TENANT, "user", "alice")
assert rec is not None
# Two medium alerts at base weight 8.0 (config default) ≈ 16.0 (within decay window).
assert rec.score >= 15.0
assert rec.alert_count == 2
@pytest.mark.asyncio
async def test_severity_weights_are_ordered(engine: EntityRiskEngine) -> None:
a_crit = _alert(severity=AlertSeverity.CRITICAL, username="u-crit")
a_high = _alert(severity=AlertSeverity.HIGH, username="u-high")
a_med = _alert(severity=AlertSeverity.MEDIUM, username="u-med")
a_low = _alert(severity=AlertSeverity.LOW, username="u-low")
for a in (a_crit, a_high, a_med, a_low):
await engine.observe(a)
s_crit = (await engine.get(_TENANT, "user", "u-crit")).score # type: ignore[union-attr]
s_high = (await engine.get(_TENANT, "user", "u-high")).score # type: ignore[union-attr]
s_med = (await engine.get(_TENANT, "user", "u-med")).score # type: ignore[union-attr]
s_low = (await engine.get(_TENANT, "user", "u-low")).score # type: ignore[union-attr]
assert s_crit > s_high > s_med > s_low
@pytest.mark.asyncio
async def test_risk_score_amplifies_points(engine: EntityRiskEngine) -> None:
"""Upstream risk_score (detection confidence) boosts entity points."""
base = _alert(severity=AlertSeverity.HIGH, risk_score=0.0, username="u-base")
boosted = _alert(severity=AlertSeverity.HIGH, risk_score=1.0, username="u-boosted")
await engine.observe(base)
await engine.observe(boosted)
s_base = (await engine.get(_TENANT, "user", "u-base")).score # type: ignore[union-attr]
s_boosted = (await engine.get(_TENANT, "user", "u-boosted")).score # type: ignore[union-attr]
# confidence_factor doubles points at risk_score=1.0
assert s_boosted == pytest.approx(2.0 * s_base, rel=0.1)
# ---------------------------------------------------------------------------
# Time decay
# ---------------------------------------------------------------------------
def test_decay_halves_after_one_halflife(engine: EntityRiskEngine) -> None:
now = datetime.utcnow()
one_half_life_ago = now - timedelta(seconds=engine._half_life) # type: ignore[attr-defined]
decayed = engine._decay(40.0, one_half_life_ago, now) # type: ignore[attr-defined]
assert decayed == pytest.approx(20.0, rel=0.05)
def test_decay_floor_zeros_dust(engine: EntityRiskEngine) -> None:
now = datetime.utcnow()
long_ago = now - timedelta(seconds=engine._half_life * 20) # type: ignore[attr-defined]
assert engine._decay(40.0, long_ago, now) == 0.0 # type: ignore[attr-defined]
def test_decay_no_negative_or_future_time(engine: EntityRiskEngine) -> None:
now = datetime.utcnow()
future = now + timedelta(seconds=120)
# Future "prior_seen" must not amplify the score.
assert engine._decay(40.0, future, now) <= 40.0 # type: ignore[attr-defined]
# ---------------------------------------------------------------------------
# Promotion
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_promotion_fires_when_threshold_crossed(engine: EntityRiskEngine) -> None:
"""Three CRITICAL alerts on one user (3 × 40 = 120) ≥ default threshold 80."""
for _ in range(3):
await engine.observe(_alert(severity=AlertSeverity.CRITICAL))
rec = await engine.get(_TENANT, "user", "alice")
assert rec is not None and rec.promoted_at is not None
@pytest.mark.asyncio
async def test_promotion_is_idempotent(engine: EntityRiskEngine) -> None:
"""Crossing the threshold a second time must not re-promote."""
for _ in range(3):
await engine.observe(_alert(severity=AlertSeverity.CRITICAL))
rec_first = await engine.get(_TENANT, "user", "alice")
promoted_at_first = rec_first.promoted_at # type: ignore[union-attr]
# Add another contributing alert.
await engine.observe(_alert(severity=AlertSeverity.CRITICAL))
rec_second = await engine.get(_TENANT, "user", "alice")
assert rec_second is not None
assert rec_second.promoted_at == promoted_at_first
@pytest.mark.asyncio
async def test_no_promotion_below_threshold(engine: EntityRiskEngine) -> None:
await engine.observe(_alert(severity=AlertSeverity.LOW))
rec = await engine.get(_TENANT, "user", "alice")
assert rec is not None and rec.promoted_at is None
# ---------------------------------------------------------------------------
# Top-N queue + stats
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_top_entities_orders_by_score(engine: EntityRiskEngine) -> None:
await engine.observe(_alert(severity=AlertSeverity.LOW, username="quiet"))
await engine.observe(_alert(severity=AlertSeverity.CRITICAL, username="loud"))
await engine.observe(_alert(severity=AlertSeverity.MEDIUM, username="middle"))
top = await engine.top_entities(_TENANT, limit=10)
names = [r.entity_value for r in top if r.entity_type == "user"]
assert names[0] == "loud"
assert names[-1] == "quiet"
@pytest.mark.asyncio
async def test_promoted_only_filters_queue(engine: EntityRiskEngine) -> None:
await engine.observe(_alert(severity=AlertSeverity.LOW, username="quiet"))
for _ in range(3):
await engine.observe(_alert(severity=AlertSeverity.CRITICAL, username="loud"))
promoted = await engine.top_entities(_TENANT, limit=10, promoted_only=True)
assert {r.entity_value for r in promoted} == {"loud"}
@pytest.mark.asyncio
async def test_stats_reports_bands(engine: EntityRiskEngine) -> None:
await engine.observe(_alert(severity=AlertSeverity.LOW, username="low-1"))
await engine.observe(_alert(severity=AlertSeverity.MEDIUM, username="med-1"))
for _ in range(3):
await engine.observe(_alert(severity=AlertSeverity.CRITICAL, username="crit-1"))
stats = await engine.stats(_TENANT)
assert stats["tracked_entities"] >= 3
assert stats["promoted_entities"] >= 1
assert stats["threshold"] == engine.threshold
assert sum(stats["score_bands"].values()) == stats["tracked_entities"]
# ---------------------------------------------------------------------------
# 2026 KPI bar — alert-to-incident ratio ≥ 50:1
#
# This is the headline number RBA exists to deliver. The scenario simulates a
# realistic noisy day: 200 alerts of mixed severity hitting a small set of
# entities (a campaign of related medium / high signal against ~5 users).
# RBA must collapse those 200 alerts into ≤ 4 entity-incidents — i.e. one
# promoted entity per cluster. If this regresses we want CI to fail.
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_alert_to_incident_ratio_meets_2026_bar(engine: EntityRiskEngine) -> None:
"""500 alerts clustered on 4 entities → ≤ 10 entity-incidents (≥ 50:1).
This is the headline number RBA exists to deliver — the published 2026
KPI bar of ≥ 50 alerts per promoted entity-incident. The scenario models
a real noisy day: a focused campaign of 500 mixed-severity alerts hitting
a small pool of users. RBA must collapse those into ≤ 10 entity-incidents.
"""
users = ["alice", "bob", "carol", "dave"]
severities = [
AlertSeverity.HIGH,
AlertSeverity.MEDIUM,
AlertSeverity.MEDIUM,
AlertSeverity.LOW,
AlertSeverity.LOW,
]
total_alerts = 500
for i in range(total_alerts):
# Distribute alerts across the 5 users so every promoted entity is a
# cluster of ~40 contributing alerts — what a real campaign looks like.
await engine.observe(
_alert(
severity=severities[i % len(severities)],
username=users[i % len(users)],
)
)
promoted = await engine.top_entities(_TENANT, limit=50, promoted_only=True)
promoted_count = len({(r.entity_type, r.entity_value) for r in promoted})
# Hard floor: every promoted entity must be a real cluster, not a single
# noisy alert. This is the "50:1" published bar — at minimum the ratio of
# contributing alerts to entity-incidents must be ≥ 50.
assert promoted_count > 0, "RBA must promote at least one entity-incident"
ratio = total_alerts / promoted_count
assert ratio >= 50.0, (
f"alert-to-incident ratio {ratio:.1f}:1 fails the 2026 KPI bar of ≥ 50:1; "
f"got {promoted_count} entity-incidents from {total_alerts} alerts"
)
# ---------------------------------------------------------------------------
# Tenant isolation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_tenants_are_isolated(engine: EntityRiskEngine) -> None:
other_tenant = UUID("22222222-2222-2222-2222-222222222222")
a = _alert(severity=AlertSeverity.HIGH, username="alice")
a_other = RawAlert(**{**a.model_dump(), "tenant_id": other_tenant, "id": uuid4()})
await engine.observe(a)
await engine.observe(a_other)
rec_a = await engine.get(_TENANT, "user", "alice")
rec_b = await engine.get(other_tenant, "user", "alice")
assert rec_a is not None and rec_b is not None
assert rec_a.tenant_id != rec_b.tenant_id
queue_a = await engine.top_entities(_TENANT)
queue_b = await engine.top_entities(other_tenant)
assert all(r.tenant_id == str(_TENANT) for r in queue_a)
assert all(r.tenant_id == str(other_tenant) for r in queue_b)