# Heterogeneous Treatment Effect Estimation: Function Approximation

## Table of Contents

## Introduction

I’ve come to think of Causal Inference as *the study of complex, only partially
observed systems*. If we only observe a part of a system, how does that impact
our ability to learn about it? It might seem like if we can’t completely
observe a system, we can’t learn much of anything about it, because how do we
know the parts we don’t observe don’t impact the conclusions we draw? But it
turns out, under certain circumstances, we really can learn at least certain
aspects of the system, and that’s what Causal Inference is all about.

In this sense, Causal Inference is different than other types of Machine Learning. In Deep Learning, we observe the entire system, it’s just that the system is complex and so we need a lot of data to learn about it. Causal Inference is kind of like Deep Learning, but without having all the data Deep Learning would need to work well.

In a not-so-recent
post, I wrote about interpreting supervised learning as an approach to
function approximation. In this post I want to discuss what a “causal effect”
really is, and how we go about estimating it. Suppose we want to learn about
some unknown function $f(t, x, z),$ where $t$ and $x$ are observed but $z$ is
unobserved. Specifically, we want to figure out how $f$ depends on $t,$ which
I’ll call the *treatment*. Perhaps $f$ is the GDP of a country. Clearly this is
too complex to ever understand completely, but we’d like to understand how it
is influenced by a specific input, say the corporate tax rate.

We might wish to estimate $\partial f / \partial t$ for example. Of course, $f$
might not be differentiable with respect to $t$, or perhaps $t$ only takes on
discrete values, such as 0 and 1. In that case we would want to compare $f(0,
x, z)$ and $f(1, x, z)$. Slightly more generally, if there are two values of
$t$ that are of interest, say $t_1$ and $t_0$ we may wish to estimate $\xi(x,
z) = f(t_1, x, z) - f(t_0, x, z).$ This last formulation can stand in for
either of the previous quantities, since the partial derivative is simply a
suitable limit of differences. So in this post I will focus on estimating
$\xi(x, z),$ which I will call the *treatment effect,* or the *causal effect*.

Is $\xi$ really causal? We have it really beat into our heads that correlation
does not imply causation, and we can only draw causal conclusions under certain
circumstances, such as random assignment. But physics draws causal conclusions
all the time: if I drop a ball and it falls to the ground, it’s *because of*
gravity. But I can’t exactly A/B test gravity, so can I really know that
gravity *caused* the ball to fall? Of course we can, but why is physics
different? It’s because in physics, we actually know what $f$ is. We actually
know the equation governing gravity, whereas we don’t know the equation
governing economic productivity. Physical systems are simple and
self-contained, and the economy is complex and interacts with all of human
existence. That’s the difference. The simple formula we have provided for the
causal effect is the correct one, it’s just that when we don’t know $f$, and
when the system is complex and only partially observed, we struggle to learn
anything about it.

Notably, there is nothing random about $f$ or $\xi$, in keeping with my belief
that randomness does not exist except in quantum mechanics. There’s nothing
random about the economy, it’s just really complicated. But as I discussed in
that previous post, we *introduce* randomness (technically *pseudo-randomness*,
but that’s close enough for statistics) in order to approximate a partially
observed system. That is, random sampling is an artificial device we humans can
use to learn about a system. There are several non-random problems in
mathematics where introducing randomness can help solve the problem, such as
evaluating complicated integrals with Monte Carlo methods, and Machine Learning
is in the same vein. But probability is so central to solving these problems we
often forget probability is not necessarily part of the problem itself, just
the way that we solve it.

## Approximation to the Treatment Effect

Now, there are a few situations where it’s actually really easy to learn about the relationship between $f$ and $t.$ When we observe all the relevant factors and have a suitably large (and representative) dataset, we can just fit a deep neural network and base any conclusions off of that. (I think Deep Learning folks sometimes consider Causal Inference to be trivial, because they don’t appreciate the impact of not observing all the relevant information, or that we can just impute the missing data or something.) Or when we have the ability to randomly assign the treatment, as in a typical A/B test, and $\xi(x, z)$ is a constant, then we can estimate it as the difference in response between the two groups.

