feat: add execute_sql tool
This commit is contained in:
@@ -0,0 +1,48 @@
|
|||||||
|
import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'
|
||||||
|
import * as z from 'zod'
|
||||||
|
import type { Manager } from '../db/manager.js'
|
||||||
|
import { clampRowLimit, formatDbError, MAX_ROW_LIMIT } from '../format.js'
|
||||||
|
import { errorMessage, fail, ok } from './respond.js'
|
||||||
|
|
||||||
|
export const registerQueryTools = (server: McpServer, manager: Manager): void => {
|
||||||
|
server.registerTool(
|
||||||
|
'execute_sql',
|
||||||
|
{
|
||||||
|
description:
|
||||||
|
'Execute a single SQL statement on a named connection. SELECT returns columns/rows; ' +
|
||||||
|
'DML returns rowCount (and lastInsertId for mysql). Use positional params: $1.. for postgres, ? for mysql.',
|
||||||
|
inputSchema: {
|
||||||
|
connection: z.string().describe('connection name, see list_connections'),
|
||||||
|
sql: z.string().min(1),
|
||||||
|
params: z
|
||||||
|
.array(z.union([z.string(), z.number(), z.boolean(), z.null()]))
|
||||||
|
.optional()
|
||||||
|
.describe('positional query parameters'),
|
||||||
|
database: z
|
||||||
|
.string()
|
||||||
|
.optional()
|
||||||
|
.describe('database to run against; defaults to the connection default'),
|
||||||
|
rowLimit: z.number().int().min(1).max(MAX_ROW_LIMIT).optional()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
async ({ connection, sql, params, database, rowLimit }) => {
|
||||||
|
let managed: Awaited<ReturnType<Manager['get']>>
|
||||||
|
try {
|
||||||
|
managed = await manager.get(connection)
|
||||||
|
} catch (error) {
|
||||||
|
return fail(errorMessage(error))
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const result = await managed.driver.query({
|
||||||
|
sql,
|
||||||
|
params: params ?? [],
|
||||||
|
database,
|
||||||
|
rowLimit: clampRowLimit(rowLimit)
|
||||||
|
})
|
||||||
|
return ok(result)
|
||||||
|
} catch (error) {
|
||||||
|
return fail(formatDbError(managed.config.type, error))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import { registerQueryTools } from '../../../src/tools/query.js'
|
||||||
|
import { callTool, connectClient, type FakeManager, fakeManager } from '../helpers.js'
|
||||||
|
|
||||||
|
describe('execute_sql', () => {
|
||||||
|
let manager: FakeManager
|
||||||
|
let query: ReturnType<typeof vi.fn>
|
||||||
|
let client: Awaited<ReturnType<typeof connectClient>>
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
query = vi.fn(async () => ({
|
||||||
|
columns: ['n'],
|
||||||
|
rows: [[1]],
|
||||||
|
rowCount: 1,
|
||||||
|
truncated: false
|
||||||
|
}))
|
||||||
|
manager = fakeManager()
|
||||||
|
manager.get.mockImplementation(async () => ({
|
||||||
|
driver: { query },
|
||||||
|
config: { name: 'c', type: 'postgres', host: 'h', user: 'u', readonly: false },
|
||||||
|
source: 'store'
|
||||||
|
}))
|
||||||
|
const server = new McpServer({ name: 'test', version: '0.0.0' })
|
||||||
|
registerQueryTools(server, manager)
|
||||||
|
client = await connectClient(server)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('executes sql and returns the result payload', async () => {
|
||||||
|
const response = await callTool(client, 'execute_sql', {
|
||||||
|
connection: 'c',
|
||||||
|
sql: 'select 1 as n'
|
||||||
|
})
|
||||||
|
expect(response.isError).toBe(false)
|
||||||
|
expect(response.json()).toEqual({
|
||||||
|
columns: ['n'],
|
||||||
|
rows: [[1]],
|
||||||
|
rowCount: 1,
|
||||||
|
truncated: false
|
||||||
|
})
|
||||||
|
expect(query).toHaveBeenCalledWith({
|
||||||
|
sql: 'select 1 as n',
|
||||||
|
params: [],
|
||||||
|
database: undefined,
|
||||||
|
rowLimit: 100
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('passes params, database and clamped rowLimit through', async () => {
|
||||||
|
await callTool(client, 'execute_sql', {
|
||||||
|
connection: 'c',
|
||||||
|
sql: 'select $1',
|
||||||
|
params: [42],
|
||||||
|
database: 'other',
|
||||||
|
rowLimit: 5
|
||||||
|
})
|
||||||
|
expect(query).toHaveBeenCalledWith({
|
||||||
|
sql: 'select $1',
|
||||||
|
params: [42],
|
||||||
|
database: 'other',
|
||||||
|
rowLimit: 5
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('formats driver errors with the engine prefix', async () => {
|
||||||
|
query.mockRejectedValue(
|
||||||
|
Object.assign(new Error('relation "x" does not exist'), { code: '42P01' })
|
||||||
|
)
|
||||||
|
const response = await callTool(client, 'execute_sql', { connection: 'c', sql: 'select' })
|
||||||
|
expect(response.isError).toBe(true)
|
||||||
|
expect(response.text).toBe('[postgres 42P01] relation "x" does not exist')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('reports unknown connections without engine prefix', async () => {
|
||||||
|
manager.get.mockRejectedValue(
|
||||||
|
new Error("connection 'nope' not found. Available connections: c")
|
||||||
|
)
|
||||||
|
const response = await callTool(client, 'execute_sql', {
|
||||||
|
connection: 'nope',
|
||||||
|
sql: 'select 1'
|
||||||
|
})
|
||||||
|
expect(response.isError).toBe(true)
|
||||||
|
expect(response.text).toContain("connection 'nope' not found")
|
||||||
|
})
|
||||||
|
})
|
||||||
Reference in New Issue
Block a user