diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 96919bc..f477ac4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -50,3 +50,6 @@ jobs: run: | uv run --python ${{ matrix.python-version }} --extra dev ruff check --no-fix uv run --python ${{ matrix.python-version }} --extra dev ruff format --check + - name: Type check + if: matrix.python-version == '3.13' + run: uv run --python ${{ matrix.python-version }} --extra dev pyright diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index c8257df..b6b3ede 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -37,6 +37,9 @@ jobs: run: | uv run --python ${{ matrix.python-version }} --extra dev ruff check --no-fix uv run --python ${{ matrix.python-version }} --extra dev ruff format --check + - name: Type check + if: matrix.python-version == '3.13' + run: uv run --python ${{ matrix.python-version }} --extra dev pyright publish: name: Publish to PyPI diff --git a/mmdb_writer.py b/mmdb_writer/__init__.py similarity index 80% rename from mmdb_writer.py rename to mmdb_writer/__init__.py index caa0a2f..776c512 100644 --- a/mmdb_writer.py +++ b/mmdb_writer/__init__.py @@ -1,18 +1,19 @@ -__version__ = "0.2.6" +__version__ = "0.2.7" import logging import math import struct import time +from collections.abc import Callable, Iterator from decimal import Decimal from enum import IntEnum -from typing import Literal, Union +from typing import Any, Literal, Optional, Union, cast, overload from netaddr import IPNetwork, IPSet class MmdbBaseType: - def __init__(self, value): + def __init__(self, value: Any): self.value = value @@ -53,8 +54,8 @@ def __init__(self, value: int): MMDBType = Union[ - dict, - list, + dict[Any, Any], + list[Any], str, bytes, int, @@ -95,13 +96,16 @@ class MMDBTypeID(IntEnum): UINT32_MAX = 0xFFFFFFFF UINT64_MAX = 0xFFFFFFFFFFFFFFFF +# A child slot of a SearchTreeNode: another node, a leaf, or empty. +TreeChild = Union["SearchTreeNode", "SearchTreeLeaf", None] + class SearchTreeNode: - def __init__(self, left=None, right=None): - self.left = left - self.right = right + def __init__(self, left: "TreeChild" = None, right: "TreeChild" = None): + self.left: TreeChild = left + self.right: TreeChild = right - def get_or_create(self, item): + def get_or_create(self, item: int) -> "TreeChild": if item == 0: self.left = self.left or SearchTreeNode() return self.left @@ -109,13 +113,13 @@ def get_or_create(self, item): self.right = self.right or SearchTreeNode() return self.right - def __getitem__(self, item): + def __getitem__(self, item: int) -> "TreeChild": if item == 0: return self.left elif item == 1: return self.right - def __setitem__(self, key, value): + def __setitem__(self, key: int, value: "TreeChild") -> None: if key == 0: self.left = value elif key == 1: @@ -123,10 +127,10 @@ def __setitem__(self, key, value): class SearchTreeLeaf: - def __init__(self, value): + def __init__(self, value: MMDBType): self.value = value - def __repr__(self): + def __repr__(self) -> str: return f"SearchTreeLeaf(value={self.value})" __str__ = __repr__ @@ -157,16 +161,19 @@ def __repr__(self): class Encoder: def __init__( - self, cache=True, int_type: IntType = "auto", float_type: FloatType = "f64" + self, + cache: bool = True, + int_type: IntType = "auto", + float_type: FloatType = "f64", ): self.cache = cache self.int_type = int_type self.float_type = float_type - self.data_cache = {} - self.data_list = [] + self.data_cache: dict[Any, int] = {} + self.data_list: list[bytes] = [] self.data_pointer = 0 - self._python_type_id = { + self._python_type_id: dict[type, MMDBTypeID] = { float: MMDBTypeID.DOUBLE, bool: MMDBTypeID.BOOLEAN, list: MMDBTypeID.ARRAY, @@ -182,7 +189,7 @@ def __init__( MmdbU128: MMDBTypeID.UINT128, } - def _encode_pointer(self, value): + def _encode_pointer(self, value: int) -> bytes: pointer = value if pointer >= 134744064: res = struct.pack(">BI", 0x38, pointer) @@ -208,19 +215,19 @@ def _encode_pointer(self, value): return res - def _encode_utf8_string(self, value): + def _encode_utf8_string(self, value: str) -> bytes: encoded_value = value.encode("utf-8") res = self._make_header(MMDBTypeID.STRING, len(encoded_value)) res += encoded_value return res - def _encode_bytes(self, value): + def _encode_bytes(self, value: bytes) -> bytes: return self._make_header(MMDBTypeID.BYTES, len(value)) + value - def _encode_uint(self, type_id, max_len): + def _encode_uint(self, type_id: MMDBTypeID, max_len: int) -> Callable[[int], bytes]: value_max = 2 ** (max_len * 8) - def _encode_unsigned_value(value): + def _encode_unsigned_value(value: int) -> bytes: if value < 0 or value >= value_max: raise ValueError( f"encode uint{max_len * 8} fail: " @@ -234,7 +241,7 @@ def _encode_unsigned_value(value): return _encode_unsigned_value - def _encode_map(self, value): + def _encode_map(self, value: dict[Any, Any]) -> bytes: res = self._make_header(MMDBTypeID.MAP, len(value)) for k, v in list(value.items()): # Keys are always stored by value. @@ -242,26 +249,28 @@ def _encode_map(self, value): res += self.encode(v) return res - def _encode_array(self, value): + def _encode_array(self, value: list[Any]) -> bytes: res = self._make_header(MMDBTypeID.ARRAY, len(value)) for k in value: res += self.encode(k) return res - def _encode_boolean(self, value): + def _encode_boolean(self, value: Any) -> bytes: return self._make_header(MMDBTypeID.BOOLEAN, 1 if value else 0) - def _encode_pack_type(self, type_id, fmt): - def pack_type(value): + def _encode_pack_type( + self, type_id: MMDBTypeID, fmt: str + ) -> Callable[[Any], bytes]: + def pack_type(value: Any) -> bytes: res = struct.pack(fmt, value) return self._make_header(type_id, len(res)) + res return pack_type - _type_encoder = None + _type_encoder: Optional[dict[int, Callable[..., bytes]]] = None @property - def type_encoder(self): + def type_encoder(self) -> dict[int, Callable[..., bytes]]: if self._type_encoder is None: self._type_encoder = { MMDBTypeID.POINTER: self._encode_pointer, @@ -280,7 +289,7 @@ def type_encoder(self): } return self._type_encoder - def _make_header(self, type_id, length): + def _make_header(self, type_id: int, length: int) -> bytes: if length >= 16843036: raise Exception("length >= 16843036") @@ -315,7 +324,7 @@ def _make_header(self, type_id, length): return res + additional_length_bytes - def python_type_id(self, value): + def python_type_id(self, value: Any) -> MMDBTypeID: value_type = type(value) type_id = self._python_type_id.get(value_type) if type_id: @@ -355,14 +364,16 @@ def python_type_id(self, value): return MMDBTypeID.DOUBLE raise TypeError(f"unknown type {value_type}") - def _freeze(self, value): + def _freeze(self, value: Any) -> Any: if isinstance(value, dict): - return tuple((k, self._freeze(v)) for k, v in value.items()) + mapping: dict[Any, Any] = value + return tuple((k, self._freeze(v)) for k, v in mapping.items()) elif isinstance(value, list): - return tuple(self._freeze(v) for v in value) + seq: list[Any] = value + return tuple(self._freeze(v) for v in seq) return value - def encode_meta(self, meta): + def encode_meta(self, meta: dict[str, Any]) -> bytes: res = self._make_header(MMDBTypeID.MAP, len(meta)) meta_type = { "node_count": 6, @@ -378,7 +389,26 @@ def encode_meta(self, meta): res += self.encode(v, meta_type.get(k)) return res - def encode(self, value, type_id=None, return_offset=False): + @overload + def encode( + self, value: Any, type_id: Optional[int] = ..., *, return_offset: Literal[True] + ) -> int: ... + + @overload + def encode( + self, + value: Any, + type_id: Optional[int] = ..., + return_offset: Literal[False] = ..., + ) -> bytes: ... + + def encode( + self, + value: Any, + type_id: Optional[int] = None, + return_offset: bool = False, + ) -> Union[int, bytes]: + cache_key: Any = None if self.cache: cache_key = self._freeze(value) try: @@ -409,20 +439,22 @@ def encode(self, value, type_id=None, return_offset=False): class TreeWriter: - encoder_cls = Encoder + encoder_cls: type[Encoder] = Encoder def __init__( self, - tree: "SearchTreeNode", - meta: dict, + tree: SearchTreeNode, + meta: dict[str, Any], int_type: IntType = "auto", float_type: FloatType = "f64", ): - self._node_idx = {} - self._leaf_offset = {} - self._node_list = [] + self._node_idx: dict[int, int] = {} + self._leaf_offset: dict[int, int] = {} + self._node_list: list[SearchTreeNode] = [] self._node_counter = 0 self._record_size = 0 + self.record_size = 0 + self.data_offset: float = 0 self.tree = tree self.meta = meta @@ -432,21 +464,21 @@ def __init__( ) @property - def _data_list(self): + def _data_list(self) -> list[bytes]: return self.encoder.data_list @property - def _data_pointer(self): + def _data_pointer(self) -> int: return self.encoder.data_pointer + 16 - def _build_meta(self): + def _build_meta(self) -> dict[str, Any]: return { "node_count": self._node_counter, "record_size": self.record_size, **self.meta, } - def _adjust_record_size(self): + def _adjust_record_size(self) -> None: # Tree records should be large enough to contain either tree node index # or data offset. max_id = self._node_counter + self._data_pointer + 1 @@ -464,7 +496,7 @@ def _adjust_record_size(self): self.data_offset = self.record_size * 2 / 8 * self._node_counter - def _enumerate_nodes(self, node): + def _enumerate_nodes(self, node: TreeChild) -> None: if type(node) is SearchTreeNode: node_id = id(node) if node_id not in self._node_idx: @@ -483,7 +515,7 @@ def _enumerate_nodes(self, node): else: # == None return - def _calc_record_idx(self, node): + def _calc_record_idx(self, node: TreeChild) -> int: if node is None: return self._node_counter elif type(node) is SearchTreeNode: @@ -493,7 +525,7 @@ def _calc_record_idx(self, node): else: raise Exception("unexpected type") - def _cal_node_bytes(self, node) -> bytes: + def _cal_node_bytes(self, node: SearchTreeNode) -> bytes: left_idx = self._calc_record_idx(node.left) right_idx = self._calc_record_idx(node.right) @@ -522,7 +554,7 @@ def _cal_node_bytes(self, node) -> bytes: else: raise Exception("self.record_size > 32") - def write(self, fname): + def write(self, fname: str) -> None: self._enumerate_nodes(self.tree) self._adjust_record_size() @@ -539,18 +571,18 @@ def write(self, fname): f.write(self.encoder_cls(cache=False).encode_meta(self._build_meta())) -def bits_rstrip(n, length=None, keep=0): +def bits_rstrip(n: int, length: int = 0, keep: int = 0) -> Iterator[int]: return map(int, bin(n)[2:].rjust(length, "0")[:keep]) class MMDBWriter: def __init__( self, - ip_version=4, - database_type="GeoIP", - languages: list[str] = None, + ip_version: int = 4, + database_type: str = "GeoIP", + languages: Optional[list[str]] = None, description: Union[dict[str, str], str] = "GeoIP db", - ipv4_compatible=False, + ipv4_compatible: bool = False, int_type: IntType = "auto", float_type: FloatType = "f64", ): @@ -576,7 +608,7 @@ def __init__( if languages is None: languages = [] - self.description = description + self.description: Union[dict[str, str], str] = description self.database_type = database_type self.ip_version = ip_version self.languages = languages @@ -599,10 +631,10 @@ def __init__( if i not in self.description: raise ValueError("language {} must have description!") - self.int_type = int_type - self.float_type = float_type + self.int_type: IntType = int_type + self.float_type: FloatType = float_type - def insert_network(self, network: IPSet, content: MMDBType): + def insert_network(self, network: IPSet, content: MMDBType) -> None: """ Inserts a network into the MaxMind database. @@ -625,8 +657,8 @@ def insert_network(self, network: IPSet, content: MMDBType): leaf = SearchTreeLeaf(content) if not isinstance(network, IPSet): raise ValueError("network type should be netaddr.IPSet.") - network = network.iter_cidrs() - for cidr in network: + cidrs = network.iter_cidrs() + for cidr in cidrs: if self.ip_version == 4 and cidr.version == 6: raise ValueError( f"You inserted a IPv6 address {cidr} to an IPv4-only database." @@ -639,15 +671,16 @@ def insert_network(self, network: IPSet, content: MMDBType): "IPv4 address in IPv6 database as ::/96 format" ) cidr = cidr.ipv6(True) - node = self.tree - bits = list(bits_rstrip(cidr.value, self._bit_length, cidr.prefixlen)) - current_node = node - supernet_leaf = None # Tracks whether we are inserting into a subnet + bits = list( + bits_rstrip(cast(int, cidr.value), self._bit_length, cidr.prefixlen) + ) + current_node = self.tree + supernet_leaf: Optional[SearchTreeLeaf] = None # set when in a subnet for index, ip_bit in enumerate(bits[:-1]): previous_node = current_node - current_node = previous_node.get_or_create(ip_bit) + child = previous_node.get_or_create(ip_bit) - if isinstance(current_node, SearchTreeLeaf): + if isinstance(child, SearchTreeLeaf): current_cidr = IPNetwork( ( int( @@ -661,11 +694,15 @@ def insert_network(self, network: IPSet, content: MMDBType): ) logger.info( f"Inserting {cidr} ({content}) into subnet of " - f"{current_cidr} ({current_node.value})" + f"{current_cidr} ({child.value})" ) - supernet_leaf = current_node + supernet_leaf = child current_node = SearchTreeNode() previous_node[ip_bit] = current_node + else: + # ip_bit is always 0 or 1, so get_or_create returns a node here. + assert child is not None + current_node = child if supernet_leaf: next_bit = bits[index + 1] @@ -674,12 +711,12 @@ def insert_network(self, network: IPSet, content: MMDBType): current_node[1 - next_bit] = supernet_leaf current_node[bits[-1]] = leaf - def to_db_file(self, filename: str): + def to_db_file(self, filename: str) -> None: return TreeWriter( self.tree, self._build_meta(), self.int_type, self.float_type ).write(filename) - def _build_meta(self): + def _build_meta(self) -> dict[str, Any]: return { "ip_version": self.ip_version, "database_type": self.database_type, diff --git a/mmdb_writer/py.typed b/mmdb_writer/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 6f9dc32..0f4019e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ test = [ ] dev = [ "ruff", + "pyright==1.1.409", ] [project.urls] @@ -50,7 +51,7 @@ Source = "https://github.com/vimt/MaxMind-DB-Writer-python" Tracker = "https://github.com/vimt/MaxMind-DB-Writer-python/issues" [tool.flit.sdist] -include = ["mmdb_writer.py"] +include = ["mmdb_writer/"] [tool.pytest.ini_options] testpaths = ["tests"] @@ -58,6 +59,11 @@ filterwarnings = [ "error", ] +[tool.pyright] +include = ["mmdb_writer"] +pythonVersion = "3.9" +typeCheckingMode = "standard" + [tool.ruff] line-length = 88 target-version = "py39"