Unverified Commit d0184b8f authored by DepFA's avatar DepFA Committed by GitHub

change json tensor key name

parent 5d12ec82
...@@ -19,15 +19,15 @@ import modules.textual_inversion.dataset ...@@ -19,15 +19,15 @@ import modules.textual_inversion.dataset
class EmbeddingEncoder(json.JSONEncoder): class EmbeddingEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
return {'EMBEDDINGTENSOR':obj.cpu().detach().numpy().tolist()} return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
return json.JSONEncoder.default(self, o) return json.JSONEncoder.default(self, o)
class EmbeddingDecoder(json.JSONDecoder): class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, d): def object_hook(self, d):
if 'EMBEDDINGTENSOR' in d: if 'TORCHTENSOR' in d:
return torch.from_numpy(np.array(d['EMBEDDINGTENSOR'])) return torch.from_numpy(np.array(d['TORCHTENSOR']))
return d return d
def embeddingToB64(data): def embeddingToB64(data):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment