Skip to content
3 changes: 3 additions & 0 deletions application/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ class Question(BaseModel):
use_rag: bool = True
query_result: bool = True
intent_ner_recognition: bool = False
agent_cot: bool = False
profile_name: str = "shopping_guide"
explain_gen_process_flag: bool = False
gen_suggested_question: bool = False


class QuestionSocket(Question):
Expand Down
30 changes: 9 additions & 21 deletions application/api/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from nlq.business.nlq_chain import NLQChain
from nlq.business.profile import ProfileManagement
from utils.database import get_db_url_dialect
from utils.llm import claude3_to_sql, create_vector_embedding_with_bedrock, \
retrieve_results_from_opensearch, get_query_intent, create_vector_embedding_with_sagemaker, \
from utils.llm import text_to_sql, get_query_intent, create_vector_embedding_with_sagemaker, \
sagemaker_to_sql, sagemaker_to_explain
from utils.opensearch import get_retrieve_opensearch
from .schemas import Question, Answer, Example, Option
from .exception_handler import BizException
from utils.constant import BEDROCK_MODEL_IDS
Expand Down Expand Up @@ -61,25 +61,13 @@ def __process_nlq_chain(question: Question) -> NLQChain:
question.keywords,
index_name=env_vars['data_sources'][selected_profile]['opensearch']['index_name'])
else:
records_with_embedding = create_vector_embedding_with_bedrock(
question.keywords,
index_name=env_vars['data_sources'][selected_profile]['opensearch']['index_name'])
# records_with_embedding = create_vector_embedding_with_bedrock(
# question.keywords,
# index_name=env_vars['data_sources'][selected_profile]['opensearch']['index_name'])
pass
logger.info(env_vars['data_sources'][selected_profile]['opensearch']['index_name'])
retrieve_result = retrieve_results_from_opensearch(
index_name=env_vars['data_sources'][selected_profile]['opensearch']['index_name'],
region_name=env_vars['data_sources'][selected_profile]['opensearch']['region_name'],
domain=env_vars['data_sources'][selected_profile]['opensearch']['domain'],
opensearch_user=env_vars['data_sources'][selected_profile]['opensearch'][
'opensearch_user'],
opensearch_password=env_vars['data_sources'][selected_profile]['opensearch'][
'opensearch_password'],
host=env_vars['data_sources'][selected_profile]['opensearch'][
'opensearch_host'],
port=env_vars['data_sources'][selected_profile]['opensearch'][
'opensearch_port'],
query_embedding=records_with_embedding['vector_field'],
top_k=3,
profile_name=selected_profile)

retrieve_result = get_retrieve_opensearch(env_vars, current_nlq_chain.get_question(), "query", selected_profile, 3, 0.5)
current_nlq_chain.set_retrieve_samples(retrieve_result)
except Exception as e:
logger.exception(e)
Expand Down Expand Up @@ -147,7 +135,7 @@ def get_result_from_llm(question: Question, current_nlq_chain: NLQChain, with_re
model_provider=None,
with_response_stream=with_response_stream,) # This does not support streaming
else:
response = claude3_to_sql(database_profile['tables_info'],
response = text_to_sql(database_profile['tables_info'],
database_profile['hints'],
question.keywords,
model_id=question.bedrock_model_id,
Expand Down
6 changes: 6 additions & 0 deletions application/nlq/business/nlq_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, profile):
self.generated_sql_response = ''
self.executed_result_df: pd.DataFrame | None = None
self.visualization_config_change: bool = False
self.sql = ''

def set_question(self, question):
if self.question != question:
Expand All @@ -40,7 +41,12 @@ def set_generated_sql_response(self, sql_response):
def get_generated_sql_response(self):
return self.generated_sql_response

def set_generated_sql(self, sql):
self.sql = sql

def get_generated_sql(self):
if self.sql != "":
return self.sql
sql = ""
try:
return self.generated_sql_response.split("<sql>")[1].split("</sql>")[0]
Expand Down
Loading