API Reference#
Configuration#
Config#
Dataclass holding all training and model hyperparameters.
Parameter |
Default |
Description |
|---|---|---|
Training |
||
|
|
Random seed for reproducibility |
|
|
Training batch size |
|
|
Evaluation/inference batch size |
|
|
Maximum training epochs |
|
|
AdamW learning rate |
|
|
Epochs without val loss improvement before stopping |
Architecture |
||
|
|
Dropout probability in MLP layers |
|
|
Hidden layer width |
|
|
Number of hidden layers in MLP |
|
|
Max token length for tokenizer |
Optimization |
||
|
|
Max gradient norm for clipping |
|
|
Linear warmup steps |
Embedding |
||
|
|
Encoder type ( |
|
|
HuggingFace model ID (auto-resolved from |
|
|
|
Uncertainty |
||
|
|
Below this probability, flag as uncertain |
|
|
Above this probability, flag as uncertain |
Cross-validation |
||
|
|
Number of stratified CV folds |
|
|
Post-hoc calibration method |
Active learning |
||
|
|
Query strategy ( |
|
|
Records per AL iteration |
|
|
Initial random sample fraction |
Categorical encoding |
||
|
|
Minimum count for a categorical value to get its own embedding |
Class weighting |
||
|
|
|
SAFE stopping |
||
|
|
Consecutive irrelevant records before stopping |
|
|
Minimum fraction screened before stopping allowed |
|
|
Random sample fraction for recall estimation |
|
|
Switch model during screening phases |
Preset Configurations#
Preset |
Description |
|---|---|
|
Balanced defaults for general use |
|
Fewer epochs, larger batch size for quick experiments |
|
More epochs, lower learning rate for production |
|
Human-in-the-loop screening settings |
|
Domain-specific presets (science, medicine, general, modernbert) |
sentence_transformer_models#
Set of models that use the SentenceTransformer encoder (frozen, no fine-tuning).
Data#
preprocess_dataset#
Tokenize text columns and encode categorical/numeric features into a dataset. Pass fitted_transforms=None to fit from training data, or a FittedTransforms object to reuse on val/test sets. Returns (CustomDataset, FittedTransforms).
column_specifications#
Dictionary specifying which DataFrame columns to use.
Key |
Value |
Example |
|---|---|---|
|
List of text column names |
|
|
List of categorical column names |
|
|
List of numeric column names |
|
|
String (single-label) or list (multi-label) |
|
Numeric Transforms#
Transform |
Description |
|---|---|
|
Subtract minimum value |
|
Divide by maximum value |
|
Subtract mean |
|
Quantile transform to normal distribution |
|
RobustScaler (median + IQR), better for outliers |
|
log1p then quantile transform, for skewed features |
load_data#
Load data from CSV or Excel files.
split_data#
Split a DataFrame into train, validation, and test sets.
FittedTransforms#
Stores fitted parameters from training data for reuse on val/test sets. Serialize with to_dict() and restore with FittedTransforms.from_dict(d).
CustomDataset#
Dataset holding tokenized text, categorical/numeric features, and labels.
create_dataloader#
Create DataLoader with custom collate function.
collate_fn#
Custom collate to handle text lists in batches.
Model#
PubMLP#
Multi-layer perceptron that combines transformer embeddings with categorical and numeric features.
Parameter |
Description |
|---|---|
|
List of vocab sizes for nn.Embedding per categorical column |
|
1 for single-label, N for multi-label |
Training & Evaluation#
train_evaluate_model#
Full training loop with validation, early stopping, and test evaluation. Returns (train_losses, val_losses, train_accs, val_accs, test_acc, best_val_loss, best_model_state, best_epoch).
calculate_loss#
Average loss across all batches.
calculate_accuracy#
Accuracy (%) across all batches; multi-label returns average per-label accuracy.
calculate_pos_weight#
Compute pos_weight from label distribution (neg_count / pos_count per label).
calculate_evaluation_metrics#
Compute classification report, confusion matrix, and ROC-AUC. Single label: returns accuracy, precision, recall, specificity, f1_score, roc_auc. Multi-label: returns per_label metrics, macro_f1, hamming_loss.
calculate_wss_at_recall#
WSS@recall (Cohen et al., 2006): fraction of screening effort saved at target recall.
calculate_ndcg#
NDCG via sklearn.metrics.ndcg_score (Järvelin & Kekäläinen, 2002).
cross_validate#
Stratified K-fold cross-validation with per-fold metrics.
plot_results#
Plot training/validation loss and accuracy curves.
TemperatureScaling#
Post-hoc temperature scaling for model calibration.
calibrate_model#
Fit temperature scaling on validation data.
collect_logits#
Collect raw logits from a trained model.
Prediction#
predict_model#
Run inference and return predictions and probabilities. Single label: flat lists. Multi-label: list of lists.
get_predictions_and_labels#
Get predictions, probabilities, and true labels from a labeled dataloader.
flag_uncertain#
Flag predictions with probabilities in an uncertain range for human review. Multi-label: returns list of lists of bools.
Screening#
regex_screen#
Screen a dataset using regex patterns with optional semantic similarity scoring.
from pubmlp import regex_screen
results = regex_screen("records.csv", inclusion_patterns=["intervention", "randomized"])
extract_window_evidence#
Extract word windows around regex matches.
extract_sentence_evidence#
Extract complete sentences containing regex matches.
extract_all_evidence#
Extract evidence from specified fields in a DataFrame row.
format_evidence_display#
Format evidence list as ‘field: text; field: text; …’.
calculate_semantic_scores#
Calculate cosine similarity between evidence texts and criterion description.
generate_descriptions#
Draft criterion descriptions from regex pattern terms. Extracts literal terms from each pattern and composes a natural language description. The user should review and refine each description before use.
from pubmlp import generate_descriptions
patterns = {'math': {'pattern': r'\b(algebra|geometry)\w*\b'}}
drafts = generate_descriptions(patterns, domain='K-12 education')
# drafts['math']['description'] → "In K-12 education, the study addresses math..."
# drafts['math']['source'] → 'generated'
confirm_descriptions#
Validate that all criteria have non-empty descriptions and return confirmed patterns. Optionally saves to JSON for reproducibility. Raises ValueError if any description is empty.
from pubmlp import confirm_descriptions
# After user edits drafts['math']['description']
confirmed = confirm_descriptions(drafts, save_path='confirmed.json')
# Returns dict ready for regex_screen() and score_full_text()
score_full_text#
Score full record text (title + abstract) against each criterion description via cosine similarity. Scores all records including those with no regex match. Adds {criterion}_semantic_full column per criterion.
from pubmlp import score_full_text
df = score_full_text(df, confirmed_patterns, fields=['title', 'abstract'])
# df['math_semantic_full'] → cosine similarity for every record
compare_screening_configs#
Run regex_screen with multiple configurations (different descriptions, window sizes, units) and return a summary DataFrame comparing match counts, semantic score distributions, and overlap.
from pubmlp import compare_screening_configs
comparison = compare_screening_configs('data.xlsx', {
'specific': {'inclusion_patterns': patterns_v1},
'broad': {'inclusion_patterns': patterns_v2, 'unit': 'window', 'window_size': 10},
})
# Returns DataFrame: config, criterion, n_matched, match_pct, semantic medians
create_stratified_sample#
Create a stratified random sample with regex pattern highlights for human coding.
save_sample_excel#
Save sample to Excel with conditional formatting for review.
apply_conditional_formatting#
Apply conditional formatting to Excel coding sheet (headers green, pattern counts yellow).
count_pattern_matches#
Count regex matches in text (case-insensitive).
highlight_pattern_matches#
Return up to 3 matched snippets with context for visual inspection.
Active Learning#
safe_stratified_split#
Stratified train/val split with random fallback when rare classes prevent stratification.
from pubmlp import safe_stratified_split
train_idx, val_idx = safe_stratified_split(X, y, test_size=0.2, random_state=42)
select_query_batch#
Select the most uncertain samples for human review.
create_review_batch#
Create a review batch DataFrame with model probability and prediction columns.
compare_reviewers#
Compute inter-rater agreement (kappa + agreement rate) between model and human.
merge_human_labels#
Merge human decisions from review batch back into the main DataFrame.
ALState#
Dataclass tracking active learning iteration state.
simulate_al#
Offline AL simulation using ground truth labels; model_fn(train_df, unlabeled_df) returns probabilities.
rank_by_hybrid_max_uncertainty#
95% max-relevance + 5% uncertainty ranking strategy.
rank_by_hybrid_max_random#
95% max-relevance + 5% random ranking strategy.
Stopping Rules#
recall_target_test#
Statistical stopping test for recall-based screening. Returns stop decision, recall lower bound, and maximum missed relevant records.
should_stop#
Evaluate whether screening can stop based on SAFE criterion.
update_stopping_state#
Update stopping state counters after a human screening decision.
estimate_recall#
Wilson score lower bound estimate of recall.
generate_stopping_report#
Generate a summary report of stopping criteria.
calculate_wss#
Calculate Work Saved over Sampling at a given recall level.
transition_phase#
Advance phase based on screening progress.
StoppingState#
Dataclass tracking stopping-rule state across iterations.
Audit#
AuditTrail#
Record and persist screening decisions for reproducibility.
AuditEntry#
Single audit log entry dataclass.
summarize_human_decisions#
Summarize human reviewer decisions from an audit trail.
generate_prisma_report#
Generate a PRISMA-style flow diagram report.
interpret_kappa#
Interpret Cohen’s kappa agreement level.
Utilities#
get_device#
Return the best available PyTorch device (CUDA or CPU).
auto_batch_size#
Suggest a batch size based on available GPU memory.
unpack_batch#
Move batch tensors to device and return unpacked components.