Exploring frequency weights in logistic regression

R
Modelling
Frequency weights can be used in logistic regression to address class imbalance caused by sampling bias. In this post, we will use simulations to explore how to effectively choose weights and visually understand the benefits of this approach.
Published

June 4, 2024

In logistic regression, we estimate a linear model for the log odds of observing an event at different levels of our predictor. These log odds can be transformed into probabilities, which then produce the characteristic S-curve.

In essence, it all comes down to odds. Having odds of 1:1 in our underlying data generation process means that both events are equally likely to occur, each with a probability of 50%. Our fictional population for the simulation follows this setup, with the probabilities of our events depending on values from normal distributions for both groups, with means of 5 and 8 and standard deviations of 1.

library(tidyverse)

    a <- tibble(value = rnorm(4000, 5, 1), condition = "a")
    b <- tibble(value = rnorm(4000, 8, 1), condition = "b")
    df_ab <- bind_rows(a, b)

df_ab %>%
  ggplot(aes(x = value, y = condition, fill = condition))+
  ggdist::stat_halfeye()+ 
  scale_x_continuous(breaks = seq(1,12,1))+
  blog_theme()+
  guides(fill = "none")

Simulating imbalance

To get an idea of how frequency weights affect the logistic regression results, we will generate some simulations from our fictitious population under different conditions. Each simulation will consist of 400 samples, with each N = 400 values drawn from our population.

  1. Balanced dataset: Each sample will have a 1:1 ratio of our two classes and represent the proportions in the population.

  2. Unbalanced dataset: Each sample will have a 1:4 ratio, representing an imbalance between the classes caused by some sort of sampling bias, where where one group is underrepresented in the sample, compared to the population.

  3. Weighted unbalanced dataset: Similar to the unbalanced dataset, but here each observation in the underrepresented class is given a weight of 4 to counteract the imbalance.

For each sample under these conditions we will run a logistic regression using the logistic_sim function on each sample.

We will then compute the median of the estimated parameters (e.g. coefficients and intercepts) from each condition in order to obtain a robust measure of central tendency.

Finally, we will aggregate these median parameter estimates across all conditions into a single logistic regression model for each condition.

logistic_sim <- function(sample_n, a_n, b_n, a_mean = 5, b_mean = 8, a_std = 1, b_std = 1, weight = 1){
  map(1:sample_n, ~ {
    # Create samples for conditions a and b
    a <- tibble(value = rnorm(a_n, a_mean, a_std), condition = 0)
    b <- tibble(value = rnorm(b_n, b_mean, b_std), condition = 1)
    df_ab <- bind_rows(a, b)

    # Adjust weights if the argument is given and not 1 
    if (weight != 1){
      df_ab <- df_ab %>%
        mutate(weights = ifelse(condition == 0, weight, 1))
    }

    return(df_ab)
  }) %>%
  imap(~ {
    if (weight != 1) {
      # Fit model with weights
      model <- glm(condition ~ value, family = binomial, data = .x, weights = .x$weights)
    } else {
      # Fit model without weights
      model <- glm(condition ~ value, family = binomial, data = .x)
    }
    # Tidy and return the model coefficients with iteration number
    broom::tidy(model) %>%
      select(term, estimate) %>%
      mutate(iteration = .y)
  }) %>%
  bind_rows() %>%
  pivot_wider(names_from = term, values_from = estimate) %>%
  rename(alpha = `(Intercept)`, beta = value)
}
doParallel::registerDoParallel(cores = parallel::detectCores())

balanced_estimates <- logistic_sim(400, a_n = 200, b_n = 200) %>%
  mutate(model = "balanced")
unbalanced_estimates <- logistic_sim(400, a_n = 67, b_n = 333) %>%
  mutate(model = "unbalanced")
weighted_estimates <- logistic_sim(400, a_n = 67, b_n = 333, weight = 4) %>%
  mutate(model = "weighted")

As we can see, the unbalanced model underestimates the first group compared to the balanced model. It seems that the intercept is particularly affected, while the slope is relatively similar. However, all the estimated probabilities are affected.

Assuming that our balanced model is the closest estimate we can get of our population, the unbalanced model that used weights seems to be better fitted. Assuming that our two events are generally equally likely to occur in our population, but our sample is biased, we might be interested in adjusting these estimates using frequency weights.

bind_rows(balanced_estimates, unbalanced_estimates, weighted_estimates) %>%
group_by(model) %>%
summarise(alpha.median = median(alpha),
            beta.median = median(beta)) %>%
  mutate(x = list(seq(3,9,.1))) %>%
  unnest(x) %>%
  mutate(y = boot::inv.logit(alpha.median + beta.median * x)) %>%
  ggplot(aes(x = x, y = y, col = model))+
  geom_line()+
  xlab("value")+
  ylab("probability")+
  blog_theme()

