论文:Detecting and Correcting for Label Shift with Black Box Predictors(BBSE)

foreword

If you are interested in this article, you can click " 【Visitor Must Read-Guide Page】This article includes all high-quality blogs on the homepage " to view the complete blog classification and corresponding links.


overview

Let’s start with an example of influenza. In August, the hospital trained the model ff based on the data of the current month.f , assuming its featurex \bm{x}x is "with or without cough", and the predicted labelyyy is "whether you have flu or not".

Follow-up months model fff was working well, but by February of the following year, the hospital found thatfff predicts that the number of people who "get the flu" has increased significantly. At this time, we know that this is related to "winter is a high incidence of flu". But a problem arose immediately, the fftrained with August dataWhether f can also be effectively predicted in February, and whether the prior learned from the data in August will affect the judgment in February.

Formalizing the problem, we can find that p ( y ∣ x ) = p ( p(y\mid \bm{x})=p(p ( andx)=p ( flu∣\mid∣cough )))p ( y ) = p ( p(y)=p(p ( and )=p ( flu) )) has clearly changed, so previous studies on "covariate shift" are no longer applicable.

Going deeper, we can find that p ( x ∣ y ) = p ( p(\bm{x}\mid y)=p(p(xy)=p ( cough∣\mid∣Flu ))) does not seem to have changed much, thus introducing the "label shift" problem that this article focuses on, which represents the following situation:

  • label marginal distribution p ( y ) p(y)p ( y ) changes, but the conditional distributionp ( x ∣ y ) p(\bm{x}\mid y)p(xy ) unchanged

Then the paper proposes the "Black Box Shift Estimation (BBSE)" method, which uses the "black box predictor" to estimate the changing p ( y ) p(y)p ( y ) , and only its corresponding "confusion matrices" are required to be invertible, even if the predictor is biased, inaccurate or uncalibrated.


problem setting

源域 X × Y \mathcal{X}\times \mathcal{Y} X×Distribution PPon YP D = { ( x i , y i ) } i = 1 n D=\{(\bm{x}_i, y_i)\}_{i=1}^n D={(xi,yi)}i=1n, based on DDD trained black box modelf : X → Y f:\mathcal{X}\rightarrow \mathcal{Y}f:XY

目标域 X × Y \mathcal{X}\times \mathcal{Y} X×Distribution QQon YQ X ′ = [ x 1 ′ ; . . . ; x m ′ ] X'=[\bm{x}_1';...;\bm{x}_m'] X=[x1;...;xm]

目标: detect P → QP\rightarrow QPQWhether a "label shift" has occurred, and if so, retrain the model to adapt to the distributionQQQ

三大假设

  • 「label shift / target shift」假设:
    p ( x ∣ y ) = q ( x ∣ y ) ∀ x ∈ X , y ∈ Y p(\boldsymbol{x} \mid y)=q(\boldsymbol{x} \mid y) \quad \forall x \in \mathcal{X}, y \in \mathcal{Y} p(xy)=q(xy)xX,yY
  • ∀ y ∈ Y \forall y\in \mathcal{Y}yY,若q ( y ) > 0 q(y)>0q(y)>0p ( y ) > 0 p(y)>0p ( and )>0
  • f f The confusion matrix corresponding to f (confusion matrix)C p ( f ) \mathrm{C}_p(f)Cp( f ) is invertible, and the matrix is ​​defined as follows:
    CP ( f ) : = p ( f ( x ) , y ) ∈ R ∣ Y ∣ × ∣ Y ∣ \mathbf{C}_P(f):=p(f(x) , y) \in \mathbb{R}^{|\mathcal{Y}| \times|\mathcal{Y}|}CP(f):=p(f(x),y)RY×Y

B.B.S.E.

The "Black Box Shift Estimation (BBSE)" method is mainly used to estimate w ( y ) = q ( y ) / p ( y ) w(y)=q(y)/p(y)w(y)=q ( ​​y ) / p ( y ) , also define the following:
q ( y ^ ) = ∑ y ∈ Y q ( y ^ ∣ y ) q ( y ) = ∑ y ∈ Y p ( y ^ ∣ y ) q ( y ) = ∑ y ∈ Y p ( y ^ , y ) q ( y ) p ( y ) \begin{aligned} q(\hat{y}) &=\sum_{y \in \mathcal{Y}} q (\hat{y} \mid y) q(y) \\ &=\sum_{y \in \mathcal{Y}} p(\hat{y} \mid y) q(y)=\sum_{y \in \mathcal{Y}} p(\hat{y}, y) \frac{q(y)}{p(y)} \end{aligned}q(y^)=yYq(y^and ) q ( and )=yYp(y^and ) q ( and )=yYp(y^,y)p ( and )q(y)

