Skip to content

Commit 95a4b7b

Browse files
committed
createRun, retrieveRun, listRunSteps
1 parent bc70ece commit 95a4b7b

File tree

13 files changed

+344
-26
lines changed

13 files changed

+344
-26
lines changed

openai-client/src/main/scala/io/cequence/openaiscala/JsonFormats.scala

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import io.cequence.openaiscala.domain.response.ResponseFormat.{
1818
}
1919
import io.cequence.openaiscala.domain.response._
2020
import io.cequence.openaiscala.domain.{ThreadMessageFile, _}
21+
import StepDetail.{MessageCreation, ToolCalls}
2122
import io.cequence.wsclient.JsonUtil
2223
import io.cequence.wsclient.JsonUtil.{enumFormat, snakeEnumFormat}
2324
import io.cequence.wsclient.domain.EnumValue
@@ -822,6 +823,56 @@ object JsonFormats {
822823
implicit lazy val SubmitToolOutputsFormat: Format[SubmitToolOutputs] =
823824
Json.format[SubmitToolOutputs]
824825

825-
implicit lazy val runResponseFormat: Format[RunResponse] =
826-
Json.format[RunResponse]
826+
implicit lazy val runResponseFormat: Format[RunResponse] = Json.format[RunResponse]
827+
828+
implicit lazy val runStepLastErrorFormat: Format[RunStep.LastError] = {
829+
import RunStep.LastError._
830+
snakeEnumFormat[RunStep.LastError](ServerError, RateLimitExceeded)
831+
}
832+
implicit lazy val runStepFormat: Format[RunStep] = {
833+
implicit val jsonConfig: JsonConfiguration = JsonConfiguration(SnakeCase)
834+
Json.format[RunStep]
835+
}
836+
837+
implicit lazy val genericLastError: Format[GenericLastError[RunStep.LastError]] = {
838+
implicit lazy val genericLastErrorFormat: Format[GenericLastError[RunStep.LastError]] =
839+
Json.format[GenericLastError[RunStep.LastError]]
840+
genericLastErrorFormat
841+
}
842+
843+
implicit val messageCreationReads: Reads[MessageCreation] =
844+
(__ \ "message_creation" \ "message_id").read[String].map(MessageCreation)
845+
implicit val messageCreationWrites: Writes[MessageCreation] = Writes { messageCreation =>
846+
Json.obj("message_creation" -> Json.obj("message_id" -> messageCreation.messageId))
847+
}
848+
implicit val messageCreationFormat: Format[MessageCreation] =
849+
Format(messageCreationReads, messageCreationWrites)
850+
851+
implicit val toolCallsFormat: Format[ToolCalls] = Json.format[ToolCalls]
852+
853+
implicit val stepDetailFormat: Format[StepDetail] = {
854+
implicit val jsonConfig: JsonConfiguration = JsonConfiguration(SnakeCase)
855+
856+
implicit val stepDetailReads: Reads[StepDetail] = Reads[StepDetail] { json =>
857+
(json \ "type").as[String] match {
858+
case "message_creation" => messageCreationFormat.reads(json)
859+
case "tool_calls" => toolCallsFormat.reads(json)
860+
}
861+
}
862+
863+
implicit val stepDetailWrites: Writes[StepDetail] = Writes[StepDetail] {
864+
case mc: MessageCreation =>
865+
messageCreationFormat.writes(mc).as[JsObject] + ("type" -> JsString("MessageCreation"))
866+
case tc: ToolCalls =>
867+
toolCallsFormat.writes(tc).as[JsObject] + ("type" -> JsString("ToolCalls"))
868+
}
869+
870+
Format(stepDetailReads, stepDetailWrites)
871+
}
872+
873+
// implicit def genericLastErrorFormat[T <: EnumValue](
874+
// implicit format: Format[T]
875+
// ): Format[GenericLastError[T]] = {
876+
// snakeEnumFormat[GenericLastError[T]]
877+
// }
827878
}

openai-client/src/main/scala/io/cequence/openaiscala/service/impl/EndPoint.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,7 @@ object Param {
104104
case object chunking_strategy extends Param
105105
case object filter extends Param
106106
case object max_prompt_tokens extends Param
107+
case object `object` extends Param
108+
case object assistant_id extends Param
109+
case object thread_id extends Param
107110
}

openai-client/src/main/scala/io/cequence/openaiscala/service/impl/OpenAIServiceImpl.scala

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,28 +65,67 @@ private[service] trait OpenAIServiceImpl
6565
)
6666
}
6767