But when we observe all the relevant factors, or can at least impute them,
that’s not really a *partially observed system*. And if the treatment effect is
constant, that’s not an especially complicated system. What we’re interested in
here is the case of *heterogeneous treatment effects with unobserved factors*.
And in general, if we don’t observe $z$, then we can’t estimate $\xi(x, z).$ In
an A/B test, we end up estimating the *average* treatment effect, but we can do
better. We can try to estimate a particular approximation to $\xi(x, z)$ that
is a function of $x$ alone. Define
$$
\hat{\xi}^\ast(x) = \underset{\zeta \in \mathcal{H}}{\operatorname{arg\,min}} \int_{\mathcal{X}} \int_{\mathcal{Z}} D(\zeta(x), \xi(x, z)) \cdot w(x, z) \,dz \,dx, \quad (1)
$$
where $\mathcal{X} \times \mathcal{Z}$ is the domain of $\xi,$ $D$ is some
distance or loss function, $w$ is some non-negative weighting function
satisfying $\int_{\mathcal{X}} \int_{\mathcal{Z}} w(x, z) \, dz \, dx = 1$ and
$\mathcal{H}$ is some family of functions with support on $\mathcal{X}.$

Now, $\hat{\xi}^\ast(x)$ is by definition the best approximation to $\xi(x, z)$ in
the family $\mathcal{H},$ where best is interpreted relative to the loss
function $D$ and the weighting function $w.$ We will show that under certain
circumstances we can calculate $\hat{\xi}^\ast(x)$ in terms of approximations to
$f.$ Of course, if it turns out that the treatment effect does not actually
depend on $z$, then $\hat{\xi}^\ast(x)$ *is* the treatment effect, not just an
approximation.

## The Calculus of Variations

We will need some results from the Calculus of Variations. Let
$$
\hat{f}^\ast(t, x) = \underset{g \in \mathcal{G}}{\operatorname{arg\,min}} \int_{\mathcal{T}}
\int_{\mathcal{X}} \int_{\mathcal{Z}} D(g(t, x), f(t, x, z)) \cdot w(t, x, z)
\,dz \,dx \, dt.
$$
We first suppose that $\mathcal{G}$ includes all functions with support on
$\mathcal{T} \times \mathcal{X}.$ In this case, the calculus of variations
allows us to calculate $\hat{f}^\ast.$ Define
$$
L(t, x, g) = \int_{\mathcal{Z}} D(g, f(t, x, z)) \cdot w(t, x, z) \, dz.
$$
Then to calculate $\hat{f}^\ast(t, x) = \operatorname{arg\,min}_g
\int_{\mathcal{T}} \int_{\mathcal{X}} L(t, x, g) \, dx \, dt,$ we
differentiate $L$ with respect to $g$ as if it were a variable and set equal to
zero, just as minimizing any other function:
$$
\frac{\partial L}{\partial g} = \int_{\mathcal{Z}} \frac{\partial D}{\partial g} \cdot w(t, x, z) \, dz = 0.
$$
For example, when $D(g, f) = (g - f)^2,$ then
$$
\begin{align}
0 &= \int_{\mathcal{Z}} 2 \cdot (\hat{f}^\ast(t, x) - f(t, x, z)) \cdot w(t, x, z) \, dz \\\
&= \hat{f}^\ast(t, x) \int_{\mathcal{Z}} w(t, x, z) \, dz - \int_{\mathcal{Z}} f(t, x, z) \cdot w(t, x, z) \, dz \\\
\Rightarrow \hat{f}^\ast(t, x) &= \frac{\int_{\mathcal{Z}} f(t, x, z) \cdot w(t, x, z) \, dz}{\int_{\mathcal{Z}} w(t, x, z) \, dz}, \quad (2)
\end{align}
$$
that is, $\hat{f}^\ast$ is a weighted average of $f.$ It can be shown that when
$f$ takes on binary values as in a classification problem, and $D(g, f) = f
\cdot \log(g) + (1 - f) \log(1 - g)$ we actually wind up with the same formula
for $\hat{f}^\ast,$ but I’ll leave this exercise to the reader. I’ll call this
loss the *logistic loss* since it is the loss function in logistic regression.
Recall that $\hat{f}^\ast$ is the best approximation to $f$ not depending on
$z,$ so it makes sense that it would be the average value, integrating over
$z.$

