module updates
This commit is contained in:
parent
a6289289ee
commit
f4b433fbc1
@ -61,6 +61,7 @@ google-auth-oauthlib
|
|||||||
#langchain[all]
|
#langchain[all]
|
||||||
langchain[llms]
|
langchain[llms]
|
||||||
langchain-community
|
langchain-community
|
||||||
|
langchain-classic
|
||||||
#langchain-cli
|
#langchain-cli
|
||||||
#langchain-core
|
#langchain-core
|
||||||
langchain-openai
|
langchain-openai
|
||||||
|
|||||||
@ -14,10 +14,10 @@ logging = logger.get_logger(
|
|||||||
log_format='default'
|
log_format='default'
|
||||||
)
|
)
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
from langchain.chains import GraphCypherQAChain
|
from langchain_classic.chains import GraphCypherQAChain
|
||||||
from langchain_community.graphs import Neo4jGraph
|
from langchain_community.graphs import Neo4jGraph
|
||||||
from langchain_community.chat_models import ChatOpenAI
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain_classic.prompts.prompt import PromptTemplate
|
||||||
from routers.llm.private.ollama.ollama_wrapper import OllamaWrapper
|
from routers.llm.private.ollama.ollama_wrapper import OllamaWrapper
|
||||||
from modules.database.tools.neontology.utils import get_node_types, get_rels_by_type
|
from modules.database.tools.neontology.utils import get_node_types, get_rels_by_type
|
||||||
from modules.database.tools.neontology.basenode import BaseNode
|
from modules.database.tools.neontology.basenode import BaseNode
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from ollama import Client
|
from ollama import Client
|
||||||
from langchain_core.runnables.base import Runnable
|
from langchain_core.runnables.base import Runnable
|
||||||
from langchain.prompts.base import StringPromptValue
|
|
||||||
|
|
||||||
class OllamaWrapper(Runnable):
|
class OllamaWrapper(Runnable):
|
||||||
def __init__(self, host: str, model: str = "llama3.2:latest"):
|
def __init__(self, host: str, model: str = "llama3.2:latest"):
|
||||||
@ -9,8 +8,11 @@ class OllamaWrapper(Runnable):
|
|||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def invoke(self, prompt: Any, config: Dict[str, Any] = None, **kwargs: Any) -> str:
|
def invoke(self, prompt: Any, config: Dict[str, Any] = None, **kwargs: Any) -> str:
|
||||||
if isinstance(prompt, StringPromptValue):
|
# Handle different prompt types (StringPromptValue, str, etc.)
|
||||||
|
if hasattr(prompt, 'to_string'):
|
||||||
prompt = prompt.to_string()
|
prompt = prompt.to_string()
|
||||||
|
elif not isinstance(prompt, str):
|
||||||
|
prompt = str(prompt)
|
||||||
|
|
||||||
# Use the model from constructor, but allow override via kwargs
|
# Use the model from constructor, but allow override via kwargs
|
||||||
model_name = kwargs.get("model", self.model)
|
model_name = kwargs.get("model", self.model)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user