From 8fb44dbfb343d52f95fbdb41ca4aca3844565eab Mon Sep 17 00:00:00 2001 From: smartass Date: Thu, 11 Jun 2026 23:33:38 +0500 Subject: [PATCH] feat: add ssh tunnel with local tcp forward --- src/net/tunnel.ts | 118 +++++++++++++++++++++++++ test/unit/net/tunnel.test.ts | 165 +++++++++++++++++++++++++++++++++++ 2 files changed, 283 insertions(+) create mode 100644 src/net/tunnel.ts create mode 100644 test/unit/net/tunnel.test.ts diff --git a/src/net/tunnel.ts b/src/net/tunnel.ts new file mode 100644 index 0000000..9968c86 --- /dev/null +++ b/src/net/tunnel.ts @@ -0,0 +1,118 @@ +import { readFileSync } from 'node:fs' +import net from 'node:net' +import { homedir } from 'node:os' +import { Client, type ConnectConfig } from 'ssh2' +import type { SshConfig } from '../config/types.js' + +export type Tunnel = { + localHost: string + localPort: number + close: () => Promise +} + +export const expandHome = (path: string): string => { + if (path === '~' || path.startsWith('~/')) { + return homedir() + path.slice(1) + } + return path +} + +export const buildSshAuth = ( + ssh: SshConfig, + env: NodeJS.ProcessEnv = process.env +): Partial => { + if (ssh.privateKey) { + return { privateKey: ssh.privateKey, passphrase: ssh.passphrase } + } + if (ssh.privateKeyPath) { + return { + privateKey: readFileSync(expandHome(ssh.privateKeyPath), 'utf8'), + passphrase: ssh.passphrase + } + } + if (ssh.agent && env.SSH_AUTH_SOCK) { + return { agent: env.SSH_AUTH_SOCK } + } + if (ssh.password) { + return { password: ssh.password } + } + throw new Error( + 'ssh config needs one of: privateKey, privateKeyPath, agent (with SSH_AUTH_SOCK set), password' + ) +} + +export const openTunnel = async ( + ssh: SshConfig, + dstHost: string, + dstPort: number +): Promise => { + const client = new Client() + + await new Promise((resolve, reject) => { + client.once('ready', resolve) + client.once('error', reject) + client.connect({ + host: ssh.host, + port: ssh.port, + username: ssh.user, + ...buildSshAuth(ssh) + }) + }) + + // After 'ready' resolved, later ssh errors must not crash the process. + client.on('error', (error) => { + console.error('dbmole: ssh tunnel error: ' + error.message) + }) + + const server = net.createServer((socket) => { + client.forwardOut( + socket.remoteAddress ?? '127.0.0.1', + socket.remotePort ?? 0, + dstHost, + dstPort, + (error, stream) => { + if (error || !stream) { + socket.destroy() + return + } + socket.pipe(stream) + stream.pipe(socket) + stream.on('close', () => { + socket.destroy() + }) + socket.on('close', () => { + stream.destroy() + }) + stream.on('error', () => { + socket.destroy() + }) + socket.on('error', () => { + stream.destroy() + }) + } + ) + }) + + try { + await new Promise((resolve, reject) => { + server.once('error', reject) + server.listen(0, '127.0.0.1', resolve) + }) + } catch (error) { + client.end() + throw error + } + + const address = server.address() as net.AddressInfo + + const close = async () => { + await new Promise((resolve) => { + server.close(() => { + resolve() + }) + }) + client.end() + } + + return { localHost: '127.0.0.1', localPort: address.port, close } +} diff --git a/test/unit/net/tunnel.test.ts b/test/unit/net/tunnel.test.ts new file mode 100644 index 0000000..d8f396f --- /dev/null +++ b/test/unit/net/tunnel.test.ts @@ -0,0 +1,165 @@ +import { mkdtempSync, rmSync, writeFileSync } from 'node:fs' +import net from 'node:net' +import { tmpdir } from 'node:os' +import { join } from 'node:path' +import { PassThrough } from 'node:stream' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +const ssh2State = vi.hoisted(() => { + // vi.hoisted runs before static imports, so EventEmitter must be pulled in + // here via require (injected by Vitest's transform) rather than a top-level + // import, which is not yet initialised at hoist time. + const { EventEmitter } = require('node:events') as typeof import('node:events') + + const state = { + instances: [] as FakeSshClient[], + failConnect: false + } + + class FakeSshClient extends EventEmitter { + connectConfig: Record | undefined + forwardCalls: unknown[][] = [] + ended = false + + connect(config: Record) { + this.connectConfig = config + process.nextTick(() => { + if (state.failConnect) { + this.emit('error', new Error('auth failed')) + } else { + this.emit('ready') + } + }) + } + + forwardOut( + srcHost: string, + srcPort: number, + dstHost: string, + dstPort: number, + callback: (err: Error | null, stream?: PassThrough) => void + ) { + this.forwardCalls.push([srcHost, srcPort, dstHost, dstPort]) + callback(null, new PassThrough()) + } + + end() { + this.ended = true + } + } + + return { state, FakeSshClient } +}) + +vi.mock('ssh2', () => ({ + Client: class extends ssh2State.FakeSshClient { + constructor() { + super() + ssh2State.state.instances.push(this) + } + } +})) + +import type { SshConfig } from '../../../src/config/types.js' +import { buildSshAuth, expandHome, openTunnel } from '../../../src/net/tunnel.js' + +const sshConfig = (extra: Partial = {}): SshConfig => ({ + host: 'bastion', + port: 22, + user: 'tunnel', + ...extra +}) + +describe('expandHome', () => { + it('expands ~ prefix', () => { + expect(expandHome('~/.ssh/key')).not.toContain('~') + expect(expandHome('~/.ssh/key')).toMatch(/\.ssh\/key$/) + }) + + it('leaves absolute paths alone', () => { + expect(expandHome('/etc/key')).toBe('/etc/key') + }) +}) + +describe('buildSshAuth', () => { + let dir: string + + beforeEach(() => { + dir = mkdtempSync(join(tmpdir(), 'dbmole-ssh-')) + }) + + afterEach(() => { + rmSync(dir, { recursive: true, force: true }) + }) + + it('prefers inline privateKey', () => { + const auth = buildSshAuth(sshConfig({ privateKey: 'PEM', password: 'pw' })) + expect(auth).toEqual({ privateKey: 'PEM', passphrase: undefined }) + }) + + it('reads privateKeyPath from disk', () => { + const keyPath = join(dir, 'id_test') + writeFileSync(keyPath, 'FILEPEM') + const auth = buildSshAuth(sshConfig({ privateKeyPath: keyPath, passphrase: 'pp' })) + expect(auth).toEqual({ privateKey: 'FILEPEM', passphrase: 'pp' }) + }) + + it('uses agent when enabled and SSH_AUTH_SOCK set', () => { + const auth = buildSshAuth(sshConfig({ agent: true }), { SSH_AUTH_SOCK: '/tmp/agent.sock' }) + expect(auth).toEqual({ agent: '/tmp/agent.sock' }) + }) + + it('falls back to password', () => { + expect(buildSshAuth(sshConfig({ password: 'pw' }))).toEqual({ password: 'pw' }) + }) + + it('throws when no method available', () => { + expect(() => buildSshAuth(sshConfig(), {})).toThrow(/privateKey/) + }) +}) + +describe('openTunnel', () => { + beforeEach(() => { + ssh2State.state.instances.length = 0 + ssh2State.state.failConnect = false + }) + + it('forwards local TCP traffic through ssh forwardOut', async () => { + const tunnel = await openTunnel(sshConfig({ password: 'pw' }), 'db-host', 5432) + expect(tunnel.localHost).toBe('127.0.0.1') + expect(tunnel.localPort).toBeGreaterThan(0) + + const echoed = await new Promise((resolve, reject) => { + const socket = net.connect(tunnel.localPort, tunnel.localHost, () => { + socket.write('ping') + }) + socket.once('data', (chunk) => { + socket.end() + resolve(chunk.toString()) + }) + socket.once('error', reject) + }) + + // FakeSshClient hands back a PassThrough, so bytes written into the + // tunnel come straight back: pipe wiring works in both directions. + expect(echoed).toBe('ping') + const client = ssh2State.state.instances[0] + expect(client.forwardCalls[0]).toEqual(expect.arrayContaining(['db-host', 5432])) + expect(client.connectConfig).toMatchObject({ + host: 'bastion', + port: 22, + username: 'tunnel', + password: 'pw' + }) + + await tunnel.close() + expect(client.ended).toBe(true) + }) + + it('rejects when ssh connection fails', async () => { + ssh2State.state.failConnect = true + await expect(openTunnel(sshConfig({ password: 'pw' }), 'db', 5432)).rejects.toThrow( + 'auth failed' + ) + }) +})