Skip to content

Commit 40684fe

Browse files
committed
feat(gpt): support stream response model: davinci, curie, babbage and ada
1 parent a986e14 commit 40684fe

File tree

7 files changed

+115
-86
lines changed

7 files changed

+115
-86
lines changed

app/google-services.json

Lines changed: 0 additions & 39 deletions
This file was deleted.

app/src/main/java/com/chatgptlite/wanted/constants/Constants.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ const val urlToImageAuthor = "https://avatars.githubusercontent.com/u/60530946?v
77
const val urlToAvatarGPT = "https://gptapk.com/wp-content/uploads/2023/02/chatgpt-icon.png"
88
const val urlToGithub = "https://github.com/lambiengcode"
99

10+
const val matchResultString = "\"text\":"
1011
const val matchResultTurboString = "\"content\":"
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
package com.chatgptlite.wanted.data.api
22

3+
import com.chatgptlite.wanted.constants.textCompletionsEndpoint
34
import com.chatgptlite.wanted.constants.textCompletionsTurboEndpoint
45
import com.google.gson.JsonObject
56
import okhttp3.ResponseBody
67
import retrofit2.Call
78
import retrofit2.http.*
89

910
interface OpenAIApi {
10-
@POST(textCompletionsTurboEndpoint)
11+
@POST(textCompletionsEndpoint)
1112
@Streaming
1213
fun textCompletionsWithStream(@Body body: JsonObject): Call<ResponseBody>
14+
15+
@POST(textCompletionsTurboEndpoint)
16+
@Streaming
17+
fun textCompletionsTurboWithStream(@Body body: JsonObject): Call<ResponseBody>
1318
}

app/src/main/java/com/chatgptlite/wanted/data/remote/OpenAIRepositoryImpl.kt

Lines changed: 81 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.chatgptlite.wanted.data.remote
22

