Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8f1f141

Browse files
committed
fix: don't make session counts cumulative
This made for some weird tracking... we want the point-in-time number of counts!
1 parent 2ff1c6d commit 8f1f141

File tree

2 files changed

+66
-37
lines changed

2 files changed

+66
-37
lines changed

agent/agent.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -879,10 +879,13 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
879879
switch magicType {
880880
case MagicSSHSessionTypeVSCode:
881881
a.connCountVSCode.Add(1)
882+
defer a.connCountVSCode.Add(-1)
882883
case MagicSSHSessionTypeJetBrains:
883884
a.connCountJetBrains.Add(1)
885+
defer a.connCountJetBrains.Add(-1)
884886
case "":
885887
a.connCountSSHSession.Add(1)
888+
defer a.connCountSSHSession.Add(-1)
886889
default:
887890
a.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType))
888891
}
@@ -986,6 +989,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
986989
defer conn.Close()
987990

988991
a.connCountReconnectingPTY.Add(1)
992+
defer a.connCountReconnectingPTY.Add(-1)
989993

990994
connectionID := uuid.NewString()
991995
logger = logger.With(slog.F("id", msg.ID), slog.F("connection_id", connectionID))
@@ -1194,8 +1198,7 @@ func (a *agent) startReportingConnectionStats(ctx context.Context) {
11941198
stats.TxPackets = a.statTxPackets.Add(int64(counts.TxPackets))
11951199
}
11961200

1197-
// Tailscale's connection stats are not cumulative, but it makes no sense to make
1198-
// ours temporary.
1201+
// The count of active sessions.
11991202
stats.SessionCountSSH = a.connCountSSHSession.Load()
12001203
stats.SessionCountVSCode = a.connCountVSCode.Load()
12011204
stats.SessionCountJetBrains = a.connCountJetBrains.Load()

agent/agent_test.go

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -110,42 +110,68 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
110110

111111
func TestAgent_Stats_Magic(t *testing.T) {
112112
t.Parallel()
113+
t.Run("StripsEnvironmentVariable", func(t *testing.T) {
114+
t.Parallel()
115+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
116+
defer cancel()
117+
conn, _, _, _ := setupAgent(t, agentsdk.Metadata{}, 0)
118+
sshClient, err := conn.SSHClient(ctx)
119+
require.NoError(t, err)
120+
defer sshClient.Close()
121+
session, err := sshClient.NewSession()
122+
require.NoError(t, err)
123+
session.Setenv(agent.MagicSSHSessionTypeEnvironmentVariable, agent.MagicSSHSessionTypeVSCode)
124+
defer session.Close()
113125

114-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
115-
defer cancel()
116-
117-
conn, _, stats, _ := setupAgent(t, agentsdk.Metadata{}, 0)
118-
sshClient, err := conn.SSHClient(ctx)
119-
require.NoError(t, err)
120-
defer sshClient.Close()
121-
session, err := sshClient.NewSession()
122-
require.NoError(t, err)
123-
session.Setenv(agent.MagicSSHSessionTypeEnvironmentVariable, agent.MagicSSHSessionTypeVSCode)
124-
defer session.Close()
125-
126-
command := "sh -c 'echo $" + agent.MagicSSHSessionTypeEnvironmentVariable + "'"
127-
expected := ""
128-
if runtime.GOOS == "windows" {
129-
expected = "%" + agent.MagicSSHSessionTypeEnvironmentVariable + "%"
130-
command = "cmd.exe /c echo " + expected
131-
}
132-
output, err := session.Output(command)
133-
require.NoError(t, err)
134-
require.Equal(t, expected, strings.TrimSpace(string(output)))
135-
var s *agentsdk.Stats
136-
require.Eventuallyf(t, func() bool {
137-
var ok bool
138-
s, ok = <-stats
139-
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 &&
140-
// Ensure that the connection didn't count as a "normal" SSH session.
141-
// This was a special one, so it should be labeled specially in the stats!
142-
s.SessionCountVSCode == 1 &&
143-
// Ensure that connection latency is being counted!
144-
// If it isn't, it's set to -1.
145-
s.ConnectionMedianLatencyMS >= 0
146-
}, testutil.WaitLong, testutil.IntervalFast,
147-
"never saw stats: %+v", s,
148-
)
126+
command := "sh -c 'echo $" + agent.MagicSSHSessionTypeEnvironmentVariable + "'"
127+
expected := ""
128+
if runtime.GOOS == "windows" {
129+
expected = "%" + agent.MagicSSHSessionTypeEnvironmentVariable + "%"
130+
command = "cmd.exe /c echo " + expected
131+
}
132+
output, err := session.Output(command)
133+
require.NoError(t, err)
134+
require.Equal(t, expected, strings.TrimSpace(string(output)))
135+
})
136+
t.Run("Tracks", func(t *testing.T) {
137+
t.Parallel()
138+
if runtime.GOOS == "window" {
139+
t.Skip("Sleeping for infinity doesn't work on Windows")
140+
}
141+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
142+
defer cancel()
143+
conn, _, stats, _ := setupAgent(t, agentsdk.Metadata{}, 0)
144+
sshClient, err := conn.SSHClient(ctx)
145+
require.NoError(t, err)
146+
defer sshClient.Close()
147+
session, err := sshClient.NewSession()
148+
require.NoError(t, err)
149+
session.Setenv(agent.MagicSSHSessionTypeEnvironmentVariable, agent.MagicSSHSessionTypeVSCode)
150+
defer session.Close()
151+
stdin, err := session.StdinPipe()
152+
require.NoError(t, err)
153+
err = session.Shell()
154+
require.NoError(t, err)
155+
var s *agentsdk.Stats
156+
require.Eventuallyf(t, func() bool {
157+
var ok bool
158+
s, ok = <-stats
159+
fmt.Printf("WE GOT STATS %+v\n", s)
160+
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 &&
161+
// Ensure that the connection didn't count as a "normal" SSH session.
162+
// This was a special one, so it should be labeled specially in the stats!
163+
s.SessionCountVSCode == 1 &&
164+
// Ensure that connection latency is being counted!
165+
// If it isn't, it's set to -1.
166+
s.ConnectionMedianLatencyMS >= 0
167+
}, testutil.WaitLong, testutil.IntervalFast,
168+
"never saw stats: %+v", s,
169+
)
170+
// The shell will automatically exit if there is no stdin!
171+
_ = stdin.Close()
172+
err = session.Wait()
173+
require.NoError(t, err)
174+
})
149175
}
150176

151177
func TestAgent_SessionExec(t *testing.T) {

0 commit comments

Comments
 (0)