Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 217 additions & 0 deletions DESIGN_AIRBNB.md

Large diffs are not rendered by default.

287 changes: 287 additions & 0 deletions DESIGN_APPLE.md

Large diffs are not rendered by default.

289 changes: 289 additions & 0 deletions DESIGN_CLAUDE.md

Large diffs are not rendered by default.

376 changes: 376 additions & 0 deletions DESIGN_NOTION.md

Large diffs are not rendered by default.

16 changes: 14 additions & 2 deletions backend/apps/chat/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@

import orjson
import pandas as pd
from fastapi import APIRouter, HTTPException, Path
from fastapi import APIRouter, HTTPException, Path, Query
from fastapi.responses import StreamingResponse
from sqlalchemy import and_, select
from starlette.responses import JSONResponse

from apps.chat.curd.chat import delete_chat_with_user, get_chart_data_with_user, get_chat_predict_data_with_user, \
list_chats, get_chat_with_records, create_chat, rename_chat, \
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \
format_json_data, format_json_list_data, get_chart_config, list_recent_questions, get_chat as get_chat_exec, \
format_json_data, format_json_list_data, get_chart_config, list_recent_questions, list_popular_questions, \
get_chat as get_chat_exec, \
rename_chat_with_user, get_chat_log_history, get_chart_data_with_user_live
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj, QuickCommand, \
ChatInfo, Chat, ChatFinishStep
Expand All @@ -34,6 +35,17 @@ async def chats(session: SessionDep, current_user: CurrentUser):
return list_chats(session, current_user)


@router.get("/popular_questions", summary=f"{PLACEHOLDER_PREFIX}popular_questions_workspace")
async def popular_questions(
session: SessionDep, current_user: CurrentUser, limit: int = Query(8, ge=1, le=50)
):
"""工作空间内提问频次统计(排除首条占位记录)。"""
def inner():
return list_popular_questions(session=session, current_user=current_user, limit=limit)

return await asyncio.to_thread(inner)


@router.get("/{chart_id}", response_model=ChatInfo, summary=f"{PLACEHOLDER_PREFIX}get_chat")
async def get_chat(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant,
trans: Trans):
Expand Down
57 changes: 56 additions & 1 deletion backend/apps/chat/curd/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from common.utils.data_format import DataFormat
from common.utils.utils import extract_nested_json, SQLBotLogUtil

from apps.chat.utils.popular_questions_cluster import cluster_questions_for_datasource


def get_chat_record_by_id(session: SessionDep, record_id: int):
record: ChatRecord | None = None
Expand Down Expand Up @@ -66,6 +68,58 @@ def list_recent_questions(session: SessionDep, current_user: CurrentUser, dataso
return [record[0] for record in chat_records] if chat_records else []


def list_popular_questions(session: SessionDep, current_user: CurrentUser, limit: int = 8) -> List[Dict[str, Any]]:
"""按数据源 + 语义合并统计热门问题(同一数据源内相近问句合并,非纯字面 group_by)。"""
oid = current_user.oid if current_user.oid is not None else 1
limit = min(max(limit, 1), 50)
cnt = func.count(ChatRecord.id).label('cnt')
rows = (
session.query(Chat.datasource, ChatRecord.question, cnt)
.join(Chat, ChatRecord.chat_id == Chat.id)
.filter(
Chat.oid == oid,
Chat.create_by == current_user.id,
Chat.datasource.isnot(None),
ChatRecord.question.isnot(None),
ChatRecord.question != '',
ChatRecord.first_chat.isnot(True),
)
.group_by(Chat.datasource, ChatRecord.question)
.order_by(desc(cnt))
.limit(400)
.all()
)
by_ds: Dict[Any, List[tuple]] = {}
for ds_id, question, c in rows:
by_ds.setdefault(ds_id, []).append((question, int(c)))

ds_ids = [k for k in by_ds.keys() if k is not None]
id_to_name: Dict[Any, str] = {}
if ds_ids:
ds_rows = session.query(CoreDatasource.id, CoreDatasource.name).filter(
CoreDatasource.id.in_(ds_ids),
CoreDatasource.oid == oid,
).all()
id_to_name = {r[0]: r[1] for r in ds_rows}

flat: List[Dict[str, Any]] = []
for ds_id, weighted in by_ds.items():
if ds_id is None:
continue
for rep_q, total in cluster_questions_for_datasource(weighted):
flat.append(
{
'datasource_id': int(ds_id),
'datasource_name': id_to_name.get(ds_id) or '',
'question': rep_q,
'count': total,
}
)

flat.sort(key=lambda x: (-x['count'], x.get('datasource_name') or ''))
return flat[:limit]


def rename_chat_with_user(session: SessionDep, current_user: CurrentUser, rename_object: RenameChat) -> str:
chat = session.get(Chat, rename_object.id)
if not chat:
Expand Down Expand Up @@ -211,7 +265,7 @@ def format_json_list_data(origin_data: list[dict]):
if len(decimal_str) > 15:
value = str(value)
_row[key] = value
data.append(_row)
data.append(DataFormat.normalize_qualified_sql_column_keys(_row))

return data

Expand Down Expand Up @@ -253,6 +307,7 @@ def get_chart_data_ds(session: SessionDep,ds_id,sql):
else:
result = exec_sql(ds=datasource,sql=sql, origin_column=False)
_data = DataFormat.convert_large_numbers_in_object_array(result.get('data'))
_data = DataFormat.normalize_qualified_sql_column_keys_in_object_array(_data)
json_result['data'] = _data
return json_result
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion backend/apps/chat/models/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class AiModelQuestion(BaseModel):
custom_prompt: str = ""
error_msg: str = ""
regenerate_record_id: Optional[int] = None
sample_data: str = ""

def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True):
templates: dict[str, str] = {}
Expand Down Expand Up @@ -256,7 +257,7 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T
example_answer_1=_example_answer_1,
example_answer_2=_example_answer_2,
example_answer_3=_example_answer_3)
templates['schema'] = _base_template['generate_basic_info'].format(engine=self.engine, schema=self.db_schema)
templates['schema'] = _base_template['generate_basic_info'].format(engine=self.engine, schema=self.db_schema, sample_data=self.sample_data)

