diff --git a/src/tools/query.ts b/src/tools/query.ts new file mode 100644 index 0000000..56b524f --- /dev/null +++ b/src/tools/query.ts @@ -0,0 +1,48 @@ +import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import * as z from 'zod' +import type { Manager } from '../db/manager.js' +import { clampRowLimit, formatDbError, MAX_ROW_LIMIT } from '../format.js' +import { errorMessage, fail, ok } from './respond.js' + +export const registerQueryTools = (server: McpServer, manager: Manager): void => { + server.registerTool( + 'execute_sql', + { + description: + 'Execute a single SQL statement on a named connection. SELECT returns columns/rows; ' + + 'DML returns rowCount (and lastInsertId for mysql). Use positional params: $1.. for postgres, ? for mysql.', + inputSchema: { + connection: z.string().describe('connection name, see list_connections'), + sql: z.string().min(1), + params: z + .array(z.union([z.string(), z.number(), z.boolean(), z.null()])) + .optional() + .describe('positional query parameters'), + database: z + .string() + .optional() + .describe('database to run against; defaults to the connection default'), + rowLimit: z.number().int().min(1).max(MAX_ROW_LIMIT).optional() + } + }, + async ({ connection, sql, params, database, rowLimit }) => { + let managed: Awaited> + try { + managed = await manager.get(connection) + } catch (error) { + return fail(errorMessage(error)) + } + try { + const result = await managed.driver.query({ + sql, + params: params ?? [], + database, + rowLimit: clampRowLimit(rowLimit) + }) + return ok(result) + } catch (error) { + return fail(formatDbError(managed.config.type, error)) + } + } + ) +} diff --git a/test/unit/tools/query.test.ts b/test/unit/tools/query.test.ts new file mode 100644 index 0000000..644eb69 --- /dev/null +++ b/test/unit/tools/query.test.ts @@ -0,0 +1,85 @@ +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { registerQueryTools } from '../../../src/tools/query.js' +import { callTool, connectClient, type FakeManager, fakeManager } from '../helpers.js' + +describe('execute_sql', () => { + let manager: FakeManager + let query: ReturnType + let client: Awaited> + + beforeEach(async () => { + query = vi.fn(async () => ({ + columns: ['n'], + rows: [[1]], + rowCount: 1, + truncated: false + })) + manager = fakeManager() + manager.get.mockImplementation(async () => ({ + driver: { query }, + config: { name: 'c', type: 'postgres', host: 'h', user: 'u', readonly: false }, + source: 'store' + })) + const server = new McpServer({ name: 'test', version: '0.0.0' }) + registerQueryTools(server, manager) + client = await connectClient(server) + }) + + it('executes sql and returns the result payload', async () => { + const response = await callTool(client, 'execute_sql', { + connection: 'c', + sql: 'select 1 as n' + }) + expect(response.isError).toBe(false) + expect(response.json()).toEqual({ + columns: ['n'], + rows: [[1]], + rowCount: 1, + truncated: false + }) + expect(query).toHaveBeenCalledWith({ + sql: 'select 1 as n', + params: [], + database: undefined, + rowLimit: 100 + }) + }) + + it('passes params, database and clamped rowLimit through', async () => { + await callTool(client, 'execute_sql', { + connection: 'c', + sql: 'select $1', + params: [42], + database: 'other', + rowLimit: 5 + }) + expect(query).toHaveBeenCalledWith({ + sql: 'select $1', + params: [42], + database: 'other', + rowLimit: 5 + }) + }) + + it('formats driver errors with the engine prefix', async () => { + query.mockRejectedValue( + Object.assign(new Error('relation "x" does not exist'), { code: '42P01' }) + ) + const response = await callTool(client, 'execute_sql', { connection: 'c', sql: 'select' }) + expect(response.isError).toBe(true) + expect(response.text).toBe('[postgres 42P01] relation "x" does not exist') + }) + + it('reports unknown connections without engine prefix', async () => { + manager.get.mockRejectedValue( + new Error("connection 'nope' not found. Available connections: c") + ) + const response = await callTool(client, 'execute_sql', { + connection: 'nope', + sql: 'select 1' + }) + expect(response.isError).toBe(true) + expect(response.text).toContain("connection 'nope' not found") + }) +})