139 lines
4.3 KiB
Python
139 lines
4.3 KiB
Python
import json
|
|
import re
|
|
from torch import bfloat16
|
|
# import transformers
|
|
from duckduckgo_search import DDGS
|
|
# from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from huggingface_hub import hf_hub_download
|
|
from llama_cpp import Llama
|
|
import json
|
|
from torch import bfloat16
|
|
# import transformers
|
|
from duckduckgo_search import DDGS
|
|
|
|
## Download the GGUF model
|
|
model_name = "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF"
|
|
model_file = "mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf"
|
|
model_id = "mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf"
|
|
model_path = "/Users/sij/AI/Models/TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF/mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf"
|
|
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
|
|
|
# Initialize the model
|
|
# model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
# model_id,
|
|
# trust_remote_code=True,
|
|
# torch_dtype=bfloat16,
|
|
# device_map='auto'
|
|
# )
|
|
# model.eval()
|
|
|
|
# Initialize the tokenizer
|
|
# tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
|
|
|
|
# # Define a text-generation pipeline
|
|
# generate_text = transformers.pipeline(
|
|
# model=model, tokenizer=tokenizer,
|
|
# return_full_text=False,
|
|
# task="text-generation",
|
|
# temperature=0.1,
|
|
# top_p=0.15,
|
|
# top_k=0,
|
|
# max_new_tokens=512,
|
|
# repetition_penalty=1.1
|
|
# )
|
|
|
|
## tokenizer not carried over
|
|
##
|
|
|
|
llm = Llama(
|
|
model_path=model_path,
|
|
n_ctx=8000, # Context length to use
|
|
n_threads=8, # Number of CPU threads to use
|
|
n_gpu_layers=2 # Number of model layers to offload to GPU
|
|
)
|
|
|
|
## Generation kwargs
|
|
generation_kwargs = {
|
|
"max_tokens":20000,
|
|
"stop":["</s>"],
|
|
"echo":True, # Echo the prompt in the output
|
|
"top_k":1 # This is essentially greedy decoding, since the model will always return the highest-probability token. Set this value > 1 for sampling decoding
|
|
}
|
|
|
|
# Define a function to use a tool based on the action dictionary
|
|
def use_tool(action: dict):
|
|
tool_name = action["tool_name"]
|
|
if tool_name == "Calculator":
|
|
exec(action["input"])
|
|
return f"Tool Output: {output}"
|
|
|
|
elif tool_name == "Search":
|
|
contexts = []
|
|
with DDGS() as ddgs:
|
|
results = ddgs.text(
|
|
action["input"],
|
|
region="wt-wt", safesearch="on",
|
|
max_results=3
|
|
)
|
|
for r in results:
|
|
contexts.append(r['body'])
|
|
info = "\n---\n".join(contexts)
|
|
return f"Tool Output: {info}"
|
|
elif tool_name == "Final Answer":
|
|
return "Assistant: "+action["input"]
|
|
|
|
|
|
# Function to format instruction prompt
|
|
def instruction_format(sys_message: str, query: str):
|
|
return f'<s> [INST] {sys_message} [/INST]\nUser: {query}\nAssistant: ```json\n'
|
|
|
|
|
|
# Function to parse the generated action string into a dictionary
|
|
def format_output(input_text: str, prefix: str):
|
|
# Remove the prefix from input_text
|
|
if input_text.startswith(prefix):
|
|
# Cutting off the prefix to isolate the JSON part
|
|
trimmed_text = input_text[len(prefix):]
|
|
else:
|
|
print("Prefix not found at the beginning of input_text.")
|
|
return None
|
|
|
|
if trimmed_text.endswith('\n```'):
|
|
json_str = trimmed_text[:-len('\n```')].strip()
|
|
else:
|
|
json_str = trimmed_text.strip()
|
|
|
|
# json_str = json_str[len('```json\n'):-len('\n```')].strip()
|
|
|
|
print(f"Trimmed: {json_str}")
|
|
|
|
try:
|
|
json_data = json.loads(json_str)
|
|
print(f"Parsed JSON: {json_data}\n")
|
|
return json_data
|
|
except json.JSONDecodeError as e:
|
|
print(f"Error parsing JSON: {e}")
|
|
return None
|
|
|
|
|
|
# Function to handle a single prompt, tool selection, and final action loop
|
|
def run(query: str):
|
|
input_prompt = instruction_format(sys_msg, query)
|
|
# res = generate_text(input_prompt #####)
|
|
|
|
res = llm(input_prompt, **generation_kwargs)
|
|
textthereof = res["choices"][0]["text"]
|
|
action_dict = format_output(textthereof, input_prompt)
|
|
|
|
response = use_tool(action_dict)
|
|
full_text = f"{query}{res[0]['generated_text']}\n{response}"
|
|
return response, full_text
|
|
|
|
# Example query
|
|
query = "Hi there, I'm stuck on a math problem, can you help? My question is what is the square root of 512 multiplied by 7?"
|
|
sys_msg = """[Your detailed system message or instructions here]""" # You would replace this with your actual detailed instructions
|
|
|
|
# Running the example
|
|
out = run(query)
|
|
|
|
print(out[0])
|