Skip to content
Snippets Groups Projects
Commit 38d65435 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Build] support batched verification i/o, one at a time

parent df9aa4a4
No related branches found
No related tags found
No related merge requests found
......@@ -121,44 +121,80 @@ def verify_step(
verify_out_dir = cfg.output_dir + "/verification_output"
intermediate_models_dir = cfg.output_dir + "/intermediate_models"
os.makedirs(verify_out_dir, exist_ok=True)
(in_npy, exp_out_npy) = cfg._resolve_verification_io_pair()
if need_parent:
assert (
cfg.save_intermediate_models
), "Enable save_intermediate_models for verification"
parent_model_fn = intermediate_models_dir + "/dataflow_parent.onnx"
child_model_fn = intermediate_models_dir + "/verify_%s.onnx" % step_name
model.save(child_model_fn)
out_tensor_name = ModelWrapper(parent_model_fn).graph.output[0].name
out_dict = execute_parent(
parent_model_fn, child_model_fn, in_npy, return_full_ctx=True
)
out_npy = out_dict[out_tensor_name]
else:
inp_tensor_name = model.graph.input[0].name
out_tensor_name = model.graph.output[0].name
inp_dict = {inp_tensor_name: in_npy}
if rtlsim_pre_hook is not None:
out_dict = rtlsim_exec(model, inp_dict, pre_hook=rtlsim_pre_hook)
(in_npy_all, exp_out_npy_all) = cfg._resolve_verification_io_pair()
bsize_in = in_npy_all.shape[0]
bsize_out = exp_out_npy_all.shape[0]
assert bsize_in == bsize_out, "Batch sizes don't match for verification IO pair"
all_res = True
for b in range(bsize_in):
in_npy = np.expand_dims(in_npy_all[b], axis=0)
exp_out_npy = np.expand_dims(exp_out_npy_all[b], axis=0)
if need_parent:
assert (
cfg.save_intermediate_models
), "Enable save_intermediate_models for verification"
parent_model_fn = intermediate_models_dir + "/dataflow_parent.onnx"
child_model_fn = intermediate_models_dir + "/verify_%s.onnx" % step_name
model.save(child_model_fn)
parent_model = ModelWrapper(parent_model_fn)
out_tensor_name = parent_model.graph.output[0].name
exp_ishape = parent_model.get_tensor_shape(parent_model.graph.input[0].name)
if in_npy.shape != exp_ishape:
print(
"Verification input has shape %s while model expects %s"
% (str(in_npy.shape), str(exp_ishape))
)
print("Attempting to force model shape on verification input")
in_npy = in_npy.reshape(exp_ishape)
out_dict = execute_parent(
parent_model_fn, child_model_fn, in_npy, return_full_ctx=True
)
out_npy = out_dict[out_tensor_name]
else:
out_dict = execute_onnx(model, inp_dict, True)
out_npy = out_dict[out_tensor_name]
res = np.isclose(exp_out_npy, out_npy, atol=1e-3).all()
res_to_str = {True: "SUCCESS", False: "FAIL"}
res_str = res_to_str[res]
if cfg.verify_save_full_context:
verification_output_fn = verify_out_dir + "/verify_%s_%s.npz" % (
step_name,
res_str,
)
np.savez(verification_output_fn, **out_dict)
else:
verification_output_fn = verify_out_dir + "/verify_%s_%s.npy" % (
step_name,
res_str,
)
np.save(verification_output_fn, out_npy)
print("Verification for %s : %s" % (step_name, res_str))
inp_tensor_name = model.graph.input[0].name
out_tensor_name = model.graph.output[0].name
exp_ishape = model.get_tensor_shape(inp_tensor_name)
if in_npy.shape != exp_ishape:
print(
"Verification input has shape %s while model expects %s"
% (str(in_npy.shape), str(exp_ishape))
)
print("Attempting to force model shape on verification input")
in_npy = in_npy.reshape(exp_ishape)
inp_dict = {inp_tensor_name: in_npy}
if rtlsim_pre_hook is not None:
out_dict = rtlsim_exec(model, inp_dict, pre_hook=rtlsim_pre_hook)
else:
out_dict = execute_onnx(model, inp_dict, True)
out_npy = out_dict[out_tensor_name]
exp_oshape = exp_out_npy.shape
if out_npy.shape != exp_oshape:
print(
"Verification output has shape %s while model produces %s"
% (str(exp_oshape), str(out_npy.shape))
)
print("Attempting to force model shape on verification output")
out_npy = out_npy.reshape(exp_oshape)
res = np.isclose(exp_out_npy, out_npy, atol=1e-3).all()
all_res = all_res and res
res_to_str = {True: "SUCCESS", False: "FAIL"}
res_str = res_to_str[res]
if cfg.verify_save_full_context:
verification_output_fn = verify_out_dir + "/verify_%s_%d_%s.npz" % (
step_name,
b,
res_str,
)
np.savez(verification_output_fn, **out_dict)
else:
verification_output_fn = verify_out_dir + "/verify_%s_%d_%s.npy" % (
step_name,
b,
res_str,
)
np.save(verification_output_fn, out_npy)
print("Verification for %s : %s" % (step_name, res_to_str[all_res]))
def prepare_for_stitched_ip_rtlsim(verify_model, cfg):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment