15
15
import os
16
16
from libs .geminiai import GeminiAI
17
17
from libs .palmai import PalmAI
18
+ from libs .tasks_parser import CodingTasksParser
18
19
import streamlit as st
19
20
from libs .vertexai_langchain import VertexAILangChain
20
21
from libs .general_utils import GeneralUtils
24
25
from libs .utils import *
25
26
from streamlit_ace import st_ace
26
27
27
- general_utils = None
28
+ st . session_state . general_utils = None
28
29
29
30
def main ():
30
31
@@ -45,7 +46,8 @@ def main():
45
46
46
47
# Initialize classes
47
48
code_language = st .session_state .get ("code_language" , "Python" )
48
- general_utils = GeneralUtils ()
49
+ st .session_state .general_utils = GeneralUtils ()
50
+ st .session_state .tasks_parser = CodingTasksParser ()
49
51
50
52
# Streamlit UI
51
53
st .markdown ("<h1 style='text-align: center; color: black;'>LangChain Coder - AI - v1.7 🦜🔗</h1>" , unsafe_allow_html = True )
@@ -83,7 +85,7 @@ def main():
83
85
logs_data = file .read ()
84
86
# download the logs
85
87
file_format = "text/plain"
86
- st .session_state .download_link = general_utils .generate_download_link (logs_data , logs_filename , file_format ,True )
88
+ st .session_state .download_link = st . session_state . general_utils .generate_download_link (logs_data , logs_filename , file_format ,True )
87
89
88
90
# Setting options for Open AI
89
91
api_key = None
@@ -141,7 +143,7 @@ def main():
141
143
if st .session_state .uploaded_file :
142
144
logger .info (f"Vertex AI File credentials file '{ st .session_state .uploaded_file .name } ' initialized state { st .session_state .vertex_ai_loaded } " )
143
145
# Save the temorary uploaded file and delete it after 60 seconds due to security reasons. (Credentials file is deleted after 60 seconds)
144
- file_path = general_utils .save_uploaded_file_temp (st .session_state .uploaded_file ) # Save the uploaded file
146
+ file_path = st . session_state . general_utils .save_uploaded_file_temp (st .session_state .uploaded_file ) # Save the uploaded file
145
147
if file_path :
146
148
credentials_file_path = file_path
147
149
else :
@@ -244,23 +246,33 @@ def main():
244
246
logger .error (f"Error loading Gemini AI: { str (exception )} " )
245
247
246
248
# UI Elements - Main Page
247
- vertex_model_selected = st .session_state ["vertexai" ]["model_name" ]
248
- if vertex_model_selected == "code-bison" :
249
- placeholder = "Enter your prompt for code generation."
250
- elif vertex_model_selected == "code-gecko" :
251
- placeholder = "Enter your code for code completion."
249
+ if st .session_state .ai_option == "Vertex AI" :
250
+ vertex_model_selected = st .session_state ["vertexai" ]["model_name" ]
251
+ if vertex_model_selected == "code-bison" :
252
+ placeholder = "Enter your prompt for code generation."
253
+ elif vertex_model_selected == "code-gecko" :
254
+ placeholder = "Enter your code for code completion."
252
255
else :
253
- placeholder = "Enter your prompt for code generation."
254
- st .error (f"Invalid Vertex AI model selected: { vertex_model_selected } " )
255
-
256
- # Input box for entering the prompt
257
- st .session_state .code_prompt = st .text_area ("Enter Prompt" , height = 200 , placeholder = placeholder ,label_visibility = 'hidden' )
256
+ if st .session_state .code_prompt :
257
+ placeholder = st .session_state .code_prompt
258
+ else :
259
+ placeholder = "Enter your prompt for code generation."
258
260
261
+ # Input box for entering the prompt
262
+ st .session_state .code_prompt = st .text_area (
263
+ "Enter Prompt" ,
264
+ value = st .session_state .code_prompt if 'code_prompt' in st .session_state else "" ,
265
+ height = 130 ,
266
+ placeholder = "Enter your prompt for code generation." if 'code_prompt' not in st .session_state else "" ,
267
+ label_visibility = 'hidden'
268
+ )
269
+
270
+ # Settings for input and output options.
259
271
with st .expander ("Input Options" ):
260
272
with st .container ():
261
273
st .session_state .code_input = st .text_input ("Input (Stdin)" , placeholder = "Input (Stdin)" , label_visibility = 'collapsed' ,value = st .session_state .code_input )
262
274
st .session_state .code_output = st .text_input ("Output (Stdout)" , placeholder = "Output (Stdout)" , label_visibility = 'collapsed' ,value = st .session_state .code_output )
263
- st .session_state .code_fix_instructions = st .text_input ("Fix instructions" , placeholder = "Fix instructions" , label_visibility = 'collapsed' ,value = st .session_state .code_fix_instructions )
275
+ st .session_state .code_fix_instructions = st .text_input ("Debug instructions" , placeholder = "Debug instructions" , label_visibility = 'collapsed' ,value = st .session_state .code_fix_instructions )
264
276
265
277
# Set the input and output to None if the input and output is empty
266
278
if st .session_state .code_input and st .session_state .code_output :
@@ -275,10 +287,10 @@ def main():
275
287
else :
276
288
logger .info (f"Stdout: { st .session_state .code_output } " )
277
289
278
-
290
+ # Buttons for generating, saving, running and debugging the code
279
291
with st .form ('code_controls_form' ):
280
292
# Create columns for alignment
281
- file_name_col , save_code_col ,generate_code_col ,run_code_col ,debug_code_col ,convert_code_col = st .columns (6 )
293
+ file_name_col , save_code_col ,generate_code_col ,run_code_col ,debug_code_col ,convert_code_col , example_code_col = st .columns (7 )
282
294
283
295
# Input Box (for entering the file name) in the first column
284
296
with file_name_col :
@@ -289,7 +301,7 @@ def main():
289
301
download_code_submitted = st .form_submit_button ("Download" )
290
302
if download_code_submitted :
291
303
file_format = "text/plain"
292
- st .session_state .download_link = general_utils .generate_download_link (st .session_state .generated_code , code_file ,file_format ,True )
304
+ st .session_state .download_link = st . session_state . general_utils .generate_download_link (st .session_state .generated_code , code_file ,file_format ,True )
293
305
294
306
# Generate Code button in the third column
295
307
with generate_code_col :
@@ -363,8 +375,12 @@ def main():
363
375
ai_llm_selected = st .session_state .openai_langchain
364
376
365
377
if not st .session_state .code_fix_instructions :
366
- st .toast ("Missing fix instructions" , icon = "❌" )
367
- logger .warning ("Missing fix instructions" )
378
+ st .toast ("Missing Debug instructions" , icon = "❌" )
379
+ logger .warning ("Missing Debug instructions" )
380
+
381
+ if not st .session_state .stderr and st .session_state .code_fix_instructions :
382
+ st .session_state .stderr = st .session_state .code_fix_instructions
383
+ logger .info ("Setting Stderr from input to Debug instructions." )
368
384
369
385
logger .info (f"Fixing code with instructions: { st .session_state .code_fix_instructions } " )
370
386
st .session_state .generated_code = ai_llm_selected .fix_generated_code (st .session_state .generated_code , st .session_state .code_language ,st .session_state .code_fix_instructions )
@@ -394,11 +410,23 @@ def main():
394
410
privacy_accepted = st .session_state .get (f'compiler_{ st .session_state .compiler_mode .lower ()} _privacy_accepted' , False )
395
411
396
412
if privacy_accepted :
397
- st .session_state .output = general_utils .execute_code (st .session_state .compiler_mode )
413
+ st .session_state .output = st . session_state . general_utils .execute_code (st .session_state .compiler_mode )
398
414
else :
399
415
st .toast (f"You didn't accept the privacy policy for { st .session_state .compiler_mode } compiler." , icon = "❌" )
400
416
logger .error (f"You didn't accept the privacy policy for { st .session_state .compiler_mode } compiler." )
401
417
418
+ # Example Code button in the fifth column
419
+ with example_code_col :
420
+ example_submitted = st .form_submit_button ("Example" )
421
+ if example_submitted :
422
+ task_name , task_input , task_output = st .session_state .tasks_parser .get_random_task ()
423
+ st .session_state .code_prompt = "Task = '" + str (task_name ) + "'\n Input = '" + str (task_input ) + "'\n Output = '" + str (task_output ) + "'"
424
+ st .session_state .code_input = task_input
425
+ st .session_state .code_output = task_output
426
+ logger .info (f"Example code loaded successfully. Task name: { task_name } , Task input: { task_input } , Task output: { task_output } " )
427
+ st .toast (f"Example code loaded successfully. Task name: { task_name } , Task input: { task_input } , Task output: { task_output } " , icon = "✅" )
428
+ st .rerun ()
429
+
402
430
# Show the privacy policy for compilers.
403
431
handle_privacy_policy (st .session_state .compiler_mode )
404
432
@@ -435,44 +463,49 @@ def main():
435
463
# Display the code output
436
464
if st .session_state .output :
437
465
st .markdown ("### Output" )
438
- st .code (st .session_state .output , language = st .session_state .code_language .lower ())
439
-
466
+ #st.toast(f"Compiler mode selected '{st.session_state.compiler_mode}'", icon="✅")
467
+ if (st .session_state .compiler_mode .lower () in ["offline" , "api" ]):
468
+ if "https://www.jdoodle.com/plugin" in st .session_state .output :
469
+ pass
470
+ else :
471
+ st .code (st .session_state .output , language = st .session_state .code_language .lower ())
472
+
440
473
# Display the price of the generated code.
441
474
if st .session_state .generated_code and st .session_state .display_cost :
442
475
if st .session_state .ai_option == "Open AI" :
443
476
selected_model = st .session_state ["openai" ]["model_name" ]
444
477
if selected_model == "gpt-3" :
445
- cost , cost_per_whole_string , total_cost = general_utils .gpt_3_generation_cost (st .session_state .generated_code )
478
+ cost , cost_per_whole_string , total_cost = st . session_state . general_utils .gpt_3_generation_cost (st .session_state .generated_code )
446
479
st .table ([["Cost/1K Token" , f"{ cost } USD" ], ["Cost/Whole String" , f"{ cost_per_whole_string } USD" ], ["Total Cost" , f"{ total_cost } USD" ]])
447
480
elif selected_model == "gpt-4" :
448
- cost , cost_per_whole_string , total_cost = general_utils .gpt_4_generation_cost (st .session_state .generated_code )
481
+ cost , cost_per_whole_string , total_cost = st . session_state . general_utils .gpt_4_generation_cost (st .session_state .generated_code )
449
482
st .table ([["Cost/1K Token" , f"{ cost } USD" ], ["Cost/Whole String" , f"{ cost_per_whole_string } USD" ], ["Total Cost" , f"{ total_cost } USD" ]])
450
483
elif selected_model == "text-davinci-003" :
451
- cost , cost_per_whole_string , total_cost = general_utils .gpt_text_davinci_generation_cost (st .session_state .generated_code )
484
+ cost , cost_per_whole_string , total_cost = st . session_state . general_utils .gpt_text_davinci_generation_cost (st .session_state .generated_code )
452
485
st .table ([["Cost/1K Token" , f"{ cost } USD" ], ["Cost/Whole String" , f"{ cost_per_whole_string } USD" ], ["Total Cost" , f"{ total_cost } USD" ]])
453
486
454
487
elif st .session_state .ai_option == "Vertex AI" :
455
488
selected_model = st .session_state ["vertexai" ]["model_name" ]
456
489
if selected_model == "code-bison" or selected_model == "code-gecko" :
457
- cost , cost_per_whole_string , total_cost = general_utils .codey_generation_cost (st .session_state .generated_code )
490
+ cost , cost_per_whole_string , total_cost = st . session_state . general_utils .codey_generation_cost (st .session_state .generated_code )
458
491
st .table ([["Cost/1K Token" , f"{ cost } USD" ], ["Cost/Whole String" , f"{ cost_per_whole_string } USD" ], ["Total Cost" , f"{ total_cost } USD" ]])
459
492
460
493
elif st .session_state .ai_option == "Palm AI" :
461
494
selected_model = st .session_state ["palm" ]["model_name" ]
462
495
if selected_model == "text-bison-001" :
463
496
cost = 0.00025 # Cost per 1K input characters for online requests
464
497
cost_per_whole_string = 0.0005 # Cost per 1K output characters for online requests
465
- total_cost = general_utils .palm_text_bison_generation_cost (st .session_state .generated_code )
498
+ total_cost = st . session_state . general_utils .palm_text_bison_generation_cost (st .session_state .generated_code )
466
499
st .table ([["Cost/1K Token" , f"{ cost } USD" ], ["Cost/Whole String" , f"{ cost_per_whole_string } USD" ], ["Total Cost" , f"{ total_cost } USD" ]])
467
500
elif selected_model == "chat-bison-001" :
468
501
cost = 0.00025 # Cost per 1K input characters for online requests
469
502
cost_per_whole_string = 0.0005 # Cost per 1K output characters for online requests
470
- total_cost = general_utils .palm_chat_bison_generation_cost (st .session_state .generated_code )
503
+ total_cost = st . session_state . general_utils .palm_chat_bison_generation_cost (st .session_state .generated_code )
471
504
st .table ([["Cost/1K Token" , f"{ cost } USD" ], ["Cost/Whole String" , f"{ cost_per_whole_string } USD" ], ["Total Cost" , f"{ total_cost } USD" ]])
472
505
elif selected_model == "embedding-gecko-001" :
473
506
cost = 0.0002 # Cost per 1K characters input for generating embeddings using text as an input
474
507
cost_per_whole_string = 0.0002 # Assuming the same cost for output characters
475
- total_cost = general_utils .palm_embedding_gecko_generation_cost (st .session_state .generated_code )
508
+ total_cost = st . session_state . general_utils .palm_embedding_gecko_generation_cost (st .session_state .generated_code )
476
509
st .table ([["Cost/1K Token" , f"{ cost } USD" ], ["Cost/Whole String" , f"{ cost_per_whole_string } USD" ], ["Total Cost" , f"{ total_cost } USD" ]])
477
510
478
511
elif st .session_state .ai_option == "Gemini AI" :
@@ -481,15 +514,15 @@ def main():
481
514
if selected_model == "gemini-pro" :
482
515
cost_per_input_char = 0.00025 # Cost per 1K input characters for online requests
483
516
cost_per_output_char = 0.0005 # Cost per 1K output characters for online requests
484
- total_cost = general_utils .gemini_pro_generation_cost (st .session_state .generated_code )
517
+ total_cost = st . session_state . general_utils .gemini_pro_generation_cost (st .session_state .generated_code )
485
518
st .table ([["Cost/1K Input Token" , f"{ cost_per_input_char } USD" ], ["Cost/1K Output Token" , f"{ cost_per_output_char } USD" ], ["Total Cost" , f"{ total_cost } USD" ]])
486
519
487
520
elif selected_model == "gemini-pro-vision" :
488
521
cost_per_image = 0.0025 # Cost per image for online requests
489
522
cost_per_second = 0.002 # Cost per second for online requests
490
523
cost_per_input_char = 0.00025 # Cost per 1K input characters for online requests
491
524
cost_per_output_char = 0.0005 # Cost per 1K output characters for online requests
492
- total_cost = general_utils .gemini_pro_vision_generation_cost (st .session_state .generated_code )
525
+ total_cost = st . session_state . general_utils .gemini_pro_vision_generation_cost (st .session_state .generated_code )
493
526
st .table ([["Cost/Image" , f"{ cost_per_image } USD" ], ["Cost/Second" , f"{ cost_per_second } USD" ], ["Cost/1K Input Token" , f"{ cost_per_input_char } USD" ], ["Cost/1K Output Token" , f"{ cost_per_output_char } USD" ], ["Total Cost" , f"{ total_cost } USD" ]])
494
527
495
528
# Expander for coding guidelines
@@ -510,7 +543,7 @@ def main():
510
543
"Robust Code" ,
511
544
"Memory efficiency" ,
512
545
"Speed efficiency" ,
513
- "Standard Naming conventions "
546
+ "Standard Naming"
514
547
]
515
548
516
549
for guideline in guidelines :
0 commit comments