Feat pytorch calibrator#190
Draft
JemmaLDaniel wants to merge 25 commits intomainfrom
Draft
Conversation
…ation and config arguments
…ing training history recording fix: convert ProbabilityCalibrator OmegaConf objects into plain Python types
…vo backwards compatibility
…s-compatibility' into feat-pytorch-calibrator
…formatting from TrainingHistory.plot()
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Replaces the scikit-learn
MLPClassifier-based calibrator with a custom PyTorch neural network, enabling larger and more customisable models for pre-trained calibration of InstaNovo predictions. Introduces a two-phase training pipeline (compute features to disk, then train from Parquet) to support training on large-scale datasets (20M+ spectra across 50+ projects) without loading everything into RAM.Key changes
PyTorch calibrator & training pipeline
ProbabilityCalibratornow wraps aCalibratorNetwork(nn.Module) with configurable hidden dims, dropout, and training hyperparameters. Automatic GPU detection with CPU fallback.model.safetensors+config.json(architecture, hyperparameters, resolved feature configs viaget_config()). Pickle and sklearn model support is dropped.TrainingHistorydataclass records epoch-level train/val losses and accuracies, with JSON persistence and plotting.winnow trainsupports pre-computed features viafeatures_path/val_features_path(single file or directory of Parquets), or a single-phase flow with automaticvalidation_fractionsplit (with data leakage warning).FeatureDataset: PyTorchDatasetwrapper that loads from numpy arrays or Parquet files.from_parquet()supports directories for multi-project concatenation.resolve_data_path: Resolves local paths or downloads from HuggingFace Hub, enabling HF-hosted datasets and models.get_config()convertsDictConfig/ListConfigto plain Python types before JSON serialisation.tqdmprogress bar.Per-experiment iRT regression (merged from
fix-per-experiment-irt-regression)The following changes were developed on the
fix-per-experiment-irt-regressionbranch and merged into this branch:RetentionTimeFeaturenow trains aLinearRegressionperexperiment_name(replacing the single globalMLPRegressor), with configurablemin_train_pointsthreshold.RetentionTimeFeatureregressors saved/loaded via safetensors (keyed by experiment name) instead of pickle, withsave_regressors()/load_regressors()API and CLI support viairt_regressor_output_path/irt_regressor_path.prosit_intensity_model_name→intensity_model_name,check_valid_chimeric_prosit_prediction→check_valid_chimeric_prediction).Prediction column remapping (merged from
feat-prediction-column-remapping-for-instanovo-backwards-compatibility)The following changes were developed on a separate branch and merged into this branch:
InstaNovoDatasetLoaderaccepts acolumn_mappingdict for backwards compatibility with older InstaNovo versions that use different CSV column headers.residue_remapping:InstaNovoDatasetLoader.residue_remappingis now optional (defaults toNone).Data pipeline & build
compute-featuressupports folders of per-experiment files, withexperiment_namedetection from DataFrame columns or file basenames.metadata_output_path(full CSV for EDA) andtraining_matrix_output_path(lean numeric Parquet for training).torch,safetensors, andpolarsadded as direct dependencies inpyproject.toml.winnow.utilspackage: Added withpaths.py(resolve_data_path).docs,docs-serve,clean-docs, andcheck-buildtargets.docs/api/calibration.md,docs/cli.md, anddocs/configuration.mdto reflect all changes.Files changed (24 files, +2652/−761)
winnow/calibration/calibrator.py,winnow/calibration/calibration_features.pywinnow/datasets/feature_dataset.py,winnow/utils/paths.pywinnow/datasets/data_loaders.pywinnow/scripts/main.pycalibrator.yaml,train.yaml,predict.yaml,compute_features.yaml,data_loader/instanovo.yamldocs/api/calibration.md,docs/cli.md,docs/configuration.mdtest_calibrator.py,test_calibration_features.py,test_data_loaders.py,test_feature_dataset.py,test_paths.pypyproject.toml,requirements.txt,uv.lock,Makefile,.gitignoreTest plan
test_end_to_end_fit_predict— full pipeline: add features → compute → fit → predicttest_save_load_roundtrip— save, load, predict on new data (verifies importlib feature reconstruction)test_save_load_weights_and_normalization_match— weights and feature_mean/std survive roundtrip exactlytest_early_stopping_triggers— patience exhausted with inverted val labelstest_get_config_converts_omegaconf_to_plain— DictConfig/ListConfig → plain Python, JSON-serialisabletest_fit_from_features_returns_history/test_fit_from_features_with_validation— two-phase training pathtest_save_and_load_regressors_safetensors— iRT regressor roundtrip via safetensorstest_prepare_per_experiment/test_prepare_skips_preloaded_experiments— per-experiment iRT regressiontest_hf_download_fallback/test_hf_download_failure_raises_with_context— mocked HF download pathtest_dataloader_integration— FeatureDataset works with PyTorch DataLoadertest_from_parquet_directory_preserves_values— multi-file Parquet loading preserves valuestest_default_column_mapping/test_custom_column_mapping_*/test_process_predictions_with_custom_column_mapping— prediction column remapping