mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-12-18 18:47:11 +00:00
Initial commit of a functional but not yet elegant prototype for this concept
This commit is contained in:
parent
9368699b2c
commit
d91935c880
15 changed files with 455 additions and 150 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
import { AttachedFileText } from "../components/chatInputArea/chatInputArea";
|
||||||
import {
|
import {
|
||||||
CodeContext,
|
CodeContext,
|
||||||
Context,
|
Context,
|
||||||
|
@ -16,6 +17,12 @@ export interface MessageMetadata {
|
||||||
turnId: string;
|
turnId: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface GeneratedAssetsData {
|
||||||
|
images: string[];
|
||||||
|
excalidrawDiagram: string;
|
||||||
|
files: AttachedFileText[];
|
||||||
|
}
|
||||||
|
|
||||||
export interface ResponseWithIntent {
|
export interface ResponseWithIntent {
|
||||||
intentType: string;
|
intentType: string;
|
||||||
response: string;
|
response: string;
|
||||||
|
@ -84,6 +91,8 @@ export function processMessageChunk(
|
||||||
|
|
||||||
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext, codeContext };
|
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext, codeContext };
|
||||||
|
|
||||||
|
console.log(`chunk type: ${chunk.type}`);
|
||||||
|
|
||||||
if (chunk.type === "status") {
|
if (chunk.type === "status") {
|
||||||
console.log(`status: ${chunk.data}`);
|
console.log(`status: ${chunk.data}`);
|
||||||
const statusMessage = chunk.data as string;
|
const statusMessage = chunk.data as string;
|
||||||
|
@ -98,6 +107,20 @@ export function processMessageChunk(
|
||||||
} else if (chunk.type === "metadata") {
|
} else if (chunk.type === "metadata") {
|
||||||
const messageMetadata = chunk.data as MessageMetadata;
|
const messageMetadata = chunk.data as MessageMetadata;
|
||||||
currentMessage.turnId = messageMetadata.turnId;
|
currentMessage.turnId = messageMetadata.turnId;
|
||||||
|
} else if (chunk.type === "generated_assets") {
|
||||||
|
const generatedAssets = chunk.data as GeneratedAssetsData;
|
||||||
|
|
||||||
|
if (generatedAssets.images) {
|
||||||
|
currentMessage.generatedImages = generatedAssets.images;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (generatedAssets.excalidrawDiagram) {
|
||||||
|
currentMessage.generatedExcalidrawDiagram = generatedAssets.excalidrawDiagram;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (generatedAssets.files) {
|
||||||
|
currentMessage.generatedFiles = generatedAssets.files;
|
||||||
|
}
|
||||||
} else if (chunk.type === "message") {
|
} else if (chunk.type === "message") {
|
||||||
const chunkData = chunk.data;
|
const chunkData = chunk.data;
|
||||||
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw
|
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw
|
||||||
|
|
|
@ -54,6 +54,12 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
|
||||||
const lastIndex = props.trainOfThought.length - 1;
|
const lastIndex = props.trainOfThought.length - 1;
|
||||||
const [collapsed, setCollapsed] = useState(props.completed);
|
const [collapsed, setCollapsed] = useState(props.completed);
|
||||||
|
|
||||||
|
// useEffect(() => {
|
||||||
|
// if (props.completed) {
|
||||||
|
// setCollapsed(true);
|
||||||
|
// }
|
||||||
|
// }), [props.completed];
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={`${!collapsed ? styles.trainOfThought + " shadow-sm" : ""}`}
|
className={`${!collapsed ? styles.trainOfThought + " shadow-sm" : ""}`}
|
||||||
|
@ -410,6 +416,9 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||||
"inferred-queries": message.inferredQueries || [],
|
"inferred-queries": message.inferredQueries || [],
|
||||||
},
|
},
|
||||||
conversationId: props.conversationId,
|
conversationId: props.conversationId,
|
||||||
|
images: message.generatedImages,
|
||||||
|
queryFiles: message.generatedFiles,
|
||||||
|
excalidrawDiagram: message.generatedExcalidrawDiagram,
|
||||||
turnId: messageTurnId,
|
turnId: messageTurnId,
|
||||||
}}
|
}}
|
||||||
conversationId={props.conversationId}
|
conversationId={props.conversationId}
|
||||||
|
|
|
@ -163,6 +163,7 @@ export interface SingleChatMessage {
|
||||||
conversationId: string;
|
conversationId: string;
|
||||||
turnId?: string;
|
turnId?: string;
|
||||||
queryFiles?: AttachedFileText[];
|
queryFiles?: AttachedFileText[];
|
||||||
|
excalidrawDiagram?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface StreamMessage {
|
export interface StreamMessage {
|
||||||
|
@ -180,6 +181,10 @@ export interface StreamMessage {
|
||||||
inferredQueries?: string[];
|
inferredQueries?: string[];
|
||||||
turnId?: string;
|
turnId?: string;
|
||||||
queryFiles?: AttachedFileText[];
|
queryFiles?: AttachedFileText[];
|
||||||
|
excalidrawDiagram?: string;
|
||||||
|
generatedFiles?: AttachedFileText[];
|
||||||
|
generatedImages?: string[];
|
||||||
|
generatedExcalidrawDiagram?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatHistoryData {
|
export interface ChatHistoryData {
|
||||||
|
@ -264,6 +269,9 @@ interface ChatMessageProps {
|
||||||
onDeleteMessage: (turnId?: string) => void;
|
onDeleteMessage: (turnId?: string) => void;
|
||||||
conversationId: string;
|
conversationId: string;
|
||||||
turnId?: string;
|
turnId?: string;
|
||||||
|
generatedImage?: string;
|
||||||
|
excalidrawDiagram?: string;
|
||||||
|
generatedFiles?: AttachedFileText[];
|
||||||
}
|
}
|
||||||
|
|
||||||
interface TrainOfThoughtProps {
|
interface TrainOfThoughtProps {
|
||||||
|
@ -394,6 +402,10 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||||
setExcalidrawData(props.chatMessage.message);
|
setExcalidrawData(props.chatMessage.message);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (props.chatMessage.excalidrawDiagram) {
|
||||||
|
setExcalidrawData(props.chatMessage.excalidrawDiagram);
|
||||||
|
}
|
||||||
|
|
||||||
// Replace LaTeX delimiters with placeholders
|
// Replace LaTeX delimiters with placeholders
|
||||||
message = message
|
message = message
|
||||||
.replace(/\\\(/g, "LEFTPAREN")
|
.replace(/\\\(/g, "LEFTPAREN")
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from random import choice
|
from random import choice
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from django.contrib.auth.models import AbstractUser
|
from django.contrib.auth.models import AbstractUser
|
||||||
from django.contrib.postgres.fields import ArrayField
|
from django.contrib.postgres.fields import ArrayField
|
||||||
|
@ -11,9 +13,109 @@ from django.db.models.signals import pre_save
|
||||||
from django.dispatch import receiver
|
from django.dispatch import receiver
|
||||||
from pgvector.django import VectorField
|
from pgvector.django import VectorField
|
||||||
from phonenumber_field.modelfields import PhoneNumberField
|
from phonenumber_field.modelfields import PhoneNumberField
|
||||||
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(models.Model):
|
# Pydantic models for type Chat Message validation
|
||||||
|
class Context(PydanticBaseModel):
|
||||||
|
compiled: str
|
||||||
|
file: str
|
||||||
|
|
||||||
|
|
||||||
|
class CodeContextFile(PydanticBaseModel):
|
||||||
|
filename: str
|
||||||
|
b64_data: str
|
||||||
|
|
||||||
|
|
||||||
|
class CodeContextResult(PydanticBaseModel):
|
||||||
|
success: bool
|
||||||
|
output_files: List[CodeContextFile]
|
||||||
|
std_out: str
|
||||||
|
std_err: str
|
||||||
|
code_runtime: int
|
||||||
|
|
||||||
|
|
||||||
|
class CodeContextData(PydanticBaseModel):
|
||||||
|
code: str
|
||||||
|
result: CodeContextResult
|
||||||
|
|
||||||
|
|
||||||
|
class WebPage(PydanticBaseModel):
|
||||||
|
link: str
|
||||||
|
query: Optional[str] = None
|
||||||
|
snippet: str
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerBox(PydanticBaseModel):
|
||||||
|
link: str
|
||||||
|
snippet: str
|
||||||
|
title: str
|
||||||
|
snippetHighlighted: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class PeopleAlsoAsk(PydanticBaseModel):
|
||||||
|
link: str
|
||||||
|
question: str
|
||||||
|
snippet: str
|
||||||
|
title: str
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeGraph(PydanticBaseModel):
|
||||||
|
attributes: Dict[str, str]
|
||||||
|
description: str
|
||||||
|
descriptionLink: str
|
||||||
|
descriptionSource: str
|
||||||
|
imageUrl: str
|
||||||
|
title: str
|
||||||
|
type: str
|
||||||
|
|
||||||
|
|
||||||
|
class OrganicContext(PydanticBaseModel):
|
||||||
|
snippet: str
|
||||||
|
title: str
|
||||||
|
link: str
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineContext(PydanticBaseModel):
|
||||||
|
webpages: Optional[Union[WebPage, List[WebPage]]] = None
|
||||||
|
answerBox: Optional[AnswerBox] = None
|
||||||
|
peopleAlsoAsk: Optional[List[PeopleAlsoAsk]] = None
|
||||||
|
knowledgeGraph: Optional[KnowledgeGraph] = None
|
||||||
|
organicContext: Optional[List[OrganicContext]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Intent(PydanticBaseModel):
|
||||||
|
type: str
|
||||||
|
query: str
|
||||||
|
memory_type: str = Field(alias="memory-type")
|
||||||
|
inferred_queries: Optional[List[str]] = Field(default=None, alias="inferred-queries")
|
||||||
|
|
||||||
|
|
||||||
|
class TrainOfThought(PydanticBaseModel):
|
||||||
|
type: str
|
||||||
|
data: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(PydanticBaseModel):
|
||||||
|
message: str
|
||||||
|
trainOfThought: List[TrainOfThought] = []
|
||||||
|
context: List[Context] = []
|
||||||
|
onlineContext: Dict[str, OnlineContext] = {}
|
||||||
|
codeContext: Dict[str, CodeContextData] = {}
|
||||||
|
created: str
|
||||||
|
images: Optional[List[str]] = None
|
||||||
|
queryFiles: Optional[List[Dict]] = None
|
||||||
|
excalidrawDiagram: Optional[str] = None
|
||||||
|
by: str
|
||||||
|
turnId: Optional[str]
|
||||||
|
intent: Optional[Intent] = None
|
||||||
|
automationId: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DbBaseModel(models.Model):
|
||||||
created_at = models.DateTimeField(auto_now_add=True)
|
created_at = models.DateTimeField(auto_now_add=True)
|
||||||
updated_at = models.DateTimeField(auto_now=True)
|
updated_at = models.DateTimeField(auto_now=True)
|
||||||
|
|
||||||
|
@ -21,7 +123,7 @@ class BaseModel(models.Model):
|
||||||
abstract = True
|
abstract = True
|
||||||
|
|
||||||
|
|
||||||
class ClientApplication(BaseModel):
|
class ClientApplication(DbBaseModel):
|
||||||
name = models.CharField(max_length=200)
|
name = models.CharField(max_length=200)
|
||||||
client_id = models.CharField(max_length=200)
|
client_id = models.CharField(max_length=200)
|
||||||
client_secret = models.CharField(max_length=200)
|
client_secret = models.CharField(max_length=200)
|
||||||
|
@ -67,7 +169,7 @@ class KhojApiUser(models.Model):
|
||||||
accessed_at = models.DateTimeField(null=True, default=None)
|
accessed_at = models.DateTimeField(null=True, default=None)
|
||||||
|
|
||||||
|
|
||||||
class Subscription(BaseModel):
|
class Subscription(DbBaseModel):
|
||||||
class Type(models.TextChoices):
|
class Type(models.TextChoices):
|
||||||
TRIAL = "trial"
|
TRIAL = "trial"
|
||||||
STANDARD = "standard"
|
STANDARD = "standard"
|
||||||
|
@ -79,13 +181,13 @@ class Subscription(BaseModel):
|
||||||
enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True)
|
enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProcessorConversationConfig(BaseModel):
|
class OpenAIProcessorConversationConfig(DbBaseModel):
|
||||||
name = models.CharField(max_length=200)
|
name = models.CharField(max_length=200)
|
||||||
api_key = models.CharField(max_length=200)
|
api_key = models.CharField(max_length=200)
|
||||||
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
|
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
|
||||||
|
|
||||||
|
|
||||||
class ChatModelOptions(BaseModel):
|
class ChatModelOptions(DbBaseModel):
|
||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
OFFLINE = "offline"
|
OFFLINE = "offline"
|
||||||
|
@ -103,12 +205,12 @@ class ChatModelOptions(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class VoiceModelOption(BaseModel):
|
class VoiceModelOption(DbBaseModel):
|
||||||
model_id = models.CharField(max_length=200)
|
model_id = models.CharField(max_length=200)
|
||||||
name = models.CharField(max_length=200)
|
name = models.CharField(max_length=200)
|
||||||
|
|
||||||
|
|
||||||
class Agent(BaseModel):
|
class Agent(DbBaseModel):
|
||||||
class StyleColorTypes(models.TextChoices):
|
class StyleColorTypes(models.TextChoices):
|
||||||
BLUE = "blue"
|
BLUE = "blue"
|
||||||
GREEN = "green"
|
GREEN = "green"
|
||||||
|
@ -208,7 +310,7 @@ class Agent(BaseModel):
|
||||||
super().save(*args, **kwargs)
|
super().save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ProcessLock(BaseModel):
|
class ProcessLock(DbBaseModel):
|
||||||
class Operation(models.TextChoices):
|
class Operation(models.TextChoices):
|
||||||
INDEX_CONTENT = "index_content"
|
INDEX_CONTENT = "index_content"
|
||||||
SCHEDULED_JOB = "scheduled_job"
|
SCHEDULED_JOB = "scheduled_job"
|
||||||
|
@ -231,24 +333,24 @@ def verify_agent(sender, instance, **kwargs):
|
||||||
raise ValidationError(f"A private Agent with the name {instance.name} already exists.")
|
raise ValidationError(f"A private Agent with the name {instance.name} already exists.")
|
||||||
|
|
||||||
|
|
||||||
class NotionConfig(BaseModel):
|
class NotionConfig(DbBaseModel):
|
||||||
token = models.CharField(max_length=200)
|
token = models.CharField(max_length=200)
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class GithubConfig(BaseModel):
|
class GithubConfig(DbBaseModel):
|
||||||
pat_token = models.CharField(max_length=200)
|
pat_token = models.CharField(max_length=200)
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class GithubRepoConfig(BaseModel):
|
class GithubRepoConfig(DbBaseModel):
|
||||||
name = models.CharField(max_length=200)
|
name = models.CharField(max_length=200)
|
||||||
owner = models.CharField(max_length=200)
|
owner = models.CharField(max_length=200)
|
||||||
branch = models.CharField(max_length=200)
|
branch = models.CharField(max_length=200)
|
||||||
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
|
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
|
||||||
|
|
||||||
|
|
||||||
class WebScraper(BaseModel):
|
class WebScraper(DbBaseModel):
|
||||||
class WebScraperType(models.TextChoices):
|
class WebScraperType(models.TextChoices):
|
||||||
FIRECRAWL = "Firecrawl"
|
FIRECRAWL = "Firecrawl"
|
||||||
OLOSTEP = "Olostep"
|
OLOSTEP = "Olostep"
|
||||||
|
@ -321,7 +423,7 @@ class WebScraper(BaseModel):
|
||||||
super().save(*args, **kwargs)
|
super().save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ServerChatSettings(BaseModel):
|
class ServerChatSettings(DbBaseModel):
|
||||||
chat_default = models.ForeignKey(
|
chat_default = models.ForeignKey(
|
||||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
|
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
|
||||||
)
|
)
|
||||||
|
@ -333,35 +435,35 @@ class ServerChatSettings(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LocalOrgConfig(BaseModel):
|
class LocalOrgConfig(DbBaseModel):
|
||||||
input_files = models.JSONField(default=list, null=True)
|
input_files = models.JSONField(default=list, null=True)
|
||||||
input_filter = models.JSONField(default=list, null=True)
|
input_filter = models.JSONField(default=list, null=True)
|
||||||
index_heading_entries = models.BooleanField(default=False)
|
index_heading_entries = models.BooleanField(default=False)
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class LocalMarkdownConfig(BaseModel):
|
class LocalMarkdownConfig(DbBaseModel):
|
||||||
input_files = models.JSONField(default=list, null=True)
|
input_files = models.JSONField(default=list, null=True)
|
||||||
input_filter = models.JSONField(default=list, null=True)
|
input_filter = models.JSONField(default=list, null=True)
|
||||||
index_heading_entries = models.BooleanField(default=False)
|
index_heading_entries = models.BooleanField(default=False)
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class LocalPdfConfig(BaseModel):
|
class LocalPdfConfig(DbBaseModel):
|
||||||
input_files = models.JSONField(default=list, null=True)
|
input_files = models.JSONField(default=list, null=True)
|
||||||
input_filter = models.JSONField(default=list, null=True)
|
input_filter = models.JSONField(default=list, null=True)
|
||||||
index_heading_entries = models.BooleanField(default=False)
|
index_heading_entries = models.BooleanField(default=False)
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class LocalPlaintextConfig(BaseModel):
|
class LocalPlaintextConfig(DbBaseModel):
|
||||||
input_files = models.JSONField(default=list, null=True)
|
input_files = models.JSONField(default=list, null=True)
|
||||||
input_filter = models.JSONField(default=list, null=True)
|
input_filter = models.JSONField(default=list, null=True)
|
||||||
index_heading_entries = models.BooleanField(default=False)
|
index_heading_entries = models.BooleanField(default=False)
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class SearchModelConfig(BaseModel):
|
class SearchModelConfig(DbBaseModel):
|
||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
TEXT = "text"
|
TEXT = "text"
|
||||||
|
|
||||||
|
@ -393,7 +495,7 @@ class SearchModelConfig(BaseModel):
|
||||||
bi_encoder_confidence_threshold = models.FloatField(default=0.18)
|
bi_encoder_confidence_threshold = models.FloatField(default=0.18)
|
||||||
|
|
||||||
|
|
||||||
class TextToImageModelConfig(BaseModel):
|
class TextToImageModelConfig(DbBaseModel):
|
||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
STABILITYAI = "stability-ai"
|
STABILITYAI = "stability-ai"
|
||||||
|
@ -430,7 +532,7 @@ class TextToImageModelConfig(BaseModel):
|
||||||
super().save(*args, **kwargs)
|
super().save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextModelOptions(BaseModel):
|
class SpeechToTextModelOptions(DbBaseModel):
|
||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
OFFLINE = "offline"
|
OFFLINE = "offline"
|
||||||
|
@ -439,22 +541,22 @@ class SpeechToTextModelOptions(BaseModel):
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||||
|
|
||||||
|
|
||||||
class UserConversationConfig(BaseModel):
|
class UserConversationConfig(DbBaseModel):
|
||||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
|
||||||
class UserVoiceModelConfig(BaseModel):
|
class UserVoiceModelConfig(DbBaseModel):
|
||||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
|
||||||
class UserTextToImageModelConfig(BaseModel):
|
class UserTextToImageModelConfig(DbBaseModel):
|
||||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
|
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class Conversation(BaseModel):
|
class Conversation(DbBaseModel):
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
conversation_log = models.JSONField(default=dict)
|
conversation_log = models.JSONField(default=dict)
|
||||||
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
@ -468,8 +570,39 @@ class Conversation(BaseModel):
|
||||||
file_filters = models.JSONField(default=list)
|
file_filters = models.JSONField(default=list)
|
||||||
id = models.UUIDField(default=uuid.uuid4, editable=False, unique=True, primary_key=True, db_index=True)
|
id = models.UUIDField(default=uuid.uuid4, editable=False, unique=True, primary_key=True, db_index=True)
|
||||||
|
|
||||||
|
def clean(self):
|
||||||
|
# Validate conversation_log structure
|
||||||
|
try:
|
||||||
|
messages = self.conversation_log.get("chat", [])
|
||||||
|
for msg in messages:
|
||||||
|
ChatMessage.model_validate(msg)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValidationError(f"Invalid conversation_log format: {str(e)}")
|
||||||
|
|
||||||
class PublicConversation(BaseModel):
|
def save(self, *args, **kwargs):
|
||||||
|
self.clean()
|
||||||
|
super().save(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self) -> List[ChatMessage]:
|
||||||
|
"""Type-hinted accessor for conversation messages"""
|
||||||
|
validated_messages = []
|
||||||
|
for msg in self.conversation_log.get("chat", []):
|
||||||
|
try:
|
||||||
|
# Clean up inferred queries if they contain None
|
||||||
|
if msg.get("intent") and msg["intent"].get("inferred-queries"):
|
||||||
|
msg["intent"]["inferred-queries"] = [
|
||||||
|
q for q in msg["intent"]["inferred-queries"] if q is not None and isinstance(q, str)
|
||||||
|
]
|
||||||
|
msg["message"] = str(msg.get("message", ""))
|
||||||
|
validated_messages.append(ChatMessage.model_validate(msg))
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.warning(f"Skipping invalid message in conversation: {e}")
|
||||||
|
continue
|
||||||
|
return validated_messages
|
||||||
|
|
||||||
|
|
||||||
|
class PublicConversation(DbBaseModel):
|
||||||
source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
conversation_log = models.JSONField(default=dict)
|
conversation_log = models.JSONField(default=dict)
|
||||||
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
|
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
@ -499,12 +632,12 @@ def verify_public_conversation(sender, instance, **kwargs):
|
||||||
instance.slug = slug
|
instance.slug = slug
|
||||||
|
|
||||||
|
|
||||||
class ReflectiveQuestion(BaseModel):
|
class ReflectiveQuestion(DbBaseModel):
|
||||||
question = models.CharField(max_length=500)
|
question = models.CharField(max_length=500)
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
|
||||||
class Entry(BaseModel):
|
class Entry(DbBaseModel):
|
||||||
class EntryType(models.TextChoices):
|
class EntryType(models.TextChoices):
|
||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
PDF = "pdf"
|
PDF = "pdf"
|
||||||
|
@ -541,7 +674,7 @@ class Entry(BaseModel):
|
||||||
raise ValidationError("An Entry cannot be associated with both a user and an agent.")
|
raise ValidationError("An Entry cannot be associated with both a user and an agent.")
|
||||||
|
|
||||||
|
|
||||||
class FileObject(BaseModel):
|
class FileObject(DbBaseModel):
|
||||||
# Same as Entry but raw will be a much larger string
|
# Same as Entry but raw will be a much larger string
|
||||||
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||||
raw_text = models.TextField()
|
raw_text = models.TextField()
|
||||||
|
@ -549,7 +682,7 @@ class FileObject(BaseModel):
|
||||||
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
|
||||||
class EntryDates(BaseModel):
|
class EntryDates(DbBaseModel):
|
||||||
date = models.DateField()
|
date = models.DateField()
|
||||||
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
|
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
|
||||||
|
|
||||||
|
@ -559,12 +692,12 @@ class EntryDates(BaseModel):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class UserRequests(BaseModel):
|
class UserRequests(DbBaseModel):
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
slug = models.CharField(max_length=200)
|
slug = models.CharField(max_length=200)
|
||||||
|
|
||||||
|
|
||||||
class DataStore(BaseModel):
|
class DataStore(DbBaseModel):
|
||||||
key = models.CharField(max_length=200, unique=True)
|
key = models.CharField(max_length=200, unique=True)
|
||||||
value = models.JSONField(default=dict)
|
value = models.JSONField(default=dict)
|
||||||
private = models.BooleanField(default=False)
|
private = models.BooleanField(default=False)
|
||||||
|
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
|
@ -158,6 +158,10 @@ def converse_anthropic(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
|
generated_images: Optional[list[str]] = None,
|
||||||
|
generated_files: List[str] = None,
|
||||||
|
generated_excalidraw_diagram: Optional[str] = None,
|
||||||
|
additional_context: Optional[str] = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -218,6 +222,10 @@ def converse_anthropic(
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||||
|
generated_files=generated_files,
|
||||||
|
generated_images=generated_images,
|
||||||
|
additional_program_context=additional_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
||||||
|
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
|
@ -168,6 +168,10 @@ def converse_gemini(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
|
generated_images: Optional[list[str]] = None,
|
||||||
|
generated_files: List[str] = None,
|
||||||
|
generated_excalidraw_diagram: Optional[str] = None,
|
||||||
|
additional_context: List[str] = None,
|
||||||
tracer={},
|
tracer={},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -229,6 +233,10 @@ def converse_gemini(
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||||
|
generated_files=generated_files,
|
||||||
|
generated_images=generated_images,
|
||||||
|
additional_program_context=additional_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||||
|
|
|
@ -162,6 +162,8 @@ def converse_offline(
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
|
generated_files: List[str] = None,
|
||||||
|
additional_context: List[str] = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||||
"""
|
"""
|
||||||
|
@ -229,6 +231,8 @@ def converse_offline(
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
generated_files=generated_files,
|
||||||
|
additional_program_context=additional_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}")
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
|
@ -157,6 +157,10 @@ def converse(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
|
generated_images: Optional[list[str]] = None,
|
||||||
|
generated_files: List[str] = None,
|
||||||
|
generated_excalidraw_diagram: Optional[str] = None,
|
||||||
|
additional_context: List[str] = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -219,6 +223,10 @@ def converse(
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||||
|
generated_files=generated_files,
|
||||||
|
generated_images=generated_images,
|
||||||
|
additional_program_context=additional_context,
|
||||||
)
|
)
|
||||||
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
||||||
|
|
||||||
|
|
|
@ -180,6 +180,20 @@ Improved Prompt:
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
generated_image_attachment = PromptTemplate.from_template(
|
||||||
|
f"""
|
||||||
|
Here is the image you generated based on my query. You can follow-up with a general response to my query. Limit to 1-2 sentences.
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_diagram_attachment = PromptTemplate.from_template(
|
||||||
|
f"""
|
||||||
|
The AI has successfully created a diagram based on the user's query and handled the request. Good job!
|
||||||
|
|
||||||
|
AI can follow-up with a general response or summary. Limit to 1-2 sentences.
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
## Diagram Generation
|
## Diagram Generation
|
||||||
## --
|
## --
|
||||||
|
|
||||||
|
@ -1031,6 +1045,13 @@ A:
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
additional_program_context = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
Here's some additional context about what happened while I was executing this query:
|
||||||
|
{context}
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
personality_prompt_safety_expert_lax = PromptTemplate.from_template(
|
personality_prompt_safety_expert_lax = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -155,6 +155,9 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
|
||||||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||||
chat_history += f"User: {chat['intent']['query']}\n"
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||||
|
elif chat["by"] == "khoj" and chat.get("images"):
|
||||||
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
|
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||||
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
||||||
chat_history += f"User: {chat['intent']['query']}\n"
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
||||||
|
@ -211,6 +214,7 @@ class ChatEvent(Enum):
|
||||||
END_LLM_RESPONSE = "end_llm_response"
|
END_LLM_RESPONSE = "end_llm_response"
|
||||||
MESSAGE = "message"
|
MESSAGE = "message"
|
||||||
REFERENCES = "references"
|
REFERENCES = "references"
|
||||||
|
GENERATED_ASSETS = "generated_assets"
|
||||||
STATUS = "status"
|
STATUS = "status"
|
||||||
METADATA = "metadata"
|
METADATA = "metadata"
|
||||||
USAGE = "usage"
|
USAGE = "usage"
|
||||||
|
@ -223,7 +227,6 @@ def message_to_log(
|
||||||
user_message_metadata={},
|
user_message_metadata={},
|
||||||
khoj_message_metadata={},
|
khoj_message_metadata={},
|
||||||
conversation_log=[],
|
conversation_log=[],
|
||||||
train_of_thought=[],
|
|
||||||
):
|
):
|
||||||
"""Create json logs from messages, metadata for conversation log"""
|
"""Create json logs from messages, metadata for conversation log"""
|
||||||
default_khoj_message_metadata = {
|
default_khoj_message_metadata = {
|
||||||
|
@ -232,6 +235,10 @@ def message_to_log(
|
||||||
}
|
}
|
||||||
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
# Filter out any fields that are set to None
|
||||||
|
user_message_metadata = {k: v for k, v in user_message_metadata.items() if v is not None}
|
||||||
|
khoj_message_metadata = {k: v for k, v in khoj_message_metadata.items() if v is not None}
|
||||||
|
|
||||||
# Create json log from Human's message
|
# Create json log from Human's message
|
||||||
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
|
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
|
||||||
|
|
||||||
|
@ -259,6 +266,9 @@ def save_to_conversation_log(
|
||||||
automation_id: str = None,
|
automation_id: str = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
raw_query_files: List[FileAttachment] = [],
|
raw_query_files: List[FileAttachment] = [],
|
||||||
|
generated_images: List[str] = [],
|
||||||
|
raw_generated_files: List[FileAttachment] = [],
|
||||||
|
generated_excalidraw_diagram: str = None,
|
||||||
train_of_thought: List[Any] = [],
|
train_of_thought: List[Any] = [],
|
||||||
tracer: Dict[str, Any] = {},
|
tracer: Dict[str, Any] = {},
|
||||||
):
|
):
|
||||||
|
@ -281,9 +291,11 @@ def save_to_conversation_log(
|
||||||
"automationId": automation_id,
|
"automationId": automation_id,
|
||||||
"trainOfThought": train_of_thought,
|
"trainOfThought": train_of_thought,
|
||||||
"turnId": turn_id,
|
"turnId": turn_id,
|
||||||
|
"images": generated_images,
|
||||||
|
"queryFiles": [file.model_dump(mode="json") for file in raw_generated_files],
|
||||||
|
"excalidrawDiagram": str(generated_excalidraw_diagram),
|
||||||
},
|
},
|
||||||
conversation_log=meta_log.get("chat", []),
|
conversation_log=meta_log.get("chat", []),
|
||||||
train_of_thought=train_of_thought,
|
|
||||||
)
|
)
|
||||||
ConversationAdapters.save_conversation(
|
ConversationAdapters.save_conversation(
|
||||||
user,
|
user,
|
||||||
|
@ -307,7 +319,7 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
|
||||||
|
|
||||||
|
|
||||||
def construct_structured_message(
|
def construct_structured_message(
|
||||||
message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str
|
message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Format messages into appropriate multimedia format for supported chat model types
|
Format messages into appropriate multimedia format for supported chat model types
|
||||||
|
@ -363,6 +375,10 @@ def generate_chatml_messages_with_context(
|
||||||
model_type="",
|
model_type="",
|
||||||
context_message="",
|
context_message="",
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
|
generated_images: Optional[list[str]] = None,
|
||||||
|
generated_files: List[FileAttachment] = None,
|
||||||
|
generated_excalidraw_diagram: str = None,
|
||||||
|
additional_program_context: List[str] = [],
|
||||||
):
|
):
|
||||||
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
||||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||||
|
@ -382,6 +398,7 @@ def generate_chatml_messages_with_context(
|
||||||
message_attached_files = ""
|
message_attached_files = ""
|
||||||
|
|
||||||
chat_message = chat.get("message")
|
chat_message = chat.get("message")
|
||||||
|
role = "user" if chat["by"] == "you" else "assistant"
|
||||||
|
|
||||||
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
||||||
chat_message = chat["intent"].get("inferred-queries")[0]
|
chat_message = chat["intent"].get("inferred-queries")[0]
|
||||||
|
@ -402,7 +419,7 @@ def generate_chatml_messages_with_context(
|
||||||
query_files_dict[file["name"]] = file["content"]
|
query_files_dict[file["name"]] = file["content"]
|
||||||
|
|
||||||
message_attached_files = gather_raw_query_files(query_files_dict)
|
message_attached_files = gather_raw_query_files(query_files_dict)
|
||||||
chatml_messages.append(ChatMessage(content=message_attached_files, role="user"))
|
chatml_messages.append(ChatMessage(content=message_attached_files, role=role))
|
||||||
|
|
||||||
if not is_none_or_empty(chat.get("onlineContext")):
|
if not is_none_or_empty(chat.get("onlineContext")):
|
||||||
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
|
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
|
||||||
|
@ -411,9 +428,18 @@ def generate_chatml_messages_with_context(
|
||||||
reconstructed_context_message = ChatMessage(content=message_context, role="user")
|
reconstructed_context_message = ChatMessage(content=message_context, role="user")
|
||||||
chatml_messages.insert(0, reconstructed_context_message)
|
chatml_messages.insert(0, reconstructed_context_message)
|
||||||
|
|
||||||
role = "user" if chat["by"] == "you" else "assistant"
|
if chat.get("images") and role == "assistant":
|
||||||
|
# Issue: the assistant role cannot accept an image as a message content, so send it in a separate user message.
|
||||||
|
file_attachment_message = construct_structured_message(
|
||||||
|
message=prompts.generated_image_attachment.format(),
|
||||||
|
images=chat.get("images"),
|
||||||
|
model_type=model_type,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
|
)
|
||||||
|
chatml_messages.append(ChatMessage(content=file_attachment_message, role="user"))
|
||||||
|
|
||||||
message_content = construct_structured_message(
|
message_content = construct_structured_message(
|
||||||
chat_message, chat.get("images"), model_type, vision_enabled, attached_file_context=query_files
|
chat_message, chat.get("images") if role == "user" else [], model_type, vision_enabled
|
||||||
)
|
)
|
||||||
|
|
||||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||||
|
@ -423,6 +449,7 @@ def generate_chatml_messages_with_context(
|
||||||
break
|
break
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
if not is_none_or_empty(user_message):
|
if not is_none_or_empty(user_message):
|
||||||
messages.append(
|
messages.append(
|
||||||
ChatMessage(
|
ChatMessage(
|
||||||
|
@ -435,6 +462,31 @@ def generate_chatml_messages_with_context(
|
||||||
if not is_none_or_empty(context_message):
|
if not is_none_or_empty(context_message):
|
||||||
messages.append(ChatMessage(content=context_message, role="user"))
|
messages.append(ChatMessage(content=context_message, role="user"))
|
||||||
|
|
||||||
|
if generated_images:
|
||||||
|
messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
content=construct_structured_message(
|
||||||
|
prompts.generated_image_attachment.format(), generated_images, model_type, vision_enabled
|
||||||
|
),
|
||||||
|
role="user",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if generated_files:
|
||||||
|
message_attached_files = gather_raw_query_files({file.name: file.content for file in generated_files})
|
||||||
|
messages.append(ChatMessage(content=message_attached_files, role="assistant"))
|
||||||
|
|
||||||
|
if generated_excalidraw_diagram:
|
||||||
|
messages.append(ChatMessage(content=prompts.generated_diagram_attachment.format(), role="assistant"))
|
||||||
|
|
||||||
|
if additional_program_context:
|
||||||
|
messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
content=prompts.additional_program_context.format(context="\n".join(additional_program_context)),
|
||||||
|
role="assistant",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if len(chatml_messages) > 0:
|
if len(chatml_messages) > 0:
|
||||||
messages += chatml_messages
|
messages += chatml_messages
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
|
||||||
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
|
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
|
||||||
from khoj.routers.storage import upload_image
|
from khoj.routers.storage import upload_image
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer
|
from khoj.utils.helpers import convert_image_to_webp, timer
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -34,14 +34,13 @@ async def text_to_image(
|
||||||
status_code = 200
|
status_code = 200
|
||||||
image = None
|
image = None
|
||||||
image_url = None
|
image_url = None
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
|
||||||
|
|
||||||
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
|
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
|
||||||
if not text_to_image_config:
|
if not text_to_image_config:
|
||||||
# If the user has not configured a text to image model, return an unsupported on server error
|
# If the user has not configured a text to image model, return an unsupported on server error
|
||||||
status_code = 501
|
status_code = 501
|
||||||
message = "Failed to generate image. Setup image generation on the server."
|
message = "Failed to generate image. Setup image generation on the server."
|
||||||
yield image_url or image, status_code, message, intent_type.value
|
yield image_url or image, status_code, message
|
||||||
return
|
return
|
||||||
|
|
||||||
text2image_model = text_to_image_config.model_name
|
text2image_model = text_to_image_config.model_name
|
||||||
|
@ -53,6 +52,9 @@ async def text_to_image(
|
||||||
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
|
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
|
||||||
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
|
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
|
||||||
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
|
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
|
||||||
|
elif chat["by"] == "khoj" and chat.get("images"):
|
||||||
|
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||||
|
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
|
||||||
|
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
async for event in send_status_func("**Enhancing the Painting Prompt**"):
|
async for event in send_status_func("**Enhancing the Painting Prompt**"):
|
||||||
|
@ -92,31 +94,29 @@ async def text_to_image(
|
||||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore
|
message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore
|
||||||
yield image_url or image, status_code, message, intent_type.value
|
yield image_url or image, status_code, message
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
message = f"Image generation failed using OpenAI" # type: ignore
|
message = f"Image generation failed using OpenAI" # type: ignore
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
yield image_url or image, status_code, message, intent_type.value
|
yield image_url or image, status_code, message
|
||||||
return
|
return
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
|
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
|
||||||
status_code = 502
|
status_code = 502
|
||||||
yield image_url or image, status_code, message, intent_type.value
|
yield image_url or image, status_code, message
|
||||||
return
|
return
|
||||||
|
|
||||||
# Decide how to store the generated image
|
# Decide how to store the generated image
|
||||||
with timer("Upload image to S3", logger):
|
with timer("Upload image to S3", logger):
|
||||||
image_url = upload_image(webp_image_bytes, user.uuid)
|
image_url = upload_image(webp_image_bytes, user.uuid)
|
||||||
if image_url:
|
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE2
|
if not image_url:
|
||||||
else:
|
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
|
||||||
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||||
|
|
||||||
yield image_url or image, status_code, image_prompt, intent_type.value
|
yield image_url or image, status_code, image_prompt
|
||||||
|
|
||||||
|
|
||||||
def generate_image_with_openai(
|
def generate_image_with_openai(
|
||||||
|
|
|
@ -77,6 +77,7 @@ from khoj.utils.helpers import (
|
||||||
)
|
)
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
ChatRequestBody,
|
ChatRequestBody,
|
||||||
|
FileAttachment,
|
||||||
FileFilterRequest,
|
FileFilterRequest,
|
||||||
FilesFilterRequest,
|
FilesFilterRequest,
|
||||||
LocationData,
|
LocationData,
|
||||||
|
@ -771,6 +772,11 @@ async def chat(
|
||||||
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
||||||
attached_file_context = gather_raw_query_files(query_files)
|
attached_file_context = gather_raw_query_files(query_files)
|
||||||
|
|
||||||
|
generated_images: List[str] = []
|
||||||
|
generated_files: List[FileAttachment] = []
|
||||||
|
generated_excalidraw_diagram: str = None
|
||||||
|
additional_context_for_llm_response: List[str] = []
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||||
chosen_io = await aget_data_sources_and_output_format(
|
chosen_io = await aget_data_sources_and_output_format(
|
||||||
q,
|
q,
|
||||||
|
@ -876,21 +882,17 @@ async def chat(
|
||||||
async for result in send_llm_response(response, tracer.get("usage")):
|
async for result in send_llm_response(response, tracer.get("usage")):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
summarized_document = FileAttachment(
|
||||||
q,
|
name="Summarized Document",
|
||||||
response_log,
|
content=response_log,
|
||||||
user,
|
type="text/plain",
|
||||||
meta_log,
|
size=len(response_log.encode("utf-8")),
|
||||||
user_message_time,
|
|
||||||
intent_type="summarize",
|
|
||||||
client_application=request.user.client_app,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
query_images=uploaded_images,
|
|
||||||
train_of_thought=train_of_thought,
|
|
||||||
raw_query_files=raw_query_files,
|
|
||||||
tracer=tracer,
|
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
async for result in send_event(ChatEvent.GENERATED_ASSETS, {"files": [summarized_document.model_dump()]}):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
generated_files.append(summarized_document)
|
||||||
|
|
||||||
custom_filters = []
|
custom_filters = []
|
||||||
if conversation_commands == [ConversationCommand.Help]:
|
if conversation_commands == [ConversationCommand.Help]:
|
||||||
|
@ -1079,6 +1081,7 @@ async def chat(
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"):
|
async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"):
|
||||||
yield result
|
yield result
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
additional_context_for_llm_response.append(f"Failed to run code")
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to use code tool: {e}. Attempting to respond without code results",
|
f"Failed to use code tool: {e}. Attempting to respond without code results",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
|
@ -1116,51 +1119,36 @@ async def chat(
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
else:
|
else:
|
||||||
generated_image, status_code, improved_image_prompt, intent_type = result
|
generated_image, status_code, improved_image_prompt = result
|
||||||
|
|
||||||
|
inferred_queries.append(improved_image_prompt)
|
||||||
if generated_image is None or status_code != 200:
|
if generated_image is None or status_code != 200:
|
||||||
content_obj = {
|
additional_context_for_llm_response.append(f"Failed to generate image with {improved_image_prompt}")
|
||||||
"content-type": "application/json",
|
async for result in send_event(ChatEvent.STATUS, f"Failed to generate image"):
|
||||||
"intentType": intent_type,
|
|
||||||
"detail": improved_image_prompt,
|
|
||||||
"image": None,
|
|
||||||
}
|
|
||||||
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
|
||||||
yield result
|
yield result
|
||||||
return
|
else:
|
||||||
|
generated_images.append(generated_image)
|
||||||
|
# content_obj = {
|
||||||
|
# "intentType": intent_type,
|
||||||
|
# "inferredQueries": [improved_image_prompt],
|
||||||
|
# "image": generated_image,
|
||||||
|
# }
|
||||||
|
# async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||||
|
# yield result
|
||||||
|
# return
|
||||||
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
async for result in send_event(
|
||||||
q,
|
ChatEvent.GENERATED_ASSETS,
|
||||||
generated_image,
|
{
|
||||||
user,
|
"images": [generated_image],
|
||||||
meta_log,
|
},
|
||||||
user_message_time,
|
):
|
||||||
intent_type=intent_type,
|
yield result
|
||||||
inferred_queries=[improved_image_prompt],
|
|
||||||
client_application=request.user.client_app,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
compiled_references=compiled_references,
|
|
||||||
online_results=online_results,
|
|
||||||
code_results=code_results,
|
|
||||||
query_images=uploaded_images,
|
|
||||||
train_of_thought=train_of_thought,
|
|
||||||
raw_query_files=raw_query_files,
|
|
||||||
tracer=tracer,
|
|
||||||
)
|
|
||||||
content_obj = {
|
|
||||||
"intentType": intent_type,
|
|
||||||
"inferredQueries": [improved_image_prompt],
|
|
||||||
"image": generated_image,
|
|
||||||
}
|
|
||||||
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
|
||||||
yield result
|
|
||||||
return
|
|
||||||
|
|
||||||
if ConversationCommand.Diagram in conversation_commands:
|
if ConversationCommand.Diagram in conversation_commands:
|
||||||
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
|
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
intent_type = "excalidraw"
|
|
||||||
inferred_queries = []
|
inferred_queries = []
|
||||||
diagram_description = ""
|
diagram_description = ""
|
||||||
|
|
||||||
|
@ -1184,62 +1172,59 @@ async def chat(
|
||||||
if better_diagram_description_prompt and excalidraw_diagram_description:
|
if better_diagram_description_prompt and excalidraw_diagram_description:
|
||||||
inferred_queries.append(better_diagram_description_prompt)
|
inferred_queries.append(better_diagram_description_prompt)
|
||||||
diagram_description = excalidraw_diagram_description
|
diagram_description = excalidraw_diagram_description
|
||||||
|
|
||||||
|
generated_excalidraw_diagram = diagram_description
|
||||||
|
|
||||||
|
async for result in send_event(
|
||||||
|
ChatEvent.GENERATED_ASSETS,
|
||||||
|
{
|
||||||
|
"excalidrawDiagram": excalidraw_diagram_description,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
yield result
|
||||||
else:
|
else:
|
||||||
error_message = "Failed to generate diagram. Please try again later."
|
error_message = "Failed to generate diagram. Please try again later."
|
||||||
async for result in send_llm_response(error_message, tracer.get("usage")):
|
additional_context_for_llm_response.append(
|
||||||
|
f"AI attempted to programmatically generate a diagram but failed due to a program issue. Generally, it is able to do so, but encountered a system issue this time. AI can suggest text description or rendering of the diagram or user can try again with a simpler prompt."
|
||||||
|
)
|
||||||
|
|
||||||
|
async for result in send_event(ChatEvent.STATUS, error_message):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
# content_obj = {
|
||||||
q,
|
# "intentType": intent_type,
|
||||||
error_message,
|
# "inferredQueries": inferred_queries,
|
||||||
user,
|
# "image": diagram_description,
|
||||||
meta_log,
|
# }
|
||||||
user_message_time,
|
|
||||||
inferred_queries=[better_diagram_description_prompt],
|
|
||||||
client_application=request.user.client_app,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
compiled_references=compiled_references,
|
|
||||||
online_results=online_results,
|
|
||||||
code_results=code_results,
|
|
||||||
query_images=uploaded_images,
|
|
||||||
train_of_thought=train_of_thought,
|
|
||||||
raw_query_files=raw_query_files,
|
|
||||||
tracer=tracer,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
content_obj = {
|
# await sync_to_async(save_to_conversation_log)(
|
||||||
"intentType": intent_type,
|
# q,
|
||||||
"inferredQueries": inferred_queries,
|
# excalidraw_diagram_description,
|
||||||
"image": diagram_description,
|
# user,
|
||||||
}
|
# meta_log,
|
||||||
|
# user_message_time,
|
||||||
|
# intent_type="excalidraw",
|
||||||
|
# inferred_queries=[better_diagram_description_prompt],
|
||||||
|
# client_application=request.user.client_app,
|
||||||
|
# conversation_id=conversation_id,
|
||||||
|
# compiled_references=compiled_references,
|
||||||
|
# online_results=online_results,
|
||||||
|
# code_results=code_results,
|
||||||
|
# query_images=uploaded_images,
|
||||||
|
# train_of_thought=train_of_thought,
|
||||||
|
# raw_query_files=raw_query_files,
|
||||||
|
# generated_images=generated_images,
|
||||||
|
# tracer=tracer,
|
||||||
|
# )
|
||||||
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
# async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
||||||
q,
|
# yield result
|
||||||
excalidraw_diagram_description,
|
# return
|
||||||
user,
|
|
||||||
meta_log,
|
|
||||||
user_message_time,
|
|
||||||
intent_type="excalidraw",
|
|
||||||
inferred_queries=[better_diagram_description_prompt],
|
|
||||||
client_application=request.user.client_app,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
compiled_references=compiled_references,
|
|
||||||
online_results=online_results,
|
|
||||||
code_results=code_results,
|
|
||||||
query_images=uploaded_images,
|
|
||||||
train_of_thought=train_of_thought,
|
|
||||||
raw_query_files=raw_query_files,
|
|
||||||
tracer=tracer,
|
|
||||||
)
|
|
||||||
|
|
||||||
async for result in send_llm_response(json.dumps(content_obj), tracer.get("usage")):
|
|
||||||
yield result
|
|
||||||
return
|
|
||||||
|
|
||||||
## Generate Text Output
|
## Generate Text Output
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
|
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
llm_response, chat_metadata = await agenerate_chat_response(
|
llm_response, chat_metadata = await agenerate_chat_response(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
meta_log,
|
meta_log,
|
||||||
|
@ -1259,6 +1244,10 @@ async def chat(
|
||||||
train_of_thought,
|
train_of_thought,
|
||||||
attached_file_context,
|
attached_file_context,
|
||||||
raw_query_files,
|
raw_query_files,
|
||||||
|
generated_images,
|
||||||
|
generated_files,
|
||||||
|
generated_excalidraw_diagram,
|
||||||
|
additional_context_for_llm_response,
|
||||||
tracer,
|
tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1184,6 +1184,10 @@ def generate_chat_response(
|
||||||
train_of_thought: List[Any] = [],
|
train_of_thought: List[Any] = [],
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
raw_query_files: List[FileAttachment] = None,
|
raw_query_files: List[FileAttachment] = None,
|
||||||
|
generated_images: List[str] = None,
|
||||||
|
raw_generated_files: List[FileAttachment] = [],
|
||||||
|
generated_excalidraw_diagram: str = None,
|
||||||
|
additional_context: List[str] = [],
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
|
@ -1207,6 +1211,9 @@ def generate_chat_response(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
raw_query_files=raw_query_files,
|
raw_query_files=raw_query_files,
|
||||||
|
generated_images=generated_images,
|
||||||
|
raw_generated_files=raw_generated_files,
|
||||||
|
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1242,6 +1249,7 @@ def generate_chat_response(
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
generated_files=raw_generated_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1268,6 +1276,10 @@ def generate_chat_response(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
generated_files=raw_generated_files,
|
||||||
|
generated_images=generated_images,
|
||||||
|
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||||
|
additional_context=additional_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1291,6 +1303,10 @@ def generate_chat_response(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
generated_files=raw_generated_files,
|
||||||
|
generated_images=generated_images,
|
||||||
|
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||||
|
additional_context=additional_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
|
@ -1313,6 +1329,10 @@ def generate_chat_response(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
generated_files=raw_generated_files,
|
||||||
|
generated_images=generated_images,
|
||||||
|
generated_excalidraw_diagram=generated_excalidraw_diagram,
|
||||||
|
additional_context=additional_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1784,6 +1804,9 @@ class MessageProcessor:
|
||||||
self.references = {}
|
self.references = {}
|
||||||
self.usage = {}
|
self.usage = {}
|
||||||
self.raw_response = ""
|
self.raw_response = ""
|
||||||
|
self.generated_images = []
|
||||||
|
self.generated_files = []
|
||||||
|
self.generated_excalidraw_diagrams = []
|
||||||
|
|
||||||
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
|
def convert_message_chunk_to_json(self, raw_chunk: str) -> Dict[str, Any]:
|
||||||
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
|
if raw_chunk.startswith("{") and raw_chunk.endswith("}"):
|
||||||
|
@ -1822,6 +1845,16 @@ class MessageProcessor:
|
||||||
self.raw_response += chunk_data
|
self.raw_response += chunk_data
|
||||||
else:
|
else:
|
||||||
self.raw_response += chunk_data
|
self.raw_response += chunk_data
|
||||||
|
elif chunk_type == ChatEvent.GENERATED_ASSETS:
|
||||||
|
chunk_data = chunk["data"]
|
||||||
|
if isinstance(chunk_data, dict):
|
||||||
|
for key in chunk_data:
|
||||||
|
if key == "images":
|
||||||
|
self.generated_images = chunk_data[key]
|
||||||
|
elif key == "files":
|
||||||
|
self.generated_files = chunk_data[key]
|
||||||
|
elif key == "excalidraw_diagrams":
|
||||||
|
self.generated_excalidraw_diagrams = chunk_data[key]
|
||||||
|
|
||||||
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
|
def handle_json_response(self, json_data: Dict[str, str]) -> str | Dict[str, str]:
|
||||||
if "image" in json_data or "details" in json_data:
|
if "image" in json_data or "details" in json_data:
|
||||||
|
@ -1852,7 +1885,14 @@ async def read_chat_stream(response_iterator: AsyncGenerator[str, None]) -> Dict
|
||||||
if buffer:
|
if buffer:
|
||||||
processor.process_message_chunk(buffer)
|
processor.process_message_chunk(buffer)
|
||||||
|
|
||||||
return {"response": processor.raw_response, "references": processor.references, "usage": processor.usage}
|
return {
|
||||||
|
"response": processor.raw_response,
|
||||||
|
"references": processor.references,
|
||||||
|
"usage": processor.usage,
|
||||||
|
"images": processor.generated_images,
|
||||||
|
"files": processor.generated_files,
|
||||||
|
"excalidraw_diagrams": processor.generated_excalidraw_diagrams,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False):
|
||||||
|
|
|
@ -22,7 +22,6 @@ from khoj.processor.conversation.offline.chat_model import (
|
||||||
filter_questions,
|
filter_questions,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.offline.utils import download_model
|
from khoj.processor.conversation.offline.utils import download_model
|
||||||
from khoj.processor.conversation.utils import message_to_log
|
|
||||||
from khoj.utils.constants import default_offline_chat_models
|
from khoj.utils.constants import default_offline_chat_models
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ from freezegun import freeze_time
|
||||||
|
|
||||||
from khoj.database.models import Agent, Entry, KhojUser
|
from khoj.database.models import Agent, Entry, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.utils import message_to_log
|
|
||||||
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_api_key
|
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_api_key
|
||||||
|
|
||||||
# Initialize variables for tests
|
# Initialize variables for tests
|
||||||
|
|
Loading…
Reference in a new issue