From 96404fe0d3e5d68d0cdf399d6a3150d54e56abaa Mon Sep 17 00:00:00 2001 From: smartass Date: Fri, 12 Jun 2026 00:06:30 +0500 Subject: [PATCH] feat: add connection management tools --- src/tools/connections.ts | 137 ++++++++++++++++++++++++ test/unit/tools/connections.test.ts | 157 ++++++++++++++++++++++++++++ 2 files changed, 294 insertions(+) create mode 100644 src/tools/connections.ts create mode 100644 test/unit/tools/connections.test.ts diff --git a/src/tools/connections.ts b/src/tools/connections.ts new file mode 100644 index 0000000..4e81b63 --- /dev/null +++ b/src/tools/connections.ts @@ -0,0 +1,137 @@ +import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import * as z from 'zod' +import type { Registry } from '../config/registry.js' +import { + connectionConfigSchema, + defaultPort, + type ResolvedConnection, + sshConfigSchema +} from '../config/types.js' +import type { Manager } from '../db/manager.js' +import { formatDbError } from '../format.js' +import { errorMessage, fail, ok } from './respond.js' + +const patchSchema = z + .object({ + name: z + .string() + .regex(/^[a-zA-Z0-9_-]+$/) + .optional(), + type: z.enum(['postgres', 'mysql']).optional(), + host: z.string().min(1).optional(), + port: z.number().int().positive().nullable().optional(), + user: z.string().min(1).optional(), + password: z.string().nullable().optional(), + database: z.string().nullable().optional(), + readonly: z.boolean().optional(), + ssh: sshConfigSchema.nullable().optional() + }) + .strict() + +const publicView = (resolved: ResolvedConnection) => ({ + name: resolved.config.name, + type: resolved.config.type, + host: resolved.config.host, + port: resolved.config.port ?? defaultPort(resolved.config.type), + database: resolved.config.database ?? null, + readonly: resolved.config.readonly, + source: resolved.source, + ssh: Boolean(resolved.config.ssh) +}) + +export const registerConnectionTools = ( + server: McpServer, + registry: Registry, + manager: Manager +): void => { + server.registerTool( + 'list_connections', + { + description: + 'List configured database connections with their source layer (env/config/store). Secrets are omitted.' + }, + async () => ok(registry.list().map(publicView)) + ) + + server.registerTool( + 'add_connection', + { + description: + 'Add a named postgres/mysql connection to the persistent store. Supports optional ssh tunnel config and a readonly flag.', + inputSchema: connectionConfigSchema.shape + }, + async (config) => { + try { + return ok({ added: publicView(registry.add(connectionConfigSchema.parse(config))) }) + } catch (error) { + return fail(errorMessage(error)) + } + } + ) + + server.registerTool( + 'update_connection', + { + description: + 'Patch a stored connection by name. Set a field to null to remove it (e.g. "ssh": null). Only store-sourced connections can be edited.', + inputSchema: { + name: z.string(), + patch: patchSchema + } + }, + async ({ name, patch }) => { + try { + const updated = registry.update(name, patch) + await manager.invalidate(name) + return ok({ updated: publicView(updated) }) + } catch (error) { + return fail(errorMessage(error)) + } + } + ) + + server.registerTool( + 'remove_connection', + { + description: 'Remove a stored connection by name and drop its cached pools/tunnel.', + inputSchema: { name: z.string() } + }, + async ({ name }) => { + try { + registry.remove(name) + await manager.invalidate(name) + return ok({ removed: name }) + } catch (error) { + return fail(errorMessage(error)) + } + } + ) + + server.registerTool( + 'test_connection', + { + description: + 'Verify a connection works end to end (including ssh tunnel if configured). Reports server version and latency.', + inputSchema: { name: z.string() } + }, + async ({ name }) => { + let managed: Awaited> + try { + managed = await manager.get(name) + } catch (error) { + return fail(errorMessage(error)) + } + const started = performance.now() + try { + const version = await managed.driver.serverVersion() + return ok({ + ok: true, + version, + latencyMs: Math.round(performance.now() - started) + }) + } catch (error) { + return fail(formatDbError(managed.config.type, error)) + } + } + ) +} diff --git a/test/unit/tools/connections.test.ts b/test/unit/tools/connections.test.ts new file mode 100644 index 0000000..635d88c --- /dev/null +++ b/test/unit/tools/connections.test.ts @@ -0,0 +1,157 @@ +import { mkdtempSync, rmSync } from 'node:fs' +import { tmpdir } from 'node:os' +import { join } from 'node:path' +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { createRegistry } from '../../../src/config/registry.js' +import { writeStore } from '../../../src/config/store.js' +import type { ConnectionConfig } from '../../../src/config/types.js' +import { registerConnectionTools } from '../../../src/tools/connections.js' +import { callTool, connectClient, type FakeManager, fakeManager } from '../helpers.js' + +const conn = (name: string, extra: Partial = {}): ConnectionConfig => ({ + name, + type: 'postgres', + host: 'localhost', + user: 'postgres', + readonly: false, + ...extra +}) + +describe('connection tools', () => { + let dir: string + let storePath: string + let manager: FakeManager + let client: Awaited> + + beforeEach(async () => { + dir = mkdtempSync(join(tmpdir(), 'dbmole-tools-')) + storePath = join(dir, 'connections.json') + writeStore(storePath, [conn('existing', { password: 'secret', database: 'app' })]) + const registry = createRegistry({ storePath, env: {} }) + manager = fakeManager() + const server = new McpServer({ name: 'test', version: '0.0.0' }) + registerConnectionTools(server, registry, manager) + client = await connectClient(server) + }) + + afterEach(() => { + rmSync(dir, { recursive: true, force: true }) + }) + + it('list_connections returns public view without secrets', async () => { + const response = await callTool(client, 'list_connections') + expect(response.isError).toBe(false) + const list = response.json() as Array> + expect(list).toEqual([ + { + name: 'existing', + type: 'postgres', + host: 'localhost', + port: 5432, + database: 'app', + readonly: false, + source: 'store', + ssh: false + } + ]) + expect(response.text).not.toContain('secret') + }) + + it('add_connection persists and reports the new connection', async () => { + const response = await callTool(client, 'add_connection', { + name: 'fresh', + type: 'mysql', + host: 'db', + user: 'root' + }) + expect(response.isError).toBe(false) + const list = await callTool(client, 'list_connections') + expect((list.json() as Array<{ name: string }>).map((c) => c.name)).toEqual([ + 'existing', + 'fresh' + ]) + }) + + it('add_connection rejects duplicates with isError', async () => { + const response = await callTool(client, 'add_connection', { + name: 'existing', + type: 'postgres', + host: 'x', + user: 'u' + }) + expect(response.isError).toBe(true) + expect(response.text).toContain('already exists') + }) + + it('update_connection patches and invalidates the manager cache', async () => { + const response = await callTool(client, 'update_connection', { + name: 'existing', + patch: { host: 'new-host', database: null } + }) + expect(response.isError).toBe(false) + expect(manager.invalidate).toHaveBeenCalledWith('existing') + const list = await callTool(client, 'list_connections') + const updated = (list.json() as Array>)[0] + expect(updated.host).toBe('new-host') + expect(updated.database).toBeNull() + }) + + it('update_connection surfaces validation errors', async () => { + const response = await callTool(client, 'update_connection', { + name: 'existing', + patch: { type: 'oracle' } + }) + expect(response.isError).toBe(true) + }) + + it('remove_connection deletes and invalidates', async () => { + const response = await callTool(client, 'remove_connection', { name: 'existing' }) + expect(response.isError).toBe(false) + expect(manager.invalidate).toHaveBeenCalledWith('existing') + const list = await callTool(client, 'list_connections') + expect(list.json()).toEqual([]) + }) + + it('remove_connection of unknown name reports available connections', async () => { + const response = await callTool(client, 'remove_connection', { name: 'ghost' }) + expect(response.isError).toBe(true) + expect(response.text).toContain('existing') + }) + + it('test_connection reports version and latency', async () => { + manager.get.mockResolvedValue({ + driver: { serverVersion: vi.fn(async () => '17.2') }, + config: conn('existing'), + source: 'store' + }) + const response = await callTool(client, 'test_connection', { name: 'existing' }) + expect(response.isError).toBe(false) + const payload = response.json() as Record + expect(payload.ok).toBe(true) + expect(payload.version).toBe('17.2') + expect(typeof payload.latencyMs).toBe('number') + }) + + it('test_connection formats driver failures', async () => { + manager.get.mockResolvedValue({ + driver: { + serverVersion: vi.fn(async () => { + throw Object.assign(new Error('connect ECONNREFUSED'), { code: 'ECONNREFUSED' }) + }) + }, + config: conn('existing'), + source: 'store' + }) + const response = await callTool(client, 'test_connection', { name: 'existing' }) + expect(response.isError).toBe(true) + expect(response.text).toContain('ECONNREFUSED') + }) + + it('test_connection reports manager failures (e.g. tunnel)', async () => { + manager.get.mockRejectedValue(new Error('ssh auth failed')) + const response = await callTool(client, 'test_connection', { name: 'existing' }) + expect(response.isError).toBe(true) + expect(response.text).toContain('ssh auth failed') + }) +})