mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 19:03:01 +01:00
Merge pull request #553 from khoj-ai/features/validation-errors
Update types of base config models for pydantic 2.0
This commit is contained in:
commit
ebdb423d3e
5 changed files with 48 additions and 34 deletions
|
@ -151,17 +151,20 @@ def truncate_messages(
|
||||||
)
|
)
|
||||||
|
|
||||||
system_message = messages.pop()
|
system_message = messages.pop()
|
||||||
|
assert type(system_message.content) == str
|
||||||
system_message_tokens = len(encoder.encode(system_message.content))
|
system_message_tokens = len(encoder.encode(system_message.content))
|
||||||
|
|
||||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||||
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
||||||
messages.pop()
|
messages.pop()
|
||||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
assert type(system_message.content) == str
|
||||||
|
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||||
|
|
||||||
# Truncate current message if still over max supported prompt size by model
|
# Truncate current message if still over max supported prompt size by model
|
||||||
if (tokens + system_message_tokens) > max_prompt_size:
|
if (tokens + system_message_tokens) > max_prompt_size:
|
||||||
current_message = "\n".join(messages[0].content.split("\n")[:-1])
|
assert type(system_message.content) == str
|
||||||
original_question = "\n".join(messages[0].content.split("\n")[-1:])
|
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
|
||||||
|
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
||||||
original_question_tokens = len(encoder.encode(original_question))
|
original_question_tokens = len(encoder.encode(original_question))
|
||||||
remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens
|
remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens
|
||||||
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
||||||
|
|
|
@ -296,7 +296,7 @@ async def get_all_filenames(
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source))
|
return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
|
||||||
@api.post("/config/data/conversation/model", status_code=200)
|
@api.post("/config/data/conversation/model", status_code=200)
|
||||||
|
|
|
@ -63,7 +63,7 @@ async def update(
|
||||||
request: Request,
|
request: Request,
|
||||||
files: list[UploadFile],
|
files: list[UploadFile],
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
t: Optional[Union[state.SearchType, str]] = None,
|
t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
user_agent: Optional[str] = Header(None),
|
user_agent: Optional[str] = Header(None),
|
||||||
referer: Optional[str] = Header(None),
|
referer: Optional[str] = Header(None),
|
||||||
|
@ -182,13 +182,16 @@ def configure_content(
|
||||||
files: Optional[dict[str, dict[str, str]]],
|
files: Optional[dict[str, dict[str, str]]],
|
||||||
search_models: SearchModels,
|
search_models: SearchModels,
|
||||||
regenerate: bool = False,
|
regenerate: bool = False,
|
||||||
t: Optional[state.SearchType] = None,
|
t: Optional[state.SearchType] = state.SearchType.All,
|
||||||
full_corpus: bool = True,
|
full_corpus: bool = True,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
) -> tuple[Optional[ContentIndex], bool]:
|
) -> tuple[Optional[ContentIndex], bool]:
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
|
|
||||||
success = True
|
success = True
|
||||||
|
if t is not None and t in [type.value for type in state.SearchType]:
|
||||||
|
t = state.SearchType(t)
|
||||||
|
|
||||||
if t is not None and not t.value in [type.value for type in state.SearchType]:
|
if t is not None and not t.value in [type.value for type in state.SearchType]:
|
||||||
logger.warning(f"🚨 Invalid search type: {t}")
|
logger.warning(f"🚨 Invalid search type: {t}")
|
||||||
return None, False
|
return None, False
|
||||||
|
@ -201,7 +204,7 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize Org Notes Search
|
# Initialize Org Notes Search
|
||||||
if (search_type == None or search_type == state.SearchType.Org.value) and files["org"]:
|
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Org.value) and files["org"]:
|
||||||
logger.info("🦄 Setting up search for orgmode notes")
|
logger.info("🦄 Setting up search for orgmode notes")
|
||||||
# Extract Entries, Generate Notes Embeddings
|
# Extract Entries, Generate Notes Embeddings
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
|
@ -217,7 +220,9 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize Markdown Search
|
# Initialize Markdown Search
|
||||||
if (search_type == None or search_type == state.SearchType.Markdown.value) and files["markdown"]:
|
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Markdown.value) and files[
|
||||||
|
"markdown"
|
||||||
|
]:
|
||||||
logger.info("💎 Setting up search for markdown notes")
|
logger.info("💎 Setting up search for markdown notes")
|
||||||
# Extract Entries, Generate Markdown Embeddings
|
# Extract Entries, Generate Markdown Embeddings
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
|
@ -234,7 +239,7 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize PDF Search
|
# Initialize PDF Search
|
||||||
if (search_type == None or search_type == state.SearchType.Pdf.value) and files["pdf"]:
|
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Pdf.value) and files["pdf"]:
|
||||||
logger.info("🖨️ Setting up search for pdf")
|
logger.info("🖨️ Setting up search for pdf")
|
||||||
# Extract Entries, Generate PDF Embeddings
|
# Extract Entries, Generate PDF Embeddings
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
|
@ -251,7 +256,9 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize Plaintext Search
|
# Initialize Plaintext Search
|
||||||
if (search_type == None or search_type == state.SearchType.Plaintext.value) and files["plaintext"]:
|
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Plaintext.value) and files[
|
||||||
|
"plaintext"
|
||||||
|
]:
|
||||||
logger.info("📄 Setting up search for plaintext")
|
logger.info("📄 Setting up search for plaintext")
|
||||||
# Extract Entries, Generate Plaintext Embeddings
|
# Extract Entries, Generate Plaintext Embeddings
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
|
@ -269,7 +276,7 @@ def configure_content(
|
||||||
try:
|
try:
|
||||||
# Initialize Image Search
|
# Initialize Image Search
|
||||||
if (
|
if (
|
||||||
(search_type == None or search_type == state.SearchType.Image.value)
|
(search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value)
|
||||||
and content_config
|
and content_config
|
||||||
and content_config.image
|
and content_config.image
|
||||||
and search_models.image_search
|
and search_models.image_search
|
||||||
|
@ -286,7 +293,9 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||||
if (search_type == None or search_type == state.SearchType.Github.value) and github_config is not None:
|
if (
|
||||||
|
search_type == state.SearchType.All.value or search_type == state.SearchType.Github.value
|
||||||
|
) and github_config is not None:
|
||||||
logger.info("🐙 Setting up search for github")
|
logger.info("🐙 Setting up search for github")
|
||||||
# Extract Entries, Generate Github Embeddings
|
# Extract Entries, Generate Github Embeddings
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
|
@ -305,7 +314,9 @@ def configure_content(
|
||||||
try:
|
try:
|
||||||
# Initialize Notion Search
|
# Initialize Notion Search
|
||||||
notion_config = NotionConfig.objects.filter(user=user).first()
|
notion_config = NotionConfig.objects.filter(user=user).first()
|
||||||
if (search_type == None or search_type in state.SearchType.Notion.value) and notion_config:
|
if (
|
||||||
|
search_type == state.SearchType.All.value or search_type in state.SearchType.Notion.value
|
||||||
|
) and notion_config:
|
||||||
logger.info("🔌 Setting up search for notion")
|
logger.info("🔌 Setting up search for notion")
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
NotionToEntries,
|
NotionToEntries,
|
||||||
|
|
|
@ -229,7 +229,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
||||||
|
|
||||||
# Add the image metadata to the results
|
# Add the image metadata to the results
|
||||||
results += [
|
results += [
|
||||||
SearchResponse.parse_obj(
|
SearchResponse.model_validate(
|
||||||
{
|
{
|
||||||
"entry": f"{image_files_url}/{target_image_name}",
|
"entry": f"{image_files_url}/{target_image_name}",
|
||||||
"score": f"{hit['score']:.9f}",
|
"score": f"{hit['score']:.9f}",
|
||||||
|
@ -237,7 +237,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
||||||
"image_score": f"{hit['image_score']:.9f}",
|
"image_score": f"{hit['image_score']:.9f}",
|
||||||
"metadata_score": f"{hit['metadata_score']:.9f}",
|
"metadata_score": f"{hit['metadata_score']:.9f}",
|
||||||
},
|
},
|
||||||
"corpus_id": hit["corpus_id"],
|
"corpus_id": str(hit["corpus_id"]),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
|
@ -14,7 +14,7 @@ from khoj.utils.helpers import to_snake_case_from_dash
|
||||||
class ConfigBase(BaseModel):
|
class ConfigBase(BaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
alias_generator = to_snake_case_from_dash
|
alias_generator = to_snake_case_from_dash
|
||||||
allow_population_by_field_name = True
|
populate_by_name = True
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return getattr(self, item)
|
return getattr(self, item)
|
||||||
|
@ -29,8 +29,8 @@ class TextConfigBase(ConfigBase):
|
||||||
|
|
||||||
|
|
||||||
class TextContentConfig(ConfigBase):
|
class TextContentConfig(ConfigBase):
|
||||||
input_files: Optional[List[Path]]
|
input_files: Optional[List[Path]] = None
|
||||||
input_filter: Optional[List[str]]
|
input_filter: Optional[List[str]] = None
|
||||||
index_heading_entries: Optional[bool] = False
|
index_heading_entries: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,31 +50,31 @@ class NotionContentConfig(ConfigBase):
|
||||||
|
|
||||||
|
|
||||||
class ImageContentConfig(ConfigBase):
|
class ImageContentConfig(ConfigBase):
|
||||||
input_directories: Optional[List[Path]]
|
input_directories: Optional[List[Path]] = None
|
||||||
input_filter: Optional[List[str]]
|
input_filter: Optional[List[str]] = None
|
||||||
embeddings_file: Path
|
embeddings_file: Path
|
||||||
use_xmp_metadata: bool
|
use_xmp_metadata: bool
|
||||||
batch_size: int
|
batch_size: int
|
||||||
|
|
||||||
|
|
||||||
class ContentConfig(ConfigBase):
|
class ContentConfig(ConfigBase):
|
||||||
org: Optional[TextContentConfig]
|
org: Optional[TextContentConfig] = None
|
||||||
image: Optional[ImageContentConfig]
|
image: Optional[ImageContentConfig] = None
|
||||||
markdown: Optional[TextContentConfig]
|
markdown: Optional[TextContentConfig] = None
|
||||||
pdf: Optional[TextContentConfig]
|
pdf: Optional[TextContentConfig] = None
|
||||||
plaintext: Optional[TextContentConfig]
|
plaintext: Optional[TextContentConfig] = None
|
||||||
github: Optional[GithubContentConfig]
|
github: Optional[GithubContentConfig] = None
|
||||||
notion: Optional[NotionContentConfig]
|
notion: Optional[NotionContentConfig] = None
|
||||||
|
|
||||||
|
|
||||||
class ImageSearchConfig(ConfigBase):
|
class ImageSearchConfig(ConfigBase):
|
||||||
encoder: str
|
encoder: str
|
||||||
encoder_type: Optional[str]
|
encoder_type: Optional[str] = None
|
||||||
model_directory: Optional[Path]
|
model_directory: Optional[Path] = None
|
||||||
|
|
||||||
|
|
||||||
class SearchConfig(ConfigBase):
|
class SearchConfig(ConfigBase):
|
||||||
image: Optional[ImageSearchConfig]
|
image: Optional[ImageSearchConfig] = None
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProcessorConfig(ConfigBase):
|
class OpenAIProcessorConfig(ConfigBase):
|
||||||
|
@ -95,7 +95,7 @@ class ConversationProcessorConfig(ConfigBase):
|
||||||
|
|
||||||
|
|
||||||
class ProcessorConfig(ConfigBase):
|
class ProcessorConfig(ConfigBase):
|
||||||
conversation: Optional[ConversationProcessorConfig]
|
conversation: Optional[ConversationProcessorConfig] = None
|
||||||
|
|
||||||
|
|
||||||
class AppConfig(ConfigBase):
|
class AppConfig(ConfigBase):
|
||||||
|
@ -113,8 +113,8 @@ class FullConfig(ConfigBase):
|
||||||
class SearchResponse(ConfigBase):
|
class SearchResponse(ConfigBase):
|
||||||
entry: str
|
entry: str
|
||||||
score: float
|
score: float
|
||||||
cross_score: Optional[float]
|
cross_score: Optional[float] = None
|
||||||
additional: Optional[dict]
|
additional: Optional[dict] = None
|
||||||
corpus_id: str
|
corpus_id: str
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue