Adding learnable coefficients for multi-objective losses?

I am running a multi-objective problem where I compute three losses and then sum them up. For each loss, I want to have a learnable coefficient (alpha, beta, and gamma, respectively) that will be optimized.

optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)

for batch in dl:


    result = model(batch)

    loss1 = loss_fn_1(result)
    loss2 = loss_fn_2(result)
    loss3 = loss_fn_3(result)

    # How to optimize alpha, beta, and gamma?
    loss = alpha*loss1 + beta*loss2 + gamma*loss3 

Specific questions:

  1. Should I even have coefficients alpha, beta, and gamma? The optimizer will minimize, so they’ll all go to 0.0, right?

  2. If having those coefficients is a good idea, how can I prevent them from going to 0.0? Someone told me to use regularization, but what does that mean in this case?

  3. How do I declare alpha, beta, and gamma to be learnable by AdamW?

  1. Yes

  2. Theoretically, we have to make a constraint like alpha+beta+gamma = 1. To change this to unconstrained optimization, we have to use Lagrange multiplier to the constraint equation, and that will be the regularization formula your friend talked about e.g. you put

lambda1*alpha, lambda2*beta and lambda3*gamma

into loss function. I believe it complicates the problem even more since finding optimum values of lambdas are difficult even theoretically.

2.5 Sorry not answer you Q3, but I think the practical way is to treat alpha, beta and gamma as hyperparameters and simply optimize them via grid search.

In this case, simply split some of your training set to validation set, and define the metric on it. The “validation metric” has to be specified to be suitable to your problem (e.g. error, f1, spearman or any others) — you can get some ideas on metrics by finding some Kaggle competitions that is similar to your problem and see their metrics.

Select hyperparaeters that optimize your validation metric.

Theoretically, we have to make a constraint like alpha+beta+gamma = 1

Thank you.

Last night I was thinking of doing

loss = alpha*loss1 + beta*loss2 + (1.0 - alpha - beta)*loss3 

which seems to be equivalent to what you wrote above.