API Reference#

Configuration#

Config#

Dataclass holding all training and model hyperparameters.

Parameter

Default

Description

Training

random_seed

42

Random seed for reproducibility

batch_size

16

Training batch size

eval_batch_size

32

Evaluation/inference batch size

epochs

10

Maximum training epochs

learning_rate

2e-5

AdamW learning rate

early_stopping_patience

3

Epochs without val loss improvement before stopping

Architecture

dropout_rate

0.2

Dropout probability in MLP layers

mlp_hidden_size

64

Hidden layer width

n_hidden_layers

1

Number of hidden layers in MLP

max_length

512

Max token length for tokenizer

Optimization

gradient_clip_norm

1.0

Max gradient norm for clipping

warmup_steps

0

Linear warmup steps

Embedding

embedding_model

'bert'

Encoder type (bert, scibert, pubmedbert, modernbert, sentence-transformer, bge-small)

model_name

None

HuggingFace model ID (auto-resolved from embedding_model if None)

pooling_strategy

'auto'

'auto', 'pooler', or 'mean'

Uncertainty

uncertainty_low

0.3

Below this probability, flag as uncertain

uncertainty_high

0.7

Above this probability, flag as uncertain

Cross-validation

n_folds

5

Number of stratified CV folds

calibration_method

'temperature'

Post-hoc calibration method

Active learning

al_query_strategy

'uncertainty'

Query strategy (uncertainty, random, max_relevance, hybrid_max_uncertainty, hybrid_max_random)

al_batch_size

20

Records per AL iteration

al_initial_sample_pct

0.1

Initial random sample fraction

Categorical encoding

rare_threshold

5

Minimum count for a categorical value to get its own embedding

Class weighting

pos_weight

'auto'

'auto' computes from training labels, None disables

SAFE stopping

safe_consecutive_irrelevant

50

Consecutive irrelevant records before stopping

safe_min_screened_pct

0.5

Minimum fraction screened before stopping allowed

safe_random_sample_pct

0.1

Random sample fraction for recall estimation

safe_switch_model

False

Switch model during screening phases

Preset Configurations#

Preset

Description

default_config

Balanced defaults for general use

fast_config

Fewer epochs, larger batch size for quick experiments

robust_config

More epochs, lower learning rate for production

hitl_config

Human-in-the-loop screening settings

domain_configs

Domain-specific presets (science, medicine, general, modernbert)

sentence_transformer_models#

Set of models that use the SentenceTransformer encoder (frozen, no fine-tuning).

Data#

preprocess_dataset#

Tokenize text columns and encode categorical/numeric features into a dataset. Pass fitted_transforms=None to fit from training data, or a FittedTransforms object to reuse on val/test sets. Returns (CustomDataset, FittedTransforms).

column_specifications#

Dictionary specifying which DataFrame columns to use.

Key

Value

Example

text_cols

List of text column names

["TI", "AB"]

categorical_cols

List of categorical column names

["SO"]

numeric_cols

List of numeric column names

["PY"]

label_col

String (single-label) or list (multi-label)

"label" or ["label_1", "label_2"]

Numeric Transforms#

Transform

Description

min

Subtract minimum value

max

Divide by maximum value

mean

Subtract mean

quantile

Quantile transform to normal distribution

robust

RobustScaler (median + IQR), better for outliers

log1p

log1p then quantile transform, for skewed features

load_data#

Load data from CSV or Excel files.

split_data#

Split a DataFrame into train, validation, and test sets.

FittedTransforms#

Stores fitted parameters from training data for reuse on val/test sets. Serialize with to_dict() and restore with FittedTransforms.from_dict(d).

CustomDataset#

Dataset holding tokenized text, categorical/numeric features, and labels.

create_dataloader#

Create DataLoader with custom collate function.

collate_fn#

Custom collate to handle text lists in batches.

Model#

PubMLP#

Multi-layer perceptron that combines transformer embeddings with categorical and numeric features.

Parameter

Description

categorical_vocab_sizes

List of vocab sizes for nn.Embedding per categorical column

output_size

1 for single-label, N for multi-label

Training & Evaluation#

train_evaluate_model#

Full training loop with validation, early stopping, and test evaluation. Returns (train_losses, val_losses, train_accs, val_accs, test_acc, best_val_loss, best_model_state, best_epoch).

calculate_loss#

Average loss across all batches.

calculate_accuracy#

Accuracy (%) across all batches; multi-label returns average per-label accuracy.

calculate_pos_weight#

Compute pos_weight from label distribution (neg_count / pos_count per label).

calculate_evaluation_metrics#

Compute classification report, confusion matrix, and ROC-AUC. Single label: returns accuracy, precision, recall, specificity, f1_score, roc_auc. Multi-label: returns per_label metrics, macro_f1, hamming_loss.

calculate_wss_at_recall#

WSS@recall (Cohen et al., 2006): fraction of screening effort saved at target recall.

calculate_ndcg#

NDCG via sklearn.metrics.ndcg_score (Järvelin & Kekäläinen, 2002).

cross_validate#

Stratified K-fold cross-validation with per-fold metrics.

plot_results#

Plot training/validation loss and accuracy curves.

TemperatureScaling#

Post-hoc temperature scaling for model calibration.

calibrate_model#

