6
6
"""
7
7
import argparse
8
8
import sys
9
+ import traceback
9
10
from datetime import datetime
10
11
from typing import TYPE_CHECKING
11
12
from typing import Iterator
23
24
from pyramid .paster import get_appsettings
24
25
from sqlalchemy import select
25
26
27
+ from riskmatrix .models import Asset
26
28
from riskmatrix .models import Organization
27
29
from riskmatrix .models import Risk
28
30
from riskmatrix .models import RiskAssessment
29
31
from riskmatrix .models import RiskCatalog
30
- from riskmatrix .models .asset import Asset
31
32
from riskmatrix .orm import Base
32
33
from riskmatrix .orm import get_engine
33
34
from riskmatrix .scripts .util import select_existing_organization
@@ -72,9 +73,7 @@ def get_or_create_asset(
72
73
Asset .name == asset_name
73
74
)
74
75
75
- asset = session .scalars (q ).one_or_none ()
76
-
77
- if asset :
76
+ if asset := session .scalars (q ).one_or_none ():
78
77
return asset
79
78
80
79
asset = Asset (asset_name , organization )
@@ -83,6 +82,44 @@ def get_or_create_asset(
83
82
return asset
84
83
85
84
85
+ def get_or_create_risk (
86
+ risk_name : str ,
87
+ catalog : RiskCatalog ,
88
+ session : 'Session'
89
+ ) -> Risk :
90
+
91
+ q = select (Risk ).where (
92
+ Risk .organization_id == catalog .organization .id ,
93
+ Risk .name == risk_name
94
+ )
95
+
96
+ if risk := session .scalars (q ).one_or_none ():
97
+ return risk
98
+
99
+ risk = Risk (risk_name , catalog )
100
+ session .add (risk )
101
+ return risk
102
+
103
+
104
+ def get_or_create_risk_assessment (
105
+ risk : Risk ,
106
+ asset : Asset ,
107
+ session : 'Session'
108
+ ) -> RiskAssessment :
109
+
110
+ q = select (RiskAssessment ).where (
111
+ RiskAssessment .risk_id == risk .id ,
112
+ RiskAssessment .asset_id == asset .id ,
113
+ )
114
+
115
+ if assessment := session .scalars (q ).one_or_none ():
116
+ return assessment
117
+
118
+ assessment = RiskAssessment (risk = risk , asset = asset )
119
+ session .add (assessment )
120
+ return assessment
121
+
122
+
86
123
def populate_catalog (
87
124
catalog : RiskCatalog ,
88
125
risks : 'Iterator[RiskDetails]' ,
@@ -94,17 +131,15 @@ def populate_catalog(
94
131
risk_details ['asset_name' ], catalog .organization , session
95
132
)
96
133
97
- risk = Risk (
98
- name = risk_details ['name' ],
99
- catalog = catalog ,
100
- description = risk_details ['desc' ],
101
- category = risk_details ['category' ]
134
+ risk = get_or_create_risk (
135
+ risk_details ['name' ], catalog , session
102
136
)
137
+ risk .category = risk_details ['category' ]
138
+ risk .description = risk_details ['desc' ]
103
139
104
- assessment = RiskAssessment (risk = risk , asset = asset )
140
+ assessment = get_or_create_risk_assessment (risk , asset , session )
105
141
assessment .likelihood = risk_details ['likelihood' ]
106
142
assessment .impact = risk_details ['impact' ]
107
- session .add (assessment )
108
143
109
144
110
145
def risks_from_excel (
@@ -126,7 +161,7 @@ def risks_from_excel(
126
161
# Anyway, actual riks rows will start after row #2.
127
162
start_after_row = 2
128
163
129
- iterator = sheet .iter_rows ( # type: ignore[union-attr,misc]
164
+ iterator = sheet .iter_rows (
130
165
values_only = True ,
131
166
min_row = start_after_row
132
167
)
@@ -194,14 +229,10 @@ def main(argv: list[str] = sys.argv) -> None:
194
229
dbsession
195
230
)
196
231
except sqlalchemy .exc .IntegrityError :
197
- # TODO: Risks and assets (and therefore also assessments) are
198
- # unique per organization, not catalog. Adding a risk from the
199
- # excel that is already present in this organization will fail.
200
- print (
201
- 'Organization already contains some risks from the Excel. '
202
- 'Abort!'
203
- )
232
+ print ('Failed to import excel, aborting.' )
233
+ print (traceback .format_exc ())
204
234
dbsession .rollback ()
235
+ sys .exit (1 )
205
236
else :
206
237
print (
207
238
f'Successfully populated risk catalog "{ catalog .name } " '
0 commit comments