3+
import com.chatgptlite.wanted.constants.matchResultString
34
import com.chatgptlite.wanted.constants.matchResultTurboString
45
import com.chatgptlite.wanted.data.api.OpenAIApi
56
import com.chatgptlite.wanted.models.TextCompletionsParam
@@ -20,58 +21,94 @@ import javax.inject.Inject
2021
@Suppress("UNREACHABLE_CODE")
2122
class OpenAIRepositoryImpl @Inject constructor(
2223
private val openAIApi: OpenAIApi,
23-
): OpenAIRepository {
24-
override fun textCompletionsWithStream(params: TextCompletionsParam): Flow<String> = callbackFlow {
25-
withContext(Dispatchers.IO) {
26-
val response = openAIApi.textCompletionsWithStream(params.toJson()).execute()
27-
28-
if (response.isSuccessful) {
29-
val input = response.body()?.byteStream()?.bufferedReader() ?: throw Exception()
30-
try {
31-
while (true) {
32-
val line =
33-
withContext(Dispatchers.IO) {
24+
) : OpenAIRepository {
25+
override fun textCompletionsWithStream(params: TextCompletionsParam): Flow<String> =
26+
callbackFlow {
27+
withContext(Dispatchers.IO) {
28+
val response = (if (params.isTurbo) openAIApi.textCompletionsTurboWithStream(
29+
params.toJson()
30+
) else openAIApi.textCompletionsWithStream(params.toJson())).execute()
31+
32+
if (response.isSuccessful) {
33+
val input = response.body()?.byteStream()?.bufferedReader() ?: throw Exception()
34+
try {
35+
while (true) {
36+
val line = withContext(Dispatchers.IO) {
3437
input.readLine()
3538
} ?: continue
36-
if (line == "data: [DONE]") {
37-
close()
38-
} else if (line.startsWith("data:")) {
39-
try {
40-
// Handle & convert data -> emit to client
41-
val value = lookupDataFromResponseTurbo(line)
42-
trySend(value)
43-
} catch (e: Exception) {
44-
e.printStackTrace()
39+
if (line == "data: [DONE]") {
40+
close()
41+
} else if (line.startsWith("data:")) {
42+
try {
43+
// Handle & convert data -> emit to client
44+
val value =
45+
if (params.isTurbo) lookupDataFromResponseTurbo(line) else lookupDataFromResponse(
46+
line
47+
)
48+
49+
if (value.isNotEmpty()) {
50+
trySend(value)
51+
}
52+
} catch (e: Exception) {
53+
54+
e.printStackTrace()
55+
}
4556
}
4657
}
47-
}
48-
} catch (e: IOException) {
49-
throw Exception(e)
50-
} finally {
51-
withContext(Dispatchers.IO) {
52-
input.close()
53-
}
58+
} catch (e: IOException) {
59+
throw Exception(e)
60+
} finally {
61+
withContext(Dispatchers.IO) {
62+
input.close()
63+
}
5464

55-
awaitClose {
5665
close()
5766
}
58-
}
59-
} else {
60-
if (!response.isSuccessful) {
61-
var jsonObject: JSONObject? = null
62-
try {
63-
jsonObject = JSONObject(response.errorBody()!!.string())
64-
} catch (e: JSONException) {
65-
e.printStackTrace()
67+
} else {
68+
if (!response.isSuccessful) {
69+
var jsonObject: JSONObject? = null
70+
try {
71+
jsonObject = JSONObject(response.errorBody()!!.string())
72+
println(jsonObject)
73+
} catch (e: JSONException) {
74+
e.printStackTrace()
75+
}
6676
}
67-
}
68-
trySend("Failure! Try again.")
69-
awaitClose {
77+
trySend("Failure! Try again.")
7078
close()
7179
}
7280
}
81+
82+
close()
83+
}
84+
85+
private fun lookupDataFromResponse(jsonString: String): String {
86+
val splitsJsonString = jsonString.split("[{")
87+
88+
val indexOfResult: Int = splitsJsonString.indexOfLast {
89+
it.contains(matchResultString)
90+
}
91+
92+
val textSplits =
93+
if (indexOfResult == -1) listOf() else splitsJsonString[indexOfResult].split(",")
94+
95+
val indexOfText: Int = textSplits.indexOfLast {
96+
it.contains(matchResultString)
97+
}
98+
99+
if (indexOfText != -1) {
100+
try {
101+
val gson = Gson()
102+
val jsonObject =
103+
gson.fromJson("{${textSplits[indexOfText]}}", JsonObject::class.java)
104+
105+
return jsonObject.get("text").asString
106+
} catch (e: java.lang.Exception) {
107+
println(e.localizedMessage)
108+
}
73109
}
74-
close()
110+
111+
return ""
75112
}
76113

77114
private fun lookupDataFromResponseTurbo(jsonString: String): String {
@@ -81,7 +118,8 @@ class OpenAIRepositoryImpl @Inject constructor(
81118
it.contains(matchResultTurboString)
82119
}
83120

84-
val textSplits = if (indexOfResult == -1) listOf() else splitsJsonString[indexOfResult].split(",")
121+
val textSplits =
122+
if (indexOfResult == -1) listOf() else splitsJsonString[indexOfResult].split(",")
85123

86124
val indexOfText: Int = textSplits.indexOfLast {
87125
it.contains(matchResultTurboString)
@@ -90,7 +128,8 @@ class OpenAIRepositoryImpl @Inject constructor(
90128
if (indexOfText != -1) {
91129
try {
92130
val gson = Gson()
93-
val jsonObject = gson.fromJson("{${textSplits[indexOfText]}}", JsonObject::class.java)
131+
val jsonObject =
132+
gson.fromJson("{${textSplits[indexOfText]}}", JsonObject::class.java)
94133

95134
return jsonObject.getAsJsonObject("delta").get("content").asString
96135
} catch (e: java.lang.Exception) {

app/src/main/java/com/chatgptlite/wanted/models/TextCompletionsParam.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ data class TextCompletionsParam(
5151
result = 31 * result + messagesTurbo.hashCode()
5252
return result
5353
}
54+
55+
val isTurbo: Boolean
56+
get() = model == GPTModel.gpt35Turbo
5457
}
5558

5659
fun TextCompletionsParam.toJson(): JsonObject {

app/src/main/java/com/chatgptlite/wanted/ui/common/AppDrawer.kt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,11 @@ private fun DrawerHeader() {
8484
.size(34.dp)
8585

8686
Row(verticalAlignment = CenterVertically, horizontalArrangement = Arrangement.SpaceBetween) {
87-
Row(modifier = Modifier
88-
.padding(16.dp)
89-
.weight(1f), verticalAlignment = CenterVertically) {
87+
Row(
88+
modifier = Modifier
89+
.padding(16.dp)
90+
.weight(1f), verticalAlignment = CenterVertically
91+
) {
9092
Image(
9193
painter = rememberAsyncImagePainter(urlToImageAppIcon),
9294
modifier = paddingSizeModifier.then(Modifier.clip(RoundedCornerShape(6.dp))),

app/src/main/java/com/chatgptlite/wanted/ui/conversations/ConversationViewModel.kt

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class ConversationViewModel @Inject constructor(
7979
// Execute API OpenAI
8080
val flow: Flow<String> = openAIRepo.textCompletionsWithStream(
8181
TextCompletionsParam(
82+
promptText = getPrompt(_currentConversation.value),
8283
messagesTurbo = getMessagesParamsTurbo(_currentConversation.value)
8384
)
8485
)
@@ -124,13 +125,30 @@ class ConversationViewModel @Inject constructor(
124125
return messagesMap[conversationId]!!
125126
}
126127

128+
private fun getPrompt(conversationId: String): String {
129+
if (_messages.value[conversationId] == null) return ""
130+
131+
val messagesMap: HashMap<String, MutableList<MessageModel>> =
132+
_messages.value.clone() as HashMap<String, MutableList<MessageModel>>
133+
134+
var response: String = ""
135+
136+
for (message in messagesMap[conversationId]!!.reversed()) {
137+
response += """
138+
Human:${message.question.trim()}
139+
Bot:${if (message.answer == "Let me thinking...") "" else message.answer.trim()}"""
140+
}
141+
142+
return response
143+
}
144+
127145
private fun getMessagesParamsTurbo(conversationId: String): List<MessageTurbo> {
128146
if (_messages.value[conversationId] == null) return listOf()
129147

130148
val messagesMap: HashMap<String, MutableList<MessageModel>> =
131149
_messages.value.clone() as HashMap<String, MutableList<MessageModel>>
132150

133-
val response:MutableList<MessageTurbo> = mutableListOf()
151+
val response: MutableList<MessageTurbo> = mutableListOf()
134152

135153
for (message in messagesMap[conversationId]!!.reversed()) {
136154
response.add(MessageTurbo(content = message.question))

0 commit comments

Comments
 (0)