Fit temperature scaling on validation data.

collect_logits#

Collect raw logits from a trained model.

Prediction#

predict_model#

Run inference and return predictions and probabilities. Single label: flat lists. Multi-label: list of lists.

get_predictions_and_labels#

Get predictions, probabilities, and true labels from a labeled dataloader.

flag_uncertain#

Flag predictions with probabilities in an uncertain range for human review. Multi-label: returns list of lists of bools.

Screening#

regex_screen#

Screen a dataset using regex patterns with optional semantic similarity scoring.

from pubmlp import regex_screen

results = regex_screen("records.csv", inclusion_patterns=["intervention", "randomized"])

extract_window_evidence#

Extract word windows around regex matches.

extract_sentence_evidence#

Extract complete sentences containing regex matches.

extract_all_evidence#

Extract evidence from specified fields in a DataFrame row.

format_evidence_display#

Format evidence list as ‘field: text; field: text; …’.

calculate_semantic_scores#

Calculate cosine similarity between evidence texts and criterion description.

generate_descriptions#

Draft criterion descriptions from regex pattern terms. Extracts literal terms from each pattern and composes a natural language description. The user should review and refine each description before use.

from pubmlp import generate_descriptions

patterns = {'math': {'pattern': r'\b(algebra|geometry)\w*\b'}}
drafts = generate_descriptions(patterns, domain='K-12 education')
# drafts['math']['description'] → "In K-12 education, the study addresses math..."
# drafts['math']['source'] → 'generated'

confirm_descriptions#

Validate that all criteria have non-empty descriptions and return confirmed patterns. Optionally saves to JSON for reproducibility. Raises ValueError if any description is empty.

from pubmlp import confirm_descriptions

# After user edits drafts['math']['description']
confirmed = confirm_descriptions(drafts, save_path='confirmed.json')
# Returns dict ready for regex_screen() and score_full_text()

score_full_text#

Score full record text (title + abstract) against each criterion description via cosine similarity. Scores all records including those with no regex match. Adds {criterion}_semantic_full column per criterion.

from pubmlp import score_full_text

df = score_full_text(df, confirmed_patterns, fields=['title', 'abstract'])
# df['math_semantic_full'] → cosine similarity for every record

compare_screening_configs#

Run regex_screen with multiple configurations (different descriptions, window sizes, units) and return a summary DataFrame comparing match counts, semantic score distributions, and overlap.

from pubmlp import compare_screening_configs

comparison = compare_screening_configs('data.xlsx', {
    'specific': {'inclusion_patterns': patterns_v1},
    'broad': {'inclusion_patterns': patterns_v2, 'unit': 'window', 'window_size': 10},
})
# Returns DataFrame: config, criterion, n_matched, match_pct, semantic medians

create_stratified_sample#

Create a stratified random sample with regex pattern highlights for human coding.

save_sample_excel#

Save sample to Excel with conditional formatting for review.

apply_conditional_formatting#

Apply conditional formatting to Excel coding sheet (headers green, pattern counts yellow).

count_pattern_matches#

Count regex matches in text (case-insensitive).

highlight_pattern_matches#

Return up to 3 matched snippets with context for visual inspection.

Active Learning#

safe_stratified_split#

Stratified train/val split with random fallback when rare classes prevent stratification.

from pubmlp import safe_stratified_split

train_idx, val_idx = safe_stratified_split(X, y, test_size=0.2, random_state=42)

select_query_batch#

Select the most uncertain samples for human review.

create_review_batch#

Create a review batch DataFrame with model probability and prediction columns.

compare_reviewers#

Compute inter-rater agreement (kappa + agreement rate) between model and human.

merge_human_labels#

Merge human decisions from review batch back into the main DataFrame.

ALState#

Dataclass tracking active learning iteration state.

simulate_al#

Offline AL simulation using ground truth labels; model_fn(train_df, unlabeled_df) returns probabilities.

rank_by_hybrid_max_uncertainty#

95% max-relevance + 5% uncertainty ranking strategy.

rank_by_hybrid_max_random#

95% max-relevance + 5% random ranking strategy.

Stopping Rules#

recall_target_test#

Statistical stopping test for recall-based screening. Returns stop decision, recall lower bound, and maximum missed relevant records.

should_stop#

Evaluate whether screening can stop based on SAFE criterion.

update_stopping_state#

Update stopping state counters after a human screening decision.

estimate_recall#

Wilson score lower bound estimate of recall.

generate_stopping_report#

Generate a summary report of stopping criteria.

calculate_wss#

Calculate Work Saved over Sampling at a given recall level.

transition_phase#

Advance phase based on screening progress.

StoppingState#

Dataclass tracking stopping-rule state across iterations.

Audit#

AuditTrail#

Record and persist screening decisions for reproducibility.

AuditEntry#

Single audit log entry dataclass.

summarize_human_decisions#

Summarize human reviewer decisions from an audit trail.

generate_prisma_report#

Generate a PRISMA-style flow diagram report.

interpret_kappa#

Interpret Cohen’s kappa agreement level.

Utilities#

get_device#

Return the best available PyTorch device (CUDA or CPU).

auto_batch_size#

Suggest a batch size based on available GPU memory.

unpack_batch#

Move batch tensors to device and return unpacked components.