|
| 1 | +""" |
| 2 | +Description: This is LangChain Coder a Streamlit app that uses LangChain to generate code and fix code using OpenAI's GPT-3. |
| 3 | +This can generate code in Python, C, C++ and Javascript. |
| 4 | +And can run and save the code generated locally. |
| 5 | +This is alternative to the OpenAI Code Interpreter Plugin. |
| 6 | +Today Date : 30-April-2023 |
| 7 | +Author: HeavenHM |
| 8 | +""" |
| 9 | + |
| 10 | +# Importing the libraries |
| 11 | +import tempfile |
| 12 | +import subprocess |
| 13 | +import traceback |
| 14 | +import sys |
| 15 | +import os |
| 16 | +from io import StringIO |
| 17 | +import streamlit as st |
| 18 | +from streamlit.components.v1 import html |
| 19 | +from langchain.llms import OpenAI |
| 20 | +from langchain.prompts import PromptTemplate |
| 21 | +from langchain.chains import LLMChain, SequentialChain |
| 22 | +from langchain.memory import ConversationBufferMemory |
| 23 | +import langchain.agents as lc_agents |
| 24 | +from langchain.llms import OpenAI |
| 25 | +import logging |
| 26 | +from langchain.llms import OpenAI as LangChainOpenAI |
| 27 | +import openai |
| 28 | +from dotenv import load_dotenv |
| 29 | + |
| 30 | +# Load the environment variables |
| 31 | +#load_dotenv() |
| 32 | +#openai.api_key = os.getenv("OPENAI_API_KEY") |
| 33 | + |
| 34 | +global generated_code |
| 35 | +global code_chain,sequential_chain |
| 36 | + |
| 37 | +LANGUAGE_CODES = { |
| 38 | + 'C': 'c', |
| 39 | + 'C++': 'cpp', |
| 40 | + 'Java': 'java', |
| 41 | + 'Ruby': 'ruby', |
| 42 | + 'Scala': 'scala', |
| 43 | + 'C#': 'csharp', |
| 44 | + 'Objective C': 'objc', |
| 45 | + 'Swift': 'swift', |
| 46 | + 'JavaScript': 'nodejs', |
| 47 | + 'Kotlin': 'kotlin', |
| 48 | + 'Python': 'python3', |
| 49 | + 'GO Lang': 'go', |
| 50 | +} |
| 51 | + |
| 52 | +# App title and description |
| 53 | +st.title("LangChain Coder - AI 🦜🔗") |
| 54 | +code_prompt = st.text_input("Enter a prompt to generate the code") |
| 55 | +code_language = st.selectbox("Select a language", list(LANGUAGE_CODES.keys())) |
| 56 | + |
| 57 | +# Generate and Run Buttons |
| 58 | +button_generate = st.button("Generate Code") |
| 59 | +code_file = st.text_input("Enter file name:") |
| 60 | +button_save = st.button("Save Code") |
| 61 | + |
| 62 | +compiler_mode = st.radio("Compiler Mode", ("Online", "Offline")) |
| 63 | +button_run = st.button("Run Code") |
| 64 | + |
| 65 | +# Prompt Templates |
| 66 | +code_template = PromptTemplate( |
| 67 | + input_variables=['code_topic'], |
| 68 | + template='Write me code in ' + |
| 69 | + f'{code_language} language' + ' for {code_topic}' |
| 70 | +) |
| 71 | + |
| 72 | +code_fix_template = PromptTemplate( |
| 73 | + input_variables=['code_topic'], |
| 74 | + template='Fix any error in the following code in ' + |
| 75 | + f'{code_language} language' + ' for {code_topic}' |
| 76 | +) |
| 77 | + |
| 78 | +# Memory for the conversation |
| 79 | +memory = ConversationBufferMemory( |
| 80 | + input_key='code_topic', memory_key='chat_history') |
| 81 | + |
| 82 | +# LLM Chains definition |
| 83 | +# Create an OpenAI LLM model |
| 84 | +def setup_llm_chain(): |
| 85 | + global code_chain,sequential_chain |
| 86 | + open_ai_llm = OpenAI(temperature=0.7, max_tokens=1000) |
| 87 | + |
| 88 | + # Create a chain that generates the code |
| 89 | + code_chain = LLMChain(llm=open_ai_llm, prompt=code_template, |
| 90 | + output_key='code', memory=memory, verbose=True) |
| 91 | + |
| 92 | + # Create a chain that fixes the code |
| 93 | + code_fix_chain = LLMChain(llm=open_ai_llm, prompt=code_fix_template, |
| 94 | + output_key='code_fix', memory=memory, verbose=True) |
| 95 | + |
| 96 | + # Create a sequential chain that combines the two chains above |
| 97 | + sequential_chain = SequentialChain(chains=[code_chain, code_fix_chain], input_variables=[ |
| 98 | + 'code_topic'], output_variables=['code', 'code_fix']) |
| 99 | + |
| 100 | + |
| 101 | +# Generate Dynamic HTML for JDoodle Compiler iFrame Embedding. |
| 102 | +def generate_dynamic_html(language, code_prompt): |
| 103 | + logger = logging.getLogger(__name__) |
| 104 | + logger.info("Generating dynamic HTML for language: %s", language) |
| 105 | + html_template = """ |
| 106 | + <!DOCTYPE html> |
| 107 | + <html lang="en"> |
| 108 | + <head> |
| 109 | + <meta charset="UTF-8"> |
| 110 | + <title>Python App with JavaScript</title> |
| 111 | + </head> |
| 112 | + <body> |
| 113 | + <div data-pym-src='https://www.jdoodle.com/plugin' data-language="{language}" |
| 114 | + data-version-index="0" data-libs=""> |
| 115 | + {script_code} |
| 116 | + </div> |
| 117 | + <script src="https://www.jdoodle.com/assets/jdoodle-pym.min.js" type="text/javascript"></script> |
| 118 | + </body> |
| 119 | + </html> |
| 120 | + """.format(language=LANGUAGE_CODES[language], script_code=code_prompt) |
| 121 | + return html_template |
| 122 | + |
| 123 | +# Setup logging method. |
| 124 | + |
| 125 | + |
| 126 | +def setup_logging(log_file): |
| 127 | + logging.basicConfig( |
| 128 | + level=logging.INFO, |
| 129 | + format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", |
| 130 | + datefmt="%H:%M:%S", |
| 131 | + filename=log_file, # Add this line to save logs to a file |
| 132 | + filemode='a', # Append logs to the file |
| 133 | + ) |
| 134 | + |
| 135 | + |
| 136 | +# Setup logging |
| 137 | +log_file = __file__.replace(".py", ".log") |
| 138 | +setup_logging(log_file) |
| 139 | + |
| 140 | + |
| 141 | +# Create a class |
| 142 | +class PythonREPL: |
| 143 | + # Define the initialization method |
| 144 | + def __init__(self): |
| 145 | + pass |
| 146 | + |
| 147 | + # Define the run method |
| 148 | + def run(self, command: str) -> str: |
| 149 | + # Store the current value of sys.stdout |
| 150 | + old_stdout = sys.stdout |
| 151 | + # Create a new StringIO object |
| 152 | + sys.stdout = mystdout = StringIO() |
| 153 | + # Try to execute the code |
| 154 | + try: |
| 155 | + # Execute the code |
| 156 | + exec(command, globals()) |
| 157 | + sys.stdout = old_stdout |
| 158 | + output = mystdout.getvalue() |
| 159 | + # If an error occurs, print the error message |
| 160 | + except Exception as e: |
| 161 | + # Restore the original value of sys.stdout |
| 162 | + sys.stdout = old_stdout |
| 163 | + # Get the error message |
| 164 | + output = str(e) |
| 165 | + return output |
| 166 | + |
| 167 | +# Define the Run query function |
| 168 | + |
| 169 | + |
| 170 | +def run_query(query, model_kwargs, max_iterations): |
| 171 | + # Create a LangChainOpenAI object |
| 172 | + llm = LangChainOpenAI(**model_kwargs) |
| 173 | + # Create the python REPL tool |
| 174 | + python_repl = lc_agents.Tool("Python REPL", PythonREPL( |
| 175 | + ).run, "A Python shell. Use this to execute python commands.") |
| 176 | + # Create a list of tools |
| 177 | + tools = [python_repl] |
| 178 | + # Initialize the agent |
| 179 | + agent = lc_agents.initialize_agent(tools, llm, agent=lc_agents.AgentType.ZERO_SHOT_REACT_DESCRIPTION, |
| 180 | + model_kwargs=model_kwargs, verbose=True, max_iterations=max_iterations) |
| 181 | + # Run the agent |
| 182 | + response = agent.run(query) |
| 183 | + return response |
| 184 | + |
| 185 | +# Define the Run code function |
| 186 | + |
| 187 | + |
| 188 | +def run_code(code, language): |
| 189 | + logger = logging.getLogger(__name__) |
| 190 | + logger.info(f"Running code: {code} in language: {language}") |
| 191 | + |
| 192 | + if language == "Python": |
| 193 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=True) as f: |
| 194 | + f.write(code) |
| 195 | + f.flush() |
| 196 | + |
| 197 | + logger.info(f"Input file: {f.name}") |
| 198 | + output = subprocess.run( |
| 199 | + ["python", f.name], capture_output=True, text=True) |
| 200 | + logger.info(f"Runner Output execution: {output.stdout + output.stderr}") |
| 201 | + return output.stdout + output.stderr |
| 202 | + |
| 203 | + elif language == "C" or language == "C++": |
| 204 | + ext = ".c" if language == "C" else ".cpp" |
| 205 | + with tempfile.NamedTemporaryFile(mode="w", suffix=ext, delete=True) as src_file: |
| 206 | + src_file.write(code) |
| 207 | + src_file.flush() |
| 208 | + |
| 209 | + logger.info(f"Input file: {src_file.name}") |
| 210 | + |
| 211 | + with tempfile.NamedTemporaryFile(mode="w", suffix="", delete=True) as exec_file: |
| 212 | + compile_output = subprocess.run( |
| 213 | + ["gcc" if language == "C" else "g++", "-o", exec_file.name, src_file.name], capture_output=True, text=True) |
| 214 | + |
| 215 | + if compile_output.returncode != 0: |
| 216 | + return compile_output.stderr |
| 217 | + |
| 218 | + logger.info(f"Output file: {exec_file.name}") |
| 219 | + run_output = subprocess.run( |
| 220 | + [exec_file.name], capture_output=True, text=True) |
| 221 | + logger.info(f"Runner Output execution: {run_output.stdout + run_output.stderr}") |
| 222 | + return run_output.stdout + run_output.stderr |
| 223 | + |
| 224 | + elif language == "JavaScript": |
| 225 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".js", delete=True) as f: |
| 226 | + f.write(code) |
| 227 | + f.flush() |
| 228 | + |
| 229 | + logger.info(f"Input file: {f.name}") |
| 230 | + output = subprocess.run( |
| 231 | + ["node", f.name], capture_output=True, text=True) |
| 232 | + logger.info(f"Runner Output execution: {output.stdout + output.stderr}") |
| 233 | + return output.stdout + output.stderr |
| 234 | + |
| 235 | + else: |
| 236 | + return "Unsupported language." |
| 237 | + |
| 238 | +# Generate the code |
| 239 | + |
| 240 | + |
| 241 | +def generate_code(): |
| 242 | + logger = logging.getLogger(__name__) |
| 243 | + try: |
| 244 | + st.session_state.generated_code = code_chain.run(code_prompt) |
| 245 | + st.session_state.code_language = code_language |
| 246 | + st.code(st.session_state.generated_code, |
| 247 | + language=st.session_state.code_language.lower()) |
| 248 | + |
| 249 | + with st.expander('Message History'): |
| 250 | + st.info(memory.buffer) |
| 251 | + except Exception as e: |
| 252 | + st.write(traceback.format_exc()) |
| 253 | + logger.error(f"Error in code generation: {traceback.format_exc()}") |
| 254 | + |
| 255 | +# Save the code to a file |
| 256 | + |
| 257 | + |
| 258 | +def save_code(): |
| 259 | + logger = logging.getLogger(__name__) |
| 260 | + try: |
| 261 | + file_name = code_file |
| 262 | + logger.info(f"Saving code to file: {file_name}") |
| 263 | + if file_name: |
| 264 | + with open(file_name, "w") as f: |
| 265 | + f.write(st.session_state.generated_code) |
| 266 | + st.success(f"Code saved to file {file_name}") |
| 267 | + logger.info(f"Code saved to file {file_name}") |
| 268 | + st.code(st.session_state.generated_code, |
| 269 | + language=st.session_state.code_language.lower()) |
| 270 | + |
| 271 | + except Exception as e: |
| 272 | + st.write(traceback.format_exc()) |
| 273 | + logger.error(f"Error in code saving: {traceback.format_exc()}") |
| 274 | + |
| 275 | +# Execute the code |
| 276 | + |
| 277 | + |
| 278 | +def execute_code(compiler_mode: str): |
| 279 | + logger = logging.getLogger(__name__) |
| 280 | + logger.info(f"Executing code: {st.session_state.generated_code} in language: {st.session_state.code_language} with Compiler Mode: {compiler_mode}") |
| 281 | + |
| 282 | + try: |
| 283 | + if compiler_mode == "online": |
| 284 | + html_template = generate_dynamic_html( |
| 285 | + st.session_state.code_language, st.session_state.generated_code) |
| 286 | + html(html_template, width=720, height=800, scrolling=True) |
| 287 | + |
| 288 | + else: |
| 289 | + output = run_code(st.session_state.generated_code, |
| 290 | + st.session_state.code_language) |
| 291 | + logger.info(f"Output execution: {output}") |
| 292 | + |
| 293 | + if "error" in output.lower() or "exception" in output.lower() or "SyntaxError" in output.lower() or "NameError" in output.lower(): |
| 294 | + |
| 295 | + logger.error(f"Error in code execution: {output}") |
| 296 | + response = sequential_chain({'code_topic': st.session_state.generated_code}) |
| 297 | + fixed_code = response['code_fix'] |
| 298 | + st.code(fixed_code, language=st.session_state.code_language.lower()) |
| 299 | + |
| 300 | + with st.expander('Message History'): |
| 301 | + st.info(memory.buffer) |
| 302 | + logger.warning(f"Trying to run fixed code: {fixed_code}") |
| 303 | + output = run_code(fixed_code, st.session_state.code_language) |
| 304 | + logger.warning(f"Fixed code output: {output}") |
| 305 | + |
| 306 | + st.code(st.session_state.generated_code, |
| 307 | + language=st.session_state.code_language.lower()) |
| 308 | + st.write("Execution Output:") |
| 309 | + st.write(output) |
| 310 | + logger.info(f"Execution Output: {output}") |
| 311 | + |
| 312 | + except Exception as e: |
| 313 | + st.write("Error in code execution:") |
| 314 | + # Output the stack trace |
| 315 | + st.write(traceback.format_exc()) |
| 316 | + logger.error(f"Error in code execution: {traceback.format_exc()}") |
| 317 | + |
| 318 | + |
| 319 | +# Main method |
| 320 | +if __name__ == "__main__": |
| 321 | + |
| 322 | + # Create Set API Key in settings. |
| 323 | + if st.expander("Settings"): |
| 324 | + api_key = st.text_input("OpenAI API Key", type="password") |
| 325 | + if api_key: |
| 326 | + openai.api_key = api_key |
| 327 | + os.environ["OPENAI_API_KEY"] = api_key |
| 328 | + st.success("API Key set successfully.") |
| 329 | + setup_llm_chain() |
| 330 | + |
| 331 | + # Session state variables |
| 332 | + if "generated_code" not in st.session_state: |
| 333 | + st.session_state.generated_code = "" |
| 334 | + |
| 335 | + if "code_language" not in st.session_state: |
| 336 | + st.session_state.code_language = "" |
| 337 | + |
| 338 | + # Generate the code |
| 339 | + if button_generate and code_prompt: |
| 340 | + generate_code() |
| 341 | + |
| 342 | + # Save the code to a file |
| 343 | + if button_save and st.session_state.generated_code: |
| 344 | + save_code() |
| 345 | + |
| 346 | + # Execute the code |
| 347 | + if button_run and code_prompt: |
| 348 | + code_state_option = "online" if compiler_mode == "Online" else "offline" |
| 349 | + execute_code(code_state_option) |
0 commit comments