Skip to content

Commit a7efab6

Browse files
committed
Added support for Gemini AI model
1 parent b9a5c01 commit a7efab6

File tree

5 files changed

+261
-38
lines changed

5 files changed

+261
-38
lines changed

.env

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
PALM_API_KEY="XXXXXXXXXX"
2+
GEMINI_API_KEY="XXXXXXXXXX"
3+
OPENAI_API_KEY="XXXXXXXXXX"

libs/geminiai.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
2+
import traceback
3+
import google.generativeai as genai
4+
from dotenv import load_dotenv
5+
from libs.logger import logger
6+
import streamlit as st
7+
8+
class GeminiAI:
9+
def __init__(self, api_key, model="gemini-pro", temperature=0.1, max_output_tokens=2048,mode="balanced"):
10+
self.api_key = api_key
11+
self.model = model
12+
self.temperature = temperature
13+
self.max_output_tokens = max_output_tokens
14+
self.mode = mode
15+
self.top_k = 20
16+
self.top_p = 0.85
17+
self._configure()
18+
19+
# Dynamically construct guidelines based on session state
20+
self.guidelines_list = []
21+
22+
if st.session_state["coding_guidelines"]["modular_code"]:
23+
self.guidelines_list.append("- Ensure the method is modular in its approach.")
24+
if st.session_state["coding_guidelines"]["exception_handling"]:
25+
self.guidelines_list.append("- Integrate robust exception handling.")
26+
if st.session_state["coding_guidelines"]["error_handling"]:
27+
self.guidelines_list.append("- Add error handling to each module.")
28+
if st.session_state["coding_guidelines"]["efficient_code"]:
29+
self.guidelines_list.append("- Optimize the code to ensure it runs efficiently.")
30+
if st.session_state["coding_guidelines"]["robust_code"]:
31+
self.guidelines_list.append("- Ensure the code is robust against potential issues.")
32+
if st.session_state["coding_guidelines"]["naming_conventions"]:
33+
self.guidelines_list.append("- Follow standard naming conventions.")
34+
35+
# Convert the list to a string
36+
self.guidelines = "\n".join(self.guidelines_list)
37+
38+
39+
def _configure(self):
40+
try:
41+
logger.info("Configuring Gemini AI Pro...")
42+
genai.configure(api_key=self.api_key)
43+
self.generation_config = {
44+
"temperature": self.temperature,
45+
"top_p": self.top_p,
46+
"top_k": self.top_k,
47+
"max_output_tokens": self.max_output_tokens
48+
}
49+
self.model = genai.GenerativeModel(model_name=self.model,generation_config=self.generation_config)
50+
logger.info("Gemini AI Pro configured successfully.")
51+
except Exception as exception:
52+
logger.error(f"Error configuring Gemini AI Pro: {str(exception)}")
53+
traceback.print_exc()
54+
55+
def _extract_code(self, code):
56+
"""
57+
Extracts the code from the provided string.
58+
If the string contains '```', it extracts the code between them.
59+
Otherwise, it returns the original string.
60+
"""
61+
try:
62+
if '```' in code:
63+
start = code.find('```') + len('```\n')
64+
end = code.find('```', start)
65+
# Skip the first line after ```
66+
start = code.find('\n', start) + 1
67+
extracted_code = code[start:end]
68+
logger.info("Code extracted successfully.")
69+
return extracted_code
70+
else:
71+
logger.info("No special characters found in the code. Returning the original code.")
72+
return code
73+
except Exception as exception:
74+
logger.error(f"Error occurred while extracting code: {exception}")
75+
return None
76+
77+
def generate_code(self, code_prompt,code_language):
78+
"""
79+
Function to generate text using the Gemini API.
80+
"""
81+
try:
82+
# Define top_k and top_p based on the mode
83+
if self.mode == "precise":
84+
top_k = 40
85+
top_p = 0.95
86+
self.temprature = 0
87+
elif self.mode == "balanced":
88+
top_k = 20
89+
top_p = 0.85
90+
self.temprature = 0.3
91+
elif self.mode == "creative":
92+
top_k = 10
93+
top_p = 0.75
94+
self.temprature = 1
95+
else:
96+
raise ValueError("Invalid mode. Choose from 'precise', 'balanced', 'creative'.")
97+
98+
logger.info(f"Generating code with mode: {self.mode}, top_k: {top_k}, top_p: {top_p}")
99+
100+
101+
# check for valid prompt and language
102+
if not code_prompt or len(code_prompt) == 0:
103+
st.toast("Error in code generation: Please enter a valid prompt.", icon="❌")
104+
logger.error("Error in code generation: Please enter a valid prompt.")
105+
return
106+
107+
logger.info(f"Generating code for prompt: {code_prompt} in language: {code_language}")
108+
if code_prompt and len(code_prompt) > 0 and code_language and len(code_language) > 0:
109+
logger.info(f"Generating code for prompt: {code_prompt} in language: {code_language}")
110+
111+
# Plain and Simple Coding Task Prompt
112+
prompt = f"""
113+
Task: You're an experienced developer. Your mission is to create a program for {code_prompt} in {code_language} that takes {st.session_state.code_input} as input.
114+
115+
Your goal is clear: Craft a solution that showcases your expertise as a coder and problem solver.
116+
117+
Ensure that the program's output contains only the code you've written, with no extraneous information.
118+
119+
Show your skills and solve this challenge with confidence!
120+
121+
And follow the proper coding guidelines and dont add comment unless instructed to do so.
122+
{self.guidelines}
123+
"""
124+
125+
gemini_completion = self.model.generate_content(prompt)
126+
logger.info("Text generation completed successfully.")
127+
128+
code = None
129+
if gemini_completion:
130+
# extract the code from the gemini completion
131+
code = gemini_completion.text
132+
logger.info(f"GeminiAI coder is initialized.")
133+
logger.info(f"Generated code: {code[:100]}...")
134+
135+
if gemini_completion:
136+
# Extracted code from the gemini completion
137+
extracted_code = self._extract_code(code)
138+
139+
# Check if the code or extracted code is not empty or null
140+
if not code or not extracted_code:
141+
raise Exception("Error: Generated code or extracted code is empty or null.")
142+
143+
return extracted_code
144+
else:
145+
raise Exception("Error in code generation: Please enter a valid code.")
146+
147+
except Exception as exception:
148+
st.toast(f"Error in code generation: {exception}", icon="❌")
149+
logger.error(f"Error in code generation: {traceback.format_exc()}")

libs/general_utils.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -401,24 +401,47 @@ def calculate_code_generation_cost(self,string,price=0.0005):
401401
total_cost = cost * number_of_words
402402

403403
# Return the cost, cost per whole string and total cost
404-
return cost, cost_per_whole_string, total_cost
404+
#return cost, cost_per_whole_string, total_cost
405+
406+
# Return the total cost
407+
return total_cost
405408

406409
def codey_generation_cost(self,string):
407410
codey_price = 0.0005
408411
return self.calculate_code_generation_cost(string,codey_price)
409412

410413
def gpt_3_generation_cost(self,string):
411-
chatgpt_price = 0.0002
412-
return self.calculate_code_generation_cost(string,chatgpt_price)
414+
model_price = 0.0002
415+
return self.calculate_code_generation_cost(string,model_price)
413416

414417
def gpt_3_5_turbo_generation_costself(self,string):
415-
chatgpt_price = 0.0080
416-
return self.calculate_code_generation_cost(string,chatgpt_price)
418+
model_price = 0.0080
419+
return self.calculate_code_generation_cost(string,model_price)
417420

418421
def gpt_4_generation_cost(self,string):
419-
chatgpt_price = 0.06
420-
return self.calculate_code_generation_cost(string,chatgpt_price)
422+
model_price = 0.06
423+
return self.calculate_code_generation_cost(string,model_price)
421424