68-
def createRun(
69-
run: Run,
68+
override def createRun(
69+
threadId: String,
70+
assistantId: AssistantId,
71+
instructions: Option[String],
7072
tools: Seq[ToolSpec],
7173
responseToolChoice: Option[String] = None,
7274
settings: CreateRunSettings = DefaultSettings.CreateRun,
7375
stream: Boolean
74-
): Future[RunResponse] = {
76+
): Future[Run] = {
7577
val coreParams = createBodyParamsForRun(settings, stream)
7678

77-
val toolParam = toolParams(tools)
79+
val toolParam = toolParams(tools, responseToolChoice)
80+
81+
val runParams = jsonBodyParams(
82+
// Param.`object` -> Some("thread.run"),
83+
// Param.thread_id -> Some(threadId),
84+
Param.assistant_id -> Some(assistantId.id),
85+
Param.instructions -> Some(instructions)
86+
)
87+
88+
(coreParams ++ toolParam ++ runParams).foreach((x: (Param, Option[JsValue])) =>
89+
println(x._1.toString + " -> " + x._2.toString)
90+
)
7891

7992
execPOST(
80-
EndPoint.runs,
81-
bodyParams = coreParams ++ toolParam
93+
EndPoint.threads,
94+
Some(s"$threadId/runs"),
95+
bodyParams = coreParams ++ toolParam ++ runParams
8296
).map(
83-
_.asSafe[RunResponse]
97+
_.asSafe[Run]
8498
)
8599
}
86100

101+
override def retrieveRun(
102+
threadId: String,
103+
runId: String
104+
): Future[Option[Run]] =
105+
execGETWithStatus(
106+
EndPoint.threads,
107+
Some(s"$threadId/runs/$runId")
108+
).map { response =>
109+
handleNotFoundAndError(response).map(_.asSafe[Run])
110+
}
111+
112+
override def listRunSteps(
113+
threadId: String,
114+
runId: String,
115+
pagination: Pagination,
116+
order: Option[SortOrder]
117+
): Future[Seq[RunStep]] =
118+
execGET(
119+
EndPoint.threads,
120+
Some(s"$threadId/runs/$runId/steps"),
121+
params = paginationParams(pagination) :+ Param.order -> order
122+
).map { response =>
123+
readAttribute(response, "data").asSafeArray[RunStep]
124+
}
125+
87126
private def toolParams(
88127
tools: Seq[ToolSpec],
89-
responseToolChoice: Option[String] = None
128+
responseToolChoice: Option[String]
90129
): Seq[(Param, Option[JsValue])] = {
91130
val toolJsons = tools.map { case tool: FunctionSpec =>
92131
Map("type" -> "function", "function" -> Json.toJson(tool))
@@ -594,21 +633,24 @@ private[service] trait OpenAIServiceImpl
594633
messages: Seq[ThreadMessage],
595634
toolResources: Seq[AssistantToolResource] = Nil,
596635
metadata: Map[String, String]
597-
): Future[Thread] =
636+
): Future[Thread] = {
637+
val params = jsonBodyParams(
638+
Param.messages -> (
639+
if (messages.nonEmpty)
640+
Some(messages.map(Json.toJson(_)(threadMessageFormat)))
641+
else None
642+
),
643+
Param.metadata -> (if (metadata.nonEmpty) Some(metadata) else None),
644+
Param.tool_resources -> (if (toolResources.nonEmpty) Some(toolResources) else None)
645+
)
646+
params.foreach(println)
598647
execPOST(
599648
EndPoint.threads,
600-
bodyParams = jsonBodyParams(
601-
Param.messages -> (
602-
if (messages.nonEmpty)
603-
Some(messages.map(Json.toJson(_)(threadMessageFormat)))
604-
else None
605-
),
606-
Param.metadata -> (if (metadata.nonEmpty) Some(metadata) else None),
607-
Param.tool_resources -> (if (toolResources.nonEmpty) Some(toolResources) else None)
608-
)
649+
bodyParams = params
609650
).map(
610651
_.asSafe[Thread]
611652
)
653+
}
612654

613655
override def retrieveThread(
614656
threadId: String

openai-core/src/main/scala/io/cequence/openaiscala/domain/Run.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ case class Run(
4141
id: String,
4242
`object`: String,
4343
created_at: Date,
44-
thread_id: String,
45-
assistant_id: String,
44+
thread_id: String, // path
45+
assistant_id: String, // param
4646
status: RunStatus,
4747
required_action: Option[RequiredAction],
4848
last_error: Option[Run.LastErrorCode],
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package io.cequence.openaiscala.domain
2+
3+
import io.cequence.openaiscala.domain.response.UsageInfo
4+
import io.cequence.wsclient.domain.EnumValue
5+
6+
import java.util.Date
7+
8+
case class RunStep(
9+
id: String,
10+
`object`: String,
11+
createdAt: Date,
12+
assistantId: String,
13+
threadId: String,
14+
runId: String,
15+
`type`: String,
16+
status: String,
17+
stepDetails: Option[StepDetail],
18+
lastError: Option[
19+
GenericLastError[RunStep.LastError]
20+
],
21+
expiredAt: Option[Date],
22+
cancelledAt: Option[Date],
23+
failedAt: Option[Date],
24+
completedAt: Option[Date],
25+
metadata: Option[Map[String, String]],
26+
usage: Option[UsageInfo]
27+
)
28+
29+
object RunStep {
30+
31+
sealed trait LastError extends EnumValue
32+
object LastError {
33+
case object ServerError extends LastError
34+
case object RateLimitExceeded extends LastError
35+
}
36+
37+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package io.cequence.openaiscala.domain
2+
3+
sealed trait StepDetail
4+
5+
object StepDetail {
6+
case class MessageCreation(messageId: String) extends StepDetail
7+
case class ToolCalls(messages: BaseMessage) extends StepDetail
8+
}

openai-core/src/main/scala/io/cequence/openaiscala/domain/Thread.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ case class ThreadMessage(
2828

2929
// The role of the entity that is creating the message.
3030
// Currently, only "user" is supported.
31-
role: ChatRole = ChatRole.User,
31+
role: ChatRole = ChatRole.User
3232

3333
// A list of File IDs that the message should use.
3434
// There can be a maximum of 10 files attached to a message.
3535
// Useful for tools like retrieval and code_interpreter that can access and use files.
36-
file_ids: Seq[String] = Nil,
36+
// file_ids: Seq[String] = Nil,
3737

3838
// Set of 16 key-value pairs that can be attached to an object.
3939
// This can be useful for storing additional information about the object in a structured format.
4040
// Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long.
41-
metadata: Map[String, String] = Map()
41+
// metadata: Map[String, String] = Map()
4242
)

openai-core/src/main/scala/io/cequence/openaiscala/service/OpenAIService.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import io.cequence.openaiscala.domain.Batch._
66
import io.cequence.openaiscala.domain.response._
77
import io.cequence.openaiscala.domain.settings._
88
import io.cequence.openaiscala.domain.{
9+
AssistantId,
910
AssistantTool,
1011
AssistantToolResource,
1112
Attachment,
@@ -14,6 +15,8 @@ import io.cequence.openaiscala.domain.{
1415
ChunkingStrategy,
1516
FunctionSpec,
1617
Pagination,
18+
Run,
19+
RunStep,
1720
SortOrder,
1821
Thread,
1922
ThreadFullMessage,
@@ -99,6 +102,28 @@ trait OpenAIService extends OpenAICoreService {
99102
settings: CreateChatCompletionSettings = DefaultSettings.CreateChatFunCompletion
100103
): Future[ChatFunCompletionResponse]
101104

105+
def createRun(
106+
threadId: String,
107+
assistantId: AssistantId,
108+
instructions: Option[String],
109+
tools: Seq[ToolSpec],
110+
responseToolChoice: Option[String] = None,
111+
settings: CreateRunSettings = DefaultSettings.CreateRun,
112+
stream: Boolean
113+
): Future[Run]
114+
115+
def retrieveRun(
116+
threadId: String,
117+
runId: String
118+
): Future[Option[Run]]
119+
120+
def listRunSteps(
121+
threadId: String,
122+
runId: String,
123+
pagination: Pagination = Pagination.default,
124+
order: Option[SortOrder] = None
125+
): Future[Seq[RunStep]]
126+
102127
/**
103128
* Creates a model response for the given chat conversation expecting a tool call.
104129
*

openai-core/src/main/scala/io/cequence/openaiscala/service/adapter/OpenAIServiceWrapper.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,42 @@ trait OpenAIServiceWrapper
5252
)
5353
)
5454

55+
def createRun(
56+
threadId: String,
57+
assistantId: AssistantId,
58+
instructions: Option[String],
59+
tools: Seq[ToolSpec],
60+
responseToolChoice: Option[String] = None,
61+
settings: CreateRunSettings = DefaultSettings.CreateRun,
62+
stream: Boolean
63+
): Future[Run] = wrap(
64+
_.createRun(
65+
threadId,
66+
assistantId,
67+
instructions,
68+
tools,
69+
responseToolChoice,
70+
settings,
71+
stream
72+
)
73+
)
74+
75+
def retrieveRun(
76+
threadId: String,
77+
runId: String
78+
): Future[Option[Run]] = wrap(
79+
_.retrieveRun(threadId, runId)
80+
)
81+
82+
override def listRunSteps(
83+
threadId: String,
84+
runId: String,
85+
pagination: Pagination,
86+
order: Option[SortOrder]
87+
): Future[Seq[RunStep]] = wrap(
88+
_.listRunSteps(threadId, runId, pagination, order)
89+
)
90+
5591
override def createChatToolCompletion(
5692
messages: Seq[BaseMessage],
5793
tools: Seq[ToolSpec],

0 commit comments

Comments
 (0)