Next we derive an important property of the *residual function*, $\epsilon(t,
x, z) := f(t, x, z) - \hat{f}^\ast(t, x):$
$$
\begin{align}
\int_{\mathcal{Z}} \epsilon(t, x, z) \cdot w(t, x, z) \, dz &= \int_{\mathcal{Z}} (f(t, x, z) - \hat{f}^\ast(t, x)) \cdot w(t, x, z) \, dz \\\
&= \int_{\mathcal{Z}} f(t, x, z) \cdot w(t, x, z) \, dz \\\
&\phantom{=} \hspace{20pt} - \hat{f}^\ast(t, x) \cdot \int_{\mathcal{Z}} w(t, x, z) \, dz \\\
&= 0 \textrm{ for all } t, x
\end{align}
$$
under Equation (2), applicable for squared loss as well as the logistic loss.
In words, the residual function integrates to zero. This is the key result we
need for the next section.

We have placed no restrictions on the class $\mathcal{G},$ but it is straightforward to show that when Equation (2) holds, if $f$ is linear in $t$ and $x$ then so is $\hat{f}^\ast.$ In this case, we can restrict $\mathcal{G}$ to be the space of linear functions, but arrive at the same formula for $\hat{f}^\ast,$ and the residual function still integrates to zero. Whereas, if $f$ is nonlinear, the residual function will not necessarily integrate to zero.

## A Simple Relationship

Now return to $\hat{\xi}^\ast,$ as defined in Equation (1). Notably, the weight
function in the definition of $\hat{\xi}^\ast$ depends on $x$ and $z,$ but in
the previous section it also depended on $t.$ We require a *consistency
condition*, $w^\prime(t, x, z) := \psi(t, x) \cdot w(x, z),$ where
$\psi(x, t) > 0,$ and with $w^\prime$ standing in for $w$ as used in the last section. This
requirement may seem odd, but we will have more to say about its interpretation
below.

Suppose $D(\zeta, \xi) = (\zeta - \xi)^2.$ Then $$ \begin{align} D(\zeta(x), \xi(x, z)) &= (\zeta(x) - \xi(x, z))^2 \\\ &= (\zeta(x) - (f(t_1, x, z) - f(t_0, x, z)))^2 \\\ &= (\zeta(x) - (\hat{f}^\ast(t_1, x) + \epsilon(t_1, x, z) - \hat{f}^\ast(t_0, x) - \epsilon(t_0, x, z)))^2 \\\ &= (\zeta(x) - (\hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x)) - (\epsilon(t_1, x, z) - \epsilon(t_0, x, z)))^2 \\\ &= (\zeta(x) - (\hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x)))^2 \\\ &\phantom{=} - 2 \cdot (\zeta(x) - (\hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x))) \cdot (\epsilon(t_1, x, z) - \epsilon(t_0, x, z))\\\ &\phantom{=} + (\epsilon(t_1, x, z) - \epsilon(t_0, x, z))^2. \end{align} $$

