Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 114 additions & 54 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from utils.crawl_github_files import crawl_github_files
from utils.call_llm import call_llm
from utils.crawl_local_files import crawl_local_files

from utils.tools import batch_chunks, length_of_tokens, MAX_TOKENS, SPLIT_TOKENS

# Helper to get content for specific file indices
def get_content_for_indices(files_data, indices):
Expand Down Expand Up @@ -78,7 +78,9 @@ def exec(self, prep_res):
return files_list

def post(self, shared, prep_res, exec_res):
shared["files"] = exec_res # List of (path, content) tuples
shared["files"] = exec_res # List of (path, content) tuples ('browserbase\\cli.js', "#!/usr/bin/env node\nimport './dist/program.js';")




class IdentifyAbstractions(Node):
Expand All @@ -92,22 +94,32 @@ def prep(self, shared):
# Helper to create context from files, respecting limits (basic example)
def create_llm_context(files_data):
context = ""
context_list = []
file_total_info = []
file_info = [] # Store tuples of (index, path)
tokens_num = 0
for i, (path, content) in enumerate(files_data):
entry = f"--- File Index {i}: {path} ---\n{content}\n\n"
context += entry
file_info.append((i, path))

return context, file_info # file_info is list of (index, path)

context, file_info = create_llm_context(files_data)
# Format file info for the prompt (comment is just a hint for LLM)
file_listing_for_prompt = "\n".join(
[f"- {idx} # {path}" for idx, path in file_info]
)
tokens_num += length_of_tokens(entry)
if tokens_num < SPLIT_TOKENS:
context += entry
file_info.append((i, path))
else:
file_total_info.append(file_info)
context_list.append(context)
file_info = [(i, path)]
tokens_num = length_of_tokens(entry)
context = entry
if context:
context_list.append(context)
file_total_info.append(file_info)

return context_list, file_total_info # file_info is list of (index, path)

context_list, file_info = create_llm_context(files_data)
return (
context,
file_listing_for_prompt,
context_list,
file_info,
len(files_data),
project_name,
language,
Expand All @@ -117,27 +129,29 @@ def create_llm_context(files_data):

def exec(self, prep_res):
(
context,
file_listing_for_prompt,
context_list,
file_info,
file_count,
project_name,
language,
use_cache,
max_abstraction_num,
) = prep_res # Unpack all parameters
) = prep_res
print(f"Identifying abstractions using LLM...")

# Add language instruction and hints only if not English
language_instruction = ""
name_lang_hint = ""
desc_lang_hint = ""
if language.lower() != "english":
language_instruction = f"IMPORTANT: Generate the `name` and `description` for each abstraction in **{language.capitalize()}** language. Do NOT use English for these fields.\n\n"
# Keep specific hints here as name/description are primary targets
language_instruction = f"IMPORTANT: Generate the `name` and `description` for each abstraction in **{language.capitalize()}** language. Do NOT use English for these fields.\n\n"

name_lang_hint = f" (value in {language.capitalize()})"
desc_lang_hint = f" (value in {language.capitalize()})"

prompt = f"""
response_result = ""
for i, context in enumerate(context_list):
file_listing_for_prompt = "\n".join([f"- {idx} # {path}" for idx, path in file_info[i]])
prompt = f"""
For the project `{project_name}`:

Codebase Context:
Expand Down Expand Up @@ -173,12 +187,22 @@ def exec(self, prep_res):
- 5 # path/to/another.js
# ... up to {max_abstraction_num} abstractions
```"""
response = call_llm(prompt, use_cache=(use_cache and self.cur_retry == 0)) # Use cache only if enabled and not retrying
response = call_llm(prompt, use_cache=(use_cache and self.cur_retry == 0))
response_result += response +"\n"

# --- Validation ---
yaml_str = response.strip().split("```yaml")[1].split("```")[0].strip()
abstractions = yaml.safe_load(yaml_str)
yaml_str = ""
yaml_str_list = response_result.strip().split("```yaml")
for item in yaml_str_list:
for c in item.strip().split("```"):
if c.strip():
yaml_str += c +"\n"

abstractions = yaml.safe_load(yaml_str)
if isinstance(abstractions, dict):
res = []
for key in abstractions.keys():
res.extend(abstractions[key])
abstractions = res
if not isinstance(abstractions, list):
raise ValueError("LLM Output is not a list")

Expand Down Expand Up @@ -251,35 +275,64 @@ def prep(self, shared):
num_abstractions = len(abstractions)

# Create context with abstraction names, indices, descriptions, and relevant file snippets
context_list = []
context = "Identified Abstractions:\\n"
context_header = "Identified Abstractions:\\n"
tokens_nums = 0
all_relevant_indices_list = list()

all_relevant_indices = set()
abstraction_info_for_prompt_list = []
abstraction_info_for_prompt = []
for i, abstr in enumerate(abstractions):
# Use 'files' which contains indices directly
file_indices_str = ", ".join(map(str, abstr["files"]))
# Abstraction name and description might be translated already
info_line = f"- Index {i}: {abstr['name']} (Relevant file indices: [{file_indices_str}])\\n Description: {abstr['description']}"
context += info_line + "\\n"
abstraction_info_for_prompt.append(
f"{i} # {abstr['name']}"
) # Use potentially translated name here too
all_relevant_indices.update(abstr["files"])

context += "\\nRelevant File Snippets (Referenced by Index and Path):\\n"
# Get content for relevant files using helper
relevant_files_content_map = get_content_for_indices(
files_data, sorted(list(all_relevant_indices))
)
# Format file content for context
file_context_str = "\\n\\n".join(
f"--- File: {idx_path} ---\\n{content}"
for idx_path, content in relevant_files_content_map.items()
)
context += file_context_str
token_nums += length_of_tokens(info_line)
if token_nums < SPLIT_TOKENS*0.5:
context += info_line + "\\n"
abstraction_info_for_prompt.append(
f"{i} # {abstr['name']}"
) # Use potentially translated name here too
all_relevant_indices.update(abstr["files"])
else:
context_list.append(context)
abstraction_info_for_prompt_list.append(abstraction_info_for_prompt)
all_relevant_indices_list.append(all_relevant_indices)
context = "Identified Abstractions:\\n" + info_line + "\\n"
abstraction_info_for_prompt = [f"{i} # {abstr['name']}"]
all_relevant_indices = set([abstr["files"]])
if abstraction_info_for_prompt:
context_list.append(context)
abstraction_info_for_prompt_list.append(abstraction_info_for_prompt)
all_relevant_indices_list.append(all_relevant_indices)


for context_index, context in enumerate(context_list):

context += "\\nRelevant File Snippets (Referenced by Index and Path):\\n"
relevant_files_content_map = get_content_for_indices(
files_data, sorted(list(all_relevant_indices_list[context_list]))
)
file_context_str = ""
token_nums = 0

for idx_path, content in relevant_files_content_map.items():
entry = f"--- File: {idx_path} ---\\n{content}" + "\\n\\n"
token_nums += length_of_tokens(entry)
if token_nums < SPLIT_TOKENS * 0.3:
file_context_str += entry
else:
context += file_context_str
file_context_str = ""
if file_context_str:
context += file_context_str
context_list[context_index] = context

return (
context,
"\n".join(abstraction_info_for_prompt),
context_list,
abstraction_info_for_prompt_list,
num_abstractions, # Pass the actual count
project_name,
language,
Expand All @@ -288,8 +341,8 @@ def prep(self, shared):

def exec(self, prep_res):
(
context,
abstraction_listing,
context_list,
abstraction_info_listing,
num_abstractions, # Receive the actual count
project_name,
language,
Expand All @@ -305,8 +358,10 @@ def exec(self, prep_res):
language_instruction = f"IMPORTANT: Generate the `summary` and relationship `label` fields in **{language.capitalize()}** language. Do NOT use English for these fields.\n\n"
lang_hint = f" (in {language.capitalize()})"
list_lang_note = f" (Names might be in {language.capitalize()})" # Note for the input list

prompt = f"""
response_result = ""
for c_index, context in enumerate(context_list):
abstraction_listing = "\n".join(abstraction_info_listing[c_index])
prompt = f"""
Based on the following abstractions and relevant code snippets from the project `{project_name}`:

List of Abstraction Indices and Names{list_lang_note}:
Expand Down Expand Up @@ -344,10 +399,15 @@ def exec(self, prep_res):

Now, provide the YAML output:
"""
response = call_llm(prompt, use_cache=(use_cache and self.cur_retry == 0)) # Use cache only if enabled and not retrying

# --- Validation ---
yaml_str = response.strip().split("```yaml")[1].split("```")[0].strip()
response = call_llm(prompt, use_cache=(use_cache and self.cur_retry == 0)) # Use cache only if enabled and not retrying
response_result += response +"\n"

yaml_str = ""
yaml_str_list = response_result.strip().split("```yaml")
for item in yaml_str_list:
for c in item.strip().split("```"):
if c.strip():
yaml_str += c.strip() +"\n"
relationships_data = yaml.safe_load(yaml_str)

if not isinstance(relationships_data, dict) or not all(
Expand Down Expand Up @@ -395,7 +455,7 @@ def exec(self, prep_res):
except (ValueError, TypeError):
raise ValueError(f"Could not parse indices from relationship: {rel}")

print("Generated project summary and relationship details.")
print(f"Generated project summary and relationship details {len(validated_relationships)}.")
return {
"summary": relationships_data["summary"], # Potentially translated summary
"details": validated_relationships, # Store validated, index-based relationships with potentially translated labels
Expand Down Expand Up @@ -488,7 +548,7 @@ def exec(self, prep_res):
"""
response = call_llm(prompt, use_cache=(use_cache and self.cur_retry == 0)) # Use cache only if enabled and not retrying

# --- Validation ---

yaml_str = response.strip().split("```yaml")[1].split("```")[0].strip()
ordered_indices_raw = yaml.safe_load(yaml_str)

Expand Down
39 changes: 39 additions & 0 deletions utils/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import tiktoken, json
from typing import List, Dict, Iterable


MAX_TOKENS = 16000
SPLIT_TOKENS = MAX_TOKENS * 0.5

def length_of_tokens(prompt, encoding_name: str = "cl100k_base"):
enc = tiktoken.get_encoding(encoding_name)
all_tokens = enc.encode(prompt, disallowed_special=())
return len(all_tokens)


def batch_chunks(chunks: List[str], prompt:str = "", params_name:str = "chunk_list_content", extra_vars: Dict[str, str] | None = None) -> Iterable[List[str]]:
"""
将 chunks 切分成若干批,使得 `prompt + batch` 的 token
不超过 BATCH_TARGET。
"""
batch: List[str] = []
cur_tokens = 0
extra_vars = extra_vars or {}
param_format= {**extra_vars, params_name: ""}
prompt_tokens = length_of_tokens(
prompt.format(**param_format)
)

for chunk in chunks:
chunk_tokens = length_of_tokens(chunk)
if (
cur_tokens + chunk_tokens + prompt_tokens > SPLIT_TOKENS
and batch
):
yield batch
batch = []
cur_tokens = 0
batch.append(chunk)
cur_tokens += chunk_tokens
if batch:
yield batch