Skip to content

Commit aca283a

Browse files
authored
Merge pull request babakcode#2 from zmhu/master
Fix a bug and add a feature.
2 parents 61044a5 + 8b5af86 commit aca283a

File tree

3 files changed

+118
-4
lines changed

3 files changed

+118
-4
lines changed

lib/src/implement/gemini_implement.dart

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,80 @@ class GeminiImpl implements GeminiInterface {
8686
return GeminiModel.jsonToList(response.data['models']);
8787
}
8888

89+
@override
90+
Stream<Candidates> streamChat(
91+
List<Content> chats, {
92+
String? modelName,
93+
List<SafetySetting>? safetySettings,
94+
GenerationConfig? generationConfig,
95+
}) async* {
96+
Gemini.instance.typeProvider?.clear();
97+
98+
final response = await api.post(
99+
'${modelName ?? Constants.defaultModel}:streamGenerateContent',
100+
isStreamResponse: true,
101+
data: {'contents': chats.map((e) => e.toJson()).toList()},
102+
generationConfig: generationConfig,
103+
safetySettings: safetySettings,
104+
);
105+
106+
Gemini.instance.typeProvider?.loading = false;
107+
// return GeminiResponse.fromJson(response.data).candidates?.lastOrNull;
108+
if (response.statusCode == 200) {
109+
final ResponseBody rb = response.data;
110+
int index = 0;
111+
String modelStr = '';
112+
List<int> cacheUnits = [];
113+
List<int> list = [];
114+
115+
await for (final itemList in rb.stream) {
116+
list = cacheUnits + itemList;
117+
118+
cacheUnits.clear();
119+
120+
String res = "";
121+
try {
122+
res = utf8.decode(list);
123+
} catch (e) {
124+
print("error: $e");
125+
cacheUnits = list;
126+
continue;
127+
}
128+
129+
res = res.trim();
130+
131+
if (index == 0 && res.startsWith("[")) {
132+
res = res.replaceFirst('[', '');
133+
}
134+
if (res.startsWith(',')) {
135+
res = res.replaceFirst(',', '');
136+
}
137+
if (res.endsWith(']')) {
138+
res = res.substring(0, res.length - 1);
139+
}
140+
141+
res = res.trim();
142+
143+
for (final line in splitter.convert(res)) {
144+
if (modelStr == '' && line == ',') {
145+
continue;
146+
}
147+
modelStr += line;
148+
try {
149+
final candidate = Candidates.fromJson(
150+
(jsonDecode(modelStr)['candidates'] as List?)?.firstOrNull);
151+
yield candidate;
152+
Gemini.instance.typeProvider?.add(candidate.output);
153+
modelStr = '';
154+
} catch (e) {
155+
continue;
156+
}
157+
}
158+
index++;
159+
}
160+
}
161+
}
162+
89163
@override
90164
Stream<Candidates> streamGenerateContent(String text,
91165
{List<Uint8List>? images,
@@ -121,9 +195,22 @@ class GeminiImpl implements GeminiInterface {
121195
final ResponseBody rb = response.data;
122196
int index = 0;
123197
String modelStr = '';
198+
List<int> cacheUnits = [];
199+
List<int> list = [];
124200

125-
await for (final list in rb.stream) {
126-
String res = utf8.decode(list);
201+
await for (final itemList in rb.stream) {
202+
list = cacheUnits + itemList;
203+
204+
cacheUnits.clear();
205+
206+
String res = "";
207+
try {
208+
res = utf8.decode(list);
209+
} catch (e) {
210+
print("error: $e");
211+
cacheUnits = list;
212+
continue;
213+
}
127214

128215
res = res.trim();
129216

lib/src/init.dart

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,15 @@ class Gemini implements GeminiInterface {
3232
/// to see request progress
3333
static bool enableDebugging = false;
3434

35+
// String? baseURL;
36+
3537
/// private constructor [Gemini._]
3638
Gemini._(
3739
{
3840
/// [apiKey] is required property
3941
required String apiKey,
42+
String? baseURL,
43+
Map<String, dynamic>? headers,
4044

4145
/// theses properties are optional
4246
List<SafetySetting>? safetySettings,
@@ -45,9 +49,9 @@ class Gemini implements GeminiInterface {
4549
: _impl = GeminiImpl(
4650
api: GeminiService(
4751
Dio(BaseOptions(
48-
baseUrl:
49-
'${Constants.baseUrl}${version ?? Constants.defaultVersion}/',
52+
baseUrl: '${baseURL ?? Constants.baseUrl}${version ?? Constants.defaultVersion}/',
5053
contentType: 'application/json',
54+
headers: headers,
5155
)),
5256
apiKey: apiKey),
5357
safetySettings: safetySettings,
@@ -67,6 +71,8 @@ class Gemini implements GeminiInterface {
6771
/// singleton initialize [Gemini.init]
6872
factory Gemini.init(
6973
{required String apiKey,
74+
String? baseURL,
75+
Map<String, dynamic>? headers,
7076
List<SafetySetting>? safetySettings,
7177
GenerationConfig? generationConfig,
7278
bool? enableDebugging,
@@ -78,6 +84,8 @@ class Gemini implements GeminiInterface {
7884
_firstInit = false;
7985
instance = Gemini._(
8086
apiKey: apiKey,
87+
baseURL: baseURL,
88+
headers: headers,
8189
safetySettings: safetySettings,
8290
generationConfig: generationConfig,
8391
version: version);
@@ -98,6 +106,18 @@ class Gemini implements GeminiInterface {
98106
safetySettings: safetySettings,
99107
modelName: modelName);
100108

109+
@override
110+
Stream<Candidates> streamChat(
111+
List<Content> chats, {
112+
String? modelName,
113+
List<SafetySetting>? safetySettings,
114+
GenerationConfig? generationConfig,
115+
}) =>
116+
_impl.streamChat(chats,
117+
modelName: modelName,
118+
safetySettings: safetySettings,
119+
generationConfig: generationConfig);
120+
101121
/// [countTokens] When using long prompts, it might be useful to count tokens
102122
/// before sending any content to the model.
103123
/// * not implemented yet

lib/src/repository/gemini_interface.dart

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ abstract class GeminiInterface {
6969
GenerationConfig? generationConfig,
7070
});
7171

72+
Stream<Candidates> streamChat(
73+
List<Content> chats, {
74+
String? modelName,
75+
List<SafetySetting>? safetySettings,
76+
GenerationConfig? generationConfig,
77+
});
78+
7279
/// [chat] or `Multi-turn conversations`
7380
/// Using Gemini, you can build freeform conversations across multiple turns.
7481
Future<Candidates?> chat(

0 commit comments

Comments
 (0)