Skip to content

Commit 4856d18

Browse files
committed
Add code fixing functionality to OpenAI
1 parent 078795f commit 4856d18

File tree

2 files changed

+94
-2
lines changed

2 files changed

+94
-2
lines changed

libs/openai_langchain.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,94 @@ def generate_code(self,code_prompt,code_language):
139139
logger.error("Error in code generation: Please enter a valid prompt and language.")
140140
except Exception as e:
141141
st.toast(f"Error in code generation: {e}", icon="❌")
142-
logger.error(f"Error in code generation: {traceback.format_exc()}")
142+
logger.error(f"Error in code generation: {traceback.format_exc()}")
143+
144+
def fix_generated_code(self, code_snippet, code_language, fix_instructions=""):
145+
"""
146+
Function to fix the generated code using the palm API.
147+
"""
148+
try:
149+
# Check for valid code
150+
if not code_snippet or len(code_snippet) == 0:
151+
logger.error("Error in code fixing: Please enter a valid code.")
152+
return
153+
154+
logger.info(f"Fixing code")
155+
if code_snippet and len(code_snippet) > 0:
156+
logger.info(f"Fixing code {code_snippet[:100]}... in language {code_language} and error is {st.session_state.stderr}")
157+
158+
# Improved instructions template
159+
template = f"""
160+
Task: Correct the code snippet provided below in the {code_language} programming language, following the given instructions {fix_instructions}
161+
162+
{code_snippet}
163+
164+
Instructions for Fixing:
165+
1. Identify and rectify any syntax errors, logical issues, or bugs in the code.
166+
2. Ensure that the code produces the desired output.
167+
3. Comment on each line where you make changes, explaining the nature of the fix.
168+
4. Verify that the corrected code is displayed in the output.
169+
170+
Please make sure that the fixed code is included in the output, along with comments detailing the modifications made.
171+
"""
172+
173+
# If there was an error in the previous execution, include it in the prompt
174+
if st.session_state.stderr:
175+
logger.info(f"Error in previous execution: {st.session_state.stderr}")
176+
st.toast(f"Error in previous execution: {st.session_state.stderr}", icon="❌")
177+
template += f"\nFix the following error: {st.session_state.output}"
178+
179+
# Check if the error indicates a missing or unavailable module
180+
error_message = st.session_state.output.lower() # Convert to lowercase for case-insensitive matching
181+
182+
else:
183+
st.toast("No error in previous execution.", icon="✅")
184+
return code_snippet
185+
186+
# Prompt Templates
187+
code_template = template
188+
189+
# LLM Chains definition
190+
# Create a chain that fixed the code
191+
fix_generated_template = PromptTemplate(
192+
input_variables=['code_prompt', 'code_language'],
193+
template=code_template
194+
)
195+
196+
fix_generated_chain = LLMChain(
197+
llm=self.lite_llm,
198+
prompt=fix_generated_template,
199+
output_key='fixed_code',
200+
memory=self.memory,
201+
verbose=True
202+
)
203+
204+
# Prepare the input for the chain
205+
input_data = {
206+
'code_prompt': code_snippet,
207+
'code_language': code_language
208+
}
209+
210+
# Run the chain
211+
output = fix_generated_chain.run(input_data)
212+
213+
logger.info("Text generation completed successfully.")
214+
215+
if output:
216+
# Extracted code from the palm completion
217+
fixed_code = output['code_fix']
218+
extracted_code = self.utils.extract_code(fixed_code)
219+
220+
# Check if the code or extracted code is not empty or null
221+
if not code_snippet or not extracted_code:
222+
raise Exception("Error: Generated code or extracted code is empty or null.")
223+
else:
224+
return extracted_code
225+
else:
226+
raise Exception("Error in code fixing: Please enter a valid code.")
227+
else:
228+
st.toast("Error in code fixing: Please enter a valid code and language.", icon="❌")
229+
logger.error("Error in code fixing: Please enter a valid code and language.")
230+
except Exception as exception:
231+
st.toast(f"Error in code fixing: {exception}", icon="❌")
232+
logger.error(f"Error in code fixing: {traceback.format_exc()}")

script.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,10 @@ def main():
433433
ai_llm_selected = st.session_state.palm_langchain
434434
elif st.session_state.ai_option == "Gemini AI":
435435
ai_llm_selected = st.session_state.gemini_langchain
436+
elif st.session_state.ai_option == "Open AI":
437+
ai_llm_selected = st.session_state.openai_langchain
436438

437-
if len(st.session_state.code_fix_instructions) == 0:
439+
if not st.session_state.code_fix_instructions:
438440
st.toast("Missing fix instructions", icon="❌")
439441
logger.warning("Missing fix instructions")
440442

0 commit comments

Comments
 (0)