mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Store turn id with each chat message. Expose API to delete chat turn
Each chat turn is a user query, khoj response message pair
This commit is contained in:
parent
f64f5b3b6e
commit
ba15686682
4 changed files with 34 additions and 0 deletions
|
@ -1348,6 +1348,17 @@ class ConversationAdapters:
|
||||||
conversation.save()
|
conversation.save()
|
||||||
return conversation.file_filters
|
return conversation.file_filters
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
|
||||||
|
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
|
||||||
|
if not conversation:
|
||||||
|
return False
|
||||||
|
conversation_log = conversation.conversation_log
|
||||||
|
updated_log = [msg for msg in conversation_log if msg.get("turnId") != turn_id]
|
||||||
|
conversation.conversation_log = updated_log
|
||||||
|
conversation.save()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class FileObjectAdapters:
|
class FileObjectAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -4,6 +4,7 @@ import math
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import queue
|
import queue
|
||||||
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -232,12 +233,14 @@ def save_to_conversation_log(
|
||||||
train_of_thought: List[Any] = [],
|
train_of_thought: List[Any] = [],
|
||||||
):
|
):
|
||||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
turn_id = tracer.get("mid") or str(uuid.uuid4())
|
||||||
updated_conversation = message_to_log(
|
updated_conversation = message_to_log(
|
||||||
user_message=q,
|
user_message=q,
|
||||||
chat_response=chat_response,
|
chat_response=chat_response,
|
||||||
user_message_metadata={
|
user_message_metadata={
|
||||||
"created": user_message_time,
|
"created": user_message_time,
|
||||||
"images": query_images,
|
"images": query_images,
|
||||||
|
"turnId": turn_id,
|
||||||
},
|
},
|
||||||
khoj_message_metadata={
|
khoj_message_metadata={
|
||||||
"context": compiled_references,
|
"context": compiled_references,
|
||||||
|
@ -246,6 +249,7 @@ def save_to_conversation_log(
|
||||||
"codeContext": code_results,
|
"codeContext": code_results,
|
||||||
"automationId": automation_id,
|
"automationId": automation_id,
|
||||||
"trainOfThought": train_of_thought,
|
"trainOfThought": train_of_thought,
|
||||||
|
"turnId": turn_id,
|
||||||
},
|
},
|
||||||
conversation_log=meta_log.get("chat", []),
|
conversation_log=meta_log.get("chat", []),
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
|
|
|
@ -38,6 +38,7 @@ from khoj.routers.helpers import (
|
||||||
ChatRequestBody,
|
ChatRequestBody,
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
|
DeleteMessageRequestBody,
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
aget_relevant_information_sources,
|
aget_relevant_information_sources,
|
||||||
aget_relevant_output_modes,
|
aget_relevant_output_modes,
|
||||||
|
@ -534,6 +535,19 @@ async def set_conversation_title(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_chat.delete("/conversation/message", response_class=Response)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response:
|
||||||
|
user = request.user.object
|
||||||
|
success = ConversationAdapters.delete_message_by_turn_id(
|
||||||
|
user, delete_request.conversation_id, delete_request.turn_id
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
return Response(content=json.dumps({"status": "ok"}), media_type="application/json", status_code=200)
|
||||||
|
else:
|
||||||
|
return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@api_chat.post("")
|
@api_chat.post("")
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def chat(
|
async def chat(
|
||||||
|
|
|
@ -1264,6 +1264,11 @@ class ChatRequestBody(BaseModel):
|
||||||
create_new: Optional[bool] = False
|
create_new: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteMessageRequestBody(BaseModel):
|
||||||
|
conversation_id: str
|
||||||
|
turn_id: str
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
|
|
Loading…
Reference in a new issue