Coverage for src / kdbxtool / parsing / context.py: 100%
83 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-20 19:19 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-20 19:19 +0000
1"""Parsing and building context classes for binary data.
3This module provides structured helpers for reading and writing binary data
4with good error messages and type safety.
5"""
7from __future__ import annotations
9import struct
10from collections.abc import Iterator
11from contextlib import contextmanager
12from dataclasses import dataclass, field
13from typing import cast
15from kdbxtool.exceptions import CorruptedDataError
18@dataclass
19class ParseContext:
20 """Stateful reader for binary data with error context tracking.
22 Tracks the current offset and a path of nested scopes for error messages.
23 When parsing fails, error messages include the full path to the failure point.
25 Example:
26 ctx = ParseContext(data)
27 with ctx.scope("header"):
28 magic = ctx.read(8, "magic")
29 with ctx.scope("version"):
30 major = ctx.read_u16("major")
31 # Error would show: "Unexpected EOF at header/version/major, offset 10"
32 """
34 data: bytes
35 offset: int = 0
36 _path: list[str] = field(default_factory=list)
38 def read(self, n: int, name: str = "") -> bytes:
39 """Read n bytes from current position.
41 Args:
42 n: Number of bytes to read
43 name: Optional name for error messages
45 Returns:
46 The bytes read
48 Raises:
49 CorruptedDataError: If not enough bytes available
50 """
51 if self.offset + n > len(self.data):
52 location = self._format_location(name)
53 raise CorruptedDataError(
54 f"Unexpected EOF at {location}, offset {self.offset}, "
55 f"need {n} bytes, have {len(self.data) - self.offset}"
56 )
57 result = self.data[self.offset : self.offset + n]
58 self.offset += n
59 return result
61 def read_u8(self, name: str = "") -> int:
62 """Read unsigned 8-bit integer."""
63 return self.read(1, name)[0]
65 def read_u16(self, name: str = "") -> int:
66 """Read unsigned 16-bit little-endian integer."""
67 return cast(int, struct.unpack("<H", self.read(2, name))[0])
69 def read_u32(self, name: str = "") -> int:
70 """Read unsigned 32-bit little-endian integer."""
71 return cast(int, struct.unpack("<I", self.read(4, name))[0])
73 def read_u64(self, name: str = "") -> int:
74 """Read unsigned 64-bit little-endian integer."""
75 return cast(int, struct.unpack("<Q", self.read(8, name))[0])
77 def read_bytes_prefixed(self, name: str = "") -> bytes:
78 """Read length-prefixed bytes (4-byte little-endian length prefix).
80 Args:
81 name: Optional name for error messages
83 Returns:
84 The bytes read (not including the length prefix)
85 """
86 length = self.read_u32(f"{name}.length" if name else "length")
87 return self.read(length, f"{name}.data" if name else "data")
89 def peek(self, n: int) -> bytes:
90 """Peek at next n bytes without advancing offset.
92 Args:
93 n: Number of bytes to peek
95 Returns:
96 The bytes (may be shorter if near end of data)
97 """
98 return self.data[self.offset : self.offset + n]
100 def skip(self, n: int, name: str = "") -> None:
101 """Skip n bytes.
103 Args:
104 n: Number of bytes to skip
105 name: Optional name for error messages
107 Raises:
108 CorruptedDataError: If not enough bytes available
109 """
110 self.read(n, name)
112 @contextmanager
113 def scope(self, name: str) -> Iterator[None]:
114 """Create a named scope for error context.
116 Args:
117 name: Scope name to add to error path
119 Example:
120 with ctx.scope("inner_header"):
121 field_type = ctx.read_u8("type")
122 """
123 self._path.append(name)
124 try:
125 yield
126 finally:
127 self._path.pop()
129 @property
130 def remaining(self) -> int:
131 """Number of bytes remaining to read."""
132 return len(self.data) - self.offset
134 @property
135 def exhausted(self) -> bool:
136 """True if all bytes have been read."""
137 return self.offset >= len(self.data)
139 @property
140 def position(self) -> int:
141 """Current read position (alias for offset)."""
142 return self.offset
144 def _format_location(self, name: str = "") -> str:
145 """Format current location for error messages."""
146 parts = self._path.copy()
147 if name:
148 parts.append(name)
149 return "/".join(parts) if parts else "<root>"
152@dataclass
153class BuildContext:
154 """Stateful writer for building binary data.
156 Accumulates bytes in a list and joins them efficiently at the end.
158 Example:
159 ctx = BuildContext()
160 ctx.write(MAGIC_BYTES)
161 ctx.write_u32(version)
162 ctx.write_tlv(FIELD_TYPE, field_data)
163 result = ctx.build()
164 """
166 _parts: list[bytes] = field(default_factory=list)
168 def write(self, data: bytes) -> None:
169 """Write raw bytes."""
170 self._parts.append(data)
172 def write_u8(self, value: int) -> None:
173 """Write unsigned 8-bit integer."""
174 self._parts.append(bytes([value]))
176 def write_u16(self, value: int) -> None:
177 """Write unsigned 16-bit little-endian integer."""
178 self._parts.append(struct.pack("<H", value))
180 def write_u32(self, value: int) -> None:
181 """Write unsigned 32-bit little-endian integer."""
182 self._parts.append(struct.pack("<I", value))
184 def write_u64(self, value: int) -> None:
185 """Write unsigned 64-bit little-endian integer."""
186 self._parts.append(struct.pack("<Q", value))
188 def write_bytes_prefixed(self, data: bytes) -> None:
189 """Write length-prefixed bytes (4-byte little-endian length prefix)."""
190 self.write_u32(len(data))
191 self.write(data)
193 def write_tlv(self, type_id: int, data: bytes, type_size: int = 1) -> None:
194 """Write Type-Length-Value field.
196 Args:
197 type_id: Field type identifier
198 data: Field data
199 type_size: Size of type field in bytes (1 for KDBX4, can vary)
200 """
201 if type_size == 1:
202 self.write_u8(type_id)
203 elif type_size == 2:
204 self.write_u16(type_id)
205 else:
206 raise ValueError(f"Unsupported type_size: {type_size}")
207 self.write_u32(len(data))
208 self.write(data)
210 def build(self) -> bytes:
211 """Join all accumulated bytes and return result."""
212 return b"".join(self._parts)
214 @property
215 def size(self) -> int:
216 """Total size of accumulated bytes."""
217 return sum(len(p) for p in self._parts)