-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtinyllama_example.py
More file actions
55 lines (47 loc) · 1.73 KB
/
tinyllama_example.py
File metadata and controls
55 lines (47 loc) · 1.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import logging
import sys
from pathlib import Path
# The line below is only necessary if you are using RAGGA from a cloned repository
sys.path.append(str(Path(__file__).parent / "src"))
from ragga import (
Config,
Encoder,
Generator,
MarkdownDataset,
VectorDatabase,
WebSearchRetriever,
)
from ragga.crafting.prompt import TinyLlamaChatPrompt
def output_model_response_stream(model_response: str) -> None:
sys.stdout.write(model_response)
sys.stdout.flush()
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logging.debug("loading config...")
conf = Config(config_path=Path(__file__).parent / "config.yaml")
logging.info("loading dataset, this will take a while the first time...")
dataset = MarkdownDataset(conf)
logging.info("loading encoder...")
encoder = Encoder(conf)
logging.debug("loading faiss db...")
faiss_db = VectorDatabase(conf, encoder)
if not faiss_db.loaded_from_disk:
faiss_db.documents = dataset.documents
logging.debug("loading generator and model...")
prompt = TinyLlamaChatPrompt(conf)
logging.info("loading chatbot...")
generator = Generator(conf, prompt, faiss_db, websearch=WebSearchRetriever)
generator.subscribe(output_model_response_stream)
logging.info("chatbot ready!")
while True:
try:
user_input = input("You: ")
if user_input.lower() in ["exit", "quit"]:
break
logging.debug("user input: %s", user_input)
model_response = generator.get_answer_stream(user_input)
full_response = ""
for response in model_response:
full_response += response
except KeyboardInterrupt:
break