-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_embeddings.py
More file actions
92 lines (68 loc) · 2.27 KB
/
create_embeddings.py
File metadata and controls
92 lines (68 loc) · 2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from pymilvus import (connections, Collection)
from sentence_transformers import SentenceTransformer
import csv
from pymilvus import (
utility, FieldSchema, CollectionSchema, DataType, Collection
)
from datasets import load_dataset
from dotenv import load_dotenv
import os
load_dotenv()
COLLECTION_NAME = "movies_db"
MILVUS_TOKEN = os.getenv("MILVUS_TOKEN")
MILVUS_URI = os.getenv("MILVUS_URI")
DIMENSION = 384
def csv_load(file):
with open(file, newline='') as f:
reader = csv.reader(f, delimiter=',')
for row in reader:
if '' in (row[1], row[7], row[8]):
continue
yield (row[1], row[7], row[8])
transformer = SentenceTransformer("all-MiniLM-L6-v2")
def generate_embeddings(data: list[str]):
embeddings = transformer.encode(data)
return [x for x in embeddings]
connections.connect(uri=MILVUS_URI,
token=MILVUS_TOKEN)
if utility.has_collection(collection_name=COLLECTION_NAME):
utility.drop_collection(collection_name=COLLECTION_NAME)
fields = [
FieldSchema(name='id', dtype=DataType.INT64, is_primary=True,
auto_id=True), # id is auto=increment
FieldSchema(name='title', dtype=DataType.VARCHAR, max_length=200),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION),
FieldSchema(name='image', dtype=DataType.VARCHAR, max_length=500),
]
schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)
index_params = {
'metric_type': 'L2',
'index_type': "IVF_FLAT",
'params': {'nlist': 1536}
}
collection.create_index(field_name="embedding", index_params=index_params)
collection.load()
def insert_data(data):
embeddings = generate_embeddings(data[1])
ins = [
data[0],
embeddings,
data[2],
]
collection.insert(ins)
BATCH_SIZE = 128
data_batch = [[], [], []]
count = 0
for title, plot, image in csv_load('data/plots.csv'):
data_batch[0].append(title)
data_batch[1].append(plot)
data_batch[2].append(image)
if len(data_batch[0]) % BATCH_SIZE == 0:
insert_data(data_batch)
data_batch = [[], [], []]
print(f"\ninserted... {count} movies")
count += 1
if len(data_batch[0]) != 0:
insert_data(data_batch)
collection.flush()