Skip to content

Commit e96dba2

Browse files
Merge branch 'streamlit' into master
2 parents 86b65fa + 379624b commit e96dba2

File tree

1 file changed

+349
-0
lines changed

1 file changed

+349
-0
lines changed

script.py

Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
"""
2+
Description: This is LangChain Coder a Streamlit app that uses LangChain to generate code and fix code using OpenAI's GPT-3.
3+
This can generate code in Python, C, C++ and Javascript.
4+
And can run and save the code generated locally.
5+
This is alternative to the OpenAI Code Interpreter Plugin.
6+
Today Date : 30-April-2023
7+
Author: HeavenHM
8+
"""
9+
10+
# Importing the libraries
11+
import tempfile
12+
import subprocess
13+
import traceback
14+
import sys
15+
import os
16+
from io import StringIO
17+
import streamlit as st
18+
from streamlit.components.v1 import html
19+
from langchain.llms import OpenAI
20+
from langchain.prompts import PromptTemplate
21+
from langchain.chains import LLMChain, SequentialChain
22+
from langchain.memory import ConversationBufferMemory
23+
import langchain.agents as lc_agents
24+
from langchain.llms import OpenAI
25+
import logging
26+
from langchain.llms import OpenAI as LangChainOpenAI
27+
import openai
28+
from dotenv import load_dotenv
29+
30+
# Load the environment variables
31+
#load_dotenv()
32+
#openai.api_key = os.getenv("OPENAI_API_KEY")
33+
34+
global generated_code
35+
global code_chain,sequential_chain
36+
37+
LANGUAGE_CODES = {
38+
'C': 'c',
39+
'C++': 'cpp',
40+
'Java': 'java',
41+
'Ruby': 'ruby',
42+
'Scala': 'scala',
43+
'C#': 'csharp',
44+
'Objective C': 'objc',
45+
'Swift': 'swift',
46+
'JavaScript': 'nodejs',
47+
'Kotlin': 'kotlin',
48+
'Python': 'python3',
49+
'GO Lang': 'go',
50+
}
51+
52+
# App title and description
53+
st.title("LangChain Coder - AI 🦜🔗")
54+
code_prompt = st.text_input("Enter a prompt to generate the code")
55+
code_language = st.selectbox("Select a language", list(LANGUAGE_CODES.keys()))
56+
57+
# Generate and Run Buttons
58+
button_generate = st.button("Generate Code")
59+
code_file = st.text_input("Enter file name:")
60+
button_save = st.button("Save Code")
61+
62+
compiler_mode = st.radio("Compiler Mode", ("Online", "Offline"))
63+
button_run = st.button("Run Code")
64+
65+
# Prompt Templates
66+
code_template = PromptTemplate(
67+
input_variables=['code_topic'],
68+
template='Write me code in ' +
69+
f'{code_language} language' + ' for {code_topic}'
70+
)
71+
72+
code_fix_template = PromptTemplate(
73+
input_variables=['code_topic'],
74+
template='Fix any error in the following code in ' +
75+
f'{code_language} language' + ' for {code_topic}'
76+
)
77+
78+
# Memory for the conversation
79+
memory = ConversationBufferMemory(
80+
input_key='code_topic', memory_key='chat_history')
81+
82+
# LLM Chains definition
83+
# Create an OpenAI LLM model
84+
def setup_llm_chain():
85+
global code_chain,sequential_chain
86+
open_ai_llm = OpenAI(temperature=0.7, max_tokens=1000)
87+
88+
# Create a chain that generates the code
89+
code_chain = LLMChain(llm=open_ai_llm, prompt=code_template,
90+
output_key='code', memory=memory, verbose=True)
91+
92+
# Create a chain that fixes the code
93+
code_fix_chain = LLMChain(llm=open_ai_llm, prompt=code_fix_template,
94+
output_key='code_fix', memory=memory, verbose=True)
95+
96+
# Create a sequential chain that combines the two chains above
97+
sequential_chain = SequentialChain(chains=[code_chain, code_fix_chain], input_variables=[
98+
'code_topic'], output_variables=['code', 'code_fix'])
99+
100+
101+
# Generate Dynamic HTML for JDoodle Compiler iFrame Embedding.
102+
def generate_dynamic_html(language, code_prompt):
103+
logger = logging.getLogger(__name__)
104+
logger.info("Generating dynamic HTML for language: %s", language)
105+
html_template = """
106+
<!DOCTYPE html>
107+
<html lang="en">
108+
<head>
109+
<meta charset="UTF-8">
110+
<title>Python App with JavaScript</title>
111+
</head>
112+
<body>
113+
<div data-pym-src='https://www.jdoodle.com/plugin' data-language="{language}"
114+
data-version-index="0" data-libs="">
115+
{script_code}
116+
</div>
117+
<script src="https://www.jdoodle.com/assets/jdoodle-pym.min.js" type="text/javascript"></script>
118+
</body>
119+
</html>
120+
""".format(language=LANGUAGE_CODES[language], script_code=code_prompt)
121+
return html_template
122+
123+
# Setup logging method.
124+
125+
126+
def setup_logging(log_file):
127+
logging.basicConfig(
128+
level=logging.INFO,
129+
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
130+
datefmt="%H:%M:%S",
131+
filename=log_file, # Add this line to save logs to a file
132+
filemode='a', # Append logs to the file
133+
)
134+
135+
136+
# Setup logging
137+
log_file = __file__.replace(".py", ".log")
138+
setup_logging(log_file)
139+
140+
141+
# Create a class
142+
class PythonREPL:
143+
# Define the initialization method
144+
def __init__(self):
145+
pass
146+
147+
# Define the run method
148+
def run(self, command: str) -> str:
149+
# Store the current value of sys.stdout
150+
old_stdout = sys.stdout
151+
# Create a new StringIO object
152+
sys.stdout = mystdout = StringIO()
153+
# Try to execute the code
154+
try:
155+
# Execute the code
156+
exec(command, globals())
157+
sys.stdout = old_stdout
158+
output = mystdout.getvalue()
159+
# If an error occurs, print the error message
160+
except Exception as e:
161+
# Restore the original value of sys.stdout
162+
sys.stdout = old_stdout
163+
# Get the error message
164+
output = str(e)
165+
return output
166+
167+
# Define the Run query function
168+
169+
170+
def run_query(query, model_kwargs, max_iterations):
171+
# Create a LangChainOpenAI object
172+
llm = LangChainOpenAI(**model_kwargs)
173+
# Create the python REPL tool
174+
python_repl = lc_agents.Tool("Python REPL", PythonREPL(
175+
).run, "A Python shell. Use this to execute python commands.")
176+
# Create a list of tools
177+
tools = [python_repl]
178+
# Initialize the agent
179+
agent = lc_agents.initialize_agent(tools, llm, agent=lc_agents.AgentType.ZERO_SHOT_REACT_DESCRIPTION,
180+
model_kwargs=model_kwargs, verbose=True, max_iterations=max_iterations)
181+
# Run the agent
182+
response = agent.run(query)
183+
return response
184+
185+
# Define the Run code function
186+
187+
188+
def run_code(code, language):
189+
logger = logging.getLogger(__name__)
190+
logger.info(f"Running code: {code} in language: {language}")
191+
192+
if language == "Python":
193+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=True) as f:
194+
f.write(code)
195+
f.flush()
196+
197+
logger.info(f"Input file: {f.name}")
198+
output = subprocess.run(
199+
["python", f.name], capture_output=True, text=True)
200+
logger.info(f"Runner Output execution: {output.stdout + output.stderr}")
201+
return output.stdout + output.stderr
202+
203+
elif language == "C" or language == "C++":
204+
ext = ".c" if language == "C" else ".cpp"
205+
with tempfile.NamedTemporaryFile(mode="w", suffix=ext, delete=True) as src_file:
206+
src_file.write(code)
207+
src_file.flush()
208+
209+
logger.info(f"Input file: {src_file.name}")
210+
211+
with tempfile.NamedTemporaryFile(mode="w", suffix="", delete=True) as exec_file:
212+
compile_output = subprocess.run(
213+
["gcc" if language == "C" else "g++", "-o", exec_file.name, src_file.name], capture_output=True, text=True)
214+
215+
if compile_output.returncode != 0:
216+
return compile_output.stderr
217+
218+
logger.info(f"Output file: {exec_file.name}")
219+
run_output = subprocess.run(
220+
[exec_file.name], capture_output=True, text=True)
221+
logger.info(f"Runner Output execution: {run_output.stdout + run_output.stderr}")
222+
return run_output.stdout + run_output.stderr
223+
224+
elif language == "JavaScript":
225+
with tempfile.NamedTemporaryFile(mode="w", suffix=".js", delete=True) as f:
226+
f.write(code)
227+
f.flush()
228+
229+
logger.info(f"Input file: {f.name}")
230+
output = subprocess.run(
231+
["node", f.name], capture_output=True, text=True)
232+
logger.info(f"Runner Output execution: {output.stdout + output.stderr}")
233+
return output.stdout + output.stderr
234+
235+
else:
236+
return "Unsupported language."
237+
238+
# Generate the code
239+
240+
241+
def generate_code():
242+
logger = logging.getLogger(__name__)
243+
try:
244+
st.session_state.generated_code = code_chain.run(code_prompt)
245+
st.session_state.code_language = code_language
246+
st.code(st.session_state.generated_code,
247+
language=st.session_state.code_language.lower())
248+
249+
with st.expander('Message History'):
250+
st.info(memory.buffer)
251+
except Exception as e:
252+
st.write(traceback.format_exc())
253+
logger.error(f"Error in code generation: {traceback.format_exc()}")
254+
255+
# Save the code to a file
256+
257+
258+
def save_code():
259+
logger = logging.getLogger(__name__)
260+
try:
261+
file_name = code_file
262+
logger.info(f"Saving code to file: {file_name}")
263+
if file_name:
264+
with open(file_name, "w") as f:
265+
f.write(st.session_state.generated_code)
266+
st.success(f"Code saved to file {file_name}")
267+
logger.info(f"Code saved to file {file_name}")
268+
st.code(st.session_state.generated_code,
269+
language=st.session_state.code_language.lower())
270+
271+
except Exception as e:
272+
st.write(traceback.format_exc())
273+
logger.error(f"Error in code saving: {traceback.format_exc()}")
274+
275+
# Execute the code
276+
277+
278+
def execute_code(compiler_mode: str):
279+
logger = logging.getLogger(__name__)
280+
logger.info(f"Executing code: {st.session_state.generated_code} in language: {st.session_state.code_language} with Compiler Mode: {compiler_mode}")
281+
282+
try:
283+
if compiler_mode == "online":
284+
html_template = generate_dynamic_html(
285+
st.session_state.code_language, st.session_state.generated_code)
286+
html(html_template, width=720, height=800, scrolling=True)
287+
288+
else:
289+
output = run_code(st.session_state.generated_code,
290+
st.session_state.code_language)
291+
logger.info(f"Output execution: {output}")
292+
293+
if "error" in output.lower() or "exception" in output.lower() or "SyntaxError" in output.lower() or "NameError" in output.lower():
294+
295+
logger.error(f"Error in code execution: {output}")
296+
response = sequential_chain({'code_topic': st.session_state.generated_code})
297+
fixed_code = response['code_fix']
298+
st.code(fixed_code, language=st.session_state.code_language.lower())
299+
300+
with st.expander('Message History'):
301+
st.info(memory.buffer)
302+
logger.warning(f"Trying to run fixed code: {fixed_code}")
303+
output = run_code(fixed_code, st.session_state.code_language)
304+
logger.warning(f"Fixed code output: {output}")
305+
306+
st.code(st.session_state.generated_code,
307+
language=st.session_state.code_language.lower())
308+
st.write("Execution Output:")
309+
st.write(output)
310+
logger.info(f"Execution Output: {output}")
311+
312+
except Exception as e:
313+
st.write("Error in code execution:")
314+
# Output the stack trace
315+
st.write(traceback.format_exc())
316+
logger.error(f"Error in code execution: {traceback.format_exc()}")
317+
318+
319+
# Main method
320+
if __name__ == "__main__":
321+
322+
# Create Set API Key in settings.
323+
if st.expander("Settings"):
324+
api_key = st.text_input("OpenAI API Key", type="password")
325+
if api_key:
326+
openai.api_key = api_key
327+
os.environ["OPENAI_API_KEY"] = api_key
328+
st.success("API Key set successfully.")
329+
setup_llm_chain()
330+
331+
# Session state variables
332+
if "generated_code" not in st.session_state:
333+
st.session_state.generated_code = ""
334+
335+
if "code_language" not in st.session_state:
336+
st.session_state.code_language = ""
337+
338+
# Generate the code
339+
if button_generate and code_prompt:
340+
generate_code()
341+
342+
# Save the code to a file
343+
if button_save and st.session_state.generated_code:
344+
save_code()
345+
346+
# Execute the code
347+
if button_run and code_prompt:
348+
code_state_option = "online" if compiler_mode == "Online" else "offline"
349+
execute_code(code_state_option)

0 commit comments

Comments
 (0)