Coverage for src / kdbxtool / parsing / context.py: 100%

83 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-01-20 19:19 +0000

1"""Parsing and building context classes for binary data. 

2 

3This module provides structured helpers for reading and writing binary data 

4with good error messages and type safety. 

5""" 

6 

7from __future__ import annotations 

8 

9import struct 

10from collections.abc import Iterator 

11from contextlib import contextmanager 

12from dataclasses import dataclass, field 

13from typing import cast 

14 

15from kdbxtool.exceptions import CorruptedDataError 

16 

17 

18@dataclass 

19class ParseContext: 

20 """Stateful reader for binary data with error context tracking. 

21 

22 Tracks the current offset and a path of nested scopes for error messages. 

23 When parsing fails, error messages include the full path to the failure point. 

24 

25 Example: 

26 ctx = ParseContext(data) 

27 with ctx.scope("header"): 

28 magic = ctx.read(8, "magic") 

29 with ctx.scope("version"): 

30 major = ctx.read_u16("major") 

31 # Error would show: "Unexpected EOF at header/version/major, offset 10" 

32 """ 

33 

34 data: bytes 

35 offset: int = 0 

36 _path: list[str] = field(default_factory=list) 

37 

38 def read(self, n: int, name: str = "") -> bytes: 

39 """Read n bytes from current position. 

40 

41 Args: 

42 n: Number of bytes to read 

43 name: Optional name for error messages 

44 

45 Returns: 

46 The bytes read 

47 

48 Raises: 

49 CorruptedDataError: If not enough bytes available 

50 """ 

51 if self.offset + n > len(self.data): 

52 location = self._format_location(name) 

53 raise CorruptedDataError( 

54 f"Unexpected EOF at {location}, offset {self.offset}, " 

55 f"need {n} bytes, have {len(self.data) - self.offset}" 

56 ) 

57 result = self.data[self.offset : self.offset + n] 

58 self.offset += n 

59 return result 

60 

61 def read_u8(self, name: str = "") -> int: 

62 """Read unsigned 8-bit integer.""" 

63 return self.read(1, name)[0] 

64 

65 def read_u16(self, name: str = "") -> int: 

66 """Read unsigned 16-bit little-endian integer.""" 

67 return cast(int, struct.unpack("<H", self.read(2, name))[0]) 

68 

69 def read_u32(self, name: str = "") -> int: 

70 """Read unsigned 32-bit little-endian integer.""" 

71 return cast(int, struct.unpack("<I", self.read(4, name))[0]) 

72 

73 def read_u64(self, name: str = "") -> int: 

74 """Read unsigned 64-bit little-endian integer.""" 

75 return cast(int, struct.unpack("<Q", self.read(8, name))[0]) 

76 

77 def read_bytes_prefixed(self, name: str = "") -> bytes: 

78 """Read length-prefixed bytes (4-byte little-endian length prefix). 

79 

80 Args: 

81 name: Optional name for error messages 

82 

83 Returns: 

84 The bytes read (not including the length prefix) 

85 """ 

86 length = self.read_u32(f"{name}.length" if name else "length") 

87 return self.read(length, f"{name}.data" if name else "data") 

88 

89 def peek(self, n: int) -> bytes: 

90 """Peek at next n bytes without advancing offset. 

91 

92 Args: 

93 n: Number of bytes to peek 

94 

95 Returns: 

96 The bytes (may be shorter if near end of data) 

97 """ 

98 return self.data[self.offset : self.offset + n] 

99 

100 def skip(self, n: int, name: str = "") -> None: 

101 """Skip n bytes. 

102 

103 Args: 

104 n: Number of bytes to skip 

105 name: Optional name for error messages 

106 

107 Raises: 

108 CorruptedDataError: If not enough bytes available 

109 """ 

110 self.read(n, name) 

111 

112 @contextmanager 

113 def scope(self, name: str) -> Iterator[None]: 

114 """Create a named scope for error context. 

115 

116 Args: 

117 name: Scope name to add to error path 

118 

119 Example: 

120 with ctx.scope("inner_header"): 

121 field_type = ctx.read_u8("type") 

122 """ 

123 self._path.append(name) 

124 try: 

125 yield 

126 finally: 

127 self._path.pop() 

128 

129 @property 

130 def remaining(self) -> int: 

131 """Number of bytes remaining to read.""" 

132 return len(self.data) - self.offset 

133 

134 @property 

135 def exhausted(self) -> bool: 

136 """True if all bytes have been read.""" 

137 return self.offset >= len(self.data) 

138 

139 @property 

140 def position(self) -> int: 

141 """Current read position (alias for offset).""" 

142 return self.offset 

143 

144 def _format_location(self, name: str = "") -> str: 

145 """Format current location for error messages.""" 

146 parts = self._path.copy() 

147 if name: 

148 parts.append(name) 

149 return "/".join(parts) if parts else "<root>" 

150 

151 

152@dataclass 

153class BuildContext: 

154 """Stateful writer for building binary data. 

155 

156 Accumulates bytes in a list and joins them efficiently at the end. 

157 

158 Example: 

159 ctx = BuildContext() 

160 ctx.write(MAGIC_BYTES) 

161 ctx.write_u32(version) 

162 ctx.write_tlv(FIELD_TYPE, field_data) 

163 result = ctx.build() 

164 """ 

165 

166 _parts: list[bytes] = field(default_factory=list) 

167 

168 def write(self, data: bytes) -> None: 

169 """Write raw bytes.""" 

170 self._parts.append(data) 

171 

172 def write_u8(self, value: int) -> None: 

173 """Write unsigned 8-bit integer.""" 

174 self._parts.append(bytes([value])) 

175 

176 def write_u16(self, value: int) -> None: 

177 """Write unsigned 16-bit little-endian integer.""" 

178 self._parts.append(struct.pack("<H", value)) 

179 

180 def write_u32(self, value: int) -> None: 

181 """Write unsigned 32-bit little-endian integer.""" 

182 self._parts.append(struct.pack("<I", value)) 

183 

184 def write_u64(self, value: int) -> None: 

185 """Write unsigned 64-bit little-endian integer.""" 

186 self._parts.append(struct.pack("<Q", value)) 

187 

188 def write_bytes_prefixed(self, data: bytes) -> None: 

189 """Write length-prefixed bytes (4-byte little-endian length prefix).""" 

190 self.write_u32(len(data)) 

191 self.write(data) 

192 

193 def write_tlv(self, type_id: int, data: bytes, type_size: int = 1) -> None: 

194 """Write Type-Length-Value field. 

195 

196 Args: 

197 type_id: Field type identifier 

198 data: Field data 

199 type_size: Size of type field in bytes (1 for KDBX4, can vary) 

200 """ 

201 if type_size == 1: 

202 self.write_u8(type_id) 

203 elif type_size == 2: 

204 self.write_u16(type_id) 

205 else: 

206 raise ValueError(f"Unsupported type_size: {type_size}") 

207 self.write_u32(len(data)) 

208 self.write(data) 

209 

210 def build(self) -> bytes: 

211 """Join all accumulated bytes and return result.""" 

212 return b"".join(self._parts) 

213 

214 @property 

215 def size(self) -> int: 

216 """Total size of accumulated bytes.""" 

217 return sum(len(p) for p in self._parts)