Publication Screening on Single Case Design Research¶


In [ ]:
from datetime import datetime
date = datetime.now()
formatted_date = date.strftime("%B %d, %Y")
print(formatted_date)
February 28, 2026

0. Configuring the System Environment¶

In [ ]:
# Check GPU memory
!nvidia-smi
Sat Feb 28 01:35:26 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82.07              Driver Version: 580.82.07      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P0             47W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
In [ ]:
# Check system RAM
!free -h
               total        used        free      shared  buff/cache   available
Mem:            83Gi       897Mi        78Gi       3.0Mi       3.7Gi        81Gi
Swap:             0B          0B          0B

1. Setting Up the Computing Environment¶

Install and load Python libraries.¶
In [ ]:
from google.colab import drive
drive.mount('/content/drive')

from google.colab import userdata
userdata.get('HF_TOKEN')

# Update this path to your project directory in Google Drive
%cd /content/drive/My\ Drive/Colab\ Notebooks/LLM/sped_biblio/screening_single_case
Mounted at /content/drive
/content/drive/My Drive/Colab Notebooks/LLM/sped_biblio/screening_single_case
In [ ]:
# Optional: install from conda environment
# !pip install -q condacolab
# !conda env create -f python-venv-environment.yml
In [ ]:
!pip install -q pubmlp
In [ ]:
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning, module='jupyter_client')

import pandas as pd
import torch
from torch.optim import AdamW
import torch.nn as nn
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from transformers import AutoTokenizer

from pubmlp import (
    Config, PubMLP,
    split_data, preprocess_dataset, create_dataloader,
    train_evaluate_model, predict_model,
    get_predictions_and_labels, flag_uncertain,
    calculate_evaluation_metrics, plot_results,
    cross_validate, calibrate_model,
    get_device,
)

2. Screening Studies on Single Case Design Research¶

Prepare data for training, validation, and testing.¶
In [ ]:
warnings.filterwarnings('ignore', category=FutureWarning)

labeled_data = pd.read_excel("files/labeled_data.xlsx")

labeled_df = labeled_data[['UT', 'single_case', 'technology_use', 'SO', 'PY', 'AF', 'TI', 'AB', 'DE']].copy()

missing_values = labeled_df.isnull().sum()
print(missing_values)

df = labeled_df.copy()
df["PY"] = pd.to_numeric(df["PY"], errors='coerce')
df.dropna(subset=["PY"], inplace=True)
df["single_case"] = df["single_case"].map({'Yes': 1, 'No': 0})
df["technology_use"] = df["technology_use"].map({'Yes': 1, 'No': 0})

train_df, validation_df, test_df = split_data(df, random_state=42)

print(f"\nTraining: {len(train_df)}, Validation: {len(validation_df)}, Test: {len(test_df)}")
UT                  0
single_case         0
technology_use      0
SO                  0
PY                  0
AF                  0
TI                  0
AB                  0
DE                269
dtype: int64

Training: 1264, Validation: 158, Test: 158
Preprocess data and convert to dataloaders.¶
In [ ]:
warnings.filterwarnings('ignore', category=FutureWarning)

config = Config(
    random_seed=2025,
    embedding_model='bert',
    batch_size=8,
    eval_batch_size=8,
    epochs=5,
    learning_rate=1e-5,
    mlp_hidden_size=16,
    dropout_rate=0.2,
    early_stopping_patience=3,
    n_hidden_layers=2,
)
config.set_random_seeds()

device = get_device()
print(f"Device: {device}, Batch size: {config.batch_size}")

tokenizer = AutoTokenizer.from_pretrained(config.model_name)

column_specifications = {
    "text_cols": ['AF', 'TI', 'AB', 'DE'],
    "categorical_cols": ['SO'],
    "numeric_cols": ['PY'],
    "label_col": "single_case",
}

numeric_transform = {'PY': 'min'}
label_col = column_specifications['label_col']

train_dataset, fitted = preprocess_dataset(
    train_df, tokenizer, device, column_specifications, numeric_transform,
    max_length=config.max_length,
)
train_dataloader = create_dataloader(train_dataset, RandomSampler, config.batch_size)