Thus,
$$
\begin{align}
\hat{\xi}^\ast(x) &= \underset{\zeta}{\operatorname{arg\,min}} \left\{
\int_{\mathcal{X}} \int_{\mathcal{Z}} D(\zeta(x), \xi(x, z)) \cdot w(x, z)
\,dz \,dx \right\} \\\
&= \underset{\zeta}{\operatorname{arg\,min}} \left\{ \int_{\mathcal{X}} (\zeta(x) - (\hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x)))^2 \, \int_{\mathcal{Z}} w(x, z) \, dz \, dx \right. \\\
&\phantom{= \operatorname{arg \,min} \{} - 2 \int_{\mathcal{X}} (\zeta(x) - (\hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x))) \\\
&\phantom{= \operatorname{arg \,min} \{}\hspace{30pt} \times \int_{\mathcal{Z}} (\epsilon(t_1, x, z) - \epsilon(t_0, x, z)) \cdot w(x, z) \, dz \, dx \\\
&\phantom{= \operatorname{arg \,min} \{} + \left. \int_{\mathcal{X}} \int_{\mathcal{Z}} (\epsilon(t_1, x, z) - \epsilon(t_0, x, z))^2 \cdot w(x, z) \, dz \, dx \right\}. \quad (3)
\end{align}
$$
The third term in this expression does not depend on $\zeta$ and thus does not
affect the solution. And the second term in this expression is zero since
$$
\begin{align}
\int_{\mathcal{Z}} (\epsilon(t_1, x, z) - \epsilon(t_0, x, z)) \cdot w(x, z) \, dz &= \frac{1}{\psi(t_1, x)} \int_{\mathcal{Z}} \epsilon(t_1, x, z) \cdot w^\prime(t_1, x, z) \, dz \\\
&\phantom{=} - \frac{1}{\psi(t_0, x)} \int_{\mathcal{Z}} \epsilon(t_0, x, z) \cdot w^\prime(t_0, x, z) \, dz \\\
&= 0,
\end{align}
$$
by the consistency condition on the weights, and the result from the last
section that the residuals integrate to zero. That leaves
$$
\hat{\xi}^\ast(x) = \underset{\zeta}{\operatorname{arg\,min}} \left\{ \int_{\mathcal{X}} (\zeta(x) - (\hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x)))^2 \, \int_{\mathcal{Z}} w(x, z) \, dz \, dx \right\}.
$$
The objective is an integral of a non-negative function but is zero when the
first term is zero. Thus, the solution is attained when
$$
\hat{\xi}^\ast(x) = \hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x).
$$
In words, *the approximation to the treatment effect is equal to the difference
in approximations of the function*.

## Commentary

This result is so simple it seems like we shouldn’t have had to derive it. We
*defined* $\xi(x, z) = f(t_1, x, z) - f(t_0, x, z)$, and we *demonstrated*
that $\hat{\xi}^\ast(x) = \hat{f}^\ast(t_1, x) - \hat{f}^\ast(t_0, x)$. This
seems obvious, but it really isn’t! And it’s only true under certain
circumstances.

Here’s why it matters: we wish to calculate an approximation to the treatment effect, $\xi(x, z),$ that depends on $x$ alone. However, in practice, we never actually observe the treatment effect directly; we only observe the function $f.$ What our result demonstrates is that if we have the ability to approximate $f,$ we can simply calculate $\hat{\xi}^\ast$ in terms of this approximation. In other words, there is a simple and intuitive relationship between the approximation to the treatment effect and the approximation to the function, $f$.

This *only* works under special circumstances:

- The loss function on the treatment effect approximation is squared error.
- The loss function on the function $f$ is squared error or the logistic loss.
- The family of functions used to approximate $f$ either satisfies a universal
function representation property, as deep neural networks do,
**or**, $f$ itself is linear. - The weight function $w^\prime(t, x, z) = \psi(t, x) \cdot w(x, z),$ with $\psi(t, x) > 0.$

Of these, the last seems most needing discussion. I’ll rewrite the consistency
condition as: $w^\prime(t, x, z) / w(x, z) = \psi(t, x)$ (when $w(x, z) > 0$).
Recall from the last post
that $w^\prime$ has the interpretation of a probability distribution that we
sample from when generating the dataset used to fit a model to $f.$ We can
think of $w(x, z)$ as being the weight function reflecting where we want
$\hat{\xi}^\ast(x)$ to best approximate $\xi(x, z).$ And in order to estimate
$\hat{\xi}^\ast(x)$ we sample from the domain of $f,$ with a sampling
distribution given by $w^\prime(t, x, z).$ But in order to get a valid
estimate, we need $w^\prime$ to be related to $w$ in a specific way: the ratio
$w^\prime / w$ could be a constant, or it could depend on the treatment, or it
could depend on both the treatment and the observed covariates $x$, but it
cannot depend on the unobserved covariates $z$. (I think this requirement is
analogous to the *unconfoundedness assumption* in Causal Inference.) Only
special sampling strategies satisfy this property!

