Skip to content

Commit 4b2eab5

Browse files
author
alexliang
committed
Merge branch 'dev/v0.7.0' of https://github.com/FedML-AI/FedML into dev/v0.7.0
2 parents c5f208c + 7a4364e commit 4b2eab5

File tree

8 files changed

+63
-20
lines changed

8 files changed

+63
-20
lines changed

android/fedmlsdk/src/main/java/ai/fedml/edge/FedEdgeImpl.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public void handleMessage(Message msg) {
6363
} else if (MSG_TRAIN_ACCURACY == msg.what) {
6464
Bundle bundle = msg.getData();
6565
if (onTrainProgressListener != null && bundle != null) {
66-
onTrainProgressListener.onEpochLoss(msg.arg1, msg.arg2,
66+
onTrainProgressListener.onEpochAccuracy(msg.arg1, msg.arg2,
6767
bundle.getFloat(TRAIN_ACCURACY, 0));
6868
}
6969
}

python/examples/cross_silo/mqtt_s3_fedavg_defense_mnist_lr_example/config/foolsgold/fedml_config.yaml

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ environment_args:
1111
data_args:
1212
dataset: "cifar10"
1313
data_cache_dir: ~/fedml_data
14-
partition_method: "hetero"
14+
partition_method: "homo"
1515
partition_alpha: 0.5
1616

1717
model_args:
18-
model: "lr"
18+
model: "resnet56"
1919
model_file_cache_folder: "./model_file_cache" # will be filled by the server automatically
2020
global_model_file_path: "./model_file_cache/global_model.pt"
2121

@@ -24,8 +24,8 @@ train_args:
2424
# for CLI running, this can be None; in MLOps deployment, `client_id_list` will be replaced with real-time selected devices
2525
client_id_list:
2626
# for FoolsGold Defense, if use_memory is true, then client_num_in_total should be equal to client_number_per_round
27-
client_num_in_total: 1000
28-
client_num_per_round: 2
27+
client_num_in_total: 8
28+
client_num_per_round: 8
2929
comm_round: 10
3030
epochs: 1
3131
batch_size: 10
@@ -37,16 +37,13 @@ validation_args:
3737
frequency_of_the_test: 1
3838

3939
device_args:
40-
worker_num: 2
41-
using_gpu: false
42-
gpu_mapping_file: config/gpu_mapping.yaml
43-
gpu_mapping_key: mapping_config3_11
40+
worker_num: 8
41+
using_gpu: true
42+
gpu_mapping_file: config/foolsgold/gpu_mapping.yaml
43+
gpu_mapping_key: mapping_default
4444

4545
comm_args:
46-
backend: "MQTT_S3"
47-
mqtt_config_path:
48-
s3_config_path:
49-
grpc_ipconfig_path: ./config/grpc_ipconfig.csv
46+
backend: "MPI"
5047

5148
tracking_args:
5249
# the default log path is at ~/fedml-client/fedml/logs/ and ~/fedml-server/fedml/logs/
@@ -64,4 +61,3 @@ attack_args:
6461
defense_args:
6562
enable_defense: true
6663
defense_type: foolsgold
67-
use_memory: true
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# this is used for 4 clients and 1 server training within a single machine which has 8 GPUs, but you hope to skip the GPU device ID.
2+
mapping_default:
3+
host1: [3, 2, 2, 2] # assume we only have 4 GPUs
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#!/usr/bin/env bash
22
RANK=$1
33
RUN_ID=$2
4-
python3 torch_client.py --cf config/krum/fedml_config.yaml --rank $RANK --role client --run_id $RUN_ID
4+
python3 torch_client.py --cf config/foolsgold/fedml_config.yaml --rank $RANK --role client --run_id $RUN_ID
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/usr/bin/env bash
2+
3+
WORKER_NUM=$1
4+
5+
PROCESS_NUM=`expr $WORKER_NUM + 1`
6+
echo $PROCESS_NUM
7+
8+
hostname > mpi_host_file
9+
10+
mpirun -np $PROCESS_NUM \
11+
-hostfile mpi_host_file \
12+
python torch_mpi.py --cf config/foolsgold/fedml_config.yaml
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import logging
2+
3+
import fedml
4+
from fedml import FedMLRunner
5+
from fedml.model.cv.resnet import resnet56
6+
7+
8+
def create_model():
9+
# please download the pre-trained weight file from
10+
# https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet44-2a3cabcb.pt
11+
pre_trained_model_path = "./config/resnet56_on_cifar10.pth"
12+
model = resnet56(10, pretrained=True, path=pre_trained_model_path)
13+
logging.info("load pretrained model successfully")
14+
return model
15+
16+
17+
if __name__ == "__main__":
18+
args = fedml.init()
19+
20+
# init device
21+
device = fedml.device.get_device(args)
22+
23+
# load data
24+
dataset, output_dim = fedml.data.load(args)
25+
26+
# load model
27+
model = create_model()
28+
29+
# start training
30+
fedml_runner = FedMLRunner(args, device, dataset, model)
31+
fedml_runner.run()

python/fedml/core/security/defense/foolsgold_defense.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
class FoolsGoldDefense(BaseDefenseMethod):
1717
def __init__(self, config):
18+
super().__init__(config)
1819
self.config = config
1920
self.memory = None
2021

python/fedml/core/security/defense/three_sigma_defense.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,9 @@ def _get_importance_feature(self, raw_client_grad_list):
162162

163163
# Get last key-value tuple
164164
(weight_name, importance_feature) = list(grads.items())[-2]
165-
print(importance_feature)
165+
# print(importance_feature)
166166
feature_len = np.array(
167-
importance_feature.data.detach().numpy().shape
167+
importance_feature.cpu().data.detach().numpy().shape
168168
).prod()
169169
feature_vector = np.reshape(
170170
importance_feature.cpu().data.detach().numpy(), feature_len
@@ -195,14 +195,14 @@ def fools_gold_score(feature_vec_list):
195195
alpha[alpha <= 0.0] = 1e-15
196196

197197
# Rescale so that max value is alpha
198-
print(np.max(alpha))
198+
# print(np.max(alpha))
199199
alpha = alpha / np.max(alpha)
200200
alpha[(alpha == 1.0)] = 0.999999
201201

202202
# Logit function
203203
alpha = np.log(alpha / (1 - alpha)) + 0.5
204-
alpha[(np.isinf(alpha) + alpha > 1)] = 1
205-
alpha[(alpha < 0)] = 0
204+
# alpha[(np.isinf(alpha) + alpha > 1)] = 1
205+
# alpha[(alpha < 0)] = 0
206206

207207
print("alpha = {}".format(alpha))
208208

0 commit comments

Comments
 (0)