validation_dataset, _ = preprocess_dataset(
    validation_df, tokenizer, device, column_specifications, numeric_transform,
    max_length=config.max_length, fitted_transforms=fitted,
)
validation_dataloader = create_dataloader(validation_dataset, SequentialSampler, config.eval_batch_size)

test_dataset, _ = preprocess_dataset(
    test_df, tokenizer, device, column_specifications, numeric_transform,
    max_length=config.max_length, fitted_transforms=fitted,
)
test_dataloader = create_dataloader(test_dataset, SequentialSampler, config.eval_batch_size)
Device: cuda, Batch size: 8
config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]
tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]
vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]
tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]
Train the model and perform validation.¶
In [ ]:
import logging
logging.getLogger("transformers").setLevel(logging.ERROR)

model = PubMLP(
    categorical_vocab_sizes=fitted.categorical_vocab_sizes,
    numeric_cols_num=1,
    mlp_hidden_size=config.mlp_hidden_size,
    output_size=1,
    dropout_rate=config.dropout_rate,
    embedding_model=config.embedding_model,
    model_name=config.model_name,
    n_hidden_layers=config.n_hidden_layers,
).to(device)

logging.getLogger("transformers").setLevel(logging.WARNING)

criterion = nn.BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=config.learning_rate, eps=1e-8)

(train_losses, validation_losses,
 train_accuracies, validation_accuracies,
 test_accuracy, best_val_loss,
 best_model_state, best_epoch) = train_evaluate_model(
    model, train_dataloader, validation_dataloader, test_dataloader,
    optimizer, criterion, device, config.epochs,
    early_stopping_patience=config.early_stopping_patience,
    gradient_clip_norm=config.gradient_clip_norm,
    pos_weight=None,
    use_warmup=False,
)

torch.save(best_model_state, f"best_model_state_{label_col}.pth")
model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]
Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]
Epoch 1/5: 100%|██████████| 158/158 [00:27<00:00,  5.82it/s, loss=0.524]
Epoch: 0001/0005 | Train Loss: 0.524 | Val Loss: 0.383 *** Best ***
Train Acc: 83.465% | Val Acc: 82.278% | 0.62 min
Epoch 2/5: 100%|██████████| 158/158 [00:25<00:00,  6.11it/s, loss=0.373]
Epoch: 0002/0005 | Train Loss: 0.373 | Val Loss: 0.290 *** Best ***
Train Acc: 95.965% | Val Acc: 95.570% | 1.21 min
Epoch 3/5: 100%|██████████| 158/158 [00:25<00:00,  6.11it/s, loss=0.309]
Epoch: 0003/0005 | Train Loss: 0.309 | Val Loss: 0.246 *** Best ***
Train Acc: 94.778% | Val Acc: 95.570% | 1.81 min
Epoch 4/5: 100%|██████████| 158/158 [00:25<00:00,  6.11it/s, loss=0.252]
Epoch: 0004/0005 | Train Loss: 0.252 | Val Loss: 0.210 *** Best ***
Train Acc: 97.310% | Val Acc: 95.570% | 2.41 min
Epoch 5/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.223]
Epoch: 0005/0005 | Train Loss: 0.223 | Val Loss: 0.197 *** Best ***
Train Acc: 96.915% | Val Acc: 96.203% | 3.00 min

Loading best model from epoch 5
Test Accuracy: 96.203% | Best epoch 5 (val loss: 0.197)
Plot shows training losses, validation losses, training accuracies, validation accuracies, test accuracy, and the best validation loss over time.¶
In [ ]:
plot_results(train_losses, validation_losses, train_accuracies, validation_accuracies, test_accuracy, best_val_loss, best_epoch=best_epoch)
No description has been provided for this image
Evaluation metrics: classification report, confusion matrix, and ROC curve.¶
In [ ]:
test_predictions, test_probs, test_labels = get_predictions_and_labels(model, test_dataloader, device)

