szuweifu commited on
Commit
2cad78e
·
verified ·
1 Parent(s): 4ffb980

Upload 2 files

Browse files
Files changed (2) hide show
  1. inference.py +1 -3
  2. inference_chunk.py +1 -3
inference.py CHANGED
@@ -44,9 +44,7 @@ def inference(args, device):
44
  compress_factor = cfg['model_cfg']['compress_factor']
45
  sampling_rate = cfg['stft_cfg']['sampling_rate']
46
 
47
- SE_model = SEMamba(cfg).to(device)
48
- state_dict = torch.load(args.checkpoint_file, map_location=device)
49
- SE_model.load_state_dict(state_dict['generator'])
50
  SE_model.eval()
51
 
52
  os.makedirs(args.output_folder, exist_ok=True)
 
44
  compress_factor = cfg['model_cfg']['compress_factor']
45
  sampling_rate = cfg['stft_cfg']['sampling_rate']
46
 
47
+ SE_model = SEMamba.from_pretrained("nvidia/RE-USE", cfg=cfg).to(device)
 
 
48
  SE_model.eval()
49
 
50
  os.makedirs(args.output_folder, exist_ok=True)
inference_chunk.py CHANGED
@@ -46,9 +46,7 @@ def inference(args, device):
46
  compress_factor = cfg['model_cfg']['compress_factor']
47
  sampling_rate = cfg['stft_cfg']['sampling_rate']
48
 
49
- SE_model = SEMamba(cfg).to(device)
50
- state_dict = torch.load(args.checkpoint_file, map_location=device)
51
- SE_model.load_state_dict(state_dict['generator'])
52
  SE_model.eval()
53
 
54
  os.makedirs(args.output_folder, exist_ok=True)
 
46
  compress_factor = cfg['model_cfg']['compress_factor']
47
  sampling_rate = cfg['stft_cfg']['sampling_rate']
48
 
49
+ SE_model = SEMamba.from_pretrained("nvidia/RE-USE", cfg=cfg).to(device)
 
 
50
  SE_model.eval()
51
 
52
  os.makedirs(args.output_folder, exist_ok=True)