Relationship between odds and weights

But how do we determine an appropriate frequency weight?

To explore this, we experiment with all possible combinations of odds and weights from 1 to 20 in our fictional population. Again, we use the logistic_sim function to estimate our median parameters under each combination.

sample_n = 400
odds = 1:20
a_n = round(sample_n / (odds + 1))
b_n = 400 - a_n
weight = list(1:20)

sim_models <- tibble(sample_n, odds, a_n, b_n, weight) %>%
  unnest(weight)

psych::headTail(sim_models)
  sample_n odds a_n b_n weight
1      400    1 200 200      1
2      400    1 200 200      2
3      400    1 200 200      3
4      400    1 200 200      4
5      ...  ... ... ...    ...
6      400   20  19 381     17
7      400   20  19 381     18
8      400   20  19 381     19
9      400   20  19 381     20
```{r}

doParallel::registerDoParallel(cores = parallel::detectCores())

model_estimates <- sim_models %>%
  mutate(estimates = pmap(list(sample_n = sample_n, a_n = a_n, b_n = b_n, weight = weight), logistic_sim, .progress = TRUE))

```

Evaluating parameter estimates for different odds and weights

We can evaluate the performance of each model by looking at the median error at different values of our variable of interest - the mean of the two groups and the midpoint between these means. As you can see in the graph, at the critical value of 6.5, the median error tends to decrease up to a certain weight threshold, after which it increases again. So even in such a simple example, weights only help to a certain extent and must be used with caution.

Code
pred <- model_estimates %>%
  unnest(estimates) %>%
  group_by(odds, a_n, b_n, weight) %>%
  summarise(alpha.median = median(alpha),
            beta.median = median(beta)) %>%
  ungroup() %>%
  mutate(x = list(c(5, 6.5, 8))) %>%
  unnest(x) %>%
  mutate(y = boot::inv.logit(alpha.median + beta.median * x))

bench <- pred %>%
  filter(weight == 1, odds == 1) %>%
  select(x, y_bench = y)

pred_bench <- pred %>%
  left_join(bench, by = join_by(x)) %>%
  mutate(median_error = abs(y_bench - y)) 

pred_min <- pred_bench %>%
  filter(x == 6.5) %>%
  group_by(odds) %>%
  filter(median_error == min(median_error)) %>%
  select(odds, weight)

pred_bench %>%
  # mutate(odds = as.character(paste0("1:",odds))) %>%
  ggplot(aes(x = weight, y = median_error, col = factor(x)))+
  geom_line()+
  facet_wrap(~odds)+
  blog_theme()+
  geom_vline(data = pred_min, aes(xintercept = weight), size = 1, color = "gray", linetype= "dotted")+
  ylab("median error")+
  labs(col = "value")

Choosing weights

The threshold with the most balanced estimates seems to be closely related to the odds ratio. Thus, choosing the assumed imbalance of the odds as the weight may be the most appropriate, especially for smaller imbalances. At least in our simple setup. However, deviations for higher imbalances may also be due to the uncertainty of the estimates as the sample size for the minority group decreases. As the odds increase, almost any weight seems to produce more appropriate estimates than the unbalanced model where we did not adjust for sampling bias.

no_weight_bench <- pred_bench %>%
  filter(weight == 1) %>%
  select(odds, x, median_error_nw = median_error)

pred_nh <- pred_bench %>%
  left_join(no_weight_bench, by = join_by(odds, x), relationship = "many-to-many") %>%
  mutate(improvement = ifelse(median_error <= median_error_nw, TRUE, FALSE))

pred_min_error <- pred_bench %>%
  group_by(x, odds) %>%
  filter(median_error == min(median_error)) %>%
  select(x, odds, weight) %>%
  mutate(best_weight = weight) %>%
  ungroup()

pred_nh %>%
  left_join(pred_min_error, by = join_by(odds, weight, x)) %>%
  ggplot(aes(x = weight, y = odds))+
  geom_tile(aes(fill = improvement ), alpha = .5)+
  geom_line(aes(x = odds, y = odds), linetype= "dashed")+
  geom_point(aes(x = best_weight, y = odds))+
  facet_wrap(~factor(x))+
  blog_theme()+
  labs(fill = "improvement to\nunbalanced model")

Conclusion

Depending on the scenario, weights may be helpful to deal with sampling bias and class imbalance that does not reflect the population (at least in a simple setup, like the one we worked with). However, they should be used with prior knowledge of the population to not overadjust the sampling bias. One solution might be to adjust the observed odds by the expected odds ratio in the population, in the group affected by the sampling bias. The natural class imbalance should also be maintained in the logistic regression unless you have good reasons not to.