diff --git a/src/db/mysql.ts b/src/db/mysql.ts new file mode 100644 index 0000000..0651f76 --- /dev/null +++ b/src/db/mysql.ts @@ -0,0 +1,247 @@ +import mysql from 'mysql2/promise' +import { normalizeCell, truncateRows } from '../format.js' +import { + type DatabaseInfo, + type Driver, + type DriverTarget, + type QueryArgs, + type QueryResult, + resolveDatabase, + type TableDescription, + type TableInfo +} from './driver.js' + +type RowsAndFields = [unknown, Array<{ name: string }> | undefined] + +const NO_DATABASE = '' + +export const createMysqlDriver = (target: DriverTarget): Driver => { + const pools = new Map() + const readonlyApplied = new WeakSet() + + const getPool = (database: string | null): mysql.Pool => { + const key = database ?? NO_DATABASE + const existing = pools.get(key) + if (existing) { + return existing + } + const pool = mysql.createPool({ + host: target.host, + port: target.port, + user: target.config.user, + password: target.config.password, + ...(database ? { database } : {}), + connectionLimit: 4, + multipleStatements: false + }) + pools.set(key, pool) + return pool + } + + const withConnection = async ( + database: string | null, + fn: (conn: mysql.PoolConnection) => Promise + ): Promise => { + const conn = await getPool(database).getConnection() + try { + if (target.config.readonly && !readonlyApplied.has(conn)) { + await conn.query('SET SESSION TRANSACTION READ ONLY') + readonlyApplied.add(conn) + } + return await fn(conn) + } finally { + conn.release() + } + } + + const runArray = async ( + database: string | null, + sql: string, + params: unknown[] = [] + ): Promise<{ columns: string[]; rows: unknown[][] }> => + withConnection(database, async (conn) => { + const [rows, fields] = (await conn.query({ + sql, + values: params, + rowsAsArray: true + })) as RowsAndFields + return { + columns: (fields ?? []).map((field) => field.name), + rows: (rows as unknown[][]) ?? [] + } + }) + + const query = async ({ + sql, + params = [], + database, + rowLimit + }: QueryArgs): Promise => { + const db = resolveDatabase(target.config, database) + return withConnection(db, async (conn) => { + const [result, fields] = (await conn.query({ + sql, + values: params, + rowsAsArray: true + })) as RowsAndFields + if (Array.isArray(result)) { + const { rows, truncated } = truncateRows(result as unknown[][], rowLimit) + return { + columns: (fields ?? []).map((field) => field.name), + rows: rows.map((row) => row.map(normalizeCell)), + rowCount: result.length, + truncated + } + } + const header = result as mysql.ResultSetHeader + return { + columns: [], + rows: [], + rowCount: header.affectedRows ?? 0, + truncated: false, + ...(header.insertId ? { lastInsertId: String(header.insertId) } : {}) + } + }) + } + + const listDatabases = async (): Promise => { + const result = await runArray( + null, + ` + SELECT s.schema_name, COALESCE(SUM(t.data_length + t.index_length), 0) + FROM information_schema.schemata s + LEFT JOIN information_schema.tables t ON t.table_schema = s.schema_name + WHERE s.schema_name NOT IN ('mysql', 'sys', 'performance_schema', 'information_schema') + GROUP BY s.schema_name + ORDER BY s.schema_name + ` + ) + return result.rows.map((row) => ({ + name: String(row[0]), + sizeBytes: row[1] === null ? null : Number(row[1]) + })) + } + + const effectiveSchema = (args: { database?: string; schema?: string }): string => + args.schema ?? resolveDatabase(target.config, args.database) + + const listTables = async (args: { + database?: string + schema?: string + }): Promise => { + const schema = effectiveSchema(args) + const result = await runArray( + null, + ` + SELECT table_schema, table_name, table_rows + FROM information_schema.tables + WHERE table_schema = ? AND table_type = 'BASE TABLE' + ORDER BY table_name + `, + [schema] + ) + return result.rows.map((row) => ({ + schema: String(row[0]), + name: String(row[1]), + rowEstimate: row[2] === null ? null : Number(row[2]) + })) + } + + const describeTable = async (args: { + table: string + database?: string + schema?: string + }): Promise => { + const schema = effectiveSchema(args) + + const columns = await runArray( + null, + ` + SELECT column_name, column_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_schema = ? AND table_name = ? + ORDER BY ordinal_position + `, + [schema, args.table] + ) + if (columns.rows.length === 0) { + throw new Error("table '" + schema + '.' + args.table + "' not found") + } + + const indexRows = await runArray( + null, + ` + SELECT index_name, non_unique, column_name + FROM information_schema.statistics + WHERE table_schema = ? AND table_name = ? + ORDER BY index_name, seq_in_index + `, + [schema, args.table] + ) + + const fkRows = await runArray( + null, + ` + SELECT constraint_name, column_name, referenced_table_name, referenced_column_name + FROM information_schema.key_column_usage + WHERE table_schema = ? AND table_name = ? AND referenced_table_name IS NOT NULL + ORDER BY constraint_name, ordinal_position + `, + [schema, args.table] + ) + + const primaryKey: string[] = [] + const indexMap = new Map() + for (const row of indexRows.rows) { + const name = String(row[0]) + if (name === 'PRIMARY') { + primaryKey.push(String(row[2])) + continue + } + const entry = indexMap.get(name) ?? { unique: !Number(row[1]), columns: [] } + entry.columns.push(String(row[2])) + indexMap.set(name, entry) + } + + const fkMap = new Map< + string, + { columns: string[]; referencedTable: string; referencedColumns: string[] } + >() + for (const row of fkRows.rows) { + const name = String(row[0]) + const entry = fkMap.get(name) ?? { + columns: [], + referencedTable: String(row[2]), + referencedColumns: [] + } + entry.columns.push(String(row[1])) + entry.referencedColumns.push(String(row[3])) + fkMap.set(name, entry) + } + + return { + columns: columns.rows.map((row) => ({ + name: String(row[0]), + type: String(row[1]), + nullable: row[2] === 'YES', + default: row[3] === null ? null : String(row[3]) + })), + primaryKey, + indexes: [...indexMap.entries()].map(([name, entry]) => ({ name, ...entry })), + foreignKeys: [...fkMap.entries()].map(([name, entry]) => ({ name, ...entry })) + } + } + + const serverVersion = async (): Promise => { + const result = await runArray(null, 'SELECT VERSION()') + return String(result.rows[0]?.[0] ?? 'unknown') + } + + const dispose = async (): Promise => { + const all = [...pools.values()] + pools.clear() + await Promise.all(all.map((pool) => pool.end().catch(() => {}))) + } + + return { query, listDatabases, listTables, describeTable, serverVersion, dispose } +} diff --git a/test/unit/db/mysql.test.ts b/test/unit/db/mysql.test.ts new file mode 100644 index 0000000..a1819da --- /dev/null +++ b/test/unit/db/mysql.test.ts @@ -0,0 +1,174 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const mysqlState = vi.hoisted(() => { + const state = { + pools: [] as FakeMysqlPool[], + nextResult: undefined as unknown, + nextFields: undefined as unknown + } + + class FakeConnection { + queries: unknown[] = [] + released = 0 + + async query(args: unknown) { + this.queries.push(args) + return [state.nextResult ?? [], state.nextFields ?? []] + } + + release() { + this.released += 1 + } + } + + class FakeMysqlPool { + options: Record + connection = new FakeConnection() + ended = false + + constructor(options: Record) { + this.options = options + state.pools.push(this) + } + + async getConnection() { + return this.connection + } + + async end() { + this.ended = true + } + } + + return { state, FakeMysqlPool } +}) + +vi.mock('mysql2/promise', () => ({ + default: { + createPool: (options: Record) => new mysqlState.FakeMysqlPool(options) + } +})) + +import type { ConnectionConfig } from '../../../src/config/types.js' +import { createMysqlDriver } from '../../../src/db/mysql.js' + +const config = (extra: Partial = {}): ConnectionConfig => ({ + name: 'my', + type: 'mysql', + host: 'real-host', + user: 'root', + password: 'pw', + database: 'main', + readonly: false, + ...extra +}) + +const target = (extra: Partial = {}) => ({ + config: config(extra), + host: '127.0.0.1', + port: 13306 +}) + +describe('createMysqlDriver', () => { + beforeEach(() => { + mysqlState.state.pools.length = 0 + mysqlState.state.nextResult = undefined + mysqlState.state.nextFields = undefined + }) + + it('creates pools with multipleStatements disabled', async () => { + const driver = createMysqlDriver(target()) + await driver.query({ sql: 'select 1', rowLimit: 10 }) + expect(mysqlState.state.pools[0].options).toMatchObject({ + host: '127.0.0.1', + port: 13306, + database: 'main', + connectionLimit: 4, + multipleStatements: false + }) + }) + + it('reuses pools per database', async () => { + const driver = createMysqlDriver(target()) + await driver.query({ sql: 'select 1', rowLimit: 10 }) + await driver.query({ sql: 'select 2', rowLimit: 10 }) + await driver.query({ sql: 'select 3', database: 'other', rowLimit: 10 }) + expect(mysqlState.state.pools).toHaveLength(2) + }) + + it('applies SET SESSION TRANSACTION READ ONLY once per connection when readonly', async () => { + const driver = createMysqlDriver(target({ readonly: true })) + await driver.query({ sql: 'select 1', rowLimit: 10 }) + await driver.query({ sql: 'select 2', rowLimit: 10 }) + const queries = mysqlState.state.pools[0].connection.queries + const readonlySets = queries.filter((q) => q === 'SET SESSION TRANSACTION READ ONLY') + expect(readonlySets).toHaveLength(1) + expect(queries[0]).toBe('SET SESSION TRANSACTION READ ONLY') + }) + + it('does not set readonly for writable connections', async () => { + const driver = createMysqlDriver(target()) + await driver.query({ sql: 'select 1', rowLimit: 10 }) + const queries = mysqlState.state.pools[0].connection.queries + expect(queries.some((q) => q === 'SET SESSION TRANSACTION READ ONLY')).toBe(false) + }) + + it('maps SELECT results with columns and normalized cells', async () => { + mysqlState.state.nextResult = [[1n, 'x']] + mysqlState.state.nextFields = [{ name: 'id' }, { name: 'label' }] + const driver = createMysqlDriver(target()) + const result = await driver.query({ sql: 'select * from t', rowLimit: 10 }) + expect(result).toEqual({ + columns: ['id', 'label'], + rows: [['1', 'x']], + rowCount: 1, + truncated: false + }) + }) + + it('maps DML results to rowCount and lastInsertId', async () => { + mysqlState.state.nextResult = { affectedRows: 3, insertId: 7 } + const driver = createMysqlDriver(target()) + const result = await driver.query({ sql: 'insert into t values (1)', rowLimit: 10 }) + expect(result).toEqual({ + columns: [], + rows: [], + rowCount: 3, + truncated: false, + lastInsertId: '7' + }) + }) + + it('omits lastInsertId when zero', async () => { + mysqlState.state.nextResult = { affectedRows: 1, insertId: 0 } + const driver = createMysqlDriver(target()) + const result = await driver.query({ sql: 'update t set a=1', rowLimit: 10 }) + expect(result.lastInsertId).toBeUndefined() + }) + + it('listTables resolves schema > database > connection default', async () => { + mysqlState.state.nextResult = [['main', 'users', 10]] + mysqlState.state.nextFields = [] + const driver = createMysqlDriver(target()) + await driver.listTables({}) + await driver.listTables({ database: 'db-param' }) + await driver.listTables({ schema: 'schema-param' }) + const calls = mysqlState.state.pools.flatMap((pool) => + pool.connection.queries.filter( + (q): q is { sql: string; values: unknown[] } => + typeof q === 'object' && q !== null && 'values' in (q as object) + ) + ) + expect(calls[0].values).toEqual(['main']) + expect(calls[1].values).toEqual(['db-param']) + expect(calls[2].values).toEqual(['schema-param']) + }) + + it('dispose ends all pools', async () => { + const driver = createMysqlDriver(target()) + await driver.query({ sql: 'select 1', rowLimit: 1 }) + await driver.listDatabases() + await driver.dispose() + expect(mysqlState.state.pools.every((pool) => pool.ended)).toBe(true) + }) +})