Skip to content

Commit 42918c6

Browse files
committed
refactor: use context manager for safetensors file handling and fix model prefix normalization in add_tensors_from_sf
1 parent e6eb7a6 commit 42918c6

1 file changed

Lines changed: 27 additions & 27 deletions

File tree

convert.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -118,35 +118,35 @@ def add_tensors_from_sf(w, sf_path, tag, model_type):
118118
for name in names:
119119
info = meta[name]
120120

121-
# normalize: some upstream checkpoints omit the "model." prefix
122-
if model_type == "lm" and not name.startswith("model."):
123-
name = "model." + name
124-
125-
dtype_str = info["dtype"]
126-
shape = info["shape"]
127-
off0, off1 = info["data_offsets"]
128-
nbytes = off1 - off0
129-
130-
f.seek(hdr_size + off0)
131-
raw = f.read(nbytes)
132-
133-
if dtype_str == "BF16":
134-
arr = np.frombuffer(raw, dtype=np.uint16).reshape(shape)
135-
w.add_tensor(name, arr, raw_dtype=BF16)
136-
elif dtype_str == "F16":
137-
arr = np.frombuffer(raw, dtype=np.float16).reshape(shape)
138-
w.add_tensor(name, arr)
139-
elif dtype_str == "F32":
140-
arr = np.frombuffer(raw, dtype=np.float32).reshape(shape)
141-
w.add_tensor(name, arr)
142-
else:
143-
log(tag, " skip %s: dtype %s" % (name, dtype_str))
144-
continue
121+
# normalize: some upstream checkpoints omit the "model." prefix
122+
if model_type == "lm" and not name.startswith("model."):
123+
name = "model." + name
124+
125+
dtype_str = info["dtype"]
126+
shape = info["shape"]
127+
off0, off1 = info["data_offsets"]
128+
nbytes = off1 - off0
129+
130+
f.seek(hdr_size + off0)
131+
raw = f.read(nbytes)
132+
133+
if dtype_str == "BF16":
134+
arr = np.frombuffer(raw, dtype=np.uint16).reshape(shape)
135+
w.add_tensor(name, arr, raw_dtype=BF16)
136+
elif dtype_str == "F16":
137+
arr = np.frombuffer(raw, dtype=np.float16).reshape(shape)
138+
w.add_tensor(name, arr)
139+
elif dtype_str == "F32":
140+
arr = np.frombuffer(raw, dtype=np.float32).reshape(shape)
141+
w.add_tensor(name, arr)
142+
else:
143+
log(tag, " skip %s: dtype %s" % (name, dtype_str))
144+
continue
145145

146-
count += 1
147-
total += nbytes
146+
count += 1
147+
total += nbytes
148148

149-
return count, total
149+
return count, total
150150

151151
# silence_latent.pt reader (replaces pt2bin C++ tool)
152152
# PyTorch .pt is a ZIP with entry "*/data/0" containing f32 [64, 15000]

0 commit comments

Comments
 (0)