1
1
package audit
2
2
3
3
import (
4
+ "fmt"
4
5
"go/types"
6
+ "strings"
5
7
"testing"
6
8
7
9
"github.com/stretchr/testify/assert"
8
10
"github.com/stretchr/testify/require"
9
11
"golang.org/x/tools/go/packages"
12
+
13
+ "github.com/coder/coder/coderd/audit"
14
+ "github.com/coder/coder/coderd/database"
15
+ "github.com/coder/coder/coderd/util/slice"
10
16
)
11
17
12
18
// TestAuditableResources ensures that all auditable resources are included in
13
19
// the Auditable interface and vice versa.
20
+ //
21
+ //nolint:tparallel
14
22
func TestAuditableResources (t * testing.T ) {
15
23
t .Parallel ()
16
24
17
25
pkgs , err := packages .Load (& packages.Config {
18
- Mode : packages .NeedTypes ,
26
+ Mode : packages .NeedTypes | packages . NeedDeps ,
19
27
}, "../../coderd/audit" )
20
28
require .NoError (t , err )
21
29
@@ -37,13 +45,15 @@ func TestAuditableResources(t *testing.T) {
37
45
require .True (t , ok , "expected Auditable to be a union" )
38
46
39
47
found := make (map [string ]bool )
48
+ expectedList := make ([]string , 0 )
40
49
// Now we check we have all the resources in the AuditableResources
41
50
for i := 0 ; i < unionType .Len (); i ++ {
42
51
// All types come across like 'github.com/coder/coder/coderd/database.<type>'
43
52
typeName := unionType .Term (i ).Type ().String ()
44
53
_ , ok := AuditableResources [typeName ]
45
54
assert .True (t , ok , "missing resource %q from AuditableResources" , typeName )
46
55
found [typeName ] = true
56
+ expectedList = append (expectedList , typeName )
47
57
}
48
58
49
59
// Also check that all resources in the table are in the union. We could
@@ -52,4 +62,90 @@ func TestAuditableResources(t *testing.T) {
52
62
_ , ok := found [name ]
53
63
assert .True (t , ok , "extra resource %q found in AuditableResources" , name )
54
64
}
65
+
66
+ // Various functions that have switch statements to include all Auditable
67
+ // resources. Make sure we have all types supported.
68
+ // nolint:paralleltest
69
+ t .Run ("ResourceID" , func (t * testing.T ) {
70
+ // The function being tested, provided here to make it easier to find
71
+ _ = audit .ResourceID [database .APIKey ]
72
+ testAuditFunctionWithSwitch (t , auditPkg , "ResourceID" , expectedList )
73
+ })
74
+
75
+ // nolint:paralleltest
76
+ t .Run ("ResourceType" , func (t * testing.T ) {
77
+ // The function being tested, provided here to make it easier to find
78
+ _ = audit .ResourceType [database .APIKey ]
79
+ testAuditFunctionWithSwitch (t , auditPkg , "ResourceType" , expectedList )
80
+ })
81
+
82
+ // nolint:paralleltest
83
+ t .Run ("ResourceTarget" , func (t * testing.T ) {
84
+ // The function being tested, provided here to make it easier to find
85
+ _ = audit .ResourceTarget [database .APIKey ]
86
+ testAuditFunctionWithSwitch (t , auditPkg , "ResourceTarget" , expectedList )
87
+ })
88
+ }
89
+
90
+ // testAuditFunctionWithSwitch is a helper function to test that a function has
91
+ // a typed switch statement that includes all the types in expectedTypes.
92
+ func testAuditFunctionWithSwitch (t * testing.T , pkg * packages.Package , funcName string , expectedTypes []string ) {
93
+ t .Helper ()
94
+
95
+ f , ok := pkg .Types .Scope ().Lookup (funcName ).(* types.Func )
96
+ require .True (t , ok , fmt .Sprintf ("expected %s to be a function" , funcName ))
97
+ switchCases := findSwitchTypes (f )
98
+ for _ , expected := range expectedTypes {
99
+ if ! slice .Contains (switchCases , expected ) {
100
+ t .Errorf ("%s switch statement is missing type %q. Include it in the switch case block" , funcName , expected )
101
+ }
102
+ }
103
+ for _ , sc := range switchCases {
104
+ if ! slice .Contains (expectedTypes , sc ) {
105
+ t .Errorf ("%s switch statement has unexpected type %q. Remove it from the switch case block" , funcName , sc )
106
+ }
107
+ }
108
+ }
109
+
110
+ // findSwitchTypes is a helper function to find all types a switch statement in
111
+ // the function body of f has.
112
+ func findSwitchTypes (f * types.Func ) []string {
113
+ caseTypes := make ([]string , 0 )
114
+ switches := returnSwitchBlocks (f .Scope ())
115
+ for _ , sc := range switches {
116
+ scTypes := findCaseTypes (sc )
117
+ caseTypes = append (caseTypes , scTypes ... )
118
+ }
119
+ return caseTypes
120
+ }
121
+
122
+ func returnSwitchBlocks (sc * types.Scope ) []* types.Scope {
123
+ switches := make ([]* types.Scope , 0 )
124
+ for i := 0 ; i < sc .NumChildren (); i ++ {
125
+ child := sc .Child (i )
126
+ cStr := child .String ()
127
+ // This is the easiest way to tell if it is a switch statement.
128
+ if strings .Contains (cStr , "type switch scope" ) {
129
+ switches = append (switches , child )
130
+ }
131
+ }
132
+ return switches
133
+ }
134
+
135
+ // findCaseTypes returns all case types in a typed switch statement. Excluding
136
+ // the "Default:" case.
137
+ func findCaseTypes (sc * types.Scope ) []string {
138
+ caseTypes := make ([]string , 0 )
139
+ for i := 0 ; i < sc .NumChildren (); i ++ {
140
+ child := sc .Child (i )
141
+ for _ , name := range child .Names () {
142
+ obj := child .Lookup (name ).Type ()
143
+ typeName := obj .String ()
144
+ // Ignore the "Default:" case
145
+ if typeName != "any" {
146
+ caseTypes = append (caseTypes , typeName )
147
+ }
148
+ }
149
+ }
150
+ return caseTypes
55
151
}
0 commit comments