Skip to content

Commit daae72b

Browse files
committed
Enhance AI query validation by adding support for query refinement and explanation handling
1 parent 8a7b1e1 commit daae72b

File tree

1 file changed

+154
-12
lines changed

1 file changed

+154
-12
lines changed

backend/src/entities/visualizations/panel-position/use-cases/generate-panel-position-with-ai.use.case.ts

Lines changed: 154 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
import { BadRequestException, Inject, Injectable, NotFoundException, Scope } from '@nestjs/common';
1+
import { BadRequestException, Inject, Injectable, Logger, NotFoundException, Scope } from '@nestjs/common';
22
import { getDataAccessObject } from '@rocketadmin/shared-code/dist/src/data-access-layer/shared/create-data-access-object.js';
33
import { ConnectionTypesEnum } from '@rocketadmin/shared-code/dist/src/shared/enums/connection-types-enum.js';
4+
import { IDataAccessObject } from '@rocketadmin/shared-code/dist/src/shared/interfaces/data-access-object.interface.js';
5+
import { IDataAccessObjectAgent } from '@rocketadmin/shared-code/dist/src/shared/interfaces/data-access-object-agent.interface.js';
46
import AbstractUseCase from '../../../../common/abstract-use.case.js';
57
import { IGlobalDatabaseContext } from '../../../../common/application/global-database-context.interface.js';
68
import { BaseType } from '../../../../common/data-injection.tokens.js';
@@ -36,11 +38,29 @@ interface AIGeneratedWidgetResponse {
3638
};
3739
}
3840

