204 lines
6.7 KiB
TypeScript
204 lines
6.7 KiB
TypeScript
import { mkdtempSync, rmSync, writeFileSync } from 'node:fs'
|
|
import net from 'node:net'
|
|
import { tmpdir } from 'node:os'
|
|
import { join } from 'node:path'
|
|
import { PassThrough } from 'node:stream'
|
|
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
|
|
|
const ssh2State = vi.hoisted(() => {
|
|
// vi.hoisted runs before static imports, so EventEmitter must be pulled in
|
|
// here via require (injected by Vitest's transform) rather than a top-level
|
|
// import, which is not yet initialised at hoist time.
|
|
const { EventEmitter } = require('node:events') as typeof import('node:events')
|
|
|
|
const state = {
|
|
instances: [] as FakeSshClient[],
|
|
failConnect: false
|
|
}
|
|
|
|
class FakeSshClient extends EventEmitter {
|
|
connectConfig: Record<string, unknown> | undefined
|
|
forwardCalls: unknown[][] = []
|
|
ended = false
|
|
|
|
connect(config: Record<string, unknown>) {
|
|
this.connectConfig = config
|
|
process.nextTick(() => {
|
|
if (state.failConnect) {
|
|
this.emit('error', new Error('auth failed'))
|
|
} else {
|
|
this.emit('ready')
|
|
}
|
|
})
|
|
}
|
|
|
|
forwardOut(
|
|
srcHost: string,
|
|
srcPort: number,
|
|
dstHost: string,
|
|
dstPort: number,
|
|
callback: (err: Error | null, stream?: PassThrough) => void
|
|
) {
|
|
this.forwardCalls.push([srcHost, srcPort, dstHost, dstPort])
|
|
callback(null, new PassThrough())
|
|
}
|
|
|
|
end() {
|
|
this.ended = true
|
|
}
|
|
}
|
|
|
|
return { state, FakeSshClient }
|
|
})
|
|
|
|
vi.mock('ssh2', () => ({
|
|
Client: class extends ssh2State.FakeSshClient {
|
|
constructor() {
|
|
super()
|
|
ssh2State.state.instances.push(this)
|
|
}
|
|
}
|
|
}))
|
|
|
|
import type { SshConfig } from '../../../src/config/types.js'
|
|
import { buildSshAuth, expandHome, openTunnel } from '../../../src/net/tunnel.js'
|
|
|
|
const sshConfig = (extra: Partial<SshConfig> = {}): SshConfig => ({
|
|
host: 'bastion',
|
|
port: 22,
|
|
user: 'tunnel',
|
|
...extra
|
|
})
|
|
|
|
describe('expandHome', () => {
|
|
it('expands ~ prefix', () => {
|
|
expect(expandHome('~/.ssh/key')).not.toContain('~')
|
|
expect(expandHome('~/.ssh/key')).toMatch(/\.ssh\/key$/)
|
|
})
|
|
|
|
it('leaves absolute paths alone', () => {
|
|
expect(expandHome('/etc/key')).toBe('/etc/key')
|
|
})
|
|
})
|
|
|
|
describe('buildSshAuth', () => {
|
|
let dir: string
|
|
|
|
beforeEach(() => {
|
|
dir = mkdtempSync(join(tmpdir(), 'dbmole-ssh-'))
|
|
})
|
|
|
|
afterEach(() => {
|
|
rmSync(dir, { recursive: true, force: true })
|
|
})
|
|
|
|
it('prefers inline privateKey', () => {
|
|
const auth = buildSshAuth(sshConfig({ privateKey: 'PEM', password: 'pw' }))
|
|
expect(auth).toEqual({ privateKey: 'PEM', passphrase: undefined })
|
|
})
|
|
|
|
it('reads privateKeyPath from disk', () => {
|
|
const keyPath = join(dir, 'id_test')
|
|
writeFileSync(keyPath, 'FILEPEM')
|
|
const auth = buildSshAuth(sshConfig({ privateKeyPath: keyPath, passphrase: 'pp' }))
|
|
expect(auth).toEqual({ privateKey: 'FILEPEM', passphrase: 'pp' })
|
|
})
|
|
|
|
it('uses agent when enabled and SSH_AUTH_SOCK set', () => {
|
|
const auth = buildSshAuth(sshConfig({ agent: true }), { SSH_AUTH_SOCK: '/tmp/agent.sock' })
|
|
expect(auth).toEqual({ agent: '/tmp/agent.sock' })
|
|
})
|
|
|
|
it('falls back to password', () => {
|
|
expect(buildSshAuth(sshConfig({ password: 'pw' }))).toEqual({ password: 'pw' })
|
|
})
|
|
|
|
it('throws when no method available', () => {
|
|
expect(() => buildSshAuth(sshConfig(), {})).toThrow(/privateKey/)
|
|
})
|
|
})
|
|
|
|
describe('openTunnel', () => {
|
|
beforeEach(() => {
|
|
ssh2State.state.instances.length = 0
|
|
ssh2State.state.failConnect = false
|
|
})
|
|
|
|
it('forwards local TCP traffic through ssh forwardOut', async () => {
|
|
const tunnel = await openTunnel(sshConfig({ password: 'pw' }), 'db-host', 5432)
|
|
expect(tunnel.localHost).toBe('127.0.0.1')
|
|
expect(tunnel.localPort).toBeGreaterThan(0)
|
|
|
|
const echoed = await new Promise<string>((resolve, reject) => {
|
|
const socket = net.connect(tunnel.localPort, tunnel.localHost, () => {
|
|
socket.write('ping')
|
|
})
|
|
socket.once('data', (chunk) => {
|
|
socket.end()
|
|
resolve(chunk.toString())
|
|
})
|
|
socket.once('error', reject)
|
|
})
|
|
|
|
// FakeSshClient hands back a PassThrough, so bytes written into the
|
|
// tunnel come straight back: pipe wiring works in both directions.
|
|
expect(echoed).toBe('ping')
|
|
const client = ssh2State.state.instances[0]
|
|
expect(client.forwardCalls[0]).toEqual(expect.arrayContaining(['db-host', 5432]))
|
|
expect(client.connectConfig).toMatchObject({
|
|
host: 'bastion',
|
|
port: 22,
|
|
username: 'tunnel',
|
|
password: 'pw'
|
|
})
|
|
|
|
await tunnel.close()
|
|
expect(client.ended).toBe(true)
|
|
})
|
|
|
|
it('rejects when ssh connection fails', async () => {
|
|
ssh2State.state.failConnect = true
|
|
await expect(openTunnel(sshConfig({ password: 'pw' }), 'db', 5432)).rejects.toThrow(
|
|
'auth failed'
|
|
)
|
|
})
|
|
|
|
it('close() resolves even with an open forwarded socket', async () => {
|
|
const tunnel = await openTunnel(sshConfig({ password: 'pw' }), 'db-host', 5432)
|
|
const socket = await new Promise<net.Socket>((resolve, reject) => {
|
|
const s = net.connect(tunnel.localPort, tunnel.localHost, () => {
|
|
s.write('hold')
|
|
resolve(s)
|
|
})
|
|
s.once('error', reject)
|
|
})
|
|
const clientClosed = new Promise<void>((resolve) => {
|
|
if (socket.closed) {
|
|
resolve()
|
|
return
|
|
}
|
|
socket.once('close', () => resolve())
|
|
})
|
|
// Deliberately do NOT end the socket; close() must still resolve.
|
|
await tunnel.close()
|
|
// The forwarded socket on the server side is destroyed, which tears
|
|
// down our client end too.
|
|
await clientClosed
|
|
expect(socket.destroyed).toBe(true)
|
|
})
|
|
|
|
it('reports isClosed false for a fresh tunnel', async () => {
|
|
const tunnel = await openTunnel(sshConfig({ password: 'pw' }), 'db-host', 5432)
|
|
expect(tunnel.isClosed()).toBe(false)
|
|
await tunnel.close()
|
|
})
|
|
|
|
it('flips isClosed to true when the ssh client closes', async () => {
|
|
const tunnel = await openTunnel(sshConfig({ password: 'pw' }), 'db-host', 5432)
|
|
const client = ssh2State.state.instances[0]
|
|
client.emit('close')
|
|
expect(tunnel.isClosed()).toBe(true)
|
|
await tunnel.close()
|
|
})
|
|
})
|