Continual learning is the idea of learning not from static dataset. Usually you collect data and train the model. In continual learning you can update the model again and again, when new data becomes available.

So instead of starting from scratch every time you have a larger dataset, you can update it starting from the previous model you had.

The dream of Continual Learning is to:

  • Learn from stream of data;
  • Without storing old samples;
    Without increasing memory footprint;
  • With the ability to add new tasks and/or labels anytime
  • Improving further over time;

You have non-i.i.d settings so gradient descent might not find optimal solution and overfit on training data.
We also have to assume that the relationship between input and output doesn’t change. It might change the input or the output distribution. You need to avoid forgetting.

A typical CL setup is in a classification scenario with small dataset where after learning from scratch you learn the first task and then when you perform well on a task you switch on a new batch of data with a new task / classes.

This is a naive way of doing it, but unfortunately it doesn’t really work, the performance on the previous tasks get worse when you introduce more and more tasks.

We have catastrophic forgetting.

Typical CL setup

Task incremental, or class incremental or domain incremental

  • Task-IL solve tasks so far, task-ID provided
  • Domain-IL solve tasks so far, task-ID not provided
  • Class-IL solve tasks so far, and infer task-ID

Class-IL is the most used and is the most hard one.

Continual Learning methods are divided in three categories:

  • Regularization-based methods;
  • Parameter isolation methods;
  • Replay-based method;

the last part is the risk of the new data we are currently processing () and we can empirically calculate, while the first part we need to approximate it () this part is to avoid forgetting, the second term is to learn new stuff.
We need the model to be both plastic and stable.

Regularization-based methods;

e.g. EWC, MAS, SI

is the previous optimal param for the minimum

We don’t regularize the parameters but the output. We compute the output of the old model and the new model, and we want it to be the same.

Replay methods

computed over small buffer of stored data. The problem would be that the model could over-fit on this small buffer but it’s not a big deal. They’ve also criticized because they’ve relaxed the constraint of storing old samples.

Parameter Isolation methods

The idea is to you freeze the parameters of the network and add more parameters for the new tasks but since you need to know all the part of the networks it works well only in task-incremental settings.

Good representation

considering the two parts of the network:

  • feature extractor, backbone, representation learner
  • classification head
    We want to avoid forgetting, is there less forgetting at representation level?
    maybe there’s a mismatch between the classification head and the representation learner that learns new feature based on the new data.
    In the representation layer there are a lot of shared knowledge which is used across tasks, but tasks-specific information might get be forgotten.

Continual learning with foundation models

Model merging since thee models give you a good starting point, you can actually average the weights of different adaptation that you have and get good results from that.
You can have different LoRA adapters, one for each tasks and average them to have a single model adapted to do different things.

The dynamics of forgetting

Each samples pull in different direction the model, which will be updated in the direction where most of the samples pulls towards.

When the model trains a task the other looses efficacy and performance. As soon as you switch tasks you have a big drop.

We can see which knowledge is kept and which is lost (aka what knowledge keep getting classified correctly).

Example that are learned quickly tend to be remembered, while exampels learned slowly tend to be forgotten.

A learning speed can be used to measure which samples are learned fast vs slow:

intuitively, how many epochs are needed to learn a sample.