Source code for immuneML.ml_methods.pytorch_implementations.PyTorchLogisticRegression

import torch


[docs]class PyTorchLogisticRegression(torch.nn.Module): def __init__(self, in_features: int, zero_abundance_weight_init: bool): super().__init__() self.linear = torch.nn.Linear(in_features, 1, bias=True) with torch.no_grad(): self.linear.bias.zero_() self.linear.weight.normal_(mean=0, std=1/in_features) if zero_abundance_weight_init: self.linear.weight[:, -1].fill_(0.0)
[docs] def forward(self, x): return self.linear(x).squeeze()