22
33import sys
44import threading
5+ from queue import Queue
56
67from pilosa import Client , Schema
78from pilosa .imports import Column , FieldValue
1213FIELDS = [
1314 {"name" : "size" , "opts" : {"keys" : True }},
1415 {"name" : "color" , "opts" : {"keys" : True }},
15- {"name" : "age" , "opts" : {"int_min" : 0 , "int_max" : 150 }},
16+ {"name" : "age" , "opts" : {
17+ "int_min" : 0 ,
18+ "int_max" : 150
19+ }},
20+ {"name" : "result" , "opts" : {
21+ "float_min" : 1.13106317 ,
22+ "float_max" : 30.23959735 ,
23+ "float_frac" : 8 , # number of fractional digits
24+ }}
1625]
1726# -----------------------------
27+ # other settings
28+ THREAD_COUNT = 0 # 0 = use the number of CPUs available to this process
29+ VERBOSE = True
30+ #------------------------------
31+
32+ if not THREAD_COUNT :
33+ import os
34+ THREAD_COUNT = len (os .sched_getaffinity (0 ))
35+
1836
1937class MultiColumnBitIterator :
2038
2139 def __init__ (self ,
2240 file_obj , field ,
2341 column_index = 0 , row_index = 1 ,
24- has_header = True ):
42+ has_header = True ,
43+ float_frac = 0 ):
2544 self .file_obj = file_obj
2645 if has_header :
2746 # if there's a header skip it
2847 next (self .file_obj )
2948
3049 ci = column_index
3150 ri = row_index
32-
51+ float_mul = 10 ** float_frac
52+
53+ def row_value (fs ):
54+ try :
55+ if float_frac :
56+ # try to get row id field as a float
57+ return int (float (fs [ri ]) * float_mul )
58+ else :
59+ # try to getrow id field as an int
60+ return int (fs [ri ])
61+ except ValueError :
62+ # cannot convert to a float or int, skip this one
63+ return None
64+
65+ def field_with_column_key (fs ):
66+ value = row_value (fs )
67+ if value is None :
68+ return None
69+ return FieldValue (column_key = fs [ci ], value = value )
70+
71+ def field_with_column_id (fs ):
72+ value = row_value (fs )
73+ if value is None :
74+ return None
75+ return FieldValue (column_id = int (fs [ci ]), value = value )
76+
3377 # set the bit yielder
3478 if field .field_type == "int" :
3579 if field .index .keys :
36- self .yield_fun = lambda fs : FieldValue ( column_key = fs [ ci ], value = int ( fs [ ri ]))
80+ self .yield_fun = field_with_column_key
3781 else :
38- self .yield_fun = lambda fs : FieldValue ( column_id = int ( fs [ ci ]), value = int ( fs [ ri ]))
82+ self .yield_fun = field_with_column_id
3983 else :
4084 if field .index .keys :
4185 if field .keys :
@@ -58,29 +102,73 @@ def __call__(self):
58102 # split fields
59103 fs = [x .strip () for x in line .split ("," )]
60104 # return a bit
61- yield yield_fun (fs )
105+ bit = yield_fun (fs )
106+ if bit is not None :
107+ yield bit
108+
109+
110+ def import_field (q , client , path ):
111+ while True :
112+ item = q .get ()
113+ if item is None :
114+ break
115+ field , row_index , float_frac = item
116+ print ("Importing field:" , field .name )
117+ with open (path ) as f :
118+ mcb = MultiColumnBitIterator (f ,
119+ field ,
120+ row_index = row_index ,
121+ float_frac = float_frac )
122+ client .import_field (field , mcb ())
123+ q .task_done ()
62124
63- def import_field (client , field , path , row_index ):
64- with open (path ) as f :
65- mcb = MultiColumnBitIterator (f , field , row_index = row_index )
66- client .import_field (field , mcb ())
67125
68126def import_csv (pilosa_addr , path ):
69- client = Client (pilosa_addr )
127+ client = Client (pilosa_addr , socket_timeout = 20000000 )
70128
71129 # create the schema
72130 schema = Schema ()
73131 index = schema .index (INDEX_NAME , keys = INDEX_KEYS , track_existence = True )
74- fields = [index .field (field ["name" ], ** field ["opts" ]) for field in FIELDS ]
132+ fields = []
133+ for field in FIELDS :
134+ opts = field .get ("opts" , {})
135+
136+ # check whether opts include float related fields
137+ # and convert them to int fields
138+ float_frac = 0
139+ if "float_frac" in opts :
140+ float_frac = opts ["float_frac" ]
141+ del opts ["float_frac" ]
142+ if "float_min" in opts :
143+ opts ["int_min" ] = int (opts ["float_min" ] * 10 ** float_frac )
144+ del opts ["float_min" ]
145+ if "float_max" in opts :
146+ opts ["int_max" ] = int (opts ["float_max" ] * 10 ** float_frac )
147+ del opts ["float_max" ]
148+
149+ field = index .field (field ["name" ], ** opts )
150+ fields .append ((field , float_frac ))
151+
75152 client .sync_schema (schema )
76153
77154 # import each field
155+ q = Queue ()
78156 threads = []
79- for i , field in enumerate (fields ):
80- t = threading .Thread (target = import_field , args = (client , field , path , i + 1 ))
157+ for i in range (THREAD_COUNT ):
158+ t = threading .Thread (target = import_field ,
159+ args = (q , client , path ))
81160 t .start ()
82161 threads .append (t )
83-
162+
163+ for i , (field , float_frac ) in enumerate (fields ):
164+ q .put ((field , i + 1 , float_frac ))
165+
166+ # wait for imports to finish
167+ q .join ()
168+
169+ # stop workers
170+ for _ in threads :
171+ q .put (None )
84172 for t in threads :
85173 t .join ()
86174
@@ -92,6 +180,13 @@ def main():
92180
93181 pilosa_addr = sys .argv [1 ]
94182 path = sys .argv [2 ]
183+
184+ print ("Pilosa Address:" , pilosa_addr )
185+ print ("Thread Count :" , THREAD_COUNT )
186+ print ("CSV Path :" , path )
187+ print ("Verbose :" , VERBOSE )
188+ print ()
189+
95190 import_csv (pilosa_addr , path )
96191
97192if __name__ == "__main__" :
0 commit comments