You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I see the get_adapters function in inference_base.py uses adapter['model'] = CoAdapter(w1 = 1, w2 = 1, w3 = 1).to(opt.device), So I made the following changes to the function:
def get_adapters(opt, cond_type: ExtraCondition):
adapter = {}
cond_weight = getattr(opt, f'{cond_type.name}_weight', None)
if cond_weight is None:
cond_weight = getattr(opt, 'cond_weight')
adapter['cond_weight'] = cond_weight
adapter['model'] = CoAdapter(w1 = 1, w2 = 1, w3 = 1).to(opt.device)
ckpt_pose_path ="F:/data_enhancement/HOIDiffusion-main/midas_models/t2iadapter_openpose_sd14v1.pth" #getattr(opt, f'{cond_type.name}_adapter_ckpt', None)
ckpt_depth_path="F:/data_enhancement/HOIDiffusion-main/midas_models/t2iadapter_depth_sd14v1.pth"
ckpt_mask_path="F:/data_enhancement/HOIDiffusion-main/midas_models/t2iadapter_seg_sd14v1.pth"
#print(ckpt_path)
# if ckpt_path is None:
# ckpt_path = getattr(opt, 'adapter_ckpt')
#print(ckpt_path)
state_dict_pose = read_state_dict(ckpt_pose_path)
state_dict_depth=read_state_dict(ckpt_depth_path)
state_dict_mask=read_state_dict(ckpt_mask_path)
new_state_dict_pose = {}
new_state_dict_depth = {}
new_state_dict_mask = {}
new_state_dict={}
for k, v in state_dict_pose.items():
if k.startswith('adapter.'):
new_state_dict_pose[k[len('adapter.'):]] = v
else:
new_state_dict_pose[k] = v
for k, v in state_dict_depth.items():
if k.startswith('adapter.'):
new_state_dict_depth[k[len('adapter.'):]] = v
else:
new_state_dict_depth[k] = v
for k, v in state_dict_mask.items():
if k.startswith('adapter.'):
new_state_dict_mask[k[len('adapter.'):]] = v
else:
new_state_dict_mask[k] = v
# 如果某些键名没有前缀,可以手动添加
for k, v in state_dict_pose.items():
if not k.startswith('pose_ada.'):
new_state_dict_pose['pose_ada.' + k] = v
del new_state_dict_pose[k]
# 如果某些键名没有前缀,可以手动添加
for k, v in state_dict_depth.items():
if not k.startswith('depth_ada.'):
new_state_dict_depth['depth_ada.' + k] = v
del new_state_dict_depth[k]
# 如果某些键名没有前缀,可以手动添加
for k, v in state_dict_mask.items():
if not k.startswith('mask_ada.'):
new_state_dict_mask['mask_ada.' + k] = v
del new_state_dict_mask[k]
# 合并 pose 状态字典
for k, v in new_state_dict_pose.items():
new_state_dict[k] = v # 直接添加键值对到 new_state_dict
# 合并 depth 状态字典
for k, v in new_state_dict_depth.items():
new_state_dict[k] = v # 直接添加键值对到 new_state_dict
# 合并 mask 状态字典
for k, v in new_state_dict_mask.items():
new_state_dict[k] = v # 直接添加键值对到 new_state_dict
#print(new_state_dict)
#print(adapter['model'])
adapter['model'].load_state_dict(new_state_dict)
return adapter
The parameters I select when debugging are as follows:
If I don't understand wrong, the condition models used in this script are from the checkpoints in t2i-adapter? The openpose keypoint arrangement might be different. Please refer to the DexYCB keypoint layout. Besides we adopted normal maps instead of the depth for training. And the model is not trained based on these released condition checkpoints. It may not work when directly using them.
The resulting effect is shown below:
![ToyCar_0_2_6](https://private-user-images.githubusercontent.com/166021533/366065047-4241afe0-5596-48f7-8cd7-ed10fbd1e7d8.jpg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwMTgzNDcsIm5iZiI6MTczOTAxODA0NywicGF0aCI6Ii8xNjYwMjE1MzMvMzY2MDY1MDQ3LTQyNDFhZmUwLTU1OTYtNDhmNy04Y2Q3LWVkMTBmYmQxZTdkOC5qcGc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjA4JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIwOFQxMjM0MDdaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT0yNWVjYTU2MTg5NDc1YTI2MDZmMmY1NTE1MmExMGZlNmMwMTZkZTk5ZjAwN2JhOGY3ZWViYzMwZWUyYWVkNTVlJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.FKcLlbQvYz9Cakg_m6NWUoWEba7k8zJ7fJZKb69Vh7A)
![ToyCar_0_2_4](https://private-user-images.githubusercontent.com/166021533/366065150-826dac4f-ccd3-48b8-92d4-2b4d3374e3e0.jpg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwMTgzNDcsIm5iZiI6MTczOTAxODA0NywicGF0aCI6Ii8xNjYwMjE1MzMvMzY2MDY1MTUwLTgyNmRhYzRmLWNjZDMtNDhiOC05MmQ0LTJiNGQzMzc0ZTNlMC5qcGc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjA4JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIwOFQxMjM0MDdaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT0xNGY4ZTAyMjMwOTZmYmZhODIzOWFkZDE3ZTVlMTYwZmIyODg3ZWFjZGZlZDgyNzYyN2NlYjlmNjllZmZmZDBjJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.Q4xkau02Wfn0oA4Uos6Ik6xNp98tp8O5GPezOkfFYuY)
I see the get_adapters function in inference_base.py uses
adapter['model'] = CoAdapter(w1 = 1, w2 = 1, w3 = 1).to(opt.device)
, So I made the following changes to the function:The parameters I select when debugging are as follows:
Please tell me how to solve the above problem, thank you
The text was updated successfully, but these errors were encountered: