import heapq

class Node:
    def __init__(self, symbol=None, freq=0, left=None, right=None):
        self.symbol = symbol
        self.freq = freq
        self.left = left
        self.right = right
    def __lt__(self, other):
        return self.freq < other.freq

def build_huffman_tree(data: bytes):
    freq = {}
    for b in data:
        freq[b] = freq.get(b, 0) + 1
    heap = [Node(symbol=s, freq=f) for s, f in freq.items()]
    heapq.heapify(heap)
    while len(heap) > 1:
        n1 = heapq.heappop(heap)
        n2 = heapq.heappop(heap)
        merged = Node(freq=n1.freq + n2.freq, left=n1, right=n2)
        heapq.heappush(heap, merged)
    return heap[0] if heap else None

def build_code_table(node, prefix="", table=None):
    if table is None:
        table = {}
    if node.symbol is not None:
        table[node.symbol] = prefix or "0"
    else:
        build_code_table(node.left, prefix + "0", table)
        build_code_table(node.right, prefix + "1", table)
    return table

def serialize_tree(node):
    if node.symbol is not None:
        return (1, node.symbol)
    return (0, serialize_tree(node.left), serialize_tree(node.right))

def deserialize_tree(obj):
    if obj[0] == 1:
        return Node(symbol=obj[1])
    return Node(left=deserialize_tree(obj[1]), right=deserialize_tree(obj[2]))

def huffman_encode(data: bytes):
    if not data:
        return b"", b""
    tree = build_huffman_tree(data)
    code_table = build_code_table(tree)
    bitstr = "".join(code_table[b] for b in data)
    padding = (8 - len(bitstr) % 8) % 8
    bitstr += "0" * padding
    b = bytearray()
    for i in range(0, len(bitstr), 8):
        b.append(int(bitstr[i:i+8], 2))
    import pickle
    tree_bytes = pickle.dumps((serialize_tree(tree), padding))
    return bytes(b), tree_bytes

def huffman_decode(data: bytes, tree_bytes: bytes) -> bytes:
    if not data or not tree_bytes:
        return b""
    import pickle
    tree, padding = pickle.loads(tree_bytes)
    root = deserialize_tree(tree)
    bitstr = "".join(f"{byte:08b}" for byte in data)
    if padding:
        bitstr = bitstr[:-padding]
    result = []
    node = root
    for bit in bitstr:
        node = node.left if bit == "0" else node.right
        if node.symbol is not None:
            result.append(node.symbol)
            node = root
    return bytes(result)