@@ -60,7 +60,7 @@ type Options struct {
60
60
}
61
61
62
62
// NewConn constructs a new Wireguard server that will accept connections from the addresses provided.
63
- func NewConn (options * Options ) (* Conn , error ) {
63
+ func NewConn (options * Options ) (conn * Conn , err error ) {
64
64
if options == nil {
65
65
options = & Options {}
66
66
}
@@ -123,6 +123,11 @@ func NewConn(options *Options) (*Conn, error) {
123
123
if err != nil {
124
124
return nil , xerrors .Errorf ("create wireguard link monitor: %w" , err )
125
125
}
126
+ defer func () {
127
+ if err != nil {
128
+ wireguardMonitor .Close ()
129
+ }
130
+ }()
126
131
127
132
dialer := & tsdial.Dialer {
128
133
Logf : Logger (options .Logger ),
@@ -134,6 +139,11 @@ func NewConn(options *Options) (*Conn, error) {
134
139
if err != nil {
135
140
return nil , xerrors .Errorf ("create wgengine: %w" , err )
136
141
}
142
+ defer func () {
143
+ if err != nil {
144
+ wireguardEngine .Close ()
145
+ }
146
+ }()
137
147
dialer .UseNetstackForIP = func (ip netip.Addr ) bool {
138
148
_ , ok := wireguardEngine .PeerForIP (ip )
139
149
return ok
@@ -166,10 +176,6 @@ func NewConn(options *Options) (*Conn, error) {
166
176
return netStack .DialContextTCP (ctx , dst )
167
177
}
168
178
netStack .ProcessLocalIPs = true
169
- err = netStack .Start (nil )
170
- if err != nil {
171
- return nil , xerrors .Errorf ("start netstack: %w" , err )
172
- }
173
179
wireguardEngine = wgengine .NewWatchdog (wireguardEngine )
174
180
wireguardEngine .SetDERPMap (options .DERPMap )
175
181
netMapCopy := * netMap
@@ -203,6 +209,11 @@ func NewConn(options *Options) (*Conn, error) {
203
209
},
204
210
wireguardEngine : wireguardEngine ,
205
211
}
212
+ defer func () {
213
+ if err != nil {
214
+ _ = server .Close ()
215
+ }
216
+ }()
206
217
wireguardEngine .SetStatusCallback (func (s * wgengine.Status , err error ) {
207
218
server .logger .Debug (context .Background (), "wireguard status" , slog .F ("status" , s ), slog .F ("err" , err ))
208
219
if err != nil {
@@ -236,6 +247,12 @@ func NewConn(options *Options) (*Conn, error) {
236
247
server .sendNode ()
237
248
})
238
249
netStack .ForwardTCPIn = server .forwardTCP
250
+
251
+ err = netStack .Start (nil )
252
+ if err != nil {
253
+ return nil , xerrors .Errorf ("start netstack: %w" , err )
254
+ }
255
+
239
256
return server , nil
240
257
}
241
258
@@ -519,22 +536,35 @@ func (c *Conn) Close() error {
519
536
default :
520
537
}
521
538
close (c .closed )
522
- for _ , l := range c .listeners {
523
- _ = l .closeNoLock ()
524
- }
525
539
c .mutex .Unlock ()
526
- c .dialCancel ()
527
- _ = c .dialer .Close ()
528
- _ = c .magicConn .Close ()
540
+
541
+ var wg sync.WaitGroup
542
+ defer wg .Wait ()
543
+
544
+ if c .trafficStats != nil {
545
+ wg .Add (1 )
546
+ go func () {
547
+ defer wg .Done ()
548
+ ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
549
+ defer cancel ()
550
+ _ = c .trafficStats .Shutdown (ctx )
551
+ }()
552
+ }
553
+
529
554
_ = c .netStack .Close ()
555
+ c .dialCancel ()
530
556
_ = c .wireguardMonitor .Close ()
531
- _ = c .tunDevice .Close ()
557
+ _ = c .dialer .Close ()
558
+ // Stops internals, e.g. tunDevice, magicConn and dnsManager.
532
559
c .wireguardEngine .Close ()
533
- if c . trafficStats != nil {
534
- ctx , cancel := context . WithTimeout ( context . Background (), 5 * time . Second )
535
- defer cancel ()
536
- _ = c . trafficStats . Shutdown ( ctx )
560
+
561
+ c . mutex . Lock ( )
562
+ for _ , l := range c . listeners {
563
+ _ = l . closeNoLock ( )
537
564
}
565
+ c .listeners = nil
566
+ c .mutex .Unlock ()
567
+
538
568
return nil
539
569
}
540
570
@@ -714,16 +744,25 @@ func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) {
714
744
func (c * Conn ) SetConnStatsCallback (maxPeriod time.Duration , maxConns int , dump func (start , end time.Time , virtual , physical map [netlogtype.Connection ]netlogtype.Counts )) {
715
745
connStats := connstats .NewStatistics (maxPeriod , maxConns , dump )
716
746
747
+ shutdown := func (s * connstats.Statistics ) {
748
+ ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
749
+ defer cancel ()
750
+ _ = s .Shutdown (ctx )
751
+ }
752
+
717
753
c .mutex .Lock ()
754
+ if c .isClosed () {
755
+ c .mutex .Unlock ()
756
+ shutdown (connStats )
757
+ return
758
+ }
718
759
old := c .trafficStats
719
760
c .trafficStats = connStats
720
761
c .mutex .Unlock ()
721
762
722
763
// Make sure to shutdown the old callback.
723
764
if old != nil {
724
- ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
725
- defer cancel ()
726
- _ = old .Shutdown (ctx )
765
+ shutdown (old )
727
766
}
728
767
729
768
c .tunDevice .SetStatistics (connStats )
@@ -776,6 +815,7 @@ func (a addr) String() string { return a.ln.addr }
776
815
// Logger converts the Tailscale logging function to use slog.
777
816
func Logger (logger slog.Logger ) tslogger.Logf {
778
817
return tslogger .Logf (func (format string , args ... any ) {
818
+ slog .Helper ()
779
819
logger .Debug (context .Background (), fmt .Sprintf (format , args ... ))
780
820
})
781
821
}
0 commit comments