- Notifications
You must be signed in to change notification settings - Fork 533
Changes to LP solver: #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
toto6 commented Aug 29, 2017
- Allow to modify the maximal number of iterations
- Display an error message in the python console if the solver encountered an issue
- Allow to modify the maximal number of iterations - Display an error message in the python console if the solver encountered an issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your work, this will definitely make POT easier to use for large matrices.
But there are a few changes missing. Indeed since the PR #22 recent merge, OTDA is deprecated because we wanted to have classes more compatible with scikit-learn.
You should also add this parameter to the ```init_```` .
ot/da.py Outdated
self.computed = False | ||
| ||
def fit(self, xs, xt, ws=None, wt=None, norm=None): | ||
def fit(self, xs, xt, ws=None, wt=None, norm=None, numItermax=10000): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the default value for numItermax should be 100000 because it allows out of the box to handle larger matrices. Same for emd and emd2.
ot/lp/__init__.py Outdated
| ||
| ||
def emd(a, b, M): | ||
def emd(a, b, M, numItermax=10000): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default value to 100000
ot/lp/__init__.py Outdated
return emd_c(a, b, M, numItermax) | ||
| ||
def emd2(a, b, M,processes=multiprocessing.cpu_count()): | ||
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=10000): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
Hi @aje, according to @rflamary remark about OTDA object, you can update the class class EMDTransport(BaseTransport): """Domain Adapatation OT method based on Earth Mover's Distance Parameters ---------- mapping : string, optional (default="barycentric") The kind of mapping to apply to transport samples from a domain into another one. if "barycentric" only the samples used to estimate the coupling can be transported from a domain to another one. metric : string, optional (default="sqeuclidean") The ground metric for the Wasserstein problem distribution : string, optional (default="uniform") The kind of distribution estimation to employ verbose : int, optional (default=0) Controls the verbosity of the optimization algorithm log : int, optional (default=0) Controls the logs of the optimization algorithm limit_max: float, optional (default=10) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit an infinite cost (10 times the maximum value of the cost matrix) max_iter : int, float, optional (default=10000) The minimum number of iteration before stopping the optimization algorithm if no it has not converged Attributes ---------- coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling References ---------- .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 """ def __init__(self, metric="sqeuclidean", distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10, max_iter=10000): self.metric = metric self.limit_max = limit_max self.distribution_estimation = distribution_estimation self.out_of_sample_map = out_of_sample_map self.max_iter = max_iter def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples (Xs, ys) and (Xt, yt) Parameters ---------- Xs : array-like, shape (n_source_samples, n_features) The training input samples. ys : array-like, shape (n_source_samples,) The class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_labeled_target_samples,) The class labels Returns ------- self : object Returns self. """ super(EMDTransport, self).fit(Xs, ys, Xt, yt) # coupling estimation self.coupling_ = emd( a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter ) return self |
OK with @Slasnista , just use max_iter=100000 to allow for larger matrices out of the box. I agree it is a better name but I want to keep consistency with the functional solvers for the moment. I think we should cleanup the function parameter names in 0.5 with deprecation warning anyway. |
Ok Sorry @aje if i wasn't clear but I want to keep numItermax for the emd and emd2 function as is used in the sinkhorn and sinkhorn2 for the moment. I'm Ok with max_iter in the classes though since we will switch to it in a next release. |
Test is failing due to small naming error. I think when this is corrected we can do a quick merge. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @aje
Last comments about the normalization that is a nice touch but that should be handled slightly differently IMO.
ot/da.py Outdated
| ||
return transp_Xt | ||
| ||
def normalizeCost_(self, norm): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the normalization part but think this function should be independent from the class since it is of potential interest to any user.
I would define this function in utils.py with proper documentation. Then you can refer to this function in the doc of the classe
def cost_normalization(C,norm=None): if norm.lower() == "median": C/=float(np.median(C)) elif norm.lower() == "max": C/=float(np.max(C)) elif norm.lower() == "log": C=np.log(1+C) elif norm.lower() == "loglog": C=np.log(1+np.log(1+C)) return C
ot/da.py Outdated
| ||
# pairwise distance | ||
self.cost_ = dist(Xs, Xt, metric=self.metric) | ||
self.normalizeCost_(self.norm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use an extern function as discussed below
ot/da.py Outdated
| ||
self.M = dist(xs, xt, metric=self.metric) | ||
self.normalizeM(norm) | ||
self.M = cost_normalization(self.M, norm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- @aje : can you handle the case where one does not want to normalise the cost matrix ?
- should we do that with a
if
in the.fit
or in thecost_normalization
function ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you give it None is is by default the function will already return C.
Seems OK to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could add the check
if self.norm is not None:
but this also seems ok to me
ok, could you add the parameter |
or another default value if one normalization appears better that |
@aje I agree with @Slasnista , the norm parameter should be defined in the @Slasnista Let's keep it at None for the moment since there is not rule aboiut the best one and it has sometimes weird effect with regularized OT. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK with these changes.
@Slasnista could you do a second quick code review before we merge the PR? |
Ok from my side, everything seems good |