Emergent generative agents
Revision | c85f6b41fc7a2feabf601e41537d7cefa0f4c4fa (tree) |
---|---|
Zeit | 2023-04-10 03:57:49 |
Autor | Corbin <cds@corb...> |
Commiter | Corbin |
IRC stuff.
Multiple channels, timestamps, better formatting of log lines, log
storage, and small tweaks.
@@ -1,7 +1,11 @@ | ||
1 | 1 | #!/usr/bin/env nix-shell |
2 | 2 | #! nix-shell -i python3 -p python3Packages.irc python3Packages.transformers python3Packages.torch |
3 | 3 | |
4 | +from concurrent.futures import ThreadPoolExecutor | |
5 | +from datetime import datetime | |
4 | 6 | import json |
7 | +import os.path | |
8 | +import random | |
5 | 9 | import sys |
6 | 10 | |
7 | 11 | from irc.bot import SingleServerIRCBot |
@@ -52,58 +56,95 @@ Goals, working memory, notes: {allGoals} | ||
52 | 56 | def load_character(path): |
53 | 57 | with open(path) as handle: return json.load(handle) |
54 | 58 | |
55 | -characters = [load_character(arg) for arg in sys.argv[1:]] | |
56 | -prompts = [build_prompt(**character) for character in characters] | |
57 | -prompt, title = prompts.pop(0) | |
59 | +character = load_character(sys.argv[2]) | |
60 | +startingChannels = character.pop("startingChannels") | |
61 | +prompt, title = build_prompt(**character) | |
58 | 62 | |
59 | 63 | MAX_NEW_TOKENS = 128 |
60 | -gen = CamelidGen() | |
61 | -gen = HFGen(Flavor.GPTNeo, MAX_NEW_TOKENS) | |
64 | +if sys.argv[1] == "llama": | |
65 | + print("~ Initializing camelid adapter…") | |
66 | + gen = CamelidGen() | |
67 | +else: | |
68 | + print("~ Initializing GPT-Neo on HF…") | |
69 | + gen = HFGen(Flavor.GPTNeo, MAX_NEW_TOKENS) | |
62 | 70 | max_context_length = gen.contextLength() |
63 | 71 | |
72 | +executor = ThreadPoolExecutor(max_workers=1) | |
73 | + | |
64 | 74 | class Agent(SingleServerIRCBot): |
65 | - def __init__(self, host, title): | |
75 | + def __init__(self, host, title, startingChannels, logpath): | |
66 | 76 | super(Agent, self).__init__([(host, 6667)], title_to_nick(title), title) |
67 | - self.log = Log([]) | |
77 | + self.startingChannels = startingChannels | |
78 | + self.logpath = logpath | |
79 | + self.logs = {} | |
80 | + | |
81 | + def on_join(self, c, e): | |
82 | + channel = e.target | |
83 | + c.topic(channel) | |
84 | + try: | |
85 | + with open(os.path.join(self.logpath, channel + ".txt"), "r") as f: | |
86 | + self.logs[channel] = Log(f.read().strip().split("\n")) | |
87 | + except IOError: self.logs[channel] = Log([]) | |
68 | 88 | |
69 | - def on_join(self, c, e): c.topic(e.target) | |
70 | 89 | def on_currenttopic(self, c, e): |
71 | 90 | self.channels[e.arguments[0]].topic = e.arguments[1] |
72 | 91 | |
73 | - def on_welcome(self, c, e): c.join("#treehouse") | |
92 | + def on_welcome(self, c, e): | |
93 | + for channel in self.startingChannels: c.join(channel) | |
74 | 94 | |
75 | 95 | def on_pubmsg(self, c, e): |
76 | 96 | line = e.arguments[0] |
77 | - self.log.push(e.source.nick, line) | |
97 | + channel = e.target | |
98 | + log = self.logs[channel] | |
99 | + log.irc(datetime.now(), e.source.nick, line) | |
78 | 100 | # Dispatch in the style of |
79 | 101 | # https://github.com/jaraco/irc/blob/main/scripts/testbot.py |
80 | - if ":" in line: | |
81 | - prefix = lower(self.connection.get_nickname()) + ":" | |
82 | - if lower(line).startswith(prefix): | |
83 | - self.do_command(c, e, line[len(prefix):].strip()) | |
84 | - | |
85 | - def do_command(self, c, e, inst): | |
86 | - channel = e.target | |
87 | - nick = self.connection.get_nickname() | |
88 | - users = self.channels[channel].users() | |
89 | - prefix = nick + ":" | |
90 | - fullPrompt = prompt + self.chatPrompt(channel) | |
91 | - self.log.bumpCutoff(max_context_length, gen.countTokens, fullPrompt, prefix) | |
92 | - s = self.log.finishPrompt(fullPrompt, prefix) | |
93 | - print("~ log cutoff:", self.log.cutoff, | |
94 | - "prompt length (tokens):", gen.countTokens(s)) | |
95 | - # Hack: GPT-Neo tries to smuggle messages, so forbid it. | |
96 | - forbidden = [prefix] + ["." + user for user in users] | |
97 | - reply = parseLine(gen.complete(s), forbidden) | |
98 | - self.log.push(nick, reply) | |
99 | - c.privmsg(channel, reply) | |
102 | + nick = lower(self.connection.get_nickname()) | |
103 | + colon = nick + ":" | |
104 | + comma = nick + "," | |
105 | + lowered = lower(line) | |
106 | + if colon in lowered or comma in lowered: | |
107 | + self.do_command(c, e, line[len(colon):].strip()) | |
108 | + elif nick in lowered and random.random() <= 0.875: | |
109 | + self.generateReply(c, channel) | |
110 | + elif random.random() <= 0.125: self.generateReply(c, channel) | |
111 | + | |
112 | + def do_command(self, c, e, inst): self.generateReply(c, e.target) | |
100 | 113 | |
101 | 114 | def chatPrompt(self, channel): |
102 | - topic = getattr(self.channels[channel], "topic", "no topic") | |
115 | + c = self.channels[channel] | |
116 | + topic = getattr(c, "topic", "no topic") | |
117 | + users = ", ".join(c.users()) | |
103 | 118 | return f""" |
104 | 119 | IRC channel: {channel} |
105 | 120 | Channel topic: {topic} |
121 | +Channel users: {users} | |
106 | 122 | """ |
107 | 123 | |
108 | -agent = Agent("june.local", title) | |
109 | -agent.start() | |
124 | + def generateReply(self, c, channel): | |
125 | + log = self.logs[channel] | |
126 | + nick = self.connection.get_nickname() | |
127 | + users = self.channels[channel].users() | |
128 | + prefix = nick + ":" | |
129 | + fullPrompt = prompt + self.chatPrompt(channel) | |
130 | + log.bumpCutoff(max_context_length, gen.countTokens, fullPrompt, prefix) | |
131 | + s = log.finishPrompt(fullPrompt, prefix) | |
132 | + print("~ log cutoff:", log.cutoff, | |
133 | + "prompt length (tokens):", gen.countTokens(s)) | |
134 | + forbidden = [prefix] + ["." + user for user in users] | |
135 | + # NB: At this point, execution is kicked out to a thread. | |
136 | + def cb(completion): | |
137 | + reply = parseLine(completion.result(), forbidden) | |
138 | + log.irc(datetime.now(), nick, reply) | |
139 | + c.privmsg(channel, reply) | |
140 | + executor.submit(lambda: gen.complete(s)).add_done_callback(cb) | |
141 | + | |
142 | +logpath = sys.argv[3] | |
143 | +agent = Agent("june.local", title, startingChannels, logpath) | |
144 | +try: agent.start() | |
145 | +except KeyboardInterrupt: | |
146 | + print("~ Saving logs…") | |
147 | + for channel, log in agent.logs.items(): | |
148 | + with open(os.path.join(agent.logpath, channel + ".txt"), "w") as f: | |
149 | + f.write("\n".join(log.l)) | |
150 | +print("~ Quitting, bye!") |
@@ -22,6 +22,8 @@ class Log: | ||
22 | 22 | self.stamp += 1 |
23 | 23 | |
24 | 24 | def push(self, speaker, entry): self.raw(speaker + ": " + entry) |
25 | + def irc(self, t, speaker, entry): | |
26 | + self.raw(f"{t:%H:%M:%S} <{speaker}> {entry}") | |
25 | 27 | |
26 | 28 | def finishPrompt(self, s, prefix): |
27 | 29 | return self.finishPromptAtCutoff(self.cutoff, s, prefix) |
@@ -4,9 +4,10 @@ from common import Timer | ||
4 | 4 | |
5 | 5 | llama = [ |
6 | 6 | "/home/simpson/models/result/bin/llama-cpp", |
7 | - "-m", "/home/simpson/models/13b-4bit.bin", | |
8 | - "-f", "/dev/stdin", | |
9 | - "-c", "2048", | |
7 | + "-t", "3", | |
8 | + "-m", "/home/simpson/models/7b-4bit.bin", | |
9 | + "-f", "/dev/stdin", | |
10 | + "-c", "2048", | |
10 | 11 | ] |
11 | 12 | |
12 | 13 | class CamelidGen: |