metrics = calculate_evaluation_metrics(
    test_labels, test_predictions, test_probs,
    output_dir='files', label_name=label_col, save_figures=True,
)

uncertain_flags = flag_uncertain(test_probs)
print(f"\nUncertain predictions (0.3-0.7): {sum(uncertain_flags)} / {len(uncertain_flags)}")
Evaluating: 100%|██████████| 20/20 [00:01<00:00, 19.36it/s]
EVALUATION METRICS: SINGLE_CASE
              precision    recall  f1-score   support

     Exclude      1.000     0.739     0.850        23
     Include      0.957     1.000     0.978       135

    accuracy                          0.962       158
   macro avg      0.979     0.870     0.914       158
weighted avg      0.964     0.962     0.960       158

Key Metrics:
  accuracy: 0.962
  precision: 0.957
  recall: 1.000
  specificity: 0.739
  f1_score: 0.978
  roc_auc: 0.867

Confusion matrix saved: files/confusion_matrix_single_case.png
ROC curve saved: files/roc_curve_single_case.png

Uncertain predictions (0.3-0.7): 17 / 158

Cross-validation for reliable performance estimates.

In [ ]:
import logging
logging.getLogger("transformers").setLevel(logging.ERROR)
warnings.filterwarnings('ignore', category=UserWarning, module='sklearn')

cv_results = cross_validate(
    df, tokenizer, device, column_specifications, numeric_transform, config,
    numeric_cols_num=1, output_size=1, output_dir='files/cv',
)

logging.getLogger("transformers").setLevel(logging.WARNING)
warnings.resetwarnings()

print(f"\nMean F1: {cv_results['mean_metrics']['f1_score']:.3f} ± {cv_results['std_metrics']['f1_score']:.3f}")
print(f"Mean ROC AUC: {cv_results['mean_metrics']['roc_auc']:.3f} ± {cv_results['std_metrics']['roc_auc']:.3f}")
============================================================
Fold 1/5
============================================================
Train: 1264 | Val: 316
Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]
Epoch 1/5: 100%|██████████| 158/158 [00:25<00:00,  6.10it/s, loss=0.209]
Epoch: 0001/0005 | Train Loss: 0.209 | Val Loss: 0.184 *** Best ***
Train Acc: 94.858% | Val Acc: 92.089% | 0.63 min
Epoch 2/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.171]
Epoch: 0002/0005 | Train Loss: 0.171 | Val Loss: 0.153 *** Best ***
Train Acc: 96.123% | Val Acc: 94.304% | 1.26 min
Epoch 3/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.146]
Epoch: 0003/0005 | Train Loss: 0.146 | Val Loss: 0.139 *** Best ***
Train Acc: 96.756% | Val Acc: 94.937% | 1.89 min
Epoch 4/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.131]
Epoch: 0004/0005 | Train Loss: 0.131 | Val Loss: 0.130 *** Best ***
Train Acc: 96.915% | Val Acc: 94.937% | 2.52 min
Epoch 5/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.129]
Epoch: 0005/0005 | Train Loss: 0.129 | Val Loss: 0.129 *** Best ***
Train Acc: 96.994% | Val Acc: 94.937% | 3.14 min

Loading best model from epoch 5
Best epoch 5 (val loss: 0.129)
Evaluating: 100%|██████████| 40/40 [00:01<00:00, 20.01it/s]
EVALUATION METRICS: FOLD_1
              precision    recall  f1-score   support

     Exclude      0.891     0.788     0.837        52
     Include      0.959     0.981     0.970       264

    accuracy                          0.949       316
   macro avg      0.925     0.885     0.903       316
weighted avg      0.948     0.949     0.948       316

Key Metrics:
  accuracy: 0.949
  precision: 0.959
  recall: 0.981
  specificity: 0.788
  f1_score: 0.970
  roc_auc: 0.922

Confusion matrix saved: files/cv/fold_1/confusion_matrix_fold_1.png
ROC curve saved: files/cv/fold_1/roc_curve_fold_1.png
Fold 1 — F1: 0.970 | Precision: 0.959 | Recall: 0.981

