using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; using LLM.Editor.Client; using LLM.Editor.Data; using UnityEditor; using UnityEngine; namespace LLM.Editor.Core { [InitializeOnLoad] public static class MemoryRetriever { public static async Task> GetRelevantMemories(string newUserPrompt, CancellationToken cancellationToken, int topK = 3) { var apiClient = ApiClientFactory.GetClient(); if (apiClient == null) { Debug.LogError("[MemoryRetriever] Could not get API client to generate embedding."); return new List(); } var newUserEmbedding = await EmbeddingHelper.GetEmbedding(newUserPrompt, apiClient.GetAuthToken, cancellationToken); if (newUserEmbedding == null) { // This can happen if the request was cancelled, which is not an error. if(!cancellationToken.IsCancellationRequested) Debug.LogError("[MemoryRetriever] Could not generate embedding for the new prompt."); return new List(); } var allRecords = MemoryLogger.GetRecords(); var scoredRecords = new List<(InteractionRecord record, float score)>(); foreach (var record in allRecords) { if (record.PromptEmbedding == null || record.PromptEmbedding.Length == 0) continue; var score = CosineSimilarity(newUserEmbedding, record.PromptEmbedding); scoredRecords.Add((record, score)); } return scoredRecords .OrderByDescending(t => t.score) .Take(topK) .Select(t => t.record) .ToList(); } private static float CosineSimilarity(IReadOnlyList v1, IReadOnlyList v2) { if (v1.Count != v2.Count) { Debug.LogError("[MemoryRetriever] Vector dimensions do not match for cosine similarity."); return 0; } float dotProduct = 0.0f; float norm1 = 0.0f; float norm2 = 0.0f; for (var i = 0; i < v1.Count; i++) { dotProduct += v1[i] * v2[i]; norm1 += Mathf.Pow(v1[i], 2); norm2 += Mathf.Pow(v2[i], 2); } if (norm1 == 0 || norm2 == 0) return 0; return dotProduct / (Mathf.Sqrt(norm1) * Mathf.Sqrt(norm2)); } } }