Skip to content

Commit bc70ece

Browse files
committed
WIP: runs
1 parent e0c39a8 commit bc70ece

File tree

11 files changed

+241
-9
lines changed

11 files changed

+241
-9
lines changed

openai-client/src/main/scala/io/cequence/openaiscala/JsonFormats.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,4 +782,46 @@ object JsonFormats {
782782

783783
Format(reads, writes)
784784
}
785+
786+
implicit lazy val runReasonFormat: Format[Run.Reason] = {
787+
implicit lazy val stringStringMapFormat: Format[Map[String, String]] =
788+
JsonUtil.StringStringMapFormat
789+
Json.format[Run.Reason]
790+
}
791+
792+
implicit lazy val lastRunErrorCodeFormat: Format[Run.LastErrorCode] = {
793+
import Run.LastErrorCode._
794+
snakeEnumFormat(ServerError, RateLimitExceeded, InvalidPrompt)
795+
}
796+
797+
implicit lazy val truncationStrategyTypeFormat: Format[Run.TruncationStrategyType] = {
798+
import Run.TruncationStrategyType._
799+
snakeEnumFormat(Auto, LastMessages)
800+
}
801+
802+
implicit lazy val RunStatusFormat: Format[RunStatus] = {
803+
import RunStatus._
804+
snakeEnumFormat(
805+
Queued,
806+
InProgress,
807+
RequiresAction,
808+
Cancelling,
809+
Cancelled,
810+
Failed,
811+
Completed,
812+
Incomplete,
813+
Expired
814+
)
815+
}
816+
Run
817+
implicit lazy val RunFormat: Format[Run] =
818+
Json.format[Run]
819+
820+
implicit lazy val RequiredActionFormat: Format[RequiredAction] = Json.format[RequiredAction]
821+
implicit lazy val ToolCallFormat: Format[ToolCall] = Json.format[ToolCall]
822+
implicit lazy val SubmitToolOutputsFormat: Format[SubmitToolOutputs] =
823+
Json.format[SubmitToolOutputs]
824+
825+
implicit lazy val runResponseFormat: Format[RunResponse] =
826+
Json.format[RunResponse]
785827
}

openai-client/src/main/scala/io/cequence/openaiscala/service/impl/EndPoint.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ object EndPoint {
2323
case object batches extends EndPoint
2424
case object assistants extends EndPoint
2525
case object vector_stores extends EndPoint
26-
case object vector_store_files extends EndPoint("vector_stores/files")
26+
case object runs extends EndPoint
2727
}
2828

2929
sealed trait Param extends EnumValue
@@ -103,4 +103,5 @@ object Param {
103103
case object completion_window extends Param
104104
case object chunking_strategy extends Param
105105
case object filter extends Param
106+
case object max_prompt_tokens extends Param
106107
}

openai-client/src/main/scala/io/cequence/openaiscala/service/impl/OpenAICoreServiceImpl.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ private[service] trait OpenAICoreServiceImpl
2121
extends OpenAICoreService
2222
with OpenAIChatCompletionServiceImpl
2323
with HandleOpenAIErrorCodes
24-
with CompletionBodyMaker {
24+
with CompletionBodyMaker
25+
with RunBodyMaker {
2526

2627
// override protected def handleErrorCodes(
2728
// httpCode: Int,

openai-client/src/main/scala/io/cequence/openaiscala/service/impl/OpenAIServiceImpl.scala

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,46 @@ private[service] trait OpenAIServiceImpl
6565
)
6666
}
6767

