@@ -25,6 +25,14 @@ import (
25
25
"github.com/coder/serpent"
26
26
)
27
27
28
+ var (
29
+ // noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify
30
+ // when the local address is not specified in port-forward flags.
31
+ noAddr netip.Addr
32
+ ipv6Loopback = netip .MustParseAddr ("::1" )
33
+ ipv4Loopback = netip .MustParseAddr ("127.0.0.1" )
34
+ )
35
+
28
36
func (r * RootCmd ) portForward () * serpent.Command {
29
37
var (
30
38
tcpForwards []string // <port>:<port>
@@ -122,7 +130,7 @@ func (r *RootCmd) portForward() *serpent.Command {
122
130
// Start all listeners.
123
131
var (
124
132
wg = new (sync.WaitGroup )
125
- listeners = make ([]net.Listener , len (specs ))
133
+ listeners = make ([]net.Listener , 0 , len (specs )* 2 )
126
134
closeAllListeners = func () {
127
135
logger .Debug (ctx , "closing all listeners" )
128
136
for _ , l := range listeners {
@@ -135,13 +143,26 @@ func (r *RootCmd) portForward() *serpent.Command {
135
143
)
136
144
defer closeAllListeners ()
137
145
138
- for i , spec := range specs {
146
+ for _ , spec := range specs {
147
+
148
+ if spec .listenHost == noAddr {
149
+ // first, opportunistically try to listen on IPv6
150
+ spec6 := spec
151
+ spec6 .listenHost = ipv6Loopback
152
+ l6 , err6 := listenAndPortForward (ctx , inv , conn , wg , spec6 , logger )
153
+ if err6 != nil {
154
+ logger .Info (ctx , "failed to opportunistically listen on IPv6" , slog .F ("spec" , spec ), slog .Error (err6 ))
155
+ } else {
156
+ listeners = append (listeners , l6 )
157
+ }
158
+ spec .listenHost = ipv4Loopback
159
+ }
139
160
l , err := listenAndPortForward (ctx , inv , conn , wg , spec , logger )
140
161
if err != nil {
141
162
logger .Error (ctx , "failed to listen" , slog .F ("spec" , spec ), slog .Error (err ))
142
163
return err
143
164
}
144
- listeners [ i ] = l
165
+ listeners = append ( listeners , l )
145
166
}
146
167
147
168
stopUpdating := client .UpdateWorkspaceUsageContext (ctx , workspace .ID )
@@ -206,12 +227,19 @@ func listenAndPortForward(
206
227
spec portForwardSpec ,
207
228
logger slog.Logger ,
208
229
) (net.Listener , error ) {
209
- logger = logger .With (slog .F ("network" , spec .listenNetwork ), slog .F ("address" , spec .listenAddress ))
210
- _ , _ = fmt .Fprintf (inv .Stderr , "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n " , spec .listenNetwork , spec .listenAddress , spec .dialNetwork , spec .dialAddress )
230
+ logger = logger .With (
231
+ slog .F ("network" , spec .network ),
232
+ slog .F ("listen_host" , spec .listenHost ),
233
+ slog .F ("listen_port" , spec .listenPort ),
234
+ )
235
+ listenAddress := netip .AddrPortFrom (spec .listenHost , spec .listenPort )
236
+ dialAddress := fmt .Sprintf ("127.0.0.1:%d" , spec .dialPort )
237
+ _ , _ = fmt .Fprintf (inv .Stderr , "Forwarding '%s://%s' locally to '%s://%s' in the workspace\n " ,
238
+ spec .network , listenAddress , spec .network , dialAddress )
211
239
212
- l , err := inv .Net .Listen (spec .listenNetwork , spec . listenAddress )
240
+ l , err := inv .Net .Listen (spec .network , listenAddress . String () )
213
241
if err != nil {
214
- return nil , xerrors .Errorf ("listen '%v ://%v ': %w" , spec .listenNetwork , spec . listenAddress , err )
242
+ return nil , xerrors .Errorf ("listen '%s ://%s ': %w" , spec .network , listenAddress . String () , err )
215
243
}
216
244
logger .Debug (ctx , "listening" )
217
245
@@ -226,24 +254,31 @@ func listenAndPortForward(
226
254
logger .Debug (ctx , "listener closed" )
227
255
return
228
256
}
229
- _ , _ = fmt .Fprintf (inv .Stderr , "Error accepting connection from '%v://%v': %v\n " , spec .listenNetwork , spec .listenAddress , err )
257
+ _ , _ = fmt .Fprintf (inv .Stderr ,
258
+ "Error accepting connection from '%s://%s': %v\n " ,
259
+ spec .network , listenAddress .String (), err )
230
260
_ , _ = fmt .Fprintln (inv .Stderr , "Killing listener" )
231
261
return
232
262
}
233
- logger .Debug (ctx , "accepted connection" , slog .F ("remote_addr" , netConn .RemoteAddr ()))
263
+ logger .Debug (ctx , "accepted connection" ,
264
+ slog .F ("remote_addr" , netConn .RemoteAddr ()))
234
265
235
266
go func (netConn net.Conn ) {
236
267
defer netConn .Close ()
237
- remoteConn , err := conn .DialContext (ctx , spec .dialNetwork , spec . dialAddress )
268
+ remoteConn , err := conn .DialContext (ctx , spec .network , dialAddress )
238
269
if err != nil {
239
- _ , _ = fmt .Fprintf (inv .Stderr , "Failed to dial '%v://%v' in workspace: %s\n " , spec .dialNetwork , spec .dialAddress , err )
270
+ _ , _ = fmt .Fprintf (inv .Stderr ,
271
+ "Failed to dial '%s://%s' in workspace: %s\n " ,
272
+ spec .network , dialAddress , err )
240
273
return
241
274
}
242
275
defer remoteConn .Close ()
243
- logger .Debug (ctx , "dialed remote" , slog .F ("remote_addr" , netConn .RemoteAddr ()))
276
+ logger .Debug (ctx ,
277
+ "dialed remote" , slog .F ("remote_addr" , netConn .RemoteAddr ()))
244
278
245
279
agentssh .Bicopy (ctx , netConn , remoteConn )
246
- logger .Debug (ctx , "connection closing" , slog .F ("remote_addr" , netConn .RemoteAddr ()))
280
+ logger .Debug (ctx ,
281
+ "connection closing" , slog .F ("remote_addr" , netConn .RemoteAddr ()))
247
282
}(netConn )
248
283
}
249
284
}(spec )
@@ -252,58 +287,48 @@ func listenAndPortForward(
252
287
}
253
288
254
289
type portForwardSpec struct {
255
- listenNetwork string // tcp, udp
256
- listenAddress string // <ip>:<port> or path
257
-
258
- dialNetwork string // tcp, udp
259
- dialAddress string // <ip>:<port> or path
290
+ network string // tcp, udp
291
+ listenHost netip.Addr
292
+ listenPort , dialPort uint16
260
293
}
261
294
262
295
func parsePortForwards (tcpSpecs , udpSpecs []string ) ([]portForwardSpec , error ) {
263
296
specs := []portForwardSpec {}
264
297
265
298
for _ , specEntry := range tcpSpecs {
266
299
for _ , spec := range strings .Split (specEntry , "," ) {
267
- ports , err := parseSrcDestPorts (strings .TrimSpace (spec ))
300
+ pfSpecs , err := parseSrcDestPorts (strings .TrimSpace (spec ))
268
301
if err != nil {
269
302
return nil , xerrors .Errorf ("failed to parse TCP port-forward specification %q: %w" , spec , err )
270
303
}
271
304
272
- for _ , port := range ports {
273
- specs = append (specs , portForwardSpec {
274
- listenNetwork : "tcp" ,
275
- listenAddress : port .local .String (),
276
- dialNetwork : "tcp" ,
277
- dialAddress : port .remote .String (),
278
- })
305
+ for _ , pfSpec := range pfSpecs {
306
+ pfSpec .network = "tcp"
307
+ specs = append (specs , pfSpec )
279
308
}
280
309
}
281
310
}
282
311
283
312
for _ , specEntry := range udpSpecs {
284
313
for _ , spec := range strings .Split (specEntry , "," ) {
285
- ports , err := parseSrcDestPorts (strings .TrimSpace (spec ))
314
+ pfSpecs , err := parseSrcDestPorts (strings .TrimSpace (spec ))
286
315
if err != nil {
287
316
return nil , xerrors .Errorf ("failed to parse UDP port-forward specification %q: %w" , spec , err )
288
317
}
289
318
290
- for _ , port := range ports {
291
- specs = append (specs , portForwardSpec {
292
- listenNetwork : "udp" ,
293
- listenAddress : port .local .String (),
294
- dialNetwork : "udp" ,
295
- dialAddress : port .remote .String (),
296
- })
319
+ for _ , pfSpec := range pfSpecs {
320
+ pfSpec .network = "udp"
321
+ specs = append (specs , pfSpec )
297
322
}
298
323
}
299
324
}
300
325
301
326
// Check for duplicate entries.
302
327
locals := map [string ]struct {}{}
303
328
for _ , spec := range specs {
304
- localStr := fmt .Sprintf ("%v:%v " , spec .listenNetwork , spec .listenAddress )
329
+ localStr := fmt .Sprintf ("%s:%s:%d " , spec .network , spec .listenHost , spec . listenPort )
305
330
if _ , ok := locals [localStr ]; ok {
306
- return nil , xerrors .Errorf ("local %v %v is specified twice" , spec .listenNetwork , spec .listenAddress )
331
+ return nil , xerrors .Errorf ("local %s host:%s port:%d is specified twice" , spec .network , spec .listenHost , spec . listenPort )
307
332
}
308
333
locals [localStr ] = struct {}{}
309
334
}
@@ -323,10 +348,6 @@ func parsePort(in string) (uint16, error) {
323
348
return uint16 (port ), nil
324
349
}
325
350
326
- type parsedSrcDestPort struct {
327
- local , remote netip.AddrPort
328
- }
329
-
330
351
// specRegexp matches port specs. It handles all the following formats:
331
352
//
332
353
// 8000
@@ -347,21 +368,19 @@ type parsedSrcDestPort struct {
347
368
// 9: end or remote port range
348
369
var specRegexp = regexp .MustCompile (`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$` )
349
370
350
- func parseSrcDestPorts (in string ) ([]parsedSrcDestPort , error ) {
351
- var (
352
- err error
353
- localAddr = netip .AddrFrom4 ([4 ]byte {127 , 0 , 0 , 1 })
354
- remoteAddr = netip .AddrFrom4 ([4 ]byte {127 , 0 , 0 , 1 })
355
- )
371
+ func parseSrcDestPorts (in string ) ([]portForwardSpec , error ) {
356
372
groups := specRegexp .FindStringSubmatch (in )
357
373
if len (groups ) == 0 {
358
374
return nil , xerrors .Errorf ("invalid port specification %q" , in )
359
375
}
376
+
377
+ var localAddr netip.Addr
360
378
if groups [2 ] != "" {
361
- localAddr , err = netip .ParseAddr (strings .Trim (groups [2 ], "[]" ))
379
+ parsedAddr , err : = netip .ParseAddr (strings .Trim (groups [2 ], "[]" ))
362
380
if err != nil {
363
381
return nil , xerrors .Errorf ("invalid IP address %q" , groups [2 ])
364
382
}
383
+ localAddr = parsedAddr
365
384
}
366
385
367
386
local , err := parsePortRange (groups [3 ], groups [5 ])
@@ -378,11 +397,12 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
378
397
if len (local ) != len (remote ) {
379
398
return nil , xerrors .Errorf ("port ranges must be the same length, got %d ports forwarded to %d ports" , len (local ), len (remote ))
380
399
}
381
- var out []parsedSrcDestPort
400
+ var out []portForwardSpec
382
401
for i := range local {
383
- out = append (out , parsedSrcDestPort {
384
- local : netip .AddrPortFrom (localAddr , local [i ]),
385
- remote : netip .AddrPortFrom (remoteAddr , remote [i ]),
402
+ out = append (out , portForwardSpec {
403
+ listenHost : localAddr ,
404
+ listenPort : local [i ],
405
+ dialPort : remote [i ],
386
406
})
387
407
}
388
408
return out , nil
0 commit comments