if self.terminologies:
templates['terminologies'] = _base_template['generate_terminologies_info'].format(
Expand Down
17 changes: 16 additions & 1 deletion backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
ChatFinishStep, AxisObj, SystemPromptMessage, HumanPromptMessage, AIPromptMessage
from apps.data_training.curd.data_training import get_training_template
from apps.datasource.crud.datasource import get_table_schema
from apps.datasource.crud.datasource import get_table_schema, get_tables_sample_data
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
from apps.datasource.embedding.ds_embedding import get_ds_embedding
from apps.datasource.models.datasource import CoreDatasource
Expand Down Expand Up @@ -384,6 +384,13 @@ def choose_table_schema(self, _session: Session):
ds=self.ds,
question=self.chat_question.question)

# Get sample data for all tables
if not self.out_ds_instance:
self.chat_question.sample_data = get_tables_sample_data(
session=_session,
current_user=self.current_user,
ds=self.ds)

self.current_logs[OperationEnum.CHOOSE_TABLE] = end_log(session=_session,
log=self.current_logs[OperationEnum.CHOOSE_TABLE],
full_message=self.chat_question.db_schema)
Expand Down Expand Up @@ -505,6 +512,13 @@ def generate_recommend_questions_task(self, _session: Session):
question=self.chat_question.question,
embedding=False)

# Get sample data for all tables
if not self.out_ds_instance:
self.chat_question.sample_data = get_tables_sample_data(
session=_session,
current_user=self.current_user,
ds=self.ds)

guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
guess_msg.append(SystemPromptMessage(content=self.chat_question.guess_sys_question(self.articles_number)))

Expand Down Expand Up @@ -1304,6 +1318,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
'count': len(result.get('data'))})

_data = DataFormat.convert_large_numbers_in_object_array(result.get('data'))
_data = DataFormat.normalize_qualified_sql_column_keys_in_object_array(_data)
result["data"] = _data

self.save_sql_data(session=_session, data_obj=result)
Expand Down
129 changes: 129 additions & 0 deletions backend/apps/chat/utils/popular_questions_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""
热门问题:按数据源聚合,并在同一数据源内做语义相近合并(非纯字面 group_by)。