============================================================
Fold 2/5
============================================================
Train: 1264 | Val: 316
Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]
Epoch 1/5: 100%|██████████| 158/158 [00:25<00:00,  6.10it/s, loss=0.221]
Epoch: 0001/0005 | Train Loss: 0.221 | Val Loss: 0.206 *** Best ***
Train Acc: 83.544% | Val Acc: 83.544% | 0.63 min
Epoch 2/5: 100%|██████████| 158/158 [00:25<00:00,  6.11it/s, loss=0.199]
Epoch: 0002/0005 | Train Loss: 0.199 | Val Loss: 0.189 *** Best ***
Train Acc: 84.494% | Val Acc: 85.443% | 1.26 min
Epoch 3/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.18]
Epoch: 0003/0005 | Train Loss: 0.180 | Val Loss: 0.170 *** Best ***
Train Acc: 96.440% | Val Acc: 95.253% | 1.89 min
Epoch 4/5: 100%|██████████| 158/158 [00:25<00:00,  6.11it/s, loss=0.161]
Epoch: 0004/0005 | Train Loss: 0.161 | Val Loss: 0.159 *** Best ***
Train Acc: 96.756% | Val Acc: 95.253% | 2.52 min
Epoch 5/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.151]
Epoch: 0005/0005 | Train Loss: 0.151 | Val Loss: 0.160
Train Acc: 96.915% | Val Acc: 95.570% | 3.15 min

Loading best model from epoch 4
Best epoch 4 (val loss: 0.159)
Evaluating: 100%|██████████| 40/40 [00:02<00:00, 19.85it/s]
EVALUATION METRICS: FOLD_2
              precision    recall  f1-score   support

     Exclude      0.894     0.808     0.848        52
     Include      0.963     0.981     0.972       264

    accuracy                          0.953       316
   macro avg      0.928     0.894     0.910       316
weighted avg      0.951     0.953     0.952       316

Key Metrics:
  accuracy: 0.953
  precision: 0.963
  recall: 0.981
  specificity: 0.808
  f1_score: 0.972
  roc_auc: 0.951

Confusion matrix saved: files/cv/fold_2/confusion_matrix_fold_2.png
ROC curve saved: files/cv/fold_2/roc_curve_fold_2.png
Fold 2 — F1: 0.972 | Precision: 0.963 | Recall: 0.981

============================================================
Fold 3/5
============================================================
Train: 1264 | Val: 316
Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]
Epoch 1/5: 100%|██████████| 158/158 [00:25<00:00,  6.11it/s, loss=0.207]
Epoch: 0001/0005 | Train Loss: 0.207 | Val Loss: 0.188 *** Best ***
Train Acc: 94.620% | Val Acc: 94.937% | 0.63 min
Epoch 2/5: 100%|██████████| 158/158 [00:25<00:00,  6.11it/s, loss=0.183]
Epoch: 0002/0005 | Train Loss: 0.183 | Val Loss: 0.184 *** Best ***
Train Acc: 95.016% | Val Acc: 94.937% | 1.26 min
Epoch 3/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.171]
Epoch: 0003/0005 | Train Loss: 0.171 | Val Loss: 0.174 *** Best ***
Train Acc: 96.994% | Val Acc: 95.886% | 1.89 min
Epoch 4/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.161]
Epoch: 0004/0005 | Train Loss: 0.161 | Val Loss: 0.173 *** Best ***
Train Acc: 97.547% | Val Acc: 95.886% | 2.51 min
Epoch 5/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.156]
Epoch: 0005/0005 | Train Loss: 0.156 | Val Loss: 0.176
Train Acc: 97.547% | Val Acc: 95.886% | 3.14 min

Loading best model from epoch 4
Best epoch 4 (val loss: 0.173)
Evaluating: 100%|██████████| 40/40 [00:01<00:00, 20.04it/s]
EVALUATION METRICS: FOLD_3
              precision    recall  f1-score   support

     Exclude      1.000     0.750     0.857        52
     Include      0.953     1.000     0.976       264

    accuracy                          0.959       316
   macro avg      0.977     0.875     0.917       316
