import logging
import pickle
from pathlib import Path
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
from glmnet import LogitNet
from immuneML.data_model.EncodedData import EncodedData
from immuneML.data_model.bnp_util import write_yaml
from immuneML.environment.Label import Label
from immuneML.ml_methods.classifiers.MLMethod import MLMethod
from immuneML.ml_methods.util.Util import Util
from immuneML.util.PathBuilder import PathBuilder
MAX_TORCH_ITERS = 200
class _TorchLogReg(BaseEstimator, ClassifierMixin):
"""sklearn-compatible logistic regression via PyTorch LBFGS with a per-feature penalty mask."""
def __init__(self, C=1.0, alpha=1.0, non_penalized_indices=None, max_iter=200, device='cpu'):
self.C = C
self.alpha = alpha
self.non_penalized_indices = non_penalized_indices
self.max_iter = max_iter
self.device = device
def fit(self, X, y):
import torch
import torch.nn as nn
import torch.optim as optim
X_arr = X.toarray() if hasattr(X, 'toarray') else np.asarray(X)
n, p = X_arr.shape
dev = torch.device(self.device)
X_t = torch.from_numpy(X_arr).float().to(dev)
y_t = torch.from_numpy(y.astype(float)).float().unsqueeze(1).to(dev)
mask = torch.ones(p, device=dev)
if self.non_penalized_indices:
mask[list(self.non_penalized_indices)] = 0.0
self.linear_ = nn.Linear(p, 1).to(dev)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.LBFGS(self.linear_.parameters(), lr=1.0,
max_iter=self.max_iter, line_search_fn='strong_wolfe')
reg_scale = 1.0 / max(self.C, 1e-8)
alpha = float(self.alpha)
def closure():
optimizer.zero_grad()
loss = criterion(self.linear_(X_t), y_t)
w = self.linear_.weight.squeeze()
penalty = alpha * torch.sum(mask * torch.abs(w)) + \
(1.0 - alpha) * 0.5 * torch.sum(mask * w ** 2)
total = loss + reg_scale * penalty
total.backward()
return total
optimizer.step(closure)
self.classes_ = np.unique(y)
return self
def predict_proba(self, X):
import torch
X_arr = X.toarray() if hasattr(X, 'toarray') else np.asarray(X)
dev = torch.device(self.device)
X_t = torch.from_numpy(X_arr.astype(np.float32)).to(dev)
with torch.no_grad():
probs = torch.sigmoid(self.linear_(X_t)).cpu().numpy().ravel()
return np.column_stack([1.0 - probs, probs])
def predict(self, X):
return self.classes_[(self.predict_proba(X)[:, 1] >= 0.5).astype(int)]
def _compute_lambda_sequence(X, y, alpha, n_lambda, min_lambda_ratio=None):
"""Log-spaced lambda sequence matching glmnet's automatic path for logistic regression.
At beta=0 the gradient of the log-loss is X^T(y - 0.5) / n, so lambda_max is the
smallest lambda that shrinks all penalised coefficients to zero.
"""
X_arr = X.toarray() if hasattr(X, 'toarray') else np.asarray(X)
n, p = X_arr.shape
grad = np.abs(X_arr.T @ (y - y.mean())) / n
lambda_max = grad.max() / max(float(alpha), 1e-3)
if min_lambda_ratio is None:
min_lambda_ratio = 1e-4 if n >= p else 1e-2
lambda_min = lambda_max * min_lambda_ratio
return np.exp(np.linspace(np.log(lambda_max), np.log(lambda_min), n_lambda))
def _lbfgs_step(linear, X_t, y_t, mask, lam, alpha, max_iter):
"""One warm-started LBFGS step along the regularisation path for a single lambda value."""
import torch.nn as nn
import torch.optim as optim
import torch
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.LBFGS(linear.parameters(), lr=1.0,
max_iter=max_iter, line_search_fn='strong_wolfe')
_lam, _alpha = float(lam), float(alpha)
def closure():
optimizer.zero_grad()
loss = criterion(linear(X_t), y_t)
w = linear.weight.squeeze()
penalty = _alpha * torch.sum(mask * torch.abs(w)) + \
(1.0 - _alpha) * 0.5 * torch.sum(mask * w ** 2)
total = loss + _lam * penalty
total.backward()
return total
optimizer.step(closure)
[docs]
class LogRegressionCustomPenalty(MLMethod):
"""
Logistic Regression with custom penalty factors for specific features.
**Specification arguments**:
- non_penalized_features (list): List of feature names that should not be penalized.
- non_penalized_encodings (list): List of encoding names (class names) whose features should not be penalized. This
parameter can be used only in combination with CompositeEncoder. None of the features from the specified encodings
will be penalized. If both non_penalized_features and non_penalized_encodings are provided, the union of the two
will be used.
- backend (str): 'glmnet' (default) or 'torch'. Both backends use the same regularisation parameters below.
The torch backend uses PyTorch LBFGS with a warm-started regularisation path and stratified K-fold CV
to select lambda, replicating glmnet's path-fitting strategy.
- alpha (float): Elastic net mixing parameter. 0 = ridge, 1 = lasso. Default 1.
- n_lambda (int): Number of lambda values in the regularisation path. Default 100.
- min_lambda_ratio (float): Ratio of the smallest to largest lambda. Default None (auto: 1e-4 if n>=p, else 1e-2).
Note: glmnet names this parameter min_lambda_ratio as well.
- n_splits (int): Cross-validation folds for lambda selection. Default 3.
- scoring (str): Scoring metric for lambda selection. Default None (accuracy for glmnet; roc_auc for torch).
- random_state (int): Random seed for CV fold splitting.
- max_iter (int): Maximum solver iterations. For glmnet this is coordinate descent passes (default 100000);
for torch this is LBFGS iterations per warm-started path step (capped at 50 internally — warm starting
means each step already starts near the solution and needs very few iterations).
Additional keyword arguments are forwarded to LogitNet when using the glmnet backend (e.g. standardize,
fit_intercept, cut_point, lambda_path). They are ignored by the torch backend.
**YAML specification:**
.. code-block:: yaml
ml_methods:
custom_log_reg:
LogRegressionCustomPenalty:
backend: torch # 'glmnet' (default) or 'torch'
alpha: 1 # 1 for lasso, 0 for ridge
n_lambda: 100
n_splits: 3
non_penalized_features: []
non_penalized_encodings: ['Metadata']
random_state: 42
"""
def __init__(self, non_penalized_features: list = None, name: str = None, label: Label = None,
non_penalized_encodings: list = None, backend: str = None,
alpha: float = None, n_lambda: int = None, min_lambda_ratio: float = None,
n_splits: int = None, scoring: str = None, random_state: int = None,
max_iter: int = None, device: str = None, **kwargs):
super().__init__(name=name, label=label)
self.non_penalized_features = non_penalized_features if non_penalized_features is not None else []
self.non_penalized_encodings = non_penalized_encodings if non_penalized_encodings is not None else []
self.backend = backend
self.alpha = alpha
self.n_lambda = n_lambda
self.min_lambda_ratio = min_lambda_ratio
self.n_splits = n_splits
self.scoring = scoring
self.random_state = random_state
self.max_iter = max_iter
self.device = device
self.kwargs = kwargs # forwarded to LogitNet for glmnet backend
for ind, encoding in enumerate(self.non_penalized_encodings):
if 'Encoder' not in encoding:
self.non_penalized_encodings[ind] = encoding + 'Encoder'
if backend == 'torch':
try:
import torch
except ImportError:
raise ImportError("LogRegressionCustomPenalty: PyTorch is required for the 'torch' backend. "
"Please install it to use this option.")
self.model = None
self.feature_names = None
def _resolve_non_penalized_features(self, encoded_data: EncodedData):
if encoded_data.encoding == 'CompositeEncoder' and self.non_penalized_encodings:
from_encodings = encoded_data.feature_annotations[
encoded_data.feature_annotations['encoder'].isin(self.non_penalized_encodings)
]['feature'].tolist()
self.non_penalized_features = list(set(from_encodings) | set(self.non_penalized_features))
logging.info(f"{self.__class__.__name__}: inferred non-penalized features: {self.non_penalized_features}")
def _fit(self, encoded_data: EncodedData, cores_for_training: int = 2):
X = encoded_data.examples
y = Util.map_to_new_class_values(encoded_data.labels[self.label.name], self.class_mapping)
self.feature_names = encoded_data.feature_names
self._resolve_non_penalized_features(encoded_data)
non_penalized_indices = [i for i, f in enumerate(self.feature_names) if f in self.non_penalized_features]
penalty_factor = np.ones(X.shape[1])
for idx in non_penalized_indices:
penalty_factor[idx] = 0.0
if self.backend == 'glmnet':
glmnet_params = dict(
alpha=self.alpha, n_lambda=self.n_lambda, n_splits=self.n_splits,
random_state=self.random_state, max_iter=self.max_iter,
n_jobs=cores_for_training, **self.kwargs
)
if self.min_lambda_ratio is not None:
glmnet_params['min_lambda_ratio'] = self.min_lambda_ratio
if self.scoring is not None:
glmnet_params['scoring'] = self.scoring
self.model = LogitNet(**glmnet_params)
self.model.fit(X, y, relative_penalties=penalty_factor)
elif self.backend == 'torch':
self._fit_torch(X, y, non_penalized_indices, cores_for_training)
else:
raise ValueError(f"Unknown backend '{self.backend}'. Use 'glmnet' or 'torch'.")
def _fit_torch(self, X, y, non_penalized_indices, n_jobs):
# With warm starting each step starts near the solution, so few LBFGS iterations suffice.
# Cap at MAX_TORCH_ITERS; the glmnet default of 100000 is for coordinate descent, not quasi-Newton.
lbfgs_max_iter = min(self.max_iter, MAX_TORCH_ITERS)
alpha = float(self.alpha)
dev = torch.device(self.device)
X_dense = (X.toarray() if hasattr(X, 'toarray') else np.asarray(X)).astype(np.float32)
lambdas = _compute_lambda_sequence(X_dense, y, alpha, self.n_lambda, self.min_lambda_ratio)
n, p = X_dense.shape
mask = torch.ones(p, device=dev)
if non_penalized_indices:
mask[list(non_penalized_indices)] = 0.0
# K-fold CV with a warm-started path per fold to score each lambda
cv_scores = np.zeros(self.n_lambda)
skf = StratifiedKFold(n_splits=self.n_splits, shuffle=True, random_state=self.random_state)
for train_idx, val_idx in skf.split(X_dense, y):
X_tr = torch.from_numpy(X_dense[train_idx]).to(dev)
y_tr = torch.from_numpy(y[train_idx].astype(np.float32)).unsqueeze(1).to(dev)
X_val = torch.from_numpy(X_dense[val_idx]).to(dev)
y_val = y[val_idx]
linear = nn.Linear(p, 1).to(dev)
for i, lam in enumerate(lambdas):
_lbfgs_step(linear, X_tr, y_tr, mask, lam, alpha, lbfgs_max_iter)
with torch.no_grad():
probs = torch.sigmoid(linear(X_val)).cpu().numpy().ravel()
try:
cv_scores[i] += roc_auc_score(y_val, probs)
except ValueError:
cv_scores[i] += 0.5
best_idx = int(np.argmax(cv_scores / self.n_splits))
# Refit on full data along the path up to the selected lambda
X_t = torch.from_numpy(X_dense).to(dev)
y_t = torch.from_numpy(y.astype(np.float32)).unsqueeze(1).to(dev)
linear_final = nn.Linear(p, 1).to(dev)
for lam in lambdas[:best_idx + 1]:
_lbfgs_step(linear_final, X_t, y_t, mask, lam, alpha, lbfgs_max_iter)
self.model = _TorchLogReg(C=float(1.0 / lambdas[best_idx]), alpha=alpha,
non_penalized_indices=non_penalized_indices,
max_iter=lbfgs_max_iter, device=self.device)
self.model.linear_ = linear_final
self.model.classes_ = np.unique(y)
def _predict(self, encoded_data: EncodedData):
predictions = self.model.predict(encoded_data.examples)
return {self.label.name: Util.map_to_old_class_values(np.array(predictions), self.class_mapping)}
def _predict_proba(self, encoded_data: EncodedData):
class_names = Util.map_to_old_class_values(self.model.classes_, self.class_mapping)
probabilities = self.model.predict_proba(encoded_data.examples)
return {self.label.name: {class_name: probabilities[:, i] for i, class_name in enumerate(class_names)}}
[docs]
def store(self, path: Path):
PathBuilder.build(path)
write_yaml(path / 'model.yaml', vars(self))
with open(path / 'model.pkl', 'wb') as f:
pickle.dump({
'model': self.model,
'non_penalized_features': self.non_penalized_features,
'feature_names': self.feature_names
}, f)
[docs]
def load(self, path: Path):
with open(path / 'model.pkl', 'rb') as f:
model = pickle.load(f)
self.model = model['model']
self.non_penalized_features = model['non_penalized_features']
self.feature_names = model['feature_names']
[docs]
def get_params(self, for_refitting=False) -> dict:
params = {
'alpha': self.alpha,
'non_penalized_features': self.non_penalized_features,
'backend': self.backend,
}
if self.backend == 'glmnet':
params.update({
'lambda_best': self.model.lambda_best_,
'lambda_max': self.model.lambda_max_,
'random_state': self.model.random_state,
})
else:
params['C'] = self.model.C
return params
[docs]
def can_predict_proba(self) -> bool:
return True
[docs]
def can_fit_with_example_weights(self) -> bool:
return False
[docs]
def get_compatible_encoders(self):
from immuneML.encodings.composite_encoding.CompositeEncoder import CompositeEncoder
return [CompositeEncoder]