44 "archive/tar"
55 "bytes"
66 "context"
7+ "crypto/sha1" //#nosec // Not used for cryptography.
8+ "encoding/hex"
79 "errors"
810 "fmt"
911 "io"
@@ -20,6 +22,7 @@ import (
2022 "github.com/klauspost/compress/zstd"
2123 "github.com/unrolled/secure"
2224 "golang.org/x/exp/slices"
25+ "golang.org/x/sync/errgroup"
2326 "golang.org/x/xerrors"
2427)
2528
@@ -439,12 +442,18 @@ func ExtractOrReadBinFS(dest string, siteFS fs.FS) (http.FileSystem, error) {
439442 return nil , err
440443 }
441444
442- n , err := extractBin (dest , archive )
445+ ok , err := verifyBinSha1IsCurrent (dest , siteFS )
443446 if err != nil {
444- return nil , xerrors .Errorf ("extract coder binaries failed: %w" , err )
447+ return nil , xerrors .Errorf ("verify coder binaries sha1 failed: %w" , err )
445448 }
446- if n == 0 {
447- return nil , xerrors .New ("no files were extracted from coder binaries archive" )
449+ if ! ok {
450+ n , err := extractBin (dest , archive )
451+ if err != nil {
452+ return nil , xerrors .Errorf ("extract coder binaries failed: %w" , err )
453+ }
454+ if n == 0 {
455+ return nil , xerrors .New ("no files were extracted from coder binaries archive" )
456+ }
448457 }
449458
450459 return dir , nil
@@ -461,6 +470,98 @@ func filterFiles(files []fs.DirEntry, names ...string) []fs.DirEntry {
461470 return filtered
462471}
463472
473+ // errHashMismatch is a sentinel error used in verifyBinSha1IsCurrent.
474+ var errHashMismatch = xerrors .New ("hash mismatch" )
475+
476+ func verifyBinSha1IsCurrent (dest string , siteFS fs.FS ) (ok bool , err error ) {
477+ b1 , err := fs .ReadFile (siteFS , "bin/coder.sha1" )
478+ if err != nil {
479+ return false , xerrors .Errorf ("read coder sha1 from embedded fs failed: %w" , err )
480+ }
481+ // Parse sha1 file.
482+ shaFiles := make (map [string ][]byte )
483+ for _ , line := range bytes .Split (bytes .TrimSpace (b1 ), []byte {'\n' }) {
484+ parts := bytes .Split (line , []byte {' ' , '*' })
485+ if len (parts ) != 2 {
486+ return false , xerrors .Errorf ("malformed sha1 file: %w" , err )
487+ }
488+ shaFiles [string (parts [1 ])] = parts [0 ]
489+ }
490+ if len (shaFiles ) == 0 {
491+ return false , xerrors .Errorf ("empty sha1 file: %w" , err )
492+ }
493+
494+ b2 , err := os .ReadFile (filepath .Join (dest , "coder.sha1" ))
495+ if err != nil {
496+ if xerrors .Is (err , fs .ErrNotExist ) {
497+ return false , nil
498+ }
499+ return false , xerrors .Errorf ("read coder sha1 failed: %w" , err )
500+ }
501+
502+ // Check shasum files for equality for early-exit.
503+ if ! bytes .Equal (b1 , b2 ) {
504+ return false , nil
505+ }
506+
507+ var eg errgroup.Group
508+ // Speed up startup by verifying files concurrently. Concurrency
509+ // is limited to save resources / early-exit. Early-exit speed
510+ // could be improved by using a context aware io.Reader and
511+ // passing the context from errgroup.WithContext.
512+ eg .SetLimit (3 )
513+
514+ // Verify the hash of each on-disk binary.
515+ for file , hash1 := range shaFiles {
516+ file := file
517+ hash1 := hash1
518+ eg .Go (func () error {
519+ hash2 , err := sha1HashFile (filepath .Join (dest , file ))
520+ if err != nil {
521+ if xerrors .Is (err , fs .ErrNotExist ) {
522+ return errHashMismatch
523+ }
524+ return xerrors .Errorf ("hash file failed: %w" , err )
525+ }
526+ if ! bytes .Equal (hash1 , hash2 ) {
527+ return errHashMismatch
528+ }
529+ return nil
530+ })
531+ }
532+ err = eg .Wait ()
533+ if err != nil {
534+ if xerrors .Is (err , errHashMismatch ) {
535+ return false , nil
536+ }
537+ return false , err
538+ }
539+
540+ return true , nil
541+ }
542+
543+ // sha1HashFile computes a SHA1 hash of the file, returning the hex
544+ // representation.
545+ func sha1HashFile (name string ) ([]byte , error ) {
546+ //#nosec // Not used for cryptography.
547+ hash := sha1 .New ()
548+ f , err := os .Open (name )
549+ if err != nil {
550+ return nil , err
551+ }
552+ defer f .Close ()
553+
554+ _ , err = io .Copy (hash , f )
555+ if err != nil {
556+ return nil , err
557+ }
558+
559+ b := make ([]byte , hash .Size ())
560+ hash .Sum (b [:0 ])
561+
562+ return []byte (hex .EncodeToString (b )), nil
563+ }
564+
464565func extractBin (dest string , r io.Reader ) (numExtraced int , err error ) {
465566 opts := []zstd.DOption {
466567 // Concurrency doesn't help us when decoding the tar and
0 commit comments