[FATE Federated Learning] How to obtain the output of the federated model for non-classification and regression tasks?

In general, data usage is obtained from the FATE framework get_component('name').get_output_data().
But in the current 1.x FATE, it can only be obtained by outputting in the format of classification and regression .

If it is a picture, text, token embedding, etc., the output of the model cannot be obtained in this way.

After negotiating with the FATE community personnel , the community affirmed that this method cannot be used. And a method is given, the function in the custom trainer predictdirectly saves the output. It is no longer obtained through the above methods.

It can only be said that it can only be used in this way now.

How to customize the trainer is available in the official documentation.
The original code of the predict part in the trainer is as follows, just add save model prediction directly here:

def _predict(self, dataset: Dataset):

        pred_result = []

        # switch eval mode
        dataset.eval()
        self.model.eval()

        
        labels = []
				# 直接在这里save prediction
        pred = self.model(images)
        torch.save('./xxxx',pred)
        
        length=len(dataset.get_sample_ids())
        ret_rs = torch.rand(length,1)
        ret_label = torch.rand(length, 1).int()

        return dataset.get_sample_ids(), ret_rs, ret_label
        

    def predict(self, dataset: Dataset):

        ids, ret_rs, ret_label=self._predict(dataset)

        if self.fed_mode:
            return self.format_predict_result(
                ids, ret_rs, ret_label, task_type=self.task_type)
        else:
            return ret_rs, ret_label

In the above code, I returned some fake data, because if the format of the returned data does not match, Fateboard will directly report an error and cannot proceed to the next step. So put it there, useless.

Guess you like

Origin blog.csdn.net/Yonggie/article/details/131476482