-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
78 lines (67 loc) · 2.66 KB
/
predict.py
File metadata and controls
78 lines (67 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import sys
import time
import toml
import argparse
import pandas as pd
from numpy import argmax
from pathlib import Path
from Bio.SeqIO.FastaIO import SimpleFastaParser
# Hide warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr
sys.stderr = open(os.devnull, 'w')
from preprocessing import Preprocessor
from tensorflow.keras.models import load_model
sys.stderr = stderr
# =============================================
parser = argparse.ArgumentParser(prog = 'predict.py',
description = "Predict using your trained model.")
parser.add_argument("-f", "--fasta",
required = True,
help = "Input FASTA file containing the sequences to be classified.")
parser.add_argument("-m", "--model",
required = True,
help = "Path to the trained model.")
parser.add_argument("-i", "--info_file",
required = True,
help = "Path to the TOML file containing model information.")
args = parser.parse_args()
# =============================================
input_fasta = Path(args.fasta)
timestamp = time.strftime('%y%m%d%H%M%S')
output_table = f"{input_fasta.stem}_{timestamp}_prediction.csv"
model = load_model(args.model)
model_info = Path(args.info_file)
# =============================================
pp = Preprocessor()
with open(input_fasta) as fa, open(model_info) as tmf:
fids = []
fsqs = []
print("### Loading sequences ###")
for fid, fsq in SimpleFastaParser(fa):
fsq = fsq.lower()
fids.append(fid)
fsqs.append(fsq)
# =============================================
tmf = toml.load(tmf)
tmf_dict = tmf['model_info']
labels = tmf_dict['Labels']
labels_dict = {k: v for k, v in enumerate(labels)}
seq_len = tmf_dict['Maximum_length']
# =============================================
print("### Tokenizing and padding sequences ###\n")
fsqs = pp.zero_padder(list(map(pp.seq_tokenizer, fsqs)),
pad_len=seq_len)
# =============================================
print("### Starting prediction ###")
predictions = model.predict(fsqs)
pred_labels = argmax(predictions, axis=1)
df = pd.DataFrame(predictions, columns=labels).round(4)
df.insert(0, 'id', pd.Series(fids))
df.insert(1, 'prediction', pd.Series(pred_labels))
df['prediction'] = df['prediction'].map(labels_dict)
df.to_csv(output_table, index=False)
print(">>> Done!")
print(f">>> Results saved to {output_table}.")
# =============================================