Skip to content

Commit 8cb9dd6

Browse files
author
Will Kurt
authored
Merge pull request #3 from dottxt-ai/smol-world
Smol world (ADV-163) Added some final cleanup and am merging now
2 parents bbb2f1d + e55eb9f commit 8cb9dd6

File tree

7 files changed

+358
-0
lines changed

7 files changed

+358
-0
lines changed

its-a-smol-world/README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# The Bunny B1: Powered by SmolLM2
2+
3+
This is a demo to celebrate the release of [the `SmolLM2-1.7B` model](https://huggingface.co/collections/HuggingFaceTB/smollm2-6723884218bcda64b34d7db9) from Hugging Face 🤗!
4+
5+
Ever want to have a natural language interface to local apps? The Bunny B1 demonstrates how to combine the power of SmolLM2 with structured generation using [Outlines](https://github.com/dottxt-ai/outlines) to be able to map natural language requests to calls to applications, even on smaller devices.
6+
7+
Here's a look at the demo in action:
8+
9+
![Bunny B1](./demo.gif)
10+
11+
## Setting up the environment
12+
13+
```bash
14+
python3 -m venv .venv
15+
source .venv/bin/activate
16+
pip install -r requirements.txt
17+
```
18+
19+
## Running the demo
20+
21+
To start the demo, run the following command:
22+
23+
```bash
24+
python3 ./src/app.py
25+
```
26+
27+
The demo provides an interface for natural language interaction with a mobile device. You can provide natural language commands and the model will choose one of the following actions:
28+
29+
- Send a text message
30+
- Order a food delivery
31+
- Order a ride
32+
- Get the weather
33+
34+
To add a new function you can edit `functions.json` and follow the pattern you'll find in the examples.
35+
36+
## Good Test Examples:
37+
38+
"I'd like to order two coffees from starbucks"
39+
40+
"I need a ride to SEATAC terminal A"
41+
42+
"What's the weather in san francisco today?"
43+
44+
"Text Remi and tell him the project is looking good"
45+
46+
## Customizing
47+
48+
The `constants.py` file allows you to customize the model, device, and torch tensor type. This demo was created on a Mac so the default device is `mps`. You can swap this out for `cuda` if you'd like.

its-a-smol-world/demo.gif

5.7 MB
Loading

its-a-smol-world/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
outlines==0.1.0
2+
transformers
3+
torch
4+
accelerate>=0.26.0

its-a-smol-world/src/__init__.py

Whitespace-only changes.

its-a-smol-world/src/app.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import time
2+
import itertools
3+
import threading
4+
import sys
5+
import argparse
6+
from smol_mind import SmolMind, load_functions
7+
from constants import MODEL_NAME
8+
9+
def spinner(stop_event):
10+
spinner = itertools.cycle(['-', '/', '|', '\\'])
11+
while not stop_event.is_set():
12+
sys.stdout.write(next(spinner))
13+
sys.stdout.flush()
14+
sys.stdout.write('\b')
15+
time.sleep(0.1)
16+
17+
def main():
18+
# Add command-line argument parsing
19+
parser = argparse.ArgumentParser(description="SmolMind CLI")
20+
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug mode')
21+
parser.add_argument('-i', '--instruct', action='store_true', help='Enable instruct mode (disables continue mode)')
22+
args = parser.parse_args()
23+
24+
print("loading SmolMind...")
25+
functions = load_functions("./src/functions.json")
26+
sm = SmolMind(functions, model_name=MODEL_NAME, debug=args.debug, instruct=args.instruct)
27+
if args.debug:
28+
print("Using model:", sm.model_name)
29+
print("Debug mode:", "Enabled" if args.debug else "Disabled")
30+
print("Instruct mode:", "Enabled" if args.instruct else "Disabled")
31+
print("Welcome to the Bunny B1! What do you need?")
32+
while True:
33+
user_input = input("> ")
34+
if user_input.lower() in ["exit", "quit"]:
35+
print("Goodbye!")
36+
break
37+
38+
# Create a shared event to stop the spinner
39+
stop_event = threading.Event()
40+
41+
# Start the spinner in a separate thread
42+
spinner_thread = threading.Thread(target=spinner, args=(stop_event,))
43+
spinner_thread.daemon = True
44+
spinner_thread.start()
45+
46+
response = sm.get_function_call(user_input)
47+
48+
# Stop the spinner
49+
stop_event.set()
50+
spinner_thread.join()
51+
sys.stdout.write(' \b') # Erase the spinner
52+
53+
print(response)
54+
55+
if __name__ == "__main__":
56+
main()
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
{
2+
"functions": [
3+
{
4+
"name": "send_text",
5+
"description": "Send a text message to a contact",
6+
"parameters": {
7+
"type": "dict",
8+
"properties": {
9+
"to": {
10+
"type": "string",
11+
"description": "The name of the contact to send the text to."
12+
},
13+
"message": {
14+
"type": "string",
15+
"description": "The message to send to the contact."
16+
}
17+
},
18+
"required": ["to", "message"]
19+
}
20+
},
21+
{
22+
"name": "order_food",
23+
"description": "Order food from a restaurant",
24+
"parameters": {
25+
"type": "dict",
26+
"properties": {
27+
"restaurant": {
28+
"type": "string",
29+
"description": "The name of the restaurant to order from."
30+
},
31+
"item": {
32+
"type": "string",
33+
"description": "The name of the item to order."
34+
},
35+
"quantity": {
36+
"type": "integer",
37+
"description": "The quantity of the item to order."
38+
}
39+
},
40+
"required": ["restaurant", "item", "quantity"]
41+
}
42+
},
43+
{
44+
"name": "order_ride",
45+
"description": "Order a ride from a ride sharing service",
46+
"parameters": {
47+
"type": "dict",
48+
"properties": {
49+
"dest": {
50+
"type": "string",
51+
"description": "The destination of the ride."
52+
}
53+
},
54+
"required": ["dest"]
55+
}
56+
},
57+
{
58+
"name": "get_weather",
59+
"description": "Get the weather for a city",
60+
"parameters": {
61+
"type": "dict",
62+
"properties": {
63+
"city": {
64+
"type": "string",
65+
"description": "The city to get the weather for."
66+
}
67+
},
68+
"required": ["city"]
69+
}
70+
}
71+
]
72+
}

its-a-smol-world/src/smol_mind.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import json
2+
from textwrap import dedent
3+
import outlines
4+
from outlines.samplers import greedy
5+
from transformers import AutoTokenizer, logging
6+
import warnings
7+
from constants import MODEL_NAME, DEVICE, T_TYPE
8+
9+
10+
logging.set_verbosity_error()
11+
12+
def format_functions(functions):
13+
formatted_functions = []
14+
for func in functions:
15+
function_info = f"{func['name']}: {func['description']}\n"
16+
if 'parameters' in func and 'properties' in func['parameters']:
17+
for arg, details in func['parameters']['properties'].items():
18+
description = details.get('description', 'No description provided')
19+
function_info += f"- {arg}: {description}\n"
20+
formatted_functions.append(function_info)
21+
return "\n".join(formatted_functions)
22+
23+
SYSTEM_PROMPT_FOR_CHAT_MODEL = dedent("""
24+
You are an expert designed to call the correct function to solve a problem based on the user's request.
25+
The functions available (with required parameters) to you are:
26+
{functions}
27+
28+
You will be given a user prompt and you need to decide which function to call.
29+
You will then need to format the function call correctly and return it in the correct format.
30+
The format for the function call is:
31+
[func1(params_name=params_value]
32+
NO other text MUST be included.
33+
34+
For example:
35+
Request: I want to order a cheese pizza from Pizza Hut.
36+
Response: [order_food(restaurant="Pizza Hut", item="cheese pizza", quantity=1)]
37+
38+
Request: Is it raining in NY.
39+
Response: [get_weather(city="New York")]
40+
41+
Request: I need a ride to SFO.
42+
Response: [order_ride(destination="SFO")]
43+
44+
Request: I want to send a text to John saying Hello.
45+
Response: [send_text(to="John", message="Hello!")]
46+
""")
47+
48+
49+
ASSISTANT_PROMPT_FOR_CHAT_MODEL = dedent("""
50+
I understand and will only return the function call in the correct format.
51+
"""
52+
)
53+
USER_PROMPT_FOR_CHAT_MODEL = dedent("""
54+
Request: {user_prompt}.
55+
""")
56+
57+
def continue_prompt(question, functions, tokenizer):
58+
prompt = SYSTEM_PROMPT_FOR_CHAT_MODEL.format(functions=format_functions(functions))
59+
prompt += "\n\n"
60+
prompt += USER_PROMPT_FOR_CHAT_MODEL.format(user_prompt=question)
61+
return prompt
62+
63+
def instruct_prompt(question, functions, tokenizer):
64+
messages = [
65+
{"role": "user", "content": SYSTEM_PROMPT_FOR_CHAT_MODEL.format(functions=format_functions(functions))},
66+
{"role": "assistant", "content": ASSISTANT_PROMPT_FOR_CHAT_MODEL },
67+
{"role": "user", "content": USER_PROMPT_FOR_CHAT_MODEL.format(user_prompt=question)},
68+
]
69+
fc_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
70+
return fc_prompt
71+
72+
INTEGER = r"(-)?(0|[1-9][0-9]*)"
73+
STRING_INNER = r'([^"\\\x00-\x1F\x7F-\x9F]|\\["\\])'
74+
# We'll limit this to just a max of 42 characters
75+
STRING = f'"{STRING_INNER}{{1,42}}"'
76+
# i.e. 1 is a not a float but 1.0 is.
77+
FLOAT = rf"({INTEGER})(\.[0-9]+)([eE][+-][0-9]+)?"
78+
BOOLEAN = r"(true|false)"
79+
NULL = r"null"
80+
81+
simple_type_map = {
82+
"string": STRING,
83+
"any": STRING,
84+
"integer": INTEGER,
85+
"number": FLOAT,
86+
"float": FLOAT,
87+
"boolean": BOOLEAN,
88+
"null": NULL,
89+
}
90+
91+
def build_dict_regex(props):
92+
out_re = r"\{"
93+
args_part = ", ".join(
94+
[f'"{prop}": ' + type_to_regex(props[prop]) for prop in props]
95+
)
96+
return out_re + args_part + r"\}"
97+
98+
def type_to_regex(arg_meta):
99+
arg_type = arg_meta["type"]
100+
if arg_type == "object":
101+
arg_type = "dict"
102+
if arg_type == "dict":
103+
try:
104+
result = build_dict_regex(arg_meta["properties"])
105+
except KeyError:
106+
return "Definition does not contain 'properties' value."
107+
elif arg_type in ["array","tuple"]:
108+
pattern = type_to_regex(arg_meta["items"])
109+
result = r"\[(" + pattern + ", ){0,8}" + pattern + r"\]"
110+
else:
111+
result = simple_type_map[arg_type]
112+
return result
113+
114+
type_to_regex({
115+
"type": "array",
116+
"items": {"type": "float"}
117+
})
118+
119+
def build_standard_fc_regex(function_data):
120+
out_re = r"\[" + function_data["name"] + r"\("
121+
args_part = ", ".join(
122+
[
123+
f"{arg}=" + type_to_regex(function_data["parameters"]["properties"][arg])
124+
for arg in function_data["parameters"]["properties"]
125+
126+
if arg in function_data["parameters"]["required"]
127+
]
128+
)
129+
optional_part = "".join(
130+
[
131+
f"(, {arg}="
132+
+ type_to_regex(function_data["parameters"]["properties"][arg])
133+
+ r")?"
134+
for arg in function_data["parameters"]["properties"]
135+
if not (arg in function_data["parameters"]["required"])
136+
]
137+
)
138+
return out_re + args_part + optional_part + r"\)]"
139+
140+
def multi_function_fc_regex(fs):
141+
multi_regex = "|".join([
142+
rf"({build_standard_fc_regex(f)})" for f in fs
143+
])
144+
return multi_regex
145+
146+
def load_functions(path):
147+
with open(path, "r") as f:
148+
return json.load(f)['functions']
149+
150+
class SmolMind:
151+
def __init__(self, functions, model_name=MODEL_NAME,instruct=True,debug=False):
152+
self.model_name = model_name
153+
self.instruct = instruct
154+
self.debug = debug
155+
self.functions = functions
156+
self.fc_regex = multi_function_fc_regex(functions)
157+
self.model = outlines.models.transformers(
158+
model_name,
159+
device=DEVICE,
160+
model_kwargs={
161+
"trust_remote_code": True,
162+
"torch_dtype": T_TYPE,
163+
})
164+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
165+
self.generator = outlines.generate.regex(self.model, self.fc_regex, sampler=greedy())
166+
167+
def get_function_call(self, user_prompt):
168+
with warnings.catch_warnings():
169+
warnings.simplefilter("ignore")
170+
if self.instruct:
171+
prompt = instruct_prompt(user_prompt, self.functions, self.tokenizer)
172+
else:
173+
prompt = continue_prompt(user_prompt, self.functions, self.tokenizer)
174+
response = self.generator(prompt)
175+
if self.debug:
176+
print(f"functions: {self.functions}")
177+
print(f"prompt: {prompt}")
178+
return response

0 commit comments

Comments
 (0)