@@ -17,109 +17,270 @@ package commands
17
17
import (
18
18
"archive/zip"
19
19
"encoding/json"
20
+ "fmt"
21
+ "github.com/lunasec-io/lunasec/tools/log4shell/scan"
20
22
"github.com/lunasec-io/lunasec/tools/log4shell/types"
21
23
"github.com/lunasec-io/lunasec/tools/log4shell/util"
22
24
"github.com/rs/zerolog/log"
23
25
"github.com/urfave/cli/v2"
24
- "io/fs"
25
26
"io/ioutil"
26
27
"os"
27
28
)
28
29
29
- func JavaArchivePatchCommand (c * cli.Context , globalBoolFlags map [string ]bool ) error {
30
- enableGlobalFlags (c , globalBoolFlags )
31
-
32
- findingsFile := c .String ("findings" )
30
+ func scanForFindings (
31
+ log4jLibraryHashes []byte ,
32
+ searchDirs []string ,
33
+ excludeDirs []string ,
34
+ noFollowSymlinks bool ,
35
+ ) (findings []types.Finding , err error ) {
36
+ var (
37
+ hashLookup types.VulnerableHashLookup
38
+ )
33
39
34
- findingsContent , err := ioutil . ReadFile ( findingsFile )
40
+ hashLookup , err = loadHashLookup ( log4jLibraryHashes , "" , false )
35
41
if err != nil {
36
- log .Error ().
37
- Err (err ).
38
- Str ("findings" , findingsFile ).
39
- Msg ("Unable to open and read findings file" )
40
- return err
42
+ return
41
43
}
42
44
43
- var findings types.FindingsOutput
44
- err = json .Unmarshal (findingsContent , & findings )
45
- if err != nil {
46
- log .Error ().
47
- Err (err ).
48
- Str ("findings" , findingsFile ).
49
- Msg ("Unable to unmarshal findings file" )
50
- return err
51
- }
45
+ processArchiveFile := scan .IdentifyPotentiallyVulnerableFiles (false , hashLookup )
46
+
47
+ scanner := scan .NewLog4jDirectoryScanner (
48
+ excludeDirs , false , noFollowSymlinks , processArchiveFile )
49
+
50
+ findings = scanner .Scan (searchDirs )
51
+ return
52
+ }
52
53
53
- for _ , finding := range findings .VulnerableLibraries {
54
- var file * os.File
54
+ func loadOrScanForFindings (
55
+ c * cli.Context ,
56
+ log4jLibraryHashes []byte ,
57
+ ) (findings []types.Finding , err error ) {
58
+ findingsFile := c .String ("findings" )
59
+ if findingsFile != "" {
60
+ var (
61
+ findingsContent []byte
62
+ findingsOutput types.FindingsOutput
63
+ )
55
64
56
- file , err = os . Open ( finding . Path )
65
+ findingsContent , err = ioutil . ReadFile ( findingsFile )
57
66
if err != nil {
58
- log .Warn ().
59
- Str ("path" , finding .Path ).
67
+ log .Error ().
60
68
Err (err ).
61
- Msg ("Unable to open findings archive" )
62
- return err
69
+ Str ("findings" , findingsFile ).
70
+ Msg ("Unable to open and read findings file" )
71
+ return
63
72
}
64
- defer file .Close ()
65
73
66
- info , _ := os .Stat (finding .Path )
74
+ err = json .Unmarshal (findingsContent , & findingsOutput )
75
+ if err != nil {
76
+ log .Error ().
77
+ Err (err ).
78
+ Str ("findings" , findingsFile ).
79
+ Msg ("Unable to unmarshal findings file" )
80
+ return
81
+ }
82
+ findings = findingsOutput .VulnerableLibraries
83
+ return
84
+ }
85
+
86
+ searchDirs := c .Args ().Slice ()
87
+
88
+ excludeDirs := c .StringSlice ("exclude" )
89
+ noFollowSymlinks := c .Bool ("no-follow-symlinks" )
90
+
91
+ log .Info ().
92
+ Strs ("searchDirs" , searchDirs ).
93
+ Strs ("excludeDirs" , excludeDirs ).
94
+ Msg ("Scanning directories for vulnerable Log4j libraries." )
95
+
96
+ return scanForFindings (log4jLibraryHashes , searchDirs , excludeDirs , noFollowSymlinks )
97
+ }
67
98
68
- var zipReader * zip.Reader
99
+ func askIfShouldSkipLibrary (msg string ) (shouldSkip , forcePatch bool ) {
100
+ var (
101
+ patchPromptResp string
102
+ )
69
103
70
- zipReader , err = zip .NewReader (file , info .Size ())
104
+ for {
105
+ fmt .Printf ("Are you sure you want to patch: %s? (y)es/(n)o/(a)ll: " , msg )
106
+ _ , err := fmt .Scan (& patchPromptResp )
71
107
if err != nil {
72
- log .Warn ().
73
- Str ("path" , finding .Path ).
108
+ log .Error ().
74
109
Err (err ).
75
- Msg ("Unable to open archive for patching " )
76
- return err
110
+ Msg ("Unable to process response. " )
111
+ return true , false
77
112
}
113
+ fmt .Println ()
78
114
79
- var zipFile fs.File
115
+ switch patchPromptResp {
116
+ case "y" :
117
+ shouldSkip = false
118
+ case "n" :
119
+ shouldSkip = true
120
+ case "a" :
121
+ forcePatch = true
122
+ default :
123
+ fmt .Printf ("Option %s is not valid, please enter 'y', 'n', or 'a'.\n " , patchPromptResp )
124
+ continue
125
+ }
126
+ break
127
+ }
128
+ return
129
+ }
80
130
81
- if finding .JndiLookupFileName == "" {
82
- log .Warn ().
131
+ func filterOutJndiLookupFromZip (
132
+ finding types.Finding ,
133
+ zipReader * zip.Reader ,
134
+ writer * zip.Writer ,
135
+ ) error {
136
+ for _ , member := range zipReader .File {
137
+ if member .Name == finding .JndiLookupFileName {
138
+ log .Debug ().
83
139
Str ("path" , finding .Path ).
84
- Err ( err ).
85
- Msg ("Finding does not have JndiLookup.class file to patch" )
140
+ Str ( "zipFilePath" , finding . JndiLookupFileName ).
141
+ Msg ("Found file to remove in order to patch log4j library. " )
86
142
continue
87
143
}
88
144
89
- zipFile , err = zipReader .Open (finding .JndiLookupFileName )
90
- if err != nil {
91
- log .Warn ().
92
- Str ("path" , finding .Path ).
93
- Str ("jndiLookupFileName" , finding .JndiLookupFileName ).
145
+ if err := writer .Copy (member ); err != nil {
146
+ log .Error ().
94
147
Err (err ).
95
- Msg ("Unable to open file from zip" )
148
+ Msg ("Error while copying zip file. " )
96
149
return err
97
150
}
151
+ }
152
+ return nil
153
+ }
98
154
99
- var zipFileHash string
155
+ func patchJavaArchive (finding types.Finding ) (err error ) {
156
+ var (
157
+ libraryFile * os.File
158
+ zipReader * zip.Reader
159
+ )
100
160
101
- zipFileHash , err = util .HexEncodedSha256FromReader (zipFile )
102
- if err != nil {
161
+ libraryFile , err = os .Open (finding .Path )
162
+ if err != nil {
163
+ log .Error ().
164
+ Str ("path" , finding .Path ).
165
+ Err (err ).
166
+ Msg ("Unable to open findings archive" )
167
+ return
168
+ }
169
+ defer libraryFile .Close ()
170
+
171
+ info , _ := os .Stat (finding .Path )
172
+
173
+ zipReader , err = zip .NewReader (libraryFile , info .Size ())
174
+ if err != nil {
175
+ log .Error ().
176
+ Str ("path" , finding .Path ).
177
+ Err (err ).
178
+ Msg ("Unable to open archive for patching" )
179
+ return
180
+ }
181
+
182
+ outZip , err := ioutil .TempFile (os .TempDir (), "*.zip" )
183
+ if err != nil {
184
+ log .Error ().
185
+ Str ("tmpDir" , os .TempDir ()).
186
+ Err (err ).
187
+ Msg ("Unable to create temporary libraryFile" )
188
+ return
189
+ }
190
+ defer os .Remove (outZip .Name ())
191
+
192
+ writer := zip .NewWriter (outZip )
193
+ defer writer .Close ()
194
+
195
+ err = filterOutJndiLookupFromZip (finding , zipReader , writer )
196
+ if err != nil {
197
+ return
198
+ }
199
+
200
+ writer .Close ()
201
+
202
+ if err = libraryFile .Close (); err != nil {
203
+ log .Error ().
204
+ Str ("outZipName" , outZip .Name ()).
205
+ Str ("libraryFileName" , finding .Path ).
206
+ Err (err ).
207
+ Msg ("Unable to close library file." )
208
+ return
209
+ }
210
+
211
+ if err = outZip .Close (); err != nil {
212
+ log .Error ().
213
+ Str ("outZipName" , outZip .Name ()).
214
+ Str ("libraryFileName" , finding .Path ).
215
+ Err (err ).
216
+ Msg ("Unable to close output zip." )
217
+ return
218
+ }
219
+
220
+ _ , err = util .CopyFile (outZip .Name (), finding .Path )
221
+ if err != nil {
222
+ log .Error ().
223
+ Str ("outZipName" , outZip .Name ()).
224
+ Str ("libraryFileName" , finding .Path ).
225
+ Err (err ).
226
+ Msg ("Unable to replace library file with patched library file." )
227
+ return
228
+ }
229
+ return
230
+ }
231
+
232
+ func JavaArchivePatchCommand (
233
+ c * cli.Context ,
234
+ globalBoolFlags map [string ]bool ,
235
+ log4jLibraryHashes []byte ,
236
+ ) error {
237
+ enableGlobalFlags (c , globalBoolFlags )
238
+
239
+ findings , err := loadOrScanForFindings (c , log4jLibraryHashes )
240
+ if err != nil {
241
+ return err
242
+ }
243
+
244
+ log .Info ().
245
+ Int ("findingsCount" , len (findings )).
246
+ Msg ("Patching found vulnerable Log4j libraries." )
247
+
248
+ forcePatch := c .Bool ("force-patch" )
249
+
250
+ var patchedLibraries []string
251
+
252
+ for _ , finding := range findings {
253
+ var (
254
+ shouldSkip bool
255
+ )
256
+
257
+ if finding .JndiLookupFileName == "" {
103
258
log .Warn ().
104
259
Str ("path" , finding .Path ).
105
260
Err (err ).
106
- Msg ("Unable to hash zip file" )
107
- return err
261
+ Msg ("Finding does not have JndiLookup.class file to patch " )
262
+ continue
108
263
}
109
264
110
- if zipFileHash != finding .JndiLookupHash {
111
- log .Warn ().
112
- Str ("path" , finding .Path ).
113
- Str ("hash" , finding .JndiLookupHash ).
114
- Err (err ).
115
- Msg ("Hashes do not match, not deleting" )
116
- return nil
265
+ if ! forcePatch {
266
+ shouldSkip , forcePatch = askIfShouldSkipLibrary (finding .Path )
267
+ if ! forcePatch && shouldSkip {
268
+ log .Info ().
269
+ Str ("findingPath" , finding .Path ).
270
+ Msg ("Skipping library for patching" )
271
+ continue
272
+ }
117
273
}
118
- log .Debug ().
119
- Str ("path" , finding .Path ).
120
- Str ("zipFilePath" , finding .JndiLookupFileName ).
121
- Msg ("Found file to remove" )
274
+
275
+ err = patchJavaArchive (finding )
276
+ if err != nil {
277
+ continue
278
+ }
279
+ patchedLibraries = append (patchedLibraries , finding .Path )
122
280
}
123
281
282
+ log .Info ().
283
+ Strs ("patchedLibraries" , patchedLibraries ).
284
+ Msg ("Successfully patched libraries." )
124
285
return nil
125
286
}
0 commit comments