| | import torch |
| | import logging |
| | import re |
| | from typing import Dict, List, Any |
| | from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | """ |
| | Initialize the RECCON emotional trigger extraction model using native transformers. |
| | Args: |
| | path: Path to model directory (provided by HuggingFace Inference Endpoints) |
| | """ |
| | logger.info("Initializing RECCON Trigger Extraction endpoint...") |
| |
|
| | |
| | cuda_available = torch.cuda.is_available() |
| | if not cuda_available: |
| | logger.warning("GPU not detected. Running on CPU. Inference will be slower.") |
| | |
| | |
| | self.device_id = 0 if cuda_available else -1 |
| |
|
| | |
| | model_path = path if path and path != "." else "." |
| | logger.info(f"Loading model from {model_path}...") |
| |
|
| | try: |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | model, loading_info = AutoModelForQuestionAnswering.from_pretrained( |
| | model_path, |
| | output_loading_info=True |
| | ) |
| |
|
| | logger.warning("RECCON load info - missing_keys: %s", loading_info.get("missing_keys")) |
| | logger.warning("RECCON load info - unexpected_keys: %s", loading_info.get("unexpected_keys")) |
| | logger.warning("RECCON load info - error_msgs: %s", loading_info.get("error_msgs")) |
| | logger.warning("Loaded model class: %s", model.__class__.__name__) |
| | logger.warning("Loaded model name_or_path: %s", getattr(model.config, "_name_or_path", None)) |
| |
|
| | |
| | |
| | self.pipe = pipeline( |
| | "question-answering", |
| | model=model, |
| | tokenizer=tokenizer, |
| | device=self.device_id, |
| | top_k=20, |
| | handle_impossible_answer=False |
| | ) |
| | logger.info("Model loaded successfully.") |
| | except Exception as e: |
| | logger.error(f"Failed to load model: {e}") |
| | raise |
| |
|
| | |
| | self.question_template = ( |
| | "Extract the exact short phrase (<= 8 words) from the target " |
| | "utterance that most strongly signals the emotion {emotion}. " |
| | "Return only a substring of the target utterance." |
| | ) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | Process inference request. |
| | """ |
| | |
| | inputs = data.pop("inputs", data) |
| |
|
| | |
| | if isinstance(inputs, dict): |
| | inputs = [inputs] |
| |
|
| | if not inputs: |
| | return [{"error": "No inputs provided", "triggers": []}] |
| |
|
| | |
| | pipeline_inputs = [] |
| | valid_indices = [] |
| |
|
| | for i, item in enumerate(inputs): |
| | utterance = item.get("utterance", "").strip() |
| | emotion = item.get("emotion", "") |
| |
|
| | if not utterance: |
| | logger.warning(f"Empty utterance at index {i}") |
| | continue |
| |
|
| | |
| | question = self.question_template.format(emotion=emotion) |
| | |
| | |
| | pipeline_inputs.append({ |
| | 'question': question, |
| | 'context': utterance |
| | }) |
| | valid_indices.append(i) |
| |
|
| | |
| | results = [] |
| |
|
| | if not pipeline_inputs: |
| | |
| | for item in inputs: |
| | results.append({ |
| | "utterance": item.get("utterance", ""), |
| | "emotion": item.get("emotion", ""), |
| | "error": "Missing or empty utterance", |
| | "triggers": [] |
| | }) |
| | return results |
| |
|
| | try: |
| | |
| | predictions = self.pipe(pipeline_inputs, batch_size=8) |
| | |
| | |
| | |
| | if isinstance(predictions, dict): |
| | predictions = [predictions] |
| | elif isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict): |
| | |
| | |
| | |
| | if len(pipeline_inputs) == 1: |
| | predictions = [predictions] |
| | |
| | |
| | |
| | logger.debug(f"Raw predictions: {predictions}") |
| |
|
| | |
| | pred_idx = 0 |
| | for i, item in enumerate(inputs): |
| | utterance = item.get("utterance", "").strip() |
| | emotion = item.get("emotion", "") |
| |
|
| | if i not in valid_indices: |
| | results.append({ |
| | "utterance": utterance, |
| | "emotion": emotion, |
| | "error": "Missing or empty utterance", |
| | "triggers": [] |
| | }) |
| | else: |
| | |
| | |
| | current_preds = predictions[pred_idx] |
| |
|
| | |
| | |
| | if isinstance(current_preds, dict): |
| | current_preds = [current_preds] |
| |
|
| | logger.info( |
| | "RECCON raw spans (answer, score): %s", |
| | [(p.get("answer"), p.get("score", 0.0), 3) for p in current_preds[:5]] |
| | ) |
| | |
| | def is_good_span(ans: str) -> bool: |
| | if not ans: |
| | return False |
| | a = ans.strip() |
| | if len(a) < 3: |
| | return False |
| | |
| | if all(ch in ".,!?;:-—'\"()[]{}" for ch in a): |
| | return False |
| | |
| | if not any(ch.isalpha() for ch in a): |
| | return False |
| | return True |
| | |
| | raw_answers = [p.get("answer", "") for p in current_preds] |
| | raw_answers = [a for a in raw_answers if is_good_span(a)] |
| | triggers = self._clean_spans(raw_answers, utterance) |
| |
|
| | results.append({ |
| | "utterance": utterance, |
| | "emotion": emotion, |
| | "triggers": triggers |
| | }) |
| | pred_idx += 1 |
| |
|
| | logger.debug(f"Cleaned results: {results}") |
| | return results |
| |
|
| | except Exception as e: |
| | logger.error(f"Model prediction failed: {e}") |
| | return [{ |
| | "utterance": item.get("utterance", ""), |
| | "emotion": item.get("emotion", ""), |
| | "error": str(e), |
| | "triggers": [] |
| | } for item in inputs] |
| |
|
| | def _clean_spans(self, spans: List[str], target_text: str) -> List[str]: |
| | """ |
| | Clean and filter extracted trigger spans. |
| | (Logic preserved exactly as provided) |
| | """ |
| | target_text = target_text or "" |
| | target_lower = target_text.lower() |
| |
|
| | def _norm(s: str) -> str: |
| | s = (s or "").strip().lower() |
| | s = re.sub(r"\s+", " ", s) |
| | s = re.sub(r"^[^\w]+|[^\w]+$", "", s) |
| | return s |
| |
|
| | def _extract_from_target(target: str, phrase_lower: str) -> str: |
| | idx = target.lower().find(phrase_lower) |
| | if idx >= 0: |
| | return target[idx:idx+len(phrase_lower)] |
| | return phrase_lower |
| |
|
| | STOP = { |
| | "a", "an", "the", "and", "or", "but", "so", "to", "of", "in", "on", "at", |
| | "with", "for", "from", "is", "am", "are", "was", "were", "be", "been", |
| | "being", "i", "you", "he", "she", "it", "we", "they", "my", "your", "his", |
| | "her", "their", "our", "me", "him", "her", "them", "this", "that", "these", |
| | "those" |
| | } |
| |
|
| | candidates = [] |
| | for s in spans: |
| | s = (s or "").strip() |
| | if not s: |
| | continue |
| | s_norm = _norm(s) |
| | if not s_norm: |
| | continue |
| | if target_text and s_norm not in target_lower: |
| | continue |
| | tokens = s_norm.split() |
| | if len(tokens) > 8 or len(s_norm) > 80: |
| | continue |
| | if len(tokens) == 1 and (tokens[0] in STOP or len(tokens[0]) <= 2): |
| | continue |
| | candidates.append({ |
| | "norm": s_norm, |
| | "tokens": tokens, |
| | "tok_len": len(tokens), |
| | "char_len": len(s_norm) |
| | }) |
| |
|
| | |
| | short_candidates = [c for c in candidates if 1 <= c["tok_len"] <= 3] |
| | if short_candidates: |
| | candidates = short_candidates |
| | |
| | |
| | candidates.sort(key=lambda x: (x["tok_len"], x["char_len"]), reverse=False) |
| | kept_norms = [] |
| | for c in list(candidates): |
| | n = c["norm"] |
| | if any(n in kn or kn in n for kn in kept_norms): |
| | continue |
| | kept_norms.append(n) |
| |
|
| | cleaned = [_extract_from_target(target_text, n) for n in kept_norms] |
| |
|
| | if not cleaned and spans: |
| | tt_tokens = target_lower.split() |
| | best = None |
| | for s in spans: |
| | words = [w for w in (s or '').lower().strip().split() if w] |
| | for L in range(min(8, len(words)), 0, -1): |
| | for i in range(len(words) - L + 1): |
| | phrase = words[i:i+L] |
| | for j in range(len(tt_tokens) - L + 1): |
| | if tt_tokens[j:j+L] == phrase: |
| | cand = " ".join(phrase) |
| | best = cand |
| | break |
| | if best: |
| | break |
| | if best: |
| | break |
| | if best: |
| | return [_extract_from_target(target_text, best)] |
| |
|
| | return cleaned[:3] |