In the last paragraph we *decided* on a $w$ that we cared about, and then
*designed* a corresponding $w^\prime$. But in an observational study it’s the
opposite: nature *hands us* a $w^\prime$, and we can only hope that it can be
factored appropriately. Otherwise, the approximation to the treatment effect is
not necessarily equal to the difference in the approximations to $f.$

But especially in an experimental setting, this provides a recipe for
estimating heterogeneous treatment effects. Fit models $\hat{f}_1(x) =
\hat{f}^\ast(t_1, x)$ and $\hat{f}_0(x) = \hat{f}^\ast(t_0, x)$, then
calculate $\hat{\xi}^\ast(x) = \hat{f}_1(x) - \hat{f}_0(x)$. In the
literature on heterogenous treatment effect estimation (also known as *uplift
modeling*), this is called the *two model approach*. It turns out there are
better approaches, but the two model approach is simple and intuitive.

As a parting thought, suppose we wanted to draw some insights about $\xi(x, z)$. That is, suppose we are interested in how the treatment effect depends on one of the observed covariates, say $x^{(1)}$. I’ll rewrite the treatment effect as $\xi(x^{(1)}, x^\prime, z),$ where $x^\prime$ just denotes all the other observed covariates. Then we might want to estimate $\partial \xi / \partial x^{(1)}$ or $\xi(x_1^{(1)}, x^\prime, z) - \xi(x_0^{(1)}, x^\prime, z),$ where $x_1^{(1)}$ and $x_0^{(1)}$ denote two values of interest. Then we’re actually in exactly the same position as we were before: we want to learn about some function that depends on unobserved factors, and we can only do this under the same conditions outlined above. Perhaps most importantly, the sampling strategy has to be unconfounded with respect to $x^{(1)}$. And just because the sampling strategy is unconfounded with respect to $t$, doesn’t mean it is unconfounded with respect to $x^{(1)}$.

So when estimating heterogeneous treatment effects, we need to be careful about
interpreting the resulting model. While we can certainly *predict* the
treatment effect for any set of covariates $x$, we cannot say that differing
treatment effects are *because* of particular covariate values, except under
special circumstances.

## Summary

In this post, I described Causal Inference as the study of complex, only partially observed systems. I defined the treatment effect as a comparison between two values of a treatment, keeping all other factors constant.

The treatment effect itself may depend on unobserved factors, but under certain circumstances we can calculate an approximation to the treatment effect in terms of approximations to the system or function itself. The most important requirement is a factorization property of the sampling distribution used to approximate the system that is related to the unconfoundedness assumption in the Causal Inference literature. In general, this can only be guaranteed in the context of a controlled experiment.

Finally, I provided some commentary on the challenges of drawing causal
conclusions about the treatment effects themselves. While we can *predict* the
treatment effect whenever the treatment assignment is unconfounded, we also
need a covariate of interest to be unconfounded in order to draw causal
conclusions about its effect on the treatment effect.

## Further Reading

The Calculus of Variations plays a central role in Classical Mechanics, and
that’s where I learned about it. V.I. Arnold’s book, *Mathematical Methods of
Classical Mechanics*, is a gem. *Calculus of Variations* by I. M. Gelfand and
S. V. Fomin is a more general-purpose reference.

One of my colleagues, Huigang Chen, is a primary contributor to CausalML, a python package for uplift modeling. The algorithms implemented are much fancier than the two-model approach I describe above! Their Github page contains a list of references for folks that want to learn more about Uplift Modeling.