weighted avg      0.961     0.959     0.956       316

Key Metrics:
  accuracy: 0.959
  precision: 0.953
  recall: 1.000
  specificity: 0.750
  f1_score: 0.976
  roc_auc: 0.952

Confusion matrix saved: files/cv/fold_3/confusion_matrix_fold_3.png
ROC curve saved: files/cv/fold_3/roc_curve_fold_3.png
Fold 3 — F1: 0.976 | Precision: 0.953 | Recall: 1.000

============================================================
Fold 4/5
============================================================
Train: 1264 | Val: 316
Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]
Epoch 1/5: 100%|██████████| 158/158 [00:25<00:00,  6.11it/s, loss=0.216]
Epoch: 0001/0005 | Train Loss: 0.216 | Val Loss: 0.188 *** Best ***
Train Acc: 94.146% | Val Acc: 95.253% | 0.63 min
Epoch 2/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.179]
Epoch: 0002/0005 | Train Loss: 0.179 | Val Loss: 0.162 *** Best ***
Train Acc: 95.649% | Val Acc: 95.253% | 1.26 min
Epoch 3/5: 100%|██████████| 158/158 [00:25<00:00,  6.11it/s, loss=0.158]
Epoch: 0003/0005 | Train Loss: 0.158 | Val Loss: 0.148 *** Best ***
Train Acc: 96.677% | Val Acc: 94.937% | 1.89 min
Epoch 4/5: 100%|██████████| 158/158 [00:25<00:00,  6.11it/s, loss=0.14]
Epoch: 0004/0005 | Train Loss: 0.140 | Val Loss: 0.142 *** Best ***
Train Acc: 96.677% | Val Acc: 95.253% | 2.51 min
Epoch 5/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.132]
Epoch: 0005/0005 | Train Loss: 0.132 | Val Loss: 0.134 *** Best ***
Train Acc: 97.310% | Val Acc: 95.570% | 3.14 min

Loading best model from epoch 5
Best epoch 5 (val loss: 0.134)
Evaluating: 100%|██████████| 40/40 [00:01<00:00, 20.03it/s]
EVALUATION METRICS: FOLD_4
              precision    recall  f1-score   support

     Exclude      0.896     0.827     0.860        52
     Include      0.966     0.981     0.974       264

    accuracy                          0.956       316
   macro avg      0.931     0.904     0.917       316
weighted avg      0.955     0.956     0.955       316

Key Metrics:
  accuracy: 0.956
  precision: 0.966
  recall: 0.981
  specificity: 0.827
  f1_score: 0.974
  roc_auc: 0.944

Confusion matrix saved: files/cv/fold_4/confusion_matrix_fold_4.png
ROC curve saved: files/cv/fold_4/roc_curve_fold_4.png
Fold 4 — F1: 0.974 | Precision: 0.966 | Recall: 0.981

============================================================
Fold 5/5
============================================================
Train: 1264 | Val: 316
Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]
Epoch 1/5: 100%|██████████| 158/158 [00:25<00:00,  6.11it/s, loss=0.21]
Epoch: 0001/0005 | Train Loss: 0.210 | Val Loss: 0.178 *** Best ***
Train Acc: 90.348% | Val Acc: 90.506% | 0.63 min
Epoch 2/5: 100%|██████████| 158/158 [00:25<00:00,  6.11it/s, loss=0.175]
Epoch: 0002/0005 | Train Loss: 0.175 | Val Loss: 0.155 *** Best ***
Train Acc: 93.987% | Val Acc: 93.354% | 1.26 min
Epoch 3/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.159]
Epoch: 0003/0005 | Train Loss: 0.159 | Val Loss: 0.152 *** Best ***
Train Acc: 96.123% | Val Acc: 94.937% | 1.89 min
Epoch 4/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.151]
Epoch: 0004/0005 | Train Loss: 0.151 | Val Loss: 0.141 *** Best ***
Train Acc: 96.598% | Val Acc: 94.304% | 2.52 min
Epoch 5/5: 100%|██████████| 158/158 [00:25<00:00,  6.12it/s, loss=0.144]
Epoch: 0005/0005 | Train Loss: 0.144 | Val Loss: 0.139 *** Best ***
Train Acc: 96.756% | Val Acc: 94.304% | 3.15 min