41+
const MAX_FEEDBACK_ITERATIONS = 3;
42+
43+
const EXPLAIN_SUPPORTED_TYPES: ReadonlySet<ConnectionTypesEnum> = new Set([
44+
ConnectionTypesEnum.postgres,
45+
ConnectionTypesEnum.agent_postgres,
46+
ConnectionTypesEnum.mysql,
47+
ConnectionTypesEnum.agent_mysql,
48+
ConnectionTypesEnum.clickhouse,
49+
ConnectionTypesEnum.agent_clickhouse,
50+
]);
51+
52+
interface TableInfo {
53+
table_name: string;
54+
columns: Array<{ name: string; type: string; nullable: boolean }>;
55+
}
56+
3957
@Injectable({ scope: Scope.REQUEST })
4058
export class GeneratePanelPositionWithAiUseCase
4159
extends AbstractUseCase<GeneratePanelPositionWithAiDs, GeneratedPanelWithPositionDto>
4260
implements IGeneratePanelPositionWithAi
4361
{
62+
private readonly logger = new Logger(GeneratePanelPositionWithAiUseCase.name);
63+
4464
constructor(
4565
@Inject(BaseType.GLOBAL_DB_CONTEXT)
4666
protected _dbContext: IGlobalDatabaseContext,
@@ -83,10 +103,7 @@ export class GeneratePanelPositionWithAiUseCase
83103

84104
const dao = getDataAccessObject(foundConnection);
85105

86-
let tableInfo: {
87-
table_name: string;
88-
columns: Array<{ name: string; type: string; nullable: boolean }>;
89-
};
106+
let tableInfo: TableInfo;
90107

91108
try {
92109
const structure = await dao.getTableStructure(table_name, null);
@@ -116,15 +133,23 @@ export class GeneratePanelPositionWithAiUseCase
116133

117134
validateQuerySafety(generatedWidget.query_text, foundConnection.type as ConnectionTypesEnum);
118135

136+
const refinedWidget = await this.validateAndRefineQueryWithExplain(
137+
dao,
138+
generatedWidget,
139+
tableInfo,
140+
foundConnection.type as ConnectionTypesEnum,
141+
chart_description,
142+
);
143+
119144
return {
120-
name: name || generatedWidget.name,
121-
description: generatedWidget.description || null,
122-
widget_type: this.mapWidgetType(generatedWidget.widget_type),
123-
chart_type: generatedWidget.chart_type || null,
124-
widget_options: generatedWidget.widget_options
125-
? (generatedWidget.widget_options as unknown as Record<string, unknown>)
145+
name: name || refinedWidget.name,
146+
description: refinedWidget.description || null,
147+
widget_type: this.mapWidgetType(refinedWidget.widget_type),
148+
chart_type: refinedWidget.chart_type || null,
149+
widget_options: refinedWidget.widget_options
150+
? (refinedWidget.widget_options as unknown as Record<string, unknown>)
126151
: null,
127-
query_text: generatedWidget.query_text,
152+
query_text: refinedWidget.query_text,
128153
connection_id: connectionId,
129154
panel_position: {
130155
position_x: position_x ?? 0,
@@ -219,6 +244,123 @@ Respond ONLY with the JSON object, no additional text or explanation.`;
219244
}
220245
}
221246

247+
private async validateAndRefineQueryWithExplain(
248+
dao: IDataAccessObject | IDataAccessObjectAgent,
249+
generatedWidget: AIGeneratedWidgetResponse,
250+
tableInfo: TableInfo,
251+
connectionType: ConnectionTypesEnum,
252+
chartDescription: string,
253+
): Promise<AIGeneratedWidgetResponse> {
254+
if (!EXPLAIN_SUPPORTED_TYPES.has(connectionType)) {
255+
return generatedWidget;
256+
}
257+
258+
let currentQuery = generatedWidget.query_text;
259+
260+
for (let iteration = 0; iteration < MAX_FEEDBACK_ITERATIONS; iteration++) {
261+
const explainResult = await this.runExplainQuery(dao, currentQuery, tableInfo.table_name);
262+
263+
const correctionPrompt = this.buildQueryCorrectionPrompt(
264+
currentQuery,
265+
explainResult.success ? explainResult.result : explainResult.error,
266+
!explainResult.success,
267+
tableInfo,
268+
connectionType,
269+
chartDescription,
270+
);
271+
272+
const aiResponse = await this.aiCoreService.completeWithProvider(AIProviderType.BEDROCK, correctionPrompt, {
273+
temperature: 0.2,
274+
});
275+
276+
const correctedQuery = this.cleanQueryResponse(aiResponse);
277+
278+
validateQuerySafety(correctedQuery, connectionType);
279+
280+
if (this.normalizeWhitespace(correctedQuery) === this.normalizeWhitespace(currentQuery)) {
281+
this.logger.log(`Query accepted by AI without changes after EXPLAIN (iteration ${iteration + 1})`);
282+
break;
283+
}
284+
285+
this.logger.log(`Query corrected by AI after EXPLAIN (iteration ${iteration + 1})`);
286+
currentQuery = correctedQuery;
287+
288+
if (explainResult.success) {
289+
break;
290+
}
291+
}
292+
293+
return { ...generatedWidget, query_text: currentQuery };
294+
}
295+
296+
private async runExplainQuery(
297+
dao: IDataAccessObject | IDataAccessObjectAgent,
298+
query: string,
299+
tableName: string,
300+
): Promise<{ success: boolean; result?: string; error?: string }> {
301+
try {
302+
const explainQuery = `EXPLAIN ${query.replace(/;\s*$/, '')}`;
303+
const result = await (dao as IDataAccessObject).executeRawQuery(explainQuery, tableName);
304+
return { success: true, result: JSON.stringify(result, null, 2) };
305+
} catch (error) {
306+
return { success: false, error: error.message };
307+
}
308+
}
309+
310+
private buildQueryCorrectionPrompt(
311+
currentQuery: string,
312+
explainResultOrError: string,
313+
isError: boolean,
314+
tableInfo: TableInfo,
315+
connectionType: ConnectionTypesEnum,
316+
chartDescription: string,
317+
): string {
318+
const schemaDescription = `Table: ${tableInfo.table_name}\n Columns:\n${tableInfo.columns
319+
.map((col) => ` - ${col.name}: ${col.type}${col.nullable ? ' (nullable)' : ''}`)
320+
.join('\n')}`;
321+
322+
const feedbackSection = isError
323+
? `The query FAILED with the following error:\n${explainResultOrError}\n\nPlease fix the query to resolve this error.`
324+
: `The EXPLAIN output for the query is:\n${explainResultOrError}\n\nReview the execution plan. If the query has performance issues (full table scans on large datasets, inefficient joins, etc.), optimize it. If the query is already acceptable, return it unchanged.`;
325+
326+
return `You are a database query optimization assistant. A SQL query was generated and needs validation.
327+
328+
DATABASE TYPE: ${connectionType}
329+
330+
DATABASE SCHEMA:
331+
${schemaDescription}
332+
333+
ORIGINAL USER REQUEST:
334+
"${chartDescription}"
335+
336+
CURRENT QUERY:
337+
${currentQuery}
338+
339+
${feedbackSection}
340+
341+
IMPORTANT:
342+
- Preserve the same column aliases used in the original query.
343+
- Write valid ${connectionType} SQL syntax.
344+
- Return ONLY the SQL query, no explanations, no markdown, no JSON wrapping.`;
345+
}
346+
347+
private cleanQueryResponse(aiResponse: string): string {
348+
let cleaned = aiResponse.trim();
349+
if (cleaned.startsWith('```sql')) {
350+
cleaned = cleaned.slice(6);
351+
} else if (cleaned.startsWith('```')) {
352+
cleaned = cleaned.slice(3);
353+
}
354+
if (cleaned.endsWith('```')) {
355+
cleaned = cleaned.slice(0, -3);
356+
}
357+
return cleaned.trim();
358+
}
359+
360+
private normalizeWhitespace(query: string): string {
361+
return query.replace(/\s+/g, ' ').trim();
362+
}
363+
222364
private mapWidgetType(type: string): DashboardWidgetTypeEnum {
223365
switch (type) {
224366
case 'chart':

0 commit comments

Comments
 (0)