Only use colors from the palette file

This commit is contained in:
Andrew Murray 2022-10-06 11:43:20 +11:00
parent 681f77a293
commit 4642e02571
2 changed files with 16 additions and 34 deletions

View File

@ -29,17 +29,8 @@ def test_get_palette():
palette, mode = palette_file.getpalette() palette, mode = palette_file.getpalette()
# Assert # Assert
assert mode == "RGB" expected_palette = b""
for color in (
def test_palette__has_correct_color_indexes():
# Arrange
with open("Tests/images/custom_gimp_palette.gpl", "rb") as fp:
palette_file = GimpPaletteFile(fp)
palette, mode = palette_file.getpalette()
colors_in_test_palette = [
(0, 0, 0), (0, 0, 0),
(65, 38, 30), (65, 38, 30),
(103, 62, 49), (103, 62, 49),
@ -48,15 +39,7 @@ def test_palette__has_correct_color_indexes():
(208, 127, 100), (208, 127, 100),
(151, 144, 142), (151, 144, 142),
(221, 207, 199), (221, 207, 199),
] ):
expected_palette += bytes(color)
for i, color in enumerate(colors_in_test_palette): assert palette == expected_palette
assert tuple(palette[i * 3 : i * 3 + 3]) == color assert mode == "RGB"
def test_palette_counts_number_of_colors_in_file():
# Arrange
with open("Tests/images/custom_gimp_palette.gpl", "rb") as fp:
palette_file = GimpPaletteFile(fp)
assert palette_file.n_colors == 8

View File

@ -26,28 +26,27 @@ class GimpPaletteFile:
def __init__(self, fp): def __init__(self, fp):
palette = bytearray(b"".join([o8(i) * 3 for i in range(256)]))
if fp.readline()[:12] != b"GIMP Palette": if fp.readline()[:12] != b"GIMP Palette":
raise SyntaxError("not a GIMP palette file") raise SyntaxError("not a GIMP palette file")
index = 0 self.palette = b""
for s in fp: while len(self.palette) < 768:
s = fp.readline()
if not s:
break
# skip fields and comment lines # skip fields and comment lines
if re.match(rb"\w+:|#", s): if re.match(rb"\w+:|#", s):
continue continue
if len(s) > 100: if len(s) > 100:
raise SyntaxError("bad palette file") raise SyntaxError("bad palette file")
v = tuple(map(int, s.split()[:3])) v = s.split()
if len(v) < 3: if len(v) < 3:
raise ValueError("bad palette entry") raise ValueError("bad palette entry")
for i in range(3):
palette[index * 3 : index * 3 + 3] = v self.palette += o8(int(v[i]))
index += 1
self.palette = bytes(palette)
self.n_colors = index
def getpalette(self): def getpalette(self):