Rename SearchModel to SearchModelConfig DB model, Require Cross-Encoder

This commit is contained in:
Debanjum Singh Solanky 2023-11-15 17:12:54 -08:00
parent 0679b2a7bd
commit 08a057bdd5
6 changed files with 43 additions and 15 deletions

View file

@ -31,7 +31,7 @@ from database.models import (
GithubRepoConfig, GithubRepoConfig,
Conversation, Conversation,
ChatModelOptions, ChatModelOptions,
SearchModel, SearchModelConfig,
Subscription, Subscription,
UserConversationConfig, UserConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
@ -216,9 +216,9 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
def get_or_create_search_model(): def get_or_create_search_model():
search_model = SearchModel.objects.filter().first() search_model = SearchModelConfig.objects.filter().first()
if not search_model: if not search_model:
search_model = SearchModel.objects.create() search_model = SearchModelConfig.objects.create()
return search_model return search_model

View file

@ -8,7 +8,7 @@ from database.models import (
ChatModelOptions, ChatModelOptions,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
SearchModel, SearchModelConfig,
Subscription, Subscription,
) )
@ -17,5 +17,5 @@ admin.site.register(KhojUser, UserAdmin)
admin.site.register(ChatModelOptions) admin.site.register(ChatModelOptions)
admin.site.register(OpenAIProcessorConversationConfig) admin.site.register(OpenAIProcessorConversationConfig)
admin.site.register(OfflineChatProcessorConversationConfig) admin.site.register(OfflineChatProcessorConversationConfig)
admin.site.register(SearchModel) admin.site.register(SearchModelConfig)
admin.site.register(Subscription) admin.site.register(Subscription)

View file

@ -0,0 +1,30 @@
# Generated by Django 4.2.5 on 2023-11-16 01:13
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0017_searchmodel"),
]
operations = [
migrations.CreateModel(
name="SearchModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(default="default", max_length=200)),
("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
("cross_encoder", models.CharField(default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200)),
],
options={
"abstract": False,
},
),
migrations.DeleteModel(
name="SearchModel",
),
]

View file

@ -102,16 +102,14 @@ class LocalPlaintextConfig(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class SearchModel(BaseModel): class SearchModelConfig(BaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
TEXT = "text" TEXT = "text"
name = models.CharField(max_length=200, default="default") name = models.CharField(max_length=200, default="default")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small") bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
cross_encoder = models.CharField( cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2", null=True, blank=True
)
class OpenAIProcessorConversationConfig(BaseModel): class OpenAIProcessorConversationConfig(BaseModel):

View file

@ -64,7 +64,7 @@ from database.models import (
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
ChatModelOptions, ChatModelOptions,
SearchModel, SearchModelConfig,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -87,12 +87,12 @@ def migrate_server_pg(args):
if "search-type" in raw_config and raw_config["search-type"]: if "search-type" in raw_config and raw_config["search-type"]:
if "asymmetric" in raw_config["search-type"]: if "asymmetric" in raw_config["search-type"]:
# Delete all existing search models # Delete all existing search models
SearchModel.objects.filter(model_type=SearchModel.ModelType.TEXT).delete() SearchModelConfig.objects.filter(model_type=SearchModelConfig.ModelType.TEXT).delete()
# Create new search model from existing Khoj YAML config # Create new search model from existing Khoj YAML config
asymmetric_search = raw_config["search-type"]["asymmetric"] asymmetric_search = raw_config["search-type"]["asymmetric"]
SearchModel.objects.create( SearchModelConfig.objects.create(
name="default", name="default",
model_type=SearchModel.ModelType.TEXT, model_type=SearchModelConfig.ModelType.TEXT,
bi_encoder=asymmetric_search.get("encoder"), bi_encoder=asymmetric_search.get("encoder"),
cross_encoder=asymmetric_search.get("cross-encoder"), cross_encoder=asymmetric_search.get("cross-encoder"),
) )

View file

@ -7,7 +7,7 @@ from database.models import (
ChatModelOptions, ChatModelOptions,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
SearchModel, SearchModelConfig,
UserConversationConfig, UserConversationConfig,
Conversation, Conversation,
Subscription, Subscription,
@ -74,7 +74,7 @@ class ConversationFactory(factory.django.DjangoModelFactory):
class SearchModelFactory(factory.django.DjangoModelFactory): class SearchModelFactory(factory.django.DjangoModelFactory):
class Meta: class Meta:
model = SearchModel model = SearchModelConfig
name = "default" name = "default"
model_type = "text" model_type = "text"