Skip to content

Commit 037b096

Browse files
committed
Added Example Code Snippets
1 parent 04c73ef commit 037b096

File tree

4 files changed

+123
-36
lines changed

4 files changed

+123
-36
lines changed

libs/geminiai.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,9 @@ def convert_generated_code(self, code, code_language):
240240
if gemini_completion:
241241
# Extracted code from the palm completion
242242
code = gemini_completion.text
243-
extracted_code = self.utils.extract_code(code)
243+
extracted_code = None
244+
if code:
245+
extracted_code = self.utils.extract_code(code)
244246

245247
# Check if the code or extracted code is not empty or null
246248
if not code or not extracted_code:

libs/tasks_parser.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import json
2+
from libs.logger import logger
3+
import random
4+
5+
class CodingTasksParser:
6+
"""A class to parse coding tasks from a JSON file."""
7+
8+
def __init__(self):
9+
"""Initialize the parser with the file path."""
10+
self.file_path = 'data/coding_tasks.json'
11+
self._data = None
12+
self._parse()
13+
14+
def _parse(self):
15+
"""Parse the JSON file and store the data."""
16+
try:
17+
with open(self.file_path, 'r') as f:
18+
self._data = json.load(f)
19+
logger.info(f'Successfully parsed file {self.file_path}')
20+
except Exception as exception:
21+
logger.error(f'Failed to parse file {self.file_path}: {exception}')
22+
raise
23+
24+
def _get_tasks(self):
25+
"""Return a list of all tasks."""
26+
if self._data is None:
27+
raise ValueError('No data parsed yet')
28+
return self._data['coding_tasks']
29+
30+
def _get_task(self, index):
31+
"""Return a specific task by its index."""
32+
if self._data is None:
33+
raise ValueError('No data parsed yet')
34+
try:
35+
return self._data['coding_tasks'][index]
36+
except IndexError:
37+
logger.error(f'No task at index {index}')
38+
raise
39+
40+
def get_random_task(self):
41+
"""Return a random task."""
42+
try:
43+
tasks = self._get_tasks()
44+
task = self._get_task(random.randint(0, len(tasks) - 1))
45+
return task['task'], task['example']['input'], task['example']['output']
46+
except Exception as exception:
47+
logger.error(f'Failed to get random task: {exception}')
48+
raise

libs/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def initialize_session_state():
6262
st.session_state.compiler_online_privacy_accepted = None
6363
if "compiler_api_privacy_accepted" not in st.session_state:
6464
st.session_state.compiler_api_privacy_accepted = None
65+
if "general_utils" not in st.session_state:
66+
st.session_state.general_utils = None
67+
if "tasks_parser" not in st.session_state:
68+
st.session_state.tasks_parser = None
6569

6670
# Initialize session state for Vertex AI
6771
if "vertexai" not in st.session_state:
@@ -162,12 +166,12 @@ def load_css(file_name):
162166
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
163167

164168
def display_code_editor(font_size, tab_size, theme, keybinding, show_gutter, show_print_margin, wrap, auto_update, readonly, language):
165-
if st.session_state.generated_code and st.session_state.compiler_mode == "Offline":
169+
if st.session_state.generated_code and st.session_state.compiler_mode in ["Offline", "API"]:
166170
st.session_state.generated_code = st_ace(
167171
language=language.lower(),
168172
theme=theme,
169173
keybinding=keybinding,
170-
height=400,
174+
height=600,
171175
value=st.session_state.generated_code,
172176
font_size=font_size,
173177
tab_size=tab_size,

script.py

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
from libs.geminiai import GeminiAI
1717
from libs.palmai import PalmAI
18+
from libs.tasks_parser import CodingTasksParser
1819
import streamlit as st
1920
from libs.vertexai_langchain import VertexAILangChain
2021
from libs.general_utils import GeneralUtils
@@ -24,7 +25,7 @@
2425
from libs.utils import *
2526
from streamlit_ace import st_ace
2627

27-
general_utils = None
28+
st.session_state.general_utils = None
2829

2930
def main():
3031

@@ -45,7 +46,8 @@ def main():
4546

4647
# Initialize classes
4748
code_language = st.session_state.get("code_language", "Python")
48-
general_utils = GeneralUtils()
49+
st.session_state.general_utils = GeneralUtils()
50+
st.session_state.tasks_parser = CodingTasksParser()
4951

5052
# Streamlit UI
5153
st.markdown("<h1 style='text-align: center; color: black;'>LangChain Coder - AI - v1.7 🦜🔗</h1>", unsafe_allow_html=True)
@@ -83,7 +85,7 @@ def main():
8385
logs_data = file.read()
8486
# download the logs
8587
file_format = "text/plain"
86-
st.session_state.download_link = general_utils.generate_download_link(logs_data, logs_filename, file_format,True)
88+
st.session_state.download_link = st.session_state.general_utils.generate_download_link(logs_data, logs_filename, file_format,True)
8789

8890
# Setting options for Open AI
8991
api_key = None
@@ -141,7 +143,7 @@ def main():
141143
if st.session_state.uploaded_file:
142144
logger.info(f"Vertex AI File credentials file '{st.session_state.uploaded_file.name}' initialized state {st.session_state.vertex_ai_loaded}")
143145
# Save the temorary uploaded file and delete it after 60 seconds due to security reasons. (Credentials file is deleted after 60 seconds)
144-
file_path = general_utils.save_uploaded_file_temp(st.session_state.uploaded_file) # Save the uploaded file
146+
file_path = st.session_state.general_utils.save_uploaded_file_temp(st.session_state.uploaded_file) # Save the uploaded file
145147
if file_path:
146148
credentials_file_path = file_path
147149
else:
@@ -244,23 +246,33 @@ def main():
244246
logger.error(f"Error loading Gemini AI: {str(exception)}")
245247

246248
# UI Elements - Main Page
247-
vertex_model_selected = st.session_state["vertexai"]["model_name"]
248-
if vertex_model_selected == "code-bison":
249-
placeholder = "Enter your prompt for code generation."
250-
elif vertex_model_selected == "code-gecko":
251-
placeholder = "Enter your code for code completion."
249+
if st.session_state.ai_option == "Vertex AI":
250+
vertex_model_selected = st.session_state["vertexai"]["model_name"]
251+
if vertex_model_selected == "code-bison":
252+
placeholder = "Enter your prompt for code generation."
253+
elif vertex_model_selected == "code-gecko":
254+
placeholder = "Enter your code for code completion."
252255
else:
253-
placeholder = "Enter your prompt for code generation."
254-
st.error(f"Invalid Vertex AI model selected: {vertex_model_selected}")
255-
256-
# Input box for entering the prompt
257-
st.session_state.code_prompt = st.text_area("Enter Prompt", height=200, placeholder=placeholder,label_visibility='hidden')
256+
if st.session_state.code_prompt:
257+
placeholder = st.session_state.code_prompt
258+
else:
259+
placeholder = "Enter your prompt for code generation."
258260

261+
# Input box for entering the prompt
262+
st.session_state.code_prompt = st.text_area(
263+
"Enter Prompt",
264+
value=st.session_state.code_prompt if 'code_prompt' in st.session_state else "",
265+
height=130,
266+
placeholder="Enter your prompt for code generation." if 'code_prompt' not in st.session_state else "",
267+
label_visibility='hidden'
268+
)
269+
270+
# Settings for input and output options.
259271
with st.expander("Input Options"):
260272
with st.container():
261273
st.session_state.code_input = st.text_input("Input (Stdin)", placeholder="Input (Stdin)", label_visibility='collapsed',value=st.session_state.code_input)
262274
st.session_state.code_output = st.text_input("Output (Stdout)", placeholder="Output (Stdout)", label_visibility='collapsed',value=st.session_state.code_output)
263-
st.session_state.code_fix_instructions = st.text_input("Fix instructions", placeholder="Fix instructions", label_visibility='collapsed',value=st.session_state.code_fix_instructions)
275+
st.session_state.code_fix_instructions = st.text_input("Debug instructions", placeholder="Debug instructions", label_visibility='collapsed',value=st.session_state.code_fix_instructions)
264276

265277
# Set the input and output to None if the input and output is empty
266278
if st.session_state.code_input and st.session_state.code_output:
@@ -275,10 +287,10 @@ def main():
275287
else:
276288
logger.info(f"Stdout: {st.session_state.code_output}")
277289

278-
290+
# Buttons for generating, saving, running and debugging the code
279291
with st.form('code_controls_form'):
280292
# Create columns for alignment
281-
file_name_col, save_code_col,generate_code_col,run_code_col,debug_code_col,convert_code_col = st.columns(6)
293+
file_name_col, save_code_col,generate_code_col,run_code_col,debug_code_col,convert_code_col,example_code_col = st.columns(7)
282294

283295
# Input Box (for entering the file name) in the first column
284296
with file_name_col:
@@ -289,7 +301,7 @@ def main():
289301
download_code_submitted = st.form_submit_button("Download")
290302
if download_code_submitted:
291303
file_format = "text/plain"
292-
st.session_state.download_link = general_utils.generate_download_link(st.session_state.generated_code, code_file,file_format,True)
304+
st.session_state.download_link = st.session_state.general_utils.generate_download_link(st.session_state.generated_code, code_file,file_format,True)
293305

294306
# Generate Code button in the third column
295307
with generate_code_col:
@@ -363,8 +375,12 @@ def main():
363375
ai_llm_selected = st.session_state.openai_langchain
364376

365377
if not st.session_state.code_fix_instructions:
366-
st.toast("Missing fix instructions", icon="❌")
367-
logger.warning("Missing fix instructions")
378+
st.toast("Missing Debug instructions", icon="❌")
379+
logger.warning("Missing Debug instructions")
380+
381+
if not st.session_state.stderr and st.session_state.code_fix_instructions:
382+
st.session_state.stderr = st.session_state.code_fix_instructions
383+
logger.info("Setting Stderr from input to Debug instructions.")
368384

369385
logger.info(f"Fixing code with instructions: {st.session_state.code_fix_instructions}")
370386
st.session_state.generated_code = ai_llm_selected.fix_generated_code(st.session_state.generated_code, st.session_state.code_language,st.session_state.code_fix_instructions)
@@ -394,11 +410,23 @@ def main():
394410
privacy_accepted = st.session_state.get(f'compiler_{st.session_state.compiler_mode.lower()}_privacy_accepted', False)
395411

396412
if privacy_accepted:
397-
st.session_state.output = general_utils.execute_code(st.session_state.compiler_mode)
413+
st.session_state.output = st.session_state.general_utils.execute_code(st.session_state.compiler_mode)
398414
else:
399415
st.toast(f"You didn't accept the privacy policy for {st.session_state.compiler_mode} compiler.", icon="❌")
400416
logger.error(f"You didn't accept the privacy policy for {st.session_state.compiler_mode} compiler.")
401417

418+
# Example Code button in the fifth column
419+
with example_code_col:
420+
example_submitted = st.form_submit_button("Example")
421+
if example_submitted:
422+
task_name, task_input, task_output = st.session_state.tasks_parser.get_random_task()
423+
st.session_state.code_prompt = "Task = '" + str(task_name) + "'\nInput = '" + str(task_input) + "'\nOutput = '" + str(task_output) + "'"
424+
st.session_state.code_input = task_input
425+
st.session_state.code_output = task_output
426+
logger.info(f"Example code loaded successfully. Task name: {task_name}, Task input: {task_input}, Task output: {task_output}")
427+
st.toast(f"Example code loaded successfully. Task name: {task_name}, Task input: {task_input}, Task output: {task_output}", icon="✅")
428+
st.rerun()
429+
402430
# Show the privacy policy for compilers.
403431
handle_privacy_policy(st.session_state.compiler_mode)
404432

@@ -435,44 +463,49 @@ def main():
435463
# Display the code output
436464
if st.session_state.output:
437465
st.markdown("### Output")
438-
st.code(st.session_state.output, language=st.session_state.code_language.lower())
439-
466+
#st.toast(f"Compiler mode selected '{st.session_state.compiler_mode}'", icon="✅")
467+
if (st.session_state.compiler_mode.lower() in ["offline", "api"]):
468+
if "https://www.jdoodle.com/plugin" in st.session_state.output:
469+
pass
470+
else:
471+
st.code(st.session_state.output, language=st.session_state.code_language.lower())
472+
440473
# Display the price of the generated code.
441474
if st.session_state.generated_code and st.session_state.display_cost:
442475
if st.session_state.ai_option == "Open AI":
443476
selected_model = st.session_state["openai"]["model_name"]
444477
if selected_model == "gpt-3":
445-
cost, cost_per_whole_string, total_cost = general_utils.gpt_3_generation_cost(st.session_state.generated_code)
478+
cost, cost_per_whole_string, total_cost = st.session_state.general_utils.gpt_3_generation_cost(st.session_state.generated_code)
446479
st.table([["Cost/1K Token", f"{cost} USD"], ["Cost/Whole String", f"{cost_per_whole_string} USD"], ["Total Cost", f"{total_cost} USD"]])
447480
elif selected_model == "gpt-4":
448-
cost, cost_per_whole_string, total_cost = general_utils.gpt_4_generation_cost(st.session_state.generated_code)
481+
cost, cost_per_whole_string, total_cost = st.session_state.general_utils.gpt_4_generation_cost(st.session_state.generated_code)
449482
st.table([["Cost/1K Token", f"{cost} USD"], ["Cost/Whole String", f"{cost_per_whole_string} USD"], ["Total Cost", f"{total_cost} USD"]])
450483
elif selected_model == "text-davinci-003":
451-
cost, cost_per_whole_string, total_cost = general_utils.gpt_text_davinci_generation_cost(st.session_state.generated_code)
484+
cost, cost_per_whole_string, total_cost = st.session_state.general_utils.gpt_text_davinci_generation_cost(st.session_state.generated_code)
452485
st.table([["Cost/1K Token", f"{cost} USD"], ["Cost/Whole String", f"{cost_per_whole_string} USD"], ["Total Cost", f"{total_cost} USD"]])
453486

454487
elif st.session_state.ai_option == "Vertex AI":
455488
selected_model = st.session_state["vertexai"]["model_name"]
456489
if selected_model == "code-bison" or selected_model == "code-gecko":
457-
cost, cost_per_whole_string, total_cost = general_utils.codey_generation_cost(st.session_state.generated_code)
490+
cost, cost_per_whole_string, total_cost = st.session_state.general_utils.codey_generation_cost(st.session_state.generated_code)
458491
st.table([["Cost/1K Token", f"{cost} USD"], ["Cost/Whole String", f"{cost_per_whole_string} USD"], ["Total Cost", f"{total_cost} USD"]])
459492

460493
elif st.session_state.ai_option == "Palm AI":
461494
selected_model = st.session_state["palm"]["model_name"]
462495
if selected_model == "text-bison-001":
463496
cost = 0.00025 # Cost per 1K input characters for online requests
464497
cost_per_whole_string = 0.0005 # Cost per 1K output characters for online requests
465-
total_cost = general_utils.palm_text_bison_generation_cost(st.session_state.generated_code)
498+
total_cost = st.session_state.general_utils.palm_text_bison_generation_cost(st.session_state.generated_code)
466499
st.table([["Cost/1K Token", f"{cost} USD"], ["Cost/Whole String", f"{cost_per_whole_string} USD"], ["Total Cost", f"{total_cost} USD"]])
467500
elif selected_model == "chat-bison-001":
468501
cost = 0.00025 # Cost per 1K input characters for online requests
469502
cost_per_whole_string = 0.0005 # Cost per 1K output characters for online requests
470-
total_cost = general_utils.palm_chat_bison_generation_cost(st.session_state.generated_code)
503+
total_cost = st.session_state.general_utils.palm_chat_bison_generation_cost(st.session_state.generated_code)
471504
st.table([["Cost/1K Token", f"{cost} USD"], ["Cost/Whole String", f"{cost_per_whole_string} USD"], ["Total Cost", f"{total_cost} USD"]])
472505
elif selected_model == "embedding-gecko-001":
473506
cost = 0.0002 # Cost per 1K characters input for generating embeddings using text as an input
474507
cost_per_whole_string = 0.0002 # Assuming the same cost for output characters
475-
total_cost = general_utils.palm_embedding_gecko_generation_cost(st.session_state.generated_code)
508+
total_cost = st.session_state.general_utils.palm_embedding_gecko_generation_cost(st.session_state.generated_code)
476509
st.table([["Cost/1K Token", f"{cost} USD"], ["Cost/Whole String", f"{cost_per_whole_string} USD"], ["Total Cost", f"{total_cost} USD"]])
477510

478511
elif st.session_state.ai_option == "Gemini AI":
@@ -481,15 +514,15 @@ def main():
481514
if selected_model == "gemini-pro":
482515
cost_per_input_char = 0.00025 # Cost per 1K input characters for online requests
483516
cost_per_output_char = 0.0005 # Cost per 1K output characters for online requests
484-
total_cost = general_utils.gemini_pro_generation_cost(st.session_state.generated_code)
517+
total_cost = st.session_state.general_utils.gemini_pro_generation_cost(st.session_state.generated_code)
485518
st.table([["Cost/1K Input Token", f"{cost_per_input_char} USD"], ["Cost/1K Output Token", f"{cost_per_output_char} USD"], ["Total Cost", f"{total_cost} USD"]])
486519

487520
elif selected_model == "gemini-pro-vision":
488521
cost_per_image = 0.0025 # Cost per image for online requests
489522
cost_per_second = 0.002 # Cost per second for online requests
490523
cost_per_input_char = 0.00025 # Cost per 1K input characters for online requests
491524
cost_per_output_char = 0.0005 # Cost per 1K output characters for online requests
492-
total_cost = general_utils.gemini_pro_vision_generation_cost(st.session_state.generated_code)
525+
total_cost = st.session_state.general_utils.gemini_pro_vision_generation_cost(st.session_state.generated_code)
493526
st.table([["Cost/Image", f"{cost_per_image} USD"], ["Cost/Second", f"{cost_per_second} USD"], ["Cost/1K Input Token", f"{cost_per_input_char} USD"], ["Cost/1K Output Token", f"{cost_per_output_char} USD"], ["Total Cost", f"{total_cost} USD"]])
494527

495528
# Expander for coding guidelines
@@ -510,7 +543,7 @@ def main():
510543
"Robust Code",
511544
"Memory efficiency",
512545
"Speed efficiency",
513-
"Standard Naming conventions"
546+
"Standard Naming"
514547
]
515548

516549
for guideline in guidelines:

0 commit comments

Comments
 (0)