Skip to content

Commit bb2b3cf

Browse files
committed
feat: Add support for chatStream
1 parent 83f9f16 commit bb2b3cf

File tree

3 files changed

+105
-2
lines changed

3 files changed

+105
-2
lines changed

lib/src/implement/gemini_implement.dart

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,80 @@ class GeminiImpl implements GeminiInterface {
8686
return GeminiModel.jsonToList(response.data['models']);
8787
}
8888

89+
@override
90+
Stream<Candidates> streamChat(
91+
List<Content> chats, {
92+
String? modelName,
93+
List<SafetySetting>? safetySettings,
94+
GenerationConfig? generationConfig,
95+
}) async* {
96+
Gemini.instance.typeProvider?.clear();
97+
98+
final response = await api.post(
99+
'${modelName ?? Constants.defaultModel}:streamGenerateContent',
100+
isStreamResponse: true,
101+
data: {'contents': chats.map((e) => e.toJson()).toList()},
102+
generationConfig: generationConfig,
103+
safetySettings: safetySettings,
104+
);
105+
106+
Gemini.instance.typeProvider?.loading = false;
107+
// return GeminiResponse.fromJson(response.data).candidates?.lastOrNull;
108+
if (response.statusCode == 200) {
109+
final ResponseBody rb = response.data;
110+
int index = 0;
111+
String modelStr = '';
112+
List<int> cacheUnits = [];
113+
List<int> list = [];
114+
115+
await for (final itemList in rb.stream) {
116+
list = cacheUnits + itemList;
117+
118+
cacheUnits.clear();
119+
120+
String res = "";
121+
try {
122+
res = utf8.decode(list);
123+
} catch (e) {
124+
print("error: $e");
125+
cacheUnits = list;
126+
continue;
127+
}
128+
129+
res = res.trim();
130+
131+
if (index == 0 && res.startsWith("[")) {
132+
res = res.replaceFirst('[', '');
133+
}
134+
if (res.startsWith(',')) {
135+
res = res.replaceFirst(',', '');
136+
}
137+
if (res.endsWith(']')) {
138+
res = res.substring(0, res.length - 1);
139+
}
140+
141+
res = res.trim();
142+
143+
for (final line in splitter.convert(res)) {
144+
if (modelStr == '' && line == ',') {
145+
continue;
146+
}
147+
modelStr += line;
148+
try {
149+
final candidate = Candidates.fromJson(
150+
(jsonDecode(modelStr)['candidates'] as List?)?.firstOrNull);
151+
yield candidate;
152+
Gemini.instance.typeProvider?.add(candidate.output);
153+
modelStr = '';
154+
} catch (e) {
155+
continue;
156+
}
157+
}
158+
index++;
159+
}
160+
}
161+
}
162+
89163
@override
90164
Stream<Candidates> streamGenerateContent(String text,
91165
{List<Uint8List>? images,

lib/src/init.dart

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,15 @@ class Gemini implements GeminiInterface {
3232
/// to see request progress
3333
static bool enableDebugging = false;
3434

35+
// String? baseURL;
36+
3537
/// private constructor [Gemini._]
3638
Gemini._(
3739
{
3840
/// [apiKey] is required property
3941
required String apiKey,
42+
String? baseURL,
43+
Map<String, dynamic>? headers,
4044

4145
/// theses properties are optional
4246
List<SafetySetting>? safetySettings,
@@ -45,9 +49,11 @@ class Gemini implements GeminiInterface {
4549
: _impl = GeminiImpl(
4650
api: GeminiService(
4751
Dio(BaseOptions(
48-
baseUrl:
49-
'${Constants.baseUrl}${version ?? Constants.defaultVersion}/',
52+
baseUrl: baseURL != null
53+
? "$baseURL${version ?? Constants.defaultVersion}/"
54+
: '${Constants.baseUrl}${version ?? Constants.defaultVersion}/',
5055
contentType: 'application/json',
56+
headers: headers,
5157
)),
5258
apiKey: apiKey),
5359
safetySettings: safetySettings,
@@ -67,6 +73,8 @@ class Gemini implements GeminiInterface {
6773
/// singleton initialize [Gemini.init]
6874
factory Gemini.init(
6975
{required String apiKey,
76+
String? baseURL,
77+
Map<String, dynamic>? headers,
7078
List<SafetySetting>? safetySettings,
7179
GenerationConfig? generationConfig,
7280
bool? enableDebugging,
@@ -78,6 +86,8 @@ class Gemini implements GeminiInterface {
7886
_firstInit = false;
7987
instance = Gemini._(
8088
apiKey: apiKey,
89+
baseURL: baseURL,
90+
headers: headers,
8191
safetySettings: safetySettings,
8292
generationConfig: generationConfig,
8393
version: version);
@@ -98,6 +108,18 @@ class Gemini implements GeminiInterface {
98108
safetySettings: safetySettings,
99109
modelName: modelName);
100110

111+
@override
112+
Stream<Candidates> streamChat(
113+
List<Content> chats, {
114+
String? modelName,
115+
List<SafetySetting>? safetySettings,
116+
GenerationConfig? generationConfig,
117+
}) =>
118+
_impl.streamChat(chats,
119+
modelName: modelName,
120+
safetySettings: safetySettings,
121+
generationConfig: generationConfig);
122+
101123
/// [countTokens] When using long prompts, it might be useful to count tokens
102124
/// before sending any content to the model.
103125
/// * not implemented yet

lib/src/repository/gemini_interface.dart

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ abstract class GeminiInterface {
6969
GenerationConfig? generationConfig,
7070
});
7171

72+
Stream<Candidates> streamChat(
73+
List<Content> chats, {
74+
String? modelName,
75+
List<SafetySetting>? safetySettings,
76+
GenerationConfig? generationConfig,
77+
});
78+
7279
/// [chat] or `Multi-turn conversations`
7380
/// Using Gemini, you can build freeform conversations across multiple turns.
7481
Future<Candidates?> chat(

0 commit comments

Comments
 (0)