From 0630191ddd54d5157f9981449d83b341608cd40d Mon Sep 17 00:00:00 2001 From: Yash Agarwal Date: Fri, 17 Jan 2025 12:26:34 -0800 Subject: [PATCH] fixed various type checking and mismatched type errors --- src/algos/MetaL2C.py | 2 +- src/algos/attack_add_noise.py | 2 +- src/algos/attack_bad_weights.py | 2 +- src/algos/attack_sign_flip.py | 2 +- src/algos/base_class.py | 2 +- src/algos/topologies/base.py | 2 +- src/algos/topologies/collections.py | 2 +- src/configs/algo_config.py | 2 +- src/configs/algo_config_test.py | 2 +- src/configs/malicious_config.py | 2 +- src/configs/non_iid_clients.py | 32 +++--- src/configs/sys_config.py | 4 +- src/configs/sys_config_test.py | 6 +- src/main_grpc.py | 2 +- src/utils/communication/grpc/main.py | 124 ++++++++++++------------ src/utils/{types.py => custom_types.py} | 0 src/utils/dropout_utils.py | 2 +- src/utils/log_utils.py | 2 +- src/utils/model_utils.py | 2 +- 19 files changed, 97 insertions(+), 97 deletions(-) rename src/utils/{types.py => custom_types.py} (100%) diff --git a/src/algos/MetaL2C.py b/src/algos/MetaL2C.py index beb267de..dbd133fe 100644 --- a/src/algos/MetaL2C.py +++ b/src/algos/MetaL2C.py @@ -398,7 +398,7 @@ def run_protocol(self) -> None: total_rounds = self.config["rounds"] stats = [] - avg_alpha = None + avg_alpha : Dict[str, torch.Tensor] = {} for cur_round in range(start_round, total_rounds): self.round = cur_round self.log_utils.log_console(f"Starting round {cur_round}") diff --git a/src/algos/attack_add_noise.py b/src/algos/attack_add_noise.py index 2835269c..01e5c323 100644 --- a/src/algos/attack_add_noise.py +++ b/src/algos/attack_add_noise.py @@ -19,7 +19,7 @@ from collections import OrderedDict from typing import Dict from torch import Tensor -from utils.types import ConfigType +from utils.custom_types import ConfigType class AddNoiseAttack: diff --git a/src/algos/attack_bad_weights.py b/src/algos/attack_bad_weights.py index 83e4dcb1..74f870aa 100644 --- a/src/algos/attack_bad_weights.py +++ b/src/algos/attack_bad_weights.py @@ -17,7 +17,7 @@ from collections import OrderedDict from typing import Dict from torch import Tensor -from utils.types import ConfigType +from utils.custom_types import ConfigType class BadWeightsAttack: diff --git a/src/algos/attack_sign_flip.py b/src/algos/attack_sign_flip.py index bb9ef0d4..497d27c0 100644 --- a/src/algos/attack_sign_flip.py +++ b/src/algos/attack_sign_flip.py @@ -17,7 +17,7 @@ from collections import OrderedDict from typing import Dict from torch import Tensor -from utils.types import ConfigType +from utils.custom_types import ConfigType class SignFlipAttack: diff --git a/src/algos/base_class.py b/src/algos/base_class.py index e967ade4..facff5d1 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -36,7 +36,7 @@ get_dset_balanced_communities, get_dset_communities, ) -from utils.types import ConfigType +from utils.custom_types import ConfigType from utils.dropout_utils import NodeDropout import torchvision.transforms as T # type: ignore diff --git a/src/algos/topologies/base.py b/src/algos/topologies/base.py index fdc86aef..1c4856e9 100644 --- a/src/algos/topologies/base.py +++ b/src/algos/topologies/base.py @@ -4,7 +4,7 @@ import numpy as np import networkx as nx -from utils.types import ConfigType +from utils.custom_types import ConfigType class BaseTopology(ABC): """ diff --git a/src/algos/topologies/collections.py b/src/algos/topologies/collections.py index 632f42b4..7d851f1e 100644 --- a/src/algos/topologies/collections.py +++ b/src/algos/topologies/collections.py @@ -1,5 +1,5 @@ from algos.topologies.base import BaseTopology -from utils.types import ConfigType +from utils.custom_types import ConfigType from math import ceil import networkx as nx diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index a2242c62..3c3b6df9 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -1,7 +1,7 @@ from typing import Dict, List from .malicious_config import malicious_config_list import random -from utils.types import ConfigType +from utils.custom_types import ConfigType def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, str]: diff --git a/src/configs/algo_config_test.py b/src/configs/algo_config_test.py index bce2b882..d6b10246 100644 --- a/src/configs/algo_config_test.py +++ b/src/configs/algo_config_test.py @@ -1,4 +1,4 @@ -from utils.types import ConfigType +from utils.custom_types import ConfigType fedstatic: ConfigType = { # Collaboration setup diff --git a/src/configs/malicious_config.py b/src/configs/malicious_config.py index 15025d42..c469c01e 100644 --- a/src/configs/malicious_config.py +++ b/src/configs/malicious_config.py @@ -1,5 +1,5 @@ # Malicious Configuration -from utils.types import ConfigType +from utils.custom_types import ConfigType from typing import Dict # Weight Update Attacks diff --git a/src/configs/non_iid_clients.py b/src/configs/non_iid_clients.py index 126fbe02..3e172101 100644 --- a/src/configs/non_iid_clients.py +++ b/src/configs/non_iid_clients.py @@ -101,25 +101,25 @@ def get_domain_support( return support -CIFAR10_ROT_DMN: List[str] = ["r0", "r90", "r180", "r270"] +CIFAR10_ROT_DMN: List[str | int] = ["r0", "r90", "r180", "r270"] def get_cifar10_rot_support( - num_clients: int, domains: List[str] = CIFAR10_ROT_DMN + num_clients: int, domains: List[str | int] = CIFAR10_ROT_DMN ) -> Dict[str, str]: return get_domain_support(num_clients, "cifar10", domains) -DOMAINNET_DMN: List[str] = ["real", "sketch", "clipart"] +DOMAINNET_DMN: List[str | int] = ["real", "sketch", "clipart"] def get_domainnet_support( - num_clients: int, domains: List[str] = DOMAINNET_DMN + num_clients: int, domains: List[str | int] = DOMAINNET_DMN ) -> Dict[str, str]: return get_domain_support(num_clients, "domainnet", domains) -DOMAINNET_DMN_FULL: List[str] = [ +DOMAINNET_DMN_FULL: List[str | int] = [ "real", "sketch", "clipart", @@ -130,42 +130,42 @@ def get_domainnet_support( def get_domainnet_support_full( - num_clients: int, domains: List[str] = DOMAINNET_DMN_FULL + num_clients: int, domains: List[str | int] = DOMAINNET_DMN_FULL ) -> Dict[str, str]: return get_domain_support(num_clients, "domainnet", domains) -DOMAINNET_DMN_V2: List[str] = ["infograph", "quickdraw", "painting"] +DOMAINNET_DMN_V2: List[str | int] = ["infograph", "quickdraw", "painting"] def get_domainnet_support_v2( - num_clients: int, domains: List[str] = DOMAINNET_DMN_V2 + num_clients: int, domains: List[str | int] = DOMAINNET_DMN_V2 ) -> Dict[str, str]: return get_domain_support(num_clients, "domainnet", domains) -IWILDCAM_DMN: List[int] = list(range(1, 5)) # 245 possible +IWILDCAM_DMN: List[str | int] = list(range(1, 5)) # 245 possible def get_iwildcam_support( - num_clients: int, domains: List[int] = IWILDCAM_DMN + num_clients: int, domains: List[str | int] = IWILDCAM_DMN ) -> Dict[str, str]: return get_domain_support(num_clients, "wilds_iwildcam", domains) # 2 classes # 3 domains: 0:116'959, 3:132'052, 4:5'3425 in training set -CAMELYON17_DMN: List[int] = [0, 3, 4] # + 1, 2 in test set +CAMELYON17_DMN: List[str | int] = [0, 3, 4] # + 1, 2 in test set def get_camelyon17_support( - num_clients: int, domains: List[int] = CAMELYON17_DMN + num_clients: int, domains: List[str | int] = CAMELYON17_DMN ) -> Dict[str, str]: return get_domain_support(num_clients, "wilds_camelyon17", domains) # Issue every of the 1139 classes has only 1 sample per domain => how to create in domain test set ? -RXRX1_DMN = [ +RXRX1_DMN: List[str | int] = [ 0, 1, 2, @@ -175,15 +175,15 @@ def get_camelyon17_support( ] # in train set: 0-6, 11-26, 35-41, 46-48, in test set: 7-10, 27-34, 42-45, 49-50 -def get_rxrx1_support(num_clients, domains=RXRX1_DMN): +def get_rxrx1_support(num_clients: int, domains: List[str|int] =RXRX1_DMN): return get_domain_support(num_clients, "wilds_rxrx1", domains) # 0: 17'809, 1: 34'816, 2: 1'582, 3: 20'973, 4:1'641, 5: 42 -FMOW_DMN = [0, 1, 3] +FMOW_DMN : List[str|int] = [0, 1, 3] -def get_fmow_support(num_clients, domains=FMOW_DMN): +def get_fmow_support(num_clients: int, domains : List[str|int] =FMOW_DMN): return get_domain_support(num_clients, "wilds_fmow", domains) diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 738f4c93..24f79f12 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -3,7 +3,7 @@ # is to simulate different real-world scenarios without changing the algorithm configuration. from typing import Dict, List, Literal, Optional import random -from utils.types import ConfigType +from utils.custom_types import ConfigType # from utils.config_utils import get_sliding_window_support, get_device_ids from .algo_config import ( @@ -334,7 +334,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): # for swift or fedavgpush, just modify the algo_configs list # for swift, synchronous should preferable be False -gpu_ids = [2, 3, 5, 6] +gpu_ids : List[int | Literal['cpu']] = [2, 3, 5, 6] grpc_system_config: ConfigType = { "exp_id": "static", diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py index 6e692146..5b497b73 100644 --- a/src/configs/sys_config_test.py +++ b/src/configs/sys_config_test.py @@ -1,9 +1,9 @@ from typing import Dict, List, Literal, Optional import random -from utils.types import ConfigType +from utils.custom_types import ConfigType from .algo_config_test import ( - traditional_fl, + #traditional_fl, fedstatic ) @@ -99,7 +99,7 @@ def get_algo_configs( for i in range(1, num_users + 1): dropout_dicts[f"node_{i}"] = dropout_dict -gpu_ids = [2, 3, 5, 6] +gpu_ids : List[int | Literal['cpu']] = [2, 3, 5, 6] grpc_system_config: ConfigType = { "exp_id": "static", diff --git a/src/main_grpc.py b/src/main_grpc.py index e3678431..01de1621 100644 --- a/src/main_grpc.py +++ b/src/main_grpc.py @@ -35,7 +35,7 @@ # Command for opening each process command_list: List[str] = ["python", "main.py", "-host", args.host] if args.dev == True: - command_list: List[str] = ["python", "main.py", "-b", "./configs/algo_config_test.py", "-s", "./configs/sys_config_test.py", "-host", args.host] + command_list = ["python", "main.py", "-b", "./configs/algo_config_test.py", "-s", "./configs/sys_config_test.py", "-host", args.host] # Start process for each user for i in range(args.n): diff --git a/src/utils/communication/grpc/main.py b/src/utils/communication/grpc/main.py index b850f0e3..b82280c2 100644 --- a/src/utils/communication/grpc/main.py +++ b/src/utils/communication/grpc/main.py @@ -9,7 +9,7 @@ import functools from typing import Any, Callable, Dict, List, OrderedDict, Union, TYPE_CHECKING, Set from urllib.parse import unquote -import grpc # type: ignore +import grpc # from torch import Tensor from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model, serialize_message, deserialize_message import os @@ -38,15 +38,15 @@ # 4. Probably a good idea to move the Servicer class to a separate file # 5. Not needed for benchmarking but for the system to be robust, we need to implement timeouts and fault tolerance # 6. Peer_ids should be indexed by a unique identifier -# 7. Try to get rid of type: ignore as much as possible +# 7. Try to get rid of type:ignore as much as possible def is_port_available(port: int) -> bool: """ Check if a port is available for use. """ - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: # type: ignore - return s.connect_ex(("localhost", port)) != 0 # type: ignore + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: # + return s.connect_ex(("localhost", port)) != 0 # def get_port(rank: int, num_users: int) -> int: @@ -122,30 +122,30 @@ def register_self(self, obj: "BaseNode"): self.base_node = obj @update_communication_cost - def send_model(self, request: comm_pb2.Model, context) -> comm_pb2.Empty: # type: ignore + def send_model(self, request: comm_pb2.Model, context) -> comm_pb2.Empty: # deserialized_message = deserialize_message(request.buffer) - self.received_data.put(deserialized_message) # type: ignore - return comm_pb2.Empty() # type: ignore + self.received_data.put(deserialized_message) # + return comm_pb2.Empty() # @update_communication_cost def get_rank(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> comm_pb2.Rank | None: try: with self.lock: - peer = context.peer() # type: ignore + peer = context.peer() # # parse the hostname from peer - peer_str = parse_peer_address(peer) # type: ignore + peer_str = parse_peer_address(peer) # rank = len(self.peer_ids) # TODO: index the peer_ids by a unique identifier self.peer_ids[rank] = {"rank": rank, "port": 0, "ip": peer_str} rank = self.peer_ids[rank].get("rank", -1) # Default to -1 if not found - return comm_pb2.Rank(rank=rank) # type: ignore + return comm_pb2.Rank(rank=rank) # except Exception as e: - context.abort(grpc.StatusCode.INTERNAL, f"Error in get_rank: {str(e)}") # type: ignore + context.abort(grpc.StatusCode.INTERNAL, f"Error in get_rank: {str(e)}") # @update_communication_cost def get_model(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> comm_pb2.Model | None: if not self.base_node: - context.abort(grpc.StatusCode.INTERNAL, "Base node not registered") # type: ignore + context.abort(grpc.StatusCode.INTERNAL, "Base node not registered") # raise Exception("Base node not registered") with self.lock: message_to_send = self.base_node.get_model_weights() @@ -154,13 +154,13 @@ def get_model(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> c if "model" in message_to_send: del message_to_send["model"] message_to_send[EMPTY_MODEL_TAG] = True - model = comm_pb2.Model(buffer=serialize_message(message_to_send)) # type: ignore + model = comm_pb2.Model(buffer=serialize_message(message_to_send)) # return model @update_communication_cost def get_current_round(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> comm_pb2.Round | None: if not self.base_node: - context.abort(grpc.StatusCode.INTERNAL, "Base node not registered") # type: ignore + context.abort(grpc.StatusCode.INTERNAL, "Base node not registered") # raise Exception("Base node not registered") with self.lock: round = comm_pb2.Round(round=self.base_node.get_local_rounds()) @@ -172,33 +172,33 @@ def update_port( with self.lock: # FIXME: This is a security vulnerability because # any node can update the ip and port of any other node - self.peer_ids[request.rank.rank]["ip"] = request.ip # type: ignore - self.peer_ids[request.rank.rank]["port"] = request.port.port # type: ignore - return comm_pb2.Empty() # type: ignore + self.peer_ids[request.rank.rank]["ip"] = request.ip # + self.peer_ids[request.rank.rank]["port"] = request.port.port # + return comm_pb2.Empty() # - def send_peer_ids(self, request: comm_pb2.PeerIds, context) -> comm_pb2.Empty: # type: ignore + def send_peer_ids(self, request: comm_pb2.PeerIds, context) -> comm_pb2.Empty: # """ Used by the super node to update all peers with the peer_ids after achieving quorum. """ - peer_ids: comm_pb2.PeerIds = request.peer_ids # type: ignore - for rank in peer_ids: # type: ignore - peer_id_proto = peer_ids[rank] # type: ignore + peer_ids: comm_pb2.PeerIds = request.peer_ids # + for rank in peer_ids: # + peer_id_proto = peer_ids[rank] # peer_id_dict: Dict[str, Union[int, str]] = { - "rank": peer_id_proto.rank.rank, # type: ignore - "port": peer_id_proto.port.port, # type: ignore - "ip": peer_id_proto.ip, # type: ignore + "rank": peer_id_proto.rank.rank, # + "port": peer_id_proto.port.port, # + "ip": peer_id_proto.ip, # } self.peer_ids[rank] = peer_id_dict return comm_pb2.Empty() - def send_quorum(self, request, context) -> comm_pb2.Empty: # type: ignore - self.quorum.put(request.quorum) # type: ignore - return comm_pb2.Empty() # type: ignore + def send_quorum(self, request, context) -> comm_pb2.Empty: # + self.quorum.put(request.quorum) # + return comm_pb2.Empty() # - def send_finished(self, request, context) -> comm_pb2.Empty: # type: ignore - self.finished.put(request.rank) # type: ignore - return comm_pb2.Empty() # type: ignore + def send_finished(self, request, context) -> comm_pb2.Empty: # + self.finished.put(request.rank) # + return comm_pb2.Empty() # class GRPCCommunication(CommunicationInterface): def __init__(self, config: Dict[str, Dict[str, Any]]): @@ -211,7 +211,7 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): # 2. Once a threshold number of peers have registered, the super node sets quorum to True # 3. The super node broadcasts the peer_ids to all peers # 4. The nodes will execute rest of the protocol in the same way as before - self.num_users: int = int(config["num_users"]) # type: ignore + self.num_users: int = int(config["num_users"]) # self.rank: int | None = config["comm"]["rank"] # TODO: Get rid of peer_ids now that we are passing [comm][host] self.super_node_host: str = config["comm"]["peer_ids"][0] @@ -237,7 +237,7 @@ def register_self(self, obj: "BaseNode"): self.servicer.register_self(obj) def recv_with_retries(self, host: str, callback: Callable[[comm_pb2_grpc.CommunicationServerStub], Any]) -> Any: - with grpc.insecure_channel(host, options=[ # type: ignore + with grpc.insecure_channel(host, options=[ # ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH), ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH), ]) as channel: @@ -263,30 +263,30 @@ def register(self): 3. The node updates its port and sends the updated peer_ids to the super node """ def callback_fn(stub: comm_pb2_grpc.CommunicationServerStub) -> int: - rank_data = stub.get_rank(comm_pb2.Empty()) # type: ignore + rank_data = stub.get_rank(comm_pb2.Empty()) # with self.servicer.lock: self.servicer.communication_cost_received += rank_data.ByteSize() - return rank_data.rank # type: ignore + return rank_data.rank # self.rank = self.recv_with_retries(self.super_node_host, callback_fn) - self.port = get_port(self.rank, self.num_users) # type: ignore because we are setting it in the register method - rank = comm_pb2.Rank(rank=self.rank) # type: ignore + self.port = get_port(self.rank, self.num_users) # because we are setting it in the register method + rank = comm_pb2.Rank(rank=self.rank) # port = comm_pb2.Port(port=self.port) peer_id = comm_pb2.PeerId(rank=rank, port=port, ip=self.host) - with grpc.insecure_channel(self.super_node_host) as channel: # type: ignore + with grpc.insecure_channel(self.super_node_host) as channel: # stub = comm_pb2_grpc.CommunicationServerStub(channel) - stub.update_port(peer_id) # type: ignore + stub.update_port(peer_id) # def start_listener(self): - self.listener: grpc.Server = grpc.server( # type: ignore + self.listener: grpc.Server = grpc.server( # futures.ThreadPoolExecutor(max_workers=4), - options=[ # type: ignore + options=[ # ("grpc.max_send_message_length", 100 * 1024 * 1024), # 100MB ("grpc.max_receive_message_length", 100 * 1024 * 1024), # 100MB ], ) - comm_pb2_grpc.add_CommunicationServerServicer_to_server(self.servicer, self.listener) # type: ignore + comm_pb2_grpc.add_CommunicationServerServicer_to_server(self.servicer, self.listener) # address = f"{self.host}:{self.port}" self.listener.add_insecure_port(address) self.listener.start() @@ -297,8 +297,8 @@ def peer_ids_to_proto( ) -> Dict[int, comm_pb2.PeerId]: peer_ids_proto: Dict[int, comm_pb2.PeerId] = {} for peer_id in peer_ids: - rank = comm_pb2.Rank(rank=peer_ids[peer_id].get("rank")) # type: ignore - port = comm_pb2.Port(port=peer_ids[peer_id].get("port")) # type: ignore + rank = comm_pb2.Rank(rank=peer_ids[peer_id].get("rank")) # + port = comm_pb2.Port(port=peer_ids[peer_id].get("port")) # ip = str(peer_ids[peer_id].get("ip")) peer_ids_proto[peer_id] = comm_pb2.PeerId(rank=rank, port=port, ip=ip) return peer_ids_proto @@ -337,13 +337,13 @@ def initialize(self): port = self.servicer.peer_ids[peer_id].get("port") address = f"{host_ip}:{port}" print(f"Sending peer_ids to {address}") - with grpc.insecure_channel(address) as channel: # type: ignore + with grpc.insecure_channel(address) as channel: # stub = comm_pb2_grpc.CommunicationServerStub(channel) proto_msg = comm_pb2.PeerIds( peer_ids=self.peer_ids_to_proto(self.servicer.peer_ids) ) - stub.send_peer_ids(proto_msg) # type: ignore - # stub.send_quorum(comm_pb2.Quorum(quorum=True)) # type: ignore + stub.send_peer_ids(proto_msg) # + # stub.send_quorum(comm_pb2.Quorum(quorum=True)) # def send_quorum(self): """ Send the quorum status to all nodes after peer IDs are sent. """ @@ -351,25 +351,25 @@ def send_quorum(self): for peer_id in self.servicer.peer_ids: if not self.is_own_id(peer_id): host = self.get_host_from_rank(peer_id) - with grpc.insecure_channel(host) as channel: # type: ignore + with grpc.insecure_channel(host) as channel: # stub = comm_pb2_grpc.CommunicationServerStub(channel) - stub.send_quorum(comm_pb2.Quorum(quorum=True)) # type: ignore + stub.send_quorum(comm_pb2.Quorum(quorum=True)) # print(f"Quorum status sent to all nodes.") def get_host_from_rank(self, rank: int) -> str: for peer_id in self.servicer.peer_ids: if self.servicer.peer_ids[peer_id].get("rank") == rank: - return self.servicer.peer_ids[peer_id].get("ip") + ":" + str(self.servicer.peer_ids[peer_id].get("port")) # type: ignore + return self.servicer.peer_ids[peer_id].get("ip") + ":" + str(self.servicer.peer_ids[peer_id].get("port")) # raise Exception(f"Rank {rank} not found in peer_ids") def send_with_retries(self, dest_host: str, buffer: Any) -> Any: - with grpc.insecure_channel(dest_host) as channel: # type: ignore - stub = comm_pb2_grpc.CommunicationServerStub(channel) # type: ignore + with grpc.insecure_channel(dest_host) as channel: # + stub = comm_pb2_grpc.CommunicationServerStub(channel) # max_tries = 10 while max_tries > 0: try: - model = comm_pb2.Model(buffer=buffer) # type: ignore - stub.send_model(model) # type: ignore + model = comm_pb2.Model(buffer=buffer) # + stub.send_model(model) # except grpc.RpcError as e: print(f"RPC failed {10 - max_tries} times: {e}", "Retrying...") # sleep for a random time between 1 and 10 seconds @@ -410,10 +410,10 @@ def wait_until_rounds_match(self, id: int): raise Exception("Base node not registered") self_round = self.servicer.base_node.get_local_rounds() def callback_fn(stub: comm_pb2_grpc.CommunicationServerStub) -> int: - round = stub.get_current_round(comm_pb2.Empty()) # type: ignore + round = stub.get_current_round(comm_pb2.Empty()) # with self.servicer.lock: self.servicer.communication_cost_received += round.ByteSize() - return round.round # type: ignore + return round.round # while True: host = self.get_host_from_rank(id) @@ -444,10 +444,10 @@ def receive(self, node_ids: List[int]) -> List[Any]: self.wait_until_rounds_match(id) items: List[Any] = [] def callback_fn(stub: comm_pb2_grpc.CommunicationServerStub) -> OrderedDict[str, Tensor]: - model = stub.get_model(comm_pb2.Empty()) # type: ignore + model = stub.get_model(comm_pb2.Empty()) # with self.servicer.lock: self.servicer.communication_cost_received += model.ByteSize() - return deserialize_message(model.buffer) # type: ignore + return deserialize_message(model.buffer) # for id in node_ids: rank = self.get_host_from_rank(id) @@ -562,18 +562,18 @@ def finalize(self): for peer_id in self.servicer.peer_ids: if not self.is_own_id(peer_id): host = self.get_host_from_rank(peer_id) - with grpc.insecure_channel(host) as channel: # type: ignore + with grpc.insecure_channel(host) as channel: # stub = comm_pb2_grpc.CommunicationServerStub(channel) - stub.send_quorum(comm_pb2.Quorum(quorum=True)) # type: ignore + stub.send_quorum(comm_pb2.Quorum(quorum=True)) # else: # send finished to the super node - with grpc.insecure_channel(self.super_node_host) as channel: # type: ignore + with grpc.insecure_channel(self.super_node_host) as channel: # stub = comm_pb2_grpc.CommunicationServerStub(channel) - stub.send_finished(comm_pb2.Rank(rank=self.rank)) # type: ignore + stub.send_finished(comm_pb2.Rank(rank=self.rank)) # status = self.servicer.quorum.get() if not status: print("Quorum became false!") sys.exit(1) if self.listener: - self.listener.stop(0) # type: ignore + self.listener.stop(0) # print(f"Stopped server on port {self.port}") diff --git a/src/utils/types.py b/src/utils/custom_types.py similarity index 100% rename from src/utils/types.py rename to src/utils/custom_types.py diff --git a/src/utils/dropout_utils.py b/src/utils/dropout_utils.py index fa70db1a..750f879e 100644 --- a/src/utils/dropout_utils.py +++ b/src/utils/dropout_utils.py @@ -1,5 +1,5 @@ import random -from utils.types import ConfigType +from utils.custom_types import ConfigType class NodeDropout: """ diff --git a/src/utils/log_utils.py b/src/utils/log_utils.py index 39d4ee89..48ac7c23 100644 --- a/src/utils/log_utils.py +++ b/src/utils/log_utils.py @@ -15,7 +15,7 @@ from tensorboardX import SummaryWriter # type: ignore import numpy as np import pandas as pd -from utils.types import ConfigType +from utils.custom_types import ConfigType import json import matplotlib.pyplot as plt diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py index 97c639eb..faf49954 100644 --- a/src/utils/model_utils.py +++ b/src/utils/model_utils.py @@ -11,7 +11,7 @@ import resnet_in import yolo -from utils.types import ConfigType +from utils.custom_types import ConfigType class ModelUtils: def __init__(self, device: torch.device, config: ConfigType) -> None: