diff --git a/src/db/manager.ts b/src/db/manager.ts index fa21c8e..ae8d826 100644 --- a/src/db/manager.ts +++ b/src/db/manager.ts @@ -55,37 +55,61 @@ export const createManager = (registry: Registry, deps: ManagerDeps = {}): Manag const build = async (config: ConnectionConfig): Promise => { const port = config.port ?? defaultPort(config.type) const tunnel = config.ssh ? await createTunnel(config.ssh, config.host, port) : undefined - const target: DriverTarget = { - config, - host: tunnel ? tunnel.localHost : config.host, - port: tunnel ? tunnel.localPort : port + try { + const target: DriverTarget = { + config, + host: tunnel ? tunnel.localHost : config.host, + port: tunnel ? tunnel.localPort : port + } + return { driver: createDriver(target), tunnel } + } catch (error) { + await tunnel?.close().catch(() => {}) + throw error } - return { driver: createDriver(target), tunnel } } const get = async (name: string): Promise => { - const resolved = registry.resolve(name) - const cached = cache.get(name) - if (cached && cached.hash === resolved.hash) { - const entry = await cached.promise - if (!entry.tunnel?.isClosed()) { - return { driver: entry.driver, config: resolved.config, source: resolved.source } + for (;;) { + const resolved = registry.resolve(name) + const cached = cache.get(name) + if (cached && cached.hash === resolved.hash) { + const entry = await cached.promise + // Recheck ownership: a concurrent invalidate/disposeAll/hash-change + // may have replaced or deleted this slot while we awaited it. + const stillOwner = cache.get(name) === cached + if (stillOwner && !entry.tunnel?.isClosed()) { + return { + driver: entry.driver, + config: resolved.config, + source: resolved.source + } + } + if (stillOwner) { + cache.delete(name) + disposeSlot(cached) + } + continue } - } - if (cached) { - cache.delete(name) - disposeSlot(cached) - } - const slot: CacheSlot = { hash: resolved.hash, promise: build(resolved.config) } - cache.set(name, slot) - try { - const entry = await slot.promise - return { driver: entry.driver, config: resolved.config, source: resolved.source } - } catch (error) { - if (cache.get(name) === slot) { + if (cached) { cache.delete(name) + disposeSlot(cached) + } + const slot: CacheSlot = { hash: resolved.hash, promise: build(resolved.config) } + cache.set(name, slot) + try { + const entry = await slot.promise + if (cache.get(name) !== slot) { + // Lost ownership while building; the slot's disposal is handled + // by whoever replaced it. Retry to get a live entry. + continue + } + return { driver: entry.driver, config: resolved.config, source: resolved.source } + } catch (error) { + if (cache.get(name) === slot) { + cache.delete(name) + } + throw error } - throw error } } diff --git a/src/index.ts b/src/index.ts index aa2f23c..1e99f30 100644 --- a/src/index.ts +++ b/src/index.ts @@ -17,6 +17,8 @@ const main = async () => { return } shuttingDown = true + // Close the transport first so no new request can create pools mid-teardown. + await server.close().catch(() => {}) await manager.disposeAll().catch(() => {}) process.exit(0) } diff --git a/src/tools/connections.ts b/src/tools/connections.ts index 4e81b63..9b40f40 100644 --- a/src/tools/connections.ts +++ b/src/tools/connections.ts @@ -58,7 +58,9 @@ export const registerConnectionTools = ( { description: 'Add a named postgres/mysql connection to the persistent store. Supports optional ssh tunnel config and a readonly flag.', - inputSchema: connectionConfigSchema.shape + // Pass the strict object schema (not .shape) so the SDK rejects unknown + // keys at the boundary instead of silently stripping them. + inputSchema: connectionConfigSchema }, async (config) => { try { @@ -111,17 +113,18 @@ export const registerConnectionTools = ( 'test_connection', { description: - 'Verify a connection works end to end (including ssh tunnel if configured). Reports server version and latency.', + 'Verify a connection works end to end (including ssh tunnel if configured). Reports server version and connect+query latency.', inputSchema: { name: z.string() } }, async ({ name }) => { + // Start the clock before get() so cold tunnel/pool establishment counts. + const started = performance.now() let managed: Awaited> try { managed = await manager.get(name) } catch (error) { return fail(errorMessage(error)) } - const started = performance.now() try { const version = await managed.driver.serverVersion() return ok({ diff --git a/test/unit/db/manager.test.ts b/test/unit/db/manager.test.ts index 635ba34..7da2edb 100644 --- a/test/unit/db/manager.test.ts +++ b/test/unit/db/manager.test.ts @@ -175,12 +175,54 @@ describe('createManager', () => { const pending = manager.get('c') const invalidated = manager.invalidate('c') release() - await pending await invalidated + // The first build is disposed by invalidate; the pending get loses + // ownership and retries, resolving to a freshly rebuilt driver. + expect((await pending).driver).toBe(drivers[1]) expect(drivers[0].disposed).toBe(true) expect(tunnels[0].closed).toBe(true) - const rebuilt = await manager.get('c') - expect(rebuilt.driver).toBe(drivers[1]) + }) + + it('get racing disposeAll does not return a disposed entry', async () => { + let release: () => void = () => {} + const gate = new Promise((resolve) => { + release = resolve + }) + const gatedTunnel = vi.fn(async (ssh: SshConfig, host: string, port: number) => { + await gate + return createTunnel(ssh, host, port) + }) + resolved = { config: config(sshExtra()), source: 'store', hash: 'ssh1' } + const manager = createManager(registry, { createDriver, createTunnel: gatedTunnel }) + const pending = manager.get('c') + const disposed = manager.disposeAll() + release() + await disposed + // The gated build is disposed; the pending get retries and returns a + // fresh entry that was never disposed. + const managed = await pending + expect(managed.driver).toBe(drivers[1]) + expect(drivers[0].disposed).toBe(true) + expect(drivers[1].disposed).toBe(false) + }) + + it('closes the tunnel when createDriver throws and a retry succeeds', async () => { + const throwingDriver = vi + .fn<(target: DriverTarget) => Driver>() + .mockImplementationOnce(() => { + throw new Error('driver boom') + }) + .mockImplementation(createDriver) + resolved = { config: config(sshExtra()), source: 'store', hash: 'ssh1' } + const manager = createManager(registry, { + createDriver: throwingDriver, + createTunnel + }) + await expect(manager.get('c')).rejects.toThrow('driver boom') + expect(tunnels[0].closed).toBe(true) + const retried = await manager.get('c') + expect(retried.driver).toBe(drivers[0]) + expect(tunnels[1].closed).toBe(false) }) it('disposeAll disposes everything', async () => { diff --git a/test/unit/tools/connections.test.ts b/test/unit/tools/connections.test.ts index 635d88c..a66686d 100644 --- a/test/unit/tools/connections.test.ts +++ b/test/unit/tools/connections.test.ts @@ -73,6 +73,20 @@ describe('connection tools', () => { ]) }) + it('add_connection rejects unknown fields', async () => { + const response = await callTool(client, 'add_connection', { + name: 'typo', + type: 'postgres', + host: 'h', + user: 'u', + readOnly: true + }) + expect(response.isError).toBe(true) + const list = await callTool(client, 'list_connections') + const names = (list.json() as Array<{ name: string }>).map((c) => c.name) + expect(names).not.toContain('typo') + }) + it('add_connection rejects duplicates with isError', async () => { const response = await callTool(client, 'add_connection', { name: 'existing',