diff --git a/pytm/__init__.py b/pytm/__init__.py index f1b7e43a..726955e0 100644 --- a/pytm/__init__.py +++ b/pytm/__init__.py @@ -1,4 +1,4 @@ -__all__ = ['Element', 'Server', 'ExternalEntity', 'Datastore', 'Actor', 'Process', 'SetOfProcesses', 'Dataflow', 'Boundary', 'TM', 'Lambda', 'Threat'] +__all__ = ['Element', 'Server', 'ExternalEntity', 'Datastore', 'Actor', 'Process', 'SetOfProcesses', 'Dataflow', 'Boundary', 'TM', 'Action', 'Lambda', 'Threat'] -from .pytm import Element, Server, ExternalEntity, Dataflow, Datastore, Actor, Process, SetOfProcesses, Boundary, TM, Lambda, Threat +from .pytm import Element, Server, ExternalEntity, Dataflow, Datastore, Actor, Process, SetOfProcesses, Boundary, TM, Action, Lambda, Threat diff --git a/pytm/pytm.py b/pytm/pytm.py index 0897a188..b27eca4a 100644 --- a/pytm/pytm.py +++ b/pytm/pytm.py @@ -5,10 +5,11 @@ import random import sys import uuid -import sys from collections import defaultdict from collections.abc import Iterable +from enum import Enum from hashlib import sha224 +from itertools import combinations from os.path import dirname from textwrap import wrap from weakref import WeakKeyDictionary @@ -112,6 +113,20 @@ def __set__(self, instance, value): super().__set__(instance, list(value)) +class varAction(var): + + def __set__(self, instance, value): + if not isinstance(value, Action): + raise ValueError("expecting an Action, got a {}".format(type(value))) + super().__set__(instance, value) + + +class Action(Enum): + NO_ACTION = 'NO_ACTION' + RESTRICT = 'RESTRICT' + IGNORE = 'IGNORE' + + def _setColor(element): if element.inScope is True: return "black" @@ -320,6 +335,7 @@ class TM(): _BagOfBoundaries = [] _threatsExcluded = [] _sf = None + _duplicate_ignored_attrs = "name", "note", "order", "response", "responseTo" name = varString("", required=True, doc="Model name") description = varString("", required=True, doc="Model description") threatsFile = varString(dirname(__file__) + "/threatlib/threats.json", @@ -329,6 +345,7 @@ class TM(): mergeResponses = varBool(False, doc="Merge response edges in DFDs") ignoreUnused = varBool(False, doc="Ignore elements not used in any Dataflow") findings = varFindings([], doc="threats found for elements of this model") + onDuplicates = varAction(Action.NO_ACTION, doc="How to handle duplicate Dataflow with same properties, except name and notes") def __init__(self, name, **kwargs): for key, value in kwargs.items(): @@ -375,6 +392,7 @@ def check(self): if self.description is None: raise ValueError("Every threat model should have at least a brief description of the system being modeled.") TM._BagOfFlows = _match_responses(_sort(TM._BagOfFlows, self.isOrdered)) + self._check_duplicates(TM._BagOfFlows) _apply_defaults(TM._BagOfFlows) if self.ignoreUnused: TM._BagOfElements, TM._BagOfBoundaries = _get_elements_and_boundaries(TM._BagOfFlows) @@ -384,6 +402,32 @@ def check(self): # cannot rely on user defined order if assets are re-used in multiple models TM._BagOfElements = _sort_elem(TM._BagOfElements) + def _check_duplicates(self, flows): + if self.onDuplicates == Action.NO_ACTION: + return + + index = defaultdict(list) + for e in flows: + key = (e.source, e.sink) + index[key].append(e) + + for flows in index.values(): + for left, right in combinations(flows, 2): + left_attrs = left._attr_values() + right_attrs = right._attr_values() + for a in self._duplicate_ignored_attrs: + del left_attrs[a], right_attrs[a] + if left_attrs != right_attrs: + continue + if self.onDuplicates == Action.IGNORE: + right._is_drawn = True + continue + + raise ValueError( + "Duplicate Dataflow found between {} and {}: " + "{} is same as {}".format(left.source, left.sink, left, right,) + ) + def dfd(self): print("digraph tm {\n\tgraph [\n\tfontname = Arial;\n\tfontsize = 14;\n\t]") print("\tnode [\n\tfontname = Arial;\n\tfontsize = 14;\n\trankdir = lr;\n\t]") @@ -560,6 +604,21 @@ def inside(self, *boundaries): return False + def _attr_values(self): + klass = self.__class__ + result = {} + for i in dir(klass): + if i.startswith("_") or callable(getattr(klass, i)): + continue + attr = getattr(klass, i, {}) + if isinstance(attr, var): + value = attr.data.get(self, attr.default) + else: + value = getattr(self, i) + result[i] = value + return result + + class Lambda(Element): """A lambda function running in a Function-as-a-Service (FaaS) environment""" diff --git a/tests/test_pytmfunc.py b/tests/test_pytmfunc.py index c628e5a0..be14ee8a 100644 --- a/tests/test_pytmfunc.py +++ b/tests/test_pytmfunc.py @@ -4,12 +4,13 @@ import json import os import random +import re import unittest from contextlib import contextmanager from os.path import dirname from io import StringIO -from pytm import (TM, Actor, Boundary, Dataflow, Datastore, ExternalEntity, +from pytm import (TM, Action, Actor, Boundary, Dataflow, Datastore, ExternalEntity, Lambda, Process, Server, Threat) @@ -113,6 +114,58 @@ def test_dfd(self): self.maxDiff = None self.assertEqual(output, expected) + def test_dfd_duplicates_ignore(self): + dir_path = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(dir_path, 'dfd.dot')) as x: + expected = x.read().strip() + + random.seed(0) + + TM.reset() + tm = TM("my test tm", description="aaa", onDuplicates=Action.IGNORE) + internet = Boundary("Internet") + server_db = Boundary("Server/DB") + user = Actor("User", inBoundary=internet) + web = Server("Web Server") + db = Datastore("SQL Database", inBoundary=server_db) + + Dataflow(user, web, "User enters comments (*)") + Dataflow(user, web, "User views comments") + Dataflow(web, db, "Insert query with comments") + Dataflow(web, db, "Select query") + Dataflow(db, web, "Retrieve comments") + Dataflow(web, user, "Show comments (*)") + + tm.check() + with captured_output() as (out, err): + tm.dfd() + + output = out.getvalue().strip() + self.maxDiff = None + self.assertEqual(output, expected) + + def test_dfd_duplicates_raise(self): + random.seed(0) + + TM.reset() + tm = TM("my test tm", description="aaa", onDuplicates=Action.RESTRICT) + internet = Boundary("Internet") + server_db = Boundary("Server/DB") + user = Actor("User", inBoundary=internet) + web = Server("Web Server") + db = Datastore("SQL Database", inBoundary=server_db) + + Dataflow(user, web, "User enters comments (*)") + Dataflow(user, web, "User views comments") + Dataflow(web, db, "Insert query with comments") + Dataflow(web, db, "Select query") + Dataflow(db, web, "Retrieve comments") + Dataflow(web, user, "Show comments (*)") + + e = re.escape("Duplicate Dataflow found between Actor(User) and Server(Web Server): Dataflow(User enters comments (*)) is same as Dataflow(User views comments)") + with self.assertRaisesRegex(ValueError, e): + tm.check() + def test_resolve(self): random.seed(0) @@ -146,7 +199,6 @@ def test_resolve(self): self.assertListEqual([f.id for f in resp.findings], ["Dataflow"]) - class Testpytm(unittest.TestCase): # Test for all the threats in threats.py - test Threat.apply() function