mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 19:03:01 +01:00
Merge pull request #250 from khoj-ai/features/github-multi-repo-and-more
Support multiple Github repositories and support indexing of multiple file types
This commit is contained in:
commit
c0d35bafdd
12 changed files with 257 additions and 84 deletions
|
@ -84,12 +84,21 @@
|
||||||
window.onload = function () {
|
window.onload = function () {
|
||||||
fetch('/api/chat?client=web')
|
fetch('/api/chat?client=web')
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => data.response)
|
.then(data => {
|
||||||
.then(chat_logs => {
|
if (data.detail) {
|
||||||
|
// If the server returns a 500 error with detail, render it as a message.
|
||||||
|
renderMessage(data.detail + " You can configure Khoj chat in your <a class='inline-chat-link' href='/config'>settings</a>.", "khoj");
|
||||||
|
}
|
||||||
|
return data.response;
|
||||||
|
})
|
||||||
|
.then(response => {
|
||||||
// Render conversation history, if any
|
// Render conversation history, if any
|
||||||
chat_logs.forEach(chat_log => {
|
response.forEach(chat_log => {
|
||||||
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created));
|
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created));
|
||||||
});
|
});
|
||||||
|
})
|
||||||
|
.catch(err => {
|
||||||
|
return;
|
||||||
});
|
});
|
||||||
|
|
||||||
// Set welcome message on load
|
// Set welcome message on load
|
||||||
|
@ -235,6 +244,12 @@
|
||||||
font-size: medium;
|
font-size: medium;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
a.inline-chat-link {
|
||||||
|
color: #475569;
|
||||||
|
text-decoration: none;
|
||||||
|
border-bottom: 1px dotted #475569;
|
||||||
|
}
|
||||||
|
|
||||||
@media (pointer: coarse), (hover: none) {
|
@media (pointer: coarse), (hover: none) {
|
||||||
abbr[title] {
|
abbr[title] {
|
||||||
position: relative;
|
position: relative;
|
||||||
|
|
|
@ -16,31 +16,25 @@
|
||||||
<input type="text" id="pat-token" name="pat" value="{{ current_config['pat_token'] }}">
|
<input type="text" id="pat-token" name="pat" value="{{ current_config['pat_token'] }}">
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<label for="repo-owner">Repository Owner</label>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<input type="text" id="repo-owner" name="repo_owner" value="{{ current_config['repo_owner'] }}">
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<label for="repo-name">Repository Name</label>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<input type="text" id="repo-name" name="repo_name" value="{{ current_config['repo_name'] }}">
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<label for="repo-branch">Repository Branch</label>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<input type="text" id="repo-branch" name="repo_branch" value="{{ current_config['repo_branch'] }}">
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
</table>
|
||||||
|
<h4>Repositories</h4>
|
||||||
|
<div id="repositories" class="section-cards">
|
||||||
|
{% for repo in current_config['repos'] %}
|
||||||
|
<div class="card repo" id="repo-card-{{loop.index}}">
|
||||||
|
<label for="repo-owner">Repository Owner</label>
|
||||||
|
<input type="text" id="repo-owner-{{loop.index}}" name="repo_owner" value="{{ repo.owner }}">
|
||||||
|
<label for="repo-name">Repository Name</label>
|
||||||
|
<input type="text" id="repo-name-{{loop.index}}" name="repo_name" value="{{ repo.name}}">
|
||||||
|
<label for="repo-branch">Repository Branch</label>
|
||||||
|
<input type="text" id="repo-branch-{{loop.index}}" name="repo_branch" value="{{ repo.branch }}">
|
||||||
|
<button type="button"
|
||||||
|
class="remove-repo-button"
|
||||||
|
onclick="remove_repo({{loop.index}})"
|
||||||
|
id="remove-repo-button-{{loop.index}}">Remove Repository</button>
|
||||||
|
</div>
|
||||||
|
{% endfor %}
|
||||||
|
</div>
|
||||||
|
<button type="button" id="add-repository-button">Add Repository</button>
|
||||||
<h4>You probably don't need to edit these.</h4>
|
<h4>You probably don't need to edit these.</h4>
|
||||||
|
|
||||||
<table>
|
<table>
|
||||||
|
@ -68,16 +62,86 @@
|
||||||
</form>
|
</form>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<style>
|
||||||
|
div.repo {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
grid-template-rows: none;
|
||||||
|
}
|
||||||
|
div#repositories {
|
||||||
|
margin-bottom: 12px;
|
||||||
|
}
|
||||||
|
button.remove-repo-button {
|
||||||
|
background-color: gainsboro;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
<script>
|
<script>
|
||||||
|
const add_repo_button = document.getElementById("add-repository-button");
|
||||||
|
add_repo_button.addEventListener("click", function(event) {
|
||||||
|
event.preventDefault();
|
||||||
|
var repo = document.createElement("div");
|
||||||
|
repo.classList.add("card");
|
||||||
|
repo.classList.add("repo");
|
||||||
|
const id = Date.now();
|
||||||
|
repo.id = "repo-card-" + id;
|
||||||
|
repo.innerHTML = `
|
||||||
|
<label for="repo-owner">Repository Owner</label>
|
||||||
|
<input type="text" id="repo-owner" name="repo_owner">
|
||||||
|
<label for="repo-name">Repository Name</label>
|
||||||
|
<input type="text" id="repo-name" name="repo_name">
|
||||||
|
<label for="repo-branch">Repository Branch</label>
|
||||||
|
<input type="text" id="repo-branch" name="repo_branch">
|
||||||
|
<button type="button"
|
||||||
|
class="remove-repo-button"
|
||||||
|
onclick="remove_repo(${id})"
|
||||||
|
id="remove-repo-button-${id}">Remove Repository</button>
|
||||||
|
`;
|
||||||
|
document.getElementById("repositories").appendChild(repo);
|
||||||
|
})
|
||||||
|
|
||||||
|
function remove_repo(index) {
|
||||||
|
document.getElementById("repo-card-" + index).remove();
|
||||||
|
}
|
||||||
|
|
||||||
submit.addEventListener("click", function(event) {
|
submit.addEventListener("click", function(event) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
var compressed_jsonl = document.getElementById("compressed-jsonl").value;
|
const compressed_jsonl = document.getElementById("compressed-jsonl").value;
|
||||||
var embeddings_file = document.getElementById("embeddings-file").value;
|
const embeddings_file = document.getElementById("embeddings-file").value;
|
||||||
var pat_token = document.getElementById("pat-token").value;
|
const pat_token = document.getElementById("pat-token").value;
|
||||||
var repo_owner = document.getElementById("repo-owner").value;
|
|
||||||
var repo_name = document.getElementById("repo-name").value;
|
if (pat_token == "") {
|
||||||
var repo_branch = document.getElementById("repo-branch").value;
|
document.getElementById("success").innerHTML = "❌ Please enter a Personal Access Token.";
|
||||||
|
document.getElementById("success").style.display = "block";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
var cards = document.getElementById("repositories").getElementsByClassName("repo");
|
||||||
|
var repos = [];
|
||||||
|
|
||||||
|
for (var i = 0; i < cards.length; i++) {
|
||||||
|
var card = cards[i];
|
||||||
|
var owner = card.getElementsByTagName("input")[0].value;
|
||||||
|
var name = card.getElementsByTagName("input")[1].value;
|
||||||
|
var branch = card.getElementsByTagName("input")[2].value;
|
||||||
|
|
||||||
|
if (owner == "" || name == "" || branch == "") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
repos.push({
|
||||||
|
"owner": owner,
|
||||||
|
"name": name,
|
||||||
|
"branch": branch,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (repos.length == 0) {
|
||||||
|
document.getElementById("success").innerHTML = "❌ Please add at least one repository.";
|
||||||
|
document.getElementById("success").style.display = "block";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||||
fetch('/api/config/data/content_type/github', {
|
fetch('/api/config/data/content_type/github', {
|
||||||
|
@ -88,9 +152,7 @@
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
"pat_token": pat_token,
|
"pat_token": pat_token,
|
||||||
"repo_owner": repo_owner,
|
"repos": repos,
|
||||||
"repo_name": repo_name,
|
|
||||||
"repo_branch": repo_branch,
|
|
||||||
"compressed_jsonl": compressed_jsonl,
|
"compressed_jsonl": compressed_jsonl,
|
||||||
"embeddings_file": embeddings_file,
|
"embeddings_file": embeddings_file,
|
||||||
})
|
})
|
||||||
|
|
|
@ -57,6 +57,27 @@
|
||||||
}).join("\n") + `</div>`;
|
}).join("\n") + `</div>`;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function render_mutliple(query, data, type) {
|
||||||
|
let org_files = data.filter((item) => item.additional.file.endsWith(".org"));
|
||||||
|
let md_files = data.filter((item) => item.additional.file.endsWith(".md"));
|
||||||
|
let pdf_files = data.filter((item) => item.additional.file.endsWith(".pdf"));
|
||||||
|
|
||||||
|
let html = "";
|
||||||
|
if (org_files.length > 0) {
|
||||||
|
html += render_org(query, org_files, type);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (md_files.length > 0) {
|
||||||
|
html += render_markdown(query, md_files);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pdf_files.length > 0) {
|
||||||
|
html += render_pdf(query, pdf_files);
|
||||||
|
}
|
||||||
|
|
||||||
|
return html;
|
||||||
|
}
|
||||||
|
|
||||||
function render_json(data, query, type) {
|
function render_json(data, query, type) {
|
||||||
if (type === "markdown") {
|
if (type === "markdown") {
|
||||||
return render_markdown(query, data);
|
return render_markdown(query, data);
|
||||||
|
@ -71,7 +92,7 @@
|
||||||
} else if (type === "pdf") {
|
} else if (type === "pdf") {
|
||||||
return render_pdf(query, data);
|
return render_pdf(query, data);
|
||||||
} else if (type == "github") {
|
} else if (type == "github") {
|
||||||
return render_markdown(query, data);
|
return render_mutliple(query, data, type);
|
||||||
} else {
|
} else {
|
||||||
return `<div id="results-plugin">`
|
return `<div id="results-plugin">`
|
||||||
+ data.map((item) => `<p>${item.entry}</p>`).join("\n")
|
+ data.map((item) => `<p>${item.entry}</p>`).join("\n")
|
||||||
|
|
|
@ -8,8 +8,9 @@ import requests
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import Entry, GithubContentConfig
|
from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
|
||||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||||
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||||
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||||
|
|
||||||
|
@ -21,7 +22,6 @@ class GithubToJsonl(TextToJsonl):
|
||||||
def __init__(self, config: GithubContentConfig):
|
def __init__(self, config: GithubContentConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.repo_url = f"https://api.github.com/repos/{self.config.repo_owner}/{self.config.repo_name}"
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait_for_rate_limit_reset(response, func, *args, **kwargs):
|
def wait_for_rate_limit_reset(response, func, *args, **kwargs):
|
||||||
|
@ -34,26 +34,43 @@ class GithubToJsonl(TextToJsonl):
|
||||||
return
|
return
|
||||||
|
|
||||||
def process(self, previous_entries=None):
|
def process(self, previous_entries=None):
|
||||||
|
current_entries = []
|
||||||
|
for repo in self.config.repos:
|
||||||
|
current_entries += self.process_repo(repo, previous_entries)
|
||||||
|
|
||||||
|
return self.update_entries_with_ids(current_entries, previous_entries)
|
||||||
|
|
||||||
|
def process_repo(self, repo: GithubRepoConfig, previous_entries=None):
|
||||||
|
repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}"
|
||||||
|
repo_shorthand = f"{repo.owner}/{repo.name}"
|
||||||
|
logger.info(f"Processing github repo {repo_shorthand}")
|
||||||
with timer("Download markdown files from github repo", logger):
|
with timer("Download markdown files from github repo", logger):
|
||||||
try:
|
try:
|
||||||
docs = self.get_markdown_files()
|
markdown_files, org_files = self.get_files(repo_url, repo)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unable to download github repo {self.config.repo_owner}/{self.config.repo_name}")
|
logger.error(f"Unable to download github repo {repo_shorthand}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
logger.info(f"Found {len(docs)} documents in github repo {self.config.repo_owner}/{self.config.repo_name}")
|
logger.info(f"Found {len(markdown_files)} markdown files in github repo {repo_shorthand}")
|
||||||
|
logger.info(f"Found {len(org_files)} org files in github repo {repo_shorthand}")
|
||||||
|
|
||||||
with timer("Extract markdown entries from github repo", logger):
|
with timer(f"Extract markdown entries from github repo {repo_shorthand}", logger):
|
||||||
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(
|
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(
|
||||||
*GithubToJsonl.extract_markdown_entries(docs)
|
*GithubToJsonl.extract_markdown_entries(markdown_files)
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Extract commit messages from github repo", logger):
|
with timer(f"Extract org entries from github repo {repo_shorthand}", logger):
|
||||||
current_entries += self.convert_commits_to_entries(self.get_commits())
|
current_entries += OrgToJsonl.convert_org_nodes_to_entries(*GithubToJsonl.extract_org_entries(org_files))
|
||||||
|
|
||||||
with timer("Split entries by max token size supported by model", logger):
|
with timer(f"Extract commit messages from github repo {repo_shorthand}", logger):
|
||||||
|
current_entries += self.convert_commits_to_entries(self.get_commits(repo_url), repo)
|
||||||
|
|
||||||
|
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
|
||||||
current_entries = TextToJsonl.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
current_entries = TextToJsonl.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||||
|
|
||||||
|
return current_entries
|
||||||
|
|
||||||
|
def update_entries_with_ids(self, current_entries, previous_entries):
|
||||||
# Identify, mark and merge any new entries with previous entries
|
# Identify, mark and merge any new entries with previous entries
|
||||||
with timer("Identify new or updated entries", logger):
|
with timer("Identify new or updated entries", logger):
|
||||||
if not previous_entries:
|
if not previous_entries:
|
||||||
|
@ -76,31 +93,40 @@ class GithubToJsonl(TextToJsonl):
|
||||||
|
|
||||||
return entries_with_ids
|
return entries_with_ids
|
||||||
|
|
||||||
def get_markdown_files(self):
|
def get_files(self, repo_url: str, repo: GithubRepoConfig):
|
||||||
# Get the contents of the repository
|
# Get the contents of the repository
|
||||||
repo_content_url = f"{self.repo_url}/git/trees/{self.config.repo_branch}"
|
repo_content_url = f"{repo_url}/git/trees/{repo.branch}"
|
||||||
headers = {"Authorization": f"token {self.config.pat_token}"}
|
headers = {"Authorization": f"token {self.config.pat_token}"}
|
||||||
params = {"recursive": "true"}
|
params = {"recursive": "true"}
|
||||||
response = requests.get(repo_content_url, headers=headers, params=params)
|
response = requests.get(repo_content_url, headers=headers, params=params)
|
||||||
contents = response.json()
|
contents = response.json()
|
||||||
|
|
||||||
# Wait for rate limit reset if needed
|
# Wait for rate limit reset if needed
|
||||||
result = self.wait_for_rate_limit_reset(response, self.get_markdown_files)
|
result = self.wait_for_rate_limit_reset(response, self.get_files)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# Extract markdown files from the repository
|
# Extract markdown files from the repository
|
||||||
markdown_files = []
|
markdown_files = []
|
||||||
|
org_files = []
|
||||||
for item in contents["tree"]:
|
for item in contents["tree"]:
|
||||||
# Find all markdown files in the repository
|
# Find all markdown files in the repository
|
||||||
if item["type"] == "blob" and item["path"].endswith(".md"):
|
if item["type"] == "blob" and item["path"].endswith(".md"):
|
||||||
# Create URL for each markdown file on Github
|
# Create URL for each markdown file on Github
|
||||||
url_path = f'https://github.com/{self.config.repo_owner}/{self.config.repo_name}/blob/{self.config.repo_branch}/{item["path"]}'
|
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}'
|
||||||
|
|
||||||
# Add markdown file contents and URL to list
|
# Add markdown file contents and URL to list
|
||||||
markdown_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]
|
markdown_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]
|
||||||
|
|
||||||
return markdown_files
|
# Find all org files in the repository
|
||||||
|
elif item["type"] == "blob" and item["path"].endswith(".org"):
|
||||||
|
# Create URL for each org file on Github
|
||||||
|
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}'
|
||||||
|
|
||||||
|
# Add org file contents and URL to list
|
||||||
|
org_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]
|
||||||
|
|
||||||
|
return markdown_files, org_files
|
||||||
|
|
||||||
def get_file_contents(self, file_url):
|
def get_file_contents(self, file_url):
|
||||||
# Get text from each markdown file
|
# Get text from each markdown file
|
||||||
|
@ -114,9 +140,9 @@ class GithubToJsonl(TextToJsonl):
|
||||||
|
|
||||||
return response.content.decode("utf-8")
|
return response.content.decode("utf-8")
|
||||||
|
|
||||||
def get_commits(self) -> List[Dict]:
|
def get_commits(self, repo_url: str) -> List[Dict]:
|
||||||
# Get commit messages from the repository using the Github API
|
# Get commit messages from the repository using the Github API
|
||||||
commits_url = f"{self.repo_url}/commits"
|
commits_url = f"{repo_url}/commits"
|
||||||
headers = {"Authorization": f"token {self.config.pat_token}"}
|
headers = {"Authorization": f"token {self.config.pat_token}"}
|
||||||
params = {"per_page": 100}
|
params = {"per_page": 100}
|
||||||
commits = []
|
commits = []
|
||||||
|
@ -140,10 +166,10 @@ class GithubToJsonl(TextToJsonl):
|
||||||
|
|
||||||
return commits
|
return commits
|
||||||
|
|
||||||
def convert_commits_to_entries(self, commits) -> List[Entry]:
|
def convert_commits_to_entries(self, commits, repo: GithubRepoConfig) -> List[Entry]:
|
||||||
entries: List[Entry] = []
|
entries: List[Entry] = []
|
||||||
for commit in commits:
|
for commit in commits:
|
||||||
compiled = f'Commit message from {self.config.repo_owner}/{self.config.repo_name}:\n{commit["content"]}'
|
compiled = f'Commit message from {repo.owner}/{repo.name}:\n{commit["content"]}'
|
||||||
entries.append(
|
entries.append(
|
||||||
Entry(
|
Entry(
|
||||||
compiled=compiled,
|
compiled=compiled,
|
||||||
|
@ -164,3 +190,14 @@ class GithubToJsonl(TextToJsonl):
|
||||||
doc["content"], doc["path"], entries, entry_to_file_map
|
doc["content"], doc["path"], entries, entry_to_file_map
|
||||||
)
|
)
|
||||||
return entries, dict(entry_to_file_map)
|
return entries, dict(entry_to_file_map)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_org_entries(org_files):
|
||||||
|
entries = []
|
||||||
|
entry_to_file_map = []
|
||||||
|
|
||||||
|
for doc in org_files:
|
||||||
|
entries, entry_to_file_map = OrgToJsonl.process_single_org_file(
|
||||||
|
doc["content"], doc["path"], entries, entry_to_file_map
|
||||||
|
)
|
||||||
|
return entries, dict(entry_to_file_map)
|
||||||
|
|
|
@ -10,13 +10,17 @@ from khoj.processor.text_to_jsonl import TextToJsonl
|
||||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||||
from khoj.utils.constants import empty_escape_sequences
|
from khoj.utils.constants import empty_escape_sequences
|
||||||
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MarkdownToJsonl(TextToJsonl):
|
class MarkdownToJsonl(TextToJsonl):
|
||||||
|
def __init__(self, config: TextContentConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, previous_entries=None):
|
def process(self, previous_entries=None):
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
|
|
|
@ -9,7 +9,7 @@ from khoj.processor.org_mode import orgnode
|
||||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||||
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
from khoj.utils.helpers import get_absolute_path, is_none_or_empty, timer
|
||||||
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
from khoj.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +17,10 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OrgToJsonl(TextToJsonl):
|
class OrgToJsonl(TextToJsonl):
|
||||||
|
def __init__(self, config: TextContentConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
# Define Functions
|
# Define Functions
|
||||||
def process(self, previous_entries: List[Entry] = None):
|
def process(self, previous_entries: List[Entry] = None):
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
|
@ -96,12 +100,20 @@ class OrgToJsonl(TextToJsonl):
|
||||||
entries = []
|
entries = []
|
||||||
entry_to_file_map = []
|
entry_to_file_map = []
|
||||||
for org_file in org_files:
|
for org_file in org_files:
|
||||||
org_file_entries = orgnode.makelist(str(org_file))
|
org_file_entries = orgnode.makelist_with_filepath(str(org_file))
|
||||||
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
|
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
|
||||||
entries.extend(org_file_entries)
|
entries.extend(org_file_entries)
|
||||||
|
|
||||||
return entries, dict(entry_to_file_map)
|
return entries, dict(entry_to_file_map)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_single_org_file(org_content: str, org_file: str, entries: List, entry_to_file_map: List):
|
||||||
|
# Process single org file. The org parser assumes that the file is a single org file and reads it from a buffer. We'll split the raw conetnt of this file by new line to mimic the same behavior.
|
||||||
|
org_file_entries = orgnode.makelist(org_content.split("\n"), org_file)
|
||||||
|
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
|
||||||
|
entries.extend(org_file_entries)
|
||||||
|
return entries, entry_to_file_map
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_org_nodes_to_entries(
|
def convert_org_nodes_to_entries(
|
||||||
parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False
|
parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False
|
||||||
|
|
|
@ -53,14 +53,19 @@ def normalize_filename(filename):
|
||||||
return escaped_filename
|
return escaped_filename
|
||||||
|
|
||||||
|
|
||||||
def makelist(filename):
|
def makelist_with_filepath(filename):
|
||||||
|
f = open(filename, "r")
|
||||||
|
return makelist(f, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def makelist(file, filename):
|
||||||
"""
|
"""
|
||||||
Read an org-mode file and return a list of Orgnode objects
|
Read an org-mode file and return a list of Orgnode objects
|
||||||
created from this file.
|
created from this file.
|
||||||
"""
|
"""
|
||||||
ctr = 0
|
ctr = 0
|
||||||
|
|
||||||
f = open(filename, "r")
|
f = file
|
||||||
|
|
||||||
todos = {
|
todos = {
|
||||||
"TODO": "",
|
"TODO": "",
|
||||||
|
|
|
@ -80,7 +80,12 @@ async def set_content_config_github_data(updated_config: GithubContentConfig):
|
||||||
if not state.config:
|
if not state.config:
|
||||||
state.config = FullConfig()
|
state.config = FullConfig()
|
||||||
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
||||||
state.config.content_type = ContentConfig(github=updated_config)
|
|
||||||
|
if not state.config.content_type:
|
||||||
|
state.config.content_type = ContentConfig(**{"github": updated_config})
|
||||||
|
else:
|
||||||
|
state.config.content_type.github = updated_config
|
||||||
|
|
||||||
try:
|
try:
|
||||||
save_config_to_file_updated_state()
|
save_config_to_file_updated_state()
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
@ -93,7 +98,12 @@ async def set_content_config_data(content_type: str, updated_config: TextContent
|
||||||
if not state.config:
|
if not state.config:
|
||||||
state.config = FullConfig()
|
state.config = FullConfig()
|
||||||
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
||||||
|
|
||||||
|
if not state.config.content_type:
|
||||||
state.config.content_type = ContentConfig(**{content_type: updated_config})
|
state.config.content_type = ContentConfig(**{content_type: updated_config})
|
||||||
|
else:
|
||||||
|
state.config.content_type[content_type] = updated_config
|
||||||
|
|
||||||
try:
|
try:
|
||||||
save_config_to_file_updated_state()
|
save_config_to_file_updated_state()
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
|
@ -49,9 +49,7 @@ default_config = {
|
||||||
},
|
},
|
||||||
"github": {
|
"github": {
|
||||||
"pat-token": None,
|
"pat-token": None,
|
||||||
"repo-name": None,
|
"repos": [],
|
||||||
"repo-owner": None,
|
|
||||||
"repo-branch": "master",
|
|
||||||
"compressed-jsonl": "~/.khoj/content/github/github.jsonl.gz",
|
"compressed-jsonl": "~/.khoj/content/github/github.jsonl.gz",
|
||||||
"embeddings-file": "~/.khoj/content/github/github_embeddings.pt",
|
"embeddings-file": "~/.khoj/content/github/github_embeddings.pt",
|
||||||
},
|
},
|
||||||
|
|
|
@ -41,11 +41,15 @@ class TextContentConfig(TextConfigBase):
|
||||||
return input_filter
|
return input_filter
|
||||||
|
|
||||||
|
|
||||||
|
class GithubRepoConfig(ConfigBase):
|
||||||
|
name: str
|
||||||
|
owner: str
|
||||||
|
branch: Optional[str] = "master"
|
||||||
|
|
||||||
|
|
||||||
class GithubContentConfig(TextConfigBase):
|
class GithubContentConfig(TextConfigBase):
|
||||||
pat_token: str
|
pat_token: str
|
||||||
repo_name: str
|
repos: List[GithubRepoConfig]
|
||||||
repo_owner: str
|
|
||||||
repo_branch: Optional[str] = "master"
|
|
||||||
|
|
||||||
|
|
||||||
class ImageContentConfig(ConfigBase):
|
class ImageContentConfig(ConfigBase):
|
||||||
|
|
|
@ -17,6 +17,7 @@ from khoj.utils.rawconfig import (
|
||||||
ProcessorConfig,
|
ProcessorConfig,
|
||||||
TextContentConfig,
|
TextContentConfig,
|
||||||
GithubContentConfig,
|
GithubContentConfig,
|
||||||
|
GithubRepoConfig,
|
||||||
ImageContentConfig,
|
ImageContentConfig,
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
TextSearchConfig,
|
TextSearchConfig,
|
||||||
|
@ -92,9 +93,13 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
|
||||||
|
|
||||||
content_config.github = GithubContentConfig(
|
content_config.github = GithubContentConfig(
|
||||||
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
|
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
|
||||||
repo_name="lantern",
|
repos=[
|
||||||
repo_owner="khoj-ai",
|
GithubRepoConfig(
|
||||||
repo_branch="master",
|
owner="khoj-ai",
|
||||||
|
name="lantern",
|
||||||
|
branch="master",
|
||||||
|
)
|
||||||
|
],
|
||||||
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
|
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
|
||||||
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
|
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,7 +14,7 @@ def test_parse_entry_with_no_headings(tmp_path):
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
entries = orgnode.makelist(orgfile)
|
entries = orgnode.makelist_with_filepath(orgfile)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
|
@ -38,7 +38,7 @@ Body Line 1"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
entries = orgnode.makelist(orgfile)
|
entries = orgnode.makelist_with_filepath(orgfile)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
|
@ -71,7 +71,7 @@ Body Line 2"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
entries = orgnode.makelist(orgfile)
|
entries = orgnode.makelist_with_filepath(orgfile)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
|
@ -109,7 +109,7 @@ def test_render_entry_with_property_drawer_and_empty_body(tmp_path):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
parsed_entries = orgnode.makelist(orgfile)
|
parsed_entries = orgnode.makelist_with_filepath(orgfile)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert f"{parsed_entries[0]}" == expected_entry
|
assert f"{parsed_entries[0]}" == expected_entry
|
||||||
|
@ -131,7 +131,7 @@ Body Line 2
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
entries = orgnode.makelist(orgfile)
|
entries = orgnode.makelist_with_filepath(orgfile)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# SOURCE link rendered with Heading
|
# SOURCE link rendered with Heading
|
||||||
|
@ -155,7 +155,7 @@ Body Line 1"""
|
||||||
orgfile = create_file(tmp_path, entry, filename="test[1].org")
|
orgfile = create_file(tmp_path, entry, filename="test[1].org")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
entries = orgnode.makelist(orgfile)
|
entries = orgnode.makelist_with_filepath(orgfile)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
|
@ -197,7 +197,7 @@ Body 2
|
||||||
orgfile = create_file(tmp_path, content)
|
orgfile = create_file(tmp_path, content)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
entries = orgnode.makelist(orgfile)
|
entries = orgnode.makelist_with_filepath(orgfile)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 2
|
assert len(entries) == 2
|
||||||
|
@ -225,7 +225,7 @@ Body Line 1"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
entries = orgnode.makelist(orgfile)
|
entries = orgnode.makelist_with_filepath(orgfile)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
|
@ -248,7 +248,7 @@ Body Line 1"""
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
entries = orgnode.makelist(orgfile)
|
entries = orgnode.makelist_with_filepath(orgfile)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
|
@ -272,7 +272,7 @@ Body Line 1
|
||||||
orgfile = create_file(tmp_path, entry)
|
orgfile = create_file(tmp_path, entry)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
entries = orgnode.makelist(orgfile)
|
entries = orgnode.makelist_with_filepath(orgfile)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
|
@ -298,7 +298,7 @@ entry body
|
||||||
orgfile = create_file(tmp_path, body)
|
orgfile = create_file(tmp_path, body)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
entries = orgnode.makelist(orgfile)
|
entries = orgnode.makelist_with_filepath(orgfile)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 2
|
assert len(entries) == 2
|
||||||
|
@ -320,7 +320,7 @@ entry body
|
||||||
orgfile = create_file(tmp_path, body)
|
orgfile = create_file(tmp_path, body)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
entries = orgnode.makelist(orgfile)
|
entries = orgnode.makelist_with_filepath(orgfile)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(entries) == 2
|
assert len(entries) == 2
|
||||||
|
|
Loading…
Reference in a new issue