Upload 2 files
Browse files- inference.py +1 -3
- inference_chunk.py +1 -3
inference.py
CHANGED
|
@@ -44,9 +44,7 @@ def inference(args, device):
|
|
| 44 |
compress_factor = cfg['model_cfg']['compress_factor']
|
| 45 |
sampling_rate = cfg['stft_cfg']['sampling_rate']
|
| 46 |
|
| 47 |
-
SE_model = SEMamba(cfg).to(device)
|
| 48 |
-
state_dict = torch.load(args.checkpoint_file, map_location=device)
|
| 49 |
-
SE_model.load_state_dict(state_dict['generator'])
|
| 50 |
SE_model.eval()
|
| 51 |
|
| 52 |
os.makedirs(args.output_folder, exist_ok=True)
|
|
|
|
| 44 |
compress_factor = cfg['model_cfg']['compress_factor']
|
| 45 |
sampling_rate = cfg['stft_cfg']['sampling_rate']
|
| 46 |
|
| 47 |
+
SE_model = SEMamba.from_pretrained("nvidia/RE-USE", cfg=cfg).to(device)
|
|
|
|
|
|
|
| 48 |
SE_model.eval()
|
| 49 |
|
| 50 |
os.makedirs(args.output_folder, exist_ok=True)
|
inference_chunk.py
CHANGED
|
@@ -46,9 +46,7 @@ def inference(args, device):
|
|
| 46 |
compress_factor = cfg['model_cfg']['compress_factor']
|
| 47 |
sampling_rate = cfg['stft_cfg']['sampling_rate']
|
| 48 |
|
| 49 |
-
SE_model = SEMamba(cfg).to(device)
|
| 50 |
-
state_dict = torch.load(args.checkpoint_file, map_location=device)
|
| 51 |
-
SE_model.load_state_dict(state_dict['generator'])
|
| 52 |
SE_model.eval()
|
| 53 |
|
| 54 |
os.makedirs(args.output_folder, exist_ok=True)
|
|
|
|
| 46 |
compress_factor = cfg['model_cfg']['compress_factor']
|
| 47 |
sampling_rate = cfg['stft_cfg']['sampling_rate']
|
| 48 |
|
| 49 |
+
SE_model = SEMamba.from_pretrained("nvidia/RE-USE", cfg=cfg).to(device)
|
|
|
|
|
|
|
| 50 |
SE_model.eval()
|
| 51 |
|
| 52 |
os.makedirs(args.output_folder, exist_ok=True)
|