Loading best model from epoch 5
Best epoch 5 (val loss: 0.139)
Evaluating: 100%|██████████| 40/40 [00:02<00:00, 19.94it/s]
EVALUATION METRICS: FOLD_5
              precision    recall  f1-score   support

     Exclude      0.827     0.827     0.827        52
     Include      0.966     0.966     0.966       264

    accuracy                          0.943       316
   macro avg      0.896     0.896     0.896       316
weighted avg      0.943     0.943     0.943       316

Key Metrics:
  accuracy: 0.943
  precision: 0.966
  recall: 0.966
  specificity: 0.827
  f1_score: 0.966
  roc_auc: 0.950

Confusion matrix saved: files/cv/fold_5/confusion_matrix_fold_5.png
ROC curve saved: files/cv/fold_5/roc_curve_fold_5.png
Fold 5 — F1: 0.966 | Precision: 0.966 | Recall: 0.966

============================================================
Cross-Validation Summary (5 folds)
============================================================
accuracy: 0.952 ± 0.005
precision: 0.961 ± 0.005
recall: 0.982 ± 0.011
specificity: 0.800 ± 0.029
f1_score: 0.971 ± 0.003
roc_auc: 0.944 ± 0.011
Best fold: 3 (val acc: 95.886%)

Mean F1: 0.971 ± 0.003
Mean ROC AUC: 0.944 ± 0.011

Calibrate model probabilities via temperature scaling.¶
In [ ]:
calibration = calibrate_model(model, validation_dataloader, device)
print(f"Temperature: {calibration.temperature:.3f}")
Temperature: 0.896
Filter rows where "single_case" is NA (not labeled).¶
In [ ]:
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning, module='huggingface_hub')

labeled_data = pd.read_excel("files/labeled_data.xlsx")
init_all_data = pd.read_csv("files/init_all_data.csv")

missing_uts = set(labeled_data['UT']) - set(init_all_data['UT'])
if missing_uts:
    print(f"Removing {len(missing_uts)} rows from labeled_data; these UTs are not in init_all_data:\n{missing_uts}")
    labeled_data = labeled_data[labeled_data['UT'].isin(init_all_data['UT'])]

init_all_data = pd.merge(init_all_data, labeled_data, on="UT", how="left", suffixes=("", "_y"))
init_all_data = init_all_data.loc[:, ~init_all_data.columns.str.endswith("_y")]
init_all_data.columns = init_all_data.columns.str.replace("_x$", "", regex=True)

unlabeled_data = init_all_data[pd.isna(init_all_data["single_case"])]
unlabeled_data.to_excel("files/unlabeled_data.xlsx", index=False)
print(f"Unlabeled data: {len(unlabeled_data)} records saved to files/unlabeled_data.xlsx")

unlabeled_data = pd.read_excel("files/unlabeled_data.xlsx")
Unlabeled data: 4349 records saved to files/unlabeled_data.xlsx

In [ ]:
unlabeled_df = unlabeled_data[['UT', 'single_case', 'technology_use', 'SO', 'PY', 'AF', 'TI', 'AB', 'DE']].copy()
unlabeled_df["PY"] = pd.to_numeric(unlabeled_df["PY"], errors="coerce")

missing_values = unlabeled_df.isnull().sum()
print(missing_values)
UT                   0
single_case       4349
technology_use    4349
SO                   0
PY                   0
AF                   0
TI                   0
AB                 129
DE                 912
dtype: int64

Load the BERT tokenizer.¶
In [ ]:
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
Preprocess and create dataloaders.¶
In [ ]:
column_specifications = {
    "text_cols": ['AF', 'TI', 'AB', 'DE'],
    "categorical_cols": ['SO'],
    "numeric_cols": ['PY'],
    "label_col": "single_case",
}