68+
def createRun(
69+
run: Run,
70+
tools: Seq[ToolSpec],
71+
responseToolChoice: Option[String] = None,
72+
settings: CreateRunSettings = DefaultSettings.CreateRun,
73+
stream: Boolean
74+
): Future[RunResponse] = {
75+
val coreParams = createBodyParamsForRun(settings, stream)
76+
77+
val toolParam = toolParams(tools)
78+
79+
execPOST(
80+
EndPoint.runs,
81+
bodyParams = coreParams ++ toolParam
82+
).map(
83+
_.asSafe[RunResponse]
84+
)
85+
}
86+
87+
private def toolParams(
88+
tools: Seq[ToolSpec],
89+
responseToolChoice: Option[String] = None
90+
): Seq[(Param, Option[JsValue])] = {
91+
val toolJsons = tools.map { case tool: FunctionSpec =>
92+
Map("type" -> "function", "function" -> Json.toJson(tool))
93+
}
94+
95+
val extraParams = jsonBodyParams(
96+
Param.tools -> Some(toolJsons),
97+
Param.tool_choice -> responseToolChoice.map(name =>
98+
Map(
99+
"type" -> "function",
100+
"function" -> Map("name" -> name)
101+
)
102+
) // otherwise "auto" is used by default (if tools are present)
103+
)
104+
105+
extraParams
106+
}
107+
68108
override def createChatToolCompletion(
69109
messages: Seq[BaseMessage],
70110
tools: Seq[ToolSpec],
@@ -74,12 +114,9 @@ private[service] trait OpenAIServiceImpl
74114
val coreParams =
75115
createBodyParamsForChatCompletion(messages, settings, stream = false)
76116

77-
val toolJsons = tools.map(
78-
_ match {
79-
case tool: FunctionSpec =>
80-
Map("type" -> "function", "function" -> Json.toJson(tool))
81-
}
82-
)
117+
val toolJsons: Seq[Map[String, Object]] = tools.map { case tool: FunctionSpec =>
118+
Map("type" -> "function", "function" -> Json.toJson(tool))
119+
}
83120

84121
val extraParams = jsonBodyParams(
85122
Param.tools -> Some(toolJsons),
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package io.cequence.openaiscala.service.impl
2+
3+
import io.cequence.openaiscala.JsonFormats._
4+
import io.cequence.openaiscala.domain.BaseMessage
5+
import io.cequence.openaiscala.domain.settings.{
6+
CreateChatCompletionSettings,
7+
CreateRunSettings
8+
}
9+
import io.cequence.wsclient.service.ws.WSRequestHelper
10+
import play.api.libs.json.{JsValue, Json}
11+
12+
trait RunBodyMaker {
13+
14+
this: WSRequestHelper =>
15+
16+
protected def createBodyParamsForRun(
17+
settings: CreateRunSettings,
18+
stream: Boolean
19+
): Seq[(Param, Option[JsValue])] = {
20+
21+
jsonBodyParams(
22+
Param.model -> Some(settings.model),
23+
Param.temperature -> settings.temperature,
24+
Param.top_p -> settings.topP,
25+
Param.stream -> Some(stream),
26+
Param.max_prompt_tokens -> settings.maxPromptTokens,
27+
Param.response_format -> settings.responseFormat.map { format =>
28+
Json.toJson(format)
29+
}
30+
)
31+
}
32+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package io.cequence.openaiscala.domain
2+
3+
import io.cequence.wsclient.domain.EnumValue
4+
5+
case class GenericLastError[T <: EnumValue](
6+
code: T,
7+
message: String
8+
)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package io.cequence.openaiscala.domain
2+
3+
import io.cequence.openaiscala.domain.Run.Reason
4+
import io.cequence.openaiscala.domain.Run.TruncationStrategyType.Auto
5+
import io.cequence.openaiscala.domain.response.UsageInfo
6+
import io.cequence.wsclient.domain.{EnumValue, SnakeCaseEnumValue}
7+
import play.api.libs.json.JsonNaming.SnakeCase
8+
9+
import java.util.Date
10+
11+
sealed trait RunStatus extends SnakeCaseEnumValue
12+
13+
object RunStatus {
14+
case object Queued extends RunStatus
15+
case object InProgress extends RunStatus
16+
case object RequiresAction extends RunStatus
17+
case object Cancelling extends RunStatus
18+
case object Cancelled extends RunStatus
19+
case object Failed extends RunStatus
20+
case object Completed extends RunStatus
21+
case object Incomplete extends RunStatus
22+
case object Expired extends RunStatus
23+
}
24+
25+
case class RequiredAction(
26+
`type`: String,
27+
submitToolOutputs: SubmitToolOutputs
28+
)
29+
30+
case class SubmitToolOutputs(
31+
toolCalls: Seq[ToolCall]
32+
)
33+
34+
case class ToolCall(
35+
id: String,
36+
`type`: String,
37+
function: FunctionCallSpec
38+
)
39+
40+
case class Run(
41+
id: String,
42+
`object`: String,
43+
created_at: Date,
44+
thread_id: String,
45+
assistant_id: String,
46+
status: RunStatus,
47+
required_action: Option[RequiredAction],
48+
last_error: Option[Run.LastErrorCode],
49+
expires_at: Option[Date],
50+
started_at: Option[Date],
51+
cancelled_at: Option[Date],
52+
failed_at: Option[Date],
53+
completed_at: Option[Date],
54+
incomplete_details: Option[Reason],
55+
model: String,
56+
instructions: String,
57+
usage: Option[UsageInfo]
58+
// tool_choice: Either[String, Any], // Replace Any with the actual type when available
59+
// response_format: Either[String, Any] // Replace Any with the actual type when available
60+
)
61+
62+
object Run {
63+
64+
case class TruncationStrategy(
65+
`type`: TruncationStrategyType = Auto,
66+
lastMessages: Option[Int]
67+
)
68+
69+
sealed trait TruncationStrategyType extends SnakeCaseEnumValue
70+
object TruncationStrategyType {
71+
case object Auto extends TruncationStrategyType
72+
case object LastMessages extends TruncationStrategyType
73+
}
74+
75+
case class Reason(value: String) extends AnyVal
76+
77+
sealed trait LastErrorCode extends SnakeCaseEnumValue
78+
79+
// server_error, rate_limit_exceeded, or invalid_prompt
80+
object LastErrorCode {
81+
case object ServerError extends LastErrorCode
82+
case object RateLimitExceeded extends LastErrorCode
83+
case object InvalidPrompt extends LastErrorCode
84+
}
85+
86+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package io.cequence.openaiscala.domain.response
2+
3+
import io.cequence.openaiscala.domain.Run
4+
5+
case class RunResponse(
6+
run: Run
7+
)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package io.cequence.openaiscala.domain.settings
2+
3+
import io.cequence.openaiscala.domain.response.ResponseFormat
4+
5+
case class CreateRunSettings (
6+
model: String,
7+
metadata: Map[String, String] = Map.empty,
8+
temperature: Option[Double] = None,
9+
topP: Option[Double] = None,
10+
maxPromptTokens: Option[Int] = None,
11+
maxCompletionTokens: Option[Int] = None,
12+
responseFormat: Option[ResponseFormat] = None,
13+
)

openai-core/src/main/scala/io/cequence/openaiscala/service/OpenAIServiceConsts.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ trait OpenAIServiceConsts {
2222
max_tokens = Some(1000)
2323
)
2424

25+
val CreateRun = CreateRunSettings(
26+
model = ModelId.gpt_4_1106_preview,
27+
maxPromptTokens = Some(1000)
28+
)
29+
2530
val CreateChatCompletion = CreateChatCompletionSettings(
2631
model = ModelId.gpt_3_5_turbo_1106,
2732
max_tokens = Some(1000)

0 commit comments

Comments
 (0)