From 52baa59f97b32f2854eebb38d650c799d3138503 Mon Sep 17 00:00:00 2001 From: smartass Date: Fri, 12 Jun 2026 00:01:50 +0500 Subject: [PATCH] feat: add manager with hash cache, tunnel rebuild --- src/db/manager.ts | 108 +++++++++++++++++++ test/unit/db/manager.test.ts | 194 +++++++++++++++++++++++++++++++++++ 2 files changed, 302 insertions(+) create mode 100644 src/db/manager.ts create mode 100644 test/unit/db/manager.test.ts diff --git a/src/db/manager.ts b/src/db/manager.ts new file mode 100644 index 0000000..fa21c8e --- /dev/null +++ b/src/db/manager.ts @@ -0,0 +1,108 @@ +import type { Registry } from '../config/registry.js' +import { type ConnectionConfig, type ConnectionSource, defaultPort } from '../config/types.js' +import { openTunnel, type Tunnel } from '../net/tunnel.js' +import type { Driver, DriverTarget } from './driver.js' +import { createMysqlDriver } from './mysql.js' +import { createPostgresDriver } from './postgres.js' + +type Entry = { + driver: Driver + tunnel?: Tunnel +} + +type CacheSlot = { + hash: string + promise: Promise +} + +export type ManagerDeps = { + createDriver?: (target: DriverTarget) => Driver + createTunnel?: typeof openTunnel +} + +export type ManagedConnection = { + driver: Driver + config: ConnectionConfig + source: ConnectionSource +} + +export type Manager = { + get: (name: string) => Promise + invalidate: (name: string) => Promise + disposeAll: () => Promise +} + +const defaultCreateDriver = (target: DriverTarget): Driver => + target.config.type === 'postgres' ? createPostgresDriver(target) : createMysqlDriver(target) + +export const createManager = (registry: Registry, deps: ManagerDeps = {}): Manager => { + const createDriver = deps.createDriver ?? defaultCreateDriver + const createTunnel = deps.createTunnel ?? openTunnel + const cache = new Map() + + const disposeEntry = async (entry: Entry): Promise => { + // Order matters: pg/mysql2 pool.end() waits for in-flight queries to + // finish, so the driver drains BEFORE the tunnel closes. New work + // lands on the rebuilt entry, not this one. + await entry.driver.dispose().catch(() => {}) + await entry.tunnel?.close().catch(() => {}) + } + + const disposeSlot = (slot: CacheSlot): void => { + slot.promise.then(disposeEntry).catch(() => {}) + } + + const build = async (config: ConnectionConfig): Promise => { + const port = config.port ?? defaultPort(config.type) + const tunnel = config.ssh ? await createTunnel(config.ssh, config.host, port) : undefined + const target: DriverTarget = { + config, + host: tunnel ? tunnel.localHost : config.host, + port: tunnel ? tunnel.localPort : port + } + return { driver: createDriver(target), tunnel } + } + + const get = async (name: string): Promise => { + const resolved = registry.resolve(name) + const cached = cache.get(name) + if (cached && cached.hash === resolved.hash) { + const entry = await cached.promise + if (!entry.tunnel?.isClosed()) { + return { driver: entry.driver, config: resolved.config, source: resolved.source } + } + } + if (cached) { + cache.delete(name) + disposeSlot(cached) + } + const slot: CacheSlot = { hash: resolved.hash, promise: build(resolved.config) } + cache.set(name, slot) + try { + const entry = await slot.promise + return { driver: entry.driver, config: resolved.config, source: resolved.source } + } catch (error) { + if (cache.get(name) === slot) { + cache.delete(name) + } + throw error + } + } + + const invalidate = async (name: string): Promise => { + const slot = cache.get(name) + if (!slot) { + return + } + cache.delete(name) + await slot.promise.then(disposeEntry).catch(() => {}) + } + + const disposeAll = async (): Promise => { + const slots = [...cache.values()] + cache.clear() + await Promise.all(slots.map((slot) => slot.promise.then(disposeEntry).catch(() => {}))) + } + + return { get, invalidate, disposeAll } +} diff --git a/test/unit/db/manager.test.ts b/test/unit/db/manager.test.ts new file mode 100644 index 0000000..635ba34 --- /dev/null +++ b/test/unit/db/manager.test.ts @@ -0,0 +1,194 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import type { Registry } from '../../../src/config/registry.js' +import type { ConnectionConfig, ResolvedConnection, SshConfig } from '../../../src/config/types.js' +import type { Driver, DriverTarget } from '../../../src/db/driver.js' +import { createManager } from '../../../src/db/manager.js' +import type { Tunnel } from '../../../src/net/tunnel.js' + +const config = (extra: Partial = {}): ConnectionConfig => ({ + name: 'c', + type: 'postgres', + host: 'db-host', + user: 'u', + readonly: false, + ...extra +}) + +const fakeDriver = (): Driver & { disposed: boolean } => { + const driver = { + disposed: false, + query: vi.fn(), + listDatabases: vi.fn(), + listTables: vi.fn(), + describeTable: vi.fn(), + serverVersion: vi.fn(), + dispose: vi.fn(async () => { + driver.disposed = true + }) + } + return driver +} + +type FakeTunnel = Tunnel & { closed: boolean } + +describe('createManager', () => { + let resolved: ResolvedConnection + let registry: Registry + let drivers: Array + let targets: DriverTarget[] + let tunnels: FakeTunnel[] + let createDriver: (target: DriverTarget) => Driver + let createTunnel: ReturnType + + beforeEach(() => { + resolved = { config: config(), source: 'store', hash: 'h1' } + registry = { + list: vi.fn(() => [resolved]), + resolve: vi.fn(() => resolved), + add: vi.fn(), + update: vi.fn(), + remove: vi.fn() + } + drivers = [] + targets = [] + tunnels = [] + createDriver = (target: DriverTarget) => { + targets.push(target) + const driver = fakeDriver() + drivers.push(driver) + return driver + } + createTunnel = vi.fn(async () => { + const tunnel: FakeTunnel = { + localHost: '127.0.0.1', + localPort: 54321, + closed: false, + isClosed: () => tunnel.closed, + close: async () => { + tunnel.closed = true + } + } + tunnels.push(tunnel) + return tunnel + }) + }) + + const sshExtra = (): { ssh: SshConfig } => ({ + ssh: { host: 'bastion', port: 22, user: 'root', password: 'x' } + }) + + it('builds lazily and caches by hash', async () => { + const manager = createManager(registry, { createDriver, createTunnel }) + const first = await manager.get('c') + const second = await manager.get('c') + expect(first.driver).toBe(second.driver) + expect(drivers).toHaveLength(1) + expect(targets[0]).toMatchObject({ host: 'db-host', port: 5432 }) + expect(createTunnel).not.toHaveBeenCalled() + }) + + it('rebuilds and disposes the old driver when the hash changes', async () => { + const manager = createManager(registry, { createDriver, createTunnel }) + const first = await manager.get('c') + resolved = { config: config({ port: 9999 }), source: 'store', hash: 'h2' } + const second = await manager.get('c') + expect(second.driver).not.toBe(first.driver) + expect(drivers[0].disposed).toBe(true) + expect(targets[1]).toMatchObject({ port: 9999 }) + }) + + it('routes through the tunnel endpoint when ssh is configured', async () => { + resolved = { config: config(sshExtra()), source: 'store', hash: 'ssh1' } + const manager = createManager(registry, { createDriver, createTunnel }) + await manager.get('c') + expect(createTunnel).toHaveBeenCalledWith( + expect.objectContaining({ host: 'bastion' }), + 'db-host', + 5432 + ) + expect(targets[0]).toMatchObject({ host: '127.0.0.1', port: 54321 }) + }) + + it('uses the engine default port for mysql', async () => { + resolved = { config: config({ type: 'mysql' }), source: 'store', hash: 'm1' } + const manager = createManager(registry, { createDriver, createTunnel }) + await manager.get('c') + expect(targets[0]).toMatchObject({ port: 3306 }) + }) + + it('rebuilds when the tunnel reports closed', async () => { + resolved = { config: config(sshExtra()), source: 'store', hash: 'ssh1' } + const manager = createManager(registry, { createDriver, createTunnel }) + const first = await manager.get('c') + tunnels[0].closed = true + const second = await manager.get('c') + expect(second.driver).not.toBe(first.driver) + expect(drivers[0].disposed).toBe(true) + expect(tunnels).toHaveLength(2) + expect(second.driver).toBe(drivers[1]) + }) + + it('deduplicates concurrent builds', async () => { + const manager = createManager(registry, { createDriver, createTunnel }) + const [a, b] = await Promise.all([manager.get('c'), manager.get('c')]) + expect(a.driver).toBe(b.driver) + expect(drivers).toHaveLength(1) + }) + + it('clears the cache when a build fails, allowing retry', async () => { + const failing = vi + .fn() + .mockRejectedValueOnce(new Error('tunnel down')) + .mockImplementation(createTunnel) + resolved = { config: config(sshExtra()), source: 'store', hash: 'ssh1' } + const manager = createManager(registry, { createDriver, createTunnel: failing }) + await expect(manager.get('c')).rejects.toThrow('tunnel down') + const retried = await manager.get('c') + expect(retried.driver).toBe(drivers[0]) + }) + + it('invalidate disposes driver and tunnel', async () => { + resolved = { config: config(sshExtra()), source: 'store', hash: 'ssh1' } + const manager = createManager(registry, { createDriver, createTunnel }) + await manager.get('c') + await manager.invalidate('c') + expect(drivers[0].disposed).toBe(true) + expect(tunnels[0].closed).toBe(true) + }) + + it('invalidate of an unknown name is a no-op', async () => { + const manager = createManager(registry, { createDriver, createTunnel }) + await expect(manager.invalidate('ghost')).resolves.toBeUndefined() + }) + + it('invalidate during an in-flight build disposes the entry once built', async () => { + let release: () => void = () => {} + const gate = new Promise((resolve) => { + release = resolve + }) + const gatedTunnel = vi.fn(async (ssh: SshConfig, host: string, port: number) => { + await gate + return createTunnel(ssh, host, port) + }) + resolved = { config: config(sshExtra()), source: 'store', hash: 'ssh1' } + const manager = createManager(registry, { createDriver, createTunnel: gatedTunnel }) + const pending = manager.get('c') + const invalidated = manager.invalidate('c') + release() + await pending + await invalidated + expect(drivers[0].disposed).toBe(true) + expect(tunnels[0].closed).toBe(true) + const rebuilt = await manager.get('c') + expect(rebuilt.driver).toBe(drivers[1]) + }) + + it('disposeAll disposes everything', async () => { + const manager = createManager(registry, { createDriver, createTunnel }) + await manager.get('c') + await manager.disposeAll() + expect(drivers[0].disposed).toBe(true) + const again = await manager.get('c') + expect(again.driver).toBe(drivers[1]) + }) +})