Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
1ee81f9
Fix one-off error in combine_selectors
hombit Jul 11, 2025
2c0dad4
Initial re-impl of Cython code in Rust
hombit May 15, 2024
6deadb6
Fix selector's dtype
hombit May 15, 2024
72f328b
Rust re-impl of feature_delta_sum
hombit May 15, 2024
d18e3aa
leaf_offsets
hombit May 15, 2024
71ec7be
Make ABI3 optional
hombit May 15, 2024
b5d4890
Make selector's dtype a OnceCell
hombit May 21, 2024
4d3ccb7
More rust modules
hombit May 21, 2024
3667a79
Fix clippy lints
hombit May 21, 2024
1b3f0d7
Better optimized release module
hombit May 21, 2024
b82364a
Allocate output arrays with numpy
hombit May 22, 2024
6d43ff0
Change signature impl to match Cython
hombit May 22, 2024
b12e168
Do not create Rayon pool for n_jobs=1
hombit May 22, 2024
995d29d
Simgle/multi impls for transpose
hombit May 25, 2024
13cb048
Revert "Simgle/multi impls for transpose"
hombit May 25, 2024
7386934
Parallel batch_size
hombit May 27, 2024
0e55d87
#[allow(clippy::too_many_arguments)]
hombit May 28, 2024
bbd52a7
Run cargo update
hombit Sep 3, 2024
84b1581
Implement calc_apply
hombit May 9, 2025
8f7a598
Cargo clippy
hombit May 9, 2025
7ee9d9e
Rebane ext moduls to calc_trees
hombit May 9, 2025
b607188
pyO3 & numpy 0.22
hombit May 9, 2025
0db0ad4
pyO3 & numpy 0.23
hombit May 9, 2025
9035691
Update rust deps
hombit May 9, 2025
6cae3dd
Support free-threading CPython
hombit May 9, 2025
98f37ca
Fix rust tests
hombit May 9, 2025
a022f8e
clippy
hombit May 9, 2025
354260e
Set version via Cargo.toml
hombit May 9, 2025
cf2789c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2025
8c4f0b1
MutSlices::next(_back) changes
hombit May 16, 2025
d987cd7
Bump ABI3 to py3.10+
hombit Jul 1, 2025
ea54650
cargo update
hombit Jul 2, 2025
81d8f4f
Bumpy pyo3/numpy to 0.25
hombit Jul 2, 2025
0229b7b
WIP: devnet notebook
hombit May 13, 2024
339831c
Update notebook to use 200 realisations
hombit Jul 3, 2025
3d388cc
update devnet galzoo2
hombit Jul 10, 2025
1117237
Update devnet_datasets.ipynb
hombit Jul 10, 2025
a4d7049
Add (temp) gz2.py
hombit Jul 10, 2025
cd85457
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 10, 2025
6e740b5
Make mutable borrow of result arrays safe
hombit Jul 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 239 additions & 0 deletions docs/notebooks/devnet_datasets.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ea4ae65a-d555-4b54-96f9-11eed006adc2",
"metadata": {},
"outputs": [],
"source": [
"# %pip install coniferest matplotlib pandas tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3d9577061e9494ed",
"metadata": {
"ExecuteTime": {
"end_time": "2025-07-03T15:29:14.711984Z",
"start_time": "2025-07-03T15:29:13.632289Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"\n",
"from coniferest.aadforest import AADForest\n",
"from coniferest.datasets import Dataset, DevNetDataset\n",
"from coniferest.isoforest import IsolationForest\n",
"from coniferest.label import Label\n",
"from coniferest.pineforest import PineForest\n",
"from coniferest.session.oracle import OracleSession, create_oracle_session"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2025-07-03T15:29:15.712748Z",
"start_time": "2025-07-03T15:29:15.707980Z"
}
},
"outputs": [],
"source": [
"class Compare:\n",
" models = {\n",
" 'Isolation Forest': IsolationForest,\n",
" 'AAD': AADForest,\n",
" 'Pine Forest': PineForest,\n",
" }\n",
"\n",
" def __init__(self, dataset: Dataset, *, iterations=100, n_jobs=-1, sampletrees_per_batch=1<<20):\n",
" self.model_kwargs = {\n",
" 'n_trees': 128,\n",
" 'sampletrees_per_batch': sampletrees_per_batch,\n",
" 'n_jobs': n_jobs,\n",
" }\n",
" self.session_kwargs = {\n",
" 'data': dataset.data,\n",
" 'labels': dataset.labels,\n",
" 'max_iterations': iterations,\n",
" }\n",
" self.results = {}\n",
" self.steps = np.arange(1, iterations + 1)\n",
" self.total_anomaly_fraction = np.mean(dataset.labels == Label.A)\n",
"\n",
" def get_sessions(self, random_seed):\n",
" model_kwargs = self.model_kwargs | {'random_seed': random_seed}\n",
"\n",
" return {\n",
" name: create_oracle_session(model=model(**model_kwargs), **self.session_kwargs)\n",
" for name, model in self.models.items()\n",
" }\n",
"\n",
" def run(self, random_seeds):\n",
" assert len(random_seeds) == len(set(random_seeds)), \"random seeds must be different\"\n",
" \n",
" results = defaultdict(dict)\n",
"\n",
" futures = []\n",
" for random_seed in tqdm(random_seeds):\n",
" sessions = self.get_sessions(random_seed)\n",
" for name, session in sessions.items():\n",
" session.run()\n",
" anomalies = np.cumsum(np.array(list(session.known_labels.values())) == Label.A)\n",
" results[name][random_seed] = anomalies\n",
"\n",
" self.results |= results\n",
" return self\n",
"\n",
" def plot(self, dataset_name: str, savefig=False):\n",
" plt.figure(figsize=(8, 6))\n",
" plt.title(f'Dataset: {dataset_name}')\n",
"\n",
" for name, anomalies_dict in self.results.items():\n",
" anomalies = np.stack(list(anomalies_dict.values()))\n",
" q5, median, q95 = np.quantile(anomalies, [0.05, 0.5, 0.95], axis=0)\n",
"\n",
" plt.plot(self.steps, median, alpha=0.75, label=name)\n",
" plt.fill_between(self.steps, q5, q95, alpha=0.5)\n",
"\n",
" plt.plot(self.steps, self.steps * self.total_anomaly_fraction, ls='--', color='grey',\n",
" label='Theoretical random')\n",
"\n",
" plt.xlabel('Iteration')\n",
" plt.ylabel('Number of anomalies')\n",
" plt.grid()\n",
" plt.legend()\n",
" if savefig:\n",
" plt.savefig(f'{dataset_name}.pdf')\n",
"\n",
" return self"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "929fd77b-3333-4937-90aa-d2804151d868",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/200 [00:00<?, ?it/s]"
]
}
],
"source": [
"import pickle\n",
"from pathlib import Path\n",
"\n",
"import pandas as pd\n",
"\n",
"class GalaxyZoo2Dataset(Dataset):\n",
" def __init__(self, path: Path, *, anomaly_class='Class6.1', anomaly_threshold=0.9):\n",
" astronomaly = pd.read_parquet(path / \"astronomaly.parquet\")\n",
" self.data = astronomaly.drop(columns=['GalaxyID', 'anomaly']).to_numpy().copy(order='C')\n",
" ids = astronomaly['GalaxyID'].to_numpy()\n",
"\n",
" solutions = pd.read_csv(path / \"training_solutions_rev1.csv\", index_col=\"GalaxyID\")\n",
" anomaly = solutions[anomaly_class][ids] >= anomaly_threshold\n",
" self.labels = np.full(anomaly.shape, Label.R)\n",
" self.labels[anomaly] = Label.A\n",
"\n",
"\n",
"seeds = range(12, 212)\n",
"\n",
"path = Path(\"/home/hombit/gz2\")\n",
"dataset_obj = GalaxyZoo2Dataset(path)\n",
"%time compare_zoo = Compare(dataset_obj, iterations=100, n_jobs=24, sampletrees_per_batch=1<<16).run(seeds)\n",
"compare_zoo.plot(\"Galaxy Zoo 2 (Anything odd? 90%)\", savefig=True)\n",
"with open(\"galaxyzoo2_compare.pickle\", \"wb\") as fh:\n",
" pickle.dump(compare_zoo, fh)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71c337b3577915d5",
"metadata": {
"ExecuteTime": {
"end_time": "2025-07-03T15:35:53.300312Z",
"start_time": "2025-07-03T15:34:16.696646Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"import pickle\n",
"\n",
"from joblib.parallel import delayed, Parallel\n",
"\n",
"print(DevNetDataset.avialble_datasets)\n",
"\n",
"seeds = range(200)\n",
"compare_delayed = delayed(\n",
" lambda dataset: Compare(DevNetDataset(dataset), iterations=100, n_jobs=48, sampletrees_per_batch=1<<16).run(seeds),\n",
")\n",
"compare_ = Parallel(\n",
" n_jobs=len(DevNetDataset.avialble_datasets),\n",
")(compare_delayed(dataset) for dataset in DevNetDataset.avialble_datasets)\n",
"\n",
"for dataset, compare_obj in zip(DevNetDataset.avialble_datasets, compare_):\n",
" print(f\"Plot {dataset}\")\n",
" compare_obj.plot(dataset, savefig=True)\n",
"\n",
"for dataset, compare_obj in zip(DevNetDataset.avialble_datasets, compare_):\n",
" print(f\"Save Compare object for {dataset}\")\n",
" with open(f'{dataset}_compare.pickle', 'wb') as fh:\n",
" pickle.dump(compare_obj, fh)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb3c8e56-306a-4bd7-a756-f8489deb1c22",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
107 changes: 107 additions & 0 deletions docs/notebooks/gz2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from coniferest.aadforest import AADForest
from coniferest.datasets import Dataset, DevNetDataset
from coniferest.isoforest import IsolationForest
from coniferest.label import Label
from coniferest.pineforest import PineForest
from coniferest.session.oracle import OracleSession, create_oracle_session

class Compare:
models = {
'Isolation Forest': IsolationForest,
'AAD': AADForest,
'Pine Forest': PineForest,
}

def __init__(self, dataset: Dataset, *, iterations=100, n_jobs=-1, sampletrees_per_batch=1<<20):
self.model_kwargs = {
'n_trees': 128,
'sampletrees_per_batch': sampletrees_per_batch,
'n_jobs': n_jobs,
}
self.session_kwargs = {
'data': dataset.data,
'labels': dataset.labels,
'max_iterations': iterations,
}
self.results = {}
self.steps = np.arange(1, iterations + 1)
self.total_anomaly_fraction = np.mean(dataset.labels == Label.A)

def get_sessions(self, random_seed):
model_kwargs = self.model_kwargs | {'random_seed': random_seed}

return {
name: create_oracle_session(model=model(**model_kwargs), **self.session_kwargs)
for name, model in self.models.items()
}

def run(self, random_seeds):
assert len(random_seeds) == len(set(random_seeds)), "random seeds must be different"

results = defaultdict(dict)

futures = []
for random_seed in tqdm(random_seeds):
sessions = self.get_sessions(random_seed)
for name, session in sessions.items():
session.run()
anomalies = np.cumsum(np.array(list(session.known_labels.values())) == Label.A)
results[name][random_seed] = anomalies

self.results |= results
return self

def plot(self, dataset_name: str, savefig=False):
plt.figure(figsize=(8, 6))
plt.title(f'Dataset: {dataset_name}')

for name, anomalies_dict in self.results.items():
anomalies = np.stack(list(anomalies_dict.values()))
q5, median, q95 = np.quantile(anomalies, [0.05, 0.5, 0.95], axis=0)

plt.plot(self.steps, median, alpha=0.75, label=name)
plt.fill_between(self.steps, q5, q95, alpha=0.5)

plt.plot(self.steps, self.steps * self.total_anomaly_fraction, ls='--', color='grey',
label='Theoretical random')

plt.xlabel('Iteration')
plt.ylabel('Number of anomalies')
plt.grid()
plt.legend()
if savefig:
plt.savefig(f'{dataset_name}.pdf')

return self

import pickle
from pathlib import Path

import pandas as pd

class GalaxyZoo2Dataset(Dataset):
def __init__(self, path: Path, *, anomaly_class='Class6.1', anomaly_threshold=0.9):
astronomaly = pd.read_parquet(path / "astronomaly.parquet")
self.data = astronomaly.drop(columns=['GalaxyID', 'anomaly']).to_numpy().copy(order='C')
ids = astronomaly['GalaxyID'].to_numpy()

solutions = pd.read_csv(path / "training_solutions_rev1.csv", index_col="GalaxyID")
anomaly = solutions[anomaly_class][ids] >= anomaly_threshold
self.labels = np.full(anomaly.shape, Label.R)
self.labels[anomaly] = Label.A


seeds = range(12, 212)

path = Path("/home/hombit/gz2")
dataset_obj = GalaxyZoo2Dataset(path)
compare_zoo = Compare(dataset_obj, iterations=100, n_jobs=1, sampletrees_per_batch=1<<16).run(seeds)
compare_zoo.plot("Galaxy Zoo 2 (Anything odd? 90%)", savefig=True)
with open("galaxyzoo2_compare.pickle", "wb") as fh:
pickle.dump(compare_zoo, fh)
2 changes: 1 addition & 1 deletion rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "coniferest"
version = "0.1.0"
version = "0.0.16"
edition = "2021"

[lib]
Expand Down
Loading