[docs]classWeakSet(MutableSet):"""A Set that uses weak references and does not check for equality. Normal sets check that two elements are equal by comparing __hash__ and __eq__. For tensors, __eq__ is slow so this class only checks __hash__. This is a bad idea normally but because we control the hash for Tensors, we can (hopefully) use it for gradient calculation To avoid circular dependencies, we implement this as a weak set so that the garbage collector can clean up objects that are referred to by this class Attributes: _dict: A WeakValueDictionary to store the weak references. """def__init__(self,*args,**kwargs):"""Initializes the WeakSet. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """super().__init__(*args,**kwargs)self._dict=WeakValueDictionary()def__contains__(self,x:Any)->bool:"""Checks if an element is in the set. Args: x: The element to check. Returns: bool: True if the element is in the set, False otherwise. """returnhash(x)inself._dictdef__iter__(self):"""Returns an iterator over the elements in the set. Returns: iterator: An iterator over the values in the WeakValueDictionary. """returnself._dict.values()def__len__(self):"""Returns the number of elements in the set. Returns: int: The number of elements in the set. """returnlen(self._dict)
[docs]defadd(self,x:Any):"""Adds an element to the set. Args: x: The element to add to the set. """self._dict[hash(x)]=x
[docs]defdiscard(self,x:Any):"""Removes an element from the set if it is present. Args: x: The element to remove from the set. """ifhash(x)inself._dict:delself._dict[hash(x)]