Skip to content

Commit 786852e

Browse files
authored
partial variables (langchain-ai#1308)
1 parent 72ef69d commit 786852e

File tree

8 files changed

+370
-14
lines changed

8 files changed

+370
-14
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "9355a547",
6+
"metadata": {},
7+
"source": [
8+
"# Partial Prompt Templates\n",
9+
"\n",
10+
"A prompt template is a class with a `.format` method which takes in a key-value map and returns a string (a prompt) to pass to the language model. Like other methods, it can make sense to \"partial\" a prompt template - eg pass in a subset of the required values, as to create a new prompt template which expects only the remaining subset of values.\n",
11+
"\n",
12+
"LangChain supports this in two ways: we allow for partially formatted prompts (1) with string values, (2) with functions that return string values. These two different ways support different use cases. In the documentation below we go over the motivations for both use cases as well as how to do it in LangChain.\n",
13+
"\n",
14+
"## Partial With Strings\n",
15+
"\n",
16+
"One common use case for wanting to partial a prompt template is if you get some of the variables before others. For example, suppose you have a prompt template that requires two variables, `foo` and `baz`. If you get the `foo` value early on in the chain, but the `baz` value later, it can be annoying to wait until you have both variables in the same place to pass them to the prompt template. Instead, you can partial the prompt template with the `foo` value, and then pass the partialed prompt template along and just use that. Below is an example of doing this:"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": 1,
22+
"id": "643af5da",
23+
"metadata": {},
24+
"outputs": [],
25+
"source": [
26+
"from langchain.prompts import PromptTemplate"
27+
]
28+
},
29+
{
30+
"cell_type": "code",
31+
"execution_count": 2,
32+
"id": "4080d8d7",
33+
"metadata": {},
34+
"outputs": [
35+
{
36+
"name": "stdout",
37+
"output_type": "stream",
38+
"text": [
39+
"foobaz\n"
40+
]
41+
}
42+
],
43+
"source": [
44+
"prompt = PromptTemplate(template=\"{foo}{bar}\", input_variables=[\"foo\", \"bar\"])\n",
45+
"partial_prompt = prompt.partial(foo=\"foo\");\n",
46+
"print(partial_prompt.format(bar=\"baz\"))"
47+
]
48+
},
49+
{
50+
"cell_type": "markdown",
51+
"id": "9986766e",
52+
"metadata": {},
53+
"source": [
54+
"You can also just initialize the prompt with the partialed variables."
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": 3,
60+
"id": "e2ce95b3",
61+
"metadata": {},
62+
"outputs": [
63+
{
64+
"name": "stdout",
65+
"output_type": "stream",
66+
"text": [
67+
"foobaz\n"
68+
]
69+
}
70+
],
71+
"source": [
72+
"prompt = PromptTemplate(template=\"{foo}{bar}\", input_variables=[\"bar\"], partial_variables={\"foo\": \"foo\"})\n",
73+
"print(prompt.format(bar=\"baz\"))"
74+
]
75+
},
76+
{
77+
"cell_type": "markdown",
78+
"id": "a9c66f83",
79+
"metadata": {},
80+
"source": [
81+
"## Partial With Functions\n",
82+
"\n",
83+
"The other common use is to partial with a function. The use case for this is when you have a variable you know that you always want to fetch in a common way. A prime example of this is with date or time. Imagine you have a prompt which you always want to have the current date. You can't hard code it in the prompt, and passing it along with the other input variables is a bit annoying. In this case, it's very handy to be able to partial the prompt with a function that always returns the current date."
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": 4,
89+
"id": "d0712d8a",
90+
"metadata": {},
91+
"outputs": [],
92+
"source": [
93+
"from datetime import datetime\n",
94+
"\n",
95+
"def _get_datetime():\n",
96+
" now = datetime.now()\n",
97+
" return now.strftime(\"%m/%d/%Y, %H:%M:%S\")"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": 5,
103+
"id": "4cbcb666",
104+
"metadata": {},
105+
"outputs": [
106+
{
107+
"name": "stdout",
108+
"output_type": "stream",
109+
"text": [
110+
"Tell me a funny joke about the day 02/27/2023, 22:15:16\n"
111+
]
112+
}
113+
],
114+
"source": [
115+
"prompt = PromptTemplate(\n",
116+
" template=\"Tell me a {adjective} joke about the day {date}\", \n",
117+
" input_variables=[\"adjective\", \"date\"]\n",
118+
");\n",
119+
"partial_prompt = prompt.partial(date=_get_datetime)\n",
120+
"print(partial_prompt.format(adjective=\"funny\"))"
121+
]
122+
},
123+
{
124+
"cell_type": "markdown",
125+
"id": "ffed6811",
126+
"metadata": {},
127+
"source": [
128+
"You can also just initialize the prompt with the partialed variables, which often makes more sense in this workflow."
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": 6,
134+
"id": "96285b25",
135+
"metadata": {},
136+
"outputs": [
137+
{
138+
"name": "stdout",
139+
"output_type": "stream",
140+
"text": [
141+
"Tell me a funny joke about the day 02/27/2023, 22:15:16\n"
142+
]
143+
}
144+
],
145+
"source": [
146+
"prompt = PromptTemplate(\n",
147+
" template=\"Tell me a {adjective} joke about the day {date}\", \n",
148+
" input_variables=[\"adjective\"],\n",
149+
" partial_variables={\"date\": _get_datetime}\n",
150+
");\n",
151+
"print(prompt.format(adjective=\"funny\"))"
152+
]
153+
},
154+
{
155+
"cell_type": "code",
156+
"execution_count": null,
157+
"id": "4bff16f7",
158+
"metadata": {},
159+
"outputs": [],
160+
"source": []
161+
}
162+
],
163+
"metadata": {
164+
"kernelspec": {
165+
"display_name": "Python 3 (ipykernel)",
166+
"language": "python",
167+
"name": "python3"
168+
},
169+
"language_info": {
170+
"codemirror_mode": {
171+
"name": "ipython",
172+
"version": 3
173+
},
174+
"file_extension": ".py",
175+
"mimetype": "text/x-python",
176+
"name": "python",
177+
"nbconvert_exporter": "python",
178+
"pygments_lexer": "ipython3",
179+
"version": "3.9.1"
180+
}
181+
},
182+
"nbformat": 4,
183+
"nbformat_minor": 5
184+
}

docs/modules/prompts/how_to_guides.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ The user guide here shows more advanced workflows and how to use the library in
1717

1818
`Few Shot Prompt Examples <./examples/few_shot_examples.html>`_: Examples of Few Shot Prompt Templates.
1919

20+
`Partial Prompt Template <./examples/partial.html>`_: How to partial Prompt Templates.
21+
2022

2123

2224
.. toctree::

langchain/prompts/base.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""BasePrompt schema definition."""
2+
from __future__ import annotations
3+
24
import json
35
import re
46
from abc import ABC, abstractmethod
57
from pathlib import Path
6-
from typing import Any, Callable, Dict, List, Optional, Union
8+
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
79

810
import yaml
9-
from pydantic import BaseModel, Extra, root_validator
11+
from pydantic import BaseModel, Extra, Field, root_validator
1012

1113
from langchain.formatting import formatter
1214

@@ -117,6 +119,9 @@ class BasePromptTemplate(BaseModel, ABC):
117119
"""A list of the names of the variables the prompt template expects."""
118120
output_parser: Optional[BaseOutputParser] = None
119121
"""How to parse the output of calling an LLM on this formatted prompt."""
122+
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
123+
default_factory=dict
124+
)
120125

121126
class Config:
122127
"""Configuration for this pydantic object."""
@@ -132,8 +137,38 @@ def validate_variable_names(cls, values: Dict) -> Dict:
132137
"Cannot have an input variable named 'stop', as it is used internally,"
133138
" please rename."
134139
)
140+
if "stop" in values["partial_variables"]:
141+
raise ValueError(
142+
"Cannot have an partial variable named 'stop', as it is used "
143+
"internally, please rename."
144+
)
145+
146+
overall = set(values["input_variables"]).intersection(
147+
values["partial_variables"]
148+
)
149+
if overall:
150+
raise ValueError(
151+
f"Found overlapping input and partial variables: {overall}"
152+
)
135153
return values
136154

155+
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
156+
"""Return a partial of the prompt template."""
157+
prompt_dict = self.__dict__.copy()
158+
prompt_dict["input_variables"] = list(
159+
set(self.input_variables).difference(kwargs)
160+
)
161+
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
162+
return type(self)(**prompt_dict)
163+
164+
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
165+
# Get partial params:
166+
partial_kwargs = {
167+
k: v if isinstance(v, str) else v()
168+
for k, v in self.partial_variables.items()
169+
}
170+
return {**partial_kwargs, **kwargs}
171+
137172
@abstractmethod
138173
def format(self, **kwargs: Any) -> str:
139174
"""Format the prompt with the inputs.
@@ -173,6 +208,8 @@ def save(self, file_path: Union[Path, str]) -> None:
173208
174209
prompt.save(file_path="path/prompt.yaml")
175210
"""
211+
if self.partial_variables:
212+
raise ValueError("Cannot save prompt with partial variables.")
176213
# Convert file to Path object.
177214
if isinstance(file_path, str):
178215
save_path = Path(file_path)

langchain/prompts/few_shot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def template_is_valid(cls, values: Dict) -> Dict:
6868
check_valid_template(
6969
values["prefix"] + values["suffix"],
7070
values["template_format"],
71-
values["input_variables"],
71+
values["input_variables"] + list(values["partial_variables"]),
7272
)
7373
return values
7474

@@ -101,6 +101,7 @@ def format(self, **kwargs: Any) -> str:
101101
102102
prompt.format(variable1="foo")
103103
"""
104+
kwargs = self._merge_partial_and_user_variables(**kwargs)
104105
# Get the examples to use.
105106
examples = self._get_examples(**kwargs)
106107
# Format the examples.
@@ -110,6 +111,7 @@ def format(self, **kwargs: Any) -> str:
110111
# Create the overall template.
111112
pieces = [self.prefix, *example_strings, self.suffix]
112113
template = self.example_separator.join([piece for piece in pieces if piece])
114+
113115
# Format the template with the input variables.
114116
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
115117

langchain/prompts/few_shot_with_templates.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,18 @@ def check_examples_and_selector(cls, values: Dict) -> Dict:
6060
@root_validator()
6161
def template_is_valid(cls, values: Dict) -> Dict:
6262
"""Check that prefix, suffix and input variables are consistent."""
63-
input_variables = values["input_variables"]
64-
expected_input_variables = set(values["suffix"].input_variables)
65-
if values["prefix"] is not None:
66-
expected_input_variables |= set(values["prefix"].input_variables)
67-
missing_vars = expected_input_variables.difference(input_variables)
68-
if missing_vars:
69-
raise ValueError(
70-
f"Got input_variables={input_variables}, but based on prefix/suffix "
71-
f"expected {expected_input_variables}"
72-
)
63+
if values["validate_template"]:
64+
input_variables = values["input_variables"]
65+
expected_input_variables = set(values["suffix"].input_variables)
66+
expected_input_variables |= set(values["partial_variables"])
67+
if values["prefix"] is not None:
68+
expected_input_variables |= set(values["prefix"].input_variables)
69+
missing_vars = expected_input_variables.difference(input_variables)
70+
if missing_vars:
71+
raise ValueError(
72+
f"Got input_variables={input_variables}, but based on "
73+
f"prefix/suffix expected {expected_input_variables}"
74+
)
7375
return values
7476

7577
class Config:
@@ -101,6 +103,7 @@ def format(self, **kwargs: Any) -> str:
101103
102104
prompt.format(variable1="foo")
103105
"""
106+
kwargs = self._merge_partial_and_user_variables(**kwargs)
104107
# Get the examples to use.
105108
examples = self._get_examples(**kwargs)
106109
# Format the examples.

langchain/prompts/prompt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,16 @@ def format(self, **kwargs: Any) -> str:
6060
6161
prompt.format(variable1="foo")
6262
"""
63+
kwargs = self._merge_partial_and_user_variables(**kwargs)
6364
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
6465

6566
@root_validator()
6667
def template_is_valid(cls, values: Dict) -> Dict:
6768
"""Check that template and input variables are consistent."""
6869
if values["validate_template"]:
70+
all_inputs = values["input_variables"] + list(values["partial_variables"])
6971
check_valid_template(
70-
values["template"], values["template_format"], values["input_variables"]
72+
values["template"], values["template_format"], all_inputs
7173
)
7274
return values
7375

0 commit comments

Comments
 (0)