-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathenron-demo.py
More file actions
78 lines (64 loc) · 3.1 KB
/
enron-demo.py
File metadata and controls
78 lines (64 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import json
import os
import palimpzest as pz
from palimpzest.core.lib.schemas import TextFile
class EnronDataset(pz.IterDataset):
def __init__(self, dir: str, labels_file: str | None = None, split: str = "test"):
super().__init__(id=f"enron-{split}", schema=TextFile)
self.filepaths = [os.path.join(dir, filename) for filename in os.listdir(dir)]
self.filepaths = self.filepaths[:50] if split == "train" else self.filepaths[50:150]
self.filename_to_labels = {}
if labels_file:
with open(labels_file) as f:
self.filename_to_labels = json.load(f)
def __len__(self):
return len(self.filepaths)
def __getitem__(self, idx: int):
# get input fields
filepath = self.filepaths[idx]
filename = os.path.basename(filepath)
with open(filepath) as f:
contents = f.read()
# create item with fields
item = {"fields": {}, "labels": {}}
item["fields"]["filename"] = filename
item["fields"]["contents"] = contents
item["labels"] = self.filename_to_labels.get(filename, {})
return item
if __name__ == "__main__":
# create validation data source
train_dataset = EnronDataset(dir="testdata/enron-eval-medium", labels_file="testdata/enron-eval-medium-labels.json", split="train")
# construct plan
plan = EnronDataset(dir="testdata/enron-eval-medium", split="test")
plan = plan.sem_add_columns([
{"name": "subject", "type": str, "desc": "The subject of the email"},
{"name": "sender", "type": str, "desc": "The email address of the email's sender"},
])
plan = plan.sem_filter('The email refers to a fraudulent scheme (i.e., "Raptor", "Deathstar", "Chewco", and/or "Fat Boy")')
plan = plan.sem_filter("The email is not quoting from a news article or an article written by someone outside of Enron")
# execute pz plan
config = pz.QueryProcessorConfig(
policy=pz.MaxQuality(),
execution_strategy="parallel",
k=5,
j=6,
sample_budget=100,
max_workers=20,
progress=True,
)
# output = plan.optimize_and_run(train_dataset=train_dataset, validator=pz.Validator(), config=config)
output = plan.optimize_and_run(train_dataset=train_dataset, validator=pz.Validator(), config=config)
# print output dataframe
print(output.to_df())
# print precision and recall
with open("testdata/enron-eval-medium-labels.json") as f:
filename_to_labels = json.load(f)
test_filenames = os.listdir("testdata/enron-eval-medium")[50:150]
filename_to_labels = {k: v for k, v in filename_to_labels.items() if k in test_filenames}
target_filenames = set(filename for filename, labels in filename_to_labels.items() if labels != [])
pred_filenames = set(output.to_df()["filename"])
tp = sum(filename in target_filenames for filename in pred_filenames)
fp = len(pred_filenames) - tp
fn = len(target_filenames) - tp
print(f"PRECISION: {tp/(tp + fp) if tp + fp > 0 else 0.0:.3f}")
print(f"RECALL: {tp/(tp + fn) if tp + fn > 0 else 0.0:.3f}")