90 lines
3.2 KiB
Python
90 lines
3.2 KiB
Python
# bot/recommender.py
|
|
|
|
import openai
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from typing import List, Dict, Literal
|
|
from library_cache import LibraryCache
|
|
from config import OPENAI_API_KEY, OPENAI_MODEL
|
|
|
|
load_dotenv()
|
|
|
|
openai.api_key = OPENAI_API_KEY
|
|
|
|
MediaType = Literal["movie", "show"]
|
|
|
|
class Recommender:
|
|
def __init__(self):
|
|
self.cache = LibraryCache()
|
|
|
|
def recommend(self, watched_titles: List[str], media_type: MediaType = "movie", max_recs: int = 5) -> Dict[str, List[str]]:
|
|
if not watched_titles:
|
|
raise ValueError("No watched titles provided.")
|
|
|
|
prompt = self.build_prompt(watched_titles, media_type, max_recs)
|
|
response = self.query_openai(prompt)
|
|
|
|
print("🧠 Prompt:", prompt)
|
|
print("📥 Raw response:", response)
|
|
|
|
all_titles = self.parse_titles(response)
|
|
print("📦 Parsed titles:", all_titles)
|
|
|
|
available = [title for title in all_titles if self.cache.search(title, media_type)]
|
|
requestable = [title for title in all_titles if title not in available]
|
|
|
|
return {
|
|
"available": available,
|
|
"requestable": requestable
|
|
}
|
|
|
|
def build_prompt(self, watched: List[str], media_type: str, max_recs: int) -> str:
|
|
type_text = "movies" if media_type == "movie" else "TV shows"
|
|
|
|
# You could optionally summarize genres here
|
|
genre_summary = self.extract_common_genres(watched, media_type)
|
|
|
|
return (
|
|
f"A user has watched the following {type_text}: {', '.join(watched[:20])}. "
|
|
f"These shows are mostly {genre_summary}. "
|
|
f"Recommend {max_recs} similar {type_text} based on theme and tone. "
|
|
f"Return only a plain comma-separated list of titles — no numbers, no explanations."
|
|
)
|
|
|
|
def extract_common_genres(self, watched: List[str], media_type: str) -> str:
|
|
genre_counts = {}
|
|
for item in self.cache.data:
|
|
if item["title"] in watched and item["type"] == media_type:
|
|
for genre in item.get("genres", []):
|
|
genre_counts[genre] = genre_counts.get(genre, 0) + 1
|
|
sorted_genres = sorted(genre_counts.items(), key=lambda x: x[1], reverse=True)
|
|
top_genres = [g for g, _ in sorted_genres[:3]]
|
|
return ", ".join(top_genres) if top_genres else "varied genres"
|
|
|
|
|
|
def query_openai(self, prompt: str) -> str:
|
|
try:
|
|
response = openai.ChatCompletion.create(
|
|
model=OPENAI_MODEL,
|
|
messages=[
|
|
{"role": "system", "content": "You're a helpful and precise media recommender."},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
temperature=0.4,
|
|
max_tokens=150
|
|
)
|
|
return response.choices[0].message.content
|
|
except Exception as e:
|
|
print("⚠️ OpenAI API error:", e)
|
|
return ""
|
|
|
|
def parse_titles(self, response: str) -> List[str]:
|
|
lines = response.replace("\n", ",").split(",")
|
|
cleaned = []
|
|
for item in lines:
|
|
item = item.strip()
|
|
item = item.lstrip("-•*0123456789. ").strip()
|
|
if item:
|
|
cleaned.append(item)
|
|
return cleaned
|