import { mkdtempSync, rmSync, writeFileSync } from 'node:fs' import net from 'node:net' import { tmpdir } from 'node:os' import { join } from 'node:path' import { PassThrough } from 'node:stream' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' const ssh2State = vi.hoisted(() => { // vi.hoisted runs before static imports, so EventEmitter must be pulled in // here via require (injected by Vitest's transform) rather than a top-level // import, which is not yet initialised at hoist time. const { EventEmitter } = require('node:events') as typeof import('node:events') const state = { instances: [] as FakeSshClient[], failConnect: false } class FakeSshClient extends EventEmitter { connectConfig: Record | undefined forwardCalls: unknown[][] = [] ended = false connect(config: Record) { this.connectConfig = config process.nextTick(() => { if (state.failConnect) { this.emit('error', new Error('auth failed')) } else { this.emit('ready') } }) } forwardOut( srcHost: string, srcPort: number, dstHost: string, dstPort: number, callback: (err: Error | null, stream?: PassThrough) => void ) { this.forwardCalls.push([srcHost, srcPort, dstHost, dstPort]) callback(null, new PassThrough()) } end() { this.ended = true } } return { state, FakeSshClient } }) vi.mock('ssh2', () => ({ Client: class extends ssh2State.FakeSshClient { constructor() { super() ssh2State.state.instances.push(this) } } })) import type { SshConfig } from '../../../src/config/types.js' import { buildSshAuth, expandHome, openTunnel } from '../../../src/net/tunnel.js' const sshConfig = (extra: Partial = {}): SshConfig => ({ host: 'bastion', port: 22, user: 'tunnel', ...extra }) describe('expandHome', () => { it('expands ~ prefix', () => { expect(expandHome('~/.ssh/key')).not.toContain('~') expect(expandHome('~/.ssh/key')).toMatch(/\.ssh\/key$/) }) it('leaves absolute paths alone', () => { expect(expandHome('/etc/key')).toBe('/etc/key') }) }) describe('buildSshAuth', () => { let dir: string beforeEach(() => { dir = mkdtempSync(join(tmpdir(), 'dbmole-ssh-')) }) afterEach(() => { rmSync(dir, { recursive: true, force: true }) }) it('prefers inline privateKey', () => { const auth = buildSshAuth(sshConfig({ privateKey: 'PEM', password: 'pw' })) expect(auth).toEqual({ privateKey: 'PEM', passphrase: undefined }) }) it('reads privateKeyPath from disk', () => { const keyPath = join(dir, 'id_test') writeFileSync(keyPath, 'FILEPEM') const auth = buildSshAuth(sshConfig({ privateKeyPath: keyPath, passphrase: 'pp' })) expect(auth).toEqual({ privateKey: 'FILEPEM', passphrase: 'pp' }) }) it('uses agent when enabled and SSH_AUTH_SOCK set', () => { const auth = buildSshAuth(sshConfig({ agent: true }), { SSH_AUTH_SOCK: '/tmp/agent.sock' }) expect(auth).toEqual({ agent: '/tmp/agent.sock' }) }) it('falls back to password', () => { expect(buildSshAuth(sshConfig({ password: 'pw' }))).toEqual({ password: 'pw' }) }) it('throws when no method available', () => { expect(() => buildSshAuth(sshConfig(), {})).toThrow(/privateKey/) }) }) describe('openTunnel', () => { beforeEach(() => { ssh2State.state.instances.length = 0 ssh2State.state.failConnect = false }) it('forwards local TCP traffic through ssh forwardOut', async () => { const tunnel = await openTunnel(sshConfig({ password: 'pw' }), 'db-host', 5432) expect(tunnel.localHost).toBe('127.0.0.1') expect(tunnel.localPort).toBeGreaterThan(0) const echoed = await new Promise((resolve, reject) => { const socket = net.connect(tunnel.localPort, tunnel.localHost, () => { socket.write('ping') }) socket.once('data', (chunk) => { socket.end() resolve(chunk.toString()) }) socket.once('error', reject) }) // FakeSshClient hands back a PassThrough, so bytes written into the // tunnel come straight back: pipe wiring works in both directions. expect(echoed).toBe('ping') const client = ssh2State.state.instances[0] expect(client.forwardCalls[0]).toEqual(expect.arrayContaining(['db-host', 5432])) expect(client.connectConfig).toMatchObject({ host: 'bastion', port: 22, username: 'tunnel', password: 'pw' }) await tunnel.close() expect(client.ended).toBe(true) }) it('rejects when ssh connection fails', async () => { ssh2State.state.failConnect = true await expect(openTunnel(sshConfig({ password: 'pw' }), 'db', 5432)).rejects.toThrow( 'auth failed' ) }) it('close() resolves even with an open forwarded socket', async () => { const tunnel = await openTunnel(sshConfig({ password: 'pw' }), 'db-host', 5432) const socket = await new Promise((resolve, reject) => { const s = net.connect(tunnel.localPort, tunnel.localHost, () => { s.write('hold') resolve(s) }) s.once('error', reject) }) const clientClosed = new Promise((resolve) => { if (socket.closed) { resolve() return } socket.once('close', () => resolve()) }) // Deliberately do NOT end the socket; close() must still resolve. await tunnel.close() // The forwarded socket on the server side is destroyed, which tears // down our client end too. await clientClosed expect(socket.destroyed).toBe(true) }) it('reports isClosed false for a fresh tunnel', async () => { const tunnel = await openTunnel(sshConfig({ password: 'pw' }), 'db-host', 5432) expect(tunnel.isClosed()).toBe(false) await tunnel.close() }) it('flips isClosed to true when the ssh client closes', async () => { const tunnel = await openTunnel(sshConfig({ password: 'pw' }), 'db-host', 5432) const client = ssh2State.state.instances[0] client.emit('close') expect(tunnel.isClosed()).toBe(true) await tunnel.close() }) })