| | from json import load |
| | from typing import Any, Dict, Optional |
| |
|
| | from numpy import array, expand_dims, float32, ndarray, transpose, zeros |
| | from PIL import Image |
| | from sentence_transformers import SentenceTransformer |
| | from tensorflow import constant |
| | from tensorflow.keras.models import load_model |
| | from transformers import TFConvNextV2Model |
| |
|
| | |
| | CATEGORY_MAP: Dict[str, str] = {} |
| | CLASS_LABELS = [] |
| |
|
| |
|
| | def build_category_map(categories_json_path: str): |
| | """ |
| | Builds a flat dictionary and a list of category labels by traversing the hierarchical categories.json file. |
| | """ |
| | global CATEGORY_MAP, CLASS_LABELS |
| |
|
| | try: |
| | with open(categories_json_path, "r") as f: |
| | categories_data = load(f) |
| | except FileNotFoundError: |
| | print( |
| | f"β Error: {categories_json_path} not found. Using hardcoded labels as fallback." |
| | ) |
| | return |
| |
|
| | category_map = {} |
| |
|
| | model_trained_ids = [ |
| | "abcat0100000", |
| | "abcat0200000", |
| | "abcat0207000", |
| | "abcat0300000", |
| | "abcat0400000", |
| | "abcat0500000", |
| | "abcat0700000", |
| | "abcat0800000", |
| | "abcat0900000", |
| | "cat09000", |
| | "pcmcat128500050004", |
| | "pcmcat139900050002", |
| | "pcmcat242800050021", |
| | "pcmcat252700050006", |
| | "pcmcat312300050015", |
| | "pcmcat332000050000", |
| | ] |
| |
|
| | def traverse_categories(categories): |
| | for category in categories: |
| | category_map[category["id"]] = category["name"] |
| | if "subCategories" in category and category["subCategories"]: |
| | traverse_categories(category["subCategories"]) |
| | if "path" in category and category["path"]: |
| | for path_item in category["path"]: |
| | category_map[path_item["id"]] = path_item["name"] |
| |
|
| | traverse_categories(categories_data) |
| |
|
| | CATEGORY_MAP = category_map |
| | CLASS_LABELS = model_trained_ids |
| |
|
| |
|
| | |
| | print("π¬ Loading embedding models...") |
| | try: |
| | text_embedding_model = SentenceTransformer("all-MiniLM-L6-v2") |
| | image_feature_extractor = TFConvNextV2Model.from_pretrained( |
| | "facebook/convnextv2-tiny-22k-224" |
| | ) |
| | print("β
Embedding models loaded successfully!") |
| | except Exception as e: |
| | print(f"β Error loading embedding models: {e}") |
| | text_embedding_model, image_feature_extractor = None, None |
| |
|
| | |
| | print("π¬ Loading classification models...") |
| | try: |
| | text_model = load_model("./models/text_model") |
| | image_model = load_model("./models/image_model") |
| | multimodal_model = load_model("./models/multimodal_model") |
| | print("β
Classification models loaded successfully!") |
| | except Exception as e: |
| | print(f"β Error loading classification models: {e}") |
| | text_model, image_model, multimodal_model = None, None, None |
| |
|
| | |
| | build_category_map("./data/raw/categories.json") |
| |
|
| |
|
| | |
| | def get_text_embeddings(text: Optional[str]) -> ndarray: |
| | """ |
| | Generates a dense embedding vector from a text string. |
| | |
| | Args: |
| | text (Optional[str]): The input text. Can be None or an empty string. |
| | |
| | Returns: |
| | np.ndarray: A NumPy array of shape (1, 384) representing the text |
| | embedding. Returns a zero vector if the input is empty. |
| | """ |
| | |
| | if not text or not text.strip(): |
| | |
| | return zeros( |
| | (1, text_embedding_model.get_sentence_embedding_dimension()), dtype=float32 |
| | ) |
| |
|
| | |
| | embeddings = text_embedding_model.encode([text]) |
| | return array(embeddings, dtype=float32) |
| |
|
| |
|
| | def get_image_embeddings(image_path: Optional[str]) -> ndarray: |
| | """ |
| | Preprocesses an image and generates an embedding vector using a pre-trained model. |
| | |
| | Args: |
| | image_path (Optional[str]): The file path to the image. |
| | |
| | Returns: |
| | np.ndarray: A NumPy array of shape (1, 768) representing the image |
| | embedding. Returns a zero vector if no image is provided. |
| | """ |
| | |
| | if image_path is None: |
| | return zeros((1, 768), dtype=float32) |
| |
|
| | |
| | image = Image.open(image_path).convert("RGB") |
| |
|
| | |
| | image = image.resize((224, 224), Image.Resampling.LANCZOS) |
| |
|
| | |
| | image_array = array(image, dtype=float32) |
| | image_array = expand_dims(image_array, axis=0) |
| |
|
| | |
| | image_array = transpose(image_array, (0, 3, 1, 2)) |
| |
|
| | |
| | image_array = image_array / 255.0 |
| |
|
| | |
| | embeddings_output = image_feature_extractor(constant(image_array)) |
| |
|
| | |
| | embeddings = embeddings_output.pooler_output |
| |
|
| | return embeddings.numpy() |
| |
|
| |
|
| | |
| | def predict( |
| | mode: str, text: Optional[str], image_path: Optional[str] |
| | ) -> Dict[str, Any]: |
| | """ |
| | Predicts the category of a product based on the selected mode. |
| | |
| | Args: |
| | mode (str): The prediction mode ("Multimodal", "Text Only", "Image Only"). |
| | text (Optional[str]): The product description text. |
| | image_path (Optional[str]): The file path to the product image. |
| | |
| | Returns: |
| | Dict[str, Any]: A dictionary of class labels and their corresponding |
| | prediction probabilities. Returns an empty dictionary |
| | if the mode is invalid. |
| | """ |
| | |
| | text_emb = get_text_embeddings(text) |
| | image_emb = get_image_embeddings(image_path) |
| |
|
| | |
| | if mode == "Multimodal": |
| | predictions = multimodal_model.predict([text_emb, image_emb]) |
| | elif mode == "Text Only": |
| | predictions = text_model.predict(text_emb) |
| | elif mode == "Image Only": |
| | predictions = image_model.predict(image_emb) |
| | else: |
| | |
| | return {} |
| |
|
| | |
| | |
| | prediction_dict_raw = dict(zip(CLASS_LABELS, predictions[0])) |
| |
|
| | |
| | prediction_dict_mapped = {} |
| | for class_id, probability in prediction_dict_raw.items(): |
| | |
| | category_name = CATEGORY_MAP.get(class_id, class_id) |
| | prediction_dict_mapped[category_name] = probability |
| |
|
| | |
| | sorted_predictions = dict( |
| | sorted(prediction_dict_mapped.items(), key=lambda item: item[1], reverse=True) |
| | ) |
| |
|
| | return sorted_predictions |
| |
|