Skip to content

Commit 733b009

Browse files
committed
tot package
1 parent 7382f24 commit 733b009

33 files changed

+1579
-1502
lines changed

MANIFEST.in

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
include src/tot/data/24/24.csv
2+
include src/tot/data/crosswords/mini0505_0_100_5.json
3+
include src/tot/data/crosswords/mini0505.json
4+
include src/tot/data/text/data_100_random_text.txt
File renamed without changes.
File renamed without changes.

pyproject.toml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
[build-system]
2+
requires = ["setuptools >= 61.0.0"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "tot"
7+
version = "0.1.0"
8+
description = 'Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"'
9+
readme = "README.md"
10+
requires-python = ">= 3.7"
11+
authors = [{ name = "Shunyu Yao", email = "shunyuyao.cs@gmail.com" }]
12+
license = { text = "MIT License" }
13+
keywords = ["tree-search", "large-language-models", "llm", "prompting", "tree-of-thoughts"]
14+
classifiers = [
15+
"License :: OSI Approved :: MIT License",
16+
"Programming Language :: Python :: 3",
17+
"Programming Language :: Python :: 3.7",
18+
"Programming Language :: Python :: 3.8",
19+
"Programming Language :: Python :: 3.9",
20+
"Programming Language :: Python :: 3.10",
21+
"Programming Language :: Python :: 3.11",
22+
'Intended Audience :: Science/Research',
23+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
24+
]
25+
dynamic=["dependencies"]
26+
27+
28+
[tool.setuptools.dynamic]
29+
dependencies = {file = ["requirements.txt"]}
30+
31+
[tool.setuptools.packages.find]
32+
where = ["src"] # list of folders that contain the packages (["."] by default)
33+
34+
[project.urls]
35+
Homepage = "https://github.com/princeton-nlp/tree-of-thought-llm"

readme.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
<details>
55
<summary>Note: https://github.com/kyegomez/tree-of-thoughts is NOT the correct implementation to replicate paper results. </summary>
66

7-
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
7+
In fact, people have reported that his code [cannot](https://github.com/kyegomez/tree-of-thoughts/issues/52) [properly](https://github.com/kyegomez/tree-of-thoughts/issues/41) [run](https://github.com/kyegomez/tree-of-thoughts/issues/60), and is [probably automatically generated by ChatGPT](pics/fake.png), and [kyegomez has done so for other popular ML methods](https://twitter.com/qbitium/status/1663954096741814272), while intentionally refusing to link to official implementations for his own interests (see https://github.com/kyegomez/tree-of-thoughts/issues/54, https://github.com/kyegomez/tree-of-thoughts/issues/55, https://github.com/kyegomez/tree-of-thoughts/issues/56).
88
Unfortunately, Google/Github searches go to kyegomez's malicious repo by default as it has more stars. **Please DE-STAR his repo and STAR this to help other people avoid being misled, thanks!**
99
</details>
1010

1111

1212

1313

1414

15-
![teaser](teaser.png)
15+
![teaser](pics/teaser.png)
1616

1717
Official implementation for paper [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) with code, prompts, model outputs.
1818
Also check [its tweet thread](https://twitter.com/ShunyuYao12/status/1659357547474681857) in 1min.

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ sympy==1.12
1616
tqdm==4.65.0
1717
urllib3==2.0.2
1818
yarl==1.9.2
19+
pandas==2.0.3

run.py

Lines changed: 7 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,18 @@
11
import os
22
import json
3-
import itertools
43
import argparse
5-
import numpy as np
6-
from functools import partial
7-
from models import gpt, gpt_usage
8-
from tasks import get_task
94

10-
def get_value(task, x, y, n_evaluate_sample, cache_value=True):
11-
value_prompt = task.value_prompt_wrap(x, y)
12-
if cache_value and value_prompt in task.value_cache:
13-
return task.value_cache[value_prompt]
14-
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
15-
value = task.value_outputs_unwrap(x, y, value_outputs)
16-
if cache_value:
17-
task.value_cache[value_prompt] = value
18-
return value
19-
20-
def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
21-
values = []
22-
local_value_cache = {}
23-
for y in ys: # each partial output
24-
if y in local_value_cache: # avoid duplicate candidates
25-
value = 0
26-
else:
27-
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
28-
local_value_cache[y] = value
29-
values.append(value)
30-
return values
31-
32-
def get_votes(task, x, ys, n_evaluate_sample):
33-
vote_prompt = task.vote_prompt_wrap(x, ys)
34-
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
35-
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
36-
return values
37-
38-
def get_proposals(task, x, y):
39-
propose_prompt = task.propose_prompt_wrap(x, y)
40-
proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
41-
return [y + _ + '\n' for _ in proposals]
42-
43-
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
44-
if prompt_sample == 'standard':
45-
prompt = task.standard_prompt_wrap(x, y)
46-
elif prompt_sample == 'cot':
47-
prompt = task.cot_prompt_wrap(x, y)
48-
else:
49-
raise ValueError(f'prompt_sample {prompt_sample} not recognized')
50-
samples = gpt(prompt, n=n_generate_sample, stop=stop)
51-
return [y + _ for _ in samples]
52-
53-
def solve(args, task, idx, to_print=True):
54-
print(gpt)
55-
x = task.get_input(idx) # input
56-
ys = [''] # current output candidates
57-
infos = []
58-
for step in range(task.steps):
59-
# generation
60-
if args.method_generate == 'sample':
61-
new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
62-
elif args.method_generate == 'propose':
63-
new_ys = [get_proposals(task, x, y) for y in ys]
64-
new_ys = list(itertools.chain(*new_ys))
65-
ids = list(range(len(new_ys)))
66-
# evaluation
67-
if args.method_evaluate == 'vote':
68-
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
69-
elif args.method_evaluate == 'value':
70-
values = get_values(task, x, new_ys, args.n_evaluate_sample)
71-
72-
# selection
73-
if args.method_select == 'sample':
74-
ps = np.array(values) / sum(values)
75-
select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
76-
elif args.method_select == 'greedy':
77-
select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
78-
select_new_ys = [new_ys[select_id] for select_id in select_ids]
79-
80-
# log
81-
if to_print:
82-
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
83-
print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
84-
85-
infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
86-
ys = select_new_ys
87-
88-
if to_print:
89-
print(ys)
90-
return ys, {'steps': infos}
91-
92-
def naive_solve(args, task, idx, to_print=True):
93-
x = task.get_input(idx) # input
94-
ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
95-
return ys, {}
5+
from tot.tasks import get_task
6+
from tot.methods.bfs import solve, naive_solve
7+
from tot.models import gpt_usage
968

979
def run(args):
98-
task = get_task(args.task, args.task_file_path)
10+
task = get_task(args.task)
9911
logs, cnt_avg, cnt_any = [], 0, 0
100-
global gpt
101-
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
10212
if args.naive_run:
103-
file = f'logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
13+
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
10414
else:
105-
file = f'logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
15+
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
10616
os.makedirs(os.path.dirname(file), exist_ok=True)
10717

10818
for i in range(args.task_start_index, args.task_end_index):
@@ -136,7 +46,6 @@ def parse_args():
13646
args.add_argument('--temperature', type=float, default=0.7)
13747

13848
args.add_argument('--task', type=str, required=True, choices=['game24', 'text', 'crosswords'])
139-
args.add_argument('--task_file_path', type=str, required=True)
14049
args.add_argument('--task_start_index', type=int, default=900)
14150
args.add_argument('--task_end_index', type=int, default=1000)
14251

@@ -145,7 +54,7 @@ def parse_args():
14554

14655
args.add_argument('--method_generate', type=str, choices=['sample', 'propose'])
14756
args.add_argument('--method_evaluate', type=str, choices=['value', 'vote'])
148-
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'])
57+
args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy')
14958
args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run
15059
args.add_argument('--n_evaluate_sample', type=int, default=1)
15160
args.add_argument('--n_select_sample', type=int, default=1)

scripts/crosswords/cot_sampling.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
python run.py \
22
--task crosswords \
3-
--task_file_path mini0505_0_100_5.json \
43
--task_start_index 0 \
54
--task_end_index 20 \
65
--naive_run \

scripts/crosswords/search_crosswords-dfs.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"metadata": {},
1515
"outputs": [],
1616
"source": [
17-
"cd ../.."
17+
"cd .."
1818
]
1919
},
2020
{
@@ -24,9 +24,9 @@
2424
"outputs": [],
2525
"source": [
2626
"import json\n",
27-
"from prompts.crosswords import propose_prompt, value_prompt\n",
28-
"from models import gpt\n",
29-
"from tasks.crosswords import MiniCrosswordsEnv\n",
27+
"from tot.prompts.crosswords import propose_prompt, value_prompt\n",
28+
"from tot.models import gpt\n",
29+
"from tot.tasks.crosswords import MiniCrosswordsEnv\n",
3030
"\n",
3131
"env = MiniCrosswordsEnv()"
3232
]
@@ -61,7 +61,7 @@
6161
"source": [
6262
"import re\n",
6363
"import copy\n",
64-
"from models import gpt\n",
64+
"from tot.models import gpt\n",
6565
"\n",
6666
"def parse_line(input_str):\n",
6767
" # regular expression pattern to match the input string format\n",

scripts/crosswords/standard_sampling.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
python run.py \
22
--task crosswords \
3-
--task_file_path mini0505_0_100_5.json \
43
--task_start_index 0 \
54
--task_end_index 20 \
65
--naive_run \

0 commit comments

Comments
 (0)