Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion application/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ services:
# 指定镜像和版本
image: mysql:8.0
ports:
- "127.0.0.1:3306:3306"
- "3306:3306"
restart: always
environment:
# 配置root密码
Expand Down
17 changes: 10 additions & 7 deletions application/pages/1_🌍_Generative_BI_Playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from nlq.business.profile import ProfileManagement
from nlq.business.suggested_question import SuggestedQuestionManagement as sqm
from nlq.business.vector_store import VectorStore
from utils.llm import text_to_sql, get_query_intent, generate_suggested_question, get_agent_cot_task, agent_data_analyse, \
from utils.llm import text_to_sql, get_query_intent, generate_suggested_question, get_agent_cot_task, data_analyse_tool, \
knowledge_search
from utils.constant import PROFILE_QUESTION_TABLE_NAME, ACTIVE_PROMPT_NAME, DEFAULT_PROMPT_NAME
from utils.navigation import make_sidebar
Expand Down Expand Up @@ -420,13 +420,13 @@ def main():
with st.expander(
f'Query Retrieve : {len(normal_search_result.retrieve_result)}, NER Retrieve : {len(normal_search_result.entity_slot_retrieve)}'):
examples = {}
examples["query_retrieve"] = []
for example in normal_search_result.retrieve_result:
examples["query_retrieve"] = []
examples["query_retrieve"].append({'Score': example['_score'],
'Question': example['_source']['text'],
'Answer': example['_source']['sql'].strip()})
examples["ner_retrieve"] = []
for example in normal_search_result.entity_slot_retrieve:
examples["ner_retrieve"] = []
examples["ner_retrieve"].append({'Score': example['_score'],
'Question': example['_source']['entity'],
'Answer': example['_source']['comment'].strip()})
Expand Down Expand Up @@ -475,7 +475,11 @@ def main():
if search_intent_result["status_code"] == 500:
with st.expander("The SQL Error Info"):
st.markdown(search_intent_result["error_info"])

else:
if search_intent_result["data"] is not None and len(search_intent_result["data"]) > 0:
search_intent_analyse_result = data_analyse_tool(model_type, search_box,
search_intent_result["data"].to_json(orient='records', force_ascii=False), "query")
st.markdown(search_intent_analyse_result)
st.session_state.current_sql_result[selected_profile] = search_intent_result["data"]

elif agent_intent_flag:
Expand All @@ -488,9 +492,8 @@ def main():
orient='records')
filter_deep_dive_sql_result.append(agent_search_result[i])

agent_data_analyse_result = agent_data_analyse(model_type, search_box,
json.dumps(
filter_deep_dive_sql_result))
agent_data_analyse_result = data_analyse_tool(model_type, search_box,
json.dumps(filter_deep_dive_sql_result, ensure_ascii=False), "agent")
logger.info("agent_data_analyse_result")
logger.info(agent_data_analyse_result)
st.session_state.messages[selected_profile].append(
Expand Down
9 changes: 6 additions & 3 deletions application/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from utils import opensearch
from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \
DEFAULT_DIALECT_PROMPT, SEARCH_INTENT_PROMPT_CLAUDE3, CLAUDE3_DATA_ANALYSE_SYSTEM_PROMPT, \
CLAUDE3_DATA_ANALYSE_USER_PROMPT, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3
CLAUDE3_AGENT_DATA_ANALYSE_USER_PROMPT, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3, CLAUDE3_QUERY_DATA_ANALYSE_USER_PROMPT
import os
import logging
from langchain_core.output_parsers import JsonOutputParser
Expand Down Expand Up @@ -339,11 +339,14 @@ def get_agent_cot_task(model_id, search_box, ddl, agent_cot_example=None):
return default_agent_cot_task


def agent_data_analyse(model_id, search_box, sql_data):
def data_analyse_tool(model_id, search_box, sql_data, search_type):
try:
system_prompt = CLAUDE3_DATA_ANALYSE_SYSTEM_PROMPT
max_tokens = 2048
user_prompt = CLAUDE3_DATA_ANALYSE_USER_PROMPT.format(question=search_box, data=sql_data)
if search_type == "agent":
user_prompt = CLAUDE3_AGENT_DATA_ANALYSE_USER_PROMPT.format(question=search_box, data=sql_data)
else:
user_prompt = CLAUDE3_QUERY_DATA_ANALYSE_USER_PROMPT.format(question=search_box, data=sql_data)
user_message = {"role": "user", "content": user_prompt}
messages = [user_message]
response = invoke_model_claude3(model_id, system_prompt, messages, max_tokens)
Expand Down
17 changes: 16 additions & 1 deletion application/utils/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
You are a data analysis expert in the retail industry
"""

CLAUDE3_DATA_ANALYSE_USER_PROMPT = """
CLAUDE3_AGENT_DATA_ANALYSE_USER_PROMPT = """
As a professional data analyst, you are now asked a question by a user, and you need to analyze the data provided.

<instructions>
Expand All @@ -167,3 +167,18 @@

Think step by step。
"""

CLAUDE3_QUERY_DATA_ANALYSE_USER_PROMPT = """

Your task is to analyze the given data and describe it in natural language.

<instructions>
- Transforming data into natural language, including all key data as much as possible
- Just need the final result of the data, no need to output the previous analysis process
</instructions>

The user question is:{question}

The data is:{data}

"""