diff --git a/src/Asn1Decode.sol b/src/Asn1Decode.sol index 97cd4e4..f6cb620 100644 --- a/src/Asn1Decode.sol +++ b/src/Asn1Decode.sol @@ -58,6 +58,11 @@ library LibAsn1Ptr { } library Asn1Decode { + error InvalidAsn1Length(); + error InvalidAsn1Type(); + error InvalidAsn1Value(); + error UnsupportedAsn1Tag(); + using LibAsn1Ptr for Asn1Ptr; using LibBytes for bytes; @@ -97,7 +102,7 @@ library Asn1Decode { * @return A pointer to the first child node */ function firstChildOf(bytes memory der, Asn1Ptr ptr) internal pure returns (Asn1Ptr) { - require(der[ptr.header()] & 0x20 == 0x20, "Not a constructed type"); + if (der[ptr.header()] & 0x20 != 0x20) revert InvalidAsn1Type(); return readNodeLength(der, ptr.content()); } @@ -108,9 +113,9 @@ library Asn1Decode { * @return A pointer to a bitstring */ function bitstring(bytes memory der, Asn1Ptr ptr) internal pure returns (Asn1Ptr) { - require(der[ptr.header()] == 0x03, "Not type BIT STRING"); + if (der[ptr.header()] != 0x03) revert InvalidAsn1Type(); // Only 00 padded bitstr can be converted to bytestr! - require(der[ptr.content()] == 0x00, "Non-0-padded BIT STRING"); + if (der[ptr.content()] != 0x00) revert InvalidAsn1Value(); return LibAsn1Ptr.toAsn1Ptr(ptr.header(), ptr.content() + 1, ptr.length() - 1); } @@ -122,22 +127,22 @@ library Asn1Decode { * significant byte, so X.509 bit masks are stable across multi-octet encodings. */ function bitstringUintAt(bytes memory der, Asn1Ptr ptr) internal pure returns (uint256) { - require(der[ptr.header()] == 0x03, "Not type BIT STRING"); - require(ptr.length() > 0, "invalid BIT STRING length"); + if (der[ptr.header()] != 0x03) revert InvalidAsn1Type(); + if (ptr.length() == 0) revert InvalidAsn1Length(); uint256 unusedBits = uint8(der[ptr.content()]); - require(unusedBits <= 7, "invalid BIT STRING padding"); + if (unusedBits > 7) revert InvalidAsn1Value(); uint256 len = ptr.length() - 1; - require(len <= 32, "BIT STRING too long"); + if (len > 32) revert InvalidAsn1Length(); if (len == 0) { - require(unusedBits == 0, "invalid BIT STRING padding"); + if (unusedBits != 0) revert InvalidAsn1Value(); return 0; } if (unusedBits != 0) { uint8 unusedMask = uint8((uint256(1) << unusedBits) - 1); - require(uint8(der[ptr.content() + len]) & unusedMask == 0, "Non-zero unused BIT STRING bits"); + if (uint8(der[ptr.content() + len]) & unusedMask != 0) revert InvalidAsn1Value(); } uint256 value; @@ -154,7 +159,7 @@ library Asn1Decode { * @return A pointer to an octet string */ function octetString(bytes memory der, Asn1Ptr ptr) internal pure returns (Asn1Ptr) { - require(der[ptr.header()] == 0x04, "Not type OCTET STRING"); + if (der[ptr.header()] != 0x04) revert InvalidAsn1Type(); return readNodeLength(der, ptr.content()); } @@ -165,8 +170,8 @@ library Asn1Decode { * @return Uint value of node */ function uintAt(bytes memory der, Asn1Ptr ptr) internal pure returns (uint256) { - require(der[ptr.header()] == 0x02, "Not type INTEGER"); - require(der[ptr.content()] & 0x80 == 0, "Not positive"); + if (der[ptr.header()] != 0x02) revert InvalidAsn1Type(); + if (der[ptr.content()] & 0x80 != 0) revert InvalidAsn1Value(); uint256 len = ptr.length(); return uint256(readBytesN(der, ptr.content(), len) >> (32 - len) * 8); } @@ -178,8 +183,8 @@ library Asn1Decode { * @return 384-bit uint encoded in uint128 and uint256 */ function uint384At(bytes memory der, Asn1Ptr ptr) internal pure returns (uint128, uint256) { - require(der[ptr.header()] == 0x02, "Not type INTEGER"); - require(der[ptr.content()] & 0x80 == 0, "Not positive"); + if (der[ptr.header()] != 0x02) revert InvalidAsn1Type(); + if (der[ptr.content()] & 0x80 != 0) revert InvalidAsn1Value(); uint256 valueLength = ptr.length(); uint256 start = ptr.content(); if (der[start] == 0) { @@ -205,12 +210,12 @@ library Asn1Decode { uint256 length = ptr.length(); // content validation: - require((_type == 0x17 && length == 13) || (_type == 0x18 && length == 15), "Invalid TIMESTAMP"); - require(der[offset + length - 1] == 0x5A, "TIMESTAMP must be UTC"); // 0x5A == 'Z' + if ((_type != 0x17 || length != 13) && (_type != 0x18 || length != 15)) revert InvalidAsn1Value(); + if (der[offset + length - 1] != 0x5A) revert InvalidAsn1Value(); // 0x5A == 'Z' for (uint256 i = 0; i < length - 1; i++) { // all other characters must be digits between 0 and 9 uint8 v = uint8(der[offset + i]); - require(48 <= v && v <= 57, "Invalid character in TIMESTAMP"); + if (v < 48 || v > 57) revert InvalidAsn1Value(); } uint16 _years; @@ -231,7 +236,8 @@ library Asn1Decode { } function readNodeLength(bytes memory der, uint256 ix) private pure returns (Asn1Ptr) { - require(der[ix] & 0x1f != 0x1f, "ASN.1 tags longer than 1-byte are not supported"); + if (ix + 1 >= der.length) revert InvalidAsn1Length(); + if (der[ix] & 0x1f == 0x1f) revert UnsupportedAsn1Tag(); uint256 length; uint256 ixFirstContentByte; if ((der[ix + 1] & 0x80) == 0) { @@ -239,6 +245,10 @@ library Asn1Decode { ixFirstContentByte = ix + 2; } else { uint8 lengthbytesLength = uint8(der[ix + 1] & 0x7F); + if (lengthbytesLength == 0 || lengthbytesLength > 32 || ix + 2 + lengthbytesLength > der.length) { + revert InvalidAsn1Length(); + } + if (der[ix + 2] == 0) revert InvalidAsn1Length(); if (lengthbytesLength == 1) { length = uint8(der[ix + 2]); } else if (lengthbytesLength == 2) { @@ -247,8 +257,10 @@ library Asn1Decode { length = uint256(readBytesN(der, ix + 2, lengthbytesLength) >> (32 - lengthbytesLength) * 8); require(length <= 2 ** 64 - 1); // bound to max uint64 to be safe } + if (length < 128) revert InvalidAsn1Length(); ixFirstContentByte = ix + 2 + lengthbytesLength; } + if (ixFirstContentByte + length > der.length) revert InvalidAsn1Length(); return LibAsn1Ptr.toAsn1Ptr(ix, ixFirstContentByte, length); } diff --git a/test/Asn1Decode.t.sol b/test/Asn1Decode.t.sol index ae44910..8eb667b 100644 --- a/test/Asn1Decode.t.sol +++ b/test/Asn1Decode.t.sol @@ -47,7 +47,7 @@ contract Asn1DecodeTest is Test { // --- readNodeLength / tag handling --- function test_root_multiByteTag_reverts() public { - vm.expectRevert("ASN.1 tags longer than 1-byte are not supported"); + vm.expectRevert(Asn1Decode.UnsupportedAsn1Tag.selector); h.rootLength(hex"1f00"); // low tag bits 0x1f == high-tag-number form } @@ -62,6 +62,28 @@ contract Asn1DecodeTest is Test { h.rootLength(hex"0289ffffffffffffffffff"); // INTEGER, 9 length bytes all 0xff } + function test_root_indefiniteLength_reverts() public { + vm.expectRevert(Asn1Decode.InvalidAsn1Length.selector); + h.rootLength(hex"0480"); // DER requires definite lengths + } + + function test_root_longFormForShortLength_reverts() public { + vm.expectRevert(Asn1Decode.InvalidAsn1Length.selector); + h.rootLength(hex"04810100"); // length 1 must use short form 0x01 + } + + function test_root_longFormLeadingZero_reverts() public { + vm.expectRevert(Asn1Decode.InvalidAsn1Length.selector); + h.rootLength(hex"04820080"); // length 128 must be 0x81 0x80, not 0x82 0x00 0x80 + } + + function test_root_canonicalLongFormLength() public view { + bytes memory der = abi.encodePacked(bytes3(0x048180), new bytes(128)); + + assertEq(h.rootLength(der), 128); + assertEq(h.rootContent(der), 3); + } + // --- uintAt --- function test_uintAt_value() public view { @@ -69,12 +91,12 @@ contract Asn1DecodeTest is Test { } function test_uintAt_notInteger_reverts() public { - vm.expectRevert("Not type INTEGER"); + vm.expectRevert(Asn1Decode.InvalidAsn1Type.selector); h.uintAtRoot(hex"0401ff"); // OCTET STRING, not INTEGER } function test_uintAt_negative_reverts() public { - vm.expectRevert("Not positive"); + vm.expectRevert(Asn1Decode.InvalidAsn1Value.selector); h.uintAtRoot(hex"020180"); // high bit set } @@ -98,23 +120,23 @@ contract Asn1DecodeTest is Test { function test_timestamp_wrongType_reverts() public { bytes memory der = abi.encodePacked(hex"160d", bytes("700101000000Z")); // type 0x16 - vm.expectRevert("Invalid TIMESTAMP"); + vm.expectRevert(Asn1Decode.InvalidAsn1Value.selector); h.timestampAtRoot(der); } function test_timestamp_wrongLength_reverts() public { bytes memory der = abi.encodePacked(hex"170c", bytes("70010100000Z")); // UTCTime, length 12 - vm.expectRevert("Invalid TIMESTAMP"); + vm.expectRevert(Asn1Decode.InvalidAsn1Value.selector); h.timestampAtRoot(der); } function test_timestamp_missingZ_reverts() public { - vm.expectRevert("TIMESTAMP must be UTC"); + vm.expectRevert(Asn1Decode.InvalidAsn1Value.selector); h.timestampAtRoot(_utcTime("700101000000X")); } function test_timestamp_nonDigit_reverts() public { - vm.expectRevert("Invalid character in TIMESTAMP"); + vm.expectRevert(Asn1Decode.InvalidAsn1Value.selector); h.timestampAtRoot(_utcTime("7A0101000000Z")); } @@ -126,12 +148,12 @@ contract Asn1DecodeTest is Test { } function test_bitstring_notBitString_reverts() public { - vm.expectRevert("Not type BIT STRING"); + vm.expectRevert(Asn1Decode.InvalidAsn1Type.selector); h.bitstringContent(hex"0401ff"); } function test_bitstring_nonZeroPadded_reverts() public { - vm.expectRevert("Non-0-padded BIT STRING"); + vm.expectRevert(Asn1Decode.InvalidAsn1Value.selector); h.bitstringContent(hex"03020100"); // pad byte is 0x01, not 0x00 } @@ -156,24 +178,24 @@ contract Asn1DecodeTest is Test { } function test_bitstringUintAt_nonZeroUnusedBits_reverts() public { - vm.expectRevert("Non-zero unused BIT STRING bits"); + vm.expectRevert(Asn1Decode.InvalidAsn1Value.selector); h.bitstringUintAtRoot(hex"03030700ff"); } function test_bitstringUintAt_invalidUnusedBits_reverts() public { - vm.expectRevert("invalid BIT STRING padding"); + vm.expectRevert(Asn1Decode.InvalidAsn1Value.selector); h.bitstringUintAtRoot(hex"03020880"); } function test_bitstringUintAt_missingUnusedBits_reverts() public { - vm.expectRevert("invalid BIT STRING length"); + vm.expectRevert(Asn1Decode.InvalidAsn1Length.selector); h.bitstringUintAtRoot(hex"0300"); } // --- firstChildOf --- function test_firstChildOf_notConstructed_reverts() public { - vm.expectRevert("Not a constructed type"); + vm.expectRevert(Asn1Decode.InvalidAsn1Type.selector); h.firstChildHeader(hex"0401ff"); // OCTET STRING is primitive, not constructed } diff --git a/test/CertManager.t.sol b/test/CertManager.t.sol index cc40cfd..91b25ad 100644 --- a/test/CertManager.t.sol +++ b/test/CertManager.t.sol @@ -81,7 +81,7 @@ contract CertManagerTest is Test { } function test_BasicConstraintsRejectsOutOfBoundsChild() public { - vm.expectRevert("basicConstraints out of bounds"); + vm.expectRevert(Asn1Decode.InvalidAsn1Length.selector); certManagerHarness.verifyBasicConstraints(hex"3003020200", false); } diff --git a/test/hinted/HintedNitroAttestation.t.sol b/test/hinted/HintedNitroAttestation.t.sol index 3606bbd..5e3c827 100644 --- a/test/hinted/HintedNitroAttestation.t.sol +++ b/test/hinted/HintedNitroAttestation.t.sol @@ -483,6 +483,19 @@ contract HintedNitroAttestationTest is Test { certManager.verifyCACertWithHints(abi.encodePacked(caCert, bytes1(0x00)), parentHash, ""); } + function test_HintedCACertRejectsNonCanonicalOuterLength() public { + bytes memory attestation = _repairMissingPublicKeyBytes(_decodeBase64(_realAttestationB64())); + (bytes memory attestationTbs,) = validator.decodeAttestationTbs(attestation); + NitroValidator.Ptrs memory ptrs = parser.parseAttestation(attestationTbs); + (bytes memory caCert, bytes32 parentHash,) = _firstNonRootCA(attestationTbs, ptrs); + bytes memory nonCanonicalCert = _nonCanonicalOuterSequenceLength(caCert); + bytes32 nonCanonicalHash = keccak256(nonCanonicalCert); + + vm.expectRevert(Asn1Decode.InvalidAsn1Length.selector); + certManager.verifyCACertWithHints(nonCanonicalCert, parentHash, ""); + assertEq(certManager.loadVerified(nonCanonicalHash).pubKey.length, 0, "non-canonical cert must not cache"); + } + function test_HintedTrailingRootBytesCannotPoisonParentCache() public { bytes memory attestation = _repairMissingPublicKeyBytes(_decodeBase64(_realAttestationB64())); (bytes memory attestationTbs,) = validator.decodeAttestationTbs(attestation); @@ -1266,6 +1279,21 @@ contract HintedNitroAttestationTest is Test { output[3] = bytes1(uint8(length)); } + function _nonCanonicalOuterSequenceLength(bytes memory der) internal pure returns (bytes memory output) { + require(der.length >= 4 && der[0] == 0x30 && der[1] == 0x82, "test: expected long sequence"); + + output = new bytes(der.length + 1); + output[0] = der[0]; + output[1] = 0x83; + output[2] = 0x00; + output[3] = der[2]; + output[4] = der[3]; + + for (uint256 i = 4; i < der.length; ++i) { + output[i + 1] = der[i]; + } + } + function _repairMissingPublicKeyBytes(bytes memory attestation) internal pure returns (bytes memory repaired) { // The pasted Base64 sample is missing "ic_" in the CBOR key // "public_key", but the key length and outer COSE payload length still