From 7514aedb1e767f5208f793b4d2683aa36c5d8e8e Mon Sep 17 00:00:00 2001 From: smartass Date: Fri, 12 Jun 2026 00:10:55 +0500 Subject: [PATCH] feat: add schema inspection tools --- src/tools/schema.ts | 74 ++++++++++++++++++++++++++++ test/unit/tools/schema.test.ts | 90 ++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 src/tools/schema.ts create mode 100644 test/unit/tools/schema.test.ts diff --git a/src/tools/schema.ts b/src/tools/schema.ts new file mode 100644 index 0000000..dc58b1e --- /dev/null +++ b/src/tools/schema.ts @@ -0,0 +1,74 @@ +import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import type { CallToolResult } from '@modelcontextprotocol/sdk/types.js' +import * as z from 'zod' +import type { ManagedConnection, Manager } from '../db/manager.js' +import { formatDbError } from '../format.js' +import { errorMessage, fail, ok } from './respond.js' + +const withManaged = async ( + manager: Manager, + connection: string, + fn: (managed: ManagedConnection) => Promise +): Promise => { + let managed: ManagedConnection + try { + managed = await manager.get(connection) + } catch (error) { + return fail(errorMessage(error)) + } + try { + return ok(await fn(managed)) + } catch (error) { + return fail(formatDbError(managed.config.type, error)) + } +} + +export const registerSchemaTools = (server: McpServer, manager: Manager): void => { + server.registerTool( + 'list_databases', + { + description: + 'List databases on a connection with their size in bytes (system databases hidden).', + inputSchema: { + connection: z.string().describe('connection name, see list_connections') + } + }, + async ({ connection }) => + withManaged(manager, connection, (managed) => managed.driver.listDatabases()) + ) + + server.registerTool( + 'list_tables', + { + description: + 'List tables with approximate row counts. For postgres, optionally filter by schema; for mysql, schema is the database.', + inputSchema: { + connection: z.string(), + database: z.string().optional(), + schema: z.string().optional() + } + }, + async ({ connection, database, schema }) => + withManaged(manager, connection, (managed) => + managed.driver.listTables({ database, schema }) + ) + ) + + server.registerTool( + 'describe_table', + { + description: + 'Describe a table: columns (type, nullable, default), primary key, indexes and foreign keys.', + inputSchema: { + connection: z.string(), + table: z.string(), + database: z.string().optional(), + schema: z.string().optional().describe('postgres schema, defaults to public') + } + }, + async ({ connection, table, database, schema }) => + withManaged(manager, connection, (managed) => + managed.driver.describeTable({ table, database, schema }) + ) + ) +} diff --git a/test/unit/tools/schema.test.ts b/test/unit/tools/schema.test.ts new file mode 100644 index 0000000..0c1759c --- /dev/null +++ b/test/unit/tools/schema.test.ts @@ -0,0 +1,90 @@ +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { registerSchemaTools } from '../../../src/tools/schema.js' +import { callTool, connectClient, type FakeManager, fakeManager } from '../helpers.js' + +describe('schema tools', () => { + let driver: { + listDatabases: ReturnType + listTables: ReturnType + describeTable: ReturnType + } + let manager: FakeManager + let client: Awaited> + + beforeEach(async () => { + driver = { + listDatabases: vi.fn(async () => [{ name: 'app', sizeBytes: 1024 }]), + listTables: vi.fn(async () => [{ schema: 'public', name: 'users', rowEstimate: 5 }]), + describeTable: vi.fn(async () => ({ + columns: [{ name: 'id', type: 'integer', nullable: false, default: null }], + primaryKey: ['id'], + indexes: [], + foreignKeys: [] + })) + } + manager = fakeManager() + manager.get.mockImplementation(async () => ({ + driver, + config: { name: 'c', type: 'postgres', host: 'h', user: 'u', readonly: false }, + source: 'store' + })) + const server = new McpServer({ name: 'test', version: '0.0.0' }) + registerSchemaTools(server, manager) + client = await connectClient(server) + }) + + it('list_databases returns database info', async () => { + const response = await callTool(client, 'list_databases', { connection: 'c' }) + expect(response.isError).toBe(false) + expect(response.json()).toEqual([{ name: 'app', sizeBytes: 1024 }]) + }) + + it('list_tables forwards database and schema filters', async () => { + const response = await callTool(client, 'list_tables', { + connection: 'c', + database: 'db', + schema: 'public' + }) + expect(response.isError).toBe(false) + expect(driver.listTables).toHaveBeenCalledWith({ database: 'db', schema: 'public' }) + expect(response.json()).toEqual([{ schema: 'public', name: 'users', rowEstimate: 5 }]) + }) + + it('describe_table forwards args and returns the description', async () => { + const response = await callTool(client, 'describe_table', { + connection: 'c', + table: 'users' + }) + expect(response.isError).toBe(false) + expect(driver.describeTable).toHaveBeenCalledWith({ + table: 'users', + database: undefined, + schema: undefined + }) + expect((response.json() as { primaryKey: string[] }).primaryKey).toEqual(['id']) + }) + + it('formats driver errors', async () => { + driver.describeTable.mockRejectedValue(new Error("table 'public.ghost' not found")) + const response = await callTool(client, 'describe_table', { + connection: 'c', + table: 'ghost' + }) + expect(response.isError).toBe(true) + expect(response.text).toContain('not found') + }) + + it('reports manager errors per tool', async () => { + manager.get.mockRejectedValue(new Error('no such connection')) + for (const [tool, args] of [ + ['list_databases', { connection: 'x' }], + ['list_tables', { connection: 'x' }], + ['describe_table', { connection: 'x', table: 't' }] + ] as const) { + const response = await callTool(client, tool, args as Record) + expect(response.isError).toBe(true) + expect(response.text).toContain('no such connection') + } + }) +})