前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >「阿里」SCI:基于子空间学习的个体干预效果(ITE)估计方法

「阿里」SCI:基于子空间学习的个体干预效果(ITE)估计方法

作者头像
秋枫学习笔记
发布2023-01-30 15:13:43
3240
发布2023-01-30 15:13:43
举报
文章被收录于专栏:秋枫学习笔记秋枫学习笔记

关注我们,一起学习~

标题:SCI: Subspace Learning Based Counterfactual Inference for Individual Treatment Effect Estimation 地址:https://dl.acm.org/doi/10.1145/3459637.3482175 会议:CIKM 2021 公司:阿里

1. 导读

本文是一篇短文,针对因果效应预估方面提出的方法,现有的表征学习方法侧重于学习一个平衡的特征空间,而忽略了某些预测结果的信息。为了充分利用预测信息,本文提出了一种基于子空间学习的反事实推理(SCI)方法来估计个体因果效应(ITE)。

2. 方法

本文的思路是比较直观的,如图所示即将表征学习分成三个子模块,分别学习公共部分和干预的部分,然后将得到的输出进行拼接,结合相应的分布约束和重构损失来促进学习。总体损失函数如下,接下来分别介绍各个部分。

\mathcal{L}=\mathcal{L}_{f}+\alpha \mathcal{L}_{b}+\beta\left(\mathcal{L}_{\text {pesu }}^{\text {con }}+\mathcal{L}_{p e s u}^{\text {tre }}\right)+\rho \mathcal{L}_{\mathrm{HSIC}}+\gamma \mathcal{L}_{r e c}+\lambda\|W\|_{2}

2.1 对照子空间

这部分的目的是去学习对照组数据特有的信息,通过DNN

Z_{spc}^{con}=\Phi_{con}(X)

学习得到对照组数据的表征Z,然后有两个输出,一方面通过linear层预测对照的伪输出

\tilde{Y}_{spc}^{con}=(w_{spc}^{con})'Z_{spc}^{con}+b_{spc}^{con}

,这部分输出可以构建相应的损失函数,如下所示,

Y^F

为对应的标签,T为是否干预。这里只是单纯的采用对照组数据进行学习,存在样本偏差,因此输出为伪输出,当伪输出接近标签时,学到的表征Z也就更能反映对照组的信息。

\mathcal{L}_{p s e u}^{c o n}=\frac{1}{\sum_{i=1}^{N} \mathbb{I}\left(t_{i}=0\right)}\left\|\left(\mathrm{Y}^{F}-\tilde{\mathrm{Y}}_{s p c}^{c o n}\right) \cdot \operatorname{diag}(1-\mathrm{T})\right\|_{2}^{2} \text {, }

另一方面的输出,是将表征Z送入后续的模块进行拼接和预测。

2.2 干预子空间

和对照子空间同理,干预子空间也是采用干预组的数据进行相应的学习,同样可以构建类似的损失函数。

\mathcal{L}_{p s e u}^{tre}=\frac{1}{\sum_{i=1}^{N} \mathbb{I}\left(t_{i}=0\right)}\left\|\left(\mathrm{Y}^{F}-\tilde{\mathrm{Y}}_{s p c}^{tre}\right) \cdot \operatorname{diag}(\mathrm{T})\right\|_{2}^{2} \text {, }

2.3 公共子空间

前面两个子空间分别学习对照和干预特定的信息,公共子空间用于学习两者的公共信息,这部分采用对照和干预的所有数据,表征表示为

Z_{com}=\Phi_{com}(X)

,通过类似CFR中的IPM对其中的对照组和干预组的表征的分布进行约束,可以采用mmd,Wasserstein距离等方式构建分布约束损失

\mathcal{L}_b

2.4 子空间结合

公共子空间学习的表征可能不足以进行结果预测,而干预,对照子空间学习到的表征可能有限。为了克服使用单个子空间的不足,SCI从公共子空间和干预,对照子空间学习到的规范化表征拼接起来,然后利用HSIC(希尔伯特-施密特独立性准则)进行约束,损失函数如下。对消化损失函数的过程中,可以使表征X与干预T的关联性更弱。

\mathcal{L}_{\mathrm{HSIC}}=\mathrm{HSIC}\left(\mathrm{H}^{\text {con }}, \mathrm{T}\right)+\mathrm{HSIC}\left(\mathrm{H}^{\text {tre }}, \mathrm{T}\right)

2.5 重构和预测

为了使级联表征更有意义,SCI引入了解码器网络

\Psi_{con}

,

\Psi_{tre}

重建原始的输入数据:

\hat{X}^{con}=\Psi_{con}(H^{con})

,

\hat{X}^{tre}=\Psi_{tre}(H^{tre})

重构损失计算如下,

\mathcal{L}_{r e c}=\sum_{i=1}^{N}\left(\left(1-t_{i}\right)\left\|\mathbf{X}[:, i]-\hat{\mathbf{X}}^{\text {con }}[:, i]\right\|_{2}^{2}+t_{i}\left\|\mathbf{X}[:, i]-\hat{\mathbf{X}}^{\operatorname{tre}}[:, i]\right\|_{2}^{2}\right)

并且利用拼接后的数据进行相应的预测,构建预测值与真实值之间的损失函数,如下所示,

\begin{array}{r} \mathcal{L}_{f}=\frac{1}{\sum_{i=1}^{N} \mathbb{I}\left(t_{i}=0\right)}\left\|\left(\mathrm{Y}^{F}-\tilde{\mathrm{Y}}_{0}\right) \cdot \operatorname{diag}(1-\mathrm{T})\right\|_{2}^{2} \\ +\frac{1}{\sum_{i=1}^{N} \mathbb{I}\left(t_{i}=1\right)}\left\|\left(\mathrm{Y}^{F}-\tilde{\mathrm{Y}}_{1}\right) \cdot \operatorname{diag}(\mathrm{T})\right\|_{2}^{2} \end{array}

3. 结果

在真实数据集Jobs和生成的数据集上能够取得不错的效果。

本文参与?腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2022-10-15,如有侵权请联系?cloudcommunity@tencent.com 删除

本文分享自 秋枫学习笔记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与?腾讯云自媒体分享计划? ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 2.1 对照子空间
  • 2.2 干预子空间
  • 2.3 公共子空间
  • 2.4 子空间结合
  • 2.5 重构和预测
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
http://www.vxiaotou.com