where y ^ \hat{y}y^ie ffThe pseudo-label given by f , while q ( y ^ ∣ y ) = p ( y ^ ∣ y ) q(\hat{y}\mid y)=p(\hat{y}\mid y)q(y^y)=p(y^y )则来自于下述推导:
q ( y ^ ∣ y ) = ∑ y ∈ Y q ( y ^ ∣ x , y ) q ( x ∣ y ) = ∑ y ∈ Y q ( y ^ ∣ x , y ) p ( x ∣ y ) = ∑ y ∈ Y pf ( y ^ ∣ x ) p ( x ∣ y ) = ∑ y ∈ Y p ( y ^ ∣ x , y ) p ( x ∣ y ) = p ( y ^ ∣ y ) \begin{aligned} &q(\hat{y} \mid y)=\sum_{y \in \mathcal{Y}} q(\hat{y} \mid \boldsymbol{x}, y) q( \boldsymbol{x} \mid y)=\sum_{y \in \mathcal{Y}} q(\hat{y} \mid \boldsymbol{x}, y) p(\boldsymbol{x} \mid y) \\ &=\sum_{y \in \mathcal{Y}} p_f(\hat{y} \mid \boldsymbol{x}) p(\boldsymbol{x} \mid y)=\sum_{y \in \ mathcal{Y}} p(\hat{y} \mid \boldsymbol{x}, y) p(\boldsymbol{x} \mid y)=p(\hat{y} \mid y) \end{aligned}q(y^y)=yYq(y^x,y)q(xy)=yYq(y^x,y)p(xy)=yYpf(y^x)p(xy)=yYp(y^x,y)p(xy)=p(y^y)

The key part is that q ( x ∣ y ) = p ( x ∣ y ) q(\bm{x}\mid y)=p(\bm{x}\mid y)q(xy)=p(xy ) Specifies the following equations:
µ y ^ = C y ^ ∣ y µ y = C y ^ , yww ^ = C ^ y ^ , y − 1 µ ^ y ^ µ ^ y = diag ⁡ ( ν ^ y ) w ^ \begin{gathered} \mu_{\hat{y}}=\mathrm{C}_{\hat{y} \mid y} \mu_y=\mathrm{C}_{\hat{y} , y} w \\ \hat{\bold symbol{w}}=\hat{\mathbf{C}}_{\hat{y}, y}^{-1} \hat{\bold symbol{\mu}} _{\hat{y}} \\ \hat{\ballsymbol{\mu}}_y=\operatorname{diag}\left(\hat{\ballsymbol{\nu}}_y\right) \hat{\ballsymbol{ w}} \end{gathered}my^=Cy^ymy=Cy^,yww^=C^y^,y1m^y^m^y=diag(n^y)w^

The symbols are defined as follows, and the core idea is the formula at the beginning of this section, but a large number of symbols are introduced for the sake of rigor, but the essence remains the same.
insert image description here

Theoretical guarantee

The first is the guarantee of "Consistency";
insert image description here
the second is the guarantee of "Error bounds":
insert image description here
According to the results of "Error bounds" above, it can be found that when choosing a black-box model, " C y ^ , y \mathrm{C}_{\ hat {y}, y}Cy^,yThe larger the minimum singular value, the more suitable the model is.


Label-Shift Detection

Under the first three assumptions, q ( y ) = p ( y ) ⇔ p ( y ^ ) = q ( y ^ ) q(y)=p(y)\Leftrightarrow p(\hat{y})= q(\hat{y})q(y)=p ( and )p(y^)=q(y^) , so using "two-sample tests" forp ( y ^ ) = q ( y ^ ) p(\hat{y})=q(\hat{y})p(y^)=q(y^) for detection.


Fit the model to the new distribution

Calculate w ^ \hat{\bm{w}}wAfter ^ , use "importance weighted ERM" in the source domain datasetD \mathcal{D}Just retrain the model on D
, the specific training target is as follows: L = ∑ i = 1 nw ^ i ⋅ ℓ ( yi , xi ) \mathcal{L}=\sum_{i=1}^n \hat{w}_i \cdot \ell\left(y_i, \bm{x}_i\right)L=i=1nw^i(yi,xi)

The overall algorithm is as follows:
insert image description here


Detect Label-Shift hypothesis is true

Use "kernel two-sample tests" to test whether the following formula is true:
E p [ w ( y ) k ( ϕ ( x ) , ⋅ ) ] = E q [ k ( ϕ ( x ) , ⋅ ) ] \mathbb{ E}_p[\boldsymbol{w}(y) k(\phi(\boldsymbol{x}), \cdot)]=\mathbb{E}_q[k(\phi(\boldsymbol{x}), \cdot )]Ep[ w ( y ) k ( ϕ ( x ) ,)]=Eq[ k ( ϕ ( x ) ,)]

That is, it is transformed into the calculation of the following MMD distance:
∥ 1 n ∑ i = 1 n [ w ^ ( yi ) k ( ϕ ( xi ) , ⋅ ) ] − 1 m ∑ j = 1 mk ( ϕ ( xj ′ ) , ⋅ ) ∥ H 2 \left\|\frac{1}{n} \sum_{i=1}^n\left[\hat{\boldsymbol{w}}\left(y_i\right) k\left(\phi \left(\boldsymbol{x}_i\right), \cdot\right)\right]-\frac{1}{m} \sum_{j=1}^mk\left(\phi\left(\boldsymbol{ x}_j^{\prime}\right), \cdot\right)\right\|_{\mathcal{H}}^2 n1i=1n[w^(yi)k( ϕ(xi),)]m1j=1mk( ϕ(xj),) H2


References

Guess you like

Origin blog.csdn.net/qq_41552508/article/details/127175875