1. 意图桶:库表/数据概览类中文问法合并为同一主题(见 META_OVERVIEW_PATTERN)。
2. 向量聚类:对其余问句用本地中文 embedding 做余弦相似度合并(可选,失败则回退)。
3. 回退:归一化 + difflib 合并相近字面。
"""

from __future__ import annotations

import re
from difflib import SequenceMatcher
from typing import Any, Dict, List, Tuple

import numpy as np

# 表/数据量/有哪些数据 等「元信息」类问题归为一类(用户示例)
META_OVERVIEW_PATTERN = re.compile(
r"(几张表|哪些表|多少张表|有多少表|表.*数据量|数据量.*表|分别.*数据量|数据量.*多大|"
r"哪些数据|有什么数据|有哪些数据|什么数据|库表|schema|多少条数据|统计.*表|表的.*数量)",
re.IGNORECASE,
)


def normalize_question(s: str) -> str:
if not s:
return ""
t = s.strip()
t = re.sub(r"[\s\u3000]+", "", t)
t = re.sub(r"[。..!?!?;;,、]+$", "", t)
return t


def _split_meta_overview(
weighted: List[Tuple[str, int]],
) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
meta: List[Tuple[str, int]] = []
rest: List[Tuple[str, int]] = []
for q, c in weighted:
if META_OVERVIEW_PATTERN.search(q):
meta.append((q, c))
else:
rest.append((q, c))
out: List[Tuple[str, int]] = []
if meta:
rep = max(meta, key=lambda x: x[1])[0]
total = sum(c for _, c in meta)
out.append((rep, total))
return out, rest


def _merge_difflib(weighted: List[Tuple[str, int]], threshold: float = 0.78) -> List[Tuple[str, int]]:
if not weighted:
return []
items = sorted(weighted, key=lambda x: -x[1])
clusters: List[Dict[str, Any]] = []
for q, c in items:
nq = normalize_question(q)
best_i = -1
best_r = 0.0
for i, cl in enumerate(clusters):
r = SequenceMatcher(None, nq, cl["norm"]).ratio()
if r >= threshold and r > best_r:
best_r = r
best_i = i
if best_i >= 0:
clusters[best_i]["count"] += c
if c > clusters[best_i].get("max_w", 0):
clusters[best_i]["rep"] = q
clusters[best_i]["max_w"] = c
else:
clusters.append({"rep": q, "count": c, "norm": nq, "max_w": c})
return [(c["rep"], int(c["count"])) for c in clusters]


def _merge_embedding(weighted: List[Tuple[str, int]], threshold: float = 0.76) -> List[Tuple[str, int]]:
if len(weighted) <= 1:
return weighted
try:
from apps.ai_model.embedding import EmbeddingModelCache

texts = [w[0] for w in weighted]
model = EmbeddingModelCache.get_model()
embs = model.embed_documents(texts)
arr = np.array(embs, dtype=np.float32)
norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-9
arr = arr / norms
n = len(weighted)
parent = list(range(n))

def find(a: int) -> int:
while parent[a] != a:
parent[a] = parent[parent[a]]
a = parent[a]
return a

def union(a: int, b: int) -> None:
ra, rb = find(a), find(b)
if ra != rb:
parent[rb] = ra

sim = arr @ arr.T
for i in range(n):
for j in range(i + 1, n):
if float(sim[i, j]) >= threshold:
union(i, j)
groups: Dict[int, List[int]] = {}
for i in range(n):
r = find(i)
groups.setdefault(r, []).append(i)
out: List[Tuple[str, int]] = []
for idxs in groups.values():
total = sum(weighted[i][1] for i in idxs)
rep_q = max((weighted[i] for i in idxs), key=lambda x: x[1])[0]
out.append((rep_q, int(total)))
return out
except Exception:
return _merge_difflib(weighted, threshold=0.78)


def cluster_questions_for_datasource(weighted: List[Tuple[str, int]]) -> List[Tuple[str, int]]:
"""同一数据源下多组 (原文, 次数) -> 合并后 (代表问句, 总次数)。"""
if not weighted:
return []
meta_merged, rest = _split_meta_overview(weighted)
if not rest:
return meta_merged
embedded_or_fb = _merge_embedding(rest)
return meta_merged + embedded_or_fb
1 change: 1 addition & 0 deletions backend/apps/datasource/api/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def inner():
try:
return preview(session, current_user, id, data)
except Exception as e:
SQLBotLogUtil.error(f"Preview failed: {e}, try another way")
ds = session.query(CoreDatasource).filter(CoreDatasource.id == id).first()
# check ds status
status = check_status(session, trans, ds, True)
Expand Down
Loading