Ollama agents ()

* add LMStudio agent support (generic) support
"work" with non-tool callable LLMs, highly dependent on system specs

* add comments

* enable few-shot prompting per function for OSS models

* Add Agent support for Ollama models

* improve json parsing for ollama text responses
This commit is contained in:
Timothy Carambat 2024-05-07 18:06:31 -07:00 committed by GitHub
parent 1b4559f57f
commit 331d3741c9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 185 additions and 42 deletions
frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection
server

View file

@ -5,8 +5,8 @@ import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference";
import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react";
import AgentModelSelection from "../AgentModelSelection";
const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio"];
const WARN_PERFORMANCE = ["lmstudio"];
const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio", "ollama"];
const WARN_PERFORMANCE = ["lmstudio", "ollama"];
const LLM_DEFAULT = {
name: "Please make a selection",

View file

@ -46,6 +46,7 @@
"dotenv": "^16.0.3",
"express": "^4.18.2",
"express-ws": "^5.0.2",
"extract-json-from-string": "^1.0.1",
"extract-zip": "^2.0.1",
"graphql": "^16.7.1",
"joi": "^17.11.0",
@ -59,6 +60,7 @@
"multer": "^1.4.5-lts.1",
"node-html-markdown": "^1.3.0",
"node-llama-cpp": "^2.8.0",
"ollama": "^0.5.0",
"openai": "4.38.5",
"pinecone-client": "^1.1.0",
"pluralize": "^8.0.0",

View file

@ -741,6 +741,8 @@ ${this.getHistory({ to: route.to })
return new Providers.AnthropicProvider({ model: config.model });
case "lmstudio":
return new Providers.LMStudioProvider({});
case "ollama":
return new Providers.OllamaProvider({ model: config.model });
default:
throw new Error(

View file

@ -102,48 +102,34 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`;
return { valid: true, reason: null };
}
async functionCall(messages, functions) {
async functionCall(messages, functions, chatCb = null) {
const history = [...messages].filter((msg) =>
["user", "assistant"].includes(msg.role)
);
if (history[history.length - 1].role !== "user") return null;
const response = await chatCb({
messages: [
{
content: `You are a program which picks the most optimal function and parameters to call.
DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY.
When a function is selection, respond in JSON with no additional text.
When there is no relevant function to call - return with a regular chat text response.
Your task is to pick a **single** function that we will use to call, if any seem useful or relevant for the user query.
const response = await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages: [
{
content: `You are a program which picks the most optimal function and parameters to call.
DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY.
When a function is selection, respond in JSON with no additional text.
When there is no relevant function to call - return with a regular chat text response.
Your task is to pick a **single** function that we will use to call, if any seem useful or relevant for the user query.
All JSON responses should have two keys.
'name': this is the name of the function name to call. eg: 'web-scraper', 'rag-memory', etc..
'arguments': this is an object with the function properties to invoke the function.
DO NOT INCLUDE ANY OTHER KEYS IN JSON RESPONSES.
All JSON responses should have two keys.
'name': this is the name of the function name to call. eg: 'web-scraper', 'rag-memory', etc..
'arguments': this is an object with the function properties to invoke the function.
DO NOT INCLUDE ANY OTHER KEYS IN JSON RESPONSES.
Here are the available tools you can use an examples of a query and response so you can understand how each one works.
${this.showcaseFunctions(functions)}
Here are the available tools you can use an examples of a query and response so you can understand how each one works.
${this.showcaseFunctions(functions)}
Now pick a function if there is an appropriate one to use given the last user message and the given conversation so far.`,
role: "system",
},
...history,
],
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("LMStudio chat: No results!");
if (result.choices.length === 0)
throw new Error("LMStudio chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
Now pick a function if there is an appropriate one to use given the last user message and the given conversation so far.`,
role: "system",
},
...history,
],
});
const call = safeJsonParse(response, null);
if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text.

View file

@ -1,9 +1,11 @@
const OpenAIProvider = require("./openai.js");
const AnthropicProvider = require("./anthropic.js");
const LMStudioProvider = require("./lmstudio.js");
const OllamaProvider = require("./ollama.js");
module.exports = {
OpenAIProvider,
AnthropicProvider,
LMStudioProvider,
OllamaProvider,
};

View file

@ -27,6 +27,25 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("LMStudio chat: No results!");
if (result.choices.length === 0)
throw new Error("LMStudio chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
@ -38,7 +57,11 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(messages, functions);
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);

View file

@ -0,0 +1,107 @@
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
const { Ollama } = require("ollama");
/**
* The provider for the Ollama provider.
*/
class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(config = {}) {
const {
// options = {},
model = null,
} = config;
super();
this._client = new Ollama({ host: process.env.OLLAMA_BASE_PATH });
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
const response = await this.client.chat({
model: this.model,
messages,
options: {
temperature: 0,
},
});
return response?.message?.content || null;
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat({
model: this.model,
messages: this.cleanMsgs(messages),
options: {
use_mlock: true,
temperature: 0.5,
},
});
completion = response.message;
}
return {
result: completion.content,
cost: 0,
};
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since LMStudio has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = OllamaProvider;

View file

@ -79,7 +79,11 @@ class AgentHandler {
break;
case "lmstudio":
if (!process.env.LMSTUDIO_BASE_PATH)
throw new Error("LMStudio bash path must be provided to use agents.");
throw new Error("LMStudio base path must be provided to use agents.");
break;
case "ollama":
if (!process.env.OLLAMA_BASE_PATH)
throw new Error("Ollama base path must be provided to use agents.");
break;
default:
throw new Error("No provider found to power agent cluster.");
@ -94,6 +98,8 @@ class AgentHandler {
return "claude-3-sonnet-20240229";
case "lmstudio":
return "server-default";
case "ollama":
return "llama3:latest";
default:
return "unknown";
}

View file

@ -4,6 +4,7 @@ process.env.NODE_ENV === "development"
const JWT = require("jsonwebtoken");
const { User } = require("../../models/user");
const { jsonrepair } = require("jsonrepair");
const extract = require("extract-json-from-string");
function reqBody(request) {
return typeof request.body === "string"
@ -67,8 +68,6 @@ function safeJsonParse(jsonString, fallback = null) {
return JSON.parse(jsonString);
} catch {}
// If the jsonString does not look like an Obj or Array, dont attempt
// to repair it.
if (jsonString?.startsWith("[") || jsonString?.startsWith("{")) {
try {
const repairedJson = jsonrepair(jsonString);
@ -76,6 +75,10 @@ function safeJsonParse(jsonString, fallback = null) {
} catch {}
}
try {
return extract(jsonString)[0];
} catch {}
return fallback;
}

View file

@ -2678,6 +2678,11 @@ extract-files@^9.0.0:
resolved "https://registry.yarnpkg.com/extract-files/-/extract-files-9.0.0.tgz#8a7744f2437f81f5ed3250ed9f1550de902fe54a"
integrity sha512-CvdFfHkC95B4bBBk36hcEmvdR2awOdhhVUYH6S/zrVj3477zven/fJMYg7121h4T1xHZC+tetUpubpAhxwI7hQ==
extract-json-from-string@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/extract-json-from-string/-/extract-json-from-string-1.0.1.tgz#5001f17e6c905826dcd5989564e130959de60c96"
integrity sha512-xfQOSFYbELVs9QVkKsV9FZAjlAmXQ2SLR6FpfFX1kpn4QAvaGBJlrnVOblMLwrLPYc26H+q9qxo6JTd4E7AwgQ==
extract-zip@^2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/extract-zip/-/extract-zip-2.0.1.tgz#663dca56fe46df890d5f131ef4a06d22bb8ba13a"
@ -4560,6 +4565,13 @@ octokit@^3.1.0:
"@octokit/request-error" "^5.0.0"
"@octokit/types" "^12.0.0"
ollama@^0.5.0:
version "0.5.0"
resolved "https://registry.yarnpkg.com/ollama/-/ollama-0.5.0.tgz#cb9bc709d4d3278c9f484f751b0d9b98b06f4859"
integrity sha512-CRtRzsho210EGdK52GrUMohA2pU+7NbgEaBG3DcYeRmvQthDO7E2LHOkLlUUeaYUlNmEd8icbjC02ug9meSYnw==
dependencies:
whatwg-fetch "^3.6.20"
on-finished@2.4.1:
version "2.4.1"
resolved "https://registry.yarnpkg.com/on-finished/-/on-finished-2.4.1.tgz#58c8c44116e54845ad57f14ab10b03533184ac3f"
@ -5980,7 +5992,7 @@ webidl-conversions@^3.0.0:
resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871"
integrity sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==
whatwg-fetch@^3.4.1:
whatwg-fetch@^3.4.1, whatwg-fetch@^3.6.20:
version "3.6.20"
resolved "https://registry.yarnpkg.com/whatwg-fetch/-/whatwg-fetch-3.6.20.tgz#580ce6d791facec91d37c72890995a0b48d31c70"
integrity sha512-EqhiFU6daOA8kpjOWTL0olhVOF3i7OrFzSYiGsEMB8GcXS+RrzauAERX65xMeNWVqxA6HXH2m69Z9LaKKdisfg==