Skip to content

Commit 709386a

Browse files
authored
Customizable AI templates (#11884)
* Start to work * Fix runtime problems and checkers * Merge with main * Fix from code review * Update from merge * Fix compiler errors * Fix from code review
1 parent aaff6f5 commit 709386a

19 files changed

+386
-141
lines changed

build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ dependencies {
351351
exclude group: 'org.jetbrains.kotlin'
352352
}
353353

354-
354+
implementation 'org.apache.velocity:velocity-engine-core:2.3'
355355
implementation platform('ai.djl:bom:0.30.0')
356356
implementation 'ai.djl:api'
357357
implementation 'ai.djl.huggingface:tokenizers'

src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@
160160
uses ai.djl.repository.RepositoryFactory;
161161
uses ai.djl.repository.zoo.ZooProvider;
162162
uses dev.langchain4j.spi.prompt.PromptTemplateFactory;
163+
requires velocity.engine.core;
163164
// endregion
164165

165166
// region: Lucene

src/main/java/org/jabref/gui/preferences/ai/AiTab.fxml

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
<?import com.dlsc.unitfx.IntegerInputField?>
1717
<?import org.controlsfx.control.SearchableComboBox?>
1818
<?import org.controlsfx.control.textfield.CustomPasswordField?>
19+
<?import javafx.scene.control.TabPane?>
20+
<?import javafx.scene.control.Tab?>
21+
<?import javafx.scene.control.TextArea?>
1922
<fx:root
2023
spacing="10.0"
2124
type="VBox"
@@ -162,10 +165,6 @@
162165
</children>
163166
</HBox>
164167

165-
<ResizableTextArea
166-
fx:id="instructionTextArea"
167-
wrapText="true"/>
168-
169168
<GridPane hgap="10" vgap="10">
170169
<columnConstraints>
171170
<ColumnConstraints hgrow="ALWAYS" percentWidth="50" />
@@ -235,5 +234,37 @@
235234
glyph="REFRESH"/>
236235
</graphic>
237236
</Button>
237+
238+
<HBox alignment="BASELINE_CENTER">
239+
<Label styleClass="sectionHeader"
240+
text="%Templates"
241+
maxWidth="Infinity"
242+
HBox.hgrow="ALWAYS"/>
243+
<Button fx:id="templatesHelp"
244+
prefWidth="20.0"/>
245+
</HBox>
246+
247+
<TabPane>
248+
<Tab text="%System message for chatting" closable="false">
249+
<TextArea fx:id="systemMessageTextArea"/>
250+
</Tab>
251+
<Tab text="User message for chatting" closable="false">
252+
<TextArea fx:id="userMessageTextArea"/>
253+
</Tab>
254+
<Tab text="Completion text for summarization of a chunk" closable="false">
255+
<TextArea fx:id="summarizationChunkTextArea"/>
256+
</Tab>
257+
<Tab text="Completion text for summarization of several chunks" closable="false">
258+
<TextArea fx:id="summarizationCombineTextArea"/>
259+
</Tab>
260+
</TabPane>
261+
262+
<Button onAction="#onResetTemplatesButtonClick"
263+
text="%Reset templates to default">
264+
<graphic>
265+
<JabRefIconView
266+
glyph="REFRESH"/>
267+
</graphic>
268+
</Button>
238269
</children>
239270
</fx:root>

src/main/java/org/jabref/gui/preferences/ai/AiTab.java

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import javafx.scene.control.Button;
88
import javafx.scene.control.CheckBox;
99
import javafx.scene.control.ComboBox;
10+
import javafx.scene.control.TextArea;
1011
import javafx.scene.control.TextField;
1112

1213
import org.jabref.gui.actions.ActionFactory;
@@ -15,13 +16,13 @@
1516
import org.jabref.gui.preferences.AbstractPreferenceTabView;
1617
import org.jabref.gui.preferences.PreferencesTab;
1718
import org.jabref.gui.util.ViewModelListCellFactory;
19+
import org.jabref.logic.ai.templates.AiTemplate;
1820
import org.jabref.logic.help.HelpFile;
1921
import org.jabref.logic.l10n.Localization;
2022
import org.jabref.model.ai.AiProvider;
2123
import org.jabref.model.ai.EmbeddingModel;
2224

2325
import com.airhacks.afterburner.views.ViewLoader;
24-
import com.dlsc.gemsfx.ResizableTextArea;
2526
import com.dlsc.unitfx.IntegerInputField;
2627
import de.saxsys.mvvmfx.utils.validation.visualization.ControlsFxVisualizer;
2728
import org.controlsfx.control.SearchableComboBox;
@@ -43,16 +44,21 @@ public class AiTab extends AbstractPreferenceTabView<AiTabViewModel> implements
4344

4445
@FXML private TextField apiBaseUrlTextField;
4546
@FXML private SearchableComboBox<EmbeddingModel> embeddingModelComboBox;
46-
@FXML private ResizableTextArea instructionTextArea;
4747
@FXML private TextField temperatureTextField;
4848
@FXML private IntegerInputField contextWindowSizeTextField;
4949
@FXML private IntegerInputField documentSplitterChunkSizeTextField;
5050
@FXML private IntegerInputField documentSplitterOverlapSizeTextField;
5151
@FXML private IntegerInputField ragMaxResultsCountTextField;
5252
@FXML private TextField ragMinScoreTextField;
5353

54+
@FXML private TextArea systemMessageTextArea;
55+
@FXML private TextArea userMessageTextArea;
56+
@FXML private TextArea summarizationChunkTextArea;
57+
@FXML private TextArea summarizationCombineTextArea;
58+
5459
@FXML private Button generalSettingsHelp;
5560
@FXML private Button expertSettingsHelp;
61+
@FXML private Button templatesHelp;
5662

5763
private final ControlsFxVisualizer visualizer = new ControlsFxVisualizer();
5864

@@ -74,14 +80,14 @@ public void initialize() {
7480
new ViewModelListCellFactory<AiProvider>()
7581
.withText(AiProvider::toString)
7682
.install(aiProviderComboBox);
77-
aiProviderComboBox.setItems(viewModel.aiProvidersProperty());
83+
aiProviderComboBox.itemsProperty().bind(viewModel.aiProvidersProperty());
7884
aiProviderComboBox.valueProperty().bindBidirectional(viewModel.selectedAiProviderProperty());
7985
aiProviderComboBox.disableProperty().bind(viewModel.disableBasicSettingsProperty());
8086

8187
new ViewModelListCellFactory<String>()
8288
.withText(text -> text)
8389
.install(chatModelComboBox);
84-
chatModelComboBox.setItems(viewModel.chatModelsProperty());
90+
chatModelComboBox.itemsProperty().bind(viewModel.chatModelsProperty());
8591
chatModelComboBox.valueProperty().bindBidirectional(viewModel.selectedChatModelProperty());
8692
chatModelComboBox.disableProperty().bind(viewModel.disableBasicSettingsProperty());
8793

@@ -123,9 +129,6 @@ public void initialize() {
123129
apiBaseUrlTextField.setDisable(newValue || viewModel.disableExpertSettingsProperty().get())
124130
);
125131

126-
instructionTextArea.textProperty().bindBidirectional(viewModel.instructionProperty());
127-
instructionTextArea.disableProperty().bind(viewModel.disableExpertSettingsProperty());
128-
129132
// bindBidirectional doesn't work well with number input fields ({@link IntegerInputField}, {@link DoubleInputField}),
130133
// so they are expanded into `addListener` calls.
131134

@@ -180,7 +183,6 @@ public void initialize() {
180183
visualizer.initVisualization(viewModel.getChatModelValidationStatus(), chatModelComboBox);
181184
visualizer.initVisualization(viewModel.getApiBaseUrlValidationStatus(), apiBaseUrlTextField);
182185
visualizer.initVisualization(viewModel.getEmbeddingModelValidationStatus(), embeddingModelComboBox);
183-
visualizer.initVisualization(viewModel.getSystemMessageValidationStatus(), instructionTextArea);
184186
visualizer.initVisualization(viewModel.getTemperatureTypeValidationStatus(), temperatureTextField);
185187
visualizer.initVisualization(viewModel.getTemperatureRangeValidationStatus(), temperatureTextField);
186188
visualizer.initVisualization(viewModel.getMessageWindowSizeValidationStatus(), contextWindowSizeTextField);
@@ -191,9 +193,15 @@ public void initialize() {
191193
visualizer.initVisualization(viewModel.getRagMinScoreRangeValidationStatus(), ragMinScoreTextField);
192194
});
193195

196+
systemMessageTextArea.textProperty().bindBidirectional(viewModel.getTemplateSources().get(AiTemplate.CHATTING_SYSTEM_MESSAGE));
197+
userMessageTextArea.textProperty().bindBidirectional(viewModel.getTemplateSources().get(AiTemplate.CHATTING_USER_MESSAGE));
198+
summarizationChunkTextArea.textProperty().bindBidirectional(viewModel.getTemplateSources().get(AiTemplate.SUMMARIZATION_CHUNK));
199+
summarizationCombineTextArea.textProperty().bindBidirectional(viewModel.getTemplateSources().get(AiTemplate.SUMMARIZATION_COMBINE));
200+
194201
ActionFactory actionFactory = new ActionFactory();
195202
actionFactory.configureIconButton(StandardActions.HELP, new HelpAction(HelpFile.AI_GENERAL_SETTINGS, dialogService, preferences.getExternalApplicationsPreferences()), generalSettingsHelp);
196203
actionFactory.configureIconButton(StandardActions.HELP, new HelpAction(HelpFile.AI_EXPERT_SETTINGS, dialogService, preferences.getExternalApplicationsPreferences()), expertSettingsHelp);
204+
actionFactory.configureIconButton(StandardActions.HELP, new HelpAction(HelpFile.AI_TEMPLATES, dialogService, preferences.getExternalApplicationsPreferences()), templatesHelp);
197205
}
198206

199207
@Override
@@ -206,6 +214,11 @@ private void onResetExpertSettingsButtonClick() {
206214
viewModel.resetExpertSettings();
207215
}
208216

217+
@FXML
218+
private void onResetTemplatesButtonClick() {
219+
viewModel.resetTemplates();
220+
}
221+
209222
public ReadOnlyBooleanProperty aiEnabledProperty() {
210223
return enableAi.selectedProperty();
211224
}

src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package org.jabref.gui.preferences.ai;
22

3+
import java.util.Arrays;
34
import java.util.List;
45
import java.util.Locale;
6+
import java.util.Map;
57
import java.util.Objects;
68

79
import javafx.beans.property.BooleanProperty;
@@ -20,6 +22,7 @@
2022
import org.jabref.gui.preferences.PreferenceTabViewModel;
2123
import org.jabref.logic.ai.AiDefaultPreferences;
2224
import org.jabref.logic.ai.AiPreferences;
25+
import org.jabref.logic.ai.templates.AiTemplate;
2326
import org.jabref.logic.l10n.Localization;
2427
import org.jabref.logic.preferences.CliPreferences;
2528
import org.jabref.logic.util.LocalizedNumbers;
@@ -79,7 +82,13 @@ public class AiTabViewModel implements PreferenceTabViewModel {
7982
private final StringProperty huggingFaceApiBaseUrl = new SimpleStringProperty();
8083
private final StringProperty gpt4AllApiBaseUrl = new SimpleStringProperty();
8184

82-
private final StringProperty instruction = new SimpleStringProperty();
85+
private final Map<AiTemplate, StringProperty> templateSources = Map.of(
86+
AiTemplate.CHATTING_SYSTEM_MESSAGE, new SimpleStringProperty(),
87+
AiTemplate.CHATTING_USER_MESSAGE, new SimpleStringProperty(),
88+
AiTemplate.SUMMARIZATION_CHUNK, new SimpleStringProperty(),
89+
AiTemplate.SUMMARIZATION_COMBINE, new SimpleStringProperty()
90+
);
91+
8392
private final StringProperty temperature = new SimpleStringProperty();
8493
private final IntegerProperty contextWindowSize = new SimpleIntegerProperty();
8594
private final IntegerProperty documentSplitterChunkSize = new SimpleIntegerProperty();
@@ -96,7 +105,6 @@ public class AiTabViewModel implements PreferenceTabViewModel {
96105
private final Validator chatModelValidator;
97106
private final Validator apiBaseUrlValidator;
98107
private final Validator embeddingModelValidator;
99-
private final Validator instructionValidator;
100108
private final Validator temperatureTypeValidator;
101109
private final Validator temperatureRangeValidator;
102110
private final Validator contextWindowSizeValidator;
@@ -242,11 +250,6 @@ public AiTabViewModel(CliPreferences preferences) {
242250
Objects::nonNull,
243251
ValidationMessage.error(Localization.lang("Embedding model has to be provided")));
244252

245-
this.instructionValidator = new FunctionBasedValidator<>(
246-
instruction,
247-
message -> !StringUtil.isBlank(message),
248-
ValidationMessage.error(Localization.lang("The instruction has to be provided")));
249-
250253
this.temperatureTypeValidator = new FunctionBasedValidator<>(
251254
temperature,
252255
temp -> LocalizedNumbers.stringToDouble(temp).isPresent(),
@@ -318,7 +321,10 @@ public void setValues() {
318321
customizeExpertSettings.setValue(aiPreferences.getCustomizeExpertSettings());
319322

320323
selectedEmbeddingModel.setValue(aiPreferences.getEmbeddingModel());
321-
instruction.setValue(aiPreferences.getInstruction());
324+
325+
Arrays.stream(AiTemplate.values()).forEach(template ->
326+
templateSources.get(template).set(aiPreferences.getTemplate(template)));
327+
322328
temperature.setValue(LocalizedNumbers.doubleToString(aiPreferences.getTemperature()));
323329
contextWindowSize.setValue(aiPreferences.getContextWindowSize());
324330
documentSplitterChunkSize.setValue(aiPreferences.getDocumentSplitterChunkSize());
@@ -359,7 +365,9 @@ public void storeSettings() {
359365
aiPreferences.setHuggingFaceApiBaseUrl(huggingFaceApiBaseUrl.get() == null ? "" : huggingFaceApiBaseUrl.get());
360366
aiPreferences.setGpt4AllApiBaseUrl(gpt4AllApiBaseUrl.get() == null ? "" : gpt4AllApiBaseUrl.get());
361367

362-
aiPreferences.setInstruction(instruction.get());
368+
Arrays.stream(AiTemplate.values()).forEach(template ->
369+
aiPreferences.setTemplate(template, templateSources.get(template).get()));
370+
363371
// We already check the correctness of temperature and RAG minimum score in validators, so we don't need to check it here.
364372
aiPreferences.setTemperature(LocalizedNumbers.stringToDouble(oldLocale, temperature.get()).get());
365373
aiPreferences.setContextWindowSize(contextWindowSize.get());
@@ -373,8 +381,6 @@ public void resetExpertSettings() {
373381
String resetApiBaseUrl = selectedAiProvider.get().getApiUrl();
374382
currentApiBaseUrl.set(resetApiBaseUrl);
375383

376-
instruction.set(AiDefaultPreferences.SYSTEM_MESSAGE);
377-
378384
contextWindowSize.set(AiDefaultPreferences.getContextWindowSize(selectedAiProvider.get(), currentChatModel.get()));
379385

380386
temperature.set(LocalizedNumbers.doubleToString(AiDefaultPreferences.TEMPERATURE));
@@ -384,6 +390,11 @@ public void resetExpertSettings() {
384390
ragMinScore.set(LocalizedNumbers.doubleToString(AiDefaultPreferences.RAG_MIN_SCORE));
385391
}
386392

393+
public void resetTemplates() {
394+
Arrays.stream(AiTemplate.values()).forEach(template ->
395+
templateSources.get(template).set(AiDefaultPreferences.TEMPLATES.get(template)));
396+
}
397+
387398
@Override
388399
public boolean validateSettings() {
389400
if (enableAi.get()) {
@@ -410,7 +421,6 @@ public boolean validateExpertSettings() {
410421
List<Validator> validators = List.of(
411422
apiBaseUrlValidator,
412423
embeddingModelValidator,
413-
instructionValidator,
414424
temperatureTypeValidator,
415425
temperatureRangeValidator,
416426
contextWindowSizeValidator,
@@ -484,8 +494,8 @@ public BooleanProperty disableApiBaseUrlProperty() {
484494
return disableApiBaseUrl;
485495
}
486496

487-
public StringProperty instructionProperty() {
488-
return instruction;
497+
public Map<AiTemplate, StringProperty> getTemplateSources() {
498+
return templateSources;
489499
}
490500

491501
public StringProperty temperatureProperty() {
@@ -536,10 +546,6 @@ public ValidationStatus getEmbeddingModelValidationStatus() {
536546
return embeddingModelValidator.getValidationStatus();
537547
}
538548

539-
public ValidationStatus getSystemMessageValidationStatus() {
540-
return instructionValidator.getValidationStatus();
541-
}
542-
543549
public ValidationStatus getTemperatureTypeValidationStatus() {
544550
return temperatureTypeValidator.getValidationStatus();
545551
}

src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import java.util.List;
55
import java.util.Map;
66

7+
import org.jabref.logic.ai.templates.AiTemplate;
78
import org.jabref.model.ai.AiProvider;
89
import org.jabref.model.ai.EmbeddingModel;
910

@@ -80,6 +81,46 @@ public String toString() {
8081

8182
public static final int FALLBACK_CONTEXT_WINDOW_SIZE = 8196;
8283

84+
public static final Map<AiTemplate, String> TEMPLATES = Map.of(
85+
AiTemplate.CHATTING_SYSTEM_MESSAGE, """
86+
You are an AI assistant that analyses research papers. You answer questions about papers.
87+
You will be supplied with the necessary information. The supplied information will contain mentions of papers in form '@citationKey'.
88+
Whenever you refer to a paper, use its citation key in the same form with @ symbol. Whenever you find relevant information, always use the citation key.
89+
90+
Here are the papers you are analyzing:
91+
#foreach( $entry in $entries )
92+
${CanonicalBibEntry.getCanonicalRepresentation($entry)}
93+
#end""",
94+
95+
AiTemplate.CHATTING_USER_MESSAGE, """
96+
$message
97+
98+
Here is some relevant information for you:
99+
#foreach( $excerpt in $excerpts )
100+
${excerpt.citationKey()}:
101+
${excerpt.text()}
102+
#end""",
103+
104+
AiTemplate.SUMMARIZATION_CHUNK, """
105+
Please provide an overview of the following text. It is a part of a scientific paper.
106+
The summary should include the main objectives, methodologies used, key findings, and conclusions.
107+
Mention any significant experiments, data, or discussions presented in the paper.
108+
109+
DOCUMENT:
110+
$document
111+
112+
OVERVIEW:""",
113+
114+
AiTemplate.SUMMARIZATION_COMBINE, """
115+
You have written an overview of a scientific paper. You have been collecting notes from various parts
116+
of the paper. Now your task is to combine all of the notes in one structured message.
117+
118+
SUMMARIES:
119+
$summaries
120+
121+
FINAL OVERVIEW:"""
122+
);
123+
83124
public static List<String> getAvailableModels(AiProvider aiProvider) {
84125
return Arrays.stream(AiDefaultPreferences.PredefinedChatModel.values())
85126
.filter(model -> model.getAiProvider() == aiProvider)

0 commit comments

Comments
 (0)