From e03940381b0e36e43cc8e906372b8f4c490b5f9b Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Mon, 22 Jun 2026 18:26:57 +0100 Subject: [PATCH] v3.1.0 Strict mode. Robuster support for unicode inc bug fixes. Make dbf string encoding robust (names and records) Replace strip_leading_whitespace and replace_ascii_spaces_with_underscores with strict Respect unicode code point boundaries when truncating field names Pass mypy - make repr of bytes explicit in f string Handle .dbf C & M fields with _encode_dbf_string & _decode_C_or_M_field Make dbf string encoding robust (names and records) Make en/decoding of dbf field name and "C" field values robuster. Make dbf record writing atomic --- README.md | 22 ++- changelog.txt | 25 ++- src/shapefile.py | 315 +++++++++++++++++++++++++++++--------- tests/hypothesis_tests.py | 25 +-- tests/run_benchmarks.py | 30 ++-- tests/test_shapefile.py | 73 +++++++-- 6 files changed, 379 insertions(+), 111 deletions(-) diff --git a/README.md b/README.md index 17092ada..ad06d551 100644 --- a/README.md +++ b/README.md @@ -93,12 +93,32 @@ part of your geospatial project. # Version Changes -## 3.0.14.dev +## 3.1.0 +### Unicode support made more robust and encoding bugs fixed + - Truncation of field names and text fields now respects unicode code point boundaries (fixes issues - + 416 and 148). + - Warnings raised if truncation, or replacing b" " with b"_" would corrupt encoded field names, - + both if they would either be undecodable, or would silently decode to incorrect data (warns users + if issue 421 applies). + - Correctly truncated field names are now stored in field instances, as would actually be seen in the file. + - New strict mode. Writer(strict=True) raises errors or refuse to create fields and text records with data that - + would be truncated or cannot be correctly decoded back again by PyShp, exactly as given by the user. + - In strict mode, ascii spaces in encoded names are no longer replaced by ascii underscores at all + (work around to avoid corrupting unicode field names - provides opt-in fix for issue 421). + - BREAKING. When reading .dbf files, Trailing ascii spaces in text fields before a null terminator char (in the - + decoded string) is now removed (i.e. instead of .strip().rstrip('\x00') we now do: .rstrip("\x00").rstrip(" ")). + - BREAKING. Enclosing Whitespace other than trailing ascii spaces (0x20) after null chars in text fields is now - + preserved, when reading .dbf files (fixes issue 418 - James feels this was a bug. Let him know if you think otherwise). + - BREAKING Trailing null chars other than null terminators & null padding bytes, followed by whitespace other than + ascii spaces, are now preserved. + - Writing dbf records is now atomic. + ### ShpWriter.shape API Tweak (small breaking change). - Make ShpWriter.shape return shape length in bytes (the same as for offset) not in 16 bit words. ### Testing - Include NullShapes in shp round trip test. + - En/decoding of Dbf files and Fields round trips correctly. ## 3.0.13 diff --git a/changelog.txt b/changelog.txt index a18b5109..71560b96 100644 --- a/changelog.txt +++ b/changelog.txt @@ -1,4 +1,27 @@ -VERSION 3.0.14.dev +VERSION 3.1.0 + +2026-06-23 + Unicode support made more robust and encoding bugs fixed + * Truncation of field names and text fields now respects unicode code point boundaries (fixes issues + 416 and 148). + * Warnings raised if truncation, or replacing b" " with b"_" would corrupt encoded field names, + both if they would either be undecodable, or would silently decode to incorrect data (warns users + if issue 421 applies). + * Correctly truncated field names are now stored in field instances, as would actually be seen in the file. + * New strict mode. Writer(strict=True) raises errors or refuse to create fields and text records with data that + would be truncated or cannot be correctly decoded back again by PyShp, exactly as given by the user. + * In strict mode, ascii spaces in encoded names are no longer replaced by ascii underscores at all + (work around to avoid corrupting unicode field names - provides opt-in fix for issue 421) + * BREAKING. When reading .dbf files, Trailing ascii spaces in text fields before a null terminator char (in the + decoded string) is now removed (i.e. instead of .strip().rstrip('\x00') we now do: .rstrip("\x00").rstrip(" ")). + * BREAKING. Enclosing Whitespace other than trailing ascii spaces (0x20) after null chars in text fields is now + preserved, when reading .dbf files (fixes issue 418 - James feels this was a bug. Let him know if you think otherwise). + * BREAKING Trailing null chars other than null terminators & null padding bytes, followed by whitespace other than + ascii spaces, are now preserved + * Writing dbf records is now atomic. + Testing. + * En/decoding of Dbf files and Fields round trips correctly. + 2026-06-20 * API Tweak (small breaking change). Make ShpWriter.shape return shape length in bytes diff --git a/src/shapefile.py b/src/shapefile.py index 84f468db..d080e37a 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -8,7 +8,7 @@ from __future__ import annotations -__version__ = "3.0.14.dev" +__version__ = "3.1.0" import abc import array @@ -24,7 +24,7 @@ import zipfile from collections.abc import Container, Iterable, Iterator, Reversible, Sequence from contextlib import AbstractContextManager, ExitStack -from datetime import date +from datetime import date, datetime from os import PathLike from pathlib import Path from struct import Struct, calcsize, error, pack, unpack @@ -235,38 +235,110 @@ class PossibleDataLoss(Warning): pass -def _largest_valid_truncated_encoding( +class DbfStringDataLoss(ValueError): + pass + + +class Decoder(Protocol): + __name__: str + + def __call__( + self, + b: bytes, + encoding: str = "utf8", + encodingErrors: str = "strict", + strict: bool = False, + ) -> str: ... + + +def _encode_dbf_string( s: str, - max_bytes: int, - strict: bool, + size: int, + decode: Decoder, + pad_byte: bytes | None = None, encoding: str = "utf8", encodingErrors: str = "strict", + strict: bool = True, ) -> tuple[bytes, str]: + """Attempts to encoded s with the codec specified, + progressively truncating its code points until + the resulting bytes are less than size + (e.g. the dbf field length or field name size == 10). + If less, these bytes are then padded to length size. + + Replaces: s.encode(self.encoding, self.encodingErrors)[:size] + .ljust(size)) + in the legacy string encoding implementation: + """ N = len(s) + trimmed: str + encoded: bytes for i in reversed(range(0, N + 1)): trimmed = s[:i] encoded = trimmed.encode(encoding, encodingErrors) - if len(encoded) <= max_bytes: + + if len(encoded) <= size: if i <= N - 1: msg = ( f"Dropped {N - i} code points (e.g. characters)! " f"{s} was truncated to {trimmed} (discarding: {s[i:]}), " - f"in order to encode it under {max_bytes} bytes for the field or field name. " + f"in order to encode it under {size} bytes for the field or field name. " f"Used: {encoding=} and {encodingErrors=}. " ) if strict: - raise ValueError(f"Data loss. {strict=}.\n{msg}") - else: - warnings.warn( - msg, - category=PossibleDataLoss, - ) - return encoded, trimmed - raise ValueError( - f"Maximum truncation not sufficient to encode below {max_bytes=}. " - f"Could not encode first code point (e.g. character): {s[0]} " - f"to a short enough byte string, using {encoding=}, {encodingErrors=}" + raise DbfStringDataLoss(f"Data loss. {strict=}.\n{msg}") + warnings.warn( + msg, + category=PossibleDataLoss, + ) + break + else: # for loop did not break, len(encoded) <= size, + # e.g. encoding "" preppends a BOM bigger than size. + raise ValueError( + f"Maximum truncation not sufficient to encode below {size=}. " + f"Could not encode first code point (e.g. character): {s[0]} " + f"to a short enough byte string, using {encoding=}, {encodingErrors=}" + ) + + if len(encoded) < size and pad_byte is not None: + padded = encoded.ljust(size, pad_byte) + else: + padded = encoded + + decoded = decode( + b=padded, + encoding=encoding, + encodingErrors=encodingErrors, ) + if decoded != trimmed: + msg = f"Padded value: {padded!r} does not decode to {trimmed!r} using PyShp's decoder: {decode.__name__}" + if len(trimmed) < len(s): + msg = f"{msg} (trimmed, original string: {s}). " + if strict: + raise DbfStringDataLoss(msg) + warnings.warn( + msg, + category=PossibleDataLoss, + ) + + return padded, trimmed + + +def _decode_C_or_M_field( + b: bytes, + encoding: str = "utf8", + encodingErrors: str = "strict", + strict: bool = True, +) -> str: + retval = b.decode(encoding, encodingErrors).rstrip("\x00").rstrip(" ") + if retval.rstrip("\x00") != retval and strict: + msg = ( + f"More Trailing Null chars in: {b!r}" + " after removing trailing null chars and ascii spaces" + f", resulting in {retval!r}" + ) + warnings.warn(msg, category=PossibleDataLoss) + return retval class Field(NamedTuple): @@ -281,28 +353,60 @@ def get_struct(cls) -> Struct: # En/decoding the name as "<10sx" embeds the null terminator. return Struct("<10sxc4xBB14x") + @staticmethod + def decode_name( + b: bytes, + encoding: str = "utf8", + encodingErrors: str = "strict", + strict: bool = True, + ) -> str: + N = len(b) + decoded: str + num_trailing_null_bytes = N - len(b.rstrip(b"\x00")) + + # Test if we need to restore any of those null bytes to + # correctly decode the remaining bytes to a string. + for num_to_trim in reversed(range(num_trailing_null_bytes + 1)): + i = N - num_to_trim + trimmed = b[:i] + try: + decoded = trimmed.decode(encoding, encodingErrors) + except UnicodeDecodeError: + continue + if strict and num_to_trim < num_trailing_null_bytes: + warnings.warn( + f"Used {num_trailing_null_bytes - num_to_trim} null bytes " + f"from padding to decode {b!r} " + f"to: {decoded!r} ({encoding=}, {encodingErrors=}) ", + category=PossibleDataLoss, + ) + if not strict: + decoded = decoded.lstrip() + return decoded + + raise dbfFileException( + f"Could not decode field name: {b!r} using {encoding=} and {encodingErrors=}" + " no matter how many trailing null-bytes (if any) were used. " + ) + @classmethod def from_byte_stream( cls, b_io: ReadableBinStream, - strict: bool = False, encoding: str = "utf8", encodingErrors: str = "strict", - strip_leading_whitespace: bool = True, + strict: bool = False, ) -> Field: encoded_field_tuple: tuple[bytes, bytes, int, int] encoded_field_tuple = cls.get_struct().unpack(b_io.read(32)) encoded_name, encoded_type_char, size, decimal = encoded_field_tuple - encoded_name, __, ___ = encoded_name.partition(b"\x00") - name = encoded_name.decode(encoding, encodingErrors) - if strip_leading_whitespace: - name = name.lstrip() + name = cls.decode_name(encoded_name, encoding, encodingErrors) field_type = FIELD_TYPE_ALIASES[encoded_type_char] return cls.from_unchecked( - name, field_type, size, decimal, strict, encoding, encodingErrors + name, field_type, size, decimal, encoding, encodingErrors, strict ) @classmethod @@ -312,15 +416,15 @@ def from_unchecked( field_type: str | bytes | FieldTypeT = "C", size: int = 50, decimal: int = 0, - strict: bool = False, encoding: str = "utf8", encodingErrors: str = "strict", + strict: bool = False, ) -> Field: if "\x00" in name: msg = ( - "Field names should contain null characters " - "as null bytes are used to pad them in the header. " + "Field names should not contain null characters " + "as null bytes are used for padding in the header. " f"Got: {name=} " ) if strict: @@ -330,8 +434,8 @@ def from_unchecked( try: type_ = FIELD_TYPE_ALIASES[field_type] except KeyError: - raise ShapefileException( - f"field_type must be in {{FieldType.__members__}}. Got: {field_type=}. " + raise dbfFileException( + f"field_type must be in {FieldType.__members__}. Got: {field_type=}. " ) if type_ is FieldType.D: @@ -341,29 +445,71 @@ def from_unchecked( size = 1 decimal = 0 + # Only use the portion of the name that we are able to encode to + # 10 bytes or less. + _encoded_name, trimmed_name = cls.trim_name_until_encodable( + name=str(name), + encoding=encoding, + encodingErrors=encodingErrors, + strict=strict, + ) + # A doctest in README.md previously passed in a string ('40') for size, # so explictly convert name to str, and size and decimal to ints. inst = cls( - name=str(name), field_type=type_, size=int(size), decimal=int(decimal) + name=trimmed_name, field_type=type_, size=int(size), decimal=int(decimal) ) + # Raise Exception or trigger warning early, before user adds more fields + # (fields are only written when first record added, and on close) inst.encode_field_descriptor( - strict=True, encoding=encoding, encodingErrors=encodingErrors + encoding=encoding, + encodingErrors=encodingErrors, + strict=strict, ) return inst + @classmethod + def trim_name_until_encodable( + cls, + name: str, + encoding: str = "utf8", + encodingErrors: str = "strict", + strict: bool = False, + ) -> tuple[bytes, str]: + return _encode_dbf_string( + s=name, + size=10, + decode=cls.decode_name, + pad_byte=b"\x00", + encoding=encoding, + encodingErrors=encodingErrors, + strict=strict, + ) + @functools.cache def encode_field_descriptor( self, - strict: bool = False, encoding: str = "utf8", encodingErrors: str = "strict", - replace_ascii_spaces_with_underscores: bool = True, + strict: bool = False, ) -> bytes: - encoded_name = self.name.encode(encoding, encodingErrors) - if replace_ascii_spaces_with_underscores: + # encoded_name = self.name.encode(encoding, encodingErrors) + # encoded_name = encoded_name[:10].ljust(10, b"\x00") + encoded_name, _trimmed_name = self.trim_name_until_encodable( + name=self.name, + encoding=encoding, + encodingErrors=encodingErrors, + strict=strict, + ) + if not strict and b" " in encoded_name: + warnings.warn( + "Replacing ascii spaces (0x20) with underscores " + f"in encoded bytes: {encoded_name!r}", + category=PossibleDataLoss, + ) encoded_name = encoded_name.replace(b" ", b"_") - encoded_name = encoded_name[:10].ljust(10, b"\x00") + encoded_field_type = self.field_type.encode("ascii") return self.get_struct().pack( encoded_name, @@ -2708,9 +2854,9 @@ def _dbfHeader(self) -> None: self.fields.append( Field.from_byte_stream( b_io=self.file, - strict=self.strict, encoding=self.encoding, encodingErrors=self.encodingErrors, + strict=self.strict, ) ) @@ -2734,6 +2880,11 @@ def _dbfHeader(self) -> None: self._fullRecStruct = recStruct self._fullRecLookup = recLookup + @property + def data_fields(self) -> list[Field]: + """All fields except the DeletionFlag.""" + return self.fields[1:] + def _record_fmt(self, fields: Container[str] | None = None) -> tuple[str, int]: """Calculates the format and size of a .dbf record. Optional 'fields' arg specifies which fieldnames to unpack and which to ignore. Note that this @@ -2781,11 +2932,10 @@ def _record_fields( # fetch relevant field info tuples fieldTuples = [] for fieldinfo in self.fields[1:]: - name = fieldinfo[0] - if name in unique_fields: + if fieldinfo.name in unique_fields: fieldTuples.append(fieldinfo) # store the field positions - recLookup = {f[0]: i for i, f in enumerate(fieldTuples)} + recLookup = {f.name: i for i, f in enumerate(fieldTuples)} else: # use all the dbf fields fieldTuples = self.fields[1:] # sans deletion flag @@ -2831,8 +2981,8 @@ def _record( for (__name, typ, __size, decimal), value in zip(fieldTuples, recordContents): if typ is FieldType.N or typ is FieldType.F: # numeric or float: number stored as a string, right justified, and padded with blanks to the width of the field. - value = value.split(b"\0")[0] - value = value.replace(b"*", b"") # QGIS NULL is all '*' chars + value, __, __ = value.partition(b"\x00") + value = value.strip(b"*") # QGIS NULL is all '*' chars if value == b"": value = None elif decimal: @@ -2866,13 +3016,13 @@ def _record( # but can check for all hex null-chars, all spaces, or all 0s (QGIS null) value = None else: + date_str = value.decode("ascii") try: # return as python date object - y, m, d = int(value[:4]), int(value[4:6]), int(value[6:8]) - value = date(y, m, d) + value = datetime.strptime(date_str, "%Y%m%d").date() except (TypeError, ValueError): - # if invalid date, just return as unicode string so user can decimalde - value = str(value.strip()) + # if invalid date, just return as unicode string so user can handle it. + value = date_str elif typ is FieldType.L: # logical: 1 byte - initialized to 0x20 (space) otherwise T or F. if value == b" ": @@ -2885,10 +3035,11 @@ def _record( else: value = None # unknown value is set to missing else: - value = value.decode(self.encoding, self.encodingErrors) - value = value.strip().rstrip( - "\x00" - ) # remove null-padding at end of strings + value = _decode_C_or_M_field( + value, + encoding=self.encoding, + encodingErrors=self.encodingErrors, + ) record.append(value) return _Record(recLookup, record, oid) @@ -3508,6 +3659,11 @@ def numRecords(self) -> int | None: def fields(self) -> list[Field]: return self.dbf_reader.fields + @property + def data_fields(self) -> list[Field]: + """All fields except the DeletionFlag.""" + return self.dbf_reader.data_fields + def record(self, i: int = 0, fields: list[str] | None = None) -> _Record | None: return self.dbf_reader.record(i, fields) @@ -3959,7 +4115,7 @@ def field( raise dbfFileException( f".dbf Shapefile Writer reached maximum number of fields: {self.max_num_fields}." ) - field_ = Field.from_unchecked( + field = Field.from_unchecked( name=name, field_type=field_type, size=size, @@ -3968,7 +4124,7 @@ def field( encodingErrors=self.encodingErrors, strict=self.strict, ) - self.fields.append(field_) + self.fields.append(field) def _header(self) -> None: """Writes the dbf header and field descriptors.""" @@ -3978,17 +4134,17 @@ def _header(self) -> None: year, month, day = time.localtime()[:3] year -= 1900 # Get all fields, ignoring DeletionFlag if specified - fields = [field for field in self.fields if field[0] != "DeletionFlag"] + fields = [field for field in self.fields if field.name != "DeletionFlag"] # Ensure has at least one field if not fields: - raise ShapefileException( + raise dbfFileException( "Shapefile dbf file must contain at least one field." ) numRecs = self.recNum numFields = len(fields) headerLength = numFields * 32 + 33 if headerLength >= 65535: - raise ShapefileException( + raise dbfFileException( "Shapefile dbf header length exceeds maximum length." ) recordLength = sum(field.size for field in fields) + 1 @@ -4008,7 +4164,7 @@ def _header(self) -> None: for field in fields: f.write( field.encode_field_descriptor( - self.strict, self.encoding, self.encodingErrors + self.encoding, self.encodingErrors, self.strict ) ) @@ -4053,19 +4209,19 @@ def record( def _record(self, record: list[RecordValue]) -> None: """Writes the dbf records.""" - f = self.file + record_stream = io.BytesIO() # Temporary buffer to make record writing atomic. if self.recNum == 0: # first records, so all fields should be set # allowing us to write the dbf header # cannot change the fields after this point self._header() # first byte of the record is deletion flag, always disabled - f.write(b" ") + record_stream.write(b" ") # begin - self.recNum += 1 fields = ( field for field in self.fields if field[0] != "DeletionFlag" ) # ignore deletionflag field in case it was specified + for (fieldName, fieldType, size, deci), value in zip(fields, record): # write # fieldName, fieldType, size and deci were already checked @@ -4096,17 +4252,19 @@ def _record(self, record: list[RecordValue]) -> None: ) # caps the size if exceeds the field size elif fieldType == "D": # date: 8 bytes - date stored as a string in the format YYYYMMDD. + if isinstance(value, list) and len(value) == 3: + value = date(*value) if isinstance(value, date): - str_val = f"{value.year:04d}{value.month:02d}{value.day:02d}" - elif isinstance(value, list) and len(value) == 3: - str_val = f"{value[0]:04d}{value[1]:02d}{value[2]:02d}" + str_val = value.strftime("%Y%m%d") elif value in MISSING: str_val = "0" * 8 # QGIS NULL for date type elif isinstance(value, str) and len(value) == 8: pass # value is already a date string else: raise ShapefileException( - "Date values must be either a datetime.date object, a list, a YYYYMMDD string, or a missing value." + f"Could not read as date: {value}. " + "Date values must be either a datetime.date object, " + "a list, a YYYYMMDD string, or a missing value." ) elif fieldType == "L": # logical: 1 byte - initialized to 0x20 (space) otherwise T or F. @@ -4120,20 +4278,22 @@ def _record(self, record: list[RecordValue]) -> None: str_val = " " # unknown is set to space if str_val is None: - # Types C and M, and anything else, value is forced to string, - # encoded by the codec specified to the Writer (utf-8 by default), - # then the resulting bytes are padded and truncated to the length - # of the field - encoded = ( - str(value) - .encode(self.encoding, self.encodingErrors)[:size] - .ljust(size) + # Types C and M, and anything else, value is forced to string. + encoded, _trimmed = _encode_dbf_string( + s=str(value), + size=size, + decode=_decode_C_or_M_field, + pad_byte=b" ", + encoding=self.encoding, + encodingErrors=self.encodingErrors, + strict=self.strict, ) else: # str_val was given a not-None string value # under the checks for fieldTypes "N", "F", "D", or "L" above # Numeric, logical, and date numeric types are ascii already, but - # for Shapefile or dbf spec reasons + # for Shapefile or dbf spec reasons ( "All field data is ASCII" + # https://en.wikipedia.org/wiki/.dbf#Database_records ) # "should be default ascii encoding" encoded = str_val.encode("ascii", self.encodingErrors) @@ -4142,7 +4302,10 @@ def _record(self, record: list[RecordValue]) -> None: f"Shapefile Writer unable to pack incorrect sized {value=}" f" (encoded as {len(encoded)}B) into field '{fieldName}' ({size}B)." ) - f.write(encoded) + record_stream.write(encoded) + + self.file.write(record_stream.getvalue()) + self.recNum += 1 class _ShpShxHeaderWriter(_HasCheckedWriteableFile): @@ -4438,6 +4601,7 @@ def __init__( *, encoding: str = "utf-8", encodingErrors: str = "strict", + strict: bool = False, shp: WriteSeekableBinStream | None = None, shx: WriteSeekableBinStream | None = None, dbf: WriteSeekableBinStream | None = None, @@ -4469,7 +4633,7 @@ def __init__( raise TypeError( "Unused kwargs were silently ignored by previous versions of PyShp. " "Either specify target (first positional only arg), " - "or shp and/or dbf, possible plus shx" + "or shp and/or dbf, possibly plus shx" ) self._shp = target.with_suffix(".shp") self._shx = target.with_suffix(".shx") @@ -4494,6 +4658,7 @@ def __init__( dbf=self._dbf, encoding=encoding, encodingErrors=encodingErrors, + strict=strict, ) self.exit_stack.enter_context(self._dbf_writer) diff --git a/tests/hypothesis_tests.py b/tests/hypothesis_tests.py index 6e848ef9..d7424fb8 100644 --- a/tests/hypothesis_tests.py +++ b/tests/hypothesis_tests.py @@ -6,7 +6,7 @@ import string import pytest -from hypothesis import HealthCheck, given, settings +from hypothesis import HealthCheck, given, settings, reproduce_failure from hypothesis.strategies import ( builds, composite, # Preferably avoid. Shrinking composite strategies is slow. @@ -543,7 +543,7 @@ def dbf_fields(draw): text( alphabet=characters( codec="ascii", - exclude_characters=["\x00"], + exclude_categories=["Z", "C"] # Z - Whitespace, C - Control chars++ ), min_size=1, max_size=10, @@ -567,10 +567,10 @@ def test_dbf_Field_roundtrips( ) -> None: expected = shp.Field.from_unchecked(**field_kwargs) stream = io.BytesIO() - encoded = expected.encode_field_descriptor(replace_ascii_spaces_with_underscores=False) + encoded = expected.encode_field_descriptor(strict=True) stream.write(encoded) stream.seek(0) - actual = shp.Field.from_byte_stream(stream, strip_leading_whitespace=False) + actual = shp.Field.from_byte_stream(stream, strict=True) assert isinstance(actual, shp.Field) assert actual.name == expected.name assert actual[1:] == expected[1:] @@ -637,17 +637,24 @@ def dbf_fields_and_records( return fields, records - @pytest.mark.hypothesis @given(fields_and_records=dbf_fields_and_records()) def test_dbf_reader_writer_roundtrip(fields_and_records)-> None: fields, records = fields_and_records stream = io.BytesIO() - with shp.DbfWriter(dbf=stream) as dbf_w: + written_records = [] + with shp.DbfWriter(dbf=stream, strict=True) as dbf_w: for field in fields: dbf_w.field(**field) for record in records: - dbf_w.record(*record) + try: + dbf_w.record(*record) + except shp.DbfStringDataLoss: + pass + else: + written_records.append(record) + + stream.seek(0) with shp.DbfReader(dbf=stream) as r: actual_fields = iter(r.fields) @@ -656,7 +663,7 @@ def test_dbf_reader_writer_roundtrip(fields_and_records)-> None: actual_field_dict = f_r._asdict() for k in ("field_type", "size", "decimal"): assert actual_field_dict[k] == f_w[k], f"{k=}, {actual_field_dict[k]=}, {f_w[k]=}" - for exp_rec, actual_rec in itertools.zip_longest(records, r.records()): + for exp_rec, actual_rec in itertools.zip_longest(written_records, r.records()): for expected, actual, field in itertools.zip_longest(exp_rec, actual_rec, fields): field_type = field["field_type"] decimal = field["decimal"] @@ -667,6 +674,4 @@ def test_dbf_reader_writer_roundtrip(fields_and_records)-> None: actual = actual.strftime("%Y%m%d") elif field_type in ("N", "F") and decimal >= 1: expected = float(format(expected, f".{decimal}f")) - elif field_type == "C": - expected = expected.strip() assert actual == expected, f"{actual=}, {expected=}, {field_type=}, {type(actual)=}, {type(expected)=}" diff --git a/tests/run_benchmarks.py b/tests/run_benchmarks.py index 31e2265e..d36f21b4 100644 --- a/tests/run_benchmarks.py +++ b/tests/run_benchmarks.py @@ -10,7 +10,7 @@ from os import PathLike from pathlib import Path from tempfile import TemporaryFile as TempF -from typing import cast +from typing import cast, Iterable import shapefile @@ -55,14 +55,14 @@ def benchmark( def open_shapefile_with_PyShp(target: str | PathLike): with shapefile.Reader(target) as r: - fields[target] = r.fields + fields[target] = r.data_fields for shapeRecord in r.iterShapeRecords(): shapeRecords[target].append(shapeRecord) def write_shapefile_with_PyShp(target: str | PathLike): with TempF("wb") as shp, TempF("wb") as dbf, TempF("wb") as shx: - with shapefile.Writer(shp=shp, dbf=dbf, shx=shx) as w: # type: ignore [arg-type] + with shapefile.Writer(shp=shp, dbf=dbf, shx=shx, strict=False) as w: # type: ignore [arg-type] for field_info_tuple in fields[target]: w.field(*field_info_tuple) for shapeRecord in shapeRecords[target]: @@ -87,32 +87,32 @@ def write_shapefile_with_PyShp(target: str | PathLike): COLS_WIDTHS = (22, 10) -reader_benchmarks = [ - functools.partial( +reader_benchmarks = { + test_name : functools.partial( benchmark, name=f"Read {test_name}", func=functools.partial(open_shapefile_with_PyShp, target=target), col_widths=COLS_WIDTHS, ) for test_name, target in SHAPEFILES.items() -] +} # Require fields and shapeRecords to first have been populated # from data from previouly running the reader_benchmarks -writer_benchmarks = [ - functools.partial( +writer_benchmarks = { + test_name : functools.partial( benchmark, name=f"Write {test_name}", func=functools.partial(write_shapefile_with_PyShp, target=target), col_widths=COLS_WIDTHS, ) for test_name, target in SHAPEFILES.items() -] +} def run( run_count: int, - benchmarks: list[Callable[[], None]], + benchmarks: Iterable[Callable[[], None]], col_widths: tuple[int, int] = COLS_WIDTHS, ) -> None: col_head = ("parser", "exec time", "performance (more is better)") @@ -125,9 +125,11 @@ def run( run_count=run_count, ) - -if __name__ == "__main__": +def main(): print("Reader tests:") - run(1, reader_benchmarks) # type: ignore [arg-type] + run(1, reader_benchmarks.values()) # type: ignore [arg-type] print("\n\nWriter tests:") - run(1, writer_benchmarks) # type: ignore [arg-type] + run(1, writer_benchmarks.values()) # type: ignore [arg-type] + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_shapefile.py b/tests/test_shapefile.py index 9a6baffa..c49ff705 100644 --- a/tests/test_shapefile.py +++ b/tests/test_shapefile.py @@ -3,6 +3,7 @@ """ import datetime +import io import json import os.path from pathlib import Path @@ -1561,25 +1562,39 @@ def test_reader_zip_polyylinez_no_m_itershaperecords(): pass -def test_write_field_name_limit(tmpdir): - """ - Abc... - """ +def test_write_field_name_below_limit(tmpdir): filename = tmpdir.join("test.shp").strpath - with shapefile.Writer(filename) as writer: + with shapefile.Writer(filename, strict=True) as writer: writer.field("a" * 5, "C") # many under length limit writer.field("a" * 9, "C") # 1 under length limit - writer.field("a" * 10, "C") # at length limit - writer.field("a" * 11, "C") # 1 over length limit - writer.field("a" * 20, "C") # many over limit with shapefile.Reader(filename) as reader: fields = reader.fields[1:] assert len(fields[0][0]) == 5 assert len(fields[1][0]) == 9 + +def test_write_field_names_above_limit_non_strict(tmpdir): + filename = tmpdir.join("test.shp").strpath + with shapefile.Writer(filename, strict=False) as w: + w.field("a" * 10, "C") # at length limit + for l in [11, 20]: # 1 over, and twice the limit + with pytest.warns(shapefile.PossibleDataLoss): + w.field("a" * l, "C") + + with shapefile.Reader(filename) as reader: + fields = reader.fields[1:] + assert len(fields[0][0]) == 10 + assert len(fields[1][0]) == 10 assert len(fields[2][0]) == 10 - assert len(fields[3][0]) == 10 - assert len(fields[4][0]) == 10 + +def test_write_field_names_above_limit_strict(tmpdir): + filename = tmpdir.join("test.shp").strpath + with shapefile.Writer(filename, strict=True) as writer: + writer.field("a" * 10, "C") # at length limit + for l in [11, 20]: # at 1 over length limitand twice the limit + with pytest.raises(ValueError): + writer.field("a" * l, "C") + def test_write_shp_only(tmpdir): @@ -2021,3 +2036,41 @@ def test_write_multipatch(tmpdir): w.record("house1") w.close() + +DATES = [datetime.date(*triple) for triple in [ + (2000,1,1), +]] + +@pytest.mark.parametrize("expected_date", DATES) +def test_round_trip_dbf_date_record(expected_date): + stream = io.BytesIO() + with shapefile.DbfWriter(dbf=stream) as dbf_w: + dbf_w.field("Date","D") + dbf_w.record(expected_date) + stream.seek(0) + with shapefile.DbfReader(dbf=stream) as dbf_r: + assert dbf_r.record(0)[0] == expected_date + + +LONG_FIELD_NAME_TESTS = [ + ("ÀÀÀÀ०", 8, "utf-8", "strict"), +] + +@pytest.mark.parametrize("name,encoded_len,codec,errors", LONG_FIELD_NAME_TESTS) +def test_encode_dbf_field_too_long_names(name,encoded_len,codec,errors): + stream = io.BytesIO() + with shapefile.DbfWriter( + stream, + encoding=codec, + encodingErrors=errors, + strict = False, + ) as w: + with pytest.warns(shapefile.PossibleDataLoss): + w.field(name=name) + field = w.fields[0] + assert name.startswith(field.name) + assert len(w.fields[0].name.encode(codec, errors)) == encoded_len + + stream.seek(0) + with shapefile.DbfReader(stream) as r: + assert r.fields[1].name == field.name \ No newline at end of file