numeric_transform = {'PY': 'min'}
label_col = column_specifications['label_col']

# Apply fitted transforms from training (no data leakage)
unlabeled_dataset, _ = preprocess_dataset(
    unlabeled_df, tokenizer, device, column_specifications, numeric_transform,
    max_length=config.max_length, fitted_transforms=fitted,
)
unlabeled_dataloader = create_dataloader(unlabeled_dataset, SequentialSampler, config.eval_batch_size)
Load and initialize the PubMLP model.¶
In [ ]:
import logging
logging.getLogger("transformers").setLevel(logging.ERROR)

model = PubMLP(
    categorical_vocab_sizes=fitted.categorical_vocab_sizes,
    numeric_cols_num=1,
    mlp_hidden_size=config.mlp_hidden_size,
    output_size=1,
    dropout_rate=config.dropout_rate,
    embedding_model=config.embedding_model,
    model_name=config.model_name,
    n_hidden_layers=config.n_hidden_layers,
).to(device)

logging.getLogger("transformers").setLevel(logging.WARNING)

model.load_state_dict(torch.load(f"best_model_state_{label_col}.pth", map_location=device))
Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]
Out[ ]:
<All keys matched successfully>
Predict each publication.¶
In [ ]:
predictions, probabilities = predict_model(model, unlabeled_dataloader, device, calibration=calibration)

uncertain = flag_uncertain(probabilities, low=config.uncertainty_low, high=config.uncertainty_high)
print(f"Uncertain predictions ({config.uncertainty_low}-{config.uncertainty_high}): {sum(uncertain)} / {len(uncertain)}")

predictions_data = pd.DataFrame({
    label_col: ['Yes' if p == 1 else 'No' for p in predictions],
    'probability': probabilities,
    'uncertain': uncertain,
})

if len(predictions_data) == len(unlabeled_df):
    unlabeled_df[label_col] = predictions_data[label_col].values
    unlabeled_df['probability'] = predictions_data['probability'].values
    unlabeled_df['uncertain'] = predictions_data['uncertain'].values
    predicted_data = unlabeled_df
else:
    print("The number of rows in predictions_df and unlabeled_df does not match.")

predicted_file = f"files/predicted_data_{label_col}.xlsx"
predicted_data.to_excel(predicted_file, index=False)
print(f"Data saved to {predicted_file}")

screened_pubs = (predicted_data[label_col] == 'Yes').sum()
print(f"'Yes' predictions for {label_col}: {screened_pubs}")
Predicting: 100%|██████████| 544/544 [00:27<00:00, 19.78it/s]
Uncertain predictions (0.3-0.7): 1859 / 4349
Data saved to files/predicted_data_single_case.xlsx
'Yes' predictions for single_case: 2506
In [ ]:
all_data = pd.merge(predicted_data, init_all_data, on="UT", how="left", suffixes=("", "_y"))
all_data = all_data.loc[:, ~all_data.columns.str.endswith("_y")]
all_data.columns = all_data.columns.str.replace("_x$", "", regex=True)

all_data_file = f"files/all_data_{label_col}.xlsx"
all_data.to_excel(all_data_file, index=False)
print(f"Data saved to {all_data_file}")
Data saved to files/all_data_single_case.xlsx

In [ ]:
from nbconvert import HTMLExporter
import nbformat

notebook_path = 'index.ipynb'
html_exporter = HTMLExporter()

with open(notebook_path, 'r', encoding='utf-8') as nb_file:
    notebook_content = nb_file.read()
    notebook = nbformat.reads(notebook_content, as_version=4)

if 'widgets' in notebook.metadata and 'application/vnd.jupyter.widget-state+json' in notebook.metadata['widgets']:
    if 'state' not in notebook.metadata['widgets']['application/vnd.jupyter.widget-state+json']:
        notebook.metadata['widgets']['application/vnd.jupyter.widget-state+json']['state'] = {}

html_output, _ = html_exporter.from_notebook_node(notebook)

with open('index.html', 'w', encoding='utf-8') as html_file:
    html_file.write(html_output)