44from langchain .llms import VertexAI
55from libs .logger import logger
66import streamlit as st
7-
7+ from google .oauth2 import service_account
8+ from langchain .prompts import ChatPromptTemplate
9+
810class VertexAILangChain :
9- def __init__ (self , project , location , model_name , max_output_tokens , temperature , credentials_file_path ):
11+ def __init__ (self , project = "" , location = "us-central1" , model_name = "code-bison" , max_tokens = 256 , temperature : float = 0.3 , credentials_file_path = None ):
1012 self .project = project
1113 self .location = location
1214 self .model_name = model_name
13- self .max_output_tokens = max_output_tokens
15+ self .max_tokens = max_tokens
1416 self .temperature = temperature
1517 self .credentials_file_path = credentials_file_path
1618 self .vertexai_llm = None
1719
18- def load_model (self ):
20+ def load_model (self , model_name , max_tokens , temperature ):
1921 try :
2022 logger .info (f"Loading model... with project: { self .project } and location: { self .location } " )
21-
2223 # Set the GOOGLE_APPLICATION_CREDENTIALS environment variable
23- from google .oauth2 import service_account
2424 credentials = service_account .Credentials .from_service_account_file (self .credentials_file_path )
2525
26- logger .info (f"Trying to set Vertex model with parameters: { self .model_name } , { self .max_output_tokens } , { self .temperature } , { self .location } " )
26+ logger .info (f"Trying to set Vertex model with parameters: { model_name or self .model_name } , { max_tokens or self .max_tokens } , { temperature or self .temperature } , { self .location } " )
2727 self .vertexai_llm = VertexAI (
28- model_name = self .model_name ,
29- max_output_tokens = self .max_output_tokens ,
30- temperature = self .temperature ,
28+ model_name = model_name or self .model_name ,
29+ max_output_tokens = max_tokens or self .max_tokens ,
30+ temperature = temperature or self .temperature ,
3131 verbose = True ,
3232 location = self .location ,
3333 credentials = credentials ,
@@ -89,8 +89,17 @@ def generate_code(self, code_prompt, code_language):
8989 response = llm_chain .run ({"code_prompt" : code_prompt , "code_language" : code_language })
9090 if response or len (response ) > 0 :
9191 logger .info (f"Code generated successfully: { response } " )
92+
9293 # Extract text inside code block
93- generated_code = re .search ('```(.*)```' , response , re .DOTALL ).group (1 )
94+ if response .startswith ("```" ) or response .endswith ("```" ):
95+ try :
96+ generated_code = re .search ('```(.*)```' , response , re .DOTALL ).group (1 )
97+ except AttributeError :
98+ generated_code = response
99+ else :
100+ st .toast (f"Error extracting code" , icon = "❌" )
101+ return response
102+
94103 if generated_code :
95104 # Skip the language name in the first line.
96105 response = generated_code .split ("\n " , 1 )[1 ]
@@ -104,5 +113,58 @@ def generate_code(self, code_prompt, code_language):
104113 logger .error (f"Error generating code: { str (exception )} stack trace: { stack_trace } " )
105114 st .toast (f"Error generating code: { str (exception )} stack trace: { stack_trace } " , icon = "❌" )
106115
116+ def generate_code_completion (self , code_prompt , code_language ):
117+ try :
118+ if not code_prompt or len (code_prompt ) == 0 :
119+ logger .error ("Code prompt is empty or null." )
120+ st .error ("Code generateration cannot be performed as the code prompt is empty or null." )
121+ return None
122+
123+ logger .info (f"Generating code completion with parameters: { code_prompt } , { code_language } " )
124+ template = f"Complete the following {{code_language}} code: {{code_prompt}}"
125+ prompt_obj = PromptTemplate (template = template , input_variables = ["code_language" , "code_prompt" ])
126+
127+ max_tokens = st .session_state ["vertexai" ]["max_tokens" ]
128+ temprature = st .session_state ["vertexai" ]["temperature" ]
129+
130+ # Check the maximum number of tokens of Gecko model i.e 65
131+ if max_tokens > 65 :
132+ max_tokens = 65
133+ logger .info (f"Maximum number of tokens for Model Gecko can't exceed 65. Setting max_tokens to 65." )
134+ st .toast (f"Maximum number of tokens for Model Gecko can't exceed 65. Setting max_tokens to 65." , icon = "⚠️" )
135+
136+ self .model_name = "code-gecko" # Define the code completion model name.
137+ self .llm = VertexAI (model_name = self .model_name ,max_output_tokens = max_tokens , temperature = temprature )
138+ logger .info (f"Initialized VertexAI with model: { self .model_name } " )
139+ llm_chain = LLMChain (prompt = prompt_obj , llm = self .llm )
140+ response = llm_chain .run ({"code_prompt" : code_prompt , "code_language" : code_language })
141+
142+ if response :
143+ logger .info (f"Code completion generated successfully: { response } " )
144+ return response
145+ else :
146+ logger .warning ("No response received from LLMChain." )
147+ return None
148+ except Exception as e :
149+ logger .error (f"Error generating code completion: { str (e )} " )
150+ raise
151+
152+ def set_temperature (self , temperature ):
153+ self .temperature = temperature
154+ self .vertexai_llm .temperature = temperature
155+ # call load_model to reload the model with the new temperature and rest values should be same
156+ self .load_model (self .model_name , self .max_tokens , self .temperature )
157+
158+ def set_max_tokens (self , max_tokens ):
159+ self .max_tokens = max_tokens
160+ self .vertexai_llm .max_output_tokens = max_tokens
161+ # call load_model to reload the model with the new max_output_tokens and rest values should be same
162+ self .load_model (self .model_name , self .max_tokens , self .temperature )
163+
164+ def set_model_name (self , model_name ):
165+ self .model_name = model_name
166+ # call load_model to reload the model with the new model_name and rest values should be same
167+ self .load_model (self .model_name , self .max_tokens , self .temperature )
168+
107169
108170
0 commit comments