Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 120 additions & 21 deletions core/http/endpoints/openai/realtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,36 @@ func registerRealtime(application *application.Application, model string) func(c

case types.ConversationItemCreateEvent:
xlog.Debug("recv", "message", string(msg))
sendNotImplemented(c, "conversation.item.create")
// Add the item to the conversation
item := e.Item
// Ensure IDs are present
if item.User != nil && item.User.ID == "" {
item.User.ID = generateItemID()
}
if item.Assistant != nil && item.Assistant.ID == "" {
item.Assistant.ID = generateItemID()
}
if item.System != nil && item.System.ID == "" {
item.System.ID = generateItemID()
}
if item.FunctionCall != nil && item.FunctionCall.ID == "" {
item.FunctionCall.ID = generateItemID()
}
if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" {
item.FunctionCallOutput.ID = generateItemID()
}

conversation.Lock.Lock()
conversation.Items = append(conversation.Items, &item)
conversation.Lock.Unlock()

sendEvent(c, types.ConversationItemAddedEvent{
ServerEventBase: types.ServerEventBase{
EventID: e.EventID,
},
PreviousItemID: e.PreviousItemID,
Item: item,
})

case types.ConversationItemDeleteEvent:
sendError(c, "not_implemented", "Deleting items not implemented", "", "event_TODO")
Expand Down Expand Up @@ -444,7 +473,34 @@ func registerRealtime(application *application.Application, model string) func(c

case types.ResponseCreateEvent:
xlog.Debug("recv", "message", string(msg))
sendNotImplemented(c, "response.create")

// Handle optional items to add to context
if len(e.Response.Input) > 0 {
conversation.Lock.Lock()
for _, item := range e.Response.Input {
// Ensure IDs are present
if item.User != nil && item.User.ID == "" {
item.User.ID = generateItemID()
}
if item.Assistant != nil && item.Assistant.ID == "" {
item.Assistant.ID = generateItemID()
}
if item.System != nil && item.System.ID == "" {
item.System.ID = generateItemID()
}
if item.FunctionCall != nil && item.FunctionCall.ID == "" {
item.FunctionCall.ID = generateItemID()
}
if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" {
item.FunctionCallOutput.ID = generateItemID()
}

conversation.Items = append(conversation.Items, &item)
}
conversation.Lock.Unlock()
}

go triggerResponse(session, conversation, c, &e.Response)

case types.ResponseCancelEvent:
xlog.Debug("recv", "message", string(msg))
Expand Down Expand Up @@ -825,8 +881,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]schema.VADS
func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *LockedWebsocket, mt int) {
xlog.Debug("Generating realtime response...")

config := session.ModelInterface.PredictConfig()

// Create user message item
item := types.MessageItemUnion{
User: &types.MessageItemUser{
ID: generateItemID(),
Expand All @@ -848,33 +903,73 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con
Item: item,
})

triggerResponse(session, conv, c, nil)
}

func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, overrides *types.ResponseCreateParams) {
config := session.ModelInterface.PredictConfig()

// Default values
tools := session.Tools
toolChoice := session.ToolChoice
instructions := session.Instructions
// Overrides
if overrides != nil {
if overrides.Tools != nil {
tools = overrides.Tools
}
if overrides.ToolChoice != nil {
toolChoice = overrides.ToolChoice
}
if overrides.Instructions != "" {
instructions = overrides.Instructions
}
}

var conversationHistory schema.Messages
conversationHistory = append(conversationHistory, schema.Message{
Role: string(types.MessageRoleSystem),
StringContent: session.Instructions,
Content: session.Instructions,
StringContent: instructions,
Content: instructions,
})

imgIndex := 0
conv.Lock.Lock()
for _, item := range conv.Items {
if item.User != nil {
msg := schema.Message{
Role: string(types.MessageRoleUser),
}
textContent := ""
nrOfImgsInMessage := 0
for _, content := range item.User.Content {
switch content.Type {
case types.MessageContentTypeInputText:
conversationHistory = append(conversationHistory, schema.Message{
Role: string(types.MessageRoleUser),
StringContent: content.Text,
Content: content.Text,
})
textContent += content.Text
case types.MessageContentTypeInputAudio:
conversationHistory = append(conversationHistory, schema.Message{
Role: string(types.MessageRoleUser),
StringContent: content.Transcript,
Content: content.Transcript,
StringAudios: []string{content.Audio},
})
textContent += content.Transcript
case types.MessageContentTypeInputImage:
msg.StringImages = append(msg.StringImages, content.ImageURL)
imgIndex++
nrOfImgsInMessage++
}
}
if nrOfImgsInMessage > 0 {
templated, err := templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex,
ImagesInMessage: nrOfImgsInMessage,
}, textContent)
if err != nil {
xlog.Warn("Failed to apply multimodal template", "error", err)
templated = textContent
}
msg.StringContent = templated
msg.Content = templated
} else {
msg.StringContent = textContent
msg.Content = textContent
}
conversationHistory = append(conversationHistory, msg)
} else if item.Assistant != nil {
for _, content := range item.Assistant.Content {
switch content.Type {
Expand Down Expand Up @@ -905,6 +1000,11 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con
}
conv.Lock.Unlock()

var images []string
for _, m := range conversationHistory {
images = append(images, m.StringImages...)
}

responseID := generateUniqueID()
sendEvent(c, types.ResponseCreatedEvent{
ServerEventBase: types.ServerEventBase{},
Expand All @@ -915,15 +1015,15 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con
},
})

predFunc, err := session.ModelInterface.Predict(context.TODO(), conversationHistory, nil, nil, nil, nil, session.Tools, session.ToolChoice, nil, nil, nil)
predFunc, err := session.ModelInterface.Predict(context.TODO(), conversationHistory, images, nil, nil, nil, tools, toolChoice, nil, nil, nil)
if err != nil {
sendError(c, "inference_failed", fmt.Sprintf("backend error: %v", err), "", item.Assistant.ID)
sendError(c, "inference_failed", fmt.Sprintf("backend error: %v", err), "", "") // item.Assistant.ID is unknown here
return
}

pred, err := predFunc()
if err != nil {
sendError(c, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", item.Assistant.ID)
sendError(c, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", "")
return
}

Expand Down Expand Up @@ -1171,7 +1271,6 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con
Status: types.ResponseStatusCompleted,
},
})

}

// Helper functions to generate unique IDs
Expand Down
1 change: 1 addition & 0 deletions core/http/endpoints/openai/types/message_item.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const (
MessageContentTypeTranscript MessageContentType = "transcript"
MessageContentTypeInputText MessageContentType = "input_text"
MessageContentTypeInputAudio MessageContentType = "input_audio"
MessageContentTypeInputImage MessageContentType = "input_image"
MessageContentTypeOutputText MessageContentType = "output_text"
MessageContentTypeOutputAudio MessageContentType = "output_audio"
)
Expand Down
Loading