@@ -1096,10 +1096,39 @@ class PyLoweringContext {
10961096 return results;
10971097 }
10981098
1099+ // Returns a mapping from HLO parameter IDs to their corresponding
1100+ // device-backed Tensors. This version only returns parameters that were
1101+ // explicitly allocated on device data, accessible via GetTensorParameterId().
1102+ // Unlike GetParameterIdTensorMapping(), it avoids transferring data from
1103+ // device to host, making it more efficient especially for SPMD scenarios
1104+ // where data may be sharded.
1105+ std::unordered_map<int64_t , at::Tensor> GetDeviceParameterIdTensorMapping () {
1106+ // Find parameters in the lowering
1107+ const std::vector<torch::lazy::BackendDataPtr>& device_data =
1108+ lowering_ctx.GetParametersData ();
1109+
1110+ // Create a mapping from parameter id to the tensor data
1111+ std::unordered_map<int64_t , at::Tensor> param_to_tensor;
1112+ param_to_tensor.reserve (device_data.size ());
1113+
1114+ for (const auto & data : device_data) {
1115+ std::optional<int64_t > param_id = lowering_ctx.GetParameterId (data);
1116+ XLA_CHECK (param_id.has_value ())
1117+ << " Parameter ID must exist for device data" ;
1118+
1119+ at::Tensor tensor =
1120+ bridge::AtenFromXlaTensor (torch_xla::XLATensor::Create (data));
1121+ param_to_tensor.emplace (param_id.value (), std::move (tensor));
1122+ }
1123+ return param_to_tensor;
1124+ }
1125+
10991126 // Get the parameter identifier of a given tensor. If the tensor is not a
11001127 // parameter this will always return -1. This is useful in conjunction with
11011128 // GetParameterIdTensorMapping to identify which values can be baked into
1102- // the graph and which values must remain parameters.
1129+ // the graph and which values must remain parameters. Note that in
1130+ // conjunction with GetDeviceParameterIdTensorMapping, all tensors are
1131+ // parameters with a valid parameter id.
11031132 int64_t GetTensorParameterId (at::Tensor tensor) {
11041133 // Convert tensor into the backing lazy node
11051134 XLATensorPtr xtensor = bridge::GetXlaTensor (tensor);
@@ -1201,6 +1230,8 @@ void BuildLoweringContextSubmodule(py::module* m) {
12011230 .def (" hlo_json" , &PyLoweringContext::GetHloJsonText)
12021231 .def (" parameter_id_tensor_mapping" ,
12031232 &PyLoweringContext::GetParameterIdTensorMapping)
1233+ .def (" device_parameter_id_tensor_mapping" ,
1234+ &PyLoweringContext::GetDeviceParameterIdTensorMapping)
12041235 .def (" tensor_parameter_id" , &PyLoweringContext::GetTensorParameterId)
12051236 .def (" set_name_string" , &PyLoweringContext::SetNameString)
12061237 .def (" get_name_string" , &PyLoweringContext::GetNameString);
0 commit comments