Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions examples/plot_stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
#############################################################################
# COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM
#############################################################################
print("------------SEMI-DUAL PROBLEM------------")
#############################################################################
# DISCRETE CASE
# DISCRETE CASE:
#
# Sample two discrete measures for the discrete case
# ---------------------------------------------
#
Expand Down Expand Up @@ -57,7 +57,8 @@
print(sag_pi)

#############################################################################
# SEMICONTINOUS CASE
# SEMICONTINOUS CASE:
#
# Sample one general measure a, one discrete measures b for the semicontinous
# case
# ---------------------------------------------
Expand Down Expand Up @@ -139,9 +140,9 @@
#############################################################################
# COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM
#############################################################################
print("------------DUAL PROBLEM------------")
#############################################################################
# SEMICONTINOUS CASE
# SEMICONTINOUS CASE:
#
# Sample one general measure a, one discrete measures b for the semicontinous
# case
# ---------------------------------------------
Expand Down
150 changes: 85 additions & 65 deletions ot/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@

def coordinate_grad_semi_dual(b, M, reg, beta, i):
'''
Compute the coordinate gradient update for regularized discrete
distributions for (i, :)
Compute the coordinate gradient update for regularized discrete distributions for (i, :)

The function computes the gradient of the semi dual problem:

.. math::
\W_\varepsilon(a, b) = \max_\v \sum_i (\sum_j v_j * b_j
- \reg log(\sum_j exp((v_j - M_{i,j})/reg) * b_j)) * a_i
\max_v \sum_i (\sum_j v_j * b_j - reg * log(\sum_j exp((v_j - M_{i,j})/reg) * b_j)) * a_i

Where :

where :
- M is the (ns,nt) metric cost matrix
- v is a dual variable in R^J
- reg is the regularization term
Expand All @@ -34,15 +33,15 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
Parameters
----------

b : np.ndarray(nt,),
b : np.ndarray(nt,)
target measure
M : np.ndarray(ns, nt),
M : np.ndarray(ns, nt)
cost matrix
reg : float nu,
reg : float nu
Regularization term > 0
v : np.ndarray(nt,),
optimization vector
i : number int,
v : np.ndarray(nt,)
dual variable
i : number int
picked number i

Returns
Expand Down Expand Up @@ -93,14 +92,19 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):

.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)

s.t. \gamma 1 = a
\gamma^T 1= b

\gamma^T 1 = b

\gamma \geq 0
where :

Where :

- M is the (ns,nt) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)

The algorithm used for solving the problem is the SAG algorithm
as proposed in [18]_ [alg.1]

Expand Down Expand Up @@ -173,33 +177,37 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):

def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
'''
Compute the ASGD algorithm to solve the regularized semi contibous measures
optimal transport max problem
Compute the ASGD algorithm to solve the regularized semi continous measures optimal transport max problem

The function solves the following optimization problem:

.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)

s.t. \gamma 1 = a

\gamma^T 1= b

\gamma \geq 0
where :

Where :

- M is the (ns,nt) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)

The algorithm used for solving the problem is the ASGD algorithm
as proposed in [18]_ [alg.2]


Parameters
----------

b : np.ndarray(nt,),
b : np.ndarray(nt,)
target measure
M : np.ndarray(ns, nt),
M : np.ndarray(ns, nt)
cost matrix
reg : float number,
reg : float number
Regularization term > 0
numItermax : int number
number of iteration
Expand All @@ -211,7 +219,7 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
-------

ave_v : np.ndarray(nt,)
optimization vector
dual variable

Examples
--------
Expand Down Expand Up @@ -265,7 +273,8 @@ def c_transform_entropic(b, M, reg, beta):
.. math::
u = v^{c,reg} = -reg \sum_j exp((v - M)/reg) b_j

where :
Where :

- M is the (ns,nt) metric cost matrix
- u, v are dual variables in R^IxR^J
- reg is the regularization term
Expand All @@ -290,6 +299,7 @@ def c_transform_entropic(b, M, reg, beta):
-------

u : np.ndarray(ns,)
dual variable

Examples
--------
Expand Down Expand Up @@ -341,10 +351,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
s.t. \gamma 1 = a
\gamma^T 1= b
\gamma \geq 0
where :

Where :

- M is the (ns,nt) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)
The algorithm used for solving the problem is the SAG or ASGD algorithms
as proposed in [18]_
Expand All @@ -353,15 +364,15 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
Parameters
----------

a : np.ndarray(ns,),
a : np.ndarray(ns,)
source measure
b : np.ndarray(nt,),
b : np.ndarray(nt,)
target measure
M : np.ndarray(ns, nt),
M : np.ndarray(ns, nt)
cost matrix
reg : float number,
reg : float number
Regularization term > 0
methode : str,
methode : str
used method (SAG or ASGD)
numItermax : int number
number of iteration
Expand Down Expand Up @@ -438,40 +449,40 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
batch_beta):
'''
Computes the partial gradient of F_\W_varepsilon
Computes the partial gradient of the dual optimal transport problem.

For each (i,j) in a batch of coordinates, the partial gradients are :

Compute the partial gradient of the dual problem:
.. math::
\partial_{u_i} F = u_i * b_s/l_{v} - \sum_{j \in B_v} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j

..math:
\forall i in batch_alpha,
grad_alpha_i = alpha_i * batch_size/len(beta) -
sum_{j in batch_beta} exp((alpha_i + beta_j - M_{i,j})/reg)
* a_i * b_j
\partial_{v_j} F = v_j * b_s/l_{u} - \sum_{i \in B_u} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j

Where :

\forall j in batch_alpha,
grad_beta_j = beta_j * batch_size/len(alpha) -
sum_{i in batch_alpha} exp((alpha_i + beta_j - M_{i,j})/reg)
* a_i * b_j
where :
- M is the (ns,nt) metric cost matrix
- alpha, beta are dual variables in R^ixR^J
- u, v are dual variables in R^ixR^J
- reg is the regularization term
- batch_alpha and batch_beta are lists of index
- :math:`B_u` and :math:`B_v` are lists of index
- :math:`b_s` is the size of the batchs :math:`B_u` and :math:`B_v`
- :math:`l_u` and :math:`l_v` are the lenghts of :math:`B_u` and :math:`B_v`
- a and b are source and target weights (sum to 1)


The algorithm used for solving the dual problem is the SGD algorithm
as proposed in [19]_ [alg.1]


Parameters
----------
a : np.ndarray(ns,),

a : np.ndarray(ns,)
source measure
b : np.ndarray(nt,),
b : np.ndarray(nt,)
target measure
M : np.ndarray(ns, nt),
M : np.ndarray(ns, nt)
cost matrix
reg : float number,
reg : float number
Regularization term > 0
alpha : np.ndarray(ns,)
dual variable
Expand Down Expand Up @@ -542,24 +553,29 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):

.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)

s.t. \gamma 1 = a

\gamma^T 1= b

\gamma \geq 0
where :

Where :

- M is the (ns,nt) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)

Parameters
----------
a : np.ndarray(ns,),

a : np.ndarray(ns,)
source measure
b : np.ndarray(nt,),
b : np.ndarray(nt,)
target measure
M : np.ndarray(ns, nt),
M : np.ndarray(ns, nt)
cost matrix
reg : float number,
reg : float number
Regularization term > 0
batch_size : int number
size of the batch
Expand Down Expand Up @@ -633,25 +649,29 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,

.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)

s.t. \gamma 1 = a

\gamma^T 1= b

\gamma \geq 0
where :

Where :

- M is the (ns,nt) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)

Parameters
----------

a : np.ndarray(ns,),
a : np.ndarray(ns,)
source measure
b : np.ndarray(nt,),
b : np.ndarray(nt,)
target measure
M : np.ndarray(ns, nt),
M : np.ndarray(ns, nt)
cost matrix
reg : float number,
reg : float number
Regularization term > 0
batch_size : int number
size of the batch
Expand Down