| """Utility functions for Chiluka.""" |
|
|
| import torch |
| from munch import Munch |
|
|
|
|
| def length_to_mask(lengths): |
| """Convert lengths to attention mask.""" |
| mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) |
| mask = torch.gt(mask + 1, lengths.unsqueeze(1)) |
| return mask |
|
|
|
|
| def recursive_munch(d): |
| """Recursively convert dict to Munch for dot notation access.""" |
| if isinstance(d, dict): |
| return Munch((k, recursive_munch(v)) for k, v in d.items()) |
| elif isinstance(d, list): |
| return [recursive_munch(v) for v in d] |
| else: |
| return d |
|
|