|
1 | | -import { BadRequestException, Inject, Injectable, NotFoundException, Scope } from '@nestjs/common'; |
| 1 | +import { BadRequestException, Inject, Injectable, Logger, NotFoundException, Scope } from '@nestjs/common'; |
2 | 2 | import { getDataAccessObject } from '@rocketadmin/shared-code/dist/src/data-access-layer/shared/create-data-access-object.js'; |
3 | 3 | import { ConnectionTypesEnum } from '@rocketadmin/shared-code/dist/src/shared/enums/connection-types-enum.js'; |
| 4 | +import { IDataAccessObject } from '@rocketadmin/shared-code/dist/src/shared/interfaces/data-access-object.interface.js'; |
| 5 | +import { IDataAccessObjectAgent } from '@rocketadmin/shared-code/dist/src/shared/interfaces/data-access-object-agent.interface.js'; |
4 | 6 | import AbstractUseCase from '../../../../common/abstract-use.case.js'; |
5 | 7 | import { IGlobalDatabaseContext } from '../../../../common/application/global-database-context.interface.js'; |
6 | 8 | import { BaseType } from '../../../../common/data-injection.tokens.js'; |
@@ -36,11 +38,29 @@ interface AIGeneratedWidgetResponse { |
36 | 38 | }; |
37 | 39 | } |
38 | 40 |
|
| 41 | +const MAX_FEEDBACK_ITERATIONS = 3; |
| 42 | + |
| 43 | +const EXPLAIN_SUPPORTED_TYPES: ReadonlySet<ConnectionTypesEnum> = new Set([ |
| 44 | + ConnectionTypesEnum.postgres, |
| 45 | + ConnectionTypesEnum.agent_postgres, |
| 46 | + ConnectionTypesEnum.mysql, |
| 47 | + ConnectionTypesEnum.agent_mysql, |
| 48 | + ConnectionTypesEnum.clickhouse, |
| 49 | + ConnectionTypesEnum.agent_clickhouse, |
| 50 | +]); |
| 51 | + |
| 52 | +interface TableInfo { |
| 53 | + table_name: string; |
| 54 | + columns: Array<{ name: string; type: string; nullable: boolean }>; |
| 55 | +} |
| 56 | + |
39 | 57 | @Injectable({ scope: Scope.REQUEST }) |
40 | 58 | export class GeneratePanelPositionWithAiUseCase |
41 | 59 | extends AbstractUseCase<GeneratePanelPositionWithAiDs, GeneratedPanelWithPositionDto> |
42 | 60 | implements IGeneratePanelPositionWithAi |
43 | 61 | { |
| 62 | + private readonly logger = new Logger(GeneratePanelPositionWithAiUseCase.name); |
| 63 | + |
44 | 64 | constructor( |
45 | 65 | @Inject(BaseType.GLOBAL_DB_CONTEXT) |
46 | 66 | protected _dbContext: IGlobalDatabaseContext, |
@@ -83,10 +103,7 @@ export class GeneratePanelPositionWithAiUseCase |
83 | 103 |
|
84 | 104 | const dao = getDataAccessObject(foundConnection); |
85 | 105 |
|
86 | | - let tableInfo: { |
87 | | - table_name: string; |
88 | | - columns: Array<{ name: string; type: string; nullable: boolean }>; |
89 | | - }; |
| 106 | + let tableInfo: TableInfo; |
90 | 107 |
|
91 | 108 | try { |
92 | 109 | const structure = await dao.getTableStructure(table_name, null); |
@@ -116,15 +133,23 @@ export class GeneratePanelPositionWithAiUseCase |
116 | 133 |
|
117 | 134 | validateQuerySafety(generatedWidget.query_text, foundConnection.type as ConnectionTypesEnum); |
118 | 135 |
|
| 136 | + const refinedWidget = await this.validateAndRefineQueryWithExplain( |
| 137 | + dao, |
| 138 | + generatedWidget, |
| 139 | + tableInfo, |
| 140 | + foundConnection.type as ConnectionTypesEnum, |
| 141 | + chart_description, |
| 142 | + ); |
| 143 | + |
119 | 144 | return { |
120 | | - name: name || generatedWidget.name, |
121 | | - description: generatedWidget.description || null, |
122 | | - widget_type: this.mapWidgetType(generatedWidget.widget_type), |
123 | | - chart_type: generatedWidget.chart_type || null, |
124 | | - widget_options: generatedWidget.widget_options |
125 | | - ? (generatedWidget.widget_options as unknown as Record<string, unknown>) |
| 145 | + name: name || refinedWidget.name, |
| 146 | + description: refinedWidget.description || null, |
| 147 | + widget_type: this.mapWidgetType(refinedWidget.widget_type), |
| 148 | + chart_type: refinedWidget.chart_type || null, |
| 149 | + widget_options: refinedWidget.widget_options |
| 150 | + ? (refinedWidget.widget_options as unknown as Record<string, unknown>) |
126 | 151 | : null, |
127 | | - query_text: generatedWidget.query_text, |
| 152 | + query_text: refinedWidget.query_text, |
128 | 153 | connection_id: connectionId, |
129 | 154 | panel_position: { |
130 | 155 | position_x: position_x ?? 0, |
@@ -219,6 +244,118 @@ Respond ONLY with the JSON object, no additional text or explanation.`; |
219 | 244 | } |
220 | 245 | } |
221 | 246 |
|
| 247 | + private async validateAndRefineQueryWithExplain( |
| 248 | + dao: IDataAccessObject | IDataAccessObjectAgent, |
| 249 | + generatedWidget: AIGeneratedWidgetResponse, |
| 250 | + tableInfo: TableInfo, |
| 251 | + connectionType: ConnectionTypesEnum, |
| 252 | + chartDescription: string, |
| 253 | + ): Promise<AIGeneratedWidgetResponse> { |
| 254 | + if (!EXPLAIN_SUPPORTED_TYPES.has(connectionType)) { |
| 255 | + return generatedWidget; |
| 256 | + } |
| 257 | + |
| 258 | + let currentQuery = generatedWidget.query_text; |
| 259 | + |
| 260 | + for (let iteration = 0; iteration < MAX_FEEDBACK_ITERATIONS; iteration++) { |
| 261 | + const explainResult = await this.runExplainQuery(dao, currentQuery, tableInfo.table_name); |
| 262 | + |
| 263 | + const correctionPrompt = this.buildQueryCorrectionPrompt( |
| 264 | + currentQuery, |
| 265 | + explainResult.success ? explainResult.result : explainResult.error, |
| 266 | + !explainResult.success, |
| 267 | + tableInfo, |
| 268 | + connectionType, |
| 269 | + chartDescription, |
| 270 | + ); |
| 271 | + |
| 272 | + const aiResponse = await this.aiCoreService.completeWithProvider(AIProviderType.BEDROCK, correctionPrompt, { |
| 273 | + temperature: 0.2, |
| 274 | + }); |
| 275 | + |
| 276 | + const correctedQuery = this.cleanQueryResponse(aiResponse); |
| 277 | + |
| 278 | + validateQuerySafety(correctedQuery, connectionType); |
| 279 | + |
| 280 | + if (this.normalizeWhitespace(correctedQuery) === this.normalizeWhitespace(currentQuery)) { |
| 281 | + this.logger.log(`Query accepted by AI without changes after EXPLAIN (iteration ${iteration + 1})`); |
| 282 | + break; |
| 283 | + } |
| 284 | + |
| 285 | + this.logger.log(`Query corrected by AI after EXPLAIN (iteration ${iteration + 1})`); |
| 286 | + currentQuery = correctedQuery; |
| 287 | + |
| 288 | + if (explainResult.success) { |
| 289 | + break; |
| 290 | + } |
| 291 | + } |
| 292 | + |
| 293 | + return { ...generatedWidget, query_text: currentQuery }; |
| 294 | + } |
| 295 | + |
| 296 | + private async runExplainQuery( |
| 297 | + dao: IDataAccessObject | IDataAccessObjectAgent, |
| 298 | + query: string, |
| 299 | + tableName: string, |
| 300 | + ): Promise<{ success: boolean; result?: string; error?: string }> { |
| 301 | + try { |
| 302 | + const explainQuery = `EXPLAIN ${query.replace(/;\s*$/, '')}`; |
| 303 | + const result = await (dao as IDataAccessObject).executeRawQuery(explainQuery, tableName); |
| 304 | + return { success: true, result: JSON.stringify(result, null, 2) }; |
| 305 | + } catch (error) { |
| 306 | + return { success: false, error: error.message }; |
| 307 | + } |
| 308 | + } |
| 309 | + |
| 310 | + private buildQueryCorrectionPrompt( |
| 311 | + currentQuery: string, |
| 312 | + explainResultOrError: string, |
| 313 | + isError: boolean, |
| 314 | + tableInfo: TableInfo, |
| 315 | + connectionType: ConnectionTypesEnum, |
| 316 | + chartDescription: string, |
| 317 | + ): string { |
| 318 | + const schemaDescription = `Table: ${tableInfo.table_name}\n Columns:\n${tableInfo.columns |
| 319 | + .map((col) => ` - ${col.name}: ${col.type}${col.nullable ? ' (nullable)' : ''}`) |
| 320 | + .join('\n')}`; |
| 321 | + |
| 322 | + const feedbackSection = isError |
| 323 | + ? `The query FAILED with the following error:\n${explainResultOrError}\n\nPlease fix the query to resolve this error.` |
| 324 | + : `The EXPLAIN output for the query is:\n${explainResultOrError}\n\nReview the execution plan. If the query has performance issues (full table scans on large datasets, inefficient joins, etc.), optimize it. If the query is already acceptable, return it unchanged.`; |
| 325 | + |
| 326 | + return `You are a database query optimization assistant. A SQL query was generated and needs validation. |
| 327 | +
|
| 328 | +DATABASE TYPE: ${connectionType} |
| 329 | +
|
| 330 | +DATABASE SCHEMA: |
| 331 | +${schemaDescription} |
| 332 | +
|
| 333 | +ORIGINAL USER REQUEST: |
| 334 | +"${chartDescription}" |
| 335 | +
|
| 336 | +CURRENT QUERY: |
| 337 | +${currentQuery} |
| 338 | +
|
| 339 | +${feedbackSection} |
| 340 | +
|
| 341 | +IMPORTANT: |
| 342 | +- Preserve the same column aliases used in the original query. |
| 343 | +- Write valid ${connectionType} SQL syntax. |
| 344 | +- Return ONLY the SQL query, no explanations, no markdown, no JSON wrapping.`; |
| 345 | + } |
| 346 | + |
| 347 | + private cleanQueryResponse(aiResponse: string): string { |
| 348 | + return aiResponse |
| 349 | + .trim() |
| 350 | + .replace(/^```[a-zA-Z]*\n?/, '') |
| 351 | + .replace(/```\s*$/, '') |
| 352 | + .trim(); |
| 353 | + } |
| 354 | + |
| 355 | + private normalizeWhitespace(query: string): string { |
| 356 | + return query.replace(/\s+/g, ' ').trim(); |
| 357 | + } |
| 358 | + |
222 | 359 | private mapWidgetType(type: string): DashboardWidgetTypeEnum { |
223 | 360 | switch (type) { |
224 | 361 | case 'chart': |
|
0 commit comments