422425
def gpt_text_davinci_generation_cost(self,string):
423-
chatgpt_price = 0.0060
424-
return self.calculate_code_generation_cost(string,chatgpt_price)
426+
model_price = 0.0060
427+
return self.calculate_code_generation_cost(string,model_price)
428+
429+
def palm_text_bison_generation_cost(self,string):
430+
model_price = 0.00025
431+
return self.calculate_code_generation_cost(string,model_price)
432+
433+
def palm_chat_bison_generation_cost(self,string):
434+
model_price = 0.00025
435+
return self.calculate_code_generation_cost(string,model_price)
436+
437+
def palm_embedding_gecko_generation_cost(self,string):
438+
model_price = 0.0002
439+
return self.calculate_code_generation_cost(string,model_price)
440+
441+
def gemini_pro_generation_cost(self,string):
442+
model_price = 0.00025
443+
return self.calculate_code_generation_cost(string,model_price)
444+
445+
def gemini_pro_vision_generation_cost(self,string):
446+
model_price = 0.00025
447+
return self.calculate_code_generation_cost(string,model_price)

libs/palm_coder.py renamed to libs/palmai.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@
1010
import streamlit as st
1111

1212

13-
# Set up logging
14-
logging.basicConfig(filename='palm-coder.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
15-
logger = logging.getLogger(__name__)
16-
1713
class PalmAI:
1814
def __init__(self,api_key, model="text-bison-001", temperature=0.3, max_output_tokens=2048, mode="balanced"):
1915
"""
@@ -87,8 +83,8 @@ def _extract_code(self, code):
8783
else:
8884
logger.info("No special characters found in the code. Returning the original code.")
8985
return code
90-
except Exception as e:
91-
logger.error(f"Error occurred while extracting code: {e}")
86+
except Exception as exception:
87+
logger.error(f"Error occurred while extracting code: {exception}")
9288
return None
9389

9490
def _install_package(self, package_name):
@@ -164,19 +160,6 @@ def generate_code(self, code_prompt,code_language):
164160
And follow the proper coding guidelines and dont add comment unless instructed to do so.
165161
{self.guidelines}
166162
"""
167-
168-
# If graph were requested.
169-
if 'graph' in code_prompt.lower():
170-
prompt += "\n" + "using Python use Matplotlib save the graph in file called 'graph.png'"
171-
172-
# if Chart were requested
173-
if 'chart' in code_prompt.lower() or 'plot' in code_prompt.lower():
174-
prompt += "\n" + "using Python use Plotly save the chart in file called 'chart.png'"
175-
176-
# if Table were requested
177-
if 'table' in code_prompt.lower():
178-
prompt += "\n" + "using Python use Pandas save the table in file called 'table.md'"
179-
180163

181164
palm_completion = palm.generate_text(
182165
model=self.model,
@@ -210,8 +193,8 @@ def generate_code(self, code_prompt,code_language):
210193
else:
211194
raise Exception("Error in code generation: Please enter a valid code.")
212195

213-
except Exception as e:
214-
st.toast(f"Error in code generation: {e}", icon="❌")
196+
except Exception as exception:
197+
st.toast(f"Error in code generation: {exception}", icon="❌")
215198
logger.error(f"Error in code generation: {traceback.format_exc()}")
216199

217200
def fix_generated_code(self, code, code_language, fix_instructions=""):
@@ -308,6 +291,6 @@ def fix_generated_code(self, code, code_language, fix_instructions=""):
308291
else:
309292
st.toast("Error in code fixing: Please enter a valid code and language.", icon="❌")
310293
logger.error("Error in code fixing: Please enter a valid code and language.")
311-
except Exception as e:
312-
st.toast(f"Error in code fixing: {e}", icon="❌")
294+
except Exception as exception:
295+
st.toast(f"Error in code fixing: {exception}", icon="❌")
313296
logger.error(f"Error in code fixing: {traceback.format_exc()}")

0 commit comments

Comments
 (0)